feat(max): allow for multiple entity search (#38504)

Co-authored-by: github-actions <41898282+github-actions[bot]@users.noreply.github.com>
This commit is contained in:
Dena Korita
2025-10-22 10:06:44 +02:00
committed by GitHub
parent 82da840971
commit 3a5b11594e
20 changed files with 666 additions and 221 deletions

View File

@@ -102,6 +102,7 @@ class MainAssistant(BaseAssistant):
AssistantNodeName.MEMORY_INITIALIZER_INTERRUPT,
AssistantNodeName.ROOT_TOOLS,
TaxonomyNodeName.TOOLS_NODE,
TaxonomyNodeName.TASK_EXECUTOR,
}
@property

View File

@@ -1,4 +1,5 @@
from abc import ABC, abstractmethod
from collections import defaultdict
from functools import cached_property
from typing import Generic, TypeVar
@@ -17,6 +18,7 @@ from posthog.schema import MaxEventContext
from posthog.models import Team, User
from posthog.models.group_type_mapping import GroupTypeMapping
from ee.hogai.graph.taxonomy.tools import TaxonomyTool
from ee.hogai.llm import MaxChatOpenAI
from ee.hogai.utils.helpers import format_events_yaml
from ee.hogai.utils.types.composed import MaxNodeName
@@ -31,7 +33,6 @@ from .prompts import (
TAXONOMY_TOOL_USAGE_PROMPT,
)
from .toolkit import TaxonomyAgentToolkit
from .tools import TaxonomyTool
from .types import EntityType, TaxonomyAgentState, TaxonomyNodeName
TaxonomyStateType = TypeVar("TaxonomyStateType", bound=TaxonomyAgentState)
@@ -50,7 +51,7 @@ class TaxonomyAgentNode(
def __init__(self, team: Team, user: User, toolkit_class: type["TaxonomyAgentToolkit"]):
super().__init__(team, user)
self._toolkit = toolkit_class(team=team)
self._toolkit = toolkit_class(team=team, user=user)
self._state_class, self._partial_state_class = self._get_state_class(TaxonomyAgentNode)
@property
@@ -77,7 +78,7 @@ class TaxonomyAgentNode(
).bind_tools(
self._toolkit.get_tools(),
tool_choice="required",
parallel_tool_calls=False,
parallel_tool_calls=True,
)
def _get_default_system_prompts(self) -> list[str]:
@@ -131,9 +132,11 @@ class TaxonomyAgentNode(
if not output_message.tool_calls:
raise ValueError("No tool calls found in the output message.")
tool_call = output_message.tool_calls[0]
result = AgentAction(tool_call["name"], tool_call["args"], tool_call["id"])
intermediate_steps = state.intermediate_steps or []
tool_calls = output_message.tool_calls
intermediate_steps = []
for tool_call in tool_calls:
result = AgentAction(tool_call["name"], tool_call["args"], tool_call["id"])
intermediate_steps.append((result, None))
# Add the new AI message to the progress log
ai_message = LangchainAIMessage(
@@ -142,8 +145,9 @@ class TaxonomyAgentNode(
return self._partial_state_class(
tool_progress_messages=[*progress_messages, ai_message],
intermediate_steps=[*intermediate_steps, (result, None)],
intermediate_steps=intermediate_steps,
output=state.output,
iteration_count=state.iteration_count + 1 if state.iteration_count is not None else 1,
)
@@ -156,7 +160,7 @@ class TaxonomyAgentToolsNode(
def __init__(self, team: Team, user: User, toolkit_class: type["TaxonomyAgentToolkit"]):
super().__init__(team, user)
self._toolkit = toolkit_class(team=team)
self._toolkit = toolkit_class(team=team, user=user)
self._state_class, self._partial_state_class = self._get_state_class(TaxonomyAgentToolsNode)
@property
@@ -165,57 +169,70 @@ class TaxonomyAgentToolsNode(
async def arun(self, state: TaxonomyStateType, config: RunnableConfig) -> TaxonomyPartialStateType:
intermediate_steps = state.intermediate_steps or []
action, _output = intermediate_steps[-1]
tool_input: TaxonomyTool | None = None
output = ""
tool_result_msg: list[LangchainToolMessage] = []
tools_metadata: dict[str, list[tuple[TaxonomyTool, str]]] = defaultdict(list)
invalid_tools = []
steps = []
try:
tool_input = self._toolkit.get_tool_input_model(action)
except ValidationError as e:
output = str(
ChatPromptTemplate.from_template(REACT_PYDANTIC_VALIDATION_EXCEPTION_PROMPT, template_format="mustache")
.format_messages(exception=e.errors(include_url=False))[0]
.content
)
else:
if tool_input.name == "final_answer":
return self._partial_state_class(
output=tool_input.arguments.answer, # type: ignore
intermediate_steps=None,
for action, _ in intermediate_steps:
try:
tool_input = self._toolkit.get_tool_input_model(action)
except ValidationError as e:
output = str(
ChatPromptTemplate.from_template(
REACT_PYDANTIC_VALIDATION_EXCEPTION_PROMPT, template_format="mustache"
)
.format_messages(exception=e.errors(include_url=False))[0]
.content
)
steps.append((action, output))
invalid_tools.append(action.log)
continue
else:
if tool_input.name == "final_answer":
return self._partial_state_class(
output=tool_input.arguments.answer, # type: ignore
intermediate_steps=None,
)
# The agent has requested help, so we return a message to the root node
if tool_input.name == "ask_user_for_help":
return self._get_reset_state(
tool_input.arguments.request, # type: ignore
tool_input.name,
state,
)
if tool_input.name == "ask_user_for_help":
return self._get_reset_state(
tool_input.arguments.request, # type: ignore
tool_input.name,
state,
)
# For any other tool, collect metadata and prepare for result processing
tools_metadata[tool_input.name].append((tool_input, action.log))
# If we're still here, check if we've hit the iteration limit within this cycle
if len(intermediate_steps) >= self.MAX_ITERATIONS:
if state.iteration_count is not None and state.iteration_count >= self.MAX_ITERATIONS:
return self._get_reset_state(ITERATION_LIMIT_PROMPT, "max_iterations", state)
if tool_input and not output:
# Taxonomy is a separate graph, so it dispatches its own messages
reasoning_message = await self.get_reasoning_message(state)
if reasoning_message:
await self._write_message(reasoning_message)
# Use the toolkit to handle tool execution
_, output = await self._toolkit.handle_tools(tool_input.name, tool_input)
# Taxonomy is a separate graph, so it dispatches its own messages
reasoning_message = await self.get_reasoning_message(state)
if reasoning_message:
await self._write_message(reasoning_message)
if output:
tool_results = await self._toolkit.handle_tools(tools_metadata)
tool_msgs = []
for action, _ in intermediate_steps:
if action.log in invalid_tools:
continue
tool_result = tool_results[action.log]
tool_msg = LangchainToolMessage(
content=output,
content=tool_result,
tool_call_id=action.log,
)
tool_result_msg.append(tool_msg)
tool_msgs.append(tool_msg)
steps.append((action, tool_result))
old_msg = state.tool_progress_messages or []
return self._partial_state_class(
tool_progress_messages=[*old_msg, *tool_result_msg],
intermediate_steps=[*intermediate_steps[:-1], (action, output)],
tool_progress_messages=[*old_msg, *tool_msgs],
intermediate_steps=steps,
iteration_count=state.iteration_count,
)
def router(self, state: TaxonomyStateType) -> str:

View File

@@ -65,12 +65,22 @@ TAXONOMY_TOOL_USAGE_PROMPT = """
- *CRITICAL*: DO NOT CALL A TOOL FOR THE SAME ENTITY, EVENT, OR PROPERTY MORE THAN ONCE. IF YOU HAVE NOT FOUND A MATCH YOU MUST TRY WITH THE NEXT BEST MATCH.
4. **Property Value Matching**:
- IMPORTANT: If tool call returns property values that are related but not synonyms to the user's requested value: USE USER'S ORIGINAL VALUE.
- IMPORTANT: If tool call returns property values that are related BUT NOT SYNONYMS to the user's requested value: USE USER'S ORIGINAL VALUE.
For example, if the user asks for $browser to be "Chrome" and the tool call returns '"Firefox", "Safari"', use "Chrome" as the property value. Since "Chrome" is related to "Firefox" and "Safari" since they are all browsers.
- IMPORTANT: If tool call returns property values that are synonyms, typos, or a variant of the user's requested value: USE FOUND VALUES
For example the user asks for the city to be "New York" and the tool call returns "New York City", "NYC", use "New York City" as the property value. Since "New York" is related to "New York City" and "NYC" since they are all variants of New York.
5. **Optimization**:
- Remember that you are able to make parallel tool calls. This is a big performance improvement. Whenever it makes sense to do so, call multiple tools at once.
- Always aim to optimize your tool calls. This will help you find the correct properties and values faster.
6. **Filter Completion**:
- Always aim to complete the filter as much as possible. This will help you meet the user's expectations.
- If you have found most of the properties and values but you are still missing some, return the filter that you have found so far. The user can always ask you to add more properties and values later.
- Be careful though, if you have not found most of the properties and values, you should use the `ask_user_for_help` tool to ask the user for more information.
Example: If the user asks to filter for location, url type, date and browser type, and you could not find anything about the url you can return the filter you found.
- If the tool call returns no values, you can retry with the next best property or entity.
</tool_usage>
""".strip()

View File

@@ -50,15 +50,27 @@ class TestEntities(ClickhouseTestMixin, NonAtomicBaseTest):
properties={"name": "Test User"},
)
self.toolkit = DummyToolkit(self.team)
self.toolkit = DummyToolkit(self.team, self.user)
async def test_retrieve_entity_properties(self):
result = await self.toolkit.retrieve_entity_properties("person")
result = await self.toolkit.retrieve_entity_properties_parallel(["person"])
assert (
"<properties><String><prop><name>name</name></prop><prop><name>property_no_values</name></prop></String></properties>"
== result
== result["person"]
)
async def test_retrieve_entity_properties_entity_not_found(self):
result = await self.toolkit.retrieve_entity_properties_parallel(["test"])
assert "Entity test not found. Available entities: person, session, organization, project" == result["test"]
async def test_retrieve_entity_properties_entity_with_group(self):
result = await self.toolkit.retrieve_entity_properties_parallel(["organization", "session"])
assert "session" in result
assert (
"<properties><String><prop><name>name_group</name></prop></String></properties>" == result["organization"]
)
assert "<properties>" in result["session"]
async def test_person_property_values_exists(self):
result = await self.toolkit._get_entity_names()
expected = ["person", "session", "organization", "project"]

View File

@@ -64,11 +64,18 @@ class TestEvents(ClickhouseTestMixin, NonAtomicBaseTest):
},
)
_create_event(
event="no-properties-event",
distinct_id="user456",
team=self.team,
properties={},
)
Action.objects.create(
id=232, team=self.team, name="action1", description="Test Description", steps_json=[{"event": "event1"}]
)
self.toolkit = DummyToolkit(self.team)
self.toolkit = DummyToolkit(self.team, self.user)
async def test_events_property_values_exists(self):
result = await self.toolkit._get_entity_names()
@@ -121,14 +128,35 @@ class TestEvents(ClickhouseTestMixin, NonAtomicBaseTest):
assert "Firefox" in "\n".join(property_vals.get(232, []))
async def test_retrieve_event_or_action_properties_action_not_found(self):
result = await self.toolkit.retrieve_event_or_action_properties(999)
result = await self.toolkit.retrieve_event_or_action_properties_parallel([999])
assert (
"Action 999 does not exist in the taxonomy. Verify that the action ID is correct and try again." == result
"Action 999 does not exist in the taxonomy. Verify that the action ID is correct and try again."
== result["999"]
)
async def test_retrieve_event_or_action_properties_event_not_found(self):
result = await self.toolkit.retrieve_event_or_action_properties("test")
assert "Properties do not exist in the taxonomy for the event test." == result
result = await self.toolkit.retrieve_event_or_action_properties_parallel(["test"])
assert "Properties do not exist in the taxonomy for the event test." == result["test"]
async def test_retrieve_event_or_action_properties_action_mixed(self):
result = await self.toolkit.retrieve_event_or_action_properties_parallel([232, "event1"])
assert "event1" in result
assert (
"<properties><String><prop><name>id</name></prop><prop><name>$browser</name><description>Name of the browser the user has used.</description></prop></String></properties>"
== result["event1"]
)
assert "<properties>" in result["232"]
async def test_retrieve_event_or_action_properties_action_no_properties(self):
result = await self.toolkit.retrieve_event_or_action_properties_parallel([232, "no-properties-event"])
assert "no-properties-event" in result
assert (
"Properties do not exist in the taxonomy for the event no-properties-event."
== result["no-properties-event"]
)
assert "<properties>" in result["232"]
def tearDown(self):
flush_persons_and_events()

View File

@@ -74,7 +74,7 @@ class TestGroups(ClickhouseTestMixin, NonAtomicBaseTest):
sync=True, # Force sync to ClickHouse
)
self.toolkit = DummyToolkit(self.team)
self.toolkit = DummyToolkit(self.team, self.user)
async def test_entity_names_with_existing_groups(self):
# Test that the entity names include the groups we created in setUp
@@ -107,21 +107,31 @@ class TestGroups(ClickhouseTestMixin, NonAtomicBaseTest):
self.assertIn("project", property_vals)
async def test_retrieve_entity_properties_group(self):
result = await self.toolkit.retrieve_entity_properties("organization")
result = await self.toolkit.retrieve_entity_properties_parallel(["organization"])
assert (
"<properties><String><prop><name>name</name></prop><prop><name>industry</name></prop><prop><name>name_group</name></prop></String></properties>"
== result
== result["organization"]
)
async def test_retrieve_entity_properties_group_not_found(self):
result = await self.toolkit.retrieve_entity_properties("test")
result = await self.toolkit.retrieve_entity_properties_parallel(["test"])
assert (
"Entity test not found. Available entities: person, session, organization, project, no_properties" == result
"Entity test not found. Available entities: person, session, organization, project, no_properties"
== result["test"]
)
async def test_retrieve_entity_properties_group_nothing_found(self):
result = await self.toolkit.retrieve_entity_properties("no_properties")
result = await self.toolkit.retrieve_entity_properties_parallel(["no_properties"])
assert "Properties do not exist in the taxonomy for the entity no_properties." == result
assert "Properties do not exist in the taxonomy for the entity no_properties." == result["no_properties"]
async def test_retrieve_entity_properties_group_mixed(self):
result = await self.toolkit.retrieve_entity_properties_parallel(["organization", "no_properties", "project"])
assert "organization" in result
assert "<properties>" in result["organization"]
assert "Properties do not exist in the taxonomy for the entity no_properties." == result["no_properties"]
assert "project" in result
assert "<properties>" in result["project"]

View File

@@ -165,10 +165,10 @@ class TestTaxonomyAgentToolsNode(BaseTest):
mock_input.name = "test_tool"
mock_input.arguments = Mock()
mock_get_tool_input.return_value = mock_input
mock_handle_tools.return_value = ("test_tool", "tool output")
mock_handle_tools.return_value = {"test_tool_id": "tool output"}
# Create state with intermediate step
action = AgentAction(tool="test_tool", tool_input={"param": "value"}, log="test_log")
action = AgentAction(tool="test_tool", tool_input={"param": "value"}, log="test_tool_id")
state = TaxonomyAgentState()
state.intermediate_steps = [(action, None)]
@@ -177,6 +177,7 @@ class TestTaxonomyAgentToolsNode(BaseTest):
self.assertIsInstance(result, TaxonomyAgentState)
self.assertEqual(len(result.intermediate_steps), 1)
self.assertEqual(result.intermediate_steps[0][1], "tool output")
self.assertEqual(result.intermediate_steps[0][0].log, "test_tool_id")
@patch.object(MockTaxonomyAgentToolkit, "get_tool_input_model")
async def test_run_validation_error(self, mock_get_tool_input):
@@ -193,7 +194,7 @@ class TestTaxonomyAgentToolsNode(BaseTest):
result = await self.node.arun(state, RunnableConfig())
self.assertIsInstance(result, TaxonomyAgentState)
self.assertEqual(len(result.tool_progress_messages), 1)
self.assertEqual(len(result.tool_progress_messages), 0)
@patch.object(MockTaxonomyAgentToolkit, "get_tool_input_model")
async def test_run_final_answer(self, mock_get_tool_input):
@@ -251,6 +252,7 @@ class TestTaxonomyAgentToolsNode(BaseTest):
state = TaxonomyAgentState()
state.intermediate_steps = actions
state.iteration_count = self.node.MAX_ITERATIONS
with patch.object(self.node, "_get_reset_state") as mock_reset:
mock_reset.return_value = TaxonomyAgentState()
@@ -300,7 +302,7 @@ class TestTaxonomyAgentToolsNode(BaseTest):
result = self.node._get_reset_state("test output", "test_tool", original_state)
self.assertEqual(len(result.intermediate_steps), 1)
action, output = result.intermediate_steps[0]
action, output = result.intermediate_steps[0] # type: ignore
self.assertEqual(action.tool, "test_tool")
self.assertEqual(action.tool_input, "test output")
self.assertIsNone(output)

View File

@@ -6,6 +6,7 @@ from parameterized import parameterized
from pydantic import BaseModel
from ee.hogai.graph.taxonomy.toolkit import TaxonomyAgentToolkit, TaxonomyToolNotFoundError
from ee.hogai.graph.taxonomy.tools import TaxonomyTool
class DummyToolkit(TaxonomyAgentToolkit):
@@ -16,7 +17,7 @@ class DummyToolkit(TaxonomyAgentToolkit):
class TestTaxonomyAgentToolkit(BaseTest):
def setUp(self):
super().setUp()
self.toolkit = DummyToolkit(self.team)
self.toolkit = DummyToolkit(self.team, self.user)
async def test_toolkit_initialization(self):
self.assertEqual(self.toolkit._team, self.team)
@@ -69,12 +70,11 @@ class TestTaxonomyAgentToolkit(BaseTest):
("retrieve_entity_property_values", {"entity": "person", "property_name": "email"}, "mocked"),
("retrieve_event_properties", {"event_name": "test_event"}, "mocked"),
("retrieve_event_property_values", {"event_name": "test_event", "property_name": "$browser"}, "mocked"),
("ask_user_for_help", {"request": "Help needed"}, "Help needed"),
]
)
@patch.object(DummyToolkit, "retrieve_entity_properties", return_value="mocked")
@patch.object(DummyToolkit, "retrieve_entity_properties_parallel", return_value={"person": "mocked"})
@patch.object(DummyToolkit, "retrieve_entity_property_values", return_value={"person": ["mocked"]})
@patch.object(DummyToolkit, "retrieve_event_or_action_properties", return_value="mocked")
@patch.object(DummyToolkit, "retrieve_event_or_action_properties_parallel", return_value={"test_event": "mocked"})
@patch.object(DummyToolkit, "retrieve_event_or_action_property_values", return_value={"test_event": ["mocked"]})
async def test_handle_tools(self, tool_name, tool_args, expected_result, *mocks):
class Arguments(BaseModel):
@@ -83,25 +83,25 @@ class TestTaxonomyAgentToolkit(BaseTest):
for key, value in tool_args.items():
setattr(Arguments, key, value)
class ToolInput(BaseModel):
class ToolInput(TaxonomyTool):
name: str
arguments: Arguments
tool_input = ToolInput(name=tool_name, arguments=Arguments(**tool_args))
tool_name_result, result = await self.toolkit.handle_tools(tool_name, tool_input)
result = await self.toolkit.handle_tools({tool_name: [(tool_input, "test_call_id")]})
self.assertEqual(result, expected_result)
self.assertEqual(tool_name_result, tool_name)
self.assertEqual(len(result), 1)
self.assertEqual(result["test_call_id"], expected_result)
async def test_handle_tools_invalid_tool(self):
class ToolInput(BaseModel):
class ToolInput(TaxonomyTool):
name: str = "invalid_tool"
arguments: dict = {}
tool_input = ToolInput()
with self.assertRaises(TaxonomyToolNotFoundError):
await self.toolkit.handle_tools("invalid_tool", tool_input)
await self.toolkit.handle_tools({"invalid_tool": [(tool_input, "invalid_tool")]})
def test_format_properties_formats(self):
props = [("prop1", "String", "Test description"), ("prop2", "Numeric", None)]
@@ -125,7 +125,12 @@ class TestTaxonomyAgentToolkit(BaseTest):
@parameterized.expand(
[
("retrieve_entity_properties", {"entity": "person"}, "retrieve_entity_properties", {"entity": "person"}),
(
"retrieve_entity_properties_parallel",
{"entity": "person"},
"retrieve_entity_properties_parallel",
{"entity": "person"},
),
(
"retrieve_event_properties",
{"event_name": "test_event"},
@@ -150,7 +155,12 @@ class TestTaxonomyAgentToolkit(BaseTest):
"retrieve_event_property_values",
{"event_name": "test_event", "property_name": "$browser"},
),
("retrieve_entity_properties", {"entity": "session"}, "retrieve_entity_properties", {"entity": "session"}),
(
"retrieve_entity_properties_parallel",
{"entity": "session"},
"retrieve_entity_properties_parallel",
{"entity": "session"},
),
]
)
def test_get_tool_input_model_with_valid_tools(self, tool_name, tool_input, expected_name, expected_args):
@@ -177,7 +187,7 @@ class TestTaxonomyAgentToolkit(BaseTest):
return [CustomTool]
custom_toolkit = CustomToolkit(self.team)
custom_toolkit = CustomToolkit(self.team, self.user)
action = AgentAction(tool="custom_tool", tool_input={"custom_field": "test_value"}, log="test log")
@@ -195,7 +205,7 @@ class TestTaxonomyAgentToolkit(BaseTest):
def _get_custom_tools(self):
raise NotImplementedError("This is a test error")
basic_toolkit = BasicToolkit(self.team)
basic_toolkit = BasicToolkit(self.team, self.user)
# Should not raise NotImplementedError, should fall back to default tools
tools = basic_toolkit.get_tools()
@@ -230,7 +240,7 @@ class TestTaxonomyAgentToolkit(BaseTest):
return [custom_tool_1, custom_tool_2]
custom_toolkit = CustomToolkit(self.team)
custom_toolkit = CustomToolkit(self.team, self.user)
# Should return both default and custom tools
tools = custom_toolkit.get_tools()

View File

@@ -1,9 +1,11 @@
import json
import asyncio
from collections.abc import Iterable
from functools import cached_property
from typing import Any, Optional, Union, cast
from typing import Optional, Union, cast
from langchain_core.agents import AgentAction
from langchain_core.runnables import RunnableConfig
from pydantic import BaseModel
from posthog.schema import (
@@ -14,6 +16,8 @@ from posthog.schema import (
EventTaxonomyItem,
EventTaxonomyQuery,
QueryStatusResponse,
TaskExecutionItem,
TaskExecutionStatus,
)
from posthog.hogql.database.schema.channel_type import DEFAULT_CHANNEL_TYPES
@@ -22,7 +26,7 @@ from posthog.clickhouse.query_tagging import Product, tags_context
from posthog.hogql_queries.ai.actors_property_taxonomy_query_runner import ActorsPropertyTaxonomyQueryRunner
from posthog.hogql_queries.ai.event_taxonomy_query_runner import EventTaxonomyQueryRunner
from posthog.hogql_queries.query_runner import ExecutionMode
from posthog.models import Action, Team
from posthog.models import Action, Team, User
from posthog.models.group_type_mapping import GroupTypeMapping
from posthog.models.property_definition import PropertyDefinition, PropertyType
from posthog.sync import database_sync_to_async
@@ -34,15 +38,18 @@ from ee.hogai.graph.taxonomy.format import (
format_properties_yaml,
format_property_values,
)
from ee.hogai.utils.types.base import BaseStateWithTasks, TaskArtifact, TaskResult
from ee.hogai.utils.types.composed import MaxNodeName
from ..parallel_task_execution.nodes import BaseTaskExecutorNode, TaskExecutionInputTuple
from .tools import (
TaxonomyTool,
ask_user_for_help,
retrieve_entity_properties,
retrieve_entity_property_values,
get_dynamic_entity_tools,
retrieve_event_properties,
retrieve_event_property_values,
)
from .types import TaxonomyNodeName
class TaxonomyToolNotFoundError(Exception):
@@ -104,12 +111,45 @@ class TaxonomyErrorMessages:
return f"Properties do not exist in the taxonomy for the {event_name}."
class TaxonomyTaskExecutorNode(
BaseTaskExecutorNode[
BaseStateWithTasks,
BaseStateWithTasks,
]
):
"""
Task executor node specifically for taxonomy operations.
"""
@property
def node_name(self) -> MaxNodeName:
return TaxonomyNodeName.TASK_EXECUTOR
async def _aget_input_tuples(self, state: BaseStateWithTasks) -> list[TaskExecutionInputTuple]:
taxonomy_toolkit = TaxonomyAgentToolkit(self._team, self._user)
if not state.tasks:
raise ValueError("No tasks to execute")
input_tuples: list[TaskExecutionInputTuple] = []
for task in state.tasks:
if task.task_type == "retrieve_event_or_action_properties":
input_tuples.append((task, [], taxonomy_toolkit._handle_event_or_action_properties_task))
elif task.task_type == "retrieve_entity_properties":
input_tuples.append((task, [], taxonomy_toolkit._handle_entity_properties_task))
elif task.task_type == "retrieve_group_properties":
input_tuples.append((task, [], taxonomy_toolkit._handle_group_properties_task))
else:
raise ValueError(f"Unsupported task type: {task.task_type}")
return input_tuples
class TaxonomyAgentToolkit:
"""Base toolkit for taxonomy agents that handle tool execution."""
def __init__(self, team: Team):
def __init__(self, team: Team, user: User):
self._team = team
self._user = user
self.MAX_ENTITIES_PER_BATCH = 6
self.MAX_PROPERTIES = 500
@property
def _groups(self):
@@ -139,6 +179,10 @@ class TaxonomyAgentToolkit:
def _get_entity_names(self) -> list[str]:
return self._entity_names
@database_sync_to_async(thread_sensitive=False)
def _get_team_group_types(self) -> list[str]:
return self._team_group_types
def _enrich_props_with_descriptions(self, entity: str, props: Iterable[tuple[str, str | None]]):
return enrich_props_with_descriptions(entity, props)
@@ -181,13 +225,18 @@ class TaxonomyAgentToolkit:
Retrieve event/action taxonomy with efficient caching.
Multiple properties are batched in a single query to maximize cache hits.
"""
is_event = isinstance(event_name_or_action_id, str)
try:
action_id = int(event_name_or_action_id)
is_event = False
except ValueError:
is_event = True
if is_event:
query = EventTaxonomyQuery(event=event_name_or_action_id, maxPropertyValues=25, properties=properties)
verbose_name = f"event {event_name_or_action_id}"
else:
query = EventTaxonomyQuery(actionId=event_name_or_action_id, maxPropertyValues=25, properties=properties)
verbose_name = f"action with ID {event_name_or_action_id}"
query = EventTaxonomyQuery(actionId=action_id, maxPropertyValues=25, properties=properties)
verbose_name = f"action with ID {action_id}"
runner = EventTaxonomyQueryRunner(query, self._team)
with tags_context(product=Product.MAX_AI, team_id=self._team.pk, org_id=self._team.organization_id):
# Use cache-first execution mode for optimal performance
@@ -226,11 +275,16 @@ class TaxonomyAgentToolkit:
def _get_default_tools(self) -> list:
"""Get default taxonomy tools."""
dynamic_retrieve_entity_properties, dynamic_retrieve_entity_property_values = get_dynamic_entity_tools(
self._team_group_types
)
return [
retrieve_event_properties,
retrieve_entity_properties,
retrieve_entity_property_values,
# retrieve_entity_properties,
# retrieve_entity_property_values,
dynamic_retrieve_entity_properties,
retrieve_event_property_values,
dynamic_retrieve_entity_property_values,
ask_user_for_help,
]
@@ -238,49 +292,6 @@ class TaxonomyAgentToolkit:
"""Get custom tools. Override in subclasses to add custom tools."""
raise NotImplementedError("_get_custom_tools must be implemented in subclasses")
async def retrieve_entity_properties(self, entity: str, max_properties: int = 500) -> str:
"""
Retrieve properties for an entitiy like person, session, or one of the groups.
"""
entity_names = await self._get_entity_names()
if entity not in entity_names:
return TaxonomyErrorMessages.entity_not_found(entity, entity_names)
props: list[Any] = []
if entity == "person":
qs = PropertyDefinition.objects.filter(team=self._team, type=PropertyDefinition.Type.PERSON).values_list(
"name", "property_type"
)
props = self._enrich_props_with_descriptions("person", [prop async for prop in qs])
elif entity == "session":
# Session properties are not in the DB.
props = self._enrich_props_with_descriptions(
"session",
[
(prop_name, prop["type"])
for prop_name, prop in CORE_FILTER_DEFINITIONS_BY_GROUP["session_properties"].items()
if prop.get("type") is not None
],
)
else:
group_type_index = None
groups = [group async for group in self._groups]
for group in groups:
if group.group_type == entity:
group_type_index = group.group_type_index
break
if group_type_index is None:
return f"Group {entity} does not exist in the taxonomy."
qs = PropertyDefinition.objects.filter(
team=self._team, type=PropertyDefinition.Type.GROUP, group_type_index=group_type_index
).values_list("name", "property_type")[:max_properties]
props = self._enrich_props_with_descriptions(entity, [prop async for prop in qs])
if not props:
return f"Properties do not exist in the taxonomy for the entity {entity}."
return self._format_properties(props)
async def retrieve_entity_property_values(self, entity_properties: dict[str, list[str]]) -> dict[str, list[str]]:
result = await self._parallel_entity_processing(entity_properties)
return result
@@ -348,6 +359,251 @@ class TaxonomyAgentToolkit:
)
return results
async def retrieve_entity_properties_parallel(self, entities: list[str]) -> dict[str, str]:
entity_tasks = []
groups = []
for entity in entities:
group_types = [group.group_type async for group in self._groups]
if entity in group_types:
groups.append(entity)
else:
entity_tasks.append(
TaskExecutionItem(
id=str(entity),
prompt=entity,
status=TaskExecutionStatus.PENDING,
description="Retrieving entity properties",
progress_text=f"Retrieving properties for {entity}...",
task_type="retrieve_entity_properties",
)
)
if groups:
entity_tasks.append(
TaskExecutionItem(
id="group_properties",
prompt=str(groups),
status=TaskExecutionStatus.PENDING,
description="Retrieving group properties",
progress_text="Retrieving properties for groups...",
task_type="retrieve_group_properties",
)
)
task_executor_state = BaseStateWithTasks(
tasks=entity_tasks,
)
config = RunnableConfig()
executor = TaxonomyTaskExecutorNode(self._team, self._user)
result = await executor.arun(task_executor_state, config)
final_result = {}
for task_result in result.task_results:
if task_result.id == "group_properties":
for artifact in task_result.artifacts:
if artifact.id is None:
continue
if isinstance(artifact.id, int):
final_result[str(artifact.id)] = artifact.content
else:
final_result[artifact.id] = artifact.content
else:
final_result[task_result.id] = task_result.result
return final_result
async def _handle_event_or_action_properties_task(self, input_dict: dict) -> TaskResult:
"""
Retrieve properties for an event.
"""
task = cast(TaskExecutionItem, input_dict["task"])
try:
response, verbose_name = await self._retrieve_event_or_action_taxonomy(task.prompt)
except Action.DoesNotExist:
project_actions = await self._get_project_actions()
if not project_actions:
result = TaxonomyErrorMessages.no_actions_exist()
return TaskResult(
id=task.id,
description=task.description,
result=result,
artifacts=[],
status=TaskExecutionStatus.FAILED,
)
result = TaxonomyErrorMessages.action_not_found(task.prompt)
return TaskResult(
id=task.id,
description=task.description,
result=result,
artifacts=[],
status=TaskExecutionStatus.FAILED,
)
if not isinstance(response, CachedEventTaxonomyQueryResponse):
result = TaxonomyErrorMessages.generic_not_found("Properties")
return TaskResult(
id=task.id,
description=task.description,
result=result,
artifacts=[],
status=TaskExecutionStatus.FAILED,
)
if not response.results:
result = TaxonomyErrorMessages.event_properties_not_found(verbose_name)
return TaskResult(
id=task.id,
description=task.description,
result=result,
artifacts=[],
status=TaskExecutionStatus.FAILED,
)
qs = PropertyDefinition.objects.filter(
team=self._team, type=PropertyDefinition.Type.EVENT, name__in=[item.property for item in response.results]
)
property_definitions = [prop async for prop in qs]
property_to_type = {
property_definition.name: property_definition.property_type for property_definition in property_definitions
}
props = [
(item.property, property_to_type.get(item.property))
for item in response.results
# Exclude properties that exist in the taxonomy, but don't have a type.
if item.property in property_to_type
]
if not props:
result = TaxonomyErrorMessages.event_properties_not_found(verbose_name)
return TaskResult(
id=task.id,
description=task.description,
result=result,
artifacts=[],
status=TaskExecutionStatus.FAILED,
)
formatted_properties = self._format_properties(self._enrich_props_with_descriptions("event", props))
return TaskResult(
id=task.id,
description=task.description,
result=formatted_properties,
artifacts=[],
status=TaskExecutionStatus.COMPLETED,
)
async def _handle_group_properties_task(self, input_dict: dict) -> TaskResult:
task = cast(TaskExecutionItem, input_dict["task"])
try:
# Convert Python list string to JSON format and parse
json_str = task.prompt.replace("'", '"')
group_entities = json.loads(json_str)
except json.JSONDecodeError:
group_entities = [task.prompt]
entity_to_group_index = {}
artifacts = []
for group_entity in group_entities:
group_types = {group.group_type: group.group_type_index async for group in self._groups}
group_type_index = group_types.get(group_entity, None)
if group_type_index is not None:
entity_to_group_index[group_entity] = group_type_index
else:
artifacts.append(
TaskArtifact(
id=group_entity,
task_id=task.id,
content=TaxonomyErrorMessages.properties_not_found(group_entity),
)
)
continue
if entity_to_group_index.values():
# Single query for all group types
group_qs = PropertyDefinition.objects.filter(
team=self._team,
type=PropertyDefinition.Type.GROUP,
group_type_index__in=entity_to_group_index.values(),
).values_list("name", "property_type", "group_type_index")[: self.MAX_PROPERTIES]
group_qs_definitions = [prop async for prop in group_qs]
# Group results by entity
for entity in group_entities:
if entity in entity_to_group_index.keys():
group_index = entity_to_group_index[entity]
properties = [
(name, prop_type) for name, prop_type, gti in group_qs_definitions if gti == group_index
]
result = (
self._format_properties(self._enrich_props_with_descriptions(entity, properties))
if properties
else TaxonomyErrorMessages.properties_not_found(entity)
)
artifacts.append(
TaskArtifact(
id=entity,
task_id="group_properties",
content=result,
)
)
return TaskResult(
id=task.id,
description=task.description,
result="",
artifacts=artifacts,
status=TaskExecutionStatus.COMPLETED,
)
async def _handle_entity_properties_task(self, input_dict: dict) -> TaskResult:
task = cast(TaskExecutionItem, input_dict["task"])
entity = task.prompt
if entity == "person":
person_qs = PropertyDefinition.objects.filter(
team=self._team, type=PropertyDefinition.Type.PERSON
).values_list("name", "property_type")
person_definitions = [prop async for prop in person_qs]
if person_definitions:
result = self._format_properties(self._enrich_props_with_descriptions("person", person_definitions))
status = TaskExecutionStatus.COMPLETED
else:
result = TaxonomyErrorMessages.properties_not_found(entity)
status = TaskExecutionStatus.FAILED
return TaskResult(
id=task.id,
description=task.description,
result=result,
artifacts=[],
status=status,
)
elif entity == "session":
props = [
(prop_name, prop["type"])
for prop_name, prop in CORE_FILTER_DEFINITIONS_BY_GROUP["session_properties"].items()
if prop.get("type") is not None
]
if props:
result = self._format_properties(self._enrich_props_with_descriptions("session", props))
status = TaskExecutionStatus.COMPLETED
else:
result = TaxonomyErrorMessages.properties_not_found(entity)
status = TaskExecutionStatus.FAILED
return TaskResult(
id=task.id,
description=task.description,
result=result,
artifacts=[],
status=status,
)
else:
return TaskResult(
id=task.id,
description=task.description,
result=TaxonomyErrorMessages.entity_not_found(entity, await self._get_entity_names()),
artifacts=[],
status=TaskExecutionStatus.FAILED,
)
@database_sync_to_async(thread_sensitive=False)
def _get_definitions_for_entity(
self, entity: str, property_names: list[str], query: ActorsPropertyTaxonomyQuery
@@ -402,47 +658,35 @@ class TaxonomyAgentToolkit:
def _get_project_actions(self) -> list[Action]:
return list(Action.objects.filter(team__project_id=self._team.project_id, deleted=False))
async def retrieve_event_or_action_properties(self, event_name_or_action_id: str | int) -> str:
"""
Retrieve properties for an event.
"""
try:
response, verbose_name = await self._retrieve_event_or_action_taxonomy(event_name_or_action_id)
except Action.DoesNotExist:
project_actions = await self._get_project_actions()
if not project_actions:
return TaxonomyErrorMessages.no_actions_exist()
return TaxonomyErrorMessages.action_not_found(event_name_or_action_id)
if not isinstance(response, CachedEventTaxonomyQueryResponse):
return TaxonomyErrorMessages.generic_not_found("Properties")
if not response.results:
return TaxonomyErrorMessages.event_properties_not_found(verbose_name)
qs = PropertyDefinition.objects.filter(
team=self._team, type=PropertyDefinition.Type.EVENT, name__in=[item.property for item in response.results]
async def retrieve_event_or_action_properties_parallel(
self, event_name_or_action_ids: list[str | int]
) -> dict[str, str]:
task_executor_state = BaseStateWithTasks(
tasks=[
TaskExecutionItem(
id=str(event_name_or_action_id),
prompt=str(event_name_or_action_id),
status=TaskExecutionStatus.PENDING,
description="Retrieving event or action properties",
progress_text=f"Retrieving properties for {event_name_or_action_id}...",
task_type="retrieve_event_or_action_properties",
)
for event_name_or_action_id in event_name_or_action_ids
],
)
property_data = [prop async for prop in qs]
property_to_type = {prop.name: prop.property_type for prop in property_data}
props = [
(item.property, property_to_type.get(item.property))
for item in response.results
# Exclude properties that exist in the taxonomy, but don't have a type.
if item.property in property_to_type
]
if not props:
return TaxonomyErrorMessages.event_properties_not_found(verbose_name)
enriched_props = self._enrich_props_with_descriptions("event", props)
return self._format_properties(enriched_props)
config = RunnableConfig()
executor = TaxonomyTaskExecutorNode(self._team, self._user)
result = await executor.arun(task_executor_state, config)
return {task.id: task.result for task in result.task_results}
async def retrieve_event_or_action_property_values(
self, event_properties: dict[str | int, list[str]]
) -> dict[str | int, list[str]]:
"""Retrieve property values for an event/action. Supports single property or list of properties."""
result = await self._parallel_event_processing(event_properties)
result = await self._parallel_event_or_action_processing(event_properties)
return result
async def _parallel_event_processing(
async def _parallel_event_or_action_processing(
self, event_properties: dict[str | int, list[str]]
) -> dict[str | int, list[str]]:
event_tasks = [
@@ -535,27 +779,102 @@ class TaxonomyAgentToolkit:
return results
async def handle_tools(self, tool_name: str, tool_input: TaxonomyTool) -> tuple[str, str]:
if tool_name == "retrieve_entity_property_values":
entity = tool_input.arguments.entity # type: ignore
property_name = tool_input.arguments.property_name # type: ignore
result = (await self.retrieve_entity_property_values({entity: [property_name]}))[entity][0]
elif tool_name == "retrieve_entity_properties":
result = await self.retrieve_entity_properties(tool_input.arguments.entity) # type: ignore
elif tool_name == "retrieve_event_property_values":
event_name_or_action_id = tool_input.arguments.event_name # type: ignore
property_name = tool_input.arguments.property_name # type: ignore
result = (await self.retrieve_event_or_action_property_values({event_name_or_action_id: [property_name]}))[
event_name_or_action_id
][0]
elif tool_name == "retrieve_event_properties":
result = await self.retrieve_event_or_action_properties(tool_input.arguments.event_name) # type: ignore
elif tool_name == "ask_user_for_help":
result = tool_input.arguments.request # type: ignore
else:
raise TaxonomyToolNotFoundError(f"Tool {tool_name} not found in taxonomy toolkit.")
def _collect_tools(self, tool_metadata: dict[str, list[tuple[TaxonomyTool, str]]]) -> dict:
"""
Collect and group tool calls by type for batch processing.
Returns grouped data and mappings for result distribution.
"""
result: dict = {
"entity_property_values": {}, # entity -> [property_names]
"entity_properties": [], # [entities]
"event_property_values": {}, # event_name -> [property_names]
"event_properties": [], # [event_names]
"entity_prop_mapping": {}, # (entity, property) -> tool_call_id
"entity_mapping": {}, # entity -> tool_call_id
"event_prop_mapping": {}, # (event, property) -> tool_call_id
"event_mapping": {}, # event -> tool_call_id
}
return tool_name, result
for tool_name, tool_inputs in tool_metadata.items():
for tool_input, tool_call_id in tool_inputs:
if tool_name == "retrieve_entity_property_values":
entity = tool_input.arguments.entity # type: ignore
property_name = tool_input.arguments.property_name # type: ignore
if entity not in result["entity_property_values"]:
result["entity_property_values"][entity] = []
result["entity_property_values"][entity].append(property_name)
result["entity_prop_mapping"][(entity, property_name)] = tool_call_id
elif tool_name == "retrieve_entity_properties":
entity = tool_input.arguments.entity # type: ignore
result["entity_properties"].append(entity)
result["entity_mapping"][entity] = tool_call_id
elif tool_name == "retrieve_event_property_values":
event_name = tool_input.arguments.event_name # type: ignore
property_name = tool_input.arguments.property_name # type: ignore
if event_name not in result["event_property_values"]:
result["event_property_values"][event_name] = []
result["event_property_values"][event_name].append(property_name)
result["event_prop_mapping"][(event_name, property_name)] = tool_call_id
elif tool_name == "retrieve_event_properties":
event_name = tool_input.arguments.event_name # type: ignore
result["event_properties"].append(event_name)
result["event_mapping"][event_name] = tool_call_id
else:
raise TaxonomyToolNotFoundError(f"Tool {tool_name} not found in taxonomy toolkit.")
return result
async def _execute_tools(self, collected_tools: dict) -> dict[str, str]:
"""
Execute batch operations and distribute results.
Returns a dict mapping tool_call_id to result for each individual tool call.
"""
results = {}
# Execute batch operations and distribute results in single passes
if collected_tools["entity_property_values"]:
entity_property_values = await self.retrieve_entity_property_values(
collected_tools["entity_property_values"]
)
for entity, property_results in entity_property_values.items():
for i, property_name in enumerate(collected_tools["entity_property_values"][entity]):
results[collected_tools["entity_prop_mapping"][(entity, property_name)]] = property_results[i]
if collected_tools["entity_properties"]:
entity_properties = await self.retrieve_entity_properties_parallel(collected_tools["entity_properties"])
for entity, result in entity_properties.items():
results[collected_tools["entity_mapping"][entity]] = result
if collected_tools["event_property_values"]:
event_property_values = await self.retrieve_event_or_action_property_values(
collected_tools["event_property_values"]
)
for event_name, property_results in event_property_values.items():
for i, property_name in enumerate(collected_tools["event_property_values"][event_name]):
results[collected_tools["event_prop_mapping"][(event_name, property_name)]] = property_results[i]
if collected_tools["event_properties"]:
event_properties = await self.retrieve_event_or_action_properties_parallel(
collected_tools["event_properties"]
)
for event_name, result in event_properties.items():
results[collected_tools["event_mapping"][event_name]] = result
return results
async def handle_tools(self, tool_metadata: dict[str, list[tuple[TaxonomyTool, str]]]) -> dict[str, str]:
"""
Handle multiple tool calls with maximum optimization by batching similar operations.
Returns a dict mapping tool_call_id to result for each individual tool call.
"""
# Collect and group tools
collected_tools = self._collect_tools(tool_metadata)
# Execute tools and return results
return await self._execute_tools(collected_tools)
def get_tool_input_model(self, action: AgentAction) -> TaxonomyTool:
try:

View File

@@ -39,9 +39,7 @@ class retrieve_entity_properties(BaseModel):
- **Avoid using ambiguous properties** unless their relevance is explicitly confirmed.
"""
entity: Literal["person", "session"] = Field(
..., description="The type of the entity that you want to retrieve properties for."
)
entity: str = Field(..., description="The type of the entity that you want to retrieve properties for.")
class retrieve_event_property_values(BaseModel):
@@ -67,9 +65,7 @@ class retrieve_entity_property_values(BaseModel):
Use this tool to retrieve property values for a property name. Adjust filters to these values. You will receive a list of property values or a message that property values have not been found. Some properties can have many values, so the output will be truncated. Use your judgment to find a proper value.
"""
entity: Literal["person", "session"] = Field(
..., description="The type of the entity that you want to retrieve properties for."
)
entity: str = Field(..., description="The type of the entity that you want to retrieve properties for.")
property_name: str = Field(..., description="The name of the property that you want to retrieve values for.")

View File

@@ -36,6 +36,11 @@ class TaxonomyAgentState(BaseStateWithIntermediateSteps, BaseStateWithMessages,
The messages with tool calls to collect tool progress.
"""
iteration_count: int | None = Field(default=None)
"""
The number of iterations the taxonomy agent has gone through.
"""
class TaxonomyNodeName(StrEnum):
"""Generic node names for taxonomy agents."""
@@ -44,6 +49,7 @@ class TaxonomyNodeName(StrEnum):
TOOLS_NODE = "taxonomy_tools_node"
START = START
END = END
TASK_EXECUTOR = "taxonomy_task_executor"
class EntityType(str, Enum):

View File

@@ -17472,6 +17472,9 @@
},
{
"$ref": "#/definitions/RecordingPropertyFilter"
},
{
"$ref": "#/definitions/GroupPropertyFilter"
}
]
},

View File

@@ -1,6 +1,7 @@
import {
EventPropertyFilter,
FilterLogicalOperator,
GroupPropertyFilter,
PersonPropertyFilter,
RecordingDurationFilter,
RecordingPropertyFilter,
@@ -35,3 +36,4 @@ export type MaxUniversalFilterValue =
| PersonPropertyFilter
| SessionPropertyFilter
| RecordingPropertyFilter
| GroupPropertyFilter

View File

@@ -372,8 +372,6 @@ ee/hogai/graph/taxonomy/test/test_nodes.py:0: error: Value of type "list[tuple[A
ee/hogai/graph/taxonomy/test/test_nodes.py:0: error: Value of type "list[tuple[AgentAction, str | None]] | None" is not indexable [index]
ee/hogai/graph/taxonomy/test/test_toolkit.py:0: error: Argument 1 to "_format_properties_xml" of "TaxonomyAgentToolkit" has incompatible type "list[tuple[str, str, str | None]]"; expected "list[tuple[str, str | None, str | None]]" [arg-type]
ee/hogai/graph/taxonomy/test/test_toolkit.py:0: error: Argument 1 to "_format_properties_yaml" of "TaxonomyAgentToolkit" has incompatible type "list[tuple[str, str, str | None]]"; expected "list[tuple[str, str | None, str | None]]" [arg-type]
ee/hogai/graph/taxonomy/test/test_toolkit.py:0: error: Argument 2 to "handle_tools" of "TaxonomyAgentToolkit" has incompatible type "ToolInput"; expected "TaxonomyTool[Any]" [arg-type]
ee/hogai/graph/taxonomy/test/test_toolkit.py:0: error: Argument 2 to "handle_tools" of "TaxonomyAgentToolkit" has incompatible type "ToolInput"; expected "TaxonomyTool[Any]" [arg-type]
ee/hogai/graph/taxonomy/test/test_toolkit.py:0: error: Item "ask_user_for_help" of "retrieve_event_properties | retrieve_entity_properties | retrieve_entity_property_values | retrieve_event_property_values | ask_user_for_help | Any" has no attribute "custom_field" [union-attr]
ee/hogai/graph/taxonomy/test/test_toolkit.py:0: error: Item "retrieve_entity_properties" of "retrieve_event_properties | retrieve_entity_properties | retrieve_entity_property_values | retrieve_event_property_values | ask_user_for_help | Any" has no attribute "custom_field" [union-attr]
ee/hogai/graph/taxonomy/test/test_toolkit.py:0: error: Item "retrieve_entity_property_values" of "retrieve_event_properties | retrieve_entity_properties | retrieve_entity_property_values | retrieve_event_property_values | ask_user_for_help | Any" has no attribute "custom_field" [union-attr]

View File

@@ -12606,7 +12606,15 @@ class MaxInnerUniversalFiltersGroup(BaseModel):
extra="forbid",
)
type: FilterLogicalOperator
values: list[Union[EventPropertyFilter, PersonPropertyFilter, SessionPropertyFilter, RecordingPropertyFilter]]
values: list[
Union[
EventPropertyFilter,
PersonPropertyFilter,
SessionPropertyFilter,
RecordingPropertyFilter,
GroupPropertyFilter,
]
]
class MaxOuterUniversalFiltersGroup(BaseModel):

View File

@@ -41,8 +41,8 @@ class final_answer(base_final_answer[FinalAnswerArgs]):
class HogQLGeneratorToolkit(TaxonomyAgentToolkit):
def __init__(self, team: Team):
super().__init__(team)
def __init__(self, team: Team, user: User):
super().__init__(team, user)
def _get_custom_tools(self) -> list:
"""Get custom tools for the HogQLGenerator."""

View File

@@ -115,11 +115,11 @@ class final_answer(base_final_answer[ErrorTrackingIssueImpactToolOutput]):
class ErrorTrackingIssueImpactToolkit(TaxonomyAgentToolkit):
def __init__(self, team: Team):
super().__init__(team)
def __init__(self, team: Team, user: User):
super().__init__(team, user)
async def handle_tools(self, tool_name: str, tool_input: TaxonomyTool) -> tuple[str, str]:
return await super().handle_tools(tool_name, tool_input)
async def handle_tools(self, tool_metadata: dict[str, list[tuple[TaxonomyTool, str]]]) -> dict[str, str]:
return await super().handle_tools(tool_metadata)
def _get_custom_tools(self) -> list:
return [final_answer]

View File

@@ -31,8 +31,8 @@ logger.setLevel(logging.DEBUG)
class SessionReplayFilterOptionsToolkit(TaxonomyAgentToolkit):
def __init__(self, team: Team):
super().__init__(team)
def __init__(self, team: Team, user: User):
super().__init__(team, user)
def _get_custom_tools(self) -> list:
"""Get custom tools for filter options."""

View File

@@ -18,7 +18,7 @@ from ee.hogai.graph.taxonomy.agent import TaxonomyAgent
from ee.hogai.graph.taxonomy.format import enrich_props_with_descriptions, format_properties_xml
from ee.hogai.graph.taxonomy.nodes import TaxonomyAgentNode, TaxonomyAgentToolsNode
from ee.hogai.graph.taxonomy.toolkit import TaxonomyAgentToolkit, TaxonomyErrorMessages
from ee.hogai.graph.taxonomy.tools import ask_user_for_help, base_final_answer
from ee.hogai.graph.taxonomy.tools import TaxonomyTool, ask_user_for_help, base_final_answer
from ee.hogai.graph.taxonomy.types import TaxonomyAgentState
from ee.hogai.tool import MaxTool
from ee.hogai.utils.types.base import AssistantNodeName
@@ -53,16 +53,27 @@ logger.setLevel(logging.DEBUG)
class RevenueAnalyticsFilterOptionsToolkit(TaxonomyAgentToolkit):
def __init__(self, team: Team):
super().__init__(team)
def __init__(self, team: Team, user: User):
super().__init__(team, user)
async def handle_tools(self, tool_name: str, tool_input) -> tuple[str, str]:
async def handle_tools(self, tool_metadata: dict[str, list[tuple[TaxonomyTool, str]]]) -> dict[str, str]:
"""Handle custom tool execution."""
if tool_name == "retrieve_revenue_analytics_property_values":
result = await self._retrieve_revenue_analytics_property_values(tool_input.arguments.property_key)
return tool_name, result
results = {}
unhandled_tools = {}
for tool_name, tool_inputs in tool_metadata.items():
if tool_name == "retrieve_revenue_analytics_property_values":
if tool_inputs:
for tool_input, tool_call_id in tool_inputs:
result = await self._retrieve_revenue_analytics_property_values(
tool_input.arguments.property_key # type: ignore
)
results[tool_call_id] = result
else:
unhandled_tools[tool_name] = tool_inputs
return await super().handle_tools(tool_name, tool_input)
if unhandled_tools:
results.update(await super().handle_tools(unhandled_tools))
return results
def _get_custom_tools(self) -> list:
return [final_answer, retrieve_revenue_analytics_property_values]

View File

@@ -15,11 +15,12 @@ from posthog.schema import SurveyAnalysisQuestionGroup, SurveyCreationSchema
from posthog.constants import DEFAULT_SURVEY_APPEARANCE
from posthog.exceptions_capture import capture_exception
from posthog.models import FeatureFlag, Survey, Team, User
from posthog.sync import database_sync_to_async
from ee.hogai.graph.taxonomy.agent import TaxonomyAgent
from ee.hogai.graph.taxonomy.nodes import TaxonomyAgentNode, TaxonomyAgentToolsNode
from ee.hogai.graph.taxonomy.toolkit import TaxonomyAgentToolkit
from ee.hogai.graph.taxonomy.tools import base_final_answer
from ee.hogai.graph.taxonomy.tools import TaxonomyTool, ask_user_for_help, base_final_answer
from ee.hogai.graph.taxonomy.types import TaxonomyAgentState
from ee.hogai.llm import MaxChatOpenAI
from ee.hogai.tool import MaxTool
@@ -160,8 +161,8 @@ class CreateSurveyTool(MaxTool):
class SurveyToolkit(TaxonomyAgentToolkit):
"""Toolkit for survey creation and feature flag lookup operations."""
def __init__(self, team: Team):
super().__init__(team)
def __init__(self, team: Team, user: User):
super().__init__(team, user)
def get_tools(self) -> list:
"""Get all tools (default + custom). Override in subclasses to add custom tools."""
@@ -181,15 +182,26 @@ class SurveyToolkit(TaxonomyAgentToolkit):
class final_answer(base_final_answer[SurveyCreationSchema]):
__doc__ = base_final_answer.__doc__
return [lookup_feature_flag, final_answer]
return [lookup_feature_flag, final_answer, ask_user_for_help]
async def handle_tools(self, tool_name: str, tool_input) -> tuple[str, str]:
async def handle_tools(self, tool_metadata: dict[str, list[tuple[TaxonomyTool, str]]]) -> dict[str, str]:
"""Handle custom tool execution."""
if tool_name == "lookup_feature_flag":
result = self._lookup_feature_flag(tool_input.arguments.flag_key)
return tool_name, result
return await super().handle_tools(tool_name, tool_input)
results = {}
unhandled_tools = {}
for tool_name, tool_inputs in tool_metadata.items():
if tool_name == "lookup_feature_flag":
if tool_inputs:
for tool_input, tool_call_id in tool_inputs:
result = await self._lookup_feature_flag(tool_input.arguments.flag_key) # type: ignore
results[tool_call_id] = result
else:
unhandled_tools[tool_name] = tool_inputs
if unhandled_tools:
results.update(await super().handle_tools(unhandled_tools))
return results
@database_sync_to_async(thread_sensitive=False)
def _lookup_feature_flag(self, flag_key: str) -> str:
"""Look up feature flag information by key."""
try: