Merge branch 'master' into master

This commit is contained in:
BillyOutlast
2025-11-18 09:29:18 -05:00
committed by GitHub
1495 changed files with 65672 additions and 25351 deletions

View File

@@ -1 +1,10 @@
# Local dev only, see dags/README.md for more info
concurrency:
runs:
tag_concurrency_limits:
# See dags/sessions.py
- key: 'sessions_backfill_concurrency'
limit: 3
value:
applyLimitPerUniqueValue: true

View File

@@ -16,5 +16,7 @@ load_from:
- python_module: dags.locations.analytics_platform
- python_module: dags.locations.experiments
- python_module: dags.locations.growth
- python_module: dags.locations.llma
- python_module: dags.locations.max_ai
- python_module: dags.locations.web_analytics
- python_module: dags.locations.ingestion

View File

@@ -45,6 +45,7 @@ cmake = { pkg-path = "cmake", version = "3.31.5", pkg-group = "cmake" }
sqlx-cli = { pkg-path = "sqlx-cli", version = "0.8.3" } # sqlx
postgresql = { pkg-path = "postgresql_14" } # psql
ffmpeg.pkg-path = "ffmpeg"
ngrok = { pkg-path = "ngrok" }
# Set environment variables in the `[vars]` section. These variables may not
# reference one another, and are added to the environment without first

View File

@@ -21,7 +21,10 @@ inputs:
description: 'Repository name (owner/repo)'
required: true
commit-sha:
description: 'Commit SHA'
description: 'Commit SHA that was tested'
required: true
branch-name:
description: 'Branch name'
required: true
github-token:
description: 'GitHub token for authentication'
@@ -66,9 +69,48 @@ runs:
git config user.name "github-actions[bot]"
git config user.email "github-actions[bot]@users.noreply.github.com"
- name: Verify branch hasn't advanced
id: verify
shell: bash
run: |
# Check if the remote branch HEAD still matches the commit we tested
TESTED_SHA="${{ inputs.commit-sha }}"
CURRENT_SHA=$(git ls-remote origin "refs/heads/${{ inputs.branch-name }}" | cut -f1)
if [ "$TESTED_SHA" != "$CURRENT_SHA" ]; then
echo "branch-advanced=true" >> $GITHUB_OUTPUT
echo "current-sha=$CURRENT_SHA" >> $GITHUB_OUTPUT
echo "⚠️ Branch has advanced during workflow execution"
echo " Tested: $TESTED_SHA"
echo " Current: $CURRENT_SHA"
echo " Skipping snapshot commit - new workflow run will handle it"
else
echo "branch-advanced=false" >> $GITHUB_OUTPUT
echo "✓ Branch HEAD matches tested commit - proceeding with snapshot commit"
fi
- name: Post skip comment
if: steps.verify.outputs.branch-advanced == 'true'
shell: bash
env:
GH_TOKEN: ${{ inputs.github-token }}
run: |
TESTED_SHA="${{ inputs.commit-sha }}"
CURRENT_SHA="${{ steps.verify.outputs.current-sha }}"
gh pr comment ${{ inputs.pr-number }} --body "⏭️ Skipped snapshot commit because branch advanced to \`${CURRENT_SHA:0:7}\` while workflow was testing \`${TESTED_SHA:0:7}\`.
The new commit will trigger its own snapshot update workflow.
**If you expected this workflow to succeed:** This can happen due to concurrent commits. To get a fresh workflow run, either:
- Merge master into your branch, or
- Push an empty commit: \`git commit --allow-empty -m 'trigger CI' && git push\`"
- name: Count and commit changes
id: commit
if: steps.verify.outputs.branch-advanced == 'false'
shell: bash
env:
GH_TOKEN: ${{ inputs.github-token }}
run: |
CHANGES_JSON=$(.github/scripts/count-snapshot-changes.sh ${{ inputs.snapshot-path }})
echo "changes=$CHANGES_JSON" >> $GITHUB_OUTPUT
@@ -76,14 +118,17 @@ runs:
TOTAL=$(echo "$CHANGES_JSON" | jq -r '.total')
if [ "$TOTAL" -gt 0 ]; then
# Disable auto-merge BEFORE committing
# This ensures auto-merge is disabled even if workflow gets cancelled after push
gh pr merge --disable-auto ${{ inputs.pr-number }} || echo "Auto-merge was not enabled"
echo "Auto-merge disabled - snapshot changes require human review"
# Now commit and push
git add ${{ inputs.snapshot-path }} -A
git commit -m "${{ inputs.commit-message }}"
# Pull before push to avoid race condition when multiple workflows commit simultaneously
# Only fetch current branch to avoid verbose output from fetching all branches
CURRENT_BRANCH=$(git rev-parse --abbrev-ref HEAD)
git fetch origin "$CURRENT_BRANCH" --quiet
git rebase origin/"$CURRENT_BRANCH" --autostash
git push --quiet
# Push directly - SHA verification already ensured we're up to date
BRANCH_NAME="${{ inputs.branch-name }}"
git push origin HEAD:"$BRANCH_NAME" --quiet
SNAPSHOT_SHA=$(git rev-parse HEAD)
echo "committed=true" >> $GITHUB_OUTPUT
echo "snapshot-sha=$SNAPSHOT_SHA" >> $GITHUB_OUTPUT
@@ -92,15 +137,6 @@ runs:
echo "committed=false" >> $GITHUB_OUTPUT
fi
- name: Disable auto-merge
if: steps.commit.outputs.committed == 'true'
shell: bash
env:
GH_TOKEN: ${{ inputs.github-token }}
run: |
gh pr merge --disable-auto ${{ inputs.pr-number }} || echo "Auto-merge was not enabled"
echo "Snapshot changes require human review"
- name: Post snapshot comment
if: steps.commit.outputs.committed == 'true'
shell: bash

View File

@@ -233,7 +233,7 @@ runs:
- name: Upload updated timing data as artifacts
uses: actions/upload-artifact@v4
if: ${{ inputs.person-on-events != 'true' && inputs.clickhouse-server-image == 'clickhouse/clickhouse-server:25.6.9.98' }}
if: ${{ inputs.person-on-events != 'true' && inputs.clickhouse-server-image == 'clickhouse/clickhouse-server:25.8.11.66' }}
with:
name: timing_data-${{ inputs.segment }}-${{ inputs.group }}
path: .test_durations

133
.github/scripts/post-eval-summary.js vendored Normal file
View File

@@ -0,0 +1,133 @@
// Export for use with actions/github-script
const DIFF_THRESHOLD = 0.02
module.exports = ({ github, context, fs }) => {
// Read the eval results
const evalResults = fs
.readFileSync('eval_results.jsonl', 'utf8')
.trim()
.split('\n')
.filter((line) => line.trim().length > 0)
.map((line) => JSON.parse(line))
if (evalResults.length === 0) {
console.warn('No eval results found')
return
}
// Calculate max diff for each experiment and categorize
const experimentsWithMaxDiff = evalResults.map((result) => {
const scores = result.scores || {}
const diffs = Object.values(scores)
.map((v) => v.diff)
.filter((d) => typeof d === 'number')
const minDiff = diffs.length > 0 ? Math.min(...diffs) : 0
const maxDiff = diffs.length > 0 ? Math.max(...diffs) : 0
const category = minDiff < -DIFF_THRESHOLD ? 'regression' : maxDiff > DIFF_THRESHOLD ? 'improvement' : 'neutral'
return {
result,
category,
maxDiffInCategoryAbs:
category === 'regression'
? -minDiff
: category === 'improvement'
? maxDiff
: Math.max(Math.abs(minDiff), Math.abs(maxDiff)),
}
})
// Sort: regressions first (most negative to least), then improvements (most positive to least), then neutral
experimentsWithMaxDiff.sort((a, b) => {
if (a.category === b.category) {
return b.maxDiffInCategoryAbs - a.maxDiffInCategoryAbs
}
const order = { regression: 0, improvement: 1, neutral: 2 }
return order[a.category] - order[b.category]
})
// Generate experiment summaries
const experimentSummaries = experimentsWithMaxDiff.map(({ result }) => {
// Format scores as bullet points with improvements/regressions and baseline comparison
const scoresList = Object.entries(result.scores || {})
.map(([key, value]) => {
const score = typeof value.score === 'number' ? `${(value.score * 100).toFixed(2)}%` : value.score
let baselineComparison = null
const diffHighlight = Math.abs(value.diff) > DIFF_THRESHOLD ? '**' : ''
let diffEmoji = '🆕'
if (result.comparison_experiment_name?.startsWith('master-')) {
baselineComparison = `${diffHighlight}${value.diff > 0 ? '+' : value.diff < 0 ? '' : '±'}${(
value.diff * 100
).toFixed(
2
)}%${diffHighlight} (improvements: ${value.improvements}, regressions: ${value.regressions})`
diffEmoji = value.diff > DIFF_THRESHOLD ? '🟢' : value.diff < -DIFF_THRESHOLD ? '🔴' : '🔵'
}
return `${diffEmoji} **${key}**: **${score}**${baselineComparison ? `, ${baselineComparison}` : ''}`
})
.join('\n')
// Format key metrics concisely
const metrics = result.metrics || {}
const duration = metrics.duration ? `⏱️ ${metrics.duration.metric.toFixed(2)} s` : null
const totalTokens = metrics.total_tokens ? `🔢 ${Math.floor(metrics.total_tokens.metric)} tokens` : null
const cost = metrics.estimated_cost ? `💵 $${metrics.estimated_cost.metric.toFixed(4)} in tokens` : null
const metricsText = [duration, totalTokens, cost].filter(Boolean).join(', ')
const baselineLink = `[${result.comparison_experiment_name}](${result.project_url}/experiments/${result.comparison_experiment_name})`
// Create concise experiment summary with header only showing experiment name
const experimentName = result.project_name.replace(/^max-ai-/, '')
return [
`### [${experimentName}](${result.experiment_url})`,
scoresList,
`Baseline: ${baselineLink} • Avg. case performance: ${metricsText}`,
].join('\n\n')
})
// Split summaries by category
const regressions = []
const improvements = []
const neutral = []
experimentsWithMaxDiff.forEach(({ category }, idx) => {
if (category === 'regression') regressions.push(experimentSummaries[idx])
else if (category === 'improvement') improvements.push(experimentSummaries[idx])
else neutral.push(experimentSummaries[idx])
})
const totalExperiments = evalResults.length
const totalMetrics = evalResults.reduce((acc, result) => acc + Object.keys(result.scores || {}).length, 0)
const bodyParts = [
`## 🧠 AI eval results`,
`Evaluated **${totalExperiments}** experiments, comprising **${totalMetrics}** metrics. Showing experiments with largest regressions first.`,
]
bodyParts.push(...regressions)
bodyParts.push(...improvements)
if (neutral.length > 0) {
bodyParts.push(
`<details><summary>${neutral.length} ${neutral.length === 1 ? 'experiment' : 'experiments'} with no significant changes</summary>\n\n${neutral.join('\n\n')}\n\n</details>`
)
}
bodyParts.push(
`_Triggered by [this commit](https://github.com/${context.repo.owner}/${context.repo.repo}/pull/${context.payload.pull_request.number}/commits/${context.payload.pull_request.head.sha})._`
)
const body = bodyParts.join('\n\n')
// Post comment on PR
if (context.payload.pull_request) {
github.rest.issues.createComment({
issue_number: context.issue.number,
owner: context.repo.owner,
repo: context.repo.repo,
body: body,
})
} else {
// Just log the summary if this is a push to master
console.info(body)
}
}

View File

@@ -80,80 +80,10 @@ jobs:
- name: Post eval summary to PR
# always() because we want to post even if `pytest` exited with an error (likely just one eval suite errored)
if: always() && github.event_name == 'pull_request'
uses: actions/github-script@v6
uses: actions/github-script@v8
with:
github-token: ${{ secrets.POSTHOG_BOT_PAT }}
script: |
const fs = require("fs")
// Read the eval results
const evalResults = fs
.readFileSync("eval_results.jsonl", "utf8")
.trim()
.split("\n")
.map((line) => JSON.parse(line))
if (evalResults.length === 0) {
console.log("No eval results found")
return
}
// Generate concise experiment summaries
const experimentSummaries = evalResults.map((result) => {
// Format scores as bullet points with improvements/regressions and baseline comparison
const scoresList = Object.entries(result.scores || {})
.map(([key, value]) => {
const score = typeof value.score === "number" ? `${(value.score * 100).toFixed(2)}%` : value.score
let baselineComparison = null
const diffHighlight = Math.abs(value.diff) > 0.01 ? "**" : ""
let diffEmoji = "🆕"
if (result.comparison_experiment_name?.startsWith("master-")) {
baselineComparison = `${diffHighlight}${value.diff > 0 ? "+" : value.diff < 0 ? "" : "±"}${(
value.diff * 100
).toFixed(2)}%${diffHighlight} (improvements: ${value.improvements}, regressions: ${value.regressions})`
diffEmoji = value.diff > 0.01 ? "🟢" : value.diff < -0.01 ? "🔴" : "🔵"
}
return `${diffEmoji} **${key}**: **${score}**${baselineComparison ? `, ${baselineComparison}` : ""}`
})
.join("\n")
// Format key metrics concisely
const metrics = result.metrics || {}
const duration = metrics.duration ? `⏱️ ${metrics.duration.metric.toFixed(2)} s` : null
const totalTokens = metrics.total_tokens ? `🔢 ${Math.floor(metrics.total_tokens.metric)} tokens` : null
const cost = metrics.estimated_cost ? `💵 $${metrics.estimated_cost.metric.toFixed(4)} in tokens` : null
const metricsText = [duration, totalTokens, cost].filter(Boolean).join(", ")
const baselineLink = `[${result.comparison_experiment_name}](${result.project_url}/experiments/${result.comparison_experiment_name})`
// Create concise experiment summary with header only showing experiment name
const experimentName = result.project_name.replace(/^max-ai-/, "")
return [
`### [${experimentName}](${result.experiment_url})`,
scoresList,
`Baseline: ${baselineLink} • Avg. case performance: ${metricsText}`,
].join("\n\n")
})
const totalExperiments = evalResults.length
const totalMetrics = evalResults.reduce((acc, result) => acc + Object.keys(result.scores || {}).length, 0)
const body = [
`## 🧠 AI eval results`,
`Evaluated **${totalExperiments}** experiments, comprising **${totalMetrics}** metrics.`,
...experimentSummaries,
`_Triggered by [this commit](https://github.com/${context.repo.owner}/${context.repo.repo}/pull/${context.payload.pull_request.number}/commits/${context.payload.pull_request.head.sha})._`,
].join("\n\n")
// Post comment on PR
if (context.payload.pull_request) {
github.rest.issues.createComment({
issue_number: context.issue.number,
owner: context.repo.owner,
repo: context.repo.repo,
body: body,
})
} else {
// Just log the summary if this is a push to master
console.log(body)
}
const script = require('.github/scripts/post-eval-summary.js')
script({ github, context, fs })

View File

@@ -29,13 +29,13 @@ jobs:
group: 1
token: ${{ secrets.POSTHOG_BOT_PAT }}
python-version: '3.12.11'
clickhouse-server-image: 'clickhouse/clickhouse-server:25.6.9.98'
clickhouse-server-image: 'clickhouse/clickhouse-server:25.8.11.66'
segment: 'FOSS'
person-on-events: false
- name: Upload updated timing data as artifacts
uses: actions/upload-artifact@v4
if: ${{ inputs.person-on-events != 'true' && inputs.clickhouse-server-image == 'clickhouse/clickhouse-server:25.6.9.98' }}
if: ${{ inputs.person-on-events != 'true' && inputs.clickhouse-server-image == 'clickhouse/clickhouse-server:25.8.11.66' }}
with:
name: timing_data-${{ inputs.segment }}-${{ inputs.group }}
path: .test_durations

View File

@@ -389,7 +389,7 @@ jobs:
fail-fast: false
matrix:
python-version: ['3.12.11']
clickhouse-server-image: ['clickhouse/clickhouse-server:25.6.9.98']
clickhouse-server-image: ['clickhouse/clickhouse-server:25.8.11.66']
segment: ['Core']
person-on-events: [false]
# :NOTE: Keep concurrency and groups in sync
@@ -440,121 +440,121 @@ jobs:
include:
- segment: 'Core'
person-on-events: true
clickhouse-server-image: 'clickhouse/clickhouse-server:25.6.9.98'
clickhouse-server-image: 'clickhouse/clickhouse-server:25.8.11.66'
python-version: '3.12.11'
concurrency: 10
group: 1
- segment: 'Core'
person-on-events: true
clickhouse-server-image: 'clickhouse/clickhouse-server:25.6.9.98'
clickhouse-server-image: 'clickhouse/clickhouse-server:25.8.11.66'
python-version: '3.12.11'
concurrency: 10
group: 2
- segment: 'Core'
person-on-events: true
clickhouse-server-image: 'clickhouse/clickhouse-server:25.6.9.98'
clickhouse-server-image: 'clickhouse/clickhouse-server:25.8.11.66'
python-version: '3.12.11'
concurrency: 10
group: 3
- segment: 'Core'
person-on-events: true
clickhouse-server-image: 'clickhouse/clickhouse-server:25.6.9.98'
clickhouse-server-image: 'clickhouse/clickhouse-server:25.8.11.66'
python-version: '3.12.11'
concurrency: 10
group: 4
- segment: 'Core'
person-on-events: true
clickhouse-server-image: 'clickhouse/clickhouse-server:25.6.9.98'
clickhouse-server-image: 'clickhouse/clickhouse-server:25.8.11.66'
python-version: '3.12.11'
concurrency: 10
group: 5
- segment: 'Core'
person-on-events: true
clickhouse-server-image: 'clickhouse/clickhouse-server:25.6.9.98'
clickhouse-server-image: 'clickhouse/clickhouse-server:25.8.11.66'
python-version: '3.12.11'
concurrency: 10
group: 6
- segment: 'Core'
person-on-events: true
clickhouse-server-image: 'clickhouse/clickhouse-server:25.6.9.98'
clickhouse-server-image: 'clickhouse/clickhouse-server:25.8.11.66'
python-version: '3.12.11'
concurrency: 10
group: 7
- segment: 'Core'
person-on-events: true
clickhouse-server-image: 'clickhouse/clickhouse-server:25.6.9.98'
clickhouse-server-image: 'clickhouse/clickhouse-server:25.8.11.66'
python-version: '3.12.11'
concurrency: 10
group: 8
- segment: 'Core'
person-on-events: true
clickhouse-server-image: 'clickhouse/clickhouse-server:25.6.9.98'
clickhouse-server-image: 'clickhouse/clickhouse-server:25.8.11.66'
python-version: '3.12.11'
concurrency: 10
group: 9
- segment: 'Core'
person-on-events: true
clickhouse-server-image: 'clickhouse/clickhouse-server:25.6.9.98'
clickhouse-server-image: 'clickhouse/clickhouse-server:25.8.11.66'
python-version: '3.12.11'
concurrency: 10
group: 10
- segment: 'Temporal'
person-on-events: false
clickhouse-server-image: 'clickhouse/clickhouse-server:25.6.9.98'
clickhouse-server-image: 'clickhouse/clickhouse-server:25.8.11.66'
python-version: '3.12.11'
concurrency: 10
group: 1
- segment: 'Temporal'
person-on-events: false
clickhouse-server-image: 'clickhouse/clickhouse-server:25.6.9.98'
clickhouse-server-image: 'clickhouse/clickhouse-server:25.8.11.66'
python-version: '3.12.11'
concurrency: 10
group: 2
- segment: 'Temporal'
person-on-events: false
clickhouse-server-image: 'clickhouse/clickhouse-server:25.6.9.98'
clickhouse-server-image: 'clickhouse/clickhouse-server:25.8.11.66'
python-version: '3.12.11'
concurrency: 10
group: 3
- segment: 'Temporal'
person-on-events: false
clickhouse-server-image: 'clickhouse/clickhouse-server:25.6.9.98'
clickhouse-server-image: 'clickhouse/clickhouse-server:25.8.11.66'
python-version: '3.12.11'
concurrency: 10
group: 4
- segment: 'Temporal'
person-on-events: false
clickhouse-server-image: 'clickhouse/clickhouse-server:25.6.9.98'
clickhouse-server-image: 'clickhouse/clickhouse-server:25.8.11.66'
python-version: '3.12.11'
concurrency: 10
group: 5
- segment: 'Temporal'
person-on-events: false
clickhouse-server-image: 'clickhouse/clickhouse-server:25.6.9.98'
clickhouse-server-image: 'clickhouse/clickhouse-server:25.8.11.66'
python-version: '3.12.11'
concurrency: 10
group: 6
- segment: 'Temporal'
person-on-events: false
clickhouse-server-image: 'clickhouse/clickhouse-server:25.6.9.98'
clickhouse-server-image: 'clickhouse/clickhouse-server:25.8.11.66'
python-version: '3.12.11'
concurrency: 10
group: 7
- segment: 'Temporal'
person-on-events: false
clickhouse-server-image: 'clickhouse/clickhouse-server:25.6.9.98'
clickhouse-server-image: 'clickhouse/clickhouse-server:25.8.11.66'
python-version: '3.12.11'
concurrency: 10
group: 8
- segment: 'Temporal'
person-on-events: false
clickhouse-server-image: 'clickhouse/clickhouse-server:25.6.9.98'
clickhouse-server-image: 'clickhouse/clickhouse-server:25.8.11.66'
python-version: '3.12.11'
concurrency: 10
group: 9
- segment: 'Temporal'
person-on-events: false
clickhouse-server-image: 'clickhouse/clickhouse-server:25.6.9.98'
clickhouse-server-image: 'clickhouse/clickhouse-server:25.8.11.66'
python-version: '3.12.11'
concurrency: 10
group: 10
@@ -649,6 +649,22 @@ jobs:
run: |
sudo apt-get update && sudo apt-get install libxml2-dev libxmlsec1-dev libxmlsec1-openssl
- name: Install Rust
if: needs.changes.outputs.backend == 'true'
uses: dtolnay/rust-toolchain@6691ebadcb18182cc1391d07c9f295f657c593cd # 1.88
with:
toolchain: 1.88.0
components: cargo
- name: Cache Rust dependencies
if: needs.changes.outputs.backend == 'true'
uses: Swatinem/rust-cache@f13886b937689c021905a6b90929199931d60db1 # v2.8.1
- name: Install sqlx-cli
if: needs.changes.outputs.backend == 'true'
run: |
cargo install sqlx-cli --version 0.8.0 --features postgres --no-default-features --locked
- name: Determine if hogql-parser has changed compared to master
shell: bash
id: hogql-parser-diff
@@ -796,7 +812,8 @@ jobs:
shell: bash
env:
AWS_S3_ALLOW_UNSAFE_RENAME: 'true'
RUNLOOP_API_KEY: ${{ needs.changes.outputs.tasks_temporal == 'true' && secrets.RUNLOOP_API_KEY || '' }}
MODAL_TOKEN_ID: ${{ needs.changes.outputs.tasks_temporal == 'true' && secrets.MODAL_TOKEN_ID || '' }}
MODAL_TOKEN_SECRET: ${{ needs.changes.outputs.tasks_temporal == 'true' && secrets.MODAL_TOKEN_SECRET || '' }}
run: |
set +e
pytest posthog/temporal products/batch_exports/backend/tests/temporal products/tasks/backend/temporal -m "not async_migrations" \
@@ -821,7 +838,7 @@ jobs:
- name: Upload updated timing data as artifacts
uses: actions/upload-artifact@v4
if: ${{ needs.changes.outputs.backend == 'true' && !matrix.person-on-events && matrix.clickhouse-server-image == 'clickhouse/clickhouse-server:25.6.9.98' }}
if: ${{ needs.changes.outputs.backend == 'true' && !matrix.person-on-events && matrix.clickhouse-server-image == 'clickhouse/clickhouse-server:25.8.11.66' }}
with:
name: timing_data-${{ matrix.segment }}-${{ matrix.group }}
path: .test_durations
@@ -906,7 +923,7 @@ jobs:
strategy:
fail-fast: false
matrix:
clickhouse-server-image: ['clickhouse/clickhouse-server:25.6.9.98']
clickhouse-server-image: ['clickhouse/clickhouse-server:25.8.11.66']
if: needs.changes.outputs.backend == 'true'
runs-on: ubuntu-latest
steps:
@@ -945,6 +962,19 @@ jobs:
sudo apt-get update
sudo apt-get install libxml2-dev libxmlsec1-dev libxmlsec1-openssl
- name: Install Rust
uses: dtolnay/rust-toolchain@6691ebadcb18182cc1391d07c9f295f657c593cd # 1.88
with:
toolchain: 1.88.0
components: cargo
- name: Cache Rust dependencies
uses: Swatinem/rust-cache@f13886b937689c021905a6b90929199931d60db1 # v2.8.1
- name: Install sqlx-cli
run: |
cargo install sqlx-cli --version 0.8.0 --features postgres --no-default-features --locked
- name: Install python dependencies
shell: bash
run: |

View File

@@ -65,7 +65,7 @@ jobs:
strategy:
fail-fast: false
matrix:
clickhouse-server-image: ['clickhouse/clickhouse-server:25.6.9.98']
clickhouse-server-image: ['clickhouse/clickhouse-server:25.8.11.66']
if: needs.changes.outputs.dagster == 'true'
runs-on: depot-ubuntu-latest
steps:

View File

@@ -43,9 +43,9 @@ env:
concurrency:
group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
# This is so that the workflow run isn't canceled when a screenshot update is pushed within it by posthog-bot
# We do however cancel from container-images-ci.yml if a commit is pushed by someone OTHER than posthog-bot
cancel-in-progress: false
# Cancel in-progress runs when new commits are pushed
# SHA verification ensures we don't commit stale snapshots if runs aren't cancelled in time
cancel-in-progress: true
jobs:
changes:
@@ -120,7 +120,7 @@ jobs:
steps:
- uses: actions/checkout@v4
with:
ref: ${{ github.event.pull_request.head.ref }}
ref: ${{ github.event.pull_request.head.sha }}
repository: ${{ github.event.pull_request.head.repo.full_name }}
token: ${{ secrets.POSTHOG_BOT_PAT || github.token }}
fetch-depth: 50 # Need enough history for flap detection to find last human commit
@@ -134,7 +134,7 @@ jobs:
- name: Stop/Start stack with Docker Compose
shell: bash
run: |
export CLICKHOUSE_SERVER_IMAGE=clickhouse/clickhouse-server:25.6.9.98
export CLICKHOUSE_SERVER_IMAGE=clickhouse/clickhouse-server:25.8.11.66
export DOCKER_REGISTRY_PREFIX="us-east1-docker.pkg.dev/posthog-301601/mirror/"
cp posthog/user_scripts/latest_user_defined_function.xml docker/clickhouse/user_defined_function.xml
@@ -495,7 +495,7 @@ jobs:
steps:
- uses: actions/checkout@v4
with:
ref: ${{ github.event.pull_request.head.ref }}
ref: ${{ github.event.pull_request.head.sha }}
token: ${{ secrets.POSTHOG_BOT_PAT || github.token }}
- name: Download screenshot patches
@@ -535,8 +535,33 @@ jobs:
workflow-type: playwright
patch-path: /tmp/screenshot-patches/
snapshot-path: playwright/
commit-message: Update E2E screenshots (Playwright)
commit-message: 'test(e2e): update screenshots'
pr-number: ${{ github.event.pull_request.number }}
repository: ${{ github.repository }}
commit-sha: ${{ github.event.pull_request.head.sha }}
branch-name: ${{ github.event.pull_request.head.ref }}
github-token: ${{ secrets.POSTHOG_BOT_PAT || github.token }}
# Job to collate the status of the matrix jobs for requiring passing status
# Must depend on handle-screenshots to prevent auto-merge before commits complete
playwright_tests:
needs: [playwright, handle-screenshots]
name: Playwright tests pass
runs-on: ubuntu-latest
if: always()
steps:
- name: Check matrix outcome
run: |
# Check playwright matrix result
if [[ "${{ needs.playwright.result }}" != "success" && "${{ needs.playwright.result }}" != "skipped" ]]; then
echo "One or more jobs in the Playwright test matrix failed."
exit 1
fi
# Check handle-screenshots result (OK if skipped, but fail if it failed)
if [[ "${{ needs.handle-screenshots.result }}" == "failure" ]]; then
echo "Screenshot commit job failed."
exit 1
fi
echo "All jobs passed or were skipped successfully."

View File

@@ -26,47 +26,6 @@ jobs:
- '.github/workflows/mcp-ci.yml'
- '.github/workflows/mcp-publish.yml'
lint-and-format:
name: Lint, Format, and Type Check
runs-on: ubuntu-latest
needs: changes
if: needs.changes.outputs.mcp == 'true'
permissions:
contents: read
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Install pnpm
uses: pnpm/action-setup@a7487c7e89a18df4991f7f222e4898a00d66ddda # v4
- name: Setup Node.js
uses: actions/setup-node@v4
with:
node-version: 22
cache: 'pnpm'
- name: Install dependencies
run: pnpm install
- name: Run linter
run: cd products/mcp && pnpm run lint
- name: Run formatter
run: cd products/mcp && pnpm run format
- name: Run type check
run: cd products/mcp && pnpm run typecheck
- name: Check for changes
run: |
if [ -n "$(git status --porcelain)" ]; then
echo "Code formatting or linting changes detected!"
git diff
exit 1
fi
unit-tests:
name: Unit Tests
runs-on: ubuntu-latest

View File

@@ -135,11 +135,6 @@ jobs:
run: |
mypy --version && mypy --cache-fine-grained . | mypy-baseline filter || (echo "run 'pnpm run mypy-baseline-sync' to update the baseline" && exit 1)
- name: Check if "schema.py" is up to date
shell: bash
run: |
npm run schema:build:python && git diff --exit-code
- name: Check hogli manifest completeness
shell: bash
run: |

View File

@@ -4,9 +4,9 @@ on:
concurrency:
group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
# This is so that the workflow run isn't canceled when a snapshot update is pushed within it by posthog-bot
# We do however cancel from container-images-ci.yml if a commit is pushed by someone OTHER than posthog-bot
cancel-in-progress: false
# Cancel in-progress runs when new commits are pushed
# SHA verification ensures we don't commit stale snapshots if runs aren't cancelled in time
cancel-in-progress: true
permissions:
contents: write
@@ -48,7 +48,7 @@ jobs:
steps:
- uses: actions/checkout@v4
with:
ref: ${{ github.event.pull_request.head.ref }}
ref: ${{ github.event.pull_request.head.sha }}
repository: ${{ github.event.pull_request.head.repo.full_name }}
# Use PostHog Bot token when not on forks to enable proper snapshot updating
token: ${{ secrets.POSTHOG_BOT_PAT || github.token }}
@@ -124,7 +124,7 @@ jobs:
steps:
- uses: actions/checkout@v4
with:
ref: ${{ github.event.pull_request.head.ref }}
ref: ${{ github.event.pull_request.head.sha }}
repository: ${{ github.event.pull_request.head.repo.full_name }}
# Use PostHog Bot token when not on forks to enable proper snapshot updating
token: ${{ secrets.POSTHOG_BOT_PAT || github.token }}
@@ -232,7 +232,7 @@ jobs:
steps:
- uses: actions/checkout@v4
with:
ref: ${{ github.event.pull_request.head.ref }}
ref: ${{ github.event.pull_request.head.sha }}
repository: ${{ github.event.pull_request.head.repo.full_name }}
# Use PostHog Bot token when not on forks to enable proper snapshot updating
token: ${{ secrets.POSTHOG_BOT_PAT || github.token }}
@@ -339,7 +339,7 @@ jobs:
if [ $EXIT_CODE -ne 0 ]; then
# Still failing - check if we should update snapshots
if [ "${{ needs.detect-snapshot-mode.outputs.mode }}" == "update" ] && grep -Eqi "(snapshot.*(failed|was not written)|Expected image to be the same size)" /tmp/test-output-${{ matrix.browser }}-${{ matrix.shard }}-attempt2.log 2>/dev/null; then
if [ "${{ needs.detect-snapshot-mode.outputs.mode }}" == "update" ] && grep -Eqi "(snapshot.*(failed|was not written)|Expected image to)" /tmp/test-output-${{ matrix.browser }}-${{ matrix.shard }}-attempt2.log 2>/dev/null; then
echo "Snapshot failure confirmed on attempt 2 - updating snapshots"
# Attempt 3: Update snapshots
echo "Attempt 3: Running @storybook/test-runner (update)"
@@ -413,21 +413,28 @@ jobs:
if-no-files-found: ignore
# Job to collate the status of the matrix jobs for requiring passing status
# Must depend on handle-snapshots to prevent auto-merge before commits complete
visual_regression_tests:
needs: [visual-regression]
needs: [visual-regression, handle-snapshots]
name: Visual regression tests pass
runs-on: ubuntu-latest
if: always()
steps:
- name: Check matrix outcome
run: |
# The `needs.visual-regression.result` will be 'success' only if all jobs in the matrix succeeded.
# Otherwise, it will be 'failure'.
# Check visual-regression matrix result
if [[ "${{ needs.visual-regression.result }}" != "success" && "${{ needs.visual-regression.result }}" != "skipped" ]]; then
echo "One or more jobs in the visual-regression test matrix failed."
exit 1
fi
echo "All jobs in the visual-regression test matrix passed."
# Check handle-snapshots result (OK if skipped, but fail if it failed)
if [[ "${{ needs.handle-snapshots.result }}" == "failure" ]]; then
echo "Snapshot commit job failed."
exit 1
fi
echo "All jobs passed or were skipped successfully."
handle-snapshots:
name: Handle snapshot changes
@@ -437,7 +444,7 @@ jobs:
steps:
- uses: actions/checkout@v4
with:
ref: ${{ github.event.pull_request.head.ref }}
ref: ${{ github.event.pull_request.head.sha }}
repository: ${{ github.event.pull_request.head.repo.full_name }}
token: ${{ secrets.POSTHOG_BOT_PAT || github.token }}
fetch-depth: 1
@@ -479,10 +486,11 @@ jobs:
workflow-type: storybook
patch-path: /tmp/snapshot-patches/
snapshot-path: frontend/__snapshots__/
commit-message: Update UI snapshots (all browsers)
commit-message: 'test(storybook): update UI snapshots'
pr-number: ${{ github.event.pull_request.number }}
repository: ${{ github.repository }}
commit-sha: ${{ github.event.pull_request.head.sha }}
branch-name: ${{ github.event.pull_request.head.ref }}
github-token: ${{ secrets.POSTHOG_BOT_PAT || github.token }}
calculate-running-time:

View File

@@ -23,7 +23,7 @@ jobs:
with:
filters: |
mcp:
- 'mcp/**'
- 'products/mcp/**'
- '.github/workflows/mcp-ci.yml'
- '.github/workflows/mcp-publish.yml'

View File

@@ -16,7 +16,10 @@
"common/plugin_transpiler/dist",
"common/hogvm/__tests__/__snapshots__",
"common/hogvm/__tests__/__snapshots__/**",
"frontend/src/queries/validators.js"
"frontend/src/queries/validators.js",
"products/mcp/**/generated.ts",
"products/mcp/schema/tool-inputs.json",
"products/mcp/python"
],
"rules": {
"no-constant-condition": "off",

View File

@@ -6,6 +6,7 @@ staticfiles
.env
*.code-workspace
.mypy_cache
.cache/
*Type.ts
.idea
.yalc
@@ -15,8 +16,7 @@ common/storybook/dist/
dist/
node_modules/
pnpm-lock.yaml
posthog/templates/email/*
posthog/templates/**/*.html
posthog/templates/**/*
common/hogvm/typescript/src/stl/bytecode.ts
common/hogvm/__tests__/__snapshots__/*
rust/
@@ -26,3 +26,10 @@ cli/tests/_cases/**/*
frontend/src/products.tsx
frontend/src/layout.html
**/fixtures/**
products/mcp/schema/**/*
products/mcp/python/**/*
products/mcp/**/generated.ts
products/mcp/typescript/worker-configuration.d.ts
frontend/src/taxonomy/core-filter-definitions-by-group.json
frontend/dist/
frontend/**/*LogicType.ts

View File

@@ -14,6 +14,7 @@
"^@posthog.*$",
"^lib/(.*)$|^scenes/(.*)$",
"^~/(.*)$",
"^@/(.*)$",
"^public/(.*)$",
"^products/(.*)$",
"^storybook/(.*)$",

View File

@@ -329,8 +329,7 @@ COPY --from=frontend-build --chown=posthog:posthog /code/frontend/dist /code/fro
# Copy the GeoLite2-City database from the fetch-geoip-db stage.
COPY --from=fetch-geoip-db --chown=posthog:posthog /code/share/GeoLite2-City.mmdb /code/share/GeoLite2-City.mmdb
# Add in the Gunicorn config, custom bin files and Django deps.
COPY --chown=posthog:posthog gunicorn.config.py ./
# Add in custom bin files and Django deps.
COPY --chown=posthog:posthog ./bin ./bin/
COPY --chown=posthog:posthog manage.py manage.py
COPY --chown=posthog:posthog posthog posthog/
@@ -339,9 +338,6 @@ COPY --chown=posthog:posthog common/hogvm common/hogvm/
COPY --chown=posthog:posthog dags dags/
COPY --chown=posthog:posthog products products/
# Keep server command backwards compatible
RUN cp ./bin/docker-server-unit ./bin/docker-server
# Setup ENV.
ENV NODE_ENV=production \
CHROME_BIN=/usr/bin/chromium \

View File

@@ -1,12 +1,14 @@
-- temporary sql to initialise log tables for local development
-- will be removed once we have migrations set up
CREATE OR REPLACE FUNCTION extractIPv4Substrings AS
(
body -> extractAll(body, '(\d\.((25[0-5]|(2[0-4]|1{0,1}[0-9]){0,1}[0-9])\.){2,2}([0-9]))')
);
-- temporary sql to initialise log tables for local development
-- will be removed once we have migrations set up
CREATE TABLE if not exists logs27
CREATE OR REPLACE TABLE logs31
(
`time_bucket` DateTime MATERIALIZED toStartOfInterval(timestamp, toIntervalMinute(5)) CODEC(DoubleDelta, ZSTD(1)),
-- time bucket is set to day which means it's effectively not in the order by key (same as partition)
-- but gives us flexibility to add the bucket to the order key if needed
`time_bucket` DateTime MATERIALIZED toStartOfDay(timestamp) CODEC(DoubleDelta, ZSTD(1)),
`uuid` String CODEC(ZSTD(1)),
`team_id` Int32 CODEC(ZSTD(1)),
`trace_id` String CODEC(ZSTD(1)),
@@ -20,32 +22,22 @@ CREATE TABLE if not exists logs27
`severity_number` Int32 CODEC(ZSTD(1)),
`service_name` String CODEC(ZSTD(1)),
`resource_attributes` Map(String, String) CODEC(ZSTD(1)),
`resource_fingerprint` UInt64 MATERIALIZED cityHash64(resource_attributes) CODEC(DoubleDelta, ZSTD(1)),
`resource_id` String CODEC(ZSTD(1)),
`instrumentation_scope` String CODEC(ZSTD(1)),
`event_name` String CODEC(ZSTD(1)),
`attributes_map_str` Map(String, String) CODEC(ZSTD(1)),
`attributes_map_float` Map(String, Float64) CODEC(ZSTD(1)),
`attributes_map_datetime` Map(String, DateTime64(6)) CODEC(ZSTD(1)),
`attributes_map_str` Map(LowCardinality(String), String) CODEC(ZSTD(1)),
`attributes_map_float` Map(LowCardinality(String), Float64) CODEC(ZSTD(1)),
`attributes_map_datetime` Map(LowCardinality(String), DateTime64(6)) CODEC(ZSTD(1)),
`level` String ALIAS severity_text,
`mat_body_ipv4_matches` Array(String) ALIAS extractAll(body, '(\\d\\.((25[0-5]|(2[0-4]|1{0,1}[0-9]){0,1}[0-9])\\.){2,2}([0-9]))'),
`time_minute` DateTime ALIAS toStartOfMinute(timestamp),
`attributes` Map(String, String) ALIAS mapApply((k, v) -> (k, toJSONString(v)), attributes_map_str),
`attributes` Map(String, String) ALIAS mapApply((k, v) -> (left(k, -5), v), attributes_map_str),
INDEX idx_severity_text_set severity_text TYPE set(10) GRANULARITY 1,
INDEX idx_attributes_str_keys mapKeys(attributes_map_str) TYPE bloom_filter(0.01) GRANULARITY 1,
INDEX idx_attributes_str_values mapValues(attributes_map_str) TYPE bloom_filter(0.001) GRANULARITY 1,
INDEX idx_mat_body_ipv4_matches mat_body_ipv4_matches TYPE bloom_filter(0.01) GRANULARITY 1,
INDEX idx_body_ngram3 body TYPE ngrambf_v1(3, 20000, 2, 0) GRANULARITY 1,
CONSTRAINT assume_time_bucket ASSUME toStartOfInterval(timestamp, toIntervalMinute(5)) = time_bucket,
PROJECTION projection_trace_span
(
SELECT
trace_id,
timestamp,
_part_offset
ORDER BY
trace_id,
timestamp
),
INDEX idx_body_ngram3 body TYPE ngrambf_v1(3, 25000, 2, 0) GRANULARITY 1,
PROJECTION projection_aggregate_counts
(
SELECT
@@ -54,8 +46,7 @@ CREATE TABLE if not exists logs27
toStartOfMinute(timestamp),
service_name,
severity_text,
resource_attributes,
resource_id,
resource_fingerprint,
count() AS event_count
GROUP BY
team_id,
@@ -63,43 +54,47 @@ CREATE TABLE if not exists logs27
toStartOfMinute(timestamp),
service_name,
severity_text,
resource_attributes,
resource_id
resource_fingerprint
)
)
ENGINE = ReplicatedReplacingMergeTree('/clickhouse/tables/{shard}/logs27', '{replica}')
ENGINE = MergeTree
PARTITION BY toDate(timestamp)
ORDER BY (team_id, time_bucket DESC, service_name, resource_attributes, severity_text, timestamp DESC, uuid, trace_id, span_id)
SETTINGS allow_remote_fs_zero_copy_replication = 1,
allow_experimental_reverse_key = 1,
deduplicate_merge_projection_mode = 'ignore';
PRIMARY KEY (team_id, time_bucket, service_name, resource_fingerprint, severity_text, timestamp)
ORDER BY (team_id, time_bucket, service_name, resource_fingerprint, severity_text, timestamp)
SETTINGS
index_granularity_bytes = 104857600,
index_granularity = 8192,
ttl_only_drop_parts = 1;
create or replace TABLE logs AS logs27 ENGINE = Distributed('posthog', 'default', 'logs27');
create or replace TABLE logs AS logs31 ENGINE = Distributed('posthog', 'default', 'logs31');
create table if not exists log_attributes
create or replace table default.log_attributes
(
`team_id` Int32,
`time_bucket` DateTime64(0),
`service_name` LowCardinality(String),
`resource_id` String DEFAULT '',
`resource_fingerprint` UInt64 DEFAULT 0,
`attribute_key` LowCardinality(String),
`attribute_value` String,
`attribute_count` SimpleAggregateFunction(sum, UInt64),
INDEX idx_attribute_key attribute_key TYPE bloom_filter(0.01) GRANULARITY 1,
INDEX idx_attribute_value attribute_value TYPE bloom_filter(0.001) GRANULARITY 1,
INDEX idx_attribute_value attribute_value TYPE bloom_filter(0.01) GRANULARITY 1,
INDEX idx_attribute_key_n3 attribute_key TYPE ngrambf_v1(3, 32768, 3, 0) GRANULARITY 1,
INDEX idx_attribute_value_n3 attribute_value TYPE ngrambf_v1(3, 32768, 3, 0) GRANULARITY 1
)
ENGINE = ReplicatedAggregatingMergeTree('/clickhouse/tables/{shard}/log_attributes', '{replica}')
ENGINE = AggregatingMergeTree
PARTITION BY toDate(time_bucket)
ORDER BY (team_id, service_name, time_bucket, attribute_key, attribute_value);
ORDER BY (team_id, time_bucket, resource_fingerprint, attribute_key, attribute_value);
set enable_dynamic_type=1;
CREATE MATERIALIZED VIEW if not exists log_to_log_attributes TO log_attributes
drop view if exists log_to_log_attributes;
CREATE MATERIALIZED VIEW log_to_log_attributes TO log_attributes
(
`team_id` Int32,
`time_bucket` DateTime64(0),
`service_name` LowCardinality(String),
`resource_fingerprint` UInt64,
`attribute_key` LowCardinality(String),
`attribute_value` String,
`attribute_count` SimpleAggregateFunction(sum, UInt64)
@@ -108,23 +103,28 @@ AS SELECT
team_id,
time_bucket,
service_name,
resource_fingerprint,
attribute_key,
attribute_value,
attribute_count
FROM (select
team_id AS team_id,
toStartOfInterval(timestamp, toIntervalMinute(10)) AS time_bucket,
service_name AS service_name,
arrayJoin(arrayMap((k, v) -> (k, if(length(v) > 256, '', v)), arrayFilter((k, v) -> (length(k) < 256), CAST(attributes, 'Array(Tuple(String, String))')))) AS attribute,
attribute.1 AS attribute_key,
CAST(JSONExtract(attribute.2, 'Dynamic'), 'String') AS attribute_value,
sumSimpleState(1) AS attribute_count
FROM logs27
GROUP BY
team_id,
time_bucket,
service_name,
attribute
FROM
(
SELECT
team_id AS team_id,
toStartOfInterval(timestamp, toIntervalMinute(10)) AS time_bucket,
service_name AS service_name,
resource_fingerprint,
arrayJoin(mapFilter((k, v) -> ((length(k) < 256) AND (length(v) < 256)), attributes)) AS attribute,
attribute.1 AS attribute_key,
attribute.2 AS attribute_value,
sumSimpleState(1) AS attribute_count
FROM logs31
GROUP BY
team_id,
time_bucket,
service_name,
resource_fingerprint,
attribute
);
CREATE OR REPLACE TABLE kafka_logs_avro
@@ -139,11 +139,10 @@ CREATE OR REPLACE TABLE kafka_logs_avro
`severity_text` String,
`severity_number` Int32,
`service_name` String,
`resource_attributes` Map(String, String),
`resource_id` String,
`resource_attributes` Map(LowCardinality(String), String),
`instrumentation_scope` String,
`event_name` String,
`attributes` Map(String, Nullable(String))
`attributes` Map(LowCardinality(String), String)
)
ENGINE = Kafka('kafka:9092', 'clickhouse_logs', 'clickhouse-logs-avro', 'Avro')
SETTINGS
@@ -152,15 +151,14 @@ SETTINGS
kafka_thread_per_consumer = 1,
kafka_num_consumers = 1,
kafka_poll_timeout_ms=15000,
kafka_poll_max_batch_size=1,
kafka_max_block_size=1;
kafka_poll_max_batch_size=10,
kafka_max_block_size=10;
drop table if exists kafka_logs_avro_mv;
CREATE MATERIALIZED VIEW kafka_logs_avro_mv TO logs27
CREATE MATERIALIZED VIEW kafka_logs_avro_mv TO logs31
(
`uuid` String,
`team_id` Int32,
`trace_id` String,
`span_id` String,
`trace_flags` Int32,
@@ -170,11 +168,10 @@ CREATE MATERIALIZED VIEW kafka_logs_avro_mv TO logs27
`severity_text` String,
`severity_number` Int32,
`service_name` String,
`resource_attributes` Map(String, String),
`resource_id` String,
`resource_attributes` Map(LowCardinality(String), String),
`instrumentation_scope` String,
`event_name` String,
`attributes` Map(String, Nullable(String))
`attributes` Map(LowCardinality(String), String)
)
AS SELECT
* except (attributes, resource_attributes),

View File

@@ -1,27 +1,4 @@
#!/bin/bash
set -e
./bin/migrate-check
# To ensure we are able to expose metrics from multiple processes, we need to
# provide a directory for `prometheus_client` to store a shared registry.
export PROMETHEUS_MULTIPROC_DIR=$(mktemp -d)
trap 'rm -rf "$PROMETHEUS_MULTIPROC_DIR"' EXIT
export PROMETHEUS_METRICS_EXPORT_PORT=8001
export STATSD_PORT=${STATSD_PORT:-8125}
exec gunicorn posthog.wsgi \
--config gunicorn.config.py \
--bind 0.0.0.0:8000 \
--log-file - \
--log-level info \
--access-logfile - \
--worker-tmp-dir /dev/shm \
--workers=2 \
--threads=8 \
--keep-alive=60 \
--backlog=${GUNICORN_BACKLOG:-1000} \
--worker-class=gthread \
${STATSD_HOST:+--statsd-host $STATSD_HOST:$STATSD_PORT} \
--limit-request-line=16384 $@
# Wrapper script for backward compatibility
# Calls the dual-mode server script (Granian/Unit)
exec "$(dirname "$0")/docker-server-unit" "$@"

View File

@@ -7,15 +7,48 @@ set -e
# provide a directory for `prometheus_client` to store a shared registry.
export PROMETHEUS_MULTIPROC_DIR=$(mktemp -d)
chmod -R 777 $PROMETHEUS_MULTIPROC_DIR
trap 'rm -rf "$PROMETHEUS_MULTIPROC_DIR"' EXIT
export PROMETHEUS_METRICS_EXPORT_PORT=8001
export STATSD_PORT=${STATSD_PORT:-8125}
export NGINX_UNIT_PYTHON_PROTOCOL=${NGINX_UNIT_PYTHON_PROTOCOL:-wsgi}
export NGINX_UNIT_APP_PROCESSES=${NGINX_UNIT_APP_PROCESSES:-4}
envsubst < /docker-entrypoint.d/unit.json.tpl > /docker-entrypoint.d/unit.json
# Dual-mode support: USE_GRANIAN env var switches between Granian and Unit (default)
if [ "${USE_GRANIAN:-false}" = "true" ]; then
echo "🚀 Starting with Granian ASGI server (opt-in via USE_GRANIAN=true)..."
# We need to run as --user root so that nginx unit can proxy the control socket for stats
# However each application is run as "nobody"
exec /usr/local/bin/docker-entrypoint.sh unitd --no-daemon --user root
# Granian configuration
export GRANIAN_WORKERS=${GRANIAN_WORKERS:-4}
export GRANIAN_THREADS=2
# Start metrics HTTP server in background on port 8001
python ./bin/granian_metrics.py &
METRICS_PID=$!
# Combined trap: kill metrics server and cleanup temp directory
trap 'kill $METRICS_PID 2>/dev/null; rm -rf "$PROMETHEUS_MULTIPROC_DIR"' EXIT
exec granian \
--interface asgi \
posthog.asgi:application \
--workers $GRANIAN_WORKERS \
--runtime-threads $GRANIAN_THREADS \
--runtime-mode mt \
--loop uvloop \
--host 0.0.0.0 \
--port 8000 \
--log-level warning \
--access-log \
--respawn-failed-workers
else
echo "🔧 Starting with Nginx Unit server (default, stable)..."
# Cleanup temp directory on exit
trap 'rm -rf "$PROMETHEUS_MULTIPROC_DIR"' EXIT
export NGINX_UNIT_PYTHON_PROTOCOL=${NGINX_UNIT_PYTHON_PROTOCOL:-wsgi}
export NGINX_UNIT_APP_PROCESSES=${NGINX_UNIT_APP_PROCESSES:-4}
envsubst < /docker-entrypoint.d/unit.json.tpl > /docker-entrypoint.d/unit.json
# We need to run as --user root so that nginx unit can proxy the control socket for stats
# However each application is run as "nobody"
exec /usr/local/bin/docker-entrypoint.sh unitd --no-daemon --user root
fi

66
bin/granian_metrics.py Executable file
View File

@@ -0,0 +1,66 @@
#!/usr/bin/env python3
"""
Metrics HTTP server for Granian multi-process setup.
Serves Prometheus metrics on port 8001 (configurable via PROMETHEUS_METRICS_EXPORT_PORT).
Aggregates metrics from all Granian worker processes using prometheus_client multiprocess mode.
Exposes Granian-equivalent metrics to maintain dashboard compatibility with previous Gunicorn setup.
"""
import os
import time
import logging
from prometheus_client import CollectorRegistry, Gauge, multiprocess, start_http_server
logger = logging.getLogger(__name__)
def create_granian_metrics(registry: CollectorRegistry) -> None:
"""
Create Granian-equivalent metrics to maintain compatibility with existing Grafana dashboards.
These metrics mirror the Gunicorn metrics that were previously exposed, allowing existing
dashboards and alerts to continue working with minimal changes.
"""
# Read Granian configuration from environment
workers = int(os.environ.get("GRANIAN_WORKERS", 4))
threads = int(os.environ.get("GRANIAN_THREADS", 2))
# Expose static configuration as gauges
# These provide equivalent metrics to what gunicorn/unit previously exposed
max_worker_threads = Gauge(
"granian_max_worker_threads",
"Maximum number of threads per worker",
registry=registry,
)
max_worker_threads.set(threads)
total_workers = Gauge(
"granian_workers_total",
"Total number of Granian workers configured",
registry=registry,
)
total_workers.set(workers)
def main():
"""Start HTTP server to expose Prometheus metrics from all workers."""
registry = CollectorRegistry()
multiprocess.MultiProcessCollector(registry)
# Create Granian-specific metrics for dashboard compatibility
create_granian_metrics(registry)
port = int(os.environ.get("PROMETHEUS_METRICS_EXPORT_PORT", 8001))
logger.info(f"Starting Prometheus metrics server on port {port}")
start_http_server(port=port, registry=registry)
# Keep the server running
while True:
time.sleep(3600)
if __name__ == "__main__":
main()

View File

@@ -61,5 +61,9 @@ procs:
shell: 'pnpm --filter=@posthog/storybook install && pnpm run storybook'
autostart: false
hedgebox-dummy:
shell: 'bin/check_postgres_up && cd hedgebox-dummy && pnpm install && pnpm run dev'
autostart: false
mouse_scroll_speed: 1
scrollback: 10000

View File

@@ -25,4 +25,13 @@ else
echo "🐧 Linux detected, binding to Docker bridge gateway: $HOST_BIND"
fi
python ${DEBUG:+ -m debugpy --listen 127.0.0.1:5678} -m uvicorn --reload posthog.asgi:application --host $HOST_BIND --log-level debug --reload-include "posthog/" --reload-include "ee/" --reload-include "products/"
python ${DEBUG:+ -m debugpy --listen 127.0.0.1:5678} -m granian \
--interface asgi \
posthog.asgi:application \
--reload \
--reload-paths ./posthog \
--reload-paths ./ee \
--reload-paths ./products \
--host $HOST_BIND \
--log-level debug \
--workers 1

View File

@@ -1,5 +1,10 @@
# posthog-cli
# 0.5.11
- Do not read bundle files as part of hermes sourcemap commands
- Change hermes clone command to take two file paths (for the minified and composed maps respectively)
# 0.5.10
- Add terminal checks for login and query command

View File

@@ -5,7 +5,9 @@ and bump the package version number at the same time.
```bash
git checkout -b "cli/release-v0.1.0-pre1"
# Bump version number in Cargo.toml and build to update Cargo.lock
# Bump version number in Cargo.toml
# Build to update Cargo.lock (cargo build)
# Update the CHANGELOG.md
git add .
git commit -m "Bump version number"
git tag "posthog-cli-v0.1.0-prerelease.1"

2
cli/Cargo.lock generated
View File

@@ -1521,7 +1521,7 @@ checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184"
[[package]]
name = "posthog-cli"
version = "0.5.10"
version = "0.5.11"
dependencies = [
"anyhow",
"chrono",

View File

@@ -1,6 +1,6 @@
[package]
name = "posthog-cli"
version = "0.5.10"
version = "0.5.11"
authors = [
"David <david@posthog.com>",
"Olly <oliver@posthog.com>",

View File

@@ -16,7 +16,7 @@ use crate::{
pub struct SourceMapContent {
#[serde(skip_serializing_if = "Option::is_none")]
pub release_id: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
#[serde(skip_serializing_if = "Option::is_none", alias = "debugId")]
pub chunk_id: Option<String>,
#[serde(flatten)]
pub fields: BTreeMap<String, Value>,
@@ -49,6 +49,10 @@ impl SourceMapFile {
self.inner.content.release_id.clone()
}
pub fn has_release_id(&self) -> bool {
self.get_release_id().is_some()
}
pub fn apply_adjustment(&mut self, adjustment: SourceMap) -> Result<()> {
let new_content = {
let content = serde_json::to_string(&self.inner.content)?.into_bytes();

View File

@@ -1,27 +1,22 @@
use std::path::PathBuf;
use anyhow::{anyhow, bail, Result};
use tracing::{info, warn};
use anyhow::{anyhow, Result};
use tracing::info;
use crate::{
invocation_context::context,
sourcemaps::{
content::SourceMapFile,
hermes::{get_composed_map, inject::is_metro_bundle},
inject::get_release_for_pairs,
source_pairs::read_pairs,
},
sourcemaps::{content::SourceMapFile, inject::get_release_for_maps},
};
#[derive(clap::Args)]
pub struct CloneArgs {
/// The directory containing the bundled chunks
/// The path of the minified source map
#[arg(short, long)]
pub directory: PathBuf,
pub minified_map_path: PathBuf,
/// One or more directory glob patterns to ignore
/// The path of the composed source map
#[arg(short, long)]
pub ignore: Vec<String>,
pub composed_map_path: PathBuf,
/// The project name associated with the uploaded chunks. Required to have the uploaded chunks associated with
/// a specific release. We will try to auto-derive this from git information if not provided. Strongly recommended
@@ -39,31 +34,30 @@ pub fn clone(args: &CloneArgs) -> Result<()> {
context().capture_command_invoked("hermes_clone");
let CloneArgs {
directory,
ignore,
minified_map_path,
composed_map_path,
project,
version,
} = args;
let directory = directory.canonicalize().map_err(|e| {
let mut minified_map = SourceMapFile::load(minified_map_path).map_err(|e| {
anyhow!(
"Directory '{}' not found or inaccessible: {}",
directory.display(),
"Failed to load minified map at '{}': {}",
minified_map_path.display(),
e
)
})?;
info!("Processing directory: {}", directory.display());
let pairs = read_pairs(&directory, ignore, is_metro_bundle, &None)?;
let mut composed_map = SourceMapFile::load(composed_map_path).map_err(|e| {
anyhow!(
"Failed to load composed map at '{}': {}",
composed_map_path.display(),
e
)
})?;
if pairs.is_empty() {
bail!("No source files found");
}
info!("Found {} pairs", pairs.len());
let release_id =
get_release_for_pairs(&directory, project, version, &pairs)?.map(|r| r.id.to_string());
let release_id = get_release_for_maps(minified_map_path, project, version, [&minified_map])?
.map(|r| r.id.to_string());
// The flow here differs from plain sourcemap injection a bit - here, we don't ever
// overwrite the chunk ID, because at this point in the build process, we no longer
@@ -73,47 +67,26 @@ pub fn clone(args: &CloneArgs) -> Result<()> {
// tries to run `clone` twice, changing release but not posthog env, we'll error out. The
// correct way to upload the same set of artefacts to the same posthog env as part of
// two different releases is, 1, not to, but failing that, 2, to re-run the bundling process
let mut pairs = pairs;
for pair in &mut pairs {
if !pair.has_release_id() || pair.get_release_id() != release_id {
pair.set_release_id(release_id.clone());
pair.save()?;
}
if !minified_map.has_release_id() || minified_map.get_release_id() != release_id {
minified_map.set_release_id(release_id.clone());
minified_map.save()?;
}
let pairs = pairs;
let maps: Result<Vec<(&SourceMapFile, Option<SourceMapFile>)>> = pairs
.iter()
.map(|p| get_composed_map(p).map(|c| (&p.sourcemap, c)))
.collect();
let maps = maps?;
for (minified, composed) in maps {
let Some(mut composed) = composed else {
warn!(
"Could not find composed map for minified sourcemap {}",
minified.inner.path.display()
);
continue;
};
// Copy metadata from source map to composed map
if let Some(chunk_id) = minified.get_chunk_id() {
composed.set_chunk_id(Some(chunk_id));
}
if let Some(release_id) = minified.get_release_id() {
composed.set_release_id(Some(release_id));
}
composed.save()?;
info!(
"Successfully cloned metadata to {}",
composed.inner.path.display()
);
// Copy metadata from source map to composed map
if let Some(chunk_id) = minified_map.get_chunk_id() {
composed_map.set_chunk_id(Some(chunk_id));
}
if let Some(release_id) = minified_map.get_release_id() {
composed_map.set_release_id(Some(release_id));
}
composed_map.save()?;
info!(
"Successfully cloned metadata to {}",
composed_map.inner.path.display()
);
info!("Finished cloning metadata");
Ok(())
}

View File

@@ -2,28 +2,23 @@ use std::path::PathBuf;
use anyhow::{anyhow, Ok, Result};
use tracing::{info, warn};
use walkdir::WalkDir;
use crate::api::symbol_sets::{self, SymbolSetUpload};
use crate::invocation_context::context;
use crate::sourcemaps::hermes::get_composed_map;
use crate::sourcemaps::hermes::inject::is_metro_bundle;
use crate::sourcemaps::source_pairs::read_pairs;
use crate::sourcemaps::content::SourceMapFile;
#[derive(clap::Args, Clone)]
pub struct Args {
/// The directory containing the bundled chunks
#[arg(short, long)]
pub directory: PathBuf,
/// One or more directory glob patterns to ignore
#[arg(short, long)]
pub ignore: Vec<String>,
}
pub fn upload(args: &Args) -> Result<()> {
context().capture_command_invoked("hermes_upload");
let Args { directory, ignore } = args;
let Args { directory } = args;
let directory = directory.canonicalize().map_err(|e| {
anyhow!(
@@ -34,17 +29,10 @@ pub fn upload(args: &Args) -> Result<()> {
})?;
info!("Processing directory: {}", directory.display());
let pairs = read_pairs(&directory, ignore, is_metro_bundle, &None)?;
let maps: Result<Vec<_>> = pairs.iter().map(get_composed_map).collect();
let maps = maps?;
let maps = read_maps(&directory);
let mut uploads: Vec<SymbolSetUpload> = Vec::new();
for map in maps.into_iter() {
let Some(map) = map else {
continue;
};
if map.get_chunk_id().is_none() {
warn!("Skipping map {}, no chunk ID", map.inner.path.display());
continue;
@@ -53,9 +41,22 @@ pub fn upload(args: &Args) -> Result<()> {
uploads.push(map.try_into()?);
}
info!("Found {} bundles to upload", uploads.len());
info!("Found {} maps to upload", uploads.len());
symbol_sets::upload(&uploads, 100)?;
Ok(())
}
fn read_maps(directory: &PathBuf) -> Vec<SourceMapFile> {
WalkDir::new(directory)
.into_iter()
.filter_map(Result::ok)
.filter(|e| e.file_type().is_file())
.map(|e| {
let path = e.path().canonicalize()?;
SourceMapFile::load(&path)
})
.filter_map(Result::ok)
.collect()
}

View File

@@ -5,7 +5,10 @@ use walkdir::DirEntry;
use crate::{
api::releases::{Release, ReleaseBuilder},
sourcemaps::source_pairs::{read_pairs, SourcePair},
sourcemaps::{
content::SourceMapFile,
source_pairs::{read_pairs, SourcePair},
},
utils::git::get_git_info,
};
@@ -61,9 +64,14 @@ pub fn inject_impl(args: &InjectArgs, matcher: impl Fn(&DirEntry) -> bool) -> Re
bail!("no source files found");
}
let created_release_id = get_release_for_pairs(&directory, project, version, &pairs)?
.as_ref()
.map(|r| r.id.to_string());
let created_release_id = get_release_for_maps(
&directory,
project,
version,
pairs.iter().map(|p| &p.sourcemap),
)?
.as_ref()
.map(|r| r.id.to_string());
pairs = inject_pairs(pairs, created_release_id)?;
@@ -97,16 +105,16 @@ pub fn inject_pairs(
Ok(pairs)
}
pub fn get_release_for_pairs<'a>(
pub fn get_release_for_maps<'a>(
directory: &Path,
project: &Option<String>,
version: &Option<String>,
pairs: impl IntoIterator<Item = &'a SourcePair>,
maps: impl IntoIterator<Item = &'a SourceMapFile>,
) -> Result<Option<Release>> {
// We need to fetch or create a release if: the user specified one, any pair is missing one, or the user
// forced release overriding
let needs_release =
project.is_some() || version.is_some() || pairs.into_iter().any(|p| !p.has_release_id());
project.is_some() || version.is_some() || maps.into_iter().any(|p| !p.has_release_id());
let mut created_release = None;
if needs_release {

View File

@@ -28,7 +28,7 @@ impl SourcePair {
}
pub fn has_release_id(&self) -> bool {
self.get_release_id().is_some()
self.sourcemap.has_release_id()
}
pub fn get_release_id(&self) -> Option<String> {

View File

@@ -206,6 +206,7 @@ export const commonConfig = {
'.woff2': 'file',
'.mp3': 'file',
'.lottie': 'file',
'.sql': 'text',
},
metafile: true,
}

View File

@@ -86,12 +86,6 @@ core:
dev:up:
description: Start full PostHog dev stack via mprocs
hidden: true
dev:setup:
cmd: python manage.py setup_dev
description: Initialize local development environment (one-time setup)
services:
- postgresql
- clickhouse
dev:demo-data:
cmd: python manage.py generate_demo_data
description: Generate demo data for local testing
@@ -100,14 +94,14 @@ core:
- clickhouse
dev:reset:
steps:
- docker:services:down
- docker:services:remove
- docker:services:up
- check:postgres
- check:clickhouse
- migrations:run
- dev:demo-data
- migrations:sync-flags
description: Full reset - stop services, migrate, load demo data, sync flags
description: Full reset - wipe volumes, migrate, load demo data, sync flags
health_checks:
check:clickhouse:
bin_script: check_clickhouse_up
@@ -272,6 +266,11 @@ docker:
description: Stop Docker infrastructure services
services:
- docker
docker:services:remove:
cmd: docker compose -f docker-compose.dev.yml down -v
description: Stop Docker infrastructure services and remove all volumes (complete wipe)
services:
- docker
docker:deprecated:
bin_script: docker
description: '[DEPRECATED] Use `hogli start` instead'
@@ -302,7 +301,7 @@ docker:
hidden: true
docker:server:
bin_script: docker-server
description: Run gunicorn application server with Prometheus metrics support
description: Run dual-mode server (Granian/Unit) with Prometheus metrics support
hidden: true
docker:server-unit:
bin_script: docker-server-unit
@@ -433,6 +432,10 @@ tools:
bin_script: create-notebook-node.sh
description: Create a new NotebookNode file and update types and editor references
hidden: true
tool:granian-metrics:
bin_script: granian_metrics.py
description: HTTP server that aggregates Prometheus metrics from Granian workers
hidden: true
sync:storage:
bin_script: sync-storage
description: 'TODO: add description for sync-storage'

View File

@@ -162,6 +162,11 @@ function createEntry(entry) {
test: /monaco-editor\/.*\.m?js/,
loader: 'babel-loader',
},
{
// Apply rule for .sql files
test: /\.sql$/,
type: 'asset/source',
},
],
},
// add devServer config only to 'main' entry

View File

@@ -8,4 +8,15 @@
--text-xxs--line-height: 0.75rem;
--spacing-button-padding-x: 7px; /* Match the padding of the button primitives base size (6px) + 1px border */
--animate-input-focus-pulse: inputFocusPulse 1s ease-in-out forwards;
@keyframes inputFocusPulse {
0% {
box-shadow: 0 0 0px 2px var(--color-accent);
}
100% {
box-shadow: 0 0 0px 2px transparent;
}
}
}

View File

@@ -20,6 +20,116 @@ Dagster is an open-source data orchestration tool designed to help you define an
- Individual DAG files (e.g., `exchange_rate.py`, `deletes.py`, `person_overrides.py`)
- `tests/`: Tests for the DAGs
### Cloud access for posthog employees
Ask someone on the #team-infrastructure or #team-clickhouse to add you to Dagster Cloud. You might also want to join the #dagster-posthog slack channel.
### Adding a New Team
To set up a new team with their own Dagster definitions and Slack alerts, follow these steps:
1. **Create a new definitions file** in `locations/<team_name>.py`:
```python
import dagster
from dags import my_module # Import your DAGs
from . import resources # Import shared resources (if needed)
defs = dagster.Definitions(
assets=[
# List your assets here
my_module.my_asset,
],
jobs=[
# List your jobs here
my_module.my_job,
],
schedules=[
# List your schedules here
my_module.my_schedule,
],
resources=resources, # Include shared resources (ClickHouse, S3, Slack, etc.)
)
```
**Examples**: See `locations/analytics_platform.py` (simple) or `locations/web_analytics.py` (complex with conditional schedules)
2. **Register the location in the workspace** (for local development):
Add your module to `.dagster_home/workspace.yaml`:
```yaml
load_from:
- python_module: dags.locations.your_team
```
**Note**: Only add locations that should run locally. Heavy operations should remain commented out.
3. **Configure production deployment**:
For PostHog employees, add the new location to the Dagster configuration in the [charts repository](https://github.com/PostHog/charts) (see `config/dagster/`).
Sample PR: https://github.com/PostHog/charts/pull/6366
4. **Add team to the `JobOwners` enum** in `common/common.py`:
```python
class JobOwners(str, Enum):
TEAM_ANALYTICS_PLATFORM = "team-analytics-platform"
TEAM_YOUR_TEAM = "team-your-team" # Add your team here (alphabetically sorted)
# ... other teams
```
5. **Add Slack channel mapping** in `slack_alerts.py`:
```python
notification_channel_per_team = {
JobOwners.TEAM_ANALYTICS_PLATFORM.value: "#alerts-analytics-platform",
JobOwners.TEAM_YOUR_TEAM.value: "#alerts-your-team", # Add mapping here (alphabetically sorted)
# ... other teams
}
```
6. **Create the Slack channel** (if it doesn't exist) and ensure the Alertmanager/Max Slack bot is invited to the channel
7. **Apply owner tags to your team's assets and jobs** (see next section)
### How slack alerts works
- The `notify_slack_on_failure` sensor (defined in `slack_alerts.py`) monitors all job failures across all code locations
- Alerts are only sent in production (when `CLOUD_DEPLOYMENT` environment variable is set)
- Each team has a dedicated Slack channel where their alerts are routed based on job ownership
- Failed jobs send a message to the appropriate team channel with a link to the Dagster run
#### Consecutive Failure Thresholds
Some jobs are configured to only alert after multiple consecutive failures to avoid alert fatigue. Configure this in `slack_alerts.py`:
```python
CONSECUTIVE_FAILURE_THRESHOLDS = {
"web_pre_aggregate_current_day_hourly_job": 3, # Alert after 3 consecutive failures
"your_job_name": 2, # Add your threshold here
}
```
#### Disabling Notifications
To disable Slack notifications for a specific job, add the `disable_slack_notifications` tag:
```python
@dagster.job(tags={"disable_slack_notifications": "true"})
def quiet_job():
pass
```
#### Testing Alerts Locally
When running Dagster locally (with `DEBUG=1`), the Slack resource is replaced with a dummy resource, so no actual notifications are sent. This prevents test alerts from being sent to production Slack channels during development.
To test the alert routing logic, write unit tests in `tests/test_slack_alerts.py`.
## Local Development
### Environment Setup

View File

@@ -2,7 +2,7 @@ import re
import time
from collections.abc import Callable
from dataclasses import dataclass
from datetime import UTC, datetime, timedelta
from datetime import UTC, datetime
from typing import Any, Optional
from django.conf import settings
@@ -88,6 +88,26 @@ NON_SHARDED_TABLES = [
]
@dataclass
class Table:
name: str
def is_backup_in_progress(self, client: Client) -> bool:
# We query the processes table to check if a backup for the requested table is in progress
rows = client.execute(
f"""
SELECT EXISTS(
SELECT 1
FROM system.processes
WHERE query_kind = 'Backup' AND query like '%{self.name}%'
)
"""
)
[[exists]] = rows
return exists
@dataclass
class BackupStatus:
hostname: str
@@ -95,6 +115,12 @@ class BackupStatus:
event_time_microseconds: datetime
error: Optional[str] = None
def created(self) -> bool:
return self.status == "BACKUP_CREATED"
def creating(self) -> bool:
return self.status == "CREATING_BACKUP"
@dataclass
class Backup:
@@ -141,7 +167,8 @@ class Backup:
backup_settings = {
"async": "1",
"max_backup_bandwidth": get_max_backup_bandwidth(),
"s3_disable_checksum": "1", # There is a CH issue that makes bandwith be half than what is configured: https://github.com/ClickHouse/ClickHouse/issues/78213
# There is a CH issue that makes bandwith be half than what is configured: https://github.com/ClickHouse/ClickHouse/issues/78213
"s3_disable_checksum": "1",
# According to CH docs, disabling this is safe enough as checksums are already made: https://clickhouse.com/docs/operations/settings/settings#s3_disable_checksum
}
if self.base_backup:
@@ -228,12 +255,6 @@ class BackupConfig(dagster.Config):
default="",
description="The table to backup. If not specified, the entire database will be backed up.",
)
date: str = pydantic.Field(
default=datetime.now(UTC).strftime("%Y-%m-%dT%H:%M:%SZ"),
description="The date to backup. If not specified, the current date will be used.",
pattern=r"^\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}Z$",
validate_default=True,
)
workload: Workload = Workload.OFFLINE
@@ -257,13 +278,38 @@ def get_shards(cluster: dagster.ResourceParam[ClickhouseCluster]):
@dagster.op
def get_latest_backup(
def check_running_backup_for_table(
config: BackupConfig, cluster: dagster.ResourceParam[ClickhouseCluster]
) -> Optional[Table]:
"""
Check if a backup for the requested table is in progress (it shouldn't, so fail if that's the case).
"""
table = Table(name=config.table)
is_running_backup = (
cluster.map_hosts_by_role(table.is_backup_in_progress, node_role=NodeRole.DATA, workload=config.workload)
.result()
.values()
)
if any(is_running_backup):
raise dagster.Failure(
description=f"A backup for table {table.name} is still in progress, this run shouldn't have been triggered. Review concurrency limits / schedule triggering logic. If there is not Dagster job running and there is a backup going on, it's worth checking what happened."
)
return table
@dagster.op
def get_latest_backups(
context: dagster.OpExecutionContext,
config: BackupConfig,
s3: S3Resource,
running_backup: Optional[Table] = None,
shard: Optional[int] = None,
) -> Optional[Backup]:
) -> list[Backup]:
"""
Get the latest backup metadata for a ClickHouse database / table from S3.
Get the latest 15 backups metadata for a ClickHouse database / table from S3.
They are sorted from most recent to oldest.
"""
shard_path = shard if shard else NO_SHARD_PATH
@@ -277,24 +323,28 @@ def get_latest_backup(
)
if "CommonPrefixes" not in backups:
return None
return []
latest_backup = sorted(backups["CommonPrefixes"], key=lambda x: x["Prefix"])[-1]["Prefix"]
return Backup.from_s3_path(latest_backup)
latest_backups = [
Backup.from_s3_path(backup["Prefix"])
for backup in sorted(backups["CommonPrefixes"], key=lambda x: x["Prefix"], reverse=True)
]
context.log.info(f"Found {len(latest_backups)} latest backups: {latest_backups}")
return latest_backups[:15]
@dagster.op
def check_latest_backup_status(
def get_latest_successful_backup(
context: dagster.OpExecutionContext,
config: BackupConfig,
latest_backup: Optional[Backup],
latest_backups: list[Backup],
cluster: dagster.ResourceParam[ClickhouseCluster],
) -> Optional[Backup]:
"""
Check if the latest backup is done.
Checks the latest succesful backup to use it as a base backup.
"""
if not latest_backup:
context.log.info("No latest backup found. Skipping status check.")
if not latest_backups or not config.incremental:
context.log.info("No latest backup found or a full backup was requested. Skipping status check.")
return
def map_hosts(func: Callable[[Client], Any]):
@@ -304,33 +354,23 @@ def check_latest_backup_status(
)
return cluster.map_hosts_by_role(fn=func, node_role=NodeRole.DATA, workload=config.workload)
is_done = map_hosts(latest_backup.is_done).result().values()
if not all(is_done):
context.log.info(f"Latest backup {latest_backup.path} is still in progress, waiting for it to finish")
map_hosts(latest_backup.wait).result()
else:
context.log.info(f"Find latest successful created backup")
for latest_backup in latest_backups:
context.log.info(f"Checking status of backup: {latest_backup.path}")
most_recent_status = get_most_recent_status(map_hosts(latest_backup.status).result().values())
if most_recent_status and most_recent_status.status != "BACKUP_CREATED":
# Check if the backup is stuck (CREATING_BACKUP with no active process)
if most_recent_status.status == "CREATING_BACKUP":
# Check how old the backup status is
time_since_status = datetime.now(UTC) - most_recent_status.event_time_microseconds.replace(tzinfo=UTC)
if time_since_status > timedelta(hours=2):
context.log.warning(
f"Previous backup {latest_backup.path} is stuck in CREATING_BACKUP status for {time_since_status}. "
f"This usually happens when the server was restarted during backup. "
f"Proceeding with new backup as the old one is no longer active."
)
# Don't raise an error - the backup is dead and won't interfere
return None
# For other unexpected statuses (like BACKUP_FAILED), still raise an error
raise ValueError(
f"Latest backup {latest_backup.path} finished with an unexpected status: {most_recent_status.status} on the host {most_recent_status.hostname}. Please check the backup logs."
if most_recent_status and not most_recent_status.created():
context.log.warning(
f"Backup {latest_backup.path} finished with an unexpected status: {most_recent_status.status} on the host {most_recent_status.hostname}. Checking next backup."
)
else:
context.log.info(f"Latest backup {latest_backup.path} finished successfully")
context.log.info(
f"Backup {latest_backup.path} finished successfully. Using it as the base backup for the new backup."
)
return latest_backup
return latest_backup
raise dagster.Failure(
f"All {len(latest_backups)} latest backups finished with an unexpected status. Please review them before launching a new one."
)
@dagster.op
@@ -352,16 +392,12 @@ def run_backup(
id=context.run_id,
database=config.database,
table=config.table,
date=config.date,
date=datetime.now(UTC).strftime("%Y-%m-%dT%H:%M:%SZ"),
base_backup=latest_backup if config.incremental else None,
shard=shard,
)
if latest_backup and latest_backup.path == backup.path:
context.log.warning(
f"This backup directory exists in S3. Skipping its run, if you want to run it again, remove the data from {backup.path}"
)
return
context.log.info(f"Running backup for table {config.table} in path: {backup.path}")
if backup.shard:
cluster.map_any_host_in_shards_by_role(
@@ -376,6 +412,15 @@ def run_backup(
workload=config.workload,
).result()
context.add_output_metadata(
{
"table": dagster.MetadataValue.text(str(backup.table)),
"path": dagster.MetadataValue.text(backup.path),
"incremental": dagster.MetadataValue.bool(config.incremental),
"date": dagster.MetadataValue.text(backup.date),
}
)
return backup
@@ -383,7 +428,7 @@ def run_backup(
def wait_for_backup(
context: dagster.OpExecutionContext,
config: BackupConfig,
backup: Optional[Backup],
backup: Backup,
cluster: dagster.ResourceParam[ClickhouseCluster],
):
"""
@@ -404,18 +449,27 @@ def wait_for_backup(
tries += 1
map_hosts(backup.wait).result().values()
most_recent_status = get_most_recent_status(map_hosts(backup.status).result().values())
if most_recent_status and most_recent_status.status == "CREATING_BACKUP":
if most_recent_status and most_recent_status.creating() and tries < 3:
context.log.warning(
f"Backup {backup.path} is no longer running but status is still creating. Waiting a bit longer in case ClickHouse didn't flush logs yet..."
)
continue
if most_recent_status and most_recent_status.status == "BACKUP_CREATED":
elif most_recent_status and most_recent_status.created():
context.log.info(f"Backup for table {backup.table} in path {backup.path} finished successfully")
done = True
if (most_recent_status and most_recent_status.status != "BACKUP_CREATED") or (
most_recent_status and tries >= 5
):
elif most_recent_status and not most_recent_status.created():
raise ValueError(
f"Backup {backup.path} finished with an unexpected status: {most_recent_status.status} on the host {most_recent_status.hostname}."
)
else:
context.log.info("No backup to wait for")
context.add_output_metadata(
{
"table": dagster.MetadataValue.text(str(backup.table)),
"path": dagster.MetadataValue.text(backup.path),
"incremental": dagster.MetadataValue.bool(config.incremental),
"date": dagster.MetadataValue.text(backup.date),
}
)
@dagster.job(
@@ -431,8 +485,8 @@ def sharded_backup():
"""
def run_backup_for_shard(shard: int):
latest_backup = get_latest_backup(shard=shard)
checked_backup = check_latest_backup_status(latest_backup=latest_backup)
latest_backups = get_latest_backups(check_running_backup_for_table(), shard=shard)
checked_backup = get_latest_successful_backup(latest_backups=latest_backups)
new_backup = run_backup(latest_backup=checked_backup, shard=shard)
wait_for_backup(backup=new_backup)
@@ -460,8 +514,8 @@ def non_sharded_backup():
Since we don't want to keep the state about which host was selected to run the backup, we always search backups by their name in every node.
When we find it in one of the nodes, we keep waiting on it only in that node. This is handy when we retry the job and a backup is in progress in any node, as we'll always wait for it to finish.
"""
latest_backup = get_latest_backup()
new_backup = run_backup(check_latest_backup_status(latest_backup))
latest_backups = get_latest_backups(check_running_backup_for_table())
new_backup = run_backup(get_latest_successful_backup(latest_backups=latest_backups))
wait_for_backup(new_backup)
@@ -469,14 +523,20 @@ def prepare_run_config(config: BackupConfig) -> dagster.RunConfig:
return dagster.RunConfig(
{
op.name: {"config": config.model_dump(mode="json")}
for op in [get_latest_backup, run_backup, check_latest_backup_status, wait_for_backup]
for op in [
check_running_backup_for_table,
get_latest_backups,
run_backup,
get_latest_successful_backup,
wait_for_backup,
]
}
)
def run_backup_request(
table: str, incremental: bool, context: dagster.ScheduleEvaluationContext
) -> dagster.RunRequest | dagster.SkipReason:
) -> Optional[dagster.RunRequest]:
skip_reason = check_for_concurrent_runs(
context,
tags={
@@ -484,12 +544,12 @@ def run_backup_request(
},
)
if skip_reason:
return skip_reason
context.log.info(skip_reason.skip_message)
return None
timestamp = datetime.now(UTC)
config = BackupConfig(
database=settings.CLICKHOUSE_DATABASE,
date=timestamp.strftime("%Y-%m-%dT%H:%M:%SZ"),
table=table,
incremental=incremental,
)
@@ -508,23 +568,29 @@ def run_backup_request(
@dagster.schedule(
job=sharded_backup,
cron_schedule=settings.CLICKHOUSE_FULL_BACKUP_SCHEDULE,
should_execute=lambda context: 1 <= context.scheduled_execution_time.day <= 7,
default_status=dagster.DefaultScheduleStatus.RUNNING,
)
def full_sharded_backup_schedule(context: dagster.ScheduleEvaluationContext):
"""Launch a full backup for sharded tables"""
for table in SHARDED_TABLES:
yield run_backup_request(table, incremental=False, context=context)
request = run_backup_request(table, incremental=False, context=context)
if request:
yield request
@dagster.schedule(
job=non_sharded_backup,
cron_schedule=settings.CLICKHOUSE_FULL_BACKUP_SCHEDULE,
should_execute=lambda context: 1 <= context.scheduled_execution_time.day <= 7,
default_status=dagster.DefaultScheduleStatus.RUNNING,
)
def full_non_sharded_backup_schedule(context: dagster.ScheduleEvaluationContext):
"""Launch a full backup for non-sharded tables"""
for table in NON_SHARDED_TABLES:
yield run_backup_request(table, incremental=False, context=context)
request = run_backup_request(table, incremental=False, context=context)
if request:
yield request
@dagster.schedule(
@@ -535,7 +601,9 @@ def full_non_sharded_backup_schedule(context: dagster.ScheduleEvaluationContext)
def incremental_sharded_backup_schedule(context: dagster.ScheduleEvaluationContext):
"""Launch an incremental backup for sharded tables"""
for table in SHARDED_TABLES:
yield run_backup_request(table, incremental=True, context=context)
request = run_backup_request(table, incremental=True, context=context)
if request:
yield request
@dagster.schedule(
@@ -546,4 +614,6 @@ def incremental_sharded_backup_schedule(context: dagster.ScheduleEvaluationConte
def incremental_non_sharded_backup_schedule(context: dagster.ScheduleEvaluationContext):
"""Launch an incremental backup for non-sharded tables"""
for table in NON_SHARDED_TABLES:
yield run_backup_request(table, incremental=True, context=context)
request = run_backup_request(table, incremental=True, context=context)
if request:
yield request

View File

@@ -1,8 +1,13 @@
import base64
from contextlib import suppress
from enum import Enum
from typing import Optional
from django.conf import settings
import dagster
import psycopg2
import psycopg2.extras
from clickhouse_driver.errors import Error, ErrorCodes
from posthog.clickhouse import query_tagging
@@ -13,14 +18,17 @@ from posthog.redis import get_client, redis
class JobOwners(str, Enum):
TEAM_ANALYTICS_PLATFORM = "team-analytics-platform"
TEAM_CLICKHOUSE = "team-clickhouse"
TEAM_DATA_WAREHOUSE = "team-data-warehouse"
TEAM_ERROR_TRACKING = "team-error-tracking"
TEAM_EXPERIMENTS = "team-experiments"
TEAM_GROWTH = "team-growth"
TEAM_INGESTION = "team-ingestion"
TEAM_LLMA = "team-llma"
TEAM_MAX_AI = "team-max-ai"
TEAM_REVENUE_ANALYTICS = "team-revenue-analytics"
TEAM_WEB_ANALYTICS = "team-web-analytics"
TEAM_ERROR_TRACKING = "team-error-tracking"
TEAM_GROWTH = "team-growth"
TEAM_EXPERIMENTS = "team-experiments"
TEAM_MAX_AI = "team-max-ai"
TEAM_DATA_WAREHOUSE = "team-data-warehouse"
class ClickhouseClusterResource(dagster.ConfigurableResource):
@@ -75,6 +83,28 @@ class RedisResource(dagster.ConfigurableResource):
return client
class PostgresResource(dagster.ConfigurableResource):
"""
A Postgres database connection resource that returns a psycopg2 connection.
"""
host: str
port: str = "5432"
database: str
user: str
password: str
def create_resource(self, context: dagster.InitResourceContext) -> psycopg2.extensions.connection:
return psycopg2.connect(
host=self.host,
port=int(self.port),
database=self.database,
user=self.user,
password=self.password,
cursor_factory=psycopg2.extras.RealDictCursor,
)
def report_job_status_metric(
context: dagster.RunStatusSensorContext, cluster: dagster.ResourceParam[ClickhouseCluster]
) -> None:
@@ -171,3 +201,33 @@ def check_for_concurrent_runs(
return dagster.SkipReason(f"Skipping {job_name} run because another run of the same job is already active")
return None
def metabase_debug_query_url(run_id: str) -> Optional[str]:
cloud_deployment = getattr(settings, "CLOUD_DEPLOYMENT", None)
if cloud_deployment == "US":
return f"https://metabase.prod-us.posthog.dev/question/1671-get-clickhouse-query-log-for-given-dagster-run-id?dagster_run_id={run_id}"
if cloud_deployment == "EU":
return f"https://metabase.prod-eu.posthog.dev/question/544-get-clickhouse-query-log-for-given-dagster-run-id?dagster_run_id={run_id}"
sql = f"""
SELECT
hostName() as host,
event_time,
type,
exception IS NOT NULL and exception != '' as has_exception,
query_duration_ms,
formatReadableSize(memory_usage) as memory_used,
formatReadableSize(read_bytes) as data_read,
JSONExtractString(log_comment, 'dagster', 'run_id') AS dagster_run_id,
JSONExtractString(log_comment, 'dagster', 'job_name') AS dagster_job_name,
JSONExtractString(log_comment, 'dagster', 'asset_key') AS dagster_asset_key,
JSONExtractString(log_comment, 'dagster', 'op_name') AS dagster_op_name,
exception,
query
FROM clusterAllReplicas('posthog', system.query_log)
WHERE
dagster_run_id = '{run_id}'
AND event_date >= today() - 1
ORDER BY event_time DESC;
"""
return f"http://localhost:8123/play?user=default#{base64.b64encode(sql.encode("utf-8")).decode("utf-8")}"

View File

@@ -23,6 +23,8 @@ from posthog.hogql_queries.experiments.experiment_query_runner import Experiment
from posthog.hogql_queries.experiments.utils import get_experiment_stats_method
from posthog.models.experiment import Experiment, ExperimentMetricResult
from products.experiments.stats.shared.statistics import StatisticError
from dags.common import JobOwners
from dags.experiments import (
_parse_partition_key,
@@ -80,31 +82,32 @@ def experiment_regular_metrics_timeseries(context: dagster.AssetExecutionContext
if not metric_dict:
raise dagster.Failure(f"Metric {metric_uuid} not found in experiment {experiment_id}")
try:
metric_type = metric_dict.get("metric_type")
metric_obj: Union[ExperimentMeanMetric, ExperimentFunnelMetric, ExperimentRatioMetric]
if metric_type == "mean":
metric_obj = ExperimentMeanMetric(**metric_dict)
elif metric_type == "funnel":
metric_obj = ExperimentFunnelMetric(**metric_dict)
elif metric_type == "ratio":
metric_obj = ExperimentRatioMetric(**metric_dict)
else:
raise dagster.Failure(f"Unknown metric type: {metric_type}")
metric_type = metric_dict.get("metric_type")
metric_obj: Union[ExperimentMeanMetric, ExperimentFunnelMetric, ExperimentRatioMetric]
if metric_type == "mean":
metric_obj = ExperimentMeanMetric(**metric_dict)
elif metric_type == "funnel":
metric_obj = ExperimentFunnelMetric(**metric_dict)
elif metric_type == "ratio":
metric_obj = ExperimentRatioMetric(**metric_dict)
else:
raise dagster.Failure(f"Unknown metric type: {metric_type}")
# Validate experiment start date upfront
if not experiment.start_date:
raise dagster.Failure(
f"Experiment {experiment_id} has no start_date - only launched experiments should be processed"
)
query_from_utc = experiment.start_date
query_to_utc = datetime.now(ZoneInfo("UTC"))
try:
experiment_query = ExperimentQuery(
experiment_id=experiment_id,
metric=metric_obj,
)
# Cumulative calculation: from experiment start to current time
if not experiment.start_date:
raise dagster.Failure(
f"Experiment {experiment_id} has no start_date - only launched experiments should be processed"
)
query_from_utc = experiment.start_date
query_to_utc = datetime.now(ZoneInfo("UTC"))
query_runner = ExperimentQueryRunner(query=experiment_query, team=experiment.team)
result = query_runner._calculate()
@@ -127,7 +130,6 @@ def experiment_regular_metrics_timeseries(context: dagster.AssetExecutionContext
},
)
# Add metadata for Dagster UI display
context.add_output_metadata(
metadata={
"experiment_id": experiment_id,
@@ -144,7 +146,6 @@ def experiment_regular_metrics_timeseries(context: dagster.AssetExecutionContext
"metric_definition": str(metric_dict),
"query_from": query_from_utc.isoformat(),
"query_to": query_to_utc.isoformat(),
"results_status": "success",
}
)
return {
@@ -157,14 +158,50 @@ def experiment_regular_metrics_timeseries(context: dagster.AssetExecutionContext
"result": result.model_dump(),
}
except Exception as e:
if not experiment.start_date:
raise dagster.Failure(
f"Experiment {experiment_id} has no start_date - only launched experiments should be processed"
)
query_from_utc = experiment.start_date
query_to_utc = datetime.now(ZoneInfo("UTC"))
except (StatisticError, ZeroDivisionError) as e:
# Insufficient data - do not fail so that real failures are visible
ExperimentMetricResult.objects.update_or_create(
experiment_id=experiment_id,
metric_uuid=metric_uuid,
fingerprint=fingerprint,
query_to=query_to_utc,
defaults={
"query_from": query_from_utc,
"status": ExperimentMetricResult.Status.FAILED,
"result": None,
"query_id": None,
"completed_at": None,
"error_message": str(e),
},
)
context.add_output_metadata(
metadata={
"experiment_id": experiment_id,
"metric_uuid": metric_uuid,
"fingerprint": fingerprint,
"metric_type": metric_type,
"metric_name": metric_dict.get("name", f"Metric {metric_uuid}"),
"experiment_name": experiment.name,
"query_from": query_from_utc.isoformat(),
"query_to": query_to_utc.isoformat(),
"error_type": type(e).__name__,
"error_message": str(e),
}
)
return {
"experiment_id": experiment_id,
"metric_uuid": metric_uuid,
"fingerprint": fingerprint,
"metric_definition": metric_dict,
"query_from": query_from_utc.isoformat(),
"query_to": query_to_utc.isoformat(),
"error": str(e),
"error_type": type(e).__name__,
}
except Exception as e:
ExperimentMetricResult.objects.update_or_create(
experiment_id=experiment_id,
metric_uuid=metric_uuid,

View File

@@ -21,6 +21,8 @@ from posthog.hogql_queries.experiments.experiment_query_runner import Experiment
from posthog.hogql_queries.experiments.utils import get_experiment_stats_method
from posthog.models.experiment import Experiment, ExperimentMetricResult
from products.experiments.stats.shared.statistics import StatisticError
from dags.common import JobOwners
from dags.experiments import (
_parse_partition_key,
@@ -97,20 +99,21 @@ def experiment_saved_metrics_timeseries(context: dagster.AssetExecutionContext)
else:
raise dagster.Failure(f"Unknown metric type: {metric_type}")
# Validate experiment start date upfront
if not experiment.start_date:
raise dagster.Failure(
f"Experiment {experiment_id} has no start_date - only launched experiments should be processed"
)
query_from_utc = experiment.start_date
query_to_utc = datetime.now(ZoneInfo("UTC"))
try:
experiment_query = ExperimentQuery(
experiment_id=experiment_id,
metric=metric_obj,
)
# Cumulative calculation: from experiment start to current time
if not experiment.start_date:
raise dagster.Failure(
f"Experiment {experiment_id} has no start_date - only launched experiments should be processed"
)
query_from_utc = experiment.start_date
query_to_utc = datetime.now(ZoneInfo("UTC"))
query_runner = ExperimentQueryRunner(query=experiment_query, team=experiment.team, user_facing=False)
result = query_runner._calculate()
@@ -118,7 +121,7 @@ def experiment_saved_metrics_timeseries(context: dagster.AssetExecutionContext)
completed_at = datetime.now(ZoneInfo("UTC"))
experiment_metric_result, created = ExperimentMetricResult.objects.update_or_create(
experiment_metric_result, _ = ExperimentMetricResult.objects.update_or_create(
experiment_id=experiment_id,
metric_uuid=metric_uuid,
fingerprint=fingerprint,
@@ -133,7 +136,6 @@ def experiment_saved_metrics_timeseries(context: dagster.AssetExecutionContext)
},
)
# Add metadata for Dagster UI display
context.add_output_metadata(
metadata={
"experiment_id": experiment_id,
@@ -150,7 +152,6 @@ def experiment_saved_metrics_timeseries(context: dagster.AssetExecutionContext)
else None,
"query_from": query_from_utc.isoformat(),
"query_to": query_to_utc.isoformat(),
"results_status": "success",
}
)
return {
@@ -164,14 +165,52 @@ def experiment_saved_metrics_timeseries(context: dagster.AssetExecutionContext)
"result": result.model_dump(),
}
except Exception as e:
if not experiment.start_date:
raise dagster.Failure(
f"Experiment {experiment_id} has no start_date - only launched experiments should be processed"
)
query_from_utc = experiment.start_date
query_to_utc = datetime.now(ZoneInfo("UTC"))
except (StatisticError, ZeroDivisionError) as e:
# Insufficient data - do not fail so that real failures are visible
ExperimentMetricResult.objects.update_or_create(
experiment_id=experiment_id,
metric_uuid=metric_uuid,
fingerprint=fingerprint,
query_to=query_to_utc,
defaults={
"query_from": query_from_utc,
"status": ExperimentMetricResult.Status.FAILED,
"result": None,
"query_id": None,
"completed_at": None,
"error_message": str(e),
},
)
context.add_output_metadata(
metadata={
"experiment_id": experiment_id,
"saved_metric_id": saved_metric.id,
"saved_metric_uuid": saved_metric.query.get("uuid"),
"saved_metric_name": saved_metric.name,
"fingerprint": fingerprint,
"metric_type": metric_type,
"experiment_name": experiment.name,
"query_from": query_from_utc.isoformat(),
"query_to": query_to_utc.isoformat(),
"error_type": type(e).__name__,
"error_message": str(e),
}
)
return {
"experiment_id": experiment_id,
"saved_metric_id": saved_metric.id,
"saved_metric_uuid": saved_metric.query.get("uuid"),
"saved_metric_name": saved_metric.name,
"fingerprint": fingerprint,
"query_from": query_from_utc.isoformat(),
"query_to": query_to_utc.isoformat(),
"error": str(e),
"error_type": type(e).__name__,
}
except Exception as e:
ExperimentMetricResult.objects.update_or_create(
experiment_id=experiment_id,
metric_uuid=metric_uuid,

62
dags/llma/README.md Normal file
View File

@@ -0,0 +1,62 @@
# LLMA (LLM Analytics) Dagster Location
Data pipelines for LLM analytics and observability.
## Overview
The LLMA location contains pipelines for aggregating and analyzing AI/LLM
events tracked through PostHog. These pipelines power analytics, cost tracking,
and observability features for AI products.
## Structure
```text
dags/llma/
├── README.md
├── daily_metrics/ # Daily aggregation pipeline
│ ├── README.md # Detailed pipeline documentation
│ ├── config.py # Pipeline configuration
│ ├── main.py # Dagster asset and schedule
│ ├── utils.py # SQL generation helpers
│ └── sql/ # Modular SQL templates
│ ├── event_counts.sql # Event count metrics
│ ├── error_rates.sql # Error rate metrics
│ ├── trace_counts.sql # Unique trace count metrics
│ ├── session_counts.sql # Unique session count metrics
│ ├── trace_error_rates.sql # Trace-level error rates
│ └── pageview_counts.sql # LLM Analytics pageview metrics
└── __init__.py
Tests: dags/tests/llma/daily_metrics/test_sql_metrics.py
```
## Pipelines
### Daily Metrics
Aggregates AI event metrics ($ai_trace, $ai_generation, $ai_span,
$ai_embedding) by team and date into the `llma_metrics_daily` ClickHouse
table.
Features:
- Modular SQL template system for easy metric additions
- Event counts, trace counts, session counts, and pageview metrics
- Error rates at event and trace level (proportions 0.0-1.0)
- Long-format schema for schema-less evolution
- Daily schedule at 6 AM UTC
See [daily_metrics/README.md](daily_metrics/README.md) for detailed
documentation.
## Local Development
The LLMA location is loaded in `.dagster_home/workspace.yaml` for local
development.
View in Dagster UI:
```bash
# Dagster runs on port 3030
open http://localhost:3030
```

5
dags/llma/__init__.py Normal file
View File

@@ -0,0 +1,5 @@
"""
LLMA (LLM Analytics) team Dagster assets.
This module contains data pipeline assets for tracking and analyzing LLM usage metrics.
"""

View File

@@ -0,0 +1,206 @@
# LLMA Daily Metrics Pipeline
Daily aggregation of AI event metrics into the `llma_metrics_daily` ClickHouse
table.
## Overview
Aggregates AI event metrics ($ai_trace, $ai_generation, $ai_span,
$ai_embedding) by team and date into a long-format metrics table for efficient
querying.
## Architecture
The pipeline uses a modular SQL template system:
- Each metric type lives in its own `.sql` file under `sql/`
- Templates are auto-discovered and combined with UNION ALL
- To add a new metric, simply drop a new `.sql` file in the directory
### SQL Template Format
Each SQL file should return these columns:
```sql
SELECT
date(timestamp) as date,
team_id,
'metric_name' as metric_name,
toFloat64(value) as metric_value
FROM events
WHERE ...
```
Templates have access to these Jinja2 variables:
- `event_types`: List of AI event types to aggregate
- `date_start`: Start date for aggregation (YYYY-MM-DD)
- `date_end`: End date for aggregation (YYYY-MM-DD)
- `pageview_mappings`: List of (url_path, metric_suffix) tuples for pageview categorization
- `include_error_rates`: Boolean flag for error rate calculation (default: true)
## Output Schema
```sql
CREATE TABLE llma_metrics_daily (
date Date,
team_id UInt64,
metric_name String,
metric_value Float64
) ENGINE = ReplicatedMergeTree()
PARTITION BY toYYYYMM(date)
ORDER BY (team_id, date, metric_name)
```
Long format allows adding new metrics without schema changes.
## Current Metrics
### Event Counts
Defined in `sql/event_counts.sql`:
- `ai_generation_count`: Number of AI generation events
- `ai_trace_count`: Number of AI trace events
- `ai_span_count`: Number of AI span events
- `ai_embedding_count`: Number of AI embedding events
Each event is counted individually, even if multiple events share the same trace_id.
### Trace Counts
Defined in `sql/trace_counts.sql`:
- `ai_trace_id_count`: Number of unique traces (distinct $ai_trace_id values)
Counts unique traces across all AI event types. A trace may contain multiple events (generations, spans, etc).
### Session Counts
Defined in `sql/session_counts.sql`:
- `ai_session_id_count`: Number of unique sessions (distinct $ai_session_id values)
Counts unique sessions across all AI event types. A session can link multiple related traces together.
### Event Error Rates
Defined in `sql/error_rates.sql`:
- `ai_generation_error_rate`: Proportion of AI generation events with errors (0.0 to 1.0)
- `ai_trace_error_rate`: Proportion of AI trace events with errors (0.0 to 1.0)
- `ai_span_error_rate`: Proportion of AI span events with errors (0.0 to 1.0)
- `ai_embedding_error_rate`: Proportion of AI embedding events with errors (0.0 to 1.0)
### Trace Error Rates
Defined in `sql/trace_error_rates.sql`:
- `ai_trace_id_has_error_rate`: Proportion of unique traces that had at least one error (0.0 to 1.0)
A trace is considered errored if ANY event within it has an error. Compare with event error rates which report the proportion of individual events with errors.
### Pageview Metrics
Defined in `sql/pageview_counts.sql`:
- `pageviews_traces`: Pageviews on /llm-analytics/traces
- `pageviews_generations`: Pageviews on /llm-analytics/generations
- `pageviews_users`: Pageviews on /llm-analytics/users
- `pageviews_sessions`: Pageviews on /llm-analytics/sessions
- `pageviews_playground`: Pageviews on /llm-analytics/playground
- `pageviews_datasets`: Pageviews on /llm-analytics/datasets
- `pageviews_evaluations`: Pageviews on /llm-analytics/evaluations
Tracks $pageview events on LLM Analytics pages. URL patterns are mapped to page types via config.pageview_mappings.
### Error Detection
All error metrics detect errors by checking for:
- `$ai_error` property is non-empty
- `$ai_is_error` property is true
## Configuration
See `config.py` for configuration options:
- `partition_start_date`: First date to backfill (default: 2025-01-01)
- `cron_schedule`: Schedule for daily runs (default: 6 AM UTC)
- `max_partitions_per_run`: Max partitions to process in backfill (default: 14)
- `ai_event_types`: List of AI event types to track (default: $ai_trace, $ai_generation, $ai_span, $ai_embedding)
- `pageview_mappings`: URL path to metric name mappings for pageview tracking
- `include_error_rates`: Enable error rate metrics (default: true)
## Schedule
Runs daily at 6 AM UTC for the previous day's partition.
## Local Development
Query results in ClickHouse:
```bash
docker exec posthog-clickhouse-1 clickhouse-client --query \
"SELECT * FROM llma_metrics_daily WHERE date = today() FORMAT Pretty"
```
Or use HogQL in PostHog UI:
```sql
SELECT
date,
metric_name,
sum(metric_value) as total
FROM llma_metrics_daily
WHERE date >= today() - INTERVAL 7 DAY
GROUP BY date, metric_name
ORDER BY date DESC, metric_name
```
## Testing
Run the test suite to validate SQL structure and logic:
```bash
python -m pytest dags/tests/llma/daily_metrics/test_sql_metrics.py -v
```
Tests validate:
- SQL templates render without errors
- All SQL files produce the correct 4-column output
- Calculation logic is correct (using mock data)
- Date filtering and grouping work as expected
Test file location: `dags/tests/llma/daily_metrics/test_sql_metrics.py`
## Adding New Metrics
1. Create a new SQL file in `sql/` (e.g., `sql/token_counts.sql`)
2. Use Jinja2 template syntax with `event_types`, `date_start`, `date_end`
3. Return columns: `date`, `team_id`, `metric_name`, `metric_value`
4. The pipeline will automatically discover and include it
5. Add test coverage in `dags/tests/llma/daily_metrics/test_sql_metrics.py` with mock data and expected output
Example:
```sql
{% for event_type in event_types %}
{% set metric_name = event_type.lstrip('$') + '_tokens' %}
SELECT
date(timestamp) as date,
team_id,
'{{ metric_name }}' as metric_name,
toFloat64(sum(JSONExtractInt(properties, '$ai_total_tokens'))) as metric_value
FROM events
WHERE event = '{{ event_type }}'
AND timestamp >= toDateTime('{{ date_start }}', 'UTC')
AND timestamp < toDateTime('{{ date_end }}', 'UTC')
GROUP BY date, team_id
HAVING metric_value > 0
{% if not loop.last %}
UNION ALL
{% endif %}
{% endfor %}
```

View File

@@ -0,0 +1,67 @@
"""
Configuration for LLMA (LLM Analytics) metrics aggregation.
This module defines all parameters and constants used in the LLMA metrics pipeline.
Modify this file to add new metrics or change aggregation behavior.
"""
from dataclasses import dataclass
@dataclass
class LLMADailyMetricsConfig:
"""Configuration for LLMA daily metrics aggregation pipeline."""
# ClickHouse table name
table_name: str = "llma_metrics_daily"
# Start date for daily partitions (when AI events were introduced)
partition_start_date: str = "2025-01-01"
# Schedule: Daily at 6 AM UTC
cron_schedule: str = "0 6 * * *"
# Backfill policy: process N days per run
max_partitions_per_run: int = 14
# ClickHouse query settings
clickhouse_max_execution_time: int = 300 # 5 minutes
# Dagster job timeout (seconds)
job_timeout: int = 1800 # 30 minutes
# Include error rate metrics (percentage of events with errors)
include_error_rates: bool = True
# AI event types to track
# Add new event types here to automatically include them in daily aggregations
ai_event_types: list[str] | None = None
# Pageview URL path to metric name mappings
# Maps URL patterns to metric names for tracking pageviews
# Order matters: more specific patterns should come before general ones
pageview_mappings: list[tuple[str, str]] | None = None
def __post_init__(self):
"""Set default values for list fields."""
if self.ai_event_types is None:
self.ai_event_types = [
"$ai_trace",
"$ai_generation",
"$ai_span",
"$ai_embedding",
]
if self.pageview_mappings is None:
self.pageview_mappings = [
("/llm-analytics/traces", "traces"),
("/llm-analytics/generations", "generations"),
("/llm-analytics/users", "users"),
("/llm-analytics/sessions", "sessions"),
("/llm-analytics/playground", "playground"),
("/llm-analytics/datasets", "datasets"),
("/llm-analytics/evaluations", "evaluations"),
]
# Global config instance
config = LLMADailyMetricsConfig()

View File

@@ -0,0 +1,135 @@
"""
Daily aggregation of LLMA (LLM Analytics) metrics.
Aggregates AI event counts from the events table into a daily metrics table
for efficient querying and cost analysis.
"""
from datetime import UTC, datetime, timedelta
import pandas as pd
import dagster
from dagster import BackfillPolicy, DailyPartitionsDefinition
from posthog.clickhouse import query_tagging
from posthog.clickhouse.client import sync_execute
from posthog.clickhouse.cluster import ClickhouseCluster
from dags.common import JobOwners, dagster_tags
from dags.llma.daily_metrics.config import config
from dags.llma.daily_metrics.utils import get_delete_query, get_insert_query
# Partition definition for daily aggregations
partition_def = DailyPartitionsDefinition(start_date=config.partition_start_date, end_offset=1)
# Backfill policy: process N days per run
backfill_policy_def = BackfillPolicy.multi_run(max_partitions_per_run=config.max_partitions_per_run)
# ClickHouse settings for aggregation queries
LLMA_CLICKHOUSE_SETTINGS = {
"max_execution_time": str(config.clickhouse_max_execution_time),
}
@dagster.asset(
name="llma_metrics_daily",
group_name="llma",
partitions_def=partition_def,
backfill_policy=backfill_policy_def,
metadata={"table": config.table_name},
tags={"owner": JobOwners.TEAM_LLMA.value},
)
def llma_metrics_daily(
context: dagster.AssetExecutionContext,
cluster: dagster.ResourceParam[ClickhouseCluster],
) -> None:
"""
Daily aggregation of LLMA metrics.
Aggregates AI event counts ($ai_trace, $ai_generation, $ai_span, $ai_embedding)
by team and date into a long-format metrics table for efficient querying.
Long format allows adding new metrics without schema changes.
"""
query_tagging.get_query_tags().with_dagster(dagster_tags(context))
if not context.partition_time_window:
raise dagster.Failure("This asset should only be run with a partition_time_window")
start_datetime, end_datetime = context.partition_time_window
date_start = start_datetime.strftime("%Y-%m-%d")
date_end = end_datetime.strftime("%Y-%m-%d")
context.log.info(f"Aggregating LLMA metrics for {date_start} to {date_end}")
try:
delete_query = get_delete_query(date_start, date_end)
sync_execute(delete_query, settings=LLMA_CLICKHOUSE_SETTINGS)
insert_query = get_insert_query(date_start, date_end)
context.log.info(f"Metrics query: \n{insert_query}")
sync_execute(insert_query, settings=LLMA_CLICKHOUSE_SETTINGS)
# Query and log the metrics that were just aggregated
metrics_query = f"""
SELECT
metric_name,
count(DISTINCT team_id) as teams,
sum(metric_value) as total_value
FROM {config.table_name}
WHERE date >= '{date_start}' AND date < '{date_end}'
GROUP BY metric_name
ORDER BY metric_name
"""
metrics_results = sync_execute(metrics_query)
if metrics_results:
df = pd.DataFrame(metrics_results, columns=["metric_name", "teams", "total_value"])
context.log.info(f"Aggregated {len(df)} metric types for {date_start}:\n{df.to_string(index=False)}")
else:
context.log.info(f"No AI events found for {date_start}")
context.log.info(f"Successfully aggregated LLMA metrics for {date_start}")
except Exception as e:
raise dagster.Failure(f"Failed to aggregate LLMA metrics: {str(e)}") from e
# Define the job that runs the asset
llma_metrics_daily_job = dagster.define_asset_job(
name="llma_metrics_daily_job",
selection=["llma_metrics_daily"],
tags={
"owner": JobOwners.TEAM_LLMA.value,
"dagster/max_runtime": str(config.job_timeout),
},
)
@dagster.schedule(
cron_schedule=config.cron_schedule,
job=llma_metrics_daily_job,
execution_timezone="UTC",
tags={"owner": JobOwners.TEAM_LLMA.value},
)
def llma_metrics_daily_schedule(context: dagster.ScheduleEvaluationContext):
"""
Runs daily for the previous day's partition.
Schedule configured in dags.llma.config.
This aggregates AI event metrics from the events table into the
llma_metrics_daily table for efficient querying.
"""
# Calculate yesterday's partition
yesterday = (datetime.now(UTC) - timedelta(days=1)).strftime("%Y-%m-%d")
context.log.info(f"Scheduling LLMA metrics aggregation for {yesterday}")
return dagster.RunRequest(
partition_key=yesterday,
run_config={
"ops": {
"llma_metrics_daily": {"config": {}},
}
},
)

View File

@@ -0,0 +1,28 @@
/*
Event Error Rates - Proportion of events with errors by type
Calculates the proportion of events of each type that had an error (0.0 to 1.0).
An event is considered errored if $ai_error is set or $ai_is_error is true.
Produces metrics: ai_generation_error_rate, ai_embedding_error_rate, etc.
Example: If 2 out of 10 generation events had errors, this reports 0.20.
Compare with trace_error_rates.sql which reports proportion of traces with any error.
*/
SELECT
date(timestamp) as date,
team_id,
concat(substring(event, 2), '_error_rate') as metric_name,
round(
countIf(
(JSONExtractRaw(properties, '$ai_error') != '' AND JSONExtractRaw(properties, '$ai_error') != 'null')
OR JSONExtractBool(properties, '$ai_is_error') = true
) / count(*),
4
) as metric_value
FROM events
WHERE event IN ({% for event_type in event_types %}'{{ event_type }}'{% if not loop.last %}, {% endif %}{% endfor %})
AND timestamp >= toDateTime('{{ date_start }}', 'UTC')
AND timestamp < toDateTime('{{ date_end }}', 'UTC')
GROUP BY date, team_id, event

View File

@@ -0,0 +1,22 @@
/*
Event Counts - Individual AI events by type
Counts the total number of individual AI events ($ai_generation, $ai_embedding, etc).
Each event is counted separately, even if multiple events share the same trace_id.
Produces metrics: ai_generation_count, ai_embedding_count, ai_span_count, ai_trace_count
Example: If a trace contains 3 span events, this counts all 3 individually.
Compare with trace_counts.sql which would count this as 1 unique trace.
*/
SELECT
date(timestamp) as date,
team_id,
concat(substring(event, 2), '_count') as metric_name,
toFloat64(count(*)) as metric_value
FROM events
WHERE event IN ({% for event_type in event_types %}'{{ event_type }}'{% if not loop.last %}, {% endif %}{% endfor %})
AND timestamp >= toDateTime('{{ date_start }}', 'UTC')
AND timestamp < toDateTime('{{ date_end }}', 'UTC')
GROUP BY date, team_id, event

View File

@@ -0,0 +1,39 @@
/*
Pageview Counts - Page views by LLM Analytics page type
Counts $pageview events on LLM Analytics pages, categorized by page type.
URL patterns are mapped to page types via config.pageview_mappings.
More specific patterns should be listed before general ones in config.
Produces metrics: pageviews_dashboard, pageviews_traces, pageviews_generations, etc.
Example: Pageview to /project/1/llm-analytics/traces → pageviews_traces metric
*/
SELECT
date(timestamp) as date,
team_id,
concat('pageviews_', page_type) as metric_name,
toFloat64(count(*)) as metric_value
FROM (
SELECT
timestamp,
team_id,
CASE
{% for url_path, metric_suffix in pageview_mappings %}
WHEN JSONExtractString(properties, '$current_url') LIKE '%{{ url_path }}%' THEN '{{ metric_suffix }}'
{% endfor %}
ELSE NULL
END as page_type
FROM events
WHERE event = '$pageview'
AND timestamp >= toDateTime('{{ date_start }}', 'UTC')
AND timestamp < toDateTime('{{ date_end }}', 'UTC')
AND (
{% for url_path, metric_suffix in pageview_mappings %}
JSONExtractString(properties, '$current_url') LIKE '%{{ url_path }}%'{% if not loop.last %} OR {% endif %}
{% endfor %}
)
)
WHERE page_type IS NOT NULL
GROUP BY date, team_id, page_type

View File

@@ -0,0 +1,23 @@
/*
Session Counts - Unique sessions (distinct $ai_session_id)
Counts the number of unique sessions by counting distinct $ai_session_id values
across all AI event types. A session can link multiple related traces together.
Produces metric: ai_session_id_count
Example: If there are 10 traces belonging to 3 unique session_ids, this counts 3.
Compare with trace_counts.sql which would count all 10 unique traces.
*/
SELECT
date(timestamp) as date,
team_id,
'ai_session_id_count' as metric_name,
toFloat64(count(DISTINCT JSONExtractString(properties, '$ai_session_id'))) as metric_value
FROM events
WHERE event IN ({% for event_type in event_types %}'{{ event_type }}'{% if not loop.last %}, {% endif %}{% endfor %})
AND timestamp >= toDateTime('{{ date_start }}', 'UTC')
AND timestamp < toDateTime('{{ date_end }}', 'UTC')
AND JSONExtractString(properties, '$ai_session_id') != ''
GROUP BY date, team_id

View File

@@ -0,0 +1,23 @@
/*
Trace Counts - Unique traces (distinct $ai_trace_id)
Counts the number of unique traces by counting distinct $ai_trace_id values
across all AI event types. A trace may contain multiple events (generations, spans, etc).
Produces metric: ai_trace_id_count
Example: If there are 10 events belonging to 3 unique trace_ids, this counts 3.
Compare with event_counts.sql which would count all 10 events individually.
*/
SELECT
date(timestamp) as date,
team_id,
'ai_trace_id_count' as metric_name,
toFloat64(count(DISTINCT JSONExtractString(properties, '$ai_trace_id'))) as metric_value
FROM events
WHERE event IN ({% for event_type in event_types %}'{{ event_type }}'{% if not loop.last %}, {% endif %}{% endfor %})
AND timestamp >= toDateTime('{{ date_start }}', 'UTC')
AND timestamp < toDateTime('{{ date_end }}', 'UTC')
AND JSONExtractString(properties, '$ai_trace_id') != ''
GROUP BY date, team_id

View File

@@ -0,0 +1,32 @@
/*
Trace Error Rates - Proportion of traces with any error
Calculates the proportion of unique traces that had at least one error event (0.0 to 1.0).
A trace is considered errored if ANY event within it has $ai_error set or $ai_is_error is true.
Produces metric: ai_trace_id_has_error_rate
Example: If 2 out of 8 unique traces had any error event, this reports 0.25.
Compare with error_rates.sql which reports proportion of individual events with errors.
Note: A single erroring event in a trace makes the entire trace count as errored.
*/
SELECT
date(timestamp) as date,
team_id,
'ai_trace_id_has_error_rate' as metric_name,
round(
countDistinctIf(
JSONExtractString(properties, '$ai_trace_id'),
(JSONExtractRaw(properties, '$ai_error') != '' AND JSONExtractRaw(properties, '$ai_error') != 'null')
OR JSONExtractBool(properties, '$ai_is_error') = true
) / count(DISTINCT JSONExtractString(properties, '$ai_trace_id')),
4
) as metric_value
FROM events
WHERE event IN ({% for event_type in event_types %}'{{ event_type }}'{% if not loop.last %}, {% endif %}{% endfor %})
AND timestamp >= toDateTime('{{ date_start }}', 'UTC')
AND timestamp < toDateTime('{{ date_end }}', 'UTC')
AND JSONExtractString(properties, '$ai_trace_id') != ''
GROUP BY date, team_id

View File

@@ -0,0 +1,65 @@
"""
Utility functions for LLMA daily metrics aggregation.
"""
from pathlib import Path
from jinja2 import Template
from dags.llma.daily_metrics.config import config
# SQL template directory
SQL_DIR = Path(__file__).parent / "sql"
# Metric types to include (matches SQL filenames without .sql extension)
# Set to None to include all, or list specific ones to include
ENABLED_METRICS: list[str] | None = None # or ["event_counts", "error_rates"]
def get_insert_query(date_start: str, date_end: str) -> str:
"""
Generate SQL to aggregate AI event counts by team and metric type.
Uses long format: each metric_name is a separate row for easy schema evolution.
Automatically discovers and combines all SQL templates in the sql/ directory.
To add a new metric type, simply add a new .sql file in dags/llma/sql/.
Each SQL file should return columns: date, team_id, metric_name, metric_value
"""
# Discover all SQL template files
sql_files = sorted(SQL_DIR.glob("*.sql"))
# Filter by enabled metrics if specified
if ENABLED_METRICS is not None:
sql_files = [f for f in sql_files if f.stem in ENABLED_METRICS]
if not sql_files:
raise ValueError(f"No SQL template files found in {SQL_DIR}")
# Load and render each template
rendered_queries = []
template_context = {
"event_types": config.ai_event_types,
"pageview_mappings": config.pageview_mappings,
"date_start": date_start,
"date_end": date_end,
"include_error_rates": config.include_error_rates,
}
for sql_file in sql_files:
with open(sql_file) as f:
template = Template(f.read())
rendered = template.render(**template_context)
if rendered.strip(): # Only include non-empty queries
rendered_queries.append(rendered)
# Combine all queries with UNION ALL
combined_query = "\n\nUNION ALL\n\n".join(rendered_queries)
# Wrap in INSERT INTO statement
return f"INSERT INTO {config.table_name} (date, team_id, metric_name, metric_value)\n\n{combined_query}"
def get_delete_query(date_start: str, date_end: str) -> str:
"""Generate SQL to delete existing data for the date range."""
return f"ALTER TABLE {config.table_name} DELETE WHERE date >= '{date_start}' AND date < '{date_end}'"

View File

@@ -5,7 +5,7 @@ import dagster_slack
from dagster_aws.s3.io_manager import s3_pickle_io_manager
from dagster_aws.s3.resources import S3Resource
from dags.common import ClickhouseClusterResource, RedisResource
from dags.common import ClickhouseClusterResource, PostgresResource, RedisResource
# Define resources for different environments
resources_by_env = {
@@ -18,6 +18,14 @@ resources_by_env = {
"s3": S3Resource(),
# Using EnvVar instead of the Django setting to ensure that the token is not leaked anywhere in the Dagster UI
"slack": dagster_slack.SlackResource(token=dagster.EnvVar("SLACK_TOKEN")),
# Postgres resource (universal for all dags)
"database": PostgresResource(
host=dagster.EnvVar("POSTGRES_HOST"),
port=dagster.EnvVar("POSTGRES_PORT"),
database=dagster.EnvVar("POSTGRES_DATABASE"),
user=dagster.EnvVar("POSTGRES_USER"),
password=dagster.EnvVar("POSTGRES_PASSWORD"),
),
},
"local": {
"cluster": ClickhouseClusterResource.configure_at_launch(),
@@ -29,6 +37,14 @@ resources_by_env = {
aws_secret_access_key=settings.OBJECT_STORAGE_SECRET_ACCESS_KEY,
),
"slack": dagster.ResourceDefinition.none_resource(description="Dummy Slack resource for local development"),
# Postgres resource (universal for all dags) - use Django settings or env vars for local dev
"database": PostgresResource(
host=dagster.EnvVar("POSTGRES_HOST"),
port=dagster.EnvVar("POSTGRES_PORT"),
database=dagster.EnvVar("POSTGRES_DATABASE"),
user=dagster.EnvVar("POSTGRES_USER"),
password=dagster.EnvVar("POSTGRES_PASSWORD"),
),
},
}

View File

@@ -0,0 +1,15 @@
import dagster
from dags import persons_new_backfill
from . import resources
defs = dagster.Definitions(
assets=[
persons_new_backfill.postgres_env_check,
],
jobs=[
persons_new_backfill.persons_new_backfill_job,
],
resources=resources,
)

16
dags/locations/llma.py Normal file
View File

@@ -0,0 +1,16 @@
import dagster
from dags.llma.daily_metrics.main import llma_metrics_daily, llma_metrics_daily_job, llma_metrics_daily_schedule
from . import resources
defs = dagster.Definitions(
assets=[llma_metrics_daily],
jobs=[
llma_metrics_daily_job,
],
schedules=[
llma_metrics_daily_schedule,
],
resources=resources,
)

View File

@@ -0,0 +1,398 @@
"""Dagster job for backfilling posthog_persons data from source to destination Postgres database."""
import os
import time
from typing import Any
import dagster
import psycopg2
import psycopg2.errors
from dagster_k8s import k8s_job_executor
from posthog.clickhouse.cluster import ClickhouseCluster
from posthog.clickhouse.custom_metrics import MetricsClient
from dags.common import JobOwners
class PersonsNewBackfillConfig(dagster.Config):
"""Configuration for the persons new backfill job."""
chunk_size: int = 1_000_000 # ID range per chunk
batch_size: int = 100_000 # Records per batch insert
source_table: str = "posthog_persons"
destination_table: str = "posthog_persons_new"
max_id: int | None = None # Optional override for max ID to resume from partial state
@dagster.op
def get_id_range(
context: dagster.OpExecutionContext,
config: PersonsNewBackfillConfig,
database: dagster.ResourceParam[psycopg2.extensions.connection],
) -> tuple[int, int]:
"""
Query source database for MIN(id) and optionally MAX(id) from posthog_persons table.
If max_id is provided in config, uses that instead of querying.
Returns tuple (min_id, max_id).
"""
with database.cursor() as cursor:
# Always query for min_id
min_query = f"SELECT MIN(id) as min_id FROM {config.source_table}"
context.log.info(f"Querying min ID: {min_query}")
cursor.execute(min_query)
min_result = cursor.fetchone()
if min_result is None or min_result["min_id"] is None:
context.log.exception("Source table is empty or has no valid IDs")
# Note: No metrics client here as this is get_id_range op, not copy_chunk
raise dagster.Failure("Source table is empty or has no valid IDs")
min_id = int(min_result["min_id"])
# Use config max_id if provided, otherwise query database
if config.max_id is not None:
max_id = config.max_id
context.log.info(f"Using configured max_id override: {max_id}")
else:
max_query = f"SELECT MAX(id) as max_id FROM {config.source_table}"
context.log.info(f"Querying max ID: {max_query}")
cursor.execute(max_query)
max_result = cursor.fetchone()
if max_result is None or max_result["max_id"] is None:
context.log.exception("Source table has no valid max ID")
# Note: No metrics client here as this is get_id_range op, not copy_chunk
raise dagster.Failure("Source table has no valid max ID")
max_id = int(max_result["max_id"])
# Validate that max_id >= min_id
if max_id < min_id:
error_msg = f"Invalid ID range: max_id ({max_id}) < min_id ({min_id})"
context.log.error(error_msg)
# Note: No metrics client here as this is get_id_range op, not copy_chunk
raise dagster.Failure(error_msg)
context.log.info(f"ID range: min={min_id}, max={max_id}, total_ids={max_id - min_id + 1}")
context.add_output_metadata(
{
"min_id": dagster.MetadataValue.int(min_id),
"max_id": dagster.MetadataValue.int(max_id),
"max_id_source": dagster.MetadataValue.text("config" if config.max_id is not None else "database"),
"total_ids": dagster.MetadataValue.int(max_id - min_id + 1),
}
)
return (min_id, max_id)
@dagster.op(out=dagster.DynamicOut(tuple[int, int]))
def create_chunks(
context: dagster.OpExecutionContext,
config: PersonsNewBackfillConfig,
id_range: tuple[int, int],
):
"""
Divide ID space into chunks of chunk_size.
Yields DynamicOutput for each chunk in reverse order (highest IDs first, lowest IDs last).
This ensures that if the job fails partway through, the final chunk to process will be
the one starting at the source table's min_id.
"""
min_id, max_id = id_range
chunk_size = config.chunk_size
# First, collect all chunks
chunks = []
chunk_min = min_id
chunk_num = 0
while chunk_min <= max_id:
chunk_max = min(chunk_min + chunk_size - 1, max_id)
chunks.append((chunk_min, chunk_max, chunk_num))
chunk_min = chunk_max + 1
chunk_num += 1
context.log.info(f"Created {chunk_num} chunks total")
# Yield chunks in reverse order (highest IDs first)
for chunk_min, chunk_max, chunk_num in reversed(chunks):
chunk_key = f"chunk_{chunk_min}_{chunk_max}"
context.log.info(f"Yielding chunk {chunk_num}: {chunk_min} to {chunk_max}")
yield dagster.DynamicOutput(
value=(chunk_min, chunk_max),
mapping_key=chunk_key,
)
@dagster.op
def copy_chunk(
context: dagster.OpExecutionContext,
config: PersonsNewBackfillConfig,
chunk: tuple[int, int],
database: dagster.ResourceParam[psycopg2.extensions.connection],
cluster: dagster.ResourceParam[ClickhouseCluster],
) -> dict[str, Any]:
"""
Copy a chunk of records from source to destination database.
Processes in batches of batch_size records.
"""
chunk_min, chunk_max = chunk
batch_size = config.batch_size
source_table = config.source_table
destination_table = config.destination_table
chunk_id = f"chunk_{chunk_min}_{chunk_max}"
job_name = context.run.job_name
# Initialize metrics client
metrics_client = MetricsClient(cluster)
context.log.info(f"Starting chunk copy: {chunk_min} to {chunk_max}")
total_records_copied = 0
batch_start_id = chunk_min
failed_batch_start_id: int | None = None
try:
with database.cursor() as cursor:
# Set session-level settings once for the entire chunk
cursor.execute("SET application_name = 'backfill_posthog_persons_to_posthog_persons_new'")
cursor.execute("SET lock_timeout = '5s'")
cursor.execute("SET statement_timeout = '30min'")
cursor.execute("SET maintenance_work_mem = '12GB'")
cursor.execute("SET work_mem = '512MB'")
cursor.execute("SET temp_buffers = '512MB'")
cursor.execute("SET max_parallel_workers_per_gather = 2")
cursor.execute("SET max_parallel_maintenance_workers = 2")
cursor.execute("SET synchronous_commit = off")
retry_attempt = 0
while batch_start_id <= chunk_max:
try:
# Track batch start time for duration metric
batch_start_time = time.time()
# Calculate batch end ID
batch_end_id = min(batch_start_id + batch_size, chunk_max)
# Track records attempted - this is also our exit condition
records_attempted = batch_end_id - batch_start_id
if records_attempted <= 0:
break
# Begin transaction (settings already applied at session level)
cursor.execute("BEGIN")
# Execute INSERT INTO ... SELECT with NOT EXISTS check
insert_query = f"""
INSERT INTO {destination_table}
SELECT s.*
FROM {source_table} s
WHERE s.id >= %s AND s.id <= %s
AND NOT EXISTS (
SELECT 1
FROM {destination_table} d
WHERE d.team_id = s.team_id
AND d.id = s.id
)
ORDER BY s.id DESC
"""
cursor.execute(insert_query, (batch_start_id, batch_end_id))
records_inserted = cursor.rowcount
# Commit the transaction
cursor.execute("COMMIT")
try:
metrics_client.increment(
"persons_new_backfill_records_attempted_total",
labels={"job_name": job_name, "chunk_id": chunk_id},
value=float(records_attempted),
).result()
except Exception:
pass # Don't fail on metrics error
batch_duration_seconds = time.time() - batch_start_time
try:
metrics_client.increment(
"persons_new_backfill_records_inserted_total",
labels={"job_name": job_name, "chunk_id": chunk_id},
value=float(records_inserted),
).result()
except Exception:
pass # Don't fail on metrics error
try:
metrics_client.increment(
"persons_new_backfill_batches_copied_total",
labels={"job_name": job_name, "chunk_id": chunk_id},
value=1.0,
).result()
except Exception:
pass
# Track batch duration metric (IV)
try:
metrics_client.increment(
"persons_new_backfill_batch_duration_seconds_total",
labels={"job_name": job_name, "chunk_id": chunk_id},
value=batch_duration_seconds,
).result()
except Exception:
pass
total_records_copied += records_inserted
context.log.info(
f"Copied batch: {records_inserted} records "
f"(chunk {chunk_min}-{chunk_max}, batch ID range {batch_start_id} to {batch_end_id})"
)
# Update batch_start_id for next iteration
batch_start_id = batch_end_id + 1
retry_attempt = 0
except Exception as batch_error:
# Rollback transaction on error
try:
cursor.execute("ROLLBACK")
except Exception as rollback_error:
context.log.exception(
f"Failed to rollback transaction for batch starting at ID {batch_start_id}"
f"in chunk {chunk_min}-{chunk_max}: {str(rollback_error)}"
)
pass # Ignore rollback errors
# Check if error is a duplicate key violation, pause and retry if so
is_unique_violation = isinstance(batch_error, psycopg2.errors.UniqueViolation) or (
isinstance(batch_error, psycopg2.Error) and getattr(batch_error, "pgcode", None) == "23505"
)
if is_unique_violation:
error_msg = (
f"Duplicate key violation detected for batch starting at ID {batch_start_id} "
f"in chunk {chunk_min}-{chunk_max}. Error is: {batch_error}. "
"This is expected if records already exist in destination table. "
)
context.log.warning(error_msg)
if retry_attempt < 3:
retry_attempt += 1
context.log.info(f"Retrying batch {retry_attempt} of 3...")
time.sleep(1)
continue
failed_batch_start_id = batch_start_id
error_msg = (
f"Failed to copy batch starting at ID {batch_start_id} "
f"in chunk {chunk_min}-{chunk_max}: {str(batch_error)}"
)
context.log.exception(error_msg)
# Report fatal error metric before raising
try:
metrics_client.increment(
"persons_new_backfill_error",
labels={"job_name": job_name, "chunk_id": chunk_id, "reason": "batch_copy_failed"},
value=1.0,
).result()
except Exception:
pass # Don't fail on metrics error
raise dagster.Failure(
description=error_msg,
metadata={
"chunk_min_id": dagster.MetadataValue.int(chunk_min),
"chunk_max_id": dagster.MetadataValue.int(chunk_max),
"failed_batch_start_id": dagster.MetadataValue.int(failed_batch_start_id)
if failed_batch_start_id
else dagster.MetadataValue.text("N/A"),
"error_message": dagster.MetadataValue.text(str(batch_error)),
"records_copied_before_failure": dagster.MetadataValue.int(total_records_copied),
},
) from batch_error
except dagster.Failure:
# Re-raise Dagster failures as-is (they already have metadata and metrics)
raise
except Exception as e:
# Catch any other unexpected errors
error_msg = f"Unexpected error copying chunk {chunk_min}-{chunk_max}: {str(e)}"
context.log.exception(error_msg)
# Report fatal error metric before raising
try:
metrics_client.increment(
"persons_new_backfill_error",
labels={"job_name": job_name, "chunk_id": chunk_id, "reason": "unexpected_copy_error"},
value=1.0,
).result()
except Exception:
pass # Don't fail on metrics error
raise dagster.Failure(
description=error_msg,
metadata={
"chunk_min_id": dagster.MetadataValue.int(chunk_min),
"chunk_max_id": dagster.MetadataValue.int(chunk_max),
"failed_batch_start_id": dagster.MetadataValue.int(failed_batch_start_id)
if failed_batch_start_id
else dagster.MetadataValue.int(batch_start_id),
"error_message": dagster.MetadataValue.text(str(e)),
"records_copied_before_failure": dagster.MetadataValue.int(total_records_copied),
},
) from e
context.log.info(f"Completed chunk {chunk_min}-{chunk_max}: copied {total_records_copied} records")
# Emit metric for chunk completion
run_id = context.run.run_id
try:
metrics_client.increment(
"persons_new_backfill_chunks_completed_total",
labels={"job_name": job_name, "run_id": run_id, "chunk_id": chunk_id},
value=1.0,
).result()
except Exception:
pass # Don't fail on metrics error
context.add_output_metadata(
{
"chunk_min": dagster.MetadataValue.int(chunk_min),
"chunk_max": dagster.MetadataValue.int(chunk_max),
"records_copied": dagster.MetadataValue.int(total_records_copied),
}
)
return {
"chunk_min": chunk_min,
"chunk_max": chunk_max,
"records_copied": total_records_copied,
}
@dagster.asset
def postgres_env_check(context: dagster.AssetExecutionContext) -> None:
"""
Simple asset that prints PostgreSQL environment variables being used.
Useful for debugging connection configuration.
"""
env_vars = {
"POSTGRES_HOST": os.getenv("POSTGRES_HOST", "not set"),
"POSTGRES_PORT": os.getenv("POSTGRES_PORT", "not set"),
"POSTGRES_DATABASE": os.getenv("POSTGRES_DATABASE", "not set"),
"POSTGRES_USER": os.getenv("POSTGRES_USER", "not set"),
"POSTGRES_PASSWORD": "***" if os.getenv("POSTGRES_PASSWORD") else "not set",
}
context.log.info("PostgreSQL environment variables:")
for key, value in env_vars.items():
context.log.info(f" {key}: {value}")
@dagster.job(
tags={"owner": JobOwners.TEAM_INGESTION.value},
executor_def=k8s_job_executor,
)
def persons_new_backfill_job():
"""
Backfill posthog_persons data from source to destination Postgres database.
Divides the ID space into chunks and processes them in parallel.
"""
id_range = get_id_range()
chunks = create_chunks(id_range)
chunks.map(copy_chunk)

View File

@@ -1,15 +1,34 @@
from clickhouse_driver import Client
from dagster import AssetExecutionContext, BackfillPolicy, DailyPartitionsDefinition, asset
from posthog.clickhouse.client import sync_execute
from posthog.clickhouse.client.connection import Workload
from posthog.clickhouse.cluster import get_cluster
from posthog.clickhouse.query_tagging import tags_context
from posthog.git import get_git_commit_short
from posthog.models.raw_sessions.sessions_v3 import (
RAW_SESSION_TABLE_BACKFILL_RECORDINGS_SQL_V3,
RAW_SESSION_TABLE_BACKFILL_SQL_V3,
)
from dags.common import dagster_tags
from dags.common.common import JobOwners, metabase_debug_query_url
# This is the number of days to backfill in one SQL operation
MAX_PARTITIONS_PER_RUN = 30
MAX_PARTITIONS_PER_RUN = 1
# Keep the number of concurrent runs low to avoid overloading ClickHouse and running into the dread "Too many parts".
# This tag needs to also exist in Dagster Cloud (and the local dev dagster.yaml) for the concurrency limit to take effect.
# concurrency:
# runs:
# tag_concurrency_limits:
# - key: 'sessions_backfill_concurrency'
# limit: 3
# value:
# applyLimitPerUniqueValue: true
CONCURRENCY_TAG = {
"sessions_backfill_concurrency": "sessions_v3",
}
daily_partitions = DailyPartitionsDefinition(
start_date="2019-01-01", # this is a year before posthog was founded, so should be early enough even including data imports
@@ -17,6 +36,17 @@ daily_partitions = DailyPartitionsDefinition(
end_offset=1, # include today's partition (note that will create a partition with incomplete data, but all our backfills are idempotent so this is ok providing we re-run later)
)
ONE_HOUR_IN_SECONDS = 60 * 60
ONE_GB_IN_BYTES = 1024 * 1024 * 1024
settings = {
# see this run which took around 2hrs 10min for 1 day https://posthog.dagster.plus/prod-us/runs/0ba8afaa-f3cc-4845-97c5-96731ec8231d?focusedTime=1762898705269&selection=sessions_v3_backfill&logs=step%3Asessions_v3_backfill
# so to give some margin, allow 4 hours per partition
"max_execution_time": MAX_PARTITIONS_PER_RUN * 4 * ONE_HOUR_IN_SECONDS,
"max_memory_usage": 100 * ONE_GB_IN_BYTES,
"distributed_aggregation_memory_efficient": "1",
}
def get_partition_where_clause(context: AssetExecutionContext, timestamp_field: str) -> str:
start_incl = context.partition_time_window.start.strftime("%Y-%m-%d")
@@ -31,13 +61,14 @@ def get_partition_where_clause(context: AssetExecutionContext, timestamp_field:
partitions_def=daily_partitions,
name="sessions_v3_backfill",
backfill_policy=BackfillPolicy.multi_run(max_partitions_per_run=MAX_PARTITIONS_PER_RUN),
tags={"owner": JobOwners.TEAM_ANALYTICS_PLATFORM.value, **CONCURRENCY_TAG},
)
def sessions_v3_backfill(context: AssetExecutionContext) -> None:
where_clause = get_partition_where_clause(context, timestamp_field="timestamp")
# note that this is idempotent, so we don't need to worry about running it multiple times for the same partition
# as long as the backfill has run at least once for each partition, the data will be correct
backfill_sql = RAW_SESSION_TABLE_BACKFILL_SQL_V3(where=where_clause)
backfill_sql = RAW_SESSION_TABLE_BACKFILL_SQL_V3(where=where_clause, use_sharded_source=True)
partition_range = context.partition_key_range
partition_range_str = f"{partition_range.start} to {partition_range.end}"
@@ -45,8 +76,17 @@ def sessions_v3_backfill(context: AssetExecutionContext) -> None:
f"Running backfill for {partition_range_str} (where='{where_clause}') using commit {get_git_commit_short() or 'unknown'} "
)
context.log.info(backfill_sql)
if debug_url := metabase_debug_query_url(context.run_id):
context.log.info(f"Debug query: {debug_url}")
sync_execute(backfill_sql, workload=Workload.OFFLINE)
cluster = get_cluster()
tags = dagster_tags(context)
def backfill_per_shard(client: Client):
with tags_context(kind="dagster", dagster=tags):
sync_execute(backfill_sql, settings=settings, sync_client=client)
cluster.map_one_host_per_shard(backfill_per_shard).result()
context.log.info(f"Successfully backfilled sessions_v3 for {partition_range_str}")
@@ -55,13 +95,14 @@ def sessions_v3_backfill(context: AssetExecutionContext) -> None:
partitions_def=daily_partitions,
name="sessions_v3_replay_backfill",
backfill_policy=BackfillPolicy.multi_run(max_partitions_per_run=MAX_PARTITIONS_PER_RUN),
tags={"owner": JobOwners.TEAM_ANALYTICS_PLATFORM.value, **CONCURRENCY_TAG},
)
def sessions_v3_backfill_replay(context: AssetExecutionContext) -> None:
where_clause = get_partition_where_clause(context, timestamp_field="min_first_timestamp")
# note that this is idempotent, so we don't need to worry about running it multiple times for the same partition
# as long as the backfill has run at least once for each partition, the data will be correct
backfill_sql = RAW_SESSION_TABLE_BACKFILL_RECORDINGS_SQL_V3(where=where_clause)
backfill_sql = RAW_SESSION_TABLE_BACKFILL_RECORDINGS_SQL_V3(where=where_clause, use_sharded_source=True)
partition_range = context.partition_key_range
partition_range_str = f"{partition_range.start} to {partition_range.end}"
@@ -69,7 +110,16 @@ def sessions_v3_backfill_replay(context: AssetExecutionContext) -> None:
f"Running backfill for {partition_range_str} (where='{where_clause}') using commit {get_git_commit_short() or 'unknown'} "
)
context.log.info(backfill_sql)
if debug_url := metabase_debug_query_url(context.run_id):
context.log.info(f"Debug query: {debug_url}")
sync_execute(backfill_sql, workload=Workload.OFFLINE)
cluster = get_cluster()
tags = dagster_tags(context)
def backfill_per_shard(client: Client):
with tags_context(kind="dagster", dagster=tags):
sync_execute(backfill_sql, workload=Workload.OFFLINE, settings=settings, sync_client=client)
cluster.map_one_host_per_shard(backfill_per_shard).result()
context.log.info(f"Successfully backfilled sessions_v3 for {partition_range_str}")

View File

@@ -9,14 +9,16 @@ from dagster import DagsterRunStatus, RunsFilter
from dags.common import JobOwners
notification_channel_per_team = {
JobOwners.TEAM_ANALYTICS_PLATFORM.value: "#alerts-analytics-platform",
JobOwners.TEAM_CLICKHOUSE.value: "#alerts-clickhouse",
JobOwners.TEAM_WEB_ANALYTICS.value: "#alerts-web-analytics",
JobOwners.TEAM_REVENUE_ANALYTICS.value: "#alerts-revenue-analytics",
JobOwners.TEAM_ERROR_TRACKING.value: "#alerts-error-tracking",
JobOwners.TEAM_GROWTH.value: "#alerts-growth",
JobOwners.TEAM_EXPERIMENTS.value: "#alerts-experiments",
JobOwners.TEAM_MAX_AI.value: "#alerts-max-ai",
JobOwners.TEAM_DATA_WAREHOUSE.value: "#alerts-data-warehouse",
JobOwners.TEAM_ERROR_TRACKING.value: "#alerts-error-tracking",
JobOwners.TEAM_EXPERIMENTS.value: "#alerts-experiments-dagster",
JobOwners.TEAM_GROWTH.value: "#alerts-growth",
JobOwners.TEAM_INGESTION.value: "#alerts-ingestion",
JobOwners.TEAM_MAX_AI.value: "#alerts-max-ai",
JobOwners.TEAM_REVENUE_ANALYTICS.value: "#alerts-revenue-analytics",
JobOwners.TEAM_WEB_ANALYTICS.value: "#alerts-web-analytics",
}
CONSECUTIVE_FAILURE_THRESHOLDS = {

View File

@@ -0,0 +1,371 @@
"""
Tests that execute SQL templates against mock data to validate output and logic.
Tests both the structure and calculation logic of each metric SQL file.
"""
from datetime import datetime
from pathlib import Path
import pytest
from jinja2 import Template
from dags.llma.daily_metrics.config import config
from dags.llma.daily_metrics.utils import SQL_DIR
# Expected output columns
EXPECTED_COLUMNS = ["date", "team_id", "metric_name", "metric_value"]
def get_all_sql_files():
"""Get all SQL template files."""
return sorted(SQL_DIR.glob("*.sql"))
@pytest.fixture
def template_context():
"""Provide sample context for rendering Jinja2 templates."""
return {
"event_types": config.ai_event_types,
"pageview_mappings": config.pageview_mappings,
"date_start": "2025-01-01",
"date_end": "2025-01-02",
"include_error_rates": config.include_error_rates,
}
@pytest.fixture
def mock_events_data():
"""
Mock events data for testing SQL queries.
Simulates the events table with various AI events for testing.
Returns a list of event dicts with timestamp, team_id, event, and properties.
"""
base_time = datetime(2025, 1, 1, 12, 0, 0)
events = [
# Team 1: 3 generations, 1 with error
{
"timestamp": base_time,
"team_id": 1,
"event": "$ai_generation",
"properties": {
"$ai_trace_id": "trace-1",
"$ai_session_id": "session-1",
"$ai_error": "",
},
},
{
"timestamp": base_time,
"team_id": 1,
"event": "$ai_generation",
"properties": {
"$ai_trace_id": "trace-1",
"$ai_session_id": "session-1",
"$ai_error": "rate limit exceeded",
},
},
{
"timestamp": base_time,
"team_id": 1,
"event": "$ai_generation",
"properties": {
"$ai_trace_id": "trace-2",
"$ai_session_id": "session-1",
"$ai_error": "",
},
},
# Team 1: 2 spans from same trace
{
"timestamp": base_time,
"team_id": 1,
"event": "$ai_span",
"properties": {
"$ai_trace_id": "trace-1",
"$ai_session_id": "session-1",
"$ai_is_error": True,
},
},
{
"timestamp": base_time,
"team_id": 1,
"event": "$ai_span",
"properties": {
"$ai_trace_id": "trace-1",
"$ai_session_id": "session-1",
"$ai_is_error": False,
},
},
# Team 2: 1 generation, no errors
{
"timestamp": base_time,
"team_id": 2,
"event": "$ai_generation",
"properties": {
"$ai_trace_id": "trace-3",
"$ai_session_id": "session-2",
"$ai_error": "",
},
},
# Team 1: Pageviews on LLM Analytics
{
"timestamp": base_time,
"team_id": 1,
"event": "$pageview",
"properties": {
"$current_url": "https://app.posthog.com/project/123/llm-analytics/traces?filter=active",
},
},
{
"timestamp": base_time,
"team_id": 1,
"event": "$pageview",
"properties": {
"$current_url": "https://app.posthog.com/project/123/llm-analytics/traces",
},
},
{
"timestamp": base_time,
"team_id": 1,
"event": "$pageview",
"properties": {
"$current_url": "https://app.posthog.com/project/123/llm-analytics/generations",
},
},
]
return events
@pytest.fixture
def expected_metrics(mock_events_data):
"""
Expected metric outputs based on mock_events_data.
This serves as the source of truth for what each SQL file should produce.
"""
return {
"event_counts.sql": [
{"date": "2025-01-01", "team_id": 1, "metric_name": "ai_generation_count", "metric_value": 3.0},
{"date": "2025-01-01", "team_id": 1, "metric_name": "ai_span_count", "metric_value": 2.0},
{"date": "2025-01-01", "team_id": 2, "metric_name": "ai_generation_count", "metric_value": 1.0},
],
"trace_counts.sql": [
{
"date": "2025-01-01",
"team_id": 1,
"metric_name": "ai_trace_id_count",
"metric_value": 2.0,
}, # trace-1, trace-2
{"date": "2025-01-01", "team_id": 2, "metric_name": "ai_trace_id_count", "metric_value": 1.0}, # trace-3
],
"session_counts.sql": [
{
"date": "2025-01-01",
"team_id": 1,
"metric_name": "ai_session_id_count",
"metric_value": 1.0,
}, # session-1
{
"date": "2025-01-01",
"team_id": 2,
"metric_name": "ai_session_id_count",
"metric_value": 1.0,
}, # session-2
],
"error_rates.sql": [
# Team 1: 1 errored generation out of 3 = 0.3333
{"date": "2025-01-01", "team_id": 1, "metric_name": "ai_generation_error_rate", "metric_value": 0.3333},
# Team 1: 1 errored span out of 2 = 0.5
{"date": "2025-01-01", "team_id": 1, "metric_name": "ai_span_error_rate", "metric_value": 0.5},
# Team 2: 0 errored out of 1 = 0.0
{"date": "2025-01-01", "team_id": 2, "metric_name": "ai_generation_error_rate", "metric_value": 0.0},
],
"trace_error_rates.sql": [
# Team 1: trace-1 has errors, trace-2 doesn't = 1/2 = 0.5
{"date": "2025-01-01", "team_id": 1, "metric_name": "ai_trace_id_has_error_rate", "metric_value": 0.5},
# Team 2: trace-3 has no errors = 0/1 = 0.0
{"date": "2025-01-01", "team_id": 2, "metric_name": "ai_trace_id_has_error_rate", "metric_value": 0.0},
],
"pageview_counts.sql": [
{"date": "2025-01-01", "team_id": 1, "metric_name": "pageviews_traces", "metric_value": 2.0},
{"date": "2025-01-01", "team_id": 1, "metric_name": "pageviews_generations", "metric_value": 1.0},
],
}
@pytest.mark.parametrize("sql_file", get_all_sql_files(), ids=lambda f: f.stem)
def test_sql_output_structure(sql_file: Path, template_context: dict, mock_events_data: list):
"""
Test that each SQL file produces output with the correct structure.
This test verifies:
1. SQL renders without errors
2. Output has exactly 4 columns
3. Columns are named correctly: date, team_id, metric_name, metric_value
"""
with open(sql_file) as f:
template = Template(f.read())
rendered = template.render(**template_context)
# Basic smoke test - SQL should render
assert rendered.strip(), f"{sql_file.name} rendered to empty string"
# SQL should be a SELECT statement
assert "SELECT" in rendered.upper(), f"{sql_file.name} should contain SELECT"
# Should have all required column aliases (or direct column references for team_id)
for col in EXPECTED_COLUMNS:
# team_id is often selected directly without an alias
if col == "team_id":
assert "team_id" in rendered.lower(), f"{sql_file.name} missing column: {col}"
else:
assert f"as {col}" in rendered.lower(), f"{sql_file.name} missing column alias: {col}"
def test_event_counts_logic(template_context: dict, mock_events_data: list, expected_metrics: dict):
"""Test event_counts.sql produces correct counts."""
sql_file = SQL_DIR / "event_counts.sql"
with open(sql_file) as f:
template = Template(f.read())
rendered = template.render(**template_context)
# Verify SQL structure expectations
assert "count(*)" in rendered.lower(), "event_counts.sql should use count(*)"
assert "group by date, team_id, event" in rendered.lower(), "Should group by date, team_id, event"
# Verify expected metrics documentation
expected = expected_metrics["event_counts.sql"]
assert len(expected) == 3, "Expected 3 metric rows based on mock data"
# Team 1 should have 3 generation events
gen_team1 = next(m for m in expected if m["team_id"] == 1 and "generation" in m["metric_name"])
assert gen_team1["metric_value"] == 3.0
# Team 1 should have 2 span events
span_team1 = next(m for m in expected if m["team_id"] == 1 and "span" in m["metric_name"])
assert span_team1["metric_value"] == 2.0
def test_trace_counts_logic(template_context: dict, mock_events_data: list, expected_metrics: dict):
"""Test trace_counts.sql produces correct distinct trace counts."""
sql_file = SQL_DIR / "trace_counts.sql"
with open(sql_file) as f:
template = Template(f.read())
rendered = template.render(**template_context)
# Verify SQL uses count(DISTINCT)
assert "count(distinct" in rendered.lower(), "trace_counts.sql should use count(DISTINCT)"
assert "$ai_trace_id" in rendered, "Should count distinct $ai_trace_id"
expected = expected_metrics["trace_counts.sql"]
# Team 1: Should have 2 unique traces (trace-1, trace-2)
team1 = next(m for m in expected if m["team_id"] == 1)
assert team1["metric_value"] == 2.0, "Team 1 should have 2 unique traces"
# Team 2: Should have 1 unique trace (trace-3)
team2 = next(m for m in expected if m["team_id"] == 2)
assert team2["metric_value"] == 1.0, "Team 2 should have 1 unique trace"
def test_session_counts_logic(template_context: dict, mock_events_data: list, expected_metrics: dict):
"""Test session_counts.sql produces correct distinct session counts."""
sql_file = SQL_DIR / "session_counts.sql"
with open(sql_file) as f:
template = Template(f.read())
rendered = template.render(**template_context)
# Verify SQL uses count(DISTINCT)
assert "count(distinct" in rendered.lower(), "session_counts.sql should use count(DISTINCT)"
assert "$ai_session_id" in rendered, "Should count distinct $ai_session_id"
expected = expected_metrics["session_counts.sql"]
# Team 1: Should have 1 unique session (session-1)
team1 = next(m for m in expected if m["team_id"] == 1)
assert team1["metric_value"] == 1.0, "Team 1 should have 1 unique session"
# Team 2: Should have 1 unique session (session-2)
team2 = next(m for m in expected if m["team_id"] == 2)
assert team2["metric_value"] == 1.0, "Team 2 should have 1 unique session"
def test_error_rates_logic(template_context: dict, mock_events_data: list, expected_metrics: dict):
"""Test error_rates.sql calculates proportions correctly."""
sql_file = SQL_DIR / "error_rates.sql"
with open(sql_file) as f:
template = Template(f.read())
rendered = template.render(**template_context)
# Verify SQL calculates proportions
assert "countif" in rendered.lower(), "error_rates.sql should use countIf"
assert "/ count(*)" in rendered.lower(), "Should divide by total count"
assert "round(" in rendered.lower(), "Should round the result"
# Verify error detection logic
assert "$ai_error" in rendered, "Should check $ai_error property"
assert "$ai_is_error" in rendered, "Should check $ai_is_error property"
expected = expected_metrics["error_rates.sql"]
# Team 1 generations: 1 error out of 3 = 0.3333
gen_team1 = next(m for m in expected if m["team_id"] == 1 and "generation" in m["metric_name"])
assert abs(gen_team1["metric_value"] - 0.3333) < 0.0001, "Generation error rate should be ~0.3333"
# Team 1 spans: 1 error out of 2 = 0.5
span_team1 = next(m for m in expected if m["team_id"] == 1 and "span" in m["metric_name"])
assert span_team1["metric_value"] == 0.5, "Span error rate should be 0.5"
def test_trace_error_rates_logic(template_context: dict, mock_events_data: list, expected_metrics: dict):
"""Test trace_error_rates.sql calculates trace-level error proportions correctly."""
sql_file = SQL_DIR / "trace_error_rates.sql"
with open(sql_file) as f:
template = Template(f.read())
rendered = template.render(**template_context)
# Verify SQL uses distinct count
assert "countdistinctif" in rendered.lower(), "Should use countDistinctIf"
assert "count(distinct" in rendered.lower(), "Should use count(DISTINCT) for total"
assert "$ai_trace_id" in rendered, "Should work with $ai_trace_id"
expected = expected_metrics["trace_error_rates.sql"]
# Team 1: trace-1 has errors, trace-2 doesn't = 1/2 = 0.5
team1 = next(m for m in expected if m["team_id"] == 1)
assert team1["metric_value"] == 0.5, "Team 1 should have 50% of traces with errors"
# Team 2: trace-3 has no errors = 0/1 = 0.0
team2 = next(m for m in expected if m["team_id"] == 2)
assert team2["metric_value"] == 0.0, "Team 2 should have 0% of traces with errors"
def test_pageview_counts_logic(template_context: dict, mock_events_data: list, expected_metrics: dict):
"""Test pageview_counts.sql categorizes and counts pageviews correctly."""
sql_file = SQL_DIR / "pageview_counts.sql"
with open(sql_file) as f:
template = Template(f.read())
rendered = template.render(**template_context)
# Verify SQL filters $pageview events
assert "event = '$pageview'" in rendered, "Should filter for $pageview events"
assert "$current_url" in rendered, "Should use $current_url property"
assert "LIKE" in rendered, "Should use LIKE for URL matching"
# Verify pageview mappings are used
assert config.pageview_mappings is not None
for url_path, _ in config.pageview_mappings:
assert url_path in rendered, f"Should include pageview mapping for {url_path}"
expected = expected_metrics["pageview_counts.sql"]
# Team 1: 2 trace pageviews
traces = next(m for m in expected if "traces" in m["metric_name"])
assert traces["metric_value"] == 2.0, "Should count 2 trace pageviews"
# Team 1: 1 generation pageview
gens = next(m for m in expected if "generations" in m["metric_name"])
assert gens["metric_value"] == 1.0, "Should count 1 generation pageview"

View File

@@ -16,7 +16,27 @@ from posthog import settings
from posthog.clickhouse.client.connection import Workload
from posthog.clickhouse.cluster import ClickhouseCluster
from dags.backups import Backup, BackupConfig, get_latest_backup, non_sharded_backup, prepare_run_config, sharded_backup
from dags.backups import (
Backup,
BackupConfig,
BackupStatus,
get_latest_backups,
get_latest_successful_backup,
non_sharded_backup,
prepare_run_config,
sharded_backup,
)
def test_get_latest_backup_empty():
mock_s3 = MagicMock()
mock_s3.get_client().list_objects_v2.return_value = {}
config = BackupConfig(database="posthog", table="dummy")
context = dagster.build_op_context()
result = get_latest_backups(context=context, config=config, s3=mock_s3)
assert result == []
@pytest.mark.parametrize("table", ["", "test"])
@@ -31,15 +51,84 @@ def test_get_latest_backup(table: str):
}
config = BackupConfig(database="posthog", table=table)
result = get_latest_backup(config=config, s3=mock_s3)
context = dagster.build_op_context()
result = get_latest_backups(context=context, config=config, s3=mock_s3)
assert isinstance(result, Backup)
assert result.database == "posthog"
assert result.date == "2024-03-01T07:54:04Z"
assert result.base_backup is None
assert isinstance(result, list)
assert result[0].database == "posthog"
assert result[0].date == "2024-03-01T07:54:04Z"
assert result[0].base_backup is None
assert result[1].database == "posthog"
assert result[1].date == "2024-02-01T07:54:04Z"
assert result[1].base_backup is None
assert result[2].database == "posthog"
assert result[2].date == "2024-01-01T07:54:04Z"
assert result[2].base_backup is None
expected_table = table if table else None
assert result.table == expected_table
assert result[0].table == expected_table
assert result[1].table == expected_table
assert result[2].table == expected_table
def test_get_latest_successful_backup_returns_latest_backup():
config = BackupConfig(database="posthog", table="test", incremental=True)
backup1 = Backup(database="posthog", date="2024-02-01T07:54:04Z", table="test")
backup1.is_done = MagicMock(return_value=True) # type: ignore
backup1.status = MagicMock( # type: ignore
return_value=BackupStatus(hostname="test", status="CREATING_BACKUP", event_time_microseconds=datetime.now())
)
backup2 = Backup(database="posthog", date="2024-01-01T07:54:04Z", table="test")
backup2.is_done = MagicMock(return_value=True) # type: ignore
backup2.status = MagicMock( # type: ignore
return_value=BackupStatus(hostname="test", status="BACKUP_CREATED", event_time_microseconds=datetime.now())
)
def mock_map_hosts(fn, **kwargs):
mock_result = MagicMock()
mock_client = MagicMock()
mock_result.result.return_value = {"host1": fn(mock_client)}
return mock_result
cluster = MagicMock()
cluster.map_hosts_by_role.side_effect = mock_map_hosts
result = get_latest_successful_backup(
context=dagster.build_op_context(),
config=config,
latest_backups=[backup1, backup2],
cluster=cluster,
)
assert result == backup2
def test_get_latest_successful_backup_fails():
config = BackupConfig(database="posthog", table="test", incremental=True)
backup1 = Backup(database="posthog", date="2024-02-01T07:54:04Z", table="test")
backup1.status = MagicMock( # type: ignore
return_value=BackupStatus(hostname="test", status="CREATING_BACKUP", event_time_microseconds=datetime.now())
)
def mock_map_hosts(fn, **kwargs):
mock_result = MagicMock()
mock_client = MagicMock()
mock_result.result.return_value = {"host1": fn(mock_client)}
return mock_result
cluster = MagicMock()
cluster.map_hosts_by_role.side_effect = mock_map_hosts
with pytest.raises(dagster.Failure):
get_latest_successful_backup(
context=dagster.build_op_context(),
config=config,
latest_backups=[backup1],
cluster=cluster,
)
def run_backup_test(
@@ -144,7 +233,6 @@ def run_backup_test(
def test_full_non_sharded_backup(cluster: ClickhouseCluster):
config = BackupConfig(
database=settings.CLICKHOUSE_DATABASE,
date=datetime.now().strftime("%Y-%m-%dT%H:%M:%SZ"),
table="person_distinct_id_overrides",
incremental=False,
workload=Workload.ONLINE,
@@ -161,7 +249,6 @@ def test_full_non_sharded_backup(cluster: ClickhouseCluster):
def test_full_sharded_backup(cluster: ClickhouseCluster):
config = BackupConfig(
database=settings.CLICKHOUSE_DATABASE,
date=datetime.now().strftime("%Y-%m-%dT%H:%M:%SZ"),
table="person_distinct_id_overrides",
incremental=False,
workload=Workload.ONLINE,
@@ -178,7 +265,6 @@ def test_full_sharded_backup(cluster: ClickhouseCluster):
def test_incremental_non_sharded_backup(cluster: ClickhouseCluster):
config = BackupConfig(
database=settings.CLICKHOUSE_DATABASE,
date=datetime.now().strftime("%Y-%m-%dT%H:%M:%SZ"),
table="person_distinct_id_overrides",
incremental=True,
workload=Workload.ONLINE,
@@ -195,7 +281,6 @@ def test_incremental_non_sharded_backup(cluster: ClickhouseCluster):
def test_incremental_sharded_backup(cluster: ClickhouseCluster):
config = BackupConfig(
database=settings.CLICKHOUSE_DATABASE,
date=datetime.now().strftime("%Y-%m-%dT%H:%M:%SZ"),
table="person_distinct_id_overrides",
incremental=True,
workload=Workload.ONLINE,

View File

@@ -0,0 +1,485 @@
"""Tests for the persons new backfill job."""
from unittest.mock import MagicMock, patch
import psycopg2.errors
from dagster import build_op_context
from dags.persons_new_backfill import PersonsNewBackfillConfig, copy_chunk, create_chunks
class TestCreateChunks:
"""Test the create_chunks function."""
def test_create_chunks_produces_non_overlapping_ranges(self):
"""Test that chunks produce non-overlapping ranges."""
config = PersonsNewBackfillConfig(chunk_size=1000)
id_range = (1, 5000) # min_id=1, max_id=5000
context = build_op_context()
chunks = list(create_chunks(context, config, id_range))
# Extract all chunk ranges from DynamicOutput objects
chunk_ranges = [chunk.value for chunk in chunks]
# Verify no overlaps
for i, (min1, max1) in enumerate(chunk_ranges):
for j, (min2, max2) in enumerate(chunk_ranges):
if i != j:
# Chunks should not overlap
assert not (
min1 <= min2 <= max1 or min1 <= max2 <= max1 or min2 <= min1 <= max2
), f"Chunks overlap: ({min1}, {max1}) and ({min2}, {max2})"
def test_create_chunks_covers_entire_id_space(self):
"""Test that chunks cover the entire ID space from min to max."""
config = PersonsNewBackfillConfig(chunk_size=1000)
min_id, max_id = 1, 5000
id_range = (min_id, max_id)
context = build_op_context()
chunks = list(create_chunks(context, config, id_range))
# Extract all chunk ranges from DynamicOutput objects
chunk_ranges = [chunk.value for chunk in chunks]
# Find the overall min and max covered
all_ids_covered: set[int] = set()
for chunk_min, chunk_max in chunk_ranges:
all_ids_covered.update(range(chunk_min, chunk_max + 1))
# Verify all IDs from min_id to max_id are covered
expected_ids = set(range(min_id, max_id + 1))
assert all_ids_covered == expected_ids, (
f"Missing IDs: {expected_ids - all_ids_covered}, " f"Extra IDs: {all_ids_covered - expected_ids}"
)
def test_create_chunks_first_chunk_includes_max_id(self):
"""Test that the first chunk (in yielded order) includes the source table max_id."""
config = PersonsNewBackfillConfig(chunk_size=1000)
min_id, max_id = 1, 5000
id_range = (min_id, max_id)
context = build_op_context()
chunks = list(create_chunks(context, config, id_range))
# First chunk in the list (yielded first, highest IDs)
first_chunk_min, first_chunk_max = chunks[0].value
assert first_chunk_max == max_id, f"First chunk max ({first_chunk_max}) should equal source max_id ({max_id})"
assert (
first_chunk_min <= max_id <= first_chunk_max
), f"First chunk ({first_chunk_min}, {first_chunk_max}) should include max_id ({max_id})"
def test_create_chunks_final_chunk_includes_min_id(self):
"""Test that the final chunk (in yielded order) includes the source table min_id."""
config = PersonsNewBackfillConfig(chunk_size=1000)
min_id, max_id = 1, 5000
id_range = (min_id, max_id)
context = build_op_context()
chunks = list(create_chunks(context, config, id_range))
# Last chunk in the list (yielded last, lowest IDs)
final_chunk_min, final_chunk_max = chunks[-1].value
assert final_chunk_min == min_id, f"Final chunk min ({final_chunk_min}) should equal source min_id ({min_id})"
assert (
final_chunk_min <= min_id <= final_chunk_max
), f"Final chunk ({final_chunk_min}, {final_chunk_max}) should include min_id ({min_id})"
def test_create_chunks_reverse_order(self):
"""Test that chunks are yielded in reverse order (highest IDs first)."""
config = PersonsNewBackfillConfig(chunk_size=1000)
min_id, max_id = 1, 5000
id_range = (min_id, max_id)
context = build_op_context()
chunks = list(create_chunks(context, config, id_range))
# Verify chunks are in descending order by max_id
for i in range(len(chunks) - 1):
current_max = chunks[i].value[1]
next_max = chunks[i + 1].value[1]
assert (
current_max > next_max
), f"Chunks not in reverse order: chunk {i} max ({current_max}) should be > chunk {i+1} max ({next_max})"
def test_create_chunks_exact_multiple(self):
"""Test chunk creation when ID range is an exact multiple of chunk_size."""
config = PersonsNewBackfillConfig(chunk_size=1000)
min_id, max_id = 1, 5000 # Exactly 5 chunks of 1000
id_range = (min_id, max_id)
context = build_op_context()
chunks = list(create_chunks(context, config, id_range))
assert len(chunks) == 5, f"Expected 5 chunks, got {len(chunks)}"
# Verify first chunk (highest IDs)
assert chunks[0].value == (4001, 5000), f"First chunk should be (4001, 5000), got {chunks[0].value}"
# Verify last chunk (lowest IDs)
assert chunks[-1].value == (1, 1000), f"Last chunk should be (1, 1000), got {chunks[-1].value}"
def test_create_chunks_non_exact_multiple(self):
"""Test chunk creation when ID range is not an exact multiple of chunk_size."""
config = PersonsNewBackfillConfig(chunk_size=1000)
min_id, max_id = 1, 3750 # 3 full chunks + 1 partial chunk
id_range = (min_id, max_id)
context = build_op_context()
chunks = list(create_chunks(context, config, id_range))
assert len(chunks) == 4, f"Expected 4 chunks, got {len(chunks)}"
# Verify first chunk (highest IDs) - should be the partial chunk
assert chunks[0].value == (3001, 3750), f"First chunk should be (3001, 3750), got {chunks[0].value}"
# Verify last chunk (lowest IDs)
assert chunks[-1].value == (1, 1000), f"Last chunk should be (1, 1000), got {chunks[-1].value}"
def test_create_chunks_single_chunk(self):
"""Test chunk creation when ID range fits in a single chunk."""
config = PersonsNewBackfillConfig(chunk_size=1000)
min_id, max_id = 100, 500
id_range = (min_id, max_id)
context = build_op_context()
chunks = list(create_chunks(context, config, id_range))
assert len(chunks) == 1, f"Expected 1 chunk, got {len(chunks)}"
assert chunks[0].value == (100, 500), f"Chunk should be (100, 500), got {chunks[0].value}"
assert chunks[0].value[0] == min_id and chunks[0].value[1] == max_id
def create_mock_database_resource(rowcount_values=None):
"""
Create a mock database resource that mimics psycopg2.extensions.connection.
Args:
rowcount_values: List of rowcount values to return per INSERT call.
If None, defaults to 0. If a single int, uses that for all calls.
"""
mock_cursor = MagicMock()
if rowcount_values is None:
mock_cursor.rowcount = 0
elif isinstance(rowcount_values, int):
mock_cursor.rowcount = rowcount_values
else:
# Use side_effect to return different rowcounts per call
call_count = [0]
def get_rowcount():
if call_count[0] < len(rowcount_values):
result = rowcount_values[call_count[0]]
call_count[0] += 1
return result
return rowcount_values[-1] if rowcount_values else 0
mock_cursor.rowcount = property(lambda self: get_rowcount())
mock_cursor.execute = MagicMock()
# Make cursor() return a context manager
mock_conn = MagicMock()
mock_conn.cursor.return_value.__enter__ = MagicMock(return_value=mock_cursor)
mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
return mock_conn
def create_mock_cluster_resource():
"""Create a mock ClickhouseCluster resource."""
return MagicMock()
class TestCopyChunk:
"""Test the copy_chunk function."""
def test_copy_chunk_single_batch_success(self):
"""Test successful copy of a single batch within a chunk."""
config = PersonsNewBackfillConfig(
chunk_size=1000, batch_size=100, source_table="posthog_persons", destination_table="posthog_persons_new"
)
chunk = (1, 100) # Single batch covers entire chunk
mock_db = create_mock_database_resource(rowcount_values=50)
mock_cluster = create_mock_cluster_resource()
context = build_op_context(
resources={"database": mock_db, "cluster": mock_cluster},
)
# Patch context.run.job_name where it's accessed in copy_chunk
from unittest.mock import PropertyMock
with patch.object(type(context), "run", PropertyMock(return_value=MagicMock(job_name="test_job"))):
result = copy_chunk(context, config, chunk)
# Verify result
assert result["chunk_min"] == 1
assert result["chunk_max"] == 100
assert result["records_copied"] == 50
# Verify SET statements called once (session-level, before loop)
set_statements = [
"SET application_name = 'backfill_posthog_persons_to_posthog_persons_new'",
"SET lock_timeout = '5s'",
"SET statement_timeout = '30min'",
"SET maintenance_work_mem = '12GB'",
"SET work_mem = '512MB'",
"SET temp_buffers = '512MB'",
"SET max_parallel_workers_per_gather = 2",
"SET max_parallel_maintenance_workers = 2",
"SET synchronous_commit = off",
]
cursor = mock_db.cursor.return_value.__enter__.return_value
execute_calls = [call[0][0] for call in cursor.execute.call_args_list]
# Check SET statements were called
for stmt in set_statements:
assert any(stmt in call for call in execute_calls), f"SET statement not found: {stmt}"
# Verify BEGIN, INSERT, COMMIT called once
assert execute_calls.count("BEGIN") == 1
assert execute_calls.count("COMMIT") == 1
# Verify INSERT query format
insert_calls = [call for call in execute_calls if "INSERT INTO" in call]
assert len(insert_calls) == 1
insert_query = insert_calls[0]
assert "INSERT INTO posthog_persons_new" in insert_query
assert "SELECT s.*" in insert_query
assert "FROM posthog_persons s" in insert_query
assert "WHERE s.id >" in insert_query
assert "AND s.id <=" in insert_query
assert "NOT EXISTS" in insert_query
assert "ORDER BY s.id DESC" in insert_query
def test_copy_chunk_multiple_batches(self):
"""Test copy with multiple batches in a chunk."""
config = PersonsNewBackfillConfig(
chunk_size=1000, batch_size=100, source_table="posthog_persons", destination_table="posthog_persons_new"
)
chunk = (1, 250) # 3 batches: (1,100), (100,200), (200,250)
mock_db = create_mock_database_resource()
mock_cluster = create_mock_cluster_resource()
# Track rowcount per batch - use a list to track INSERT calls
rowcounts = [50, 75, 25]
insert_call_count = [0]
cursor = mock_db.cursor.return_value.__enter__.return_value
# Track INSERT calls and set rowcount accordingly
def execute_with_rowcount(query, *args):
if "INSERT INTO" in query:
if insert_call_count[0] < len(rowcounts):
cursor.rowcount = rowcounts[insert_call_count[0]]
insert_call_count[0] += 1
else:
cursor.rowcount = 0
cursor.execute.side_effect = execute_with_rowcount
context = build_op_context(
resources={"database": mock_db, "cluster": mock_cluster},
)
# Patch context.run.job_name where it's accessed in copy_chunk
from unittest.mock import PropertyMock
with patch.object(type(context), "run", PropertyMock(return_value=MagicMock(job_name="test_job"))):
result = copy_chunk(context, config, chunk)
# Verify result
assert result["chunk_min"] == 1
assert result["chunk_max"] == 250
assert result["records_copied"] == 150 # 50 + 75 + 25
# Verify SET statements called once (before loop)
cursor = mock_db.cursor.return_value.__enter__.return_value
execute_calls = [call[0][0] for call in cursor.execute.call_args_list]
# Verify BEGIN/COMMIT called 3 times (one per batch)
assert execute_calls.count("BEGIN") == 3
assert execute_calls.count("COMMIT") == 3
# Verify INSERT called 3 times
insert_calls = [call for call in execute_calls if "INSERT INTO" in call]
assert len(insert_calls) == 3
def test_copy_chunk_duplicate_key_violation_retry(self):
"""Test that duplicate key violation triggers retry."""
config = PersonsNewBackfillConfig(
chunk_size=1000, batch_size=100, source_table="posthog_persons", destination_table="posthog_persons_new"
)
chunk = (1, 100)
mock_db = create_mock_database_resource()
mock_cluster = create_mock_cluster_resource()
cursor = mock_db.cursor.return_value.__enter__.return_value
# Track INSERT attempts
insert_attempts = [0]
# First INSERT raises UniqueViolation, second succeeds
def execute_side_effect(query, *args):
if "INSERT INTO" in query:
insert_attempts[0] += 1
if insert_attempts[0] == 1:
# First INSERT attempt raises error
# Use real UniqueViolation - pgcode is readonly but isinstance check will pass
raise psycopg2.errors.UniqueViolation("duplicate key value violates unique constraint")
# Subsequent calls succeed
cursor.rowcount = 50 # Success on retry
cursor.execute.side_effect = execute_side_effect
context = build_op_context(
resources={"database": mock_db, "cluster": mock_cluster},
)
# Need to patch time.sleep and run.job_name
from unittest.mock import PropertyMock
mock_run = MagicMock(job_name="test_job")
with (
patch("dags.persons_new_backfill.time.sleep"),
patch.object(type(context), "run", PropertyMock(return_value=mock_run)),
):
copy_chunk(context, config, chunk)
# Verify ROLLBACK was called on error
execute_calls = [call[0][0] for call in cursor.execute.call_args_list]
assert "ROLLBACK" in execute_calls
# Verify retry succeeded (should have INSERT called twice, COMMIT once)
insert_calls = [call for call in execute_calls if "INSERT INTO" in call]
assert len(insert_calls) >= 1 # At least one successful INSERT
def test_copy_chunk_error_handling_and_rollback(self):
"""Test error handling and rollback on non-duplicate errors."""
config = PersonsNewBackfillConfig(
chunk_size=1000, batch_size=100, source_table="posthog_persons", destination_table="posthog_persons_new"
)
chunk = (1, 100)
mock_db = create_mock_database_resource()
mock_cluster = create_mock_cluster_resource()
cursor = mock_db.cursor.return_value.__enter__.return_value
# Raise generic error on INSERT
def execute_side_effect(query, *args):
if "INSERT INTO" in query:
raise Exception("Connection lost")
cursor.execute.side_effect = execute_side_effect
context = build_op_context(
resources={"database": mock_db, "cluster": mock_cluster},
)
# Patch context.run.job_name where it's accessed in copy_chunk
from unittest.mock import PropertyMock
mock_run = MagicMock(job_name="test_job")
with patch.object(type(context), "run", PropertyMock(return_value=mock_run)):
# Should raise Dagster.Failure
from dagster import Failure
try:
copy_chunk(context, config, chunk)
raise AssertionError("Expected Dagster.Failure to be raised")
except Failure as e:
# Verify error metadata
assert e.description is not None
assert "Failed to copy batch" in e.description
# Verify ROLLBACK was called
execute_calls = [call[0][0] for call in cursor.execute.call_args_list]
assert "ROLLBACK" in execute_calls
def test_copy_chunk_insert_query_format(self):
"""Test that INSERT query has correct format."""
config = PersonsNewBackfillConfig(
chunk_size=1000, batch_size=100, source_table="test_source", destination_table="test_dest"
)
chunk = (1, 100)
mock_db = create_mock_database_resource(rowcount_values=10)
mock_cluster = create_mock_cluster_resource()
context = build_op_context(
resources={"database": mock_db, "cluster": mock_cluster},
)
# Patch context.run.job_name where it's accessed in copy_chunk
from unittest.mock import PropertyMock
with patch.object(type(context), "run", PropertyMock(return_value=MagicMock(job_name="test_job"))):
copy_chunk(context, config, chunk)
cursor = mock_db.cursor.return_value.__enter__.return_value
execute_calls = [call[0][0] for call in cursor.execute.call_args_list]
# Find INSERT query
insert_query = next((call for call in execute_calls if "INSERT INTO" in call), None)
assert insert_query is not None
# Verify query components
assert "INSERT INTO test_dest" in insert_query
assert "SELECT s.*" in insert_query
assert "FROM test_source s" in insert_query
assert "WHERE s.id >" in insert_query
assert "AND s.id <=" in insert_query
assert "NOT EXISTS" in insert_query
assert "d.team_id = s.team_id" in insert_query
assert "d.id = s.id" in insert_query
assert "ORDER BY s.id DESC" in insert_query
def test_copy_chunk_session_settings_applied_once(self):
"""Test that SET statements are applied once at session level before batch loop."""
config = PersonsNewBackfillConfig(
chunk_size=1000, batch_size=50, source_table="posthog_persons", destination_table="posthog_persons_new"
)
chunk = (1, 150) # 3 batches
mock_db = create_mock_database_resource(rowcount_values=25)
mock_cluster = create_mock_cluster_resource()
context = build_op_context(
resources={"database": mock_db, "cluster": mock_cluster},
)
# Patch context.run.job_name where it's accessed in copy_chunk
from unittest.mock import PropertyMock
with patch.object(type(context), "run", PropertyMock(return_value=MagicMock(job_name="test_job"))):
copy_chunk(context, config, chunk)
cursor = mock_db.cursor.return_value.__enter__.return_value
execute_calls = [call[0][0] for call in cursor.execute.call_args_list]
# Count SET statements (should be called once each, before loop)
set_statements = [
"SET application_name",
"SET lock_timeout",
"SET statement_timeout",
"SET maintenance_work_mem",
"SET work_mem",
"SET temp_buffers",
"SET max_parallel_workers_per_gather",
"SET max_parallel_maintenance_workers",
"SET synchronous_commit",
]
for stmt in set_statements:
count = sum(1 for call in execute_calls if stmt in call)
assert count == 1, f"Expected {stmt} to be called once, but it was called {count} times"
# Verify SET statements come before BEGIN statements
set_indices = [i for i, call in enumerate(execute_calls) if any(stmt in call for stmt in set_statements)]
begin_indices = [i for i, call in enumerate(execute_calls) if call == "BEGIN"]
if set_indices and begin_indices:
assert max(set_indices) < min(begin_indices), "SET statements should come before BEGIN statements"

View File

@@ -113,7 +113,7 @@ services:
# Note: please keep the default version in sync across
# `posthog` and the `charts-clickhouse` repos
#
image: ${CLICKHOUSE_SERVER_IMAGE:-clickhouse/clickhouse-server:25.6.13.41}
image: ${CLICKHOUSE_SERVER_IMAGE:-clickhouse/clickhouse-server:25.8.11.66}
restart: on-failure
environment:
CLICKHOUSE_SKIP_USER_SETUP: 1
@@ -296,17 +296,15 @@ services:
MAXMIND_DB_PATH: '/share/GeoLite2-City.mmdb'
# Shared Redis for non-critical path (analytics, billing, cookieless)
REDIS_URL: 'redis://redis:6379/'
# Optional: Use separate Redis URLs for read/write separation
# Optional: Use separate Redis URL for read replicas
# REDIS_READER_URL: 'redis://redis-replica:6379/'
# REDIS_WRITER_URL: 'redis://redis-primary:6379/'
# Optional: Increase Redis timeout (default is 100ms)
# REDIS_TIMEOUT_MS: 200
# Dedicated Redis database for critical path (team cache + flags cache)
# Hobby deployments start in Mode 1 (shared-only). Developers override in docker-compose.dev.yml for Mode 2.
# FLAGS_REDIS_URL: 'redis://redis:6379/1'
# Optional: Use separate Flags Redis URLs for read/write separation
# Optional: Use separate Flags Redis URL for read replicas
# FLAGS_REDIS_READER_URL: 'redis://redis-replica:6379/1'
# FLAGS_REDIS_WRITER_URL: 'redis://redis-primary:6379/1'
ADDRESS: '0.0.0.0:3001'
RUST_LOG: 'info'
COOKIELESS_REDIS_HOST: redis7

View File

@@ -111,7 +111,7 @@ services:
service: clickhouse
hostname: clickhouse
# Development performance optimizations
mem_limit: 4g
mem_limit: 6g
cpus: 2
environment:
- AWS_ACCESS_KEY_ID=object_storage_root_user

View File

@@ -103,7 +103,7 @@ class ConversationViewSet(TeamAndOrgViewSetMixin, ListModelMixin, RetrieveModelM
# Only for streaming
and self.action == "create"
# Strict limits are skipped for select US region teams (PostHog + an active user we've chatted with)
and not (get_instance_region() == "US" and self.team_id in (2, 87921, 41124))
and not (get_instance_region() == "US" and self.team_id in (2, 87921, 41124, 103224))
):
return [AIBurstRateThrottle(), AISustainedRateThrottle()]

View File

@@ -79,6 +79,12 @@ class AccessControlSerializer(serializers.ModelSerializer):
f"Access level cannot be set below the minimum '{min_level}' for {resource}."
)
max_level = highest_access_level(resource)
if levels.index(access_level) > levels.index(max_level):
raise serializers.ValidationError(
f"Access level cannot be set above the maximum '{max_level}' for {resource}."
)
return access_level
def validate(self, data):
@@ -219,6 +225,7 @@ class AccessControlViewSetMixin(_GenericViewSet):
else ordered_access_levels(resource),
"default_access_level": "editor" if is_resource_level else default_access_level(resource),
"minimum_access_level": minimum_access_level(resource) if not is_resource_level else "none",
"maximum_access_level": highest_access_level(resource) if not is_resource_level else "manager",
"user_access_level": user_access_level,
"user_can_edit_access_levels": user_access_control.check_can_modify_access_levels_for_object(obj),
}

View File

@@ -146,6 +146,42 @@ class TestAccessControlMinimumLevelValidation(BaseAccessControlTest):
)
assert res.status_code == status.HTTP_200_OK, f"Failed for level {level}: {res.json()}"
def test_activity_log_access_level_cannot_be_above_viewer(self):
"""Test that activity_log access level cannot be set above maximum 'viewer'"""
self._org_membership(OrganizationMembership.Level.ADMIN)
for level in ["editor", "manager"]:
res = self.client.put(
"/api/projects/@current/resource_access_controls",
{"resource": "activity_log", "access_level": level},
)
assert res.status_code == status.HTTP_400_BAD_REQUEST, f"Failed for level {level}: {res.json()}"
assert "cannot be set above the maximum 'viewer'" in res.json()["detail"]
def test_activity_log_access_restricted_for_users_without_access(self):
"""Test that users without access to activity_log cannot access activity log endpoints"""
self._org_membership(OrganizationMembership.Level.ADMIN)
res = self.client.put(
"/api/projects/@current/resource_access_controls",
{"resource": "activity_log", "access_level": "none"},
)
assert res.status_code == status.HTTP_200_OK, f"Failed to set access control: {res.json()}"
from ee.models.rbac.access_control import AccessControl
ac = AccessControl.objects.filter(team=self.team, resource="activity_log", resource_id=None).first()
assert ac is not None, "Access control was not created"
assert ac.access_level == "none", f"Access level is {ac.access_level}, expected 'none'"
self._org_membership(OrganizationMembership.Level.MEMBER)
res = self.client.get("/api/projects/@current/activity_log/")
assert res.status_code == status.HTTP_403_FORBIDDEN, f"Expected 403, got {res.status_code}: {res.json()}"
res = self.client.get("/api/projects/@current/advanced_activity_logs/")
assert res.status_code == status.HTTP_403_FORBIDDEN, f"Expected 403, got {res.status_code}: {res.json()}"
class TestAccessControlResourceLevelAPI(BaseAccessControlTest):
def setUp(self):
@@ -186,6 +222,7 @@ class TestAccessControlResourceLevelAPI(BaseAccessControlTest):
"default_access_level": "editor",
"user_can_edit_access_levels": True,
"minimum_access_level": "none",
"maximum_access_level": "manager",
}
def test_change_rejected_if_not_org_admin(self):

View File

@@ -171,8 +171,6 @@
WHERE lc_cohort_id = 99999
AND team_id = 99999
AND query LIKE '%cohort_calc:00000000%'
AND type IN ('QueryFinish',
'ExceptionWhileProcessing')
AND event_date >= 'today'
AND event_time >= 'today 00:00:00'
ORDER BY event_time DESC
@@ -180,6 +178,25 @@
# ---
# name: TestCohortQuery.test_cohort_filter_with_another_cohort_with_event_sequence.1
'''
/* celery:posthog.tasks.calculate_cohort.collect_cohort_query_stats */
SELECT query_id,
query_duration_ms,
read_rows,
read_bytes,
written_rows,
memory_usage,
exception
FROM query_log_archive
WHERE lc_cohort_id = 99999
AND team_id = 99999
AND query LIKE '%cohort_calc:00000000%'
AND event_date >= 'today'
AND event_time >= 'today 00:00:00'
ORDER BY event_time DESC
'''
# ---
# name: TestCohortQuery.test_cohort_filter_with_another_cohort_with_event_sequence.2
'''
SELECT count(DISTINCT person_id)
FROM cohortpeople
@@ -188,7 +205,7 @@
AND version = 0
'''
# ---
# name: TestCohortQuery.test_cohort_filter_with_another_cohort_with_event_sequence.2
# name: TestCohortQuery.test_cohort_filter_with_another_cohort_with_event_sequence.3
'''
(SELECT cohort_people.person_id AS id
@@ -269,7 +286,7 @@
allow_experimental_join_condition=1
'''
# ---
# name: TestCohortQuery.test_cohort_filter_with_another_cohort_with_event_sequence.3
# name: TestCohortQuery.test_cohort_filter_with_another_cohort_with_event_sequence.4
'''
SELECT if(funnel_query.person_id = '00000000-0000-0000-0000-000000000000', person.person_id, funnel_query.person_id) AS id
@@ -341,8 +358,6 @@
WHERE lc_cohort_id = 99999
AND team_id = 99999
AND query LIKE '%cohort_calc:00000000%'
AND type IN ('QueryFinish',
'ExceptionWhileProcessing')
AND event_date >= 'today'
AND event_time >= 'today 00:00:00'
ORDER BY event_time DESC
@@ -350,12 +365,21 @@
# ---
# name: TestCohortQuery.test_cohort_filter_with_extra.1
'''
SELECT count(DISTINCT person_id)
FROM cohortpeople
WHERE team_id = 99999
AND cohort_id = 99999
AND version = 0
/* celery:posthog.tasks.calculate_cohort.collect_cohort_query_stats */
SELECT query_id,
query_duration_ms,
read_rows,
read_bytes,
written_rows,
memory_usage,
exception
FROM query_log_archive
WHERE lc_cohort_id = 99999
AND team_id = 99999
AND query LIKE '%cohort_calc:00000000%'
AND event_date >= 'today'
AND event_time >= 'today 00:00:00'
ORDER BY event_time DESC
'''
# ---
# name: TestCohortQuery.test_cohort_filter_with_extra.10
@@ -638,6 +662,16 @@
'''
# ---
# name: TestCohortQuery.test_cohort_filter_with_extra.2
'''
SELECT count(DISTINCT person_id)
FROM cohortpeople
WHERE team_id = 99999
AND cohort_id = 99999
AND version = 0
'''
# ---
# name: TestCohortQuery.test_cohort_filter_with_extra.3
'''
(SELECT cohort_people.person_id AS id
@@ -684,7 +718,7 @@
allow_experimental_join_condition=1
'''
# ---
# name: TestCohortQuery.test_cohort_filter_with_extra.3
# name: TestCohortQuery.test_cohort_filter_with_extra.4
'''
SELECT if(behavior_query.person_id = '00000000-0000-0000-0000-000000000000', person.person_id, behavior_query.person_id) AS id
@@ -726,7 +760,7 @@
join_algorithm = 'auto'
'''
# ---
# name: TestCohortQuery.test_cohort_filter_with_extra.4
# name: TestCohortQuery.test_cohort_filter_with_extra.5
'''
/* celery:posthog.tasks.calculate_cohort.collect_cohort_query_stats */
SELECT query_id,
@@ -740,14 +774,31 @@
WHERE lc_cohort_id = 99999
AND team_id = 99999
AND query LIKE '%cohort_calc:00000000%'
AND type IN ('QueryFinish',
'ExceptionWhileProcessing')
AND event_date >= 'today'
AND event_time >= 'today 00:00:00'
ORDER BY event_time DESC
'''
# ---
# name: TestCohortQuery.test_cohort_filter_with_extra.5
# name: TestCohortQuery.test_cohort_filter_with_extra.6
'''
/* celery:posthog.tasks.calculate_cohort.collect_cohort_query_stats */
SELECT query_id,
query_duration_ms,
read_rows,
read_bytes,
written_rows,
memory_usage,
exception
FROM query_log_archive
WHERE lc_cohort_id = 99999
AND team_id = 99999
AND query LIKE '%cohort_calc:00000000%'
AND event_date >= 'today'
AND event_time >= 'today 00:00:00'
ORDER BY event_time DESC
'''
# ---
# name: TestCohortQuery.test_cohort_filter_with_extra.7
'''
SELECT count(DISTINCT person_id)
@@ -757,99 +808,6 @@
AND version = 0
'''
# ---
# name: TestCohortQuery.test_cohort_filter_with_extra.6
'''
(
(SELECT persons.id AS id
FROM
(SELECT argMax(replaceRegexpAll(nullIf(nullIf(JSONExtractRaw(person.properties, 'name'), ''), 'null'), '^"|"$', ''), person.version) AS properties___name,
person.id AS id
FROM person
WHERE and(equals(person.team_id, 99999), in(id,
(SELECT where_optimization.id AS id
FROM person AS where_optimization
WHERE and(equals(where_optimization.team_id, 99999), ifNull(equals(replaceRegexpAll(nullIf(nullIf(JSONExtractRaw(where_optimization.properties, 'name'), ''), 'null'), '^"|"$', ''), 'test'), 0)))))
GROUP BY person.id
HAVING and(ifNull(equals(argMax(person.is_deleted, person.version), 0), 0), ifNull(less(argMax(toTimeZone(person.created_at, 'UTC'), person.version), plus(now64(6, 'UTC'), toIntervalDay(1))), 0))) AS persons
WHERE ifNull(equals(persons.properties___name, 'test'), 0)
ORDER BY persons.id ASC
LIMIT 1000000000 SETTINGS optimize_aggregation_in_order=1,
join_algorithm='auto'))
UNION DISTINCT (
(SELECT source.id AS id
FROM
(SELECT actor_id AS actor_id,
count() AS event_count,
groupUniqArray(distinct_id) AS event_distinct_ids,
actor_id AS id
FROM
(SELECT if(not(empty(e__override.distinct_id)), e__override.person_id, e.person_id) AS actor_id,
toTimeZone(e.timestamp, 'UTC') AS timestamp,
e.uuid AS uuid,
e.distinct_id AS distinct_id
FROM events AS e
LEFT OUTER JOIN
(SELECT argMax(person_distinct_id_overrides.person_id, person_distinct_id_overrides.version) AS person_id,
person_distinct_id_overrides.distinct_id AS distinct_id
FROM person_distinct_id_overrides
WHERE equals(person_distinct_id_overrides.team_id, 99999)
GROUP BY person_distinct_id_overrides.distinct_id
HAVING ifNull(equals(argMax(person_distinct_id_overrides.is_deleted, person_distinct_id_overrides.version), 0), 0) SETTINGS optimize_aggregation_in_order=1) AS e__override ON equals(e.distinct_id, e__override.distinct_id)
WHERE and(equals(e.team_id, 99999), greaterOrEquals(timestamp, toDateTime64('explicit_redacted_timestamp', 6, 'UTC')), lessOrEquals(timestamp, toDateTime64('today', 6, 'UTC')), equals(e.event, '$pageview')))
GROUP BY actor_id) AS source
ORDER BY source.id ASC
LIMIT 1000000000 SETTINGS optimize_aggregation_in_order=1,
join_algorithm='auto')) SETTINGS readonly=2,
max_execution_time=600,
allow_experimental_object_type=1,
format_csv_allow_double_quotes=0,
max_ast_elements=4000000,
max_expanded_ast_elements=4000000,
max_bytes_before_external_group_by=0,
transform_null_in=1,
optimize_min_equality_disjunction_chain_length=4294967295,
allow_experimental_join_condition=1
'''
# ---
# name: TestCohortQuery.test_cohort_filter_with_extra.7
'''
SELECT if(behavior_query.person_id = '00000000-0000-0000-0000-000000000000', person.person_id, behavior_query.person_id) AS id
FROM
(SELECT if(not(empty(pdi.distinct_id)), pdi.person_id, e.person_id) AS person_id,
countIf(timestamp > now() - INTERVAL 1 week
AND timestamp < now()
AND event = '$pageview'
AND 1=1) > 0 AS performed_event_condition_None_level_level_0_level_1_level_0_0
FROM events e
LEFT OUTER JOIN
(SELECT distinct_id,
argMax(person_id, version) as person_id
FROM person_distinct_id2
WHERE team_id = 99999
GROUP BY distinct_id
HAVING argMax(is_deleted, version) = 0) AS pdi ON e.distinct_id = pdi.distinct_id
WHERE team_id = 99999
AND event IN ['$pageview']
AND timestamp <= now()
AND timestamp >= now() - INTERVAL 1 week
GROUP BY person_id) behavior_query
FULL OUTER JOIN
(SELECT *,
id AS person_id
FROM
(SELECT id,
argMax(properties, version) as person_props
FROM person
WHERE team_id = 99999
GROUP BY id
HAVING max(is_deleted) = 0 SETTINGS optimize_aggregation_in_order = 1)) person ON person.person_id = behavior_query.person_id
WHERE 1 = 1
AND ((((has(['test'], replaceRegexpAll(JSONExtractRaw(person_props, 'name'), '^"|"$', ''))))
OR ((coalesce(performed_event_condition_None_level_level_0_level_1_level_0_0, false))))) SETTINGS optimize_aggregation_in_order = 1,
join_algorithm = 'auto'
'''
# ---
# name: TestCohortQuery.test_cohort_filter_with_extra.8
'''
(
@@ -1972,8 +1930,6 @@
WHERE lc_cohort_id = 99999
AND team_id = 99999
AND query LIKE '%cohort_calc:00000000%'
AND type IN ('QueryFinish',
'ExceptionWhileProcessing')
AND event_date >= 'today'
AND event_time >= 'today 00:00:00'
ORDER BY event_time DESC
@@ -1981,6 +1937,25 @@
# ---
# name: TestCohortQuery.test_precalculated_cohort_filter_with_extra_filters.1
'''
/* celery:posthog.tasks.calculate_cohort.collect_cohort_query_stats */
SELECT query_id,
query_duration_ms,
read_rows,
read_bytes,
written_rows,
memory_usage,
exception
FROM query_log_archive
WHERE lc_cohort_id = 99999
AND team_id = 99999
AND query LIKE '%cohort_calc:00000000%'
AND event_date >= 'today'
AND event_time >= 'today 00:00:00'
ORDER BY event_time DESC
'''
# ---
# name: TestCohortQuery.test_precalculated_cohort_filter_with_extra_filters.2
'''
SELECT count(DISTINCT person_id)
FROM cohortpeople
@@ -1989,45 +1964,6 @@
AND version = 0
'''
# ---
# name: TestCohortQuery.test_precalculated_cohort_filter_with_extra_filters.2
'''
(SELECT cohort_people.person_id AS id
FROM
(SELECT DISTINCT cohortpeople.person_id AS person_id,
cohortpeople.cohort_id AS cohort_id,
cohortpeople.team_id AS team_id
FROM cohortpeople
WHERE and(equals(cohortpeople.team_id, 99999), in(tuple(cohortpeople.cohort_id, cohortpeople.version), [(99999, 0)]))) AS cohort_people
WHERE and(ifNull(equals(cohort_people.cohort_id, 99999), 0), ifNull(equals(cohort_people.team_id, 99999), 0))
LIMIT 1000000000)
UNION DISTINCT
(SELECT persons.id AS id
FROM
(SELECT argMax(replaceRegexpAll(nullIf(nullIf(JSONExtractRaw(person.properties, 'name'), ''), 'null'), '^"|"$', ''), person.version) AS properties___name,
person.id AS id
FROM person
WHERE and(equals(person.team_id, 99999), in(id,
(SELECT where_optimization.id AS id
FROM person AS where_optimization
WHERE and(equals(where_optimization.team_id, 99999), ifNull(equals(replaceRegexpAll(nullIf(nullIf(JSONExtractRaw(where_optimization.properties, 'name'), ''), 'null'), '^"|"$', ''), 'test2'), 0)))))
GROUP BY person.id
HAVING and(ifNull(equals(argMax(person.is_deleted, person.version), 0), 0), ifNull(less(argMax(toTimeZone(person.created_at, 'UTC'), person.version), plus(now64(6, 'UTC'), toIntervalDay(1))), 0))) AS persons
WHERE ifNull(equals(persons.properties___name, 'test2'), 0)
ORDER BY persons.id ASC
LIMIT 1000000000 SETTINGS optimize_aggregation_in_order=1,
join_algorithm='auto') SETTINGS readonly=2,
max_execution_time=600,
allow_experimental_object_type=1,
format_csv_allow_double_quotes=0,
max_ast_elements=4000000,
max_expanded_ast_elements=4000000,
max_bytes_before_external_group_by=0,
transform_null_in=1,
optimize_min_equality_disjunction_chain_length=4294967295,
allow_experimental_join_condition=1
'''
# ---
# name: TestCohortQuery.test_precalculated_cohort_filter_with_extra_filters.3
'''
@@ -2068,6 +2004,45 @@
'''
# ---
# name: TestCohortQuery.test_precalculated_cohort_filter_with_extra_filters.4
'''
(SELECT cohort_people.person_id AS id
FROM
(SELECT DISTINCT cohortpeople.person_id AS person_id,
cohortpeople.cohort_id AS cohort_id,
cohortpeople.team_id AS team_id
FROM cohortpeople
WHERE and(equals(cohortpeople.team_id, 99999), in(tuple(cohortpeople.cohort_id, cohortpeople.version), [(99999, 0)]))) AS cohort_people
WHERE and(ifNull(equals(cohort_people.cohort_id, 99999), 0), ifNull(equals(cohort_people.team_id, 99999), 0))
LIMIT 1000000000)
UNION DISTINCT
(SELECT persons.id AS id
FROM
(SELECT argMax(replaceRegexpAll(nullIf(nullIf(JSONExtractRaw(person.properties, 'name'), ''), 'null'), '^"|"$', ''), person.version) AS properties___name,
person.id AS id
FROM person
WHERE and(equals(person.team_id, 99999), in(id,
(SELECT where_optimization.id AS id
FROM person AS where_optimization
WHERE and(equals(where_optimization.team_id, 99999), ifNull(equals(replaceRegexpAll(nullIf(nullIf(JSONExtractRaw(where_optimization.properties, 'name'), ''), 'null'), '^"|"$', ''), 'test2'), 0)))))
GROUP BY person.id
HAVING and(ifNull(equals(argMax(person.is_deleted, person.version), 0), 0), ifNull(less(argMax(toTimeZone(person.created_at, 'UTC'), person.version), plus(now64(6, 'UTC'), toIntervalDay(1))), 0))) AS persons
WHERE ifNull(equals(persons.properties___name, 'test2'), 0)
ORDER BY persons.id ASC
LIMIT 1000000000 SETTINGS optimize_aggregation_in_order=1,
join_algorithm='auto') SETTINGS readonly=2,
max_execution_time=600,
allow_experimental_object_type=1,
format_csv_allow_double_quotes=0,
max_ast_elements=4000000,
max_expanded_ast_elements=4000000,
max_bytes_before_external_group_by=0,
transform_null_in=1,
optimize_min_equality_disjunction_chain_length=4294967295,
allow_experimental_join_condition=1
'''
# ---
# name: TestCohortQuery.test_precalculated_cohort_filter_with_extra_filters.5
'''
SELECT person.person_id AS id
@@ -2314,8 +2289,6 @@
WHERE lc_cohort_id = 99999
AND team_id = 99999
AND query LIKE '%cohort_calc:00000000%'
AND type IN ('QueryFinish',
'ExceptionWhileProcessing')
AND event_date >= 'today'
AND event_time >= 'today 00:00:00'
ORDER BY event_time DESC
@@ -2323,6 +2296,25 @@
# ---
# name: TestCohortQuery.test_unwrapping_static_cohort_filter_hidden_in_layers_of_cohorts.1
'''
/* celery:posthog.tasks.calculate_cohort.collect_cohort_query_stats */
SELECT query_id,
query_duration_ms,
read_rows,
read_bytes,
written_rows,
memory_usage,
exception
FROM query_log_archive
WHERE lc_cohort_id = 99999
AND team_id = 99999
AND query LIKE '%cohort_calc:00000000%'
AND event_date >= 'today'
AND event_time >= 'today 00:00:00'
ORDER BY event_time DESC
'''
# ---
# name: TestCohortQuery.test_unwrapping_static_cohort_filter_hidden_in_layers_of_cohorts.2
'''
SELECT count(DISTINCT person_id)
FROM cohortpeople
@@ -2331,7 +2323,7 @@
AND version = 0
'''
# ---
# name: TestCohortQuery.test_unwrapping_static_cohort_filter_hidden_in_layers_of_cohorts.2
# name: TestCohortQuery.test_unwrapping_static_cohort_filter_hidden_in_layers_of_cohorts.3
'''
(SELECT cohort_people.person_id AS id
@@ -2379,7 +2371,7 @@
allow_experimental_join_condition=1
'''
# ---
# name: TestCohortQuery.test_unwrapping_static_cohort_filter_hidden_in_layers_of_cohorts.3
# name: TestCohortQuery.test_unwrapping_static_cohort_filter_hidden_in_layers_of_cohorts.4
'''
SELECT if(behavior_query.person_id = '00000000-0000-0000-0000-000000000000', person.person_id, behavior_query.person_id) AS id

View File

@@ -13,8 +13,6 @@
WHERE lc_cohort_id = 99999
AND team_id = 99999
AND query LIKE '%cohort_calc:00000000%'
AND type IN ('QueryFinish',
'ExceptionWhileProcessing')
AND event_date >= '2021-01-21'
AND event_time >= '2021-01-21 00:00:00'
ORDER BY event_time DESC
@@ -22,6 +20,25 @@
# ---
# name: TestEventQuery.test_account_filters.1
'''
/* celery:posthog.tasks.calculate_cohort.collect_cohort_query_stats */
SELECT query_id,
query_duration_ms,
read_rows,
read_bytes,
written_rows,
memory_usage,
exception
FROM query_log_archive
WHERE lc_cohort_id = 99999
AND team_id = 99999
AND query LIKE '%cohort_calc:00000000%'
AND event_date >= '2021-01-21'
AND event_time >= '2021-01-21 00:00:00'
ORDER BY event_time DESC
'''
# ---
# name: TestEventQuery.test_account_filters.2
'''
SELECT count(DISTINCT person_id)
FROM cohortpeople
@@ -30,7 +47,7 @@
AND version = 0
'''
# ---
# name: TestEventQuery.test_account_filters.2
# name: TestEventQuery.test_account_filters.3
'''
SELECT e.timestamp as timestamp,
if(notEmpty(pdi.distinct_id), pdi.person_id, e.person_id) as person_id

View File

@@ -49,7 +49,7 @@ from posthog.errors import ch_error_type, wrap_query_error
),
(
ServerException(
"Code: 439. DB::Exception: Cannot schedule a task: cannot allocate thread (threads=36, jobs=36). (CANNOT_SCHEDULE_TASK) (version 24.8.14.39 (official build))",
"Code: 439. DB::Exception: Cannot schedule a task: cannot allocate thread (threads=36, jobs=36). (CANNOT_SCHEDULE_TASK) (version 25.8.11.66 (official build))",
code=439,
),
"ClickHouseAtCapacity",
@@ -59,7 +59,7 @@ from posthog.errors import ch_error_type, wrap_query_error
),
(
ServerException(
"Code: 159. DB::Exception: Timeout exceeded: elapsed 60.046752587 seconds, maximum: 60. (TIMEOUT_EXCEEDED) (version 24.8.7.41 (official build))",
"Code: 159. DB::Exception: Timeout exceeded: elapsed 60.046752587 seconds, maximum: 60. (TIMEOUT_EXCEEDED) (version 25.8.11.66 (official build))",
code=159,
),
"ClickHouseQueryTimeOut",

View File

@@ -14,7 +14,7 @@ from posthog.api.routing import TeamAndOrgViewSetMixin
from posthog.api.shared import UserBasicSerializer
from posthog.models.activity_logging.activity_log import Detail, changes_between, log_activity
from posthog.models.experiment import ExperimentHoldout
from posthog.models.signals import model_activity_signal
from posthog.models.signals import model_activity_signal, mutable_receiver
class ExperimentHoldoutSerializer(serializers.ModelSerializer):
@@ -125,7 +125,7 @@ class ExperimentHoldoutViewSet(TeamAndOrgViewSetMixin, viewsets.ModelViewSet):
return super().destroy(request, *args, **kwargs)
@receiver(model_activity_signal, sender=ExperimentHoldout)
@mutable_receiver(model_activity_signal, sender=ExperimentHoldout)
def handle_experiment_holdout_change(
sender, scope, before_update, after_update, activity, user=None, was_impersonated=False, **kwargs
):

View File

@@ -20,7 +20,7 @@ from posthog.api.shared import UserBasicSerializer
from posthog.api.tagged_item import TaggedItemSerializerMixin
from posthog.models.activity_logging.activity_log import Detail, changes_between, log_activity
from posthog.models.experiment import ExperimentSavedMetric, ExperimentToSavedMetric
from posthog.models.signals import model_activity_signal
from posthog.models.signals import model_activity_signal, mutable_receiver
class ExperimentToSavedMetricSerializer(serializers.ModelSerializer):
@@ -112,7 +112,7 @@ class ExperimentSavedMetricViewSet(TeamAndOrgViewSetMixin, viewsets.ModelViewSet
serializer_class = ExperimentSavedMetricSerializer
@receiver(model_activity_signal, sender=ExperimentSavedMetric)
@mutable_receiver(model_activity_signal, sender=ExperimentSavedMetric)
def handle_experiment_saved_metric_change(
sender, scope, before_update, after_update, activity, user, was_impersonated=False, **kwargs
):

View File

@@ -1,12 +1,11 @@
from copy import deepcopy
from datetime import date, timedelta
from datetime import date, datetime, timedelta
from enum import Enum
from typing import Any, Literal
from zoneinfo import ZoneInfo
from django.db.models import Case, F, Prefetch, Q, QuerySet, Value, When
from django.db.models.functions import Now
from django.dispatch import receiver
from rest_framework import serializers, viewsets
from rest_framework.exceptions import ValidationError
@@ -35,7 +34,7 @@ from posthog.models.experiment import (
)
from posthog.models.feature_flag.feature_flag import FeatureFlag, FeatureFlagEvaluationTag
from posthog.models.filters.filter import Filter
from posthog.models.signals import model_activity_signal
from posthog.models.signals import model_activity_signal, mutable_receiver
from posthog.models.team.team import Team
from posthog.rbac.access_control_api_mixin import AccessControlViewSetMixin
from posthog.rbac.user_access_control import UserAccessControlSerializerMixin
@@ -1057,8 +1056,49 @@ class EnterpriseExperimentsViewSet(
status=201,
)
@action(methods=["GET"], detail=False, url_path="stats", required_scopes=["experiment:read"])
def stats(self, request: Request, **kwargs: Any) -> Response:
"""Get experimentation velocity statistics."""
team_tz = ZoneInfo(self.team.timezone) if self.team.timezone else ZoneInfo("UTC")
today = datetime.now(team_tz).date()
@receiver(model_activity_signal, sender=Experiment)
last_30d_start = today - timedelta(days=30)
previous_30d_start = today - timedelta(days=60)
previous_30d_end = last_30d_start
base_queryset = Experiment.objects.filter(team=self.team, deleted=False, archived=False)
launched_last_30d = base_queryset.filter(
start_date__gte=last_30d_start, start_date__lt=today + timedelta(days=1)
).count()
launched_previous_30d = base_queryset.filter(
start_date__gte=previous_30d_start, start_date__lt=previous_30d_end
).count()
if launched_previous_30d == 0:
percent_change = 100.0 if launched_last_30d > 0 else 0.0
else:
percent_change = ((launched_last_30d - launched_previous_30d) / launched_previous_30d) * 100
active_experiments = base_queryset.filter(start_date__isnull=False, end_date__isnull=True).count()
completed_last_30d = base_queryset.filter(
end_date__gte=last_30d_start, end_date__lt=today + timedelta(days=1)
).count()
return Response(
{
"launched_last_30d": launched_last_30d,
"launched_previous_30d": launched_previous_30d,
"percent_change": round(percent_change, 1),
"active_experiments": active_experiments,
"completed_last_30d": completed_last_30d,
}
)
@mutable_receiver(model_activity_signal, sender=Experiment)
def handle_experiment_change(
sender, scope, before_update, after_update, activity, user, was_impersonated=False, **kwargs
):

View File

@@ -6,7 +6,7 @@
"mobile-replay:schema:build:json": "pnpm mobile-replay:web:schema:build:json && pnpm mobile-replay:mobile:schema:build:json"
},
"dependencies": {
"posthog-js": "1.290.0"
"posthog-js": "1.292.0"
},
"devDependencies": {
"ts-json-schema-generator": "^v2.4.0-next.6"

View File

@@ -127,9 +127,9 @@ For a _lot_ of great detail on prompting, check out the [GPT-4.1 prompting guide
## Support new query types
Max can now read from frontend context multiple query types like trends, funnels, retention, and HogQL queries. To add support for new query types, you need to extend both the QueryExecutor and the Root node.
PostHog AI can now read from frontend context multiple query types like trends, funnels, retention, and HogQL queries. To add support for new query types, you need to extend both the QueryExecutor and the Root node.
NOTE: this won't extend query types generation. For that, talk to the Max AI team.
NOTE: this won't extend query types generation. For that, talk to the PostHog AI team.
### Adding a new query type

View File

@@ -453,18 +453,15 @@ Query results: 42 events
configurable={
"contextual_tools": {
"search_session_recordings": {"current_filters": {}},
"navigate": {"page_key": "insights"},
}
}
)
context_manager = AssistantContextManager(self.team, self.user, config)
tools = context_manager.get_contextual_tools()
self.assertEqual(len(tools), 2)
self.assertEqual(len(tools), 1)
self.assertIn("search_session_recordings", tools)
self.assertIn("navigate", tools)
self.assertEqual(tools["search_session_recordings"], {"current_filters": {}})
self.assertEqual(tools["navigate"], {"page_key": "insights"})
def test_get_contextual_tools_empty(self):
"""Test extraction of contextual tools returns empty dict when no tools"""

View File

@@ -187,23 +187,6 @@ async def eval_root(call_root, pytestconfig):
id="call_insight_default_props_2",
),
),
# Ensure we try and navigate to the relevant page when asked about specific topics
EvalCase(
input="What's my MRR?",
expected=AssistantToolCall(
name="navigate",
args={"page_key": "revenueAnalytics"},
id="call_navigate_1",
),
),
EvalCase(
input="Can you help me create a survey to collect NPS ratings?",
expected=AssistantToolCall(
name="navigate",
args={"page_key": "surveys"},
id="call_navigate_1",
),
),
EvalCase(
input="Give me the signup to purchase conversion rate for the dates between 8 Jul and 9 Sep",
expected=AssistantToolCall(

View File

@@ -0,0 +1,407 @@
"""Evaluations for CreateExperimentTool."""
import uuid
import pytest
from autoevals.partial import ScorerWithPartial
from autoevals.ragas import AnswerSimilarity
from braintrust import EvalCase, Score
from posthog.models import Experiment, FeatureFlag
from products.experiments.backend.max_tools import CreateExperimentTool
from ee.hogai.eval.base import MaxPublicEval
from ee.hogai.utils.types import AssistantState
from ee.models.assistant import Conversation
class ExperimentOutputScorer(ScorerWithPartial):
"""Custom scorer for experiment tool output that combines semantic similarity for text and exact matching for numbers/booleans."""
def __init__(self, semantic_fields: set[str] | None = None, **kwargs):
super().__init__(**kwargs)
self.semantic_fields = semantic_fields or {"message"}
def _run_eval_sync(self, output: dict, expected: dict, **kwargs):
if not expected:
return Score(name=self._name(), score=None, metadata={"reason": "No expected value provided"})
if not output:
return Score(name=self._name(), score=0.0, metadata={"reason": "No output provided"})
total_fields = len(expected)
if total_fields == 0:
return Score(name=self._name(), score=1.0)
score_per_field = 1.0 / total_fields
total_score = 0.0
metadata = {}
for field_name, expected_value in expected.items():
actual_value = output.get(field_name)
if field_name in self.semantic_fields:
# Use semantic similarity for text fields
if actual_value is not None and expected_value is not None:
similarity_scorer = AnswerSimilarity(model="text-embedding-3-small")
result = similarity_scorer.eval(output=str(actual_value), expected=str(expected_value))
field_score = result.score * score_per_field
total_score += field_score
metadata[f"{field_name}_score"] = result.score
else:
metadata[f"{field_name}_missing"] = True
else:
# Use exact match for numeric/boolean fields
if actual_value == expected_value:
total_score += score_per_field
metadata[f"{field_name}_match"] = True
else:
metadata[f"{field_name}_mismatch"] = {
"expected": expected_value,
"actual": actual_value,
}
return Score(name=self._name(), score=total_score, metadata=metadata)
@pytest.mark.django_db
async def eval_create_experiment_basic(pytestconfig, demo_org_team_user):
"""Test basic experiment creation."""
_, team, user = demo_org_team_user
conversation = await Conversation.objects.acreate(team=team, user=user)
async def task_create_experiment(args: dict):
# Create feature flag first (required by the tool)
await FeatureFlag.objects.acreate(
team=team,
created_by=user,
key=args["feature_flag_key"],
name=f"Flag for {args['name']}",
filters={
"groups": [{"properties": [], "rollout_percentage": 100}],
"multivariate": {
"variants": [
{"key": "control", "name": "Control", "rollout_percentage": 50},
{"key": "test", "name": "Test", "rollout_percentage": 50},
]
},
},
)
tool = await CreateExperimentTool.create_tool_class(
team=team,
user=user,
state=AssistantState(messages=[]),
config={
"configurable": {
"thread_id": conversation.id,
"team": team,
"user": user,
}
},
)
result_message, artifact = await tool._arun_impl(
name=args["name"],
feature_flag_key=args["feature_flag_key"],
description=args.get("description"),
type=args.get("type", "product"),
)
exp_created = await Experiment.objects.filter(team=team, name=args["name"], deleted=False).aexists()
return {
"message": result_message,
"experiment_created": exp_created,
"experiment_name": artifact.get("experiment_name") if artifact else None,
}
await MaxPublicEval(
experiment_name="create_experiment_basic",
task=task_create_experiment, # type: ignore
scores=[ExperimentOutputScorer(semantic_fields={"message", "experiment_name"})],
data=[
EvalCase(
input={"name": "Pricing Test", "feature_flag_key": "pricing-test-flag"},
expected={
"message": "Successfully created experiment",
"experiment_created": True,
"experiment_name": "Pricing Test",
},
),
EvalCase(
input={
"name": "Homepage Redesign",
"feature_flag_key": "homepage-redesign",
"description": "Testing new homepage layout for better conversion",
},
expected={
"message": "Successfully created experiment",
"experiment_created": True,
"experiment_name": "Homepage Redesign",
},
),
],
pytestconfig=pytestconfig,
)
@pytest.mark.django_db
async def eval_create_experiment_types(pytestconfig, demo_org_team_user):
"""Test experiment creation with different types (product vs web)."""
_, team, user = demo_org_team_user
conversation = await Conversation.objects.acreate(team=team, user=user)
async def task_create_typed_experiment(args: dict):
# Create feature flag first
await FeatureFlag.objects.acreate(
team=team,
created_by=user,
key=args["feature_flag_key"],
name=f"Flag for {args['name']}",
filters={
"groups": [{"properties": [], "rollout_percentage": 100}],
"multivariate": {
"variants": [
{"key": "control", "name": "Control", "rollout_percentage": 50},
{"key": "test", "name": "Test", "rollout_percentage": 50},
]
},
},
)
tool = await CreateExperimentTool.create_tool_class(
team=team,
user=user,
state=AssistantState(messages=[]),
config={
"configurable": {
"thread_id": conversation.id,
"team": team,
"user": user,
}
},
)
result_message, artifact = await tool._arun_impl(
name=args["name"],
feature_flag_key=args["feature_flag_key"],
type=args["type"],
)
# Verify experiment type
experiment = await Experiment.objects.aget(team=team, name=args["name"])
return {
"message": result_message,
"experiment_type": experiment.type,
"artifact_type": artifact.get("type") if artifact else None,
}
await MaxPublicEval(
experiment_name="create_experiment_types",
task=task_create_typed_experiment, # type: ignore
scores=[ExperimentOutputScorer(semantic_fields={"message"})],
data=[
EvalCase(
input={"name": "Product Feature Test", "feature_flag_key": "product-test", "type": "product"},
expected={
"message": "Successfully created experiment",
"experiment_type": "product",
"artifact_type": "product",
},
),
EvalCase(
input={"name": "Web UI Test", "feature_flag_key": "web-test", "type": "web"},
expected={
"message": "Successfully created experiment",
"experiment_type": "web",
"artifact_type": "web",
},
),
],
pytestconfig=pytestconfig,
)
@pytest.mark.django_db
async def eval_create_experiment_with_existing_flag(pytestconfig, demo_org_team_user):
"""Test experiment creation with an existing feature flag."""
_, team, user = demo_org_team_user
# Create an existing flag with unique key and multivariate variants
unique_key = f"reusable-flag-{uuid.uuid4().hex[:8]}"
await FeatureFlag.objects.acreate(
team=team,
key=unique_key,
name="Reusable Flag",
created_by=user,
filters={
"groups": [{"properties": [], "rollout_percentage": 100}],
"multivariate": {
"variants": [
{"key": "control", "name": "Control", "rollout_percentage": 50},
{"key": "test", "name": "Test", "rollout_percentage": 50},
]
},
},
)
conversation = await Conversation.objects.acreate(team=team, user=user)
async def task_create_experiment_reuse_flag(args: dict):
tool = await CreateExperimentTool.create_tool_class(
team=team,
user=user,
state=AssistantState(messages=[]),
config={
"configurable": {
"thread_id": conversation.id,
"team": team,
"user": user,
}
},
)
result_message, artifact = await tool._arun_impl(
name=args["name"],
feature_flag_key=args["feature_flag_key"],
)
return {
"message": result_message,
"experiment_created": artifact is not None and "experiment_id" in artifact,
}
await MaxPublicEval(
experiment_name="create_experiment_with_existing_flag",
task=task_create_experiment_reuse_flag, # type: ignore
scores=[ExperimentOutputScorer(semantic_fields={"message"})],
data=[
EvalCase(
input={"name": "Reuse Flag Test", "feature_flag_key": unique_key},
expected={
"message": "Successfully created experiment",
"experiment_created": True,
},
),
],
pytestconfig=pytestconfig,
)
@pytest.mark.django_db
async def eval_create_experiment_duplicate_name_error(pytestconfig, demo_org_team_user):
"""Test that creating a duplicate experiment returns an error."""
_, team, user = demo_org_team_user
# Create an existing experiment with unique flag key
unique_flag_key = f"test-flag-{uuid.uuid4().hex[:8]}"
flag = await FeatureFlag.objects.acreate(team=team, key=unique_flag_key, created_by=user)
await Experiment.objects.acreate(team=team, name="Existing Experiment", feature_flag=flag, created_by=user)
conversation = await Conversation.objects.acreate(team=team, user=user)
async def task_create_duplicate_experiment(args: dict):
# Create a different flag for the duplicate attempt
await FeatureFlag.objects.acreate(
team=team,
created_by=user,
key=args["feature_flag_key"],
name=f"Flag for {args['name']}",
)
tool = await CreateExperimentTool.create_tool_class(
team=team,
user=user,
state=AssistantState(messages=[]),
config={
"configurable": {
"thread_id": conversation.id,
"team": team,
"user": user,
}
},
)
result_message, artifact = await tool._arun_impl(
name=args["name"],
feature_flag_key=args["feature_flag_key"],
)
return {
"message": result_message,
"has_error": artifact.get("error") is not None if artifact else False,
}
await MaxPublicEval(
experiment_name="create_experiment_duplicate_name_error",
task=task_create_duplicate_experiment, # type: ignore
scores=[ExperimentOutputScorer(semantic_fields={"message"})],
data=[
EvalCase(
input={"name": "Existing Experiment", "feature_flag_key": "another-flag"},
expected={
"message": "Failed to create experiment: An experiment with name 'Existing Experiment' already exists",
"has_error": True,
},
),
],
pytestconfig=pytestconfig,
)
@pytest.mark.django_db
async def eval_create_experiment_flag_already_used_error(pytestconfig, demo_org_team_user):
"""Test that using a flag already tied to another experiment returns an error."""
_, team, user = demo_org_team_user
# Create an experiment with a flag (unique key)
unique_flag_key = f"used-flag-{uuid.uuid4().hex[:8]}"
flag = await FeatureFlag.objects.acreate(team=team, key=unique_flag_key, created_by=user)
await Experiment.objects.acreate(team=team, name="First Experiment", feature_flag=flag, created_by=user)
conversation = await Conversation.objects.acreate(team=team, user=user)
async def task_create_experiment_with_used_flag(args: dict):
tool = await CreateExperimentTool.create_tool_class(
team=team,
user=user,
state=AssistantState(messages=[]),
config={
"configurable": {
"thread_id": conversation.id,
"team": team,
"user": user,
}
},
)
result_message, artifact = await tool._arun_impl(
name=args["name"],
feature_flag_key=args["feature_flag_key"],
)
return {
"message": result_message,
"has_error": artifact.get("error") is not None if artifact else False,
}
await MaxPublicEval(
experiment_name="create_experiment_flag_already_used_error",
task=task_create_experiment_with_used_flag, # type: ignore
scores=[ExperimentOutputScorer(semantic_fields={"message"})],
data=[
EvalCase(
input={"name": "Second Experiment", "feature_flag_key": unique_flag_key},
expected={
"message": "Failed to create experiment: Feature flag is already used by experiment",
"has_error": True,
},
),
],
pytestconfig=pytestconfig,
)

View File

@@ -365,3 +365,467 @@ async def eval_create_feature_flag_duplicate_handling(pytestconfig, demo_org_tea
],
pytestconfig=pytestconfig,
)
@pytest.mark.django_db
async def eval_create_multivariate_feature_flag(pytestconfig, demo_org_team_user):
"""Test multivariate feature flag creation for A/B testing."""
_, team, user = demo_org_team_user
conversation = await Conversation.objects.acreate(team=team, user=user)
# Generate unique keys for this test run
unique_suffix = uuid.uuid4().hex[:6]
async def task_create_multivariate_flag(instructions: str):
tool = await CreateFeatureFlagTool.create_tool_class(
team=team,
user=user,
state=AssistantState(messages=[]),
config={
"configurable": {
"thread_id": conversation.id,
"team": team,
"user": user,
}
},
)
result_message, artifact = await tool._arun_impl(instructions=instructions)
flag_key = artifact.get("flag_key")
# Verify multivariate schema structure if flag was created
variant_count = 0
schema_valid = False
has_multivariate = False
if flag_key:
try:
flag = await FeatureFlag.objects.aget(team=team, key=flag_key)
multivariate = flag.filters.get("multivariate")
if multivariate:
has_multivariate = True
variants = multivariate.get("variants", [])
variant_count = len(variants)
# Verify each variant has the required schema structure
schema_valid = True
for variant in variants:
# Verify variant has key and rollout_percentage (name is optional)
if not all(key in variant for key in ["key", "rollout_percentage"]):
schema_valid = False
break
# Verify rollout_percentage is a number
if not isinstance(variant["rollout_percentage"], int | float):
schema_valid = False
break
except FeatureFlag.DoesNotExist:
pass
return {
"message": result_message,
"has_multivariate": has_multivariate,
"variant_count": variant_count,
"created": flag_key is not None,
"schema_valid": schema_valid,
}
await MaxPublicEval(
experiment_name="create_multivariate_feature_flag",
task=task_create_multivariate_flag, # type: ignore
scores=[FeatureFlagOutputScorer(semantic_fields={"message"})],
data=[
EvalCase(
input=f"Create an A/B test flag called 'ab-test-{unique_suffix}' with control and test variants",
expected={
"message": f"Successfully created feature flag 'ab-test-{unique_suffix}' with A/B test",
"has_multivariate": True,
"variant_count": 2,
"created": True,
"schema_valid": True,
},
),
EvalCase(
input=f"Create a multivariate flag called 'abc-test-{unique_suffix}' with 3 variants for testing",
expected={
"message": f"Successfully created feature flag 'abc-test-{unique_suffix}' with multivariate",
"has_multivariate": True,
"variant_count": 3,
"created": True,
"schema_valid": True,
},
),
EvalCase(
input=f"Create an A/B test flag called 'pricing-test-{unique_suffix}' for testing new pricing",
expected={
"message": f"Successfully created feature flag 'pricing-test-{unique_suffix}' with A/B test",
"has_multivariate": True,
"variant_count": 2,
"created": True,
"schema_valid": True,
},
),
],
pytestconfig=pytestconfig,
)
@pytest.mark.django_db
async def eval_create_multivariate_with_rollout(pytestconfig, demo_org_team_user):
"""Test multivariate feature flags with rollout percentages for targeted experiments."""
_, team, user = demo_org_team_user
conversation = await Conversation.objects.acreate(team=team, user=user)
# Generate unique keys for this test run
unique_suffix = uuid.uuid4().hex[:6]
async def task_create_multivariate_with_rollout(instructions: str):
tool = await CreateFeatureFlagTool.create_tool_class(
team=team,
user=user,
state=AssistantState(messages=[]),
config={
"configurable": {
"thread_id": conversation.id,
"team": team,
"user": user,
}
},
)
result_message, artifact = await tool._arun_impl(instructions=instructions)
flag_key = artifact.get("flag_key")
# Verify multivariate and rollout schema structure
has_multivariate = False
has_rollout = False
variant_count = 0
rollout_percentage = None
schema_valid = False
if flag_key:
try:
flag = await FeatureFlag.objects.aget(team=team, key=flag_key)
# Check multivariate config
multivariate = flag.filters.get("multivariate")
if multivariate:
has_multivariate = True
variants = multivariate.get("variants", [])
variant_count = len(variants)
# Check rollout in groups
groups = flag.filters.get("groups", [])
if groups and len(groups) > 0:
group = groups[0]
rollout_percentage = group.get("rollout_percentage")
has_rollout = rollout_percentage is not None
# Verify schema
schema_valid = has_multivariate and variant_count > 0
if has_rollout:
schema_valid = schema_valid and isinstance(rollout_percentage, int | float)
except FeatureFlag.DoesNotExist:
pass
return {
"message": result_message,
"has_multivariate": has_multivariate,
"has_rollout": has_rollout,
"variant_count": variant_count,
"created": flag_key is not None,
"schema_valid": schema_valid,
}
await MaxPublicEval(
experiment_name="create_multivariate_with_rollout",
task=task_create_multivariate_with_rollout, # type: ignore
scores=[FeatureFlagOutputScorer(semantic_fields={"message"})],
data=[
EvalCase(
input=f"Create an A/B test flag called 'ab-rollout-{unique_suffix}' with control and test variants at 50% rollout",
expected={
"message": f"Successfully created feature flag 'ab-rollout-{unique_suffix}' with A/B test and 50% rollout",
"has_multivariate": True,
"has_rollout": True,
"variant_count": 2,
"created": True,
"schema_valid": True,
},
),
EvalCase(
input=f"Create a multivariate flag called 'experiment-{unique_suffix}' with 3 variants at 10% rollout",
expected={
"message": f"Successfully created feature flag 'experiment-{unique_suffix}' with multivariate and 10% rollout",
"has_multivariate": True,
"has_rollout": True,
"variant_count": 3,
"created": True,
"schema_valid": True,
},
),
],
pytestconfig=pytestconfig,
)
@pytest.mark.django_db
async def eval_create_multivariate_with_property_filters(pytestconfig, demo_org_team_user):
"""Test multivariate feature flags with property-based targeting for segment-specific experiments."""
_, team, user = demo_org_team_user
conversation = await Conversation.objects.acreate(team=team, user=user)
# Generate unique keys for this test run
unique_suffix = uuid.uuid4().hex[:6]
async def task_create_multivariate_with_properties(instructions: str):
tool = await CreateFeatureFlagTool.create_tool_class(
team=team,
user=user,
state=AssistantState(messages=[]),
config={
"configurable": {
"thread_id": conversation.id,
"team": team,
"user": user,
}
},
)
result_message, artifact = await tool._arun_impl(instructions=instructions)
flag_key = artifact.get("flag_key")
# Verify multivariate and property filter schema structure
has_multivariate = False
has_properties = False
variant_count = 0
property_count = 0
schema_valid = False
if flag_key:
try:
flag = await FeatureFlag.objects.aget(team=team, key=flag_key)
# Check multivariate config
multivariate = flag.filters.get("multivariate")
if multivariate:
has_multivariate = True
variants = multivariate.get("variants", [])
variant_count = len(variants)
# Check properties in groups
groups = flag.filters.get("groups", [])
for group in groups:
properties = group.get("properties", [])
property_count += len(properties)
# Verify each property has required schema structure
for prop in properties:
if all(key in prop for key in ["key", "type", "value", "operator"]):
has_properties = True
else:
schema_valid = False
break
# Verify schema
schema_valid = has_multivariate and variant_count > 0 and has_properties
except FeatureFlag.DoesNotExist:
pass
return {
"message": result_message,
"has_multivariate": has_multivariate,
"has_properties": has_properties,
"variant_count": variant_count,
"created": flag_key is not None,
"schema_valid": schema_valid,
}
await MaxPublicEval(
experiment_name="create_multivariate_with_property_filters",
task=task_create_multivariate_with_properties, # type: ignore
scores=[FeatureFlagOutputScorer(semantic_fields={"message"})],
data=[
EvalCase(
input=f"Create an A/B test flag called 'email-test-{unique_suffix}' for users where email contains @company.com with control and test variants",
expected={
"message": f"Successfully created feature flag 'email-test-{unique_suffix}' with A/B test for users where email contains @company.com",
"has_multivariate": True,
"has_properties": True,
"variant_count": 2,
"created": True,
"schema_valid": True,
},
),
EvalCase(
input=f"Create a multivariate flag called 'us-experiment-{unique_suffix}' with 3 variants targeting US users",
expected={
"message": f"Successfully created feature flag 'us-experiment-{unique_suffix}' with multivariate targeting US users",
"has_multivariate": True,
"has_properties": True,
"variant_count": 3,
"created": True,
"schema_valid": True,
},
),
],
pytestconfig=pytestconfig,
)
@pytest.mark.django_db
async def eval_create_multivariate_with_custom_percentages(pytestconfig, demo_org_team_user):
"""Test multivariate feature flags with custom variant percentage distributions."""
_, team, user = demo_org_team_user
conversation = await Conversation.objects.acreate(team=team, user=user)
# Generate unique keys for this test run
unique_suffix = uuid.uuid4().hex[:6]
async def task_create_multivariate_custom_percentages(instructions: str):
tool = await CreateFeatureFlagTool.create_tool_class(
team=team,
user=user,
state=AssistantState(messages=[]),
config={
"configurable": {
"thread_id": conversation.id,
"team": team,
"user": user,
}
},
)
result_message, artifact = await tool._arun_impl(instructions=instructions)
flag_key = artifact.get("flag_key")
# Verify multivariate with custom percentages
has_multivariate = False
variant_count = 0
percentages_sum_to_100 = False
schema_valid = False
if flag_key:
try:
flag = await FeatureFlag.objects.aget(team=team, key=flag_key)
multivariate = flag.filters.get("multivariate")
if multivariate:
has_multivariate = True
variants = multivariate.get("variants", [])
variant_count = len(variants)
# Check if percentages sum to 100
total_percentage = sum(v.get("rollout_percentage", 0) for v in variants)
percentages_sum_to_100 = total_percentage == 100
# Verify schema
schema_valid = True
for variant in variants:
if not all(key in variant for key in ["key", "rollout_percentage"]):
schema_valid = False
break
except FeatureFlag.DoesNotExist:
pass
return {
"message": result_message,
"has_multivariate": has_multivariate,
"variant_count": variant_count,
"percentages_valid": percentages_sum_to_100,
"created": flag_key is not None,
"schema_valid": schema_valid and percentages_sum_to_100,
}
await MaxPublicEval(
experiment_name="create_multivariate_with_custom_percentages",
task=task_create_multivariate_custom_percentages, # type: ignore
scores=[FeatureFlagOutputScorer(semantic_fields={"message"})],
data=[
EvalCase(
input=f"Create an A/B test flag called 'uneven-test-{unique_suffix}' with control at 70% and test at 30%",
expected={
"message": f"Successfully created feature flag 'uneven-test-{unique_suffix}' with A/B test",
"has_multivariate": True,
"variant_count": 2,
"percentages_valid": True,
"created": True,
"schema_valid": True,
},
),
EvalCase(
input=f"Create a multivariate flag called 'weighted-test-{unique_suffix}' with control (33%), variant_a (33%), variant_b (34%)",
expected={
"message": f"Successfully created feature flag 'weighted-test-{unique_suffix}' with multivariate",
"has_multivariate": True,
"variant_count": 3,
"percentages_valid": True,
"created": True,
"schema_valid": True,
},
),
],
pytestconfig=pytestconfig,
)
@pytest.mark.django_db
async def eval_create_multivariate_error_handling(pytestconfig, demo_org_team_user):
"""Test multivariate feature flag error handling for invalid configurations."""
_, team, user = demo_org_team_user
conversation = await Conversation.objects.acreate(team=team, user=user)
# Generate unique keys for this test run
unique_suffix = uuid.uuid4().hex[:6]
async def task_create_invalid_multivariate(instructions: str):
tool = await CreateFeatureFlagTool.create_tool_class(
team=team,
user=user,
state=AssistantState(messages=[]),
config={
"configurable": {
"thread_id": conversation.id,
"team": team,
"user": user,
}
},
)
result_message, artifact = await tool._arun_impl(instructions=instructions)
# Check if error was properly reported
has_error = (
"error" in artifact or "invalid" in result_message.lower() or "must sum to 100" in result_message.lower()
)
return {
"message": result_message,
"has_error": has_error,
}
await MaxPublicEval(
experiment_name="create_multivariate_error_handling",
task=task_create_invalid_multivariate, # type: ignore
scores=[FeatureFlagOutputScorer(semantic_fields={"message"})],
data=[
EvalCase(
input=f"Create an A/B test flag called 'invalid-percentage-{unique_suffix}' with control at 60% and test at 50%",
expected={
"message": "The variant percentages you provided (control: 60%, test: 50%) sum to 110%, but they must sum to exactly 100%. Please adjust the percentages so they add up to 100.",
"has_error": True,
},
),
],
pytestconfig=pytestconfig,
)

View File

@@ -1,147 +0,0 @@
import pytest
from braintrust import EvalCase
from pydantic import BaseModel, Field
from posthog.schema import AssistantMessage, AssistantNavigateUrl, AssistantToolCall, FailureMessage, HumanMessage
from ee.hogai.django_checkpoint.checkpointer import DjangoCheckpointer
from ee.hogai.graph.graph import AssistantGraph
from ee.hogai.utils.types import AssistantMessageUnion, AssistantNodeName, AssistantState
from ee.models.assistant import Conversation
from ...base import MaxPublicEval
from ...scorers import ToolRelevance
TOOLS_PROMPT = """
- **actions**: Combine several related events into one, which you can then analyze in insights and dashboards as if it were a single event.
- **cohorts**: A catalog of identified persons and your created cohorts.
- **dashboards**: Create and manage your dashboards
- **earlyAccessFeatures**: Allow your users to individually enable or disable features that are in public beta.
- **errorTracking**: Track and analyze your error tracking data to understand and fix issues. [tools: Filter issues, Find impactful issues]
- **experiments**: Experiments help you test changes to your product to see which changes will lead to optimal results. Automatic statistical calculations let you see if the results are valid or if they are likely just a chance occurrence.
- **featureFlags**: Use feature flags to safely deploy and roll back new features in an easy-to-manage way. Roll variants out to certain groups, a percentage of users, or everyone all at once.
- **notebooks**: Notebooks are a way to organize your work and share it with others.
- **persons**: A catalog of all the people behind your events
- **insights**: Track, analyze, and experiment with user behavior.
- **insightNew** [tools: Edit the insight]
- **savedInsights**: Track, analyze, and experiment with user behavior.
- **alerts**: Track, analyze, and experiment with user behavior.
- **replay**: Replay recordings of user sessions to understand how users interact with your product or website. [tools: Search recordings]
- **revenueAnalytics**: Track and analyze your revenue metrics to understand your business performance and growth. [tools: Filter revenue analytics]
- **surveys**: Create surveys to collect feedback from your users [tools: Create surveys, Analyze survey responses]
- **webAnalytics**: Analyze your web analytics data to understand website performance and user behavior.
- **webAnalyticsWebVitals**: Analyze your web analytics data to understand website performance and user behavior.
- **activity**: A catalog of all user interactions with your app or website.
- **sqlEditor**: Write and execute SQL queries against your data warehouse [tools: Write and tweak SQL]
- **heatmaps**: Heatmaps are a way to visualize user behavior on your website.
""".strip()
class EvalInput(BaseModel):
messages: str | list[AssistantMessageUnion]
current_page: str = Field(default="")
@pytest.fixture
def call_root(demo_org_team_user):
graph = (
AssistantGraph(demo_org_team_user[1], demo_org_team_user[2])
.add_edge(AssistantNodeName.START, AssistantNodeName.ROOT)
.add_root(lambda state: AssistantNodeName.END)
# TRICKY: We need to set a checkpointer here because async tests create a new event loop.
.compile(checkpointer=DjangoCheckpointer())
)
async def callable(input: EvalInput) -> AssistantMessage:
conversation = await Conversation.objects.acreate(team=demo_org_team_user[1], user=demo_org_team_user[2])
initial_state = AssistantState(
messages=[HumanMessage(content=input.messages)] if isinstance(input.messages, str) else input.messages
)
raw_state = await graph.ainvoke(
initial_state,
{
"configurable": {
"thread_id": conversation.id,
"contextual_tools": {
"navigate": {"scene_descriptions": TOOLS_PROMPT, "current_page": input.current_page}
},
}
},
)
state = AssistantState.model_validate(raw_state)
assert isinstance(state.messages[-1], AssistantMessage)
return state.messages[-1]
return callable
@pytest.mark.django_db
async def eval_root_navigate_tool(call_root, pytestconfig):
await MaxPublicEval(
experiment_name="root_navigate_tool",
task=call_root,
scores=[ToolRelevance(semantic_similarity_args={"query_description"})],
data=[
# Shouldn't navigate to the insights page
EvalCase(
input=EvalInput(messages="build pageview insight"),
expected=AssistantToolCall(
id="1",
name="create_and_query_insight",
args={
"query_description": "Create a trends insight showing pageview events over time. Track the $pageview event to visualize how many pageviews are happening."
},
),
),
# Should navigate to the persons page
EvalCase(
input=EvalInput(
messages="I added tracking of persons, but I can't find where the persons are in the app"
),
expected=AssistantToolCall(
id="1",
name="navigate",
args={"page_key": AssistantNavigateUrl.PERSONS.value},
),
),
# Should navigate to the surveys page
EvalCase(
input=EvalInput(
messages="were is my survey. I jsut created a survey and save it as draft, I cannot find it now",
current_page="/project/1/surveys/new",
),
expected=AssistantToolCall(
id="1",
name="navigate",
args={"page_key": AssistantNavigateUrl.SURVEYS.value},
),
),
# Should not navigate to the SQL editor
EvalCase(
input=EvalInput(
messages="I need a query written in SQL to tell me what all of my identified events are for any given day."
),
expected=AssistantToolCall(
id="1",
name="create_and_query_insight",
args={"query_description": "All identified events for any given day"},
),
),
# Should just say that the query failed
EvalCase(
input=EvalInput(
messages=[
HumanMessage(
content="I need a query written in SQL to tell me what all of my identified events are for any given day."
),
FailureMessage(
content="An unknown failure occurred while accessing the `events` table",
),
]
),
expected=None,
),
],
pytestconfig=pytestconfig,
)

View File

@@ -7,10 +7,12 @@ from uuid import uuid4
from langchain_anthropic import ChatAnthropic
from langchain_core.language_models import BaseChatModel
from langchain_core.messages import (
AIMessage as LangchainAIMessage,
BaseMessage,
HumanMessage as LangchainHumanMessage,
)
from langchain_core.tools import BaseTool
from langchain_core.utils.function_calling import convert_to_openai_tool
from pydantic import BaseModel
from posthog.schema import AssistantMessage, AssistantToolCallMessage, ContextMessage, HumanMessage
@@ -36,7 +38,7 @@ class ConversationCompactionManager(ABC):
Manages conversation window boundaries, message filtering, and summarization decisions.
"""
CONVERSATION_WINDOW_SIZE = 64000
CONVERSATION_WINDOW_SIZE = 100_000
"""
Determines the maximum number of tokens allowed in the conversation window.
"""
@@ -54,7 +56,7 @@ class ConversationCompactionManager(ABC):
new_window_id: str | None = None
for message in reversed(messages):
# Handle limits before assigning the window ID.
max_tokens -= self._get_estimated_tokens(message)
max_tokens -= self._get_estimated_assistant_message_tokens(message)
max_messages -= 1
if max_tokens < 0 or max_messages < 0:
break
@@ -83,12 +85,20 @@ class ConversationCompactionManager(ABC):
Determine if the conversation should be summarized based on token count.
Avoids summarizing if there are only two human messages or fewer.
"""
return await self.calculate_token_count(model, messages, tools, **kwargs) > self.CONVERSATION_WINDOW_SIZE
async def calculate_token_count(
self, model: BaseChatModel, messages: list[BaseMessage], tools: LangchainTools | None = None, **kwargs
) -> int:
"""
Calculate the token count for a conversation.
"""
# Avoid summarizing the conversation if there is only two human messages.
human_messages = [message for message in messages if isinstance(message, LangchainHumanMessage)]
if len(human_messages) <= 2:
return False
token_count = await self._get_token_count(model, messages, tools, **kwargs)
return token_count > self.CONVERSATION_WINDOW_SIZE
tool_tokens = self._get_estimated_tools_tokens(tools) if tools else 0
return sum(self._get_estimated_langchain_message_tokens(message) for message in messages) + tool_tokens
return await self._get_token_count(model, messages, tools, **kwargs)
def update_window(
self, messages: Sequence[T], summary_message: ContextMessage, start_id: str | None = None
@@ -134,7 +144,7 @@ class ConversationCompactionManager(ABC):
updated_window_start_id=window_start_id_candidate,
)
def _get_estimated_tokens(self, message: AssistantMessageUnion) -> int:
def _get_estimated_assistant_message_tokens(self, message: AssistantMessageUnion) -> int:
"""
Estimate token count for a message using character/4 heuristic.
"""
@@ -149,6 +159,24 @@ class ConversationCompactionManager(ABC):
char_count = len(message.content)
return round(char_count / self.APPROXIMATE_TOKEN_LENGTH)
def _get_estimated_langchain_message_tokens(self, message: BaseMessage) -> int:
"""
Estimate token count for a message using character/4 heuristic.
"""
char_count = 0
if isinstance(message.content, str):
char_count = len(message.content)
else:
for content in message.content:
if isinstance(content, str):
char_count += len(content)
elif isinstance(content, dict):
char_count += self._count_json_tokens(content)
if isinstance(message, LangchainAIMessage) and message.tool_calls:
for tool_call in message.tool_calls:
char_count += len(json.dumps(tool_call, separators=(",", ":")))
return round(char_count / self.APPROXIMATE_TOKEN_LENGTH)
def _get_conversation_window(self, messages: Sequence[T], start_id: str) -> Sequence[T]:
"""
Get messages from the start_id onwards.
@@ -158,6 +186,22 @@ class ConversationCompactionManager(ABC):
return messages[idx:]
return messages
def _get_estimated_tools_tokens(self, tools: LangchainTools) -> int:
"""
Estimate token count for tools by converting them to JSON schemas.
"""
if not tools:
return 0
total_chars = 0
for tool in tools:
tool_schema = convert_to_openai_tool(tool)
total_chars += self._count_json_tokens(tool_schema)
return round(total_chars / self.APPROXIMATE_TOKEN_LENGTH)
def _count_json_tokens(self, json_data: dict) -> int:
return len(json.dumps(json_data, separators=(",", ":")))
@abstractmethod
async def _get_token_count(
self,

View File

@@ -13,7 +13,6 @@ from langchain_core.messages import (
)
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnableConfig
from langgraph.errors import NodeInterrupt
from langgraph.types import Send
from posthoganalytics import capture_exception
@@ -27,6 +26,7 @@ from ee.hogai.graph.conversation_summarizer.nodes import AnthropicConversationSu
from ee.hogai.graph.shared_prompts import CORE_MEMORY_PROMPT
from ee.hogai.llm import MaxChatAnthropic
from ee.hogai.tool import ToolMessagesArtifact
from ee.hogai.tool_errors import MaxToolError
from ee.hogai.tools import ReadDataTool, ReadTaxonomyTool, SearchTool, TodoWriteTool
from ee.hogai.utils.anthropic import add_cache_control, convert_to_anthropic_messages
from ee.hogai.utils.helpers import convert_tool_messages_to_dict, normalize_ai_message
@@ -42,9 +42,9 @@ from ee.hogai.utils.types.base import NodePath
from .compaction_manager import AnthropicConversationCompactionManager
from .prompts import (
AGENT_CORE_MEMORY_PROMPT,
AGENT_PROMPT,
BASIC_FUNCTIONALITY_PROMPT,
CORE_MEMORY_INSTRUCTIONS_PROMPT,
DOING_TASKS_PROMPT,
PROACTIVENESS_PROMPT,
ROLE_PROMPT,
@@ -187,12 +187,18 @@ class AgentExecutable(BaseAgentExecutable):
start_id = state.start_id
# Summarize the conversation if it's too long.
if await self._window_manager.should_compact_conversation(
current_token_count = await self._window_manager.calculate_token_count(
model, langchain_messages, tools=tools, thinking_config=self.THINKING_CONFIG
):
)
if current_token_count > self._window_manager.CONVERSATION_WINDOW_SIZE:
# Exclude the last message if it's the first turn.
messages_to_summarize = langchain_messages[:-1] if self._is_first_turn(state) else langchain_messages
summary = await AnthropicConversationSummarizer(self._team, self._user).summarize(messages_to_summarize)
summary = await AnthropicConversationSummarizer(
self._team,
self._user,
extend_context_window=current_token_count > 195_000,
).summarize(messages_to_summarize)
summary_message = ContextMessage(
content=ROOT_CONVERSATION_SUMMARY_PROMPT.format(summary=summary),
id=str(uuid4()),
@@ -212,6 +218,7 @@ class AgentExecutable(BaseAgentExecutable):
system_prompts = ChatPromptTemplate.from_messages(
[
("system", self._get_system_prompt(state, config)),
("system", AGENT_CORE_MEMORY_PROMPT),
],
template_format="mustache",
).format_messages(
@@ -221,7 +228,7 @@ class AgentExecutable(BaseAgentExecutable):
)
# Mark the longest default prefix as cacheable
add_cache_control(system_prompts[-1])
add_cache_control(system_prompts[0], ttl="1h")
message = await model.ainvoke(system_prompts + langchain_messages, config)
assistant_message = self._process_output_message(message)
@@ -263,7 +270,6 @@ class AgentExecutable(BaseAgentExecutable):
- `{{{task_management}}}`
- `{{{doing_tasks}}}`
- `{{{tool_usage_policy}}}`
- `{{{core_memory_instructions}}}`
The variables from above can have the following nested variables that will be injected:
- `{{{groups}}}` a prompt containing the description of the groups.
@@ -291,7 +297,6 @@ class AgentExecutable(BaseAgentExecutable):
task_management=TASK_MANAGEMENT_PROMPT,
doing_tasks=DOING_TASKS_PROMPT,
tool_usage_policy=TOOL_USAGE_POLICY_PROMPT,
core_memory_instructions=CORE_MEMORY_INSTRUCTIONS_PROMPT,
)
async def _get_billing_prompt(self) -> str:
@@ -319,10 +324,11 @@ class AgentExecutable(BaseAgentExecutable):
stream_usage=True,
user=self._user,
team=self._team,
betas=["interleaved-thinking-2025-05-14"],
betas=["interleaved-thinking-2025-05-14", "context-1m-2025-08-07"],
max_tokens=8192,
thinking=self.THINKING_CONFIG,
conversation_start_dt=state.start_dt,
billable=True,
)
# The agent can operate in loops. Since insight building is an expensive operation, we want to limit a recursion depth.
@@ -465,6 +471,30 @@ class AgentToolsExecutable(BaseAgentExecutable):
raise ValueError(
f"Tool '{tool_call.name}' returned {type(result).__name__}, expected LangchainToolMessage"
)
except MaxToolError as e:
logger.exception(
"maxtool_error", extra={"tool": tool_call.name, "error": str(e), "retry_strategy": e.retry_strategy}
)
capture_exception(
e,
distinct_id=self._get_user_distinct_id(config),
properties={
**self._get_debug_props(config),
"tool": tool_call.name,
"retry_strategy": e.retry_strategy,
},
)
content = f"Tool failed: {e.to_summary()}.{e.retry_hint}"
return PartialAssistantState(
messages=[
AssistantToolCallMessage(
content=content,
id=str(uuid4()),
tool_call_id=tool_call.id,
)
],
)
except Exception as e:
logger.exception("Error calling tool", extra={"tool_name": tool_call.name, "error": str(e)})
capture_exception(
@@ -485,21 +515,6 @@ class AgentToolsExecutable(BaseAgentExecutable):
messages=result.artifact.messages,
)
# If this is a navigation tool call, pause the graph execution
# so that the frontend can re-initialise Max with a new set of contextual tools.
if tool_call.name == "navigate":
navigate_message = AssistantToolCallMessage(
content=str(result.content) if result.content else "",
ui_payload={tool_call.name: result.artifact},
id=str(uuid4()),
tool_call_id=tool_call.id,
)
# Raising a `NodeInterrupt` ensures the assistant graph stops here and
# surfaces the navigation confirmation to the client. The next user
# interaction will resume the graph with potentially different
# contextual tools.
raise NodeInterrupt(navigate_message)
tool_message = AssistantToolCallMessage(
content=str(result.content) if result.content else "",
ui_payload={tool_call.name: result.artifact},

View File

@@ -27,6 +27,7 @@ Do not create links like "here" or "click here". All links should have relevant
We always use sentence case rather than title case, including in titles, headings, subheadings, or bold text. However if quoting provided text, we keep the original case.
When writing numbers in the thousands to the billions, it's acceptable to abbreviate them (like 10M or 100B - capital letter, no space). If you write out the full number, use commas (like 15,000,000).
You can use light Markdown formatting for readability. Never use the em-dash (—) if you can use the en-dash ().
For headers, use sentence case rather than title case.
</writing_style>
""".strip()
@@ -128,14 +129,10 @@ The user is a product engineer and will primarily request you perform product ma
TOOL_USAGE_POLICY_PROMPT = """
<tool_usage_policy>
- You can invoke multiple tools within a single response. When a request involves several independent pieces of information, batch your tool calls together for optimal performance
- Retry failed tool calls only if the error proposes retrying, or suggests how to fix tool arguments
</tool_usage_policy>
""".strip()
CORE_MEMORY_INSTRUCTIONS_PROMPT = """
{{{core_memory}}}
New memories will automatically be added to the core memory as the conversation progresses. If users ask to save, update, or delete the core memory, say you have done it. If the '/remember [information]' command is used, the information gets appended verbatim to core memory.
""".strip()
AGENT_PROMPT = """
{{{role}}}
@@ -154,8 +151,11 @@ AGENT_PROMPT = """
{{{tool_usage_policy}}}
{{{billing_context}}}
""".strip()
{{{core_memory_instructions}}}
AGENT_CORE_MEMORY_PROMPT = """
{{{core_memory}}}
New memories will automatically be added to the core memory as the conversation progresses. If users ask to save, update, or delete the core memory, say you have done it. If the '/remember [information]' command is used, the information gets appended verbatim to core memory.
""".strip()
# Conditional prompts

View File

@@ -114,11 +114,11 @@ class TestAnthropicConversationCompactionManager(BaseTest):
@parameterized.expand(
[
# (num_human_messages, token_count, should_compact)
[1, 70000, False], # Only 1 human message
[2, 70000, False], # Only 2 human messages
[3, 50000, False], # 3 human messages but under token limit
[3, 70000, True], # 3 human messages and over token limit
[5, 70000, True], # Many messages over limit
[1, 90000, False], # Only 1 human message, under limit
[2, 90000, False], # Only 2 human messages, under limit
[3, 80000, False], # 3 human messages but under token limit
[3, 110000, True], # 3 human messages and over token limit
[5, 110000, True], # Many messages over limit
]
)
async def test_should_compact_conversation(self, num_human_messages, token_count, should_compact):
@@ -136,19 +136,57 @@ class TestAnthropicConversationCompactionManager(BaseTest):
result = await self.window_manager.should_compact_conversation(mock_model, messages)
self.assertEqual(result, should_compact)
def test_get_estimated_tokens_human_message(self):
async def test_should_compact_conversation_with_tools_under_limit(self):
"""Test that tools are accounted for when estimating tokens with 2 or fewer human messages"""
from langchain_core.tools import tool
@tool
def test_tool(query: str) -> str:
"""A test tool"""
return f"Result for {query}"
messages: list[BaseMessage] = [
LangchainHumanMessage(content="A" * 1000), # ~250 tokens
LangchainAIMessage(content="B" * 1000), # ~250 tokens
]
tools = [test_tool]
mock_model = MagicMock()
# With 2 human messages, should use estimation and not call _get_token_count
result = await self.window_manager.should_compact_conversation(mock_model, messages, tools=tools)
# Total should be well under 100k limit
self.assertFalse(result)
async def test_should_compact_conversation_with_tools_over_limit(self):
"""Test that tools push estimation over limit with 2 or fewer human messages"""
messages: list[BaseMessage] = [
LangchainHumanMessage(content="A" * 200000), # ~50k tokens
LangchainAIMessage(content="B" * 200000), # ~50k tokens
]
# Create large tool schemas to push over 100k limit
tools = [{"type": "function", "function": {"name": f"tool_{i}", "description": "X" * 1000}} for i in range(100)]
mock_model = MagicMock()
result = await self.window_manager.should_compact_conversation(mock_model, messages, tools=tools)
# Should be over the 100k limit
self.assertTrue(result)
def test_get_estimated_assistant_message_tokens_human_message(self):
"""Test token estimation for human messages"""
message = HumanMessage(content="A" * 100, id="1") # 100 chars = ~25 tokens
tokens = self.window_manager._get_estimated_tokens(message)
tokens = self.window_manager._get_estimated_assistant_message_tokens(message)
self.assertEqual(tokens, 25)
def test_get_estimated_tokens_assistant_message(self):
def test_get_estimated_assistant_message_tokens_assistant_message(self):
"""Test token estimation for assistant messages without tool calls"""
message = AssistantMessage(content="A" * 100, id="1") # 100 chars = ~25 tokens
tokens = self.window_manager._get_estimated_tokens(message)
tokens = self.window_manager._get_estimated_assistant_message_tokens(message)
self.assertEqual(tokens, 25)
def test_get_estimated_tokens_assistant_message_with_tool_calls(self):
def test_get_estimated_assistant_message_tokens_assistant_message_with_tool_calls(self):
"""Test token estimation for assistant messages with tool calls"""
message = AssistantMessage(
content="A" * 100, # 100 chars
@@ -162,17 +200,117 @@ class TestAnthropicConversationCompactionManager(BaseTest):
],
)
# Should count content + JSON serialized args
tokens = self.window_manager._get_estimated_tokens(message)
tokens = self.window_manager._get_estimated_assistant_message_tokens(message)
# 100 chars content + ~15 chars for args = ~29 tokens
self.assertGreater(tokens, 25)
self.assertLess(tokens, 35)
def test_get_estimated_tokens_tool_call_message(self):
def test_get_estimated_assistant_message_tokens_tool_call_message(self):
"""Test token estimation for tool call messages"""
message = AssistantToolCallMessage(content="A" * 200, id="1", tool_call_id="t1")
tokens = self.window_manager._get_estimated_tokens(message)
tokens = self.window_manager._get_estimated_assistant_message_tokens(message)
self.assertEqual(tokens, 50)
def test_get_estimated_langchain_message_tokens_string_content(self):
"""Test token estimation for langchain messages with string content"""
message = LangchainHumanMessage(content="A" * 100)
tokens = self.window_manager._get_estimated_langchain_message_tokens(message)
self.assertEqual(tokens, 25)
def test_get_estimated_langchain_message_tokens_list_content_with_strings(self):
"""Test token estimation for langchain messages with list of string content"""
message = LangchainHumanMessage(content=["A" * 100, "B" * 100])
tokens = self.window_manager._get_estimated_langchain_message_tokens(message)
self.assertEqual(tokens, 50)
def test_get_estimated_langchain_message_tokens_list_content_with_dicts(self):
"""Test token estimation for langchain messages with dict content"""
message = LangchainHumanMessage(content=[{"type": "text", "text": "A" * 100}])
tokens = self.window_manager._get_estimated_langchain_message_tokens(message)
# 100 chars for text + overhead for JSON structure
self.assertGreater(tokens, 25)
self.assertLess(tokens, 40)
def test_get_estimated_langchain_message_tokens_ai_message_with_tool_calls(self):
"""Test token estimation for AI messages with tool calls"""
message = LangchainAIMessage(
content="A" * 100,
tool_calls=[
{"id": "t1", "name": "test_tool", "args": {"key": "value"}},
{"id": "t2", "name": "another_tool", "args": {"foo": "bar"}},
],
)
tokens = self.window_manager._get_estimated_langchain_message_tokens(message)
# Content + tool calls JSON
self.assertGreater(tokens, 25)
self.assertLess(tokens, 70)
def test_get_estimated_langchain_message_tokens_ai_message_without_tool_calls(self):
"""Test token estimation for AI messages without tool calls"""
message = LangchainAIMessage(content="A" * 100)
tokens = self.window_manager._get_estimated_langchain_message_tokens(message)
self.assertEqual(tokens, 25)
def test_count_json_tokens(self):
"""Test JSON token counting helper"""
json_data = {"key": "value", "nested": {"foo": "bar"}}
char_count = self.window_manager._count_json_tokens(json_data)
# Should match length of compact JSON
import json
expected = len(json.dumps(json_data, separators=(",", ":")))
self.assertEqual(char_count, expected)
def test_get_estimated_tools_tokens_empty(self):
"""Test tool token estimation with no tools"""
tokens = self.window_manager._get_estimated_tools_tokens([])
self.assertEqual(tokens, 0)
def test_get_estimated_tools_tokens_with_dict_tools(self):
"""Test tool token estimation with dict tools"""
tools = [
{"type": "function", "function": {"name": "test_tool", "description": "A test tool"}},
]
tokens = self.window_manager._get_estimated_tools_tokens(tools)
# Should be positive and reasonable
self.assertGreater(tokens, 0)
self.assertLess(tokens, 100)
def test_get_estimated_tools_tokens_with_base_tool(self):
"""Test tool token estimation with BaseTool"""
from langchain_core.tools import tool
@tool
def sample_tool(query: str) -> str:
"""A sample tool for testing"""
return f"Result for {query}"
tools = [sample_tool]
tokens = self.window_manager._get_estimated_tools_tokens(tools)
# Should count the tool schema
self.assertGreater(tokens, 0)
self.assertLess(tokens, 200)
def test_get_estimated_tools_tokens_multiple_tools(self):
"""Test tool token estimation with multiple tools"""
from langchain_core.tools import tool
@tool
def tool1(x: int) -> int:
"""First tool"""
return x * 2
@tool
def tool2(y: str) -> str:
"""Second tool"""
return y.upper()
tools = [tool1, tool2]
tokens = self.window_manager._get_estimated_tools_tokens(tools)
# Should count both tool schemas
self.assertGreater(tokens, 0)
self.assertLess(tokens, 400)
async def test_get_token_count_calls_model(self):
"""Test that _get_token_count properly calls the model's token counting"""
mock_model = MagicMock()

View File

@@ -13,7 +13,6 @@ from langchain_core.messages import (
)
from langchain_core.outputs import ChatGeneration, ChatResult
from langchain_core.runnables import RunnableConfig
from langgraph.errors import NodeInterrupt
from parameterized import parameterized
from posthog.schema import (
@@ -34,6 +33,7 @@ from posthog.models.organization import OrganizationMembership
from products.replay.backend.max_tools import SearchSessionRecordingsTool
from ee.hogai.context import AssistantContextManager
from ee.hogai.tool_errors import MaxToolError, MaxToolFatalError, MaxToolRetryableError, MaxToolTransientError
from ee.hogai.tools.read_taxonomy import ReadEvents
from ee.hogai.utils.tests import FakeChatAnthropic, FakeChatOpenAI
from ee.hogai.utils.types import AssistantState, PartialAssistantState
@@ -445,6 +445,46 @@ class TestAgentNode(ClickhouseTestMixin, BaseTest):
self.assertIn("You are currently in project ", system_content)
self.assertIn("The user's name appears to be ", system_content)
async def test_node_includes_core_memory_in_system_prompt(self):
"""Test that core memory content is appended to the conversation in system prompts"""
with (
patch("os.environ", {"ANTHROPIC_API_KEY": "foo"}),
patch("langchain_anthropic.chat_models.ChatAnthropic._agenerate") as mock_generate,
patch("ee.hogai.graph.agent_modes.nodes.AgentExecutable._aget_core_memory_text") as mock_core_memory,
):
mock_core_memory.return_value = "User prefers concise responses and technical details"
mock_generate.return_value = ChatResult(
generations=[ChatGeneration(message=AIMessage(content="Response"))],
llm_output={},
)
node = _create_agent_node(self.team, self.user)
config = RunnableConfig(configurable={})
node._config = config
await node.arun(AssistantState(messages=[HumanMessage(content="Test")]), config)
# Verify _agenerate was called
mock_generate.assert_called_once()
# Get the messages passed to _agenerate
call_args = mock_generate.call_args
messages = call_args[0][0]
# Check system messages contain core memory
system_messages = [msg for msg in messages if isinstance(msg, SystemMessage)]
self.assertGreater(len(system_messages), 0)
content_parts = []
for msg in system_messages:
if isinstance(msg.content, str):
content_parts.append(msg.content)
else:
content_parts.append(str(msg.content))
system_content = "\n\n".join(content_parts)
self.assertIn("User prefers concise responses and technical details", system_content)
@parameterized.expand(
[
# (membership_level, add_context, expected_prompt)
@@ -480,13 +520,12 @@ class TestAgentNode(ClickhouseTestMixin, BaseTest):
self.assertEqual(await node._get_billing_prompt(), expected_prompt)
@patch("ee.hogai.graph.agent_modes.nodes.AgentExecutable._get_model", return_value=FakeChatOpenAI(responses=[]))
@patch(
"ee.hogai.graph.agent_modes.compaction_manager.AnthropicConversationCompactionManager.should_compact_conversation"
)
@patch("ee.hogai.graph.agent_modes.compaction_manager.AnthropicConversationCompactionManager.calculate_token_count")
@patch("ee.hogai.graph.conversation_summarizer.nodes.AnthropicConversationSummarizer.summarize")
async def test_conversation_summarization_flow(self, mock_summarize, mock_should_compact, mock_model):
async def test_conversation_summarization_flow(self, mock_summarize, mock_calculate_tokens, mock_model):
"""Test that conversation is summarized when it gets too long"""
mock_should_compact.return_value = True
# Return a token count higher than CONVERSATION_WINDOW_SIZE (100,000)
mock_calculate_tokens.return_value = 150_000
mock_summarize.return_value = "This is a summary of the conversation so far."
mock_model_instance = FakeChatOpenAI(responses=[LangchainAIMessage(content="Response after summary")])
@@ -514,13 +553,12 @@ class TestAgentNode(ClickhouseTestMixin, BaseTest):
self.assertIn("This is a summary of the conversation so far.", context_messages[0].content)
@patch("ee.hogai.graph.agent_modes.nodes.AgentExecutable._get_model", return_value=FakeChatOpenAI(responses=[]))
@patch(
"ee.hogai.graph.agent_modes.compaction_manager.AnthropicConversationCompactionManager.should_compact_conversation"
)
@patch("ee.hogai.graph.agent_modes.compaction_manager.AnthropicConversationCompactionManager.calculate_token_count")
@patch("ee.hogai.graph.conversation_summarizer.nodes.AnthropicConversationSummarizer.summarize")
async def test_conversation_summarization_on_first_turn(self, mock_summarize, mock_should_compact, mock_model):
async def test_conversation_summarization_on_first_turn(self, mock_summarize, mock_calculate_tokens, mock_model):
"""Test that on first turn, the last message is excluded from summarization"""
mock_should_compact.return_value = True
# Return a token count higher than CONVERSATION_WINDOW_SIZE (100,000)
mock_calculate_tokens.return_value = 150_000
mock_summarize.return_value = "Summary without last message"
mock_model_instance = FakeChatOpenAI(responses=[LangchainAIMessage(content="Response")])
@@ -871,40 +909,6 @@ class TestAgentToolsNode(BaseTest):
self.assertEqual(len(result.messages), 1)
self.assertIsInstance(result.messages[0], AssistantToolCallMessage)
async def test_navigate_tool_call_raises_node_interrupt(self):
"""Test that navigate tool calls raise NodeInterrupt to pause graph execution"""
node = _create_agent_tools_node(self.team, self.user)
state = AssistantState(
messages=[
AssistantMessage(
content="I'll help you navigate to insights",
id="test-id",
tool_calls=[AssistantToolCall(id="nav-123", name="navigate", args={"page_key": "insights"})],
)
],
root_tool_call_id="nav-123",
)
mock_navigate_tool = AsyncMock()
mock_navigate_tool.ainvoke.return_value = LangchainToolMessage(
content="XXX", tool_call_id="nav-123", artifact={"page_key": "insights"}
)
# The navigate tool call should raise NodeInterrupt
with self.assertRaises(NodeInterrupt) as cm:
await node(state, {"configurable": {"contextual_tools": {"navigate": {}}}})
# Verify the NodeInterrupt contains the expected message
# NodeInterrupt wraps the message in an Interrupt object
interrupt_data = cm.exception.args[0]
if isinstance(interrupt_data, list):
interrupt_data = interrupt_data[0].value
self.assertIsInstance(interrupt_data, AssistantToolCallMessage)
self.assertIn("Navigated to **insights**.", interrupt_data.content)
self.assertEqual(interrupt_data.tool_call_id, "nav-123")
self.assertEqual(interrupt_data.ui_payload, {"navigate": {"page_key": "insights"}})
async def test_arun_tool_returns_wrong_type_returns_error_message(self):
"""Test that tool returning wrong type returns an error message"""
node = _create_agent_tools_node(self.team, self.user)
@@ -955,3 +959,163 @@ class TestAgentToolsNode(BaseTest):
assert isinstance(result.messages[0], AssistantToolCallMessage)
self.assertEqual(result.messages[0].tool_call_id, "tool-123")
self.assertIn("does not exist", result.messages[0].content)
@patch("ee.hogai.tools.read_taxonomy.ReadTaxonomyTool._run_impl")
async def test_max_tool_fatal_error_returns_error_message(self, read_taxonomy_mock):
"""Test that MaxToolFatalError is caught and converted to tool message."""
read_taxonomy_mock.side_effect = MaxToolFatalError(
"Configuration error: INKEEP_API_KEY environment variable is not set"
)
node = _create_agent_tools_node(self.team, self.user)
state = AssistantState(
messages=[
AssistantMessage(
content="Using tool that will fail",
id="test-id",
tool_calls=[
AssistantToolCall(id="tool-123", name="read_taxonomy", args={"query": {"kind": "events"}})
],
)
],
root_tool_call_id="tool-123",
)
result = await node.arun(state, {})
self.assertIsInstance(result, PartialAssistantState)
assert result is not None
self.assertEqual(len(result.messages), 1)
assert isinstance(result.messages[0], AssistantToolCallMessage)
self.assertEqual(result.messages[0].tool_call_id, "tool-123")
self.assertIn("Configuration error", result.messages[0].content)
self.assertIn("INKEEP_API_KEY", result.messages[0].content)
self.assertNotIn("retry", result.messages[0].content.lower())
@patch("ee.hogai.tools.read_taxonomy.ReadTaxonomyTool._run_impl")
async def test_max_tool_retryable_error_returns_error_with_retry_hint(self, read_taxonomy_mock):
"""Test that MaxToolRetryableError includes retry hint for adjusted inputs."""
read_taxonomy_mock.side_effect = MaxToolRetryableError(
"Invalid entity kind: 'unknown_entity'. Must be one of: person, session, organization"
)
node = _create_agent_tools_node(self.team, self.user)
state = AssistantState(
messages=[
AssistantMessage(
content="Using tool with invalid input",
id="test-id",
tool_calls=[
AssistantToolCall(id="tool-123", name="read_taxonomy", args={"query": {"kind": "events"}})
],
)
],
root_tool_call_id="tool-123",
)
result = await node.arun(state, {})
self.assertIsInstance(result, PartialAssistantState)
assert result is not None
self.assertEqual(len(result.messages), 1)
assert isinstance(result.messages[0], AssistantToolCallMessage)
self.assertEqual(result.messages[0].tool_call_id, "tool-123")
self.assertIn("Invalid entity kind", result.messages[0].content)
self.assertIn("retry with adjusted inputs", result.messages[0].content.lower())
@patch("ee.hogai.tools.read_taxonomy.ReadTaxonomyTool._run_impl")
async def test_max_tool_transient_error_returns_error_with_once_retry_hint(self, read_taxonomy_mock):
"""Test that MaxToolTransientError includes hint to retry once without changes."""
read_taxonomy_mock.side_effect = MaxToolTransientError("Rate limit exceeded. Please try again in a few moments")
node = _create_agent_tools_node(self.team, self.user)
state = AssistantState(
messages=[
AssistantMessage(
content="Using tool that hits rate limit",
id="test-id",
tool_calls=[
AssistantToolCall(id="tool-123", name="read_taxonomy", args={"query": {"kind": "events"}})
],
)
],
root_tool_call_id="tool-123",
)
result = await node.arun(state, {})
self.assertIsInstance(result, PartialAssistantState)
assert result is not None
self.assertEqual(len(result.messages), 1)
assert isinstance(result.messages[0], AssistantToolCallMessage)
self.assertEqual(result.messages[0].tool_call_id, "tool-123")
self.assertIn("Rate limit exceeded", result.messages[0].content)
self.assertIn("retry this operation once without changes", result.messages[0].content.lower())
@patch("ee.hogai.tools.read_taxonomy.ReadTaxonomyTool._run_impl")
async def test_generic_exception_returns_internal_error_message(self, read_taxonomy_mock):
"""Test that generic exceptions are caught and return internal error message."""
read_taxonomy_mock.side_effect = RuntimeError("Unexpected internal error")
node = _create_agent_tools_node(self.team, self.user)
state = AssistantState(
messages=[
AssistantMessage(
content="Using tool that crashes unexpectedly",
id="test-id",
tool_calls=[
AssistantToolCall(id="tool-123", name="read_taxonomy", args={"query": {"kind": "events"}})
],
)
],
root_tool_call_id="tool-123",
)
result = await node.arun(state, {})
self.assertIsInstance(result, PartialAssistantState)
assert result is not None
self.assertEqual(len(result.messages), 1)
assert isinstance(result.messages[0], AssistantToolCallMessage)
self.assertEqual(result.messages[0].tool_call_id, "tool-123")
self.assertIn("internal error", result.messages[0].content.lower())
self.assertIn("do not immediately retry", result.messages[0].content.lower())
@parameterized.expand(
[
("fatal", MaxToolFatalError("Fatal error"), "never"),
("transient", MaxToolTransientError("Transient error"), "once"),
("retryable", MaxToolRetryableError("Retryable error"), "adjusted"),
]
)
@patch("ee.hogai.tools.read_taxonomy.ReadTaxonomyTool._run_impl")
async def test_all_error_types_are_logged_with_retry_strategy(
self, name, error, expected_strategy, read_taxonomy_mock
):
"""Test that all MaxToolError types are logged with their retry strategy."""
read_taxonomy_mock.side_effect = error
node = _create_agent_tools_node(self.team, self.user)
state = AssistantState(
messages=[
AssistantMessage(
content="Using tool",
id="test-id",
tool_calls=[
AssistantToolCall(id="tool-123", name="read_taxonomy", args={"query": {"kind": "events"}})
],
)
],
root_tool_call_id="tool-123",
)
with patch("ee.hogai.graph.agent_modes.nodes.capture_exception") as mock_capture:
_ = await node.arun(state, {})
mock_capture.assert_called_once()
call_kwargs = mock_capture.call_args.kwargs
captured_error = mock_capture.call_args.args[0]
self.assertIsInstance(captured_error, MaxToolError)
self.assertEqual(call_kwargs["properties"]["retry_strategy"], expected_strategy)
self.assertEqual(call_kwargs["properties"]["tool"], "read_taxonomy")

View File

@@ -55,15 +55,22 @@ class ConversationSummarizer:
class AnthropicConversationSummarizer(ConversationSummarizer):
def __init__(self, team: Team, user: User, extend_context_window: bool | None = False):
super().__init__(team, user)
self._extend_context_window = extend_context_window
def _get_model(self):
# Haiku has 200k token limit. Sonnet has 1M token limit.
return MaxChatAnthropic(
model="claude-haiku-4-5",
model="claude-sonnet-4-5" if self._extend_context_window else "claude-haiku-4-5",
streaming=False,
stream_usage=False,
max_tokens=8192,
disable_streaming=True,
user=self._user,
team=self._team,
billable=True,
betas=["context-1m-2025-08-07"] if self._extend_context_window else None,
)
def _construct_messages(self, messages: Sequence[BaseMessage]):

View File

@@ -5,7 +5,6 @@ from django.db import transaction
import structlog
from langchain_core.runnables import RunnableConfig
from langchain_openai import ChatOpenAI
from pydantic import BaseModel
from posthog.schema import (
@@ -28,6 +27,7 @@ from ee.hogai.graph.parallel_task_execution.mixins import (
)
from ee.hogai.graph.parallel_task_execution.nodes import BaseTaskExecutorNode, TaskExecutionInputTuple
from ee.hogai.graph.shared_prompts import HYPERLINK_USAGE_INSTRUCTIONS
from ee.hogai.llm import MaxChatOpenAI
from ee.hogai.utils.helpers import build_dashboard_url, build_insight_url, cast_assistant_query
from ee.hogai.utils.types import AssistantState, PartialAssistantState
from ee.hogai.utils.types.base import BaseStateWithTasks, InsightArtifact, InsightQuery, TaskResult
@@ -417,10 +417,14 @@ class DashboardCreationNode(AssistantNode):
@property
def _model(self):
return ChatOpenAI(
return MaxChatOpenAI(
model="gpt-4.1-mini",
temperature=0.3,
max_completion_tokens=500,
max_retries=3,
disable_streaming=True,
user=self._user,
team=self._team,
billable=True,
inject_context=False,
)

View File

@@ -14,7 +14,6 @@ import structlog
from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage, ToolMessage
from langchain_core.runnables import RunnableConfig
from langchain_core.tools import tool
from langchain_openai import ChatOpenAI
from posthog.schema import AssistantToolCallMessage, VisualizationMessage
@@ -25,6 +24,7 @@ from ee.hogai.context import SUPPORTED_QUERY_MODEL_BY_KIND
from ee.hogai.graph.base import AssistantNode
from ee.hogai.graph.query_executor.query_executor import AssistantQueryExecutor, SupportedQueryTypes
from ee.hogai.graph.shared_prompts import HYPERLINK_USAGE_INSTRUCTIONS
from ee.hogai.llm import MaxChatOpenAI
from ee.hogai.utils.helpers import build_insight_url
from ee.hogai.utils.types import AssistantState, PartialAssistantState
@@ -856,7 +856,7 @@ class InsightSearchNode(AssistantNode):
@property
def _model(self):
return ChatOpenAI(
return MaxChatOpenAI(
model="gpt-4.1-mini",
temperature=0.7,
max_completion_tokens=1000,
@@ -864,4 +864,8 @@ class InsightSearchNode(AssistantNode):
stream_usage=False,
max_retries=3,
disable_streaming=True,
user=self._user,
team=self._team,
billable=True,
inject_context=False,
)

View File

@@ -266,7 +266,7 @@ class TestInsightSearchNode(BaseTest):
# Note: Additional visualization messages depend on query type support in test data
@patch("ee.hogai.graph.insights.nodes.ChatOpenAI")
@patch("ee.hogai.graph.insights.nodes.MaxChatOpenAI")
def test_search_insights_iteratively_single_page(self, mock_openai):
"""Test iterative search with single page (no pagination)."""
@@ -296,7 +296,7 @@ class TestInsightSearchNode(BaseTest):
self.assertIn(self.insight1.id, result)
self.assertIn(self.insight2.id, result)
@patch("ee.hogai.graph.insights.nodes.ChatOpenAI")
@patch("ee.hogai.graph.insights.nodes.MaxChatOpenAI")
def test_search_insights_iteratively_with_pagination(self, mock_openai):
"""Test iterative search with pagination returns valid IDs."""
@@ -326,7 +326,7 @@ class TestInsightSearchNode(BaseTest):
self.assertIn(existing_insight_ids[0], result)
self.assertIn(existing_insight_ids[1], result)
@patch("ee.hogai.graph.insights.nodes.ChatOpenAI")
@patch("ee.hogai.graph.insights.nodes.MaxChatOpenAI")
def test_search_insights_iteratively_fallback(self, mock_openai):
"""Test iterative search when LLM fails - should return empty list."""
@@ -586,7 +586,7 @@ class TestInsightSearchNode(BaseTest):
self.assertEqual(len(self.node._evaluation_selections), 0)
self.assertEqual(self.node._rejection_reason, "None of these match the user's needs")
@patch("ee.hogai.graph.insights.nodes.ChatOpenAI")
@patch("ee.hogai.graph.insights.nodes.MaxChatOpenAI")
async def test_evaluate_insights_with_tools_selection(self, mock_openai):
"""Test the new tool-based evaluation with insight selection."""
# Load insights
@@ -662,7 +662,7 @@ class TestInsightSearchNode(BaseTest):
query_info_empty = self.node._extract_query_metadata({})
self.assertIsNone(query_info_empty)
@patch("ee.hogai.graph.insights.nodes.ChatOpenAI")
@patch("ee.hogai.graph.insights.nodes.MaxChatOpenAI")
async def test_non_executable_insights_handling(self, mock_openai):
"""Test that non-executable insights are presented to LLM but rejected."""
# Create a mock insight that can't be visualized
@@ -707,7 +707,7 @@ class TestInsightSearchNode(BaseTest):
# The explanation should indicate why the insight was rejected
self.assertTrue(len(result["explanation"]) > 0)
@patch("ee.hogai.graph.insights.nodes.ChatOpenAI")
@patch("ee.hogai.graph.insights.nodes.MaxChatOpenAI")
async def test_evaluate_insights_with_tools_rejection(self, mock_openai):
"""Test the new tool-based evaluation with rejection."""
# Load insights
@@ -737,7 +737,7 @@ class TestInsightSearchNode(BaseTest):
result["explanation"], "User is looking for retention analysis, but these are trends and funnels"
)
@patch("ee.hogai.graph.insights.nodes.ChatOpenAI")
@patch("ee.hogai.graph.insights.nodes.MaxChatOpenAI")
async def test_evaluate_insights_with_tools_multiple_selection(self, mock_openai):
"""Test the evaluation with multiple selection mode."""
# Load insights

View File

@@ -424,7 +424,7 @@ class MemoryCollectorNode(MemoryOnboardingShouldRunMixin):
@property
def _model(self):
return MaxChatOpenAI(
model="gpt-4.1", temperature=0.3, disable_streaming=True, user=self._user, team=self._team
model="gpt-4.1", temperature=0.3, disable_streaming=True, user=self._user, team=self._team, billable=True
).bind_tools(memory_collector_tools)
def _construct_messages(self, state: AssistantState) -> list[BaseMessage]:

View File

@@ -155,6 +155,7 @@ class QueryPlannerNode(TaxonomyUpdateDispatcherNodeMixin, AssistantNode):
# Ref: https://forum.langchain.com/t/langgraph-openai-responses-api-400-error-web-search-call-was-provided-without-its-required-reasoning-item/1740/2
output_version="responses/v1",
disable_streaming=True,
billable=True,
).bind_tools(
[
retrieve_event_properties,

View File

@@ -58,7 +58,13 @@ class SchemaGeneratorNode(AssistantNode, Generic[Q]):
@property
def _model(self):
return MaxChatOpenAI(
model="gpt-4.1", temperature=0.3, disable_streaming=True, user=self._user, team=self._team, max_tokens=8192
model="gpt-4.1",
temperature=0.3,
disable_streaming=True,
user=self._user,
team=self._team,
max_tokens=8192,
billable=True,
).with_structured_output(
self.OUTPUT_SCHEMA,
method="json_schema",

View File

@@ -59,12 +59,24 @@ class SessionSummarizationNode(AssistantNode):
self._session_search = _SessionSearch(self)
self._session_summarizer = _SessionSummarizer(self)
async def _stream_progress(self, progress_message: str) -> None:
def _stream_progress(self, progress_message: str) -> None:
"""Push summarization progress as reasoning messages"""
content = prepare_reasoning_progress_message(progress_message)
if content:
self.dispatcher.update(content)
def _stream_filters(self, filters: MaxRecordingUniversalFilters) -> None:
"""Stream filters to the user"""
self.dispatcher.message(
AssistantToolCallMessage(
content="",
ui_payload={"search_session_recordings": filters.model_dump(exclude_none=True)},
# Randomized tool call ID, as we don't want this to be THE result of the actual session summarization tool call
# - it's OK because this is only dispatched ephemerally, so the tool message doesn't get added to the state
tool_call_id=str(uuid4()),
)
)
async def _stream_notebook_content(self, content: dict, state: AssistantState, partial: bool = True) -> None:
"""Stream TipTap content directly to a notebook if notebook_id is present in state."""
# Check if we have a notebook_id in the state
@@ -270,9 +282,7 @@ class _SessionSearch:
def _convert_current_filters_to_recordings_query(self, current_filters: dict[str, Any]) -> RecordingsQuery:
"""Convert current filters into recordings query format"""
from ee.session_recordings.playlist_counters.recordings_that_match_playlist_filters import (
convert_filters_to_recordings_query,
)
from posthog.session_recordings.playlist_counters import convert_filters_to_recordings_query
# Create a temporary playlist object to use the conversion function
temp_playlist = SessionRecordingPlaylist(filters=current_filters)
@@ -386,6 +396,7 @@ class _SessionSearch:
root_tool_call_id=None,
)
# Use filters when generated successfully
self._node._stream_filters(filter_generation_result)
replay_filters = self._convert_max_filters_to_recordings_query(filter_generation_result)
# Query the filters to get session ids
query_limit = state.session_summarization_limit
@@ -429,13 +440,13 @@ class _SessionSummarizer:
)
completed += 1
# Update the user on the progress
await self._node._stream_progress(progress_message=f"Watching sessions ({completed}/{total})")
self._node._stream_progress(progress_message=f"Watching sessions ({completed}/{total})")
return result
# Run all tasks concurrently
tasks = [_summarize(sid) for sid in session_ids]
summaries = await asyncio.gather(*tasks)
await self._node._stream_progress(progress_message=f"Generating a summary, almost there")
self._node._stream_progress(progress_message=f"Generating a summary, almost there")
# Stringify, as chat doesn't need full JSON to be context-aware, while providing it could overload the context
stringified_summaries = []
for summary in summaries:
@@ -483,7 +494,7 @@ class _SessionSummarizer:
# Update intermediate state based on step enum (no content, as it's just a status message)
self._intermediate_state.update_step_progress(content=None, step=step)
# Status message - stream to user
await self._node._stream_progress(progress_message=data)
self._node._stream_progress(progress_message=data)
# Notebook intermediate data update messages
elif update_type == SessionSummaryStreamUpdate.NOTEBOOK_UPDATE:
if not isinstance(data, dict):
@@ -543,7 +554,7 @@ class _SessionSummarizer:
base_message = f"Found sessions ({len(session_ids)})"
if len(session_ids) <= GROUP_SUMMARIES_MIN_SESSIONS:
# If small amount of sessions - there are no patterns to extract, so summarize them individually and return as is
await self._node._stream_progress(
self._node._stream_progress(
progress_message=f"{base_message}. We will do a quick summary, as the scope is small",
)
summaries_content = await self._summarize_sessions_individually(session_ids=session_ids)
@@ -558,7 +569,7 @@ class _SessionSummarizer:
state.notebook_short_id = notebook.short_id
# For large groups, process in detail, searching for patterns
# TODO: Allow users to define the pattern themselves (or rather catch it from the query)
await self._node._stream_progress(
self._node._stream_progress(
progress_message=f"{base_message}. We will analyze in detail, and store the report in a notebook",
)
summaries_content = await self._summarize_sessions_as_group(

View File

@@ -25,18 +25,18 @@
WHERE and(equals(s.team_id, 99999), greaterOrEquals(toTimeZone(s.min_first_timestamp, 'UTC'), toDateTime64('2025-08-24 00:00:00.000000', 6, 'UTC')), lessOrEquals(toTimeZone(s.min_first_timestamp, 'UTC'), toDateTime64('2025-09-03 23:59:59.000000', 6, 'UTC')))
GROUP BY s.session_id
HAVING and(ifNull(greaterOrEquals(expiry_time, toDateTime64('2025-09-03 12:00:00.000000', 6, 'UTC')), 0), ifNull(greater(active_seconds, 8.0), 0))
ORDER BY start_time DESC
LIMIT 101
OFFSET 0 SETTINGS readonly=2,
max_execution_time=60,
allow_experimental_object_type=1,
format_csv_allow_double_quotes=0,
max_ast_elements=4000000,
max_expanded_ast_elements=4000000,
max_bytes_before_external_group_by=0,
transform_null_in=1,
optimize_min_equality_disjunction_chain_length=4294967295,
allow_experimental_join_condition=1
ORDER BY start_time DESC,
s.session_id DESC
LIMIT 101 SETTINGS readonly=2,
max_execution_time=60,
allow_experimental_object_type=1,
format_csv_allow_double_quotes=0,
max_ast_elements=4000000,
max_expanded_ast_elements=4000000,
max_bytes_before_external_group_by=0,
transform_null_in=1,
optimize_min_equality_disjunction_chain_length=4294967295,
allow_experimental_join_condition=1
'''
# ---
# name: TestSessionSummarizationNodeFilterGeneration.test_get_session_ids_respects_limit
@@ -65,18 +65,18 @@
WHERE and(equals(s.team_id, 99999), greaterOrEquals(toTimeZone(s.min_first_timestamp, 'UTC'), toDateTime64('2025-08-27 00:00:00.000000', 6, 'UTC')), lessOrEquals(toTimeZone(s.min_first_timestamp, 'UTC'), toDateTime64('2025-08-31 23:59:59.000000', 6, 'UTC')))
GROUP BY s.session_id
HAVING ifNull(greaterOrEquals(expiry_time, toDateTime64('2025-09-03 12:00:00.000000', 6, 'UTC')), 0)
ORDER BY start_time DESC
LIMIT 2
OFFSET 0 SETTINGS readonly=2,
max_execution_time=60,
allow_experimental_object_type=1,
format_csv_allow_double_quotes=0,
max_ast_elements=4000000,
max_expanded_ast_elements=4000000,
max_bytes_before_external_group_by=0,
transform_null_in=1,
optimize_min_equality_disjunction_chain_length=4294967295,
allow_experimental_join_condition=1
ORDER BY start_time DESC,
s.session_id DESC
LIMIT 2 SETTINGS readonly=2,
max_execution_time=60,
allow_experimental_object_type=1,
format_csv_allow_double_quotes=0,
max_ast_elements=4000000,
max_expanded_ast_elements=4000000,
max_bytes_before_external_group_by=0,
transform_null_in=1,
optimize_min_equality_disjunction_chain_length=4294967295,
allow_experimental_join_condition=1
'''
# ---
# name: TestSessionSummarizationNodeFilterGeneration.test_use_current_filters_with_date_range
@@ -105,18 +105,18 @@
WHERE and(equals(s.team_id, 99999), greaterOrEquals(toTimeZone(s.min_first_timestamp, 'UTC'), toDateTime64('2025-08-27 00:00:00.000000', 6, 'UTC')), lessOrEquals(toTimeZone(s.min_first_timestamp, 'UTC'), toDateTime64('2025-08-31 23:59:59.000000', 6, 'UTC')))
GROUP BY s.session_id
HAVING and(ifNull(greaterOrEquals(expiry_time, toDateTime64('2025-09-03 12:00:00.000000', 6, 'UTC')), 0), ifNull(greater(active_seconds, 7.0), 0))
ORDER BY start_time DESC
LIMIT 101
OFFSET 0 SETTINGS readonly=2,
max_execution_time=60,
allow_experimental_object_type=1,
format_csv_allow_double_quotes=0,
max_ast_elements=4000000,
max_expanded_ast_elements=4000000,
max_bytes_before_external_group_by=0,
transform_null_in=1,
optimize_min_equality_disjunction_chain_length=4294967295,
allow_experimental_join_condition=1
ORDER BY start_time DESC,
s.session_id DESC
LIMIT 101 SETTINGS readonly=2,
max_execution_time=60,
allow_experimental_object_type=1,
format_csv_allow_double_quotes=0,
max_ast_elements=4000000,
max_expanded_ast_elements=4000000,
max_bytes_before_external_group_by=0,
transform_null_in=1,
optimize_min_equality_disjunction_chain_length=4294967295,
allow_experimental_join_condition=1
'''
# ---
# name: TestSessionSummarizationNodeFilterGeneration.test_use_current_filters_with_os_and_events
@@ -149,42 +149,42 @@
GROUP BY events.`$session_id`
HAVING 1
ORDER BY min(toTimeZone(events.timestamp, 'UTC')) DESC
LIMIT 10000)), globalIn(s.session_id,
(SELECT events.`$session_id` AS session_id
FROM events
LEFT OUTER JOIN
(SELECT argMax(person_distinct_id_overrides.person_id, person_distinct_id_overrides.version) AS person_id, person_distinct_id_overrides.distinct_id AS distinct_id
FROM person_distinct_id_overrides
WHERE equals(person_distinct_id_overrides.team_id, 99999)
GROUP BY person_distinct_id_overrides.distinct_id
HAVING ifNull(equals(argMax(person_distinct_id_overrides.is_deleted, person_distinct_id_overrides.version), 0), 0) SETTINGS optimize_aggregation_in_order=1) AS events__override ON equals(events.distinct_id, events__override.distinct_id)
LEFT JOIN
(SELECT person.id AS id, replaceRegexpAll(nullIf(nullIf(JSONExtractRaw(person.properties, '$os'), ''), 'null'), '^"|"$', '') AS `properties___$os`
FROM person
WHERE and(equals(person.team_id, 99999), in(tuple(person.id, person.version),
(SELECT person.id AS id, max(person.version) AS version
FROM person
WHERE equals(person.team_id, 99999)
GROUP BY person.id
HAVING and(ifNull(equals(argMax(person.is_deleted, person.version), 0), 0), ifNull(less(argMax(toTimeZone(person.created_at, 'UTC'), person.version), plus(now64(6, 'UTC'), toIntervalDay(1))), 0))))) SETTINGS optimize_aggregation_in_order=1) AS events__person ON equals(if(not(empty(events__override.distinct_id)), events__override.person_id, events.person_id), events__person.id)
WHERE and(equals(events.team_id, 99999), notEmpty(events.`$session_id`), greaterOrEquals(toTimeZone(events.timestamp, 'UTC'), toDateTime64('explicit_redacted_timestamp', 6, 'UTC')), lessOrEquals(toTimeZone(events.timestamp, 'UTC'), now64(6, 'UTC')), ifNull(equals(events__person.`properties___$os`, 'Mac OS X'), 0), greaterOrEquals(toTimeZone(events.timestamp, 'UTC'), toDateTime64('explicit_redacted_timestamp', 6, 'UTC')), lessOrEquals(toTimeZone(events.timestamp, 'UTC'), toDateTime64('explicit_redacted_timestamp', 6, 'UTC')))
GROUP BY events.`$session_id`
HAVING 1
ORDER BY min(toTimeZone(events.timestamp, 'UTC')) DESC
LIMIT 10000))))
LIMIT 1000000)), globalIn(s.session_id,
(SELECT events.`$session_id` AS session_id
FROM events
LEFT OUTER JOIN
(SELECT argMax(person_distinct_id_overrides.person_id, person_distinct_id_overrides.version) AS person_id, person_distinct_id_overrides.distinct_id AS distinct_id
FROM person_distinct_id_overrides
WHERE equals(person_distinct_id_overrides.team_id, 99999)
GROUP BY person_distinct_id_overrides.distinct_id
HAVING ifNull(equals(argMax(person_distinct_id_overrides.is_deleted, person_distinct_id_overrides.version), 0), 0) SETTINGS optimize_aggregation_in_order=1) AS events__override ON equals(events.distinct_id, events__override.distinct_id)
LEFT JOIN
(SELECT person.id AS id, replaceRegexpAll(nullIf(nullIf(JSONExtractRaw(person.properties, '$os'), ''), 'null'), '^"|"$', '') AS `properties___$os`
FROM person
WHERE and(equals(person.team_id, 99999), in(tuple(person.id, person.version),
(SELECT person.id AS id, max(person.version) AS version
FROM person
WHERE equals(person.team_id, 99999)
GROUP BY person.id
HAVING and(ifNull(equals(argMax(person.is_deleted, person.version), 0), 0), ifNull(less(argMax(toTimeZone(person.created_at, 'UTC'), person.version), plus(now64(6, 'UTC'), toIntervalDay(1))), 0))))) SETTINGS optimize_aggregation_in_order=1) AS events__person ON equals(if(not(empty(events__override.distinct_id)), events__override.person_id, events.person_id), events__person.id)
WHERE and(equals(events.team_id, 99999), notEmpty(events.`$session_id`), greaterOrEquals(toTimeZone(events.timestamp, 'UTC'), toDateTime64('explicit_redacted_timestamp', 6, 'UTC')), lessOrEquals(toTimeZone(events.timestamp, 'UTC'), now64(6, 'UTC')), ifNull(equals(events__person.`properties___$os`, 'Mac OS X'), 0), greaterOrEquals(toTimeZone(events.timestamp, 'UTC'), toDateTime64('explicit_redacted_timestamp', 6, 'UTC')), lessOrEquals(toTimeZone(events.timestamp, 'UTC'), toDateTime64('explicit_redacted_timestamp', 6, 'UTC')))
GROUP BY events.`$session_id`
HAVING 1
ORDER BY min(toTimeZone(events.timestamp, 'UTC')) DESC
LIMIT 1000000))))
GROUP BY s.session_id
HAVING and(ifNull(greaterOrEquals(expiry_time, toDateTime64('explicit_redacted_timestamp', 6, 'UTC')), 0), ifNull(greater(active_seconds, 6.0), 0))
ORDER BY start_time DESC
LIMIT 101
OFFSET 0 SETTINGS readonly=2,
max_execution_time=60,
allow_experimental_object_type=1,
format_csv_allow_double_quotes=0,
max_ast_elements=4000000,
max_expanded_ast_elements=4000000,
max_bytes_before_external_group_by=0,
transform_null_in=1,
optimize_min_equality_disjunction_chain_length=4294967295,
allow_experimental_join_condition=1
ORDER BY start_time DESC,
s.session_id DESC
LIMIT 101 SETTINGS readonly=2,
max_execution_time=60,
allow_experimental_object_type=1,
format_csv_allow_double_quotes=0,
max_ast_elements=4000000,
max_expanded_ast_elements=4000000,
max_bytes_before_external_group_by=0,
transform_null_in=1,
optimize_min_equality_disjunction_chain_length=4294967295,
allow_experimental_join_condition=1
'''
# ---

View File

@@ -38,29 +38,50 @@ Standardized events/properties such as pageview or screen start with `$`. Custom
`virtual_table` and `lazy_table` fields are connections to linked tables, e.g. the virtual table field `person` allows accessing person properties like so: `person.properties.foo`.
<person_id_join_limitation>
There is a known issue with queries that join multiple events tables where join constraints
reference person_id fields. The person_id fields are ExpressionFields that expand to
expressions referencing override tables (e.g., e_all__override). However, these expressions
are resolved during type resolution (in printer.py) BEFORE lazy table processing begins.
This creates forward references to override tables that don't exist yet.
CRITICAL: There is a known issue with queries where JOIN constraints reference events.person_id fields.
Example problematic HogQL:
SELECT MAX(e_all.timestamp) AS last_seen
FROM events e_dl
JOIN persons p ON e_dl.person_id = p.id
JOIN events e_all ON e_dl.person_id = e_all.person_id
TECHNICAL CAUSE:
The person_id fields are ExpressionFields that expand to expressions referencing override tables
(e.g., e_all__override). However, these expressions are resolved during type resolution (in printer.py)
BEFORE lazy table processing begins. This creates forward references to override tables that don't
exist yet, causing ClickHouse errors like:
"Missing columns: '_--e__override.person_id' '_--e__override.distinct_id'"
The join constraint "e_dl.person_id = e_all.person_id" expands to:
if(NOT empty(e_dl__override.distinct_id), e_dl__override.person_id, e_dl.person_id) =
if(NOT empty(e_all__override.distinct_id), e_all__override.person_id, e_all.person_id)
PROBLEMATIC PATTERNS:
1. Joining persons to events using events.person_id:
❌ FROM persons p ALL INNER JOIN events e ON p.id = e.person_id
But e_all__override is defined later in the SQL, causing a ClickHouse error.
2. Joining multiple events tables using person_id:
❌ FROM events e_dl
JOIN persons p ON e_dl.person_id = p.id
JOIN events e_all ON e_dl.person_id = e_all.person_id
WORKAROUND: Use subqueries or rewrite queries to avoid direct joins between multiple events tables:
SELECT MAX(e.timestamp) AS last_seen
FROM events e
JOIN persons p ON e.person_id = p.id
WHERE e.event IN (SELECT event FROM events WHERE ...)
The join constraint "e_dl.person_id = e_all.person_id" expands to:
if(NOT empty(e_dl__override.distinct_id), e_dl__override.person_id, e_dl.person_id) =
if(NOT empty(e_all__override.distinct_id), e_all__override.person_id, e_all.person_id)
But e_all__override is defined later in the SQL, causing the error.
REQUIRED WORKAROUNDS:
1. For accessing person data, use the person virtual table from events:
✅ SELECT e.person.id, e.person.properties.email, e.event
FROM events e
WHERE e.timestamp > now() - INTERVAL 7 DAY
2. For filtering persons by event data, use subqueries with WHERE IN:
✅ SELECT p.id, p.properties.email
FROM persons p
WHERE p.id IN (
SELECT DISTINCT person_id FROM events
WHERE event = 'purchase' AND timestamp > now() - INTERVAL 7 DAY
)
3. For multiple events tables, use subqueries to avoid direct joins:
✅ SELECT MAX(e.timestamp) AS last_seen
FROM events e
WHERE e.person_id IN (SELECT DISTINCT person_id FROM events WHERE ...)
NEVER use events.person_id directly in JOIN ON constraints - always use one of the workarounds above.
</person_id_join_limitation>
ONLY make formatting or casing changes if explicitly requested by the user.

View File

@@ -68,8 +68,16 @@ class TaxonomyAgentNode(
return EntityType.values() + self._team_group_types
def _get_model(self, state: TaxonomyStateType):
# Check if this invocation should be billable (set by the calling tool)
billable = getattr(state, "billable", False)
return MaxChatOpenAI(
model="gpt-4.1", streaming=False, temperature=0.3, user=self._user, team=self._team, disable_streaming=True
model="gpt-4.1",
streaming=False,
temperature=0.3,
user=self._user,
team=self._team,
disable_streaming=True,
billable=billable,
).bind_tools(
self._toolkit.get_tools(),
tool_choice="required",
@@ -146,6 +154,7 @@ class TaxonomyAgentNode(
intermediate_steps=intermediate_steps,
output=state.output,
iteration_count=state.iteration_count + 1 if state.iteration_count is not None else 1,
billable=state.billable,
)
@@ -198,6 +207,7 @@ class TaxonomyAgentToolsNode(
return self._partial_state_class(
output=tool_input.arguments.answer, # type: ignore
intermediate_steps=None,
billable=state.billable,
)
if tool_input.name == "ask_user_for_help":
@@ -237,6 +247,7 @@ class TaxonomyAgentToolsNode(
tool_progress_messages=[*old_msg, *tool_msgs],
intermediate_steps=steps,
iteration_count=state.iteration_count,
billable=state.billable,
)
def router(self, state: TaxonomyStateType) -> str:
@@ -257,4 +268,5 @@ class TaxonomyAgentToolsNode(
)
]
reset_state.output = output
reset_state.billable = state.billable
return reset_state # type: ignore[return-value]

Some files were not shown because too many files have changed in this diff Show More