feat(dagster): persons backfill job (#41565)

Co-authored-by: Lucas Ricoy <2034367+lricoy@users.noreply.github.com>
This commit is contained in:
Eli Reisman
2025-11-14 19:38:28 -08:00
committed by GitHub
parent f1e1a72fb9
commit a2a7f2a7a1
7 changed files with 941 additions and 1 deletions

View File

@@ -19,3 +19,4 @@ load_from:
- python_module: dags.locations.llma
- python_module: dags.locations.max_ai
- python_module: dags.locations.web_analytics
- python_module: dags.locations.ingestion

View File

@@ -6,6 +6,8 @@ from typing import Optional
from django.conf import settings
import dagster
import psycopg2
import psycopg2.extras
from clickhouse_driver.errors import Error, ErrorCodes
from posthog.clickhouse import query_tagging
@@ -22,6 +24,7 @@ class JobOwners(str, Enum):
TEAM_ERROR_TRACKING = "team-error-tracking"
TEAM_EXPERIMENTS = "team-experiments"
TEAM_GROWTH = "team-growth"
TEAM_INGESTION = "team-ingestion"
TEAM_LLMA = "team-llma"
TEAM_MAX_AI = "team-max-ai"
TEAM_REVENUE_ANALYTICS = "team-revenue-analytics"
@@ -80,6 +83,28 @@ class RedisResource(dagster.ConfigurableResource):
return client
class PostgresResource(dagster.ConfigurableResource):
"""
A Postgres database connection resource that returns a psycopg2 connection.
"""
host: str
port: str = "5432"
database: str
user: str
password: str
def create_resource(self, context: dagster.InitResourceContext) -> psycopg2.extensions.connection:
return psycopg2.connect(
host=self.host,
port=int(self.port),
database=self.database,
user=self.user,
password=self.password,
cursor_factory=psycopg2.extras.RealDictCursor,
)
def report_job_status_metric(
context: dagster.RunStatusSensorContext, cluster: dagster.ResourceParam[ClickhouseCluster]
) -> None:

View File

@@ -5,7 +5,7 @@ import dagster_slack
from dagster_aws.s3.io_manager import s3_pickle_io_manager
from dagster_aws.s3.resources import S3Resource
from dags.common import ClickhouseClusterResource, RedisResource
from dags.common import ClickhouseClusterResource, PostgresResource, RedisResource
# Define resources for different environments
resources_by_env = {
@@ -18,6 +18,14 @@ resources_by_env = {
"s3": S3Resource(),
# Using EnvVar instead of the Django setting to ensure that the token is not leaked anywhere in the Dagster UI
"slack": dagster_slack.SlackResource(token=dagster.EnvVar("SLACK_TOKEN")),
# Postgres resource (universal for all dags)
"database": PostgresResource(
host=dagster.EnvVar("POSTGRES_HOST"),
port=dagster.EnvVar("POSTGRES_PORT"),
database=dagster.EnvVar("POSTGRES_DATABASE"),
user=dagster.EnvVar("POSTGRES_USER"),
password=dagster.EnvVar("POSTGRES_PASSWORD"),
),
},
"local": {
"cluster": ClickhouseClusterResource.configure_at_launch(),
@@ -29,6 +37,14 @@ resources_by_env = {
aws_secret_access_key=settings.OBJECT_STORAGE_SECRET_ACCESS_KEY,
),
"slack": dagster.ResourceDefinition.none_resource(description="Dummy Slack resource for local development"),
# Postgres resource (universal for all dags) - use Django settings or env vars for local dev
"database": PostgresResource(
host=dagster.EnvVar("POSTGRES_HOST"),
port=dagster.EnvVar("POSTGRES_PORT"),
database=dagster.EnvVar("POSTGRES_DATABASE"),
user=dagster.EnvVar("POSTGRES_USER"),
password=dagster.EnvVar("POSTGRES_PASSWORD"),
),
},
}

View File

@@ -0,0 +1,15 @@
import dagster
from dags import persons_new_backfill
from . import resources
defs = dagster.Definitions(
assets=[
persons_new_backfill.postgres_env_check,
],
jobs=[
persons_new_backfill.persons_new_backfill_job,
],
resources=resources,
)

View File

@@ -0,0 +1,397 @@
"""Dagster job for backfilling posthog_persons data from source to destination Postgres database."""
import os
import time
from typing import Any
import dagster
import psycopg2
import psycopg2.errors
from posthog.clickhouse.cluster import ClickhouseCluster
from posthog.clickhouse.custom_metrics import MetricsClient
from dags.common import JobOwners
class PersonsNewBackfillConfig(dagster.Config):
"""Configuration for the persons new backfill job."""
chunk_size: int = 1_000_000 # ID range per chunk
batch_size: int = 100_000 # Records per batch insert
source_table: str = "posthog_persons"
destination_table: str = "posthog_persons_new"
max_id: int | None = None # Optional override for max ID to resume from partial state
@dagster.op
def get_id_range(
context: dagster.OpExecutionContext,
config: PersonsNewBackfillConfig,
database: dagster.ResourceParam[psycopg2.extensions.connection],
) -> tuple[int, int]:
"""
Query source database for MIN(id) and optionally MAX(id) from posthog_persons table.
If max_id is provided in config, uses that instead of querying.
Returns tuple (min_id, max_id).
"""
with database.cursor() as cursor:
# Always query for min_id
min_query = f"SELECT MIN(id) as min_id FROM {config.source_table}"
context.log.info(f"Querying min ID: {min_query}")
cursor.execute(min_query)
min_result = cursor.fetchone()
if min_result is None or min_result["min_id"] is None:
context.log.exception("Source table is empty or has no valid IDs")
# Note: No metrics client here as this is get_id_range op, not copy_chunk
raise dagster.Failure("Source table is empty or has no valid IDs")
min_id = int(min_result["min_id"])
# Use config max_id if provided, otherwise query database
if config.max_id is not None:
max_id = config.max_id
context.log.info(f"Using configured max_id override: {max_id}")
else:
max_query = f"SELECT MAX(id) as max_id FROM {config.source_table}"
context.log.info(f"Querying max ID: {max_query}")
cursor.execute(max_query)
max_result = cursor.fetchone()
if max_result is None or max_result["max_id"] is None:
context.log.exception("Source table has no valid max ID")
# Note: No metrics client here as this is get_id_range op, not copy_chunk
raise dagster.Failure("Source table has no valid max ID")
max_id = int(max_result["max_id"])
# Validate that max_id >= min_id
if max_id < min_id:
error_msg = f"Invalid ID range: max_id ({max_id}) < min_id ({min_id})"
context.log.error(error_msg)
# Note: No metrics client here as this is get_id_range op, not copy_chunk
raise dagster.Failure(error_msg)
context.log.info(f"ID range: min={min_id}, max={max_id}, total_ids={max_id - min_id + 1}")
context.add_output_metadata(
{
"min_id": dagster.MetadataValue.int(min_id),
"max_id": dagster.MetadataValue.int(max_id),
"max_id_source": dagster.MetadataValue.text("config" if config.max_id is not None else "database"),
"total_ids": dagster.MetadataValue.int(max_id - min_id + 1),
}
)
return (min_id, max_id)
@dagster.op(out=dagster.DynamicOut(tuple[int, int]))
def create_chunks(
context: dagster.OpExecutionContext,
config: PersonsNewBackfillConfig,
id_range: tuple[int, int],
):
"""
Divide ID space into chunks of chunk_size.
Yields DynamicOutput for each chunk in reverse order (highest IDs first, lowest IDs last).
This ensures that if the job fails partway through, the final chunk to process will be
the one starting at the source table's min_id.
"""
min_id, max_id = id_range
chunk_size = config.chunk_size
# First, collect all chunks
chunks = []
chunk_min = min_id
chunk_num = 0
while chunk_min <= max_id:
chunk_max = min(chunk_min + chunk_size - 1, max_id)
chunks.append((chunk_min, chunk_max, chunk_num))
chunk_min = chunk_max + 1
chunk_num += 1
context.log.info(f"Created {chunk_num} chunks total")
# Yield chunks in reverse order (highest IDs first)
for chunk_min, chunk_max, chunk_num in reversed(chunks):
chunk_key = f"chunk_{chunk_min}_{chunk_max}"
context.log.info(f"Yielding chunk {chunk_num}: {chunk_min} to {chunk_max}")
yield dagster.DynamicOutput(
value=(chunk_min, chunk_max),
mapping_key=chunk_key,
)
@dagster.op
def copy_chunk(
context: dagster.OpExecutionContext,
config: PersonsNewBackfillConfig,
chunk: tuple[int, int],
database: dagster.ResourceParam[psycopg2.extensions.connection],
cluster: dagster.ResourceParam[ClickhouseCluster],
) -> dict[str, Any]:
"""
Copy a chunk of records from source to destination database.
Processes in batches of batch_size records.
"""
chunk_min, chunk_max = chunk
batch_size = config.batch_size
source_table = config.source_table
destination_table = config.destination_table
chunk_id = f"chunk_{chunk_min}_{chunk_max}"
job_name = context.run.job_name
# Initialize metrics client
metrics_client = MetricsClient(cluster)
context.log.info(f"Starting chunk copy: {chunk_min} to {chunk_max}")
total_records_copied = 0
batch_start_id = chunk_min
failed_batch_start_id: int | None = None
try:
with database.cursor() as cursor:
# Set session-level settings once for the entire chunk
cursor.execute("SET application_name = 'backfill_posthog_persons_to_posthog_persons_new'")
cursor.execute("SET lock_timeout = '5s'")
cursor.execute("SET statement_timeout = '30min'")
cursor.execute("SET maintenance_work_mem = '12GB'")
cursor.execute("SET work_mem = '512MB'")
cursor.execute("SET temp_buffers = '512MB'")
cursor.execute("SET max_parallel_workers_per_gather = 2")
cursor.execute("SET max_parallel_maintenance_workers = 2")
cursor.execute("SET synchronous_commit = off")
retry_attempt = 0
while batch_start_id <= chunk_max:
try:
# Track batch start time for duration metric
batch_start_time = time.time()
# Calculate batch end ID
batch_end_id = min(batch_start_id + batch_size, chunk_max)
# Track records attempted - this is also our exit condition
records_attempted = batch_end_id - batch_start_id
if records_attempted <= 0:
break
# Begin transaction (settings already applied at session level)
cursor.execute("BEGIN")
# Execute INSERT INTO ... SELECT with NOT EXISTS check
insert_query = f"""
INSERT INTO {destination_table}
SELECT s.*
FROM {source_table} s
WHERE s.id >= %s AND s.id <= %s
AND NOT EXISTS (
SELECT 1
FROM {destination_table} d
WHERE d.team_id = s.team_id
AND d.id = s.id
)
ORDER BY s.id DESC
"""
cursor.execute(insert_query, (batch_start_id, batch_end_id))
records_inserted = cursor.rowcount
# Commit the transaction
cursor.execute("COMMIT")
try:
metrics_client.increment(
"persons_new_backfill_records_attempted_total",
labels={"job_name": job_name, "chunk_id": chunk_id},
value=float(records_attempted),
).result()
except Exception:
pass # Don't fail on metrics error
batch_duration_seconds = time.time() - batch_start_time
try:
metrics_client.increment(
"persons_new_backfill_records_inserted_total",
labels={"job_name": job_name, "chunk_id": chunk_id},
value=float(records_inserted),
).result()
except Exception:
pass # Don't fail on metrics error
try:
metrics_client.increment(
"persons_new_backfill_batches_copied_total",
labels={"job_name": job_name, "chunk_id": chunk_id},
value=1.0,
).result()
except Exception:
pass
# Track batch duration metric (IV)
try:
metrics_client.increment(
"persons_new_backfill_batch_duration_seconds_total",
labels={"job_name": job_name, "chunk_id": chunk_id},
value=batch_duration_seconds,
).result()
except Exception:
pass
total_records_copied += records_inserted
context.log.info(
f"Copied batch: {records_inserted} records "
f"(chunk {chunk_min}-{chunk_max}, batch ID range {batch_start_id} to {batch_end_id})"
)
# Update batch_start_id for next iteration
batch_start_id = batch_end_id + 1
retry_attempt = 0
except Exception as batch_error:
# Rollback transaction on error
try:
cursor.execute("ROLLBACK")
except Exception as rollback_error:
context.log.exception(
f"Failed to rollback transaction for batch starting at ID {batch_start_id}"
f"in chunk {chunk_min}-{chunk_max}: {str(rollback_error)}"
)
pass # Ignore rollback errors
# Check if error is a duplicate key violation, pause and retry if so
is_unique_violation = isinstance(batch_error, psycopg2.errors.UniqueViolation) or (
isinstance(batch_error, psycopg2.Error) and getattr(batch_error, "pgcode", None) == "23505"
)
if is_unique_violation:
error_msg = (
f"Duplicate key violation detected for batch starting at ID {batch_start_id} "
f"in chunk {chunk_min}-{chunk_max}. Error is: {batch_error}. "
"This is expected if records already exist in destination table. "
)
context.log.warning(error_msg)
if retry_attempt < 3:
retry_attempt += 1
context.log.info(f"Retrying batch {retry_attempt} of 3...")
time.sleep(1)
continue
failed_batch_start_id = batch_start_id
error_msg = (
f"Failed to copy batch starting at ID {batch_start_id} "
f"in chunk {chunk_min}-{chunk_max}: {str(batch_error)}"
)
context.log.exception(error_msg)
# Report fatal error metric before raising
try:
metrics_client.increment(
"persons_new_backfill_error",
labels={"job_name": job_name, "chunk_id": chunk_id, "reason": "batch_copy_failed"},
value=1.0,
).result()
except Exception:
pass # Don't fail on metrics error
raise dagster.Failure(
description=error_msg,
metadata={
"chunk_min_id": dagster.MetadataValue.int(chunk_min),
"chunk_max_id": dagster.MetadataValue.int(chunk_max),
"failed_batch_start_id": dagster.MetadataValue.int(failed_batch_start_id)
if failed_batch_start_id
else dagster.MetadataValue.text("N/A"),
"error_message": dagster.MetadataValue.text(str(batch_error)),
"records_copied_before_failure": dagster.MetadataValue.int(total_records_copied),
},
) from batch_error
except dagster.Failure:
# Re-raise Dagster failures as-is (they already have metadata and metrics)
raise
except Exception as e:
# Catch any other unexpected errors
error_msg = f"Unexpected error copying chunk {chunk_min}-{chunk_max}: {str(e)}"
context.log.exception(error_msg)
# Report fatal error metric before raising
try:
metrics_client.increment(
"persons_new_backfill_error",
labels={"job_name": job_name, "chunk_id": chunk_id, "reason": "unexpected_copy_error"},
value=1.0,
).result()
except Exception:
pass # Don't fail on metrics error
raise dagster.Failure(
description=error_msg,
metadata={
"chunk_min_id": dagster.MetadataValue.int(chunk_min),
"chunk_max_id": dagster.MetadataValue.int(chunk_max),
"failed_batch_start_id": dagster.MetadataValue.int(failed_batch_start_id)
if failed_batch_start_id
else dagster.MetadataValue.int(batch_start_id),
"error_message": dagster.MetadataValue.text(str(e)),
"records_copied_before_failure": dagster.MetadataValue.int(total_records_copied),
},
) from e
context.log.info(f"Completed chunk {chunk_min}-{chunk_max}: copied {total_records_copied} records")
# Emit metric for chunk completion
run_id = context.run.run_id
try:
metrics_client.increment(
"persons_new_backfill_chunks_completed_total",
labels={"job_name": job_name, "run_id": run_id, "chunk_id": chunk_id},
value=1.0,
).result()
except Exception:
pass # Don't fail on metrics error
context.add_output_metadata(
{
"chunk_min": dagster.MetadataValue.int(chunk_min),
"chunk_max": dagster.MetadataValue.int(chunk_max),
"records_copied": dagster.MetadataValue.int(total_records_copied),
}
)
return {
"chunk_min": chunk_min,
"chunk_max": chunk_max,
"records_copied": total_records_copied,
}
@dagster.asset
def postgres_env_check(context: dagster.AssetExecutionContext) -> None:
"""
Simple asset that prints PostgreSQL environment variables being used.
Useful for debugging connection configuration.
"""
env_vars = {
"POSTGRES_HOST": os.getenv("POSTGRES_HOST", "not set"),
"POSTGRES_PORT": os.getenv("POSTGRES_PORT", "not set"),
"POSTGRES_DATABASE": os.getenv("POSTGRES_DATABASE", "not set"),
"POSTGRES_USER": os.getenv("POSTGRES_USER", "not set"),
"POSTGRES_PASSWORD": "***" if os.getenv("POSTGRES_PASSWORD") else "not set",
}
context.log.info("PostgreSQL environment variables:")
for key, value in env_vars.items():
context.log.info(f" {key}: {value}")
@dagster.job(
tags={"owner": JobOwners.TEAM_INGESTION.value},
executor_def=dagster.multiprocess_executor.configured({"max_concurrent": 32}),
)
def persons_new_backfill_job():
"""
Backfill posthog_persons data from source to destination Postgres database.
Divides the ID space into chunks and processes them in parallel.
"""
id_range = get_id_range()
chunks = create_chunks(id_range)
chunks.map(copy_chunk)

View File

@@ -15,6 +15,7 @@ notification_channel_per_team = {
JobOwners.TEAM_ERROR_TRACKING.value: "#alerts-error-tracking",
JobOwners.TEAM_EXPERIMENTS.value: "#alerts-experiments-dagster",
JobOwners.TEAM_GROWTH.value: "#alerts-growth",
JobOwners.TEAM_INGESTION.value: "#alerts-ingestion",
JobOwners.TEAM_MAX_AI.value: "#alerts-max-ai",
JobOwners.TEAM_REVENUE_ANALYTICS.value: "#alerts-revenue-analytics",
JobOwners.TEAM_WEB_ANALYTICS.value: "#alerts-web-analytics",

View File

@@ -0,0 +1,485 @@
"""Tests for the persons new backfill job."""
from unittest.mock import MagicMock, patch
import psycopg2.errors
from dagster import build_op_context
from dags.persons_new_backfill import PersonsNewBackfillConfig, copy_chunk, create_chunks
class TestCreateChunks:
"""Test the create_chunks function."""
def test_create_chunks_produces_non_overlapping_ranges(self):
"""Test that chunks produce non-overlapping ranges."""
config = PersonsNewBackfillConfig(chunk_size=1000)
id_range = (1, 5000) # min_id=1, max_id=5000
context = build_op_context()
chunks = list(create_chunks(context, config, id_range))
# Extract all chunk ranges from DynamicOutput objects
chunk_ranges = [chunk.value for chunk in chunks]
# Verify no overlaps
for i, (min1, max1) in enumerate(chunk_ranges):
for j, (min2, max2) in enumerate(chunk_ranges):
if i != j:
# Chunks should not overlap
assert not (
min1 <= min2 <= max1 or min1 <= max2 <= max1 or min2 <= min1 <= max2
), f"Chunks overlap: ({min1}, {max1}) and ({min2}, {max2})"
def test_create_chunks_covers_entire_id_space(self):
"""Test that chunks cover the entire ID space from min to max."""
config = PersonsNewBackfillConfig(chunk_size=1000)
min_id, max_id = 1, 5000
id_range = (min_id, max_id)
context = build_op_context()
chunks = list(create_chunks(context, config, id_range))
# Extract all chunk ranges from DynamicOutput objects
chunk_ranges = [chunk.value for chunk in chunks]
# Find the overall min and max covered
all_ids_covered: set[int] = set()
for chunk_min, chunk_max in chunk_ranges:
all_ids_covered.update(range(chunk_min, chunk_max + 1))
# Verify all IDs from min_id to max_id are covered
expected_ids = set(range(min_id, max_id + 1))
assert all_ids_covered == expected_ids, (
f"Missing IDs: {expected_ids - all_ids_covered}, " f"Extra IDs: {all_ids_covered - expected_ids}"
)
def test_create_chunks_first_chunk_includes_max_id(self):
"""Test that the first chunk (in yielded order) includes the source table max_id."""
config = PersonsNewBackfillConfig(chunk_size=1000)
min_id, max_id = 1, 5000
id_range = (min_id, max_id)
context = build_op_context()
chunks = list(create_chunks(context, config, id_range))
# First chunk in the list (yielded first, highest IDs)
first_chunk_min, first_chunk_max = chunks[0].value
assert first_chunk_max == max_id, f"First chunk max ({first_chunk_max}) should equal source max_id ({max_id})"
assert (
first_chunk_min <= max_id <= first_chunk_max
), f"First chunk ({first_chunk_min}, {first_chunk_max}) should include max_id ({max_id})"
def test_create_chunks_final_chunk_includes_min_id(self):
"""Test that the final chunk (in yielded order) includes the source table min_id."""
config = PersonsNewBackfillConfig(chunk_size=1000)
min_id, max_id = 1, 5000
id_range = (min_id, max_id)
context = build_op_context()
chunks = list(create_chunks(context, config, id_range))
# Last chunk in the list (yielded last, lowest IDs)
final_chunk_min, final_chunk_max = chunks[-1].value
assert final_chunk_min == min_id, f"Final chunk min ({final_chunk_min}) should equal source min_id ({min_id})"
assert (
final_chunk_min <= min_id <= final_chunk_max
), f"Final chunk ({final_chunk_min}, {final_chunk_max}) should include min_id ({min_id})"
def test_create_chunks_reverse_order(self):
"""Test that chunks are yielded in reverse order (highest IDs first)."""
config = PersonsNewBackfillConfig(chunk_size=1000)
min_id, max_id = 1, 5000
id_range = (min_id, max_id)
context = build_op_context()
chunks = list(create_chunks(context, config, id_range))
# Verify chunks are in descending order by max_id
for i in range(len(chunks) - 1):
current_max = chunks[i].value[1]
next_max = chunks[i + 1].value[1]
assert (
current_max > next_max
), f"Chunks not in reverse order: chunk {i} max ({current_max}) should be > chunk {i+1} max ({next_max})"
def test_create_chunks_exact_multiple(self):
"""Test chunk creation when ID range is an exact multiple of chunk_size."""
config = PersonsNewBackfillConfig(chunk_size=1000)
min_id, max_id = 1, 5000 # Exactly 5 chunks of 1000
id_range = (min_id, max_id)
context = build_op_context()
chunks = list(create_chunks(context, config, id_range))
assert len(chunks) == 5, f"Expected 5 chunks, got {len(chunks)}"
# Verify first chunk (highest IDs)
assert chunks[0].value == (4001, 5000), f"First chunk should be (4001, 5000), got {chunks[0].value}"
# Verify last chunk (lowest IDs)
assert chunks[-1].value == (1, 1000), f"Last chunk should be (1, 1000), got {chunks[-1].value}"
def test_create_chunks_non_exact_multiple(self):
"""Test chunk creation when ID range is not an exact multiple of chunk_size."""
config = PersonsNewBackfillConfig(chunk_size=1000)
min_id, max_id = 1, 3750 # 3 full chunks + 1 partial chunk
id_range = (min_id, max_id)
context = build_op_context()
chunks = list(create_chunks(context, config, id_range))
assert len(chunks) == 4, f"Expected 4 chunks, got {len(chunks)}"
# Verify first chunk (highest IDs) - should be the partial chunk
assert chunks[0].value == (3001, 3750), f"First chunk should be (3001, 3750), got {chunks[0].value}"
# Verify last chunk (lowest IDs)
assert chunks[-1].value == (1, 1000), f"Last chunk should be (1, 1000), got {chunks[-1].value}"
def test_create_chunks_single_chunk(self):
"""Test chunk creation when ID range fits in a single chunk."""
config = PersonsNewBackfillConfig(chunk_size=1000)
min_id, max_id = 100, 500
id_range = (min_id, max_id)
context = build_op_context()
chunks = list(create_chunks(context, config, id_range))
assert len(chunks) == 1, f"Expected 1 chunk, got {len(chunks)}"
assert chunks[0].value == (100, 500), f"Chunk should be (100, 500), got {chunks[0].value}"
assert chunks[0].value[0] == min_id and chunks[0].value[1] == max_id
def create_mock_database_resource(rowcount_values=None):
"""
Create a mock database resource that mimics psycopg2.extensions.connection.
Args:
rowcount_values: List of rowcount values to return per INSERT call.
If None, defaults to 0. If a single int, uses that for all calls.
"""
mock_cursor = MagicMock()
if rowcount_values is None:
mock_cursor.rowcount = 0
elif isinstance(rowcount_values, int):
mock_cursor.rowcount = rowcount_values
else:
# Use side_effect to return different rowcounts per call
call_count = [0]
def get_rowcount():
if call_count[0] < len(rowcount_values):
result = rowcount_values[call_count[0]]
call_count[0] += 1
return result
return rowcount_values[-1] if rowcount_values else 0
mock_cursor.rowcount = property(lambda self: get_rowcount())
mock_cursor.execute = MagicMock()
# Make cursor() return a context manager
mock_conn = MagicMock()
mock_conn.cursor.return_value.__enter__ = MagicMock(return_value=mock_cursor)
mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
return mock_conn
def create_mock_cluster_resource():
"""Create a mock ClickhouseCluster resource."""
return MagicMock()
class TestCopyChunk:
"""Test the copy_chunk function."""
def test_copy_chunk_single_batch_success(self):
"""Test successful copy of a single batch within a chunk."""
config = PersonsNewBackfillConfig(
chunk_size=1000, batch_size=100, source_table="posthog_persons", destination_table="posthog_persons_new"
)
chunk = (1, 100) # Single batch covers entire chunk
mock_db = create_mock_database_resource(rowcount_values=50)
mock_cluster = create_mock_cluster_resource()
context = build_op_context(
resources={"database": mock_db, "cluster": mock_cluster},
)
# Patch context.run.job_name where it's accessed in copy_chunk
from unittest.mock import PropertyMock
with patch.object(type(context), "run", PropertyMock(return_value=MagicMock(job_name="test_job"))):
result = copy_chunk(context, config, chunk)
# Verify result
assert result["chunk_min"] == 1
assert result["chunk_max"] == 100
assert result["records_copied"] == 50
# Verify SET statements called once (session-level, before loop)
set_statements = [
"SET application_name = 'backfill_posthog_persons_to_posthog_persons_new'",
"SET lock_timeout = '5s'",
"SET statement_timeout = '30min'",
"SET maintenance_work_mem = '12GB'",
"SET work_mem = '512MB'",
"SET temp_buffers = '512MB'",
"SET max_parallel_workers_per_gather = 2",
"SET max_parallel_maintenance_workers = 2",
"SET synchronous_commit = off",
]
cursor = mock_db.cursor.return_value.__enter__.return_value
execute_calls = [call[0][0] for call in cursor.execute.call_args_list]
# Check SET statements were called
for stmt in set_statements:
assert any(stmt in call for call in execute_calls), f"SET statement not found: {stmt}"
# Verify BEGIN, INSERT, COMMIT called once
assert execute_calls.count("BEGIN") == 1
assert execute_calls.count("COMMIT") == 1
# Verify INSERT query format
insert_calls = [call for call in execute_calls if "INSERT INTO" in call]
assert len(insert_calls) == 1
insert_query = insert_calls[0]
assert "INSERT INTO posthog_persons_new" in insert_query
assert "SELECT s.*" in insert_query
assert "FROM posthog_persons s" in insert_query
assert "WHERE s.id >" in insert_query
assert "AND s.id <=" in insert_query
assert "NOT EXISTS" in insert_query
assert "ORDER BY s.id DESC" in insert_query
def test_copy_chunk_multiple_batches(self):
"""Test copy with multiple batches in a chunk."""
config = PersonsNewBackfillConfig(
chunk_size=1000, batch_size=100, source_table="posthog_persons", destination_table="posthog_persons_new"
)
chunk = (1, 250) # 3 batches: (1,100), (100,200), (200,250)
mock_db = create_mock_database_resource()
mock_cluster = create_mock_cluster_resource()
# Track rowcount per batch - use a list to track INSERT calls
rowcounts = [50, 75, 25]
insert_call_count = [0]
cursor = mock_db.cursor.return_value.__enter__.return_value
# Track INSERT calls and set rowcount accordingly
def execute_with_rowcount(query, *args):
if "INSERT INTO" in query:
if insert_call_count[0] < len(rowcounts):
cursor.rowcount = rowcounts[insert_call_count[0]]
insert_call_count[0] += 1
else:
cursor.rowcount = 0
cursor.execute.side_effect = execute_with_rowcount
context = build_op_context(
resources={"database": mock_db, "cluster": mock_cluster},
)
# Patch context.run.job_name where it's accessed in copy_chunk
from unittest.mock import PropertyMock
with patch.object(type(context), "run", PropertyMock(return_value=MagicMock(job_name="test_job"))):
result = copy_chunk(context, config, chunk)
# Verify result
assert result["chunk_min"] == 1
assert result["chunk_max"] == 250
assert result["records_copied"] == 150 # 50 + 75 + 25
# Verify SET statements called once (before loop)
cursor = mock_db.cursor.return_value.__enter__.return_value
execute_calls = [call[0][0] for call in cursor.execute.call_args_list]
# Verify BEGIN/COMMIT called 3 times (one per batch)
assert execute_calls.count("BEGIN") == 3
assert execute_calls.count("COMMIT") == 3
# Verify INSERT called 3 times
insert_calls = [call for call in execute_calls if "INSERT INTO" in call]
assert len(insert_calls) == 3
def test_copy_chunk_duplicate_key_violation_retry(self):
"""Test that duplicate key violation triggers retry."""
config = PersonsNewBackfillConfig(
chunk_size=1000, batch_size=100, source_table="posthog_persons", destination_table="posthog_persons_new"
)
chunk = (1, 100)
mock_db = create_mock_database_resource()
mock_cluster = create_mock_cluster_resource()
cursor = mock_db.cursor.return_value.__enter__.return_value
# Track INSERT attempts
insert_attempts = [0]
# First INSERT raises UniqueViolation, second succeeds
def execute_side_effect(query, *args):
if "INSERT INTO" in query:
insert_attempts[0] += 1
if insert_attempts[0] == 1:
# First INSERT attempt raises error
# Use real UniqueViolation - pgcode is readonly but isinstance check will pass
raise psycopg2.errors.UniqueViolation("duplicate key value violates unique constraint")
# Subsequent calls succeed
cursor.rowcount = 50 # Success on retry
cursor.execute.side_effect = execute_side_effect
context = build_op_context(
resources={"database": mock_db, "cluster": mock_cluster},
)
# Need to patch time.sleep and run.job_name
from unittest.mock import PropertyMock
mock_run = MagicMock(job_name="test_job")
with (
patch("dags.persons_new_backfill.time.sleep"),
patch.object(type(context), "run", PropertyMock(return_value=mock_run)),
):
copy_chunk(context, config, chunk)
# Verify ROLLBACK was called on error
execute_calls = [call[0][0] for call in cursor.execute.call_args_list]
assert "ROLLBACK" in execute_calls
# Verify retry succeeded (should have INSERT called twice, COMMIT once)
insert_calls = [call for call in execute_calls if "INSERT INTO" in call]
assert len(insert_calls) >= 1 # At least one successful INSERT
def test_copy_chunk_error_handling_and_rollback(self):
"""Test error handling and rollback on non-duplicate errors."""
config = PersonsNewBackfillConfig(
chunk_size=1000, batch_size=100, source_table="posthog_persons", destination_table="posthog_persons_new"
)
chunk = (1, 100)
mock_db = create_mock_database_resource()
mock_cluster = create_mock_cluster_resource()
cursor = mock_db.cursor.return_value.__enter__.return_value
# Raise generic error on INSERT
def execute_side_effect(query, *args):
if "INSERT INTO" in query:
raise Exception("Connection lost")
cursor.execute.side_effect = execute_side_effect
context = build_op_context(
resources={"database": mock_db, "cluster": mock_cluster},
)
# Patch context.run.job_name where it's accessed in copy_chunk
from unittest.mock import PropertyMock
mock_run = MagicMock(job_name="test_job")
with patch.object(type(context), "run", PropertyMock(return_value=mock_run)):
# Should raise Dagster.Failure
from dagster import Failure
try:
copy_chunk(context, config, chunk)
raise AssertionError("Expected Dagster.Failure to be raised")
except Failure as e:
# Verify error metadata
assert e.description is not None
assert "Failed to copy batch" in e.description
# Verify ROLLBACK was called
execute_calls = [call[0][0] for call in cursor.execute.call_args_list]
assert "ROLLBACK" in execute_calls
def test_copy_chunk_insert_query_format(self):
"""Test that INSERT query has correct format."""
config = PersonsNewBackfillConfig(
chunk_size=1000, batch_size=100, source_table="test_source", destination_table="test_dest"
)
chunk = (1, 100)
mock_db = create_mock_database_resource(rowcount_values=10)
mock_cluster = create_mock_cluster_resource()
context = build_op_context(
resources={"database": mock_db, "cluster": mock_cluster},
)
# Patch context.run.job_name where it's accessed in copy_chunk
from unittest.mock import PropertyMock
with patch.object(type(context), "run", PropertyMock(return_value=MagicMock(job_name="test_job"))):
copy_chunk(context, config, chunk)
cursor = mock_db.cursor.return_value.__enter__.return_value
execute_calls = [call[0][0] for call in cursor.execute.call_args_list]
# Find INSERT query
insert_query = next((call for call in execute_calls if "INSERT INTO" in call), None)
assert insert_query is not None
# Verify query components
assert "INSERT INTO test_dest" in insert_query
assert "SELECT s.*" in insert_query
assert "FROM test_source s" in insert_query
assert "WHERE s.id >" in insert_query
assert "AND s.id <=" in insert_query
assert "NOT EXISTS" in insert_query
assert "d.team_id = s.team_id" in insert_query
assert "d.id = s.id" in insert_query
assert "ORDER BY s.id DESC" in insert_query
def test_copy_chunk_session_settings_applied_once(self):
"""Test that SET statements are applied once at session level before batch loop."""
config = PersonsNewBackfillConfig(
chunk_size=1000, batch_size=50, source_table="posthog_persons", destination_table="posthog_persons_new"
)
chunk = (1, 150) # 3 batches
mock_db = create_mock_database_resource(rowcount_values=25)
mock_cluster = create_mock_cluster_resource()
context = build_op_context(
resources={"database": mock_db, "cluster": mock_cluster},
)
# Patch context.run.job_name where it's accessed in copy_chunk
from unittest.mock import PropertyMock
with patch.object(type(context), "run", PropertyMock(return_value=MagicMock(job_name="test_job"))):
copy_chunk(context, config, chunk)
cursor = mock_db.cursor.return_value.__enter__.return_value
execute_calls = [call[0][0] for call in cursor.execute.call_args_list]
# Count SET statements (should be called once each, before loop)
set_statements = [
"SET application_name",
"SET lock_timeout",
"SET statement_timeout",
"SET maintenance_work_mem",
"SET work_mem",
"SET temp_buffers",
"SET max_parallel_workers_per_gather",
"SET max_parallel_maintenance_workers",
"SET synchronous_commit",
]
for stmt in set_statements:
count = sum(1 for call in execute_calls if stmt in call)
assert count == 1, f"Expected {stmt} to be called once, but it was called {count} times"
# Verify SET statements come before BEGIN statements
set_indices = [i for i, call in enumerate(execute_calls) if any(stmt in call for stmt in set_statements)]
begin_indices = [i for i, call in enumerate(execute_calls) if call == "BEGIN"]
if set_indices and begin_indices:
assert max(set_indices) < min(begin_indices), "SET statements should come before BEGIN statements"