from __future__ import annotations
import datetime
import json
import logging
from typing import Dict, Optional
from moto.stepfunctions.parser.api import (
Arn,
CloudWatchEventsExecutionDataDetails,
DescribeExecutionOutput,
DescribeStateMachineForExecutionOutput,
ExecutionListItem,
ExecutionStatus,
GetExecutionHistoryOutput,
HistoryEventList,
InvalidName,
SensitiveCause,
SensitiveError,
StartExecutionOutput,
StartSyncExecutionOutput,
StateMachineType,
SyncExecutionStatus,
Timestamp,
TraceHeader,
VariableReferences,
)
from moto.stepfunctions.parser.asl.eval.evaluation_details import (
AWSExecutionDetails,
EvaluationDetails,
ExecutionDetails,
StateMachineDetails,
)
from moto.stepfunctions.parser.asl.eval.event.logging import (
CloudWatchLoggingSession,
)
from moto.stepfunctions.parser.asl.eval.program_state import (
ProgramEnded,
ProgramError,
ProgramState,
ProgramStopped,
ProgramTimedOut,
)
from moto.stepfunctions.parser.asl.static_analyser.variable_references_static_analyser import (
VariableReferencesStaticAnalyser,
)
from moto.stepfunctions.parser.asl.utils.encoding import to_json_str
from moto.stepfunctions.parser.backend.activity import Activity
from moto.stepfunctions.parser.backend.execution_worker import (
ExecutionWorker,
SyncExecutionWorker,
)
from moto.stepfunctions.parser.backend.execution_worker_comm import (
ExecutionWorkerCommunication,
)
from moto.stepfunctions.parser.backend.state_machine import (
StateMachineInstance,
StateMachineVersion,
)
LOG = logging.getLogger(__name__)
class BaseExecutionWorkerCommunication(ExecutionWorkerCommunication):
execution: Execution
def __init__(self, execution: Execution):
self.execution = execution
def _reflect_execution_status(self):
exit_program_state: ProgramState = (
self.execution.exec_worker.env.program_state()
)
self.execution.stop_date = datetime.datetime.now(tz=datetime.timezone.utc)
if isinstance(exit_program_state, ProgramEnded):
self.execution.exec_status = ExecutionStatus.SUCCEEDED
self.execution.output = self.execution.exec_worker.env.states.get_input()
elif isinstance(exit_program_state, ProgramStopped):
self.execution.exec_status = ExecutionStatus.ABORTED
elif isinstance(exit_program_state, ProgramError):
self.execution.exec_status = ExecutionStatus.FAILED
self.execution.error = exit_program_state.error.get("error")
self.execution.cause = exit_program_state.error.get("cause")
elif isinstance(exit_program_state, ProgramTimedOut):
self.execution.exec_status = ExecutionStatus.TIMED_OUT
else:
raise RuntimeWarning(
f"Execution ended with unsupported ProgramState type '{type(exit_program_state)}'."
)
def terminated(self) -> None:
self._reflect_execution_status()
class Execution:
name: str
sm_type: StateMachineType
role_arn: Arn
exec_arn: Arn
account_id: str
region_name: str
state_machine: StateMachineInstance
start_date: Timestamp
input_data: Optional[json]
input_details: Optional[CloudWatchEventsExecutionDataDetails]
trace_header: Optional[TraceHeader]
_cloud_watch_logging_session: Optional[CloudWatchLoggingSession]
exec_status: Optional[ExecutionStatus]
stop_date: Optional[Timestamp]
output: Optional[json]
output_details: Optional[CloudWatchEventsExecutionDataDetails]
error: Optional[SensitiveError]
cause: Optional[SensitiveCause]
exec_worker: Optional[ExecutionWorker]
_activity_store: Dict[Arn, Activity]
def __init__(
self,
name: str,
sm_type: StateMachineType,
role_arn: Arn,
exec_arn: Arn,
account_id: str,
region_name: str,
state_machine: StateMachineInstance,
start_date: Timestamp,
cloud_watch_logging_session: Optional[CloudWatchLoggingSession],
activity_store: Dict[Arn, Activity],
input_data: Optional[json] = None,
trace_header: Optional[TraceHeader] = None,
):
self.name = name
self.sm_type = sm_type
self.role_arn = role_arn
self.exec_arn = exec_arn
self.execution_arn = exec_arn
self.account_id = account_id
self.region_name = region_name
self.state_machine = state_machine
self._cloud_watch_logging_session = cloud_watch_logging_session
self.input_data = input_data
self.input_details = CloudWatchEventsExecutionDataDetails(included=True)
self.trace_header = trace_header
self.exec_status = None
self.stop_date = None
self.output = None
self.output_details = CloudWatchEventsExecutionDataDetails(included=True)
self.exec_worker = None
self.error = None
self.cause = None
self._activity_store = activity_store
# Compatibility with mock SFN
self.state_machine_arn = state_machine.arn
self.start_date = start_date
self.execution_input = input_data
@property
def status(self):
return self.exec_status.value
def to_start_output(self) -> StartExecutionOutput:
return StartExecutionOutput(
executionArn=self.exec_arn, startDate=self.start_date
)
def to_describe_output(self) -> DescribeExecutionOutput:
describe_output = DescribeExecutionOutput(
executionArn=self.exec_arn,
stateMachineArn=self.state_machine.arn,
name=self.name,
status=self.exec_status,
startDate=self.start_date,
stopDate=self.stop_date,
input=to_json_str(self.input_data, separators=(",", ":")),
inputDetails=self.input_details,
traceHeader=self.trace_header,
)
if describe_output["status"] == ExecutionStatus.SUCCEEDED:
describe_output["output"] = to_json_str(self.output, separators=(",", ":"))
describe_output["outputDetails"] = self.output_details
if self.error is not None:
describe_output["error"] = self.error
if self.cause is not None:
describe_output["cause"] = self.cause
return describe_output
def to_describe_state_machine_for_execution_output(
self,
) -> DescribeStateMachineForExecutionOutput:
state_machine: StateMachineInstance = self.state_machine
state_machine_arn = (
state_machine.source_arn
if isinstance(state_machine, StateMachineVersion)
else state_machine.arn
)
out = DescribeStateMachineForExecutionOutput(
stateMachineArn=state_machine_arn,
name=state_machine.name,
definition=state_machine.definition,
roleArn=self.role_arn,
# The date and time the state machine associated with an execution was updated.
updateDate=state_machine.create_date,
loggingConfiguration=state_machine.logging_config,
)
revision_id = self.state_machine.revision_id
if self.state_machine.revision_id:
out["revisionId"] = revision_id
variable_references: VariableReferences = (
VariableReferencesStaticAnalyser.process_and_get(
definition=self.state_machine.definition
)
)
if variable_references:
out["variableReferences"] = variable_references
return out
def to_execution_list_item(self) -> ExecutionListItem:
if isinstance(self.state_machine, StateMachineVersion):
state_machine_arn = self.state_machine.source_arn
state_machine_version_arn = self.state_machine.arn
else:
state_machine_arn = self.state_machine.arn
state_machine_version_arn = None
item = ExecutionListItem(
executionArn=self.exec_arn,
stateMachineArn=state_machine_arn,
name=self.name,
status=self.exec_status,
startDate=self.start_date,
stopDate=self.stop_date,
)
if state_machine_version_arn is not None:
item["stateMachineVersionArn"] = state_machine_version_arn
return item
def to_history_output(self) -> GetExecutionHistoryOutput:
env = self.exec_worker.env
event_history: HistoryEventList = list()
if env is not None:
# The execution has not started yet.
event_history: HistoryEventList = env.event_manager.get_event_history()
return GetExecutionHistoryOutput(events=event_history)
def _get_start_execution_worker_comm(self) -> BaseExecutionWorkerCommunication:
return BaseExecutionWorkerCommunication(self)
def _get_start_aws_execution_details(self) -> AWSExecutionDetails:
return AWSExecutionDetails(
account=self.account_id, region=self.region_name, role_arn=self.role_arn
)
def get_start_execution_details(self) -> ExecutionDetails:
return ExecutionDetails(
arn=self.exec_arn,
name=self.name,
role_arn=self.role_arn,
inpt=self.input_data,
start_time=self.start_date,
)
def get_start_state_machine_details(self) -> StateMachineDetails:
return StateMachineDetails(
arn=self.state_machine.arn,
name=self.state_machine.name,
typ=self.state_machine.sm_type,
definition=self.state_machine.definition,
)
def _get_start_execution_worker(self) -> ExecutionWorker:
return ExecutionWorker(
evaluation_details=EvaluationDetails(
aws_execution_details=self._get_start_aws_execution_details(),
execution_details=self.get_start_execution_details(),
state_machine_details=self.get_start_state_machine_details(),
),
exec_comm=self._get_start_execution_worker_comm(),
cloud_watch_logging_session=self._cloud_watch_logging_session,
activity_store=self._activity_store,
)
def start(self) -> None:
# TODO: checks exec_worker does not exists already?
if self.exec_worker:
raise InvalidName() # TODO.
self.exec_worker = self._get_start_execution_worker()
self.exec_status = ExecutionStatus.RUNNING
self.exec_worker.start()
def stop(
self, stop_date: datetime.datetime, error: Optional[str], cause: Optional[str]
):
exec_worker: Optional[ExecutionWorker] = self.exec_worker
if exec_worker:
exec_worker.stop(stop_date=stop_date, cause=cause, error=error)
class SyncExecutionWorkerCommunication(BaseExecutionWorkerCommunication):
execution: SyncExecution
def _reflect_execution_status(self) -> None:
super()._reflect_execution_status()
exit_status: ExecutionStatus = self.execution.exec_status
if exit_status == ExecutionStatus.SUCCEEDED:
self.execution.sync_execution_status = SyncExecutionStatus.SUCCEEDED
elif exit_status == ExecutionStatus.TIMED_OUT:
self.execution.sync_execution_status = SyncExecutionStatus.TIMED_OUT
else:
self.execution.sync_execution_status = SyncExecutionStatus.FAILED
class SyncExecution(Execution):
sync_execution_status: Optional[SyncExecutionStatus] = None
def _get_start_execution_worker(self) -> SyncExecutionWorker:
return SyncExecutionWorker(
evaluation_details=EvaluationDetails(
aws_execution_details=self._get_start_aws_execution_details(),
execution_details=self.get_start_execution_details(),
state_machine_details=self.get_start_state_machine_details(),
),
exec_comm=self._get_start_execution_worker_comm(),
cloud_watch_logging_session=self._cloud_watch_logging_session,
activity_store=self._activity_store,
)
def _get_start_execution_worker_comm(self) -> BaseExecutionWorkerCommunication:
return SyncExecutionWorkerCommunication(self)
def to_start_sync_execution_output(self) -> StartSyncExecutionOutput:
start_output = StartSyncExecutionOutput(
executionArn=self.exec_arn,
stateMachineArn=self.state_machine.arn,
name=self.name,
status=self.sync_execution_status,
startDate=self.start_date,
stopDate=self.stop_date,
input=to_json_str(self.input_data, separators=(",", ":")),
inputDetails=self.input_details,
traceHeader=self.trace_header,
)
if self.sync_execution_status == SyncExecutionStatus.SUCCEEDED:
start_output["output"] = to_json_str(self.output, separators=(",", ":"))
if self.output_details:
start_output["outputDetails"] = self.output_details
if self.error is not None:
start_output["error"] = self.error
if self.cause is not None:
start_output["cause"] = self.cause
return start_output