feat(ph-ai): create experiments max tool (#40753)

## Problem
This PR introduces a new `MaxTool` to create experiments. It builds on top of the `create_feature_flag` `MaxTool` with multivariate flag support.

## Changes
- Added a new `CreateExperimentTool` that allows creating experiments via PostHog AI

## How did you test this code?
New tests + evals

## Changelog: (features only) Is this feature complete?
Yes, this feature is complete and ready to be included in the changelog.
This commit is contained in:
Emanuele Capparelli
2025-11-13 16:59:58 +00:00
committed by GitHub
parent dfaf7b28d7
commit daa062b202
10 changed files with 1041 additions and 20 deletions

View File

@@ -0,0 +1,407 @@
"""Evaluations for CreateExperimentTool."""
import uuid
import pytest
from autoevals.partial import ScorerWithPartial
from autoevals.ragas import AnswerSimilarity
from braintrust import EvalCase, Score
from posthog.models import Experiment, FeatureFlag
from products.experiments.backend.max_tools import CreateExperimentTool
from ee.hogai.eval.base import MaxPublicEval
from ee.hogai.utils.types import AssistantState
from ee.models.assistant import Conversation
class ExperimentOutputScorer(ScorerWithPartial):
"""Custom scorer for experiment tool output that combines semantic similarity for text and exact matching for numbers/booleans."""
def __init__(self, semantic_fields: set[str] | None = None, **kwargs):
super().__init__(**kwargs)
self.semantic_fields = semantic_fields or {"message"}
def _run_eval_sync(self, output: dict, expected: dict, **kwargs):
if not expected:
return Score(name=self._name(), score=None, metadata={"reason": "No expected value provided"})
if not output:
return Score(name=self._name(), score=0.0, metadata={"reason": "No output provided"})
total_fields = len(expected)
if total_fields == 0:
return Score(name=self._name(), score=1.0)
score_per_field = 1.0 / total_fields
total_score = 0.0
metadata = {}
for field_name, expected_value in expected.items():
actual_value = output.get(field_name)
if field_name in self.semantic_fields:
# Use semantic similarity for text fields
if actual_value is not None and expected_value is not None:
similarity_scorer = AnswerSimilarity(model="text-embedding-3-small")
result = similarity_scorer.eval(output=str(actual_value), expected=str(expected_value))
field_score = result.score * score_per_field
total_score += field_score
metadata[f"{field_name}_score"] = result.score
else:
metadata[f"{field_name}_missing"] = True
else:
# Use exact match for numeric/boolean fields
if actual_value == expected_value:
total_score += score_per_field
metadata[f"{field_name}_match"] = True
else:
metadata[f"{field_name}_mismatch"] = {
"expected": expected_value,
"actual": actual_value,
}
return Score(name=self._name(), score=total_score, metadata=metadata)
@pytest.mark.django_db
async def eval_create_experiment_basic(pytestconfig, demo_org_team_user):
"""Test basic experiment creation."""
_, team, user = demo_org_team_user
conversation = await Conversation.objects.acreate(team=team, user=user)
async def task_create_experiment(args: dict):
# Create feature flag first (required by the tool)
await FeatureFlag.objects.acreate(
team=team,
created_by=user,
key=args["feature_flag_key"],
name=f"Flag for {args['name']}",
filters={
"groups": [{"properties": [], "rollout_percentage": 100}],
"multivariate": {
"variants": [
{"key": "control", "name": "Control", "rollout_percentage": 50},
{"key": "test", "name": "Test", "rollout_percentage": 50},
]
},
},
)
tool = await CreateExperimentTool.create_tool_class(
team=team,
user=user,
state=AssistantState(messages=[]),
config={
"configurable": {
"thread_id": conversation.id,
"team": team,
"user": user,
}
},
)
result_message, artifact = await tool._arun_impl(
name=args["name"],
feature_flag_key=args["feature_flag_key"],
description=args.get("description"),
type=args.get("type", "product"),
)
exp_created = await Experiment.objects.filter(team=team, name=args["name"], deleted=False).aexists()
return {
"message": result_message,
"experiment_created": exp_created,
"experiment_name": artifact.get("experiment_name") if artifact else None,
}
await MaxPublicEval(
experiment_name="create_experiment_basic",
task=task_create_experiment, # type: ignore
scores=[ExperimentOutputScorer(semantic_fields={"message", "experiment_name"})],
data=[
EvalCase(
input={"name": "Pricing Test", "feature_flag_key": "pricing-test-flag"},
expected={
"message": "Successfully created experiment",
"experiment_created": True,
"experiment_name": "Pricing Test",
},
),
EvalCase(
input={
"name": "Homepage Redesign",
"feature_flag_key": "homepage-redesign",
"description": "Testing new homepage layout for better conversion",
},
expected={
"message": "Successfully created experiment",
"experiment_created": True,
"experiment_name": "Homepage Redesign",
},
),
],
pytestconfig=pytestconfig,
)
@pytest.mark.django_db
async def eval_create_experiment_types(pytestconfig, demo_org_team_user):
"""Test experiment creation with different types (product vs web)."""
_, team, user = demo_org_team_user
conversation = await Conversation.objects.acreate(team=team, user=user)
async def task_create_typed_experiment(args: dict):
# Create feature flag first
await FeatureFlag.objects.acreate(
team=team,
created_by=user,
key=args["feature_flag_key"],
name=f"Flag for {args['name']}",
filters={
"groups": [{"properties": [], "rollout_percentage": 100}],
"multivariate": {
"variants": [
{"key": "control", "name": "Control", "rollout_percentage": 50},
{"key": "test", "name": "Test", "rollout_percentage": 50},
]
},
},
)
tool = await CreateExperimentTool.create_tool_class(
team=team,
user=user,
state=AssistantState(messages=[]),
config={
"configurable": {
"thread_id": conversation.id,
"team": team,
"user": user,
}
},
)
result_message, artifact = await tool._arun_impl(
name=args["name"],
feature_flag_key=args["feature_flag_key"],
type=args["type"],
)
# Verify experiment type
experiment = await Experiment.objects.aget(team=team, name=args["name"])
return {
"message": result_message,
"experiment_type": experiment.type,
"artifact_type": artifact.get("type") if artifact else None,
}
await MaxPublicEval(
experiment_name="create_experiment_types",
task=task_create_typed_experiment, # type: ignore
scores=[ExperimentOutputScorer(semantic_fields={"message"})],
data=[
EvalCase(
input={"name": "Product Feature Test", "feature_flag_key": "product-test", "type": "product"},
expected={
"message": "Successfully created experiment",
"experiment_type": "product",
"artifact_type": "product",
},
),
EvalCase(
input={"name": "Web UI Test", "feature_flag_key": "web-test", "type": "web"},
expected={
"message": "Successfully created experiment",
"experiment_type": "web",
"artifact_type": "web",
},
),
],
pytestconfig=pytestconfig,
)
@pytest.mark.django_db
async def eval_create_experiment_with_existing_flag(pytestconfig, demo_org_team_user):
"""Test experiment creation with an existing feature flag."""
_, team, user = demo_org_team_user
# Create an existing flag with unique key and multivariate variants
unique_key = f"reusable-flag-{uuid.uuid4().hex[:8]}"
await FeatureFlag.objects.acreate(
team=team,
key=unique_key,
name="Reusable Flag",
created_by=user,
filters={
"groups": [{"properties": [], "rollout_percentage": 100}],
"multivariate": {
"variants": [
{"key": "control", "name": "Control", "rollout_percentage": 50},
{"key": "test", "name": "Test", "rollout_percentage": 50},
]
},
},
)
conversation = await Conversation.objects.acreate(team=team, user=user)
async def task_create_experiment_reuse_flag(args: dict):
tool = await CreateExperimentTool.create_tool_class(
team=team,
user=user,
state=AssistantState(messages=[]),
config={
"configurable": {
"thread_id": conversation.id,
"team": team,
"user": user,
}
},
)
result_message, artifact = await tool._arun_impl(
name=args["name"],
feature_flag_key=args["feature_flag_key"],
)
return {
"message": result_message,
"experiment_created": artifact is not None and "experiment_id" in artifact,
}
await MaxPublicEval(
experiment_name="create_experiment_with_existing_flag",
task=task_create_experiment_reuse_flag, # type: ignore
scores=[ExperimentOutputScorer(semantic_fields={"message"})],
data=[
EvalCase(
input={"name": "Reuse Flag Test", "feature_flag_key": unique_key},
expected={
"message": "Successfully created experiment",
"experiment_created": True,
},
),
],
pytestconfig=pytestconfig,
)
@pytest.mark.django_db
async def eval_create_experiment_duplicate_name_error(pytestconfig, demo_org_team_user):
"""Test that creating a duplicate experiment returns an error."""
_, team, user = demo_org_team_user
# Create an existing experiment with unique flag key
unique_flag_key = f"test-flag-{uuid.uuid4().hex[:8]}"
flag = await FeatureFlag.objects.acreate(team=team, key=unique_flag_key, created_by=user)
await Experiment.objects.acreate(team=team, name="Existing Experiment", feature_flag=flag, created_by=user)
conversation = await Conversation.objects.acreate(team=team, user=user)
async def task_create_duplicate_experiment(args: dict):
# Create a different flag for the duplicate attempt
await FeatureFlag.objects.acreate(
team=team,
created_by=user,
key=args["feature_flag_key"],
name=f"Flag for {args['name']}",
)
tool = await CreateExperimentTool.create_tool_class(
team=team,
user=user,
state=AssistantState(messages=[]),
config={
"configurable": {
"thread_id": conversation.id,
"team": team,
"user": user,
}
},
)
result_message, artifact = await tool._arun_impl(
name=args["name"],
feature_flag_key=args["feature_flag_key"],
)
return {
"message": result_message,
"has_error": artifact.get("error") is not None if artifact else False,
}
await MaxPublicEval(
experiment_name="create_experiment_duplicate_name_error",
task=task_create_duplicate_experiment, # type: ignore
scores=[ExperimentOutputScorer(semantic_fields={"message"})],
data=[
EvalCase(
input={"name": "Existing Experiment", "feature_flag_key": "another-flag"},
expected={
"message": "Failed to create experiment: An experiment with name 'Existing Experiment' already exists",
"has_error": True,
},
),
],
pytestconfig=pytestconfig,
)
@pytest.mark.django_db
async def eval_create_experiment_flag_already_used_error(pytestconfig, demo_org_team_user):
"""Test that using a flag already tied to another experiment returns an error."""
_, team, user = demo_org_team_user
# Create an experiment with a flag (unique key)
unique_flag_key = f"used-flag-{uuid.uuid4().hex[:8]}"
flag = await FeatureFlag.objects.acreate(team=team, key=unique_flag_key, created_by=user)
await Experiment.objects.acreate(team=team, name="First Experiment", feature_flag=flag, created_by=user)
conversation = await Conversation.objects.acreate(team=team, user=user)
async def task_create_experiment_with_used_flag(args: dict):
tool = await CreateExperimentTool.create_tool_class(
team=team,
user=user,
state=AssistantState(messages=[]),
config={
"configurable": {
"thread_id": conversation.id,
"team": team,
"user": user,
}
},
)
result_message, artifact = await tool._arun_impl(
name=args["name"],
feature_flag_key=args["feature_flag_key"],
)
return {
"message": result_message,
"has_error": artifact.get("error") is not None if artifact else False,
}
await MaxPublicEval(
experiment_name="create_experiment_flag_already_used_error",
task=task_create_experiment_with_used_flag, # type: ignore
scores=[ExperimentOutputScorer(semantic_fields={"message"})],
data=[
EvalCase(
input={"name": "Second Experiment", "feature_flag_key": unique_flag_key},
expected={
"message": "Failed to create experiment: Feature flag is already used by experiment",
"has_error": True,
},
),
],
pytestconfig=pytestconfig,
)

View File

@@ -1850,7 +1850,8 @@
"read_data",
"todo_write",
"filter_revenue_analytics",
"create_feature_flag"
"create_feature_flag",
"create_experiment"
],
"type": "string"
},

View File

@@ -287,6 +287,7 @@ export type AssistantTool =
| 'todo_write'
| 'filter_revenue_analytics'
| 'create_feature_flag'
| 'create_experiment'
export enum AgentMode {
ProductAnalytics = 'product_analytics',

View File

@@ -3,7 +3,7 @@ import { router } from 'kea-router'
import { useState } from 'react'
import { match } from 'ts-pattern'
import { LemonDialog, LemonInput, LemonSelect, LemonTag, Tooltip } from '@posthog/lemon-ui'
import { LemonDialog, LemonInput, LemonSelect, LemonTag, Tooltip, lemonToast } from '@posthog/lemon-ui'
import { AccessControlAction } from 'lib/components/AccessControlAction'
import { ActivityLog } from 'lib/components/ActivityLog/ActivityLog'
@@ -20,6 +20,7 @@ import { atColumn, createdAtColumn, createdByColumn } from 'lib/lemon-ui/LemonTa
import { LemonTabs } from 'lib/lemon-ui/LemonTabs'
import { deleteWithUndo } from 'lib/utils/deleteWithUndo'
import stringWithWBR from 'lib/utils/stringWithWBR'
import MaxTool from 'scenes/max/MaxTool'
import { useMaxTool } from 'scenes/max/useMaxTool'
import { SceneExport } from 'scenes/sceneTypes'
import { urls } from 'scenes/urls'
@@ -438,14 +439,7 @@ export function Experiments(): JSX.Element {
resourceType={AccessControlResourceType.Experiment}
minAccessLevel={AccessControlLevel.Editor}
>
<LemonButton
size="small"
type="primary"
data-attr="create-experiment"
to={urls.experiment('new')}
>
New experiment
</LemonButton>
<NewExperimentButton />
</AccessControlAction>
) : undefined
}
@@ -493,3 +487,43 @@ export function Experiments(): JSX.Element {
</SceneContent>
)
}
function NewExperimentButton(): JSX.Element {
const { loadExperiments } = useActions(experimentsLogic)
useMaxTool({ identifier: 'create_feature_flag', context: {} })
return (
<MaxTool
identifier="create_experiment"
initialMaxPrompt="Create an experiment for "
suggestions={[
'Create an experiment to test our new checkout flow',
'Set up an A/B test for the pricing page redesign',
'Create an experiment to test different call-to-action buttons on the homepage',
'Create an experiment for testing our new recommendation algorithm',
]}
callback={(toolOutput: {
experiment_id?: string | number
experiment_name?: string
feature_flag_key?: string
error?: string
}) => {
if (toolOutput?.error || !toolOutput?.experiment_id) {
lemonToast.error(`Failed to create experiment: ${toolOutput?.error || 'Unknown error'}`)
return
}
// Refresh experiments list to show new experiment, then redirect to it
loadExperiments()
router.actions.push(urls.experiment(toolOutput.experiment_id))
}}
position="bottom-right"
active={true}
context={{}}
>
<LemonButton size="small" type="primary" data-attr="create-experiment" to={urls.experiment('new')}>
<span className="pr-3">New experiment</span>
</LemonButton>
</MaxTool>
)
}

View File

@@ -359,6 +359,18 @@ export const TOOL_DEFINITIONS: Record<Exclude<AssistantTool, 'todo_write'>, Tool
return 'Creating feature flag...'
},
},
create_experiment: {
name: 'Create an experiment',
description: 'Create an experiment in seconds',
product: Scene.Experiments,
icon: iconForType('experiment'),
displayFormatter: (toolCall) => {
if (toolCall.status === 'completed') {
return 'Created experiment'
}
return 'Creating experiment...'
},
},
}
export const MAX_GENERALLY_CAN: { icon: JSX.Element; description: string }[] = [

View File

@@ -285,6 +285,7 @@ class AssistantTool(StrEnum):
TODO_WRITE = "todo_write"
FILTER_REVENUE_ANALYTICS = "filter_revenue_analytics"
CREATE_FEATURE_FLAG = "create_feature_flag"
CREATE_EXPERIMENT = "create_experiment"
class AssistantToolCall(BaseModel):

View File

@@ -1,20 +1,154 @@
"""
MaxTool for AI-powered experiment summary.
"""
from typing import Any
from typing import Any, Literal
from posthoganalytics import capture_exception
from pydantic import BaseModel, Field
from posthog.schema import MaxExperimentSummaryContext
from posthog.exceptions_capture import capture_exception
from posthog.models import Experiment, FeatureFlag
from posthog.sync import database_sync_to_async
from ee.hogai.llm import MaxChatOpenAI
from ee.hogai.tool import MaxTool
from .prompts import EXPERIMENT_SUMMARY_BAYESIAN_PROMPT, EXPERIMENT_SUMMARY_FREQUENTIST_PROMPT
class CreateExperimentArgs(BaseModel):
name: str = Field(description="Experiment name - should clearly describe what is being tested")
feature_flag_key: str = Field(
description="Feature flag key (letters, numbers, hyphens, underscores only). Will create a new flag if it doesn't exist."
)
description: str | None = Field(
default=None,
description="Detailed description of the experiment hypothesis, what changes are being tested, and expected outcomes",
)
type: Literal["product", "web"] = Field(
default="product",
description="Experiment type: 'product' for backend/API changes, 'web' for frontend UI changes",
)
class CreateExperimentTool(MaxTool):
name: Literal["create_experiment"] = "create_experiment"
description: str = """
Create a new A/B test experiment in the current project.
Experiments allow you to test changes with a controlled rollout and measure their impact.
Use this tool when the user wants to:
- Create a new A/B test experiment
- Set up a controlled experiment to test changes
- Test variants of a feature with users
Examples:
- "Create an experiment to test the new checkout flow"
- "Set up an A/B test for our pricing page redesign"
- "Create an experiment called 'homepage-cta-test' to test different call-to-action buttons
**IMPORTANT**: You must first find or create a multivariate feature flag using `create_feature_flag`, with at least two variants (control and test). Navigate to the feature flags page to create the flag, create the flag, then navigate back to the experiments page and use this tool to create the experiment."
""".strip()
context_prompt_template: str = "Creates a new A/B test experiment in the project"
args_schema: type[BaseModel] = CreateExperimentArgs
async def _arun_impl(
self,
name: str,
feature_flag_key: str,
description: str | None = None,
type: Literal["product", "web"] = "product",
) -> tuple[str, dict[str, Any] | None]:
# Validate inputs
if not name or not name.strip():
return "Experiment name cannot be empty", {"error": "invalid_name"}
if not feature_flag_key or not feature_flag_key.strip():
return "Feature flag key cannot be empty", {"error": "invalid_flag_key"}
@database_sync_to_async
def create_experiment() -> Experiment:
# Check if experiment with this name already exists
existing_experiment = Experiment.objects.filter(team=self._team, name=name, deleted=False).first()
if existing_experiment:
raise ValueError(f"An experiment with name '{name}' already exists")
try:
feature_flag = FeatureFlag.objects.get(team=self._team, key=feature_flag_key, deleted=False)
except FeatureFlag.DoesNotExist:
raise ValueError(f"Feature flag '{feature_flag_key}' does not exist")
# Validate that the flag has multivariate variants
multivariate = feature_flag.filters.get("multivariate")
if not multivariate or not multivariate.get("variants"):
raise ValueError(
f"Feature flag '{feature_flag_key}' must have multivariate variants to be used in an experiment. "
f"Create the flag with variants first using the create_feature_flag tool."
)
variants = multivariate["variants"]
if len(variants) < 2:
raise ValueError(
f"Feature flag '{feature_flag_key}' must have at least 2 variants for an experiment (e.g., control and test)"
)
# If flag already exists and is already used by another experiment, raise error
existing_experiment_with_flag = Experiment.objects.filter(feature_flag=feature_flag, deleted=False).first()
if existing_experiment_with_flag:
raise ValueError(
f"Feature flag '{feature_flag_key}' is already used by experiment '{existing_experiment_with_flag.name}'"
)
# Use the actual variants from the feature flag
feature_flag_variants = [
{
"key": variant["key"],
"name": variant.get("name", variant["key"]),
"rollout_percentage": variant["rollout_percentage"],
}
for variant in variants
]
# Create the experiment as a draft (no start_date)
experiment = Experiment.objects.create(
team=self._team,
created_by=self._user,
name=name,
description=description or "",
type=type,
feature_flag=feature_flag,
filters={}, # Empty filters for draft
parameters={
"feature_flag_variants": feature_flag_variants,
"minimum_detectable_effect": 30,
},
metrics=[],
metrics_secondary=[],
)
return experiment
try:
experiment = await create_experiment()
experiment_url = f"/project/{self._team.project_id}/experiments/{experiment.id}"
return (
f"Successfully created experiment '{name}'. "
f"The experiment is in draft mode - you can configure metrics and launch it at {experiment_url}",
{
"experiment_id": experiment.id,
"experiment_name": experiment.name,
"feature_flag_key": feature_flag_key,
"type": type,
"url": experiment_url,
},
)
except ValueError as e:
return f"Failed to create experiment: {str(e)}", {"error": str(e)}
except Exception as e:
capture_exception(e)
return f"Failed to create experiment: {str(e)}", {"error": "creation_failed"}
MAX_METRICS_TO_SUMMARIZE = 3
@@ -104,7 +238,8 @@ class ExperimentSummaryTool(MaxTool):
except Exception as e:
capture_exception(
e, {"team_id": self._team.id, "user_id": self._user.id, "experiment_id": context.experiment_id}
e,
properties={"team_id": self._team.id, "user_id": self._user.id, "experiment_id": context.experiment_id},
)
return ExperimentSummaryOutput(key_metrics=[f"Analysis failed: {str(e)}"])
@@ -184,7 +319,7 @@ class ExperimentSummaryTool(MaxTool):
capture_exception(
e,
{
properties={
"team_id": self._team.id,
"user_id": self._user.id,
"context_keys": list(self.context.keys()) if isinstance(self.context, dict) else None,
@@ -210,5 +345,5 @@ class ExperimentSummaryTool(MaxTool):
}
except Exception as e:
capture_exception(e, {"team_id": self._team.id, "user_id": self._user.id})
capture_exception(e, properties={"team_id": self._team.id, "user_id": self._user.id})
return f"❌ Failed to summarize experiment: {str(e)}", {"error": "summary_failed", "details": str(e)}

View File

@@ -0,0 +1,426 @@
from posthog.test.base import APIBaseTest
from posthog.models import Experiment, FeatureFlag
from products.experiments.backend.max_tools import CreateExperimentTool
from ee.hogai.utils.types import AssistantState
class TestCreateExperimentTool(APIBaseTest):
async def test_create_experiment_minimal(self):
# Create feature flag first (as the tool expects)
await FeatureFlag.objects.acreate(
team=self.team,
created_by=self.user,
key="test-experiment-flag",
name="Test Experiment Flag",
filters={
"groups": [{"properties": [], "rollout_percentage": 100}],
"multivariate": {
"variants": [
{"key": "control", "name": "Control", "rollout_percentage": 50},
{"key": "test", "name": "Test", "rollout_percentage": 50},
]
},
},
)
tool = await CreateExperimentTool.create_tool_class(
team=self.team,
user=self.user,
state=AssistantState(messages=[]),
)
result, artifact = await tool._arun_impl(
name="Test Experiment",
feature_flag_key="test-experiment-flag",
)
assert "Successfully created" in result
assert artifact is not None
assert artifact["experiment_name"] == "Test Experiment"
assert artifact["feature_flag_key"] == "test-experiment-flag"
assert "/experiments/" in artifact["url"]
from posthog.sync import database_sync_to_async
@database_sync_to_async
def get_experiment():
return Experiment.objects.select_related("feature_flag").get(name="Test Experiment", team=self.team)
experiment = await get_experiment()
assert experiment.description == ""
assert experiment.type == "product"
assert experiment.start_date is None # Draft
assert experiment.feature_flag.key == "test-experiment-flag"
async def test_create_experiment_with_description(self):
# Create feature flag first
await FeatureFlag.objects.acreate(
team=self.team,
created_by=self.user,
key="checkout-test",
name="Checkout Test Flag",
filters={
"groups": [{"properties": [], "rollout_percentage": 100}],
"multivariate": {
"variants": [
{"key": "control", "name": "Control", "rollout_percentage": 50},
{"key": "test", "name": "Test", "rollout_percentage": 50},
]
},
},
)
tool = await CreateExperimentTool.create_tool_class(
team=self.team,
user=self.user,
state=AssistantState(messages=[]),
)
result, artifact = await tool._arun_impl(
name="Checkout Experiment",
feature_flag_key="checkout-test",
description="Testing new checkout flow to improve conversion rates",
)
assert "Successfully created" in result
experiment = await Experiment.objects.aget(name="Checkout Experiment", team=self.team)
assert experiment.description == "Testing new checkout flow to improve conversion rates"
async def test_create_experiment_web_type(self):
# Create feature flag first
await FeatureFlag.objects.acreate(
team=self.team,
created_by=self.user,
key="homepage-redesign",
name="Homepage Redesign Flag",
filters={
"groups": [{"properties": [], "rollout_percentage": 100}],
"multivariate": {
"variants": [
{"key": "control", "name": "Control", "rollout_percentage": 50},
{"key": "test", "name": "Test", "rollout_percentage": 50},
]
},
},
)
tool = await CreateExperimentTool.create_tool_class(
team=self.team,
user=self.user,
state=AssistantState(messages=[]),
)
result, artifact = await tool._arun_impl(
name="Homepage Redesign",
feature_flag_key="homepage-redesign",
type="web",
)
assert "Successfully created" in result
assert artifact is not None
assert artifact["type"] == "web"
experiment = await Experiment.objects.aget(name="Homepage Redesign", team=self.team)
assert experiment.type == "web"
async def test_create_experiment_duplicate_name(self):
flag = await FeatureFlag.objects.acreate(
team=self.team,
created_by=self.user,
key="existing-flag",
filters={
"groups": [{"properties": [], "rollout_percentage": 100}],
"multivariate": {
"variants": [
{"key": "control", "name": "Control", "rollout_percentage": 50},
{"key": "test", "name": "Test", "rollout_percentage": 50},
]
},
},
)
await Experiment.objects.acreate(
team=self.team,
created_by=self.user,
name="Existing Experiment",
feature_flag=flag,
)
# Create another flag for the duplicate attempt
await FeatureFlag.objects.acreate(
team=self.team,
created_by=self.user,
key="another-flag",
filters={
"groups": [{"properties": [], "rollout_percentage": 100}],
"multivariate": {
"variants": [
{"key": "control", "name": "Control", "rollout_percentage": 50},
{"key": "test", "name": "Test", "rollout_percentage": 50},
]
},
},
)
tool = await CreateExperimentTool.create_tool_class(
team=self.team,
user=self.user,
state=AssistantState(messages=[]),
)
result, artifact = await tool._arun_impl(
name="Existing Experiment",
feature_flag_key="another-flag",
)
assert "Failed to create" in result
assert "already exists" in result
assert artifact is not None
assert artifact.get("error")
async def test_create_experiment_with_existing_flag(self):
# Create a feature flag first
await FeatureFlag.objects.acreate(
team=self.team,
created_by=self.user,
key="existing-flag",
name="Existing Flag",
filters={
"groups": [{"properties": [], "rollout_percentage": 100}],
"multivariate": {
"variants": [
{"key": "control", "name": "Control", "rollout_percentage": 50},
{"key": "test", "name": "Test", "rollout_percentage": 50},
]
},
},
)
tool = await CreateExperimentTool.create_tool_class(
team=self.team,
user=self.user,
state=AssistantState(messages=[]),
)
result, artifact = await tool._arun_impl(
name="New Experiment",
feature_flag_key="existing-flag",
)
assert "Successfully created" in result
from posthog.sync import database_sync_to_async
@database_sync_to_async
def get_experiment():
return Experiment.objects.select_related("feature_flag").get(name="New Experiment", team=self.team)
experiment = await get_experiment()
assert experiment.feature_flag.key == "existing-flag"
async def test_create_experiment_flag_already_used(self):
# Create a flag and experiment
flag = await FeatureFlag.objects.acreate(
team=self.team,
created_by=self.user,
key="used-flag",
filters={
"groups": [{"properties": [], "rollout_percentage": 100}],
"multivariate": {
"variants": [
{"key": "control", "name": "Control", "rollout_percentage": 50},
{"key": "test", "name": "Test", "rollout_percentage": 50},
]
},
},
)
await Experiment.objects.acreate(
team=self.team,
created_by=self.user,
name="First Experiment",
feature_flag=flag,
)
tool = await CreateExperimentTool.create_tool_class(
team=self.team,
user=self.user,
state=AssistantState(messages=[]),
)
result, artifact = await tool._arun_impl(
name="Second Experiment",
feature_flag_key="used-flag",
)
assert "Failed to create" in result
assert "already used by experiment" in result
assert artifact is not None
assert artifact.get("error")
async def test_create_experiment_default_parameters(self):
# Create feature flag first
await FeatureFlag.objects.acreate(
team=self.team,
created_by=self.user,
key="param-test",
name="Param Test Flag",
filters={
"groups": [{"properties": [], "rollout_percentage": 100}],
"multivariate": {
"variants": [
{"key": "control", "name": "Control", "rollout_percentage": 50},
{"key": "test", "name": "Test", "rollout_percentage": 50},
]
},
},
)
tool = await CreateExperimentTool.create_tool_class(
team=self.team,
user=self.user,
state=AssistantState(messages=[]),
)
result, artifact = await tool._arun_impl(
name="Parameter Test",
feature_flag_key="param-test",
)
assert "Successfully created" in result
experiment = await Experiment.objects.aget(name="Parameter Test", team=self.team)
# Variants should come from the feature flag, not hardcoded
assert experiment.parameters is not None
assert experiment.parameters["feature_flag_variants"] == [
{"key": "control", "name": "Control", "rollout_percentage": 50},
{"key": "test", "name": "Test", "rollout_percentage": 50},
]
assert experiment.parameters["minimum_detectable_effect"] == 30
assert experiment.metrics == []
assert experiment.metrics_secondary == []
async def test_create_experiment_missing_flag(self):
"""Test error when trying to create experiment with non-existent flag."""
tool = await CreateExperimentTool.create_tool_class(
team=self.team,
user=self.user,
state=AssistantState(messages=[]),
)
result, artifact = await tool._arun_impl(
name="Test Experiment",
feature_flag_key="non-existent-flag",
)
assert "Failed to create" in result
assert "does not exist" in result
assert artifact is not None
assert artifact.get("error")
async def test_create_experiment_flag_without_variants(self):
"""Test error when flag doesn't have multivariate variants."""
# Create a flag without multivariate
await FeatureFlag.objects.acreate(
team=self.team,
created_by=self.user,
key="no-variants-flag",
name="No Variants Flag",
filters={"groups": [{"properties": [], "rollout_percentage": 100}]},
)
tool = await CreateExperimentTool.create_tool_class(
team=self.team,
user=self.user,
state=AssistantState(messages=[]),
)
result, artifact = await tool._arun_impl(
name="Test Experiment",
feature_flag_key="no-variants-flag",
)
assert "Failed to create" in result
assert "must have multivariate variants" in result
assert artifact is not None
assert artifact.get("error")
async def test_create_experiment_flag_with_one_variant(self):
"""Test error when flag has only 1 variant (need at least 2)."""
# Create a flag with only 1 variant
await FeatureFlag.objects.acreate(
team=self.team,
created_by=self.user,
key="one-variant-flag",
name="One Variant Flag",
filters={
"groups": [{"properties": [], "rollout_percentage": 100}],
"multivariate": {
"variants": [
{"key": "only_one", "name": "Only Variant", "rollout_percentage": 100},
]
},
},
)
tool = await CreateExperimentTool.create_tool_class(
team=self.team,
user=self.user,
state=AssistantState(messages=[]),
)
result, artifact = await tool._arun_impl(
name="Test Experiment",
feature_flag_key="one-variant-flag",
)
assert "Failed to create" in result
assert "at least 2 variants" in result
assert artifact is not None
assert artifact.get("error")
async def test_create_experiment_uses_flag_variants(self):
"""Test that experiment uses the actual variants from the feature flag."""
# Create a flag with 3 custom variants
await FeatureFlag.objects.acreate(
team=self.team,
created_by=self.user,
key="custom-variants-flag",
name="Custom Variants Flag",
filters={
"groups": [{"properties": [], "rollout_percentage": 100}],
"multivariate": {
"variants": [
{"key": "variant_a", "name": "Variant A", "rollout_percentage": 33},
{"key": "variant_b", "name": "Variant B", "rollout_percentage": 33},
{"key": "variant_c", "name": "Variant C", "rollout_percentage": 34},
]
},
},
)
tool = await CreateExperimentTool.create_tool_class(
team=self.team,
user=self.user,
state=AssistantState(messages=[]),
)
result, artifact = await tool._arun_impl(
name="Custom Variants Test",
feature_flag_key="custom-variants-flag",
)
assert "Successfully created" in result
experiment = await Experiment.objects.aget(name="Custom Variants Test", team=self.team)
assert experiment is not None
assert experiment.parameters is not None
assert len(experiment.parameters["feature_flag_variants"]) == 3
assert experiment.parameters["feature_flag_variants"][0] is not None
assert experiment.parameters["feature_flag_variants"][0]["key"] == "variant_a"
assert experiment.parameters["feature_flag_variants"][0]["name"] == "Variant A"
assert experiment.parameters["feature_flag_variants"][0]["rollout_percentage"] == 33
assert experiment.parameters["feature_flag_variants"][1]["key"] == "variant_b"
assert experiment.parameters["feature_flag_variants"][2]["key"] == "variant_c"

View File

@@ -609,7 +609,11 @@ The tool will automatically:
**Group-based:**
- "Create a flag targeting organizations"
- "Create a flag for companies where employee count > 100"
- "Create a flag for companies where employee count > 100
**For experiments**: If creating a flag for an A/B test or experiment, after creating
the flag, you should navigate to the experiments page and use create_experiment with
this flag's key to complete the experiment setup."
""".strip()
context_prompt_template: str = "Creates a new feature flag in the project with optional property-based targeting and multivariate variants for A/B testing"
args_schema: type[BaseModel] = CreateFeatureFlagArgs