import base64 import json from datetime import datetime, timedelta, timezone from typing import Any, Dict, Iterable, List, Optional, Tuple from moto.core.base_backend import BackendDict, BaseBackend from moto.core.common_models import BaseModel from moto.core.utils import unix_time from moto.moto_api._internal import mock_random from moto.utilities.tagging_service import TaggingService from moto.utilities.utils import get_partition from .exceptions import ( BadRequestException, GraphqlAPICacheNotFound, GraphqlAPINotFound, GraphQLSchemaException, ) # AWS custom scalars and directives # https://github.com/dotansimha/graphql-code-generator/discussions/4311#discussioncomment-2921796 AWS_CUSTOM_GRAPHQL = """scalar AWSTime scalar AWSDateTime scalar AWSTimestamp scalar AWSEmail scalar AWSJSON scalar AWSURL scalar AWSPhone scalar AWSIPAddress scalar BigInt scalar Double directive @aws_subscribe(mutations: [String!]!) on FIELD_DEFINITION # Allows transformer libraries to deprecate directive arguments. directive @deprecated(reason: String!) on INPUT_FIELD_DEFINITION | ENUM directive @aws_auth(cognito_groups: [String!]!) on FIELD_DEFINITION directive @aws_api_key on FIELD_DEFINITION | OBJECT directive @aws_iam on FIELD_DEFINITION | OBJECT directive @aws_oidc on FIELD_DEFINITION | OBJECT directive @aws_cognito_user_pools( cognito_groups: [String!] ) on FIELD_DEFINITION | OBJECT """ class GraphqlSchema(BaseModel): def __init__(self, definition: Any, region_name: str): self.definition = definition self.region_name = region_name # [graphql.language.ast.ObjectTypeDefinitionNode, ..] self.types: List[Any] = [] self.status = "PROCESSING" self.parse_error: Optional[str] = None self._parse_graphql_definition() def get_type(self, name: str) -> Optional[Dict[str, Any]]: # type: ignore[return] for graphql_type in self.types: if graphql_type.name.value == name: return { "name": name, "description": graphql_type.description.value if graphql_type.description else None, "arn": f"arn:{get_partition(self.region_name)}:appsync:graphql_type/{name}", "definition": "NotYetImplemented", } def get_status(self) -> Tuple[str, Optional[str]]: return self.status, self.parse_error def _parse_graphql_definition(self) -> None: try: from graphql import parse from graphql.error.graphql_error import GraphQLError from graphql.language.ast import ObjectTypeDefinitionNode res = parse(self.definition) for definition in res.definitions: if isinstance(definition, ObjectTypeDefinitionNode): self.types.append(definition) self.status = "SUCCESS" except GraphQLError as e: self.status = "FAILED" self.parse_error = str(e) def get_introspection_schema(self, format_: str, include_directives: bool) -> str: from graphql import ( build_client_schema, build_schema, introspection_from_schema, print_schema, ) schema = build_schema(self.definition + AWS_CUSTOM_GRAPHQL) introspection_data = introspection_from_schema(schema, descriptions=False) if not include_directives: introspection_data["__schema"]["directives"] = [] if format_ == "SDL": return print_schema(build_client_schema(introspection_data)) elif format_ == "JSON": return json.dumps(introspection_data) else: raise BadRequestException(message=f"Invalid format {format_} given") class GraphqlAPIKey(BaseModel): def __init__(self, description: str, expires: Optional[int]): self.key_id = str(mock_random.uuid4())[0:6] self.description = description if not expires: default_expiry = datetime.now(timezone.utc) default_expiry = default_expiry.replace( minute=0, second=0, microsecond=0, tzinfo=None ) default_expiry = default_expiry + timedelta(days=7) self.expires = unix_time(default_expiry) else: self.expires = expires def update(self, description: Optional[str], expires: Optional[int]) -> None: if description: self.description = description if expires: self.expires = expires def to_json(self) -> Dict[str, Any]: return { "id": self.key_id, "description": self.description, "expires": self.expires, "deletes": self.expires, } class APICache(BaseModel): def __init__( self, ttl: int, api_caching_behavior: str, type_: str, transit_encryption_enabled: Optional[bool] = None, at_rest_encryption_enabled: Optional[bool] = None, health_metrics_config: Optional[str] = None, ): self.ttl = ttl self.api_caching_behavior = api_caching_behavior self.type = type_ self.transit_encryption_enabled = transit_encryption_enabled or False self.at_rest_encryption_enabled = at_rest_encryption_enabled or False self.health_metrics_config = health_metrics_config or "DISABLED" self.status = "AVAILABLE" def update( self, ttl: int, api_caching_behavior: str, type: str, health_metrics_config: Optional[str] = None, ) -> None: self.ttl = ttl self.api_caching_behavior = api_caching_behavior self.type = type if health_metrics_config is not None: self.health_metrics_config = health_metrics_config def to_json(self) -> Dict[str, Any]: return { "ttl": self.ttl, "transitEncryptionEnabled": self.transit_encryption_enabled, "atRestEncryptionEnabled": self.at_rest_encryption_enabled, "apiCachingBehavior": self.api_caching_behavior, "type": self.type, "healthMetricsConfig": self.health_metrics_config, "status": self.status, } class GraphqlAPI(BaseModel): def __init__( self, account_id: str, region: str, name: str, authentication_type: str, additional_authentication_providers: Optional[List[str]], log_config: str, xray_enabled: str, user_pool_config: str, open_id_connect_config: str, lambda_authorizer_config: str, visibility: str, ): self.region = region self.name = name self.api_id = str(mock_random.uuid4()) self.authentication_type = authentication_type self.additional_authentication_providers = additional_authentication_providers self.lambda_authorizer_config = lambda_authorizer_config self.log_config = log_config self.open_id_connect_config = open_id_connect_config self.user_pool_config = user_pool_config self.xray_enabled = xray_enabled self.visibility = visibility or "GLOBAL" # Default to Global if not provided self.arn = f"arn:{get_partition(self.region)}:appsync:{self.region}:{account_id}:apis/{self.api_id}" self.graphql_schema: Optional[GraphqlSchema] = None self.api_keys: Dict[str, GraphqlAPIKey] = dict() self.api_cache: Optional[APICache] = None def update( self, name: str, additional_authentication_providers: Optional[List[str]], authentication_type: str, lambda_authorizer_config: str, log_config: str, open_id_connect_config: str, user_pool_config: str, xray_enabled: str, ) -> None: if name: self.name = name if additional_authentication_providers: self.additional_authentication_providers = ( additional_authentication_providers ) if authentication_type: self.authentication_type = authentication_type if lambda_authorizer_config: self.lambda_authorizer_config = lambda_authorizer_config if log_config: self.log_config = log_config if open_id_connect_config: self.open_id_connect_config = open_id_connect_config if user_pool_config: self.user_pool_config = user_pool_config if xray_enabled is not None: self.xray_enabled = xray_enabled def create_api_key(self, description: str, expires: Optional[int]) -> GraphqlAPIKey: api_key = GraphqlAPIKey(description, expires) self.api_keys[api_key.key_id] = api_key return api_key def list_api_keys(self) -> Iterable[GraphqlAPIKey]: return self.api_keys.values() def delete_api_key(self, api_key_id: str) -> None: self.api_keys.pop(api_key_id) def update_api_key( self, api_key_id: str, description: str, expires: Optional[int] ) -> GraphqlAPIKey: api_key = self.api_keys[api_key_id] api_key.update(description, expires) return api_key def start_schema_creation(self, definition: str) -> None: graphql_definition = base64.b64decode(definition).decode("utf-8") self.graphql_schema = GraphqlSchema(graphql_definition, region_name=self.region) def get_schema_status(self) -> Any: return self.graphql_schema.get_status() # type: ignore[union-attr] def get_type(self, type_name: str, type_format: str) -> Any: graphql_type = self.graphql_schema.get_type(type_name) # type: ignore[union-attr] graphql_type["format"] = type_format # type: ignore[index] return graphql_type def create_api_cache( self, ttl: int, api_caching_behavior: str, type: str, transit_encryption_enabled: Optional[bool] = None, at_rest_encryption_enabled: Optional[bool] = None, health_metrics_config: Optional[str] = None, ) -> APICache: self.api_cache = APICache( ttl, api_caching_behavior, type, transit_encryption_enabled, at_rest_encryption_enabled, health_metrics_config, ) return self.api_cache def update_api_cache( self, ttl: int, api_caching_behavior: str, type: str, health_metrics_config: Optional[str] = None, ) -> APICache: self.api_cache.update(ttl, api_caching_behavior, type, health_metrics_config) # type: ignore[union-attr] return self.api_cache # type: ignore[return-value] def delete_api_cache(self) -> None: self.api_cache = None def to_json(self) -> Dict[str, Any]: return { "name": self.name, "apiId": self.api_id, "authenticationType": self.authentication_type, "arn": self.arn, "uris": {"GRAPHQL": "http://graphql.uri"}, "additionalAuthenticationProviders": self.additional_authentication_providers, "lambdaAuthorizerConfig": self.lambda_authorizer_config, "logConfig": self.log_config, "openIDConnectConfig": self.open_id_connect_config, "userPoolConfig": self.user_pool_config, "xrayEnabled": self.xray_enabled, "visibility": self.visibility, } class AppSyncBackend(BaseBackend): """Implementation of AppSync APIs.""" def __init__(self, region_name: str, account_id: str): super().__init__(region_name, account_id) self.graphql_apis: Dict[str, GraphqlAPI] = dict() self.tagger = TaggingService() def create_graphql_api( self, name: str, log_config: str, authentication_type: str, user_pool_config: str, open_id_connect_config: str, additional_authentication_providers: Optional[List[str]], xray_enabled: str, lambda_authorizer_config: str, tags: Dict[str, str], visibility: str, ) -> GraphqlAPI: graphql_api = GraphqlAPI( account_id=self.account_id, region=self.region_name, name=name, authentication_type=authentication_type, additional_authentication_providers=additional_authentication_providers, log_config=log_config, xray_enabled=xray_enabled, user_pool_config=user_pool_config, open_id_connect_config=open_id_connect_config, lambda_authorizer_config=lambda_authorizer_config, visibility=visibility, ) self.graphql_apis[graphql_api.api_id] = graphql_api self.tagger.tag_resource( graphql_api.arn, TaggingService.convert_dict_to_tags_input(tags) ) return graphql_api def update_graphql_api( self, api_id: str, name: str, log_config: str, authentication_type: str, user_pool_config: str, open_id_connect_config: str, additional_authentication_providers: Optional[List[str]], xray_enabled: str, lambda_authorizer_config: str, ) -> GraphqlAPI: graphql_api = self.graphql_apis[api_id] graphql_api.update( name, additional_authentication_providers, authentication_type, lambda_authorizer_config, log_config, open_id_connect_config, user_pool_config, xray_enabled, ) return graphql_api def get_graphql_api(self, api_id: str) -> GraphqlAPI: if api_id not in self.graphql_apis: raise GraphqlAPINotFound(api_id) return self.graphql_apis[api_id] def get_graphql_schema(self, api_id: str) -> GraphqlSchema: graphql_api = self.get_graphql_api(api_id) if not graphql_api.graphql_schema: # When calling get_introspetion_schema without a graphql schema # the response GraphQLSchemaException exception includes InvalidSyntaxError # in the message. This might not be the case for other methods. raise GraphQLSchemaException(message="InvalidSyntaxError") return graphql_api.graphql_schema def delete_graphql_api(self, api_id: str) -> None: self.graphql_apis.pop(api_id) def list_graphql_apis(self) -> Iterable[GraphqlAPI]: """ Pagination or the maxResults-parameter have not yet been implemented. """ return self.graphql_apis.values() def create_api_key( self, api_id: str, description: str, expires: Optional[int] ) -> GraphqlAPIKey: return self.graphql_apis[api_id].create_api_key(description, expires) def delete_api_key(self, api_id: str, api_key_id: str) -> None: self.graphql_apis[api_id].delete_api_key(api_key_id) def list_api_keys(self, api_id: str) -> Iterable[GraphqlAPIKey]: """ Pagination or the maxResults-parameter have not yet been implemented. """ if api_id in self.graphql_apis: return self.graphql_apis[api_id].list_api_keys() else: return [] def update_api_key( self, api_id: str, api_key_id: str, description: str, expires: Optional[int], ) -> GraphqlAPIKey: return self.graphql_apis[api_id].update_api_key( api_key_id, description, expires ) def start_schema_creation(self, api_id: str, definition: str) -> str: self.graphql_apis[api_id].start_schema_creation(definition) return "PROCESSING" def get_schema_creation_status(self, api_id: str) -> Any: return self.graphql_apis[api_id].get_schema_status() def tag_resource(self, resource_arn: str, tags: Dict[str, str]) -> None: self.tagger.tag_resource( resource_arn, TaggingService.convert_dict_to_tags_input(tags) ) def untag_resource(self, resource_arn: str, tag_keys: List[str]) -> None: self.tagger.untag_resource_using_names(resource_arn, tag_keys) def list_tags_for_resource(self, resource_arn: str) -> Dict[str, str]: return self.tagger.get_tag_dict_for_resource(resource_arn) def get_type(self, api_id: str, type_name: str, type_format: str) -> Any: return self.graphql_apis[api_id].get_type(type_name, type_format) def get_api_cache(self, api_id: str) -> APICache: if api_id not in self.graphql_apis: raise GraphqlAPINotFound(api_id) api_cache = self.graphql_apis[api_id].api_cache if api_cache is None: raise GraphqlAPICacheNotFound("get") return api_cache def delete_api_cache(self, api_id: str) -> None: if api_id not in self.graphql_apis: raise GraphqlAPINotFound(api_id) if self.graphql_apis[api_id].api_cache is None: raise GraphqlAPICacheNotFound("delete") self.graphql_apis[api_id].delete_api_cache() return def create_api_cache( self, api_id: str, ttl: int, api_caching_behavior: str, type: str, transit_encryption_enabled: Optional[bool] = None, at_rest_encryption_enabled: Optional[bool] = None, health_metrics_config: Optional[str] = None, ) -> APICache: if api_id not in self.graphql_apis: raise GraphqlAPINotFound(api_id) graphql_api = self.graphql_apis[api_id] if graphql_api.api_cache is not None: raise BadRequestException(message="The API has already enabled caching.") api_cache = graphql_api.create_api_cache( ttl, api_caching_behavior, type, transit_encryption_enabled, at_rest_encryption_enabled, health_metrics_config, ) return api_cache def update_api_cache( self, api_id: str, ttl: int, api_caching_behavior: str, type: str, health_metrics_config: Optional[str] = None, ) -> APICache: if api_id not in self.graphql_apis: raise GraphqlAPINotFound(api_id) graphql_api = self.graphql_apis[api_id] if graphql_api.api_cache is None: raise GraphqlAPICacheNotFound("update") api_cache = graphql_api.update_api_cache( ttl, api_caching_behavior, type, health_metrics_config ) return api_cache def flush_api_cache(self, api_id: str) -> None: if api_id not in self.graphql_apis: raise GraphqlAPINotFound(api_id) if self.graphql_apis[api_id].api_cache is None: raise GraphqlAPICacheNotFound("flush") return appsync_backends = BackendDict(AppSyncBackend, "appsync")
Memory