from typing import Any, Dict, List, Optional
from moto.appmesh.dataclasses.route import (
GrcpRouteRetryPolicy,
GrpcMetadatum,
GrpcRoute,
GrpcRouteMatch,
HttpRoute,
HttpRouteMatch,
HttpRouteRetryPolicy,
Match,
QueryParameterMatch,
Range,
RouteAction,
RouteActionWeightedTarget,
RouteMatchPath,
RouteMatchQueryParameter,
RouteSpec,
TCPRoute,
TCPRouteMatch,
)
from moto.appmesh.dataclasses.shared import Duration, Timeout
from moto.appmesh.dataclasses.virtual_node import (
ACM,
DNS,
SDS,
AccessLog,
AccessLogFile,
AWSCloudMap,
Backend,
BackendDefaults,
BackendTrust,
Certificate,
CertificateFile,
CertificateFileWithPrivateKey,
ClientPolicy,
ConnectionPool,
GRPCOrHTTP2Connection,
HealthCheck,
HTTPConnection,
KeyValue,
Listener,
ListenerCertificateACM,
ListenerTLS,
Logging,
LoggingFormat,
OutlierDetection,
PortMapping,
ProtocolTimeouts,
ServiceDiscovery,
SubjectAlternativeNames,
TCPConnection,
TCPTimeout,
TLSBackendValidation,
TLSClientPolicy,
TLSListenerCertificate,
TLSListenerValidation,
Trust,
VirtualNodeSpec,
VirtualService,
)
from moto.appmesh.dataclasses.virtual_node import (
Match as VirtualNodeMatch,
)
from moto.appmesh.dataclasses.virtual_router import PortMapping as RouterPortMapping
from moto.appmesh.exceptions import (
MissingRequiredFieldError,
)
def port_mappings_from_router_spec(spec: Any) -> List[RouterPortMapping]: # type: ignore[misc]
return [
RouterPortMapping(
port=(listener.get("portMapping") or {}).get("port"),
protocol=(listener.get("portMapping") or {}).get("protocol"),
)
for listener in ((spec or {}).get("listeners") or [])
]
def get_action_from_route(route: Any) -> RouteAction: # type: ignore[misc]
weighted_targets = [
RouteActionWeightedTarget(
port=target.get("port"),
virtual_node=target.get("virtualNode"),
weight=target.get("weight"),
)
for target in (route.get("action") or {}).get("weightedTargets") or []
]
return RouteAction(weighted_targets=weighted_targets)
def get_route_match_metadata(metadata: List[Any]) -> List[GrpcMetadatum]: # type: ignore[misc]
output = []
for _metadatum in metadata:
_match = _metadatum.get("match")
match = None
if _match is not None:
_range = _match.get("range")
range = None
if _range is not None:
range = Range(start=_range.get("start"), end=_range.get("end"))
match = Match(
exact=_match.get("exact"),
prefix=_match.get("prefix"),
range=range,
regex=_match.get("regex"),
suffix=_match.get("suffix"),
)
output.append(
GrpcMetadatum(
invert=_metadatum.get("invert"),
match=match,
name=_metadatum.get("name"),
)
)
return output
def get_grpc_route_match(route: Any) -> GrpcRouteMatch: # type: ignore[misc]
_route_match = route.get("match")
metadata = None
if _route_match is not None:
metadata = get_route_match_metadata(_route_match.get("metadata") or [])
return GrpcRouteMatch(
metadata=metadata,
method_name=_route_match.get("methodName"),
port=_route_match.get("port"),
service_name=_route_match.get("serviceName"),
)
def get_http_match_from_route(route: Any) -> HttpRouteMatch: # type: ignore[misc]
_route_match = route.get("match") or {}
headers, path, query_parameters = None, None, None
if _route_match is not None:
headers = get_route_match_metadata(_route_match.get("headers") or [])
_path = _route_match.get("path")
if _path is not None:
path = RouteMatchPath(exact=_path.get("exact"), regex=_path.get("regex"))
_query_parameters = _route_match.get("queryParameters")
if _query_parameters is not None:
query_parameters = []
for _param in _query_parameters:
_match = _param.get("match")
match = None
if _match is not None:
match = QueryParameterMatch(exact=_match.get("exact"))
query_parameters.append(
RouteMatchQueryParameter(name=_param.get("name"), match=match)
)
return HttpRouteMatch(
headers=headers,
method=(_route_match or {}).get("method"),
path=path,
port=(_route_match or {}).get("port"),
prefix=(_route_match or {}).get("prefix"),
query_parameters=query_parameters,
scheme=(_route_match or {}).get("scheme"),
)
def get_http_retry_policy_from_route(route: Any) -> Optional[HttpRouteRetryPolicy]: # type: ignore[misc]
_retry_policy = route.get("retryPolicy")
retry_policy = None
if _retry_policy is not None:
_per_retry_timeout = _retry_policy.get("perRetryTimeout")
per_retry_timeout = Duration(
unit=_per_retry_timeout.get("unit"), value=_per_retry_timeout.get("value")
)
retry_policy = HttpRouteRetryPolicy(
max_retries=_retry_policy.get("maxRetries"),
http_retry_events=_retry_policy.get("httpRetryEvents"),
per_retry_timeout=per_retry_timeout,
tcp_retry_events=_retry_policy.get("tcpRetryEvents"),
)
return retry_policy
def get_timeout_from_route(route: Any) -> Optional[Timeout]: # type: ignore[misc]
_timeout = route.get("timeout") or {}
idle, per_request = None, None
_idle = _timeout.get("idle")
if _idle is not None:
idle = Duration(unit=_idle.get("unit"), value=_idle.get("value"))
_per_request = _timeout.get("perRequest")
if _per_request is not None:
per_request = Duration(
unit=_per_request.get("unit"), value=_per_request.get("value")
)
return (
Timeout(idle=idle, per_request=per_request)
if idle is not None or per_request is not None
else None
)
def get_tls_for_client_policy(tls: Any) -> TLSClientPolicy: # type: ignore[misc]
_certificate = tls.get("certificate")
_validation = tls.get("validation")
certificate, validation = None, None
if _certificate is not None:
_file = _certificate.get("file")
_sds = _certificate.get("sds")
file, sds = None, None
if _file is not None:
file = CertificateFileWithPrivateKey(
certificate_chain=_file.get("certificateChain"),
private_key=_file.get("privateKey"),
)
if _sds is not None:
sds = SDS(secret_name=_sds.get("secretName"))
certificate = Certificate(file=file, sds=sds)
if _validation is None:
raise MissingRequiredFieldError("validation")
_subject_alternative_names = _validation.get("subjectAlternativeNames")
_trust = _validation.get("trust")
subject_alternative_names = None
if _subject_alternative_names is not None:
match = VirtualNodeMatch(
exact=(_subject_alternative_names.get("match") or {}).get("exact") or []
)
subject_alternative_names = SubjectAlternativeNames(match=match)
if _trust is None:
raise MissingRequiredFieldError("trust")
_trust_file = _trust.get("file")
_trust_sds = _trust.get("sds")
_acm = _trust.get("acm")
trust_file, trust_sds, acm = None, None, None
if _trust_file is not None:
trust_file = CertificateFile(
certificate_chain=_trust_file.get("certificateChain")
)
if _trust_sds is not None:
trust_sds = SDS(secret_name=_trust_sds.get("secretName"))
if _acm is not None:
acm = ACM(certificate_authority_arns=_acm.get("certificateAuthorityArns"))
trust = BackendTrust(file=trust_file, sds=trust_sds, acm=acm)
validation = TLSBackendValidation(
subject_alternative_names=subject_alternative_names, trust=trust
)
return TLSClientPolicy(
certificate=certificate,
enforce=tls.get("enforce"),
ports=tls.get("ports"),
validation=validation,
)
def build_route_spec(spec: Dict[str, Any]) -> RouteSpec: # type: ignore[misc]
_grpc_route = spec.get("grpcRoute")
_http_route = spec.get("httpRoute")
_http2_route = spec.get("http2Route")
_tcp_route = spec.get("tcpRoute")
grpc_route, http_route, http2_route, tcp_route = None, None, None, None
if _grpc_route is not None:
grpc_action = get_action_from_route(_grpc_route)
grpc_route_match = get_grpc_route_match(_grpc_route)
_retry_policy = _grpc_route.get("retryPolicy")
grpc_retry_policy = None
if _retry_policy is not None:
_per_retry_timeout = _retry_policy.get("perRetryTimeout")
per_retry_timeout = Duration(
unit=_per_retry_timeout.get("unit"),
value=_per_retry_timeout.get("value"),
)
grpc_retry_policy = GrcpRouteRetryPolicy(
grpc_retry_events=_retry_policy.get("grpcRetryEvents"),
http_retry_events=_retry_policy.get("httpRetryEvents"),
max_retries=_retry_policy.get("maxRetries"),
per_retry_timeout=per_retry_timeout,
tcp_retry_events=_retry_policy.get("tcpRetryEvents"),
)
grpc_timeout = get_timeout_from_route(_grpc_route)
grpc_route = GrpcRoute(
action=grpc_action,
match=grpc_route_match,
retry_policy=grpc_retry_policy,
timeout=grpc_timeout,
)
if _http_route is not None:
http_action = get_action_from_route(_http_route)
http_match = get_http_match_from_route(_http_route)
http_retry_policy = get_http_retry_policy_from_route(_http_route)
http_timeout = get_timeout_from_route(_http_route)
http_route = HttpRoute(
action=http_action,
match=http_match,
retry_policy=http_retry_policy,
timeout=http_timeout,
)
if _http2_route is not None:
http2_action = get_action_from_route(_http2_route)
http2_match = get_http_match_from_route(_http2_route)
http2_retry_policy = get_http_retry_policy_from_route(_http2_route)
http2_timeout = get_timeout_from_route(_http2_route)
http2_route = HttpRoute(
action=http2_action,
match=http2_match,
retry_policy=http2_retry_policy,
timeout=http2_timeout,
)
if _tcp_route is not None:
tcp_action = get_action_from_route(_tcp_route)
tcp_timeout = get_timeout_from_route(_tcp_route)
_tcp_match = _tcp_route.get("match")
tcp_match = None
if _tcp_match is not None:
tcp_match = TCPRouteMatch(port=_tcp_match.get("port"))
tcp_route = TCPRoute(action=tcp_action, match=tcp_match, timeout=tcp_timeout)
return RouteSpec(
grpc_route=grpc_route,
http_route=http_route,
http2_route=http2_route,
priority=spec.get("priority"),
tcp_route=tcp_route,
)
def build_virtual_node_spec(spec: Dict[str, Any]) -> VirtualNodeSpec: # type: ignore[misc]
_backend_defaults = spec.get("backendDefaults")
_backends = spec.get("backends")
_listeners = spec.get("listeners")
_logging = spec.get("logging")
_service_discovery = spec.get("serviceDiscovery")
backend_defaults, backends, listeners, logging, service_discovery = (
None,
None,
None,
None,
None,
)
if _backend_defaults is not None:
_client_policy = _backend_defaults.get("clientPolicy")
client_policy = None
if _client_policy is not None:
_tls = _client_policy.get("tls")
tls = None
if _tls is not None:
tls = get_tls_for_client_policy(_tls)
client_policy = ClientPolicy(tls=tls)
backend_defaults = BackendDefaults(client_policy=client_policy)
if _backends is not None:
backends = []
for _backend in _backends:
_virtual_service = _backend.get("virtualService")
virtual_service = None
if _virtual_service is not None:
_virtual_service_client_policy = _virtual_service.get("clientPolicy")
virtual_service_client_policy = None
if _virtual_service_client_policy is not None:
_tls_client_policy = _virtual_service_client_policy.get("tls")
tls_client_policy = None
if _tls_client_policy is not None:
tls_client_policy = get_tls_for_client_policy(
_tls_client_policy
)
virtual_service_client_policy = ClientPolicy(tls=tls_client_policy)
virtual_service = VirtualService(
client_policy=virtual_service_client_policy,
virtual_service_name=_virtual_service.get("virtualServiceName"),
)
backend = Backend(virtual_service=virtual_service)
backends.append(backend)
if _listeners is not None:
listeners = []
for _listener in _listeners:
_connection_pool = _listener.get("connectionPool")
_health_check = _listener.get("healthCheck")
_outlier_detection = _listener.get("outlierDetection")
_port_mapping = _listener.get("portMapping")
_timeout = _listener.get("timeout")
_listener_tls = _listener.get("tls")
(
connection_pool,
health_check,
outlier_detection,
timeout,
listener_tls,
) = None, None, None, None, None
if _connection_pool is not None:
_grpc = _connection_pool.get("grpc")
_http = _connection_pool.get("http")
_http2 = _connection_pool.get("http2")
_tcp = _connection_pool.get("tcp")
grpc, http, http2, tcp = None, None, None, None
if _grpc is not None:
grpc = GRPCOrHTTP2Connection(max_requests=_grpc.get("maxRequests"))
if _http is not None:
http = HTTPConnection(
max_connections=_http.get("maxConnections"),
max_pending_requests=_http.get("maxPendingRequests"),
)
if _http2 is not None:
http2 = GRPCOrHTTP2Connection(
max_requests=_http2.get("maxRequests")
)
if _tcp is not None:
tcp = TCPConnection(max_connections=_tcp.get("maxConnections"))
connection_pool = ConnectionPool(
grpc=grpc, http=http, http2=http2, tcp=tcp
)
if _health_check is not None:
health_check = HealthCheck(
healthy_threshold=_health_check.get("healthyThreshold"),
interval_millis=_health_check.get("intervalMillis"),
path=_health_check.get("path"),
port=_health_check.get("port"),
protocol=_health_check.get("protocol"),
timeout_millis=_health_check.get("timeoutMillis"),
unhealthy_threshold=_health_check.get("unhealthyThreshold"),
)
if _outlier_detection is not None:
_base_ejection_duration = _outlier_detection.get("baseEjectionDuration")
_interval = _outlier_detection.get("interval")
if _base_ejection_duration is None:
raise MissingRequiredFieldError("baseEjectionDuration")
base_ejection_duration = Duration(
unit=_base_ejection_duration.get("unit"),
value=_base_ejection_duration.get("value"),
)
if _interval is None:
raise MissingRequiredFieldError("interval")
interval = Duration(
unit=_interval.get("unit"), value=_interval.get("value")
)
outlier_detection = OutlierDetection(
base_ejection_duration=base_ejection_duration,
interval=interval,
max_ejection_percent=_outlier_detection.get("maxEjectionPercent"),
max_server_errors=_outlier_detection.get("maxServerErrors"),
)
if _port_mapping is None:
raise MissingRequiredFieldError("portMapping")
port_mapping = PortMapping(
port=_port_mapping.get("port"),
protocol=_port_mapping.get("protocol"),
)
if _timeout is not None:
_grpc_timeout = _timeout.get("grpc")
_http_timeout = _timeout.get("http")
_http2_timeout = _timeout.get("http2")
_tcp_timeout = _timeout.get("tcp")
grpc_timeout, http_timeout, http2_timeout, tcp_timeout = (
None,
None,
None,
None,
)
if _grpc_timeout is not None:
_idle = _grpc_timeout.get("idle")
_per_request = _grpc_timeout.get("perRequest")
idle, per_request = None, None
if _idle is not None:
idle = Duration(
unit=_idle.get("unit"), value=_idle.get("value")
)
if _per_request is not None:
per_request = Duration(
unit=_per_request.get("unit"),
value=_per_request.get("value"),
)
grpc_timeout = Timeout(idle=idle, per_request=per_request)
if _http_timeout is not None:
_idle = _http_timeout.get("idle")
_per_request = _http_timeout.get("perRequest")
idle, per_request = None, None
if _idle is not None:
idle = Duration(
unit=_idle.get("unit"), value=_idle.get("value")
)
if _per_request is not None:
per_request = Duration(
unit=_per_request.get("unit"),
value=_per_request.get("value"),
)
http_timeout = Timeout(idle=idle, per_request=per_request)
if _http2_timeout is not None:
_idle = _http2_timeout.get("idle")
_per_request = _http2_timeout.get("perRequest")
idle, per_request = None, None
if _idle is not None:
idle = Duration(
unit=_idle.get("unit"), value=_idle.get("value")
)
if _per_request is not None:
per_request = Duration(
unit=_per_request.get("unit"),
value=_per_request.get("value"),
)
http2_timeout = Timeout(idle=idle, per_request=per_request)
if _tcp_timeout is not None:
_idle = _tcp_timeout.get("idle")
if _idle is None:
raise MissingRequiredFieldError("idle")
idle = Duration(unit=_idle.get("unit"), value=_idle.get("value"))
tcp_timeout = TCPTimeout(idle=idle)
timeout = ProtocolTimeouts(
grpc=grpc_timeout,
http=http_timeout,
http2=http2_timeout,
tcp=tcp_timeout,
)
if _listener_tls is not None:
_tls_listener_certificate = _listener_tls.get("certificate")
_tls_listener_validation = _listener_tls.get("validation")
tls_listener_validation = None
if _tls_listener_certificate is None:
raise MissingRequiredFieldError("certificate")
_listener_certificate_file = _tls_listener_certificate.get("file")
_listener_certificate_sds = _tls_listener_certificate.get("sds")
_listener_certificate_acm = _tls_listener_certificate.get("acm")
(
listener_certificate_file,
listener_certificate_sds,
listener_certificate_acm,
) = None, None, None
if _listener_certificate_file is not None:
listener_certificate_file = CertificateFileWithPrivateKey(
certificate_chain=_listener_certificate_file.get(
"certificateChain"
),
private_key=_listener_certificate_file.get("privateKey"),
)
if _listener_certificate_sds is not None:
listener_certificate_sds = SDS(
secret_name=_listener_certificate_sds.get("secretName")
)
if _listener_certificate_acm is not None:
listener_certificate_acm = ListenerCertificateACM(
certificate_arn=_listener_certificate_acm.get("certificateArn")
)
tls_listener_certificate = TLSListenerCertificate(
file=listener_certificate_file,
sds=listener_certificate_sds,
acm=listener_certificate_acm,
)
if _tls_listener_validation is not None:
_subject_alternative_names = _tls_listener_validation.get(
"subjectAlternativeNames"
)
_trust = _tls_listener_validation.get("trust")
subject_alternative_names = None
if _subject_alternative_names is not None:
_tls_listener_match = _subject_alternative_names.get("match")
tls_listener_match = VirtualNodeMatch(
exact=_tls_listener_match.get("exact")
)
subject_alternative_names = SubjectAlternativeNames(
match=tls_listener_match
)
if _trust is None:
raise MissingRequiredFieldError("trust")
_tls_listener_certificate_file = _trust.get("file")
_tls_listener_sds = _trust.get("sds")
tls_listener_certificate_file, tls_listener_sds = None, None
if _tls_listener_certificate_file is not None:
tls_listener_certificate_file = CertificateFile(
certificate_chain=_tls_listener_certificate_file.get(
"certificateChain"
)
)
if _tls_listener_sds is not None:
tls_listener_sds = SDS(
secret_name=_tls_listener_sds.get("secretName")
)
tls_listener_trust = Trust(
file=tls_listener_certificate_file, sds=tls_listener_sds
)
tls_listener_validation = TLSListenerValidation(
subject_alternative_names=subject_alternative_names,
trust=tls_listener_trust,
)
listener_tls = ListenerTLS(
certificate=tls_listener_certificate,
mode=_listener_tls.get("mode"),
validation=tls_listener_validation,
)
listener = Listener(
connection_pool=connection_pool,
health_check=health_check,
outlier_detection=outlier_detection,
port_mapping=port_mapping,
timeout=timeout,
tls=listener_tls,
)
listeners.append(listener)
if _logging is not None:
_access_log = _logging.get("accessLog")
access_log = None
if _access_log is not None:
_file = _access_log.get("file")
file = None
if _file is not None:
_format = _file.get("format")
format = None
if _format is not None:
_json = _format.get("json")
json = None
if _json is not None:
json = list()
for item in _json:
json.append(
KeyValue(key=item.get("key"), value=item.get("value"))
)
format = LoggingFormat(json=json, text=_format.get("text"))
file = AccessLogFile(format=format, path=_file.get("path"))
access_log = AccessLog(file=file)
logging = Logging(access_log=access_log)
if _service_discovery is not None:
_aws_cloud_map = _service_discovery.get("awsCloudMap")
_dns = _service_discovery.get("dns")
aws_cloud_map, dns = None, None
if _aws_cloud_map is not None:
_attributes = _aws_cloud_map.get("attributes")
if _attributes is None:
raise MissingRequiredFieldError("attributes")
attributes = [
KeyValue(key=attribute.get("key"), value=attribute.get("value"))
for attribute in _attributes
]
aws_cloud_map = AWSCloudMap(
attributes=attributes,
ip_preference=_aws_cloud_map.get("ipPreference"),
namespace_name=_aws_cloud_map.get("namespaceName"),
service_name=_aws_cloud_map.get("serviceName"),
)
if _dns is not None:
dns = DNS(
hostname=_dns.get("hostname"),
ip_preference=_dns.get("ipPreference"),
response_type=_dns.get("responseType"),
)
service_discovery = ServiceDiscovery(aws_cloud_map=aws_cloud_map, dns=dns)
return VirtualNodeSpec(
backend_defaults=backend_defaults,
backends=backends,
listeners=listeners,
logging=logging,
service_discovery=service_discovery,
)