from __future__ import annotations
import abc
import datetime
import json
from collections import OrderedDict
from typing import Dict, Final, Optional
from moto.stepfunctions.parser.api import (
Definition,
DescribeStateMachineOutput,
LoggingConfiguration,
Name,
RevisionId,
StateMachineListItem,
StateMachineStatus,
StateMachineType,
StateMachineVersionListItem,
Tag,
TagKeyList,
TagList,
TracingConfiguration,
ValidationException,
VariableReferences,
)
from moto.stepfunctions.parser.asl.eval.event.logging import (
CloudWatchLoggingConfiguration,
)
from moto.stepfunctions.parser.asl.static_analyser.variable_references_static_analyser import (
VariableReferencesStaticAnalyser,
)
from moto.stepfunctions.parser.utils import long_uid
class StateMachineInstance:
name: Name
arn: str
revision_id: Optional[RevisionId]
definition: Definition
role_arn: str
create_date: datetime.datetime
sm_type: StateMachineType
logging_config: LoggingConfiguration
cloud_watch_logging_configuration: Optional[CloudWatchLoggingConfiguration]
tags: Optional[TagList]
tracing_config: Optional[TracingConfiguration]
def __init__(
self,
name: Name,
arn: str,
definition: Definition,
role_arn: str,
logging_config: LoggingConfiguration,
cloud_watch_logging_configuration: Optional[
CloudWatchLoggingConfiguration
] = None,
create_date: Optional[datetime.datetime] = None,
sm_type: Optional[StateMachineType] = None,
tags: Optional[TagList] = None,
tracing_config: Optional[TracingConfiguration] = None,
):
self.name = name
self.arn = arn
self.revision_id = None
self.definition = definition
self.role_arn = role_arn
self.create_date = create_date or datetime.datetime.now(
tz=datetime.timezone.utc
)
self.sm_type = sm_type or StateMachineType.STANDARD
self.logging_config = logging_config
self.cloud_watch_logging_configuration = cloud_watch_logging_configuration
self.tags = tags
self.tracing_config = tracing_config
def describe(self) -> DescribeStateMachineOutput:
describe_output = DescribeStateMachineOutput(
stateMachineArn=self.arn,
name=self.name,
status=StateMachineStatus.ACTIVE,
definition=self.definition,
roleArn=self.role_arn,
type=self.sm_type,
creationDate=self.create_date,
loggingConfiguration=self.logging_config,
)
if self.revision_id:
describe_output["revisionId"] = self.revision_id
variable_references: VariableReferences = (
VariableReferencesStaticAnalyser.process_and_get(definition=self.definition)
)
if variable_references:
describe_output["variableReferences"] = variable_references
return describe_output
@abc.abstractmethod
def itemise(self): ...
class TestStateMachine(StateMachineInstance):
def __init__(
self,
name: Name,
arn: str,
definition: Definition,
role_arn: str,
create_date: Optional[datetime.datetime] = None,
):
super().__init__(
name,
arn,
definition,
role_arn,
create_date,
StateMachineType.STANDARD,
None,
None,
None,
)
def itemise(self):
raise NotImplementedError("TestStateMachine does not support itemise.")
class TagManager:
_tags: Final[Dict[str, Optional[str]]]
def __init__(self):
self._tags = OrderedDict()
@staticmethod
def _validate_key_value(key: str) -> None:
if not key:
raise ValidationException()
@staticmethod
def _validate_tag_value(value: str) -> None:
if value is None:
raise ValidationException()
def add_all(self, tags: TagList) -> None:
for tag in tags:
tag_key = tag["key"]
tag_value = tag["value"]
self._validate_key_value(key=tag_key)
self._validate_tag_value(value=tag_value)
self._tags[tag_key] = tag_value
def remove_all(self, keys: TagKeyList):
for key in keys:
self._validate_key_value(key=key)
self._tags.pop(key, None)
def to_tag_list(self) -> TagList:
tag_list = list()
for key, value in self._tags.items():
tag_list.append(Tag(key=key, value=value))
return tag_list
class StateMachineRevision(StateMachineInstance):
_next_version_number: int
versions: Final[Dict[RevisionId, str]]
tag_manager: Final[TagManager]
def __init__(
self,
name: Name,
arn: str,
definition: Definition,
role_arn: str,
logging_config: LoggingConfiguration,
cloud_watch_logging_configuration: Optional[CloudWatchLoggingConfiguration],
create_date: Optional[datetime.datetime] = None,
sm_type: Optional[StateMachineType] = None,
tags: Optional[TagList] = None,
tracing_config: Optional[TracingConfiguration] = None,
):
super().__init__(
name,
arn,
definition,
role_arn,
logging_config,
cloud_watch_logging_configuration,
create_date,
sm_type,
tags,
tracing_config,
)
self.versions = dict()
self._version_number = 0
self.tag_manager = TagManager()
if tags:
self.tag_manager.add_all(tags)
def create_revision(
self,
definition: Optional[str],
role_arn: Optional[str],
logging_configuration: Optional[LoggingConfiguration],
) -> Optional[RevisionId]:
update_definition = definition and json.loads(definition) != json.loads(
self.definition
)
if update_definition:
self.definition = definition
update_role_arn = role_arn and role_arn != self.role_arn
if update_role_arn:
self.role_arn = role_arn
update_logging_configuration = (
logging_configuration and logging_configuration != self.logging_config
)
if update_logging_configuration:
self.logging_config = logging_configuration
self.cloud_watch_logging_configuration = (
CloudWatchLoggingConfiguration.from_logging_configuration(
state_machine_arn=self.arn,
logging_configuration=self.logging_config,
)
)
if any([update_definition, update_role_arn, update_logging_configuration]):
self.revision_id = long_uid()
return self.revision_id
def create_version(
self, description: Optional[str]
) -> Optional[StateMachineVersion]:
if self.revision_id not in self.versions:
self._version_number += 1
version = StateMachineVersion(
self, version=self._version_number, description=description
)
self.versions[self.revision_id] = version.arn
return version
return None
def delete_version(self, state_machine_version_arn: str) -> None:
source_revision_id = None
for revision_id, version_arn in self.versions.items():
if version_arn == state_machine_version_arn:
source_revision_id = revision_id
break
self.versions.pop(source_revision_id, None)
def itemise(self) -> StateMachineListItem:
return StateMachineListItem(
stateMachineArn=self.arn,
name=self.name,
type=self.sm_type,
creationDate=self.create_date,
)
class StateMachineVersion(StateMachineInstance):
source_arn: str
version: int
description: Optional[str]
def __init__(
self,
state_machine_revision: StateMachineRevision,
version: int,
description: Optional[str],
):
version_arn = f"{state_machine_revision.arn}:{version}"
super().__init__(
name=state_machine_revision.name,
arn=version_arn,
definition=state_machine_revision.definition,
role_arn=state_machine_revision.role_arn,
create_date=datetime.datetime.now(tz=datetime.timezone.utc),
sm_type=state_machine_revision.sm_type,
logging_config=state_machine_revision.logging_config,
cloud_watch_logging_configuration=state_machine_revision.cloud_watch_logging_configuration,
tags=state_machine_revision.tags,
tracing_config=state_machine_revision.tracing_config,
)
self.source_arn = state_machine_revision.arn
self.revision_id = state_machine_revision.revision_id
self.version = version
self.description = description
def describe(self) -> DescribeStateMachineOutput:
describe_output: DescribeStateMachineOutput = super().describe()
if self.description:
describe_output["description"] = self.description
return describe_output
def itemise(self) -> StateMachineVersionListItem:
return StateMachineVersionListItem(
stateMachineVersionArn=self.arn, creationDate=self.create_date
)