import json from typing import Any from moto.core.common_types import TYPE_RESPONSE from moto.core.responses import BaseResponse from moto.sagemaker.exceptions import AWSValidationException from .models import SageMakerModelBackend, sagemaker_backends def format_enum_error(value: str, attribute: str, allowed: Any) -> str: return f"Value '{value}' at '{attribute}' failed to satisfy constraint: Member must satisfy enum value set: {allowed}" class SageMakerResponse(BaseResponse): def __init__(self) -> None: super().__init__(service_name="sagemaker") @property def sagemaker_backend(self) -> SageMakerModelBackend: return sagemaker_backends[self.current_account][self.region] def describe_model(self) -> str: model_name = self._get_param("ModelName") model = self.sagemaker_backend.describe_model(model_name) return json.dumps(model.response_object) def create_model(self) -> str: model_name = self._get_param("ModelName") execution_role_arn = self._get_param("ExecutionRoleArn") primary_container = self._get_param("PrimaryContainer") vpc_config = self._get_param("VpcConfig") containers = self._get_param("Containers") tags = self._get_param("Tags") model = self.sagemaker_backend.create_model( model_name=model_name, execution_role_arn=execution_role_arn, primary_container=primary_container, vpc_config=vpc_config, containers=containers, tags=tags, ) return json.dumps(model.response_create) def delete_model(self) -> str: model_name = self._get_param("ModelName") self.sagemaker_backend.delete_model(model_name) return "{}" def list_models(self) -> str: models = self.sagemaker_backend.list_models() return json.dumps({"Models": [model.response_object for model in models]}) def create_notebook_instance(self) -> TYPE_RESPONSE: sagemaker_notebook = self.sagemaker_backend.create_notebook_instance( notebook_instance_name=self._get_param("NotebookInstanceName"), instance_type=self._get_param("InstanceType"), subnet_id=self._get_param("SubnetId"), security_group_ids=self._get_param("SecurityGroupIds"), role_arn=self._get_param("RoleArn"), kms_key_id=self._get_param("KmsKeyId"), tags=self._get_param("Tags"), lifecycle_config_name=self._get_param("LifecycleConfigName"), direct_internet_access=self._get_param("DirectInternetAccess"), volume_size_in_gb=self._get_param("VolumeSizeInGB"), accelerator_types=self._get_param("AcceleratorTypes"), default_code_repository=self._get_param("DefaultCodeRepository"), additional_code_repositories=self._get_param("AdditionalCodeRepositories"), root_access=self._get_param("RootAccess"), ) return 200, {}, json.dumps({"NotebookInstanceArn": sagemaker_notebook.arn}) def describe_notebook_instance(self) -> str: notebook_instance_name = self._get_param("NotebookInstanceName") notebook_instance = self.sagemaker_backend.get_notebook_instance( notebook_instance_name ) return json.dumps(notebook_instance.to_dict()) def start_notebook_instance(self) -> TYPE_RESPONSE: notebook_instance_name = self._get_param("NotebookInstanceName") self.sagemaker_backend.start_notebook_instance(notebook_instance_name) return 200, {}, json.dumps("{}") def stop_notebook_instance(self) -> TYPE_RESPONSE: notebook_instance_name = self._get_param("NotebookInstanceName") self.sagemaker_backend.stop_notebook_instance(notebook_instance_name) return 200, {}, json.dumps("{}") def delete_notebook_instance(self) -> TYPE_RESPONSE: notebook_instance_name = self._get_param("NotebookInstanceName") self.sagemaker_backend.delete_notebook_instance(notebook_instance_name) return 200, {}, json.dumps("{}") def list_notebook_instances(self) -> str: sort_by = self._get_param("SortBy", "Name") sort_order = self._get_param("SortOrder", "Ascending") name_contains = self._get_param("NameContains") status = self._get_param("StatusEquals") max_results = self._get_param("MaxResults") next_token = self._get_param("NextToken") instances, next_token = self.sagemaker_backend.list_notebook_instances( sort_by=sort_by, sort_order=sort_order, name_contains=name_contains, status=status, max_results=max_results, next_token=next_token, ) return json.dumps( { "NotebookInstances": [i.to_dict() for i in instances], "NextToken": next_token, } ) def list_tags(self) -> TYPE_RESPONSE: arn = self._get_param("ResourceArn") max_results = self._get_param("MaxResults") next_token = self._get_param("NextToken") paged_results, next_token = self.sagemaker_backend.list_tags( arn=arn, MaxResults=max_results, NextToken=next_token ) response = {"Tags": paged_results} if next_token: response["NextToken"] = next_token return 200, {}, json.dumps(response) def add_tags(self) -> TYPE_RESPONSE: arn = self._get_param("ResourceArn") tags = self._get_param("Tags") tags = self.sagemaker_backend.add_tags(arn, tags) return 200, {}, json.dumps({"Tags": tags}) def delete_tags(self) -> TYPE_RESPONSE: arn = self._get_param("ResourceArn") tag_keys = self._get_param("TagKeys") self.sagemaker_backend.delete_tags(arn, tag_keys) return 200, {}, json.dumps({}) def create_endpoint_config(self) -> TYPE_RESPONSE: endpoint_config = self.sagemaker_backend.create_endpoint_config( endpoint_config_name=self._get_param("EndpointConfigName"), production_variants=self._get_param("ProductionVariants"), data_capture_config=self._get_param("DataCaptureConfig"), tags=self._get_param("Tags"), kms_key_id=self._get_param("KmsKeyId"), ) return ( 200, {}, json.dumps({"EndpointConfigArn": endpoint_config.endpoint_config_arn}), ) def describe_endpoint_config(self) -> str: endpoint_config_name = self._get_param("EndpointConfigName") response = self.sagemaker_backend.describe_endpoint_config(endpoint_config_name) return json.dumps(response) def delete_endpoint_config(self) -> TYPE_RESPONSE: endpoint_config_name = self._get_param("EndpointConfigName") self.sagemaker_backend.delete_endpoint_config(endpoint_config_name) return 200, {}, json.dumps("{}") def create_endpoint(self) -> TYPE_RESPONSE: endpoint = self.sagemaker_backend.create_endpoint( endpoint_name=self._get_param("EndpointName"), endpoint_config_name=self._get_param("EndpointConfigName"), tags=self._get_param("Tags"), ) return 200, {}, json.dumps({"EndpointArn": endpoint.arn}) def describe_endpoint(self) -> str: endpoint_name = self._get_param("EndpointName") response = self.sagemaker_backend.describe_endpoint(endpoint_name) return json.dumps(response) def delete_endpoint(self) -> TYPE_RESPONSE: endpoint_name = self._get_param("EndpointName") self.sagemaker_backend.delete_endpoint(endpoint_name) return 200, {}, json.dumps("{}") def create_processing_job(self) -> TYPE_RESPONSE: processing_job = self.sagemaker_backend.create_processing_job( app_specification=self._get_param("AppSpecification"), experiment_config=self._get_param("ExperimentConfig"), network_config=self._get_param("NetworkConfig"), processing_inputs=self._get_param("ProcessingInputs"), processing_job_name=self._get_param("ProcessingJobName"), processing_output_config=self._get_param("ProcessingOutputConfig"), role_arn=self._get_param("RoleArn"), stopping_condition=self._get_param("StoppingCondition"), tags=self._get_param("Tags"), ) response = {"ProcessingJobArn": processing_job.arn} return 200, {}, json.dumps(response) def describe_processing_job(self) -> str: processing_job_name = self._get_param("ProcessingJobName") response = self.sagemaker_backend.describe_processing_job(processing_job_name) return json.dumps(response) def create_transform_job(self) -> TYPE_RESPONSE: transform_job = self.sagemaker_backend.create_transform_job( transform_job_name=self._get_param("TransformJobName"), model_name=self._get_param("ModelName"), max_concurrent_transforms=self._get_param("MaxConcurrentTransforms"), model_client_config=self._get_param("ModelClientConfig"), max_payload_in_mb=self._get_param("MaxPayloadInMB"), batch_strategy=self._get_param("BatchStrategy"), environment=self._get_param("Environment"), transform_input=self._get_param("TransformInput"), transform_output=self._get_param("TransformOutput"), data_capture_config=self._get_param("DataCaptureConfig"), transform_resources=self._get_param("TransformResources"), data_processing=self._get_param("DataProcessing"), tags=self._get_param("Tags"), experiment_config=self._get_param("ExperimentConfig"), ) response = { "TransformJobArn": transform_job.arn, } return 200, {}, json.dumps(response) def describe_transform_job(self) -> str: transform_job_name = self._get_param("TransformJobName") response = self.sagemaker_backend.describe_transform_job( transform_job_name=transform_job_name ) return json.dumps(response) def create_training_job(self) -> TYPE_RESPONSE: training_job = self.sagemaker_backend.create_training_job( training_job_name=self._get_param("TrainingJobName"), hyper_parameters=self._get_param("HyperParameters"), algorithm_specification=self._get_param("AlgorithmSpecification"), role_arn=self._get_param("RoleArn"), input_data_config=self._get_param("InputDataConfig"), output_data_config=self._get_param("OutputDataConfig"), resource_config=self._get_param("ResourceConfig"), vpc_config=self._get_param("VpcConfig"), stopping_condition=self._get_param("StoppingCondition"), tags=self._get_param("Tags"), enable_network_isolation=self._get_param("EnableNetworkIsolation", False), enable_inter_container_traffic_encryption=self._get_param( "EnableInterContainerTrafficEncryption", False ), enable_managed_spot_training=self._get_param( "EnableManagedSpotTraining", False ), checkpoint_config=self._get_param("CheckpointConfig"), debug_hook_config=self._get_param("DebugHookConfig"), debug_rule_configurations=self._get_param("DebugRuleConfigurations"), tensor_board_output_config=self._get_param("TensorBoardOutputConfig"), experiment_config=self._get_param("ExperimentConfig"), ) response = { "TrainingJobArn": training_job.arn, } return 200, {}, json.dumps(response) def describe_training_job(self) -> str: training_job_name = self._get_param("TrainingJobName") response = self.sagemaker_backend.describe_training_job(training_job_name) return json.dumps(response) def create_notebook_instance_lifecycle_config(self) -> TYPE_RESPONSE: lifecycle_configuration = ( self.sagemaker_backend.create_notebook_instance_lifecycle_config( notebook_instance_lifecycle_config_name=self._get_param( "NotebookInstanceLifecycleConfigName" ), on_create=self._get_param("OnCreate"), on_start=self._get_param("OnStart"), ) ) response = { "NotebookInstanceLifecycleConfigArn": lifecycle_configuration.arn, } return 200, {}, json.dumps(response) def describe_notebook_instance_lifecycle_config(self) -> str: response = self.sagemaker_backend.describe_notebook_instance_lifecycle_config( notebook_instance_lifecycle_config_name=self._get_param( "NotebookInstanceLifecycleConfigName" ) ) return json.dumps(response) def delete_notebook_instance_lifecycle_config(self) -> TYPE_RESPONSE: self.sagemaker_backend.delete_notebook_instance_lifecycle_config( notebook_instance_lifecycle_config_name=self._get_param( "NotebookInstanceLifecycleConfigName" ) ) return 200, {}, json.dumps("{}") def search(self) -> TYPE_RESPONSE: response = self.sagemaker_backend.search( resource=self._get_param("Resource"), search_expression=self._get_param("SearchExpression"), ) return 200, {}, json.dumps(response) def list_experiments(self) -> TYPE_RESPONSE: MaxResults = self._get_param("MaxResults") NextToken = self._get_param("NextToken") paged_results, next_token = self.sagemaker_backend.list_experiments( MaxResults=MaxResults, NextToken=NextToken ) experiment_summaries = [ { "ExperimentName": experiment_data.experiment_name, "ExperimentArn": experiment_data.arn, "CreationTime": experiment_data.creation_time, "LastModifiedTime": experiment_data.last_modified_time, } for experiment_data in paged_results ] response = { "ExperimentSummaries": experiment_summaries, "NextToken": next_token, } return 200, {}, json.dumps(response) def delete_experiment(self) -> TYPE_RESPONSE: self.sagemaker_backend.delete_experiment( experiment_name=self._get_param("ExperimentName") ) return 200, {}, json.dumps({}) def create_experiment(self) -> TYPE_RESPONSE: response = self.sagemaker_backend.create_experiment( experiment_name=self._get_param("ExperimentName") ) return 200, {}, json.dumps(response) def describe_experiment(self) -> TYPE_RESPONSE: response = self.sagemaker_backend.describe_experiment( experiment_name=self._get_param("ExperimentName") ) return 200, {}, json.dumps(response) def list_trials(self) -> TYPE_RESPONSE: MaxResults = self._get_param("MaxResults") NextToken = self._get_param("NextToken") paged_results, next_token = self.sagemaker_backend.list_trials( NextToken=NextToken, MaxResults=MaxResults, experiment_name=self._get_param("ExperimentName"), trial_component_name=self._get_param("TrialComponentName"), ) trial_summaries = [ { "TrialName": trial_data.trial_name, "TrialArn": trial_data.arn, "CreationTime": trial_data.creation_time, "LastModifiedTime": trial_data.last_modified_time, } for trial_data in paged_results ] response = {"TrialSummaries": trial_summaries, "NextToken": next_token} return 200, {}, json.dumps(response) def create_trial(self) -> TYPE_RESPONSE: response = self.sagemaker_backend.create_trial( trial_name=self._get_param("TrialName"), experiment_name=self._get_param("ExperimentName"), ) return 200, {}, json.dumps(response) def list_trial_components(self) -> TYPE_RESPONSE: MaxResults = self._get_param("MaxResults") NextToken = self._get_param("NextToken") paged_results, next_token = self.sagemaker_backend.list_trial_components( NextToken=NextToken, MaxResults=MaxResults, trial_name=self._get_param("TrialName"), ) trial_component_summaries = [ { "TrialComponentName": trial_component_data.trial_component_name, "TrialComponentArn": trial_component_data.arn, "CreationTime": trial_component_data.creation_time, "LastModifiedTime": trial_component_data.last_modified_time, } for trial_component_data in paged_results ] response = { "TrialComponentSummaries": trial_component_summaries, "NextToken": next_token, } return 200, {}, json.dumps(response) def create_trial_component(self) -> TYPE_RESPONSE: response = self.sagemaker_backend.create_trial_component( trial_component_name=self._get_param("TrialComponentName"), start_time=self._get_param("StartTime"), end_time=self._get_param("EndTime"), display_name=self._get_param("DisplayName"), parameters=self._get_param("Parameters"), input_artifacts=self._get_param("InputArtifacts"), output_artifacts=self._get_param("OutputArtifacts"), metadata_properties=self._get_param("MetadataProperties"), status=self._get_param("Status"), trial_name=self._get_param("TrialName"), ) return 200, {}, json.dumps(response) def describe_trial(self) -> str: trial_name = self._get_param("TrialName") response = self.sagemaker_backend.describe_trial(trial_name) return json.dumps(response) def delete_trial(self) -> TYPE_RESPONSE: trial_name = self._get_param("TrialName") self.sagemaker_backend.delete_trial(trial_name) return 200, {}, json.dumps({}) def delete_trial_component(self) -> TYPE_RESPONSE: trial_component_name = self._get_param("TrialComponentName") self.sagemaker_backend.delete_trial_component(trial_component_name) return 200, {}, json.dumps({}) def describe_trial_component(self) -> str: trial_component_name = self._get_param("TrialComponentName") response = self.sagemaker_backend.describe_trial_component(trial_component_name) return json.dumps(response) def associate_trial_component(self) -> TYPE_RESPONSE: trial_name = self._get_param("TrialName") trial_component_name = self._get_param("TrialComponentName") response = self.sagemaker_backend.associate_trial_component( trial_name, trial_component_name ) return 200, {}, json.dumps(response) def disassociate_trial_component(self) -> TYPE_RESPONSE: trial_component_name = self._get_param("TrialComponentName") trial_name = self._get_param("TrialName") response = self.sagemaker_backend.disassociate_trial_component( trial_name, trial_component_name ) return 200, {}, json.dumps(response) def update_trial_component(self) -> TYPE_RESPONSE: response = self.sagemaker_backend.update_trial_component( trial_component_name=self._get_param("TrialComponentName"), status=self._get_param("Status"), display_name=self._get_param("DisplayName"), start_time=self._get_param("StartTime"), end_time=self._get_param("EndTime"), parameters=self._get_param("Parameters"), parameters_to_remove=self._get_param("ParametersToRemove"), input_artifacts=self._get_param("InputArtifacts"), input_artifacts_to_remove=self._get_param("InputArtifactsToRemove"), output_artifacts=self._get_param("OutputArtifacts"), output_artifacts_to_remove=self._get_param("OutputArtifactsToRemove"), ) return 200, {}, json.dumps(response) def describe_pipeline(self) -> TYPE_RESPONSE: response = self.sagemaker_backend.describe_pipeline( self._get_param("PipelineName") ) return 200, {}, json.dumps(response) def start_pipeline_execution(self) -> TYPE_RESPONSE: response = self.sagemaker_backend.start_pipeline_execution( self._get_param("PipelineName"), self._get_param("PipelineExecutionDisplayName"), self._get_param("PipelineParameters"), self._get_param("PipelineExecutionDescription"), self._get_param("ParallelismConfiguration"), self._get_param("ClientRequestToken"), ) return 200, {}, json.dumps(response) def describe_pipeline_execution(self) -> TYPE_RESPONSE: response = self.sagemaker_backend.describe_pipeline_execution( self._get_param("PipelineExecutionArn") ) return 200, {}, json.dumps(response) def describe_pipeline_definition_for_execution(self) -> TYPE_RESPONSE: response = self.sagemaker_backend.describe_pipeline_definition_for_execution( self._get_param("PipelineExecutionArn") ) return 200, {}, json.dumps(response) def list_pipeline_parameters_for_execution(self) -> TYPE_RESPONSE: response = self.sagemaker_backend.list_pipeline_parameters_for_execution( self._get_param("PipelineExecutionArn") ) return 200, {}, json.dumps(response) def list_pipeline_executions(self) -> TYPE_RESPONSE: response = self.sagemaker_backend.list_pipeline_executions( self._get_param("PipelineName") ) return 200, {}, json.dumps(response) def create_pipeline(self) -> TYPE_RESPONSE: pipeline = self.sagemaker_backend.create_pipeline( pipeline_name=self._get_param("PipelineName"), pipeline_display_name=self._get_param("PipelineDisplayName"), pipeline_definition=self._get_param("PipelineDefinition"), pipeline_definition_s3_location=self._get_param( "PipelineDefinitionS3Location" ), pipeline_description=self._get_param("PipelineDescription"), role_arn=self._get_param("RoleArn"), tags=self._get_param("Tags"), parallelism_configuration=self._get_param("ParallelismConfiguration"), ) response = { "PipelineArn": pipeline.arn, } return 200, {}, json.dumps(response) def delete_pipeline(self) -> TYPE_RESPONSE: arn = self.sagemaker_backend.delete_pipeline( pipeline_name=self._get_param("PipelineName"), ) response = {"PipelineArn": arn} return 200, {}, json.dumps(response) def update_pipeline(self) -> TYPE_RESPONSE: arn = self.sagemaker_backend.update_pipeline( pipeline_name=self._get_param("PipelineName"), pipeline_display_name=self._get_param("PipelineDisplayName"), pipeline_definition=self._get_param("PipelineDefinition"), pipeline_definition_s3_location=self._get_param( "PipelineDefinitionS3Location" ), pipeline_description=self._get_param("PipelineDescription"), role_arn=self._get_param("RoleArn"), parallelism_configuration=self._get_param("ParallelismConfiguration"), ) response = {"PipelineArn": arn} return 200, {}, json.dumps(response) def list_pipelines(self) -> TYPE_RESPONSE: max_results_range = range(1, 101) allowed_sort_by = ("Name", "CreationTime") allowed_sort_order = ("Ascending", "Descending") pipeline_name_prefix = self._get_param("PipelineNamePrefix") created_after = self._get_param("CreatedAfter") created_before = self._get_param("CreatedBefore") sort_by = self._get_param("SortBy", "CreationTime") sort_order = self._get_param("SortOrder", "Descending") next_token = self._get_param("NextToken") max_results = self._get_param("MaxResults") errors = [] if max_results and max_results not in max_results_range: errors.append( f"Value '{max_results}' at 'maxResults' failed to satisfy constraint: Member must have value less than or equal to {max_results_range[-1]}" ) if sort_by not in allowed_sort_by: errors.append(format_enum_error(sort_by, "SortBy", allowed_sort_by)) if sort_order not in allowed_sort_order: errors.append( format_enum_error(sort_order, "SortOrder", allowed_sort_order) ) if errors: raise AWSValidationException( f"{len(errors)} validation errors detected: {';'.join(errors)}" ) response = self.sagemaker_backend.list_pipelines( pipeline_name_prefix=pipeline_name_prefix, created_after=created_after, created_before=created_before, next_token=next_token, max_results=max_results, sort_by=sort_by, sort_order=sort_order, ) return 200, {}, json.dumps(response) def list_processing_jobs(self) -> TYPE_RESPONSE: max_results_range = range(1, 101) allowed_sort_by = ["Name", "CreationTime", "Status"] allowed_sort_order = ["Ascending", "Descending"] allowed_status_equals = [ "Completed", "Stopped", "InProgress", "Stopping", "Failed", ] max_results = self._get_int_param("MaxResults") sort_by = self._get_param("SortBy", "CreationTime") sort_order = self._get_param("SortOrder", "Ascending") status_equals = self._get_param("StatusEquals") next_token = self._get_param("NextToken") errors = [] if max_results and max_results not in max_results_range: errors.append( f"Value '{max_results}' at 'maxResults' failed to satisfy constraint: Member must have value less than or equal to {max_results_range[-1]}" ) if sort_by not in allowed_sort_by: errors.append(format_enum_error(sort_by, "sortBy", allowed_sort_by)) if sort_order not in allowed_sort_order: errors.append( format_enum_error(sort_order, "sortOrder", allowed_sort_order) ) if status_equals and status_equals not in allowed_status_equals: errors.append( format_enum_error(status_equals, "statusEquals", allowed_status_equals) ) if errors != []: raise AWSValidationException( f"{len(errors)} validation errors detected: {';'.join(errors)}" ) response = self.sagemaker_backend.list_processing_jobs( next_token=next_token, max_results=max_results, creation_time_after=self._get_param("CreationTimeAfter"), creation_time_before=self._get_param("CreationTimeBefore"), last_modified_time_after=self._get_param("LastModifiedTimeAfter"), last_modified_time_before=self._get_param("LastModifiedTimeBefore"), name_contains=self._get_param("NameContains"), status_equals=status_equals, ) return 200, {}, json.dumps(response) def list_transform_jobs(self) -> TYPE_RESPONSE: max_results_range = range(1, 101) allowed_sort_by = ["Name", "CreationTime", "Status"] allowed_sort_order = ["Ascending", "Descending"] allowed_status_equals = [ "Completed", "Stopped", "InProgress", "Stopping", "Failed", ] max_results = self._get_int_param("MaxResults") sort_by = self._get_param("SortBy", "CreationTime") sort_order = self._get_param("SortOrder", "Ascending") status_equals = self._get_param("StatusEquals") next_token = self._get_param("NextToken") errors = [] if max_results and max_results not in max_results_range: errors.append( f"Value '{max_results}' at 'maxResults' failed to satisfy constraint: Member must have value less than or equal to {max_results_range[-1]}" ) if sort_by not in allowed_sort_by: errors.append(format_enum_error(sort_by, "sortBy", allowed_sort_by)) if sort_order not in allowed_sort_order: errors.append( format_enum_error(sort_order, "sortOrder", allowed_sort_order) ) if status_equals and status_equals not in allowed_status_equals: errors.append( format_enum_error(status_equals, "statusEquals", allowed_status_equals) ) if errors != []: raise AWSValidationException( f"{len(errors)} validation errors detected: {';'.join(errors)}" ) response = self.sagemaker_backend.list_transform_jobs( next_token=next_token, max_results=max_results, creation_time_after=self._get_param("CreationTimeAfter"), creation_time_before=self._get_param("CreationTimeBefore"), last_modified_time_after=self._get_param("LastModifiedTimeAfter"), last_modified_time_before=self._get_param("LastModifiedTimeBefore"), name_contains=self._get_param("NameContains"), status_equals=status_equals, ) return 200, {}, json.dumps(response) def list_training_jobs(self) -> TYPE_RESPONSE: max_results_range = range(1, 101) allowed_sort_by = ["Name", "CreationTime", "Status"] allowed_sort_order = ["Ascending", "Descending"] allowed_status_equals = [ "Completed", "Stopped", "InProgress", "Stopping", "Failed", ] max_results = self._get_int_param("MaxResults") sort_by = self._get_param("SortBy", "CreationTime") sort_order = self._get_param("SortOrder", "Ascending") status_equals = self._get_param("StatusEquals") next_token = self._get_param("NextToken") errors = [] if max_results and max_results not in max_results_range: errors.append( f"Value '{max_results}' at 'maxResults' failed to satisfy constraint: Member must have value less than or equal to {max_results_range[-1]}" ) if sort_by not in allowed_sort_by: errors.append(format_enum_error(sort_by, "sortBy", allowed_sort_by)) if sort_order not in allowed_sort_order: errors.append( format_enum_error(sort_order, "sortOrder", allowed_sort_order) ) if status_equals and status_equals not in allowed_status_equals: errors.append( format_enum_error(status_equals, "statusEquals", allowed_status_equals) ) if errors != []: raise AWSValidationException( f"{len(errors)} validation errors detected: {';'.join(errors)}" ) response = self.sagemaker_backend.list_training_jobs( next_token=next_token, max_results=max_results, creation_time_after=self._get_param("CreationTimeAfter"), creation_time_before=self._get_param("CreationTimeBefore"), last_modified_time_after=self._get_param("LastModifiedTimeAfter"), last_modified_time_before=self._get_param("LastModifiedTimeBefore"), name_contains=self._get_param("NameContains"), status_equals=status_equals, ) return 200, {}, json.dumps(response) def update_endpoint_weights_and_capacities(self) -> TYPE_RESPONSE: endpoint_name = self._get_param("EndpointName") desired_weights_and_capacities = self._get_param("DesiredWeightsAndCapacities") endpoint_arn = self.sagemaker_backend.update_endpoint_weights_and_capacities( endpoint_name=endpoint_name, desired_weights_and_capacities=desired_weights_and_capacities, ) return 200, {}, json.dumps({"EndpointArn": endpoint_arn}) def list_model_package_groups(self) -> str: creation_time_after = self._get_param("CreationTimeAfter") creation_time_before = self._get_param("CreationTimeBefore") max_results = self._get_param("MaxResults") name_contains = self._get_param("NameContains") next_token = self._get_param("NextToken") sort_by = self._get_param("SortBy") sort_order = self._get_param("SortOrder") ( model_package_group_summary_list, next_token, ) = self.sagemaker_backend.list_model_package_groups( creation_time_after=creation_time_after, creation_time_before=creation_time_before, max_results=max_results, name_contains=name_contains, next_token=next_token, sort_by=sort_by, sort_order=sort_order, ) model_package_group_summary_list_response_object = [ x.gen_response_object() for x in model_package_group_summary_list ] return json.dumps( dict( ModelPackageGroupSummaryList=model_package_group_summary_list_response_object, NextToken=next_token, ) ) def list_model_packages(self) -> str: creation_time_after = self._get_param("CreationTimeAfter") creation_time_before = self._get_param("CreationTimeBefore") max_results = self._get_param("MaxResults") name_contains = self._get_param("NameContains") model_approval_status = self._get_param("ModelApprovalStatus") model_package_group_name = self._get_param("ModelPackageGroupName") model_package_type = self._get_param("ModelPackageType", "Unversioned") next_token = self._get_param("NextToken") sort_by = self._get_param("SortBy") sort_order = self._get_param("SortOrder") ( model_package_summary_list, next_token, ) = self.sagemaker_backend.list_model_packages( creation_time_after=creation_time_after, creation_time_before=creation_time_before, max_results=max_results, name_contains=name_contains, model_approval_status=model_approval_status, model_package_group_name=model_package_group_name, model_package_type=model_package_type, next_token=next_token, sort_by=sort_by, sort_order=sort_order, ) model_package_summary_list_response_object = [ x.gen_response_object() for x in model_package_summary_list ] return json.dumps( dict( ModelPackageSummaryList=model_package_summary_list_response_object, NextToken=next_token, ) ) def describe_model_package(self) -> str: model_package_name = self._get_param("ModelPackageName") model_package = self.sagemaker_backend.describe_model_package( model_package_name=model_package_name, ) return json.dumps( model_package.gen_response_object(), ) def describe_model_package_group(self) -> str: model_package_group_name = self._get_param("ModelPackageGroupName") model_package_group = self.sagemaker_backend.describe_model_package_group( model_package_group_name=model_package_group_name, ) return json.dumps( model_package_group.gen_response_object(), ) def update_model_package(self) -> str: model_package_arn = self._get_param("ModelPackageArn") model_approval_status = self._get_param("ModelApprovalStatus") approval_dexcription = self._get_param("ApprovalDescription") customer_metadata_properties = self._get_param("CustomerMetadataProperties") customer_metadata_properties_to_remove = self._get_param( "CustomerMetadataPropertiesToRemove", [] ) additional_inference_specification_to_add = self._get_param( "AdditionalInferenceSpecificationsToAdd" ) updated_model_package_arn = self.sagemaker_backend.update_model_package( model_package_arn=model_package_arn, model_approval_status=model_approval_status, approval_description=approval_dexcription, customer_metadata_properties=customer_metadata_properties, customer_metadata_properties_to_remove=customer_metadata_properties_to_remove, additional_inference_specifications_to_add=additional_inference_specification_to_add, ) return json.dumps(dict(ModelPackageArn=updated_model_package_arn)) def create_model_package(self) -> str: model_package_name = self._get_param("ModelPackageName") model_package_group_name = self._get_param("ModelPackageGroupName") model_package_description = self._get_param("ModelPackageDescription") inference_specification = self._get_param("InferenceSpecification") validation_specification = self._get_param("ValidationSpecification") source_algorithm_specification = self._get_param("SourceAlgorithmSpecification") certify_for_marketplace = self._get_param("CertifyForMarketplace", False) tags = self._get_param("Tags") model_approval_status = self._get_param("ModelApprovalStatus") metadata_properties = self._get_param("MetadataProperties") model_metrics = self._get_param("ModelMetrics") client_token = self._get_param("ClientToken") customer_metadata_properties = self._get_param("CustomerMetadataProperties") drift_check_baselines = self._get_param("DriftCheckBaselines") domain = self._get_param("Domain") task = self._get_param("Task") sample_payload_url = self._get_param("SamplePayloadUrl") additional_inference_specifications = self._get_param( "AdditionalInferenceSpecifications", None ) model_package_arn = self.sagemaker_backend.create_model_package( model_package_name=model_package_name, model_package_group_name=model_package_group_name, model_package_description=model_package_description, inference_specification=inference_specification, validation_specification=validation_specification, source_algorithm_specification=source_algorithm_specification, certify_for_marketplace=certify_for_marketplace, tags=tags, model_approval_status=model_approval_status, metadata_properties=metadata_properties, model_metrics=model_metrics, customer_metadata_properties=customer_metadata_properties, drift_check_baselines=drift_check_baselines, domain=domain, task=task, sample_payload_url=sample_payload_url, additional_inference_specifications=additional_inference_specifications, client_token=client_token, ) return json.dumps(dict(ModelPackageArn=model_package_arn)) def create_model_package_group(self) -> str: model_package_group_name = self._get_param("ModelPackageGroupName") model_package_group_description = self._get_param( "ModelPackageGroupDescription" ) tags = self._get_param("Tags") model_package_group_arn = self.sagemaker_backend.create_model_package_group( model_package_group_name=model_package_group_name, model_package_group_description=model_package_group_description, tags=tags, ) return json.dumps(dict(ModelPackageGroupArn=model_package_group_arn)) def create_feature_group(self) -> str: feature_group_arn = self.sagemaker_backend.create_feature_group( feature_group_name=self._get_param("FeatureGroupName"), record_identifier_feature_name=self._get_param( "RecordIdentifierFeatureName" ), event_time_feature_name=self._get_param("EventTimeFeatureName"), feature_definitions=self._get_param("FeatureDefinitions"), offline_store_config=self._get_param("OfflineStoreConfig"), role_arn=self._get_param("RoleArn"), tags=self._get_param("Tags"), ) return json.dumps(dict(FeatureGroupArn=feature_group_arn)) def describe_feature_group(self) -> str: resp = self.sagemaker_backend.describe_feature_group( feature_group_name=self._get_param("FeatureGroupName"), ) return json.dumps(resp) def create_cluster(self) -> str: cluster_name = self._get_param("ClusterName") instance_groups = self._get_param("InstanceGroups") vpc_config = self._get_param("VpcConfig") tags = self._get_param("Tags") cluster_arn = self.sagemaker_backend.create_cluster( cluster_name=cluster_name, instance_groups=instance_groups, vpc_config=vpc_config, tags=tags, ) return json.dumps(dict(ClusterArn=cluster_arn)) def delete_cluster(self) -> str: cluster_name = self._get_param("ClusterName") cluster_arn = self.sagemaker_backend.delete_cluster( cluster_name=cluster_name, ) return json.dumps(dict(ClusterArn=cluster_arn)) def describe_cluster(self) -> str: cluster_name = self._get_param("ClusterName") cluster_description = self.sagemaker_backend.describe_cluster( cluster_name=cluster_name, ) return json.dumps(cluster_description) def describe_cluster_node(self) -> str: cluster_name = self._get_param("ClusterName") node_id = self._get_param("NodeId") node_details = self.sagemaker_backend.describe_cluster_node( cluster_name=cluster_name, node_id=node_id, ) return json.dumps(dict(NodeDetails=node_details)) def list_clusters(self) -> str: creation_time_after = self._get_param("CreationTimeAfter") creation_time_before = self._get_param("CreationTimeBefore") max_results = self._get_param("MaxResults") name_contains = self._get_param("NameContains") next_token = self._get_param("NextToken") sort_by = self._get_param("SortBy") sort_order = self._get_param("SortOrder") clusters, next_token = self.sagemaker_backend.list_clusters( creation_time_after=creation_time_after, creation_time_before=creation_time_before, max_results=max_results, name_contains=name_contains, next_token=next_token, sort_by=sort_by, sort_order=sort_order, ) cluster_summaries = [cluster.summary() for cluster in clusters] return json.dumps( dict(NextToken=next_token, ClusterSummaries=cluster_summaries) ) def list_cluster_nodes(self) -> str: cluster_name = self._get_param("ClusterName") creation_time_after = self._get_param("CreationTimeAfter") creation_time_before = self._get_param("CreationTimeBefore") instance_group_name_contains = self._get_param("InstanceGroupNameContains") max_results = self._get_param("MaxResults") next_token = self._get_param("NextToken") sort_by = self._get_param("SortBy") sort_order = self._get_param("SortOrder") cluster_nodes, next_token = self.sagemaker_backend.list_cluster_nodes( cluster_name=cluster_name, creation_time_after=creation_time_after, creation_time_before=creation_time_before, instance_group_name_contains=instance_group_name_contains, max_results=max_results, next_token=next_token, sort_by=sort_by, sort_order=sort_order, ) cluster_node_summaries = [node.summary() for node in cluster_nodes] return json.dumps( dict(NextToken=next_token, ClusterNodeSummaries=cluster_node_summaries) ) def create_model_bias_job_definition(self) -> str: account_id = self.current_account job_definition_name = self._get_param("JobDefinitionName") tags = self._get_param("Tags", []) role_arn = self._get_param("RoleArn") job_resources = self._get_param("JobResources") stopping_condition = self._get_param("StoppingCondition") environment = self._get_param("Environment", {}) network_config = self._get_param("NetworkConfig", {}) model_bias_baseline_config = self._get_param("ModelBiasBaselineConfig", {}) model_bias_app_specification = self._get_param("ModelBiasAppSpecification") model_bias_job_input = self._get_param("ModelBiasJobInput") model_bias_job_output_config = self._get_param("ModelBiasJobOutputConfig") response = self.sagemaker_backend.create_model_bias_job_definition( account_id=account_id, job_definition_name=job_definition_name, tags=tags, role_arn=role_arn, job_resources=job_resources, stopping_condition=stopping_condition, environment=environment, network_config=network_config, model_bias_baseline_config=model_bias_baseline_config, model_bias_app_specification=model_bias_app_specification, model_bias_job_input=model_bias_job_input, model_bias_job_output_config=model_bias_job_output_config, ) return json.dumps(response) def list_model_bias_job_definitions(self) -> str: result, next_token = self.sagemaker_backend.list_model_bias_job_definitions() return json.dumps({"JobDefinitionSummaries": result, "NextToken": next_token}) def describe_model_bias_job_definition(self) -> str: job_definition_name = self._get_param("JobDefinitionName") job_definition = self.sagemaker_backend.describe_model_bias_job_definition( job_definition_name ) return json.dumps(job_definition) def delete_model_bias_job_definition(self) -> str: job_definition_name = self._get_param("JobDefinitionName") self.sagemaker_backend.delete_model_bias_job_definition(job_definition_name) return json.dumps({}) def create_data_quality_job_definition(self) -> str: account_id = self.current_account job_definition_name = self._get_param("JobDefinitionName") tags = self._get_param("Tags", []) role_arn = self._get_param("RoleArn") job_resources = self._get_param("JobResources") stopping_condition = self._get_param("StoppingCondition") environment = self._get_param("Environment", {}) network_config = self._get_param("NetworkConfig", {}) data_quality_baseline_config = self._get_param("DataQualityBaselineConfig", {}) data_quality_app_specification = self._get_param("DataQualityAppSpecification") data_quality_job_input = self._get_param("DataQualityJobInput") data_quality_job_output_config = self._get_param("DataQualityJobOutputConfig") response = self.sagemaker_backend.create_data_quality_job_definition( account_id=account_id, job_definition_name=job_definition_name, tags=tags, role_arn=role_arn, job_resources=job_resources, stopping_condition=stopping_condition, environment=environment, network_config=network_config, data_quality_baseline_config=data_quality_baseline_config, data_quality_app_specification=data_quality_app_specification, data_quality_job_input=data_quality_job_input, data_quality_job_output_config=data_quality_job_output_config, ) return json.dumps(response) def list_data_quality_job_definitions(self) -> str: result, next_token = self.sagemaker_backend.list_data_quality_job_definitions() return json.dumps({"JobDefinitionSummaries": result, "NextToken": next_token}) def describe_data_quality_job_definition(self) -> str: job_definition_name = self._get_param("JobDefinitionName") job_definition = self.sagemaker_backend.describe_data_quality_job_definition( job_definition_name ) return json.dumps(job_definition) def delete_data_quality_job_definition(self) -> str: job_definition_name = self._get_param("JobDefinitionName") self.sagemaker_backend.delete_data_quality_job_definition(job_definition_name) return json.dumps({}) def create_auto_ml_job_v2(self) -> str: auto_ml_job_name = self._get_param("AutoMLJobName") auto_ml_job_input_data_config = self._get_param("AutoMLJobInputDataConfig") output_data_config = self._get_param("OutputDataConfig") auto_ml_problem_type_config = self._get_param("AutoMLProblemTypeConfig") role_arn = self._get_param("RoleArn") tags = self._get_param("Tags") security_config = self._get_param("SecurityConfig") auto_ml_job_objective = self._get_param("AutoMLJobObjective") model_deploy_config = self._get_param("ModelDeployConfig") data_split_config = self._get_param("DataSplitConfig") auto_ml_job_arn = self.sagemaker_backend.create_auto_ml_job_v2( auto_ml_job_name=auto_ml_job_name, auto_ml_job_input_data_config=auto_ml_job_input_data_config, output_data_config=output_data_config, auto_ml_problem_type_config=auto_ml_problem_type_config, role_arn=role_arn, tags=tags, security_config=security_config, auto_ml_job_objective=auto_ml_job_objective, model_deploy_config=model_deploy_config, data_split_config=data_split_config, ) return json.dumps(dict(AutoMLJobArn=auto_ml_job_arn)) def describe_auto_ml_job_v2(self) -> str: auto_ml_job_name = self._get_param("AutoMLJobName") auto_ml_job_description = self.sagemaker_backend.describe_auto_ml_job_v2( auto_ml_job_name=auto_ml_job_name, ) return json.dumps(auto_ml_job_description) def list_auto_ml_jobs(self) -> str: creation_time_after = self._get_param("CreationTimeAfter") creation_time_before = self._get_param("CreationTimeBefore") last_modified_time_after = self._get_param("LastModifiedTimeAfter") last_modified_time_before = self._get_param("LastModifiedTimeBefore") name_contains = self._get_param("NameContains") status_equals = self._get_param("StatusEquals") sort_order = self._get_param("SortOrder") sort_by = self._get_param("SortBy") max_results = self._get_param("MaxResults") next_token = self._get_param("NextToken") auto_ml_jobs, next_token = self.sagemaker_backend.list_auto_ml_jobs( creation_time_after=creation_time_after, creation_time_before=creation_time_before, last_modified_time_after=last_modified_time_after, last_modified_time_before=last_modified_time_before, name_contains=name_contains, status_equals=status_equals, sort_order=sort_order, sort_by=sort_by, max_results=max_results, next_token=next_token, ) auto_ml_job_summaries = [auto_ml_job.summary() for auto_ml_job in auto_ml_jobs] return json.dumps( dict(AutoMLJobSummaries=auto_ml_job_summaries, NextToken=next_token) ) def stop_auto_ml_job(self) -> str: auto_ml_job_name = self._get_param("AutoMLJobName") self.sagemaker_backend.stop_auto_ml_job( auto_ml_job_name=auto_ml_job_name, ) return json.dumps(dict()) def list_endpoints(self) -> str: sort_by = self._get_param("SortBy") sort_order = self._get_param("SortOrder") next_token = self._get_param("NextToken") max_results = self._get_param("MaxResults") name_contains = self._get_param("NameContains") creation_time_before = self._get_param("CreationTimeBefore") creation_time_after = self._get_param("CreationTimeAfter") last_modified_time_before = self._get_param("LastModifiedTimeBefore") last_modified_time_after = self._get_param("LastModifiedTimeAfter") status_equals = self._get_param("StatusEquals") endpoints, next_token = self.sagemaker_backend.list_endpoints( sort_by=sort_by, sort_order=sort_order, next_token=next_token, max_results=max_results, name_contains=name_contains, creation_time_before=creation_time_before, creation_time_after=creation_time_after, last_modified_time_before=last_modified_time_before, last_modified_time_after=last_modified_time_after, status_equals=status_equals, ) endpoint_summaries = [endpoint.summary() for endpoint in endpoints] return json.dumps(dict(Endpoints=endpoint_summaries, NextToken=next_token)) def list_endpoint_configs(self) -> str: sort_by = self._get_param("SortBy") sort_order = self._get_param("SortOrder") next_token = self._get_param("NextToken") max_results = self._get_param("MaxResults") name_contains = self._get_param("NameContains") creation_time_before = self._get_param("CreationTimeBefore") creation_time_after = self._get_param("CreationTimeAfter") endpoint_configs, next_token = self.sagemaker_backend.list_endpoint_configs( sort_by=sort_by, sort_order=sort_order, next_token=next_token, max_results=max_results, name_contains=name_contains, creation_time_before=creation_time_before, creation_time_after=creation_time_after, ) endpoint_summaries = [ endpoint_config.summary() for endpoint_config in endpoint_configs ] return json.dumps( dict(EndpointConfigs=endpoint_summaries, NextToken=next_token) ) def create_compilation_job(self) -> str: compilation_job_name = self._get_param("CompilationJobName") role_arn = self._get_param("RoleArn") model_package_version_arn = self._get_param("ModelPackageVersionArn") input_config = self._get_param("InputConfig") output_config = self._get_param("OutputConfig") vpc_config = self._get_param("VpcConfig") stopping_condition = self._get_param("StoppingCondition") tags = self._get_param("Tags") compilation_job_arn = self.sagemaker_backend.create_compilation_job( compilation_job_name=compilation_job_name, role_arn=role_arn, model_package_version_arn=model_package_version_arn, input_config=input_config, output_config=output_config, vpc_config=vpc_config, stopping_condition=stopping_condition, tags=tags, ) return json.dumps(dict(CompilationJobArn=compilation_job_arn)) def describe_compilation_job(self) -> str: compilation_job_name = self._get_param("CompilationJobName") compilation_job_description = self.sagemaker_backend.describe_compilation_job( compilation_job_name=compilation_job_name, ) return json.dumps(compilation_job_description) def list_compilation_jobs(self) -> str: next_token = self._get_param("NextToken") max_results = self._get_param("MaxResults") creation_time_after = self._get_param("CreationTimeAfter") creation_time_before = self._get_param("CreationTimeBefore") last_modified_time_after = self._get_param("LastModifiedTimeAfter") last_modified_time_before = self._get_param("LastModifiedTimeBefore") name_contains = self._get_param("NameContains") status_equals = self._get_param("StatusEquals") sort_by = self._get_param("SortBy") sort_order = self._get_param("SortOrder") compilation_jobs, next_token = self.sagemaker_backend.list_compilation_jobs( next_token=next_token, max_results=max_results, creation_time_after=creation_time_after, creation_time_before=creation_time_before, last_modified_time_after=last_modified_time_after, last_modified_time_before=last_modified_time_before, name_contains=name_contains, status_equals=status_equals, sort_by=sort_by, sort_order=sort_order, ) compilation_job_summaries = [x.summary() for x in compilation_jobs] return json.dumps( dict( CompilationJobSummaries=compilation_job_summaries, NextToken=next_token ) ) def delete_compilation_job(self) -> str: compilation_job_name = self._get_param("CompilationJobName") self.sagemaker_backend.delete_compilation_job( compilation_job_name=compilation_job_name, ) return json.dumps({}) def create_domain(self) -> str: domain_name = self._get_param("DomainName") auth_mode = self._get_param("AuthMode") default_user_settings = self._get_param("DefaultUserSettings") domain_settings = self._get_param("DomainSettings") subnet_ids = self._get_param("SubnetIds") vpc_id = self._get_param("VpcId") tags = self._get_param("Tags") app_network_access_type = self._get_param("AppNetworkAccessType") home_efs_file_system_kms_key_id = self._get_param("HomeEfsFileSystemKmsKeyId") kms_key_id = self._get_param("KmsKeyId") app_security_group_management = self._get_param("AppSecurityGroupManagement") default_space_settings = self._get_param("DefaultSpaceSettings") resp = self.sagemaker_backend.create_domain( domain_name=domain_name, auth_mode=auth_mode, default_user_settings=default_user_settings, domain_settings=domain_settings, subnet_ids=subnet_ids, vpc_id=vpc_id, tags=tags, app_network_access_type=app_network_access_type, home_efs_file_system_kms_key_id=home_efs_file_system_kms_key_id, kms_key_id=kms_key_id, app_security_group_management=app_security_group_management, default_space_settings=default_space_settings, ) return json.dumps(resp) def describe_domain(self) -> str: domain_id = self._get_param("DomainId") domain_description = self.sagemaker_backend.describe_domain( domain_id=domain_id, ) return json.dumps(domain_description) def list_domains(self) -> str: next_token = self._get_param("NextToken") max_results = self._get_param("MaxResults") domains, next_token = self.sagemaker_backend.list_domains( next_token=next_token, max_results=max_results, ) domain_summaries = [domain.summary() for domain in domains] return json.dumps(dict(Domains=domain_summaries, NextToken=next_token)) def delete_domain(self) -> str: domain_id = self._get_param("DomainId") retention_policy = self._get_param("RetentionPolicy") self.sagemaker_backend.delete_domain( domain_id=domain_id, retention_policy=retention_policy, ) return json.dumps(dict()) def create_model_explainability_job_definition(self) -> str: job_definition_name = self._get_param("JobDefinitionName") model_explainability_baseline_config = self._get_param( "ModelExplainabilityBaselineConfig" ) model_explainability_app_specification = self._get_param( "ModelExplainabilityAppSpecification" ) model_explainability_job_input = self._get_param("ModelExplainabilityJobInput") model_explainability_job_output_config = self._get_param( "ModelExplainabilityJobOutputConfig" ) job_resources = self._get_param("JobResources") network_config = self._get_param("NetworkConfig") role_arn = self._get_param("RoleArn") stopping_condition = self._get_param("StoppingCondition") tags = self._get_param("Tags") job_definition_arn = self.sagemaker_backend.create_model_explainability_job_definition( job_definition_name=job_definition_name, model_explainability_baseline_config=model_explainability_baseline_config, model_explainability_app_specification=model_explainability_app_specification, model_explainability_job_input=model_explainability_job_input, model_explainability_job_output_config=model_explainability_job_output_config, job_resources=job_resources, network_config=network_config, role_arn=role_arn, stopping_condition=stopping_condition, tags=tags, ) return json.dumps(dict(JobDefinitionArn=job_definition_arn)) def describe_model_explainability_job_definition(self) -> str: job_definition_name = self._get_param("JobDefinitionName") description = ( self.sagemaker_backend.describe_model_explainability_job_definition( job_definition_name=job_definition_name, ) ) return json.dumps(description) def list_model_explainability_job_definitions(self) -> str: endpoint_name = self._get_param("EndpointName") sort_by = self._get_param("SortBy") sort_order = self._get_param("SortOrder") next_token = self._get_param("NextToken") max_results = self._get_param("MaxResults") name_contains = self._get_param("NameContains") creation_time_before = self._get_param("CreationTimeBefore") creation_time_after = self._get_param("CreationTimeAfter") job_definitions, next_token = ( self.sagemaker_backend.list_model_explainability_job_definitions( endpoint_name=endpoint_name, sort_by=sort_by, sort_order=sort_order, next_token=next_token, max_results=max_results, name_contains=name_contains, creation_time_before=creation_time_before, creation_time_after=creation_time_after, ) ) job_definition_summaries = [job.summary() for job in job_definitions] return json.dumps( dict(JobDefinitionSummaries=job_definition_summaries, NextToken=next_token) ) def delete_model_explainability_job_definition(self) -> str: job_definition_name = self._get_param("JobDefinitionName") self.sagemaker_backend.delete_model_explainability_job_definition( job_definition_name=job_definition_name, ) return json.dumps(dict()) def create_hyper_parameter_tuning_job(self) -> str: hyper_parameter_tuning_job_name = self._get_param("HyperParameterTuningJobName") hyper_parameter_tuning_job_config = self._get_param( "HyperParameterTuningJobConfig" ) training_job_definition = self._get_param("TrainingJobDefinition") training_job_definitions = self._get_param("TrainingJobDefinitions") warm_start_config = self._get_param("WarmStartConfig") tags = self._get_param("Tags") autotune = self._get_param("Autotune") hyper_parameter_tuning_job_arn = ( self.sagemaker_backend.create_hyper_parameter_tuning_job( hyper_parameter_tuning_job_name=hyper_parameter_tuning_job_name, hyper_parameter_tuning_job_config=hyper_parameter_tuning_job_config, training_job_definition=training_job_definition, training_job_definitions=training_job_definitions, warm_start_config=warm_start_config, tags=tags, autotune=autotune, ) ) return json.dumps( dict(HyperParameterTuningJobArn=hyper_parameter_tuning_job_arn) ) def describe_hyper_parameter_tuning_job(self) -> str: hyper_parameter_tuning_job_name = self._get_param("HyperParameterTuningJobName") hyper_parameter_tuning_job_description = ( self.sagemaker_backend.describe_hyper_parameter_tuning_job( hyper_parameter_tuning_job_name=hyper_parameter_tuning_job_name, ) ) return json.dumps(hyper_parameter_tuning_job_description) def list_hyper_parameter_tuning_jobs(self) -> str: next_token = self._get_param("NextToken") max_results = self._get_param("MaxResults") sort_by = self._get_param("SortBy") sort_order = self._get_param("SortOrder") name_contains = self._get_param("NameContains") creation_time_after = self._get_param("CreationTimeAfter") creation_time_before = self._get_param("CreationTimeBefore") last_modified_time_after = self._get_param("LastModifiedTimeAfter") last_modified_time_before = self._get_param("LastModifiedTimeBefore") status_equals = self._get_param("StatusEquals") hyper_parameter_tuning_jobs, next_token = ( self.sagemaker_backend.list_hyper_parameter_tuning_jobs( next_token=next_token, max_results=max_results, sort_by=sort_by, sort_order=sort_order, name_contains=name_contains, creation_time_after=creation_time_after, creation_time_before=creation_time_before, last_modified_time_after=last_modified_time_after, last_modified_time_before=last_modified_time_before, status_equals=status_equals, ) ) hyper_parameter_tuning_job_summaries = [ job.summary() for job in hyper_parameter_tuning_jobs ] return json.dumps( dict( HyperParameterTuningJobSummaries=hyper_parameter_tuning_job_summaries, NextToken=next_token, ) ) def delete_hyper_parameter_tuning_job(self) -> str: hyper_parameter_tuning_job_name = self._get_param("HyperParameterTuningJobName") self.sagemaker_backend.delete_hyper_parameter_tuning_job( hyper_parameter_tuning_job_name=hyper_parameter_tuning_job_name, ) return json.dumps(dict()) def create_model_quality_job_definition(self) -> str: job_definition_name = self._get_param("JobDefinitionName") model_quality_baseline_config = self._get_param("ModelQualityBaselineConfig") model_quality_app_specification = self._get_param( "ModelQualityAppSpecification" ) model_quality_job_input = self._get_param("ModelQualityJobInput") model_quality_job_output_config = self._get_param("ModelQualityJobOutputConfig") job_resources = self._get_param("JobResources") network_config = self._get_param("NetworkConfig") role_arn = self._get_param("RoleArn") stopping_condition = self._get_param("StoppingCondition") tags = self._get_param("Tags") job_definition_arn = self.sagemaker_backend.create_model_quality_job_definition( job_definition_name=job_definition_name, model_quality_baseline_config=model_quality_baseline_config, model_quality_app_specification=model_quality_app_specification, model_quality_job_input=model_quality_job_input, model_quality_job_output_config=model_quality_job_output_config, job_resources=job_resources, network_config=network_config, role_arn=role_arn, stopping_condition=stopping_condition, tags=tags, ) return json.dumps(dict(JobDefinitionArn=job_definition_arn)) def describe_model_quality_job_definition(self) -> str: job_definition_name = self._get_param("JobDefinitionName") description = self.sagemaker_backend.describe_model_quality_job_definition( job_definition_name=job_definition_name, ) return json.dumps(description) def list_model_quality_job_definitions(self) -> str: endpoint_name = self._get_param("EndpointName") sort_by = self._get_param("SortBy") sort_order = self._get_param("SortOrder") next_token = self._get_param("NextToken") max_results = self._get_param("MaxResults") name_contains = self._get_param("NameContains") creation_time_before = self._get_param("CreationTimeBefore") creation_time_after = self._get_param("CreationTimeAfter") job_definitions, next_token = ( self.sagemaker_backend.list_model_quality_job_definitions( endpoint_name=endpoint_name, sort_by=sort_by, sort_order=sort_order, next_token=next_token, max_results=max_results, name_contains=name_contains, creation_time_before=creation_time_before, creation_time_after=creation_time_after, ) ) job_definition_summaries = [x.summary() for x in job_definitions] return json.dumps( dict(JobDefinitionSummaries=job_definition_summaries, NextToken=next_token) ) def delete_model_quality_job_definition(self) -> str: job_definition_name = self._get_param("JobDefinitionName") self.sagemaker_backend.delete_model_quality_job_definition( job_definition_name=job_definition_name, ) return json.dumps(dict()) def create_model_card(self) -> str: model_card_name = self._get_param("ModelCardName") security_config = self._get_param("SecurityConfig") content = self._get_param("Content") model_card_status = self._get_param("ModelCardStatus") tags = self._get_param("Tags") model_card_arn = self.sagemaker_backend.create_model_card( model_card_name=model_card_name, security_config=security_config, content=content, model_card_status=model_card_status, tags=tags, ) return json.dumps(dict(ModelCardArn=model_card_arn)) def list_model_cards(self) -> str: creation_time_after = self._get_param("CreationTimeAfter") creation_time_before = self._get_param("CreationTimeBefore") max_results = self._get_param("MaxResults") name_contains = self._get_param("NameContains") model_card_status = self._get_param("ModelCardStatus") next_token = self._get_param("NextToken") sort_by = self._get_param("SortBy") sort_order = self._get_param("SortOrder") model_cards, next_token = self.sagemaker_backend.list_model_cards( creation_time_after=creation_time_after, creation_time_before=creation_time_before, max_results=max_results, name_contains=name_contains, model_card_status=model_card_status, next_token=next_token, sort_by=sort_by, sort_order=sort_order, ) model_card_summaries = [model_card.summary() for model_card in model_cards] return json.dumps( dict(ModelCardSummaries=model_card_summaries, NextToken=next_token) ) def list_model_card_versions(self) -> str: creation_time_after = self._get_param("CreationTimeAfter") creation_time_before = self._get_param("CreationTimeBefore") max_results = self._get_param("MaxResults") model_card_name = self._get_param("ModelCardName") model_card_status = self._get_param("ModelCardStatus") next_token = self._get_param("NextToken") sort_by = self._get_param("SortBy") sort_order = self._get_param("SortOrder") model_card_versions, next_token = ( self.sagemaker_backend.list_model_card_versions( creation_time_after=creation_time_after, creation_time_before=creation_time_before, max_results=max_results, model_card_name=model_card_name, model_card_status=model_card_status, next_token=next_token, sort_by=sort_by, sort_order=sort_order, ) ) model_card_version_summaries = [ mcv.version_summary() for mcv in model_card_versions ] return json.dumps( dict( ModelCardVersionSummaryList=model_card_version_summaries, NextToken=next_token, ) ) def update_model_card(self) -> str: model_card_name = self._get_param("ModelCardName") content = self._get_param("Content") model_card_status = self._get_param("ModelCardStatus") model_card_arn = self.sagemaker_backend.update_model_card( model_card_name=model_card_name, content=content, model_card_status=model_card_status, ) return json.dumps(dict(ModelCardArn=model_card_arn)) def describe_model_card(self) -> str: model_card_name = self._get_param("ModelCardName") model_card_version = self._get_param("ModelCardVersion") model_card_description = self.sagemaker_backend.describe_model_card( model_card_name=model_card_name, model_card_version=model_card_version, ) return json.dumps(model_card_description) def delete_model_card(self) -> str: model_card_name = self._get_param("ModelCardName") self.sagemaker_backend.delete_model_card( model_card_name=model_card_name, ) return json.dumps(dict())
Memory