"""BedrockBackend class with methods for supported APIs."""
import re
from datetime import datetime
from typing import Any, Dict, List, Optional
from moto.bedrock.exceptions import (
ResourceInUseException,
ResourceNotFoundException,
TooManyTagsException,
ValidationException,
)
from moto.core.base_backend import BackendDict, BaseBackend
from moto.core.common_models import BaseModel
from moto.utilities.paginator import paginate
from moto.utilities.tagging_service import TaggingService
from moto.utilities.utils import get_partition
class ModelCustomizationJob(BaseModel):
def __init__(
self,
job_name: str,
custom_model_name: str,
role_arn: str,
base_model_identifier: str,
training_data_config: Dict[str, str],
output_data_config: Dict[str, str],
hyper_parameters: Dict[str, str],
region_name: str,
account_id: str,
client_request_token: Optional[str],
customization_type: Optional[str],
custom_model_kms_key_id: Optional[str],
job_tags: Optional[List[Dict[str, str]]],
custom_model_tags: Optional[List[Dict[str, str]]],
validation_data_config: Optional[Dict[str, Any]],
vpc_config: Optional[Dict[str, Any]],
):
self.job_name = job_name
self.custom_model_name = custom_model_name
self.role_arn = role_arn
self.client_request_token = client_request_token
self.base_model_identifier = base_model_identifier
self.customization_type = customization_type
self.custom_model_kms_key_id = custom_model_kms_key_id
self.job_tags = job_tags
self.custom_model_tags = custom_model_tags
if "s3Uri" not in training_data_config or not re.match(
r"s3://.*", training_data_config["s3Uri"]
):
raise ValidationException(
"Validation error detected: "
f"Value '{training_data_config}' at 'training_data_config' failed to satisfy constraint: "
"Member must satisfy regular expression pattern: "
"s3://.*"
)
self.training_data_config = training_data_config
if validation_data_config:
if "validators" in validation_data_config:
for validator in validation_data_config["validators"]:
if not re.match(r"s3://.*", validator["s3Uri"]):
raise ValidationException(
"Validation error detected: "
f"Value '{validator}' at 'validation_data_config' failed to satisfy constraint: "
"Member must satisfy regular expression pattern: "
"s3://.*"
)
self.validation_data_config = validation_data_config
if "s3Uri" not in output_data_config or not re.match(
r"s3://.*", output_data_config["s3Uri"]
):
raise ValidationException(
"Validation error detected: "
f"Value '{output_data_config}' at 'output_data_config' failed to satisfy constraint: "
"Member must satisfy regular expression pattern: "
"s3://.*"
)
self.output_data_config = output_data_config
self.hyper_parameters = hyper_parameters
self.vpc_config = vpc_config
self.region_name = region_name
self.account_id = account_id
self.job_arn = f"arn:{get_partition(self.region_name)}:bedrock:{self.region_name}:{self.account_id}:model-customization-job/{self.job_name}"
self.output_model_name = f"{self.custom_model_name}-{self.job_name}"
self.output_model_arn = f"arn:{get_partition(self.region_name)}:bedrock:{self.region_name}:{self.account_id}:custom-model/{self.output_model_name}"
self.status = "InProgress"
self.failure_message = "Failure Message"
self.creation_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
self.last_modified_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
self.end_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
self.base_model_arn = f"arn:{get_partition(self.region_name)}:bedrock:{self.region_name}::foundation-model/{self.base_model_identifier}"
self.output_model_kms_key_arn = f"arn:{get_partition(self.region_name)}:kms:{self.region_name}:{self.account_id}:key/{self.output_model_name}-kms-key"
self.training_metrics = {"trainingLoss": 0.0} # hard coded
self.validation_metrics = [{"validationLoss": 0.0}] # hard coded
def to_dict(self) -> Dict[str, Any]:
dct = {
"baseModelArn": self.base_model_arn,
"clientRequestToken": self.client_request_token,
"creationTime": self.creation_time,
"customizationType": self.customization_type,
"endTime": self.end_time,
"failureMessage": self.failure_message,
"hyperParameters": self.hyper_parameters,
"jobArn": self.job_arn,
"jobName": self.job_name,
"lastModifiedTime": self.last_modified_time,
"outputDataConfig": self.output_data_config,
"outputModelArn": self.output_model_arn,
"outputModelKmsKeyArn": self.output_model_kms_key_arn,
"outputModelName": self.output_model_name,
"roleArn": self.role_arn,
"status": self.status,
"trainingDataConfig": self.training_data_config,
"trainingMetrics": self.training_metrics,
"validationDataConfig": self.validation_data_config,
"validationMetrics": self.validation_metrics,
"vpcConfig": self.vpc_config,
}
return {k: v for k, v in dct.items() if v}
class CustomModel(BaseModel):
def __init__(
self,
model_name: str,
job_name: str,
job_arn: str,
base_model_arn: str,
hyper_parameters: Dict[str, str],
output_data_config: Dict[str, str],
training_data_config: Dict[str, str],
training_metrics: Dict[str, float],
base_model_name: str,
region_name: str,
account_id: str,
customization_type: Optional[str],
model_kms_key_arn: Optional[str],
validation_data_config: Optional[Dict[str, Any]],
validation_metrics: Optional[List[Dict[str, float]]],
):
self.model_name = model_name
self.job_name = job_name
self.job_arn = job_arn
self.base_model_arn = base_model_arn
self.customization_type = customization_type
self.model_kms_key_arn = model_kms_key_arn
self.hyper_parameters = hyper_parameters
self.training_data_config = training_data_config
self.validation_data_config = validation_data_config
self.output_data_config = output_data_config
self.training_metrics = training_metrics
self.validation_metrics = validation_metrics
self.region_name = region_name
self.account_id = account_id
self.model_arn = f"arn:{get_partition(self.region_name)}:bedrock:{self.region_name}:{self.account_id}:custom-model/{self.model_name}"
self.creation_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
self.base_model_name = base_model_name
def to_dict(self) -> Dict[str, Any]:
dct = {
"baseModelArn": self.base_model_arn,
"creationTime": self.creation_time,
"customizationType": self.customization_type,
"hyperParameters": self.hyper_parameters,
"jobArn": self.job_arn,
"jobName": self.job_name,
"modelArn": self.model_arn,
"modelKmsKeyArn": self.model_kms_key_arn,
"modelName": self.model_name,
"outputDataConfig": self.output_data_config,
"trainingDataConfig": self.training_data_config,
"trainingMetrics": self.training_metrics,
"validationDataConfig": self.validation_data_config,
"validationMetrics": self.validation_metrics,
}
return {k: v for k, v in dct.items() if v}
class model_invocation_logging_configuration(BaseModel):
def __init__(self, logging_config: Dict[str, Any]) -> None:
self.logging_config = logging_config
class BedrockBackend(BaseBackend):
"""Implementation of Bedrock APIs."""
PAGINATION_MODEL = {
"list_model_customization_jobs": {
"input_token": "next_token",
"limit_key": "max_results",
"limit_default": 100,
"unique_attribute": "job_arn",
},
"list_custom_models": {
"input_token": "next_token",
"limit_key": "max_results",
"limit_default": 100,
"unique_attribute": "model_arn",
},
}
def __init__(self, region_name: str, account_id: str) -> None:
super().__init__(region_name, account_id)
self.model_customization_jobs: Dict[str, ModelCustomizationJob] = {}
self.custom_models: Dict[str, CustomModel] = {}
self.model_invocation_logging_configuration: Optional[
model_invocation_logging_configuration
] = None
self.tagger = TaggingService()
def _list_arns(self) -> List[str]:
return [job.job_arn for job in self.model_customization_jobs.values()] + [
model.model_arn for model in self.custom_models.values()
]
def create_model_customization_job(
self,
job_name: str,
custom_model_name: str,
role_arn: str,
base_model_identifier: str,
training_data_config: Dict[str, Any],
output_data_config: Dict[str, str],
hyper_parameters: Dict[str, str],
client_request_token: Optional[str],
customization_type: Optional[str],
custom_model_kms_key_id: Optional[str],
job_tags: Optional[List[Dict[str, str]]],
custom_model_tags: Optional[List[Dict[str, str]]],
validation_data_config: Optional[Dict[str, Any]],
vpc_config: Optional[Dict[str, Any]],
) -> str:
if job_name in self.model_customization_jobs.keys():
raise ResourceInUseException(
f"Model customization job {job_name} already exists"
)
if custom_model_name in self.custom_models.keys():
raise ResourceInUseException(
f"Custom model {custom_model_name} already exists"
)
model_customization_job = ModelCustomizationJob(
job_name,
custom_model_name,
role_arn,
base_model_identifier,
training_data_config,
output_data_config,
hyper_parameters,
self.region_name,
self.account_id,
client_request_token,
customization_type,
custom_model_kms_key_id,
job_tags,
custom_model_tags,
validation_data_config,
vpc_config,
)
self.model_customization_jobs[job_name] = model_customization_job
if job_tags:
self.tag_resource(model_customization_job.job_arn, job_tags)
# Create associated custom model
custom_model = CustomModel(
custom_model_name,
job_name,
model_customization_job.job_arn,
model_customization_job.base_model_arn,
model_customization_job.hyper_parameters,
model_customization_job.output_data_config,
model_customization_job.training_data_config,
model_customization_job.training_metrics,
model_customization_job.base_model_identifier,
self.region_name,
self.account_id,
model_customization_job.customization_type,
model_customization_job.output_model_kms_key_arn,
model_customization_job.validation_data_config,
model_customization_job.validation_metrics,
)
self.custom_models[custom_model_name] = custom_model
if custom_model_tags:
self.tag_resource(custom_model.model_arn, custom_model_tags)
return model_customization_job.job_arn
def get_model_customization_job(self, job_identifier: str) -> ModelCustomizationJob:
if job_identifier not in self.model_customization_jobs:
raise ResourceNotFoundException(
f"Model customization job {job_identifier} not found"
)
else:
return self.model_customization_jobs[job_identifier]
def stop_model_customization_job(self, job_identifier: str) -> None:
if job_identifier in self.model_customization_jobs:
self.model_customization_jobs[job_identifier].status = "Stopped"
else:
raise ResourceNotFoundException(
f"Model customization job {job_identifier} not found"
)
return
@paginate(pagination_model=PAGINATION_MODEL)
def list_model_customization_jobs(
self,
creation_time_after: Optional[datetime],
creation_time_before: Optional[datetime],
status_equals: Optional[str],
name_contains: Optional[str],
sort_by: Optional[str],
sort_order: Optional[str],
) -> List[ModelCustomizationJob]:
customization_jobs_fetched = list(self.model_customization_jobs.values())
if name_contains is not None:
customization_jobs_fetched = list(
filter(
lambda x: name_contains in x.job_name,
customization_jobs_fetched,
)
)
if creation_time_after is not None:
customization_jobs_fetched = list(
filter(
lambda x: x.creation_time > str(creation_time_after),
customization_jobs_fetched,
)
)
if creation_time_before is not None:
customization_jobs_fetched = list(
filter(
lambda x: x.creation_time < str(creation_time_before),
customization_jobs_fetched,
)
)
if status_equals is not None:
customization_jobs_fetched = list(
filter(
lambda x: x.status == status_equals,
customization_jobs_fetched,
)
)
if sort_by is not None:
if sort_by == "CreationTime":
if sort_order is not None and sort_order == "Ascending":
customization_jobs_fetched = sorted(
customization_jobs_fetched, key=lambda x: x.creation_time
)
elif sort_order is not None and sort_order == "Descending":
customization_jobs_fetched = sorted(
customization_jobs_fetched,
key=lambda x: x.creation_time,
reverse=True,
)
else:
raise ValidationException(f"Invalid sort order: {sort_order}")
else:
raise ValidationException(f"Invalid sort by field: {sort_by}")
return customization_jobs_fetched
def get_model_invocation_logging_configuration(self) -> Optional[Dict[str, Any]]:
if self.model_invocation_logging_configuration:
return self.model_invocation_logging_configuration.logging_config
else:
return {}
def put_model_invocation_logging_configuration(
self, logging_config: Dict[str, Any]
) -> None:
invocation_logging = model_invocation_logging_configuration(logging_config)
self.model_invocation_logging_configuration = invocation_logging
return
def get_custom_model(self, model_identifier: str) -> CustomModel:
if model_identifier[:3] == "arn":
for model in self.custom_models.values():
if model.model_arn == model_identifier:
return model
raise ResourceNotFoundException(
f"Custom model {model_identifier} not found"
)
elif model_identifier in self.custom_models:
return self.custom_models[model_identifier]
else:
raise ResourceNotFoundException(
f"Custom model {model_identifier} not found"
)
def delete_custom_model(self, model_identifier: str) -> None:
if model_identifier in self.custom_models:
del self.custom_models[model_identifier]
else:
raise ResourceNotFoundException(
f"Custom model {model_identifier} not found"
)
return
@paginate(pagination_model=PAGINATION_MODEL)
def list_custom_models(
self,
creation_time_before: Optional[datetime],
creation_time_after: Optional[datetime],
name_contains: Optional[str],
base_model_arn_equals: Optional[str],
foundation_model_arn_equals: Optional[str],
sort_by: Optional[str],
sort_order: Optional[str],
) -> List[CustomModel]:
"""
The foundation_model_arn_equals-argument is not yet supported
"""
custom_models_fetched = list(self.custom_models.values())
if name_contains is not None:
custom_models_fetched = list(
filter(
lambda x: name_contains in x.job_name,
custom_models_fetched,
)
)
if creation_time_after is not None:
custom_models_fetched = list(
filter(
lambda x: x.creation_time > str(creation_time_after),
custom_models_fetched,
)
)
if creation_time_before is not None:
custom_models_fetched = list(
filter(
lambda x: x.creation_time < str(creation_time_before),
custom_models_fetched,
)
)
if base_model_arn_equals is not None:
custom_models_fetched = list(
filter(
lambda x: x.base_model_arn == base_model_arn_equals,
custom_models_fetched,
)
)
if sort_by is not None:
if sort_by == "CreationTime":
if sort_order is not None and sort_order == "Ascending":
custom_models_fetched = sorted(
custom_models_fetched, key=lambda x: x.creation_time
)
elif sort_order is not None and sort_order == "Descending":
custom_models_fetched = sorted(
custom_models_fetched,
key=lambda x: x.creation_time,
reverse=True,
)
else:
raise ValidationException(f"Invalid sort order: {sort_order}")
else:
raise ValidationException(f"Invalid sort by field: {sort_by}")
return custom_models_fetched
def tag_resource(self, resource_arn: str, tags: List[Dict[str, str]]) -> None:
if resource_arn not in self._list_arns():
raise ResourceNotFoundException(f"Resource {resource_arn} not found")
fixed_tags = []
if len(tags) + len(self.tagger.list_tags_for_resource(resource_arn)) > 50:
raise TooManyTagsException(
"Member must have length less than or equal to 50"
)
for tag_dict in tags:
fixed_tags.append({"Key": tag_dict["key"], "Value": tag_dict["value"]})
self.tagger.tag_resource(resource_arn, fixed_tags)
return
def untag_resource(self, resource_arn: str, tag_keys: List[str]) -> None:
if resource_arn not in self._list_arns():
raise ResourceNotFoundException(f"Resource {resource_arn} not found")
self.tagger.untag_resource_using_names(resource_arn, tag_keys)
return
def list_tags_for_resource(self, resource_arn: str) -> List[Dict[str, str]]:
if resource_arn not in self._list_arns():
raise ResourceNotFoundException(f"Resource {resource_arn} not found")
tags = self.tagger.list_tags_for_resource(resource_arn)
fixed_tags = []
for tag_dict in tags["Tags"]:
fixed_tags.append({"key": tag_dict["Key"], "value": tag_dict["Value"]})
return fixed_tags
def delete_model_invocation_logging_configuration(self) -> None:
if self.model_invocation_logging_configuration:
self.model_invocation_logging_configuration.logging_config = {}
return
bedrock_backends = BackendDict(BedrockBackend, "bedrock")