feat(max): allow multiple property value search (#38489)

This commit is contained in:
Dena Korita
2025-10-07 18:29:46 +02:00
committed by GitHub
parent 4f5432f08d
commit 08b0f13ef0
12 changed files with 766 additions and 308 deletions

View File

@@ -7,13 +7,18 @@ import yaml
from posthog.taxonomy.taxonomy import CORE_FILTER_DEFINITIONS_BY_GROUP
def format_property_values(sample_values: list, sample_count: Optional[int] = 0, format_as_string: bool = False) -> str:
def format_property_values(
property_name: str, sample_values: list, sample_count: Optional[int] = 0, format_as_string: bool = False
) -> str:
if len(sample_values) == 0 or sample_count == 0:
return f"The property does not have any values in the taxonomy."
data = {
"property": property_name,
"values": [],
"message": f"The property does not have any values in the taxonomy.",
}
return yaml.dump(data, default_flow_style=False, sort_keys=False)
# Add quotes to the String type, so the LLM can easily infer a type.
# Strings like "true" or "10" are interpreted as booleans or numbers without quotes, so the schema generation fails.
# Remove the floating point the value is an integer.
# Format values for YAML
formatted_sample_values: list[str] = []
for value in sample_values:
if format_as_string:
@@ -22,16 +27,14 @@ def format_property_values(sample_values: list, sample_count: Optional[int] = 0,
formatted_sample_values.append(str(int(value)))
else:
formatted_sample_values.append(str(value))
prop_values = ", ".join(formatted_sample_values)
# If there wasn't an exact match with the user's search, we provide a hint that LLM can use an arbitrary value.
if sample_count is None:
return f"{prop_values} and many more distinct values."
formatted_sample_values.append("and many more distinct values")
elif sample_count > len(sample_values):
diff = sample_count - len(sample_values)
return f"{prop_values} and {diff} more distinct value{'' if diff == 1 else 's'}."
return prop_values
remaining = sample_count - len(sample_values)
formatted_sample_values.append(f"and {remaining} more distinct values")
data = {"property": property_name, "values": formatted_sample_values}
return yaml.dump(data, default_flow_style=False, sort_keys=False)
def format_properties_xml(children: list[tuple[str, str | None, str | None]]):

View File

@@ -163,7 +163,7 @@ class TaxonomyAgentToolsNode(
def node_name(self) -> MaxNodeName:
return TaxonomyNodeName.TOOLS_NODE
def run(self, state: TaxonomyStateType, config: RunnableConfig) -> TaxonomyPartialStateType:
async def arun(self, state: TaxonomyStateType, config: RunnableConfig) -> TaxonomyPartialStateType:
intermediate_steps = state.intermediate_steps or []
action, _output = intermediate_steps[-1]
tool_input: TaxonomyTool | None = None
@@ -199,7 +199,7 @@ class TaxonomyAgentToolsNode(
if tool_input and not output:
# Use the toolkit to handle tool execution
_, output = self._toolkit.handle_tools(tool_input.name, tool_input)
_, output = await self._toolkit.handle_tools(tool_input.name, tool_input)
if output:
tool_msg = LangchainToolMessage(

View File

@@ -0,0 +1,189 @@
from posthog.test.base import ClickhouseTestMixin, NonAtomicBaseTest
from unittest.mock import patch
from posthog.models.person import Person
from posthog.models.property_definition import PropertyDefinition
from posthog.test.test_utils import create_group_type_mapping_without_created_at
from ee.hogai.graph.taxonomy.toolkit import TaxonomyAgentToolkit
class DummyToolkit(TaxonomyAgentToolkit):
def get_tools(self):
return self._get_default_tools()
class TestEntities(ClickhouseTestMixin, NonAtomicBaseTest):
def setUp(self):
super().setUp()
for i, group_type in enumerate(["organization", "project"]):
create_group_type_mapping_without_created_at(
team=self.team, project_id=self.team.project_id, group_type_index=i, group_type=group_type
)
PropertyDefinition.objects.create(
team=self.team,
name="name",
property_type="String",
is_numerical=False,
type=PropertyDefinition.Type.PERSON,
)
PropertyDefinition.objects.create(
team=self.team,
name="name_group",
property_type="String",
is_numerical=False,
type=PropertyDefinition.Type.GROUP,
group_type_index=0,
)
PropertyDefinition.objects.create(
team=self.team,
name="property_no_values",
property_type="String",
is_numerical=False,
type=PropertyDefinition.Type.PERSON,
)
Person.objects.create(
team=self.team,
distinct_ids=["test-user"],
properties={"name": "Test User"},
)
self.toolkit = DummyToolkit(self.team)
async def test_retrieve_entity_properties(self):
result = await self.toolkit.retrieve_entity_properties("person")
assert (
"<properties><String><prop><name>name</name></prop><prop><name>property_no_values</name></prop></String></properties>"
== result
)
async def test_person_property_values_exists(self):
result = await self.toolkit._get_entity_names()
expected = ["person", "session", "organization", "project"]
self.assertEqual(result, expected)
property_vals = await self.toolkit.retrieve_entity_property_values({"person": ["name"]})
self.assertIn("person", property_vals)
self.assertIn("name", "\n".join(property_vals.get("person", [])))
self.assertTrue(any("Test User" in str(val) for val in property_vals.get("person", [])))
async def test_person_property_values_do_not_exist(self):
result = await self.toolkit._get_entity_names()
expected = ["person", "session", "organization", "project"]
self.assertEqual(result, expected)
property_vals = await self.toolkit.retrieve_entity_property_values({"person": ["property_no_values"]})
self.assertIn("person", property_vals)
self.assertIn("property_no_values", "\n".join(property_vals.get("person", [])))
self.assertTrue(
any(
"The property does not have any values in the taxonomy." in str(val)
for val in property_vals.get("person", [])
)
)
async def test_person_property_values_mixed(self):
result = await self.toolkit._get_entity_names()
expected = ["person", "session", "organization", "project"]
self.assertEqual(result, expected)
property_vals = await self.toolkit.retrieve_entity_property_values({"person": ["property_no_values", "name"]})
self.assertIn("person", property_vals)
self.assertIn("property_no_values", "\n".join(property_vals.get("person", [])))
self.assertTrue(
any(
"The property does not have any values in the taxonomy." in str(val)
for val in property_vals.get("person", [])
)
)
self.assertIn("name", "\n".join(property_vals.get("person", [])))
self.assertTrue(any("Test User" in str(val) for val in property_vals.get("person", [])))
async def test_multiple_entities(self):
result = await self.toolkit._get_entity_names()
expected = ["person", "session", "organization", "project"]
self.assertEqual(result, expected)
property_vals = await self.toolkit.retrieve_entity_property_values(
{
"person": ["property_no_values"],
"session": ["$session_duration", "$channel_type", "nonexistent_property"],
}
)
self.assertIn("person", property_vals)
self.assertIn("property_no_values", "\n".join(property_vals.get("person", [])))
self.assertTrue(
any(
"The property does not have any values in the taxonomy." in str(val)
for val in property_vals.get("person", [])
)
)
self.assertIn("session", property_vals)
self.assertIn("$session_duration", "\n".join(property_vals.get("session", [])))
self.assertIn("$channel_type", "\n".join(property_vals.get("session", [])))
self.assertIn("nonexistent_property", "\n".join(property_vals.get("session", [])))
self.assertTrue(
any(
"values:\n- '30'\n- '146'\n- '2'\n- and many more distinct values\n" in str(val)
for val in property_vals.get("session", [])
)
)
self.assertTrue(any("Direct" in str(val) for val in property_vals.get("session", [])))
self.assertTrue(
any(
"The property nonexistent_property does not exist in the taxonomy." in str(val)
for val in property_vals.get("session", [])
)
)
async def test_retrieve_entity_property_values_batching(self):
"""Test that when more than 6 entities are processed, they are sent in batches of 6"""
# Create 8 entities (more than 6) to test batching
entities = [f"entity_{i}" for i in range(8)]
entity_properties = {
entity: ["$session_duration", "$channel_type", "nonexistent_property"] for entity in entities
}
# Spy on the _handle_entity_batch method to track how many times it's called
with patch.object(self.toolkit, "_handle_entity_batch") as mock_handle_batch:
# Mock the method to return a simple result
mock_handle_batch.return_value = {
entity: [
"values:\n- '30'\n- '146'\n- '2'\n- and many more distinct values\n",
"Direct",
"The property nonexistent_property does not exist in the taxonomy.",
]
for entity in entities
}
result = await self.toolkit.retrieve_entity_property_values(entity_properties)
# Verify that we got results for all entities
self.assertEqual(len(result), 8)
for entity in entities:
self.assertIn(entity, result)
self.assertEqual(
result[entity],
[
"values:\n- '30'\n- '146'\n- '2'\n- and many more distinct values\n",
"Direct",
"The property nonexistent_property does not exist in the taxonomy.",
],
)
# Verify that _handle_entity_batch was called twice:
# - First batch: entities 0-5 (6 entities)
# - Second batch: entities 6-7 (2 entities)
self.assertEqual(mock_handle_batch.call_count, 2)
# Verify the batch sizes
call_args_list = mock_handle_batch.call_args_list
first_batch = call_args_list[0][0][0] # First argument of first call
second_batch = call_args_list[1][0][0] # First argument of second call
self.assertEqual(len(first_batch), 6) # First batch should have 6 entities
self.assertEqual(len(second_batch), 2) # Second batch should have 2 entities

View File

@@ -0,0 +1,135 @@
from posthog.test.base import ClickhouseTestMixin, NonAtomicBaseTest, _create_event, flush_persons_and_events
from posthog.models import Action
from posthog.models.property_definition import PropertyDefinition
from posthog.test.test_utils import create_group_type_mapping_without_created_at
from ee.hogai.graph.taxonomy.toolkit import TaxonomyAgentToolkit
class DummyToolkit(TaxonomyAgentToolkit):
def get_tools(self):
return self._get_default_tools()
class TestEvents(ClickhouseTestMixin, NonAtomicBaseTest):
def setUp(self):
super().setUp()
for i, group_type in enumerate(["organization", "project"]):
create_group_type_mapping_without_created_at(
team=self.team, project_id=self.team.project_id, group_type_index=i, group_type=group_type
)
PropertyDefinition.objects.create(
team=self.team,
name="$browser",
property_type="String",
is_numerical=False,
type=PropertyDefinition.Type.EVENT,
)
PropertyDefinition.objects.create(
team=self.team,
name="id",
property_type="String",
is_numerical=False,
type=PropertyDefinition.Type.EVENT,
)
PropertyDefinition.objects.create(
team=self.team,
name="no_values",
property_type="Boolean",
is_numerical=False,
type=PropertyDefinition.Type.EVENT,
)
# Create events that match the action conditions
_create_event(
event="event1",
distinct_id="user123",
team=self.team,
properties={
"$browser": "Chrome",
"id": "123",
},
)
_create_event(
event="event1",
distinct_id="user456",
team=self.team,
properties={
"$browser": "Firefox",
"id": "456",
},
)
Action.objects.create(
id=232, team=self.team, name="action1", description="Test Description", steps_json=[{"event": "event1"}]
)
self.toolkit = DummyToolkit(self.team)
async def test_events_property_values_exists(self):
result = await self.toolkit._get_entity_names()
expected = ["person", "session", "organization", "project"]
assert result == expected
property_vals = await self.toolkit.retrieve_event_or_action_property_values({"event1": ["$browser", "id"]})
assert "event1" in property_vals
assert "$browser" in "\n".join(property_vals.get("event1", []))
assert "id" in "\n".join(property_vals.get("event1", []))
async def test_events_property_values_do_not_exist(self):
result = await self.toolkit._get_entity_names()
expected = ["person", "session", "organization", "project"]
assert result == expected
property_vals = await self.toolkit.retrieve_event_or_action_property_values({"event1": ["no_values"]})
assert "event1" in property_vals
assert "no_values" in "\n".join(property_vals.get("event1", []))
assert "No values found for property no_values on entity event event1" in "\n".join(
property_vals.get("event1", [])
)
async def test_events_property_values_action_values_not_found(self):
result = await self.toolkit._get_entity_names()
expected = ["person", "session", "organization", "project"]
assert result == expected
property_vals = await self.toolkit.retrieve_event_or_action_property_values({232: ["no_values"]})
assert 232 in property_vals
assert "no_values" in "\n".join(property_vals.get(232, []))
assert "No values found for property no_values on entity action with ID 232" in "\n".join(
property_vals.get(232, [])
)
async def test_events_property_values_action_multiple_properties(self):
result = await self.toolkit._get_entity_names()
expected = ["person", "session", "organization", "project"]
assert result == expected
property_vals = await self.toolkit.retrieve_event_or_action_property_values({232: ["no_values", "$browser"]})
assert 232 in property_vals
assert "no_values" in "\n".join(property_vals.get(232, []))
assert "$browser" in "\n".join(property_vals.get(232, []))
# Should have actual values
assert "Chrome" in "\n".join(property_vals.get(232, []))
assert "Firefox" in "\n".join(property_vals.get(232, []))
async def test_retrieve_event_or_action_properties_action_not_found(self):
result = await self.toolkit.retrieve_event_or_action_properties(999)
assert (
"Action 999 does not exist in the taxonomy. Verify that the action ID is correct and try again." == result
)
async def test_retrieve_event_or_action_properties_event_not_found(self):
result = await self.toolkit.retrieve_event_or_action_properties("test")
assert "Properties do not exist in the taxonomy for the event test." == result
def tearDown(self):
flush_persons_and_events()
super().tearDown()

View File

@@ -0,0 +1,127 @@
from datetime import datetime
from posthog.test.base import ClickhouseTestMixin, NonAtomicBaseTest
from posthog.models.group.util import raw_create_group_ch
from posthog.models.property_definition import PropertyDefinition
from posthog.test.test_utils import create_group_type_mapping_without_created_at
from ee.hogai.graph.taxonomy.toolkit import TaxonomyAgentToolkit
class DummyToolkit(TaxonomyAgentToolkit):
def get_tools(self):
return self._get_default_tools()
class TestGroups(ClickhouseTestMixin, NonAtomicBaseTest):
def setUp(self):
super().setUp()
for i, group_type in enumerate(["organization", "project", "no_properties"]):
create_group_type_mapping_without_created_at(
team=self.team, project_id=self.team.project_id, group_type_index=i, group_type=group_type
)
# Create property definitions for organization (group_type_index=0)
PropertyDefinition.objects.create(
team=self.team,
name="name",
property_type="String",
is_numerical=False,
type=PropertyDefinition.Type.GROUP,
group_type_index=0, # organization
)
PropertyDefinition.objects.create(
team=self.team,
name="industry",
property_type="String",
is_numerical=False,
type=PropertyDefinition.Type.GROUP,
group_type_index=0, # organization
)
# Create property definitions for project (group_type_index=1)
PropertyDefinition.objects.create(
team=self.team,
name="size",
property_type="Numeric",
is_numerical=True,
type=PropertyDefinition.Type.GROUP,
group_type_index=1, # project
)
PropertyDefinition.objects.create(
team=self.team,
name="name_group",
property_type="String",
is_numerical=False,
type=PropertyDefinition.Type.GROUP,
group_type_index=0,
)
raw_create_group_ch(
team_id=self.team.id,
group_type_index=0,
group_key="acme-corp",
properties={"name": "Acme Corp", "industry": "tech"},
created_at=datetime.now(),
sync=True, # Force sync to ClickHouse
)
raw_create_group_ch(
team_id=self.team.id,
group_type_index=1,
group_key="acme-project",
properties={"name": "Acme Project", "size": 100},
created_at=datetime.now(),
sync=True, # Force sync to ClickHouse
)
self.toolkit = DummyToolkit(self.team)
async def test_entity_names_with_existing_groups(self):
# Test that the entity names include the groups we created in setUp
result = await self.toolkit._get_entity_names()
expected = ["person", "session", "organization", "project", "no_properties"]
self.assertEqual(result, expected)
property_vals = await self.toolkit.retrieve_entity_property_values(
{"organization": ["name", "industry"], "project": ["size"]}
)
# Should return the actual values from the groups we created
self.assertIn("organization", property_vals)
self.assertIn("project", property_vals)
self.assertTrue(any("Acme Corp" in str(val) for val in property_vals.get("organization", [])))
self.assertTrue(any("tech" in str(val) for val in property_vals.get("organization", [])))
self.assertTrue(any("100" in str(val) for val in property_vals.get("project", [])))
async def test_retrieve_entity_property_values_wrong_group(self):
property_vals = await self.toolkit.retrieve_entity_property_values(
{"test": ["name", "industry"], "project": ["size"]}
)
self.assertIn("test", property_vals)
self.assertIn(
"Entity test not found. Available entities: person, session, organization, project, no_properties",
property_vals["test"],
)
self.assertIn("project", property_vals)
async def test_retrieve_entity_properties_group(self):
result = await self.toolkit.retrieve_entity_properties("organization")
assert (
"<properties><String><prop><name>name</name></prop><prop><name>industry</name></prop><prop><name>name_group</name></prop></String></properties>"
== result
)
async def test_retrieve_entity_properties_group_not_found(self):
result = await self.toolkit.retrieve_entity_properties("test")
assert (
"Entity test not found. Available entities: person, session, organization, project, no_properties" == result
)
async def test_retrieve_entity_properties_group_nothing_found(self):
result = await self.toolkit.retrieve_entity_properties("no_properties")
assert "Properties do not exist in the taxonomy for the entity no_properties." == result

View File

@@ -159,7 +159,7 @@ class TestTaxonomyAgentToolsNode(BaseTest):
@patch.object(MockTaxonomyAgentToolkit, "get_tool_input_model")
@patch.object(MockTaxonomyAgentToolkit, "handle_tools")
def test_run_normal_tool_execution(self, mock_handle_tools, mock_get_tool_input):
async def test_run_normal_tool_execution(self, mock_handle_tools, mock_get_tool_input):
# Setup mocks
mock_input = Mock()
mock_input.name = "test_tool"
@@ -172,14 +172,14 @@ class TestTaxonomyAgentToolsNode(BaseTest):
state = TaxonomyAgentState()
state.intermediate_steps = [(action, None)]
result = self.node.run(state, RunnableConfig())
result = await self.node.arun(state, RunnableConfig())
self.assertIsInstance(result, TaxonomyAgentState)
self.assertEqual(len(result.intermediate_steps), 1)
self.assertEqual(result.intermediate_steps[0][1], "tool output")
@patch.object(MockTaxonomyAgentToolkit, "get_tool_input_model")
def test_run_validation_error(self, mock_get_tool_input):
async def test_run_validation_error(self, mock_get_tool_input):
# Setup validation error
validation_error = ValidationError.from_exception_data(
"TestModel", [{"type": "missing", "loc": ("field",), "msg": "Field required"}]
@@ -190,13 +190,13 @@ class TestTaxonomyAgentToolsNode(BaseTest):
state = TaxonomyAgentState()
state.intermediate_steps = [(action, None)]
result = self.node.run(state, RunnableConfig())
result = await self.node.arun(state, RunnableConfig())
self.assertIsInstance(result, TaxonomyAgentState)
self.assertEqual(len(result.tool_progress_messages), 1)
@patch.object(MockTaxonomyAgentToolkit, "get_tool_input_model")
def test_run_final_answer(self, mock_get_tool_input):
async def test_run_final_answer(self, mock_get_tool_input):
# Mock final answer tool
from pydantic import BaseModel
@@ -216,14 +216,14 @@ class TestTaxonomyAgentToolsNode(BaseTest):
state = TaxonomyAgentState()
state.intermediate_steps = [(action, None)]
result = self.node.run(state, RunnableConfig())
result = await self.node.arun(state, RunnableConfig())
self.assertIsInstance(result, TaxonomyAgentState)
self.assertEqual(result.output, expected_data)
self.assertIsNone(result.intermediate_steps)
@patch.object(MockTaxonomyAgentToolkit, "get_tool_input_model")
def test_run_ask_user_for_help(self, mock_get_tool_input):
async def test_run_ask_user_for_help(self, mock_get_tool_input):
# Mock ask for help tool
mock_input = Mock()
mock_input.name = "ask_user_for_help"
@@ -238,11 +238,11 @@ class TestTaxonomyAgentToolsNode(BaseTest):
with patch.object(self.node, "_get_reset_state") as mock_reset:
mock_reset.return_value = TaxonomyAgentState()
_ = self.node.run(state, RunnableConfig())
_ = await self.node.arun(state, RunnableConfig())
mock_reset.assert_called_once_with("Need help", "ask_user_for_help", state)
def test_run_max_iterations(self):
async def test_run_max_iterations(self):
# Create state with max iterations
actions = []
for i in range(self.node.MAX_ITERATIONS):
@@ -255,7 +255,7 @@ class TestTaxonomyAgentToolsNode(BaseTest):
with patch.object(self.node, "_get_reset_state") as mock_reset:
mock_reset.return_value = TaxonomyAgentState()
_ = self.node.run(state, RunnableConfig())
_ = await self.node.arun(state, RunnableConfig())
mock_reset.assert_called_once()
call_args = mock_reset.call_args
@@ -289,7 +289,7 @@ class TestTaxonomyAgentToolsNode(BaseTest):
result = self.node.router(state)
self.assertEqual(result, expected)
def test_get_reset_state(self):
async def test_get_reset_state(self):
original_state = TaxonomyAgentState()
original_state.change = "test change"

View File

@@ -1,23 +1,10 @@
from datetime import datetime
from posthog.test.base import BaseTest, ClickhouseTestMixin
from unittest.mock import Mock, patch
from posthog.test.base import BaseTest
from unittest.mock import patch
from langchain_core.agents import AgentAction
from parameterized import parameterized
from pydantic import BaseModel
from posthog.schema import (
ActorsPropertyTaxonomyResponse,
CachedActorsPropertyTaxonomyQueryResponse,
CachedEventTaxonomyQueryResponse,
EventTaxonomyItem,
)
from posthog.models import Action
from posthog.models.property_definition import PropertyDefinition, PropertyType
from posthog.test.test_utils import create_group_type_mapping_without_created_at
from ee.hogai.graph.taxonomy.toolkit import TaxonomyAgentToolkit, TaxonomyToolNotFoundError
@@ -26,16 +13,14 @@ class DummyToolkit(TaxonomyAgentToolkit):
return self._get_default_tools()
class TestTaxonomyAgentToolkit(ClickhouseTestMixin, BaseTest):
class TestTaxonomyAgentToolkit(BaseTest):
def setUp(self):
super().setUp()
self.toolkit = DummyToolkit(self.team)
self.action = Action.objects.create(team=self.team, name="test_action", steps_json=[{"event": "test_event"}])
def test_toolkit_initialization(self):
async def test_toolkit_initialization(self):
self.assertEqual(self.toolkit._team, self.team)
self.assertIsInstance(self.toolkit._team_group_types, list)
self.assertIsInstance(self.toolkit._entity_names, list)
self.assertIsInstance(await self.toolkit._get_entity_names(), list)
@parameterized.expand(
[
@@ -43,35 +28,10 @@ class TestTaxonomyAgentToolkit(ClickhouseTestMixin, BaseTest):
("session", ["person", "session"]),
]
)
def test_entity_names_basic(self, entity, expected_base):
self.assertIn(entity, self.toolkit._entity_names)
async def test_entity_names_basic(self, entity, expected_base):
self.assertIn(entity, await self.toolkit._get_entity_names())
for expected in expected_base:
self.assertIn(expected, self.toolkit._entity_names)
def test_entity_names_with_groups(self):
# Create group type mappings
for i, group_type in enumerate(["organization", "project"]):
create_group_type_mapping_without_created_at(
team=self.team, project_id=self.team.project_id, group_type_index=i, group_type=group_type
)
toolkit = DummyToolkit(self.team)
expected = ["person", "session", "organization", "project"]
self.assertEqual(toolkit._entity_names, expected)
@parameterized.expand(
[
("$session_duration", True, "30, 146, 2"),
("$channel_type", True, "Direct"),
("nonexistent_property", False, "does not exist"),
]
)
def test_retrieve_session_properties(self, property_name, should_contain_values, expected_content):
result = self.toolkit._retrieve_session_properties(property_name)
if should_contain_values:
self.assertIn(expected_content, result)
else:
self.assertIn(expected_content, result)
self.assertIn(expected, await self.toolkit._get_entity_names())
def test_enrich_props_with_descriptions(self):
props = [("$browser", "String"), ("custom_prop", "Numeric")]
@@ -85,130 +45,16 @@ class TestTaxonomyAgentToolkit(ClickhouseTestMixin, BaseTest):
@parameterized.expand(
[
([], 0, False, "The property does not have any values"),
(["value1", "value2"], None, False, "value1, value2 and many more"),
(["value1", "value2"], 5, False, "value1, value2 and 3 more"),
(["value1", "value2"], None, False, "- value1\n- value2\n- and many more distinct values"),
(["value1", "value2"], 5, False, "- value1\n- value2\n- and 3 more distinct values"),
(["string_val"], 1, True, '"string_val"'),
([1.0, 2.0], 2, False, "1, 2"),
([1.0, 2.0], 2, False, "'1'\n- '2'"),
]
)
def test_format_property_values(self, sample_values, sample_count, format_as_string, expected_substring):
result = self.toolkit._format_property_values(sample_values, sample_count, format_as_string)
result = self.toolkit._format_property_values("test_property", sample_values, sample_count, format_as_string)
self.assertIn(expected_substring, result)
def _create_property_definition(self, prop_type, name="test_prop", group_type_index=None):
"""Helper to create property definitions"""
kwargs = {"team": self.team, "name": name, "property_type": PropertyType.String}
if prop_type == PropertyDefinition.Type.GROUP:
kwargs["type"] = PropertyDefinition.Type.GROUP
kwargs["group_type_index"] = group_type_index
else:
kwargs["type"] = prop_type
return PropertyDefinition.objects.create(**kwargs)
def _create_mock_taxonomy_response(self, response_type="event", **kwargs):
"""Helper to create mock taxonomy responses"""
if response_type == "event":
return CachedEventTaxonomyQueryResponse(
cache_key="test",
is_cached=False,
last_refresh=datetime.now().isoformat(),
next_allowed_client_refresh=datetime.now().isoformat(),
timezone="UTC",
results=[EventTaxonomyItem(**kwargs)],
)
elif response_type == "actors":
return CachedActorsPropertyTaxonomyQueryResponse(
cache_key="test",
is_cached=False,
last_refresh=datetime.now().isoformat(),
next_allowed_client_refresh=datetime.now().isoformat(),
timezone="UTC",
results=ActorsPropertyTaxonomyResponse(**kwargs),
)
def test_retrieve_entity_properties_person(self):
self._create_property_definition(PropertyDefinition.Type.PERSON, "email")
result = self.toolkit.retrieve_entity_properties("person")
self.assertIn("email", result)
self.assertIn("String", result)
def test_retrieve_entity_properties_session(self):
result = self.toolkit.retrieve_entity_properties("session")
self.assertIn("$session_duration", result)
self.assertIn("properties", result)
def test_retrieve_entity_properties_group(self):
create_group_type_mapping_without_created_at(
team=self.team, project_id=self.team.project_id, group_type_index=0, group_type="organization"
)
self._create_property_definition(PropertyDefinition.Type.GROUP, "org_name", group_type_index=0)
result = self.toolkit.retrieve_entity_properties("organization")
self.assertIn("org_name", result)
@parameterized.expand(
[
("invalid_entity", "Entity invalid_entity not found"),
("person", "Properties do not exist in the taxonomy for the entity person."),
]
)
def test_retrieve_entity_properties_edge_cases(self, entity, expected_content):
result = self.toolkit.retrieve_entity_properties(entity)
self.assertIn(expected_content, result)
@patch("ee.hogai.graph.taxonomy.toolkit.ActorsPropertyTaxonomyQueryRunner")
def test_retrieve_entity_property_values_person(self, mock_runner_class):
self._create_property_definition(PropertyDefinition.Type.PERSON, "email")
mock_response = self._create_mock_taxonomy_response(
response_type="actors", sample_values=["test@example.com", "user@test.com"], sample_count=2
)
mock_runner = Mock()
mock_runner.run.return_value = mock_response
mock_runner_class.return_value = mock_runner
result = self.toolkit.retrieve_entity_property_values("person", "email")
self.assertIn("test@example.com", result)
def test_retrieve_entity_property_values_invalid_entity(self):
result = self.toolkit.retrieve_entity_property_values("invalid", "prop")
self.assertIn("Entity invalid not found", result)
@patch("ee.hogai.graph.taxonomy.toolkit.EventTaxonomyQueryRunner")
def test_retrieve_event_or_action_properties(self, mock_runner_class):
self._create_property_definition(PropertyDefinition.Type.EVENT, "$browser")
mock_response = self._create_mock_taxonomy_response(property="$browser", sample_values=[], sample_count=0)
mock_runner = Mock()
mock_runner.run.return_value = mock_response
mock_runner_class.return_value = mock_runner
result = self.toolkit.retrieve_event_or_action_properties("test_event")
self.assertIn("$browser", result)
def test_retrieve_event_or_action_properties_action_not_found(self):
Action.objects.all().delete()
result = self.toolkit.retrieve_event_or_action_properties(999)
self.assertEqual(result, "No actions exist in the project.")
@patch("ee.hogai.graph.taxonomy.toolkit.EventTaxonomyQueryRunner")
def test_retrieve_event_or_action_property_values(self, mock_runner_class):
self._create_property_definition(PropertyDefinition.Type.EVENT, "$browser")
mock_response = self._create_mock_taxonomy_response(
property="$browser", sample_values=["Chrome", "Firefox"], sample_count=2
)
mock_runner = Mock()
mock_runner.run.return_value = mock_response
mock_runner_class.return_value = mock_runner
result = self.toolkit.retrieve_event_or_action_property_values("test_event", "$browser")
self.assertIn("Chrome", result)
self.assertIn("Firefox", result)
def test_handle_incorrect_response(self):
class TestModel(BaseModel):
field: str = "test"
@@ -227,10 +73,10 @@ class TestTaxonomyAgentToolkit(ClickhouseTestMixin, BaseTest):
]
)
@patch.object(DummyToolkit, "retrieve_entity_properties", return_value="mocked")
@patch.object(DummyToolkit, "retrieve_entity_property_values", return_value="mocked")
@patch.object(DummyToolkit, "retrieve_entity_property_values", return_value={"person": ["mocked"]})
@patch.object(DummyToolkit, "retrieve_event_or_action_properties", return_value="mocked")
@patch.object(DummyToolkit, "retrieve_event_or_action_property_values", return_value="mocked")
def test_handle_tools(self, tool_name, tool_args, expected_result, *mocks):
@patch.object(DummyToolkit, "retrieve_event_or_action_property_values", return_value={"test_event": ["mocked"]})
async def test_handle_tools(self, tool_name, tool_args, expected_result, *mocks):
class Arguments(BaseModel):
pass
@@ -242,12 +88,12 @@ class TestTaxonomyAgentToolkit(ClickhouseTestMixin, BaseTest):
arguments: Arguments
tool_input = ToolInput(name=tool_name, arguments=Arguments(**tool_args))
tool_name_result, result = self.toolkit.handle_tools(tool_name, tool_input)
tool_name_result, result = await self.toolkit.handle_tools(tool_name, tool_input)
self.assertEqual(result, expected_result)
self.assertEqual(tool_name_result, tool_name)
def test_handle_tools_invalid_tool(self):
async def test_handle_tools_invalid_tool(self):
class ToolInput(BaseModel):
name: str = "invalid_tool"
arguments: dict = {}
@@ -255,7 +101,7 @@ class TestTaxonomyAgentToolkit(ClickhouseTestMixin, BaseTest):
tool_input = ToolInput()
with self.assertRaises(TaxonomyToolNotFoundError):
self.toolkit.handle_tools("invalid_tool", tool_input)
await self.toolkit.handle_tools("invalid_tool", tool_input)
def test_format_properties_formats(self):
props = [("prop1", "String", "Test description"), ("prop2", "Numeric", None)]

View File

@@ -1,3 +1,4 @@
import asyncio
from collections.abc import Iterable
from functools import cached_property
from typing import Any, Optional, Union, cast
@@ -9,7 +10,10 @@ from posthog.schema import (
ActorsPropertyTaxonomyQuery,
CachedActorsPropertyTaxonomyQueryResponse,
CachedEventTaxonomyQueryResponse,
CacheMissResponse,
EventTaxonomyItem,
EventTaxonomyQuery,
QueryStatusResponse,
)
from posthog.hogql.database.schema.channel_type import DEFAULT_CHANNEL_TYPES
@@ -21,6 +25,7 @@ from posthog.hogql_queries.query_runner import ExecutionMode
from posthog.models import Action, Team
from posthog.models.group_type_mapping import GroupTypeMapping
from posthog.models.property_definition import PropertyDefinition, PropertyType
from posthog.sync import database_sync_to_async
from posthog.taxonomy.taxonomy import CORE_FILTER_DEFINITIONS_BY_GROUP
from ee.hogai.graph.taxonomy.format import (
@@ -104,6 +109,7 @@ class TaxonomyAgentToolkit:
def __init__(self, team: Team):
self._team = team
self.MAX_ENTITIES_PER_BATCH = 6
@property
def _groups(self):
@@ -125,17 +131,21 @@ class TaxonomyAgentToolkit:
entities = [
"person",
"session",
*[group.group_type for group in self._groups],
*self._team_group_types,
]
return entities
@database_sync_to_async(thread_sensitive=False)
def _get_entity_names(self) -> list[str]:
return self._entity_names
def _enrich_props_with_descriptions(self, entity: str, props: Iterable[tuple[str, str | None]]):
return enrich_props_with_descriptions(entity, props)
def _format_property_values(
self, sample_values: list, sample_count: Optional[int] = 0, format_as_string: bool = False
self, property_name: str, sample_values: list, sample_count: Optional[int] = 0, format_as_string: bool = False
) -> str:
return format_property_values(sample_values, sample_count, format_as_string)
return format_property_values(property_name, sample_values, sample_count, format_as_string)
def _retrieve_session_properties(self, property_name: str) -> str:
"""
@@ -161,18 +171,26 @@ class TaxonomyAgentToolkit:
else:
return TaxonomyErrorMessages.property_values_not_found(property_name, "session")
return self._format_property_values(sample_values, sample_count, format_as_string=is_str)
return self._format_property_values(property_name, sample_values, sample_count, format_as_string=is_str)
def _retrieve_event_or_action_taxonomy(self, event_name_or_action_id: str | int):
@database_sync_to_async(thread_sensitive=False)
def _retrieve_event_or_action_taxonomy(
self, event_name_or_action_id: str | int, properties: list[str] | None = None
):
"""
Retrieve event/action taxonomy with efficient caching.
Multiple properties are batched in a single query to maximize cache hits.
"""
is_event = isinstance(event_name_or_action_id, str)
if is_event:
query = EventTaxonomyQuery(event=event_name_or_action_id, maxPropertyValues=25)
query = EventTaxonomyQuery(event=event_name_or_action_id, maxPropertyValues=25, properties=properties)
verbose_name = f"event {event_name_or_action_id}"
else:
query = EventTaxonomyQuery(actionId=event_name_or_action_id, maxPropertyValues=25)
query = EventTaxonomyQuery(actionId=event_name_or_action_id, maxPropertyValues=25, properties=properties)
verbose_name = f"action with ID {event_name_or_action_id}"
runner = EventTaxonomyQueryRunner(query, self._team)
with tags_context(product=Product.MAX_AI, team_id=self._team.pk, org_id=self._team.organization_id):
# Use cache-first execution mode for optimal performance
response = runner.run(ExecutionMode.RECENT_CACHE_CALCULATE_ASYNC_IF_STALE_AND_BLOCKING_ON_MISS)
return response, verbose_name
@@ -220,20 +238,20 @@ class TaxonomyAgentToolkit:
"""Get custom tools. Override in subclasses to add custom tools."""
raise NotImplementedError("_get_custom_tools must be implemented in subclasses")
def retrieve_entity_properties(self, entity: str, max_properties: int = 500) -> str:
async def retrieve_entity_properties(self, entity: str, max_properties: int = 500) -> str:
"""
Retrieve properties for an entitiy like person, session, or one of the groups.
"""
if entity not in ("person", "session", *[group.group_type for group in self._groups]):
return TaxonomyErrorMessages.entity_not_found(entity, self._entity_names)
entity_names = await self._get_entity_names()
if entity not in entity_names:
return TaxonomyErrorMessages.entity_not_found(entity, entity_names)
props: list[Any] = []
if entity == "person":
qs = PropertyDefinition.objects.filter(team=self._team, type=PropertyDefinition.Type.PERSON).values_list(
"name", "property_type"
)
props = self._enrich_props_with_descriptions("person", qs)
props = self._enrich_props_with_descriptions("person", [prop async for prop in qs])
elif entity == "session":
# Session properties are not in the DB.
props = self._enrich_props_with_descriptions(
@@ -245,91 +263,153 @@ class TaxonomyAgentToolkit:
],
)
else:
group_type_index = next(
(group.group_type_index for group in self._groups if group.group_type == entity), None
)
group_type_index = None
groups = [group async for group in self._groups]
for group in groups:
if group.group_type == entity:
group_type_index = group.group_type_index
break
if group_type_index is None:
return f"Group {entity} does not exist in the taxonomy."
qs = PropertyDefinition.objects.filter(
team=self._team, type=PropertyDefinition.Type.GROUP, group_type_index=group_type_index
).values_list("name", "property_type")[:max_properties]
props = self._enrich_props_with_descriptions(entity, qs)
props = self._enrich_props_with_descriptions(entity, [prop async for prop in qs])
if not props:
return f"Properties do not exist in the taxonomy for the entity {entity}."
return self._format_properties(props)
def retrieve_entity_property_values(self, entity: str, property_name: str) -> str:
"""Retrieve property values for an entity."""
if entity not in self._entity_names:
return TaxonomyErrorMessages.entity_not_found(entity, self._entity_names)
async def retrieve_entity_property_values(self, entity_properties: dict[str, list[str]]) -> dict[str, list[str]]:
result = await self._parallel_entity_processing(entity_properties)
return result
if entity == "session":
return self._retrieve_session_properties(property_name)
async def _handle_entity_batch(self, batch: dict[str, list[str]]) -> dict[str, list[str]]:
entity_tasks = [
self._retrieve_multiple_entity_property_values(entity, property_names)
for entity, property_names in batch.items()
]
batch_results = await asyncio.gather(*entity_tasks)
return dict(zip(batch.keys(), batch_results))
if entity == "person":
query = ActorsPropertyTaxonomyQuery(properties=[property_name], maxPropertyValues=25)
elif entity == "event":
query = ActorsPropertyTaxonomyQuery(properties=[property_name], maxPropertyValues=50)
async def _parallel_entity_processing(self, entity_properties: dict[str, list[str]]) -> dict[str, list[str]]:
entity_items = list(entity_properties.items())
if len(entity_items) > self.MAX_ENTITIES_PER_BATCH:
# Process in batches
results = {}
for i in range(0, len(entity_items), self.MAX_ENTITIES_PER_BATCH):
batch = dict(entity_items[i : i + self.MAX_ENTITIES_PER_BATCH])
batch_results = await self._handle_entity_batch(batch)
results.update(batch_results)
return results
else:
group_index = next((group.group_type_index for group in self._groups if group.group_type == entity), None)
return await self._handle_entity_batch(entity_properties)
async def _retrieve_multiple_entity_property_values(self, entity: str, property_names: list[str]) -> list[str]:
"""Retrieve property values for multiple entities and properties efficiently."""
results = []
entity_names = await self._get_entity_names()
if entity not in entity_names:
results.append(TaxonomyErrorMessages.entity_not_found(entity, entity_names))
return results
if entity == "session":
for property_name in property_names:
results.append(self._retrieve_session_properties(property_name))
return results
groups = [group async for group in self._groups]
query = self._build_query(entity, property_names, groups)
if query is None:
results.append(TaxonomyErrorMessages.entity_not_found(entity))
return results
property_values_response = await self._run_actors_taxonomy_query(query)
if not isinstance(property_values_response, CachedActorsPropertyTaxonomyQueryResponse):
results.append(TaxonomyErrorMessages.entity_not_found(entity))
return results
if not property_values_response.results:
for property_name in property_names:
results.append(TaxonomyErrorMessages.property_values_not_found(property_name, entity))
return results
if isinstance(property_values_response.results, list):
property_values_results = property_values_response.results
else:
property_values_results = [property_values_response.results]
property_definitions: dict[str, PropertyDefinition] = await self._get_definitions_for_entity(
entity, property_names, query
)
results.extend(
self._process_property_values(
property_names, property_values_results, property_definitions, entity, is_indexed=True
)
)
return results
@database_sync_to_async(thread_sensitive=False)
def _get_definitions_for_entity(
self, entity: str, property_names: list[str], query: ActorsPropertyTaxonomyQuery
) -> dict[str, PropertyDefinition]:
"""Get property definitions for one entity and properties."""
if not property_names:
return {}
if query.groupTypeIndex is not None:
prop_type = PropertyDefinition.Type.GROUP
group_type_index = query.groupTypeIndex
elif entity == "event":
prop_type = PropertyDefinition.Type.EVENT
group_type_index = None
else:
prop_type = PropertyDefinition.Type.PERSON
group_type_index = None
property_definitions = PropertyDefinition.objects.filter(
team=self._team,
name__in=property_names,
type=prop_type,
group_type_index=group_type_index,
)
return {prop.name: prop for prop in property_definitions}
def _build_query(
self, entity: str, properties: list[str], groups: list[GroupTypeMapping]
) -> ActorsPropertyTaxonomyQuery | None:
"""Build a query for the given entity and property names."""
if entity == "person":
query = ActorsPropertyTaxonomyQuery(properties=properties, maxPropertyValues=25)
elif entity == "event":
query = ActorsPropertyTaxonomyQuery(properties=properties, maxPropertyValues=50)
else:
group_index = next((group.group_type_index for group in groups if group.group_type == entity), None)
if group_index is None:
return TaxonomyErrorMessages.entity_not_found(entity)
query = ActorsPropertyTaxonomyQuery(
groupTypeIndex=group_index, properties=[property_name], maxPropertyValues=25
)
try:
if query.groupTypeIndex is not None:
prop_type = PropertyDefinition.Type.GROUP
group_type_index = query.groupTypeIndex
elif entity == "event":
prop_type = PropertyDefinition.Type.EVENT
group_type_index = None
else:
prop_type = PropertyDefinition.Type.PERSON
group_type_index = None
property_definition = PropertyDefinition.objects.get(
team=self._team,
name=property_name,
type=prop_type,
group_type_index=group_type_index,
)
except PropertyDefinition.DoesNotExist:
return TaxonomyErrorMessages.property_not_found(property_name, entity)
return None
query = ActorsPropertyTaxonomyQuery(groupTypeIndex=group_index, properties=properties, maxPropertyValues=25)
return query
@database_sync_to_async(thread_sensitive=False)
def _run_actors_taxonomy_query(
self, query
) -> CachedActorsPropertyTaxonomyQueryResponse | CacheMissResponse | QueryStatusResponse:
with tags_context(product=Product.MAX_AI, team_id=self._team.pk, org_id=self._team.organization_id):
response = ActorsPropertyTaxonomyQueryRunner(query, self._team).run(
return ActorsPropertyTaxonomyQueryRunner(query, self._team).run(
ExecutionMode.RECENT_CACHE_CALCULATE_ASYNC_IF_STALE_AND_BLOCKING_ON_MISS
)
if not isinstance(response, CachedActorsPropertyTaxonomyQueryResponse):
return TaxonomyErrorMessages.entity_not_found(entity)
@database_sync_to_async(thread_sensitive=False)
def _get_project_actions(self) -> list[Action]:
return list(Action.objects.filter(team__project_id=self._team.project_id, deleted=False))
if not response.results:
return TaxonomyErrorMessages.property_values_not_found(property_name, entity)
# TRICKY. Remove when the toolkit supports multiple results.
if isinstance(response.results, list):
unpacked_results = response.results[0]
else:
unpacked_results = response.results
return self._format_property_values(
unpacked_results.sample_values,
unpacked_results.sample_count,
format_as_string=property_definition.property_type in (PropertyType.String, PropertyType.Datetime),
)
def retrieve_event_or_action_properties(self, event_name_or_action_id: str | int) -> str:
async def retrieve_event_or_action_properties(self, event_name_or_action_id: str | int) -> str:
"""
Retrieve properties for an event.
"""
try:
response, verbose_name = self._retrieve_event_or_action_taxonomy(event_name_or_action_id)
response, verbose_name = await self._retrieve_event_or_action_taxonomy(event_name_or_action_id)
except Action.DoesNotExist:
project_actions = Action.objects.filter(team__project_id=self._team.project_id, deleted=False)
project_actions = await self._get_project_actions()
if not project_actions:
return TaxonomyErrorMessages.no_actions_exist()
return TaxonomyErrorMessages.action_not_found(event_name_or_action_id)
@@ -337,11 +417,12 @@ class TaxonomyAgentToolkit:
return TaxonomyErrorMessages.generic_not_found("Properties")
if not response.results:
return TaxonomyErrorMessages.event_properties_not_found(verbose_name)
# Intersect properties with their types.
qs = PropertyDefinition.objects.filter(
team=self._team, type=PropertyDefinition.Type.EVENT, name__in=[item.property for item in response.results]
)
property_to_type = {property_definition.name: property_definition.property_type for property_definition in qs}
property_data = [prop async for prop in qs]
property_to_type = {prop.name: prop.property_type for prop in property_data}
props = [
(item.property, property_to_type.get(item.property))
for item in response.results
@@ -351,48 +432,124 @@ class TaxonomyAgentToolkit:
if not props:
return TaxonomyErrorMessages.event_properties_not_found(verbose_name)
return self._format_properties(self._enrich_props_with_descriptions("event", props))
enriched_props = self._enrich_props_with_descriptions("event", props)
return self._format_properties(enriched_props)
def retrieve_event_or_action_property_values(self, event_name_or_action_id: str | int, property_name: str) -> str:
async def retrieve_event_or_action_property_values(
self, event_properties: dict[str | int, list[str]]
) -> dict[str | int, list[str]]:
"""Retrieve property values for an event/action. Supports single property or list of properties."""
result = await self._parallel_event_processing(event_properties)
return result
async def _parallel_event_processing(
self, event_properties: dict[str | int, list[str]]
) -> dict[str | int, list[str]]:
event_tasks = [
self._retrieve_multiple_event_or_action_property_values(event_name_or_action_id, property_names)
for event_name_or_action_id, property_names in event_properties.items()
]
results = await asyncio.gather(*event_tasks)
return dict(zip(event_properties.keys(), results))
@database_sync_to_async(thread_sensitive=False)
def _get_definitions_for_event_or_action(self, property_names: list[str]) -> dict[str, PropertyDefinition]:
return {
prop.name: prop
for prop in PropertyDefinition.objects.filter(
team=self._team,
name__in=property_names,
type=PropertyDefinition.Type.EVENT,
)
}
async def _retrieve_multiple_event_or_action_property_values(
self, event_name_or_action_id: str | int, property_names: list[str]
) -> list[str]:
"""Retrieve property values for multiple events/actions and properties efficiently."""
results = []
try:
property_definition = PropertyDefinition.objects.get(
team=self._team, name=property_name, type=PropertyDefinition.Type.EVENT
definitions_map: dict[str, PropertyDefinition] = await self._get_definitions_for_event_or_action(
property_names
)
except PropertyDefinition.DoesNotExist:
return TaxonomyErrorMessages.property_not_found(property_name)
definitions_map = {}
response, verbose_name = await self._retrieve_event_or_action_taxonomy(event_name_or_action_id, property_names)
response, verbose_name = self._retrieve_event_or_action_taxonomy(event_name_or_action_id)
if not isinstance(response, CachedEventTaxonomyQueryResponse):
return TaxonomyErrorMessages.event_not_found(verbose_name)
results.append(TaxonomyErrorMessages.event_not_found(verbose_name))
return results
if not response.results:
return TaxonomyErrorMessages.property_values_not_found(property_name, verbose_name)
for property_name in property_names:
results.append(TaxonomyErrorMessages.property_values_not_found(property_name, verbose_name))
return results
prop = next((item for item in response.results if item.property == property_name), None)
if not prop:
return TaxonomyErrorMessages.property_not_found(property_name, verbose_name)
# Create a map of property name to taxonomy result for efficient lookup
taxonomy_results_map: dict[str, EventTaxonomyItem] = {item.property: item for item in response.results}
return self._format_property_values(
prop.sample_values,
prop.sample_count,
format_as_string=property_definition.property_type in (PropertyType.String, PropertyType.Datetime),
results.extend(
self._process_property_values(
property_names, list(taxonomy_results_map.values()), definitions_map, verbose_name, is_indexed=False
)
)
def handle_tools(self, tool_name: str, tool_input: TaxonomyTool) -> tuple[str, str]:
# Here we handle the tool execution for base taxonomy tools.
return results
def _process_property_values(
self,
property_names: list[str],
property_results: list,
property_definitions: dict[str, PropertyDefinition],
entity_name: str,
is_indexed: bool = False,
) -> list[str]:
"""Common logic for processing property values from taxonomy results."""
results = []
for i, property_name in enumerate(property_names):
property_definition = property_definitions.get(property_name)
if property_definition is None:
results.append(TaxonomyErrorMessages.property_not_found(property_name, entity_name))
continue
if is_indexed:
if i >= len(property_results):
results.append(TaxonomyErrorMessages.property_not_found(property_name, entity_name))
continue
prop_result = property_results[i]
else:
prop_result = next((r for r in property_results if r.property == property_name), None)
if prop_result is None:
results.append(TaxonomyErrorMessages.property_not_found(property_name, entity_name))
continue
result = self._format_property_values(
property_name,
prop_result.sample_values,
prop_result.sample_count,
format_as_string=property_definition.property_type in (PropertyType.String, PropertyType.Datetime),
)
results.append(result)
return results
async def handle_tools(self, tool_name: str, tool_input: TaxonomyTool) -> tuple[str, str]:
if tool_name == "retrieve_entity_property_values":
result = self.retrieve_entity_property_values(
tool_input.arguments.entity, # type: ignore
tool_input.arguments.property_name, # type: ignore
)
entity = tool_input.arguments.entity # type: ignore
property_name = tool_input.arguments.property_name # type: ignore
result = (await self.retrieve_entity_property_values({entity: [property_name]}))[entity][0]
elif tool_name == "retrieve_entity_properties":
result = self.retrieve_entity_properties(tool_input.arguments.entity) # type: ignore
result = await self.retrieve_entity_properties(tool_input.arguments.entity) # type: ignore
elif tool_name == "retrieve_event_property_values":
result = self.retrieve_event_or_action_property_values(
tool_input.arguments.event_name, # type: ignore
tool_input.arguments.property_name, # type: ignore
)
event_name_or_action_id = tool_input.arguments.event_name # type: ignore
property_name = tool_input.arguments.property_name # type: ignore
result = (await self.retrieve_event_or_action_property_values({event_name_or_action_id: [property_name]}))[
event_name_or_action_id
][0]
elif tool_name == "retrieve_event_properties":
result = self.retrieve_event_or_action_properties(tool_input.arguments.event_name) # type: ignore
result = await self.retrieve_event_or_action_properties(tool_input.arguments.event_name) # type: ignore
elif tool_name == "ask_user_for_help":
result = tool_input.arguments.request # type: ignore
else:

View File

@@ -1028,7 +1028,7 @@
WHERE and(equals(s.team_id, 99999), greaterOrEquals(toTimeZone(s.min_first_timestamp, 'UTC'), toDateTime64('2020-12-11 13:46:23.000000', 6, 'UTC')), greaterOrEquals(addDays(dateTrunc('DAY', toTimeZone(s.min_first_timestamp, 'UTC')), 1), minus(toDateTime64('today', 6, 'UTC'), toIntervalDay(coalesce(s.retention_period_days, 365)))), greaterOrEquals(toTimeZone(s.min_first_timestamp, 'UTC'), toDateTime64('2020-12-29 00:00:00.000000', 6, 'UTC')), lessOrEquals(toTimeZone(s.min_first_timestamp, 'UTC'), toDateTime64('today', 6, 'UTC')), globalIn(s.session_id,
(SELECT events.`$session_id` AS session_id
FROM events
WHERE and(equals(events.team_id, 99999), notEmpty(events.`$session_id`), greaterOrEquals(toTimeZone(events.timestamp, 'UTC'), toDateTime64('2020-12-11 13:46:23.000000', 6, 'UTC')), lessOrEquals(toTimeZone(events.timestamp, 'UTC'), now64(6, 'UTC')), ifNull(equals(nullIf(nullIf(events.mat_pp_email, ''), 'null'), 'bla'), 0), greaterOrEquals(toTimeZone(events.timestamp, 'UTC'), toDateTime64('2020-12-28 00:00:00.000000', 6, 'UTC')))
WHERE and(equals(events.team_id, 99999), notEmpty(events.`$session_id`), greaterOrEquals(toTimeZone(events.timestamp, 'UTC'), toDateTime64('2020-12-11 13:46:23.000000', 6, 'UTC')), lessOrEquals(toTimeZone(events.timestamp, 'UTC'), now64(6, 'UTC')), ifNull(equals(replaceRegexpAll(nullIf(nullIf(JSONExtractRaw(events.person_properties, 'email'), ''), 'null'), '^"|"$', ''), 'bla'), 0), greaterOrEquals(toTimeZone(events.timestamp, 'UTC'), toDateTime64('2020-12-28 00:00:00.000000', 6, 'UTC')))
GROUP BY events.`$session_id`
HAVING 1
ORDER BY min(toTimeZone(events.timestamp, 'UTC')) DESC)))

View File

@@ -118,8 +118,8 @@ class ErrorTrackingIssueImpactToolkit(TaxonomyAgentToolkit):
def __init__(self, team: Team):
super().__init__(team)
def handle_tools(self, tool_name: str, tool_input: TaxonomyTool) -> tuple[str, str]:
return super().handle_tools(tool_name, tool_input)
async def handle_tools(self, tool_name: str, tool_input: TaxonomyTool) -> tuple[str, str]:
return await super().handle_tools(tool_name, tool_input)
def _get_custom_tools(self) -> list:
return [final_answer]

View File

@@ -9,6 +9,7 @@ from posthog.schema import RevenueAnalyticsAssistantFilters
from posthog.clickhouse.query_tagging import Product, tags_context
from posthog.models import Team, User
from posthog.sync import database_sync_to_async
from posthog.taxonomy.taxonomy import CORE_FILTER_DEFINITIONS_BY_GROUP
from products.revenue_analytics.backend.api import find_values_for_revenue_analytics_property
@@ -55,13 +56,13 @@ class RevenueAnalyticsFilterOptionsToolkit(TaxonomyAgentToolkit):
def __init__(self, team: Team):
super().__init__(team)
def handle_tools(self, tool_name: str, tool_input) -> tuple[str, str]:
async def handle_tools(self, tool_name: str, tool_input) -> tuple[str, str]:
"""Handle custom tool execution."""
if tool_name == "retrieve_revenue_analytics_property_values":
result = self._retrieve_revenue_analytics_property_values(tool_input.arguments.property_key)
result = await self._retrieve_revenue_analytics_property_values(tool_input.arguments.property_key)
return tool_name, result
return super().handle_tools(tool_name, tool_input)
return await super().handle_tools(tool_name, tool_input)
def _get_custom_tools(self) -> list:
return [final_answer, retrieve_revenue_analytics_property_values]
@@ -72,7 +73,7 @@ class RevenueAnalyticsFilterOptionsToolkit(TaxonomyAgentToolkit):
"""Returns the list of tools available in this toolkit."""
return [*self._get_custom_tools(), ask_user_for_help]
def _retrieve_revenue_analytics_property_values(self, property_name: str) -> str:
async def _retrieve_revenue_analytics_property_values(self, property_name: str) -> str:
"""
Revenue analytics properties come from Clickhouse so let's run a separate query here.
"""
@@ -80,9 +81,9 @@ class RevenueAnalyticsFilterOptionsToolkit(TaxonomyAgentToolkit):
return TaxonomyErrorMessages.property_not_found(property_name, "revenue_analytics")
with tags_context(product=Product.MAX_AI, team_id=self._team.pk, org_id=self._team.organization_id):
values = find_values_for_revenue_analytics_property(property_name, self._team)
values = await database_sync_to_async(find_values_for_revenue_analytics_property)(property_name, self._team)
return self._format_property_values(values, sample_count=len(values))
return self._format_property_values(property_name, values, sample_count=len(values))
class RevenueAnalyticsFilterNode(

View File

@@ -183,12 +183,12 @@ class SurveyToolkit(TaxonomyAgentToolkit):
return [lookup_feature_flag, final_answer]
def handle_tools(self, tool_name: str, tool_input) -> tuple[str, str]:
async def handle_tools(self, tool_name: str, tool_input) -> tuple[str, str]:
"""Handle custom tool execution."""
if tool_name == "lookup_feature_flag":
result = self._lookup_feature_flag(tool_input.arguments.flag_key)
return tool_name, result
return super().handle_tools(tool_name, tool_input)
return await super().handle_tools(tool_name, tool_input)
def _lookup_feature_flag(self, flag_key: str) -> str:
"""Look up feature flag information by key."""