"""TransferBackend class with methods for supported APIs.""" from typing import Any, Dict, List, Optional, Tuple from moto.core.base_backend import BackendDict, BaseBackend from moto.core.utils import unix_time from moto.transfer.exceptions import PublicKeyNotFound, ServerNotFound, UserNotFound from .types import ( Server, ServerDomain, ServerEndpointType, ServerIdentityProviderType, ServerProtocols, User, UserHomeDirectoryType, ) class TransferBackend(BaseBackend): """Implementation of Transfer APIs.""" def __init__(self, region_name: str, account_id: str) -> None: super().__init__(region_name, account_id) self.servers: Dict[str, Server] = {} def create_server( self, certificate: Optional[str], domain: Optional[ServerDomain], endpoint_details: Optional[Dict[str, Any]], endpoint_type: Optional[ServerEndpointType], host_key: str, identity_provider_details: Optional[Dict[str, Any]], identity_provider_type: Optional[ServerIdentityProviderType], logging_role: Optional[str], post_authentication_login_banner: Optional[str], pre_authentication_login_banner: Optional[str], protocols: Optional[List[ServerProtocols]], protocol_details: Optional[Dict[str, Any]], security_policy_name: Optional[str], tags: Optional[List[Dict[str, str]]], workflow_details: Optional[Dict[str, Any]], structured_log_destinations: Optional[List[str]], s3_storage_options: Optional[Dict[str, Optional[str]]], ) -> str: server = Server( certificate=certificate, domain=domain, endpoint_type=endpoint_type, host_key_fingerprint=host_key, identity_provider_type=identity_provider_type, logging_role=logging_role, post_authentication_login_banner=post_authentication_login_banner, pre_authentication_login_banner=pre_authentication_login_banner, protocols=protocols, security_policy_name=security_policy_name, structured_log_destinations=structured_log_destinations, tags=(tags or []), ) if endpoint_details is not None: endpoint_details = { "address_allocation_ids": endpoint_details.get("AddressAllocationIds"), "subnet_ids": endpoint_details.get("SubnetIds"), "vpc_endpoint_id": endpoint_details.get("VpcEndpointId"), "vpc_id": endpoint_details.get("VpcId"), "security_group_ids": endpoint_details.get("SecurityGroupIds"), } server.endpoint_details = endpoint_details if identity_provider_details is not None: identity_provider_details = { "url": identity_provider_details.get("Url"), "invocation_role": identity_provider_details.get("InvocationRole"), "directory_id": identity_provider_details.get("DirectoryId"), "function": identity_provider_details.get("Function"), "sftp_authentication_methods": identity_provider_details.get( "SftpAuthenticationMethods" ), } server.identity_provider_details = identity_provider_details if protocol_details is not None: protocol_details = { "passive_ip": protocol_details.get("PassiveIp"), "tls_session_resumption_mode": protocol_details.get( "TlsSessionResumptionMode" ), "set_stat_option": protocol_details.get("SetStatOption"), "as2_transports": protocol_details.get("As2Transports"), } server.protocol_details = protocol_details if s3_storage_options is not None: server.s3_storage_options = { "directory_listing_optimization": s3_storage_options.get( "DirectoryListingOptimization" ) } if workflow_details is not None: server.workflow_details = { "on_upload": [ { "workflow_id": workflow.get("WorkflowId"), "execution_role": workflow.get("ExecutionRole"), } for workflow in (workflow_details.get("OnUpload") or []) ], "on_partial_upload": [ { "workflow_id": workflow.get("WorkflowId"), "execution_role": workflow.get("ExecutionRole"), } for workflow in (workflow_details.get("OnPartialUpload") or []) ], } server_id = server.server_id self.servers[server_id] = server return server_id def describe_server(self, server_id: str) -> Server: if server_id not in self.servers: ServerNotFound(server_id=server_id) server = self.servers[server_id] return server def delete_server(self, server_id: str) -> None: if server_id not in self.servers: ServerNotFound(server_id=server_id) del self.servers[server_id] return def create_user( self, home_directory: Optional[str], home_directory_type: Optional[UserHomeDirectoryType], home_directory_mappings: Optional[List[Dict[str, Optional[str]]]], policy: Optional[str], posix_profile: Optional[Dict[str, Any]], role: str, server_id: str, ssh_public_key_body: Optional[str], tags: Optional[List[Dict[str, str]]], user_name: str, ) -> Tuple[str, str]: if server_id not in self.servers: ServerNotFound(server_id=server_id) user = User( home_directory=home_directory, home_directory_type=home_directory_type, policy=policy, role=role, tags=(tags or []), user_name=user_name, ) if home_directory_mappings: for mapping in home_directory_mappings: user.home_directory_mappings.append( { "entry": mapping.get("Entry"), "target": mapping.get("Target"), "type": mapping.get("Type"), } ) if posix_profile is not None: posix_profile = { "gid": posix_profile.get("Gid"), "uid": posix_profile.get("Uid"), "secondary_gids": posix_profile.get("SecondaryGids"), } user.posix_profile = posix_profile if ssh_public_key_body is not None: now = unix_time() ssh_public_keys = [ { "date_imported": str(now), "ssh_public_key_body": ssh_public_key_body, "ssh_public_key_id": "mock_ssh_public_key_id_{ssh_public_key_body}_{now}", } ] user.ssh_public_keys = ssh_public_keys self.servers[server_id]._users.append(user) self.servers[server_id].user_count += 1 return server_id, user_name def describe_user(self, server_id: str, user_name: str) -> Tuple[str, User]: if server_id not in self.servers: raise ServerNotFound(server_id=server_id) for user in self.servers[server_id]._users: if user.user_name == user_name: return server_id, user raise UserNotFound(user_name=user_name, server_id=server_id) def delete_user(self, server_id: str, user_name: str) -> None: if server_id not in self.servers: raise ServerNotFound(server_id=server_id) for i, user in enumerate(self.servers[server_id]._users): if user.user_name == user_name: del self.servers[server_id]._users[i] self.servers[server_id].user_count -= 1 return raise UserNotFound(server_id=server_id, user_name=user_name) def import_ssh_public_key( self, server_id: str, ssh_public_key_body: str, user_name: str ) -> Tuple[str, str, str]: if server_id not in self.servers: raise ServerNotFound(server_id=server_id) for user in self.servers[server_id]._users: if user.user_name == user_name: date_imported = unix_time() ssh_public_key_id = ( f"{server_id}:{user_name}:public_key:{date_imported}" ) key = { "ssh_public_key_id": ssh_public_key_id, "ssh_public_key_body": ssh_public_key_body, "date_imported": str(date_imported), } user.ssh_public_keys.append(key) return server_id, ssh_public_key_id, user_name raise UserNotFound(user_name=user_name, server_id=server_id) def delete_ssh_public_key( self, server_id: str, ssh_public_key_id: str, user_name: str ) -> None: if server_id not in self.servers: raise ServerNotFound(server_id=server_id) for i, user in enumerate(self.servers[server_id]._users): if user.user_name == user_name: for j, key in enumerate( self.servers[server_id]._users[i].ssh_public_keys ): if key["ssh_public_key_id"] == ssh_public_key_id: del user.ssh_public_keys[j] return raise PublicKeyNotFound( user_name=user_name, server_id=server_id, ssh_public_key_id=ssh_public_key_id, ) raise UserNotFound(user_name=user_name, server_id=server_id) transfer_backends = BackendDict(TransferBackend, "transfer")
Memory