"""EMRServerlessBackend class with methods for supported APIs.""" import inspect import re from datetime import datetime from typing import Any, Dict, Iterator, List, Optional, Tuple, Union from moto.core.base_backend import BackendDict, BaseBackend from moto.core.common_models import BaseModel from moto.core.utils import iso_8601_datetime_without_milliseconds from moto.emrcontainers.utils import paginated_list from moto.utilities.utils import get_partition from .exceptions import ( AccessDeniedException, ResourceNotFoundException, ValidationException, ) from .utils import ( default_auto_start_configuration, default_auto_stop_configuration, random_appplication_id, random_job_id, ) APPLICATION_ARN_TEMPLATE = "arn:{partition}:emr-serverless:{region}:{account_id}:/applications/{application_id}" JOB_RUN_ARN_TEMPLATE = "arn:{partition}:emr-serverless:{region}:{account_id}:/applications/{application_id}/jobruns/{job_run_id}" # Defaults used for creating an EMR Serverless application APPLICATION_STATUS = "STARTED" JOB_STATUS = "SUCCESS" class FakeApplication(BaseModel): def __init__( self, name: str, release_label: str, application_type: str, client_token: str, account_id: str, region_name: str, initial_capacity: str, maximum_capacity: str, tags: Dict[str, str], auto_start_configuration: str, auto_stop_configuration: str, network_configuration: str, ): # Provided parameters self.name = name self.release_label = release_label self.application_type = application_type.capitalize() self.client_token = client_token self.initial_capacity = initial_capacity self.maximum_capacity = maximum_capacity self.auto_start_configuration = ( auto_start_configuration or default_auto_start_configuration() ) self.auto_stop_configuration = ( auto_stop_configuration or default_auto_stop_configuration() ) self.network_configuration = network_configuration self.tags: Dict[str, str] = tags or {} # Service-generated-parameters self.id = random_appplication_id() self.arn = APPLICATION_ARN_TEMPLATE.format( partition="aws", region=region_name, account_id=account_id, application_id=self.id, ) self.state = APPLICATION_STATUS self.state_details = "" self.created_at = iso_8601_datetime_without_milliseconds( datetime.today().replace(hour=0, minute=0, second=0, microsecond=0) ) self.updated_at = self.created_at def __iter__(self) -> Iterator[Tuple[str, Any]]: yield "applicationId", self.id yield "name", self.name yield "arn", self.arn yield ( "autoStartConfig", self.auto_start_configuration, ) yield ( "autoStopConfig", self.auto_stop_configuration, ) def to_dict(self) -> Dict[str, Any]: """ Dictionary representation of an EMR Serverless Application. When used in `list-applications`, capacity, auto-start/stop configs, and tags are not returned. https://docs.aws.amazon.com/emr-serverless/latest/APIReference/API_ListApplications.html When used in `get-application`, more details are returned. https://docs.aws.amazon.com/emr-serverless/latest/APIReference/API_GetApplication.html#API_GetApplication_ResponseSyntax """ caller_methods = inspect.stack()[1].function caller_methods_type = caller_methods.split("_")[0] if caller_methods_type in ["get", "update"]: response = { "applicationId": self.id, "name": self.name, "arn": self.arn, "releaseLabel": self.release_label, "type": self.application_type, "state": self.state, "stateDetails": self.state_details, "createdAt": self.created_at, "updatedAt": self.updated_at, "autoStartConfiguration": self.auto_start_configuration, "autoStopConfiguration": self.auto_stop_configuration, "tags": self.tags, } else: response = { "id": self.id, "name": self.name, "arn": self.arn, "releaseLabel": self.release_label, "type": self.application_type, "state": self.state, "stateDetails": self.state_details, "createdAt": self.created_at, "updatedAt": self.updated_at, } if self.network_configuration: response.update({"networkConfiguration": self.network_configuration}) if self.initial_capacity: response.update({"initialCapacity": self.initial_capacity}) if self.maximum_capacity: response.update({"maximumCapacity": self.maximum_capacity}) return response class FakeJobRun(BaseModel): def __init__( self, application_id: str, client_token: str, execution_role_arn: str, account_id: str, region_name: str, release_label: str, application_type: str, job_driver: Optional[Dict[str, Dict[str, Union[str, List[str]]]]], configuration_overrides: Optional[Dict[str, Union[List[Any], Dict[str, Any]]]], tags: Optional[Dict[str, str]], network_configuration: Optional[Dict[str, List[str]]], execution_timeout_minutes: Optional[int], name: Optional[str], ): self.name = name self.application_id = application_id self.client_token = client_token self.execution_role_arn = execution_role_arn self.job_driver = job_driver self.configuration_overrides = configuration_overrides self.network_configuration = network_configuration self.execution_timeout_minutes = execution_timeout_minutes or 720 # Service-generated-parameters self.id = random_job_id() self.arn = JOB_RUN_ARN_TEMPLATE.format( partition="aws", account_id=account_id, application_id=self.application_id, region=region_name, job_run_id=self.id, ) self.release_label = release_label self.application_type = application_type self.state = JOB_STATUS self.state_details: Optional[str] = None self.created_by: Optional[str] = None self.created_at: str = iso_8601_datetime_without_milliseconds( datetime.today().replace(hour=0, minute=0, second=0, microsecond=0) ) self.updated_at: str = self.created_at self.total_execution_duration_seconds: int = 0 self.billed_resource_utilization: Dict[str, float] = { "vCPUHour": 0.0, "memoryGBHour": 0.0, "storageGBHour": 0.0, } self.tags = tags def to_dict(self, caller_methods_type: str) -> Dict[str, Any]: # The response structure is different for get/update and list if caller_methods_type in ["get", "update"]: response = { "applicationId": self.application_id, "jobRunId": self.id, "name": self.name, "arn": self.arn, "createdBy": self.created_by, "createdAt": self.created_at, "updatedAt": self.updated_at, "executionRole": self.execution_role_arn, "state": self.state, "stateDetails": self.state_details, "releaseLabel": self.release_label, "configurationOverrides": self.configuration_overrides, "jobDriver": self.job_driver, "tags": self.tags, "networkConfiguration": self.network_configuration, "totalExecutionDurationSeconds": self.total_execution_duration_seconds, "executionTimeoutMinutes": self.execution_timeout_minutes, "billedResourceUtilization": self.billed_resource_utilization, } else: response = { "applicationId": self.application_id, "id": self.id, "name": self.name, "arn": self.arn, "createdBy": self.created_by, "createdAt": self.created_at, "updatedAt": self.updated_at, "executionRole": self.execution_role_arn, "state": self.state, "stateDetails": self.state_details, "releaseLabel": self.release_label, "type": self.application_type, } return response class EMRServerlessBackend(BaseBackend): """Implementation of EMRServerless APIs.""" def __init__(self, region_name: str, account_id: str): super().__init__(region_name, account_id) self.region_name = region_name self.partition = get_partition(region_name) self.applications: Dict[str, FakeApplication] = dict() self.job_runs: Dict[str, List[FakeJobRun]] = ( dict() ) # {application_id: [job_run1, job_run2]} def create_application( self, name: str, release_label: str, application_type: str, client_token: str, initial_capacity: str, maximum_capacity: str, tags: Dict[str, str], auto_start_configuration: str, auto_stop_configuration: str, network_configuration: str, ) -> FakeApplication: if application_type not in ["HIVE", "SPARK"]: raise ValidationException(f"Unsupported engine {application_type}") if not re.match(r"emr-[0-9]{1}\.[0-9]{1,2}\.0(" "|-[0-9]{8})", release_label): raise ValidationException( f"Type '{application_type}' is not supported for release label '{release_label}' or release label does not exist" ) application = FakeApplication( name=name, release_label=release_label, application_type=application_type, account_id=self.account_id, region_name=self.region_name, client_token=client_token, initial_capacity=initial_capacity, maximum_capacity=maximum_capacity, tags=tags, auto_start_configuration=auto_start_configuration, auto_stop_configuration=auto_stop_configuration, network_configuration=network_configuration, ) self.applications[application.id] = application return application def delete_application(self, application_id: str) -> None: if application_id not in self.applications.keys(): raise ResourceNotFoundException(application_id) if self.applications[application_id].state not in ["CREATED", "STOPPED"]: raise ValidationException( f"Application {application_id} must be in one of the following statuses [CREATED, STOPPED]. " f"Current status: {self.applications[application_id].state}" ) self.applications[application_id].state = "TERMINATED" def get_application(self, application_id: str) -> Dict[str, Any]: if application_id not in self.applications.keys(): raise ResourceNotFoundException(application_id) return self.applications[application_id].to_dict() def list_applications( self, next_token: Optional[str], max_results: int, states: Optional[List[str]] ) -> Tuple[List[Dict[str, Any]], Optional[str]]: applications = [ application.to_dict() for application in self.applications.values() ] if states: applications = [ application for application in applications if application["state"] in states ] sort_key = "name" return paginated_list(applications, sort_key, max_results, next_token) def start_application(self, application_id: str) -> None: if application_id not in self.applications.keys(): raise ResourceNotFoundException(application_id) self.applications[application_id].state = "STARTED" def stop_application(self, application_id: str) -> None: if application_id not in self.applications.keys(): raise ResourceNotFoundException(application_id) self.applications[application_id].state = "STOPPED" def update_application( self, application_id: str, initial_capacity: Optional[str], maximum_capacity: Optional[str], auto_start_configuration: Optional[str], auto_stop_configuration: Optional[str], network_configuration: Optional[str], ) -> Dict[str, Any]: if application_id not in self.applications.keys(): raise ResourceNotFoundException(application_id) if self.applications[application_id].state not in ["CREATED", "STOPPED"]: raise ValidationException( f"Application {application_id} must be in one of the following statuses [CREATED, STOPPED]. " f"Current status: {self.applications[application_id].state}" ) if initial_capacity: self.applications[application_id].initial_capacity = initial_capacity if maximum_capacity: self.applications[application_id].maximum_capacity = maximum_capacity if auto_start_configuration: self.applications[ application_id ].auto_start_configuration = auto_start_configuration if auto_stop_configuration: self.applications[ application_id ].auto_stop_configuration = auto_stop_configuration if network_configuration: self.applications[ application_id ].network_configuration = network_configuration self.applications[ application_id ].updated_at = iso_8601_datetime_without_milliseconds( datetime.today().replace(hour=0, minute=0, second=0, microsecond=0) ) return self.applications[application_id].to_dict() def start_job_run( self, application_id: str, client_token: str, execution_role_arn: str, job_driver: Optional[Dict[str, Dict[str, Union[str, List[str]]]]], configuration_overrides: Optional[Dict[str, Union[List[Any], Dict[str, Any]]]], tags: Optional[Dict[str, str]], execution_timeout_minutes: Optional[int], name: Optional[str], ) -> FakeJobRun: role_account_id = execution_role_arn.split(":")[4] if role_account_id != self.account_id: raise AccessDeniedException("Cross-account pass role is not allowed.") if execution_timeout_minutes and execution_timeout_minutes < 5: raise ValidationException("RunTimeout must be at least 5 minutes.") application_resp = self.get_application(application_id) job_run = FakeJobRun( application_id=application_id, client_token=client_token, execution_role_arn=execution_role_arn, account_id=self.account_id, region_name=self.region_name, release_label=application_resp["releaseLabel"], application_type=application_resp["type"], job_driver=job_driver, configuration_overrides=configuration_overrides, tags=tags, network_configuration=application_resp.get("networkConfiguration"), execution_timeout_minutes=execution_timeout_minutes, name=name, ) if application_resp["state"] == "TERMINATED": raise ValidationException( f"Application {application_id} is terminated. Cannot start job run." ) if application_id not in self.job_runs: self.job_runs[application_id] = [] self.job_runs[application_id].append(job_run) return job_run def get_job_run(self, application_id: str, job_run_id: str) -> FakeJobRun: if application_id not in self.job_runs.keys(): raise ResourceNotFoundException(application_id, "Application") job_run_ids = [job_run.id for job_run in self.job_runs[application_id]] if job_run_id not in job_run_ids: raise ResourceNotFoundException(job_run_id, "JobRun") filtered_job_runs = [ job_run for job_run in self.job_runs[application_id] if job_run.id == job_run_id ] assert len(filtered_job_runs) == 1 job_run: FakeJobRun = filtered_job_runs[0] return job_run def cancel_job_run(self, application_id: str, job_run_id: str) -> Tuple[str, str]: # implement here if application_id not in self.job_runs.keys(): raise ResourceNotFoundException(application_id, "Application") job_run_ids = [job_run.id for job_run in self.job_runs[application_id]] if job_run_id not in job_run_ids: raise ResourceNotFoundException(job_run_id, "JobRun") self.job_runs[application_id][job_run_ids.index(job_run_id)].state = "CANCELLED" return application_id, job_run_id def list_job_runs( self, application_id: str, max_results: int, next_token: Optional[str], created_at_after: Optional[str], created_at_before: Optional[str], states: Optional[List[str]], ) -> Tuple[List[Dict[str, Any]], Optional[str]]: if application_id not in self.job_runs.keys(): raise ResourceNotFoundException(application_id, "Application") job_runs = self.job_runs[application_id] if states: job_runs = [job_run for job_run in job_runs if job_run.state in states] if created_at_after: job_runs = [ job_run for job_run in job_runs if datetime.strptime(job_run.created_at, "%Y-%m-%dT%H:%M:%SZ") > datetime.strptime(created_at_after, "%Y-%m-%dT%H:%M:%SZ") ] if created_at_before: job_runs = [ job_run for job_run in job_runs if datetime.strptime(job_run.created_at, "%Y-%m-%dT%H:%M:%SZ") < datetime.strptime(created_at_before, "%Y-%m-%dT%H:%M:%SZ") ] job_run_dicts = [job_run.to_dict("list") for job_run in job_runs] sort_key = "createdAt" return paginated_list(job_run_dicts, sort_key, max_results, next_token) emrserverless_backends = BackendDict(EMRServerlessBackend, "emr-serverless")
Memory