feat: S3 BatchExports missing features (#17233)

This commit is contained in:
Tomás Farías Santana
2023-09-01 16:19:02 +02:00
committed by GitHub
parent 8a835df63b
commit 17c5ec1710
11 changed files with 565 additions and 458 deletions

View File

@@ -117,7 +117,7 @@ jobs:
- name: Install SAML (python3-saml) dependencies
run: |
sudo apt-get update
sudo apt-get install libxml2-dev libxmlsec1-dev libxmlsec1-openssl
sudo apt-get install libxml2-dev libxmlsec1 libxmlsec1-dev libxmlsec1-openssl
- name: Install python dependencies
if: steps.cache-backend-tests.outputs.cache-hit != 'true'

View File

@@ -50,6 +50,8 @@ class S3BatchExportInputs:
aws_access_key_id: str | None = None
aws_secret_access_key: str | None = None
data_interval_end: str | None = None
compression: str | None = None
exclude_events: list[str] | None = None
@dataclass

View File

@@ -193,6 +193,53 @@ async def test_get_rows_count_handles_duplicates(client):
assert row_count == 10000
@pytest.mark.django_db
@pytest.mark.asyncio
async def test_get_rows_count_can_exclude_events(client):
"""Test the count of rows returned by get_rows_count can exclude events."""
team_id = randint(1, 1000000)
events: list[EventValues] = [
{
"uuid": str(uuid4()),
"event": f"test-{i}",
"_timestamp": "2023-04-20 14:30:00",
"timestamp": f"2023-04-20 14:30:00.{i:06d}",
"inserted_at": f"2023-04-20 14:30:00.{i:06d}",
"created_at": "2023-04-20 14:30:00.000000",
"distinct_id": str(uuid4()),
"person_id": str(uuid4()),
"person_properties": {"$browser": "Chrome", "$os": "Mac OS X"},
"team_id": team_id,
"properties": {
"$browser": "Chrome",
"$os": "Mac OS X",
"$ip": "127.0.0.1",
"$current_url": "http://localhost.com",
},
"elements_chain": "this that and the other",
"elements": json.dumps("this that and the other"),
"ip": "127.0.0.1",
"site_url": "http://localhost.com",
"set": None,
"set_once": None,
}
for i in range(10000)
]
# Duplicate everything
duplicate_events = events * 2
await insert_events(
ch_client=client,
events=duplicate_events,
)
# Exclude the latter half of events.
exclude_events = (f"test-{i}" for i in range(5000, 10000))
row_count = await get_rows_count(client, team_id, "2023-04-20 14:30:00", "2023-04-20 14:31:00", exclude_events)
assert row_count == 5000
@pytest.mark.django_db
@pytest.mark.asyncio
async def test_get_results_iterator(client):
@@ -300,6 +347,63 @@ async def test_get_results_iterator_handles_duplicates(client):
assert value == expected[key], f"{key} value in {result} didn't match value in {expected}"
@pytest.mark.django_db
@pytest.mark.asyncio
async def test_get_results_iterator_can_exclude_events(client):
"""Test the rows returned by get_results_iterator can exclude events."""
team_id = randint(1, 1000000)
events: list[EventValues] = [
{
"uuid": str(uuid4()),
"event": f"test-{i}",
"_timestamp": "2023-04-20 14:30:00",
"timestamp": f"2023-04-20 14:30:00.{i:06d}",
"inserted_at": f"2023-04-20 14:30:00.{i:06d}",
"created_at": "2023-04-20 14:30:00.000000",
"distinct_id": str(uuid4()),
"person_id": str(uuid4()),
"person_properties": {"$browser": "Chrome", "$os": "Mac OS X"},
"team_id": team_id,
"properties": {
"$browser": "Chrome",
"$os": "Mac OS X",
"$ip": "127.0.0.1",
"$current_url": "http://localhost.com",
},
"elements_chain": "this that and the other",
"elements": json.dumps("this that and the other"),
"ip": "127.0.0.1",
"site_url": "http://localhost.com",
"set": None,
"set_once": None,
}
for i in range(10000)
]
duplicate_events = events * 2
await insert_events(
ch_client=client,
events=duplicate_events,
)
# Exclude the latter half of events.
exclude_events = (f"test-{i}" for i in range(5000, 10000))
iter_ = get_results_iterator(client, team_id, "2023-04-20 14:30:00", "2023-04-20 14:31:00", exclude_events)
rows = [row for row in iter_]
all_expected = sorted(events[:5000], key=operator.itemgetter("event"))
all_result = sorted(rows, key=operator.itemgetter("event"))
assert len(all_expected) == len(all_result)
assert len([row["uuid"] for row in all_result]) == len(set(row["uuid"] for row in all_result))
for expected, result in zip(all_expected, all_result):
for key, value in result.items():
# Some keys will be missing from result, so let's only check the ones we have.
assert value == expected[key], f"{key} value in {result} didn't match value in {expected}"
@pytest.mark.parametrize(
"interval,data_interval_end,expected",
[

View File

@@ -1,11 +1,14 @@
import datetime as dt
import functools
import gzip
import itertools
import json
from random import randint
from unittest import mock
from uuid import uuid4
import boto3
import brotli
import pytest
from django.conf import settings
from django.test import Client as HttpClient
@@ -26,7 +29,6 @@ from posthog.temporal.tests.batch_exports.fixtures import (
afetch_batch_export_runs,
)
from posthog.temporal.workflows.base import create_export_run, update_export_run_status
from posthog.temporal.workflows.batch_exports import get_results_iterator
from posthog.temporal.workflows.clickhouse import ClickHouseClient
from posthog.temporal.workflows.s3_batch_export import (
S3BatchExportInputs,
@@ -74,7 +76,9 @@ def s3_client(bucket_name):
s3_client.delete_bucket(Bucket=bucket_name)
def assert_events_in_s3(s3_client, bucket_name, key_prefix, events):
def assert_events_in_s3(
s3_client, bucket_name, key_prefix, events, compression: str | None = None, exclude_events: list[str] | None = None
):
"""Assert provided events written to JSON in key_prefix in S3 bucket_name."""
# List the objects in the bucket with the prefix.
objects = s3_client.list_objects_v2(Bucket=bucket_name, Prefix=key_prefix)
@@ -89,13 +93,28 @@ def assert_events_in_s3(s3_client, bucket_name, key_prefix, events):
data = object["Body"].read()
# Check that the data is correct.
match compression:
case "gzip":
data = gzip.decompress(data)
case "brotli":
data = brotli.decompress(data)
case _:
pass
json_data = [json.loads(line) for line in data.decode("utf-8").split("\n") if line]
# Pull out the fields we inserted only
json_data.sort(key=lambda x: x["timestamp"])
# Remove team_id, _timestamp from events
expected_events = [{k: v for k, v in event.items() if k not in ["team_id", "_timestamp"]} for event in events]
if exclude_events is None:
exclude_events = []
expected_events = [
{k: v for k, v in event.items() if k not in ["team_id", "_timestamp"]}
for event in events
if event["event"] not in exclude_events
]
expected_events.sort(key=lambda x: x["timestamp"])
# First check one event, the first one, so that we can get a nice diff if
@@ -106,7 +125,13 @@ def assert_events_in_s3(s3_client, bucket_name, key_prefix, events):
@pytest.mark.django_db
@pytest.mark.asyncio
async def test_insert_into_s3_activity_puts_data_into_s3(bucket_name, s3_client, activity_environment):
@pytest.mark.parametrize(
"compression,exclude_events",
itertools.product([None, "gzip", "brotli"], [None, ["test-exclude"]]),
)
async def test_insert_into_s3_activity_puts_data_into_s3(
bucket_name, s3_client, activity_environment, compression, exclude_events
):
"""Test that the insert_into_s3_activity function puts data into S3."""
data_interval_start = "2023-04-20 14:00:00"
@@ -155,7 +180,7 @@ async def test_insert_into_s3_activity_puts_data_into_s3(bucket_name, s3_client,
EventValues(
{
"uuid": str(uuid4()),
"event": "test",
"event": "test-exclude",
"_timestamp": "2023-04-20 14:29:00",
"timestamp": "2023-04-20 14:29:00.000000",
"inserted_at": "2023-04-20 14:30:00.000000",
@@ -241,6 +266,8 @@ async def test_insert_into_s3_activity_puts_data_into_s3(bucket_name, s3_client,
data_interval_end=data_interval_end,
aws_access_key_id="object_storage_root_user",
aws_secret_access_key="object_storage_root_password",
compression=compression,
exclude_events=exclude_events,
)
with override_settings(
@@ -249,13 +276,18 @@ async def test_insert_into_s3_activity_puts_data_into_s3(bucket_name, s3_client,
with mock.patch("posthog.temporal.workflows.s3_batch_export.boto3.client", side_effect=create_test_client):
await activity_environment.run(insert_into_s3_activity, insert_inputs)
assert_events_in_s3(s3_client, bucket_name, prefix, events)
assert_events_in_s3(s3_client, bucket_name, prefix, events, compression, exclude_events)
@pytest.mark.django_db
@pytest.mark.asyncio
@pytest.mark.parametrize("interval", ["hour", "day"])
async def test_s3_export_workflow_with_minio_bucket(client: HttpClient, s3_client, bucket_name, interval):
@pytest.mark.parametrize(
"interval,compression,exclude_events",
itertools.product(["hour", "day"], [None, "gzip", "brotli"], [None, ["test-exclude"]]),
)
async def test_s3_export_workflow_with_minio_bucket(
client: HttpClient, s3_client, bucket_name, interval, compression, exclude_events
):
"""Test S3 Export Workflow end-to-end by using a local MinIO bucket instead of S3.
The workflow should update the batch export run status to completed and produce the expected
@@ -270,6 +302,8 @@ async def test_s3_export_workflow_with_minio_bucket(client: HttpClient, s3_clien
"prefix": prefix,
"aws_access_key_id": "object_storage_root_user",
"aws_secret_access_key": "object_storage_root_password",
"compression": compression,
"exclude_events": exclude_events,
},
}
@@ -305,7 +339,7 @@ async def test_s3_export_workflow_with_minio_bucket(client: HttpClient, s3_clien
},
{
"uuid": str(uuid4()),
"event": "test",
"event": "test-exclude",
"timestamp": "2023-04-25 14:29:00.000000",
"created_at": "2023-04-25 14:29:00.000000",
"inserted_at": "2023-04-25 14:29:00.000000",
@@ -385,12 +419,15 @@ async def test_s3_export_workflow_with_minio_bucket(client: HttpClient, s3_clien
run = runs[0]
assert run.status == "Completed"
assert_events_in_s3(s3_client, bucket_name, prefix, events)
assert_events_in_s3(s3_client, bucket_name, prefix, events, compression, exclude_events)
@pytest.mark.django_db
@pytest.mark.asyncio
async def test_s3_export_workflow_with_minio_bucket_and_a_lot_of_data(client: HttpClient, s3_client, bucket_name):
@pytest.mark.parametrize("compression", [None, "gzip"])
async def test_s3_export_workflow_with_minio_bucket_and_a_lot_of_data(
client: HttpClient, s3_client, bucket_name, compression
):
"""Test the full S3 workflow targetting a MinIO bucket.
The workflow should update the batch export run status to completed and produce the expected
@@ -412,6 +449,7 @@ async def test_s3_export_workflow_with_minio_bucket_and_a_lot_of_data(client: Ht
"prefix": prefix,
"aws_access_key_id": "object_storage_root_user",
"aws_secret_access_key": "object_storage_root_password",
"compression": compression,
},
}
@@ -477,7 +515,7 @@ async def test_s3_export_workflow_with_minio_bucket_and_a_lot_of_data(client: Ht
id=workflow_id,
task_queue=settings.TEMPORAL_TASK_QUEUE,
retry_policy=RetryPolicy(maximum_attempts=1),
execution_timeout=dt.timedelta(seconds=180),
execution_timeout=dt.timedelta(seconds=360),
)
runs = await afetch_batch_export_runs(batch_export_id=batch_export.id)
@@ -486,12 +524,15 @@ async def test_s3_export_workflow_with_minio_bucket_and_a_lot_of_data(client: Ht
run = runs[0]
assert run.status == "Completed"
assert_events_in_s3(s3_client, bucket_name, prefix.format(year=2023, month="04", day="25"), events)
assert_events_in_s3(s3_client, bucket_name, prefix.format(year=2023, month="04", day="25"), events, compression)
@pytest.mark.django_db
@pytest.mark.asyncio
async def test_s3_export_workflow_defaults_to_timestamp_on_null_inserted_at(client: HttpClient, s3_client, bucket_name):
@pytest.mark.parametrize("compression", [None, "gzip", "brotli"])
async def test_s3_export_workflow_defaults_to_timestamp_on_null_inserted_at(
client: HttpClient, s3_client, bucket_name, compression
):
"""Test the full S3 workflow targetting a MinIO bucket.
In this scenario we assert that when inserted_at is NULL, we default to _timestamp.
@@ -513,6 +554,7 @@ async def test_s3_export_workflow_defaults_to_timestamp_on_null_inserted_at(clie
"prefix": prefix,
"aws_access_key_id": "object_storage_root_user",
"aws_secret_access_key": "object_storage_root_password",
"compression": compression,
},
}
@@ -600,12 +642,15 @@ async def test_s3_export_workflow_defaults_to_timestamp_on_null_inserted_at(clie
run = runs[0]
assert run.status == "Completed"
assert_events_in_s3(s3_client, bucket_name, prefix, events)
assert_events_in_s3(s3_client, bucket_name, prefix, events, compression)
@pytest.mark.django_db
@pytest.mark.asyncio
async def test_s3_export_workflow_with_minio_bucket_and_custom_key_prefix(client: HttpClient, s3_client, bucket_name):
@pytest.mark.parametrize("compression", [None, "gzip", "brotli"])
async def test_s3_export_workflow_with_minio_bucket_and_custom_key_prefix(
client: HttpClient, s3_client, bucket_name, compression
):
"""Test the S3BatchExport Workflow utilizing a custom key prefix.
We will be asserting that exported events land in the appropiate S3 key according to the prefix.
@@ -626,6 +671,7 @@ async def test_s3_export_workflow_with_minio_bucket_and_custom_key_prefix(client
"prefix": prefix,
"aws_access_key_id": "object_storage_root_user",
"aws_secret_access_key": "object_storage_root_password",
"compression": compression,
},
}
@@ -707,324 +753,15 @@ async def test_s3_export_workflow_with_minio_bucket_and_custom_key_prefix(client
assert len(objects.get("Contents", [])) == 1
assert key.startswith(expected_key_prefix)
assert_events_in_s3(s3_client, bucket_name, expected_key_prefix, events)
assert_events_in_s3(s3_client, bucket_name, expected_key_prefix, events, compression)
@pytest.mark.django_db
@pytest.mark.asyncio
async def test_s3_export_workflow_continues_on_json_decode_error(client: HttpClient, s3_client, bucket_name):
"""Test that S3 Export Workflow end-to-end by using a local MinIO bucket instead of S3.
In this particular case, we should be handling JSONDecodeErrors produced by attempting to parse
ClickHouse error strings.
"""
ch_client = ClickHouseClient(
url=settings.CLICKHOUSE_HTTP_URL,
user=settings.CLICKHOUSE_USER,
password=settings.CLICKHOUSE_PASSWORD,
database=settings.CLICKHOUSE_DATABASE,
)
prefix = f"posthog-events-{str(uuid4())}"
destination_data = {
"type": "S3",
"config": {
"bucket_name": bucket_name,
"region": "us-east-1",
"prefix": prefix,
"aws_access_key_id": "object_storage_root_user",
"aws_secret_access_key": "object_storage_root_password",
},
}
batch_export_data = {
"name": "my-production-s3-bucket-destination",
"destination": destination_data,
"interval": "hour",
}
organization = await acreate_organization("test")
team = await acreate_team(organization=organization)
batch_export = await acreate_batch_export(
team_id=team.pk,
name=batch_export_data["name"],
destination_data=batch_export_data["destination"],
interval=batch_export_data["interval"],
)
events: list[EventValues] = [
{
"uuid": str(uuid4()),
"event": "test",
"timestamp": "2023-04-25 13:30:00.000000",
"created_at": "2023-04-25 13:30:00.000000",
"inserted_at": "2023-04-25 13:30:00.000000",
"_timestamp": "2023-04-25 13:30:00",
"person_id": str(uuid4()),
"person_properties": {"$browser": "Chrome", "$os": "Mac OS X"},
"team_id": team.pk,
"properties": {"$browser": "Chrome", "$os": "Mac OS X"},
"distinct_id": str(uuid4()),
"elements_chain": "this is a comman, separated, list, of css selectors(?)",
},
{
"uuid": str(uuid4()),
"event": "test",
"timestamp": "2023-04-25 14:29:00.000000",
"inserted_at": "2023-04-25 14:29:00.000000",
"created_at": "2023-04-25 14:29:00.000000",
"_timestamp": "2023-04-25 14:29:00",
"person_id": str(uuid4()),
"person_properties": {"$browser": "Chrome", "$os": "Mac OS X"},
"team_id": team.pk,
"properties": {"$browser": "Chrome", "$os": "Mac OS X"},
"distinct_id": str(uuid4()),
"elements_chain": "this is a comman, separated, list, of css selectors(?)",
},
]
# Insert some data into the `sharded_events` table.
await insert_events(
client=ch_client,
events=events,
)
workflow_id = str(uuid4())
inputs = S3BatchExportInputs(
team_id=team.pk,
batch_export_id=str(batch_export.id),
data_interval_end="2023-04-25 14:30:00.000000",
**batch_export.destination.config,
)
error_raised = False
def fake_get_results_iterator(*args, **kwargs):
nonlocal error_raised
for result in get_results_iterator(*args, **kwargs):
if error_raised is False:
error_raised = True
raise json.JSONDecodeError("Test error", "A ClickHouse error message\n", 0)
yield result
async with await WorkflowEnvironment.start_time_skipping() as activity_environment:
async with Worker(
activity_environment.client,
task_queue=settings.TEMPORAL_TASK_QUEUE,
workflows=[S3BatchExportWorkflow],
activities=[create_export_run, insert_into_s3_activity, update_export_run_status],
workflow_runner=UnsandboxedWorkflowRunner(),
):
with mock.patch("posthog.temporal.workflows.s3_batch_export.boto3.client", side_effect=create_test_client):
with mock.patch(
"posthog.temporal.workflows.s3_batch_export.get_results_iterator",
side_effect=fake_get_results_iterator,
) as mocked_iterator:
await activity_environment.client.execute_workflow(
S3BatchExportWorkflow.run,
inputs,
id=workflow_id,
task_queue=settings.TEMPORAL_TASK_QUEUE,
retry_policy=RetryPolicy(maximum_attempts=1),
execution_timeout=dt.timedelta(seconds=10),
)
assert mocked_iterator.call_count == 2 # An extra call for the error
mocked_iterator.assert_has_calls(
[
mock.call(
client=mock.ANY,
interval_start="2023-04-25T13:30:00",
interval_end="2023-04-25T14:30:00",
team_id=team.pk,
),
mock.call(
client=mock.ANY,
interval_start="2023-04-25T13:30:00",
interval_end="2023-04-25T14:30:00",
team_id=team.pk,
),
]
)
assert error_raised is True
runs = await afetch_batch_export_runs(batch_export_id=batch_export.id)
assert len(runs) == 1
run = runs[0]
assert run.status == "Completed"
assert_events_in_s3(s3_client, bucket_name, prefix, events)
@pytest.mark.django_db
@pytest.mark.asyncio
async def test_s3_export_workflow_continues_on_multiple_json_decode_error(client: HttpClient, s3_client, bucket_name):
"""Test that S3 Export Workflow end-to-end by using a local MinIO bucket instead of S3.
In this particular case, we should be handling JSONDecodeErrors produced by attempting to parse
ClickHouse error strings.
"""
ch_client = ClickHouseClient(
url=settings.CLICKHOUSE_HTTP_URL,
user=settings.CLICKHOUSE_USER,
password=settings.CLICKHOUSE_PASSWORD,
database=settings.CLICKHOUSE_DATABASE,
)
prefix = f"posthog-events-{str(uuid4())}"
destination_data = {
"type": "S3",
"config": {
"bucket_name": bucket_name,
"region": "us-east-1",
"prefix": prefix,
"aws_access_key_id": "object_storage_root_user",
"aws_secret_access_key": "object_storage_root_password",
},
}
batch_export_data = {
"name": "my-production-s3-bucket-destination",
"destination": destination_data,
"interval": "hour",
}
organization = await acreate_organization("test")
team = await acreate_team(organization=organization)
batch_export = await acreate_batch_export(
team_id=team.pk,
name=batch_export_data["name"],
destination_data=batch_export_data["destination"],
interval=batch_export_data["interval"],
)
# Produce a list of 10 events.
events: list[EventValues] = [
{
"uuid": str(uuid4()),
"event": str(i),
"timestamp": f"2023-04-25 13:3{i}:00.000000",
"inserted_at": f"2023-04-25 13:3{i}:00.000000",
"created_at": f"2023-04-25 13:3{i}:00.000000",
"_timestamp": f"2023-04-25 13:3{i}:00",
"person_id": str(uuid4()),
"person_properties": {"$browser": "Chrome", "$os": "Mac OS X"},
"team_id": team.pk,
"properties": {"$browser": "Chrome", "$os": "Mac OS X"},
"distinct_id": str(uuid4()),
"elements_chain": "this is a comman, separated, list, of css selectors(?)",
}
for i in range(10)
]
# Insert some data into the `sharded_events` table.
await insert_events(
client=ch_client,
events=events,
)
workflow_id = str(uuid4())
inputs = S3BatchExportInputs(
team_id=team.pk,
batch_export_id=str(batch_export.id),
data_interval_end="2023-04-25 14:30:00.000000",
**batch_export.destination.config,
)
failed_events = set()
def should_fail(event):
return bool(int(event["event"]) % 2)
def fake_get_results_iterator(*args, **kwargs):
for result in get_results_iterator(*args, **kwargs):
if result["event"] not in failed_events and should_fail(result):
# Will raise an exception every other row.
failed_events.add(result["event"]) # Otherwise we infinite loop
raise json.JSONDecodeError("Test error", "A ClickHouse error message\n", 0)
yield result
async with await WorkflowEnvironment.start_time_skipping() as activity_environment:
async with Worker(
activity_environment.client,
task_queue=settings.TEMPORAL_TASK_QUEUE,
workflows=[S3BatchExportWorkflow],
activities=[create_export_run, insert_into_s3_activity, update_export_run_status],
workflow_runner=UnsandboxedWorkflowRunner(),
):
with mock.patch("posthog.temporal.workflows.s3_batch_export.boto3.client", side_effect=create_test_client):
with mock.patch(
"posthog.temporal.workflows.s3_batch_export.get_results_iterator",
side_effect=fake_get_results_iterator,
) as mocked_iterator:
await activity_environment.client.execute_workflow(
S3BatchExportWorkflow.run,
inputs,
id=workflow_id,
task_queue=settings.TEMPORAL_TASK_QUEUE,
retry_policy=RetryPolicy(maximum_attempts=1),
execution_timeout=dt.timedelta(seconds=10),
)
assert mocked_iterator.call_count == 6 # 5 failures so 5 extra calls
mocked_iterator.assert_has_calls( # The first call + we resume on all the even dates.
[
mock.call(
client=mock.ANY,
interval_start="2023-04-25T13:30:00",
interval_end="2023-04-25T14:30:00",
team_id=team.pk,
),
mock.call(
client=mock.ANY,
interval_start="2023-04-25 13:30:00.000000",
interval_end="2023-04-25T14:30:00",
team_id=team.pk,
),
mock.call(
client=mock.ANY,
interval_start="2023-04-25 13:32:00.000000",
interval_end="2023-04-25T14:30:00",
team_id=team.pk,
),
mock.call(
client=mock.ANY,
interval_start="2023-04-25 13:34:00.000000",
interval_end="2023-04-25T14:30:00",
team_id=team.pk,
),
mock.call(
client=mock.ANY,
interval_start="2023-04-25 13:36:00.000000",
interval_end="2023-04-25T14:30:00",
team_id=team.pk,
),
mock.call(
client=mock.ANY,
interval_start="2023-04-25 13:38:00.000000",
interval_end="2023-04-25T14:30:00",
team_id=team.pk,
),
]
)
runs = await afetch_batch_export_runs(batch_export_id=batch_export.id)
assert len(runs) == 1
run = runs[0]
assert run.status == "Completed"
duplicate_events = [event for event in events if not should_fail(event)]
assert_events_in_s3(s3_client, bucket_name, prefix, events + duplicate_events)
@pytest.mark.django_db
@pytest.mark.asyncio
async def test_s3_export_workflow_with_minio_bucket_produces_no_duplicates(client: HttpClient, s3_client, bucket_name):
@pytest.mark.parametrize("compression", [None, "gzip", "brotli"])
async def test_s3_export_workflow_with_minio_bucket_produces_no_duplicates(
client: HttpClient, s3_client, bucket_name, compression
):
"""Test that S3 Export Workflow end-to-end by using a local MinIO bucket instead of S3.
In this particular instance of the test, we assert no duplicates are exported to S3.
@@ -1045,6 +782,7 @@ async def test_s3_export_workflow_with_minio_bucket_produces_no_duplicates(clien
"prefix": prefix,
"aws_access_key_id": "object_storage_root_user",
"aws_secret_access_key": "object_storage_root_password",
"compression": compression,
},
}
@@ -1150,7 +888,7 @@ async def test_s3_export_workflow_with_minio_bucket_produces_no_duplicates(clien
run = runs[0]
assert run.status == "Completed"
assert_events_in_s3(s3_client, bucket_name, prefix, events)
assert_events_in_s3(s3_client, bucket_name, prefix, events, compression)
# We don't care about these for the next test, just need something to be defined.
@@ -1182,6 +920,26 @@ base_inputs = {
),
"2023-01-01 00:00:00-2023-01-01 01:00:00.jsonl",
),
(
S3InsertInputs(
prefix="",
data_interval_start="2023-01-01 00:00:00",
data_interval_end="2023-01-01 01:00:00",
compression="gzip",
**base_inputs,
),
"2023-01-01 00:00:00-2023-01-01 01:00:00.jsonl.gz",
),
(
S3InsertInputs(
prefix="",
data_interval_start="2023-01-01 00:00:00",
data_interval_end="2023-01-01 01:00:00",
compression="brotli",
**base_inputs,
),
"2023-01-01 00:00:00-2023-01-01 01:00:00.jsonl.br",
),
(
S3InsertInputs(
prefix="my-fancy-prefix",
@@ -1200,6 +958,26 @@ base_inputs = {
),
"my-fancy-prefix/2023-01-01 00:00:00-2023-01-01 01:00:00.jsonl",
),
(
S3InsertInputs(
prefix="my-fancy-prefix",
data_interval_start="2023-01-01 00:00:00",
data_interval_end="2023-01-01 01:00:00",
compression="gzip",
**base_inputs,
),
"my-fancy-prefix/2023-01-01 00:00:00-2023-01-01 01:00:00.jsonl.gz",
),
(
S3InsertInputs(
prefix="my-fancy-prefix",
data_interval_start="2023-01-01 00:00:00",
data_interval_end="2023-01-01 01:00:00",
compression="brotli",
**base_inputs,
),
"my-fancy-prefix/2023-01-01 00:00:00-2023-01-01 01:00:00.jsonl.br",
),
(
S3InsertInputs(
prefix="my-fancy-prefix-with-a-forwardslash/",
@@ -1236,6 +1014,26 @@ base_inputs = {
),
"nested/prefix/2023-01-01 00:00:00-2023-01-01 01:00:00.jsonl",
),
(
S3InsertInputs(
prefix="/nested/prefix/",
data_interval_start="2023-01-01 00:00:00",
data_interval_end="2023-01-01 01:00:00",
compression="gzip",
**base_inputs,
),
"nested/prefix/2023-01-01 00:00:00-2023-01-01 01:00:00.jsonl.gz",
),
(
S3InsertInputs(
prefix="/nested/prefix/",
data_interval_start="2023-01-01 00:00:00",
data_interval_end="2023-01-01 01:00:00",
compression="brotli",
**base_inputs,
),
"nested/prefix/2023-01-01 00:00:00-2023-01-01 01:00:00.jsonl.br",
),
],
)
def test_get_s3_key(inputs, expected):

View File

@@ -65,6 +65,7 @@ def person_overrides_table(query_inputs):
sync_execute(PERSON_OVERRIDES_CREATE_TABLE_SQL)
sync_execute(KAFKA_PERSON_OVERRIDES_TABLE_SQL)
sync_execute(PERSON_OVERRIDES_CREATE_MATERIALIZED_VIEW_SQL)
sync_execute("TRUNCATE TABLE person_overrides")
yield
@@ -89,9 +90,9 @@ def person_overrides_data(person_overrides_table):
"""
person_overrides = {
# These numbers are all arbitrary.
1: {PersonOverrideTuple(uuid4(), uuid4()) for _ in range(5)},
2: {PersonOverrideTuple(uuid4(), uuid4()) for _ in range(4)},
3: {PersonOverrideTuple(uuid4(), uuid4()) for _ in range(3)},
100: {PersonOverrideTuple(uuid4(), uuid4()) for _ in range(5)},
200: {PersonOverrideTuple(uuid4(), uuid4()) for _ in range(4)},
300: {PersonOverrideTuple(uuid4(), uuid4()) for _ in range(3)},
}
all_test_values = []

View File

@@ -1,11 +1,13 @@
import collections.abc
import csv
import datetime as dt
import gzip
import json
import tempfile
import typing
from string import Template
import brotli
from temporalio import workflow
SELECT_QUERY_TEMPLATE = Template(
@@ -21,17 +23,33 @@ SELECT_QUERY_TEMPLATE = Template(
AND COALESCE(inserted_at, _timestamp) >= toDateTime64({data_interval_start}, 6, 'UTC')
AND COALESCE(inserted_at, _timestamp) < toDateTime64({data_interval_end}, 6, 'UTC')
AND team_id = {team_id}
$exclude_events
$order_by
$format
"""
)
async def get_rows_count(client, team_id: int, interval_start: str, interval_end: str) -> int:
async def get_rows_count(
client,
team_id: int,
interval_start: str,
interval_end: str,
exclude_events: collections.abc.Iterable[str] | None = None,
) -> int:
data_interval_start_ch = dt.datetime.fromisoformat(interval_start).strftime("%Y-%m-%d %H:%M:%S")
data_interval_end_ch = dt.datetime.fromisoformat(interval_end).strftime("%Y-%m-%d %H:%M:%S")
if exclude_events:
exclude_events_statement = f"AND event NOT IN {str(tuple(exclude_events))}"
else:
exclude_events_statement = ""
query = SELECT_QUERY_TEMPLATE.substitute(
fields="count(DISTINCT event, cityHash64(distinct_id), cityHash64(uuid)) as count", order_by="", format=""
fields="count(DISTINCT event, cityHash64(distinct_id), cityHash64(uuid)) as count",
order_by="",
format="",
exclude_events=exclude_events_statement,
)
count = await client.read_query(
@@ -68,14 +86,25 @@ elements_chain
def get_results_iterator(
client, team_id: int, interval_start: str, interval_end: str
client,
team_id: int,
interval_start: str,
interval_end: str,
exclude_events: collections.abc.Iterable[str] | None = None,
) -> typing.Generator[dict[str, typing.Any], None, None]:
data_interval_start_ch = dt.datetime.fromisoformat(interval_start).strftime("%Y-%m-%d %H:%M:%S")
data_interval_end_ch = dt.datetime.fromisoformat(interval_end).strftime("%Y-%m-%d %H:%M:%S")
if exclude_events:
exclude_events_statement = f"AND event NOT IN {str(tuple(exclude_events))}"
else:
exclude_events_statement = ""
query = SELECT_QUERY_TEMPLATE.substitute(
fields=FIELDS,
order_by="ORDER BY inserted_at",
format="FORMAT ArrowStream",
exclude_events=exclude_events_statement,
)
for batch in client.stream_query_as_arrow(
@@ -204,6 +233,7 @@ class BatchExportTemporaryFile:
self,
mode: str = "w+b",
buffering=-1,
compression: str | None = None,
encoding: str | None = None,
newline: str | None = None,
suffix: str | None = None,
@@ -222,10 +252,12 @@ class BatchExportTemporaryFile:
dir=dir,
errors=errors,
)
self.compression = compression
self.bytes_total = 0
self.records_total = 0
self.bytes_since_last_reset = 0
self.records_since_last_reset = 0
self._brotli_compressor = None
def __getattr__(self, name):
"""Pass get attr to underlying tempfile.NamedTemporaryFile."""
@@ -240,11 +272,37 @@ class BatchExportTemporaryFile:
"""Context-manager protocol exit method."""
return self._file.__exit__(exc, value, tb)
@property
def brotli_compressor(self):
if self._brotli_compressor is None:
self._brotli_compressor = brotli.Compressor()
return self._brotli_compressor
def compress(self, content: bytes | str) -> bytes:
if isinstance(content, str):
encoded = content.encode("utf-8")
else:
encoded = content
match self.compression:
case "gzip":
return gzip.compress(encoded)
case "brotli":
self.brotli_compressor.process(encoded)
return self.brotli_compressor.flush()
case None:
return encoded
case _:
raise ValueError(f"Unsupported compression: '{self.compression}'")
def write(self, content: bytes | str):
"""Write bytes to underlying file keeping track of how many bytes were written."""
if "b" in self.mode and isinstance(content, str):
content = content.encode("utf-8")
result = self._file.write(content)
compressed_content = self.compress(content)
if "b" in self.mode:
result = self._file.write(compressed_content)
else:
result = self._file.write(compressed_content.decode("utf-8"))
self.bytes_total += result
self.bytes_since_last_reset += result
@@ -316,6 +374,18 @@ class BatchExportTemporaryFile:
quoting=quoting,
)
def rewind(self):
"""Rewind the file before reading it."""
if self.compression == "brotli":
result = self._file.write(self.brotli_compressor.finish())
self.bytes_total += result
self.bytes_since_last_reset += result
self._brotli_compressor = None
self._file.seek(0)
def reset(self):
"""Reset underlying file by truncating it.

View File

@@ -171,7 +171,7 @@ class ClickHouseClient:
ClickHouseError: If the status code is not 200.
"""
if response.status_code != 200:
error_message = response.text()
error_message = response.text
raise ClickHouseError(query, error_message)
@contextlib.asynccontextmanager

View File

@@ -2,9 +2,8 @@ import asyncio
import datetime as dt
import json
import posixpath
import tempfile
import typing
from dataclasses import dataclass
from typing import TYPE_CHECKING, List
import boto3
from django.conf import settings
@@ -20,15 +19,13 @@ from posthog.temporal.workflows.base import (
update_export_run_status,
)
from posthog.temporal.workflows.batch_exports import (
BatchExportTemporaryFile,
get_data_interval,
get_results_iterator,
get_rows_count,
)
from posthog.temporal.workflows.clickhouse import get_client
if TYPE_CHECKING:
from mypy_boto3_s3.type_defs import CompletedPartTypeDef
def get_allowed_template_variables(inputs) -> dict[str, str]:
"""Derive from inputs a dictionary of supported template variables for the S3 key prefix."""
@@ -50,7 +47,17 @@ def get_s3_key(inputs) -> str:
"""Return an S3 key given S3InsertInputs."""
template_variables = get_allowed_template_variables(inputs)
key_prefix = inputs.prefix.format(**template_variables)
key = posixpath.join(key_prefix, f"{inputs.data_interval_start}-{inputs.data_interval_end}.jsonl")
base_file_name = f"{inputs.data_interval_start}-{inputs.data_interval_end}"
match inputs.compression:
case "gzip":
file_name = base_file_name + ".jsonl.gz"
case "brotli":
file_name = base_file_name + ".jsonl.br"
case _:
file_name = base_file_name + ".jsonl"
key = posixpath.join(key_prefix, file_name)
if posixpath.isabs(key):
# Keys are relative to root dir, so this would add an extra "/"
@@ -59,6 +66,152 @@ def get_s3_key(inputs) -> str:
return key
class UploadAlreadyInProgressError(Exception):
"""Exception raised when an S3MultiPartUpload is already in progress."""
def __init__(self, upload_id):
super().__init__(f"This upload is already in progress with ID: {upload_id}. Instantiate a new object.")
class NoUploadInProgressError(Exception):
"""Exception raised when there is no S3MultiPartUpload in progress."""
def __init__(self):
super().__init__("No multi-part upload is in progress. Call 'create' to start one.")
class S3MultiPartUploadState(typing.NamedTuple):
upload_id: str
parts: list[dict[str, str | int]]
class S3MultiPartUpload:
"""An S3 multi-part upload."""
def __init__(self, s3_client, bucket_name, key):
self.s3_client = s3_client
self.bucket_name = bucket_name
self.key = key
self.upload_id = None
self.parts = []
def to_state(self) -> S3MultiPartUploadState:
"""Produce state tuple that can be used to resume this S3MultiPartUpload."""
# The second predicate is trivial but required by type-checking.
if self.is_upload_in_progress() is False or self.upload_id is None:
raise NoUploadInProgressError()
return S3MultiPartUploadState(self.upload_id, self.parts)
@property
def part_number(self):
"""Return the current part number."""
return len(self.parts)
def is_upload_in_progress(self) -> bool:
"""Whether this S3MultiPartUpload is in progress or not."""
if self.upload_id is None:
return False
return True
def start(self) -> str:
"""Start this S3MultiPartUpload."""
if self.is_upload_in_progress() is True:
raise UploadAlreadyInProgressError(self.upload_id)
multipart_response = self.s3_client.create_multipart_upload(Bucket=self.bucket_name, Key=self.key)
self.upload_id = multipart_response["UploadId"]
return self.upload_id
def continue_from_state(self, state: S3MultiPartUploadState):
"""Continue this S3MultiPartUpload from a previous state."""
self.upload_id = state.upload_id
self.parts = state.parts
return self.upload_id
def complete(self) -> str:
if self.is_upload_in_progress() is False:
raise NoUploadInProgressError()
response = self.s3_client.complete_multipart_upload(
Bucket=self.bucket_name,
Key=self.key,
UploadId=self.upload_id,
MultipartUpload={"Parts": self.parts},
)
self.upload_id = None
self.parts = []
return response["Location"]
def abort(self):
if self.is_upload_in_progress() is False:
raise NoUploadInProgressError()
self.s3_client.abort_multipart_upload(
Bucket=self.bucket_name,
Key=self.key,
UploadId=self.upload_id,
)
self.upload_id = None
self.parts = []
def upload_part(self, body: BatchExportTemporaryFile, rewind: bool = True):
next_part_number = self.part_number + 1
if rewind is True:
body.rewind()
response = self.s3_client.upload_part(
Bucket=self.bucket_name,
Key=self.key,
PartNumber=next_part_number,
UploadId=self.upload_id,
Body=body,
)
self.parts.append({"PartNumber": next_part_number, "ETag": response["ETag"]})
def __enter__(self):
if not self.is_upload_in_progress():
self.start()
return self
def __exit__(self, exc_type, exc_value, traceback) -> bool:
if exc_value is None:
# Succesfully completed the upload
self.complete()
return True
if exc_type == asyncio.CancelledError:
# Ensure we clean-up the cancelled upload.
self.abort()
return False
class HeartbeatDetails(typing.NamedTuple):
"""This tuple allows us to enforce a schema on the Heartbeat details.
Attributes:
last_uploaded_part_timestamp: The timestamp of the last part we managed to upload.
upload_state: State to continue a S3MultiPartUpload when activity execution resumes.
"""
last_uploaded_part_timestamp: str
upload_state: S3MultiPartUploadState
@classmethod
def from_activity_details(cls, details):
last_uploaded_part_timestamp = details[0]
upload_state = S3MultiPartUploadState(*details[1])
return HeartbeatDetails(last_uploaded_part_timestamp, upload_state)
@dataclass
class S3InsertInputs:
"""Inputs for S3 exports."""
@@ -75,6 +228,55 @@ class S3InsertInputs:
data_interval_end: str
aws_access_key_id: str | None = None
aws_secret_access_key: str | None = None
compression: str | None = None
exclude_events: list[str] | None = None
def initialize_and_resume_multipart_upload(inputs: S3InsertInputs) -> tuple[S3MultiPartUpload, str]:
"""Initialize a S3MultiPartUpload and resume it from a hearbeat state if available."""
key = get_s3_key(inputs)
s3_client = boto3.client(
"s3",
region_name=inputs.region,
aws_access_key_id=inputs.aws_access_key_id,
aws_secret_access_key=inputs.aws_secret_access_key,
)
s3_upload = S3MultiPartUpload(s3_client, inputs.bucket_name, key)
details = activity.info().heartbeat_details
try:
interval_start, upload_state = HeartbeatDetails.from_activity_details(details)
except IndexError:
# This is the error we expect when no details as the sequence will be empty.
interval_start = inputs.data_interval_start
activity.logger.info(
f"Did not receive details from previous activity Excecution. Export will start from the beginning: {interval_start}"
)
except Exception as e:
# We still start from the beginning, but we make a point to log unexpected errors.
# Ideally, any new exceptions should be added to the previous block after the first time and we will never land here.
interval_start = inputs.data_interval_start
activity.logger.warning(
f"Did not receive details from previous activity Excecution due to an unexpected error. Export will start from the beginning: {interval_start}",
exc_info=e,
)
else:
activity.logger.info(
f"Received details from previous activity. Export will attempt to resume from: {interval_start}"
)
s3_upload.continue_from_state(upload_state)
if inputs.compression == "brotli":
# Even if we receive details we cannot resume a brotli compressed upload as we have lost the compressor state.
interval_start = inputs.data_interval_start
activity.logger.info(
f"Export will start from the beginning as we are using brotli compression: {interval_start}"
)
s3_upload.abort()
return s3_upload, interval_start
@activity.defn
@@ -111,28 +313,7 @@ async def insert_into_s3_activity(inputs: S3InsertInputs):
activity.logger.info("BatchExporting %s rows to S3", count)
# Create a multipart upload to S3
key = get_s3_key(inputs)
s3_client = boto3.client(
"s3",
region_name=inputs.region,
aws_access_key_id=inputs.aws_access_key_id,
aws_secret_access_key=inputs.aws_secret_access_key,
)
details = activity.info().heartbeat_details
parts: List[CompletedPartTypeDef] = []
if len(details) == 4:
interval_start, upload_id, parts, part_number = details
activity.logger.info(f"Received details from previous activity. Export will resume from {interval_start}")
else:
multipart_response = s3_client.create_multipart_upload(Bucket=inputs.bucket_name, Key=key)
upload_id = multipart_response["UploadId"]
interval_start = inputs.data_interval_start
part_number = 1
s3_upload, interval_start = initialize_and_resume_multipart_upload(inputs)
# Iterate through chunks of results from ClickHouse and push them to S3
# as a multipart upload. The intention here is to keep memory usage low,
@@ -145,6 +326,7 @@ async def insert_into_s3_activity(inputs: S3InsertInputs):
team_id=inputs.team_id,
interval_start=interval_start,
interval_end=inputs.data_interval_end,
exclude_events=inputs.exclude_events,
)
result = None
@@ -156,46 +338,14 @@ async def insert_into_s3_activity(inputs: S3InsertInputs):
activity.logger.warn(
f"Worker shutting down! Reporting back latest exported part {last_uploaded_part_timestamp}"
)
activity.heartbeat(last_uploaded_part_timestamp, upload_id)
activity.heartbeat(last_uploaded_part_timestamp, s3_upload.to_state())
asyncio.create_task(worker_shutdown_handler())
with tempfile.NamedTemporaryFile() as local_results_file:
while True:
try:
result = results_iterator.__next__()
except StopIteration:
break
except json.JSONDecodeError:
# This is raised by aiochclient as we try to decode an error message from ClickHouse.
# So far, this error message only indicated that we were too slow consuming rows.
# So, we can resume from the last result.
if result is None:
# We failed right at the beginning
new_interval_start = None
else:
new_interval_start = result.get("inserted_at", None)
if not isinstance(new_interval_start, str):
new_interval_start = inputs.data_interval_start
activity.logger.warn(
f"Failed to decode a JSON value while iterating, potentially due to a ClickHouse error. Resuming from {new_interval_start}"
)
results_iterator = get_results_iterator(
client=client,
team_id=inputs.team_id,
interval_start=new_interval_start, # This means we'll generate at least one duplicate.
interval_end=inputs.data_interval_end,
)
continue
if not result:
break
content = json.dumps(
{
with s3_upload as s3_upload:
with BatchExportTemporaryFile(compression=inputs.compression) as local_results_file:
for result in results_iterator:
record = {
"created_at": result["created_at"],
"distinct_id": result["distinct_id"],
"elements_chain": result["elements_chain"],
@@ -207,60 +357,36 @@ async def insert_into_s3_activity(inputs: S3InsertInputs):
"timestamp": result["timestamp"],
"uuid": result["uuid"],
}
)
# Write the results to a local file
local_results_file.write(content.encode("utf-8"))
local_results_file.write("\n".encode("utf-8"))
local_results_file.write_records_to_jsonl([record])
# Write results to S3 when the file reaches 50MB and reset the
# file, or if there is nothing else to write.
if (
local_results_file.tell()
and local_results_file.tell() > settings.BATCH_EXPORT_S3_UPLOAD_CHUNK_SIZE_BYTES
):
activity.logger.info("Uploading part %s", part_number)
if local_results_file.tell() > settings.BATCH_EXPORT_S3_UPLOAD_CHUNK_SIZE_BYTES:
activity.logger.info(
"Uploading part %s containing %s records with size %s bytes to S3",
s3_upload.part_number + 1,
local_results_file.records_since_last_reset,
local_results_file.bytes_since_last_reset,
)
local_results_file.seek(0)
response = s3_client.upload_part(
Bucket=inputs.bucket_name,
Key=key,
PartNumber=part_number,
UploadId=upload_id,
Body=local_results_file,
s3_upload.upload_part(local_results_file)
last_uploaded_part_timestamp = result["inserted_at"]
activity.heartbeat(last_uploaded_part_timestamp, s3_upload.to_state())
local_results_file.reset()
if local_results_file.tell() > 0 and result is not None:
activity.logger.info(
"Uploading last part %s containing %s records with size %s bytes to S3",
s3_upload.part_number + 1,
local_results_file.records_since_last_reset,
local_results_file.bytes_since_last_reset,
)
s3_upload.upload_part(local_results_file)
last_uploaded_part_timestamp = result["inserted_at"]
# Record the ETag for the part
parts.append({"PartNumber": part_number, "ETag": response["ETag"]})
part_number += 1
activity.heartbeat(last_uploaded_part_timestamp, upload_id, parts, part_number)
# Reset the file
local_results_file.seek(0)
local_results_file.truncate()
# Upload the last part
local_results_file.seek(0)
response = s3_client.upload_part(
Bucket=inputs.bucket_name,
Key=key,
PartNumber=part_number,
UploadId=upload_id,
Body=local_results_file,
)
activity.heartbeat(last_uploaded_part_timestamp, upload_id, parts, part_number)
# Record the ETag for the last part
parts.append({"PartNumber": part_number, "ETag": response["ETag"]})
# Complete the multipart upload
s3_client.complete_multipart_upload(
Bucket=inputs.bucket_name,
Key=key,
UploadId=upload_id,
MultipartUpload={"Parts": parts},
)
activity.heartbeat(last_uploaded_part_timestamp, s3_upload.to_state())
@workflow.defn(name="s3-export")
@@ -314,6 +440,8 @@ class S3BatchExportWorkflow(PostHogWorkflow):
aws_secret_access_key=inputs.aws_secret_access_key,
data_interval_start=data_interval_start.isoformat(),
data_interval_end=data_interval_end.isoformat(),
compression=inputs.compression,
exclude_events=inputs.exclude_events,
)
try:
await workflow.execute_activity(
@@ -334,6 +462,7 @@ class S3BatchExportWorkflow(PostHogWorkflow):
except Exception as e:
workflow.logger.exception("S3 BatchExport failed.", exc_info=e)
update_inputs.status = "Failed"
update_inputs.latest_error = str(e)
raise
finally:

View File

@@ -32,8 +32,8 @@ class TestGzipMiddleware(APIBaseTest):
def test_no_compression_for_unsuccessful_requests_to_paths_on_the_allow_list(self) -> None:
with self.settings(GZIP_RESPONSE_ALLOW_LIST=["something-else", "snapshots$"]):
response = self._get_path("/api/projects/12/session_recordings/blah/snapshots")
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
response = self._get_path(f"/api/projects/{self.team.pk}/session_recordings/blah/snapshots")
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND, msg=response.content.decode("utf-8"))
contentEncoding = response.headers.get("Content-Encoding", None)
self.assertEqual(contentEncoding, None)

View File

@@ -9,6 +9,7 @@ antlr4-python3-runtime==4.13.0
amqp==2.6.0
boto3==1.26.66
boto3-stubs[s3]
brotli==1.0.9
celery==4.4.7
celery-redbeat==2.0.0
clickhouse-driver==0.2.4

View File

@@ -51,6 +51,8 @@ botocore==1.29.66
# s3transfer
botocore-stubs==1.29.130
# via boto3-stubs
brotli==1.0.9
# via -r requirements.in
celery==4.4.7
# via
# -r requirements.in