"""OpenSearchServiceServerlessBackend class with methods for supported APIs."""
import json
from typing import Any, Dict, List, Tuple
from moto.core.base_backend import BackendDict, BaseBackend
from moto.core.common_models import BaseModel
from moto.core.utils import unix_time
from moto.moto_api._internal import mock_random
from moto.utilities.tagging_service import TaggingService
from .exceptions import (
ConflictException,
ResourceNotFoundException,
ValidationException,
)
class SecurityPolicy(BaseModel):
def __init__(
self,
client_token: str,
description: str,
name: str,
policy: str,
type: str,
):
self.client_token = client_token
self.description = description
self.name = name
self.type = type
self.created_date = int(unix_time() * 1000)
# update policy # current date default
self.last_modified_date = int(unix_time() * 1000)
self.policy = json.loads(policy)
self.policy_version = mock_random.get_random_string(20)
if type == "encryption":
self.resources = [
res for rule in self.policy["Rules"] for res in rule["Resource"]
]
else:
self.resources = [
res
for p in self.policy
for rule in p["Rules"]
for res in rule["Resource"]
]
def to_dict(self) -> Dict[str, Any]:
dct = {
"createdDate": self.created_date,
"description": self.description,
"lastModifiedDate": self.last_modified_date,
"name": self.name,
"policy": self.policy,
"policyVersion": self.policy_version,
"type": self.type,
}
return {k: v for k, v in dct.items() if v}
def to_dict_list(self) -> Dict[str, Any]:
dct = self.to_dict()
dct.pop("policy")
return {k: v for k, v in dct.items() if v}
class Collection(BaseModel):
def __init__(
self,
client_token: str,
description: str,
name: str,
standby_replicas: str,
tags: List[Dict[str, str]],
type: str,
policy: Any,
region: str,
account_id: str,
):
self.client_token = client_token
self.description = description
self.name = name
self.standby_replicas = standby_replicas
self.tags = tags
self.type = type
self.id = mock_random.get_random_string(length=20, lower_case=True)
self.arn = f"arn:aws:aoss:{region}:{account_id}:collection/{self.id}"
self.created_date = int(unix_time() * 1000)
self.kms_key_arn = policy["KmsARN"]
self.last_modified_date = int(unix_time() * 1000)
self.status = "ACTIVE"
self.collection_endpoint = f"https://{self.id}.{region}.aoss.amazonaws.com"
self.dashboard_endpoint = (
f"https://{self.id}.{region}.aoss.amazonaws.com/_dashboards"
)
def to_dict(self) -> Dict[str, Any]:
dct = {
"arn": self.arn,
"createdDate": self.created_date,
"description": self.description,
"id": self.id,
"kmsKeyArn": self.kms_key_arn,
"lastModifiedDate": self.last_modified_date,
"name": self.name,
"standbyReplicas": self.standby_replicas,
"status": self.status,
"type": self.type,
}
return {k: v for k, v in dct.items() if v}
def to_dict_list(self) -> Dict[str, Any]:
dct = {"arn": self.arn, "id": self.id, "name": self.name, "status": self.status}
return {k: v for k, v in dct.items() if v}
def to_dict_batch(self) -> Dict[str, Any]:
dct = self.to_dict()
dct_options = {
"collectionEndpoint": self.collection_endpoint,
"dashboardEndpoint": self.dashboard_endpoint,
}
for key, value in dct_options.items():
if value is not None:
dct[key] = value
return dct
class OSEndpoint(BaseModel):
def __init__(
self,
client_token: str,
name: str,
security_group_ids: List[str],
subnet_ids: List[str],
vpc_id: str,
):
self.client_token = client_token
self.name = name
self.security_group_ids = security_group_ids
self.subnet_ids = subnet_ids
self.vpc_id = vpc_id
self.id = f"vpce-0{mock_random.get_random_string(length=16,lower_case=True)}"
self.status = "ACTIVE"
def to_dict(self) -> Dict[str, Any]:
dct = {"id": self.id, "name": self.name, "status": self.status}
return {k: v for k, v in dct.items() if v}
class OpenSearchServiceServerlessBackend(BaseBackend):
"""Implementation of OpenSearchServiceServerless APIs."""
def __init__(self, region_name: str, account_id: str):
super().__init__(region_name, account_id)
self.collections: Dict[str, Collection] = dict()
self.security_policies: Dict[str, SecurityPolicy] = dict()
self.os_endpoints: Dict[str, OSEndpoint] = dict()
self.tagger = TaggingService(
tag_name="tags", key_name="key", value_name="value"
)
def create_security_policy(
self, client_token: str, description: str, name: str, policy: str, type: str
) -> SecurityPolicy:
if not client_token:
client_token = mock_random.get_random_string(10)
if (name, type) in list(
(sp.name, sp.type) for sp in list(self.security_policies.values())
):
raise ConflictException(
msg=f"Policy with name {name} and type {type} already exists"
)
if type not in ["encryption", "network"]:
raise ValidationException(
msg=f"1 validation error detected: Value '{type}' at 'type' failed to satisfy constraint: Member must satisfy enum value set: [encryption, network]"
)
security_policy = SecurityPolicy(
client_token=client_token,
description=description,
name=name,
policy=policy,
type=type,
)
self.security_policies[security_policy.client_token] = security_policy
return security_policy
def get_security_policy(self, name: str, type: str) -> SecurityPolicy:
for sp in list(self.security_policies.values()):
if sp.name == name and sp.type == type:
return sp
raise ResourceNotFoundException(
msg=f"Policy with name {name} and type {type} is not found"
)
def list_security_policies(
self, resource: List[str], type: str
) -> List[SecurityPolicy]:
"""
Pagination is not yet implemented
"""
security_policy_summaries = []
if resource:
for res in resource:
security_policy_summaries.extend(
[
sp
for sp in list(self.security_policies.values())
if res in sp.resources and type == sp.type
]
)
else:
security_policy_summaries = [
sp for sp in list(self.security_policies.values()) if sp.type == type
]
return security_policy_summaries
def update_security_policy(
self,
client_token: str,
description: str,
name: str,
policy: str,
policy_version: str,
type: str,
) -> SecurityPolicy:
if not client_token:
client_token = mock_random.get_random_string(10)
for sp in list(self.security_policies.values()):
if sp.name == name and sp.type == type:
if sp.policy_version == policy_version:
last_modified_date = sp.last_modified_date
if sp.policy != json.loads(policy):
last_modified_date = int(unix_time() * 1000)
# Updating policy version
policy_version = mock_random.get_random_string(20)
sp.client_token = client_token
sp.description = description
sp.name = name
sp.policy = json.loads(policy)
sp.last_modified_date = last_modified_date
sp.policy_version = policy_version
return sp
else:
raise ValidationException(
msg="Policy version specified in the request refers to an older version and policy has since changed"
)
raise ResourceNotFoundException(
msg=f"Policy with name {name} and type {type} is not found"
)
def create_collection(
self,
client_token: str,
description: str,
name: str,
standby_replicas: str,
tags: List[Dict[str, str]],
type: str,
) -> Collection:
policy = ""
if not client_token:
client_token = mock_random.get_random_string(10)
for sp in list(self.security_policies.values()):
if f"collection/{name}" in sp.resources:
policy = sp.policy
if not policy:
raise ValidationException(
msg=f"No matching security policy of encryption type found for collection name: {name}. Please create security policy of encryption type for this collection."
)
collection = Collection(
client_token=client_token,
description=description,
name=name,
standby_replicas=standby_replicas,
tags=tags,
type=type,
policy=policy,
region=self.region_name,
account_id=self.account_id,
)
self.collections[collection.id] = collection
self.tag_resource(collection.arn, tags)
return collection
def list_collections(self, collection_filters: Dict[str, str]) -> List[Collection]:
"""
Pagination is not yet implemented
"""
collection_summaries = []
if (collection_filters) and ("name" in collection_filters):
collection_summaries = [
collection
for collection in list(self.collections.values())
if collection.name == collection_filters["name"]
]
else:
collection_summaries = [
collection for collection in list(self.collections.values())
]
return collection_summaries
def create_vpc_endpoint(
self,
client_token: str,
name: str,
security_group_ids: List[str],
subnet_ids: List[str],
vpc_id: str,
) -> OSEndpoint:
if not client_token:
client_token = mock_random.get_random_string(10)
# Only 1 endpoint should exists under each VPC
if vpc_id in list(ose.vpc_id for ose in list(self.os_endpoints.values())):
raise ConflictException(
msg=f"Failed to create a VpcEndpoint {name} for AccountId {self.account_id} :: There is already a VpcEndpoint exist under VpcId {vpc_id}"
)
os_endpoint = OSEndpoint(
client_token=client_token,
name=name,
security_group_ids=security_group_ids,
subnet_ids=subnet_ids,
vpc_id=vpc_id,
)
self.os_endpoints[os_endpoint.client_token] = os_endpoint
return os_endpoint
def delete_collection(self, client_token: str, id: str) -> Collection:
if not client_token:
client_token = mock_random.get_random_string(10)
if id in self.collections:
self.collections[id].status = "DELETING"
return self.collections.pop(id)
raise ResourceNotFoundException(f"Collection with ID {id} cannot be found.")
def tag_resource(self, resource_arn: str, tags: List[Dict[str, str]]) -> None:
self.tagger.tag_resource(resource_arn, tags)
def untag_resource(self, resource_arn: str, tag_keys: List[str]) -> None:
self.tagger.untag_resource_using_names(resource_arn, tag_keys)
def list_tags_for_resource(self, resource_arn: str) -> List[Dict[str, str]]:
return self.tagger.list_tags_for_resource(resource_arn)["tags"]
def batch_get_collection(
self, ids: List[str], names: List[str]
) -> Tuple[List[Any], List[Dict[str, str]]]:
collection_details = []
collection_error_details = []
collection_error_detail = {
"errorCode": "NOT_FOUND",
"errorMessage": "The specified Collection is not found.",
}
if ids and names:
raise ValidationException(
msg="You need to provide IDs or names. You can't provide both IDs and names in the same request"
)
if ids:
for i in ids:
if i in self.collections:
collection_details.append(self.collections[i].to_dict_batch())
else:
collection_error_detail["id"] = i
collection_error_details.append(collection_error_detail)
if names:
for n in names:
for collection in self.collections.values():
if collection.name == n:
collection_details.append(collection.to_dict_batch())
else:
collection_error_detail["name"] = n
collection_error_details.append(collection_error_detail)
return collection_details, collection_error_details
opensearchserverless_backends = BackendDict(
OpenSearchServiceServerlessBackend, "opensearchserverless"
)