"""DirectConnectBackend class with methods for supported APIs."""
from dataclasses import dataclass, field
from datetime import datetime
from typing import Any, Dict, List, Optional, Tuple
from moto.core.base_backend import BackendDict, BaseBackend
from moto.core.common_models import BaseModel
from .enums import (
ConnectionStateType,
EncryptionModeType,
LagStateType,
MacSecKeyStateType,
PortEncryptionStatusType,
)
from .exceptions import (
ConnectionIdMissing,
ConnectionNotFound,
LAGNotFound,
MacSecKeyNotFound,
)
@dataclass
class MacSecKey(BaseModel):
secret_arn: Optional[str]
ckn: Optional[str]
state: MacSecKeyStateType
start_on: str
cak: Optional[str] = None
def to_dict(self) -> Dict[str, str]:
return {
"secretARN": self.secret_arn or "",
"ckn": self.ckn or "",
"state": self.state,
"startOn": self.start_on,
}
@dataclass
class Connection(BaseModel):
aws_device_v2: str
aws_device: str
aws_logical_device_id: str
bandwidth: str
connection_name: str
connection_state: ConnectionStateType
encryption_mode: EncryptionModeType
has_logical_redundancy: bool
jumbo_frame_capable: bool
lag_id: Optional[str]
loa_issue_time: str
location: str
mac_sec_capable: Optional[bool]
mac_sec_keys: List[MacSecKey]
owner_account: str
partner_name: str
port_encryption_status: PortEncryptionStatusType
provider_name: Optional[str]
region: str
tags: Optional[List[Dict[str, str]]]
vlan: int
connection_id: str = field(default="", init=False)
def __post_init__(self) -> None:
if self.connection_id == "":
self.connection_id = f"dx-moto-{self.connection_name}-{datetime.now().strftime('%Y%m%d%H%M%S')}"
def to_dict(
self,
) -> Dict[str, Any]:
return {
"awsDevice": self.aws_device,
"awsDeviceV2": self.aws_device_v2,
"awsLogicalDeviceId": self.aws_logical_device_id,
"bandwidth": self.bandwidth,
"connectionId": self.connection_id,
"connectionName": self.connection_name,
"connectionState": self.connection_state,
"encryptionMode": self.encryption_mode,
"hasLogicalRedundancy": self.has_logical_redundancy,
"jumboFrameCapable": self.jumbo_frame_capable,
"lagId": self.lag_id,
"loaIssueTime": self.loa_issue_time,
"location": self.location,
"macSecCapable": self.mac_sec_capable,
"macSecKeys": [key.to_dict() for key in self.mac_sec_keys],
"partnerName": self.partner_name,
"portEncryptionStatus": self.port_encryption_status,
"providerName": self.provider_name,
"region": self.region,
"tags": self.tags,
"vlan": self.vlan,
}
@dataclass
class LAG(BaseModel):
aws_device_v2: str
aws_device: str
aws_logical_device_id: str
connections_bandwidth: str
number_of_connections: int
minimum_links: int
connections: List[Connection]
lag_name: str
lag_state: LagStateType
encryption_mode: EncryptionModeType
has_logical_redundancy: bool
jumbo_frame_capable: bool
location: str
mac_sec_capable: Optional[bool]
mac_sec_keys: List[MacSecKey]
owner_account: str
provider_name: Optional[str]
region: str
tags: Optional[List[Dict[str, str]]]
lag_id: str = field(default="", init=False)
def __post_init__(self) -> None:
if self.lag_id == "":
self.lag_id = (
f"dxlag-moto-{self.lag_name}-{datetime.now().strftime('%Y%m%d%H%M%S')}"
)
def to_dict(
self,
) -> Dict[str, Any]:
return {
"awsDevice": self.aws_device,
"awsDeviceV2": self.aws_device_v2,
"awsLogicalDeviceId": self.aws_logical_device_id,
"connectionsBandwidth": self.connections_bandwidth,
"numberOfConnections": self.number_of_connections,
"minimumLinks": self.minimum_links,
"connections": [conn.to_dict() for conn in self.connections],
"lagId": self.lag_id,
"lagName": self.lag_name,
"lagState": self.lag_state,
"encryptionMode": self.encryption_mode,
"hasLogicalRedundancy": self.has_logical_redundancy,
"jumboFrameCapable": self.jumbo_frame_capable,
"location": self.location,
"macSecCapable": self.mac_sec_capable,
"macSecKeys": [key.to_dict() for key in self.mac_sec_keys],
"providerName": self.provider_name,
"region": self.region,
"tags": self.tags,
}
class DirectConnectBackend(BaseBackend):
"""Implementation of DirectConnect APIs."""
def __init__(self, region_name: str, account_id: str) -> None:
super().__init__(region_name, account_id)
self.connections: Dict[str, Connection] = {}
self.lags: Dict[str, LAG] = {}
def describe_connections(self, connection_id: Optional[str]) -> List[Connection]:
if connection_id and connection_id not in self.connections:
raise ConnectionNotFound(connection_id, self.region_name)
if connection_id:
connection = self.connections.get(connection_id)
return [] if not connection else [connection]
return list(self.connections.values())
def create_connection(
self,
location: str,
bandwidth: str,
connection_name: str,
lag_id: Optional[str],
tags: Optional[List[Dict[str, str]]],
provider_name: Optional[str],
request_mac_sec: Optional[bool],
) -> Connection:
encryption_mode = EncryptionModeType.NO
mac_sec_keys = []
if request_mac_sec:
encryption_mode = EncryptionModeType.MUST
mac_sec_keys = [
MacSecKey(
secret_arn="mock_secret_arn",
ckn="mock_ckn",
state=MacSecKeyStateType.ASSOCIATED,
start_on=datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
)
]
connection = Connection(
aws_device_v2="mock_device_v2",
aws_device="mock_device",
aws_logical_device_id="mock_logical_device_id",
bandwidth=bandwidth,
connection_name=connection_name,
connection_state=ConnectionStateType.AVAILABLE,
encryption_mode=encryption_mode,
has_logical_redundancy=False,
jumbo_frame_capable=False,
lag_id=lag_id,
loa_issue_time=datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
location=location,
mac_sec_capable=request_mac_sec,
mac_sec_keys=mac_sec_keys,
owner_account=self.account_id,
partner_name="mock_partner",
port_encryption_status=PortEncryptionStatusType.DOWN,
provider_name=provider_name,
region=self.region_name,
tags=tags,
vlan=0,
)
self.connections[connection.connection_id] = connection
return connection
def delete_connection(self, connection_id: str) -> Connection:
if not connection_id:
raise ConnectionIdMissing()
connection = self.connections.get(connection_id)
if connection:
self.connections[
connection_id
].connection_state = ConnectionStateType.DELETED
return connection
raise ConnectionNotFound(connection_id, self.region_name)
def update_connection(
self,
connection_id: str,
new_connection_name: Optional[str],
new_encryption_mode: Optional[EncryptionModeType],
) -> Connection:
if not connection_id:
raise ConnectionIdMissing()
connection = self.connections.get(connection_id)
if connection:
if new_connection_name:
self.connections[connection_id].connection_name = new_connection_name
if new_encryption_mode:
self.connections[connection_id].encryption_mode = new_encryption_mode
return connection
raise ConnectionNotFound(connection_id, self.region_name)
def associate_mac_sec_key(
self,
connection_id: str,
secret_arn: Optional[str],
ckn: Optional[str],
cak: Optional[str],
) -> Tuple[str, List[MacSecKey]]:
mac_sec_key = MacSecKey(
secret_arn=secret_arn or "mock_secret_arn",
ckn=ckn,
cak=cak,
state=MacSecKeyStateType.ASSOCIATED,
start_on=datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
)
if connection_id.startswith("dxlag-"):
return self._associate_mac_sec_key_with_lag(
lag_id=connection_id, mac_sec_key=mac_sec_key
)
return self._associate_mac_sec_key_with_connection(
connection_id=connection_id, mac_sec_key=mac_sec_key
)
def _associate_mac_sec_key_with_lag(
self, lag_id: str, mac_sec_key: MacSecKey
) -> Tuple[str, List[MacSecKey]]:
lag = self.lags.get(lag_id) or None
if not lag:
raise LAGNotFound(lag_id, self.region_name)
lag.mac_sec_keys.append(mac_sec_key)
for connection in lag.connections:
connection.mac_sec_keys = lag.mac_sec_keys
return lag_id, lag.mac_sec_keys
def _associate_mac_sec_key_with_connection(
self, connection_id: str, mac_sec_key: MacSecKey
) -> Tuple[str, List[MacSecKey]]:
connection = self.connections.get(connection_id) or None
if not connection:
raise ConnectionNotFound(connection_id, self.region_name)
self.connections[connection_id].mac_sec_keys.append(mac_sec_key)
return connection_id, self.connections[connection_id].mac_sec_keys
def create_lag(
self,
number_of_connections: int,
location: str,
connections_bandwidth: str,
lag_name: str,
connection_id: Optional[str],
tags: Optional[List[Dict[str, str]]],
child_connection_tags: Optional[List[Dict[str, str]]],
provider_name: Optional[str],
request_mac_sec: Optional[bool],
) -> LAG:
if connection_id:
raise NotImplementedError(
"creating a lag with a connection_id is not currently supported by moto"
)
encryption_mode = EncryptionModeType.NO
mac_sec_keys = []
if request_mac_sec:
encryption_mode = EncryptionModeType.MUST
mac_sec_keys = [
MacSecKey(
secret_arn="mock_secret_arn",
ckn="mock_ckn",
state=MacSecKeyStateType.ASSOCIATED,
start_on=datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
)
]
lag = LAG(
aws_device_v2="mock_device_v2",
aws_device="mock_device",
aws_logical_device_id="mock_logical_device_id",
connections_bandwidth=connections_bandwidth,
lag_name=lag_name,
lag_state=LagStateType.AVAILABLE,
minimum_links=0,
encryption_mode=encryption_mode,
has_logical_redundancy=False,
jumbo_frame_capable=False,
number_of_connections=number_of_connections,
connections=[],
location=location,
mac_sec_capable=request_mac_sec,
mac_sec_keys=mac_sec_keys,
owner_account=self.account_id,
provider_name=provider_name,
region=self.region_name,
tags=tags,
)
for i in range(number_of_connections):
connection = self.create_connection(
location=location,
bandwidth=connections_bandwidth,
connection_name=f"Requested Connection {i+1} for Lag {lag.lag_id}",
lag_id=lag.lag_id,
tags=child_connection_tags,
request_mac_sec=False,
provider_name=provider_name,
)
if request_mac_sec:
connection.mac_sec_capable = True
connection.mac_sec_keys = mac_sec_keys
connection.encryption_mode = encryption_mode
lag.connections.append(connection)
self.lags[lag.lag_id] = lag
return lag
def describe_lags(self, lag_id: Optional[str]) -> List[LAG]:
if lag_id and lag_id not in self.lags:
raise LAGNotFound(lag_id, self.region_name)
if lag_id:
lag = self.lags.get(lag_id)
return [] if not lag else [lag]
return list(self.lags.values())
def disassociate_mac_sec_key(
self, connection_id: str, secret_arn: str
) -> Tuple[str, MacSecKey]:
mac_sec_keys: List[MacSecKey] = []
if connection_id.startswith("dxlag-"):
if connection_id in self.lags:
mac_sec_keys = self.lags[connection_id].mac_sec_keys
elif connection_id in self.connections:
mac_sec_keys = self.connections[connection_id].mac_sec_keys
if not mac_sec_keys:
raise ConnectionNotFound(connection_id, self.region_name)
arn_casefold = secret_arn.casefold()
for i, mac_sec_key in enumerate(mac_sec_keys):
if str(mac_sec_key.secret_arn).casefold() == arn_casefold:
mac_sec_key.state = MacSecKeyStateType.DISASSOCIATED
return connection_id, mac_sec_keys.pop(i)
raise MacSecKeyNotFound(secret_arn=secret_arn, connection_id=connection_id)
directconnect_backends = BackendDict(DirectConnectBackend, "directconnect")