feat: Build authentication HyperCache for local_evaluation (#37596)

Co-authored-by: github-actions <41898282+github-actions[bot]@users.noreply.github.com>
This commit is contained in:
Phil Haack
2025-09-12 09:48:41 -07:00
committed by GitHub
parent 6730a5c16b
commit 8989f01aec
11 changed files with 5877 additions and 1889 deletions

File diff suppressed because it is too large Load Diff

View File

@@ -317,6 +317,7 @@ field_exclusions: dict[ActivityScope, list[str]] = {
"id",
"secret_api_token",
"secret_api_token_backup",
"_old_api_token",
],
"Project": ["id", "created_at"],
"DataWarehouseSavedQuery": [

View File

@@ -5,7 +5,7 @@ from typing import Any, Optional
from django.conf import settings
from django.core.cache import cache
from django.db import models, transaction
from django.db.models.signals import post_save
from django.db.models.signals import post_delete, post_save, pre_save
from django.dispatch.dispatcher import receiver
from django.http import HttpRequest
from django.utils import timezone
@@ -19,9 +19,12 @@ from posthog.exceptions_capture import capture_exception
from posthog.models.error_tracking.error_tracking import ErrorTrackingSuppressionRule
from posthog.models.feature_flag.feature_flag import FeatureFlag
from posthog.models.hog_functions.hog_function import HogFunction
from posthog.models.organization import OrganizationMembership
from posthog.models.personal_api_key import PersonalAPIKey
from posthog.models.plugin import PluginConfig
from posthog.models.surveys.survey import Survey
from posthog.models.team.team import Team
from posthog.models.user import User
from posthog.models.utils import UUIDTModel, execute_with_timeout
from posthog.storage.hypercache import HyperCache, HyperCacheStoreMissing
@@ -467,10 +470,30 @@ def _update_team_remote_config(team_id: int):
update_team_remote_config.delay(team_id)
@receiver(pre_save, sender=Team)
def team_pre_save(sender, instance: "Team", **kwargs):
"""Capture old api_token value before save for cache cleanup."""
from posthog.storage.team_access_cache_signal_handlers import capture_old_api_token
capture_old_api_token(instance, **kwargs)
@receiver(post_save, sender=Team)
def team_saved(sender, instance: "Team", created, **kwargs):
transaction.on_commit(lambda: _update_team_remote_config(instance.id))
from posthog.storage.team_access_cache_signal_handlers import update_team_authentication_cache
transaction.on_commit(lambda: update_team_authentication_cache(instance, created, **kwargs))
@receiver(post_delete, sender=Team)
def team_deleted(sender, instance: "Team", **kwargs):
"""Handle team deletion for access cache."""
from posthog.storage.team_access_cache_signal_handlers import update_team_authentication_cache_on_delete
transaction.on_commit(lambda: update_team_authentication_cache_on_delete(instance, **kwargs))
@receiver(post_save, sender=FeatureFlag)
def feature_flag_saved(sender, instance: "FeatureFlag", created, **kwargs):
@@ -500,3 +523,85 @@ def survey_saved(sender, instance: "Survey", created, **kwargs):
@receiver(post_save, sender=ErrorTrackingSuppressionRule)
def error_tracking_suppression_rule_saved(sender, instance: "ErrorTrackingSuppressionRule", created, **kwargs):
transaction.on_commit(lambda: _update_team_remote_config(instance.team_id))
@receiver(post_save, sender=PersonalAPIKey)
def personal_api_key_saved(sender, instance: "PersonalAPIKey", created, **kwargs):
"""
Handle PersonalAPIKey save for team access cache invalidation.
Skip cache updates for last_used_at field updates to avoid unnecessary cache warming
during authentication requests.
"""
from posthog.storage.team_access_cache_signal_handlers import update_personal_api_key_authentication_cache
# Skip cache updates if only last_used_at is being updated
update_fields = kwargs.get("update_fields")
if update_fields is not None and set(update_fields) == {"last_used_at"}:
return
transaction.on_commit(lambda: update_personal_api_key_authentication_cache(instance))
@receiver(post_delete, sender=PersonalAPIKey)
def personal_api_key_deleted(sender, instance: "PersonalAPIKey", **kwargs):
"""
Handle PersonalAPIKey delete for team access cache invalidation.
"""
from posthog.storage.team_access_cache_signal_handlers import update_personal_api_key_deleted_cache
transaction.on_commit(lambda: update_personal_api_key_deleted_cache(instance))
@receiver(post_save, sender=User)
def user_saved(sender, instance: "User", created, **kwargs):
"""
Handle User save for team access cache updates when is_active changes.
When a user's is_active status changes, their Personal API Keys need to be
added or removed from team authentication caches.
"""
update_fields = kwargs.get("update_fields")
if update_fields is not None and "is_active" not in update_fields:
logger.debug(f"User {instance.id} updated but is_active unchanged, skipping cache update")
return
# If update_fields is None, we need to update cache since all fields (including is_active) might have changed
from posthog.storage.team_access_cache_signal_handlers import update_user_authentication_cache
transaction.on_commit(lambda: update_user_authentication_cache(instance, **kwargs))
@receiver(post_save, sender=OrganizationMembership)
def organization_membership_saved(sender, instance: "OrganizationMembership", created, **kwargs):
"""
Handle OrganizationMembership creation for team access cache updates.
When a user is added to an organization, their unscoped personal API keys
should gain access to teams within that organization. This ensures
that the authentication cache is updated to reflect the new access rights.
Note: We intentionally only handle creation (created=True), not updates.
Changes to membership level (e.g., MEMBER → ADMIN) don't affect API key
access - Personal API keys grant access based on organization membership
existence, not role level.
"""
if created:
from posthog.storage.team_access_cache_signal_handlers import update_organization_membership_created_cache
transaction.on_commit(lambda: update_organization_membership_created_cache(instance))
@receiver(post_delete, sender=OrganizationMembership)
def organization_membership_deleted(sender, instance: "OrganizationMembership", **kwargs):
"""
Handle OrganizationMembership deletion for team access cache invalidation.
When a user is removed from an organization, their unscoped personal API keys
should no longer have access to teams within that organization. This ensures
that the authentication cache is updated to reflect the change in access rights.
"""
from posthog.storage.team_access_cache_signal_handlers import update_organization_membership_deleted_cache
transaction.on_commit(lambda: update_organization_membership_deleted_cache(instance))

View File

@@ -258,9 +258,9 @@ class TestHogFunctionsBackgroundReloading(TestCase, QueryMatchingTest):
{"key": "$host", "operator": "regex", "value": "^(localhost|127\\.0\\.0\\.1)($|:)"},
{"key": "$pageview", "operator": "regex", "value": "test"},
]
# 1 update team, 1 load hog flows, 1 load hog functions, 1 update hog functions
# 1 select team (for field comparison), 1 update team, 1 load hog flows, 1 load hog functions, 1 update hog functions
# Note: RemoteConfig refresh queries are now deferred via async signals
with self.assertNumQueries(4):
with self.assertNumQueries(5):
self.team.save()
hog_function_1.refresh_from_db()
hog_function_2.refresh_from_db()

View File

@@ -0,0 +1,208 @@
"""
Tests for signal handlers in posthog/models/remote_config.py.
"""
from unittest.mock import MagicMock, patch
from django.test import TestCase
from parameterized import parameterized
from posthog.models.organization import OrganizationMembership
from posthog.models.remote_config import organization_membership_deleted, user_saved
from posthog.models.user import User
class TestUserSavedSignalHandler(TestCase):
"""Test the user_saved signal handler in remote_config.py."""
@parameterized.expand(
[
# (update_fields, should_schedule_update, description)
(["is_active", "email"], True, "is_active in update_fields"),
(None, True, "update_fields is None (bulk operation)"),
(["email", "name"], False, "is_active not in update_fields"),
([], False, "empty update_fields list"),
(["is_active"], True, "only is_active in update_fields"),
]
)
@patch("django.db.transaction.on_commit")
def test_user_saved_update_fields_scenarios(
self, update_fields, should_schedule_update, description, mock_on_commit
):
"""Test user_saved signal handler for various update_fields scenarios."""
# Create mock user
mock_user = MagicMock()
mock_user.id = 42
mock_user.is_active = True
# Call user_saved with specified update_fields
user_saved(sender=User, instance=mock_user, created=False, update_fields=update_fields)
# Verify transaction.on_commit behavior
if should_schedule_update:
mock_on_commit.assert_called_once()
else:
mock_on_commit.assert_not_called()
@patch("posthog.models.remote_config.logger")
@patch("django.db.transaction.on_commit")
def test_user_saved_logs_debug_when_skipping_update(self, mock_on_commit, mock_logger):
"""Test that user_saved logs debug message when skipping cache update."""
# Create mock user
mock_user = MagicMock()
mock_user.id = 42
mock_user.is_active = True
# Call user_saved with is_active NOT in update_fields
user_saved(sender=User, instance=mock_user, created=False, update_fields=["email", "name"])
# Verify debug message was logged
mock_logger.debug.assert_called_once_with("User 42 updated but is_active unchanged, skipping cache update")
# Verify transaction.on_commit was not called
mock_on_commit.assert_not_called()
@patch("posthog.storage.team_access_cache_signal_handlers.update_user_authentication_cache")
@patch("django.db.transaction.on_commit")
def test_user_saved_uses_transaction_on_commit(self, mock_on_commit, mock_update_cache):
"""Test that user_saved uses transaction.on_commit to defer cache updates."""
# Create mock user
mock_user = MagicMock()
mock_user.id = 42
mock_user.is_active = True
# Call user_saved with is_active in update_fields
user_saved(sender=User, instance=mock_user, created=False, update_fields=["is_active"])
# Verify transaction.on_commit was called
mock_on_commit.assert_called_once()
# Get the lambda function that was passed to on_commit and call it
on_commit_lambda = mock_on_commit.call_args[0][0]
on_commit_lambda()
# Verify that the update function would be called after transaction commits
# The lambda passes instance and **kwargs (which doesn't include created)
mock_update_cache.assert_called_once_with(mock_user, update_fields=["is_active"])
class TestOrganizationMembershipDeletedSignalHandler(TestCase):
"""Test the organization_membership_deleted signal handler in remote_config.py."""
@patch("django.db.transaction.on_commit")
def test_organization_membership_deleted_calls_update_when_user_removed(self, mock_on_commit):
"""Test that organization_membership_deleted schedules cache update when a user is removed from org."""
# Create mock user and organization
mock_user = MagicMock()
mock_user.id = 42
mock_org = MagicMock()
mock_org.id = "test-org-uuid"
# Create mock OrganizationMembership
mock_membership = MagicMock()
mock_membership.user = mock_user
mock_membership.organization = mock_org
# Call organization_membership_deleted
organization_membership_deleted(sender=OrganizationMembership, instance=mock_membership)
# Verify transaction.on_commit was called (update was scheduled)
mock_on_commit.assert_called_once()
@patch("posthog.storage.team_access_cache_signal_handlers.update_organization_membership_deleted_cache")
@patch("django.db.transaction.on_commit")
def test_organization_membership_deleted_uses_transaction_on_commit(self, mock_on_commit, mock_update_cache):
"""Test that organization_membership_deleted uses transaction.on_commit to defer cache updates."""
# Create mock user and organization
mock_user = MagicMock()
mock_user.id = 42
mock_org = MagicMock()
mock_org.id = "test-org-uuid"
# Create mock OrganizationMembership
mock_membership = MagicMock()
mock_membership.user = mock_user
mock_membership.organization = mock_org
mock_membership.organization_id = "test-org-uuid"
mock_membership.user_id = 42
# Call organization_membership_deleted
organization_membership_deleted(sender=OrganizationMembership, instance=mock_membership)
# Verify transaction.on_commit was called
mock_on_commit.assert_called_once()
# Get the lambda function that was passed to on_commit and call it
on_commit_lambda = mock_on_commit.call_args[0][0]
on_commit_lambda()
# Verify that the update function would be called after transaction commits
# The lambda passes the membership instance
mock_update_cache.assert_called_once_with(mock_membership)
@patch("posthog.models.remote_config.logger")
@patch("django.db.transaction.on_commit")
def test_organization_membership_deleted_logs_when_scheduled(self, mock_on_commit, mock_logger):
"""Test that organization_membership_deleted properly schedules cache updates."""
# Create mock user and organization
mock_user = MagicMock()
mock_user.id = 42
mock_org = MagicMock()
mock_org.id = "test-org-uuid"
# Create mock OrganizationMembership
mock_membership = MagicMock()
mock_membership.user = mock_user
mock_membership.organization = mock_org
# Call organization_membership_deleted
organization_membership_deleted(sender=OrganizationMembership, instance=mock_membership)
# Verify transaction.on_commit was called (update was scheduled)
mock_on_commit.assert_called_once()
def test_organization_membership_deleted_handles_none_user(self):
"""Test that organization_membership_deleted handles membership with None user gracefully."""
# Create mock OrganizationMembership with None user
mock_membership = MagicMock()
mock_membership.user = None
# Should not raise an exception
try:
organization_membership_deleted(sender=OrganizationMembership, instance=mock_membership)
except Exception as e:
self.fail(f"organization_membership_deleted raised an exception with None user: {e}")
@patch("django.db.transaction.on_commit")
def test_organization_membership_deleted_with_different_kwargs(self, mock_on_commit):
"""Test that organization_membership_deleted properly forwards kwargs to the cache update."""
# Create mock user and organization
mock_user = MagicMock()
mock_user.id = 42
mock_org = MagicMock()
mock_org.id = "test-org-uuid"
# Create mock OrganizationMembership
mock_membership = MagicMock()
mock_membership.user = mock_user
mock_membership.organization = mock_org
# Call organization_membership_deleted with additional kwargs
test_kwargs = {"raw": False, "using": "default"}
organization_membership_deleted(sender=OrganizationMembership, instance=mock_membership, **test_kwargs)
# Verify transaction.on_commit was called
mock_on_commit.assert_called_once()

View File

@@ -0,0 +1,383 @@
"""
Per-team access token cache layer for cache-based authentication.
This module provides Redis-based caching of hashed access tokens per team,
enabling zero-database-call authentication for the local_evaluation endpoint.
"""
import logging
from datetime import UTC, datetime
from typing import Any
from django.db import transaction
from django.db.models import Q
from posthog.models.organization import OrganizationMembership
from posthog.models.personal_api_key import PersonalAPIKey, hash_key_value
from posthog.models.team.team import Team
from posthog.storage.hypercache import HyperCache, HyperCacheStoreMissing, KeyType
logger = logging.getLogger(__name__)
# Cache configuration
DEFAULT_TTL = 300 # 5 minutes
CACHE_KEY_PREFIX = "cache/teams"
class TeamAccessTokenCache:
"""
HyperCache-based cache for per-team access token lists.
This class manages hashed token lists per team to enable fast authentication
lookups without database queries. Each team has its own cache entry with
JSON data containing hashed authorized tokens and metadata. Uses HyperCache
for automatic Redis + S3 backup and improved reliability.
"""
def __init__(self, ttl: int = DEFAULT_TTL):
"""
Initialize the team access token cache.
Args:
ttl: Time-to-live for cache entries in seconds
"""
self.ttl = ttl
def update_team_tokens(self, project_api_key: str, team_id: int, hashed_tokens: list[str]) -> None:
"""
Update a team's complete token list in cache.
Args:
project_api_key: The team's project API key
team_id: The team's ID
hashed_tokens: List of hashed tokens (already in sha256$ format)
"""
try:
token_data = {
"hashed_tokens": hashed_tokens,
"last_updated": datetime.now(UTC).isoformat(),
"team_id": team_id,
}
team_access_tokens_hypercache.set_cache_value(project_api_key, token_data)
logger.info(
f"Updated token cache for team {project_api_key} with {len(hashed_tokens)} tokens",
extra={"team_project_api_key": project_api_key, "token_count": len(hashed_tokens)},
)
except Exception as e:
logger.exception(f"Error updating tokens for team {project_api_key}: {e}")
raise
def invalidate_team(self, project_api_key: str) -> None:
"""
Invalidate (delete) a team's token cache.
Args:
project_api_key: The team's project API key
"""
try:
team_access_tokens_hypercache.clear_cache(project_api_key)
logger.info(f"Invalidated token cache for team {project_api_key}")
except Exception as e:
logger.exception(f"Error invalidating cache for team {project_api_key}: {e}")
raise
def _load_team_access_tokens(team_token: KeyType) -> dict[str, Any] | HyperCacheStoreMissing:
"""
Load team access tokens from the database.
Args:
team_token: Team identifier (can be Team object, API token string, or team ID)
Returns:
Dictionary containing hashed tokens and metadata, or HyperCacheStoreMissing if team not found
"""
try:
# Use transaction isolation to ensure consistent reads across all queries
with transaction.atomic():
if isinstance(team_token, str):
team = Team.objects.select_related("organization").get(api_token=team_token)
elif isinstance(team_token, int):
team = Team.objects.select_related("organization").get(id=team_token)
else:
# team_token is already a Team object, but ensure organization is loaded
team = team_token
if not hasattr(team, "organization") or team.organization is None:
team = Team.objects.select_related("organization").get(id=team.id)
hashed_tokens: list[str] = []
# Get all relevant personal API keys in one optimized query
# Combines scoped and unscoped keys with proper filtering
personal_keys = (
PersonalAPIKey.objects.select_related("user")
.filter(
user__organization_membership__organization_id=team.organization_id,
user__is_active=True,
)
.filter(
# Organization scoping: key must either have no org restriction OR include this org
Q(scoped_organizations__isnull=True)
| Q(scoped_organizations=[])
| Q(scoped_organizations__contains=[str(team.organization_id)])
)
.filter(
(
# Scoped keys: explicitly include this team AND have feature flag read or write access
Q(scoped_teams__contains=[team.id])
& (
# Keys with write permission implicitly have read permission
Q(scopes__contains=["feature_flag:read"]) | Q(scopes__contains=["feature_flag:write"])
)
)
| (
# Unscoped keys: no team restriction (null or empty array)
(Q(scoped_teams__isnull=True) | Q(scoped_teams=[]))
& (
# AND either no scope restriction OR has feature flag read or write access
Q(scopes__isnull=True)
| Q(scopes=[])
| Q(scopes__contains=["feature_flag:read"])
| Q(scopes__contains=["feature_flag:write"])
)
)
)
.distinct()
.values_list("secure_value", flat=True)
)
# Collect personal API key tokens
hashed_tokens.extend(secure_value for secure_value in personal_keys if secure_value)
# Add team secret tokens
if team.secret_api_token:
hashed_secret = hash_key_value(team.secret_api_token, mode="sha256")
hashed_tokens.append(hashed_secret)
if team.secret_api_token_backup:
hashed_secret_backup = hash_key_value(team.secret_api_token_backup, mode="sha256")
hashed_tokens.append(hashed_secret_backup)
return {
"hashed_tokens": hashed_tokens,
"last_updated": datetime.now(UTC).isoformat(),
"team_id": team.id, # Include team_id for zero-DB-call authentication
}
except Team.DoesNotExist:
logger.warning(f"Team not found for project API key: {team_token}")
return HyperCacheStoreMissing()
except Exception as e:
logger.exception(f"Error loading team access tokens for {team_token}: {e}")
return HyperCacheStoreMissing()
# HyperCache instance for team access tokens
team_access_tokens_hypercache = HyperCache(
namespace="team_access_tokens",
value="access_tokens.json",
token_based=True, # Use team API token as key
load_fn=_load_team_access_tokens,
)
# Global instance for convenience
team_access_cache = TeamAccessTokenCache()
def warm_team_token_cache(project_api_key: str) -> bool:
"""
Warm the token cache for a specific team by loading from database.
This function now uses the HyperCache to update the cache, which handles
both Redis and S3 storage automatically.
Args:
project_api_key: The team's project API key
Returns:
True if cache warming succeeded, False otherwise
"""
# Use HyperCache to update the cache - this will call _load_team_access_tokens
# It does not raise an exception if the cache is not updated, so we don't need to try/except
success = team_access_tokens_hypercache.update_cache(project_api_key)
if success:
logger.info(
f"Warmed token cache for team {project_api_key} using HyperCache",
extra={"project_api_key": project_api_key},
)
else:
logger.warning(f"Failed to warm token cache for team {project_api_key}")
return success
def get_teams_needing_cache_refresh(limit: int | None = None, offset: int = 0) -> list[str]:
"""
Get a list of project API keys for teams that need cache refresh.
This function now supports pagination to handle large datasets efficiently.
For installations with many teams, use limit/offset to process in batches.
Args:
limit: Maximum number of teams to check. None means no limit (all teams).
offset: Number of teams to skip before starting to check.
Returns:
List of project API keys that need cache refresh
Raises:
Exception: Database connectivity or other systemic issues that should trigger retries
"""
# Build queryset with pagination support
# Note: Filtering by project__isnull=False may be needed in production
# but is removed for testing since test teams often don't have projects
queryset = Team.objects.values_list("api_token", flat=True).order_by("id") # Consistent ordering for pagination
# Apply pagination if specified
if offset > 0:
queryset = queryset[offset:]
if limit is not None:
queryset = queryset[:limit]
# Check which teams have missing caches in HyperCache
teams_needing_refresh = []
for project_api_key in queryset:
try:
token_data = team_access_tokens_hypercache.get_from_cache(project_api_key)
if token_data is None:
teams_needing_refresh.append(project_api_key)
except Exception as e:
# Log individual team cache check failure but continue with others
logger.warning(
f"Failed to check cache for team {project_api_key}: {e}",
extra={"project_api_key": project_api_key, "error": str(e)},
)
# Assume this team needs refresh if we can't check its cache
teams_needing_refresh.append(project_api_key)
return teams_needing_refresh
def get_teams_needing_cache_refresh_paginated(batch_size: int = 1000):
"""
Generator that yields batches of teams needing cache refresh.
This is the recommended approach for processing large numbers of teams
to avoid memory issues. It processes teams in chunks and yields each
batch as it's completed.
Args:
batch_size: Number of teams to process per batch
Yields:
List[str]: Batches of project API keys that need cache refresh
"""
offset = 0
while True:
batch = get_teams_needing_cache_refresh(limit=batch_size, offset=offset)
if not batch:
# No more teams to process
break
yield batch
offset += batch_size
# If we got fewer teams than requested, we've reached the end
if len(batch) < batch_size:
break
def get_teams_for_user_personal_api_keys(user_id: int) -> set[str]:
"""
Get all project API keys for teams that a user's PersonalAPIKeys have access to.
This function eliminates N+1 queries by determining all affected teams for
a user's personal API keys in minimal database queries (1-3 queries maximum).
Args:
user_id: The user ID whose personal API keys to analyze
Returns:
Set of project API keys (strings) for all teams the user's keys have access to
"""
# Get all personal API keys for the user
personal_keys = list(PersonalAPIKey.objects.filter(user_id=user_id).values("id", "scoped_teams"))
if not personal_keys:
return set()
affected_teams = set()
scoped_team_ids: set[int] = set()
has_unscoped_keys = False
# Analyze all keys to determine which teams they affect
for key_data in personal_keys:
scoped_teams = key_data["scoped_teams"] or []
if scoped_teams:
# Scoped key - add specific team IDs
scoped_team_ids.update(scoped_teams)
else:
# Unscoped key - will need all teams in user's organizations
has_unscoped_keys = True
# Get project API keys for scoped teams (if any) in one query
if scoped_team_ids:
scoped_team_tokens = Team.objects.filter(id__in=scoped_team_ids).values_list("api_token", flat=True)
affected_teams.update(scoped_team_tokens)
# Get project API keys for unscoped keys (if any) in two queries maximum
if has_unscoped_keys:
# Get user's organizations
user_organizations = OrganizationMembership.objects.filter(user_id=user_id).values_list(
"organization_id", flat=True
)
if user_organizations:
# Get all teams in those organizations
org_team_tokens = Team.objects.filter(organization_id__in=user_organizations).values_list(
"api_token", flat=True
)
affected_teams.update(org_team_tokens)
return affected_teams
def get_teams_for_single_personal_api_key(personal_api_key_instance: "PersonalAPIKey") -> list[str]:
"""
Get project API keys for teams that a single PersonalAPIKey has access to.
This is a helper function that internally uses the optimized user-based function.
For better performance when processing multiple keys for the same user, use
get_teams_for_user_personal_api_keys() directly.
Args:
personal_api_key_instance: The PersonalAPIKey instance
Returns:
List of project API keys (strings) for teams the key has access to
"""
user_id = personal_api_key_instance.user_id
all_user_teams = get_teams_for_user_personal_api_keys(user_id)
# Filter to only teams this specific key has access to
scoped_teams = personal_api_key_instance.scoped_teams or []
if scoped_teams:
# Scoped key - only return teams in the scoped list
scoped_team_tokens = set(Team.objects.filter(id__in=scoped_teams).values_list("api_token", flat=True))
# Return intersection of user teams and scoped teams
result = list(all_user_teams.intersection(scoped_team_tokens))
else:
# Unscoped key - return all teams the user has access to
result = list(all_user_teams)
return result

View File

@@ -0,0 +1,326 @@
"""
Signal handler functions for team access token cache invalidation.
This module provides handler functions that automatically update
the team access token cache when PersonalAPIKey or Team models change,
ensuring cache consistency with the database.
Note: Signal subscriptions are registered in posthog/models/remote_config.py
"""
import logging
from posthog.models.organization import OrganizationMembership
from posthog.models.personal_api_key import PersonalAPIKey
from posthog.models.team.team import Team
from posthog.storage.team_access_cache import (
get_teams_for_user_personal_api_keys,
team_access_cache,
warm_team_token_cache,
)
logger = logging.getLogger(__name__)
def capture_old_api_token(instance: Team, **kwargs):
"""
Capture the old api_token value before save for cleanup.
This pre_save handler stores the old api_token value so the post_save
handler can clean up the old cache entry when the token changes.
"""
if instance.pk: # Only for existing teams
try:
old_team = Team.objects.only("api_token").get(pk=instance.pk)
# Store the old api_token value for post_save cleanup
instance._old_api_token = old_team.api_token # type: ignore[attr-defined]
except Team.DoesNotExist:
pass
def update_team_authentication_cache(instance: Team, created: bool, **kwargs):
"""
Rebuild team access cache when Team model is saved.
This handler only rebuilds the cache when authentication-related fields change
to avoid unnecessary cache operations for unrelated team updates.
"""
try:
if not instance.api_token:
return
if created:
logger.debug(f"New team created: {instance.pk}")
return
# Check if this is a new team being created
if hasattr(instance, "_state") and instance._state.adding:
logger.debug(f"Team {instance.pk} is being created, skipping cache update")
return
# Check if api_token changed (project API key regeneration)
# We look for the old value stored before save
old_api_token = getattr(instance, "_old_api_token", None)
# If update_fields is specified, only rebuild cache if auth-related fields changed
update_fields = kwargs.get("update_fields")
auth_related_fields = {"api_token", "secret_api_token", "secret_api_token_backup"}
if update_fields is not None:
# Convert update_fields to set for efficient intersection
updated_fields = set(update_fields) if update_fields else set()
# Check if any auth-related fields were updated
if not updated_fields.intersection(auth_related_fields):
logger.debug(
f"Team {instance.pk} updated but no auth fields changed, skipping cache update",
extra={
"team_id": instance.pk,
"updated_fields": list(updated_fields),
"auth_fields": list(auth_related_fields),
},
)
return
try:
# Clean up old cache if api_token changed
if old_api_token and old_api_token != instance.api_token:
team_access_cache.invalidate_team(old_api_token)
logger.info(
f"Invalidated old cache for team {instance.pk} after API token change",
extra={"team_id": instance.pk, "old_api_token": old_api_token, "new_api_token": instance.api_token},
)
warm_team_token_cache(instance.api_token)
logger.info(
f"Rebuilt team access cache for team {instance.pk} after auth field change",
extra={"team_id": instance.pk, "project_api_key": instance.api_token},
)
except Exception as e:
logger.warning(
f"Failed to rebuild cache for team {instance.pk}, falling back to invalidation: {e}",
extra={"team_id": instance.pk},
)
# Fall back to invalidation if rebuild fails
team_access_cache.invalidate_team(instance.api_token)
except Exception as e:
logger.exception(f"Error updating cache on team save for team {instance.pk}: {e}")
def update_team_authentication_cache_on_delete(instance: Team, **kwargs):
"""
Invalidate team access cache when Team is deleted.
"""
try:
if instance.api_token:
team_access_cache.invalidate_team(instance.api_token)
logger.info(f"Invalidated cache for deleted team {instance.pk}")
except Exception as e:
logger.exception(f"Error invalidating cache on team delete for team {instance.pk}: {e}")
def update_personal_api_key_authentication_cache(instance: PersonalAPIKey):
"""
Update team access cache when PersonalAPIKey is saved.
This handler warms the cache for all teams that the user's personal API keys have
access to. For optimal performance, it uses the user-based function to warm all
affected teams at once. Since warming completely rebuilds the cache from the database,
no prior invalidation is needed.
"""
# Get the list of affected teams using optimized user-based function
affected_teams = get_teams_for_user_personal_api_keys(instance.user_id)
# Warm the cache for each affected team
for project_api_key in affected_teams:
try:
warm_team_token_cache(project_api_key)
logger.debug(f"Warmed cache for team {project_api_key} after PersonalAPIKey change")
except Exception as e:
logger.warning(
f"Failed to warm cache for team {project_api_key} after PersonalAPIKey change: {e}",
extra={"project_api_key": project_api_key, "personal_api_key_id": instance.id},
)
logger.info(
f"Updated authentication cache for {len(affected_teams)} teams after PersonalAPIKey change",
extra={"personal_api_key_id": instance.id, "affected_teams_count": len(affected_teams)},
)
def update_user_authentication_cache(instance, **kwargs):
"""
Update team access caches when a User's status changes.
When a user is activated/deactivated, their Personal API Keys need to be
added/removed from the authentication caches of all teams they have access to.
This includes both scoped and unscoped keys.
Note: The update_fields filtering is now handled by the user_saved signal handler
in remote_config.py before calling this function.
Args:
sender: The model class (User)
instance: The User instance that changed
**kwargs: Additional signal arguments
"""
try:
# Get all teams that the user's personal API keys have access to
affected_teams = get_teams_for_user_personal_api_keys(instance.id)
_warm_cache_for_teams(affected_teams, "user status change", str(instance.id), None)
except Exception as e:
logger.exception(f"Error updating authentication cache for user {instance.id} status change: {e}")
def update_personal_api_key_deleted_cache(instance: PersonalAPIKey):
"""
Update team access caches when a PersonalAPIKey is deleted.
When a PersonalAPIKey is deleted, it needs to be removed from all team caches
that it had access to. This includes both scoped and unscoped keys.
Args:
instance: The PersonalAPIKey instance that was deleted
"""
try:
# Get all teams that this specific key had access to
# We need to determine this based on the key's scoping
scoped_teams = instance.scoped_teams or []
if scoped_teams:
# Scoped key - only affects specific teams
team_api_tokens = Team.objects.filter(id__in=scoped_teams).values_list("api_token", flat=True)
else:
# Unscoped key - affects all teams in user's organizations
user_organizations = OrganizationMembership.objects.filter(user_id=instance.user_id).values_list(
"organization_id", flat=True
)
if user_organizations:
team_api_tokens = Team.objects.filter(organization_id__in=user_organizations).values_list(
"api_token", flat=True
)
else:
team_api_tokens = []
_warm_cache_for_teams(team_api_tokens, "PersonalAPIKey deletion", str(instance.user_id), None)
except Exception as e:
logger.exception(
f"Error updating team caches after PersonalAPIKey deletion: {e}",
extra={
"personal_api_key_id": getattr(instance, "id", None),
"user_id": getattr(instance, "user_id", None),
},
)
def update_organization_membership_created_cache(membership_instance):
"""
Update team access caches when an OrganizationMembership is created.
When a user is added to an organization, their unscoped Personal API Keys should
gain access to teams within that organization. This function updates the caches for
all teams in the organization that was joined.
Args:
membership_instance: The OrganizationMembership instance that was created
"""
try:
# Get all teams in the organization the user joined
organization_id = membership_instance.organization_id
user_id = membership_instance.user_id
team_api_tokens = Team.objects.filter(organization_id=organization_id).values_list("api_token", flat=True)
_warm_cache_for_teams(team_api_tokens, "adding user to organization", user_id, organization_id)
except Exception as e:
logger.exception(
f"Error updating team caches after OrganizationMembership creation: {e}",
extra={
"user_id": getattr(membership_instance, "user_id", None),
"organization_id": getattr(membership_instance, "organization_id", None),
},
)
def update_organization_membership_deleted_cache(membership_instance):
"""
Update team access caches when an OrganizationMembership is deleted.
When a user is removed from an organization, their Personal API Keys should no longer
have access to teams within that organization. This function updates the caches for
all teams in the organization that was removed from.
This is different from update_user_authentication_cache because we need to update
the teams from the organization the user was REMOVED from, not their current teams
(which won't include the removed organization anymore).
Args:
membership_instance: The OrganizationMembership instance that was deleted
"""
try:
# Get all teams in the organization the user was removed from
organization_id = membership_instance.organization_id
user_id = membership_instance.user_id
team_api_tokens = Team.objects.filter(organization_id=organization_id).values_list("api_token", flat=True)
_warm_cache_for_teams(team_api_tokens, "removing user from organization", user_id, organization_id)
except Exception as e:
logger.exception(
f"Error updating team caches after OrganizationMembership deletion: {e}",
extra={
"user_id": getattr(membership_instance, "user_id", None),
"organization_id": getattr(membership_instance, "organization_id", None),
},
)
def _warm_cache_for_teams(
team_api_tokens: set[str] | list[str], action: str, user_id: str, organization_id: str | None
):
"""
Warm the cache for a set of teams.
"""
if not team_api_tokens:
logger.debug(f"No teams found in organization {organization_id}, no cache updates needed")
return
# Warm the cache for each team in the organization
# This will rebuild the cache without the removed user's keys
for project_api_key in team_api_tokens:
try:
warm_team_token_cache(project_api_key)
logger.debug(
f"Warmed cache for team {project_api_key} after {action}",
extra={
"project_api_key": project_api_key,
"user_id": user_id,
"organization_id": organization_id,
},
)
except Exception as e:
logger.warning(
f"Failed to warm cache for team {project_api_key} after {action}: {e}",
extra={
"project_api_key": project_api_key,
"user_id": user_id,
"organization_id": organization_id,
},
)
logger.info(
f"Updated {len(team_api_tokens)} team caches after {action}",
extra={
"user_id": user_id,
"organization_id": organization_id,
"teams_updated": len(team_api_tokens),
},
)

File diff suppressed because it is too large Load Diff

View File

@@ -54,6 +54,7 @@ from posthog.tasks.tasks import (
update_survey_iteration,
verify_persons_data_in_sync,
)
from posthog.tasks.team_access_cache_tasks import warm_all_team_access_caches_task
from posthog.utils import get_crontab
TWENTY_FOUR_HOURS = 24 * 60 * 60
@@ -95,6 +96,14 @@ def setup_periodic_tasks(sender: Celery, **kwargs: Any) -> None:
name="schedule warming for largest teams",
)
# Team access cache warming - every 10 minutes
add_periodic_task_with_expiry(
sender,
600, # Every 10 minutes (no TTL, just fill missing entries)
warm_all_team_access_caches_task.s(),
name="warm team access caches",
)
# Update events table partitions twice a week
sender.add_periodic_task(
crontab(day_of_week="mon,fri", hour="0", minute="0"),

View File

@@ -0,0 +1,117 @@
"""
Background tasks for warming team access token caches.
This module provides Celery tasks to periodically warm the team access token
caches, ensuring that the cached authentication system has fresh data.
"""
import logging
from django.conf import settings
from celery import shared_task
from celery.app.task import Task
from posthog.storage.team_access_cache import get_teams_needing_cache_refresh_paginated, warm_team_token_cache
logger = logging.getLogger(__name__)
# Configuration
CACHE_WARMING_BATCH_SIZE = getattr(settings, "CACHE_WARMING_BATCH_SIZE", 50)
CACHE_WARMING_PAGE_SIZE = getattr(settings, "CACHE_WARMING_PAGE_SIZE", 1000) # Teams per database page
@shared_task(bind=True, max_retries=3)
def warm_team_cache_task(self: "Task", project_api_key: str) -> dict:
"""
Warm the token cache for a specific team.
Args:
project_api_key: The team's project API key
Returns:
Dictionary with operation results
"""
success = warm_team_token_cache(project_api_key)
if not success:
# Log a warning, but don't retry. We'll let the next scheduled task pick it up.
logger.warning(f"Failed to warm cache for team {project_api_key}")
return {"status": "failure", "project_api_key": project_api_key}
logger.info(
f"Successfully warmed cache for team {project_api_key}",
extra={"project_api_key": project_api_key},
)
return {"status": "success", "project_api_key": project_api_key}
@shared_task(bind=True, max_retries=1)
def warm_all_team_access_caches_task(self: "Task") -> dict:
"""
Warm caches for all teams that need refreshing.
This task identifies teams with expired or missing caches and
schedules individual warming tasks for each team.
Returns:
Dictionary with operation results
"""
try:
teams_scheduled = 0
failed_teams = 0
teams_pages_processed = 0
total_teams_found = 0
# Use paginated approach for memory efficiency
logger.info(f"Using paginated cache warming with page size {CACHE_WARMING_PAGE_SIZE}")
for teams_page in get_teams_needing_cache_refresh_paginated(batch_size=CACHE_WARMING_PAGE_SIZE):
teams_pages_processed += 1
if not teams_page:
continue
total_teams_found += len(teams_page)
logger.debug(
f"Processing page {teams_pages_processed} with {len(teams_page)} teams needing refresh",
extra={"page": teams_pages_processed, "teams_in_page": len(teams_page)},
)
# Process teams in batches to avoid overwhelming the system
for i in range(0, len(teams_page), CACHE_WARMING_BATCH_SIZE):
batch = teams_page[i : i + CACHE_WARMING_BATCH_SIZE]
# Schedule warming tasks for this batch
for project_api_key in batch:
try:
warm_team_cache_task.delay(project_api_key)
teams_scheduled += 1
except Exception as e:
# Log individual team scheduling failure but continue with others
failed_teams += 1
logger.warning(
f"Failed to schedule cache warming for team {project_api_key}: {e}",
extra={"project_api_key": project_api_key, "error": str(e)},
)
logger.debug(f"Scheduled cache warming for batch of {len(batch)} teams")
logger.info(
"Cache warming completed",
extra={"teams_found": total_teams_found, "teams_scheduled": teams_scheduled, "failed_teams": failed_teams},
)
return {
"status": "success",
"teams_found": total_teams_found,
"teams_scheduled": teams_scheduled,
"failed_teams": failed_teams,
}
except Exception as e:
# Retry for systemic failures (database connectivity, etc.)
logger.exception(f"Systemic failure in cache warming batch task: {e}")
raise self.retry(exc=e, countdown=300) # 5 minutes

View File

@@ -0,0 +1,167 @@
"""
Tests for team access cache Celery tasks.
"""
from unittest.mock import MagicMock, patch
from django.test import TestCase
from posthog.tasks.team_access_cache_tasks import warm_all_team_access_caches_task, warm_team_cache_task
class TestWarmTeamCacheTask(TestCase):
"""Test the individual team cache warming task."""
@patch("posthog.tasks.team_access_cache_tasks.warm_team_token_cache")
def test_warm_team_cache_task_success(self, mock_warm: MagicMock) -> None:
"""Test successful cache warming for a team."""
mock_warm.return_value = True
project_api_key = "phs_test_team_123"
result = warm_team_cache_task(project_api_key)
mock_warm.assert_called_once_with(project_api_key)
assert result["status"] == "success"
assert result["project_api_key"] == project_api_key
@patch("posthog.tasks.team_access_cache_tasks.warm_team_token_cache")
def test_warm_team_cache_task_failure(self, mock_warm: MagicMock) -> None:
"""Test that cache warming failure does not trigger retry."""
mock_warm.return_value = False
project_api_key = "phs_test_team_123"
result = warm_team_cache_task(project_api_key)
mock_warm.assert_called_once_with(project_api_key)
assert result["status"] == "failure"
assert result["project_api_key"] == project_api_key
class TestWarmAllTeamsCachesTask(TestCase):
"""Test the batch cache warming task."""
@patch("posthog.tasks.team_access_cache_tasks.get_teams_needing_cache_refresh_paginated")
def test_warm_all_team_access_caches_task_no_teams(self, mock_get_teams_paginated: MagicMock) -> None:
"""Test batch warming when no teams need refresh."""
mock_get_teams_paginated.return_value = iter([])
result = warm_all_team_access_caches_task()
assert result["status"] == "success"
assert result["teams_found"] == 0
assert result["teams_scheduled"] == 0
assert result["failed_teams"] == 0
@patch("posthog.tasks.team_access_cache_tasks.get_teams_needing_cache_refresh_paginated")
@patch("posthog.tasks.team_access_cache_tasks.warm_team_cache_task.delay")
@patch("posthog.tasks.team_access_cache_tasks.CACHE_WARMING_BATCH_SIZE", 2)
def test_warm_all_team_access_caches_task_batching(
self, mock_delay: MagicMock, mock_get_teams_paginated: MagicMock
) -> None:
"""Test that teams are processed in configured batches."""
# Setup mock teams - more than batch size, split across pages
page1 = ["phs_team1_123", "phs_team2_456"]
page2 = ["phs_team3_789", "phs_team4_012"]
mock_get_teams_paginated.return_value = iter([page1, page2])
result = warm_all_team_access_caches_task()
# Verify all teams were scheduled despite batching
assert mock_delay.call_count == 4
all_teams = page1 + page2
for team in all_teams:
mock_delay.assert_any_call(team)
assert result["status"] == "success"
assert result["teams_scheduled"] == 4
assert result["teams_found"] == 4
assert result["failed_teams"] == 0
@patch("posthog.tasks.team_access_cache_tasks.get_teams_needing_cache_refresh_paginated")
@patch("posthog.tasks.team_access_cache_tasks.warm_team_cache_task.delay")
def test_warm_all_team_access_caches_task_handles_individual_failures(
self, mock_delay: MagicMock, mock_get_teams_paginated: MagicMock
) -> None:
"""Test that batch warming handles individual team failures gracefully."""
# Setup mocks - single page of teams
mock_teams = ["phs_team1_123", "phs_team2_456", "phs_team3_789"]
mock_get_teams_paginated.return_value = iter([mock_teams])
# Make delay fail for the second team only
def delay_side_effect(project_api_key: str) -> None:
if project_api_key == "phs_team2_456":
raise Exception("Task scheduling failed for team 2")
return None
mock_delay.side_effect = delay_side_effect
# Execute task - should NOT raise exception, but should handle failure gracefully
result = warm_all_team_access_caches_task()
mock_get_teams_paginated.assert_called_once()
# Verify all teams were attempted
assert mock_delay.call_count == 3
mock_delay.assert_any_call("phs_team1_123")
mock_delay.assert_any_call("phs_team2_456")
mock_delay.assert_any_call("phs_team3_789")
# Verify result shows partial success
assert result["status"] == "success"
assert result["teams_scheduled"] == 2 # team2 failed to schedule
assert result["failed_teams"] == 1
assert result["teams_found"] == 3
@patch("posthog.tasks.team_access_cache_tasks.get_teams_needing_cache_refresh_paginated")
def test_warm_all_team_access_caches_task_handles_systemic_failures(
self, mock_get_teams_paginated: MagicMock
) -> None:
"""Test that batch warming retries on systemic failures like database connectivity issues."""
# Make get_teams_needing_cache_refresh_paginated fail (systemic issue)
mock_get_teams_paginated.side_effect = Exception("Database connection failed")
# Execute task - should raise some exception that triggers retry
# The retry mechanism may raise the original exception or a Retry exception
with self.assertRaises(Exception) as cm:
warm_all_team_access_caches_task()
mock_get_teams_paginated.assert_called_once()
# Verify the exception is related to our test failure
self.assertIn("Database connection failed", str(cm.exception))
class TestTaskIntegration(TestCase):
"""Integration tests for the complete task flow."""
@patch("posthog.tasks.team_access_cache_tasks.warm_team_token_cache")
@patch("posthog.tasks.team_access_cache_tasks.get_teams_needing_cache_refresh_paginated")
def test_complete_cache_refresh_flow(self, mock_get_teams_paginated: MagicMock, mock_warm: MagicMock) -> None:
"""Test the complete flow from batch task to individual warming."""
# Setup mocks - single page with one team
mock_teams = ["phs_team1_123"]
mock_get_teams_paginated.return_value = iter([mock_teams])
mock_warm.return_value = True
# Execute batch task (without actually scheduling async tasks)
with patch("posthog.tasks.team_access_cache_tasks.warm_team_cache_task.delay") as mock_delay:
batch_result = warm_all_team_access_caches_task()
# Verify batch task scheduled individual task
mock_delay.assert_called_once_with("phs_team1_123")
# Simulate individual task execution
individual_result = warm_team_cache_task("phs_team1_123")
# Verify both tasks succeeded
assert batch_result["status"] == "success"
assert batch_result["teams_scheduled"] == 1
assert batch_result["teams_found"] == 1
assert batch_result["failed_teams"] == 0
assert individual_result["status"] == "success"
assert individual_result["project_api_key"] == "phs_team1_123"