import copy import datetime from collections import OrderedDict from typing import Any, Dict, Iterable, List, Optional from dateutil.tz import tzutc from moto.core.base_backend import BackendDict, BaseBackend from moto.core.common_models import BaseModel, CloudFormationModel from moto.core.utils import iso_8601_datetime_with_milliseconds from moto.ec2 import ec2_backends from moto.ec2.models.security_groups import SecurityGroup as EC2SecurityGroup from moto.moto_api._internal import mock_random from moto.utilities.utils import get_partition from .exceptions import ( ClusterAlreadyExistsFaultError, ClusterNotFoundError, ClusterParameterGroupNotFoundError, ClusterSecurityGroupNotFoundError, ClusterSecurityGroupNotFoundFaultError, ClusterSnapshotAlreadyExistsError, ClusterSnapshotNotFoundError, ClusterSubnetGroupNotFoundError, InvalidClusterSnapshotStateFaultError, InvalidParameterCombinationError, InvalidParameterValueError, InvalidSubnetError, ResourceNotFoundFaultError, SnapshotCopyAlreadyDisabledFaultError, SnapshotCopyAlreadyEnabledFaultError, SnapshotCopyDisabledFaultError, SnapshotCopyGrantAlreadyExistsFaultError, SnapshotCopyGrantNotFoundFaultError, UnknownSnapshotCopyRegionFaultError, ) class TaggableResourceMixin: resource_type = "" def __init__( self, account_id: str, region_name: str, tags: Optional[List[Dict[str, Any]]] ): self.account_id = account_id self.region = region_name self.tags = tags or [] @property def resource_id(self) -> str: return "" @property def arn(self) -> str: return f"arn:{get_partition(self.region)}:redshift:{self.region}:{self.account_id}:{self.resource_type}:{self.resource_id}" def create_tags(self, tags: List[Dict[str, str]]) -> List[Dict[str, str]]: new_keys = [tag_set["Key"] for tag_set in tags] self.tags = [tag_set for tag_set in self.tags if tag_set["Key"] not in new_keys] self.tags.extend(tags) return self.tags def delete_tags(self, tag_keys: List[str]) -> List[Dict[str, str]]: self.tags = [tag_set for tag_set in self.tags if tag_set["Key"] not in tag_keys] return self.tags class Cluster(TaggableResourceMixin, CloudFormationModel): resource_type = "cluster" def __init__( self, redshift_backend: "RedshiftBackend", cluster_identifier: str, node_type: str, master_username: str, master_user_password: str, db_name: str, cluster_type: str, cluster_security_groups: List[str], vpc_security_group_ids: List[str], cluster_subnet_group_name: str, availability_zone: str, preferred_maintenance_window: str, cluster_parameter_group_name: str, automated_snapshot_retention_period: str, port: str, cluster_version: str, allow_version_upgrade: str, number_of_nodes: str, publicly_accessible: str, encrypted: str, region_name: str, tags: Optional[List[Dict[str, str]]] = None, iam_roles_arn: Optional[List[str]] = None, enhanced_vpc_routing: Optional[str] = None, restored_from_snapshot: bool = False, kms_key_id: Optional[str] = None, ): super().__init__(redshift_backend.account_id, region_name, tags) self.redshift_backend = redshift_backend self.cluster_identifier = cluster_identifier self.create_time = iso_8601_datetime_with_milliseconds() self.status = "available" self.node_type = node_type self.master_username = master_username self.master_user_password = master_user_password self.db_name = db_name if db_name else "dev" self.vpc_security_group_ids = vpc_security_group_ids self.enhanced_vpc_routing = ( enhanced_vpc_routing if enhanced_vpc_routing is not None else False ) self.cluster_subnet_group_name = cluster_subnet_group_name self.publicly_accessible = publicly_accessible self.encrypted = encrypted self.allow_version_upgrade = ( allow_version_upgrade if allow_version_upgrade is not None else True ) self.cluster_version = cluster_version if cluster_version else "1.0" self.port = int(port) if port else 5439 self.automated_snapshot_retention_period = ( int(automated_snapshot_retention_period) if automated_snapshot_retention_period else 1 ) self.preferred_maintenance_window = ( preferred_maintenance_window if preferred_maintenance_window else "Mon:03:00-Mon:03:30" ) if cluster_parameter_group_name: self.cluster_parameter_group_name = [cluster_parameter_group_name] else: self.cluster_parameter_group_name = ["default.redshift-1.0"] if cluster_security_groups: self.cluster_security_groups = cluster_security_groups else: self.cluster_security_groups = ["Default"] if availability_zone: self.availability_zone = availability_zone else: # This could probably be smarter, but there doesn't appear to be a # way to pull AZs for a region in boto self.availability_zone = region_name + "a" if cluster_type == "single-node": self.number_of_nodes = 1 elif number_of_nodes: self.number_of_nodes = int(number_of_nodes) else: self.number_of_nodes = 1 self.iam_roles_arn = iam_roles_arn or [] self.restored_from_snapshot = restored_from_snapshot self.kms_key_id = kms_key_id self.cluster_snapshot_copy_status: Optional[Dict[str, Any]] = None self.total_storage_capacity = 0 self.logging_details = { "LoggingEnabled": "false", # Lower case is required in response so we use string to simplify "BucketName": "", "S3KeyPrefix": "", "LastSuccessfulDeliveryTime": datetime.datetime.now(), "LastFailureTime": datetime.datetime.now(), "LastFailureMessage": "", "LogDestinationType": "", "LogExports": [], } @staticmethod def cloudformation_name_type() -> str: return "" @staticmethod def cloudformation_type() -> str: # https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-resource-redshift-cluster.html return "AWS::Redshift::Cluster" @classmethod def create_from_cloudformation_json( # type: ignore[misc] cls, resource_name: str, cloudformation_json: Any, account_id: str, region_name: str, **kwargs: Any, ) -> "Cluster": redshift_backend = redshift_backends[account_id][region_name] properties = cloudformation_json["Properties"] if "ClusterSubnetGroupName" in properties: subnet_group_name = properties[ "ClusterSubnetGroupName" ].cluster_subnet_group_name else: subnet_group_name = None cluster = redshift_backend.create_cluster( cluster_identifier=resource_name, node_type=properties.get("NodeType"), master_username=properties.get("MasterUsername"), master_user_password=properties.get("MasterUserPassword"), db_name=properties.get("DBName"), cluster_type=properties.get("ClusterType"), cluster_security_groups=properties.get("ClusterSecurityGroups", []), vpc_security_group_ids=properties.get("VpcSecurityGroupIds", []), cluster_subnet_group_name=subnet_group_name, availability_zone=properties.get("AvailabilityZone"), preferred_maintenance_window=properties.get("PreferredMaintenanceWindow"), cluster_parameter_group_name=properties.get("ClusterParameterGroupName"), automated_snapshot_retention_period=properties.get( "AutomatedSnapshotRetentionPeriod" ), port=properties.get("Port"), cluster_version=properties.get("ClusterVersion"), allow_version_upgrade=properties.get("AllowVersionUpgrade"), enhanced_vpc_routing=properties.get("EnhancedVpcRouting"), number_of_nodes=properties.get("NumberOfNodes"), publicly_accessible=properties.get("PubliclyAccessible"), encrypted=properties.get("Encrypted"), region_name=region_name, kms_key_id=properties.get("KmsKeyId"), ) return cluster @classmethod def has_cfn_attr(cls, attr: str) -> bool: return attr in ["Endpoint.Address", "Endpoint.Port"] def get_cfn_attribute(self, attribute_name: str) -> Any: from moto.cloudformation.exceptions import UnformattedGetAttTemplateException if attribute_name == "Endpoint.Address": return self.endpoint if attribute_name == "Endpoint.Port": return self.port raise UnformattedGetAttTemplateException() @property def endpoint(self) -> str: return f"{self.cluster_identifier}.cg034hpkmmjt.{self.region}.redshift.amazonaws.com" @property def security_groups(self) -> List["SecurityGroup"]: return [ security_group for security_group in self.redshift_backend.describe_cluster_security_groups() if security_group.cluster_security_group_name in self.cluster_security_groups ] @property def vpc_security_groups(self) -> List["EC2SecurityGroup"]: return [ security_group for security_group in self.redshift_backend.ec2_backend.describe_security_groups() if security_group.id in self.vpc_security_group_ids ] @property def parameter_groups(self) -> List["ParameterGroup"]: return [ parameter_group for parameter_group in self.redshift_backend.describe_cluster_parameter_groups() if parameter_group.cluster_parameter_group_name in self.cluster_parameter_group_name ] @property def resource_id(self) -> str: return self.cluster_identifier def pause(self) -> None: self.status = "paused" def resume(self) -> None: self.status = "available" def to_json(self) -> Dict[str, Any]: json_response = { "MasterUsername": self.master_username, "MasterUserPassword": "****", "ClusterVersion": self.cluster_version, "VpcSecurityGroups": [ {"Status": "active", "VpcSecurityGroupId": group.id} for group in self.vpc_security_groups ], "ClusterSubnetGroupName": self.cluster_subnet_group_name, "AvailabilityZone": self.availability_zone, "ClusterStatus": self.status, "NumberOfNodes": self.number_of_nodes, "AutomatedSnapshotRetentionPeriod": self.automated_snapshot_retention_period, "PubliclyAccessible": self.publicly_accessible, "Encrypted": self.encrypted, "DBName": self.db_name, "PreferredMaintenanceWindow": self.preferred_maintenance_window, "ClusterParameterGroups": [ { "ParameterApplyStatus": "in-sync", "ParameterGroupName": group.cluster_parameter_group_name, } for group in self.parameter_groups ], "ClusterSecurityGroups": [ { "Status": "active", "ClusterSecurityGroupName": group.cluster_security_group_name, } for group in self.security_groups ], "Port": self.port, "NodeType": self.node_type, "ClusterIdentifier": self.cluster_identifier, "AllowVersionUpgrade": self.allow_version_upgrade, "Endpoint": {"Address": self.endpoint, "Port": self.port}, "ClusterCreateTime": self.create_time, "PendingModifiedValues": [], "Tags": self.tags, "EnhancedVpcRouting": self.enhanced_vpc_routing, "IamRoles": [ {"ApplyStatus": "in-sync", "IamRoleArn": iam_role_arn} for iam_role_arn in self.iam_roles_arn ], "KmsKeyId": self.kms_key_id, "TotalStorageCapacityInMegaBytes": self.total_storage_capacity, } if self.restored_from_snapshot: json_response["RestoreStatus"] = { "Status": "completed", "CurrentRestoreRateInMegaBytesPerSecond": 123.0, "SnapshotSizeInMegaBytes": 123, "ProgressInMegaBytes": 123, "ElapsedTimeInSeconds": 123, "EstimatedTimeToCompletionInSeconds": 123, } if self.cluster_snapshot_copy_status is not None: json_response["ClusterSnapshotCopyStatus"] = ( self.cluster_snapshot_copy_status ) return json_response class SnapshotCopyGrant(TaggableResourceMixin, BaseModel): resource_type = "snapshotcopygrant" def __init__(self, snapshot_copy_grant_name: str, kms_key_id: str): self.snapshot_copy_grant_name = snapshot_copy_grant_name self.kms_key_id = kms_key_id def to_json(self) -> Dict[str, Any]: return { "SnapshotCopyGrantName": self.snapshot_copy_grant_name, "KmsKeyId": self.kms_key_id, } class SubnetGroup(TaggableResourceMixin, CloudFormationModel): resource_type = "subnetgroup" def __init__( self, ec2_backend: Any, cluster_subnet_group_name: str, description: str, subnet_ids: List[str], region_name: str, tags: Optional[List[Dict[str, str]]] = None, ): super().__init__(ec2_backend.account_id, region_name, tags) self.ec2_backend = ec2_backend self.cluster_subnet_group_name = cluster_subnet_group_name self.description = description self.subnet_ids = subnet_ids if not self.subnets: raise InvalidSubnetError(subnet_ids) @staticmethod def cloudformation_name_type() -> str: return "" @staticmethod def cloudformation_type() -> str: # https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-resource-redshift-clustersubnetgroup.html return "AWS::Redshift::ClusterSubnetGroup" @classmethod def create_from_cloudformation_json( # type: ignore[misc] cls, resource_name: str, cloudformation_json: Any, account_id: str, region_name: str, **kwargs: Any, ) -> "SubnetGroup": redshift_backend = redshift_backends[account_id][region_name] properties = cloudformation_json["Properties"] subnet_group = redshift_backend.create_cluster_subnet_group( cluster_subnet_group_name=resource_name, description=properties.get("Description"), subnet_ids=properties.get("SubnetIds", []), region_name=region_name, ) return subnet_group @property def subnets(self) -> Any: # type: ignore[misc] return self.ec2_backend.describe_subnets(filters={"subnet-id": self.subnet_ids}) @property def vpc_id(self) -> str: return self.subnets[0].vpc_id @property def resource_id(self) -> str: return self.cluster_subnet_group_name def to_json(self) -> Dict[str, Any]: return { "VpcId": self.vpc_id, "Description": self.description, "ClusterSubnetGroupName": self.cluster_subnet_group_name, "SubnetGroupStatus": "Complete", "Subnets": [ { "SubnetStatus": "Active", "SubnetIdentifier": subnet.id, "SubnetAvailabilityZone": {"Name": subnet.availability_zone}, } for subnet in self.subnets ], "Tags": self.tags, } class SecurityGroup(TaggableResourceMixin, BaseModel): resource_type = "securitygroup" def __init__( self, cluster_security_group_name: str, description: str, account_id: str, region_name: str, tags: Optional[List[Dict[str, str]]] = None, ): super().__init__(account_id, region_name, tags) self.cluster_security_group_name = cluster_security_group_name self.description = description self.ingress_rules: List[str] = [] @property def resource_id(self) -> str: return self.cluster_security_group_name def to_json(self) -> Dict[str, Any]: return { "EC2SecurityGroups": [], "IPRanges": [], "Description": self.description, "ClusterSecurityGroupName": self.cluster_security_group_name, "Tags": self.tags, } class ParameterGroup(TaggableResourceMixin, CloudFormationModel): resource_type = "parametergroup" def __init__( self, cluster_parameter_group_name: str, group_family: str, description: str, account_id: str, region_name: str, tags: Optional[List[Dict[str, str]]] = None, ): super().__init__(account_id, region_name, tags) self.cluster_parameter_group_name = cluster_parameter_group_name self.group_family = group_family self.description = description @staticmethod def cloudformation_name_type() -> str: return "" @staticmethod def cloudformation_type() -> str: # https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-resource-redshift-clusterparametergroup.html return "AWS::Redshift::ClusterParameterGroup" @classmethod def create_from_cloudformation_json( # type: ignore[misc] cls, resource_name: str, cloudformation_json: Any, account_id: str, region_name: str, **kwargs: Any, ) -> "ParameterGroup": redshift_backend = redshift_backends[account_id][region_name] properties = cloudformation_json["Properties"] parameter_group = redshift_backend.create_cluster_parameter_group( cluster_parameter_group_name=resource_name, description=properties.get("Description"), group_family=properties.get("ParameterGroupFamily"), region_name=region_name, ) return parameter_group @property def resource_id(self) -> str: return self.cluster_parameter_group_name def to_json(self) -> Dict[str, Any]: return { "ParameterGroupFamily": self.group_family, "Description": self.description, "ParameterGroupName": self.cluster_parameter_group_name, "Tags": self.tags, } class Snapshot(TaggableResourceMixin, BaseModel): resource_type = "snapshot" def __init__( self, cluster: Any, snapshot_identifier: str, account_id: str, region_name: str, tags: Optional[List[Dict[str, str]]] = None, iam_roles_arn: Optional[List[str]] = None, snapshot_type: str = "manual", ): super().__init__(account_id, region_name, tags) self.cluster = copy.copy(cluster) self.snapshot_identifier = snapshot_identifier self.snapshot_type = snapshot_type self.status = "available" self.create_time = iso_8601_datetime_with_milliseconds() self.iam_roles_arn = iam_roles_arn or [] @property def resource_id(self) -> str: return f"{self.cluster.cluster_identifier}/{self.snapshot_identifier}" def to_json(self) -> Dict[str, Any]: return { "SnapshotIdentifier": self.snapshot_identifier, "ClusterIdentifier": self.cluster.cluster_identifier, "SnapshotCreateTime": self.create_time, "Status": self.status, "Port": self.cluster.port, "AvailabilityZone": self.cluster.availability_zone, "MasterUsername": self.cluster.master_username, "ClusterVersion": self.cluster.cluster_version, "SnapshotType": self.snapshot_type, "NodeType": self.cluster.node_type, "NumberOfNodes": self.cluster.number_of_nodes, "DBName": self.cluster.db_name, "Tags": self.tags, "EnhancedVpcRouting": self.cluster.enhanced_vpc_routing, "IamRoles": [ {"ApplyStatus": "in-sync", "IamRoleArn": iam_role_arn} for iam_role_arn in self.iam_roles_arn ], } class RedshiftBackend(BaseBackend): def __init__(self, region_name: str, account_id: str): super().__init__(region_name, account_id) self.clusters: Dict[str, Cluster] = {} self.subnet_groups: Dict[str, SubnetGroup] = {} self.security_groups: Dict[str, SecurityGroup] = { "Default": SecurityGroup( "Default", "Default Redshift Security Group", account_id, region_name ) } self.parameter_groups: Dict[str, ParameterGroup] = { "default.redshift-1.0": ParameterGroup( "default.redshift-1.0", "redshift-1.0", "Default Redshift parameter group", self.account_id, self.region_name, ) } self.ec2_backend = ec2_backends[self.account_id][self.region_name] self.snapshots: Dict[str, Snapshot] = OrderedDict() self.RESOURCE_TYPE_MAP: Dict[str, Dict[str, TaggableResourceMixin]] = { "cluster": self.clusters, # type: ignore "parametergroup": self.parameter_groups, # type: ignore "securitygroup": self.security_groups, # type: ignore "snapshot": self.snapshots, # type: ignore "subnetgroup": self.subnet_groups, # type: ignore } self.snapshot_copy_grants: Dict[str, SnapshotCopyGrant] = {} def enable_snapshot_copy(self, **kwargs: Any) -> Cluster: cluster_identifier = kwargs["cluster_identifier"] cluster = self.clusters[cluster_identifier] if cluster.cluster_snapshot_copy_status is None: if ( cluster.encrypted == "true" and kwargs["snapshot_copy_grant_name"] is None ): raise InvalidParameterValueError( "SnapshotCopyGrantName is required for Snapshot Copy on KMS encrypted clusters." ) if kwargs["destination_region"] == self.region_name: raise UnknownSnapshotCopyRegionFaultError( f"Invalid region {self.region_name}" ) status = { "DestinationRegion": kwargs["destination_region"], "RetentionPeriod": kwargs["retention_period"], "SnapshotCopyGrantName": kwargs["snapshot_copy_grant_name"], } cluster.cluster_snapshot_copy_status = status return cluster raise SnapshotCopyAlreadyEnabledFaultError(cluster_identifier) def disable_snapshot_copy(self, **kwargs: Any) -> Cluster: cluster_identifier = kwargs["cluster_identifier"] cluster = self.clusters[cluster_identifier] if cluster.cluster_snapshot_copy_status is not None: cluster.cluster_snapshot_copy_status = None return cluster raise SnapshotCopyAlreadyDisabledFaultError(cluster_identifier) def modify_snapshot_copy_retention_period( self, cluster_identifier: str, retention_period: str ) -> Cluster: cluster = self.clusters[cluster_identifier] if cluster.cluster_snapshot_copy_status is not None: cluster.cluster_snapshot_copy_status["RetentionPeriod"] = retention_period return cluster else: raise SnapshotCopyDisabledFaultError(cluster_identifier) def create_cluster(self, **cluster_kwargs: Any) -> Cluster: cluster_identifier = cluster_kwargs["cluster_identifier"] if cluster_identifier in self.clusters: raise ClusterAlreadyExistsFaultError() cluster = Cluster(self, **cluster_kwargs) self.clusters[cluster_identifier] = cluster snapshot_id = ( f"rs:{cluster_identifier}-" f"{datetime.datetime.now(tzutc()).strftime('%Y-%m-%d-%H-%M')}" ) # Automated snapshots don't copy over the tags self.create_cluster_snapshot( cluster_identifier, snapshot_id, cluster.region, None, snapshot_type="automated", ) return cluster def pause_cluster(self, cluster_id: str) -> Cluster: if cluster_id not in self.clusters: raise ClusterNotFoundError(cluster_identifier=cluster_id) self.clusters[cluster_id].pause() return self.clusters[cluster_id] def resume_cluster(self, cluster_id: str) -> Cluster: if cluster_id not in self.clusters: raise ClusterNotFoundError(cluster_identifier=cluster_id) self.clusters[cluster_id].resume() return self.clusters[cluster_id] def describe_clusters( self, cluster_identifier: Optional[str] = None ) -> List[Cluster]: if cluster_identifier: if cluster_identifier in self.clusters: return [self.clusters[cluster_identifier]] raise ClusterNotFoundError(cluster_identifier) return list(self.clusters.values()) def modify_cluster(self, **cluster_kwargs: Any) -> Cluster: cluster_identifier = cluster_kwargs.pop("cluster_identifier") new_cluster_identifier = cluster_kwargs.pop("new_cluster_identifier", None) cluster_type = cluster_kwargs.get("cluster_type") if cluster_type and cluster_type not in ["multi-node", "single-node"]: raise InvalidParameterValueError( "Invalid cluster type. Cluster type can be one of multi-node or single-node" ) if cluster_type == "single-node": # AWS will always silently override this value for single-node clusters. cluster_kwargs["number_of_nodes"] = 1 elif cluster_type == "multi-node": if cluster_kwargs.get("number_of_nodes", 0) < 2: raise InvalidParameterCombinationError( "Number of nodes for cluster type multi-node must be greater than or equal to 2" ) cluster = self.describe_clusters(cluster_identifier)[0] for key, value in cluster_kwargs.items(): setattr(cluster, key, value) if new_cluster_identifier: dic = { "cluster_identifier": cluster_identifier, "skip_final_snapshot": True, "final_cluster_snapshot_identifier": None, } self.delete_cluster(**dic) cluster.cluster_identifier = new_cluster_identifier self.clusters[new_cluster_identifier] = cluster return cluster def delete_automated_snapshots(self, cluster_identifier: str) -> None: snapshots = self.describe_cluster_snapshots( cluster_identifier=cluster_identifier ) for snapshot in snapshots: if snapshot.snapshot_type == "automated": self.snapshots.pop(snapshot.snapshot_identifier) def delete_cluster(self, **cluster_kwargs: Any) -> Cluster: cluster_identifier = cluster_kwargs.pop("cluster_identifier") cluster_skip_final_snapshot = cluster_kwargs.pop("skip_final_snapshot") cluster_snapshot_identifer = cluster_kwargs.pop( "final_cluster_snapshot_identifier" ) if cluster_identifier in self.clusters: if ( cluster_skip_final_snapshot is False and cluster_snapshot_identifer is None ): raise InvalidParameterCombinationError( "FinalClusterSnapshotIdentifier is required unless " "SkipFinalClusterSnapshot is specified." ) if ( cluster_skip_final_snapshot is False and cluster_snapshot_identifer is not None ): # create snapshot cluster = self.describe_clusters(cluster_identifier)[0] self.create_cluster_snapshot( cluster_identifier, cluster_snapshot_identifer, cluster.region, cluster.tags, ) self.delete_automated_snapshots(cluster_identifier) return self.clusters.pop(cluster_identifier) raise ClusterNotFoundError(cluster_identifier) def create_cluster_subnet_group( self, cluster_subnet_group_name: str, description: str, subnet_ids: List[str], region_name: str, tags: Optional[List[Dict[str, str]]] = None, ) -> SubnetGroup: subnet_group = SubnetGroup( self.ec2_backend, cluster_subnet_group_name, description, subnet_ids, region_name, tags, ) self.subnet_groups[cluster_subnet_group_name] = subnet_group return subnet_group def describe_cluster_subnet_groups( self, subnet_identifier: Optional[str] = None ) -> List[SubnetGroup]: if subnet_identifier: if subnet_identifier in self.subnet_groups: return [self.subnet_groups[subnet_identifier]] raise ClusterSubnetGroupNotFoundError(subnet_identifier) return list(self.subnet_groups.values()) def delete_cluster_subnet_group(self, subnet_identifier: str) -> SubnetGroup: if subnet_identifier in self.subnet_groups: return self.subnet_groups.pop(subnet_identifier) raise ClusterSubnetGroupNotFoundError(subnet_identifier) def create_cluster_security_group( self, cluster_security_group_name: str, description: str, tags: Optional[List[Dict[str, str]]] = None, ) -> SecurityGroup: security_group = SecurityGroup( cluster_security_group_name, description, self.account_id, self.region_name, tags, ) self.security_groups[cluster_security_group_name] = security_group return security_group def describe_cluster_security_groups( self, security_group_name: Optional[str] = None ) -> List[SecurityGroup]: if security_group_name: if security_group_name in self.security_groups: return [self.security_groups[security_group_name]] raise ClusterSecurityGroupNotFoundError(security_group_name) return list(self.security_groups.values()) def delete_cluster_security_group( self, security_group_identifier: str ) -> SecurityGroup: if security_group_identifier in self.security_groups: return self.security_groups.pop(security_group_identifier) raise ClusterSecurityGroupNotFoundError(security_group_identifier) def authorize_cluster_security_group_ingress( self, security_group_name: str, cidr_ip: str ) -> SecurityGroup: security_group = self.security_groups.get(security_group_name) if not security_group: raise ClusterSecurityGroupNotFoundFaultError() # just adding the cidr_ip as ingress rule for now as there is no security rule security_group.ingress_rules.append(cidr_ip) return security_group def create_cluster_parameter_group( self, cluster_parameter_group_name: str, group_family: str, description: str, region_name: str, tags: Optional[List[Dict[str, str]]] = None, ) -> ParameterGroup: parameter_group = ParameterGroup( cluster_parameter_group_name, group_family, description, self.account_id, region_name, tags, ) self.parameter_groups[cluster_parameter_group_name] = parameter_group return parameter_group def describe_cluster_parameter_groups( self, parameter_group_name: Optional[str] = None ) -> List[ParameterGroup]: if parameter_group_name: if parameter_group_name in self.parameter_groups: return [self.parameter_groups[parameter_group_name]] raise ClusterParameterGroupNotFoundError(parameter_group_name) return list(self.parameter_groups.values()) def delete_cluster_parameter_group( self, parameter_group_name: str ) -> ParameterGroup: if parameter_group_name in self.parameter_groups: return self.parameter_groups.pop(parameter_group_name) raise ClusterParameterGroupNotFoundError(parameter_group_name) def create_cluster_snapshot( self, cluster_identifier: str, snapshot_identifier: str, region_name: str, tags: Optional[List[Dict[str, str]]], snapshot_type: str = "manual", ) -> Snapshot: cluster = self.clusters.get(cluster_identifier) if not cluster: raise ClusterNotFoundError(cluster_identifier) if self.snapshots.get(snapshot_identifier) is not None: raise ClusterSnapshotAlreadyExistsError(snapshot_identifier) snapshot = Snapshot( cluster, snapshot_identifier, self.account_id, region_name, tags, snapshot_type=snapshot_type, ) self.snapshots[snapshot_identifier] = snapshot return snapshot def describe_cluster_snapshots( self, cluster_identifier: Optional[str] = None, snapshot_identifier: Optional[str] = None, snapshot_type: Optional[str] = None, ) -> List[Snapshot]: snapshot_types = ( ["automated", "manual"] if snapshot_type is None else [snapshot_type] ) if cluster_identifier: cluster_snapshots = [] for snapshot in self.snapshots.values(): if snapshot.cluster.cluster_identifier == cluster_identifier: if snapshot.snapshot_type in snapshot_types: cluster_snapshots.append(snapshot) if cluster_snapshots: return cluster_snapshots if snapshot_identifier: if snapshot_identifier in self.snapshots: if self.snapshots[snapshot_identifier].snapshot_type in snapshot_types: return [self.snapshots[snapshot_identifier]] raise ClusterSnapshotNotFoundError(snapshot_identifier) return list(self.snapshots.values()) def delete_cluster_snapshot(self, snapshot_identifier: str) -> Snapshot: if snapshot_identifier not in self.snapshots: raise ClusterSnapshotNotFoundError(snapshot_identifier) snapshot = self.describe_cluster_snapshots( snapshot_identifier=snapshot_identifier )[0] if snapshot.snapshot_type == "automated": raise InvalidClusterSnapshotStateFaultError(snapshot_identifier) deleted_snapshot = self.snapshots.pop(snapshot_identifier) deleted_snapshot.status = "deleted" return deleted_snapshot def restore_from_cluster_snapshot(self, **kwargs: Any) -> Cluster: snapshot_identifier = kwargs.pop("snapshot_identifier") snapshot = self.describe_cluster_snapshots( snapshot_identifier=snapshot_identifier )[0] create_kwargs = { "node_type": snapshot.cluster.node_type, "master_username": snapshot.cluster.master_username, "master_user_password": snapshot.cluster.master_user_password, "db_name": snapshot.cluster.db_name, "cluster_type": "multi-node" if snapshot.cluster.number_of_nodes > 1 else "single-node", "availability_zone": snapshot.cluster.availability_zone, "port": snapshot.cluster.port, "cluster_version": snapshot.cluster.cluster_version, "number_of_nodes": snapshot.cluster.number_of_nodes, "encrypted": snapshot.cluster.encrypted, "tags": snapshot.cluster.tags, "restored_from_snapshot": True, "enhanced_vpc_routing": snapshot.cluster.enhanced_vpc_routing, } create_kwargs.update(kwargs) return self.create_cluster(**create_kwargs) def create_snapshot_copy_grant(self, **kwargs: Any) -> SnapshotCopyGrant: snapshot_copy_grant_name = kwargs["snapshot_copy_grant_name"] kms_key_id = kwargs["kms_key_id"] if snapshot_copy_grant_name not in self.snapshot_copy_grants: snapshot_copy_grant = SnapshotCopyGrant( snapshot_copy_grant_name, kms_key_id ) self.snapshot_copy_grants[snapshot_copy_grant_name] = snapshot_copy_grant return snapshot_copy_grant raise SnapshotCopyGrantAlreadyExistsFaultError(snapshot_copy_grant_name) def delete_snapshot_copy_grant(self, **kwargs: Any) -> SnapshotCopyGrant: snapshot_copy_grant_name = kwargs["snapshot_copy_grant_name"] if snapshot_copy_grant_name in self.snapshot_copy_grants: return self.snapshot_copy_grants.pop(snapshot_copy_grant_name) raise SnapshotCopyGrantNotFoundFaultError(snapshot_copy_grant_name) def describe_snapshot_copy_grants(self, **kwargs: Any) -> List[SnapshotCopyGrant]: copy_grants = list(self.snapshot_copy_grants.values()) snapshot_copy_grant_name = kwargs["snapshot_copy_grant_name"] if snapshot_copy_grant_name: if snapshot_copy_grant_name in self.snapshot_copy_grants: return [self.snapshot_copy_grants[snapshot_copy_grant_name]] raise SnapshotCopyGrantNotFoundFaultError(snapshot_copy_grant_name) return copy_grants def _get_resource_from_arn(self, arn: str) -> TaggableResourceMixin: try: arn_breakdown = arn.split(":") resource_type = arn_breakdown[5] if resource_type == "snapshot": resource_id = arn_breakdown[6].split("/")[1] else: resource_id = arn_breakdown[6] except IndexError: resource_type = resource_id = arn resources = self.RESOURCE_TYPE_MAP.get(resource_type) if resources is None: message = ( "Tagging is not supported for this type of resource: " f"'{resource_type}' (the ARN is potentially malformed, " "please check the ARN documentation for more information)" ) raise ResourceNotFoundFaultError(message=message) try: resource = resources[resource_id] except KeyError: raise ResourceNotFoundFaultError(resource_type, resource_id) return resource @staticmethod def _describe_tags_for_resources(resources: Iterable[Any]) -> List[Dict[str, Any]]: # type: ignore[misc] tagged_resources = [] for resource in resources: for tag in resource.tags: data = { "ResourceName": resource.arn, "ResourceType": resource.resource_type, "Tag": {"Key": tag["Key"], "Value": tag["Value"]}, } tagged_resources.append(data) return tagged_resources def _describe_tags_for_resource_type( self, resource_type: str ) -> List[Dict[str, Any]]: resources = self.RESOURCE_TYPE_MAP.get(resource_type) if not resources: raise ResourceNotFoundFaultError(resource_type=resource_type) return self._describe_tags_for_resources(resources.values()) def _describe_tags_for_resource_name( self, resource_name: str ) -> List[Dict[str, Any]]: resource = self._get_resource_from_arn(resource_name) return self._describe_tags_for_resources([resource]) def create_tags(self, resource_name: str, tags: List[Dict[str, str]]) -> None: resource = self._get_resource_from_arn(resource_name) resource.create_tags(tags) def describe_tags( self, resource_name: str, resource_type: str ) -> List[Dict[str, Any]]: if resource_name and resource_type: raise InvalidParameterValueError( "You cannot filter a list of resources using an Amazon " "Resource Name (ARN) and a resource type together in the " "same request. Retry the request using either an ARN or " "a resource type, but not both." ) if resource_type: return self._describe_tags_for_resource_type(resource_type.lower()) if resource_name: return self._describe_tags_for_resource_name(resource_name) # If name and type are not specified, return all tagged resources. # TODO: Implement aws marker pagination tagged_resources = [] for resource_type in self.RESOURCE_TYPE_MAP: try: tagged_resources += self._describe_tags_for_resource_type(resource_type) except ResourceNotFoundFaultError: pass return tagged_resources def delete_tags(self, resource_name: str, tag_keys: List[str]) -> None: resource = self._get_resource_from_arn(resource_name) resource.delete_tags(tag_keys) def get_cluster_credentials( self, cluster_identifier: str, db_user: str, auto_create: bool, duration_seconds: int, ) -> Dict[str, Any]: if duration_seconds < 900 or duration_seconds > 3600: raise InvalidParameterValueError( "Token duration must be between 900 and 3600 seconds" ) expiration = datetime.datetime.now(tzutc()) + datetime.timedelta( 0, duration_seconds ) if cluster_identifier in self.clusters: user_prefix = "IAM:" if auto_create is False else "IAMA:" db_user = user_prefix + db_user return { "DbUser": db_user, "DbPassword": mock_random.get_random_string(32), "Expiration": expiration, } raise ClusterNotFoundError(cluster_identifier) def enable_logging( self, cluster_identifier: str, bucket_name: str, s3_key_prefix: str, log_destination_type: str, log_exports: List[str], ) -> Dict[str, Any]: if cluster_identifier not in self.clusters: raise ClusterNotFoundError(cluster_identifier) cluster = self.clusters[cluster_identifier] cluster.logging_details["LoggingEnabled"] = "true" cluster.logging_details["BucketName"] = bucket_name cluster.logging_details["S3KeyPrefix"] = s3_key_prefix cluster.logging_details["LogDestinationType"] = log_destination_type cluster.logging_details["LogExports"] = log_exports return cluster.logging_details def disable_logging(self, cluster_identifier: str) -> Dict[str, Any]: if cluster_identifier not in self.clusters: raise ClusterNotFoundError(cluster_identifier) cluster = self.clusters[cluster_identifier] cluster.logging_details["LoggingEnabled"] = "false" return cluster.logging_details def describe_logging_status(self, cluster_identifier: str) -> Dict[str, Any]: if cluster_identifier not in self.clusters: raise ClusterNotFoundError(cluster_identifier) cluster = self.clusters[cluster_identifier] return cluster.logging_details redshift_backends = BackendDict(RedshiftBackend, "redshift")
Memory