import json import re from datetime import datetime from typing import Any, Dict, List, Optional, Pattern from dateutil.tz import tzlocal from moto import settings from moto.core.base_backend import BackendDict, BaseBackend from moto.core.common_models import CloudFormationModel from moto.core.utils import iso_8601_datetime_with_milliseconds from moto.moto_api._internal import mock_random from moto.utilities.paginator import paginate from moto.utilities.utils import ARN_PARTITION_REGEX, get_partition from .exceptions import ( ExecutionAlreadyExists, ExecutionDoesNotExist, InvalidArn, InvalidExecutionInput, InvalidName, NameTooLongException, ResourceNotFound, StateMachineDoesNotExist, ) from .utils import PAGINATION_MODEL, api_to_cfn_tags, cfn_to_api_tags class StateMachineInstance: def __init__( self, arn: str, name: str, definition: str, roleArn: str, encryptionConfiguration: Optional[Dict[str, Any]] = None, loggingConfiguration: Optional[Dict[str, Any]] = None, tracingConfiguration: Optional[Dict[str, Any]] = None, ): self.creation_date = iso_8601_datetime_with_milliseconds() self.update_date = self.creation_date self.arn = arn self.name = name self.definition = definition self.roleArn = roleArn self.executions: List[Execution] = [] self.type = "STANDARD" self.encryptionConfiguration = encryptionConfiguration or { "type": "AWS_OWNED_KEY" } self.loggingConfiguration = loggingConfiguration or {"level": "OFF"} self.tracingConfiguration = tracingConfiguration or {"enabled": False} self.sm_type = "STANDARD" # or express self.description: Optional[str] = None class StateMachineVersion(StateMachineInstance, CloudFormationModel): def __init__( self, source: StateMachineInstance, version: int, description: Optional[str] ): version_arn = f"{source.arn}:{version}" StateMachineInstance.__init__( self, arn=version_arn, name=source.name, definition=source.definition, roleArn=source.roleArn, encryptionConfiguration=source.encryptionConfiguration, loggingConfiguration=source.loggingConfiguration, tracingConfiguration=source.tracingConfiguration, ) self.source_arn = source.arn self.version = version self.description = description class StateMachine(StateMachineInstance, CloudFormationModel): def __init__( self, arn: str, name: str, definition: str, roleArn: str, tags: Optional[List[Dict[str, str]]] = None, encryptionConfiguration: Optional[Dict[str, Any]] = None, loggingConfiguration: Optional[Dict[str, Any]] = None, tracingConfiguration: Optional[Dict[str, Any]] = None, ): StateMachineInstance.__init__( self, arn=arn, name=name, definition=definition, roleArn=roleArn, encryptionConfiguration=encryptionConfiguration, loggingConfiguration=loggingConfiguration, tracingConfiguration=tracingConfiguration, ) self.tags: List[Dict[str, str]] = [] if tags: self.add_tags(tags) self.latest_version_number = 0 self.versions: Dict[int, StateMachineVersion] = {} self.latest_version: Optional[StateMachineVersion] = None def publish(self, description: Optional[str]) -> None: new_version_number = self.latest_version_number + 1 new_version = StateMachineVersion( source=self, version=new_version_number, description=description ) self.versions[new_version_number] = new_version self.latest_version = new_version self.latest_version_number = new_version_number def start_execution( self, region_name: str, account_id: str, execution_name: str, execution_input: str, ) -> "Execution": self._ensure_execution_name_doesnt_exist(execution_name) self._validate_execution_input(execution_input) execution = Execution( region_name=region_name, account_id=account_id, state_machine_name=self.name, execution_name=execution_name, state_machine_arn=self.arn, execution_input=json.loads(execution_input), ) self.executions.append(execution) return execution def stop_execution(self, execution_arn: str) -> "Execution": execution = next( (x for x in self.executions if x.execution_arn == execution_arn), None ) if not execution: raise ExecutionDoesNotExist( "Execution Does Not Exist: '" + execution_arn + "'" ) execution.stop(stop_date=datetime.now(), error="", cause="") return execution def _ensure_execution_name_doesnt_exist(self, name: str) -> None: for execution in self.executions: if execution.name == name: raise ExecutionAlreadyExists( "Execution Already Exists: '" + execution.execution_arn + "'" ) def _validate_execution_input(self, execution_input: str) -> None: try: json.loads(execution_input) except Exception as ex: raise InvalidExecutionInput( "Invalid State Machine Execution Input: '" + str(ex) + "'" ) def update(self, **kwargs: Any) -> None: for key, value in kwargs.items(): if value is not None: setattr(self, key, value) self.update_date = iso_8601_datetime_with_milliseconds() def add_tags(self, tags: List[Dict[str, str]]) -> List[Dict[str, str]]: merged_tags = [] for tag in self.tags: replacement_index = next( (index for (index, d) in enumerate(tags) if d["key"] == tag["key"]), None, ) if replacement_index is not None: replacement = tags.pop(replacement_index) merged_tags.append(replacement) else: merged_tags.append(tag) for tag in tags: merged_tags.append(tag) self.tags = merged_tags return self.tags def remove_tags(self, tag_keys: List[str]) -> List[Dict[str, str]]: self.tags = [tag_set for tag_set in self.tags if tag_set["key"] not in tag_keys] return self.tags @property def physical_resource_id(self) -> str: return self.arn def get_cfn_properties(self, prop_overrides: Dict[str, Any]) -> Dict[str, Any]: property_names = [ "DefinitionString", "RoleArn", "StateMachineName", ] properties = {} for prop in property_names: properties[prop] = prop_overrides.get(prop, self.get_cfn_attribute(prop)) # Special handling for Tags overridden_keys = [tag["Key"] for tag in prop_overrides.get("Tags", [])] original_tags_to_include = [ tag for tag in self.get_cfn_attribute("Tags") if tag["Key"] not in overridden_keys ] properties["Tags"] = original_tags_to_include + prop_overrides.get("Tags", []) return properties @classmethod def has_cfn_attr(cls, attr: str) -> bool: return attr in [ "Name", "DefinitionString", "RoleArn", "StateMachineName", "Tags", ] def get_cfn_attribute(self, attribute_name: str) -> Any: from moto.cloudformation.exceptions import UnformattedGetAttTemplateException if attribute_name == "Name": return self.name elif attribute_name == "DefinitionString": return self.definition elif attribute_name == "RoleArn": return self.roleArn elif attribute_name == "StateMachineName": return self.name elif attribute_name == "Tags": return api_to_cfn_tags(self.tags) raise UnformattedGetAttTemplateException() @staticmethod def cloudformation_name_type() -> str: return "StateMachine" @staticmethod def cloudformation_type() -> str: return "AWS::StepFunctions::StateMachine" @classmethod def create_from_cloudformation_json( # type: ignore[misc] cls, resource_name: str, cloudformation_json: Any, account_id: str, region_name: str, **kwargs: Any, ) -> "StateMachine": properties = cloudformation_json["Properties"] name = properties.get("StateMachineName", resource_name) definition = properties.get("DefinitionString", "") role_arn = properties.get("RoleArn", "") tags = cfn_to_api_tags(properties.get("Tags", [])) sf_backend = stepfunctions_backends[account_id][region_name] return sf_backend.create_state_machine(name, definition, role_arn, tags=tags) @classmethod def delete_from_cloudformation_json( # type: ignore[misc] cls, resource_name: str, cloudformation_json: Any, account_id: str, region_name: str, ) -> None: sf_backend = stepfunctions_backends[account_id][region_name] sf_backend.delete_state_machine(resource_name) @classmethod def update_from_cloudformation_json( # type: ignore[misc] cls, original_resource: Any, new_resource_name: str, cloudformation_json: Any, account_id: str, region_name: str, ) -> "StateMachine": properties = cloudformation_json.get("Properties", {}) name = properties.get("StateMachineName", original_resource.name) if name != original_resource.name: # Replacement new_properties = original_resource.get_cfn_properties(properties) cloudformation_json["Properties"] = new_properties new_resource = cls.create_from_cloudformation_json( name, cloudformation_json, account_id, region_name ) cls.delete_from_cloudformation_json( original_resource.arn, cloudformation_json, account_id, region_name ) return new_resource else: # No Interruption definition = properties.get("DefinitionString") role_arn = properties.get("RoleArn") tags = cfn_to_api_tags(properties.get("Tags", [])) sf_backend = stepfunctions_backends[account_id][region_name] state_machine = sf_backend.update_state_machine( original_resource.arn, definition=definition, role_arn=role_arn ) state_machine.add_tags(tags) return state_machine class Execution: def __init__( self, region_name: str, account_id: str, state_machine_name: str, execution_name: str, state_machine_arn: str, execution_input: str, ): execution_arn = "arn:{}:states:{}:{}:execution:{}:{}" execution_arn = execution_arn.format( get_partition(region_name), region_name, account_id, state_machine_name, execution_name, ) self.execution_arn = execution_arn self.name = execution_name self.start_date = datetime.now() self.state_machine_arn = state_machine_arn self.execution_input = execution_input self.status = ( "RUNNING" if settings.get_sf_execution_history_type() == "SUCCESS" else "FAILED" ) self.stop_date: Optional[datetime] = None self.account_id = account_id self.region_name = region_name self.output: Optional[str] = None self.output_details: Optional[str] = None self.cause: Optional[str] = None self.error: Optional[str] = None def get_execution_history(self, roleArn: str) -> List[Dict[str, Any]]: sf_execution_history_type = settings.get_sf_execution_history_type() if sf_execution_history_type == "SUCCESS": return [ { "timestamp": iso_8601_datetime_with_milliseconds( datetime(2020, 1, 1, 0, 0, 0, tzinfo=tzlocal()) ), "type": "ExecutionStarted", "id": 1, "previousEventId": 0, "executionStartedEventDetails": { "input": "{}", "inputDetails": {"truncated": False}, "roleArn": roleArn, }, }, { "timestamp": iso_8601_datetime_with_milliseconds( datetime(2020, 1, 1, 0, 0, 10, tzinfo=tzlocal()) ), "type": "PassStateEntered", "id": 2, "previousEventId": 0, "stateEnteredEventDetails": { "name": "A State", "input": "{}", "inputDetails": {"truncated": False}, }, }, { "timestamp": iso_8601_datetime_with_milliseconds( datetime(2020, 1, 1, 0, 0, 10, tzinfo=tzlocal()) ), "type": "PassStateExited", "id": 3, "previousEventId": 2, "stateExitedEventDetails": { "name": "A State", "output": "An output", "outputDetails": {"truncated": False}, }, }, { "timestamp": iso_8601_datetime_with_milliseconds( datetime(2020, 1, 1, 0, 0, 20, tzinfo=tzlocal()) ), "type": "ExecutionSucceeded", "id": 4, "previousEventId": 3, "executionSucceededEventDetails": { "output": "An output", "outputDetails": {"truncated": False}, }, }, ] elif sf_execution_history_type == "FAILURE": return [ { "timestamp": iso_8601_datetime_with_milliseconds( datetime(2020, 1, 1, 0, 0, 0, tzinfo=tzlocal()) ), "type": "ExecutionStarted", "id": 1, "previousEventId": 0, "executionStartedEventDetails": { "input": "{}", "inputDetails": {"truncated": False}, "roleArn": roleArn, }, }, { "timestamp": iso_8601_datetime_with_milliseconds( datetime(2020, 1, 1, 0, 0, 10, tzinfo=tzlocal()) ), "type": "FailStateEntered", "id": 2, "previousEventId": 0, "stateEnteredEventDetails": { "name": "A State", "input": "{}", "inputDetails": {"truncated": False}, }, }, { "timestamp": iso_8601_datetime_with_milliseconds( datetime(2020, 1, 1, 0, 0, 10, tzinfo=tzlocal()) ), "type": "ExecutionFailed", "id": 3, "previousEventId": 2, "executionFailedEventDetails": { "error": "AnError", "cause": "An error occurred!", }, }, ] return [] def stop(self, *args: Any, **kwargs: Any) -> None: self.status = "ABORTED" self.stop_date = datetime.now() class StepFunctionBackend(BaseBackend): """ Configure Moto to explicitly parse and execute the StateMachine: .. sourcecode:: python @mock_aws(config={"stepfunctions": {"execute_state_machine": True}}) By default, executing a StateMachine does nothing, and calling `describe_state_machine` will return static data. Set the following environment variable if you want to get the static data to have a FAILED status: .. sourcecode:: bash SF_EXECUTION_HISTORY_TYPE=FAILURE """ # https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/stepfunctions.html#SFN.Client.create_state_machine # A name must not contain: # whitespace # brackets < > { } [ ] # wildcard characters ? * # special characters " # % \ ^ | ~ ` $ & , ; : / invalid_chars_for_name = [ " ", "{", "}", "[", "]", "<", ">", "?", "*", '"', "#", "%", "\\", "^", "|", "~", "`", "$", "&", ",", ";", ":", "/", ] # control characters (U+0000-001F , U+007F-009F ) invalid_unicodes_for_name = [ "\u0000", "\u0001", "\u0002", "\u0003", "\u0004", "\u0005", "\u0006", "\u0007", "\u0008", "\u0009", "\u000a", "\u000b", "\u000c", "\u000d", "\u000e", "\u000f", "\u0010", "\u0011", "\u0012", "\u0013", "\u0014", "\u0015", "\u0016", "\u0017", "\u0018", "\u0019", "\u001a", "\u001b", "\u001c", "\u001d", "\u001e", "\u001f", "\u007f", "\u0080", "\u0081", "\u0082", "\u0083", "\u0084", "\u0085", "\u0086", "\u0087", "\u0088", "\u0089", "\u008a", "\u008b", "\u008c", "\u008d", "\u008e", "\u008f", "\u0090", "\u0091", "\u0092", "\u0093", "\u0094", "\u0095", "\u0096", "\u0097", "\u0098", "\u0099", "\u009a", "\u009b", "\u009c", "\u009d", "\u009e", "\u009f", ] accepted_role_arn_format = re.compile( ARN_PARTITION_REGEX + r":iam::(?P<account_id>[0-9]{12}):role/.+" ) accepted_mchn_arn_format = re.compile( ARN_PARTITION_REGEX + r":states:[-0-9a-zA-Z]+:(?P<account_id>[0-9]{12}):stateMachine:.+" ) accepted_exec_arn_format = re.compile( ARN_PARTITION_REGEX + r":states:[-0-9a-zA-Z]+:(?P<account_id>[0-9]{12}):execution:.+" ) def __init__(self, region_name: str, account_id: str): super().__init__(region_name, account_id) self.state_machines: List[StateMachine] = [] self._account_id = None def create_state_machine( self, name: str, definition: str, roleArn: str, tags: Optional[List[Dict[str, str]]] = None, publish: Optional[bool] = None, loggingConfiguration: Optional[Dict[str, Any]] = None, tracingConfiguration: Optional[Dict[str, Any]] = None, encryptionConfiguration: Optional[Dict[str, Any]] = None, version_description: Optional[str] = None, ) -> StateMachine: self._validate_name(name) self._validate_role_arn(roleArn) arn = f"arn:{get_partition(self.region_name)}:states:{self.region_name}:{self.account_id}:stateMachine:{name}" try: return self.describe_state_machine(arn) except StateMachineDoesNotExist: state_machine = StateMachine( arn, name, definition, roleArn, tags, encryptionConfiguration, loggingConfiguration, tracingConfiguration, ) if publish: state_machine.publish(description=version_description) self.state_machines.append(state_machine) return state_machine @paginate(pagination_model=PAGINATION_MODEL) def list_state_machines(self) -> List[StateMachine]: return sorted(self.state_machines, key=lambda x: x.creation_date) def describe_state_machine(self, arn: str) -> StateMachine: self._validate_machine_arn(arn) sm = next((x for x in self.state_machines if x.arn == arn), None) if not sm: if ( (arn_parts := arn.split(":")) and len(arn_parts) > 7 and arn_parts[-1].isnumeric() ): # we might have a versioned arn, ending in :stateMachine:name:version_nr source_arn = ":".join(arn_parts[:-1]) source_sm = next( (x for x in self.state_machines if x.arn == source_arn), None ) if source_sm: sm = source_sm.versions.get(int(arn_parts[-1])) # type: ignore[assignment] if not sm: raise StateMachineDoesNotExist(f"State Machine Does Not Exist: '{arn}'") return sm # type: ignore[return-value] def delete_state_machine(self, arn: str) -> None: self._validate_machine_arn(arn) sm = next((x for x in self.state_machines if x.arn == arn), None) if sm: self.state_machines.remove(sm) def update_state_machine( self, arn: str, definition: Optional[str] = None, role_arn: Optional[str] = None, logging_configuration: Optional[Dict[str, bool]] = None, tracing_configuration: Optional[Dict[str, bool]] = None, encryption_configuration: Optional[Dict[str, Any]] = None, publish: Optional[bool] = None, version_description: Optional[str] = None, ) -> StateMachine: sm = self.describe_state_machine(arn) updates: Dict[str, Any] = { "definition": definition, "roleArn": role_arn, } if encryption_configuration: updates["encryptionConfiguration"] = encryption_configuration if logging_configuration: updates["loggingConfiguration"] = logging_configuration if tracing_configuration: updates["tracingConfiguration"] = tracing_configuration sm.update(**updates) if publish: sm.publish(version_description) return sm def start_execution( self, state_machine_arn: str, name: str, execution_input: str ) -> Execution: if name: self._validate_name(name) state_machine = self.describe_state_machine(state_machine_arn) return state_machine.start_execution( region_name=self.region_name, account_id=self.account_id, execution_name=name or str(mock_random.uuid4()), execution_input=execution_input, ) def stop_execution(self, execution_arn: str) -> Execution: self._validate_execution_arn(execution_arn) state_machine = self._get_state_machine_for_execution(execution_arn) return state_machine.stop_execution(execution_arn) @paginate(pagination_model=PAGINATION_MODEL) def list_executions( self, state_machine_arn: str, status_filter: Optional[str] = None ) -> List[Execution]: executions = self.describe_state_machine(state_machine_arn).executions if status_filter: executions = list(filter(lambda e: e.status == status_filter, executions)) return sorted(executions, key=lambda x: x.start_date, reverse=True) def describe_execution(self, execution_arn: str) -> Execution: self._validate_execution_arn(execution_arn) state_machine = self._get_state_machine_for_execution(execution_arn) exctn = next( (x for x in state_machine.executions if x.execution_arn == execution_arn), None, ) if not exctn: raise ExecutionDoesNotExist( "Execution Does Not Exist: '" + execution_arn + "'" ) return exctn def get_execution_history(self, execution_arn: str) -> Dict[str, Any]: self._validate_execution_arn(execution_arn) state_machine = self._get_state_machine_for_execution(execution_arn) execution = next( (x for x in state_machine.executions if x.execution_arn == execution_arn), None, ) if not execution: raise ExecutionDoesNotExist( "Execution Does Not Exist: '" + execution_arn + "'" ) return {"events": execution.get_execution_history(state_machine.roleArn)} def describe_state_machine_for_execution(self, execution_arn: str) -> StateMachine: for sm in self.state_machines: for exc in sm.executions: if exc.execution_arn == execution_arn: return sm raise ResourceNotFound(execution_arn) def list_tags_for_resource(self, arn: str) -> List[Dict[str, str]]: try: state_machine = self.describe_state_machine(arn) return state_machine.tags or [] except StateMachineDoesNotExist: return [] def tag_resource(self, resource_arn: str, tags: List[Dict[str, str]]) -> None: try: state_machine = self.describe_state_machine(resource_arn) state_machine.add_tags(tags) except StateMachineDoesNotExist: raise ResourceNotFound(resource_arn) def untag_resource(self, resource_arn: str, tag_keys: List[str]) -> None: try: state_machine = self.describe_state_machine(resource_arn) state_machine.remove_tags(tag_keys) except StateMachineDoesNotExist: raise ResourceNotFound(resource_arn) def send_task_failure(self, task_token: str, error: Optional[str] = None) -> None: pass def send_task_heartbeat(self, task_token: str) -> None: pass def send_task_success(self, task_token: str, outcome: str) -> None: pass def describe_map_run(self, map_run_arn: str) -> Dict[str, Any]: return {} def list_map_runs(self, execution_arn: str) -> Any: return [] def update_map_run( self, map_run_arn: str, max_concurrency: int, tolerated_failure_count: str, tolerated_failure_percentage: str, ) -> None: pass def _validate_name(self, name: str) -> None: if any(invalid_char in name for invalid_char in self.invalid_chars_for_name): raise InvalidName("Invalid Name: '" + name + "'") if any(name.find(char) >= 0 for char in self.invalid_unicodes_for_name): raise InvalidName("Invalid Name: '" + name + "'") if len(name) > 80: raise NameTooLongException(name) def _validate_role_arn(self, role_arn: str) -> None: self._validate_arn( arn=role_arn, regex=self.accepted_role_arn_format, invalid_msg="Invalid Role Arn: '" + role_arn + "'", ) def _validate_machine_arn(self, machine_arn: str) -> None: self._validate_arn( arn=machine_arn, regex=self.accepted_mchn_arn_format, invalid_msg="Invalid State Machine Arn: '" + machine_arn + "'", ) def _validate_execution_arn(self, execution_arn: str) -> None: self._validate_arn( arn=execution_arn, regex=self.accepted_exec_arn_format, invalid_msg="Execution Does Not Exist: '" + execution_arn + "'", ) def _validate_arn(self, arn: str, regex: Pattern[str], invalid_msg: str) -> None: match = regex.match(arn) if not arn or not match: raise InvalidArn(invalid_msg) def _get_state_machine_for_execution(self, execution_arn: str) -> StateMachine: state_machine_name = execution_arn.split(":")[6] state_machine_arn = next( (x.arn for x in self.state_machines if x.name == state_machine_name), None ) if not state_machine_arn: # Assume that if the state machine arn is not present, then neither will the # execution raise ExecutionDoesNotExist( "Execution Does Not Exist: '" + execution_arn + "'" ) return self.describe_state_machine(state_machine_arn) stepfunctions_backends = BackendDict(StepFunctionBackend, "stepfunctions")
Memory