from typing import Any, Dict, List, Optional from moto.core.utils import iso_8601_datetime_with_milliseconds, utcnow from ..exceptions import ( InvalidKeyPairDuplicateError, InvalidKeyPairFormatError, InvalidKeyPairNameError, ) from ..utils import ( generic_filter, public_key_fingerprint, public_key_parse, random_ed25519_key_pair, random_key_pair_id, random_rsa_key_pair, select_hash_algorithm, ) from .core import TaggedEC2Resource class KeyPair(TaggedEC2Resource): def __init__( self, name: str, fingerprint: str, material: Optional[str], material_public: str, tags: Dict[str, str], ec2_backend: Any, ): self.id = random_key_pair_id() self.name = name self.fingerprint = fingerprint # public key fingerprint self.material = material # PEM encoded private key self.material_public = material_public # public key in OpenSSH format self.create_time = utcnow() self.ec2_backend = ec2_backend self.add_tags(tags or {}) @property def created_iso_8601(self) -> str: return iso_8601_datetime_with_milliseconds(self.create_time) def get_filter_value( self, filter_name: str, method_name: Optional[str] = None ) -> str: if filter_name == "key-name": return self.name elif filter_name == "fingerprint": return self.fingerprint else: return super().get_filter_value(filter_name, "DescribeKeyPairs") class KeyPairBackend: def __init__(self) -> None: self.keypairs: Dict[str, KeyPair] = {} def create_key_pair( self, name: str, key_type: str, tags: Dict[str, str] ) -> KeyPair: if name in self.keypairs: raise InvalidKeyPairDuplicateError(name) if key_type == "ed25519": keypair = KeyPair( name, **random_ed25519_key_pair(), tags=tags, ec2_backend=self ) else: keypair = KeyPair( name, **random_rsa_key_pair(), tags=tags, ec2_backend=self ) self.keypairs[name] = keypair return keypair def delete_key_pair(self, name: str) -> None: self.keypairs.pop(name, None) def describe_key_pairs( self, key_names: List[str], filters: Any = None ) -> List[KeyPair]: if any(key_names): results = [ keypair for keypair in self.keypairs.values() if keypair.name in key_names ] if len(key_names) > len(results): unknown_keys = set(key_names) - set(results) # type: ignore raise InvalidKeyPairNameError(unknown_keys) else: results = list(self.keypairs.values()) if filters: return generic_filter(filters, results) else: return results def import_key_pair( self, key_name: str, public_key_material: str, tags: Dict[str, str] ) -> KeyPair: if key_name in self.keypairs: raise InvalidKeyPairDuplicateError(key_name) try: public_key = public_key_parse(public_key_material) except ValueError: raise InvalidKeyPairFormatError() hash_constructor = select_hash_algorithm(public_key) fingerprint = public_key_fingerprint(public_key, hash_constructor) keypair = KeyPair( key_name, material_public=public_key_material, material=None, fingerprint=fingerprint, tags=tags, ec2_backend=self, ) self.keypairs[key_name] = keypair return keypair
Memory