from __future__ import annotations import copy import logging import threading from typing import Any, Dict, Final, List, Optional from moto.stepfunctions.parser.api import ( Arn, ExecutionFailedEventDetails, StateMachineType, Timestamp, ) from moto.stepfunctions.parser.asl.component.state.exec.state_map.iteration.itemprocessor.map_run_record import ( MapRunRecordPoolManager, ) from moto.stepfunctions.parser.asl.eval.callback.callback import CallbackPoolManager from moto.stepfunctions.parser.asl.eval.evaluation_details import AWSExecutionDetails from moto.stepfunctions.parser.asl.eval.event.event_manager import ( EventHistoryContext, EventManager, ) from moto.stepfunctions.parser.asl.eval.event.logging import ( CloudWatchLoggingSession, ) from moto.stepfunctions.parser.asl.eval.program_state import ( ProgramEnded, ProgramError, ProgramRunning, ProgramState, ProgramStopped, ProgramTimedOut, ) from moto.stepfunctions.parser.asl.eval.states import ContextObjectData, States from moto.stepfunctions.parser.asl.eval.variable_store import VariableStore from moto.stepfunctions.parser.backend.activity import Activity LOG = logging.getLogger(__name__) class Environment: _state_mutex: Final[threading.RLock()] _program_state: Optional[ProgramState] program_state_event: Final[threading.Event()] event_manager: EventManager event_history_context: Final[EventHistoryContext] cloud_watch_logging_session: Optional[CloudWatchLoggingSession] aws_execution_details: Final[AWSExecutionDetails] execution_type: StateMachineType callback_pool_manager: CallbackPoolManager map_run_record_pool_manager: MapRunRecordPoolManager activity_store: Dict[Arn, Activity] _frames: Final[List[Environment]] _is_frame: bool = False heap: Dict[str, Any] = dict() stack: List[Any] = list() states: States variable_store: VariableStore def __init__( self, aws_execution_details: AWSExecutionDetails, execution_type: StateMachineType, context: ContextObjectData, event_history_context: EventHistoryContext, cloud_watch_logging_session: Optional[CloudWatchLoggingSession], activity_store: Dict[Arn, Activity], variable_store: Optional[VariableStore] = None, ): super(Environment, self).__init__() self._state_mutex = threading.RLock() self._program_state = None self.program_state_event = threading.Event() self.cloud_watch_logging_session = cloud_watch_logging_session self.event_manager = EventManager( cloud_watch_logging_session=cloud_watch_logging_session ) self.event_history_context = event_history_context self.aws_execution_details = aws_execution_details self.execution_type = execution_type self.callback_pool_manager = CallbackPoolManager(activity_store=activity_store) self.map_run_record_pool_manager = MapRunRecordPoolManager() self.activity_store = activity_store self._frames = list() self._is_frame = False self.heap = dict() self.stack = list() self.states = States(context=context) self.variable_store = variable_store or VariableStore() @classmethod def as_frame_of( cls, env: Environment, event_history_frame_cache: Optional[EventHistoryContext] = None, ) -> Environment: return Environment.as_inner_frame_of( env=env, variable_store=env.variable_store, event_history_frame_cache=event_history_frame_cache, ) @classmethod def as_inner_frame_of( cls, env: Environment, variable_store: VariableStore, event_history_frame_cache: Optional[EventHistoryContext] = None, ) -> Environment: # Construct the frame's context object data. context = ContextObjectData( Execution=env.states.context_object.context_object_data["Execution"], StateMachine=env.states.context_object.context_object_data["StateMachine"], ) if "Task" in env.states.context_object.context_object_data: context["Task"] = env.states.context_object.context_object_data["Task"] # The default logic provisions for child frame to extend the source frame event id. if event_history_frame_cache is None: event_history_frame_cache = EventHistoryContext( previous_event_id=env.event_history_context.source_event_id ) frame = cls( aws_execution_details=env.aws_execution_details, execution_type=env.execution_type, context=context, event_history_context=event_history_frame_cache, cloud_watch_logging_session=env.cloud_watch_logging_session, activity_store=env.activity_store, variable_store=variable_store, ) frame._is_frame = True frame.event_manager = env.event_manager if "State" in env.states.context_object.context_object_data: frame.states.context_object.context_object_data["State"] = copy.deepcopy( env.states.context_object.context_object_data["State"] ) frame.callback_pool_manager = env.callback_pool_manager frame.map_run_record_pool_manager = env.map_run_record_pool_manager frame.heap = env.heap frame._program_state = copy.deepcopy(env._program_state) return frame @property def next_state_name(self) -> Optional[str]: next_state_name: Optional[str] = None if isinstance(self._program_state, ProgramRunning): next_state_name = self._program_state.next_state_name return next_state_name @next_state_name.setter def next_state_name(self, next_state_name: str) -> None: if self._program_state is None: self._program_state = ProgramRunning() if isinstance(self._program_state, ProgramRunning): self._program_state.next_state_name = next_state_name else: raise RuntimeError( f"Could not set NextState value when in state '{type(self._program_state)}'." ) def program_state(self) -> ProgramState: return copy.deepcopy(self._program_state) def is_running(self) -> bool: return isinstance(self._program_state, ProgramRunning) def set_ended(self) -> None: with self._state_mutex: if isinstance(self._program_state, ProgramRunning): self._program_state = ProgramEnded() for frame in self._frames: frame.set_ended() self.program_state_event.set() self.program_state_event.clear() def set_error(self, error: ExecutionFailedEventDetails) -> None: with self._state_mutex: self._program_state = ProgramError(error=error) for frame in self._frames: frame.set_error(error=error) self.program_state_event.set() self.program_state_event.clear() def set_timed_out(self) -> None: with self._state_mutex: self._program_state = ProgramTimedOut() for frame in self._frames: frame.set_timed_out() self.program_state_event.set() self.program_state_event.clear() def set_stop( self, stop_date: Timestamp, cause: Optional[str], error: Optional[str] ) -> None: with self._state_mutex: if isinstance(self._program_state, ProgramRunning): self._program_state = ProgramStopped( stop_date=stop_date, cause=cause, error=error ) for frame in self._frames: frame.set_stop(stop_date=stop_date, cause=cause, error=error) self.program_state_event.set() self.program_state_event.clear() def open_frame( self, event_history_context: Optional[EventHistoryContext] = None ) -> Environment: with self._state_mutex: frame = self.as_frame_of( env=self, event_history_frame_cache=event_history_context ) self._frames.append(frame) return frame def open_inner_frame( self, event_history_context: Optional[EventHistoryContext] = None ) -> Environment: with self._state_mutex: variable_store = VariableStore.as_inner_scope_of( outer_variable_store=self.variable_store ) frame = self.as_inner_frame_of( env=self, variable_store=variable_store, event_history_frame_cache=event_history_context, ) self._frames.append(frame) return frame def close_frame(self, frame: Environment) -> None: with self._state_mutex: if frame in self._frames: self._frames.remove(frame) self.event_history_context.integrate(frame.event_history_context) def delete_frame(self, frame: Environment) -> None: with self._state_mutex: if frame in self._frames: self._frames.remove(frame) def is_frame(self) -> bool: return self._is_frame def is_standard_workflow(self) -> bool: return self.execution_type == StateMachineType.STANDARD
Memory