mirror of
https://github.com/BillyOutlast/posthog.git
synced 2026-02-04 03:01:23 +01:00
Merge branch 'master' into master
This commit is contained in:
@@ -1 +1,10 @@
|
||||
# Local dev only, see dags/README.md for more info
|
||||
|
||||
concurrency:
|
||||
runs:
|
||||
tag_concurrency_limits:
|
||||
# See dags/sessions.py
|
||||
- key: 'sessions_backfill_concurrency'
|
||||
limit: 3
|
||||
value:
|
||||
applyLimitPerUniqueValue: true
|
||||
|
||||
@@ -16,5 +16,7 @@ load_from:
|
||||
- python_module: dags.locations.analytics_platform
|
||||
- python_module: dags.locations.experiments
|
||||
- python_module: dags.locations.growth
|
||||
- python_module: dags.locations.llma
|
||||
- python_module: dags.locations.max_ai
|
||||
- python_module: dags.locations.web_analytics
|
||||
- python_module: dags.locations.ingestion
|
||||
|
||||
1
.flox/env/manifest.toml
vendored
1
.flox/env/manifest.toml
vendored
@@ -45,6 +45,7 @@ cmake = { pkg-path = "cmake", version = "3.31.5", pkg-group = "cmake" }
|
||||
sqlx-cli = { pkg-path = "sqlx-cli", version = "0.8.3" } # sqlx
|
||||
postgresql = { pkg-path = "postgresql_14" } # psql
|
||||
ffmpeg.pkg-path = "ffmpeg"
|
||||
ngrok = { pkg-path = "ngrok" }
|
||||
|
||||
# Set environment variables in the `[vars]` section. These variables may not
|
||||
# reference one another, and are added to the environment without first
|
||||
|
||||
68
.github/actions/commit-snapshots/action.yml
vendored
68
.github/actions/commit-snapshots/action.yml
vendored
@@ -21,7 +21,10 @@ inputs:
|
||||
description: 'Repository name (owner/repo)'
|
||||
required: true
|
||||
commit-sha:
|
||||
description: 'Commit SHA'
|
||||
description: 'Commit SHA that was tested'
|
||||
required: true
|
||||
branch-name:
|
||||
description: 'Branch name'
|
||||
required: true
|
||||
github-token:
|
||||
description: 'GitHub token for authentication'
|
||||
@@ -66,9 +69,48 @@ runs:
|
||||
git config user.name "github-actions[bot]"
|
||||
git config user.email "github-actions[bot]@users.noreply.github.com"
|
||||
|
||||
- name: Verify branch hasn't advanced
|
||||
id: verify
|
||||
shell: bash
|
||||
run: |
|
||||
# Check if the remote branch HEAD still matches the commit we tested
|
||||
TESTED_SHA="${{ inputs.commit-sha }}"
|
||||
CURRENT_SHA=$(git ls-remote origin "refs/heads/${{ inputs.branch-name }}" | cut -f1)
|
||||
|
||||
if [ "$TESTED_SHA" != "$CURRENT_SHA" ]; then
|
||||
echo "branch-advanced=true" >> $GITHUB_OUTPUT
|
||||
echo "current-sha=$CURRENT_SHA" >> $GITHUB_OUTPUT
|
||||
echo "⚠️ Branch has advanced during workflow execution"
|
||||
echo " Tested: $TESTED_SHA"
|
||||
echo " Current: $CURRENT_SHA"
|
||||
echo " Skipping snapshot commit - new workflow run will handle it"
|
||||
else
|
||||
echo "branch-advanced=false" >> $GITHUB_OUTPUT
|
||||
echo "✓ Branch HEAD matches tested commit - proceeding with snapshot commit"
|
||||
fi
|
||||
|
||||
- name: Post skip comment
|
||||
if: steps.verify.outputs.branch-advanced == 'true'
|
||||
shell: bash
|
||||
env:
|
||||
GH_TOKEN: ${{ inputs.github-token }}
|
||||
run: |
|
||||
TESTED_SHA="${{ inputs.commit-sha }}"
|
||||
CURRENT_SHA="${{ steps.verify.outputs.current-sha }}"
|
||||
gh pr comment ${{ inputs.pr-number }} --body "⏭️ Skipped snapshot commit because branch advanced to \`${CURRENT_SHA:0:7}\` while workflow was testing \`${TESTED_SHA:0:7}\`.
|
||||
|
||||
The new commit will trigger its own snapshot update workflow.
|
||||
|
||||
**If you expected this workflow to succeed:** This can happen due to concurrent commits. To get a fresh workflow run, either:
|
||||
- Merge master into your branch, or
|
||||
- Push an empty commit: \`git commit --allow-empty -m 'trigger CI' && git push\`"
|
||||
|
||||
- name: Count and commit changes
|
||||
id: commit
|
||||
if: steps.verify.outputs.branch-advanced == 'false'
|
||||
shell: bash
|
||||
env:
|
||||
GH_TOKEN: ${{ inputs.github-token }}
|
||||
run: |
|
||||
CHANGES_JSON=$(.github/scripts/count-snapshot-changes.sh ${{ inputs.snapshot-path }})
|
||||
echo "changes=$CHANGES_JSON" >> $GITHUB_OUTPUT
|
||||
@@ -76,14 +118,17 @@ runs:
|
||||
|
||||
TOTAL=$(echo "$CHANGES_JSON" | jq -r '.total')
|
||||
if [ "$TOTAL" -gt 0 ]; then
|
||||
# Disable auto-merge BEFORE committing
|
||||
# This ensures auto-merge is disabled even if workflow gets cancelled after push
|
||||
gh pr merge --disable-auto ${{ inputs.pr-number }} || echo "Auto-merge was not enabled"
|
||||
echo "Auto-merge disabled - snapshot changes require human review"
|
||||
|
||||
# Now commit and push
|
||||
git add ${{ inputs.snapshot-path }} -A
|
||||
git commit -m "${{ inputs.commit-message }}"
|
||||
# Pull before push to avoid race condition when multiple workflows commit simultaneously
|
||||
# Only fetch current branch to avoid verbose output from fetching all branches
|
||||
CURRENT_BRANCH=$(git rev-parse --abbrev-ref HEAD)
|
||||
git fetch origin "$CURRENT_BRANCH" --quiet
|
||||
git rebase origin/"$CURRENT_BRANCH" --autostash
|
||||
git push --quiet
|
||||
# Push directly - SHA verification already ensured we're up to date
|
||||
BRANCH_NAME="${{ inputs.branch-name }}"
|
||||
git push origin HEAD:"$BRANCH_NAME" --quiet
|
||||
SNAPSHOT_SHA=$(git rev-parse HEAD)
|
||||
echo "committed=true" >> $GITHUB_OUTPUT
|
||||
echo "snapshot-sha=$SNAPSHOT_SHA" >> $GITHUB_OUTPUT
|
||||
@@ -92,15 +137,6 @@ runs:
|
||||
echo "committed=false" >> $GITHUB_OUTPUT
|
||||
fi
|
||||
|
||||
- name: Disable auto-merge
|
||||
if: steps.commit.outputs.committed == 'true'
|
||||
shell: bash
|
||||
env:
|
||||
GH_TOKEN: ${{ inputs.github-token }}
|
||||
run: |
|
||||
gh pr merge --disable-auto ${{ inputs.pr-number }} || echo "Auto-merge was not enabled"
|
||||
echo "Snapshot changes require human review"
|
||||
|
||||
- name: Post snapshot comment
|
||||
if: steps.commit.outputs.committed == 'true'
|
||||
shell: bash
|
||||
|
||||
2
.github/actions/run-backend-tests/action.yml
vendored
2
.github/actions/run-backend-tests/action.yml
vendored
@@ -233,7 +233,7 @@ runs:
|
||||
|
||||
- name: Upload updated timing data as artifacts
|
||||
uses: actions/upload-artifact@v4
|
||||
if: ${{ inputs.person-on-events != 'true' && inputs.clickhouse-server-image == 'clickhouse/clickhouse-server:25.6.9.98' }}
|
||||
if: ${{ inputs.person-on-events != 'true' && inputs.clickhouse-server-image == 'clickhouse/clickhouse-server:25.8.11.66' }}
|
||||
with:
|
||||
name: timing_data-${{ inputs.segment }}-${{ inputs.group }}
|
||||
path: .test_durations
|
||||
|
||||
133
.github/scripts/post-eval-summary.js
vendored
Normal file
133
.github/scripts/post-eval-summary.js
vendored
Normal file
@@ -0,0 +1,133 @@
|
||||
// Export for use with actions/github-script
|
||||
|
||||
const DIFF_THRESHOLD = 0.02
|
||||
|
||||
module.exports = ({ github, context, fs }) => {
|
||||
// Read the eval results
|
||||
const evalResults = fs
|
||||
.readFileSync('eval_results.jsonl', 'utf8')
|
||||
.trim()
|
||||
.split('\n')
|
||||
.filter((line) => line.trim().length > 0)
|
||||
.map((line) => JSON.parse(line))
|
||||
|
||||
if (evalResults.length === 0) {
|
||||
console.warn('No eval results found')
|
||||
return
|
||||
}
|
||||
|
||||
// Calculate max diff for each experiment and categorize
|
||||
const experimentsWithMaxDiff = evalResults.map((result) => {
|
||||
const scores = result.scores || {}
|
||||
const diffs = Object.values(scores)
|
||||
.map((v) => v.diff)
|
||||
.filter((d) => typeof d === 'number')
|
||||
const minDiff = diffs.length > 0 ? Math.min(...diffs) : 0
|
||||
const maxDiff = diffs.length > 0 ? Math.max(...diffs) : 0
|
||||
const category = minDiff < -DIFF_THRESHOLD ? 'regression' : maxDiff > DIFF_THRESHOLD ? 'improvement' : 'neutral'
|
||||
|
||||
return {
|
||||
result,
|
||||
category,
|
||||
maxDiffInCategoryAbs:
|
||||
category === 'regression'
|
||||
? -minDiff
|
||||
: category === 'improvement'
|
||||
? maxDiff
|
||||
: Math.max(Math.abs(minDiff), Math.abs(maxDiff)),
|
||||
}
|
||||
})
|
||||
|
||||
// Sort: regressions first (most negative to least), then improvements (most positive to least), then neutral
|
||||
experimentsWithMaxDiff.sort((a, b) => {
|
||||
if (a.category === b.category) {
|
||||
return b.maxDiffInCategoryAbs - a.maxDiffInCategoryAbs
|
||||
}
|
||||
const order = { regression: 0, improvement: 1, neutral: 2 }
|
||||
return order[a.category] - order[b.category]
|
||||
})
|
||||
|
||||
// Generate experiment summaries
|
||||
const experimentSummaries = experimentsWithMaxDiff.map(({ result }) => {
|
||||
// Format scores as bullet points with improvements/regressions and baseline comparison
|
||||
const scoresList = Object.entries(result.scores || {})
|
||||
.map(([key, value]) => {
|
||||
const score = typeof value.score === 'number' ? `${(value.score * 100).toFixed(2)}%` : value.score
|
||||
let baselineComparison = null
|
||||
const diffHighlight = Math.abs(value.diff) > DIFF_THRESHOLD ? '**' : ''
|
||||
let diffEmoji = '🆕'
|
||||
if (result.comparison_experiment_name?.startsWith('master-')) {
|
||||
baselineComparison = `${diffHighlight}${value.diff > 0 ? '+' : value.diff < 0 ? '' : '±'}${(
|
||||
value.diff * 100
|
||||
).toFixed(
|
||||
2
|
||||
)}%${diffHighlight} (improvements: ${value.improvements}, regressions: ${value.regressions})`
|
||||
diffEmoji = value.diff > DIFF_THRESHOLD ? '🟢' : value.diff < -DIFF_THRESHOLD ? '🔴' : '🔵'
|
||||
}
|
||||
return `${diffEmoji} **${key}**: **${score}**${baselineComparison ? `, ${baselineComparison}` : ''}`
|
||||
})
|
||||
.join('\n')
|
||||
|
||||
// Format key metrics concisely
|
||||
const metrics = result.metrics || {}
|
||||
const duration = metrics.duration ? `⏱️ ${metrics.duration.metric.toFixed(2)} s` : null
|
||||
const totalTokens = metrics.total_tokens ? `🔢 ${Math.floor(metrics.total_tokens.metric)} tokens` : null
|
||||
const cost = metrics.estimated_cost ? `💵 $${metrics.estimated_cost.metric.toFixed(4)} in tokens` : null
|
||||
const metricsText = [duration, totalTokens, cost].filter(Boolean).join(', ')
|
||||
const baselineLink = `[${result.comparison_experiment_name}](${result.project_url}/experiments/${result.comparison_experiment_name})`
|
||||
|
||||
// Create concise experiment summary with header only showing experiment name
|
||||
const experimentName = result.project_name.replace(/^max-ai-/, '')
|
||||
|
||||
return [
|
||||
`### [${experimentName}](${result.experiment_url})`,
|
||||
scoresList,
|
||||
`Baseline: ${baselineLink} • Avg. case performance: ${metricsText}`,
|
||||
].join('\n\n')
|
||||
})
|
||||
|
||||
// Split summaries by category
|
||||
const regressions = []
|
||||
const improvements = []
|
||||
const neutral = []
|
||||
experimentsWithMaxDiff.forEach(({ category }, idx) => {
|
||||
if (category === 'regression') regressions.push(experimentSummaries[idx])
|
||||
else if (category === 'improvement') improvements.push(experimentSummaries[idx])
|
||||
else neutral.push(experimentSummaries[idx])
|
||||
})
|
||||
|
||||
const totalExperiments = evalResults.length
|
||||
const totalMetrics = evalResults.reduce((acc, result) => acc + Object.keys(result.scores || {}).length, 0)
|
||||
|
||||
const bodyParts = [
|
||||
`## 🧠 AI eval results`,
|
||||
`Evaluated **${totalExperiments}** experiments, comprising **${totalMetrics}** metrics. Showing experiments with largest regressions first.`,
|
||||
]
|
||||
|
||||
bodyParts.push(...regressions)
|
||||
bodyParts.push(...improvements)
|
||||
if (neutral.length > 0) {
|
||||
bodyParts.push(
|
||||
`<details><summary>${neutral.length} ${neutral.length === 1 ? 'experiment' : 'experiments'} with no significant changes</summary>\n\n${neutral.join('\n\n')}\n\n</details>`
|
||||
)
|
||||
}
|
||||
|
||||
bodyParts.push(
|
||||
`_Triggered by [this commit](https://github.com/${context.repo.owner}/${context.repo.repo}/pull/${context.payload.pull_request.number}/commits/${context.payload.pull_request.head.sha})._`
|
||||
)
|
||||
|
||||
const body = bodyParts.join('\n\n')
|
||||
|
||||
// Post comment on PR
|
||||
if (context.payload.pull_request) {
|
||||
github.rest.issues.createComment({
|
||||
issue_number: context.issue.number,
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
body: body,
|
||||
})
|
||||
} else {
|
||||
// Just log the summary if this is a push to master
|
||||
console.info(body)
|
||||
}
|
||||
}
|
||||
76
.github/workflows/ci-ai.yml
vendored
76
.github/workflows/ci-ai.yml
vendored
@@ -80,80 +80,10 @@ jobs:
|
||||
- name: Post eval summary to PR
|
||||
# always() because we want to post even if `pytest` exited with an error (likely just one eval suite errored)
|
||||
if: always() && github.event_name == 'pull_request'
|
||||
uses: actions/github-script@v6
|
||||
uses: actions/github-script@v8
|
||||
with:
|
||||
github-token: ${{ secrets.POSTHOG_BOT_PAT }}
|
||||
script: |
|
||||
const fs = require("fs")
|
||||
|
||||
// Read the eval results
|
||||
const evalResults = fs
|
||||
.readFileSync("eval_results.jsonl", "utf8")
|
||||
.trim()
|
||||
.split("\n")
|
||||
.map((line) => JSON.parse(line))
|
||||
|
||||
if (evalResults.length === 0) {
|
||||
console.log("No eval results found")
|
||||
return
|
||||
}
|
||||
|
||||
// Generate concise experiment summaries
|
||||
const experimentSummaries = evalResults.map((result) => {
|
||||
// Format scores as bullet points with improvements/regressions and baseline comparison
|
||||
const scoresList = Object.entries(result.scores || {})
|
||||
.map(([key, value]) => {
|
||||
const score = typeof value.score === "number" ? `${(value.score * 100).toFixed(2)}%` : value.score
|
||||
let baselineComparison = null
|
||||
const diffHighlight = Math.abs(value.diff) > 0.01 ? "**" : ""
|
||||
let diffEmoji = "🆕"
|
||||
if (result.comparison_experiment_name?.startsWith("master-")) {
|
||||
baselineComparison = `${diffHighlight}${value.diff > 0 ? "+" : value.diff < 0 ? "" : "±"}${(
|
||||
value.diff * 100
|
||||
).toFixed(2)}%${diffHighlight} (improvements: ${value.improvements}, regressions: ${value.regressions})`
|
||||
diffEmoji = value.diff > 0.01 ? "🟢" : value.diff < -0.01 ? "🔴" : "🔵"
|
||||
}
|
||||
return `${diffEmoji} **${key}**: **${score}**${baselineComparison ? `, ${baselineComparison}` : ""}`
|
||||
})
|
||||
.join("\n")
|
||||
|
||||
// Format key metrics concisely
|
||||
const metrics = result.metrics || {}
|
||||
const duration = metrics.duration ? `⏱️ ${metrics.duration.metric.toFixed(2)} s` : null
|
||||
const totalTokens = metrics.total_tokens ? `🔢 ${Math.floor(metrics.total_tokens.metric)} tokens` : null
|
||||
const cost = metrics.estimated_cost ? `💵 $${metrics.estimated_cost.metric.toFixed(4)} in tokens` : null
|
||||
const metricsText = [duration, totalTokens, cost].filter(Boolean).join(", ")
|
||||
const baselineLink = `[${result.comparison_experiment_name}](${result.project_url}/experiments/${result.comparison_experiment_name})`
|
||||
|
||||
// Create concise experiment summary with header only showing experiment name
|
||||
const experimentName = result.project_name.replace(/^max-ai-/, "")
|
||||
|
||||
return [
|
||||
`### [${experimentName}](${result.experiment_url})`,
|
||||
scoresList,
|
||||
`Baseline: ${baselineLink} • Avg. case performance: ${metricsText}`,
|
||||
].join("\n\n")
|
||||
})
|
||||
|
||||
const totalExperiments = evalResults.length
|
||||
const totalMetrics = evalResults.reduce((acc, result) => acc + Object.keys(result.scores || {}).length, 0)
|
||||
|
||||
const body = [
|
||||
`## 🧠 AI eval results`,
|
||||
`Evaluated **${totalExperiments}** experiments, comprising **${totalMetrics}** metrics.`,
|
||||
...experimentSummaries,
|
||||
`_Triggered by [this commit](https://github.com/${context.repo.owner}/${context.repo.repo}/pull/${context.payload.pull_request.number}/commits/${context.payload.pull_request.head.sha})._`,
|
||||
].join("\n\n")
|
||||
|
||||
// Post comment on PR
|
||||
if (context.payload.pull_request) {
|
||||
github.rest.issues.createComment({
|
||||
issue_number: context.issue.number,
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
body: body,
|
||||
})
|
||||
} else {
|
||||
// Just log the summary if this is a push to master
|
||||
console.log(body)
|
||||
}
|
||||
const script = require('.github/scripts/post-eval-summary.js')
|
||||
script({ github, context, fs })
|
||||
|
||||
@@ -29,13 +29,13 @@ jobs:
|
||||
group: 1
|
||||
token: ${{ secrets.POSTHOG_BOT_PAT }}
|
||||
python-version: '3.12.11'
|
||||
clickhouse-server-image: 'clickhouse/clickhouse-server:25.6.9.98'
|
||||
clickhouse-server-image: 'clickhouse/clickhouse-server:25.8.11.66'
|
||||
segment: 'FOSS'
|
||||
person-on-events: false
|
||||
|
||||
- name: Upload updated timing data as artifacts
|
||||
uses: actions/upload-artifact@v4
|
||||
if: ${{ inputs.person-on-events != 'true' && inputs.clickhouse-server-image == 'clickhouse/clickhouse-server:25.6.9.98' }}
|
||||
if: ${{ inputs.person-on-events != 'true' && inputs.clickhouse-server-image == 'clickhouse/clickhouse-server:25.8.11.66' }}
|
||||
with:
|
||||
name: timing_data-${{ inputs.segment }}-${{ inputs.group }}
|
||||
path: .test_durations
|
||||
|
||||
78
.github/workflows/ci-backend.yml
vendored
78
.github/workflows/ci-backend.yml
vendored
@@ -389,7 +389,7 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python-version: ['3.12.11']
|
||||
clickhouse-server-image: ['clickhouse/clickhouse-server:25.6.9.98']
|
||||
clickhouse-server-image: ['clickhouse/clickhouse-server:25.8.11.66']
|
||||
segment: ['Core']
|
||||
person-on-events: [false]
|
||||
# :NOTE: Keep concurrency and groups in sync
|
||||
@@ -440,121 +440,121 @@ jobs:
|
||||
include:
|
||||
- segment: 'Core'
|
||||
person-on-events: true
|
||||
clickhouse-server-image: 'clickhouse/clickhouse-server:25.6.9.98'
|
||||
clickhouse-server-image: 'clickhouse/clickhouse-server:25.8.11.66'
|
||||
python-version: '3.12.11'
|
||||
concurrency: 10
|
||||
group: 1
|
||||
- segment: 'Core'
|
||||
person-on-events: true
|
||||
clickhouse-server-image: 'clickhouse/clickhouse-server:25.6.9.98'
|
||||
clickhouse-server-image: 'clickhouse/clickhouse-server:25.8.11.66'
|
||||
python-version: '3.12.11'
|
||||
concurrency: 10
|
||||
group: 2
|
||||
- segment: 'Core'
|
||||
person-on-events: true
|
||||
clickhouse-server-image: 'clickhouse/clickhouse-server:25.6.9.98'
|
||||
clickhouse-server-image: 'clickhouse/clickhouse-server:25.8.11.66'
|
||||
python-version: '3.12.11'
|
||||
concurrency: 10
|
||||
group: 3
|
||||
- segment: 'Core'
|
||||
person-on-events: true
|
||||
clickhouse-server-image: 'clickhouse/clickhouse-server:25.6.9.98'
|
||||
clickhouse-server-image: 'clickhouse/clickhouse-server:25.8.11.66'
|
||||
python-version: '3.12.11'
|
||||
concurrency: 10
|
||||
group: 4
|
||||
- segment: 'Core'
|
||||
person-on-events: true
|
||||
clickhouse-server-image: 'clickhouse/clickhouse-server:25.6.9.98'
|
||||
clickhouse-server-image: 'clickhouse/clickhouse-server:25.8.11.66'
|
||||
python-version: '3.12.11'
|
||||
concurrency: 10
|
||||
group: 5
|
||||
- segment: 'Core'
|
||||
person-on-events: true
|
||||
clickhouse-server-image: 'clickhouse/clickhouse-server:25.6.9.98'
|
||||
clickhouse-server-image: 'clickhouse/clickhouse-server:25.8.11.66'
|
||||
python-version: '3.12.11'
|
||||
concurrency: 10
|
||||
group: 6
|
||||
- segment: 'Core'
|
||||
person-on-events: true
|
||||
clickhouse-server-image: 'clickhouse/clickhouse-server:25.6.9.98'
|
||||
clickhouse-server-image: 'clickhouse/clickhouse-server:25.8.11.66'
|
||||
python-version: '3.12.11'
|
||||
concurrency: 10
|
||||
group: 7
|
||||
- segment: 'Core'
|
||||
person-on-events: true
|
||||
clickhouse-server-image: 'clickhouse/clickhouse-server:25.6.9.98'
|
||||
clickhouse-server-image: 'clickhouse/clickhouse-server:25.8.11.66'
|
||||
python-version: '3.12.11'
|
||||
concurrency: 10
|
||||
group: 8
|
||||
- segment: 'Core'
|
||||
person-on-events: true
|
||||
clickhouse-server-image: 'clickhouse/clickhouse-server:25.6.9.98'
|
||||
clickhouse-server-image: 'clickhouse/clickhouse-server:25.8.11.66'
|
||||
python-version: '3.12.11'
|
||||
concurrency: 10
|
||||
group: 9
|
||||
- segment: 'Core'
|
||||
person-on-events: true
|
||||
clickhouse-server-image: 'clickhouse/clickhouse-server:25.6.9.98'
|
||||
clickhouse-server-image: 'clickhouse/clickhouse-server:25.8.11.66'
|
||||
python-version: '3.12.11'
|
||||
concurrency: 10
|
||||
group: 10
|
||||
- segment: 'Temporal'
|
||||
person-on-events: false
|
||||
clickhouse-server-image: 'clickhouse/clickhouse-server:25.6.9.98'
|
||||
clickhouse-server-image: 'clickhouse/clickhouse-server:25.8.11.66'
|
||||
python-version: '3.12.11'
|
||||
concurrency: 10
|
||||
group: 1
|
||||
- segment: 'Temporal'
|
||||
person-on-events: false
|
||||
clickhouse-server-image: 'clickhouse/clickhouse-server:25.6.9.98'
|
||||
clickhouse-server-image: 'clickhouse/clickhouse-server:25.8.11.66'
|
||||
python-version: '3.12.11'
|
||||
concurrency: 10
|
||||
group: 2
|
||||
- segment: 'Temporal'
|
||||
person-on-events: false
|
||||
clickhouse-server-image: 'clickhouse/clickhouse-server:25.6.9.98'
|
||||
clickhouse-server-image: 'clickhouse/clickhouse-server:25.8.11.66'
|
||||
python-version: '3.12.11'
|
||||
concurrency: 10
|
||||
group: 3
|
||||
- segment: 'Temporal'
|
||||
person-on-events: false
|
||||
clickhouse-server-image: 'clickhouse/clickhouse-server:25.6.9.98'
|
||||
clickhouse-server-image: 'clickhouse/clickhouse-server:25.8.11.66'
|
||||
python-version: '3.12.11'
|
||||
concurrency: 10
|
||||
group: 4
|
||||
- segment: 'Temporal'
|
||||
person-on-events: false
|
||||
clickhouse-server-image: 'clickhouse/clickhouse-server:25.6.9.98'
|
||||
clickhouse-server-image: 'clickhouse/clickhouse-server:25.8.11.66'
|
||||
python-version: '3.12.11'
|
||||
concurrency: 10
|
||||
group: 5
|
||||
- segment: 'Temporal'
|
||||
person-on-events: false
|
||||
clickhouse-server-image: 'clickhouse/clickhouse-server:25.6.9.98'
|
||||
clickhouse-server-image: 'clickhouse/clickhouse-server:25.8.11.66'
|
||||
python-version: '3.12.11'
|
||||
concurrency: 10
|
||||
group: 6
|
||||
- segment: 'Temporal'
|
||||
person-on-events: false
|
||||
clickhouse-server-image: 'clickhouse/clickhouse-server:25.6.9.98'
|
||||
clickhouse-server-image: 'clickhouse/clickhouse-server:25.8.11.66'
|
||||
python-version: '3.12.11'
|
||||
concurrency: 10
|
||||
group: 7
|
||||
- segment: 'Temporal'
|
||||
person-on-events: false
|
||||
clickhouse-server-image: 'clickhouse/clickhouse-server:25.6.9.98'
|
||||
clickhouse-server-image: 'clickhouse/clickhouse-server:25.8.11.66'
|
||||
python-version: '3.12.11'
|
||||
concurrency: 10
|
||||
group: 8
|
||||
- segment: 'Temporal'
|
||||
person-on-events: false
|
||||
clickhouse-server-image: 'clickhouse/clickhouse-server:25.6.9.98'
|
||||
clickhouse-server-image: 'clickhouse/clickhouse-server:25.8.11.66'
|
||||
python-version: '3.12.11'
|
||||
concurrency: 10
|
||||
group: 9
|
||||
- segment: 'Temporal'
|
||||
person-on-events: false
|
||||
clickhouse-server-image: 'clickhouse/clickhouse-server:25.6.9.98'
|
||||
clickhouse-server-image: 'clickhouse/clickhouse-server:25.8.11.66'
|
||||
python-version: '3.12.11'
|
||||
concurrency: 10
|
||||
group: 10
|
||||
@@ -649,6 +649,22 @@ jobs:
|
||||
run: |
|
||||
sudo apt-get update && sudo apt-get install libxml2-dev libxmlsec1-dev libxmlsec1-openssl
|
||||
|
||||
- name: Install Rust
|
||||
if: needs.changes.outputs.backend == 'true'
|
||||
uses: dtolnay/rust-toolchain@6691ebadcb18182cc1391d07c9f295f657c593cd # 1.88
|
||||
with:
|
||||
toolchain: 1.88.0
|
||||
components: cargo
|
||||
|
||||
- name: Cache Rust dependencies
|
||||
if: needs.changes.outputs.backend == 'true'
|
||||
uses: Swatinem/rust-cache@f13886b937689c021905a6b90929199931d60db1 # v2.8.1
|
||||
|
||||
- name: Install sqlx-cli
|
||||
if: needs.changes.outputs.backend == 'true'
|
||||
run: |
|
||||
cargo install sqlx-cli --version 0.8.0 --features postgres --no-default-features --locked
|
||||
|
||||
- name: Determine if hogql-parser has changed compared to master
|
||||
shell: bash
|
||||
id: hogql-parser-diff
|
||||
@@ -796,7 +812,8 @@ jobs:
|
||||
shell: bash
|
||||
env:
|
||||
AWS_S3_ALLOW_UNSAFE_RENAME: 'true'
|
||||
RUNLOOP_API_KEY: ${{ needs.changes.outputs.tasks_temporal == 'true' && secrets.RUNLOOP_API_KEY || '' }}
|
||||
MODAL_TOKEN_ID: ${{ needs.changes.outputs.tasks_temporal == 'true' && secrets.MODAL_TOKEN_ID || '' }}
|
||||
MODAL_TOKEN_SECRET: ${{ needs.changes.outputs.tasks_temporal == 'true' && secrets.MODAL_TOKEN_SECRET || '' }}
|
||||
run: |
|
||||
set +e
|
||||
pytest posthog/temporal products/batch_exports/backend/tests/temporal products/tasks/backend/temporal -m "not async_migrations" \
|
||||
@@ -821,7 +838,7 @@ jobs:
|
||||
|
||||
- name: Upload updated timing data as artifacts
|
||||
uses: actions/upload-artifact@v4
|
||||
if: ${{ needs.changes.outputs.backend == 'true' && !matrix.person-on-events && matrix.clickhouse-server-image == 'clickhouse/clickhouse-server:25.6.9.98' }}
|
||||
if: ${{ needs.changes.outputs.backend == 'true' && !matrix.person-on-events && matrix.clickhouse-server-image == 'clickhouse/clickhouse-server:25.8.11.66' }}
|
||||
with:
|
||||
name: timing_data-${{ matrix.segment }}-${{ matrix.group }}
|
||||
path: .test_durations
|
||||
@@ -906,7 +923,7 @@ jobs:
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
clickhouse-server-image: ['clickhouse/clickhouse-server:25.6.9.98']
|
||||
clickhouse-server-image: ['clickhouse/clickhouse-server:25.8.11.66']
|
||||
if: needs.changes.outputs.backend == 'true'
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
@@ -945,6 +962,19 @@ jobs:
|
||||
sudo apt-get update
|
||||
sudo apt-get install libxml2-dev libxmlsec1-dev libxmlsec1-openssl
|
||||
|
||||
- name: Install Rust
|
||||
uses: dtolnay/rust-toolchain@6691ebadcb18182cc1391d07c9f295f657c593cd # 1.88
|
||||
with:
|
||||
toolchain: 1.88.0
|
||||
components: cargo
|
||||
|
||||
- name: Cache Rust dependencies
|
||||
uses: Swatinem/rust-cache@f13886b937689c021905a6b90929199931d60db1 # v2.8.1
|
||||
|
||||
- name: Install sqlx-cli
|
||||
run: |
|
||||
cargo install sqlx-cli --version 0.8.0 --features postgres --no-default-features --locked
|
||||
|
||||
- name: Install python dependencies
|
||||
shell: bash
|
||||
run: |
|
||||
|
||||
2
.github/workflows/ci-dagster.yml
vendored
2
.github/workflows/ci-dagster.yml
vendored
@@ -65,7 +65,7 @@ jobs:
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
clickhouse-server-image: ['clickhouse/clickhouse-server:25.6.9.98']
|
||||
clickhouse-server-image: ['clickhouse/clickhouse-server:25.8.11.66']
|
||||
if: needs.changes.outputs.dagster == 'true'
|
||||
runs-on: depot-ubuntu-latest
|
||||
steps:
|
||||
|
||||
39
.github/workflows/ci-e2e-playwright.yml
vendored
39
.github/workflows/ci-e2e-playwright.yml
vendored
@@ -43,9 +43,9 @@ env:
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
|
||||
# This is so that the workflow run isn't canceled when a screenshot update is pushed within it by posthog-bot
|
||||
# We do however cancel from container-images-ci.yml if a commit is pushed by someone OTHER than posthog-bot
|
||||
cancel-in-progress: false
|
||||
# Cancel in-progress runs when new commits are pushed
|
||||
# SHA verification ensures we don't commit stale snapshots if runs aren't cancelled in time
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
changes:
|
||||
@@ -120,7 +120,7 @@ jobs:
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
ref: ${{ github.event.pull_request.head.ref }}
|
||||
ref: ${{ github.event.pull_request.head.sha }}
|
||||
repository: ${{ github.event.pull_request.head.repo.full_name }}
|
||||
token: ${{ secrets.POSTHOG_BOT_PAT || github.token }}
|
||||
fetch-depth: 50 # Need enough history for flap detection to find last human commit
|
||||
@@ -134,7 +134,7 @@ jobs:
|
||||
- name: Stop/Start stack with Docker Compose
|
||||
shell: bash
|
||||
run: |
|
||||
export CLICKHOUSE_SERVER_IMAGE=clickhouse/clickhouse-server:25.6.9.98
|
||||
export CLICKHOUSE_SERVER_IMAGE=clickhouse/clickhouse-server:25.8.11.66
|
||||
export DOCKER_REGISTRY_PREFIX="us-east1-docker.pkg.dev/posthog-301601/mirror/"
|
||||
cp posthog/user_scripts/latest_user_defined_function.xml docker/clickhouse/user_defined_function.xml
|
||||
|
||||
@@ -495,7 +495,7 @@ jobs:
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
ref: ${{ github.event.pull_request.head.ref }}
|
||||
ref: ${{ github.event.pull_request.head.sha }}
|
||||
token: ${{ secrets.POSTHOG_BOT_PAT || github.token }}
|
||||
|
||||
- name: Download screenshot patches
|
||||
@@ -535,8 +535,33 @@ jobs:
|
||||
workflow-type: playwright
|
||||
patch-path: /tmp/screenshot-patches/
|
||||
snapshot-path: playwright/
|
||||
commit-message: Update E2E screenshots (Playwright)
|
||||
commit-message: 'test(e2e): update screenshots'
|
||||
pr-number: ${{ github.event.pull_request.number }}
|
||||
repository: ${{ github.repository }}
|
||||
commit-sha: ${{ github.event.pull_request.head.sha }}
|
||||
branch-name: ${{ github.event.pull_request.head.ref }}
|
||||
github-token: ${{ secrets.POSTHOG_BOT_PAT || github.token }}
|
||||
|
||||
# Job to collate the status of the matrix jobs for requiring passing status
|
||||
# Must depend on handle-screenshots to prevent auto-merge before commits complete
|
||||
playwright_tests:
|
||||
needs: [playwright, handle-screenshots]
|
||||
name: Playwright tests pass
|
||||
runs-on: ubuntu-latest
|
||||
if: always()
|
||||
steps:
|
||||
- name: Check matrix outcome
|
||||
run: |
|
||||
# Check playwright matrix result
|
||||
if [[ "${{ needs.playwright.result }}" != "success" && "${{ needs.playwright.result }}" != "skipped" ]]; then
|
||||
echo "One or more jobs in the Playwright test matrix failed."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Check handle-screenshots result (OK if skipped, but fail if it failed)
|
||||
if [[ "${{ needs.handle-screenshots.result }}" == "failure" ]]; then
|
||||
echo "Screenshot commit job failed."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "All jobs passed or were skipped successfully."
|
||||
|
||||
41
.github/workflows/ci-mcp.yml
vendored
41
.github/workflows/ci-mcp.yml
vendored
@@ -26,47 +26,6 @@ jobs:
|
||||
- '.github/workflows/mcp-ci.yml'
|
||||
- '.github/workflows/mcp-publish.yml'
|
||||
|
||||
lint-and-format:
|
||||
name: Lint, Format, and Type Check
|
||||
runs-on: ubuntu-latest
|
||||
needs: changes
|
||||
if: needs.changes.outputs.mcp == 'true'
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Install pnpm
|
||||
uses: pnpm/action-setup@a7487c7e89a18df4991f7f222e4898a00d66ddda # v4
|
||||
|
||||
- name: Setup Node.js
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: 22
|
||||
cache: 'pnpm'
|
||||
|
||||
- name: Install dependencies
|
||||
run: pnpm install
|
||||
|
||||
- name: Run linter
|
||||
run: cd products/mcp && pnpm run lint
|
||||
|
||||
- name: Run formatter
|
||||
run: cd products/mcp && pnpm run format
|
||||
|
||||
- name: Run type check
|
||||
run: cd products/mcp && pnpm run typecheck
|
||||
|
||||
- name: Check for changes
|
||||
run: |
|
||||
if [ -n "$(git status --porcelain)" ]; then
|
||||
echo "Code formatting or linting changes detected!"
|
||||
git diff
|
||||
exit 1
|
||||
fi
|
||||
|
||||
unit-tests:
|
||||
name: Unit Tests
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
5
.github/workflows/ci-python.yml
vendored
5
.github/workflows/ci-python.yml
vendored
@@ -135,11 +135,6 @@ jobs:
|
||||
run: |
|
||||
mypy --version && mypy --cache-fine-grained . | mypy-baseline filter || (echo "run 'pnpm run mypy-baseline-sync' to update the baseline" && exit 1)
|
||||
|
||||
- name: Check if "schema.py" is up to date
|
||||
shell: bash
|
||||
run: |
|
||||
npm run schema:build:python && git diff --exit-code
|
||||
|
||||
- name: Check hogli manifest completeness
|
||||
shell: bash
|
||||
run: |
|
||||
|
||||
34
.github/workflows/ci-storybook.yml
vendored
34
.github/workflows/ci-storybook.yml
vendored
@@ -4,9 +4,9 @@ on:
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
|
||||
# This is so that the workflow run isn't canceled when a snapshot update is pushed within it by posthog-bot
|
||||
# We do however cancel from container-images-ci.yml if a commit is pushed by someone OTHER than posthog-bot
|
||||
cancel-in-progress: false
|
||||
# Cancel in-progress runs when new commits are pushed
|
||||
# SHA verification ensures we don't commit stale snapshots if runs aren't cancelled in time
|
||||
cancel-in-progress: true
|
||||
|
||||
permissions:
|
||||
contents: write
|
||||
@@ -48,7 +48,7 @@ jobs:
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
ref: ${{ github.event.pull_request.head.ref }}
|
||||
ref: ${{ github.event.pull_request.head.sha }}
|
||||
repository: ${{ github.event.pull_request.head.repo.full_name }}
|
||||
# Use PostHog Bot token when not on forks to enable proper snapshot updating
|
||||
token: ${{ secrets.POSTHOG_BOT_PAT || github.token }}
|
||||
@@ -124,7 +124,7 @@ jobs:
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
ref: ${{ github.event.pull_request.head.ref }}
|
||||
ref: ${{ github.event.pull_request.head.sha }}
|
||||
repository: ${{ github.event.pull_request.head.repo.full_name }}
|
||||
# Use PostHog Bot token when not on forks to enable proper snapshot updating
|
||||
token: ${{ secrets.POSTHOG_BOT_PAT || github.token }}
|
||||
@@ -232,7 +232,7 @@ jobs:
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
ref: ${{ github.event.pull_request.head.ref }}
|
||||
ref: ${{ github.event.pull_request.head.sha }}
|
||||
repository: ${{ github.event.pull_request.head.repo.full_name }}
|
||||
# Use PostHog Bot token when not on forks to enable proper snapshot updating
|
||||
token: ${{ secrets.POSTHOG_BOT_PAT || github.token }}
|
||||
@@ -339,7 +339,7 @@ jobs:
|
||||
|
||||
if [ $EXIT_CODE -ne 0 ]; then
|
||||
# Still failing - check if we should update snapshots
|
||||
if [ "${{ needs.detect-snapshot-mode.outputs.mode }}" == "update" ] && grep -Eqi "(snapshot.*(failed|was not written)|Expected image to be the same size)" /tmp/test-output-${{ matrix.browser }}-${{ matrix.shard }}-attempt2.log 2>/dev/null; then
|
||||
if [ "${{ needs.detect-snapshot-mode.outputs.mode }}" == "update" ] && grep -Eqi "(snapshot.*(failed|was not written)|Expected image to)" /tmp/test-output-${{ matrix.browser }}-${{ matrix.shard }}-attempt2.log 2>/dev/null; then
|
||||
echo "Snapshot failure confirmed on attempt 2 - updating snapshots"
|
||||
# Attempt 3: Update snapshots
|
||||
echo "Attempt 3: Running @storybook/test-runner (update)"
|
||||
@@ -413,21 +413,28 @@ jobs:
|
||||
if-no-files-found: ignore
|
||||
|
||||
# Job to collate the status of the matrix jobs for requiring passing status
|
||||
# Must depend on handle-snapshots to prevent auto-merge before commits complete
|
||||
visual_regression_tests:
|
||||
needs: [visual-regression]
|
||||
needs: [visual-regression, handle-snapshots]
|
||||
name: Visual regression tests pass
|
||||
runs-on: ubuntu-latest
|
||||
if: always()
|
||||
steps:
|
||||
- name: Check matrix outcome
|
||||
run: |
|
||||
# The `needs.visual-regression.result` will be 'success' only if all jobs in the matrix succeeded.
|
||||
# Otherwise, it will be 'failure'.
|
||||
# Check visual-regression matrix result
|
||||
if [[ "${{ needs.visual-regression.result }}" != "success" && "${{ needs.visual-regression.result }}" != "skipped" ]]; then
|
||||
echo "One or more jobs in the visual-regression test matrix failed."
|
||||
exit 1
|
||||
fi
|
||||
echo "All jobs in the visual-regression test matrix passed."
|
||||
|
||||
# Check handle-snapshots result (OK if skipped, but fail if it failed)
|
||||
if [[ "${{ needs.handle-snapshots.result }}" == "failure" ]]; then
|
||||
echo "Snapshot commit job failed."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "All jobs passed or were skipped successfully."
|
||||
|
||||
handle-snapshots:
|
||||
name: Handle snapshot changes
|
||||
@@ -437,7 +444,7 @@ jobs:
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
ref: ${{ github.event.pull_request.head.ref }}
|
||||
ref: ${{ github.event.pull_request.head.sha }}
|
||||
repository: ${{ github.event.pull_request.head.repo.full_name }}
|
||||
token: ${{ secrets.POSTHOG_BOT_PAT || github.token }}
|
||||
fetch-depth: 1
|
||||
@@ -479,10 +486,11 @@ jobs:
|
||||
workflow-type: storybook
|
||||
patch-path: /tmp/snapshot-patches/
|
||||
snapshot-path: frontend/__snapshots__/
|
||||
commit-message: Update UI snapshots (all browsers)
|
||||
commit-message: 'test(storybook): update UI snapshots'
|
||||
pr-number: ${{ github.event.pull_request.number }}
|
||||
repository: ${{ github.repository }}
|
||||
commit-sha: ${{ github.event.pull_request.head.sha }}
|
||||
branch-name: ${{ github.event.pull_request.head.ref }}
|
||||
github-token: ${{ secrets.POSTHOG_BOT_PAT || github.token }}
|
||||
|
||||
calculate-running-time:
|
||||
|
||||
2
.github/workflows/mcp-publish.yml
vendored
2
.github/workflows/mcp-publish.yml
vendored
@@ -23,7 +23,7 @@ jobs:
|
||||
with:
|
||||
filters: |
|
||||
mcp:
|
||||
- 'mcp/**'
|
||||
- 'products/mcp/**'
|
||||
- '.github/workflows/mcp-ci.yml'
|
||||
- '.github/workflows/mcp-publish.yml'
|
||||
|
||||
|
||||
@@ -16,7 +16,10 @@
|
||||
"common/plugin_transpiler/dist",
|
||||
"common/hogvm/__tests__/__snapshots__",
|
||||
"common/hogvm/__tests__/__snapshots__/**",
|
||||
"frontend/src/queries/validators.js"
|
||||
"frontend/src/queries/validators.js",
|
||||
"products/mcp/**/generated.ts",
|
||||
"products/mcp/schema/tool-inputs.json",
|
||||
"products/mcp/python"
|
||||
],
|
||||
"rules": {
|
||||
"no-constant-condition": "off",
|
||||
|
||||
@@ -6,6 +6,7 @@ staticfiles
|
||||
.env
|
||||
*.code-workspace
|
||||
.mypy_cache
|
||||
.cache/
|
||||
*Type.ts
|
||||
.idea
|
||||
.yalc
|
||||
@@ -15,8 +16,7 @@ common/storybook/dist/
|
||||
dist/
|
||||
node_modules/
|
||||
pnpm-lock.yaml
|
||||
posthog/templates/email/*
|
||||
posthog/templates/**/*.html
|
||||
posthog/templates/**/*
|
||||
common/hogvm/typescript/src/stl/bytecode.ts
|
||||
common/hogvm/__tests__/__snapshots__/*
|
||||
rust/
|
||||
@@ -26,3 +26,10 @@ cli/tests/_cases/**/*
|
||||
frontend/src/products.tsx
|
||||
frontend/src/layout.html
|
||||
**/fixtures/**
|
||||
products/mcp/schema/**/*
|
||||
products/mcp/python/**/*
|
||||
products/mcp/**/generated.ts
|
||||
products/mcp/typescript/worker-configuration.d.ts
|
||||
frontend/src/taxonomy/core-filter-definitions-by-group.json
|
||||
frontend/dist/
|
||||
frontend/**/*LogicType.ts
|
||||
@@ -14,6 +14,7 @@
|
||||
"^@posthog.*$",
|
||||
"^lib/(.*)$|^scenes/(.*)$",
|
||||
"^~/(.*)$",
|
||||
"^@/(.*)$",
|
||||
"^public/(.*)$",
|
||||
"^products/(.*)$",
|
||||
"^storybook/(.*)$",
|
||||
|
||||
@@ -329,8 +329,7 @@ COPY --from=frontend-build --chown=posthog:posthog /code/frontend/dist /code/fro
|
||||
# Copy the GeoLite2-City database from the fetch-geoip-db stage.
|
||||
COPY --from=fetch-geoip-db --chown=posthog:posthog /code/share/GeoLite2-City.mmdb /code/share/GeoLite2-City.mmdb
|
||||
|
||||
# Add in the Gunicorn config, custom bin files and Django deps.
|
||||
COPY --chown=posthog:posthog gunicorn.config.py ./
|
||||
# Add in custom bin files and Django deps.
|
||||
COPY --chown=posthog:posthog ./bin ./bin/
|
||||
COPY --chown=posthog:posthog manage.py manage.py
|
||||
COPY --chown=posthog:posthog posthog posthog/
|
||||
@@ -339,9 +338,6 @@ COPY --chown=posthog:posthog common/hogvm common/hogvm/
|
||||
COPY --chown=posthog:posthog dags dags/
|
||||
COPY --chown=posthog:posthog products products/
|
||||
|
||||
# Keep server command backwards compatible
|
||||
RUN cp ./bin/docker-server-unit ./bin/docker-server
|
||||
|
||||
# Setup ENV.
|
||||
ENV NODE_ENV=production \
|
||||
CHROME_BIN=/usr/bin/chromium \
|
||||
|
||||
@@ -1,12 +1,14 @@
|
||||
-- temporary sql to initialise log tables for local development
|
||||
-- will be removed once we have migrations set up
|
||||
CREATE OR REPLACE FUNCTION extractIPv4Substrings AS
|
||||
(
|
||||
body -> extractAll(body, '(\d\.((25[0-5]|(2[0-4]|1{0,1}[0-9]){0,1}[0-9])\.){2,2}([0-9]))')
|
||||
);
|
||||
-- temporary sql to initialise log tables for local development
|
||||
-- will be removed once we have migrations set up
|
||||
CREATE TABLE if not exists logs27
|
||||
CREATE OR REPLACE TABLE logs31
|
||||
(
|
||||
`time_bucket` DateTime MATERIALIZED toStartOfInterval(timestamp, toIntervalMinute(5)) CODEC(DoubleDelta, ZSTD(1)),
|
||||
-- time bucket is set to day which means it's effectively not in the order by key (same as partition)
|
||||
-- but gives us flexibility to add the bucket to the order key if needed
|
||||
`time_bucket` DateTime MATERIALIZED toStartOfDay(timestamp) CODEC(DoubleDelta, ZSTD(1)),
|
||||
`uuid` String CODEC(ZSTD(1)),
|
||||
`team_id` Int32 CODEC(ZSTD(1)),
|
||||
`trace_id` String CODEC(ZSTD(1)),
|
||||
@@ -20,32 +22,22 @@ CREATE TABLE if not exists logs27
|
||||
`severity_number` Int32 CODEC(ZSTD(1)),
|
||||
`service_name` String CODEC(ZSTD(1)),
|
||||
`resource_attributes` Map(String, String) CODEC(ZSTD(1)),
|
||||
`resource_fingerprint` UInt64 MATERIALIZED cityHash64(resource_attributes) CODEC(DoubleDelta, ZSTD(1)),
|
||||
`resource_id` String CODEC(ZSTD(1)),
|
||||
`instrumentation_scope` String CODEC(ZSTD(1)),
|
||||
`event_name` String CODEC(ZSTD(1)),
|
||||
`attributes_map_str` Map(String, String) CODEC(ZSTD(1)),
|
||||
`attributes_map_float` Map(String, Float64) CODEC(ZSTD(1)),
|
||||
`attributes_map_datetime` Map(String, DateTime64(6)) CODEC(ZSTD(1)),
|
||||
`attributes_map_str` Map(LowCardinality(String), String) CODEC(ZSTD(1)),
|
||||
`attributes_map_float` Map(LowCardinality(String), Float64) CODEC(ZSTD(1)),
|
||||
`attributes_map_datetime` Map(LowCardinality(String), DateTime64(6)) CODEC(ZSTD(1)),
|
||||
`level` String ALIAS severity_text,
|
||||
`mat_body_ipv4_matches` Array(String) ALIAS extractAll(body, '(\\d\\.((25[0-5]|(2[0-4]|1{0,1}[0-9]){0,1}[0-9])\\.){2,2}([0-9]))'),
|
||||
`time_minute` DateTime ALIAS toStartOfMinute(timestamp),
|
||||
`attributes` Map(String, String) ALIAS mapApply((k, v) -> (k, toJSONString(v)), attributes_map_str),
|
||||
`attributes` Map(String, String) ALIAS mapApply((k, v) -> (left(k, -5), v), attributes_map_str),
|
||||
INDEX idx_severity_text_set severity_text TYPE set(10) GRANULARITY 1,
|
||||
INDEX idx_attributes_str_keys mapKeys(attributes_map_str) TYPE bloom_filter(0.01) GRANULARITY 1,
|
||||
INDEX idx_attributes_str_values mapValues(attributes_map_str) TYPE bloom_filter(0.001) GRANULARITY 1,
|
||||
INDEX idx_mat_body_ipv4_matches mat_body_ipv4_matches TYPE bloom_filter(0.01) GRANULARITY 1,
|
||||
INDEX idx_body_ngram3 body TYPE ngrambf_v1(3, 20000, 2, 0) GRANULARITY 1,
|
||||
CONSTRAINT assume_time_bucket ASSUME toStartOfInterval(timestamp, toIntervalMinute(5)) = time_bucket,
|
||||
PROJECTION projection_trace_span
|
||||
(
|
||||
SELECT
|
||||
trace_id,
|
||||
timestamp,
|
||||
_part_offset
|
||||
ORDER BY
|
||||
trace_id,
|
||||
timestamp
|
||||
),
|
||||
INDEX idx_body_ngram3 body TYPE ngrambf_v1(3, 25000, 2, 0) GRANULARITY 1,
|
||||
PROJECTION projection_aggregate_counts
|
||||
(
|
||||
SELECT
|
||||
@@ -54,8 +46,7 @@ CREATE TABLE if not exists logs27
|
||||
toStartOfMinute(timestamp),
|
||||
service_name,
|
||||
severity_text,
|
||||
resource_attributes,
|
||||
resource_id,
|
||||
resource_fingerprint,
|
||||
count() AS event_count
|
||||
GROUP BY
|
||||
team_id,
|
||||
@@ -63,43 +54,47 @@ CREATE TABLE if not exists logs27
|
||||
toStartOfMinute(timestamp),
|
||||
service_name,
|
||||
severity_text,
|
||||
resource_attributes,
|
||||
resource_id
|
||||
resource_fingerprint
|
||||
)
|
||||
)
|
||||
ENGINE = ReplicatedReplacingMergeTree('/clickhouse/tables/{shard}/logs27', '{replica}')
|
||||
ENGINE = MergeTree
|
||||
PARTITION BY toDate(timestamp)
|
||||
ORDER BY (team_id, time_bucket DESC, service_name, resource_attributes, severity_text, timestamp DESC, uuid, trace_id, span_id)
|
||||
SETTINGS allow_remote_fs_zero_copy_replication = 1,
|
||||
allow_experimental_reverse_key = 1,
|
||||
deduplicate_merge_projection_mode = 'ignore';
|
||||
PRIMARY KEY (team_id, time_bucket, service_name, resource_fingerprint, severity_text, timestamp)
|
||||
ORDER BY (team_id, time_bucket, service_name, resource_fingerprint, severity_text, timestamp)
|
||||
SETTINGS
|
||||
index_granularity_bytes = 104857600,
|
||||
index_granularity = 8192,
|
||||
ttl_only_drop_parts = 1;
|
||||
|
||||
create or replace TABLE logs AS logs27 ENGINE = Distributed('posthog', 'default', 'logs27');
|
||||
create or replace TABLE logs AS logs31 ENGINE = Distributed('posthog', 'default', 'logs31');
|
||||
|
||||
create table if not exists log_attributes
|
||||
create or replace table default.log_attributes
|
||||
|
||||
(
|
||||
`team_id` Int32,
|
||||
`time_bucket` DateTime64(0),
|
||||
`service_name` LowCardinality(String),
|
||||
`resource_id` String DEFAULT '',
|
||||
`resource_fingerprint` UInt64 DEFAULT 0,
|
||||
`attribute_key` LowCardinality(String),
|
||||
`attribute_value` String,
|
||||
`attribute_count` SimpleAggregateFunction(sum, UInt64),
|
||||
INDEX idx_attribute_key attribute_key TYPE bloom_filter(0.01) GRANULARITY 1,
|
||||
INDEX idx_attribute_value attribute_value TYPE bloom_filter(0.001) GRANULARITY 1,
|
||||
INDEX idx_attribute_value attribute_value TYPE bloom_filter(0.01) GRANULARITY 1,
|
||||
INDEX idx_attribute_key_n3 attribute_key TYPE ngrambf_v1(3, 32768, 3, 0) GRANULARITY 1,
|
||||
INDEX idx_attribute_value_n3 attribute_value TYPE ngrambf_v1(3, 32768, 3, 0) GRANULARITY 1
|
||||
)
|
||||
ENGINE = ReplicatedAggregatingMergeTree('/clickhouse/tables/{shard}/log_attributes', '{replica}')
|
||||
ENGINE = AggregatingMergeTree
|
||||
PARTITION BY toDate(time_bucket)
|
||||
ORDER BY (team_id, service_name, time_bucket, attribute_key, attribute_value);
|
||||
ORDER BY (team_id, time_bucket, resource_fingerprint, attribute_key, attribute_value);
|
||||
|
||||
set enable_dynamic_type=1;
|
||||
CREATE MATERIALIZED VIEW if not exists log_to_log_attributes TO log_attributes
|
||||
drop view if exists log_to_log_attributes;
|
||||
CREATE MATERIALIZED VIEW log_to_log_attributes TO log_attributes
|
||||
(
|
||||
`team_id` Int32,
|
||||
`time_bucket` DateTime64(0),
|
||||
`service_name` LowCardinality(String),
|
||||
`resource_fingerprint` UInt64,
|
||||
`attribute_key` LowCardinality(String),
|
||||
`attribute_value` String,
|
||||
`attribute_count` SimpleAggregateFunction(sum, UInt64)
|
||||
@@ -108,23 +103,28 @@ AS SELECT
|
||||
team_id,
|
||||
time_bucket,
|
||||
service_name,
|
||||
resource_fingerprint,
|
||||
attribute_key,
|
||||
attribute_value,
|
||||
attribute_count
|
||||
FROM (select
|
||||
team_id AS team_id,
|
||||
toStartOfInterval(timestamp, toIntervalMinute(10)) AS time_bucket,
|
||||
service_name AS service_name,
|
||||
arrayJoin(arrayMap((k, v) -> (k, if(length(v) > 256, '', v)), arrayFilter((k, v) -> (length(k) < 256), CAST(attributes, 'Array(Tuple(String, String))')))) AS attribute,
|
||||
attribute.1 AS attribute_key,
|
||||
CAST(JSONExtract(attribute.2, 'Dynamic'), 'String') AS attribute_value,
|
||||
sumSimpleState(1) AS attribute_count
|
||||
FROM logs27
|
||||
GROUP BY
|
||||
team_id,
|
||||
time_bucket,
|
||||
service_name,
|
||||
attribute
|
||||
FROM
|
||||
(
|
||||
SELECT
|
||||
team_id AS team_id,
|
||||
toStartOfInterval(timestamp, toIntervalMinute(10)) AS time_bucket,
|
||||
service_name AS service_name,
|
||||
resource_fingerprint,
|
||||
arrayJoin(mapFilter((k, v) -> ((length(k) < 256) AND (length(v) < 256)), attributes)) AS attribute,
|
||||
attribute.1 AS attribute_key,
|
||||
attribute.2 AS attribute_value,
|
||||
sumSimpleState(1) AS attribute_count
|
||||
FROM logs31
|
||||
GROUP BY
|
||||
team_id,
|
||||
time_bucket,
|
||||
service_name,
|
||||
resource_fingerprint,
|
||||
attribute
|
||||
);
|
||||
|
||||
CREATE OR REPLACE TABLE kafka_logs_avro
|
||||
@@ -139,11 +139,10 @@ CREATE OR REPLACE TABLE kafka_logs_avro
|
||||
`severity_text` String,
|
||||
`severity_number` Int32,
|
||||
`service_name` String,
|
||||
`resource_attributes` Map(String, String),
|
||||
`resource_id` String,
|
||||
`resource_attributes` Map(LowCardinality(String), String),
|
||||
`instrumentation_scope` String,
|
||||
`event_name` String,
|
||||
`attributes` Map(String, Nullable(String))
|
||||
`attributes` Map(LowCardinality(String), String)
|
||||
)
|
||||
ENGINE = Kafka('kafka:9092', 'clickhouse_logs', 'clickhouse-logs-avro', 'Avro')
|
||||
SETTINGS
|
||||
@@ -152,15 +151,14 @@ SETTINGS
|
||||
kafka_thread_per_consumer = 1,
|
||||
kafka_num_consumers = 1,
|
||||
kafka_poll_timeout_ms=15000,
|
||||
kafka_poll_max_batch_size=1,
|
||||
kafka_max_block_size=1;
|
||||
kafka_poll_max_batch_size=10,
|
||||
kafka_max_block_size=10;
|
||||
|
||||
drop table if exists kafka_logs_avro_mv;
|
||||
|
||||
CREATE MATERIALIZED VIEW kafka_logs_avro_mv TO logs27
|
||||
CREATE MATERIALIZED VIEW kafka_logs_avro_mv TO logs31
|
||||
(
|
||||
`uuid` String,
|
||||
`team_id` Int32,
|
||||
`trace_id` String,
|
||||
`span_id` String,
|
||||
`trace_flags` Int32,
|
||||
@@ -170,11 +168,10 @@ CREATE MATERIALIZED VIEW kafka_logs_avro_mv TO logs27
|
||||
`severity_text` String,
|
||||
`severity_number` Int32,
|
||||
`service_name` String,
|
||||
`resource_attributes` Map(String, String),
|
||||
`resource_id` String,
|
||||
`resource_attributes` Map(LowCardinality(String), String),
|
||||
`instrumentation_scope` String,
|
||||
`event_name` String,
|
||||
`attributes` Map(String, Nullable(String))
|
||||
`attributes` Map(LowCardinality(String), String)
|
||||
)
|
||||
AS SELECT
|
||||
* except (attributes, resource_attributes),
|
||||
|
||||
@@ -1,27 +1,4 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
./bin/migrate-check
|
||||
|
||||
# To ensure we are able to expose metrics from multiple processes, we need to
|
||||
# provide a directory for `prometheus_client` to store a shared registry.
|
||||
export PROMETHEUS_MULTIPROC_DIR=$(mktemp -d)
|
||||
trap 'rm -rf "$PROMETHEUS_MULTIPROC_DIR"' EXIT
|
||||
|
||||
export PROMETHEUS_METRICS_EXPORT_PORT=8001
|
||||
export STATSD_PORT=${STATSD_PORT:-8125}
|
||||
|
||||
exec gunicorn posthog.wsgi \
|
||||
--config gunicorn.config.py \
|
||||
--bind 0.0.0.0:8000 \
|
||||
--log-file - \
|
||||
--log-level info \
|
||||
--access-logfile - \
|
||||
--worker-tmp-dir /dev/shm \
|
||||
--workers=2 \
|
||||
--threads=8 \
|
||||
--keep-alive=60 \
|
||||
--backlog=${GUNICORN_BACKLOG:-1000} \
|
||||
--worker-class=gthread \
|
||||
${STATSD_HOST:+--statsd-host $STATSD_HOST:$STATSD_PORT} \
|
||||
--limit-request-line=16384 $@
|
||||
# Wrapper script for backward compatibility
|
||||
# Calls the dual-mode server script (Granian/Unit)
|
||||
exec "$(dirname "$0")/docker-server-unit" "$@"
|
||||
|
||||
@@ -7,15 +7,48 @@ set -e
|
||||
# provide a directory for `prometheus_client` to store a shared registry.
|
||||
export PROMETHEUS_MULTIPROC_DIR=$(mktemp -d)
|
||||
chmod -R 777 $PROMETHEUS_MULTIPROC_DIR
|
||||
trap 'rm -rf "$PROMETHEUS_MULTIPROC_DIR"' EXIT
|
||||
|
||||
export PROMETHEUS_METRICS_EXPORT_PORT=8001
|
||||
export STATSD_PORT=${STATSD_PORT:-8125}
|
||||
export NGINX_UNIT_PYTHON_PROTOCOL=${NGINX_UNIT_PYTHON_PROTOCOL:-wsgi}
|
||||
export NGINX_UNIT_APP_PROCESSES=${NGINX_UNIT_APP_PROCESSES:-4}
|
||||
envsubst < /docker-entrypoint.d/unit.json.tpl > /docker-entrypoint.d/unit.json
|
||||
|
||||
# Dual-mode support: USE_GRANIAN env var switches between Granian and Unit (default)
|
||||
if [ "${USE_GRANIAN:-false}" = "true" ]; then
|
||||
echo "🚀 Starting with Granian ASGI server (opt-in via USE_GRANIAN=true)..."
|
||||
|
||||
# We need to run as --user root so that nginx unit can proxy the control socket for stats
|
||||
# However each application is run as "nobody"
|
||||
exec /usr/local/bin/docker-entrypoint.sh unitd --no-daemon --user root
|
||||
# Granian configuration
|
||||
export GRANIAN_WORKERS=${GRANIAN_WORKERS:-4}
|
||||
export GRANIAN_THREADS=2
|
||||
|
||||
# Start metrics HTTP server in background on port 8001
|
||||
python ./bin/granian_metrics.py &
|
||||
METRICS_PID=$!
|
||||
|
||||
# Combined trap: kill metrics server and cleanup temp directory
|
||||
trap 'kill $METRICS_PID 2>/dev/null; rm -rf "$PROMETHEUS_MULTIPROC_DIR"' EXIT
|
||||
|
||||
exec granian \
|
||||
--interface asgi \
|
||||
posthog.asgi:application \
|
||||
--workers $GRANIAN_WORKERS \
|
||||
--runtime-threads $GRANIAN_THREADS \
|
||||
--runtime-mode mt \
|
||||
--loop uvloop \
|
||||
--host 0.0.0.0 \
|
||||
--port 8000 \
|
||||
--log-level warning \
|
||||
--access-log \
|
||||
--respawn-failed-workers
|
||||
else
|
||||
echo "🔧 Starting with Nginx Unit server (default, stable)..."
|
||||
|
||||
# Cleanup temp directory on exit
|
||||
trap 'rm -rf "$PROMETHEUS_MULTIPROC_DIR"' EXIT
|
||||
|
||||
export NGINX_UNIT_PYTHON_PROTOCOL=${NGINX_UNIT_PYTHON_PROTOCOL:-wsgi}
|
||||
export NGINX_UNIT_APP_PROCESSES=${NGINX_UNIT_APP_PROCESSES:-4}
|
||||
envsubst < /docker-entrypoint.d/unit.json.tpl > /docker-entrypoint.d/unit.json
|
||||
|
||||
# We need to run as --user root so that nginx unit can proxy the control socket for stats
|
||||
# However each application is run as "nobody"
|
||||
exec /usr/local/bin/docker-entrypoint.sh unitd --no-daemon --user root
|
||||
fi
|
||||
|
||||
66
bin/granian_metrics.py
Executable file
66
bin/granian_metrics.py
Executable file
@@ -0,0 +1,66 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Metrics HTTP server for Granian multi-process setup.
|
||||
|
||||
Serves Prometheus metrics on port 8001 (configurable via PROMETHEUS_METRICS_EXPORT_PORT).
|
||||
Aggregates metrics from all Granian worker processes using prometheus_client multiprocess mode.
|
||||
Exposes Granian-equivalent metrics to maintain dashboard compatibility with previous Gunicorn setup.
|
||||
"""
|
||||
|
||||
import os
|
||||
import time
|
||||
import logging
|
||||
|
||||
from prometheus_client import CollectorRegistry, Gauge, multiprocess, start_http_server
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_granian_metrics(registry: CollectorRegistry) -> None:
|
||||
"""
|
||||
Create Granian-equivalent metrics to maintain compatibility with existing Grafana dashboards.
|
||||
|
||||
These metrics mirror the Gunicorn metrics that were previously exposed, allowing existing
|
||||
dashboards and alerts to continue working with minimal changes.
|
||||
"""
|
||||
# Read Granian configuration from environment
|
||||
workers = int(os.environ.get("GRANIAN_WORKERS", 4))
|
||||
threads = int(os.environ.get("GRANIAN_THREADS", 2))
|
||||
|
||||
# Expose static configuration as gauges
|
||||
# These provide equivalent metrics to what gunicorn/unit previously exposed
|
||||
max_worker_threads = Gauge(
|
||||
"granian_max_worker_threads",
|
||||
"Maximum number of threads per worker",
|
||||
registry=registry,
|
||||
)
|
||||
max_worker_threads.set(threads)
|
||||
|
||||
total_workers = Gauge(
|
||||
"granian_workers_total",
|
||||
"Total number of Granian workers configured",
|
||||
registry=registry,
|
||||
)
|
||||
total_workers.set(workers)
|
||||
|
||||
|
||||
def main():
|
||||
"""Start HTTP server to expose Prometheus metrics from all workers."""
|
||||
registry = CollectorRegistry()
|
||||
multiprocess.MultiProcessCollector(registry)
|
||||
|
||||
# Create Granian-specific metrics for dashboard compatibility
|
||||
create_granian_metrics(registry)
|
||||
|
||||
port = int(os.environ.get("PROMETHEUS_METRICS_EXPORT_PORT", 8001))
|
||||
|
||||
logger.info(f"Starting Prometheus metrics server on port {port}")
|
||||
start_http_server(port=port, registry=registry)
|
||||
|
||||
# Keep the server running
|
||||
while True:
|
||||
time.sleep(3600)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -61,5 +61,9 @@ procs:
|
||||
shell: 'pnpm --filter=@posthog/storybook install && pnpm run storybook'
|
||||
autostart: false
|
||||
|
||||
hedgebox-dummy:
|
||||
shell: 'bin/check_postgres_up && cd hedgebox-dummy && pnpm install && pnpm run dev'
|
||||
autostart: false
|
||||
|
||||
mouse_scroll_speed: 1
|
||||
scrollback: 10000
|
||||
|
||||
@@ -25,4 +25,13 @@ else
|
||||
echo "🐧 Linux detected, binding to Docker bridge gateway: $HOST_BIND"
|
||||
fi
|
||||
|
||||
python ${DEBUG:+ -m debugpy --listen 127.0.0.1:5678} -m uvicorn --reload posthog.asgi:application --host $HOST_BIND --log-level debug --reload-include "posthog/" --reload-include "ee/" --reload-include "products/"
|
||||
python ${DEBUG:+ -m debugpy --listen 127.0.0.1:5678} -m granian \
|
||||
--interface asgi \
|
||||
posthog.asgi:application \
|
||||
--reload \
|
||||
--reload-paths ./posthog \
|
||||
--reload-paths ./ee \
|
||||
--reload-paths ./products \
|
||||
--host $HOST_BIND \
|
||||
--log-level debug \
|
||||
--workers 1
|
||||
|
||||
@@ -1,5 +1,10 @@
|
||||
# posthog-cli
|
||||
|
||||
# 0.5.11
|
||||
|
||||
- Do not read bundle files as part of hermes sourcemap commands
|
||||
- Change hermes clone command to take two file paths (for the minified and composed maps respectively)
|
||||
|
||||
# 0.5.10
|
||||
|
||||
- Add terminal checks for login and query command
|
||||
|
||||
@@ -5,7 +5,9 @@ and bump the package version number at the same time.
|
||||
|
||||
```bash
|
||||
git checkout -b "cli/release-v0.1.0-pre1"
|
||||
# Bump version number in Cargo.toml and build to update Cargo.lock
|
||||
# Bump version number in Cargo.toml
|
||||
# Build to update Cargo.lock (cargo build)
|
||||
# Update the CHANGELOG.md
|
||||
git add .
|
||||
git commit -m "Bump version number"
|
||||
git tag "posthog-cli-v0.1.0-prerelease.1"
|
||||
|
||||
2
cli/Cargo.lock
generated
2
cli/Cargo.lock
generated
@@ -1521,7 +1521,7 @@ checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184"
|
||||
|
||||
[[package]]
|
||||
name = "posthog-cli"
|
||||
version = "0.5.10"
|
||||
version = "0.5.11"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"chrono",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "posthog-cli"
|
||||
version = "0.5.10"
|
||||
version = "0.5.11"
|
||||
authors = [
|
||||
"David <david@posthog.com>",
|
||||
"Olly <oliver@posthog.com>",
|
||||
|
||||
@@ -16,7 +16,7 @@ use crate::{
|
||||
pub struct SourceMapContent {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub release_id: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
#[serde(skip_serializing_if = "Option::is_none", alias = "debugId")]
|
||||
pub chunk_id: Option<String>,
|
||||
#[serde(flatten)]
|
||||
pub fields: BTreeMap<String, Value>,
|
||||
@@ -49,6 +49,10 @@ impl SourceMapFile {
|
||||
self.inner.content.release_id.clone()
|
||||
}
|
||||
|
||||
pub fn has_release_id(&self) -> bool {
|
||||
self.get_release_id().is_some()
|
||||
}
|
||||
|
||||
pub fn apply_adjustment(&mut self, adjustment: SourceMap) -> Result<()> {
|
||||
let new_content = {
|
||||
let content = serde_json::to_string(&self.inner.content)?.into_bytes();
|
||||
|
||||
@@ -1,27 +1,22 @@
|
||||
use std::path::PathBuf;
|
||||
|
||||
use anyhow::{anyhow, bail, Result};
|
||||
use tracing::{info, warn};
|
||||
use anyhow::{anyhow, Result};
|
||||
use tracing::info;
|
||||
|
||||
use crate::{
|
||||
invocation_context::context,
|
||||
sourcemaps::{
|
||||
content::SourceMapFile,
|
||||
hermes::{get_composed_map, inject::is_metro_bundle},
|
||||
inject::get_release_for_pairs,
|
||||
source_pairs::read_pairs,
|
||||
},
|
||||
sourcemaps::{content::SourceMapFile, inject::get_release_for_maps},
|
||||
};
|
||||
|
||||
#[derive(clap::Args)]
|
||||
pub struct CloneArgs {
|
||||
/// The directory containing the bundled chunks
|
||||
/// The path of the minified source map
|
||||
#[arg(short, long)]
|
||||
pub directory: PathBuf,
|
||||
pub minified_map_path: PathBuf,
|
||||
|
||||
/// One or more directory glob patterns to ignore
|
||||
/// The path of the composed source map
|
||||
#[arg(short, long)]
|
||||
pub ignore: Vec<String>,
|
||||
pub composed_map_path: PathBuf,
|
||||
|
||||
/// The project name associated with the uploaded chunks. Required to have the uploaded chunks associated with
|
||||
/// a specific release. We will try to auto-derive this from git information if not provided. Strongly recommended
|
||||
@@ -39,31 +34,30 @@ pub fn clone(args: &CloneArgs) -> Result<()> {
|
||||
context().capture_command_invoked("hermes_clone");
|
||||
|
||||
let CloneArgs {
|
||||
directory,
|
||||
ignore,
|
||||
minified_map_path,
|
||||
composed_map_path,
|
||||
project,
|
||||
version,
|
||||
} = args;
|
||||
|
||||
let directory = directory.canonicalize().map_err(|e| {
|
||||
let mut minified_map = SourceMapFile::load(minified_map_path).map_err(|e| {
|
||||
anyhow!(
|
||||
"Directory '{}' not found or inaccessible: {}",
|
||||
directory.display(),
|
||||
"Failed to load minified map at '{}': {}",
|
||||
minified_map_path.display(),
|
||||
e
|
||||
)
|
||||
})?;
|
||||
|
||||
info!("Processing directory: {}", directory.display());
|
||||
let pairs = read_pairs(&directory, ignore, is_metro_bundle, &None)?;
|
||||
let mut composed_map = SourceMapFile::load(composed_map_path).map_err(|e| {
|
||||
anyhow!(
|
||||
"Failed to load composed map at '{}': {}",
|
||||
composed_map_path.display(),
|
||||
e
|
||||
)
|
||||
})?;
|
||||
|
||||
if pairs.is_empty() {
|
||||
bail!("No source files found");
|
||||
}
|
||||
|
||||
info!("Found {} pairs", pairs.len());
|
||||
|
||||
let release_id =
|
||||
get_release_for_pairs(&directory, project, version, &pairs)?.map(|r| r.id.to_string());
|
||||
let release_id = get_release_for_maps(minified_map_path, project, version, [&minified_map])?
|
||||
.map(|r| r.id.to_string());
|
||||
|
||||
// The flow here differs from plain sourcemap injection a bit - here, we don't ever
|
||||
// overwrite the chunk ID, because at this point in the build process, we no longer
|
||||
@@ -73,47 +67,26 @@ pub fn clone(args: &CloneArgs) -> Result<()> {
|
||||
// tries to run `clone` twice, changing release but not posthog env, we'll error out. The
|
||||
// correct way to upload the same set of artefacts to the same posthog env as part of
|
||||
// two different releases is, 1, not to, but failing that, 2, to re-run the bundling process
|
||||
let mut pairs = pairs;
|
||||
for pair in &mut pairs {
|
||||
if !pair.has_release_id() || pair.get_release_id() != release_id {
|
||||
pair.set_release_id(release_id.clone());
|
||||
pair.save()?;
|
||||
}
|
||||
if !minified_map.has_release_id() || minified_map.get_release_id() != release_id {
|
||||
minified_map.set_release_id(release_id.clone());
|
||||
minified_map.save()?;
|
||||
}
|
||||
let pairs = pairs;
|
||||
|
||||
let maps: Result<Vec<(&SourceMapFile, Option<SourceMapFile>)>> = pairs
|
||||
.iter()
|
||||
.map(|p| get_composed_map(p).map(|c| (&p.sourcemap, c)))
|
||||
.collect();
|
||||
|
||||
let maps = maps?;
|
||||
|
||||
for (minified, composed) in maps {
|
||||
let Some(mut composed) = composed else {
|
||||
warn!(
|
||||
"Could not find composed map for minified sourcemap {}",
|
||||
minified.inner.path.display()
|
||||
);
|
||||
continue;
|
||||
};
|
||||
|
||||
// Copy metadata from source map to composed map
|
||||
if let Some(chunk_id) = minified.get_chunk_id() {
|
||||
composed.set_chunk_id(Some(chunk_id));
|
||||
}
|
||||
|
||||
if let Some(release_id) = minified.get_release_id() {
|
||||
composed.set_release_id(Some(release_id));
|
||||
}
|
||||
|
||||
composed.save()?;
|
||||
info!(
|
||||
"Successfully cloned metadata to {}",
|
||||
composed.inner.path.display()
|
||||
);
|
||||
// Copy metadata from source map to composed map
|
||||
if let Some(chunk_id) = minified_map.get_chunk_id() {
|
||||
composed_map.set_chunk_id(Some(chunk_id));
|
||||
}
|
||||
|
||||
if let Some(release_id) = minified_map.get_release_id() {
|
||||
composed_map.set_release_id(Some(release_id));
|
||||
}
|
||||
|
||||
composed_map.save()?;
|
||||
info!(
|
||||
"Successfully cloned metadata to {}",
|
||||
composed_map.inner.path.display()
|
||||
);
|
||||
|
||||
info!("Finished cloning metadata");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -2,28 +2,23 @@ use std::path::PathBuf;
|
||||
|
||||
use anyhow::{anyhow, Ok, Result};
|
||||
use tracing::{info, warn};
|
||||
use walkdir::WalkDir;
|
||||
|
||||
use crate::api::symbol_sets::{self, SymbolSetUpload};
|
||||
use crate::invocation_context::context;
|
||||
|
||||
use crate::sourcemaps::hermes::get_composed_map;
|
||||
use crate::sourcemaps::hermes::inject::is_metro_bundle;
|
||||
use crate::sourcemaps::source_pairs::read_pairs;
|
||||
use crate::sourcemaps::content::SourceMapFile;
|
||||
|
||||
#[derive(clap::Args, Clone)]
|
||||
pub struct Args {
|
||||
/// The directory containing the bundled chunks
|
||||
#[arg(short, long)]
|
||||
pub directory: PathBuf,
|
||||
|
||||
/// One or more directory glob patterns to ignore
|
||||
#[arg(short, long)]
|
||||
pub ignore: Vec<String>,
|
||||
}
|
||||
|
||||
pub fn upload(args: &Args) -> Result<()> {
|
||||
context().capture_command_invoked("hermes_upload");
|
||||
let Args { directory, ignore } = args;
|
||||
let Args { directory } = args;
|
||||
|
||||
let directory = directory.canonicalize().map_err(|e| {
|
||||
anyhow!(
|
||||
@@ -34,17 +29,10 @@ pub fn upload(args: &Args) -> Result<()> {
|
||||
})?;
|
||||
|
||||
info!("Processing directory: {}", directory.display());
|
||||
let pairs = read_pairs(&directory, ignore, is_metro_bundle, &None)?;
|
||||
|
||||
let maps: Result<Vec<_>> = pairs.iter().map(get_composed_map).collect();
|
||||
let maps = maps?;
|
||||
let maps = read_maps(&directory);
|
||||
|
||||
let mut uploads: Vec<SymbolSetUpload> = Vec::new();
|
||||
for map in maps.into_iter() {
|
||||
let Some(map) = map else {
|
||||
continue;
|
||||
};
|
||||
|
||||
if map.get_chunk_id().is_none() {
|
||||
warn!("Skipping map {}, no chunk ID", map.inner.path.display());
|
||||
continue;
|
||||
@@ -53,9 +41,22 @@ pub fn upload(args: &Args) -> Result<()> {
|
||||
uploads.push(map.try_into()?);
|
||||
}
|
||||
|
||||
info!("Found {} bundles to upload", uploads.len());
|
||||
info!("Found {} maps to upload", uploads.len());
|
||||
|
||||
symbol_sets::upload(&uploads, 100)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn read_maps(directory: &PathBuf) -> Vec<SourceMapFile> {
|
||||
WalkDir::new(directory)
|
||||
.into_iter()
|
||||
.filter_map(Result::ok)
|
||||
.filter(|e| e.file_type().is_file())
|
||||
.map(|e| {
|
||||
let path = e.path().canonicalize()?;
|
||||
SourceMapFile::load(&path)
|
||||
})
|
||||
.filter_map(Result::ok)
|
||||
.collect()
|
||||
}
|
||||
|
||||
@@ -5,7 +5,10 @@ use walkdir::DirEntry;
|
||||
|
||||
use crate::{
|
||||
api::releases::{Release, ReleaseBuilder},
|
||||
sourcemaps::source_pairs::{read_pairs, SourcePair},
|
||||
sourcemaps::{
|
||||
content::SourceMapFile,
|
||||
source_pairs::{read_pairs, SourcePair},
|
||||
},
|
||||
utils::git::get_git_info,
|
||||
};
|
||||
|
||||
@@ -61,9 +64,14 @@ pub fn inject_impl(args: &InjectArgs, matcher: impl Fn(&DirEntry) -> bool) -> Re
|
||||
bail!("no source files found");
|
||||
}
|
||||
|
||||
let created_release_id = get_release_for_pairs(&directory, project, version, &pairs)?
|
||||
.as_ref()
|
||||
.map(|r| r.id.to_string());
|
||||
let created_release_id = get_release_for_maps(
|
||||
&directory,
|
||||
project,
|
||||
version,
|
||||
pairs.iter().map(|p| &p.sourcemap),
|
||||
)?
|
||||
.as_ref()
|
||||
.map(|r| r.id.to_string());
|
||||
|
||||
pairs = inject_pairs(pairs, created_release_id)?;
|
||||
|
||||
@@ -97,16 +105,16 @@ pub fn inject_pairs(
|
||||
Ok(pairs)
|
||||
}
|
||||
|
||||
pub fn get_release_for_pairs<'a>(
|
||||
pub fn get_release_for_maps<'a>(
|
||||
directory: &Path,
|
||||
project: &Option<String>,
|
||||
version: &Option<String>,
|
||||
pairs: impl IntoIterator<Item = &'a SourcePair>,
|
||||
maps: impl IntoIterator<Item = &'a SourceMapFile>,
|
||||
) -> Result<Option<Release>> {
|
||||
// We need to fetch or create a release if: the user specified one, any pair is missing one, or the user
|
||||
// forced release overriding
|
||||
let needs_release =
|
||||
project.is_some() || version.is_some() || pairs.into_iter().any(|p| !p.has_release_id());
|
||||
project.is_some() || version.is_some() || maps.into_iter().any(|p| !p.has_release_id());
|
||||
|
||||
let mut created_release = None;
|
||||
if needs_release {
|
||||
|
||||
@@ -28,7 +28,7 @@ impl SourcePair {
|
||||
}
|
||||
|
||||
pub fn has_release_id(&self) -> bool {
|
||||
self.get_release_id().is_some()
|
||||
self.sourcemap.has_release_id()
|
||||
}
|
||||
|
||||
pub fn get_release_id(&self) -> Option<String> {
|
||||
|
||||
@@ -206,6 +206,7 @@ export const commonConfig = {
|
||||
'.woff2': 'file',
|
||||
'.mp3': 'file',
|
||||
'.lottie': 'file',
|
||||
'.sql': 'text',
|
||||
},
|
||||
metafile: true,
|
||||
}
|
||||
|
||||
@@ -86,12 +86,6 @@ core:
|
||||
dev:up:
|
||||
description: Start full PostHog dev stack via mprocs
|
||||
hidden: true
|
||||
dev:setup:
|
||||
cmd: python manage.py setup_dev
|
||||
description: Initialize local development environment (one-time setup)
|
||||
services:
|
||||
- postgresql
|
||||
- clickhouse
|
||||
dev:demo-data:
|
||||
cmd: python manage.py generate_demo_data
|
||||
description: Generate demo data for local testing
|
||||
@@ -100,14 +94,14 @@ core:
|
||||
- clickhouse
|
||||
dev:reset:
|
||||
steps:
|
||||
- docker:services:down
|
||||
- docker:services:remove
|
||||
- docker:services:up
|
||||
- check:postgres
|
||||
- check:clickhouse
|
||||
- migrations:run
|
||||
- dev:demo-data
|
||||
- migrations:sync-flags
|
||||
description: Full reset - stop services, migrate, load demo data, sync flags
|
||||
description: Full reset - wipe volumes, migrate, load demo data, sync flags
|
||||
health_checks:
|
||||
check:clickhouse:
|
||||
bin_script: check_clickhouse_up
|
||||
@@ -272,6 +266,11 @@ docker:
|
||||
description: Stop Docker infrastructure services
|
||||
services:
|
||||
- docker
|
||||
docker:services:remove:
|
||||
cmd: docker compose -f docker-compose.dev.yml down -v
|
||||
description: Stop Docker infrastructure services and remove all volumes (complete wipe)
|
||||
services:
|
||||
- docker
|
||||
docker:deprecated:
|
||||
bin_script: docker
|
||||
description: '[DEPRECATED] Use `hogli start` instead'
|
||||
@@ -302,7 +301,7 @@ docker:
|
||||
hidden: true
|
||||
docker:server:
|
||||
bin_script: docker-server
|
||||
description: Run gunicorn application server with Prometheus metrics support
|
||||
description: Run dual-mode server (Granian/Unit) with Prometheus metrics support
|
||||
hidden: true
|
||||
docker:server-unit:
|
||||
bin_script: docker-server-unit
|
||||
@@ -433,6 +432,10 @@ tools:
|
||||
bin_script: create-notebook-node.sh
|
||||
description: Create a new NotebookNode file and update types and editor references
|
||||
hidden: true
|
||||
tool:granian-metrics:
|
||||
bin_script: granian_metrics.py
|
||||
description: HTTP server that aggregates Prometheus metrics from Granian workers
|
||||
hidden: true
|
||||
sync:storage:
|
||||
bin_script: sync-storage
|
||||
description: 'TODO: add description for sync-storage'
|
||||
|
||||
@@ -162,6 +162,11 @@ function createEntry(entry) {
|
||||
test: /monaco-editor\/.*\.m?js/,
|
||||
loader: 'babel-loader',
|
||||
},
|
||||
{
|
||||
// Apply rule for .sql files
|
||||
test: /\.sql$/,
|
||||
type: 'asset/source',
|
||||
},
|
||||
],
|
||||
},
|
||||
// add devServer config only to 'main' entry
|
||||
|
||||
@@ -8,4 +8,15 @@
|
||||
--text-xxs--line-height: 0.75rem;
|
||||
|
||||
--spacing-button-padding-x: 7px; /* Match the padding of the button primitives base size (6px) + 1px border */
|
||||
|
||||
--animate-input-focus-pulse: inputFocusPulse 1s ease-in-out forwards;
|
||||
|
||||
@keyframes inputFocusPulse {
|
||||
0% {
|
||||
box-shadow: 0 0 0px 2px var(--color-accent);
|
||||
}
|
||||
100% {
|
||||
box-shadow: 0 0 0px 2px transparent;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
110
dags/README.md
110
dags/README.md
@@ -20,6 +20,116 @@ Dagster is an open-source data orchestration tool designed to help you define an
|
||||
- Individual DAG files (e.g., `exchange_rate.py`, `deletes.py`, `person_overrides.py`)
|
||||
- `tests/`: Tests for the DAGs
|
||||
|
||||
### Cloud access for posthog employees
|
||||
|
||||
Ask someone on the #team-infrastructure or #team-clickhouse to add you to Dagster Cloud. You might also want to join the #dagster-posthog slack channel.
|
||||
|
||||
### Adding a New Team
|
||||
|
||||
To set up a new team with their own Dagster definitions and Slack alerts, follow these steps:
|
||||
|
||||
1. **Create a new definitions file** in `locations/<team_name>.py`:
|
||||
|
||||
```python
|
||||
import dagster
|
||||
|
||||
from dags import my_module # Import your DAGs
|
||||
|
||||
from . import resources # Import shared resources (if needed)
|
||||
|
||||
defs = dagster.Definitions(
|
||||
assets=[
|
||||
# List your assets here
|
||||
my_module.my_asset,
|
||||
],
|
||||
jobs=[
|
||||
# List your jobs here
|
||||
my_module.my_job,
|
||||
],
|
||||
schedules=[
|
||||
# List your schedules here
|
||||
my_module.my_schedule,
|
||||
],
|
||||
resources=resources, # Include shared resources (ClickHouse, S3, Slack, etc.)
|
||||
)
|
||||
```
|
||||
|
||||
**Examples**: See `locations/analytics_platform.py` (simple) or `locations/web_analytics.py` (complex with conditional schedules)
|
||||
|
||||
2. **Register the location in the workspace** (for local development):
|
||||
|
||||
Add your module to `.dagster_home/workspace.yaml`:
|
||||
|
||||
```yaml
|
||||
load_from:
|
||||
- python_module: dags.locations.your_team
|
||||
```
|
||||
|
||||
**Note**: Only add locations that should run locally. Heavy operations should remain commented out.
|
||||
|
||||
3. **Configure production deployment**:
|
||||
|
||||
For PostHog employees, add the new location to the Dagster configuration in the [charts repository](https://github.com/PostHog/charts) (see `config/dagster/`).
|
||||
|
||||
Sample PR: https://github.com/PostHog/charts/pull/6366
|
||||
|
||||
4. **Add team to the `JobOwners` enum** in `common/common.py`:
|
||||
|
||||
```python
|
||||
class JobOwners(str, Enum):
|
||||
TEAM_ANALYTICS_PLATFORM = "team-analytics-platform"
|
||||
TEAM_YOUR_TEAM = "team-your-team" # Add your team here (alphabetically sorted)
|
||||
# ... other teams
|
||||
```
|
||||
|
||||
5. **Add Slack channel mapping** in `slack_alerts.py`:
|
||||
|
||||
```python
|
||||
notification_channel_per_team = {
|
||||
JobOwners.TEAM_ANALYTICS_PLATFORM.value: "#alerts-analytics-platform",
|
||||
JobOwners.TEAM_YOUR_TEAM.value: "#alerts-your-team", # Add mapping here (alphabetically sorted)
|
||||
# ... other teams
|
||||
}
|
||||
```
|
||||
|
||||
6. **Create the Slack channel** (if it doesn't exist) and ensure the Alertmanager/Max Slack bot is invited to the channel
|
||||
|
||||
7. **Apply owner tags to your team's assets and jobs** (see next section)
|
||||
|
||||
### How slack alerts works
|
||||
|
||||
- The `notify_slack_on_failure` sensor (defined in `slack_alerts.py`) monitors all job failures across all code locations
|
||||
- Alerts are only sent in production (when `CLOUD_DEPLOYMENT` environment variable is set)
|
||||
- Each team has a dedicated Slack channel where their alerts are routed based on job ownership
|
||||
- Failed jobs send a message to the appropriate team channel with a link to the Dagster run
|
||||
|
||||
#### Consecutive Failure Thresholds
|
||||
|
||||
Some jobs are configured to only alert after multiple consecutive failures to avoid alert fatigue. Configure this in `slack_alerts.py`:
|
||||
|
||||
```python
|
||||
CONSECUTIVE_FAILURE_THRESHOLDS = {
|
||||
"web_pre_aggregate_current_day_hourly_job": 3, # Alert after 3 consecutive failures
|
||||
"your_job_name": 2, # Add your threshold here
|
||||
}
|
||||
```
|
||||
|
||||
#### Disabling Notifications
|
||||
|
||||
To disable Slack notifications for a specific job, add the `disable_slack_notifications` tag:
|
||||
|
||||
```python
|
||||
@dagster.job(tags={"disable_slack_notifications": "true"})
|
||||
def quiet_job():
|
||||
pass
|
||||
```
|
||||
|
||||
#### Testing Alerts Locally
|
||||
|
||||
When running Dagster locally (with `DEBUG=1`), the Slack resource is replaced with a dummy resource, so no actual notifications are sent. This prevents test alerts from being sent to production Slack channels during development.
|
||||
|
||||
To test the alert routing logic, write unit tests in `tests/test_slack_alerts.py`.
|
||||
|
||||
## Local Development
|
||||
|
||||
### Environment Setup
|
||||
|
||||
206
dags/backups.py
206
dags/backups.py
@@ -2,7 +2,7 @@ import re
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any, Optional
|
||||
|
||||
from django.conf import settings
|
||||
@@ -88,6 +88,26 @@ NON_SHARDED_TABLES = [
|
||||
]
|
||||
|
||||
|
||||
@dataclass
|
||||
class Table:
|
||||
name: str
|
||||
|
||||
def is_backup_in_progress(self, client: Client) -> bool:
|
||||
# We query the processes table to check if a backup for the requested table is in progress
|
||||
rows = client.execute(
|
||||
f"""
|
||||
SELECT EXISTS(
|
||||
SELECT 1
|
||||
FROM system.processes
|
||||
WHERE query_kind = 'Backup' AND query like '%{self.name}%'
|
||||
)
|
||||
"""
|
||||
)
|
||||
|
||||
[[exists]] = rows
|
||||
return exists
|
||||
|
||||
|
||||
@dataclass
|
||||
class BackupStatus:
|
||||
hostname: str
|
||||
@@ -95,6 +115,12 @@ class BackupStatus:
|
||||
event_time_microseconds: datetime
|
||||
error: Optional[str] = None
|
||||
|
||||
def created(self) -> bool:
|
||||
return self.status == "BACKUP_CREATED"
|
||||
|
||||
def creating(self) -> bool:
|
||||
return self.status == "CREATING_BACKUP"
|
||||
|
||||
|
||||
@dataclass
|
||||
class Backup:
|
||||
@@ -141,7 +167,8 @@ class Backup:
|
||||
backup_settings = {
|
||||
"async": "1",
|
||||
"max_backup_bandwidth": get_max_backup_bandwidth(),
|
||||
"s3_disable_checksum": "1", # There is a CH issue that makes bandwith be half than what is configured: https://github.com/ClickHouse/ClickHouse/issues/78213
|
||||
# There is a CH issue that makes bandwith be half than what is configured: https://github.com/ClickHouse/ClickHouse/issues/78213
|
||||
"s3_disable_checksum": "1",
|
||||
# According to CH docs, disabling this is safe enough as checksums are already made: https://clickhouse.com/docs/operations/settings/settings#s3_disable_checksum
|
||||
}
|
||||
if self.base_backup:
|
||||
@@ -228,12 +255,6 @@ class BackupConfig(dagster.Config):
|
||||
default="",
|
||||
description="The table to backup. If not specified, the entire database will be backed up.",
|
||||
)
|
||||
date: str = pydantic.Field(
|
||||
default=datetime.now(UTC).strftime("%Y-%m-%dT%H:%M:%SZ"),
|
||||
description="The date to backup. If not specified, the current date will be used.",
|
||||
pattern=r"^\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}Z$",
|
||||
validate_default=True,
|
||||
)
|
||||
workload: Workload = Workload.OFFLINE
|
||||
|
||||
|
||||
@@ -257,13 +278,38 @@ def get_shards(cluster: dagster.ResourceParam[ClickhouseCluster]):
|
||||
|
||||
|
||||
@dagster.op
|
||||
def get_latest_backup(
|
||||
def check_running_backup_for_table(
|
||||
config: BackupConfig, cluster: dagster.ResourceParam[ClickhouseCluster]
|
||||
) -> Optional[Table]:
|
||||
"""
|
||||
Check if a backup for the requested table is in progress (it shouldn't, so fail if that's the case).
|
||||
"""
|
||||
table = Table(name=config.table)
|
||||
is_running_backup = (
|
||||
cluster.map_hosts_by_role(table.is_backup_in_progress, node_role=NodeRole.DATA, workload=config.workload)
|
||||
.result()
|
||||
.values()
|
||||
)
|
||||
if any(is_running_backup):
|
||||
raise dagster.Failure(
|
||||
description=f"A backup for table {table.name} is still in progress, this run shouldn't have been triggered. Review concurrency limits / schedule triggering logic. If there is not Dagster job running and there is a backup going on, it's worth checking what happened."
|
||||
)
|
||||
|
||||
return table
|
||||
|
||||
|
||||
@dagster.op
|
||||
def get_latest_backups(
|
||||
context: dagster.OpExecutionContext,
|
||||
config: BackupConfig,
|
||||
s3: S3Resource,
|
||||
running_backup: Optional[Table] = None,
|
||||
shard: Optional[int] = None,
|
||||
) -> Optional[Backup]:
|
||||
) -> list[Backup]:
|
||||
"""
|
||||
Get the latest backup metadata for a ClickHouse database / table from S3.
|
||||
Get the latest 15 backups metadata for a ClickHouse database / table from S3.
|
||||
|
||||
They are sorted from most recent to oldest.
|
||||
"""
|
||||
shard_path = shard if shard else NO_SHARD_PATH
|
||||
|
||||
@@ -277,24 +323,28 @@ def get_latest_backup(
|
||||
)
|
||||
|
||||
if "CommonPrefixes" not in backups:
|
||||
return None
|
||||
return []
|
||||
|
||||
latest_backup = sorted(backups["CommonPrefixes"], key=lambda x: x["Prefix"])[-1]["Prefix"]
|
||||
return Backup.from_s3_path(latest_backup)
|
||||
latest_backups = [
|
||||
Backup.from_s3_path(backup["Prefix"])
|
||||
for backup in sorted(backups["CommonPrefixes"], key=lambda x: x["Prefix"], reverse=True)
|
||||
]
|
||||
context.log.info(f"Found {len(latest_backups)} latest backups: {latest_backups}")
|
||||
return latest_backups[:15]
|
||||
|
||||
|
||||
@dagster.op
|
||||
def check_latest_backup_status(
|
||||
def get_latest_successful_backup(
|
||||
context: dagster.OpExecutionContext,
|
||||
config: BackupConfig,
|
||||
latest_backup: Optional[Backup],
|
||||
latest_backups: list[Backup],
|
||||
cluster: dagster.ResourceParam[ClickhouseCluster],
|
||||
) -> Optional[Backup]:
|
||||
"""
|
||||
Check if the latest backup is done.
|
||||
Checks the latest succesful backup to use it as a base backup.
|
||||
"""
|
||||
if not latest_backup:
|
||||
context.log.info("No latest backup found. Skipping status check.")
|
||||
if not latest_backups or not config.incremental:
|
||||
context.log.info("No latest backup found or a full backup was requested. Skipping status check.")
|
||||
return
|
||||
|
||||
def map_hosts(func: Callable[[Client], Any]):
|
||||
@@ -304,33 +354,23 @@ def check_latest_backup_status(
|
||||
)
|
||||
return cluster.map_hosts_by_role(fn=func, node_role=NodeRole.DATA, workload=config.workload)
|
||||
|
||||
is_done = map_hosts(latest_backup.is_done).result().values()
|
||||
if not all(is_done):
|
||||
context.log.info(f"Latest backup {latest_backup.path} is still in progress, waiting for it to finish")
|
||||
map_hosts(latest_backup.wait).result()
|
||||
else:
|
||||
context.log.info(f"Find latest successful created backup")
|
||||
for latest_backup in latest_backups:
|
||||
context.log.info(f"Checking status of backup: {latest_backup.path}")
|
||||
most_recent_status = get_most_recent_status(map_hosts(latest_backup.status).result().values())
|
||||
if most_recent_status and most_recent_status.status != "BACKUP_CREATED":
|
||||
# Check if the backup is stuck (CREATING_BACKUP with no active process)
|
||||
if most_recent_status.status == "CREATING_BACKUP":
|
||||
# Check how old the backup status is
|
||||
time_since_status = datetime.now(UTC) - most_recent_status.event_time_microseconds.replace(tzinfo=UTC)
|
||||
if time_since_status > timedelta(hours=2):
|
||||
context.log.warning(
|
||||
f"Previous backup {latest_backup.path} is stuck in CREATING_BACKUP status for {time_since_status}. "
|
||||
f"This usually happens when the server was restarted during backup. "
|
||||
f"Proceeding with new backup as the old one is no longer active."
|
||||
)
|
||||
# Don't raise an error - the backup is dead and won't interfere
|
||||
return None
|
||||
# For other unexpected statuses (like BACKUP_FAILED), still raise an error
|
||||
raise ValueError(
|
||||
f"Latest backup {latest_backup.path} finished with an unexpected status: {most_recent_status.status} on the host {most_recent_status.hostname}. Please check the backup logs."
|
||||
if most_recent_status and not most_recent_status.created():
|
||||
context.log.warning(
|
||||
f"Backup {latest_backup.path} finished with an unexpected status: {most_recent_status.status} on the host {most_recent_status.hostname}. Checking next backup."
|
||||
)
|
||||
else:
|
||||
context.log.info(f"Latest backup {latest_backup.path} finished successfully")
|
||||
context.log.info(
|
||||
f"Backup {latest_backup.path} finished successfully. Using it as the base backup for the new backup."
|
||||
)
|
||||
return latest_backup
|
||||
|
||||
return latest_backup
|
||||
raise dagster.Failure(
|
||||
f"All {len(latest_backups)} latest backups finished with an unexpected status. Please review them before launching a new one."
|
||||
)
|
||||
|
||||
|
||||
@dagster.op
|
||||
@@ -352,16 +392,12 @@ def run_backup(
|
||||
id=context.run_id,
|
||||
database=config.database,
|
||||
table=config.table,
|
||||
date=config.date,
|
||||
date=datetime.now(UTC).strftime("%Y-%m-%dT%H:%M:%SZ"),
|
||||
base_backup=latest_backup if config.incremental else None,
|
||||
shard=shard,
|
||||
)
|
||||
|
||||
if latest_backup and latest_backup.path == backup.path:
|
||||
context.log.warning(
|
||||
f"This backup directory exists in S3. Skipping its run, if you want to run it again, remove the data from {backup.path}"
|
||||
)
|
||||
return
|
||||
context.log.info(f"Running backup for table {config.table} in path: {backup.path}")
|
||||
|
||||
if backup.shard:
|
||||
cluster.map_any_host_in_shards_by_role(
|
||||
@@ -376,6 +412,15 @@ def run_backup(
|
||||
workload=config.workload,
|
||||
).result()
|
||||
|
||||
context.add_output_metadata(
|
||||
{
|
||||
"table": dagster.MetadataValue.text(str(backup.table)),
|
||||
"path": dagster.MetadataValue.text(backup.path),
|
||||
"incremental": dagster.MetadataValue.bool(config.incremental),
|
||||
"date": dagster.MetadataValue.text(backup.date),
|
||||
}
|
||||
)
|
||||
|
||||
return backup
|
||||
|
||||
|
||||
@@ -383,7 +428,7 @@ def run_backup(
|
||||
def wait_for_backup(
|
||||
context: dagster.OpExecutionContext,
|
||||
config: BackupConfig,
|
||||
backup: Optional[Backup],
|
||||
backup: Backup,
|
||||
cluster: dagster.ResourceParam[ClickhouseCluster],
|
||||
):
|
||||
"""
|
||||
@@ -404,18 +449,27 @@ def wait_for_backup(
|
||||
tries += 1
|
||||
map_hosts(backup.wait).result().values()
|
||||
most_recent_status = get_most_recent_status(map_hosts(backup.status).result().values())
|
||||
if most_recent_status and most_recent_status.status == "CREATING_BACKUP":
|
||||
if most_recent_status and most_recent_status.creating() and tries < 3:
|
||||
context.log.warning(
|
||||
f"Backup {backup.path} is no longer running but status is still creating. Waiting a bit longer in case ClickHouse didn't flush logs yet..."
|
||||
)
|
||||
continue
|
||||
if most_recent_status and most_recent_status.status == "BACKUP_CREATED":
|
||||
elif most_recent_status and most_recent_status.created():
|
||||
context.log.info(f"Backup for table {backup.table} in path {backup.path} finished successfully")
|
||||
done = True
|
||||
if (most_recent_status and most_recent_status.status != "BACKUP_CREATED") or (
|
||||
most_recent_status and tries >= 5
|
||||
):
|
||||
elif most_recent_status and not most_recent_status.created():
|
||||
raise ValueError(
|
||||
f"Backup {backup.path} finished with an unexpected status: {most_recent_status.status} on the host {most_recent_status.hostname}."
|
||||
)
|
||||
else:
|
||||
context.log.info("No backup to wait for")
|
||||
|
||||
context.add_output_metadata(
|
||||
{
|
||||
"table": dagster.MetadataValue.text(str(backup.table)),
|
||||
"path": dagster.MetadataValue.text(backup.path),
|
||||
"incremental": dagster.MetadataValue.bool(config.incremental),
|
||||
"date": dagster.MetadataValue.text(backup.date),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@dagster.job(
|
||||
@@ -431,8 +485,8 @@ def sharded_backup():
|
||||
"""
|
||||
|
||||
def run_backup_for_shard(shard: int):
|
||||
latest_backup = get_latest_backup(shard=shard)
|
||||
checked_backup = check_latest_backup_status(latest_backup=latest_backup)
|
||||
latest_backups = get_latest_backups(check_running_backup_for_table(), shard=shard)
|
||||
checked_backup = get_latest_successful_backup(latest_backups=latest_backups)
|
||||
new_backup = run_backup(latest_backup=checked_backup, shard=shard)
|
||||
wait_for_backup(backup=new_backup)
|
||||
|
||||
@@ -460,8 +514,8 @@ def non_sharded_backup():
|
||||
Since we don't want to keep the state about which host was selected to run the backup, we always search backups by their name in every node.
|
||||
When we find it in one of the nodes, we keep waiting on it only in that node. This is handy when we retry the job and a backup is in progress in any node, as we'll always wait for it to finish.
|
||||
"""
|
||||
latest_backup = get_latest_backup()
|
||||
new_backup = run_backup(check_latest_backup_status(latest_backup))
|
||||
latest_backups = get_latest_backups(check_running_backup_for_table())
|
||||
new_backup = run_backup(get_latest_successful_backup(latest_backups=latest_backups))
|
||||
wait_for_backup(new_backup)
|
||||
|
||||
|
||||
@@ -469,14 +523,20 @@ def prepare_run_config(config: BackupConfig) -> dagster.RunConfig:
|
||||
return dagster.RunConfig(
|
||||
{
|
||||
op.name: {"config": config.model_dump(mode="json")}
|
||||
for op in [get_latest_backup, run_backup, check_latest_backup_status, wait_for_backup]
|
||||
for op in [
|
||||
check_running_backup_for_table,
|
||||
get_latest_backups,
|
||||
run_backup,
|
||||
get_latest_successful_backup,
|
||||
wait_for_backup,
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def run_backup_request(
|
||||
table: str, incremental: bool, context: dagster.ScheduleEvaluationContext
|
||||
) -> dagster.RunRequest | dagster.SkipReason:
|
||||
) -> Optional[dagster.RunRequest]:
|
||||
skip_reason = check_for_concurrent_runs(
|
||||
context,
|
||||
tags={
|
||||
@@ -484,12 +544,12 @@ def run_backup_request(
|
||||
},
|
||||
)
|
||||
if skip_reason:
|
||||
return skip_reason
|
||||
context.log.info(skip_reason.skip_message)
|
||||
return None
|
||||
|
||||
timestamp = datetime.now(UTC)
|
||||
config = BackupConfig(
|
||||
database=settings.CLICKHOUSE_DATABASE,
|
||||
date=timestamp.strftime("%Y-%m-%dT%H:%M:%SZ"),
|
||||
table=table,
|
||||
incremental=incremental,
|
||||
)
|
||||
@@ -508,23 +568,29 @@ def run_backup_request(
|
||||
@dagster.schedule(
|
||||
job=sharded_backup,
|
||||
cron_schedule=settings.CLICKHOUSE_FULL_BACKUP_SCHEDULE,
|
||||
should_execute=lambda context: 1 <= context.scheduled_execution_time.day <= 7,
|
||||
default_status=dagster.DefaultScheduleStatus.RUNNING,
|
||||
)
|
||||
def full_sharded_backup_schedule(context: dagster.ScheduleEvaluationContext):
|
||||
"""Launch a full backup for sharded tables"""
|
||||
for table in SHARDED_TABLES:
|
||||
yield run_backup_request(table, incremental=False, context=context)
|
||||
request = run_backup_request(table, incremental=False, context=context)
|
||||
if request:
|
||||
yield request
|
||||
|
||||
|
||||
@dagster.schedule(
|
||||
job=non_sharded_backup,
|
||||
cron_schedule=settings.CLICKHOUSE_FULL_BACKUP_SCHEDULE,
|
||||
should_execute=lambda context: 1 <= context.scheduled_execution_time.day <= 7,
|
||||
default_status=dagster.DefaultScheduleStatus.RUNNING,
|
||||
)
|
||||
def full_non_sharded_backup_schedule(context: dagster.ScheduleEvaluationContext):
|
||||
"""Launch a full backup for non-sharded tables"""
|
||||
for table in NON_SHARDED_TABLES:
|
||||
yield run_backup_request(table, incremental=False, context=context)
|
||||
request = run_backup_request(table, incremental=False, context=context)
|
||||
if request:
|
||||
yield request
|
||||
|
||||
|
||||
@dagster.schedule(
|
||||
@@ -535,7 +601,9 @@ def full_non_sharded_backup_schedule(context: dagster.ScheduleEvaluationContext)
|
||||
def incremental_sharded_backup_schedule(context: dagster.ScheduleEvaluationContext):
|
||||
"""Launch an incremental backup for sharded tables"""
|
||||
for table in SHARDED_TABLES:
|
||||
yield run_backup_request(table, incremental=True, context=context)
|
||||
request = run_backup_request(table, incremental=True, context=context)
|
||||
if request:
|
||||
yield request
|
||||
|
||||
|
||||
@dagster.schedule(
|
||||
@@ -546,4 +614,6 @@ def incremental_sharded_backup_schedule(context: dagster.ScheduleEvaluationConte
|
||||
def incremental_non_sharded_backup_schedule(context: dagster.ScheduleEvaluationContext):
|
||||
"""Launch an incremental backup for non-sharded tables"""
|
||||
for table in NON_SHARDED_TABLES:
|
||||
yield run_backup_request(table, incremental=True, context=context)
|
||||
request = run_backup_request(table, incremental=True, context=context)
|
||||
if request:
|
||||
yield request
|
||||
|
||||
@@ -1,8 +1,13 @@
|
||||
import base64
|
||||
from contextlib import suppress
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
from django.conf import settings
|
||||
|
||||
import dagster
|
||||
import psycopg2
|
||||
import psycopg2.extras
|
||||
from clickhouse_driver.errors import Error, ErrorCodes
|
||||
|
||||
from posthog.clickhouse import query_tagging
|
||||
@@ -13,14 +18,17 @@ from posthog.redis import get_client, redis
|
||||
|
||||
|
||||
class JobOwners(str, Enum):
|
||||
TEAM_ANALYTICS_PLATFORM = "team-analytics-platform"
|
||||
TEAM_CLICKHOUSE = "team-clickhouse"
|
||||
TEAM_DATA_WAREHOUSE = "team-data-warehouse"
|
||||
TEAM_ERROR_TRACKING = "team-error-tracking"
|
||||
TEAM_EXPERIMENTS = "team-experiments"
|
||||
TEAM_GROWTH = "team-growth"
|
||||
TEAM_INGESTION = "team-ingestion"
|
||||
TEAM_LLMA = "team-llma"
|
||||
TEAM_MAX_AI = "team-max-ai"
|
||||
TEAM_REVENUE_ANALYTICS = "team-revenue-analytics"
|
||||
TEAM_WEB_ANALYTICS = "team-web-analytics"
|
||||
TEAM_ERROR_TRACKING = "team-error-tracking"
|
||||
TEAM_GROWTH = "team-growth"
|
||||
TEAM_EXPERIMENTS = "team-experiments"
|
||||
TEAM_MAX_AI = "team-max-ai"
|
||||
TEAM_DATA_WAREHOUSE = "team-data-warehouse"
|
||||
|
||||
|
||||
class ClickhouseClusterResource(dagster.ConfigurableResource):
|
||||
@@ -75,6 +83,28 @@ class RedisResource(dagster.ConfigurableResource):
|
||||
return client
|
||||
|
||||
|
||||
class PostgresResource(dagster.ConfigurableResource):
|
||||
"""
|
||||
A Postgres database connection resource that returns a psycopg2 connection.
|
||||
"""
|
||||
|
||||
host: str
|
||||
port: str = "5432"
|
||||
database: str
|
||||
user: str
|
||||
password: str
|
||||
|
||||
def create_resource(self, context: dagster.InitResourceContext) -> psycopg2.extensions.connection:
|
||||
return psycopg2.connect(
|
||||
host=self.host,
|
||||
port=int(self.port),
|
||||
database=self.database,
|
||||
user=self.user,
|
||||
password=self.password,
|
||||
cursor_factory=psycopg2.extras.RealDictCursor,
|
||||
)
|
||||
|
||||
|
||||
def report_job_status_metric(
|
||||
context: dagster.RunStatusSensorContext, cluster: dagster.ResourceParam[ClickhouseCluster]
|
||||
) -> None:
|
||||
@@ -171,3 +201,33 @@ def check_for_concurrent_runs(
|
||||
return dagster.SkipReason(f"Skipping {job_name} run because another run of the same job is already active")
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def metabase_debug_query_url(run_id: str) -> Optional[str]:
|
||||
cloud_deployment = getattr(settings, "CLOUD_DEPLOYMENT", None)
|
||||
if cloud_deployment == "US":
|
||||
return f"https://metabase.prod-us.posthog.dev/question/1671-get-clickhouse-query-log-for-given-dagster-run-id?dagster_run_id={run_id}"
|
||||
if cloud_deployment == "EU":
|
||||
return f"https://metabase.prod-eu.posthog.dev/question/544-get-clickhouse-query-log-for-given-dagster-run-id?dagster_run_id={run_id}"
|
||||
sql = f"""
|
||||
SELECT
|
||||
hostName() as host,
|
||||
event_time,
|
||||
type,
|
||||
exception IS NOT NULL and exception != '' as has_exception,
|
||||
query_duration_ms,
|
||||
formatReadableSize(memory_usage) as memory_used,
|
||||
formatReadableSize(read_bytes) as data_read,
|
||||
JSONExtractString(log_comment, 'dagster', 'run_id') AS dagster_run_id,
|
||||
JSONExtractString(log_comment, 'dagster', 'job_name') AS dagster_job_name,
|
||||
JSONExtractString(log_comment, 'dagster', 'asset_key') AS dagster_asset_key,
|
||||
JSONExtractString(log_comment, 'dagster', 'op_name') AS dagster_op_name,
|
||||
exception,
|
||||
query
|
||||
FROM clusterAllReplicas('posthog', system.query_log)
|
||||
WHERE
|
||||
dagster_run_id = '{run_id}'
|
||||
AND event_date >= today() - 1
|
||||
ORDER BY event_time DESC;
|
||||
"""
|
||||
return f"http://localhost:8123/play?user=default#{base64.b64encode(sql.encode("utf-8")).decode("utf-8")}"
|
||||
|
||||
@@ -23,6 +23,8 @@ from posthog.hogql_queries.experiments.experiment_query_runner import Experiment
|
||||
from posthog.hogql_queries.experiments.utils import get_experiment_stats_method
|
||||
from posthog.models.experiment import Experiment, ExperimentMetricResult
|
||||
|
||||
from products.experiments.stats.shared.statistics import StatisticError
|
||||
|
||||
from dags.common import JobOwners
|
||||
from dags.experiments import (
|
||||
_parse_partition_key,
|
||||
@@ -80,31 +82,32 @@ def experiment_regular_metrics_timeseries(context: dagster.AssetExecutionContext
|
||||
if not metric_dict:
|
||||
raise dagster.Failure(f"Metric {metric_uuid} not found in experiment {experiment_id}")
|
||||
|
||||
try:
|
||||
metric_type = metric_dict.get("metric_type")
|
||||
metric_obj: Union[ExperimentMeanMetric, ExperimentFunnelMetric, ExperimentRatioMetric]
|
||||
if metric_type == "mean":
|
||||
metric_obj = ExperimentMeanMetric(**metric_dict)
|
||||
elif metric_type == "funnel":
|
||||
metric_obj = ExperimentFunnelMetric(**metric_dict)
|
||||
elif metric_type == "ratio":
|
||||
metric_obj = ExperimentRatioMetric(**metric_dict)
|
||||
else:
|
||||
raise dagster.Failure(f"Unknown metric type: {metric_type}")
|
||||
metric_type = metric_dict.get("metric_type")
|
||||
metric_obj: Union[ExperimentMeanMetric, ExperimentFunnelMetric, ExperimentRatioMetric]
|
||||
if metric_type == "mean":
|
||||
metric_obj = ExperimentMeanMetric(**metric_dict)
|
||||
elif metric_type == "funnel":
|
||||
metric_obj = ExperimentFunnelMetric(**metric_dict)
|
||||
elif metric_type == "ratio":
|
||||
metric_obj = ExperimentRatioMetric(**metric_dict)
|
||||
else:
|
||||
raise dagster.Failure(f"Unknown metric type: {metric_type}")
|
||||
|
||||
# Validate experiment start date upfront
|
||||
if not experiment.start_date:
|
||||
raise dagster.Failure(
|
||||
f"Experiment {experiment_id} has no start_date - only launched experiments should be processed"
|
||||
)
|
||||
|
||||
query_from_utc = experiment.start_date
|
||||
query_to_utc = datetime.now(ZoneInfo("UTC"))
|
||||
|
||||
try:
|
||||
experiment_query = ExperimentQuery(
|
||||
experiment_id=experiment_id,
|
||||
metric=metric_obj,
|
||||
)
|
||||
|
||||
# Cumulative calculation: from experiment start to current time
|
||||
if not experiment.start_date:
|
||||
raise dagster.Failure(
|
||||
f"Experiment {experiment_id} has no start_date - only launched experiments should be processed"
|
||||
)
|
||||
query_from_utc = experiment.start_date
|
||||
query_to_utc = datetime.now(ZoneInfo("UTC"))
|
||||
|
||||
query_runner = ExperimentQueryRunner(query=experiment_query, team=experiment.team)
|
||||
result = query_runner._calculate()
|
||||
|
||||
@@ -127,7 +130,6 @@ def experiment_regular_metrics_timeseries(context: dagster.AssetExecutionContext
|
||||
},
|
||||
)
|
||||
|
||||
# Add metadata for Dagster UI display
|
||||
context.add_output_metadata(
|
||||
metadata={
|
||||
"experiment_id": experiment_id,
|
||||
@@ -144,7 +146,6 @@ def experiment_regular_metrics_timeseries(context: dagster.AssetExecutionContext
|
||||
"metric_definition": str(metric_dict),
|
||||
"query_from": query_from_utc.isoformat(),
|
||||
"query_to": query_to_utc.isoformat(),
|
||||
"results_status": "success",
|
||||
}
|
||||
)
|
||||
return {
|
||||
@@ -157,14 +158,50 @@ def experiment_regular_metrics_timeseries(context: dagster.AssetExecutionContext
|
||||
"result": result.model_dump(),
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
if not experiment.start_date:
|
||||
raise dagster.Failure(
|
||||
f"Experiment {experiment_id} has no start_date - only launched experiments should be processed"
|
||||
)
|
||||
query_from_utc = experiment.start_date
|
||||
query_to_utc = datetime.now(ZoneInfo("UTC"))
|
||||
except (StatisticError, ZeroDivisionError) as e:
|
||||
# Insufficient data - do not fail so that real failures are visible
|
||||
ExperimentMetricResult.objects.update_or_create(
|
||||
experiment_id=experiment_id,
|
||||
metric_uuid=metric_uuid,
|
||||
fingerprint=fingerprint,
|
||||
query_to=query_to_utc,
|
||||
defaults={
|
||||
"query_from": query_from_utc,
|
||||
"status": ExperimentMetricResult.Status.FAILED,
|
||||
"result": None,
|
||||
"query_id": None,
|
||||
"completed_at": None,
|
||||
"error_message": str(e),
|
||||
},
|
||||
)
|
||||
|
||||
context.add_output_metadata(
|
||||
metadata={
|
||||
"experiment_id": experiment_id,
|
||||
"metric_uuid": metric_uuid,
|
||||
"fingerprint": fingerprint,
|
||||
"metric_type": metric_type,
|
||||
"metric_name": metric_dict.get("name", f"Metric {metric_uuid}"),
|
||||
"experiment_name": experiment.name,
|
||||
"query_from": query_from_utc.isoformat(),
|
||||
"query_to": query_to_utc.isoformat(),
|
||||
"error_type": type(e).__name__,
|
||||
"error_message": str(e),
|
||||
}
|
||||
)
|
||||
|
||||
return {
|
||||
"experiment_id": experiment_id,
|
||||
"metric_uuid": metric_uuid,
|
||||
"fingerprint": fingerprint,
|
||||
"metric_definition": metric_dict,
|
||||
"query_from": query_from_utc.isoformat(),
|
||||
"query_to": query_to_utc.isoformat(),
|
||||
"error": str(e),
|
||||
"error_type": type(e).__name__,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
ExperimentMetricResult.objects.update_or_create(
|
||||
experiment_id=experiment_id,
|
||||
metric_uuid=metric_uuid,
|
||||
|
||||
@@ -21,6 +21,8 @@ from posthog.hogql_queries.experiments.experiment_query_runner import Experiment
|
||||
from posthog.hogql_queries.experiments.utils import get_experiment_stats_method
|
||||
from posthog.models.experiment import Experiment, ExperimentMetricResult
|
||||
|
||||
from products.experiments.stats.shared.statistics import StatisticError
|
||||
|
||||
from dags.common import JobOwners
|
||||
from dags.experiments import (
|
||||
_parse_partition_key,
|
||||
@@ -97,20 +99,21 @@ def experiment_saved_metrics_timeseries(context: dagster.AssetExecutionContext)
|
||||
else:
|
||||
raise dagster.Failure(f"Unknown metric type: {metric_type}")
|
||||
|
||||
# Validate experiment start date upfront
|
||||
if not experiment.start_date:
|
||||
raise dagster.Failure(
|
||||
f"Experiment {experiment_id} has no start_date - only launched experiments should be processed"
|
||||
)
|
||||
|
||||
query_from_utc = experiment.start_date
|
||||
query_to_utc = datetime.now(ZoneInfo("UTC"))
|
||||
|
||||
try:
|
||||
experiment_query = ExperimentQuery(
|
||||
experiment_id=experiment_id,
|
||||
metric=metric_obj,
|
||||
)
|
||||
|
||||
# Cumulative calculation: from experiment start to current time
|
||||
if not experiment.start_date:
|
||||
raise dagster.Failure(
|
||||
f"Experiment {experiment_id} has no start_date - only launched experiments should be processed"
|
||||
)
|
||||
query_from_utc = experiment.start_date
|
||||
query_to_utc = datetime.now(ZoneInfo("UTC"))
|
||||
|
||||
query_runner = ExperimentQueryRunner(query=experiment_query, team=experiment.team, user_facing=False)
|
||||
result = query_runner._calculate()
|
||||
|
||||
@@ -118,7 +121,7 @@ def experiment_saved_metrics_timeseries(context: dagster.AssetExecutionContext)
|
||||
|
||||
completed_at = datetime.now(ZoneInfo("UTC"))
|
||||
|
||||
experiment_metric_result, created = ExperimentMetricResult.objects.update_or_create(
|
||||
experiment_metric_result, _ = ExperimentMetricResult.objects.update_or_create(
|
||||
experiment_id=experiment_id,
|
||||
metric_uuid=metric_uuid,
|
||||
fingerprint=fingerprint,
|
||||
@@ -133,7 +136,6 @@ def experiment_saved_metrics_timeseries(context: dagster.AssetExecutionContext)
|
||||
},
|
||||
)
|
||||
|
||||
# Add metadata for Dagster UI display
|
||||
context.add_output_metadata(
|
||||
metadata={
|
||||
"experiment_id": experiment_id,
|
||||
@@ -150,7 +152,6 @@ def experiment_saved_metrics_timeseries(context: dagster.AssetExecutionContext)
|
||||
else None,
|
||||
"query_from": query_from_utc.isoformat(),
|
||||
"query_to": query_to_utc.isoformat(),
|
||||
"results_status": "success",
|
||||
}
|
||||
)
|
||||
return {
|
||||
@@ -164,14 +165,52 @@ def experiment_saved_metrics_timeseries(context: dagster.AssetExecutionContext)
|
||||
"result": result.model_dump(),
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
if not experiment.start_date:
|
||||
raise dagster.Failure(
|
||||
f"Experiment {experiment_id} has no start_date - only launched experiments should be processed"
|
||||
)
|
||||
query_from_utc = experiment.start_date
|
||||
query_to_utc = datetime.now(ZoneInfo("UTC"))
|
||||
except (StatisticError, ZeroDivisionError) as e:
|
||||
# Insufficient data - do not fail so that real failures are visible
|
||||
ExperimentMetricResult.objects.update_or_create(
|
||||
experiment_id=experiment_id,
|
||||
metric_uuid=metric_uuid,
|
||||
fingerprint=fingerprint,
|
||||
query_to=query_to_utc,
|
||||
defaults={
|
||||
"query_from": query_from_utc,
|
||||
"status": ExperimentMetricResult.Status.FAILED,
|
||||
"result": None,
|
||||
"query_id": None,
|
||||
"completed_at": None,
|
||||
"error_message": str(e),
|
||||
},
|
||||
)
|
||||
|
||||
context.add_output_metadata(
|
||||
metadata={
|
||||
"experiment_id": experiment_id,
|
||||
"saved_metric_id": saved_metric.id,
|
||||
"saved_metric_uuid": saved_metric.query.get("uuid"),
|
||||
"saved_metric_name": saved_metric.name,
|
||||
"fingerprint": fingerprint,
|
||||
"metric_type": metric_type,
|
||||
"experiment_name": experiment.name,
|
||||
"query_from": query_from_utc.isoformat(),
|
||||
"query_to": query_to_utc.isoformat(),
|
||||
"error_type": type(e).__name__,
|
||||
"error_message": str(e),
|
||||
}
|
||||
)
|
||||
|
||||
return {
|
||||
"experiment_id": experiment_id,
|
||||
"saved_metric_id": saved_metric.id,
|
||||
"saved_metric_uuid": saved_metric.query.get("uuid"),
|
||||
"saved_metric_name": saved_metric.name,
|
||||
"fingerprint": fingerprint,
|
||||
"query_from": query_from_utc.isoformat(),
|
||||
"query_to": query_to_utc.isoformat(),
|
||||
"error": str(e),
|
||||
"error_type": type(e).__name__,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
ExperimentMetricResult.objects.update_or_create(
|
||||
experiment_id=experiment_id,
|
||||
metric_uuid=metric_uuid,
|
||||
|
||||
62
dags/llma/README.md
Normal file
62
dags/llma/README.md
Normal file
@@ -0,0 +1,62 @@
|
||||
# LLMA (LLM Analytics) Dagster Location
|
||||
|
||||
Data pipelines for LLM analytics and observability.
|
||||
|
||||
## Overview
|
||||
|
||||
The LLMA location contains pipelines for aggregating and analyzing AI/LLM
|
||||
events tracked through PostHog. These pipelines power analytics, cost tracking,
|
||||
and observability features for AI products.
|
||||
|
||||
## Structure
|
||||
|
||||
```text
|
||||
dags/llma/
|
||||
├── README.md
|
||||
├── daily_metrics/ # Daily aggregation pipeline
|
||||
│ ├── README.md # Detailed pipeline documentation
|
||||
│ ├── config.py # Pipeline configuration
|
||||
│ ├── main.py # Dagster asset and schedule
|
||||
│ ├── utils.py # SQL generation helpers
|
||||
│ └── sql/ # Modular SQL templates
|
||||
│ ├── event_counts.sql # Event count metrics
|
||||
│ ├── error_rates.sql # Error rate metrics
|
||||
│ ├── trace_counts.sql # Unique trace count metrics
|
||||
│ ├── session_counts.sql # Unique session count metrics
|
||||
│ ├── trace_error_rates.sql # Trace-level error rates
|
||||
│ └── pageview_counts.sql # LLM Analytics pageview metrics
|
||||
└── __init__.py
|
||||
|
||||
Tests: dags/tests/llma/daily_metrics/test_sql_metrics.py
|
||||
```
|
||||
|
||||
## Pipelines
|
||||
|
||||
### Daily Metrics
|
||||
|
||||
Aggregates AI event metrics ($ai_trace, $ai_generation, $ai_span,
|
||||
$ai_embedding) by team and date into the `llma_metrics_daily` ClickHouse
|
||||
table.
|
||||
|
||||
Features:
|
||||
|
||||
- Modular SQL template system for easy metric additions
|
||||
- Event counts, trace counts, session counts, and pageview metrics
|
||||
- Error rates at event and trace level (proportions 0.0-1.0)
|
||||
- Long-format schema for schema-less evolution
|
||||
- Daily schedule at 6 AM UTC
|
||||
|
||||
See [daily_metrics/README.md](daily_metrics/README.md) for detailed
|
||||
documentation.
|
||||
|
||||
## Local Development
|
||||
|
||||
The LLMA location is loaded in `.dagster_home/workspace.yaml` for local
|
||||
development.
|
||||
|
||||
View in Dagster UI:
|
||||
|
||||
```bash
|
||||
# Dagster runs on port 3030
|
||||
open http://localhost:3030
|
||||
```
|
||||
5
dags/llma/__init__.py
Normal file
5
dags/llma/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""
|
||||
LLMA (LLM Analytics) team Dagster assets.
|
||||
|
||||
This module contains data pipeline assets for tracking and analyzing LLM usage metrics.
|
||||
"""
|
||||
206
dags/llma/daily_metrics/README.md
Normal file
206
dags/llma/daily_metrics/README.md
Normal file
@@ -0,0 +1,206 @@
|
||||
# LLMA Daily Metrics Pipeline
|
||||
|
||||
Daily aggregation of AI event metrics into the `llma_metrics_daily` ClickHouse
|
||||
table.
|
||||
|
||||
## Overview
|
||||
|
||||
Aggregates AI event metrics ($ai_trace, $ai_generation, $ai_span,
|
||||
$ai_embedding) by team and date into a long-format metrics table for efficient
|
||||
querying.
|
||||
|
||||
## Architecture
|
||||
|
||||
The pipeline uses a modular SQL template system:
|
||||
|
||||
- Each metric type lives in its own `.sql` file under `sql/`
|
||||
- Templates are auto-discovered and combined with UNION ALL
|
||||
- To add a new metric, simply drop a new `.sql` file in the directory
|
||||
|
||||
### SQL Template Format
|
||||
|
||||
Each SQL file should return these columns:
|
||||
|
||||
```sql
|
||||
SELECT
|
||||
date(timestamp) as date,
|
||||
team_id,
|
||||
'metric_name' as metric_name,
|
||||
toFloat64(value) as metric_value
|
||||
FROM events
|
||||
WHERE ...
|
||||
```
|
||||
|
||||
Templates have access to these Jinja2 variables:
|
||||
|
||||
- `event_types`: List of AI event types to aggregate
|
||||
- `date_start`: Start date for aggregation (YYYY-MM-DD)
|
||||
- `date_end`: End date for aggregation (YYYY-MM-DD)
|
||||
- `pageview_mappings`: List of (url_path, metric_suffix) tuples for pageview categorization
|
||||
- `include_error_rates`: Boolean flag for error rate calculation (default: true)
|
||||
|
||||
## Output Schema
|
||||
|
||||
```sql
|
||||
CREATE TABLE llma_metrics_daily (
|
||||
date Date,
|
||||
team_id UInt64,
|
||||
metric_name String,
|
||||
metric_value Float64
|
||||
) ENGINE = ReplicatedMergeTree()
|
||||
PARTITION BY toYYYYMM(date)
|
||||
ORDER BY (team_id, date, metric_name)
|
||||
```
|
||||
|
||||
Long format allows adding new metrics without schema changes.
|
||||
|
||||
## Current Metrics
|
||||
|
||||
### Event Counts
|
||||
|
||||
Defined in `sql/event_counts.sql`:
|
||||
|
||||
- `ai_generation_count`: Number of AI generation events
|
||||
- `ai_trace_count`: Number of AI trace events
|
||||
- `ai_span_count`: Number of AI span events
|
||||
- `ai_embedding_count`: Number of AI embedding events
|
||||
|
||||
Each event is counted individually, even if multiple events share the same trace_id.
|
||||
|
||||
### Trace Counts
|
||||
|
||||
Defined in `sql/trace_counts.sql`:
|
||||
|
||||
- `ai_trace_id_count`: Number of unique traces (distinct $ai_trace_id values)
|
||||
|
||||
Counts unique traces across all AI event types. A trace may contain multiple events (generations, spans, etc).
|
||||
|
||||
### Session Counts
|
||||
|
||||
Defined in `sql/session_counts.sql`:
|
||||
|
||||
- `ai_session_id_count`: Number of unique sessions (distinct $ai_session_id values)
|
||||
|
||||
Counts unique sessions across all AI event types. A session can link multiple related traces together.
|
||||
|
||||
### Event Error Rates
|
||||
|
||||
Defined in `sql/error_rates.sql`:
|
||||
|
||||
- `ai_generation_error_rate`: Proportion of AI generation events with errors (0.0 to 1.0)
|
||||
- `ai_trace_error_rate`: Proportion of AI trace events with errors (0.0 to 1.0)
|
||||
- `ai_span_error_rate`: Proportion of AI span events with errors (0.0 to 1.0)
|
||||
- `ai_embedding_error_rate`: Proportion of AI embedding events with errors (0.0 to 1.0)
|
||||
|
||||
### Trace Error Rates
|
||||
|
||||
Defined in `sql/trace_error_rates.sql`:
|
||||
|
||||
- `ai_trace_id_has_error_rate`: Proportion of unique traces that had at least one error (0.0 to 1.0)
|
||||
|
||||
A trace is considered errored if ANY event within it has an error. Compare with event error rates which report the proportion of individual events with errors.
|
||||
|
||||
### Pageview Metrics
|
||||
|
||||
Defined in `sql/pageview_counts.sql`:
|
||||
|
||||
- `pageviews_traces`: Pageviews on /llm-analytics/traces
|
||||
- `pageviews_generations`: Pageviews on /llm-analytics/generations
|
||||
- `pageviews_users`: Pageviews on /llm-analytics/users
|
||||
- `pageviews_sessions`: Pageviews on /llm-analytics/sessions
|
||||
- `pageviews_playground`: Pageviews on /llm-analytics/playground
|
||||
- `pageviews_datasets`: Pageviews on /llm-analytics/datasets
|
||||
- `pageviews_evaluations`: Pageviews on /llm-analytics/evaluations
|
||||
|
||||
Tracks $pageview events on LLM Analytics pages. URL patterns are mapped to page types via config.pageview_mappings.
|
||||
|
||||
### Error Detection
|
||||
|
||||
All error metrics detect errors by checking for:
|
||||
|
||||
- `$ai_error` property is non-empty
|
||||
- `$ai_is_error` property is true
|
||||
|
||||
## Configuration
|
||||
|
||||
See `config.py` for configuration options:
|
||||
|
||||
- `partition_start_date`: First date to backfill (default: 2025-01-01)
|
||||
- `cron_schedule`: Schedule for daily runs (default: 6 AM UTC)
|
||||
- `max_partitions_per_run`: Max partitions to process in backfill (default: 14)
|
||||
- `ai_event_types`: List of AI event types to track (default: $ai_trace, $ai_generation, $ai_span, $ai_embedding)
|
||||
- `pageview_mappings`: URL path to metric name mappings for pageview tracking
|
||||
- `include_error_rates`: Enable error rate metrics (default: true)
|
||||
|
||||
## Schedule
|
||||
|
||||
Runs daily at 6 AM UTC for the previous day's partition.
|
||||
|
||||
## Local Development
|
||||
|
||||
Query results in ClickHouse:
|
||||
|
||||
```bash
|
||||
docker exec posthog-clickhouse-1 clickhouse-client --query \
|
||||
"SELECT * FROM llma_metrics_daily WHERE date = today() FORMAT Pretty"
|
||||
```
|
||||
|
||||
Or use HogQL in PostHog UI:
|
||||
|
||||
```sql
|
||||
SELECT
|
||||
date,
|
||||
metric_name,
|
||||
sum(metric_value) as total
|
||||
FROM llma_metrics_daily
|
||||
WHERE date >= today() - INTERVAL 7 DAY
|
||||
GROUP BY date, metric_name
|
||||
ORDER BY date DESC, metric_name
|
||||
```
|
||||
|
||||
## Testing
|
||||
|
||||
Run the test suite to validate SQL structure and logic:
|
||||
|
||||
```bash
|
||||
python -m pytest dags/tests/llma/daily_metrics/test_sql_metrics.py -v
|
||||
```
|
||||
|
||||
Tests validate:
|
||||
|
||||
- SQL templates render without errors
|
||||
- All SQL files produce the correct 4-column output
|
||||
- Calculation logic is correct (using mock data)
|
||||
- Date filtering and grouping work as expected
|
||||
|
||||
Test file location: `dags/tests/llma/daily_metrics/test_sql_metrics.py`
|
||||
|
||||
## Adding New Metrics
|
||||
|
||||
1. Create a new SQL file in `sql/` (e.g., `sql/token_counts.sql`)
|
||||
2. Use Jinja2 template syntax with `event_types`, `date_start`, `date_end`
|
||||
3. Return columns: `date`, `team_id`, `metric_name`, `metric_value`
|
||||
4. The pipeline will automatically discover and include it
|
||||
5. Add test coverage in `dags/tests/llma/daily_metrics/test_sql_metrics.py` with mock data and expected output
|
||||
|
||||
Example:
|
||||
|
||||
```sql
|
||||
{% for event_type in event_types %}
|
||||
{% set metric_name = event_type.lstrip('$') + '_tokens' %}
|
||||
SELECT
|
||||
date(timestamp) as date,
|
||||
team_id,
|
||||
'{{ metric_name }}' as metric_name,
|
||||
toFloat64(sum(JSONExtractInt(properties, '$ai_total_tokens'))) as metric_value
|
||||
FROM events
|
||||
WHERE event = '{{ event_type }}'
|
||||
AND timestamp >= toDateTime('{{ date_start }}', 'UTC')
|
||||
AND timestamp < toDateTime('{{ date_end }}', 'UTC')
|
||||
GROUP BY date, team_id
|
||||
HAVING metric_value > 0
|
||||
{% if not loop.last %}
|
||||
UNION ALL
|
||||
{% endif %}
|
||||
{% endfor %}
|
||||
```
|
||||
67
dags/llma/daily_metrics/config.py
Normal file
67
dags/llma/daily_metrics/config.py
Normal file
@@ -0,0 +1,67 @@
|
||||
"""
|
||||
Configuration for LLMA (LLM Analytics) metrics aggregation.
|
||||
|
||||
This module defines all parameters and constants used in the LLMA metrics pipeline.
|
||||
Modify this file to add new metrics or change aggregation behavior.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLMADailyMetricsConfig:
|
||||
"""Configuration for LLMA daily metrics aggregation pipeline."""
|
||||
|
||||
# ClickHouse table name
|
||||
table_name: str = "llma_metrics_daily"
|
||||
|
||||
# Start date for daily partitions (when AI events were introduced)
|
||||
partition_start_date: str = "2025-01-01"
|
||||
|
||||
# Schedule: Daily at 6 AM UTC
|
||||
cron_schedule: str = "0 6 * * *"
|
||||
|
||||
# Backfill policy: process N days per run
|
||||
max_partitions_per_run: int = 14
|
||||
|
||||
# ClickHouse query settings
|
||||
clickhouse_max_execution_time: int = 300 # 5 minutes
|
||||
|
||||
# Dagster job timeout (seconds)
|
||||
job_timeout: int = 1800 # 30 minutes
|
||||
|
||||
# Include error rate metrics (percentage of events with errors)
|
||||
include_error_rates: bool = True
|
||||
|
||||
# AI event types to track
|
||||
# Add new event types here to automatically include them in daily aggregations
|
||||
ai_event_types: list[str] | None = None
|
||||
|
||||
# Pageview URL path to metric name mappings
|
||||
# Maps URL patterns to metric names for tracking pageviews
|
||||
# Order matters: more specific patterns should come before general ones
|
||||
pageview_mappings: list[tuple[str, str]] | None = None
|
||||
|
||||
def __post_init__(self):
|
||||
"""Set default values for list fields."""
|
||||
if self.ai_event_types is None:
|
||||
self.ai_event_types = [
|
||||
"$ai_trace",
|
||||
"$ai_generation",
|
||||
"$ai_span",
|
||||
"$ai_embedding",
|
||||
]
|
||||
if self.pageview_mappings is None:
|
||||
self.pageview_mappings = [
|
||||
("/llm-analytics/traces", "traces"),
|
||||
("/llm-analytics/generations", "generations"),
|
||||
("/llm-analytics/users", "users"),
|
||||
("/llm-analytics/sessions", "sessions"),
|
||||
("/llm-analytics/playground", "playground"),
|
||||
("/llm-analytics/datasets", "datasets"),
|
||||
("/llm-analytics/evaluations", "evaluations"),
|
||||
]
|
||||
|
||||
|
||||
# Global config instance
|
||||
config = LLMADailyMetricsConfig()
|
||||
135
dags/llma/daily_metrics/main.py
Normal file
135
dags/llma/daily_metrics/main.py
Normal file
@@ -0,0 +1,135 @@
|
||||
"""
|
||||
Daily aggregation of LLMA (LLM Analytics) metrics.
|
||||
|
||||
Aggregates AI event counts from the events table into a daily metrics table
|
||||
for efficient querying and cost analysis.
|
||||
"""
|
||||
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
import pandas as pd
|
||||
import dagster
|
||||
from dagster import BackfillPolicy, DailyPartitionsDefinition
|
||||
|
||||
from posthog.clickhouse import query_tagging
|
||||
from posthog.clickhouse.client import sync_execute
|
||||
from posthog.clickhouse.cluster import ClickhouseCluster
|
||||
|
||||
from dags.common import JobOwners, dagster_tags
|
||||
from dags.llma.daily_metrics.config import config
|
||||
from dags.llma.daily_metrics.utils import get_delete_query, get_insert_query
|
||||
|
||||
# Partition definition for daily aggregations
|
||||
partition_def = DailyPartitionsDefinition(start_date=config.partition_start_date, end_offset=1)
|
||||
|
||||
# Backfill policy: process N days per run
|
||||
backfill_policy_def = BackfillPolicy.multi_run(max_partitions_per_run=config.max_partitions_per_run)
|
||||
|
||||
# ClickHouse settings for aggregation queries
|
||||
LLMA_CLICKHOUSE_SETTINGS = {
|
||||
"max_execution_time": str(config.clickhouse_max_execution_time),
|
||||
}
|
||||
|
||||
|
||||
@dagster.asset(
|
||||
name="llma_metrics_daily",
|
||||
group_name="llma",
|
||||
partitions_def=partition_def,
|
||||
backfill_policy=backfill_policy_def,
|
||||
metadata={"table": config.table_name},
|
||||
tags={"owner": JobOwners.TEAM_LLMA.value},
|
||||
)
|
||||
def llma_metrics_daily(
|
||||
context: dagster.AssetExecutionContext,
|
||||
cluster: dagster.ResourceParam[ClickhouseCluster],
|
||||
) -> None:
|
||||
"""
|
||||
Daily aggregation of LLMA metrics.
|
||||
|
||||
Aggregates AI event counts ($ai_trace, $ai_generation, $ai_span, $ai_embedding)
|
||||
by team and date into a long-format metrics table for efficient querying.
|
||||
|
||||
Long format allows adding new metrics without schema changes.
|
||||
"""
|
||||
query_tagging.get_query_tags().with_dagster(dagster_tags(context))
|
||||
|
||||
if not context.partition_time_window:
|
||||
raise dagster.Failure("This asset should only be run with a partition_time_window")
|
||||
|
||||
start_datetime, end_datetime = context.partition_time_window
|
||||
date_start = start_datetime.strftime("%Y-%m-%d")
|
||||
date_end = end_datetime.strftime("%Y-%m-%d")
|
||||
|
||||
context.log.info(f"Aggregating LLMA metrics for {date_start} to {date_end}")
|
||||
|
||||
try:
|
||||
delete_query = get_delete_query(date_start, date_end)
|
||||
sync_execute(delete_query, settings=LLMA_CLICKHOUSE_SETTINGS)
|
||||
|
||||
insert_query = get_insert_query(date_start, date_end)
|
||||
context.log.info(f"Metrics query: \n{insert_query}")
|
||||
sync_execute(insert_query, settings=LLMA_CLICKHOUSE_SETTINGS)
|
||||
|
||||
# Query and log the metrics that were just aggregated
|
||||
metrics_query = f"""
|
||||
SELECT
|
||||
metric_name,
|
||||
count(DISTINCT team_id) as teams,
|
||||
sum(metric_value) as total_value
|
||||
FROM {config.table_name}
|
||||
WHERE date >= '{date_start}' AND date < '{date_end}'
|
||||
GROUP BY metric_name
|
||||
ORDER BY metric_name
|
||||
"""
|
||||
metrics_results = sync_execute(metrics_query)
|
||||
|
||||
if metrics_results:
|
||||
df = pd.DataFrame(metrics_results, columns=["metric_name", "teams", "total_value"])
|
||||
context.log.info(f"Aggregated {len(df)} metric types for {date_start}:\n{df.to_string(index=False)}")
|
||||
else:
|
||||
context.log.info(f"No AI events found for {date_start}")
|
||||
|
||||
context.log.info(f"Successfully aggregated LLMA metrics for {date_start}")
|
||||
|
||||
except Exception as e:
|
||||
raise dagster.Failure(f"Failed to aggregate LLMA metrics: {str(e)}") from e
|
||||
|
||||
|
||||
# Define the job that runs the asset
|
||||
llma_metrics_daily_job = dagster.define_asset_job(
|
||||
name="llma_metrics_daily_job",
|
||||
selection=["llma_metrics_daily"],
|
||||
tags={
|
||||
"owner": JobOwners.TEAM_LLMA.value,
|
||||
"dagster/max_runtime": str(config.job_timeout),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@dagster.schedule(
|
||||
cron_schedule=config.cron_schedule,
|
||||
job=llma_metrics_daily_job,
|
||||
execution_timezone="UTC",
|
||||
tags={"owner": JobOwners.TEAM_LLMA.value},
|
||||
)
|
||||
def llma_metrics_daily_schedule(context: dagster.ScheduleEvaluationContext):
|
||||
"""
|
||||
Runs daily for the previous day's partition.
|
||||
|
||||
Schedule configured in dags.llma.config.
|
||||
This aggregates AI event metrics from the events table into the
|
||||
llma_metrics_daily table for efficient querying.
|
||||
"""
|
||||
# Calculate yesterday's partition
|
||||
yesterday = (datetime.now(UTC) - timedelta(days=1)).strftime("%Y-%m-%d")
|
||||
|
||||
context.log.info(f"Scheduling LLMA metrics aggregation for {yesterday}")
|
||||
|
||||
return dagster.RunRequest(
|
||||
partition_key=yesterday,
|
||||
run_config={
|
||||
"ops": {
|
||||
"llma_metrics_daily": {"config": {}},
|
||||
}
|
||||
},
|
||||
)
|
||||
28
dags/llma/daily_metrics/sql/error_rates.sql
Normal file
28
dags/llma/daily_metrics/sql/error_rates.sql
Normal file
@@ -0,0 +1,28 @@
|
||||
/*
|
||||
Event Error Rates - Proportion of events with errors by type
|
||||
|
||||
Calculates the proportion of events of each type that had an error (0.0 to 1.0).
|
||||
An event is considered errored if $ai_error is set or $ai_is_error is true.
|
||||
|
||||
Produces metrics: ai_generation_error_rate, ai_embedding_error_rate, etc.
|
||||
|
||||
Example: If 2 out of 10 generation events had errors, this reports 0.20.
|
||||
Compare with trace_error_rates.sql which reports proportion of traces with any error.
|
||||
*/
|
||||
|
||||
SELECT
|
||||
date(timestamp) as date,
|
||||
team_id,
|
||||
concat(substring(event, 2), '_error_rate') as metric_name,
|
||||
round(
|
||||
countIf(
|
||||
(JSONExtractRaw(properties, '$ai_error') != '' AND JSONExtractRaw(properties, '$ai_error') != 'null')
|
||||
OR JSONExtractBool(properties, '$ai_is_error') = true
|
||||
) / count(*),
|
||||
4
|
||||
) as metric_value
|
||||
FROM events
|
||||
WHERE event IN ({% for event_type in event_types %}'{{ event_type }}'{% if not loop.last %}, {% endif %}{% endfor %})
|
||||
AND timestamp >= toDateTime('{{ date_start }}', 'UTC')
|
||||
AND timestamp < toDateTime('{{ date_end }}', 'UTC')
|
||||
GROUP BY date, team_id, event
|
||||
22
dags/llma/daily_metrics/sql/event_counts.sql
Normal file
22
dags/llma/daily_metrics/sql/event_counts.sql
Normal file
@@ -0,0 +1,22 @@
|
||||
/*
|
||||
Event Counts - Individual AI events by type
|
||||
|
||||
Counts the total number of individual AI events ($ai_generation, $ai_embedding, etc).
|
||||
Each event is counted separately, even if multiple events share the same trace_id.
|
||||
|
||||
Produces metrics: ai_generation_count, ai_embedding_count, ai_span_count, ai_trace_count
|
||||
|
||||
Example: If a trace contains 3 span events, this counts all 3 individually.
|
||||
Compare with trace_counts.sql which would count this as 1 unique trace.
|
||||
*/
|
||||
|
||||
SELECT
|
||||
date(timestamp) as date,
|
||||
team_id,
|
||||
concat(substring(event, 2), '_count') as metric_name,
|
||||
toFloat64(count(*)) as metric_value
|
||||
FROM events
|
||||
WHERE event IN ({% for event_type in event_types %}'{{ event_type }}'{% if not loop.last %}, {% endif %}{% endfor %})
|
||||
AND timestamp >= toDateTime('{{ date_start }}', 'UTC')
|
||||
AND timestamp < toDateTime('{{ date_end }}', 'UTC')
|
||||
GROUP BY date, team_id, event
|
||||
39
dags/llma/daily_metrics/sql/pageview_counts.sql
Normal file
39
dags/llma/daily_metrics/sql/pageview_counts.sql
Normal file
@@ -0,0 +1,39 @@
|
||||
/*
|
||||
Pageview Counts - Page views by LLM Analytics page type
|
||||
|
||||
Counts $pageview events on LLM Analytics pages, categorized by page type.
|
||||
URL patterns are mapped to page types via config.pageview_mappings.
|
||||
More specific patterns should be listed before general ones in config.
|
||||
|
||||
Produces metrics: pageviews_dashboard, pageviews_traces, pageviews_generations, etc.
|
||||
|
||||
Example: Pageview to /project/1/llm-analytics/traces → pageviews_traces metric
|
||||
*/
|
||||
|
||||
SELECT
|
||||
date(timestamp) as date,
|
||||
team_id,
|
||||
concat('pageviews_', page_type) as metric_name,
|
||||
toFloat64(count(*)) as metric_value
|
||||
FROM (
|
||||
SELECT
|
||||
timestamp,
|
||||
team_id,
|
||||
CASE
|
||||
{% for url_path, metric_suffix in pageview_mappings %}
|
||||
WHEN JSONExtractString(properties, '$current_url') LIKE '%{{ url_path }}%' THEN '{{ metric_suffix }}'
|
||||
{% endfor %}
|
||||
ELSE NULL
|
||||
END as page_type
|
||||
FROM events
|
||||
WHERE event = '$pageview'
|
||||
AND timestamp >= toDateTime('{{ date_start }}', 'UTC')
|
||||
AND timestamp < toDateTime('{{ date_end }}', 'UTC')
|
||||
AND (
|
||||
{% for url_path, metric_suffix in pageview_mappings %}
|
||||
JSONExtractString(properties, '$current_url') LIKE '%{{ url_path }}%'{% if not loop.last %} OR {% endif %}
|
||||
{% endfor %}
|
||||
)
|
||||
)
|
||||
WHERE page_type IS NOT NULL
|
||||
GROUP BY date, team_id, page_type
|
||||
23
dags/llma/daily_metrics/sql/session_counts.sql
Normal file
23
dags/llma/daily_metrics/sql/session_counts.sql
Normal file
@@ -0,0 +1,23 @@
|
||||
/*
|
||||
Session Counts - Unique sessions (distinct $ai_session_id)
|
||||
|
||||
Counts the number of unique sessions by counting distinct $ai_session_id values
|
||||
across all AI event types. A session can link multiple related traces together.
|
||||
|
||||
Produces metric: ai_session_id_count
|
||||
|
||||
Example: If there are 10 traces belonging to 3 unique session_ids, this counts 3.
|
||||
Compare with trace_counts.sql which would count all 10 unique traces.
|
||||
*/
|
||||
|
||||
SELECT
|
||||
date(timestamp) as date,
|
||||
team_id,
|
||||
'ai_session_id_count' as metric_name,
|
||||
toFloat64(count(DISTINCT JSONExtractString(properties, '$ai_session_id'))) as metric_value
|
||||
FROM events
|
||||
WHERE event IN ({% for event_type in event_types %}'{{ event_type }}'{% if not loop.last %}, {% endif %}{% endfor %})
|
||||
AND timestamp >= toDateTime('{{ date_start }}', 'UTC')
|
||||
AND timestamp < toDateTime('{{ date_end }}', 'UTC')
|
||||
AND JSONExtractString(properties, '$ai_session_id') != ''
|
||||
GROUP BY date, team_id
|
||||
23
dags/llma/daily_metrics/sql/trace_counts.sql
Normal file
23
dags/llma/daily_metrics/sql/trace_counts.sql
Normal file
@@ -0,0 +1,23 @@
|
||||
/*
|
||||
Trace Counts - Unique traces (distinct $ai_trace_id)
|
||||
|
||||
Counts the number of unique traces by counting distinct $ai_trace_id values
|
||||
across all AI event types. A trace may contain multiple events (generations, spans, etc).
|
||||
|
||||
Produces metric: ai_trace_id_count
|
||||
|
||||
Example: If there are 10 events belonging to 3 unique trace_ids, this counts 3.
|
||||
Compare with event_counts.sql which would count all 10 events individually.
|
||||
*/
|
||||
|
||||
SELECT
|
||||
date(timestamp) as date,
|
||||
team_id,
|
||||
'ai_trace_id_count' as metric_name,
|
||||
toFloat64(count(DISTINCT JSONExtractString(properties, '$ai_trace_id'))) as metric_value
|
||||
FROM events
|
||||
WHERE event IN ({% for event_type in event_types %}'{{ event_type }}'{% if not loop.last %}, {% endif %}{% endfor %})
|
||||
AND timestamp >= toDateTime('{{ date_start }}', 'UTC')
|
||||
AND timestamp < toDateTime('{{ date_end }}', 'UTC')
|
||||
AND JSONExtractString(properties, '$ai_trace_id') != ''
|
||||
GROUP BY date, team_id
|
||||
32
dags/llma/daily_metrics/sql/trace_error_rates.sql
Normal file
32
dags/llma/daily_metrics/sql/trace_error_rates.sql
Normal file
@@ -0,0 +1,32 @@
|
||||
/*
|
||||
Trace Error Rates - Proportion of traces with any error
|
||||
|
||||
Calculates the proportion of unique traces that had at least one error event (0.0 to 1.0).
|
||||
A trace is considered errored if ANY event within it has $ai_error set or $ai_is_error is true.
|
||||
|
||||
Produces metric: ai_trace_id_has_error_rate
|
||||
|
||||
Example: If 2 out of 8 unique traces had any error event, this reports 0.25.
|
||||
Compare with error_rates.sql which reports proportion of individual events with errors.
|
||||
|
||||
Note: A single erroring event in a trace makes the entire trace count as errored.
|
||||
*/
|
||||
|
||||
SELECT
|
||||
date(timestamp) as date,
|
||||
team_id,
|
||||
'ai_trace_id_has_error_rate' as metric_name,
|
||||
round(
|
||||
countDistinctIf(
|
||||
JSONExtractString(properties, '$ai_trace_id'),
|
||||
(JSONExtractRaw(properties, '$ai_error') != '' AND JSONExtractRaw(properties, '$ai_error') != 'null')
|
||||
OR JSONExtractBool(properties, '$ai_is_error') = true
|
||||
) / count(DISTINCT JSONExtractString(properties, '$ai_trace_id')),
|
||||
4
|
||||
) as metric_value
|
||||
FROM events
|
||||
WHERE event IN ({% for event_type in event_types %}'{{ event_type }}'{% if not loop.last %}, {% endif %}{% endfor %})
|
||||
AND timestamp >= toDateTime('{{ date_start }}', 'UTC')
|
||||
AND timestamp < toDateTime('{{ date_end }}', 'UTC')
|
||||
AND JSONExtractString(properties, '$ai_trace_id') != ''
|
||||
GROUP BY date, team_id
|
||||
65
dags/llma/daily_metrics/utils.py
Normal file
65
dags/llma/daily_metrics/utils.py
Normal file
@@ -0,0 +1,65 @@
|
||||
"""
|
||||
Utility functions for LLMA daily metrics aggregation.
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from jinja2 import Template
|
||||
|
||||
from dags.llma.daily_metrics.config import config
|
||||
|
||||
# SQL template directory
|
||||
SQL_DIR = Path(__file__).parent / "sql"
|
||||
|
||||
# Metric types to include (matches SQL filenames without .sql extension)
|
||||
# Set to None to include all, or list specific ones to include
|
||||
ENABLED_METRICS: list[str] | None = None # or ["event_counts", "error_rates"]
|
||||
|
||||
|
||||
def get_insert_query(date_start: str, date_end: str) -> str:
|
||||
"""
|
||||
Generate SQL to aggregate AI event counts by team and metric type.
|
||||
|
||||
Uses long format: each metric_name is a separate row for easy schema evolution.
|
||||
Automatically discovers and combines all SQL templates in the sql/ directory.
|
||||
|
||||
To add a new metric type, simply add a new .sql file in dags/llma/sql/.
|
||||
Each SQL file should return columns: date, team_id, metric_name, metric_value
|
||||
"""
|
||||
# Discover all SQL template files
|
||||
sql_files = sorted(SQL_DIR.glob("*.sql"))
|
||||
|
||||
# Filter by enabled metrics if specified
|
||||
if ENABLED_METRICS is not None:
|
||||
sql_files = [f for f in sql_files if f.stem in ENABLED_METRICS]
|
||||
|
||||
if not sql_files:
|
||||
raise ValueError(f"No SQL template files found in {SQL_DIR}")
|
||||
|
||||
# Load and render each template
|
||||
rendered_queries = []
|
||||
template_context = {
|
||||
"event_types": config.ai_event_types,
|
||||
"pageview_mappings": config.pageview_mappings,
|
||||
"date_start": date_start,
|
||||
"date_end": date_end,
|
||||
"include_error_rates": config.include_error_rates,
|
||||
}
|
||||
|
||||
for sql_file in sql_files:
|
||||
with open(sql_file) as f:
|
||||
template = Template(f.read())
|
||||
rendered = template.render(**template_context)
|
||||
if rendered.strip(): # Only include non-empty queries
|
||||
rendered_queries.append(rendered)
|
||||
|
||||
# Combine all queries with UNION ALL
|
||||
combined_query = "\n\nUNION ALL\n\n".join(rendered_queries)
|
||||
|
||||
# Wrap in INSERT INTO statement
|
||||
return f"INSERT INTO {config.table_name} (date, team_id, metric_name, metric_value)\n\n{combined_query}"
|
||||
|
||||
|
||||
def get_delete_query(date_start: str, date_end: str) -> str:
|
||||
"""Generate SQL to delete existing data for the date range."""
|
||||
return f"ALTER TABLE {config.table_name} DELETE WHERE date >= '{date_start}' AND date < '{date_end}'"
|
||||
@@ -5,7 +5,7 @@ import dagster_slack
|
||||
from dagster_aws.s3.io_manager import s3_pickle_io_manager
|
||||
from dagster_aws.s3.resources import S3Resource
|
||||
|
||||
from dags.common import ClickhouseClusterResource, RedisResource
|
||||
from dags.common import ClickhouseClusterResource, PostgresResource, RedisResource
|
||||
|
||||
# Define resources for different environments
|
||||
resources_by_env = {
|
||||
@@ -18,6 +18,14 @@ resources_by_env = {
|
||||
"s3": S3Resource(),
|
||||
# Using EnvVar instead of the Django setting to ensure that the token is not leaked anywhere in the Dagster UI
|
||||
"slack": dagster_slack.SlackResource(token=dagster.EnvVar("SLACK_TOKEN")),
|
||||
# Postgres resource (universal for all dags)
|
||||
"database": PostgresResource(
|
||||
host=dagster.EnvVar("POSTGRES_HOST"),
|
||||
port=dagster.EnvVar("POSTGRES_PORT"),
|
||||
database=dagster.EnvVar("POSTGRES_DATABASE"),
|
||||
user=dagster.EnvVar("POSTGRES_USER"),
|
||||
password=dagster.EnvVar("POSTGRES_PASSWORD"),
|
||||
),
|
||||
},
|
||||
"local": {
|
||||
"cluster": ClickhouseClusterResource.configure_at_launch(),
|
||||
@@ -29,6 +37,14 @@ resources_by_env = {
|
||||
aws_secret_access_key=settings.OBJECT_STORAGE_SECRET_ACCESS_KEY,
|
||||
),
|
||||
"slack": dagster.ResourceDefinition.none_resource(description="Dummy Slack resource for local development"),
|
||||
# Postgres resource (universal for all dags) - use Django settings or env vars for local dev
|
||||
"database": PostgresResource(
|
||||
host=dagster.EnvVar("POSTGRES_HOST"),
|
||||
port=dagster.EnvVar("POSTGRES_PORT"),
|
||||
database=dagster.EnvVar("POSTGRES_DATABASE"),
|
||||
user=dagster.EnvVar("POSTGRES_USER"),
|
||||
password=dagster.EnvVar("POSTGRES_PASSWORD"),
|
||||
),
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
15
dags/locations/ingestion.py
Normal file
15
dags/locations/ingestion.py
Normal file
@@ -0,0 +1,15 @@
|
||||
import dagster
|
||||
|
||||
from dags import persons_new_backfill
|
||||
|
||||
from . import resources
|
||||
|
||||
defs = dagster.Definitions(
|
||||
assets=[
|
||||
persons_new_backfill.postgres_env_check,
|
||||
],
|
||||
jobs=[
|
||||
persons_new_backfill.persons_new_backfill_job,
|
||||
],
|
||||
resources=resources,
|
||||
)
|
||||
16
dags/locations/llma.py
Normal file
16
dags/locations/llma.py
Normal file
@@ -0,0 +1,16 @@
|
||||
import dagster
|
||||
|
||||
from dags.llma.daily_metrics.main import llma_metrics_daily, llma_metrics_daily_job, llma_metrics_daily_schedule
|
||||
|
||||
from . import resources
|
||||
|
||||
defs = dagster.Definitions(
|
||||
assets=[llma_metrics_daily],
|
||||
jobs=[
|
||||
llma_metrics_daily_job,
|
||||
],
|
||||
schedules=[
|
||||
llma_metrics_daily_schedule,
|
||||
],
|
||||
resources=resources,
|
||||
)
|
||||
398
dags/persons_new_backfill.py
Normal file
398
dags/persons_new_backfill.py
Normal file
@@ -0,0 +1,398 @@
|
||||
"""Dagster job for backfilling posthog_persons data from source to destination Postgres database."""
|
||||
|
||||
import os
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
import dagster
|
||||
import psycopg2
|
||||
import psycopg2.errors
|
||||
from dagster_k8s import k8s_job_executor
|
||||
|
||||
from posthog.clickhouse.cluster import ClickhouseCluster
|
||||
from posthog.clickhouse.custom_metrics import MetricsClient
|
||||
|
||||
from dags.common import JobOwners
|
||||
|
||||
|
||||
class PersonsNewBackfillConfig(dagster.Config):
|
||||
"""Configuration for the persons new backfill job."""
|
||||
|
||||
chunk_size: int = 1_000_000 # ID range per chunk
|
||||
batch_size: int = 100_000 # Records per batch insert
|
||||
source_table: str = "posthog_persons"
|
||||
destination_table: str = "posthog_persons_new"
|
||||
max_id: int | None = None # Optional override for max ID to resume from partial state
|
||||
|
||||
|
||||
@dagster.op
|
||||
def get_id_range(
|
||||
context: dagster.OpExecutionContext,
|
||||
config: PersonsNewBackfillConfig,
|
||||
database: dagster.ResourceParam[psycopg2.extensions.connection],
|
||||
) -> tuple[int, int]:
|
||||
"""
|
||||
Query source database for MIN(id) and optionally MAX(id) from posthog_persons table.
|
||||
If max_id is provided in config, uses that instead of querying.
|
||||
Returns tuple (min_id, max_id).
|
||||
"""
|
||||
with database.cursor() as cursor:
|
||||
# Always query for min_id
|
||||
min_query = f"SELECT MIN(id) as min_id FROM {config.source_table}"
|
||||
context.log.info(f"Querying min ID: {min_query}")
|
||||
cursor.execute(min_query)
|
||||
min_result = cursor.fetchone()
|
||||
|
||||
if min_result is None or min_result["min_id"] is None:
|
||||
context.log.exception("Source table is empty or has no valid IDs")
|
||||
# Note: No metrics client here as this is get_id_range op, not copy_chunk
|
||||
raise dagster.Failure("Source table is empty or has no valid IDs")
|
||||
|
||||
min_id = int(min_result["min_id"])
|
||||
|
||||
# Use config max_id if provided, otherwise query database
|
||||
if config.max_id is not None:
|
||||
max_id = config.max_id
|
||||
context.log.info(f"Using configured max_id override: {max_id}")
|
||||
else:
|
||||
max_query = f"SELECT MAX(id) as max_id FROM {config.source_table}"
|
||||
context.log.info(f"Querying max ID: {max_query}")
|
||||
cursor.execute(max_query)
|
||||
max_result = cursor.fetchone()
|
||||
|
||||
if max_result is None or max_result["max_id"] is None:
|
||||
context.log.exception("Source table has no valid max ID")
|
||||
# Note: No metrics client here as this is get_id_range op, not copy_chunk
|
||||
raise dagster.Failure("Source table has no valid max ID")
|
||||
|
||||
max_id = int(max_result["max_id"])
|
||||
|
||||
# Validate that max_id >= min_id
|
||||
if max_id < min_id:
|
||||
error_msg = f"Invalid ID range: max_id ({max_id}) < min_id ({min_id})"
|
||||
context.log.error(error_msg)
|
||||
# Note: No metrics client here as this is get_id_range op, not copy_chunk
|
||||
raise dagster.Failure(error_msg)
|
||||
|
||||
context.log.info(f"ID range: min={min_id}, max={max_id}, total_ids={max_id - min_id + 1}")
|
||||
context.add_output_metadata(
|
||||
{
|
||||
"min_id": dagster.MetadataValue.int(min_id),
|
||||
"max_id": dagster.MetadataValue.int(max_id),
|
||||
"max_id_source": dagster.MetadataValue.text("config" if config.max_id is not None else "database"),
|
||||
"total_ids": dagster.MetadataValue.int(max_id - min_id + 1),
|
||||
}
|
||||
)
|
||||
|
||||
return (min_id, max_id)
|
||||
|
||||
|
||||
@dagster.op(out=dagster.DynamicOut(tuple[int, int]))
|
||||
def create_chunks(
|
||||
context: dagster.OpExecutionContext,
|
||||
config: PersonsNewBackfillConfig,
|
||||
id_range: tuple[int, int],
|
||||
):
|
||||
"""
|
||||
Divide ID space into chunks of chunk_size.
|
||||
Yields DynamicOutput for each chunk in reverse order (highest IDs first, lowest IDs last).
|
||||
This ensures that if the job fails partway through, the final chunk to process will be
|
||||
the one starting at the source table's min_id.
|
||||
"""
|
||||
min_id, max_id = id_range
|
||||
chunk_size = config.chunk_size
|
||||
|
||||
# First, collect all chunks
|
||||
chunks = []
|
||||
chunk_min = min_id
|
||||
chunk_num = 0
|
||||
|
||||
while chunk_min <= max_id:
|
||||
chunk_max = min(chunk_min + chunk_size - 1, max_id)
|
||||
chunks.append((chunk_min, chunk_max, chunk_num))
|
||||
chunk_min = chunk_max + 1
|
||||
chunk_num += 1
|
||||
|
||||
context.log.info(f"Created {chunk_num} chunks total")
|
||||
|
||||
# Yield chunks in reverse order (highest IDs first)
|
||||
for chunk_min, chunk_max, chunk_num in reversed(chunks):
|
||||
chunk_key = f"chunk_{chunk_min}_{chunk_max}"
|
||||
context.log.info(f"Yielding chunk {chunk_num}: {chunk_min} to {chunk_max}")
|
||||
yield dagster.DynamicOutput(
|
||||
value=(chunk_min, chunk_max),
|
||||
mapping_key=chunk_key,
|
||||
)
|
||||
|
||||
|
||||
@dagster.op
|
||||
def copy_chunk(
|
||||
context: dagster.OpExecutionContext,
|
||||
config: PersonsNewBackfillConfig,
|
||||
chunk: tuple[int, int],
|
||||
database: dagster.ResourceParam[psycopg2.extensions.connection],
|
||||
cluster: dagster.ResourceParam[ClickhouseCluster],
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Copy a chunk of records from source to destination database.
|
||||
Processes in batches of batch_size records.
|
||||
"""
|
||||
chunk_min, chunk_max = chunk
|
||||
batch_size = config.batch_size
|
||||
source_table = config.source_table
|
||||
destination_table = config.destination_table
|
||||
chunk_id = f"chunk_{chunk_min}_{chunk_max}"
|
||||
job_name = context.run.job_name
|
||||
|
||||
# Initialize metrics client
|
||||
metrics_client = MetricsClient(cluster)
|
||||
|
||||
context.log.info(f"Starting chunk copy: {chunk_min} to {chunk_max}")
|
||||
|
||||
total_records_copied = 0
|
||||
batch_start_id = chunk_min
|
||||
failed_batch_start_id: int | None = None
|
||||
|
||||
try:
|
||||
with database.cursor() as cursor:
|
||||
# Set session-level settings once for the entire chunk
|
||||
cursor.execute("SET application_name = 'backfill_posthog_persons_to_posthog_persons_new'")
|
||||
cursor.execute("SET lock_timeout = '5s'")
|
||||
cursor.execute("SET statement_timeout = '30min'")
|
||||
cursor.execute("SET maintenance_work_mem = '12GB'")
|
||||
cursor.execute("SET work_mem = '512MB'")
|
||||
cursor.execute("SET temp_buffers = '512MB'")
|
||||
cursor.execute("SET max_parallel_workers_per_gather = 2")
|
||||
cursor.execute("SET max_parallel_maintenance_workers = 2")
|
||||
cursor.execute("SET synchronous_commit = off")
|
||||
|
||||
retry_attempt = 0
|
||||
while batch_start_id <= chunk_max:
|
||||
try:
|
||||
# Track batch start time for duration metric
|
||||
batch_start_time = time.time()
|
||||
|
||||
# Calculate batch end ID
|
||||
batch_end_id = min(batch_start_id + batch_size, chunk_max)
|
||||
|
||||
# Track records attempted - this is also our exit condition
|
||||
records_attempted = batch_end_id - batch_start_id
|
||||
if records_attempted <= 0:
|
||||
break
|
||||
# Begin transaction (settings already applied at session level)
|
||||
cursor.execute("BEGIN")
|
||||
|
||||
# Execute INSERT INTO ... SELECT with NOT EXISTS check
|
||||
insert_query = f"""
|
||||
INSERT INTO {destination_table}
|
||||
SELECT s.*
|
||||
FROM {source_table} s
|
||||
WHERE s.id >= %s AND s.id <= %s
|
||||
AND NOT EXISTS (
|
||||
SELECT 1
|
||||
FROM {destination_table} d
|
||||
WHERE d.team_id = s.team_id
|
||||
AND d.id = s.id
|
||||
)
|
||||
ORDER BY s.id DESC
|
||||
"""
|
||||
cursor.execute(insert_query, (batch_start_id, batch_end_id))
|
||||
records_inserted = cursor.rowcount
|
||||
|
||||
# Commit the transaction
|
||||
cursor.execute("COMMIT")
|
||||
|
||||
try:
|
||||
metrics_client.increment(
|
||||
"persons_new_backfill_records_attempted_total",
|
||||
labels={"job_name": job_name, "chunk_id": chunk_id},
|
||||
value=float(records_attempted),
|
||||
).result()
|
||||
except Exception:
|
||||
pass # Don't fail on metrics error
|
||||
|
||||
batch_duration_seconds = time.time() - batch_start_time
|
||||
|
||||
try:
|
||||
metrics_client.increment(
|
||||
"persons_new_backfill_records_inserted_total",
|
||||
labels={"job_name": job_name, "chunk_id": chunk_id},
|
||||
value=float(records_inserted),
|
||||
).result()
|
||||
except Exception:
|
||||
pass # Don't fail on metrics error
|
||||
|
||||
try:
|
||||
metrics_client.increment(
|
||||
"persons_new_backfill_batches_copied_total",
|
||||
labels={"job_name": job_name, "chunk_id": chunk_id},
|
||||
value=1.0,
|
||||
).result()
|
||||
except Exception:
|
||||
pass
|
||||
# Track batch duration metric (IV)
|
||||
try:
|
||||
metrics_client.increment(
|
||||
"persons_new_backfill_batch_duration_seconds_total",
|
||||
labels={"job_name": job_name, "chunk_id": chunk_id},
|
||||
value=batch_duration_seconds,
|
||||
).result()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
total_records_copied += records_inserted
|
||||
|
||||
context.log.info(
|
||||
f"Copied batch: {records_inserted} records "
|
||||
f"(chunk {chunk_min}-{chunk_max}, batch ID range {batch_start_id} to {batch_end_id})"
|
||||
)
|
||||
|
||||
# Update batch_start_id for next iteration
|
||||
batch_start_id = batch_end_id + 1
|
||||
retry_attempt = 0
|
||||
|
||||
except Exception as batch_error:
|
||||
# Rollback transaction on error
|
||||
try:
|
||||
cursor.execute("ROLLBACK")
|
||||
except Exception as rollback_error:
|
||||
context.log.exception(
|
||||
f"Failed to rollback transaction for batch starting at ID {batch_start_id}"
|
||||
f"in chunk {chunk_min}-{chunk_max}: {str(rollback_error)}"
|
||||
)
|
||||
pass # Ignore rollback errors
|
||||
|
||||
# Check if error is a duplicate key violation, pause and retry if so
|
||||
is_unique_violation = isinstance(batch_error, psycopg2.errors.UniqueViolation) or (
|
||||
isinstance(batch_error, psycopg2.Error) and getattr(batch_error, "pgcode", None) == "23505"
|
||||
)
|
||||
if is_unique_violation:
|
||||
error_msg = (
|
||||
f"Duplicate key violation detected for batch starting at ID {batch_start_id} "
|
||||
f"in chunk {chunk_min}-{chunk_max}. Error is: {batch_error}. "
|
||||
"This is expected if records already exist in destination table. "
|
||||
)
|
||||
context.log.warning(error_msg)
|
||||
if retry_attempt < 3:
|
||||
retry_attempt += 1
|
||||
context.log.info(f"Retrying batch {retry_attempt} of 3...")
|
||||
time.sleep(1)
|
||||
continue
|
||||
|
||||
failed_batch_start_id = batch_start_id
|
||||
error_msg = (
|
||||
f"Failed to copy batch starting at ID {batch_start_id} "
|
||||
f"in chunk {chunk_min}-{chunk_max}: {str(batch_error)}"
|
||||
)
|
||||
context.log.exception(error_msg)
|
||||
# Report fatal error metric before raising
|
||||
try:
|
||||
metrics_client.increment(
|
||||
"persons_new_backfill_error",
|
||||
labels={"job_name": job_name, "chunk_id": chunk_id, "reason": "batch_copy_failed"},
|
||||
value=1.0,
|
||||
).result()
|
||||
except Exception:
|
||||
pass # Don't fail on metrics error
|
||||
|
||||
raise dagster.Failure(
|
||||
description=error_msg,
|
||||
metadata={
|
||||
"chunk_min_id": dagster.MetadataValue.int(chunk_min),
|
||||
"chunk_max_id": dagster.MetadataValue.int(chunk_max),
|
||||
"failed_batch_start_id": dagster.MetadataValue.int(failed_batch_start_id)
|
||||
if failed_batch_start_id
|
||||
else dagster.MetadataValue.text("N/A"),
|
||||
"error_message": dagster.MetadataValue.text(str(batch_error)),
|
||||
"records_copied_before_failure": dagster.MetadataValue.int(total_records_copied),
|
||||
},
|
||||
) from batch_error
|
||||
|
||||
except dagster.Failure:
|
||||
# Re-raise Dagster failures as-is (they already have metadata and metrics)
|
||||
raise
|
||||
except Exception as e:
|
||||
# Catch any other unexpected errors
|
||||
error_msg = f"Unexpected error copying chunk {chunk_min}-{chunk_max}: {str(e)}"
|
||||
context.log.exception(error_msg)
|
||||
# Report fatal error metric before raising
|
||||
try:
|
||||
metrics_client.increment(
|
||||
"persons_new_backfill_error",
|
||||
labels={"job_name": job_name, "chunk_id": chunk_id, "reason": "unexpected_copy_error"},
|
||||
value=1.0,
|
||||
).result()
|
||||
except Exception:
|
||||
pass # Don't fail on metrics error
|
||||
raise dagster.Failure(
|
||||
description=error_msg,
|
||||
metadata={
|
||||
"chunk_min_id": dagster.MetadataValue.int(chunk_min),
|
||||
"chunk_max_id": dagster.MetadataValue.int(chunk_max),
|
||||
"failed_batch_start_id": dagster.MetadataValue.int(failed_batch_start_id)
|
||||
if failed_batch_start_id
|
||||
else dagster.MetadataValue.int(batch_start_id),
|
||||
"error_message": dagster.MetadataValue.text(str(e)),
|
||||
"records_copied_before_failure": dagster.MetadataValue.int(total_records_copied),
|
||||
},
|
||||
) from e
|
||||
|
||||
context.log.info(f"Completed chunk {chunk_min}-{chunk_max}: copied {total_records_copied} records")
|
||||
|
||||
# Emit metric for chunk completion
|
||||
run_id = context.run.run_id
|
||||
try:
|
||||
metrics_client.increment(
|
||||
"persons_new_backfill_chunks_completed_total",
|
||||
labels={"job_name": job_name, "run_id": run_id, "chunk_id": chunk_id},
|
||||
value=1.0,
|
||||
).result()
|
||||
except Exception:
|
||||
pass # Don't fail on metrics error
|
||||
|
||||
context.add_output_metadata(
|
||||
{
|
||||
"chunk_min": dagster.MetadataValue.int(chunk_min),
|
||||
"chunk_max": dagster.MetadataValue.int(chunk_max),
|
||||
"records_copied": dagster.MetadataValue.int(total_records_copied),
|
||||
}
|
||||
)
|
||||
|
||||
return {
|
||||
"chunk_min": chunk_min,
|
||||
"chunk_max": chunk_max,
|
||||
"records_copied": total_records_copied,
|
||||
}
|
||||
|
||||
|
||||
@dagster.asset
|
||||
def postgres_env_check(context: dagster.AssetExecutionContext) -> None:
|
||||
"""
|
||||
Simple asset that prints PostgreSQL environment variables being used.
|
||||
Useful for debugging connection configuration.
|
||||
"""
|
||||
env_vars = {
|
||||
"POSTGRES_HOST": os.getenv("POSTGRES_HOST", "not set"),
|
||||
"POSTGRES_PORT": os.getenv("POSTGRES_PORT", "not set"),
|
||||
"POSTGRES_DATABASE": os.getenv("POSTGRES_DATABASE", "not set"),
|
||||
"POSTGRES_USER": os.getenv("POSTGRES_USER", "not set"),
|
||||
"POSTGRES_PASSWORD": "***" if os.getenv("POSTGRES_PASSWORD") else "not set",
|
||||
}
|
||||
|
||||
context.log.info("PostgreSQL environment variables:")
|
||||
for key, value in env_vars.items():
|
||||
context.log.info(f" {key}: {value}")
|
||||
|
||||
|
||||
@dagster.job(
|
||||
tags={"owner": JobOwners.TEAM_INGESTION.value},
|
||||
executor_def=k8s_job_executor,
|
||||
)
|
||||
def persons_new_backfill_job():
|
||||
"""
|
||||
Backfill posthog_persons data from source to destination Postgres database.
|
||||
Divides the ID space into chunks and processes them in parallel.
|
||||
"""
|
||||
id_range = get_id_range()
|
||||
chunks = create_chunks(id_range)
|
||||
chunks.map(copy_chunk)
|
||||
@@ -1,15 +1,34 @@
|
||||
from clickhouse_driver import Client
|
||||
from dagster import AssetExecutionContext, BackfillPolicy, DailyPartitionsDefinition, asset
|
||||
|
||||
from posthog.clickhouse.client import sync_execute
|
||||
from posthog.clickhouse.client.connection import Workload
|
||||
from posthog.clickhouse.cluster import get_cluster
|
||||
from posthog.clickhouse.query_tagging import tags_context
|
||||
from posthog.git import get_git_commit_short
|
||||
from posthog.models.raw_sessions.sessions_v3 import (
|
||||
RAW_SESSION_TABLE_BACKFILL_RECORDINGS_SQL_V3,
|
||||
RAW_SESSION_TABLE_BACKFILL_SQL_V3,
|
||||
)
|
||||
|
||||
from dags.common import dagster_tags
|
||||
from dags.common.common import JobOwners, metabase_debug_query_url
|
||||
|
||||
# This is the number of days to backfill in one SQL operation
|
||||
MAX_PARTITIONS_PER_RUN = 30
|
||||
MAX_PARTITIONS_PER_RUN = 1
|
||||
|
||||
# Keep the number of concurrent runs low to avoid overloading ClickHouse and running into the dread "Too many parts".
|
||||
# This tag needs to also exist in Dagster Cloud (and the local dev dagster.yaml) for the concurrency limit to take effect.
|
||||
# concurrency:
|
||||
# runs:
|
||||
# tag_concurrency_limits:
|
||||
# - key: 'sessions_backfill_concurrency'
|
||||
# limit: 3
|
||||
# value:
|
||||
# applyLimitPerUniqueValue: true
|
||||
CONCURRENCY_TAG = {
|
||||
"sessions_backfill_concurrency": "sessions_v3",
|
||||
}
|
||||
|
||||
daily_partitions = DailyPartitionsDefinition(
|
||||
start_date="2019-01-01", # this is a year before posthog was founded, so should be early enough even including data imports
|
||||
@@ -17,6 +36,17 @@ daily_partitions = DailyPartitionsDefinition(
|
||||
end_offset=1, # include today's partition (note that will create a partition with incomplete data, but all our backfills are idempotent so this is ok providing we re-run later)
|
||||
)
|
||||
|
||||
ONE_HOUR_IN_SECONDS = 60 * 60
|
||||
ONE_GB_IN_BYTES = 1024 * 1024 * 1024
|
||||
|
||||
settings = {
|
||||
# see this run which took around 2hrs 10min for 1 day https://posthog.dagster.plus/prod-us/runs/0ba8afaa-f3cc-4845-97c5-96731ec8231d?focusedTime=1762898705269&selection=sessions_v3_backfill&logs=step%3Asessions_v3_backfill
|
||||
# so to give some margin, allow 4 hours per partition
|
||||
"max_execution_time": MAX_PARTITIONS_PER_RUN * 4 * ONE_HOUR_IN_SECONDS,
|
||||
"max_memory_usage": 100 * ONE_GB_IN_BYTES,
|
||||
"distributed_aggregation_memory_efficient": "1",
|
||||
}
|
||||
|
||||
|
||||
def get_partition_where_clause(context: AssetExecutionContext, timestamp_field: str) -> str:
|
||||
start_incl = context.partition_time_window.start.strftime("%Y-%m-%d")
|
||||
@@ -31,13 +61,14 @@ def get_partition_where_clause(context: AssetExecutionContext, timestamp_field:
|
||||
partitions_def=daily_partitions,
|
||||
name="sessions_v3_backfill",
|
||||
backfill_policy=BackfillPolicy.multi_run(max_partitions_per_run=MAX_PARTITIONS_PER_RUN),
|
||||
tags={"owner": JobOwners.TEAM_ANALYTICS_PLATFORM.value, **CONCURRENCY_TAG},
|
||||
)
|
||||
def sessions_v3_backfill(context: AssetExecutionContext) -> None:
|
||||
where_clause = get_partition_where_clause(context, timestamp_field="timestamp")
|
||||
|
||||
# note that this is idempotent, so we don't need to worry about running it multiple times for the same partition
|
||||
# as long as the backfill has run at least once for each partition, the data will be correct
|
||||
backfill_sql = RAW_SESSION_TABLE_BACKFILL_SQL_V3(where=where_clause)
|
||||
backfill_sql = RAW_SESSION_TABLE_BACKFILL_SQL_V3(where=where_clause, use_sharded_source=True)
|
||||
|
||||
partition_range = context.partition_key_range
|
||||
partition_range_str = f"{partition_range.start} to {partition_range.end}"
|
||||
@@ -45,8 +76,17 @@ def sessions_v3_backfill(context: AssetExecutionContext) -> None:
|
||||
f"Running backfill for {partition_range_str} (where='{where_clause}') using commit {get_git_commit_short() or 'unknown'} "
|
||||
)
|
||||
context.log.info(backfill_sql)
|
||||
if debug_url := metabase_debug_query_url(context.run_id):
|
||||
context.log.info(f"Debug query: {debug_url}")
|
||||
|
||||
sync_execute(backfill_sql, workload=Workload.OFFLINE)
|
||||
cluster = get_cluster()
|
||||
tags = dagster_tags(context)
|
||||
|
||||
def backfill_per_shard(client: Client):
|
||||
with tags_context(kind="dagster", dagster=tags):
|
||||
sync_execute(backfill_sql, settings=settings, sync_client=client)
|
||||
|
||||
cluster.map_one_host_per_shard(backfill_per_shard).result()
|
||||
|
||||
context.log.info(f"Successfully backfilled sessions_v3 for {partition_range_str}")
|
||||
|
||||
@@ -55,13 +95,14 @@ def sessions_v3_backfill(context: AssetExecutionContext) -> None:
|
||||
partitions_def=daily_partitions,
|
||||
name="sessions_v3_replay_backfill",
|
||||
backfill_policy=BackfillPolicy.multi_run(max_partitions_per_run=MAX_PARTITIONS_PER_RUN),
|
||||
tags={"owner": JobOwners.TEAM_ANALYTICS_PLATFORM.value, **CONCURRENCY_TAG},
|
||||
)
|
||||
def sessions_v3_backfill_replay(context: AssetExecutionContext) -> None:
|
||||
where_clause = get_partition_where_clause(context, timestamp_field="min_first_timestamp")
|
||||
|
||||
# note that this is idempotent, so we don't need to worry about running it multiple times for the same partition
|
||||
# as long as the backfill has run at least once for each partition, the data will be correct
|
||||
backfill_sql = RAW_SESSION_TABLE_BACKFILL_RECORDINGS_SQL_V3(where=where_clause)
|
||||
backfill_sql = RAW_SESSION_TABLE_BACKFILL_RECORDINGS_SQL_V3(where=where_clause, use_sharded_source=True)
|
||||
|
||||
partition_range = context.partition_key_range
|
||||
partition_range_str = f"{partition_range.start} to {partition_range.end}"
|
||||
@@ -69,7 +110,16 @@ def sessions_v3_backfill_replay(context: AssetExecutionContext) -> None:
|
||||
f"Running backfill for {partition_range_str} (where='{where_clause}') using commit {get_git_commit_short() or 'unknown'} "
|
||||
)
|
||||
context.log.info(backfill_sql)
|
||||
if debug_url := metabase_debug_query_url(context.run_id):
|
||||
context.log.info(f"Debug query: {debug_url}")
|
||||
|
||||
sync_execute(backfill_sql, workload=Workload.OFFLINE)
|
||||
cluster = get_cluster()
|
||||
tags = dagster_tags(context)
|
||||
|
||||
def backfill_per_shard(client: Client):
|
||||
with tags_context(kind="dagster", dagster=tags):
|
||||
sync_execute(backfill_sql, workload=Workload.OFFLINE, settings=settings, sync_client=client)
|
||||
|
||||
cluster.map_one_host_per_shard(backfill_per_shard).result()
|
||||
|
||||
context.log.info(f"Successfully backfilled sessions_v3 for {partition_range_str}")
|
||||
|
||||
@@ -9,14 +9,16 @@ from dagster import DagsterRunStatus, RunsFilter
|
||||
from dags.common import JobOwners
|
||||
|
||||
notification_channel_per_team = {
|
||||
JobOwners.TEAM_ANALYTICS_PLATFORM.value: "#alerts-analytics-platform",
|
||||
JobOwners.TEAM_CLICKHOUSE.value: "#alerts-clickhouse",
|
||||
JobOwners.TEAM_WEB_ANALYTICS.value: "#alerts-web-analytics",
|
||||
JobOwners.TEAM_REVENUE_ANALYTICS.value: "#alerts-revenue-analytics",
|
||||
JobOwners.TEAM_ERROR_TRACKING.value: "#alerts-error-tracking",
|
||||
JobOwners.TEAM_GROWTH.value: "#alerts-growth",
|
||||
JobOwners.TEAM_EXPERIMENTS.value: "#alerts-experiments",
|
||||
JobOwners.TEAM_MAX_AI.value: "#alerts-max-ai",
|
||||
JobOwners.TEAM_DATA_WAREHOUSE.value: "#alerts-data-warehouse",
|
||||
JobOwners.TEAM_ERROR_TRACKING.value: "#alerts-error-tracking",
|
||||
JobOwners.TEAM_EXPERIMENTS.value: "#alerts-experiments-dagster",
|
||||
JobOwners.TEAM_GROWTH.value: "#alerts-growth",
|
||||
JobOwners.TEAM_INGESTION.value: "#alerts-ingestion",
|
||||
JobOwners.TEAM_MAX_AI.value: "#alerts-max-ai",
|
||||
JobOwners.TEAM_REVENUE_ANALYTICS.value: "#alerts-revenue-analytics",
|
||||
JobOwners.TEAM_WEB_ANALYTICS.value: "#alerts-web-analytics",
|
||||
}
|
||||
|
||||
CONSECUTIVE_FAILURE_THRESHOLDS = {
|
||||
|
||||
371
dags/tests/llma/daily_metrics/test_sql_metrics.py
Normal file
371
dags/tests/llma/daily_metrics/test_sql_metrics.py
Normal file
@@ -0,0 +1,371 @@
|
||||
"""
|
||||
Tests that execute SQL templates against mock data to validate output and logic.
|
||||
|
||||
Tests both the structure and calculation logic of each metric SQL file.
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from jinja2 import Template
|
||||
|
||||
from dags.llma.daily_metrics.config import config
|
||||
from dags.llma.daily_metrics.utils import SQL_DIR
|
||||
|
||||
# Expected output columns
|
||||
EXPECTED_COLUMNS = ["date", "team_id", "metric_name", "metric_value"]
|
||||
|
||||
|
||||
def get_all_sql_files():
|
||||
"""Get all SQL template files."""
|
||||
return sorted(SQL_DIR.glob("*.sql"))
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def template_context():
|
||||
"""Provide sample context for rendering Jinja2 templates."""
|
||||
return {
|
||||
"event_types": config.ai_event_types,
|
||||
"pageview_mappings": config.pageview_mappings,
|
||||
"date_start": "2025-01-01",
|
||||
"date_end": "2025-01-02",
|
||||
"include_error_rates": config.include_error_rates,
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_events_data():
|
||||
"""
|
||||
Mock events data for testing SQL queries.
|
||||
|
||||
Simulates the events table with various AI events for testing.
|
||||
Returns a list of event dicts with timestamp, team_id, event, and properties.
|
||||
"""
|
||||
base_time = datetime(2025, 1, 1, 12, 0, 0)
|
||||
|
||||
events = [
|
||||
# Team 1: 3 generations, 1 with error
|
||||
{
|
||||
"timestamp": base_time,
|
||||
"team_id": 1,
|
||||
"event": "$ai_generation",
|
||||
"properties": {
|
||||
"$ai_trace_id": "trace-1",
|
||||
"$ai_session_id": "session-1",
|
||||
"$ai_error": "",
|
||||
},
|
||||
},
|
||||
{
|
||||
"timestamp": base_time,
|
||||
"team_id": 1,
|
||||
"event": "$ai_generation",
|
||||
"properties": {
|
||||
"$ai_trace_id": "trace-1",
|
||||
"$ai_session_id": "session-1",
|
||||
"$ai_error": "rate limit exceeded",
|
||||
},
|
||||
},
|
||||
{
|
||||
"timestamp": base_time,
|
||||
"team_id": 1,
|
||||
"event": "$ai_generation",
|
||||
"properties": {
|
||||
"$ai_trace_id": "trace-2",
|
||||
"$ai_session_id": "session-1",
|
||||
"$ai_error": "",
|
||||
},
|
||||
},
|
||||
# Team 1: 2 spans from same trace
|
||||
{
|
||||
"timestamp": base_time,
|
||||
"team_id": 1,
|
||||
"event": "$ai_span",
|
||||
"properties": {
|
||||
"$ai_trace_id": "trace-1",
|
||||
"$ai_session_id": "session-1",
|
||||
"$ai_is_error": True,
|
||||
},
|
||||
},
|
||||
{
|
||||
"timestamp": base_time,
|
||||
"team_id": 1,
|
||||
"event": "$ai_span",
|
||||
"properties": {
|
||||
"$ai_trace_id": "trace-1",
|
||||
"$ai_session_id": "session-1",
|
||||
"$ai_is_error": False,
|
||||
},
|
||||
},
|
||||
# Team 2: 1 generation, no errors
|
||||
{
|
||||
"timestamp": base_time,
|
||||
"team_id": 2,
|
||||
"event": "$ai_generation",
|
||||
"properties": {
|
||||
"$ai_trace_id": "trace-3",
|
||||
"$ai_session_id": "session-2",
|
||||
"$ai_error": "",
|
||||
},
|
||||
},
|
||||
# Team 1: Pageviews on LLM Analytics
|
||||
{
|
||||
"timestamp": base_time,
|
||||
"team_id": 1,
|
||||
"event": "$pageview",
|
||||
"properties": {
|
||||
"$current_url": "https://app.posthog.com/project/123/llm-analytics/traces?filter=active",
|
||||
},
|
||||
},
|
||||
{
|
||||
"timestamp": base_time,
|
||||
"team_id": 1,
|
||||
"event": "$pageview",
|
||||
"properties": {
|
||||
"$current_url": "https://app.posthog.com/project/123/llm-analytics/traces",
|
||||
},
|
||||
},
|
||||
{
|
||||
"timestamp": base_time,
|
||||
"team_id": 1,
|
||||
"event": "$pageview",
|
||||
"properties": {
|
||||
"$current_url": "https://app.posthog.com/project/123/llm-analytics/generations",
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
return events
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def expected_metrics(mock_events_data):
|
||||
"""
|
||||
Expected metric outputs based on mock_events_data.
|
||||
|
||||
This serves as the source of truth for what each SQL file should produce.
|
||||
"""
|
||||
return {
|
||||
"event_counts.sql": [
|
||||
{"date": "2025-01-01", "team_id": 1, "metric_name": "ai_generation_count", "metric_value": 3.0},
|
||||
{"date": "2025-01-01", "team_id": 1, "metric_name": "ai_span_count", "metric_value": 2.0},
|
||||
{"date": "2025-01-01", "team_id": 2, "metric_name": "ai_generation_count", "metric_value": 1.0},
|
||||
],
|
||||
"trace_counts.sql": [
|
||||
{
|
||||
"date": "2025-01-01",
|
||||
"team_id": 1,
|
||||
"metric_name": "ai_trace_id_count",
|
||||
"metric_value": 2.0,
|
||||
}, # trace-1, trace-2
|
||||
{"date": "2025-01-01", "team_id": 2, "metric_name": "ai_trace_id_count", "metric_value": 1.0}, # trace-3
|
||||
],
|
||||
"session_counts.sql": [
|
||||
{
|
||||
"date": "2025-01-01",
|
||||
"team_id": 1,
|
||||
"metric_name": "ai_session_id_count",
|
||||
"metric_value": 1.0,
|
||||
}, # session-1
|
||||
{
|
||||
"date": "2025-01-01",
|
||||
"team_id": 2,
|
||||
"metric_name": "ai_session_id_count",
|
||||
"metric_value": 1.0,
|
||||
}, # session-2
|
||||
],
|
||||
"error_rates.sql": [
|
||||
# Team 1: 1 errored generation out of 3 = 0.3333
|
||||
{"date": "2025-01-01", "team_id": 1, "metric_name": "ai_generation_error_rate", "metric_value": 0.3333},
|
||||
# Team 1: 1 errored span out of 2 = 0.5
|
||||
{"date": "2025-01-01", "team_id": 1, "metric_name": "ai_span_error_rate", "metric_value": 0.5},
|
||||
# Team 2: 0 errored out of 1 = 0.0
|
||||
{"date": "2025-01-01", "team_id": 2, "metric_name": "ai_generation_error_rate", "metric_value": 0.0},
|
||||
],
|
||||
"trace_error_rates.sql": [
|
||||
# Team 1: trace-1 has errors, trace-2 doesn't = 1/2 = 0.5
|
||||
{"date": "2025-01-01", "team_id": 1, "metric_name": "ai_trace_id_has_error_rate", "metric_value": 0.5},
|
||||
# Team 2: trace-3 has no errors = 0/1 = 0.0
|
||||
{"date": "2025-01-01", "team_id": 2, "metric_name": "ai_trace_id_has_error_rate", "metric_value": 0.0},
|
||||
],
|
||||
"pageview_counts.sql": [
|
||||
{"date": "2025-01-01", "team_id": 1, "metric_name": "pageviews_traces", "metric_value": 2.0},
|
||||
{"date": "2025-01-01", "team_id": 1, "metric_name": "pageviews_generations", "metric_value": 1.0},
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize("sql_file", get_all_sql_files(), ids=lambda f: f.stem)
|
||||
def test_sql_output_structure(sql_file: Path, template_context: dict, mock_events_data: list):
|
||||
"""
|
||||
Test that each SQL file produces output with the correct structure.
|
||||
|
||||
This test verifies:
|
||||
1. SQL renders without errors
|
||||
2. Output has exactly 4 columns
|
||||
3. Columns are named correctly: date, team_id, metric_name, metric_value
|
||||
"""
|
||||
with open(sql_file) as f:
|
||||
template = Template(f.read())
|
||||
rendered = template.render(**template_context)
|
||||
|
||||
# Basic smoke test - SQL should render
|
||||
assert rendered.strip(), f"{sql_file.name} rendered to empty string"
|
||||
|
||||
# SQL should be a SELECT statement
|
||||
assert "SELECT" in rendered.upper(), f"{sql_file.name} should contain SELECT"
|
||||
|
||||
# Should have all required column aliases (or direct column references for team_id)
|
||||
for col in EXPECTED_COLUMNS:
|
||||
# team_id is often selected directly without an alias
|
||||
if col == "team_id":
|
||||
assert "team_id" in rendered.lower(), f"{sql_file.name} missing column: {col}"
|
||||
else:
|
||||
assert f"as {col}" in rendered.lower(), f"{sql_file.name} missing column alias: {col}"
|
||||
|
||||
|
||||
def test_event_counts_logic(template_context: dict, mock_events_data: list, expected_metrics: dict):
|
||||
"""Test event_counts.sql produces correct counts."""
|
||||
sql_file = SQL_DIR / "event_counts.sql"
|
||||
with open(sql_file) as f:
|
||||
template = Template(f.read())
|
||||
rendered = template.render(**template_context)
|
||||
|
||||
# Verify SQL structure expectations
|
||||
assert "count(*)" in rendered.lower(), "event_counts.sql should use count(*)"
|
||||
assert "group by date, team_id, event" in rendered.lower(), "Should group by date, team_id, event"
|
||||
|
||||
# Verify expected metrics documentation
|
||||
expected = expected_metrics["event_counts.sql"]
|
||||
assert len(expected) == 3, "Expected 3 metric rows based on mock data"
|
||||
|
||||
# Team 1 should have 3 generation events
|
||||
gen_team1 = next(m for m in expected if m["team_id"] == 1 and "generation" in m["metric_name"])
|
||||
assert gen_team1["metric_value"] == 3.0
|
||||
|
||||
# Team 1 should have 2 span events
|
||||
span_team1 = next(m for m in expected if m["team_id"] == 1 and "span" in m["metric_name"])
|
||||
assert span_team1["metric_value"] == 2.0
|
||||
|
||||
|
||||
def test_trace_counts_logic(template_context: dict, mock_events_data: list, expected_metrics: dict):
|
||||
"""Test trace_counts.sql produces correct distinct trace counts."""
|
||||
sql_file = SQL_DIR / "trace_counts.sql"
|
||||
with open(sql_file) as f:
|
||||
template = Template(f.read())
|
||||
rendered = template.render(**template_context)
|
||||
|
||||
# Verify SQL uses count(DISTINCT)
|
||||
assert "count(distinct" in rendered.lower(), "trace_counts.sql should use count(DISTINCT)"
|
||||
assert "$ai_trace_id" in rendered, "Should count distinct $ai_trace_id"
|
||||
|
||||
expected = expected_metrics["trace_counts.sql"]
|
||||
|
||||
# Team 1: Should have 2 unique traces (trace-1, trace-2)
|
||||
team1 = next(m for m in expected if m["team_id"] == 1)
|
||||
assert team1["metric_value"] == 2.0, "Team 1 should have 2 unique traces"
|
||||
|
||||
# Team 2: Should have 1 unique trace (trace-3)
|
||||
team2 = next(m for m in expected if m["team_id"] == 2)
|
||||
assert team2["metric_value"] == 1.0, "Team 2 should have 1 unique trace"
|
||||
|
||||
|
||||
def test_session_counts_logic(template_context: dict, mock_events_data: list, expected_metrics: dict):
|
||||
"""Test session_counts.sql produces correct distinct session counts."""
|
||||
sql_file = SQL_DIR / "session_counts.sql"
|
||||
with open(sql_file) as f:
|
||||
template = Template(f.read())
|
||||
rendered = template.render(**template_context)
|
||||
|
||||
# Verify SQL uses count(DISTINCT)
|
||||
assert "count(distinct" in rendered.lower(), "session_counts.sql should use count(DISTINCT)"
|
||||
assert "$ai_session_id" in rendered, "Should count distinct $ai_session_id"
|
||||
|
||||
expected = expected_metrics["session_counts.sql"]
|
||||
|
||||
# Team 1: Should have 1 unique session (session-1)
|
||||
team1 = next(m for m in expected if m["team_id"] == 1)
|
||||
assert team1["metric_value"] == 1.0, "Team 1 should have 1 unique session"
|
||||
|
||||
# Team 2: Should have 1 unique session (session-2)
|
||||
team2 = next(m for m in expected if m["team_id"] == 2)
|
||||
assert team2["metric_value"] == 1.0, "Team 2 should have 1 unique session"
|
||||
|
||||
|
||||
def test_error_rates_logic(template_context: dict, mock_events_data: list, expected_metrics: dict):
|
||||
"""Test error_rates.sql calculates proportions correctly."""
|
||||
sql_file = SQL_DIR / "error_rates.sql"
|
||||
with open(sql_file) as f:
|
||||
template = Template(f.read())
|
||||
rendered = template.render(**template_context)
|
||||
|
||||
# Verify SQL calculates proportions
|
||||
assert "countif" in rendered.lower(), "error_rates.sql should use countIf"
|
||||
assert "/ count(*)" in rendered.lower(), "Should divide by total count"
|
||||
assert "round(" in rendered.lower(), "Should round the result"
|
||||
|
||||
# Verify error detection logic
|
||||
assert "$ai_error" in rendered, "Should check $ai_error property"
|
||||
assert "$ai_is_error" in rendered, "Should check $ai_is_error property"
|
||||
|
||||
expected = expected_metrics["error_rates.sql"]
|
||||
|
||||
# Team 1 generations: 1 error out of 3 = 0.3333
|
||||
gen_team1 = next(m for m in expected if m["team_id"] == 1 and "generation" in m["metric_name"])
|
||||
assert abs(gen_team1["metric_value"] - 0.3333) < 0.0001, "Generation error rate should be ~0.3333"
|
||||
|
||||
# Team 1 spans: 1 error out of 2 = 0.5
|
||||
span_team1 = next(m for m in expected if m["team_id"] == 1 and "span" in m["metric_name"])
|
||||
assert span_team1["metric_value"] == 0.5, "Span error rate should be 0.5"
|
||||
|
||||
|
||||
def test_trace_error_rates_logic(template_context: dict, mock_events_data: list, expected_metrics: dict):
|
||||
"""Test trace_error_rates.sql calculates trace-level error proportions correctly."""
|
||||
sql_file = SQL_DIR / "trace_error_rates.sql"
|
||||
with open(sql_file) as f:
|
||||
template = Template(f.read())
|
||||
rendered = template.render(**template_context)
|
||||
|
||||
# Verify SQL uses distinct count
|
||||
assert "countdistinctif" in rendered.lower(), "Should use countDistinctIf"
|
||||
assert "count(distinct" in rendered.lower(), "Should use count(DISTINCT) for total"
|
||||
assert "$ai_trace_id" in rendered, "Should work with $ai_trace_id"
|
||||
|
||||
expected = expected_metrics["trace_error_rates.sql"]
|
||||
|
||||
# Team 1: trace-1 has errors, trace-2 doesn't = 1/2 = 0.5
|
||||
team1 = next(m for m in expected if m["team_id"] == 1)
|
||||
assert team1["metric_value"] == 0.5, "Team 1 should have 50% of traces with errors"
|
||||
|
||||
# Team 2: trace-3 has no errors = 0/1 = 0.0
|
||||
team2 = next(m for m in expected if m["team_id"] == 2)
|
||||
assert team2["metric_value"] == 0.0, "Team 2 should have 0% of traces with errors"
|
||||
|
||||
|
||||
def test_pageview_counts_logic(template_context: dict, mock_events_data: list, expected_metrics: dict):
|
||||
"""Test pageview_counts.sql categorizes and counts pageviews correctly."""
|
||||
sql_file = SQL_DIR / "pageview_counts.sql"
|
||||
with open(sql_file) as f:
|
||||
template = Template(f.read())
|
||||
rendered = template.render(**template_context)
|
||||
|
||||
# Verify SQL filters $pageview events
|
||||
assert "event = '$pageview'" in rendered, "Should filter for $pageview events"
|
||||
assert "$current_url" in rendered, "Should use $current_url property"
|
||||
assert "LIKE" in rendered, "Should use LIKE for URL matching"
|
||||
|
||||
# Verify pageview mappings are used
|
||||
assert config.pageview_mappings is not None
|
||||
for url_path, _ in config.pageview_mappings:
|
||||
assert url_path in rendered, f"Should include pageview mapping for {url_path}"
|
||||
|
||||
expected = expected_metrics["pageview_counts.sql"]
|
||||
|
||||
# Team 1: 2 trace pageviews
|
||||
traces = next(m for m in expected if "traces" in m["metric_name"])
|
||||
assert traces["metric_value"] == 2.0, "Should count 2 trace pageviews"
|
||||
|
||||
# Team 1: 1 generation pageview
|
||||
gens = next(m for m in expected if "generations" in m["metric_name"])
|
||||
assert gens["metric_value"] == 1.0, "Should count 1 generation pageview"
|
||||
@@ -16,7 +16,27 @@ from posthog import settings
|
||||
from posthog.clickhouse.client.connection import Workload
|
||||
from posthog.clickhouse.cluster import ClickhouseCluster
|
||||
|
||||
from dags.backups import Backup, BackupConfig, get_latest_backup, non_sharded_backup, prepare_run_config, sharded_backup
|
||||
from dags.backups import (
|
||||
Backup,
|
||||
BackupConfig,
|
||||
BackupStatus,
|
||||
get_latest_backups,
|
||||
get_latest_successful_backup,
|
||||
non_sharded_backup,
|
||||
prepare_run_config,
|
||||
sharded_backup,
|
||||
)
|
||||
|
||||
|
||||
def test_get_latest_backup_empty():
|
||||
mock_s3 = MagicMock()
|
||||
mock_s3.get_client().list_objects_v2.return_value = {}
|
||||
|
||||
config = BackupConfig(database="posthog", table="dummy")
|
||||
context = dagster.build_op_context()
|
||||
result = get_latest_backups(context=context, config=config, s3=mock_s3)
|
||||
|
||||
assert result == []
|
||||
|
||||
|
||||
@pytest.mark.parametrize("table", ["", "test"])
|
||||
@@ -31,15 +51,84 @@ def test_get_latest_backup(table: str):
|
||||
}
|
||||
|
||||
config = BackupConfig(database="posthog", table=table)
|
||||
result = get_latest_backup(config=config, s3=mock_s3)
|
||||
context = dagster.build_op_context()
|
||||
result = get_latest_backups(context=context, config=config, s3=mock_s3)
|
||||
|
||||
assert isinstance(result, Backup)
|
||||
assert result.database == "posthog"
|
||||
assert result.date == "2024-03-01T07:54:04Z"
|
||||
assert result.base_backup is None
|
||||
assert isinstance(result, list)
|
||||
assert result[0].database == "posthog"
|
||||
assert result[0].date == "2024-03-01T07:54:04Z"
|
||||
assert result[0].base_backup is None
|
||||
|
||||
assert result[1].database == "posthog"
|
||||
assert result[1].date == "2024-02-01T07:54:04Z"
|
||||
assert result[1].base_backup is None
|
||||
|
||||
assert result[2].database == "posthog"
|
||||
assert result[2].date == "2024-01-01T07:54:04Z"
|
||||
assert result[2].base_backup is None
|
||||
|
||||
expected_table = table if table else None
|
||||
assert result.table == expected_table
|
||||
assert result[0].table == expected_table
|
||||
assert result[1].table == expected_table
|
||||
assert result[2].table == expected_table
|
||||
|
||||
|
||||
def test_get_latest_successful_backup_returns_latest_backup():
|
||||
config = BackupConfig(database="posthog", table="test", incremental=True)
|
||||
backup1 = Backup(database="posthog", date="2024-02-01T07:54:04Z", table="test")
|
||||
backup1.is_done = MagicMock(return_value=True) # type: ignore
|
||||
backup1.status = MagicMock( # type: ignore
|
||||
return_value=BackupStatus(hostname="test", status="CREATING_BACKUP", event_time_microseconds=datetime.now())
|
||||
)
|
||||
|
||||
backup2 = Backup(database="posthog", date="2024-01-01T07:54:04Z", table="test")
|
||||
backup2.is_done = MagicMock(return_value=True) # type: ignore
|
||||
backup2.status = MagicMock( # type: ignore
|
||||
return_value=BackupStatus(hostname="test", status="BACKUP_CREATED", event_time_microseconds=datetime.now())
|
||||
)
|
||||
|
||||
def mock_map_hosts(fn, **kwargs):
|
||||
mock_result = MagicMock()
|
||||
mock_client = MagicMock()
|
||||
mock_result.result.return_value = {"host1": fn(mock_client)}
|
||||
return mock_result
|
||||
|
||||
cluster = MagicMock()
|
||||
cluster.map_hosts_by_role.side_effect = mock_map_hosts
|
||||
|
||||
result = get_latest_successful_backup(
|
||||
context=dagster.build_op_context(),
|
||||
config=config,
|
||||
latest_backups=[backup1, backup2],
|
||||
cluster=cluster,
|
||||
)
|
||||
|
||||
assert result == backup2
|
||||
|
||||
|
||||
def test_get_latest_successful_backup_fails():
|
||||
config = BackupConfig(database="posthog", table="test", incremental=True)
|
||||
backup1 = Backup(database="posthog", date="2024-02-01T07:54:04Z", table="test")
|
||||
backup1.status = MagicMock( # type: ignore
|
||||
return_value=BackupStatus(hostname="test", status="CREATING_BACKUP", event_time_microseconds=datetime.now())
|
||||
)
|
||||
|
||||
def mock_map_hosts(fn, **kwargs):
|
||||
mock_result = MagicMock()
|
||||
mock_client = MagicMock()
|
||||
mock_result.result.return_value = {"host1": fn(mock_client)}
|
||||
return mock_result
|
||||
|
||||
cluster = MagicMock()
|
||||
cluster.map_hosts_by_role.side_effect = mock_map_hosts
|
||||
|
||||
with pytest.raises(dagster.Failure):
|
||||
get_latest_successful_backup(
|
||||
context=dagster.build_op_context(),
|
||||
config=config,
|
||||
latest_backups=[backup1],
|
||||
cluster=cluster,
|
||||
)
|
||||
|
||||
|
||||
def run_backup_test(
|
||||
@@ -144,7 +233,6 @@ def run_backup_test(
|
||||
def test_full_non_sharded_backup(cluster: ClickhouseCluster):
|
||||
config = BackupConfig(
|
||||
database=settings.CLICKHOUSE_DATABASE,
|
||||
date=datetime.now().strftime("%Y-%m-%dT%H:%M:%SZ"),
|
||||
table="person_distinct_id_overrides",
|
||||
incremental=False,
|
||||
workload=Workload.ONLINE,
|
||||
@@ -161,7 +249,6 @@ def test_full_non_sharded_backup(cluster: ClickhouseCluster):
|
||||
def test_full_sharded_backup(cluster: ClickhouseCluster):
|
||||
config = BackupConfig(
|
||||
database=settings.CLICKHOUSE_DATABASE,
|
||||
date=datetime.now().strftime("%Y-%m-%dT%H:%M:%SZ"),
|
||||
table="person_distinct_id_overrides",
|
||||
incremental=False,
|
||||
workload=Workload.ONLINE,
|
||||
@@ -178,7 +265,6 @@ def test_full_sharded_backup(cluster: ClickhouseCluster):
|
||||
def test_incremental_non_sharded_backup(cluster: ClickhouseCluster):
|
||||
config = BackupConfig(
|
||||
database=settings.CLICKHOUSE_DATABASE,
|
||||
date=datetime.now().strftime("%Y-%m-%dT%H:%M:%SZ"),
|
||||
table="person_distinct_id_overrides",
|
||||
incremental=True,
|
||||
workload=Workload.ONLINE,
|
||||
@@ -195,7 +281,6 @@ def test_incremental_non_sharded_backup(cluster: ClickhouseCluster):
|
||||
def test_incremental_sharded_backup(cluster: ClickhouseCluster):
|
||||
config = BackupConfig(
|
||||
database=settings.CLICKHOUSE_DATABASE,
|
||||
date=datetime.now().strftime("%Y-%m-%dT%H:%M:%SZ"),
|
||||
table="person_distinct_id_overrides",
|
||||
incremental=True,
|
||||
workload=Workload.ONLINE,
|
||||
|
||||
485
dags/tests/test_persons_new_backfill.py
Normal file
485
dags/tests/test_persons_new_backfill.py
Normal file
@@ -0,0 +1,485 @@
|
||||
"""Tests for the persons new backfill job."""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import psycopg2.errors
|
||||
from dagster import build_op_context
|
||||
|
||||
from dags.persons_new_backfill import PersonsNewBackfillConfig, copy_chunk, create_chunks
|
||||
|
||||
|
||||
class TestCreateChunks:
|
||||
"""Test the create_chunks function."""
|
||||
|
||||
def test_create_chunks_produces_non_overlapping_ranges(self):
|
||||
"""Test that chunks produce non-overlapping ranges."""
|
||||
config = PersonsNewBackfillConfig(chunk_size=1000)
|
||||
id_range = (1, 5000) # min_id=1, max_id=5000
|
||||
|
||||
context = build_op_context()
|
||||
chunks = list(create_chunks(context, config, id_range))
|
||||
|
||||
# Extract all chunk ranges from DynamicOutput objects
|
||||
chunk_ranges = [chunk.value for chunk in chunks]
|
||||
|
||||
# Verify no overlaps
|
||||
for i, (min1, max1) in enumerate(chunk_ranges):
|
||||
for j, (min2, max2) in enumerate(chunk_ranges):
|
||||
if i != j:
|
||||
# Chunks should not overlap
|
||||
assert not (
|
||||
min1 <= min2 <= max1 or min1 <= max2 <= max1 or min2 <= min1 <= max2
|
||||
), f"Chunks overlap: ({min1}, {max1}) and ({min2}, {max2})"
|
||||
|
||||
def test_create_chunks_covers_entire_id_space(self):
|
||||
"""Test that chunks cover the entire ID space from min to max."""
|
||||
config = PersonsNewBackfillConfig(chunk_size=1000)
|
||||
min_id, max_id = 1, 5000
|
||||
id_range = (min_id, max_id)
|
||||
|
||||
context = build_op_context()
|
||||
chunks = list(create_chunks(context, config, id_range))
|
||||
|
||||
# Extract all chunk ranges from DynamicOutput objects
|
||||
chunk_ranges = [chunk.value for chunk in chunks]
|
||||
|
||||
# Find the overall min and max covered
|
||||
all_ids_covered: set[int] = set()
|
||||
for chunk_min, chunk_max in chunk_ranges:
|
||||
all_ids_covered.update(range(chunk_min, chunk_max + 1))
|
||||
|
||||
# Verify all IDs from min_id to max_id are covered
|
||||
expected_ids = set(range(min_id, max_id + 1))
|
||||
assert all_ids_covered == expected_ids, (
|
||||
f"Missing IDs: {expected_ids - all_ids_covered}, " f"Extra IDs: {all_ids_covered - expected_ids}"
|
||||
)
|
||||
|
||||
def test_create_chunks_first_chunk_includes_max_id(self):
|
||||
"""Test that the first chunk (in yielded order) includes the source table max_id."""
|
||||
config = PersonsNewBackfillConfig(chunk_size=1000)
|
||||
min_id, max_id = 1, 5000
|
||||
id_range = (min_id, max_id)
|
||||
|
||||
context = build_op_context()
|
||||
chunks = list(create_chunks(context, config, id_range))
|
||||
|
||||
# First chunk in the list (yielded first, highest IDs)
|
||||
first_chunk_min, first_chunk_max = chunks[0].value
|
||||
|
||||
assert first_chunk_max == max_id, f"First chunk max ({first_chunk_max}) should equal source max_id ({max_id})"
|
||||
assert (
|
||||
first_chunk_min <= max_id <= first_chunk_max
|
||||
), f"First chunk ({first_chunk_min}, {first_chunk_max}) should include max_id ({max_id})"
|
||||
|
||||
def test_create_chunks_final_chunk_includes_min_id(self):
|
||||
"""Test that the final chunk (in yielded order) includes the source table min_id."""
|
||||
config = PersonsNewBackfillConfig(chunk_size=1000)
|
||||
min_id, max_id = 1, 5000
|
||||
id_range = (min_id, max_id)
|
||||
|
||||
context = build_op_context()
|
||||
chunks = list(create_chunks(context, config, id_range))
|
||||
|
||||
# Last chunk in the list (yielded last, lowest IDs)
|
||||
final_chunk_min, final_chunk_max = chunks[-1].value
|
||||
|
||||
assert final_chunk_min == min_id, f"Final chunk min ({final_chunk_min}) should equal source min_id ({min_id})"
|
||||
assert (
|
||||
final_chunk_min <= min_id <= final_chunk_max
|
||||
), f"Final chunk ({final_chunk_min}, {final_chunk_max}) should include min_id ({min_id})"
|
||||
|
||||
def test_create_chunks_reverse_order(self):
|
||||
"""Test that chunks are yielded in reverse order (highest IDs first)."""
|
||||
config = PersonsNewBackfillConfig(chunk_size=1000)
|
||||
min_id, max_id = 1, 5000
|
||||
id_range = (min_id, max_id)
|
||||
|
||||
context = build_op_context()
|
||||
chunks = list(create_chunks(context, config, id_range))
|
||||
|
||||
# Verify chunks are in descending order by max_id
|
||||
for i in range(len(chunks) - 1):
|
||||
current_max = chunks[i].value[1]
|
||||
next_max = chunks[i + 1].value[1]
|
||||
assert (
|
||||
current_max > next_max
|
||||
), f"Chunks not in reverse order: chunk {i} max ({current_max}) should be > chunk {i+1} max ({next_max})"
|
||||
|
||||
def test_create_chunks_exact_multiple(self):
|
||||
"""Test chunk creation when ID range is an exact multiple of chunk_size."""
|
||||
config = PersonsNewBackfillConfig(chunk_size=1000)
|
||||
min_id, max_id = 1, 5000 # Exactly 5 chunks of 1000
|
||||
id_range = (min_id, max_id)
|
||||
|
||||
context = build_op_context()
|
||||
chunks = list(create_chunks(context, config, id_range))
|
||||
|
||||
assert len(chunks) == 5, f"Expected 5 chunks, got {len(chunks)}"
|
||||
|
||||
# Verify first chunk (highest IDs)
|
||||
assert chunks[0].value == (4001, 5000), f"First chunk should be (4001, 5000), got {chunks[0].value}"
|
||||
|
||||
# Verify last chunk (lowest IDs)
|
||||
assert chunks[-1].value == (1, 1000), f"Last chunk should be (1, 1000), got {chunks[-1].value}"
|
||||
|
||||
def test_create_chunks_non_exact_multiple(self):
|
||||
"""Test chunk creation when ID range is not an exact multiple of chunk_size."""
|
||||
config = PersonsNewBackfillConfig(chunk_size=1000)
|
||||
min_id, max_id = 1, 3750 # 3 full chunks + 1 partial chunk
|
||||
id_range = (min_id, max_id)
|
||||
|
||||
context = build_op_context()
|
||||
chunks = list(create_chunks(context, config, id_range))
|
||||
|
||||
assert len(chunks) == 4, f"Expected 4 chunks, got {len(chunks)}"
|
||||
|
||||
# Verify first chunk (highest IDs) - should be the partial chunk
|
||||
assert chunks[0].value == (3001, 3750), f"First chunk should be (3001, 3750), got {chunks[0].value}"
|
||||
|
||||
# Verify last chunk (lowest IDs)
|
||||
assert chunks[-1].value == (1, 1000), f"Last chunk should be (1, 1000), got {chunks[-1].value}"
|
||||
|
||||
def test_create_chunks_single_chunk(self):
|
||||
"""Test chunk creation when ID range fits in a single chunk."""
|
||||
config = PersonsNewBackfillConfig(chunk_size=1000)
|
||||
min_id, max_id = 100, 500
|
||||
id_range = (min_id, max_id)
|
||||
|
||||
context = build_op_context()
|
||||
chunks = list(create_chunks(context, config, id_range))
|
||||
|
||||
assert len(chunks) == 1, f"Expected 1 chunk, got {len(chunks)}"
|
||||
assert chunks[0].value == (100, 500), f"Chunk should be (100, 500), got {chunks[0].value}"
|
||||
assert chunks[0].value[0] == min_id and chunks[0].value[1] == max_id
|
||||
|
||||
|
||||
def create_mock_database_resource(rowcount_values=None):
|
||||
"""
|
||||
Create a mock database resource that mimics psycopg2.extensions.connection.
|
||||
|
||||
Args:
|
||||
rowcount_values: List of rowcount values to return per INSERT call.
|
||||
If None, defaults to 0. If a single int, uses that for all calls.
|
||||
"""
|
||||
mock_cursor = MagicMock()
|
||||
if rowcount_values is None:
|
||||
mock_cursor.rowcount = 0
|
||||
elif isinstance(rowcount_values, int):
|
||||
mock_cursor.rowcount = rowcount_values
|
||||
else:
|
||||
# Use side_effect to return different rowcounts per call
|
||||
call_count = [0]
|
||||
|
||||
def get_rowcount():
|
||||
if call_count[0] < len(rowcount_values):
|
||||
result = rowcount_values[call_count[0]]
|
||||
call_count[0] += 1
|
||||
return result
|
||||
return rowcount_values[-1] if rowcount_values else 0
|
||||
|
||||
mock_cursor.rowcount = property(lambda self: get_rowcount())
|
||||
|
||||
mock_cursor.execute = MagicMock()
|
||||
|
||||
# Make cursor() return a context manager
|
||||
mock_conn = MagicMock()
|
||||
mock_conn.cursor.return_value.__enter__ = MagicMock(return_value=mock_cursor)
|
||||
mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
return mock_conn
|
||||
|
||||
|
||||
def create_mock_cluster_resource():
|
||||
"""Create a mock ClickhouseCluster resource."""
|
||||
return MagicMock()
|
||||
|
||||
|
||||
class TestCopyChunk:
|
||||
"""Test the copy_chunk function."""
|
||||
|
||||
def test_copy_chunk_single_batch_success(self):
|
||||
"""Test successful copy of a single batch within a chunk."""
|
||||
config = PersonsNewBackfillConfig(
|
||||
chunk_size=1000, batch_size=100, source_table="posthog_persons", destination_table="posthog_persons_new"
|
||||
)
|
||||
chunk = (1, 100) # Single batch covers entire chunk
|
||||
|
||||
mock_db = create_mock_database_resource(rowcount_values=50)
|
||||
mock_cluster = create_mock_cluster_resource()
|
||||
|
||||
context = build_op_context(
|
||||
resources={"database": mock_db, "cluster": mock_cluster},
|
||||
)
|
||||
# Patch context.run.job_name where it's accessed in copy_chunk
|
||||
from unittest.mock import PropertyMock
|
||||
|
||||
with patch.object(type(context), "run", PropertyMock(return_value=MagicMock(job_name="test_job"))):
|
||||
result = copy_chunk(context, config, chunk)
|
||||
|
||||
# Verify result
|
||||
assert result["chunk_min"] == 1
|
||||
assert result["chunk_max"] == 100
|
||||
assert result["records_copied"] == 50
|
||||
|
||||
# Verify SET statements called once (session-level, before loop)
|
||||
set_statements = [
|
||||
"SET application_name = 'backfill_posthog_persons_to_posthog_persons_new'",
|
||||
"SET lock_timeout = '5s'",
|
||||
"SET statement_timeout = '30min'",
|
||||
"SET maintenance_work_mem = '12GB'",
|
||||
"SET work_mem = '512MB'",
|
||||
"SET temp_buffers = '512MB'",
|
||||
"SET max_parallel_workers_per_gather = 2",
|
||||
"SET max_parallel_maintenance_workers = 2",
|
||||
"SET synchronous_commit = off",
|
||||
]
|
||||
|
||||
cursor = mock_db.cursor.return_value.__enter__.return_value
|
||||
execute_calls = [call[0][0] for call in cursor.execute.call_args_list]
|
||||
|
||||
# Check SET statements were called
|
||||
for stmt in set_statements:
|
||||
assert any(stmt in call for call in execute_calls), f"SET statement not found: {stmt}"
|
||||
|
||||
# Verify BEGIN, INSERT, COMMIT called once
|
||||
assert execute_calls.count("BEGIN") == 1
|
||||
assert execute_calls.count("COMMIT") == 1
|
||||
|
||||
# Verify INSERT query format
|
||||
insert_calls = [call for call in execute_calls if "INSERT INTO" in call]
|
||||
assert len(insert_calls) == 1
|
||||
insert_query = insert_calls[0]
|
||||
assert "INSERT INTO posthog_persons_new" in insert_query
|
||||
assert "SELECT s.*" in insert_query
|
||||
assert "FROM posthog_persons s" in insert_query
|
||||
assert "WHERE s.id >" in insert_query
|
||||
assert "AND s.id <=" in insert_query
|
||||
assert "NOT EXISTS" in insert_query
|
||||
assert "ORDER BY s.id DESC" in insert_query
|
||||
|
||||
def test_copy_chunk_multiple_batches(self):
|
||||
"""Test copy with multiple batches in a chunk."""
|
||||
config = PersonsNewBackfillConfig(
|
||||
chunk_size=1000, batch_size=100, source_table="posthog_persons", destination_table="posthog_persons_new"
|
||||
)
|
||||
chunk = (1, 250) # 3 batches: (1,100), (100,200), (200,250)
|
||||
|
||||
mock_db = create_mock_database_resource()
|
||||
mock_cluster = create_mock_cluster_resource()
|
||||
|
||||
# Track rowcount per batch - use a list to track INSERT calls
|
||||
rowcounts = [50, 75, 25]
|
||||
insert_call_count = [0]
|
||||
|
||||
cursor = mock_db.cursor.return_value.__enter__.return_value
|
||||
|
||||
# Track INSERT calls and set rowcount accordingly
|
||||
def execute_with_rowcount(query, *args):
|
||||
if "INSERT INTO" in query:
|
||||
if insert_call_count[0] < len(rowcounts):
|
||||
cursor.rowcount = rowcounts[insert_call_count[0]]
|
||||
insert_call_count[0] += 1
|
||||
else:
|
||||
cursor.rowcount = 0
|
||||
|
||||
cursor.execute.side_effect = execute_with_rowcount
|
||||
|
||||
context = build_op_context(
|
||||
resources={"database": mock_db, "cluster": mock_cluster},
|
||||
)
|
||||
# Patch context.run.job_name where it's accessed in copy_chunk
|
||||
from unittest.mock import PropertyMock
|
||||
|
||||
with patch.object(type(context), "run", PropertyMock(return_value=MagicMock(job_name="test_job"))):
|
||||
result = copy_chunk(context, config, chunk)
|
||||
|
||||
# Verify result
|
||||
assert result["chunk_min"] == 1
|
||||
assert result["chunk_max"] == 250
|
||||
assert result["records_copied"] == 150 # 50 + 75 + 25
|
||||
|
||||
# Verify SET statements called once (before loop)
|
||||
cursor = mock_db.cursor.return_value.__enter__.return_value
|
||||
execute_calls = [call[0][0] for call in cursor.execute.call_args_list]
|
||||
|
||||
# Verify BEGIN/COMMIT called 3 times (one per batch)
|
||||
assert execute_calls.count("BEGIN") == 3
|
||||
assert execute_calls.count("COMMIT") == 3
|
||||
|
||||
# Verify INSERT called 3 times
|
||||
insert_calls = [call for call in execute_calls if "INSERT INTO" in call]
|
||||
assert len(insert_calls) == 3
|
||||
|
||||
def test_copy_chunk_duplicate_key_violation_retry(self):
|
||||
"""Test that duplicate key violation triggers retry."""
|
||||
config = PersonsNewBackfillConfig(
|
||||
chunk_size=1000, batch_size=100, source_table="posthog_persons", destination_table="posthog_persons_new"
|
||||
)
|
||||
chunk = (1, 100)
|
||||
|
||||
mock_db = create_mock_database_resource()
|
||||
mock_cluster = create_mock_cluster_resource()
|
||||
|
||||
cursor = mock_db.cursor.return_value.__enter__.return_value
|
||||
|
||||
# Track INSERT attempts
|
||||
insert_attempts = [0]
|
||||
|
||||
# First INSERT raises UniqueViolation, second succeeds
|
||||
def execute_side_effect(query, *args):
|
||||
if "INSERT INTO" in query:
|
||||
insert_attempts[0] += 1
|
||||
if insert_attempts[0] == 1:
|
||||
# First INSERT attempt raises error
|
||||
# Use real UniqueViolation - pgcode is readonly but isinstance check will pass
|
||||
raise psycopg2.errors.UniqueViolation("duplicate key value violates unique constraint")
|
||||
# Subsequent calls succeed
|
||||
cursor.rowcount = 50 # Success on retry
|
||||
|
||||
cursor.execute.side_effect = execute_side_effect
|
||||
|
||||
context = build_op_context(
|
||||
resources={"database": mock_db, "cluster": mock_cluster},
|
||||
)
|
||||
# Need to patch time.sleep and run.job_name
|
||||
from unittest.mock import PropertyMock
|
||||
|
||||
mock_run = MagicMock(job_name="test_job")
|
||||
with (
|
||||
patch("dags.persons_new_backfill.time.sleep"),
|
||||
patch.object(type(context), "run", PropertyMock(return_value=mock_run)),
|
||||
):
|
||||
copy_chunk(context, config, chunk)
|
||||
|
||||
# Verify ROLLBACK was called on error
|
||||
execute_calls = [call[0][0] for call in cursor.execute.call_args_list]
|
||||
assert "ROLLBACK" in execute_calls
|
||||
|
||||
# Verify retry succeeded (should have INSERT called twice, COMMIT once)
|
||||
insert_calls = [call for call in execute_calls if "INSERT INTO" in call]
|
||||
assert len(insert_calls) >= 1 # At least one successful INSERT
|
||||
|
||||
def test_copy_chunk_error_handling_and_rollback(self):
|
||||
"""Test error handling and rollback on non-duplicate errors."""
|
||||
config = PersonsNewBackfillConfig(
|
||||
chunk_size=1000, batch_size=100, source_table="posthog_persons", destination_table="posthog_persons_new"
|
||||
)
|
||||
chunk = (1, 100)
|
||||
|
||||
mock_db = create_mock_database_resource()
|
||||
mock_cluster = create_mock_cluster_resource()
|
||||
|
||||
cursor = mock_db.cursor.return_value.__enter__.return_value
|
||||
|
||||
# Raise generic error on INSERT
|
||||
def execute_side_effect(query, *args):
|
||||
if "INSERT INTO" in query:
|
||||
raise Exception("Connection lost")
|
||||
|
||||
cursor.execute.side_effect = execute_side_effect
|
||||
|
||||
context = build_op_context(
|
||||
resources={"database": mock_db, "cluster": mock_cluster},
|
||||
)
|
||||
# Patch context.run.job_name where it's accessed in copy_chunk
|
||||
from unittest.mock import PropertyMock
|
||||
|
||||
mock_run = MagicMock(job_name="test_job")
|
||||
with patch.object(type(context), "run", PropertyMock(return_value=mock_run)):
|
||||
# Should raise Dagster.Failure
|
||||
from dagster import Failure
|
||||
|
||||
try:
|
||||
copy_chunk(context, config, chunk)
|
||||
raise AssertionError("Expected Dagster.Failure to be raised")
|
||||
except Failure as e:
|
||||
# Verify error metadata
|
||||
assert e.description is not None
|
||||
assert "Failed to copy batch" in e.description
|
||||
|
||||
# Verify ROLLBACK was called
|
||||
execute_calls = [call[0][0] for call in cursor.execute.call_args_list]
|
||||
assert "ROLLBACK" in execute_calls
|
||||
|
||||
def test_copy_chunk_insert_query_format(self):
|
||||
"""Test that INSERT query has correct format."""
|
||||
config = PersonsNewBackfillConfig(
|
||||
chunk_size=1000, batch_size=100, source_table="test_source", destination_table="test_dest"
|
||||
)
|
||||
chunk = (1, 100)
|
||||
|
||||
mock_db = create_mock_database_resource(rowcount_values=10)
|
||||
mock_cluster = create_mock_cluster_resource()
|
||||
|
||||
context = build_op_context(
|
||||
resources={"database": mock_db, "cluster": mock_cluster},
|
||||
)
|
||||
# Patch context.run.job_name where it's accessed in copy_chunk
|
||||
from unittest.mock import PropertyMock
|
||||
|
||||
with patch.object(type(context), "run", PropertyMock(return_value=MagicMock(job_name="test_job"))):
|
||||
copy_chunk(context, config, chunk)
|
||||
|
||||
cursor = mock_db.cursor.return_value.__enter__.return_value
|
||||
execute_calls = [call[0][0] for call in cursor.execute.call_args_list]
|
||||
|
||||
# Find INSERT query
|
||||
insert_query = next((call for call in execute_calls if "INSERT INTO" in call), None)
|
||||
assert insert_query is not None
|
||||
|
||||
# Verify query components
|
||||
assert "INSERT INTO test_dest" in insert_query
|
||||
assert "SELECT s.*" in insert_query
|
||||
assert "FROM test_source s" in insert_query
|
||||
assert "WHERE s.id >" in insert_query
|
||||
assert "AND s.id <=" in insert_query
|
||||
assert "NOT EXISTS" in insert_query
|
||||
assert "d.team_id = s.team_id" in insert_query
|
||||
assert "d.id = s.id" in insert_query
|
||||
assert "ORDER BY s.id DESC" in insert_query
|
||||
|
||||
def test_copy_chunk_session_settings_applied_once(self):
|
||||
"""Test that SET statements are applied once at session level before batch loop."""
|
||||
config = PersonsNewBackfillConfig(
|
||||
chunk_size=1000, batch_size=50, source_table="posthog_persons", destination_table="posthog_persons_new"
|
||||
)
|
||||
chunk = (1, 150) # 3 batches
|
||||
|
||||
mock_db = create_mock_database_resource(rowcount_values=25)
|
||||
mock_cluster = create_mock_cluster_resource()
|
||||
|
||||
context = build_op_context(
|
||||
resources={"database": mock_db, "cluster": mock_cluster},
|
||||
)
|
||||
# Patch context.run.job_name where it's accessed in copy_chunk
|
||||
from unittest.mock import PropertyMock
|
||||
|
||||
with patch.object(type(context), "run", PropertyMock(return_value=MagicMock(job_name="test_job"))):
|
||||
copy_chunk(context, config, chunk)
|
||||
|
||||
cursor = mock_db.cursor.return_value.__enter__.return_value
|
||||
execute_calls = [call[0][0] for call in cursor.execute.call_args_list]
|
||||
|
||||
# Count SET statements (should be called once each, before loop)
|
||||
set_statements = [
|
||||
"SET application_name",
|
||||
"SET lock_timeout",
|
||||
"SET statement_timeout",
|
||||
"SET maintenance_work_mem",
|
||||
"SET work_mem",
|
||||
"SET temp_buffers",
|
||||
"SET max_parallel_workers_per_gather",
|
||||
"SET max_parallel_maintenance_workers",
|
||||
"SET synchronous_commit",
|
||||
]
|
||||
|
||||
for stmt in set_statements:
|
||||
count = sum(1 for call in execute_calls if stmt in call)
|
||||
assert count == 1, f"Expected {stmt} to be called once, but it was called {count} times"
|
||||
|
||||
# Verify SET statements come before BEGIN statements
|
||||
set_indices = [i for i, call in enumerate(execute_calls) if any(stmt in call for stmt in set_statements)]
|
||||
begin_indices = [i for i, call in enumerate(execute_calls) if call == "BEGIN"]
|
||||
|
||||
if set_indices and begin_indices:
|
||||
assert max(set_indices) < min(begin_indices), "SET statements should come before BEGIN statements"
|
||||
@@ -113,7 +113,7 @@ services:
|
||||
# Note: please keep the default version in sync across
|
||||
# `posthog` and the `charts-clickhouse` repos
|
||||
#
|
||||
image: ${CLICKHOUSE_SERVER_IMAGE:-clickhouse/clickhouse-server:25.6.13.41}
|
||||
image: ${CLICKHOUSE_SERVER_IMAGE:-clickhouse/clickhouse-server:25.8.11.66}
|
||||
restart: on-failure
|
||||
environment:
|
||||
CLICKHOUSE_SKIP_USER_SETUP: 1
|
||||
@@ -296,17 +296,15 @@ services:
|
||||
MAXMIND_DB_PATH: '/share/GeoLite2-City.mmdb'
|
||||
# Shared Redis for non-critical path (analytics, billing, cookieless)
|
||||
REDIS_URL: 'redis://redis:6379/'
|
||||
# Optional: Use separate Redis URLs for read/write separation
|
||||
# Optional: Use separate Redis URL for read replicas
|
||||
# REDIS_READER_URL: 'redis://redis-replica:6379/'
|
||||
# REDIS_WRITER_URL: 'redis://redis-primary:6379/'
|
||||
# Optional: Increase Redis timeout (default is 100ms)
|
||||
# REDIS_TIMEOUT_MS: 200
|
||||
# Dedicated Redis database for critical path (team cache + flags cache)
|
||||
# Hobby deployments start in Mode 1 (shared-only). Developers override in docker-compose.dev.yml for Mode 2.
|
||||
# FLAGS_REDIS_URL: 'redis://redis:6379/1'
|
||||
# Optional: Use separate Flags Redis URLs for read/write separation
|
||||
# Optional: Use separate Flags Redis URL for read replicas
|
||||
# FLAGS_REDIS_READER_URL: 'redis://redis-replica:6379/1'
|
||||
# FLAGS_REDIS_WRITER_URL: 'redis://redis-primary:6379/1'
|
||||
ADDRESS: '0.0.0.0:3001'
|
||||
RUST_LOG: 'info'
|
||||
COOKIELESS_REDIS_HOST: redis7
|
||||
|
||||
@@ -111,7 +111,7 @@ services:
|
||||
service: clickhouse
|
||||
hostname: clickhouse
|
||||
# Development performance optimizations
|
||||
mem_limit: 4g
|
||||
mem_limit: 6g
|
||||
cpus: 2
|
||||
environment:
|
||||
- AWS_ACCESS_KEY_ID=object_storage_root_user
|
||||
|
||||
@@ -103,7 +103,7 @@ class ConversationViewSet(TeamAndOrgViewSetMixin, ListModelMixin, RetrieveModelM
|
||||
# Only for streaming
|
||||
and self.action == "create"
|
||||
# Strict limits are skipped for select US region teams (PostHog + an active user we've chatted with)
|
||||
and not (get_instance_region() == "US" and self.team_id in (2, 87921, 41124))
|
||||
and not (get_instance_region() == "US" and self.team_id in (2, 87921, 41124, 103224))
|
||||
):
|
||||
return [AIBurstRateThrottle(), AISustainedRateThrottle()]
|
||||
|
||||
|
||||
@@ -79,6 +79,12 @@ class AccessControlSerializer(serializers.ModelSerializer):
|
||||
f"Access level cannot be set below the minimum '{min_level}' for {resource}."
|
||||
)
|
||||
|
||||
max_level = highest_access_level(resource)
|
||||
if levels.index(access_level) > levels.index(max_level):
|
||||
raise serializers.ValidationError(
|
||||
f"Access level cannot be set above the maximum '{max_level}' for {resource}."
|
||||
)
|
||||
|
||||
return access_level
|
||||
|
||||
def validate(self, data):
|
||||
@@ -219,6 +225,7 @@ class AccessControlViewSetMixin(_GenericViewSet):
|
||||
else ordered_access_levels(resource),
|
||||
"default_access_level": "editor" if is_resource_level else default_access_level(resource),
|
||||
"minimum_access_level": minimum_access_level(resource) if not is_resource_level else "none",
|
||||
"maximum_access_level": highest_access_level(resource) if not is_resource_level else "manager",
|
||||
"user_access_level": user_access_level,
|
||||
"user_can_edit_access_levels": user_access_control.check_can_modify_access_levels_for_object(obj),
|
||||
}
|
||||
|
||||
@@ -146,6 +146,42 @@ class TestAccessControlMinimumLevelValidation(BaseAccessControlTest):
|
||||
)
|
||||
assert res.status_code == status.HTTP_200_OK, f"Failed for level {level}: {res.json()}"
|
||||
|
||||
def test_activity_log_access_level_cannot_be_above_viewer(self):
|
||||
"""Test that activity_log access level cannot be set above maximum 'viewer'"""
|
||||
self._org_membership(OrganizationMembership.Level.ADMIN)
|
||||
|
||||
for level in ["editor", "manager"]:
|
||||
res = self.client.put(
|
||||
"/api/projects/@current/resource_access_controls",
|
||||
{"resource": "activity_log", "access_level": level},
|
||||
)
|
||||
assert res.status_code == status.HTTP_400_BAD_REQUEST, f"Failed for level {level}: {res.json()}"
|
||||
assert "cannot be set above the maximum 'viewer'" in res.json()["detail"]
|
||||
|
||||
def test_activity_log_access_restricted_for_users_without_access(self):
|
||||
"""Test that users without access to activity_log cannot access activity log endpoints"""
|
||||
self._org_membership(OrganizationMembership.Level.ADMIN)
|
||||
|
||||
res = self.client.put(
|
||||
"/api/projects/@current/resource_access_controls",
|
||||
{"resource": "activity_log", "access_level": "none"},
|
||||
)
|
||||
assert res.status_code == status.HTTP_200_OK, f"Failed to set access control: {res.json()}"
|
||||
|
||||
from ee.models.rbac.access_control import AccessControl
|
||||
|
||||
ac = AccessControl.objects.filter(team=self.team, resource="activity_log", resource_id=None).first()
|
||||
assert ac is not None, "Access control was not created"
|
||||
assert ac.access_level == "none", f"Access level is {ac.access_level}, expected 'none'"
|
||||
|
||||
self._org_membership(OrganizationMembership.Level.MEMBER)
|
||||
|
||||
res = self.client.get("/api/projects/@current/activity_log/")
|
||||
assert res.status_code == status.HTTP_403_FORBIDDEN, f"Expected 403, got {res.status_code}: {res.json()}"
|
||||
|
||||
res = self.client.get("/api/projects/@current/advanced_activity_logs/")
|
||||
assert res.status_code == status.HTTP_403_FORBIDDEN, f"Expected 403, got {res.status_code}: {res.json()}"
|
||||
|
||||
|
||||
class TestAccessControlResourceLevelAPI(BaseAccessControlTest):
|
||||
def setUp(self):
|
||||
@@ -186,6 +222,7 @@ class TestAccessControlResourceLevelAPI(BaseAccessControlTest):
|
||||
"default_access_level": "editor",
|
||||
"user_can_edit_access_levels": True,
|
||||
"minimum_access_level": "none",
|
||||
"maximum_access_level": "manager",
|
||||
}
|
||||
|
||||
def test_change_rejected_if_not_org_admin(self):
|
||||
|
||||
@@ -171,8 +171,6 @@
|
||||
WHERE lc_cohort_id = 99999
|
||||
AND team_id = 99999
|
||||
AND query LIKE '%cohort_calc:00000000%'
|
||||
AND type IN ('QueryFinish',
|
||||
'ExceptionWhileProcessing')
|
||||
AND event_date >= 'today'
|
||||
AND event_time >= 'today 00:00:00'
|
||||
ORDER BY event_time DESC
|
||||
@@ -180,6 +178,25 @@
|
||||
# ---
|
||||
# name: TestCohortQuery.test_cohort_filter_with_another_cohort_with_event_sequence.1
|
||||
'''
|
||||
/* celery:posthog.tasks.calculate_cohort.collect_cohort_query_stats */
|
||||
SELECT query_id,
|
||||
query_duration_ms,
|
||||
read_rows,
|
||||
read_bytes,
|
||||
written_rows,
|
||||
memory_usage,
|
||||
exception
|
||||
FROM query_log_archive
|
||||
WHERE lc_cohort_id = 99999
|
||||
AND team_id = 99999
|
||||
AND query LIKE '%cohort_calc:00000000%'
|
||||
AND event_date >= 'today'
|
||||
AND event_time >= 'today 00:00:00'
|
||||
ORDER BY event_time DESC
|
||||
'''
|
||||
# ---
|
||||
# name: TestCohortQuery.test_cohort_filter_with_another_cohort_with_event_sequence.2
|
||||
'''
|
||||
|
||||
SELECT count(DISTINCT person_id)
|
||||
FROM cohortpeople
|
||||
@@ -188,7 +205,7 @@
|
||||
AND version = 0
|
||||
'''
|
||||
# ---
|
||||
# name: TestCohortQuery.test_cohort_filter_with_another_cohort_with_event_sequence.2
|
||||
# name: TestCohortQuery.test_cohort_filter_with_another_cohort_with_event_sequence.3
|
||||
'''
|
||||
|
||||
(SELECT cohort_people.person_id AS id
|
||||
@@ -269,7 +286,7 @@
|
||||
allow_experimental_join_condition=1
|
||||
'''
|
||||
# ---
|
||||
# name: TestCohortQuery.test_cohort_filter_with_another_cohort_with_event_sequence.3
|
||||
# name: TestCohortQuery.test_cohort_filter_with_another_cohort_with_event_sequence.4
|
||||
'''
|
||||
|
||||
SELECT if(funnel_query.person_id = '00000000-0000-0000-0000-000000000000', person.person_id, funnel_query.person_id) AS id
|
||||
@@ -341,8 +358,6 @@
|
||||
WHERE lc_cohort_id = 99999
|
||||
AND team_id = 99999
|
||||
AND query LIKE '%cohort_calc:00000000%'
|
||||
AND type IN ('QueryFinish',
|
||||
'ExceptionWhileProcessing')
|
||||
AND event_date >= 'today'
|
||||
AND event_time >= 'today 00:00:00'
|
||||
ORDER BY event_time DESC
|
||||
@@ -350,12 +365,21 @@
|
||||
# ---
|
||||
# name: TestCohortQuery.test_cohort_filter_with_extra.1
|
||||
'''
|
||||
|
||||
SELECT count(DISTINCT person_id)
|
||||
FROM cohortpeople
|
||||
WHERE team_id = 99999
|
||||
AND cohort_id = 99999
|
||||
AND version = 0
|
||||
/* celery:posthog.tasks.calculate_cohort.collect_cohort_query_stats */
|
||||
SELECT query_id,
|
||||
query_duration_ms,
|
||||
read_rows,
|
||||
read_bytes,
|
||||
written_rows,
|
||||
memory_usage,
|
||||
exception
|
||||
FROM query_log_archive
|
||||
WHERE lc_cohort_id = 99999
|
||||
AND team_id = 99999
|
||||
AND query LIKE '%cohort_calc:00000000%'
|
||||
AND event_date >= 'today'
|
||||
AND event_time >= 'today 00:00:00'
|
||||
ORDER BY event_time DESC
|
||||
'''
|
||||
# ---
|
||||
# name: TestCohortQuery.test_cohort_filter_with_extra.10
|
||||
@@ -638,6 +662,16 @@
|
||||
'''
|
||||
# ---
|
||||
# name: TestCohortQuery.test_cohort_filter_with_extra.2
|
||||
'''
|
||||
|
||||
SELECT count(DISTINCT person_id)
|
||||
FROM cohortpeople
|
||||
WHERE team_id = 99999
|
||||
AND cohort_id = 99999
|
||||
AND version = 0
|
||||
'''
|
||||
# ---
|
||||
# name: TestCohortQuery.test_cohort_filter_with_extra.3
|
||||
'''
|
||||
|
||||
(SELECT cohort_people.person_id AS id
|
||||
@@ -684,7 +718,7 @@
|
||||
allow_experimental_join_condition=1
|
||||
'''
|
||||
# ---
|
||||
# name: TestCohortQuery.test_cohort_filter_with_extra.3
|
||||
# name: TestCohortQuery.test_cohort_filter_with_extra.4
|
||||
'''
|
||||
|
||||
SELECT if(behavior_query.person_id = '00000000-0000-0000-0000-000000000000', person.person_id, behavior_query.person_id) AS id
|
||||
@@ -726,7 +760,7 @@
|
||||
join_algorithm = 'auto'
|
||||
'''
|
||||
# ---
|
||||
# name: TestCohortQuery.test_cohort_filter_with_extra.4
|
||||
# name: TestCohortQuery.test_cohort_filter_with_extra.5
|
||||
'''
|
||||
/* celery:posthog.tasks.calculate_cohort.collect_cohort_query_stats */
|
||||
SELECT query_id,
|
||||
@@ -740,14 +774,31 @@
|
||||
WHERE lc_cohort_id = 99999
|
||||
AND team_id = 99999
|
||||
AND query LIKE '%cohort_calc:00000000%'
|
||||
AND type IN ('QueryFinish',
|
||||
'ExceptionWhileProcessing')
|
||||
AND event_date >= 'today'
|
||||
AND event_time >= 'today 00:00:00'
|
||||
ORDER BY event_time DESC
|
||||
'''
|
||||
# ---
|
||||
# name: TestCohortQuery.test_cohort_filter_with_extra.5
|
||||
# name: TestCohortQuery.test_cohort_filter_with_extra.6
|
||||
'''
|
||||
/* celery:posthog.tasks.calculate_cohort.collect_cohort_query_stats */
|
||||
SELECT query_id,
|
||||
query_duration_ms,
|
||||
read_rows,
|
||||
read_bytes,
|
||||
written_rows,
|
||||
memory_usage,
|
||||
exception
|
||||
FROM query_log_archive
|
||||
WHERE lc_cohort_id = 99999
|
||||
AND team_id = 99999
|
||||
AND query LIKE '%cohort_calc:00000000%'
|
||||
AND event_date >= 'today'
|
||||
AND event_time >= 'today 00:00:00'
|
||||
ORDER BY event_time DESC
|
||||
'''
|
||||
# ---
|
||||
# name: TestCohortQuery.test_cohort_filter_with_extra.7
|
||||
'''
|
||||
|
||||
SELECT count(DISTINCT person_id)
|
||||
@@ -757,99 +808,6 @@
|
||||
AND version = 0
|
||||
'''
|
||||
# ---
|
||||
# name: TestCohortQuery.test_cohort_filter_with_extra.6
|
||||
'''
|
||||
(
|
||||
(SELECT persons.id AS id
|
||||
FROM
|
||||
(SELECT argMax(replaceRegexpAll(nullIf(nullIf(JSONExtractRaw(person.properties, 'name'), ''), 'null'), '^"|"$', ''), person.version) AS properties___name,
|
||||
person.id AS id
|
||||
FROM person
|
||||
WHERE and(equals(person.team_id, 99999), in(id,
|
||||
(SELECT where_optimization.id AS id
|
||||
FROM person AS where_optimization
|
||||
WHERE and(equals(where_optimization.team_id, 99999), ifNull(equals(replaceRegexpAll(nullIf(nullIf(JSONExtractRaw(where_optimization.properties, 'name'), ''), 'null'), '^"|"$', ''), 'test'), 0)))))
|
||||
GROUP BY person.id
|
||||
HAVING and(ifNull(equals(argMax(person.is_deleted, person.version), 0), 0), ifNull(less(argMax(toTimeZone(person.created_at, 'UTC'), person.version), plus(now64(6, 'UTC'), toIntervalDay(1))), 0))) AS persons
|
||||
WHERE ifNull(equals(persons.properties___name, 'test'), 0)
|
||||
ORDER BY persons.id ASC
|
||||
LIMIT 1000000000 SETTINGS optimize_aggregation_in_order=1,
|
||||
join_algorithm='auto'))
|
||||
UNION DISTINCT (
|
||||
(SELECT source.id AS id
|
||||
FROM
|
||||
(SELECT actor_id AS actor_id,
|
||||
count() AS event_count,
|
||||
groupUniqArray(distinct_id) AS event_distinct_ids,
|
||||
actor_id AS id
|
||||
FROM
|
||||
(SELECT if(not(empty(e__override.distinct_id)), e__override.person_id, e.person_id) AS actor_id,
|
||||
toTimeZone(e.timestamp, 'UTC') AS timestamp,
|
||||
e.uuid AS uuid,
|
||||
e.distinct_id AS distinct_id
|
||||
FROM events AS e
|
||||
LEFT OUTER JOIN
|
||||
(SELECT argMax(person_distinct_id_overrides.person_id, person_distinct_id_overrides.version) AS person_id,
|
||||
person_distinct_id_overrides.distinct_id AS distinct_id
|
||||
FROM person_distinct_id_overrides
|
||||
WHERE equals(person_distinct_id_overrides.team_id, 99999)
|
||||
GROUP BY person_distinct_id_overrides.distinct_id
|
||||
HAVING ifNull(equals(argMax(person_distinct_id_overrides.is_deleted, person_distinct_id_overrides.version), 0), 0) SETTINGS optimize_aggregation_in_order=1) AS e__override ON equals(e.distinct_id, e__override.distinct_id)
|
||||
WHERE and(equals(e.team_id, 99999), greaterOrEquals(timestamp, toDateTime64('explicit_redacted_timestamp', 6, 'UTC')), lessOrEquals(timestamp, toDateTime64('today', 6, 'UTC')), equals(e.event, '$pageview')))
|
||||
GROUP BY actor_id) AS source
|
||||
ORDER BY source.id ASC
|
||||
LIMIT 1000000000 SETTINGS optimize_aggregation_in_order=1,
|
||||
join_algorithm='auto')) SETTINGS readonly=2,
|
||||
max_execution_time=600,
|
||||
allow_experimental_object_type=1,
|
||||
format_csv_allow_double_quotes=0,
|
||||
max_ast_elements=4000000,
|
||||
max_expanded_ast_elements=4000000,
|
||||
max_bytes_before_external_group_by=0,
|
||||
transform_null_in=1,
|
||||
optimize_min_equality_disjunction_chain_length=4294967295,
|
||||
allow_experimental_join_condition=1
|
||||
'''
|
||||
# ---
|
||||
# name: TestCohortQuery.test_cohort_filter_with_extra.7
|
||||
'''
|
||||
|
||||
SELECT if(behavior_query.person_id = '00000000-0000-0000-0000-000000000000', person.person_id, behavior_query.person_id) AS id
|
||||
FROM
|
||||
(SELECT if(not(empty(pdi.distinct_id)), pdi.person_id, e.person_id) AS person_id,
|
||||
countIf(timestamp > now() - INTERVAL 1 week
|
||||
AND timestamp < now()
|
||||
AND event = '$pageview'
|
||||
AND 1=1) > 0 AS performed_event_condition_None_level_level_0_level_1_level_0_0
|
||||
FROM events e
|
||||
LEFT OUTER JOIN
|
||||
(SELECT distinct_id,
|
||||
argMax(person_id, version) as person_id
|
||||
FROM person_distinct_id2
|
||||
WHERE team_id = 99999
|
||||
GROUP BY distinct_id
|
||||
HAVING argMax(is_deleted, version) = 0) AS pdi ON e.distinct_id = pdi.distinct_id
|
||||
WHERE team_id = 99999
|
||||
AND event IN ['$pageview']
|
||||
AND timestamp <= now()
|
||||
AND timestamp >= now() - INTERVAL 1 week
|
||||
GROUP BY person_id) behavior_query
|
||||
FULL OUTER JOIN
|
||||
(SELECT *,
|
||||
id AS person_id
|
||||
FROM
|
||||
(SELECT id,
|
||||
argMax(properties, version) as person_props
|
||||
FROM person
|
||||
WHERE team_id = 99999
|
||||
GROUP BY id
|
||||
HAVING max(is_deleted) = 0 SETTINGS optimize_aggregation_in_order = 1)) person ON person.person_id = behavior_query.person_id
|
||||
WHERE 1 = 1
|
||||
AND ((((has(['test'], replaceRegexpAll(JSONExtractRaw(person_props, 'name'), '^"|"$', ''))))
|
||||
OR ((coalesce(performed_event_condition_None_level_level_0_level_1_level_0_0, false))))) SETTINGS optimize_aggregation_in_order = 1,
|
||||
join_algorithm = 'auto'
|
||||
'''
|
||||
# ---
|
||||
# name: TestCohortQuery.test_cohort_filter_with_extra.8
|
||||
'''
|
||||
(
|
||||
@@ -1972,8 +1930,6 @@
|
||||
WHERE lc_cohort_id = 99999
|
||||
AND team_id = 99999
|
||||
AND query LIKE '%cohort_calc:00000000%'
|
||||
AND type IN ('QueryFinish',
|
||||
'ExceptionWhileProcessing')
|
||||
AND event_date >= 'today'
|
||||
AND event_time >= 'today 00:00:00'
|
||||
ORDER BY event_time DESC
|
||||
@@ -1981,6 +1937,25 @@
|
||||
# ---
|
||||
# name: TestCohortQuery.test_precalculated_cohort_filter_with_extra_filters.1
|
||||
'''
|
||||
/* celery:posthog.tasks.calculate_cohort.collect_cohort_query_stats */
|
||||
SELECT query_id,
|
||||
query_duration_ms,
|
||||
read_rows,
|
||||
read_bytes,
|
||||
written_rows,
|
||||
memory_usage,
|
||||
exception
|
||||
FROM query_log_archive
|
||||
WHERE lc_cohort_id = 99999
|
||||
AND team_id = 99999
|
||||
AND query LIKE '%cohort_calc:00000000%'
|
||||
AND event_date >= 'today'
|
||||
AND event_time >= 'today 00:00:00'
|
||||
ORDER BY event_time DESC
|
||||
'''
|
||||
# ---
|
||||
# name: TestCohortQuery.test_precalculated_cohort_filter_with_extra_filters.2
|
||||
'''
|
||||
|
||||
SELECT count(DISTINCT person_id)
|
||||
FROM cohortpeople
|
||||
@@ -1989,45 +1964,6 @@
|
||||
AND version = 0
|
||||
'''
|
||||
# ---
|
||||
# name: TestCohortQuery.test_precalculated_cohort_filter_with_extra_filters.2
|
||||
'''
|
||||
|
||||
(SELECT cohort_people.person_id AS id
|
||||
FROM
|
||||
(SELECT DISTINCT cohortpeople.person_id AS person_id,
|
||||
cohortpeople.cohort_id AS cohort_id,
|
||||
cohortpeople.team_id AS team_id
|
||||
FROM cohortpeople
|
||||
WHERE and(equals(cohortpeople.team_id, 99999), in(tuple(cohortpeople.cohort_id, cohortpeople.version), [(99999, 0)]))) AS cohort_people
|
||||
WHERE and(ifNull(equals(cohort_people.cohort_id, 99999), 0), ifNull(equals(cohort_people.team_id, 99999), 0))
|
||||
LIMIT 1000000000)
|
||||
UNION DISTINCT
|
||||
(SELECT persons.id AS id
|
||||
FROM
|
||||
(SELECT argMax(replaceRegexpAll(nullIf(nullIf(JSONExtractRaw(person.properties, 'name'), ''), 'null'), '^"|"$', ''), person.version) AS properties___name,
|
||||
person.id AS id
|
||||
FROM person
|
||||
WHERE and(equals(person.team_id, 99999), in(id,
|
||||
(SELECT where_optimization.id AS id
|
||||
FROM person AS where_optimization
|
||||
WHERE and(equals(where_optimization.team_id, 99999), ifNull(equals(replaceRegexpAll(nullIf(nullIf(JSONExtractRaw(where_optimization.properties, 'name'), ''), 'null'), '^"|"$', ''), 'test2'), 0)))))
|
||||
GROUP BY person.id
|
||||
HAVING and(ifNull(equals(argMax(person.is_deleted, person.version), 0), 0), ifNull(less(argMax(toTimeZone(person.created_at, 'UTC'), person.version), plus(now64(6, 'UTC'), toIntervalDay(1))), 0))) AS persons
|
||||
WHERE ifNull(equals(persons.properties___name, 'test2'), 0)
|
||||
ORDER BY persons.id ASC
|
||||
LIMIT 1000000000 SETTINGS optimize_aggregation_in_order=1,
|
||||
join_algorithm='auto') SETTINGS readonly=2,
|
||||
max_execution_time=600,
|
||||
allow_experimental_object_type=1,
|
||||
format_csv_allow_double_quotes=0,
|
||||
max_ast_elements=4000000,
|
||||
max_expanded_ast_elements=4000000,
|
||||
max_bytes_before_external_group_by=0,
|
||||
transform_null_in=1,
|
||||
optimize_min_equality_disjunction_chain_length=4294967295,
|
||||
allow_experimental_join_condition=1
|
||||
'''
|
||||
# ---
|
||||
# name: TestCohortQuery.test_precalculated_cohort_filter_with_extra_filters.3
|
||||
'''
|
||||
|
||||
@@ -2068,6 +2004,45 @@
|
||||
'''
|
||||
# ---
|
||||
# name: TestCohortQuery.test_precalculated_cohort_filter_with_extra_filters.4
|
||||
'''
|
||||
|
||||
(SELECT cohort_people.person_id AS id
|
||||
FROM
|
||||
(SELECT DISTINCT cohortpeople.person_id AS person_id,
|
||||
cohortpeople.cohort_id AS cohort_id,
|
||||
cohortpeople.team_id AS team_id
|
||||
FROM cohortpeople
|
||||
WHERE and(equals(cohortpeople.team_id, 99999), in(tuple(cohortpeople.cohort_id, cohortpeople.version), [(99999, 0)]))) AS cohort_people
|
||||
WHERE and(ifNull(equals(cohort_people.cohort_id, 99999), 0), ifNull(equals(cohort_people.team_id, 99999), 0))
|
||||
LIMIT 1000000000)
|
||||
UNION DISTINCT
|
||||
(SELECT persons.id AS id
|
||||
FROM
|
||||
(SELECT argMax(replaceRegexpAll(nullIf(nullIf(JSONExtractRaw(person.properties, 'name'), ''), 'null'), '^"|"$', ''), person.version) AS properties___name,
|
||||
person.id AS id
|
||||
FROM person
|
||||
WHERE and(equals(person.team_id, 99999), in(id,
|
||||
(SELECT where_optimization.id AS id
|
||||
FROM person AS where_optimization
|
||||
WHERE and(equals(where_optimization.team_id, 99999), ifNull(equals(replaceRegexpAll(nullIf(nullIf(JSONExtractRaw(where_optimization.properties, 'name'), ''), 'null'), '^"|"$', ''), 'test2'), 0)))))
|
||||
GROUP BY person.id
|
||||
HAVING and(ifNull(equals(argMax(person.is_deleted, person.version), 0), 0), ifNull(less(argMax(toTimeZone(person.created_at, 'UTC'), person.version), plus(now64(6, 'UTC'), toIntervalDay(1))), 0))) AS persons
|
||||
WHERE ifNull(equals(persons.properties___name, 'test2'), 0)
|
||||
ORDER BY persons.id ASC
|
||||
LIMIT 1000000000 SETTINGS optimize_aggregation_in_order=1,
|
||||
join_algorithm='auto') SETTINGS readonly=2,
|
||||
max_execution_time=600,
|
||||
allow_experimental_object_type=1,
|
||||
format_csv_allow_double_quotes=0,
|
||||
max_ast_elements=4000000,
|
||||
max_expanded_ast_elements=4000000,
|
||||
max_bytes_before_external_group_by=0,
|
||||
transform_null_in=1,
|
||||
optimize_min_equality_disjunction_chain_length=4294967295,
|
||||
allow_experimental_join_condition=1
|
||||
'''
|
||||
# ---
|
||||
# name: TestCohortQuery.test_precalculated_cohort_filter_with_extra_filters.5
|
||||
'''
|
||||
|
||||
SELECT person.person_id AS id
|
||||
@@ -2314,8 +2289,6 @@
|
||||
WHERE lc_cohort_id = 99999
|
||||
AND team_id = 99999
|
||||
AND query LIKE '%cohort_calc:00000000%'
|
||||
AND type IN ('QueryFinish',
|
||||
'ExceptionWhileProcessing')
|
||||
AND event_date >= 'today'
|
||||
AND event_time >= 'today 00:00:00'
|
||||
ORDER BY event_time DESC
|
||||
@@ -2323,6 +2296,25 @@
|
||||
# ---
|
||||
# name: TestCohortQuery.test_unwrapping_static_cohort_filter_hidden_in_layers_of_cohorts.1
|
||||
'''
|
||||
/* celery:posthog.tasks.calculate_cohort.collect_cohort_query_stats */
|
||||
SELECT query_id,
|
||||
query_duration_ms,
|
||||
read_rows,
|
||||
read_bytes,
|
||||
written_rows,
|
||||
memory_usage,
|
||||
exception
|
||||
FROM query_log_archive
|
||||
WHERE lc_cohort_id = 99999
|
||||
AND team_id = 99999
|
||||
AND query LIKE '%cohort_calc:00000000%'
|
||||
AND event_date >= 'today'
|
||||
AND event_time >= 'today 00:00:00'
|
||||
ORDER BY event_time DESC
|
||||
'''
|
||||
# ---
|
||||
# name: TestCohortQuery.test_unwrapping_static_cohort_filter_hidden_in_layers_of_cohorts.2
|
||||
'''
|
||||
|
||||
SELECT count(DISTINCT person_id)
|
||||
FROM cohortpeople
|
||||
@@ -2331,7 +2323,7 @@
|
||||
AND version = 0
|
||||
'''
|
||||
# ---
|
||||
# name: TestCohortQuery.test_unwrapping_static_cohort_filter_hidden_in_layers_of_cohorts.2
|
||||
# name: TestCohortQuery.test_unwrapping_static_cohort_filter_hidden_in_layers_of_cohorts.3
|
||||
'''
|
||||
|
||||
(SELECT cohort_people.person_id AS id
|
||||
@@ -2379,7 +2371,7 @@
|
||||
allow_experimental_join_condition=1
|
||||
'''
|
||||
# ---
|
||||
# name: TestCohortQuery.test_unwrapping_static_cohort_filter_hidden_in_layers_of_cohorts.3
|
||||
# name: TestCohortQuery.test_unwrapping_static_cohort_filter_hidden_in_layers_of_cohorts.4
|
||||
'''
|
||||
|
||||
SELECT if(behavior_query.person_id = '00000000-0000-0000-0000-000000000000', person.person_id, behavior_query.person_id) AS id
|
||||
|
||||
@@ -13,8 +13,6 @@
|
||||
WHERE lc_cohort_id = 99999
|
||||
AND team_id = 99999
|
||||
AND query LIKE '%cohort_calc:00000000%'
|
||||
AND type IN ('QueryFinish',
|
||||
'ExceptionWhileProcessing')
|
||||
AND event_date >= '2021-01-21'
|
||||
AND event_time >= '2021-01-21 00:00:00'
|
||||
ORDER BY event_time DESC
|
||||
@@ -22,6 +20,25 @@
|
||||
# ---
|
||||
# name: TestEventQuery.test_account_filters.1
|
||||
'''
|
||||
/* celery:posthog.tasks.calculate_cohort.collect_cohort_query_stats */
|
||||
SELECT query_id,
|
||||
query_duration_ms,
|
||||
read_rows,
|
||||
read_bytes,
|
||||
written_rows,
|
||||
memory_usage,
|
||||
exception
|
||||
FROM query_log_archive
|
||||
WHERE lc_cohort_id = 99999
|
||||
AND team_id = 99999
|
||||
AND query LIKE '%cohort_calc:00000000%'
|
||||
AND event_date >= '2021-01-21'
|
||||
AND event_time >= '2021-01-21 00:00:00'
|
||||
ORDER BY event_time DESC
|
||||
'''
|
||||
# ---
|
||||
# name: TestEventQuery.test_account_filters.2
|
||||
'''
|
||||
|
||||
SELECT count(DISTINCT person_id)
|
||||
FROM cohortpeople
|
||||
@@ -30,7 +47,7 @@
|
||||
AND version = 0
|
||||
'''
|
||||
# ---
|
||||
# name: TestEventQuery.test_account_filters.2
|
||||
# name: TestEventQuery.test_account_filters.3
|
||||
'''
|
||||
SELECT e.timestamp as timestamp,
|
||||
if(notEmpty(pdi.distinct_id), pdi.person_id, e.person_id) as person_id
|
||||
|
||||
@@ -49,7 +49,7 @@ from posthog.errors import ch_error_type, wrap_query_error
|
||||
),
|
||||
(
|
||||
ServerException(
|
||||
"Code: 439. DB::Exception: Cannot schedule a task: cannot allocate thread (threads=36, jobs=36). (CANNOT_SCHEDULE_TASK) (version 24.8.14.39 (official build))",
|
||||
"Code: 439. DB::Exception: Cannot schedule a task: cannot allocate thread (threads=36, jobs=36). (CANNOT_SCHEDULE_TASK) (version 25.8.11.66 (official build))",
|
||||
code=439,
|
||||
),
|
||||
"ClickHouseAtCapacity",
|
||||
@@ -59,7 +59,7 @@ from posthog.errors import ch_error_type, wrap_query_error
|
||||
),
|
||||
(
|
||||
ServerException(
|
||||
"Code: 159. DB::Exception: Timeout exceeded: elapsed 60.046752587 seconds, maximum: 60. (TIMEOUT_EXCEEDED) (version 24.8.7.41 (official build))",
|
||||
"Code: 159. DB::Exception: Timeout exceeded: elapsed 60.046752587 seconds, maximum: 60. (TIMEOUT_EXCEEDED) (version 25.8.11.66 (official build))",
|
||||
code=159,
|
||||
),
|
||||
"ClickHouseQueryTimeOut",
|
||||
|
||||
@@ -14,7 +14,7 @@ from posthog.api.routing import TeamAndOrgViewSetMixin
|
||||
from posthog.api.shared import UserBasicSerializer
|
||||
from posthog.models.activity_logging.activity_log import Detail, changes_between, log_activity
|
||||
from posthog.models.experiment import ExperimentHoldout
|
||||
from posthog.models.signals import model_activity_signal
|
||||
from posthog.models.signals import model_activity_signal, mutable_receiver
|
||||
|
||||
|
||||
class ExperimentHoldoutSerializer(serializers.ModelSerializer):
|
||||
@@ -125,7 +125,7 @@ class ExperimentHoldoutViewSet(TeamAndOrgViewSetMixin, viewsets.ModelViewSet):
|
||||
return super().destroy(request, *args, **kwargs)
|
||||
|
||||
|
||||
@receiver(model_activity_signal, sender=ExperimentHoldout)
|
||||
@mutable_receiver(model_activity_signal, sender=ExperimentHoldout)
|
||||
def handle_experiment_holdout_change(
|
||||
sender, scope, before_update, after_update, activity, user=None, was_impersonated=False, **kwargs
|
||||
):
|
||||
|
||||
@@ -20,7 +20,7 @@ from posthog.api.shared import UserBasicSerializer
|
||||
from posthog.api.tagged_item import TaggedItemSerializerMixin
|
||||
from posthog.models.activity_logging.activity_log import Detail, changes_between, log_activity
|
||||
from posthog.models.experiment import ExperimentSavedMetric, ExperimentToSavedMetric
|
||||
from posthog.models.signals import model_activity_signal
|
||||
from posthog.models.signals import model_activity_signal, mutable_receiver
|
||||
|
||||
|
||||
class ExperimentToSavedMetricSerializer(serializers.ModelSerializer):
|
||||
@@ -112,7 +112,7 @@ class ExperimentSavedMetricViewSet(TeamAndOrgViewSetMixin, viewsets.ModelViewSet
|
||||
serializer_class = ExperimentSavedMetricSerializer
|
||||
|
||||
|
||||
@receiver(model_activity_signal, sender=ExperimentSavedMetric)
|
||||
@mutable_receiver(model_activity_signal, sender=ExperimentSavedMetric)
|
||||
def handle_experiment_saved_metric_change(
|
||||
sender, scope, before_update, after_update, activity, user, was_impersonated=False, **kwargs
|
||||
):
|
||||
|
||||
@@ -1,12 +1,11 @@
|
||||
from copy import deepcopy
|
||||
from datetime import date, timedelta
|
||||
from datetime import date, datetime, timedelta
|
||||
from enum import Enum
|
||||
from typing import Any, Literal
|
||||
from zoneinfo import ZoneInfo
|
||||
|
||||
from django.db.models import Case, F, Prefetch, Q, QuerySet, Value, When
|
||||
from django.db.models.functions import Now
|
||||
from django.dispatch import receiver
|
||||
|
||||
from rest_framework import serializers, viewsets
|
||||
from rest_framework.exceptions import ValidationError
|
||||
@@ -35,7 +34,7 @@ from posthog.models.experiment import (
|
||||
)
|
||||
from posthog.models.feature_flag.feature_flag import FeatureFlag, FeatureFlagEvaluationTag
|
||||
from posthog.models.filters.filter import Filter
|
||||
from posthog.models.signals import model_activity_signal
|
||||
from posthog.models.signals import model_activity_signal, mutable_receiver
|
||||
from posthog.models.team.team import Team
|
||||
from posthog.rbac.access_control_api_mixin import AccessControlViewSetMixin
|
||||
from posthog.rbac.user_access_control import UserAccessControlSerializerMixin
|
||||
@@ -1057,8 +1056,49 @@ class EnterpriseExperimentsViewSet(
|
||||
status=201,
|
||||
)
|
||||
|
||||
@action(methods=["GET"], detail=False, url_path="stats", required_scopes=["experiment:read"])
|
||||
def stats(self, request: Request, **kwargs: Any) -> Response:
|
||||
"""Get experimentation velocity statistics."""
|
||||
team_tz = ZoneInfo(self.team.timezone) if self.team.timezone else ZoneInfo("UTC")
|
||||
today = datetime.now(team_tz).date()
|
||||
|
||||
@receiver(model_activity_signal, sender=Experiment)
|
||||
last_30d_start = today - timedelta(days=30)
|
||||
previous_30d_start = today - timedelta(days=60)
|
||||
previous_30d_end = last_30d_start
|
||||
|
||||
base_queryset = Experiment.objects.filter(team=self.team, deleted=False, archived=False)
|
||||
|
||||
launched_last_30d = base_queryset.filter(
|
||||
start_date__gte=last_30d_start, start_date__lt=today + timedelta(days=1)
|
||||
).count()
|
||||
|
||||
launched_previous_30d = base_queryset.filter(
|
||||
start_date__gte=previous_30d_start, start_date__lt=previous_30d_end
|
||||
).count()
|
||||
|
||||
if launched_previous_30d == 0:
|
||||
percent_change = 100.0 if launched_last_30d > 0 else 0.0
|
||||
else:
|
||||
percent_change = ((launched_last_30d - launched_previous_30d) / launched_previous_30d) * 100
|
||||
|
||||
active_experiments = base_queryset.filter(start_date__isnull=False, end_date__isnull=True).count()
|
||||
|
||||
completed_last_30d = base_queryset.filter(
|
||||
end_date__gte=last_30d_start, end_date__lt=today + timedelta(days=1)
|
||||
).count()
|
||||
|
||||
return Response(
|
||||
{
|
||||
"launched_last_30d": launched_last_30d,
|
||||
"launched_previous_30d": launched_previous_30d,
|
||||
"percent_change": round(percent_change, 1),
|
||||
"active_experiments": active_experiments,
|
||||
"completed_last_30d": completed_last_30d,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@mutable_receiver(model_activity_signal, sender=Experiment)
|
||||
def handle_experiment_change(
|
||||
sender, scope, before_update, after_update, activity, user, was_impersonated=False, **kwargs
|
||||
):
|
||||
|
||||
@@ -6,7 +6,7 @@
|
||||
"mobile-replay:schema:build:json": "pnpm mobile-replay:web:schema:build:json && pnpm mobile-replay:mobile:schema:build:json"
|
||||
},
|
||||
"dependencies": {
|
||||
"posthog-js": "1.290.0"
|
||||
"posthog-js": "1.292.0"
|
||||
},
|
||||
"devDependencies": {
|
||||
"ts-json-schema-generator": "^v2.4.0-next.6"
|
||||
|
||||
@@ -127,9 +127,9 @@ For a _lot_ of great detail on prompting, check out the [GPT-4.1 prompting guide
|
||||
|
||||
## Support new query types
|
||||
|
||||
Max can now read from frontend context multiple query types like trends, funnels, retention, and HogQL queries. To add support for new query types, you need to extend both the QueryExecutor and the Root node.
|
||||
PostHog AI can now read from frontend context multiple query types like trends, funnels, retention, and HogQL queries. To add support for new query types, you need to extend both the QueryExecutor and the Root node.
|
||||
|
||||
NOTE: this won't extend query types generation. For that, talk to the Max AI team.
|
||||
NOTE: this won't extend query types generation. For that, talk to the PostHog AI team.
|
||||
|
||||
### Adding a new query type
|
||||
|
||||
|
||||
@@ -453,18 +453,15 @@ Query results: 42 events
|
||||
configurable={
|
||||
"contextual_tools": {
|
||||
"search_session_recordings": {"current_filters": {}},
|
||||
"navigate": {"page_key": "insights"},
|
||||
}
|
||||
}
|
||||
)
|
||||
context_manager = AssistantContextManager(self.team, self.user, config)
|
||||
tools = context_manager.get_contextual_tools()
|
||||
|
||||
self.assertEqual(len(tools), 2)
|
||||
self.assertEqual(len(tools), 1)
|
||||
self.assertIn("search_session_recordings", tools)
|
||||
self.assertIn("navigate", tools)
|
||||
self.assertEqual(tools["search_session_recordings"], {"current_filters": {}})
|
||||
self.assertEqual(tools["navigate"], {"page_key": "insights"})
|
||||
|
||||
def test_get_contextual_tools_empty(self):
|
||||
"""Test extraction of contextual tools returns empty dict when no tools"""
|
||||
|
||||
@@ -187,23 +187,6 @@ async def eval_root(call_root, pytestconfig):
|
||||
id="call_insight_default_props_2",
|
||||
),
|
||||
),
|
||||
# Ensure we try and navigate to the relevant page when asked about specific topics
|
||||
EvalCase(
|
||||
input="What's my MRR?",
|
||||
expected=AssistantToolCall(
|
||||
name="navigate",
|
||||
args={"page_key": "revenueAnalytics"},
|
||||
id="call_navigate_1",
|
||||
),
|
||||
),
|
||||
EvalCase(
|
||||
input="Can you help me create a survey to collect NPS ratings?",
|
||||
expected=AssistantToolCall(
|
||||
name="navigate",
|
||||
args={"page_key": "surveys"},
|
||||
id="call_navigate_1",
|
||||
),
|
||||
),
|
||||
EvalCase(
|
||||
input="Give me the signup to purchase conversion rate for the dates between 8 Jul and 9 Sep",
|
||||
expected=AssistantToolCall(
|
||||
|
||||
407
ee/hogai/eval/ci/max_tools/eval_create_experiment_tool.py
Normal file
407
ee/hogai/eval/ci/max_tools/eval_create_experiment_tool.py
Normal file
@@ -0,0 +1,407 @@
|
||||
"""Evaluations for CreateExperimentTool."""
|
||||
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
|
||||
from autoevals.partial import ScorerWithPartial
|
||||
from autoevals.ragas import AnswerSimilarity
|
||||
from braintrust import EvalCase, Score
|
||||
|
||||
from posthog.models import Experiment, FeatureFlag
|
||||
|
||||
from products.experiments.backend.max_tools import CreateExperimentTool
|
||||
|
||||
from ee.hogai.eval.base import MaxPublicEval
|
||||
from ee.hogai.utils.types import AssistantState
|
||||
from ee.models.assistant import Conversation
|
||||
|
||||
|
||||
class ExperimentOutputScorer(ScorerWithPartial):
|
||||
"""Custom scorer for experiment tool output that combines semantic similarity for text and exact matching for numbers/booleans."""
|
||||
|
||||
def __init__(self, semantic_fields: set[str] | None = None, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.semantic_fields = semantic_fields or {"message"}
|
||||
|
||||
def _run_eval_sync(self, output: dict, expected: dict, **kwargs):
|
||||
if not expected:
|
||||
return Score(name=self._name(), score=None, metadata={"reason": "No expected value provided"})
|
||||
if not output:
|
||||
return Score(name=self._name(), score=0.0, metadata={"reason": "No output provided"})
|
||||
|
||||
total_fields = len(expected)
|
||||
if total_fields == 0:
|
||||
return Score(name=self._name(), score=1.0)
|
||||
|
||||
score_per_field = 1.0 / total_fields
|
||||
total_score = 0.0
|
||||
metadata = {}
|
||||
|
||||
for field_name, expected_value in expected.items():
|
||||
actual_value = output.get(field_name)
|
||||
|
||||
if field_name in self.semantic_fields:
|
||||
# Use semantic similarity for text fields
|
||||
if actual_value is not None and expected_value is not None:
|
||||
similarity_scorer = AnswerSimilarity(model="text-embedding-3-small")
|
||||
result = similarity_scorer.eval(output=str(actual_value), expected=str(expected_value))
|
||||
field_score = result.score * score_per_field
|
||||
total_score += field_score
|
||||
metadata[f"{field_name}_score"] = result.score
|
||||
else:
|
||||
metadata[f"{field_name}_missing"] = True
|
||||
else:
|
||||
# Use exact match for numeric/boolean fields
|
||||
if actual_value == expected_value:
|
||||
total_score += score_per_field
|
||||
metadata[f"{field_name}_match"] = True
|
||||
else:
|
||||
metadata[f"{field_name}_mismatch"] = {
|
||||
"expected": expected_value,
|
||||
"actual": actual_value,
|
||||
}
|
||||
|
||||
return Score(name=self._name(), score=total_score, metadata=metadata)
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
async def eval_create_experiment_basic(pytestconfig, demo_org_team_user):
|
||||
"""Test basic experiment creation."""
|
||||
_, team, user = demo_org_team_user
|
||||
|
||||
conversation = await Conversation.objects.acreate(team=team, user=user)
|
||||
|
||||
async def task_create_experiment(args: dict):
|
||||
# Create feature flag first (required by the tool)
|
||||
await FeatureFlag.objects.acreate(
|
||||
team=team,
|
||||
created_by=user,
|
||||
key=args["feature_flag_key"],
|
||||
name=f"Flag for {args['name']}",
|
||||
filters={
|
||||
"groups": [{"properties": [], "rollout_percentage": 100}],
|
||||
"multivariate": {
|
||||
"variants": [
|
||||
{"key": "control", "name": "Control", "rollout_percentage": 50},
|
||||
{"key": "test", "name": "Test", "rollout_percentage": 50},
|
||||
]
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
tool = await CreateExperimentTool.create_tool_class(
|
||||
team=team,
|
||||
user=user,
|
||||
state=AssistantState(messages=[]),
|
||||
config={
|
||||
"configurable": {
|
||||
"thread_id": conversation.id,
|
||||
"team": team,
|
||||
"user": user,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
result_message, artifact = await tool._arun_impl(
|
||||
name=args["name"],
|
||||
feature_flag_key=args["feature_flag_key"],
|
||||
description=args.get("description"),
|
||||
type=args.get("type", "product"),
|
||||
)
|
||||
|
||||
exp_created = await Experiment.objects.filter(team=team, name=args["name"], deleted=False).aexists()
|
||||
|
||||
return {
|
||||
"message": result_message,
|
||||
"experiment_created": exp_created,
|
||||
"experiment_name": artifact.get("experiment_name") if artifact else None,
|
||||
}
|
||||
|
||||
await MaxPublicEval(
|
||||
experiment_name="create_experiment_basic",
|
||||
task=task_create_experiment, # type: ignore
|
||||
scores=[ExperimentOutputScorer(semantic_fields={"message", "experiment_name"})],
|
||||
data=[
|
||||
EvalCase(
|
||||
input={"name": "Pricing Test", "feature_flag_key": "pricing-test-flag"},
|
||||
expected={
|
||||
"message": "Successfully created experiment",
|
||||
"experiment_created": True,
|
||||
"experiment_name": "Pricing Test",
|
||||
},
|
||||
),
|
||||
EvalCase(
|
||||
input={
|
||||
"name": "Homepage Redesign",
|
||||
"feature_flag_key": "homepage-redesign",
|
||||
"description": "Testing new homepage layout for better conversion",
|
||||
},
|
||||
expected={
|
||||
"message": "Successfully created experiment",
|
||||
"experiment_created": True,
|
||||
"experiment_name": "Homepage Redesign",
|
||||
},
|
||||
),
|
||||
],
|
||||
pytestconfig=pytestconfig,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
async def eval_create_experiment_types(pytestconfig, demo_org_team_user):
|
||||
"""Test experiment creation with different types (product vs web)."""
|
||||
_, team, user = demo_org_team_user
|
||||
|
||||
conversation = await Conversation.objects.acreate(team=team, user=user)
|
||||
|
||||
async def task_create_typed_experiment(args: dict):
|
||||
# Create feature flag first
|
||||
await FeatureFlag.objects.acreate(
|
||||
team=team,
|
||||
created_by=user,
|
||||
key=args["feature_flag_key"],
|
||||
name=f"Flag for {args['name']}",
|
||||
filters={
|
||||
"groups": [{"properties": [], "rollout_percentage": 100}],
|
||||
"multivariate": {
|
||||
"variants": [
|
||||
{"key": "control", "name": "Control", "rollout_percentage": 50},
|
||||
{"key": "test", "name": "Test", "rollout_percentage": 50},
|
||||
]
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
tool = await CreateExperimentTool.create_tool_class(
|
||||
team=team,
|
||||
user=user,
|
||||
state=AssistantState(messages=[]),
|
||||
config={
|
||||
"configurable": {
|
||||
"thread_id": conversation.id,
|
||||
"team": team,
|
||||
"user": user,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
result_message, artifact = await tool._arun_impl(
|
||||
name=args["name"],
|
||||
feature_flag_key=args["feature_flag_key"],
|
||||
type=args["type"],
|
||||
)
|
||||
|
||||
# Verify experiment type
|
||||
experiment = await Experiment.objects.aget(team=team, name=args["name"])
|
||||
|
||||
return {
|
||||
"message": result_message,
|
||||
"experiment_type": experiment.type,
|
||||
"artifact_type": artifact.get("type") if artifact else None,
|
||||
}
|
||||
|
||||
await MaxPublicEval(
|
||||
experiment_name="create_experiment_types",
|
||||
task=task_create_typed_experiment, # type: ignore
|
||||
scores=[ExperimentOutputScorer(semantic_fields={"message"})],
|
||||
data=[
|
||||
EvalCase(
|
||||
input={"name": "Product Feature Test", "feature_flag_key": "product-test", "type": "product"},
|
||||
expected={
|
||||
"message": "Successfully created experiment",
|
||||
"experiment_type": "product",
|
||||
"artifact_type": "product",
|
||||
},
|
||||
),
|
||||
EvalCase(
|
||||
input={"name": "Web UI Test", "feature_flag_key": "web-test", "type": "web"},
|
||||
expected={
|
||||
"message": "Successfully created experiment",
|
||||
"experiment_type": "web",
|
||||
"artifact_type": "web",
|
||||
},
|
||||
),
|
||||
],
|
||||
pytestconfig=pytestconfig,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
async def eval_create_experiment_with_existing_flag(pytestconfig, demo_org_team_user):
|
||||
"""Test experiment creation with an existing feature flag."""
|
||||
_, team, user = demo_org_team_user
|
||||
|
||||
# Create an existing flag with unique key and multivariate variants
|
||||
unique_key = f"reusable-flag-{uuid.uuid4().hex[:8]}"
|
||||
await FeatureFlag.objects.acreate(
|
||||
team=team,
|
||||
key=unique_key,
|
||||
name="Reusable Flag",
|
||||
created_by=user,
|
||||
filters={
|
||||
"groups": [{"properties": [], "rollout_percentage": 100}],
|
||||
"multivariate": {
|
||||
"variants": [
|
||||
{"key": "control", "name": "Control", "rollout_percentage": 50},
|
||||
{"key": "test", "name": "Test", "rollout_percentage": 50},
|
||||
]
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
conversation = await Conversation.objects.acreate(team=team, user=user)
|
||||
|
||||
async def task_create_experiment_reuse_flag(args: dict):
|
||||
tool = await CreateExperimentTool.create_tool_class(
|
||||
team=team,
|
||||
user=user,
|
||||
state=AssistantState(messages=[]),
|
||||
config={
|
||||
"configurable": {
|
||||
"thread_id": conversation.id,
|
||||
"team": team,
|
||||
"user": user,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
result_message, artifact = await tool._arun_impl(
|
||||
name=args["name"],
|
||||
feature_flag_key=args["feature_flag_key"],
|
||||
)
|
||||
|
||||
return {
|
||||
"message": result_message,
|
||||
"experiment_created": artifact is not None and "experiment_id" in artifact,
|
||||
}
|
||||
|
||||
await MaxPublicEval(
|
||||
experiment_name="create_experiment_with_existing_flag",
|
||||
task=task_create_experiment_reuse_flag, # type: ignore
|
||||
scores=[ExperimentOutputScorer(semantic_fields={"message"})],
|
||||
data=[
|
||||
EvalCase(
|
||||
input={"name": "Reuse Flag Test", "feature_flag_key": unique_key},
|
||||
expected={
|
||||
"message": "Successfully created experiment",
|
||||
"experiment_created": True,
|
||||
},
|
||||
),
|
||||
],
|
||||
pytestconfig=pytestconfig,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
async def eval_create_experiment_duplicate_name_error(pytestconfig, demo_org_team_user):
|
||||
"""Test that creating a duplicate experiment returns an error."""
|
||||
_, team, user = demo_org_team_user
|
||||
|
||||
# Create an existing experiment with unique flag key
|
||||
unique_flag_key = f"test-flag-{uuid.uuid4().hex[:8]}"
|
||||
flag = await FeatureFlag.objects.acreate(team=team, key=unique_flag_key, created_by=user)
|
||||
await Experiment.objects.acreate(team=team, name="Existing Experiment", feature_flag=flag, created_by=user)
|
||||
|
||||
conversation = await Conversation.objects.acreate(team=team, user=user)
|
||||
|
||||
async def task_create_duplicate_experiment(args: dict):
|
||||
# Create a different flag for the duplicate attempt
|
||||
await FeatureFlag.objects.acreate(
|
||||
team=team,
|
||||
created_by=user,
|
||||
key=args["feature_flag_key"],
|
||||
name=f"Flag for {args['name']}",
|
||||
)
|
||||
|
||||
tool = await CreateExperimentTool.create_tool_class(
|
||||
team=team,
|
||||
user=user,
|
||||
state=AssistantState(messages=[]),
|
||||
config={
|
||||
"configurable": {
|
||||
"thread_id": conversation.id,
|
||||
"team": team,
|
||||
"user": user,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
result_message, artifact = await tool._arun_impl(
|
||||
name=args["name"],
|
||||
feature_flag_key=args["feature_flag_key"],
|
||||
)
|
||||
|
||||
return {
|
||||
"message": result_message,
|
||||
"has_error": artifact.get("error") is not None if artifact else False,
|
||||
}
|
||||
|
||||
await MaxPublicEval(
|
||||
experiment_name="create_experiment_duplicate_name_error",
|
||||
task=task_create_duplicate_experiment, # type: ignore
|
||||
scores=[ExperimentOutputScorer(semantic_fields={"message"})],
|
||||
data=[
|
||||
EvalCase(
|
||||
input={"name": "Existing Experiment", "feature_flag_key": "another-flag"},
|
||||
expected={
|
||||
"message": "Failed to create experiment: An experiment with name 'Existing Experiment' already exists",
|
||||
"has_error": True,
|
||||
},
|
||||
),
|
||||
],
|
||||
pytestconfig=pytestconfig,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
async def eval_create_experiment_flag_already_used_error(pytestconfig, demo_org_team_user):
|
||||
"""Test that using a flag already tied to another experiment returns an error."""
|
||||
_, team, user = demo_org_team_user
|
||||
|
||||
# Create an experiment with a flag (unique key)
|
||||
unique_flag_key = f"used-flag-{uuid.uuid4().hex[:8]}"
|
||||
flag = await FeatureFlag.objects.acreate(team=team, key=unique_flag_key, created_by=user)
|
||||
await Experiment.objects.acreate(team=team, name="First Experiment", feature_flag=flag, created_by=user)
|
||||
|
||||
conversation = await Conversation.objects.acreate(team=team, user=user)
|
||||
|
||||
async def task_create_experiment_with_used_flag(args: dict):
|
||||
tool = await CreateExperimentTool.create_tool_class(
|
||||
team=team,
|
||||
user=user,
|
||||
state=AssistantState(messages=[]),
|
||||
config={
|
||||
"configurable": {
|
||||
"thread_id": conversation.id,
|
||||
"team": team,
|
||||
"user": user,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
result_message, artifact = await tool._arun_impl(
|
||||
name=args["name"],
|
||||
feature_flag_key=args["feature_flag_key"],
|
||||
)
|
||||
|
||||
return {
|
||||
"message": result_message,
|
||||
"has_error": artifact.get("error") is not None if artifact else False,
|
||||
}
|
||||
|
||||
await MaxPublicEval(
|
||||
experiment_name="create_experiment_flag_already_used_error",
|
||||
task=task_create_experiment_with_used_flag, # type: ignore
|
||||
scores=[ExperimentOutputScorer(semantic_fields={"message"})],
|
||||
data=[
|
||||
EvalCase(
|
||||
input={"name": "Second Experiment", "feature_flag_key": unique_flag_key},
|
||||
expected={
|
||||
"message": "Failed to create experiment: Feature flag is already used by experiment",
|
||||
"has_error": True,
|
||||
},
|
||||
),
|
||||
],
|
||||
pytestconfig=pytestconfig,
|
||||
)
|
||||
@@ -365,3 +365,467 @@ async def eval_create_feature_flag_duplicate_handling(pytestconfig, demo_org_tea
|
||||
],
|
||||
pytestconfig=pytestconfig,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
async def eval_create_multivariate_feature_flag(pytestconfig, demo_org_team_user):
|
||||
"""Test multivariate feature flag creation for A/B testing."""
|
||||
_, team, user = demo_org_team_user
|
||||
|
||||
conversation = await Conversation.objects.acreate(team=team, user=user)
|
||||
|
||||
# Generate unique keys for this test run
|
||||
unique_suffix = uuid.uuid4().hex[:6]
|
||||
|
||||
async def task_create_multivariate_flag(instructions: str):
|
||||
tool = await CreateFeatureFlagTool.create_tool_class(
|
||||
team=team,
|
||||
user=user,
|
||||
state=AssistantState(messages=[]),
|
||||
config={
|
||||
"configurable": {
|
||||
"thread_id": conversation.id,
|
||||
"team": team,
|
||||
"user": user,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
result_message, artifact = await tool._arun_impl(instructions=instructions)
|
||||
|
||||
flag_key = artifact.get("flag_key")
|
||||
|
||||
# Verify multivariate schema structure if flag was created
|
||||
variant_count = 0
|
||||
schema_valid = False
|
||||
has_multivariate = False
|
||||
if flag_key:
|
||||
try:
|
||||
flag = await FeatureFlag.objects.aget(team=team, key=flag_key)
|
||||
multivariate = flag.filters.get("multivariate")
|
||||
|
||||
if multivariate:
|
||||
has_multivariate = True
|
||||
variants = multivariate.get("variants", [])
|
||||
variant_count = len(variants)
|
||||
|
||||
# Verify each variant has the required schema structure
|
||||
schema_valid = True
|
||||
for variant in variants:
|
||||
# Verify variant has key and rollout_percentage (name is optional)
|
||||
if not all(key in variant for key in ["key", "rollout_percentage"]):
|
||||
schema_valid = False
|
||||
break
|
||||
# Verify rollout_percentage is a number
|
||||
if not isinstance(variant["rollout_percentage"], int | float):
|
||||
schema_valid = False
|
||||
break
|
||||
except FeatureFlag.DoesNotExist:
|
||||
pass
|
||||
|
||||
return {
|
||||
"message": result_message,
|
||||
"has_multivariate": has_multivariate,
|
||||
"variant_count": variant_count,
|
||||
"created": flag_key is not None,
|
||||
"schema_valid": schema_valid,
|
||||
}
|
||||
|
||||
await MaxPublicEval(
|
||||
experiment_name="create_multivariate_feature_flag",
|
||||
task=task_create_multivariate_flag, # type: ignore
|
||||
scores=[FeatureFlagOutputScorer(semantic_fields={"message"})],
|
||||
data=[
|
||||
EvalCase(
|
||||
input=f"Create an A/B test flag called 'ab-test-{unique_suffix}' with control and test variants",
|
||||
expected={
|
||||
"message": f"Successfully created feature flag 'ab-test-{unique_suffix}' with A/B test",
|
||||
"has_multivariate": True,
|
||||
"variant_count": 2,
|
||||
"created": True,
|
||||
"schema_valid": True,
|
||||
},
|
||||
),
|
||||
EvalCase(
|
||||
input=f"Create a multivariate flag called 'abc-test-{unique_suffix}' with 3 variants for testing",
|
||||
expected={
|
||||
"message": f"Successfully created feature flag 'abc-test-{unique_suffix}' with multivariate",
|
||||
"has_multivariate": True,
|
||||
"variant_count": 3,
|
||||
"created": True,
|
||||
"schema_valid": True,
|
||||
},
|
||||
),
|
||||
EvalCase(
|
||||
input=f"Create an A/B test flag called 'pricing-test-{unique_suffix}' for testing new pricing",
|
||||
expected={
|
||||
"message": f"Successfully created feature flag 'pricing-test-{unique_suffix}' with A/B test",
|
||||
"has_multivariate": True,
|
||||
"variant_count": 2,
|
||||
"created": True,
|
||||
"schema_valid": True,
|
||||
},
|
||||
),
|
||||
],
|
||||
pytestconfig=pytestconfig,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
async def eval_create_multivariate_with_rollout(pytestconfig, demo_org_team_user):
|
||||
"""Test multivariate feature flags with rollout percentages for targeted experiments."""
|
||||
_, team, user = demo_org_team_user
|
||||
|
||||
conversation = await Conversation.objects.acreate(team=team, user=user)
|
||||
|
||||
# Generate unique keys for this test run
|
||||
unique_suffix = uuid.uuid4().hex[:6]
|
||||
|
||||
async def task_create_multivariate_with_rollout(instructions: str):
|
||||
tool = await CreateFeatureFlagTool.create_tool_class(
|
||||
team=team,
|
||||
user=user,
|
||||
state=AssistantState(messages=[]),
|
||||
config={
|
||||
"configurable": {
|
||||
"thread_id": conversation.id,
|
||||
"team": team,
|
||||
"user": user,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
result_message, artifact = await tool._arun_impl(instructions=instructions)
|
||||
|
||||
flag_key = artifact.get("flag_key")
|
||||
|
||||
# Verify multivariate and rollout schema structure
|
||||
has_multivariate = False
|
||||
has_rollout = False
|
||||
variant_count = 0
|
||||
rollout_percentage = None
|
||||
schema_valid = False
|
||||
|
||||
if flag_key:
|
||||
try:
|
||||
flag = await FeatureFlag.objects.aget(team=team, key=flag_key)
|
||||
|
||||
# Check multivariate config
|
||||
multivariate = flag.filters.get("multivariate")
|
||||
if multivariate:
|
||||
has_multivariate = True
|
||||
variants = multivariate.get("variants", [])
|
||||
variant_count = len(variants)
|
||||
|
||||
# Check rollout in groups
|
||||
groups = flag.filters.get("groups", [])
|
||||
if groups and len(groups) > 0:
|
||||
group = groups[0]
|
||||
rollout_percentage = group.get("rollout_percentage")
|
||||
has_rollout = rollout_percentage is not None
|
||||
|
||||
# Verify schema
|
||||
schema_valid = has_multivariate and variant_count > 0
|
||||
if has_rollout:
|
||||
schema_valid = schema_valid and isinstance(rollout_percentage, int | float)
|
||||
|
||||
except FeatureFlag.DoesNotExist:
|
||||
pass
|
||||
|
||||
return {
|
||||
"message": result_message,
|
||||
"has_multivariate": has_multivariate,
|
||||
"has_rollout": has_rollout,
|
||||
"variant_count": variant_count,
|
||||
"created": flag_key is not None,
|
||||
"schema_valid": schema_valid,
|
||||
}
|
||||
|
||||
await MaxPublicEval(
|
||||
experiment_name="create_multivariate_with_rollout",
|
||||
task=task_create_multivariate_with_rollout, # type: ignore
|
||||
scores=[FeatureFlagOutputScorer(semantic_fields={"message"})],
|
||||
data=[
|
||||
EvalCase(
|
||||
input=f"Create an A/B test flag called 'ab-rollout-{unique_suffix}' with control and test variants at 50% rollout",
|
||||
expected={
|
||||
"message": f"Successfully created feature flag 'ab-rollout-{unique_suffix}' with A/B test and 50% rollout",
|
||||
"has_multivariate": True,
|
||||
"has_rollout": True,
|
||||
"variant_count": 2,
|
||||
"created": True,
|
||||
"schema_valid": True,
|
||||
},
|
||||
),
|
||||
EvalCase(
|
||||
input=f"Create a multivariate flag called 'experiment-{unique_suffix}' with 3 variants at 10% rollout",
|
||||
expected={
|
||||
"message": f"Successfully created feature flag 'experiment-{unique_suffix}' with multivariate and 10% rollout",
|
||||
"has_multivariate": True,
|
||||
"has_rollout": True,
|
||||
"variant_count": 3,
|
||||
"created": True,
|
||||
"schema_valid": True,
|
||||
},
|
||||
),
|
||||
],
|
||||
pytestconfig=pytestconfig,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
async def eval_create_multivariate_with_property_filters(pytestconfig, demo_org_team_user):
|
||||
"""Test multivariate feature flags with property-based targeting for segment-specific experiments."""
|
||||
_, team, user = demo_org_team_user
|
||||
|
||||
conversation = await Conversation.objects.acreate(team=team, user=user)
|
||||
|
||||
# Generate unique keys for this test run
|
||||
unique_suffix = uuid.uuid4().hex[:6]
|
||||
|
||||
async def task_create_multivariate_with_properties(instructions: str):
|
||||
tool = await CreateFeatureFlagTool.create_tool_class(
|
||||
team=team,
|
||||
user=user,
|
||||
state=AssistantState(messages=[]),
|
||||
config={
|
||||
"configurable": {
|
||||
"thread_id": conversation.id,
|
||||
"team": team,
|
||||
"user": user,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
result_message, artifact = await tool._arun_impl(instructions=instructions)
|
||||
|
||||
flag_key = artifact.get("flag_key")
|
||||
|
||||
# Verify multivariate and property filter schema structure
|
||||
has_multivariate = False
|
||||
has_properties = False
|
||||
variant_count = 0
|
||||
property_count = 0
|
||||
schema_valid = False
|
||||
|
||||
if flag_key:
|
||||
try:
|
||||
flag = await FeatureFlag.objects.aget(team=team, key=flag_key)
|
||||
|
||||
# Check multivariate config
|
||||
multivariate = flag.filters.get("multivariate")
|
||||
if multivariate:
|
||||
has_multivariate = True
|
||||
variants = multivariate.get("variants", [])
|
||||
variant_count = len(variants)
|
||||
|
||||
# Check properties in groups
|
||||
groups = flag.filters.get("groups", [])
|
||||
for group in groups:
|
||||
properties = group.get("properties", [])
|
||||
property_count += len(properties)
|
||||
# Verify each property has required schema structure
|
||||
for prop in properties:
|
||||
if all(key in prop for key in ["key", "type", "value", "operator"]):
|
||||
has_properties = True
|
||||
else:
|
||||
schema_valid = False
|
||||
break
|
||||
|
||||
# Verify schema
|
||||
schema_valid = has_multivariate and variant_count > 0 and has_properties
|
||||
|
||||
except FeatureFlag.DoesNotExist:
|
||||
pass
|
||||
|
||||
return {
|
||||
"message": result_message,
|
||||
"has_multivariate": has_multivariate,
|
||||
"has_properties": has_properties,
|
||||
"variant_count": variant_count,
|
||||
"created": flag_key is not None,
|
||||
"schema_valid": schema_valid,
|
||||
}
|
||||
|
||||
await MaxPublicEval(
|
||||
experiment_name="create_multivariate_with_property_filters",
|
||||
task=task_create_multivariate_with_properties, # type: ignore
|
||||
scores=[FeatureFlagOutputScorer(semantic_fields={"message"})],
|
||||
data=[
|
||||
EvalCase(
|
||||
input=f"Create an A/B test flag called 'email-test-{unique_suffix}' for users where email contains @company.com with control and test variants",
|
||||
expected={
|
||||
"message": f"Successfully created feature flag 'email-test-{unique_suffix}' with A/B test for users where email contains @company.com",
|
||||
"has_multivariate": True,
|
||||
"has_properties": True,
|
||||
"variant_count": 2,
|
||||
"created": True,
|
||||
"schema_valid": True,
|
||||
},
|
||||
),
|
||||
EvalCase(
|
||||
input=f"Create a multivariate flag called 'us-experiment-{unique_suffix}' with 3 variants targeting US users",
|
||||
expected={
|
||||
"message": f"Successfully created feature flag 'us-experiment-{unique_suffix}' with multivariate targeting US users",
|
||||
"has_multivariate": True,
|
||||
"has_properties": True,
|
||||
"variant_count": 3,
|
||||
"created": True,
|
||||
"schema_valid": True,
|
||||
},
|
||||
),
|
||||
],
|
||||
pytestconfig=pytestconfig,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
async def eval_create_multivariate_with_custom_percentages(pytestconfig, demo_org_team_user):
|
||||
"""Test multivariate feature flags with custom variant percentage distributions."""
|
||||
_, team, user = demo_org_team_user
|
||||
|
||||
conversation = await Conversation.objects.acreate(team=team, user=user)
|
||||
|
||||
# Generate unique keys for this test run
|
||||
unique_suffix = uuid.uuid4().hex[:6]
|
||||
|
||||
async def task_create_multivariate_custom_percentages(instructions: str):
|
||||
tool = await CreateFeatureFlagTool.create_tool_class(
|
||||
team=team,
|
||||
user=user,
|
||||
state=AssistantState(messages=[]),
|
||||
config={
|
||||
"configurable": {
|
||||
"thread_id": conversation.id,
|
||||
"team": team,
|
||||
"user": user,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
result_message, artifact = await tool._arun_impl(instructions=instructions)
|
||||
|
||||
flag_key = artifact.get("flag_key")
|
||||
|
||||
# Verify multivariate with custom percentages
|
||||
has_multivariate = False
|
||||
variant_count = 0
|
||||
percentages_sum_to_100 = False
|
||||
schema_valid = False
|
||||
|
||||
if flag_key:
|
||||
try:
|
||||
flag = await FeatureFlag.objects.aget(team=team, key=flag_key)
|
||||
|
||||
multivariate = flag.filters.get("multivariate")
|
||||
if multivariate:
|
||||
has_multivariate = True
|
||||
variants = multivariate.get("variants", [])
|
||||
variant_count = len(variants)
|
||||
|
||||
# Check if percentages sum to 100
|
||||
total_percentage = sum(v.get("rollout_percentage", 0) for v in variants)
|
||||
percentages_sum_to_100 = total_percentage == 100
|
||||
|
||||
# Verify schema
|
||||
schema_valid = True
|
||||
for variant in variants:
|
||||
if not all(key in variant for key in ["key", "rollout_percentage"]):
|
||||
schema_valid = False
|
||||
break
|
||||
|
||||
except FeatureFlag.DoesNotExist:
|
||||
pass
|
||||
|
||||
return {
|
||||
"message": result_message,
|
||||
"has_multivariate": has_multivariate,
|
||||
"variant_count": variant_count,
|
||||
"percentages_valid": percentages_sum_to_100,
|
||||
"created": flag_key is not None,
|
||||
"schema_valid": schema_valid and percentages_sum_to_100,
|
||||
}
|
||||
|
||||
await MaxPublicEval(
|
||||
experiment_name="create_multivariate_with_custom_percentages",
|
||||
task=task_create_multivariate_custom_percentages, # type: ignore
|
||||
scores=[FeatureFlagOutputScorer(semantic_fields={"message"})],
|
||||
data=[
|
||||
EvalCase(
|
||||
input=f"Create an A/B test flag called 'uneven-test-{unique_suffix}' with control at 70% and test at 30%",
|
||||
expected={
|
||||
"message": f"Successfully created feature flag 'uneven-test-{unique_suffix}' with A/B test",
|
||||
"has_multivariate": True,
|
||||
"variant_count": 2,
|
||||
"percentages_valid": True,
|
||||
"created": True,
|
||||
"schema_valid": True,
|
||||
},
|
||||
),
|
||||
EvalCase(
|
||||
input=f"Create a multivariate flag called 'weighted-test-{unique_suffix}' with control (33%), variant_a (33%), variant_b (34%)",
|
||||
expected={
|
||||
"message": f"Successfully created feature flag 'weighted-test-{unique_suffix}' with multivariate",
|
||||
"has_multivariate": True,
|
||||
"variant_count": 3,
|
||||
"percentages_valid": True,
|
||||
"created": True,
|
||||
"schema_valid": True,
|
||||
},
|
||||
),
|
||||
],
|
||||
pytestconfig=pytestconfig,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
async def eval_create_multivariate_error_handling(pytestconfig, demo_org_team_user):
|
||||
"""Test multivariate feature flag error handling for invalid configurations."""
|
||||
_, team, user = demo_org_team_user
|
||||
|
||||
conversation = await Conversation.objects.acreate(team=team, user=user)
|
||||
|
||||
# Generate unique keys for this test run
|
||||
unique_suffix = uuid.uuid4().hex[:6]
|
||||
|
||||
async def task_create_invalid_multivariate(instructions: str):
|
||||
tool = await CreateFeatureFlagTool.create_tool_class(
|
||||
team=team,
|
||||
user=user,
|
||||
state=AssistantState(messages=[]),
|
||||
config={
|
||||
"configurable": {
|
||||
"thread_id": conversation.id,
|
||||
"team": team,
|
||||
"user": user,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
result_message, artifact = await tool._arun_impl(instructions=instructions)
|
||||
|
||||
# Check if error was properly reported
|
||||
has_error = (
|
||||
"error" in artifact or "invalid" in result_message.lower() or "must sum to 100" in result_message.lower()
|
||||
)
|
||||
|
||||
return {
|
||||
"message": result_message,
|
||||
"has_error": has_error,
|
||||
}
|
||||
|
||||
await MaxPublicEval(
|
||||
experiment_name="create_multivariate_error_handling",
|
||||
task=task_create_invalid_multivariate, # type: ignore
|
||||
scores=[FeatureFlagOutputScorer(semantic_fields={"message"})],
|
||||
data=[
|
||||
EvalCase(
|
||||
input=f"Create an A/B test flag called 'invalid-percentage-{unique_suffix}' with control at 60% and test at 50%",
|
||||
expected={
|
||||
"message": "The variant percentages you provided (control: 60%, test: 50%) sum to 110%, but they must sum to exactly 100%. Please adjust the percentages so they add up to 100.",
|
||||
"has_error": True,
|
||||
},
|
||||
),
|
||||
],
|
||||
pytestconfig=pytestconfig,
|
||||
)
|
||||
|
||||
@@ -1,147 +0,0 @@
|
||||
import pytest
|
||||
|
||||
from braintrust import EvalCase
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from posthog.schema import AssistantMessage, AssistantNavigateUrl, AssistantToolCall, FailureMessage, HumanMessage
|
||||
|
||||
from ee.hogai.django_checkpoint.checkpointer import DjangoCheckpointer
|
||||
from ee.hogai.graph.graph import AssistantGraph
|
||||
from ee.hogai.utils.types import AssistantMessageUnion, AssistantNodeName, AssistantState
|
||||
from ee.models.assistant import Conversation
|
||||
|
||||
from ...base import MaxPublicEval
|
||||
from ...scorers import ToolRelevance
|
||||
|
||||
TOOLS_PROMPT = """
|
||||
- **actions**: Combine several related events into one, which you can then analyze in insights and dashboards as if it were a single event.
|
||||
- **cohorts**: A catalog of identified persons and your created cohorts.
|
||||
- **dashboards**: Create and manage your dashboards
|
||||
- **earlyAccessFeatures**: Allow your users to individually enable or disable features that are in public beta.
|
||||
- **errorTracking**: Track and analyze your error tracking data to understand and fix issues. [tools: Filter issues, Find impactful issues]
|
||||
- **experiments**: Experiments help you test changes to your product to see which changes will lead to optimal results. Automatic statistical calculations let you see if the results are valid or if they are likely just a chance occurrence.
|
||||
- **featureFlags**: Use feature flags to safely deploy and roll back new features in an easy-to-manage way. Roll variants out to certain groups, a percentage of users, or everyone all at once.
|
||||
- **notebooks**: Notebooks are a way to organize your work and share it with others.
|
||||
- **persons**: A catalog of all the people behind your events
|
||||
- **insights**: Track, analyze, and experiment with user behavior.
|
||||
- **insightNew** [tools: Edit the insight]
|
||||
- **savedInsights**: Track, analyze, and experiment with user behavior.
|
||||
- **alerts**: Track, analyze, and experiment with user behavior.
|
||||
- **replay**: Replay recordings of user sessions to understand how users interact with your product or website. [tools: Search recordings]
|
||||
- **revenueAnalytics**: Track and analyze your revenue metrics to understand your business performance and growth. [tools: Filter revenue analytics]
|
||||
- **surveys**: Create surveys to collect feedback from your users [tools: Create surveys, Analyze survey responses]
|
||||
- **webAnalytics**: Analyze your web analytics data to understand website performance and user behavior.
|
||||
- **webAnalyticsWebVitals**: Analyze your web analytics data to understand website performance and user behavior.
|
||||
- **activity**: A catalog of all user interactions with your app or website.
|
||||
- **sqlEditor**: Write and execute SQL queries against your data warehouse [tools: Write and tweak SQL]
|
||||
- **heatmaps**: Heatmaps are a way to visualize user behavior on your website.
|
||||
""".strip()
|
||||
|
||||
|
||||
class EvalInput(BaseModel):
|
||||
messages: str | list[AssistantMessageUnion]
|
||||
current_page: str = Field(default="")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def call_root(demo_org_team_user):
|
||||
graph = (
|
||||
AssistantGraph(demo_org_team_user[1], demo_org_team_user[2])
|
||||
.add_edge(AssistantNodeName.START, AssistantNodeName.ROOT)
|
||||
.add_root(lambda state: AssistantNodeName.END)
|
||||
# TRICKY: We need to set a checkpointer here because async tests create a new event loop.
|
||||
.compile(checkpointer=DjangoCheckpointer())
|
||||
)
|
||||
|
||||
async def callable(input: EvalInput) -> AssistantMessage:
|
||||
conversation = await Conversation.objects.acreate(team=demo_org_team_user[1], user=demo_org_team_user[2])
|
||||
initial_state = AssistantState(
|
||||
messages=[HumanMessage(content=input.messages)] if isinstance(input.messages, str) else input.messages
|
||||
)
|
||||
raw_state = await graph.ainvoke(
|
||||
initial_state,
|
||||
{
|
||||
"configurable": {
|
||||
"thread_id": conversation.id,
|
||||
"contextual_tools": {
|
||||
"navigate": {"scene_descriptions": TOOLS_PROMPT, "current_page": input.current_page}
|
||||
},
|
||||
}
|
||||
},
|
||||
)
|
||||
state = AssistantState.model_validate(raw_state)
|
||||
assert isinstance(state.messages[-1], AssistantMessage)
|
||||
return state.messages[-1]
|
||||
|
||||
return callable
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
async def eval_root_navigate_tool(call_root, pytestconfig):
|
||||
await MaxPublicEval(
|
||||
experiment_name="root_navigate_tool",
|
||||
task=call_root,
|
||||
scores=[ToolRelevance(semantic_similarity_args={"query_description"})],
|
||||
data=[
|
||||
# Shouldn't navigate to the insights page
|
||||
EvalCase(
|
||||
input=EvalInput(messages="build pageview insight"),
|
||||
expected=AssistantToolCall(
|
||||
id="1",
|
||||
name="create_and_query_insight",
|
||||
args={
|
||||
"query_description": "Create a trends insight showing pageview events over time. Track the $pageview event to visualize how many pageviews are happening."
|
||||
},
|
||||
),
|
||||
),
|
||||
# Should navigate to the persons page
|
||||
EvalCase(
|
||||
input=EvalInput(
|
||||
messages="I added tracking of persons, but I can't find where the persons are in the app"
|
||||
),
|
||||
expected=AssistantToolCall(
|
||||
id="1",
|
||||
name="navigate",
|
||||
args={"page_key": AssistantNavigateUrl.PERSONS.value},
|
||||
),
|
||||
),
|
||||
# Should navigate to the surveys page
|
||||
EvalCase(
|
||||
input=EvalInput(
|
||||
messages="were is my survey. I jsut created a survey and save it as draft, I cannot find it now",
|
||||
current_page="/project/1/surveys/new",
|
||||
),
|
||||
expected=AssistantToolCall(
|
||||
id="1",
|
||||
name="navigate",
|
||||
args={"page_key": AssistantNavigateUrl.SURVEYS.value},
|
||||
),
|
||||
),
|
||||
# Should not navigate to the SQL editor
|
||||
EvalCase(
|
||||
input=EvalInput(
|
||||
messages="I need a query written in SQL to tell me what all of my identified events are for any given day."
|
||||
),
|
||||
expected=AssistantToolCall(
|
||||
id="1",
|
||||
name="create_and_query_insight",
|
||||
args={"query_description": "All identified events for any given day"},
|
||||
),
|
||||
),
|
||||
# Should just say that the query failed
|
||||
EvalCase(
|
||||
input=EvalInput(
|
||||
messages=[
|
||||
HumanMessage(
|
||||
content="I need a query written in SQL to tell me what all of my identified events are for any given day."
|
||||
),
|
||||
FailureMessage(
|
||||
content="An unknown failure occurred while accessing the `events` table",
|
||||
),
|
||||
]
|
||||
),
|
||||
expected=None,
|
||||
),
|
||||
],
|
||||
pytestconfig=pytestconfig,
|
||||
)
|
||||
@@ -7,10 +7,12 @@ from uuid import uuid4
|
||||
from langchain_anthropic import ChatAnthropic
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_core.messages import (
|
||||
AIMessage as LangchainAIMessage,
|
||||
BaseMessage,
|
||||
HumanMessage as LangchainHumanMessage,
|
||||
)
|
||||
from langchain_core.tools import BaseTool
|
||||
from langchain_core.utils.function_calling import convert_to_openai_tool
|
||||
from pydantic import BaseModel
|
||||
|
||||
from posthog.schema import AssistantMessage, AssistantToolCallMessage, ContextMessage, HumanMessage
|
||||
@@ -36,7 +38,7 @@ class ConversationCompactionManager(ABC):
|
||||
Manages conversation window boundaries, message filtering, and summarization decisions.
|
||||
"""
|
||||
|
||||
CONVERSATION_WINDOW_SIZE = 64000
|
||||
CONVERSATION_WINDOW_SIZE = 100_000
|
||||
"""
|
||||
Determines the maximum number of tokens allowed in the conversation window.
|
||||
"""
|
||||
@@ -54,7 +56,7 @@ class ConversationCompactionManager(ABC):
|
||||
new_window_id: str | None = None
|
||||
for message in reversed(messages):
|
||||
# Handle limits before assigning the window ID.
|
||||
max_tokens -= self._get_estimated_tokens(message)
|
||||
max_tokens -= self._get_estimated_assistant_message_tokens(message)
|
||||
max_messages -= 1
|
||||
if max_tokens < 0 or max_messages < 0:
|
||||
break
|
||||
@@ -83,12 +85,20 @@ class ConversationCompactionManager(ABC):
|
||||
Determine if the conversation should be summarized based on token count.
|
||||
Avoids summarizing if there are only two human messages or fewer.
|
||||
"""
|
||||
return await self.calculate_token_count(model, messages, tools, **kwargs) > self.CONVERSATION_WINDOW_SIZE
|
||||
|
||||
async def calculate_token_count(
|
||||
self, model: BaseChatModel, messages: list[BaseMessage], tools: LangchainTools | None = None, **kwargs
|
||||
) -> int:
|
||||
"""
|
||||
Calculate the token count for a conversation.
|
||||
"""
|
||||
# Avoid summarizing the conversation if there is only two human messages.
|
||||
human_messages = [message for message in messages if isinstance(message, LangchainHumanMessage)]
|
||||
if len(human_messages) <= 2:
|
||||
return False
|
||||
token_count = await self._get_token_count(model, messages, tools, **kwargs)
|
||||
return token_count > self.CONVERSATION_WINDOW_SIZE
|
||||
tool_tokens = self._get_estimated_tools_tokens(tools) if tools else 0
|
||||
return sum(self._get_estimated_langchain_message_tokens(message) for message in messages) + tool_tokens
|
||||
return await self._get_token_count(model, messages, tools, **kwargs)
|
||||
|
||||
def update_window(
|
||||
self, messages: Sequence[T], summary_message: ContextMessage, start_id: str | None = None
|
||||
@@ -134,7 +144,7 @@ class ConversationCompactionManager(ABC):
|
||||
updated_window_start_id=window_start_id_candidate,
|
||||
)
|
||||
|
||||
def _get_estimated_tokens(self, message: AssistantMessageUnion) -> int:
|
||||
def _get_estimated_assistant_message_tokens(self, message: AssistantMessageUnion) -> int:
|
||||
"""
|
||||
Estimate token count for a message using character/4 heuristic.
|
||||
"""
|
||||
@@ -149,6 +159,24 @@ class ConversationCompactionManager(ABC):
|
||||
char_count = len(message.content)
|
||||
return round(char_count / self.APPROXIMATE_TOKEN_LENGTH)
|
||||
|
||||
def _get_estimated_langchain_message_tokens(self, message: BaseMessage) -> int:
|
||||
"""
|
||||
Estimate token count for a message using character/4 heuristic.
|
||||
"""
|
||||
char_count = 0
|
||||
if isinstance(message.content, str):
|
||||
char_count = len(message.content)
|
||||
else:
|
||||
for content in message.content:
|
||||
if isinstance(content, str):
|
||||
char_count += len(content)
|
||||
elif isinstance(content, dict):
|
||||
char_count += self._count_json_tokens(content)
|
||||
if isinstance(message, LangchainAIMessage) and message.tool_calls:
|
||||
for tool_call in message.tool_calls:
|
||||
char_count += len(json.dumps(tool_call, separators=(",", ":")))
|
||||
return round(char_count / self.APPROXIMATE_TOKEN_LENGTH)
|
||||
|
||||
def _get_conversation_window(self, messages: Sequence[T], start_id: str) -> Sequence[T]:
|
||||
"""
|
||||
Get messages from the start_id onwards.
|
||||
@@ -158,6 +186,22 @@ class ConversationCompactionManager(ABC):
|
||||
return messages[idx:]
|
||||
return messages
|
||||
|
||||
def _get_estimated_tools_tokens(self, tools: LangchainTools) -> int:
|
||||
"""
|
||||
Estimate token count for tools by converting them to JSON schemas.
|
||||
"""
|
||||
if not tools:
|
||||
return 0
|
||||
|
||||
total_chars = 0
|
||||
for tool in tools:
|
||||
tool_schema = convert_to_openai_tool(tool)
|
||||
total_chars += self._count_json_tokens(tool_schema)
|
||||
return round(total_chars / self.APPROXIMATE_TOKEN_LENGTH)
|
||||
|
||||
def _count_json_tokens(self, json_data: dict) -> int:
|
||||
return len(json.dumps(json_data, separators=(",", ":")))
|
||||
|
||||
@abstractmethod
|
||||
async def _get_token_count(
|
||||
self,
|
||||
|
||||
@@ -13,7 +13,6 @@ from langchain_core.messages import (
|
||||
)
|
||||
from langchain_core.prompts import ChatPromptTemplate
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.errors import NodeInterrupt
|
||||
from langgraph.types import Send
|
||||
from posthoganalytics import capture_exception
|
||||
|
||||
@@ -27,6 +26,7 @@ from ee.hogai.graph.conversation_summarizer.nodes import AnthropicConversationSu
|
||||
from ee.hogai.graph.shared_prompts import CORE_MEMORY_PROMPT
|
||||
from ee.hogai.llm import MaxChatAnthropic
|
||||
from ee.hogai.tool import ToolMessagesArtifact
|
||||
from ee.hogai.tool_errors import MaxToolError
|
||||
from ee.hogai.tools import ReadDataTool, ReadTaxonomyTool, SearchTool, TodoWriteTool
|
||||
from ee.hogai.utils.anthropic import add_cache_control, convert_to_anthropic_messages
|
||||
from ee.hogai.utils.helpers import convert_tool_messages_to_dict, normalize_ai_message
|
||||
@@ -42,9 +42,9 @@ from ee.hogai.utils.types.base import NodePath
|
||||
|
||||
from .compaction_manager import AnthropicConversationCompactionManager
|
||||
from .prompts import (
|
||||
AGENT_CORE_MEMORY_PROMPT,
|
||||
AGENT_PROMPT,
|
||||
BASIC_FUNCTIONALITY_PROMPT,
|
||||
CORE_MEMORY_INSTRUCTIONS_PROMPT,
|
||||
DOING_TASKS_PROMPT,
|
||||
PROACTIVENESS_PROMPT,
|
||||
ROLE_PROMPT,
|
||||
@@ -187,12 +187,18 @@ class AgentExecutable(BaseAgentExecutable):
|
||||
start_id = state.start_id
|
||||
|
||||
# Summarize the conversation if it's too long.
|
||||
if await self._window_manager.should_compact_conversation(
|
||||
current_token_count = await self._window_manager.calculate_token_count(
|
||||
model, langchain_messages, tools=tools, thinking_config=self.THINKING_CONFIG
|
||||
):
|
||||
)
|
||||
if current_token_count > self._window_manager.CONVERSATION_WINDOW_SIZE:
|
||||
# Exclude the last message if it's the first turn.
|
||||
messages_to_summarize = langchain_messages[:-1] if self._is_first_turn(state) else langchain_messages
|
||||
summary = await AnthropicConversationSummarizer(self._team, self._user).summarize(messages_to_summarize)
|
||||
summary = await AnthropicConversationSummarizer(
|
||||
self._team,
|
||||
self._user,
|
||||
extend_context_window=current_token_count > 195_000,
|
||||
).summarize(messages_to_summarize)
|
||||
|
||||
summary_message = ContextMessage(
|
||||
content=ROOT_CONVERSATION_SUMMARY_PROMPT.format(summary=summary),
|
||||
id=str(uuid4()),
|
||||
@@ -212,6 +218,7 @@ class AgentExecutable(BaseAgentExecutable):
|
||||
system_prompts = ChatPromptTemplate.from_messages(
|
||||
[
|
||||
("system", self._get_system_prompt(state, config)),
|
||||
("system", AGENT_CORE_MEMORY_PROMPT),
|
||||
],
|
||||
template_format="mustache",
|
||||
).format_messages(
|
||||
@@ -221,7 +228,7 @@ class AgentExecutable(BaseAgentExecutable):
|
||||
)
|
||||
|
||||
# Mark the longest default prefix as cacheable
|
||||
add_cache_control(system_prompts[-1])
|
||||
add_cache_control(system_prompts[0], ttl="1h")
|
||||
|
||||
message = await model.ainvoke(system_prompts + langchain_messages, config)
|
||||
assistant_message = self._process_output_message(message)
|
||||
@@ -263,7 +270,6 @@ class AgentExecutable(BaseAgentExecutable):
|
||||
- `{{{task_management}}}`
|
||||
- `{{{doing_tasks}}}`
|
||||
- `{{{tool_usage_policy}}}`
|
||||
- `{{{core_memory_instructions}}}`
|
||||
|
||||
The variables from above can have the following nested variables that will be injected:
|
||||
- `{{{groups}}}` – a prompt containing the description of the groups.
|
||||
@@ -291,7 +297,6 @@ class AgentExecutable(BaseAgentExecutable):
|
||||
task_management=TASK_MANAGEMENT_PROMPT,
|
||||
doing_tasks=DOING_TASKS_PROMPT,
|
||||
tool_usage_policy=TOOL_USAGE_POLICY_PROMPT,
|
||||
core_memory_instructions=CORE_MEMORY_INSTRUCTIONS_PROMPT,
|
||||
)
|
||||
|
||||
async def _get_billing_prompt(self) -> str:
|
||||
@@ -319,10 +324,11 @@ class AgentExecutable(BaseAgentExecutable):
|
||||
stream_usage=True,
|
||||
user=self._user,
|
||||
team=self._team,
|
||||
betas=["interleaved-thinking-2025-05-14"],
|
||||
betas=["interleaved-thinking-2025-05-14", "context-1m-2025-08-07"],
|
||||
max_tokens=8192,
|
||||
thinking=self.THINKING_CONFIG,
|
||||
conversation_start_dt=state.start_dt,
|
||||
billable=True,
|
||||
)
|
||||
|
||||
# The agent can operate in loops. Since insight building is an expensive operation, we want to limit a recursion depth.
|
||||
@@ -465,6 +471,30 @@ class AgentToolsExecutable(BaseAgentExecutable):
|
||||
raise ValueError(
|
||||
f"Tool '{tool_call.name}' returned {type(result).__name__}, expected LangchainToolMessage"
|
||||
)
|
||||
except MaxToolError as e:
|
||||
logger.exception(
|
||||
"maxtool_error", extra={"tool": tool_call.name, "error": str(e), "retry_strategy": e.retry_strategy}
|
||||
)
|
||||
capture_exception(
|
||||
e,
|
||||
distinct_id=self._get_user_distinct_id(config),
|
||||
properties={
|
||||
**self._get_debug_props(config),
|
||||
"tool": tool_call.name,
|
||||
"retry_strategy": e.retry_strategy,
|
||||
},
|
||||
)
|
||||
|
||||
content = f"Tool failed: {e.to_summary()}.{e.retry_hint}"
|
||||
return PartialAssistantState(
|
||||
messages=[
|
||||
AssistantToolCallMessage(
|
||||
content=content,
|
||||
id=str(uuid4()),
|
||||
tool_call_id=tool_call.id,
|
||||
)
|
||||
],
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception("Error calling tool", extra={"tool_name": tool_call.name, "error": str(e)})
|
||||
capture_exception(
|
||||
@@ -485,21 +515,6 @@ class AgentToolsExecutable(BaseAgentExecutable):
|
||||
messages=result.artifact.messages,
|
||||
)
|
||||
|
||||
# If this is a navigation tool call, pause the graph execution
|
||||
# so that the frontend can re-initialise Max with a new set of contextual tools.
|
||||
if tool_call.name == "navigate":
|
||||
navigate_message = AssistantToolCallMessage(
|
||||
content=str(result.content) if result.content else "",
|
||||
ui_payload={tool_call.name: result.artifact},
|
||||
id=str(uuid4()),
|
||||
tool_call_id=tool_call.id,
|
||||
)
|
||||
# Raising a `NodeInterrupt` ensures the assistant graph stops here and
|
||||
# surfaces the navigation confirmation to the client. The next user
|
||||
# interaction will resume the graph with potentially different
|
||||
# contextual tools.
|
||||
raise NodeInterrupt(navigate_message)
|
||||
|
||||
tool_message = AssistantToolCallMessage(
|
||||
content=str(result.content) if result.content else "",
|
||||
ui_payload={tool_call.name: result.artifact},
|
||||
|
||||
@@ -27,6 +27,7 @@ Do not create links like "here" or "click here". All links should have relevant
|
||||
We always use sentence case rather than title case, including in titles, headings, subheadings, or bold text. However if quoting provided text, we keep the original case.
|
||||
When writing numbers in the thousands to the billions, it's acceptable to abbreviate them (like 10M or 100B - capital letter, no space). If you write out the full number, use commas (like 15,000,000).
|
||||
You can use light Markdown formatting for readability. Never use the em-dash (—) if you can use the en-dash (–).
|
||||
For headers, use sentence case rather than title case.
|
||||
</writing_style>
|
||||
""".strip()
|
||||
|
||||
@@ -128,14 +129,10 @@ The user is a product engineer and will primarily request you perform product ma
|
||||
TOOL_USAGE_POLICY_PROMPT = """
|
||||
<tool_usage_policy>
|
||||
- You can invoke multiple tools within a single response. When a request involves several independent pieces of information, batch your tool calls together for optimal performance
|
||||
- Retry failed tool calls only if the error proposes retrying, or suggests how to fix tool arguments
|
||||
</tool_usage_policy>
|
||||
""".strip()
|
||||
|
||||
CORE_MEMORY_INSTRUCTIONS_PROMPT = """
|
||||
{{{core_memory}}}
|
||||
New memories will automatically be added to the core memory as the conversation progresses. If users ask to save, update, or delete the core memory, say you have done it. If the '/remember [information]' command is used, the information gets appended verbatim to core memory.
|
||||
""".strip()
|
||||
|
||||
AGENT_PROMPT = """
|
||||
{{{role}}}
|
||||
|
||||
@@ -154,8 +151,11 @@ AGENT_PROMPT = """
|
||||
{{{tool_usage_policy}}}
|
||||
|
||||
{{{billing_context}}}
|
||||
""".strip()
|
||||
|
||||
{{{core_memory_instructions}}}
|
||||
AGENT_CORE_MEMORY_PROMPT = """
|
||||
{{{core_memory}}}
|
||||
New memories will automatically be added to the core memory as the conversation progresses. If users ask to save, update, or delete the core memory, say you have done it. If the '/remember [information]' command is used, the information gets appended verbatim to core memory.
|
||||
""".strip()
|
||||
|
||||
# Conditional prompts
|
||||
|
||||
@@ -114,11 +114,11 @@ class TestAnthropicConversationCompactionManager(BaseTest):
|
||||
@parameterized.expand(
|
||||
[
|
||||
# (num_human_messages, token_count, should_compact)
|
||||
[1, 70000, False], # Only 1 human message
|
||||
[2, 70000, False], # Only 2 human messages
|
||||
[3, 50000, False], # 3 human messages but under token limit
|
||||
[3, 70000, True], # 3 human messages and over token limit
|
||||
[5, 70000, True], # Many messages over limit
|
||||
[1, 90000, False], # Only 1 human message, under limit
|
||||
[2, 90000, False], # Only 2 human messages, under limit
|
||||
[3, 80000, False], # 3 human messages but under token limit
|
||||
[3, 110000, True], # 3 human messages and over token limit
|
||||
[5, 110000, True], # Many messages over limit
|
||||
]
|
||||
)
|
||||
async def test_should_compact_conversation(self, num_human_messages, token_count, should_compact):
|
||||
@@ -136,19 +136,57 @@ class TestAnthropicConversationCompactionManager(BaseTest):
|
||||
result = await self.window_manager.should_compact_conversation(mock_model, messages)
|
||||
self.assertEqual(result, should_compact)
|
||||
|
||||
def test_get_estimated_tokens_human_message(self):
|
||||
async def test_should_compact_conversation_with_tools_under_limit(self):
|
||||
"""Test that tools are accounted for when estimating tokens with 2 or fewer human messages"""
|
||||
from langchain_core.tools import tool
|
||||
|
||||
@tool
|
||||
def test_tool(query: str) -> str:
|
||||
"""A test tool"""
|
||||
return f"Result for {query}"
|
||||
|
||||
messages: list[BaseMessage] = [
|
||||
LangchainHumanMessage(content="A" * 1000), # ~250 tokens
|
||||
LangchainAIMessage(content="B" * 1000), # ~250 tokens
|
||||
]
|
||||
tools = [test_tool]
|
||||
|
||||
mock_model = MagicMock()
|
||||
# With 2 human messages, should use estimation and not call _get_token_count
|
||||
result = await self.window_manager.should_compact_conversation(mock_model, messages, tools=tools)
|
||||
|
||||
# Total should be well under 100k limit
|
||||
self.assertFalse(result)
|
||||
|
||||
async def test_should_compact_conversation_with_tools_over_limit(self):
|
||||
"""Test that tools push estimation over limit with 2 or fewer human messages"""
|
||||
messages: list[BaseMessage] = [
|
||||
LangchainHumanMessage(content="A" * 200000), # ~50k tokens
|
||||
LangchainAIMessage(content="B" * 200000), # ~50k tokens
|
||||
]
|
||||
|
||||
# Create large tool schemas to push over 100k limit
|
||||
tools = [{"type": "function", "function": {"name": f"tool_{i}", "description": "X" * 1000}} for i in range(100)]
|
||||
|
||||
mock_model = MagicMock()
|
||||
result = await self.window_manager.should_compact_conversation(mock_model, messages, tools=tools)
|
||||
|
||||
# Should be over the 100k limit
|
||||
self.assertTrue(result)
|
||||
|
||||
def test_get_estimated_assistant_message_tokens_human_message(self):
|
||||
"""Test token estimation for human messages"""
|
||||
message = HumanMessage(content="A" * 100, id="1") # 100 chars = ~25 tokens
|
||||
tokens = self.window_manager._get_estimated_tokens(message)
|
||||
tokens = self.window_manager._get_estimated_assistant_message_tokens(message)
|
||||
self.assertEqual(tokens, 25)
|
||||
|
||||
def test_get_estimated_tokens_assistant_message(self):
|
||||
def test_get_estimated_assistant_message_tokens_assistant_message(self):
|
||||
"""Test token estimation for assistant messages without tool calls"""
|
||||
message = AssistantMessage(content="A" * 100, id="1") # 100 chars = ~25 tokens
|
||||
tokens = self.window_manager._get_estimated_tokens(message)
|
||||
tokens = self.window_manager._get_estimated_assistant_message_tokens(message)
|
||||
self.assertEqual(tokens, 25)
|
||||
|
||||
def test_get_estimated_tokens_assistant_message_with_tool_calls(self):
|
||||
def test_get_estimated_assistant_message_tokens_assistant_message_with_tool_calls(self):
|
||||
"""Test token estimation for assistant messages with tool calls"""
|
||||
message = AssistantMessage(
|
||||
content="A" * 100, # 100 chars
|
||||
@@ -162,17 +200,117 @@ class TestAnthropicConversationCompactionManager(BaseTest):
|
||||
],
|
||||
)
|
||||
# Should count content + JSON serialized args
|
||||
tokens = self.window_manager._get_estimated_tokens(message)
|
||||
tokens = self.window_manager._get_estimated_assistant_message_tokens(message)
|
||||
# 100 chars content + ~15 chars for args = ~29 tokens
|
||||
self.assertGreater(tokens, 25)
|
||||
self.assertLess(tokens, 35)
|
||||
|
||||
def test_get_estimated_tokens_tool_call_message(self):
|
||||
def test_get_estimated_assistant_message_tokens_tool_call_message(self):
|
||||
"""Test token estimation for tool call messages"""
|
||||
message = AssistantToolCallMessage(content="A" * 200, id="1", tool_call_id="t1")
|
||||
tokens = self.window_manager._get_estimated_tokens(message)
|
||||
tokens = self.window_manager._get_estimated_assistant_message_tokens(message)
|
||||
self.assertEqual(tokens, 50)
|
||||
|
||||
def test_get_estimated_langchain_message_tokens_string_content(self):
|
||||
"""Test token estimation for langchain messages with string content"""
|
||||
message = LangchainHumanMessage(content="A" * 100)
|
||||
tokens = self.window_manager._get_estimated_langchain_message_tokens(message)
|
||||
self.assertEqual(tokens, 25)
|
||||
|
||||
def test_get_estimated_langchain_message_tokens_list_content_with_strings(self):
|
||||
"""Test token estimation for langchain messages with list of string content"""
|
||||
message = LangchainHumanMessage(content=["A" * 100, "B" * 100])
|
||||
tokens = self.window_manager._get_estimated_langchain_message_tokens(message)
|
||||
self.assertEqual(tokens, 50)
|
||||
|
||||
def test_get_estimated_langchain_message_tokens_list_content_with_dicts(self):
|
||||
"""Test token estimation for langchain messages with dict content"""
|
||||
message = LangchainHumanMessage(content=[{"type": "text", "text": "A" * 100}])
|
||||
tokens = self.window_manager._get_estimated_langchain_message_tokens(message)
|
||||
# 100 chars for text + overhead for JSON structure
|
||||
self.assertGreater(tokens, 25)
|
||||
self.assertLess(tokens, 40)
|
||||
|
||||
def test_get_estimated_langchain_message_tokens_ai_message_with_tool_calls(self):
|
||||
"""Test token estimation for AI messages with tool calls"""
|
||||
message = LangchainAIMessage(
|
||||
content="A" * 100,
|
||||
tool_calls=[
|
||||
{"id": "t1", "name": "test_tool", "args": {"key": "value"}},
|
||||
{"id": "t2", "name": "another_tool", "args": {"foo": "bar"}},
|
||||
],
|
||||
)
|
||||
tokens = self.window_manager._get_estimated_langchain_message_tokens(message)
|
||||
# Content + tool calls JSON
|
||||
self.assertGreater(tokens, 25)
|
||||
self.assertLess(tokens, 70)
|
||||
|
||||
def test_get_estimated_langchain_message_tokens_ai_message_without_tool_calls(self):
|
||||
"""Test token estimation for AI messages without tool calls"""
|
||||
message = LangchainAIMessage(content="A" * 100)
|
||||
tokens = self.window_manager._get_estimated_langchain_message_tokens(message)
|
||||
self.assertEqual(tokens, 25)
|
||||
|
||||
def test_count_json_tokens(self):
|
||||
"""Test JSON token counting helper"""
|
||||
json_data = {"key": "value", "nested": {"foo": "bar"}}
|
||||
char_count = self.window_manager._count_json_tokens(json_data)
|
||||
# Should match length of compact JSON
|
||||
import json
|
||||
|
||||
expected = len(json.dumps(json_data, separators=(",", ":")))
|
||||
self.assertEqual(char_count, expected)
|
||||
|
||||
def test_get_estimated_tools_tokens_empty(self):
|
||||
"""Test tool token estimation with no tools"""
|
||||
tokens = self.window_manager._get_estimated_tools_tokens([])
|
||||
self.assertEqual(tokens, 0)
|
||||
|
||||
def test_get_estimated_tools_tokens_with_dict_tools(self):
|
||||
"""Test tool token estimation with dict tools"""
|
||||
tools = [
|
||||
{"type": "function", "function": {"name": "test_tool", "description": "A test tool"}},
|
||||
]
|
||||
tokens = self.window_manager._get_estimated_tools_tokens(tools)
|
||||
# Should be positive and reasonable
|
||||
self.assertGreater(tokens, 0)
|
||||
self.assertLess(tokens, 100)
|
||||
|
||||
def test_get_estimated_tools_tokens_with_base_tool(self):
|
||||
"""Test tool token estimation with BaseTool"""
|
||||
from langchain_core.tools import tool
|
||||
|
||||
@tool
|
||||
def sample_tool(query: str) -> str:
|
||||
"""A sample tool for testing"""
|
||||
return f"Result for {query}"
|
||||
|
||||
tools = [sample_tool]
|
||||
tokens = self.window_manager._get_estimated_tools_tokens(tools)
|
||||
# Should count the tool schema
|
||||
self.assertGreater(tokens, 0)
|
||||
self.assertLess(tokens, 200)
|
||||
|
||||
def test_get_estimated_tools_tokens_multiple_tools(self):
|
||||
"""Test tool token estimation with multiple tools"""
|
||||
from langchain_core.tools import tool
|
||||
|
||||
@tool
|
||||
def tool1(x: int) -> int:
|
||||
"""First tool"""
|
||||
return x * 2
|
||||
|
||||
@tool
|
||||
def tool2(y: str) -> str:
|
||||
"""Second tool"""
|
||||
return y.upper()
|
||||
|
||||
tools = [tool1, tool2]
|
||||
tokens = self.window_manager._get_estimated_tools_tokens(tools)
|
||||
# Should count both tool schemas
|
||||
self.assertGreater(tokens, 0)
|
||||
self.assertLess(tokens, 400)
|
||||
|
||||
async def test_get_token_count_calls_model(self):
|
||||
"""Test that _get_token_count properly calls the model's token counting"""
|
||||
mock_model = MagicMock()
|
||||
|
||||
@@ -13,7 +13,6 @@ from langchain_core.messages import (
|
||||
)
|
||||
from langchain_core.outputs import ChatGeneration, ChatResult
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.errors import NodeInterrupt
|
||||
from parameterized import parameterized
|
||||
|
||||
from posthog.schema import (
|
||||
@@ -34,6 +33,7 @@ from posthog.models.organization import OrganizationMembership
|
||||
from products.replay.backend.max_tools import SearchSessionRecordingsTool
|
||||
|
||||
from ee.hogai.context import AssistantContextManager
|
||||
from ee.hogai.tool_errors import MaxToolError, MaxToolFatalError, MaxToolRetryableError, MaxToolTransientError
|
||||
from ee.hogai.tools.read_taxonomy import ReadEvents
|
||||
from ee.hogai.utils.tests import FakeChatAnthropic, FakeChatOpenAI
|
||||
from ee.hogai.utils.types import AssistantState, PartialAssistantState
|
||||
@@ -445,6 +445,46 @@ class TestAgentNode(ClickhouseTestMixin, BaseTest):
|
||||
self.assertIn("You are currently in project ", system_content)
|
||||
self.assertIn("The user's name appears to be ", system_content)
|
||||
|
||||
async def test_node_includes_core_memory_in_system_prompt(self):
|
||||
"""Test that core memory content is appended to the conversation in system prompts"""
|
||||
with (
|
||||
patch("os.environ", {"ANTHROPIC_API_KEY": "foo"}),
|
||||
patch("langchain_anthropic.chat_models.ChatAnthropic._agenerate") as mock_generate,
|
||||
patch("ee.hogai.graph.agent_modes.nodes.AgentExecutable._aget_core_memory_text") as mock_core_memory,
|
||||
):
|
||||
mock_core_memory.return_value = "User prefers concise responses and technical details"
|
||||
mock_generate.return_value = ChatResult(
|
||||
generations=[ChatGeneration(message=AIMessage(content="Response"))],
|
||||
llm_output={},
|
||||
)
|
||||
|
||||
node = _create_agent_node(self.team, self.user)
|
||||
config = RunnableConfig(configurable={})
|
||||
node._config = config
|
||||
|
||||
await node.arun(AssistantState(messages=[HumanMessage(content="Test")]), config)
|
||||
|
||||
# Verify _agenerate was called
|
||||
mock_generate.assert_called_once()
|
||||
|
||||
# Get the messages passed to _agenerate
|
||||
call_args = mock_generate.call_args
|
||||
messages = call_args[0][0]
|
||||
|
||||
# Check system messages contain core memory
|
||||
system_messages = [msg for msg in messages if isinstance(msg, SystemMessage)]
|
||||
self.assertGreater(len(system_messages), 0)
|
||||
|
||||
content_parts = []
|
||||
for msg in system_messages:
|
||||
if isinstance(msg.content, str):
|
||||
content_parts.append(msg.content)
|
||||
else:
|
||||
content_parts.append(str(msg.content))
|
||||
system_content = "\n\n".join(content_parts)
|
||||
|
||||
self.assertIn("User prefers concise responses and technical details", system_content)
|
||||
|
||||
@parameterized.expand(
|
||||
[
|
||||
# (membership_level, add_context, expected_prompt)
|
||||
@@ -480,13 +520,12 @@ class TestAgentNode(ClickhouseTestMixin, BaseTest):
|
||||
self.assertEqual(await node._get_billing_prompt(), expected_prompt)
|
||||
|
||||
@patch("ee.hogai.graph.agent_modes.nodes.AgentExecutable._get_model", return_value=FakeChatOpenAI(responses=[]))
|
||||
@patch(
|
||||
"ee.hogai.graph.agent_modes.compaction_manager.AnthropicConversationCompactionManager.should_compact_conversation"
|
||||
)
|
||||
@patch("ee.hogai.graph.agent_modes.compaction_manager.AnthropicConversationCompactionManager.calculate_token_count")
|
||||
@patch("ee.hogai.graph.conversation_summarizer.nodes.AnthropicConversationSummarizer.summarize")
|
||||
async def test_conversation_summarization_flow(self, mock_summarize, mock_should_compact, mock_model):
|
||||
async def test_conversation_summarization_flow(self, mock_summarize, mock_calculate_tokens, mock_model):
|
||||
"""Test that conversation is summarized when it gets too long"""
|
||||
mock_should_compact.return_value = True
|
||||
# Return a token count higher than CONVERSATION_WINDOW_SIZE (100,000)
|
||||
mock_calculate_tokens.return_value = 150_000
|
||||
mock_summarize.return_value = "This is a summary of the conversation so far."
|
||||
|
||||
mock_model_instance = FakeChatOpenAI(responses=[LangchainAIMessage(content="Response after summary")])
|
||||
@@ -514,13 +553,12 @@ class TestAgentNode(ClickhouseTestMixin, BaseTest):
|
||||
self.assertIn("This is a summary of the conversation so far.", context_messages[0].content)
|
||||
|
||||
@patch("ee.hogai.graph.agent_modes.nodes.AgentExecutable._get_model", return_value=FakeChatOpenAI(responses=[]))
|
||||
@patch(
|
||||
"ee.hogai.graph.agent_modes.compaction_manager.AnthropicConversationCompactionManager.should_compact_conversation"
|
||||
)
|
||||
@patch("ee.hogai.graph.agent_modes.compaction_manager.AnthropicConversationCompactionManager.calculate_token_count")
|
||||
@patch("ee.hogai.graph.conversation_summarizer.nodes.AnthropicConversationSummarizer.summarize")
|
||||
async def test_conversation_summarization_on_first_turn(self, mock_summarize, mock_should_compact, mock_model):
|
||||
async def test_conversation_summarization_on_first_turn(self, mock_summarize, mock_calculate_tokens, mock_model):
|
||||
"""Test that on first turn, the last message is excluded from summarization"""
|
||||
mock_should_compact.return_value = True
|
||||
# Return a token count higher than CONVERSATION_WINDOW_SIZE (100,000)
|
||||
mock_calculate_tokens.return_value = 150_000
|
||||
mock_summarize.return_value = "Summary without last message"
|
||||
|
||||
mock_model_instance = FakeChatOpenAI(responses=[LangchainAIMessage(content="Response")])
|
||||
@@ -871,40 +909,6 @@ class TestAgentToolsNode(BaseTest):
|
||||
self.assertEqual(len(result.messages), 1)
|
||||
self.assertIsInstance(result.messages[0], AssistantToolCallMessage)
|
||||
|
||||
async def test_navigate_tool_call_raises_node_interrupt(self):
|
||||
"""Test that navigate tool calls raise NodeInterrupt to pause graph execution"""
|
||||
node = _create_agent_tools_node(self.team, self.user)
|
||||
|
||||
state = AssistantState(
|
||||
messages=[
|
||||
AssistantMessage(
|
||||
content="I'll help you navigate to insights",
|
||||
id="test-id",
|
||||
tool_calls=[AssistantToolCall(id="nav-123", name="navigate", args={"page_key": "insights"})],
|
||||
)
|
||||
],
|
||||
root_tool_call_id="nav-123",
|
||||
)
|
||||
|
||||
mock_navigate_tool = AsyncMock()
|
||||
mock_navigate_tool.ainvoke.return_value = LangchainToolMessage(
|
||||
content="XXX", tool_call_id="nav-123", artifact={"page_key": "insights"}
|
||||
)
|
||||
|
||||
# The navigate tool call should raise NodeInterrupt
|
||||
with self.assertRaises(NodeInterrupt) as cm:
|
||||
await node(state, {"configurable": {"contextual_tools": {"navigate": {}}}})
|
||||
|
||||
# Verify the NodeInterrupt contains the expected message
|
||||
# NodeInterrupt wraps the message in an Interrupt object
|
||||
interrupt_data = cm.exception.args[0]
|
||||
if isinstance(interrupt_data, list):
|
||||
interrupt_data = interrupt_data[0].value
|
||||
self.assertIsInstance(interrupt_data, AssistantToolCallMessage)
|
||||
self.assertIn("Navigated to **insights**.", interrupt_data.content)
|
||||
self.assertEqual(interrupt_data.tool_call_id, "nav-123")
|
||||
self.assertEqual(interrupt_data.ui_payload, {"navigate": {"page_key": "insights"}})
|
||||
|
||||
async def test_arun_tool_returns_wrong_type_returns_error_message(self):
|
||||
"""Test that tool returning wrong type returns an error message"""
|
||||
node = _create_agent_tools_node(self.team, self.user)
|
||||
@@ -955,3 +959,163 @@ class TestAgentToolsNode(BaseTest):
|
||||
assert isinstance(result.messages[0], AssistantToolCallMessage)
|
||||
self.assertEqual(result.messages[0].tool_call_id, "tool-123")
|
||||
self.assertIn("does not exist", result.messages[0].content)
|
||||
|
||||
@patch("ee.hogai.tools.read_taxonomy.ReadTaxonomyTool._run_impl")
|
||||
async def test_max_tool_fatal_error_returns_error_message(self, read_taxonomy_mock):
|
||||
"""Test that MaxToolFatalError is caught and converted to tool message."""
|
||||
read_taxonomy_mock.side_effect = MaxToolFatalError(
|
||||
"Configuration error: INKEEP_API_KEY environment variable is not set"
|
||||
)
|
||||
|
||||
node = _create_agent_tools_node(self.team, self.user)
|
||||
state = AssistantState(
|
||||
messages=[
|
||||
AssistantMessage(
|
||||
content="Using tool that will fail",
|
||||
id="test-id",
|
||||
tool_calls=[
|
||||
AssistantToolCall(id="tool-123", name="read_taxonomy", args={"query": {"kind": "events"}})
|
||||
],
|
||||
)
|
||||
],
|
||||
root_tool_call_id="tool-123",
|
||||
)
|
||||
|
||||
result = await node.arun(state, {})
|
||||
|
||||
self.assertIsInstance(result, PartialAssistantState)
|
||||
assert result is not None
|
||||
self.assertEqual(len(result.messages), 1)
|
||||
assert isinstance(result.messages[0], AssistantToolCallMessage)
|
||||
self.assertEqual(result.messages[0].tool_call_id, "tool-123")
|
||||
self.assertIn("Configuration error", result.messages[0].content)
|
||||
self.assertIn("INKEEP_API_KEY", result.messages[0].content)
|
||||
self.assertNotIn("retry", result.messages[0].content.lower())
|
||||
|
||||
@patch("ee.hogai.tools.read_taxonomy.ReadTaxonomyTool._run_impl")
|
||||
async def test_max_tool_retryable_error_returns_error_with_retry_hint(self, read_taxonomy_mock):
|
||||
"""Test that MaxToolRetryableError includes retry hint for adjusted inputs."""
|
||||
read_taxonomy_mock.side_effect = MaxToolRetryableError(
|
||||
"Invalid entity kind: 'unknown_entity'. Must be one of: person, session, organization"
|
||||
)
|
||||
|
||||
node = _create_agent_tools_node(self.team, self.user)
|
||||
state = AssistantState(
|
||||
messages=[
|
||||
AssistantMessage(
|
||||
content="Using tool with invalid input",
|
||||
id="test-id",
|
||||
tool_calls=[
|
||||
AssistantToolCall(id="tool-123", name="read_taxonomy", args={"query": {"kind": "events"}})
|
||||
],
|
||||
)
|
||||
],
|
||||
root_tool_call_id="tool-123",
|
||||
)
|
||||
|
||||
result = await node.arun(state, {})
|
||||
|
||||
self.assertIsInstance(result, PartialAssistantState)
|
||||
assert result is not None
|
||||
self.assertEqual(len(result.messages), 1)
|
||||
assert isinstance(result.messages[0], AssistantToolCallMessage)
|
||||
self.assertEqual(result.messages[0].tool_call_id, "tool-123")
|
||||
self.assertIn("Invalid entity kind", result.messages[0].content)
|
||||
self.assertIn("retry with adjusted inputs", result.messages[0].content.lower())
|
||||
|
||||
@patch("ee.hogai.tools.read_taxonomy.ReadTaxonomyTool._run_impl")
|
||||
async def test_max_tool_transient_error_returns_error_with_once_retry_hint(self, read_taxonomy_mock):
|
||||
"""Test that MaxToolTransientError includes hint to retry once without changes."""
|
||||
read_taxonomy_mock.side_effect = MaxToolTransientError("Rate limit exceeded. Please try again in a few moments")
|
||||
|
||||
node = _create_agent_tools_node(self.team, self.user)
|
||||
state = AssistantState(
|
||||
messages=[
|
||||
AssistantMessage(
|
||||
content="Using tool that hits rate limit",
|
||||
id="test-id",
|
||||
tool_calls=[
|
||||
AssistantToolCall(id="tool-123", name="read_taxonomy", args={"query": {"kind": "events"}})
|
||||
],
|
||||
)
|
||||
],
|
||||
root_tool_call_id="tool-123",
|
||||
)
|
||||
|
||||
result = await node.arun(state, {})
|
||||
|
||||
self.assertIsInstance(result, PartialAssistantState)
|
||||
assert result is not None
|
||||
self.assertEqual(len(result.messages), 1)
|
||||
assert isinstance(result.messages[0], AssistantToolCallMessage)
|
||||
self.assertEqual(result.messages[0].tool_call_id, "tool-123")
|
||||
self.assertIn("Rate limit exceeded", result.messages[0].content)
|
||||
self.assertIn("retry this operation once without changes", result.messages[0].content.lower())
|
||||
|
||||
@patch("ee.hogai.tools.read_taxonomy.ReadTaxonomyTool._run_impl")
|
||||
async def test_generic_exception_returns_internal_error_message(self, read_taxonomy_mock):
|
||||
"""Test that generic exceptions are caught and return internal error message."""
|
||||
read_taxonomy_mock.side_effect = RuntimeError("Unexpected internal error")
|
||||
|
||||
node = _create_agent_tools_node(self.team, self.user)
|
||||
state = AssistantState(
|
||||
messages=[
|
||||
AssistantMessage(
|
||||
content="Using tool that crashes unexpectedly",
|
||||
id="test-id",
|
||||
tool_calls=[
|
||||
AssistantToolCall(id="tool-123", name="read_taxonomy", args={"query": {"kind": "events"}})
|
||||
],
|
||||
)
|
||||
],
|
||||
root_tool_call_id="tool-123",
|
||||
)
|
||||
|
||||
result = await node.arun(state, {})
|
||||
|
||||
self.assertIsInstance(result, PartialAssistantState)
|
||||
assert result is not None
|
||||
self.assertEqual(len(result.messages), 1)
|
||||
assert isinstance(result.messages[0], AssistantToolCallMessage)
|
||||
self.assertEqual(result.messages[0].tool_call_id, "tool-123")
|
||||
self.assertIn("internal error", result.messages[0].content.lower())
|
||||
self.assertIn("do not immediately retry", result.messages[0].content.lower())
|
||||
|
||||
@parameterized.expand(
|
||||
[
|
||||
("fatal", MaxToolFatalError("Fatal error"), "never"),
|
||||
("transient", MaxToolTransientError("Transient error"), "once"),
|
||||
("retryable", MaxToolRetryableError("Retryable error"), "adjusted"),
|
||||
]
|
||||
)
|
||||
@patch("ee.hogai.tools.read_taxonomy.ReadTaxonomyTool._run_impl")
|
||||
async def test_all_error_types_are_logged_with_retry_strategy(
|
||||
self, name, error, expected_strategy, read_taxonomy_mock
|
||||
):
|
||||
"""Test that all MaxToolError types are logged with their retry strategy."""
|
||||
read_taxonomy_mock.side_effect = error
|
||||
|
||||
node = _create_agent_tools_node(self.team, self.user)
|
||||
state = AssistantState(
|
||||
messages=[
|
||||
AssistantMessage(
|
||||
content="Using tool",
|
||||
id="test-id",
|
||||
tool_calls=[
|
||||
AssistantToolCall(id="tool-123", name="read_taxonomy", args={"query": {"kind": "events"}})
|
||||
],
|
||||
)
|
||||
],
|
||||
root_tool_call_id="tool-123",
|
||||
)
|
||||
|
||||
with patch("ee.hogai.graph.agent_modes.nodes.capture_exception") as mock_capture:
|
||||
_ = await node.arun(state, {})
|
||||
|
||||
mock_capture.assert_called_once()
|
||||
call_kwargs = mock_capture.call_args.kwargs
|
||||
captured_error = mock_capture.call_args.args[0]
|
||||
|
||||
self.assertIsInstance(captured_error, MaxToolError)
|
||||
self.assertEqual(call_kwargs["properties"]["retry_strategy"], expected_strategy)
|
||||
self.assertEqual(call_kwargs["properties"]["tool"], "read_taxonomy")
|
||||
|
||||
@@ -55,15 +55,22 @@ class ConversationSummarizer:
|
||||
|
||||
|
||||
class AnthropicConversationSummarizer(ConversationSummarizer):
|
||||
def __init__(self, team: Team, user: User, extend_context_window: bool | None = False):
|
||||
super().__init__(team, user)
|
||||
self._extend_context_window = extend_context_window
|
||||
|
||||
def _get_model(self):
|
||||
# Haiku has 200k token limit. Sonnet has 1M token limit.
|
||||
return MaxChatAnthropic(
|
||||
model="claude-haiku-4-5",
|
||||
model="claude-sonnet-4-5" if self._extend_context_window else "claude-haiku-4-5",
|
||||
streaming=False,
|
||||
stream_usage=False,
|
||||
max_tokens=8192,
|
||||
disable_streaming=True,
|
||||
user=self._user,
|
||||
team=self._team,
|
||||
billable=True,
|
||||
betas=["context-1m-2025-08-07"] if self._extend_context_window else None,
|
||||
)
|
||||
|
||||
def _construct_messages(self, messages: Sequence[BaseMessage]):
|
||||
|
||||
@@ -5,7 +5,6 @@ from django.db import transaction
|
||||
|
||||
import structlog
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langchain_openai import ChatOpenAI
|
||||
from pydantic import BaseModel
|
||||
|
||||
from posthog.schema import (
|
||||
@@ -28,6 +27,7 @@ from ee.hogai.graph.parallel_task_execution.mixins import (
|
||||
)
|
||||
from ee.hogai.graph.parallel_task_execution.nodes import BaseTaskExecutorNode, TaskExecutionInputTuple
|
||||
from ee.hogai.graph.shared_prompts import HYPERLINK_USAGE_INSTRUCTIONS
|
||||
from ee.hogai.llm import MaxChatOpenAI
|
||||
from ee.hogai.utils.helpers import build_dashboard_url, build_insight_url, cast_assistant_query
|
||||
from ee.hogai.utils.types import AssistantState, PartialAssistantState
|
||||
from ee.hogai.utils.types.base import BaseStateWithTasks, InsightArtifact, InsightQuery, TaskResult
|
||||
@@ -417,10 +417,14 @@ class DashboardCreationNode(AssistantNode):
|
||||
|
||||
@property
|
||||
def _model(self):
|
||||
return ChatOpenAI(
|
||||
return MaxChatOpenAI(
|
||||
model="gpt-4.1-mini",
|
||||
temperature=0.3,
|
||||
max_completion_tokens=500,
|
||||
max_retries=3,
|
||||
disable_streaming=True,
|
||||
user=self._user,
|
||||
team=self._team,
|
||||
billable=True,
|
||||
inject_context=False,
|
||||
)
|
||||
|
||||
@@ -14,7 +14,6 @@ import structlog
|
||||
from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage, ToolMessage
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langchain_core.tools import tool
|
||||
from langchain_openai import ChatOpenAI
|
||||
|
||||
from posthog.schema import AssistantToolCallMessage, VisualizationMessage
|
||||
|
||||
@@ -25,6 +24,7 @@ from ee.hogai.context import SUPPORTED_QUERY_MODEL_BY_KIND
|
||||
from ee.hogai.graph.base import AssistantNode
|
||||
from ee.hogai.graph.query_executor.query_executor import AssistantQueryExecutor, SupportedQueryTypes
|
||||
from ee.hogai.graph.shared_prompts import HYPERLINK_USAGE_INSTRUCTIONS
|
||||
from ee.hogai.llm import MaxChatOpenAI
|
||||
from ee.hogai.utils.helpers import build_insight_url
|
||||
from ee.hogai.utils.types import AssistantState, PartialAssistantState
|
||||
|
||||
@@ -856,7 +856,7 @@ class InsightSearchNode(AssistantNode):
|
||||
|
||||
@property
|
||||
def _model(self):
|
||||
return ChatOpenAI(
|
||||
return MaxChatOpenAI(
|
||||
model="gpt-4.1-mini",
|
||||
temperature=0.7,
|
||||
max_completion_tokens=1000,
|
||||
@@ -864,4 +864,8 @@ class InsightSearchNode(AssistantNode):
|
||||
stream_usage=False,
|
||||
max_retries=3,
|
||||
disable_streaming=True,
|
||||
user=self._user,
|
||||
team=self._team,
|
||||
billable=True,
|
||||
inject_context=False,
|
||||
)
|
||||
|
||||
@@ -266,7 +266,7 @@ class TestInsightSearchNode(BaseTest):
|
||||
|
||||
# Note: Additional visualization messages depend on query type support in test data
|
||||
|
||||
@patch("ee.hogai.graph.insights.nodes.ChatOpenAI")
|
||||
@patch("ee.hogai.graph.insights.nodes.MaxChatOpenAI")
|
||||
def test_search_insights_iteratively_single_page(self, mock_openai):
|
||||
"""Test iterative search with single page (no pagination)."""
|
||||
|
||||
@@ -296,7 +296,7 @@ class TestInsightSearchNode(BaseTest):
|
||||
self.assertIn(self.insight1.id, result)
|
||||
self.assertIn(self.insight2.id, result)
|
||||
|
||||
@patch("ee.hogai.graph.insights.nodes.ChatOpenAI")
|
||||
@patch("ee.hogai.graph.insights.nodes.MaxChatOpenAI")
|
||||
def test_search_insights_iteratively_with_pagination(self, mock_openai):
|
||||
"""Test iterative search with pagination returns valid IDs."""
|
||||
|
||||
@@ -326,7 +326,7 @@ class TestInsightSearchNode(BaseTest):
|
||||
self.assertIn(existing_insight_ids[0], result)
|
||||
self.assertIn(existing_insight_ids[1], result)
|
||||
|
||||
@patch("ee.hogai.graph.insights.nodes.ChatOpenAI")
|
||||
@patch("ee.hogai.graph.insights.nodes.MaxChatOpenAI")
|
||||
def test_search_insights_iteratively_fallback(self, mock_openai):
|
||||
"""Test iterative search when LLM fails - should return empty list."""
|
||||
|
||||
@@ -586,7 +586,7 @@ class TestInsightSearchNode(BaseTest):
|
||||
self.assertEqual(len(self.node._evaluation_selections), 0)
|
||||
self.assertEqual(self.node._rejection_reason, "None of these match the user's needs")
|
||||
|
||||
@patch("ee.hogai.graph.insights.nodes.ChatOpenAI")
|
||||
@patch("ee.hogai.graph.insights.nodes.MaxChatOpenAI")
|
||||
async def test_evaluate_insights_with_tools_selection(self, mock_openai):
|
||||
"""Test the new tool-based evaluation with insight selection."""
|
||||
# Load insights
|
||||
@@ -662,7 +662,7 @@ class TestInsightSearchNode(BaseTest):
|
||||
query_info_empty = self.node._extract_query_metadata({})
|
||||
self.assertIsNone(query_info_empty)
|
||||
|
||||
@patch("ee.hogai.graph.insights.nodes.ChatOpenAI")
|
||||
@patch("ee.hogai.graph.insights.nodes.MaxChatOpenAI")
|
||||
async def test_non_executable_insights_handling(self, mock_openai):
|
||||
"""Test that non-executable insights are presented to LLM but rejected."""
|
||||
# Create a mock insight that can't be visualized
|
||||
@@ -707,7 +707,7 @@ class TestInsightSearchNode(BaseTest):
|
||||
# The explanation should indicate why the insight was rejected
|
||||
self.assertTrue(len(result["explanation"]) > 0)
|
||||
|
||||
@patch("ee.hogai.graph.insights.nodes.ChatOpenAI")
|
||||
@patch("ee.hogai.graph.insights.nodes.MaxChatOpenAI")
|
||||
async def test_evaluate_insights_with_tools_rejection(self, mock_openai):
|
||||
"""Test the new tool-based evaluation with rejection."""
|
||||
# Load insights
|
||||
@@ -737,7 +737,7 @@ class TestInsightSearchNode(BaseTest):
|
||||
result["explanation"], "User is looking for retention analysis, but these are trends and funnels"
|
||||
)
|
||||
|
||||
@patch("ee.hogai.graph.insights.nodes.ChatOpenAI")
|
||||
@patch("ee.hogai.graph.insights.nodes.MaxChatOpenAI")
|
||||
async def test_evaluate_insights_with_tools_multiple_selection(self, mock_openai):
|
||||
"""Test the evaluation with multiple selection mode."""
|
||||
# Load insights
|
||||
|
||||
@@ -424,7 +424,7 @@ class MemoryCollectorNode(MemoryOnboardingShouldRunMixin):
|
||||
@property
|
||||
def _model(self):
|
||||
return MaxChatOpenAI(
|
||||
model="gpt-4.1", temperature=0.3, disable_streaming=True, user=self._user, team=self._team
|
||||
model="gpt-4.1", temperature=0.3, disable_streaming=True, user=self._user, team=self._team, billable=True
|
||||
).bind_tools(memory_collector_tools)
|
||||
|
||||
def _construct_messages(self, state: AssistantState) -> list[BaseMessage]:
|
||||
|
||||
@@ -155,6 +155,7 @@ class QueryPlannerNode(TaxonomyUpdateDispatcherNodeMixin, AssistantNode):
|
||||
# Ref: https://forum.langchain.com/t/langgraph-openai-responses-api-400-error-web-search-call-was-provided-without-its-required-reasoning-item/1740/2
|
||||
output_version="responses/v1",
|
||||
disable_streaming=True,
|
||||
billable=True,
|
||||
).bind_tools(
|
||||
[
|
||||
retrieve_event_properties,
|
||||
|
||||
@@ -58,7 +58,13 @@ class SchemaGeneratorNode(AssistantNode, Generic[Q]):
|
||||
@property
|
||||
def _model(self):
|
||||
return MaxChatOpenAI(
|
||||
model="gpt-4.1", temperature=0.3, disable_streaming=True, user=self._user, team=self._team, max_tokens=8192
|
||||
model="gpt-4.1",
|
||||
temperature=0.3,
|
||||
disable_streaming=True,
|
||||
user=self._user,
|
||||
team=self._team,
|
||||
max_tokens=8192,
|
||||
billable=True,
|
||||
).with_structured_output(
|
||||
self.OUTPUT_SCHEMA,
|
||||
method="json_schema",
|
||||
|
||||
@@ -59,12 +59,24 @@ class SessionSummarizationNode(AssistantNode):
|
||||
self._session_search = _SessionSearch(self)
|
||||
self._session_summarizer = _SessionSummarizer(self)
|
||||
|
||||
async def _stream_progress(self, progress_message: str) -> None:
|
||||
def _stream_progress(self, progress_message: str) -> None:
|
||||
"""Push summarization progress as reasoning messages"""
|
||||
content = prepare_reasoning_progress_message(progress_message)
|
||||
if content:
|
||||
self.dispatcher.update(content)
|
||||
|
||||
def _stream_filters(self, filters: MaxRecordingUniversalFilters) -> None:
|
||||
"""Stream filters to the user"""
|
||||
self.dispatcher.message(
|
||||
AssistantToolCallMessage(
|
||||
content="",
|
||||
ui_payload={"search_session_recordings": filters.model_dump(exclude_none=True)},
|
||||
# Randomized tool call ID, as we don't want this to be THE result of the actual session summarization tool call
|
||||
# - it's OK because this is only dispatched ephemerally, so the tool message doesn't get added to the state
|
||||
tool_call_id=str(uuid4()),
|
||||
)
|
||||
)
|
||||
|
||||
async def _stream_notebook_content(self, content: dict, state: AssistantState, partial: bool = True) -> None:
|
||||
"""Stream TipTap content directly to a notebook if notebook_id is present in state."""
|
||||
# Check if we have a notebook_id in the state
|
||||
@@ -270,9 +282,7 @@ class _SessionSearch:
|
||||
|
||||
def _convert_current_filters_to_recordings_query(self, current_filters: dict[str, Any]) -> RecordingsQuery:
|
||||
"""Convert current filters into recordings query format"""
|
||||
from ee.session_recordings.playlist_counters.recordings_that_match_playlist_filters import (
|
||||
convert_filters_to_recordings_query,
|
||||
)
|
||||
from posthog.session_recordings.playlist_counters import convert_filters_to_recordings_query
|
||||
|
||||
# Create a temporary playlist object to use the conversion function
|
||||
temp_playlist = SessionRecordingPlaylist(filters=current_filters)
|
||||
@@ -386,6 +396,7 @@ class _SessionSearch:
|
||||
root_tool_call_id=None,
|
||||
)
|
||||
# Use filters when generated successfully
|
||||
self._node._stream_filters(filter_generation_result)
|
||||
replay_filters = self._convert_max_filters_to_recordings_query(filter_generation_result)
|
||||
# Query the filters to get session ids
|
||||
query_limit = state.session_summarization_limit
|
||||
@@ -429,13 +440,13 @@ class _SessionSummarizer:
|
||||
)
|
||||
completed += 1
|
||||
# Update the user on the progress
|
||||
await self._node._stream_progress(progress_message=f"Watching sessions ({completed}/{total})")
|
||||
self._node._stream_progress(progress_message=f"Watching sessions ({completed}/{total})")
|
||||
return result
|
||||
|
||||
# Run all tasks concurrently
|
||||
tasks = [_summarize(sid) for sid in session_ids]
|
||||
summaries = await asyncio.gather(*tasks)
|
||||
await self._node._stream_progress(progress_message=f"Generating a summary, almost there")
|
||||
self._node._stream_progress(progress_message=f"Generating a summary, almost there")
|
||||
# Stringify, as chat doesn't need full JSON to be context-aware, while providing it could overload the context
|
||||
stringified_summaries = []
|
||||
for summary in summaries:
|
||||
@@ -483,7 +494,7 @@ class _SessionSummarizer:
|
||||
# Update intermediate state based on step enum (no content, as it's just a status message)
|
||||
self._intermediate_state.update_step_progress(content=None, step=step)
|
||||
# Status message - stream to user
|
||||
await self._node._stream_progress(progress_message=data)
|
||||
self._node._stream_progress(progress_message=data)
|
||||
# Notebook intermediate data update messages
|
||||
elif update_type == SessionSummaryStreamUpdate.NOTEBOOK_UPDATE:
|
||||
if not isinstance(data, dict):
|
||||
@@ -543,7 +554,7 @@ class _SessionSummarizer:
|
||||
base_message = f"Found sessions ({len(session_ids)})"
|
||||
if len(session_ids) <= GROUP_SUMMARIES_MIN_SESSIONS:
|
||||
# If small amount of sessions - there are no patterns to extract, so summarize them individually and return as is
|
||||
await self._node._stream_progress(
|
||||
self._node._stream_progress(
|
||||
progress_message=f"{base_message}. We will do a quick summary, as the scope is small",
|
||||
)
|
||||
summaries_content = await self._summarize_sessions_individually(session_ids=session_ids)
|
||||
@@ -558,7 +569,7 @@ class _SessionSummarizer:
|
||||
state.notebook_short_id = notebook.short_id
|
||||
# For large groups, process in detail, searching for patterns
|
||||
# TODO: Allow users to define the pattern themselves (or rather catch it from the query)
|
||||
await self._node._stream_progress(
|
||||
self._node._stream_progress(
|
||||
progress_message=f"{base_message}. We will analyze in detail, and store the report in a notebook",
|
||||
)
|
||||
summaries_content = await self._summarize_sessions_as_group(
|
||||
|
||||
@@ -25,18 +25,18 @@
|
||||
WHERE and(equals(s.team_id, 99999), greaterOrEquals(toTimeZone(s.min_first_timestamp, 'UTC'), toDateTime64('2025-08-24 00:00:00.000000', 6, 'UTC')), lessOrEquals(toTimeZone(s.min_first_timestamp, 'UTC'), toDateTime64('2025-09-03 23:59:59.000000', 6, 'UTC')))
|
||||
GROUP BY s.session_id
|
||||
HAVING and(ifNull(greaterOrEquals(expiry_time, toDateTime64('2025-09-03 12:00:00.000000', 6, 'UTC')), 0), ifNull(greater(active_seconds, 8.0), 0))
|
||||
ORDER BY start_time DESC
|
||||
LIMIT 101
|
||||
OFFSET 0 SETTINGS readonly=2,
|
||||
max_execution_time=60,
|
||||
allow_experimental_object_type=1,
|
||||
format_csv_allow_double_quotes=0,
|
||||
max_ast_elements=4000000,
|
||||
max_expanded_ast_elements=4000000,
|
||||
max_bytes_before_external_group_by=0,
|
||||
transform_null_in=1,
|
||||
optimize_min_equality_disjunction_chain_length=4294967295,
|
||||
allow_experimental_join_condition=1
|
||||
ORDER BY start_time DESC,
|
||||
s.session_id DESC
|
||||
LIMIT 101 SETTINGS readonly=2,
|
||||
max_execution_time=60,
|
||||
allow_experimental_object_type=1,
|
||||
format_csv_allow_double_quotes=0,
|
||||
max_ast_elements=4000000,
|
||||
max_expanded_ast_elements=4000000,
|
||||
max_bytes_before_external_group_by=0,
|
||||
transform_null_in=1,
|
||||
optimize_min_equality_disjunction_chain_length=4294967295,
|
||||
allow_experimental_join_condition=1
|
||||
'''
|
||||
# ---
|
||||
# name: TestSessionSummarizationNodeFilterGeneration.test_get_session_ids_respects_limit
|
||||
@@ -65,18 +65,18 @@
|
||||
WHERE and(equals(s.team_id, 99999), greaterOrEquals(toTimeZone(s.min_first_timestamp, 'UTC'), toDateTime64('2025-08-27 00:00:00.000000', 6, 'UTC')), lessOrEquals(toTimeZone(s.min_first_timestamp, 'UTC'), toDateTime64('2025-08-31 23:59:59.000000', 6, 'UTC')))
|
||||
GROUP BY s.session_id
|
||||
HAVING ifNull(greaterOrEquals(expiry_time, toDateTime64('2025-09-03 12:00:00.000000', 6, 'UTC')), 0)
|
||||
ORDER BY start_time DESC
|
||||
LIMIT 2
|
||||
OFFSET 0 SETTINGS readonly=2,
|
||||
max_execution_time=60,
|
||||
allow_experimental_object_type=1,
|
||||
format_csv_allow_double_quotes=0,
|
||||
max_ast_elements=4000000,
|
||||
max_expanded_ast_elements=4000000,
|
||||
max_bytes_before_external_group_by=0,
|
||||
transform_null_in=1,
|
||||
optimize_min_equality_disjunction_chain_length=4294967295,
|
||||
allow_experimental_join_condition=1
|
||||
ORDER BY start_time DESC,
|
||||
s.session_id DESC
|
||||
LIMIT 2 SETTINGS readonly=2,
|
||||
max_execution_time=60,
|
||||
allow_experimental_object_type=1,
|
||||
format_csv_allow_double_quotes=0,
|
||||
max_ast_elements=4000000,
|
||||
max_expanded_ast_elements=4000000,
|
||||
max_bytes_before_external_group_by=0,
|
||||
transform_null_in=1,
|
||||
optimize_min_equality_disjunction_chain_length=4294967295,
|
||||
allow_experimental_join_condition=1
|
||||
'''
|
||||
# ---
|
||||
# name: TestSessionSummarizationNodeFilterGeneration.test_use_current_filters_with_date_range
|
||||
@@ -105,18 +105,18 @@
|
||||
WHERE and(equals(s.team_id, 99999), greaterOrEquals(toTimeZone(s.min_first_timestamp, 'UTC'), toDateTime64('2025-08-27 00:00:00.000000', 6, 'UTC')), lessOrEquals(toTimeZone(s.min_first_timestamp, 'UTC'), toDateTime64('2025-08-31 23:59:59.000000', 6, 'UTC')))
|
||||
GROUP BY s.session_id
|
||||
HAVING and(ifNull(greaterOrEquals(expiry_time, toDateTime64('2025-09-03 12:00:00.000000', 6, 'UTC')), 0), ifNull(greater(active_seconds, 7.0), 0))
|
||||
ORDER BY start_time DESC
|
||||
LIMIT 101
|
||||
OFFSET 0 SETTINGS readonly=2,
|
||||
max_execution_time=60,
|
||||
allow_experimental_object_type=1,
|
||||
format_csv_allow_double_quotes=0,
|
||||
max_ast_elements=4000000,
|
||||
max_expanded_ast_elements=4000000,
|
||||
max_bytes_before_external_group_by=0,
|
||||
transform_null_in=1,
|
||||
optimize_min_equality_disjunction_chain_length=4294967295,
|
||||
allow_experimental_join_condition=1
|
||||
ORDER BY start_time DESC,
|
||||
s.session_id DESC
|
||||
LIMIT 101 SETTINGS readonly=2,
|
||||
max_execution_time=60,
|
||||
allow_experimental_object_type=1,
|
||||
format_csv_allow_double_quotes=0,
|
||||
max_ast_elements=4000000,
|
||||
max_expanded_ast_elements=4000000,
|
||||
max_bytes_before_external_group_by=0,
|
||||
transform_null_in=1,
|
||||
optimize_min_equality_disjunction_chain_length=4294967295,
|
||||
allow_experimental_join_condition=1
|
||||
'''
|
||||
# ---
|
||||
# name: TestSessionSummarizationNodeFilterGeneration.test_use_current_filters_with_os_and_events
|
||||
@@ -149,42 +149,42 @@
|
||||
GROUP BY events.`$session_id`
|
||||
HAVING 1
|
||||
ORDER BY min(toTimeZone(events.timestamp, 'UTC')) DESC
|
||||
LIMIT 10000)), globalIn(s.session_id,
|
||||
(SELECT events.`$session_id` AS session_id
|
||||
FROM events
|
||||
LEFT OUTER JOIN
|
||||
(SELECT argMax(person_distinct_id_overrides.person_id, person_distinct_id_overrides.version) AS person_id, person_distinct_id_overrides.distinct_id AS distinct_id
|
||||
FROM person_distinct_id_overrides
|
||||
WHERE equals(person_distinct_id_overrides.team_id, 99999)
|
||||
GROUP BY person_distinct_id_overrides.distinct_id
|
||||
HAVING ifNull(equals(argMax(person_distinct_id_overrides.is_deleted, person_distinct_id_overrides.version), 0), 0) SETTINGS optimize_aggregation_in_order=1) AS events__override ON equals(events.distinct_id, events__override.distinct_id)
|
||||
LEFT JOIN
|
||||
(SELECT person.id AS id, replaceRegexpAll(nullIf(nullIf(JSONExtractRaw(person.properties, '$os'), ''), 'null'), '^"|"$', '') AS `properties___$os`
|
||||
FROM person
|
||||
WHERE and(equals(person.team_id, 99999), in(tuple(person.id, person.version),
|
||||
(SELECT person.id AS id, max(person.version) AS version
|
||||
FROM person
|
||||
WHERE equals(person.team_id, 99999)
|
||||
GROUP BY person.id
|
||||
HAVING and(ifNull(equals(argMax(person.is_deleted, person.version), 0), 0), ifNull(less(argMax(toTimeZone(person.created_at, 'UTC'), person.version), plus(now64(6, 'UTC'), toIntervalDay(1))), 0))))) SETTINGS optimize_aggregation_in_order=1) AS events__person ON equals(if(not(empty(events__override.distinct_id)), events__override.person_id, events.person_id), events__person.id)
|
||||
WHERE and(equals(events.team_id, 99999), notEmpty(events.`$session_id`), greaterOrEquals(toTimeZone(events.timestamp, 'UTC'), toDateTime64('explicit_redacted_timestamp', 6, 'UTC')), lessOrEquals(toTimeZone(events.timestamp, 'UTC'), now64(6, 'UTC')), ifNull(equals(events__person.`properties___$os`, 'Mac OS X'), 0), greaterOrEquals(toTimeZone(events.timestamp, 'UTC'), toDateTime64('explicit_redacted_timestamp', 6, 'UTC')), lessOrEquals(toTimeZone(events.timestamp, 'UTC'), toDateTime64('explicit_redacted_timestamp', 6, 'UTC')))
|
||||
GROUP BY events.`$session_id`
|
||||
HAVING 1
|
||||
ORDER BY min(toTimeZone(events.timestamp, 'UTC')) DESC
|
||||
LIMIT 10000))))
|
||||
LIMIT 1000000)), globalIn(s.session_id,
|
||||
(SELECT events.`$session_id` AS session_id
|
||||
FROM events
|
||||
LEFT OUTER JOIN
|
||||
(SELECT argMax(person_distinct_id_overrides.person_id, person_distinct_id_overrides.version) AS person_id, person_distinct_id_overrides.distinct_id AS distinct_id
|
||||
FROM person_distinct_id_overrides
|
||||
WHERE equals(person_distinct_id_overrides.team_id, 99999)
|
||||
GROUP BY person_distinct_id_overrides.distinct_id
|
||||
HAVING ifNull(equals(argMax(person_distinct_id_overrides.is_deleted, person_distinct_id_overrides.version), 0), 0) SETTINGS optimize_aggregation_in_order=1) AS events__override ON equals(events.distinct_id, events__override.distinct_id)
|
||||
LEFT JOIN
|
||||
(SELECT person.id AS id, replaceRegexpAll(nullIf(nullIf(JSONExtractRaw(person.properties, '$os'), ''), 'null'), '^"|"$', '') AS `properties___$os`
|
||||
FROM person
|
||||
WHERE and(equals(person.team_id, 99999), in(tuple(person.id, person.version),
|
||||
(SELECT person.id AS id, max(person.version) AS version
|
||||
FROM person
|
||||
WHERE equals(person.team_id, 99999)
|
||||
GROUP BY person.id
|
||||
HAVING and(ifNull(equals(argMax(person.is_deleted, person.version), 0), 0), ifNull(less(argMax(toTimeZone(person.created_at, 'UTC'), person.version), plus(now64(6, 'UTC'), toIntervalDay(1))), 0))))) SETTINGS optimize_aggregation_in_order=1) AS events__person ON equals(if(not(empty(events__override.distinct_id)), events__override.person_id, events.person_id), events__person.id)
|
||||
WHERE and(equals(events.team_id, 99999), notEmpty(events.`$session_id`), greaterOrEquals(toTimeZone(events.timestamp, 'UTC'), toDateTime64('explicit_redacted_timestamp', 6, 'UTC')), lessOrEquals(toTimeZone(events.timestamp, 'UTC'), now64(6, 'UTC')), ifNull(equals(events__person.`properties___$os`, 'Mac OS X'), 0), greaterOrEquals(toTimeZone(events.timestamp, 'UTC'), toDateTime64('explicit_redacted_timestamp', 6, 'UTC')), lessOrEquals(toTimeZone(events.timestamp, 'UTC'), toDateTime64('explicit_redacted_timestamp', 6, 'UTC')))
|
||||
GROUP BY events.`$session_id`
|
||||
HAVING 1
|
||||
ORDER BY min(toTimeZone(events.timestamp, 'UTC')) DESC
|
||||
LIMIT 1000000))))
|
||||
GROUP BY s.session_id
|
||||
HAVING and(ifNull(greaterOrEquals(expiry_time, toDateTime64('explicit_redacted_timestamp', 6, 'UTC')), 0), ifNull(greater(active_seconds, 6.0), 0))
|
||||
ORDER BY start_time DESC
|
||||
LIMIT 101
|
||||
OFFSET 0 SETTINGS readonly=2,
|
||||
max_execution_time=60,
|
||||
allow_experimental_object_type=1,
|
||||
format_csv_allow_double_quotes=0,
|
||||
max_ast_elements=4000000,
|
||||
max_expanded_ast_elements=4000000,
|
||||
max_bytes_before_external_group_by=0,
|
||||
transform_null_in=1,
|
||||
optimize_min_equality_disjunction_chain_length=4294967295,
|
||||
allow_experimental_join_condition=1
|
||||
ORDER BY start_time DESC,
|
||||
s.session_id DESC
|
||||
LIMIT 101 SETTINGS readonly=2,
|
||||
max_execution_time=60,
|
||||
allow_experimental_object_type=1,
|
||||
format_csv_allow_double_quotes=0,
|
||||
max_ast_elements=4000000,
|
||||
max_expanded_ast_elements=4000000,
|
||||
max_bytes_before_external_group_by=0,
|
||||
transform_null_in=1,
|
||||
optimize_min_equality_disjunction_chain_length=4294967295,
|
||||
allow_experimental_join_condition=1
|
||||
'''
|
||||
# ---
|
||||
|
||||
@@ -38,29 +38,50 @@ Standardized events/properties such as pageview or screen start with `$`. Custom
|
||||
`virtual_table` and `lazy_table` fields are connections to linked tables, e.g. the virtual table field `person` allows accessing person properties like so: `person.properties.foo`.
|
||||
|
||||
<person_id_join_limitation>
|
||||
There is a known issue with queries that join multiple events tables where join constraints
|
||||
reference person_id fields. The person_id fields are ExpressionFields that expand to
|
||||
expressions referencing override tables (e.g., e_all__override). However, these expressions
|
||||
are resolved during type resolution (in printer.py) BEFORE lazy table processing begins.
|
||||
This creates forward references to override tables that don't exist yet.
|
||||
CRITICAL: There is a known issue with queries where JOIN constraints reference events.person_id fields.
|
||||
|
||||
Example problematic HogQL:
|
||||
SELECT MAX(e_all.timestamp) AS last_seen
|
||||
FROM events e_dl
|
||||
JOIN persons p ON e_dl.person_id = p.id
|
||||
JOIN events e_all ON e_dl.person_id = e_all.person_id
|
||||
TECHNICAL CAUSE:
|
||||
The person_id fields are ExpressionFields that expand to expressions referencing override tables
|
||||
(e.g., e_all__override). However, these expressions are resolved during type resolution (in printer.py)
|
||||
BEFORE lazy table processing begins. This creates forward references to override tables that don't
|
||||
exist yet, causing ClickHouse errors like:
|
||||
"Missing columns: '_--e__override.person_id' '_--e__override.distinct_id'"
|
||||
|
||||
The join constraint "e_dl.person_id = e_all.person_id" expands to:
|
||||
if(NOT empty(e_dl__override.distinct_id), e_dl__override.person_id, e_dl.person_id) =
|
||||
if(NOT empty(e_all__override.distinct_id), e_all__override.person_id, e_all.person_id)
|
||||
PROBLEMATIC PATTERNS:
|
||||
1. Joining persons to events using events.person_id:
|
||||
❌ FROM persons p ALL INNER JOIN events e ON p.id = e.person_id
|
||||
|
||||
But e_all__override is defined later in the SQL, causing a ClickHouse error.
|
||||
2. Joining multiple events tables using person_id:
|
||||
❌ FROM events e_dl
|
||||
JOIN persons p ON e_dl.person_id = p.id
|
||||
JOIN events e_all ON e_dl.person_id = e_all.person_id
|
||||
|
||||
WORKAROUND: Use subqueries or rewrite queries to avoid direct joins between multiple events tables:
|
||||
SELECT MAX(e.timestamp) AS last_seen
|
||||
FROM events e
|
||||
JOIN persons p ON e.person_id = p.id
|
||||
WHERE e.event IN (SELECT event FROM events WHERE ...)
|
||||
The join constraint "e_dl.person_id = e_all.person_id" expands to:
|
||||
if(NOT empty(e_dl__override.distinct_id), e_dl__override.person_id, e_dl.person_id) =
|
||||
if(NOT empty(e_all__override.distinct_id), e_all__override.person_id, e_all.person_id)
|
||||
|
||||
But e_all__override is defined later in the SQL, causing the error.
|
||||
|
||||
REQUIRED WORKAROUNDS:
|
||||
1. For accessing person data, use the person virtual table from events:
|
||||
✅ SELECT e.person.id, e.person.properties.email, e.event
|
||||
FROM events e
|
||||
WHERE e.timestamp > now() - INTERVAL 7 DAY
|
||||
|
||||
2. For filtering persons by event data, use subqueries with WHERE IN:
|
||||
✅ SELECT p.id, p.properties.email
|
||||
FROM persons p
|
||||
WHERE p.id IN (
|
||||
SELECT DISTINCT person_id FROM events
|
||||
WHERE event = 'purchase' AND timestamp > now() - INTERVAL 7 DAY
|
||||
)
|
||||
|
||||
3. For multiple events tables, use subqueries to avoid direct joins:
|
||||
✅ SELECT MAX(e.timestamp) AS last_seen
|
||||
FROM events e
|
||||
WHERE e.person_id IN (SELECT DISTINCT person_id FROM events WHERE ...)
|
||||
|
||||
NEVER use events.person_id directly in JOIN ON constraints - always use one of the workarounds above.
|
||||
</person_id_join_limitation>
|
||||
|
||||
ONLY make formatting or casing changes if explicitly requested by the user.
|
||||
|
||||
@@ -68,8 +68,16 @@ class TaxonomyAgentNode(
|
||||
return EntityType.values() + self._team_group_types
|
||||
|
||||
def _get_model(self, state: TaxonomyStateType):
|
||||
# Check if this invocation should be billable (set by the calling tool)
|
||||
billable = getattr(state, "billable", False)
|
||||
return MaxChatOpenAI(
|
||||
model="gpt-4.1", streaming=False, temperature=0.3, user=self._user, team=self._team, disable_streaming=True
|
||||
model="gpt-4.1",
|
||||
streaming=False,
|
||||
temperature=0.3,
|
||||
user=self._user,
|
||||
team=self._team,
|
||||
disable_streaming=True,
|
||||
billable=billable,
|
||||
).bind_tools(
|
||||
self._toolkit.get_tools(),
|
||||
tool_choice="required",
|
||||
@@ -146,6 +154,7 @@ class TaxonomyAgentNode(
|
||||
intermediate_steps=intermediate_steps,
|
||||
output=state.output,
|
||||
iteration_count=state.iteration_count + 1 if state.iteration_count is not None else 1,
|
||||
billable=state.billable,
|
||||
)
|
||||
|
||||
|
||||
@@ -198,6 +207,7 @@ class TaxonomyAgentToolsNode(
|
||||
return self._partial_state_class(
|
||||
output=tool_input.arguments.answer, # type: ignore
|
||||
intermediate_steps=None,
|
||||
billable=state.billable,
|
||||
)
|
||||
|
||||
if tool_input.name == "ask_user_for_help":
|
||||
@@ -237,6 +247,7 @@ class TaxonomyAgentToolsNode(
|
||||
tool_progress_messages=[*old_msg, *tool_msgs],
|
||||
intermediate_steps=steps,
|
||||
iteration_count=state.iteration_count,
|
||||
billable=state.billable,
|
||||
)
|
||||
|
||||
def router(self, state: TaxonomyStateType) -> str:
|
||||
@@ -257,4 +268,5 @@ class TaxonomyAgentToolsNode(
|
||||
)
|
||||
]
|
||||
reset_state.output = output
|
||||
reset_state.billable = state.billable
|
||||
return reset_state # type: ignore[return-value]
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user