mirror of
https://github.com/BillyOutlast/posthog.git
synced 2026-02-04 11:11:24 +01:00
feat: Build authentication HyperCache for local_evaluation (#37596)
Co-authored-by: github-actions <41898282+github-actions[bot]@users.noreply.github.com>
This commit is contained in:
File diff suppressed because it is too large
Load Diff
@@ -317,6 +317,7 @@ field_exclusions: dict[ActivityScope, list[str]] = {
|
||||
"id",
|
||||
"secret_api_token",
|
||||
"secret_api_token_backup",
|
||||
"_old_api_token",
|
||||
],
|
||||
"Project": ["id", "created_at"],
|
||||
"DataWarehouseSavedQuery": [
|
||||
|
||||
@@ -5,7 +5,7 @@ from typing import Any, Optional
|
||||
from django.conf import settings
|
||||
from django.core.cache import cache
|
||||
from django.db import models, transaction
|
||||
from django.db.models.signals import post_save
|
||||
from django.db.models.signals import post_delete, post_save, pre_save
|
||||
from django.dispatch.dispatcher import receiver
|
||||
from django.http import HttpRequest
|
||||
from django.utils import timezone
|
||||
@@ -19,9 +19,12 @@ from posthog.exceptions_capture import capture_exception
|
||||
from posthog.models.error_tracking.error_tracking import ErrorTrackingSuppressionRule
|
||||
from posthog.models.feature_flag.feature_flag import FeatureFlag
|
||||
from posthog.models.hog_functions.hog_function import HogFunction
|
||||
from posthog.models.organization import OrganizationMembership
|
||||
from posthog.models.personal_api_key import PersonalAPIKey
|
||||
from posthog.models.plugin import PluginConfig
|
||||
from posthog.models.surveys.survey import Survey
|
||||
from posthog.models.team.team import Team
|
||||
from posthog.models.user import User
|
||||
from posthog.models.utils import UUIDTModel, execute_with_timeout
|
||||
from posthog.storage.hypercache import HyperCache, HyperCacheStoreMissing
|
||||
|
||||
@@ -467,10 +470,30 @@ def _update_team_remote_config(team_id: int):
|
||||
update_team_remote_config.delay(team_id)
|
||||
|
||||
|
||||
@receiver(pre_save, sender=Team)
|
||||
def team_pre_save(sender, instance: "Team", **kwargs):
|
||||
"""Capture old api_token value before save for cache cleanup."""
|
||||
from posthog.storage.team_access_cache_signal_handlers import capture_old_api_token
|
||||
|
||||
capture_old_api_token(instance, **kwargs)
|
||||
|
||||
|
||||
@receiver(post_save, sender=Team)
|
||||
def team_saved(sender, instance: "Team", created, **kwargs):
|
||||
transaction.on_commit(lambda: _update_team_remote_config(instance.id))
|
||||
|
||||
from posthog.storage.team_access_cache_signal_handlers import update_team_authentication_cache
|
||||
|
||||
transaction.on_commit(lambda: update_team_authentication_cache(instance, created, **kwargs))
|
||||
|
||||
|
||||
@receiver(post_delete, sender=Team)
|
||||
def team_deleted(sender, instance: "Team", **kwargs):
|
||||
"""Handle team deletion for access cache."""
|
||||
from posthog.storage.team_access_cache_signal_handlers import update_team_authentication_cache_on_delete
|
||||
|
||||
transaction.on_commit(lambda: update_team_authentication_cache_on_delete(instance, **kwargs))
|
||||
|
||||
|
||||
@receiver(post_save, sender=FeatureFlag)
|
||||
def feature_flag_saved(sender, instance: "FeatureFlag", created, **kwargs):
|
||||
@@ -500,3 +523,85 @@ def survey_saved(sender, instance: "Survey", created, **kwargs):
|
||||
@receiver(post_save, sender=ErrorTrackingSuppressionRule)
|
||||
def error_tracking_suppression_rule_saved(sender, instance: "ErrorTrackingSuppressionRule", created, **kwargs):
|
||||
transaction.on_commit(lambda: _update_team_remote_config(instance.team_id))
|
||||
|
||||
|
||||
@receiver(post_save, sender=PersonalAPIKey)
|
||||
def personal_api_key_saved(sender, instance: "PersonalAPIKey", created, **kwargs):
|
||||
"""
|
||||
Handle PersonalAPIKey save for team access cache invalidation.
|
||||
|
||||
Skip cache updates for last_used_at field updates to avoid unnecessary cache warming
|
||||
during authentication requests.
|
||||
"""
|
||||
from posthog.storage.team_access_cache_signal_handlers import update_personal_api_key_authentication_cache
|
||||
|
||||
# Skip cache updates if only last_used_at is being updated
|
||||
update_fields = kwargs.get("update_fields")
|
||||
if update_fields is not None and set(update_fields) == {"last_used_at"}:
|
||||
return
|
||||
|
||||
transaction.on_commit(lambda: update_personal_api_key_authentication_cache(instance))
|
||||
|
||||
|
||||
@receiver(post_delete, sender=PersonalAPIKey)
|
||||
def personal_api_key_deleted(sender, instance: "PersonalAPIKey", **kwargs):
|
||||
"""
|
||||
Handle PersonalAPIKey delete for team access cache invalidation.
|
||||
"""
|
||||
from posthog.storage.team_access_cache_signal_handlers import update_personal_api_key_deleted_cache
|
||||
|
||||
transaction.on_commit(lambda: update_personal_api_key_deleted_cache(instance))
|
||||
|
||||
|
||||
@receiver(post_save, sender=User)
|
||||
def user_saved(sender, instance: "User", created, **kwargs):
|
||||
"""
|
||||
Handle User save for team access cache updates when is_active changes.
|
||||
|
||||
When a user's is_active status changes, their Personal API Keys need to be
|
||||
added or removed from team authentication caches.
|
||||
"""
|
||||
update_fields = kwargs.get("update_fields")
|
||||
if update_fields is not None and "is_active" not in update_fields:
|
||||
logger.debug(f"User {instance.id} updated but is_active unchanged, skipping cache update")
|
||||
return
|
||||
|
||||
# If update_fields is None, we need to update cache since all fields (including is_active) might have changed
|
||||
|
||||
from posthog.storage.team_access_cache_signal_handlers import update_user_authentication_cache
|
||||
|
||||
transaction.on_commit(lambda: update_user_authentication_cache(instance, **kwargs))
|
||||
|
||||
|
||||
@receiver(post_save, sender=OrganizationMembership)
|
||||
def organization_membership_saved(sender, instance: "OrganizationMembership", created, **kwargs):
|
||||
"""
|
||||
Handle OrganizationMembership creation for team access cache updates.
|
||||
|
||||
When a user is added to an organization, their unscoped personal API keys
|
||||
should gain access to teams within that organization. This ensures
|
||||
that the authentication cache is updated to reflect the new access rights.
|
||||
|
||||
Note: We intentionally only handle creation (created=True), not updates.
|
||||
Changes to membership level (e.g., MEMBER → ADMIN) don't affect API key
|
||||
access - Personal API keys grant access based on organization membership
|
||||
existence, not role level.
|
||||
"""
|
||||
if created:
|
||||
from posthog.storage.team_access_cache_signal_handlers import update_organization_membership_created_cache
|
||||
|
||||
transaction.on_commit(lambda: update_organization_membership_created_cache(instance))
|
||||
|
||||
|
||||
@receiver(post_delete, sender=OrganizationMembership)
|
||||
def organization_membership_deleted(sender, instance: "OrganizationMembership", **kwargs):
|
||||
"""
|
||||
Handle OrganizationMembership deletion for team access cache invalidation.
|
||||
|
||||
When a user is removed from an organization, their unscoped personal API keys
|
||||
should no longer have access to teams within that organization. This ensures
|
||||
that the authentication cache is updated to reflect the change in access rights.
|
||||
"""
|
||||
from posthog.storage.team_access_cache_signal_handlers import update_organization_membership_deleted_cache
|
||||
|
||||
transaction.on_commit(lambda: update_organization_membership_deleted_cache(instance))
|
||||
|
||||
@@ -258,9 +258,9 @@ class TestHogFunctionsBackgroundReloading(TestCase, QueryMatchingTest):
|
||||
{"key": "$host", "operator": "regex", "value": "^(localhost|127\\.0\\.0\\.1)($|:)"},
|
||||
{"key": "$pageview", "operator": "regex", "value": "test"},
|
||||
]
|
||||
# 1 update team, 1 load hog flows, 1 load hog functions, 1 update hog functions
|
||||
# 1 select team (for field comparison), 1 update team, 1 load hog flows, 1 load hog functions, 1 update hog functions
|
||||
# Note: RemoteConfig refresh queries are now deferred via async signals
|
||||
with self.assertNumQueries(4):
|
||||
with self.assertNumQueries(5):
|
||||
self.team.save()
|
||||
hog_function_1.refresh_from_db()
|
||||
hog_function_2.refresh_from_db()
|
||||
|
||||
208
posthog/models/test/test_remote_config_signals.py
Normal file
208
posthog/models/test/test_remote_config_signals.py
Normal file
@@ -0,0 +1,208 @@
|
||||
"""
|
||||
Tests for signal handlers in posthog/models/remote_config.py.
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from django.test import TestCase
|
||||
|
||||
from parameterized import parameterized
|
||||
|
||||
from posthog.models.organization import OrganizationMembership
|
||||
from posthog.models.remote_config import organization_membership_deleted, user_saved
|
||||
from posthog.models.user import User
|
||||
|
||||
|
||||
class TestUserSavedSignalHandler(TestCase):
|
||||
"""Test the user_saved signal handler in remote_config.py."""
|
||||
|
||||
@parameterized.expand(
|
||||
[
|
||||
# (update_fields, should_schedule_update, description)
|
||||
(["is_active", "email"], True, "is_active in update_fields"),
|
||||
(None, True, "update_fields is None (bulk operation)"),
|
||||
(["email", "name"], False, "is_active not in update_fields"),
|
||||
([], False, "empty update_fields list"),
|
||||
(["is_active"], True, "only is_active in update_fields"),
|
||||
]
|
||||
)
|
||||
@patch("django.db.transaction.on_commit")
|
||||
def test_user_saved_update_fields_scenarios(
|
||||
self, update_fields, should_schedule_update, description, mock_on_commit
|
||||
):
|
||||
"""Test user_saved signal handler for various update_fields scenarios."""
|
||||
|
||||
# Create mock user
|
||||
mock_user = MagicMock()
|
||||
mock_user.id = 42
|
||||
mock_user.is_active = True
|
||||
|
||||
# Call user_saved with specified update_fields
|
||||
user_saved(sender=User, instance=mock_user, created=False, update_fields=update_fields)
|
||||
|
||||
# Verify transaction.on_commit behavior
|
||||
if should_schedule_update:
|
||||
mock_on_commit.assert_called_once()
|
||||
else:
|
||||
mock_on_commit.assert_not_called()
|
||||
|
||||
@patch("posthog.models.remote_config.logger")
|
||||
@patch("django.db.transaction.on_commit")
|
||||
def test_user_saved_logs_debug_when_skipping_update(self, mock_on_commit, mock_logger):
|
||||
"""Test that user_saved logs debug message when skipping cache update."""
|
||||
|
||||
# Create mock user
|
||||
mock_user = MagicMock()
|
||||
mock_user.id = 42
|
||||
mock_user.is_active = True
|
||||
|
||||
# Call user_saved with is_active NOT in update_fields
|
||||
user_saved(sender=User, instance=mock_user, created=False, update_fields=["email", "name"])
|
||||
|
||||
# Verify debug message was logged
|
||||
mock_logger.debug.assert_called_once_with("User 42 updated but is_active unchanged, skipping cache update")
|
||||
|
||||
# Verify transaction.on_commit was not called
|
||||
mock_on_commit.assert_not_called()
|
||||
|
||||
@patch("posthog.storage.team_access_cache_signal_handlers.update_user_authentication_cache")
|
||||
@patch("django.db.transaction.on_commit")
|
||||
def test_user_saved_uses_transaction_on_commit(self, mock_on_commit, mock_update_cache):
|
||||
"""Test that user_saved uses transaction.on_commit to defer cache updates."""
|
||||
|
||||
# Create mock user
|
||||
mock_user = MagicMock()
|
||||
mock_user.id = 42
|
||||
mock_user.is_active = True
|
||||
|
||||
# Call user_saved with is_active in update_fields
|
||||
user_saved(sender=User, instance=mock_user, created=False, update_fields=["is_active"])
|
||||
|
||||
# Verify transaction.on_commit was called
|
||||
mock_on_commit.assert_called_once()
|
||||
|
||||
# Get the lambda function that was passed to on_commit and call it
|
||||
on_commit_lambda = mock_on_commit.call_args[0][0]
|
||||
on_commit_lambda()
|
||||
|
||||
# Verify that the update function would be called after transaction commits
|
||||
# The lambda passes instance and **kwargs (which doesn't include created)
|
||||
mock_update_cache.assert_called_once_with(mock_user, update_fields=["is_active"])
|
||||
|
||||
|
||||
class TestOrganizationMembershipDeletedSignalHandler(TestCase):
|
||||
"""Test the organization_membership_deleted signal handler in remote_config.py."""
|
||||
|
||||
@patch("django.db.transaction.on_commit")
|
||||
def test_organization_membership_deleted_calls_update_when_user_removed(self, mock_on_commit):
|
||||
"""Test that organization_membership_deleted schedules cache update when a user is removed from org."""
|
||||
|
||||
# Create mock user and organization
|
||||
mock_user = MagicMock()
|
||||
mock_user.id = 42
|
||||
|
||||
mock_org = MagicMock()
|
||||
mock_org.id = "test-org-uuid"
|
||||
|
||||
# Create mock OrganizationMembership
|
||||
mock_membership = MagicMock()
|
||||
mock_membership.user = mock_user
|
||||
mock_membership.organization = mock_org
|
||||
|
||||
# Call organization_membership_deleted
|
||||
organization_membership_deleted(sender=OrganizationMembership, instance=mock_membership)
|
||||
|
||||
# Verify transaction.on_commit was called (update was scheduled)
|
||||
mock_on_commit.assert_called_once()
|
||||
|
||||
@patch("posthog.storage.team_access_cache_signal_handlers.update_organization_membership_deleted_cache")
|
||||
@patch("django.db.transaction.on_commit")
|
||||
def test_organization_membership_deleted_uses_transaction_on_commit(self, mock_on_commit, mock_update_cache):
|
||||
"""Test that organization_membership_deleted uses transaction.on_commit to defer cache updates."""
|
||||
|
||||
# Create mock user and organization
|
||||
mock_user = MagicMock()
|
||||
mock_user.id = 42
|
||||
|
||||
mock_org = MagicMock()
|
||||
mock_org.id = "test-org-uuid"
|
||||
|
||||
# Create mock OrganizationMembership
|
||||
mock_membership = MagicMock()
|
||||
mock_membership.user = mock_user
|
||||
mock_membership.organization = mock_org
|
||||
mock_membership.organization_id = "test-org-uuid"
|
||||
mock_membership.user_id = 42
|
||||
|
||||
# Call organization_membership_deleted
|
||||
organization_membership_deleted(sender=OrganizationMembership, instance=mock_membership)
|
||||
|
||||
# Verify transaction.on_commit was called
|
||||
mock_on_commit.assert_called_once()
|
||||
|
||||
# Get the lambda function that was passed to on_commit and call it
|
||||
on_commit_lambda = mock_on_commit.call_args[0][0]
|
||||
on_commit_lambda()
|
||||
|
||||
# Verify that the update function would be called after transaction commits
|
||||
# The lambda passes the membership instance
|
||||
mock_update_cache.assert_called_once_with(mock_membership)
|
||||
|
||||
@patch("posthog.models.remote_config.logger")
|
||||
@patch("django.db.transaction.on_commit")
|
||||
def test_organization_membership_deleted_logs_when_scheduled(self, mock_on_commit, mock_logger):
|
||||
"""Test that organization_membership_deleted properly schedules cache updates."""
|
||||
|
||||
# Create mock user and organization
|
||||
mock_user = MagicMock()
|
||||
mock_user.id = 42
|
||||
|
||||
mock_org = MagicMock()
|
||||
mock_org.id = "test-org-uuid"
|
||||
|
||||
# Create mock OrganizationMembership
|
||||
mock_membership = MagicMock()
|
||||
mock_membership.user = mock_user
|
||||
mock_membership.organization = mock_org
|
||||
|
||||
# Call organization_membership_deleted
|
||||
organization_membership_deleted(sender=OrganizationMembership, instance=mock_membership)
|
||||
|
||||
# Verify transaction.on_commit was called (update was scheduled)
|
||||
mock_on_commit.assert_called_once()
|
||||
|
||||
def test_organization_membership_deleted_handles_none_user(self):
|
||||
"""Test that organization_membership_deleted handles membership with None user gracefully."""
|
||||
|
||||
# Create mock OrganizationMembership with None user
|
||||
mock_membership = MagicMock()
|
||||
mock_membership.user = None
|
||||
|
||||
# Should not raise an exception
|
||||
try:
|
||||
organization_membership_deleted(sender=OrganizationMembership, instance=mock_membership)
|
||||
except Exception as e:
|
||||
self.fail(f"organization_membership_deleted raised an exception with None user: {e}")
|
||||
|
||||
@patch("django.db.transaction.on_commit")
|
||||
def test_organization_membership_deleted_with_different_kwargs(self, mock_on_commit):
|
||||
"""Test that organization_membership_deleted properly forwards kwargs to the cache update."""
|
||||
|
||||
# Create mock user and organization
|
||||
mock_user = MagicMock()
|
||||
mock_user.id = 42
|
||||
|
||||
mock_org = MagicMock()
|
||||
mock_org.id = "test-org-uuid"
|
||||
|
||||
# Create mock OrganizationMembership
|
||||
mock_membership = MagicMock()
|
||||
mock_membership.user = mock_user
|
||||
mock_membership.organization = mock_org
|
||||
|
||||
# Call organization_membership_deleted with additional kwargs
|
||||
test_kwargs = {"raw": False, "using": "default"}
|
||||
organization_membership_deleted(sender=OrganizationMembership, instance=mock_membership, **test_kwargs)
|
||||
|
||||
# Verify transaction.on_commit was called
|
||||
mock_on_commit.assert_called_once()
|
||||
383
posthog/storage/team_access_cache.py
Normal file
383
posthog/storage/team_access_cache.py
Normal file
@@ -0,0 +1,383 @@
|
||||
"""
|
||||
Per-team access token cache layer for cache-based authentication.
|
||||
|
||||
This module provides Redis-based caching of hashed access tokens per team,
|
||||
enabling zero-database-call authentication for the local_evaluation endpoint.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
|
||||
from django.db import transaction
|
||||
from django.db.models import Q
|
||||
|
||||
from posthog.models.organization import OrganizationMembership
|
||||
from posthog.models.personal_api_key import PersonalAPIKey, hash_key_value
|
||||
from posthog.models.team.team import Team
|
||||
from posthog.storage.hypercache import HyperCache, HyperCacheStoreMissing, KeyType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Cache configuration
|
||||
DEFAULT_TTL = 300 # 5 minutes
|
||||
CACHE_KEY_PREFIX = "cache/teams"
|
||||
|
||||
|
||||
class TeamAccessTokenCache:
|
||||
"""
|
||||
HyperCache-based cache for per-team access token lists.
|
||||
|
||||
This class manages hashed token lists per team to enable fast authentication
|
||||
lookups without database queries. Each team has its own cache entry with
|
||||
JSON data containing hashed authorized tokens and metadata. Uses HyperCache
|
||||
for automatic Redis + S3 backup and improved reliability.
|
||||
"""
|
||||
|
||||
def __init__(self, ttl: int = DEFAULT_TTL):
|
||||
"""
|
||||
Initialize the team access token cache.
|
||||
|
||||
Args:
|
||||
ttl: Time-to-live for cache entries in seconds
|
||||
"""
|
||||
self.ttl = ttl
|
||||
|
||||
def update_team_tokens(self, project_api_key: str, team_id: int, hashed_tokens: list[str]) -> None:
|
||||
"""
|
||||
Update a team's complete token list in cache.
|
||||
|
||||
Args:
|
||||
project_api_key: The team's project API key
|
||||
team_id: The team's ID
|
||||
hashed_tokens: List of hashed tokens (already in sha256$ format)
|
||||
"""
|
||||
try:
|
||||
token_data = {
|
||||
"hashed_tokens": hashed_tokens,
|
||||
"last_updated": datetime.now(UTC).isoformat(),
|
||||
"team_id": team_id,
|
||||
}
|
||||
team_access_tokens_hypercache.set_cache_value(project_api_key, token_data)
|
||||
|
||||
logger.info(
|
||||
f"Updated token cache for team {project_api_key} with {len(hashed_tokens)} tokens",
|
||||
extra={"team_project_api_key": project_api_key, "token_count": len(hashed_tokens)},
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Error updating tokens for team {project_api_key}: {e}")
|
||||
raise
|
||||
|
||||
def invalidate_team(self, project_api_key: str) -> None:
|
||||
"""
|
||||
Invalidate (delete) a team's token cache.
|
||||
|
||||
Args:
|
||||
project_api_key: The team's project API key
|
||||
"""
|
||||
try:
|
||||
team_access_tokens_hypercache.clear_cache(project_api_key)
|
||||
|
||||
logger.info(f"Invalidated token cache for team {project_api_key}")
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Error invalidating cache for team {project_api_key}: {e}")
|
||||
raise
|
||||
|
||||
|
||||
def _load_team_access_tokens(team_token: KeyType) -> dict[str, Any] | HyperCacheStoreMissing:
|
||||
"""
|
||||
Load team access tokens from the database.
|
||||
|
||||
Args:
|
||||
team_token: Team identifier (can be Team object, API token string, or team ID)
|
||||
|
||||
Returns:
|
||||
Dictionary containing hashed tokens and metadata, or HyperCacheStoreMissing if team not found
|
||||
"""
|
||||
try:
|
||||
# Use transaction isolation to ensure consistent reads across all queries
|
||||
with transaction.atomic():
|
||||
if isinstance(team_token, str):
|
||||
team = Team.objects.select_related("organization").get(api_token=team_token)
|
||||
elif isinstance(team_token, int):
|
||||
team = Team.objects.select_related("organization").get(id=team_token)
|
||||
else:
|
||||
# team_token is already a Team object, but ensure organization is loaded
|
||||
team = team_token
|
||||
if not hasattr(team, "organization") or team.organization is None:
|
||||
team = Team.objects.select_related("organization").get(id=team.id)
|
||||
|
||||
hashed_tokens: list[str] = []
|
||||
|
||||
# Get all relevant personal API keys in one optimized query
|
||||
# Combines scoped and unscoped keys with proper filtering
|
||||
personal_keys = (
|
||||
PersonalAPIKey.objects.select_related("user")
|
||||
.filter(
|
||||
user__organization_membership__organization_id=team.organization_id,
|
||||
user__is_active=True,
|
||||
)
|
||||
.filter(
|
||||
# Organization scoping: key must either have no org restriction OR include this org
|
||||
Q(scoped_organizations__isnull=True)
|
||||
| Q(scoped_organizations=[])
|
||||
| Q(scoped_organizations__contains=[str(team.organization_id)])
|
||||
)
|
||||
.filter(
|
||||
(
|
||||
# Scoped keys: explicitly include this team AND have feature flag read or write access
|
||||
Q(scoped_teams__contains=[team.id])
|
||||
& (
|
||||
# Keys with write permission implicitly have read permission
|
||||
Q(scopes__contains=["feature_flag:read"]) | Q(scopes__contains=["feature_flag:write"])
|
||||
)
|
||||
)
|
||||
| (
|
||||
# Unscoped keys: no team restriction (null or empty array)
|
||||
(Q(scoped_teams__isnull=True) | Q(scoped_teams=[]))
|
||||
& (
|
||||
# AND either no scope restriction OR has feature flag read or write access
|
||||
Q(scopes__isnull=True)
|
||||
| Q(scopes=[])
|
||||
| Q(scopes__contains=["feature_flag:read"])
|
||||
| Q(scopes__contains=["feature_flag:write"])
|
||||
)
|
||||
)
|
||||
)
|
||||
.distinct()
|
||||
.values_list("secure_value", flat=True)
|
||||
)
|
||||
|
||||
# Collect personal API key tokens
|
||||
hashed_tokens.extend(secure_value for secure_value in personal_keys if secure_value)
|
||||
|
||||
# Add team secret tokens
|
||||
if team.secret_api_token:
|
||||
hashed_secret = hash_key_value(team.secret_api_token, mode="sha256")
|
||||
hashed_tokens.append(hashed_secret)
|
||||
|
||||
if team.secret_api_token_backup:
|
||||
hashed_secret_backup = hash_key_value(team.secret_api_token_backup, mode="sha256")
|
||||
hashed_tokens.append(hashed_secret_backup)
|
||||
|
||||
return {
|
||||
"hashed_tokens": hashed_tokens,
|
||||
"last_updated": datetime.now(UTC).isoformat(),
|
||||
"team_id": team.id, # Include team_id for zero-DB-call authentication
|
||||
}
|
||||
|
||||
except Team.DoesNotExist:
|
||||
logger.warning(f"Team not found for project API key: {team_token}")
|
||||
return HyperCacheStoreMissing()
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Error loading team access tokens for {team_token}: {e}")
|
||||
return HyperCacheStoreMissing()
|
||||
|
||||
|
||||
# HyperCache instance for team access tokens
|
||||
team_access_tokens_hypercache = HyperCache(
|
||||
namespace="team_access_tokens",
|
||||
value="access_tokens.json",
|
||||
token_based=True, # Use team API token as key
|
||||
load_fn=_load_team_access_tokens,
|
||||
)
|
||||
|
||||
# Global instance for convenience
|
||||
team_access_cache = TeamAccessTokenCache()
|
||||
|
||||
|
||||
def warm_team_token_cache(project_api_key: str) -> bool:
|
||||
"""
|
||||
Warm the token cache for a specific team by loading from database.
|
||||
|
||||
This function now uses the HyperCache to update the cache, which handles
|
||||
both Redis and S3 storage automatically.
|
||||
|
||||
Args:
|
||||
project_api_key: The team's project API key
|
||||
|
||||
Returns:
|
||||
True if cache warming succeeded, False otherwise
|
||||
"""
|
||||
# Use HyperCache to update the cache - this will call _load_team_access_tokens
|
||||
# It does not raise an exception if the cache is not updated, so we don't need to try/except
|
||||
success = team_access_tokens_hypercache.update_cache(project_api_key)
|
||||
|
||||
if success:
|
||||
logger.info(
|
||||
f"Warmed token cache for team {project_api_key} using HyperCache",
|
||||
extra={"project_api_key": project_api_key},
|
||||
)
|
||||
else:
|
||||
logger.warning(f"Failed to warm token cache for team {project_api_key}")
|
||||
|
||||
return success
|
||||
|
||||
|
||||
def get_teams_needing_cache_refresh(limit: int | None = None, offset: int = 0) -> list[str]:
|
||||
"""
|
||||
Get a list of project API keys for teams that need cache refresh.
|
||||
|
||||
This function now supports pagination to handle large datasets efficiently.
|
||||
For installations with many teams, use limit/offset to process in batches.
|
||||
|
||||
Args:
|
||||
limit: Maximum number of teams to check. None means no limit (all teams).
|
||||
offset: Number of teams to skip before starting to check.
|
||||
|
||||
Returns:
|
||||
List of project API keys that need cache refresh
|
||||
|
||||
Raises:
|
||||
Exception: Database connectivity or other systemic issues that should trigger retries
|
||||
"""
|
||||
# Build queryset with pagination support
|
||||
# Note: Filtering by project__isnull=False may be needed in production
|
||||
# but is removed for testing since test teams often don't have projects
|
||||
queryset = Team.objects.values_list("api_token", flat=True).order_by("id") # Consistent ordering for pagination
|
||||
|
||||
# Apply pagination if specified
|
||||
if offset > 0:
|
||||
queryset = queryset[offset:]
|
||||
if limit is not None:
|
||||
queryset = queryset[:limit]
|
||||
|
||||
# Check which teams have missing caches in HyperCache
|
||||
teams_needing_refresh = []
|
||||
|
||||
for project_api_key in queryset:
|
||||
try:
|
||||
token_data = team_access_tokens_hypercache.get_from_cache(project_api_key)
|
||||
if token_data is None:
|
||||
teams_needing_refresh.append(project_api_key)
|
||||
except Exception as e:
|
||||
# Log individual team cache check failure but continue with others
|
||||
logger.warning(
|
||||
f"Failed to check cache for team {project_api_key}: {e}",
|
||||
extra={"project_api_key": project_api_key, "error": str(e)},
|
||||
)
|
||||
# Assume this team needs refresh if we can't check its cache
|
||||
teams_needing_refresh.append(project_api_key)
|
||||
|
||||
return teams_needing_refresh
|
||||
|
||||
|
||||
def get_teams_needing_cache_refresh_paginated(batch_size: int = 1000):
|
||||
"""
|
||||
Generator that yields batches of teams needing cache refresh.
|
||||
|
||||
This is the recommended approach for processing large numbers of teams
|
||||
to avoid memory issues. It processes teams in chunks and yields each
|
||||
batch as it's completed.
|
||||
|
||||
Args:
|
||||
batch_size: Number of teams to process per batch
|
||||
|
||||
Yields:
|
||||
List[str]: Batches of project API keys that need cache refresh
|
||||
"""
|
||||
offset = 0
|
||||
|
||||
while True:
|
||||
batch = get_teams_needing_cache_refresh(limit=batch_size, offset=offset)
|
||||
|
||||
if not batch:
|
||||
# No more teams to process
|
||||
break
|
||||
|
||||
yield batch
|
||||
offset += batch_size
|
||||
|
||||
# If we got fewer teams than requested, we've reached the end
|
||||
if len(batch) < batch_size:
|
||||
break
|
||||
|
||||
|
||||
def get_teams_for_user_personal_api_keys(user_id: int) -> set[str]:
|
||||
"""
|
||||
Get all project API keys for teams that a user's PersonalAPIKeys have access to.
|
||||
|
||||
This function eliminates N+1 queries by determining all affected teams for
|
||||
a user's personal API keys in minimal database queries (1-3 queries maximum).
|
||||
|
||||
Args:
|
||||
user_id: The user ID whose personal API keys to analyze
|
||||
|
||||
Returns:
|
||||
Set of project API keys (strings) for all teams the user's keys have access to
|
||||
"""
|
||||
# Get all personal API keys for the user
|
||||
personal_keys = list(PersonalAPIKey.objects.filter(user_id=user_id).values("id", "scoped_teams"))
|
||||
|
||||
if not personal_keys:
|
||||
return set()
|
||||
|
||||
affected_teams = set()
|
||||
scoped_team_ids: set[int] = set()
|
||||
has_unscoped_keys = False
|
||||
|
||||
# Analyze all keys to determine which teams they affect
|
||||
for key_data in personal_keys:
|
||||
scoped_teams = key_data["scoped_teams"] or []
|
||||
if scoped_teams:
|
||||
# Scoped key - add specific team IDs
|
||||
scoped_team_ids.update(scoped_teams)
|
||||
else:
|
||||
# Unscoped key - will need all teams in user's organizations
|
||||
has_unscoped_keys = True
|
||||
|
||||
# Get project API keys for scoped teams (if any) in one query
|
||||
if scoped_team_ids:
|
||||
scoped_team_tokens = Team.objects.filter(id__in=scoped_team_ids).values_list("api_token", flat=True)
|
||||
affected_teams.update(scoped_team_tokens)
|
||||
|
||||
# Get project API keys for unscoped keys (if any) in two queries maximum
|
||||
if has_unscoped_keys:
|
||||
# Get user's organizations
|
||||
user_organizations = OrganizationMembership.objects.filter(user_id=user_id).values_list(
|
||||
"organization_id", flat=True
|
||||
)
|
||||
|
||||
if user_organizations:
|
||||
# Get all teams in those organizations
|
||||
org_team_tokens = Team.objects.filter(organization_id__in=user_organizations).values_list(
|
||||
"api_token", flat=True
|
||||
)
|
||||
affected_teams.update(org_team_tokens)
|
||||
|
||||
return affected_teams
|
||||
|
||||
|
||||
def get_teams_for_single_personal_api_key(personal_api_key_instance: "PersonalAPIKey") -> list[str]:
|
||||
"""
|
||||
Get project API keys for teams that a single PersonalAPIKey has access to.
|
||||
|
||||
This is a helper function that internally uses the optimized user-based function.
|
||||
For better performance when processing multiple keys for the same user, use
|
||||
get_teams_for_user_personal_api_keys() directly.
|
||||
|
||||
Args:
|
||||
personal_api_key_instance: The PersonalAPIKey instance
|
||||
|
||||
Returns:
|
||||
List of project API keys (strings) for teams the key has access to
|
||||
"""
|
||||
user_id = personal_api_key_instance.user_id
|
||||
all_user_teams = get_teams_for_user_personal_api_keys(user_id)
|
||||
|
||||
# Filter to only teams this specific key has access to
|
||||
scoped_teams = personal_api_key_instance.scoped_teams or []
|
||||
|
||||
if scoped_teams:
|
||||
# Scoped key - only return teams in the scoped list
|
||||
scoped_team_tokens = set(Team.objects.filter(id__in=scoped_teams).values_list("api_token", flat=True))
|
||||
# Return intersection of user teams and scoped teams
|
||||
result = list(all_user_teams.intersection(scoped_team_tokens))
|
||||
else:
|
||||
# Unscoped key - return all teams the user has access to
|
||||
result = list(all_user_teams)
|
||||
|
||||
return result
|
||||
326
posthog/storage/team_access_cache_signal_handlers.py
Normal file
326
posthog/storage/team_access_cache_signal_handlers.py
Normal file
@@ -0,0 +1,326 @@
|
||||
"""
|
||||
Signal handler functions for team access token cache invalidation.
|
||||
|
||||
This module provides handler functions that automatically update
|
||||
the team access token cache when PersonalAPIKey or Team models change,
|
||||
ensuring cache consistency with the database.
|
||||
|
||||
Note: Signal subscriptions are registered in posthog/models/remote_config.py
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
from posthog.models.organization import OrganizationMembership
|
||||
from posthog.models.personal_api_key import PersonalAPIKey
|
||||
from posthog.models.team.team import Team
|
||||
from posthog.storage.team_access_cache import (
|
||||
get_teams_for_user_personal_api_keys,
|
||||
team_access_cache,
|
||||
warm_team_token_cache,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def capture_old_api_token(instance: Team, **kwargs):
|
||||
"""
|
||||
Capture the old api_token value before save for cleanup.
|
||||
|
||||
This pre_save handler stores the old api_token value so the post_save
|
||||
handler can clean up the old cache entry when the token changes.
|
||||
"""
|
||||
if instance.pk: # Only for existing teams
|
||||
try:
|
||||
old_team = Team.objects.only("api_token").get(pk=instance.pk)
|
||||
# Store the old api_token value for post_save cleanup
|
||||
instance._old_api_token = old_team.api_token # type: ignore[attr-defined]
|
||||
except Team.DoesNotExist:
|
||||
pass
|
||||
|
||||
|
||||
def update_team_authentication_cache(instance: Team, created: bool, **kwargs):
|
||||
"""
|
||||
Rebuild team access cache when Team model is saved.
|
||||
|
||||
This handler only rebuilds the cache when authentication-related fields change
|
||||
to avoid unnecessary cache operations for unrelated team updates.
|
||||
"""
|
||||
try:
|
||||
if not instance.api_token:
|
||||
return
|
||||
|
||||
if created:
|
||||
logger.debug(f"New team created: {instance.pk}")
|
||||
return
|
||||
|
||||
# Check if this is a new team being created
|
||||
if hasattr(instance, "_state") and instance._state.adding:
|
||||
logger.debug(f"Team {instance.pk} is being created, skipping cache update")
|
||||
return
|
||||
|
||||
# Check if api_token changed (project API key regeneration)
|
||||
# We look for the old value stored before save
|
||||
old_api_token = getattr(instance, "_old_api_token", None)
|
||||
|
||||
# If update_fields is specified, only rebuild cache if auth-related fields changed
|
||||
update_fields = kwargs.get("update_fields")
|
||||
auth_related_fields = {"api_token", "secret_api_token", "secret_api_token_backup"}
|
||||
|
||||
if update_fields is not None:
|
||||
# Convert update_fields to set for efficient intersection
|
||||
updated_fields = set(update_fields) if update_fields else set()
|
||||
|
||||
# Check if any auth-related fields were updated
|
||||
if not updated_fields.intersection(auth_related_fields):
|
||||
logger.debug(
|
||||
f"Team {instance.pk} updated but no auth fields changed, skipping cache update",
|
||||
extra={
|
||||
"team_id": instance.pk,
|
||||
"updated_fields": list(updated_fields),
|
||||
"auth_fields": list(auth_related_fields),
|
||||
},
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
# Clean up old cache if api_token changed
|
||||
if old_api_token and old_api_token != instance.api_token:
|
||||
team_access_cache.invalidate_team(old_api_token)
|
||||
logger.info(
|
||||
f"Invalidated old cache for team {instance.pk} after API token change",
|
||||
extra={"team_id": instance.pk, "old_api_token": old_api_token, "new_api_token": instance.api_token},
|
||||
)
|
||||
|
||||
warm_team_token_cache(instance.api_token)
|
||||
logger.info(
|
||||
f"Rebuilt team access cache for team {instance.pk} after auth field change",
|
||||
extra={"team_id": instance.pk, "project_api_key": instance.api_token},
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to rebuild cache for team {instance.pk}, falling back to invalidation: {e}",
|
||||
extra={"team_id": instance.pk},
|
||||
)
|
||||
# Fall back to invalidation if rebuild fails
|
||||
team_access_cache.invalidate_team(instance.api_token)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Error updating cache on team save for team {instance.pk}: {e}")
|
||||
|
||||
|
||||
def update_team_authentication_cache_on_delete(instance: Team, **kwargs):
|
||||
"""
|
||||
Invalidate team access cache when Team is deleted.
|
||||
"""
|
||||
try:
|
||||
if instance.api_token:
|
||||
team_access_cache.invalidate_team(instance.api_token)
|
||||
logger.info(f"Invalidated cache for deleted team {instance.pk}")
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Error invalidating cache on team delete for team {instance.pk}: {e}")
|
||||
|
||||
|
||||
def update_personal_api_key_authentication_cache(instance: PersonalAPIKey):
|
||||
"""
|
||||
Update team access cache when PersonalAPIKey is saved.
|
||||
|
||||
This handler warms the cache for all teams that the user's personal API keys have
|
||||
access to. For optimal performance, it uses the user-based function to warm all
|
||||
affected teams at once. Since warming completely rebuilds the cache from the database,
|
||||
no prior invalidation is needed.
|
||||
"""
|
||||
# Get the list of affected teams using optimized user-based function
|
||||
affected_teams = get_teams_for_user_personal_api_keys(instance.user_id)
|
||||
|
||||
# Warm the cache for each affected team
|
||||
for project_api_key in affected_teams:
|
||||
try:
|
||||
warm_team_token_cache(project_api_key)
|
||||
logger.debug(f"Warmed cache for team {project_api_key} after PersonalAPIKey change")
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to warm cache for team {project_api_key} after PersonalAPIKey change: {e}",
|
||||
extra={"project_api_key": project_api_key, "personal_api_key_id": instance.id},
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Updated authentication cache for {len(affected_teams)} teams after PersonalAPIKey change",
|
||||
extra={"personal_api_key_id": instance.id, "affected_teams_count": len(affected_teams)},
|
||||
)
|
||||
|
||||
|
||||
def update_user_authentication_cache(instance, **kwargs):
|
||||
"""
|
||||
Update team access caches when a User's status changes.
|
||||
|
||||
When a user is activated/deactivated, their Personal API Keys need to be
|
||||
added/removed from the authentication caches of all teams they have access to.
|
||||
This includes both scoped and unscoped keys.
|
||||
|
||||
Note: The update_fields filtering is now handled by the user_saved signal handler
|
||||
in remote_config.py before calling this function.
|
||||
|
||||
Args:
|
||||
sender: The model class (User)
|
||||
instance: The User instance that changed
|
||||
**kwargs: Additional signal arguments
|
||||
"""
|
||||
try:
|
||||
# Get all teams that the user's personal API keys have access to
|
||||
affected_teams = get_teams_for_user_personal_api_keys(instance.id)
|
||||
|
||||
_warm_cache_for_teams(affected_teams, "user status change", str(instance.id), None)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Error updating authentication cache for user {instance.id} status change: {e}")
|
||||
|
||||
|
||||
def update_personal_api_key_deleted_cache(instance: PersonalAPIKey):
|
||||
"""
|
||||
Update team access caches when a PersonalAPIKey is deleted.
|
||||
|
||||
When a PersonalAPIKey is deleted, it needs to be removed from all team caches
|
||||
that it had access to. This includes both scoped and unscoped keys.
|
||||
|
||||
Args:
|
||||
instance: The PersonalAPIKey instance that was deleted
|
||||
"""
|
||||
try:
|
||||
# Get all teams that this specific key had access to
|
||||
# We need to determine this based on the key's scoping
|
||||
scoped_teams = instance.scoped_teams or []
|
||||
|
||||
if scoped_teams:
|
||||
# Scoped key - only affects specific teams
|
||||
team_api_tokens = Team.objects.filter(id__in=scoped_teams).values_list("api_token", flat=True)
|
||||
else:
|
||||
# Unscoped key - affects all teams in user's organizations
|
||||
user_organizations = OrganizationMembership.objects.filter(user_id=instance.user_id).values_list(
|
||||
"organization_id", flat=True
|
||||
)
|
||||
|
||||
if user_organizations:
|
||||
team_api_tokens = Team.objects.filter(organization_id__in=user_organizations).values_list(
|
||||
"api_token", flat=True
|
||||
)
|
||||
else:
|
||||
team_api_tokens = []
|
||||
|
||||
_warm_cache_for_teams(team_api_tokens, "PersonalAPIKey deletion", str(instance.user_id), None)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Error updating team caches after PersonalAPIKey deletion: {e}",
|
||||
extra={
|
||||
"personal_api_key_id": getattr(instance, "id", None),
|
||||
"user_id": getattr(instance, "user_id", None),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def update_organization_membership_created_cache(membership_instance):
|
||||
"""
|
||||
Update team access caches when an OrganizationMembership is created.
|
||||
|
||||
When a user is added to an organization, their unscoped Personal API Keys should
|
||||
gain access to teams within that organization. This function updates the caches for
|
||||
all teams in the organization that was joined.
|
||||
|
||||
Args:
|
||||
membership_instance: The OrganizationMembership instance that was created
|
||||
"""
|
||||
try:
|
||||
# Get all teams in the organization the user joined
|
||||
organization_id = membership_instance.organization_id
|
||||
user_id = membership_instance.user_id
|
||||
|
||||
team_api_tokens = Team.objects.filter(organization_id=organization_id).values_list("api_token", flat=True)
|
||||
|
||||
_warm_cache_for_teams(team_api_tokens, "adding user to organization", user_id, organization_id)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Error updating team caches after OrganizationMembership creation: {e}",
|
||||
extra={
|
||||
"user_id": getattr(membership_instance, "user_id", None),
|
||||
"organization_id": getattr(membership_instance, "organization_id", None),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def update_organization_membership_deleted_cache(membership_instance):
|
||||
"""
|
||||
Update team access caches when an OrganizationMembership is deleted.
|
||||
|
||||
When a user is removed from an organization, their Personal API Keys should no longer
|
||||
have access to teams within that organization. This function updates the caches for
|
||||
all teams in the organization that was removed from.
|
||||
|
||||
This is different from update_user_authentication_cache because we need to update
|
||||
the teams from the organization the user was REMOVED from, not their current teams
|
||||
(which won't include the removed organization anymore).
|
||||
|
||||
Args:
|
||||
membership_instance: The OrganizationMembership instance that was deleted
|
||||
"""
|
||||
try:
|
||||
# Get all teams in the organization the user was removed from
|
||||
organization_id = membership_instance.organization_id
|
||||
user_id = membership_instance.user_id
|
||||
|
||||
team_api_tokens = Team.objects.filter(organization_id=organization_id).values_list("api_token", flat=True)
|
||||
|
||||
_warm_cache_for_teams(team_api_tokens, "removing user from organization", user_id, organization_id)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Error updating team caches after OrganizationMembership deletion: {e}",
|
||||
extra={
|
||||
"user_id": getattr(membership_instance, "user_id", None),
|
||||
"organization_id": getattr(membership_instance, "organization_id", None),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def _warm_cache_for_teams(
|
||||
team_api_tokens: set[str] | list[str], action: str, user_id: str, organization_id: str | None
|
||||
):
|
||||
"""
|
||||
Warm the cache for a set of teams.
|
||||
"""
|
||||
if not team_api_tokens:
|
||||
logger.debug(f"No teams found in organization {organization_id}, no cache updates needed")
|
||||
return
|
||||
|
||||
# Warm the cache for each team in the organization
|
||||
# This will rebuild the cache without the removed user's keys
|
||||
for project_api_key in team_api_tokens:
|
||||
try:
|
||||
warm_team_token_cache(project_api_key)
|
||||
logger.debug(
|
||||
f"Warmed cache for team {project_api_key} after {action}",
|
||||
extra={
|
||||
"project_api_key": project_api_key,
|
||||
"user_id": user_id,
|
||||
"organization_id": organization_id,
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to warm cache for team {project_api_key} after {action}: {e}",
|
||||
extra={
|
||||
"project_api_key": project_api_key,
|
||||
"user_id": user_id,
|
||||
"organization_id": organization_id,
|
||||
},
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Updated {len(team_api_tokens)} team caches after {action}",
|
||||
extra={
|
||||
"user_id": user_id,
|
||||
"organization_id": organization_id,
|
||||
"teams_updated": len(team_api_tokens),
|
||||
},
|
||||
)
|
||||
2609
posthog/storage/test/test_team_access_cache.py
Normal file
2609
posthog/storage/test/test_team_access_cache.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -54,6 +54,7 @@ from posthog.tasks.tasks import (
|
||||
update_survey_iteration,
|
||||
verify_persons_data_in_sync,
|
||||
)
|
||||
from posthog.tasks.team_access_cache_tasks import warm_all_team_access_caches_task
|
||||
from posthog.utils import get_crontab
|
||||
|
||||
TWENTY_FOUR_HOURS = 24 * 60 * 60
|
||||
@@ -95,6 +96,14 @@ def setup_periodic_tasks(sender: Celery, **kwargs: Any) -> None:
|
||||
name="schedule warming for largest teams",
|
||||
)
|
||||
|
||||
# Team access cache warming - every 10 minutes
|
||||
add_periodic_task_with_expiry(
|
||||
sender,
|
||||
600, # Every 10 minutes (no TTL, just fill missing entries)
|
||||
warm_all_team_access_caches_task.s(),
|
||||
name="warm team access caches",
|
||||
)
|
||||
|
||||
# Update events table partitions twice a week
|
||||
sender.add_periodic_task(
|
||||
crontab(day_of_week="mon,fri", hour="0", minute="0"),
|
||||
|
||||
117
posthog/tasks/team_access_cache_tasks.py
Normal file
117
posthog/tasks/team_access_cache_tasks.py
Normal file
@@ -0,0 +1,117 @@
|
||||
"""
|
||||
Background tasks for warming team access token caches.
|
||||
|
||||
This module provides Celery tasks to periodically warm the team access token
|
||||
caches, ensuring that the cached authentication system has fresh data.
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
from django.conf import settings
|
||||
|
||||
from celery import shared_task
|
||||
from celery.app.task import Task
|
||||
|
||||
from posthog.storage.team_access_cache import get_teams_needing_cache_refresh_paginated, warm_team_token_cache
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Configuration
|
||||
CACHE_WARMING_BATCH_SIZE = getattr(settings, "CACHE_WARMING_BATCH_SIZE", 50)
|
||||
CACHE_WARMING_PAGE_SIZE = getattr(settings, "CACHE_WARMING_PAGE_SIZE", 1000) # Teams per database page
|
||||
|
||||
|
||||
@shared_task(bind=True, max_retries=3)
|
||||
def warm_team_cache_task(self: "Task", project_api_key: str) -> dict:
|
||||
"""
|
||||
Warm the token cache for a specific team.
|
||||
|
||||
Args:
|
||||
project_api_key: The team's project API key
|
||||
|
||||
Returns:
|
||||
Dictionary with operation results
|
||||
"""
|
||||
success = warm_team_token_cache(project_api_key)
|
||||
|
||||
if not success:
|
||||
# Log a warning, but don't retry. We'll let the next scheduled task pick it up.
|
||||
logger.warning(f"Failed to warm cache for team {project_api_key}")
|
||||
return {"status": "failure", "project_api_key": project_api_key}
|
||||
|
||||
logger.info(
|
||||
f"Successfully warmed cache for team {project_api_key}",
|
||||
extra={"project_api_key": project_api_key},
|
||||
)
|
||||
|
||||
return {"status": "success", "project_api_key": project_api_key}
|
||||
|
||||
|
||||
@shared_task(bind=True, max_retries=1)
|
||||
def warm_all_team_access_caches_task(self: "Task") -> dict:
|
||||
"""
|
||||
Warm caches for all teams that need refreshing.
|
||||
|
||||
This task identifies teams with expired or missing caches and
|
||||
schedules individual warming tasks for each team.
|
||||
|
||||
Returns:
|
||||
Dictionary with operation results
|
||||
"""
|
||||
try:
|
||||
teams_scheduled = 0
|
||||
failed_teams = 0
|
||||
teams_pages_processed = 0
|
||||
total_teams_found = 0
|
||||
|
||||
# Use paginated approach for memory efficiency
|
||||
logger.info(f"Using paginated cache warming with page size {CACHE_WARMING_PAGE_SIZE}")
|
||||
|
||||
for teams_page in get_teams_needing_cache_refresh_paginated(batch_size=CACHE_WARMING_PAGE_SIZE):
|
||||
teams_pages_processed += 1
|
||||
|
||||
if not teams_page:
|
||||
continue
|
||||
|
||||
total_teams_found += len(teams_page)
|
||||
|
||||
logger.debug(
|
||||
f"Processing page {teams_pages_processed} with {len(teams_page)} teams needing refresh",
|
||||
extra={"page": teams_pages_processed, "teams_in_page": len(teams_page)},
|
||||
)
|
||||
|
||||
# Process teams in batches to avoid overwhelming the system
|
||||
for i in range(0, len(teams_page), CACHE_WARMING_BATCH_SIZE):
|
||||
batch = teams_page[i : i + CACHE_WARMING_BATCH_SIZE]
|
||||
|
||||
# Schedule warming tasks for this batch
|
||||
for project_api_key in batch:
|
||||
try:
|
||||
warm_team_cache_task.delay(project_api_key)
|
||||
teams_scheduled += 1
|
||||
except Exception as e:
|
||||
# Log individual team scheduling failure but continue with others
|
||||
failed_teams += 1
|
||||
logger.warning(
|
||||
f"Failed to schedule cache warming for team {project_api_key}: {e}",
|
||||
extra={"project_api_key": project_api_key, "error": str(e)},
|
||||
)
|
||||
|
||||
logger.debug(f"Scheduled cache warming for batch of {len(batch)} teams")
|
||||
|
||||
logger.info(
|
||||
"Cache warming completed",
|
||||
extra={"teams_found": total_teams_found, "teams_scheduled": teams_scheduled, "failed_teams": failed_teams},
|
||||
)
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"teams_found": total_teams_found,
|
||||
"teams_scheduled": teams_scheduled,
|
||||
"failed_teams": failed_teams,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
# Retry for systemic failures (database connectivity, etc.)
|
||||
logger.exception(f"Systemic failure in cache warming batch task: {e}")
|
||||
raise self.retry(exc=e, countdown=300) # 5 minutes
|
||||
167
posthog/tasks/test/test_team_access_cache_tasks.py
Normal file
167
posthog/tasks/test/test_team_access_cache_tasks.py
Normal file
@@ -0,0 +1,167 @@
|
||||
"""
|
||||
Tests for team access cache Celery tasks.
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from django.test import TestCase
|
||||
|
||||
from posthog.tasks.team_access_cache_tasks import warm_all_team_access_caches_task, warm_team_cache_task
|
||||
|
||||
|
||||
class TestWarmTeamCacheTask(TestCase):
|
||||
"""Test the individual team cache warming task."""
|
||||
|
||||
@patch("posthog.tasks.team_access_cache_tasks.warm_team_token_cache")
|
||||
def test_warm_team_cache_task_success(self, mock_warm: MagicMock) -> None:
|
||||
"""Test successful cache warming for a team."""
|
||||
mock_warm.return_value = True
|
||||
|
||||
project_api_key = "phs_test_team_123"
|
||||
|
||||
result = warm_team_cache_task(project_api_key)
|
||||
|
||||
mock_warm.assert_called_once_with(project_api_key)
|
||||
|
||||
assert result["status"] == "success"
|
||||
assert result["project_api_key"] == project_api_key
|
||||
|
||||
@patch("posthog.tasks.team_access_cache_tasks.warm_team_token_cache")
|
||||
def test_warm_team_cache_task_failure(self, mock_warm: MagicMock) -> None:
|
||||
"""Test that cache warming failure does not trigger retry."""
|
||||
mock_warm.return_value = False
|
||||
|
||||
project_api_key = "phs_test_team_123"
|
||||
|
||||
result = warm_team_cache_task(project_api_key)
|
||||
|
||||
mock_warm.assert_called_once_with(project_api_key)
|
||||
|
||||
assert result["status"] == "failure"
|
||||
assert result["project_api_key"] == project_api_key
|
||||
|
||||
|
||||
class TestWarmAllTeamsCachesTask(TestCase):
|
||||
"""Test the batch cache warming task."""
|
||||
|
||||
@patch("posthog.tasks.team_access_cache_tasks.get_teams_needing_cache_refresh_paginated")
|
||||
def test_warm_all_team_access_caches_task_no_teams(self, mock_get_teams_paginated: MagicMock) -> None:
|
||||
"""Test batch warming when no teams need refresh."""
|
||||
mock_get_teams_paginated.return_value = iter([])
|
||||
|
||||
result = warm_all_team_access_caches_task()
|
||||
|
||||
assert result["status"] == "success"
|
||||
assert result["teams_found"] == 0
|
||||
assert result["teams_scheduled"] == 0
|
||||
assert result["failed_teams"] == 0
|
||||
|
||||
@patch("posthog.tasks.team_access_cache_tasks.get_teams_needing_cache_refresh_paginated")
|
||||
@patch("posthog.tasks.team_access_cache_tasks.warm_team_cache_task.delay")
|
||||
@patch("posthog.tasks.team_access_cache_tasks.CACHE_WARMING_BATCH_SIZE", 2)
|
||||
def test_warm_all_team_access_caches_task_batching(
|
||||
self, mock_delay: MagicMock, mock_get_teams_paginated: MagicMock
|
||||
) -> None:
|
||||
"""Test that teams are processed in configured batches."""
|
||||
# Setup mock teams - more than batch size, split across pages
|
||||
page1 = ["phs_team1_123", "phs_team2_456"]
|
||||
page2 = ["phs_team3_789", "phs_team4_012"]
|
||||
mock_get_teams_paginated.return_value = iter([page1, page2])
|
||||
|
||||
result = warm_all_team_access_caches_task()
|
||||
|
||||
# Verify all teams were scheduled despite batching
|
||||
assert mock_delay.call_count == 4
|
||||
all_teams = page1 + page2
|
||||
for team in all_teams:
|
||||
mock_delay.assert_any_call(team)
|
||||
|
||||
assert result["status"] == "success"
|
||||
assert result["teams_scheduled"] == 4
|
||||
assert result["teams_found"] == 4
|
||||
assert result["failed_teams"] == 0
|
||||
|
||||
@patch("posthog.tasks.team_access_cache_tasks.get_teams_needing_cache_refresh_paginated")
|
||||
@patch("posthog.tasks.team_access_cache_tasks.warm_team_cache_task.delay")
|
||||
def test_warm_all_team_access_caches_task_handles_individual_failures(
|
||||
self, mock_delay: MagicMock, mock_get_teams_paginated: MagicMock
|
||||
) -> None:
|
||||
"""Test that batch warming handles individual team failures gracefully."""
|
||||
# Setup mocks - single page of teams
|
||||
mock_teams = ["phs_team1_123", "phs_team2_456", "phs_team3_789"]
|
||||
mock_get_teams_paginated.return_value = iter([mock_teams])
|
||||
|
||||
# Make delay fail for the second team only
|
||||
def delay_side_effect(project_api_key: str) -> None:
|
||||
if project_api_key == "phs_team2_456":
|
||||
raise Exception("Task scheduling failed for team 2")
|
||||
return None
|
||||
|
||||
mock_delay.side_effect = delay_side_effect
|
||||
|
||||
# Execute task - should NOT raise exception, but should handle failure gracefully
|
||||
result = warm_all_team_access_caches_task()
|
||||
|
||||
mock_get_teams_paginated.assert_called_once()
|
||||
|
||||
# Verify all teams were attempted
|
||||
assert mock_delay.call_count == 3
|
||||
mock_delay.assert_any_call("phs_team1_123")
|
||||
mock_delay.assert_any_call("phs_team2_456")
|
||||
mock_delay.assert_any_call("phs_team3_789")
|
||||
|
||||
# Verify result shows partial success
|
||||
assert result["status"] == "success"
|
||||
assert result["teams_scheduled"] == 2 # team2 failed to schedule
|
||||
assert result["failed_teams"] == 1
|
||||
assert result["teams_found"] == 3
|
||||
|
||||
@patch("posthog.tasks.team_access_cache_tasks.get_teams_needing_cache_refresh_paginated")
|
||||
def test_warm_all_team_access_caches_task_handles_systemic_failures(
|
||||
self, mock_get_teams_paginated: MagicMock
|
||||
) -> None:
|
||||
"""Test that batch warming retries on systemic failures like database connectivity issues."""
|
||||
# Make get_teams_needing_cache_refresh_paginated fail (systemic issue)
|
||||
mock_get_teams_paginated.side_effect = Exception("Database connection failed")
|
||||
|
||||
# Execute task - should raise some exception that triggers retry
|
||||
# The retry mechanism may raise the original exception or a Retry exception
|
||||
with self.assertRaises(Exception) as cm:
|
||||
warm_all_team_access_caches_task()
|
||||
|
||||
mock_get_teams_paginated.assert_called_once()
|
||||
|
||||
# Verify the exception is related to our test failure
|
||||
self.assertIn("Database connection failed", str(cm.exception))
|
||||
|
||||
|
||||
class TestTaskIntegration(TestCase):
|
||||
"""Integration tests for the complete task flow."""
|
||||
|
||||
@patch("posthog.tasks.team_access_cache_tasks.warm_team_token_cache")
|
||||
@patch("posthog.tasks.team_access_cache_tasks.get_teams_needing_cache_refresh_paginated")
|
||||
def test_complete_cache_refresh_flow(self, mock_get_teams_paginated: MagicMock, mock_warm: MagicMock) -> None:
|
||||
"""Test the complete flow from batch task to individual warming."""
|
||||
# Setup mocks - single page with one team
|
||||
mock_teams = ["phs_team1_123"]
|
||||
mock_get_teams_paginated.return_value = iter([mock_teams])
|
||||
mock_warm.return_value = True
|
||||
|
||||
# Execute batch task (without actually scheduling async tasks)
|
||||
with patch("posthog.tasks.team_access_cache_tasks.warm_team_cache_task.delay") as mock_delay:
|
||||
batch_result = warm_all_team_access_caches_task()
|
||||
|
||||
# Verify batch task scheduled individual task
|
||||
mock_delay.assert_called_once_with("phs_team1_123")
|
||||
|
||||
# Simulate individual task execution
|
||||
individual_result = warm_team_cache_task("phs_team1_123")
|
||||
|
||||
# Verify both tasks succeeded
|
||||
assert batch_result["status"] == "success"
|
||||
assert batch_result["teams_scheduled"] == 1
|
||||
assert batch_result["teams_found"] == 1
|
||||
assert batch_result["failed_teams"] == 0
|
||||
|
||||
assert individual_result["status"] == "success"
|
||||
assert individual_result["project_api_key"] == "phs_team1_123"
|
||||
Reference in New Issue
Block a user