import json import os import random import re import string from collections import defaultdict from datetime import datetime from typing import Any, DefaultDict, Dict, Iterable, List, Optional, Union, cast from dateutil.tz import tzutc from moto.core.base_backend import BackendDict, BaseBackend from moto.core.common_models import BaseModel, CloudFormationModel from moto.core.utils import camelcase_to_underscores from moto.sagemaker import validators from moto.utilities.paginator import paginate from moto.utilities.utils import ARN_PARTITION_REGEX, get_partition from .exceptions import ( AWSValidationException, ConflictException, MissingModel, ResourceInUseException, ResourceNotFound, ValidationError, ) from .utils import ( arn_formatter, filter_model_cards, get_pipeline_execution_from_arn, get_pipeline_from_name, get_pipeline_name_from_execution_arn, load_pipeline_definition_from_s3, validate_model_approval_status, ) PAGINATION_MODEL = { "list_experiments": { "input_token": "NextToken", "limit_key": "MaxResults", "limit_default": 100, "unique_attribute": "arn", }, "list_trials": { "input_token": "NextToken", "limit_key": "MaxResults", "limit_default": 100, "unique_attribute": "arn", }, "list_trial_components": { "input_token": "NextToken", "limit_key": "MaxResults", "limit_default": 100, "unique_attribute": "arn", }, "list_tags": { "input_token": "NextToken", "limit_key": "MaxResults", "limit_default": 50, "unique_attribute": "Key", }, "list_model_package_groups": { "input_token": "next_token", "limit_key": "max_results", "limit_default": 100, "unique_attribute": "arn", }, "list_model_packages": { "input_token": "next_token", "limit_key": "max_results", "limit_default": 100, "unique_attribute": "arn", }, "list_notebook_instances": { "input_token": "next_token", "limit_key": "max_results", "limit_default": 100, "unique_attribute": "arn", }, "list_clusters": { "input_token": "next_token", "limit_key": "max_results", "limit_default": 100, "unique_attribute": "arn", }, "list_cluster_nodes": { "input_token": "next_token", "limit_key": "max_results", "limit_default": 100, "unique_attribute": "arn", }, "list_auto_ml_jobs": { "input_token": "next_token", "limit_key": "max_results", "limit_default": 100, "unique_attribute": "arn", }, "list_endpoints": { "input_token": "next_token", "limit_key": "max_results", "limit_default": 100, "unique_attribute": "arn", }, "list_endpoint_configs": { "input_token": "next_token", "limit_key": "max_results", "limit_default": 100, "unique_attribute": "arn", }, "list_compilation_jobs": { "input_token": "next_token", "limit_key": "max_results", "limit_default": 100, "unique_attribute": "arn", }, "list_domains": { "input_token": "next_token", "limit_key": "max_results", "limit_default": 100, "unique_attribute": "arn", }, "list_model_explainability_job_definitions": { "input_token": "next_token", "limit_key": "max_results", "limit_default": 100, "unique_attribute": "arn", }, "list_hyper_parameter_tuning_jobs": { "input_token": "next_token", "limit_key": "max_results", "limit_default": 100, "unique_attribute": "arn", }, "list_model_quality_job_definitions": { "input_token": "next_token", "limit_key": "max_results", "limit_default": 100, "unique_attribute": "arn", }, "list_model_cards": { "input_token": "next_token", "limit_key": "max_results", "limit_default": 100, "unique_attribute": "arn", }, "list_model_card_versions": { "input_token": "next_token", "limit_key": "max_results", "limit_default": 100, "unique_attribute": "model_card_arn", }, "list_model_bias_job_definitions": { "input_token": "next_token", "limit_key": "max_results", "limit_default": 100, "unique_attribute": "arn", }, "list_data_quality_job_definitions": { "input_token": "next_token", "limit_key": "max_results", "limit_default": 100, "unique_attribute": "arn", }, } METRIC_INFO_TYPE = Dict[str, Union[str, int, float, datetime]] METRIC_STEP_TYPE = Dict[int, METRIC_INFO_TYPE] class BaseObject(BaseModel): def camelCase(self, key: str) -> str: words = [] for word in key.split("_"): words.append(word.title()) return "".join(words) def update(self, details_json: str) -> None: details = json.loads(details_json) for k in details.keys(): setattr(self, k, details[k]) def gen_response_object(self) -> Dict[str, Any]: response_object: Dict[str, Any] = dict() for key, value in self.__dict__.items(): if "_" in key: response_object[self.camelCase(key)] = value else: response_object[key[0].upper() + key[1:]] = value return response_object @property def response_object(self) -> Dict[str, Any]: # type: ignore[misc] return self.gen_response_object() class FakePipelineExecution(BaseObject): def __init__( self, pipeline_execution_arn: str, pipeline_execution_display_name: str, pipeline_parameters: List[Dict[str, str]], pipeline_execution_description: str, parallelism_configuration: Dict[str, int], pipeline_definition: str, client_request_token: str, ): self.arn = pipeline_execution_arn self.pipeline_execution_display_name = pipeline_execution_display_name self.pipeline_parameters = pipeline_parameters self.pipeline_execution_description = pipeline_execution_description self.pipeline_execution_status = "Succeeded" self.pipeline_execution_failure_reason = None self.parallelism_configuration = parallelism_configuration self.pipeline_definition_for_execution = pipeline_definition self.client_request_token = client_request_token now_string = datetime.now().strftime("%Y-%m-%d %H:%M:%S") self.creation_time = now_string self.last_modified_time = now_string self.start_time = now_string fake_user_profile_name = "fake-user-profile-name" fake_domain_id = "fake-domain-id" fake_user_profile_arn = arn_formatter( "user-profile", f"{fake_domain_id}/{fake_user_profile_name}", pipeline_execution_arn.split(":")[4], pipeline_execution_arn.split(":")[3], ) self.created_by = { "UserProfileArn": fake_user_profile_arn, "UserProfileName": fake_user_profile_name, "DomainId": fake_domain_id, } self.last_modified_by = { "UserProfileArn": fake_user_profile_arn, "UserProfileName": fake_user_profile_name, "DomainId": fake_domain_id, } class FakePipeline(BaseObject): def __init__( self, pipeline_name: str, pipeline_display_name: str, pipeline_definition: str, pipeline_description: str, role_arn: str, tags: List[Dict[str, str]], account_id: str, region_name: str, parallelism_configuration: Dict[str, int], ): self.pipeline_name = pipeline_name self.arn = arn_formatter("pipeline", pipeline_name, account_id, region_name) self.pipeline_display_name = pipeline_display_name or pipeline_name self.pipeline_definition = pipeline_definition self.pipeline_description = pipeline_description self.pipeline_executions: Dict[str, FakePipelineExecution] = dict() self.role_arn = role_arn self.tags = tags or [] self.parallelism_configuration = parallelism_configuration now_string = datetime.now().strftime("%Y-%m-%d %H:%M:%S") self.creation_time = now_string self.last_modified_time = now_string self.last_execution_time: Optional[str] = None self.pipeline_status = "Active" fake_user_profile_name = "fake-user-profile-name" fake_domain_id = "fake-domain-id" fake_user_profile_arn = arn_formatter( "user-profile", f"{fake_domain_id}/{fake_user_profile_name}", account_id, region_name, ) self.created_by = { "UserProfileArn": fake_user_profile_arn, "UserProfileName": fake_user_profile_name, "DomainId": fake_domain_id, } self.last_modified_by = { "UserProfileArn": fake_user_profile_arn, "UserProfileName": fake_user_profile_name, "DomainId": fake_domain_id, } class FakeProcessingJob(BaseObject): def __init__( self, app_specification: Dict[str, Any], experiment_config: Dict[str, str], network_config: Dict[str, Any], processing_inputs: List[Dict[str, Any]], processing_job_name: str, processing_output_config: Dict[str, Any], account_id: str, region_name: str, role_arn: str, tags: List[Dict[str, str]], stopping_condition: Dict[str, int], ): self.processing_job_name = processing_job_name self.arn = FakeProcessingJob.arn_formatter( processing_job_name, account_id, region_name ) now_string = datetime.now().strftime("%Y-%m-%d %H:%M:%S") self.creation_time = now_string self.last_modified_time = now_string self.processing_end_time = now_string self.tags = tags or [] self.role_arn = role_arn self.app_specification = app_specification self.experiment_config = experiment_config self.network_config = network_config self.processing_inputs = processing_inputs self.processing_job_status = "Completed" self.processing_output_config = processing_output_config self.stopping_condition = stopping_condition @property def response_object(self) -> Dict[str, Any]: # type: ignore[misc] response_object = self.gen_response_object() response = { k: v for k, v in response_object.items() if v is not None and v != [None] } response["ProcessingJobArn"] = response.pop("Arn") return response @property def response_create(self) -> Dict[str, str]: return {"ProcessingJobArn": self.arn} @staticmethod def arn_formatter(name: str, account_id: str, region: str) -> str: return arn_formatter("processing-job", name, account_id, region) class FakeTrainingJob(BaseObject): def __init__( self, account_id: str, region_name: str, training_job_name: str, hyper_parameters: Dict[str, str], algorithm_specification: Dict[str, Any], role_arn: str, input_data_config: List[Dict[str, Any]], output_data_config: Dict[str, str], resource_config: Dict[str, Any], vpc_config: Dict[str, List[str]], stopping_condition: Dict[str, int], tags: List[Dict[str, str]], enable_network_isolation: bool, enable_inter_container_traffic_encryption: bool, enable_managed_spot_training: bool, checkpoint_config: Dict[str, str], debug_hook_config: Dict[str, Any], debug_rule_configurations: List[Dict[str, Any]], tensor_board_output_config: Dict[str, str], experiment_config: Dict[str, str], ): self.training_job_name = training_job_name self.hyper_parameters = hyper_parameters self.algorithm_specification = algorithm_specification self.role_arn = role_arn self.input_data_config = input_data_config self.output_data_config = output_data_config self.resource_config = resource_config self.vpc_config = vpc_config self.stopping_condition = stopping_condition self.tags = tags or [] self.enable_network_isolation = enable_network_isolation self.enable_inter_container_traffic_encryption = ( enable_inter_container_traffic_encryption ) self.enable_managed_spot_training = enable_managed_spot_training self.checkpoint_config = checkpoint_config self.debug_hook_config = debug_hook_config self.debug_rule_configurations = debug_rule_configurations self.tensor_board_output_config = tensor_board_output_config self.experiment_config = experiment_config self.arn = FakeTrainingJob.arn_formatter( training_job_name, account_id, region_name ) self.creation_time = self.last_modified_time = datetime.now().strftime( "%Y-%m-%d %H:%M:%S" ) self.model_artifacts = { "S3ModelArtifacts": os.path.join( self.output_data_config["S3OutputPath"], self.training_job_name, "output", "model.tar.gz", ) } self.training_job_status = "Completed" self.secondary_status = "Completed" self.algorithm_specification["MetricDefinitions"] = [ { "Name": "test:dcg", "Regex": "#quality_metric: host=\\S+, test dcg <score>=(\\S+)", } ] now_string = datetime.now().strftime("%Y-%m-%d %H:%M:%S") self.creation_time = now_string self.last_modified_time = now_string self.training_start_time = now_string self.training_end_time = now_string self.secondary_status_transitions = [ { "Status": "Starting", "StartTime": self.creation_time, "EndTime": self.creation_time, "StatusMessage": "Preparing the instances for training", } ] self.final_metric_data_list = [ { "MetricName": "train:progress", "Value": 100.0, "Timestamp": self.creation_time, } ] @property def response_object(self) -> Dict[str, Any]: # type: ignore[misc] response_object = self.gen_response_object() response = { k: v for k, v in response_object.items() if v is not None and v != [None] } response["TrainingJobArn"] = response.pop("Arn") return response @property def response_create(self) -> Dict[str, str]: return {"TrainingJobArn": self.arn} @staticmethod def arn_formatter(name: str, account_id: str, region_name: str) -> str: return arn_formatter("training-job", name, account_id, region_name) class FakeEndpoint(BaseObject, CloudFormationModel): def __init__( self, account_id: str, region_name: str, endpoint_name: str, endpoint_config_name: str, production_variants: List[Dict[str, Any]], data_capture_config: Dict[str, Any], tags: List[Dict[str, str]], ): self.endpoint_name = endpoint_name self.arn = FakeEndpoint.arn_formatter(endpoint_name, account_id, region_name) self.endpoint_config_name = endpoint_config_name self.production_variants = self._process_production_variants( production_variants ) self.data_capture_config = data_capture_config self.tags = tags or [] self.endpoint_status = "InService" self.failure_reason = None self.creation_time = self.last_modified_time = datetime.now().strftime( "%Y-%m-%d %H:%M:%S" ) def _process_production_variants( self, production_variants: List[Dict[str, Any]] ) -> List[Dict[str, Any]]: endpoint_variants = [] for production_variant in production_variants: temp_variant = {} # VariantName is the only required param temp_variant["VariantName"] = production_variant["VariantName"] if production_variant.get("InitialInstanceCount", None): temp_variant["CurrentInstanceCount"] = production_variant[ "InitialInstanceCount" ] temp_variant["DesiredInstanceCount"] = production_variant[ "InitialInstanceCount" ] if production_variant.get("InitialVariantWeight", None): temp_variant["CurrentWeight"] = production_variant[ "InitialVariantWeight" ] temp_variant["DesiredWeight"] = production_variant[ "InitialVariantWeight" ] if production_variant.get("ServerlessConfig", None): temp_variant["CurrentServerlessConfig"] = production_variant[ "ServerlessConfig" ] temp_variant["DesiredServerlessConfig"] = production_variant[ "ServerlessConfig" ] endpoint_variants.append(temp_variant) return endpoint_variants def summary(self) -> Dict[str, Any]: return { "EndpointName": self.endpoint_name, "EndpointArn": self.arn, "CreationTime": self.creation_time, "LastModifiedTime": self.last_modified_time, "EndpointStatus": self.endpoint_status, } @property def response_object(self) -> Dict[str, Any]: # type: ignore[misc] response_object = self.gen_response_object() response = { k: v for k, v in response_object.items() if v is not None and v != [None] } response["EndpointArn"] = response.pop("Arn") return response @property def response_create(self) -> Dict[str, str]: return {"EndpointArn": self.arn} @staticmethod def arn_formatter(endpoint_name: str, account_id: str, region_name: str) -> str: return arn_formatter("endpoint", endpoint_name, account_id, region_name) @property def physical_resource_id(self) -> str: return self.arn @classmethod def has_cfn_attr(cls, attr: str) -> bool: return attr in ["EndpointName"] def get_cfn_attribute(self, attribute_name: str) -> str: # https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-resource-sagemaker-endpoint.html#aws-resource-sagemaker-endpoint-return-values from moto.cloudformation.exceptions import UnformattedGetAttTemplateException if attribute_name == "EndpointName": return self.endpoint_name raise UnformattedGetAttTemplateException() @staticmethod def cloudformation_name_type() -> str: return "" @staticmethod def cloudformation_type() -> str: # https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-resource-sagemaker-endpoint.html return "AWS::SageMaker::Endpoint" @classmethod def create_from_cloudformation_json( # type: ignore[misc] cls, resource_name: str, cloudformation_json: Any, account_id: str, region_name: str, **kwargs: Any, ) -> "FakeEndpoint": sagemaker_backend = sagemaker_backends[account_id][region_name] # Get required properties from provided CloudFormation template properties = cloudformation_json["Properties"] endpoint_config_name = properties["EndpointConfigName"] endpoint = sagemaker_backend.create_endpoint( endpoint_name=resource_name, endpoint_config_name=endpoint_config_name, tags=properties.get("Tags", []), ) return endpoint @classmethod def update_from_cloudformation_json( # type: ignore[misc] cls, original_resource: Any, new_resource_name: str, cloudformation_json: Any, account_id: str, region_name: str, ) -> "FakeEndpoint": # Changes to the Endpoint will not change resource name cls.delete_from_cloudformation_json( original_resource.arn, cloudformation_json, account_id, region_name ) new_resource = cls.create_from_cloudformation_json( original_resource.endpoint_name, cloudformation_json, account_id, region_name, ) return new_resource @classmethod def delete_from_cloudformation_json( # type: ignore[misc] cls, resource_name: str, cloudformation_json: Any, account_id: str, region_name: str, ) -> None: # Get actual name because resource_name actually provides the ARN # since the Physical Resource ID is the ARN despite SageMaker # using the name for most of its operations. endpoint_name = resource_name.split("/")[-1] sagemaker_backends[account_id][region_name].delete_endpoint(endpoint_name) class FakeEndpointConfig(BaseObject, CloudFormationModel): def __init__( self, account_id: str, region_name: str, endpoint_config_name: str, production_variants: List[Dict[str, Any]], data_capture_config: Dict[str, Any], tags: List[Dict[str, Any]], kms_key_id: str, ): self.validate_production_variants(production_variants) self.endpoint_config_name = endpoint_config_name self.endpoint_config_arn = FakeEndpointConfig.arn_formatter( endpoint_config_name, account_id, region_name ) self.arn = (self.endpoint_config_arn,) self.production_variants = production_variants or [] self.data_capture_config = data_capture_config or {} self.tags = tags or [] self.kms_key_id = kms_key_id self.creation_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") def validate_production_variants( self, production_variants: List[Dict[str, Any]] ) -> None: for production_variant in production_variants: if "InstanceType" in production_variant.keys(): self.validate_instance_type(production_variant["InstanceType"]) elif "ServerlessConfig" in production_variant.keys(): self.validate_serverless_config(production_variant["ServerlessConfig"]) else: message = f"Invalid Keys for ProductionVariant: received {production_variant.keys()} but expected it to contain one of {['InstanceType', 'ServerlessConfig']}" raise ValidationError(message=message) def validate_serverless_config(self, serverless_config: Dict[str, Any]) -> None: VALID_SERVERLESS_MEMORY_SIZE = [1024, 2048, 3072, 4096, 5120, 6144] if not validators.is_one_of( serverless_config["MemorySizeInMB"], VALID_SERVERLESS_MEMORY_SIZE ): message = f"Value '{serverless_config['MemorySizeInMB']}' at 'MemorySizeInMB' failed to satisfy constraint: Member must satisfy enum value set: {VALID_SERVERLESS_MEMORY_SIZE}" raise ValidationError(message=message) def validate_instance_type(self, instance_type: str) -> None: VALID_INSTANCE_TYPES = [ "ml.r5d.12xlarge", "ml.r5.12xlarge", "ml.p2.xlarge", "ml.m5.4xlarge", "ml.m4.16xlarge", "ml.r5d.24xlarge", "ml.r5.24xlarge", "ml.p3.16xlarge", "ml.m5d.xlarge", "ml.m5.large", "ml.t2.xlarge", "ml.p2.16xlarge", "ml.m5d.12xlarge", "ml.inf1.2xlarge", "ml.m5d.24xlarge", "ml.c4.2xlarge", "ml.c5.2xlarge", "ml.c4.4xlarge", "ml.inf1.6xlarge", "ml.c5d.2xlarge", "ml.c5.4xlarge", "ml.g4dn.xlarge", "ml.g4dn.12xlarge", "ml.c5d.4xlarge", "ml.g4dn.2xlarge", "ml.c4.8xlarge", "ml.c4.large", "ml.c5d.xlarge", "ml.c5.large", "ml.g4dn.4xlarge", "ml.c5.9xlarge", "ml.g4dn.16xlarge", "ml.c5d.large", "ml.c5.xlarge", "ml.c5d.9xlarge", "ml.c4.xlarge", "ml.inf1.xlarge", "ml.g4dn.8xlarge", "ml.inf1.24xlarge", "ml.m5d.2xlarge", "ml.t2.2xlarge", "ml.c5d.18xlarge", "ml.m5d.4xlarge", "ml.t2.medium", "ml.c5.18xlarge", "ml.r5d.2xlarge", "ml.r5.2xlarge", "ml.p3.2xlarge", "ml.m5d.large", "ml.m5.xlarge", "ml.m4.10xlarge", "ml.t2.large", "ml.r5d.4xlarge", "ml.r5.4xlarge", "ml.m5.12xlarge", "ml.m4.xlarge", "ml.m5.24xlarge", "ml.m4.2xlarge", "ml.p2.8xlarge", "ml.m5.2xlarge", "ml.r5d.xlarge", "ml.r5d.large", "ml.r5.xlarge", "ml.r5.large", "ml.p3.8xlarge", "ml.m4.4xlarge", ] if not validators.is_one_of(instance_type, VALID_INSTANCE_TYPES): message = f"Value '{instance_type}' at 'instanceType' failed to satisfy constraint: Member must satisfy enum value set: {VALID_INSTANCE_TYPES}" raise ValidationError(message=message) def summary(self) -> Dict[str, Any]: return { "EndpointConfigName": self.endpoint_config_name, "EndpointConfigArn": self.endpoint_config_arn, "CreationTime": self.creation_time, } @property def response_object(self) -> Dict[str, Any]: # type: ignore[misc] response_object = self.gen_response_object() return { k: v for k, v in response_object.items() if v is not None and v != [None] } @property def response_create(self) -> Dict[str, str]: return {"EndpointConfigArn": self.endpoint_config_arn} @staticmethod def arn_formatter( endpoint_config_name: str, account_id: str, region_name: str ) -> str: return arn_formatter( "endpoint-config", endpoint_config_name, account_id, region_name ) @property def physical_resource_id(self) -> str: return self.endpoint_config_arn @classmethod def has_cfn_attr(cls, attr: str) -> bool: return attr in ["EndpointConfigName"] def get_cfn_attribute(self, attribute_name: str) -> str: # https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-resource-sagemaker-endpointconfig.html#aws-resource-sagemaker-endpointconfig-return-values from moto.cloudformation.exceptions import UnformattedGetAttTemplateException if attribute_name == "EndpointConfigName": return self.endpoint_config_name raise UnformattedGetAttTemplateException() @staticmethod def cloudformation_name_type() -> str: return "" @staticmethod def cloudformation_type() -> str: # https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-resource-sagemaker-endpointconfig.html return "AWS::SageMaker::EndpointConfig" @classmethod def create_from_cloudformation_json( # type: ignore[misc] cls, resource_name: str, cloudformation_json: Any, account_id: str, region_name: str, **kwargs: Any, ) -> "FakeEndpointConfig": sagemaker_backend = sagemaker_backends[account_id][region_name] # Get required properties from provided CloudFormation template properties = cloudformation_json["Properties"] production_variants = properties["ProductionVariants"] endpoint_config = sagemaker_backend.create_endpoint_config( endpoint_config_name=resource_name, production_variants=production_variants, data_capture_config=properties.get("DataCaptureConfig", {}), kms_key_id=properties.get("KmsKeyId"), tags=properties.get("Tags", []), ) return endpoint_config @classmethod def update_from_cloudformation_json( # type: ignore[misc] cls, original_resource: Any, new_resource_name: str, cloudformation_json: Any, account_id: str, region_name: str, ) -> "FakeEndpointConfig": # Most changes to the endpoint config will change resource name for EndpointConfigs cls.delete_from_cloudformation_json( original_resource.endpoint_config_arn, cloudformation_json, account_id, region_name, ) new_resource = cls.create_from_cloudformation_json( new_resource_name, cloudformation_json, account_id, region_name ) return new_resource @classmethod def delete_from_cloudformation_json( # type: ignore[misc] cls, resource_name: str, cloudformation_json: Any, account_id: str, region_name: str, ) -> None: # Get actual name because resource_name actually provides the ARN # since the Physical Resource ID is the ARN despite SageMaker # using the name for most of its operations. endpoint_config_name = resource_name.split("/")[-1] sagemaker_backends[account_id][region_name].delete_endpoint_config( endpoint_config_name ) class FakeTransformJob(BaseObject): def __init__( self, account_id: str, region_name: str, transform_job_name: str, model_name: str, max_concurrent_transforms: int, model_client_config: Dict[str, int], max_payload_in_mb: int, batch_strategy: str, environment: Dict[str, str], transform_input: Dict[str, Union[Dict[str, str], str]], transform_output: Dict[str, str], data_capture_config: Dict[str, Union[str, bool]], transform_resources: Dict[str, Union[str, int]], data_processing: Dict[str, str], tags: Dict[str, str], experiment_config: Dict[str, str], ): self.transform_job_name = transform_job_name self.model_name = model_name self.max_concurrent_transforms = max_concurrent_transforms self.model_client_config = model_client_config self.max_payload_in_mb = max_payload_in_mb self.batch_strategy = batch_strategy self.environment = environment self.transform_input = transform_input self.transform_output = transform_output self.data_capture_config = data_capture_config self.transform_resources = transform_resources self.data_processing = data_processing self.tags = tags self.experiment_config = experiment_config self.arn = FakeTransformJob.arn_formatter( transform_job_name, account_id, region_name ) self.transform_job_status = "Completed" self.failure_reason = "" self.labeling_job_arn = "" self.auto_ml_job_arn = "" now_string = datetime.now().strftime("%Y-%m-%d %H:%M:%S") self.creation_time = now_string self.transform_start_time = now_string self.transform_end_time = now_string self.last_modified_time = now_string # Override title case def camelCase(self, key: str) -> str: words = [] for word in key.split("_"): if word == "mb": words.append("MB") else: words.append(word.title()) return "".join(words) @property def response_object(self) -> Dict[str, Any]: # type: ignore[misc] response_object = self.gen_response_object() response = { k: v for k, v in response_object.items() if v is not None and v != [None] } return response @property def response_create(self) -> Dict[str, str]: return {"TransformJobArn": self.arn} @staticmethod def arn_formatter(name: str, account_id: str, region_name: str) -> str: return arn_formatter("transform-job", name, account_id, region_name) class Model(BaseObject, CloudFormationModel): def __init__( self, account_id: str, region_name: str, model_name: str, execution_role_arn: str, primary_container: Dict[str, Any], vpc_config: Dict[str, Any], containers: Optional[List[Dict[str, Any]]] = None, tags: Optional[List[Dict[str, str]]] = None, ): self.model_name = model_name self.creation_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") self.containers = containers or [] self.tags = tags or [] self.enable_network_isolation = False self.vpc_config = vpc_config self.primary_container = primary_container self.execution_role_arn = execution_role_arn or "arn:test" self.arn = arn_formatter("model", self.model_name, account_id, region_name) @property def response_object(self) -> Dict[str, Any]: # type: ignore[misc] response_object = self.gen_response_object() response = { k: v for k, v in response_object.items() if v is not None and v != [None] } response["ModelArn"] = response.pop("Arn") return response @property def response_create(self) -> Dict[str, str]: return {"ModelArn": self.arn} @property def physical_resource_id(self) -> str: return self.arn @classmethod def has_cfn_attr(cls, attr: str) -> bool: return attr in ["ModelName"] def get_cfn_attribute(self, attribute_name: str) -> str: # https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-resource-sagemaker-model.html#aws-resource-sagemaker-model-return-values from moto.cloudformation.exceptions import UnformattedGetAttTemplateException if attribute_name == "ModelName": return self.model_name raise UnformattedGetAttTemplateException() @staticmethod def cloudformation_name_type() -> str: return "" @staticmethod def cloudformation_type() -> str: # https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-resource-sagemaker-model.html return "AWS::SageMaker::Model" @classmethod def create_from_cloudformation_json( # type: ignore[misc] cls, resource_name: str, cloudformation_json: Any, account_id: str, region_name: str, **kwargs: Any, ) -> "Model": sagemaker_backend = sagemaker_backends[account_id][region_name] # Get required properties from provided CloudFormation template properties = cloudformation_json["Properties"] execution_role_arn = properties["ExecutionRoleArn"] primary_container = properties["PrimaryContainer"] model = sagemaker_backend.create_model( model_name=resource_name, execution_role_arn=execution_role_arn, primary_container=primary_container, vpc_config=properties.get("VpcConfig", {}), containers=properties.get("Containers", []), tags=properties.get("Tags", []), ) return model @classmethod def update_from_cloudformation_json( # type: ignore[misc] cls, original_resource: Any, new_resource_name: str, cloudformation_json: Any, account_id: str, region_name: str, ) -> "Model": # Most changes to the model will change resource name for Models cls.delete_from_cloudformation_json( original_resource.arn, cloudformation_json, account_id, region_name ) new_resource = cls.create_from_cloudformation_json( new_resource_name, cloudformation_json, account_id, region_name ) return new_resource @classmethod def delete_from_cloudformation_json( # type: ignore[misc] cls, resource_name: str, cloudformation_json: Any, account_id: str, region_name: str, ) -> None: # Get actual name because resource_name actually provides the ARN # since the Physical Resource ID is the ARN despite SageMaker # using the name for most of its operations. model_name = resource_name.split("/")[-1] sagemaker_backends[account_id][region_name].delete_model(model_name) class ModelPackageGroup(BaseObject): def __init__( self, model_package_group_name: str, model_package_group_description: str, account_id: str, region_name: str, tags: Optional[List[Dict[str, str]]] = None, ) -> None: model_package_group_arn = arn_formatter( region_name=region_name, account_id=account_id, _type="model-package-group", _id=model_package_group_name, ) fake_user_profile_name = "fake-user-profile-name" fake_domain_id = "fake-domain-id" fake_user_profile_arn = arn_formatter( _type="user-profile", _id=f"{fake_domain_id}/{fake_user_profile_name}", account_id=account_id, region_name=region_name, ) datetime_now = datetime.now(tzutc()) self.model_package_group_name = model_package_group_name self.arn = model_package_group_arn self.model_package_group_description = model_package_group_description self.creation_time = datetime_now self.created_by = { "UserProfileArn": fake_user_profile_arn, "UserProfileName": fake_user_profile_name, "DomainId": fake_domain_id, } self.model_package_group_status = "Completed" self.tags = tags def gen_response_object(self) -> Dict[str, Any]: response_object = super().gen_response_object() for k, v in response_object.items(): if isinstance(v, datetime): response_object[k] = v.isoformat() response_values = [ "ModelPackageGroupName", "Arn", "ModelPackageGroupDescription", "CreationTime", "ModelPackageGroupStatus", "Tags", ] response = {k: v for k, v in response_object.items() if k in response_values} response["ModelPackageGroupArn"] = response.pop("Arn") return response class FakeModelCard(BaseObject): def __init__( self, account_id: str, region_name: str, model_card_name: str, model_card_version: int, content: str, model_card_status: str, security_config: Optional[Dict[str, str]] = None, tags: Optional[List[Dict[str, Any]]] = None, creation_time: Optional[str] = None, last_modified_time: Optional[str] = None, ) -> None: datetime_now = str(datetime.now(tzutc())) self.arn = arn_formatter("model-card", model_card_name, account_id, region_name) self.model_card_name = model_card_name self.model_card_version = model_card_version self.content = content self.model_card_status = model_card_status self.creation_time = creation_time if creation_time else datetime_now self.last_modified_time = ( last_modified_time if last_modified_time else datetime_now ) self.security_config = security_config self.tags = tags def describe(self) -> Dict[str, Any]: return { "ModelCardArn": self.arn, "ModelCardName": self.model_card_name, "ModelCardVersion": self.model_card_version, "Content": self.content, "ModelCardStatus": self.model_card_status, "SecurityConfig": self.security_config, "CreationTime": self.creation_time, "CreatedBy": {}, "LastModifiedTime": self.creation_time, "LastModifiedBy": {}, } def summary(self) -> Dict[str, Any]: return { "ModelCardName": self.model_card_name, "ModelCardArn": self.arn, "ModelCardStatus": self.model_card_status, "CreationTime": self.creation_time, "LastModifiedTime": self.last_modified_time, } def version_summary(self) -> Dict[str, Any]: return { "ModelCardName": self.model_card_name, "ModelCardArn": self.arn, "ModelCardStatus": self.model_card_status, "ModelCardVersion": self.model_card_version, "CreationTime": self.creation_time, "LastModifiedTime": self.last_modified_time, } class FeatureGroup(BaseObject): def __init__( self, region_name: str, account_id: str, feature_group_name: str, record_identifier_feature_name: str, event_time_feature_name: str, feature_definitions: List[Dict[str, str]], offline_store_config: Dict[str, Any], role_arn: str, tags: Optional[List[Dict[str, str]]] = None, ) -> None: self.feature_group_name = feature_group_name self.record_identifier_feature_name = record_identifier_feature_name self.event_time_feature_name = event_time_feature_name self.feature_definitions = feature_definitions table_name = ( f"{feature_group_name.replace('-','_')}_{int(datetime.now().timestamp())}" ) offline_store_config["DataCatalogConfig"] = { "TableName": table_name, "Catalog": "AwsDataCatalog", "Database": "sagemaker_featurestore", } offline_store_config["S3StorageConfig"]["ResolvedOutputS3Uri"] = ( f'{offline_store_config["S3StorageConfig"]["S3Uri"]}/{account_id}/{region_name}/offline-store/{feature_group_name}-{int(datetime.now().timestamp())}/data' ) self.offline_store_config = offline_store_config self.role_arn = role_arn self.creation_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") self.arn = arn_formatter( region_name=region_name, account_id=account_id, _type="feature-group", _id=f"{self.feature_group_name.lower()}", ) self.tags = tags def describe(self) -> Dict[str, Any]: return { "FeatureGroupArn": self.arn, "FeatureGroupName": self.feature_group_name, "RecordIdentifierFeatureName": self.record_identifier_feature_name, "EventTimeFeatureName": self.event_time_feature_name, "FeatureDefinitions": self.feature_definitions, "CreationTime": self.creation_time, "OfflineStoreConfig": self.offline_store_config, "RoleArn": self.role_arn, "ThroughputConfig": {"ThroughputMode": "OnDemand"}, "FeatureGroupStatus": "Created", } class ModelPackage(BaseObject): def __init__( self, model_package_name: str, model_package_group_name: Optional[str], model_package_version: Optional[int], model_package_description: Optional[str], inference_specification: Any, source_algorithm_specification: Any, validation_specification: Any, certify_for_marketplace: bool, model_approval_status: Optional[str], metadata_properties: Any, model_metrics: Any, approval_description: Optional[str], customer_metadata_properties: Any, drift_check_baselines: Any, domain: str, task: str, sample_payload_url: str, additional_inference_specifications: List[Any], client_token: str, region_name: str, account_id: str, model_package_type: str, tags: Optional[List[Dict[str, str]]] = None, ) -> None: fake_user_profile_name = "fake-user-profile-name" fake_domain_id = "fake-domain-id" fake_user_profile_arn = arn_formatter( _type="user-profile", _id=f"{fake_domain_id}/{fake_user_profile_name}", account_id=account_id, region_name=region_name, ) model_package_arn = arn_formatter( region_name=region_name, account_id=account_id, _type="model-package", _id=f"{model_package_name.lower()}/{model_package_version}" if model_package_version else model_package_name.lower(), ) datetime_now = datetime.now(tzutc()) self.model_package_name = model_package_name self.model_package_group_name = model_package_group_name self.model_package_version = model_package_version self.arn = model_package_arn self.model_package_description = model_package_description self.creation_time = datetime_now self.inference_specification = inference_specification self.source_algorithm_specification = source_algorithm_specification self.validation_specification = validation_specification self.model_package_type = model_package_type self.model_package_status_details = { "ValidationStatuses": [ { "Name": model_package_arn, "Status": "Completed", } ], "ImageScanStatuses": [ { "Name": model_package_arn, "Status": "Completed", } ], } self.certify_for_marketplace = certify_for_marketplace self.model_approval_status: Optional[str] = None self.set_model_approval_status(model_approval_status) self.created_by = { "UserProfileArn": fake_user_profile_arn, "UserProfileName": fake_user_profile_name, "DomainId": fake_domain_id, } self.metadata_properties = metadata_properties self.model_metrics = model_metrics self.last_modified_time: Optional[datetime] = None self.approval_description = approval_description self.customer_metadata_properties = customer_metadata_properties self.drift_check_baselines = drift_check_baselines self.domain = domain self.task = task self.sample_payload_url = sample_payload_url self.additional_inference_specifications: Optional[List[Any]] = None self.add_additional_inference_specifications( additional_inference_specifications ) self.tags = tags self.model_package_status = "Completed" self.last_modified_by: Optional[Dict[str, str]] = None self.client_token = client_token def gen_response_object(self) -> Dict[str, Any]: response_object = super().gen_response_object() for k, v in response_object.items(): if isinstance(v, datetime): response_object[k] = v.isoformat() response_values = [ "ModelPackageName", "ModelPackageGroupName", "ModelPackageVersion", "Arn", "ModelPackageDescription", "CreationTime", "InferenceSpecification", "SourceAlgorithmSpecification", "ValidationSpecification", "ModelPackageStatus", "ModelPackageStatusDetails", "CertifyForMarketplace", "ModelApprovalStatus", "CreatedBy", "MetadataProperties", "ModelMetrics", "LastModifiedTime", "LastModifiedBy", "ApprovalDescription", "CustomerMetadataProperties", "DriftCheckBaselines", "Domain", "Task", "SamplePayloadUrl", "AdditionalInferenceSpecifications", "SkipModelValidation", ] if self.model_package_type == "Versioned": del response_object["ModelPackageName"] elif self.model_package_type == "Unversioned": del response_object["ModelPackageGroupName"] response = { k: v for k, v in response_object.items() if k in response_values if v is not None } response["ModelPackageArn"] = response.pop("Arn") return response def modifications_done(self) -> None: self.last_modified_time = datetime.now(tzutc()) self.last_modified_by = self.created_by def set_model_approval_status(self, model_approval_status: Optional[str]) -> None: if model_approval_status is not None: validate_model_approval_status(model_approval_status) self.model_approval_status = model_approval_status def remove_customer_metadata_property( self, customer_metadata_properties_to_remove: List[str] ) -> None: if customer_metadata_properties_to_remove is not None: for customer_metadata_property in customer_metadata_properties_to_remove: self.customer_metadata_properties.pop(customer_metadata_property, None) def add_additional_inference_specifications( self, additional_inference_specifications_to_add: Optional[List[Any]] ) -> None: self.validate_additional_inference_specifications( additional_inference_specifications_to_add ) if ( self.additional_inference_specifications is not None and additional_inference_specifications_to_add is not None ): self.additional_inference_specifications.extend( additional_inference_specifications_to_add ) else: self.additional_inference_specifications = ( additional_inference_specifications_to_add ) def validate_additional_inference_specifications( self, additional_inference_specifications: Optional[List[Dict[str, Any]]] ) -> None: specifications_to_validate = additional_inference_specifications or [] for additional_inference_specification in specifications_to_validate: if "SupportedTransformInstanceTypes" in additional_inference_specification: self.validate_supported_transform_instance_types( additional_inference_specification[ "SupportedTransformInstanceTypes" ] ) if ( "SupportedRealtimeInferenceInstanceTypes" in additional_inference_specification ): self.validate_supported_realtime_inference_instance_types( additional_inference_specification[ "SupportedRealtimeInferenceInstanceTypes" ] ) @staticmethod def validate_supported_transform_instance_types(instance_types: List[str]) -> None: VALID_TRANSFORM_INSTANCE_TYPES = [ "ml.m4.xlarge", "ml.m4.2xlarge", "ml.m4.4xlarge", "ml.m4.10xlarge", "ml.m4.16xlarge", "ml.c4.xlarge", "ml.c4.2xlarge", "ml.c4.4xlarge", "ml.c4.8xlarge", "ml.p2.xlarge", "ml.p2.8xlarge", "ml.p2.16xlarge", "ml.p3.2xlarge", "ml.p3.8xlarge", "ml.p3.16xlarge", "ml.c5.xlarge", "ml.c5.2xlarge", "ml.c5.4xlarge", "ml.c5.9xlarge", "ml.c5.18xlarge", "ml.m5.large", "ml.m5.xlarge", "ml.m5.2xlarge", "ml.m5.4xlarge", "ml.m5.12xlarge", "ml.m5.24xlarge", "ml.g4dn.xlarge", "ml.g4dn.2xlarge", "ml.g4dn.4xlarge", "ml.g4dn.8xlarge", "ml.g4dn.12xlarge", "ml.g4dn.16xlarge", ] for instance_type in instance_types: if not validators.is_one_of(instance_type, VALID_TRANSFORM_INSTANCE_TYPES): message = f"Value '{instance_type}' at 'SupportedTransformInstanceTypes' failed to satisfy constraint: Member must satisfy enum value set: {VALID_TRANSFORM_INSTANCE_TYPES}" raise ValidationError(message=message) @staticmethod def validate_supported_realtime_inference_instance_types( instance_types: List[str], ) -> None: VALID_REALTIME_INFERENCE_INSTANCE_TYPES = [ "ml.t2.medium", "ml.t2.large", "ml.t2.xlarge", "ml.t2.2xlarge", "ml.m4.xlarge", "ml.m4.2xlarge", "ml.m4.4xlarge", "ml.m4.10xlarge", "ml.m4.16xlarge", "ml.m5.large", "ml.m5.xlarge", "ml.m5.2xlarge", "ml.m5.4xlarge", "ml.m5.12xlarge", "ml.m5.24xlarge", "ml.m5d.large", "ml.m5d.xlarge", "ml.m5d.2xlarge", "ml.m5d.4xlarge", "ml.m5d.12xlarge", "ml.m5d.24xlarge", "ml.c4.large", "ml.c4.xlarge", "ml.c4.2xlarge", "ml.c4.4xlarge", "ml.c4.8xlarge", "ml.p2.xlarge", "ml.p2.8xlarge", "ml.p2.16xlarge", "ml.p3.2xlarge", "ml.p3.8xlarge", "ml.p3.16xlarge", "ml.c5.large", "ml.c5.xlarge", "ml.c5.2xlarge", "ml.c5.4xlarge", "ml.c5.9xlarge", "ml.c5.18xlarge", "ml.c5d.large", "ml.c5d.xlarge", "ml.c5d.2xlarge", "ml.c5d.4xlarge", "ml.c5d.9xlarge", "ml.c5d.18xlarge", "ml.g4dn.xlarge", "ml.g4dn.2xlarge", "ml.g4dn.4xlarge", "ml.g4dn.8xlarge", "ml.g4dn.12xlarge", "ml.g4dn.16xlarge", "ml.r5.large", "ml.r5.xlarge", "ml.r5.2xlarge", "ml.r5.4xlarge", "ml.r5.12xlarge", "ml.r5.24xlarge", "ml.r5d.large", "ml.r5d.xlarge", "ml.r5d.2xlarge", "ml.r5d.4xlarge", "ml.r5d.12xlarge", "ml.r5d.24xlarge", "ml.inf1.xlarge", "ml.inf1.2xlarge", "ml.inf1.6xlarge", "ml.inf1.24xlarge", "ml.c6i.large", "ml.c6i.xlarge", "ml.c6i.2xlarge", "ml.c6i.4xlarge", "ml.c6i.8xlarge", "ml.c6i.12xlarge", "ml.c6i.16xlarge", "ml.c6i.24xlarge", "ml.c6i.32xlarge", "ml.g5.xlarge", "ml.g5.2xlarge", "ml.g5.4xlarge", "ml.g5.8xlarge", "ml.g5.12xlarge", "ml.g5.16xlarge", "ml.g5.24xlarge", "ml.g5.48xlarge", "ml.p4d.24xlarge", "ml.c7g.large", "ml.c7g.xlarge", "ml.c7g.2xlarge", "ml.c7g.4xlarge", "ml.c7g.8xlarge", "ml.c7g.12xlarge", "ml.c7g.16xlarge", "ml.m6g.large", "ml.m6g.xlarge", "ml.m6g.2xlarge", "ml.m6g.4xlarge", "ml.m6g.8xlarge", "ml.m6g.12xlarge", "ml.m6g.16xlarge", "ml.m6gd.large", "ml.m6gd.xlarge", "ml.m6gd.2xlarge", "ml.m6gd.4xlarge", "ml.m6gd.8xlarge", "ml.m6gd.12xlarge", "ml.m6gd.16xlarge", "ml.c6g.large", "ml.c6g.xlarge", "ml.c6g.2xlarge", "ml.c6g.4xlarge", "ml.c6g.8xlarge", "ml.c6g.12xlarge", "ml.c6g.16xlarge", "ml.c6gd.large", "ml.c6gd.xlarge", "ml.c6gd.2xlarge", "ml.c6gd.4xlarge", "ml.c6gd.8xlarge", "ml.c6gd.12xlarge", "ml.c6gd.16xlarge", "ml.c6gn.large", "ml.c6gn.xlarge", "ml.c6gn.2xlarge", "ml.c6gn.4xlarge", "ml.c6gn.8xlarge", "ml.c6gn.12xlarge", "ml.c6gn.16xlarge", "ml.r6g.large", "ml.r6g.xlarge", "ml.r6g.2xlarge", "ml.r6g.4xlarge", "ml.r6g.8xlarge", "ml.r6g.12xlarge", "ml.r6g.16xlarge", "ml.r6gd.large", "ml.r6gd.xlarge", "ml.r6gd.2xlarge", "ml.r6gd.4xlarge", "ml.r6gd.8xlarge", "ml.r6gd.12xlarge", "ml.r6gd.16xlarge", "ml.p4de.24xlarge", "ml.trn1.2xlarge", "ml.trn1.32xlarge", "ml.inf2.xlarge", "ml.inf2.8xlarge", "ml.inf2.24xlarge", "ml.inf2.48xlarge", "ml.p5.48xlarge", ] for instance_type in instance_types: if not validators.is_one_of( instance_type, VALID_REALTIME_INFERENCE_INSTANCE_TYPES ): message = f"Value '{instance_type}' at 'SupportedRealtimeInferenceInstanceTypes' failed to satisfy constraint: Member must satisfy enum value set: {VALID_REALTIME_INFERENCE_INSTANCE_TYPES}" raise ValidationError(message=message) class Cluster(BaseObject): def __init__( self, cluster_name: str, region_name: str, account_id: str, instance_groups: List[Dict[str, Any]], vpc_config: Dict[str, List[str]], tags: Optional[List[Dict[str, str]]] = None, ): self.region_name = region_name self.account_id = account_id self.cluster_name = cluster_name if cluster_name in sagemaker_backends[account_id][region_name].clusters: raise ResourceInUseException( message=f"Resource Already Exists: Cluster with name {cluster_name} already exists. Choose a different name." ) self.instance_groups = instance_groups for instance_group in instance_groups: self.valid_cluster_node_instance_types(instance_group["InstanceType"]) if not instance_group["LifeCycleConfig"]["SourceS3Uri"].startswith( "s3://sagemaker-" ): raise ValidationError( message=f"Validation Error: SourceS3Uri {instance_group['LifeCycleConfig']['SourceS3Uri']} does not start with 's3://sagemaker'." ) self.vpc_config = vpc_config self.tags = tags or [] self.arn = arn_formatter("cluster", self.cluster_name, account_id, region_name) self.status = "InService" self.creation_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") self.failure_message = "" self.nodes: Dict[str, ClusterNode] = {} for instance_group in self.instance_groups: instance_group["CurrentCount"] = instance_group["InstanceCount"] instance_group["TargetCount"] = instance_group["InstanceCount"] del instance_group["InstanceCount"] def describe(self) -> Dict[str, Any]: return { "ClusterArn": self.arn, "ClusterName": self.cluster_name, "ClusterStatus": self.status, "CreationTime": self.creation_time, "FailureMessage": self.failure_message, "InstanceGroups": self.instance_groups, "VpcConfig": self.vpc_config, } def summary(self) -> Dict[str, Any]: return { "ClusterArn": self.arn, "ClusterName": self.cluster_name, "CreationTime": self.creation_time, "ClusterStatus": self.status, } def valid_cluster_node_instance_types(self, instance_type: str) -> None: VALID_CLUSTER_INSTANCE_TYPES = [ "ml.p4d.24xlarge", "ml.p4de.24xlarge", "ml.p5.48xlarge", "ml.trn1.32xlarge", "ml.trn1n.32xlarge", "ml.g5.xlarge", "ml.g5.2xlarge", "ml.g5.4xlarge", "ml.g5.8xlarge", "ml.g5.12xlarge", "ml.g5.16xlarge", "ml.g5.24xlarge", "ml.g5.48xlarge", "ml.c5.large", "ml.c5.xlarge", "ml.c5.2xlarge", "ml.c5.4xlarge", "ml.c5.9xlarge", "ml.c5.12xlarge", "ml.c5.18xlarge", "ml.c5.24xlarge", "ml.c5n.large", "ml.c5n.2xlarge", "ml.c5n.4xlarge", "ml.c5n.9xlarge", "ml.c5n.18xlarge", "ml.m5.large", "ml.m5.xlarge", "ml.m5.2xlarge", "ml.m5.4xlarge", "ml.m5.8xlarge", "ml.m5.12xlarge", "ml.m5.16xlarge", "ml.m5.24xlarge", "ml.t3.medium", "ml.t3.large", "ml.t3.xlarge", "ml.t3.2xlarge", ] if instance_type not in VALID_CLUSTER_INSTANCE_TYPES: message = f"Value '{instance_type}' at 'InstanceType' failed to satisfy constraint: Member must satisfy enum value set: {VALID_CLUSTER_INSTANCE_TYPES}" raise ValidationError(message=message) class ClusterNode(BaseObject): def __init__( self, region_name: str, account_id: str, cluster_name: str, instance_group_name: str, instance_type: str, life_cycle_config: Dict[str, Any], execution_role: str, node_id: str, threads_per_core: Optional[int] = None, ): self.region_name = region_name self.account_id = account_id self.cluster_name = cluster_name self.instance_group_name = ( instance_group_name # probably need to do something with this ) self.instance_id = node_id # generate instance id self.instance_type = instance_type self.life_cycle_config = life_cycle_config self.execution_role = execution_role self.threads_per_core = threads_per_core self.status = "Running" self.launch_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") def describe(self) -> Dict[str, Any]: return { "InstanceGroupName": self.instance_group_name, "InstanceId": self.instance_id, "InstanceStatus": {"Status": self.status, "Message": "message"}, "InstanceType": self.instance_type, "LaunchTime": self.launch_time, "LifeCycleConfig": self.life_cycle_config, "ThreadsPerCore": self.threads_per_core, } def summary(self) -> Dict[str, Any]: return { "InstanceGroupName": self.instance_group_name, "InstanceId": self.instance_id, "InstanceType": self.instance_type, "LaunchTime": self.launch_time, "InstanceStatus": {"Status": self.status, "Message": "message"}, } class CompilationJob(BaseObject): def __init__( self, compilation_job_name: str, role_arn: str, region_name: str, account_id: str, output_config: Dict[str, Any], stopping_condition: Dict[str, Any], model_package_version_arn: Optional[str], input_config: Optional[Dict[str, Any]], vpc_config: Optional[Dict[str, Any]], tags: Optional[List[Dict[str, str]]], ): self.compilation_job_name = compilation_job_name if ( compilation_job_name in sagemaker_backends[account_id][region_name].compilation_jobs ): raise ResourceInUseException( message=f"Resource Already Exists: Compilation job with name {compilation_job_name} already exists. Choose a different name." ) self.arn = arn_formatter( "compilation-job", self.compilation_job_name, account_id, region_name ) self.compilation_job_status = "COMPLETED" self.compilation_start_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") self.compilation_end_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") self.stopping_condition = stopping_condition self.inference_image = "InferenceImage" self.model_package_version_arn = model_package_version_arn self.creation_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") self.last_modified_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") self.failure_reason = "" self.model_artifacts = {"S3ModelArtifacts": output_config["S3OutputLocation"]} self.model_digests = { "ArtifactDigest": "786a02f742015903c6c6fd852552d272912f4740e15847618a86e217f71f5419d25e1031afee585313896444934eb04b903a685b1448b755d56f701afe9be2ce" } self.role_arn = role_arn self.input_config = input_config if input_config and model_package_version_arn: raise ValidationError( message="InputConfig and ModelPackageVersionArn cannot be specified at the same time." ) if not input_config and not model_package_version_arn: raise ValidationError( message="Either InputConfig or ModelPackageVersionArn must be specified." ) self.output_config = output_config self.vpc_config = vpc_config self.derived_information = {"DerivedDataInputConfig": "DerivedDataInputConfig"} self.tags = tags def describe(self) -> Dict[str, Any]: return { "CompilationJobName": self.compilation_job_name, "CompilationJobArn": self.arn, "CompilationJobStatus": self.compilation_job_status, "CompilationStartTime": self.compilation_start_time, "CompilationEndTime": self.compilation_end_time, "StoppingCondition": self.stopping_condition, "InferenceImage": self.inference_image, "ModelPackageVersionArn": self.model_package_version_arn, "CreationTime": self.creation_time, "LastModifiedTime": self.last_modified_time, "FailureReason": self.failure_reason, "ModelArtifacts": self.model_artifacts, "ModelDigests": self.model_digests, "RoleArn": self.role_arn, "InputConfig": self.input_config, "OutputConfig": self.output_config, "VpcConfig": self.vpc_config, "DerivedInformation": self.derived_information, } def summary(self) -> Dict[str, Any]: summary = { "CompilationJobName": self.compilation_job_name, "CompilationJobArn": self.arn, "CreationTime": self.creation_time, "CompilationStartTime": self.compilation_start_time, "CompilationEndTime": self.compilation_end_time, "LastModifiedTime": self.last_modified_time, "CompilationJobStatus": self.compilation_job_status, } if "TargetDevice" in self.output_config: summary["CompilationTargetDevice"] = self.output_config["TargetDevice"] else: summary["CompilationTargetPlatformOs"] = self.output_config[ "TargetPlatform" ]["Os"] summary["CompilationTargetPlatformArch"] = self.output_config[ "TargetPlatform" ]["Arch"] summary["CompilationTargetPlatformAccelerator"] = self.output_config[ "TargetPlatform" ]["Accelerator"] return summary class AutoMLJob(BaseObject): def __init__( self, auto_ml_job_name: str, auto_ml_job_input_data_config: List[Dict[str, Any]], output_data_config: Dict[str, Any], auto_ml_problem_type_config: Dict[str, Any], role_arn: str, region_name: str, account_id: str, security_config: Optional[Dict[str, Any]], auto_ml_job_objective: Optional[Dict[str, Any]], model_deploy_config: Optional[Dict[str, Any]], data_split_config: Optional[Dict[str, Any]], tags: Optional[List[Dict[str, str]]] = None, ): self.region_name = region_name self.account_id = account_id self.auto_ml_job_name = auto_ml_job_name if auto_ml_job_name in sagemaker_backends[account_id][region_name].auto_ml_jobs: raise ResourceInUseException( message=f"Resource Already Exists: Auto ML Job with name {auto_ml_job_name} already exists. Choose a different name." ) self.auto_ml_job_input_data_config = auto_ml_job_input_data_config self.output_data_config = output_data_config self.auto_ml_problem_type_config = auto_ml_problem_type_config self.role_arn = role_arn self.security_config = security_config self.auto_ml_job_objective = auto_ml_job_objective self.auto_ml_problem_type_resolved_attributes = { "SDK_UNKNOWN_MEMBER": {"name": "UnknownMemberName"} } if "ImageClassificationJobConfig" in self.auto_ml_problem_type_config: self.auto_ml_job_objective = ( {"MetricName": "Accuracy"} if self.auto_ml_job_objective is None else self.auto_ml_job_objective ) self.auto_ml_problem_type_config_name = "ImageClassification" elif "TextClassificationJobConfig" in self.auto_ml_problem_type_config: self.auto_ml_job_objective = ( {"MetricName": "Accuracy"} if self.auto_ml_job_objective is None else self.auto_ml_job_objective ) self.auto_ml_problem_type_config_name = "TextClassification" elif "TimeSeriesForecastingJobConfig" in self.auto_ml_problem_type_config: self.auto_ml_job_objective = ( {"MetricName": "AverageWeightedQuantileLoss"} if self.auto_ml_job_objective is None else self.auto_ml_job_objective ) self.auto_ml_problem_type_config_name = "TimeSeriesForecasting" elif "TabularJobConfig" in self.auto_ml_problem_type_config: self.auto_ml_problem_type_config_name = "Tabular" if ( self.auto_ml_problem_type_config["TabularJobConfig"]["ProblemType"] == "BinaryClassification" ): self.auto_ml_job_objective = ( {"MetricName": "F1"} if self.auto_ml_job_objective is None else self.auto_ml_job_objective ) self.auto_ml_problem_type_resolved_attributes = { "TabularResolvedAttributes": { "TabularProblemType": "BinaryClassification" } } if ( self.auto_ml_problem_type_config["TabularJobConfig"]["ProblemType"] == "MulticlassClassification" ): self.auto_ml_job_objective = ( {"MetricName": "Accuracy"} if self.auto_ml_job_objective is None else self.auto_ml_job_objective ) self.auto_ml_problem_type_resolved_attributes = { "TabularResolvedAttributes": { "TabularProblemType": "MulticlassClassification" } } if ( self.auto_ml_problem_type_config["TabularJobConfig"]["ProblemType"] == "Regression" ): self.auto_ml_job_objective = ( {"MetricName": "MSE"} if self.auto_ml_job_objective is None else self.auto_ml_job_objective ) self.auto_ml_problem_type_resolved_attributes = { "TabularResolvedAttributes": {"TabularProblemType": "Regression"} } elif "TextGenerationJobConfig" in self.auto_ml_problem_type_config: self.auto_ml_problem_type_config_name = "TextGeneration" self.auto_ml_job_objective = ( {"MetricName": ""} if self.auto_ml_job_objective is None else self.auto_ml_job_objective ) self.auto_ml_problem_type_resolved_attributes = { "TextGenerationResolvedAttributes": {"BaseModelName": "string"} } self.model_deploy_config = ( model_deploy_config if model_deploy_config else {"AutoGenerateEndpointName": False, "EndpointName": "EndpointName"} ) if ( "AutoGenerateEndpointName" in self.model_deploy_config and self.model_deploy_config["AutoGenerateEndpointName"] and "EndpointName" in self.model_deploy_config ): raise ValidationError( message="Validation Error: An EndpointName cannot be provided while AutoGenerateEndpoint name is True." ) self.output_data_config = output_data_config self.data_split_config = ( data_split_config if data_split_config else {"ValidationFraction": 0.2} ) self.tags = tags or [] self.arn = arn_formatter( "automl-job", self.auto_ml_job_name, account_id, region_name ) self.creation_time = str(datetime.now().strftime("%Y-%m-%d %H:%M:%S")) self.end_time = str(datetime.now().strftime("%Y-%m-%d %H:%M:%S")) self.last_modified_time = str(datetime.now().strftime("%Y-%m-%d %H:%M:%S")) self.failure_reason = "" self.partial_failure_reasons = [{"PartialFailureMessage": ""}] self.best_candidate = { "CandidateName": "best_candidate", "FinalAutoMLJobObjectiveMetric": { "Type": "Maximize", "MetricName": "Accuracy", "Value": 123, "StandardMetricName": "Accuracy", }, "ObjectiveStatus": "Succeeded", "CandidateSteps": [ { "CandidateStepType": "AWS::SageMaker::TrainingJob", "CandidateStepArn": arn_formatter( "training-job", "candidate_step_name", account_id, region_name ), "CandidateStepName": "candidate_step_name", }, ], "CandidateStatus": "Completed", "InferenceContainers": [ { "Image": "string", "ModelDataUrl": "string", "Environment": {"string": "string"}, }, ], "CreationTime": str(datetime(2024, 1, 1)), "EndTime": str(datetime(2024, 1, 1)), "LastModifiedTime": str(datetime(2024, 1, 1)), "FailureReason": "string", "CandidateProperties": { "CandidateArtifactLocations": { "Explainability": "string", "ModelInsights": "string", "BacktestResults": "string", }, "CandidateMetrics": [ { "MetricName": "Accuracy", "Value": 123, "Set": "Train", "StandardMetricName": "Accuracy", }, ], }, "InferenceContainerDefinitions": { "string": [ { "Image": "string", "ModelDataUrl": "string", "Environment": {"string": "string"}, }, ] }, } self.auto_ml_job_status = "InProgress" self.auto_ml_job_secondary_status = "Completed" self.auto_ml_job_artifacts = { "CandidateDefinitionNotebookLocation": "candidate/notebook/location", "DataExplorationNotebookLocation": "data/notebook/location", } self.resolved_attributes = { "AutoMLJobObjective": self.auto_ml_job_objective, "CompletionCriteria": self.auto_ml_problem_type_config[ self.auto_ml_problem_type_config_name + "JobConfig" ]["CompletionCriteria"], "AutoMLProblemTypeResolvedAttributes": self.auto_ml_problem_type_resolved_attributes, } self.model_deploy_result = { "EndpointName": self.model_deploy_config["EndpointName"] if self.model_deploy_config else "endpoint_name", } def describe(self) -> Dict[str, Any]: return { "AutoMLJobName": self.auto_ml_job_name, "AutoMLJobArn": self.arn, "AutoMLJobInputDataConfig": self.auto_ml_job_input_data_config, "OutputDataConfig": self.output_data_config, "RoleArn": self.role_arn, "AutoMLJobObjective": self.auto_ml_job_objective, "AutoMLProblemTypeConfig": self.auto_ml_problem_type_config, "AutoMLProblemTypeConfigName": self.auto_ml_problem_type_config_name, "CreationTime": self.creation_time, "EndTime": self.end_time, "LastModifiedTime": self.last_modified_time, "FailureReason": self.failure_reason, "PartialFailureReasons": self.partial_failure_reasons, "BestCandidate": self.best_candidate, "AutoMLJobStatus": self.auto_ml_job_status, "AutoMLJobSecondaryStatus": self.auto_ml_job_secondary_status, "AutoMLJobArtifacts": self.auto_ml_job_artifacts, "ResolvedAttributes": self.resolved_attributes, "ModelDeployConfig": self.model_deploy_config, "ModelDeployResult": self.model_deploy_result, "DataSplitConfig": self.data_split_config, "SecurityConfig": self.security_config, } def summary(self) -> Dict[str, Any]: return { "AutoMLJobName": self.auto_ml_job_name, "AutoMLJobArn": self.arn, "AutoMLJobStatus": self.auto_ml_job_status, "AutoMLJobSecondaryStatus": self.auto_ml_job_secondary_status, "CreationTime": self.creation_time, "EndTime": self.end_time, "LastModifiedTime": self.last_modified_time, "FailureReason": self.failure_reason, "PartialFailureReasons": self.partial_failure_reasons, } class Domain(BaseObject): def __init__( self, domain_name: str, auth_mode: str, default_user_settings: Dict[str, Any], subnet_ids: List[str], vpc_id: str, account_id: str, region_name: str, domain_settings: Optional[Dict[str, Any]], tags: Optional[List[Dict[str, str]]], app_network_access_type: Optional[str], home_efs_file_system_kms_key_id: Optional[str], kms_key_id: Optional[str], app_security_group_management: Optional[str], default_space_settings: Optional[Dict[str, Any]], ): self.domain_name = domain_name if domain_name in sagemaker_backends[account_id][region_name].domains: raise ResourceInUseException( message=f"Resource Already Exists: Domain with name {domain_name} already exists. Choose a different name." ) self.auth_mode = auth_mode self.default_user_settings = default_user_settings self.subnet_ids = subnet_ids self.vpc_id = vpc_id self.account_id = account_id self.region_name = region_name self.domain_settings = domain_settings self.tags = tags self.app_network_access_type = ( app_network_access_type if app_network_access_type else "PublicInternetOnly" ) self.home_efs_file_system_kms_key_id = ( home_efs_file_system_kms_key_id if home_efs_file_system_kms_key_id else kms_key_id ) self.kms_key_id = kms_key_id self.app_security_group_management = app_security_group_management self.default_space_settings = default_space_settings self.id = f"d-{domain_name}" self.arn = arn_formatter("domain", self.id, account_id, region_name) self.home_efs_file_system_id = f"{domain_name}-efs-id" self.single_sign_on_managed_application_instance_id = f"{domain_name}-sso-id" self.single_sign_on_managed_application_arn = arn_formatter( "sso", f"application/{domain_name}/apl-{domain_name}", account_id, "" ) self.status = "InService" self.creation_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") self.last_modified_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") self.failure_reason = "" self.security_group_id_for_domain_boundary = f"sg-{domain_name}" self.url = f"{domain_name}.{region_name}.sagemaker.test.com" def describe(self) -> Dict[str, Any]: return { "DomainArn": self.arn, "DomainId": self.id, "DomainName": self.domain_name, "HomeEfsFileSystemId": self.home_efs_file_system_id, "SingleSignOnManagedApplicationInstanceId": self.single_sign_on_managed_application_instance_id, "SingleSignOnApplicationArn": self.single_sign_on_managed_application_arn, "Status": self.status, "CreationTime": self.creation_time, "LastModifiedTime": self.last_modified_time, "FailureReason": self.failure_reason, "SecurityGroupIdForDomainBoundary": self.security_group_id_for_domain_boundary, "AuthMode": self.auth_mode, "DefaultUserSettings": self.default_user_settings, "DomainSetting": self.domain_settings, "AppNetworkAccessType": self.app_network_access_type, "HomeEfsFileSystemKmsKeyId": self.home_efs_file_system_kms_key_id, "SubnetIds": self.subnet_ids, "Url": self.url, "VpcId": self.vpc_id, "KmsKeyId": self.kms_key_id, "AppSecurityGroupManagement": self.app_security_group_management, "DefaultSpaceSettings": self.default_space_settings, } def summary(self) -> Dict[str, Any]: return { "DomainArn": self.arn, "DomainId": self.id, "DomainName": self.domain_name, "Status": self.status, "CreationTime": self.creation_time, "LastModifiedTime": self.last_modified_time, "Url": self.url, } class ModelExplainabilityJobDefinition(BaseObject): def __init__( self, job_definition_name: str, model_explainability_baseline_config: Optional[Dict[str, Any]], model_explainability_app_specification: Dict[str, Any], model_explainability_job_input: Dict[str, Any], model_explainability_job_output_config: Dict[str, Any], job_resources: Dict[str, Any], network_config: Optional[Dict[str, Any]], role_arn: str, stopping_condition: Optional[Dict[str, Any]], region_name: str, account_id: str, tags: Optional[List[Dict[str, str]]], ): self.job_definition_name = job_definition_name if ( job_definition_name in sagemaker_backends[account_id][ region_name ].model_explainability_job_definitions ): raise ResourceInUseException( message=f"Resource Already Exists: ModelExplainabilityJobDefinition with name {job_definition_name} already exists. Choose a different name." ) self.model_explainability_baseline_config = model_explainability_baseline_config self.model_explainability_app_specification = ( model_explainability_app_specification ) self.model_explainability_job_input = model_explainability_job_input self.model_explainability_job_output_config = ( model_explainability_job_output_config ) self.job_resources = job_resources self.network_config = network_config self.role_arn = role_arn self.stopping_condition = stopping_condition self.region_name = region_name self.account_id = account_id self.tags = tags self.arn = arn_formatter( "model-explainability-job-definition", job_definition_name, self.account_id, self.region_name, ) self.creation_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") self.endpoint_name = model_explainability_job_input["EndpointInput"][ "EndpointName" ] def describe(self) -> Dict[str, Any]: return { "JobDefinitionArn": self.arn, "JobDefinitionName": self.job_definition_name, "CreationTime": self.creation_time, "ModelExplainabilityBaselineConfig": self.model_explainability_baseline_config, "ModelExplainabilityAppSpecification": self.model_explainability_app_specification, "ModelExplainabilityJobInput": self.model_explainability_job_input, "ModelExplainabilityJobOutputConfig": self.model_explainability_job_output_config, "JobResources": self.job_resources, "NetworkConfig": self.network_config, "RoleArn": self.role_arn, "StoppingConditions": self.stopping_condition, } def summary(self) -> Dict[str, Any]: return { "MonitoringJobDefinitionName": self.job_definition_name, "MonitoringJobDefinitionArn": self.arn, "CreationTime": self.creation_time, "EndpointName": self.endpoint_name, } class HyperParameterTuningJob(BaseObject): def __init__( self, hyper_parameter_tuning_job_name: str, hyper_parameter_tuning_job_config: Dict[str, Any], region_name: str, account_id: str, training_job_definition: Optional[Dict[str, Any]], training_job_definitions: Optional[List[Dict[str, Any]]], warm_start_config: Optional[Dict[str, Any]], tags: Optional[List[Dict[str, str]]], autotune: Optional[Dict[str, Any]], ): self.hyper_parameter_tuning_job_name = hyper_parameter_tuning_job_name if ( hyper_parameter_tuning_job_name in sagemaker_backends[account_id][region_name].hyper_parameter_tuning_jobs ): raise ResourceInUseException( message=f"Resource Already Exists: Hyper Parameter Tuning Job with name {hyper_parameter_tuning_job_name} already exists. Choose a different name." ) self.arn = arn_formatter( "hyper-parameter-tuning-job", self.hyper_parameter_tuning_job_name, account_id, region_name, ) self.hyper_parameter_tuning_job_config = hyper_parameter_tuning_job_config self.region_name = region_name self.account_id = account_id self.training_job_definition = training_job_definition self.training_job_definitions = training_job_definitions self.hyper_parameter_tuning_job_status = "Completed" self.creation_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") self.last_modified_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") self.hyper_parameter_tuning_end_time = datetime.now().strftime( "%Y-%m-%d %H:%M:%S" ) self.training_job_status_counters = { "Completed": 1, "InProgress": 0, "RetryableError": 0, "NonRetryableError": 0, "Stopped": 0, } self.objective_status_counters = { "Succeeded": 1, "Pending": 0, "Failed": 0, } self.best_training_job = { "TrainingJobDefinitionName": "string", "TrainingJobName": "FakeTrainingJobName", "TrainingJobArn": "FakeTrainingJobArn", "TuningJobName": "FakeTuningJobName", "CreationTime": str(datetime(2024, 1, 1)), "TrainingStartTime": str(datetime(2024, 1, 1)), "TrainingEndTime": str(datetime(2024, 1, 1)), "TrainingJobStatus": "Completed", "TunedHyperParameters": {"string": "TunedHyperParameters"}, "FailureReason": "string", "FinalHyperParameterTuningJobObjectiveMetric": { "Type": "Maximize", "MetricName": "Accuracy", "Value": 1, }, "ObjectiveStatus": "Succeeded", } self.OverallBestTrainingJob = { "TrainingJobDefinitionName": "FakeTrainingJobDefinitionName", "TrainingJobName": "FakeTrainingJobName", "TrainingJobArn": "FakeTrainingJobArn", "TuningJobName": "FakeTuningJobName", "CreationTime": str(datetime(2024, 1, 1)), "TrainingStartTime": str(datetime(2024, 1, 1)), "TrainingEndTime": str(datetime(2024, 1, 1)), "TrainingJobStatus": "Completed", "TunedHyperParameters": {"string": "FakeTunedHyperParameters"}, "FailureReason": "FakeFailureReason", "FinalHyperParameterTuningJobObjectiveMetric": { "Type": "Maximize", "MetricName": "Acccuracy", "Value": 1, }, "ObjectiveStatus": "Succeeded", } self.warm_start_config = warm_start_config self.failure_reason = "" self.tuning_job_completion_details = { "NumberOfTrainingJobsObjectiveNotImproving": 123, "ConvergenceDetectedTime": str(datetime(2024, 1, 1)), } self.consumed_resources = {"RuntimeInSeconds": 123} self.tags = tags self.autotune = autotune def describe(self) -> Dict[str, Any]: return { "HyperParameterTuningJobName": self.hyper_parameter_tuning_job_name, "HyperParameterTuningJobArn": self.arn, "HyperParameterTuningJobConfig": self.hyper_parameter_tuning_job_config, "TrainingJobDefinition": self.training_job_definition, "TrainingJobDefinitions": self.training_job_definitions, "HyperParameterTuningJobStatus": self.hyper_parameter_tuning_job_status, "CreationTime": self.creation_time, "HyperParameterTuningEndTime": self.hyper_parameter_tuning_end_time, "LastModifiedTime": self.last_modified_time, "TrainingJobStatusCounters": self.training_job_status_counters, "ObjectiveStatusCounters": self.objective_status_counters, "BestTrainingJob": self.best_training_job, "OverallBestTrainingJob": self.OverallBestTrainingJob, "WarmStartConfig": self.warm_start_config, "Autotune": self.autotune, "FailureReason": self.failure_reason, "TuningJobCompletionDetails": self.tuning_job_completion_details, "ConsumedResources": self.consumed_resources, } def summary(self) -> Dict[str, Any]: return { "HyperParameterTuningJobName": self.hyper_parameter_tuning_job_name, "HyperParameterTuningJobArn": self.arn, "HyperParameterTuningJobStatus": self.hyper_parameter_tuning_job_status, "Strategy": self.hyper_parameter_tuning_job_config["Strategy"], "CreationTime": self.creation_time, "HyperParameterTuningEndTime": self.hyper_parameter_tuning_end_time, "LastModifiedTime": self.last_modified_time, "TrainingJobStatusCounters": self.training_job_status_counters, "ObjectiveStatusCounters": self.objective_status_counters, "ResourceLimits": self.hyper_parameter_tuning_job_config["ResourceLimits"], } class ModelQualityJobDefinition(BaseObject): def __init__( self, job_definition_name: str, model_quality_baseline_config: Optional[Dict[str, Any]], model_quality_app_specification: Dict[str, Any], model_quality_job_input: Dict[str, Any], model_quality_job_output_config: Dict[str, Any], job_resources: Dict[str, Any], network_config: Optional[Dict[str, Any]], role_arn: str, stopping_condition: Optional[Dict[str, Any]], tags: Optional[List[Dict[str, str]]], region_name: str, account_id: str, ): self.region_name = region_name self.account_id = account_id self.job_definition_name = job_definition_name if ( job_definition_name in sagemaker_backends[account_id][region_name].model_quality_job_definitions ): raise ResourceInUseException( message=f"Resource Already Exists: Model Quality Job Definition with name {job_definition_name} already exists. Choose a different name." ) self.model_quality_baseline_config = model_quality_baseline_config self.model_quality_app_specification = model_quality_app_specification self.model_quality_job_input = model_quality_job_input self.model_quality_job_output_config = model_quality_job_output_config self.job_resources = job_resources self.network_config = network_config self.role_arn = role_arn self.stopping_condition = stopping_condition self.tags = tags or [] self.arn = arn_formatter( "model-quality-job-definition", self.job_definition_name, account_id, region_name, ) self.creation_time = str(datetime.now().strftime("%Y-%m-%d %H:%M:%S")) self.endpoint_name = self.model_quality_job_input["EndpointInput"][ "EndpointName" ] def describe(self) -> Dict[str, Any]: return { "JobDefinitionArn": self.arn, "JobDefinitionName": self.job_definition_name, "CreationTime": self.creation_time, "ModelQualityBaselineConfig": self.model_quality_baseline_config, "ModelQualityAppSpecification": self.model_quality_app_specification, "ModelQualityJobInput": self.model_quality_job_input, "ModelQualityJobOutputConfig": self.model_quality_job_output_config, "JobResources": self.job_resources, "NetworkConfig": self.network_config, "RoleArn": self.role_arn, "StoppingCondition": self.stopping_condition, } def summary(self) -> Dict[str, Any]: return { "MonitoringJobDefinitionName": self.job_definition_name, "MonitoringJobDefinitionArn": self.arn, "CreationTime": self.creation_time, "EndpointName": self.endpoint_name, } class VpcConfig(BaseObject): def __init__(self, security_group_ids: List[str], subnets: List[str]): self.security_group_ids = security_group_ids self.subnets = subnets @property def response_object(self) -> Dict[str, List[str]]: response_object = self.gen_response_object() return { k: v for k, v in response_object.items() if v is not None and v != [None] } class Container(BaseObject): def __init__(self, **kwargs: Any): self.container_hostname = kwargs.get("container_hostname", "localhost") self.model_data_url = kwargs.get("data_url", "") self.model_package_name = kwargs.get("package_name", "pkg") self.image = kwargs.get("image", "") self.environment = kwargs.get("environment", {}) @property def response_object(self) -> Dict[str, Any]: # type: ignore[misc] response_object = self.gen_response_object() return { k: v for k, v in response_object.items() if v is not None and v != [None] } class FakeSagemakerNotebookInstance(CloudFormationModel): def __init__( self, account_id: str, region_name: str, notebook_instance_name: str, instance_type: str, role_arn: str, subnet_id: Optional[str], security_group_ids: Optional[List[str]], kms_key_id: Optional[str], tags: Optional[List[Dict[str, str]]], lifecycle_config_name: Optional[str], direct_internet_access: str, volume_size_in_gb: int, accelerator_types: Optional[List[str]], default_code_repository: Optional[str], additional_code_repositories: Optional[List[str]], root_access: Optional[str], ): self.validate_volume_size_in_gb(volume_size_in_gb) self.validate_instance_type(instance_type) self.region_name = region_name self.notebook_instance_name = notebook_instance_name self.instance_type = instance_type self.role_arn = role_arn self.subnet_id = subnet_id self.security_group_ids = security_group_ids self.kms_key_id = kms_key_id self.tags = tags or [] self.lifecycle_config_name = lifecycle_config_name self.direct_internet_access = direct_internet_access self.volume_size_in_gb = volume_size_in_gb self.accelerator_types = accelerator_types self.default_code_repository = default_code_repository self.additional_code_repositories = additional_code_repositories self.root_access = root_access self.status = "Pending" self.creation_time = self.last_modified_time = datetime.now() self.arn = arn_formatter( "notebook-instance", notebook_instance_name, account_id, region_name ) self.start() def validate_volume_size_in_gb(self, volume_size_in_gb: int) -> None: if not validators.is_integer_between(volume_size_in_gb, mn=5, optional=True): message = "Invalid range for parameter VolumeSizeInGB, value: {}, valid range: 5-inf" raise ValidationError(message=message) def validate_instance_type(self, instance_type: str) -> None: VALID_INSTANCE_TYPES = [ "ml.p2.xlarge", "ml.m5.4xlarge", "ml.m4.16xlarge", "ml.t3.xlarge", "ml.p3.16xlarge", "ml.t2.xlarge", "ml.p2.16xlarge", "ml.c4.2xlarge", "ml.c5.2xlarge", "ml.c4.4xlarge", "ml.c5d.2xlarge", "ml.c5.4xlarge", "ml.c5d.4xlarge", "ml.c4.8xlarge", "ml.c5d.xlarge", "ml.c5.9xlarge", "ml.c5.xlarge", "ml.c5d.9xlarge", "ml.c4.xlarge", "ml.t2.2xlarge", "ml.c5d.18xlarge", "ml.t3.2xlarge", "ml.t3.medium", "ml.t2.medium", "ml.c5.18xlarge", "ml.p3.2xlarge", "ml.m5.xlarge", "ml.m4.10xlarge", "ml.t2.large", "ml.m5.12xlarge", "ml.m4.xlarge", "ml.t3.large", "ml.m5.24xlarge", "ml.m4.2xlarge", "ml.p2.8xlarge", "ml.m5.2xlarge", "ml.p3.8xlarge", "ml.m4.4xlarge", ] if not validators.is_one_of(instance_type, VALID_INSTANCE_TYPES): message = f"Value '{instance_type}' at 'instanceType' failed to satisfy constraint: Member must satisfy enum value set: {VALID_INSTANCE_TYPES}" raise ValidationError(message=message) @property def url(self) -> str: return ( f"{self.notebook_instance_name}.notebook.{self.region_name}.sagemaker.aws" ) def start(self) -> None: self.status = "InService" @property def is_deletable(self) -> bool: return self.status in ["Stopped", "Failed"] def stop(self) -> None: self.status = "Stopped" @property def physical_resource_id(self) -> str: return self.arn @classmethod def has_cfn_attr(cls, attr: str) -> bool: return attr in ["NotebookInstanceName"] def get_cfn_attribute(self, attribute_name: str) -> str: # https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-resource-sagemaker-notebookinstance.html#aws-resource-sagemaker-notebookinstance-return-values from moto.cloudformation.exceptions import UnformattedGetAttTemplateException if attribute_name == "NotebookInstanceName": return self.notebook_instance_name raise UnformattedGetAttTemplateException() @staticmethod def cloudformation_name_type() -> str: return "" @staticmethod def cloudformation_type() -> str: # https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-resource-sagemaker-notebookinstance.html return "AWS::SageMaker::NotebookInstance" @classmethod def create_from_cloudformation_json( # type: ignore[misc] cls, resource_name: str, cloudformation_json: Any, account_id: str, region_name: str, **kwargs: Any, ) -> "FakeSagemakerNotebookInstance": # Get required properties from provided CloudFormation template properties = cloudformation_json["Properties"] instance_type = properties["InstanceType"] role_arn = properties["RoleArn"] notebook = sagemaker_backends[account_id][region_name].create_notebook_instance( notebook_instance_name=resource_name, instance_type=instance_type, role_arn=role_arn, ) return notebook @classmethod def update_from_cloudformation_json( # type: ignore[misc] cls, original_resource: Any, new_resource_name: str, cloudformation_json: Any, account_id: str, region_name: str, ) -> "FakeSagemakerNotebookInstance": # Operations keep same resource name so delete old and create new to mimic update cls.delete_from_cloudformation_json( original_resource.arn, cloudformation_json, account_id, region_name ) new_resource = cls.create_from_cloudformation_json( original_resource.notebook_instance_name, cloudformation_json, account_id, region_name, ) return new_resource @classmethod def delete_from_cloudformation_json( # type: ignore[misc] cls, resource_name: str, cloudformation_json: Any, account_id: str, region_name: str, ) -> None: # Get actual name because resource_name actually provides the ARN # since the Physical Resource ID is the ARN despite SageMaker # using the name for most of its operations. notebook_instance_name = resource_name.split("/")[-1] backend = sagemaker_backends[account_id][region_name] backend.stop_notebook_instance(notebook_instance_name) backend.delete_notebook_instance(notebook_instance_name) def to_dict(self) -> Dict[str, Any]: return { "NotebookInstanceArn": self.arn, "NotebookInstanceName": self.notebook_instance_name, "NotebookInstanceStatus": self.status, "Url": self.url, "InstanceType": self.instance_type, "SubnetId": self.subnet_id, "SecurityGroups": self.security_group_ids, "RoleArn": self.role_arn, "KmsKeyId": self.kms_key_id, # ToDo: NetworkInterfaceId "LastModifiedTime": str(self.last_modified_time), "CreationTime": str(self.creation_time), "NotebookInstanceLifecycleConfigName": self.lifecycle_config_name, "DirectInternetAccess": self.direct_internet_access, "VolumeSizeInGB": self.volume_size_in_gb, "AcceleratorTypes": self.accelerator_types, "DefaultCodeRepository": self.default_code_repository, "AdditionalCodeRepositories": self.additional_code_repositories, "RootAccess": self.root_access, } class FakeSageMakerNotebookInstanceLifecycleConfig(BaseObject, CloudFormationModel): def __init__( self, account_id: str, region_name: str, notebook_instance_lifecycle_config_name: str, on_create: List[Dict[str, str]], on_start: List[Dict[str, str]], ): self.region_name = region_name self.notebook_instance_lifecycle_config_name = ( notebook_instance_lifecycle_config_name ) self.on_create = on_create self.on_start = on_start self.creation_time = self.last_modified_time = datetime.now().strftime( "%Y-%m-%d %H:%M:%S" ) self.arn = FakeSageMakerNotebookInstanceLifecycleConfig.arn_formatter( self.notebook_instance_lifecycle_config_name, account_id, region_name ) @staticmethod def arn_formatter(name: str, account_id: str, region_name: str) -> str: return arn_formatter( "notebook-instance-lifecycle-config", name, account_id, region_name ) @property def response_object(self) -> Dict[str, Any]: # type: ignore[misc] response_object = self.gen_response_object() response = { k: v for k, v in response_object.items() if v is not None and v != [None] } response["NotebookInstanceLifecycleConfigArn"] = response.pop("Arn") return response @property def physical_resource_id(self) -> str: return self.arn @classmethod def has_cfn_attr(cls, attr: str) -> bool: return attr in ["NotebookInstanceLifecycleConfigName"] def get_cfn_attribute(self, attribute_name: str) -> str: # https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-resource-sagemaker-notebookinstancelifecycleconfig.html#aws-resource-sagemaker-notebookinstancelifecycleconfig-return-values from moto.cloudformation.exceptions import UnformattedGetAttTemplateException if attribute_name == "NotebookInstanceLifecycleConfigName": return self.notebook_instance_lifecycle_config_name raise UnformattedGetAttTemplateException() @staticmethod def cloudformation_name_type() -> str: return "" @staticmethod def cloudformation_type() -> str: # https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-resource-sagemaker-notebookinstancelifecycleconfig.html return "AWS::SageMaker::NotebookInstanceLifecycleConfig" @classmethod def create_from_cloudformation_json( # type: ignore[misc] cls, resource_name: str, cloudformation_json: Any, account_id: str, region_name: str, **kwargs: Any, ) -> "FakeSageMakerNotebookInstanceLifecycleConfig": properties = cloudformation_json["Properties"] config = sagemaker_backends[account_id][ region_name ].create_notebook_instance_lifecycle_config( notebook_instance_lifecycle_config_name=resource_name, on_create=properties.get("OnCreate"), on_start=properties.get("OnStart"), ) return config @classmethod def update_from_cloudformation_json( # type: ignore[misc] cls, original_resource: Any, new_resource_name: str, cloudformation_json: Any, account_id: str, region_name: str, ) -> "FakeSageMakerNotebookInstanceLifecycleConfig": # Operations keep same resource name so delete old and create new to mimic update cls.delete_from_cloudformation_json( original_resource.arn, cloudformation_json, account_id, region_name, ) new_resource = cls.create_from_cloudformation_json( original_resource.notebook_instance_lifecycle_config_name, cloudformation_json, account_id, region_name, ) return new_resource @classmethod def delete_from_cloudformation_json( # type: ignore[misc] cls, resource_name: str, cloudformation_json: Any, account_id: str, region_name: str, ) -> None: # Get actual name because resource_name actually provides the ARN # since the Physical Resource ID is the ARN despite SageMaker # using the name for most of its operations. config_name = resource_name.split("/")[-1] backend = sagemaker_backends[account_id][region_name] backend.delete_notebook_instance_lifecycle_config(config_name) class SageMakerModelBackend(BaseBackend): def __init__(self, region_name: str, account_id: str): super().__init__(region_name, account_id) self._models: Dict[str, Model] = {} self.notebook_instances: Dict[str, FakeSagemakerNotebookInstance] = {} self.endpoint_configs: Dict[str, FakeEndpointConfig] = {} self.endpoints: Dict[str, FakeEndpoint] = {} self.experiments: Dict[str, FakeExperiment] = {} self.pipelines: Dict[str, FakePipeline] = {} self.pipeline_executions: Dict[str, FakePipelineExecution] = {} self.processing_jobs: Dict[str, FakeProcessingJob] = {} self.trials: Dict[str, FakeTrial] = {} self.trial_components: Dict[str, FakeTrialComponent] = {} self.training_jobs: Dict[str, FakeTrainingJob] = {} self.transform_jobs: Dict[str, FakeTransformJob] = {} self.notebook_instance_lifecycle_configurations: Dict[ str, FakeSageMakerNotebookInstanceLifecycleConfig ] = {} self.model_cards: DefaultDict[str, List[FakeModelCard]] = defaultdict(list) self.model_package_groups: Dict[str, ModelPackageGroup] = {} self.model_packages: Dict[str, ModelPackage] = {} self.model_package_name_mapping: Dict[str, str] = {} self.feature_groups: Dict[str, FeatureGroup] = {} self.clusters: Dict[str, Cluster] = {} self.data_quality_job_definitions: Dict[str, FakeDataQualityJobDefinition] = {} self.model_bias_job_definitions: Dict[str, FakeModelBiasJobDefinition] = {} self.auto_ml_jobs: Dict[str, AutoMLJob] = {} self.compilation_jobs: Dict[str, CompilationJob] = {} self.domains: Dict[str, Domain] = {} self.model_explainability_job_definitions: Dict[ str, ModelExplainabilityJobDefinition ] = {} self.hyper_parameter_tuning_jobs: Dict[str, HyperParameterTuningJob] = {} self.model_quality_job_definitions: Dict[str, ModelQualityJobDefinition] = {} @staticmethod def default_vpc_endpoint_service( service_region: str, zones: List[str] ) -> List[Dict[str, str]]: """Default VPC endpoint services.""" api_service = BaseBackend.default_vpc_endpoint_service_factory( service_region, zones, "api.sagemaker", special_service_name="sagemaker.api" ) notebook_service_id = f"vpce-svc-{BaseBackend.vpce_random_number()}" studio_service_id = f"vpce-svc-{BaseBackend.vpce_random_number()}" notebook_service = { "AcceptanceRequired": False, "AvailabilityZones": zones, "BaseEndpointDnsNames": [ f"{notebook_service_id}.{service_region}.vpce.amazonaws.com", f"notebook.{service_region}.vpce.sagemaker.aws", ], "ManagesVpcEndpoints": False, "Owner": "amazon", "PrivateDnsName": f"*.notebook.{service_region}.sagemaker.aws", "PrivateDnsNameVerificationState": "verified", "PrivateDnsNames": [ {"PrivateDnsName": f"*.notebook.{service_region}.sagemaker.aws"} ], "ServiceId": notebook_service_id, "ServiceName": f"aws.sagemaker.{service_region}.notebook", "ServiceType": [{"ServiceType": "Interface"}], "Tags": [], "VpcEndpointPolicySupported": True, } studio_service = { "AcceptanceRequired": False, "AvailabilityZones": zones, "BaseEndpointDnsNames": [ f"{studio_service_id}.{service_region}.vpce.amazonaws.com", f"studio.{service_region}.vpce.sagemaker.aws", ], "ManagesVpcEndpoints": False, "Owner": "amazon", "PrivateDnsName": f"*.studio.{service_region}.sagemaker.aws", "PrivateDnsNameVerificationState": "verified", "PrivateDnsNames": [ {"PrivateDnsName": f"*.studio.{service_region}.sagemaker.aws"} ], "ServiceId": studio_service_id, "ServiceName": f"aws.sagemaker.{service_region}.studio", "ServiceType": [{"ServiceType": "Interface"}], "Tags": [], "VpcEndpointPolicySupported": True, } return api_service + [notebook_service, studio_service] def create_model( self, model_name: str, execution_role_arn: str, primary_container: Optional[Dict[str, Any]], vpc_config: Optional[Dict[str, Any]], containers: Optional[List[Dict[str, Any]]], tags: Optional[List[Dict[str, str]]], ) -> Model: model_obj = Model( account_id=self.account_id, region_name=self.region_name, model_name=model_name, execution_role_arn=execution_role_arn, primary_container=primary_container or {}, vpc_config=vpc_config or {}, containers=containers or [], tags=tags or [], ) self._models[model_name] = model_obj return model_obj def describe_model(self, model_name: str) -> Model: model = self._models.get(model_name) if model: return model arn = arn_formatter("model", model_name, self.account_id, self.region_name) raise ValidationError(message=f"Could not find model '{arn}'.") def list_models(self) -> Iterable[Model]: return self._models.values() def delete_model(self, model_name: str) -> None: for model in self._models.values(): if model.model_name == model_name: self._models.pop(model.model_name) break else: raise MissingModel(model=model_name) def create_experiment(self, experiment_name: str) -> Dict[str, str]: experiment = FakeExperiment( account_id=self.account_id, region_name=self.region_name, experiment_name=experiment_name, tags=[], ) self.experiments[experiment_name] = experiment return experiment.response_create def describe_experiment(self, experiment_name: str) -> Dict[str, Any]: experiment_data = self.experiments[experiment_name] return { "ExperimentName": experiment_data.experiment_name, "ExperimentArn": experiment_data.arn, "CreationTime": experiment_data.creation_time, "LastModifiedTime": experiment_data.last_modified_time, } def _get_resource_from_arn(self, arn: str) -> Any: resources = { "model": self._models, "notebook-instance": self.notebook_instances, "endpoint": self.endpoints, "endpoint-config": self.endpoint_configs, "training-job": self.training_jobs, "transform-job": self.transform_jobs, "experiment": self.experiments, "experiment-trial": self.trials, "experiment-trial-component": self.trial_components, "processing-job": self.processing_jobs, "pipeline": self.pipelines, "model-package-group": self.model_package_groups, "cluster": self.clusters, "data-quality-job-definition": self.data_quality_job_definitions, "model-bias-job-definition": self.model_bias_job_definitions, "automl-job": self.auto_ml_jobs, "compilation-job": self.compilation_jobs, "domain": self.domains, "model-explainability-job-definition": self.model_explainability_job_definitions, "hyper-parameter-tuning-job": self.hyper_parameter_tuning_jobs, "model-quality-job-definition": self.model_quality_job_definitions, "model-card": self.model_cards, } target_resource, target_name = arn.split(":")[-1].split("/") try: resource = resources.get(target_resource).get(target_name) # type: ignore except KeyError: message = f"Could not find {target_resource} with name {target_name}" raise ValidationError(message=message) if isinstance(resource, list): return resource[0] return resource def add_tags(self, arn: str, tags: List[Dict[str, str]]) -> List[Dict[str, str]]: resource = self._get_resource_from_arn(arn) resource.tags.extend(tags) return resource.tags @paginate(pagination_model=PAGINATION_MODEL) def list_tags(self, arn: str) -> List[Dict[str, str]]: resource = self._get_resource_from_arn(arn) return resource.tags def delete_tags(self, arn: str, tag_keys: List[str]) -> None: resource = self._get_resource_from_arn(arn) resource.tags = [tag for tag in resource.tags if tag["Key"] not in tag_keys] @paginate(pagination_model=PAGINATION_MODEL) def list_experiments(self) -> List["FakeExperiment"]: return list(self.experiments.values()) def search(self, resource: Any = None, search_expression: Any = None) -> Any: """ Only a few SearchExpressions are implemented. Please open a bug report if you find any issues. """ next_index = None valid_resources = { "Pipeline": self.pipelines.values(), "ModelPackageGroup": self.model_package_groups.values(), "TrainingJob": self.training_jobs.values(), "ExperimentTrialComponent": self.trial_components.values(), "FeatureGroup": self.feature_groups.values(), "Endpoint": self.endpoints.values(), "PipelineExecution": self.pipeline_executions.values(), "Project": [], "ExperimentTrial": self.trials.values(), "Image": [], "ImageVersion": [], "ModelPackage": self.model_packages.values(), "Experiment": self.experiments.values(), } if resource not in valid_resources: raise AWSValidationException( f"An error occurred (ValidationException) when calling the Search operation: 1 validation error detected: Value '{resource}' at 'resource' failed to satisfy constraint: Member must satisfy enum value set: {valid_resources}" ) def compare_value(actual: Any, expected: Any, operator: str) -> bool: # Defeault: operator == "Equals" if operator == "Contains": return expected in actual if operator == "NotEquals": return expected != actual return actual == expected def evaluate_search_expression(item: Any) -> bool: filters = None if search_expression is not None: filters = search_expression.get("Filters") if filters is not None: for f in filters: prop_key = camelcase_to_underscores(f["Name"]) if f["Name"].startswith("Tags."): key = f["Name"][5:] value = f["Value"] if f["Operator"] == "Equals": if not [ e for e in item.tags if e["Key"] == key and e["Value"] == value ]: return False return True elif f["Name"] == "TrialName": raise AWSValidationException( f"An error occurred (ValidationException) when calling the Search operation: Unknown property name: {f['Name']}" ) elif f["Name"] == "Parents.TrialName": trial_name = f["Value"] if getattr(item, "trial_name") != trial_name: return False elif hasattr(item, prop_key): if not compare_value( getattr(item, prop_key), f["Value"], f["Operator"] ): return False else: raise ValidationError( message=f"Unknown property name: {f['Name']}" ) return True result: Dict[str, Any] = { "Results": [], "NextToken": str(next_index) if next_index is not None else None, } # ResourceName, ResultName, Resources result_names = { "ExperimentTrial": "Trial", "ExperimentTrialComponent": "TrialComponent", } resources_found = [ x for x in valid_resources[resource] if evaluate_search_expression(x) ] result_name = result_names.get(resource, resource) for found in resources_found: result["Results"].append({result_name: found.gen_response_object()}) return result def delete_experiment(self, experiment_name: str) -> None: try: del self.experiments[experiment_name] except KeyError: arn = FakeTrial.arn_formatter( experiment_name, self.account_id, self.region_name ) raise ValidationError( message=f"Could not find experiment configuration '{arn}'." ) def create_trial(self, trial_name: str, experiment_name: str) -> Dict[str, str]: trial = FakeTrial( account_id=self.account_id, region_name=self.region_name, trial_name=trial_name, experiment_name=experiment_name, tags=[], trial_components=[], ) self.trials[trial_name] = trial return trial.response_create def describe_trial(self, trial_name: str) -> Dict[str, Any]: try: return self.trials[trial_name].response_object except KeyError: arn = FakeTrial.arn_formatter(trial_name, self.account_id, self.region_name) raise ValidationError(message=f"Could not find trial '{arn}'.") def delete_trial(self, trial_name: str) -> None: try: del self.trials[trial_name] except KeyError: arn = FakeTrial.arn_formatter(trial_name, self.account_id, self.region_name) raise ValidationError( message=f"Could not find trial configuration '{arn}'." ) @paginate(pagination_model=PAGINATION_MODEL) def list_trials( self, experiment_name: Optional[str] = None, trial_component_name: Optional[str] = None, ) -> List["FakeTrial"]: trials_fetched = list(self.trials.values()) def evaluate_filter_expression(trial_data: FakeTrial) -> bool: if experiment_name is not None: if trial_data.experiment_name != experiment_name: return False if trial_component_name is not None: if trial_component_name not in trial_data.trial_components: return False return True return [ trial_data for trial_data in trials_fetched if evaluate_filter_expression(trial_data) ] def create_trial_component( self, trial_component_name: str, trial_name: str, status: Dict[str, str], start_time: Optional[datetime], end_time: Optional[datetime], display_name: Optional[str], parameters: Optional[Dict[str, Dict[str, Union[str, float]]]], input_artifacts: Optional[Dict[str, Dict[str, str]]], output_artifacts: Optional[Dict[str, Dict[str, str]]], metadata_properties: Optional[Dict[str, str]], ) -> Dict[str, Any]: trial_component = FakeTrialComponent( account_id=self.account_id, region_name=self.region_name, display_name=display_name, start_time=start_time, end_time=end_time, parameters=parameters, input_artifacts=input_artifacts, output_artifacts=output_artifacts, metadata_properties=metadata_properties, trial_component_name=trial_component_name, trial_name=trial_name, status=status, tags=[], ) self.trial_components[trial_component_name] = trial_component return trial_component.response_create def delete_trial_component(self, trial_component_name: str) -> None: try: del self.trial_components[trial_component_name] except KeyError: arn = FakeTrial.arn_formatter( trial_component_name, self.account_id, self.region_name ) raise ValidationError( message=f"Could not find trial-component configuration '{arn}'." ) def describe_trial_component(self, trial_component_name: str) -> Dict[str, Any]: try: return self.trial_components[trial_component_name].response_object except KeyError: arn = FakeTrialComponent.arn_formatter( trial_component_name, self.account_id, self.region_name ) raise ValidationError(message=f"Could not find trial component '{arn}'.") def _update_trial_component_details( self, trial_component_name: str, details_json: str ) -> None: self.trial_components[trial_component_name].update(details_json) @paginate(pagination_model=PAGINATION_MODEL) def list_trial_components( self, trial_name: Optional[str] = None ) -> List["FakeTrialComponent"]: trial_components_fetched = list(self.trial_components.values()) return [ trial_component_data for trial_component_data in trial_components_fetched if trial_name is None or trial_component_data.trial_name == trial_name ] def associate_trial_component( self, trial_name: str, trial_component_name: str ) -> Dict[str, str]: if trial_name in self.trials.keys(): self.trials[trial_name].trial_components.extend([trial_component_name]) else: raise ResourceNotFound( message=f"Trial 'arn:{get_partition(self.region_name)}:sagemaker:{self.region_name}:{self.account_id}:experiment-trial/{trial_name}' does not exist." ) if trial_component_name in self.trial_components.keys(): self.trial_components[trial_component_name].trial_name = trial_name return { "TrialComponentArn": self.trial_components[trial_component_name].arn, "TrialArn": self.trials[trial_name].arn, } def disassociate_trial_component( self, trial_name: str, trial_component_name: str ) -> Dict[str, str]: if trial_component_name in self.trial_components.keys(): self.trial_components[trial_component_name].trial_name = None if trial_name in self.trials.keys(): self.trials[trial_name].trial_components = list( filter( lambda x: x != trial_component_name, self.trials[trial_name].trial_components, ) ) return { "TrialComponentArn": f"arn:{get_partition(self.region_name)}:sagemaker:{self.region_name}:{self.account_id}:experiment-trial-component/{trial_component_name}", "TrialArn": f"arn:{get_partition(self.region_name)}:sagemaker:{self.region_name}:{self.account_id}:experiment-trial/{trial_name}", } def update_trial_component( self, trial_component_name: str, status: Optional[Dict[str, str]], display_name: Optional[str], start_time: Optional[datetime], end_time: Optional[datetime], parameters: Optional[Dict[str, Dict[str, Union[str, float]]]], parameters_to_remove: Optional[List[str]], input_artifacts: Optional[Dict[str, Dict[str, str]]], input_artifacts_to_remove: Optional[List[str]], output_artifacts: Optional[Dict[str, Dict[str, str]]], output_artifacts_to_remove: Optional[List[str]], ) -> Dict[str, str]: try: trial_component = self.trial_components[trial_component_name] except KeyError: arn = FakeTrialComponent.arn_formatter( trial_component_name, self.account_id, self.region_name ) raise ValidationError(message=f"Could not find trial component '{arn}'") if status: trial_component.status = status if display_name: trial_component.display_name = display_name if start_time: trial_component.start_time = start_time if end_time: trial_component.end_time = end_time if parameters: trial_component.parameters = parameters if input_artifacts: trial_component.input_artifacts = input_artifacts if output_artifacts: trial_component.output_artifacts = output_artifacts trial_component.last_modified_time = datetime.now().strftime( "%Y-%m-%d %H:%M:%S" ) for parameter_to_remove in parameters_to_remove or []: trial_component.parameters.pop(parameter_to_remove) for input_artifact_to_remove in input_artifacts_to_remove or []: trial_component.input_artifacts.pop(input_artifact_to_remove) for output_artifact_to_remove in output_artifacts_to_remove or []: trial_component.output_artifacts.pop(output_artifact_to_remove) return { "TrialComponentArn": FakeTrialComponent.arn_formatter( trial_component_name, self.account_id, self.region_name ) } def create_notebook_instance( self, notebook_instance_name: str, instance_type: str, role_arn: str, subnet_id: Optional[str] = None, security_group_ids: Optional[List[str]] = None, kms_key_id: Optional[str] = None, tags: Optional[List[Dict[str, str]]] = None, lifecycle_config_name: Optional[str] = None, direct_internet_access: str = "Enabled", volume_size_in_gb: int = 5, accelerator_types: Optional[List[str]] = None, default_code_repository: Optional[str] = None, additional_code_repositories: Optional[List[str]] = None, root_access: Optional[str] = None, ) -> FakeSagemakerNotebookInstance: self._validate_unique_notebook_instance_name(notebook_instance_name) notebook_instance = FakeSagemakerNotebookInstance( account_id=self.account_id, region_name=self.region_name, notebook_instance_name=notebook_instance_name, instance_type=instance_type, role_arn=role_arn, subnet_id=subnet_id, security_group_ids=security_group_ids, kms_key_id=kms_key_id, tags=tags, lifecycle_config_name=lifecycle_config_name, direct_internet_access=direct_internet_access if direct_internet_access is not None else "Enabled", volume_size_in_gb=volume_size_in_gb if volume_size_in_gb is not None else 5, accelerator_types=accelerator_types, default_code_repository=default_code_repository, additional_code_repositories=additional_code_repositories, root_access=root_access, ) self.notebook_instances[notebook_instance_name] = notebook_instance return notebook_instance def _validate_unique_notebook_instance_name( self, notebook_instance_name: str ) -> None: if notebook_instance_name in self.notebook_instances: duplicate_arn = self.notebook_instances[notebook_instance_name].arn message = f"Cannot create a duplicate Notebook Instance ({duplicate_arn})" raise ValidationError(message=message) def get_notebook_instance( self, notebook_instance_name: str ) -> FakeSagemakerNotebookInstance: try: return self.notebook_instances[notebook_instance_name] except KeyError: raise ValidationError(message="RecordNotFound") def start_notebook_instance(self, notebook_instance_name: str) -> None: notebook_instance = self.get_notebook_instance(notebook_instance_name) notebook_instance.start() def stop_notebook_instance(self, notebook_instance_name: str) -> None: notebook_instance = self.get_notebook_instance(notebook_instance_name) notebook_instance.stop() def delete_notebook_instance(self, notebook_instance_name: str) -> None: notebook_instance = self.get_notebook_instance(notebook_instance_name) if not notebook_instance.is_deletable: message = f"Status ({notebook_instance.status}) not in ([Stopped, Failed]). Unable to transition to (Deleting) for Notebook Instance ({notebook_instance.arn})" raise ValidationError(message=message) del self.notebook_instances[notebook_instance_name] @paginate(pagination_model=PAGINATION_MODEL) def list_notebook_instances( self, sort_by: str, sort_order: str, name_contains: Optional[str], status: Optional[str], ) -> List[FakeSagemakerNotebookInstance]: """ The following parameters are not yet implemented: CreationTimeBefore, CreationTimeAfter, LastModifiedTimeBefore, LastModifiedTimeAfter, NotebookInstanceLifecycleConfigNameContains, DefaultCodeRepositoryContains, AdditionalCodeRepositoryEquals """ instances = list(self.notebook_instances.values()) if name_contains: instances = [ i for i in instances if name_contains in i.notebook_instance_name ] if status: instances = [i for i in instances if i.status == status] reverse = sort_order == "Descending" if sort_by == "Name": instances = sorted( instances, key=lambda x: x.notebook_instance_name, reverse=reverse ) if sort_by == "CreationTime": instances = sorted( instances, key=lambda x: x.creation_time, reverse=reverse ) if sort_by == "Status": instances = sorted(instances, key=lambda x: x.status, reverse=reverse) return instances def create_notebook_instance_lifecycle_config( self, notebook_instance_lifecycle_config_name: str, on_create: List[Dict[str, str]], on_start: List[Dict[str, str]], ) -> FakeSageMakerNotebookInstanceLifecycleConfig: if ( notebook_instance_lifecycle_config_name in self.notebook_instance_lifecycle_configurations ): arn = FakeSageMakerNotebookInstanceLifecycleConfig.arn_formatter( notebook_instance_lifecycle_config_name, self.account_id, self.region_name, ) message = f"Unable to create Notebook Instance Lifecycle Config {arn}. (Details: Notebook Instance Lifecycle Config already exists.)" raise ValidationError(message=message) lifecycle_config = FakeSageMakerNotebookInstanceLifecycleConfig( account_id=self.account_id, region_name=self.region_name, notebook_instance_lifecycle_config_name=notebook_instance_lifecycle_config_name, on_create=on_create, on_start=on_start, ) self.notebook_instance_lifecycle_configurations[ notebook_instance_lifecycle_config_name ] = lifecycle_config return lifecycle_config def describe_notebook_instance_lifecycle_config( self, notebook_instance_lifecycle_config_name: str ) -> Dict[str, Any]: try: return self.notebook_instance_lifecycle_configurations[ notebook_instance_lifecycle_config_name ].response_object except KeyError: arn = FakeSageMakerNotebookInstanceLifecycleConfig.arn_formatter( notebook_instance_lifecycle_config_name, self.account_id, self.region_name, ) message = f"Unable to describe Notebook Instance Lifecycle Config '{arn}'. (Details: Notebook Instance Lifecycle Config does not exist.)" raise ValidationError(message=message) def delete_notebook_instance_lifecycle_config( self, notebook_instance_lifecycle_config_name: str ) -> None: try: del self.notebook_instance_lifecycle_configurations[ notebook_instance_lifecycle_config_name ] except KeyError: arn = FakeSageMakerNotebookInstanceLifecycleConfig.arn_formatter( notebook_instance_lifecycle_config_name, self.account_id, self.region_name, ) message = f"Unable to delete Notebook Instance Lifecycle Config '{arn}'. (Details: Notebook Instance Lifecycle Config does not exist.)" raise ValidationError(message=message) def create_endpoint_config( self, endpoint_config_name: str, production_variants: List[Dict[str, Any]], data_capture_config: Dict[str, Any], tags: List[Dict[str, str]], kms_key_id: str, ) -> FakeEndpointConfig: endpoint_config = FakeEndpointConfig( account_id=self.account_id, region_name=self.region_name, endpoint_config_name=endpoint_config_name, production_variants=production_variants, data_capture_config=data_capture_config, tags=tags, kms_key_id=kms_key_id, ) self.validate_production_variants(production_variants) self.endpoint_configs[endpoint_config_name] = endpoint_config return endpoint_config def validate_production_variants( self, production_variants: List[Dict[str, Any]] ) -> None: for production_variant in production_variants: if production_variant["ModelName"] not in self._models: arn = arn_formatter( "model", production_variant["ModelName"], self.account_id, self.region_name, ) raise ValidationError(message=f"Could not find model '{arn}'.") def describe_endpoint_config(self, endpoint_config_name: str) -> Dict[str, Any]: try: return self.endpoint_configs[endpoint_config_name].response_object except KeyError: arn = FakeEndpointConfig.arn_formatter( endpoint_config_name, self.account_id, self.region_name ) raise ValidationError( message=f"Could not find endpoint configuration '{arn}'." ) def delete_endpoint_config(self, endpoint_config_name: str) -> None: try: del self.endpoint_configs[endpoint_config_name] except KeyError: arn = FakeEndpointConfig.arn_formatter( endpoint_config_name, self.account_id, self.region_name ) raise ValidationError( message=f"Could not find endpoint configuration '{arn}'." ) def create_endpoint( self, endpoint_name: str, endpoint_config_name: str, tags: List[Dict[str, str]] ) -> FakeEndpoint: try: endpoint_config = self.describe_endpoint_config(endpoint_config_name) except KeyError: arn = FakeEndpointConfig.arn_formatter( endpoint_config_name, self.account_id, self.region_name ) raise ValidationError(message=f"Could not find endpoint_config '{arn}'.") endpoint = FakeEndpoint( account_id=self.account_id, region_name=self.region_name, endpoint_name=endpoint_name, endpoint_config_name=endpoint_config_name, production_variants=endpoint_config["ProductionVariants"], data_capture_config=endpoint_config["DataCaptureConfig"], tags=tags, ) self.endpoints[endpoint_name] = endpoint return endpoint def describe_endpoint(self, endpoint_name: str) -> Dict[str, Any]: try: return self.endpoints[endpoint_name].response_object except KeyError: arn = FakeEndpoint.arn_formatter( endpoint_name, self.account_id, self.region_name ) raise ValidationError(message=f"Could not find endpoint '{arn}'.") def delete_endpoint(self, endpoint_name: str) -> None: try: del self.endpoints[endpoint_name] except KeyError: arn = FakeEndpoint.arn_formatter( endpoint_name, self.account_id, self.region_name ) raise ValidationError(message=f"Could not find endpoint '{arn}'.") def create_processing_job( self, app_specification: Dict[str, Any], experiment_config: Dict[str, str], network_config: Dict[str, Any], processing_inputs: List[Dict[str, Any]], processing_job_name: str, processing_output_config: Dict[str, Any], role_arn: str, tags: List[Dict[str, str]], stopping_condition: Dict[str, int], ) -> FakeProcessingJob: processing_job = FakeProcessingJob( app_specification=app_specification, experiment_config=experiment_config, network_config=network_config, processing_inputs=processing_inputs, processing_job_name=processing_job_name, processing_output_config=processing_output_config, account_id=self.account_id, region_name=self.region_name, role_arn=role_arn, stopping_condition=stopping_condition, tags=tags, ) self.processing_jobs[processing_job_name] = processing_job return processing_job def describe_processing_job(self, processing_job_name: str) -> Dict[str, Any]: try: return self.processing_jobs[processing_job_name].response_object except KeyError: arn = FakeProcessingJob.arn_formatter( processing_job_name, self.account_id, self.region_name ) raise ValidationError(message=f"Could not find processing job '{arn}'.") def create_pipeline( self, pipeline_name: str, pipeline_display_name: str, pipeline_definition: str, pipeline_definition_s3_location: Dict[str, Any], pipeline_description: str, role_arn: str, tags: List[Dict[str, str]], parallelism_configuration: Dict[str, int], ) -> FakePipeline: if not any([pipeline_definition, pipeline_definition_s3_location]): raise ValidationError( "An error occurred (ValidationException) when calling the CreatePipeline operation: Either " "Pipeline Definition or Pipeline Definition S3 location should be provided" ) if all([pipeline_definition, pipeline_definition_s3_location]): raise ValidationError( "An error occurred (ValidationException) when calling the CreatePipeline operation: " "Both Pipeline Definition and Pipeline Definition S3 Location shouldn't be present" ) if pipeline_name in self.pipelines: raise ValidationError( f"An error occurred (ValidationException) when calling the CreatePipeline operation: Pipeline names " f"must be unique within an AWS account and region. Pipeline with name ({pipeline_name}) already exists." ) if pipeline_definition_s3_location: pipeline_definition = load_pipeline_definition_from_s3( # type: ignore pipeline_definition_s3_location, account_id=self.account_id, partition=self.partition, ) pipeline = FakePipeline( pipeline_name, pipeline_display_name, pipeline_definition, pipeline_description, role_arn, tags, self.account_id, self.region_name, parallelism_configuration, ) self.pipelines[pipeline_name] = pipeline return pipeline def delete_pipeline(self, pipeline_name: str) -> str: pipeline = get_pipeline_from_name(self.pipelines, pipeline_name) del self.pipelines[pipeline.pipeline_name] return pipeline.arn def update_pipeline(self, pipeline_name: str, **kwargs: Any) -> str: pipeline = get_pipeline_from_name(self.pipelines, pipeline_name) if all( [ kwargs.get("pipeline_definition"), kwargs.get("pipeline_definition_s3_location"), ] ): raise ValidationError( "An error occurred (ValidationException) when calling the UpdatePipeline operation: " "Both Pipeline Definition and Pipeline Definition S3 Location shouldn't be present" ) for attr_key, attr_value in kwargs.items(): if attr_value: if attr_key == "pipeline_definition_s3_location": self.pipelines[ pipeline_name ].pipeline_definition = load_pipeline_definition_from_s3( # type: ignore attr_value, self.account_id, partition=self.partition, ) continue setattr(self.pipelines[pipeline_name], attr_key, attr_value) return pipeline.arn def start_pipeline_execution( self, pipeline_name: str, pipeline_execution_display_name: str, pipeline_parameters: List[Dict[str, Any]], pipeline_execution_description: str, parallelism_configuration: Dict[str, int], client_request_token: str, ) -> Dict[str, str]: pipeline = get_pipeline_from_name(self.pipelines, pipeline_name) execution_id = "".join( random.choices(string.ascii_lowercase + string.digits, k=12) ) pipeline_execution_arn = arn_formatter( _type="pipeline", _id=f"{pipeline.pipeline_name}/execution/{execution_id}", account_id=self.account_id, region_name=self.region_name, ) fake_pipeline_execution = FakePipelineExecution( pipeline_execution_arn=pipeline_execution_arn, pipeline_execution_display_name=pipeline_execution_display_name, pipeline_parameters=pipeline_parameters, pipeline_execution_description=pipeline_execution_description, pipeline_definition=pipeline.pipeline_definition, parallelism_configuration=parallelism_configuration or pipeline.parallelism_configuration, client_request_token=client_request_token, ) self.pipelines[pipeline_name].pipeline_executions[pipeline_execution_arn] = ( fake_pipeline_execution ) self.pipelines[ pipeline_name ].last_execution_time = fake_pipeline_execution.start_time return {"PipelineExecutionArn": pipeline_execution_arn} def list_pipeline_executions(self, pipeline_name: str) -> Dict[str, Any]: pipeline = get_pipeline_from_name(self.pipelines, pipeline_name) return { "PipelineExecutionSummaries": [ { "PipelineExecutionArn": arn, "StartTime": pipeline_execution.start_time, "PipelineExecutionStatus": pipeline_execution.pipeline_execution_status, "PipelineExecutionDescription": pipeline_execution.pipeline_execution_description, "PipelineExecutionDisplayName": pipeline_execution.pipeline_execution_display_name, "PipelineExecutionFailureReason": str( pipeline_execution.pipeline_execution_failure_reason ), } for arn, pipeline_execution in pipeline.pipeline_executions.items() ] } def describe_pipeline_definition_for_execution( self, pipeline_execution_arn: str ) -> Dict[str, Any]: pipeline_execution = get_pipeline_execution_from_arn( self.pipelines, pipeline_execution_arn ) return { "PipelineDefinition": str( pipeline_execution.pipeline_definition_for_execution ), "CreationTime": pipeline_execution.creation_time, } def list_pipeline_parameters_for_execution( self, pipeline_execution_arn: str ) -> Dict[str, Any]: pipeline_execution = get_pipeline_execution_from_arn( self.pipelines, pipeline_execution_arn ) return { "PipelineParameters": pipeline_execution.pipeline_parameters, } def describe_pipeline_execution( self, pipeline_execution_arn: str ) -> Dict[str, Any]: pipeline_execution = get_pipeline_execution_from_arn( self.pipelines, pipeline_execution_arn ) pipeline_name = get_pipeline_name_from_execution_arn(pipeline_execution_arn) pipeline = get_pipeline_from_name(self.pipelines, pipeline_name) return { "PipelineArn": pipeline.arn, "PipelineExecutionArn": pipeline_execution.arn, "PipelineExecutionDisplayName": pipeline_execution.pipeline_execution_display_name, "PipelineExecutionStatus": pipeline_execution.pipeline_execution_status, "PipelineExecutionDescription": pipeline_execution.pipeline_execution_description, "PipelineExperimentConfig": {}, "FailureReason": "", "CreationTime": pipeline_execution.creation_time, "LastModifiedTime": pipeline_execution.last_modified_time, "CreatedBy": pipeline_execution.created_by, "LastModifiedBy": pipeline_execution.last_modified_by, "ParallelismConfiguration": pipeline_execution.parallelism_configuration, } def describe_pipeline(self, pipeline_name: str) -> Dict[str, Any]: pipeline = get_pipeline_from_name(self.pipelines, pipeline_name) return { "PipelineArn": pipeline.arn, "PipelineName": pipeline.pipeline_name, "PipelineDisplayName": pipeline.pipeline_display_name, "PipelineDescription": pipeline.pipeline_description, "PipelineDefinition": pipeline.pipeline_definition, "RoleArn": pipeline.role_arn, "PipelineStatus": pipeline.pipeline_status, "CreationTime": pipeline.creation_time, "LastModifiedTime": pipeline.last_modified_time, "LastRunTime": pipeline.last_execution_time, "CreatedBy": pipeline.created_by, "LastModifiedBy": pipeline.last_modified_by, "ParallelismConfiguration": pipeline.parallelism_configuration, } def list_pipelines( self, pipeline_name_prefix: str, created_after: str, created_before: str, next_token: str, max_results: int, sort_by: str, sort_order: str, ) -> Dict[str, Any]: if next_token: try: starting_index = int(next_token) if starting_index > len(self.pipelines): raise ValueError # invalid next_token except ValueError: raise AWSValidationException('Invalid pagination token because "{0}".') else: starting_index = 0 if max_results: end_index = max_results + starting_index pipelines_fetched: Iterable[FakePipeline] = list(self.pipelines.values())[ starting_index:end_index ] if end_index >= len(self.pipelines): next_index = None else: next_index = end_index else: pipelines_fetched = list(self.pipelines.values()) next_index = None if pipeline_name_prefix is not None: pipelines_fetched = filter( lambda x: pipeline_name_prefix in x.pipeline_name, pipelines_fetched, ) def format_time(x: Any) -> str: return ( x if isinstance(x, str) else datetime.fromtimestamp(x).strftime("%Y-%m-%d %H:%M:%S") ) if created_after is not None: pipelines_fetched = filter( lambda x: x.creation_time > format_time(created_after), pipelines_fetched, ) if created_before is not None: pipelines_fetched = filter( lambda x: x.creation_time < format_time(created_before), pipelines_fetched, ) sort_key = "pipeline_name" if sort_by == "Name" else "creation_time" pipelines_fetched = sorted( pipelines_fetched, key=lambda pipeline_fetched: getattr(pipeline_fetched, sort_key), reverse=sort_order != "Ascending", ) pipeline_summaries = [ { "PipelineArn": pipeline_data.arn, "PipelineName": pipeline_data.pipeline_name, "PipelineDisplayName": pipeline_data.pipeline_display_name, "PipelineDescription": pipeline_data.pipeline_description, "RoleArn": pipeline_data.role_arn, "CreationTime": pipeline_data.creation_time, "LastModifiedTime": pipeline_data.last_modified_time, "LastExecutionTime": pipeline_data.last_execution_time, } for pipeline_data in pipelines_fetched ] return { "PipelineSummaries": pipeline_summaries, "NextToken": str(next_index) if next_index is not None else None, } def list_processing_jobs( self, next_token: str, max_results: int, creation_time_after: str, creation_time_before: str, last_modified_time_after: str, last_modified_time_before: str, name_contains: str, status_equals: str, ) -> Dict[str, Any]: if next_token: try: starting_index = int(next_token) if starting_index > len(self.processing_jobs): raise ValueError # invalid next_token except ValueError: raise AWSValidationException('Invalid pagination token because "{0}".') else: starting_index = 0 if max_results: end_index = max_results + starting_index processing_jobs_fetched: Iterable[FakeProcessingJob] = list( self.processing_jobs.values() )[starting_index:end_index] if end_index >= len(self.processing_jobs): next_index = None else: next_index = end_index else: processing_jobs_fetched = list(self.processing_jobs.values()) next_index = None if name_contains is not None: processing_jobs_fetched = filter( lambda x: name_contains in x.processing_job_name, processing_jobs_fetched, ) if creation_time_after is not None: processing_jobs_fetched = filter( lambda x: x.creation_time > creation_time_after, processing_jobs_fetched ) if creation_time_before is not None: processing_jobs_fetched = filter( lambda x: x.creation_time < creation_time_before, processing_jobs_fetched, ) if last_modified_time_after is not None: processing_jobs_fetched = filter( lambda x: x.last_modified_time > last_modified_time_after, processing_jobs_fetched, ) if last_modified_time_before is not None: processing_jobs_fetched = filter( lambda x: x.last_modified_time < last_modified_time_before, processing_jobs_fetched, ) if status_equals is not None: processing_jobs_fetched = filter( lambda x: x.processing_job_status == status_equals, processing_jobs_fetched, ) processing_job_summaries = [ { "ProcessingJobName": processing_job_data.processing_job_name, "ProcessingJobArn": processing_job_data.arn, "CreationTime": processing_job_data.creation_time, "ProcessingEndTime": processing_job_data.processing_end_time, "LastModifiedTime": processing_job_data.last_modified_time, "ProcessingJobStatus": processing_job_data.processing_job_status, } for processing_job_data in processing_jobs_fetched ] return { "ProcessingJobSummaries": processing_job_summaries, "NextToken": str(next_index) if next_index is not None else None, } def create_transform_job( self, transform_job_name: str, model_name: str, max_concurrent_transforms: int, model_client_config: Dict[str, int], max_payload_in_mb: int, batch_strategy: str, environment: Dict[str, str], transform_input: Dict[str, Union[Dict[str, str], str]], transform_output: Dict[str, str], data_capture_config: Dict[str, Union[str, bool]], transform_resources: Dict[str, Union[str, int]], data_processing: Dict[str, str], tags: Dict[str, str], experiment_config: Dict[str, str], ) -> FakeTransformJob: transform_job = FakeTransformJob( account_id=self.account_id, region_name=self.region_name, transform_job_name=transform_job_name, model_name=model_name, max_concurrent_transforms=max_concurrent_transforms, model_client_config=model_client_config, max_payload_in_mb=max_payload_in_mb, batch_strategy=batch_strategy, environment=environment, transform_input=transform_input, transform_output=transform_output, data_capture_config=data_capture_config, transform_resources=transform_resources, data_processing=data_processing, tags=tags, experiment_config=experiment_config, ) self.transform_jobs[transform_job_name] = transform_job return transform_job def list_transform_jobs( self, next_token: str, max_results: int, creation_time_after: str, creation_time_before: str, last_modified_time_after: str, last_modified_time_before: str, name_contains: str, status_equals: str, ) -> Dict[str, Any]: if next_token: try: starting_index = int(next_token) if starting_index > len(self.transform_jobs): raise ValueError # invalid next_token except ValueError: raise AWSValidationException('Invalid pagination token because "{0}".') else: starting_index = 0 if max_results: end_index = max_results + starting_index transform_jobs_fetched: Iterable[FakeTransformJob] = list( self.transform_jobs.values() )[starting_index:end_index] if end_index >= len(self.transform_jobs): next_index = None else: next_index = end_index else: transform_jobs_fetched = list(self.transform_jobs.values()) next_index = None if name_contains is not None: transform_jobs_fetched = filter( lambda x: name_contains in x.transform_job_name, transform_jobs_fetched ) if creation_time_after is not None: transform_jobs_fetched = filter( lambda x: x.creation_time > creation_time_after, transform_jobs_fetched ) if creation_time_before is not None: transform_jobs_fetched = filter( lambda x: x.creation_time < creation_time_before, transform_jobs_fetched ) if last_modified_time_after is not None: transform_jobs_fetched = filter( lambda x: x.last_modified_time > last_modified_time_after, transform_jobs_fetched, ) if last_modified_time_before is not None: transform_jobs_fetched = filter( lambda x: x.last_modified_time < last_modified_time_before, transform_jobs_fetched, ) if status_equals is not None: transform_jobs_fetched = filter( lambda x: x.transform_job_status == status_equals, transform_jobs_fetched, ) transform_job_summaries = [ { "TransformJobName": transform_job_data.transform_job_name, "TransformJobArn": transform_job_data.arn, "CreationTime": transform_job_data.creation_time, "TransformEndTime": transform_job_data.transform_end_time, "LastModifiedTime": transform_job_data.last_modified_time, "TransformJobStatus": transform_job_data.transform_job_status, } for transform_job_data in transform_jobs_fetched ] return { "TransformJobSummaries": transform_job_summaries, "NextToken": str(next_index) if next_index is not None else None, } def describe_transform_job(self, transform_job_name: str) -> Dict[str, Any]: try: return self.transform_jobs[transform_job_name].response_object except KeyError: arn = FakeTransformJob.arn_formatter( transform_job_name, self.account_id, self.region_name ) message = f"Could not find transform job '{arn}'." raise ValidationError(message=message) def create_training_job( self, training_job_name: str, hyper_parameters: Dict[str, str], algorithm_specification: Dict[str, Any], role_arn: str, input_data_config: List[Dict[str, Any]], output_data_config: Dict[str, str], resource_config: Dict[str, Any], vpc_config: Dict[str, List[str]], stopping_condition: Dict[str, int], tags: List[Dict[str, str]], enable_network_isolation: bool, enable_inter_container_traffic_encryption: bool, enable_managed_spot_training: bool, checkpoint_config: Dict[str, str], debug_hook_config: Dict[str, Any], debug_rule_configurations: List[Dict[str, Any]], tensor_board_output_config: Dict[str, str], experiment_config: Dict[str, str], ) -> FakeTrainingJob: training_job = FakeTrainingJob( account_id=self.account_id, region_name=self.region_name, training_job_name=training_job_name, hyper_parameters=hyper_parameters, algorithm_specification=algorithm_specification, role_arn=role_arn, input_data_config=input_data_config, output_data_config=output_data_config, resource_config=resource_config, vpc_config=vpc_config, stopping_condition=stopping_condition, tags=tags, enable_network_isolation=enable_network_isolation, enable_inter_container_traffic_encryption=enable_inter_container_traffic_encryption, enable_managed_spot_training=enable_managed_spot_training, checkpoint_config=checkpoint_config, debug_hook_config=debug_hook_config, debug_rule_configurations=debug_rule_configurations, tensor_board_output_config=tensor_board_output_config, experiment_config=experiment_config, ) self.training_jobs[training_job_name] = training_job return training_job def describe_training_job(self, training_job_name: str) -> Dict[str, Any]: try: return self.training_jobs[training_job_name].response_object except KeyError: arn = FakeTrainingJob.arn_formatter( training_job_name, self.account_id, self.region_name ) message = f"Could not find training job '{arn}'." raise ValidationError(message=message) def list_training_jobs( self, next_token: str, max_results: int, creation_time_after: str, creation_time_before: str, last_modified_time_after: str, last_modified_time_before: str, name_contains: str, status_equals: str, ) -> Dict[str, Any]: if next_token: try: starting_index = int(next_token) if starting_index > len(self.training_jobs): raise ValueError # invalid next_token except ValueError: raise AWSValidationException('Invalid pagination token because "{0}".') else: starting_index = 0 if max_results: end_index = max_results + starting_index training_jobs_fetched: Iterable[FakeTrainingJob] = list( self.training_jobs.values() )[starting_index:end_index] if end_index >= len(self.training_jobs): next_index = None else: next_index = end_index else: training_jobs_fetched = list(self.training_jobs.values()) next_index = None if name_contains is not None: training_jobs_fetched = filter( lambda x: name_contains in x.training_job_name, training_jobs_fetched ) if creation_time_after is not None: training_jobs_fetched = filter( lambda x: x.creation_time > creation_time_after, training_jobs_fetched ) if creation_time_before is not None: training_jobs_fetched = filter( lambda x: x.creation_time < creation_time_before, training_jobs_fetched ) if last_modified_time_after is not None: training_jobs_fetched = filter( lambda x: x.last_modified_time > last_modified_time_after, training_jobs_fetched, ) if last_modified_time_before is not None: training_jobs_fetched = filter( lambda x: x.last_modified_time < last_modified_time_before, training_jobs_fetched, ) if status_equals is not None: training_jobs_fetched = filter( lambda x: x.training_job_status == status_equals, training_jobs_fetched ) training_job_summaries = [ { "TrainingJobName": training_job_data.training_job_name, "TrainingJobArn": training_job_data.arn, "CreationTime": training_job_data.creation_time, "TrainingEndTime": training_job_data.training_end_time, "LastModifiedTime": training_job_data.last_modified_time, "TrainingJobStatus": training_job_data.training_job_status, } for training_job_data in training_jobs_fetched ] return { "TrainingJobSummaries": training_job_summaries, "NextToken": str(next_index) if next_index is not None else None, } def update_endpoint_weights_and_capacities( self, endpoint_name: str, desired_weights_and_capacities: List[Dict[str, Any]] ) -> str: # Validate inputs endpoint = self.endpoints.get(endpoint_name, None) if not endpoint: arn = FakeEndpoint.arn_formatter( endpoint_name, self.account_id, self.region_name ) raise AWSValidationException(f'Could not find endpoint "{arn}".') names_checked = [] for variant_config in desired_weights_and_capacities: name = variant_config.get("VariantName") if name in names_checked: raise AWSValidationException( f'The variant name "{name}" was non-unique within the request.' ) if not any( variant["VariantName"] == name for variant in endpoint.production_variants ): raise AWSValidationException( f'The variant name(s) "{name}" is/are not present within endpoint configuration "{endpoint.endpoint_config_name}".' ) names_checked.append(name) # Update endpoint variants endpoint.endpoint_status = "Updating" for variant_config in desired_weights_and_capacities: name = variant_config.get("VariantName") desired_weight = variant_config.get("DesiredWeight") desired_instance_count = variant_config.get("DesiredInstanceCount") for variant in endpoint.production_variants: if variant.get("VariantName") == name: variant["DesiredWeight"] = desired_weight variant["CurrentWeight"] = desired_weight variant["DesiredInstanceCount"] = desired_instance_count variant["CurrentInstanceCount"] = desired_instance_count break endpoint.endpoint_status = "InService" return endpoint.arn def create_model_package_group( self, model_package_group_name: str, model_package_group_description: str, tags: Optional[List[Dict[str, str]]] = None, ) -> str: self.model_package_groups[model_package_group_name] = ModelPackageGroup( model_package_group_name=model_package_group_name, model_package_group_description=model_package_group_description, account_id=self.account_id, region_name=self.region_name, tags=tags or [], ) return self.model_package_groups[model_package_group_name].arn def _get_versioned_or_not( self, model_package_type: Optional[str], model_package_version: Optional[int] ) -> bool: if model_package_type == "Versioned": return model_package_version is not None elif model_package_type == "Unversioned" or model_package_type is None: return model_package_version is None elif model_package_type == "Both": return True raise ValueError(f"Invalid model package type: {model_package_type}") @paginate(pagination_model=PAGINATION_MODEL) def list_model_package_groups( self, creation_time_after: Optional[int], creation_time_before: Optional[int], name_contains: Optional[str], sort_by: Optional[str], sort_order: Optional[str], ) -> List[ModelPackageGroup]: if isinstance(creation_time_before, int): creation_time_before_datetime = datetime.fromtimestamp( creation_time_before, tz=tzutc() ) if isinstance(creation_time_after, int): creation_time_after_datetime = datetime.fromtimestamp( creation_time_after, tz=tzutc() ) model_package_group_summary_list = list( filter( lambda x: ( creation_time_after is None or x.creation_time > creation_time_after_datetime ) and ( creation_time_before is None or x.creation_time < creation_time_before_datetime ) and ( name_contains is None or x.model_package_group_name.find(name_contains) != -1 ), self.model_package_groups.values(), ) ) model_package_group_summary_list = list( sorted( model_package_group_summary_list, key={ "Name": lambda x: x.model_package_group_name, "CreationTime": lambda x: x.creation_time, None: lambda x: x.creation_time, }[sort_by], reverse=sort_order == "Descending", ) ) return model_package_group_summary_list def describe_model_package_group( self, model_package_group_name: str ) -> ModelPackageGroup: model_package_group = self.model_package_groups.get(model_package_group_name) if model_package_group is None: model_package_group_arn = arn_formatter( region_name=self.region_name, account_id=self.account_id, _type="model-package-group", _id=f"{model_package_group_name}", ) raise ValidationError( f"ModelPackageGroup {model_package_group_arn} does not exist." ) return model_package_group @paginate(pagination_model=PAGINATION_MODEL) def list_model_packages( self, creation_time_after: Optional[int], creation_time_before: Optional[int], name_contains: Optional[str], model_approval_status: Optional[str], model_package_group_name: Optional[str], model_package_type: Optional[str], sort_by: Optional[str], sort_order: Optional[str], ) -> List[ModelPackage]: if isinstance(creation_time_before, int): creation_time_before_datetime = datetime.fromtimestamp( creation_time_before, tz=tzutc() ) if isinstance(creation_time_after, int): creation_time_after_datetime = datetime.fromtimestamp( creation_time_after, tz=tzutc() ) if model_package_group_name is not None: model_package_type = "Versioned" if re.match(ARN_PARTITION_REGEX, model_package_group_name): model_package_group_name = model_package_group_name.split("/")[-1] model_package_summary_list = list( filter( lambda x: ( creation_time_after is None or x.creation_time > creation_time_after_datetime ) and ( creation_time_before is None or x.creation_time < creation_time_before_datetime ) and ( name_contains is None or x.model_package_name.find(name_contains) != -1 ) and ( model_approval_status is None or x.model_approval_status == model_approval_status ) and ( model_package_group_name is None or x.model_package_group_name == model_package_group_name ) and self._get_versioned_or_not( model_package_type, x.model_package_version ), self.model_packages.values(), ) ) model_package_summary_list = list( sorted( model_package_summary_list, key={ "Name": lambda x: x.model_package_name, "CreationTime": lambda x: x.creation_time, None: lambda x: x.creation_time, }[sort_by], reverse=sort_order == "Descending", ) ) return model_package_summary_list def describe_model_package(self, model_package_name: str) -> ModelPackage: model_package_name_mapped = self.model_package_name_mapping.get( model_package_name, model_package_name ) model_package = self.model_packages.get(model_package_name_mapped) if model_package is None: raise ValidationError(f"Model package {model_package_name} not found") return model_package def update_model_package( self, model_package_arn: str, model_approval_status: Optional[str], approval_description: Optional[str], customer_metadata_properties: Optional[Dict[str, str]], customer_metadata_properties_to_remove: List[str], additional_inference_specifications_to_add: Optional[List[Any]], ) -> str: model_package_name_mapped = self.model_package_name_mapping.get( model_package_arn, model_package_arn ) model_package = self.model_packages.get(model_package_name_mapped) if model_package is None: raise ValidationError(f"Model package {model_package_arn} not found") model_package.set_model_approval_status(model_approval_status) model_package.approval_description = approval_description model_package.customer_metadata_properties = customer_metadata_properties model_package.remove_customer_metadata_property( customer_metadata_properties_to_remove ) model_package.add_additional_inference_specifications( additional_inference_specifications_to_add ) model_package.modifications_done() return model_package.arn def create_model_package( self, model_package_name: Optional[str], model_package_group_name: Optional[str], model_package_description: Optional[str], inference_specification: Any, validation_specification: Any, source_algorithm_specification: Any, certify_for_marketplace: bool, tags: Any, model_approval_status: Optional[str], metadata_properties: Any, model_metrics: Any, client_token: Any, customer_metadata_properties: Any, drift_check_baselines: Any, domain: Any, task: Any, sample_payload_url: Any, additional_inference_specifications: Any, ) -> str: model_package_version = None if model_package_group_name and model_package_name: raise AWSValidationException( "An error occurred (ValidationException) when calling the CreateModelPackage operation: Both ModelPackageName and ModelPackageGroupName are provided in the input. Cannot determine which one to use." ) elif not model_package_group_name and not model_package_name: raise AWSValidationException( "An error ocurred (ValidationException) when calling the CreateModelPackag operation: Missing ARN." ) elif model_package_group_name: model_package_type = "Versioned" model_package_name = model_package_group_name model_packages_for_group = [ x for x in self.model_packages.values() if x.model_package_group_name == model_package_group_name ] if model_package_group_name not in self.model_package_groups: raise AWSValidationException( "An error ocurred (ValidationException) when calling the CreateModelPackage operation: Model Package Group does not exist." ) model_package_version = len(model_packages_for_group) + 1 else: model_package_type = "Unversioned" model_package = ModelPackage( model_package_name=cast(str, model_package_name), model_package_group_name=model_package_group_name, model_package_description=model_package_description, inference_specification=inference_specification, validation_specification=validation_specification, source_algorithm_specification=source_algorithm_specification, certify_for_marketplace=certify_for_marketplace, tags=tags, model_approval_status=model_approval_status, metadata_properties=metadata_properties, model_metrics=model_metrics, customer_metadata_properties=customer_metadata_properties, drift_check_baselines=drift_check_baselines, domain=domain, task=task, sample_payload_url=sample_payload_url, additional_inference_specifications=additional_inference_specifications, model_package_version=model_package_version, approval_description=None, region_name=self.region_name, account_id=self.account_id, client_token=client_token, model_package_type=model_package_type, ) self.model_package_name_mapping[model_package.model_package_name] = ( model_package.arn ) self.model_package_name_mapping[model_package.arn] = model_package.arn self.model_packages[model_package.arn] = model_package return model_package.arn def create_feature_group( self, feature_group_name: str, record_identifier_feature_name: str, event_time_feature_name: str, feature_definitions: List[Dict[str, str]], offline_store_config: Dict[str, Any], role_arn: str, tags: Any, ) -> str: feature_group_arn = arn_formatter( region_name=self.region_name, account_id=self.account_id, _type="feature-group", _id=f"{feature_group_name.lower()}", ) if feature_group_arn in self.feature_groups: raise ResourceInUseException( message=f"An error occurred (ResourceInUse) when calling the CreateFeatureGroup operation: Resource Already Exists: FeatureGroup with name {feature_group_name} already exists. Choose a different name.\nInfo: Feature Group '{feature_group_name}' already exists." ) feature_group = FeatureGroup( feature_group_name=feature_group_name, record_identifier_feature_name=record_identifier_feature_name, event_time_feature_name=event_time_feature_name, feature_definitions=feature_definitions, offline_store_config=offline_store_config, role_arn=role_arn, region_name=self.region_name, account_id=self.account_id, tags=tags, ) self.feature_groups[feature_group.arn] = feature_group return feature_group.arn def describe_feature_group( self, feature_group_name: str, ) -> Dict[str, Any]: feature_group_arn = arn_formatter( region_name=self.region_name, account_id=self.account_id, _type="feature-group", _id=f"{feature_group_name.lower()}", ) feature_group = self.feature_groups[feature_group_arn] return feature_group.describe() def create_cluster( self, cluster_name: str, instance_groups: List[Dict[str, Any]], vpc_config: Dict[str, List[str]], tags: Any, ) -> str: cluster = Cluster( cluster_name=cluster_name, region_name=self.region_name, account_id=self.account_id, instance_groups=instance_groups, vpc_config=vpc_config, tags=tags, ) self.clusters[cluster_name] = cluster # create Cluster Nodes for instance_group in instance_groups: for i in range(instance_group["TargetCount"]): node_id = f"{instance_group['InstanceGroupName']}-{i}" fake_cluster_node = ClusterNode( region_name=self.region_name, account_id=self.account_id, cluster_name=cluster_name, instance_group_name=instance_group["InstanceGroupName"], instance_type=instance_group["InstanceType"], life_cycle_config=instance_group["LifeCycleConfig"], execution_role=instance_group["ExecutionRole"], node_id=node_id, threads_per_core=instance_group["ThreadsPerCore"], ) cluster.nodes[node_id] = fake_cluster_node return cluster.arn def describe_cluster(self, cluster_name: str) -> Dict[str, Any]: if cluster_name.startswith(f"arn:{self.partition}:sagemaker:"): cluster_name = (cluster_name.split(":")[-1]).split("/")[-1] cluster = self.clusters.get(cluster_name) if not cluster: raise ValidationError(message=f"Could not find cluster '{cluster_name}'.") return cluster.describe() def delete_cluster(self, cluster_name: str) -> str: if cluster_name.startswith(f"arn:{self.partition}:sagemaker:"): cluster_name = (cluster_name.split(":")[-1]).split("/")[-1] cluster = self.clusters.get(cluster_name) if not cluster: raise ValidationError(message=f"Could not find cluster '{cluster_name}'.") arn = cluster.arn del self.clusters[cluster_name] return arn def describe_cluster_node(self, cluster_name: str, node_id: str) -> Dict[str, Any]: if cluster_name.startswith(f"arn:{self.partition}:sagemaker:"): cluster_name = (cluster_name.split(":")[-1]).split("/")[-1] cluster = self.clusters.get(cluster_name) if not cluster: raise ValidationError(message=f"Could not find cluster '{cluster_name}'.") if node_id in cluster.nodes: return cluster.nodes[node_id].describe() else: raise ValidationError( message=f"Could not find node '{node_id}' in cluster '{cluster_name}'." ) @paginate(pagination_model=PAGINATION_MODEL) def list_clusters( self, creation_time_after: Optional[datetime], creation_time_before: Optional[datetime], name_contains: Optional[str], sort_by: Optional[str], sort_order: Optional[str], ) -> List[Cluster]: clusters = list(self.clusters.values()) if name_contains: clusters = [i for i in clusters if name_contains in i.cluster_name] if creation_time_before: clusters = [ i for i in clusters if i.creation_time < str(creation_time_before) ] if creation_time_after: clusters = [ i for i in clusters if i.creation_time > str(creation_time_after) ] reverse = sort_order == "Descending" if sort_by == "Name": clusters = sorted(clusters, key=lambda x: x.cluster_name, reverse=reverse) if sort_by == "CreationTime" or sort_by is None: clusters = sorted(clusters, key=lambda x: x.creation_time, reverse=reverse) return clusters @paginate(pagination_model=PAGINATION_MODEL) def list_cluster_nodes( self, cluster_name: str, creation_time_after: Optional[str], creation_time_before: Optional[str], instance_group_name_contains: Optional[str], sort_by: Optional[str], sort_order: Optional[str], ) -> List[ClusterNode]: if cluster_name.startswith(f"arn:{self.partition}:sagemaker:"): cluster_name = (cluster_name.split(":")[-1]).split("/")[-1] cluster = self.clusters.get(cluster_name) if not cluster: raise ValidationError(message=f"Could not find cluster '{cluster_name}'.") nodes_list = list(cluster.nodes.values()) if instance_group_name_contains: nodes_list = [ i for i in nodes_list if instance_group_name_contains in i.instance_group_name ] if creation_time_before: nodes_list = [ i for i in nodes_list if i.launch_time < str(creation_time_before) ] if creation_time_after: nodes_list = [ i for i in nodes_list if i.launch_time > str(creation_time_after) ] reverse = sort_order == "Descending" if sort_by == "Name": nodes_list = sorted( nodes_list, key=lambda x: x.instance_group_name, reverse=reverse ) if sort_by == "CreationTime" or sort_by is None: nodes_list = sorted( nodes_list, key=lambda x: x.launch_time, reverse=reverse ) return nodes_list def create_model_bias_job_definition( self, account_id: str, job_definition_name: str, tags: List[Dict[str, str]] = [], role_arn: str = "", job_resources: Optional[Dict[str, Any]] = None, stopping_condition: Optional[Dict[str, Any]] = None, environment: Optional[Dict[str, str]] = None, network_config: Optional[Dict[str, Any]] = None, model_bias_baseline_config: Optional[Dict[str, Any]] = None, model_bias_app_specification: Optional[Dict[str, Any]] = None, model_bias_job_input: Optional[Dict[str, Any]] = None, model_bias_job_output_config: Optional[Dict[str, Any]] = None, ) -> Dict[str, str]: job_definition = FakeModelBiasJobDefinition( account_id=account_id, region_name=self.region_name, job_definition_name=job_definition_name, tags=tags, role_arn=role_arn, job_resources=job_resources, stopping_condition=stopping_condition, environment=environment, network_config=network_config, model_bias_baseline_config=model_bias_baseline_config, model_bias_app_specification=model_bias_app_specification, model_bias_job_input=model_bias_job_input, model_bias_job_output_config=model_bias_job_output_config, ) self.model_bias_job_definitions[job_definition_name] = job_definition return job_definition.response_create @paginate(pagination_model=PAGINATION_MODEL) def list_model_bias_job_definitions(self) -> List[Dict[str, str]]: return [job.summary_object for job in self.model_bias_job_definitions.values()] def describe_model_bias_job_definition( self, job_definition_name: str ) -> Dict[str, Any]: job_definition = self.model_bias_job_definitions.get(job_definition_name) if job_definition is None: raise ResourceNotFound(f"Job definition {job_definition_name} not found") return job_definition.response_object def delete_model_bias_job_definition(self, job_definition_name: str) -> None: if job_definition_name in self.model_bias_job_definitions: del self.model_bias_job_definitions[job_definition_name] else: raise ResourceNotFound(f"Job definition {job_definition_name} not found") def create_auto_ml_job_v2( self, auto_ml_job_name: str, auto_ml_job_input_data_config: List[Dict[str, Any]], output_data_config: Dict[str, Any], auto_ml_problem_type_config: Dict[str, Any], role_arn: str, tags: Optional[List[Dict[str, str]]], security_config: Optional[Dict[str, Any]], auto_ml_job_objective: Optional[Dict[str, str]], model_deploy_config: Optional[Dict[str, Any]], data_split_config: Optional[Dict[str, Any]], ) -> str: auto_ml_job = AutoMLJob( auto_ml_job_name=auto_ml_job_name, auto_ml_job_input_data_config=auto_ml_job_input_data_config, output_data_config=output_data_config, auto_ml_problem_type_config=auto_ml_problem_type_config, role_arn=role_arn, region_name=self.region_name, account_id=self.account_id, tags=tags, security_config=security_config, auto_ml_job_objective=auto_ml_job_objective, model_deploy_config=model_deploy_config, data_split_config=data_split_config, ) self.auto_ml_jobs[auto_ml_job_name] = auto_ml_job return auto_ml_job.arn def describe_auto_ml_job_v2(self, auto_ml_job_name: str) -> Dict[str, Any]: if auto_ml_job_name not in self.auto_ml_jobs: raise ResourceNotFound( f"Could not find AutoML job with name {auto_ml_job_name}." ) auto_ml_job = self.auto_ml_jobs[auto_ml_job_name] return auto_ml_job.describe() @paginate(pagination_model=PAGINATION_MODEL) def list_auto_ml_jobs( self, creation_time_after: Optional[str], creation_time_before: Optional[str], last_modified_time_after: Optional[str], last_modified_time_before: Optional[str], name_contains: Optional[str], status_equals: Optional[str], sort_order: Optional[str], sort_by: Optional[str], ) -> List[AutoMLJob]: auto_ml_jobs = list(self.auto_ml_jobs.values()) if name_contains: auto_ml_jobs = [ i for i in auto_ml_jobs if name_contains in i.auto_ml_job_name ] if status_equals: auto_ml_jobs = [ i for i in auto_ml_jobs if status_equals == i.auto_ml_job_status ] if creation_time_before: auto_ml_jobs = [ i for i in auto_ml_jobs if i.creation_time < str(creation_time_before) ] if creation_time_after: auto_ml_jobs = [ i for i in auto_ml_jobs if i.creation_time > str(creation_time_after) ] if last_modified_time_before: auto_ml_jobs = [ i for i in auto_ml_jobs if i.last_modified_time < str(last_modified_time_before) ] if last_modified_time_after: auto_ml_jobs = [ i for i in auto_ml_jobs if i.last_modified_time > str(last_modified_time_after) ] reverse = sort_order == "Descending" if sort_by == "Status": auto_ml_jobs = sorted( auto_ml_jobs, key=lambda x: x.auto_ml_job_status, reverse=reverse ) if sort_by == "CreationTime": auto_ml_jobs = sorted( auto_ml_jobs, key=lambda x: x.creation_time, reverse=reverse ) if sort_by == "Name" or sort_by is None: auto_ml_jobs = sorted( auto_ml_jobs, key=lambda x: x.auto_ml_job_name, reverse=reverse ) return auto_ml_jobs def stop_auto_ml_job(self, auto_ml_job_name: str) -> None: if auto_ml_job_name not in self.auto_ml_jobs: raise ResourceNotFound( f"Could not find AutoML job with name {auto_ml_job_name}." ) auto_ml_job = self.auto_ml_jobs[auto_ml_job_name] auto_ml_job.auto_ml_job_status = "Stopped" auto_ml_job.auto_ml_job_secondary_status = "Stopped" @paginate(pagination_model=PAGINATION_MODEL) def list_endpoints( self, sort_by: Optional[str], sort_order: Optional[str], name_contains: Optional[str], creation_time_before: Optional[str], creation_time_after: Optional[str], last_modified_time_before: Optional[str], last_modified_time_after: Optional[str], status_equals: Optional[str], ) -> List[FakeEndpoint]: endpoints = list(self.endpoints.values()) if name_contains: endpoints = [i for i in endpoints if name_contains in i.endpoint_name] if status_equals: endpoints = [i for i in endpoints if status_equals == i.endpoint_status] if creation_time_before: endpoints = [ i for i in endpoints if i.creation_time < str(creation_time_before) ] if creation_time_after: endpoints = [ i for i in endpoints if i.creation_time > str(creation_time_after) ] if last_modified_time_before: endpoints = [ i for i in endpoints if i.last_modified_time < str(last_modified_time_before) ] if last_modified_time_after: endpoints = [ i for i in endpoints if i.last_modified_time > str(last_modified_time_after) ] reverse = sort_order == "Descending" if sort_by == "Name": endpoints = sorted( endpoints, key=lambda x: x.endpoint_name, reverse=reverse ) elif sort_by == "Status": endpoints = sorted( endpoints, key=lambda x: x.endpoint_status, reverse=reverse ) else: endpoints = sorted( endpoints, key=lambda x: x.creation_time, reverse=reverse ) return endpoints @paginate(pagination_model=PAGINATION_MODEL) def list_endpoint_configs( self, sort_by: Optional[str], sort_order: Optional[str], name_contains: Optional[str], creation_time_before: Optional[str], creation_time_after: Optional[str], ) -> List[FakeEndpointConfig]: endpoint_configs = list(self.endpoint_configs.values()) if name_contains: endpoint_configs = [ i for i in endpoint_configs if name_contains in i.endpoint_config_name ] if creation_time_before: endpoint_configs = [ i for i in endpoint_configs if i.creation_time < str(creation_time_before) ] if creation_time_after: endpoint_configs = [ i for i in endpoint_configs if i.creation_time > str(creation_time_after) ] reverse = sort_order == "Descending" if sort_by == "Name": endpoint_configs = sorted( endpoint_configs, key=lambda x: x.endpoint_config_name, reverse=reverse ) else: endpoint_configs = sorted( endpoint_configs, key=lambda x: x.creation_time, reverse=reverse ) return endpoint_configs def create_compilation_job( self, compilation_job_name: str, role_arn: str, output_config: Dict[str, Any], stopping_condition: Dict[str, Any], model_package_version_arn: Optional[str], input_config: Optional[Dict[str, Any]], vpc_config: Optional[Dict[str, Any]], tags: Optional[List[Dict[str, str]]], ) -> str: compilation_job = CompilationJob( compilation_job_name=compilation_job_name, role_arn=role_arn, region_name=self.region_name, account_id=self.account_id, model_package_version_arn=model_package_version_arn, input_config=input_config, output_config=output_config, vpc_config=vpc_config, stopping_condition=stopping_condition, tags=tags, ) self.compilation_jobs[compilation_job_name] = compilation_job return compilation_job.arn def describe_compilation_job(self, compilation_job_name: str) -> Dict[str, Any]: if compilation_job_name not in self.compilation_jobs: raise ResourceNotFound( message=f"Could not find compilation job '{compilation_job_name}'." ) compilation_job = self.compilation_jobs[compilation_job_name] return compilation_job.describe() @paginate(pagination_model=PAGINATION_MODEL) def list_compilation_jobs( self, creation_time_after: Optional[str], creation_time_before: Optional[str], last_modified_time_after: Optional[str], last_modified_time_before: Optional[str], name_contains: Optional[str], status_equals: Optional[str], sort_by: Optional[str], sort_order: Optional[str], ) -> List[CompilationJob]: compilation_jobs = list(self.compilation_jobs.values()) if name_contains: compilation_jobs = [ i for i in compilation_jobs if name_contains in i.compilation_job_name ] if creation_time_before: compilation_jobs = [ i for i in compilation_jobs if i.creation_time < str(creation_time_before) ] if creation_time_after: compilation_jobs = [ i for i in compilation_jobs if i.creation_time > str(creation_time_after) ] if last_modified_time_before: compilation_jobs = [ i for i in compilation_jobs if i.last_modified_time < str(last_modified_time_before) ] if creation_time_after: compilation_jobs = [ i for i in compilation_jobs if i.last_modified_time > str(last_modified_time_after) ] if status_equals: compilation_jobs = [ i for i in compilation_jobs if i.compilation_job_status == status_equals ] reverse = sort_order == "Descending" if sort_by == "Name": compilation_jobs = sorted( compilation_jobs, key=lambda x: x.compilation_job_name, reverse=reverse ) if sort_by == "Status": compilation_jobs = sorted( compilation_jobs, key=lambda x: x.compilation_job_status, reverse=reverse, ) if sort_by == "CreationTime" or sort_by is None: compilation_jobs = sorted( compilation_jobs, key=lambda x: x.creation_time, reverse=reverse ) return compilation_jobs def delete_compilation_job(self, compilation_job_name: str) -> None: if compilation_job_name not in self.compilation_jobs: raise ResourceNotFound( message=f"Could not find compilation job '{compilation_job_name}'." ) del self.compilation_jobs[compilation_job_name] def create_domain( self, domain_name: str, auth_mode: str, default_user_settings: Dict[str, Any], subnet_ids: List[str], vpc_id: str, domain_settings: Optional[Dict[str, Any]], tags: Optional[List[Dict[str, str]]], app_network_access_type: Optional[str], home_efs_file_system_kms_key_id: Optional[str], kms_key_id: Optional[str], app_security_group_management: Optional[str], default_space_settings: Optional[Dict[str, Any]], ) -> Dict[str, Any]: domain = Domain( domain_name=domain_name, auth_mode=auth_mode, default_user_settings=default_user_settings, subnet_ids=subnet_ids, vpc_id=vpc_id, domain_settings=domain_settings, tags=tags, app_network_access_type=app_network_access_type, home_efs_file_system_kms_key_id=home_efs_file_system_kms_key_id, kms_key_id=kms_key_id, app_security_group_management=app_security_group_management, default_space_settings=default_space_settings, region_name=self.region_name, account_id=self.account_id, ) self.domains[domain.id] = domain return {"DomainArn": domain.arn, "Url": domain.url} def describe_domain(self, domain_id: str) -> Dict[str, Any]: if domain_id not in self.domains: raise ValidationError(message=f"Could not find domain '{domain_id}'.") return self.domains[domain_id].describe() @paginate(pagination_model=PAGINATION_MODEL) def list_domains(self) -> List[Domain]: return list(self.domains.values()) def delete_domain( self, domain_id: str, retention_policy: Optional[Dict[str, str]] ) -> None: # 'retention_policy' parameter is not used if domain_id not in self.domains: raise ValidationError(message=f"Could not find domain '{domain_id}'.") del self.domains[domain_id] def create_model_explainability_job_definition( self, job_definition_name: str, model_explainability_baseline_config: Optional[Dict[str, Any]], model_explainability_app_specification: Dict[str, Any], model_explainability_job_input: Dict[str, Any], model_explainability_job_output_config: Dict[str, Any], job_resources: Dict[str, Any], network_config: Optional[Dict[str, Any]], role_arn: str, stopping_condition: Optional[Dict[str, Any]], tags: List[Dict[str, str]], ) -> str: model_explainability_job_definition = ModelExplainabilityJobDefinition( job_definition_name=job_definition_name, model_explainability_baseline_config=model_explainability_baseline_config, model_explainability_app_specification=model_explainability_app_specification, model_explainability_job_input=model_explainability_job_input, model_explainability_job_output_config=model_explainability_job_output_config, job_resources=job_resources, region_name=self.region_name, account_id=self.account_id, network_config=network_config, role_arn=role_arn, stopping_condition=stopping_condition, tags=tags, ) self.model_explainability_job_definitions[ model_explainability_job_definition.job_definition_name ] = model_explainability_job_definition return model_explainability_job_definition.arn def describe_model_explainability_job_definition( self, job_definition_name: str ) -> Dict[str, Any]: if job_definition_name not in self.model_explainability_job_definitions: raise ResourceNotFound( message=f"Could not find model explainability job definition with name '{job_definition_name}'." ) return self.model_explainability_job_definitions[job_definition_name].describe() @paginate(pagination_model=PAGINATION_MODEL) def list_model_explainability_job_definitions( self, endpoint_name: Optional[str], sort_by: Optional[str], sort_order: Optional[str], name_contains: Optional[str], creation_time_before: Optional[str], creation_time_after: Optional[str], ) -> List[ModelExplainabilityJobDefinition]: model_explainability_job_definitions = list( self.model_explainability_job_definitions.values() ) if endpoint_name: model_explainability_job_definitions = [ i for i in model_explainability_job_definitions if endpoint_name == i.endpoint_name ] if name_contains: model_explainability_job_definitions = [ i for i in model_explainability_job_definitions if name_contains in i.job_definition_name ] if creation_time_before: model_explainability_job_definitions = [ i for i in model_explainability_job_definitions if i.creation_time < str(creation_time_before) ] if creation_time_after: model_explainability_job_definitions = [ i for i in model_explainability_job_definitions if i.creation_time > str(creation_time_after) ] reverse = sort_order == "Descending" if sort_by == "Name": model_explainability_job_definitions = sorted( model_explainability_job_definitions, key=lambda x: x.job_definition_name, reverse=reverse, ) if sort_by == "CreationTime" or sort_by is None: model_explainability_job_definitions = sorted( model_explainability_job_definitions, key=lambda x: x.creation_time, reverse=reverse, ) return model_explainability_job_definitions def delete_model_explainability_job_definition( self, job_definition_name: str ) -> None: if job_definition_name not in self.model_explainability_job_definitions: raise ResourceNotFound( message=f"Could not find model explainability job definition with name '{job_definition_name}'." ) del self.model_explainability_job_definitions[job_definition_name] def create_hyper_parameter_tuning_job( self, hyper_parameter_tuning_job_name: str, hyper_parameter_tuning_job_config: Dict[str, Any], training_job_definition: Optional[Dict[str, Any]], training_job_definitions: Optional[List[Dict[str, Any]]], warm_start_config: Optional[Dict[str, Any]], tags: Optional[List[Dict[str, str]]], autotune: Optional[Dict[str, Any]], ) -> str: hyper_parameter_tuning_job = HyperParameterTuningJob( hyper_parameter_tuning_job_name=hyper_parameter_tuning_job_name, hyper_parameter_tuning_job_config=hyper_parameter_tuning_job_config, region_name=self.region_name, account_id=self.account_id, training_job_definition=training_job_definition, training_job_definitions=training_job_definitions, warm_start_config=warm_start_config, tags=tags, autotune=autotune, ) self.hyper_parameter_tuning_jobs[hyper_parameter_tuning_job_name] = ( hyper_parameter_tuning_job ) return hyper_parameter_tuning_job.arn def describe_hyper_parameter_tuning_job( self, hyper_parameter_tuning_job_name: str ) -> Dict[str, Any]: if hyper_parameter_tuning_job_name not in self.hyper_parameter_tuning_jobs: raise ResourceNotFound( message=f"Could not find hyper parameter tuning job '{hyper_parameter_tuning_job_name}'." ) return self.hyper_parameter_tuning_jobs[ hyper_parameter_tuning_job_name ].describe() @paginate(pagination_model=PAGINATION_MODEL) def list_hyper_parameter_tuning_jobs( self, sort_by: Optional[str], sort_order: Optional[str], name_contains: Optional[str], creation_time_after: Optional[str], creation_time_before: Optional[str], last_modified_time_after: Optional[str], last_modified_time_before: Optional[str], status_equals: Optional[str], ) -> List[HyperParameterTuningJob]: hyper_parameter_tuning_jobs = list(self.hyper_parameter_tuning_jobs.values()) if name_contains: hyper_parameter_tuning_jobs = [ i for i in hyper_parameter_tuning_jobs if name_contains in i.hyper_parameter_tuning_job_name ] if status_equals: hyper_parameter_tuning_jobs = [ i for i in hyper_parameter_tuning_jobs if status_equals == i.hyper_parameter_tuning_job_status ] if creation_time_before: hyper_parameter_tuning_jobs = [ i for i in hyper_parameter_tuning_jobs if i.creation_time < str(creation_time_before) ] if creation_time_after: hyper_parameter_tuning_jobs = [ i for i in hyper_parameter_tuning_jobs if i.creation_time > str(creation_time_after) ] if last_modified_time_before: hyper_parameter_tuning_jobs = [ i for i in hyper_parameter_tuning_jobs if i.last_modified_time < str(last_modified_time_before) ] if last_modified_time_after: hyper_parameter_tuning_jobs = [ i for i in hyper_parameter_tuning_jobs if i.last_modified_time > str(last_modified_time_after) ] reverse = sort_order == "Descending" if sort_by == "Name": hyper_parameter_tuning_jobs = sorted( hyper_parameter_tuning_jobs, key=lambda x: x.hyper_parameter_tuning_job_name, reverse=reverse, ) elif sort_by == "Status": hyper_parameter_tuning_jobs = sorted( hyper_parameter_tuning_jobs, key=lambda x: x.hyper_parameter_tuning_job_status, reverse=reverse, ) else: hyper_parameter_tuning_jobs = sorted( hyper_parameter_tuning_jobs, key=lambda x: x.creation_time, reverse=reverse, ) return hyper_parameter_tuning_jobs def delete_hyper_parameter_tuning_job( self, hyper_parameter_tuning_job_name: str ) -> None: if hyper_parameter_tuning_job_name not in self.hyper_parameter_tuning_jobs: raise ResourceNotFound( message=f"Could not find hyper parameter tuning job '{hyper_parameter_tuning_job_name}'." ) del self.hyper_parameter_tuning_jobs[hyper_parameter_tuning_job_name] def create_model_quality_job_definition( self, job_definition_name: str, model_quality_baseline_config: Optional[Dict[str, Any]], model_quality_app_specification: Dict[str, Any], model_quality_job_input: Dict[str, Any], model_quality_job_output_config: Dict[str, Any], job_resources: Dict[str, Any], network_config: Optional[Dict[str, Any]], role_arn: str, stopping_condition: Optional[Dict[str, Any]], tags: Optional[List[Dict[str, str]]], ) -> str: model_quality_job_definition = ModelQualityJobDefinition( job_definition_name=job_definition_name, model_quality_baseline_config=model_quality_baseline_config, model_quality_app_specification=model_quality_app_specification, model_quality_job_input=model_quality_job_input, model_quality_job_output_config=model_quality_job_output_config, job_resources=job_resources, network_config=network_config, role_arn=role_arn, stopping_condition=stopping_condition, region_name=self.region_name, account_id=self.account_id, tags=tags, ) self.model_quality_job_definitions[job_definition_name] = ( model_quality_job_definition ) return model_quality_job_definition.arn def describe_model_quality_job_definition( self, job_definition_name: str ) -> Dict[str, Any]: if job_definition_name not in self.model_quality_job_definitions: raise ResourceNotFound( message=f"Could not find model quality job definition '{job_definition_name}'." ) return self.model_quality_job_definitions[job_definition_name].describe() @paginate(pagination_model=PAGINATION_MODEL) def list_model_quality_job_definitions( self, endpoint_name: Optional[str], sort_by: Optional[str], sort_order: Optional[str], name_contains: Optional[str], creation_time_before: Optional[str], creation_time_after: Optional[str], ) -> List[ModelQualityJobDefinition]: model_quality_job_definitions = list( self.model_quality_job_definitions.values() ) if endpoint_name: model_quality_job_definitions = [ i for i in model_quality_job_definitions if endpoint_name == i.endpoint_name ] if name_contains: model_quality_job_definitions = [ i for i in model_quality_job_definitions if name_contains in i.job_definition_name ] if creation_time_before: model_quality_job_definitions = [ i for i in model_quality_job_definitions if i.creation_time < str(creation_time_before) ] if creation_time_after: model_quality_job_definitions = [ i for i in model_quality_job_definitions if i.creation_time > str(creation_time_after) ] reverse = sort_order == "Descending" if sort_by == "Name": model_quality_job_definitions = sorted( model_quality_job_definitions, key=lambda x: x.job_definition_name, reverse=reverse, ) if sort_by == "CreationTime" or sort_by is None: model_quality_job_definitions = sorted( model_quality_job_definitions, key=lambda x: x.creation_time, reverse=reverse, ) return model_quality_job_definitions def delete_model_quality_job_definition(self, job_definition_name: str) -> None: if job_definition_name not in self.model_quality_job_definitions: raise ResourceNotFound( message=f"Could not find model quality job definition '{job_definition_name}'." ) del self.model_quality_job_definitions[job_definition_name] def create_model_card( self, model_card_name: str, security_config: Optional[Dict[str, str]], content: str, model_card_status: str, tags: Optional[List[Dict[str, str]]], model_card_version: Optional[int] = None, creation_time: Optional[str] = None, last_modified_time: Optional[str] = None, ) -> str: if model_card_name in self.model_cards: raise ConflictException(f"Modelcard {model_card_name} already exists") if not model_card_version: model_card_version = 1 # implement here model_card = FakeModelCard( account_id=self.account_id, region_name=self.region_name, model_card_name=model_card_name, model_card_version=model_card_version, content=content, model_card_status=model_card_status, security_config=security_config, tags=tags, ) self.model_cards[model_card_name].append(model_card) return model_card.arn def update_model_card( self, model_card_name: str, content: str, model_card_status: str ) -> str: if model_card_name not in self.model_cards: raise ResourceNotFound(f"Modelcard {model_card_name} does not exist.") datetime_now = str(datetime.now(tzutc())) first_version = self.model_cards[model_card_name][0] creation_time = first_version.creation_time most_recent_version = self.model_cards[model_card_name][-1] next_version = most_recent_version.model_card_version + 1 security_config = most_recent_version.security_config tags = most_recent_version.tags model_card = FakeModelCard( account_id=self.account_id, region_name=self.region_name, model_card_name=model_card_name, model_card_version=next_version, security_config=security_config, content=content, model_card_status=model_card_status, tags=tags, creation_time=creation_time, last_modified_time=datetime_now, ) self.model_cards[model_card_name].append(model_card) return model_card.arn @paginate(pagination_model=PAGINATION_MODEL) def list_model_cards( self, creation_time_after: Optional[datetime], creation_time_before: Optional[datetime], name_contains: Optional[str], model_card_status: Optional[str], sort_by: Optional[str], sort_order: Optional[str], ) -> List[FakeModelCard]: model_cards = self.model_cards return filter_model_cards( model_cards, creation_time_after, creation_time_before, name_contains, model_card_status, sort_by, sort_order, ) @paginate(pagination_model=PAGINATION_MODEL) def list_model_card_versions( self, creation_time_after: Optional[datetime], creation_time_before: Optional[datetime], model_card_name: str, model_card_status: Optional[str], sort_by: Optional[str], sort_order: Optional[str], ) -> List[FakeModelCard]: if model_card_name not in self.model_cards: raise ResourceNotFound(f"Modelcard {model_card_name} does not exist") versions = self.model_cards[model_card_name] if creation_time_after: versions = [ v for v in versions if v.last_modified_time > str(creation_time_after) ] if creation_time_before: versions = [ v for v in versions if v.last_modified_time < str(creation_time_before) ] if model_card_status: versions = [v for v in versions if v.model_card_status == model_card_status] reverse = sort_order == "Descending" return sorted(versions, key=lambda x: x.model_card_version, reverse=reverse) def describe_model_card( self, model_card_name: str, model_card_version: int ) -> Dict[str, Any]: if model_card_name not in self.model_cards: raise ResourceNotFound(f"Modelcard {model_card_name} does not exist") versions = self.model_cards[model_card_name] if model_card_version: filtered = [ v for v in versions if v.model_card_version == model_card_version ] if filtered: version = filtered[0] return version.describe() else: raise ResourceNotFound( f"Modelcard with name {model_card_name} and version: {model_card_version} does not exist" ) return versions[-1].describe() def delete_model_card(self, model_card_name: str) -> None: if model_card_name not in self.model_cards: raise ResourceNotFound(f"Modelcard {model_card_name} does not exist") del self.model_cards[model_card_name] def create_data_quality_job_definition( self, account_id: str, job_definition_name: str, tags: List[Dict[str, str]] = [], role_arn: str = "", job_resources: Optional[Dict[str, Any]] = None, stopping_condition: Optional[Dict[str, Any]] = None, environment: Optional[Dict[str, str]] = None, network_config: Optional[Dict[str, Any]] = None, data_quality_baseline_config: Optional[Dict[str, Any]] = None, data_quality_app_specification: Optional[Dict[str, Any]] = None, data_quality_job_input: Optional[Dict[str, Any]] = None, data_quality_job_output_config: Optional[Dict[str, Any]] = None, ) -> Dict[str, str]: job_definition = FakeDataQualityJobDefinition( account_id=account_id, region_name=self.region_name, job_definition_name=job_definition_name, tags=tags, role_arn=role_arn, job_resources=job_resources, stopping_condition=stopping_condition, environment=environment, network_config=network_config, data_quality_baseline_config=data_quality_baseline_config, data_quality_app_specification=data_quality_app_specification, data_quality_job_input=data_quality_job_input, data_quality_job_output_config=data_quality_job_output_config, ) self.data_quality_job_definitions[job_definition_name] = job_definition return job_definition.response_create @paginate(pagination_model=PAGINATION_MODEL) def list_data_quality_job_definitions(self) -> List[Dict[str, str]]: return [ job.summary_object for job in self.data_quality_job_definitions.values() ] def describe_data_quality_job_definition( self, job_definition_name: str ) -> Dict[str, Any]: job_definition = self.data_quality_job_definitions.get(job_definition_name) if job_definition is None: raise ResourceNotFound(f"Job definition {job_definition_name} not found") return job_definition.response_object def delete_data_quality_job_definition(self, job_definition_name: str) -> None: if job_definition_name in self.data_quality_job_definitions: del self.data_quality_job_definitions[job_definition_name] else: raise ResourceNotFound(f"Job definition {job_definition_name} not found") class FakeDataQualityJobDefinition(BaseObject): def __init__( self, account_id: str, region_name: str, job_definition_name: str, tags: List[Dict[str, str]] = [], role_arn: str = "", job_resources: Optional[Dict[str, Any]] = None, stopping_condition: Optional[Dict[str, Any]] = None, environment: Optional[Dict[str, str]] = None, network_config: Optional[Dict[str, Any]] = None, data_quality_baseline_config: Optional[Dict[str, Any]] = None, data_quality_app_specification: Optional[Dict[str, Any]] = None, data_quality_job_input: Optional[Dict[str, Any]] = None, data_quality_job_output_config: Optional[Dict[str, Any]] = None, ): self.job_definition_name = job_definition_name self.arn = FakeDataQualityJobDefinition.arn_formatter( job_definition_name, account_id, region_name ) self.tags = tags self.role_arn = role_arn self.job_resources = job_resources or {} self.stopping_condition = stopping_condition or {} self.environment = environment or {} self.network_config = network_config or {} self.data_quality_baseline_config = data_quality_baseline_config or {} self.data_quality_app_specification = data_quality_app_specification or {} self.data_quality_job_input = data_quality_job_input or {} self.data_quality_job_output_config = data_quality_job_output_config or {} self.creation_time = self.last_modified_time = datetime.now().strftime( "%Y-%m-%d %H:%M:%S" ) @property def response_object(self) -> Dict[str, str]: response_object = self.gen_response_object() response = { k: v for k, v in response_object.items() if v is not None and v != [None] } response["JobDefinitionArn"] = response.pop("Arn") return response @property def response_create(self) -> Dict[str, str]: return {"JobDefinitionArn": self.arn} @staticmethod def arn_formatter(name: str, account_id: str, region: str) -> str: return arn_formatter("data-quality-job-definition", name, account_id, region) @property def summary_object(self) -> Dict[str, str]: return { "MonitoringJobDefinitionName": self.job_definition_name, "MonitoringJobDefinitionArn": self.arn, "CreationTime": self.creation_time, "EndpointName": "EndpointName", } class FakeExperiment(BaseObject): def __init__( self, account_id: str, region_name: str, experiment_name: str, tags: List[Dict[str, str]], ): self.experiment_name = experiment_name self.arn = arn_formatter("experiment", experiment_name, account_id, region_name) self.tags = tags self.creation_time = self.last_modified_time = datetime.now().strftime( "%Y-%m-%d %H:%M:%S" ) @property def response_object(self) -> Dict[str, Any]: # type: ignore[misc] response_object = self.gen_response_object() return { k: v for k, v in response_object.items() if v is not None and v != [None] } @property def response_create(self) -> Dict[str, str]: return {"ExperimentArn": self.arn} class FakeTrial(BaseObject): def __init__( self, account_id: str, region_name: str, trial_name: str, experiment_name: str, tags: List[Dict[str, str]], trial_components: List[str], ): self.trial_name = trial_name self.arn = FakeTrial.arn_formatter(trial_name, account_id, region_name) self.tags = tags self.trial_components = trial_components self.experiment_name = experiment_name self.creation_time = self.last_modified_time = datetime.now().strftime( "%Y-%m-%d %H:%M:%S" ) @property def response_object(self) -> Dict[str, Any]: # type: ignore[misc] response_object = self.gen_response_object() response = { k: v for k, v in response_object.items() if v is not None and v != [None] } response["TrialArn"] = response.pop("Arn") return response @property def response_create(self) -> Dict[str, str]: return {"TrialArn": self.arn} @staticmethod def arn_formatter(name: str, account_id: str, region: str) -> str: return arn_formatter("experiment-trial", name, account_id, region) class FakeTrialComponent(BaseObject): def __init__( self, account_id: str, region_name: str, trial_component_name: str, display_name: Optional[str], start_time: Optional[datetime], end_time: Optional[datetime], parameters: Optional[Dict[str, Dict[str, Union[str, float]]]], input_artifacts: Optional[Dict[str, Dict[str, str]]], output_artifacts: Optional[Dict[str, Dict[str, str]]], metadata_properties: Optional[Dict[str, str]], status: Optional[Dict[str, str]], trial_name: Optional[str], tags: List[Dict[str, str]], ): self.trial_component_name = trial_component_name self.display_name = ( display_name if display_name is not None else trial_component_name ) self.arn = FakeTrialComponent.arn_formatter( trial_component_name, account_id, region_name ) self.status = status self.tags = tags self.trial_name = trial_name self.start_time = start_time self.end_time = end_time now_string = datetime.now().strftime("%Y-%m-%d %H:%M:%S") self.creation_time = self.last_modified_time = now_string self.created_by: Dict[str, Union[Dict[str, str], str]] = {} self.last_modified_by: Dict[str, Union[Dict[str, str], str]] = {} self.parameters = parameters if parameters is not None else {} self.input_artifacts = input_artifacts if input_artifacts is not None else {} self.output_artifacts = output_artifacts if output_artifacts is not None else {} self.metadata_properties = metadata_properties self.metrics: Dict[str, Dict[str, Union[str, int, METRIC_STEP_TYPE]]] = {} self.sources: List[Dict[str, str]] = [] @property def response_object(self) -> Dict[str, Any]: # type: ignore[misc] response_object = self.gen_response_object() response_object["Metrics"] = self.gen_metrics_response_object() response = { k: v for k, v in response_object.items() if v is not None and v != [None] } response["TrialComponentArn"] = response.pop("Arn") return response def gen_metrics_response_object( self, ) -> List[Dict[str, Union[str, int, float, datetime]]]: metrics_names = self.metrics.keys() metrics_response_objects = [] for metrics_name in metrics_names: metrics_steps: METRIC_STEP_TYPE = cast( METRIC_STEP_TYPE, self.metrics[metrics_name]["Values"] ) max_step = max(list(metrics_steps.keys())) metrics_steps_values: List[float] = list( map( lambda metric: cast(float, metric["Value"]), list(metrics_steps.values()), ) ) count = len(metrics_steps_values) mean = sum(metrics_steps_values) / count std = ( sum(map(lambda value: (value - mean) ** 2, metrics_steps_values)) / count ) ** 0.5 timestamp_int: int = cast(int, self.metrics[metrics_name]["Timestamp"]) metrics_response_object = { "MetricName": metrics_name, "SourceArn": self.arn, "TimeStamp": datetime.fromtimestamp(timestamp_int, tz=tzutc()).strftime( "%Y-%m-%d %H:%M:%S" ), "Max": max(metrics_steps_values), "Min": min(metrics_steps_values), "Last": metrics_steps[max_step]["Value"], "Count": count, "Avg": mean, "StdDev": std, } metrics_response_objects.append(metrics_response_object) return metrics_response_objects @property def response_create(self) -> Dict[str, str]: return {"TrialComponentArn": self.arn} @staticmethod def arn_formatter( trial_component_name: str, account_id: str, region_name: str ) -> str: return arn_formatter( "experiment-trial-component", trial_component_name, account_id, region_name ) class FakeModelBiasJobDefinition(BaseObject): def __init__( self, account_id: str, region_name: str, job_definition_name: str, tags: List[Dict[str, str]] = [], role_arn: str = "", job_resources: Optional[Dict[str, Any]] = None, stopping_condition: Optional[Dict[str, Any]] = None, environment: Optional[Dict[str, str]] = None, network_config: Optional[Dict[str, Any]] = None, model_bias_baseline_config: Optional[Dict[str, Any]] = None, model_bias_app_specification: Optional[Dict[str, Any]] = None, model_bias_job_input: Optional[Dict[str, Any]] = None, model_bias_job_output_config: Optional[Dict[str, Any]] = None, ): self.job_definition_name = job_definition_name self.arn = FakeModelBiasJobDefinition.arn_formatter( job_definition_name, account_id, region_name ) self.tags = tags self.role_arn = role_arn self.job_resources = job_resources or {} self.stopping_condition = stopping_condition or {} self.environment = environment or {} self.network_config = network_config or {} self.model_bias_baseline_config = model_bias_baseline_config or {} self.model_bias_app_specification = model_bias_app_specification or {} self.model_bias_job_input = model_bias_job_input or {} self.model_bias_job_output_config = model_bias_job_output_config or {} self.creation_time = self.last_modified_time = datetime.now().strftime( "%Y-%m-%d %H:%M:%S" ) @property def response_object(self) -> Dict[str, str]: response_object = self.gen_response_object() response = { k: v for k, v in response_object.items() if v is not None and v != [None] } response["JobDefinitionArn"] = response.pop("Arn") return response @property def response_create(self) -> Dict[str, str]: return {"JobDefinitionArn": self.arn} @staticmethod def arn_formatter(name: str, account_id: str, region: str) -> str: return f"arn:{get_partition(region)}:sagemaker:{region}:{account_id}:model-bias-job-definition/{name}" @property def summary_object(self) -> Dict[str, str]: return { "MonitoringJobDefinitionName": self.job_definition_name, "MonitoringJobDefinitionArn": self.arn, "CreationTime": self.creation_time, "EndpointName": self.model_bias_job_input.get("EndpointInput", {}).get( "EndpointName", "EndpointName" ), } sagemaker_backends = BackendDict(SageMakerModelBackend, "sagemaker")
Memory