From a2a7f2a7a1d00ac45a87791c718e0b5196ae9929 Mon Sep 17 00:00:00 2001 From: Eli Reisman Date: Fri, 14 Nov 2025 19:38:28 -0800 Subject: [PATCH] feat(dagster): persons backfill job (#41565) Co-authored-by: Lucas Ricoy <2034367+lricoy@users.noreply.github.com> --- .dagster_home/workspace.yaml | 1 + dags/common/common.py | 25 ++ dags/locations/__init__.py | 18 +- dags/locations/ingestion.py | 15 + dags/persons_new_backfill.py | 397 +++++++++++++++++++ dags/slack_alerts.py | 1 + dags/tests/test_persons_new_backfill.py | 485 ++++++++++++++++++++++++ 7 files changed, 941 insertions(+), 1 deletion(-) create mode 100644 dags/locations/ingestion.py create mode 100644 dags/persons_new_backfill.py create mode 100644 dags/tests/test_persons_new_backfill.py diff --git a/.dagster_home/workspace.yaml b/.dagster_home/workspace.yaml index b83469816e..6895d20b21 100644 --- a/.dagster_home/workspace.yaml +++ b/.dagster_home/workspace.yaml @@ -19,3 +19,4 @@ load_from: - python_module: dags.locations.llma - python_module: dags.locations.max_ai - python_module: dags.locations.web_analytics + - python_module: dags.locations.ingestion diff --git a/dags/common/common.py b/dags/common/common.py index 31deec5f25..300452f7a0 100644 --- a/dags/common/common.py +++ b/dags/common/common.py @@ -6,6 +6,8 @@ from typing import Optional from django.conf import settings import dagster +import psycopg2 +import psycopg2.extras from clickhouse_driver.errors import Error, ErrorCodes from posthog.clickhouse import query_tagging @@ -22,6 +24,7 @@ class JobOwners(str, Enum): TEAM_ERROR_TRACKING = "team-error-tracking" TEAM_EXPERIMENTS = "team-experiments" TEAM_GROWTH = "team-growth" + TEAM_INGESTION = "team-ingestion" TEAM_LLMA = "team-llma" TEAM_MAX_AI = "team-max-ai" TEAM_REVENUE_ANALYTICS = "team-revenue-analytics" @@ -80,6 +83,28 @@ class RedisResource(dagster.ConfigurableResource): return client +class PostgresResource(dagster.ConfigurableResource): + """ + A Postgres database connection resource that returns a psycopg2 connection. + """ + + host: str + port: str = "5432" + database: str + user: str + password: str + + def create_resource(self, context: dagster.InitResourceContext) -> psycopg2.extensions.connection: + return psycopg2.connect( + host=self.host, + port=int(self.port), + database=self.database, + user=self.user, + password=self.password, + cursor_factory=psycopg2.extras.RealDictCursor, + ) + + def report_job_status_metric( context: dagster.RunStatusSensorContext, cluster: dagster.ResourceParam[ClickhouseCluster] ) -> None: diff --git a/dags/locations/__init__.py b/dags/locations/__init__.py index 965967fd54..a4fb32a7cf 100644 --- a/dags/locations/__init__.py +++ b/dags/locations/__init__.py @@ -5,7 +5,7 @@ import dagster_slack from dagster_aws.s3.io_manager import s3_pickle_io_manager from dagster_aws.s3.resources import S3Resource -from dags.common import ClickhouseClusterResource, RedisResource +from dags.common import ClickhouseClusterResource, PostgresResource, RedisResource # Define resources for different environments resources_by_env = { @@ -18,6 +18,14 @@ resources_by_env = { "s3": S3Resource(), # Using EnvVar instead of the Django setting to ensure that the token is not leaked anywhere in the Dagster UI "slack": dagster_slack.SlackResource(token=dagster.EnvVar("SLACK_TOKEN")), + # Postgres resource (universal for all dags) + "database": PostgresResource( + host=dagster.EnvVar("POSTGRES_HOST"), + port=dagster.EnvVar("POSTGRES_PORT"), + database=dagster.EnvVar("POSTGRES_DATABASE"), + user=dagster.EnvVar("POSTGRES_USER"), + password=dagster.EnvVar("POSTGRES_PASSWORD"), + ), }, "local": { "cluster": ClickhouseClusterResource.configure_at_launch(), @@ -29,6 +37,14 @@ resources_by_env = { aws_secret_access_key=settings.OBJECT_STORAGE_SECRET_ACCESS_KEY, ), "slack": dagster.ResourceDefinition.none_resource(description="Dummy Slack resource for local development"), + # Postgres resource (universal for all dags) - use Django settings or env vars for local dev + "database": PostgresResource( + host=dagster.EnvVar("POSTGRES_HOST"), + port=dagster.EnvVar("POSTGRES_PORT"), + database=dagster.EnvVar("POSTGRES_DATABASE"), + user=dagster.EnvVar("POSTGRES_USER"), + password=dagster.EnvVar("POSTGRES_PASSWORD"), + ), }, } diff --git a/dags/locations/ingestion.py b/dags/locations/ingestion.py new file mode 100644 index 0000000000..f314d4da8d --- /dev/null +++ b/dags/locations/ingestion.py @@ -0,0 +1,15 @@ +import dagster + +from dags import persons_new_backfill + +from . import resources + +defs = dagster.Definitions( + assets=[ + persons_new_backfill.postgres_env_check, + ], + jobs=[ + persons_new_backfill.persons_new_backfill_job, + ], + resources=resources, +) diff --git a/dags/persons_new_backfill.py b/dags/persons_new_backfill.py new file mode 100644 index 0000000000..7469b617d0 --- /dev/null +++ b/dags/persons_new_backfill.py @@ -0,0 +1,397 @@ +"""Dagster job for backfilling posthog_persons data from source to destination Postgres database.""" + +import os +import time +from typing import Any + +import dagster +import psycopg2 +import psycopg2.errors + +from posthog.clickhouse.cluster import ClickhouseCluster +from posthog.clickhouse.custom_metrics import MetricsClient + +from dags.common import JobOwners + + +class PersonsNewBackfillConfig(dagster.Config): + """Configuration for the persons new backfill job.""" + + chunk_size: int = 1_000_000 # ID range per chunk + batch_size: int = 100_000 # Records per batch insert + source_table: str = "posthog_persons" + destination_table: str = "posthog_persons_new" + max_id: int | None = None # Optional override for max ID to resume from partial state + + +@dagster.op +def get_id_range( + context: dagster.OpExecutionContext, + config: PersonsNewBackfillConfig, + database: dagster.ResourceParam[psycopg2.extensions.connection], +) -> tuple[int, int]: + """ + Query source database for MIN(id) and optionally MAX(id) from posthog_persons table. + If max_id is provided in config, uses that instead of querying. + Returns tuple (min_id, max_id). + """ + with database.cursor() as cursor: + # Always query for min_id + min_query = f"SELECT MIN(id) as min_id FROM {config.source_table}" + context.log.info(f"Querying min ID: {min_query}") + cursor.execute(min_query) + min_result = cursor.fetchone() + + if min_result is None or min_result["min_id"] is None: + context.log.exception("Source table is empty or has no valid IDs") + # Note: No metrics client here as this is get_id_range op, not copy_chunk + raise dagster.Failure("Source table is empty or has no valid IDs") + + min_id = int(min_result["min_id"]) + + # Use config max_id if provided, otherwise query database + if config.max_id is not None: + max_id = config.max_id + context.log.info(f"Using configured max_id override: {max_id}") + else: + max_query = f"SELECT MAX(id) as max_id FROM {config.source_table}" + context.log.info(f"Querying max ID: {max_query}") + cursor.execute(max_query) + max_result = cursor.fetchone() + + if max_result is None or max_result["max_id"] is None: + context.log.exception("Source table has no valid max ID") + # Note: No metrics client here as this is get_id_range op, not copy_chunk + raise dagster.Failure("Source table has no valid max ID") + + max_id = int(max_result["max_id"]) + + # Validate that max_id >= min_id + if max_id < min_id: + error_msg = f"Invalid ID range: max_id ({max_id}) < min_id ({min_id})" + context.log.error(error_msg) + # Note: No metrics client here as this is get_id_range op, not copy_chunk + raise dagster.Failure(error_msg) + + context.log.info(f"ID range: min={min_id}, max={max_id}, total_ids={max_id - min_id + 1}") + context.add_output_metadata( + { + "min_id": dagster.MetadataValue.int(min_id), + "max_id": dagster.MetadataValue.int(max_id), + "max_id_source": dagster.MetadataValue.text("config" if config.max_id is not None else "database"), + "total_ids": dagster.MetadataValue.int(max_id - min_id + 1), + } + ) + + return (min_id, max_id) + + +@dagster.op(out=dagster.DynamicOut(tuple[int, int])) +def create_chunks( + context: dagster.OpExecutionContext, + config: PersonsNewBackfillConfig, + id_range: tuple[int, int], +): + """ + Divide ID space into chunks of chunk_size. + Yields DynamicOutput for each chunk in reverse order (highest IDs first, lowest IDs last). + This ensures that if the job fails partway through, the final chunk to process will be + the one starting at the source table's min_id. + """ + min_id, max_id = id_range + chunk_size = config.chunk_size + + # First, collect all chunks + chunks = [] + chunk_min = min_id + chunk_num = 0 + + while chunk_min <= max_id: + chunk_max = min(chunk_min + chunk_size - 1, max_id) + chunks.append((chunk_min, chunk_max, chunk_num)) + chunk_min = chunk_max + 1 + chunk_num += 1 + + context.log.info(f"Created {chunk_num} chunks total") + + # Yield chunks in reverse order (highest IDs first) + for chunk_min, chunk_max, chunk_num in reversed(chunks): + chunk_key = f"chunk_{chunk_min}_{chunk_max}" + context.log.info(f"Yielding chunk {chunk_num}: {chunk_min} to {chunk_max}") + yield dagster.DynamicOutput( + value=(chunk_min, chunk_max), + mapping_key=chunk_key, + ) + + +@dagster.op +def copy_chunk( + context: dagster.OpExecutionContext, + config: PersonsNewBackfillConfig, + chunk: tuple[int, int], + database: dagster.ResourceParam[psycopg2.extensions.connection], + cluster: dagster.ResourceParam[ClickhouseCluster], +) -> dict[str, Any]: + """ + Copy a chunk of records from source to destination database. + Processes in batches of batch_size records. + """ + chunk_min, chunk_max = chunk + batch_size = config.batch_size + source_table = config.source_table + destination_table = config.destination_table + chunk_id = f"chunk_{chunk_min}_{chunk_max}" + job_name = context.run.job_name + + # Initialize metrics client + metrics_client = MetricsClient(cluster) + + context.log.info(f"Starting chunk copy: {chunk_min} to {chunk_max}") + + total_records_copied = 0 + batch_start_id = chunk_min + failed_batch_start_id: int | None = None + + try: + with database.cursor() as cursor: + # Set session-level settings once for the entire chunk + cursor.execute("SET application_name = 'backfill_posthog_persons_to_posthog_persons_new'") + cursor.execute("SET lock_timeout = '5s'") + cursor.execute("SET statement_timeout = '30min'") + cursor.execute("SET maintenance_work_mem = '12GB'") + cursor.execute("SET work_mem = '512MB'") + cursor.execute("SET temp_buffers = '512MB'") + cursor.execute("SET max_parallel_workers_per_gather = 2") + cursor.execute("SET max_parallel_maintenance_workers = 2") + cursor.execute("SET synchronous_commit = off") + + retry_attempt = 0 + while batch_start_id <= chunk_max: + try: + # Track batch start time for duration metric + batch_start_time = time.time() + + # Calculate batch end ID + batch_end_id = min(batch_start_id + batch_size, chunk_max) + + # Track records attempted - this is also our exit condition + records_attempted = batch_end_id - batch_start_id + if records_attempted <= 0: + break + # Begin transaction (settings already applied at session level) + cursor.execute("BEGIN") + + # Execute INSERT INTO ... SELECT with NOT EXISTS check + insert_query = f""" +INSERT INTO {destination_table} +SELECT s.* +FROM {source_table} s +WHERE s.id >= %s AND s.id <= %s + AND NOT EXISTS ( + SELECT 1 + FROM {destination_table} d + WHERE d.team_id = s.team_id + AND d.id = s.id + ) +ORDER BY s.id DESC +""" + cursor.execute(insert_query, (batch_start_id, batch_end_id)) + records_inserted = cursor.rowcount + + # Commit the transaction + cursor.execute("COMMIT") + + try: + metrics_client.increment( + "persons_new_backfill_records_attempted_total", + labels={"job_name": job_name, "chunk_id": chunk_id}, + value=float(records_attempted), + ).result() + except Exception: + pass # Don't fail on metrics error + + batch_duration_seconds = time.time() - batch_start_time + + try: + metrics_client.increment( + "persons_new_backfill_records_inserted_total", + labels={"job_name": job_name, "chunk_id": chunk_id}, + value=float(records_inserted), + ).result() + except Exception: + pass # Don't fail on metrics error + + try: + metrics_client.increment( + "persons_new_backfill_batches_copied_total", + labels={"job_name": job_name, "chunk_id": chunk_id}, + value=1.0, + ).result() + except Exception: + pass + # Track batch duration metric (IV) + try: + metrics_client.increment( + "persons_new_backfill_batch_duration_seconds_total", + labels={"job_name": job_name, "chunk_id": chunk_id}, + value=batch_duration_seconds, + ).result() + except Exception: + pass + + total_records_copied += records_inserted + + context.log.info( + f"Copied batch: {records_inserted} records " + f"(chunk {chunk_min}-{chunk_max}, batch ID range {batch_start_id} to {batch_end_id})" + ) + + # Update batch_start_id for next iteration + batch_start_id = batch_end_id + 1 + retry_attempt = 0 + + except Exception as batch_error: + # Rollback transaction on error + try: + cursor.execute("ROLLBACK") + except Exception as rollback_error: + context.log.exception( + f"Failed to rollback transaction for batch starting at ID {batch_start_id}" + f"in chunk {chunk_min}-{chunk_max}: {str(rollback_error)}" + ) + pass # Ignore rollback errors + + # Check if error is a duplicate key violation, pause and retry if so + is_unique_violation = isinstance(batch_error, psycopg2.errors.UniqueViolation) or ( + isinstance(batch_error, psycopg2.Error) and getattr(batch_error, "pgcode", None) == "23505" + ) + if is_unique_violation: + error_msg = ( + f"Duplicate key violation detected for batch starting at ID {batch_start_id} " + f"in chunk {chunk_min}-{chunk_max}. Error is: {batch_error}. " + "This is expected if records already exist in destination table. " + ) + context.log.warning(error_msg) + if retry_attempt < 3: + retry_attempt += 1 + context.log.info(f"Retrying batch {retry_attempt} of 3...") + time.sleep(1) + continue + + failed_batch_start_id = batch_start_id + error_msg = ( + f"Failed to copy batch starting at ID {batch_start_id} " + f"in chunk {chunk_min}-{chunk_max}: {str(batch_error)}" + ) + context.log.exception(error_msg) + # Report fatal error metric before raising + try: + metrics_client.increment( + "persons_new_backfill_error", + labels={"job_name": job_name, "chunk_id": chunk_id, "reason": "batch_copy_failed"}, + value=1.0, + ).result() + except Exception: + pass # Don't fail on metrics error + + raise dagster.Failure( + description=error_msg, + metadata={ + "chunk_min_id": dagster.MetadataValue.int(chunk_min), + "chunk_max_id": dagster.MetadataValue.int(chunk_max), + "failed_batch_start_id": dagster.MetadataValue.int(failed_batch_start_id) + if failed_batch_start_id + else dagster.MetadataValue.text("N/A"), + "error_message": dagster.MetadataValue.text(str(batch_error)), + "records_copied_before_failure": dagster.MetadataValue.int(total_records_copied), + }, + ) from batch_error + + except dagster.Failure: + # Re-raise Dagster failures as-is (they already have metadata and metrics) + raise + except Exception as e: + # Catch any other unexpected errors + error_msg = f"Unexpected error copying chunk {chunk_min}-{chunk_max}: {str(e)}" + context.log.exception(error_msg) + # Report fatal error metric before raising + try: + metrics_client.increment( + "persons_new_backfill_error", + labels={"job_name": job_name, "chunk_id": chunk_id, "reason": "unexpected_copy_error"}, + value=1.0, + ).result() + except Exception: + pass # Don't fail on metrics error + raise dagster.Failure( + description=error_msg, + metadata={ + "chunk_min_id": dagster.MetadataValue.int(chunk_min), + "chunk_max_id": dagster.MetadataValue.int(chunk_max), + "failed_batch_start_id": dagster.MetadataValue.int(failed_batch_start_id) + if failed_batch_start_id + else dagster.MetadataValue.int(batch_start_id), + "error_message": dagster.MetadataValue.text(str(e)), + "records_copied_before_failure": dagster.MetadataValue.int(total_records_copied), + }, + ) from e + + context.log.info(f"Completed chunk {chunk_min}-{chunk_max}: copied {total_records_copied} records") + + # Emit metric for chunk completion + run_id = context.run.run_id + try: + metrics_client.increment( + "persons_new_backfill_chunks_completed_total", + labels={"job_name": job_name, "run_id": run_id, "chunk_id": chunk_id}, + value=1.0, + ).result() + except Exception: + pass # Don't fail on metrics error + + context.add_output_metadata( + { + "chunk_min": dagster.MetadataValue.int(chunk_min), + "chunk_max": dagster.MetadataValue.int(chunk_max), + "records_copied": dagster.MetadataValue.int(total_records_copied), + } + ) + + return { + "chunk_min": chunk_min, + "chunk_max": chunk_max, + "records_copied": total_records_copied, + } + + +@dagster.asset +def postgres_env_check(context: dagster.AssetExecutionContext) -> None: + """ + Simple asset that prints PostgreSQL environment variables being used. + Useful for debugging connection configuration. + """ + env_vars = { + "POSTGRES_HOST": os.getenv("POSTGRES_HOST", "not set"), + "POSTGRES_PORT": os.getenv("POSTGRES_PORT", "not set"), + "POSTGRES_DATABASE": os.getenv("POSTGRES_DATABASE", "not set"), + "POSTGRES_USER": os.getenv("POSTGRES_USER", "not set"), + "POSTGRES_PASSWORD": "***" if os.getenv("POSTGRES_PASSWORD") else "not set", + } + + context.log.info("PostgreSQL environment variables:") + for key, value in env_vars.items(): + context.log.info(f" {key}: {value}") + + +@dagster.job( + tags={"owner": JobOwners.TEAM_INGESTION.value}, + executor_def=dagster.multiprocess_executor.configured({"max_concurrent": 32}), +) +def persons_new_backfill_job(): + """ + Backfill posthog_persons data from source to destination Postgres database. + Divides the ID space into chunks and processes them in parallel. + """ + id_range = get_id_range() + chunks = create_chunks(id_range) + chunks.map(copy_chunk) diff --git a/dags/slack_alerts.py b/dags/slack_alerts.py index d18ce893cc..7d6371b9b7 100644 --- a/dags/slack_alerts.py +++ b/dags/slack_alerts.py @@ -15,6 +15,7 @@ notification_channel_per_team = { JobOwners.TEAM_ERROR_TRACKING.value: "#alerts-error-tracking", JobOwners.TEAM_EXPERIMENTS.value: "#alerts-experiments-dagster", JobOwners.TEAM_GROWTH.value: "#alerts-growth", + JobOwners.TEAM_INGESTION.value: "#alerts-ingestion", JobOwners.TEAM_MAX_AI.value: "#alerts-max-ai", JobOwners.TEAM_REVENUE_ANALYTICS.value: "#alerts-revenue-analytics", JobOwners.TEAM_WEB_ANALYTICS.value: "#alerts-web-analytics", diff --git a/dags/tests/test_persons_new_backfill.py b/dags/tests/test_persons_new_backfill.py new file mode 100644 index 0000000000..68537c3bf8 --- /dev/null +++ b/dags/tests/test_persons_new_backfill.py @@ -0,0 +1,485 @@ +"""Tests for the persons new backfill job.""" + +from unittest.mock import MagicMock, patch + +import psycopg2.errors +from dagster import build_op_context + +from dags.persons_new_backfill import PersonsNewBackfillConfig, copy_chunk, create_chunks + + +class TestCreateChunks: + """Test the create_chunks function.""" + + def test_create_chunks_produces_non_overlapping_ranges(self): + """Test that chunks produce non-overlapping ranges.""" + config = PersonsNewBackfillConfig(chunk_size=1000) + id_range = (1, 5000) # min_id=1, max_id=5000 + + context = build_op_context() + chunks = list(create_chunks(context, config, id_range)) + + # Extract all chunk ranges from DynamicOutput objects + chunk_ranges = [chunk.value for chunk in chunks] + + # Verify no overlaps + for i, (min1, max1) in enumerate(chunk_ranges): + for j, (min2, max2) in enumerate(chunk_ranges): + if i != j: + # Chunks should not overlap + assert not ( + min1 <= min2 <= max1 or min1 <= max2 <= max1 or min2 <= min1 <= max2 + ), f"Chunks overlap: ({min1}, {max1}) and ({min2}, {max2})" + + def test_create_chunks_covers_entire_id_space(self): + """Test that chunks cover the entire ID space from min to max.""" + config = PersonsNewBackfillConfig(chunk_size=1000) + min_id, max_id = 1, 5000 + id_range = (min_id, max_id) + + context = build_op_context() + chunks = list(create_chunks(context, config, id_range)) + + # Extract all chunk ranges from DynamicOutput objects + chunk_ranges = [chunk.value for chunk in chunks] + + # Find the overall min and max covered + all_ids_covered: set[int] = set() + for chunk_min, chunk_max in chunk_ranges: + all_ids_covered.update(range(chunk_min, chunk_max + 1)) + + # Verify all IDs from min_id to max_id are covered + expected_ids = set(range(min_id, max_id + 1)) + assert all_ids_covered == expected_ids, ( + f"Missing IDs: {expected_ids - all_ids_covered}, " f"Extra IDs: {all_ids_covered - expected_ids}" + ) + + def test_create_chunks_first_chunk_includes_max_id(self): + """Test that the first chunk (in yielded order) includes the source table max_id.""" + config = PersonsNewBackfillConfig(chunk_size=1000) + min_id, max_id = 1, 5000 + id_range = (min_id, max_id) + + context = build_op_context() + chunks = list(create_chunks(context, config, id_range)) + + # First chunk in the list (yielded first, highest IDs) + first_chunk_min, first_chunk_max = chunks[0].value + + assert first_chunk_max == max_id, f"First chunk max ({first_chunk_max}) should equal source max_id ({max_id})" + assert ( + first_chunk_min <= max_id <= first_chunk_max + ), f"First chunk ({first_chunk_min}, {first_chunk_max}) should include max_id ({max_id})" + + def test_create_chunks_final_chunk_includes_min_id(self): + """Test that the final chunk (in yielded order) includes the source table min_id.""" + config = PersonsNewBackfillConfig(chunk_size=1000) + min_id, max_id = 1, 5000 + id_range = (min_id, max_id) + + context = build_op_context() + chunks = list(create_chunks(context, config, id_range)) + + # Last chunk in the list (yielded last, lowest IDs) + final_chunk_min, final_chunk_max = chunks[-1].value + + assert final_chunk_min == min_id, f"Final chunk min ({final_chunk_min}) should equal source min_id ({min_id})" + assert ( + final_chunk_min <= min_id <= final_chunk_max + ), f"Final chunk ({final_chunk_min}, {final_chunk_max}) should include min_id ({min_id})" + + def test_create_chunks_reverse_order(self): + """Test that chunks are yielded in reverse order (highest IDs first).""" + config = PersonsNewBackfillConfig(chunk_size=1000) + min_id, max_id = 1, 5000 + id_range = (min_id, max_id) + + context = build_op_context() + chunks = list(create_chunks(context, config, id_range)) + + # Verify chunks are in descending order by max_id + for i in range(len(chunks) - 1): + current_max = chunks[i].value[1] + next_max = chunks[i + 1].value[1] + assert ( + current_max > next_max + ), f"Chunks not in reverse order: chunk {i} max ({current_max}) should be > chunk {i+1} max ({next_max})" + + def test_create_chunks_exact_multiple(self): + """Test chunk creation when ID range is an exact multiple of chunk_size.""" + config = PersonsNewBackfillConfig(chunk_size=1000) + min_id, max_id = 1, 5000 # Exactly 5 chunks of 1000 + id_range = (min_id, max_id) + + context = build_op_context() + chunks = list(create_chunks(context, config, id_range)) + + assert len(chunks) == 5, f"Expected 5 chunks, got {len(chunks)}" + + # Verify first chunk (highest IDs) + assert chunks[0].value == (4001, 5000), f"First chunk should be (4001, 5000), got {chunks[0].value}" + + # Verify last chunk (lowest IDs) + assert chunks[-1].value == (1, 1000), f"Last chunk should be (1, 1000), got {chunks[-1].value}" + + def test_create_chunks_non_exact_multiple(self): + """Test chunk creation when ID range is not an exact multiple of chunk_size.""" + config = PersonsNewBackfillConfig(chunk_size=1000) + min_id, max_id = 1, 3750 # 3 full chunks + 1 partial chunk + id_range = (min_id, max_id) + + context = build_op_context() + chunks = list(create_chunks(context, config, id_range)) + + assert len(chunks) == 4, f"Expected 4 chunks, got {len(chunks)}" + + # Verify first chunk (highest IDs) - should be the partial chunk + assert chunks[0].value == (3001, 3750), f"First chunk should be (3001, 3750), got {chunks[0].value}" + + # Verify last chunk (lowest IDs) + assert chunks[-1].value == (1, 1000), f"Last chunk should be (1, 1000), got {chunks[-1].value}" + + def test_create_chunks_single_chunk(self): + """Test chunk creation when ID range fits in a single chunk.""" + config = PersonsNewBackfillConfig(chunk_size=1000) + min_id, max_id = 100, 500 + id_range = (min_id, max_id) + + context = build_op_context() + chunks = list(create_chunks(context, config, id_range)) + + assert len(chunks) == 1, f"Expected 1 chunk, got {len(chunks)}" + assert chunks[0].value == (100, 500), f"Chunk should be (100, 500), got {chunks[0].value}" + assert chunks[0].value[0] == min_id and chunks[0].value[1] == max_id + + +def create_mock_database_resource(rowcount_values=None): + """ + Create a mock database resource that mimics psycopg2.extensions.connection. + + Args: + rowcount_values: List of rowcount values to return per INSERT call. + If None, defaults to 0. If a single int, uses that for all calls. + """ + mock_cursor = MagicMock() + if rowcount_values is None: + mock_cursor.rowcount = 0 + elif isinstance(rowcount_values, int): + mock_cursor.rowcount = rowcount_values + else: + # Use side_effect to return different rowcounts per call + call_count = [0] + + def get_rowcount(): + if call_count[0] < len(rowcount_values): + result = rowcount_values[call_count[0]] + call_count[0] += 1 + return result + return rowcount_values[-1] if rowcount_values else 0 + + mock_cursor.rowcount = property(lambda self: get_rowcount()) + + mock_cursor.execute = MagicMock() + + # Make cursor() return a context manager + mock_conn = MagicMock() + mock_conn.cursor.return_value.__enter__ = MagicMock(return_value=mock_cursor) + mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False) + + return mock_conn + + +def create_mock_cluster_resource(): + """Create a mock ClickhouseCluster resource.""" + return MagicMock() + + +class TestCopyChunk: + """Test the copy_chunk function.""" + + def test_copy_chunk_single_batch_success(self): + """Test successful copy of a single batch within a chunk.""" + config = PersonsNewBackfillConfig( + chunk_size=1000, batch_size=100, source_table="posthog_persons", destination_table="posthog_persons_new" + ) + chunk = (1, 100) # Single batch covers entire chunk + + mock_db = create_mock_database_resource(rowcount_values=50) + mock_cluster = create_mock_cluster_resource() + + context = build_op_context( + resources={"database": mock_db, "cluster": mock_cluster}, + ) + # Patch context.run.job_name where it's accessed in copy_chunk + from unittest.mock import PropertyMock + + with patch.object(type(context), "run", PropertyMock(return_value=MagicMock(job_name="test_job"))): + result = copy_chunk(context, config, chunk) + + # Verify result + assert result["chunk_min"] == 1 + assert result["chunk_max"] == 100 + assert result["records_copied"] == 50 + + # Verify SET statements called once (session-level, before loop) + set_statements = [ + "SET application_name = 'backfill_posthog_persons_to_posthog_persons_new'", + "SET lock_timeout = '5s'", + "SET statement_timeout = '30min'", + "SET maintenance_work_mem = '12GB'", + "SET work_mem = '512MB'", + "SET temp_buffers = '512MB'", + "SET max_parallel_workers_per_gather = 2", + "SET max_parallel_maintenance_workers = 2", + "SET synchronous_commit = off", + ] + + cursor = mock_db.cursor.return_value.__enter__.return_value + execute_calls = [call[0][0] for call in cursor.execute.call_args_list] + + # Check SET statements were called + for stmt in set_statements: + assert any(stmt in call for call in execute_calls), f"SET statement not found: {stmt}" + + # Verify BEGIN, INSERT, COMMIT called once + assert execute_calls.count("BEGIN") == 1 + assert execute_calls.count("COMMIT") == 1 + + # Verify INSERT query format + insert_calls = [call for call in execute_calls if "INSERT INTO" in call] + assert len(insert_calls) == 1 + insert_query = insert_calls[0] + assert "INSERT INTO posthog_persons_new" in insert_query + assert "SELECT s.*" in insert_query + assert "FROM posthog_persons s" in insert_query + assert "WHERE s.id >" in insert_query + assert "AND s.id <=" in insert_query + assert "NOT EXISTS" in insert_query + assert "ORDER BY s.id DESC" in insert_query + + def test_copy_chunk_multiple_batches(self): + """Test copy with multiple batches in a chunk.""" + config = PersonsNewBackfillConfig( + chunk_size=1000, batch_size=100, source_table="posthog_persons", destination_table="posthog_persons_new" + ) + chunk = (1, 250) # 3 batches: (1,100), (100,200), (200,250) + + mock_db = create_mock_database_resource() + mock_cluster = create_mock_cluster_resource() + + # Track rowcount per batch - use a list to track INSERT calls + rowcounts = [50, 75, 25] + insert_call_count = [0] + + cursor = mock_db.cursor.return_value.__enter__.return_value + + # Track INSERT calls and set rowcount accordingly + def execute_with_rowcount(query, *args): + if "INSERT INTO" in query: + if insert_call_count[0] < len(rowcounts): + cursor.rowcount = rowcounts[insert_call_count[0]] + insert_call_count[0] += 1 + else: + cursor.rowcount = 0 + + cursor.execute.side_effect = execute_with_rowcount + + context = build_op_context( + resources={"database": mock_db, "cluster": mock_cluster}, + ) + # Patch context.run.job_name where it's accessed in copy_chunk + from unittest.mock import PropertyMock + + with patch.object(type(context), "run", PropertyMock(return_value=MagicMock(job_name="test_job"))): + result = copy_chunk(context, config, chunk) + + # Verify result + assert result["chunk_min"] == 1 + assert result["chunk_max"] == 250 + assert result["records_copied"] == 150 # 50 + 75 + 25 + + # Verify SET statements called once (before loop) + cursor = mock_db.cursor.return_value.__enter__.return_value + execute_calls = [call[0][0] for call in cursor.execute.call_args_list] + + # Verify BEGIN/COMMIT called 3 times (one per batch) + assert execute_calls.count("BEGIN") == 3 + assert execute_calls.count("COMMIT") == 3 + + # Verify INSERT called 3 times + insert_calls = [call for call in execute_calls if "INSERT INTO" in call] + assert len(insert_calls) == 3 + + def test_copy_chunk_duplicate_key_violation_retry(self): + """Test that duplicate key violation triggers retry.""" + config = PersonsNewBackfillConfig( + chunk_size=1000, batch_size=100, source_table="posthog_persons", destination_table="posthog_persons_new" + ) + chunk = (1, 100) + + mock_db = create_mock_database_resource() + mock_cluster = create_mock_cluster_resource() + + cursor = mock_db.cursor.return_value.__enter__.return_value + + # Track INSERT attempts + insert_attempts = [0] + + # First INSERT raises UniqueViolation, second succeeds + def execute_side_effect(query, *args): + if "INSERT INTO" in query: + insert_attempts[0] += 1 + if insert_attempts[0] == 1: + # First INSERT attempt raises error + # Use real UniqueViolation - pgcode is readonly but isinstance check will pass + raise psycopg2.errors.UniqueViolation("duplicate key value violates unique constraint") + # Subsequent calls succeed + cursor.rowcount = 50 # Success on retry + + cursor.execute.side_effect = execute_side_effect + + context = build_op_context( + resources={"database": mock_db, "cluster": mock_cluster}, + ) + # Need to patch time.sleep and run.job_name + from unittest.mock import PropertyMock + + mock_run = MagicMock(job_name="test_job") + with ( + patch("dags.persons_new_backfill.time.sleep"), + patch.object(type(context), "run", PropertyMock(return_value=mock_run)), + ): + copy_chunk(context, config, chunk) + + # Verify ROLLBACK was called on error + execute_calls = [call[0][0] for call in cursor.execute.call_args_list] + assert "ROLLBACK" in execute_calls + + # Verify retry succeeded (should have INSERT called twice, COMMIT once) + insert_calls = [call for call in execute_calls if "INSERT INTO" in call] + assert len(insert_calls) >= 1 # At least one successful INSERT + + def test_copy_chunk_error_handling_and_rollback(self): + """Test error handling and rollback on non-duplicate errors.""" + config = PersonsNewBackfillConfig( + chunk_size=1000, batch_size=100, source_table="posthog_persons", destination_table="posthog_persons_new" + ) + chunk = (1, 100) + + mock_db = create_mock_database_resource() + mock_cluster = create_mock_cluster_resource() + + cursor = mock_db.cursor.return_value.__enter__.return_value + + # Raise generic error on INSERT + def execute_side_effect(query, *args): + if "INSERT INTO" in query: + raise Exception("Connection lost") + + cursor.execute.side_effect = execute_side_effect + + context = build_op_context( + resources={"database": mock_db, "cluster": mock_cluster}, + ) + # Patch context.run.job_name where it's accessed in copy_chunk + from unittest.mock import PropertyMock + + mock_run = MagicMock(job_name="test_job") + with patch.object(type(context), "run", PropertyMock(return_value=mock_run)): + # Should raise Dagster.Failure + from dagster import Failure + + try: + copy_chunk(context, config, chunk) + raise AssertionError("Expected Dagster.Failure to be raised") + except Failure as e: + # Verify error metadata + assert e.description is not None + assert "Failed to copy batch" in e.description + + # Verify ROLLBACK was called + execute_calls = [call[0][0] for call in cursor.execute.call_args_list] + assert "ROLLBACK" in execute_calls + + def test_copy_chunk_insert_query_format(self): + """Test that INSERT query has correct format.""" + config = PersonsNewBackfillConfig( + chunk_size=1000, batch_size=100, source_table="test_source", destination_table="test_dest" + ) + chunk = (1, 100) + + mock_db = create_mock_database_resource(rowcount_values=10) + mock_cluster = create_mock_cluster_resource() + + context = build_op_context( + resources={"database": mock_db, "cluster": mock_cluster}, + ) + # Patch context.run.job_name where it's accessed in copy_chunk + from unittest.mock import PropertyMock + + with patch.object(type(context), "run", PropertyMock(return_value=MagicMock(job_name="test_job"))): + copy_chunk(context, config, chunk) + + cursor = mock_db.cursor.return_value.__enter__.return_value + execute_calls = [call[0][0] for call in cursor.execute.call_args_list] + + # Find INSERT query + insert_query = next((call for call in execute_calls if "INSERT INTO" in call), None) + assert insert_query is not None + + # Verify query components + assert "INSERT INTO test_dest" in insert_query + assert "SELECT s.*" in insert_query + assert "FROM test_source s" in insert_query + assert "WHERE s.id >" in insert_query + assert "AND s.id <=" in insert_query + assert "NOT EXISTS" in insert_query + assert "d.team_id = s.team_id" in insert_query + assert "d.id = s.id" in insert_query + assert "ORDER BY s.id DESC" in insert_query + + def test_copy_chunk_session_settings_applied_once(self): + """Test that SET statements are applied once at session level before batch loop.""" + config = PersonsNewBackfillConfig( + chunk_size=1000, batch_size=50, source_table="posthog_persons", destination_table="posthog_persons_new" + ) + chunk = (1, 150) # 3 batches + + mock_db = create_mock_database_resource(rowcount_values=25) + mock_cluster = create_mock_cluster_resource() + + context = build_op_context( + resources={"database": mock_db, "cluster": mock_cluster}, + ) + # Patch context.run.job_name where it's accessed in copy_chunk + from unittest.mock import PropertyMock + + with patch.object(type(context), "run", PropertyMock(return_value=MagicMock(job_name="test_job"))): + copy_chunk(context, config, chunk) + + cursor = mock_db.cursor.return_value.__enter__.return_value + execute_calls = [call[0][0] for call in cursor.execute.call_args_list] + + # Count SET statements (should be called once each, before loop) + set_statements = [ + "SET application_name", + "SET lock_timeout", + "SET statement_timeout", + "SET maintenance_work_mem", + "SET work_mem", + "SET temp_buffers", + "SET max_parallel_workers_per_gather", + "SET max_parallel_maintenance_workers", + "SET synchronous_commit", + ] + + for stmt in set_statements: + count = sum(1 for call in execute_calls if stmt in call) + assert count == 1, f"Expected {stmt} to be called once, but it was called {count} times" + + # Verify SET statements come before BEGIN statements + set_indices = [i for i, call in enumerate(execute_calls) if any(stmt in call for stmt in set_statements)] + begin_indices = [i for i, call in enumerate(execute_calls) if call == "BEGIN"] + + if set_indices and begin_indices: + assert max(set_indices) < min(begin_indices), "SET statements should come before BEGIN statements"