mirror of
https://github.com/BillyOutlast/posthog.git
synced 2026-02-04 03:01:23 +01:00
feat(ph-ai): create experiments max tool (#40753)
## Problem This PR introduces a new `MaxTool` to create experiments. It builds on top of the `create_feature_flag` `MaxTool` with multivariate flag support. ## Changes - Added a new `CreateExperimentTool` that allows creating experiments via PostHog AI ## How did you test this code? New tests + evals ## Changelog: (features only) Is this feature complete? Yes, this feature is complete and ready to be included in the changelog.
This commit is contained in:
committed by
GitHub
parent
dfaf7b28d7
commit
daa062b202
407
ee/hogai/eval/ci/max_tools/eval_create_experiment_tool.py
Normal file
407
ee/hogai/eval/ci/max_tools/eval_create_experiment_tool.py
Normal file
@@ -0,0 +1,407 @@
|
||||
"""Evaluations for CreateExperimentTool."""
|
||||
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
|
||||
from autoevals.partial import ScorerWithPartial
|
||||
from autoevals.ragas import AnswerSimilarity
|
||||
from braintrust import EvalCase, Score
|
||||
|
||||
from posthog.models import Experiment, FeatureFlag
|
||||
|
||||
from products.experiments.backend.max_tools import CreateExperimentTool
|
||||
|
||||
from ee.hogai.eval.base import MaxPublicEval
|
||||
from ee.hogai.utils.types import AssistantState
|
||||
from ee.models.assistant import Conversation
|
||||
|
||||
|
||||
class ExperimentOutputScorer(ScorerWithPartial):
|
||||
"""Custom scorer for experiment tool output that combines semantic similarity for text and exact matching for numbers/booleans."""
|
||||
|
||||
def __init__(self, semantic_fields: set[str] | None = None, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.semantic_fields = semantic_fields or {"message"}
|
||||
|
||||
def _run_eval_sync(self, output: dict, expected: dict, **kwargs):
|
||||
if not expected:
|
||||
return Score(name=self._name(), score=None, metadata={"reason": "No expected value provided"})
|
||||
if not output:
|
||||
return Score(name=self._name(), score=0.0, metadata={"reason": "No output provided"})
|
||||
|
||||
total_fields = len(expected)
|
||||
if total_fields == 0:
|
||||
return Score(name=self._name(), score=1.0)
|
||||
|
||||
score_per_field = 1.0 / total_fields
|
||||
total_score = 0.0
|
||||
metadata = {}
|
||||
|
||||
for field_name, expected_value in expected.items():
|
||||
actual_value = output.get(field_name)
|
||||
|
||||
if field_name in self.semantic_fields:
|
||||
# Use semantic similarity for text fields
|
||||
if actual_value is not None and expected_value is not None:
|
||||
similarity_scorer = AnswerSimilarity(model="text-embedding-3-small")
|
||||
result = similarity_scorer.eval(output=str(actual_value), expected=str(expected_value))
|
||||
field_score = result.score * score_per_field
|
||||
total_score += field_score
|
||||
metadata[f"{field_name}_score"] = result.score
|
||||
else:
|
||||
metadata[f"{field_name}_missing"] = True
|
||||
else:
|
||||
# Use exact match for numeric/boolean fields
|
||||
if actual_value == expected_value:
|
||||
total_score += score_per_field
|
||||
metadata[f"{field_name}_match"] = True
|
||||
else:
|
||||
metadata[f"{field_name}_mismatch"] = {
|
||||
"expected": expected_value,
|
||||
"actual": actual_value,
|
||||
}
|
||||
|
||||
return Score(name=self._name(), score=total_score, metadata=metadata)
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
async def eval_create_experiment_basic(pytestconfig, demo_org_team_user):
|
||||
"""Test basic experiment creation."""
|
||||
_, team, user = demo_org_team_user
|
||||
|
||||
conversation = await Conversation.objects.acreate(team=team, user=user)
|
||||
|
||||
async def task_create_experiment(args: dict):
|
||||
# Create feature flag first (required by the tool)
|
||||
await FeatureFlag.objects.acreate(
|
||||
team=team,
|
||||
created_by=user,
|
||||
key=args["feature_flag_key"],
|
||||
name=f"Flag for {args['name']}",
|
||||
filters={
|
||||
"groups": [{"properties": [], "rollout_percentage": 100}],
|
||||
"multivariate": {
|
||||
"variants": [
|
||||
{"key": "control", "name": "Control", "rollout_percentage": 50},
|
||||
{"key": "test", "name": "Test", "rollout_percentage": 50},
|
||||
]
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
tool = await CreateExperimentTool.create_tool_class(
|
||||
team=team,
|
||||
user=user,
|
||||
state=AssistantState(messages=[]),
|
||||
config={
|
||||
"configurable": {
|
||||
"thread_id": conversation.id,
|
||||
"team": team,
|
||||
"user": user,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
result_message, artifact = await tool._arun_impl(
|
||||
name=args["name"],
|
||||
feature_flag_key=args["feature_flag_key"],
|
||||
description=args.get("description"),
|
||||
type=args.get("type", "product"),
|
||||
)
|
||||
|
||||
exp_created = await Experiment.objects.filter(team=team, name=args["name"], deleted=False).aexists()
|
||||
|
||||
return {
|
||||
"message": result_message,
|
||||
"experiment_created": exp_created,
|
||||
"experiment_name": artifact.get("experiment_name") if artifact else None,
|
||||
}
|
||||
|
||||
await MaxPublicEval(
|
||||
experiment_name="create_experiment_basic",
|
||||
task=task_create_experiment, # type: ignore
|
||||
scores=[ExperimentOutputScorer(semantic_fields={"message", "experiment_name"})],
|
||||
data=[
|
||||
EvalCase(
|
||||
input={"name": "Pricing Test", "feature_flag_key": "pricing-test-flag"},
|
||||
expected={
|
||||
"message": "Successfully created experiment",
|
||||
"experiment_created": True,
|
||||
"experiment_name": "Pricing Test",
|
||||
},
|
||||
),
|
||||
EvalCase(
|
||||
input={
|
||||
"name": "Homepage Redesign",
|
||||
"feature_flag_key": "homepage-redesign",
|
||||
"description": "Testing new homepage layout for better conversion",
|
||||
},
|
||||
expected={
|
||||
"message": "Successfully created experiment",
|
||||
"experiment_created": True,
|
||||
"experiment_name": "Homepage Redesign",
|
||||
},
|
||||
),
|
||||
],
|
||||
pytestconfig=pytestconfig,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
async def eval_create_experiment_types(pytestconfig, demo_org_team_user):
|
||||
"""Test experiment creation with different types (product vs web)."""
|
||||
_, team, user = demo_org_team_user
|
||||
|
||||
conversation = await Conversation.objects.acreate(team=team, user=user)
|
||||
|
||||
async def task_create_typed_experiment(args: dict):
|
||||
# Create feature flag first
|
||||
await FeatureFlag.objects.acreate(
|
||||
team=team,
|
||||
created_by=user,
|
||||
key=args["feature_flag_key"],
|
||||
name=f"Flag for {args['name']}",
|
||||
filters={
|
||||
"groups": [{"properties": [], "rollout_percentage": 100}],
|
||||
"multivariate": {
|
||||
"variants": [
|
||||
{"key": "control", "name": "Control", "rollout_percentage": 50},
|
||||
{"key": "test", "name": "Test", "rollout_percentage": 50},
|
||||
]
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
tool = await CreateExperimentTool.create_tool_class(
|
||||
team=team,
|
||||
user=user,
|
||||
state=AssistantState(messages=[]),
|
||||
config={
|
||||
"configurable": {
|
||||
"thread_id": conversation.id,
|
||||
"team": team,
|
||||
"user": user,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
result_message, artifact = await tool._arun_impl(
|
||||
name=args["name"],
|
||||
feature_flag_key=args["feature_flag_key"],
|
||||
type=args["type"],
|
||||
)
|
||||
|
||||
# Verify experiment type
|
||||
experiment = await Experiment.objects.aget(team=team, name=args["name"])
|
||||
|
||||
return {
|
||||
"message": result_message,
|
||||
"experiment_type": experiment.type,
|
||||
"artifact_type": artifact.get("type") if artifact else None,
|
||||
}
|
||||
|
||||
await MaxPublicEval(
|
||||
experiment_name="create_experiment_types",
|
||||
task=task_create_typed_experiment, # type: ignore
|
||||
scores=[ExperimentOutputScorer(semantic_fields={"message"})],
|
||||
data=[
|
||||
EvalCase(
|
||||
input={"name": "Product Feature Test", "feature_flag_key": "product-test", "type": "product"},
|
||||
expected={
|
||||
"message": "Successfully created experiment",
|
||||
"experiment_type": "product",
|
||||
"artifact_type": "product",
|
||||
},
|
||||
),
|
||||
EvalCase(
|
||||
input={"name": "Web UI Test", "feature_flag_key": "web-test", "type": "web"},
|
||||
expected={
|
||||
"message": "Successfully created experiment",
|
||||
"experiment_type": "web",
|
||||
"artifact_type": "web",
|
||||
},
|
||||
),
|
||||
],
|
||||
pytestconfig=pytestconfig,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
async def eval_create_experiment_with_existing_flag(pytestconfig, demo_org_team_user):
|
||||
"""Test experiment creation with an existing feature flag."""
|
||||
_, team, user = demo_org_team_user
|
||||
|
||||
# Create an existing flag with unique key and multivariate variants
|
||||
unique_key = f"reusable-flag-{uuid.uuid4().hex[:8]}"
|
||||
await FeatureFlag.objects.acreate(
|
||||
team=team,
|
||||
key=unique_key,
|
||||
name="Reusable Flag",
|
||||
created_by=user,
|
||||
filters={
|
||||
"groups": [{"properties": [], "rollout_percentage": 100}],
|
||||
"multivariate": {
|
||||
"variants": [
|
||||
{"key": "control", "name": "Control", "rollout_percentage": 50},
|
||||
{"key": "test", "name": "Test", "rollout_percentage": 50},
|
||||
]
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
conversation = await Conversation.objects.acreate(team=team, user=user)
|
||||
|
||||
async def task_create_experiment_reuse_flag(args: dict):
|
||||
tool = await CreateExperimentTool.create_tool_class(
|
||||
team=team,
|
||||
user=user,
|
||||
state=AssistantState(messages=[]),
|
||||
config={
|
||||
"configurable": {
|
||||
"thread_id": conversation.id,
|
||||
"team": team,
|
||||
"user": user,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
result_message, artifact = await tool._arun_impl(
|
||||
name=args["name"],
|
||||
feature_flag_key=args["feature_flag_key"],
|
||||
)
|
||||
|
||||
return {
|
||||
"message": result_message,
|
||||
"experiment_created": artifact is not None and "experiment_id" in artifact,
|
||||
}
|
||||
|
||||
await MaxPublicEval(
|
||||
experiment_name="create_experiment_with_existing_flag",
|
||||
task=task_create_experiment_reuse_flag, # type: ignore
|
||||
scores=[ExperimentOutputScorer(semantic_fields={"message"})],
|
||||
data=[
|
||||
EvalCase(
|
||||
input={"name": "Reuse Flag Test", "feature_flag_key": unique_key},
|
||||
expected={
|
||||
"message": "Successfully created experiment",
|
||||
"experiment_created": True,
|
||||
},
|
||||
),
|
||||
],
|
||||
pytestconfig=pytestconfig,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
async def eval_create_experiment_duplicate_name_error(pytestconfig, demo_org_team_user):
|
||||
"""Test that creating a duplicate experiment returns an error."""
|
||||
_, team, user = demo_org_team_user
|
||||
|
||||
# Create an existing experiment with unique flag key
|
||||
unique_flag_key = f"test-flag-{uuid.uuid4().hex[:8]}"
|
||||
flag = await FeatureFlag.objects.acreate(team=team, key=unique_flag_key, created_by=user)
|
||||
await Experiment.objects.acreate(team=team, name="Existing Experiment", feature_flag=flag, created_by=user)
|
||||
|
||||
conversation = await Conversation.objects.acreate(team=team, user=user)
|
||||
|
||||
async def task_create_duplicate_experiment(args: dict):
|
||||
# Create a different flag for the duplicate attempt
|
||||
await FeatureFlag.objects.acreate(
|
||||
team=team,
|
||||
created_by=user,
|
||||
key=args["feature_flag_key"],
|
||||
name=f"Flag for {args['name']}",
|
||||
)
|
||||
|
||||
tool = await CreateExperimentTool.create_tool_class(
|
||||
team=team,
|
||||
user=user,
|
||||
state=AssistantState(messages=[]),
|
||||
config={
|
||||
"configurable": {
|
||||
"thread_id": conversation.id,
|
||||
"team": team,
|
||||
"user": user,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
result_message, artifact = await tool._arun_impl(
|
||||
name=args["name"],
|
||||
feature_flag_key=args["feature_flag_key"],
|
||||
)
|
||||
|
||||
return {
|
||||
"message": result_message,
|
||||
"has_error": artifact.get("error") is not None if artifact else False,
|
||||
}
|
||||
|
||||
await MaxPublicEval(
|
||||
experiment_name="create_experiment_duplicate_name_error",
|
||||
task=task_create_duplicate_experiment, # type: ignore
|
||||
scores=[ExperimentOutputScorer(semantic_fields={"message"})],
|
||||
data=[
|
||||
EvalCase(
|
||||
input={"name": "Existing Experiment", "feature_flag_key": "another-flag"},
|
||||
expected={
|
||||
"message": "Failed to create experiment: An experiment with name 'Existing Experiment' already exists",
|
||||
"has_error": True,
|
||||
},
|
||||
),
|
||||
],
|
||||
pytestconfig=pytestconfig,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
async def eval_create_experiment_flag_already_used_error(pytestconfig, demo_org_team_user):
|
||||
"""Test that using a flag already tied to another experiment returns an error."""
|
||||
_, team, user = demo_org_team_user
|
||||
|
||||
# Create an experiment with a flag (unique key)
|
||||
unique_flag_key = f"used-flag-{uuid.uuid4().hex[:8]}"
|
||||
flag = await FeatureFlag.objects.acreate(team=team, key=unique_flag_key, created_by=user)
|
||||
await Experiment.objects.acreate(team=team, name="First Experiment", feature_flag=flag, created_by=user)
|
||||
|
||||
conversation = await Conversation.objects.acreate(team=team, user=user)
|
||||
|
||||
async def task_create_experiment_with_used_flag(args: dict):
|
||||
tool = await CreateExperimentTool.create_tool_class(
|
||||
team=team,
|
||||
user=user,
|
||||
state=AssistantState(messages=[]),
|
||||
config={
|
||||
"configurable": {
|
||||
"thread_id": conversation.id,
|
||||
"team": team,
|
||||
"user": user,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
result_message, artifact = await tool._arun_impl(
|
||||
name=args["name"],
|
||||
feature_flag_key=args["feature_flag_key"],
|
||||
)
|
||||
|
||||
return {
|
||||
"message": result_message,
|
||||
"has_error": artifact.get("error") is not None if artifact else False,
|
||||
}
|
||||
|
||||
await MaxPublicEval(
|
||||
experiment_name="create_experiment_flag_already_used_error",
|
||||
task=task_create_experiment_with_used_flag, # type: ignore
|
||||
scores=[ExperimentOutputScorer(semantic_fields={"message"})],
|
||||
data=[
|
||||
EvalCase(
|
||||
input={"name": "Second Experiment", "feature_flag_key": unique_flag_key},
|
||||
expected={
|
||||
"message": "Failed to create experiment: Feature flag is already used by experiment",
|
||||
"has_error": True,
|
||||
},
|
||||
),
|
||||
],
|
||||
pytestconfig=pytestconfig,
|
||||
)
|
||||
@@ -1850,7 +1850,8 @@
|
||||
"read_data",
|
||||
"todo_write",
|
||||
"filter_revenue_analytics",
|
||||
"create_feature_flag"
|
||||
"create_feature_flag",
|
||||
"create_experiment"
|
||||
],
|
||||
"type": "string"
|
||||
},
|
||||
|
||||
@@ -287,6 +287,7 @@ export type AssistantTool =
|
||||
| 'todo_write'
|
||||
| 'filter_revenue_analytics'
|
||||
| 'create_feature_flag'
|
||||
| 'create_experiment'
|
||||
|
||||
export enum AgentMode {
|
||||
ProductAnalytics = 'product_analytics',
|
||||
|
||||
@@ -3,7 +3,7 @@ import { router } from 'kea-router'
|
||||
import { useState } from 'react'
|
||||
import { match } from 'ts-pattern'
|
||||
|
||||
import { LemonDialog, LemonInput, LemonSelect, LemonTag, Tooltip } from '@posthog/lemon-ui'
|
||||
import { LemonDialog, LemonInput, LemonSelect, LemonTag, Tooltip, lemonToast } from '@posthog/lemon-ui'
|
||||
|
||||
import { AccessControlAction } from 'lib/components/AccessControlAction'
|
||||
import { ActivityLog } from 'lib/components/ActivityLog/ActivityLog'
|
||||
@@ -20,6 +20,7 @@ import { atColumn, createdAtColumn, createdByColumn } from 'lib/lemon-ui/LemonTa
|
||||
import { LemonTabs } from 'lib/lemon-ui/LemonTabs'
|
||||
import { deleteWithUndo } from 'lib/utils/deleteWithUndo'
|
||||
import stringWithWBR from 'lib/utils/stringWithWBR'
|
||||
import MaxTool from 'scenes/max/MaxTool'
|
||||
import { useMaxTool } from 'scenes/max/useMaxTool'
|
||||
import { SceneExport } from 'scenes/sceneTypes'
|
||||
import { urls } from 'scenes/urls'
|
||||
@@ -438,14 +439,7 @@ export function Experiments(): JSX.Element {
|
||||
resourceType={AccessControlResourceType.Experiment}
|
||||
minAccessLevel={AccessControlLevel.Editor}
|
||||
>
|
||||
<LemonButton
|
||||
size="small"
|
||||
type="primary"
|
||||
data-attr="create-experiment"
|
||||
to={urls.experiment('new')}
|
||||
>
|
||||
New experiment
|
||||
</LemonButton>
|
||||
<NewExperimentButton />
|
||||
</AccessControlAction>
|
||||
) : undefined
|
||||
}
|
||||
@@ -493,3 +487,43 @@ export function Experiments(): JSX.Element {
|
||||
</SceneContent>
|
||||
)
|
||||
}
|
||||
|
||||
function NewExperimentButton(): JSX.Element {
|
||||
const { loadExperiments } = useActions(experimentsLogic)
|
||||
|
||||
useMaxTool({ identifier: 'create_feature_flag', context: {} })
|
||||
|
||||
return (
|
||||
<MaxTool
|
||||
identifier="create_experiment"
|
||||
initialMaxPrompt="Create an experiment for "
|
||||
suggestions={[
|
||||
'Create an experiment to test our new checkout flow',
|
||||
'Set up an A/B test for the pricing page redesign',
|
||||
'Create an experiment to test different call-to-action buttons on the homepage',
|
||||
'Create an experiment for testing our new recommendation algorithm',
|
||||
]}
|
||||
callback={(toolOutput: {
|
||||
experiment_id?: string | number
|
||||
experiment_name?: string
|
||||
feature_flag_key?: string
|
||||
error?: string
|
||||
}) => {
|
||||
if (toolOutput?.error || !toolOutput?.experiment_id) {
|
||||
lemonToast.error(`Failed to create experiment: ${toolOutput?.error || 'Unknown error'}`)
|
||||
return
|
||||
}
|
||||
// Refresh experiments list to show new experiment, then redirect to it
|
||||
loadExperiments()
|
||||
router.actions.push(urls.experiment(toolOutput.experiment_id))
|
||||
}}
|
||||
position="bottom-right"
|
||||
active={true}
|
||||
context={{}}
|
||||
>
|
||||
<LemonButton size="small" type="primary" data-attr="create-experiment" to={urls.experiment('new')}>
|
||||
<span className="pr-3">New experiment</span>
|
||||
</LemonButton>
|
||||
</MaxTool>
|
||||
)
|
||||
}
|
||||
|
||||
@@ -359,6 +359,18 @@ export const TOOL_DEFINITIONS: Record<Exclude<AssistantTool, 'todo_write'>, Tool
|
||||
return 'Creating feature flag...'
|
||||
},
|
||||
},
|
||||
create_experiment: {
|
||||
name: 'Create an experiment',
|
||||
description: 'Create an experiment in seconds',
|
||||
product: Scene.Experiments,
|
||||
icon: iconForType('experiment'),
|
||||
displayFormatter: (toolCall) => {
|
||||
if (toolCall.status === 'completed') {
|
||||
return 'Created experiment'
|
||||
}
|
||||
return 'Creating experiment...'
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
export const MAX_GENERALLY_CAN: { icon: JSX.Element; description: string }[] = [
|
||||
|
||||
@@ -285,6 +285,7 @@ class AssistantTool(StrEnum):
|
||||
TODO_WRITE = "todo_write"
|
||||
FILTER_REVENUE_ANALYTICS = "filter_revenue_analytics"
|
||||
CREATE_FEATURE_FLAG = "create_feature_flag"
|
||||
CREATE_EXPERIMENT = "create_experiment"
|
||||
|
||||
|
||||
class AssistantToolCall(BaseModel):
|
||||
|
||||
@@ -1,20 +1,154 @@
|
||||
"""
|
||||
MaxTool for AI-powered experiment summary.
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
from typing import Any, Literal
|
||||
|
||||
from posthoganalytics import capture_exception
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from posthog.schema import MaxExperimentSummaryContext
|
||||
|
||||
from posthog.exceptions_capture import capture_exception
|
||||
from posthog.models import Experiment, FeatureFlag
|
||||
from posthog.sync import database_sync_to_async
|
||||
|
||||
from ee.hogai.llm import MaxChatOpenAI
|
||||
from ee.hogai.tool import MaxTool
|
||||
|
||||
from .prompts import EXPERIMENT_SUMMARY_BAYESIAN_PROMPT, EXPERIMENT_SUMMARY_FREQUENTIST_PROMPT
|
||||
|
||||
|
||||
class CreateExperimentArgs(BaseModel):
|
||||
name: str = Field(description="Experiment name - should clearly describe what is being tested")
|
||||
feature_flag_key: str = Field(
|
||||
description="Feature flag key (letters, numbers, hyphens, underscores only). Will create a new flag if it doesn't exist."
|
||||
)
|
||||
description: str | None = Field(
|
||||
default=None,
|
||||
description="Detailed description of the experiment hypothesis, what changes are being tested, and expected outcomes",
|
||||
)
|
||||
type: Literal["product", "web"] = Field(
|
||||
default="product",
|
||||
description="Experiment type: 'product' for backend/API changes, 'web' for frontend UI changes",
|
||||
)
|
||||
|
||||
|
||||
class CreateExperimentTool(MaxTool):
|
||||
name: Literal["create_experiment"] = "create_experiment"
|
||||
description: str = """
|
||||
Create a new A/B test experiment in the current project.
|
||||
|
||||
Experiments allow you to test changes with a controlled rollout and measure their impact.
|
||||
|
||||
Use this tool when the user wants to:
|
||||
- Create a new A/B test experiment
|
||||
- Set up a controlled experiment to test changes
|
||||
- Test variants of a feature with users
|
||||
|
||||
Examples:
|
||||
- "Create an experiment to test the new checkout flow"
|
||||
- "Set up an A/B test for our pricing page redesign"
|
||||
- "Create an experiment called 'homepage-cta-test' to test different call-to-action buttons
|
||||
|
||||
**IMPORTANT**: You must first find or create a multivariate feature flag using `create_feature_flag`, with at least two variants (control and test). Navigate to the feature flags page to create the flag, create the flag, then navigate back to the experiments page and use this tool to create the experiment."
|
||||
""".strip()
|
||||
context_prompt_template: str = "Creates a new A/B test experiment in the project"
|
||||
args_schema: type[BaseModel] = CreateExperimentArgs
|
||||
|
||||
async def _arun_impl(
|
||||
self,
|
||||
name: str,
|
||||
feature_flag_key: str,
|
||||
description: str | None = None,
|
||||
type: Literal["product", "web"] = "product",
|
||||
) -> tuple[str, dict[str, Any] | None]:
|
||||
# Validate inputs
|
||||
if not name or not name.strip():
|
||||
return "Experiment name cannot be empty", {"error": "invalid_name"}
|
||||
|
||||
if not feature_flag_key or not feature_flag_key.strip():
|
||||
return "Feature flag key cannot be empty", {"error": "invalid_flag_key"}
|
||||
|
||||
@database_sync_to_async
|
||||
def create_experiment() -> Experiment:
|
||||
# Check if experiment with this name already exists
|
||||
existing_experiment = Experiment.objects.filter(team=self._team, name=name, deleted=False).first()
|
||||
if existing_experiment:
|
||||
raise ValueError(f"An experiment with name '{name}' already exists")
|
||||
|
||||
try:
|
||||
feature_flag = FeatureFlag.objects.get(team=self._team, key=feature_flag_key, deleted=False)
|
||||
except FeatureFlag.DoesNotExist:
|
||||
raise ValueError(f"Feature flag '{feature_flag_key}' does not exist")
|
||||
|
||||
# Validate that the flag has multivariate variants
|
||||
multivariate = feature_flag.filters.get("multivariate")
|
||||
if not multivariate or not multivariate.get("variants"):
|
||||
raise ValueError(
|
||||
f"Feature flag '{feature_flag_key}' must have multivariate variants to be used in an experiment. "
|
||||
f"Create the flag with variants first using the create_feature_flag tool."
|
||||
)
|
||||
|
||||
variants = multivariate["variants"]
|
||||
if len(variants) < 2:
|
||||
raise ValueError(
|
||||
f"Feature flag '{feature_flag_key}' must have at least 2 variants for an experiment (e.g., control and test)"
|
||||
)
|
||||
|
||||
# If flag already exists and is already used by another experiment, raise error
|
||||
existing_experiment_with_flag = Experiment.objects.filter(feature_flag=feature_flag, deleted=False).first()
|
||||
if existing_experiment_with_flag:
|
||||
raise ValueError(
|
||||
f"Feature flag '{feature_flag_key}' is already used by experiment '{existing_experiment_with_flag.name}'"
|
||||
)
|
||||
|
||||
# Use the actual variants from the feature flag
|
||||
feature_flag_variants = [
|
||||
{
|
||||
"key": variant["key"],
|
||||
"name": variant.get("name", variant["key"]),
|
||||
"rollout_percentage": variant["rollout_percentage"],
|
||||
}
|
||||
for variant in variants
|
||||
]
|
||||
|
||||
# Create the experiment as a draft (no start_date)
|
||||
experiment = Experiment.objects.create(
|
||||
team=self._team,
|
||||
created_by=self._user,
|
||||
name=name,
|
||||
description=description or "",
|
||||
type=type,
|
||||
feature_flag=feature_flag,
|
||||
filters={}, # Empty filters for draft
|
||||
parameters={
|
||||
"feature_flag_variants": feature_flag_variants,
|
||||
"minimum_detectable_effect": 30,
|
||||
},
|
||||
metrics=[],
|
||||
metrics_secondary=[],
|
||||
)
|
||||
|
||||
return experiment
|
||||
|
||||
try:
|
||||
experiment = await create_experiment()
|
||||
experiment_url = f"/project/{self._team.project_id}/experiments/{experiment.id}"
|
||||
|
||||
return (
|
||||
f"Successfully created experiment '{name}'. "
|
||||
f"The experiment is in draft mode - you can configure metrics and launch it at {experiment_url}",
|
||||
{
|
||||
"experiment_id": experiment.id,
|
||||
"experiment_name": experiment.name,
|
||||
"feature_flag_key": feature_flag_key,
|
||||
"type": type,
|
||||
"url": experiment_url,
|
||||
},
|
||||
)
|
||||
except ValueError as e:
|
||||
return f"Failed to create experiment: {str(e)}", {"error": str(e)}
|
||||
except Exception as e:
|
||||
capture_exception(e)
|
||||
return f"Failed to create experiment: {str(e)}", {"error": "creation_failed"}
|
||||
|
||||
|
||||
MAX_METRICS_TO_SUMMARIZE = 3
|
||||
|
||||
|
||||
@@ -104,7 +238,8 @@ class ExperimentSummaryTool(MaxTool):
|
||||
|
||||
except Exception as e:
|
||||
capture_exception(
|
||||
e, {"team_id": self._team.id, "user_id": self._user.id, "experiment_id": context.experiment_id}
|
||||
e,
|
||||
properties={"team_id": self._team.id, "user_id": self._user.id, "experiment_id": context.experiment_id},
|
||||
)
|
||||
return ExperimentSummaryOutput(key_metrics=[f"Analysis failed: {str(e)}"])
|
||||
|
||||
@@ -184,7 +319,7 @@ class ExperimentSummaryTool(MaxTool):
|
||||
|
||||
capture_exception(
|
||||
e,
|
||||
{
|
||||
properties={
|
||||
"team_id": self._team.id,
|
||||
"user_id": self._user.id,
|
||||
"context_keys": list(self.context.keys()) if isinstance(self.context, dict) else None,
|
||||
@@ -210,5 +345,5 @@ class ExperimentSummaryTool(MaxTool):
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
capture_exception(e, {"team_id": self._team.id, "user_id": self._user.id})
|
||||
capture_exception(e, properties={"team_id": self._team.id, "user_id": self._user.id})
|
||||
return f"❌ Failed to summarize experiment: {str(e)}", {"error": "summary_failed", "details": str(e)}
|
||||
|
||||
0
products/experiments/backend/test/__init__.py
Normal file
0
products/experiments/backend/test/__init__.py
Normal file
426
products/experiments/backend/test/test_max_tools.py
Normal file
426
products/experiments/backend/test/test_max_tools.py
Normal file
@@ -0,0 +1,426 @@
|
||||
from posthog.test.base import APIBaseTest
|
||||
|
||||
from posthog.models import Experiment, FeatureFlag
|
||||
|
||||
from products.experiments.backend.max_tools import CreateExperimentTool
|
||||
|
||||
from ee.hogai.utils.types import AssistantState
|
||||
|
||||
|
||||
class TestCreateExperimentTool(APIBaseTest):
|
||||
async def test_create_experiment_minimal(self):
|
||||
# Create feature flag first (as the tool expects)
|
||||
await FeatureFlag.objects.acreate(
|
||||
team=self.team,
|
||||
created_by=self.user,
|
||||
key="test-experiment-flag",
|
||||
name="Test Experiment Flag",
|
||||
filters={
|
||||
"groups": [{"properties": [], "rollout_percentage": 100}],
|
||||
"multivariate": {
|
||||
"variants": [
|
||||
{"key": "control", "name": "Control", "rollout_percentage": 50},
|
||||
{"key": "test", "name": "Test", "rollout_percentage": 50},
|
||||
]
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
tool = await CreateExperimentTool.create_tool_class(
|
||||
team=self.team,
|
||||
user=self.user,
|
||||
state=AssistantState(messages=[]),
|
||||
)
|
||||
|
||||
result, artifact = await tool._arun_impl(
|
||||
name="Test Experiment",
|
||||
feature_flag_key="test-experiment-flag",
|
||||
)
|
||||
|
||||
assert "Successfully created" in result
|
||||
assert artifact is not None
|
||||
assert artifact["experiment_name"] == "Test Experiment"
|
||||
assert artifact["feature_flag_key"] == "test-experiment-flag"
|
||||
assert "/experiments/" in artifact["url"]
|
||||
|
||||
from posthog.sync import database_sync_to_async
|
||||
|
||||
@database_sync_to_async
|
||||
def get_experiment():
|
||||
return Experiment.objects.select_related("feature_flag").get(name="Test Experiment", team=self.team)
|
||||
|
||||
experiment = await get_experiment()
|
||||
assert experiment.description == ""
|
||||
assert experiment.type == "product"
|
||||
assert experiment.start_date is None # Draft
|
||||
assert experiment.feature_flag.key == "test-experiment-flag"
|
||||
|
||||
async def test_create_experiment_with_description(self):
|
||||
# Create feature flag first
|
||||
await FeatureFlag.objects.acreate(
|
||||
team=self.team,
|
||||
created_by=self.user,
|
||||
key="checkout-test",
|
||||
name="Checkout Test Flag",
|
||||
filters={
|
||||
"groups": [{"properties": [], "rollout_percentage": 100}],
|
||||
"multivariate": {
|
||||
"variants": [
|
||||
{"key": "control", "name": "Control", "rollout_percentage": 50},
|
||||
{"key": "test", "name": "Test", "rollout_percentage": 50},
|
||||
]
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
tool = await CreateExperimentTool.create_tool_class(
|
||||
team=self.team,
|
||||
user=self.user,
|
||||
state=AssistantState(messages=[]),
|
||||
)
|
||||
|
||||
result, artifact = await tool._arun_impl(
|
||||
name="Checkout Experiment",
|
||||
feature_flag_key="checkout-test",
|
||||
description="Testing new checkout flow to improve conversion rates",
|
||||
)
|
||||
|
||||
assert "Successfully created" in result
|
||||
|
||||
experiment = await Experiment.objects.aget(name="Checkout Experiment", team=self.team)
|
||||
assert experiment.description == "Testing new checkout flow to improve conversion rates"
|
||||
|
||||
async def test_create_experiment_web_type(self):
|
||||
# Create feature flag first
|
||||
await FeatureFlag.objects.acreate(
|
||||
team=self.team,
|
||||
created_by=self.user,
|
||||
key="homepage-redesign",
|
||||
name="Homepage Redesign Flag",
|
||||
filters={
|
||||
"groups": [{"properties": [], "rollout_percentage": 100}],
|
||||
"multivariate": {
|
||||
"variants": [
|
||||
{"key": "control", "name": "Control", "rollout_percentage": 50},
|
||||
{"key": "test", "name": "Test", "rollout_percentage": 50},
|
||||
]
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
tool = await CreateExperimentTool.create_tool_class(
|
||||
team=self.team,
|
||||
user=self.user,
|
||||
state=AssistantState(messages=[]),
|
||||
)
|
||||
|
||||
result, artifact = await tool._arun_impl(
|
||||
name="Homepage Redesign",
|
||||
feature_flag_key="homepage-redesign",
|
||||
type="web",
|
||||
)
|
||||
|
||||
assert "Successfully created" in result
|
||||
assert artifact is not None
|
||||
assert artifact["type"] == "web"
|
||||
|
||||
experiment = await Experiment.objects.aget(name="Homepage Redesign", team=self.team)
|
||||
assert experiment.type == "web"
|
||||
|
||||
async def test_create_experiment_duplicate_name(self):
|
||||
flag = await FeatureFlag.objects.acreate(
|
||||
team=self.team,
|
||||
created_by=self.user,
|
||||
key="existing-flag",
|
||||
filters={
|
||||
"groups": [{"properties": [], "rollout_percentage": 100}],
|
||||
"multivariate": {
|
||||
"variants": [
|
||||
{"key": "control", "name": "Control", "rollout_percentage": 50},
|
||||
{"key": "test", "name": "Test", "rollout_percentage": 50},
|
||||
]
|
||||
},
|
||||
},
|
||||
)
|
||||
await Experiment.objects.acreate(
|
||||
team=self.team,
|
||||
created_by=self.user,
|
||||
name="Existing Experiment",
|
||||
feature_flag=flag,
|
||||
)
|
||||
|
||||
# Create another flag for the duplicate attempt
|
||||
await FeatureFlag.objects.acreate(
|
||||
team=self.team,
|
||||
created_by=self.user,
|
||||
key="another-flag",
|
||||
filters={
|
||||
"groups": [{"properties": [], "rollout_percentage": 100}],
|
||||
"multivariate": {
|
||||
"variants": [
|
||||
{"key": "control", "name": "Control", "rollout_percentage": 50},
|
||||
{"key": "test", "name": "Test", "rollout_percentage": 50},
|
||||
]
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
tool = await CreateExperimentTool.create_tool_class(
|
||||
team=self.team,
|
||||
user=self.user,
|
||||
state=AssistantState(messages=[]),
|
||||
)
|
||||
|
||||
result, artifact = await tool._arun_impl(
|
||||
name="Existing Experiment",
|
||||
feature_flag_key="another-flag",
|
||||
)
|
||||
|
||||
assert "Failed to create" in result
|
||||
assert "already exists" in result
|
||||
assert artifact is not None
|
||||
assert artifact.get("error")
|
||||
|
||||
async def test_create_experiment_with_existing_flag(self):
|
||||
# Create a feature flag first
|
||||
await FeatureFlag.objects.acreate(
|
||||
team=self.team,
|
||||
created_by=self.user,
|
||||
key="existing-flag",
|
||||
name="Existing Flag",
|
||||
filters={
|
||||
"groups": [{"properties": [], "rollout_percentage": 100}],
|
||||
"multivariate": {
|
||||
"variants": [
|
||||
{"key": "control", "name": "Control", "rollout_percentage": 50},
|
||||
{"key": "test", "name": "Test", "rollout_percentage": 50},
|
||||
]
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
tool = await CreateExperimentTool.create_tool_class(
|
||||
team=self.team,
|
||||
user=self.user,
|
||||
state=AssistantState(messages=[]),
|
||||
)
|
||||
|
||||
result, artifact = await tool._arun_impl(
|
||||
name="New Experiment",
|
||||
feature_flag_key="existing-flag",
|
||||
)
|
||||
|
||||
assert "Successfully created" in result
|
||||
|
||||
from posthog.sync import database_sync_to_async
|
||||
|
||||
@database_sync_to_async
|
||||
def get_experiment():
|
||||
return Experiment.objects.select_related("feature_flag").get(name="New Experiment", team=self.team)
|
||||
|
||||
experiment = await get_experiment()
|
||||
assert experiment.feature_flag.key == "existing-flag"
|
||||
|
||||
async def test_create_experiment_flag_already_used(self):
|
||||
# Create a flag and experiment
|
||||
flag = await FeatureFlag.objects.acreate(
|
||||
team=self.team,
|
||||
created_by=self.user,
|
||||
key="used-flag",
|
||||
filters={
|
||||
"groups": [{"properties": [], "rollout_percentage": 100}],
|
||||
"multivariate": {
|
||||
"variants": [
|
||||
{"key": "control", "name": "Control", "rollout_percentage": 50},
|
||||
{"key": "test", "name": "Test", "rollout_percentage": 50},
|
||||
]
|
||||
},
|
||||
},
|
||||
)
|
||||
await Experiment.objects.acreate(
|
||||
team=self.team,
|
||||
created_by=self.user,
|
||||
name="First Experiment",
|
||||
feature_flag=flag,
|
||||
)
|
||||
|
||||
tool = await CreateExperimentTool.create_tool_class(
|
||||
team=self.team,
|
||||
user=self.user,
|
||||
state=AssistantState(messages=[]),
|
||||
)
|
||||
|
||||
result, artifact = await tool._arun_impl(
|
||||
name="Second Experiment",
|
||||
feature_flag_key="used-flag",
|
||||
)
|
||||
|
||||
assert "Failed to create" in result
|
||||
assert "already used by experiment" in result
|
||||
assert artifact is not None
|
||||
assert artifact.get("error")
|
||||
|
||||
async def test_create_experiment_default_parameters(self):
|
||||
# Create feature flag first
|
||||
await FeatureFlag.objects.acreate(
|
||||
team=self.team,
|
||||
created_by=self.user,
|
||||
key="param-test",
|
||||
name="Param Test Flag",
|
||||
filters={
|
||||
"groups": [{"properties": [], "rollout_percentage": 100}],
|
||||
"multivariate": {
|
||||
"variants": [
|
||||
{"key": "control", "name": "Control", "rollout_percentage": 50},
|
||||
{"key": "test", "name": "Test", "rollout_percentage": 50},
|
||||
]
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
tool = await CreateExperimentTool.create_tool_class(
|
||||
team=self.team,
|
||||
user=self.user,
|
||||
state=AssistantState(messages=[]),
|
||||
)
|
||||
|
||||
result, artifact = await tool._arun_impl(
|
||||
name="Parameter Test",
|
||||
feature_flag_key="param-test",
|
||||
)
|
||||
|
||||
assert "Successfully created" in result
|
||||
|
||||
experiment = await Experiment.objects.aget(name="Parameter Test", team=self.team)
|
||||
# Variants should come from the feature flag, not hardcoded
|
||||
assert experiment.parameters is not None
|
||||
assert experiment.parameters["feature_flag_variants"] == [
|
||||
{"key": "control", "name": "Control", "rollout_percentage": 50},
|
||||
{"key": "test", "name": "Test", "rollout_percentage": 50},
|
||||
]
|
||||
assert experiment.parameters["minimum_detectable_effect"] == 30
|
||||
assert experiment.metrics == []
|
||||
assert experiment.metrics_secondary == []
|
||||
|
||||
async def test_create_experiment_missing_flag(self):
|
||||
"""Test error when trying to create experiment with non-existent flag."""
|
||||
tool = await CreateExperimentTool.create_tool_class(
|
||||
team=self.team,
|
||||
user=self.user,
|
||||
state=AssistantState(messages=[]),
|
||||
)
|
||||
|
||||
result, artifact = await tool._arun_impl(
|
||||
name="Test Experiment",
|
||||
feature_flag_key="non-existent-flag",
|
||||
)
|
||||
|
||||
assert "Failed to create" in result
|
||||
assert "does not exist" in result
|
||||
assert artifact is not None
|
||||
assert artifact.get("error")
|
||||
|
||||
async def test_create_experiment_flag_without_variants(self):
|
||||
"""Test error when flag doesn't have multivariate variants."""
|
||||
# Create a flag without multivariate
|
||||
await FeatureFlag.objects.acreate(
|
||||
team=self.team,
|
||||
created_by=self.user,
|
||||
key="no-variants-flag",
|
||||
name="No Variants Flag",
|
||||
filters={"groups": [{"properties": [], "rollout_percentage": 100}]},
|
||||
)
|
||||
|
||||
tool = await CreateExperimentTool.create_tool_class(
|
||||
team=self.team,
|
||||
user=self.user,
|
||||
state=AssistantState(messages=[]),
|
||||
)
|
||||
|
||||
result, artifact = await tool._arun_impl(
|
||||
name="Test Experiment",
|
||||
feature_flag_key="no-variants-flag",
|
||||
)
|
||||
|
||||
assert "Failed to create" in result
|
||||
assert "must have multivariate variants" in result
|
||||
assert artifact is not None
|
||||
assert artifact.get("error")
|
||||
|
||||
async def test_create_experiment_flag_with_one_variant(self):
|
||||
"""Test error when flag has only 1 variant (need at least 2)."""
|
||||
# Create a flag with only 1 variant
|
||||
await FeatureFlag.objects.acreate(
|
||||
team=self.team,
|
||||
created_by=self.user,
|
||||
key="one-variant-flag",
|
||||
name="One Variant Flag",
|
||||
filters={
|
||||
"groups": [{"properties": [], "rollout_percentage": 100}],
|
||||
"multivariate": {
|
||||
"variants": [
|
||||
{"key": "only_one", "name": "Only Variant", "rollout_percentage": 100},
|
||||
]
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
tool = await CreateExperimentTool.create_tool_class(
|
||||
team=self.team,
|
||||
user=self.user,
|
||||
state=AssistantState(messages=[]),
|
||||
)
|
||||
|
||||
result, artifact = await tool._arun_impl(
|
||||
name="Test Experiment",
|
||||
feature_flag_key="one-variant-flag",
|
||||
)
|
||||
|
||||
assert "Failed to create" in result
|
||||
assert "at least 2 variants" in result
|
||||
assert artifact is not None
|
||||
assert artifact.get("error")
|
||||
|
||||
async def test_create_experiment_uses_flag_variants(self):
|
||||
"""Test that experiment uses the actual variants from the feature flag."""
|
||||
# Create a flag with 3 custom variants
|
||||
await FeatureFlag.objects.acreate(
|
||||
team=self.team,
|
||||
created_by=self.user,
|
||||
key="custom-variants-flag",
|
||||
name="Custom Variants Flag",
|
||||
filters={
|
||||
"groups": [{"properties": [], "rollout_percentage": 100}],
|
||||
"multivariate": {
|
||||
"variants": [
|
||||
{"key": "variant_a", "name": "Variant A", "rollout_percentage": 33},
|
||||
{"key": "variant_b", "name": "Variant B", "rollout_percentage": 33},
|
||||
{"key": "variant_c", "name": "Variant C", "rollout_percentage": 34},
|
||||
]
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
tool = await CreateExperimentTool.create_tool_class(
|
||||
team=self.team,
|
||||
user=self.user,
|
||||
state=AssistantState(messages=[]),
|
||||
)
|
||||
|
||||
result, artifact = await tool._arun_impl(
|
||||
name="Custom Variants Test",
|
||||
feature_flag_key="custom-variants-flag",
|
||||
)
|
||||
|
||||
assert "Successfully created" in result
|
||||
|
||||
experiment = await Experiment.objects.aget(name="Custom Variants Test", team=self.team)
|
||||
assert experiment is not None
|
||||
assert experiment.parameters is not None
|
||||
assert len(experiment.parameters["feature_flag_variants"]) == 3
|
||||
assert experiment.parameters["feature_flag_variants"][0] is not None
|
||||
assert experiment.parameters["feature_flag_variants"][0]["key"] == "variant_a"
|
||||
assert experiment.parameters["feature_flag_variants"][0]["name"] == "Variant A"
|
||||
assert experiment.parameters["feature_flag_variants"][0]["rollout_percentage"] == 33
|
||||
assert experiment.parameters["feature_flag_variants"][1]["key"] == "variant_b"
|
||||
assert experiment.parameters["feature_flag_variants"][2]["key"] == "variant_c"
|
||||
@@ -609,7 +609,11 @@ The tool will automatically:
|
||||
|
||||
**Group-based:**
|
||||
- "Create a flag targeting organizations"
|
||||
- "Create a flag for companies where employee count > 100"
|
||||
- "Create a flag for companies where employee count > 100
|
||||
|
||||
**For experiments**: If creating a flag for an A/B test or experiment, after creating
|
||||
the flag, you should navigate to the experiments page and use create_experiment with
|
||||
this flag's key to complete the experiment setup."
|
||||
""".strip()
|
||||
context_prompt_template: str = "Creates a new feature flag in the project with optional property-based targeting and multivariate variants for A/B testing"
|
||||
args_schema: type[BaseModel] = CreateFeatureFlagArgs
|
||||
|
||||
Reference in New Issue
Block a user