mirror of
https://github.com/BillyOutlast/posthog.git
synced 2026-02-04 03:01:23 +01:00
feat(max): allow for multiple entity search (#38504)
Co-authored-by: github-actions <41898282+github-actions[bot]@users.noreply.github.com>
This commit is contained in:
@@ -102,6 +102,7 @@ class MainAssistant(BaseAssistant):
|
||||
AssistantNodeName.MEMORY_INITIALIZER_INTERRUPT,
|
||||
AssistantNodeName.ROOT_TOOLS,
|
||||
TaxonomyNodeName.TOOLS_NODE,
|
||||
TaxonomyNodeName.TASK_EXECUTOR,
|
||||
}
|
||||
|
||||
@property
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import defaultdict
|
||||
from functools import cached_property
|
||||
from typing import Generic, TypeVar
|
||||
|
||||
@@ -17,6 +18,7 @@ from posthog.schema import MaxEventContext
|
||||
from posthog.models import Team, User
|
||||
from posthog.models.group_type_mapping import GroupTypeMapping
|
||||
|
||||
from ee.hogai.graph.taxonomy.tools import TaxonomyTool
|
||||
from ee.hogai.llm import MaxChatOpenAI
|
||||
from ee.hogai.utils.helpers import format_events_yaml
|
||||
from ee.hogai.utils.types.composed import MaxNodeName
|
||||
@@ -31,7 +33,6 @@ from .prompts import (
|
||||
TAXONOMY_TOOL_USAGE_PROMPT,
|
||||
)
|
||||
from .toolkit import TaxonomyAgentToolkit
|
||||
from .tools import TaxonomyTool
|
||||
from .types import EntityType, TaxonomyAgentState, TaxonomyNodeName
|
||||
|
||||
TaxonomyStateType = TypeVar("TaxonomyStateType", bound=TaxonomyAgentState)
|
||||
@@ -50,7 +51,7 @@ class TaxonomyAgentNode(
|
||||
|
||||
def __init__(self, team: Team, user: User, toolkit_class: type["TaxonomyAgentToolkit"]):
|
||||
super().__init__(team, user)
|
||||
self._toolkit = toolkit_class(team=team)
|
||||
self._toolkit = toolkit_class(team=team, user=user)
|
||||
self._state_class, self._partial_state_class = self._get_state_class(TaxonomyAgentNode)
|
||||
|
||||
@property
|
||||
@@ -77,7 +78,7 @@ class TaxonomyAgentNode(
|
||||
).bind_tools(
|
||||
self._toolkit.get_tools(),
|
||||
tool_choice="required",
|
||||
parallel_tool_calls=False,
|
||||
parallel_tool_calls=True,
|
||||
)
|
||||
|
||||
def _get_default_system_prompts(self) -> list[str]:
|
||||
@@ -131,9 +132,11 @@ class TaxonomyAgentNode(
|
||||
if not output_message.tool_calls:
|
||||
raise ValueError("No tool calls found in the output message.")
|
||||
|
||||
tool_call = output_message.tool_calls[0]
|
||||
result = AgentAction(tool_call["name"], tool_call["args"], tool_call["id"])
|
||||
intermediate_steps = state.intermediate_steps or []
|
||||
tool_calls = output_message.tool_calls
|
||||
intermediate_steps = []
|
||||
for tool_call in tool_calls:
|
||||
result = AgentAction(tool_call["name"], tool_call["args"], tool_call["id"])
|
||||
intermediate_steps.append((result, None))
|
||||
|
||||
# Add the new AI message to the progress log
|
||||
ai_message = LangchainAIMessage(
|
||||
@@ -142,8 +145,9 @@ class TaxonomyAgentNode(
|
||||
|
||||
return self._partial_state_class(
|
||||
tool_progress_messages=[*progress_messages, ai_message],
|
||||
intermediate_steps=[*intermediate_steps, (result, None)],
|
||||
intermediate_steps=intermediate_steps,
|
||||
output=state.output,
|
||||
iteration_count=state.iteration_count + 1 if state.iteration_count is not None else 1,
|
||||
)
|
||||
|
||||
|
||||
@@ -156,7 +160,7 @@ class TaxonomyAgentToolsNode(
|
||||
|
||||
def __init__(self, team: Team, user: User, toolkit_class: type["TaxonomyAgentToolkit"]):
|
||||
super().__init__(team, user)
|
||||
self._toolkit = toolkit_class(team=team)
|
||||
self._toolkit = toolkit_class(team=team, user=user)
|
||||
self._state_class, self._partial_state_class = self._get_state_class(TaxonomyAgentToolsNode)
|
||||
|
||||
@property
|
||||
@@ -165,57 +169,70 @@ class TaxonomyAgentToolsNode(
|
||||
|
||||
async def arun(self, state: TaxonomyStateType, config: RunnableConfig) -> TaxonomyPartialStateType:
|
||||
intermediate_steps = state.intermediate_steps or []
|
||||
action, _output = intermediate_steps[-1]
|
||||
tool_input: TaxonomyTool | None = None
|
||||
output = ""
|
||||
tool_result_msg: list[LangchainToolMessage] = []
|
||||
tools_metadata: dict[str, list[tuple[TaxonomyTool, str]]] = defaultdict(list)
|
||||
invalid_tools = []
|
||||
steps = []
|
||||
|
||||
try:
|
||||
tool_input = self._toolkit.get_tool_input_model(action)
|
||||
except ValidationError as e:
|
||||
output = str(
|
||||
ChatPromptTemplate.from_template(REACT_PYDANTIC_VALIDATION_EXCEPTION_PROMPT, template_format="mustache")
|
||||
.format_messages(exception=e.errors(include_url=False))[0]
|
||||
.content
|
||||
)
|
||||
else:
|
||||
if tool_input.name == "final_answer":
|
||||
return self._partial_state_class(
|
||||
output=tool_input.arguments.answer, # type: ignore
|
||||
intermediate_steps=None,
|
||||
for action, _ in intermediate_steps:
|
||||
try:
|
||||
tool_input = self._toolkit.get_tool_input_model(action)
|
||||
except ValidationError as e:
|
||||
output = str(
|
||||
ChatPromptTemplate.from_template(
|
||||
REACT_PYDANTIC_VALIDATION_EXCEPTION_PROMPT, template_format="mustache"
|
||||
)
|
||||
.format_messages(exception=e.errors(include_url=False))[0]
|
||||
.content
|
||||
)
|
||||
steps.append((action, output))
|
||||
invalid_tools.append(action.log)
|
||||
continue
|
||||
else:
|
||||
if tool_input.name == "final_answer":
|
||||
return self._partial_state_class(
|
||||
output=tool_input.arguments.answer, # type: ignore
|
||||
intermediate_steps=None,
|
||||
)
|
||||
|
||||
# The agent has requested help, so we return a message to the root node
|
||||
if tool_input.name == "ask_user_for_help":
|
||||
return self._get_reset_state(
|
||||
tool_input.arguments.request, # type: ignore
|
||||
tool_input.name,
|
||||
state,
|
||||
)
|
||||
if tool_input.name == "ask_user_for_help":
|
||||
return self._get_reset_state(
|
||||
tool_input.arguments.request, # type: ignore
|
||||
tool_input.name,
|
||||
state,
|
||||
)
|
||||
|
||||
# For any other tool, collect metadata and prepare for result processing
|
||||
tools_metadata[tool_input.name].append((tool_input, action.log))
|
||||
|
||||
# If we're still here, check if we've hit the iteration limit within this cycle
|
||||
if len(intermediate_steps) >= self.MAX_ITERATIONS:
|
||||
if state.iteration_count is not None and state.iteration_count >= self.MAX_ITERATIONS:
|
||||
return self._get_reset_state(ITERATION_LIMIT_PROMPT, "max_iterations", state)
|
||||
|
||||
if tool_input and not output:
|
||||
# Taxonomy is a separate graph, so it dispatches its own messages
|
||||
reasoning_message = await self.get_reasoning_message(state)
|
||||
if reasoning_message:
|
||||
await self._write_message(reasoning_message)
|
||||
# Use the toolkit to handle tool execution
|
||||
_, output = await self._toolkit.handle_tools(tool_input.name, tool_input)
|
||||
# Taxonomy is a separate graph, so it dispatches its own messages
|
||||
reasoning_message = await self.get_reasoning_message(state)
|
||||
if reasoning_message:
|
||||
await self._write_message(reasoning_message)
|
||||
|
||||
if output:
|
||||
tool_results = await self._toolkit.handle_tools(tools_metadata)
|
||||
|
||||
tool_msgs = []
|
||||
for action, _ in intermediate_steps:
|
||||
if action.log in invalid_tools:
|
||||
continue
|
||||
tool_result = tool_results[action.log]
|
||||
tool_msg = LangchainToolMessage(
|
||||
content=output,
|
||||
content=tool_result,
|
||||
tool_call_id=action.log,
|
||||
)
|
||||
tool_result_msg.append(tool_msg)
|
||||
tool_msgs.append(tool_msg)
|
||||
steps.append((action, tool_result))
|
||||
|
||||
old_msg = state.tool_progress_messages or []
|
||||
|
||||
return self._partial_state_class(
|
||||
tool_progress_messages=[*old_msg, *tool_result_msg],
|
||||
intermediate_steps=[*intermediate_steps[:-1], (action, output)],
|
||||
tool_progress_messages=[*old_msg, *tool_msgs],
|
||||
intermediate_steps=steps,
|
||||
iteration_count=state.iteration_count,
|
||||
)
|
||||
|
||||
def router(self, state: TaxonomyStateType) -> str:
|
||||
|
||||
@@ -65,12 +65,22 @@ TAXONOMY_TOOL_USAGE_PROMPT = """
|
||||
- *CRITICAL*: DO NOT CALL A TOOL FOR THE SAME ENTITY, EVENT, OR PROPERTY MORE THAN ONCE. IF YOU HAVE NOT FOUND A MATCH YOU MUST TRY WITH THE NEXT BEST MATCH.
|
||||
|
||||
4. **Property Value Matching**:
|
||||
- IMPORTANT: If tool call returns property values that are related but not synonyms to the user's requested value: USE USER'S ORIGINAL VALUE.
|
||||
- IMPORTANT: If tool call returns property values that are related BUT NOT SYNONYMS to the user's requested value: USE USER'S ORIGINAL VALUE.
|
||||
For example, if the user asks for $browser to be "Chrome" and the tool call returns '"Firefox", "Safari"', use "Chrome" as the property value. Since "Chrome" is related to "Firefox" and "Safari" since they are all browsers.
|
||||
|
||||
- IMPORTANT: If tool call returns property values that are synonyms, typos, or a variant of the user's requested value: USE FOUND VALUES
|
||||
For example the user asks for the city to be "New York" and the tool call returns "New York City", "NYC", use "New York City" as the property value. Since "New York" is related to "New York City" and "NYC" since they are all variants of New York.
|
||||
|
||||
5. **Optimization**:
|
||||
- Remember that you are able to make parallel tool calls. This is a big performance improvement. Whenever it makes sense to do so, call multiple tools at once.
|
||||
- Always aim to optimize your tool calls. This will help you find the correct properties and values faster.
|
||||
|
||||
6. **Filter Completion**:
|
||||
- Always aim to complete the filter as much as possible. This will help you meet the user's expectations.
|
||||
- If you have found most of the properties and values but you are still missing some, return the filter that you have found so far. The user can always ask you to add more properties and values later.
|
||||
- Be careful though, if you have not found most of the properties and values, you should use the `ask_user_for_help` tool to ask the user for more information.
|
||||
Example: If the user asks to filter for location, url type, date and browser type, and you could not find anything about the url you can return the filter you found.
|
||||
|
||||
|
||||
- If the tool call returns no values, you can retry with the next best property or entity.
|
||||
</tool_usage>
|
||||
""".strip()
|
||||
|
||||
@@ -50,15 +50,27 @@ class TestEntities(ClickhouseTestMixin, NonAtomicBaseTest):
|
||||
properties={"name": "Test User"},
|
||||
)
|
||||
|
||||
self.toolkit = DummyToolkit(self.team)
|
||||
self.toolkit = DummyToolkit(self.team, self.user)
|
||||
|
||||
async def test_retrieve_entity_properties(self):
|
||||
result = await self.toolkit.retrieve_entity_properties("person")
|
||||
result = await self.toolkit.retrieve_entity_properties_parallel(["person"])
|
||||
assert (
|
||||
"<properties><String><prop><name>name</name></prop><prop><name>property_no_values</name></prop></String></properties>"
|
||||
== result
|
||||
== result["person"]
|
||||
)
|
||||
|
||||
async def test_retrieve_entity_properties_entity_not_found(self):
|
||||
result = await self.toolkit.retrieve_entity_properties_parallel(["test"])
|
||||
assert "Entity test not found. Available entities: person, session, organization, project" == result["test"]
|
||||
|
||||
async def test_retrieve_entity_properties_entity_with_group(self):
|
||||
result = await self.toolkit.retrieve_entity_properties_parallel(["organization", "session"])
|
||||
assert "session" in result
|
||||
assert (
|
||||
"<properties><String><prop><name>name_group</name></prop></String></properties>" == result["organization"]
|
||||
)
|
||||
assert "<properties>" in result["session"]
|
||||
|
||||
async def test_person_property_values_exists(self):
|
||||
result = await self.toolkit._get_entity_names()
|
||||
expected = ["person", "session", "organization", "project"]
|
||||
|
||||
@@ -64,11 +64,18 @@ class TestEvents(ClickhouseTestMixin, NonAtomicBaseTest):
|
||||
},
|
||||
)
|
||||
|
||||
_create_event(
|
||||
event="no-properties-event",
|
||||
distinct_id="user456",
|
||||
team=self.team,
|
||||
properties={},
|
||||
)
|
||||
|
||||
Action.objects.create(
|
||||
id=232, team=self.team, name="action1", description="Test Description", steps_json=[{"event": "event1"}]
|
||||
)
|
||||
|
||||
self.toolkit = DummyToolkit(self.team)
|
||||
self.toolkit = DummyToolkit(self.team, self.user)
|
||||
|
||||
async def test_events_property_values_exists(self):
|
||||
result = await self.toolkit._get_entity_names()
|
||||
@@ -121,14 +128,35 @@ class TestEvents(ClickhouseTestMixin, NonAtomicBaseTest):
|
||||
assert "Firefox" in "\n".join(property_vals.get(232, []))
|
||||
|
||||
async def test_retrieve_event_or_action_properties_action_not_found(self):
|
||||
result = await self.toolkit.retrieve_event_or_action_properties(999)
|
||||
result = await self.toolkit.retrieve_event_or_action_properties_parallel([999])
|
||||
assert (
|
||||
"Action 999 does not exist in the taxonomy. Verify that the action ID is correct and try again." == result
|
||||
"Action 999 does not exist in the taxonomy. Verify that the action ID is correct and try again."
|
||||
== result["999"]
|
||||
)
|
||||
|
||||
async def test_retrieve_event_or_action_properties_event_not_found(self):
|
||||
result = await self.toolkit.retrieve_event_or_action_properties("test")
|
||||
assert "Properties do not exist in the taxonomy for the event test." == result
|
||||
result = await self.toolkit.retrieve_event_or_action_properties_parallel(["test"])
|
||||
assert "Properties do not exist in the taxonomy for the event test." == result["test"]
|
||||
|
||||
async def test_retrieve_event_or_action_properties_action_mixed(self):
|
||||
result = await self.toolkit.retrieve_event_or_action_properties_parallel([232, "event1"])
|
||||
|
||||
assert "event1" in result
|
||||
assert (
|
||||
"<properties><String><prop><name>id</name></prop><prop><name>$browser</name><description>Name of the browser the user has used.</description></prop></String></properties>"
|
||||
== result["event1"]
|
||||
)
|
||||
assert "<properties>" in result["232"]
|
||||
|
||||
async def test_retrieve_event_or_action_properties_action_no_properties(self):
|
||||
result = await self.toolkit.retrieve_event_or_action_properties_parallel([232, "no-properties-event"])
|
||||
|
||||
assert "no-properties-event" in result
|
||||
assert (
|
||||
"Properties do not exist in the taxonomy for the event no-properties-event."
|
||||
== result["no-properties-event"]
|
||||
)
|
||||
assert "<properties>" in result["232"]
|
||||
|
||||
def tearDown(self):
|
||||
flush_persons_and_events()
|
||||
|
||||
@@ -74,7 +74,7 @@ class TestGroups(ClickhouseTestMixin, NonAtomicBaseTest):
|
||||
sync=True, # Force sync to ClickHouse
|
||||
)
|
||||
|
||||
self.toolkit = DummyToolkit(self.team)
|
||||
self.toolkit = DummyToolkit(self.team, self.user)
|
||||
|
||||
async def test_entity_names_with_existing_groups(self):
|
||||
# Test that the entity names include the groups we created in setUp
|
||||
@@ -107,21 +107,31 @@ class TestGroups(ClickhouseTestMixin, NonAtomicBaseTest):
|
||||
self.assertIn("project", property_vals)
|
||||
|
||||
async def test_retrieve_entity_properties_group(self):
|
||||
result = await self.toolkit.retrieve_entity_properties("organization")
|
||||
result = await self.toolkit.retrieve_entity_properties_parallel(["organization"])
|
||||
|
||||
assert (
|
||||
"<properties><String><prop><name>name</name></prop><prop><name>industry</name></prop><prop><name>name_group</name></prop></String></properties>"
|
||||
== result
|
||||
== result["organization"]
|
||||
)
|
||||
|
||||
async def test_retrieve_entity_properties_group_not_found(self):
|
||||
result = await self.toolkit.retrieve_entity_properties("test")
|
||||
result = await self.toolkit.retrieve_entity_properties_parallel(["test"])
|
||||
|
||||
assert (
|
||||
"Entity test not found. Available entities: person, session, organization, project, no_properties" == result
|
||||
"Entity test not found. Available entities: person, session, organization, project, no_properties"
|
||||
== result["test"]
|
||||
)
|
||||
|
||||
async def test_retrieve_entity_properties_group_nothing_found(self):
|
||||
result = await self.toolkit.retrieve_entity_properties("no_properties")
|
||||
result = await self.toolkit.retrieve_entity_properties_parallel(["no_properties"])
|
||||
|
||||
assert "Properties do not exist in the taxonomy for the entity no_properties." == result
|
||||
assert "Properties do not exist in the taxonomy for the entity no_properties." == result["no_properties"]
|
||||
|
||||
async def test_retrieve_entity_properties_group_mixed(self):
|
||||
result = await self.toolkit.retrieve_entity_properties_parallel(["organization", "no_properties", "project"])
|
||||
|
||||
assert "organization" in result
|
||||
assert "<properties>" in result["organization"]
|
||||
assert "Properties do not exist in the taxonomy for the entity no_properties." == result["no_properties"]
|
||||
assert "project" in result
|
||||
assert "<properties>" in result["project"]
|
||||
|
||||
@@ -165,10 +165,10 @@ class TestTaxonomyAgentToolsNode(BaseTest):
|
||||
mock_input.name = "test_tool"
|
||||
mock_input.arguments = Mock()
|
||||
mock_get_tool_input.return_value = mock_input
|
||||
mock_handle_tools.return_value = ("test_tool", "tool output")
|
||||
mock_handle_tools.return_value = {"test_tool_id": "tool output"}
|
||||
|
||||
# Create state with intermediate step
|
||||
action = AgentAction(tool="test_tool", tool_input={"param": "value"}, log="test_log")
|
||||
action = AgentAction(tool="test_tool", tool_input={"param": "value"}, log="test_tool_id")
|
||||
state = TaxonomyAgentState()
|
||||
state.intermediate_steps = [(action, None)]
|
||||
|
||||
@@ -177,6 +177,7 @@ class TestTaxonomyAgentToolsNode(BaseTest):
|
||||
self.assertIsInstance(result, TaxonomyAgentState)
|
||||
self.assertEqual(len(result.intermediate_steps), 1)
|
||||
self.assertEqual(result.intermediate_steps[0][1], "tool output")
|
||||
self.assertEqual(result.intermediate_steps[0][0].log, "test_tool_id")
|
||||
|
||||
@patch.object(MockTaxonomyAgentToolkit, "get_tool_input_model")
|
||||
async def test_run_validation_error(self, mock_get_tool_input):
|
||||
@@ -193,7 +194,7 @@ class TestTaxonomyAgentToolsNode(BaseTest):
|
||||
result = await self.node.arun(state, RunnableConfig())
|
||||
|
||||
self.assertIsInstance(result, TaxonomyAgentState)
|
||||
self.assertEqual(len(result.tool_progress_messages), 1)
|
||||
self.assertEqual(len(result.tool_progress_messages), 0)
|
||||
|
||||
@patch.object(MockTaxonomyAgentToolkit, "get_tool_input_model")
|
||||
async def test_run_final_answer(self, mock_get_tool_input):
|
||||
@@ -251,6 +252,7 @@ class TestTaxonomyAgentToolsNode(BaseTest):
|
||||
|
||||
state = TaxonomyAgentState()
|
||||
state.intermediate_steps = actions
|
||||
state.iteration_count = self.node.MAX_ITERATIONS
|
||||
|
||||
with patch.object(self.node, "_get_reset_state") as mock_reset:
|
||||
mock_reset.return_value = TaxonomyAgentState()
|
||||
@@ -300,7 +302,7 @@ class TestTaxonomyAgentToolsNode(BaseTest):
|
||||
result = self.node._get_reset_state("test output", "test_tool", original_state)
|
||||
|
||||
self.assertEqual(len(result.intermediate_steps), 1)
|
||||
action, output = result.intermediate_steps[0]
|
||||
action, output = result.intermediate_steps[0] # type: ignore
|
||||
self.assertEqual(action.tool, "test_tool")
|
||||
self.assertEqual(action.tool_input, "test output")
|
||||
self.assertIsNone(output)
|
||||
|
||||
@@ -6,6 +6,7 @@ from parameterized import parameterized
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ee.hogai.graph.taxonomy.toolkit import TaxonomyAgentToolkit, TaxonomyToolNotFoundError
|
||||
from ee.hogai.graph.taxonomy.tools import TaxonomyTool
|
||||
|
||||
|
||||
class DummyToolkit(TaxonomyAgentToolkit):
|
||||
@@ -16,7 +17,7 @@ class DummyToolkit(TaxonomyAgentToolkit):
|
||||
class TestTaxonomyAgentToolkit(BaseTest):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.toolkit = DummyToolkit(self.team)
|
||||
self.toolkit = DummyToolkit(self.team, self.user)
|
||||
|
||||
async def test_toolkit_initialization(self):
|
||||
self.assertEqual(self.toolkit._team, self.team)
|
||||
@@ -69,12 +70,11 @@ class TestTaxonomyAgentToolkit(BaseTest):
|
||||
("retrieve_entity_property_values", {"entity": "person", "property_name": "email"}, "mocked"),
|
||||
("retrieve_event_properties", {"event_name": "test_event"}, "mocked"),
|
||||
("retrieve_event_property_values", {"event_name": "test_event", "property_name": "$browser"}, "mocked"),
|
||||
("ask_user_for_help", {"request": "Help needed"}, "Help needed"),
|
||||
]
|
||||
)
|
||||
@patch.object(DummyToolkit, "retrieve_entity_properties", return_value="mocked")
|
||||
@patch.object(DummyToolkit, "retrieve_entity_properties_parallel", return_value={"person": "mocked"})
|
||||
@patch.object(DummyToolkit, "retrieve_entity_property_values", return_value={"person": ["mocked"]})
|
||||
@patch.object(DummyToolkit, "retrieve_event_or_action_properties", return_value="mocked")
|
||||
@patch.object(DummyToolkit, "retrieve_event_or_action_properties_parallel", return_value={"test_event": "mocked"})
|
||||
@patch.object(DummyToolkit, "retrieve_event_or_action_property_values", return_value={"test_event": ["mocked"]})
|
||||
async def test_handle_tools(self, tool_name, tool_args, expected_result, *mocks):
|
||||
class Arguments(BaseModel):
|
||||
@@ -83,25 +83,25 @@ class TestTaxonomyAgentToolkit(BaseTest):
|
||||
for key, value in tool_args.items():
|
||||
setattr(Arguments, key, value)
|
||||
|
||||
class ToolInput(BaseModel):
|
||||
class ToolInput(TaxonomyTool):
|
||||
name: str
|
||||
arguments: Arguments
|
||||
|
||||
tool_input = ToolInput(name=tool_name, arguments=Arguments(**tool_args))
|
||||
tool_name_result, result = await self.toolkit.handle_tools(tool_name, tool_input)
|
||||
result = await self.toolkit.handle_tools({tool_name: [(tool_input, "test_call_id")]})
|
||||
|
||||
self.assertEqual(result, expected_result)
|
||||
self.assertEqual(tool_name_result, tool_name)
|
||||
self.assertEqual(len(result), 1)
|
||||
self.assertEqual(result["test_call_id"], expected_result)
|
||||
|
||||
async def test_handle_tools_invalid_tool(self):
|
||||
class ToolInput(BaseModel):
|
||||
class ToolInput(TaxonomyTool):
|
||||
name: str = "invalid_tool"
|
||||
arguments: dict = {}
|
||||
|
||||
tool_input = ToolInput()
|
||||
|
||||
with self.assertRaises(TaxonomyToolNotFoundError):
|
||||
await self.toolkit.handle_tools("invalid_tool", tool_input)
|
||||
await self.toolkit.handle_tools({"invalid_tool": [(tool_input, "invalid_tool")]})
|
||||
|
||||
def test_format_properties_formats(self):
|
||||
props = [("prop1", "String", "Test description"), ("prop2", "Numeric", None)]
|
||||
@@ -125,7 +125,12 @@ class TestTaxonomyAgentToolkit(BaseTest):
|
||||
|
||||
@parameterized.expand(
|
||||
[
|
||||
("retrieve_entity_properties", {"entity": "person"}, "retrieve_entity_properties", {"entity": "person"}),
|
||||
(
|
||||
"retrieve_entity_properties_parallel",
|
||||
{"entity": "person"},
|
||||
"retrieve_entity_properties_parallel",
|
||||
{"entity": "person"},
|
||||
),
|
||||
(
|
||||
"retrieve_event_properties",
|
||||
{"event_name": "test_event"},
|
||||
@@ -150,7 +155,12 @@ class TestTaxonomyAgentToolkit(BaseTest):
|
||||
"retrieve_event_property_values",
|
||||
{"event_name": "test_event", "property_name": "$browser"},
|
||||
),
|
||||
("retrieve_entity_properties", {"entity": "session"}, "retrieve_entity_properties", {"entity": "session"}),
|
||||
(
|
||||
"retrieve_entity_properties_parallel",
|
||||
{"entity": "session"},
|
||||
"retrieve_entity_properties_parallel",
|
||||
{"entity": "session"},
|
||||
),
|
||||
]
|
||||
)
|
||||
def test_get_tool_input_model_with_valid_tools(self, tool_name, tool_input, expected_name, expected_args):
|
||||
@@ -177,7 +187,7 @@ class TestTaxonomyAgentToolkit(BaseTest):
|
||||
|
||||
return [CustomTool]
|
||||
|
||||
custom_toolkit = CustomToolkit(self.team)
|
||||
custom_toolkit = CustomToolkit(self.team, self.user)
|
||||
|
||||
action = AgentAction(tool="custom_tool", tool_input={"custom_field": "test_value"}, log="test log")
|
||||
|
||||
@@ -195,7 +205,7 @@ class TestTaxonomyAgentToolkit(BaseTest):
|
||||
def _get_custom_tools(self):
|
||||
raise NotImplementedError("This is a test error")
|
||||
|
||||
basic_toolkit = BasicToolkit(self.team)
|
||||
basic_toolkit = BasicToolkit(self.team, self.user)
|
||||
|
||||
# Should not raise NotImplementedError, should fall back to default tools
|
||||
tools = basic_toolkit.get_tools()
|
||||
@@ -230,7 +240,7 @@ class TestTaxonomyAgentToolkit(BaseTest):
|
||||
|
||||
return [custom_tool_1, custom_tool_2]
|
||||
|
||||
custom_toolkit = CustomToolkit(self.team)
|
||||
custom_toolkit = CustomToolkit(self.team, self.user)
|
||||
|
||||
# Should return both default and custom tools
|
||||
tools = custom_toolkit.get_tools()
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
import json
|
||||
import asyncio
|
||||
from collections.abc import Iterable
|
||||
from functools import cached_property
|
||||
from typing import Any, Optional, Union, cast
|
||||
from typing import Optional, Union, cast
|
||||
|
||||
from langchain_core.agents import AgentAction
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from pydantic import BaseModel
|
||||
|
||||
from posthog.schema import (
|
||||
@@ -14,6 +16,8 @@ from posthog.schema import (
|
||||
EventTaxonomyItem,
|
||||
EventTaxonomyQuery,
|
||||
QueryStatusResponse,
|
||||
TaskExecutionItem,
|
||||
TaskExecutionStatus,
|
||||
)
|
||||
|
||||
from posthog.hogql.database.schema.channel_type import DEFAULT_CHANNEL_TYPES
|
||||
@@ -22,7 +26,7 @@ from posthog.clickhouse.query_tagging import Product, tags_context
|
||||
from posthog.hogql_queries.ai.actors_property_taxonomy_query_runner import ActorsPropertyTaxonomyQueryRunner
|
||||
from posthog.hogql_queries.ai.event_taxonomy_query_runner import EventTaxonomyQueryRunner
|
||||
from posthog.hogql_queries.query_runner import ExecutionMode
|
||||
from posthog.models import Action, Team
|
||||
from posthog.models import Action, Team, User
|
||||
from posthog.models.group_type_mapping import GroupTypeMapping
|
||||
from posthog.models.property_definition import PropertyDefinition, PropertyType
|
||||
from posthog.sync import database_sync_to_async
|
||||
@@ -34,15 +38,18 @@ from ee.hogai.graph.taxonomy.format import (
|
||||
format_properties_yaml,
|
||||
format_property_values,
|
||||
)
|
||||
from ee.hogai.utils.types.base import BaseStateWithTasks, TaskArtifact, TaskResult
|
||||
from ee.hogai.utils.types.composed import MaxNodeName
|
||||
|
||||
from ..parallel_task_execution.nodes import BaseTaskExecutorNode, TaskExecutionInputTuple
|
||||
from .tools import (
|
||||
TaxonomyTool,
|
||||
ask_user_for_help,
|
||||
retrieve_entity_properties,
|
||||
retrieve_entity_property_values,
|
||||
get_dynamic_entity_tools,
|
||||
retrieve_event_properties,
|
||||
retrieve_event_property_values,
|
||||
)
|
||||
from .types import TaxonomyNodeName
|
||||
|
||||
|
||||
class TaxonomyToolNotFoundError(Exception):
|
||||
@@ -104,12 +111,45 @@ class TaxonomyErrorMessages:
|
||||
return f"Properties do not exist in the taxonomy for the {event_name}."
|
||||
|
||||
|
||||
class TaxonomyTaskExecutorNode(
|
||||
BaseTaskExecutorNode[
|
||||
BaseStateWithTasks,
|
||||
BaseStateWithTasks,
|
||||
]
|
||||
):
|
||||
"""
|
||||
Task executor node specifically for taxonomy operations.
|
||||
"""
|
||||
|
||||
@property
|
||||
def node_name(self) -> MaxNodeName:
|
||||
return TaxonomyNodeName.TASK_EXECUTOR
|
||||
|
||||
async def _aget_input_tuples(self, state: BaseStateWithTasks) -> list[TaskExecutionInputTuple]:
|
||||
taxonomy_toolkit = TaxonomyAgentToolkit(self._team, self._user)
|
||||
if not state.tasks:
|
||||
raise ValueError("No tasks to execute")
|
||||
input_tuples: list[TaskExecutionInputTuple] = []
|
||||
for task in state.tasks:
|
||||
if task.task_type == "retrieve_event_or_action_properties":
|
||||
input_tuples.append((task, [], taxonomy_toolkit._handle_event_or_action_properties_task))
|
||||
elif task.task_type == "retrieve_entity_properties":
|
||||
input_tuples.append((task, [], taxonomy_toolkit._handle_entity_properties_task))
|
||||
elif task.task_type == "retrieve_group_properties":
|
||||
input_tuples.append((task, [], taxonomy_toolkit._handle_group_properties_task))
|
||||
else:
|
||||
raise ValueError(f"Unsupported task type: {task.task_type}")
|
||||
return input_tuples
|
||||
|
||||
|
||||
class TaxonomyAgentToolkit:
|
||||
"""Base toolkit for taxonomy agents that handle tool execution."""
|
||||
|
||||
def __init__(self, team: Team):
|
||||
def __init__(self, team: Team, user: User):
|
||||
self._team = team
|
||||
self._user = user
|
||||
self.MAX_ENTITIES_PER_BATCH = 6
|
||||
self.MAX_PROPERTIES = 500
|
||||
|
||||
@property
|
||||
def _groups(self):
|
||||
@@ -139,6 +179,10 @@ class TaxonomyAgentToolkit:
|
||||
def _get_entity_names(self) -> list[str]:
|
||||
return self._entity_names
|
||||
|
||||
@database_sync_to_async(thread_sensitive=False)
|
||||
def _get_team_group_types(self) -> list[str]:
|
||||
return self._team_group_types
|
||||
|
||||
def _enrich_props_with_descriptions(self, entity: str, props: Iterable[tuple[str, str | None]]):
|
||||
return enrich_props_with_descriptions(entity, props)
|
||||
|
||||
@@ -181,13 +225,18 @@ class TaxonomyAgentToolkit:
|
||||
Retrieve event/action taxonomy with efficient caching.
|
||||
Multiple properties are batched in a single query to maximize cache hits.
|
||||
"""
|
||||
is_event = isinstance(event_name_or_action_id, str)
|
||||
try:
|
||||
action_id = int(event_name_or_action_id)
|
||||
is_event = False
|
||||
except ValueError:
|
||||
is_event = True
|
||||
|
||||
if is_event:
|
||||
query = EventTaxonomyQuery(event=event_name_or_action_id, maxPropertyValues=25, properties=properties)
|
||||
verbose_name = f"event {event_name_or_action_id}"
|
||||
else:
|
||||
query = EventTaxonomyQuery(actionId=event_name_or_action_id, maxPropertyValues=25, properties=properties)
|
||||
verbose_name = f"action with ID {event_name_or_action_id}"
|
||||
query = EventTaxonomyQuery(actionId=action_id, maxPropertyValues=25, properties=properties)
|
||||
verbose_name = f"action with ID {action_id}"
|
||||
runner = EventTaxonomyQueryRunner(query, self._team)
|
||||
with tags_context(product=Product.MAX_AI, team_id=self._team.pk, org_id=self._team.organization_id):
|
||||
# Use cache-first execution mode for optimal performance
|
||||
@@ -226,11 +275,16 @@ class TaxonomyAgentToolkit:
|
||||
|
||||
def _get_default_tools(self) -> list:
|
||||
"""Get default taxonomy tools."""
|
||||
dynamic_retrieve_entity_properties, dynamic_retrieve_entity_property_values = get_dynamic_entity_tools(
|
||||
self._team_group_types
|
||||
)
|
||||
return [
|
||||
retrieve_event_properties,
|
||||
retrieve_entity_properties,
|
||||
retrieve_entity_property_values,
|
||||
# retrieve_entity_properties,
|
||||
# retrieve_entity_property_values,
|
||||
dynamic_retrieve_entity_properties,
|
||||
retrieve_event_property_values,
|
||||
dynamic_retrieve_entity_property_values,
|
||||
ask_user_for_help,
|
||||
]
|
||||
|
||||
@@ -238,49 +292,6 @@ class TaxonomyAgentToolkit:
|
||||
"""Get custom tools. Override in subclasses to add custom tools."""
|
||||
raise NotImplementedError("_get_custom_tools must be implemented in subclasses")
|
||||
|
||||
async def retrieve_entity_properties(self, entity: str, max_properties: int = 500) -> str:
|
||||
"""
|
||||
Retrieve properties for an entitiy like person, session, or one of the groups.
|
||||
"""
|
||||
entity_names = await self._get_entity_names()
|
||||
if entity not in entity_names:
|
||||
return TaxonomyErrorMessages.entity_not_found(entity, entity_names)
|
||||
|
||||
props: list[Any] = []
|
||||
if entity == "person":
|
||||
qs = PropertyDefinition.objects.filter(team=self._team, type=PropertyDefinition.Type.PERSON).values_list(
|
||||
"name", "property_type"
|
||||
)
|
||||
props = self._enrich_props_with_descriptions("person", [prop async for prop in qs])
|
||||
elif entity == "session":
|
||||
# Session properties are not in the DB.
|
||||
props = self._enrich_props_with_descriptions(
|
||||
"session",
|
||||
[
|
||||
(prop_name, prop["type"])
|
||||
for prop_name, prop in CORE_FILTER_DEFINITIONS_BY_GROUP["session_properties"].items()
|
||||
if prop.get("type") is not None
|
||||
],
|
||||
)
|
||||
else:
|
||||
group_type_index = None
|
||||
groups = [group async for group in self._groups]
|
||||
for group in groups:
|
||||
if group.group_type == entity:
|
||||
group_type_index = group.group_type_index
|
||||
break
|
||||
if group_type_index is None:
|
||||
return f"Group {entity} does not exist in the taxonomy."
|
||||
qs = PropertyDefinition.objects.filter(
|
||||
team=self._team, type=PropertyDefinition.Type.GROUP, group_type_index=group_type_index
|
||||
).values_list("name", "property_type")[:max_properties]
|
||||
props = self._enrich_props_with_descriptions(entity, [prop async for prop in qs])
|
||||
|
||||
if not props:
|
||||
return f"Properties do not exist in the taxonomy for the entity {entity}."
|
||||
|
||||
return self._format_properties(props)
|
||||
|
||||
async def retrieve_entity_property_values(self, entity_properties: dict[str, list[str]]) -> dict[str, list[str]]:
|
||||
result = await self._parallel_entity_processing(entity_properties)
|
||||
return result
|
||||
@@ -348,6 +359,251 @@ class TaxonomyAgentToolkit:
|
||||
)
|
||||
return results
|
||||
|
||||
async def retrieve_entity_properties_parallel(self, entities: list[str]) -> dict[str, str]:
|
||||
entity_tasks = []
|
||||
groups = []
|
||||
|
||||
for entity in entities:
|
||||
group_types = [group.group_type async for group in self._groups]
|
||||
|
||||
if entity in group_types:
|
||||
groups.append(entity)
|
||||
else:
|
||||
entity_tasks.append(
|
||||
TaskExecutionItem(
|
||||
id=str(entity),
|
||||
prompt=entity,
|
||||
status=TaskExecutionStatus.PENDING,
|
||||
description="Retrieving entity properties",
|
||||
progress_text=f"Retrieving properties for {entity}...",
|
||||
task_type="retrieve_entity_properties",
|
||||
)
|
||||
)
|
||||
|
||||
if groups:
|
||||
entity_tasks.append(
|
||||
TaskExecutionItem(
|
||||
id="group_properties",
|
||||
prompt=str(groups),
|
||||
status=TaskExecutionStatus.PENDING,
|
||||
description="Retrieving group properties",
|
||||
progress_text="Retrieving properties for groups...",
|
||||
task_type="retrieve_group_properties",
|
||||
)
|
||||
)
|
||||
task_executor_state = BaseStateWithTasks(
|
||||
tasks=entity_tasks,
|
||||
)
|
||||
config = RunnableConfig()
|
||||
executor = TaxonomyTaskExecutorNode(self._team, self._user)
|
||||
result = await executor.arun(task_executor_state, config)
|
||||
|
||||
final_result = {}
|
||||
for task_result in result.task_results:
|
||||
if task_result.id == "group_properties":
|
||||
for artifact in task_result.artifacts:
|
||||
if artifact.id is None:
|
||||
continue
|
||||
if isinstance(artifact.id, int):
|
||||
final_result[str(artifact.id)] = artifact.content
|
||||
else:
|
||||
final_result[artifact.id] = artifact.content
|
||||
else:
|
||||
final_result[task_result.id] = task_result.result
|
||||
return final_result
|
||||
|
||||
async def _handle_event_or_action_properties_task(self, input_dict: dict) -> TaskResult:
|
||||
"""
|
||||
Retrieve properties for an event.
|
||||
"""
|
||||
|
||||
task = cast(TaskExecutionItem, input_dict["task"])
|
||||
|
||||
try:
|
||||
response, verbose_name = await self._retrieve_event_or_action_taxonomy(task.prompt)
|
||||
except Action.DoesNotExist:
|
||||
project_actions = await self._get_project_actions()
|
||||
if not project_actions:
|
||||
result = TaxonomyErrorMessages.no_actions_exist()
|
||||
return TaskResult(
|
||||
id=task.id,
|
||||
description=task.description,
|
||||
result=result,
|
||||
artifacts=[],
|
||||
status=TaskExecutionStatus.FAILED,
|
||||
)
|
||||
result = TaxonomyErrorMessages.action_not_found(task.prompt)
|
||||
return TaskResult(
|
||||
id=task.id,
|
||||
description=task.description,
|
||||
result=result,
|
||||
artifacts=[],
|
||||
status=TaskExecutionStatus.FAILED,
|
||||
)
|
||||
if not isinstance(response, CachedEventTaxonomyQueryResponse):
|
||||
result = TaxonomyErrorMessages.generic_not_found("Properties")
|
||||
return TaskResult(
|
||||
id=task.id,
|
||||
description=task.description,
|
||||
result=result,
|
||||
artifacts=[],
|
||||
status=TaskExecutionStatus.FAILED,
|
||||
)
|
||||
if not response.results:
|
||||
result = TaxonomyErrorMessages.event_properties_not_found(verbose_name)
|
||||
return TaskResult(
|
||||
id=task.id,
|
||||
description=task.description,
|
||||
result=result,
|
||||
artifacts=[],
|
||||
status=TaskExecutionStatus.FAILED,
|
||||
)
|
||||
|
||||
qs = PropertyDefinition.objects.filter(
|
||||
team=self._team, type=PropertyDefinition.Type.EVENT, name__in=[item.property for item in response.results]
|
||||
)
|
||||
property_definitions = [prop async for prop in qs]
|
||||
property_to_type = {
|
||||
property_definition.name: property_definition.property_type for property_definition in property_definitions
|
||||
}
|
||||
props = [
|
||||
(item.property, property_to_type.get(item.property))
|
||||
for item in response.results
|
||||
# Exclude properties that exist in the taxonomy, but don't have a type.
|
||||
if item.property in property_to_type
|
||||
]
|
||||
|
||||
if not props:
|
||||
result = TaxonomyErrorMessages.event_properties_not_found(verbose_name)
|
||||
return TaskResult(
|
||||
id=task.id,
|
||||
description=task.description,
|
||||
result=result,
|
||||
artifacts=[],
|
||||
status=TaskExecutionStatus.FAILED,
|
||||
)
|
||||
|
||||
formatted_properties = self._format_properties(self._enrich_props_with_descriptions("event", props))
|
||||
return TaskResult(
|
||||
id=task.id,
|
||||
description=task.description,
|
||||
result=formatted_properties,
|
||||
artifacts=[],
|
||||
status=TaskExecutionStatus.COMPLETED,
|
||||
)
|
||||
|
||||
async def _handle_group_properties_task(self, input_dict: dict) -> TaskResult:
|
||||
task = cast(TaskExecutionItem, input_dict["task"])
|
||||
try:
|
||||
# Convert Python list string to JSON format and parse
|
||||
json_str = task.prompt.replace("'", '"')
|
||||
group_entities = json.loads(json_str)
|
||||
except json.JSONDecodeError:
|
||||
group_entities = [task.prompt]
|
||||
|
||||
entity_to_group_index = {}
|
||||
artifacts = []
|
||||
for group_entity in group_entities:
|
||||
group_types = {group.group_type: group.group_type_index async for group in self._groups}
|
||||
group_type_index = group_types.get(group_entity, None)
|
||||
if group_type_index is not None:
|
||||
entity_to_group_index[group_entity] = group_type_index
|
||||
else:
|
||||
artifacts.append(
|
||||
TaskArtifact(
|
||||
id=group_entity,
|
||||
task_id=task.id,
|
||||
content=TaxonomyErrorMessages.properties_not_found(group_entity),
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
if entity_to_group_index.values():
|
||||
# Single query for all group types
|
||||
group_qs = PropertyDefinition.objects.filter(
|
||||
team=self._team,
|
||||
type=PropertyDefinition.Type.GROUP,
|
||||
group_type_index__in=entity_to_group_index.values(),
|
||||
).values_list("name", "property_type", "group_type_index")[: self.MAX_PROPERTIES]
|
||||
group_qs_definitions = [prop async for prop in group_qs]
|
||||
# Group results by entity
|
||||
for entity in group_entities:
|
||||
if entity in entity_to_group_index.keys():
|
||||
group_index = entity_to_group_index[entity]
|
||||
properties = [
|
||||
(name, prop_type) for name, prop_type, gti in group_qs_definitions if gti == group_index
|
||||
]
|
||||
result = (
|
||||
self._format_properties(self._enrich_props_with_descriptions(entity, properties))
|
||||
if properties
|
||||
else TaxonomyErrorMessages.properties_not_found(entity)
|
||||
)
|
||||
artifacts.append(
|
||||
TaskArtifact(
|
||||
id=entity,
|
||||
task_id="group_properties",
|
||||
content=result,
|
||||
)
|
||||
)
|
||||
|
||||
return TaskResult(
|
||||
id=task.id,
|
||||
description=task.description,
|
||||
result="",
|
||||
artifacts=artifacts,
|
||||
status=TaskExecutionStatus.COMPLETED,
|
||||
)
|
||||
|
||||
async def _handle_entity_properties_task(self, input_dict: dict) -> TaskResult:
|
||||
task = cast(TaskExecutionItem, input_dict["task"])
|
||||
entity = task.prompt
|
||||
if entity == "person":
|
||||
person_qs = PropertyDefinition.objects.filter(
|
||||
team=self._team, type=PropertyDefinition.Type.PERSON
|
||||
).values_list("name", "property_type")
|
||||
person_definitions = [prop async for prop in person_qs]
|
||||
if person_definitions:
|
||||
result = self._format_properties(self._enrich_props_with_descriptions("person", person_definitions))
|
||||
status = TaskExecutionStatus.COMPLETED
|
||||
else:
|
||||
result = TaxonomyErrorMessages.properties_not_found(entity)
|
||||
status = TaskExecutionStatus.FAILED
|
||||
return TaskResult(
|
||||
id=task.id,
|
||||
description=task.description,
|
||||
result=result,
|
||||
artifacts=[],
|
||||
status=status,
|
||||
)
|
||||
elif entity == "session":
|
||||
props = [
|
||||
(prop_name, prop["type"])
|
||||
for prop_name, prop in CORE_FILTER_DEFINITIONS_BY_GROUP["session_properties"].items()
|
||||
if prop.get("type") is not None
|
||||
]
|
||||
|
||||
if props:
|
||||
result = self._format_properties(self._enrich_props_with_descriptions("session", props))
|
||||
status = TaskExecutionStatus.COMPLETED
|
||||
else:
|
||||
result = TaxonomyErrorMessages.properties_not_found(entity)
|
||||
status = TaskExecutionStatus.FAILED
|
||||
return TaskResult(
|
||||
id=task.id,
|
||||
description=task.description,
|
||||
result=result,
|
||||
artifacts=[],
|
||||
status=status,
|
||||
)
|
||||
else:
|
||||
return TaskResult(
|
||||
id=task.id,
|
||||
description=task.description,
|
||||
result=TaxonomyErrorMessages.entity_not_found(entity, await self._get_entity_names()),
|
||||
artifacts=[],
|
||||
status=TaskExecutionStatus.FAILED,
|
||||
)
|
||||
|
||||
@database_sync_to_async(thread_sensitive=False)
|
||||
def _get_definitions_for_entity(
|
||||
self, entity: str, property_names: list[str], query: ActorsPropertyTaxonomyQuery
|
||||
@@ -402,47 +658,35 @@ class TaxonomyAgentToolkit:
|
||||
def _get_project_actions(self) -> list[Action]:
|
||||
return list(Action.objects.filter(team__project_id=self._team.project_id, deleted=False))
|
||||
|
||||
async def retrieve_event_or_action_properties(self, event_name_or_action_id: str | int) -> str:
|
||||
"""
|
||||
Retrieve properties for an event.
|
||||
"""
|
||||
try:
|
||||
response, verbose_name = await self._retrieve_event_or_action_taxonomy(event_name_or_action_id)
|
||||
except Action.DoesNotExist:
|
||||
project_actions = await self._get_project_actions()
|
||||
if not project_actions:
|
||||
return TaxonomyErrorMessages.no_actions_exist()
|
||||
return TaxonomyErrorMessages.action_not_found(event_name_or_action_id)
|
||||
if not isinstance(response, CachedEventTaxonomyQueryResponse):
|
||||
return TaxonomyErrorMessages.generic_not_found("Properties")
|
||||
if not response.results:
|
||||
return TaxonomyErrorMessages.event_properties_not_found(verbose_name)
|
||||
|
||||
qs = PropertyDefinition.objects.filter(
|
||||
team=self._team, type=PropertyDefinition.Type.EVENT, name__in=[item.property for item in response.results]
|
||||
async def retrieve_event_or_action_properties_parallel(
|
||||
self, event_name_or_action_ids: list[str | int]
|
||||
) -> dict[str, str]:
|
||||
task_executor_state = BaseStateWithTasks(
|
||||
tasks=[
|
||||
TaskExecutionItem(
|
||||
id=str(event_name_or_action_id),
|
||||
prompt=str(event_name_or_action_id),
|
||||
status=TaskExecutionStatus.PENDING,
|
||||
description="Retrieving event or action properties",
|
||||
progress_text=f"Retrieving properties for {event_name_or_action_id}...",
|
||||
task_type="retrieve_event_or_action_properties",
|
||||
)
|
||||
for event_name_or_action_id in event_name_or_action_ids
|
||||
],
|
||||
)
|
||||
property_data = [prop async for prop in qs]
|
||||
property_to_type = {prop.name: prop.property_type for prop in property_data}
|
||||
props = [
|
||||
(item.property, property_to_type.get(item.property))
|
||||
for item in response.results
|
||||
# Exclude properties that exist in the taxonomy, but don't have a type.
|
||||
if item.property in property_to_type
|
||||
]
|
||||
|
||||
if not props:
|
||||
return TaxonomyErrorMessages.event_properties_not_found(verbose_name)
|
||||
enriched_props = self._enrich_props_with_descriptions("event", props)
|
||||
return self._format_properties(enriched_props)
|
||||
config = RunnableConfig()
|
||||
executor = TaxonomyTaskExecutorNode(self._team, self._user)
|
||||
result = await executor.arun(task_executor_state, config)
|
||||
return {task.id: task.result for task in result.task_results}
|
||||
|
||||
async def retrieve_event_or_action_property_values(
|
||||
self, event_properties: dict[str | int, list[str]]
|
||||
) -> dict[str | int, list[str]]:
|
||||
"""Retrieve property values for an event/action. Supports single property or list of properties."""
|
||||
result = await self._parallel_event_processing(event_properties)
|
||||
result = await self._parallel_event_or_action_processing(event_properties)
|
||||
return result
|
||||
|
||||
async def _parallel_event_processing(
|
||||
async def _parallel_event_or_action_processing(
|
||||
self, event_properties: dict[str | int, list[str]]
|
||||
) -> dict[str | int, list[str]]:
|
||||
event_tasks = [
|
||||
@@ -535,27 +779,102 @@ class TaxonomyAgentToolkit:
|
||||
|
||||
return results
|
||||
|
||||
async def handle_tools(self, tool_name: str, tool_input: TaxonomyTool) -> tuple[str, str]:
|
||||
if tool_name == "retrieve_entity_property_values":
|
||||
entity = tool_input.arguments.entity # type: ignore
|
||||
property_name = tool_input.arguments.property_name # type: ignore
|
||||
result = (await self.retrieve_entity_property_values({entity: [property_name]}))[entity][0]
|
||||
elif tool_name == "retrieve_entity_properties":
|
||||
result = await self.retrieve_entity_properties(tool_input.arguments.entity) # type: ignore
|
||||
elif tool_name == "retrieve_event_property_values":
|
||||
event_name_or_action_id = tool_input.arguments.event_name # type: ignore
|
||||
property_name = tool_input.arguments.property_name # type: ignore
|
||||
result = (await self.retrieve_event_or_action_property_values({event_name_or_action_id: [property_name]}))[
|
||||
event_name_or_action_id
|
||||
][0]
|
||||
elif tool_name == "retrieve_event_properties":
|
||||
result = await self.retrieve_event_or_action_properties(tool_input.arguments.event_name) # type: ignore
|
||||
elif tool_name == "ask_user_for_help":
|
||||
result = tool_input.arguments.request # type: ignore
|
||||
else:
|
||||
raise TaxonomyToolNotFoundError(f"Tool {tool_name} not found in taxonomy toolkit.")
|
||||
def _collect_tools(self, tool_metadata: dict[str, list[tuple[TaxonomyTool, str]]]) -> dict:
|
||||
"""
|
||||
Collect and group tool calls by type for batch processing.
|
||||
Returns grouped data and mappings for result distribution.
|
||||
"""
|
||||
result: dict = {
|
||||
"entity_property_values": {}, # entity -> [property_names]
|
||||
"entity_properties": [], # [entities]
|
||||
"event_property_values": {}, # event_name -> [property_names]
|
||||
"event_properties": [], # [event_names]
|
||||
"entity_prop_mapping": {}, # (entity, property) -> tool_call_id
|
||||
"entity_mapping": {}, # entity -> tool_call_id
|
||||
"event_prop_mapping": {}, # (event, property) -> tool_call_id
|
||||
"event_mapping": {}, # event -> tool_call_id
|
||||
}
|
||||
|
||||
return tool_name, result
|
||||
for tool_name, tool_inputs in tool_metadata.items():
|
||||
for tool_input, tool_call_id in tool_inputs:
|
||||
if tool_name == "retrieve_entity_property_values":
|
||||
entity = tool_input.arguments.entity # type: ignore
|
||||
property_name = tool_input.arguments.property_name # type: ignore
|
||||
if entity not in result["entity_property_values"]:
|
||||
result["entity_property_values"][entity] = []
|
||||
result["entity_property_values"][entity].append(property_name)
|
||||
result["entity_prop_mapping"][(entity, property_name)] = tool_call_id
|
||||
|
||||
elif tool_name == "retrieve_entity_properties":
|
||||
entity = tool_input.arguments.entity # type: ignore
|
||||
result["entity_properties"].append(entity)
|
||||
result["entity_mapping"][entity] = tool_call_id
|
||||
|
||||
elif tool_name == "retrieve_event_property_values":
|
||||
event_name = tool_input.arguments.event_name # type: ignore
|
||||
property_name = tool_input.arguments.property_name # type: ignore
|
||||
if event_name not in result["event_property_values"]:
|
||||
result["event_property_values"][event_name] = []
|
||||
result["event_property_values"][event_name].append(property_name)
|
||||
result["event_prop_mapping"][(event_name, property_name)] = tool_call_id
|
||||
|
||||
elif tool_name == "retrieve_event_properties":
|
||||
event_name = tool_input.arguments.event_name # type: ignore
|
||||
result["event_properties"].append(event_name)
|
||||
result["event_mapping"][event_name] = tool_call_id
|
||||
else:
|
||||
raise TaxonomyToolNotFoundError(f"Tool {tool_name} not found in taxonomy toolkit.")
|
||||
|
||||
return result
|
||||
|
||||
async def _execute_tools(self, collected_tools: dict) -> dict[str, str]:
|
||||
"""
|
||||
Execute batch operations and distribute results.
|
||||
Returns a dict mapping tool_call_id to result for each individual tool call.
|
||||
"""
|
||||
results = {}
|
||||
|
||||
# Execute batch operations and distribute results in single passes
|
||||
if collected_tools["entity_property_values"]:
|
||||
entity_property_values = await self.retrieve_entity_property_values(
|
||||
collected_tools["entity_property_values"]
|
||||
)
|
||||
for entity, property_results in entity_property_values.items():
|
||||
for i, property_name in enumerate(collected_tools["entity_property_values"][entity]):
|
||||
results[collected_tools["entity_prop_mapping"][(entity, property_name)]] = property_results[i]
|
||||
|
||||
if collected_tools["entity_properties"]:
|
||||
entity_properties = await self.retrieve_entity_properties_parallel(collected_tools["entity_properties"])
|
||||
for entity, result in entity_properties.items():
|
||||
results[collected_tools["entity_mapping"][entity]] = result
|
||||
|
||||
if collected_tools["event_property_values"]:
|
||||
event_property_values = await self.retrieve_event_or_action_property_values(
|
||||
collected_tools["event_property_values"]
|
||||
)
|
||||
for event_name, property_results in event_property_values.items():
|
||||
for i, property_name in enumerate(collected_tools["event_property_values"][event_name]):
|
||||
results[collected_tools["event_prop_mapping"][(event_name, property_name)]] = property_results[i]
|
||||
|
||||
if collected_tools["event_properties"]:
|
||||
event_properties = await self.retrieve_event_or_action_properties_parallel(
|
||||
collected_tools["event_properties"]
|
||||
)
|
||||
for event_name, result in event_properties.items():
|
||||
results[collected_tools["event_mapping"][event_name]] = result
|
||||
|
||||
return results
|
||||
|
||||
async def handle_tools(self, tool_metadata: dict[str, list[tuple[TaxonomyTool, str]]]) -> dict[str, str]:
|
||||
"""
|
||||
Handle multiple tool calls with maximum optimization by batching similar operations.
|
||||
Returns a dict mapping tool_call_id to result for each individual tool call.
|
||||
"""
|
||||
# Collect and group tools
|
||||
collected_tools = self._collect_tools(tool_metadata)
|
||||
|
||||
# Execute tools and return results
|
||||
return await self._execute_tools(collected_tools)
|
||||
|
||||
def get_tool_input_model(self, action: AgentAction) -> TaxonomyTool:
|
||||
try:
|
||||
|
||||
@@ -39,9 +39,7 @@ class retrieve_entity_properties(BaseModel):
|
||||
- **Avoid using ambiguous properties** unless their relevance is explicitly confirmed.
|
||||
"""
|
||||
|
||||
entity: Literal["person", "session"] = Field(
|
||||
..., description="The type of the entity that you want to retrieve properties for."
|
||||
)
|
||||
entity: str = Field(..., description="The type of the entity that you want to retrieve properties for.")
|
||||
|
||||
|
||||
class retrieve_event_property_values(BaseModel):
|
||||
@@ -67,9 +65,7 @@ class retrieve_entity_property_values(BaseModel):
|
||||
Use this tool to retrieve property values for a property name. Adjust filters to these values. You will receive a list of property values or a message that property values have not been found. Some properties can have many values, so the output will be truncated. Use your judgment to find a proper value.
|
||||
"""
|
||||
|
||||
entity: Literal["person", "session"] = Field(
|
||||
..., description="The type of the entity that you want to retrieve properties for."
|
||||
)
|
||||
entity: str = Field(..., description="The type of the entity that you want to retrieve properties for.")
|
||||
property_name: str = Field(..., description="The name of the property that you want to retrieve values for.")
|
||||
|
||||
|
||||
|
||||
@@ -36,6 +36,11 @@ class TaxonomyAgentState(BaseStateWithIntermediateSteps, BaseStateWithMessages,
|
||||
The messages with tool calls to collect tool progress.
|
||||
"""
|
||||
|
||||
iteration_count: int | None = Field(default=None)
|
||||
"""
|
||||
The number of iterations the taxonomy agent has gone through.
|
||||
"""
|
||||
|
||||
|
||||
class TaxonomyNodeName(StrEnum):
|
||||
"""Generic node names for taxonomy agents."""
|
||||
@@ -44,6 +49,7 @@ class TaxonomyNodeName(StrEnum):
|
||||
TOOLS_NODE = "taxonomy_tools_node"
|
||||
START = START
|
||||
END = END
|
||||
TASK_EXECUTOR = "taxonomy_task_executor"
|
||||
|
||||
|
||||
class EntityType(str, Enum):
|
||||
|
||||
@@ -17472,6 +17472,9 @@
|
||||
},
|
||||
{
|
||||
"$ref": "#/definitions/RecordingPropertyFilter"
|
||||
},
|
||||
{
|
||||
"$ref": "#/definitions/GroupPropertyFilter"
|
||||
}
|
||||
]
|
||||
},
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import {
|
||||
EventPropertyFilter,
|
||||
FilterLogicalOperator,
|
||||
GroupPropertyFilter,
|
||||
PersonPropertyFilter,
|
||||
RecordingDurationFilter,
|
||||
RecordingPropertyFilter,
|
||||
@@ -35,3 +36,4 @@ export type MaxUniversalFilterValue =
|
||||
| PersonPropertyFilter
|
||||
| SessionPropertyFilter
|
||||
| RecordingPropertyFilter
|
||||
| GroupPropertyFilter
|
||||
|
||||
@@ -372,8 +372,6 @@ ee/hogai/graph/taxonomy/test/test_nodes.py:0: error: Value of type "list[tuple[A
|
||||
ee/hogai/graph/taxonomy/test/test_nodes.py:0: error: Value of type "list[tuple[AgentAction, str | None]] | None" is not indexable [index]
|
||||
ee/hogai/graph/taxonomy/test/test_toolkit.py:0: error: Argument 1 to "_format_properties_xml" of "TaxonomyAgentToolkit" has incompatible type "list[tuple[str, str, str | None]]"; expected "list[tuple[str, str | None, str | None]]" [arg-type]
|
||||
ee/hogai/graph/taxonomy/test/test_toolkit.py:0: error: Argument 1 to "_format_properties_yaml" of "TaxonomyAgentToolkit" has incompatible type "list[tuple[str, str, str | None]]"; expected "list[tuple[str, str | None, str | None]]" [arg-type]
|
||||
ee/hogai/graph/taxonomy/test/test_toolkit.py:0: error: Argument 2 to "handle_tools" of "TaxonomyAgentToolkit" has incompatible type "ToolInput"; expected "TaxonomyTool[Any]" [arg-type]
|
||||
ee/hogai/graph/taxonomy/test/test_toolkit.py:0: error: Argument 2 to "handle_tools" of "TaxonomyAgentToolkit" has incompatible type "ToolInput"; expected "TaxonomyTool[Any]" [arg-type]
|
||||
ee/hogai/graph/taxonomy/test/test_toolkit.py:0: error: Item "ask_user_for_help" of "retrieve_event_properties | retrieve_entity_properties | retrieve_entity_property_values | retrieve_event_property_values | ask_user_for_help | Any" has no attribute "custom_field" [union-attr]
|
||||
ee/hogai/graph/taxonomy/test/test_toolkit.py:0: error: Item "retrieve_entity_properties" of "retrieve_event_properties | retrieve_entity_properties | retrieve_entity_property_values | retrieve_event_property_values | ask_user_for_help | Any" has no attribute "custom_field" [union-attr]
|
||||
ee/hogai/graph/taxonomy/test/test_toolkit.py:0: error: Item "retrieve_entity_property_values" of "retrieve_event_properties | retrieve_entity_properties | retrieve_entity_property_values | retrieve_event_property_values | ask_user_for_help | Any" has no attribute "custom_field" [union-attr]
|
||||
|
||||
@@ -12606,7 +12606,15 @@ class MaxInnerUniversalFiltersGroup(BaseModel):
|
||||
extra="forbid",
|
||||
)
|
||||
type: FilterLogicalOperator
|
||||
values: list[Union[EventPropertyFilter, PersonPropertyFilter, SessionPropertyFilter, RecordingPropertyFilter]]
|
||||
values: list[
|
||||
Union[
|
||||
EventPropertyFilter,
|
||||
PersonPropertyFilter,
|
||||
SessionPropertyFilter,
|
||||
RecordingPropertyFilter,
|
||||
GroupPropertyFilter,
|
||||
]
|
||||
]
|
||||
|
||||
|
||||
class MaxOuterUniversalFiltersGroup(BaseModel):
|
||||
|
||||
@@ -41,8 +41,8 @@ class final_answer(base_final_answer[FinalAnswerArgs]):
|
||||
|
||||
|
||||
class HogQLGeneratorToolkit(TaxonomyAgentToolkit):
|
||||
def __init__(self, team: Team):
|
||||
super().__init__(team)
|
||||
def __init__(self, team: Team, user: User):
|
||||
super().__init__(team, user)
|
||||
|
||||
def _get_custom_tools(self) -> list:
|
||||
"""Get custom tools for the HogQLGenerator."""
|
||||
|
||||
@@ -115,11 +115,11 @@ class final_answer(base_final_answer[ErrorTrackingIssueImpactToolOutput]):
|
||||
|
||||
|
||||
class ErrorTrackingIssueImpactToolkit(TaxonomyAgentToolkit):
|
||||
def __init__(self, team: Team):
|
||||
super().__init__(team)
|
||||
def __init__(self, team: Team, user: User):
|
||||
super().__init__(team, user)
|
||||
|
||||
async def handle_tools(self, tool_name: str, tool_input: TaxonomyTool) -> tuple[str, str]:
|
||||
return await super().handle_tools(tool_name, tool_input)
|
||||
async def handle_tools(self, tool_metadata: dict[str, list[tuple[TaxonomyTool, str]]]) -> dict[str, str]:
|
||||
return await super().handle_tools(tool_metadata)
|
||||
|
||||
def _get_custom_tools(self) -> list:
|
||||
return [final_answer]
|
||||
|
||||
@@ -31,8 +31,8 @@ logger.setLevel(logging.DEBUG)
|
||||
|
||||
|
||||
class SessionReplayFilterOptionsToolkit(TaxonomyAgentToolkit):
|
||||
def __init__(self, team: Team):
|
||||
super().__init__(team)
|
||||
def __init__(self, team: Team, user: User):
|
||||
super().__init__(team, user)
|
||||
|
||||
def _get_custom_tools(self) -> list:
|
||||
"""Get custom tools for filter options."""
|
||||
|
||||
@@ -18,7 +18,7 @@ from ee.hogai.graph.taxonomy.agent import TaxonomyAgent
|
||||
from ee.hogai.graph.taxonomy.format import enrich_props_with_descriptions, format_properties_xml
|
||||
from ee.hogai.graph.taxonomy.nodes import TaxonomyAgentNode, TaxonomyAgentToolsNode
|
||||
from ee.hogai.graph.taxonomy.toolkit import TaxonomyAgentToolkit, TaxonomyErrorMessages
|
||||
from ee.hogai.graph.taxonomy.tools import ask_user_for_help, base_final_answer
|
||||
from ee.hogai.graph.taxonomy.tools import TaxonomyTool, ask_user_for_help, base_final_answer
|
||||
from ee.hogai.graph.taxonomy.types import TaxonomyAgentState
|
||||
from ee.hogai.tool import MaxTool
|
||||
from ee.hogai.utils.types.base import AssistantNodeName
|
||||
@@ -53,16 +53,27 @@ logger.setLevel(logging.DEBUG)
|
||||
|
||||
|
||||
class RevenueAnalyticsFilterOptionsToolkit(TaxonomyAgentToolkit):
|
||||
def __init__(self, team: Team):
|
||||
super().__init__(team)
|
||||
def __init__(self, team: Team, user: User):
|
||||
super().__init__(team, user)
|
||||
|
||||
async def handle_tools(self, tool_name: str, tool_input) -> tuple[str, str]:
|
||||
async def handle_tools(self, tool_metadata: dict[str, list[tuple[TaxonomyTool, str]]]) -> dict[str, str]:
|
||||
"""Handle custom tool execution."""
|
||||
if tool_name == "retrieve_revenue_analytics_property_values":
|
||||
result = await self._retrieve_revenue_analytics_property_values(tool_input.arguments.property_key)
|
||||
return tool_name, result
|
||||
results = {}
|
||||
unhandled_tools = {}
|
||||
for tool_name, tool_inputs in tool_metadata.items():
|
||||
if tool_name == "retrieve_revenue_analytics_property_values":
|
||||
if tool_inputs:
|
||||
for tool_input, tool_call_id in tool_inputs:
|
||||
result = await self._retrieve_revenue_analytics_property_values(
|
||||
tool_input.arguments.property_key # type: ignore
|
||||
)
|
||||
results[tool_call_id] = result
|
||||
else:
|
||||
unhandled_tools[tool_name] = tool_inputs
|
||||
|
||||
return await super().handle_tools(tool_name, tool_input)
|
||||
if unhandled_tools:
|
||||
results.update(await super().handle_tools(unhandled_tools))
|
||||
return results
|
||||
|
||||
def _get_custom_tools(self) -> list:
|
||||
return [final_answer, retrieve_revenue_analytics_property_values]
|
||||
|
||||
@@ -15,11 +15,12 @@ from posthog.schema import SurveyAnalysisQuestionGroup, SurveyCreationSchema
|
||||
from posthog.constants import DEFAULT_SURVEY_APPEARANCE
|
||||
from posthog.exceptions_capture import capture_exception
|
||||
from posthog.models import FeatureFlag, Survey, Team, User
|
||||
from posthog.sync import database_sync_to_async
|
||||
|
||||
from ee.hogai.graph.taxonomy.agent import TaxonomyAgent
|
||||
from ee.hogai.graph.taxonomy.nodes import TaxonomyAgentNode, TaxonomyAgentToolsNode
|
||||
from ee.hogai.graph.taxonomy.toolkit import TaxonomyAgentToolkit
|
||||
from ee.hogai.graph.taxonomy.tools import base_final_answer
|
||||
from ee.hogai.graph.taxonomy.tools import TaxonomyTool, ask_user_for_help, base_final_answer
|
||||
from ee.hogai.graph.taxonomy.types import TaxonomyAgentState
|
||||
from ee.hogai.llm import MaxChatOpenAI
|
||||
from ee.hogai.tool import MaxTool
|
||||
@@ -160,8 +161,8 @@ class CreateSurveyTool(MaxTool):
|
||||
class SurveyToolkit(TaxonomyAgentToolkit):
|
||||
"""Toolkit for survey creation and feature flag lookup operations."""
|
||||
|
||||
def __init__(self, team: Team):
|
||||
super().__init__(team)
|
||||
def __init__(self, team: Team, user: User):
|
||||
super().__init__(team, user)
|
||||
|
||||
def get_tools(self) -> list:
|
||||
"""Get all tools (default + custom). Override in subclasses to add custom tools."""
|
||||
@@ -181,15 +182,26 @@ class SurveyToolkit(TaxonomyAgentToolkit):
|
||||
class final_answer(base_final_answer[SurveyCreationSchema]):
|
||||
__doc__ = base_final_answer.__doc__
|
||||
|
||||
return [lookup_feature_flag, final_answer]
|
||||
return [lookup_feature_flag, final_answer, ask_user_for_help]
|
||||
|
||||
async def handle_tools(self, tool_name: str, tool_input) -> tuple[str, str]:
|
||||
async def handle_tools(self, tool_metadata: dict[str, list[tuple[TaxonomyTool, str]]]) -> dict[str, str]:
|
||||
"""Handle custom tool execution."""
|
||||
if tool_name == "lookup_feature_flag":
|
||||
result = self._lookup_feature_flag(tool_input.arguments.flag_key)
|
||||
return tool_name, result
|
||||
return await super().handle_tools(tool_name, tool_input)
|
||||
results = {}
|
||||
unhandled_tools = {}
|
||||
for tool_name, tool_inputs in tool_metadata.items():
|
||||
if tool_name == "lookup_feature_flag":
|
||||
if tool_inputs:
|
||||
for tool_input, tool_call_id in tool_inputs:
|
||||
result = await self._lookup_feature_flag(tool_input.arguments.flag_key) # type: ignore
|
||||
results[tool_call_id] = result
|
||||
else:
|
||||
unhandled_tools[tool_name] = tool_inputs
|
||||
|
||||
if unhandled_tools:
|
||||
results.update(await super().handle_tools(unhandled_tools))
|
||||
return results
|
||||
|
||||
@database_sync_to_async(thread_sensitive=False)
|
||||
def _lookup_feature_flag(self, flag_key: str) -> str:
|
||||
"""Look up feature flag information by key."""
|
||||
try:
|
||||
|
||||
Reference in New Issue
Block a user