from typing import Any, Dict, List, Optional
from ..exceptions import (
DependencyViolationError,
InvalidNetworkAclIdError,
InvalidRouteTableIdError,
NetworkAclEntryAlreadyExistsError,
)
from ..utils import (
generic_filter,
random_network_acl_id,
random_network_acl_subnet_association_id,
)
from .core import TaggedEC2Resource
class NetworkAclBackend:
def __init__(self) -> None:
self.network_acls: Dict[str, "NetworkAcl"] = {}
def get_network_acl(self, network_acl_id: str) -> "NetworkAcl":
network_acl = self.network_acls.get(network_acl_id, None)
if not network_acl:
raise InvalidNetworkAclIdError(network_acl_id)
return network_acl
def create_network_acl(
self,
vpc_id: str,
tags: Optional[List[Dict[str, str]]] = None,
default: bool = False,
) -> "NetworkAcl":
network_acl_id = random_network_acl_id()
self.get_vpc(vpc_id) # type: ignore[attr-defined]
network_acl = NetworkAcl(self, network_acl_id, vpc_id, default)
for tag in tags or []:
network_acl.add_tag(tag["Key"], tag["Value"])
self.network_acls[network_acl_id] = network_acl
if default:
self.add_default_entries(network_acl_id)
return network_acl
def add_default_entries(self, network_acl_id: str) -> None:
default_acl_entries = [
{"rule_number": "100", "rule_action": "allow", "egress": "true"},
{"rule_number": "32767", "rule_action": "deny", "egress": "true"},
{"rule_number": "100", "rule_action": "allow", "egress": "false"},
{"rule_number": "32767", "rule_action": "deny", "egress": "false"},
]
for entry in default_acl_entries:
self.create_network_acl_entry(
network_acl_id=network_acl_id,
rule_number=entry["rule_number"],
protocol="-1",
rule_action=entry["rule_action"],
egress=entry["egress"],
cidr_block="0.0.0.0/0",
icmp_code=None,
icmp_type=None,
port_range_from=None,
port_range_to=None,
ipv6_cidr_block=None,
)
def delete_network_acl(self, network_acl_id: str) -> "NetworkAcl":
if any(
network_acl.id == network_acl_id and len(network_acl.associations) > 0
for network_acl in self.network_acls.values()
):
raise DependencyViolationError(
f"The network ACL '{network_acl_id}' has dependencies and cannot be deleted."
)
deleted = self.network_acls.pop(network_acl_id, None)
if not deleted:
raise InvalidNetworkAclIdError(network_acl_id)
return deleted
def create_network_acl_entry(
self,
network_acl_id: str,
rule_number: str,
protocol: str,
rule_action: str,
egress: str,
cidr_block: str,
icmp_code: Optional[int],
icmp_type: Optional[int],
port_range_from: Optional[int],
port_range_to: Optional[int],
ipv6_cidr_block: Optional[str],
) -> "NetworkAclEntry":
network_acl = self.get_network_acl(network_acl_id)
if any(
entry.egress == egress and entry.rule_number == rule_number
for entry in network_acl.network_acl_entries
):
raise NetworkAclEntryAlreadyExistsError(rule_number)
network_acl_entry = NetworkAclEntry(
self,
network_acl_id=network_acl_id,
rule_number=rule_number,
protocol=protocol,
rule_action=rule_action,
egress=egress,
cidr_block=cidr_block,
icmp_code=icmp_code,
icmp_type=icmp_type,
port_range_from=port_range_from,
port_range_to=port_range_to,
ipv6_cidr_block=ipv6_cidr_block,
)
network_acl.network_acl_entries.append(network_acl_entry)
return network_acl_entry
def delete_network_acl_entry(
self, network_acl_id: str, rule_number: str, egress: str
) -> "NetworkAclEntry":
network_acl = self.get_network_acl(network_acl_id)
entry = next(
entry
for entry in network_acl.network_acl_entries
if entry.egress == egress and entry.rule_number == rule_number
)
if entry is not None:
network_acl.network_acl_entries.remove(entry)
return entry
def replace_network_acl_entry(
self,
network_acl_id: str,
rule_number: str,
protocol: str,
rule_action: str,
egress: str,
cidr_block: str,
icmp_code: int,
icmp_type: int,
port_range_from: int,
port_range_to: int,
ipv6_cidr_block: Optional[str],
) -> "NetworkAclEntry":
self.delete_network_acl_entry(network_acl_id, rule_number, egress)
network_acl_entry = self.create_network_acl_entry(
network_acl_id=network_acl_id,
rule_number=rule_number,
protocol=protocol,
rule_action=rule_action,
egress=egress,
cidr_block=cidr_block,
icmp_code=icmp_code,
icmp_type=icmp_type,
port_range_from=port_range_from,
port_range_to=port_range_to,
ipv6_cidr_block=ipv6_cidr_block,
)
return network_acl_entry
def replace_network_acl_association(
self, association_id: str, network_acl_id: str
) -> "NetworkAclAssociation":
# lookup existing association for subnet and delete it
default_acl = next(
value
for key, value in self.network_acls.items()
if association_id in value.associations.keys()
)
subnet_id = None
for key in default_acl.associations:
if key == association_id:
subnet_id = default_acl.associations[key].subnet_id
del default_acl.associations[key]
break
new_assoc_id = random_network_acl_subnet_association_id()
association = NetworkAclAssociation(
self, new_assoc_id, subnet_id, network_acl_id
)
new_acl = self.get_network_acl(network_acl_id)
new_acl.associations[new_assoc_id] = association
return association
def associate_default_network_acl_with_subnet(
self, subnet_id: str, vpc_id: str
) -> None:
association_id = random_network_acl_subnet_association_id()
acl = next(
acl
for acl in self.network_acls.values()
if acl.default and acl.vpc_id == vpc_id
)
acl.associations[association_id] = NetworkAclAssociation(
self, association_id, subnet_id, acl.id
)
def describe_network_acls(
self, network_acl_ids: Optional[List[str]] = None, filters: Any = None
) -> List["NetworkAcl"]:
network_acls = list(self.network_acls.values())
if network_acl_ids:
network_acls = [
network_acl
for network_acl in network_acls
if network_acl.id in network_acl_ids
]
if len(network_acls) != len(network_acl_ids):
invalid_id = list(
set(network_acl_ids).difference(
set([network_acl.id for network_acl in network_acls])
)
)[0]
raise InvalidRouteTableIdError(invalid_id)
return generic_filter(filters, network_acls)
class NetworkAclAssociation:
def __init__(
self,
ec2_backend: Any,
new_association_id: str,
subnet_id: Optional[str],
network_acl_id: str,
):
self.ec2_backend = ec2_backend
self.id = new_association_id
self.new_association_id = new_association_id
self.subnet_id = subnet_id
self.network_acl_id = network_acl_id
class NetworkAcl(TaggedEC2Resource):
def __init__(
self,
ec2_backend: Any,
network_acl_id: str,
vpc_id: str,
default: bool = False,
owner_id: Optional[str] = None,
):
self.ec2_backend = ec2_backend
self.id = network_acl_id
self.vpc_id = vpc_id
self.owner_id = owner_id or ec2_backend.account_id
self.network_acl_entries: List[NetworkAclEntry] = []
self.associations: Dict[str, NetworkAclAssociation] = {}
self.default = "true" if default is True else "false"
def get_filter_value(
self, filter_name: str, method_name: Optional[str] = None
) -> Any:
if filter_name == "default":
return self.default
elif filter_name == "vpc-id":
return self.vpc_id
elif filter_name == "association.network-acl-id":
return self.id
elif filter_name == "association.subnet-id":
return [assoc.subnet_id for assoc in self.associations.values()]
elif filter_name == "entry.cidr":
return [entry.cidr_block for entry in self.network_acl_entries]
elif filter_name == "entry.protocol":
return [entry.protocol for entry in self.network_acl_entries]
elif filter_name == "entry.rule-number":
return [entry.rule_number for entry in self.network_acl_entries]
elif filter_name == "entry.rule-action":
return [entry.rule_action for entry in self.network_acl_entries]
elif filter_name == "entry.egress":
return [entry.egress for entry in self.network_acl_entries]
elif filter_name == "owner-id":
return self.owner_id
else:
return super().get_filter_value(filter_name, "DescribeNetworkAcls")
class NetworkAclEntry(TaggedEC2Resource):
def __init__(
self,
ec2_backend: Any,
network_acl_id: str,
rule_number: str,
protocol: str,
rule_action: str,
egress: str,
cidr_block: str,
icmp_code: Optional[int],
icmp_type: Optional[int],
port_range_from: Optional[int],
port_range_to: Optional[int],
ipv6_cidr_block: Optional[str],
):
self.ec2_backend = ec2_backend
self.network_acl_id = network_acl_id
self.rule_number = rule_number
self.protocol = protocol
self.rule_action = rule_action
self.egress = egress
self.cidr_block = cidr_block
self.ipv6_cidr_block = ipv6_cidr_block
self.icmp_code = icmp_code
self.icmp_type = icmp_type
self.port_range_from = port_range_from
self.port_range_to = port_range_to