feat(max): dagster evaluation runner (#36320)
Co-authored-by: github-actions <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Co-authored-by: Michael Matloka <michael@matloka.com>
@@ -62,3 +62,6 @@ rust/cyclotron-node/index.node
|
||||
rust/cyclotron-node/node_modules
|
||||
rust/docker
|
||||
rust/target
|
||||
!docker-compose.base.yml
|
||||
!docker-compose.dev.yml
|
||||
!docker
|
||||
|
||||
66
.github/workflows/cd-ai-evals-image.yml
vendored
Normal file
@@ -0,0 +1,66 @@
|
||||
#
|
||||
# Build and push PostHog container images for AI evaluations to AWS ECR
|
||||
#
|
||||
# - posthog_ai_evals_build: build and push the PostHog container image to AWS ECR
|
||||
#
|
||||
name: AI Evals Container Images CD
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- master
|
||||
paths-ignore:
|
||||
- 'rust/**'
|
||||
- 'livestream/**'
|
||||
- 'plugin-server/**'
|
||||
pull_request:
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
posthog_ai_evals_build:
|
||||
name: Build and push container image
|
||||
if: |
|
||||
github.repository == 'PostHog/posthog' && (
|
||||
github.event_name == 'push' || contains(github.event.pull_request.labels.*.name, 'build-ai-evals-image')
|
||||
)
|
||||
runs-on: depot-ubuntu-latest
|
||||
permissions:
|
||||
id-token: write # allow issuing OIDC tokens for this workflow run
|
||||
contents: read # allow at least reading the repo contents, add other permissions if necessary
|
||||
|
||||
steps:
|
||||
- name: Check out
|
||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4
|
||||
with:
|
||||
fetch-depth: 2
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@b5ca514318bd6ebac0fb2aedd5d36ec1b5c232a2 # v3
|
||||
|
||||
- name: Set up QEMU
|
||||
uses: docker/setup-qemu-action@29109295f81e9208d7d86ff1c6c12d2833863392 # v3
|
||||
|
||||
- name: Set up Depot CLI
|
||||
uses: depot/setup-action@b0b1ea4f69e92ebf5dea3f8713a1b0c37b2126a5 # v1
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@e3dd6a429d7300a6a4c196c26e071d42e0343502 # v4
|
||||
with:
|
||||
aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }}
|
||||
aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
|
||||
aws-region: us-east-1
|
||||
|
||||
- name: Login to Amazon ECR
|
||||
id: aws-ecr
|
||||
uses: aws-actions/amazon-ecr-login@062b18b96a7aff071d4dc91bc00c4c1a7945b076 # v2
|
||||
|
||||
- name: Build and push container image
|
||||
id: build
|
||||
uses: depot/build-push-action@2583627a84956d07561420dcc1d0eb1f2af3fac0 # v1
|
||||
with:
|
||||
file: ./Dockerfile.ai-evals
|
||||
buildx-fallback: false # the fallback is so slow it's better to just fail
|
||||
push: true
|
||||
tags: ${{ github.ref == 'refs/heads/master' && format('{0}/posthog-ai-evals:master,{0}/posthog-ai-evals:{1}', steps.aws-ecr.outputs.registry, github.sha) || format('{0}/posthog-ai-evals:{1}', steps.aws-ecr.outputs.registry, github.sha) }}
|
||||
platforms: linux/arm64,linux/amd64
|
||||
build-args: COMMIT_HASH=${{ github.sha }}
|
||||
4
.github/workflows/ci-ai.yml
vendored
@@ -59,8 +59,10 @@ jobs:
|
||||
run: bin/check_kafka_clickhouse_up
|
||||
|
||||
- name: Run LLM evals
|
||||
run: pytest ee/hogai/eval -vv
|
||||
run: pytest ee/hogai/eval/ci -vv
|
||||
env:
|
||||
EVAL_MODE: ci
|
||||
EXPORT_EVAL_RESULTS: true
|
||||
BRAINTRUST_API_KEY: ${{ secrets.BRAINTRUST_API_KEY }}
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
|
||||
|
||||
52
Dockerfile.ai-evals
Normal file
@@ -0,0 +1,52 @@
|
||||
FROM python:3.11.9-slim-bookworm AS python-base
|
||||
FROM cruizba/ubuntu-dind:latest
|
||||
SHELL ["/bin/bash", "-e", "-o", "pipefail", "-c"]
|
||||
|
||||
# Copy Python base
|
||||
COPY --from=python-base /usr/local /usr/local
|
||||
|
||||
# Set working directory
|
||||
WORKDIR /code
|
||||
|
||||
# Copy Docker scripts
|
||||
COPY docker/ ./docker/
|
||||
|
||||
# Install system dependencies
|
||||
RUN apt-get update && \
|
||||
apt-get install -y --no-install-recommends \
|
||||
"build-essential" \
|
||||
"git" \
|
||||
"libpq-dev" \
|
||||
"libxmlsec1" \
|
||||
"libxmlsec1-dev" \
|
||||
"libffi-dev" \
|
||||
"zlib1g-dev" \
|
||||
"pkg-config" \
|
||||
"netcat-openbsd" \
|
||||
"postgresql-client"
|
||||
|
||||
# Copy uv dependencies
|
||||
COPY pyproject.toml uv.lock docker-compose.base.yml docker-compose.dev.yml ./
|
||||
|
||||
# Install python deps
|
||||
RUN rm -rf /var/lib/apt/lists/* && \
|
||||
pip install uv~=0.7.0 --no-cache-dir && \
|
||||
UV_PROJECT_ENVIRONMENT=/python-runtime uv sync --frozen --no-cache --compile-bytecode \
|
||||
--no-binary-package lxml --no-binary-package xmlsec
|
||||
|
||||
# Copy project files
|
||||
COPY bin/ ./bin/
|
||||
COPY manage.py manage.py
|
||||
COPY common/esbuilder common/esbuilder
|
||||
COPY common/hogvm common/hogvm/
|
||||
COPY posthog posthog/
|
||||
COPY products/ products/
|
||||
COPY ee ee/
|
||||
|
||||
ENV PATH=/python-runtime/bin:$PATH \
|
||||
PYTHONPATH=/python-runtime
|
||||
|
||||
# Make scripts executable
|
||||
RUN chmod +x bin/*
|
||||
|
||||
CMD bin/docker-ai-evals
|
||||
8
bin/check_clickhouse_up
Executable file
@@ -0,0 +1,8 @@
|
||||
#!/bin/bash
|
||||
|
||||
set -e
|
||||
|
||||
# Check ClickHouse
|
||||
while true; do
|
||||
curl -s -o /dev/null -I 'http://localhost:8123/' && break || echo 'Checking ClickHouse status...' && sleep 1
|
||||
done
|
||||
@@ -7,7 +7,4 @@ while true; do
|
||||
nc -z localhost 9092 && break || echo 'Checking Kafka status...' && sleep 1
|
||||
done
|
||||
|
||||
# Check ClickHouse
|
||||
while true; do
|
||||
curl -s -o /dev/null -I 'http://localhost:8123/' && break || echo 'Checking ClickHouse status...' && sleep 1
|
||||
done
|
||||
./bin/check_clickhouse_up
|
||||
|
||||
29
bin/docker-ai-evals
Executable file
@@ -0,0 +1,29 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
export DEBUG=1
|
||||
export IN_EVAL_TESTING=1
|
||||
export EVAL_MODE=offline
|
||||
export EXPORT_EVAL_RESULTS=1
|
||||
|
||||
cleanup() {
|
||||
echo "🧹 Cleaning up..."
|
||||
docker compose -f docker-compose.dev.yml -p evals down -v --remove-orphans || true
|
||||
}
|
||||
trap cleanup EXIT INT TERM
|
||||
|
||||
echo "127.0.0.1 kafka clickhouse objectstorage db" >> /etc/hosts
|
||||
|
||||
echo "🚀 Starting services..."
|
||||
docker compose -f docker-compose.dev.yml -p evals up -d db clickhouse objectstorage
|
||||
|
||||
echo "🔄 Waiting for services to start..."
|
||||
bin/check_postgres_up & bin/check_kafka_clickhouse_up
|
||||
|
||||
echo "🏃 Running evaluation..."
|
||||
if [ -z "$EVAL_SCRIPT" ]; then
|
||||
echo "Error: EVAL_SCRIPT environment variable is not set"
|
||||
exit 1
|
||||
fi
|
||||
$EVAL_SCRIPT
|
||||
|
||||
echo "🎉 Done."
|
||||
@@ -1,4 +1,5 @@
|
||||
import dagster
|
||||
from dagster_docker import PipesDockerClient
|
||||
|
||||
from dags.max_ai.run_evaluation import run_evaluation
|
||||
|
||||
@@ -6,5 +7,8 @@ from . import resources
|
||||
|
||||
defs = dagster.Definitions(
|
||||
jobs=[run_evaluation],
|
||||
resources=resources,
|
||||
resources={
|
||||
**resources,
|
||||
"docker_pipes_client": PipesDockerClient(),
|
||||
},
|
||||
)
|
||||
|
||||
@@ -1,22 +1,144 @@
|
||||
import base64
|
||||
|
||||
from django.conf import settings
|
||||
|
||||
import boto3
|
||||
import dagster
|
||||
from dagster_docker import PipesDockerClient
|
||||
from pydantic import Field
|
||||
from tenacity import retry, stop_after_attempt, wait_exponential
|
||||
|
||||
from dags.common import JobOwners
|
||||
from dags.max_ai.snapshot_project_data import snapshot_clickhouse_project_data, snapshot_postgres_project_data
|
||||
from dags.max_ai.snapshot_team_data import (
|
||||
ClickhouseTeamDataSnapshot,
|
||||
PostgresTeamDataSnapshot,
|
||||
snapshot_clickhouse_team_data,
|
||||
snapshot_postgres_team_data,
|
||||
)
|
||||
from ee.hogai.eval.schema import DatasetInput, EvalsDockerImageConfig, TeamEvaluationSnapshot
|
||||
|
||||
|
||||
class ExportProjectsConfig(dagster.Config):
|
||||
project_ids: list[int]
|
||||
"""Project IDs to run the evaluation for."""
|
||||
def get_object_storage_endpoint() -> str:
|
||||
"""
|
||||
Get the object storage endpoint.
|
||||
Debug mode uses the local object storage, so we need to set a DNS endpoint (like orb.dev).
|
||||
Production mode uses the AWS S3.
|
||||
"""
|
||||
if settings.DEBUG:
|
||||
val = dagster.EnvVar("EVALS_DIND_OBJECT_STORAGE_ENDPOINT").get_value("http://objectstorage.posthog.orb.local")
|
||||
if not val:
|
||||
raise ValueError("EVALS_DIND_OBJECT_STORAGE_ENDPOINT is not set")
|
||||
return val
|
||||
return settings.OBJECT_STORAGE_ENDPOINT
|
||||
|
||||
|
||||
class ExportTeamsConfig(dagster.Config):
|
||||
team_ids: list[int]
|
||||
"""Team IDs to run the evaluation for."""
|
||||
|
||||
|
||||
@dagster.op(out=dagster.DynamicOut(int))
|
||||
def export_projects(config: ExportProjectsConfig):
|
||||
seen_projects = set()
|
||||
for pid in config.project_ids:
|
||||
if pid in seen_projects:
|
||||
def export_teams(config: ExportTeamsConfig):
|
||||
seen_teams = set()
|
||||
for tid in config.team_ids:
|
||||
if tid in seen_teams:
|
||||
continue
|
||||
seen_projects.add(pid)
|
||||
yield dagster.DynamicOutput(pid, mapping_key=str(pid))
|
||||
seen_teams.add(tid)
|
||||
yield dagster.DynamicOutput(tid, mapping_key=str(tid))
|
||||
|
||||
|
||||
class EvaluationConfig(dagster.Config):
|
||||
image_name: str = Field(description="Name of the Docker image to run.")
|
||||
image_tag: str = Field(description="Tag of the Docker image to run.")
|
||||
experiment_name: str = Field(description="Name of the experiment.")
|
||||
evaluation_module: str = Field(description="Python module containing the evaluation runner.")
|
||||
|
||||
@property
|
||||
def image(self) -> str:
|
||||
# We use the local Docker image in debug mode
|
||||
if settings.DEBUG:
|
||||
return f"{self.image_name}:{self.image_tag}"
|
||||
return f"{dagster.EnvVar('AWS_EKS_REGISTRY_URL').get_value()}/{self.image_name}:{self.image_tag}"
|
||||
|
||||
|
||||
@retry(stop=stop_after_attempt(3), wait=wait_exponential(min=2))
|
||||
def get_registry_credentials():
|
||||
# We use the local Docker image in debug mode
|
||||
if settings.DEBUG:
|
||||
return None
|
||||
|
||||
client = boto3.client("ecr")
|
||||
# https://boto3.amazonaws.com/v1/documentation/api/1.29.2/reference/services/ecr/client/get_authorization_token.html
|
||||
token = client.get_authorization_token()["authorizationData"][0]["authorizationToken"]
|
||||
username, password = base64.b64decode(token).decode("utf-8").split(":")
|
||||
|
||||
return {
|
||||
"url": dagster.EnvVar("AWS_EKS_REGISTRY_URL").get_value(),
|
||||
"username": username,
|
||||
"password": password,
|
||||
}
|
||||
|
||||
|
||||
@dagster.op
|
||||
def spawn_evaluation_container(
|
||||
context: dagster.OpExecutionContext,
|
||||
config: EvaluationConfig,
|
||||
docker_pipes_client: PipesDockerClient,
|
||||
team_ids: list[int],
|
||||
postgres_snapshots: list[PostgresTeamDataSnapshot],
|
||||
clickhouse_snapshots: list[ClickhouseTeamDataSnapshot],
|
||||
):
|
||||
evaluation_config = EvalsDockerImageConfig(
|
||||
aws_endpoint_url=get_object_storage_endpoint(),
|
||||
aws_bucket_name=settings.OBJECT_STORAGE_BUCKET,
|
||||
team_snapshots=[
|
||||
TeamEvaluationSnapshot(team_id=team_id, postgres=postgres, clickhouse=clickhouse).model_dump()
|
||||
for team_id, postgres, clickhouse in zip(team_ids, postgres_snapshots, clickhouse_snapshots)
|
||||
],
|
||||
experiment_name=config.experiment_name,
|
||||
dataset=[
|
||||
DatasetInput(
|
||||
team_id=team_id,
|
||||
input={"query": "List all events from the last 7 days. Use SQL."},
|
||||
expected={"output": "SELECT * FROM events WHERE timestamp >= now() - INTERVAL 7 day"},
|
||||
)
|
||||
for team_id in team_ids
|
||||
],
|
||||
)
|
||||
|
||||
context.log.info(f"Running evaluation for the image: {config.image}")
|
||||
|
||||
asset_result = docker_pipes_client.run(
|
||||
context=context,
|
||||
image=config.image,
|
||||
container_kwargs={
|
||||
"privileged": True,
|
||||
"auto_remove": True,
|
||||
},
|
||||
env={
|
||||
"EVAL_SCRIPT": f"pytest {config.evaluation_module} -s -vv",
|
||||
"OBJECT_STORAGE_ACCESS_KEY_ID": settings.OBJECT_STORAGE_ACCESS_KEY_ID, # type: ignore
|
||||
"OBJECT_STORAGE_SECRET_ACCESS_KEY": settings.OBJECT_STORAGE_SECRET_ACCESS_KEY, # type: ignore
|
||||
"OPENAI_API_KEY": settings.OPENAI_API_KEY,
|
||||
"ANTHROPIC_API_KEY": settings.ANTHROPIC_API_KEY,
|
||||
"GEMINI_API_KEY": settings.GEMINI_API_KEY,
|
||||
"INKEEP_API_KEY": settings.INKEEP_API_KEY,
|
||||
"PPLX_API_KEY": settings.PPLX_API_KEY,
|
||||
"AZURE_INFERENCE_ENDPOINT": settings.AZURE_INFERENCE_ENDPOINT,
|
||||
"AZURE_INFERENCE_CREDENTIAL": settings.AZURE_INFERENCE_CREDENTIAL,
|
||||
"BRAINTRUST_API_KEY": settings.BRAINTRUST_API_KEY,
|
||||
},
|
||||
extras=evaluation_config.model_dump(exclude_unset=True),
|
||||
registry=get_registry_credentials(),
|
||||
).get_materialize_result()
|
||||
|
||||
context.log_event(
|
||||
dagster.AssetMaterialization(
|
||||
asset_key=asset_result.asset_key or "evaluation_report",
|
||||
metadata=asset_result.metadata,
|
||||
tags={"owner": JobOwners.TEAM_MAX_AI.value},
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@dagster.job(
|
||||
@@ -28,11 +150,18 @@ def export_projects(config: ExportProjectsConfig):
|
||||
executor_def=dagster.multiprocess_executor.configured({"max_concurrent": 4}),
|
||||
config=dagster.RunConfig(
|
||||
ops={
|
||||
"export_projects": ExportProjectsConfig(project_ids=[]),
|
||||
"export_teams": ExportTeamsConfig(team_ids=[]),
|
||||
"spawn_evaluation_container": EvaluationConfig(
|
||||
evaluation_module="",
|
||||
experiment_name="offline_evaluation",
|
||||
image_name="posthog-ai-evals",
|
||||
image_tag="master",
|
||||
),
|
||||
}
|
||||
),
|
||||
)
|
||||
def run_evaluation():
|
||||
project_ids = export_projects()
|
||||
project_ids.map(snapshot_postgres_project_data)
|
||||
project_ids.map(snapshot_clickhouse_project_data)
|
||||
team_ids = export_teams()
|
||||
postgres_snapshots = team_ids.map(snapshot_postgres_team_data)
|
||||
clickhouse_snapshots = team_ids.map(snapshot_clickhouse_team_data)
|
||||
spawn_evaluation_container(team_ids.collect(), postgres_snapshots.collect(), clickhouse_snapshots.collect())
|
||||
|
||||
@@ -9,6 +9,9 @@ from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_ex
|
||||
from posthog.schema import (
|
||||
ActorsPropertyTaxonomyQuery,
|
||||
ActorsPropertyTaxonomyResponse,
|
||||
CachedActorsPropertyTaxonomyQueryResponse,
|
||||
CachedEventTaxonomyQueryResponse,
|
||||
CachedTeamTaxonomyQueryResponse,
|
||||
EventTaxonomyQuery,
|
||||
TeamTaxonomyItem,
|
||||
TeamTaxonomyQuery,
|
||||
@@ -18,6 +21,7 @@ from posthog.errors import InternalCHQueryError
|
||||
from posthog.hogql_queries.ai.actors_property_taxonomy_query_runner import ActorsPropertyTaxonomyQueryRunner
|
||||
from posthog.hogql_queries.ai.event_taxonomy_query_runner import EventTaxonomyQueryRunner
|
||||
from posthog.hogql_queries.ai.team_taxonomy_query_runner import TeamTaxonomyQueryRunner
|
||||
from posthog.hogql_queries.query_runner import ExecutionMode
|
||||
from posthog.models import GroupTypeMapping, Team
|
||||
from posthog.models.property_definition import PropertyDefinition
|
||||
|
||||
@@ -26,10 +30,10 @@ from dags.max_ai.utils import check_dump_exists, compose_clickhouse_dump_path, c
|
||||
from ee.hogai.eval.schema import (
|
||||
ActorsPropertyTaxonomySnapshot,
|
||||
BaseSnapshot,
|
||||
ClickhouseProjectDataSnapshot,
|
||||
ClickhouseTeamDataSnapshot,
|
||||
DataWarehouseTableSnapshot,
|
||||
GroupTypeMappingSnapshot,
|
||||
PostgresProjectDataSnapshot,
|
||||
PostgresTeamDataSnapshot,
|
||||
PropertyDefinitionSnapshot,
|
||||
PropertyTaxonomySnapshot,
|
||||
TeamSnapshot,
|
||||
@@ -47,26 +51,35 @@ DEFAULT_RETRY_POLICY = dagster.RetryPolicy(
|
||||
SchemaBound = TypeVar("SchemaBound", bound=BaseSnapshot)
|
||||
|
||||
|
||||
class SnapshotUnrecoverableError(ValueError):
|
||||
"""
|
||||
An error that indicates that the snapshot operation cannot be recovered from.
|
||||
This is used to indicate that the snapshot operation failed and should not be retried.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
def snapshot_postgres_model(
|
||||
context: dagster.OpExecutionContext,
|
||||
model_type: type[SchemaBound],
|
||||
file_name: str,
|
||||
s3: S3Resource,
|
||||
project_id: int,
|
||||
team_id: int,
|
||||
code_version: str | None = None,
|
||||
) -> str:
|
||||
file_key = compose_postgres_dump_path(project_id, file_name, code_version)
|
||||
file_key = compose_postgres_dump_path(team_id, file_name, code_version)
|
||||
if check_dump_exists(s3, file_key):
|
||||
context.log.info(f"Skipping {file_key} because it already exists")
|
||||
return file_key
|
||||
context.log.info(f"Dumping {file_key}")
|
||||
with dump_model(s3=s3, schema=model_type, file_key=file_key) as dump:
|
||||
dump(model_type.serialize_for_project(project_id))
|
||||
dump(model_type.serialize_for_team(team_id=team_id))
|
||||
return file_key
|
||||
|
||||
|
||||
@dagster.op(
|
||||
description="Snapshots Postgres project data (property definitions, DWH schema, etc.)",
|
||||
description="Snapshots Postgres team data (property definitions, DWH schema, etc.)",
|
||||
retry_policy=DEFAULT_RETRY_POLICY,
|
||||
code_version="v1",
|
||||
tags={
|
||||
@@ -74,29 +87,38 @@ def snapshot_postgres_model(
|
||||
"dagster/max_runtime": 60 * 15, # 15 minutes
|
||||
},
|
||||
)
|
||||
def snapshot_postgres_project_data(
|
||||
context: dagster.OpExecutionContext, project_id: int, s3: S3Resource
|
||||
) -> PostgresProjectDataSnapshot:
|
||||
context.log.info(f"Snapshotting Postgres project data for {project_id}")
|
||||
def snapshot_postgres_team_data(
|
||||
context: dagster.OpExecutionContext, team_id: int, s3: S3Resource
|
||||
) -> PostgresTeamDataSnapshot:
|
||||
context.log.info(f"Snapshotting Postgres team data for {team_id}")
|
||||
snapshot_map: dict[str, type[BaseSnapshot]] = {
|
||||
"project": TeamSnapshot,
|
||||
"team": TeamSnapshot,
|
||||
"property_definitions": PropertyDefinitionSnapshot,
|
||||
"group_type_mappings": GroupTypeMappingSnapshot,
|
||||
"data_warehouse_tables": DataWarehouseTableSnapshot,
|
||||
}
|
||||
deps = {
|
||||
file_name: snapshot_postgres_model(context, model_type, file_name, s3, project_id, context.op_def.version)
|
||||
for file_name, model_type in snapshot_map.items()
|
||||
}
|
||||
context.log_event(
|
||||
dagster.AssetMaterialization(
|
||||
asset_key="project_postgres_snapshot",
|
||||
description="Avro snapshots of project Postgres data",
|
||||
metadata={"project_id": project_id, **deps},
|
||||
tags={"owner": JobOwners.TEAM_MAX_AI.value},
|
||||
|
||||
try:
|
||||
deps = {
|
||||
file_name: snapshot_postgres_model(context, model_type, file_name, s3, team_id, context.op_def.version)
|
||||
for file_name, model_type in snapshot_map.items()
|
||||
}
|
||||
context.log_event(
|
||||
dagster.AssetMaterialization(
|
||||
asset_key="team_postgres_snapshot",
|
||||
description="Avro snapshots of team Postgres data",
|
||||
metadata={"team_id": team_id, **deps},
|
||||
tags={"owner": JobOwners.TEAM_MAX_AI.value},
|
||||
)
|
||||
)
|
||||
)
|
||||
return PostgresProjectDataSnapshot(**deps)
|
||||
except Team.DoesNotExist as e:
|
||||
raise dagster.Failure(
|
||||
description=f"Team {team_id} does not exist",
|
||||
metadata={"team_id": team_id},
|
||||
allow_retries=False,
|
||||
) from e
|
||||
|
||||
return PostgresTeamDataSnapshot(**deps)
|
||||
|
||||
|
||||
C = TypeVar("C")
|
||||
@@ -120,17 +142,20 @@ def snapshot_properties_taxonomy(
|
||||
):
|
||||
results: list[PropertyTaxonomySnapshot] = []
|
||||
|
||||
def snapshot_event(item: TeamTaxonomyItem):
|
||||
return call_query_runner(
|
||||
lambda: EventTaxonomyQueryRunner(
|
||||
query=EventTaxonomyQuery(event=item.event),
|
||||
team=team,
|
||||
).calculate()
|
||||
def wrapped_query_runner(item: TeamTaxonomyItem):
|
||||
response = EventTaxonomyQueryRunner(query=EventTaxonomyQuery(event=item.event), team=team).run(
|
||||
execution_mode=ExecutionMode.RECENT_CACHE_CALCULATE_BLOCKING_IF_STALE
|
||||
)
|
||||
if not isinstance(response, CachedEventTaxonomyQueryResponse):
|
||||
raise SnapshotUnrecoverableError(f"Unexpected response type from event taxonomy query: {type(response)}")
|
||||
return response
|
||||
|
||||
def call_event_taxonomy_query(item: TeamTaxonomyItem):
|
||||
return call_query_runner(lambda: wrapped_query_runner(item))
|
||||
|
||||
for item in events:
|
||||
context.log.info(f"Snapshotting properties taxonomy for event {item.event} of {team.id}")
|
||||
results.append(PropertyTaxonomySnapshot(event=item.event, results=snapshot_event(item).results))
|
||||
results.append(PropertyTaxonomySnapshot(event=item.event, results=call_event_taxonomy_query(item).results))
|
||||
|
||||
context.log.info(f"Dumping properties taxonomy to {file_key}")
|
||||
with dump_model(s3=s3, schema=PropertyTaxonomySnapshot, file_key=file_key) as dump:
|
||||
@@ -152,9 +177,18 @@ def snapshot_events_taxonomy(
|
||||
|
||||
context.log.info(f"Snapshotting events taxonomy for {team.id}")
|
||||
|
||||
res = call_query_runner(lambda: TeamTaxonomyQueryRunner(query=TeamTaxonomyQuery(), team=team).calculate())
|
||||
def snapshot_events_taxonomy():
|
||||
response = TeamTaxonomyQueryRunner(query=TeamTaxonomyQuery(), team=team).run(
|
||||
execution_mode=ExecutionMode.RECENT_CACHE_CALCULATE_BLOCKING_IF_STALE
|
||||
)
|
||||
if not isinstance(response, CachedTeamTaxonomyQueryResponse):
|
||||
raise SnapshotUnrecoverableError(f"Unexpected response type from events taxonomy query: {type(response)}")
|
||||
return response
|
||||
|
||||
res = call_query_runner(snapshot_events_taxonomy)
|
||||
|
||||
if not res.results:
|
||||
raise ValueError("No results from events taxonomy query")
|
||||
raise SnapshotUnrecoverableError("No results from events taxonomy query")
|
||||
|
||||
# Dump properties
|
||||
snapshot_properties_taxonomy(context, s3, properties_file_key, team, res.results)
|
||||
@@ -216,18 +250,24 @@ def snapshot_actors_property_taxonomy(
|
||||
# Query ClickHouse in batches of 200 properties
|
||||
for batch in chunked(property_defs, 200):
|
||||
|
||||
def snapshot(index: int | None, batch: list[str]):
|
||||
return call_query_runner(
|
||||
lambda: ActorsPropertyTaxonomyQueryRunner(
|
||||
query=ActorsPropertyTaxonomyQuery(groupTypeIndex=index, properties=batch, maxPropertyValues=25),
|
||||
team=team,
|
||||
).calculate()
|
||||
)
|
||||
def wrapped_query_runner(index: int | None, batch: list[str]):
|
||||
response = ActorsPropertyTaxonomyQueryRunner(
|
||||
query=ActorsPropertyTaxonomyQuery(groupTypeIndex=index, properties=batch, maxPropertyValues=25),
|
||||
team=team,
|
||||
).run(execution_mode=ExecutionMode.RECENT_CACHE_CALCULATE_BLOCKING_IF_STALE)
|
||||
if not isinstance(response, CachedActorsPropertyTaxonomyQueryResponse):
|
||||
raise SnapshotUnrecoverableError(
|
||||
f"Unexpected response type from actors property taxonomy query: {type(response)}"
|
||||
)
|
||||
return response
|
||||
|
||||
res = snapshot(index, batch)
|
||||
def call_actors_property_taxonomy_query(index: int | None, batch: list[str]):
|
||||
return call_query_runner(lambda: wrapped_query_runner(index, batch))
|
||||
|
||||
res = call_actors_property_taxonomy_query(index, batch)
|
||||
|
||||
if not res.results:
|
||||
raise ValueError(
|
||||
raise SnapshotUnrecoverableError(
|
||||
f"No results from actors property taxonomy query for group type {index} and properties {batch}"
|
||||
)
|
||||
|
||||
@@ -244,7 +284,7 @@ def snapshot_actors_property_taxonomy(
|
||||
|
||||
|
||||
@dagster.op(
|
||||
description="Snapshots ClickHouse project data",
|
||||
description="Snapshots ClickHouse team data",
|
||||
retry_policy=DEFAULT_RETRY_POLICY,
|
||||
tags={
|
||||
"owner": JobOwners.TEAM_MAX_AI.value,
|
||||
@@ -252,17 +292,33 @@ def snapshot_actors_property_taxonomy(
|
||||
},
|
||||
code_version="v1",
|
||||
)
|
||||
def snapshot_clickhouse_project_data(
|
||||
context: dagster.OpExecutionContext, project_id: int, s3: S3Resource
|
||||
) -> ClickhouseProjectDataSnapshot:
|
||||
team = Team.objects.get(id=project_id)
|
||||
def snapshot_clickhouse_team_data(
|
||||
context: dagster.OpExecutionContext, team_id: int, s3: S3Resource
|
||||
) -> ClickhouseTeamDataSnapshot:
|
||||
try:
|
||||
team = Team.objects.get(id=team_id)
|
||||
|
||||
event_taxonomy_file_key, properties_taxonomy_file_key = snapshot_events_taxonomy(
|
||||
context, s3, team, context.op_def.version
|
||||
)
|
||||
actors_property_taxonomy_file_key = snapshot_actors_property_taxonomy(context, s3, team, context.op_def.version)
|
||||
except Team.DoesNotExist as e:
|
||||
raise dagster.Failure(
|
||||
description=f"Team {team_id} does not exist",
|
||||
metadata={"team_id": team_id},
|
||||
allow_retries=False,
|
||||
) from e
|
||||
|
||||
materialized_result = ClickhouseProjectDataSnapshot(
|
||||
try:
|
||||
event_taxonomy_file_key, properties_taxonomy_file_key = snapshot_events_taxonomy(
|
||||
context, s3, team, context.op_def.version
|
||||
)
|
||||
actors_property_taxonomy_file_key = snapshot_actors_property_taxonomy(context, s3, team, context.op_def.version)
|
||||
|
||||
except SnapshotUnrecoverableError as e:
|
||||
raise dagster.Failure(
|
||||
description=f"Error snapshotting team {team_id}",
|
||||
metadata={"team_id": team_id},
|
||||
allow_retries=False,
|
||||
) from e
|
||||
|
||||
materialized_result = ClickhouseTeamDataSnapshot(
|
||||
event_taxonomy=event_taxonomy_file_key,
|
||||
properties_taxonomy=properties_taxonomy_file_key,
|
||||
actors_property_taxonomy=actors_property_taxonomy_file_key,
|
||||
@@ -270,10 +326,10 @@ def snapshot_clickhouse_project_data(
|
||||
|
||||
context.log_event(
|
||||
dagster.AssetMaterialization(
|
||||
asset_key="project_clickhouse_snapshot",
|
||||
description="Avro snapshots of project's ClickHouse queries",
|
||||
asset_key="team_clickhouse_snapshot",
|
||||
description="Avro snapshots of team's ClickHouse queries",
|
||||
metadata={
|
||||
"project_id": project_id,
|
||||
"team_id": team_id,
|
||||
**materialized_result.model_dump(),
|
||||
},
|
||||
tags={"owner": JobOwners.TEAM_MAX_AI.value},
|
||||
@@ -5,19 +5,27 @@ import dagster
|
||||
from dagster import OpExecutionContext
|
||||
from dagster_aws.s3.resources import S3Resource
|
||||
|
||||
from posthog.schema import ActorsPropertyTaxonomyResponse, EventTaxonomyItem, TeamTaxonomyItem
|
||||
from posthog.schema import (
|
||||
ActorsPropertyTaxonomyResponse,
|
||||
CachedEventTaxonomyQueryResponse,
|
||||
CachedTeamTaxonomyQueryResponse,
|
||||
EventTaxonomyItem,
|
||||
TeamTaxonomyItem,
|
||||
)
|
||||
|
||||
from posthog.models import GroupTypeMapping, Organization, Project, Team
|
||||
from posthog.models.property_definition import PropertyDefinition
|
||||
|
||||
from dags.max_ai.snapshot_project_data import (
|
||||
from dags.max_ai.snapshot_team_data import (
|
||||
SnapshotUnrecoverableError,
|
||||
snapshot_actors_property_taxonomy,
|
||||
snapshot_clickhouse_team_data,
|
||||
snapshot_events_taxonomy,
|
||||
snapshot_postgres_model,
|
||||
snapshot_postgres_project_data,
|
||||
snapshot_postgres_team_data,
|
||||
snapshot_properties_taxonomy,
|
||||
)
|
||||
from ee.hogai.eval.schema import PostgresProjectDataSnapshot, TeamSnapshot
|
||||
from ee.hogai.eval.schema import PostgresTeamDataSnapshot, TeamSnapshot
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -58,7 +66,7 @@ def team():
|
||||
|
||||
@pytest.fixture
|
||||
def mock_dump():
|
||||
with patch("dags.max_ai.snapshot_project_data.dump_model") as mock_dump_model:
|
||||
with patch("dags.max_ai.snapshot_team_data.dump_model") as mock_dump_model:
|
||||
mock_dump_context = MagicMock()
|
||||
mock_dump_function = MagicMock()
|
||||
mock_dump_context.__enter__ = MagicMock(return_value=mock_dump_function)
|
||||
@@ -67,8 +75,8 @@ def mock_dump():
|
||||
yield mock_dump_function
|
||||
|
||||
|
||||
@patch("dags.max_ai.snapshot_project_data.compose_postgres_dump_path")
|
||||
@patch("dags.max_ai.snapshot_project_data.check_dump_exists")
|
||||
@patch("dags.max_ai.snapshot_team_data.compose_postgres_dump_path")
|
||||
@patch("dags.max_ai.snapshot_team_data.check_dump_exists")
|
||||
def test_snapshot_postgres_model_skips_when_file_exists(
|
||||
mock_check_dump_exists, mock_compose_path, mock_context, mock_s3
|
||||
):
|
||||
@@ -78,7 +86,7 @@ def test_snapshot_postgres_model_skips_when_file_exists(
|
||||
mock_compose_path.return_value = file_key
|
||||
mock_check_dump_exists.return_value = True
|
||||
|
||||
project_id = 123
|
||||
team_id = 123
|
||||
file_name = "teams"
|
||||
code_version = "v1"
|
||||
|
||||
@@ -88,19 +96,19 @@ def test_snapshot_postgres_model_skips_when_file_exists(
|
||||
model_type=TeamSnapshot,
|
||||
file_name=file_name,
|
||||
s3=mock_s3,
|
||||
project_id=project_id,
|
||||
team_id=team_id,
|
||||
code_version=code_version,
|
||||
)
|
||||
|
||||
# Verify
|
||||
assert result == file_key
|
||||
mock_compose_path.assert_called_once_with(project_id, file_name, code_version)
|
||||
mock_compose_path.assert_called_once_with(team_id, file_name, code_version)
|
||||
mock_check_dump_exists.assert_called_once_with(mock_s3, file_key)
|
||||
mock_context.log.info.assert_called_once_with(f"Skipping {file_key} because it already exists")
|
||||
|
||||
|
||||
@patch("dags.max_ai.snapshot_project_data.compose_postgres_dump_path")
|
||||
@patch("dags.max_ai.snapshot_project_data.check_dump_exists")
|
||||
@patch("dags.max_ai.snapshot_team_data.compose_postgres_dump_path")
|
||||
@patch("dags.max_ai.snapshot_team_data.check_dump_exists")
|
||||
def test_snapshot_postgres_model_dumps_when_file_not_exists(
|
||||
mock_check_dump_exists, mock_compose_path, mock_context, mock_s3, mock_dump
|
||||
):
|
||||
@@ -110,10 +118,10 @@ def test_snapshot_postgres_model_dumps_when_file_not_exists(
|
||||
mock_compose_path.return_value = file_key
|
||||
mock_check_dump_exists.return_value = False
|
||||
|
||||
# Mock the serialize_for_project method
|
||||
# Mock the serialize_for_team method
|
||||
mock_serialized_data = [{"id": 1, "name": "test"}]
|
||||
with patch.object(TeamSnapshot, "serialize_for_project", return_value=mock_serialized_data):
|
||||
project_id = 123
|
||||
with patch.object(TeamSnapshot, "serialize_for_team", return_value=mock_serialized_data):
|
||||
team_id = 123
|
||||
file_name = "teams"
|
||||
code_version = "v1"
|
||||
|
||||
@@ -123,25 +131,25 @@ def test_snapshot_postgres_model_dumps_when_file_not_exists(
|
||||
model_type=TeamSnapshot,
|
||||
file_name=file_name,
|
||||
s3=mock_s3,
|
||||
project_id=project_id,
|
||||
team_id=team_id,
|
||||
code_version=code_version,
|
||||
)
|
||||
|
||||
# Verify
|
||||
assert result == file_key
|
||||
mock_compose_path.assert_called_once_with(project_id, file_name, code_version)
|
||||
mock_compose_path.assert_called_once_with(team_id, file_name, code_version)
|
||||
mock_check_dump_exists.assert_called_once_with(mock_s3, file_key)
|
||||
mock_context.log.info.assert_called_with(f"Dumping {file_key}")
|
||||
mock_dump.assert_called_once_with(mock_serialized_data)
|
||||
|
||||
|
||||
@patch("dags.max_ai.snapshot_project_data.snapshot_postgres_model")
|
||||
def test_snapshot_postgres_project_data_exports_all_models(mock_snapshot_postgres_model, mock_s3):
|
||||
"""Test that snapshot_postgres_project_data exports all expected models."""
|
||||
@patch("dags.max_ai.snapshot_team_data.snapshot_postgres_model")
|
||||
def test_snapshot_postgres_team_data_exports_all_models(mock_snapshot_postgres_model, mock_s3):
|
||||
"""Test that snapshot_postgres_team_data exports all expected models."""
|
||||
# Setup
|
||||
project_id = 456
|
||||
team_id = 456
|
||||
mock_snapshot_postgres_model.side_effect = [
|
||||
"path/to/project.avro",
|
||||
"path/to/team.avro",
|
||||
"path/to/property_definitions.avro",
|
||||
"path/to/group_type_mappings.avro",
|
||||
"path/to/data_warehouse_tables.avro",
|
||||
@@ -151,11 +159,11 @@ def test_snapshot_postgres_project_data_exports_all_models(mock_snapshot_postgre
|
||||
context = dagster.build_op_context()
|
||||
|
||||
# Execute
|
||||
result = snapshot_postgres_project_data(context, project_id, mock_s3)
|
||||
result = snapshot_postgres_team_data(context, team_id, mock_s3)
|
||||
|
||||
# Verify all expected models are in the result
|
||||
assert isinstance(result, PostgresProjectDataSnapshot)
|
||||
assert result.project == "path/to/project.avro"
|
||||
assert isinstance(result, PostgresTeamDataSnapshot)
|
||||
assert result.team == "path/to/team.avro"
|
||||
assert result.property_definitions == "path/to/property_definitions.avro"
|
||||
assert result.group_type_mappings == "path/to/group_type_mappings.avro"
|
||||
assert result.data_warehouse_tables == "path/to/data_warehouse_tables.avro"
|
||||
@@ -165,7 +173,7 @@ def test_snapshot_postgres_project_data_exports_all_models(mock_snapshot_postgre
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
@patch("dags.max_ai.snapshot_project_data.call_query_runner")
|
||||
@patch("dags.max_ai.snapshot_team_data.call_query_runner")
|
||||
def test_snapshot_properties_taxonomy(mock_call_query_runner, mock_context, mock_s3, team, mock_dump):
|
||||
"""Test that snapshot_properties_taxonomy correctly processes events and dumps results."""
|
||||
# Setup
|
||||
@@ -190,9 +198,47 @@ def test_snapshot_properties_taxonomy(mock_call_query_runner, mock_context, mock
|
||||
mock_dump.assert_called_once()
|
||||
|
||||
|
||||
@patch("dags.max_ai.snapshot_project_data.check_dump_exists")
|
||||
@patch("dags.max_ai.snapshot_project_data.EventTaxonomyQueryRunner.calculate")
|
||||
@patch("dags.max_ai.snapshot_project_data.TeamTaxonomyQueryRunner.calculate")
|
||||
@patch("dags.max_ai.snapshot_team_data.snapshot_postgres_model")
|
||||
def test_snapshot_postgres_team_data_raises_failure_on_missing_team(mock_snapshot_postgres_model, mock_s3):
|
||||
mock_snapshot_postgres_model.side_effect = Team.DoesNotExist()
|
||||
|
||||
context = dagster.build_op_context()
|
||||
|
||||
with pytest.raises(dagster.Failure) as exc:
|
||||
snapshot_postgres_team_data(context=context, team_id=999999, s3=mock_s3)
|
||||
|
||||
assert getattr(exc.value, "allow_retries", None) is False
|
||||
assert "Team 999999 does not exist" in str(exc.value)
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_snapshot_clickhouse_team_data_raises_failure_on_missing_team(mock_s3):
|
||||
context = dagster.build_op_context()
|
||||
|
||||
with pytest.raises(dagster.Failure) as exc:
|
||||
snapshot_clickhouse_team_data(context=context, team_id=424242, s3=mock_s3)
|
||||
|
||||
assert getattr(exc.value, "allow_retries", None) is False
|
||||
assert "Team 424242 does not exist" in str(exc.value)
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
@patch("dags.max_ai.snapshot_team_data.snapshot_events_taxonomy")
|
||||
def test_snapshot_clickhouse_team_data_raises_failure_on_unrecoverable_error(
|
||||
mock_snapshot_events_taxonomy, mock_s3, team
|
||||
):
|
||||
context = dagster.build_op_context()
|
||||
mock_snapshot_events_taxonomy.side_effect = SnapshotUnrecoverableError("boom")
|
||||
|
||||
with pytest.raises(dagster.Failure) as exc:
|
||||
snapshot_clickhouse_team_data(context=context, team_id=team.id, s3=mock_s3)
|
||||
|
||||
assert getattr(exc.value, "allow_retries", None) is False
|
||||
|
||||
|
||||
@patch("dags.max_ai.snapshot_team_data.check_dump_exists")
|
||||
@patch("dags.max_ai.snapshot_team_data.EventTaxonomyQueryRunner.run")
|
||||
@patch("dags.max_ai.snapshot_team_data.TeamTaxonomyQueryRunner.run")
|
||||
@pytest.mark.django_db
|
||||
def test_snapshot_events_taxonomy(
|
||||
mock_team_taxonomy_query_runner,
|
||||
@@ -207,16 +253,18 @@ def test_snapshot_events_taxonomy(
|
||||
mock_check_dump_exists.return_value = False
|
||||
|
||||
mock_team_taxonomy_query_runner.return_value = MagicMock(
|
||||
spec=CachedTeamTaxonomyQueryResponse,
|
||||
results=[
|
||||
TeamTaxonomyItem(event="pageview", count=2),
|
||||
TeamTaxonomyItem(event="click", count=1),
|
||||
]
|
||||
],
|
||||
)
|
||||
|
||||
mock_event_taxonomy_query_runner.return_value = MagicMock(
|
||||
spec=CachedEventTaxonomyQueryResponse,
|
||||
results=[
|
||||
EventTaxonomyItem(property="$current_url", sample_values=["https://posthog.com"], sample_count=1),
|
||||
]
|
||||
],
|
||||
)
|
||||
|
||||
mock_s3_client = MagicMock()
|
||||
@@ -228,7 +276,7 @@ def test_snapshot_events_taxonomy(
|
||||
assert mock_dump.call_count == 2
|
||||
|
||||
|
||||
@patch("dags.max_ai.snapshot_project_data.check_dump_exists")
|
||||
@patch("dags.max_ai.snapshot_team_data.check_dump_exists")
|
||||
@pytest.mark.django_db
|
||||
def test_snapshot_events_taxonomy_can_be_skipped(mock_check_dump_exists, mock_context, mock_s3, team, mock_dump):
|
||||
"""est that snapshot_events_taxonomy can be skipped when file already exists."""
|
||||
@@ -241,7 +289,7 @@ def test_snapshot_events_taxonomy_can_be_skipped(mock_check_dump_exists, mock_co
|
||||
assert mock_dump.call_count == 0
|
||||
|
||||
|
||||
@patch("dags.max_ai.snapshot_project_data.check_dump_exists")
|
||||
@patch("dags.max_ai.snapshot_team_data.check_dump_exists")
|
||||
@pytest.mark.django_db
|
||||
def test_snapshot_actors_property_taxonomy_can_be_skipped(
|
||||
mock_check_dump_exists, mock_context, mock_s3, team, mock_dump
|
||||
@@ -263,8 +311,8 @@ def test_snapshot_actors_property_taxonomy_can_be_skipped(
|
||||
)
|
||||
|
||||
|
||||
@patch("dags.max_ai.snapshot_project_data.check_dump_exists")
|
||||
@patch("dags.max_ai.snapshot_project_data.call_query_runner")
|
||||
@patch("dags.max_ai.snapshot_team_data.check_dump_exists")
|
||||
@patch("dags.max_ai.snapshot_team_data.call_query_runner")
|
||||
@pytest.mark.django_db
|
||||
def test_snapshot_actors_property_taxonomy_dumps_with_group_type_mapping(
|
||||
mock_call_query_runner, mock_check_dump_exists, mock_context, mock_s3, team, mock_dump
|
||||
@@ -12,12 +12,12 @@ We currently use [Braintrust](https://braintrust.dev) as our evaluation platform
|
||||
2. Run all evals with:
|
||||
|
||||
```bash
|
||||
pytest ee/hogai/eval
|
||||
pytest ee/hogai/eval/ci
|
||||
```
|
||||
|
||||
The key bit is specifying the `ee/hogai/eval` directory – that activates our eval-specific config, `ee/hogai/eval/pytest.ini`!
|
||||
The key bit is specifying the `ee/hogai/eval/ci` directory – that activates our eval-specific config, `ee/hogai/eval/pytest.ini`!
|
||||
|
||||
As always with pytest, you can also run a specific file, e.g. `pytest ee/hogai/eval/eval_root.py`.
|
||||
As always with pytest, you can also run a specific file, e.g. `pytest ee/hogai/eval/ci/eval_root.py`. Apply the `--eval sql` argument to only run evals for test cases that contain `sql`.
|
||||
|
||||
3. Voila! Max ran, evals executed, and results and traces uploaded to the Braintrust platform + summarized in the terminal.
|
||||
|
||||
|
||||
65
ee/hogai/eval/base.py
Normal file
@@ -0,0 +1,65 @@
|
||||
import os
|
||||
import asyncio
|
||||
from collections.abc import Sequence
|
||||
from functools import partial
|
||||
|
||||
import pytest
|
||||
|
||||
from braintrust import EvalAsync, Metadata, init_logger
|
||||
from braintrust.framework import EvalData, EvalScorer, EvalTask, Input, Output
|
||||
|
||||
|
||||
async def BaseMaxEval(
|
||||
experiment_name: str,
|
||||
data: EvalData[Input, Output],
|
||||
task: EvalTask[Input, Output],
|
||||
scores: Sequence[EvalScorer[Input, Output]],
|
||||
pytestconfig: pytest.Config,
|
||||
metadata: Metadata | None = None,
|
||||
is_public: bool = False,
|
||||
no_send_logs: bool = True,
|
||||
):
|
||||
# We need to specify a separate project for each MaxEval() suite for comparison to baseline to work
|
||||
# That's the way Braintrust folks recommended - Braintrust projects are much more lightweight than PostHog ones
|
||||
project_name = f"max-ai-{experiment_name}"
|
||||
init_logger(project_name)
|
||||
|
||||
# Filter by --case <eval_case_name_part> pytest flag
|
||||
case_filter = pytestconfig.option.eval
|
||||
if case_filter:
|
||||
if asyncio.iscoroutine(data):
|
||||
data = await data
|
||||
data = [case for case in data if case_filter in str(case.input)] # type: ignore
|
||||
|
||||
timeout = 60 * 8 # 8 minutes
|
||||
if os.getenv("EVAL_MODE") == "offline":
|
||||
timeout = 60 * 60 # 1 hour
|
||||
|
||||
result = await EvalAsync(
|
||||
project_name,
|
||||
data=data,
|
||||
task=task,
|
||||
scores=scores,
|
||||
timeout=timeout,
|
||||
max_concurrency=100,
|
||||
is_public=is_public,
|
||||
no_send_logs=no_send_logs,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
# If we're running in the offline mode and the test case marked as public, the pipeline must completely fail.
|
||||
if os.getenv("EVAL_MODE") == "offline" and is_public:
|
||||
raise RuntimeError("Evaluation cases must be private when EVAL_MODE is set to offline.")
|
||||
|
||||
if os.getenv("EXPORT_EVAL_RESULTS"):
|
||||
with open("eval_results.jsonl", "a") as f:
|
||||
f.write(result.summary.as_json() + "\n")
|
||||
|
||||
return result
|
||||
|
||||
|
||||
MaxPublicEval = partial(BaseMaxEval, is_public=True, no_send_logs=False)
|
||||
"""Evaluation case that is publicly accessible."""
|
||||
|
||||
MaxPrivateEval = partial(BaseMaxEval, is_public=False, no_send_logs=True)
|
||||
"""Evaluation case is not accessible publicly."""
|
||||
159
ee/hogai/eval/ci/conftest.py
Normal file
@@ -0,0 +1,159 @@
|
||||
import os
|
||||
import datetime
|
||||
from collections.abc import Generator
|
||||
|
||||
import pytest
|
||||
|
||||
from django.test import override_settings
|
||||
|
||||
from braintrust_langchain import BraintrustCallbackHandler, set_global_handler
|
||||
|
||||
from posthog.schema import FailureMessage, HumanMessage, VisualizationMessage
|
||||
|
||||
from posthog.demo.matrix.manager import MatrixManager
|
||||
from posthog.models import Organization, Team, User
|
||||
from posthog.tasks.demo_create_data import HedgeboxMatrix
|
||||
|
||||
from ee.hogai.django_checkpoint.checkpointer import DjangoCheckpointer
|
||||
|
||||
# We want the PostHog set_up_evals fixture here
|
||||
from ee.hogai.eval.conftest import set_up_evals # noqa: F401
|
||||
from ee.hogai.eval.scorers import PlanAndQueryOutput
|
||||
from ee.hogai.graph.graph import AssistantGraph, InsightsAssistantGraph
|
||||
from ee.hogai.utils.types import AssistantNodeName, AssistantState
|
||||
from ee.models.assistant import Conversation, CoreMemory
|
||||
|
||||
handler = BraintrustCallbackHandler()
|
||||
if os.environ.get("EVAL_MODE") == "ci" and os.environ.get("BRAINTRUST_API_KEY"):
|
||||
set_global_handler(handler)
|
||||
|
||||
|
||||
EVAL_USER_FULL_NAME = "Karen Smith"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def call_root_for_insight_generation(demo_org_team_user):
|
||||
# This graph structure will first get a plan, then generate the SQL query.
|
||||
|
||||
insights_subgraph = (
|
||||
# Insights subgraph without query execution, so we only create the queries
|
||||
InsightsAssistantGraph(demo_org_team_user[1], demo_org_team_user[2])
|
||||
.add_query_creation_flow(next_node=AssistantNodeName.END)
|
||||
.compile()
|
||||
)
|
||||
graph = (
|
||||
AssistantGraph(demo_org_team_user[1], demo_org_team_user[2])
|
||||
.add_edge(AssistantNodeName.START, AssistantNodeName.ROOT)
|
||||
.add_root(
|
||||
path_map={
|
||||
"insights": AssistantNodeName.INSIGHTS_SUBGRAPH,
|
||||
"root": AssistantNodeName.END,
|
||||
"search_documentation": AssistantNodeName.END,
|
||||
"end": AssistantNodeName.END,
|
||||
}
|
||||
)
|
||||
.add_node(AssistantNodeName.INSIGHTS_SUBGRAPH, insights_subgraph)
|
||||
.add_edge(AssistantNodeName.INSIGHTS_SUBGRAPH, AssistantNodeName.END)
|
||||
# TRICKY: We need to set a checkpointer here because async tests create a new event loop.
|
||||
.compile(checkpointer=DjangoCheckpointer())
|
||||
)
|
||||
|
||||
async def callable(query_with_extra_context: str | tuple[str, str]) -> PlanAndQueryOutput:
|
||||
# If query_with_extra_context is a tuple, the first element is the query, the second is the extra context
|
||||
# in case there's an ask_user tool call.
|
||||
query = query_with_extra_context[0] if isinstance(query_with_extra_context, tuple) else query_with_extra_context
|
||||
# Initial state for the graph
|
||||
initial_state = AssistantState(
|
||||
messages=[HumanMessage(content=f"Answer this question: {query}")],
|
||||
)
|
||||
conversation = await Conversation.objects.acreate(team=demo_org_team_user[1], user=demo_org_team_user[2])
|
||||
|
||||
# Invoke the graph. The state will be updated through planner and then generator.
|
||||
final_state_raw = await graph.ainvoke(initial_state, {"configurable": {"thread_id": conversation.id}})
|
||||
|
||||
final_state = AssistantState.model_validate(final_state_raw)
|
||||
|
||||
# If we have extra context for the potential ask_user tool, and there's no message of type ai/failure
|
||||
# or ai/visualization, we should answer with that extra context. We only do this once at most in an eval case.
|
||||
if isinstance(query_with_extra_context, tuple) and not any(
|
||||
isinstance(m, VisualizationMessage | FailureMessage) for m in final_state.messages
|
||||
):
|
||||
final_state.messages = [*final_state.messages, HumanMessage(content=query_with_extra_context[1])]
|
||||
final_state.graph_status = "resumed"
|
||||
final_state_raw = await graph.ainvoke(final_state, {"configurable": {"thread_id": conversation.id}})
|
||||
final_state = AssistantState.model_validate(final_state_raw)
|
||||
|
||||
if not final_state.messages or not isinstance(final_state.messages[-1], VisualizationMessage):
|
||||
return {
|
||||
"plan": None,
|
||||
"query": None,
|
||||
"query_generation_retry_count": final_state.query_generation_retry_count,
|
||||
}
|
||||
return {
|
||||
"plan": final_state.messages[-1].plan,
|
||||
"query": final_state.messages[-1].answer,
|
||||
"query_generation_retry_count": final_state.query_generation_retry_count,
|
||||
}
|
||||
|
||||
return callable
|
||||
|
||||
|
||||
@pytest.fixture(scope="package")
|
||||
def demo_org_team_user(set_up_evals, django_db_blocker) -> Generator[tuple[Organization, Team, User], None, None]: # noqa: F811
|
||||
with django_db_blocker.unblock():
|
||||
team: Team | None = Team.objects.order_by("-created_at").first()
|
||||
today = datetime.date.today()
|
||||
# If there's no eval team or it's older than today, we need to create a new one with fresh data
|
||||
if not team or team.created_at.date() < today:
|
||||
print(f"Generating fresh demo data for evals...") # noqa: T201
|
||||
|
||||
matrix = HedgeboxMatrix(
|
||||
seed="b1ef3c66-5f43-488a-98be-6b46d92fbcef", # this seed generates all events
|
||||
days_past=120,
|
||||
days_future=30,
|
||||
n_clusters=500,
|
||||
group_type_index_offset=0,
|
||||
)
|
||||
matrix_manager = MatrixManager(matrix, print_steps=True)
|
||||
with override_settings(TEST=False):
|
||||
# Simulation saving should occur in non-test mode, so that Kafka isn't mocked. Normally in tests we don't
|
||||
# want to ingest via Kafka, but simulation saving is specifically designed to use that route for speed
|
||||
org, team, user = matrix_manager.ensure_account_and_save(
|
||||
f"eval-{today.isoformat()}", EVAL_USER_FULL_NAME, "Hedgebox Inc."
|
||||
)
|
||||
else:
|
||||
print(f"Using existing demo data for evals...") # noqa: T201
|
||||
org = team.organization
|
||||
membership = org.memberships.first()
|
||||
assert membership is not None
|
||||
user = membership.user
|
||||
|
||||
yield org, team, user
|
||||
|
||||
|
||||
@pytest.fixture(scope="package", autouse=True)
|
||||
def core_memory(demo_org_team_user, django_db_blocker) -> Generator[CoreMemory, None, None]:
|
||||
initial_memory = """Hedgebox is a cloud storage service enabling users to store, share, and access files across devices.
|
||||
|
||||
The company operates in the cloud storage and collaboration market for individuals and businesses.
|
||||
|
||||
Their audience includes professionals and organizations seeking file management and collaboration solutions.
|
||||
|
||||
Hedgebox's freemium model provides free accounts with limited storage and paid subscription plans for additional features.
|
||||
|
||||
Core features include file storage, synchronization, sharing, and collaboration tools for seamless file access and sharing.
|
||||
|
||||
It integrates with third-party applications to enhance functionality and streamline workflows.
|
||||
|
||||
Hedgebox sponsors the YouTube channel Marius Tech Tips."""
|
||||
|
||||
with django_db_blocker.unblock():
|
||||
core_memory, _ = CoreMemory.objects.get_or_create(
|
||||
team=demo_org_team_user[1],
|
||||
defaults={
|
||||
"text": initial_memory,
|
||||
"initial_text": initial_memory,
|
||||
"scraping_status": CoreMemory.ScrapingStatus.COMPLETED,
|
||||
},
|
||||
)
|
||||
yield core_memory
|
||||
@@ -15,13 +15,13 @@ from posthog.schema import (
|
||||
|
||||
from ee.hogai.graph.funnels.toolkit import FUNNEL_SCHEMA
|
||||
|
||||
from .conftest import MaxEval
|
||||
from .scorers import PlanAndQueryOutput, PlanCorrectness, QueryAndPlanAlignment, QueryKindSelection, TimeRangeRelevancy
|
||||
from ..base import MaxPublicEval
|
||||
from ..scorers import PlanAndQueryOutput, PlanCorrectness, QueryAndPlanAlignment, QueryKindSelection, TimeRangeRelevancy
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
async def eval_funnel(call_root_for_insight_generation, pytestconfig):
|
||||
await MaxEval(
|
||||
await MaxPublicEval(
|
||||
experiment_name="funnel",
|
||||
task=call_root_for_insight_generation,
|
||||
scores=[
|
||||
@@ -11,8 +11,8 @@ from ee.hogai.graph import AssistantGraph
|
||||
from ee.hogai.utils.types import AssistantNodeName, AssistantState
|
||||
from ee.models.assistant import Conversation
|
||||
|
||||
from .conftest import MaxEval
|
||||
from .scorers import InsightEvaluationAccuracy, InsightSearchOutput
|
||||
from ..base import MaxPublicEval
|
||||
from ..scorers import InsightEvaluationAccuracy, InsightSearchOutput
|
||||
|
||||
|
||||
def extract_evaluation_info_from_state(state) -> dict:
|
||||
@@ -144,7 +144,7 @@ def call_insight_search(demo_org_team_user):
|
||||
@pytest.mark.django_db
|
||||
async def eval_insight_evaluation_accuracy(call_insight_search, pytestconfig):
|
||||
"""Evaluate the accuracy of the insight evaluation decision."""
|
||||
await MaxEval(
|
||||
await MaxPublicEval(
|
||||
experiment_name="insight_evaluation_accuracy",
|
||||
task=call_insight_search,
|
||||
scores=[InsightEvaluationAccuracy()],
|
||||
@@ -4,6 +4,7 @@ import pytest
|
||||
|
||||
from autoevals.llm import LLMClassifier
|
||||
from braintrust import EvalCase, Score
|
||||
from langchain_core.messages import AIMessage as LangchainAIMessage
|
||||
|
||||
from posthog.schema import AssistantMessage, AssistantToolCall, HumanMessage
|
||||
|
||||
@@ -12,8 +13,8 @@ from ee.hogai.graph import AssistantGraph
|
||||
from ee.hogai.utils.types import AssistantNodeName, AssistantState
|
||||
from ee.models.assistant import Conversation
|
||||
|
||||
from .conftest import MaxEval
|
||||
from .scorers import ToolRelevance
|
||||
from ..base import MaxPublicEval
|
||||
from ..scorers import ToolRelevance
|
||||
|
||||
|
||||
class MemoryContentRelevance(LLMClassifier):
|
||||
@@ -83,9 +84,12 @@ def call_node(demo_org_team_user, core_memory):
|
||||
state = AssistantState.model_validate(raw_state)
|
||||
if not state.memory_collection_messages:
|
||||
return None
|
||||
last_message = state.memory_collection_messages[-1]
|
||||
if not isinstance(last_message, LangchainAIMessage):
|
||||
return None
|
||||
return AssistantMessage(
|
||||
content=state.memory_collection_messages[-1].content,
|
||||
tool_calls=state.memory_collection_messages[-1].tool_calls,
|
||||
content=last_message.content,
|
||||
tool_calls=last_message.tool_calls,
|
||||
)
|
||||
|
||||
return callable
|
||||
@@ -93,7 +97,7 @@ def call_node(demo_org_team_user, core_memory):
|
||||
|
||||
@pytest.mark.django_db
|
||||
async def eval_memory(call_node, pytestconfig):
|
||||
await MaxEval(
|
||||
await MaxPublicEval(
|
||||
experiment_name="memory",
|
||||
task=call_node,
|
||||
scores=[ToolRelevance(semantic_similarity_args={"memory_content", "new_fragment"}), MemoryContentRelevance()],
|
||||
@@ -6,13 +6,13 @@ from posthog.schema import AssistantRetentionEventsNode, AssistantRetentionFilte
|
||||
|
||||
from ee.hogai.graph.retention.toolkit import RETENTION_SCHEMA
|
||||
|
||||
from .conftest import MaxEval
|
||||
from .scorers import PlanAndQueryOutput, PlanCorrectness, QueryAndPlanAlignment, QueryKindSelection, TimeRangeRelevancy
|
||||
from ..base import MaxPublicEval
|
||||
from ..scorers import PlanAndQueryOutput, PlanCorrectness, QueryAndPlanAlignment, QueryKindSelection, TimeRangeRelevancy
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
async def eval_retention(call_root_for_insight_generation, pytestconfig):
|
||||
await MaxEval(
|
||||
await MaxPublicEval(
|
||||
experiment_name="retention",
|
||||
task=call_root_for_insight_generation,
|
||||
scores=[
|
||||
@@ -11,8 +11,8 @@ from ee.hogai.graph import AssistantGraph
|
||||
from ee.hogai.utils.types import AssistantMessageUnion, AssistantNodeName, AssistantState
|
||||
from ee.models.assistant import Conversation
|
||||
|
||||
from .conftest import MaxEval
|
||||
from .scorers import ToolRelevance
|
||||
from ..base import MaxPublicEval
|
||||
from ..scorers import ToolRelevance
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -47,7 +47,7 @@ def call_root(demo_org_team_user):
|
||||
|
||||
@pytest.mark.django_db
|
||||
async def eval_root(call_root, pytestconfig):
|
||||
await MaxEval(
|
||||
await MaxPublicEval(
|
||||
experiment_name="root",
|
||||
task=call_root,
|
||||
scores=[ToolRelevance(semantic_similarity_args={"query_description"})],
|
||||
@@ -9,7 +9,8 @@ from ee.hogai.graph import AssistantGraph
|
||||
from ee.hogai.utils.types import AssistantMessageUnion, AssistantNodeName, AssistantState
|
||||
from ee.models.assistant import Conversation
|
||||
|
||||
from .conftest import EVAL_USER_FULL_NAME, MaxEval
|
||||
from ..base import MaxPublicEval
|
||||
from .conftest import EVAL_USER_FULL_NAME
|
||||
|
||||
|
||||
class StyleChecker(LLMClassifier):
|
||||
@@ -90,7 +91,7 @@ def call_root(demo_org_team_user):
|
||||
|
||||
@pytest.mark.django_db
|
||||
async def eval_root_style(call_root, pytestconfig):
|
||||
await MaxEval(
|
||||
await MaxPublicEval(
|
||||
experiment_name="root_style",
|
||||
task=call_root,
|
||||
scores=[StyleChecker(user_name=EVAL_USER_FULL_NAME)],
|
||||
@@ -6,16 +6,13 @@ from braintrust_core.score import Scorer
|
||||
|
||||
from posthog.schema import AssistantHogQLQuery, NodeKind
|
||||
|
||||
from posthog.hogql.errors import BaseHogQLError
|
||||
|
||||
from posthog.errors import InternalCHQueryError
|
||||
from posthog.hogql_queries.hogql_query_runner import HogQLQueryRunner
|
||||
from posthog.models.team.team import Team
|
||||
from posthog.models import Team
|
||||
|
||||
from ee.hogai.eval.scorers.sql import evaluate_sql_query
|
||||
from ee.hogai.graph.sql.toolkit import SQL_SCHEMA
|
||||
|
||||
from .conftest import MaxEval
|
||||
from .scorers import PlanAndQueryOutput, PlanCorrectness, QueryAndPlanAlignment, QueryKindSelection, TimeRangeRelevancy
|
||||
from ..base import MaxPublicEval
|
||||
from ..scorers import PlanAndQueryOutput, PlanCorrectness, QueryAndPlanAlignment, QueryKindSelection, TimeRangeRelevancy
|
||||
|
||||
QUERY_GENERATION_MAX_RETRIES = 3
|
||||
|
||||
@@ -43,57 +40,23 @@ class RetryEfficiency(Scorer):
|
||||
return Score(name=self._name(), score=score, metadata={"query_generation_retry_count": retry_count})
|
||||
|
||||
|
||||
class SQLSyntaxCorrectness(Scorer):
|
||||
"""Evaluate if the generated SQL query has correct syntax."""
|
||||
|
||||
class HogQLQuerySyntaxCorrectness(Scorer):
|
||||
def _name(self):
|
||||
return "sql_syntax_correctness"
|
||||
|
||||
async def _run_eval_async(self, output, expected=None, **kwargs):
|
||||
if not output:
|
||||
return Score(
|
||||
name=self._name(), score=None, metadata={"reason": "No SQL query to verify, skipping evaluation"}
|
||||
)
|
||||
query = {"query": output}
|
||||
team = await Team.objects.alatest("created_at")
|
||||
try:
|
||||
# Try to parse, print, and run the query
|
||||
await sync_to_async(HogQLQueryRunner(query, team).calculate)()
|
||||
except BaseHogQLError as e:
|
||||
return Score(name=self._name(), score=0.0, metadata={"reason": f"HogQL-level error: {str(e)}"})
|
||||
except InternalCHQueryError as e:
|
||||
return Score(name=self._name(), score=0.5, metadata={"reason": f"ClickHouse-level error: {str(e)}"})
|
||||
else:
|
||||
return Score(name=self._name(), score=1.0)
|
||||
async def _run_eval_async(self, output: PlanAndQueryOutput, *args, **kwargs):
|
||||
return await sync_to_async(self._evaluate)(output)
|
||||
|
||||
def _run_eval_sync(self, output, expected=None, **kwargs):
|
||||
if not output:
|
||||
return Score(
|
||||
name=self._name(), score=None, metadata={"reason": "No SQL query to verify, skipping evaluation"}
|
||||
)
|
||||
query = {"query": output}
|
||||
def _run_eval_sync(self, output: PlanAndQueryOutput, *args, **kwargs):
|
||||
return self._evaluate(output)
|
||||
|
||||
def _evaluate(self, output: PlanAndQueryOutput) -> Score:
|
||||
team = Team.objects.latest("created_at")
|
||||
try:
|
||||
# Try to parse, print, and run the query
|
||||
HogQLQueryRunner(query, team).calculate()
|
||||
except BaseHogQLError as e:
|
||||
return Score(name=self._name(), score=0.0, metadata={"reason": f"HogQL-level error: {str(e)}"})
|
||||
except InternalCHQueryError as e:
|
||||
return Score(name=self._name(), score=0.5, metadata={"reason": f"ClickHouse-level error: {str(e)}"})
|
||||
if isinstance(output["query"], AssistantHogQLQuery):
|
||||
query = output["query"].query
|
||||
else:
|
||||
return Score(name=self._name(), score=1.0)
|
||||
|
||||
|
||||
class HogQLQuerySyntaxCorrectness(SQLSyntaxCorrectness):
|
||||
async def _run_eval_async(self, output, expected=None, **kwargs):
|
||||
return await super()._run_eval_async(
|
||||
output["query"].query if output and output.get("query") else None, expected, **kwargs
|
||||
)
|
||||
|
||||
def _run_eval_sync(self, output, expected=None, **kwargs):
|
||||
return super()._run_eval_sync(
|
||||
output["query"].query if output and output.get("query") else None, expected, **kwargs
|
||||
)
|
||||
query = None
|
||||
return evaluate_sql_query(self._name(), query, team)
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
@@ -697,7 +660,7 @@ ORDER BY ABS(corr(toFloat(uploads_30d), toFloat(churned))) DESC
|
||||
),
|
||||
]
|
||||
|
||||
await MaxEval(
|
||||
await MaxPublicEval(
|
||||
experiment_name="sql",
|
||||
task=call_root_for_insight_generation,
|
||||
scores=[
|
||||
@@ -13,7 +13,7 @@ from products.surveys.backend.max_tools import CreateSurveyTool, FeatureFlagLook
|
||||
from ee.hogai.django_checkpoint.checkpointer import DjangoCheckpointer
|
||||
from ee.models.assistant import Conversation
|
||||
|
||||
from .conftest import MaxEval
|
||||
from ..base import MaxPublicEval
|
||||
|
||||
|
||||
def validate_survey_output(output, scorer_name):
|
||||
@@ -630,7 +630,7 @@ async def eval_surveys(call_surveys_max_tool, pytestconfig):
|
||||
"""
|
||||
Evaluation for survey creation functionality.
|
||||
"""
|
||||
await MaxEval(
|
||||
await MaxPublicEval(
|
||||
experiment_name="surveys",
|
||||
task=call_surveys_max_tool,
|
||||
scores=[
|
||||
@@ -16,8 +16,8 @@ from posthog.schema import (
|
||||
|
||||
from ee.hogai.graph.trends.toolkit import TRENDS_SCHEMA
|
||||
|
||||
from .conftest import MaxEval
|
||||
from .scorers import PlanAndQueryOutput, PlanCorrectness, QueryAndPlanAlignment, QueryKindSelection, TimeRangeRelevancy
|
||||
from ..base import MaxPublicEval
|
||||
from ..scorers import PlanAndQueryOutput, PlanCorrectness, QueryAndPlanAlignment, QueryKindSelection, TimeRangeRelevancy
|
||||
|
||||
TRENDS_CASES = [
|
||||
EvalCase(
|
||||
@@ -408,7 +408,7 @@ Query kind:
|
||||
|
||||
@pytest.mark.django_db
|
||||
async def eval_trends(call_root_for_insight_generation, pytestconfig):
|
||||
await MaxEval(
|
||||
await MaxPublicEval(
|
||||
experiment_name="trends",
|
||||
task=call_root_for_insight_generation,
|
||||
scores=[
|
||||
@@ -19,8 +19,8 @@ from ee.hogai.graph import AssistantGraph
|
||||
from ee.hogai.utils.types import AssistantNodeName, AssistantState
|
||||
from ee.models.assistant import Conversation
|
||||
|
||||
from .conftest import MaxEval
|
||||
from .scorers import ToolRelevance
|
||||
from ..base import MaxPublicEval
|
||||
from ..scorers import ToolRelevance
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -71,7 +71,7 @@ def sample_action(demo_org_team_user):
|
||||
@pytest.mark.django_db
|
||||
async def eval_ui_context_actions(call_root_with_ui_context, sample_action, pytestconfig):
|
||||
"""Test that actions in UI context are properly used in RAG context retrieval"""
|
||||
await MaxEval(
|
||||
await MaxPublicEval(
|
||||
experiment_name="ui_context_actions",
|
||||
task=call_root_with_ui_context,
|
||||
scores=[
|
||||
@@ -136,7 +136,7 @@ async def eval_ui_context_actions(call_root_with_ui_context, sample_action, pyte
|
||||
@pytest.mark.django_db
|
||||
async def eval_ui_context_events(call_root_with_ui_context, pytestconfig):
|
||||
"""Test that events in UI context are properly used in taxonomy agent"""
|
||||
await MaxEval(
|
||||
await MaxPublicEval(
|
||||
experiment_name="ui_context_events",
|
||||
task=call_root_with_ui_context,
|
||||
scores=[
|
||||
0
ee/hogai/eval/ci/max_tools/__init__.py
Normal file
@@ -14,9 +14,8 @@ from posthog.sync import database_sync_to_async
|
||||
|
||||
from products.data_warehouse.backend.max_tools import HogQLGeneratorArgs, HogQLGeneratorTool
|
||||
|
||||
from ee.hogai.eval.conftest import MaxEval
|
||||
from ee.hogai.eval.eval_sql import SQLSyntaxCorrectness
|
||||
from ee.hogai.eval.scorers import SQLSemanticsCorrectness
|
||||
from ee.hogai.eval.base import MaxPublicEval
|
||||
from ee.hogai.eval.scorers import SQLSemanticsCorrectness, SQLSyntaxCorrectness
|
||||
from ee.hogai.utils.markdown import remove_markdown
|
||||
from ee.hogai.utils.types import AssistantState
|
||||
from ee.hogai.utils.warehouse import serialize_database_schema
|
||||
@@ -73,7 +72,8 @@ async def database_schema(demo_org_team_user):
|
||||
return await serialize_database_schema(database, context)
|
||||
|
||||
|
||||
async def sql_semantics_scorer(input: EvalInput, expected: str, output: str, metadata: dict) -> Score:
|
||||
async def sql_semantics_scorer(input: EvalInput, expected: str, output: str, **kwargs) -> Score:
|
||||
metadata: dict = kwargs["metadata"]
|
||||
metric = SQLSemanticsCorrectness()
|
||||
return await metric.eval_async(
|
||||
input=input.instructions, expected=expected, output=output, database_schema=metadata["schema"]
|
||||
@@ -84,7 +84,7 @@ async def sql_semantics_scorer(input: EvalInput, expected: str, output: str, met
|
||||
async def eval_tool_generate_hogql_query(call_generate_hogql_query, database_schema, pytestconfig):
|
||||
metadata = {"schema": database_schema}
|
||||
|
||||
await MaxEval(
|
||||
await MaxPublicEval(
|
||||
experiment_name="tool_generate_hogql_query",
|
||||
task=call_generate_hogql_query,
|
||||
scores=[SQLSyntaxCorrectness(), sql_semantics_scorer],
|
||||
@@ -24,9 +24,10 @@ from products.replay.backend.max_tools import SessionReplayFilterOptionsGraph
|
||||
from products.replay.backend.prompts import USER_FILTER_OPTIONS_PROMPT
|
||||
|
||||
from ee.hogai.django_checkpoint.checkpointer import DjangoCheckpointer
|
||||
from ee.hogai.eval.conftest import MaxEval
|
||||
from ee.models.assistant import Conversation
|
||||
|
||||
from ...base import MaxPublicEval
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DUMMY_CURRENT_FILTERS = MaxRecordingUniversalFilters(
|
||||
@@ -178,7 +179,7 @@ class DateTimeFilteringCorrectness(Scorer):
|
||||
|
||||
@pytest.mark.django_db
|
||||
async def eval_tool_search_session_recordings(call_search_session_recordings, pytestconfig):
|
||||
await MaxEval(
|
||||
await MaxPublicEval(
|
||||
experiment_name="tool_search_session_recordings",
|
||||
task=call_search_session_recordings,
|
||||
scores=[FilterGenerationCorrectness(), DateTimeFilteringCorrectness()],
|
||||
@@ -586,7 +587,7 @@ async def eval_tool_search_session_recordings(call_search_session_recordings, py
|
||||
|
||||
@pytest.mark.django_db
|
||||
async def eval_tool_search_session_recordings_ask_user_for_help(call_search_session_recordings, pytestconfig):
|
||||
await MaxEval(
|
||||
await MaxPublicEval(
|
||||
experiment_name="tool_search_session_recordings_ask_user_for_help",
|
||||
task=call_search_session_recordings,
|
||||
scores=[AskUserForHelp()],
|
||||
598
ee/hogai/eval/ci/max_tools/eval_tool_filter_generation.py
Normal file
@@ -0,0 +1,598 @@
|
||||
import logging
|
||||
|
||||
import pytest
|
||||
|
||||
from braintrust import EvalCase, Score
|
||||
from braintrust_core.score import Scorer
|
||||
from deepdiff import DeepDiff
|
||||
|
||||
from posthog.schema import (
|
||||
DurationType,
|
||||
EventPropertyFilter,
|
||||
FilterLogicalOperator,
|
||||
MaxInnerUniversalFiltersGroup,
|
||||
MaxOuterUniversalFiltersGroup,
|
||||
MaxRecordingUniversalFilters,
|
||||
PersonPropertyFilter,
|
||||
PropertyOperator,
|
||||
RecordingDurationFilter,
|
||||
RecordingOrder,
|
||||
RecordingOrderDirection,
|
||||
)
|
||||
|
||||
from products.replay.backend.max_tools import SessionReplayFilterOptionsGraph
|
||||
from products.replay.backend.prompts import USER_FILTER_OPTIONS_PROMPT
|
||||
|
||||
from ee.hogai.django_checkpoint.checkpointer import DjangoCheckpointer
|
||||
from ee.models.assistant import Conversation
|
||||
|
||||
from ...base import MaxPublicEval
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DUMMY_CURRENT_FILTERS = MaxRecordingUniversalFilters(
|
||||
date_from="-7d",
|
||||
date_to=None,
|
||||
duration=[
|
||||
RecordingDurationFilter(key=DurationType.DURATION, operator=PropertyOperator.GT, type="recording", value=60.0)
|
||||
],
|
||||
filter_group=MaxOuterUniversalFiltersGroup(
|
||||
type=FilterLogicalOperator.AND_,
|
||||
values=[MaxInnerUniversalFiltersGroup(type=FilterLogicalOperator.AND_, values=[])],
|
||||
),
|
||||
filter_test_accounts=True,
|
||||
order=RecordingOrder.START_TIME,
|
||||
order_direction=RecordingOrderDirection.DESC,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def call_search_session_recordings(demo_org_team_user):
|
||||
graph = SessionReplayFilterOptionsGraph(demo_org_team_user[1], demo_org_team_user[2]).compile_full_graph(
|
||||
checkpointer=DjangoCheckpointer()
|
||||
)
|
||||
|
||||
async def callable(change: str) -> dict:
|
||||
conversation = await Conversation.objects.acreate(team=demo_org_team_user[1], user=demo_org_team_user[2])
|
||||
|
||||
# Convert filters to JSON string and use test-specific prompt
|
||||
filters_json = DUMMY_CURRENT_FILTERS.model_dump_json(indent=2)
|
||||
|
||||
graph_input = {
|
||||
"change": USER_FILTER_OPTIONS_PROMPT.format(change=change, current_filters=filters_json),
|
||||
"output": None,
|
||||
}
|
||||
|
||||
result = await graph.ainvoke(graph_input, config={"configurable": {"thread_id": conversation.id}})
|
||||
return result
|
||||
|
||||
return callable
|
||||
|
||||
|
||||
class FilterGenerationCorrectness(Scorer):
|
||||
"""Score the correctness of generated filters."""
|
||||
|
||||
def _name(self):
|
||||
return "filter_generation_correctness"
|
||||
|
||||
async def _run_eval_async(self, output, expected=None, **kwargs):
|
||||
return self._run_eval_sync(output, expected, **kwargs)
|
||||
|
||||
def _run_eval_sync(self, output, expected=None, **kwargs):
|
||||
try:
|
||||
actual_filters = MaxRecordingUniversalFilters.model_validate(output["output"])
|
||||
except Exception as e:
|
||||
logger.exception(f"Error parsing filters: {e}")
|
||||
return Score(name=self._name(), score=0.0, metadata={"reason": "LLM returned invalid filter structure"})
|
||||
|
||||
# Convert both objects to dict for deepdiff comparison
|
||||
actual_dict = actual_filters.model_dump()
|
||||
expected_dict = expected.model_dump()
|
||||
|
||||
# Use deepdiff to find differences
|
||||
diff = DeepDiff(expected_dict, actual_dict, ignore_order=True, report_repetition=True)
|
||||
|
||||
if not diff:
|
||||
return Score(name=self._name(), score=1.0, metadata={"reason": "Perfect match"})
|
||||
|
||||
# Calculate score based on number of differences
|
||||
total_fields = len(expected_dict.keys())
|
||||
changed_fields = (
|
||||
len(diff.get("values_changed", {}))
|
||||
+ len(diff.get("dictionary_item_added", set()))
|
||||
+ len(diff.get("dictionary_item_removed", set()))
|
||||
)
|
||||
|
||||
score = max(0.0, (total_fields - changed_fields) / total_fields)
|
||||
|
||||
return Score(
|
||||
name=self._name(),
|
||||
score=score,
|
||||
metadata={
|
||||
"differences": str(diff),
|
||||
"total_fields": total_fields,
|
||||
"changed_fields": changed_fields,
|
||||
"reason": f"Found {changed_fields} differences out of {total_fields} fields",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
class AskUserForHelp(Scorer):
|
||||
"""Score the correctness of the ask_user_for_help tool."""
|
||||
|
||||
def _name(self):
|
||||
return "ask_user_for_help_scorer"
|
||||
|
||||
def _run_eval_sync(self, output, expected=None, **kwargs):
|
||||
if "output" not in output or output["output"] is None:
|
||||
if (
|
||||
"intermediate_steps" in output
|
||||
and len(output["intermediate_steps"]) > 0
|
||||
and output["intermediate_steps"][-1][0].tool == "ask_user_for_help"
|
||||
):
|
||||
return Score(
|
||||
name=self._name(), score=1, metadata={"reason": "LLM returned valid ask_user_for_help response"}
|
||||
)
|
||||
else:
|
||||
return Score(
|
||||
name=self._name(),
|
||||
score=0,
|
||||
metadata={"reason": "LLM did not return valid ask_user_for_help response"},
|
||||
)
|
||||
else:
|
||||
return Score(name=self._name(), score=0.0, metadata={"reason": "LLM returned a filter"})
|
||||
|
||||
|
||||
class DateTimeFilteringCorrectness(Scorer):
|
||||
"""Score the correctness of the date time filtering."""
|
||||
|
||||
def _name(self):
|
||||
return "date_time_filtering_correctness"
|
||||
|
||||
async def _run_eval_async(self, output, expected=None, **kwargs):
|
||||
return self._run_eval_sync(output, expected, **kwargs)
|
||||
|
||||
def _run_eval_sync(self, output, expected=None, **kwargs):
|
||||
try:
|
||||
actual_filters = MaxRecordingUniversalFilters.model_validate(output["output"])
|
||||
except Exception as e:
|
||||
logger.exception(f"Error parsing filters: {e}")
|
||||
return Score(name=self._name(), score=0.0, metadata={"reason": "LLM returned invalid filter structure"})
|
||||
|
||||
if actual_filters.date_from == expected.date_from and actual_filters.date_to == expected.date_to:
|
||||
return Score(name=self._name(), score=1.0, metadata={"reason": "LLM returned valid date time filters"})
|
||||
elif actual_filters.date_from == expected.date_from:
|
||||
return Score(
|
||||
name=self._name(),
|
||||
score=0.5,
|
||||
metadata={"reason": "LLM returned valid date_from but did not return valid date_to"},
|
||||
)
|
||||
elif actual_filters.date_to == expected.date_to:
|
||||
return Score(
|
||||
name=self._name(),
|
||||
score=0.5,
|
||||
metadata={"reason": "LLM returned valid date_to but did not return valid date_from"},
|
||||
)
|
||||
else:
|
||||
return Score(name=self._name(), score=0.0, metadata={"reason": "LLM returned invalid date time filters"})
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
async def eval_tool_search_session_recordings(call_search_session_recordings, pytestconfig):
|
||||
await MaxPublicEval(
|
||||
experiment_name="tool_search_session_recordings",
|
||||
task=call_search_session_recordings,
|
||||
scores=[FilterGenerationCorrectness(), DateTimeFilteringCorrectness()],
|
||||
data=[
|
||||
# Test basic filter generation for mobile devices
|
||||
EvalCase(
|
||||
input="show me recordings of users that were using a mobile device (use events)",
|
||||
expected=MaxRecordingUniversalFilters(
|
||||
date_from="-7d",
|
||||
date_to=None,
|
||||
duration=[
|
||||
RecordingDurationFilter(
|
||||
key=DurationType.DURATION, operator=PropertyOperator.GT, type="recording", value=60.0
|
||||
)
|
||||
],
|
||||
filter_group=MaxOuterUniversalFiltersGroup(
|
||||
type=FilterLogicalOperator.AND_,
|
||||
values=[
|
||||
MaxInnerUniversalFiltersGroup(
|
||||
type=FilterLogicalOperator.AND_,
|
||||
values=[
|
||||
EventPropertyFilter(
|
||||
key="$device_type",
|
||||
type="event",
|
||||
value=["Mobile"],
|
||||
operator=PropertyOperator.EXACT,
|
||||
)
|
||||
],
|
||||
)
|
||||
],
|
||||
),
|
||||
filter_test_accounts=True,
|
||||
order=RecordingOrder.START_TIME,
|
||||
),
|
||||
),
|
||||
# Test date range filtering
|
||||
EvalCase(
|
||||
input="Show recordings from the last 2 hours",
|
||||
expected=MaxRecordingUniversalFilters(
|
||||
date_from="-2h",
|
||||
date_to=None,
|
||||
duration=[
|
||||
RecordingDurationFilter(
|
||||
key=DurationType.DURATION, operator=PropertyOperator.GT, type="recording", value=60.0
|
||||
)
|
||||
],
|
||||
filter_group=MaxOuterUniversalFiltersGroup(
|
||||
type=FilterLogicalOperator.AND_,
|
||||
values=[MaxInnerUniversalFiltersGroup(type=FilterLogicalOperator.AND_, values=[])],
|
||||
),
|
||||
filter_test_accounts=True,
|
||||
order=RecordingOrder.START_TIME,
|
||||
),
|
||||
),
|
||||
# Test location filtering
|
||||
EvalCase(
|
||||
input="Show recordings for users located in the US",
|
||||
expected=MaxRecordingUniversalFilters(
|
||||
date_from="-7d",
|
||||
date_to=None,
|
||||
duration=[
|
||||
RecordingDurationFilter(
|
||||
key=DurationType.DURATION, operator=PropertyOperator.GT, type="recording", value=60.0
|
||||
)
|
||||
],
|
||||
filter_group=MaxOuterUniversalFiltersGroup(
|
||||
type=FilterLogicalOperator.AND_,
|
||||
values=[
|
||||
MaxInnerUniversalFiltersGroup(
|
||||
type=FilterLogicalOperator.AND_,
|
||||
values=[
|
||||
PersonPropertyFilter(
|
||||
key="$geoip_country_code",
|
||||
type="person",
|
||||
value=["US"],
|
||||
operator=PropertyOperator.EXACT,
|
||||
)
|
||||
],
|
||||
)
|
||||
],
|
||||
),
|
||||
filter_test_accounts=True,
|
||||
order=RecordingOrder.START_TIME,
|
||||
),
|
||||
),
|
||||
# Test browser-specific filtering
|
||||
EvalCase(
|
||||
input="Show recordings from users that were using a browser in English",
|
||||
expected=MaxRecordingUniversalFilters(
|
||||
date_from="-7d",
|
||||
date_to=None,
|
||||
duration=[
|
||||
RecordingDurationFilter(
|
||||
key=DurationType.DURATION, operator=PropertyOperator.GT, type="recording", value=60.0
|
||||
)
|
||||
],
|
||||
filter_group=MaxOuterUniversalFiltersGroup(
|
||||
type=FilterLogicalOperator.AND_,
|
||||
values=[
|
||||
MaxInnerUniversalFiltersGroup(
|
||||
type=FilterLogicalOperator.AND_,
|
||||
values=[
|
||||
PersonPropertyFilter(
|
||||
key="$browser_language",
|
||||
type="person",
|
||||
value=["EN-en"],
|
||||
operator=PropertyOperator.EXACT,
|
||||
)
|
||||
],
|
||||
)
|
||||
],
|
||||
),
|
||||
filter_test_accounts=True,
|
||||
order=RecordingOrder.START_TIME,
|
||||
),
|
||||
),
|
||||
EvalCase(
|
||||
input="Show me recordings from chrome browsers",
|
||||
expected=MaxRecordingUniversalFilters(
|
||||
date_from="-7d",
|
||||
date_to=None,
|
||||
duration=[
|
||||
RecordingDurationFilter(
|
||||
key=DurationType.DURATION, operator=PropertyOperator.GT, type="recording", value=60.0
|
||||
)
|
||||
],
|
||||
filter_group=MaxOuterUniversalFiltersGroup(
|
||||
type=FilterLogicalOperator.AND_,
|
||||
values=[
|
||||
MaxInnerUniversalFiltersGroup(
|
||||
type=FilterLogicalOperator.AND_,
|
||||
values=[
|
||||
EventPropertyFilter(
|
||||
key="$browser",
|
||||
type="event",
|
||||
value=["Chrome"],
|
||||
operator=PropertyOperator.EXACT,
|
||||
)
|
||||
],
|
||||
)
|
||||
],
|
||||
),
|
||||
filter_test_accounts=True,
|
||||
order=RecordingOrder.START_TIME,
|
||||
),
|
||||
),
|
||||
# Test user behavior filtering
|
||||
EvalCase(
|
||||
input="Show recordings where users visited the posthog.com/checkout_page",
|
||||
expected=MaxRecordingUniversalFilters(
|
||||
date_from="-7d",
|
||||
date_to=None,
|
||||
duration=[
|
||||
RecordingDurationFilter(
|
||||
key=DurationType.DURATION, operator=PropertyOperator.GT, type="recording", value=60.0
|
||||
)
|
||||
],
|
||||
filter_group=MaxOuterUniversalFiltersGroup(
|
||||
type=FilterLogicalOperator.AND_,
|
||||
values=[
|
||||
MaxInnerUniversalFiltersGroup(
|
||||
type=FilterLogicalOperator.AND_,
|
||||
values=[
|
||||
EventPropertyFilter(
|
||||
key="$current_url",
|
||||
type="event",
|
||||
value=["posthog.com/checkout_page"],
|
||||
operator=PropertyOperator.ICONTAINS,
|
||||
)
|
||||
],
|
||||
)
|
||||
],
|
||||
),
|
||||
filter_test_accounts=True,
|
||||
order=RecordingOrder.START_TIME,
|
||||
),
|
||||
),
|
||||
# Test session duration filtering
|
||||
EvalCase(
|
||||
input="Show recordings longer than 5 minutes",
|
||||
expected=MaxRecordingUniversalFilters(
|
||||
date_from="-7d",
|
||||
date_to=None,
|
||||
duration=[
|
||||
RecordingDurationFilter(
|
||||
key=DurationType.DURATION, operator=PropertyOperator.GT, type="recording", value=300.0
|
||||
)
|
||||
],
|
||||
filter_group=MaxOuterUniversalFiltersGroup(
|
||||
type=FilterLogicalOperator.AND_,
|
||||
values=[MaxInnerUniversalFiltersGroup(type=FilterLogicalOperator.AND_, values=[])],
|
||||
),
|
||||
filter_test_accounts=True,
|
||||
order=RecordingOrder.START_TIME,
|
||||
),
|
||||
),
|
||||
# Test user action
|
||||
EvalCase(
|
||||
input="Show recordings from users that performed a billing action",
|
||||
expected=MaxRecordingUniversalFilters(
|
||||
date_from="-7d",
|
||||
date_to=None,
|
||||
duration=[
|
||||
RecordingDurationFilter(
|
||||
key=DurationType.DURATION, operator=PropertyOperator.GT, type="recording", value=60.0
|
||||
)
|
||||
],
|
||||
filter_group=MaxOuterUniversalFiltersGroup(
|
||||
type=FilterLogicalOperator.AND_,
|
||||
values=[
|
||||
MaxInnerUniversalFiltersGroup(
|
||||
type=FilterLogicalOperator.AND_,
|
||||
values=[
|
||||
EventPropertyFilter(
|
||||
key="paid_bill",
|
||||
type="event",
|
||||
value=None,
|
||||
operator=PropertyOperator.IS_SET,
|
||||
)
|
||||
],
|
||||
)
|
||||
],
|
||||
),
|
||||
filter_test_accounts=True,
|
||||
order=RecordingOrder.START_TIME,
|
||||
),
|
||||
),
|
||||
# Test page-specific filtering
|
||||
EvalCase(
|
||||
input="Show recordings from users who visited the pricing page",
|
||||
expected=MaxRecordingUniversalFilters(
|
||||
date_from="-7d",
|
||||
date_to=None,
|
||||
duration=[
|
||||
RecordingDurationFilter(
|
||||
key=DurationType.DURATION, operator=PropertyOperator.GT, type="recording", value=60.0
|
||||
)
|
||||
],
|
||||
filter_group=MaxOuterUniversalFiltersGroup(
|
||||
type=FilterLogicalOperator.AND_,
|
||||
values=[
|
||||
MaxInnerUniversalFiltersGroup(
|
||||
type=FilterLogicalOperator.AND_,
|
||||
values=[
|
||||
EventPropertyFilter(
|
||||
key="$pathname",
|
||||
type="event",
|
||||
value=["/pricing/"],
|
||||
operator=PropertyOperator.ICONTAINS,
|
||||
)
|
||||
],
|
||||
)
|
||||
],
|
||||
),
|
||||
filter_test_accounts=True,
|
||||
order=RecordingOrder.START_TIME,
|
||||
),
|
||||
),
|
||||
# Test conversion funnel filtering
|
||||
EvalCase(
|
||||
input="Show recordings from users who completed the signup flow",
|
||||
expected=MaxRecordingUniversalFilters(
|
||||
date_from="-7d",
|
||||
date_to=None,
|
||||
duration=[
|
||||
RecordingDurationFilter(
|
||||
key=DurationType.DURATION, operator=PropertyOperator.GT, type="recording", value=60.0
|
||||
)
|
||||
],
|
||||
filter_group=MaxOuterUniversalFiltersGroup(
|
||||
type=FilterLogicalOperator.AND_,
|
||||
values=[
|
||||
MaxInnerUniversalFiltersGroup(
|
||||
type=FilterLogicalOperator.AND_,
|
||||
values=[
|
||||
EventPropertyFilter(
|
||||
key="signup_completed",
|
||||
type="event",
|
||||
value=None,
|
||||
operator=PropertyOperator.IS_SET,
|
||||
)
|
||||
],
|
||||
)
|
||||
],
|
||||
),
|
||||
filter_test_accounts=True,
|
||||
order=RecordingOrder.START_TIME,
|
||||
),
|
||||
),
|
||||
# Test device and browser combination
|
||||
EvalCase(
|
||||
input="Show recordings from mobile Safari users",
|
||||
expected=MaxRecordingUniversalFilters(
|
||||
date_from="-7d",
|
||||
date_to=None,
|
||||
duration=[
|
||||
RecordingDurationFilter(
|
||||
key=DurationType.DURATION, operator=PropertyOperator.GT, type="recording", value=60.0
|
||||
)
|
||||
],
|
||||
filter_group=MaxOuterUniversalFiltersGroup(
|
||||
type=FilterLogicalOperator.AND_,
|
||||
values=[
|
||||
MaxInnerUniversalFiltersGroup(
|
||||
type=FilterLogicalOperator.AND_,
|
||||
values=[
|
||||
EventPropertyFilter(
|
||||
key="$device_type",
|
||||
type="event",
|
||||
value=["Mobile"],
|
||||
operator=PropertyOperator.EXACT,
|
||||
),
|
||||
EventPropertyFilter(
|
||||
key="$browser",
|
||||
type="event",
|
||||
value=["Safari"],
|
||||
operator=PropertyOperator.EXACT,
|
||||
),
|
||||
],
|
||||
)
|
||||
],
|
||||
),
|
||||
filter_test_accounts=True,
|
||||
order=RecordingOrder.START_TIME,
|
||||
),
|
||||
),
|
||||
# Test time-based filtering
|
||||
EvalCase(
|
||||
input="Show recordings from yesterday",
|
||||
expected=DUMMY_CURRENT_FILTERS.model_copy(update={"date_from": "-1d", "date_to": "-1d"}),
|
||||
),
|
||||
EvalCase(
|
||||
input="Show me recordings since the 1st of August",
|
||||
expected=DUMMY_CURRENT_FILTERS.model_copy(update={"date_from": "2025-08-01T00:00:00:000"}),
|
||||
),
|
||||
EvalCase(
|
||||
input="Show me recordings until the 31st of August",
|
||||
expected=DUMMY_CURRENT_FILTERS.model_copy(update={"date_to": "2025-08-31T23:59:59:999"}),
|
||||
),
|
||||
EvalCase(
|
||||
input="Show me recordings from the 1st of September to the 31st of September",
|
||||
expected=DUMMY_CURRENT_FILTERS.model_copy(
|
||||
update={"date_from": "2025-09-01T00:00:00:000", "date_to": "2025-09-30T23:59:59:999"}
|
||||
),
|
||||
),
|
||||
EvalCase(
|
||||
input="Show me recordings from the 1st of September to the 1st of September",
|
||||
expected=DUMMY_CURRENT_FILTERS.model_copy(
|
||||
update={"date_from": "2025-09-01T00:00:00:000", "date_to": "2025-09-01T23:59:59:999"}
|
||||
),
|
||||
),
|
||||
EvalCase(
|
||||
input="show me recordings of users who signed up on mobile",
|
||||
expected=MaxRecordingUniversalFilters(
|
||||
date_from="-7d",
|
||||
date_to=None,
|
||||
duration=[
|
||||
RecordingDurationFilter(
|
||||
key=DurationType.DURATION, operator=PropertyOperator.GT, type="recording", value=60.0
|
||||
)
|
||||
],
|
||||
filter_group=MaxOuterUniversalFiltersGroup(
|
||||
type=FilterLogicalOperator.AND_,
|
||||
values=[
|
||||
MaxInnerUniversalFiltersGroup(
|
||||
type=FilterLogicalOperator.AND_,
|
||||
values=[
|
||||
EventPropertyFilter(
|
||||
key="$device_type",
|
||||
type="event",
|
||||
value=["Mobile"],
|
||||
operator=PropertyOperator.EXACT,
|
||||
)
|
||||
],
|
||||
)
|
||||
],
|
||||
),
|
||||
filter_test_accounts=True,
|
||||
order=RecordingOrder.START_TIME,
|
||||
),
|
||||
),
|
||||
EvalCase(
|
||||
input="Show recordings in an ascending order by duration",
|
||||
expected=MaxRecordingUniversalFilters(
|
||||
date_from="-7d",
|
||||
date_to=None,
|
||||
duration=[
|
||||
RecordingDurationFilter(
|
||||
key=DurationType.DURATION, operator=PropertyOperator.GT, type="recording", value=60.0
|
||||
)
|
||||
],
|
||||
filter_group=MaxOuterUniversalFiltersGroup(
|
||||
type=FilterLogicalOperator.AND_,
|
||||
values=[MaxInnerUniversalFiltersGroup(type=FilterLogicalOperator.AND_, values=[])],
|
||||
),
|
||||
filter_test_accounts=True,
|
||||
order=RecordingOrder.DURATION,
|
||||
order_direction=RecordingOrderDirection.ASC,
|
||||
),
|
||||
),
|
||||
],
|
||||
pytestconfig=pytestconfig,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
async def eval_tool_search_session_recordings_ask_user_for_help(call_search_session_recordings, pytestconfig):
|
||||
await MaxPublicEval(
|
||||
experiment_name="tool_search_session_recordings_ask_user_for_help",
|
||||
task=call_search_session_recordings,
|
||||
scores=[AskUserForHelp()],
|
||||
data=[
|
||||
EvalCase(input="Tell me something about insights", expected="clarify"),
|
||||
],
|
||||
pytestconfig=pytestconfig,
|
||||
)
|
||||
@@ -1,218 +1,28 @@
|
||||
import os
|
||||
import asyncio
|
||||
import datetime
|
||||
from collections import namedtuple
|
||||
from collections.abc import Generator, Sequence
|
||||
|
||||
import pytest
|
||||
from unittest import mock
|
||||
|
||||
from django.test import override_settings
|
||||
|
||||
from _pytest.terminal import TerminalReporter
|
||||
from braintrust import EvalAsync, Metadata, init_logger
|
||||
from braintrust.framework import EvalData, EvalScorer, EvalTask, Input, Output
|
||||
from braintrust_langchain import BraintrustCallbackHandler, set_global_handler
|
||||
|
||||
from posthog.schema import FailureMessage, HumanMessage, VisualizationMessage
|
||||
|
||||
# We want the PostHog django_db_setup fixture here
|
||||
from posthog.conftest import django_db_setup # noqa: F401
|
||||
from posthog.demo.matrix.manager import MatrixManager
|
||||
from posthog.models import Team
|
||||
from posthog.tasks.demo_create_data import HedgeboxMatrix
|
||||
|
||||
from ee.hogai.django_checkpoint.checkpointer import DjangoCheckpointer
|
||||
from ee.hogai.eval.scorers import PlanAndQueryOutput
|
||||
from ee.hogai.graph.graph import AssistantGraph, InsightsAssistantGraph
|
||||
from ee.hogai.utils.types import AssistantNodeName, AssistantState
|
||||
from ee.hogai.utils.types.base import AnyAssistantGeneratedQuery
|
||||
from ee.models.assistant import Conversation, CoreMemory
|
||||
|
||||
handler = BraintrustCallbackHandler()
|
||||
if os.environ.get("BRAINTRUST_API_KEY"):
|
||||
set_global_handler(handler)
|
||||
|
||||
EVAL_USER_FULL_NAME = "Karen Smith"
|
||||
|
||||
|
||||
def pytest_addoption(parser):
|
||||
# Example: pytest ee/hogai/eval/eval_sql.py --eval churn - to only run cases containing "churn" in input
|
||||
# Example: pytest ee/hogai/eval/ci/eval_sql.py --eval churn - to only run cases containing "churn" in input
|
||||
parser.addoption("--eval", action="store")
|
||||
|
||||
|
||||
async def MaxEval(
|
||||
experiment_name: str,
|
||||
data: EvalData[Input, Output],
|
||||
task: EvalTask[Input, Output],
|
||||
scores: Sequence[EvalScorer[Input, Output]],
|
||||
pytestconfig: pytest.Config,
|
||||
metadata: Metadata | None = None,
|
||||
):
|
||||
# We need to specify a separate project for each MaxEval() suite for comparison to baseline to work
|
||||
# That's the way Braintrust folks recommended - Braintrust projects are much more lightweight than PostHog ones
|
||||
project_name = f"max-ai-{experiment_name}"
|
||||
init_logger(project_name)
|
||||
|
||||
# Filter by --case <eval_case_name_part> pytest flag
|
||||
case_filter = pytestconfig.option.eval
|
||||
if case_filter:
|
||||
if asyncio.iscoroutine(data):
|
||||
data = await data
|
||||
data = [case for case in data if case_filter in str(case.input)] # type: ignore
|
||||
|
||||
result = await EvalAsync(
|
||||
project_name,
|
||||
data=data,
|
||||
task=task,
|
||||
scores=scores,
|
||||
timeout=60 * 8,
|
||||
max_concurrency=100,
|
||||
is_public=True,
|
||||
metadata=metadata,
|
||||
)
|
||||
if os.getenv("GITHUB_EVENT_NAME") == "pull_request":
|
||||
with open("eval_results.jsonl", "a") as f:
|
||||
f.write(result.summary.as_json() + "\n")
|
||||
return result
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def call_root_for_insight_generation(demo_org_team_user):
|
||||
# This graph structure will first get a plan, then generate the SQL query.
|
||||
|
||||
insights_subgraph = (
|
||||
# Insights subgraph without query execution, so we only create the queries
|
||||
InsightsAssistantGraph(demo_org_team_user[1], demo_org_team_user[2])
|
||||
.add_query_creation_flow(next_node=AssistantNodeName.END)
|
||||
.compile()
|
||||
)
|
||||
graph = (
|
||||
AssistantGraph(demo_org_team_user[1], demo_org_team_user[2])
|
||||
.add_edge(AssistantNodeName.START, AssistantNodeName.ROOT)
|
||||
.add_root(
|
||||
path_map={
|
||||
"insights": AssistantNodeName.INSIGHTS_SUBGRAPH,
|
||||
"root": AssistantNodeName.END,
|
||||
"search_documentation": AssistantNodeName.END,
|
||||
"end": AssistantNodeName.END,
|
||||
}
|
||||
)
|
||||
.add_node(AssistantNodeName.INSIGHTS_SUBGRAPH, insights_subgraph)
|
||||
.add_edge(AssistantNodeName.INSIGHTS_SUBGRAPH, AssistantNodeName.END)
|
||||
# TRICKY: We need to set a checkpointer here because async tests create a new event loop.
|
||||
.compile(checkpointer=DjangoCheckpointer())
|
||||
)
|
||||
|
||||
async def callable(query_with_extra_context: str | tuple[str, str]) -> PlanAndQueryOutput:
|
||||
# If query_with_extra_context is a tuple, the first element is the query, the second is the extra context
|
||||
# in case there's an ask_user tool call.
|
||||
query = query_with_extra_context[0] if isinstance(query_with_extra_context, tuple) else query_with_extra_context
|
||||
# Initial state for the graph
|
||||
initial_state = AssistantState(
|
||||
messages=[HumanMessage(content=f"Answer this question: {query}")],
|
||||
)
|
||||
conversation = await Conversation.objects.acreate(team=demo_org_team_user[1], user=demo_org_team_user[2])
|
||||
|
||||
# Invoke the graph. The state will be updated through planner and then generator.
|
||||
final_state_raw = await graph.ainvoke(initial_state, {"configurable": {"thread_id": conversation.id}})
|
||||
|
||||
final_state = AssistantState.model_validate(final_state_raw)
|
||||
|
||||
# If we have extra context for the potential ask_user tool, and there's no message of type ai/failure
|
||||
# or ai/visualization, we should answer with that extra context. We only do this once at most in an eval case.
|
||||
if isinstance(query_with_extra_context, tuple) and not any(
|
||||
isinstance(m, VisualizationMessage | FailureMessage) for m in final_state.messages
|
||||
):
|
||||
final_state.messages = [*final_state.messages, HumanMessage(content=query_with_extra_context[1])]
|
||||
final_state.graph_status = "resumed"
|
||||
final_state_raw = await graph.ainvoke(final_state, {"configurable": {"thread_id": conversation.id}})
|
||||
final_state = AssistantState.model_validate(final_state_raw)
|
||||
|
||||
if (
|
||||
not final_state.messages
|
||||
or not isinstance(final_state.messages[-1], VisualizationMessage)
|
||||
or not isinstance(final_state.messages[-1].answer, AnyAssistantGeneratedQuery)
|
||||
):
|
||||
return {
|
||||
"plan": None,
|
||||
"query": None,
|
||||
"query_generation_retry_count": final_state.query_generation_retry_count,
|
||||
}
|
||||
return {
|
||||
"plan": final_state.messages[-1].plan,
|
||||
"query": final_state.messages[-1].answer,
|
||||
"query_generation_retry_count": final_state.query_generation_retry_count,
|
||||
}
|
||||
|
||||
return callable
|
||||
|
||||
|
||||
@pytest.fixture(scope="package")
|
||||
def demo_org_team_user(django_db_setup, django_db_blocker): # noqa: F811
|
||||
with django_db_blocker.unblock():
|
||||
team = Team.objects.order_by("-created_at").first()
|
||||
today = datetime.date.today()
|
||||
# If there's no eval team or it's older than today, we need to create a new one with fresh data
|
||||
should_create_new_team = not team or team.created_at.date() < today
|
||||
|
||||
if should_create_new_team:
|
||||
print(f"Generating fresh demo data for evals...") # noqa: T201
|
||||
|
||||
matrix = HedgeboxMatrix(
|
||||
seed="b1ef3c66-5f43-488a-98be-6b46d92fbcef", # this seed generates all events
|
||||
days_past=120,
|
||||
days_future=30,
|
||||
n_clusters=500,
|
||||
group_type_index_offset=0,
|
||||
)
|
||||
matrix_manager = MatrixManager(matrix, print_steps=True)
|
||||
with override_settings(TEST=False):
|
||||
# Simulation saving should occur in non-test mode, so that Kafka isn't mocked. Normally in tests we don't
|
||||
# want to ingest via Kafka, but simulation saving is specifically designed to use that route for speed
|
||||
org, team, user = matrix_manager.ensure_account_and_save(
|
||||
f"eval-{today.isoformat()}", EVAL_USER_FULL_NAME, "Hedgebox Inc."
|
||||
)
|
||||
else:
|
||||
print(f"Using existing demo data for evals...") # noqa: T201
|
||||
org = team.organization
|
||||
user = org.memberships.first().user
|
||||
|
||||
yield org, team, user
|
||||
|
||||
|
||||
@pytest.fixture(scope="package", autouse=True)
|
||||
def core_memory(demo_org_team_user, django_db_blocker) -> Generator[CoreMemory, None, None]:
|
||||
initial_memory = """Hedgebox is a cloud storage service enabling users to store, share, and access files across devices.
|
||||
|
||||
The company operates in the cloud storage and collaboration market for individuals and businesses.
|
||||
|
||||
Their audience includes professionals and organizations seeking file management and collaboration solutions.
|
||||
|
||||
Hedgebox's freemium model provides free accounts with limited storage and paid subscription plans for additional features.
|
||||
|
||||
Core features include file storage, synchronization, sharing, and collaboration tools for seamless file access and sharing.
|
||||
|
||||
It integrates with third-party applications to enhance functionality and streamline workflows.
|
||||
|
||||
Hedgebox sponsors the YouTube channel Marius Tech Tips."""
|
||||
|
||||
with django_db_blocker.unblock():
|
||||
core_memory, _ = CoreMemory.objects.get_or_create(
|
||||
team=demo_org_team_user[1],
|
||||
defaults={
|
||||
"text": initial_memory,
|
||||
"initial_text": initial_memory,
|
||||
"scraping_status": CoreMemory.ScrapingStatus.COMPLETED,
|
||||
},
|
||||
)
|
||||
yield core_memory
|
||||
|
||||
|
||||
_nodeid_to_results_url_map: dict[str, str] = {}
|
||||
"""Map of test nodeid (file + test name) to Braintrust results URL."""
|
||||
|
||||
|
||||
@pytest.fixture(scope="package")
|
||||
def set_up_evals(django_db_setup): # noqa: F811
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def capture_stdout(request, capsys):
|
||||
yield
|
||||
|
||||
0
ee/hogai/eval/offline/__init__.py
Normal file
70
ee/hogai/eval/offline/conftest.py
Normal file
@@ -0,0 +1,70 @@
|
||||
from collections.abc import Generator
|
||||
from typing import Annotated
|
||||
|
||||
import pytest
|
||||
|
||||
from asgiref.sync import async_to_sync
|
||||
from dagster_pipes import PipesContext, open_dagster_pipes
|
||||
from pydantic import BaseModel, ConfigDict, SkipValidation
|
||||
|
||||
from posthog.models import Organization, User
|
||||
|
||||
# We want the PostHog set_up_evals fixture here
|
||||
from ee.hogai.eval.conftest import set_up_evals # noqa: F401
|
||||
from ee.hogai.eval.offline.snapshot_loader import SnapshotLoader
|
||||
from ee.hogai.eval.schema import DatasetInput
|
||||
|
||||
|
||||
@pytest.fixture(scope="package")
|
||||
def dagster_context() -> Generator[PipesContext, None, None]:
|
||||
with open_dagster_pipes() as context:
|
||||
yield context
|
||||
|
||||
|
||||
class EvaluationContext(BaseModel):
|
||||
# We don't want to validate Django models here.
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
organization: Annotated[Organization, SkipValidation]
|
||||
user: Annotated[User, SkipValidation]
|
||||
experiment_name: str
|
||||
dataset: list[DatasetInput]
|
||||
|
||||
|
||||
@pytest.fixture(scope="package", autouse=True)
|
||||
def eval_ctx(
|
||||
set_up_evals, # noqa: F811
|
||||
dagster_context: PipesContext,
|
||||
django_db_blocker,
|
||||
) -> Generator[EvaluationContext, None, None]:
|
||||
"""
|
||||
Script that restores dumped Django models and patches AI query runners.
|
||||
Creates teams with team_id=project_id for the same single user and organization,
|
||||
keeping the original project_ids for teams.
|
||||
"""
|
||||
with django_db_blocker.unblock():
|
||||
dagster_context.log.info(f"Loading Postgres and ClickHouse snapshots...")
|
||||
|
||||
loader = SnapshotLoader(dagster_context)
|
||||
org, user, dataset = async_to_sync(loader.load_snapshots)()
|
||||
|
||||
dagster_context.log.info(f"Running tests...")
|
||||
yield EvaluationContext(
|
||||
organization=org,
|
||||
user=user,
|
||||
experiment_name=loader.config.experiment_name,
|
||||
dataset=dataset,
|
||||
)
|
||||
|
||||
dagster_context.log.info(f"Cleaning up...")
|
||||
loader.cleanup()
|
||||
|
||||
dagster_context.log.info(f"Reporting results...")
|
||||
with open("eval_results.jsonl") as f:
|
||||
lines = f.readlines()
|
||||
dagster_context.report_asset_materialization(
|
||||
asset_key="evaluation_report",
|
||||
metadata={
|
||||
"output": "\n".join(lines),
|
||||
},
|
||||
)
|
||||
107
ee/hogai/eval/offline/eval_sql.py
Normal file
@@ -0,0 +1,107 @@
|
||||
import asyncio
|
||||
from typing import TypedDict, cast
|
||||
|
||||
import pytest
|
||||
|
||||
from braintrust import EvalCase, Score
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from posthog.schema import AssistantHogQLQuery, HumanMessage, VisualizationMessage
|
||||
|
||||
from posthog.hogql.context import HogQLContext
|
||||
from posthog.hogql.database.database import create_hogql_database
|
||||
|
||||
from posthog.models import Team
|
||||
from posthog.sync import database_sync_to_async
|
||||
|
||||
from ee.hogai.eval.base import MaxPrivateEval
|
||||
from ee.hogai.eval.offline.conftest import EvaluationContext
|
||||
from ee.hogai.eval.schema import DatasetInput
|
||||
from ee.hogai.eval.scorers.sql import SQLSemanticsCorrectness, SQLSyntaxCorrectness
|
||||
from ee.hogai.graph import AssistantGraph
|
||||
from ee.hogai.utils.helpers import find_last_message_of_type
|
||||
from ee.hogai.utils.types import AssistantState
|
||||
from ee.hogai.utils.warehouse import serialize_database_schema
|
||||
from ee.models import Conversation
|
||||
|
||||
|
||||
class EvalOutput(BaseModel):
|
||||
database_schema: str
|
||||
query_kind: str | None = Field(default=None)
|
||||
sql_query: str | None = Field(default=None)
|
||||
|
||||
|
||||
class EvalMetadata(TypedDict):
|
||||
team_id: int
|
||||
|
||||
|
||||
async def serialize_database(team: Team):
|
||||
database = await database_sync_to_async(create_hogql_database)(team=team)
|
||||
context = HogQLContext(team=team, database=database, enable_select_queries=True)
|
||||
return await serialize_database_schema(database, context)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def call_graph(eval_ctx: EvaluationContext):
|
||||
async def callable(entry: DatasetInput, *args, **kwargs) -> EvalOutput:
|
||||
team = await Team.objects.aget(id=entry.team_id)
|
||||
conversation, database_schema = await asyncio.gather(
|
||||
Conversation.objects.acreate(team=team, user=eval_ctx.user),
|
||||
serialize_database(team),
|
||||
)
|
||||
graph = AssistantGraph(team, eval_ctx.user).compile_full_graph()
|
||||
|
||||
state = await graph.ainvoke(
|
||||
AssistantState(messages=[HumanMessage(content=entry.input["query"])]),
|
||||
{
|
||||
"configurable": {
|
||||
"thread_id": conversation.id,
|
||||
"user": eval_ctx.user,
|
||||
"team": team,
|
||||
"distinct_id": eval_ctx.user.distinct_id,
|
||||
}
|
||||
},
|
||||
)
|
||||
maybe_viz_message = find_last_message_of_type(state["messages"], VisualizationMessage)
|
||||
if maybe_viz_message:
|
||||
return EvalOutput(
|
||||
database_schema=database_schema,
|
||||
query_kind=maybe_viz_message.answer.kind,
|
||||
sql_query=maybe_viz_message.answer.query
|
||||
if isinstance(maybe_viz_message.answer, AssistantHogQLQuery)
|
||||
else None,
|
||||
)
|
||||
return EvalOutput(database_schema=database_schema)
|
||||
|
||||
return callable
|
||||
|
||||
|
||||
async def sql_semantics_scorer(input: DatasetInput, expected: str, output: EvalOutput, **kwargs) -> Score:
|
||||
metric = SQLSemanticsCorrectness()
|
||||
return await metric.eval_async(
|
||||
input=input.input["query"], expected=expected, output=output.sql_query, database_schema=output.database_schema
|
||||
)
|
||||
|
||||
|
||||
async def sql_syntax_scorer(input: DatasetInput, expected: str, output: EvalOutput, **kwargs) -> Score:
|
||||
metric = SQLSyntaxCorrectness()
|
||||
return await metric.eval_async(
|
||||
input=input.input["query"], expected=expected, output=output.sql_query, database_schema=output.database_schema
|
||||
)
|
||||
|
||||
|
||||
def generate_test_cases(eval_ctx: EvaluationContext):
|
||||
for entry in eval_ctx.dataset:
|
||||
metadata: EvalMetadata = {"team_id": entry.team_id}
|
||||
yield EvalCase(input=entry, expected=entry.expected["output"], metadata=cast(dict, metadata))
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
async def eval_offline_sql(eval_ctx: EvaluationContext, call_graph, pytestconfig):
|
||||
await MaxPrivateEval(
|
||||
experiment_name=eval_ctx.experiment_name,
|
||||
task=call_graph,
|
||||
scores=[sql_syntax_scorer, sql_semantics_scorer],
|
||||
data=generate_test_cases(eval_ctx),
|
||||
pytestconfig=pytestconfig,
|
||||
)
|
||||
65
ee/hogai/eval/offline/query_patches.py
Normal file
@@ -0,0 +1,65 @@
|
||||
from collections import defaultdict
|
||||
from typing import Literal
|
||||
|
||||
from posthog.schema import (
|
||||
ActorsPropertyTaxonomyQueryResponse,
|
||||
ActorsPropertyTaxonomyResponse,
|
||||
EventTaxonomyItem,
|
||||
EventTaxonomyQueryResponse,
|
||||
TeamTaxonomyItem,
|
||||
TeamTaxonomyQueryResponse,
|
||||
)
|
||||
|
||||
from posthog.hogql_queries.ai.actors_property_taxonomy_query_runner import ActorsPropertyTaxonomyQueryRunner
|
||||
from posthog.hogql_queries.ai.event_taxonomy_query_runner import EventTaxonomyQueryRunner
|
||||
from posthog.hogql_queries.ai.team_taxonomy_query_runner import TeamTaxonomyQueryRunner
|
||||
|
||||
# This is a global state that is used to store the patched results for the team taxonomy query.
|
||||
TEAM_TAXONOMY_QUERY_DATA_SOURCE: dict[int, list[TeamTaxonomyItem]] = {}
|
||||
|
||||
|
||||
class PatchedTeamTaxonomyQueryRunner(TeamTaxonomyQueryRunner):
|
||||
def _calculate(self):
|
||||
results: list[TeamTaxonomyItem] = []
|
||||
if precomputed_results := TEAM_TAXONOMY_QUERY_DATA_SOURCE.get(self.team.id):
|
||||
results = precomputed_results
|
||||
return TeamTaxonomyQueryResponse(results=results, modifiers=self.modifiers)
|
||||
|
||||
|
||||
# This is a global state that is used to store the patched results for the event taxonomy query.
|
||||
EVENT_TAXONOMY_QUERY_DATA_SOURCE: dict[int, dict[str | int, list[EventTaxonomyItem]]] = defaultdict(dict)
|
||||
|
||||
|
||||
class PatchedEventTaxonomyQueryRunner(EventTaxonomyQueryRunner):
|
||||
def _calculate(self):
|
||||
results: list[EventTaxonomyItem] = []
|
||||
team_data = EVENT_TAXONOMY_QUERY_DATA_SOURCE.get(self.team.id, {})
|
||||
if self.query.event in team_data:
|
||||
results = team_data[self.query.event]
|
||||
elif self.query.actionId in team_data:
|
||||
results = team_data[self.query.actionId]
|
||||
return EventTaxonomyQueryResponse(results=results, modifiers=self.modifiers)
|
||||
|
||||
|
||||
# This is a global state that is used to store the patched results for the actors property taxonomy query.
|
||||
ACTORS_PROPERTY_TAXONOMY_QUERY_DATA_SOURCE: dict[
|
||||
int, dict[int | Literal["person"], dict[str, ActorsPropertyTaxonomyResponse]]
|
||||
] = defaultdict(lambda: defaultdict(dict))
|
||||
|
||||
|
||||
class PatchedActorsPropertyTaxonomyQueryRunner(ActorsPropertyTaxonomyQueryRunner):
|
||||
def _calculate(self):
|
||||
key: int | Literal["person"] = (
|
||||
self.query.groupTypeIndex if isinstance(self.query.groupTypeIndex, int) else "person"
|
||||
)
|
||||
if (
|
||||
self.team.id in ACTORS_PROPERTY_TAXONOMY_QUERY_DATA_SOURCE
|
||||
and key in ACTORS_PROPERTY_TAXONOMY_QUERY_DATA_SOURCE[self.team.id]
|
||||
):
|
||||
data = ACTORS_PROPERTY_TAXONOMY_QUERY_DATA_SOURCE[self.team.id][key]
|
||||
result: list[ActorsPropertyTaxonomyResponse] = []
|
||||
for prop in self.query.properties:
|
||||
result.append(data.get(prop, ActorsPropertyTaxonomyResponse(sample_values=[], sample_count=0)))
|
||||
else:
|
||||
result = [ActorsPropertyTaxonomyResponse(sample_values=[], sample_count=0) for _ in self.query.properties]
|
||||
return ActorsPropertyTaxonomyQueryResponse(results=result, modifiers=self.modifiers)
|
||||
182
ee/hogai/eval/offline/snapshot_loader.py
Normal file
@@ -0,0 +1,182 @@
|
||||
import os
|
||||
import asyncio
|
||||
from collections.abc import Generator
|
||||
from io import BytesIO
|
||||
from typing import TYPE_CHECKING, Any, Literal, TypeVar
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
import backoff
|
||||
import aioboto3
|
||||
from asgiref.sync import sync_to_async
|
||||
from dagster_pipes import PipesContext
|
||||
from fastavro import reader
|
||||
from pydantic_avro import AvroBase
|
||||
|
||||
from posthog.models import GroupTypeMapping, Organization, Project, PropertyDefinition, Team, User
|
||||
from posthog.warehouse.models.table import DataWarehouseTable
|
||||
|
||||
from ee.hogai.eval.schema import (
|
||||
ActorsPropertyTaxonomySnapshot,
|
||||
DatasetInput,
|
||||
DataWarehouseTableSnapshot,
|
||||
EvalsDockerImageConfig,
|
||||
GroupTypeMappingSnapshot,
|
||||
PropertyDefinitionSnapshot,
|
||||
PropertyTaxonomySnapshot,
|
||||
TeamEvaluationSnapshot,
|
||||
TeamSnapshot,
|
||||
TeamTaxonomyItemSnapshot,
|
||||
)
|
||||
|
||||
from .query_patches import (
|
||||
ACTORS_PROPERTY_TAXONOMY_QUERY_DATA_SOURCE,
|
||||
EVENT_TAXONOMY_QUERY_DATA_SOURCE,
|
||||
TEAM_TAXONOMY_QUERY_DATA_SOURCE,
|
||||
PatchedActorsPropertyTaxonomyQueryRunner,
|
||||
PatchedEventTaxonomyQueryRunner,
|
||||
PatchedTeamTaxonomyQueryRunner,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from types_aiobotocore_s3.client import S3Client
|
||||
|
||||
|
||||
T = TypeVar("T", bound=AvroBase)
|
||||
|
||||
|
||||
class SnapshotLoader:
|
||||
"""Loads snapshots from S3, restores Django models, and patches query runners."""
|
||||
|
||||
def __init__(self, context: PipesContext):
|
||||
self.context = context
|
||||
self.config = EvalsDockerImageConfig.model_validate(context.extras)
|
||||
self.patches: list[Any] = []
|
||||
|
||||
async def load_snapshots(self) -> tuple[Organization, User, list[DatasetInput]]:
|
||||
self.organization = await Organization.objects.acreate(name="PostHog")
|
||||
self.user = await sync_to_async(User.objects.create_and_join)(self.organization, "test@posthog.com", "12345678")
|
||||
|
||||
for snapshot in self.config.team_snapshots:
|
||||
self.context.log.info(f"Loading Postgres snapshot for team {snapshot.team_id}...")
|
||||
|
||||
project = await Project.objects.acreate(
|
||||
id=await sync_to_async(Team.objects.increment_id_sequence)(), organization=self.organization
|
||||
)
|
||||
|
||||
(
|
||||
project_snapshot_bytes,
|
||||
property_definitions_snapshot_bytes,
|
||||
group_type_mappings_snapshot_bytes,
|
||||
data_warehouse_tables_snapshot_bytes,
|
||||
event_taxonomy_snapshot_bytes,
|
||||
properties_taxonomy_snapshot_bytes,
|
||||
actors_property_taxonomy_snapshot_bytes,
|
||||
) = await self._get_all_snapshots(snapshot)
|
||||
|
||||
team = await self._load_team_snapshot(project, snapshot.team_id, project_snapshot_bytes)
|
||||
await asyncio.gather(
|
||||
self._load_property_definitions(property_definitions_snapshot_bytes, team=team, project=project),
|
||||
self._load_group_type_mappings(group_type_mappings_snapshot_bytes, team=team, project=project),
|
||||
self._load_data_warehouse_tables(data_warehouse_tables_snapshot_bytes, team=team, project=project),
|
||||
)
|
||||
self._load_event_taxonomy(event_taxonomy_snapshot_bytes, team=team)
|
||||
self._load_properties_taxonomy(properties_taxonomy_snapshot_bytes, team=team)
|
||||
self._load_actors_property_taxonomy(actors_property_taxonomy_snapshot_bytes, team=team)
|
||||
|
||||
self._patch_query_runners()
|
||||
|
||||
return self.organization, self.user, self.config.dataset
|
||||
|
||||
def cleanup(self):
|
||||
for mock in self.patches:
|
||||
mock.stop()
|
||||
|
||||
@backoff.on_exception(backoff.expo, Exception, max_tries=3)
|
||||
async def _get_all_snapshots(self, snapshot: TeamEvaluationSnapshot):
|
||||
async with aioboto3.Session().client(
|
||||
"s3",
|
||||
endpoint_url=self.config.aws_endpoint_url,
|
||||
aws_access_key_id=os.getenv("OBJECT_STORAGE_ACCESS_KEY_ID"),
|
||||
aws_secret_access_key=os.getenv("OBJECT_STORAGE_SECRET_ACCESS_KEY"),
|
||||
) as client:
|
||||
loaded_snapshots = await asyncio.gather(
|
||||
self._get_snapshot_from_s3(client, snapshot.postgres.team),
|
||||
self._get_snapshot_from_s3(client, snapshot.postgres.property_definitions),
|
||||
self._get_snapshot_from_s3(client, snapshot.postgres.group_type_mappings),
|
||||
self._get_snapshot_from_s3(client, snapshot.postgres.data_warehouse_tables),
|
||||
self._get_snapshot_from_s3(client, snapshot.clickhouse.event_taxonomy),
|
||||
self._get_snapshot_from_s3(client, snapshot.clickhouse.properties_taxonomy),
|
||||
self._get_snapshot_from_s3(client, snapshot.clickhouse.actors_property_taxonomy),
|
||||
)
|
||||
return loaded_snapshots
|
||||
|
||||
async def _get_snapshot_from_s3(self, client: "S3Client", file_key: str):
|
||||
response = await client.get_object(Bucket=self.config.aws_bucket_name, Key=file_key)
|
||||
content = await response["Body"].read()
|
||||
return BytesIO(content)
|
||||
|
||||
def _parse_snapshot_to_schema(self, schema: type[T], buffer: BytesIO) -> Generator[T, None, None]:
|
||||
for record in reader(buffer):
|
||||
yield schema.model_validate(record)
|
||||
|
||||
async def _load_team_snapshot(self, project: Project, team_id: int, buffer: BytesIO) -> Team:
|
||||
team_snapshot = next(self._parse_snapshot_to_schema(TeamSnapshot, buffer))
|
||||
team = next(TeamSnapshot.deserialize_for_team([team_snapshot], team_id=team_id, project_id=project.id))
|
||||
team.project = project
|
||||
team.organization = self.organization
|
||||
team.api_token = f"team_{team_id}"
|
||||
await team.asave()
|
||||
return team
|
||||
|
||||
async def _load_property_definitions(self, buffer: BytesIO, *, team: Team, project: Project):
|
||||
snapshot = list(self._parse_snapshot_to_schema(PropertyDefinitionSnapshot, buffer))
|
||||
property_definitions = PropertyDefinitionSnapshot.deserialize_for_team(
|
||||
snapshot, team_id=team.id, project_id=project.id
|
||||
)
|
||||
return await PropertyDefinition.objects.abulk_create(property_definitions, batch_size=500)
|
||||
|
||||
async def _load_group_type_mappings(self, buffer: BytesIO, *, team: Team, project: Project):
|
||||
snapshot = list(self._parse_snapshot_to_schema(GroupTypeMappingSnapshot, buffer))
|
||||
group_type_mappings = GroupTypeMappingSnapshot.deserialize_for_team(
|
||||
snapshot, team_id=team.id, project_id=project.id
|
||||
)
|
||||
return await GroupTypeMapping.objects.abulk_create(group_type_mappings, batch_size=500)
|
||||
|
||||
async def _load_data_warehouse_tables(self, buffer: BytesIO, *, team: Team, project: Project):
|
||||
snapshot = list(self._parse_snapshot_to_schema(DataWarehouseTableSnapshot, buffer))
|
||||
data_warehouse_tables = DataWarehouseTableSnapshot.deserialize_for_team(
|
||||
snapshot, team_id=team.id, project_id=project.id
|
||||
)
|
||||
return await DataWarehouseTable.objects.abulk_create(data_warehouse_tables, batch_size=500)
|
||||
|
||||
def _load_event_taxonomy(self, buffer: BytesIO, *, team: Team):
|
||||
snapshot = next(self._parse_snapshot_to_schema(TeamTaxonomyItemSnapshot, buffer))
|
||||
TEAM_TAXONOMY_QUERY_DATA_SOURCE[team.id] = snapshot.results
|
||||
|
||||
def _load_properties_taxonomy(self, buffer: BytesIO, *, team: Team):
|
||||
for item in self._parse_snapshot_to_schema(PropertyTaxonomySnapshot, buffer):
|
||||
EVENT_TAXONOMY_QUERY_DATA_SOURCE[team.id][item.event] = item.results
|
||||
|
||||
def _load_actors_property_taxonomy(self, buffer: BytesIO, *, team: Team):
|
||||
for item in self._parse_snapshot_to_schema(ActorsPropertyTaxonomySnapshot, buffer):
|
||||
key: int | Literal["person"] = item.group_type_index if isinstance(item.group_type_index, int) else "person"
|
||||
ACTORS_PROPERTY_TAXONOMY_QUERY_DATA_SOURCE[team.pk][key][item.property] = item.results
|
||||
|
||||
def _patch_query_runners(self):
|
||||
self.patches = [
|
||||
patch(
|
||||
"posthog.hogql_queries.ai.team_taxonomy_query_runner.TeamTaxonomyQueryRunner",
|
||||
new=PatchedTeamTaxonomyQueryRunner,
|
||||
),
|
||||
patch(
|
||||
"posthog.hogql_queries.ai.event_taxonomy_query_runner.EventTaxonomyQueryRunner",
|
||||
new=PatchedEventTaxonomyQueryRunner,
|
||||
),
|
||||
patch(
|
||||
"posthog.hogql_queries.ai.actors_property_taxonomy_query_runner.ActorsPropertyTaxonomyQueryRunner",
|
||||
new=PatchedActorsPropertyTaxonomyQueryRunner,
|
||||
),
|
||||
]
|
||||
for mock in self.patches:
|
||||
mock.start()
|
||||
@@ -1,11 +1,11 @@
|
||||
import json
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Generator, Sequence
|
||||
from typing import Generic, Self, TypeVar
|
||||
from typing import Any, Generic, Self, TypeVar
|
||||
|
||||
from django.db.models import Model
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic_avro import AvroBase
|
||||
|
||||
from posthog.schema import ActorsPropertyTaxonomyResponse, EventTaxonomyItem, TeamTaxonomyItem
|
||||
@@ -18,14 +18,12 @@ T = TypeVar("T", bound=Model)
|
||||
class BaseSnapshot(AvroBase, ABC, Generic[T]):
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def serialize_for_project(cls, project_id: int) -> Generator[Self, None, None]:
|
||||
def serialize_for_team(cls, *, team_id: int) -> Generator[Self, None, None]:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def deserialize_for_project(
|
||||
cls, project_id: int, models: Sequence[Self], *, team_id: int
|
||||
) -> Generator[T, None, None]:
|
||||
def deserialize_for_team(cls, models: Sequence[Self], *, team_id: int, project_id: int) -> Generator[T, None, None]:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@@ -35,15 +33,18 @@ class TeamSnapshot(BaseSnapshot[Team]):
|
||||
test_account_filters: str
|
||||
|
||||
@classmethod
|
||||
def serialize_for_project(cls, project_id: int):
|
||||
team = Team.objects.get(pk=project_id)
|
||||
def serialize_for_team(cls, *, team_id: int):
|
||||
team = Team.objects.get(pk=team_id)
|
||||
yield TeamSnapshot(name=team.name, test_account_filters=json.dumps(team.test_account_filters))
|
||||
|
||||
@classmethod
|
||||
def deserialize_for_project(cls, project_id: int, models: Sequence[Self], **kwargs) -> Generator[Team, None, None]:
|
||||
def deserialize_for_team(
|
||||
cls, models: Sequence[Self], *, team_id: int, project_id: int
|
||||
) -> Generator[Team, None, None]:
|
||||
for model in models:
|
||||
yield Team(
|
||||
id=project_id,
|
||||
id=team_id,
|
||||
project_id=project_id,
|
||||
name=model.name,
|
||||
test_account_filters=json.loads(model.test_account_filters),
|
||||
)
|
||||
@@ -58,8 +59,8 @@ class PropertyDefinitionSnapshot(BaseSnapshot[PropertyDefinition]):
|
||||
group_type_index: int | None
|
||||
|
||||
@classmethod
|
||||
def serialize_for_project(cls, project_id: int):
|
||||
for prop in PropertyDefinition.objects.filter(project_id=project_id).iterator(500):
|
||||
def serialize_for_team(cls, *, team_id: int):
|
||||
for prop in PropertyDefinition.objects.filter(team_id=team_id).iterator(500):
|
||||
yield PropertyDefinitionSnapshot(
|
||||
name=prop.name,
|
||||
is_numerical=prop.is_numerical,
|
||||
@@ -69,7 +70,7 @@ class PropertyDefinitionSnapshot(BaseSnapshot[PropertyDefinition]):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def deserialize_for_project(cls, project_id: int, models: Sequence[Self], **kwargs):
|
||||
def deserialize_for_team(cls, models: Sequence[Self], *, team_id: int, project_id: int):
|
||||
for model in models:
|
||||
yield PropertyDefinition(
|
||||
name=model.name,
|
||||
@@ -77,7 +78,8 @@ class PropertyDefinitionSnapshot(BaseSnapshot[PropertyDefinition]):
|
||||
property_type=model.property_type,
|
||||
type=model.type,
|
||||
group_type_index=model.group_type_index,
|
||||
team_id=project_id,
|
||||
team_id=team_id,
|
||||
project_id=project_id,
|
||||
)
|
||||
|
||||
|
||||
@@ -89,8 +91,8 @@ class GroupTypeMappingSnapshot(BaseSnapshot[GroupTypeMapping]):
|
||||
name_plural: str | None
|
||||
|
||||
@classmethod
|
||||
def serialize_for_project(cls, project_id: int):
|
||||
for mapping in GroupTypeMapping.objects.filter(project_id=project_id).iterator(500):
|
||||
def serialize_for_team(cls, *, team_id: int):
|
||||
for mapping in GroupTypeMapping.objects.filter(team_id=team_id).iterator(500):
|
||||
yield GroupTypeMappingSnapshot(
|
||||
group_type=mapping.group_type,
|
||||
group_type_index=mapping.group_type_index,
|
||||
@@ -99,7 +101,7 @@ class GroupTypeMappingSnapshot(BaseSnapshot[GroupTypeMapping]):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def deserialize_for_project(cls, project_id: int, models: Sequence[Self], *, team_id: int):
|
||||
def deserialize_for_team(cls, models: Sequence[Self], *, team_id: int, project_id: int):
|
||||
for model in models:
|
||||
yield GroupTypeMapping(
|
||||
group_type=model.group_type,
|
||||
@@ -118,8 +120,8 @@ class DataWarehouseTableSnapshot(BaseSnapshot[DataWarehouseTable]):
|
||||
columns: dict
|
||||
|
||||
@classmethod
|
||||
def serialize_for_project(cls, project_id: int):
|
||||
for table in DataWarehouseTable.objects.filter(team_id=project_id).iterator(500):
|
||||
def serialize_for_team(cls, *, team_id: int):
|
||||
for table in DataWarehouseTable.objects.filter(team_id=team_id).iterator(500):
|
||||
yield DataWarehouseTableSnapshot(
|
||||
name=table.name,
|
||||
format=table.format,
|
||||
@@ -127,19 +129,19 @@ class DataWarehouseTableSnapshot(BaseSnapshot[DataWarehouseTable]):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def deserialize_for_project(cls, project_id: int, models: Sequence[Self], **kwargs):
|
||||
def deserialize_for_team(cls, models: Sequence[Self], *, team_id: int, project_id: int):
|
||||
for model in models:
|
||||
yield DataWarehouseTable(
|
||||
name=model.name,
|
||||
format=model.format,
|
||||
columns=model.columns,
|
||||
url_pattern="http://localhost", # Hardcoded. It's not important for evaluations what the value is.
|
||||
team_id=project_id,
|
||||
team_id=team_id,
|
||||
)
|
||||
|
||||
|
||||
class PostgresProjectDataSnapshot(BaseModel):
|
||||
project: str
|
||||
class PostgresTeamDataSnapshot(BaseModel):
|
||||
team: str
|
||||
property_definitions: str
|
||||
group_type_mappings: str
|
||||
data_warehouse_tables: str
|
||||
@@ -163,7 +165,46 @@ class ActorsPropertyTaxonomySnapshot(AvroBase):
|
||||
results: ActorsPropertyTaxonomyResponse
|
||||
|
||||
|
||||
class ClickhouseProjectDataSnapshot(BaseModel):
|
||||
class ClickhouseTeamDataSnapshot(BaseModel):
|
||||
event_taxonomy: str
|
||||
properties_taxonomy: str
|
||||
actors_property_taxonomy: str
|
||||
|
||||
|
||||
class TeamEvaluationSnapshot(BaseModel):
|
||||
team_id: int
|
||||
postgres: PostgresTeamDataSnapshot
|
||||
clickhouse: ClickhouseTeamDataSnapshot
|
||||
|
||||
|
||||
class DatasetInput(BaseModel):
|
||||
team_id: int
|
||||
input: dict[str, Any]
|
||||
expected: dict[str, Any]
|
||||
metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class EvalsDockerImageConfig(BaseModel):
|
||||
class Config:
|
||||
extra = "allow"
|
||||
|
||||
aws_bucket_name: str
|
||||
"""
|
||||
AWS S3 bucket name for the raw snapshots for all projects.
|
||||
"""
|
||||
aws_endpoint_url: str
|
||||
"""
|
||||
AWS S3 endpoint URL for the raw snapshots for all projects.
|
||||
"""
|
||||
team_snapshots: list[TeamEvaluationSnapshot]
|
||||
"""
|
||||
Raw snapshots for all projects.
|
||||
"""
|
||||
experiment_name: str
|
||||
"""
|
||||
Name of the experiment.
|
||||
"""
|
||||
dataset: list[DatasetInput]
|
||||
"""
|
||||
Parsed dataset.
|
||||
"""
|
||||
|
||||
@@ -1,15 +1,27 @@
|
||||
import json
|
||||
from typing import TypedDict
|
||||
from typing import Any, TypedDict, cast
|
||||
|
||||
from autoevals.llm import LLMClassifier
|
||||
from autoevals.partial import ScorerWithPartial
|
||||
from autoevals.ragas import AnswerSimilarity
|
||||
from braintrust import Score
|
||||
from langchain_core.messages import AIMessage as LangchainAIMessage
|
||||
|
||||
from posthog.schema import AssistantMessage, AssistantToolCall, NodeKind
|
||||
|
||||
from ee.hogai.utils.types.base import AnyAssistantGeneratedQuery
|
||||
from ee.hogai.utils.types.base import AnyAssistantGeneratedQuery, AnyAssistantSupportedQuery
|
||||
|
||||
from .sql import SQLSemanticsCorrectness, SQLSyntaxCorrectness
|
||||
|
||||
__all__ = [
|
||||
"SQLSemanticsCorrectness",
|
||||
"SQLSyntaxCorrectness",
|
||||
"ToolRelevance",
|
||||
"QueryKindSelection",
|
||||
"PlanCorrectness",
|
||||
"QueryAndPlanAlignment",
|
||||
"TimeRangeRelevancy",
|
||||
"InsightEvaluationAccuracy",
|
||||
]
|
||||
|
||||
|
||||
class ToolRelevance(ScorerWithPartial):
|
||||
@@ -25,7 +37,7 @@ class ToolRelevance(ScorerWithPartial):
|
||||
return Score(name=self._name(), score=0)
|
||||
if not isinstance(expected, AssistantToolCall):
|
||||
raise TypeError(f"Eval case expected must be an AssistantToolCall, not {type(expected)}")
|
||||
if not isinstance(output, AssistantMessage | LangchainAIMessage):
|
||||
if not isinstance(output, AssistantMessage):
|
||||
raise TypeError(f"Eval case output must be an AssistantMessage, not {type(output)}")
|
||||
if output.tool_calls and len(output.tool_calls) > 1:
|
||||
raise ValueError("Parallel tool calls not supported by this scorer yet")
|
||||
@@ -52,16 +64,24 @@ class ToolRelevance(ScorerWithPartial):
|
||||
|
||||
class PlanAndQueryOutput(TypedDict, total=False):
|
||||
plan: str | None
|
||||
query: AnyAssistantGeneratedQuery
|
||||
query: AnyAssistantGeneratedQuery | AnyAssistantSupportedQuery | None
|
||||
query_generation_retry_count: int | None
|
||||
|
||||
|
||||
def serialize_output(output: PlanAndQueryOutput | dict | None) -> PlanAndQueryOutput | None:
|
||||
class SerializedPlanAndQueryOutput(TypedDict):
|
||||
plan: str | None
|
||||
query: dict[str, Any] | None
|
||||
query_generation_retry_count: int | None
|
||||
|
||||
|
||||
def serialize_output(output: PlanAndQueryOutput | dict | None) -> SerializedPlanAndQueryOutput | None:
|
||||
if output:
|
||||
return {
|
||||
query = output.get("query")
|
||||
serialized_output = {
|
||||
**output,
|
||||
"query": output.get("query").model_dump(exclude_none=True),
|
||||
"query": query.model_dump(exclude_none=True) if query else None,
|
||||
}
|
||||
return cast(SerializedPlanAndQueryOutput, serialized_output)
|
||||
return None
|
||||
|
||||
|
||||
@@ -75,13 +95,14 @@ class QueryKindSelection(ScorerWithPartial):
|
||||
self._expected = expected
|
||||
|
||||
def _run_eval_sync(self, output: PlanAndQueryOutput, expected=None, **kwargs):
|
||||
if not output.get("query"):
|
||||
query = output.get("query")
|
||||
if not query:
|
||||
return Score(name=self._name(), score=None, metadata={"reason": "No query present"})
|
||||
score = 1 if output["query"].kind == self._expected else 0
|
||||
score = 1 if query.kind == self._expected else 0
|
||||
return Score(
|
||||
name=self._name(),
|
||||
score=score,
|
||||
metadata={"reason": f"Expected {self._expected}, got {output['query'].kind}"} if not score else {},
|
||||
metadata={"reason": f"Expected {self._expected}, got {query.kind}"} if not score else {},
|
||||
)
|
||||
|
||||
|
||||
@@ -89,22 +110,18 @@ class PlanCorrectness(LLMClassifier):
|
||||
"""Evaluate if the generated plan correctly answers the user's question."""
|
||||
|
||||
async def _run_eval_async(self, output: PlanAndQueryOutput, expected=None, **kwargs):
|
||||
if not output.get("plan"):
|
||||
return Score(name=self._name(), score=0.0, metadata={"reason": "No plan present"})
|
||||
output = PlanAndQueryOutput(
|
||||
plan=output.get("plan"),
|
||||
query=output["query"].model_dump_json(exclude_none=True) if output.get("query") else None, # Clean up
|
||||
)
|
||||
return await super()._run_eval_async(output, serialize_output(expected), **kwargs)
|
||||
plan = output.get("plan")
|
||||
query = output.get("query")
|
||||
if not plan or not query:
|
||||
return Score(name=self._name(), score=0.0, metadata={"reason": "No plan or query present"})
|
||||
return await super()._run_eval_async(serialize_output(output), serialize_output(expected), **kwargs)
|
||||
|
||||
def _run_eval_sync(self, output: PlanAndQueryOutput, expected=None, **kwargs):
|
||||
if not output.get("plan"):
|
||||
return Score(name=self._name(), score=0.0, metadata={"reason": "No plan present"})
|
||||
output = PlanAndQueryOutput(
|
||||
plan=output.get("plan"),
|
||||
query=output["query"].model_dump_json(exclude_none=True) if output.get("query") else None, # Clean up
|
||||
)
|
||||
return super()._run_eval_sync(output, serialize_output(expected), **kwargs)
|
||||
plan = output.get("plan")
|
||||
query = output.get("query")
|
||||
if not plan or not query:
|
||||
return Score(name=self._name(), score=0.0, metadata={"reason": "No plan or query present"})
|
||||
return super()._run_eval_sync(serialize_output(output), serialize_output(expected), **kwargs)
|
||||
|
||||
def __init__(self, query_kind: NodeKind, evaluation_criteria: str, **kwargs):
|
||||
super().__init__(
|
||||
@@ -167,34 +184,30 @@ class QueryAndPlanAlignment(LLMClassifier):
|
||||
"""Evaluate if the generated SQL query aligns with the plan generated in the previous step."""
|
||||
|
||||
async def _run_eval_async(self, output: PlanAndQueryOutput, expected: PlanAndQueryOutput | None = None, **kwargs):
|
||||
if not output.get("plan"):
|
||||
plan = output.get("plan")
|
||||
if not plan:
|
||||
return Score(
|
||||
name=self._name(),
|
||||
score=None,
|
||||
metadata={"reason": "No plan present in the first place, skipping evaluation"},
|
||||
)
|
||||
if not output.get("query"):
|
||||
query = output.get("query")
|
||||
if not query:
|
||||
return Score(name=self._name(), score=0.0, metadata={"reason": "Query failed to be generated"})
|
||||
output = PlanAndQueryOutput(
|
||||
plan=output.get("plan"),
|
||||
query=output["query"].model_dump_json(exclude_none=True) if output.get("query") else None, # Clean up
|
||||
)
|
||||
return await super()._run_eval_async(output, serialize_output(expected), **kwargs)
|
||||
return await super()._run_eval_async(serialize_output(output), serialize_output(expected), **kwargs)
|
||||
|
||||
def _run_eval_sync(self, output: PlanAndQueryOutput, expected: PlanAndQueryOutput | None = None, **kwargs):
|
||||
if not output.get("plan"):
|
||||
plan = output.get("plan")
|
||||
if not plan:
|
||||
return Score(
|
||||
name=self._name(),
|
||||
score=None,
|
||||
metadata={"reason": "No plan present in the first place, skipping evaluation"},
|
||||
)
|
||||
if not output.get("query"):
|
||||
query = output.get("query")
|
||||
if not query:
|
||||
return Score(name=self._name(), score=0.0, metadata={"reason": "Query failed to be generated"})
|
||||
output = PlanAndQueryOutput(
|
||||
plan=output.get("plan"),
|
||||
query=output["query"].model_dump_json(exclude_none=True) if output.get("query") else None, # Clean up
|
||||
)
|
||||
return super()._run_eval_sync(output, serialize_output(expected), **kwargs)
|
||||
return super()._run_eval_sync(serialize_output(output), serialize_output(expected), **kwargs)
|
||||
|
||||
def __init__(self, query_kind: NodeKind, json_schema: dict, evaluation_criteria: str, **kwargs):
|
||||
json_schema_str = json.dumps(json_schema)
|
||||
@@ -270,15 +283,15 @@ Details matter greatly here - including math types or property types - so be har
|
||||
class TimeRangeRelevancy(LLMClassifier):
|
||||
"""Evaluate if the generated query's time range, interval, or period correctly answers the user's question."""
|
||||
|
||||
async def _run_eval_async(self, output, expected=None, **kwargs):
|
||||
async def _run_eval_async(self, output: PlanAndQueryOutput, expected: PlanAndQueryOutput | None = None, **kwargs):
|
||||
if not output.get("query"):
|
||||
return Score(name=self._name(), score=None, metadata={"reason": "No query to check, skipping evaluation"})
|
||||
return await super()._run_eval_async(output, serialize_output(expected), **kwargs)
|
||||
return await super()._run_eval_async(serialize_output(output), serialize_output(expected), **kwargs)
|
||||
|
||||
def _run_eval_sync(self, output, expected=None, **kwargs):
|
||||
def _run_eval_sync(self, output: PlanAndQueryOutput, expected: PlanAndQueryOutput | None = None, **kwargs):
|
||||
if not output.get("query"):
|
||||
return Score(name=self._name(), score=None, metadata={"reason": "No query to check"})
|
||||
return super()._run_eval_sync(output, serialize_output(expected), **kwargs)
|
||||
return super()._run_eval_sync(serialize_output(output), serialize_output(expected), **kwargs)
|
||||
|
||||
def __init__(self, query_kind: NodeKind, **kwargs):
|
||||
super().__init__(
|
||||
@@ -344,78 +357,6 @@ How would you rate the time range relevancy of the generated query? Choose one:
|
||||
)
|
||||
|
||||
|
||||
SQL_SEMANTICS_CORRECTNESS_PROMPT = """
|
||||
<system>
|
||||
You are an expert ClickHouse SQL auditor.
|
||||
Your job is to decide whether two ClickHouse SQL queries are **semantically equivalent for every possible valid database state**, given the same task description and schema.
|
||||
|
||||
When you respond, think step-by-step **internally**, but reveal **nothing** except the final verdict:
|
||||
• Output **Pass** if the candidate query would always return the same result set (ignoring column aliases, ordering, or trivial formatting) as the reference query.
|
||||
• Output **Fail** otherwise, or if you are uncertain.
|
||||
Respond with a single word—**Pass** or **Fail**—and no additional text.
|
||||
</system>
|
||||
|
||||
<input>
|
||||
Task / natural-language question:
|
||||
```
|
||||
{{input}}
|
||||
```
|
||||
|
||||
Database schema (tables and columns):
|
||||
```
|
||||
{{database_schema}}
|
||||
```
|
||||
|
||||
Reference (human-labelled) SQL:
|
||||
```sql
|
||||
{{expected}}
|
||||
```
|
||||
|
||||
Candidate (generated) SQL:
|
||||
```sql
|
||||
{{output}}
|
||||
```
|
||||
</input>
|
||||
|
||||
<reminder>
|
||||
Think through edge cases: NULL handling, grouping, filters, joins, HAVING clauses, aggregations, sub-queries, limits, and data-type quirks.
|
||||
If any logical difference could yield different outputs under some data scenario, the queries are *not* equivalent.
|
||||
</reminder>
|
||||
|
||||
When ready, output your verdict—**Pass** or **Fail**—with absolutely no extra characters.
|
||||
""".strip()
|
||||
|
||||
|
||||
class SQLSemanticsCorrectness(LLMClassifier):
|
||||
"""Evaluate if the actual query matches semantically the expected query."""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(
|
||||
name="sql_semantics_correctness",
|
||||
prompt_template=SQL_SEMANTICS_CORRECTNESS_PROMPT,
|
||||
choice_scores={
|
||||
"Pass": 1.0,
|
||||
"Fail": 0.0,
|
||||
},
|
||||
model="gpt-4.1",
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
async def _run_eval_async(
|
||||
self, output: str | None, expected: str | None = None, database_schema: str | None = None, **kwargs
|
||||
):
|
||||
if not output or output.strip() == "":
|
||||
return Score(name=self._name(), score=None, metadata={"reason": "No query to check, skipping evaluation"})
|
||||
return await super()._run_eval_async(output, expected, database_schema=database_schema, **kwargs)
|
||||
|
||||
def _run_eval_sync(
|
||||
self, output: str | None, expected: str | None = None, database_schema: str | None = None, **kwargs
|
||||
):
|
||||
if not output or output.strip() == "":
|
||||
return Score(name=self._name(), score=None, metadata={"reason": "No query to check, skipping evaluation"})
|
||||
return super()._run_eval_sync(output, expected, database_schema=database_schema, **kwargs)
|
||||
|
||||
|
||||
class InsightSearchOutput(TypedDict, total=False):
|
||||
"""Output structure for insight search evaluations."""
|
||||
|
||||
115
ee/hogai/eval/scorers/sql.py
Normal file
@@ -0,0 +1,115 @@
|
||||
from asgiref.sync import sync_to_async
|
||||
from autoevals.llm import LLMClassifier
|
||||
from braintrust import Score
|
||||
from braintrust_core.score import Scorer
|
||||
|
||||
from posthog.hogql.errors import BaseHogQLError
|
||||
|
||||
from posthog.errors import InternalCHQueryError
|
||||
from posthog.hogql_queries.hogql_query_runner import HogQLQueryRunner
|
||||
from posthog.models.team.team import Team
|
||||
|
||||
|
||||
def evaluate_sql_query(name: str, output: str | None, team: Team | None = None) -> Score:
|
||||
if not output:
|
||||
return Score(name=name, score=None, metadata={"reason": "No SQL query to verify, skipping evaluation"})
|
||||
if not team:
|
||||
return Score(name=name, score=None, metadata={"reason": "No team provided, skipping evaluation"})
|
||||
query = {"query": output}
|
||||
try:
|
||||
# Try to parse, print, and run the query
|
||||
HogQLQueryRunner(query, team).calculate()
|
||||
except BaseHogQLError as e:
|
||||
return Score(name=name, score=0.0, metadata={"reason": f"HogQL-level error: {str(e)}"})
|
||||
except InternalCHQueryError as e:
|
||||
return Score(name=name, score=0.5, metadata={"reason": f"ClickHouse-level error: {str(e)}"})
|
||||
else:
|
||||
return Score(name=name, score=1.0)
|
||||
|
||||
|
||||
class SQLSyntaxCorrectness(Scorer):
|
||||
"""Evaluate if the generated SQL query has correct syntax."""
|
||||
|
||||
def _name(self):
|
||||
return "sql_syntax_correctness"
|
||||
|
||||
async def _run_eval_async(self, output: str | None, team: Team | None = None, **kwargs):
|
||||
return await sync_to_async(self._evaluate)(output, team)
|
||||
|
||||
def _run_eval_sync(self, output: str | None, team: Team | None = None, **kwargs):
|
||||
return self._evaluate(output, team)
|
||||
|
||||
def _evaluate(self, output: str | None, team: Team | None = None) -> Score:
|
||||
return evaluate_sql_query(self._name(), output, team)
|
||||
|
||||
|
||||
SQL_SEMANTICS_CORRECTNESS_PROMPT = """
|
||||
<system>
|
||||
You are an expert ClickHouse SQL auditor.
|
||||
Your job is to decide whether two ClickHouse SQL queries are **semantically equivalent for every possible valid database state**, given the same task description and schema.
|
||||
|
||||
When you respond, think step-by-step **internally**, but reveal **nothing** except the final verdict:
|
||||
• Output **Pass** if the candidate query would always return the same result set (ignoring column aliases, ordering, or trivial formatting) as the reference query.
|
||||
• Output **Fail** otherwise, or if you are uncertain.
|
||||
Respond with a single word—**Pass** or **Fail**—and no additional text.
|
||||
</system>
|
||||
|
||||
<input>
|
||||
Task / natural-language question:
|
||||
```
|
||||
{{input}}
|
||||
```
|
||||
|
||||
Database schema (tables and columns):
|
||||
```
|
||||
{{database_schema}}
|
||||
```
|
||||
|
||||
Reference (human-labelled) SQL:
|
||||
```sql
|
||||
{{expected}}
|
||||
```
|
||||
|
||||
Candidate (generated) SQL:
|
||||
```sql
|
||||
{{output}}
|
||||
```
|
||||
</input>
|
||||
|
||||
<reminder>
|
||||
Think through edge cases: NULL handling, grouping, filters, joins, HAVING clauses, aggregations, sub-queries, limits, and data-type quirks.
|
||||
If any logical difference could yield different outputs under some data scenario, the queries are *not* equivalent.
|
||||
</reminder>
|
||||
|
||||
When ready, output your verdict—**Pass** or **Fail**—with absolutely no extra characters.
|
||||
""".strip()
|
||||
|
||||
|
||||
class SQLSemanticsCorrectness(LLMClassifier):
|
||||
"""Evaluate if the actual query matches semantically the expected query."""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(
|
||||
name="sql_semantics_correctness",
|
||||
prompt_template=SQL_SEMANTICS_CORRECTNESS_PROMPT,
|
||||
choice_scores={
|
||||
"Pass": 1.0,
|
||||
"Fail": 0.0,
|
||||
},
|
||||
model="gpt-4.1",
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
async def _run_eval_async(
|
||||
self, output: str | None, expected: str | None = None, database_schema: str | None = None, **kwargs
|
||||
):
|
||||
if not output or output.strip() == "":
|
||||
return Score(name=self._name(), score=None, metadata={"reason": "No query to check, skipping evaluation"})
|
||||
return await super()._run_eval_async(output, expected, database_schema=database_schema, **kwargs)
|
||||
|
||||
def _run_eval_sync(
|
||||
self, output: str | None, expected: str | None = None, database_schema: str | None = None, **kwargs
|
||||
):
|
||||
if not output or output.strip() == "":
|
||||
return Score(name=self._name(), score=None, metadata={"reason": "No query to check, skipping evaluation"})
|
||||
return super()._run_eval_sync(output, expected, database_schema=database_schema, **kwargs)
|
||||
133
ee/hogai/test/test_query_patches.py
Normal file
@@ -0,0 +1,133 @@
|
||||
from posthog.test.base import BaseTest
|
||||
|
||||
from posthog.schema import (
|
||||
ActorsPropertyTaxonomyQuery,
|
||||
ActorsPropertyTaxonomyResponse,
|
||||
EventTaxonomyItem,
|
||||
EventTaxonomyQuery,
|
||||
TeamTaxonomyItem,
|
||||
TeamTaxonomyQuery,
|
||||
)
|
||||
|
||||
from ee.hogai.eval.offline.query_patches import (
|
||||
ACTORS_PROPERTY_TAXONOMY_QUERY_DATA_SOURCE,
|
||||
EVENT_TAXONOMY_QUERY_DATA_SOURCE,
|
||||
TEAM_TAXONOMY_QUERY_DATA_SOURCE,
|
||||
PatchedActorsPropertyTaxonomyQueryRunner,
|
||||
PatchedEventTaxonomyQueryRunner,
|
||||
PatchedTeamTaxonomyQueryRunner,
|
||||
)
|
||||
|
||||
|
||||
class TestQueryPatches(BaseTest):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
TEAM_TAXONOMY_QUERY_DATA_SOURCE[self.team.id] = [
|
||||
TeamTaxonomyItem(count=10, event="$pageview"),
|
||||
]
|
||||
EVENT_TAXONOMY_QUERY_DATA_SOURCE[self.team.id] = {
|
||||
"$pageview": [
|
||||
EventTaxonomyItem(property="$browser", sample_values=["Safari"], sample_count=1),
|
||||
],
|
||||
}
|
||||
ACTORS_PROPERTY_TAXONOMY_QUERY_DATA_SOURCE[self.team.id]["person"] = {
|
||||
"$location": ActorsPropertyTaxonomyResponse(sample_values=["US"], sample_count=1),
|
||||
"$browser": ActorsPropertyTaxonomyResponse(sample_values=["Safari"], sample_count=10),
|
||||
}
|
||||
ACTORS_PROPERTY_TAXONOMY_QUERY_DATA_SOURCE[self.team.id][0] = {
|
||||
"$device": ActorsPropertyTaxonomyResponse(sample_values=["Phone"], sample_count=1),
|
||||
}
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
TEAM_TAXONOMY_QUERY_DATA_SOURCE.clear()
|
||||
EVENT_TAXONOMY_QUERY_DATA_SOURCE.clear()
|
||||
ACTORS_PROPERTY_TAXONOMY_QUERY_DATA_SOURCE.clear()
|
||||
|
||||
def test_patched_team_taxonomy_query_runner_returns_result(self):
|
||||
query_runner = PatchedTeamTaxonomyQueryRunner(
|
||||
team=self.team,
|
||||
query=TeamTaxonomyQuery(),
|
||||
).calculate()
|
||||
self.assertEqual(query_runner.results, [TeamTaxonomyItem(count=10, event="$pageview")])
|
||||
|
||||
def test_patched_team_taxonomy_query_runner_handles_no_results(self):
|
||||
TEAM_TAXONOMY_QUERY_DATA_SOURCE.clear()
|
||||
query_runner = PatchedTeamTaxonomyQueryRunner(
|
||||
team=self.team,
|
||||
query=TeamTaxonomyQuery(),
|
||||
).calculate()
|
||||
self.assertEqual(query_runner.results, [])
|
||||
|
||||
def test_patched_event_taxonomy_query_runner_returns_result(self):
|
||||
query_runner = PatchedEventTaxonomyQueryRunner(
|
||||
team=self.team,
|
||||
query=EventTaxonomyQuery(event="$pageview"),
|
||||
).calculate()
|
||||
self.assertEqual(
|
||||
query_runner.results, [EventTaxonomyItem(property="$browser", sample_values=["Safari"], sample_count=1)]
|
||||
)
|
||||
|
||||
def test_patched_event_taxonomy_query_runner_handles_no_results(self):
|
||||
EVENT_TAXONOMY_QUERY_DATA_SOURCE.clear()
|
||||
query_runner = PatchedEventTaxonomyQueryRunner(
|
||||
team=self.team,
|
||||
query=EventTaxonomyQuery(event="$pageview"),
|
||||
).calculate()
|
||||
self.assertEqual(query_runner.results, [])
|
||||
|
||||
def test_patched_event_taxonomy_query_runner_returns_result_for_action_id(self):
|
||||
EVENT_TAXONOMY_QUERY_DATA_SOURCE[self.team.id][123] = [
|
||||
EventTaxonomyItem(property="$browser", sample_values=["Safari"], sample_count=1),
|
||||
]
|
||||
query_runner = PatchedEventTaxonomyQueryRunner(
|
||||
team=self.team,
|
||||
query=EventTaxonomyQuery(actionId=123),
|
||||
).calculate()
|
||||
self.assertEqual(
|
||||
query_runner.results,
|
||||
[EventTaxonomyItem(property="$browser", sample_values=["Safari"], sample_count=1)],
|
||||
)
|
||||
|
||||
def test_patched_actors_property_taxonomy_query_runner_returns_result(self):
|
||||
query_runner = PatchedActorsPropertyTaxonomyQueryRunner(
|
||||
team=self.team,
|
||||
query=ActorsPropertyTaxonomyQuery(groupTypeIndex=0, properties=["$device"]),
|
||||
).calculate()
|
||||
self.assertEqual(
|
||||
query_runner.results,
|
||||
[ActorsPropertyTaxonomyResponse(sample_values=["Phone"], sample_count=1)],
|
||||
)
|
||||
|
||||
query_runner = PatchedActorsPropertyTaxonomyQueryRunner(
|
||||
team=self.team,
|
||||
query=ActorsPropertyTaxonomyQuery(properties=["$location", "$browser"]),
|
||||
).calculate()
|
||||
self.assertEqual(
|
||||
query_runner.results,
|
||||
[
|
||||
ActorsPropertyTaxonomyResponse(sample_values=["US"], sample_count=1),
|
||||
ActorsPropertyTaxonomyResponse(sample_values=["Safari"], sample_count=10),
|
||||
],
|
||||
)
|
||||
|
||||
def test_patched_actors_property_taxonomy_query_runner_handles_mixed_results(self):
|
||||
query_runner = PatchedActorsPropertyTaxonomyQueryRunner(
|
||||
team=self.team,
|
||||
query=ActorsPropertyTaxonomyQuery(properties=["$location", "$latitude"]),
|
||||
).calculate()
|
||||
self.assertEqual(
|
||||
query_runner.results,
|
||||
[
|
||||
ActorsPropertyTaxonomyResponse(sample_values=["US"], sample_count=1),
|
||||
ActorsPropertyTaxonomyResponse(sample_values=[], sample_count=0),
|
||||
],
|
||||
)
|
||||
|
||||
def test_patched_actors_property_taxonomy_query_runner_handles_no_results(self):
|
||||
ACTORS_PROPERTY_TAXONOMY_QUERY_DATA_SOURCE.clear()
|
||||
query_runner = PatchedActorsPropertyTaxonomyQueryRunner(
|
||||
team=self.team,
|
||||
query=ActorsPropertyTaxonomyQuery(properties=["$device"]),
|
||||
).calculate()
|
||||
self.assertEqual(query_runner.results, [ActorsPropertyTaxonomyResponse(sample_values=[], sample_count=0)])
|
||||
254
ee/hogai/test/test_snapshot_loader.py
Normal file
@@ -0,0 +1,254 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from io import BytesIO
|
||||
from typing import Any
|
||||
|
||||
from posthog.test.base import BaseTest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
from asgiref.sync import async_to_sync
|
||||
|
||||
from posthog.schema import (
|
||||
ActorsPropertyTaxonomyQuery,
|
||||
ActorsPropertyTaxonomyResponse,
|
||||
EventTaxonomyItem,
|
||||
EventTaxonomyQuery,
|
||||
TeamTaxonomyItem,
|
||||
TeamTaxonomyQuery,
|
||||
)
|
||||
|
||||
from posthog.hogql_queries.query_runner import get_query_runner
|
||||
from posthog.models import GroupTypeMapping, Organization, PropertyDefinition, Team
|
||||
from posthog.warehouse.models.table import DataWarehouseTable
|
||||
|
||||
from ee.hogai.eval.offline.query_patches import (
|
||||
ACTORS_PROPERTY_TAXONOMY_QUERY_DATA_SOURCE,
|
||||
EVENT_TAXONOMY_QUERY_DATA_SOURCE,
|
||||
TEAM_TAXONOMY_QUERY_DATA_SOURCE,
|
||||
)
|
||||
from ee.hogai.eval.offline.snapshot_loader import SnapshotLoader
|
||||
from ee.hogai.eval.schema import (
|
||||
ActorsPropertyTaxonomySnapshot,
|
||||
ClickhouseTeamDataSnapshot,
|
||||
DataWarehouseTableSnapshot,
|
||||
GroupTypeMappingSnapshot,
|
||||
PostgresTeamDataSnapshot,
|
||||
PropertyDefinitionSnapshot,
|
||||
PropertyTaxonomySnapshot,
|
||||
TeamSnapshot,
|
||||
TeamTaxonomyItemSnapshot,
|
||||
)
|
||||
|
||||
|
||||
class TestSnapshotLoader(BaseTest):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
TEAM_TAXONOMY_QUERY_DATA_SOURCE.clear()
|
||||
EVENT_TAXONOMY_QUERY_DATA_SOURCE.clear()
|
||||
ACTORS_PROPERTY_TAXONOMY_QUERY_DATA_SOURCE.clear()
|
||||
|
||||
def tearDown(self):
|
||||
TEAM_TAXONOMY_QUERY_DATA_SOURCE.clear()
|
||||
EVENT_TAXONOMY_QUERY_DATA_SOURCE.clear()
|
||||
ACTORS_PROPERTY_TAXONOMY_QUERY_DATA_SOURCE.clear()
|
||||
super().tearDown()
|
||||
|
||||
def _extras(self, team_id: int) -> dict[str, Any]:
|
||||
return {
|
||||
"aws_endpoint_url": "http://localhost:9000",
|
||||
"aws_bucket_name": "evals",
|
||||
"experiment_name": "offline_evaluation",
|
||||
"team_snapshots": [
|
||||
{
|
||||
"team_id": team_id,
|
||||
"postgres": PostgresTeamDataSnapshot(
|
||||
team=f"pg/team_{team_id}.avro",
|
||||
property_definitions=f"pg/property_defs_{team_id}.avro",
|
||||
group_type_mappings=f"pg/group_mappings_{team_id}.avro",
|
||||
data_warehouse_tables=f"pg/dw_tables_{team_id}.avro",
|
||||
).model_dump(),
|
||||
"clickhouse": ClickhouseTeamDataSnapshot(
|
||||
event_taxonomy=f"ch/event_taxonomy_{team_id}.avro",
|
||||
properties_taxonomy=f"ch/properties_taxonomy_{team_id}.avro",
|
||||
actors_property_taxonomy=f"ch/actors_property_taxonomy_{team_id}.avro",
|
||||
).model_dump(),
|
||||
}
|
||||
],
|
||||
"dataset": [],
|
||||
}
|
||||
|
||||
def _fake_parse(self, schema, _buffer):
|
||||
if schema is TeamSnapshot:
|
||||
yield TeamSnapshot(name="Evaluated Team", test_account_filters="{}")
|
||||
return
|
||||
if schema is PropertyDefinitionSnapshot:
|
||||
yield PropertyDefinitionSnapshot(
|
||||
name="$browser", is_numerical=False, property_type="String", type=1, group_type_index=None
|
||||
)
|
||||
yield PropertyDefinitionSnapshot(
|
||||
name="$device", is_numerical=False, property_type="String", type=1, group_type_index=0
|
||||
)
|
||||
return
|
||||
if schema is GroupTypeMappingSnapshot:
|
||||
yield GroupTypeMappingSnapshot(
|
||||
group_type="organization", group_type_index=0, name_singular="org", name_plural="orgs"
|
||||
)
|
||||
yield GroupTypeMappingSnapshot(
|
||||
group_type="company", group_type_index=1, name_singular="company", name_plural="companies"
|
||||
)
|
||||
return
|
||||
if schema is DataWarehouseTableSnapshot:
|
||||
yield DataWarehouseTableSnapshot(name="users", format="Parquet", columns={"id": "Int64"})
|
||||
yield DataWarehouseTableSnapshot(name="orders", format="Parquet", columns={"id": "Int64"})
|
||||
return
|
||||
if schema is TeamTaxonomyItemSnapshot:
|
||||
yield TeamTaxonomyItemSnapshot(
|
||||
results=[
|
||||
TeamTaxonomyItem(count=123, event="$pageview"),
|
||||
TeamTaxonomyItem(count=45, event="$autocapture"),
|
||||
]
|
||||
)
|
||||
return
|
||||
if schema is PropertyTaxonomySnapshot:
|
||||
# Two events to fill EVENT_TAXONOMY_QUERY_DATA_SOURCE
|
||||
yield PropertyTaxonomySnapshot(
|
||||
event="$pageview",
|
||||
results=[
|
||||
EventTaxonomyItem(property="$browser", sample_values=["Safari"], sample_count=10),
|
||||
EventTaxonomyItem(property="$os", sample_values=["macOS"], sample_count=5),
|
||||
],
|
||||
)
|
||||
yield PropertyTaxonomySnapshot(
|
||||
event="$autocapture",
|
||||
results=[
|
||||
EventTaxonomyItem(property="$element_type", sample_values=["a"], sample_count=3),
|
||||
],
|
||||
)
|
||||
return
|
||||
if schema is ActorsPropertyTaxonomySnapshot:
|
||||
# person (None), group 0, and group 1
|
||||
yield ActorsPropertyTaxonomySnapshot(
|
||||
group_type_index=None,
|
||||
property="$browser",
|
||||
results=ActorsPropertyTaxonomyResponse(sample_values=["Safari"], sample_count=10),
|
||||
)
|
||||
yield ActorsPropertyTaxonomySnapshot(
|
||||
group_type_index=0,
|
||||
property="$device",
|
||||
results=ActorsPropertyTaxonomyResponse(sample_values=["Phone"], sample_count=2),
|
||||
)
|
||||
yield ActorsPropertyTaxonomySnapshot(
|
||||
group_type_index=1,
|
||||
property="$industry",
|
||||
results=ActorsPropertyTaxonomyResponse(sample_values=["Tech"], sample_count=1),
|
||||
)
|
||||
return
|
||||
raise AssertionError(f"Unhandled schema in fake parser: {schema}")
|
||||
|
||||
def _build_context(self, extras: dict[str, Any]) -> MagicMock:
|
||||
ctx = MagicMock()
|
||||
ctx.extras = extras
|
||||
ctx.log = MagicMock()
|
||||
ctx.log.info = MagicMock()
|
||||
return ctx
|
||||
|
||||
def _load_with_mocks(self) -> tuple[Organization, Any, list[Any], Team]:
|
||||
extras = self._extras(9990)
|
||||
ctx = self._build_context(extras)
|
||||
|
||||
async_get = AsyncMock(side_effect=lambda client, key: BytesIO(b"ok"))
|
||||
|
||||
with patch.object(SnapshotLoader, "_get_snapshot_from_s3", new=async_get):
|
||||
with patch.object(SnapshotLoader, "_parse_snapshot_to_schema", new=self._fake_parse):
|
||||
loader = SnapshotLoader(ctx)
|
||||
org, user, dataset = async_to_sync(loader.load_snapshots)()
|
||||
return org, user, dataset, Team.objects.get(id=9990)
|
||||
|
||||
def test_loads_data_from_s3(self):
|
||||
team_id = 99990
|
||||
extras = self._extras(team_id)
|
||||
ctx = self._build_context(extras)
|
||||
|
||||
calls: list[tuple[dict[str, Any], str]] = []
|
||||
|
||||
async def record_call(_self, client, key: str):
|
||||
calls.append(({"Bucket": extras["aws_bucket_name"]}, key))
|
||||
return BytesIO(b"ok")
|
||||
|
||||
with patch.object(SnapshotLoader, "_get_snapshot_from_s3", new=record_call):
|
||||
with patch.object(SnapshotLoader, "_parse_snapshot_to_schema", new=self._fake_parse):
|
||||
loader = SnapshotLoader(ctx)
|
||||
async_to_sync(loader.load_snapshots)()
|
||||
|
||||
keys = [k for _, k in calls]
|
||||
self.assertIn(f"pg/team_{team_id}.avro", keys)
|
||||
self.assertIn(f"pg/property_defs_{team_id}.avro", keys)
|
||||
self.assertIn(f"pg/group_mappings_{team_id}.avro", keys)
|
||||
self.assertIn(f"pg/dw_tables_{team_id}.avro", keys)
|
||||
self.assertIn(f"ch/event_taxonomy_{team_id}.avro", keys)
|
||||
self.assertIn(f"ch/properties_taxonomy_{team_id}.avro", keys)
|
||||
self.assertIn(f"ch/actors_property_taxonomy_{team_id}.avro", keys)
|
||||
|
||||
def test_restores_org_team_user(self):
|
||||
org, user, _dataset, team = self._load_with_mocks()
|
||||
self.assertEqual(org.name, "PostHog")
|
||||
self.assertEqual(team.organization_id, org.id)
|
||||
self.assertEqual(team.id, 9990)
|
||||
self.assertEqual(team.api_token, "team_9990")
|
||||
|
||||
def test_restores_models_counts(self):
|
||||
_org, _user, _dataset, team = self._load_with_mocks()
|
||||
self.assertEqual(PropertyDefinition.objects.filter(team_id=team.id).count(), 2)
|
||||
self.assertEqual(GroupTypeMapping.objects.filter(team_id=team.id).count(), 2)
|
||||
self.assertEqual(DataWarehouseTable.objects.filter(team_id=team.id).count(), 2)
|
||||
|
||||
def test_loads_team_taxonomy_data_source(self):
|
||||
_org, _user, _dataset, team = self._load_with_mocks()
|
||||
self.assertIn(team.id, TEAM_TAXONOMY_QUERY_DATA_SOURCE)
|
||||
self.assertEqual([i.event for i in TEAM_TAXONOMY_QUERY_DATA_SOURCE[team.id]], ["$pageview", "$autocapture"])
|
||||
|
||||
def test_loads_event_taxonomy_data_source(self):
|
||||
_org, _user, _dataset, team = self._load_with_mocks()
|
||||
self.assertIn("$pageview", EVENT_TAXONOMY_QUERY_DATA_SOURCE[team.id])
|
||||
self.assertIn("$autocapture", EVENT_TAXONOMY_QUERY_DATA_SOURCE[team.id])
|
||||
props = [i.property for i in EVENT_TAXONOMY_QUERY_DATA_SOURCE[team.id]["$pageview"]]
|
||||
self.assertIn("$browser", props)
|
||||
self.assertIn("$os", props)
|
||||
|
||||
def test_loads_actors_property_taxonomy_data_source_various_group_types(self):
|
||||
_org, _user, _dataset, team = self._load_with_mocks()
|
||||
self.assertIn("person", ACTORS_PROPERTY_TAXONOMY_QUERY_DATA_SOURCE[team.id])
|
||||
self.assertIn(0, ACTORS_PROPERTY_TAXONOMY_QUERY_DATA_SOURCE[team.id])
|
||||
self.assertIn(1, ACTORS_PROPERTY_TAXONOMY_QUERY_DATA_SOURCE[team.id])
|
||||
|
||||
def test_runs_query_runner_patches(self):
|
||||
_org, _user, _dataset, team = self._load_with_mocks()
|
||||
|
||||
# Team taxonomy
|
||||
team_resp = get_query_runner(TeamTaxonomyQuery(), team=team).calculate()
|
||||
self.assertTrue(any(item.event == "$pageview" for item in team_resp.results))
|
||||
|
||||
# Event taxonomy by event
|
||||
event_resp = get_query_runner(EventTaxonomyQuery(event="$pageview", maxPropertyValues=5), team=team).calculate()
|
||||
self.assertTrue(any(item.property == "$browser" for item in event_resp.results))
|
||||
|
||||
# Actors property taxonomy for person
|
||||
actors_resp_person = get_query_runner(
|
||||
ActorsPropertyTaxonomyQuery(properties=["$browser"], groupTypeIndex=None),
|
||||
team=team,
|
||||
).calculate()
|
||||
self.assertEqual(actors_resp_person.results[0].sample_values, ["Safari"])
|
||||
|
||||
# Actors property taxonomy for group 0
|
||||
actors_resp_group0 = get_query_runner(
|
||||
ActorsPropertyTaxonomyQuery(properties=["$device"], groupTypeIndex=0),
|
||||
team=team,
|
||||
).calculate()
|
||||
self.assertEqual(actors_resp_group0.results[0].sample_values, ["Phone"])
|
||||
|
||||
# Actors property taxonomy for group 1
|
||||
actors_resp_group1 = get_query_runner(
|
||||
ActorsPropertyTaxonomyQuery(properties=["$industry"], groupTypeIndex=1),
|
||||
team=team,
|
||||
).calculate()
|
||||
self.assertEqual(actors_resp_group1.results[0].sample_values, ["Tech"])
|
||||
@@ -76,6 +76,10 @@ OPENAI_API_KEY = get_from_env("OPENAI_API_KEY", "")
|
||||
INKEEP_API_KEY = get_from_env("INKEEP_API_KEY", "")
|
||||
MISTRAL_API_KEY = get_from_env("MISTRAL_API_KEY", "")
|
||||
GEMINI_API_KEY = get_from_env("GEMINI_API_KEY", "")
|
||||
PPLX_API_KEY = get_from_env("PPLX_API_KEY", "")
|
||||
AZURE_INFERENCE_ENDPOINT = get_from_env("AZURE_INFERENCE_ENDPOINT", "")
|
||||
AZURE_INFERENCE_CREDENTIAL = get_from_env("AZURE_INFERENCE_CREDENTIAL", "")
|
||||
BRAINTRUST_API_KEY = get_from_env("BRAINTRUST_API_KEY", "")
|
||||
|
||||
MAILJET_PUBLIC_KEY = get_from_env("MAILJET_PUBLIC_KEY", "", type_cast=str)
|
||||
MAILJET_SECRET_KEY = get_from_env("MAILJET_SECRET_KEY", "", type_cast=str)
|
||||
|
||||
|
Before Width: | Height: | Size: 92 KiB After Width: | Height: | Size: 92 KiB |
|
Before Width: | Height: | Size: 136 KiB After Width: | Height: | Size: 136 KiB |
|
Before Width: | Height: | Size: 76 KiB After Width: | Height: | Size: 76 KiB |
|
Before Width: | Height: | Size: 77 KiB After Width: | Height: | Size: 77 KiB |
|
Before Width: | Height: | Size: 115 KiB After Width: | Height: | Size: 115 KiB |
|
Before Width: | Height: | Size: 91 KiB After Width: | Height: | Size: 91 KiB |
|
Before Width: | Height: | Size: 140 KiB After Width: | Height: | Size: 140 KiB |
|
Before Width: | Height: | Size: 114 KiB After Width: | Height: | Size: 114 KiB |
|
Before Width: | Height: | Size: 85 KiB After Width: | Height: | Size: 85 KiB |
|
Before Width: | Height: | Size: 84 KiB After Width: | Height: | Size: 84 KiB |
|
Before Width: | Height: | Size: 120 KiB After Width: | Height: | Size: 121 KiB |
|
Before Width: | Height: | Size: 142 KiB After Width: | Height: | Size: 141 KiB |
@@ -299,20 +299,6 @@ ee/clickhouse/views/test/test_experiment_saved_metrics.py:0: error: Item "None"
|
||||
ee/clickhouse/views/test/test_experiment_saved_metrics.py:0: error: Item "None" of "Any | None" has no attribute "query" [union-attr]
|
||||
ee/clickhouse/views/test/test_experiment_saved_metrics.py:0: error: Item "None" of "Any | None" has no attribute "query" [union-attr]
|
||||
ee/clickhouse/views/test/test_experiment_saved_metrics.py:0: error: Item "None" of "ExperimentToSavedMetric | None" has no attribute "metadata" [union-attr]
|
||||
ee/hogai/eval/conftest.py:0: error: Incompatible types (expression has type "None", TypedDict item "query" has type "AssistantTrendsQuery | AssistantFunnelsQuery | AssistantRetentionQuery | AssistantHogQLQuery") [typeddict-item]
|
||||
ee/hogai/eval/conftest.py:0: error: Item "None" of "Any | None" has no attribute "organization" [union-attr]
|
||||
ee/hogai/eval/eval_memory.py:0: error: "BaseMessage" has no attribute "tool_calls" [attr-defined]
|
||||
ee/hogai/eval/max_tools/eval_hogql_generator_tool.py:0: error: List item 1 has incompatible type "Callable[[EvalInput, str, str, dict[Any, Any]], Coroutine[Any, Any, Score]]"; expected "AsyncScorerLike[Any, Any] | type[SyncScorerLike[Any, Any]] | type[AsyncScorerLike[Any, Any]] | Callable[[Any, Any, Any], float | int | bool | Score | list[Score] | None] | Callable[[Any, Any, Any], Awaitable[float | int | bool | Score | list[Score] | None]]" [list-item]
|
||||
ee/hogai/eval/scorers.py:0: error: Incompatible types (expression has type "dict[str, Any] | Any", TypedDict item "query" has type "AssistantTrendsQuery | AssistantFunnelsQuery | AssistantRetentionQuery | AssistantHogQLQuery") [typeddict-item]
|
||||
ee/hogai/eval/scorers.py:0: error: Incompatible types (expression has type "str | None", TypedDict item "query" has type "AssistantTrendsQuery | AssistantFunnelsQuery | AssistantRetentionQuery | AssistantHogQLQuery") [typeddict-item]
|
||||
ee/hogai/eval/scorers.py:0: error: Incompatible types (expression has type "str | None", TypedDict item "query" has type "AssistantTrendsQuery | AssistantFunnelsQuery | AssistantRetentionQuery | AssistantHogQLQuery") [typeddict-item]
|
||||
ee/hogai/eval/scorers.py:0: error: Incompatible types (expression has type "str | None", TypedDict item "query" has type "AssistantTrendsQuery | AssistantFunnelsQuery | AssistantRetentionQuery | AssistantHogQLQuery") [typeddict-item]
|
||||
ee/hogai/eval/scorers.py:0: error: Incompatible types (expression has type "str | None", TypedDict item "query" has type "AssistantTrendsQuery | AssistantFunnelsQuery | AssistantRetentionQuery | AssistantHogQLQuery") [typeddict-item]
|
||||
ee/hogai/eval/scorers.py:0: error: Item "None" of "AssistantTrendsQuery | AssistantFunnelsQuery | AssistantRetentionQuery | AssistantHogQLQuery | Any | None" has no attribute "model_dump" [union-attr]
|
||||
ee/hogai/eval/scorers.py:0: error: Item "ToolCall" of "AssistantToolCall | ToolCall" has no attribute "args" [union-attr]
|
||||
ee/hogai/eval/scorers.py:0: error: Item "ToolCall" of "AssistantToolCall | ToolCall" has no attribute "args" [union-attr]
|
||||
ee/hogai/eval/scorers.py:0: error: Item "ToolCall" of "AssistantToolCall | ToolCall" has no attribute "args" [union-attr]
|
||||
ee/hogai/eval/scorers.py:0: error: Item "ToolCall" of "AssistantToolCall | ToolCall" has no attribute "name" [union-attr]
|
||||
ee/hogai/graph/billing/test/test_nodes.py:0: error: Item "None" of "Any | None" has no attribute "data" [union-attr]
|
||||
ee/hogai/graph/billing/test/test_nodes.py:0: error: Item "None" of "Any | None" has no attribute "dates" [union-attr]
|
||||
ee/hogai/graph/billing/test/test_nodes.py:0: error: Item "None" of "Any | None" has no attribute "dates" [union-attr]
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from django.db import models
|
||||
|
||||
from posthog.models.utils import UUIDModel, BytecodeModelMixin
|
||||
from posthog.models.utils import BytecodeModelMixin, UUIDModel
|
||||
|
||||
|
||||
class GroupUsageMetric(UUIDModel, BytecodeModelMixin):
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from posthog.models import GroupUsageMetric
|
||||
from posthog.test.base import BaseTest
|
||||
|
||||
from posthog.models import GroupUsageMetric
|
||||
|
||||
|
||||
class GroupUsageMetricTestCase(BaseTest):
|
||||
def test_bytecode_generation(self):
|
||||
|
||||
@@ -17,9 +17,10 @@ from django.db.models import Q, UniqueConstraint
|
||||
from django.db.models.constraints import BaseConstraint
|
||||
from django.utils.text import slugify
|
||||
|
||||
from posthog.constants import MAX_SLUG_LENGTH
|
||||
from posthog.hogql import ast
|
||||
|
||||
from posthog.constants import MAX_SLUG_LENGTH
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from random import Random
|
||||
|
||||
|
||||
@@ -23,6 +23,8 @@ dependencies = [
|
||||
"dagster-cloud==1.10.18",
|
||||
"dagster-aws==0.26.18",
|
||||
"dagster-celery==0.26.18",
|
||||
"dagster-docker==0.26.18",
|
||||
"dagster-pipes==1.10.18",
|
||||
"dagster-postgres==0.26.18",
|
||||
"dagster-slack==0.26.18",
|
||||
"dagster-webserver==1.10.18",
|
||||
@@ -168,8 +170,9 @@ dev = [
|
||||
"autoevals==0.0.129",
|
||||
"black~=23.9.1",
|
||||
"boto3-stubs[s3]>=1.34.84",
|
||||
"braintrust==0.2.0",
|
||||
"braintrust-langchain==0.0.2",
|
||||
"braintrust==0.2.4",
|
||||
"braintrust-langchain==0.0.4",
|
||||
"dagster-dg-cli>=1.10.18",
|
||||
"datamodel-code-generator==0.28.5",
|
||||
"deepdiff>=8.5.0",
|
||||
"django-linear-migrations==2.16.*",
|
||||
@@ -207,7 +210,7 @@ dev = [
|
||||
"responses==0.23.1",
|
||||
"ruff~=0.8.1",
|
||||
"sqlalchemy==2.0.38",
|
||||
"stpyv8==13.1.201.22",
|
||||
"stpyv8==13.1.201.22; sys_platform != 'linux' or (platform_machine != 'aarch64' and platform_machine != 'arm64')",
|
||||
"syrupy~=4.6.0",
|
||||
"tach~=0.20.0",
|
||||
"types-aioboto3[s3]>=14.3.0",
|
||||
|
||||
149
uv.lock
generated
@@ -520,7 +520,7 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "braintrust"
|
||||
version = "0.2.0"
|
||||
version = "0.2.4"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "chevron" },
|
||||
@@ -533,9 +533,9 @@ dependencies = [
|
||||
{ name = "tqdm" },
|
||||
{ name = "typing-extensions" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/ca/2d/c1edc49d8ea3548552a844251de01946e566621b2f6a42af284a058c360d/braintrust-0.2.0.tar.gz", hash = "sha256:fd45665b9521276d581c12a947bf9033a51c5365c986454df6332e01729da270", size = 145441, upload-time = "2025-07-25T05:34:44.555Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/5a/c3/01db8b072d56b3bcebe443af7acc57466c87bf2f49612a444ea707f76a90/braintrust-0.2.4.tar.gz", hash = "sha256:b1866a646d2368bd4c2ccba305e5c1e1d97c7f8c715833d3fc69810644fdc66f", size = 159340, upload-time = "2025-08-20T21:15:18.8Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/6a/60/83fb8bcc84f7fe408b7373c520187e777859c79bc6aa9ad5d1eb7d4b4c8a/braintrust-0.2.0-py3-none-any.whl", hash = "sha256:dffc74fe85f9c855461712210a852c543d8573767544403b3f73fbc0e645f5a2", size = 171651, upload-time = "2025-07-25T05:34:43.249Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/31/32/4a48037e9956592f3a7cfae442ecf23c4ea6884950932126561916c089b4/braintrust-0.2.4-py3-none-any.whl", hash = "sha256:7dc30f32ec5e909b0b35e75ee6596090bdc15e4ed9f4d82eda52103b1aaaf721", size = 185252, upload-time = "2025-08-20T21:15:17.137Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -549,15 +549,15 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "braintrust-langchain"
|
||||
version = "0.0.2"
|
||||
version = "0.0.4"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "braintrust" },
|
||||
{ name = "langchain" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/77/1d/9f1def0527de67191c41fa3941a8cc72b187c6fd52c1f8b7a9b24fc6240a/braintrust_langchain-0.0.2.tar.gz", hash = "sha256:f5af5ebc600cd45fd67a4ba4b6e0b99248eb6f6caea79f521d0bee1f045ac4b0", size = 42543, upload-time = "2025-03-10T21:24:25.926Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/7a/28/c302623e8d1ad293b2a7029dd7a2de3feaa50a486335f6d37a577b7453eb/braintrust_langchain-0.0.4.tar.gz", hash = "sha256:5ee372f97857f436d3f81d19a9600d4a741e570b6bf80a714caf424406713cfb", size = 43081, upload-time = "2025-08-06T19:40:11.861Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/3a/7e/235c99c8d934b922b8ffd9cd2220f21e14fabb4cc126eb2e3a914baa54f1/braintrust_langchain-0.0.2-py3-none-any.whl", hash = "sha256:8a5cab14b84f375e43500eb5feae64d5311683a268a3d7c4c70e592dfdbaa444", size = 43040, upload-time = "2025-03-10T21:24:24.421Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/e5/f3/aaf624f046902454c0ebf84e8426864be9587938d8ef85e4484502c444ed/braintrust_langchain-0.0.4-py3-none-any.whl", hash = "sha256:c9bb9cfbcb938f71b0344821fdb8e69091e207d3691c0fead8e780088798bbc6", size = 43492, upload-time = "2025-08-06T19:40:10.456Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -748,6 +748,18 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/7e/d4/7ebdbd03970677812aac39c869717059dbb71a4cfc033ca6e5221787892c/click-8.1.8-py3-none-any.whl", hash = "sha256:63c132bbbed01578a06712a2d1f497bb62d9c1c0d329b7903a866228027263b2", size = 98188, upload-time = "2024-12-21T18:38:41.666Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "click-aliases"
|
||||
version = "1.0.5"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "click" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/40/d8/adbaeadc13c9686b9bda8b4c50e5a3983f504faae2ffbea5165d5beb1cdb/click_aliases-1.0.5.tar.gz", hash = "sha256:e37d4cabbaad68e1c48ec0f063a59dfa15f0e7450ec901bd1ce4f4b954bc881d", size = 3105, upload-time = "2024-10-17T15:44:19.056Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/3d/1a/d5e29a6f896293e32ab3e63201df5d396599e57a726575adaafbcd9d70a6/click_aliases-1.0.5-py3-none-any.whl", hash = "sha256:cbb83a348acc00809fe18b6da13a7f6307bc71b3c5f69cc730e012dfb4bbfdc3", size = 3524, upload-time = "2024-10-17T15:44:17.389Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "click-didyoumean"
|
||||
version = "0.3.1"
|
||||
@@ -1092,6 +1104,61 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/e5/14/d8e23457fbabd97c5e71852864a423cdef983781daba360e70631987a917/dagster_cloud_cli-1.10.18-py3-none-any.whl", hash = "sha256:ad513f6ca11b7c79703b2eaa4989f90395db0d0a439679cccebd794a336f6199", size = 107883, upload-time = "2025-05-29T21:36:21.843Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "dagster-dg-cli"
|
||||
version = "1.10.18"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "dagster-dg-core" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/0b/f3/dcdeca179a69fad1cc75054725e72ddab4ebd3c095b2af6a63cf18228fe0/dagster_dg_cli-1.10.18.tar.gz", hash = "sha256:666d64f448b707ef557d855b94e2431b6f9ce86e35e09b877be16ae164cd678b", size = 539236, upload-time = "2025-05-29T21:52:30.092Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/8b/4f/0e452fa2846a41ba6db0fa3a3772b9fde40ea0c342ddf17f9a27eb7e2b69/dagster_dg_cli-1.10.18-py3-none-any.whl", hash = "sha256:949043b72d9ce5af0eee2499c606e132b8655db51dc2b78f8e7b2e380f475109", size = 569884, upload-time = "2025-05-29T21:52:28.111Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "dagster-dg-core"
|
||||
version = "1.10.18"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "click" },
|
||||
{ name = "click-aliases" },
|
||||
{ name = "dagster-cloud-cli" },
|
||||
{ name = "dagster-shared" },
|
||||
{ name = "gql", extra = ["requests"] },
|
||||
{ name = "jinja2" },
|
||||
{ name = "jsonschema" },
|
||||
{ name = "markdown" },
|
||||
{ name = "packaging" },
|
||||
{ name = "python-dotenv" },
|
||||
{ name = "pyyaml" },
|
||||
{ name = "rich" },
|
||||
{ name = "setuptools" },
|
||||
{ name = "tomlkit" },
|
||||
{ name = "typer" },
|
||||
{ name = "typing-extensions" },
|
||||
{ name = "watchdog" },
|
||||
{ name = "yaspin" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/d2/a8/3965d6d769d541bddd545e5624378c853539a7aa7582479e78d5b516bf30/dagster_dg_core-1.10.18.tar.gz", hash = "sha256:4e988e30f1d390ccf4a1f0ec15d571a2482b14e18c4089bce515ef07c533cff0", size = 43181, upload-time = "2025-05-29T21:36:04.451Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/e6/58/59cc3dce6fa82cc62da6e6b876d8decc549255995a7394736d707543c1b7/dagster_dg_core-1.10.18-py3-none-any.whl", hash = "sha256:3299badb16949405033b8690796269112b4243f90766410edd0098c638c556ec", size = 49977, upload-time = "2025-05-29T21:36:03.419Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "dagster-docker"
|
||||
version = "0.26.18"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "dagster" },
|
||||
{ name = "docker" },
|
||||
{ name = "docker-image-py" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/fd/87/b091ac095ea590f581199c060058d01afeb6aea8de5cd78deb3937db908b/dagster_docker-0.26.18.tar.gz", hash = "sha256:7aa7625b7271c67eacda26a06d384134bdb2f34e548f4c7df74bd8fb085298b7", size = 16129, upload-time = "2025-05-29T21:37:41.881Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/7f/9e/2ab929c7103c63df6ebaef2bd1d16390173a842cdf7f60bd991655446f1c/dagster_docker-0.26.18-py3-none-any.whl", hash = "sha256:d62f07b8aa1e03f990cad413917a0c75515fc80a0c949c34e431c7465c1f1a8f", size = 19580, upload-time = "2025-05-29T21:37:40.775Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "dagster-graphql"
|
||||
version = "1.10.18"
|
||||
@@ -1676,6 +1743,32 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/9b/ed/28fb14146c7033ba0e89decd92a4fa16b0b69b84471e2deab3cc4337cc35/dnspython-2.2.1-py3-none-any.whl", hash = "sha256:a851e51367fb93e9e1361732c1d60dab63eff98712e503ea7d92e6eccb109b4f", size = 269084, upload-time = "2022-03-06T23:36:12.209Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "docker"
|
||||
version = "7.1.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "pywin32", marker = "sys_platform == 'win32'" },
|
||||
{ name = "requests" },
|
||||
{ name = "urllib3" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/91/9b/4a2ea29aeba62471211598dac5d96825bb49348fa07e906ea930394a83ce/docker-7.1.0.tar.gz", hash = "sha256:ad8c70e6e3f8926cb8a92619b832b4ea5299e2831c14284663184e200546fa6c", size = 117834, upload-time = "2024-05-23T11:13:57.216Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/e3/26/57c6fb270950d476074c087527a558ccb6f4436657314bfb6cdf484114c4/docker-7.1.0-py3-none-any.whl", hash = "sha256:c96b93b7f0a746f9e77d325bcfb87422a3d8bd4f03136ae8a85b37f1898d5fc0", size = 147774, upload-time = "2024-05-23T11:13:55.01Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "docker-image-py"
|
||||
version = "0.1.13"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "regex" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/2f/10/28dad68f4693a5131ff859ac9dc8eda7967fd922509c1e667e143b0760b8/docker_image_py-0.1.13.tar.gz", hash = "sha256:de2755de0a09c99ae3b4cf42cc470a83b35bfdc4bf8ad66a3d7550d622a8915a", size = 11211, upload-time = "2024-07-17T05:36:04.298Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/39/06/c8d170aeb3e9eb3d951dd37acf1b6bad2b9401d0bee3346e40295c9e15a2/docker_image_py-0.1.13-py3-none-any.whl", hash = "sha256:c217fc72e8cdf2aa2caa718c758cec1ad41c09972ec1c94e85a4bf79a8f81061", size = 8894, upload-time = "2024-07-17T05:36:02.563Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "docopt"
|
||||
version = "0.6.2"
|
||||
@@ -3015,6 +3108,15 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/87/fb/99f81ac72ae23375f22b7afdb7642aba97c00a713c217124420147681a2f/mako-1.3.10-py3-none-any.whl", hash = "sha256:baef24a52fc4fc514a0887ac600f9f1cff3d82c61d4d700a1fa84d597b88db59", size = 78509, upload-time = "2025-04-10T12:50:53.297Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "markdown"
|
||||
version = "3.8.2"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/d7/c2/4ab49206c17f75cb08d6311171f2d65798988db4360c4d1485bd0eedd67c/markdown-3.8.2.tar.gz", hash = "sha256:247b9a70dd12e27f67431ce62523e675b866d254f900c4fe75ce3dda62237c45", size = 362071, upload-time = "2025-06-19T17:12:44.483Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/96/2b/34cc11786bc00d0f04d0f5fdc3a2b1ae0b6239eef72d3d345805f9ad92a1/markdown-3.8.2-py3-none-any.whl", hash = "sha256:5c83764dbd4e00bdd94d85a19b8d55ccca20fe35b2e678a1422b380324dd5f24", size = 106827, upload-time = "2025-06-19T17:12:42.994Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "markdown-it-py"
|
||||
version = "3.0.0"
|
||||
@@ -3929,6 +4031,8 @@ dependencies = [
|
||||
{ name = "dagster-aws" },
|
||||
{ name = "dagster-celery" },
|
||||
{ name = "dagster-cloud" },
|
||||
{ name = "dagster-docker" },
|
||||
{ name = "dagster-pipes" },
|
||||
{ name = "dagster-postgres" },
|
||||
{ name = "dagster-slack" },
|
||||
{ name = "dagster-webserver" },
|
||||
@@ -4071,6 +4175,7 @@ dev = [
|
||||
{ name = "boto3-stubs", extra = ["s3"] },
|
||||
{ name = "braintrust" },
|
||||
{ name = "braintrust-langchain" },
|
||||
{ name = "dagster-dg-cli" },
|
||||
{ name = "datamodel-code-generator" },
|
||||
{ name = "deepdiff" },
|
||||
{ name = "django-linear-migrations" },
|
||||
@@ -4107,7 +4212,7 @@ dev = [
|
||||
{ name = "responses" },
|
||||
{ name = "ruff" },
|
||||
{ name = "sqlalchemy" },
|
||||
{ name = "stpyv8" },
|
||||
{ name = "stpyv8", marker = "(platform_machine != 'aarch64' and platform_machine != 'arm64') or sys_platform != 'linux'" },
|
||||
{ name = "syrupy" },
|
||||
{ name = "tach" },
|
||||
{ name = "types-aioboto3", extra = ["s3"] },
|
||||
@@ -4153,6 +4258,8 @@ requires-dist = [
|
||||
{ name = "dagster-aws", specifier = "==0.26.18" },
|
||||
{ name = "dagster-celery", specifier = "==0.26.18" },
|
||||
{ name = "dagster-cloud", specifier = "==1.10.18" },
|
||||
{ name = "dagster-docker", specifier = "==0.26.18" },
|
||||
{ name = "dagster-pipes", specifier = "==1.10.18" },
|
||||
{ name = "dagster-postgres", specifier = "==0.26.18" },
|
||||
{ name = "dagster-slack", specifier = "==0.26.18" },
|
||||
{ name = "dagster-webserver", specifier = "==1.10.18" },
|
||||
@@ -4293,8 +4400,9 @@ dev = [
|
||||
{ name = "autoevals", specifier = "==0.0.129" },
|
||||
{ name = "black", specifier = "~=23.9.1" },
|
||||
{ name = "boto3-stubs", extras = ["s3"], specifier = ">=1.34.84" },
|
||||
{ name = "braintrust", specifier = "==0.2.0" },
|
||||
{ name = "braintrust-langchain", specifier = "==0.0.2" },
|
||||
{ name = "braintrust", specifier = "==0.2.4" },
|
||||
{ name = "braintrust-langchain", specifier = "==0.0.4" },
|
||||
{ name = "dagster-dg-cli", specifier = ">=1.10.18" },
|
||||
{ name = "datamodel-code-generator", specifier = "==0.28.5" },
|
||||
{ name = "deepdiff", specifier = ">=8.5.0" },
|
||||
{ name = "django-linear-migrations", specifier = "==2.16.*" },
|
||||
@@ -4331,7 +4439,7 @@ dev = [
|
||||
{ name = "responses", specifier = "==0.23.1" },
|
||||
{ name = "ruff", specifier = "~=0.8.1" },
|
||||
{ name = "sqlalchemy", specifier = "==2.0.38" },
|
||||
{ name = "stpyv8", specifier = "==13.1.201.22" },
|
||||
{ name = "stpyv8", marker = "(platform_machine != 'aarch64' and platform_machine != 'arm64') or sys_platform != 'linux'", specifier = "==13.1.201.22" },
|
||||
{ name = "syrupy", specifier = "~=4.6.0" },
|
||||
{ name = "tach", specifier = "~=0.20.0" },
|
||||
{ name = "types-aioboto3", extras = ["s3"], specifier = ">=14.3.0" },
|
||||
@@ -5930,6 +6038,15 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/e5/30/643397144bfbfec6f6ef821f36f33e57d35946c44a2352d3c9f0ae847619/tenacity-9.1.2-py3-none-any.whl", hash = "sha256:f77bf36710d8b73a50b2dd155c97b870017ad21afe6ab300326b0371b3b05138", size = 28248, upload-time = "2025-04-02T08:25:07.678Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "termcolor"
|
||||
version = "2.3.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/b8/85/147a0529b4e80b6b9d021ca8db3a820fcac53ec7374b87073d004aaf444c/termcolor-2.3.0.tar.gz", hash = "sha256:b5b08f68937f138fe92f6c089b99f1e2da0ae56c52b78bf7075fd95420fd9a5a", size = 12163, upload-time = "2023-04-23T19:45:24.004Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/67/e1/434566ffce04448192369c1a282931cf4ae593e91907558eaecd2e9f2801/termcolor-2.3.0-py3-none-any.whl", hash = "sha256:3afb05607b89aed0ffe25202399ee0867ad4d3cb4180d98aaf8eefa6a5f7d475", size = 6872, upload-time = "2023-04-23T19:45:22.671Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "text-unidecode"
|
||||
version = "1.3"
|
||||
@@ -6920,6 +7037,18 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/ea/1f/70c57b3d7278e94ed22d85e09685d3f0a38ebdd8c5c73b65ba4c0d0fe002/yarl-1.20.0-py3-none-any.whl", hash = "sha256:5d0fe6af927a47a230f31e6004621fd0959eaa915fc62acfafa67ff7229a3124", size = 46124, upload-time = "2025-04-17T00:45:12.199Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "yaspin"
|
||||
version = "3.1.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "termcolor" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/07/3c/70df5034e6712fcc238b76f6afd1871de143a2a124d80ae2c377cde180f3/yaspin-3.1.0.tar.gz", hash = "sha256:7b97c7e257ec598f98cef9878e038bfa619ceb54ac31d61d8ead2b3128f8d7c7", size = 36791, upload-time = "2024-09-22T17:07:09.376Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/89/78/fa25b385d9f2c406719b5cf574a0980f5ccc6ea1f8411d56249f44acd3c2/yaspin-3.1.0-py3-none-any.whl", hash = "sha256:5e3d4dfb547d942cae6565718123f1ecfa93e745b7e51871ad2bbae839e71b73", size = 18629, upload-time = "2024-09-22T17:07:06.923Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "zeep"
|
||||
version = "4.3.1"
|
||||
|
||||