"""MemoryDBBackend class with methods for supported APIs."""
import copy
import random
from datetime import datetime
from typing import Any, Dict, List, Optional
from moto.core.base_backend import BackendDict, BaseBackend
from moto.core.common_models import BaseModel
from moto.ec2 import ec2_backends
from moto.utilities.tagging_service import TaggingService
from .exceptions import (
ClusterAlreadyExistsFault,
ClusterNotFoundFault,
InvalidParameterValueException,
InvalidSubnetError,
SnapshotAlreadyExistsFault,
SnapshotNotFoundFault,
SubnetGroupAlreadyExistsFault,
SubnetGroupInUseFault,
SubnetGroupNotFoundFault,
TagNotFoundFault,
)
class MemoryDBCluster(BaseModel):
def __init__(
self,
cluster_name: str,
node_type: str,
parameter_group_name: str,
description: str,
num_shards: int,
num_replicas_per_shard: int,
subnet_group_name: str,
vpc_id: str,
maintenance_window: str,
port: int,
sns_topic_arn: str,
kms_key_id: str,
snapshot_arns: List[str],
snapshot_name: str,
snapshot_retention_limit: int,
snapshot_window: str,
acl_name: str,
engine_version: str,
region: str,
account_id: str,
security_group_ids: List[str],
auto_minor_version_upgrade: bool,
data_tiering: bool,
tls_enabled: bool,
):
self.cluster_name = cluster_name
self.node_type = node_type
# Default is set to 'default.memorydb-redis7'.
self.parameter_group_name = parameter_group_name or "default.memorydb-redis7"
# Setting it to 'in-sync', other option are 'active' or 'applying'.
self.parameter_group_status = "in-sync"
self.description = description
self.num_shards = num_shards or 1 # Default shards is set to 1
# Defaults to 1 (i.e. 2 nodes per shard).
self.num_replicas_per_shard = num_replicas_per_shard or 1
self.subnet_group_name = subnet_group_name
self.vpc_id = vpc_id
self.maintenance_window = maintenance_window or "wed:08:00-wed:09:00"
self.port = port or 6379 # Default is set to 6379
self.sns_topic_arn = sns_topic_arn
self.tls_enabled = tls_enabled if tls_enabled is not None else True
# Clusters that do not have TLS enabled must use the "open-access" ACL to provide open authentication.
self.acl_name = "open-access" if tls_enabled is False else acl_name
self.kms_key_id = kms_key_id
self.snapshot_arns = snapshot_arns
self.snapshot_name = snapshot_name
self.snapshot_retention_limit = snapshot_retention_limit or 0
self.snapshot_window = snapshot_window or "03:00-04:00"
self.region = region
self.engine_version = engine_version
if engine_version == "7.0":
self.engine_patch_version = "7.0.7"
elif engine_version == "6.2":
self.engine_patch_version = "6.2.6"
else:
self.engine_version = "7.1" # Default is '7.1'.
self.engine_patch_version = "7.1.1"
self.auto_minor_version_upgrade = (
auto_minor_version_upgrade
if auto_minor_version_upgrade is not None
else True
)
self.data_tiering = "true" if data_tiering else "false"
# The initial status of the cluster will be set to 'creating'."
self.status = (
# Set to 'available', other options are 'creating', 'Updating'.
"available"
)
self.pending_updates: Dict[Any, Any] = {} # TODO
self.shards = self.get_shard_details()
self.availability_mode = (
"SingleAZ" if self.num_replicas_per_shard == 0 else "MultiAZ"
)
self.cluster_endpoint = {
"Address": f"clustercfg.{self.cluster_name}.aoneci.memorydb.{region}.amazonaws.com",
"Port": self.port,
}
self.security_group_ids = security_group_ids or []
self.security_groups = []
for sg in self.security_group_ids:
security_group = {"SecurityGroupId": sg, "Status": "active"}
self.security_groups.append(security_group)
self.arn = f"arn:aws:memorydb:{region}:{account_id}:cluster/{self.cluster_name}"
self.sns_topic_status = "active" if self.sns_topic_arn else ""
def get_shard_details(self) -> List[Dict[str, Any]]:
shards = []
for i in range(self.num_shards):
shard_name = f"{i+1:04}"
num_nodes = self.num_replicas_per_shard + 1
nodes = []
azs = ["a", "b", "c", "d"]
for n in range(num_nodes):
node_name = f"{self.cluster_name}-{shard_name}-{n+1:03}"
node = {
"Name": node_name,
"Status": "available",
"AvailabilityZone": f"{self.region}{random.choice(azs)}",
"CreateTime": datetime.now().strftime(
"%Y-%m-%dT%H:%M:%S.000%f+0000"
),
"Endpoint": {
"Address": f"{node_name}.{self.cluster_name}.aoneci.memorydb.{self.region}.amazonaws.com",
"Port": self.port,
},
}
nodes.append(node)
shard = {
"Name": shard_name,
# Set to 'available', other options are 'creating', 'modifying' , 'deleting'.
"Status": "available",
"Slots": f"0-{str(random.randint(10000,99999))}",
"Nodes": nodes,
"NumberOfNodes": num_nodes,
}
shards.append(shard)
return shards
def update(
self,
description: Optional[str],
security_group_ids: Optional[List[str]],
maintenance_window: Optional[str],
sns_topic_arn: Optional[str],
sns_topic_status: Optional[str],
parameter_group_name: Optional[str],
snapshot_window: Optional[str],
snapshot_retention_limit: Optional[int],
node_type: Optional[str],
engine_version: Optional[str],
replica_configuration: Optional[Dict[str, int]],
shard_configuration: Optional[Dict[str, int]],
acl_name: Optional[str],
) -> None:
if description is not None:
self.description = description
if security_group_ids is not None:
self.security_group_ids = security_group_ids
if maintenance_window is not None:
self.maintenance_window = maintenance_window
if sns_topic_arn is not None:
self.sns_topic_arn = sns_topic_arn
if sns_topic_status is not None:
self.sns_topic_status = sns_topic_status
if parameter_group_name is not None:
self.parameter_group_name = parameter_group_name
if snapshot_window is not None:
self.snapshot_window = snapshot_window
if snapshot_retention_limit is not None:
self.snapshot_retention_limit = snapshot_retention_limit
if node_type is not None:
self.node_type = node_type
if engine_version is not None:
self.engine_version = engine_version
if replica_configuration is not None:
self.num_replicas_per_shard = replica_configuration["ReplicaCount"]
self.shards = self.get_shard_details() # update shards and nodes
if shard_configuration is not None:
self.num_shards = shard_configuration["ShardCount"]
self.shards = self.get_shard_details() # update shards and nodes
if acl_name is not None:
self.acl_name = acl_name
def to_dict(self) -> Dict[str, Any]:
dct = {
"Name": self.cluster_name,
"Description": self.description,
"Status": self.status,
"PendingUpdates": self.pending_updates,
"NumberOfShards": self.num_shards,
"AvailabilityMode": self.availability_mode,
"ClusterEndpoint": self.cluster_endpoint,
"NodeType": self.node_type,
"EngineVersion": self.engine_version,
"EnginePatchVersion": self.engine_patch_version,
"ParameterGroupName": self.parameter_group_name,
"ParameterGroupStatus": self.parameter_group_status,
"SecurityGroups": self.security_groups,
"SubnetGroupName": self.subnet_group_name,
"KmsKeyId": self.kms_key_id,
"ARN": self.arn,
"SnsTopicArn": self.sns_topic_arn,
"SnsTopicStatus": self.sns_topic_status,
"MaintenanceWindow": self.maintenance_window,
"SnapshotWindow": self.snapshot_window,
"ACLName": self.acl_name,
"DataTiering": self.data_tiering,
}
dct_items = {k: v for k, v in dct.items() if v}
dct_items["TLSEnabled"] = self.tls_enabled
dct_items["AutoMinorVersionUpgrade"] = self.auto_minor_version_upgrade
dct_items["SnapshotRetentionLimit"] = self.snapshot_retention_limit
return dct_items
def to_desc_dict(self) -> Dict[str, Any]:
dct = self.to_dict()
dct["Shards"] = self.shards
return dct
class MemoryDBSubnetGroup(BaseModel):
def __init__(
self,
region_name: str,
account_id: str,
ec2_backend: Any,
subnet_group_name: str,
description: str,
subnet_ids: List[str],
tags: Optional[List[Dict[str, str]]] = None,
):
self.ec2_backend = ec2_backend
self.subnet_group_name = subnet_group_name
self.description = description
self.subnet_ids = subnet_ids
if not self.subnets:
raise InvalidSubnetError(subnet_ids)
self.arn = f"arn:aws:memorydb:{region_name}:{account_id}:subnetgroup/{subnet_group_name}"
@property
def subnets(self) -> Any: # type: ignore[misc]
return self.ec2_backend.describe_subnets(filters={"subnet-id": self.subnet_ids})
@property
def vpc_id(self) -> str:
return self.subnets[0].vpc_id
def to_dict(self) -> Dict[str, Any]:
return {
"Name": self.subnet_group_name,
"Description": self.description,
"VpcId": self.vpc_id,
"Subnets": [
{
"Identifier": subnet.id,
"AvailabilityZone": {"Name": subnet.availability_zone},
}
for subnet in self.subnets
],
"ARN": self.arn,
}
class MemoryDBSnapshot(BaseModel):
def __init__(
self,
account_id: str,
region_name: str,
cluster: MemoryDBCluster,
snapshot_name: str,
kms_key_id: Optional[str],
tags: Optional[List[Dict[str, str]]],
source: Optional[str],
):
self.cluster = copy.copy(cluster)
self.cluster_name = self.cluster.cluster_name
self.snapshot_name = snapshot_name
self.status = "available"
self.source = source
self.kms_key_id = kms_key_id if kms_key_id else cluster.kms_key_id
self.arn = (
f"arn:aws:memorydb:{region_name}:{account_id}:snapshot/{snapshot_name}"
)
self.vpc_id = self.cluster.vpc_id
self.shards = []
for i in self.cluster.shards:
shard = {
"Name": i["Name"],
"Configuration": {
"Slots": i["Slots"],
"ReplicaCount": self.cluster.num_replicas_per_shard,
},
"Size": "11 MB",
"SnapshotCreationTime": datetime.now().strftime(
"%Y-%m-%dT%H:%M:%S.000%f+0000"
),
}
self.shards.append(shard)
def to_dict(self) -> Dict[str, Any]:
dct = {
"Name": self.snapshot_name,
"Status": self.status,
"Source": self.source,
"KmsKeyId": self.kms_key_id,
"ARN": self.arn,
"ClusterConfiguration": {
"Name": self.cluster_name,
"Description": self.cluster.description,
"NodeType": self.cluster.node_type,
"EngineVersion": self.cluster.engine_version,
"MaintenanceWindow": self.cluster.maintenance_window,
"TopicArn": self.cluster.sns_topic_arn,
"Port": self.cluster.port,
"ParameterGroupName": self.cluster.parameter_group_name,
"SubnetGroupName": self.cluster.subnet_group_name,
"VpcId": self.vpc_id,
"SnapshotRetentionLimit": self.cluster.snapshot_retention_limit,
"SnapshotWindow": self.cluster.snapshot_window,
"NumShards": self.cluster.num_shards,
},
"DataTiering": self.cluster.data_tiering,
}
return {k: v for k, v in dct.items() if v}
def to_desc_dict(self) -> Dict[str, Any]:
dct = self.to_dict()
dct["ClusterConfiguration"]["Shards"] = self.shards
return dct
class MemoryDBBackend(BaseBackend):
"""Implementation of MemoryDB APIs."""
def __init__(self, region_name: str, account_id: str):
super().__init__(region_name, account_id)
self.ec2_backend = ec2_backends[account_id][region_name]
self.clusters: Dict[str, MemoryDBCluster] = dict()
self.subnet_groups: Dict[str, MemoryDBSubnetGroup] = {
"default": MemoryDBSubnetGroup(
region_name,
account_id,
self.ec2_backend,
"default",
"Default MemoryDB Subnet Group",
self.get_default_subnets(),
)
}
self.snapshots: Dict[str, MemoryDBSnapshot] = dict()
self.tagger = TaggingService()
def get_default_subnets(self) -> List[str]:
default_subnets = self.ec2_backend.describe_subnets(
filters={"default-for-az": "true"}
)
default_subnet_ids = [i.id for i in default_subnets]
return default_subnet_ids
def _list_arns(self) -> List[str]:
cluster_arns = [cluster.arn for cluster in self.clusters.values()]
snapshot_arns = [snapshot.arn for snapshot in self.snapshots.values()]
subnet_group_arns = [subnet.arn for subnet in self.subnet_groups.values()]
return cluster_arns + snapshot_arns + subnet_group_arns
def create_cluster(
self,
cluster_name: str,
node_type: str,
parameter_group_name: str,
description: str,
subnet_group_name: str,
security_group_ids: List[str],
maintenance_window: str,
port: int,
sns_topic_arn: str,
tls_enabled: bool,
kms_key_id: str,
snapshot_arns: List[str],
snapshot_name: str,
snapshot_retention_limit: int,
tags: List[Dict[str, str]],
snapshot_window: str,
acl_name: str,
engine_version: str,
auto_minor_version_upgrade: bool,
data_tiering: bool,
num_shards: int,
num_replicas_per_shard: int,
) -> MemoryDBCluster:
if cluster_name in self.clusters:
raise ClusterAlreadyExistsFault(
msg="Cluster with specified name already exists."
)
subnet_group_name = subnet_group_name or "default"
subnet_group = self.subnet_groups[subnet_group_name]
vpc_id = subnet_group.vpc_id
cluster = MemoryDBCluster(
cluster_name=cluster_name,
node_type=node_type,
parameter_group_name=parameter_group_name,
description=description,
num_shards=num_shards,
num_replicas_per_shard=num_replicas_per_shard,
subnet_group_name=subnet_group_name,
vpc_id=vpc_id,
security_group_ids=security_group_ids,
maintenance_window=maintenance_window,
port=port,
sns_topic_arn=sns_topic_arn,
tls_enabled=tls_enabled,
kms_key_id=kms_key_id,
snapshot_arns=snapshot_arns,
snapshot_name=snapshot_name,
snapshot_retention_limit=snapshot_retention_limit,
snapshot_window=snapshot_window,
acl_name=acl_name,
engine_version=engine_version,
auto_minor_version_upgrade=auto_minor_version_upgrade,
data_tiering=data_tiering,
region=self.region_name,
account_id=self.account_id,
)
self.clusters[cluster.cluster_name] = cluster
self.tag_resource(cluster.arn, tags)
return cluster
def create_subnet_group(
self,
subnet_group_name: str,
description: str,
subnet_ids: List[str],
tags: Optional[List[Dict[str, str]]] = None,
) -> MemoryDBSubnetGroup:
if subnet_group_name in self.subnet_groups:
raise SubnetGroupAlreadyExistsFault(
msg=f"Subnet group {subnet_group_name} already exists."
)
subnet_group = MemoryDBSubnetGroup(
self.region_name,
self.account_id,
self.ec2_backend,
subnet_group_name,
description,
subnet_ids,
tags,
)
self.subnet_groups[subnet_group_name] = subnet_group
if tags:
self.tag_resource(subnet_group.arn, tags)
return subnet_group
def create_snapshot(
self,
cluster_name: str,
snapshot_name: str,
kms_key_id: Optional[str] = None,
tags: Optional[List[Dict[str, str]]] = None,
source: str = "manual",
) -> MemoryDBSnapshot:
if cluster_name not in self.clusters:
raise ClusterNotFoundFault(msg=f"Cluster not found: {cluster_name}")
cluster = self.clusters[cluster_name]
if snapshot_name in self.snapshots:
raise SnapshotAlreadyExistsFault(
msg="Snapshot with specified name already exists."
)
snapshot = MemoryDBSnapshot(
account_id=self.account_id,
region_name=self.region_name,
cluster=cluster,
snapshot_name=snapshot_name,
kms_key_id=kms_key_id,
tags=tags,
source=source,
)
self.snapshots[snapshot_name] = snapshot
if tags:
self.tag_resource(snapshot.arn, tags)
return snapshot
def describe_clusters(
self, cluster_name: Optional[str] = None
) -> List[MemoryDBCluster]:
if cluster_name:
if cluster_name in self.clusters:
cluster = self.clusters[cluster_name]
return list([cluster])
else:
raise ClusterNotFoundFault(msg=f"Cluster {cluster_name} not found")
clusters = list(self.clusters.values())
return clusters
def describe_snapshots(
self,
cluster_name: Optional[str] = None,
snapshot_name: Optional[str] = None,
source: Optional[str] = None,
) -> List[MemoryDBSnapshot]:
sources = ["automated", "manual"] if source is None else [source]
if cluster_name and snapshot_name:
for snapshot in list(self.snapshots.values()):
if (
snapshot.cluster_name == cluster_name
and snapshot.snapshot_name == snapshot_name
and snapshot.source in sources
):
return [snapshot]
raise SnapshotNotFoundFault(
msg=f"Snapshot with name {snapshot_name} not found"
)
if cluster_name:
snapshots = [
snapshot
for snapshot in self.snapshots.values()
if (snapshot.cluster_name == cluster_name)
and (snapshot.source in sources)
]
return snapshots
if snapshot_name:
snapshots = [
snapshot
for snapshot in self.snapshots.values()
if (snapshot.snapshot_name == snapshot_name)
and (snapshot.source in sources)
]
if snapshots:
return snapshots
raise SnapshotNotFoundFault(
msg=f"Snapshot with name {snapshot_name} not found"
)
snapshots = [
snapshot
for snapshot in self.snapshots.values()
if snapshot.source in sources
]
return snapshots
def describe_subnet_groups(
self, subnet_group_name: str
) -> List[MemoryDBSubnetGroup]:
if subnet_group_name:
if subnet_group_name in self.subnet_groups:
return list([self.subnet_groups[subnet_group_name]])
raise SubnetGroupNotFoundFault(
msg=f"Subnet group {subnet_group_name} not found."
)
subnet_groups = list(self.subnet_groups.values())
return subnet_groups
def list_tags(self, resource_arn: str) -> List[Dict[str, str]]:
if resource_arn not in self._list_arns():
# Get the resource name from the resource_arn
resource_name = resource_arn.split("/")[-1]
if "subnetgroup" in resource_arn:
raise SubnetGroupNotFoundFault(f"{resource_name} is not present")
elif "snapshot" in resource_arn:
raise SnapshotNotFoundFault(f"{resource_name} is not present")
else:
raise ClusterNotFoundFault(f"{resource_name} is not present")
return self.tagger.list_tags_for_resource(arn=resource_arn)["Tags"]
def tag_resource(
self, resource_arn: str, tags: List[Dict[str, str]]
) -> List[Dict[str, str]]:
if resource_arn not in self._list_arns():
resource_name = resource_arn.split("/")[-1]
if "subnetgroup" in resource_arn:
raise SubnetGroupNotFoundFault(f"{resource_name} is not present")
elif "snapshot" in resource_arn:
raise SnapshotNotFoundFault(f"{resource_name} is not present")
else:
raise ClusterNotFoundFault(f"{resource_name} is not present")
self.tagger.tag_resource(resource_arn, tags)
return self.tagger.list_tags_for_resource(arn=resource_arn)["Tags"]
def untag_resource(
self, resource_arn: str, tag_keys: List[str]
) -> List[Dict[str, str]]:
if resource_arn not in self._list_arns():
resource_name = resource_arn.split("/")[-1]
if "subnetgroup" in resource_arn:
raise SubnetGroupNotFoundFault(f"{resource_name} is not present")
elif "snapshot" in resource_arn:
raise SnapshotNotFoundFault(f"{resource_name} is not present")
else:
raise ClusterNotFoundFault(f"{resource_name} is not present")
list_tags = self.list_tags(resource_arn=resource_arn)
list_keys = [i["Key"] for i in list_tags]
invalid_keys = [key for key in tag_keys if key not in list_keys]
if invalid_keys:
raise TagNotFoundFault(msg=f"These tags are not present : {[invalid_keys]}")
self.tagger.untag_resource_using_names(resource_arn, tag_keys)
return self.tagger.list_tags_for_resource(arn=resource_arn)["Tags"]
def update_cluster(
self,
cluster_name: str,
description: Optional[str],
security_group_ids: Optional[List[str]],
maintenance_window: Optional[str],
sns_topic_arn: Optional[str],
sns_topic_status: Optional[str],
parameter_group_name: Optional[str],
snapshot_window: Optional[str],
snapshot_retention_limit: Optional[int],
node_type: Optional[str],
engine_version: Optional[str],
replica_configuration: Optional[Dict[str, int]],
shard_configuration: Optional[Dict[str, int]],
acl_name: Optional[str],
) -> MemoryDBCluster:
if cluster_name in self.clusters:
cluster = self.clusters[cluster_name]
cluster.update(
description,
security_group_ids,
maintenance_window,
sns_topic_arn,
sns_topic_status,
parameter_group_name,
snapshot_window,
snapshot_retention_limit,
node_type,
engine_version,
replica_configuration,
shard_configuration,
acl_name,
)
return cluster
raise ClusterNotFoundFault(msg="Cluster not found.")
def delete_cluster(
self, cluster_name: str, final_snapshot_name: Optional[str]
) -> MemoryDBCluster:
if cluster_name in self.clusters:
cluster = self.clusters[cluster_name]
cluster.status = "deleting"
if final_snapshot_name is not None: # create snapshot
self.create_snapshot(
cluster_name=cluster_name,
snapshot_name=final_snapshot_name,
source="manual",
)
return self.clusters.pop(cluster_name)
raise ClusterNotFoundFault(cluster_name)
def delete_snapshot(self, snapshot_name: str) -> MemoryDBSnapshot:
if snapshot_name in self.snapshots:
snapshot = self.snapshots[snapshot_name]
snapshot.status = "deleting"
return self.snapshots.pop(snapshot_name)
raise SnapshotNotFoundFault(snapshot_name)
def delete_subnet_group(self, subnet_group_name: str) -> MemoryDBSubnetGroup:
if subnet_group_name in self.subnet_groups:
if subnet_group_name == "default":
raise InvalidParameterValueException(
msg="default is reserved and cannot be modified."
)
if subnet_group_name in [
c.subnet_group_name for c in self.clusters.values()
]:
raise SubnetGroupInUseFault(
msg=f"Subnet group {subnet_group_name} is currently in use by a cluster."
)
return self.subnet_groups.pop(subnet_group_name)
raise SubnetGroupNotFoundFault(
msg=f"Subnet group {subnet_group_name} not found."
)
memorydb_backends = BackendDict(MemoryDBBackend, "memorydb")