feat(max): dagster evaluation runner (#36320)

Co-authored-by: github-actions <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Co-authored-by: Michael Matloka <michael@matloka.com>
This commit is contained in:
Georgiy Tarasov
2025-08-26 17:34:14 +02:00
committed by GitHub
parent 8afe171d5b
commit a5394c47f2
59 changed files with 2593 additions and 566 deletions

View File

@@ -62,3 +62,6 @@ rust/cyclotron-node/index.node
rust/cyclotron-node/node_modules
rust/docker
rust/target
!docker-compose.base.yml
!docker-compose.dev.yml
!docker

66
.github/workflows/cd-ai-evals-image.yml vendored Normal file
View File

@@ -0,0 +1,66 @@
#
# Build and push PostHog container images for AI evaluations to AWS ECR
#
# - posthog_ai_evals_build: build and push the PostHog container image to AWS ECR
#
name: AI Evals Container Images CD
on:
push:
branches:
- master
paths-ignore:
- 'rust/**'
- 'livestream/**'
- 'plugin-server/**'
pull_request:
workflow_dispatch:
jobs:
posthog_ai_evals_build:
name: Build and push container image
if: |
github.repository == 'PostHog/posthog' && (
github.event_name == 'push' || contains(github.event.pull_request.labels.*.name, 'build-ai-evals-image')
)
runs-on: depot-ubuntu-latest
permissions:
id-token: write # allow issuing OIDC tokens for this workflow run
contents: read # allow at least reading the repo contents, add other permissions if necessary
steps:
- name: Check out
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4
with:
fetch-depth: 2
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@b5ca514318bd6ebac0fb2aedd5d36ec1b5c232a2 # v3
- name: Set up QEMU
uses: docker/setup-qemu-action@29109295f81e9208d7d86ff1c6c12d2833863392 # v3
- name: Set up Depot CLI
uses: depot/setup-action@b0b1ea4f69e92ebf5dea3f8713a1b0c37b2126a5 # v1
- name: Configure AWS credentials
uses: aws-actions/configure-aws-credentials@e3dd6a429d7300a6a4c196c26e071d42e0343502 # v4
with:
aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }}
aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
aws-region: us-east-1
- name: Login to Amazon ECR
id: aws-ecr
uses: aws-actions/amazon-ecr-login@062b18b96a7aff071d4dc91bc00c4c1a7945b076 # v2
- name: Build and push container image
id: build
uses: depot/build-push-action@2583627a84956d07561420dcc1d0eb1f2af3fac0 # v1
with:
file: ./Dockerfile.ai-evals
buildx-fallback: false # the fallback is so slow it's better to just fail
push: true
tags: ${{ github.ref == 'refs/heads/master' && format('{0}/posthog-ai-evals:master,{0}/posthog-ai-evals:{1}', steps.aws-ecr.outputs.registry, github.sha) || format('{0}/posthog-ai-evals:{1}', steps.aws-ecr.outputs.registry, github.sha) }}
platforms: linux/arm64,linux/amd64
build-args: COMMIT_HASH=${{ github.sha }}

View File

@@ -59,8 +59,10 @@ jobs:
run: bin/check_kafka_clickhouse_up
- name: Run LLM evals
run: pytest ee/hogai/eval -vv
run: pytest ee/hogai/eval/ci -vv
env:
EVAL_MODE: ci
EXPORT_EVAL_RESULTS: true
BRAINTRUST_API_KEY: ${{ secrets.BRAINTRUST_API_KEY }}
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}

52
Dockerfile.ai-evals Normal file
View File

@@ -0,0 +1,52 @@
FROM python:3.11.9-slim-bookworm AS python-base
FROM cruizba/ubuntu-dind:latest
SHELL ["/bin/bash", "-e", "-o", "pipefail", "-c"]
# Copy Python base
COPY --from=python-base /usr/local /usr/local
# Set working directory
WORKDIR /code
# Copy Docker scripts
COPY docker/ ./docker/
# Install system dependencies
RUN apt-get update && \
apt-get install -y --no-install-recommends \
"build-essential" \
"git" \
"libpq-dev" \
"libxmlsec1" \
"libxmlsec1-dev" \
"libffi-dev" \
"zlib1g-dev" \
"pkg-config" \
"netcat-openbsd" \
"postgresql-client"
# Copy uv dependencies
COPY pyproject.toml uv.lock docker-compose.base.yml docker-compose.dev.yml ./
# Install python deps
RUN rm -rf /var/lib/apt/lists/* && \
pip install uv~=0.7.0 --no-cache-dir && \
UV_PROJECT_ENVIRONMENT=/python-runtime uv sync --frozen --no-cache --compile-bytecode \
--no-binary-package lxml --no-binary-package xmlsec
# Copy project files
COPY bin/ ./bin/
COPY manage.py manage.py
COPY common/esbuilder common/esbuilder
COPY common/hogvm common/hogvm/
COPY posthog posthog/
COPY products/ products/
COPY ee ee/
ENV PATH=/python-runtime/bin:$PATH \
PYTHONPATH=/python-runtime
# Make scripts executable
RUN chmod +x bin/*
CMD bin/docker-ai-evals

8
bin/check_clickhouse_up Executable file
View File

@@ -0,0 +1,8 @@
#!/bin/bash
set -e
# Check ClickHouse
while true; do
curl -s -o /dev/null -I 'http://localhost:8123/' && break || echo 'Checking ClickHouse status...' && sleep 1
done

View File

@@ -7,7 +7,4 @@ while true; do
nc -z localhost 9092 && break || echo 'Checking Kafka status...' && sleep 1
done
# Check ClickHouse
while true; do
curl -s -o /dev/null -I 'http://localhost:8123/' && break || echo 'Checking ClickHouse status...' && sleep 1
done
./bin/check_clickhouse_up

29
bin/docker-ai-evals Executable file
View File

@@ -0,0 +1,29 @@
#!/bin/bash
set -e
export DEBUG=1
export IN_EVAL_TESTING=1
export EVAL_MODE=offline
export EXPORT_EVAL_RESULTS=1
cleanup() {
echo "🧹 Cleaning up..."
docker compose -f docker-compose.dev.yml -p evals down -v --remove-orphans || true
}
trap cleanup EXIT INT TERM
echo "127.0.0.1 kafka clickhouse objectstorage db" >> /etc/hosts
echo "🚀 Starting services..."
docker compose -f docker-compose.dev.yml -p evals up -d db clickhouse objectstorage
echo "🔄 Waiting for services to start..."
bin/check_postgres_up & bin/check_kafka_clickhouse_up
echo "🏃 Running evaluation..."
if [ -z "$EVAL_SCRIPT" ]; then
echo "Error: EVAL_SCRIPT environment variable is not set"
exit 1
fi
$EVAL_SCRIPT
echo "🎉 Done."

View File

@@ -1,4 +1,5 @@
import dagster
from dagster_docker import PipesDockerClient
from dags.max_ai.run_evaluation import run_evaluation
@@ -6,5 +7,8 @@ from . import resources
defs = dagster.Definitions(
jobs=[run_evaluation],
resources=resources,
resources={
**resources,
"docker_pipes_client": PipesDockerClient(),
},
)

View File

@@ -1,22 +1,144 @@
import base64
from django.conf import settings
import boto3
import dagster
from dagster_docker import PipesDockerClient
from pydantic import Field
from tenacity import retry, stop_after_attempt, wait_exponential
from dags.common import JobOwners
from dags.max_ai.snapshot_project_data import snapshot_clickhouse_project_data, snapshot_postgres_project_data
from dags.max_ai.snapshot_team_data import (
ClickhouseTeamDataSnapshot,
PostgresTeamDataSnapshot,
snapshot_clickhouse_team_data,
snapshot_postgres_team_data,
)
from ee.hogai.eval.schema import DatasetInput, EvalsDockerImageConfig, TeamEvaluationSnapshot
class ExportProjectsConfig(dagster.Config):
project_ids: list[int]
"""Project IDs to run the evaluation for."""
def get_object_storage_endpoint() -> str:
"""
Get the object storage endpoint.
Debug mode uses the local object storage, so we need to set a DNS endpoint (like orb.dev).
Production mode uses the AWS S3.
"""
if settings.DEBUG:
val = dagster.EnvVar("EVALS_DIND_OBJECT_STORAGE_ENDPOINT").get_value("http://objectstorage.posthog.orb.local")
if not val:
raise ValueError("EVALS_DIND_OBJECT_STORAGE_ENDPOINT is not set")
return val
return settings.OBJECT_STORAGE_ENDPOINT
class ExportTeamsConfig(dagster.Config):
team_ids: list[int]
"""Team IDs to run the evaluation for."""
@dagster.op(out=dagster.DynamicOut(int))
def export_projects(config: ExportProjectsConfig):
seen_projects = set()
for pid in config.project_ids:
if pid in seen_projects:
def export_teams(config: ExportTeamsConfig):
seen_teams = set()
for tid in config.team_ids:
if tid in seen_teams:
continue
seen_projects.add(pid)
yield dagster.DynamicOutput(pid, mapping_key=str(pid))
seen_teams.add(tid)
yield dagster.DynamicOutput(tid, mapping_key=str(tid))
class EvaluationConfig(dagster.Config):
image_name: str = Field(description="Name of the Docker image to run.")
image_tag: str = Field(description="Tag of the Docker image to run.")
experiment_name: str = Field(description="Name of the experiment.")
evaluation_module: str = Field(description="Python module containing the evaluation runner.")
@property
def image(self) -> str:
# We use the local Docker image in debug mode
if settings.DEBUG:
return f"{self.image_name}:{self.image_tag}"
return f"{dagster.EnvVar('AWS_EKS_REGISTRY_URL').get_value()}/{self.image_name}:{self.image_tag}"
@retry(stop=stop_after_attempt(3), wait=wait_exponential(min=2))
def get_registry_credentials():
# We use the local Docker image in debug mode
if settings.DEBUG:
return None
client = boto3.client("ecr")
# https://boto3.amazonaws.com/v1/documentation/api/1.29.2/reference/services/ecr/client/get_authorization_token.html
token = client.get_authorization_token()["authorizationData"][0]["authorizationToken"]
username, password = base64.b64decode(token).decode("utf-8").split(":")
return {
"url": dagster.EnvVar("AWS_EKS_REGISTRY_URL").get_value(),
"username": username,
"password": password,
}
@dagster.op
def spawn_evaluation_container(
context: dagster.OpExecutionContext,
config: EvaluationConfig,
docker_pipes_client: PipesDockerClient,
team_ids: list[int],
postgres_snapshots: list[PostgresTeamDataSnapshot],
clickhouse_snapshots: list[ClickhouseTeamDataSnapshot],
):
evaluation_config = EvalsDockerImageConfig(
aws_endpoint_url=get_object_storage_endpoint(),
aws_bucket_name=settings.OBJECT_STORAGE_BUCKET,
team_snapshots=[
TeamEvaluationSnapshot(team_id=team_id, postgres=postgres, clickhouse=clickhouse).model_dump()
for team_id, postgres, clickhouse in zip(team_ids, postgres_snapshots, clickhouse_snapshots)
],
experiment_name=config.experiment_name,
dataset=[
DatasetInput(
team_id=team_id,
input={"query": "List all events from the last 7 days. Use SQL."},
expected={"output": "SELECT * FROM events WHERE timestamp >= now() - INTERVAL 7 day"},
)
for team_id in team_ids
],
)
context.log.info(f"Running evaluation for the image: {config.image}")
asset_result = docker_pipes_client.run(
context=context,
image=config.image,
container_kwargs={
"privileged": True,
"auto_remove": True,
},
env={
"EVAL_SCRIPT": f"pytest {config.evaluation_module} -s -vv",
"OBJECT_STORAGE_ACCESS_KEY_ID": settings.OBJECT_STORAGE_ACCESS_KEY_ID, # type: ignore
"OBJECT_STORAGE_SECRET_ACCESS_KEY": settings.OBJECT_STORAGE_SECRET_ACCESS_KEY, # type: ignore
"OPENAI_API_KEY": settings.OPENAI_API_KEY,
"ANTHROPIC_API_KEY": settings.ANTHROPIC_API_KEY,
"GEMINI_API_KEY": settings.GEMINI_API_KEY,
"INKEEP_API_KEY": settings.INKEEP_API_KEY,
"PPLX_API_KEY": settings.PPLX_API_KEY,
"AZURE_INFERENCE_ENDPOINT": settings.AZURE_INFERENCE_ENDPOINT,
"AZURE_INFERENCE_CREDENTIAL": settings.AZURE_INFERENCE_CREDENTIAL,
"BRAINTRUST_API_KEY": settings.BRAINTRUST_API_KEY,
},
extras=evaluation_config.model_dump(exclude_unset=True),
registry=get_registry_credentials(),
).get_materialize_result()
context.log_event(
dagster.AssetMaterialization(
asset_key=asset_result.asset_key or "evaluation_report",
metadata=asset_result.metadata,
tags={"owner": JobOwners.TEAM_MAX_AI.value},
)
)
@dagster.job(
@@ -28,11 +150,18 @@ def export_projects(config: ExportProjectsConfig):
executor_def=dagster.multiprocess_executor.configured({"max_concurrent": 4}),
config=dagster.RunConfig(
ops={
"export_projects": ExportProjectsConfig(project_ids=[]),
"export_teams": ExportTeamsConfig(team_ids=[]),
"spawn_evaluation_container": EvaluationConfig(
evaluation_module="",
experiment_name="offline_evaluation",
image_name="posthog-ai-evals",
image_tag="master",
),
}
),
)
def run_evaluation():
project_ids = export_projects()
project_ids.map(snapshot_postgres_project_data)
project_ids.map(snapshot_clickhouse_project_data)
team_ids = export_teams()
postgres_snapshots = team_ids.map(snapshot_postgres_team_data)
clickhouse_snapshots = team_ids.map(snapshot_clickhouse_team_data)
spawn_evaluation_container(team_ids.collect(), postgres_snapshots.collect(), clickhouse_snapshots.collect())

View File

@@ -9,6 +9,9 @@ from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_ex
from posthog.schema import (
ActorsPropertyTaxonomyQuery,
ActorsPropertyTaxonomyResponse,
CachedActorsPropertyTaxonomyQueryResponse,
CachedEventTaxonomyQueryResponse,
CachedTeamTaxonomyQueryResponse,
EventTaxonomyQuery,
TeamTaxonomyItem,
TeamTaxonomyQuery,
@@ -18,6 +21,7 @@ from posthog.errors import InternalCHQueryError
from posthog.hogql_queries.ai.actors_property_taxonomy_query_runner import ActorsPropertyTaxonomyQueryRunner
from posthog.hogql_queries.ai.event_taxonomy_query_runner import EventTaxonomyQueryRunner
from posthog.hogql_queries.ai.team_taxonomy_query_runner import TeamTaxonomyQueryRunner
from posthog.hogql_queries.query_runner import ExecutionMode
from posthog.models import GroupTypeMapping, Team
from posthog.models.property_definition import PropertyDefinition
@@ -26,10 +30,10 @@ from dags.max_ai.utils import check_dump_exists, compose_clickhouse_dump_path, c
from ee.hogai.eval.schema import (
ActorsPropertyTaxonomySnapshot,
BaseSnapshot,
ClickhouseProjectDataSnapshot,
ClickhouseTeamDataSnapshot,
DataWarehouseTableSnapshot,
GroupTypeMappingSnapshot,
PostgresProjectDataSnapshot,
PostgresTeamDataSnapshot,
PropertyDefinitionSnapshot,
PropertyTaxonomySnapshot,
TeamSnapshot,
@@ -47,26 +51,35 @@ DEFAULT_RETRY_POLICY = dagster.RetryPolicy(
SchemaBound = TypeVar("SchemaBound", bound=BaseSnapshot)
class SnapshotUnrecoverableError(ValueError):
"""
An error that indicates that the snapshot operation cannot be recovered from.
This is used to indicate that the snapshot operation failed and should not be retried.
"""
pass
def snapshot_postgres_model(
context: dagster.OpExecutionContext,
model_type: type[SchemaBound],
file_name: str,
s3: S3Resource,
project_id: int,
team_id: int,
code_version: str | None = None,
) -> str:
file_key = compose_postgres_dump_path(project_id, file_name, code_version)
file_key = compose_postgres_dump_path(team_id, file_name, code_version)
if check_dump_exists(s3, file_key):
context.log.info(f"Skipping {file_key} because it already exists")
return file_key
context.log.info(f"Dumping {file_key}")
with dump_model(s3=s3, schema=model_type, file_key=file_key) as dump:
dump(model_type.serialize_for_project(project_id))
dump(model_type.serialize_for_team(team_id=team_id))
return file_key
@dagster.op(
description="Snapshots Postgres project data (property definitions, DWH schema, etc.)",
description="Snapshots Postgres team data (property definitions, DWH schema, etc.)",
retry_policy=DEFAULT_RETRY_POLICY,
code_version="v1",
tags={
@@ -74,29 +87,38 @@ def snapshot_postgres_model(
"dagster/max_runtime": 60 * 15, # 15 minutes
},
)
def snapshot_postgres_project_data(
context: dagster.OpExecutionContext, project_id: int, s3: S3Resource
) -> PostgresProjectDataSnapshot:
context.log.info(f"Snapshotting Postgres project data for {project_id}")
def snapshot_postgres_team_data(
context: dagster.OpExecutionContext, team_id: int, s3: S3Resource
) -> PostgresTeamDataSnapshot:
context.log.info(f"Snapshotting Postgres team data for {team_id}")
snapshot_map: dict[str, type[BaseSnapshot]] = {
"project": TeamSnapshot,
"team": TeamSnapshot,
"property_definitions": PropertyDefinitionSnapshot,
"group_type_mappings": GroupTypeMappingSnapshot,
"data_warehouse_tables": DataWarehouseTableSnapshot,
}
deps = {
file_name: snapshot_postgres_model(context, model_type, file_name, s3, project_id, context.op_def.version)
for file_name, model_type in snapshot_map.items()
}
context.log_event(
dagster.AssetMaterialization(
asset_key="project_postgres_snapshot",
description="Avro snapshots of project Postgres data",
metadata={"project_id": project_id, **deps},
tags={"owner": JobOwners.TEAM_MAX_AI.value},
try:
deps = {
file_name: snapshot_postgres_model(context, model_type, file_name, s3, team_id, context.op_def.version)
for file_name, model_type in snapshot_map.items()
}
context.log_event(
dagster.AssetMaterialization(
asset_key="team_postgres_snapshot",
description="Avro snapshots of team Postgres data",
metadata={"team_id": team_id, **deps},
tags={"owner": JobOwners.TEAM_MAX_AI.value},
)
)
)
return PostgresProjectDataSnapshot(**deps)
except Team.DoesNotExist as e:
raise dagster.Failure(
description=f"Team {team_id} does not exist",
metadata={"team_id": team_id},
allow_retries=False,
) from e
return PostgresTeamDataSnapshot(**deps)
C = TypeVar("C")
@@ -120,17 +142,20 @@ def snapshot_properties_taxonomy(
):
results: list[PropertyTaxonomySnapshot] = []
def snapshot_event(item: TeamTaxonomyItem):
return call_query_runner(
lambda: EventTaxonomyQueryRunner(
query=EventTaxonomyQuery(event=item.event),
team=team,
).calculate()
def wrapped_query_runner(item: TeamTaxonomyItem):
response = EventTaxonomyQueryRunner(query=EventTaxonomyQuery(event=item.event), team=team).run(
execution_mode=ExecutionMode.RECENT_CACHE_CALCULATE_BLOCKING_IF_STALE
)
if not isinstance(response, CachedEventTaxonomyQueryResponse):
raise SnapshotUnrecoverableError(f"Unexpected response type from event taxonomy query: {type(response)}")
return response
def call_event_taxonomy_query(item: TeamTaxonomyItem):
return call_query_runner(lambda: wrapped_query_runner(item))
for item in events:
context.log.info(f"Snapshotting properties taxonomy for event {item.event} of {team.id}")
results.append(PropertyTaxonomySnapshot(event=item.event, results=snapshot_event(item).results))
results.append(PropertyTaxonomySnapshot(event=item.event, results=call_event_taxonomy_query(item).results))
context.log.info(f"Dumping properties taxonomy to {file_key}")
with dump_model(s3=s3, schema=PropertyTaxonomySnapshot, file_key=file_key) as dump:
@@ -152,9 +177,18 @@ def snapshot_events_taxonomy(
context.log.info(f"Snapshotting events taxonomy for {team.id}")
res = call_query_runner(lambda: TeamTaxonomyQueryRunner(query=TeamTaxonomyQuery(), team=team).calculate())
def snapshot_events_taxonomy():
response = TeamTaxonomyQueryRunner(query=TeamTaxonomyQuery(), team=team).run(
execution_mode=ExecutionMode.RECENT_CACHE_CALCULATE_BLOCKING_IF_STALE
)
if not isinstance(response, CachedTeamTaxonomyQueryResponse):
raise SnapshotUnrecoverableError(f"Unexpected response type from events taxonomy query: {type(response)}")
return response
res = call_query_runner(snapshot_events_taxonomy)
if not res.results:
raise ValueError("No results from events taxonomy query")
raise SnapshotUnrecoverableError("No results from events taxonomy query")
# Dump properties
snapshot_properties_taxonomy(context, s3, properties_file_key, team, res.results)
@@ -216,18 +250,24 @@ def snapshot_actors_property_taxonomy(
# Query ClickHouse in batches of 200 properties
for batch in chunked(property_defs, 200):
def snapshot(index: int | None, batch: list[str]):
return call_query_runner(
lambda: ActorsPropertyTaxonomyQueryRunner(
query=ActorsPropertyTaxonomyQuery(groupTypeIndex=index, properties=batch, maxPropertyValues=25),
team=team,
).calculate()
)
def wrapped_query_runner(index: int | None, batch: list[str]):
response = ActorsPropertyTaxonomyQueryRunner(
query=ActorsPropertyTaxonomyQuery(groupTypeIndex=index, properties=batch, maxPropertyValues=25),
team=team,
).run(execution_mode=ExecutionMode.RECENT_CACHE_CALCULATE_BLOCKING_IF_STALE)
if not isinstance(response, CachedActorsPropertyTaxonomyQueryResponse):
raise SnapshotUnrecoverableError(
f"Unexpected response type from actors property taxonomy query: {type(response)}"
)
return response
res = snapshot(index, batch)
def call_actors_property_taxonomy_query(index: int | None, batch: list[str]):
return call_query_runner(lambda: wrapped_query_runner(index, batch))
res = call_actors_property_taxonomy_query(index, batch)
if not res.results:
raise ValueError(
raise SnapshotUnrecoverableError(
f"No results from actors property taxonomy query for group type {index} and properties {batch}"
)
@@ -244,7 +284,7 @@ def snapshot_actors_property_taxonomy(
@dagster.op(
description="Snapshots ClickHouse project data",
description="Snapshots ClickHouse team data",
retry_policy=DEFAULT_RETRY_POLICY,
tags={
"owner": JobOwners.TEAM_MAX_AI.value,
@@ -252,17 +292,33 @@ def snapshot_actors_property_taxonomy(
},
code_version="v1",
)
def snapshot_clickhouse_project_data(
context: dagster.OpExecutionContext, project_id: int, s3: S3Resource
) -> ClickhouseProjectDataSnapshot:
team = Team.objects.get(id=project_id)
def snapshot_clickhouse_team_data(
context: dagster.OpExecutionContext, team_id: int, s3: S3Resource
) -> ClickhouseTeamDataSnapshot:
try:
team = Team.objects.get(id=team_id)
event_taxonomy_file_key, properties_taxonomy_file_key = snapshot_events_taxonomy(
context, s3, team, context.op_def.version
)
actors_property_taxonomy_file_key = snapshot_actors_property_taxonomy(context, s3, team, context.op_def.version)
except Team.DoesNotExist as e:
raise dagster.Failure(
description=f"Team {team_id} does not exist",
metadata={"team_id": team_id},
allow_retries=False,
) from e
materialized_result = ClickhouseProjectDataSnapshot(
try:
event_taxonomy_file_key, properties_taxonomy_file_key = snapshot_events_taxonomy(
context, s3, team, context.op_def.version
)
actors_property_taxonomy_file_key = snapshot_actors_property_taxonomy(context, s3, team, context.op_def.version)
except SnapshotUnrecoverableError as e:
raise dagster.Failure(
description=f"Error snapshotting team {team_id}",
metadata={"team_id": team_id},
allow_retries=False,
) from e
materialized_result = ClickhouseTeamDataSnapshot(
event_taxonomy=event_taxonomy_file_key,
properties_taxonomy=properties_taxonomy_file_key,
actors_property_taxonomy=actors_property_taxonomy_file_key,
@@ -270,10 +326,10 @@ def snapshot_clickhouse_project_data(
context.log_event(
dagster.AssetMaterialization(
asset_key="project_clickhouse_snapshot",
description="Avro snapshots of project's ClickHouse queries",
asset_key="team_clickhouse_snapshot",
description="Avro snapshots of team's ClickHouse queries",
metadata={
"project_id": project_id,
"team_id": team_id,
**materialized_result.model_dump(),
},
tags={"owner": JobOwners.TEAM_MAX_AI.value},

View File

@@ -5,19 +5,27 @@ import dagster
from dagster import OpExecutionContext
from dagster_aws.s3.resources import S3Resource
from posthog.schema import ActorsPropertyTaxonomyResponse, EventTaxonomyItem, TeamTaxonomyItem
from posthog.schema import (
ActorsPropertyTaxonomyResponse,
CachedEventTaxonomyQueryResponse,
CachedTeamTaxonomyQueryResponse,
EventTaxonomyItem,
TeamTaxonomyItem,
)
from posthog.models import GroupTypeMapping, Organization, Project, Team
from posthog.models.property_definition import PropertyDefinition
from dags.max_ai.snapshot_project_data import (
from dags.max_ai.snapshot_team_data import (
SnapshotUnrecoverableError,
snapshot_actors_property_taxonomy,
snapshot_clickhouse_team_data,
snapshot_events_taxonomy,
snapshot_postgres_model,
snapshot_postgres_project_data,
snapshot_postgres_team_data,
snapshot_properties_taxonomy,
)
from ee.hogai.eval.schema import PostgresProjectDataSnapshot, TeamSnapshot
from ee.hogai.eval.schema import PostgresTeamDataSnapshot, TeamSnapshot
@pytest.fixture
@@ -58,7 +66,7 @@ def team():
@pytest.fixture
def mock_dump():
with patch("dags.max_ai.snapshot_project_data.dump_model") as mock_dump_model:
with patch("dags.max_ai.snapshot_team_data.dump_model") as mock_dump_model:
mock_dump_context = MagicMock()
mock_dump_function = MagicMock()
mock_dump_context.__enter__ = MagicMock(return_value=mock_dump_function)
@@ -67,8 +75,8 @@ def mock_dump():
yield mock_dump_function
@patch("dags.max_ai.snapshot_project_data.compose_postgres_dump_path")
@patch("dags.max_ai.snapshot_project_data.check_dump_exists")
@patch("dags.max_ai.snapshot_team_data.compose_postgres_dump_path")
@patch("dags.max_ai.snapshot_team_data.check_dump_exists")
def test_snapshot_postgres_model_skips_when_file_exists(
mock_check_dump_exists, mock_compose_path, mock_context, mock_s3
):
@@ -78,7 +86,7 @@ def test_snapshot_postgres_model_skips_when_file_exists(
mock_compose_path.return_value = file_key
mock_check_dump_exists.return_value = True
project_id = 123
team_id = 123
file_name = "teams"
code_version = "v1"
@@ -88,19 +96,19 @@ def test_snapshot_postgres_model_skips_when_file_exists(
model_type=TeamSnapshot,
file_name=file_name,
s3=mock_s3,
project_id=project_id,
team_id=team_id,
code_version=code_version,
)
# Verify
assert result == file_key
mock_compose_path.assert_called_once_with(project_id, file_name, code_version)
mock_compose_path.assert_called_once_with(team_id, file_name, code_version)
mock_check_dump_exists.assert_called_once_with(mock_s3, file_key)
mock_context.log.info.assert_called_once_with(f"Skipping {file_key} because it already exists")
@patch("dags.max_ai.snapshot_project_data.compose_postgres_dump_path")
@patch("dags.max_ai.snapshot_project_data.check_dump_exists")
@patch("dags.max_ai.snapshot_team_data.compose_postgres_dump_path")
@patch("dags.max_ai.snapshot_team_data.check_dump_exists")
def test_snapshot_postgres_model_dumps_when_file_not_exists(
mock_check_dump_exists, mock_compose_path, mock_context, mock_s3, mock_dump
):
@@ -110,10 +118,10 @@ def test_snapshot_postgres_model_dumps_when_file_not_exists(
mock_compose_path.return_value = file_key
mock_check_dump_exists.return_value = False
# Mock the serialize_for_project method
# Mock the serialize_for_team method
mock_serialized_data = [{"id": 1, "name": "test"}]
with patch.object(TeamSnapshot, "serialize_for_project", return_value=mock_serialized_data):
project_id = 123
with patch.object(TeamSnapshot, "serialize_for_team", return_value=mock_serialized_data):
team_id = 123
file_name = "teams"
code_version = "v1"
@@ -123,25 +131,25 @@ def test_snapshot_postgres_model_dumps_when_file_not_exists(
model_type=TeamSnapshot,
file_name=file_name,
s3=mock_s3,
project_id=project_id,
team_id=team_id,
code_version=code_version,
)
# Verify
assert result == file_key
mock_compose_path.assert_called_once_with(project_id, file_name, code_version)
mock_compose_path.assert_called_once_with(team_id, file_name, code_version)
mock_check_dump_exists.assert_called_once_with(mock_s3, file_key)
mock_context.log.info.assert_called_with(f"Dumping {file_key}")
mock_dump.assert_called_once_with(mock_serialized_data)
@patch("dags.max_ai.snapshot_project_data.snapshot_postgres_model")
def test_snapshot_postgres_project_data_exports_all_models(mock_snapshot_postgres_model, mock_s3):
"""Test that snapshot_postgres_project_data exports all expected models."""
@patch("dags.max_ai.snapshot_team_data.snapshot_postgres_model")
def test_snapshot_postgres_team_data_exports_all_models(mock_snapshot_postgres_model, mock_s3):
"""Test that snapshot_postgres_team_data exports all expected models."""
# Setup
project_id = 456
team_id = 456
mock_snapshot_postgres_model.side_effect = [
"path/to/project.avro",
"path/to/team.avro",
"path/to/property_definitions.avro",
"path/to/group_type_mappings.avro",
"path/to/data_warehouse_tables.avro",
@@ -151,11 +159,11 @@ def test_snapshot_postgres_project_data_exports_all_models(mock_snapshot_postgre
context = dagster.build_op_context()
# Execute
result = snapshot_postgres_project_data(context, project_id, mock_s3)
result = snapshot_postgres_team_data(context, team_id, mock_s3)
# Verify all expected models are in the result
assert isinstance(result, PostgresProjectDataSnapshot)
assert result.project == "path/to/project.avro"
assert isinstance(result, PostgresTeamDataSnapshot)
assert result.team == "path/to/team.avro"
assert result.property_definitions == "path/to/property_definitions.avro"
assert result.group_type_mappings == "path/to/group_type_mappings.avro"
assert result.data_warehouse_tables == "path/to/data_warehouse_tables.avro"
@@ -165,7 +173,7 @@ def test_snapshot_postgres_project_data_exports_all_models(mock_snapshot_postgre
@pytest.mark.django_db
@patch("dags.max_ai.snapshot_project_data.call_query_runner")
@patch("dags.max_ai.snapshot_team_data.call_query_runner")
def test_snapshot_properties_taxonomy(mock_call_query_runner, mock_context, mock_s3, team, mock_dump):
"""Test that snapshot_properties_taxonomy correctly processes events and dumps results."""
# Setup
@@ -190,9 +198,47 @@ def test_snapshot_properties_taxonomy(mock_call_query_runner, mock_context, mock
mock_dump.assert_called_once()
@patch("dags.max_ai.snapshot_project_data.check_dump_exists")
@patch("dags.max_ai.snapshot_project_data.EventTaxonomyQueryRunner.calculate")
@patch("dags.max_ai.snapshot_project_data.TeamTaxonomyQueryRunner.calculate")
@patch("dags.max_ai.snapshot_team_data.snapshot_postgres_model")
def test_snapshot_postgres_team_data_raises_failure_on_missing_team(mock_snapshot_postgres_model, mock_s3):
mock_snapshot_postgres_model.side_effect = Team.DoesNotExist()
context = dagster.build_op_context()
with pytest.raises(dagster.Failure) as exc:
snapshot_postgres_team_data(context=context, team_id=999999, s3=mock_s3)
assert getattr(exc.value, "allow_retries", None) is False
assert "Team 999999 does not exist" in str(exc.value)
@pytest.mark.django_db
def test_snapshot_clickhouse_team_data_raises_failure_on_missing_team(mock_s3):
context = dagster.build_op_context()
with pytest.raises(dagster.Failure) as exc:
snapshot_clickhouse_team_data(context=context, team_id=424242, s3=mock_s3)
assert getattr(exc.value, "allow_retries", None) is False
assert "Team 424242 does not exist" in str(exc.value)
@pytest.mark.django_db
@patch("dags.max_ai.snapshot_team_data.snapshot_events_taxonomy")
def test_snapshot_clickhouse_team_data_raises_failure_on_unrecoverable_error(
mock_snapshot_events_taxonomy, mock_s3, team
):
context = dagster.build_op_context()
mock_snapshot_events_taxonomy.side_effect = SnapshotUnrecoverableError("boom")
with pytest.raises(dagster.Failure) as exc:
snapshot_clickhouse_team_data(context=context, team_id=team.id, s3=mock_s3)
assert getattr(exc.value, "allow_retries", None) is False
@patch("dags.max_ai.snapshot_team_data.check_dump_exists")
@patch("dags.max_ai.snapshot_team_data.EventTaxonomyQueryRunner.run")
@patch("dags.max_ai.snapshot_team_data.TeamTaxonomyQueryRunner.run")
@pytest.mark.django_db
def test_snapshot_events_taxonomy(
mock_team_taxonomy_query_runner,
@@ -207,16 +253,18 @@ def test_snapshot_events_taxonomy(
mock_check_dump_exists.return_value = False
mock_team_taxonomy_query_runner.return_value = MagicMock(
spec=CachedTeamTaxonomyQueryResponse,
results=[
TeamTaxonomyItem(event="pageview", count=2),
TeamTaxonomyItem(event="click", count=1),
]
],
)
mock_event_taxonomy_query_runner.return_value = MagicMock(
spec=CachedEventTaxonomyQueryResponse,
results=[
EventTaxonomyItem(property="$current_url", sample_values=["https://posthog.com"], sample_count=1),
]
],
)
mock_s3_client = MagicMock()
@@ -228,7 +276,7 @@ def test_snapshot_events_taxonomy(
assert mock_dump.call_count == 2
@patch("dags.max_ai.snapshot_project_data.check_dump_exists")
@patch("dags.max_ai.snapshot_team_data.check_dump_exists")
@pytest.mark.django_db
def test_snapshot_events_taxonomy_can_be_skipped(mock_check_dump_exists, mock_context, mock_s3, team, mock_dump):
"""est that snapshot_events_taxonomy can be skipped when file already exists."""
@@ -241,7 +289,7 @@ def test_snapshot_events_taxonomy_can_be_skipped(mock_check_dump_exists, mock_co
assert mock_dump.call_count == 0
@patch("dags.max_ai.snapshot_project_data.check_dump_exists")
@patch("dags.max_ai.snapshot_team_data.check_dump_exists")
@pytest.mark.django_db
def test_snapshot_actors_property_taxonomy_can_be_skipped(
mock_check_dump_exists, mock_context, mock_s3, team, mock_dump
@@ -263,8 +311,8 @@ def test_snapshot_actors_property_taxonomy_can_be_skipped(
)
@patch("dags.max_ai.snapshot_project_data.check_dump_exists")
@patch("dags.max_ai.snapshot_project_data.call_query_runner")
@patch("dags.max_ai.snapshot_team_data.check_dump_exists")
@patch("dags.max_ai.snapshot_team_data.call_query_runner")
@pytest.mark.django_db
def test_snapshot_actors_property_taxonomy_dumps_with_group_type_mapping(
mock_call_query_runner, mock_check_dump_exists, mock_context, mock_s3, team, mock_dump

View File

@@ -12,12 +12,12 @@ We currently use [Braintrust](https://braintrust.dev) as our evaluation platform
2. Run all evals with:
```bash
pytest ee/hogai/eval
pytest ee/hogai/eval/ci
```
The key bit is specifying the `ee/hogai/eval` directory that activates our eval-specific config, `ee/hogai/eval/pytest.ini`!
The key bit is specifying the `ee/hogai/eval/ci` directory that activates our eval-specific config, `ee/hogai/eval/pytest.ini`!
As always with pytest, you can also run a specific file, e.g. `pytest ee/hogai/eval/eval_root.py`.
As always with pytest, you can also run a specific file, e.g. `pytest ee/hogai/eval/ci/eval_root.py`. Apply the `--eval sql` argument to only run evals for test cases that contain `sql`.
3. Voila! Max ran, evals executed, and results and traces uploaded to the Braintrust platform + summarized in the terminal.

65
ee/hogai/eval/base.py Normal file
View File

@@ -0,0 +1,65 @@
import os
import asyncio
from collections.abc import Sequence
from functools import partial
import pytest
from braintrust import EvalAsync, Metadata, init_logger
from braintrust.framework import EvalData, EvalScorer, EvalTask, Input, Output
async def BaseMaxEval(
experiment_name: str,
data: EvalData[Input, Output],
task: EvalTask[Input, Output],
scores: Sequence[EvalScorer[Input, Output]],
pytestconfig: pytest.Config,
metadata: Metadata | None = None,
is_public: bool = False,
no_send_logs: bool = True,
):
# We need to specify a separate project for each MaxEval() suite for comparison to baseline to work
# That's the way Braintrust folks recommended - Braintrust projects are much more lightweight than PostHog ones
project_name = f"max-ai-{experiment_name}"
init_logger(project_name)
# Filter by --case <eval_case_name_part> pytest flag
case_filter = pytestconfig.option.eval
if case_filter:
if asyncio.iscoroutine(data):
data = await data
data = [case for case in data if case_filter in str(case.input)] # type: ignore
timeout = 60 * 8 # 8 minutes
if os.getenv("EVAL_MODE") == "offline":
timeout = 60 * 60 # 1 hour
result = await EvalAsync(
project_name,
data=data,
task=task,
scores=scores,
timeout=timeout,
max_concurrency=100,
is_public=is_public,
no_send_logs=no_send_logs,
metadata=metadata,
)
# If we're running in the offline mode and the test case marked as public, the pipeline must completely fail.
if os.getenv("EVAL_MODE") == "offline" and is_public:
raise RuntimeError("Evaluation cases must be private when EVAL_MODE is set to offline.")
if os.getenv("EXPORT_EVAL_RESULTS"):
with open("eval_results.jsonl", "a") as f:
f.write(result.summary.as_json() + "\n")
return result
MaxPublicEval = partial(BaseMaxEval, is_public=True, no_send_logs=False)
"""Evaluation case that is publicly accessible."""
MaxPrivateEval = partial(BaseMaxEval, is_public=False, no_send_logs=True)
"""Evaluation case is not accessible publicly."""

View File

@@ -0,0 +1,159 @@
import os
import datetime
from collections.abc import Generator
import pytest
from django.test import override_settings
from braintrust_langchain import BraintrustCallbackHandler, set_global_handler
from posthog.schema import FailureMessage, HumanMessage, VisualizationMessage
from posthog.demo.matrix.manager import MatrixManager
from posthog.models import Organization, Team, User
from posthog.tasks.demo_create_data import HedgeboxMatrix
from ee.hogai.django_checkpoint.checkpointer import DjangoCheckpointer
# We want the PostHog set_up_evals fixture here
from ee.hogai.eval.conftest import set_up_evals # noqa: F401
from ee.hogai.eval.scorers import PlanAndQueryOutput
from ee.hogai.graph.graph import AssistantGraph, InsightsAssistantGraph
from ee.hogai.utils.types import AssistantNodeName, AssistantState
from ee.models.assistant import Conversation, CoreMemory
handler = BraintrustCallbackHandler()
if os.environ.get("EVAL_MODE") == "ci" and os.environ.get("BRAINTRUST_API_KEY"):
set_global_handler(handler)
EVAL_USER_FULL_NAME = "Karen Smith"
@pytest.fixture
def call_root_for_insight_generation(demo_org_team_user):
# This graph structure will first get a plan, then generate the SQL query.
insights_subgraph = (
# Insights subgraph without query execution, so we only create the queries
InsightsAssistantGraph(demo_org_team_user[1], demo_org_team_user[2])
.add_query_creation_flow(next_node=AssistantNodeName.END)
.compile()
)
graph = (
AssistantGraph(demo_org_team_user[1], demo_org_team_user[2])
.add_edge(AssistantNodeName.START, AssistantNodeName.ROOT)
.add_root(
path_map={
"insights": AssistantNodeName.INSIGHTS_SUBGRAPH,
"root": AssistantNodeName.END,
"search_documentation": AssistantNodeName.END,
"end": AssistantNodeName.END,
}
)
.add_node(AssistantNodeName.INSIGHTS_SUBGRAPH, insights_subgraph)
.add_edge(AssistantNodeName.INSIGHTS_SUBGRAPH, AssistantNodeName.END)
# TRICKY: We need to set a checkpointer here because async tests create a new event loop.
.compile(checkpointer=DjangoCheckpointer())
)
async def callable(query_with_extra_context: str | tuple[str, str]) -> PlanAndQueryOutput:
# If query_with_extra_context is a tuple, the first element is the query, the second is the extra context
# in case there's an ask_user tool call.
query = query_with_extra_context[0] if isinstance(query_with_extra_context, tuple) else query_with_extra_context
# Initial state for the graph
initial_state = AssistantState(
messages=[HumanMessage(content=f"Answer this question: {query}")],
)
conversation = await Conversation.objects.acreate(team=demo_org_team_user[1], user=demo_org_team_user[2])
# Invoke the graph. The state will be updated through planner and then generator.
final_state_raw = await graph.ainvoke(initial_state, {"configurable": {"thread_id": conversation.id}})
final_state = AssistantState.model_validate(final_state_raw)
# If we have extra context for the potential ask_user tool, and there's no message of type ai/failure
# or ai/visualization, we should answer with that extra context. We only do this once at most in an eval case.
if isinstance(query_with_extra_context, tuple) and not any(
isinstance(m, VisualizationMessage | FailureMessage) for m in final_state.messages
):
final_state.messages = [*final_state.messages, HumanMessage(content=query_with_extra_context[1])]
final_state.graph_status = "resumed"
final_state_raw = await graph.ainvoke(final_state, {"configurable": {"thread_id": conversation.id}})
final_state = AssistantState.model_validate(final_state_raw)
if not final_state.messages or not isinstance(final_state.messages[-1], VisualizationMessage):
return {
"plan": None,
"query": None,
"query_generation_retry_count": final_state.query_generation_retry_count,
}
return {
"plan": final_state.messages[-1].plan,
"query": final_state.messages[-1].answer,
"query_generation_retry_count": final_state.query_generation_retry_count,
}
return callable
@pytest.fixture(scope="package")
def demo_org_team_user(set_up_evals, django_db_blocker) -> Generator[tuple[Organization, Team, User], None, None]: # noqa: F811
with django_db_blocker.unblock():
team: Team | None = Team.objects.order_by("-created_at").first()
today = datetime.date.today()
# If there's no eval team or it's older than today, we need to create a new one with fresh data
if not team or team.created_at.date() < today:
print(f"Generating fresh demo data for evals...") # noqa: T201
matrix = HedgeboxMatrix(
seed="b1ef3c66-5f43-488a-98be-6b46d92fbcef", # this seed generates all events
days_past=120,
days_future=30,
n_clusters=500,
group_type_index_offset=0,
)
matrix_manager = MatrixManager(matrix, print_steps=True)
with override_settings(TEST=False):
# Simulation saving should occur in non-test mode, so that Kafka isn't mocked. Normally in tests we don't
# want to ingest via Kafka, but simulation saving is specifically designed to use that route for speed
org, team, user = matrix_manager.ensure_account_and_save(
f"eval-{today.isoformat()}", EVAL_USER_FULL_NAME, "Hedgebox Inc."
)
else:
print(f"Using existing demo data for evals...") # noqa: T201
org = team.organization
membership = org.memberships.first()
assert membership is not None
user = membership.user
yield org, team, user
@pytest.fixture(scope="package", autouse=True)
def core_memory(demo_org_team_user, django_db_blocker) -> Generator[CoreMemory, None, None]:
initial_memory = """Hedgebox is a cloud storage service enabling users to store, share, and access files across devices.
The company operates in the cloud storage and collaboration market for individuals and businesses.
Their audience includes professionals and organizations seeking file management and collaboration solutions.
Hedgebox's freemium model provides free accounts with limited storage and paid subscription plans for additional features.
Core features include file storage, synchronization, sharing, and collaboration tools for seamless file access and sharing.
It integrates with third-party applications to enhance functionality and streamline workflows.
Hedgebox sponsors the YouTube channel Marius Tech Tips."""
with django_db_blocker.unblock():
core_memory, _ = CoreMemory.objects.get_or_create(
team=demo_org_team_user[1],
defaults={
"text": initial_memory,
"initial_text": initial_memory,
"scraping_status": CoreMemory.ScrapingStatus.COMPLETED,
},
)
yield core_memory

View File

@@ -15,13 +15,13 @@ from posthog.schema import (
from ee.hogai.graph.funnels.toolkit import FUNNEL_SCHEMA
from .conftest import MaxEval
from .scorers import PlanAndQueryOutput, PlanCorrectness, QueryAndPlanAlignment, QueryKindSelection, TimeRangeRelevancy
from ..base import MaxPublicEval
from ..scorers import PlanAndQueryOutput, PlanCorrectness, QueryAndPlanAlignment, QueryKindSelection, TimeRangeRelevancy
@pytest.mark.django_db
async def eval_funnel(call_root_for_insight_generation, pytestconfig):
await MaxEval(
await MaxPublicEval(
experiment_name="funnel",
task=call_root_for_insight_generation,
scores=[

View File

@@ -11,8 +11,8 @@ from ee.hogai.graph import AssistantGraph
from ee.hogai.utils.types import AssistantNodeName, AssistantState
from ee.models.assistant import Conversation
from .conftest import MaxEval
from .scorers import InsightEvaluationAccuracy, InsightSearchOutput
from ..base import MaxPublicEval
from ..scorers import InsightEvaluationAccuracy, InsightSearchOutput
def extract_evaluation_info_from_state(state) -> dict:
@@ -144,7 +144,7 @@ def call_insight_search(demo_org_team_user):
@pytest.mark.django_db
async def eval_insight_evaluation_accuracy(call_insight_search, pytestconfig):
"""Evaluate the accuracy of the insight evaluation decision."""
await MaxEval(
await MaxPublicEval(
experiment_name="insight_evaluation_accuracy",
task=call_insight_search,
scores=[InsightEvaluationAccuracy()],

View File

@@ -4,6 +4,7 @@ import pytest
from autoevals.llm import LLMClassifier
from braintrust import EvalCase, Score
from langchain_core.messages import AIMessage as LangchainAIMessage
from posthog.schema import AssistantMessage, AssistantToolCall, HumanMessage
@@ -12,8 +13,8 @@ from ee.hogai.graph import AssistantGraph
from ee.hogai.utils.types import AssistantNodeName, AssistantState
from ee.models.assistant import Conversation
from .conftest import MaxEval
from .scorers import ToolRelevance
from ..base import MaxPublicEval
from ..scorers import ToolRelevance
class MemoryContentRelevance(LLMClassifier):
@@ -83,9 +84,12 @@ def call_node(demo_org_team_user, core_memory):
state = AssistantState.model_validate(raw_state)
if not state.memory_collection_messages:
return None
last_message = state.memory_collection_messages[-1]
if not isinstance(last_message, LangchainAIMessage):
return None
return AssistantMessage(
content=state.memory_collection_messages[-1].content,
tool_calls=state.memory_collection_messages[-1].tool_calls,
content=last_message.content,
tool_calls=last_message.tool_calls,
)
return callable
@@ -93,7 +97,7 @@ def call_node(demo_org_team_user, core_memory):
@pytest.mark.django_db
async def eval_memory(call_node, pytestconfig):
await MaxEval(
await MaxPublicEval(
experiment_name="memory",
task=call_node,
scores=[ToolRelevance(semantic_similarity_args={"memory_content", "new_fragment"}), MemoryContentRelevance()],

View File

@@ -6,13 +6,13 @@ from posthog.schema import AssistantRetentionEventsNode, AssistantRetentionFilte
from ee.hogai.graph.retention.toolkit import RETENTION_SCHEMA
from .conftest import MaxEval
from .scorers import PlanAndQueryOutput, PlanCorrectness, QueryAndPlanAlignment, QueryKindSelection, TimeRangeRelevancy
from ..base import MaxPublicEval
from ..scorers import PlanAndQueryOutput, PlanCorrectness, QueryAndPlanAlignment, QueryKindSelection, TimeRangeRelevancy
@pytest.mark.django_db
async def eval_retention(call_root_for_insight_generation, pytestconfig):
await MaxEval(
await MaxPublicEval(
experiment_name="retention",
task=call_root_for_insight_generation,
scores=[

View File

@@ -11,8 +11,8 @@ from ee.hogai.graph import AssistantGraph
from ee.hogai.utils.types import AssistantMessageUnion, AssistantNodeName, AssistantState
from ee.models.assistant import Conversation
from .conftest import MaxEval
from .scorers import ToolRelevance
from ..base import MaxPublicEval
from ..scorers import ToolRelevance
@pytest.fixture
@@ -47,7 +47,7 @@ def call_root(demo_org_team_user):
@pytest.mark.django_db
async def eval_root(call_root, pytestconfig):
await MaxEval(
await MaxPublicEval(
experiment_name="root",
task=call_root,
scores=[ToolRelevance(semantic_similarity_args={"query_description"})],

View File

@@ -9,7 +9,8 @@ from ee.hogai.graph import AssistantGraph
from ee.hogai.utils.types import AssistantMessageUnion, AssistantNodeName, AssistantState
from ee.models.assistant import Conversation
from .conftest import EVAL_USER_FULL_NAME, MaxEval
from ..base import MaxPublicEval
from .conftest import EVAL_USER_FULL_NAME
class StyleChecker(LLMClassifier):
@@ -90,7 +91,7 @@ def call_root(demo_org_team_user):
@pytest.mark.django_db
async def eval_root_style(call_root, pytestconfig):
await MaxEval(
await MaxPublicEval(
experiment_name="root_style",
task=call_root,
scores=[StyleChecker(user_name=EVAL_USER_FULL_NAME)],

View File

@@ -6,16 +6,13 @@ from braintrust_core.score import Scorer
from posthog.schema import AssistantHogQLQuery, NodeKind
from posthog.hogql.errors import BaseHogQLError
from posthog.errors import InternalCHQueryError
from posthog.hogql_queries.hogql_query_runner import HogQLQueryRunner
from posthog.models.team.team import Team
from posthog.models import Team
from ee.hogai.eval.scorers.sql import evaluate_sql_query
from ee.hogai.graph.sql.toolkit import SQL_SCHEMA
from .conftest import MaxEval
from .scorers import PlanAndQueryOutput, PlanCorrectness, QueryAndPlanAlignment, QueryKindSelection, TimeRangeRelevancy
from ..base import MaxPublicEval
from ..scorers import PlanAndQueryOutput, PlanCorrectness, QueryAndPlanAlignment, QueryKindSelection, TimeRangeRelevancy
QUERY_GENERATION_MAX_RETRIES = 3
@@ -43,57 +40,23 @@ class RetryEfficiency(Scorer):
return Score(name=self._name(), score=score, metadata={"query_generation_retry_count": retry_count})
class SQLSyntaxCorrectness(Scorer):
"""Evaluate if the generated SQL query has correct syntax."""
class HogQLQuerySyntaxCorrectness(Scorer):
def _name(self):
return "sql_syntax_correctness"
async def _run_eval_async(self, output, expected=None, **kwargs):
if not output:
return Score(
name=self._name(), score=None, metadata={"reason": "No SQL query to verify, skipping evaluation"}
)
query = {"query": output}
team = await Team.objects.alatest("created_at")
try:
# Try to parse, print, and run the query
await sync_to_async(HogQLQueryRunner(query, team).calculate)()
except BaseHogQLError as e:
return Score(name=self._name(), score=0.0, metadata={"reason": f"HogQL-level error: {str(e)}"})
except InternalCHQueryError as e:
return Score(name=self._name(), score=0.5, metadata={"reason": f"ClickHouse-level error: {str(e)}"})
else:
return Score(name=self._name(), score=1.0)
async def _run_eval_async(self, output: PlanAndQueryOutput, *args, **kwargs):
return await sync_to_async(self._evaluate)(output)
def _run_eval_sync(self, output, expected=None, **kwargs):
if not output:
return Score(
name=self._name(), score=None, metadata={"reason": "No SQL query to verify, skipping evaluation"}
)
query = {"query": output}
def _run_eval_sync(self, output: PlanAndQueryOutput, *args, **kwargs):
return self._evaluate(output)
def _evaluate(self, output: PlanAndQueryOutput) -> Score:
team = Team.objects.latest("created_at")
try:
# Try to parse, print, and run the query
HogQLQueryRunner(query, team).calculate()
except BaseHogQLError as e:
return Score(name=self._name(), score=0.0, metadata={"reason": f"HogQL-level error: {str(e)}"})
except InternalCHQueryError as e:
return Score(name=self._name(), score=0.5, metadata={"reason": f"ClickHouse-level error: {str(e)}"})
if isinstance(output["query"], AssistantHogQLQuery):
query = output["query"].query
else:
return Score(name=self._name(), score=1.0)
class HogQLQuerySyntaxCorrectness(SQLSyntaxCorrectness):
async def _run_eval_async(self, output, expected=None, **kwargs):
return await super()._run_eval_async(
output["query"].query if output and output.get("query") else None, expected, **kwargs
)
def _run_eval_sync(self, output, expected=None, **kwargs):
return super()._run_eval_sync(
output["query"].query if output and output.get("query") else None, expected, **kwargs
)
query = None
return evaluate_sql_query(self._name(), query, team)
@pytest.mark.django_db
@@ -697,7 +660,7 @@ ORDER BY ABS(corr(toFloat(uploads_30d), toFloat(churned))) DESC
),
]
await MaxEval(
await MaxPublicEval(
experiment_name="sql",
task=call_root_for_insight_generation,
scores=[

View File

@@ -13,7 +13,7 @@ from products.surveys.backend.max_tools import CreateSurveyTool, FeatureFlagLook
from ee.hogai.django_checkpoint.checkpointer import DjangoCheckpointer
from ee.models.assistant import Conversation
from .conftest import MaxEval
from ..base import MaxPublicEval
def validate_survey_output(output, scorer_name):
@@ -630,7 +630,7 @@ async def eval_surveys(call_surveys_max_tool, pytestconfig):
"""
Evaluation for survey creation functionality.
"""
await MaxEval(
await MaxPublicEval(
experiment_name="surveys",
task=call_surveys_max_tool,
scores=[

View File

@@ -16,8 +16,8 @@ from posthog.schema import (
from ee.hogai.graph.trends.toolkit import TRENDS_SCHEMA
from .conftest import MaxEval
from .scorers import PlanAndQueryOutput, PlanCorrectness, QueryAndPlanAlignment, QueryKindSelection, TimeRangeRelevancy
from ..base import MaxPublicEval
from ..scorers import PlanAndQueryOutput, PlanCorrectness, QueryAndPlanAlignment, QueryKindSelection, TimeRangeRelevancy
TRENDS_CASES = [
EvalCase(
@@ -408,7 +408,7 @@ Query kind:
@pytest.mark.django_db
async def eval_trends(call_root_for_insight_generation, pytestconfig):
await MaxEval(
await MaxPublicEval(
experiment_name="trends",
task=call_root_for_insight_generation,
scores=[

View File

@@ -19,8 +19,8 @@ from ee.hogai.graph import AssistantGraph
from ee.hogai.utils.types import AssistantNodeName, AssistantState
from ee.models.assistant import Conversation
from .conftest import MaxEval
from .scorers import ToolRelevance
from ..base import MaxPublicEval
from ..scorers import ToolRelevance
@pytest.fixture
@@ -71,7 +71,7 @@ def sample_action(demo_org_team_user):
@pytest.mark.django_db
async def eval_ui_context_actions(call_root_with_ui_context, sample_action, pytestconfig):
"""Test that actions in UI context are properly used in RAG context retrieval"""
await MaxEval(
await MaxPublicEval(
experiment_name="ui_context_actions",
task=call_root_with_ui_context,
scores=[
@@ -136,7 +136,7 @@ async def eval_ui_context_actions(call_root_with_ui_context, sample_action, pyte
@pytest.mark.django_db
async def eval_ui_context_events(call_root_with_ui_context, pytestconfig):
"""Test that events in UI context are properly used in taxonomy agent"""
await MaxEval(
await MaxPublicEval(
experiment_name="ui_context_events",
task=call_root_with_ui_context,
scores=[

View File

View File

@@ -14,9 +14,8 @@ from posthog.sync import database_sync_to_async
from products.data_warehouse.backend.max_tools import HogQLGeneratorArgs, HogQLGeneratorTool
from ee.hogai.eval.conftest import MaxEval
from ee.hogai.eval.eval_sql import SQLSyntaxCorrectness
from ee.hogai.eval.scorers import SQLSemanticsCorrectness
from ee.hogai.eval.base import MaxPublicEval
from ee.hogai.eval.scorers import SQLSemanticsCorrectness, SQLSyntaxCorrectness
from ee.hogai.utils.markdown import remove_markdown
from ee.hogai.utils.types import AssistantState
from ee.hogai.utils.warehouse import serialize_database_schema
@@ -73,7 +72,8 @@ async def database_schema(demo_org_team_user):
return await serialize_database_schema(database, context)
async def sql_semantics_scorer(input: EvalInput, expected: str, output: str, metadata: dict) -> Score:
async def sql_semantics_scorer(input: EvalInput, expected: str, output: str, **kwargs) -> Score:
metadata: dict = kwargs["metadata"]
metric = SQLSemanticsCorrectness()
return await metric.eval_async(
input=input.instructions, expected=expected, output=output, database_schema=metadata["schema"]
@@ -84,7 +84,7 @@ async def sql_semantics_scorer(input: EvalInput, expected: str, output: str, met
async def eval_tool_generate_hogql_query(call_generate_hogql_query, database_schema, pytestconfig):
metadata = {"schema": database_schema}
await MaxEval(
await MaxPublicEval(
experiment_name="tool_generate_hogql_query",
task=call_generate_hogql_query,
scores=[SQLSyntaxCorrectness(), sql_semantics_scorer],

View File

@@ -24,9 +24,10 @@ from products.replay.backend.max_tools import SessionReplayFilterOptionsGraph
from products.replay.backend.prompts import USER_FILTER_OPTIONS_PROMPT
from ee.hogai.django_checkpoint.checkpointer import DjangoCheckpointer
from ee.hogai.eval.conftest import MaxEval
from ee.models.assistant import Conversation
from ...base import MaxPublicEval
logger = logging.getLogger(__name__)
DUMMY_CURRENT_FILTERS = MaxRecordingUniversalFilters(
@@ -178,7 +179,7 @@ class DateTimeFilteringCorrectness(Scorer):
@pytest.mark.django_db
async def eval_tool_search_session_recordings(call_search_session_recordings, pytestconfig):
await MaxEval(
await MaxPublicEval(
experiment_name="tool_search_session_recordings",
task=call_search_session_recordings,
scores=[FilterGenerationCorrectness(), DateTimeFilteringCorrectness()],
@@ -586,7 +587,7 @@ async def eval_tool_search_session_recordings(call_search_session_recordings, py
@pytest.mark.django_db
async def eval_tool_search_session_recordings_ask_user_for_help(call_search_session_recordings, pytestconfig):
await MaxEval(
await MaxPublicEval(
experiment_name="tool_search_session_recordings_ask_user_for_help",
task=call_search_session_recordings,
scores=[AskUserForHelp()],

View File

@@ -0,0 +1,598 @@
import logging
import pytest
from braintrust import EvalCase, Score
from braintrust_core.score import Scorer
from deepdiff import DeepDiff
from posthog.schema import (
DurationType,
EventPropertyFilter,
FilterLogicalOperator,
MaxInnerUniversalFiltersGroup,
MaxOuterUniversalFiltersGroup,
MaxRecordingUniversalFilters,
PersonPropertyFilter,
PropertyOperator,
RecordingDurationFilter,
RecordingOrder,
RecordingOrderDirection,
)
from products.replay.backend.max_tools import SessionReplayFilterOptionsGraph
from products.replay.backend.prompts import USER_FILTER_OPTIONS_PROMPT
from ee.hogai.django_checkpoint.checkpointer import DjangoCheckpointer
from ee.models.assistant import Conversation
from ...base import MaxPublicEval
logger = logging.getLogger(__name__)
DUMMY_CURRENT_FILTERS = MaxRecordingUniversalFilters(
date_from="-7d",
date_to=None,
duration=[
RecordingDurationFilter(key=DurationType.DURATION, operator=PropertyOperator.GT, type="recording", value=60.0)
],
filter_group=MaxOuterUniversalFiltersGroup(
type=FilterLogicalOperator.AND_,
values=[MaxInnerUniversalFiltersGroup(type=FilterLogicalOperator.AND_, values=[])],
),
filter_test_accounts=True,
order=RecordingOrder.START_TIME,
order_direction=RecordingOrderDirection.DESC,
)
@pytest.fixture
def call_search_session_recordings(demo_org_team_user):
graph = SessionReplayFilterOptionsGraph(demo_org_team_user[1], demo_org_team_user[2]).compile_full_graph(
checkpointer=DjangoCheckpointer()
)
async def callable(change: str) -> dict:
conversation = await Conversation.objects.acreate(team=demo_org_team_user[1], user=demo_org_team_user[2])
# Convert filters to JSON string and use test-specific prompt
filters_json = DUMMY_CURRENT_FILTERS.model_dump_json(indent=2)
graph_input = {
"change": USER_FILTER_OPTIONS_PROMPT.format(change=change, current_filters=filters_json),
"output": None,
}
result = await graph.ainvoke(graph_input, config={"configurable": {"thread_id": conversation.id}})
return result
return callable
class FilterGenerationCorrectness(Scorer):
"""Score the correctness of generated filters."""
def _name(self):
return "filter_generation_correctness"
async def _run_eval_async(self, output, expected=None, **kwargs):
return self._run_eval_sync(output, expected, **kwargs)
def _run_eval_sync(self, output, expected=None, **kwargs):
try:
actual_filters = MaxRecordingUniversalFilters.model_validate(output["output"])
except Exception as e:
logger.exception(f"Error parsing filters: {e}")
return Score(name=self._name(), score=0.0, metadata={"reason": "LLM returned invalid filter structure"})
# Convert both objects to dict for deepdiff comparison
actual_dict = actual_filters.model_dump()
expected_dict = expected.model_dump()
# Use deepdiff to find differences
diff = DeepDiff(expected_dict, actual_dict, ignore_order=True, report_repetition=True)
if not diff:
return Score(name=self._name(), score=1.0, metadata={"reason": "Perfect match"})
# Calculate score based on number of differences
total_fields = len(expected_dict.keys())
changed_fields = (
len(diff.get("values_changed", {}))
+ len(diff.get("dictionary_item_added", set()))
+ len(diff.get("dictionary_item_removed", set()))
)
score = max(0.0, (total_fields - changed_fields) / total_fields)
return Score(
name=self._name(),
score=score,
metadata={
"differences": str(diff),
"total_fields": total_fields,
"changed_fields": changed_fields,
"reason": f"Found {changed_fields} differences out of {total_fields} fields",
},
)
class AskUserForHelp(Scorer):
"""Score the correctness of the ask_user_for_help tool."""
def _name(self):
return "ask_user_for_help_scorer"
def _run_eval_sync(self, output, expected=None, **kwargs):
if "output" not in output or output["output"] is None:
if (
"intermediate_steps" in output
and len(output["intermediate_steps"]) > 0
and output["intermediate_steps"][-1][0].tool == "ask_user_for_help"
):
return Score(
name=self._name(), score=1, metadata={"reason": "LLM returned valid ask_user_for_help response"}
)
else:
return Score(
name=self._name(),
score=0,
metadata={"reason": "LLM did not return valid ask_user_for_help response"},
)
else:
return Score(name=self._name(), score=0.0, metadata={"reason": "LLM returned a filter"})
class DateTimeFilteringCorrectness(Scorer):
"""Score the correctness of the date time filtering."""
def _name(self):
return "date_time_filtering_correctness"
async def _run_eval_async(self, output, expected=None, **kwargs):
return self._run_eval_sync(output, expected, **kwargs)
def _run_eval_sync(self, output, expected=None, **kwargs):
try:
actual_filters = MaxRecordingUniversalFilters.model_validate(output["output"])
except Exception as e:
logger.exception(f"Error parsing filters: {e}")
return Score(name=self._name(), score=0.0, metadata={"reason": "LLM returned invalid filter structure"})
if actual_filters.date_from == expected.date_from and actual_filters.date_to == expected.date_to:
return Score(name=self._name(), score=1.0, metadata={"reason": "LLM returned valid date time filters"})
elif actual_filters.date_from == expected.date_from:
return Score(
name=self._name(),
score=0.5,
metadata={"reason": "LLM returned valid date_from but did not return valid date_to"},
)
elif actual_filters.date_to == expected.date_to:
return Score(
name=self._name(),
score=0.5,
metadata={"reason": "LLM returned valid date_to but did not return valid date_from"},
)
else:
return Score(name=self._name(), score=0.0, metadata={"reason": "LLM returned invalid date time filters"})
@pytest.mark.django_db
async def eval_tool_search_session_recordings(call_search_session_recordings, pytestconfig):
await MaxPublicEval(
experiment_name="tool_search_session_recordings",
task=call_search_session_recordings,
scores=[FilterGenerationCorrectness(), DateTimeFilteringCorrectness()],
data=[
# Test basic filter generation for mobile devices
EvalCase(
input="show me recordings of users that were using a mobile device (use events)",
expected=MaxRecordingUniversalFilters(
date_from="-7d",
date_to=None,
duration=[
RecordingDurationFilter(
key=DurationType.DURATION, operator=PropertyOperator.GT, type="recording", value=60.0
)
],
filter_group=MaxOuterUniversalFiltersGroup(
type=FilterLogicalOperator.AND_,
values=[
MaxInnerUniversalFiltersGroup(
type=FilterLogicalOperator.AND_,
values=[
EventPropertyFilter(
key="$device_type",
type="event",
value=["Mobile"],
operator=PropertyOperator.EXACT,
)
],
)
],
),
filter_test_accounts=True,
order=RecordingOrder.START_TIME,
),
),
# Test date range filtering
EvalCase(
input="Show recordings from the last 2 hours",
expected=MaxRecordingUniversalFilters(
date_from="-2h",
date_to=None,
duration=[
RecordingDurationFilter(
key=DurationType.DURATION, operator=PropertyOperator.GT, type="recording", value=60.0
)
],
filter_group=MaxOuterUniversalFiltersGroup(
type=FilterLogicalOperator.AND_,
values=[MaxInnerUniversalFiltersGroup(type=FilterLogicalOperator.AND_, values=[])],
),
filter_test_accounts=True,
order=RecordingOrder.START_TIME,
),
),
# Test location filtering
EvalCase(
input="Show recordings for users located in the US",
expected=MaxRecordingUniversalFilters(
date_from="-7d",
date_to=None,
duration=[
RecordingDurationFilter(
key=DurationType.DURATION, operator=PropertyOperator.GT, type="recording", value=60.0
)
],
filter_group=MaxOuterUniversalFiltersGroup(
type=FilterLogicalOperator.AND_,
values=[
MaxInnerUniversalFiltersGroup(
type=FilterLogicalOperator.AND_,
values=[
PersonPropertyFilter(
key="$geoip_country_code",
type="person",
value=["US"],
operator=PropertyOperator.EXACT,
)
],
)
],
),
filter_test_accounts=True,
order=RecordingOrder.START_TIME,
),
),
# Test browser-specific filtering
EvalCase(
input="Show recordings from users that were using a browser in English",
expected=MaxRecordingUniversalFilters(
date_from="-7d",
date_to=None,
duration=[
RecordingDurationFilter(
key=DurationType.DURATION, operator=PropertyOperator.GT, type="recording", value=60.0
)
],
filter_group=MaxOuterUniversalFiltersGroup(
type=FilterLogicalOperator.AND_,
values=[
MaxInnerUniversalFiltersGroup(
type=FilterLogicalOperator.AND_,
values=[
PersonPropertyFilter(
key="$browser_language",
type="person",
value=["EN-en"],
operator=PropertyOperator.EXACT,
)
],
)
],
),
filter_test_accounts=True,
order=RecordingOrder.START_TIME,
),
),
EvalCase(
input="Show me recordings from chrome browsers",
expected=MaxRecordingUniversalFilters(
date_from="-7d",
date_to=None,
duration=[
RecordingDurationFilter(
key=DurationType.DURATION, operator=PropertyOperator.GT, type="recording", value=60.0
)
],
filter_group=MaxOuterUniversalFiltersGroup(
type=FilterLogicalOperator.AND_,
values=[
MaxInnerUniversalFiltersGroup(
type=FilterLogicalOperator.AND_,
values=[
EventPropertyFilter(
key="$browser",
type="event",
value=["Chrome"],
operator=PropertyOperator.EXACT,
)
],
)
],
),
filter_test_accounts=True,
order=RecordingOrder.START_TIME,
),
),
# Test user behavior filtering
EvalCase(
input="Show recordings where users visited the posthog.com/checkout_page",
expected=MaxRecordingUniversalFilters(
date_from="-7d",
date_to=None,
duration=[
RecordingDurationFilter(
key=DurationType.DURATION, operator=PropertyOperator.GT, type="recording", value=60.0
)
],
filter_group=MaxOuterUniversalFiltersGroup(
type=FilterLogicalOperator.AND_,
values=[
MaxInnerUniversalFiltersGroup(
type=FilterLogicalOperator.AND_,
values=[
EventPropertyFilter(
key="$current_url",
type="event",
value=["posthog.com/checkout_page"],
operator=PropertyOperator.ICONTAINS,
)
],
)
],
),
filter_test_accounts=True,
order=RecordingOrder.START_TIME,
),
),
# Test session duration filtering
EvalCase(
input="Show recordings longer than 5 minutes",
expected=MaxRecordingUniversalFilters(
date_from="-7d",
date_to=None,
duration=[
RecordingDurationFilter(
key=DurationType.DURATION, operator=PropertyOperator.GT, type="recording", value=300.0
)
],
filter_group=MaxOuterUniversalFiltersGroup(
type=FilterLogicalOperator.AND_,
values=[MaxInnerUniversalFiltersGroup(type=FilterLogicalOperator.AND_, values=[])],
),
filter_test_accounts=True,
order=RecordingOrder.START_TIME,
),
),
# Test user action
EvalCase(
input="Show recordings from users that performed a billing action",
expected=MaxRecordingUniversalFilters(
date_from="-7d",
date_to=None,
duration=[
RecordingDurationFilter(
key=DurationType.DURATION, operator=PropertyOperator.GT, type="recording", value=60.0
)
],
filter_group=MaxOuterUniversalFiltersGroup(
type=FilterLogicalOperator.AND_,
values=[
MaxInnerUniversalFiltersGroup(
type=FilterLogicalOperator.AND_,
values=[
EventPropertyFilter(
key="paid_bill",
type="event",
value=None,
operator=PropertyOperator.IS_SET,
)
],
)
],
),
filter_test_accounts=True,
order=RecordingOrder.START_TIME,
),
),
# Test page-specific filtering
EvalCase(
input="Show recordings from users who visited the pricing page",
expected=MaxRecordingUniversalFilters(
date_from="-7d",
date_to=None,
duration=[
RecordingDurationFilter(
key=DurationType.DURATION, operator=PropertyOperator.GT, type="recording", value=60.0
)
],
filter_group=MaxOuterUniversalFiltersGroup(
type=FilterLogicalOperator.AND_,
values=[
MaxInnerUniversalFiltersGroup(
type=FilterLogicalOperator.AND_,
values=[
EventPropertyFilter(
key="$pathname",
type="event",
value=["/pricing/"],
operator=PropertyOperator.ICONTAINS,
)
],
)
],
),
filter_test_accounts=True,
order=RecordingOrder.START_TIME,
),
),
# Test conversion funnel filtering
EvalCase(
input="Show recordings from users who completed the signup flow",
expected=MaxRecordingUniversalFilters(
date_from="-7d",
date_to=None,
duration=[
RecordingDurationFilter(
key=DurationType.DURATION, operator=PropertyOperator.GT, type="recording", value=60.0
)
],
filter_group=MaxOuterUniversalFiltersGroup(
type=FilterLogicalOperator.AND_,
values=[
MaxInnerUniversalFiltersGroup(
type=FilterLogicalOperator.AND_,
values=[
EventPropertyFilter(
key="signup_completed",
type="event",
value=None,
operator=PropertyOperator.IS_SET,
)
],
)
],
),
filter_test_accounts=True,
order=RecordingOrder.START_TIME,
),
),
# Test device and browser combination
EvalCase(
input="Show recordings from mobile Safari users",
expected=MaxRecordingUniversalFilters(
date_from="-7d",
date_to=None,
duration=[
RecordingDurationFilter(
key=DurationType.DURATION, operator=PropertyOperator.GT, type="recording", value=60.0
)
],
filter_group=MaxOuterUniversalFiltersGroup(
type=FilterLogicalOperator.AND_,
values=[
MaxInnerUniversalFiltersGroup(
type=FilterLogicalOperator.AND_,
values=[
EventPropertyFilter(
key="$device_type",
type="event",
value=["Mobile"],
operator=PropertyOperator.EXACT,
),
EventPropertyFilter(
key="$browser",
type="event",
value=["Safari"],
operator=PropertyOperator.EXACT,
),
],
)
],
),
filter_test_accounts=True,
order=RecordingOrder.START_TIME,
),
),
# Test time-based filtering
EvalCase(
input="Show recordings from yesterday",
expected=DUMMY_CURRENT_FILTERS.model_copy(update={"date_from": "-1d", "date_to": "-1d"}),
),
EvalCase(
input="Show me recordings since the 1st of August",
expected=DUMMY_CURRENT_FILTERS.model_copy(update={"date_from": "2025-08-01T00:00:00:000"}),
),
EvalCase(
input="Show me recordings until the 31st of August",
expected=DUMMY_CURRENT_FILTERS.model_copy(update={"date_to": "2025-08-31T23:59:59:999"}),
),
EvalCase(
input="Show me recordings from the 1st of September to the 31st of September",
expected=DUMMY_CURRENT_FILTERS.model_copy(
update={"date_from": "2025-09-01T00:00:00:000", "date_to": "2025-09-30T23:59:59:999"}
),
),
EvalCase(
input="Show me recordings from the 1st of September to the 1st of September",
expected=DUMMY_CURRENT_FILTERS.model_copy(
update={"date_from": "2025-09-01T00:00:00:000", "date_to": "2025-09-01T23:59:59:999"}
),
),
EvalCase(
input="show me recordings of users who signed up on mobile",
expected=MaxRecordingUniversalFilters(
date_from="-7d",
date_to=None,
duration=[
RecordingDurationFilter(
key=DurationType.DURATION, operator=PropertyOperator.GT, type="recording", value=60.0
)
],
filter_group=MaxOuterUniversalFiltersGroup(
type=FilterLogicalOperator.AND_,
values=[
MaxInnerUniversalFiltersGroup(
type=FilterLogicalOperator.AND_,
values=[
EventPropertyFilter(
key="$device_type",
type="event",
value=["Mobile"],
operator=PropertyOperator.EXACT,
)
],
)
],
),
filter_test_accounts=True,
order=RecordingOrder.START_TIME,
),
),
EvalCase(
input="Show recordings in an ascending order by duration",
expected=MaxRecordingUniversalFilters(
date_from="-7d",
date_to=None,
duration=[
RecordingDurationFilter(
key=DurationType.DURATION, operator=PropertyOperator.GT, type="recording", value=60.0
)
],
filter_group=MaxOuterUniversalFiltersGroup(
type=FilterLogicalOperator.AND_,
values=[MaxInnerUniversalFiltersGroup(type=FilterLogicalOperator.AND_, values=[])],
),
filter_test_accounts=True,
order=RecordingOrder.DURATION,
order_direction=RecordingOrderDirection.ASC,
),
),
],
pytestconfig=pytestconfig,
)
@pytest.mark.django_db
async def eval_tool_search_session_recordings_ask_user_for_help(call_search_session_recordings, pytestconfig):
await MaxPublicEval(
experiment_name="tool_search_session_recordings_ask_user_for_help",
task=call_search_session_recordings,
scores=[AskUserForHelp()],
data=[
EvalCase(input="Tell me something about insights", expected="clarify"),
],
pytestconfig=pytestconfig,
)

View File

@@ -1,218 +1,28 @@
import os
import asyncio
import datetime
from collections import namedtuple
from collections.abc import Generator, Sequence
import pytest
from unittest import mock
from django.test import override_settings
from _pytest.terminal import TerminalReporter
from braintrust import EvalAsync, Metadata, init_logger
from braintrust.framework import EvalData, EvalScorer, EvalTask, Input, Output
from braintrust_langchain import BraintrustCallbackHandler, set_global_handler
from posthog.schema import FailureMessage, HumanMessage, VisualizationMessage
# We want the PostHog django_db_setup fixture here
from posthog.conftest import django_db_setup # noqa: F401
from posthog.demo.matrix.manager import MatrixManager
from posthog.models import Team
from posthog.tasks.demo_create_data import HedgeboxMatrix
from ee.hogai.django_checkpoint.checkpointer import DjangoCheckpointer
from ee.hogai.eval.scorers import PlanAndQueryOutput
from ee.hogai.graph.graph import AssistantGraph, InsightsAssistantGraph
from ee.hogai.utils.types import AssistantNodeName, AssistantState
from ee.hogai.utils.types.base import AnyAssistantGeneratedQuery
from ee.models.assistant import Conversation, CoreMemory
handler = BraintrustCallbackHandler()
if os.environ.get("BRAINTRUST_API_KEY"):
set_global_handler(handler)
EVAL_USER_FULL_NAME = "Karen Smith"
def pytest_addoption(parser):
# Example: pytest ee/hogai/eval/eval_sql.py --eval churn - to only run cases containing "churn" in input
# Example: pytest ee/hogai/eval/ci/eval_sql.py --eval churn - to only run cases containing "churn" in input
parser.addoption("--eval", action="store")
async def MaxEval(
experiment_name: str,
data: EvalData[Input, Output],
task: EvalTask[Input, Output],
scores: Sequence[EvalScorer[Input, Output]],
pytestconfig: pytest.Config,
metadata: Metadata | None = None,
):
# We need to specify a separate project for each MaxEval() suite for comparison to baseline to work
# That's the way Braintrust folks recommended - Braintrust projects are much more lightweight than PostHog ones
project_name = f"max-ai-{experiment_name}"
init_logger(project_name)
# Filter by --case <eval_case_name_part> pytest flag
case_filter = pytestconfig.option.eval
if case_filter:
if asyncio.iscoroutine(data):
data = await data
data = [case for case in data if case_filter in str(case.input)] # type: ignore
result = await EvalAsync(
project_name,
data=data,
task=task,
scores=scores,
timeout=60 * 8,
max_concurrency=100,
is_public=True,
metadata=metadata,
)
if os.getenv("GITHUB_EVENT_NAME") == "pull_request":
with open("eval_results.jsonl", "a") as f:
f.write(result.summary.as_json() + "\n")
return result
@pytest.fixture
def call_root_for_insight_generation(demo_org_team_user):
# This graph structure will first get a plan, then generate the SQL query.
insights_subgraph = (
# Insights subgraph without query execution, so we only create the queries
InsightsAssistantGraph(demo_org_team_user[1], demo_org_team_user[2])
.add_query_creation_flow(next_node=AssistantNodeName.END)
.compile()
)
graph = (
AssistantGraph(demo_org_team_user[1], demo_org_team_user[2])
.add_edge(AssistantNodeName.START, AssistantNodeName.ROOT)
.add_root(
path_map={
"insights": AssistantNodeName.INSIGHTS_SUBGRAPH,
"root": AssistantNodeName.END,
"search_documentation": AssistantNodeName.END,
"end": AssistantNodeName.END,
}
)
.add_node(AssistantNodeName.INSIGHTS_SUBGRAPH, insights_subgraph)
.add_edge(AssistantNodeName.INSIGHTS_SUBGRAPH, AssistantNodeName.END)
# TRICKY: We need to set a checkpointer here because async tests create a new event loop.
.compile(checkpointer=DjangoCheckpointer())
)
async def callable(query_with_extra_context: str | tuple[str, str]) -> PlanAndQueryOutput:
# If query_with_extra_context is a tuple, the first element is the query, the second is the extra context
# in case there's an ask_user tool call.
query = query_with_extra_context[0] if isinstance(query_with_extra_context, tuple) else query_with_extra_context
# Initial state for the graph
initial_state = AssistantState(
messages=[HumanMessage(content=f"Answer this question: {query}")],
)
conversation = await Conversation.objects.acreate(team=demo_org_team_user[1], user=demo_org_team_user[2])
# Invoke the graph. The state will be updated through planner and then generator.
final_state_raw = await graph.ainvoke(initial_state, {"configurable": {"thread_id": conversation.id}})
final_state = AssistantState.model_validate(final_state_raw)
# If we have extra context for the potential ask_user tool, and there's no message of type ai/failure
# or ai/visualization, we should answer with that extra context. We only do this once at most in an eval case.
if isinstance(query_with_extra_context, tuple) and not any(
isinstance(m, VisualizationMessage | FailureMessage) for m in final_state.messages
):
final_state.messages = [*final_state.messages, HumanMessage(content=query_with_extra_context[1])]
final_state.graph_status = "resumed"
final_state_raw = await graph.ainvoke(final_state, {"configurable": {"thread_id": conversation.id}})
final_state = AssistantState.model_validate(final_state_raw)
if (
not final_state.messages
or not isinstance(final_state.messages[-1], VisualizationMessage)
or not isinstance(final_state.messages[-1].answer, AnyAssistantGeneratedQuery)
):
return {
"plan": None,
"query": None,
"query_generation_retry_count": final_state.query_generation_retry_count,
}
return {
"plan": final_state.messages[-1].plan,
"query": final_state.messages[-1].answer,
"query_generation_retry_count": final_state.query_generation_retry_count,
}
return callable
@pytest.fixture(scope="package")
def demo_org_team_user(django_db_setup, django_db_blocker): # noqa: F811
with django_db_blocker.unblock():
team = Team.objects.order_by("-created_at").first()
today = datetime.date.today()
# If there's no eval team or it's older than today, we need to create a new one with fresh data
should_create_new_team = not team or team.created_at.date() < today
if should_create_new_team:
print(f"Generating fresh demo data for evals...") # noqa: T201
matrix = HedgeboxMatrix(
seed="b1ef3c66-5f43-488a-98be-6b46d92fbcef", # this seed generates all events
days_past=120,
days_future=30,
n_clusters=500,
group_type_index_offset=0,
)
matrix_manager = MatrixManager(matrix, print_steps=True)
with override_settings(TEST=False):
# Simulation saving should occur in non-test mode, so that Kafka isn't mocked. Normally in tests we don't
# want to ingest via Kafka, but simulation saving is specifically designed to use that route for speed
org, team, user = matrix_manager.ensure_account_and_save(
f"eval-{today.isoformat()}", EVAL_USER_FULL_NAME, "Hedgebox Inc."
)
else:
print(f"Using existing demo data for evals...") # noqa: T201
org = team.organization
user = org.memberships.first().user
yield org, team, user
@pytest.fixture(scope="package", autouse=True)
def core_memory(demo_org_team_user, django_db_blocker) -> Generator[CoreMemory, None, None]:
initial_memory = """Hedgebox is a cloud storage service enabling users to store, share, and access files across devices.
The company operates in the cloud storage and collaboration market for individuals and businesses.
Their audience includes professionals and organizations seeking file management and collaboration solutions.
Hedgebox's freemium model provides free accounts with limited storage and paid subscription plans for additional features.
Core features include file storage, synchronization, sharing, and collaboration tools for seamless file access and sharing.
It integrates with third-party applications to enhance functionality and streamline workflows.
Hedgebox sponsors the YouTube channel Marius Tech Tips."""
with django_db_blocker.unblock():
core_memory, _ = CoreMemory.objects.get_or_create(
team=demo_org_team_user[1],
defaults={
"text": initial_memory,
"initial_text": initial_memory,
"scraping_status": CoreMemory.ScrapingStatus.COMPLETED,
},
)
yield core_memory
_nodeid_to_results_url_map: dict[str, str] = {}
"""Map of test nodeid (file + test name) to Braintrust results URL."""
@pytest.fixture(scope="package")
def set_up_evals(django_db_setup): # noqa: F811
yield
@pytest.fixture(autouse=True)
def capture_stdout(request, capsys):
yield

View File

View File

@@ -0,0 +1,70 @@
from collections.abc import Generator
from typing import Annotated
import pytest
from asgiref.sync import async_to_sync
from dagster_pipes import PipesContext, open_dagster_pipes
from pydantic import BaseModel, ConfigDict, SkipValidation
from posthog.models import Organization, User
# We want the PostHog set_up_evals fixture here
from ee.hogai.eval.conftest import set_up_evals # noqa: F401
from ee.hogai.eval.offline.snapshot_loader import SnapshotLoader
from ee.hogai.eval.schema import DatasetInput
@pytest.fixture(scope="package")
def dagster_context() -> Generator[PipesContext, None, None]:
with open_dagster_pipes() as context:
yield context
class EvaluationContext(BaseModel):
# We don't want to validate Django models here.
model_config = ConfigDict(arbitrary_types_allowed=True)
organization: Annotated[Organization, SkipValidation]
user: Annotated[User, SkipValidation]
experiment_name: str
dataset: list[DatasetInput]
@pytest.fixture(scope="package", autouse=True)
def eval_ctx(
set_up_evals, # noqa: F811
dagster_context: PipesContext,
django_db_blocker,
) -> Generator[EvaluationContext, None, None]:
"""
Script that restores dumped Django models and patches AI query runners.
Creates teams with team_id=project_id for the same single user and organization,
keeping the original project_ids for teams.
"""
with django_db_blocker.unblock():
dagster_context.log.info(f"Loading Postgres and ClickHouse snapshots...")
loader = SnapshotLoader(dagster_context)
org, user, dataset = async_to_sync(loader.load_snapshots)()
dagster_context.log.info(f"Running tests...")
yield EvaluationContext(
organization=org,
user=user,
experiment_name=loader.config.experiment_name,
dataset=dataset,
)
dagster_context.log.info(f"Cleaning up...")
loader.cleanup()
dagster_context.log.info(f"Reporting results...")
with open("eval_results.jsonl") as f:
lines = f.readlines()
dagster_context.report_asset_materialization(
asset_key="evaluation_report",
metadata={
"output": "\n".join(lines),
},
)

View File

@@ -0,0 +1,107 @@
import asyncio
from typing import TypedDict, cast
import pytest
from braintrust import EvalCase, Score
from pydantic import BaseModel, Field
from posthog.schema import AssistantHogQLQuery, HumanMessage, VisualizationMessage
from posthog.hogql.context import HogQLContext
from posthog.hogql.database.database import create_hogql_database
from posthog.models import Team
from posthog.sync import database_sync_to_async
from ee.hogai.eval.base import MaxPrivateEval
from ee.hogai.eval.offline.conftest import EvaluationContext
from ee.hogai.eval.schema import DatasetInput
from ee.hogai.eval.scorers.sql import SQLSemanticsCorrectness, SQLSyntaxCorrectness
from ee.hogai.graph import AssistantGraph
from ee.hogai.utils.helpers import find_last_message_of_type
from ee.hogai.utils.types import AssistantState
from ee.hogai.utils.warehouse import serialize_database_schema
from ee.models import Conversation
class EvalOutput(BaseModel):
database_schema: str
query_kind: str | None = Field(default=None)
sql_query: str | None = Field(default=None)
class EvalMetadata(TypedDict):
team_id: int
async def serialize_database(team: Team):
database = await database_sync_to_async(create_hogql_database)(team=team)
context = HogQLContext(team=team, database=database, enable_select_queries=True)
return await serialize_database_schema(database, context)
@pytest.fixture
def call_graph(eval_ctx: EvaluationContext):
async def callable(entry: DatasetInput, *args, **kwargs) -> EvalOutput:
team = await Team.objects.aget(id=entry.team_id)
conversation, database_schema = await asyncio.gather(
Conversation.objects.acreate(team=team, user=eval_ctx.user),
serialize_database(team),
)
graph = AssistantGraph(team, eval_ctx.user).compile_full_graph()
state = await graph.ainvoke(
AssistantState(messages=[HumanMessage(content=entry.input["query"])]),
{
"configurable": {
"thread_id": conversation.id,
"user": eval_ctx.user,
"team": team,
"distinct_id": eval_ctx.user.distinct_id,
}
},
)
maybe_viz_message = find_last_message_of_type(state["messages"], VisualizationMessage)
if maybe_viz_message:
return EvalOutput(
database_schema=database_schema,
query_kind=maybe_viz_message.answer.kind,
sql_query=maybe_viz_message.answer.query
if isinstance(maybe_viz_message.answer, AssistantHogQLQuery)
else None,
)
return EvalOutput(database_schema=database_schema)
return callable
async def sql_semantics_scorer(input: DatasetInput, expected: str, output: EvalOutput, **kwargs) -> Score:
metric = SQLSemanticsCorrectness()
return await metric.eval_async(
input=input.input["query"], expected=expected, output=output.sql_query, database_schema=output.database_schema
)
async def sql_syntax_scorer(input: DatasetInput, expected: str, output: EvalOutput, **kwargs) -> Score:
metric = SQLSyntaxCorrectness()
return await metric.eval_async(
input=input.input["query"], expected=expected, output=output.sql_query, database_schema=output.database_schema
)
def generate_test_cases(eval_ctx: EvaluationContext):
for entry in eval_ctx.dataset:
metadata: EvalMetadata = {"team_id": entry.team_id}
yield EvalCase(input=entry, expected=entry.expected["output"], metadata=cast(dict, metadata))
@pytest.mark.django_db
async def eval_offline_sql(eval_ctx: EvaluationContext, call_graph, pytestconfig):
await MaxPrivateEval(
experiment_name=eval_ctx.experiment_name,
task=call_graph,
scores=[sql_syntax_scorer, sql_semantics_scorer],
data=generate_test_cases(eval_ctx),
pytestconfig=pytestconfig,
)

View File

@@ -0,0 +1,65 @@
from collections import defaultdict
from typing import Literal
from posthog.schema import (
ActorsPropertyTaxonomyQueryResponse,
ActorsPropertyTaxonomyResponse,
EventTaxonomyItem,
EventTaxonomyQueryResponse,
TeamTaxonomyItem,
TeamTaxonomyQueryResponse,
)
from posthog.hogql_queries.ai.actors_property_taxonomy_query_runner import ActorsPropertyTaxonomyQueryRunner
from posthog.hogql_queries.ai.event_taxonomy_query_runner import EventTaxonomyQueryRunner
from posthog.hogql_queries.ai.team_taxonomy_query_runner import TeamTaxonomyQueryRunner
# This is a global state that is used to store the patched results for the team taxonomy query.
TEAM_TAXONOMY_QUERY_DATA_SOURCE: dict[int, list[TeamTaxonomyItem]] = {}
class PatchedTeamTaxonomyQueryRunner(TeamTaxonomyQueryRunner):
def _calculate(self):
results: list[TeamTaxonomyItem] = []
if precomputed_results := TEAM_TAXONOMY_QUERY_DATA_SOURCE.get(self.team.id):
results = precomputed_results
return TeamTaxonomyQueryResponse(results=results, modifiers=self.modifiers)
# This is a global state that is used to store the patched results for the event taxonomy query.
EVENT_TAXONOMY_QUERY_DATA_SOURCE: dict[int, dict[str | int, list[EventTaxonomyItem]]] = defaultdict(dict)
class PatchedEventTaxonomyQueryRunner(EventTaxonomyQueryRunner):
def _calculate(self):
results: list[EventTaxonomyItem] = []
team_data = EVENT_TAXONOMY_QUERY_DATA_SOURCE.get(self.team.id, {})
if self.query.event in team_data:
results = team_data[self.query.event]
elif self.query.actionId in team_data:
results = team_data[self.query.actionId]
return EventTaxonomyQueryResponse(results=results, modifiers=self.modifiers)
# This is a global state that is used to store the patched results for the actors property taxonomy query.
ACTORS_PROPERTY_TAXONOMY_QUERY_DATA_SOURCE: dict[
int, dict[int | Literal["person"], dict[str, ActorsPropertyTaxonomyResponse]]
] = defaultdict(lambda: defaultdict(dict))
class PatchedActorsPropertyTaxonomyQueryRunner(ActorsPropertyTaxonomyQueryRunner):
def _calculate(self):
key: int | Literal["person"] = (
self.query.groupTypeIndex if isinstance(self.query.groupTypeIndex, int) else "person"
)
if (
self.team.id in ACTORS_PROPERTY_TAXONOMY_QUERY_DATA_SOURCE
and key in ACTORS_PROPERTY_TAXONOMY_QUERY_DATA_SOURCE[self.team.id]
):
data = ACTORS_PROPERTY_TAXONOMY_QUERY_DATA_SOURCE[self.team.id][key]
result: list[ActorsPropertyTaxonomyResponse] = []
for prop in self.query.properties:
result.append(data.get(prop, ActorsPropertyTaxonomyResponse(sample_values=[], sample_count=0)))
else:
result = [ActorsPropertyTaxonomyResponse(sample_values=[], sample_count=0) for _ in self.query.properties]
return ActorsPropertyTaxonomyQueryResponse(results=result, modifiers=self.modifiers)

View File

@@ -0,0 +1,182 @@
import os
import asyncio
from collections.abc import Generator
from io import BytesIO
from typing import TYPE_CHECKING, Any, Literal, TypeVar
from unittest.mock import patch
import backoff
import aioboto3
from asgiref.sync import sync_to_async
from dagster_pipes import PipesContext
from fastavro import reader
from pydantic_avro import AvroBase
from posthog.models import GroupTypeMapping, Organization, Project, PropertyDefinition, Team, User
from posthog.warehouse.models.table import DataWarehouseTable
from ee.hogai.eval.schema import (
ActorsPropertyTaxonomySnapshot,
DatasetInput,
DataWarehouseTableSnapshot,
EvalsDockerImageConfig,
GroupTypeMappingSnapshot,
PropertyDefinitionSnapshot,
PropertyTaxonomySnapshot,
TeamEvaluationSnapshot,
TeamSnapshot,
TeamTaxonomyItemSnapshot,
)
from .query_patches import (
ACTORS_PROPERTY_TAXONOMY_QUERY_DATA_SOURCE,
EVENT_TAXONOMY_QUERY_DATA_SOURCE,
TEAM_TAXONOMY_QUERY_DATA_SOURCE,
PatchedActorsPropertyTaxonomyQueryRunner,
PatchedEventTaxonomyQueryRunner,
PatchedTeamTaxonomyQueryRunner,
)
if TYPE_CHECKING:
from types_aiobotocore_s3.client import S3Client
T = TypeVar("T", bound=AvroBase)
class SnapshotLoader:
"""Loads snapshots from S3, restores Django models, and patches query runners."""
def __init__(self, context: PipesContext):
self.context = context
self.config = EvalsDockerImageConfig.model_validate(context.extras)
self.patches: list[Any] = []
async def load_snapshots(self) -> tuple[Organization, User, list[DatasetInput]]:
self.organization = await Organization.objects.acreate(name="PostHog")
self.user = await sync_to_async(User.objects.create_and_join)(self.organization, "test@posthog.com", "12345678")
for snapshot in self.config.team_snapshots:
self.context.log.info(f"Loading Postgres snapshot for team {snapshot.team_id}...")
project = await Project.objects.acreate(
id=await sync_to_async(Team.objects.increment_id_sequence)(), organization=self.organization
)
(
project_snapshot_bytes,
property_definitions_snapshot_bytes,
group_type_mappings_snapshot_bytes,
data_warehouse_tables_snapshot_bytes,
event_taxonomy_snapshot_bytes,
properties_taxonomy_snapshot_bytes,
actors_property_taxonomy_snapshot_bytes,
) = await self._get_all_snapshots(snapshot)
team = await self._load_team_snapshot(project, snapshot.team_id, project_snapshot_bytes)
await asyncio.gather(
self._load_property_definitions(property_definitions_snapshot_bytes, team=team, project=project),
self._load_group_type_mappings(group_type_mappings_snapshot_bytes, team=team, project=project),
self._load_data_warehouse_tables(data_warehouse_tables_snapshot_bytes, team=team, project=project),
)
self._load_event_taxonomy(event_taxonomy_snapshot_bytes, team=team)
self._load_properties_taxonomy(properties_taxonomy_snapshot_bytes, team=team)
self._load_actors_property_taxonomy(actors_property_taxonomy_snapshot_bytes, team=team)
self._patch_query_runners()
return self.organization, self.user, self.config.dataset
def cleanup(self):
for mock in self.patches:
mock.stop()
@backoff.on_exception(backoff.expo, Exception, max_tries=3)
async def _get_all_snapshots(self, snapshot: TeamEvaluationSnapshot):
async with aioboto3.Session().client(
"s3",
endpoint_url=self.config.aws_endpoint_url,
aws_access_key_id=os.getenv("OBJECT_STORAGE_ACCESS_KEY_ID"),
aws_secret_access_key=os.getenv("OBJECT_STORAGE_SECRET_ACCESS_KEY"),
) as client:
loaded_snapshots = await asyncio.gather(
self._get_snapshot_from_s3(client, snapshot.postgres.team),
self._get_snapshot_from_s3(client, snapshot.postgres.property_definitions),
self._get_snapshot_from_s3(client, snapshot.postgres.group_type_mappings),
self._get_snapshot_from_s3(client, snapshot.postgres.data_warehouse_tables),
self._get_snapshot_from_s3(client, snapshot.clickhouse.event_taxonomy),
self._get_snapshot_from_s3(client, snapshot.clickhouse.properties_taxonomy),
self._get_snapshot_from_s3(client, snapshot.clickhouse.actors_property_taxonomy),
)
return loaded_snapshots
async def _get_snapshot_from_s3(self, client: "S3Client", file_key: str):
response = await client.get_object(Bucket=self.config.aws_bucket_name, Key=file_key)
content = await response["Body"].read()
return BytesIO(content)
def _parse_snapshot_to_schema(self, schema: type[T], buffer: BytesIO) -> Generator[T, None, None]:
for record in reader(buffer):
yield schema.model_validate(record)
async def _load_team_snapshot(self, project: Project, team_id: int, buffer: BytesIO) -> Team:
team_snapshot = next(self._parse_snapshot_to_schema(TeamSnapshot, buffer))
team = next(TeamSnapshot.deserialize_for_team([team_snapshot], team_id=team_id, project_id=project.id))
team.project = project
team.organization = self.organization
team.api_token = f"team_{team_id}"
await team.asave()
return team
async def _load_property_definitions(self, buffer: BytesIO, *, team: Team, project: Project):
snapshot = list(self._parse_snapshot_to_schema(PropertyDefinitionSnapshot, buffer))
property_definitions = PropertyDefinitionSnapshot.deserialize_for_team(
snapshot, team_id=team.id, project_id=project.id
)
return await PropertyDefinition.objects.abulk_create(property_definitions, batch_size=500)
async def _load_group_type_mappings(self, buffer: BytesIO, *, team: Team, project: Project):
snapshot = list(self._parse_snapshot_to_schema(GroupTypeMappingSnapshot, buffer))
group_type_mappings = GroupTypeMappingSnapshot.deserialize_for_team(
snapshot, team_id=team.id, project_id=project.id
)
return await GroupTypeMapping.objects.abulk_create(group_type_mappings, batch_size=500)
async def _load_data_warehouse_tables(self, buffer: BytesIO, *, team: Team, project: Project):
snapshot = list(self._parse_snapshot_to_schema(DataWarehouseTableSnapshot, buffer))
data_warehouse_tables = DataWarehouseTableSnapshot.deserialize_for_team(
snapshot, team_id=team.id, project_id=project.id
)
return await DataWarehouseTable.objects.abulk_create(data_warehouse_tables, batch_size=500)
def _load_event_taxonomy(self, buffer: BytesIO, *, team: Team):
snapshot = next(self._parse_snapshot_to_schema(TeamTaxonomyItemSnapshot, buffer))
TEAM_TAXONOMY_QUERY_DATA_SOURCE[team.id] = snapshot.results
def _load_properties_taxonomy(self, buffer: BytesIO, *, team: Team):
for item in self._parse_snapshot_to_schema(PropertyTaxonomySnapshot, buffer):
EVENT_TAXONOMY_QUERY_DATA_SOURCE[team.id][item.event] = item.results
def _load_actors_property_taxonomy(self, buffer: BytesIO, *, team: Team):
for item in self._parse_snapshot_to_schema(ActorsPropertyTaxonomySnapshot, buffer):
key: int | Literal["person"] = item.group_type_index if isinstance(item.group_type_index, int) else "person"
ACTORS_PROPERTY_TAXONOMY_QUERY_DATA_SOURCE[team.pk][key][item.property] = item.results
def _patch_query_runners(self):
self.patches = [
patch(
"posthog.hogql_queries.ai.team_taxonomy_query_runner.TeamTaxonomyQueryRunner",
new=PatchedTeamTaxonomyQueryRunner,
),
patch(
"posthog.hogql_queries.ai.event_taxonomy_query_runner.EventTaxonomyQueryRunner",
new=PatchedEventTaxonomyQueryRunner,
),
patch(
"posthog.hogql_queries.ai.actors_property_taxonomy_query_runner.ActorsPropertyTaxonomyQueryRunner",
new=PatchedActorsPropertyTaxonomyQueryRunner,
),
]
for mock in self.patches:
mock.start()

View File

@@ -1,11 +1,11 @@
import json
from abc import ABC, abstractmethod
from collections.abc import Generator, Sequence
from typing import Generic, Self, TypeVar
from typing import Any, Generic, Self, TypeVar
from django.db.models import Model
from pydantic import BaseModel
from pydantic import BaseModel, Field
from pydantic_avro import AvroBase
from posthog.schema import ActorsPropertyTaxonomyResponse, EventTaxonomyItem, TeamTaxonomyItem
@@ -18,14 +18,12 @@ T = TypeVar("T", bound=Model)
class BaseSnapshot(AvroBase, ABC, Generic[T]):
@classmethod
@abstractmethod
def serialize_for_project(cls, project_id: int) -> Generator[Self, None, None]:
def serialize_for_team(cls, *, team_id: int) -> Generator[Self, None, None]:
raise NotImplementedError
@classmethod
@abstractmethod
def deserialize_for_project(
cls, project_id: int, models: Sequence[Self], *, team_id: int
) -> Generator[T, None, None]:
def deserialize_for_team(cls, models: Sequence[Self], *, team_id: int, project_id: int) -> Generator[T, None, None]:
raise NotImplementedError
@@ -35,15 +33,18 @@ class TeamSnapshot(BaseSnapshot[Team]):
test_account_filters: str
@classmethod
def serialize_for_project(cls, project_id: int):
team = Team.objects.get(pk=project_id)
def serialize_for_team(cls, *, team_id: int):
team = Team.objects.get(pk=team_id)
yield TeamSnapshot(name=team.name, test_account_filters=json.dumps(team.test_account_filters))
@classmethod
def deserialize_for_project(cls, project_id: int, models: Sequence[Self], **kwargs) -> Generator[Team, None, None]:
def deserialize_for_team(
cls, models: Sequence[Self], *, team_id: int, project_id: int
) -> Generator[Team, None, None]:
for model in models:
yield Team(
id=project_id,
id=team_id,
project_id=project_id,
name=model.name,
test_account_filters=json.loads(model.test_account_filters),
)
@@ -58,8 +59,8 @@ class PropertyDefinitionSnapshot(BaseSnapshot[PropertyDefinition]):
group_type_index: int | None
@classmethod
def serialize_for_project(cls, project_id: int):
for prop in PropertyDefinition.objects.filter(project_id=project_id).iterator(500):
def serialize_for_team(cls, *, team_id: int):
for prop in PropertyDefinition.objects.filter(team_id=team_id).iterator(500):
yield PropertyDefinitionSnapshot(
name=prop.name,
is_numerical=prop.is_numerical,
@@ -69,7 +70,7 @@ class PropertyDefinitionSnapshot(BaseSnapshot[PropertyDefinition]):
)
@classmethod
def deserialize_for_project(cls, project_id: int, models: Sequence[Self], **kwargs):
def deserialize_for_team(cls, models: Sequence[Self], *, team_id: int, project_id: int):
for model in models:
yield PropertyDefinition(
name=model.name,
@@ -77,7 +78,8 @@ class PropertyDefinitionSnapshot(BaseSnapshot[PropertyDefinition]):
property_type=model.property_type,
type=model.type,
group_type_index=model.group_type_index,
team_id=project_id,
team_id=team_id,
project_id=project_id,
)
@@ -89,8 +91,8 @@ class GroupTypeMappingSnapshot(BaseSnapshot[GroupTypeMapping]):
name_plural: str | None
@classmethod
def serialize_for_project(cls, project_id: int):
for mapping in GroupTypeMapping.objects.filter(project_id=project_id).iterator(500):
def serialize_for_team(cls, *, team_id: int):
for mapping in GroupTypeMapping.objects.filter(team_id=team_id).iterator(500):
yield GroupTypeMappingSnapshot(
group_type=mapping.group_type,
group_type_index=mapping.group_type_index,
@@ -99,7 +101,7 @@ class GroupTypeMappingSnapshot(BaseSnapshot[GroupTypeMapping]):
)
@classmethod
def deserialize_for_project(cls, project_id: int, models: Sequence[Self], *, team_id: int):
def deserialize_for_team(cls, models: Sequence[Self], *, team_id: int, project_id: int):
for model in models:
yield GroupTypeMapping(
group_type=model.group_type,
@@ -118,8 +120,8 @@ class DataWarehouseTableSnapshot(BaseSnapshot[DataWarehouseTable]):
columns: dict
@classmethod
def serialize_for_project(cls, project_id: int):
for table in DataWarehouseTable.objects.filter(team_id=project_id).iterator(500):
def serialize_for_team(cls, *, team_id: int):
for table in DataWarehouseTable.objects.filter(team_id=team_id).iterator(500):
yield DataWarehouseTableSnapshot(
name=table.name,
format=table.format,
@@ -127,19 +129,19 @@ class DataWarehouseTableSnapshot(BaseSnapshot[DataWarehouseTable]):
)
@classmethod
def deserialize_for_project(cls, project_id: int, models: Sequence[Self], **kwargs):
def deserialize_for_team(cls, models: Sequence[Self], *, team_id: int, project_id: int):
for model in models:
yield DataWarehouseTable(
name=model.name,
format=model.format,
columns=model.columns,
url_pattern="http://localhost", # Hardcoded. It's not important for evaluations what the value is.
team_id=project_id,
team_id=team_id,
)
class PostgresProjectDataSnapshot(BaseModel):
project: str
class PostgresTeamDataSnapshot(BaseModel):
team: str
property_definitions: str
group_type_mappings: str
data_warehouse_tables: str
@@ -163,7 +165,46 @@ class ActorsPropertyTaxonomySnapshot(AvroBase):
results: ActorsPropertyTaxonomyResponse
class ClickhouseProjectDataSnapshot(BaseModel):
class ClickhouseTeamDataSnapshot(BaseModel):
event_taxonomy: str
properties_taxonomy: str
actors_property_taxonomy: str
class TeamEvaluationSnapshot(BaseModel):
team_id: int
postgres: PostgresTeamDataSnapshot
clickhouse: ClickhouseTeamDataSnapshot
class DatasetInput(BaseModel):
team_id: int
input: dict[str, Any]
expected: dict[str, Any]
metadata: dict[str, Any] = Field(default_factory=dict)
class EvalsDockerImageConfig(BaseModel):
class Config:
extra = "allow"
aws_bucket_name: str
"""
AWS S3 bucket name for the raw snapshots for all projects.
"""
aws_endpoint_url: str
"""
AWS S3 endpoint URL for the raw snapshots for all projects.
"""
team_snapshots: list[TeamEvaluationSnapshot]
"""
Raw snapshots for all projects.
"""
experiment_name: str
"""
Name of the experiment.
"""
dataset: list[DatasetInput]
"""
Parsed dataset.
"""

View File

@@ -1,15 +1,27 @@
import json
from typing import TypedDict
from typing import Any, TypedDict, cast
from autoevals.llm import LLMClassifier
from autoevals.partial import ScorerWithPartial
from autoevals.ragas import AnswerSimilarity
from braintrust import Score
from langchain_core.messages import AIMessage as LangchainAIMessage
from posthog.schema import AssistantMessage, AssistantToolCall, NodeKind
from ee.hogai.utils.types.base import AnyAssistantGeneratedQuery
from ee.hogai.utils.types.base import AnyAssistantGeneratedQuery, AnyAssistantSupportedQuery
from .sql import SQLSemanticsCorrectness, SQLSyntaxCorrectness
__all__ = [
"SQLSemanticsCorrectness",
"SQLSyntaxCorrectness",
"ToolRelevance",
"QueryKindSelection",
"PlanCorrectness",
"QueryAndPlanAlignment",
"TimeRangeRelevancy",
"InsightEvaluationAccuracy",
]
class ToolRelevance(ScorerWithPartial):
@@ -25,7 +37,7 @@ class ToolRelevance(ScorerWithPartial):
return Score(name=self._name(), score=0)
if not isinstance(expected, AssistantToolCall):
raise TypeError(f"Eval case expected must be an AssistantToolCall, not {type(expected)}")
if not isinstance(output, AssistantMessage | LangchainAIMessage):
if not isinstance(output, AssistantMessage):
raise TypeError(f"Eval case output must be an AssistantMessage, not {type(output)}")
if output.tool_calls and len(output.tool_calls) > 1:
raise ValueError("Parallel tool calls not supported by this scorer yet")
@@ -52,16 +64,24 @@ class ToolRelevance(ScorerWithPartial):
class PlanAndQueryOutput(TypedDict, total=False):
plan: str | None
query: AnyAssistantGeneratedQuery
query: AnyAssistantGeneratedQuery | AnyAssistantSupportedQuery | None
query_generation_retry_count: int | None
def serialize_output(output: PlanAndQueryOutput | dict | None) -> PlanAndQueryOutput | None:
class SerializedPlanAndQueryOutput(TypedDict):
plan: str | None
query: dict[str, Any] | None
query_generation_retry_count: int | None
def serialize_output(output: PlanAndQueryOutput | dict | None) -> SerializedPlanAndQueryOutput | None:
if output:
return {
query = output.get("query")
serialized_output = {
**output,
"query": output.get("query").model_dump(exclude_none=True),
"query": query.model_dump(exclude_none=True) if query else None,
}
return cast(SerializedPlanAndQueryOutput, serialized_output)
return None
@@ -75,13 +95,14 @@ class QueryKindSelection(ScorerWithPartial):
self._expected = expected
def _run_eval_sync(self, output: PlanAndQueryOutput, expected=None, **kwargs):
if not output.get("query"):
query = output.get("query")
if not query:
return Score(name=self._name(), score=None, metadata={"reason": "No query present"})
score = 1 if output["query"].kind == self._expected else 0
score = 1 if query.kind == self._expected else 0
return Score(
name=self._name(),
score=score,
metadata={"reason": f"Expected {self._expected}, got {output['query'].kind}"} if not score else {},
metadata={"reason": f"Expected {self._expected}, got {query.kind}"} if not score else {},
)
@@ -89,22 +110,18 @@ class PlanCorrectness(LLMClassifier):
"""Evaluate if the generated plan correctly answers the user's question."""
async def _run_eval_async(self, output: PlanAndQueryOutput, expected=None, **kwargs):
if not output.get("plan"):
return Score(name=self._name(), score=0.0, metadata={"reason": "No plan present"})
output = PlanAndQueryOutput(
plan=output.get("plan"),
query=output["query"].model_dump_json(exclude_none=True) if output.get("query") else None, # Clean up
)
return await super()._run_eval_async(output, serialize_output(expected), **kwargs)
plan = output.get("plan")
query = output.get("query")
if not plan or not query:
return Score(name=self._name(), score=0.0, metadata={"reason": "No plan or query present"})
return await super()._run_eval_async(serialize_output(output), serialize_output(expected), **kwargs)
def _run_eval_sync(self, output: PlanAndQueryOutput, expected=None, **kwargs):
if not output.get("plan"):
return Score(name=self._name(), score=0.0, metadata={"reason": "No plan present"})
output = PlanAndQueryOutput(
plan=output.get("plan"),
query=output["query"].model_dump_json(exclude_none=True) if output.get("query") else None, # Clean up
)
return super()._run_eval_sync(output, serialize_output(expected), **kwargs)
plan = output.get("plan")
query = output.get("query")
if not plan or not query:
return Score(name=self._name(), score=0.0, metadata={"reason": "No plan or query present"})
return super()._run_eval_sync(serialize_output(output), serialize_output(expected), **kwargs)
def __init__(self, query_kind: NodeKind, evaluation_criteria: str, **kwargs):
super().__init__(
@@ -167,34 +184,30 @@ class QueryAndPlanAlignment(LLMClassifier):
"""Evaluate if the generated SQL query aligns with the plan generated in the previous step."""
async def _run_eval_async(self, output: PlanAndQueryOutput, expected: PlanAndQueryOutput | None = None, **kwargs):
if not output.get("plan"):
plan = output.get("plan")
if not plan:
return Score(
name=self._name(),
score=None,
metadata={"reason": "No plan present in the first place, skipping evaluation"},
)
if not output.get("query"):
query = output.get("query")
if not query:
return Score(name=self._name(), score=0.0, metadata={"reason": "Query failed to be generated"})
output = PlanAndQueryOutput(
plan=output.get("plan"),
query=output["query"].model_dump_json(exclude_none=True) if output.get("query") else None, # Clean up
)
return await super()._run_eval_async(output, serialize_output(expected), **kwargs)
return await super()._run_eval_async(serialize_output(output), serialize_output(expected), **kwargs)
def _run_eval_sync(self, output: PlanAndQueryOutput, expected: PlanAndQueryOutput | None = None, **kwargs):
if not output.get("plan"):
plan = output.get("plan")
if not plan:
return Score(
name=self._name(),
score=None,
metadata={"reason": "No plan present in the first place, skipping evaluation"},
)
if not output.get("query"):
query = output.get("query")
if not query:
return Score(name=self._name(), score=0.0, metadata={"reason": "Query failed to be generated"})
output = PlanAndQueryOutput(
plan=output.get("plan"),
query=output["query"].model_dump_json(exclude_none=True) if output.get("query") else None, # Clean up
)
return super()._run_eval_sync(output, serialize_output(expected), **kwargs)
return super()._run_eval_sync(serialize_output(output), serialize_output(expected), **kwargs)
def __init__(self, query_kind: NodeKind, json_schema: dict, evaluation_criteria: str, **kwargs):
json_schema_str = json.dumps(json_schema)
@@ -270,15 +283,15 @@ Details matter greatly here - including math types or property types - so be har
class TimeRangeRelevancy(LLMClassifier):
"""Evaluate if the generated query's time range, interval, or period correctly answers the user's question."""
async def _run_eval_async(self, output, expected=None, **kwargs):
async def _run_eval_async(self, output: PlanAndQueryOutput, expected: PlanAndQueryOutput | None = None, **kwargs):
if not output.get("query"):
return Score(name=self._name(), score=None, metadata={"reason": "No query to check, skipping evaluation"})
return await super()._run_eval_async(output, serialize_output(expected), **kwargs)
return await super()._run_eval_async(serialize_output(output), serialize_output(expected), **kwargs)
def _run_eval_sync(self, output, expected=None, **kwargs):
def _run_eval_sync(self, output: PlanAndQueryOutput, expected: PlanAndQueryOutput | None = None, **kwargs):
if not output.get("query"):
return Score(name=self._name(), score=None, metadata={"reason": "No query to check"})
return super()._run_eval_sync(output, serialize_output(expected), **kwargs)
return super()._run_eval_sync(serialize_output(output), serialize_output(expected), **kwargs)
def __init__(self, query_kind: NodeKind, **kwargs):
super().__init__(
@@ -344,78 +357,6 @@ How would you rate the time range relevancy of the generated query? Choose one:
)
SQL_SEMANTICS_CORRECTNESS_PROMPT = """
<system>
You are an expert ClickHouse SQL auditor.
Your job is to decide whether two ClickHouse SQL queries are **semantically equivalent for every possible valid database state**, given the same task description and schema.
When you respond, think step-by-step **internally**, but reveal **nothing** except the final verdict:
Output **Pass** if the candidate query would always return the same result set (ignoring column aliases, ordering, or trivial formatting) as the reference query.
Output **Fail** otherwise, or if you are uncertain.
Respond with a single word**Pass** or **Fail**and no additional text.
</system>
<input>
Task / natural-language question:
```
{{input}}
```
Database schema (tables and columns):
```
{{database_schema}}
```
Reference (human-labelled) SQL:
```sql
{{expected}}
```
Candidate (generated) SQL:
```sql
{{output}}
```
</input>
<reminder>
Think through edge cases: NULL handling, grouping, filters, joins, HAVING clauses, aggregations, sub-queries, limits, and data-type quirks.
If any logical difference could yield different outputs under some data scenario, the queries are *not* equivalent.
</reminder>
When ready, output your verdict**Pass** or **Fail**with absolutely no extra characters.
""".strip()
class SQLSemanticsCorrectness(LLMClassifier):
"""Evaluate if the actual query matches semantically the expected query."""
def __init__(self, **kwargs):
super().__init__(
name="sql_semantics_correctness",
prompt_template=SQL_SEMANTICS_CORRECTNESS_PROMPT,
choice_scores={
"Pass": 1.0,
"Fail": 0.0,
},
model="gpt-4.1",
**kwargs,
)
async def _run_eval_async(
self, output: str | None, expected: str | None = None, database_schema: str | None = None, **kwargs
):
if not output or output.strip() == "":
return Score(name=self._name(), score=None, metadata={"reason": "No query to check, skipping evaluation"})
return await super()._run_eval_async(output, expected, database_schema=database_schema, **kwargs)
def _run_eval_sync(
self, output: str | None, expected: str | None = None, database_schema: str | None = None, **kwargs
):
if not output or output.strip() == "":
return Score(name=self._name(), score=None, metadata={"reason": "No query to check, skipping evaluation"})
return super()._run_eval_sync(output, expected, database_schema=database_schema, **kwargs)
class InsightSearchOutput(TypedDict, total=False):
"""Output structure for insight search evaluations."""

View File

@@ -0,0 +1,115 @@
from asgiref.sync import sync_to_async
from autoevals.llm import LLMClassifier
from braintrust import Score
from braintrust_core.score import Scorer
from posthog.hogql.errors import BaseHogQLError
from posthog.errors import InternalCHQueryError
from posthog.hogql_queries.hogql_query_runner import HogQLQueryRunner
from posthog.models.team.team import Team
def evaluate_sql_query(name: str, output: str | None, team: Team | None = None) -> Score:
if not output:
return Score(name=name, score=None, metadata={"reason": "No SQL query to verify, skipping evaluation"})
if not team:
return Score(name=name, score=None, metadata={"reason": "No team provided, skipping evaluation"})
query = {"query": output}
try:
# Try to parse, print, and run the query
HogQLQueryRunner(query, team).calculate()
except BaseHogQLError as e:
return Score(name=name, score=0.0, metadata={"reason": f"HogQL-level error: {str(e)}"})
except InternalCHQueryError as e:
return Score(name=name, score=0.5, metadata={"reason": f"ClickHouse-level error: {str(e)}"})
else:
return Score(name=name, score=1.0)
class SQLSyntaxCorrectness(Scorer):
"""Evaluate if the generated SQL query has correct syntax."""
def _name(self):
return "sql_syntax_correctness"
async def _run_eval_async(self, output: str | None, team: Team | None = None, **kwargs):
return await sync_to_async(self._evaluate)(output, team)
def _run_eval_sync(self, output: str | None, team: Team | None = None, **kwargs):
return self._evaluate(output, team)
def _evaluate(self, output: str | None, team: Team | None = None) -> Score:
return evaluate_sql_query(self._name(), output, team)
SQL_SEMANTICS_CORRECTNESS_PROMPT = """
<system>
You are an expert ClickHouse SQL auditor.
Your job is to decide whether two ClickHouse SQL queries are **semantically equivalent for every possible valid database state**, given the same task description and schema.
When you respond, think step-by-step **internally**, but reveal **nothing** except the final verdict:
• Output **Pass** if the candidate query would always return the same result set (ignoring column aliases, ordering, or trivial formatting) as the reference query.
• Output **Fail** otherwise, or if you are uncertain.
Respond with a single word—**Pass** or **Fail**—and no additional text.
</system>
<input>
Task / natural-language question:
```
{{input}}
```
Database schema (tables and columns):
```
{{database_schema}}
```
Reference (human-labelled) SQL:
```sql
{{expected}}
```
Candidate (generated) SQL:
```sql
{{output}}
```
</input>
<reminder>
Think through edge cases: NULL handling, grouping, filters, joins, HAVING clauses, aggregations, sub-queries, limits, and data-type quirks.
If any logical difference could yield different outputs under some data scenario, the queries are *not* equivalent.
</reminder>
When ready, output your verdict—**Pass** or **Fail**—with absolutely no extra characters.
""".strip()
class SQLSemanticsCorrectness(LLMClassifier):
"""Evaluate if the actual query matches semantically the expected query."""
def __init__(self, **kwargs):
super().__init__(
name="sql_semantics_correctness",
prompt_template=SQL_SEMANTICS_CORRECTNESS_PROMPT,
choice_scores={
"Pass": 1.0,
"Fail": 0.0,
},
model="gpt-4.1",
**kwargs,
)
async def _run_eval_async(
self, output: str | None, expected: str | None = None, database_schema: str | None = None, **kwargs
):
if not output or output.strip() == "":
return Score(name=self._name(), score=None, metadata={"reason": "No query to check, skipping evaluation"})
return await super()._run_eval_async(output, expected, database_schema=database_schema, **kwargs)
def _run_eval_sync(
self, output: str | None, expected: str | None = None, database_schema: str | None = None, **kwargs
):
if not output or output.strip() == "":
return Score(name=self._name(), score=None, metadata={"reason": "No query to check, skipping evaluation"})
return super()._run_eval_sync(output, expected, database_schema=database_schema, **kwargs)

View File

@@ -0,0 +1,133 @@
from posthog.test.base import BaseTest
from posthog.schema import (
ActorsPropertyTaxonomyQuery,
ActorsPropertyTaxonomyResponse,
EventTaxonomyItem,
EventTaxonomyQuery,
TeamTaxonomyItem,
TeamTaxonomyQuery,
)
from ee.hogai.eval.offline.query_patches import (
ACTORS_PROPERTY_TAXONOMY_QUERY_DATA_SOURCE,
EVENT_TAXONOMY_QUERY_DATA_SOURCE,
TEAM_TAXONOMY_QUERY_DATA_SOURCE,
PatchedActorsPropertyTaxonomyQueryRunner,
PatchedEventTaxonomyQueryRunner,
PatchedTeamTaxonomyQueryRunner,
)
class TestQueryPatches(BaseTest):
def setUp(self):
super().setUp()
TEAM_TAXONOMY_QUERY_DATA_SOURCE[self.team.id] = [
TeamTaxonomyItem(count=10, event="$pageview"),
]
EVENT_TAXONOMY_QUERY_DATA_SOURCE[self.team.id] = {
"$pageview": [
EventTaxonomyItem(property="$browser", sample_values=["Safari"], sample_count=1),
],
}
ACTORS_PROPERTY_TAXONOMY_QUERY_DATA_SOURCE[self.team.id]["person"] = {
"$location": ActorsPropertyTaxonomyResponse(sample_values=["US"], sample_count=1),
"$browser": ActorsPropertyTaxonomyResponse(sample_values=["Safari"], sample_count=10),
}
ACTORS_PROPERTY_TAXONOMY_QUERY_DATA_SOURCE[self.team.id][0] = {
"$device": ActorsPropertyTaxonomyResponse(sample_values=["Phone"], sample_count=1),
}
def tearDown(self):
super().tearDown()
TEAM_TAXONOMY_QUERY_DATA_SOURCE.clear()
EVENT_TAXONOMY_QUERY_DATA_SOURCE.clear()
ACTORS_PROPERTY_TAXONOMY_QUERY_DATA_SOURCE.clear()
def test_patched_team_taxonomy_query_runner_returns_result(self):
query_runner = PatchedTeamTaxonomyQueryRunner(
team=self.team,
query=TeamTaxonomyQuery(),
).calculate()
self.assertEqual(query_runner.results, [TeamTaxonomyItem(count=10, event="$pageview")])
def test_patched_team_taxonomy_query_runner_handles_no_results(self):
TEAM_TAXONOMY_QUERY_DATA_SOURCE.clear()
query_runner = PatchedTeamTaxonomyQueryRunner(
team=self.team,
query=TeamTaxonomyQuery(),
).calculate()
self.assertEqual(query_runner.results, [])
def test_patched_event_taxonomy_query_runner_returns_result(self):
query_runner = PatchedEventTaxonomyQueryRunner(
team=self.team,
query=EventTaxonomyQuery(event="$pageview"),
).calculate()
self.assertEqual(
query_runner.results, [EventTaxonomyItem(property="$browser", sample_values=["Safari"], sample_count=1)]
)
def test_patched_event_taxonomy_query_runner_handles_no_results(self):
EVENT_TAXONOMY_QUERY_DATA_SOURCE.clear()
query_runner = PatchedEventTaxonomyQueryRunner(
team=self.team,
query=EventTaxonomyQuery(event="$pageview"),
).calculate()
self.assertEqual(query_runner.results, [])
def test_patched_event_taxonomy_query_runner_returns_result_for_action_id(self):
EVENT_TAXONOMY_QUERY_DATA_SOURCE[self.team.id][123] = [
EventTaxonomyItem(property="$browser", sample_values=["Safari"], sample_count=1),
]
query_runner = PatchedEventTaxonomyQueryRunner(
team=self.team,
query=EventTaxonomyQuery(actionId=123),
).calculate()
self.assertEqual(
query_runner.results,
[EventTaxonomyItem(property="$browser", sample_values=["Safari"], sample_count=1)],
)
def test_patched_actors_property_taxonomy_query_runner_returns_result(self):
query_runner = PatchedActorsPropertyTaxonomyQueryRunner(
team=self.team,
query=ActorsPropertyTaxonomyQuery(groupTypeIndex=0, properties=["$device"]),
).calculate()
self.assertEqual(
query_runner.results,
[ActorsPropertyTaxonomyResponse(sample_values=["Phone"], sample_count=1)],
)
query_runner = PatchedActorsPropertyTaxonomyQueryRunner(
team=self.team,
query=ActorsPropertyTaxonomyQuery(properties=["$location", "$browser"]),
).calculate()
self.assertEqual(
query_runner.results,
[
ActorsPropertyTaxonomyResponse(sample_values=["US"], sample_count=1),
ActorsPropertyTaxonomyResponse(sample_values=["Safari"], sample_count=10),
],
)
def test_patched_actors_property_taxonomy_query_runner_handles_mixed_results(self):
query_runner = PatchedActorsPropertyTaxonomyQueryRunner(
team=self.team,
query=ActorsPropertyTaxonomyQuery(properties=["$location", "$latitude"]),
).calculate()
self.assertEqual(
query_runner.results,
[
ActorsPropertyTaxonomyResponse(sample_values=["US"], sample_count=1),
ActorsPropertyTaxonomyResponse(sample_values=[], sample_count=0),
],
)
def test_patched_actors_property_taxonomy_query_runner_handles_no_results(self):
ACTORS_PROPERTY_TAXONOMY_QUERY_DATA_SOURCE.clear()
query_runner = PatchedActorsPropertyTaxonomyQueryRunner(
team=self.team,
query=ActorsPropertyTaxonomyQuery(properties=["$device"]),
).calculate()
self.assertEqual(query_runner.results, [ActorsPropertyTaxonomyResponse(sample_values=[], sample_count=0)])

View File

@@ -0,0 +1,254 @@
from __future__ import annotations
from io import BytesIO
from typing import Any
from posthog.test.base import BaseTest
from unittest.mock import AsyncMock, MagicMock, patch
from asgiref.sync import async_to_sync
from posthog.schema import (
ActorsPropertyTaxonomyQuery,
ActorsPropertyTaxonomyResponse,
EventTaxonomyItem,
EventTaxonomyQuery,
TeamTaxonomyItem,
TeamTaxonomyQuery,
)
from posthog.hogql_queries.query_runner import get_query_runner
from posthog.models import GroupTypeMapping, Organization, PropertyDefinition, Team
from posthog.warehouse.models.table import DataWarehouseTable
from ee.hogai.eval.offline.query_patches import (
ACTORS_PROPERTY_TAXONOMY_QUERY_DATA_SOURCE,
EVENT_TAXONOMY_QUERY_DATA_SOURCE,
TEAM_TAXONOMY_QUERY_DATA_SOURCE,
)
from ee.hogai.eval.offline.snapshot_loader import SnapshotLoader
from ee.hogai.eval.schema import (
ActorsPropertyTaxonomySnapshot,
ClickhouseTeamDataSnapshot,
DataWarehouseTableSnapshot,
GroupTypeMappingSnapshot,
PostgresTeamDataSnapshot,
PropertyDefinitionSnapshot,
PropertyTaxonomySnapshot,
TeamSnapshot,
TeamTaxonomyItemSnapshot,
)
class TestSnapshotLoader(BaseTest):
def setUp(self):
super().setUp()
TEAM_TAXONOMY_QUERY_DATA_SOURCE.clear()
EVENT_TAXONOMY_QUERY_DATA_SOURCE.clear()
ACTORS_PROPERTY_TAXONOMY_QUERY_DATA_SOURCE.clear()
def tearDown(self):
TEAM_TAXONOMY_QUERY_DATA_SOURCE.clear()
EVENT_TAXONOMY_QUERY_DATA_SOURCE.clear()
ACTORS_PROPERTY_TAXONOMY_QUERY_DATA_SOURCE.clear()
super().tearDown()
def _extras(self, team_id: int) -> dict[str, Any]:
return {
"aws_endpoint_url": "http://localhost:9000",
"aws_bucket_name": "evals",
"experiment_name": "offline_evaluation",
"team_snapshots": [
{
"team_id": team_id,
"postgres": PostgresTeamDataSnapshot(
team=f"pg/team_{team_id}.avro",
property_definitions=f"pg/property_defs_{team_id}.avro",
group_type_mappings=f"pg/group_mappings_{team_id}.avro",
data_warehouse_tables=f"pg/dw_tables_{team_id}.avro",
).model_dump(),
"clickhouse": ClickhouseTeamDataSnapshot(
event_taxonomy=f"ch/event_taxonomy_{team_id}.avro",
properties_taxonomy=f"ch/properties_taxonomy_{team_id}.avro",
actors_property_taxonomy=f"ch/actors_property_taxonomy_{team_id}.avro",
).model_dump(),
}
],
"dataset": [],
}
def _fake_parse(self, schema, _buffer):
if schema is TeamSnapshot:
yield TeamSnapshot(name="Evaluated Team", test_account_filters="{}")
return
if schema is PropertyDefinitionSnapshot:
yield PropertyDefinitionSnapshot(
name="$browser", is_numerical=False, property_type="String", type=1, group_type_index=None
)
yield PropertyDefinitionSnapshot(
name="$device", is_numerical=False, property_type="String", type=1, group_type_index=0
)
return
if schema is GroupTypeMappingSnapshot:
yield GroupTypeMappingSnapshot(
group_type="organization", group_type_index=0, name_singular="org", name_plural="orgs"
)
yield GroupTypeMappingSnapshot(
group_type="company", group_type_index=1, name_singular="company", name_plural="companies"
)
return
if schema is DataWarehouseTableSnapshot:
yield DataWarehouseTableSnapshot(name="users", format="Parquet", columns={"id": "Int64"})
yield DataWarehouseTableSnapshot(name="orders", format="Parquet", columns={"id": "Int64"})
return
if schema is TeamTaxonomyItemSnapshot:
yield TeamTaxonomyItemSnapshot(
results=[
TeamTaxonomyItem(count=123, event="$pageview"),
TeamTaxonomyItem(count=45, event="$autocapture"),
]
)
return
if schema is PropertyTaxonomySnapshot:
# Two events to fill EVENT_TAXONOMY_QUERY_DATA_SOURCE
yield PropertyTaxonomySnapshot(
event="$pageview",
results=[
EventTaxonomyItem(property="$browser", sample_values=["Safari"], sample_count=10),
EventTaxonomyItem(property="$os", sample_values=["macOS"], sample_count=5),
],
)
yield PropertyTaxonomySnapshot(
event="$autocapture",
results=[
EventTaxonomyItem(property="$element_type", sample_values=["a"], sample_count=3),
],
)
return
if schema is ActorsPropertyTaxonomySnapshot:
# person (None), group 0, and group 1
yield ActorsPropertyTaxonomySnapshot(
group_type_index=None,
property="$browser",
results=ActorsPropertyTaxonomyResponse(sample_values=["Safari"], sample_count=10),
)
yield ActorsPropertyTaxonomySnapshot(
group_type_index=0,
property="$device",
results=ActorsPropertyTaxonomyResponse(sample_values=["Phone"], sample_count=2),
)
yield ActorsPropertyTaxonomySnapshot(
group_type_index=1,
property="$industry",
results=ActorsPropertyTaxonomyResponse(sample_values=["Tech"], sample_count=1),
)
return
raise AssertionError(f"Unhandled schema in fake parser: {schema}")
def _build_context(self, extras: dict[str, Any]) -> MagicMock:
ctx = MagicMock()
ctx.extras = extras
ctx.log = MagicMock()
ctx.log.info = MagicMock()
return ctx
def _load_with_mocks(self) -> tuple[Organization, Any, list[Any], Team]:
extras = self._extras(9990)
ctx = self._build_context(extras)
async_get = AsyncMock(side_effect=lambda client, key: BytesIO(b"ok"))
with patch.object(SnapshotLoader, "_get_snapshot_from_s3", new=async_get):
with patch.object(SnapshotLoader, "_parse_snapshot_to_schema", new=self._fake_parse):
loader = SnapshotLoader(ctx)
org, user, dataset = async_to_sync(loader.load_snapshots)()
return org, user, dataset, Team.objects.get(id=9990)
def test_loads_data_from_s3(self):
team_id = 99990
extras = self._extras(team_id)
ctx = self._build_context(extras)
calls: list[tuple[dict[str, Any], str]] = []
async def record_call(_self, client, key: str):
calls.append(({"Bucket": extras["aws_bucket_name"]}, key))
return BytesIO(b"ok")
with patch.object(SnapshotLoader, "_get_snapshot_from_s3", new=record_call):
with patch.object(SnapshotLoader, "_parse_snapshot_to_schema", new=self._fake_parse):
loader = SnapshotLoader(ctx)
async_to_sync(loader.load_snapshots)()
keys = [k for _, k in calls]
self.assertIn(f"pg/team_{team_id}.avro", keys)
self.assertIn(f"pg/property_defs_{team_id}.avro", keys)
self.assertIn(f"pg/group_mappings_{team_id}.avro", keys)
self.assertIn(f"pg/dw_tables_{team_id}.avro", keys)
self.assertIn(f"ch/event_taxonomy_{team_id}.avro", keys)
self.assertIn(f"ch/properties_taxonomy_{team_id}.avro", keys)
self.assertIn(f"ch/actors_property_taxonomy_{team_id}.avro", keys)
def test_restores_org_team_user(self):
org, user, _dataset, team = self._load_with_mocks()
self.assertEqual(org.name, "PostHog")
self.assertEqual(team.organization_id, org.id)
self.assertEqual(team.id, 9990)
self.assertEqual(team.api_token, "team_9990")
def test_restores_models_counts(self):
_org, _user, _dataset, team = self._load_with_mocks()
self.assertEqual(PropertyDefinition.objects.filter(team_id=team.id).count(), 2)
self.assertEqual(GroupTypeMapping.objects.filter(team_id=team.id).count(), 2)
self.assertEqual(DataWarehouseTable.objects.filter(team_id=team.id).count(), 2)
def test_loads_team_taxonomy_data_source(self):
_org, _user, _dataset, team = self._load_with_mocks()
self.assertIn(team.id, TEAM_TAXONOMY_QUERY_DATA_SOURCE)
self.assertEqual([i.event for i in TEAM_TAXONOMY_QUERY_DATA_SOURCE[team.id]], ["$pageview", "$autocapture"])
def test_loads_event_taxonomy_data_source(self):
_org, _user, _dataset, team = self._load_with_mocks()
self.assertIn("$pageview", EVENT_TAXONOMY_QUERY_DATA_SOURCE[team.id])
self.assertIn("$autocapture", EVENT_TAXONOMY_QUERY_DATA_SOURCE[team.id])
props = [i.property for i in EVENT_TAXONOMY_QUERY_DATA_SOURCE[team.id]["$pageview"]]
self.assertIn("$browser", props)
self.assertIn("$os", props)
def test_loads_actors_property_taxonomy_data_source_various_group_types(self):
_org, _user, _dataset, team = self._load_with_mocks()
self.assertIn("person", ACTORS_PROPERTY_TAXONOMY_QUERY_DATA_SOURCE[team.id])
self.assertIn(0, ACTORS_PROPERTY_TAXONOMY_QUERY_DATA_SOURCE[team.id])
self.assertIn(1, ACTORS_PROPERTY_TAXONOMY_QUERY_DATA_SOURCE[team.id])
def test_runs_query_runner_patches(self):
_org, _user, _dataset, team = self._load_with_mocks()
# Team taxonomy
team_resp = get_query_runner(TeamTaxonomyQuery(), team=team).calculate()
self.assertTrue(any(item.event == "$pageview" for item in team_resp.results))
# Event taxonomy by event
event_resp = get_query_runner(EventTaxonomyQuery(event="$pageview", maxPropertyValues=5), team=team).calculate()
self.assertTrue(any(item.property == "$browser" for item in event_resp.results))
# Actors property taxonomy for person
actors_resp_person = get_query_runner(
ActorsPropertyTaxonomyQuery(properties=["$browser"], groupTypeIndex=None),
team=team,
).calculate()
self.assertEqual(actors_resp_person.results[0].sample_values, ["Safari"])
# Actors property taxonomy for group 0
actors_resp_group0 = get_query_runner(
ActorsPropertyTaxonomyQuery(properties=["$device"], groupTypeIndex=0),
team=team,
).calculate()
self.assertEqual(actors_resp_group0.results[0].sample_values, ["Phone"])
# Actors property taxonomy for group 1
actors_resp_group1 = get_query_runner(
ActorsPropertyTaxonomyQuery(properties=["$industry"], groupTypeIndex=1),
team=team,
).calculate()
self.assertEqual(actors_resp_group1.results[0].sample_values, ["Tech"])

View File

@@ -76,6 +76,10 @@ OPENAI_API_KEY = get_from_env("OPENAI_API_KEY", "")
INKEEP_API_KEY = get_from_env("INKEEP_API_KEY", "")
MISTRAL_API_KEY = get_from_env("MISTRAL_API_KEY", "")
GEMINI_API_KEY = get_from_env("GEMINI_API_KEY", "")
PPLX_API_KEY = get_from_env("PPLX_API_KEY", "")
AZURE_INFERENCE_ENDPOINT = get_from_env("AZURE_INFERENCE_ENDPOINT", "")
AZURE_INFERENCE_CREDENTIAL = get_from_env("AZURE_INFERENCE_CREDENTIAL", "")
BRAINTRUST_API_KEY = get_from_env("BRAINTRUST_API_KEY", "")
MAILJET_PUBLIC_KEY = get_from_env("MAILJET_PUBLIC_KEY", "", type_cast=str)
MAILJET_SECRET_KEY = get_from_env("MAILJET_SECRET_KEY", "", type_cast=str)

Binary file not shown.

Before

Width:  |  Height:  |  Size: 92 KiB

After

Width:  |  Height:  |  Size: 92 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 136 KiB

After

Width:  |  Height:  |  Size: 136 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 76 KiB

After

Width:  |  Height:  |  Size: 76 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 77 KiB

After

Width:  |  Height:  |  Size: 77 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 115 KiB

After

Width:  |  Height:  |  Size: 115 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 91 KiB

After

Width:  |  Height:  |  Size: 91 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 140 KiB

After

Width:  |  Height:  |  Size: 140 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 114 KiB

After

Width:  |  Height:  |  Size: 114 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 85 KiB

After

Width:  |  Height:  |  Size: 85 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 84 KiB

After

Width:  |  Height:  |  Size: 84 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 120 KiB

After

Width:  |  Height:  |  Size: 121 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 142 KiB

After

Width:  |  Height:  |  Size: 141 KiB

View File

@@ -299,20 +299,6 @@ ee/clickhouse/views/test/test_experiment_saved_metrics.py:0: error: Item "None"
ee/clickhouse/views/test/test_experiment_saved_metrics.py:0: error: Item "None" of "Any | None" has no attribute "query" [union-attr]
ee/clickhouse/views/test/test_experiment_saved_metrics.py:0: error: Item "None" of "Any | None" has no attribute "query" [union-attr]
ee/clickhouse/views/test/test_experiment_saved_metrics.py:0: error: Item "None" of "ExperimentToSavedMetric | None" has no attribute "metadata" [union-attr]
ee/hogai/eval/conftest.py:0: error: Incompatible types (expression has type "None", TypedDict item "query" has type "AssistantTrendsQuery | AssistantFunnelsQuery | AssistantRetentionQuery | AssistantHogQLQuery") [typeddict-item]
ee/hogai/eval/conftest.py:0: error: Item "None" of "Any | None" has no attribute "organization" [union-attr]
ee/hogai/eval/eval_memory.py:0: error: "BaseMessage" has no attribute "tool_calls" [attr-defined]
ee/hogai/eval/max_tools/eval_hogql_generator_tool.py:0: error: List item 1 has incompatible type "Callable[[EvalInput, str, str, dict[Any, Any]], Coroutine[Any, Any, Score]]"; expected "AsyncScorerLike[Any, Any] | type[SyncScorerLike[Any, Any]] | type[AsyncScorerLike[Any, Any]] | Callable[[Any, Any, Any], float | int | bool | Score | list[Score] | None] | Callable[[Any, Any, Any], Awaitable[float | int | bool | Score | list[Score] | None]]" [list-item]
ee/hogai/eval/scorers.py:0: error: Incompatible types (expression has type "dict[str, Any] | Any", TypedDict item "query" has type "AssistantTrendsQuery | AssistantFunnelsQuery | AssistantRetentionQuery | AssistantHogQLQuery") [typeddict-item]
ee/hogai/eval/scorers.py:0: error: Incompatible types (expression has type "str | None", TypedDict item "query" has type "AssistantTrendsQuery | AssistantFunnelsQuery | AssistantRetentionQuery | AssistantHogQLQuery") [typeddict-item]
ee/hogai/eval/scorers.py:0: error: Incompatible types (expression has type "str | None", TypedDict item "query" has type "AssistantTrendsQuery | AssistantFunnelsQuery | AssistantRetentionQuery | AssistantHogQLQuery") [typeddict-item]
ee/hogai/eval/scorers.py:0: error: Incompatible types (expression has type "str | None", TypedDict item "query" has type "AssistantTrendsQuery | AssistantFunnelsQuery | AssistantRetentionQuery | AssistantHogQLQuery") [typeddict-item]
ee/hogai/eval/scorers.py:0: error: Incompatible types (expression has type "str | None", TypedDict item "query" has type "AssistantTrendsQuery | AssistantFunnelsQuery | AssistantRetentionQuery | AssistantHogQLQuery") [typeddict-item]
ee/hogai/eval/scorers.py:0: error: Item "None" of "AssistantTrendsQuery | AssistantFunnelsQuery | AssistantRetentionQuery | AssistantHogQLQuery | Any | None" has no attribute "model_dump" [union-attr]
ee/hogai/eval/scorers.py:0: error: Item "ToolCall" of "AssistantToolCall | ToolCall" has no attribute "args" [union-attr]
ee/hogai/eval/scorers.py:0: error: Item "ToolCall" of "AssistantToolCall | ToolCall" has no attribute "args" [union-attr]
ee/hogai/eval/scorers.py:0: error: Item "ToolCall" of "AssistantToolCall | ToolCall" has no attribute "args" [union-attr]
ee/hogai/eval/scorers.py:0: error: Item "ToolCall" of "AssistantToolCall | ToolCall" has no attribute "name" [union-attr]
ee/hogai/graph/billing/test/test_nodes.py:0: error: Item "None" of "Any | None" has no attribute "data" [union-attr]
ee/hogai/graph/billing/test/test_nodes.py:0: error: Item "None" of "Any | None" has no attribute "dates" [union-attr]
ee/hogai/graph/billing/test/test_nodes.py:0: error: Item "None" of "Any | None" has no attribute "dates" [union-attr]

View File

@@ -1,6 +1,6 @@
from django.db import models
from posthog.models.utils import UUIDModel, BytecodeModelMixin
from posthog.models.utils import BytecodeModelMixin, UUIDModel
class GroupUsageMetric(UUIDModel, BytecodeModelMixin):

View File

@@ -1,6 +1,7 @@
from posthog.models import GroupUsageMetric
from posthog.test.base import BaseTest
from posthog.models import GroupUsageMetric
class GroupUsageMetricTestCase(BaseTest):
def test_bytecode_generation(self):

View File

@@ -17,9 +17,10 @@ from django.db.models import Q, UniqueConstraint
from django.db.models.constraints import BaseConstraint
from django.utils.text import slugify
from posthog.constants import MAX_SLUG_LENGTH
from posthog.hogql import ast
from posthog.constants import MAX_SLUG_LENGTH
if TYPE_CHECKING:
from random import Random

View File

@@ -23,6 +23,8 @@ dependencies = [
"dagster-cloud==1.10.18",
"dagster-aws==0.26.18",
"dagster-celery==0.26.18",
"dagster-docker==0.26.18",
"dagster-pipes==1.10.18",
"dagster-postgres==0.26.18",
"dagster-slack==0.26.18",
"dagster-webserver==1.10.18",
@@ -168,8 +170,9 @@ dev = [
"autoevals==0.0.129",
"black~=23.9.1",
"boto3-stubs[s3]>=1.34.84",
"braintrust==0.2.0",
"braintrust-langchain==0.0.2",
"braintrust==0.2.4",
"braintrust-langchain==0.0.4",
"dagster-dg-cli>=1.10.18",
"datamodel-code-generator==0.28.5",
"deepdiff>=8.5.0",
"django-linear-migrations==2.16.*",
@@ -207,7 +210,7 @@ dev = [
"responses==0.23.1",
"ruff~=0.8.1",
"sqlalchemy==2.0.38",
"stpyv8==13.1.201.22",
"stpyv8==13.1.201.22; sys_platform != 'linux' or (platform_machine != 'aarch64' and platform_machine != 'arm64')",
"syrupy~=4.6.0",
"tach~=0.20.0",
"types-aioboto3[s3]>=14.3.0",

149
uv.lock generated
View File

@@ -520,7 +520,7 @@ wheels = [
[[package]]
name = "braintrust"
version = "0.2.0"
version = "0.2.4"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "chevron" },
@@ -533,9 +533,9 @@ dependencies = [
{ name = "tqdm" },
{ name = "typing-extensions" },
]
sdist = { url = "https://files.pythonhosted.org/packages/ca/2d/c1edc49d8ea3548552a844251de01946e566621b2f6a42af284a058c360d/braintrust-0.2.0.tar.gz", hash = "sha256:fd45665b9521276d581c12a947bf9033a51c5365c986454df6332e01729da270", size = 145441, upload-time = "2025-07-25T05:34:44.555Z" }
sdist = { url = "https://files.pythonhosted.org/packages/5a/c3/01db8b072d56b3bcebe443af7acc57466c87bf2f49612a444ea707f76a90/braintrust-0.2.4.tar.gz", hash = "sha256:b1866a646d2368bd4c2ccba305e5c1e1d97c7f8c715833d3fc69810644fdc66f", size = 159340, upload-time = "2025-08-20T21:15:18.8Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/6a/60/83fb8bcc84f7fe408b7373c520187e777859c79bc6aa9ad5d1eb7d4b4c8a/braintrust-0.2.0-py3-none-any.whl", hash = "sha256:dffc74fe85f9c855461712210a852c543d8573767544403b3f73fbc0e645f5a2", size = 171651, upload-time = "2025-07-25T05:34:43.249Z" },
{ url = "https://files.pythonhosted.org/packages/31/32/4a48037e9956592f3a7cfae442ecf23c4ea6884950932126561916c089b4/braintrust-0.2.4-py3-none-any.whl", hash = "sha256:7dc30f32ec5e909b0b35e75ee6596090bdc15e4ed9f4d82eda52103b1aaaf721", size = 185252, upload-time = "2025-08-20T21:15:17.137Z" },
]
[[package]]
@@ -549,15 +549,15 @@ wheels = [
[[package]]
name = "braintrust-langchain"
version = "0.0.2"
version = "0.0.4"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "braintrust" },
{ name = "langchain" },
]
sdist = { url = "https://files.pythonhosted.org/packages/77/1d/9f1def0527de67191c41fa3941a8cc72b187c6fd52c1f8b7a9b24fc6240a/braintrust_langchain-0.0.2.tar.gz", hash = "sha256:f5af5ebc600cd45fd67a4ba4b6e0b99248eb6f6caea79f521d0bee1f045ac4b0", size = 42543, upload-time = "2025-03-10T21:24:25.926Z" }
sdist = { url = "https://files.pythonhosted.org/packages/7a/28/c302623e8d1ad293b2a7029dd7a2de3feaa50a486335f6d37a577b7453eb/braintrust_langchain-0.0.4.tar.gz", hash = "sha256:5ee372f97857f436d3f81d19a9600d4a741e570b6bf80a714caf424406713cfb", size = 43081, upload-time = "2025-08-06T19:40:11.861Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/3a/7e/235c99c8d934b922b8ffd9cd2220f21e14fabb4cc126eb2e3a914baa54f1/braintrust_langchain-0.0.2-py3-none-any.whl", hash = "sha256:8a5cab14b84f375e43500eb5feae64d5311683a268a3d7c4c70e592dfdbaa444", size = 43040, upload-time = "2025-03-10T21:24:24.421Z" },
{ url = "https://files.pythonhosted.org/packages/e5/f3/aaf624f046902454c0ebf84e8426864be9587938d8ef85e4484502c444ed/braintrust_langchain-0.0.4-py3-none-any.whl", hash = "sha256:c9bb9cfbcb938f71b0344821fdb8e69091e207d3691c0fead8e780088798bbc6", size = 43492, upload-time = "2025-08-06T19:40:10.456Z" },
]
[[package]]
@@ -748,6 +748,18 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/7e/d4/7ebdbd03970677812aac39c869717059dbb71a4cfc033ca6e5221787892c/click-8.1.8-py3-none-any.whl", hash = "sha256:63c132bbbed01578a06712a2d1f497bb62d9c1c0d329b7903a866228027263b2", size = 98188, upload-time = "2024-12-21T18:38:41.666Z" },
]
[[package]]
name = "click-aliases"
version = "1.0.5"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "click" },
]
sdist = { url = "https://files.pythonhosted.org/packages/40/d8/adbaeadc13c9686b9bda8b4c50e5a3983f504faae2ffbea5165d5beb1cdb/click_aliases-1.0.5.tar.gz", hash = "sha256:e37d4cabbaad68e1c48ec0f063a59dfa15f0e7450ec901bd1ce4f4b954bc881d", size = 3105, upload-time = "2024-10-17T15:44:19.056Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/3d/1a/d5e29a6f896293e32ab3e63201df5d396599e57a726575adaafbcd9d70a6/click_aliases-1.0.5-py3-none-any.whl", hash = "sha256:cbb83a348acc00809fe18b6da13a7f6307bc71b3c5f69cc730e012dfb4bbfdc3", size = 3524, upload-time = "2024-10-17T15:44:17.389Z" },
]
[[package]]
name = "click-didyoumean"
version = "0.3.1"
@@ -1092,6 +1104,61 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/e5/14/d8e23457fbabd97c5e71852864a423cdef983781daba360e70631987a917/dagster_cloud_cli-1.10.18-py3-none-any.whl", hash = "sha256:ad513f6ca11b7c79703b2eaa4989f90395db0d0a439679cccebd794a336f6199", size = 107883, upload-time = "2025-05-29T21:36:21.843Z" },
]
[[package]]
name = "dagster-dg-cli"
version = "1.10.18"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "dagster-dg-core" },
]
sdist = { url = "https://files.pythonhosted.org/packages/0b/f3/dcdeca179a69fad1cc75054725e72ddab4ebd3c095b2af6a63cf18228fe0/dagster_dg_cli-1.10.18.tar.gz", hash = "sha256:666d64f448b707ef557d855b94e2431b6f9ce86e35e09b877be16ae164cd678b", size = 539236, upload-time = "2025-05-29T21:52:30.092Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/8b/4f/0e452fa2846a41ba6db0fa3a3772b9fde40ea0c342ddf17f9a27eb7e2b69/dagster_dg_cli-1.10.18-py3-none-any.whl", hash = "sha256:949043b72d9ce5af0eee2499c606e132b8655db51dc2b78f8e7b2e380f475109", size = 569884, upload-time = "2025-05-29T21:52:28.111Z" },
]
[[package]]
name = "dagster-dg-core"
version = "1.10.18"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "click" },
{ name = "click-aliases" },
{ name = "dagster-cloud-cli" },
{ name = "dagster-shared" },
{ name = "gql", extra = ["requests"] },
{ name = "jinja2" },
{ name = "jsonschema" },
{ name = "markdown" },
{ name = "packaging" },
{ name = "python-dotenv" },
{ name = "pyyaml" },
{ name = "rich" },
{ name = "setuptools" },
{ name = "tomlkit" },
{ name = "typer" },
{ name = "typing-extensions" },
{ name = "watchdog" },
{ name = "yaspin" },
]
sdist = { url = "https://files.pythonhosted.org/packages/d2/a8/3965d6d769d541bddd545e5624378c853539a7aa7582479e78d5b516bf30/dagster_dg_core-1.10.18.tar.gz", hash = "sha256:4e988e30f1d390ccf4a1f0ec15d571a2482b14e18c4089bce515ef07c533cff0", size = 43181, upload-time = "2025-05-29T21:36:04.451Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/e6/58/59cc3dce6fa82cc62da6e6b876d8decc549255995a7394736d707543c1b7/dagster_dg_core-1.10.18-py3-none-any.whl", hash = "sha256:3299badb16949405033b8690796269112b4243f90766410edd0098c638c556ec", size = 49977, upload-time = "2025-05-29T21:36:03.419Z" },
]
[[package]]
name = "dagster-docker"
version = "0.26.18"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "dagster" },
{ name = "docker" },
{ name = "docker-image-py" },
]
sdist = { url = "https://files.pythonhosted.org/packages/fd/87/b091ac095ea590f581199c060058d01afeb6aea8de5cd78deb3937db908b/dagster_docker-0.26.18.tar.gz", hash = "sha256:7aa7625b7271c67eacda26a06d384134bdb2f34e548f4c7df74bd8fb085298b7", size = 16129, upload-time = "2025-05-29T21:37:41.881Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/7f/9e/2ab929c7103c63df6ebaef2bd1d16390173a842cdf7f60bd991655446f1c/dagster_docker-0.26.18-py3-none-any.whl", hash = "sha256:d62f07b8aa1e03f990cad413917a0c75515fc80a0c949c34e431c7465c1f1a8f", size = 19580, upload-time = "2025-05-29T21:37:40.775Z" },
]
[[package]]
name = "dagster-graphql"
version = "1.10.18"
@@ -1676,6 +1743,32 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/9b/ed/28fb14146c7033ba0e89decd92a4fa16b0b69b84471e2deab3cc4337cc35/dnspython-2.2.1-py3-none-any.whl", hash = "sha256:a851e51367fb93e9e1361732c1d60dab63eff98712e503ea7d92e6eccb109b4f", size = 269084, upload-time = "2022-03-06T23:36:12.209Z" },
]
[[package]]
name = "docker"
version = "7.1.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "pywin32", marker = "sys_platform == 'win32'" },
{ name = "requests" },
{ name = "urllib3" },
]
sdist = { url = "https://files.pythonhosted.org/packages/91/9b/4a2ea29aeba62471211598dac5d96825bb49348fa07e906ea930394a83ce/docker-7.1.0.tar.gz", hash = "sha256:ad8c70e6e3f8926cb8a92619b832b4ea5299e2831c14284663184e200546fa6c", size = 117834, upload-time = "2024-05-23T11:13:57.216Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/e3/26/57c6fb270950d476074c087527a558ccb6f4436657314bfb6cdf484114c4/docker-7.1.0-py3-none-any.whl", hash = "sha256:c96b93b7f0a746f9e77d325bcfb87422a3d8bd4f03136ae8a85b37f1898d5fc0", size = 147774, upload-time = "2024-05-23T11:13:55.01Z" },
]
[[package]]
name = "docker-image-py"
version = "0.1.13"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "regex" },
]
sdist = { url = "https://files.pythonhosted.org/packages/2f/10/28dad68f4693a5131ff859ac9dc8eda7967fd922509c1e667e143b0760b8/docker_image_py-0.1.13.tar.gz", hash = "sha256:de2755de0a09c99ae3b4cf42cc470a83b35bfdc4bf8ad66a3d7550d622a8915a", size = 11211, upload-time = "2024-07-17T05:36:04.298Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/39/06/c8d170aeb3e9eb3d951dd37acf1b6bad2b9401d0bee3346e40295c9e15a2/docker_image_py-0.1.13-py3-none-any.whl", hash = "sha256:c217fc72e8cdf2aa2caa718c758cec1ad41c09972ec1c94e85a4bf79a8f81061", size = 8894, upload-time = "2024-07-17T05:36:02.563Z" },
]
[[package]]
name = "docopt"
version = "0.6.2"
@@ -3015,6 +3108,15 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/87/fb/99f81ac72ae23375f22b7afdb7642aba97c00a713c217124420147681a2f/mako-1.3.10-py3-none-any.whl", hash = "sha256:baef24a52fc4fc514a0887ac600f9f1cff3d82c61d4d700a1fa84d597b88db59", size = 78509, upload-time = "2025-04-10T12:50:53.297Z" },
]
[[package]]
name = "markdown"
version = "3.8.2"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/d7/c2/4ab49206c17f75cb08d6311171f2d65798988db4360c4d1485bd0eedd67c/markdown-3.8.2.tar.gz", hash = "sha256:247b9a70dd12e27f67431ce62523e675b866d254f900c4fe75ce3dda62237c45", size = 362071, upload-time = "2025-06-19T17:12:44.483Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/96/2b/34cc11786bc00d0f04d0f5fdc3a2b1ae0b6239eef72d3d345805f9ad92a1/markdown-3.8.2-py3-none-any.whl", hash = "sha256:5c83764dbd4e00bdd94d85a19b8d55ccca20fe35b2e678a1422b380324dd5f24", size = 106827, upload-time = "2025-06-19T17:12:42.994Z" },
]
[[package]]
name = "markdown-it-py"
version = "3.0.0"
@@ -3929,6 +4031,8 @@ dependencies = [
{ name = "dagster-aws" },
{ name = "dagster-celery" },
{ name = "dagster-cloud" },
{ name = "dagster-docker" },
{ name = "dagster-pipes" },
{ name = "dagster-postgres" },
{ name = "dagster-slack" },
{ name = "dagster-webserver" },
@@ -4071,6 +4175,7 @@ dev = [
{ name = "boto3-stubs", extra = ["s3"] },
{ name = "braintrust" },
{ name = "braintrust-langchain" },
{ name = "dagster-dg-cli" },
{ name = "datamodel-code-generator" },
{ name = "deepdiff" },
{ name = "django-linear-migrations" },
@@ -4107,7 +4212,7 @@ dev = [
{ name = "responses" },
{ name = "ruff" },
{ name = "sqlalchemy" },
{ name = "stpyv8" },
{ name = "stpyv8", marker = "(platform_machine != 'aarch64' and platform_machine != 'arm64') or sys_platform != 'linux'" },
{ name = "syrupy" },
{ name = "tach" },
{ name = "types-aioboto3", extra = ["s3"] },
@@ -4153,6 +4258,8 @@ requires-dist = [
{ name = "dagster-aws", specifier = "==0.26.18" },
{ name = "dagster-celery", specifier = "==0.26.18" },
{ name = "dagster-cloud", specifier = "==1.10.18" },
{ name = "dagster-docker", specifier = "==0.26.18" },
{ name = "dagster-pipes", specifier = "==1.10.18" },
{ name = "dagster-postgres", specifier = "==0.26.18" },
{ name = "dagster-slack", specifier = "==0.26.18" },
{ name = "dagster-webserver", specifier = "==1.10.18" },
@@ -4293,8 +4400,9 @@ dev = [
{ name = "autoevals", specifier = "==0.0.129" },
{ name = "black", specifier = "~=23.9.1" },
{ name = "boto3-stubs", extras = ["s3"], specifier = ">=1.34.84" },
{ name = "braintrust", specifier = "==0.2.0" },
{ name = "braintrust-langchain", specifier = "==0.0.2" },
{ name = "braintrust", specifier = "==0.2.4" },
{ name = "braintrust-langchain", specifier = "==0.0.4" },
{ name = "dagster-dg-cli", specifier = ">=1.10.18" },
{ name = "datamodel-code-generator", specifier = "==0.28.5" },
{ name = "deepdiff", specifier = ">=8.5.0" },
{ name = "django-linear-migrations", specifier = "==2.16.*" },
@@ -4331,7 +4439,7 @@ dev = [
{ name = "responses", specifier = "==0.23.1" },
{ name = "ruff", specifier = "~=0.8.1" },
{ name = "sqlalchemy", specifier = "==2.0.38" },
{ name = "stpyv8", specifier = "==13.1.201.22" },
{ name = "stpyv8", marker = "(platform_machine != 'aarch64' and platform_machine != 'arm64') or sys_platform != 'linux'", specifier = "==13.1.201.22" },
{ name = "syrupy", specifier = "~=4.6.0" },
{ name = "tach", specifier = "~=0.20.0" },
{ name = "types-aioboto3", extras = ["s3"], specifier = ">=14.3.0" },
@@ -5930,6 +6038,15 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/e5/30/643397144bfbfec6f6ef821f36f33e57d35946c44a2352d3c9f0ae847619/tenacity-9.1.2-py3-none-any.whl", hash = "sha256:f77bf36710d8b73a50b2dd155c97b870017ad21afe6ab300326b0371b3b05138", size = 28248, upload-time = "2025-04-02T08:25:07.678Z" },
]
[[package]]
name = "termcolor"
version = "2.3.0"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/b8/85/147a0529b4e80b6b9d021ca8db3a820fcac53ec7374b87073d004aaf444c/termcolor-2.3.0.tar.gz", hash = "sha256:b5b08f68937f138fe92f6c089b99f1e2da0ae56c52b78bf7075fd95420fd9a5a", size = 12163, upload-time = "2023-04-23T19:45:24.004Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/67/e1/434566ffce04448192369c1a282931cf4ae593e91907558eaecd2e9f2801/termcolor-2.3.0-py3-none-any.whl", hash = "sha256:3afb05607b89aed0ffe25202399ee0867ad4d3cb4180d98aaf8eefa6a5f7d475", size = 6872, upload-time = "2023-04-23T19:45:22.671Z" },
]
[[package]]
name = "text-unidecode"
version = "1.3"
@@ -6920,6 +7037,18 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/ea/1f/70c57b3d7278e94ed22d85e09685d3f0a38ebdd8c5c73b65ba4c0d0fe002/yarl-1.20.0-py3-none-any.whl", hash = "sha256:5d0fe6af927a47a230f31e6004621fd0959eaa915fc62acfafa67ff7229a3124", size = 46124, upload-time = "2025-04-17T00:45:12.199Z" },
]
[[package]]
name = "yaspin"
version = "3.1.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "termcolor" },
]
sdist = { url = "https://files.pythonhosted.org/packages/07/3c/70df5034e6712fcc238b76f6afd1871de143a2a124d80ae2c377cde180f3/yaspin-3.1.0.tar.gz", hash = "sha256:7b97c7e257ec598f98cef9878e038bfa619ceb54ac31d61d8ead2b3128f8d7c7", size = 36791, upload-time = "2024-09-22T17:07:09.376Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/89/78/fa25b385d9f2c406719b5cf574a0980f5ccc6ea1f8411d56249f44acd3c2/yaspin-3.1.0-py3-none-any.whl", hash = "sha256:5e3d4dfb547d942cae6565718123f1ecfa93e745b7e51871ad2bbae839e71b73", size = 18629, upload-time = "2024-09-22T17:07:06.923Z" },
]
[[package]]
name = "zeep"
version = "4.3.1"