mirror of
https://github.com/BillyOutlast/posthog.git
synced 2026-02-04 03:01:23 +01:00
feat(dagster): persons backfill job (#41565)
Co-authored-by: Lucas Ricoy <2034367+lricoy@users.noreply.github.com>
This commit is contained in:
@@ -19,3 +19,4 @@ load_from:
|
||||
- python_module: dags.locations.llma
|
||||
- python_module: dags.locations.max_ai
|
||||
- python_module: dags.locations.web_analytics
|
||||
- python_module: dags.locations.ingestion
|
||||
|
||||
@@ -6,6 +6,8 @@ from typing import Optional
|
||||
from django.conf import settings
|
||||
|
||||
import dagster
|
||||
import psycopg2
|
||||
import psycopg2.extras
|
||||
from clickhouse_driver.errors import Error, ErrorCodes
|
||||
|
||||
from posthog.clickhouse import query_tagging
|
||||
@@ -22,6 +24,7 @@ class JobOwners(str, Enum):
|
||||
TEAM_ERROR_TRACKING = "team-error-tracking"
|
||||
TEAM_EXPERIMENTS = "team-experiments"
|
||||
TEAM_GROWTH = "team-growth"
|
||||
TEAM_INGESTION = "team-ingestion"
|
||||
TEAM_LLMA = "team-llma"
|
||||
TEAM_MAX_AI = "team-max-ai"
|
||||
TEAM_REVENUE_ANALYTICS = "team-revenue-analytics"
|
||||
@@ -80,6 +83,28 @@ class RedisResource(dagster.ConfigurableResource):
|
||||
return client
|
||||
|
||||
|
||||
class PostgresResource(dagster.ConfigurableResource):
|
||||
"""
|
||||
A Postgres database connection resource that returns a psycopg2 connection.
|
||||
"""
|
||||
|
||||
host: str
|
||||
port: str = "5432"
|
||||
database: str
|
||||
user: str
|
||||
password: str
|
||||
|
||||
def create_resource(self, context: dagster.InitResourceContext) -> psycopg2.extensions.connection:
|
||||
return psycopg2.connect(
|
||||
host=self.host,
|
||||
port=int(self.port),
|
||||
database=self.database,
|
||||
user=self.user,
|
||||
password=self.password,
|
||||
cursor_factory=psycopg2.extras.RealDictCursor,
|
||||
)
|
||||
|
||||
|
||||
def report_job_status_metric(
|
||||
context: dagster.RunStatusSensorContext, cluster: dagster.ResourceParam[ClickhouseCluster]
|
||||
) -> None:
|
||||
|
||||
@@ -5,7 +5,7 @@ import dagster_slack
|
||||
from dagster_aws.s3.io_manager import s3_pickle_io_manager
|
||||
from dagster_aws.s3.resources import S3Resource
|
||||
|
||||
from dags.common import ClickhouseClusterResource, RedisResource
|
||||
from dags.common import ClickhouseClusterResource, PostgresResource, RedisResource
|
||||
|
||||
# Define resources for different environments
|
||||
resources_by_env = {
|
||||
@@ -18,6 +18,14 @@ resources_by_env = {
|
||||
"s3": S3Resource(),
|
||||
# Using EnvVar instead of the Django setting to ensure that the token is not leaked anywhere in the Dagster UI
|
||||
"slack": dagster_slack.SlackResource(token=dagster.EnvVar("SLACK_TOKEN")),
|
||||
# Postgres resource (universal for all dags)
|
||||
"database": PostgresResource(
|
||||
host=dagster.EnvVar("POSTGRES_HOST"),
|
||||
port=dagster.EnvVar("POSTGRES_PORT"),
|
||||
database=dagster.EnvVar("POSTGRES_DATABASE"),
|
||||
user=dagster.EnvVar("POSTGRES_USER"),
|
||||
password=dagster.EnvVar("POSTGRES_PASSWORD"),
|
||||
),
|
||||
},
|
||||
"local": {
|
||||
"cluster": ClickhouseClusterResource.configure_at_launch(),
|
||||
@@ -29,6 +37,14 @@ resources_by_env = {
|
||||
aws_secret_access_key=settings.OBJECT_STORAGE_SECRET_ACCESS_KEY,
|
||||
),
|
||||
"slack": dagster.ResourceDefinition.none_resource(description="Dummy Slack resource for local development"),
|
||||
# Postgres resource (universal for all dags) - use Django settings or env vars for local dev
|
||||
"database": PostgresResource(
|
||||
host=dagster.EnvVar("POSTGRES_HOST"),
|
||||
port=dagster.EnvVar("POSTGRES_PORT"),
|
||||
database=dagster.EnvVar("POSTGRES_DATABASE"),
|
||||
user=dagster.EnvVar("POSTGRES_USER"),
|
||||
password=dagster.EnvVar("POSTGRES_PASSWORD"),
|
||||
),
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
15
dags/locations/ingestion.py
Normal file
15
dags/locations/ingestion.py
Normal file
@@ -0,0 +1,15 @@
|
||||
import dagster
|
||||
|
||||
from dags import persons_new_backfill
|
||||
|
||||
from . import resources
|
||||
|
||||
defs = dagster.Definitions(
|
||||
assets=[
|
||||
persons_new_backfill.postgres_env_check,
|
||||
],
|
||||
jobs=[
|
||||
persons_new_backfill.persons_new_backfill_job,
|
||||
],
|
||||
resources=resources,
|
||||
)
|
||||
397
dags/persons_new_backfill.py
Normal file
397
dags/persons_new_backfill.py
Normal file
@@ -0,0 +1,397 @@
|
||||
"""Dagster job for backfilling posthog_persons data from source to destination Postgres database."""
|
||||
|
||||
import os
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
import dagster
|
||||
import psycopg2
|
||||
import psycopg2.errors
|
||||
|
||||
from posthog.clickhouse.cluster import ClickhouseCluster
|
||||
from posthog.clickhouse.custom_metrics import MetricsClient
|
||||
|
||||
from dags.common import JobOwners
|
||||
|
||||
|
||||
class PersonsNewBackfillConfig(dagster.Config):
|
||||
"""Configuration for the persons new backfill job."""
|
||||
|
||||
chunk_size: int = 1_000_000 # ID range per chunk
|
||||
batch_size: int = 100_000 # Records per batch insert
|
||||
source_table: str = "posthog_persons"
|
||||
destination_table: str = "posthog_persons_new"
|
||||
max_id: int | None = None # Optional override for max ID to resume from partial state
|
||||
|
||||
|
||||
@dagster.op
|
||||
def get_id_range(
|
||||
context: dagster.OpExecutionContext,
|
||||
config: PersonsNewBackfillConfig,
|
||||
database: dagster.ResourceParam[psycopg2.extensions.connection],
|
||||
) -> tuple[int, int]:
|
||||
"""
|
||||
Query source database for MIN(id) and optionally MAX(id) from posthog_persons table.
|
||||
If max_id is provided in config, uses that instead of querying.
|
||||
Returns tuple (min_id, max_id).
|
||||
"""
|
||||
with database.cursor() as cursor:
|
||||
# Always query for min_id
|
||||
min_query = f"SELECT MIN(id) as min_id FROM {config.source_table}"
|
||||
context.log.info(f"Querying min ID: {min_query}")
|
||||
cursor.execute(min_query)
|
||||
min_result = cursor.fetchone()
|
||||
|
||||
if min_result is None or min_result["min_id"] is None:
|
||||
context.log.exception("Source table is empty or has no valid IDs")
|
||||
# Note: No metrics client here as this is get_id_range op, not copy_chunk
|
||||
raise dagster.Failure("Source table is empty or has no valid IDs")
|
||||
|
||||
min_id = int(min_result["min_id"])
|
||||
|
||||
# Use config max_id if provided, otherwise query database
|
||||
if config.max_id is not None:
|
||||
max_id = config.max_id
|
||||
context.log.info(f"Using configured max_id override: {max_id}")
|
||||
else:
|
||||
max_query = f"SELECT MAX(id) as max_id FROM {config.source_table}"
|
||||
context.log.info(f"Querying max ID: {max_query}")
|
||||
cursor.execute(max_query)
|
||||
max_result = cursor.fetchone()
|
||||
|
||||
if max_result is None or max_result["max_id"] is None:
|
||||
context.log.exception("Source table has no valid max ID")
|
||||
# Note: No metrics client here as this is get_id_range op, not copy_chunk
|
||||
raise dagster.Failure("Source table has no valid max ID")
|
||||
|
||||
max_id = int(max_result["max_id"])
|
||||
|
||||
# Validate that max_id >= min_id
|
||||
if max_id < min_id:
|
||||
error_msg = f"Invalid ID range: max_id ({max_id}) < min_id ({min_id})"
|
||||
context.log.error(error_msg)
|
||||
# Note: No metrics client here as this is get_id_range op, not copy_chunk
|
||||
raise dagster.Failure(error_msg)
|
||||
|
||||
context.log.info(f"ID range: min={min_id}, max={max_id}, total_ids={max_id - min_id + 1}")
|
||||
context.add_output_metadata(
|
||||
{
|
||||
"min_id": dagster.MetadataValue.int(min_id),
|
||||
"max_id": dagster.MetadataValue.int(max_id),
|
||||
"max_id_source": dagster.MetadataValue.text("config" if config.max_id is not None else "database"),
|
||||
"total_ids": dagster.MetadataValue.int(max_id - min_id + 1),
|
||||
}
|
||||
)
|
||||
|
||||
return (min_id, max_id)
|
||||
|
||||
|
||||
@dagster.op(out=dagster.DynamicOut(tuple[int, int]))
|
||||
def create_chunks(
|
||||
context: dagster.OpExecutionContext,
|
||||
config: PersonsNewBackfillConfig,
|
||||
id_range: tuple[int, int],
|
||||
):
|
||||
"""
|
||||
Divide ID space into chunks of chunk_size.
|
||||
Yields DynamicOutput for each chunk in reverse order (highest IDs first, lowest IDs last).
|
||||
This ensures that if the job fails partway through, the final chunk to process will be
|
||||
the one starting at the source table's min_id.
|
||||
"""
|
||||
min_id, max_id = id_range
|
||||
chunk_size = config.chunk_size
|
||||
|
||||
# First, collect all chunks
|
||||
chunks = []
|
||||
chunk_min = min_id
|
||||
chunk_num = 0
|
||||
|
||||
while chunk_min <= max_id:
|
||||
chunk_max = min(chunk_min + chunk_size - 1, max_id)
|
||||
chunks.append((chunk_min, chunk_max, chunk_num))
|
||||
chunk_min = chunk_max + 1
|
||||
chunk_num += 1
|
||||
|
||||
context.log.info(f"Created {chunk_num} chunks total")
|
||||
|
||||
# Yield chunks in reverse order (highest IDs first)
|
||||
for chunk_min, chunk_max, chunk_num in reversed(chunks):
|
||||
chunk_key = f"chunk_{chunk_min}_{chunk_max}"
|
||||
context.log.info(f"Yielding chunk {chunk_num}: {chunk_min} to {chunk_max}")
|
||||
yield dagster.DynamicOutput(
|
||||
value=(chunk_min, chunk_max),
|
||||
mapping_key=chunk_key,
|
||||
)
|
||||
|
||||
|
||||
@dagster.op
|
||||
def copy_chunk(
|
||||
context: dagster.OpExecutionContext,
|
||||
config: PersonsNewBackfillConfig,
|
||||
chunk: tuple[int, int],
|
||||
database: dagster.ResourceParam[psycopg2.extensions.connection],
|
||||
cluster: dagster.ResourceParam[ClickhouseCluster],
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Copy a chunk of records from source to destination database.
|
||||
Processes in batches of batch_size records.
|
||||
"""
|
||||
chunk_min, chunk_max = chunk
|
||||
batch_size = config.batch_size
|
||||
source_table = config.source_table
|
||||
destination_table = config.destination_table
|
||||
chunk_id = f"chunk_{chunk_min}_{chunk_max}"
|
||||
job_name = context.run.job_name
|
||||
|
||||
# Initialize metrics client
|
||||
metrics_client = MetricsClient(cluster)
|
||||
|
||||
context.log.info(f"Starting chunk copy: {chunk_min} to {chunk_max}")
|
||||
|
||||
total_records_copied = 0
|
||||
batch_start_id = chunk_min
|
||||
failed_batch_start_id: int | None = None
|
||||
|
||||
try:
|
||||
with database.cursor() as cursor:
|
||||
# Set session-level settings once for the entire chunk
|
||||
cursor.execute("SET application_name = 'backfill_posthog_persons_to_posthog_persons_new'")
|
||||
cursor.execute("SET lock_timeout = '5s'")
|
||||
cursor.execute("SET statement_timeout = '30min'")
|
||||
cursor.execute("SET maintenance_work_mem = '12GB'")
|
||||
cursor.execute("SET work_mem = '512MB'")
|
||||
cursor.execute("SET temp_buffers = '512MB'")
|
||||
cursor.execute("SET max_parallel_workers_per_gather = 2")
|
||||
cursor.execute("SET max_parallel_maintenance_workers = 2")
|
||||
cursor.execute("SET synchronous_commit = off")
|
||||
|
||||
retry_attempt = 0
|
||||
while batch_start_id <= chunk_max:
|
||||
try:
|
||||
# Track batch start time for duration metric
|
||||
batch_start_time = time.time()
|
||||
|
||||
# Calculate batch end ID
|
||||
batch_end_id = min(batch_start_id + batch_size, chunk_max)
|
||||
|
||||
# Track records attempted - this is also our exit condition
|
||||
records_attempted = batch_end_id - batch_start_id
|
||||
if records_attempted <= 0:
|
||||
break
|
||||
# Begin transaction (settings already applied at session level)
|
||||
cursor.execute("BEGIN")
|
||||
|
||||
# Execute INSERT INTO ... SELECT with NOT EXISTS check
|
||||
insert_query = f"""
|
||||
INSERT INTO {destination_table}
|
||||
SELECT s.*
|
||||
FROM {source_table} s
|
||||
WHERE s.id >= %s AND s.id <= %s
|
||||
AND NOT EXISTS (
|
||||
SELECT 1
|
||||
FROM {destination_table} d
|
||||
WHERE d.team_id = s.team_id
|
||||
AND d.id = s.id
|
||||
)
|
||||
ORDER BY s.id DESC
|
||||
"""
|
||||
cursor.execute(insert_query, (batch_start_id, batch_end_id))
|
||||
records_inserted = cursor.rowcount
|
||||
|
||||
# Commit the transaction
|
||||
cursor.execute("COMMIT")
|
||||
|
||||
try:
|
||||
metrics_client.increment(
|
||||
"persons_new_backfill_records_attempted_total",
|
||||
labels={"job_name": job_name, "chunk_id": chunk_id},
|
||||
value=float(records_attempted),
|
||||
).result()
|
||||
except Exception:
|
||||
pass # Don't fail on metrics error
|
||||
|
||||
batch_duration_seconds = time.time() - batch_start_time
|
||||
|
||||
try:
|
||||
metrics_client.increment(
|
||||
"persons_new_backfill_records_inserted_total",
|
||||
labels={"job_name": job_name, "chunk_id": chunk_id},
|
||||
value=float(records_inserted),
|
||||
).result()
|
||||
except Exception:
|
||||
pass # Don't fail on metrics error
|
||||
|
||||
try:
|
||||
metrics_client.increment(
|
||||
"persons_new_backfill_batches_copied_total",
|
||||
labels={"job_name": job_name, "chunk_id": chunk_id},
|
||||
value=1.0,
|
||||
).result()
|
||||
except Exception:
|
||||
pass
|
||||
# Track batch duration metric (IV)
|
||||
try:
|
||||
metrics_client.increment(
|
||||
"persons_new_backfill_batch_duration_seconds_total",
|
||||
labels={"job_name": job_name, "chunk_id": chunk_id},
|
||||
value=batch_duration_seconds,
|
||||
).result()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
total_records_copied += records_inserted
|
||||
|
||||
context.log.info(
|
||||
f"Copied batch: {records_inserted} records "
|
||||
f"(chunk {chunk_min}-{chunk_max}, batch ID range {batch_start_id} to {batch_end_id})"
|
||||
)
|
||||
|
||||
# Update batch_start_id for next iteration
|
||||
batch_start_id = batch_end_id + 1
|
||||
retry_attempt = 0
|
||||
|
||||
except Exception as batch_error:
|
||||
# Rollback transaction on error
|
||||
try:
|
||||
cursor.execute("ROLLBACK")
|
||||
except Exception as rollback_error:
|
||||
context.log.exception(
|
||||
f"Failed to rollback transaction for batch starting at ID {batch_start_id}"
|
||||
f"in chunk {chunk_min}-{chunk_max}: {str(rollback_error)}"
|
||||
)
|
||||
pass # Ignore rollback errors
|
||||
|
||||
# Check if error is a duplicate key violation, pause and retry if so
|
||||
is_unique_violation = isinstance(batch_error, psycopg2.errors.UniqueViolation) or (
|
||||
isinstance(batch_error, psycopg2.Error) and getattr(batch_error, "pgcode", None) == "23505"
|
||||
)
|
||||
if is_unique_violation:
|
||||
error_msg = (
|
||||
f"Duplicate key violation detected for batch starting at ID {batch_start_id} "
|
||||
f"in chunk {chunk_min}-{chunk_max}. Error is: {batch_error}. "
|
||||
"This is expected if records already exist in destination table. "
|
||||
)
|
||||
context.log.warning(error_msg)
|
||||
if retry_attempt < 3:
|
||||
retry_attempt += 1
|
||||
context.log.info(f"Retrying batch {retry_attempt} of 3...")
|
||||
time.sleep(1)
|
||||
continue
|
||||
|
||||
failed_batch_start_id = batch_start_id
|
||||
error_msg = (
|
||||
f"Failed to copy batch starting at ID {batch_start_id} "
|
||||
f"in chunk {chunk_min}-{chunk_max}: {str(batch_error)}"
|
||||
)
|
||||
context.log.exception(error_msg)
|
||||
# Report fatal error metric before raising
|
||||
try:
|
||||
metrics_client.increment(
|
||||
"persons_new_backfill_error",
|
||||
labels={"job_name": job_name, "chunk_id": chunk_id, "reason": "batch_copy_failed"},
|
||||
value=1.0,
|
||||
).result()
|
||||
except Exception:
|
||||
pass # Don't fail on metrics error
|
||||
|
||||
raise dagster.Failure(
|
||||
description=error_msg,
|
||||
metadata={
|
||||
"chunk_min_id": dagster.MetadataValue.int(chunk_min),
|
||||
"chunk_max_id": dagster.MetadataValue.int(chunk_max),
|
||||
"failed_batch_start_id": dagster.MetadataValue.int(failed_batch_start_id)
|
||||
if failed_batch_start_id
|
||||
else dagster.MetadataValue.text("N/A"),
|
||||
"error_message": dagster.MetadataValue.text(str(batch_error)),
|
||||
"records_copied_before_failure": dagster.MetadataValue.int(total_records_copied),
|
||||
},
|
||||
) from batch_error
|
||||
|
||||
except dagster.Failure:
|
||||
# Re-raise Dagster failures as-is (they already have metadata and metrics)
|
||||
raise
|
||||
except Exception as e:
|
||||
# Catch any other unexpected errors
|
||||
error_msg = f"Unexpected error copying chunk {chunk_min}-{chunk_max}: {str(e)}"
|
||||
context.log.exception(error_msg)
|
||||
# Report fatal error metric before raising
|
||||
try:
|
||||
metrics_client.increment(
|
||||
"persons_new_backfill_error",
|
||||
labels={"job_name": job_name, "chunk_id": chunk_id, "reason": "unexpected_copy_error"},
|
||||
value=1.0,
|
||||
).result()
|
||||
except Exception:
|
||||
pass # Don't fail on metrics error
|
||||
raise dagster.Failure(
|
||||
description=error_msg,
|
||||
metadata={
|
||||
"chunk_min_id": dagster.MetadataValue.int(chunk_min),
|
||||
"chunk_max_id": dagster.MetadataValue.int(chunk_max),
|
||||
"failed_batch_start_id": dagster.MetadataValue.int(failed_batch_start_id)
|
||||
if failed_batch_start_id
|
||||
else dagster.MetadataValue.int(batch_start_id),
|
||||
"error_message": dagster.MetadataValue.text(str(e)),
|
||||
"records_copied_before_failure": dagster.MetadataValue.int(total_records_copied),
|
||||
},
|
||||
) from e
|
||||
|
||||
context.log.info(f"Completed chunk {chunk_min}-{chunk_max}: copied {total_records_copied} records")
|
||||
|
||||
# Emit metric for chunk completion
|
||||
run_id = context.run.run_id
|
||||
try:
|
||||
metrics_client.increment(
|
||||
"persons_new_backfill_chunks_completed_total",
|
||||
labels={"job_name": job_name, "run_id": run_id, "chunk_id": chunk_id},
|
||||
value=1.0,
|
||||
).result()
|
||||
except Exception:
|
||||
pass # Don't fail on metrics error
|
||||
|
||||
context.add_output_metadata(
|
||||
{
|
||||
"chunk_min": dagster.MetadataValue.int(chunk_min),
|
||||
"chunk_max": dagster.MetadataValue.int(chunk_max),
|
||||
"records_copied": dagster.MetadataValue.int(total_records_copied),
|
||||
}
|
||||
)
|
||||
|
||||
return {
|
||||
"chunk_min": chunk_min,
|
||||
"chunk_max": chunk_max,
|
||||
"records_copied": total_records_copied,
|
||||
}
|
||||
|
||||
|
||||
@dagster.asset
|
||||
def postgres_env_check(context: dagster.AssetExecutionContext) -> None:
|
||||
"""
|
||||
Simple asset that prints PostgreSQL environment variables being used.
|
||||
Useful for debugging connection configuration.
|
||||
"""
|
||||
env_vars = {
|
||||
"POSTGRES_HOST": os.getenv("POSTGRES_HOST", "not set"),
|
||||
"POSTGRES_PORT": os.getenv("POSTGRES_PORT", "not set"),
|
||||
"POSTGRES_DATABASE": os.getenv("POSTGRES_DATABASE", "not set"),
|
||||
"POSTGRES_USER": os.getenv("POSTGRES_USER", "not set"),
|
||||
"POSTGRES_PASSWORD": "***" if os.getenv("POSTGRES_PASSWORD") else "not set",
|
||||
}
|
||||
|
||||
context.log.info("PostgreSQL environment variables:")
|
||||
for key, value in env_vars.items():
|
||||
context.log.info(f" {key}: {value}")
|
||||
|
||||
|
||||
@dagster.job(
|
||||
tags={"owner": JobOwners.TEAM_INGESTION.value},
|
||||
executor_def=dagster.multiprocess_executor.configured({"max_concurrent": 32}),
|
||||
)
|
||||
def persons_new_backfill_job():
|
||||
"""
|
||||
Backfill posthog_persons data from source to destination Postgres database.
|
||||
Divides the ID space into chunks and processes them in parallel.
|
||||
"""
|
||||
id_range = get_id_range()
|
||||
chunks = create_chunks(id_range)
|
||||
chunks.map(copy_chunk)
|
||||
@@ -15,6 +15,7 @@ notification_channel_per_team = {
|
||||
JobOwners.TEAM_ERROR_TRACKING.value: "#alerts-error-tracking",
|
||||
JobOwners.TEAM_EXPERIMENTS.value: "#alerts-experiments-dagster",
|
||||
JobOwners.TEAM_GROWTH.value: "#alerts-growth",
|
||||
JobOwners.TEAM_INGESTION.value: "#alerts-ingestion",
|
||||
JobOwners.TEAM_MAX_AI.value: "#alerts-max-ai",
|
||||
JobOwners.TEAM_REVENUE_ANALYTICS.value: "#alerts-revenue-analytics",
|
||||
JobOwners.TEAM_WEB_ANALYTICS.value: "#alerts-web-analytics",
|
||||
|
||||
485
dags/tests/test_persons_new_backfill.py
Normal file
485
dags/tests/test_persons_new_backfill.py
Normal file
@@ -0,0 +1,485 @@
|
||||
"""Tests for the persons new backfill job."""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import psycopg2.errors
|
||||
from dagster import build_op_context
|
||||
|
||||
from dags.persons_new_backfill import PersonsNewBackfillConfig, copy_chunk, create_chunks
|
||||
|
||||
|
||||
class TestCreateChunks:
|
||||
"""Test the create_chunks function."""
|
||||
|
||||
def test_create_chunks_produces_non_overlapping_ranges(self):
|
||||
"""Test that chunks produce non-overlapping ranges."""
|
||||
config = PersonsNewBackfillConfig(chunk_size=1000)
|
||||
id_range = (1, 5000) # min_id=1, max_id=5000
|
||||
|
||||
context = build_op_context()
|
||||
chunks = list(create_chunks(context, config, id_range))
|
||||
|
||||
# Extract all chunk ranges from DynamicOutput objects
|
||||
chunk_ranges = [chunk.value for chunk in chunks]
|
||||
|
||||
# Verify no overlaps
|
||||
for i, (min1, max1) in enumerate(chunk_ranges):
|
||||
for j, (min2, max2) in enumerate(chunk_ranges):
|
||||
if i != j:
|
||||
# Chunks should not overlap
|
||||
assert not (
|
||||
min1 <= min2 <= max1 or min1 <= max2 <= max1 or min2 <= min1 <= max2
|
||||
), f"Chunks overlap: ({min1}, {max1}) and ({min2}, {max2})"
|
||||
|
||||
def test_create_chunks_covers_entire_id_space(self):
|
||||
"""Test that chunks cover the entire ID space from min to max."""
|
||||
config = PersonsNewBackfillConfig(chunk_size=1000)
|
||||
min_id, max_id = 1, 5000
|
||||
id_range = (min_id, max_id)
|
||||
|
||||
context = build_op_context()
|
||||
chunks = list(create_chunks(context, config, id_range))
|
||||
|
||||
# Extract all chunk ranges from DynamicOutput objects
|
||||
chunk_ranges = [chunk.value for chunk in chunks]
|
||||
|
||||
# Find the overall min and max covered
|
||||
all_ids_covered: set[int] = set()
|
||||
for chunk_min, chunk_max in chunk_ranges:
|
||||
all_ids_covered.update(range(chunk_min, chunk_max + 1))
|
||||
|
||||
# Verify all IDs from min_id to max_id are covered
|
||||
expected_ids = set(range(min_id, max_id + 1))
|
||||
assert all_ids_covered == expected_ids, (
|
||||
f"Missing IDs: {expected_ids - all_ids_covered}, " f"Extra IDs: {all_ids_covered - expected_ids}"
|
||||
)
|
||||
|
||||
def test_create_chunks_first_chunk_includes_max_id(self):
|
||||
"""Test that the first chunk (in yielded order) includes the source table max_id."""
|
||||
config = PersonsNewBackfillConfig(chunk_size=1000)
|
||||
min_id, max_id = 1, 5000
|
||||
id_range = (min_id, max_id)
|
||||
|
||||
context = build_op_context()
|
||||
chunks = list(create_chunks(context, config, id_range))
|
||||
|
||||
# First chunk in the list (yielded first, highest IDs)
|
||||
first_chunk_min, first_chunk_max = chunks[0].value
|
||||
|
||||
assert first_chunk_max == max_id, f"First chunk max ({first_chunk_max}) should equal source max_id ({max_id})"
|
||||
assert (
|
||||
first_chunk_min <= max_id <= first_chunk_max
|
||||
), f"First chunk ({first_chunk_min}, {first_chunk_max}) should include max_id ({max_id})"
|
||||
|
||||
def test_create_chunks_final_chunk_includes_min_id(self):
|
||||
"""Test that the final chunk (in yielded order) includes the source table min_id."""
|
||||
config = PersonsNewBackfillConfig(chunk_size=1000)
|
||||
min_id, max_id = 1, 5000
|
||||
id_range = (min_id, max_id)
|
||||
|
||||
context = build_op_context()
|
||||
chunks = list(create_chunks(context, config, id_range))
|
||||
|
||||
# Last chunk in the list (yielded last, lowest IDs)
|
||||
final_chunk_min, final_chunk_max = chunks[-1].value
|
||||
|
||||
assert final_chunk_min == min_id, f"Final chunk min ({final_chunk_min}) should equal source min_id ({min_id})"
|
||||
assert (
|
||||
final_chunk_min <= min_id <= final_chunk_max
|
||||
), f"Final chunk ({final_chunk_min}, {final_chunk_max}) should include min_id ({min_id})"
|
||||
|
||||
def test_create_chunks_reverse_order(self):
|
||||
"""Test that chunks are yielded in reverse order (highest IDs first)."""
|
||||
config = PersonsNewBackfillConfig(chunk_size=1000)
|
||||
min_id, max_id = 1, 5000
|
||||
id_range = (min_id, max_id)
|
||||
|
||||
context = build_op_context()
|
||||
chunks = list(create_chunks(context, config, id_range))
|
||||
|
||||
# Verify chunks are in descending order by max_id
|
||||
for i in range(len(chunks) - 1):
|
||||
current_max = chunks[i].value[1]
|
||||
next_max = chunks[i + 1].value[1]
|
||||
assert (
|
||||
current_max > next_max
|
||||
), f"Chunks not in reverse order: chunk {i} max ({current_max}) should be > chunk {i+1} max ({next_max})"
|
||||
|
||||
def test_create_chunks_exact_multiple(self):
|
||||
"""Test chunk creation when ID range is an exact multiple of chunk_size."""
|
||||
config = PersonsNewBackfillConfig(chunk_size=1000)
|
||||
min_id, max_id = 1, 5000 # Exactly 5 chunks of 1000
|
||||
id_range = (min_id, max_id)
|
||||
|
||||
context = build_op_context()
|
||||
chunks = list(create_chunks(context, config, id_range))
|
||||
|
||||
assert len(chunks) == 5, f"Expected 5 chunks, got {len(chunks)}"
|
||||
|
||||
# Verify first chunk (highest IDs)
|
||||
assert chunks[0].value == (4001, 5000), f"First chunk should be (4001, 5000), got {chunks[0].value}"
|
||||
|
||||
# Verify last chunk (lowest IDs)
|
||||
assert chunks[-1].value == (1, 1000), f"Last chunk should be (1, 1000), got {chunks[-1].value}"
|
||||
|
||||
def test_create_chunks_non_exact_multiple(self):
|
||||
"""Test chunk creation when ID range is not an exact multiple of chunk_size."""
|
||||
config = PersonsNewBackfillConfig(chunk_size=1000)
|
||||
min_id, max_id = 1, 3750 # 3 full chunks + 1 partial chunk
|
||||
id_range = (min_id, max_id)
|
||||
|
||||
context = build_op_context()
|
||||
chunks = list(create_chunks(context, config, id_range))
|
||||
|
||||
assert len(chunks) == 4, f"Expected 4 chunks, got {len(chunks)}"
|
||||
|
||||
# Verify first chunk (highest IDs) - should be the partial chunk
|
||||
assert chunks[0].value == (3001, 3750), f"First chunk should be (3001, 3750), got {chunks[0].value}"
|
||||
|
||||
# Verify last chunk (lowest IDs)
|
||||
assert chunks[-1].value == (1, 1000), f"Last chunk should be (1, 1000), got {chunks[-1].value}"
|
||||
|
||||
def test_create_chunks_single_chunk(self):
|
||||
"""Test chunk creation when ID range fits in a single chunk."""
|
||||
config = PersonsNewBackfillConfig(chunk_size=1000)
|
||||
min_id, max_id = 100, 500
|
||||
id_range = (min_id, max_id)
|
||||
|
||||
context = build_op_context()
|
||||
chunks = list(create_chunks(context, config, id_range))
|
||||
|
||||
assert len(chunks) == 1, f"Expected 1 chunk, got {len(chunks)}"
|
||||
assert chunks[0].value == (100, 500), f"Chunk should be (100, 500), got {chunks[0].value}"
|
||||
assert chunks[0].value[0] == min_id and chunks[0].value[1] == max_id
|
||||
|
||||
|
||||
def create_mock_database_resource(rowcount_values=None):
|
||||
"""
|
||||
Create a mock database resource that mimics psycopg2.extensions.connection.
|
||||
|
||||
Args:
|
||||
rowcount_values: List of rowcount values to return per INSERT call.
|
||||
If None, defaults to 0. If a single int, uses that for all calls.
|
||||
"""
|
||||
mock_cursor = MagicMock()
|
||||
if rowcount_values is None:
|
||||
mock_cursor.rowcount = 0
|
||||
elif isinstance(rowcount_values, int):
|
||||
mock_cursor.rowcount = rowcount_values
|
||||
else:
|
||||
# Use side_effect to return different rowcounts per call
|
||||
call_count = [0]
|
||||
|
||||
def get_rowcount():
|
||||
if call_count[0] < len(rowcount_values):
|
||||
result = rowcount_values[call_count[0]]
|
||||
call_count[0] += 1
|
||||
return result
|
||||
return rowcount_values[-1] if rowcount_values else 0
|
||||
|
||||
mock_cursor.rowcount = property(lambda self: get_rowcount())
|
||||
|
||||
mock_cursor.execute = MagicMock()
|
||||
|
||||
# Make cursor() return a context manager
|
||||
mock_conn = MagicMock()
|
||||
mock_conn.cursor.return_value.__enter__ = MagicMock(return_value=mock_cursor)
|
||||
mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
return mock_conn
|
||||
|
||||
|
||||
def create_mock_cluster_resource():
|
||||
"""Create a mock ClickhouseCluster resource."""
|
||||
return MagicMock()
|
||||
|
||||
|
||||
class TestCopyChunk:
|
||||
"""Test the copy_chunk function."""
|
||||
|
||||
def test_copy_chunk_single_batch_success(self):
|
||||
"""Test successful copy of a single batch within a chunk."""
|
||||
config = PersonsNewBackfillConfig(
|
||||
chunk_size=1000, batch_size=100, source_table="posthog_persons", destination_table="posthog_persons_new"
|
||||
)
|
||||
chunk = (1, 100) # Single batch covers entire chunk
|
||||
|
||||
mock_db = create_mock_database_resource(rowcount_values=50)
|
||||
mock_cluster = create_mock_cluster_resource()
|
||||
|
||||
context = build_op_context(
|
||||
resources={"database": mock_db, "cluster": mock_cluster},
|
||||
)
|
||||
# Patch context.run.job_name where it's accessed in copy_chunk
|
||||
from unittest.mock import PropertyMock
|
||||
|
||||
with patch.object(type(context), "run", PropertyMock(return_value=MagicMock(job_name="test_job"))):
|
||||
result = copy_chunk(context, config, chunk)
|
||||
|
||||
# Verify result
|
||||
assert result["chunk_min"] == 1
|
||||
assert result["chunk_max"] == 100
|
||||
assert result["records_copied"] == 50
|
||||
|
||||
# Verify SET statements called once (session-level, before loop)
|
||||
set_statements = [
|
||||
"SET application_name = 'backfill_posthog_persons_to_posthog_persons_new'",
|
||||
"SET lock_timeout = '5s'",
|
||||
"SET statement_timeout = '30min'",
|
||||
"SET maintenance_work_mem = '12GB'",
|
||||
"SET work_mem = '512MB'",
|
||||
"SET temp_buffers = '512MB'",
|
||||
"SET max_parallel_workers_per_gather = 2",
|
||||
"SET max_parallel_maintenance_workers = 2",
|
||||
"SET synchronous_commit = off",
|
||||
]
|
||||
|
||||
cursor = mock_db.cursor.return_value.__enter__.return_value
|
||||
execute_calls = [call[0][0] for call in cursor.execute.call_args_list]
|
||||
|
||||
# Check SET statements were called
|
||||
for stmt in set_statements:
|
||||
assert any(stmt in call for call in execute_calls), f"SET statement not found: {stmt}"
|
||||
|
||||
# Verify BEGIN, INSERT, COMMIT called once
|
||||
assert execute_calls.count("BEGIN") == 1
|
||||
assert execute_calls.count("COMMIT") == 1
|
||||
|
||||
# Verify INSERT query format
|
||||
insert_calls = [call for call in execute_calls if "INSERT INTO" in call]
|
||||
assert len(insert_calls) == 1
|
||||
insert_query = insert_calls[0]
|
||||
assert "INSERT INTO posthog_persons_new" in insert_query
|
||||
assert "SELECT s.*" in insert_query
|
||||
assert "FROM posthog_persons s" in insert_query
|
||||
assert "WHERE s.id >" in insert_query
|
||||
assert "AND s.id <=" in insert_query
|
||||
assert "NOT EXISTS" in insert_query
|
||||
assert "ORDER BY s.id DESC" in insert_query
|
||||
|
||||
def test_copy_chunk_multiple_batches(self):
|
||||
"""Test copy with multiple batches in a chunk."""
|
||||
config = PersonsNewBackfillConfig(
|
||||
chunk_size=1000, batch_size=100, source_table="posthog_persons", destination_table="posthog_persons_new"
|
||||
)
|
||||
chunk = (1, 250) # 3 batches: (1,100), (100,200), (200,250)
|
||||
|
||||
mock_db = create_mock_database_resource()
|
||||
mock_cluster = create_mock_cluster_resource()
|
||||
|
||||
# Track rowcount per batch - use a list to track INSERT calls
|
||||
rowcounts = [50, 75, 25]
|
||||
insert_call_count = [0]
|
||||
|
||||
cursor = mock_db.cursor.return_value.__enter__.return_value
|
||||
|
||||
# Track INSERT calls and set rowcount accordingly
|
||||
def execute_with_rowcount(query, *args):
|
||||
if "INSERT INTO" in query:
|
||||
if insert_call_count[0] < len(rowcounts):
|
||||
cursor.rowcount = rowcounts[insert_call_count[0]]
|
||||
insert_call_count[0] += 1
|
||||
else:
|
||||
cursor.rowcount = 0
|
||||
|
||||
cursor.execute.side_effect = execute_with_rowcount
|
||||
|
||||
context = build_op_context(
|
||||
resources={"database": mock_db, "cluster": mock_cluster},
|
||||
)
|
||||
# Patch context.run.job_name where it's accessed in copy_chunk
|
||||
from unittest.mock import PropertyMock
|
||||
|
||||
with patch.object(type(context), "run", PropertyMock(return_value=MagicMock(job_name="test_job"))):
|
||||
result = copy_chunk(context, config, chunk)
|
||||
|
||||
# Verify result
|
||||
assert result["chunk_min"] == 1
|
||||
assert result["chunk_max"] == 250
|
||||
assert result["records_copied"] == 150 # 50 + 75 + 25
|
||||
|
||||
# Verify SET statements called once (before loop)
|
||||
cursor = mock_db.cursor.return_value.__enter__.return_value
|
||||
execute_calls = [call[0][0] for call in cursor.execute.call_args_list]
|
||||
|
||||
# Verify BEGIN/COMMIT called 3 times (one per batch)
|
||||
assert execute_calls.count("BEGIN") == 3
|
||||
assert execute_calls.count("COMMIT") == 3
|
||||
|
||||
# Verify INSERT called 3 times
|
||||
insert_calls = [call for call in execute_calls if "INSERT INTO" in call]
|
||||
assert len(insert_calls) == 3
|
||||
|
||||
def test_copy_chunk_duplicate_key_violation_retry(self):
|
||||
"""Test that duplicate key violation triggers retry."""
|
||||
config = PersonsNewBackfillConfig(
|
||||
chunk_size=1000, batch_size=100, source_table="posthog_persons", destination_table="posthog_persons_new"
|
||||
)
|
||||
chunk = (1, 100)
|
||||
|
||||
mock_db = create_mock_database_resource()
|
||||
mock_cluster = create_mock_cluster_resource()
|
||||
|
||||
cursor = mock_db.cursor.return_value.__enter__.return_value
|
||||
|
||||
# Track INSERT attempts
|
||||
insert_attempts = [0]
|
||||
|
||||
# First INSERT raises UniqueViolation, second succeeds
|
||||
def execute_side_effect(query, *args):
|
||||
if "INSERT INTO" in query:
|
||||
insert_attempts[0] += 1
|
||||
if insert_attempts[0] == 1:
|
||||
# First INSERT attempt raises error
|
||||
# Use real UniqueViolation - pgcode is readonly but isinstance check will pass
|
||||
raise psycopg2.errors.UniqueViolation("duplicate key value violates unique constraint")
|
||||
# Subsequent calls succeed
|
||||
cursor.rowcount = 50 # Success on retry
|
||||
|
||||
cursor.execute.side_effect = execute_side_effect
|
||||
|
||||
context = build_op_context(
|
||||
resources={"database": mock_db, "cluster": mock_cluster},
|
||||
)
|
||||
# Need to patch time.sleep and run.job_name
|
||||
from unittest.mock import PropertyMock
|
||||
|
||||
mock_run = MagicMock(job_name="test_job")
|
||||
with (
|
||||
patch("dags.persons_new_backfill.time.sleep"),
|
||||
patch.object(type(context), "run", PropertyMock(return_value=mock_run)),
|
||||
):
|
||||
copy_chunk(context, config, chunk)
|
||||
|
||||
# Verify ROLLBACK was called on error
|
||||
execute_calls = [call[0][0] for call in cursor.execute.call_args_list]
|
||||
assert "ROLLBACK" in execute_calls
|
||||
|
||||
# Verify retry succeeded (should have INSERT called twice, COMMIT once)
|
||||
insert_calls = [call for call in execute_calls if "INSERT INTO" in call]
|
||||
assert len(insert_calls) >= 1 # At least one successful INSERT
|
||||
|
||||
def test_copy_chunk_error_handling_and_rollback(self):
|
||||
"""Test error handling and rollback on non-duplicate errors."""
|
||||
config = PersonsNewBackfillConfig(
|
||||
chunk_size=1000, batch_size=100, source_table="posthog_persons", destination_table="posthog_persons_new"
|
||||
)
|
||||
chunk = (1, 100)
|
||||
|
||||
mock_db = create_mock_database_resource()
|
||||
mock_cluster = create_mock_cluster_resource()
|
||||
|
||||
cursor = mock_db.cursor.return_value.__enter__.return_value
|
||||
|
||||
# Raise generic error on INSERT
|
||||
def execute_side_effect(query, *args):
|
||||
if "INSERT INTO" in query:
|
||||
raise Exception("Connection lost")
|
||||
|
||||
cursor.execute.side_effect = execute_side_effect
|
||||
|
||||
context = build_op_context(
|
||||
resources={"database": mock_db, "cluster": mock_cluster},
|
||||
)
|
||||
# Patch context.run.job_name where it's accessed in copy_chunk
|
||||
from unittest.mock import PropertyMock
|
||||
|
||||
mock_run = MagicMock(job_name="test_job")
|
||||
with patch.object(type(context), "run", PropertyMock(return_value=mock_run)):
|
||||
# Should raise Dagster.Failure
|
||||
from dagster import Failure
|
||||
|
||||
try:
|
||||
copy_chunk(context, config, chunk)
|
||||
raise AssertionError("Expected Dagster.Failure to be raised")
|
||||
except Failure as e:
|
||||
# Verify error metadata
|
||||
assert e.description is not None
|
||||
assert "Failed to copy batch" in e.description
|
||||
|
||||
# Verify ROLLBACK was called
|
||||
execute_calls = [call[0][0] for call in cursor.execute.call_args_list]
|
||||
assert "ROLLBACK" in execute_calls
|
||||
|
||||
def test_copy_chunk_insert_query_format(self):
|
||||
"""Test that INSERT query has correct format."""
|
||||
config = PersonsNewBackfillConfig(
|
||||
chunk_size=1000, batch_size=100, source_table="test_source", destination_table="test_dest"
|
||||
)
|
||||
chunk = (1, 100)
|
||||
|
||||
mock_db = create_mock_database_resource(rowcount_values=10)
|
||||
mock_cluster = create_mock_cluster_resource()
|
||||
|
||||
context = build_op_context(
|
||||
resources={"database": mock_db, "cluster": mock_cluster},
|
||||
)
|
||||
# Patch context.run.job_name where it's accessed in copy_chunk
|
||||
from unittest.mock import PropertyMock
|
||||
|
||||
with patch.object(type(context), "run", PropertyMock(return_value=MagicMock(job_name="test_job"))):
|
||||
copy_chunk(context, config, chunk)
|
||||
|
||||
cursor = mock_db.cursor.return_value.__enter__.return_value
|
||||
execute_calls = [call[0][0] for call in cursor.execute.call_args_list]
|
||||
|
||||
# Find INSERT query
|
||||
insert_query = next((call for call in execute_calls if "INSERT INTO" in call), None)
|
||||
assert insert_query is not None
|
||||
|
||||
# Verify query components
|
||||
assert "INSERT INTO test_dest" in insert_query
|
||||
assert "SELECT s.*" in insert_query
|
||||
assert "FROM test_source s" in insert_query
|
||||
assert "WHERE s.id >" in insert_query
|
||||
assert "AND s.id <=" in insert_query
|
||||
assert "NOT EXISTS" in insert_query
|
||||
assert "d.team_id = s.team_id" in insert_query
|
||||
assert "d.id = s.id" in insert_query
|
||||
assert "ORDER BY s.id DESC" in insert_query
|
||||
|
||||
def test_copy_chunk_session_settings_applied_once(self):
|
||||
"""Test that SET statements are applied once at session level before batch loop."""
|
||||
config = PersonsNewBackfillConfig(
|
||||
chunk_size=1000, batch_size=50, source_table="posthog_persons", destination_table="posthog_persons_new"
|
||||
)
|
||||
chunk = (1, 150) # 3 batches
|
||||
|
||||
mock_db = create_mock_database_resource(rowcount_values=25)
|
||||
mock_cluster = create_mock_cluster_resource()
|
||||
|
||||
context = build_op_context(
|
||||
resources={"database": mock_db, "cluster": mock_cluster},
|
||||
)
|
||||
# Patch context.run.job_name where it's accessed in copy_chunk
|
||||
from unittest.mock import PropertyMock
|
||||
|
||||
with patch.object(type(context), "run", PropertyMock(return_value=MagicMock(job_name="test_job"))):
|
||||
copy_chunk(context, config, chunk)
|
||||
|
||||
cursor = mock_db.cursor.return_value.__enter__.return_value
|
||||
execute_calls = [call[0][0] for call in cursor.execute.call_args_list]
|
||||
|
||||
# Count SET statements (should be called once each, before loop)
|
||||
set_statements = [
|
||||
"SET application_name",
|
||||
"SET lock_timeout",
|
||||
"SET statement_timeout",
|
||||
"SET maintenance_work_mem",
|
||||
"SET work_mem",
|
||||
"SET temp_buffers",
|
||||
"SET max_parallel_workers_per_gather",
|
||||
"SET max_parallel_maintenance_workers",
|
||||
"SET synchronous_commit",
|
||||
]
|
||||
|
||||
for stmt in set_statements:
|
||||
count = sum(1 for call in execute_calls if stmt in call)
|
||||
assert count == 1, f"Expected {stmt} to be called once, but it was called {count} times"
|
||||
|
||||
# Verify SET statements come before BEGIN statements
|
||||
set_indices = [i for i, call in enumerate(execute_calls) if any(stmt in call for stmt in set_statements)]
|
||||
begin_indices = [i for i, call in enumerate(execute_calls) if call == "BEGIN"]
|
||||
|
||||
if set_indices and begin_indices:
|
||||
assert max(set_indices) < min(begin_indices), "SET statements should come before BEGIN statements"
|
||||
Reference in New Issue
Block a user