feat(workflows): use SES tenants (#40612)

This commit is contained in:
Haven
2025-11-05 11:42:00 -06:00
committed by GitHub
parent 40cfa5f6c0
commit aafa854ce3
4 changed files with 411 additions and 21 deletions

View File

@@ -0,0 +1,170 @@
import logging
from collections.abc import Iterable
from django.conf import settings
from django.core.management.base import BaseCommand
from django.core.paginator import Paginator
from django.db.models import Q
import boto3
from botocore.exceptions import BotoCoreError, ClientError
from posthog.models.integration import Integration
logger = logging.getLogger(__name__)
def _batched(iterable: Iterable, size: int) -> Iterable[list]:
batch: list = []
for item in iterable:
batch.append(item)
if len(batch) >= size:
yield batch
batch = []
if batch:
yield batch
def migrate_ses_tenants(team_ids: list[int], domains: list[str], dry_run: bool = False):
"""
Ensure existing SES email identities have SES Tenants and Tenant Resource Associations.
The command is idempotent.
"""
if team_ids and domains:
print("Please provide either team_ids or domains, not both") # noqa: T201
return
query = (
Integration.objects.filter(kind="email")
.filter(Q(config__provider="ses") | Q(config__provider__isnull=True))
.order_by("id")
)
if team_ids:
print("Setting up SES tenants for teams:", team_ids) # noqa: T201
query = query.filter(team_id__in=team_ids)
elif domains:
print("Setting up SES tenants for domains:", domains) # noqa: T201
# Domains are stored in Integration.config["domain"]
query = query.filter(config__domain__in=domains)
else:
print("Setting up SES tenants for all SES email identities") # noqa: T201
# Collect unique (team_id, domain) pairs to avoid duplicate work per domain
pairs: list[tuple[int, str]] = []
paginator = Paginator(query, 200)
for page_num in paginator.page_range:
page = paginator.page(page_num)
for integration in page.object_list:
domain = integration.config.get("domain")
if not domain:
continue
provider = integration.config.get("provider", "mailjet")
if provider != "ses":
continue
pair = (integration.team_id, domain)
if pair not in pairs:
pairs.append(pair)
if not pairs:
print("No SES email identities found to migrate.") # noqa: T201
return
sts_client = boto3.client(
"sts",
)
tenant_client = boto3.client(
"sesv2",
)
try:
aws_account_id = sts_client.get_caller_identity()["Account"]
except (ClientError, BotoCoreError) as e:
logger.exception("Failed to get AWS account id for SES tenant association: %s", e)
print("Error determining AWS account ID. Aborting.") # noqa: T201
return
for batch in _batched(pairs, 50):
for team_id, domain in batch:
tenant_name = f"team-{team_id}"
identity_arn = f"arn:aws:ses:{settings.SES_REGION}:{aws_account_id}:identity/{domain}"
# Create tenant if missing
try:
if dry_run:
print(f"[DRY-RUN] Would ensure tenant '{tenant_name}' exists") # noqa: T201
else:
try:
tenant_client.create_tenant(
TenantName=tenant_name,
Tags=[{"Key": "team_id", "Value": str(team_id)}],
)
print(f"Created SES tenant '{tenant_name}'") # noqa: T201
except ClientError as e:
if e.response.get("Error", {}).get("Code") == "AlreadyExistsException":
print(f"Tenant '{tenant_name}' already exists") # noqa: T201
else:
raise
except (ClientError, BotoCoreError) as e:
logger.exception("Error creating SES tenant '%s': %s", tenant_name, e)
print(f"Error creating tenant '{tenant_name}': {e}") # noqa: T201
continue
# Create association if missing
try:
if dry_run:
print(f"[DRY-RUN] Would associate identity '{identity_arn}' with tenant '{tenant_name}'") # noqa: T201
else:
try:
tenant_client.create_tenant_resource_association(
TenantName=tenant_name,
ResourceArn=identity_arn,
)
print(f"Associated identity '{domain}' with tenant '{tenant_name}'") # noqa: T201
except ClientError as e:
if e.response.get("Error", {}).get("Code") == "AlreadyExistsException":
print(f"Association already exists for '{domain}' and tenant '{tenant_name}'") # noqa: T201
else:
raise
except (ClientError, BotoCoreError) as e:
logger.exception(
"Error creating SES tenant_resource_association for '%s' on '%s': %s",
domain,
tenant_name,
e,
)
print(f"Error creating tenant_resource_association for '{domain}' on '{tenant_name}': {e}") # noqa: T201
continue
class Command(BaseCommand):
help = "Migrate existing SES identities to use SES Tenants and resource associations"
def add_arguments(self, parser):
parser.add_argument(
"--dry-run",
action="store_true",
help="If set, will not perform changes, only print actions",
)
parser.add_argument(
"--team-ids",
type=str,
help="Comma separated list of team ids to migrate",
)
parser.add_argument(
"--domains",
type=str,
help="Comma separated list of email domains to migrate (e.g., example.com,foo.bar)",
)
def handle(self, *args, **options):
dry_run: bool = bool(options.get("dry_run"))
team_ids_opt = options.get("team_ids")
domains_opt = options.get("domains")
team_ids = [int(x) for x in team_ids_opt.split(",")] if team_ids_opt else []
domains = [x.strip() for x in domains_opt.split(",")] if domains_opt else []
migrate_ses_tenants(team_ids=team_ids, domains=domains, dry_run=dry_run)

View File

@@ -0,0 +1,122 @@
from posthog.test.base import BaseTest
from unittest.mock import patch
from django.test import override_settings
from posthog.management.commands.migrate_ses_tenants import migrate_ses_tenants
from posthog.models.integration import Integration
class _FakeSESv2Client:
def __init__(self):
self.created_tenants: list[str] = []
self.associations: list[tuple[str, str]] = []
def get_caller_identity(self):
return {"Account": "123456789012"}
def create_tenant(self, TenantName: str, Tags: list[dict]): # noqa: N803
# emulate idempotency externally in test assertions
if TenantName in self.created_tenants:
from botocore.exceptions import ClientError
raise ClientError({"Error": {"Code": "AlreadyExistsException", "Message": "Tenant exists"}}, "CreateTenant")
self.created_tenants.append(TenantName)
return {"TenantName": TenantName}
def create_tenant_resource_association(self, TenantName: str, ResourceArn: str): # noqa: N803
# emulate idempotency externally in test assertions
pair = (TenantName, ResourceArn)
if pair in self.associations:
from botocore.exceptions import ClientError
raise ClientError(
{"Error": {"Code": "AlreadyExistsException", "Message": "Association exists"}},
"CreateTenantResourceAssociation",
)
self.associations.append(pair)
return {"TenantName": TenantName, "ResourceArn": ResourceArn}
class TestMigrateSESTenants(BaseTest):
def setUp(self):
super().setUp()
# Two SES email integrations on the same domain (should dedupe by (team, domain))
Integration.objects.create(
team=self.team,
kind="email",
integration_id="noreply@example.com",
config={"domain": "example.com", "provider": "ses"},
created_by=self.user,
)
Integration.objects.create(
team=self.team,
kind="email",
integration_id="alerts@example.com",
config={"domain": "example.com", "provider": "ses"},
created_by=self.user,
)
# Non-SES provider should be ignored
Integration.objects.create(
team=self.team,
kind="email",
integration_id="ops@other.com",
config={"domain": "other.com", "provider": "mailjet"},
created_by=self.user,
)
@override_settings(SES_ACCESS_KEY_ID="test", SES_SECRET_ACCESS_KEY="test", SES_REGION="us-east-1", SES_ENDPOINT="")
@patch("posthog.management.commands.migrate_ses_tenants.boto3.client")
def test_dry_run(self, mock_boto_client):
# Arrange stub clients
sesv2 = _FakeSESv2Client()
mock_boto_client.side_effect = lambda service, **kwargs: sesv2
# Act: dry-run should not attempt create calls but will still resolve account id
migrate_ses_tenants(team_ids=[], domains=[], dry_run=True)
# Assert: no tenants/associations performed
assert sesv2.created_tenants == []
assert sesv2.associations == []
@override_settings(SES_ACCESS_KEY_ID="test", SES_SECRET_ACCESS_KEY="test", SES_REGION="us-east-1", SES_ENDPOINT="")
@patch("posthog.management.commands.migrate_ses_tenants.boto3.client")
def test_migrate_for_team(self, mock_boto_client):
sesv2 = _FakeSESv2Client()
mock_boto_client.side_effect = lambda service, **kwargs: sesv2
migrate_ses_tenants(team_ids=[self.team.id], domains=[], dry_run=False)
# Deduped: only one tenant and one association for (team, example.com)
assert sesv2.created_tenants == [f"team-{self.team.id}"]
expected_arn = f"arn:aws:ses:us-east-1:123456789012:identity/example.com"
assert sesv2.associations == [(f"team-{self.team.id}", expected_arn)]
@override_settings(SES_ACCESS_KEY_ID="test", SES_SECRET_ACCESS_KEY="test", SES_REGION="eu-west-1", SES_ENDPOINT="")
@patch("posthog.management.commands.migrate_ses_tenants.boto3.client")
def test_migrate_for_domain_filter(self, mock_boto_client):
sesv2 = _FakeSESv2Client()
mock_boto_client.side_effect = lambda service, **kwargs: sesv2
# Use domains filter; should match example.com only
migrate_ses_tenants(team_ids=[], domains=["example.com"], dry_run=False)
assert sesv2.created_tenants == [f"team-{self.team.id}"]
expected_arn = f"arn:aws:ses:eu-west-1:123456789012:identity/example.com"
assert sesv2.associations == [(f"team-{self.team.id}", expected_arn)]
@override_settings(SES_ACCESS_KEY_ID="test", SES_SECRET_ACCESS_KEY="test", SES_REGION="us-east-1", SES_ENDPOINT="")
@patch("posthog.management.commands.migrate_ses_tenants.boto3.client")
def test_idempotent_on_repeated_run(self, mock_boto_client):
sesv2 = _FakeSESv2Client()
mock_boto_client.side_effect = lambda service, **kwargs: sesv2
# First run creates
migrate_ses_tenants(team_ids=[self.team.id], domains=[], dry_run=False)
# Second run should hit AlreadyExistsException internally and not error
migrate_ses_tenants(team_ids=[self.team.id], domains=[], dry_run=False)
# Still only one tenant and association recorded
assert sesv2.created_tenants == [f"team-{self.team.id}"]
expected_arn = f"arn:aws:ses:us-east-1:123456789012:identity/example.com"
assert sesv2.associations == [(f"team-{self.team.id}", expected_arn)]

View File

@@ -12,19 +12,48 @@ logger = logging.getLogger(__name__)
class SESProvider:
def __init__(self):
# Initialize SES client
self.client = boto3.client(
# Initialize the boto3 clients
self.sts_client = boto3.client(
"sts",
aws_access_key_id=settings.SES_ACCESS_KEY_ID,
aws_secret_access_key=settings.SES_SECRET_ACCESS_KEY,
region_name=settings.SES_REGION,
)
self.ses_client = boto3.client(
"ses",
aws_access_key_id=settings.SES_ACCESS_KEY_ID,
aws_secret_access_key=settings.SES_SECRET_ACCESS_KEY,
region_name=settings.SES_REGION,
endpoint_url=settings.SES_ENDPOINT if settings.SES_ENDPOINT else None,
)
self.ses_v2_client = boto3.client(
"sesv2",
aws_access_key_id=settings.SES_ACCESS_KEY_ID,
aws_secret_access_key=settings.SES_SECRET_ACCESS_KEY,
region_name=settings.SES_REGION,
)
def create_email_domain(self, domain: str, team_id: int):
# NOTE: For sesv1 creation is done through verification
# NOTE: For sesv1, domain Identity creation is done through verification
self.verify_email_domain(domain, team_id)
# Create a tenant for the domain if not exists
tenant_name = f"team-{team_id}"
try:
self.ses_v2_client.create_tenant(TenantName=tenant_name, Tags=[{"Key": "team_id", "Value": str(team_id)}])
except ClientError as e:
if e.response["Error"]["Code"] != "AlreadyExistsException":
raise
# Associate the new domain identity with the tenant
try:
self.ses_v2_client.create_tenant_resource_association(
TenantName=tenant_name,
ResourceArn=f"arn:aws:ses:{settings.SES_REGION}:{self.sts_client.get_caller_identity()['Account']}:identity/{domain}",
)
except ClientError as e:
if e.response["Error"]["Code"] != "AlreadyExistsException":
raise
def verify_email_domain(self, domain: str, team_id: int):
# Validate the domain contains valid characters for a domain name
DOMAIN_REGEX = r"(?i)^([a-z0-9]+(-[a-z0-9]+)*\.)+[a-z]{2,}$"
@@ -36,7 +65,7 @@ class SESProvider:
# Start/ensure domain verification (TXT at _amazonses.domain) ---
verification_token = None
try:
resp = self.client.verify_domain_identity(Domain=domain)
resp = self.ses_client.verify_domain_identity(Domain=domain)
verification_token = resp.get("VerificationToken")
except ClientError as e:
# If already requested/exists, carry on; SES v1 is idempotent-ish here
@@ -57,7 +86,7 @@ class SESProvider:
# Start/ensure DKIM (three CNAMEs) ---
dkim_tokens: list[str] = []
try:
resp = self.client.verify_domain_dkim(Domain=domain)
resp = self.ses_client.verify_domain_dkim(Domain=domain)
dkim_tokens = resp.get("DkimTokens", []) or []
except ClientError as e:
if e.response["Error"]["Code"] not in ("InvalidParameterValue",):
@@ -86,7 +115,7 @@ class SESProvider:
# Current verification / DKIM statuses to compute overall status & per-record statuses ---
try:
id_attrs = self.client.get_identity_verification_attributes(Identities=[domain])
id_attrs = self.ses_client.get_identity_verification_attributes(Identities=[domain])
verification_status = (
id_attrs["VerificationAttributes"].get(domain, {}).get("VerificationStatus", "Unknown")
)
@@ -94,7 +123,7 @@ class SESProvider:
verification_status = "Unknown"
try:
dkim_attrs = self.client.get_identity_dkim_attributes(Identities=[domain])
dkim_attrs = self.ses_client.get_identity_dkim_attributes(Identities=[domain])
dkim_status = dkim_attrs["DkimAttributes"].get(domain, {}).get("DkimVerificationStatus", "Unknown")
except ClientError:
dkim_status = "Unknown"
@@ -131,7 +160,7 @@ class SESProvider:
Delete an identity from SES
"""
try:
self.client.delete_identity(Identity=identity)
self.ses_client.delete_identity(Identity=identity)
logger.info(f"Identity {identity} deleted from SES")
except (ClientError, BotoCoreError) as e:
logger.exception(f"SES API error deleting identity: {e}")

View File

@@ -1,6 +1,8 @@
from typing import Optional
import pytest
from unittest import TestCase
from unittest.mock import patch
from unittest.mock import MagicMock, patch
from django.test import override_settings
@@ -10,10 +12,30 @@ TEST_DOMAIN = "test.posthog.com"
class TestSESProvider(TestCase):
boto3_client_patcher: Optional[patch] = None # type: ignore
mock_boto3_client: Optional[MagicMock] = None
@classmethod
def setUpClass(cls):
# Patch boto3.client for all tests in this class
patcher = patch("products.workflows.backend.providers.ses.boto3.client")
cls.boto3_client_patcher = patcher
cls.mock_boto3_client = patcher.start()
# Set up a default mock client with safe return values
mock_client_instance = cls.mock_boto3_client.return_value
mock_client_instance.list_identities.return_value = {"Identities": []}
mock_client_instance.delete_identity.return_value = None
mock_client_instance.get_identity_verification_attributes.return_value = {"VerificationAttributes": {}}
mock_client_instance.get_identity_dkim_attributes.return_value = {"DkimAttributes": {}}
mock_client_instance.verify_domain_identity.return_value = {"VerificationToken": "test-token-123"}
mock_client_instance.verify_domain_dkim.return_value = {"DkimTokens": ["token1", "token2", "token3"]}
mock_client_instance.get_caller_identity.return_value = {"Account": "123456789012"}
def setUp(self):
# Remove all domains from SES
# Remove all domains from SES (mocked)
ses_provider = SESProvider()
if TEST_DOMAIN in ses_provider.client.list_identities()["Identities"]:
if TEST_DOMAIN in ses_provider.ses_client.list_identities()["Identities"]:
ses_provider.delete_identity(TEST_DOMAIN)
def test_init_with_valid_credentials(self):
@@ -24,11 +46,43 @@ class TestSESProvider(TestCase):
SES_ENDPOINT="",
):
provider = SESProvider()
assert provider.client
assert provider.ses_client
assert provider.ses_v2_client
assert provider.sts_client
def test_create_email_domain_success(self):
provider = SESProvider()
provider.create_email_domain(TEST_DOMAIN, team_id=1)
# Mock the SES and SESv2 clients on the provider instance
with (
patch.object(provider, "ses_client") as mock_ses_client,
patch.object(provider, "ses_v2_client") as mock_ses_v2_client,
):
# Mock the verification attributes to return a success status
mock_ses_client.get_identity_verification_attributes.return_value = {
"VerificationAttributes": {
TEST_DOMAIN: {
"VerificationStatus": "Success",
"VerificationToken": "test-token-123",
}
}
}
# Mock DKIM attributes to return a success status
mock_ses_client.get_identity_dkim_attributes.return_value = {
"DkimAttributes": {TEST_DOMAIN: {"DkimVerificationStatus": "Success"}}
}
# Mock the domain verification and DKIM setup calls
mock_ses_client.verify_domain_identity.return_value = {"VerificationToken": "test-token-123"}
mock_ses_client.verify_domain_dkim.return_value = {"DkimTokens": ["token1", "token2", "token3"]}
# Mock tenant client methods
mock_ses_v2_client.create_tenant.return_value = {}
mock_ses_v2_client.get_caller_identity.return_value = {"Account": "123456789012"}
mock_ses_v2_client.create_tenant_resource_association.return_value = {}
provider.create_email_domain(TEST_DOMAIN, team_id=1)
@patch("products.workflows.backend.providers.ses.boto3.client")
def test_create_email_domain_invalid_domain(self, mock_boto_client):
@@ -42,10 +96,10 @@ class TestSESProvider(TestCase):
def test_verify_email_domain_initial_setup(self):
provider = SESProvider()
# Mock the client on the provider instance
with patch.object(provider, "client") as mock_client:
# Mock the SES client on the provider instance
with patch.object(provider, "ses_client") as mock_ses_client:
# Mock the verification attributes to return a non-success status
mock_client.get_identity_verification_attributes.return_value = {
mock_ses_client.get_identity_verification_attributes.return_value = {
"VerificationAttributes": {
TEST_DOMAIN: {
"VerificationStatus": "Pending", # Non-success status
@@ -55,7 +109,7 @@ class TestSESProvider(TestCase):
}
# Mock DKIM attributes to return a non-success status
mock_client.get_identity_dkim_attributes.return_value = {
mock_ses_client.get_identity_dkim_attributes.return_value = {
"DkimAttributes": {
TEST_DOMAIN: {
"DkimVerificationStatus": "Pending" # Non-success status
@@ -64,8 +118,8 @@ class TestSESProvider(TestCase):
}
# Mock the domain verification and DKIM setup calls
mock_client.verify_domain_identity.return_value = {"VerificationToken": "test-token-123"}
mock_client.verify_domain_dkim.return_value = {"DkimTokens": ["token1", "token2", "token3"]}
mock_ses_client.verify_domain_identity.return_value = {"VerificationToken": "test-token-123"}
mock_ses_client.verify_domain_dkim.return_value = {"DkimTokens": ["token1", "token2", "token3"]}
result = provider.verify_email_domain(TEST_DOMAIN, team_id=1)
@@ -114,6 +168,21 @@ class TestSESProvider(TestCase):
def test_verify_email_domain_success(self):
provider = SESProvider()
result = provider.verify_email_domain(TEST_DOMAIN, team_id=1)
# Patch the SES client to return 'Success' for both verification and DKIM
with (
patch.object(provider.ses_client, "get_identity_verification_attributes") as mock_verif_attrs,
patch.object(provider.ses_client, "get_identity_dkim_attributes") as mock_dkim_attrs,
):
mock_verif_attrs.return_value = {
"VerificationAttributes": {
TEST_DOMAIN: {
"VerificationStatus": "Success",
"VerificationToken": "test-token-123",
}
}
}
mock_dkim_attrs.return_value = {"DkimAttributes": {TEST_DOMAIN: {"DkimVerificationStatus": "Success"}}}
result = provider.verify_email_domain(TEST_DOMAIN, team_id=1)
# Should return verified status with no DNS records needed
assert result == {"status": "success", "dnsRecords": []}