import copy import datetime import json import logging import re import time from typing import Optional from localstack.aws.api import CommonServiceException, RequestContext from localstack.services.plugins import ServiceLifecycleHook from localstack.state import StateVisitor from localstack.utils.collections import PaginatedList from moto.stepfunctions.parser.api import ( ActivityDoesNotExist, Arn, ConflictException, CreateActivityOutput, CreateStateMachineInput, CreateStateMachineOutput, Definition, DeleteActivityOutput, DeleteStateMachineOutput, DeleteStateMachineVersionOutput, DescribeActivityOutput, DescribeExecutionOutput, DescribeMapRunOutput, DescribeStateMachineForExecutionOutput, DescribeStateMachineOutput, EncryptionConfiguration, ExecutionDoesNotExist, ExecutionList, ExecutionRedriveFilter, ExecutionStatus, GetActivityTaskOutput, GetExecutionHistoryOutput, IncludedData, IncludeExecutionDataGetExecutionHistory, InspectionLevel, InvalidArn, InvalidDefinition, InvalidExecutionInput, InvalidLoggingConfiguration, InvalidName, InvalidToken, ListActivitiesOutput, ListExecutionsOutput, ListExecutionsPageToken, ListMapRunsOutput, ListStateMachinesOutput, ListStateMachineVersionsOutput, ListTagsForResourceOutput, LoggingConfiguration, LogLevel, LongArn, MaxConcurrency, MissingRequiredParameter, Name, PageSize, PageToken, Publish, PublishStateMachineVersionOutput, ResourceNotFound, RevealSecrets, ReverseOrder, RevisionId, SendTaskFailureOutput, SendTaskHeartbeatOutput, SendTaskSuccessOutput, SensitiveCause, SensitiveData, SensitiveError, StartExecutionOutput, StartSyncExecutionOutput, StateMachineAlreadyExists, StateMachineDoesNotExist, StateMachineList, StateMachineType, StateMachineTypeNotSupported, StepfunctionsApi, StopExecutionOutput, TagKeyList, TagList, TagResourceOutput, TaskDoesNotExist, TaskTimedOut, TaskToken, TestStateOutput, ToleratedFailureCount, ToleratedFailurePercentage, TraceHeader, TracingConfiguration, UntagResourceOutput, UpdateMapRunOutput, UpdateStateMachineOutput, ValidateStateMachineDefinitionDiagnostic, ValidateStateMachineDefinitionDiagnosticList, ValidateStateMachineDefinitionInput, ValidateStateMachineDefinitionOutput, ValidateStateMachineDefinitionResultCode, ValidateStateMachineDefinitionSeverity, ValidationException, VersionDescription, ) from moto.stepfunctions.parser.asl.component.state.state_execution.state_map.iteration.itemprocessor.map_run_record import ( MapRunRecord, ) from moto.stepfunctions.parser.asl.eval.callback.callback import ( ActivityCallbackEndpoint, CallbackConsumerTimeout, CallbackNotifyConsumerError, CallbackOutcomeFailure, CallbackOutcomeSuccess, ) from moto.stepfunctions.parser.asl.eval.event.logging import ( CloudWatchLoggingConfiguration, CloudWatchLoggingSession, ) from moto.stepfunctions.parser.asl.parse.asl_parser import ( ASLParserException, ) from moto.stepfunctions.parser.asl.static_analyser.express_static_analyser import ( ExpressStaticAnalyser, ) from moto.stepfunctions.parser.asl.static_analyser.static_analyser import ( StaticAnalyser, ) from moto.stepfunctions.parser.asl.static_analyser.test_state.test_state_analyser import ( TestStateStaticAnalyser, ) from moto.stepfunctions.parser.asl.static_analyser.usage_metrics_static_analyser import ( UsageMetricsStaticAnalyser, ) from moto.stepfunctions.parser.backend.activity import Activity, ActivityTask from moto.stepfunctions.parser.backend.execution import Execution, SyncExecution from moto.stepfunctions.parser.backend.state_machine import ( StateMachineInstance, StateMachineRevision, StateMachineVersion, TestStateMachine, ) from moto.stepfunctions.parser.backend.store import SFNStore, sfn_stores from moto.stepfunctions.parser.backend.test_state.execution import ( TestStateExecution, ) from moto.stepfunctions.parser.stepfunctions_utils import ( assert_pagination_parameters_valid, get_next_page_token_from_arn, normalise_max_results, ) from moto.stepfunctions.parser.utils import long_uid, short_uid from moto.utilities.arns import ( ARN_PARTITION_REGEX, stepfunctions_activity_arn, stepfunctions_express_execution_arn, stepfunctions_standard_execution_arn, stepfunctions_state_machine_arn, ) LOG = logging.getLogger(__name__) class StepFunctionsProvider(StepfunctionsApi, ServiceLifecycleHook): _TEST_STATE_MAX_TIMEOUT_SECONDS: int = 300 # 5 minutes. @staticmethod def get_store(context: RequestContext) -> SFNStore: return sfn_stores[context.account_id][context.region] def accept_state_visitor(self, visitor: StateVisitor): visitor.visit(sfn_stores) _STATE_MACHINE_ARN_REGEX: re.Pattern = re.compile( rf"{ARN_PARTITION_REGEX}:states:[a-z0-9-]+:[0-9]{{12}}:stateMachine:[a-zA-Z0-9-_.]+(:\d+)?$" ) _STATE_MACHINE_EXECUTION_ARN_REGEX: re.Pattern = re.compile( rf"{ARN_PARTITION_REGEX}:states:[a-z0-9-]+:[0-9]{{12}}:(stateMachine|execution|express):[a-zA-Z0-9-_.]+(:\d+)?(:[a-zA-Z0-9-_.]+)*$" ) _ACTIVITY_ARN_REGEX: re.Pattern = re.compile( rf"{ARN_PARTITION_REGEX}:states:[a-z0-9-]+:[0-9]{{12}}:activity:[a-zA-Z0-9-_\.]{{1,80}}$" ) @staticmethod def _validate_state_machine_arn(state_machine_arn: str) -> None: # TODO: InvalidArn exception message do not communicate which part of the ARN is incorrect. if not StepFunctionsProvider._STATE_MACHINE_ARN_REGEX.match(state_machine_arn): raise InvalidArn(f"Invalid arn: '{state_machine_arn}'") @staticmethod def _raise_state_machine_does_not_exist(state_machine_arn: str) -> None: raise StateMachineDoesNotExist( f"State Machine Does Not Exist: '{state_machine_arn}'" ) @staticmethod def _validate_state_machine_execution_arn(execution_arn: str) -> None: # TODO: InvalidArn exception message do not communicate which part of the ARN is incorrect. if not StepFunctionsProvider._STATE_MACHINE_EXECUTION_ARN_REGEX.match( execution_arn ): raise InvalidArn(f"Invalid arn: '{execution_arn}'") @staticmethod def _validate_activity_arn(activity_arn: str) -> None: # TODO: InvalidArn exception message do not communicate which part of the ARN is incorrect. if not StepFunctionsProvider._ACTIVITY_ARN_REGEX.match(activity_arn): raise InvalidArn(f"Invalid arn: '{activity_arn}'") def _raise_state_machine_type_not_supported(self): raise StateMachineTypeNotSupported( "This operation is not supported by this type of state machine" ) @staticmethod def _raise_resource_type_not_in_context(resource_type: str) -> None: lower_resource_type = resource_type.lower() raise InvalidArn( f"Invalid Arn: 'Resource type not valid in this context: {lower_resource_type}'" ) @staticmethod def _validate_activity_name(name: str) -> None: # The activity name is validated according to the AWS StepFunctions documentation, the name should not contain: # - white space # - brackets < > { } [ ] # - wildcard characters ? * # - special characters " # % \ ^ | ~ ` $ & , ; : / # - control characters (U+0000-001F, U+007F-009F) # https://docs.aws.amazon.com/step-functions/latest/apireference/API_CreateActivity.html#API_CreateActivity_RequestSyntax if not (1 <= len(name) <= 80): raise InvalidName(f"Invalid Name: '{name}'") invalid_chars = set(' <>{}[]?*"#%\\^|~`$&,;:/') control_chars = {chr(i) for i in range(32)} | {chr(i) for i in range(127, 160)} invalid_chars |= control_chars for char in name: if char in invalid_chars: raise InvalidName(f"Invalid Name: '{name}'") def _get_execution(self, context: RequestContext, execution_arn: Arn) -> Execution: execution: Optional[Execution] = self.get_store(context).executions.get( execution_arn ) if not execution: raise ExecutionDoesNotExist(f"Execution Does Not Exist: '{execution_arn}'") return execution def _get_executions( self, context: RequestContext, execution_status: Optional[ExecutionStatus] = None, ): store = self.get_store(context) execution: list[Execution] = list(store.executions.values()) if execution_status: execution = list( filter( lambda e: e.exec_status == execution_status, store.executions.values(), ) ) return execution def _get_activity(self, context: RequestContext, activity_arn: Arn) -> Activity: maybe_activity: Optional[Activity] = self.get_store(context).activities.get( activity_arn, None ) if maybe_activity is None: raise ActivityDoesNotExist(f"Activity Does Not Exist: '{activity_arn}'") return maybe_activity def _idempotent_revision( self, context: RequestContext, name: str, definition: Definition, state_machine_type: StateMachineType, logging_configuration: LoggingConfiguration, tracing_configuration: TracingConfiguration, ) -> Optional[StateMachineRevision]: # CreateStateMachine's idempotency check is based on the state machine name, definition, type, # LoggingConfiguration and TracingConfiguration. # If a following request has a different roleArn or tags, Step Functions will ignore these differences and # treat it as an idempotent request of the previous. In this case, roleArn and tags will not be updated, even # if they are different. state_machines: list[StateMachineInstance] = list( self.get_store(context).state_machines.values() ) revisions = filter( lambda sm: isinstance(sm, StateMachineRevision), state_machines ) for state_machine in revisions: check = all( [ state_machine.name == name, state_machine.definition == definition, state_machine.sm_type == state_machine_type, state_machine.logging_config == logging_configuration, state_machine.tracing_config == tracing_configuration, ] ) if check: return state_machine return None def _idempotent_start_execution( self, execution: Optional[Execution], state_machine: StateMachineInstance, name: Name, input_data: SensitiveData, ) -> Optional[Execution]: # StartExecution is idempotent for STANDARD workflows. For a STANDARD workflow, # if you call StartExecution with the same name and input as a running execution, # the call succeeds and return the same response as the original request. # If the execution is closed or if the input is different, # it returns a 400 ExecutionAlreadyExists error. You can reuse names after 90 days. if not execution: return None if (name, input_data, execution.exec_status, state_machine.sm_type) == ( execution.name, execution.input_data, ExecutionStatus.RUNNING, StateMachineType.STANDARD, ): return execution raise CommonServiceException( code="ExecutionAlreadyExists", message=f"Execution Already Exists: '{execution.exec_arn}'", ) def _revision_by_name( self, context: RequestContext, name: str ) -> Optional[StateMachineInstance]: state_machines: list[StateMachineInstance] = list( self.get_store(context).state_machines.values() ) for state_machine in state_machines: if ( isinstance(state_machine, StateMachineRevision) and state_machine.name == name ): return state_machine return None @staticmethod def _validate_definition( definition: str, static_analysers: list[StaticAnalyser] ) -> None: try: for static_analyser in static_analysers: static_analyser.analyse(definition) except ASLParserException as asl_parser_exception: invalid_definition = InvalidDefinition() invalid_definition.message = repr(asl_parser_exception) raise invalid_definition except Exception as exception: exception_name = exception.__class__.__name__ exception_args = list(exception.args) invalid_definition = InvalidDefinition() invalid_definition.message = f"Error={exception_name} Args={exception_args} in definition '{definition}'." raise invalid_definition @staticmethod def _sanitise_logging_configuration( logging_configuration: LoggingConfiguration, ) -> None: level = logging_configuration.get("level") destinations = logging_configuration.get("destinations") if destinations is not None and len(destinations) > 1: raise InvalidLoggingConfiguration( "Invalid Logging Configuration: Must specify exactly one Log Destination." ) # A LogLevel that is not OFF, should have a destination. if level is not None and level != LogLevel.OFF and not destinations: raise InvalidLoggingConfiguration( "Invalid Logging Configuration: Must specify exactly one Log Destination." ) # Default for level is OFF. level = level or LogLevel.OFF # Default for includeExecutionData is False. include_flag = logging_configuration.get("includeExecutionData", False) # Update configuration object. logging_configuration["level"] = level logging_configuration["includeExecutionData"] = include_flag def create_state_machine( self, context: RequestContext, request: CreateStateMachineInput, **kwargs ) -> CreateStateMachineOutput: if not request.get("publish", False) and request.get("versionDescription"): raise ValidationException( "Version description can only be set when publish is true" ) # Extract parameters and set defaults. state_machine_name = request["name"] state_machine_role_arn = request["roleArn"] state_machine_definition = request["definition"] state_machine_type = request.get("type") or StateMachineType.STANDARD state_machine_tracing_configuration = request.get("tracingConfiguration") state_machine_tags = request.get("tags") state_machine_logging_configuration = request.get( "loggingConfiguration", LoggingConfiguration() ) self._sanitise_logging_configuration( logging_configuration=state_machine_logging_configuration ) # CreateStateMachine is an idempotent API. Subsequent requests won’t create a duplicate resource if it was # already created. idem_state_machine: Optional[StateMachineRevision] = self._idempotent_revision( context=context, name=state_machine_name, definition=state_machine_definition, state_machine_type=state_machine_type, logging_configuration=state_machine_logging_configuration, tracing_configuration=state_machine_tracing_configuration, ) if idem_state_machine is not None: return CreateStateMachineOutput( stateMachineArn=idem_state_machine.arn, creationDate=idem_state_machine.create_date, ) # Assert this state machine name is unique. state_machine_with_name: Optional[StateMachineRevision] = ( self._revision_by_name(context=context, name=state_machine_name) ) if state_machine_with_name is not None: raise StateMachineAlreadyExists( f"State Machine Already Exists: '{state_machine_with_name.arn}'" ) # Compute the state machine's Arn. state_machine_arn = stepfunctions_state_machine_arn( name=state_machine_name, account_id=context.account_id, region_name=context.region, ) state_machines = self.get_store(context).state_machines # Reduce the logging configuration to a usable cloud watch representation, and validate the destinations # if any were given. cloud_watch_logging_configuration = ( CloudWatchLoggingConfiguration.from_logging_configuration( state_machine_arn=state_machine_arn, logging_configuration=state_machine_logging_configuration, ) ) if cloud_watch_logging_configuration is not None: cloud_watch_logging_configuration.validate() # Run static analysers on the definition given. if state_machine_type == StateMachineType.EXPRESS: StepFunctionsProvider._validate_definition( definition=state_machine_definition, static_analysers=[ExpressStaticAnalyser()], ) else: StepFunctionsProvider._validate_definition( definition=state_machine_definition, static_analysers=[StaticAnalyser()] ) # Create the state machine and add it to the store. state_machine = StateMachineRevision( name=state_machine_name, arn=state_machine_arn, role_arn=state_machine_role_arn, definition=state_machine_definition, sm_type=state_machine_type, logging_config=state_machine_logging_configuration, cloud_watch_logging_configuration=cloud_watch_logging_configuration, tracing_config=state_machine_tracing_configuration, tags=state_machine_tags, ) state_machines[state_machine_arn] = state_machine create_output = CreateStateMachineOutput( stateMachineArn=state_machine.arn, creationDate=state_machine.create_date ) # Create the first version if the 'publish' flag is used. if request.get("publish", False): version_description = request.get("versionDescription") state_machine_version = state_machine.create_version( description=version_description ) if state_machine_version is not None: state_machine_version_arn = state_machine_version.arn state_machines[state_machine_version_arn] = state_machine_version create_output["stateMachineVersionArn"] = state_machine_version_arn # Run static analyser on definition and collect usage metrics UsageMetricsStaticAnalyser.process(state_machine_definition) return create_output def describe_state_machine( self, context: RequestContext, state_machine_arn: Arn, included_data: IncludedData = None, **kwargs, ) -> DescribeStateMachineOutput: self._validate_state_machine_arn(state_machine_arn) state_machine = self.get_store(context).state_machines.get(state_machine_arn) if state_machine is None: self._raise_state_machine_does_not_exist(state_machine_arn) return state_machine.describe() def describe_state_machine_for_execution( self, context: RequestContext, execution_arn: Arn, included_data: IncludedData = None, **kwargs, ) -> DescribeStateMachineForExecutionOutput: self._validate_state_machine_execution_arn(execution_arn) execution: Execution = self._get_execution( context=context, execution_arn=execution_arn ) return execution.to_describe_state_machine_for_execution_output() def send_task_heartbeat( self, context: RequestContext, task_token: TaskToken, **kwargs ) -> SendTaskHeartbeatOutput: running_executions: list[Execution] = self._get_executions( context, ExecutionStatus.RUNNING ) for execution in running_executions: try: if execution.exec_worker.env.callback_pool_manager.heartbeat( callback_id=task_token ): return SendTaskHeartbeatOutput() except CallbackNotifyConsumerError as consumer_error: if isinstance(consumer_error, CallbackConsumerTimeout): raise TaskTimedOut() else: raise TaskDoesNotExist() raise InvalidToken() def send_task_success( self, context: RequestContext, task_token: TaskToken, output: SensitiveData, **kwargs, ) -> SendTaskSuccessOutput: outcome = CallbackOutcomeSuccess(callback_id=task_token, output=output) running_executions: list[Execution] = self._get_executions( context, ExecutionStatus.RUNNING ) for execution in running_executions: try: if execution.exec_worker.env.callback_pool_manager.notify( callback_id=task_token, outcome=outcome ): return SendTaskSuccessOutput() except CallbackNotifyConsumerError as consumer_error: if isinstance(consumer_error, CallbackConsumerTimeout): raise TaskTimedOut() else: raise TaskDoesNotExist() raise InvalidToken("Invalid token") def send_task_failure( self, context: RequestContext, task_token: TaskToken, error: SensitiveError = None, cause: SensitiveCause = None, **kwargs, ) -> SendTaskFailureOutput: outcome = CallbackOutcomeFailure( callback_id=task_token, error=error, cause=cause ) store = self.get_store(context) for execution in store.executions.values(): try: if execution.exec_worker.env.callback_pool_manager.notify( callback_id=task_token, outcome=outcome ): return SendTaskFailureOutput() except CallbackNotifyConsumerError as consumer_error: if isinstance(consumer_error, CallbackConsumerTimeout): raise TaskTimedOut() else: raise TaskDoesNotExist() raise InvalidToken("Invalid token") def start_execution( self, context: RequestContext, state_machine_arn: Arn, name: Name = None, input: SensitiveData = None, trace_header: TraceHeader = None, **kwargs, ) -> StartExecutionOutput: self._validate_state_machine_arn(state_machine_arn) unsafe_state_machine: Optional[StateMachineInstance] = self.get_store( context ).state_machines.get(state_machine_arn) if not unsafe_state_machine: self._raise_state_machine_does_not_exist(state_machine_arn) # Update event change parameters about the state machine and should not affect those about this execution. state_machine_clone = copy.deepcopy(unsafe_state_machine) if input is None: input_data = dict() else: try: input_data = json.loads(input) except Exception as ex: raise InvalidExecutionInput( str(ex) ) # TODO: report parsing error like AWS. normalised_state_machine_arn = ( state_machine_clone.source_arn if isinstance(state_machine_clone, StateMachineVersion) else state_machine_clone.arn ) exec_name = name or long_uid() # TODO: validate name format if state_machine_clone.sm_type == StateMachineType.STANDARD: exec_arn = stepfunctions_standard_execution_arn( normalised_state_machine_arn, exec_name ) else: # Exhaustive check on STANDARD and EXPRESS type, validated on creation. exec_arn = stepfunctions_express_execution_arn( normalised_state_machine_arn, exec_name ) if execution := self.get_store(context).executions.get(exec_arn): # Return already running execution if name and input match existing_execution = self._idempotent_start_execution( execution=execution, state_machine=state_machine_clone, name=name, input_data=input_data, ) if existing_execution: return existing_execution.to_start_output() # Create the execution logging session, if logging is configured. cloud_watch_logging_session = None if state_machine_clone.cloud_watch_logging_configuration is not None: cloud_watch_logging_session = CloudWatchLoggingSession( execution_arn=exec_arn, configuration=state_machine_clone.cloud_watch_logging_configuration, ) execution = Execution( name=exec_name, sm_type=state_machine_clone.sm_type, role_arn=state_machine_clone.role_arn, exec_arn=exec_arn, account_id=context.account_id, region_name=context.region, state_machine=state_machine_clone, start_date=datetime.datetime.now(tz=datetime.timezone.utc), cloud_watch_logging_session=cloud_watch_logging_session, input_data=input_data, trace_header=trace_header, activity_store=self.get_store(context).activities, ) self.get_store(context).executions[exec_arn] = execution execution.start() return execution.to_start_output() def start_sync_execution( self, context: RequestContext, state_machine_arn: Arn, name: Name = None, input: SensitiveData = None, trace_header: TraceHeader = None, included_data: IncludedData = None, **kwargs, ) -> StartSyncExecutionOutput: self._validate_state_machine_arn(state_machine_arn) unsafe_state_machine: Optional[StateMachineInstance] = self.get_store( context ).state_machines.get(state_machine_arn) if not unsafe_state_machine: self._raise_state_machine_does_not_exist(state_machine_arn) if unsafe_state_machine.sm_type == StateMachineType.STANDARD: self._raise_state_machine_type_not_supported() # Update event change parameters about the state machine and should not affect those about this execution. state_machine_clone = copy.deepcopy(unsafe_state_machine) if input is None: input_data = dict() else: try: input_data = json.loads(input) except Exception as ex: raise InvalidExecutionInput( str(ex) ) # TODO: report parsing error like AWS. normalised_state_machine_arn = ( state_machine_clone.source_arn if isinstance(state_machine_clone, StateMachineVersion) else state_machine_clone.arn ) exec_name = name or long_uid() # TODO: validate name format exec_arn = stepfunctions_express_execution_arn( normalised_state_machine_arn, exec_name ) if exec_arn in self.get_store(context).executions: raise InvalidName() # TODO # Create the execution logging session, if logging is configured. cloud_watch_logging_session = None if state_machine_clone.cloud_watch_logging_configuration is not None: cloud_watch_logging_session = CloudWatchLoggingSession( execution_arn=exec_arn, configuration=state_machine_clone.cloud_watch_logging_configuration, ) execution = SyncExecution( name=exec_name, sm_type=state_machine_clone.sm_type, role_arn=state_machine_clone.role_arn, exec_arn=exec_arn, account_id=context.account_id, region_name=context.region, state_machine=state_machine_clone, start_date=datetime.datetime.now(tz=datetime.timezone.utc), cloud_watch_logging_session=cloud_watch_logging_session, input_data=input_data, trace_header=trace_header, activity_store=self.get_store(context).activities, ) self.get_store(context).executions[exec_arn] = execution execution.start() return execution.to_start_sync_execution_output() def describe_execution( self, context: RequestContext, execution_arn: Arn, included_data: IncludedData = None, **kwargs, ) -> DescribeExecutionOutput: self._validate_state_machine_execution_arn(execution_arn) execution: Execution = self._get_execution( context=context, execution_arn=execution_arn ) # Action only compatible with STANDARD workflows. if execution.sm_type != StateMachineType.STANDARD: self._raise_resource_type_not_in_context(resource_type=execution.sm_type) return execution.to_describe_output() @staticmethod def _list_execution_filter( ex: Execution, state_machine_arn: Optional[str], status_filter: Optional[str] ) -> bool: if state_machine_arn and ex.state_machine.arn != state_machine_arn: return False if not status_filter: return True return ex.exec_status == status_filter def list_executions( self, context: RequestContext, state_machine_arn: Arn = None, status_filter: ExecutionStatus = None, max_results: PageSize = None, next_token: ListExecutionsPageToken = None, map_run_arn: LongArn = None, redrive_filter: ExecutionRedriveFilter = None, **kwargs, ) -> ListExecutionsOutput: self._validate_state_machine_arn(state_machine_arn) assert_pagination_parameters_valid( max_results=max_results, next_token=next_token, next_token_length_limit=3096, ) max_results = normalise_max_results(max_results) state_machine = self.get_store(context).state_machines.get(state_machine_arn) if state_machine is None: self._raise_state_machine_does_not_exist(state_machine_arn) if state_machine.sm_type != StateMachineType.STANDARD: self._raise_state_machine_type_not_supported() # TODO: add support for paging allowed_execution_status = [ ExecutionStatus.SUCCEEDED, ExecutionStatus.TIMED_OUT, ExecutionStatus.PENDING_REDRIVE, ExecutionStatus.ABORTED, ExecutionStatus.FAILED, ExecutionStatus.RUNNING, ] validation_errors = [] if status_filter and status_filter not in allowed_execution_status: validation_errors.append( f"Value '{status_filter}' at 'statusFilter' failed to satisfy constraint: Member must satisfy enum value set: [{', '.join(allowed_execution_status)}]" ) if not state_machine_arn and not map_run_arn: validation_errors.append("Must provide a StateMachine ARN or MapRun ARN") if validation_errors: errors_message = "; ".join(validation_errors) message = f"{len(validation_errors)} validation {'errors' if len(validation_errors) > 1 else 'error'} detected: {errors_message}" raise CommonServiceException(message=message, code="ValidationException") executions: ExecutionList = [ execution.to_execution_list_item() for execution in self.get_store(context).executions.values() if self._list_execution_filter( execution, state_machine_arn=state_machine_arn, status_filter=status_filter, ) ] executions.sort(key=lambda item: item["startDate"], reverse=True) paginated_executions = PaginatedList(executions) page, token_for_next_page = paginated_executions.get_page( token_generator=lambda item: get_next_page_token_from_arn( item.get("executionArn") ), page_size=max_results, next_token=next_token, ) return ListExecutionsOutput(executions=page, nextToken=token_for_next_page) def list_state_machines( self, context: RequestContext, max_results: PageSize = None, next_token: PageToken = None, **kwargs, ) -> ListStateMachinesOutput: assert_pagination_parameters_valid(max_results, next_token) max_results = normalise_max_results(max_results) state_machines: StateMachineList = [ sm.itemise() for sm in self.get_store(context).state_machines.values() if isinstance(sm, StateMachineRevision) ] state_machines.sort(key=lambda item: item["name"]) paginated_state_machines = PaginatedList(state_machines) page, token_for_next_page = paginated_state_machines.get_page( token_generator=lambda item: get_next_page_token_from_arn( item.get("stateMachineArn") ), page_size=max_results, next_token=next_token, ) return ListStateMachinesOutput( stateMachines=page, nextToken=token_for_next_page ) def list_state_machine_versions( self, context: RequestContext, state_machine_arn: Arn, next_token: PageToken = None, max_results: PageSize = None, **kwargs, ) -> ListStateMachineVersionsOutput: self._validate_state_machine_arn(state_machine_arn) assert_pagination_parameters_valid(max_results, next_token) max_results = normalise_max_results(max_results) state_machines = self.get_store(context).state_machines state_machine_revision = state_machines.get(state_machine_arn) if not isinstance(state_machine_revision, StateMachineRevision): raise InvalidArn(f"Invalid arn: {state_machine_arn}") state_machine_version_items = list() for version_arn in state_machine_revision.versions.values(): state_machine_version = state_machines[version_arn] if isinstance(state_machine_version, StateMachineVersion): state_machine_version_items.append(state_machine_version.itemise()) else: raise RuntimeError( f"Expected {version_arn} to be a StateMachine Version, but got '{type(state_machine_version)}'." ) state_machine_version_items.sort( key=lambda item: item["creationDate"], reverse=True ) paginated_state_machine_versions = PaginatedList(state_machine_version_items) page, token_for_next_page = paginated_state_machine_versions.get_page( token_generator=lambda item: get_next_page_token_from_arn( item.get("stateMachineVersionArn") ), page_size=max_results, next_token=next_token, ) return ListStateMachineVersionsOutput( stateMachineVersions=page, nextToken=token_for_next_page ) def get_execution_history( self, context: RequestContext, execution_arn: Arn, max_results: PageSize = None, reverse_order: ReverseOrder = None, next_token: PageToken = None, include_execution_data: IncludeExecutionDataGetExecutionHistory = None, **kwargs, ) -> GetExecutionHistoryOutput: # TODO: add support for paging, ordering, and other manipulations. self._validate_state_machine_execution_arn(execution_arn) execution: Execution = self._get_execution( context=context, execution_arn=execution_arn ) # Action only compatible with STANDARD workflows. if execution.sm_type != StateMachineType.STANDARD: self._raise_resource_type_not_in_context(resource_type=execution.sm_type) history: GetExecutionHistoryOutput = execution.to_history_output() if reverse_order: history["events"].reverse() return history def delete_state_machine( self, context: RequestContext, state_machine_arn: Arn, **kwargs ) -> DeleteStateMachineOutput: # TODO: halt executions? self._validate_state_machine_arn(state_machine_arn) state_machines = self.get_store(context).state_machines state_machine = state_machines.get(state_machine_arn) if isinstance(state_machine, StateMachineRevision): state_machines.pop(state_machine_arn) for version_arn in state_machine.versions.values(): state_machines.pop(version_arn, None) return DeleteStateMachineOutput() def delete_state_machine_version( self, context: RequestContext, state_machine_version_arn: LongArn, **kwargs ) -> DeleteStateMachineVersionOutput: self._validate_state_machine_arn(state_machine_version_arn) state_machines = self.get_store(context).state_machines state_machine_version = state_machines.get(state_machine_version_arn) if isinstance(state_machine_version, StateMachineVersion): state_machines.pop(state_machine_version.arn) state_machine_revision = state_machines.get( state_machine_version.source_arn ) if isinstance(state_machine_revision, StateMachineRevision): state_machine_revision.delete_version(state_machine_version_arn) return DeleteStateMachineVersionOutput() def stop_execution( self, context: RequestContext, execution_arn: Arn, error: SensitiveError = None, cause: SensitiveCause = None, **kwargs, ) -> StopExecutionOutput: self._validate_state_machine_execution_arn(execution_arn) execution: Execution = self._get_execution( context=context, execution_arn=execution_arn ) # Action only compatible with STANDARD workflows. if execution.sm_type != StateMachineType.STANDARD: self._raise_resource_type_not_in_context(resource_type=execution.sm_type) stop_date = datetime.datetime.now(tz=datetime.timezone.utc) execution.stop(stop_date=stop_date, cause=cause, error=error) return StopExecutionOutput(stopDate=stop_date) def update_state_machine( self, context: RequestContext, state_machine_arn: Arn, definition: Definition = None, role_arn: Arn = None, logging_configuration: LoggingConfiguration = None, tracing_configuration: TracingConfiguration = None, publish: Publish = None, version_description: VersionDescription = None, encryption_configuration: EncryptionConfiguration = None, **kwargs, ) -> UpdateStateMachineOutput: self._validate_state_machine_arn(state_machine_arn) state_machines = self.get_store(context).state_machines state_machine = state_machines.get(state_machine_arn) if not isinstance(state_machine, StateMachineRevision): self._raise_state_machine_does_not_exist(state_machine_arn) # TODO: Add logic to handle metrics for when SFN definitions update if not any([definition, role_arn, logging_configuration]): raise MissingRequiredParameter( "Either the definition, the role ARN, the LoggingConfiguration, " "or the TracingConfiguration must be specified" ) if definition is not None: self._validate_definition( definition=definition, static_analysers=[StaticAnalyser()] ) if logging_configuration is not None: self._sanitise_logging_configuration( logging_configuration=logging_configuration ) revision_id = state_machine.create_revision( definition=definition, role_arn=role_arn, logging_configuration=logging_configuration, ) version_arn = None if publish: version = state_machine.create_version(description=version_description) if version is not None: version_arn = version.arn state_machines[version_arn] = version else: target_revision_id = revision_id or state_machine.revision_id version_arn = state_machine.versions[target_revision_id] update_output = UpdateStateMachineOutput( updateDate=datetime.datetime.now(tz=datetime.timezone.utc) ) if revision_id is not None: update_output["revisionId"] = revision_id if version_arn is not None: update_output["stateMachineVersionArn"] = version_arn return update_output def publish_state_machine_version( self, context: RequestContext, state_machine_arn: Arn, revision_id: RevisionId = None, description: VersionDescription = None, **kwargs, ) -> PublishStateMachineVersionOutput: self._validate_state_machine_arn(state_machine_arn) state_machines = self.get_store(context).state_machines state_machine_revision = state_machines.get(state_machine_arn) if not isinstance(state_machine_revision, StateMachineRevision): self._raise_state_machine_does_not_exist(state_machine_arn) if ( revision_id is not None and state_machine_revision.revision_id != revision_id ): raise ConflictException( f"Failed to publish the State Machine version for revision {revision_id}. " f"The current State Machine revision is {state_machine_revision.revision_id}." ) state_machine_version = state_machine_revision.create_version( description=description ) if state_machine_version is not None: state_machines[state_machine_version.arn] = state_machine_version else: target_revision_id = revision_id or state_machine_revision.revision_id state_machine_version_arn = state_machine_revision.versions.get( target_revision_id ) state_machine_version = state_machines[state_machine_version_arn] return PublishStateMachineVersionOutput( creationDate=state_machine_version.create_date, stateMachineVersionArn=state_machine_version.arn, ) def tag_resource( self, context: RequestContext, resource_arn: Arn, tags: TagList, **kwargs ) -> TagResourceOutput: # TODO: add tagging for activities. state_machines = self.get_store(context).state_machines state_machine = state_machines.get(resource_arn) if not isinstance(state_machine, StateMachineRevision): raise ResourceNotFound(f"Resource not found: '{resource_arn}'") state_machine.tag_manager.add_all(tags) return TagResourceOutput() def untag_resource( self, context: RequestContext, resource_arn: Arn, tag_keys: TagKeyList, **kwargs ) -> UntagResourceOutput: # TODO: add untagging for activities. state_machines = self.get_store(context).state_machines state_machine = state_machines.get(resource_arn) if not isinstance(state_machine, StateMachineRevision): raise ResourceNotFound(f"Resource not found: '{resource_arn}'") state_machine.tag_manager.remove_all(tag_keys) return UntagResourceOutput() def list_tags_for_resource( self, context: RequestContext, resource_arn: Arn, **kwargs ) -> ListTagsForResourceOutput: # TODO: add untagging for activities. state_machines = self.get_store(context).state_machines state_machine = state_machines.get(resource_arn) if not isinstance(state_machine, StateMachineRevision): raise ResourceNotFound(f"Resource not found: '{resource_arn}'") tags: TagList = state_machine.tag_manager.to_tag_list() return ListTagsForResourceOutput(tags=tags) def describe_map_run( self, context: RequestContext, map_run_arn: LongArn, **kwargs ) -> DescribeMapRunOutput: store = self.get_store(context) for execution in store.executions.values(): map_run_record: Optional[MapRunRecord] = ( execution.exec_worker.env.map_run_record_pool_manager.get(map_run_arn) ) if map_run_record is not None: return map_run_record.describe() raise ResourceNotFound() def list_map_runs( self, context: RequestContext, execution_arn: Arn, max_results: PageSize = None, next_token: PageToken = None, **kwargs, ) -> ListMapRunsOutput: # TODO: add support for paging. execution = self._get_execution(context=context, execution_arn=execution_arn) map_run_records: list[MapRunRecord] = ( execution.exec_worker.env.map_run_record_pool_manager.get_all() ) return ListMapRunsOutput( mapRuns=[map_run_record.list_item() for map_run_record in map_run_records] ) def update_map_run( self, context: RequestContext, map_run_arn: LongArn, max_concurrency: MaxConcurrency = None, tolerated_failure_percentage: ToleratedFailurePercentage = None, tolerated_failure_count: ToleratedFailureCount = None, **kwargs, ) -> UpdateMapRunOutput: if ( tolerated_failure_percentage is not None or tolerated_failure_count is not None ): raise NotImplementedError( "Updating of ToleratedFailureCount and ToleratedFailurePercentage is currently unsupported." ) # TODO: investigate behaviour of empty requests. store = self.get_store(context) for execution in store.executions.values(): map_run_record: Optional[MapRunRecord] = ( execution.exec_worker.env.map_run_record_pool_manager.get(map_run_arn) ) if map_run_record is not None: map_run_record.update( max_concurrency=max_concurrency, tolerated_failure_count=tolerated_failure_count, tolerated_failure_percentage=tolerated_failure_percentage, ) LOG.warning( "StepFunctions UpdateMapRun changes are currently not being reflected in the MapRun instances." ) return UpdateMapRunOutput() raise ResourceNotFound() def test_state( self, context: RequestContext, definition: Definition, role_arn: Arn = None, input: SensitiveData = None, inspection_level: InspectionLevel = None, reveal_secrets: RevealSecrets = None, variables: SensitiveData = None, **kwargs, ) -> TestStateOutput: StepFunctionsProvider._validate_definition( definition=definition, static_analysers=[TestStateStaticAnalyser()] ) name: Optional[Name] = f"TestState-{short_uid()}" arn = stepfunctions_state_machine_arn( name=name, account_id=context.account_id, region_name=context.region ) state_machine = TestStateMachine( name=name, arn=arn, role_arn=role_arn, definition=definition, ) exec_arn = stepfunctions_standard_execution_arn(state_machine.arn, name) input_json = json.loads(input) execution = TestStateExecution( name=name, role_arn=role_arn, exec_arn=exec_arn, account_id=context.account_id, region_name=context.region, state_machine=state_machine, start_date=datetime.datetime.now(tz=datetime.timezone.utc), input_data=input_json, activity_store=self.get_store(context).activities, ) execution.start() test_state_output = execution.to_test_state_output( inspection_level=inspection_level or InspectionLevel.INFO ) return test_state_output def create_activity( self, context: RequestContext, name: Name, tags: TagList = None, encryption_configuration: EncryptionConfiguration = None, **kwargs, ) -> CreateActivityOutput: self._validate_activity_name(name=name) activity_arn = stepfunctions_activity_arn( name=name, account_id=context.account_id, region_name=context.region ) activities = self.get_store(context).activities if activity_arn not in activities: activity = Activity(arn=activity_arn, name=name) activities[activity_arn] = activity else: activity = activities[activity_arn] return CreateActivityOutput( activityArn=activity.arn, creationDate=activity.creation_date ) def delete_activity( self, context: RequestContext, activity_arn: Arn, **kwargs ) -> DeleteActivityOutput: self._validate_activity_arn(activity_arn) self.get_store(context).activities.pop(activity_arn, None) return DeleteActivityOutput() def describe_activity( self, context: RequestContext, activity_arn: Arn, **kwargs ) -> DescribeActivityOutput: self._validate_activity_arn(activity_arn) activity = self._get_activity(context=context, activity_arn=activity_arn) return activity.to_describe_activity_output() def list_activities( self, context: RequestContext, max_results: PageSize = None, next_token: PageToken = None, **kwargs, ) -> ListActivitiesOutput: activities: list[Activity] = list(self.get_store(context).activities.values()) return ListActivitiesOutput( activities=[activity.to_activity_list_item() for activity in activities] ) def _send_activity_task_started( self, context: RequestContext, task_token: TaskToken, worker_name: Optional[Name], ) -> None: executions: list[Execution] = self._get_executions(context) for execution in executions: callback_endpoint = execution.exec_worker.env.callback_pool_manager.get( callback_id=task_token ) if isinstance(callback_endpoint, ActivityCallbackEndpoint): callback_endpoint.notify_activity_task_start(worker_name=worker_name) return raise InvalidToken() @staticmethod def _pull_activity_task(activity: Activity) -> Optional[ActivityTask]: seconds_left = 60 while seconds_left > 0: try: return activity.get_task() except IndexError: time.sleep(1) seconds_left -= 1 return None def get_activity_task( self, context: RequestContext, activity_arn: Arn, worker_name: Name = None, **kwargs, ) -> GetActivityTaskOutput: self._validate_activity_arn(activity_arn) activity = self._get_activity(context=context, activity_arn=activity_arn) maybe_task: Optional[ActivityTask] = self._pull_activity_task(activity=activity) if maybe_task is not None: self._send_activity_task_started( context, maybe_task.task_token, worker_name=worker_name ) return GetActivityTaskOutput( taskToken=maybe_task.task_token, input=maybe_task.task_input ) return GetActivityTaskOutput(taskToken=None, input=None) def validate_state_machine_definition( self, context: RequestContext, request: ValidateStateMachineDefinitionInput, **kwargs, ) -> ValidateStateMachineDefinitionOutput: # TODO: increase parity of static analysers, current implementation is an unblocker for this API action. # TODO: add support for ValidateStateMachineDefinitionSeverity # TODO: add support for ValidateStateMachineDefinitionMaxResult state_machine_type: StateMachineType = request.get( "type", StateMachineType.STANDARD ) definition: str = request["definition"] static_analysers = list() if state_machine_type == StateMachineType.STANDARD: static_analysers.append(StaticAnalyser()) else: static_analysers.append(ExpressStaticAnalyser()) diagnostics: ValidateStateMachineDefinitionDiagnosticList = list() try: StepFunctionsProvider._validate_definition( definition=definition, static_analysers=static_analysers ) validation_result = ValidateStateMachineDefinitionResultCode.OK except InvalidDefinition as invalid_definition: validation_result = ValidateStateMachineDefinitionResultCode.FAIL diagnostics.append( ValidateStateMachineDefinitionDiagnostic( severity=ValidateStateMachineDefinitionSeverity.ERROR, code="SCHEMA_VALIDATION_FAILED", message=invalid_definition.message, ) ) except Exception as ex: validation_result = ValidateStateMachineDefinitionResultCode.FAIL LOG.error("Unknown error during validation %s", ex) return ValidateStateMachineDefinitionOutput( result=validation_result, diagnostics=diagnostics, truncated=False )
Memory