mirror of
https://github.com/BillyOutlast/posthog.git
synced 2026-02-04 03:01:23 +01:00
feat(max): allow multiple property value search (#38489)
This commit is contained in:
@@ -7,13 +7,18 @@ import yaml
|
||||
from posthog.taxonomy.taxonomy import CORE_FILTER_DEFINITIONS_BY_GROUP
|
||||
|
||||
|
||||
def format_property_values(sample_values: list, sample_count: Optional[int] = 0, format_as_string: bool = False) -> str:
|
||||
def format_property_values(
|
||||
property_name: str, sample_values: list, sample_count: Optional[int] = 0, format_as_string: bool = False
|
||||
) -> str:
|
||||
if len(sample_values) == 0 or sample_count == 0:
|
||||
return f"The property does not have any values in the taxonomy."
|
||||
data = {
|
||||
"property": property_name,
|
||||
"values": [],
|
||||
"message": f"The property does not have any values in the taxonomy.",
|
||||
}
|
||||
return yaml.dump(data, default_flow_style=False, sort_keys=False)
|
||||
|
||||
# Add quotes to the String type, so the LLM can easily infer a type.
|
||||
# Strings like "true" or "10" are interpreted as booleans or numbers without quotes, so the schema generation fails.
|
||||
# Remove the floating point the value is an integer.
|
||||
# Format values for YAML
|
||||
formatted_sample_values: list[str] = []
|
||||
for value in sample_values:
|
||||
if format_as_string:
|
||||
@@ -22,16 +27,14 @@ def format_property_values(sample_values: list, sample_count: Optional[int] = 0,
|
||||
formatted_sample_values.append(str(int(value)))
|
||||
else:
|
||||
formatted_sample_values.append(str(value))
|
||||
prop_values = ", ".join(formatted_sample_values)
|
||||
|
||||
# If there wasn't an exact match with the user's search, we provide a hint that LLM can use an arbitrary value.
|
||||
if sample_count is None:
|
||||
return f"{prop_values} and many more distinct values."
|
||||
formatted_sample_values.append("and many more distinct values")
|
||||
elif sample_count > len(sample_values):
|
||||
diff = sample_count - len(sample_values)
|
||||
return f"{prop_values} and {diff} more distinct value{'' if diff == 1 else 's'}."
|
||||
|
||||
return prop_values
|
||||
remaining = sample_count - len(sample_values)
|
||||
formatted_sample_values.append(f"and {remaining} more distinct values")
|
||||
data = {"property": property_name, "values": formatted_sample_values}
|
||||
return yaml.dump(data, default_flow_style=False, sort_keys=False)
|
||||
|
||||
|
||||
def format_properties_xml(children: list[tuple[str, str | None, str | None]]):
|
||||
|
||||
@@ -163,7 +163,7 @@ class TaxonomyAgentToolsNode(
|
||||
def node_name(self) -> MaxNodeName:
|
||||
return TaxonomyNodeName.TOOLS_NODE
|
||||
|
||||
def run(self, state: TaxonomyStateType, config: RunnableConfig) -> TaxonomyPartialStateType:
|
||||
async def arun(self, state: TaxonomyStateType, config: RunnableConfig) -> TaxonomyPartialStateType:
|
||||
intermediate_steps = state.intermediate_steps or []
|
||||
action, _output = intermediate_steps[-1]
|
||||
tool_input: TaxonomyTool | None = None
|
||||
@@ -199,7 +199,7 @@ class TaxonomyAgentToolsNode(
|
||||
|
||||
if tool_input and not output:
|
||||
# Use the toolkit to handle tool execution
|
||||
_, output = self._toolkit.handle_tools(tool_input.name, tool_input)
|
||||
_, output = await self._toolkit.handle_tools(tool_input.name, tool_input)
|
||||
|
||||
if output:
|
||||
tool_msg = LangchainToolMessage(
|
||||
|
||||
189
ee/hogai/graph/taxonomy/test/test_entities.py
Normal file
189
ee/hogai/graph/taxonomy/test/test_entities.py
Normal file
@@ -0,0 +1,189 @@
|
||||
from posthog.test.base import ClickhouseTestMixin, NonAtomicBaseTest
|
||||
from unittest.mock import patch
|
||||
|
||||
from posthog.models.person import Person
|
||||
from posthog.models.property_definition import PropertyDefinition
|
||||
from posthog.test.test_utils import create_group_type_mapping_without_created_at
|
||||
|
||||
from ee.hogai.graph.taxonomy.toolkit import TaxonomyAgentToolkit
|
||||
|
||||
|
||||
class DummyToolkit(TaxonomyAgentToolkit):
|
||||
def get_tools(self):
|
||||
return self._get_default_tools()
|
||||
|
||||
|
||||
class TestEntities(ClickhouseTestMixin, NonAtomicBaseTest):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
for i, group_type in enumerate(["organization", "project"]):
|
||||
create_group_type_mapping_without_created_at(
|
||||
team=self.team, project_id=self.team.project_id, group_type_index=i, group_type=group_type
|
||||
)
|
||||
|
||||
PropertyDefinition.objects.create(
|
||||
team=self.team,
|
||||
name="name",
|
||||
property_type="String",
|
||||
is_numerical=False,
|
||||
type=PropertyDefinition.Type.PERSON,
|
||||
)
|
||||
PropertyDefinition.objects.create(
|
||||
team=self.team,
|
||||
name="name_group",
|
||||
property_type="String",
|
||||
is_numerical=False,
|
||||
type=PropertyDefinition.Type.GROUP,
|
||||
group_type_index=0,
|
||||
)
|
||||
PropertyDefinition.objects.create(
|
||||
team=self.team,
|
||||
name="property_no_values",
|
||||
property_type="String",
|
||||
is_numerical=False,
|
||||
type=PropertyDefinition.Type.PERSON,
|
||||
)
|
||||
|
||||
Person.objects.create(
|
||||
team=self.team,
|
||||
distinct_ids=["test-user"],
|
||||
properties={"name": "Test User"},
|
||||
)
|
||||
|
||||
self.toolkit = DummyToolkit(self.team)
|
||||
|
||||
async def test_retrieve_entity_properties(self):
|
||||
result = await self.toolkit.retrieve_entity_properties("person")
|
||||
assert (
|
||||
"<properties><String><prop><name>name</name></prop><prop><name>property_no_values</name></prop></String></properties>"
|
||||
== result
|
||||
)
|
||||
|
||||
async def test_person_property_values_exists(self):
|
||||
result = await self.toolkit._get_entity_names()
|
||||
expected = ["person", "session", "organization", "project"]
|
||||
self.assertEqual(result, expected)
|
||||
|
||||
property_vals = await self.toolkit.retrieve_entity_property_values({"person": ["name"]})
|
||||
self.assertIn("person", property_vals)
|
||||
self.assertIn("name", "\n".join(property_vals.get("person", [])))
|
||||
self.assertTrue(any("Test User" in str(val) for val in property_vals.get("person", [])))
|
||||
|
||||
async def test_person_property_values_do_not_exist(self):
|
||||
result = await self.toolkit._get_entity_names()
|
||||
expected = ["person", "session", "organization", "project"]
|
||||
self.assertEqual(result, expected)
|
||||
|
||||
property_vals = await self.toolkit.retrieve_entity_property_values({"person": ["property_no_values"]})
|
||||
self.assertIn("person", property_vals)
|
||||
self.assertIn("property_no_values", "\n".join(property_vals.get("person", [])))
|
||||
self.assertTrue(
|
||||
any(
|
||||
"The property does not have any values in the taxonomy." in str(val)
|
||||
for val in property_vals.get("person", [])
|
||||
)
|
||||
)
|
||||
|
||||
async def test_person_property_values_mixed(self):
|
||||
result = await self.toolkit._get_entity_names()
|
||||
expected = ["person", "session", "organization", "project"]
|
||||
self.assertEqual(result, expected)
|
||||
|
||||
property_vals = await self.toolkit.retrieve_entity_property_values({"person": ["property_no_values", "name"]})
|
||||
|
||||
self.assertIn("person", property_vals)
|
||||
self.assertIn("property_no_values", "\n".join(property_vals.get("person", [])))
|
||||
self.assertTrue(
|
||||
any(
|
||||
"The property does not have any values in the taxonomy." in str(val)
|
||||
for val in property_vals.get("person", [])
|
||||
)
|
||||
)
|
||||
self.assertIn("name", "\n".join(property_vals.get("person", [])))
|
||||
self.assertTrue(any("Test User" in str(val) for val in property_vals.get("person", [])))
|
||||
|
||||
async def test_multiple_entities(self):
|
||||
result = await self.toolkit._get_entity_names()
|
||||
expected = ["person", "session", "organization", "project"]
|
||||
self.assertEqual(result, expected)
|
||||
|
||||
property_vals = await self.toolkit.retrieve_entity_property_values(
|
||||
{
|
||||
"person": ["property_no_values"],
|
||||
"session": ["$session_duration", "$channel_type", "nonexistent_property"],
|
||||
}
|
||||
)
|
||||
|
||||
self.assertIn("person", property_vals)
|
||||
self.assertIn("property_no_values", "\n".join(property_vals.get("person", [])))
|
||||
self.assertTrue(
|
||||
any(
|
||||
"The property does not have any values in the taxonomy." in str(val)
|
||||
for val in property_vals.get("person", [])
|
||||
)
|
||||
)
|
||||
self.assertIn("session", property_vals)
|
||||
self.assertIn("$session_duration", "\n".join(property_vals.get("session", [])))
|
||||
self.assertIn("$channel_type", "\n".join(property_vals.get("session", [])))
|
||||
self.assertIn("nonexistent_property", "\n".join(property_vals.get("session", [])))
|
||||
self.assertTrue(
|
||||
any(
|
||||
"values:\n- '30'\n- '146'\n- '2'\n- and many more distinct values\n" in str(val)
|
||||
for val in property_vals.get("session", [])
|
||||
)
|
||||
)
|
||||
self.assertTrue(any("Direct" in str(val) for val in property_vals.get("session", [])))
|
||||
self.assertTrue(
|
||||
any(
|
||||
"The property nonexistent_property does not exist in the taxonomy." in str(val)
|
||||
for val in property_vals.get("session", [])
|
||||
)
|
||||
)
|
||||
|
||||
async def test_retrieve_entity_property_values_batching(self):
|
||||
"""Test that when more than 6 entities are processed, they are sent in batches of 6"""
|
||||
# Create 8 entities (more than 6) to test batching
|
||||
entities = [f"entity_{i}" for i in range(8)]
|
||||
entity_properties = {
|
||||
entity: ["$session_duration", "$channel_type", "nonexistent_property"] for entity in entities
|
||||
}
|
||||
|
||||
# Spy on the _handle_entity_batch method to track how many times it's called
|
||||
with patch.object(self.toolkit, "_handle_entity_batch") as mock_handle_batch:
|
||||
# Mock the method to return a simple result
|
||||
mock_handle_batch.return_value = {
|
||||
entity: [
|
||||
"values:\n- '30'\n- '146'\n- '2'\n- and many more distinct values\n",
|
||||
"Direct",
|
||||
"The property nonexistent_property does not exist in the taxonomy.",
|
||||
]
|
||||
for entity in entities
|
||||
}
|
||||
|
||||
result = await self.toolkit.retrieve_entity_property_values(entity_properties)
|
||||
|
||||
# Verify that we got results for all entities
|
||||
self.assertEqual(len(result), 8)
|
||||
for entity in entities:
|
||||
self.assertIn(entity, result)
|
||||
self.assertEqual(
|
||||
result[entity],
|
||||
[
|
||||
"values:\n- '30'\n- '146'\n- '2'\n- and many more distinct values\n",
|
||||
"Direct",
|
||||
"The property nonexistent_property does not exist in the taxonomy.",
|
||||
],
|
||||
)
|
||||
|
||||
# Verify that _handle_entity_batch was called twice:
|
||||
# - First batch: entities 0-5 (6 entities)
|
||||
# - Second batch: entities 6-7 (2 entities)
|
||||
self.assertEqual(mock_handle_batch.call_count, 2)
|
||||
|
||||
# Verify the batch sizes
|
||||
call_args_list = mock_handle_batch.call_args_list
|
||||
first_batch = call_args_list[0][0][0] # First argument of first call
|
||||
second_batch = call_args_list[1][0][0] # First argument of second call
|
||||
|
||||
self.assertEqual(len(first_batch), 6) # First batch should have 6 entities
|
||||
self.assertEqual(len(second_batch), 2) # Second batch should have 2 entities
|
||||
135
ee/hogai/graph/taxonomy/test/test_events.py
Normal file
135
ee/hogai/graph/taxonomy/test/test_events.py
Normal file
@@ -0,0 +1,135 @@
|
||||
from posthog.test.base import ClickhouseTestMixin, NonAtomicBaseTest, _create_event, flush_persons_and_events
|
||||
|
||||
from posthog.models import Action
|
||||
from posthog.models.property_definition import PropertyDefinition
|
||||
from posthog.test.test_utils import create_group_type_mapping_without_created_at
|
||||
|
||||
from ee.hogai.graph.taxonomy.toolkit import TaxonomyAgentToolkit
|
||||
|
||||
|
||||
class DummyToolkit(TaxonomyAgentToolkit):
|
||||
def get_tools(self):
|
||||
return self._get_default_tools()
|
||||
|
||||
|
||||
class TestEvents(ClickhouseTestMixin, NonAtomicBaseTest):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
for i, group_type in enumerate(["organization", "project"]):
|
||||
create_group_type_mapping_without_created_at(
|
||||
team=self.team, project_id=self.team.project_id, group_type_index=i, group_type=group_type
|
||||
)
|
||||
|
||||
PropertyDefinition.objects.create(
|
||||
team=self.team,
|
||||
name="$browser",
|
||||
property_type="String",
|
||||
is_numerical=False,
|
||||
type=PropertyDefinition.Type.EVENT,
|
||||
)
|
||||
PropertyDefinition.objects.create(
|
||||
team=self.team,
|
||||
name="id",
|
||||
property_type="String",
|
||||
is_numerical=False,
|
||||
type=PropertyDefinition.Type.EVENT,
|
||||
)
|
||||
|
||||
PropertyDefinition.objects.create(
|
||||
team=self.team,
|
||||
name="no_values",
|
||||
property_type="Boolean",
|
||||
is_numerical=False,
|
||||
type=PropertyDefinition.Type.EVENT,
|
||||
)
|
||||
|
||||
# Create events that match the action conditions
|
||||
_create_event(
|
||||
event="event1",
|
||||
distinct_id="user123",
|
||||
team=self.team,
|
||||
properties={
|
||||
"$browser": "Chrome",
|
||||
"id": "123",
|
||||
},
|
||||
)
|
||||
|
||||
_create_event(
|
||||
event="event1",
|
||||
distinct_id="user456",
|
||||
team=self.team,
|
||||
properties={
|
||||
"$browser": "Firefox",
|
||||
"id": "456",
|
||||
},
|
||||
)
|
||||
|
||||
Action.objects.create(
|
||||
id=232, team=self.team, name="action1", description="Test Description", steps_json=[{"event": "event1"}]
|
||||
)
|
||||
|
||||
self.toolkit = DummyToolkit(self.team)
|
||||
|
||||
async def test_events_property_values_exists(self):
|
||||
result = await self.toolkit._get_entity_names()
|
||||
expected = ["person", "session", "organization", "project"]
|
||||
assert result == expected
|
||||
|
||||
property_vals = await self.toolkit.retrieve_event_or_action_property_values({"event1": ["$browser", "id"]})
|
||||
assert "event1" in property_vals
|
||||
assert "$browser" in "\n".join(property_vals.get("event1", []))
|
||||
assert "id" in "\n".join(property_vals.get("event1", []))
|
||||
|
||||
async def test_events_property_values_do_not_exist(self):
|
||||
result = await self.toolkit._get_entity_names()
|
||||
expected = ["person", "session", "organization", "project"]
|
||||
assert result == expected
|
||||
|
||||
property_vals = await self.toolkit.retrieve_event_or_action_property_values({"event1": ["no_values"]})
|
||||
|
||||
assert "event1" in property_vals
|
||||
assert "no_values" in "\n".join(property_vals.get("event1", []))
|
||||
assert "No values found for property no_values on entity event event1" in "\n".join(
|
||||
property_vals.get("event1", [])
|
||||
)
|
||||
|
||||
async def test_events_property_values_action_values_not_found(self):
|
||||
result = await self.toolkit._get_entity_names()
|
||||
expected = ["person", "session", "organization", "project"]
|
||||
assert result == expected
|
||||
|
||||
property_vals = await self.toolkit.retrieve_event_or_action_property_values({232: ["no_values"]})
|
||||
|
||||
assert 232 in property_vals
|
||||
assert "no_values" in "\n".join(property_vals.get(232, []))
|
||||
assert "No values found for property no_values on entity action with ID 232" in "\n".join(
|
||||
property_vals.get(232, [])
|
||||
)
|
||||
|
||||
async def test_events_property_values_action_multiple_properties(self):
|
||||
result = await self.toolkit._get_entity_names()
|
||||
expected = ["person", "session", "organization", "project"]
|
||||
assert result == expected
|
||||
|
||||
property_vals = await self.toolkit.retrieve_event_or_action_property_values({232: ["no_values", "$browser"]})
|
||||
|
||||
assert 232 in property_vals
|
||||
assert "no_values" in "\n".join(property_vals.get(232, []))
|
||||
assert "$browser" in "\n".join(property_vals.get(232, []))
|
||||
# Should have actual values
|
||||
assert "Chrome" in "\n".join(property_vals.get(232, []))
|
||||
assert "Firefox" in "\n".join(property_vals.get(232, []))
|
||||
|
||||
async def test_retrieve_event_or_action_properties_action_not_found(self):
|
||||
result = await self.toolkit.retrieve_event_or_action_properties(999)
|
||||
assert (
|
||||
"Action 999 does not exist in the taxonomy. Verify that the action ID is correct and try again." == result
|
||||
)
|
||||
|
||||
async def test_retrieve_event_or_action_properties_event_not_found(self):
|
||||
result = await self.toolkit.retrieve_event_or_action_properties("test")
|
||||
assert "Properties do not exist in the taxonomy for the event test." == result
|
||||
|
||||
def tearDown(self):
|
||||
flush_persons_and_events()
|
||||
super().tearDown()
|
||||
127
ee/hogai/graph/taxonomy/test/test_groups.py
Normal file
127
ee/hogai/graph/taxonomy/test/test_groups.py
Normal file
@@ -0,0 +1,127 @@
|
||||
from datetime import datetime
|
||||
|
||||
from posthog.test.base import ClickhouseTestMixin, NonAtomicBaseTest
|
||||
|
||||
from posthog.models.group.util import raw_create_group_ch
|
||||
from posthog.models.property_definition import PropertyDefinition
|
||||
from posthog.test.test_utils import create_group_type_mapping_without_created_at
|
||||
|
||||
from ee.hogai.graph.taxonomy.toolkit import TaxonomyAgentToolkit
|
||||
|
||||
|
||||
class DummyToolkit(TaxonomyAgentToolkit):
|
||||
def get_tools(self):
|
||||
return self._get_default_tools()
|
||||
|
||||
|
||||
class TestGroups(ClickhouseTestMixin, NonAtomicBaseTest):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
for i, group_type in enumerate(["organization", "project", "no_properties"]):
|
||||
create_group_type_mapping_without_created_at(
|
||||
team=self.team, project_id=self.team.project_id, group_type_index=i, group_type=group_type
|
||||
)
|
||||
|
||||
# Create property definitions for organization (group_type_index=0)
|
||||
PropertyDefinition.objects.create(
|
||||
team=self.team,
|
||||
name="name",
|
||||
property_type="String",
|
||||
is_numerical=False,
|
||||
type=PropertyDefinition.Type.GROUP,
|
||||
group_type_index=0, # organization
|
||||
)
|
||||
PropertyDefinition.objects.create(
|
||||
team=self.team,
|
||||
name="industry",
|
||||
property_type="String",
|
||||
is_numerical=False,
|
||||
type=PropertyDefinition.Type.GROUP,
|
||||
group_type_index=0, # organization
|
||||
)
|
||||
# Create property definitions for project (group_type_index=1)
|
||||
PropertyDefinition.objects.create(
|
||||
team=self.team,
|
||||
name="size",
|
||||
property_type="Numeric",
|
||||
is_numerical=True,
|
||||
type=PropertyDefinition.Type.GROUP,
|
||||
group_type_index=1, # project
|
||||
)
|
||||
PropertyDefinition.objects.create(
|
||||
team=self.team,
|
||||
name="name_group",
|
||||
property_type="String",
|
||||
is_numerical=False,
|
||||
type=PropertyDefinition.Type.GROUP,
|
||||
group_type_index=0,
|
||||
)
|
||||
|
||||
raw_create_group_ch(
|
||||
team_id=self.team.id,
|
||||
group_type_index=0,
|
||||
group_key="acme-corp",
|
||||
properties={"name": "Acme Corp", "industry": "tech"},
|
||||
created_at=datetime.now(),
|
||||
sync=True, # Force sync to ClickHouse
|
||||
)
|
||||
raw_create_group_ch(
|
||||
team_id=self.team.id,
|
||||
group_type_index=1,
|
||||
group_key="acme-project",
|
||||
properties={"name": "Acme Project", "size": 100},
|
||||
created_at=datetime.now(),
|
||||
sync=True, # Force sync to ClickHouse
|
||||
)
|
||||
|
||||
self.toolkit = DummyToolkit(self.team)
|
||||
|
||||
async def test_entity_names_with_existing_groups(self):
|
||||
# Test that the entity names include the groups we created in setUp
|
||||
result = await self.toolkit._get_entity_names()
|
||||
expected = ["person", "session", "organization", "project", "no_properties"]
|
||||
self.assertEqual(result, expected)
|
||||
|
||||
property_vals = await self.toolkit.retrieve_entity_property_values(
|
||||
{"organization": ["name", "industry"], "project": ["size"]}
|
||||
)
|
||||
|
||||
# Should return the actual values from the groups we created
|
||||
self.assertIn("organization", property_vals)
|
||||
self.assertIn("project", property_vals)
|
||||
|
||||
self.assertTrue(any("Acme Corp" in str(val) for val in property_vals.get("organization", [])))
|
||||
self.assertTrue(any("tech" in str(val) for val in property_vals.get("organization", [])))
|
||||
self.assertTrue(any("100" in str(val) for val in property_vals.get("project", [])))
|
||||
|
||||
async def test_retrieve_entity_property_values_wrong_group(self):
|
||||
property_vals = await self.toolkit.retrieve_entity_property_values(
|
||||
{"test": ["name", "industry"], "project": ["size"]}
|
||||
)
|
||||
|
||||
self.assertIn("test", property_vals)
|
||||
self.assertIn(
|
||||
"Entity test not found. Available entities: person, session, organization, project, no_properties",
|
||||
property_vals["test"],
|
||||
)
|
||||
self.assertIn("project", property_vals)
|
||||
|
||||
async def test_retrieve_entity_properties_group(self):
|
||||
result = await self.toolkit.retrieve_entity_properties("organization")
|
||||
|
||||
assert (
|
||||
"<properties><String><prop><name>name</name></prop><prop><name>industry</name></prop><prop><name>name_group</name></prop></String></properties>"
|
||||
== result
|
||||
)
|
||||
|
||||
async def test_retrieve_entity_properties_group_not_found(self):
|
||||
result = await self.toolkit.retrieve_entity_properties("test")
|
||||
|
||||
assert (
|
||||
"Entity test not found. Available entities: person, session, organization, project, no_properties" == result
|
||||
)
|
||||
|
||||
async def test_retrieve_entity_properties_group_nothing_found(self):
|
||||
result = await self.toolkit.retrieve_entity_properties("no_properties")
|
||||
|
||||
assert "Properties do not exist in the taxonomy for the entity no_properties." == result
|
||||
@@ -159,7 +159,7 @@ class TestTaxonomyAgentToolsNode(BaseTest):
|
||||
|
||||
@patch.object(MockTaxonomyAgentToolkit, "get_tool_input_model")
|
||||
@patch.object(MockTaxonomyAgentToolkit, "handle_tools")
|
||||
def test_run_normal_tool_execution(self, mock_handle_tools, mock_get_tool_input):
|
||||
async def test_run_normal_tool_execution(self, mock_handle_tools, mock_get_tool_input):
|
||||
# Setup mocks
|
||||
mock_input = Mock()
|
||||
mock_input.name = "test_tool"
|
||||
@@ -172,14 +172,14 @@ class TestTaxonomyAgentToolsNode(BaseTest):
|
||||
state = TaxonomyAgentState()
|
||||
state.intermediate_steps = [(action, None)]
|
||||
|
||||
result = self.node.run(state, RunnableConfig())
|
||||
result = await self.node.arun(state, RunnableConfig())
|
||||
|
||||
self.assertIsInstance(result, TaxonomyAgentState)
|
||||
self.assertEqual(len(result.intermediate_steps), 1)
|
||||
self.assertEqual(result.intermediate_steps[0][1], "tool output")
|
||||
|
||||
@patch.object(MockTaxonomyAgentToolkit, "get_tool_input_model")
|
||||
def test_run_validation_error(self, mock_get_tool_input):
|
||||
async def test_run_validation_error(self, mock_get_tool_input):
|
||||
# Setup validation error
|
||||
validation_error = ValidationError.from_exception_data(
|
||||
"TestModel", [{"type": "missing", "loc": ("field",), "msg": "Field required"}]
|
||||
@@ -190,13 +190,13 @@ class TestTaxonomyAgentToolsNode(BaseTest):
|
||||
state = TaxonomyAgentState()
|
||||
state.intermediate_steps = [(action, None)]
|
||||
|
||||
result = self.node.run(state, RunnableConfig())
|
||||
result = await self.node.arun(state, RunnableConfig())
|
||||
|
||||
self.assertIsInstance(result, TaxonomyAgentState)
|
||||
self.assertEqual(len(result.tool_progress_messages), 1)
|
||||
|
||||
@patch.object(MockTaxonomyAgentToolkit, "get_tool_input_model")
|
||||
def test_run_final_answer(self, mock_get_tool_input):
|
||||
async def test_run_final_answer(self, mock_get_tool_input):
|
||||
# Mock final answer tool
|
||||
from pydantic import BaseModel
|
||||
|
||||
@@ -216,14 +216,14 @@ class TestTaxonomyAgentToolsNode(BaseTest):
|
||||
state = TaxonomyAgentState()
|
||||
state.intermediate_steps = [(action, None)]
|
||||
|
||||
result = self.node.run(state, RunnableConfig())
|
||||
result = await self.node.arun(state, RunnableConfig())
|
||||
|
||||
self.assertIsInstance(result, TaxonomyAgentState)
|
||||
self.assertEqual(result.output, expected_data)
|
||||
self.assertIsNone(result.intermediate_steps)
|
||||
|
||||
@patch.object(MockTaxonomyAgentToolkit, "get_tool_input_model")
|
||||
def test_run_ask_user_for_help(self, mock_get_tool_input):
|
||||
async def test_run_ask_user_for_help(self, mock_get_tool_input):
|
||||
# Mock ask for help tool
|
||||
mock_input = Mock()
|
||||
mock_input.name = "ask_user_for_help"
|
||||
@@ -238,11 +238,11 @@ class TestTaxonomyAgentToolsNode(BaseTest):
|
||||
with patch.object(self.node, "_get_reset_state") as mock_reset:
|
||||
mock_reset.return_value = TaxonomyAgentState()
|
||||
|
||||
_ = self.node.run(state, RunnableConfig())
|
||||
_ = await self.node.arun(state, RunnableConfig())
|
||||
|
||||
mock_reset.assert_called_once_with("Need help", "ask_user_for_help", state)
|
||||
|
||||
def test_run_max_iterations(self):
|
||||
async def test_run_max_iterations(self):
|
||||
# Create state with max iterations
|
||||
actions = []
|
||||
for i in range(self.node.MAX_ITERATIONS):
|
||||
@@ -255,7 +255,7 @@ class TestTaxonomyAgentToolsNode(BaseTest):
|
||||
with patch.object(self.node, "_get_reset_state") as mock_reset:
|
||||
mock_reset.return_value = TaxonomyAgentState()
|
||||
|
||||
_ = self.node.run(state, RunnableConfig())
|
||||
_ = await self.node.arun(state, RunnableConfig())
|
||||
|
||||
mock_reset.assert_called_once()
|
||||
call_args = mock_reset.call_args
|
||||
@@ -289,7 +289,7 @@ class TestTaxonomyAgentToolsNode(BaseTest):
|
||||
result = self.node.router(state)
|
||||
self.assertEqual(result, expected)
|
||||
|
||||
def test_get_reset_state(self):
|
||||
async def test_get_reset_state(self):
|
||||
original_state = TaxonomyAgentState()
|
||||
original_state.change = "test change"
|
||||
|
||||
|
||||
@@ -1,23 +1,10 @@
|
||||
from datetime import datetime
|
||||
|
||||
from posthog.test.base import BaseTest, ClickhouseTestMixin
|
||||
from unittest.mock import Mock, patch
|
||||
from posthog.test.base import BaseTest
|
||||
from unittest.mock import patch
|
||||
|
||||
from langchain_core.agents import AgentAction
|
||||
from parameterized import parameterized
|
||||
from pydantic import BaseModel
|
||||
|
||||
from posthog.schema import (
|
||||
ActorsPropertyTaxonomyResponse,
|
||||
CachedActorsPropertyTaxonomyQueryResponse,
|
||||
CachedEventTaxonomyQueryResponse,
|
||||
EventTaxonomyItem,
|
||||
)
|
||||
|
||||
from posthog.models import Action
|
||||
from posthog.models.property_definition import PropertyDefinition, PropertyType
|
||||
from posthog.test.test_utils import create_group_type_mapping_without_created_at
|
||||
|
||||
from ee.hogai.graph.taxonomy.toolkit import TaxonomyAgentToolkit, TaxonomyToolNotFoundError
|
||||
|
||||
|
||||
@@ -26,16 +13,14 @@ class DummyToolkit(TaxonomyAgentToolkit):
|
||||
return self._get_default_tools()
|
||||
|
||||
|
||||
class TestTaxonomyAgentToolkit(ClickhouseTestMixin, BaseTest):
|
||||
class TestTaxonomyAgentToolkit(BaseTest):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.toolkit = DummyToolkit(self.team)
|
||||
self.action = Action.objects.create(team=self.team, name="test_action", steps_json=[{"event": "test_event"}])
|
||||
|
||||
def test_toolkit_initialization(self):
|
||||
async def test_toolkit_initialization(self):
|
||||
self.assertEqual(self.toolkit._team, self.team)
|
||||
self.assertIsInstance(self.toolkit._team_group_types, list)
|
||||
self.assertIsInstance(self.toolkit._entity_names, list)
|
||||
self.assertIsInstance(await self.toolkit._get_entity_names(), list)
|
||||
|
||||
@parameterized.expand(
|
||||
[
|
||||
@@ -43,35 +28,10 @@ class TestTaxonomyAgentToolkit(ClickhouseTestMixin, BaseTest):
|
||||
("session", ["person", "session"]),
|
||||
]
|
||||
)
|
||||
def test_entity_names_basic(self, entity, expected_base):
|
||||
self.assertIn(entity, self.toolkit._entity_names)
|
||||
async def test_entity_names_basic(self, entity, expected_base):
|
||||
self.assertIn(entity, await self.toolkit._get_entity_names())
|
||||
for expected in expected_base:
|
||||
self.assertIn(expected, self.toolkit._entity_names)
|
||||
|
||||
def test_entity_names_with_groups(self):
|
||||
# Create group type mappings
|
||||
for i, group_type in enumerate(["organization", "project"]):
|
||||
create_group_type_mapping_without_created_at(
|
||||
team=self.team, project_id=self.team.project_id, group_type_index=i, group_type=group_type
|
||||
)
|
||||
|
||||
toolkit = DummyToolkit(self.team)
|
||||
expected = ["person", "session", "organization", "project"]
|
||||
self.assertEqual(toolkit._entity_names, expected)
|
||||
|
||||
@parameterized.expand(
|
||||
[
|
||||
("$session_duration", True, "30, 146, 2"),
|
||||
("$channel_type", True, "Direct"),
|
||||
("nonexistent_property", False, "does not exist"),
|
||||
]
|
||||
)
|
||||
def test_retrieve_session_properties(self, property_name, should_contain_values, expected_content):
|
||||
result = self.toolkit._retrieve_session_properties(property_name)
|
||||
if should_contain_values:
|
||||
self.assertIn(expected_content, result)
|
||||
else:
|
||||
self.assertIn(expected_content, result)
|
||||
self.assertIn(expected, await self.toolkit._get_entity_names())
|
||||
|
||||
def test_enrich_props_with_descriptions(self):
|
||||
props = [("$browser", "String"), ("custom_prop", "Numeric")]
|
||||
@@ -85,130 +45,16 @@ class TestTaxonomyAgentToolkit(ClickhouseTestMixin, BaseTest):
|
||||
@parameterized.expand(
|
||||
[
|
||||
([], 0, False, "The property does not have any values"),
|
||||
(["value1", "value2"], None, False, "value1, value2 and many more"),
|
||||
(["value1", "value2"], 5, False, "value1, value2 and 3 more"),
|
||||
(["value1", "value2"], None, False, "- value1\n- value2\n- and many more distinct values"),
|
||||
(["value1", "value2"], 5, False, "- value1\n- value2\n- and 3 more distinct values"),
|
||||
(["string_val"], 1, True, '"string_val"'),
|
||||
([1.0, 2.0], 2, False, "1, 2"),
|
||||
([1.0, 2.0], 2, False, "'1'\n- '2'"),
|
||||
]
|
||||
)
|
||||
def test_format_property_values(self, sample_values, sample_count, format_as_string, expected_substring):
|
||||
result = self.toolkit._format_property_values(sample_values, sample_count, format_as_string)
|
||||
result = self.toolkit._format_property_values("test_property", sample_values, sample_count, format_as_string)
|
||||
self.assertIn(expected_substring, result)
|
||||
|
||||
def _create_property_definition(self, prop_type, name="test_prop", group_type_index=None):
|
||||
"""Helper to create property definitions"""
|
||||
kwargs = {"team": self.team, "name": name, "property_type": PropertyType.String}
|
||||
if prop_type == PropertyDefinition.Type.GROUP:
|
||||
kwargs["type"] = PropertyDefinition.Type.GROUP
|
||||
kwargs["group_type_index"] = group_type_index
|
||||
else:
|
||||
kwargs["type"] = prop_type
|
||||
|
||||
return PropertyDefinition.objects.create(**kwargs)
|
||||
|
||||
def _create_mock_taxonomy_response(self, response_type="event", **kwargs):
|
||||
"""Helper to create mock taxonomy responses"""
|
||||
if response_type == "event":
|
||||
return CachedEventTaxonomyQueryResponse(
|
||||
cache_key="test",
|
||||
is_cached=False,
|
||||
last_refresh=datetime.now().isoformat(),
|
||||
next_allowed_client_refresh=datetime.now().isoformat(),
|
||||
timezone="UTC",
|
||||
results=[EventTaxonomyItem(**kwargs)],
|
||||
)
|
||||
elif response_type == "actors":
|
||||
return CachedActorsPropertyTaxonomyQueryResponse(
|
||||
cache_key="test",
|
||||
is_cached=False,
|
||||
last_refresh=datetime.now().isoformat(),
|
||||
next_allowed_client_refresh=datetime.now().isoformat(),
|
||||
timezone="UTC",
|
||||
results=ActorsPropertyTaxonomyResponse(**kwargs),
|
||||
)
|
||||
|
||||
def test_retrieve_entity_properties_person(self):
|
||||
self._create_property_definition(PropertyDefinition.Type.PERSON, "email")
|
||||
result = self.toolkit.retrieve_entity_properties("person")
|
||||
self.assertIn("email", result)
|
||||
self.assertIn("String", result)
|
||||
|
||||
def test_retrieve_entity_properties_session(self):
|
||||
result = self.toolkit.retrieve_entity_properties("session")
|
||||
self.assertIn("$session_duration", result)
|
||||
self.assertIn("properties", result)
|
||||
|
||||
def test_retrieve_entity_properties_group(self):
|
||||
create_group_type_mapping_without_created_at(
|
||||
team=self.team, project_id=self.team.project_id, group_type_index=0, group_type="organization"
|
||||
)
|
||||
self._create_property_definition(PropertyDefinition.Type.GROUP, "org_name", group_type_index=0)
|
||||
result = self.toolkit.retrieve_entity_properties("organization")
|
||||
self.assertIn("org_name", result)
|
||||
|
||||
@parameterized.expand(
|
||||
[
|
||||
("invalid_entity", "Entity invalid_entity not found"),
|
||||
("person", "Properties do not exist in the taxonomy for the entity person."),
|
||||
]
|
||||
)
|
||||
def test_retrieve_entity_properties_edge_cases(self, entity, expected_content):
|
||||
result = self.toolkit.retrieve_entity_properties(entity)
|
||||
self.assertIn(expected_content, result)
|
||||
|
||||
@patch("ee.hogai.graph.taxonomy.toolkit.ActorsPropertyTaxonomyQueryRunner")
|
||||
def test_retrieve_entity_property_values_person(self, mock_runner_class):
|
||||
self._create_property_definition(PropertyDefinition.Type.PERSON, "email")
|
||||
|
||||
mock_response = self._create_mock_taxonomy_response(
|
||||
response_type="actors", sample_values=["test@example.com", "user@test.com"], sample_count=2
|
||||
)
|
||||
|
||||
mock_runner = Mock()
|
||||
mock_runner.run.return_value = mock_response
|
||||
mock_runner_class.return_value = mock_runner
|
||||
|
||||
result = self.toolkit.retrieve_entity_property_values("person", "email")
|
||||
self.assertIn("test@example.com", result)
|
||||
|
||||
def test_retrieve_entity_property_values_invalid_entity(self):
|
||||
result = self.toolkit.retrieve_entity_property_values("invalid", "prop")
|
||||
self.assertIn("Entity invalid not found", result)
|
||||
|
||||
@patch("ee.hogai.graph.taxonomy.toolkit.EventTaxonomyQueryRunner")
|
||||
def test_retrieve_event_or_action_properties(self, mock_runner_class):
|
||||
self._create_property_definition(PropertyDefinition.Type.EVENT, "$browser")
|
||||
|
||||
mock_response = self._create_mock_taxonomy_response(property="$browser", sample_values=[], sample_count=0)
|
||||
|
||||
mock_runner = Mock()
|
||||
mock_runner.run.return_value = mock_response
|
||||
mock_runner_class.return_value = mock_runner
|
||||
|
||||
result = self.toolkit.retrieve_event_or_action_properties("test_event")
|
||||
self.assertIn("$browser", result)
|
||||
|
||||
def test_retrieve_event_or_action_properties_action_not_found(self):
|
||||
Action.objects.all().delete()
|
||||
result = self.toolkit.retrieve_event_or_action_properties(999)
|
||||
self.assertEqual(result, "No actions exist in the project.")
|
||||
|
||||
@patch("ee.hogai.graph.taxonomy.toolkit.EventTaxonomyQueryRunner")
|
||||
def test_retrieve_event_or_action_property_values(self, mock_runner_class):
|
||||
self._create_property_definition(PropertyDefinition.Type.EVENT, "$browser")
|
||||
|
||||
mock_response = self._create_mock_taxonomy_response(
|
||||
property="$browser", sample_values=["Chrome", "Firefox"], sample_count=2
|
||||
)
|
||||
|
||||
mock_runner = Mock()
|
||||
mock_runner.run.return_value = mock_response
|
||||
mock_runner_class.return_value = mock_runner
|
||||
|
||||
result = self.toolkit.retrieve_event_or_action_property_values("test_event", "$browser")
|
||||
self.assertIn("Chrome", result)
|
||||
self.assertIn("Firefox", result)
|
||||
|
||||
def test_handle_incorrect_response(self):
|
||||
class TestModel(BaseModel):
|
||||
field: str = "test"
|
||||
@@ -227,10 +73,10 @@ class TestTaxonomyAgentToolkit(ClickhouseTestMixin, BaseTest):
|
||||
]
|
||||
)
|
||||
@patch.object(DummyToolkit, "retrieve_entity_properties", return_value="mocked")
|
||||
@patch.object(DummyToolkit, "retrieve_entity_property_values", return_value="mocked")
|
||||
@patch.object(DummyToolkit, "retrieve_entity_property_values", return_value={"person": ["mocked"]})
|
||||
@patch.object(DummyToolkit, "retrieve_event_or_action_properties", return_value="mocked")
|
||||
@patch.object(DummyToolkit, "retrieve_event_or_action_property_values", return_value="mocked")
|
||||
def test_handle_tools(self, tool_name, tool_args, expected_result, *mocks):
|
||||
@patch.object(DummyToolkit, "retrieve_event_or_action_property_values", return_value={"test_event": ["mocked"]})
|
||||
async def test_handle_tools(self, tool_name, tool_args, expected_result, *mocks):
|
||||
class Arguments(BaseModel):
|
||||
pass
|
||||
|
||||
@@ -242,12 +88,12 @@ class TestTaxonomyAgentToolkit(ClickhouseTestMixin, BaseTest):
|
||||
arguments: Arguments
|
||||
|
||||
tool_input = ToolInput(name=tool_name, arguments=Arguments(**tool_args))
|
||||
tool_name_result, result = self.toolkit.handle_tools(tool_name, tool_input)
|
||||
tool_name_result, result = await self.toolkit.handle_tools(tool_name, tool_input)
|
||||
|
||||
self.assertEqual(result, expected_result)
|
||||
self.assertEqual(tool_name_result, tool_name)
|
||||
|
||||
def test_handle_tools_invalid_tool(self):
|
||||
async def test_handle_tools_invalid_tool(self):
|
||||
class ToolInput(BaseModel):
|
||||
name: str = "invalid_tool"
|
||||
arguments: dict = {}
|
||||
@@ -255,7 +101,7 @@ class TestTaxonomyAgentToolkit(ClickhouseTestMixin, BaseTest):
|
||||
tool_input = ToolInput()
|
||||
|
||||
with self.assertRaises(TaxonomyToolNotFoundError):
|
||||
self.toolkit.handle_tools("invalid_tool", tool_input)
|
||||
await self.toolkit.handle_tools("invalid_tool", tool_input)
|
||||
|
||||
def test_format_properties_formats(self):
|
||||
props = [("prop1", "String", "Test description"), ("prop2", "Numeric", None)]
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import asyncio
|
||||
from collections.abc import Iterable
|
||||
from functools import cached_property
|
||||
from typing import Any, Optional, Union, cast
|
||||
@@ -9,7 +10,10 @@ from posthog.schema import (
|
||||
ActorsPropertyTaxonomyQuery,
|
||||
CachedActorsPropertyTaxonomyQueryResponse,
|
||||
CachedEventTaxonomyQueryResponse,
|
||||
CacheMissResponse,
|
||||
EventTaxonomyItem,
|
||||
EventTaxonomyQuery,
|
||||
QueryStatusResponse,
|
||||
)
|
||||
|
||||
from posthog.hogql.database.schema.channel_type import DEFAULT_CHANNEL_TYPES
|
||||
@@ -21,6 +25,7 @@ from posthog.hogql_queries.query_runner import ExecutionMode
|
||||
from posthog.models import Action, Team
|
||||
from posthog.models.group_type_mapping import GroupTypeMapping
|
||||
from posthog.models.property_definition import PropertyDefinition, PropertyType
|
||||
from posthog.sync import database_sync_to_async
|
||||
from posthog.taxonomy.taxonomy import CORE_FILTER_DEFINITIONS_BY_GROUP
|
||||
|
||||
from ee.hogai.graph.taxonomy.format import (
|
||||
@@ -104,6 +109,7 @@ class TaxonomyAgentToolkit:
|
||||
|
||||
def __init__(self, team: Team):
|
||||
self._team = team
|
||||
self.MAX_ENTITIES_PER_BATCH = 6
|
||||
|
||||
@property
|
||||
def _groups(self):
|
||||
@@ -125,17 +131,21 @@ class TaxonomyAgentToolkit:
|
||||
entities = [
|
||||
"person",
|
||||
"session",
|
||||
*[group.group_type for group in self._groups],
|
||||
*self._team_group_types,
|
||||
]
|
||||
return entities
|
||||
|
||||
@database_sync_to_async(thread_sensitive=False)
|
||||
def _get_entity_names(self) -> list[str]:
|
||||
return self._entity_names
|
||||
|
||||
def _enrich_props_with_descriptions(self, entity: str, props: Iterable[tuple[str, str | None]]):
|
||||
return enrich_props_with_descriptions(entity, props)
|
||||
|
||||
def _format_property_values(
|
||||
self, sample_values: list, sample_count: Optional[int] = 0, format_as_string: bool = False
|
||||
self, property_name: str, sample_values: list, sample_count: Optional[int] = 0, format_as_string: bool = False
|
||||
) -> str:
|
||||
return format_property_values(sample_values, sample_count, format_as_string)
|
||||
return format_property_values(property_name, sample_values, sample_count, format_as_string)
|
||||
|
||||
def _retrieve_session_properties(self, property_name: str) -> str:
|
||||
"""
|
||||
@@ -161,18 +171,26 @@ class TaxonomyAgentToolkit:
|
||||
else:
|
||||
return TaxonomyErrorMessages.property_values_not_found(property_name, "session")
|
||||
|
||||
return self._format_property_values(sample_values, sample_count, format_as_string=is_str)
|
||||
return self._format_property_values(property_name, sample_values, sample_count, format_as_string=is_str)
|
||||
|
||||
def _retrieve_event_or_action_taxonomy(self, event_name_or_action_id: str | int):
|
||||
@database_sync_to_async(thread_sensitive=False)
|
||||
def _retrieve_event_or_action_taxonomy(
|
||||
self, event_name_or_action_id: str | int, properties: list[str] | None = None
|
||||
):
|
||||
"""
|
||||
Retrieve event/action taxonomy with efficient caching.
|
||||
Multiple properties are batched in a single query to maximize cache hits.
|
||||
"""
|
||||
is_event = isinstance(event_name_or_action_id, str)
|
||||
if is_event:
|
||||
query = EventTaxonomyQuery(event=event_name_or_action_id, maxPropertyValues=25)
|
||||
query = EventTaxonomyQuery(event=event_name_or_action_id, maxPropertyValues=25, properties=properties)
|
||||
verbose_name = f"event {event_name_or_action_id}"
|
||||
else:
|
||||
query = EventTaxonomyQuery(actionId=event_name_or_action_id, maxPropertyValues=25)
|
||||
query = EventTaxonomyQuery(actionId=event_name_or_action_id, maxPropertyValues=25, properties=properties)
|
||||
verbose_name = f"action with ID {event_name_or_action_id}"
|
||||
runner = EventTaxonomyQueryRunner(query, self._team)
|
||||
with tags_context(product=Product.MAX_AI, team_id=self._team.pk, org_id=self._team.organization_id):
|
||||
# Use cache-first execution mode for optimal performance
|
||||
response = runner.run(ExecutionMode.RECENT_CACHE_CALCULATE_ASYNC_IF_STALE_AND_BLOCKING_ON_MISS)
|
||||
return response, verbose_name
|
||||
|
||||
@@ -220,20 +238,20 @@ class TaxonomyAgentToolkit:
|
||||
"""Get custom tools. Override in subclasses to add custom tools."""
|
||||
raise NotImplementedError("_get_custom_tools must be implemented in subclasses")
|
||||
|
||||
def retrieve_entity_properties(self, entity: str, max_properties: int = 500) -> str:
|
||||
async def retrieve_entity_properties(self, entity: str, max_properties: int = 500) -> str:
|
||||
"""
|
||||
Retrieve properties for an entitiy like person, session, or one of the groups.
|
||||
"""
|
||||
|
||||
if entity not in ("person", "session", *[group.group_type for group in self._groups]):
|
||||
return TaxonomyErrorMessages.entity_not_found(entity, self._entity_names)
|
||||
entity_names = await self._get_entity_names()
|
||||
if entity not in entity_names:
|
||||
return TaxonomyErrorMessages.entity_not_found(entity, entity_names)
|
||||
|
||||
props: list[Any] = []
|
||||
if entity == "person":
|
||||
qs = PropertyDefinition.objects.filter(team=self._team, type=PropertyDefinition.Type.PERSON).values_list(
|
||||
"name", "property_type"
|
||||
)
|
||||
props = self._enrich_props_with_descriptions("person", qs)
|
||||
props = self._enrich_props_with_descriptions("person", [prop async for prop in qs])
|
||||
elif entity == "session":
|
||||
# Session properties are not in the DB.
|
||||
props = self._enrich_props_with_descriptions(
|
||||
@@ -245,91 +263,153 @@ class TaxonomyAgentToolkit:
|
||||
],
|
||||
)
|
||||
else:
|
||||
group_type_index = next(
|
||||
(group.group_type_index for group in self._groups if group.group_type == entity), None
|
||||
)
|
||||
group_type_index = None
|
||||
groups = [group async for group in self._groups]
|
||||
for group in groups:
|
||||
if group.group_type == entity:
|
||||
group_type_index = group.group_type_index
|
||||
break
|
||||
if group_type_index is None:
|
||||
return f"Group {entity} does not exist in the taxonomy."
|
||||
qs = PropertyDefinition.objects.filter(
|
||||
team=self._team, type=PropertyDefinition.Type.GROUP, group_type_index=group_type_index
|
||||
).values_list("name", "property_type")[:max_properties]
|
||||
props = self._enrich_props_with_descriptions(entity, qs)
|
||||
props = self._enrich_props_with_descriptions(entity, [prop async for prop in qs])
|
||||
|
||||
if not props:
|
||||
return f"Properties do not exist in the taxonomy for the entity {entity}."
|
||||
|
||||
return self._format_properties(props)
|
||||
|
||||
def retrieve_entity_property_values(self, entity: str, property_name: str) -> str:
|
||||
"""Retrieve property values for an entity."""
|
||||
if entity not in self._entity_names:
|
||||
return TaxonomyErrorMessages.entity_not_found(entity, self._entity_names)
|
||||
async def retrieve_entity_property_values(self, entity_properties: dict[str, list[str]]) -> dict[str, list[str]]:
|
||||
result = await self._parallel_entity_processing(entity_properties)
|
||||
return result
|
||||
|
||||
if entity == "session":
|
||||
return self._retrieve_session_properties(property_name)
|
||||
async def _handle_entity_batch(self, batch: dict[str, list[str]]) -> dict[str, list[str]]:
|
||||
entity_tasks = [
|
||||
self._retrieve_multiple_entity_property_values(entity, property_names)
|
||||
for entity, property_names in batch.items()
|
||||
]
|
||||
batch_results = await asyncio.gather(*entity_tasks)
|
||||
return dict(zip(batch.keys(), batch_results))
|
||||
|
||||
if entity == "person":
|
||||
query = ActorsPropertyTaxonomyQuery(properties=[property_name], maxPropertyValues=25)
|
||||
elif entity == "event":
|
||||
query = ActorsPropertyTaxonomyQuery(properties=[property_name], maxPropertyValues=50)
|
||||
async def _parallel_entity_processing(self, entity_properties: dict[str, list[str]]) -> dict[str, list[str]]:
|
||||
entity_items = list(entity_properties.items())
|
||||
if len(entity_items) > self.MAX_ENTITIES_PER_BATCH:
|
||||
# Process in batches
|
||||
results = {}
|
||||
for i in range(0, len(entity_items), self.MAX_ENTITIES_PER_BATCH):
|
||||
batch = dict(entity_items[i : i + self.MAX_ENTITIES_PER_BATCH])
|
||||
batch_results = await self._handle_entity_batch(batch)
|
||||
results.update(batch_results)
|
||||
return results
|
||||
else:
|
||||
group_index = next((group.group_type_index for group in self._groups if group.group_type == entity), None)
|
||||
return await self._handle_entity_batch(entity_properties)
|
||||
|
||||
async def _retrieve_multiple_entity_property_values(self, entity: str, property_names: list[str]) -> list[str]:
|
||||
"""Retrieve property values for multiple entities and properties efficiently."""
|
||||
results = []
|
||||
entity_names = await self._get_entity_names()
|
||||
if entity not in entity_names:
|
||||
results.append(TaxonomyErrorMessages.entity_not_found(entity, entity_names))
|
||||
return results
|
||||
if entity == "session":
|
||||
for property_name in property_names:
|
||||
results.append(self._retrieve_session_properties(property_name))
|
||||
return results
|
||||
groups = [group async for group in self._groups]
|
||||
query = self._build_query(entity, property_names, groups)
|
||||
if query is None:
|
||||
results.append(TaxonomyErrorMessages.entity_not_found(entity))
|
||||
return results
|
||||
property_values_response = await self._run_actors_taxonomy_query(query)
|
||||
if not isinstance(property_values_response, CachedActorsPropertyTaxonomyQueryResponse):
|
||||
results.append(TaxonomyErrorMessages.entity_not_found(entity))
|
||||
return results
|
||||
|
||||
if not property_values_response.results:
|
||||
for property_name in property_names:
|
||||
results.append(TaxonomyErrorMessages.property_values_not_found(property_name, entity))
|
||||
return results
|
||||
|
||||
if isinstance(property_values_response.results, list):
|
||||
property_values_results = property_values_response.results
|
||||
else:
|
||||
property_values_results = [property_values_response.results]
|
||||
|
||||
property_definitions: dict[str, PropertyDefinition] = await self._get_definitions_for_entity(
|
||||
entity, property_names, query
|
||||
)
|
||||
|
||||
results.extend(
|
||||
self._process_property_values(
|
||||
property_names, property_values_results, property_definitions, entity, is_indexed=True
|
||||
)
|
||||
)
|
||||
return results
|
||||
|
||||
@database_sync_to_async(thread_sensitive=False)
|
||||
def _get_definitions_for_entity(
|
||||
self, entity: str, property_names: list[str], query: ActorsPropertyTaxonomyQuery
|
||||
) -> dict[str, PropertyDefinition]:
|
||||
"""Get property definitions for one entity and properties."""
|
||||
if not property_names:
|
||||
return {}
|
||||
|
||||
if query.groupTypeIndex is not None:
|
||||
prop_type = PropertyDefinition.Type.GROUP
|
||||
group_type_index = query.groupTypeIndex
|
||||
elif entity == "event":
|
||||
prop_type = PropertyDefinition.Type.EVENT
|
||||
group_type_index = None
|
||||
else:
|
||||
prop_type = PropertyDefinition.Type.PERSON
|
||||
group_type_index = None
|
||||
|
||||
property_definitions = PropertyDefinition.objects.filter(
|
||||
team=self._team,
|
||||
name__in=property_names,
|
||||
type=prop_type,
|
||||
group_type_index=group_type_index,
|
||||
)
|
||||
return {prop.name: prop for prop in property_definitions}
|
||||
|
||||
def _build_query(
|
||||
self, entity: str, properties: list[str], groups: list[GroupTypeMapping]
|
||||
) -> ActorsPropertyTaxonomyQuery | None:
|
||||
"""Build a query for the given entity and property names."""
|
||||
if entity == "person":
|
||||
query = ActorsPropertyTaxonomyQuery(properties=properties, maxPropertyValues=25)
|
||||
elif entity == "event":
|
||||
query = ActorsPropertyTaxonomyQuery(properties=properties, maxPropertyValues=50)
|
||||
else:
|
||||
group_index = next((group.group_type_index for group in groups if group.group_type == entity), None)
|
||||
if group_index is None:
|
||||
return TaxonomyErrorMessages.entity_not_found(entity)
|
||||
query = ActorsPropertyTaxonomyQuery(
|
||||
groupTypeIndex=group_index, properties=[property_name], maxPropertyValues=25
|
||||
)
|
||||
|
||||
try:
|
||||
if query.groupTypeIndex is not None:
|
||||
prop_type = PropertyDefinition.Type.GROUP
|
||||
group_type_index = query.groupTypeIndex
|
||||
elif entity == "event":
|
||||
prop_type = PropertyDefinition.Type.EVENT
|
||||
group_type_index = None
|
||||
else:
|
||||
prop_type = PropertyDefinition.Type.PERSON
|
||||
group_type_index = None
|
||||
property_definition = PropertyDefinition.objects.get(
|
||||
team=self._team,
|
||||
name=property_name,
|
||||
type=prop_type,
|
||||
group_type_index=group_type_index,
|
||||
)
|
||||
except PropertyDefinition.DoesNotExist:
|
||||
return TaxonomyErrorMessages.property_not_found(property_name, entity)
|
||||
return None
|
||||
query = ActorsPropertyTaxonomyQuery(groupTypeIndex=group_index, properties=properties, maxPropertyValues=25)
|
||||
return query
|
||||
|
||||
@database_sync_to_async(thread_sensitive=False)
|
||||
def _run_actors_taxonomy_query(
|
||||
self, query
|
||||
) -> CachedActorsPropertyTaxonomyQueryResponse | CacheMissResponse | QueryStatusResponse:
|
||||
with tags_context(product=Product.MAX_AI, team_id=self._team.pk, org_id=self._team.organization_id):
|
||||
response = ActorsPropertyTaxonomyQueryRunner(query, self._team).run(
|
||||
return ActorsPropertyTaxonomyQueryRunner(query, self._team).run(
|
||||
ExecutionMode.RECENT_CACHE_CALCULATE_ASYNC_IF_STALE_AND_BLOCKING_ON_MISS
|
||||
)
|
||||
|
||||
if not isinstance(response, CachedActorsPropertyTaxonomyQueryResponse):
|
||||
return TaxonomyErrorMessages.entity_not_found(entity)
|
||||
@database_sync_to_async(thread_sensitive=False)
|
||||
def _get_project_actions(self) -> list[Action]:
|
||||
return list(Action.objects.filter(team__project_id=self._team.project_id, deleted=False))
|
||||
|
||||
if not response.results:
|
||||
return TaxonomyErrorMessages.property_values_not_found(property_name, entity)
|
||||
|
||||
# TRICKY. Remove when the toolkit supports multiple results.
|
||||
if isinstance(response.results, list):
|
||||
unpacked_results = response.results[0]
|
||||
else:
|
||||
unpacked_results = response.results
|
||||
|
||||
return self._format_property_values(
|
||||
unpacked_results.sample_values,
|
||||
unpacked_results.sample_count,
|
||||
format_as_string=property_definition.property_type in (PropertyType.String, PropertyType.Datetime),
|
||||
)
|
||||
|
||||
def retrieve_event_or_action_properties(self, event_name_or_action_id: str | int) -> str:
|
||||
async def retrieve_event_or_action_properties(self, event_name_or_action_id: str | int) -> str:
|
||||
"""
|
||||
Retrieve properties for an event.
|
||||
"""
|
||||
try:
|
||||
response, verbose_name = self._retrieve_event_or_action_taxonomy(event_name_or_action_id)
|
||||
response, verbose_name = await self._retrieve_event_or_action_taxonomy(event_name_or_action_id)
|
||||
except Action.DoesNotExist:
|
||||
project_actions = Action.objects.filter(team__project_id=self._team.project_id, deleted=False)
|
||||
project_actions = await self._get_project_actions()
|
||||
if not project_actions:
|
||||
return TaxonomyErrorMessages.no_actions_exist()
|
||||
return TaxonomyErrorMessages.action_not_found(event_name_or_action_id)
|
||||
@@ -337,11 +417,12 @@ class TaxonomyAgentToolkit:
|
||||
return TaxonomyErrorMessages.generic_not_found("Properties")
|
||||
if not response.results:
|
||||
return TaxonomyErrorMessages.event_properties_not_found(verbose_name)
|
||||
# Intersect properties with their types.
|
||||
|
||||
qs = PropertyDefinition.objects.filter(
|
||||
team=self._team, type=PropertyDefinition.Type.EVENT, name__in=[item.property for item in response.results]
|
||||
)
|
||||
property_to_type = {property_definition.name: property_definition.property_type for property_definition in qs}
|
||||
property_data = [prop async for prop in qs]
|
||||
property_to_type = {prop.name: prop.property_type for prop in property_data}
|
||||
props = [
|
||||
(item.property, property_to_type.get(item.property))
|
||||
for item in response.results
|
||||
@@ -351,48 +432,124 @@ class TaxonomyAgentToolkit:
|
||||
|
||||
if not props:
|
||||
return TaxonomyErrorMessages.event_properties_not_found(verbose_name)
|
||||
return self._format_properties(self._enrich_props_with_descriptions("event", props))
|
||||
enriched_props = self._enrich_props_with_descriptions("event", props)
|
||||
return self._format_properties(enriched_props)
|
||||
|
||||
def retrieve_event_or_action_property_values(self, event_name_or_action_id: str | int, property_name: str) -> str:
|
||||
async def retrieve_event_or_action_property_values(
|
||||
self, event_properties: dict[str | int, list[str]]
|
||||
) -> dict[str | int, list[str]]:
|
||||
"""Retrieve property values for an event/action. Supports single property or list of properties."""
|
||||
result = await self._parallel_event_processing(event_properties)
|
||||
return result
|
||||
|
||||
async def _parallel_event_processing(
|
||||
self, event_properties: dict[str | int, list[str]]
|
||||
) -> dict[str | int, list[str]]:
|
||||
event_tasks = [
|
||||
self._retrieve_multiple_event_or_action_property_values(event_name_or_action_id, property_names)
|
||||
for event_name_or_action_id, property_names in event_properties.items()
|
||||
]
|
||||
results = await asyncio.gather(*event_tasks)
|
||||
return dict(zip(event_properties.keys(), results))
|
||||
|
||||
@database_sync_to_async(thread_sensitive=False)
|
||||
def _get_definitions_for_event_or_action(self, property_names: list[str]) -> dict[str, PropertyDefinition]:
|
||||
return {
|
||||
prop.name: prop
|
||||
for prop in PropertyDefinition.objects.filter(
|
||||
team=self._team,
|
||||
name__in=property_names,
|
||||
type=PropertyDefinition.Type.EVENT,
|
||||
)
|
||||
}
|
||||
|
||||
async def _retrieve_multiple_event_or_action_property_values(
|
||||
self, event_name_or_action_id: str | int, property_names: list[str]
|
||||
) -> list[str]:
|
||||
"""Retrieve property values for multiple events/actions and properties efficiently."""
|
||||
results = []
|
||||
try:
|
||||
property_definition = PropertyDefinition.objects.get(
|
||||
team=self._team, name=property_name, type=PropertyDefinition.Type.EVENT
|
||||
definitions_map: dict[str, PropertyDefinition] = await self._get_definitions_for_event_or_action(
|
||||
property_names
|
||||
)
|
||||
except PropertyDefinition.DoesNotExist:
|
||||
return TaxonomyErrorMessages.property_not_found(property_name)
|
||||
definitions_map = {}
|
||||
|
||||
response, verbose_name = await self._retrieve_event_or_action_taxonomy(event_name_or_action_id, property_names)
|
||||
|
||||
response, verbose_name = self._retrieve_event_or_action_taxonomy(event_name_or_action_id)
|
||||
if not isinstance(response, CachedEventTaxonomyQueryResponse):
|
||||
return TaxonomyErrorMessages.event_not_found(verbose_name)
|
||||
results.append(TaxonomyErrorMessages.event_not_found(verbose_name))
|
||||
return results
|
||||
if not response.results:
|
||||
return TaxonomyErrorMessages.property_values_not_found(property_name, verbose_name)
|
||||
for property_name in property_names:
|
||||
results.append(TaxonomyErrorMessages.property_values_not_found(property_name, verbose_name))
|
||||
return results
|
||||
|
||||
prop = next((item for item in response.results if item.property == property_name), None)
|
||||
if not prop:
|
||||
return TaxonomyErrorMessages.property_not_found(property_name, verbose_name)
|
||||
# Create a map of property name to taxonomy result for efficient lookup
|
||||
taxonomy_results_map: dict[str, EventTaxonomyItem] = {item.property: item for item in response.results}
|
||||
|
||||
return self._format_property_values(
|
||||
prop.sample_values,
|
||||
prop.sample_count,
|
||||
format_as_string=property_definition.property_type in (PropertyType.String, PropertyType.Datetime),
|
||||
results.extend(
|
||||
self._process_property_values(
|
||||
property_names, list(taxonomy_results_map.values()), definitions_map, verbose_name, is_indexed=False
|
||||
)
|
||||
)
|
||||
|
||||
def handle_tools(self, tool_name: str, tool_input: TaxonomyTool) -> tuple[str, str]:
|
||||
# Here we handle the tool execution for base taxonomy tools.
|
||||
return results
|
||||
|
||||
def _process_property_values(
|
||||
self,
|
||||
property_names: list[str],
|
||||
property_results: list,
|
||||
property_definitions: dict[str, PropertyDefinition],
|
||||
entity_name: str,
|
||||
is_indexed: bool = False,
|
||||
) -> list[str]:
|
||||
"""Common logic for processing property values from taxonomy results."""
|
||||
results = []
|
||||
|
||||
for i, property_name in enumerate(property_names):
|
||||
property_definition = property_definitions.get(property_name)
|
||||
|
||||
if property_definition is None:
|
||||
results.append(TaxonomyErrorMessages.property_not_found(property_name, entity_name))
|
||||
continue
|
||||
|
||||
if is_indexed:
|
||||
if i >= len(property_results):
|
||||
results.append(TaxonomyErrorMessages.property_not_found(property_name, entity_name))
|
||||
continue
|
||||
prop_result = property_results[i]
|
||||
else:
|
||||
prop_result = next((r for r in property_results if r.property == property_name), None)
|
||||
if prop_result is None:
|
||||
results.append(TaxonomyErrorMessages.property_not_found(property_name, entity_name))
|
||||
continue
|
||||
|
||||
result = self._format_property_values(
|
||||
property_name,
|
||||
prop_result.sample_values,
|
||||
prop_result.sample_count,
|
||||
format_as_string=property_definition.property_type in (PropertyType.String, PropertyType.Datetime),
|
||||
)
|
||||
results.append(result)
|
||||
|
||||
return results
|
||||
|
||||
async def handle_tools(self, tool_name: str, tool_input: TaxonomyTool) -> tuple[str, str]:
|
||||
if tool_name == "retrieve_entity_property_values":
|
||||
result = self.retrieve_entity_property_values(
|
||||
tool_input.arguments.entity, # type: ignore
|
||||
tool_input.arguments.property_name, # type: ignore
|
||||
)
|
||||
entity = tool_input.arguments.entity # type: ignore
|
||||
property_name = tool_input.arguments.property_name # type: ignore
|
||||
result = (await self.retrieve_entity_property_values({entity: [property_name]}))[entity][0]
|
||||
elif tool_name == "retrieve_entity_properties":
|
||||
result = self.retrieve_entity_properties(tool_input.arguments.entity) # type: ignore
|
||||
result = await self.retrieve_entity_properties(tool_input.arguments.entity) # type: ignore
|
||||
elif tool_name == "retrieve_event_property_values":
|
||||
result = self.retrieve_event_or_action_property_values(
|
||||
tool_input.arguments.event_name, # type: ignore
|
||||
tool_input.arguments.property_name, # type: ignore
|
||||
)
|
||||
event_name_or_action_id = tool_input.arguments.event_name # type: ignore
|
||||
property_name = tool_input.arguments.property_name # type: ignore
|
||||
result = (await self.retrieve_event_or_action_property_values({event_name_or_action_id: [property_name]}))[
|
||||
event_name_or_action_id
|
||||
][0]
|
||||
elif tool_name == "retrieve_event_properties":
|
||||
result = self.retrieve_event_or_action_properties(tool_input.arguments.event_name) # type: ignore
|
||||
result = await self.retrieve_event_or_action_properties(tool_input.arguments.event_name) # type: ignore
|
||||
elif tool_name == "ask_user_for_help":
|
||||
result = tool_input.arguments.request # type: ignore
|
||||
else:
|
||||
|
||||
@@ -1028,7 +1028,7 @@
|
||||
WHERE and(equals(s.team_id, 99999), greaterOrEquals(toTimeZone(s.min_first_timestamp, 'UTC'), toDateTime64('2020-12-11 13:46:23.000000', 6, 'UTC')), greaterOrEquals(addDays(dateTrunc('DAY', toTimeZone(s.min_first_timestamp, 'UTC')), 1), minus(toDateTime64('today', 6, 'UTC'), toIntervalDay(coalesce(s.retention_period_days, 365)))), greaterOrEquals(toTimeZone(s.min_first_timestamp, 'UTC'), toDateTime64('2020-12-29 00:00:00.000000', 6, 'UTC')), lessOrEquals(toTimeZone(s.min_first_timestamp, 'UTC'), toDateTime64('today', 6, 'UTC')), globalIn(s.session_id,
|
||||
(SELECT events.`$session_id` AS session_id
|
||||
FROM events
|
||||
WHERE and(equals(events.team_id, 99999), notEmpty(events.`$session_id`), greaterOrEquals(toTimeZone(events.timestamp, 'UTC'), toDateTime64('2020-12-11 13:46:23.000000', 6, 'UTC')), lessOrEquals(toTimeZone(events.timestamp, 'UTC'), now64(6, 'UTC')), ifNull(equals(nullIf(nullIf(events.mat_pp_email, ''), 'null'), 'bla'), 0), greaterOrEquals(toTimeZone(events.timestamp, 'UTC'), toDateTime64('2020-12-28 00:00:00.000000', 6, 'UTC')))
|
||||
WHERE and(equals(events.team_id, 99999), notEmpty(events.`$session_id`), greaterOrEquals(toTimeZone(events.timestamp, 'UTC'), toDateTime64('2020-12-11 13:46:23.000000', 6, 'UTC')), lessOrEquals(toTimeZone(events.timestamp, 'UTC'), now64(6, 'UTC')), ifNull(equals(replaceRegexpAll(nullIf(nullIf(JSONExtractRaw(events.person_properties, 'email'), ''), 'null'), '^"|"$', ''), 'bla'), 0), greaterOrEquals(toTimeZone(events.timestamp, 'UTC'), toDateTime64('2020-12-28 00:00:00.000000', 6, 'UTC')))
|
||||
GROUP BY events.`$session_id`
|
||||
HAVING 1
|
||||
ORDER BY min(toTimeZone(events.timestamp, 'UTC')) DESC)))
|
||||
|
||||
@@ -118,8 +118,8 @@ class ErrorTrackingIssueImpactToolkit(TaxonomyAgentToolkit):
|
||||
def __init__(self, team: Team):
|
||||
super().__init__(team)
|
||||
|
||||
def handle_tools(self, tool_name: str, tool_input: TaxonomyTool) -> tuple[str, str]:
|
||||
return super().handle_tools(tool_name, tool_input)
|
||||
async def handle_tools(self, tool_name: str, tool_input: TaxonomyTool) -> tuple[str, str]:
|
||||
return await super().handle_tools(tool_name, tool_input)
|
||||
|
||||
def _get_custom_tools(self) -> list:
|
||||
return [final_answer]
|
||||
|
||||
@@ -9,6 +9,7 @@ from posthog.schema import RevenueAnalyticsAssistantFilters
|
||||
|
||||
from posthog.clickhouse.query_tagging import Product, tags_context
|
||||
from posthog.models import Team, User
|
||||
from posthog.sync import database_sync_to_async
|
||||
from posthog.taxonomy.taxonomy import CORE_FILTER_DEFINITIONS_BY_GROUP
|
||||
|
||||
from products.revenue_analytics.backend.api import find_values_for_revenue_analytics_property
|
||||
@@ -55,13 +56,13 @@ class RevenueAnalyticsFilterOptionsToolkit(TaxonomyAgentToolkit):
|
||||
def __init__(self, team: Team):
|
||||
super().__init__(team)
|
||||
|
||||
def handle_tools(self, tool_name: str, tool_input) -> tuple[str, str]:
|
||||
async def handle_tools(self, tool_name: str, tool_input) -> tuple[str, str]:
|
||||
"""Handle custom tool execution."""
|
||||
if tool_name == "retrieve_revenue_analytics_property_values":
|
||||
result = self._retrieve_revenue_analytics_property_values(tool_input.arguments.property_key)
|
||||
result = await self._retrieve_revenue_analytics_property_values(tool_input.arguments.property_key)
|
||||
return tool_name, result
|
||||
|
||||
return super().handle_tools(tool_name, tool_input)
|
||||
return await super().handle_tools(tool_name, tool_input)
|
||||
|
||||
def _get_custom_tools(self) -> list:
|
||||
return [final_answer, retrieve_revenue_analytics_property_values]
|
||||
@@ -72,7 +73,7 @@ class RevenueAnalyticsFilterOptionsToolkit(TaxonomyAgentToolkit):
|
||||
"""Returns the list of tools available in this toolkit."""
|
||||
return [*self._get_custom_tools(), ask_user_for_help]
|
||||
|
||||
def _retrieve_revenue_analytics_property_values(self, property_name: str) -> str:
|
||||
async def _retrieve_revenue_analytics_property_values(self, property_name: str) -> str:
|
||||
"""
|
||||
Revenue analytics properties come from Clickhouse so let's run a separate query here.
|
||||
"""
|
||||
@@ -80,9 +81,9 @@ class RevenueAnalyticsFilterOptionsToolkit(TaxonomyAgentToolkit):
|
||||
return TaxonomyErrorMessages.property_not_found(property_name, "revenue_analytics")
|
||||
|
||||
with tags_context(product=Product.MAX_AI, team_id=self._team.pk, org_id=self._team.organization_id):
|
||||
values = find_values_for_revenue_analytics_property(property_name, self._team)
|
||||
values = await database_sync_to_async(find_values_for_revenue_analytics_property)(property_name, self._team)
|
||||
|
||||
return self._format_property_values(values, sample_count=len(values))
|
||||
return self._format_property_values(property_name, values, sample_count=len(values))
|
||||
|
||||
|
||||
class RevenueAnalyticsFilterNode(
|
||||
|
||||
@@ -183,12 +183,12 @@ class SurveyToolkit(TaxonomyAgentToolkit):
|
||||
|
||||
return [lookup_feature_flag, final_answer]
|
||||
|
||||
def handle_tools(self, tool_name: str, tool_input) -> tuple[str, str]:
|
||||
async def handle_tools(self, tool_name: str, tool_input) -> tuple[str, str]:
|
||||
"""Handle custom tool execution."""
|
||||
if tool_name == "lookup_feature_flag":
|
||||
result = self._lookup_feature_flag(tool_input.arguments.flag_key)
|
||||
return tool_name, result
|
||||
return super().handle_tools(tool_name, tool_input)
|
||||
return await super().handle_tools(tool_name, tool_input)
|
||||
|
||||
def _lookup_feature_flag(self, flag_key: str) -> str:
|
||||
"""Look up feature flag information by key."""
|
||||
|
||||
Reference in New Issue
Block a user