import copy from typing import Any, Optional, Set, TypedDict from moto.stepfunctions.parser.asl.jsonata.jsonata import ( VariableDeclarations, VariableReference, encode_jsonata_variable_declarations, ) from moto.stepfunctions.parser.asl.utils.json_path import extract_json from moto.stepfunctions.parser.utils import long_uid _STATES_PREFIX: str = "$states" _STATES_INPUT_PREFIX: str = "$states.input" _STATES_CONTEXT_PREFIX: str = "$states.context" _STATES_RESULT_PREFIX: str = "$states.result" _STATES_ERROR_OUTPUT_PREFIX: str = "$states.errorOutput" class ExecutionData(TypedDict): Id: str Input: Optional[Any] Name: str RoleArn: str StartTime: str # Format: ISO 8601. class StateData(TypedDict): EnteredTime: str # Format: ISO 8601. Name: str RetryCount: int class StateMachineData(TypedDict): Id: str Name: str class TaskData(TypedDict): Token: str class ItemData(TypedDict): # Contains the index number for the array item that is being currently processed. Index: int # Contains the array item being processed. Value: Optional[Any] class MapData(TypedDict): Item: ItemData class ContextObjectData(TypedDict): Execution: ExecutionData State: Optional[StateData] StateMachine: StateMachineData Task: Optional[TaskData] # Null if the Parameters field is outside a task state. Map: Optional[MapData] # Only available when processing a Map state. class ContextObject: context_object_data: ContextObjectData def __init__(self, context_object: ContextObjectData): self.context_object_data = context_object def update_task_token(self) -> str: new_token = long_uid() self.context_object_data["Task"] = TaskData(Token=new_token) return new_token class StatesData(TypedDict): input: Any context: ContextObjectData result: Optional[Optional[Any]] errorOutput: Optional[Optional[Any]] class States: _states_data: StatesData context_object: ContextObject def __init__(self, context: ContextObjectData): input_value = context["Execution"]["Input"] self._states_data = StatesData(input=input_value, context=context) self.context_object = ContextObject(context_object=context) @staticmethod def _extract(query: Optional[str], data: Any) -> Any: if query is None: result = data else: result = extract_json(query, data) return copy.deepcopy(result) def extract(self, query: str) -> Any: if not query.startswith(_STATES_PREFIX): raise RuntimeError(f"No such variable {query} in $states") jsonpath_states_query = "$." + query[1:] return self._extract(jsonpath_states_query, self._states_data) def get_input(self, query: Optional[str] = None) -> Any: return self._extract(query, self._states_data["input"]) def reset(self, input_value: Any) -> None: clone_input_value = copy.deepcopy(input_value) self._states_data["input"] = clone_input_value self._states_data["result"] = None self._states_data["errorOutput"] = None def get_context(self, query: Optional[str] = None) -> Any: return self._extract(query, self._states_data["context"]) def get_result(self, query: Optional[str] = None) -> Any: if "result" not in self._states_data: raise RuntimeError("Illegal access to $states.result") return self._extract(query, self._states_data["result"]) def set_result(self, result: Any) -> Any: clone_result = copy.deepcopy(result) self._states_data["result"] = clone_result def get_error_output(self, query: Optional[str] = None) -> Any: if "errorOutput" not in self._states_data: raise RuntimeError("Illegal access to $states.errorOutput") return self._extract(query, self._states_data["errorOutput"]) def set_error_output(self, error_output: Any) -> None: clone_error_output = copy.deepcopy(error_output) self._states_data["errorOutput"] = clone_error_output def to_variable_declarations( self, variable_references: Optional[Set[VariableReference]] = None ) -> VariableDeclarations: if not variable_references or _STATES_PREFIX in variable_references: return encode_jsonata_variable_declarations( bindings={_STATES_PREFIX: self._states_data} ) candidate_sub_states = { "input": _STATES_INPUT_PREFIX, "context": _STATES_CONTEXT_PREFIX, "result": _STATES_RESULT_PREFIX, "errorOutput": _STATES_ERROR_OUTPUT_PREFIX, } sub_states = dict() for variable_reference in variable_references: if not candidate_sub_states: break for sub_states_key, sub_states_prefix in candidate_sub_states.items(): if variable_reference.startswith(sub_states_prefix): sub_states[sub_states_key] = self._states_data[sub_states_key] # noqa del candidate_sub_states[sub_states_key] break return encode_jsonata_variable_declarations( bindings={_STATES_PREFIX: sub_states} )
Memory