import string
from typing import Any, Dict, Iterable, List, Optional, Tuple
from moto.core.base_backend import BackendDict, BaseBackend
from moto.core.common_models import BaseModel
from moto.core.utils import unix_time
from moto.moto_api._internal import mock_random as random
from moto.utilities.tagging_service import TaggingService
from moto.utilities.utils import get_partition
from .exceptions import (
ConflictingDomainExists,
CustomHealthNotFound,
InstanceNotFound,
InvalidInput,
NamespaceNotFound,
OperationNotFound,
ServiceNotFound,
)
def random_id(size: int) -> str:
return "".join(
[random.choice(string.ascii_lowercase + string.digits) for _ in range(size)]
)
class Namespace(BaseModel):
def __init__(
self,
account_id: str,
region: str,
name: str,
ns_type: str,
creator_request_id: str,
description: str,
dns_properties: Dict[str, Any],
http_properties: Dict[str, Any],
vpc: Optional[str] = None,
):
self.id = f"ns-{random_id(20)}"
self.arn = f"arn:{get_partition(region)}:servicediscovery:{region}:{account_id}:namespace/{self.id}"
self.name = name
self.type = ns_type
self.creator_request_id = creator_request_id
self.description = description
self.dns_properties = dns_properties
self.http_properties = http_properties
self.vpc = vpc
self.created = unix_time()
self.updated = unix_time()
def to_json(self) -> Dict[str, Any]:
return {
"Arn": self.arn,
"Id": self.id,
"Name": self.name,
"Description": self.description,
"Type": self.type,
"Properties": {
"DnsProperties": self.dns_properties,
"HttpProperties": self.http_properties,
},
"CreateDate": self.created,
"UpdateDate": self.updated,
"CreatorRequestId": self.creator_request_id,
}
class Service(BaseModel):
def __init__(
self,
account_id: str,
region: str,
name: str,
namespace_id: str,
description: str,
creator_request_id: str,
dns_config: Dict[str, Any],
health_check_config: Dict[str, Any],
health_check_custom_config: Dict[str, int],
service_type: str,
):
self.id = f"srv-{random_id(8)}"
self.arn = f"arn:{get_partition(region)}:servicediscovery:{region}:{account_id}:service/{self.id}"
self.name = name
self.namespace_id = namespace_id
self.description = description
self.creator_request_id = creator_request_id
self.dns_config: Optional[Dict[str, Any]] = dns_config
self.health_check_config = health_check_config
self.health_check_custom_config = health_check_custom_config
self.service_type = service_type
self.created = unix_time()
self.instances: List[ServiceInstance] = []
self.instances_revision: Dict[str, int] = {}
def update(self, details: Dict[str, Any]) -> None:
if "Description" in details:
self.description = details["Description"]
if "DnsConfig" in details:
if self.dns_config is None:
self.dns_config = {}
self.dns_config["DnsRecords"] = details["DnsConfig"]["DnsRecords"]
else:
# From the docs:
# If you omit any existing DnsRecords or HealthCheckConfig configurations from an UpdateService request,
# the configurations are deleted from the service.
self.dns_config = None
if "HealthCheckConfig" in details:
self.health_check_config = details["HealthCheckConfig"]
def to_json(self) -> Dict[str, Any]:
return {
"Arn": self.arn,
"Id": self.id,
"Name": self.name,
"NamespaceId": self.namespace_id,
"CreateDate": self.created,
"Description": self.description,
"CreatorRequestId": self.creator_request_id,
"DnsConfig": self.dns_config,
"HealthCheckConfig": self.health_check_config,
"HealthCheckCustomConfig": self.health_check_custom_config,
"Type": self.service_type,
}
class ServiceInstance(BaseModel):
def __init__(
self,
service_id: str,
instance_id: str,
creator_request_id: Optional[str] = None,
attributes: Optional[Dict[str, str]] = None,
):
self.service_id = service_id
self.instance_id = instance_id
self.attributes = attributes if attributes else {}
self.creator_request_id = (
creator_request_id if creator_request_id else random_id(32)
)
self.health_status = "HEALTHY"
def to_json(self) -> Dict[str, Any]:
return {
"Id": self.instance_id,
"CreatorRequestId": self.creator_request_id,
"Attributes": self.attributes,
}
class Operation(BaseModel):
def __init__(self, operation_type: str, targets: Dict[str, str]):
super().__init__()
self.id = f"{random_id(32)}-{random_id(8)}"
self.status = "SUCCESS"
self.operation_type = operation_type
self.created = unix_time()
self.updated = unix_time()
self.targets = targets
def to_json(self, short: bool = False) -> Dict[str, Any]:
if short:
return {"Id": self.id, "Status": self.status}
else:
return {
"Id": self.id,
"Status": self.status,
"Type": self.operation_type,
"CreateDate": self.created,
"UpdateDate": self.updated,
"Targets": self.targets,
}
class ServiceDiscoveryBackend(BaseBackend):
"""Implementation of ServiceDiscovery APIs."""
def __init__(self, region_name: str, account_id: str):
super().__init__(region_name, account_id)
self.operations: Dict[str, Operation] = dict()
self.namespaces: Dict[str, Namespace] = dict()
self.services: Dict[str, Service] = dict()
self.tagger = TaggingService()
def list_namespaces(self) -> Iterable[Namespace]:
"""
Pagination or the Filters-parameter is not yet implemented
"""
return list(self.namespaces.values())
def create_http_namespace(
self,
name: str,
creator_request_id: str,
description: str,
tags: List[Dict[str, str]],
) -> str:
namespace = Namespace(
account_id=self.account_id,
region=self.region_name,
name=name,
ns_type="HTTP",
creator_request_id=creator_request_id,
description=description,
dns_properties={"SOA": {}},
http_properties={"HttpName": name},
)
self.namespaces[namespace.id] = namespace
if tags:
self.tagger.tag_resource(namespace.arn, tags)
operation_id = self._create_operation(
"CREATE_NAMESPACE", targets={"NAMESPACE": namespace.id}
)
return operation_id
def _create_operation(self, op_type: str, targets: Dict[str, str]) -> str:
operation = Operation(operation_type=op_type, targets=targets)
self.operations[operation.id] = operation
return operation.id
def delete_namespace(self, namespace_id: str) -> str:
if namespace_id not in self.namespaces:
raise NamespaceNotFound(namespace_id)
del self.namespaces[namespace_id]
operation_id = self._create_operation(
op_type="DELETE_NAMESPACE", targets={"NAMESPACE": namespace_id}
)
return operation_id
def get_namespace(self, namespace_id: str) -> Namespace:
if namespace_id not in self.namespaces:
raise NamespaceNotFound(namespace_id)
return self.namespaces[namespace_id]
def list_operations(self) -> Iterable[Operation]:
"""
Pagination or the Filters-argument is not yet implemented
"""
# Operations for namespaces will only be listed as long as namespaces exist
self.operations = {
op_id: op
for op_id, op in self.operations.items()
if op.targets.get("NAMESPACE") in self.namespaces
}
return self.operations.values()
def get_operation(self, operation_id: str) -> Operation:
if operation_id not in self.operations:
raise OperationNotFound()
return self.operations[operation_id]
def tag_resource(self, resource_arn: str, tags: List[Dict[str, str]]) -> None:
self.tagger.tag_resource(resource_arn, tags)
def untag_resource(self, resource_arn: str, tag_keys: List[str]) -> None:
self.tagger.untag_resource_using_names(resource_arn, tag_keys)
def list_tags_for_resource(
self, resource_arn: str
) -> Dict[str, List[Dict[str, str]]]:
return self.tagger.list_tags_for_resource(resource_arn)
def create_private_dns_namespace(
self,
name: str,
creator_request_id: str,
description: str,
vpc: str,
tags: List[Dict[str, str]],
properties: Dict[str, Any],
) -> str:
for namespace in self.namespaces.values():
if namespace.vpc == vpc:
raise ConflictingDomainExists(vpc)
dns_properties = (properties or {}).get("DnsProperties", {})
dns_properties["HostedZoneId"] = "hzi"
namespace = Namespace(
account_id=self.account_id,
region=self.region_name,
name=name,
ns_type="DNS_PRIVATE",
creator_request_id=creator_request_id,
description=description,
dns_properties=dns_properties,
http_properties={},
vpc=vpc,
)
self.namespaces[namespace.id] = namespace
if tags:
self.tagger.tag_resource(namespace.arn, tags)
operation_id = self._create_operation(
"CREATE_NAMESPACE", targets={"NAMESPACE": namespace.id}
)
return operation_id
def create_public_dns_namespace(
self,
name: str,
creator_request_id: str,
description: str,
tags: List[Dict[str, str]],
properties: Dict[str, Any],
) -> str:
dns_properties = (properties or {}).get("DnsProperties", {})
dns_properties["HostedZoneId"] = "hzi"
namespace = Namespace(
account_id=self.account_id,
region=self.region_name,
name=name,
ns_type="DNS_PUBLIC",
creator_request_id=creator_request_id,
description=description,
dns_properties=dns_properties,
http_properties={},
)
self.namespaces[namespace.id] = namespace
if tags:
self.tagger.tag_resource(namespace.arn, tags)
operation_id = self._create_operation(
"CREATE_NAMESPACE", targets={"NAMESPACE": namespace.id}
)
return operation_id
def create_service(
self,
name: str,
namespace_id: str,
creator_request_id: str,
description: str,
dns_config: Dict[str, Any],
health_check_config: Dict[str, Any],
health_check_custom_config: Dict[str, Any],
tags: List[Dict[str, str]],
service_type: str,
) -> Service:
service = Service(
account_id=self.account_id,
region=self.region_name,
name=name,
namespace_id=namespace_id,
description=description,
creator_request_id=creator_request_id,
dns_config=dns_config,
health_check_config=health_check_config,
health_check_custom_config=health_check_custom_config,
service_type=service_type,
)
self.services[service.id] = service
if tags:
self.tagger.tag_resource(service.arn, tags)
return service
def get_service(self, service_id: str) -> Service:
if service_id not in self.services:
raise ServiceNotFound(service_id)
return self.services[service_id]
def delete_service(self, service_id: str) -> None:
self.services.pop(service_id, None)
def list_services(self) -> Iterable[Service]:
"""
Pagination or the Filters-argument is not yet implemented
"""
return self.services.values()
def update_service(self, service_id: str, details: Dict[str, Any]) -> str:
service = self.get_service(service_id)
service.update(details=details)
operation_id = self._create_operation(
"UPDATE_SERVICE", targets={"SERVICE": service.id}
)
return operation_id
def update_http_namespace(
self,
_id: str,
namespace_dict: Dict[str, Any],
updater_request_id: Optional[str] = None,
) -> str:
if "Description" not in namespace_dict:
raise InvalidInput("Description is required")
namespace = self.get_namespace(namespace_id=_id)
if updater_request_id is None:
# Unused as the operation cannot fail
updater_request_id = random_id(32)
namespace.description = namespace_dict["Description"]
if "Properties" in namespace_dict:
if "HttpProperties" in namespace_dict["Properties"]:
namespace.http_properties = namespace_dict["Properties"][
"HttpProperties"
]
operation_id = self._create_operation(
"UPDATE_NAMESPACE", targets={"NAMESPACE": namespace.id}
)
return operation_id
def update_private_dns_namespace(
self, _id: str, description: str, properties: Dict[str, Any]
) -> str:
namespace = self.get_namespace(namespace_id=_id)
if description is not None:
namespace.description = description
if properties is not None:
namespace.dns_properties = properties
operation_id = self._create_operation(
"UPDATE_NAMESPACE", targets={"NAMESPACE": namespace.id}
)
return operation_id
def update_public_dns_namespace(
self, _id: str, description: str, properties: Dict[str, Any]
) -> str:
namespace = self.get_namespace(namespace_id=_id)
if description is not None:
namespace.description = description
if properties is not None:
namespace.dns_properties = properties
operation_id = self._create_operation(
"UPDATE_NAMESPACE", targets={"NAMESPACE": namespace.id}
)
return operation_id
def register_instance(
self,
service_id: str,
instance_id: str,
creator_request_id: str,
attributes: Dict[str, str],
) -> str:
service = self.get_service(service_id)
instance = ServiceInstance(
service_id=service_id,
instance_id=instance_id,
creator_request_id=creator_request_id,
attributes=attributes,
)
service.instances.append(instance)
service.instances_revision[instance_id] = (
service.instances_revision.get(instance_id, 0) + 1
)
operation_id = self._create_operation(
"REGISTER_INSTANCE", targets={"INSTANCE": instance_id}
)
return operation_id
def deregister_instance(self, service_id: str, instance_id: str) -> str:
service = self.get_service(service_id)
i = 0
while i < len(service.instances):
instance = service.instances[i]
if instance.instance_id == instance_id:
service.instances.remove(instance)
service.instances_revision[instance_id] = (
service.instances_revision.get(instance_id, 0) + 1
)
operation_id = self._create_operation(
"DEREGISTER_INSTANCE", targets={"INSTANCE": instance_id}
)
return operation_id
i += 1
raise InstanceNotFound(instance_id)
def list_instances(self, service_id: str) -> List[ServiceInstance]:
service = self.get_service(service_id)
return service.instances
def get_instance(self, service_id: str, instance_id: str) -> ServiceInstance:
for instance in self.list_instances(service_id):
if instance.instance_id == instance_id:
return instance
raise InstanceNotFound(instance_id)
def get_instances_health_status(
self,
service_id: str,
instances: Optional[List[str]] = None,
) -> List[Tuple[str, str]]:
service = self.get_service(service_id)
status = []
if instances is None:
instances = [instance.instance_id for instance in service.instances]
if not isinstance(instances, list):
raise InvalidInput("Instances must be a list")
filtered_instances = [
instance
for instance in service.instances
if instance.instance_id in instances
]
for instance in filtered_instances:
status.append((instance.instance_id, instance.health_status))
return status
def update_instance_custom_health_status(
self, service_id: str, instance_id: str, status: str
) -> None:
if status not in ["HEALTHY", "UNHEALTHY"]:
raise CustomHealthNotFound(service_id)
instance = self.get_instance(service_id, instance_id)
instance.health_status = status
def _filter_instances(
self,
instances: List[ServiceInstance],
query_parameters: Optional[Dict[str, str]] = None,
optional_parameters: Optional[Dict[str, str]] = None,
health_status: Optional[str] = None,
) -> List[ServiceInstance]:
if query_parameters is None:
query_parameters = {}
if optional_parameters is None:
optional_parameters = {}
if health_status is None:
health_status = "ALL"
filtered_instances = []
has_healthy = False
for instance in instances:
# Filter out instances with mismatching health status
if (
health_status not in ["ALL", "HEALTHY_OR_ELSE_ALL"]
and instance.health_status != health_status
):
continue
# Record if there is at least one healthy instance for HEALTHY_OR_ELSE_ALL
if instance.health_status == "HEALTHY":
has_healthy = True
# Filter out instances with mismatching query parameters
matches_query = True
for param in query_parameters:
if instance.attributes.get(param) != query_parameters[param]:
matches_query = False
break
if not matches_query:
continue
# Add instance to the list if it passed all filters
filtered_instances.append(instance)
# Handle HEALTHY_OR_ELSE_ALL
if has_healthy and health_status == "HEALTHY_OR_ELSE_ALL":
filtered_instances = [
instance
for instance in filtered_instances
if instance.health_status == "HEALTHY"
]
# Filter out instances with mismatching optional parameters
opt_filtered_instances = []
for instance in filtered_instances:
matches_optional = True
for param in optional_parameters:
if instance.attributes.get(param) != optional_parameters[param]:
matches_optional = False
break
if matches_optional:
opt_filtered_instances.append(instance)
# If no instances passed the optional parameters, return the original filtered list
return opt_filtered_instances if opt_filtered_instances else filtered_instances
def discover_instances(
self,
namespace_name: str,
service_name: str,
query_parameters: Optional[Dict[str, str]] = None,
optional_parameters: Optional[Dict[str, str]] = None,
health_status: Optional[str] = None,
) -> Tuple[List[ServiceInstance], Dict[str, int]]:
if query_parameters is None:
query_parameters = {}
if optional_parameters is None:
optional_parameters = {}
if health_status is None:
health_status = "ALL"
if health_status not in ["HEALTHY", "UNHEALTHY", "ALL", "HEALTHY_OR_ELSE_ALL"]:
raise InvalidInput("Invalid health status")
try:
namespace = [
ns for ns in self.list_namespaces() if ns.name == namespace_name
][0]
except IndexError:
raise NamespaceNotFound(namespace_name)
try:
service = [
srv
for srv in self.list_services()
if srv.name == service_name and srv.namespace_id == namespace.id
][0]
except IndexError:
raise ServiceNotFound(service_name)
instances = self.list_instances(service.id)
# Filter instances based on query parameters, optional parameters, and health status
final_instances = self._filter_instances(
instances, query_parameters, optional_parameters, health_status
)
# Get the revision number for each instance that passed the filters
instance_revisions = {
instance.instance_id: service.instances_revision.get(
instance.instance_id, 0
)
for instance in final_instances
}
return final_instances, instance_revisions
def discover_instances_revision(
self, namespace_name: str, service_name: str
) -> int:
return sum(self.discover_instances(namespace_name, service_name)[1].values())
def paginate(
self,
items: List[Any],
max_results: Optional[int] = None,
next_token: Optional[str] = None,
) -> Tuple[List[Any], Optional[str]]:
"""
Paginates a list of items. If called without optional parameters, the entire list is returned as-is.
"""
# Default to beginning of list
if next_token is None:
next_token = "0"
# Return empty list if next_token is invalid
if not next_token.isdigit():
return [], None
# Default to the entire list
if max_results is None:
max_results = len(items)
new_token = int(next_token) + max_results
# If the new token overflows the list, return the rest of the list
if new_token >= len(items):
return items[int(next_token) :], None
return items[int(next_token) : new_token], str(new_token)
servicediscovery_backends = BackendDict(ServiceDiscoveryBackend, "servicediscovery")