import copy import datetime import json from typing import Any, Dict, List, Optional from moto.core.common_models import BackendDict from moto.stepfunctions.models import StateMachine, StepFunctionBackend from moto.stepfunctions.parser.api import ( Definition, EncryptionConfiguration, ExecutionStatus, GetExecutionHistoryOutput, InvalidDefinition, InvalidExecutionInput, InvalidToken, LoggingConfiguration, MissingRequiredParameter, Name, ResourceNotFound, SendTaskFailureOutput, SendTaskHeartbeatOutput, SendTaskSuccessOutput, SensitiveCause, SensitiveData, SensitiveError, TaskDoesNotExist, TaskTimedOut, TaskToken, TraceHeader, TracingConfiguration, ) from moto.stepfunctions.parser.asl.component.state.exec.state_map.iteration.itemprocessor.map_run_record import ( MapRunRecord, ) from moto.stepfunctions.parser.asl.eval.callback.callback import ( CallbackConsumerTimeout, CallbackNotifyConsumerError, CallbackOutcomeFailure, CallbackOutcomeSuccess, ) from moto.stepfunctions.parser.asl.parse.asl_parser import ( AmazonStateLanguageParser, ASLParserException, ) from moto.stepfunctions.parser.backend.execution import Execution class StepFunctionsParserBackend(StepFunctionBackend): def _get_executions(self, execution_status: Optional[ExecutionStatus] = None): executions = [] for sm in self.state_machines: for execution in sm.executions: if execution_status is None or execution_status == execution.status: executions.append(execution) return executions def _revision_by_name(self, name: str) -> Optional[StateMachine]: for state_machine in self.state_machines: if state_machine.name == name: return state_machine return None @staticmethod def _validate_definition(definition: str): # Validate # TODO: pass through static analyser. try: AmazonStateLanguageParser.parse(definition) except ASLParserException as asl_parser_exception: raise InvalidDefinition(message=repr(asl_parser_exception)) except Exception as exception: exception_name = exception.__class__.__name__ exception_args = list(exception.args) raise InvalidDefinition( message=f"Error={exception_name} Args={exception_args} in definition '{definition}'." ) def create_state_machine( self, name: str, definition: str, roleArn: str, tags: Optional[List[Dict[str, str]]] = None, publish: Optional[bool] = None, loggingConfiguration: Optional[LoggingConfiguration] = None, tracingConfiguration: Optional[TracingConfiguration] = None, encryptionConfiguration: Optional[EncryptionConfiguration] = None, version_description: Optional[str] = None, ) -> StateMachine: StepFunctionsParserBackend._validate_definition(definition=definition) return super().create_state_machine( name=name, definition=definition, roleArn=roleArn, tags=tags, publish=publish, loggingConfiguration=loggingConfiguration, tracingConfiguration=tracingConfiguration, encryptionConfiguration=encryptionConfiguration, version_description=version_description, ) def send_task_heartbeat(self, task_token: TaskToken) -> SendTaskHeartbeatOutput: running_executions = self._get_executions(ExecutionStatus.RUNNING) for execution in running_executions: try: if execution.exec_worker.env.callback_pool_manager.heartbeat( callback_id=task_token ): return except CallbackNotifyConsumerError as consumer_error: if isinstance(consumer_error, CallbackConsumerTimeout): raise TaskTimedOut() else: raise TaskDoesNotExist() raise InvalidToken() def send_task_success( self, task_token: TaskToken, outcome: str ) -> SendTaskSuccessOutput: outcome = CallbackOutcomeSuccess(callback_id=task_token, output=outcome) running_executions = self._get_executions(ExecutionStatus.RUNNING) for execution in running_executions: try: if execution.exec_worker.env.callback_pool_manager.notify( callback_id=task_token, outcome=outcome ): return except CallbackNotifyConsumerError as consumer_error: if isinstance(consumer_error, CallbackConsumerTimeout): raise TaskTimedOut() else: raise TaskDoesNotExist() raise InvalidToken() def send_task_failure( self, task_token: TaskToken, error: SensitiveError = None, cause: SensitiveCause = None, ) -> SendTaskFailureOutput: outcome = CallbackOutcomeFailure( callback_id=task_token, error=error, cause=cause ) for execution in self._get_executions(): try: if execution.exec_worker.env.callback_pool_manager.notify( callback_id=task_token, outcome=outcome ): return SendTaskFailureOutput() except CallbackNotifyConsumerError as consumer_error: if isinstance(consumer_error, CallbackConsumerTimeout): raise TaskTimedOut() else: raise TaskDoesNotExist() raise InvalidToken() def start_execution( self, state_machine_arn: str, name: Name = None, execution_input: SensitiveData = None, trace_header: TraceHeader = None, ) -> Execution: state_machine = self.describe_state_machine(state_machine_arn) # Update event change parameters about the state machine and should not affect those about this execution. state_machine_clone = copy.deepcopy(state_machine) if execution_input is None: input_data = dict() else: try: input_data = json.loads(execution_input) except Exception as ex: raise InvalidExecutionInput( str(ex) ) # TODO: report parsing error like AWS. exec_name = name # TODO: validate name format execution_arn = "arn:{}:states:{}:{}:execution:{}:{}" execution_arn = execution_arn.format( self.partition, self.region_name, self.account_id, state_machine.name, name, ) execution = Execution( name=exec_name, sm_type=state_machine_clone.sm_type, role_arn=state_machine_clone.roleArn, exec_arn=execution_arn, account_id=self.account_id, region_name=self.region_name, state_machine=state_machine_clone, start_date=datetime.datetime.now(tz=datetime.timezone.utc), cloud_watch_logging_session=None, input_data=input_data, trace_header=trace_header, activity_store={}, ) state_machine.executions.append(execution) execution.start() return execution def update_state_machine( self, arn: str, definition: Definition = None, role_arn: str = None, logging_configuration: LoggingConfiguration = None, tracing_configuration: TracingConfiguration = None, encryption_configuration: EncryptionConfiguration = None, publish: Optional[bool] = None, version_description: str = None, ) -> StateMachine: if not any( [ definition, role_arn, logging_configuration, tracing_configuration, encryption_configuration, ] ): raise MissingRequiredParameter( "Either the definition, the role ARN, the LoggingConfiguration, the EncryptionConfiguration or the TracingConfiguration must be specified" ) if definition is not None: self._validate_definition(definition=definition) return super().update_state_machine( arn, definition, role_arn, logging_configuration=logging_configuration, tracing_configuration=tracing_configuration, encryption_configuration=encryption_configuration, publish=publish, version_description=version_description, ) def describe_map_run(self, map_run_arn: str) -> Dict[str, Any]: for execution in self._get_executions(): map_run_record: Optional[MapRunRecord] = ( execution.exec_worker.env.map_run_record_pool_manager.get(map_run_arn) ) if map_run_record is not None: return map_run_record.describe() raise ResourceNotFound() def list_map_runs(self, execution_arn: str) -> Dict[str, Any]: """ Pagination is not yet implemented """ execution = self.describe_execution(execution_arn=execution_arn) map_run_records: List[MapRunRecord] = ( execution.exec_worker.env.map_run_record_pool_manager.get_all() ) return dict( mapRuns=[map_run_record.list_item() for map_run_record in map_run_records] ) def update_map_run( self, map_run_arn: str, max_concurrency: int, tolerated_failure_count: str, tolerated_failure_percentage: str, ) -> None: # TODO: investigate behaviour of empty requests. for execution in self._get_executions(): map_run_record = execution.exec_worker.env.map_run_record_pool_manager.get( map_run_arn ) if map_run_record is not None: map_run_record.update( max_concurrency=max_concurrency, tolerated_failure_count=tolerated_failure_count, tolerated_failure_percentage=tolerated_failure_percentage, ) return raise ResourceNotFound() def get_execution_history(self, execution_arn: str) -> GetExecutionHistoryOutput: execution = self.describe_execution(execution_arn=execution_arn) return execution.to_history_output() stepfunctions_parser_backends = BackendDict(StepFunctionsParserBackend, "stepfunctions")
Memory