mirror of
https://github.com/BillyOutlast/posthog.git
synced 2026-02-04 03:01:23 +01:00
feat: S3 BatchExports missing features (#17233)
This commit is contained in:
committed by
GitHub
parent
8a835df63b
commit
17c5ec1710
2
.github/workflows/ci-backend.yml
vendored
2
.github/workflows/ci-backend.yml
vendored
@@ -117,7 +117,7 @@ jobs:
|
||||
- name: Install SAML (python3-saml) dependencies
|
||||
run: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install libxml2-dev libxmlsec1-dev libxmlsec1-openssl
|
||||
sudo apt-get install libxml2-dev libxmlsec1 libxmlsec1-dev libxmlsec1-openssl
|
||||
|
||||
- name: Install python dependencies
|
||||
if: steps.cache-backend-tests.outputs.cache-hit != 'true'
|
||||
|
||||
@@ -50,6 +50,8 @@ class S3BatchExportInputs:
|
||||
aws_access_key_id: str | None = None
|
||||
aws_secret_access_key: str | None = None
|
||||
data_interval_end: str | None = None
|
||||
compression: str | None = None
|
||||
exclude_events: list[str] | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -193,6 +193,53 @@ async def test_get_rows_count_handles_duplicates(client):
|
||||
assert row_count == 10000
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_rows_count_can_exclude_events(client):
|
||||
"""Test the count of rows returned by get_rows_count can exclude events."""
|
||||
team_id = randint(1, 1000000)
|
||||
|
||||
events: list[EventValues] = [
|
||||
{
|
||||
"uuid": str(uuid4()),
|
||||
"event": f"test-{i}",
|
||||
"_timestamp": "2023-04-20 14:30:00",
|
||||
"timestamp": f"2023-04-20 14:30:00.{i:06d}",
|
||||
"inserted_at": f"2023-04-20 14:30:00.{i:06d}",
|
||||
"created_at": "2023-04-20 14:30:00.000000",
|
||||
"distinct_id": str(uuid4()),
|
||||
"person_id": str(uuid4()),
|
||||
"person_properties": {"$browser": "Chrome", "$os": "Mac OS X"},
|
||||
"team_id": team_id,
|
||||
"properties": {
|
||||
"$browser": "Chrome",
|
||||
"$os": "Mac OS X",
|
||||
"$ip": "127.0.0.1",
|
||||
"$current_url": "http://localhost.com",
|
||||
},
|
||||
"elements_chain": "this that and the other",
|
||||
"elements": json.dumps("this that and the other"),
|
||||
"ip": "127.0.0.1",
|
||||
"site_url": "http://localhost.com",
|
||||
"set": None,
|
||||
"set_once": None,
|
||||
}
|
||||
for i in range(10000)
|
||||
]
|
||||
# Duplicate everything
|
||||
duplicate_events = events * 2
|
||||
|
||||
await insert_events(
|
||||
ch_client=client,
|
||||
events=duplicate_events,
|
||||
)
|
||||
|
||||
# Exclude the latter half of events.
|
||||
exclude_events = (f"test-{i}" for i in range(5000, 10000))
|
||||
row_count = await get_rows_count(client, team_id, "2023-04-20 14:30:00", "2023-04-20 14:31:00", exclude_events)
|
||||
assert row_count == 5000
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_results_iterator(client):
|
||||
@@ -300,6 +347,63 @@ async def test_get_results_iterator_handles_duplicates(client):
|
||||
assert value == expected[key], f"{key} value in {result} didn't match value in {expected}"
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_results_iterator_can_exclude_events(client):
|
||||
"""Test the rows returned by get_results_iterator can exclude events."""
|
||||
team_id = randint(1, 1000000)
|
||||
|
||||
events: list[EventValues] = [
|
||||
{
|
||||
"uuid": str(uuid4()),
|
||||
"event": f"test-{i}",
|
||||
"_timestamp": "2023-04-20 14:30:00",
|
||||
"timestamp": f"2023-04-20 14:30:00.{i:06d}",
|
||||
"inserted_at": f"2023-04-20 14:30:00.{i:06d}",
|
||||
"created_at": "2023-04-20 14:30:00.000000",
|
||||
"distinct_id": str(uuid4()),
|
||||
"person_id": str(uuid4()),
|
||||
"person_properties": {"$browser": "Chrome", "$os": "Mac OS X"},
|
||||
"team_id": team_id,
|
||||
"properties": {
|
||||
"$browser": "Chrome",
|
||||
"$os": "Mac OS X",
|
||||
"$ip": "127.0.0.1",
|
||||
"$current_url": "http://localhost.com",
|
||||
},
|
||||
"elements_chain": "this that and the other",
|
||||
"elements": json.dumps("this that and the other"),
|
||||
"ip": "127.0.0.1",
|
||||
"site_url": "http://localhost.com",
|
||||
"set": None,
|
||||
"set_once": None,
|
||||
}
|
||||
for i in range(10000)
|
||||
]
|
||||
duplicate_events = events * 2
|
||||
|
||||
await insert_events(
|
||||
ch_client=client,
|
||||
events=duplicate_events,
|
||||
)
|
||||
|
||||
# Exclude the latter half of events.
|
||||
exclude_events = (f"test-{i}" for i in range(5000, 10000))
|
||||
iter_ = get_results_iterator(client, team_id, "2023-04-20 14:30:00", "2023-04-20 14:31:00", exclude_events)
|
||||
rows = [row for row in iter_]
|
||||
|
||||
all_expected = sorted(events[:5000], key=operator.itemgetter("event"))
|
||||
all_result = sorted(rows, key=operator.itemgetter("event"))
|
||||
|
||||
assert len(all_expected) == len(all_result)
|
||||
assert len([row["uuid"] for row in all_result]) == len(set(row["uuid"] for row in all_result))
|
||||
|
||||
for expected, result in zip(all_expected, all_result):
|
||||
for key, value in result.items():
|
||||
# Some keys will be missing from result, so let's only check the ones we have.
|
||||
assert value == expected[key], f"{key} value in {result} didn't match value in {expected}"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"interval,data_interval_end,expected",
|
||||
[
|
||||
|
||||
@@ -1,11 +1,14 @@
|
||||
import datetime as dt
|
||||
import functools
|
||||
import gzip
|
||||
import itertools
|
||||
import json
|
||||
from random import randint
|
||||
from unittest import mock
|
||||
from uuid import uuid4
|
||||
|
||||
import boto3
|
||||
import brotli
|
||||
import pytest
|
||||
from django.conf import settings
|
||||
from django.test import Client as HttpClient
|
||||
@@ -26,7 +29,6 @@ from posthog.temporal.tests.batch_exports.fixtures import (
|
||||
afetch_batch_export_runs,
|
||||
)
|
||||
from posthog.temporal.workflows.base import create_export_run, update_export_run_status
|
||||
from posthog.temporal.workflows.batch_exports import get_results_iterator
|
||||
from posthog.temporal.workflows.clickhouse import ClickHouseClient
|
||||
from posthog.temporal.workflows.s3_batch_export import (
|
||||
S3BatchExportInputs,
|
||||
@@ -74,7 +76,9 @@ def s3_client(bucket_name):
|
||||
s3_client.delete_bucket(Bucket=bucket_name)
|
||||
|
||||
|
||||
def assert_events_in_s3(s3_client, bucket_name, key_prefix, events):
|
||||
def assert_events_in_s3(
|
||||
s3_client, bucket_name, key_prefix, events, compression: str | None = None, exclude_events: list[str] | None = None
|
||||
):
|
||||
"""Assert provided events written to JSON in key_prefix in S3 bucket_name."""
|
||||
# List the objects in the bucket with the prefix.
|
||||
objects = s3_client.list_objects_v2(Bucket=bucket_name, Prefix=key_prefix)
|
||||
@@ -89,13 +93,28 @@ def assert_events_in_s3(s3_client, bucket_name, key_prefix, events):
|
||||
data = object["Body"].read()
|
||||
|
||||
# Check that the data is correct.
|
||||
match compression:
|
||||
case "gzip":
|
||||
data = gzip.decompress(data)
|
||||
case "brotli":
|
||||
data = brotli.decompress(data)
|
||||
case _:
|
||||
pass
|
||||
|
||||
json_data = [json.loads(line) for line in data.decode("utf-8").split("\n") if line]
|
||||
# Pull out the fields we inserted only
|
||||
|
||||
json_data.sort(key=lambda x: x["timestamp"])
|
||||
|
||||
# Remove team_id, _timestamp from events
|
||||
expected_events = [{k: v for k, v in event.items() if k not in ["team_id", "_timestamp"]} for event in events]
|
||||
if exclude_events is None:
|
||||
exclude_events = []
|
||||
|
||||
expected_events = [
|
||||
{k: v for k, v in event.items() if k not in ["team_id", "_timestamp"]}
|
||||
for event in events
|
||||
if event["event"] not in exclude_events
|
||||
]
|
||||
expected_events.sort(key=lambda x: x["timestamp"])
|
||||
|
||||
# First check one event, the first one, so that we can get a nice diff if
|
||||
@@ -106,7 +125,13 @@ def assert_events_in_s3(s3_client, bucket_name, key_prefix, events):
|
||||
|
||||
@pytest.mark.django_db
|
||||
@pytest.mark.asyncio
|
||||
async def test_insert_into_s3_activity_puts_data_into_s3(bucket_name, s3_client, activity_environment):
|
||||
@pytest.mark.parametrize(
|
||||
"compression,exclude_events",
|
||||
itertools.product([None, "gzip", "brotli"], [None, ["test-exclude"]]),
|
||||
)
|
||||
async def test_insert_into_s3_activity_puts_data_into_s3(
|
||||
bucket_name, s3_client, activity_environment, compression, exclude_events
|
||||
):
|
||||
"""Test that the insert_into_s3_activity function puts data into S3."""
|
||||
|
||||
data_interval_start = "2023-04-20 14:00:00"
|
||||
@@ -155,7 +180,7 @@ async def test_insert_into_s3_activity_puts_data_into_s3(bucket_name, s3_client,
|
||||
EventValues(
|
||||
{
|
||||
"uuid": str(uuid4()),
|
||||
"event": "test",
|
||||
"event": "test-exclude",
|
||||
"_timestamp": "2023-04-20 14:29:00",
|
||||
"timestamp": "2023-04-20 14:29:00.000000",
|
||||
"inserted_at": "2023-04-20 14:30:00.000000",
|
||||
@@ -241,6 +266,8 @@ async def test_insert_into_s3_activity_puts_data_into_s3(bucket_name, s3_client,
|
||||
data_interval_end=data_interval_end,
|
||||
aws_access_key_id="object_storage_root_user",
|
||||
aws_secret_access_key="object_storage_root_password",
|
||||
compression=compression,
|
||||
exclude_events=exclude_events,
|
||||
)
|
||||
|
||||
with override_settings(
|
||||
@@ -249,13 +276,18 @@ async def test_insert_into_s3_activity_puts_data_into_s3(bucket_name, s3_client,
|
||||
with mock.patch("posthog.temporal.workflows.s3_batch_export.boto3.client", side_effect=create_test_client):
|
||||
await activity_environment.run(insert_into_s3_activity, insert_inputs)
|
||||
|
||||
assert_events_in_s3(s3_client, bucket_name, prefix, events)
|
||||
assert_events_in_s3(s3_client, bucket_name, prefix, events, compression, exclude_events)
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("interval", ["hour", "day"])
|
||||
async def test_s3_export_workflow_with_minio_bucket(client: HttpClient, s3_client, bucket_name, interval):
|
||||
@pytest.mark.parametrize(
|
||||
"interval,compression,exclude_events",
|
||||
itertools.product(["hour", "day"], [None, "gzip", "brotli"], [None, ["test-exclude"]]),
|
||||
)
|
||||
async def test_s3_export_workflow_with_minio_bucket(
|
||||
client: HttpClient, s3_client, bucket_name, interval, compression, exclude_events
|
||||
):
|
||||
"""Test S3 Export Workflow end-to-end by using a local MinIO bucket instead of S3.
|
||||
|
||||
The workflow should update the batch export run status to completed and produce the expected
|
||||
@@ -270,6 +302,8 @@ async def test_s3_export_workflow_with_minio_bucket(client: HttpClient, s3_clien
|
||||
"prefix": prefix,
|
||||
"aws_access_key_id": "object_storage_root_user",
|
||||
"aws_secret_access_key": "object_storage_root_password",
|
||||
"compression": compression,
|
||||
"exclude_events": exclude_events,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -305,7 +339,7 @@ async def test_s3_export_workflow_with_minio_bucket(client: HttpClient, s3_clien
|
||||
},
|
||||
{
|
||||
"uuid": str(uuid4()),
|
||||
"event": "test",
|
||||
"event": "test-exclude",
|
||||
"timestamp": "2023-04-25 14:29:00.000000",
|
||||
"created_at": "2023-04-25 14:29:00.000000",
|
||||
"inserted_at": "2023-04-25 14:29:00.000000",
|
||||
@@ -385,12 +419,15 @@ async def test_s3_export_workflow_with_minio_bucket(client: HttpClient, s3_clien
|
||||
run = runs[0]
|
||||
assert run.status == "Completed"
|
||||
|
||||
assert_events_in_s3(s3_client, bucket_name, prefix, events)
|
||||
assert_events_in_s3(s3_client, bucket_name, prefix, events, compression, exclude_events)
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
@pytest.mark.asyncio
|
||||
async def test_s3_export_workflow_with_minio_bucket_and_a_lot_of_data(client: HttpClient, s3_client, bucket_name):
|
||||
@pytest.mark.parametrize("compression", [None, "gzip"])
|
||||
async def test_s3_export_workflow_with_minio_bucket_and_a_lot_of_data(
|
||||
client: HttpClient, s3_client, bucket_name, compression
|
||||
):
|
||||
"""Test the full S3 workflow targetting a MinIO bucket.
|
||||
|
||||
The workflow should update the batch export run status to completed and produce the expected
|
||||
@@ -412,6 +449,7 @@ async def test_s3_export_workflow_with_minio_bucket_and_a_lot_of_data(client: Ht
|
||||
"prefix": prefix,
|
||||
"aws_access_key_id": "object_storage_root_user",
|
||||
"aws_secret_access_key": "object_storage_root_password",
|
||||
"compression": compression,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -477,7 +515,7 @@ async def test_s3_export_workflow_with_minio_bucket_and_a_lot_of_data(client: Ht
|
||||
id=workflow_id,
|
||||
task_queue=settings.TEMPORAL_TASK_QUEUE,
|
||||
retry_policy=RetryPolicy(maximum_attempts=1),
|
||||
execution_timeout=dt.timedelta(seconds=180),
|
||||
execution_timeout=dt.timedelta(seconds=360),
|
||||
)
|
||||
|
||||
runs = await afetch_batch_export_runs(batch_export_id=batch_export.id)
|
||||
@@ -486,12 +524,15 @@ async def test_s3_export_workflow_with_minio_bucket_and_a_lot_of_data(client: Ht
|
||||
run = runs[0]
|
||||
assert run.status == "Completed"
|
||||
|
||||
assert_events_in_s3(s3_client, bucket_name, prefix.format(year=2023, month="04", day="25"), events)
|
||||
assert_events_in_s3(s3_client, bucket_name, prefix.format(year=2023, month="04", day="25"), events, compression)
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
@pytest.mark.asyncio
|
||||
async def test_s3_export_workflow_defaults_to_timestamp_on_null_inserted_at(client: HttpClient, s3_client, bucket_name):
|
||||
@pytest.mark.parametrize("compression", [None, "gzip", "brotli"])
|
||||
async def test_s3_export_workflow_defaults_to_timestamp_on_null_inserted_at(
|
||||
client: HttpClient, s3_client, bucket_name, compression
|
||||
):
|
||||
"""Test the full S3 workflow targetting a MinIO bucket.
|
||||
|
||||
In this scenario we assert that when inserted_at is NULL, we default to _timestamp.
|
||||
@@ -513,6 +554,7 @@ async def test_s3_export_workflow_defaults_to_timestamp_on_null_inserted_at(clie
|
||||
"prefix": prefix,
|
||||
"aws_access_key_id": "object_storage_root_user",
|
||||
"aws_secret_access_key": "object_storage_root_password",
|
||||
"compression": compression,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -600,12 +642,15 @@ async def test_s3_export_workflow_defaults_to_timestamp_on_null_inserted_at(clie
|
||||
run = runs[0]
|
||||
assert run.status == "Completed"
|
||||
|
||||
assert_events_in_s3(s3_client, bucket_name, prefix, events)
|
||||
assert_events_in_s3(s3_client, bucket_name, prefix, events, compression)
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
@pytest.mark.asyncio
|
||||
async def test_s3_export_workflow_with_minio_bucket_and_custom_key_prefix(client: HttpClient, s3_client, bucket_name):
|
||||
@pytest.mark.parametrize("compression", [None, "gzip", "brotli"])
|
||||
async def test_s3_export_workflow_with_minio_bucket_and_custom_key_prefix(
|
||||
client: HttpClient, s3_client, bucket_name, compression
|
||||
):
|
||||
"""Test the S3BatchExport Workflow utilizing a custom key prefix.
|
||||
|
||||
We will be asserting that exported events land in the appropiate S3 key according to the prefix.
|
||||
@@ -626,6 +671,7 @@ async def test_s3_export_workflow_with_minio_bucket_and_custom_key_prefix(client
|
||||
"prefix": prefix,
|
||||
"aws_access_key_id": "object_storage_root_user",
|
||||
"aws_secret_access_key": "object_storage_root_password",
|
||||
"compression": compression,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -707,324 +753,15 @@ async def test_s3_export_workflow_with_minio_bucket_and_custom_key_prefix(client
|
||||
assert len(objects.get("Contents", [])) == 1
|
||||
assert key.startswith(expected_key_prefix)
|
||||
|
||||
assert_events_in_s3(s3_client, bucket_name, expected_key_prefix, events)
|
||||
assert_events_in_s3(s3_client, bucket_name, expected_key_prefix, events, compression)
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
@pytest.mark.asyncio
|
||||
async def test_s3_export_workflow_continues_on_json_decode_error(client: HttpClient, s3_client, bucket_name):
|
||||
"""Test that S3 Export Workflow end-to-end by using a local MinIO bucket instead of S3.
|
||||
|
||||
In this particular case, we should be handling JSONDecodeErrors produced by attempting to parse
|
||||
ClickHouse error strings.
|
||||
"""
|
||||
ch_client = ClickHouseClient(
|
||||
url=settings.CLICKHOUSE_HTTP_URL,
|
||||
user=settings.CLICKHOUSE_USER,
|
||||
password=settings.CLICKHOUSE_PASSWORD,
|
||||
database=settings.CLICKHOUSE_DATABASE,
|
||||
)
|
||||
|
||||
prefix = f"posthog-events-{str(uuid4())}"
|
||||
destination_data = {
|
||||
"type": "S3",
|
||||
"config": {
|
||||
"bucket_name": bucket_name,
|
||||
"region": "us-east-1",
|
||||
"prefix": prefix,
|
||||
"aws_access_key_id": "object_storage_root_user",
|
||||
"aws_secret_access_key": "object_storage_root_password",
|
||||
},
|
||||
}
|
||||
|
||||
batch_export_data = {
|
||||
"name": "my-production-s3-bucket-destination",
|
||||
"destination": destination_data,
|
||||
"interval": "hour",
|
||||
}
|
||||
|
||||
organization = await acreate_organization("test")
|
||||
team = await acreate_team(organization=organization)
|
||||
batch_export = await acreate_batch_export(
|
||||
team_id=team.pk,
|
||||
name=batch_export_data["name"],
|
||||
destination_data=batch_export_data["destination"],
|
||||
interval=batch_export_data["interval"],
|
||||
)
|
||||
|
||||
events: list[EventValues] = [
|
||||
{
|
||||
"uuid": str(uuid4()),
|
||||
"event": "test",
|
||||
"timestamp": "2023-04-25 13:30:00.000000",
|
||||
"created_at": "2023-04-25 13:30:00.000000",
|
||||
"inserted_at": "2023-04-25 13:30:00.000000",
|
||||
"_timestamp": "2023-04-25 13:30:00",
|
||||
"person_id": str(uuid4()),
|
||||
"person_properties": {"$browser": "Chrome", "$os": "Mac OS X"},
|
||||
"team_id": team.pk,
|
||||
"properties": {"$browser": "Chrome", "$os": "Mac OS X"},
|
||||
"distinct_id": str(uuid4()),
|
||||
"elements_chain": "this is a comman, separated, list, of css selectors(?)",
|
||||
},
|
||||
{
|
||||
"uuid": str(uuid4()),
|
||||
"event": "test",
|
||||
"timestamp": "2023-04-25 14:29:00.000000",
|
||||
"inserted_at": "2023-04-25 14:29:00.000000",
|
||||
"created_at": "2023-04-25 14:29:00.000000",
|
||||
"_timestamp": "2023-04-25 14:29:00",
|
||||
"person_id": str(uuid4()),
|
||||
"person_properties": {"$browser": "Chrome", "$os": "Mac OS X"},
|
||||
"team_id": team.pk,
|
||||
"properties": {"$browser": "Chrome", "$os": "Mac OS X"},
|
||||
"distinct_id": str(uuid4()),
|
||||
"elements_chain": "this is a comman, separated, list, of css selectors(?)",
|
||||
},
|
||||
]
|
||||
|
||||
# Insert some data into the `sharded_events` table.
|
||||
await insert_events(
|
||||
client=ch_client,
|
||||
events=events,
|
||||
)
|
||||
|
||||
workflow_id = str(uuid4())
|
||||
inputs = S3BatchExportInputs(
|
||||
team_id=team.pk,
|
||||
batch_export_id=str(batch_export.id),
|
||||
data_interval_end="2023-04-25 14:30:00.000000",
|
||||
**batch_export.destination.config,
|
||||
)
|
||||
error_raised = False
|
||||
|
||||
def fake_get_results_iterator(*args, **kwargs):
|
||||
nonlocal error_raised
|
||||
|
||||
for result in get_results_iterator(*args, **kwargs):
|
||||
if error_raised is False:
|
||||
error_raised = True
|
||||
raise json.JSONDecodeError("Test error", "A ClickHouse error message\n", 0)
|
||||
|
||||
yield result
|
||||
|
||||
async with await WorkflowEnvironment.start_time_skipping() as activity_environment:
|
||||
async with Worker(
|
||||
activity_environment.client,
|
||||
task_queue=settings.TEMPORAL_TASK_QUEUE,
|
||||
workflows=[S3BatchExportWorkflow],
|
||||
activities=[create_export_run, insert_into_s3_activity, update_export_run_status],
|
||||
workflow_runner=UnsandboxedWorkflowRunner(),
|
||||
):
|
||||
with mock.patch("posthog.temporal.workflows.s3_batch_export.boto3.client", side_effect=create_test_client):
|
||||
with mock.patch(
|
||||
"posthog.temporal.workflows.s3_batch_export.get_results_iterator",
|
||||
side_effect=fake_get_results_iterator,
|
||||
) as mocked_iterator:
|
||||
await activity_environment.client.execute_workflow(
|
||||
S3BatchExportWorkflow.run,
|
||||
inputs,
|
||||
id=workflow_id,
|
||||
task_queue=settings.TEMPORAL_TASK_QUEUE,
|
||||
retry_policy=RetryPolicy(maximum_attempts=1),
|
||||
execution_timeout=dt.timedelta(seconds=10),
|
||||
)
|
||||
|
||||
assert mocked_iterator.call_count == 2 # An extra call for the error
|
||||
mocked_iterator.assert_has_calls(
|
||||
[
|
||||
mock.call(
|
||||
client=mock.ANY,
|
||||
interval_start="2023-04-25T13:30:00",
|
||||
interval_end="2023-04-25T14:30:00",
|
||||
team_id=team.pk,
|
||||
),
|
||||
mock.call(
|
||||
client=mock.ANY,
|
||||
interval_start="2023-04-25T13:30:00",
|
||||
interval_end="2023-04-25T14:30:00",
|
||||
team_id=team.pk,
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
assert error_raised is True
|
||||
|
||||
runs = await afetch_batch_export_runs(batch_export_id=batch_export.id)
|
||||
assert len(runs) == 1
|
||||
|
||||
run = runs[0]
|
||||
assert run.status == "Completed"
|
||||
|
||||
assert_events_in_s3(s3_client, bucket_name, prefix, events)
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
@pytest.mark.asyncio
|
||||
async def test_s3_export_workflow_continues_on_multiple_json_decode_error(client: HttpClient, s3_client, bucket_name):
|
||||
"""Test that S3 Export Workflow end-to-end by using a local MinIO bucket instead of S3.
|
||||
|
||||
In this particular case, we should be handling JSONDecodeErrors produced by attempting to parse
|
||||
ClickHouse error strings.
|
||||
"""
|
||||
ch_client = ClickHouseClient(
|
||||
url=settings.CLICKHOUSE_HTTP_URL,
|
||||
user=settings.CLICKHOUSE_USER,
|
||||
password=settings.CLICKHOUSE_PASSWORD,
|
||||
database=settings.CLICKHOUSE_DATABASE,
|
||||
)
|
||||
|
||||
prefix = f"posthog-events-{str(uuid4())}"
|
||||
destination_data = {
|
||||
"type": "S3",
|
||||
"config": {
|
||||
"bucket_name": bucket_name,
|
||||
"region": "us-east-1",
|
||||
"prefix": prefix,
|
||||
"aws_access_key_id": "object_storage_root_user",
|
||||
"aws_secret_access_key": "object_storage_root_password",
|
||||
},
|
||||
}
|
||||
|
||||
batch_export_data = {
|
||||
"name": "my-production-s3-bucket-destination",
|
||||
"destination": destination_data,
|
||||
"interval": "hour",
|
||||
}
|
||||
|
||||
organization = await acreate_organization("test")
|
||||
team = await acreate_team(organization=organization)
|
||||
batch_export = await acreate_batch_export(
|
||||
team_id=team.pk,
|
||||
name=batch_export_data["name"],
|
||||
destination_data=batch_export_data["destination"],
|
||||
interval=batch_export_data["interval"],
|
||||
)
|
||||
|
||||
# Produce a list of 10 events.
|
||||
events: list[EventValues] = [
|
||||
{
|
||||
"uuid": str(uuid4()),
|
||||
"event": str(i),
|
||||
"timestamp": f"2023-04-25 13:3{i}:00.000000",
|
||||
"inserted_at": f"2023-04-25 13:3{i}:00.000000",
|
||||
"created_at": f"2023-04-25 13:3{i}:00.000000",
|
||||
"_timestamp": f"2023-04-25 13:3{i}:00",
|
||||
"person_id": str(uuid4()),
|
||||
"person_properties": {"$browser": "Chrome", "$os": "Mac OS X"},
|
||||
"team_id": team.pk,
|
||||
"properties": {"$browser": "Chrome", "$os": "Mac OS X"},
|
||||
"distinct_id": str(uuid4()),
|
||||
"elements_chain": "this is a comman, separated, list, of css selectors(?)",
|
||||
}
|
||||
for i in range(10)
|
||||
]
|
||||
|
||||
# Insert some data into the `sharded_events` table.
|
||||
await insert_events(
|
||||
client=ch_client,
|
||||
events=events,
|
||||
)
|
||||
|
||||
workflow_id = str(uuid4())
|
||||
inputs = S3BatchExportInputs(
|
||||
team_id=team.pk,
|
||||
batch_export_id=str(batch_export.id),
|
||||
data_interval_end="2023-04-25 14:30:00.000000",
|
||||
**batch_export.destination.config,
|
||||
)
|
||||
|
||||
failed_events = set()
|
||||
|
||||
def should_fail(event):
|
||||
return bool(int(event["event"]) % 2)
|
||||
|
||||
def fake_get_results_iterator(*args, **kwargs):
|
||||
for result in get_results_iterator(*args, **kwargs):
|
||||
if result["event"] not in failed_events and should_fail(result):
|
||||
# Will raise an exception every other row.
|
||||
failed_events.add(result["event"]) # Otherwise we infinite loop
|
||||
raise json.JSONDecodeError("Test error", "A ClickHouse error message\n", 0)
|
||||
|
||||
yield result
|
||||
|
||||
async with await WorkflowEnvironment.start_time_skipping() as activity_environment:
|
||||
async with Worker(
|
||||
activity_environment.client,
|
||||
task_queue=settings.TEMPORAL_TASK_QUEUE,
|
||||
workflows=[S3BatchExportWorkflow],
|
||||
activities=[create_export_run, insert_into_s3_activity, update_export_run_status],
|
||||
workflow_runner=UnsandboxedWorkflowRunner(),
|
||||
):
|
||||
with mock.patch("posthog.temporal.workflows.s3_batch_export.boto3.client", side_effect=create_test_client):
|
||||
with mock.patch(
|
||||
"posthog.temporal.workflows.s3_batch_export.get_results_iterator",
|
||||
side_effect=fake_get_results_iterator,
|
||||
) as mocked_iterator:
|
||||
await activity_environment.client.execute_workflow(
|
||||
S3BatchExportWorkflow.run,
|
||||
inputs,
|
||||
id=workflow_id,
|
||||
task_queue=settings.TEMPORAL_TASK_QUEUE,
|
||||
retry_policy=RetryPolicy(maximum_attempts=1),
|
||||
execution_timeout=dt.timedelta(seconds=10),
|
||||
)
|
||||
|
||||
assert mocked_iterator.call_count == 6 # 5 failures so 5 extra calls
|
||||
mocked_iterator.assert_has_calls( # The first call + we resume on all the even dates.
|
||||
[
|
||||
mock.call(
|
||||
client=mock.ANY,
|
||||
interval_start="2023-04-25T13:30:00",
|
||||
interval_end="2023-04-25T14:30:00",
|
||||
team_id=team.pk,
|
||||
),
|
||||
mock.call(
|
||||
client=mock.ANY,
|
||||
interval_start="2023-04-25 13:30:00.000000",
|
||||
interval_end="2023-04-25T14:30:00",
|
||||
team_id=team.pk,
|
||||
),
|
||||
mock.call(
|
||||
client=mock.ANY,
|
||||
interval_start="2023-04-25 13:32:00.000000",
|
||||
interval_end="2023-04-25T14:30:00",
|
||||
team_id=team.pk,
|
||||
),
|
||||
mock.call(
|
||||
client=mock.ANY,
|
||||
interval_start="2023-04-25 13:34:00.000000",
|
||||
interval_end="2023-04-25T14:30:00",
|
||||
team_id=team.pk,
|
||||
),
|
||||
mock.call(
|
||||
client=mock.ANY,
|
||||
interval_start="2023-04-25 13:36:00.000000",
|
||||
interval_end="2023-04-25T14:30:00",
|
||||
team_id=team.pk,
|
||||
),
|
||||
mock.call(
|
||||
client=mock.ANY,
|
||||
interval_start="2023-04-25 13:38:00.000000",
|
||||
interval_end="2023-04-25T14:30:00",
|
||||
team_id=team.pk,
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
runs = await afetch_batch_export_runs(batch_export_id=batch_export.id)
|
||||
assert len(runs) == 1
|
||||
|
||||
run = runs[0]
|
||||
assert run.status == "Completed"
|
||||
|
||||
duplicate_events = [event for event in events if not should_fail(event)]
|
||||
assert_events_in_s3(s3_client, bucket_name, prefix, events + duplicate_events)
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
@pytest.mark.asyncio
|
||||
async def test_s3_export_workflow_with_minio_bucket_produces_no_duplicates(client: HttpClient, s3_client, bucket_name):
|
||||
@pytest.mark.parametrize("compression", [None, "gzip", "brotli"])
|
||||
async def test_s3_export_workflow_with_minio_bucket_produces_no_duplicates(
|
||||
client: HttpClient, s3_client, bucket_name, compression
|
||||
):
|
||||
"""Test that S3 Export Workflow end-to-end by using a local MinIO bucket instead of S3.
|
||||
|
||||
In this particular instance of the test, we assert no duplicates are exported to S3.
|
||||
@@ -1045,6 +782,7 @@ async def test_s3_export_workflow_with_minio_bucket_produces_no_duplicates(clien
|
||||
"prefix": prefix,
|
||||
"aws_access_key_id": "object_storage_root_user",
|
||||
"aws_secret_access_key": "object_storage_root_password",
|
||||
"compression": compression,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -1150,7 +888,7 @@ async def test_s3_export_workflow_with_minio_bucket_produces_no_duplicates(clien
|
||||
run = runs[0]
|
||||
assert run.status == "Completed"
|
||||
|
||||
assert_events_in_s3(s3_client, bucket_name, prefix, events)
|
||||
assert_events_in_s3(s3_client, bucket_name, prefix, events, compression)
|
||||
|
||||
|
||||
# We don't care about these for the next test, just need something to be defined.
|
||||
@@ -1182,6 +920,26 @@ base_inputs = {
|
||||
),
|
||||
"2023-01-01 00:00:00-2023-01-01 01:00:00.jsonl",
|
||||
),
|
||||
(
|
||||
S3InsertInputs(
|
||||
prefix="",
|
||||
data_interval_start="2023-01-01 00:00:00",
|
||||
data_interval_end="2023-01-01 01:00:00",
|
||||
compression="gzip",
|
||||
**base_inputs,
|
||||
),
|
||||
"2023-01-01 00:00:00-2023-01-01 01:00:00.jsonl.gz",
|
||||
),
|
||||
(
|
||||
S3InsertInputs(
|
||||
prefix="",
|
||||
data_interval_start="2023-01-01 00:00:00",
|
||||
data_interval_end="2023-01-01 01:00:00",
|
||||
compression="brotli",
|
||||
**base_inputs,
|
||||
),
|
||||
"2023-01-01 00:00:00-2023-01-01 01:00:00.jsonl.br",
|
||||
),
|
||||
(
|
||||
S3InsertInputs(
|
||||
prefix="my-fancy-prefix",
|
||||
@@ -1200,6 +958,26 @@ base_inputs = {
|
||||
),
|
||||
"my-fancy-prefix/2023-01-01 00:00:00-2023-01-01 01:00:00.jsonl",
|
||||
),
|
||||
(
|
||||
S3InsertInputs(
|
||||
prefix="my-fancy-prefix",
|
||||
data_interval_start="2023-01-01 00:00:00",
|
||||
data_interval_end="2023-01-01 01:00:00",
|
||||
compression="gzip",
|
||||
**base_inputs,
|
||||
),
|
||||
"my-fancy-prefix/2023-01-01 00:00:00-2023-01-01 01:00:00.jsonl.gz",
|
||||
),
|
||||
(
|
||||
S3InsertInputs(
|
||||
prefix="my-fancy-prefix",
|
||||
data_interval_start="2023-01-01 00:00:00",
|
||||
data_interval_end="2023-01-01 01:00:00",
|
||||
compression="brotli",
|
||||
**base_inputs,
|
||||
),
|
||||
"my-fancy-prefix/2023-01-01 00:00:00-2023-01-01 01:00:00.jsonl.br",
|
||||
),
|
||||
(
|
||||
S3InsertInputs(
|
||||
prefix="my-fancy-prefix-with-a-forwardslash/",
|
||||
@@ -1236,6 +1014,26 @@ base_inputs = {
|
||||
),
|
||||
"nested/prefix/2023-01-01 00:00:00-2023-01-01 01:00:00.jsonl",
|
||||
),
|
||||
(
|
||||
S3InsertInputs(
|
||||
prefix="/nested/prefix/",
|
||||
data_interval_start="2023-01-01 00:00:00",
|
||||
data_interval_end="2023-01-01 01:00:00",
|
||||
compression="gzip",
|
||||
**base_inputs,
|
||||
),
|
||||
"nested/prefix/2023-01-01 00:00:00-2023-01-01 01:00:00.jsonl.gz",
|
||||
),
|
||||
(
|
||||
S3InsertInputs(
|
||||
prefix="/nested/prefix/",
|
||||
data_interval_start="2023-01-01 00:00:00",
|
||||
data_interval_end="2023-01-01 01:00:00",
|
||||
compression="brotli",
|
||||
**base_inputs,
|
||||
),
|
||||
"nested/prefix/2023-01-01 00:00:00-2023-01-01 01:00:00.jsonl.br",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_get_s3_key(inputs, expected):
|
||||
|
||||
@@ -65,6 +65,7 @@ def person_overrides_table(query_inputs):
|
||||
sync_execute(PERSON_OVERRIDES_CREATE_TABLE_SQL)
|
||||
sync_execute(KAFKA_PERSON_OVERRIDES_TABLE_SQL)
|
||||
sync_execute(PERSON_OVERRIDES_CREATE_MATERIALIZED_VIEW_SQL)
|
||||
sync_execute("TRUNCATE TABLE person_overrides")
|
||||
|
||||
yield
|
||||
|
||||
@@ -89,9 +90,9 @@ def person_overrides_data(person_overrides_table):
|
||||
"""
|
||||
person_overrides = {
|
||||
# These numbers are all arbitrary.
|
||||
1: {PersonOverrideTuple(uuid4(), uuid4()) for _ in range(5)},
|
||||
2: {PersonOverrideTuple(uuid4(), uuid4()) for _ in range(4)},
|
||||
3: {PersonOverrideTuple(uuid4(), uuid4()) for _ in range(3)},
|
||||
100: {PersonOverrideTuple(uuid4(), uuid4()) for _ in range(5)},
|
||||
200: {PersonOverrideTuple(uuid4(), uuid4()) for _ in range(4)},
|
||||
300: {PersonOverrideTuple(uuid4(), uuid4()) for _ in range(3)},
|
||||
}
|
||||
|
||||
all_test_values = []
|
||||
|
||||
@@ -1,11 +1,13 @@
|
||||
import collections.abc
|
||||
import csv
|
||||
import datetime as dt
|
||||
import gzip
|
||||
import json
|
||||
import tempfile
|
||||
import typing
|
||||
from string import Template
|
||||
|
||||
import brotli
|
||||
from temporalio import workflow
|
||||
|
||||
SELECT_QUERY_TEMPLATE = Template(
|
||||
@@ -21,17 +23,33 @@ SELECT_QUERY_TEMPLATE = Template(
|
||||
AND COALESCE(inserted_at, _timestamp) >= toDateTime64({data_interval_start}, 6, 'UTC')
|
||||
AND COALESCE(inserted_at, _timestamp) < toDateTime64({data_interval_end}, 6, 'UTC')
|
||||
AND team_id = {team_id}
|
||||
$exclude_events
|
||||
$order_by
|
||||
$format
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
async def get_rows_count(client, team_id: int, interval_start: str, interval_end: str) -> int:
|
||||
async def get_rows_count(
|
||||
client,
|
||||
team_id: int,
|
||||
interval_start: str,
|
||||
interval_end: str,
|
||||
exclude_events: collections.abc.Iterable[str] | None = None,
|
||||
) -> int:
|
||||
data_interval_start_ch = dt.datetime.fromisoformat(interval_start).strftime("%Y-%m-%d %H:%M:%S")
|
||||
data_interval_end_ch = dt.datetime.fromisoformat(interval_end).strftime("%Y-%m-%d %H:%M:%S")
|
||||
|
||||
if exclude_events:
|
||||
exclude_events_statement = f"AND event NOT IN {str(tuple(exclude_events))}"
|
||||
else:
|
||||
exclude_events_statement = ""
|
||||
|
||||
query = SELECT_QUERY_TEMPLATE.substitute(
|
||||
fields="count(DISTINCT event, cityHash64(distinct_id), cityHash64(uuid)) as count", order_by="", format=""
|
||||
fields="count(DISTINCT event, cityHash64(distinct_id), cityHash64(uuid)) as count",
|
||||
order_by="",
|
||||
format="",
|
||||
exclude_events=exclude_events_statement,
|
||||
)
|
||||
|
||||
count = await client.read_query(
|
||||
@@ -68,14 +86,25 @@ elements_chain
|
||||
|
||||
|
||||
def get_results_iterator(
|
||||
client, team_id: int, interval_start: str, interval_end: str
|
||||
client,
|
||||
team_id: int,
|
||||
interval_start: str,
|
||||
interval_end: str,
|
||||
exclude_events: collections.abc.Iterable[str] | None = None,
|
||||
) -> typing.Generator[dict[str, typing.Any], None, None]:
|
||||
data_interval_start_ch = dt.datetime.fromisoformat(interval_start).strftime("%Y-%m-%d %H:%M:%S")
|
||||
data_interval_end_ch = dt.datetime.fromisoformat(interval_end).strftime("%Y-%m-%d %H:%M:%S")
|
||||
|
||||
if exclude_events:
|
||||
exclude_events_statement = f"AND event NOT IN {str(tuple(exclude_events))}"
|
||||
else:
|
||||
exclude_events_statement = ""
|
||||
|
||||
query = SELECT_QUERY_TEMPLATE.substitute(
|
||||
fields=FIELDS,
|
||||
order_by="ORDER BY inserted_at",
|
||||
format="FORMAT ArrowStream",
|
||||
exclude_events=exclude_events_statement,
|
||||
)
|
||||
|
||||
for batch in client.stream_query_as_arrow(
|
||||
@@ -204,6 +233,7 @@ class BatchExportTemporaryFile:
|
||||
self,
|
||||
mode: str = "w+b",
|
||||
buffering=-1,
|
||||
compression: str | None = None,
|
||||
encoding: str | None = None,
|
||||
newline: str | None = None,
|
||||
suffix: str | None = None,
|
||||
@@ -222,10 +252,12 @@ class BatchExportTemporaryFile:
|
||||
dir=dir,
|
||||
errors=errors,
|
||||
)
|
||||
self.compression = compression
|
||||
self.bytes_total = 0
|
||||
self.records_total = 0
|
||||
self.bytes_since_last_reset = 0
|
||||
self.records_since_last_reset = 0
|
||||
self._brotli_compressor = None
|
||||
|
||||
def __getattr__(self, name):
|
||||
"""Pass get attr to underlying tempfile.NamedTemporaryFile."""
|
||||
@@ -240,11 +272,37 @@ class BatchExportTemporaryFile:
|
||||
"""Context-manager protocol exit method."""
|
||||
return self._file.__exit__(exc, value, tb)
|
||||
|
||||
@property
|
||||
def brotli_compressor(self):
|
||||
if self._brotli_compressor is None:
|
||||
self._brotli_compressor = brotli.Compressor()
|
||||
return self._brotli_compressor
|
||||
|
||||
def compress(self, content: bytes | str) -> bytes:
|
||||
if isinstance(content, str):
|
||||
encoded = content.encode("utf-8")
|
||||
else:
|
||||
encoded = content
|
||||
|
||||
match self.compression:
|
||||
case "gzip":
|
||||
return gzip.compress(encoded)
|
||||
case "brotli":
|
||||
self.brotli_compressor.process(encoded)
|
||||
return self.brotli_compressor.flush()
|
||||
case None:
|
||||
return encoded
|
||||
case _:
|
||||
raise ValueError(f"Unsupported compression: '{self.compression}'")
|
||||
|
||||
def write(self, content: bytes | str):
|
||||
"""Write bytes to underlying file keeping track of how many bytes were written."""
|
||||
if "b" in self.mode and isinstance(content, str):
|
||||
content = content.encode("utf-8")
|
||||
result = self._file.write(content)
|
||||
compressed_content = self.compress(content)
|
||||
|
||||
if "b" in self.mode:
|
||||
result = self._file.write(compressed_content)
|
||||
else:
|
||||
result = self._file.write(compressed_content.decode("utf-8"))
|
||||
|
||||
self.bytes_total += result
|
||||
self.bytes_since_last_reset += result
|
||||
@@ -316,6 +374,18 @@ class BatchExportTemporaryFile:
|
||||
quoting=quoting,
|
||||
)
|
||||
|
||||
def rewind(self):
|
||||
"""Rewind the file before reading it."""
|
||||
if self.compression == "brotli":
|
||||
result = self._file.write(self.brotli_compressor.finish())
|
||||
|
||||
self.bytes_total += result
|
||||
self.bytes_since_last_reset += result
|
||||
|
||||
self._brotli_compressor = None
|
||||
|
||||
self._file.seek(0)
|
||||
|
||||
def reset(self):
|
||||
"""Reset underlying file by truncating it.
|
||||
|
||||
|
||||
@@ -171,7 +171,7 @@ class ClickHouseClient:
|
||||
ClickHouseError: If the status code is not 200.
|
||||
"""
|
||||
if response.status_code != 200:
|
||||
error_message = response.text()
|
||||
error_message = response.text
|
||||
raise ClickHouseError(query, error_message)
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
|
||||
@@ -2,9 +2,8 @@ import asyncio
|
||||
import datetime as dt
|
||||
import json
|
||||
import posixpath
|
||||
import tempfile
|
||||
import typing
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, List
|
||||
|
||||
import boto3
|
||||
from django.conf import settings
|
||||
@@ -20,15 +19,13 @@ from posthog.temporal.workflows.base import (
|
||||
update_export_run_status,
|
||||
)
|
||||
from posthog.temporal.workflows.batch_exports import (
|
||||
BatchExportTemporaryFile,
|
||||
get_data_interval,
|
||||
get_results_iterator,
|
||||
get_rows_count,
|
||||
)
|
||||
from posthog.temporal.workflows.clickhouse import get_client
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from mypy_boto3_s3.type_defs import CompletedPartTypeDef
|
||||
|
||||
|
||||
def get_allowed_template_variables(inputs) -> dict[str, str]:
|
||||
"""Derive from inputs a dictionary of supported template variables for the S3 key prefix."""
|
||||
@@ -50,7 +47,17 @@ def get_s3_key(inputs) -> str:
|
||||
"""Return an S3 key given S3InsertInputs."""
|
||||
template_variables = get_allowed_template_variables(inputs)
|
||||
key_prefix = inputs.prefix.format(**template_variables)
|
||||
key = posixpath.join(key_prefix, f"{inputs.data_interval_start}-{inputs.data_interval_end}.jsonl")
|
||||
|
||||
base_file_name = f"{inputs.data_interval_start}-{inputs.data_interval_end}"
|
||||
match inputs.compression:
|
||||
case "gzip":
|
||||
file_name = base_file_name + ".jsonl.gz"
|
||||
case "brotli":
|
||||
file_name = base_file_name + ".jsonl.br"
|
||||
case _:
|
||||
file_name = base_file_name + ".jsonl"
|
||||
|
||||
key = posixpath.join(key_prefix, file_name)
|
||||
|
||||
if posixpath.isabs(key):
|
||||
# Keys are relative to root dir, so this would add an extra "/"
|
||||
@@ -59,6 +66,152 @@ def get_s3_key(inputs) -> str:
|
||||
return key
|
||||
|
||||
|
||||
class UploadAlreadyInProgressError(Exception):
|
||||
"""Exception raised when an S3MultiPartUpload is already in progress."""
|
||||
|
||||
def __init__(self, upload_id):
|
||||
super().__init__(f"This upload is already in progress with ID: {upload_id}. Instantiate a new object.")
|
||||
|
||||
|
||||
class NoUploadInProgressError(Exception):
|
||||
"""Exception raised when there is no S3MultiPartUpload in progress."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__("No multi-part upload is in progress. Call 'create' to start one.")
|
||||
|
||||
|
||||
class S3MultiPartUploadState(typing.NamedTuple):
|
||||
upload_id: str
|
||||
parts: list[dict[str, str | int]]
|
||||
|
||||
|
||||
class S3MultiPartUpload:
|
||||
"""An S3 multi-part upload."""
|
||||
|
||||
def __init__(self, s3_client, bucket_name, key):
|
||||
self.s3_client = s3_client
|
||||
self.bucket_name = bucket_name
|
||||
self.key = key
|
||||
self.upload_id = None
|
||||
self.parts = []
|
||||
|
||||
def to_state(self) -> S3MultiPartUploadState:
|
||||
"""Produce state tuple that can be used to resume this S3MultiPartUpload."""
|
||||
# The second predicate is trivial but required by type-checking.
|
||||
if self.is_upload_in_progress() is False or self.upload_id is None:
|
||||
raise NoUploadInProgressError()
|
||||
|
||||
return S3MultiPartUploadState(self.upload_id, self.parts)
|
||||
|
||||
@property
|
||||
def part_number(self):
|
||||
"""Return the current part number."""
|
||||
return len(self.parts)
|
||||
|
||||
def is_upload_in_progress(self) -> bool:
|
||||
"""Whether this S3MultiPartUpload is in progress or not."""
|
||||
if self.upload_id is None:
|
||||
return False
|
||||
return True
|
||||
|
||||
def start(self) -> str:
|
||||
"""Start this S3MultiPartUpload."""
|
||||
if self.is_upload_in_progress() is True:
|
||||
raise UploadAlreadyInProgressError(self.upload_id)
|
||||
|
||||
multipart_response = self.s3_client.create_multipart_upload(Bucket=self.bucket_name, Key=self.key)
|
||||
self.upload_id = multipart_response["UploadId"]
|
||||
|
||||
return self.upload_id
|
||||
|
||||
def continue_from_state(self, state: S3MultiPartUploadState):
|
||||
"""Continue this S3MultiPartUpload from a previous state."""
|
||||
self.upload_id = state.upload_id
|
||||
self.parts = state.parts
|
||||
|
||||
return self.upload_id
|
||||
|
||||
def complete(self) -> str:
|
||||
if self.is_upload_in_progress() is False:
|
||||
raise NoUploadInProgressError()
|
||||
|
||||
response = self.s3_client.complete_multipart_upload(
|
||||
Bucket=self.bucket_name,
|
||||
Key=self.key,
|
||||
UploadId=self.upload_id,
|
||||
MultipartUpload={"Parts": self.parts},
|
||||
)
|
||||
|
||||
self.upload_id = None
|
||||
self.parts = []
|
||||
|
||||
return response["Location"]
|
||||
|
||||
def abort(self):
|
||||
if self.is_upload_in_progress() is False:
|
||||
raise NoUploadInProgressError()
|
||||
|
||||
self.s3_client.abort_multipart_upload(
|
||||
Bucket=self.bucket_name,
|
||||
Key=self.key,
|
||||
UploadId=self.upload_id,
|
||||
)
|
||||
|
||||
self.upload_id = None
|
||||
self.parts = []
|
||||
|
||||
def upload_part(self, body: BatchExportTemporaryFile, rewind: bool = True):
|
||||
next_part_number = self.part_number + 1
|
||||
|
||||
if rewind is True:
|
||||
body.rewind()
|
||||
|
||||
response = self.s3_client.upload_part(
|
||||
Bucket=self.bucket_name,
|
||||
Key=self.key,
|
||||
PartNumber=next_part_number,
|
||||
UploadId=self.upload_id,
|
||||
Body=body,
|
||||
)
|
||||
|
||||
self.parts.append({"PartNumber": next_part_number, "ETag": response["ETag"]})
|
||||
|
||||
def __enter__(self):
|
||||
if not self.is_upload_in_progress():
|
||||
self.start()
|
||||
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback) -> bool:
|
||||
if exc_value is None:
|
||||
# Succesfully completed the upload
|
||||
self.complete()
|
||||
return True
|
||||
|
||||
if exc_type == asyncio.CancelledError:
|
||||
# Ensure we clean-up the cancelled upload.
|
||||
self.abort()
|
||||
|
||||
return False
|
||||
|
||||
|
||||
class HeartbeatDetails(typing.NamedTuple):
|
||||
"""This tuple allows us to enforce a schema on the Heartbeat details.
|
||||
Attributes:
|
||||
last_uploaded_part_timestamp: The timestamp of the last part we managed to upload.
|
||||
upload_state: State to continue a S3MultiPartUpload when activity execution resumes.
|
||||
"""
|
||||
|
||||
last_uploaded_part_timestamp: str
|
||||
upload_state: S3MultiPartUploadState
|
||||
|
||||
@classmethod
|
||||
def from_activity_details(cls, details):
|
||||
last_uploaded_part_timestamp = details[0]
|
||||
upload_state = S3MultiPartUploadState(*details[1])
|
||||
return HeartbeatDetails(last_uploaded_part_timestamp, upload_state)
|
||||
|
||||
|
||||
@dataclass
|
||||
class S3InsertInputs:
|
||||
"""Inputs for S3 exports."""
|
||||
@@ -75,6 +228,55 @@ class S3InsertInputs:
|
||||
data_interval_end: str
|
||||
aws_access_key_id: str | None = None
|
||||
aws_secret_access_key: str | None = None
|
||||
compression: str | None = None
|
||||
exclude_events: list[str] | None = None
|
||||
|
||||
|
||||
def initialize_and_resume_multipart_upload(inputs: S3InsertInputs) -> tuple[S3MultiPartUpload, str]:
|
||||
"""Initialize a S3MultiPartUpload and resume it from a hearbeat state if available."""
|
||||
key = get_s3_key(inputs)
|
||||
s3_client = boto3.client(
|
||||
"s3",
|
||||
region_name=inputs.region,
|
||||
aws_access_key_id=inputs.aws_access_key_id,
|
||||
aws_secret_access_key=inputs.aws_secret_access_key,
|
||||
)
|
||||
s3_upload = S3MultiPartUpload(s3_client, inputs.bucket_name, key)
|
||||
|
||||
details = activity.info().heartbeat_details
|
||||
|
||||
try:
|
||||
interval_start, upload_state = HeartbeatDetails.from_activity_details(details)
|
||||
except IndexError:
|
||||
# This is the error we expect when no details as the sequence will be empty.
|
||||
interval_start = inputs.data_interval_start
|
||||
activity.logger.info(
|
||||
f"Did not receive details from previous activity Excecution. Export will start from the beginning: {interval_start}"
|
||||
)
|
||||
except Exception as e:
|
||||
# We still start from the beginning, but we make a point to log unexpected errors.
|
||||
# Ideally, any new exceptions should be added to the previous block after the first time and we will never land here.
|
||||
interval_start = inputs.data_interval_start
|
||||
activity.logger.warning(
|
||||
f"Did not receive details from previous activity Excecution due to an unexpected error. Export will start from the beginning: {interval_start}",
|
||||
exc_info=e,
|
||||
)
|
||||
else:
|
||||
activity.logger.info(
|
||||
f"Received details from previous activity. Export will attempt to resume from: {interval_start}"
|
||||
)
|
||||
s3_upload.continue_from_state(upload_state)
|
||||
|
||||
if inputs.compression == "brotli":
|
||||
# Even if we receive details we cannot resume a brotli compressed upload as we have lost the compressor state.
|
||||
interval_start = inputs.data_interval_start
|
||||
|
||||
activity.logger.info(
|
||||
f"Export will start from the beginning as we are using brotli compression: {interval_start}"
|
||||
)
|
||||
s3_upload.abort()
|
||||
|
||||
return s3_upload, interval_start
|
||||
|
||||
|
||||
@activity.defn
|
||||
@@ -111,28 +313,7 @@ async def insert_into_s3_activity(inputs: S3InsertInputs):
|
||||
|
||||
activity.logger.info("BatchExporting %s rows to S3", count)
|
||||
|
||||
# Create a multipart upload to S3
|
||||
key = get_s3_key(inputs)
|
||||
s3_client = boto3.client(
|
||||
"s3",
|
||||
region_name=inputs.region,
|
||||
aws_access_key_id=inputs.aws_access_key_id,
|
||||
aws_secret_access_key=inputs.aws_secret_access_key,
|
||||
)
|
||||
|
||||
details = activity.info().heartbeat_details
|
||||
|
||||
parts: List[CompletedPartTypeDef] = []
|
||||
|
||||
if len(details) == 4:
|
||||
interval_start, upload_id, parts, part_number = details
|
||||
activity.logger.info(f"Received details from previous activity. Export will resume from {interval_start}")
|
||||
|
||||
else:
|
||||
multipart_response = s3_client.create_multipart_upload(Bucket=inputs.bucket_name, Key=key)
|
||||
upload_id = multipart_response["UploadId"]
|
||||
interval_start = inputs.data_interval_start
|
||||
part_number = 1
|
||||
s3_upload, interval_start = initialize_and_resume_multipart_upload(inputs)
|
||||
|
||||
# Iterate through chunks of results from ClickHouse and push them to S3
|
||||
# as a multipart upload. The intention here is to keep memory usage low,
|
||||
@@ -145,6 +326,7 @@ async def insert_into_s3_activity(inputs: S3InsertInputs):
|
||||
team_id=inputs.team_id,
|
||||
interval_start=interval_start,
|
||||
interval_end=inputs.data_interval_end,
|
||||
exclude_events=inputs.exclude_events,
|
||||
)
|
||||
|
||||
result = None
|
||||
@@ -156,46 +338,14 @@ async def insert_into_s3_activity(inputs: S3InsertInputs):
|
||||
activity.logger.warn(
|
||||
f"Worker shutting down! Reporting back latest exported part {last_uploaded_part_timestamp}"
|
||||
)
|
||||
activity.heartbeat(last_uploaded_part_timestamp, upload_id)
|
||||
activity.heartbeat(last_uploaded_part_timestamp, s3_upload.to_state())
|
||||
|
||||
asyncio.create_task(worker_shutdown_handler())
|
||||
|
||||
with tempfile.NamedTemporaryFile() as local_results_file:
|
||||
while True:
|
||||
try:
|
||||
result = results_iterator.__next__()
|
||||
except StopIteration:
|
||||
break
|
||||
except json.JSONDecodeError:
|
||||
# This is raised by aiochclient as we try to decode an error message from ClickHouse.
|
||||
# So far, this error message only indicated that we were too slow consuming rows.
|
||||
# So, we can resume from the last result.
|
||||
if result is None:
|
||||
# We failed right at the beginning
|
||||
new_interval_start = None
|
||||
else:
|
||||
new_interval_start = result.get("inserted_at", None)
|
||||
|
||||
if not isinstance(new_interval_start, str):
|
||||
new_interval_start = inputs.data_interval_start
|
||||
|
||||
activity.logger.warn(
|
||||
f"Failed to decode a JSON value while iterating, potentially due to a ClickHouse error. Resuming from {new_interval_start}"
|
||||
)
|
||||
|
||||
results_iterator = get_results_iterator(
|
||||
client=client,
|
||||
team_id=inputs.team_id,
|
||||
interval_start=new_interval_start, # This means we'll generate at least one duplicate.
|
||||
interval_end=inputs.data_interval_end,
|
||||
)
|
||||
continue
|
||||
|
||||
if not result:
|
||||
break
|
||||
|
||||
content = json.dumps(
|
||||
{
|
||||
with s3_upload as s3_upload:
|
||||
with BatchExportTemporaryFile(compression=inputs.compression) as local_results_file:
|
||||
for result in results_iterator:
|
||||
record = {
|
||||
"created_at": result["created_at"],
|
||||
"distinct_id": result["distinct_id"],
|
||||
"elements_chain": result["elements_chain"],
|
||||
@@ -207,60 +357,36 @@ async def insert_into_s3_activity(inputs: S3InsertInputs):
|
||||
"timestamp": result["timestamp"],
|
||||
"uuid": result["uuid"],
|
||||
}
|
||||
)
|
||||
|
||||
# Write the results to a local file
|
||||
local_results_file.write(content.encode("utf-8"))
|
||||
local_results_file.write("\n".encode("utf-8"))
|
||||
local_results_file.write_records_to_jsonl([record])
|
||||
|
||||
# Write results to S3 when the file reaches 50MB and reset the
|
||||
# file, or if there is nothing else to write.
|
||||
if (
|
||||
local_results_file.tell()
|
||||
and local_results_file.tell() > settings.BATCH_EXPORT_S3_UPLOAD_CHUNK_SIZE_BYTES
|
||||
):
|
||||
activity.logger.info("Uploading part %s", part_number)
|
||||
if local_results_file.tell() > settings.BATCH_EXPORT_S3_UPLOAD_CHUNK_SIZE_BYTES:
|
||||
activity.logger.info(
|
||||
"Uploading part %s containing %s records with size %s bytes to S3",
|
||||
s3_upload.part_number + 1,
|
||||
local_results_file.records_since_last_reset,
|
||||
local_results_file.bytes_since_last_reset,
|
||||
)
|
||||
|
||||
local_results_file.seek(0)
|
||||
response = s3_client.upload_part(
|
||||
Bucket=inputs.bucket_name,
|
||||
Key=key,
|
||||
PartNumber=part_number,
|
||||
UploadId=upload_id,
|
||||
Body=local_results_file,
|
||||
s3_upload.upload_part(local_results_file)
|
||||
|
||||
last_uploaded_part_timestamp = result["inserted_at"]
|
||||
activity.heartbeat(last_uploaded_part_timestamp, s3_upload.to_state())
|
||||
|
||||
local_results_file.reset()
|
||||
|
||||
if local_results_file.tell() > 0 and result is not None:
|
||||
activity.logger.info(
|
||||
"Uploading last part %s containing %s records with size %s bytes to S3",
|
||||
s3_upload.part_number + 1,
|
||||
local_results_file.records_since_last_reset,
|
||||
local_results_file.bytes_since_last_reset,
|
||||
)
|
||||
|
||||
s3_upload.upload_part(local_results_file)
|
||||
|
||||
last_uploaded_part_timestamp = result["inserted_at"]
|
||||
# Record the ETag for the part
|
||||
parts.append({"PartNumber": part_number, "ETag": response["ETag"]})
|
||||
part_number += 1
|
||||
|
||||
activity.heartbeat(last_uploaded_part_timestamp, upload_id, parts, part_number)
|
||||
|
||||
# Reset the file
|
||||
local_results_file.seek(0)
|
||||
local_results_file.truncate()
|
||||
|
||||
# Upload the last part
|
||||
local_results_file.seek(0)
|
||||
response = s3_client.upload_part(
|
||||
Bucket=inputs.bucket_name,
|
||||
Key=key,
|
||||
PartNumber=part_number,
|
||||
UploadId=upload_id,
|
||||
Body=local_results_file,
|
||||
)
|
||||
activity.heartbeat(last_uploaded_part_timestamp, upload_id, parts, part_number)
|
||||
|
||||
# Record the ETag for the last part
|
||||
parts.append({"PartNumber": part_number, "ETag": response["ETag"]})
|
||||
|
||||
# Complete the multipart upload
|
||||
s3_client.complete_multipart_upload(
|
||||
Bucket=inputs.bucket_name,
|
||||
Key=key,
|
||||
UploadId=upload_id,
|
||||
MultipartUpload={"Parts": parts},
|
||||
)
|
||||
activity.heartbeat(last_uploaded_part_timestamp, s3_upload.to_state())
|
||||
|
||||
|
||||
@workflow.defn(name="s3-export")
|
||||
@@ -314,6 +440,8 @@ class S3BatchExportWorkflow(PostHogWorkflow):
|
||||
aws_secret_access_key=inputs.aws_secret_access_key,
|
||||
data_interval_start=data_interval_start.isoformat(),
|
||||
data_interval_end=data_interval_end.isoformat(),
|
||||
compression=inputs.compression,
|
||||
exclude_events=inputs.exclude_events,
|
||||
)
|
||||
try:
|
||||
await workflow.execute_activity(
|
||||
@@ -334,6 +462,7 @@ class S3BatchExportWorkflow(PostHogWorkflow):
|
||||
except Exception as e:
|
||||
workflow.logger.exception("S3 BatchExport failed.", exc_info=e)
|
||||
update_inputs.status = "Failed"
|
||||
update_inputs.latest_error = str(e)
|
||||
raise
|
||||
|
||||
finally:
|
||||
|
||||
@@ -32,8 +32,8 @@ class TestGzipMiddleware(APIBaseTest):
|
||||
|
||||
def test_no_compression_for_unsuccessful_requests_to_paths_on_the_allow_list(self) -> None:
|
||||
with self.settings(GZIP_RESPONSE_ALLOW_LIST=["something-else", "snapshots$"]):
|
||||
response = self._get_path("/api/projects/12/session_recordings/blah/snapshots")
|
||||
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
|
||||
response = self._get_path(f"/api/projects/{self.team.pk}/session_recordings/blah/snapshots")
|
||||
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND, msg=response.content.decode("utf-8"))
|
||||
|
||||
contentEncoding = response.headers.get("Content-Encoding", None)
|
||||
self.assertEqual(contentEncoding, None)
|
||||
|
||||
@@ -9,6 +9,7 @@ antlr4-python3-runtime==4.13.0
|
||||
amqp==2.6.0
|
||||
boto3==1.26.66
|
||||
boto3-stubs[s3]
|
||||
brotli==1.0.9
|
||||
celery==4.4.7
|
||||
celery-redbeat==2.0.0
|
||||
clickhouse-driver==0.2.4
|
||||
|
||||
@@ -51,6 +51,8 @@ botocore==1.29.66
|
||||
# s3transfer
|
||||
botocore-stubs==1.29.130
|
||||
# via boto3-stubs
|
||||
brotli==1.0.9
|
||||
# via -r requirements.in
|
||||
celery==4.4.7
|
||||
# via
|
||||
# -r requirements.in
|
||||
|
||||
Reference in New Issue
Block a user