feat(data-warehouse): Added logic for stopping syncs when the user will go over their billi… (#34356)

This commit is contained in:
Tom Owers
2025-07-01 17:23:36 +01:00
committed by GitHub
parent a56fbafc6c
commit d83668587f
7 changed files with 448 additions and 2 deletions

View File

@@ -1,6 +1,38 @@
ee/api/test/base.py:0: error: "setUpTestData" undefined in superclass [misc]
ee/api/test/base.py:0: error: Incompatible types in assignment (expression has type "None", variable has type "License") [assignment]
ee/api/test/test_billing.py:0: error: "_MonkeyPatchedResponse" has no attribute "url" [attr-defined]
ee/api/test/test_billing.py:0: error: "_MonkeyPatchedResponse" has no attribute "url" [attr-defined]
ee/api/test/test_billing.py:0: error: "_MonkeyPatchedResponse" has no attribute "url" [attr-defined]
ee/api/test/test_billing.py:0: error: "_MonkeyPatchedResponse" has no attribute "url" [attr-defined]
ee/api/test/test_billing.py:0: error: "_MonkeyPatchedResponse" has no attribute "url" [attr-defined]
ee/api/test/test_billing.py:0: error: "_MonkeyPatchedResponse" has no attribute "url" [attr-defined]
ee/api/test/test_billing.py:0: error: "_MonkeyPatchedResponse" has no attribute "url" [attr-defined]
ee/api/test/test_billing.py:0: error: "_MonkeyPatchedResponse" has no attribute "url" [attr-defined]
ee/api/test/test_billing.py:0: error: "_MonkeyPatchedResponse" has no attribute "url" [attr-defined]
ee/api/test/test_billing.py:0: error: Argument 1 to "assertDictEqual" of "TestCase" has incompatible type "None"; expected "Mapping[Any, object]" [arg-type]
ee/api/test/test_billing.py:0: error: Extra key "projected_amount_usd_with_limit" for TypedDict "CustomerProduct" [typeddict-unknown-key]
ee/api/test/test_billing.py:0: error: Extra keys ("free_allocation", "has_exceeded_limit", "percentage_usage") for TypedDict "CustomerProductAddon" [typeddict-unknown-key]
ee/api/test/test_billing.py:0: error: Incompatible types (expression has type "float", TypedDict item "current_amount_usd" has type "Decimal | None") [typeddict-item]
ee/api/test/test_billing.py:0: error: Incompatible types (expression has type "float", TypedDict item "current_amount_usd" has type "Decimal") [typeddict-item]
ee/api/test/test_billing.py:0: error: Incompatible types (expression has type "float", TypedDict item "projected_amount_usd" has type "Decimal") [typeddict-item]
ee/api/test/test_billing.py:0: error: Incompatible types (expression has type "int", TypedDict item "projected_amount" has type "Decimal") [typeddict-item]
ee/api/test/test_billing.py:0: error: Incompatible types (expression has type "list[dict[str, object]]", TypedDict item "tiers" has type "Tier | None") [typeddict-item]
ee/api/test/test_billing.py:0: error: Incompatible types (expression has type "str", TypedDict item "projected_amount_usd" has type "Decimal | None") [typeddict-item]
ee/api/test/test_billing.py:0: error: Incompatible types (expression has type "str", TypedDict item "unit_amount_usd" has type "Decimal | None") [typeddict-item]
ee/api/test/test_billing.py:0: error: Incompatible types (expression has type "str", TypedDict item "unit_amount_usd" has type "Decimal | None") [typeddict-item]
ee/api/test/test_billing.py:0: error: Item "None" of "License | None" has no attribute "plan" [union-attr]
ee/api/test/test_billing.py:0: error: Item "None" of "License | None" has no attribute "plan" [union-attr]
ee/api/test/test_billing.py:0: error: Item "None" of "License | None" has no attribute "valid_until" [union-attr]
ee/api/test/test_billing.py:0: error: Item "None" of "License | None" has no attribute "valid_until" [union-attr]
ee/api/test/test_billing.py:0: error: Missing keys ("available_product_features", "current_total_amount_usd_after_discount", "discount_percent", "discount_amount_usd") for TypedDict "CustomerInfo" [typeddict-item]
ee/api/test/test_billing.py:0: error: Missing keys ("current_total_amount_usd_after_discount", "discount_percent", "discount_amount_usd", "customer_trust_scores") for TypedDict "CustomerInfo" [typeddict-item]
ee/api/test/test_billing.py:0: error: Missing keys ("icon_key", "docs_url", "included_with_main_product", "inclusion_only", "unit", "contact_support") for TypedDict "CustomerProductAddon" [typeddict-item]
ee/billing/billing_manager.py:0: error: Incompatible types in assignment (expression has type "object", variable has type "bool | Combinable | None") [assignment]
ee/billing/test/test_billing_manager.py:0: error: Argument 2 to "update_org_details" of "BillingManager" has incompatible type "dict[str, dict[str, object]]"; expected "BillingStatus" [arg-type]
ee/billing/test/test_billing_manager.py:0: error: Cannot resolve keyword 'distinct_id' into field. Choices are: access_controls, explicit_team_membership, id, joined_at, level, organization, organization_id, role_membership, updated_at, user, user_id [misc]
ee/billing/test/test_billing_manager.py:0: error: Module "django.utils.timezone" does not explicitly export attribute "datetime" [attr-defined]
ee/billing/test/test_billing_manager.py:0: error: Module "django.utils.timezone" does not explicitly export attribute "datetime" [attr-defined]
ee/billing/test/test_billing_manager.py:0: error: Module "django.utils.timezone" does not explicitly export attribute "datetime" [attr-defined]
ee/clickhouse/queries/funnels/funnel_correlation.py:0: error: Statement is unreachable [unreachable]
ee/models/explicit_team_membership.py:0: error: Incompatible return value type (got "int", expected "Level") [return-value]
ee/models/license.py:0: error: "_T" has no attribute "plan" [attr-defined]

View File

@@ -333,8 +333,8 @@ class ExternalDataJobWorkflow(PostHogWorkflow):
)
except exceptions.ActivityError as e:
# Check if this is a WorkerShuttingDownError - implement Buffer One retry
if isinstance(e.cause, exceptions.ApplicationError) and e.cause.type == "WorkerShuttingDownError":
# Check if this is a WorkerShuttingDownError - implement Buffer One retry
schedule_id = str(inputs.external_data_schema_id)
await workflow.execute_activity(
trigger_schedule_buffer_one_activity,
@@ -342,6 +342,12 @@ class ExternalDataJobWorkflow(PostHogWorkflow):
start_to_close_timeout=dt.timedelta(minutes=10),
retry_policy=RetryPolicy(maximum_attempts=1),
)
elif (
isinstance(e.cause, exceptions.ApplicationError)
and e.cause.type == "BillingLimitsWillBeReachedException"
):
# Check if this is a BillingLimitsWillBeReachedException - update the job status
update_inputs.status = ExternalDataJob.Status.BILLING_LIMIT_TOO_LOW
else:
# Handle other activity errors normally
update_inputs.status = ExternalDataJob.Status.FAILED

View File

@@ -21,6 +21,7 @@ from posthog.temporal.data_imports.pipelines.pipeline.delta_table_helper import
from posthog.temporal.data_imports.pipelines.pipeline.hogql_schema import HogQLSchema
from posthog.temporal.data_imports.pipelines.pipeline.typings import SourceResponse
from posthog.temporal.data_imports.pipelines.pipeline.utils import (
BillingLimitsWillBeReachedException,
DuplicatePrimaryKeysException,
_append_debug_column_to_pyarrows_table,
_evolve_pyarrow_schema,
@@ -39,7 +40,7 @@ from posthog.temporal.data_imports.pipelines.pipeline_sync import (
from posthog.temporal.data_imports.pipelines.stripe.constants import (
CHARGE_RESOURCE_NAME as STRIPE_CHARGE_RESOURCE_NAME,
)
from posthog.temporal.data_imports.row_tracking import decrement_rows, increment_rows
from posthog.temporal.data_imports.row_tracking import decrement_rows, increment_rows, will_hit_billing_limit
from posthog.temporal.data_imports.util import prepare_s3_files_for_querying
from posthog.warehouse.models import (
DataWarehouseTable,
@@ -130,6 +131,12 @@ class PipelineNonDLT:
if self._resource.rows_to_sync:
increment_rows(self._job.team_id, self._schema.id, self._resource.rows_to_sync)
# Check billing limits against incoming rows
if will_hit_billing_limit(team_id=self._job.team_id, logger=self._logger):
raise BillingLimitsWillBeReachedException(
f"Your account will hit your Data Warehouse billing limits syncing {self._resource.name} with {self._resource.rows_to_sync} rows"
)
buffer: list[Any] = []
py_table = None
row_count = 0

View File

@@ -47,6 +47,10 @@ DEFAULT_NUMERIC_SCALE = 32 # Delta Lake maximum scale
DEFAULT_PARTITION_TARGET_SIZE_IN_BYTES = 200 * 1024 * 1024 # 200 MB
class BillingLimitsWillBeReachedException(Exception):
pass
class DuplicatePrimaryKeysException(Exception):
pass

View File

@@ -1,6 +1,13 @@
from contextlib import contextmanager
from dateutil import parser
from django.db.models import Sum
import uuid
from posthog.cloud_utils import get_cached_instance_license
from posthog.exceptions_capture import capture_exception
from posthog.models import Organization, Team
from posthog.settings import EE_AVAILABLE
from posthog.temporal.common.logger import FilteringBoundLogger
from posthog.warehouse.models import ExternalDataJob
from posthog.redis import get_client
@@ -76,3 +83,95 @@ def get_rows(team_id: int, schema_id: uuid.UUID | str) -> int:
return int(value)
return 0
def get_all_rows_for_team(team_id: int) -> int:
with _get_redis() as redis:
if not redis:
return 0
pairs = redis.hgetall(_get_hash_key(team_id))
return sum(int(v) for v in pairs.values())
def will_hit_billing_limit(team_id: int, logger: FilteringBoundLogger) -> bool:
if not EE_AVAILABLE:
return False
try:
from ee.billing.billing_manager import BillingManager
logger.debug("Running will_hit_billing_limit")
license = get_cached_instance_license()
billing_manager = BillingManager(license)
team = Team.objects.get(id=team_id)
organization: Organization = team.organization
all_teams_in_org: list[int] = [
value[0] for value in Team.objects.filter(organization_id=organization.id).values_list("id")
]
logger.debug(f"will_hit_billing_limit: Organisation_id = {organization.id}")
logger.debug(f"will_hit_billing_limit: Teams in org: {all_teams_in_org}")
billing_res = billing_manager.get_billing(organization)
logger.debug(f"will_hit_billing_limit: billing_res = {billing_res}")
current_billing_cycle_start = billing_res.get("billing_period", {}).get("current_period_start")
if current_billing_cycle_start is None:
logger.debug(
f"will_hit_billing_limit: returning early, no current_period_start available. current_billing_cycle_start = {current_billing_cycle_start}"
)
return False
current_billing_cycle_start_dt = parser.parse(current_billing_cycle_start)
logger.debug(f"will_hit_billing_limit: current_billing_cycle_start = {current_billing_cycle_start}")
usage_summary = billing_res["usage_summary"]
rows_synced_summary = usage_summary.get("rows_synced", None)
if not rows_synced_summary:
logger.debug(
f"will_hit_billing_limit: returning early, no rows_synced key in usage_summary. {usage_summary}"
)
return False
rows_synced_limit = rows_synced_summary.get("limit")
logger.debug(f"will_hit_billing_limit: rows_synced_limit = {rows_synced_limit}")
if rows_synced_limit is None or not isinstance(rows_synced_limit, int | float):
logger.debug("will_hit_billing_limit: rows_synced_limit is None or not a number, returning False")
return False
# Get all completed rows for all teams in org
rows_synced_in_billing_period_dict = ExternalDataJob.objects.filter(
team_id__in=all_teams_in_org,
finished_at__gte=current_billing_cycle_start_dt,
billable=True,
status=ExternalDataJob.Status.COMPLETED,
).aggregate(total_rows=Sum("rows_synced"))
rows_synced_in_billing_period = rows_synced_in_billing_period_dict.get("total_rows", 0) or 0
logger.debug(f"will_hit_billing_limit: rows_synced_in_billing_period = {rows_synced_in_billing_period}")
# Get all in-progress rows for all teams in org
existing_rows_in_progress = sum(get_all_rows_for_team(t_id) for t_id in all_teams_in_org)
expected_rows = rows_synced_in_billing_period + existing_rows_in_progress
result = expected_rows > rows_synced_limit
logger.debug(
f"will_hit_billing_limit: expected_rows = {expected_rows}. rows_synced_limit = {rows_synced_limit}. Returning {result}"
)
return result
except Exception as e:
logger.debug(f"will_hit_billing_limit: Failed with exception {e}")
capture_exception(e)
return False

View File

@@ -0,0 +1,145 @@
import contextlib
from datetime import datetime
from typing import Optional
from unittest import mock
import uuid
from zoneinfo import ZoneInfo
from posthog.models import Team
from posthog.tasks.usage_report import ExternalDataJob
from posthog.temporal.common.logger import FilteringBoundLogger
from posthog.temporal.data_imports.row_tracking import (
finish_row_tracking,
increment_rows,
setup_row_tracking,
will_hit_billing_limit,
)
from posthog.test.base import BaseTest
from posthog.warehouse.models import ExternalDataSource
class TestRowTracking(BaseTest):
def _logger(self) -> FilteringBoundLogger:
return mock.MagicMock()
@contextlib.contextmanager
def _setup_limits(self, limit: int):
from ee.api.test.test_billing import create_billing_customer
with mock.patch("ee.api.billing.requests.get") as mock_billing_request:
mock_res = create_billing_customer()
usage_summary = mock_res.get("usage_summary") or {}
mock_billing_request.return_value.status_code = 200
mock_billing_request.return_value.json.return_value = {
"license": {
"type": "scale",
},
"customer": {
**mock_res,
"usage_summary": {**usage_summary, "rows_synced": {"limit": limit, "usage": 0}},
},
}
yield
@contextlib.contextmanager
def _setup_redis_rows(self, rows: int, team_id: Optional[int] = None):
t_id = team_id or self.team.pk
schema_id = str(uuid.uuid4())
setup_row_tracking(t_id, schema_id)
increment_rows(t_id, schema_id, rows)
yield
finish_row_tracking(t_id, schema_id)
def _run(self, limit: int) -> bool:
from ee.models.license import License
License.objects.create(
key="12345::67890",
plan="enterprise",
valid_until=datetime(2038, 1, 19, 3, 14, 7, tzinfo=ZoneInfo("UTC")),
)
with self._setup_limits(limit):
return will_hit_billing_limit(self.team.pk, self._logger())
def test_row_tracking(self):
assert self._run(10) is False
def test_row_tracking_with_previous_jobs(self):
source = ExternalDataSource.objects.create(team=self.team)
ExternalDataJob.objects.create(
team=self.team,
rows_synced=11,
pipeline=source,
finished_at=datetime.now(),
billable=True,
status=ExternalDataJob.Status.COMPLETED,
)
assert self._run(10) is True
def test_row_tracking_with_previous_incomplete_jobs(self):
source = ExternalDataSource.objects.create(team=self.team)
ExternalDataJob.objects.create(
team=self.team,
rows_synced=11,
pipeline=source,
finished_at=datetime.now(),
billable=True,
status=ExternalDataJob.Status.RUNNING,
)
assert self._run(10) is False
def test_row_tracking_with_previous_no_finished_at_jobs(self):
source = ExternalDataSource.objects.create(team=self.team)
ExternalDataJob.objects.create(
team=self.team,
rows_synced=11,
pipeline=source,
finished_at=None,
billable=True,
status=ExternalDataJob.Status.COMPLETED,
)
assert self._run(10) is False
def test_row_tracking_with_previous_unbillable_jobs(self):
source = ExternalDataSource.objects.create(team=self.team)
ExternalDataJob.objects.create(
team=self.team,
rows_synced=11,
pipeline=source,
finished_at=datetime.now(),
billable=False,
status=ExternalDataJob.Status.COMPLETED,
)
assert self._run(10) is False
def test_row_tracking_with_in_progress_rows(self):
with self._setup_redis_rows(20):
assert self._run(10) is True
def test_row_tracking_with_previous_rows_from_other_team_in_org(self):
another_team = Team.objects.create(organization=self.organization)
source = ExternalDataSource.objects.create(team=self.team)
ExternalDataJob.objects.create(
team=another_team,
rows_synced=11,
pipeline=source,
finished_at=datetime.now(),
billable=True,
status=ExternalDataJob.Status.COMPLETED,
)
assert self._run(10) is True
def test_row_tracking_with_in_progress_rows_from_other_team_in_org(self):
another_team = Team.objects.create(organization=self.organization)
with self._setup_redis_rows(20, team_id=another_team.pk):
assert self._run(10) is True

View File

@@ -1,8 +1,10 @@
from datetime import datetime
import functools
import uuid
from concurrent.futures import ThreadPoolExecutor
from typing import Any, Optional, cast
from unittest import mock
from zoneinfo import ZoneInfo
import aioboto3
import deltalake
@@ -2248,3 +2250,154 @@ async def test_worker_shutdown_triggers_schedule_buffer_one(team, stripe_price,
team_id=inputs.team_id, pipeline_id=inputs.external_data_source_id
)
assert run.status == ExternalDataJob.Status.COMPLETED
@pytest.mark.django_db(transaction=True)
@pytest.mark.asyncio
async def test_billing_limits_too_many_rows(team, postgres_config, postgres_connection):
from ee.api.test.test_billing import create_billing_customer
from ee.models.license import License
await postgres_connection.execute(
"CREATE TABLE IF NOT EXISTS {schema}.billing_limits (id integer)".format(schema=postgres_config["schema"])
)
await postgres_connection.execute(
"INSERT INTO {schema}.billing_limits (id) VALUES (1)".format(schema=postgres_config["schema"])
)
await postgres_connection.execute(
"INSERT INTO {schema}.billing_limits (id) VALUES (2)".format(schema=postgres_config["schema"])
)
await postgres_connection.commit()
with (
mock.patch("ee.api.billing.requests.get") as mock_billing_request,
mock.patch("posthog.cloud_utils.is_instance_licensed_cached", None),
):
await sync_to_async(License.objects.create)(
key="12345::67890",
plan="enterprise",
valid_until=datetime(2038, 1, 19, 3, 14, 7, tzinfo=ZoneInfo("UTC")),
)
mock_res = create_billing_customer()
usage_summary = mock_res.get("usage_summary") or {}
mock_billing_request.return_value.status_code = 200
mock_billing_request.return_value.json.return_value = {
"license": {
"type": "scale",
},
"customer": {
**mock_res,
"usage_summary": {**usage_summary, "rows_synced": {"limit": 0, "usage": 0}},
},
}
await _run(
team=team,
schema_name="billing_limits",
table_name="postgres_billing_limits",
source_type="Postgres",
job_inputs={
"host": postgres_config["host"],
"port": postgres_config["port"],
"database": postgres_config["database"],
"user": postgres_config["user"],
"password": postgres_config["password"],
"schema": postgres_config["schema"],
"ssh_tunnel_enabled": "False",
},
mock_data_response=[],
sync_type=ExternalDataSchema.SyncType.INCREMENTAL,
sync_type_config={"incremental_field": "id", "incremental_field_type": "integer"},
ignore_assertions=True,
)
job: ExternalDataJob = await sync_to_async(ExternalDataJob.objects.get)(
team_id=team.id, schema__name="billing_limits"
)
assert job.status == ExternalDataJob.Status.BILLING_LIMIT_TOO_LOW
with pytest.raises(Exception):
await sync_to_async(execute_hogql_query)(f"SELECT * FROM postgres_billing_limits", team)
@pytest.mark.django_db(transaction=True)
@pytest.mark.asyncio
async def test_billing_limits_too_many_rows_previously(team, postgres_config, postgres_connection):
from ee.api.test.test_billing import create_billing_customer
from ee.models.license import License
await postgres_connection.execute(
"CREATE TABLE IF NOT EXISTS {schema}.billing_limits (id integer)".format(schema=postgres_config["schema"])
)
await postgres_connection.execute(
"INSERT INTO {schema}.billing_limits (id) VALUES (1)".format(schema=postgres_config["schema"])
)
await postgres_connection.execute(
"INSERT INTO {schema}.billing_limits (id) VALUES (2)".format(schema=postgres_config["schema"])
)
await postgres_connection.commit()
with (
mock.patch("ee.api.billing.requests.get") as mock_billing_request,
mock.patch("posthog.cloud_utils.is_instance_licensed_cached", None),
):
source = await sync_to_async(ExternalDataSource.objects.create)(team=team)
# A previous job that reached the billing limit
await sync_to_async(ExternalDataJob.objects.create)(
team=team,
rows_synced=10,
pipeline=source,
finished_at=datetime.now(),
billable=True,
status=ExternalDataJob.Status.COMPLETED,
)
await sync_to_async(License.objects.create)(
key="12345::67890",
plan="enterprise",
valid_until=datetime(2038, 1, 19, 3, 14, 7, tzinfo=ZoneInfo("UTC")),
)
mock_res = create_billing_customer()
usage_summary = mock_res.get("usage_summary") or {}
mock_billing_request.return_value.status_code = 200
mock_billing_request.return_value.json.return_value = {
"license": {
"type": "scale",
},
"customer": {
**mock_res,
"usage_summary": {**usage_summary, "rows_synced": {"limit": 10, "usage": 0}},
},
}
await _run(
team=team,
schema_name="billing_limits",
table_name="postgres_billing_limits",
source_type="Postgres",
job_inputs={
"host": postgres_config["host"],
"port": postgres_config["port"],
"database": postgres_config["database"],
"user": postgres_config["user"],
"password": postgres_config["password"],
"schema": postgres_config["schema"],
"ssh_tunnel_enabled": "False",
},
mock_data_response=[],
sync_type=ExternalDataSchema.SyncType.INCREMENTAL,
sync_type_config={"incremental_field": "id", "incremental_field_type": "integer"},
ignore_assertions=True,
)
job: ExternalDataJob = await sync_to_async(ExternalDataJob.objects.get)(
team_id=team.id, schema__name="billing_limits"
)
assert job.status == ExternalDataJob.Status.BILLING_LIMIT_TOO_LOW
with pytest.raises(Exception):
await sync_to_async(execute_hogql_query)(f"SELECT * FROM postgres_billing_limits", team)