import copy
import itertools
import json
from collections import defaultdict
from typing import Any, Dict, Iterator, List, Optional, Tuple
from moto.core.common_models import BaseModel, CloudFormationModel
from moto.core.utils import aws_api_matches
from ..exceptions import (
InvalidCIDRSubnetError,
InvalidGroupIdMalformedError,
InvalidPermissionDuplicateError,
InvalidPermissionNotFoundError,
InvalidSecurityGroupDuplicateError,
InvalidSecurityGroupNotFoundError,
MissingParameterError,
MotoNotImplementedError,
RulesPerSecurityGroupLimitExceededError,
)
from ..utils import (
is_tag_filter,
is_valid_cidr,
is_valid_ipv6_cidr,
is_valid_security_group_id,
random_security_group_id,
random_security_group_rule_id,
tag_filter_matches,
)
from .core import TaggedEC2Resource
class SecurityRule(TaggedEC2Resource):
def __init__(
self,
ec2_backend: Any,
ip_protocol: str,
group_id: str,
from_port: Optional[str],
to_port: Optional[str],
ip_range: Optional[Dict[str, str]],
source_group: Optional[Dict[str, str]] = None,
prefix_list_id: Optional[Dict[str, str]] = None,
is_egress: bool = True,
tags: Dict[str, str] = {},
description: str = "",
):
self.ec2_backend = ec2_backend
self.id = random_security_group_rule_id()
self.ip_protocol = str(ip_protocol) if ip_protocol else None
self.ip_range = ip_range or {}
self.source_group = source_group or {}
self.prefix_list_id = prefix_list_id or {}
self.from_port = self.to_port = None
self.is_egress = is_egress
self.description = description
self.group_id = group_id
if self.ip_protocol and self.ip_protocol != "-1":
self.from_port = int(from_port) # type: ignore[arg-type]
self.to_port = int(to_port) # type: ignore[arg-type]
ip_protocol_keywords = {
"tcp": "tcp",
"6": "tcp",
"udp": "udp",
"17": "udp",
"all": "-1",
"-1": "-1",
"tCp": "tcp",
"UDp": "udp",
"ALL": "-1",
"icMp": "icmp",
"1": "icmp",
"icmp": "icmp",
}
proto = (
ip_protocol_keywords.get(self.ip_protocol.lower())
if self.ip_protocol
else None
)
self.ip_protocol = proto if proto else self.ip_protocol
self.add_tags(tags)
@property
def owner_id(self) -> str:
return self.ec2_backend.account_id
def __eq__(self, other: "SecurityRule") -> bool: # type: ignore[override]
if self.ip_protocol != other.ip_protocol:
return False
if "CidrIp" in self.ip_range and self.ip_range.get(
"CidrIp"
) != other.ip_range.get("CidrIp"):
return False
if "CidrIpv6" in self.ip_range and self.ip_range.get(
"CidrIpv6"
) != other.ip_range.get("CidrIpv6"):
return False
if self.source_group != other.source_group:
return False
if self.prefix_list_id != other.prefix_list_id:
return False
if self.ip_protocol != "-1":
if self.from_port != other.from_port:
return False
if self.to_port != other.to_port:
return False
return True
def __deepcopy__(self, memodict: Dict[Any, Any]) -> BaseModel:
memodict = memodict or {}
cls = self.__class__
new = cls.__new__(cls)
memodict[id(self)] = new
for k, v in self.__dict__.items():
if k == "ec2_backend":
setattr(new, k, self.ec2_backend)
else:
setattr(new, k, copy.deepcopy(v, memodict))
return new
class GroupedSecurityRuleView:
def __init__(
self,
from_port: Optional[int],
to_port: Optional[int],
ip_protocol: Optional[str],
):
self.from_port = from_port
self.to_port = to_port
self.ip_protocol = ip_protocol
self.ip_ranges: List[Dict[str, str]] = []
self.source_groups: List[Dict[str, str]] = []
self.prefix_list_ids: List[Dict[str, str]] = []
class SecurityGroup(TaggedEC2Resource, CloudFormationModel):
def __init__(
self,
ec2_backend: Any,
group_id: str,
name: str,
description: str,
vpc_id: Optional[str] = None,
tags: Optional[Dict[str, str]] = None,
is_default: Optional[bool] = None,
):
self.ec2_backend = ec2_backend
self.id = group_id
self.group_id = self.id
self.name = name
self.group_name = self.name
self.description = description
self.ingress_rules: List[SecurityRule] = []
self.egress_rules: List[SecurityRule] = []
self.vpc_id: Optional[str] = vpc_id
self.owner_id = ec2_backend.account_id
self.add_tags(tags or {})
self.is_default = is_default or False
self.arn = f"arn:aws:ec2:{ec2_backend.region_name}:{ec2_backend.account_id}:security-group/{group_id}"
# Append default IPv6 egress rule for VPCs with IPv6 support
if vpc_id:
vpc = self.ec2_backend.vpcs.get(vpc_id)
if vpc:
self.egress_rules.append(
SecurityRule(
self.ec2_backend,
"-1",
self.id,
None,
None,
{"CidrIp": "0.0.0.0/0"},
)
)
if vpc and len(vpc.get_cidr_block_association_set(ipv6=True)) > 0:
self.egress_rules.append(
SecurityRule(
self.ec2_backend,
"-1",
self.id,
None,
None,
{"CidrIpv6": "::/0"},
)
)
# each filter as a simple function in a mapping
self.filters = {
"description": self.filter_description,
"egress.ip-permission.cidr": self.filter_egress__ip_permission__cidr,
"egress.ip-permission.from-port": self.filter_egress__ip_permission__from_port,
"egress.ip-permission.group-id": self.filter_egress__ip_permission__group_id,
"egress.ip-permission.group-name": self.filter_egress__ip_permission__group_name,
"egress.ip-permission.ipv6-cidr": self.filter_egress__ip_permission__ipv6_cidr,
"egress.ip-permission.prefix-list-id": self.filter_egress__ip_permission__prefix_list_id,
"egress.ip-permission.protocol": self.filter_egress__ip_permission__protocol,
"egress.ip-permission.to-port": self.filter_egress__ip_permission__to_port,
"egress.ip-permission.user-id": self.filter_egress__ip_permission__user_id,
"group-id": self.filter_group_id,
"group-name": self.filter_group_name,
"ip-permission.cidr": self.filter_ip_permission__cidr,
"ip-permission.from-port": self.filter_ip_permission__from_port,
"ip-permission.group-id": self.filter_ip_permission__group_id,
"ip-permission.group-name": self.filter_ip_permission__group_name,
"ip-permission.ipv6-cidr": self.filter_ip_permission__ipv6_cidr,
"ip-permission.prefix-list-id": self.filter_ip_permission__prefix_list_id,
"ip-permission.protocol": self.filter_ip_permission__protocol,
"ip-permission.to-port": self.filter_ip_permission__to_port,
"ip-permission.user-id": self.filter_ip_permission__user_id,
"owner-id": self.filter_owner_id,
"vpc-id": self.filter_vpc_id,
}
@staticmethod
def cloudformation_name_type() -> str:
return "GroupName"
@staticmethod
def cloudformation_type() -> str:
# https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-resource-ec2-securitygroup.html
return "AWS::EC2::SecurityGroup"
@classmethod
def create_from_cloudformation_json( # type: ignore[misc]
cls,
resource_name: str,
cloudformation_json: Any,
account_id: str,
region_name: str,
**kwargs: Any,
) -> "SecurityGroup":
from ..models import ec2_backends
properties = cloudformation_json["Properties"]
ec2_backend = ec2_backends[account_id][region_name]
vpc_id = properties.get("VpcId")
security_group = ec2_backend.create_security_group(
name=resource_name,
description=properties.get("GroupDescription"),
vpc_id=vpc_id,
)
for tag in properties.get("Tags", []):
tag_key = tag["Key"]
tag_value = tag["Value"]
security_group.add_tag(tag_key, tag_value)
for ingress_rule in properties.get("SecurityGroupIngress", []):
source_group_id = ingress_rule.get("SourceSecurityGroupId")
source_group_name = ingress_rule.get("SourceSecurityGroupName")
source_group = {}
if source_group_id:
source_group["GroupId"] = source_group_id
if source_group_name:
source_group["GroupName"] = source_group_name
ec2_backend.authorize_security_group_ingress(
group_name_or_id=security_group.id,
ip_protocol=ingress_rule["IpProtocol"],
from_port=ingress_rule["FromPort"],
to_port=ingress_rule["ToPort"],
ip_ranges=ingress_rule.get("CidrIp", []),
source_groups=[source_group] if source_group else [],
vpc_id=vpc_id,
)
return security_group
@classmethod
def update_from_cloudformation_json( # type: ignore[misc]
cls,
original_resource: Any,
new_resource_name: str,
cloudformation_json: Any,
account_id: str,
region_name: str,
) -> "SecurityGroup":
cls._delete_security_group_given_vpc_id(
original_resource.name, original_resource.vpc_id, account_id, region_name
)
return cls.create_from_cloudformation_json(
new_resource_name, cloudformation_json, account_id, region_name
)
@classmethod
def delete_from_cloudformation_json( # type: ignore[misc]
cls,
resource_name: str,
cloudformation_json: Any,
account_id: str,
region_name: str,
) -> None:
properties = cloudformation_json["Properties"]
vpc_id = properties.get("VpcId")
cls._delete_security_group_given_vpc_id(
resource_name, vpc_id, account_id, region_name
)
@classmethod
def _delete_security_group_given_vpc_id(
cls, resource_name: str, vpc_id: str, account_id: str, region_name: str
) -> None:
from ..models import ec2_backends
ec2_backend = ec2_backends[account_id][region_name]
security_group = ec2_backend.get_security_group_by_name_or_id(
resource_name, vpc_id
)
if security_group:
security_group.delete(account_id, region_name)
def delete(
self,
account_id: str,
region_name: str, # pylint: disable=unused-argument
) -> None:
"""Not exposed as part of the ELB API - used for CloudFormation."""
self.ec2_backend.delete_security_group(group_id=self.id)
@property
def physical_resource_id(self) -> str:
return self.id
def filter_description(self, values: List[Any]) -> bool:
for value in values:
if aws_api_matches(value, self.description):
return True
return False
def filter_egress__ip_permission__cidr(self, values: List[Any]) -> bool:
for value in values:
for rule in self.egress_rules:
if aws_api_matches(value, rule.ip_range.get("CidrIp", "NONE")):
return True
return False
def filter_egress__ip_permission__from_port(self, values: List[Any]) -> bool:
for value in values:
for rule in self.egress_rules:
if rule.ip_protocol != "-1" and aws_api_matches(
value, str(rule.from_port)
):
return True
return False
def filter_egress__ip_permission__group_id(self, values: List[Any]) -> bool:
for value in values:
for rule in self.egress_rules:
if aws_api_matches(value, rule.source_group.get("GroupId", None)):
return True
return False
def filter_egress__ip_permission__group_name(self, values: List[Any]) -> bool:
for value in values:
for rule in self.egress_rules:
if aws_api_matches(value, rule.source_group.get("GroupName", None)):
return True
return False
def filter_egress__ip_permission__ipv6_cidr(self, values: List[Any]) -> bool:
raise MotoNotImplementedError("egress.ip-permission.ipv6-cidr filter")
def filter_egress__ip_permission__prefix_list_id(self, values: List[Any]) -> bool:
raise MotoNotImplementedError("egress.ip-permission.prefix-list-id filter")
def filter_egress__ip_permission__protocol(self, values: List[Any]) -> bool:
for value in values:
for rule in self.egress_rules:
if aws_api_matches(value, rule.ip_protocol):
return True
return False
def filter_egress__ip_permission__to_port(self, values: List[Any]) -> bool:
for value in values:
for rule in self.egress_rules:
if aws_api_matches(value, rule.to_port):
return True
return False
def filter_egress__ip_permission__user_id(self, values: List[Any]) -> bool:
for value in values:
for rule in self.egress_rules:
if aws_api_matches(value, rule.owner_id):
return True
return False
def filter_group_id(self, values: List[Any]) -> bool:
for value in values:
if aws_api_matches(value, self.id):
return True
return False
def filter_group_name(self, values: List[Any]) -> bool:
for value in values:
if aws_api_matches(value, self.group_name):
return True
return False
def filter_ip_permission__cidr(self, values: List[Any]) -> bool:
for value in values:
for rule in self.ingress_rules:
if aws_api_matches(value, rule.ip_range.get("CidrIp", "NONE")):
return True
return False
def filter_ip_permission__from_port(self, values: List[Any]) -> bool:
for value in values:
for rule in self.ingress_rules:
if aws_api_matches(value, rule.from_port):
return True
return False
def filter_ip_permission__group_id(self, values: List[Any]) -> bool:
for value in values:
for rule in self.ingress_rules:
if aws_api_matches(value, rule.source_group.get("GroupId", None)):
return True
return False
def filter_ip_permission__group_name(self, values: List[Any]) -> bool:
for value in values:
for rule in self.ingress_rules:
if aws_api_matches(value, rule.source_group.get("GroupName", None)):
return True
return False
def filter_ip_permission__ipv6_cidr(self, values: List[Any]) -> None:
raise MotoNotImplementedError("ip-permission.ipv6 filter")
def filter_ip_permission__prefix_list_id(self, values: List[Any]) -> None:
raise MotoNotImplementedError("ip-permission.prefix-list-id filter")
def filter_ip_permission__protocol(self, values: List[Any]) -> bool:
for value in values:
for rule in self.ingress_rules:
if aws_api_matches(value, rule.ip_protocol):
return True
return False
def filter_ip_permission__to_port(self, values: List[Any]) -> bool:
for value in values:
for rule in self.ingress_rules:
if aws_api_matches(value, rule.to_port):
return True
return False
def filter_ip_permission__user_id(self, values: List[Any]) -> bool:
for value in values:
for rule in self.ingress_rules:
if aws_api_matches(value, rule.owner_id):
return True
return False
def filter_owner_id(self, values: List[Any]) -> bool:
for value in values:
if aws_api_matches(value, self.owner_id):
return True
return False
def filter_vpc_id(self, values: List[Any]) -> bool:
for value in values:
if aws_api_matches(value, self.vpc_id):
return True
return False
def matches_filter(self, key: str, filter_value: Any) -> Any:
if is_tag_filter(key):
tag_value = self.get_filter_value(key)
if isinstance(filter_value, list):
return tag_filter_matches(self, key, filter_value)
return tag_value in filter_value
else:
return self.filters[key](filter_value)
def matches_filters(self, filters: Any) -> bool:
for key, value in filters.items():
if not self.matches_filter(key, value):
return False
return True
@classmethod
def has_cfn_attr(cls, attr: str) -> bool:
return attr in ["GroupId"]
def get_cfn_attribute(self, attribute_name: str) -> str:
from moto.cloudformation.exceptions import UnformattedGetAttTemplateException
if attribute_name == "GroupId":
return self.id
raise UnformattedGetAttTemplateException()
def add_ingress_rule(self, rule: SecurityRule) -> None:
if rule in self.ingress_rules:
raise InvalidPermissionDuplicateError()
self.ingress_rules.append(rule)
def add_egress_rule(self, rule: SecurityRule) -> None:
if rule in self.egress_rules:
raise InvalidPermissionDuplicateError()
self.egress_rules.append(rule)
def get_number_of_ingress_rules(self) -> int:
return len(self.ingress_rules)
def get_number_of_egress_rules(self) -> int:
return len(self.egress_rules)
@property
def flattened_ingress_rules(self) -> List[GroupedSecurityRuleView]:
return self._flattened_rules(copy.copy(self.ingress_rules))
@property
def flattened_egress_rules(self) -> List[GroupedSecurityRuleView]:
return self._flattened_rules(copy.copy(self.egress_rules))
def _flattened_rules(
self, rules: List[SecurityRule]
) -> List[GroupedSecurityRuleView]:
rules_to_return: List[GroupedSecurityRuleView] = []
for rule in rules:
for already_added in rules_to_return:
if (
already_added.from_port == rule.from_port
and already_added.to_port == rule.to_port
and already_added.ip_protocol == rule.ip_protocol
):
if rule.ip_range:
already_added.ip_ranges.append(rule.ip_range)
if rule.source_group:
already_added.source_groups.append(rule.source_group)
if rule.prefix_list_id:
already_added.prefix_list_ids.append(rule.prefix_list_id)
break
else:
view = GroupedSecurityRuleView(
rule.from_port, rule.to_port, rule.ip_protocol
)
if rule.ip_range:
view.ip_ranges.append(rule.ip_range)
if rule.source_group:
view.source_groups.append(rule.source_group)
if rule.prefix_list_id:
view.prefix_list_ids.append(rule.prefix_list_id)
rules_to_return.append(view)
return rules_to_return
class SecurityGroupBackend:
def __init__(self) -> None:
# the key in the dict group is the vpc_id or None (non-vpc)
self.groups: Dict[str, Dict[str, SecurityGroup]] = defaultdict(dict)
def create_security_group(
self,
name: str,
description: str,
vpc_id: Optional[str] = None,
tags: Optional[Dict[str, str]] = None,
force: bool = False,
is_default: Optional[bool] = None,
) -> SecurityGroup:
vpc_id = vpc_id or self.default_vpc.id # type: ignore[attr-defined]
if not description:
raise MissingParameterError("GroupDescription")
group_id = random_security_group_id()
if not force:
existing_group = self.get_security_group_by_name_or_id(name, vpc_id)
if existing_group:
raise InvalidSecurityGroupDuplicateError(name)
group = SecurityGroup(
self,
group_id,
name,
description,
vpc_id=vpc_id,
tags=tags,
is_default=is_default,
)
self.groups[vpc_id][group_id] = group
return group
def describe_security_groups(
self,
group_ids: Optional[List[str]] = None,
groupnames: Optional[List[str]] = None,
filters: Any = None,
) -> List[SecurityGroup]:
all_groups = self.groups.copy()
matches = list(
itertools.chain(*[x.copy().values() for x in all_groups.values()])
)
if group_ids:
matches = [grp for grp in matches if grp.id in group_ids]
if len(group_ids) > len(matches):
unknown_ids = set(group_ids) - set(matches) # type: ignore[arg-type]
raise InvalidSecurityGroupNotFoundError(unknown_ids)
if groupnames:
matches = [grp for grp in matches if grp.name in groupnames]
if len(groupnames) > len(matches):
unknown_names = set(groupnames) - set(matches) # type: ignore[arg-type]
raise InvalidSecurityGroupNotFoundError(unknown_names)
if filters:
matches = [grp for grp in matches if grp.matches_filters(filters)]
return matches
def describe_security_group_rules(
self,
group_ids: Optional[List[str]] = None,
sg_rule_ids: List[str] = [],
filters: Any = None,
) -> List[SecurityRule]:
results = []
if sg_rule_ids:
for id_and_group in self.groups.values():
for group in id_and_group.values():
for rule in itertools.chain(
group.egress_rules, group.ingress_rules
):
if rule.id in sg_rule_ids:
results.append(rule)
return results
if group_ids:
all_sgs = self.describe_security_groups(group_ids=group_ids)
for group in all_sgs:
results.extend(group.ingress_rules)
results.extend(group.egress_rules)
return results
if filters and "group-id" in filters:
for group_id in filters["group-id"]:
if not is_valid_security_group_id(group_id):
raise InvalidGroupIdMalformedError(group_id)
matches = self.describe_security_groups(
group_ids=group_ids, filters=filters
)
for group in matches:
results.extend(group.ingress_rules)
results.extend(group.egress_rules)
return results
all_sgs = self.describe_security_groups()
for group in all_sgs:
results.extend(self._match_sg_rules(group.ingress_rules, filters))
results.extend(self._match_sg_rules(group.egress_rules, filters))
return results
@staticmethod
def _match_sg_rules( # type: ignore[misc]
rules_list: List[SecurityRule], filters: Any
) -> List[SecurityRule]:
results = []
for rule in rules_list:
if rule.match_tags(filters):
results.append(rule)
return results
def _delete_security_group(self, vpc_id: Optional[str], group_id: str) -> None:
vpc_id = vpc_id or self.default_vpc.id # type: ignore[attr-defined]
self.groups[vpc_id].pop(group_id)
def delete_security_group(
self, name: Optional[str] = None, group_id: Optional[str] = None
) -> None:
if group_id:
# loop over all the SGs, find the right one
for vpc_id, groups in self.groups.items():
if group_id in groups:
return self._delete_security_group(vpc_id, group_id)
raise InvalidSecurityGroupNotFoundError(group_id)
elif name:
# Group Name. Has to be in standard EC2, VPC needs to be
# identified by group_id
group = self.get_security_group_by_name_or_id(name)
if group:
return self._delete_security_group(None, group.id)
raise InvalidSecurityGroupNotFoundError(name)
def get_security_group_from_id(self, group_id: str) -> Optional[SecurityGroup]:
# 2 levels of chaining necessary since it's a complex structure
all_groups = itertools.chain.from_iterable(
[x.copy().values() for x in self.groups.copy().values()]
)
for group in all_groups:
if group.id == group_id:
return group
return None
def get_security_group_from_name(
self, name: str, vpc_id: Optional[str] = None
) -> Optional[SecurityGroup]:
if vpc_id:
for group in self.groups[vpc_id].values():
if group.name == name:
return group
else:
for vpc_id in self.groups:
for group in self.groups[vpc_id].values():
if group.name == name:
return group
return None
def get_security_group_by_name_or_id(
self, group_name_or_id: str, vpc_id: Optional[str] = None
) -> Optional[SecurityGroup]:
# try searching by id, fallbacks to name search
group = self.get_security_group_from_id(group_name_or_id)
if group is None:
group = self.get_security_group_from_name(group_name_or_id, vpc_id)
return group
def get_default_security_group(
self, vpc_id: Optional[str] = None
) -> Optional[SecurityGroup]:
for group in self.groups[vpc_id or self.default_vpc.id].values(): # type: ignore[attr-defined]
if group.is_default:
return group
return None
def _iterate_security_rules(
self,
ip_protocol: str,
group_id: str,
from_port: str,
to_port: str,
ip_ranges: List[Any],
source_groups: List[Dict[str, Any]],
prefix_list_ids: List[Dict[str, str]],
is_egress: bool = False,
tags: Dict[str, str] = {},
) -> Iterator[SecurityRule]:
for ip_range in ip_ranges:
yield SecurityRule(
self,
ip_protocol,
group_id,
from_port,
to_port,
ip_range,
None,
None,
is_egress=is_egress,
tags=tags,
)
for source_group in source_groups:
yield SecurityRule(
self,
ip_protocol,
group_id,
from_port,
to_port,
None,
source_group,
None,
is_egress=is_egress,
tags=tags,
)
for prefix_list_id in prefix_list_ids:
yield SecurityRule(
self,
ip_protocol,
group_id,
from_port,
to_port,
None,
None,
prefix_list_id,
is_egress=is_egress,
tags=tags,
)
def authorize_security_group_ingress(
self,
group_name_or_id: str,
ip_protocol: str,
from_port: str,
to_port: str,
ip_ranges: List[Any],
sgrule_tags: Dict[str, str] = {},
source_groups: Optional[List[Dict[str, str]]] = None,
prefix_list_ids: Optional[List[Dict[str, str]]] = None,
security_rule_ids: Optional[List[str]] = None, # pylint:disable=unused-argument
vpc_id: Optional[str] = None,
) -> Tuple[List[SecurityRule], SecurityGroup]:
group = self.get_security_group_by_name_or_id(group_name_or_id, vpc_id)
if group is None:
raise InvalidSecurityGroupNotFoundError(group_name_or_id)
if ip_ranges:
if isinstance(ip_ranges, str):
ip_ranges = [{"CidrIp": str(ip_ranges)}]
elif not isinstance(ip_ranges, list):
ip_ranges = [json.loads(ip_ranges)]
if ip_ranges:
for cidr in ip_ranges:
if (
isinstance(cidr, dict)
and not any(
[
is_valid_cidr(cidr.get("CidrIp", "")),
is_valid_ipv6_cidr(cidr.get("CidrIpv6", "")),
]
)
) or (
isinstance(cidr, str)
and not any([is_valid_cidr(cidr), is_valid_ipv6_cidr(cidr)])
):
raise InvalidCIDRSubnetError(cidr=cidr)
self._verify_group_will_respect_rule_count_limit(
group, group.get_number_of_ingress_rules(), ip_ranges, source_groups
)
_source_groups = self._add_source_group(source_groups, vpc_id)
rules_added: List[SecurityRule] = []
for security_rule in self._iterate_security_rules(
ip_protocol,
group.group_id,
from_port,
to_port,
ip_ranges,
_source_groups,
prefix_list_ids or [],
is_egress=False,
tags=sgrule_tags,
):
if security_rule in group.ingress_rules:
raise InvalidPermissionDuplicateError()
group.add_ingress_rule(security_rule)
rules_added.append(security_rule)
return rules_added, group
def revoke_security_group_ingress(
self,
group_name_or_id: str,
ip_protocol: str,
from_port: str,
to_port: str,
ip_ranges: List[Any],
source_groups: Optional[List[Dict[str, Any]]] = None,
prefix_list_ids: Optional[List[Dict[str, str]]] = None,
security_rule_ids: Optional[List[str]] = None,
vpc_id: Optional[str] = None,
) -> None:
group: SecurityGroup = self.get_security_group_by_name_or_id(
group_name_or_id, vpc_id
) # type: ignore[assignment]
if group is None:
raise InvalidSecurityGroupNotFoundError(group_name_or_id)
rules_to_remove: List[str] = []
has_unknown_rules = False
if security_rule_ids:
ingress_rule_ids = [rule.id for rule in group.ingress_rules]
for rule_id in security_rule_ids:
if rule_id in ingress_rule_ids:
rules_to_remove.append(rule_id)
else:
has_unknown_rules = True
break
else:
_source_groups = self._add_source_group(source_groups, vpc_id)
for security_rule in self._iterate_security_rules(
ip_protocol,
group.group_id,
from_port,
to_port,
ip_ranges,
_source_groups,
prefix_list_ids or [],
is_egress=False,
):
try:
idx = group.ingress_rules.index(security_rule)
rules_to_remove.append(group.ingress_rules[idx].id)
except ValueError:
has_unknown_rules = True
break
if has_unknown_rules:
raise InvalidPermissionNotFoundError()
group.ingress_rules = [
rule for rule in group.ingress_rules if rule.id not in rules_to_remove
]
def authorize_security_group_egress(
self,
group_name_or_id: str,
ip_protocol: str,
from_port: str,
to_port: str,
ip_ranges: List[Any],
sgrule_tags: Dict[str, str] = {},
source_groups: Optional[List[Dict[str, Any]]] = None,
prefix_list_ids: Optional[List[Dict[str, str]]] = None,
security_rule_ids: Optional[List[str]] = None, # pylint:disable=unused-argument
vpc_id: Optional[str] = None,
) -> Tuple[List[SecurityRule], SecurityGroup]:
group = self.get_security_group_by_name_or_id(group_name_or_id, vpc_id)
if group is None:
raise InvalidSecurityGroupNotFoundError(group_name_or_id)
if ip_ranges and not isinstance(ip_ranges, list):
if isinstance(ip_ranges, str) and "CidrIp" not in ip_ranges:
ip_ranges = [{"CidrIp": ip_ranges}]
else:
ip_ranges = [json.loads(ip_ranges)]
if ip_ranges:
for cidr in ip_ranges:
if (
isinstance(cidr, dict)
and not any(
[
is_valid_cidr(cidr.get("CidrIp", "")),
is_valid_ipv6_cidr(cidr.get("CidrIpv6", "")),
]
)
) or (
isinstance(cidr, str)
and not any([is_valid_cidr(cidr), is_valid_ipv6_cidr(cidr)])
):
raise InvalidCIDRSubnetError(cidr=cidr)
self._verify_group_will_respect_rule_count_limit(
group,
group.get_number_of_egress_rules(),
ip_ranges,
source_groups,
egress=True,
)
_source_groups = self._add_source_group(source_groups, vpc_id)
rules_added: List[SecurityRule] = []
for security_rule in self._iterate_security_rules(
ip_protocol,
group.group_id,
from_port,
to_port,
ip_ranges,
_source_groups,
prefix_list_ids or [],
is_egress=True,
tags=sgrule_tags,
):
if security_rule in group.egress_rules:
raise InvalidPermissionDuplicateError()
group.add_egress_rule(security_rule)
rules_added.append(security_rule)
return rules_added, group
def revoke_security_group_egress(
self,
group_name_or_id: str,
ip_protocol: str,
from_port: str,
to_port: str,
ip_ranges: List[Any],
source_groups: Optional[List[Dict[str, Any]]] = None,
prefix_list_ids: Optional[List[Dict[str, str]]] = None,
security_rule_ids: Optional[List[str]] = None,
vpc_id: Optional[str] = None,
) -> None:
group: SecurityGroup = self.get_security_group_by_name_or_id(
group_name_or_id, vpc_id
) # type: ignore[assignment]
if group is None:
raise InvalidSecurityGroupNotFoundError(group_name_or_id)
rules_to_remove: List[str] = []
has_unknown_rules = False
if security_rule_ids:
egress_rule_ids = [rule.id for rule in group.egress_rules]
for rule_id in security_rule_ids:
if rule_id in egress_rule_ids:
rules_to_remove.append(rule_id)
else:
has_unknown_rules = True
break
else:
_source_groups = self._add_source_group(source_groups, vpc_id)
# I don't believe this is required after changing the default egress rule
# to be {'CidrIp': '0.0.0.0/0'} instead of just '0.0.0.0/0'
# Not sure why this would return only the IP if it was 0.0.0.0/0 instead of
# the ip_range?
# for ip in ip_ranges:
# ip_ranges = [ip.get("CidrIp") if ip.get("CidrIp") == "0.0.0.0/0" else ip]
if group.vpc_id:
vpc = self.vpcs.get(group.vpc_id) # type: ignore[attr-defined]
if vpc and not len(vpc.get_cidr_block_association_set(ipv6=True)) > 0:
for item in ip_ranges.copy():
if "CidrIpv6" in item:
ip_ranges.remove(item)
for security_rule in self._iterate_security_rules(
ip_protocol,
group.group_id,
from_port,
to_port,
ip_ranges,
_source_groups,
prefix_list_ids or [],
is_egress=True,
):
try:
idx = group.egress_rules.index(security_rule)
rules_to_remove.append(group.egress_rules[idx].id)
except ValueError:
has_unknown_rules = True
break
if has_unknown_rules:
raise InvalidPermissionNotFoundError()
group.egress_rules = [
rule for rule in group.egress_rules if rule.id not in rules_to_remove
]
def update_security_group_rule_descriptions_ingress(
self,
group_name_or_id: str,
ip_protocol: str,
from_port: str,
to_port: str,
ip_ranges: List[str],
source_groups: Optional[List[Dict[str, Any]]] = None,
prefix_list_ids: Optional[List[Dict[str, str]]] = None,
security_rule_ids: Optional[List[str]] = None, # pylint:disable=unused-argument
vpc_id: Optional[str] = None,
) -> SecurityGroup:
group = self.get_security_group_by_name_or_id(group_name_or_id, vpc_id)
if group is None:
raise InvalidSecurityGroupNotFoundError(group_name_or_id)
if ip_ranges and not isinstance(ip_ranges, list):
if isinstance(ip_ranges, str) and "CidrIp" not in ip_ranges:
ip_ranges = [{"CidrIp": ip_ranges}]
else:
ip_ranges = [json.loads(ip_ranges)]
if ip_ranges:
for cidr in ip_ranges:
if (
isinstance(cidr, dict)
and not any(
[
is_valid_cidr(cidr.get("CidrIp", "")),
is_valid_ipv6_cidr(cidr.get("CidrIpv6", "")),
]
)
) or (
isinstance(cidr, str)
and not any([is_valid_cidr(cidr), is_valid_ipv6_cidr(cidr)])
):
raise InvalidCIDRSubnetError(cidr=cidr)
_source_groups = self._add_source_group(source_groups, vpc_id)
for security_rule in self._iterate_security_rules(
ip_protocol,
group.group_id,
from_port,
to_port,
ip_ranges,
_source_groups,
prefix_list_ids or [],
is_egress=False,
):
try:
idx = group.ingress_rules.index(security_rule)
self._sg_update_description(security_rule, group.ingress_rules[idx])
except ValueError:
continue
return group
def update_security_group_rule_descriptions_egress(
self,
group_name_or_id: str,
ip_protocol: str,
from_port: str,
to_port: str,
ip_ranges: List[str],
source_groups: Optional[List[Dict[str, Any]]] = None,
prefix_list_ids: Optional[List[Dict[str, str]]] = None,
security_rule_ids: Optional[List[str]] = None, # pylint:disable=unused-argument
vpc_id: Optional[str] = None,
) -> SecurityGroup:
group = self.get_security_group_by_name_or_id(group_name_or_id, vpc_id)
if group is None:
raise InvalidSecurityGroupNotFoundError(group_name_or_id)
if ip_ranges and not isinstance(ip_ranges, list):
if isinstance(ip_ranges, str) and "CidrIp" not in ip_ranges:
ip_ranges = [{"CidrIp": ip_ranges}]
else:
ip_ranges = [json.loads(ip_ranges)]
if ip_ranges:
for cidr in ip_ranges:
if (
isinstance(cidr, dict)
and not any(
[
is_valid_cidr(cidr.get("CidrIp", "")),
is_valid_ipv6_cidr(cidr.get("CidrIpv6", "")),
]
)
) or (
isinstance(cidr, str)
and not any([is_valid_cidr(cidr), is_valid_ipv6_cidr(cidr)])
):
raise InvalidCIDRSubnetError(cidr=cidr)
_source_groups = self._add_source_group(source_groups, vpc_id)
for security_rule in self._iterate_security_rules(
ip_protocol,
group.group_id,
from_port,
to_port,
ip_ranges,
_source_groups,
prefix_list_ids or [],
is_egress=True,
):
try:
idx = group.egress_rules.index(security_rule)
self._sg_update_description(security_rule, group.egress_rules[idx])
except ValueError:
continue
return group
def _sg_update_description(
self, security_rule: SecurityRule, rule: SecurityRule
) -> None:
if "Description" in security_rule.ip_range:
description = security_rule.ip_range["Description"]
if "CidrIp" in rule.ip_range and rule.ip_range.get(
"CidrIp"
) == security_rule.ip_range.get("CidrIp"):
rule.ip_range["Description"] = description
elif "CidrIpv6" in rule.ip_range and rule.ip_range.get(
"CidrIpv6"
) == security_rule.ip_range.get("CidrIpv6"):
rule.ip_range["Description"] = description
if "Description" in security_rule.source_group:
description = security_rule.source_group["Description"]
if security_rule.source_group.get("GroupId") == rule.source_group.get(
"GroupId"
) or security_rule.source_group.get("GroupName") == rule.source_group.get(
"GroupName"
):
rule.source_group["Description"] = description
def _add_source_group(
self, source_groups: Optional[List[Dict[str, Any]]], vpc_id: Optional[str]
) -> List[Dict[str, Any]]:
_source_groups = []
for item in source_groups or []:
if "OwnerId" not in item:
item["OwnerId"] = self.account_id # type: ignore[attr-defined]
# for VPCs
if "GroupId" in item:
if not self.get_security_group_by_name_or_id(item["GroupId"], vpc_id):
raise InvalidSecurityGroupNotFoundError(item["GroupId"])
if "GroupName" in item:
source_group = self.get_security_group_by_name_or_id(
item["GroupName"], vpc_id
)
if not source_group:
raise InvalidSecurityGroupNotFoundError(item["GroupName"])
else:
item["GroupId"] = source_group.id
item.pop("GroupName")
_source_groups.append(item)
return _source_groups
def _verify_group_will_respect_rule_count_limit(
self,
group: SecurityGroup,
current_rule_nb: int,
ip_ranges: List[str],
source_groups: Optional[List[Dict[str, str]]] = None,
egress: bool = False,
) -> None:
max_nb_rules = 60 if group.vpc_id else 100
future_group_nb_rules = current_rule_nb
if ip_ranges:
future_group_nb_rules += len(ip_ranges)
if source_groups:
future_group_nb_rules += len(source_groups)
if future_group_nb_rules > max_nb_rules:
raise RulesPerSecurityGroupLimitExceededError
class SecurityGroupIngress(CloudFormationModel):
def __init__(self, security_group: SecurityGroup, properties: Any):
self.security_group = security_group
self.properties = properties
@staticmethod
def cloudformation_name_type() -> str:
return ""
@staticmethod
def cloudformation_type() -> str:
# https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-resource-ec2-securitygroupingress.html
return "AWS::EC2::SecurityGroupIngress"
@classmethod
def create_from_cloudformation_json( # type: ignore[misc]
cls,
resource_name: str,
cloudformation_json: Any,
account_id: str,
region_name: str,
**kwargs: Any,
) -> "SecurityGroupIngress":
from ..models import ec2_backends
properties = cloudformation_json["Properties"]
ec2_backend = ec2_backends[account_id][region_name]
group_name = properties.get("GroupName")
group_id = properties.get("GroupId")
ip_protocol = properties.get("IpProtocol")
cidr_ip = properties.get("CidrIp")
cidr_desc = properties.get("Description")
cidr_ipv6 = properties.get("CidrIpv6")
from_port = properties.get("FromPort")
source_security_group_id = properties.get("SourceSecurityGroupId")
source_security_group_name = properties.get("SourceSecurityGroupName")
# source_security_owner_id =
# properties.get("SourceSecurityGroupOwnerId") # IGNORED AT THE MOMENT
to_port = properties.get("ToPort")
assert group_id or group_name
assert (
source_security_group_name
or cidr_ip
or cidr_ipv6
or source_security_group_id
)
assert ip_protocol
source_group = {}
if source_security_group_id:
source_group["GroupId"] = source_security_group_id
if source_security_group_name:
source_group["GroupName"] = source_security_group_name
if cidr_ip:
ip_ranges = [{"CidrIp": cidr_ip, "Description": cidr_desc}]
else:
ip_ranges = []
if group_id:
security_group = ec2_backend.describe_security_groups(group_ids=[group_id])[
0
]
else:
security_group = ec2_backend.describe_security_groups(
groupnames=[group_name]
)[0]
ec2_backend.authorize_security_group_ingress(
group_name_or_id=security_group.id,
ip_protocol=ip_protocol,
from_port=from_port,
to_port=to_port,
ip_ranges=ip_ranges,
source_groups=[source_group] if source_group else [],
)
return cls(security_group, properties)