", formatted_string)
@@ -545,12 +549,12 @@ class TestBillingNode(ClickhouseTestMixin, BaseTest):
),
]
- self.node._teams_map = {
+ self.tool._teams_map = {
84444: "Project 84444",
12345: "Project 12345",
}
- table = self.node._format_history_table(usage_history)
+ table = self.tool._format_history_table(usage_history)
# Should always include aggregated table first
self.assertIn("### Overall (all projects)", table)
@@ -597,7 +601,7 @@ class TestBillingNode(ClickhouseTestMixin, BaseTest):
),
]
- table = self.node._format_history_table(usage_history)
+ table = self.tool._format_history_table(usage_history)
# Should include aggregated table
self.assertIn("### Overall (all projects)", table)
@@ -645,7 +649,7 @@ class TestBillingNode(ClickhouseTestMixin, BaseTest):
)
]
- aggregated = self.node._create_aggregated_items(team_items, other_items)
+ aggregated = self.tool._create_aggregated_items(team_items, other_items)
# Should have 2 aggregated items: events and feature flags
self.assertEqual(len(aggregated), 2)
@@ -687,7 +691,7 @@ class TestBillingNode(ClickhouseTestMixin, BaseTest):
),
]
- table = self.node._format_history_table(usage_history)
+ table = self.tool._format_history_table(usage_history)
# Should only have aggregated table, no team-specific tables
self.assertIn("### Overall (all projects)", table)
@@ -721,18 +725,18 @@ class TestBillingNode(ClickhouseTestMixin, BaseTest):
)
]
- self.node._teams_map = {
+ self.tool._teams_map = {
123: "Project 123",
}
# Test usage history formatting
- usage_table = self.node._format_history_table(usage_history)
+ usage_table = self.tool._format_history_table(usage_history)
self.assertIn("### Overall (all projects)", usage_table)
self.assertIn("### Project 123", usage_table)
self.assertIn("| Events | 1,000.00 |", usage_table)
# Test spend history formatting
- spend_table = self.node._format_history_table(spend_history)
+ spend_table = self.tool._format_history_table(spend_history)
self.assertIn("### Overall (all projects)", spend_table)
self.assertIn("### Project 123", spend_table)
self.assertIn("| Events | 50.00 |", spend_table)
@@ -758,7 +762,7 @@ class TestBillingNode(ClickhouseTestMixin, BaseTest):
),
]
- table = self.node._format_history_table(usage_history)
+ table = self.tool._format_history_table(usage_history)
# Should handle empty dates gracefully and still show valid data
self.assertIn("### Overall (all projects)", table)
diff --git a/ee/hogai/graph/billing/nodes.py b/ee/hogai/graph/root/tools/read_billing_tool/tool.py
similarity index 91%
rename from ee/hogai/graph/billing/nodes.py
rename to ee/hogai/graph/root/tools/read_billing_tool/tool.py
index 985337862b..6d002e380c 100644
--- a/ee/hogai/graph/billing/nodes.py
+++ b/ee/hogai/graph/root/tools/read_billing_tool/tool.py
@@ -1,18 +1,19 @@
-from typing import Any, cast
-from uuid import uuid4
+from typing import Any
from langchain_core.prompts import PromptTemplate
from langchain_core.runnables import RunnableConfig
-from posthog.schema import AssistantToolCallMessage, MaxBillingContext, SpendHistoryItem, UsageHistoryItem
+from posthog.schema import MaxBillingContext, SpendHistoryItem, UsageHistoryItem
from posthog.clickhouse.client import sync_execute
+from posthog.models import Team, User
+from posthog.sync import database_sync_to_async
-from ee.hogai.graph.base import AssistantNode
-from ee.hogai.graph.billing.prompts import BILLING_CONTEXT_PROMPT
+from ee.hogai.context.context import AssistantContextManager
+from ee.hogai.tool import MaxSubtool
from ee.hogai.utils.types import AssistantState
-from ee.hogai.utils.types.base import AssistantNodeName, PartialAssistantState
-from ee.hogai.utils.types.composed import MaxNodeName
+
+from .prompts import BILLING_CONTEXT_PROMPT
# sync with frontend/src/scenes/billing/constants.ts
USAGE_TYPES = [
@@ -33,33 +34,27 @@ USAGE_TYPES = [
]
-class BillingNode(AssistantNode):
- _teams_map: dict[int, str] = {}
+class ReadBillingTool(MaxSubtool):
+ def __init__(
+ self,
+ *,
+ team: Team,
+ user: User,
+ state: AssistantState,
+ config: RunnableConfig,
+ context_manager: AssistantContextManager,
+ ):
+ super().__init__(team=team, user=user, state=state, config=config, context_manager=context_manager)
+ self._teams_map: dict[int, str] = {}
- @property
- def node_name(self) -> MaxNodeName:
- return AssistantNodeName.BILLING
-
- def run(self, state: AssistantState, config: RunnableConfig) -> PartialAssistantState:
- tool_call_id = cast(str, state.root_tool_call_id)
- billing_context = self.context_manager.get_billing_context()
+ async def execute(self) -> str:
+ billing_context = self._context_manager.get_billing_context()
if not billing_context:
- return PartialAssistantState(
- messages=[
- AssistantToolCallMessage(
- content="No billing information available", id=str(uuid4()), tool_call_id=tool_call_id
- )
- ],
- root_tool_call_id=None,
- )
- formatted_billing_context = self._format_billing_context(billing_context)
- return PartialAssistantState(
- messages=[
- AssistantToolCallMessage(content=formatted_billing_context, tool_call_id=tool_call_id, id=str(uuid4())),
- ],
- root_tool_call_id=None,
- )
+ return "No billing information available"
+ formatted_billing_context = await self._format_billing_context(billing_context)
+ return formatted_billing_context
+ @database_sync_to_async(thread_sensitive=False)
def _format_billing_context(self, billing_context: MaxBillingContext) -> str:
"""Format billing context into a readable prompt section."""
# Convert billing context to a format suitable for the mustache template
diff --git a/ee/hogai/graph/root/tools/read_data.py b/ee/hogai/graph/root/tools/read_data.py
index 712d8c22d0..577743f5c3 100644
--- a/ee/hogai/graph/root/tools/read_data.py
+++ b/ee/hogai/graph/root/tools/read_data.py
@@ -1,4 +1,4 @@
-from typing import Any, Literal, Self
+from typing import Literal, Self
from langchain_core.runnables import RunnableConfig
from pydantic import BaseModel
@@ -9,12 +9,14 @@ from ee.hogai.context.context import AssistantContextManager
from ee.hogai.graph.sql.mixins import HogQLDatabaseMixin
from ee.hogai.tool import MaxTool
from ee.hogai.utils.prompt import format_prompt_string
-from ee.hogai.utils.types.base import AssistantState
+from ee.hogai.utils.types.base import AssistantState, NodePath
+
+from .read_billing_tool.tool import ReadBillingTool
READ_DATA_BILLING_PROMPT = """
# Billing information
-Use this tool with the "billing_info" kind to retrieve the billing information if the user asks about billing, their subscription, their usage, or their spending.
+Use this tool with the "billing_info" kind to retrieve the billing information if the user asks about their billing, subscription, product usage, spending, or cost reduction strategies.
You can use the information retrieved to check which PostHog products and add-ons the user has activated, how much they are spending, their usage history across all products in the last 30 days, as well as trials, spending limits, billing period, and more.
If the user wants to reduce their spending, always call this tool to get suggestions on how to do so.
If an insight shows zero data, it could mean either the query is looking at the wrong data or there was a temporary data collection issue. You can investigate potential dips in usage/captured data using the billing tool.
@@ -62,7 +64,7 @@ class ReadDataTool(HogQLDatabaseMixin, MaxTool):
*,
team: Team,
user: User,
- tool_call_id: str,
+ node_path: tuple[NodePath, ...] | None = None,
state: AssistantState | None = None,
config: RunnableConfig | None = None,
context_manager: AssistantContextManager | None = None,
@@ -83,21 +85,29 @@ class ReadDataTool(HogQLDatabaseMixin, MaxTool):
return cls(
team=team,
user=user,
- tool_call_id=tool_call_id,
state=state,
+ node_path=node_path,
config=config,
args_schema=args,
description=description,
context_manager=context_manager,
)
- async def _arun_impl(self, kind: ReadDataAdminAccessKind | ReadDataKind) -> tuple[str, dict[str, Any] | None]:
+ async def _arun_impl(self, kind: ReadDataAdminAccessKind | ReadDataKind) -> tuple[str, None]:
match kind:
case "billing_info":
has_access = await self._context_manager.check_user_has_billing_access()
if not has_access:
return BILLING_INSUFFICIENT_ACCESS_PROMPT, None
# used for routing
- return "", self.args_schema(kind=kind).model_dump()
+ billing_tool = ReadBillingTool(
+ team=self._team,
+ user=self._user,
+ state=self._state,
+ config=self._config,
+ context_manager=self._context_manager,
+ )
+ result = await billing_tool.execute()
+ return result, None
case "datawarehouse_schema":
return await self._serialize_database_schema(), None
diff --git a/ee/hogai/graph/root/tools/read_taxonomy.py b/ee/hogai/graph/root/tools/read_taxonomy.py
index 0fffab82e0..d088d7b5ca 100644
--- a/ee/hogai/graph/root/tools/read_taxonomy.py
+++ b/ee/hogai/graph/root/tools/read_taxonomy.py
@@ -9,7 +9,7 @@ from ee.hogai.context.context import AssistantContextManager
from ee.hogai.graph.query_planner.toolkit import TaxonomyAgentToolkit
from ee.hogai.tool import MaxTool
from ee.hogai.utils.helpers import format_events_yaml
-from ee.hogai.utils.types.base import AssistantState
+from ee.hogai.utils.types.base import AssistantState, NodePath
READ_TAXONOMY_TOOL_DESCRIPTION = """
Use this tool to explore the user's taxonomy (i.e. data schema).
@@ -155,7 +155,7 @@ class ReadTaxonomyTool(MaxTool):
*,
team: Team,
user: User,
- tool_call_id: str,
+ node_path: tuple[NodePath, ...] | None = None,
state: AssistantState | None = None,
config: RunnableConfig | None = None,
context_manager: AssistantContextManager | None = None,
@@ -200,9 +200,9 @@ class ReadTaxonomyTool(MaxTool):
return cls(
team=team,
user=user,
- tool_call_id=tool_call_id,
state=state,
config=config,
+ node_path=node_path,
args_schema=ReadTaxonomyToolArgsWithGroups,
context_manager=context_manager,
)
diff --git a/ee/hogai/graph/root/tools/search.py b/ee/hogai/graph/root/tools/search.py
index 11c32dfb99..7f7910a282 100644
--- a/ee/hogai/graph/root/tools/search.py
+++ b/ee/hogai/graph/root/tools/search.py
@@ -1,17 +1,19 @@
-from typing import Any, Literal
+from typing import Literal
from django.conf import settings
import posthoganalytics
from langchain_core.output_parsers import SimpleJsonOutputParser
from langchain_core.prompts import ChatPromptTemplate
+from langchain_core.runnables import RunnableLambda
from langchain_openai import ChatOpenAI
from pydantic import BaseModel, Field
-from posthog.models import Team, User
-
-from ee.hogai.graph.root.tools.full_text_search.tool import EntitySearchToolkit, FTSKind
-from ee.hogai.tool import MaxTool
+from ee.hogai.graph.insights.nodes import InsightSearchNode, NoInsightsException
+from ee.hogai.graph.root.tools.full_text_search.tool import EntitySearchTool, FTSKind
+from ee.hogai.tool import MaxSubtool, MaxTool, ToolMessagesArtifact
+from ee.hogai.utils.prompt import format_prompt_string
+from ee.hogai.utils.types.base import AssistantState, PartialAssistantState
SEARCH_TOOL_PROMPT = """
Use this tool to search docs, insights, dashboards, cohorts, actions, experiments, feature flags, notebooks, error tracking issues, and surveys in PostHog.
@@ -57,35 +59,10 @@ If you want to search for all entities, you should use `all`.
""".strip()
-DOCS_SEARCH_RESULTS_TEMPLATE = """Found {count} relevant documentation page(s):
-
-{docs}
-
-Use retrieved documentation to answer the user's question if it is relevant to the user's query.
-Format the response using Markdown and reference the documentation using hyperlinks.
-Every link to docs clearly explicitly be labeled, for example as "(see docs)".
-
+INVALID_ENTITY_KIND_PROMPT = """
+Invalid entity kind: {{{kind}}}. Please provide a valid entity kind for the tool.
""".strip()
-DOCS_SEARCH_NO_RESULTS_TEMPLATE = """
-No documentation found.
-
-
-Do not answer the user's question if you did not find any documentation. Try rewriting the query.
-If after a couple of attempts you still do not find any documentation, suggest the user navigate to the documentation page, which is available at `https://posthog.com/docs`.
-
-""".strip()
-
-DOC_ITEM_TEMPLATE = """
-# {title}
-URL: {url}
-
-{text}
-""".strip()
-
-
-FTS_SEARCH_FEATURE_FLAG = "hogai-insights-fts-search"
-
ENTITIES = [f"{entity}" for entity in FTSKind if entity != FTSKind.INSIGHTS]
SearchKind = Literal["insights", "docs", *ENTITIES] # type: ignore
@@ -126,50 +103,106 @@ class SearchTool(MaxTool):
context_prompt_template: str = "Searches documentation, insights, dashboards, cohorts, actions, experiments, feature flags, notebooks, error tracking issues, and surveys in PostHog"
args_schema: type[BaseModel] = SearchToolArgs
- @staticmethod
- def _get_fts_entities(include_insight_fts: bool) -> list[str]:
- if not include_insight_fts:
- entities = [e for e in FTSKind if e != FTSKind.INSIGHTS]
- else:
- entities = list(FTSKind)
- return [*entities, FTSKind.ALL]
-
- async def _arun_impl(self, kind: SearchKind, query: str) -> tuple[str, dict[str, Any] | None]:
+ async def _arun_impl(self, kind: str, query: str) -> tuple[str, ToolMessagesArtifact | None]:
if kind == "docs":
if not settings.INKEEP_API_KEY:
return "This tool is not available in this environment.", None
- if self._has_docs_search_feature_flag():
- return await self._search_docs(query), None
+ docs_tool = InkeepDocsSearchTool(
+ team=self._team,
+ user=self._user,
+ state=self._state,
+ config=self._config,
+ context_manager=self._context_manager,
+ )
+ return await docs_tool.execute(query, self.tool_call_id)
- fts_entities = SearchTool._get_fts_entities(SearchTool._has_fts_search_feature_flag(self._user, self._team))
+ if kind == "insights" and not self._has_insights_fts_search_feature_flag():
+ insights_tool = InsightSearchTool(
+ team=self._team,
+ user=self._user,
+ state=self._state,
+ config=self._config,
+ context_manager=self._context_manager,
+ )
+ return await insights_tool.execute(query, self.tool_call_id)
- if kind in fts_entities:
- entity_search_toolkit = EntitySearchToolkit(self._team, self._user)
- response = await entity_search_toolkit.execute(query, FTSKind(kind))
- return response, None
- # Used for routing
- return "Search tool executed", SearchToolArgs(kind=kind, query=query).model_dump()
+ if kind not in self._fts_entities:
+ return format_prompt_string(INVALID_ENTITY_KIND_PROMPT, kind=kind), None
- def _has_docs_search_feature_flag(self) -> bool:
+ entity_search_toolkit = EntitySearchTool(
+ team=self._team,
+ user=self._user,
+ state=self._state,
+ config=self._config,
+ context_manager=self._context_manager,
+ )
+ response = await entity_search_toolkit.execute(query, FTSKind(kind))
+ return response, None
+
+ @property
+ def _fts_entities(self) -> list[str]:
+ entities = list(FTSKind)
+ return [*entities, FTSKind.ALL]
+
+ def _has_insights_fts_search_feature_flag(self) -> bool:
return posthoganalytics.feature_enabled(
- "max-inkeep-rag-docs-search",
+ "hogai-insights-fts-search",
str(self._user.distinct_id),
groups={"organization": str(self._team.organization_id)},
group_properties={"organization": {"id": str(self._team.organization_id)}},
send_feature_flag_events=False,
)
- @staticmethod
- def _has_fts_search_feature_flag(user: User, team: Team) -> bool:
- return posthoganalytics.feature_enabled(
- FTS_SEARCH_FEATURE_FLAG,
- str(user.distinct_id),
- groups={"organization": str(team.organization_id)},
- group_properties={"organization": {"id": str(team.organization_id)}},
- send_feature_flag_events=False,
- )
- async def _search_docs(self, query: str) -> str:
+DOCS_SEARCH_RESULTS_TEMPLATE = """Found {count} relevant documentation page(s):
+
+{docs}
+
+Use retrieved documentation to answer the user's question if it is relevant to the user's query.
+Format the response using Markdown and reference the documentation using hyperlinks.
+Every link to docs clearly explicitly be labeled, for example as "(see docs)".
+
+""".strip()
+
+DOCS_SEARCH_NO_RESULTS_TEMPLATE = """
+No documentation found.
+
+
+Do not answer the user's question if you did not find any documentation. Try rewriting the query.
+If after a couple of attempts you still do not find any documentation, suggest the user navigate to the documentation page, which is available at `https://posthog.com/docs`.
+
+""".strip()
+
+DOC_ITEM_TEMPLATE = """
+# {title}
+URL: {url}
+
+{text}
+""".strip()
+
+
+class InkeepDocsSearchTool(MaxSubtool):
+ async def execute(self, query: str, tool_call_id: str) -> tuple[str, ToolMessagesArtifact | None]:
+ if self._has_rag_docs_search_feature_flag():
+ return await self._search_using_rag_endpoint(query, tool_call_id)
+ else:
+ return await self._search_using_node(query, tool_call_id)
+
+ async def _search_using_node(self, query: str, tool_call_id: str) -> tuple[str, ToolMessagesArtifact | None]:
+ # Avoid circular import
+ from ee.hogai.graph.inkeep_docs.nodes import InkeepDocsNode
+
+ # Init the graph
+ node = InkeepDocsNode(self._team, self._user)
+ chain: RunnableLambda[AssistantState, PartialAssistantState | None] = RunnableLambda(node)
+ copied_state = self._state.model_copy(deep=True, update={"root_tool_call_id": tool_call_id})
+ result = await chain.ainvoke(copied_state)
+ assert result is not None
+ return "", ToolMessagesArtifact(messages=result.messages)
+
+ async def _search_using_rag_endpoint(
+ self, query: str, tool_call_id: str
+ ) -> tuple[str, ToolMessagesArtifact | None]:
model = ChatOpenAI(
model="inkeep-rag",
base_url="https://api.inkeep.com/v1/",
@@ -184,7 +217,7 @@ class SearchTool(MaxTool):
rag_context_raw = await chain.ainvoke({"query": query})
if not rag_context_raw or not rag_context_raw.get("content"):
- return DOCS_SEARCH_NO_RESULTS_TEMPLATE
+ return DOCS_SEARCH_NO_RESULTS_TEMPLATE, None
rag_context = InkeepResponse.model_validate(rag_context_raw)
@@ -197,7 +230,35 @@ class SearchTool(MaxTool):
docs.append(DOC_ITEM_TEMPLATE.format(title=doc.title, url=doc.url, text=text))
if not docs:
- return DOCS_SEARCH_NO_RESULTS_TEMPLATE
+ return DOCS_SEARCH_NO_RESULTS_TEMPLATE, None
formatted_docs = "\n\n---\n\n".join(docs)
- return DOCS_SEARCH_RESULTS_TEMPLATE.format(count=len(docs), docs=formatted_docs)
+ return DOCS_SEARCH_RESULTS_TEMPLATE.format(count=len(docs), docs=formatted_docs), None
+
+ def _has_rag_docs_search_feature_flag(self) -> bool:
+ return posthoganalytics.feature_enabled(
+ "max-inkeep-rag-docs-search",
+ str(self._user.distinct_id),
+ groups={"organization": str(self._team.organization_id)},
+ group_properties={"organization": {"id": str(self._team.organization_id)}},
+ send_feature_flag_events=False,
+ )
+
+
+EMPTY_DATABASE_ERROR_MESSAGE = """
+The user doesn't have any insights created yet.
+""".strip()
+
+
+class InsightSearchTool(MaxSubtool):
+ async def execute(self, query: str, tool_call_id: str) -> tuple[str, ToolMessagesArtifact | None]:
+ try:
+ node = InsightSearchNode(self._team, self._user)
+ copied_state = self._state.model_copy(
+ deep=True, update={"search_insights_query": query, "root_tool_call_id": tool_call_id}
+ )
+ chain: RunnableLambda[AssistantState, PartialAssistantState | None] = RunnableLambda(node)
+ result = await chain.ainvoke(copied_state)
+ return "", ToolMessagesArtifact(messages=result.messages) if result else None
+ except NoInsightsException:
+ return EMPTY_DATABASE_ERROR_MESSAGE, None
diff --git a/ee/hogai/graph/root/tools/session_summarization.py b/ee/hogai/graph/root/tools/session_summarization.py
new file mode 100644
index 0000000000..594358ec6b
--- /dev/null
+++ b/ee/hogai/graph/root/tools/session_summarization.py
@@ -0,0 +1,122 @@
+from typing import Literal
+
+from langchain_core.runnables import RunnableLambda
+from pydantic import BaseModel, Field
+
+from ee.hogai.graph.session_summaries.nodes import SessionSummarizationNode
+from ee.hogai.tool import MaxTool, ToolMessagesArtifact
+from ee.hogai.utils.types.base import AssistantState, PartialAssistantState
+
+SESSION_SUMMARIZATION_TOOL_PROMPT = """
+Use this tool to summarize session recordings by analysing the events within those sessions to find patterns and issues.
+It will return a textual summary of the captured session recordings.
+
+# When to use the tool:
+When the user asks to summarize session recordings:
+- "summarize" synonyms: "watch", "analyze", "review", and similar
+- "session recordings" synonyms: "sessions", "recordings", "replays", "user sessions", and similar
+
+# When NOT to use the tool:
+- When the user asks to find, search for, or look up session recordings, but doesn't ask to summarize them
+- When users asks to update, change, or adjust session recordings filters
+
+# Synonyms
+- "summarize": "watch", "analyze", "review", and similar
+- "session recordings": "sessions", "recordings", "replays", "user sessions", and similar
+
+# Managing context
+If the conversation history contains context about the current filters or session recordings, follow these steps:
+- Convert the user query into a `session_summarization_query`
+- The query should be used to understand the user's intent
+- Decide if the query is relevant to the current filters and set `should_use_current_filters` accordingly
+- Generate the `summary_title` based on the user's query and the current filters
+
+Otherwise:
+- Convert the user query into a `session_summarization_query`
+- The query should be used to search for relevant sessions and then summarize them
+- Assume the `should_use_current_filters` should be always `false`
+- Generate the `summary_title` based on the user's query
+
+# Additional guidelines
+- CRITICAL: Always pass the user's complete, unmodified query to the `session_summarization_query` parameter
+- DO NOT truncate, summarize, or extract keywords from the user's query
+- The query is used to find relevant sessions - context helps find better matches
+- Use explicit tool definition to make a decision
+""".strip()
+
+
+class SessionSummarizationToolArgs(BaseModel):
+ session_summarization_query: str = Field(
+ description="""
+ - The user's complete query for session recordings summarization.
+ - This will be used to find relevant session recordings.
+ - Always pass the user's complete, unmodified query.
+ - Examples:
+ * 'summarize all session recordings from yesterday'
+ * 'analyze mobile user session recordings from last week, even if 1 second'
+ * 'watch last 300 session recordings of MacOS users from US'
+ * and similar
+ """
+ )
+ should_use_current_filters: bool = Field(
+ description="""
+ - Whether to use current filters from user's UI to find relevant session recordings.
+ - IMPORTANT: Should be always `false` if the current filters or `search_session_recordings` tool are not present in the conversation history.
+ - Examples:
+ * Set to `true` if one of the conditions is met:
+ - the user wants to summarize "current/selected/opened/my/all/these" session recordings
+ - the user wants to use "current/these" filters
+ - the user's query specifies filters identical to the current filters
+ - if the user's query doesn't specify any filters/conditions
+ - the user refers to what they're "looking at" or "viewing"
+ * Set to `false` if one of the conditions is met:
+ - no current filters or `search_session_recordings` tool are present in the conversation
+ - the user specifies date/time period different from the current filters
+ - the user specifies conditions (user, device, id, URL, etc.) not present in the current filters
+ """,
+ )
+ summary_title: str = Field(
+ description="""
+ - The name of the summary that is expected to be generated from the user's `session_summarization_query` and/or `current_filters` (if present).
+ - The name should cover in 3-7 words what sessions would be to be summarized in the summary
+ - This won't be used for any search of filtering, only to properly label the generated summary.
+ - Examples:
+ * If `should_use_current_filters` is `false`, then the `summary_title` should be generated based on the `session_summarization_query`:
+ - query: "I want to watch all the sessions of user `user@example.com` in the last 30 days no matter how long" -> name: "Sessions of the user user@example.com (last 30 days)"
+ - query: "summarize my last 100 session recordings" -> name: "Last 100 sessions"
+ - and similar
+ * If `should_use_current_filters` is `true`, then the `summary_title` should be generated based on the current filters in the context (if present):
+ - filters: "{"key":"$os","value":["Mac OS X"],"operator":"exact","type":"event"}" -> name: "MacOS users"
+ - filters: "{"date_from": "-7d", "filter_test_accounts": True}" -> name: "All sessions (last 7 days)"
+ - and similar
+ * If there's not enough context to generated the summary name - keep it an empty string ("")
+ """
+ )
+
+
+class SessionSummarizationTool(MaxTool):
+ name: Literal["session_summarization"] = "session_summarization"
+ description: str = SESSION_SUMMARIZATION_TOOL_PROMPT
+ thinking_message: str = "Summarizing session recordings"
+ context_prompt_template: str = "Summarizes session recordings based on the user's query and current filters"
+ args_schema: type[BaseModel] = SessionSummarizationToolArgs
+ show_tool_call_message: bool = False
+
+ async def _arun_impl(
+ self, session_summarization_query: str, should_use_current_filters: bool, summary_title: str
+ ) -> tuple[str, ToolMessagesArtifact | None]:
+ node = SessionSummarizationNode(self._team, self._user)
+ chain: RunnableLambda[AssistantState, PartialAssistantState | None] = RunnableLambda(node)
+ copied_state = self._state.model_copy(
+ deep=True,
+ update={
+ "root_tool_call_id": self.tool_call_id,
+ "session_summarization_query": session_summarization_query,
+ "should_use_current_filters": should_use_current_filters,
+ "summary_title": summary_title,
+ },
+ )
+ result = await chain.ainvoke(copied_state)
+ if not result or not result.messages:
+ return "Session summarization failed", None
+ return "", ToolMessagesArtifact(messages=result.messages)
diff --git a/ee/hogai/graph/root/tools/test/test_create_and_query_insight.py b/ee/hogai/graph/root/tools/test/test_create_and_query_insight.py
new file mode 100644
index 0000000000..c2926d53db
--- /dev/null
+++ b/ee/hogai/graph/root/tools/test/test_create_and_query_insight.py
@@ -0,0 +1,253 @@
+from typing import Any
+
+from posthog.test.base import ClickhouseTestMixin, NonAtomicBaseTest
+from unittest.mock import AsyncMock, patch
+
+from langchain_core.runnables import RunnableConfig
+
+from posthog.schema import (
+ AssistantMessage,
+ AssistantTool,
+ AssistantToolCallMessage,
+ AssistantTrendsQuery,
+ VisualizationMessage,
+)
+
+from ee.hogai.context.context import AssistantContextManager
+from ee.hogai.graph.root.tools.create_and_query_insight import (
+ INSIGHT_TOOL_FAILURE_SYSTEM_REMINDER_PROMPT,
+ CreateAndQueryInsightTool,
+)
+from ee.hogai.graph.schema_generator.nodes import SchemaGenerationException
+from ee.hogai.utils.types import AssistantState
+from ee.hogai.utils.types.base import NodePath
+from ee.models.assistant import Conversation
+
+
+class TestCreateAndQueryInsightTool(ClickhouseTestMixin, NonAtomicBaseTest):
+ CLASS_DATA_LEVEL_SETUP = False
+
+ def setUp(self):
+ super().setUp()
+ self.conversation = Conversation.objects.create(team=self.team, user=self.user)
+ self.tool_call_id = "test_tool_call_id"
+
+ def _create_tool(
+ self, state: AssistantState | None = None, contextual_tools: dict[str, dict[str, Any]] | None = None
+ ):
+ """Helper to create tool instance with optional state and contextual tools."""
+ if state is None:
+ state = AssistantState(messages=[])
+
+ config: RunnableConfig = RunnableConfig()
+ if contextual_tools:
+ config = RunnableConfig(configurable={"contextual_tools": contextual_tools})
+
+ context_manager = AssistantContextManager(team=self.team, user=self.user, config=config)
+
+ return CreateAndQueryInsightTool(
+ team=self.team,
+ user=self.user,
+ state=state,
+ context_manager=context_manager,
+ node_path=(NodePath(name="test_node", tool_call_id=self.tool_call_id, message_id="test"),),
+ )
+
+ async def test_successful_insight_creation_returns_messages(self):
+ """Test successful insight creation returns visualization and tool call messages."""
+ tool = self._create_tool()
+
+ query = AssistantTrendsQuery(series=[])
+ viz_message = VisualizationMessage(query="test query", answer=query, plan="test plan")
+ tool_call_message = AssistantToolCallMessage(content="Results are here", tool_call_id=self.tool_call_id)
+
+ mock_state = AssistantState(messages=[viz_message, tool_call_message])
+
+ with patch("ee.hogai.graph.insights_graph.graph.InsightsGraph.compile_full_graph") as mock_compile:
+ mock_graph = AsyncMock()
+ mock_graph.ainvoke = AsyncMock(return_value=mock_state.model_dump())
+ mock_compile.return_value = mock_graph
+
+ result_text, artifact = await tool._arun_impl(query_description="test description")
+
+ self.assertEqual(result_text, "")
+ self.assertIsNotNone(artifact)
+ self.assertEqual(len(artifact.messages), 2)
+ self.assertIsInstance(artifact.messages[0], VisualizationMessage)
+ self.assertIsInstance(artifact.messages[1], AssistantToolCallMessage)
+
+ async def test_schema_generation_exception_returns_formatted_error(self):
+ """Test SchemaGenerationException is caught and returns formatted error message."""
+ tool = self._create_tool()
+
+ exception = SchemaGenerationException(
+ llm_output="Invalid query structure", validation_message="Missing required field: series"
+ )
+
+ with patch("ee.hogai.graph.insights_graph.graph.InsightsGraph.compile_full_graph") as mock_compile:
+ mock_graph = AsyncMock()
+ mock_graph.ainvoke = AsyncMock(side_effect=exception)
+ mock_compile.return_value = mock_graph
+
+ result_text, artifact = await tool._arun_impl(query_description="test description")
+
+ self.assertIsNone(artifact)
+ self.assertIn("Invalid query structure", result_text)
+ self.assertIn("Missing required field: series", result_text)
+ self.assertIn(INSIGHT_TOOL_FAILURE_SYSTEM_REMINDER_PROMPT, result_text)
+
+ async def test_invalid_tool_call_message_type_returns_error(self):
+ """Test when the last message is not AssistantToolCallMessage, returns error."""
+ tool = self._create_tool()
+
+ viz_message = VisualizationMessage(query="test query", answer=AssistantTrendsQuery(series=[]), plan="test plan")
+ # Last message is AssistantMessage instead of AssistantToolCallMessage
+ invalid_message = AssistantMessage(content="Not a tool call message")
+
+ mock_state = AssistantState(messages=[viz_message, invalid_message])
+
+ with patch("ee.hogai.graph.insights_graph.graph.InsightsGraph.compile_full_graph") as mock_compile:
+ mock_graph = AsyncMock()
+ mock_graph.ainvoke = AsyncMock(return_value=mock_state.model_dump())
+ mock_compile.return_value = mock_graph
+
+ result_text, artifact = await tool._arun_impl(query_description="test description")
+
+ self.assertIsNone(artifact)
+ self.assertIn("unknown error", result_text)
+ self.assertIn(INSIGHT_TOOL_FAILURE_SYSTEM_REMINDER_PROMPT, result_text)
+
+ async def test_human_feedback_requested_returns_only_tool_call_message(self):
+ """Test when visualization message is not present, returns only tool call message."""
+ tool = self._create_tool()
+
+ # When agent requests human feedback, there's no VisualizationMessage
+ some_message = AssistantMessage(content="I need help with this query")
+ tool_call_message = AssistantToolCallMessage(content="Need clarification", tool_call_id=self.tool_call_id)
+
+ mock_state = AssistantState(messages=[some_message, tool_call_message])
+
+ with patch("ee.hogai.graph.insights_graph.graph.InsightsGraph.compile_full_graph") as mock_compile:
+ mock_graph = AsyncMock()
+ mock_graph.ainvoke = AsyncMock(return_value=mock_state.model_dump())
+ mock_compile.return_value = mock_graph
+
+ result_text, artifact = await tool._arun_impl(query_description="test description")
+
+ self.assertEqual(result_text, "")
+ self.assertIsNotNone(artifact)
+ self.assertEqual(len(artifact.messages), 1)
+ self.assertIsInstance(artifact.messages[0], AssistantToolCallMessage)
+ self.assertEqual(artifact.messages[0].content, "Need clarification")
+
+ async def test_editing_mode_adds_ui_payload(self):
+ """Test that in editing mode, UI payload is added to tool call message."""
+ # Create tool with contextual tool available
+ tool = self._create_tool(contextual_tools={AssistantTool.CREATE_AND_QUERY_INSIGHT.value: {}})
+
+ query = AssistantTrendsQuery(series=[])
+ viz_message = VisualizationMessage(query="test query", answer=query, plan="test plan")
+ tool_call_message = AssistantToolCallMessage(content="Results are here", tool_call_id=self.tool_call_id)
+
+ mock_state = AssistantState(messages=[viz_message, tool_call_message])
+
+ with patch("ee.hogai.graph.insights_graph.graph.InsightsGraph.compile_full_graph") as mock_compile:
+ mock_graph = AsyncMock()
+ mock_graph.ainvoke = AsyncMock(return_value=mock_state.model_dump())
+ mock_compile.return_value = mock_graph
+
+ result_text, artifact = await tool._arun_impl(query_description="test description")
+
+ self.assertEqual(result_text, "")
+ self.assertIsNotNone(artifact)
+ self.assertEqual(len(artifact.messages), 2)
+
+ # Check that UI payload was added to tool call message
+ returned_tool_call_message = artifact.messages[1]
+ self.assertIsInstance(returned_tool_call_message, AssistantToolCallMessage)
+ self.assertIsNotNone(returned_tool_call_message.ui_payload)
+ self.assertIn("create_and_query_insight", returned_tool_call_message.ui_payload)
+ self.assertEqual(
+ returned_tool_call_message.ui_payload["create_and_query_insight"], query.model_dump(exclude_none=True)
+ )
+
+ async def test_non_editing_mode_no_ui_payload(self):
+ """Test that in non-editing mode, no UI payload is added to tool call message."""
+ # Create tool without contextual tools (non-editing mode)
+ tool = self._create_tool(contextual_tools={})
+
+ query = AssistantTrendsQuery(series=[])
+ viz_message = VisualizationMessage(query="test query", answer=query, plan="test plan")
+ tool_call_message = AssistantToolCallMessage(content="Results are here", tool_call_id=self.tool_call_id)
+
+ mock_state = AssistantState(messages=[viz_message, tool_call_message])
+
+ with patch("ee.hogai.graph.insights_graph.graph.InsightsGraph.compile_full_graph") as mock_compile:
+ mock_graph = AsyncMock()
+ mock_graph.ainvoke = AsyncMock(return_value=mock_state.model_dump())
+ mock_compile.return_value = mock_graph
+
+ result_text, artifact = await tool._arun_impl(query_description="test description")
+
+ self.assertEqual(result_text, "")
+ self.assertIsNotNone(artifact)
+ self.assertEqual(len(artifact.messages), 2)
+
+ # Check that the original tool call message is returned without modification
+ returned_tool_call_message = artifact.messages[1]
+ self.assertIsInstance(returned_tool_call_message, AssistantToolCallMessage)
+ # In non-editing mode, the original message is returned as-is
+ self.assertEqual(returned_tool_call_message, tool_call_message)
+
+ async def test_state_updates_include_tool_call_metadata(self):
+ """Test that the state passed to graph includes root_tool_call_id and root_tool_insight_plan."""
+ initial_state = AssistantState(messages=[AssistantMessage(content="initial")])
+ tool = self._create_tool(state=initial_state)
+
+ query = AssistantTrendsQuery(series=[])
+ viz_message = VisualizationMessage(query="test query", answer=query, plan="test plan")
+ tool_call_message = AssistantToolCallMessage(content="Results", tool_call_id=self.tool_call_id)
+ mock_state = AssistantState(messages=[viz_message, tool_call_message])
+
+ invoked_state = None
+
+ async def capture_invoked_state(state):
+ nonlocal invoked_state
+ invoked_state = state
+ return mock_state.model_dump()
+
+ with patch("ee.hogai.graph.insights_graph.graph.InsightsGraph.compile_full_graph") as mock_compile:
+ mock_graph = AsyncMock()
+ mock_graph.ainvoke = AsyncMock(side_effect=capture_invoked_state)
+ mock_compile.return_value = mock_graph
+
+ await tool._arun_impl(query_description="my test query")
+
+ # Verify the state passed to ainvoke has the correct metadata
+ self.assertIsNotNone(invoked_state)
+ validated_state = AssistantState.model_validate(invoked_state)
+ self.assertEqual(validated_state.root_tool_call_id, self.tool_call_id)
+ self.assertEqual(validated_state.root_tool_insight_plan, "my test query")
+ # Original message should still be there
+ self.assertEqual(len(validated_state.messages), 1)
+ assert isinstance(validated_state.messages[0], AssistantMessage)
+ self.assertEqual(validated_state.messages[0].content, "initial")
+
+ async def test_is_editing_mode_classmethod(self):
+ """Test the is_editing_mode class method correctly detects editing mode."""
+ # Test with editing mode enabled
+ config_editing: RunnableConfig = RunnableConfig(
+ configurable={"contextual_tools": {AssistantTool.CREATE_AND_QUERY_INSIGHT.value: {}}}
+ )
+ context_manager_editing = AssistantContextManager(team=self.team, user=self.user, config=config_editing)
+ self.assertTrue(CreateAndQueryInsightTool.is_editing_mode(context_manager_editing))
+
+ # Test with editing mode disabled
+ config_not_editing = RunnableConfig(configurable={"contextual_tools": {}})
+ context_manager_not_editing = AssistantContextManager(team=self.team, user=self.user, config=config_not_editing)
+ self.assertFalse(CreateAndQueryInsightTool.is_editing_mode(context_manager_not_editing))
+
+ # Test with other contextual tools but not create_and_query_insight
+ config_other = RunnableConfig(configurable={"contextual_tools": {"some_other_tool": {}}})
+ context_manager_other = AssistantContextManager(team=self.team, user=self.user, config=config_other)
+ self.assertFalse(CreateAndQueryInsightTool.is_editing_mode(context_manager_other))
diff --git a/ee/hogai/graph/root/tools/test/test_create_dashboard.py b/ee/hogai/graph/root/tools/test/test_create_dashboard.py
new file mode 100644
index 0000000000..39eb8e3150
--- /dev/null
+++ b/ee/hogai/graph/root/tools/test/test_create_dashboard.py
@@ -0,0 +1,248 @@
+from typing import cast
+from uuid import uuid4
+
+from posthog.test.base import ClickhouseTestMixin, NonAtomicBaseTest
+from unittest.mock import AsyncMock, MagicMock, patch
+
+from posthog.schema import AssistantMessage
+
+from ee.hogai.context.context import AssistantContextManager
+from ee.hogai.graph.root.tools.create_dashboard import CreateDashboardTool
+from ee.hogai.utils.types import AssistantState, InsightQuery, PartialAssistantState
+from ee.hogai.utils.types.base import NodePath
+
+
+class TestCreateDashboardTool(ClickhouseTestMixin, NonAtomicBaseTest):
+ CLASS_DATA_LEVEL_SETUP = False
+
+ def setUp(self):
+ super().setUp()
+ self.tool_call_id = str(uuid4())
+ self.state = AssistantState(messages=[], root_tool_call_id=self.tool_call_id)
+ self.context_manager = AssistantContextManager(self.team, self.user, {})
+ self.tool = CreateDashboardTool(
+ team=self.team,
+ user=self.user,
+ state=self.state,
+ context_manager=self.context_manager,
+ node_path=(NodePath(name="test_node", tool_call_id=self.tool_call_id, message_id="test"),),
+ )
+
+ async def test_execute_calls_dashboard_creation_node(self):
+ mock_node_instance = MagicMock()
+ mock_result = PartialAssistantState(
+ messages=[AssistantMessage(content="Dashboard created successfully with 3 insights")]
+ )
+
+ with patch("ee.hogai.graph.dashboards.nodes.DashboardCreationNode", return_value=mock_node_instance):
+ with patch("ee.hogai.graph.root.tools.create_dashboard.RunnableLambda") as mock_runnable:
+ mock_chain = MagicMock()
+ mock_chain.ainvoke = AsyncMock(return_value=mock_result)
+ mock_runnable.return_value = mock_chain
+
+ insight_queries = [
+ InsightQuery(name="Pageviews", description="Show pageviews for last 7 days"),
+ InsightQuery(name="User signups", description="Show user signups funnel"),
+ ]
+
+ result, artifact = await self.tool._arun_impl(
+ search_insights_queries=insight_queries,
+ dashboard_name="Marketing Dashboard",
+ )
+
+ self.assertEqual(result, "")
+ self.assertIsNotNone(artifact)
+ assert artifact is not None
+ self.assertEqual(len(artifact.messages), 1)
+ message = cast(AssistantMessage, artifact.messages[0])
+ self.assertEqual(message.content, "Dashboard created successfully with 3 insights")
+
+ async def test_execute_updates_state_with_all_parameters(self):
+ mock_result = PartialAssistantState(messages=[AssistantMessage(content="Test response")])
+
+ insight_queries = [
+ InsightQuery(name="Revenue Trends", description="Monthly revenue trends for Q4"),
+ InsightQuery(name="Churn Rate", description="Customer churn rate by cohort"),
+ InsightQuery(name="NPS Score", description="Net Promoter Score over time"),
+ ]
+
+ async def mock_ainvoke(state):
+ self.assertEqual(state.search_insights_queries, insight_queries)
+ self.assertEqual(state.dashboard_name, "Executive Summary Q4")
+ self.assertEqual(state.root_tool_call_id, self.tool_call_id)
+ return mock_result
+
+ mock_chain = MagicMock()
+ mock_chain.ainvoke = mock_ainvoke
+
+ with patch("ee.hogai.graph.dashboards.nodes.DashboardCreationNode"):
+ with patch("ee.hogai.graph.root.tools.create_dashboard.RunnableLambda", return_value=mock_chain):
+ await self.tool._arun_impl(
+ search_insights_queries=insight_queries,
+ dashboard_name="Executive Summary Q4",
+ )
+
+ async def test_execute_with_single_insight_query(self):
+ mock_result = PartialAssistantState(messages=[AssistantMessage(content="Dashboard with one insight created")])
+
+ insight_queries = [InsightQuery(name="Daily Active Users", description="Count of daily active users")]
+
+ async def mock_ainvoke(state):
+ self.assertEqual(len(state.search_insights_queries), 1)
+ self.assertEqual(state.search_insights_queries[0].name, "Daily Active Users")
+ self.assertEqual(state.search_insights_queries[0].description, "Count of daily active users")
+ self.assertEqual(state.dashboard_name, "User Activity Dashboard")
+ return mock_result
+
+ mock_chain = MagicMock()
+ mock_chain.ainvoke = mock_ainvoke
+
+ with patch("ee.hogai.graph.dashboards.nodes.DashboardCreationNode"):
+ with patch("ee.hogai.graph.root.tools.create_dashboard.RunnableLambda", return_value=mock_chain):
+ result, artifact = await self.tool._arun_impl(
+ search_insights_queries=insight_queries,
+ dashboard_name="User Activity Dashboard",
+ )
+
+ self.assertEqual(result, "")
+ self.assertIsNotNone(artifact)
+
+ async def test_execute_with_many_insight_queries(self):
+ mock_result = PartialAssistantState(messages=[AssistantMessage(content="Large dashboard created")])
+
+ insight_queries = [
+ InsightQuery(name=f"Insight {i}", description=f"Description for insight {i}") for i in range(10)
+ ]
+
+ async def mock_ainvoke(state):
+ self.assertEqual(len(state.search_insights_queries), 10)
+ self.assertEqual(state.search_insights_queries[0].name, "Insight 0")
+ self.assertEqual(state.search_insights_queries[9].name, "Insight 9")
+ self.assertEqual(state.dashboard_name, "Comprehensive Dashboard")
+ return mock_result
+
+ mock_chain = MagicMock()
+ mock_chain.ainvoke = mock_ainvoke
+
+ with patch("ee.hogai.graph.dashboards.nodes.DashboardCreationNode"):
+ with patch("ee.hogai.graph.root.tools.create_dashboard.RunnableLambda", return_value=mock_chain):
+ result, artifact = await self.tool._arun_impl(
+ search_insights_queries=insight_queries,
+ dashboard_name="Comprehensive Dashboard",
+ )
+
+ self.assertEqual(result, "")
+ self.assertIsNotNone(artifact)
+
+ async def test_execute_returns_failure_message_when_result_is_none(self):
+ async def mock_ainvoke(state):
+ return None
+
+ mock_chain = MagicMock()
+ mock_chain.ainvoke = mock_ainvoke
+
+ with patch("ee.hogai.graph.dashboards.nodes.DashboardCreationNode"):
+ with patch("ee.hogai.graph.root.tools.create_dashboard.RunnableLambda", return_value=mock_chain):
+ result, artifact = await self.tool._arun_impl(
+ search_insights_queries=[InsightQuery(name="Test", description="Test insight")],
+ dashboard_name="Test Dashboard",
+ )
+
+ self.assertEqual(result, "Dashboard creation failed")
+ self.assertIsNone(artifact)
+
+ async def test_execute_returns_failure_message_when_result_has_no_messages(self):
+ mock_result = PartialAssistantState(messages=[])
+
+ async def mock_ainvoke(state):
+ return mock_result
+
+ mock_chain = MagicMock()
+ mock_chain.ainvoke = mock_ainvoke
+
+ with patch("ee.hogai.graph.dashboards.nodes.DashboardCreationNode"):
+ with patch("ee.hogai.graph.root.tools.create_dashboard.RunnableLambda", return_value=mock_chain):
+ result, artifact = await self.tool._arun_impl(
+ search_insights_queries=[InsightQuery(name="Test", description="Test insight")],
+ dashboard_name="Test Dashboard",
+ )
+
+ self.assertEqual(result, "Dashboard creation failed")
+ self.assertIsNone(artifact)
+
+ async def test_execute_preserves_original_state(self):
+ """Test that the original state is not modified when creating the copied state"""
+ original_queries = [InsightQuery(name="Original", description="Original insight")]
+ original_state = AssistantState(
+ messages=[],
+ root_tool_call_id=self.tool_call_id,
+ search_insights_queries=original_queries,
+ dashboard_name="Original Dashboard",
+ )
+
+ tool = CreateDashboardTool(
+ team=self.team,
+ user=self.user,
+ state=original_state,
+ context_manager=self.context_manager,
+ node_path=(NodePath(name="test_node", tool_call_id=self.tool_call_id, message_id="test"),),
+ )
+
+ mock_result = PartialAssistantState(messages=[AssistantMessage(content="Test")])
+
+ new_queries = [InsightQuery(name="New", description="New insight")]
+
+ async def mock_ainvoke(state):
+ # Verify the new state has updated values
+ self.assertEqual(state.search_insights_queries, new_queries)
+ self.assertEqual(state.dashboard_name, "New Dashboard")
+ return mock_result
+
+ mock_chain = MagicMock()
+ mock_chain.ainvoke = mock_ainvoke
+
+ with patch("ee.hogai.graph.dashboards.nodes.DashboardCreationNode"):
+ with patch("ee.hogai.graph.root.tools.create_dashboard.RunnableLambda", return_value=mock_chain):
+ await tool._arun_impl(
+ search_insights_queries=new_queries,
+ dashboard_name="New Dashboard",
+ )
+
+ # Verify original state was not modified
+ self.assertEqual(original_state.search_insights_queries, original_queries)
+ self.assertEqual(original_state.dashboard_name, "Original Dashboard")
+ self.assertEqual(original_state.root_tool_call_id, self.tool_call_id)
+
+ async def test_execute_with_complex_insight_descriptions(self):
+ """Test that complex insight descriptions with special characters are handled correctly"""
+ mock_result = PartialAssistantState(messages=[AssistantMessage(content="Dashboard created")])
+
+ insight_queries = [
+ InsightQuery(
+ name="User Journey",
+ description='Show funnel from "Sign up" → "Create project" → "Invite team" for users with property email containing "@company.com"',
+ ),
+ InsightQuery(
+ name="Revenue (USD)",
+ description="Track total revenue in USD from 2024-01-01 to 2024-12-31, filtered by plan_type = 'premium' OR plan_type = 'enterprise'",
+ ),
+ ]
+
+ async def mock_ainvoke(state):
+ self.assertEqual(len(state.search_insights_queries), 2)
+ self.assertIn("@company.com", state.search_insights_queries[0].description)
+ self.assertIn("'premium'", state.search_insights_queries[1].description)
+ return mock_result
+
+ mock_chain = MagicMock()
+ mock_chain.ainvoke = mock_ainvoke
+
+ with patch("ee.hogai.graph.dashboards.nodes.DashboardCreationNode"):
+ with patch("ee.hogai.graph.root.tools.create_dashboard.RunnableLambda", return_value=mock_chain):
+ result, artifact = await self.tool._arun_impl(
+ search_insights_queries=insight_queries,
+ dashboard_name="Complex Dashboard",
+ )
+
+ self.assertEqual(result, "")
+ self.assertIsNotNone(artifact)
diff --git a/ee/hogai/graph/root/tools/test/test_search.py b/ee/hogai/graph/root/tools/test/test_search.py
index 4bbdb14b9f..5077b79eaa 100644
--- a/ee/hogai/graph/root/tools/test/test_search.py
+++ b/ee/hogai/graph/root/tools/test/test_search.py
@@ -1,27 +1,178 @@
-from posthog.test.base import BaseTest
-from unittest.mock import patch
+from typing import cast
+from uuid import uuid4
+
+from posthog.test.base import ClickhouseTestMixin, NonAtomicBaseTest
+from unittest.mock import AsyncMock, MagicMock, patch
from django.test import override_settings
from langchain_core import messages
+from langchain_core.runnables import RunnableConfig
+from posthog.schema import AssistantMessage
+
+from ee.hogai.context.context import AssistantContextManager
from ee.hogai.graph.root.tools.search import (
DOC_ITEM_TEMPLATE,
- DOCS_SEARCH_NO_RESULTS_TEMPLATE,
DOCS_SEARCH_RESULTS_TEMPLATE,
+ EMPTY_DATABASE_ERROR_MESSAGE,
+ InkeepDocsSearchTool,
+ InsightSearchTool,
SearchTool,
)
from ee.hogai.utils.tests import FakeChatOpenAI
+from ee.hogai.utils.types import AssistantState, PartialAssistantState
+from ee.hogai.utils.types.base import NodePath
-class TestSearchToolDocumentation(BaseTest):
+class TestSearchTool(ClickhouseTestMixin, NonAtomicBaseTest):
+ CLASS_DATA_LEVEL_SETUP = False
+
def setUp(self):
super().setUp()
- self.tool = SearchTool(team=self.team, user=self.user, tool_call_id="test-tool-call-id")
+ self.tool_call_id = "test_tool_call_id"
+ self.state = AssistantState(messages=[], root_tool_call_id=str(uuid4()))
+ self.context_manager = AssistantContextManager(self.team, self.user, {})
+ self.tool = SearchTool(
+ team=self.team,
+ user=self.user,
+ state=self.state,
+ context_manager=self.context_manager,
+ node_path=(NodePath(name="test_node", tool_call_id=self.tool_call_id, message_id="test"),),
+ )
+
+ async def test_run_docs_search_without_api_key(self):
+ with patch("ee.hogai.graph.root.tools.search.settings") as mock_settings:
+ mock_settings.INKEEP_API_KEY = None
+ result, artifact = await self.tool._arun_impl(kind="docs", query="How to use feature flags?")
+ self.assertEqual(result, "This tool is not available in this environment.")
+ self.assertIsNone(artifact)
+
+ async def test_run_docs_search_with_api_key(self):
+ mock_docs_tool = MagicMock()
+ mock_docs_tool.execute = AsyncMock(return_value=("", MagicMock()))
+
+ with (
+ patch("ee.hogai.graph.root.tools.search.settings") as mock_settings,
+ patch("ee.hogai.graph.root.tools.search.InkeepDocsSearchTool", return_value=mock_docs_tool),
+ ):
+ mock_settings.INKEEP_API_KEY = "test-key"
+ result, artifact = await self.tool._arun_impl(kind="docs", query="How to use feature flags?")
+
+ mock_docs_tool.execute.assert_called_once_with("How to use feature flags?", self.tool_call_id)
+ self.assertEqual(result, "")
+ self.assertIsNotNone(artifact)
+
+ async def test_run_insights_search(self):
+ mock_insights_tool = MagicMock()
+ mock_insights_tool.execute = AsyncMock(return_value=("", MagicMock()))
+
+ with patch("ee.hogai.graph.root.tools.search.InsightSearchTool", return_value=mock_insights_tool):
+ result, artifact = await self.tool._arun_impl(kind="insights", query="user signups")
+
+ mock_insights_tool.execute.assert_called_once_with("user signups", self.tool_call_id)
+ self.assertEqual(result, "")
+ self.assertIsNotNone(artifact)
+
+ async def test_run_unknown_kind(self):
+ result, artifact = await self.tool._arun_impl(kind="unknown", query="test")
+ self.assertEqual(result, "Invalid entity kind: unknown. Please provide a valid entity kind for the tool.")
+ self.assertIsNone(artifact)
+
+ @patch("ee.hogai.graph.root.tools.search.EntitySearchTool.execute")
+ async def test_arun_impl_error_tracking_issues_returns_routing_data(self, mock_execute):
+ mock_execute.return_value = "Search results for error tracking issues"
+
+ result, artifact = await self.tool._arun_impl(
+ kind="error_tracking_issues", query="test error tracking issue query"
+ )
+
+ self.assertEqual(result, "Search results for error tracking issues")
+ self.assertIsNone(artifact)
+ mock_execute.assert_called_once_with("test error tracking issue query", "error_tracking_issues")
+
+ @patch("ee.hogai.graph.root.tools.search.EntitySearchTool.execute")
+ @patch("ee.hogai.graph.root.tools.search.SearchTool._has_insights_fts_search_feature_flag")
+ async def test_arun_impl_insight_with_feature_flag_disabled(
+ self, mock_has_insights_fts_search_feature_flag, mock_execute
+ ):
+ mock_has_insights_fts_search_feature_flag.return_value = False
+ mock_execute.return_value = "Search results for insights"
+
+ result, artifact = await self.tool._arun_impl(kind="insights", query="test insight query")
+
+ self.assertEqual(result, "The user doesn't have any insights created yet.")
+ self.assertIsNone(artifact)
+ mock_execute.assert_not_called()
+
+ @patch("ee.hogai.graph.root.tools.search.EntitySearchTool.execute")
+ @patch("ee.hogai.graph.root.tools.search.SearchTool._has_insights_fts_search_feature_flag")
+ async def test_arun_impl_insight_with_feature_flag_enabled(
+ self, mock_has_insights_fts_search_feature_flag, mock_execute
+ ):
+ mock_has_insights_fts_search_feature_flag.return_value = True
+ mock_execute.return_value = "Search results for insights"
+
+ result, artifact = await self.tool._arun_impl(kind="insights", query="test insight query")
+
+ self.assertEqual(result, "Search results for insights")
+ self.assertIsNone(artifact)
+ mock_execute.assert_called_once_with("test insight query", "insights")
+
+
+class TestInkeepDocsSearchTool(ClickhouseTestMixin, NonAtomicBaseTest):
+ CLASS_DATA_LEVEL_SETUP = False
+
+ def setUp(self):
+ super().setUp()
+ self.tool_call_id = str(uuid4())
+ self.state = AssistantState(messages=[], root_tool_call_id=self.tool_call_id)
+ self.context_manager = AssistantContextManager(self.team, self.user, {})
+ self.tool = InkeepDocsSearchTool(
+ team=self.team,
+ user=self.user,
+ state=self.state,
+ config=RunnableConfig(configurable={}),
+ context_manager=self.context_manager,
+ )
+
+ async def test_execute_calls_inkeep_docs_node(self):
+ mock_node_instance = MagicMock()
+ mock_result = PartialAssistantState(messages=[AssistantMessage(content="Here is the answer from docs")])
+
+ with patch("ee.hogai.graph.inkeep_docs.nodes.InkeepDocsNode", return_value=mock_node_instance):
+ with patch("ee.hogai.graph.root.tools.search.RunnableLambda") as mock_runnable:
+ mock_chain = MagicMock()
+ mock_chain.ainvoke = AsyncMock(return_value=mock_result)
+ mock_runnable.return_value = mock_chain
+
+ result, artifact = await self.tool.execute("How to track events?", "test-tool-call-id")
+
+ self.assertEqual(result, "")
+ self.assertIsNotNone(artifact)
+ assert artifact is not None
+ self.assertEqual(len(artifact.messages), 1)
+ message = cast(AssistantMessage, artifact.messages[0])
+ self.assertEqual(message.content, "Here is the answer from docs")
+
+ async def test_execute_updates_state_with_tool_call_id(self):
+ mock_result = PartialAssistantState(messages=[AssistantMessage(content="Test response")])
+
+ async def mock_ainvoke(state):
+ self.assertEqual(state.root_tool_call_id, "custom-tool-call-id")
+ return mock_result
+
+ mock_chain = MagicMock()
+ mock_chain.ainvoke = mock_ainvoke
+
+ with patch("ee.hogai.graph.inkeep_docs.nodes.InkeepDocsNode"):
+ with patch("ee.hogai.graph.root.tools.search.RunnableLambda", return_value=mock_chain):
+ await self.tool.execute("test query", "custom-tool-call-id")
@override_settings(INKEEP_API_KEY="test-inkeep-key")
@patch("ee.hogai.graph.root.tools.search.ChatOpenAI")
- async def test_search_docs_with_successful_results(self, mock_llm_class):
+ @patch("ee.hogai.graph.root.tools.search.InkeepDocsSearchTool._has_rag_docs_search_feature_flag", return_value=True)
+ async def test_search_docs_with_successful_results(self, mock_has_rag_docs_search_feature_flag, mock_llm_class):
response_json = """{
"content": [
{
@@ -47,7 +198,7 @@ class TestSearchToolDocumentation(BaseTest):
fake_llm = FakeChatOpenAI(responses=[messages.AIMessage(content=response_json)])
mock_llm_class.return_value = fake_llm
- result = await self.tool._search_docs("how to use feature")
+ result, _ = await self.tool.execute("how to use feature", "test-tool-call-id")
expected_doc_1 = DOC_ITEM_TEMPLATE.format(
title="Feature Documentation",
@@ -68,196 +219,76 @@ class TestSearchToolDocumentation(BaseTest):
self.assertEqual(mock_llm_class.call_args.kwargs["api_key"], "test-inkeep-key")
self.assertEqual(mock_llm_class.call_args.kwargs["streaming"], False)
- @override_settings(INKEEP_API_KEY="test-inkeep-key")
- @patch("ee.hogai.graph.root.tools.search.ChatOpenAI")
- async def test_search_docs_with_no_results(self, mock_llm_class):
- fake_llm = FakeChatOpenAI(responses=[messages.AIMessage(content="{}")])
- mock_llm_class.return_value = fake_llm
- result = await self.tool._search_docs("nonexistent feature")
+class TestInsightSearchTool(ClickhouseTestMixin, NonAtomicBaseTest):
+ CLASS_DATA_LEVEL_SETUP = False
- self.assertEqual(result, DOCS_SEARCH_NO_RESULTS_TEMPLATE)
-
- @override_settings(INKEEP_API_KEY="test-inkeep-key")
- @patch("ee.hogai.graph.root.tools.search.ChatOpenAI")
- async def test_search_docs_with_empty_content(self, mock_llm_class):
- fake_llm = FakeChatOpenAI(responses=[messages.AIMessage(content='{"content": []}')])
- mock_llm_class.return_value = fake_llm
-
- result = await self.tool._search_docs("query")
-
- self.assertEqual(result, DOCS_SEARCH_NO_RESULTS_TEMPLATE)
-
- @override_settings(INKEEP_API_KEY="test-inkeep-key")
- @patch("ee.hogai.graph.root.tools.search.ChatOpenAI")
- async def test_search_docs_filters_non_document_types(self, mock_llm_class):
- response_json = """{
- "content": [
- {
- "type": "snippet",
- "record_type": "code",
- "url": "https://posthog.com/code",
- "title": "Code Snippet",
- "source": {"type": "text", "content": [{"type": "text", "text": "Code example"}]}
- },
- {
- "type": "answer",
- "record_type": "answer",
- "url": "https://posthog.com/answer",
- "title": "Answer",
- "source": {"type": "text", "content": [{"type": "text", "text": "Answer text"}]}
- }
- ]
- }"""
-
- fake_llm = FakeChatOpenAI(responses=[messages.AIMessage(content=response_json)])
- mock_llm_class.return_value = fake_llm
-
- result = await self.tool._search_docs("query")
-
- self.assertEqual(result, DOCS_SEARCH_NO_RESULTS_TEMPLATE)
-
- @override_settings(INKEEP_API_KEY="test-inkeep-key")
- @patch("ee.hogai.graph.root.tools.search.ChatOpenAI")
- async def test_search_docs_handles_empty_source_content(self, mock_llm_class):
- response_json = """{
- "content": [
- {
- "type": "document",
- "record_type": "page",
- "url": "https://posthog.com/docs/feature",
- "title": "Feature Documentation",
- "source": {"type": "text", "content": []}
- }
- ]
- }"""
-
- fake_llm = FakeChatOpenAI(responses=[messages.AIMessage(content=response_json)])
- mock_llm_class.return_value = fake_llm
-
- result = await self.tool._search_docs("query")
-
- expected_doc = DOC_ITEM_TEMPLATE.format(
- title="Feature Documentation", url="https://posthog.com/docs/feature", text=""
- )
- expected_result = DOCS_SEARCH_RESULTS_TEMPLATE.format(count=1, docs=expected_doc)
-
- self.assertEqual(result, expected_result)
-
- @override_settings(INKEEP_API_KEY="test-inkeep-key")
- @patch("ee.hogai.graph.root.tools.search.ChatOpenAI")
- async def test_search_docs_handles_mixed_document_types(self, mock_llm_class):
- response_json = """{
- "content": [
- {
- "type": "document",
- "record_type": "page",
- "url": "https://posthog.com/docs/valid",
- "title": "Valid Doc",
- "source": {"type": "text", "content": [{"type": "text", "text": "Valid content"}]}
- },
- {
- "type": "snippet",
- "record_type": "code",
- "url": "https://posthog.com/code",
- "title": "Code Snippet",
- "source": {"type": "text", "content": [{"type": "text", "text": "Code"}]}
- },
- {
- "type": "document",
- "record_type": "guide",
- "url": "https://posthog.com/docs/another",
- "title": "Another Valid Doc",
- "source": {"type": "text", "content": [{"type": "text", "text": "More content"}]}
- }
- ]
- }"""
-
- fake_llm = FakeChatOpenAI(responses=[messages.AIMessage(content=response_json)])
- mock_llm_class.return_value = fake_llm
-
- result = await self.tool._search_docs("query")
-
- expected_doc_1 = DOC_ITEM_TEMPLATE.format(
- title="Valid Doc", url="https://posthog.com/docs/valid", text="Valid content"
- )
- expected_doc_2 = DOC_ITEM_TEMPLATE.format(
- title="Another Valid Doc", url="https://posthog.com/docs/another", text="More content"
- )
- expected_result = DOCS_SEARCH_RESULTS_TEMPLATE.format(
- count=2, docs=f"{expected_doc_1}\n\n---\n\n{expected_doc_2}"
+ def setUp(self):
+ super().setUp()
+ self.tool_call_id = str(uuid4())
+ self.state = AssistantState(messages=[], root_tool_call_id=self.tool_call_id)
+ self.context_manager = AssistantContextManager(self.team, self.user, {})
+ self.tool = InsightSearchTool(
+ team=self.team,
+ user=self.user,
+ state=self.state,
+ config=RunnableConfig(configurable={}),
+ context_manager=self.context_manager,
)
- self.assertEqual(result, expected_result)
+ async def test_execute_calls_insight_search_node(self):
+ mock_node_instance = MagicMock()
+ mock_result = PartialAssistantState(messages=[AssistantMessage(content="Found 3 insights matching your query")])
- @override_settings(INKEEP_API_KEY=None)
- async def test_arun_impl_docs_without_api_key(self):
- result, artifact = await self.tool._arun_impl(kind="docs", query="test query")
+ with patch("ee.hogai.graph.insights.nodes.InsightSearchNode", return_value=mock_node_instance):
+ with patch("ee.hogai.graph.root.tools.search.RunnableLambda") as mock_runnable:
+ mock_chain = MagicMock()
+ mock_chain.ainvoke = AsyncMock(return_value=mock_result)
+ mock_runnable.return_value = mock_chain
- self.assertEqual(result, "This tool is not available in this environment.")
- self.assertIsNone(artifact)
+ result, artifact = await self.tool.execute("user signups by week", self.tool_call_id)
- @override_settings(INKEEP_API_KEY="test-key")
- @patch("ee.hogai.graph.root.tools.search.posthoganalytics.feature_enabled")
- async def test_arun_impl_docs_with_feature_flag_disabled(self, mock_feature_enabled):
- mock_feature_enabled.return_value = False
+ self.assertEqual(result, "")
+ self.assertIsNotNone(artifact)
+ assert artifact is not None
+ self.assertEqual(len(artifact.messages), 1)
+ message = cast(AssistantMessage, artifact.messages[0])
+ self.assertEqual(message.content, "Found 3 insights matching your query")
- result, artifact = await self.tool._arun_impl(kind="docs", query="test query")
+ async def test_execute_updates_state_with_search_query(self):
+ mock_result = PartialAssistantState(messages=[AssistantMessage(content="Test response")])
- self.assertEqual(result, "Search tool executed")
- self.assertEqual(artifact, {"kind": "docs", "query": "test query"})
+ async def mock_ainvoke(state):
+ self.assertEqual(state.search_insights_query, "custom search query")
+ self.assertEqual(state.root_tool_call_id, self.tool_call_id)
+ return mock_result
- @override_settings(INKEEP_API_KEY="test-key")
- @patch("ee.hogai.graph.root.tools.search.posthoganalytics.feature_enabled")
- @patch.object(SearchTool, "_search_docs")
- async def test_arun_impl_docs_with_feature_flag_enabled(self, mock_search_docs, mock_feature_enabled):
- mock_feature_enabled.return_value = True
- mock_search_docs.return_value = "Search results"
+ mock_chain = MagicMock()
+ mock_chain.ainvoke = mock_ainvoke
- result, artifact = await self.tool._arun_impl(kind="docs", query="test query")
+ with patch("ee.hogai.graph.insights.nodes.InsightSearchNode"):
+ with patch("ee.hogai.graph.root.tools.search.RunnableLambda", return_value=mock_chain):
+ await self.tool.execute("custom search query", self.tool_call_id)
- self.assertEqual(result, "Search results")
- self.assertIsNone(artifact)
- mock_search_docs.assert_called_once_with("test query")
+ async def test_execute_handles_no_insights_exception(self):
+ from ee.hogai.graph.insights.nodes import NoInsightsException
- async def test_arun_impl_insights_returns_routing_data(self):
- result, artifact = await self.tool._arun_impl(kind="insights", query="test insight query")
+ with patch("ee.hogai.graph.insights.nodes.InsightSearchNode", side_effect=NoInsightsException()):
+ result, artifact = await self.tool.execute("user signups", self.tool_call_id)
- self.assertEqual(result, "Search tool executed")
- self.assertEqual(artifact, {"kind": "insights", "query": "test insight query"})
+ self.assertEqual(result, EMPTY_DATABASE_ERROR_MESSAGE)
+ self.assertIsNone(artifact)
- @patch("ee.hogai.graph.root.tools.search.EntitySearchToolkit.execute")
- async def test_arun_impl_error_tracking_issues_returns_routing_data(self, mock_execute):
- mock_execute.return_value = "Search results for error tracking issues"
+ async def test_execute_returns_none_artifact_when_result_is_none(self):
+ async def mock_ainvoke(state):
+ return None
- result, artifact = await self.tool._arun_impl(
- kind="error_tracking_issues", query="test error tracking issue query"
- )
+ mock_chain = MagicMock()
+ mock_chain.ainvoke = mock_ainvoke
- self.assertEqual(result, "Search results for error tracking issues")
- self.assertIsNone(artifact)
- mock_execute.assert_called_once_with("test error tracking issue query", "error_tracking_issues")
+ with patch("ee.hogai.graph.insights.nodes.InsightSearchNode"):
+ with patch("ee.hogai.graph.root.tools.search.RunnableLambda", return_value=mock_chain):
+ result, artifact = await self.tool.execute("test query", self.tool_call_id)
- @patch("ee.hogai.graph.root.tools.search.EntitySearchToolkit.execute")
- @patch("ee.hogai.graph.root.tools.search.SearchTool._has_fts_search_feature_flag")
- async def test_arun_impl_insight_with_feature_flag_disabled(self, mock_has_fts_search_feature_flag, mock_execute):
- mock_has_fts_search_feature_flag.return_value = False
- mock_execute.return_value = "Search results for insights"
-
- result, artifact = await self.tool._arun_impl(kind="insights", query="test insight query")
-
- self.assertEqual(result, "Search tool executed")
- self.assertEqual(artifact, {"kind": "insights", "query": "test insight query"})
- mock_execute.assert_not_called()
-
- @patch("ee.hogai.graph.root.tools.search.EntitySearchToolkit.execute")
- @patch("ee.hogai.graph.root.tools.search.SearchTool._has_fts_search_feature_flag")
- async def test_arun_impl_insight_with_feature_flag_enabled(self, mock_has_fts_search_feature_flag, mock_execute):
- mock_has_fts_search_feature_flag.return_value = True
- mock_execute.return_value = "Search results for insights"
-
- result, artifact = await self.tool._arun_impl(kind="insights", query="test insight query")
-
- self.assertEqual(result, "Search results for insights")
- self.assertIsNone(artifact)
- mock_execute.assert_called_once_with("test insight query", "insights")
+ self.assertEqual(result, "")
+ self.assertIsNone(artifact)
diff --git a/ee/hogai/graph/root/tools/test/test_session_summarization.py b/ee/hogai/graph/root/tools/test/test_session_summarization.py
new file mode 100644
index 0000000000..ff6d4965c6
--- /dev/null
+++ b/ee/hogai/graph/root/tools/test/test_session_summarization.py
@@ -0,0 +1,193 @@
+from typing import cast
+from uuid import uuid4
+
+from posthog.test.base import ClickhouseTestMixin, NonAtomicBaseTest
+from unittest.mock import AsyncMock, MagicMock, patch
+
+from posthog.schema import AssistantMessage
+
+from ee.hogai.context.context import AssistantContextManager
+from ee.hogai.graph.root.tools.session_summarization import SessionSummarizationTool
+from ee.hogai.utils.types import AssistantState, PartialAssistantState
+from ee.hogai.utils.types.base import NodePath
+
+
+class TestSessionSummarizationTool(ClickhouseTestMixin, NonAtomicBaseTest):
+ CLASS_DATA_LEVEL_SETUP = False
+
+ def setUp(self):
+ super().setUp()
+ self.tool_call_id = str(uuid4())
+ self.state = AssistantState(messages=[], root_tool_call_id=self.tool_call_id)
+ self.context_manager = AssistantContextManager(self.team, self.user, {})
+ self.tool = SessionSummarizationTool(
+ team=self.team,
+ user=self.user,
+ state=self.state,
+ context_manager=self.context_manager,
+ node_path=(NodePath(name="test_node", tool_call_id=self.tool_call_id, message_id="test"),),
+ )
+
+ async def test_execute_calls_session_summarization_node(self):
+ mock_node_instance = MagicMock()
+ mock_result = PartialAssistantState(
+ messages=[AssistantMessage(content="Session summary: 10 sessions analyzed with 5 key patterns found")]
+ )
+
+ with patch("ee.hogai.graph.session_summaries.nodes.SessionSummarizationNode", return_value=mock_node_instance):
+ with patch("ee.hogai.graph.root.tools.session_summarization.RunnableLambda") as mock_runnable:
+ mock_chain = MagicMock()
+ mock_chain.ainvoke = AsyncMock(return_value=mock_result)
+ mock_runnable.return_value = mock_chain
+
+ result, artifact = await self.tool._arun_impl(
+ session_summarization_query="summarize all sessions from yesterday",
+ should_use_current_filters=False,
+ summary_title="All sessions from yesterday",
+ )
+
+ self.assertEqual(result, "")
+ self.assertIsNotNone(artifact)
+ assert artifact is not None
+ self.assertEqual(len(artifact.messages), 1)
+ message = cast(AssistantMessage, artifact.messages[0])
+ self.assertEqual(message.content, "Session summary: 10 sessions analyzed with 5 key patterns found")
+
+ async def test_execute_updates_state_with_all_parameters(self):
+ mock_result = PartialAssistantState(messages=[AssistantMessage(content="Test response")])
+
+ async def mock_ainvoke(state):
+ self.assertEqual(state.session_summarization_query, "analyze mobile user sessions")
+ self.assertEqual(state.should_use_current_filters, True)
+ self.assertEqual(state.summary_title, "Mobile user sessions")
+ self.assertEqual(state.root_tool_call_id, self.tool_call_id)
+ return mock_result
+
+ mock_chain = MagicMock()
+ mock_chain.ainvoke = mock_ainvoke
+
+ with patch("ee.hogai.graph.session_summaries.nodes.SessionSummarizationNode"):
+ with patch("ee.hogai.graph.root.tools.session_summarization.RunnableLambda", return_value=mock_chain):
+ await self.tool._arun_impl(
+ session_summarization_query="analyze mobile user sessions",
+ should_use_current_filters=True,
+ summary_title="Mobile user sessions",
+ )
+
+ async def test_execute_with_should_use_current_filters_false(self):
+ mock_result = PartialAssistantState(messages=[AssistantMessage(content="Test response")])
+
+ async def mock_ainvoke(state):
+ self.assertEqual(state.should_use_current_filters, False)
+ self.assertEqual(state.session_summarization_query, "watch last 300 session recordings")
+ self.assertEqual(state.summary_title, "Last 300 sessions")
+ return mock_result
+
+ mock_chain = MagicMock()
+ mock_chain.ainvoke = mock_ainvoke
+
+ with patch("ee.hogai.graph.session_summaries.nodes.SessionSummarizationNode"):
+ with patch("ee.hogai.graph.root.tools.session_summarization.RunnableLambda", return_value=mock_chain):
+ await self.tool._arun_impl(
+ session_summarization_query="watch last 300 session recordings",
+ should_use_current_filters=False,
+ summary_title="Last 300 sessions",
+ )
+
+ async def test_execute_returns_failure_message_when_result_is_none(self):
+ async def mock_ainvoke(state):
+ return None
+
+ mock_chain = MagicMock()
+ mock_chain.ainvoke = mock_ainvoke
+
+ with patch("ee.hogai.graph.session_summaries.nodes.SessionSummarizationNode"):
+ with patch("ee.hogai.graph.root.tools.session_summarization.RunnableLambda", return_value=mock_chain):
+ result, artifact = await self.tool._arun_impl(
+ session_summarization_query="test query",
+ should_use_current_filters=False,
+ summary_title="Test",
+ )
+
+ self.assertEqual(result, "Session summarization failed")
+ self.assertIsNone(artifact)
+
+ async def test_execute_returns_failure_message_when_result_has_no_messages(self):
+ mock_result = PartialAssistantState(messages=[])
+
+ async def mock_ainvoke(state):
+ return mock_result
+
+ mock_chain = MagicMock()
+ mock_chain.ainvoke = mock_ainvoke
+
+ with patch("ee.hogai.graph.session_summaries.nodes.SessionSummarizationNode"):
+ with patch("ee.hogai.graph.root.tools.session_summarization.RunnableLambda", return_value=mock_chain):
+ result, artifact = await self.tool._arun_impl(
+ session_summarization_query="test query",
+ should_use_current_filters=False,
+ summary_title="Test",
+ )
+
+ self.assertEqual(result, "Session summarization failed")
+ self.assertIsNone(artifact)
+
+ async def test_execute_with_empty_summary_title(self):
+ mock_result = PartialAssistantState(messages=[AssistantMessage(content="Summary completed")])
+
+ async def mock_ainvoke(state):
+ self.assertEqual(state.summary_title, "")
+ return mock_result
+
+ mock_chain = MagicMock()
+ mock_chain.ainvoke = mock_ainvoke
+
+ with patch("ee.hogai.graph.session_summaries.nodes.SessionSummarizationNode"):
+ with patch("ee.hogai.graph.root.tools.session_summarization.RunnableLambda", return_value=mock_chain):
+ result, artifact = await self.tool._arun_impl(
+ session_summarization_query="summarize sessions",
+ should_use_current_filters=False,
+ summary_title="",
+ )
+
+ self.assertEqual(result, "")
+ self.assertIsNotNone(artifact)
+
+ async def test_execute_preserves_original_state(self):
+ """Test that the original state is not modified when creating the copied state"""
+ original_query = "original query"
+ original_state = AssistantState(
+ messages=[],
+ root_tool_call_id=self.tool_call_id,
+ session_summarization_query=original_query,
+ )
+
+ tool = SessionSummarizationTool(
+ team=self.team,
+ user=self.user,
+ state=original_state,
+ context_manager=self.context_manager,
+ node_path=(NodePath(name="test_node", tool_call_id=self.tool_call_id, message_id="test"),),
+ )
+
+ mock_result = PartialAssistantState(messages=[AssistantMessage(content="Test")])
+
+ async def mock_ainvoke(state):
+ # Verify the new state has updated values
+ self.assertEqual(state.session_summarization_query, "new query")
+ return mock_result
+
+ mock_chain = MagicMock()
+ mock_chain.ainvoke = mock_ainvoke
+
+ with patch("ee.hogai.graph.session_summaries.nodes.SessionSummarizationNode"):
+ with patch("ee.hogai.graph.root.tools.session_summarization.RunnableLambda", return_value=mock_chain):
+ await tool._arun_impl(
+ session_summarization_query="new query",
+ should_use_current_filters=True,
+ summary_title="New Summary",
+ )
+
+ # Verify original state was not modified
+ self.assertEqual(original_state.session_summarization_query, original_query)
+ self.assertEqual(original_state.root_tool_call_id, self.tool_call_id)
diff --git a/ee/hogai/graph/root/tools/todo_write.py b/ee/hogai/graph/root/tools/todo_write.py
index 816c19b6ce..f8670eef8b 100644
--- a/ee/hogai/graph/root/tools/todo_write.py
+++ b/ee/hogai/graph/root/tools/todo_write.py
@@ -138,11 +138,11 @@ When unsure, use this tool. Proactive task management shows attentiveness and he
""".strip()
+# Has its unique schema that doesn't match the Deep Research schema
class TodoItem(BaseModel):
content: str = Field(..., min_length=1)
status: Literal["pending", "in_progress", "completed"]
id: str
- priority: Literal["low", "medium", "high"]
class TodoWriteToolArgs(BaseModel):
diff --git a/ee/hogai/graph/schema_generator/nodes.py b/ee/hogai/graph/schema_generator/nodes.py
index f2e957ac58..2149c58071 100644
--- a/ee/hogai/graph/schema_generator/nodes.py
+++ b/ee/hogai/graph/schema_generator/nodes.py
@@ -13,7 +13,7 @@ from langchain_core.messages import (
from langchain_core.prompts import ChatPromptTemplate, HumanMessagePromptTemplate
from langchain_core.runnables import RunnableConfig
-from posthog.schema import FailureMessage, VisualizationMessage
+from posthog.schema import VisualizationMessage
from posthog.models.group_type_mapping import GroupTypeMapping
@@ -36,6 +36,15 @@ from .utils import Q, SchemaGeneratorOutput
RETRIES_ALLOWED = 2
+class SchemaGenerationException(Exception):
+ """An error occurred while generating a schema in the `SchemaGeneratorNode` node."""
+
+ def __init__(self, llm_output: str, validation_message: str):
+ super().__init__("Failed to generate schema")
+ self.llm_output = llm_output
+ self.validation_message = validation_message
+
+
class SchemaGeneratorNode(AssistantNode, Generic[Q]):
INSIGHT_NAME: str
"""
@@ -87,9 +96,8 @@ class SchemaGeneratorNode(AssistantNode, Generic[Q]):
chain = generation_prompt | merger | self._model | self._parse_output
- result: SchemaGeneratorOutput[Q] | None = None
try:
- result = await chain.ainvoke(
+ result: SchemaGeneratorOutput[Q] = await chain.ainvoke(
{
"project_datetime": self.project_now,
"project_timezone": self.project_timezone,
@@ -120,18 +128,9 @@ class SchemaGeneratorNode(AssistantNode, Generic[Q]):
query_generation_retry_count=len(intermediate_steps) + 1,
)
- if not result:
- # We've got no usable result after exhausting all iteration attempts - it's failure message time
- return PartialAssistantState(
- messages=[
- FailureMessage(
- content=f"It looks like I'm having trouble generating this {self.INSIGHT_NAME} insight."
- )
- ],
- intermediate_steps=None,
- plan=None,
- query_generation_retry_count=len(intermediate_steps) + 1,
- )
+ if isinstance(e, PydanticOutputParserException):
+ raise SchemaGenerationException(e.llm_output, e.validation_message)
+ raise SchemaGenerationException(e.llm_output or "No input was provided.", str(e))
# We've got a result that either passed the quality check or we've exhausted all attempts at iterating - return
return PartialAssistantState(
diff --git a/ee/hogai/graph/schema_generator/test/test_nodes.py b/ee/hogai/graph/schema_generator/test/test_nodes.py
index 1e1ea5e1c1..2dc0c4507b 100644
--- a/ee/hogai/graph/schema_generator/test/test_nodes.py
+++ b/ee/hogai/graph/schema_generator/test/test_nodes.py
@@ -13,7 +13,12 @@ from langchain_core.runnables import RunnableConfig, RunnableLambda
from posthog.schema import AssistantMessage, AssistantTrendsQuery, FailureMessage, HumanMessage, VisualizationMessage
-from ee.hogai.graph.schema_generator.nodes import RETRIES_ALLOWED, SchemaGeneratorNode, SchemaGeneratorToolsNode
+from ee.hogai.graph.schema_generator.nodes import (
+ RETRIES_ALLOWED,
+ SchemaGenerationException,
+ SchemaGeneratorNode,
+ SchemaGeneratorToolsNode,
+)
from ee.hogai.graph.schema_generator.parsers import PydanticOutputParserException
from ee.hogai.graph.schema_generator.utils import SchemaGeneratorOutput
from ee.hogai.utils.types import AssistantState, PartialAssistantState
@@ -258,7 +263,7 @@ class TestSchemaGeneratorNode(BaseTest):
self.assertEqual(action.log, "Field validation failed")
async def test_quality_check_failure_with_retries_exhausted(self):
- """Test quality check failure with retries exhausted still returns VisualizationMessage."""
+ """Test quality check failure with retries exhausted raises SchemaGenerationException."""
node = DummyGeneratorNode(self.team, self.user)
with (
patch.object(DummyGeneratorNode, "_model") as generator_model_mock,
@@ -273,26 +278,25 @@ class TestSchemaGeneratorNode(BaseTest):
)
# Start with RETRIES_ALLOWED intermediate steps (so no more allowed)
- new_state = await node.arun(
- AssistantState(
- messages=[HumanMessage(content="Text", id="0")],
- start_id="0",
- intermediate_steps=cast(
- list[IntermediateStep],
- [
- (AgentAction(tool="handle_incorrect_response", tool_input="", log=""), "retry"),
- ],
- )
- * RETRIES_ALLOWED,
- ),
- {},
- )
+ with self.assertRaises(SchemaGenerationException) as cm:
+ await node.arun(
+ AssistantState(
+ messages=[HumanMessage(content="Text", id="0")],
+ start_id="0",
+ intermediate_steps=cast(
+ list[IntermediateStep],
+ [
+ (AgentAction(tool="handle_incorrect_response", tool_input="", log=""), "retry"),
+ ],
+ )
+ * RETRIES_ALLOWED,
+ ),
+ {},
+ )
- # Should return VisualizationMessage despite quality check failure
- self.assertEqual(new_state.intermediate_steps, None)
- self.assertEqual(len(new_state.messages), 1)
- self.assertEqual(new_state.messages[0].type, "ai/viz")
- self.assertEqual(cast(VisualizationMessage, new_state.messages[0]).answer, self.basic_trends)
+ # Verify the exception contains the expected information
+ self.assertEqual(cm.exception.llm_output, '{"query": "test"}')
+ self.assertEqual(cm.exception.validation_message, "Quality check failed")
async def test_node_leaves_failover(self):
node = DummyGeneratorNode(self.team, self.user)
@@ -329,20 +333,17 @@ class TestSchemaGeneratorNode(BaseTest):
schema = DummySchema.model_construct(query=[]).model_dump() # type: ignore
generator_model_mock.return_value = RunnableLambda(lambda _: json.dumps(schema))
- new_state = await node.arun(
- AssistantState(
- messages=[HumanMessage(content="Text")],
- intermediate_steps=[
- (AgentAction(tool="", tool_input="", log="exception"), "exception"),
- (AgentAction(tool="", tool_input="", log="exception"), "exception"),
- ],
- ),
- {},
- )
- self.assertEqual(new_state.intermediate_steps, None)
- self.assertEqual(len(new_state.messages), 1)
- self.assertIsInstance(new_state.messages[0], FailureMessage)
- self.assertEqual(new_state.plan, None)
+ with self.assertRaises(SchemaGenerationException):
+ await node.arun(
+ AssistantState(
+ messages=[HumanMessage(content="Text")],
+ intermediate_steps=[
+ (AgentAction(tool="", tool_input="", log="exception"), "exception"),
+ (AgentAction(tool="", tool_input="", log="exception"), "exception"),
+ ],
+ ),
+ {},
+ )
async def test_agent_reconstructs_conversation_with_failover(self):
action = AgentAction(tool="fix", tool_input="validation error", log="exception")
diff --git a/ee/hogai/graph/session_summaries/nodes.py b/ee/hogai/graph/session_summaries/nodes.py
index 4d33fee929..8e5692256a 100644
--- a/ee/hogai/graph/session_summaries/nodes.py
+++ b/ee/hogai/graph/session_summaries/nodes.py
@@ -11,7 +11,6 @@ from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnableConfig
from posthog.schema import (
- AssistantMessage,
AssistantToolCallMessage,
MaxRecordingUniversalFilters,
NotebookUpdateMessage,
@@ -70,11 +69,7 @@ class SessionSummarizationNode(AssistantNode):
"""Push summarization progress as reasoning messages"""
content = prepare_reasoning_progress_message(progress_message)
if content:
- self.dispatcher.message(
- AssistantMessage(
- content=content,
- )
- )
+ self.dispatcher.update(content)
async def _stream_notebook_content(self, content: dict, state: AssistantState, partial: bool = True) -> None:
"""Stream TipTap content directly to a notebook if notebook_id is present in state."""
@@ -202,7 +197,7 @@ class _SessionSearch:
tool = await SearchSessionRecordingsTool.create_tool_class(
team=self._node._team,
user=self._node._user,
- tool_call_id=self._node._parent_tool_call_id or "",
+ node_path=self._node.node_path,
state=state,
config=config,
context_manager=self._node.context_manager,
diff --git a/ee/hogai/graph/sql/nodes.py b/ee/hogai/graph/sql/nodes.py
index eb9da64c9c..36d3840bf9 100644
--- a/ee/hogai/graph/sql/nodes.py
+++ b/ee/hogai/graph/sql/nodes.py
@@ -1,6 +1,6 @@
from langchain_core.runnables import RunnableConfig
-from posthog.schema import AssistantHogQLQuery, AssistantMessage
+from posthog.schema import AssistantHogQLQuery
from posthog.hogql.context import HogQLContext
@@ -25,7 +25,7 @@ class SQLGeneratorNode(HogQLGeneratorMixin, SchemaGeneratorNode[AssistantHogQLQu
return AssistantNodeName.SQL_GENERATOR
async def arun(self, state: AssistantState, config: RunnableConfig) -> PartialAssistantState:
- self.dispatcher.message(AssistantMessage(content="Creating SQL query"))
+ self.dispatcher.update("Creating SQL query")
prompt = await self._construct_system_prompt()
return await super()._run_with_prompt(state, prompt, config=config)
diff --git a/ee/hogai/graph/taxonomy/agent.py b/ee/hogai/graph/taxonomy/agent.py
index 3ae8a99dab..bfb7b4b01e 100644
--- a/ee/hogai/graph/taxonomy/agent.py
+++ b/ee/hogai/graph/taxonomy/agent.py
@@ -3,34 +3,43 @@ from typing import Generic
from posthog.models import Team, User
from ee.hogai.django_checkpoint.checkpointer import DjangoCheckpointer
-from ee.hogai.graph.graph import BaseAssistantGraph
+from ee.hogai.graph.base import BaseAssistantGraph
from ee.hogai.utils.types import PartialStateType, StateType
+from ee.hogai.utils.types.base import AssistantGraphName
from .nodes import StateClassMixin, TaxonomyAgentNode, TaxonomyAgentToolsNode
from .toolkit import TaxonomyAgentToolkit
from .types import TaxonomyNodeName
-class TaxonomyAgent(BaseAssistantGraph[StateType], Generic[StateType, PartialStateType], StateClassMixin):
+class TaxonomyAgent(
+ BaseAssistantGraph[StateType, PartialStateType], Generic[StateType, PartialStateType], StateClassMixin
+):
"""Taxonomy agent that can be configured with different node classes."""
def __init__(
self,
team: Team,
user: User,
- tool_call_id: str,
loop_node_class: type["TaxonomyAgentNode"],
tools_node_class: type["TaxonomyAgentToolsNode"],
toolkit_class: type["TaxonomyAgentToolkit"],
):
- # Extract the State type from the generic parameter
- state_class, _ = self._get_state_class(TaxonomyAgent)
- super().__init__(team, user, state_class, parent_tool_call_id=tool_call_id)
-
+ super().__init__(team, user)
self._loop_node_class = loop_node_class
self._tools_node_class = tools_node_class
self._toolkit_class = toolkit_class
+ @property
+ def state_type(self) -> type[StateType]:
+ # Extract the State type from the generic parameter
+ state_type, _ = self._get_state_class(TaxonomyAgent)
+ return state_type
+
+ @property
+ def graph_name(self) -> AssistantGraphName:
+ return AssistantGraphName.TAXONOMY
+
def compile_full_graph(self, checkpointer: DjangoCheckpointer | None = None):
"""Compile a complete taxonomy graph."""
return self.add_taxonomy_generator().compile(checkpointer=checkpointer)
diff --git a/ee/hogai/graph/taxonomy/nodes.py b/ee/hogai/graph/taxonomy/nodes.py
index 327b1cbb48..1db0ef30d8 100644
--- a/ee/hogai/graph/taxonomy/nodes.py
+++ b/ee/hogai/graph/taxonomy/nodes.py
@@ -40,25 +40,10 @@ TaxonomyPartialStateType = TypeVar("TaxonomyPartialStateType", bound=TaxonomyAge
TaxonomyNodeBound = BaseAssistantNode[TaxonomyStateType, TaxonomyPartialStateType]
-class ParentToolCallIdMixin(BaseAssistantNode[TaxonomyStateType, TaxonomyPartialStateType]):
- _parent_tool_call_id: str | None = None
- _toolkit: TaxonomyAgentToolkit
-
- async def __call__(self, state: TaxonomyStateType, config: RunnableConfig) -> TaxonomyPartialStateType | None:
- """
- Run the assistant node and handle cancelled conversation before the node is run.
- """
- if self._parent_tool_call_id:
- self._toolkit._parent_tool_call_id = self._parent_tool_call_id
-
- return await super().__call__(state, config)
-
-
class TaxonomyAgentNode(
Generic[TaxonomyStateType, TaxonomyPartialStateType],
StateClassMixin,
TaxonomyUpdateDispatcherNodeMixin,
- ParentToolCallIdMixin,
TaxonomyNodeBound,
ABC,
):
@@ -173,7 +158,6 @@ class TaxonomyAgentToolsNode(
Generic[TaxonomyStateType, TaxonomyPartialStateType],
StateClassMixin,
TaxonomyUpdateDispatcherNodeMixin,
- ParentToolCallIdMixin,
TaxonomyNodeBound,
):
"""Base tools node for taxonomy agents."""
diff --git a/ee/hogai/graph/taxonomy/test/test_agent.py b/ee/hogai/graph/taxonomy/test/test_agent.py
index cb0551a864..23eaded165 100644
--- a/ee/hogai/graph/taxonomy/test/test_agent.py
+++ b/ee/hogai/graph/taxonomy/test/test_agent.py
@@ -36,14 +36,13 @@ class TestTaxonomyAgent(BaseTest):
self.mock_graph = Mock()
# Patch StateGraph in the parent class where it's actually used
- self.patcher = patch("ee.hogai.graph.graph.StateGraph")
+ self.patcher = patch("ee.hogai.graph.base.graph.StateGraph")
mock_state_graph_class = self.patcher.start()
mock_state_graph_class.return_value = self.mock_graph
self.agent = ConcreteTaxonomyAgent(
team=self.team,
user=self.user,
- tool_call_id="test_tool_call_id",
loop_node_class=MockTaxonomyAgentNode,
tools_node_class=MockTaxonomyAgentToolsNode,
toolkit_class=MockTaxonomyAgentToolkit,
@@ -75,7 +74,6 @@ class TestTaxonomyAgent(BaseTest):
NonGenericAgent(
team=self.team,
user=self.user,
- tool_call_id="test_tool_call_id",
loop_node_class=MockTaxonomyAgentNode,
tools_node_class=MockTaxonomyAgentToolsNode,
toolkit_class=MockTaxonomyAgentToolkit,
diff --git a/ee/hogai/graph/taxonomy/toolkit.py b/ee/hogai/graph/taxonomy/toolkit.py
index 5ccde1c15c..676a3bb12b 100644
--- a/ee/hogai/graph/taxonomy/toolkit.py
+++ b/ee/hogai/graph/taxonomy/toolkit.py
@@ -122,10 +122,6 @@ class TaxonomyTaskExecutorNode(
Task executor node specifically for taxonomy operations.
"""
- def __init__(self, team: Team, user: User, parent_tool_call_id: str | None):
- super().__init__(team, user)
- self._parent_tool_call_id = parent_tool_call_id
-
@property
def node_name(self) -> MaxNodeName:
return TaxonomyNodeName.TASK_EXECUTOR
@@ -148,8 +144,6 @@ class TaxonomyTaskExecutorNode(
class TaxonomyAgentToolkit:
"""Base toolkit for taxonomy agents that handle tool execution."""
- _parent_tool_call_id: str | None
-
def __init__(self, team: Team, user: User):
self._team = team
self._user = user
@@ -391,7 +385,7 @@ class TaxonomyAgentToolkit:
)
)
message = AssistantMessage(content="", id=str(uuid4()), tool_calls=tool_calls)
- executor = TaxonomyTaskExecutorNode(self._team, self._user, parent_tool_call_id=self._parent_tool_call_id)
+ executor = TaxonomyTaskExecutorNode(self._team, self._user)
result = await executor.arun(AssistantState(messages=[message]), RunnableConfig())
final_result = {}
@@ -651,7 +645,7 @@ class TaxonomyAgentToolkit:
for event_name_or_action_id in event_name_or_action_ids
]
message = AssistantMessage(content="", id=str(uuid4()), tool_calls=tool_calls)
- executor = TaxonomyTaskExecutorNode(self._team, self._user, parent_tool_call_id=self._parent_tool_call_id)
+ executor = TaxonomyTaskExecutorNode(self._team, self._user)
result = await executor.arun(AssistantState(messages=[message]), RunnableConfig())
return {task_result.id: task_result.result for task_result in result.task_results}
diff --git a/ee/hogai/graph/test/test_assistant_graph.py b/ee/hogai/graph/test/test_assistant_graph.py
deleted file mode 100644
index bd2a144b02..0000000000
--- a/ee/hogai/graph/test/test_assistant_graph.py
+++ /dev/null
@@ -1,29 +0,0 @@
-from posthog.test.base import BaseTest
-
-from langchain_core.runnables import RunnableLambda
-from langgraph.checkpoint.memory import InMemorySaver
-
-from ee.hogai.graph.graph import BaseAssistantGraph
-from ee.hogai.utils.types import AssistantNodeName, AssistantState, PartialAssistantState
-
-
-class TestAssistantGraph(BaseTest):
- async def test_pydantic_state_resets_with_none(self):
- """When a None field is set, it should be reset to None."""
-
- async def runnable(state: AssistantState) -> PartialAssistantState:
- return PartialAssistantState(start_id=None)
-
- graph = BaseAssistantGraph(self.team, self.user, state_type=AssistantState)
- compiled_graph = (
- graph.add_node(AssistantNodeName.ROOT, RunnableLambda(runnable))
- .add_edge(AssistantNodeName.START, AssistantNodeName.ROOT)
- .add_edge(AssistantNodeName.ROOT, AssistantNodeName.END)
- .compile(checkpointer=InMemorySaver())
- )
- state = await compiled_graph.ainvoke(
- AssistantState(messages=[], graph_status="resumed", start_id=None),
- {"configurable": {"thread_id": "test"}},
- )
- self.assertEqual(state["start_id"], None)
- self.assertEqual(state["graph_status"], "resumed")
diff --git a/ee/hogai/graph/test/test_dispatcher_integration.py b/ee/hogai/graph/test/test_dispatcher_integration.py
index db8c36240f..2687579a1b 100644
--- a/ee/hogai/graph/test/test_dispatcher_integration.py
+++ b/ee/hogai/graph/test/test_dispatcher_integration.py
@@ -4,8 +4,7 @@ Integration tests for dispatcher usage in BaseAssistantNode and graph execution.
These tests ensure that the dispatcher pattern works correctly end-to-end in real graph execution.
"""
-import uuid
-from typing import Any
+from typing import cast
from posthog.test.base import BaseTest
from unittest.mock import MagicMock, patch
@@ -14,30 +13,40 @@ from langchain_core.runnables import RunnableConfig
from posthog.schema import AssistantMessage
-from ee.hogai.graph.base import BaseAssistantNode
-from ee.hogai.utils.dispatcher import MessageAction, NodeStartAction
-from ee.hogai.utils.types import AssistantNodeName
-from ee.hogai.utils.types.base import AssistantDispatcherEvent, AssistantState, PartialAssistantState
+from ee.hogai.graph.base.node import BaseAssistantNode
+from ee.hogai.utils.types.base import (
+ AssistantDispatcherEvent,
+ AssistantGraphName,
+ AssistantNodeName,
+ AssistantState,
+ MessageAction,
+ NodeEndAction,
+ NodePath,
+ NodeStartAction,
+ PartialAssistantState,
+ UpdateAction,
+)
-class MockAssistantNode(BaseAssistantNode):
+class MockAssistantNode(BaseAssistantNode[AssistantState, PartialAssistantState]):
"""Mock node for testing dispatcher integration."""
- def __init__(self, team, user):
- super().__init__(team, user)
+ def __init__(self, team, user, node_path=None):
+ super().__init__(team, user, node_path)
self.arun_called = False
self.messages_dispatched = []
@property
- def node_name(self) -> AssistantNodeName:
+ def node_name(self) -> str:
return AssistantNodeName.ROOT
- async def arun(self, state, config: RunnableConfig) -> PartialAssistantState:
+ async def arun(self, state: AssistantState, config: RunnableConfig) -> PartialAssistantState:
self.arun_called = True
# Simulate dispatching messages during execution
- self.dispatcher.message(AssistantMessage(content="Processing..."))
- self.dispatcher.message(AssistantMessage(content="Done!"))
+ self.dispatcher.update("Processing...")
+ self.dispatcher.message(AssistantMessage(content="Intermediate result"))
+ self.dispatcher.update("Done!")
return PartialAssistantState(messages=[AssistantMessage(content="Final result")])
@@ -57,147 +66,169 @@ class TestDispatcherIntegration(BaseTest):
node = MockAssistantNode(self.mock_team, self.mock_user)
state = AssistantState(messages=[])
- config = RunnableConfig()
+ config = RunnableConfig(metadata={"langgraph_node": AssistantNodeName.ROOT, "langgraph_checkpoint_ns": "cp_1"})
- await node.arun(state, config)
+ with patch("ee.hogai.utils.dispatcher.get_stream_writer", side_effect=RuntimeError("Not streaming")):
+ await node(state, config)
self.assertTrue(node.arun_called)
self.assertIsNotNone(node.dispatcher)
+ async def test_node_path_propagation(self):
+ """Test that node_path is correctly set and propagated."""
+ node_path = (
+ NodePath(name=AssistantGraphName.ASSISTANT),
+ NodePath(name=AssistantNodeName.ROOT),
+ )
+
+ node = MockAssistantNode(self.mock_team, self.mock_user, node_path=node_path)
+
+ self.assertEqual(node.node_path, node_path)
+
+ async def test_dispatcher_dispatches_node_start_and_end(self):
+ """Test that NODE_START and NODE_END actions are dispatched."""
+ node = MockAssistantNode(self.mock_team, self.mock_user)
+
+ state = AssistantState(messages=[])
+ config = RunnableConfig(metadata={"langgraph_node": AssistantNodeName.ROOT, "langgraph_checkpoint_ns": "cp_2"})
+
+ dispatched_actions = []
+
+ def capture_write(event):
+ if isinstance(event, AssistantDispatcherEvent):
+ dispatched_actions.append(event)
+
+ with patch("ee.hogai.utils.dispatcher.get_stream_writer", return_value=capture_write):
+ await node(state, config)
+
+ # Should have dispatched actions in this order:
+ # 1. NODE_START
+ # 2. UpdateAction ("Processing...")
+ # 3. MessageAction (intermediate)
+ # 4. UpdateAction ("Done!")
+ # 5. NODE_END with final state
+ self.assertGreater(len(dispatched_actions), 0)
+
+ # Verify NODE_START is first
+ self.assertIsInstance(dispatched_actions[0].action, NodeStartAction)
+
+ # Verify NODE_END is last
+ last_action = dispatched_actions[-1].action
+ self.assertIsInstance(last_action, NodeEndAction)
+ self.assertIsNotNone(cast(NodeEndAction, last_action).state)
+
async def test_messages_dispatched_during_node_execution(self):
"""Test that messages dispatched during node execution are sent to writer."""
node = MockAssistantNode(self.mock_team, self.mock_user)
state = AssistantState(messages=[])
- config = RunnableConfig()
+ config = RunnableConfig(metadata={"langgraph_node": AssistantNodeName.ROOT, "langgraph_checkpoint_ns": "cp_3"})
dispatched_actions = []
- def capture_write(update: Any):
- dispatched_actions.append(update)
+ def capture_write(event: AssistantDispatcherEvent):
+ dispatched_actions.append(event)
- # Mock get_stream_writer to return our test writer
- with patch("ee.hogai.graph.base.get_stream_writer", return_value=capture_write):
- # Call the node (not arun) to trigger __call__ which handles dispatching
+ with patch("ee.hogai.utils.dispatcher.get_stream_writer", return_value=capture_write):
await node(state, config)
- # Should have:
- # 1. NODE_START from __call__
- # 2. Two MESSAGE actions from arun (Processing..., Done!)
- # 3. One MESSAGE action from returned state (Final result)
+ # Find the update and message actions (excluding NODE_START and NODE_END)
+ update_actions = [e for e in dispatched_actions if isinstance(e.action, UpdateAction)]
+ message_actions = [e for e in dispatched_actions if isinstance(e.action, MessageAction)]
- self.assertEqual(len(dispatched_actions), 4)
- self.assertIsInstance(dispatched_actions[0], AssistantDispatcherEvent)
- self.assertIsInstance(dispatched_actions[0].action, NodeStartAction)
- self.assertIsInstance(dispatched_actions[1], AssistantDispatcherEvent)
- self.assertIsInstance(dispatched_actions[1].action, MessageAction)
- self.assertEqual(dispatched_actions[1].action.message.content, "Processing...")
- self.assertIsInstance(dispatched_actions[2], AssistantDispatcherEvent)
- self.assertIsInstance(dispatched_actions[2].action, MessageAction)
- self.assertEqual(dispatched_actions[2].action.message.content, "Done!")
- self.assertIsInstance(dispatched_actions[3], AssistantDispatcherEvent)
- self.assertIsInstance(dispatched_actions[3].action, MessageAction)
- self.assertEqual(dispatched_actions[3].action.message.content, "Final result")
+ # Should have 2 updates: "Processing..." and "Done!"
+ self.assertEqual(len(update_actions), 2)
+ self.assertEqual(cast(UpdateAction, update_actions[0].action).content, "Processing...")
+ self.assertEqual(cast(UpdateAction, update_actions[1].action).content, "Done!")
- async def test_node_start_action_dispatched(self):
- """Test that NODE_START action is dispatched at node entry."""
+ # Should have 1 message: intermediate (final message is in NODE_END state, not dispatched separately)
+ self.assertEqual(len(message_actions), 1)
+ msg = cast(MessageAction, message_actions[0].action).message
+ self.assertEqual(cast(AssistantMessage, msg).content, "Intermediate result")
+
+ async def test_node_path_included_in_dispatched_events(self):
+ """Test that node_path is included in all dispatched events."""
+ node_path = (
+ NodePath(name=AssistantGraphName.ASSISTANT),
+ NodePath(name=AssistantNodeName.ROOT),
+ )
+
+ node = MockAssistantNode(self.mock_team, self.mock_user, node_path=node_path)
+
+ state = AssistantState(messages=[])
+ config = RunnableConfig(metadata={"langgraph_node": AssistantNodeName.ROOT, "langgraph_checkpoint_ns": "cp_4"})
+
+ dispatched_events = []
+
+ def capture_write(event: AssistantDispatcherEvent):
+ dispatched_events.append(event)
+
+ with patch("ee.hogai.utils.dispatcher.get_stream_writer", return_value=capture_write):
+ await node(state, config)
+
+ # Verify all events have the correct node_path
+ for event in dispatched_events:
+ self.assertEqual(event.node_path, node_path)
+
+ async def test_node_run_id_included_in_dispatched_events(self):
+ """Test that node_run_id is included in all dispatched events."""
node = MockAssistantNode(self.mock_team, self.mock_user)
state = AssistantState(messages=[])
- config = RunnableConfig()
+ checkpoint_ns = "checkpoint_xyz_789"
+ config = RunnableConfig(
+ metadata={"langgraph_node": AssistantNodeName.ROOT, "langgraph_checkpoint_ns": checkpoint_ns}
+ )
- dispatched_actions = []
+ dispatched_events = []
- def capture_write(update):
- if isinstance(update, AssistantDispatcherEvent):
- dispatched_actions.append(update.action)
+ def capture_write(event: AssistantDispatcherEvent):
+ dispatched_events.append(event)
- with patch("ee.hogai.graph.base.get_stream_writer", return_value=capture_write):
+ with patch("ee.hogai.utils.dispatcher.get_stream_writer", return_value=capture_write):
await node(state, config)
- # Should have at least one NODE_START action
- from ee.hogai.utils.dispatcher import NodeStartAction
+ # Verify all events have the correct node_run_id
+ for event in dispatched_events:
+ self.assertEqual(event.node_run_id, checkpoint_ns)
- node_start_actions = [action for action in dispatched_actions if isinstance(action, NodeStartAction)]
- self.assertGreater(len(node_start_actions), 0)
-
- @patch("ee.hogai.graph.base.get_stream_writer")
- async def test_parent_tool_call_id_propagation(self, mock_get_stream_writer):
- """Test that parent_tool_call_id is propagated to dispatched messages."""
- parent_tool_call_id = str(uuid.uuid4())
-
- class NodeWithParent(BaseAssistantNode):
- @property
- def node_name(self):
- return AssistantNodeName.ROOT
-
- async def arun(self, state, config: RunnableConfig) -> PartialAssistantState:
- # Dispatch message - should inherit parent_tool_call_id from state
- self.dispatcher.message(AssistantMessage(content="Child message"))
- return PartialAssistantState(messages=[])
-
- node = NodeWithParent(self.mock_team, self.mock_user)
-
- state = {"messages": [], "parent_tool_call_id": parent_tool_call_id}
- config = RunnableConfig()
-
- dispatched_messages = []
-
- def capture_write(update):
- if isinstance(update, AssistantDispatcherEvent) and isinstance(update.action, MessageAction):
- dispatched_messages.append(update.action.message)
-
- mock_get_stream_writer.return_value = capture_write
-
- await node.arun(state, config)
-
- # Verify dispatched messages have parent_tool_call_id
- assistant_messages = [msg for msg in dispatched_messages if isinstance(msg, AssistantMessage)]
- for msg in assistant_messages:
- # If the implementation propagates it, this should be true
- # Otherwise this test will help catch that as a potential issue
- if msg.parent_tool_call_id:
- self.assertEqual(msg.parent_tool_call_id, parent_tool_call_id)
-
- @patch("ee.hogai.graph.base.get_stream_writer")
- async def test_dispatcher_error_handling(self, mock_get_stream_writer):
+ async def test_dispatcher_error_handling_does_not_crash_node(self):
"""Test that errors in dispatcher don't crash node execution."""
- class FailingDispatcherNode(BaseAssistantNode):
+ class FailingDispatcherNode(BaseAssistantNode[AssistantState, PartialAssistantState]):
@property
def node_name(self):
return AssistantNodeName.ROOT
- async def arun(self, state, config: RunnableConfig) -> PartialAssistantState:
+ async def arun(self, state: AssistantState, config: RunnableConfig) -> PartialAssistantState:
# Try to dispatch - if writer fails, should handle gracefully
- self.dispatcher.message(AssistantMessage(content="Test"))
+ self.dispatcher.update("Test")
return PartialAssistantState(messages=[AssistantMessage(content="Result")])
node = FailingDispatcherNode(self.mock_team, self.mock_user)
state = AssistantState(messages=[])
- config = RunnableConfig()
+ config = RunnableConfig(metadata={"langgraph_node": AssistantNodeName.ROOT, "langgraph_checkpoint_ns": "cp_5"})
# Make writer raise an error
def failing_writer(data):
raise RuntimeError("Writer failed")
- mock_get_stream_writer.return_value = failing_writer
+ with patch("ee.hogai.utils.dispatcher.get_stream_writer", return_value=failing_writer):
+ # Should not crash - node should complete
+ result = await node(state, config)
+ self.assertIsNotNone(result)
- # Should not crash - node should complete
- result = await node.arun(state, config)
- self.assertIsNotNone(result)
+ async def test_messages_in_partial_state_dispatched_via_node_end(self):
+ """Test that messages in PartialState are dispatched via NODE_END."""
- async def test_messages_in_partial_state_are_auto_dispatched(self):
- """Test that messages in PartialState are automatically dispatched."""
-
- class NodeReturningMessages(BaseAssistantNode):
+ class NodeReturningMessages(BaseAssistantNode[AssistantState, PartialAssistantState]):
@property
def node_name(self):
return AssistantNodeName.ROOT
- async def arun(self, state, config: RunnableConfig) -> PartialAssistantState:
- # Return messages in state - should be auto-dispatched
+ async def arun(self, state: AssistantState, config: RunnableConfig) -> PartialAssistantState:
+ # Return messages in state
return PartialAssistantState(
messages=[
AssistantMessage(content="Message 1"),
@@ -208,39 +239,139 @@ class TestDispatcherIntegration(BaseTest):
node = NodeReturningMessages(self.mock_team, self.mock_user)
state = AssistantState(messages=[])
- config = RunnableConfig()
+ config = RunnableConfig(metadata={"langgraph_node": AssistantNodeName.ROOT, "langgraph_checkpoint_ns": "cp_6"})
- dispatched_messages = []
+ dispatched_events = []
- def capture_write(update):
- if isinstance(update, AssistantDispatcherEvent) and isinstance(update.action, MessageAction):
- dispatched_messages.append(update.action.message)
+ def capture_write(event: AssistantDispatcherEvent):
+ dispatched_events.append(event)
- with patch("ee.hogai.graph.base.get_stream_writer", return_value=capture_write):
+ with patch("ee.hogai.utils.dispatcher.get_stream_writer", return_value=capture_write):
await node(state, config)
- # Should have dispatched the messages from PartialState (and NODE_START)
- # We expect at least 2 message actions (Message 1 and Message 2)
- self.assertGreaterEqual(len(dispatched_messages), 2)
+ # Should have NODE_END action with state containing messages
+ node_end_actions = [e for e in dispatched_events if isinstance(e.action, NodeEndAction)]
+ self.assertEqual(len(node_end_actions), 1)
+
+ node_end_state = cast(NodeEndAction, node_end_actions[0].action).state
+ self.assertIsNotNone(node_end_state)
+ assert node_end_state is not None
+ self.assertEqual(len(node_end_state.messages), 2)
async def test_node_returns_none_state_handling(self):
"""Test that node can return None state without errors."""
- class NoneStateNode(BaseAssistantNode):
+ class NoneStateNode(BaseAssistantNode[AssistantState, PartialAssistantState]):
@property
def node_name(self):
return AssistantNodeName.ROOT
- async def arun(self, state, config: RunnableConfig) -> None:
+ async def arun(self, state: AssistantState, config: RunnableConfig) -> PartialAssistantState | None:
# Dispatch a message but return None state
- self.dispatcher.message(AssistantMessage(content="Test"))
+ self.dispatcher.update("Test")
return None
node = NoneStateNode(self.mock_team, self.mock_user)
state = AssistantState(messages=[])
- config = RunnableConfig()
+ config = RunnableConfig(metadata={"langgraph_node": AssistantNodeName.ROOT, "langgraph_checkpoint_ns": "cp_7"})
- result = await node(state, config)
- # Should handle None gracefully
- self.assertIsNone(result)
+ with patch("ee.hogai.utils.dispatcher.get_stream_writer", side_effect=RuntimeError("Not streaming")):
+ result = await node(state, config)
+ # Should handle None gracefully
+ self.assertIsNone(result)
+
+ async def test_nested_node_path_in_dispatched_events(self):
+ """Test that nested nodes have correct node_path."""
+ # Simulate a nested node path (e.g., from a tool call)
+ parent_path = (
+ NodePath(name=AssistantGraphName.ASSISTANT),
+ NodePath(name=AssistantNodeName.ROOT, message_id="msg_123", tool_call_id="tc_123"),
+ NodePath(name=AssistantGraphName.INSIGHTS),
+ )
+
+ node = MockAssistantNode(self.mock_team, self.mock_user, node_path=parent_path)
+
+ state = AssistantState(messages=[])
+ config = RunnableConfig(
+ metadata={"langgraph_node": AssistantNodeName.TRENDS_GENERATOR, "langgraph_checkpoint_ns": "cp_8"}
+ )
+
+ dispatched_events = []
+
+ def capture_write(event: AssistantDispatcherEvent):
+ dispatched_events.append(event)
+
+ with patch("ee.hogai.utils.dispatcher.get_stream_writer", return_value=capture_write):
+ await node(state, config)
+
+ # Verify all events have the nested path
+ for event in dispatched_events:
+ self.assertIsNotNone(event.node_path)
+ assert event.node_path is not None
+ self.assertEqual(event.node_path, parent_path)
+ self.assertEqual(len(event.node_path), 3)
+ self.assertEqual(event.node_path[1].message_id, "msg_123")
+ self.assertEqual(event.node_path[1].tool_call_id, "tc_123")
+
+ async def test_update_actions_include_node_metadata(self):
+ """Test that update actions include correct node metadata."""
+ node = MockAssistantNode(self.mock_team, self.mock_user)
+
+ state = AssistantState(messages=[])
+ config = RunnableConfig(metadata={"langgraph_node": AssistantNodeName.ROOT, "langgraph_checkpoint_ns": "cp_9"})
+
+ dispatched_events = []
+
+ def capture_write(event: AssistantDispatcherEvent):
+ dispatched_events.append(event)
+
+ with patch("ee.hogai.utils.dispatcher.get_stream_writer", return_value=capture_write):
+ await node(state, config)
+
+ # Find update actions
+ update_events = [e for e in dispatched_events if isinstance(e.action, UpdateAction)]
+
+ for event in update_events:
+ self.assertEqual(event.node_name, AssistantNodeName.ROOT)
+ self.assertEqual(event.node_run_id, "cp_9")
+ self.assertIsNotNone(event.node_path)
+
+ async def test_concurrent_node_executions_independent_dispatchers(self):
+ """Test that concurrent node executions use independent dispatchers."""
+ node1 = MockAssistantNode(self.mock_team, self.mock_user)
+ node2 = MockAssistantNode(self.mock_team, self.mock_user)
+
+ state = AssistantState(messages=[])
+ config1 = RunnableConfig(
+ metadata={"langgraph_node": AssistantNodeName.ROOT, "langgraph_checkpoint_ns": "cp_10a"}
+ )
+ config2 = RunnableConfig(
+ metadata={"langgraph_node": AssistantNodeName.ROOT, "langgraph_checkpoint_ns": "cp_10b"}
+ )
+
+ events1 = []
+ events2 = []
+
+ def capture_write1(event: AssistantDispatcherEvent):
+ events1.append(event)
+
+ def capture_write2(event: AssistantDispatcherEvent):
+ events2.append(event)
+
+ with patch("ee.hogai.utils.dispatcher.get_stream_writer", return_value=capture_write1):
+ await node1(state, config1)
+
+ with patch("ee.hogai.utils.dispatcher.get_stream_writer", return_value=capture_write2):
+ await node2(state, config2)
+
+ # Verify events went to separate lists
+ self.assertGreater(len(events1), 0)
+ self.assertGreater(len(events2), 0)
+
+ # Verify node_run_ids are different
+ for event in events1:
+ self.assertEqual(event.node_run_id, "cp_10a")
+
+ for event in events2:
+ self.assertEqual(event.node_run_id, "cp_10b")
diff --git a/ee/hogai/graph/trends/nodes.py b/ee/hogai/graph/trends/nodes.py
index cb11eb2ba7..6034a7ef58 100644
--- a/ee/hogai/graph/trends/nodes.py
+++ b/ee/hogai/graph/trends/nodes.py
@@ -1,7 +1,7 @@
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnableConfig
-from posthog.schema import AssistantMessage, AssistantTrendsQuery
+from posthog.schema import AssistantTrendsQuery
from ee.hogai.utils.types import AssistantState
from ee.hogai.utils.types.base import AssistantNodeName, PartialAssistantState
@@ -25,7 +25,7 @@ class TrendsGeneratorNode(SchemaGeneratorNode[AssistantTrendsQuery]):
return AssistantNodeName.TRENDS_GENERATOR
async def arun(self, state: AssistantState, config: RunnableConfig) -> PartialAssistantState:
- self.dispatcher.message(AssistantMessage(content="Creating trends query"))
+ self.dispatcher.update("Creating trends query")
prompt = ChatPromptTemplate.from_messages(
[
("system", TRENDS_SYSTEM_PROMPT),
diff --git a/ee/hogai/test/test_assistant.py b/ee/hogai/test/test_assistant.py
index 588ba4d2ae..ccdf8e00df 100644
--- a/ee/hogai/test/test_assistant.py
+++ b/ee/hogai/test/test_assistant.py
@@ -9,7 +9,7 @@ from posthog.test.base import (
_create_person,
flush_persons_and_events,
)
-from unittest.mock import AsyncMock, patch
+from unittest.mock import AsyncMock, MagicMock, patch
from django.test import override_settings
@@ -61,13 +61,13 @@ from posthog.models import Action
from ee.hogai.assistant.base import BaseAssistant
from ee.hogai.django_checkpoint.checkpointer import DjangoCheckpointer
-from ee.hogai.graph.base import BaseAssistantNode
+from ee.hogai.graph.base import AssistantNode
from ee.hogai.graph.funnels.nodes import FunnelsSchemaGeneratorOutput
+from ee.hogai.graph.insights_graph.graph import InsightsGraph
from ee.hogai.graph.memory import prompts as memory_prompts
from ee.hogai.graph.retention.nodes import RetentionSchemaGeneratorOutput
from ee.hogai.graph.root.nodes import SLASH_COMMAND_INIT
from ee.hogai.graph.trends.nodes import TrendsSchemaGeneratorOutput
-from ee.hogai.utils.state import GraphValueUpdateTuple
from ee.hogai.utils.tests import FakeAnthropicRunnableLambdaWithTokenCounter, FakeChatAnthropic, FakeChatOpenAI
from ee.hogai.utils.types import (
AssistantMode,
@@ -76,16 +76,21 @@ from ee.hogai.utils.types import (
AssistantState,
PartialAssistantState,
)
+from ee.hogai.utils.types.base import ReplaceMessages
from ee.models.assistant import Conversation, CoreMemory
from ..assistant import Assistant
-from ..graph import AssistantGraph, InsightsAssistantGraph
+from ..graph.graph import AssistantGraph
title_generator_mock = patch(
"ee.hogai.graph.title_generator.nodes.TitleGeneratorNode._model",
return_value=FakeChatOpenAI(responses=[messages.AIMessage(content="Title")]),
)
+query_executor_mock = patch(
+ "ee.hogai.graph.query_executor.nodes.QueryExecutorNode._format_query_result", new=MagicMock(return_value="Result")
+)
+
class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest):
CLASS_DATA_LEVEL_SETUP = False
@@ -118,7 +123,7 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest):
),
).start()
- self.checkpointer_patch = patch("ee.hogai.graph.graph.global_checkpointer", new=DjangoCheckpointer())
+ self.checkpointer_patch = patch("ee.hogai.graph.base.graph.global_checkpointer", new=DjangoCheckpointer())
self.checkpointer_patch.start()
def tearDown(self):
@@ -232,14 +237,7 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest):
graph = (
AssistantGraph(self.team, self.user)
.add_edge(AssistantNodeName.START, AssistantNodeName.ROOT)
- .add_root(
- {
- "insights": AssistantNodeName.INSIGHTS_SUBGRAPH,
- "root": AssistantNodeName.ROOT,
- "end": AssistantNodeName.END,
- }
- )
- .add_insights(AssistantNodeName.ROOT)
+ .add_root()
.compile()
)
@@ -269,7 +267,7 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest):
{
"id": "1",
"name": "create_and_query_insight",
- "args": {"query_description": "Foobar", "query_kind": insight_type},
+ "args": {"query_description": "Foobar"},
}
],
)
@@ -302,7 +300,7 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest):
AssistantToolCall(
id="1",
name="create_and_query_insight",
- args={"query_description": "Foobar", "query_kind": insight_type},
+ args={"query_description": "Foobar"},
)
],
),
@@ -327,12 +325,12 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest):
snapshot: StateSnapshot = await graph.aget_state(config)
self.assertFalse(snapshot.next)
self.assertFalse(snapshot.values.get("intermediate_steps"))
- self.assertFalse(snapshot.values["plan"])
- self.assertFalse(snapshot.values["graph_status"])
- self.assertFalse(snapshot.values["root_tool_call_id"])
- self.assertFalse(snapshot.values["root_tool_insight_plan"])
- self.assertFalse(snapshot.values["root_tool_insight_type"])
- self.assertFalse(snapshot.values["root_tool_calls_count"])
+ self.assertFalse(snapshot.values.get("plan"))
+ self.assertFalse(snapshot.values.get("graph_status"))
+ self.assertFalse(snapshot.values.get("root_tool_call_id"))
+ self.assertFalse(snapshot.values.get("root_tool_insight_plan"))
+ self.assertFalse(snapshot.values.get("root_tool_insight_type"))
+ self.assertFalse(snapshot.values.get("root_tool_calls_count"))
async def test_trends_interrupt_when_asking_for_help(self):
await self._test_human_in_the_loop("trends")
@@ -345,7 +343,7 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest):
async def test_ai_messages_appended_after_interrupt(self):
with patch("ee.hogai.graph.query_planner.nodes.QueryPlannerNode._get_model") as mock:
- graph = InsightsAssistantGraph(self.team, self.user).compile_full_graph()
+ graph = InsightsGraph(self.team, self.user).compile_full_graph()
config: RunnableConfig = {
"configurable": {
"thread_id": self.conversation.id,
@@ -441,16 +439,13 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest):
self.assertIsInstance(output[1][1], FailureMessage)
async def test_new_conversation_handles_serialized_conversation(self):
- from ee.hogai.graph.base import BaseAssistantNode
-
- class TestNode(BaseAssistantNode):
+ class TestNode(AssistantNode):
@property
def node_name(self):
return AssistantNodeName.ROOT
async def arun(self, state, config):
- self.dispatcher.message(AssistantMessage(content="Hello"))
- return PartialAssistantState()
+ return PartialAssistantState(messages=[AssistantMessage(content="Hello", id=str(uuid4()))])
test_node = TestNode(self.team, self.user)
graph = (
@@ -478,16 +473,13 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest):
self.assertNotEqual(output[0][0], "conversation")
async def test_async_stream(self):
- from ee.hogai.graph.base import BaseAssistantNode
-
- class TestNode(BaseAssistantNode):
+ class TestNode(AssistantNode):
@property
def node_name(self):
return AssistantNodeName.ROOT
async def arun(self, state, config):
- self.dispatcher.message(AssistantMessage(content="bar"))
- return PartialAssistantState()
+ return PartialAssistantState(messages=[AssistantMessage(content="bar", id=str(uuid4()))])
test_node = TestNode(self.team, self.user)
graph = (
@@ -517,12 +509,13 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest):
self.assertConversationEqual(actual_output, expected_output)
async def test_async_stream_handles_exceptions(self):
- def node_handler(state):
- raise ValueError()
+ class NodeHandler(AssistantNode):
+ async def arun(self, state, config):
+ raise ValueError
graph = (
AssistantGraph(self.team, self.user)
- .add_node(AssistantNodeName.ROOT, node_handler)
+ .add_node(AssistantNodeName.ROOT, NodeHandler(self.team, self.user))
.add_edge(AssistantNodeName.START, AssistantNodeName.ROOT)
.add_edge(AssistantNodeName.ROOT, AssistantNodeName.END)
.compile()
@@ -536,12 +529,11 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest):
("message", HumanMessage(content="foo")),
("message", FailureMessage()),
]
- actual_output = []
- async for event in assistant.astream():
- actual_output.append(event)
+ actual_output, _ = await self._run_assistant_graph(graph, message="foo")
self.assertConversationEqual(actual_output, expected_output)
@title_generator_mock
+ @query_executor_mock
@patch("ee.hogai.graph.schema_generator.nodes.SchemaGeneratorNode._model")
@patch("ee.hogai.graph.query_planner.nodes.QueryPlannerNode._get_model")
@patch("ee.hogai.graph.root.nodes.RootNode._get_model")
@@ -556,7 +548,7 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest):
{
"id": "xyz",
"name": "create_and_query_insight",
- "args": {"query_description": "Foobar", "query_kind": "trends"},
+ "args": {"query_description": "Foobar"},
}
],
)
@@ -596,7 +588,7 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest):
AssistantToolCall(
id="xyz",
name="create_and_query_insight",
- args={"query_description": "Foobar", "query_kind": "trends"},
+ args={"query_description": "Foobar"},
)
],
),
@@ -611,7 +603,7 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest):
),
("update", AssistantUpdateEvent(id="message_2", tool_call_id="xyz", content="Creating trends query")),
("message", VisualizationMessage(query="Foobar", answer=query, plan="Plan")),
- ("message", {"tool_call_id": "xyz", "type": "tool"}), # Don't check content as it's implementation detail
+ ("message", AssistantToolCallMessage(tool_call_id="xyz", content="Result")),
("message", AssistantMessage(content="The results indicate a great future for you.")),
]
self.assertConversationEqual(actual_output, expected_output)
@@ -634,6 +626,7 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest):
)
@title_generator_mock
+ @query_executor_mock
@patch("ee.hogai.graph.schema_generator.nodes.SchemaGeneratorNode._model")
@patch("ee.hogai.graph.query_planner.nodes.QueryPlannerNode._get_model")
@patch("ee.hogai.graph.root.nodes.RootNode._get_model")
@@ -648,7 +641,7 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest):
{
"id": "xyz",
"name": "create_and_query_insight",
- "args": {"query_description": "Foobar", "query_kind": "funnel"},
+ "args": {"query_description": "Foobar"},
}
],
)
@@ -693,7 +686,7 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest):
AssistantToolCall(
id="xyz",
name="create_and_query_insight",
- args={"query_description": "Foobar", "query_kind": "funnel"},
+ args={"query_description": "Foobar"},
)
],
),
@@ -708,7 +701,7 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest):
),
("update", AssistantUpdateEvent(id="message_2", tool_call_id="xyz", content="Creating funnel query")),
("message", VisualizationMessage(query="Foobar", answer=query, plan="Plan")),
- ("message", {"tool_call_id": "xyz", "type": "tool"}), # Don't check content as it's implementation detail
+ ("message", AssistantToolCallMessage(tool_call_id="xyz", content="Result")),
("message", AssistantMessage(content="The results indicate a great future for you.")),
]
self.assertConversationEqual(actual_output, expected_output)
@@ -731,6 +724,7 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest):
)
@title_generator_mock
+ @query_executor_mock
@patch("ee.hogai.graph.schema_generator.nodes.SchemaGeneratorNode._model")
@patch("ee.hogai.graph.query_planner.nodes.QueryPlannerNode._get_model")
@patch("ee.hogai.graph.root.nodes.RootNode._get_model")
@@ -747,7 +741,7 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest):
{
"id": "xyz",
"name": "create_and_query_insight",
- "args": {"query_description": "Foobar", "query_kind": "retention"},
+ "args": {"query_description": "Foobar"},
}
],
)
@@ -792,7 +786,7 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest):
AssistantToolCall(
id="xyz",
name="create_and_query_insight",
- args={"query_description": "Foobar", "query_kind": "retention"},
+ args={"query_description": "Foobar"},
)
],
),
@@ -807,7 +801,7 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest):
),
("update", AssistantUpdateEvent(id="message_2", tool_call_id="xyz", content="Creating retention query")),
("message", VisualizationMessage(query="Foobar", answer=query, plan="Plan")),
- ("message", {"tool_call_id": "xyz", "type": "tool"}), # Don't check content as it's implementation detail
+ ("message", AssistantToolCallMessage(tool_call_id="xyz", content="Result")),
("message", AssistantMessage(content="The results indicate a great future for you.")),
]
self.assertConversationEqual(actual_output, expected_output)
@@ -830,6 +824,7 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest):
)
@title_generator_mock
+ @query_executor_mock
@patch("ee.hogai.graph.schema_generator.nodes.SchemaGeneratorNode._model")
@patch("ee.hogai.graph.query_planner.nodes.QueryPlannerNode._get_model")
@patch("ee.hogai.graph.root.nodes.RootNode._get_model")
@@ -844,7 +839,7 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest):
{
"id": "xyz",
"name": "create_and_query_insight",
- "args": {"query_description": "Foobar", "query_kind": "sql"},
+ "args": {"query_description": "Foobar"},
}
],
)
@@ -883,7 +878,7 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest):
AssistantToolCall(
id="xyz",
name="create_and_query_insight",
- args={"query_description": "Foobar", "query_kind": "sql"},
+ args={"query_description": "Foobar"},
)
],
),
@@ -898,7 +893,7 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest):
),
("update", AssistantUpdateEvent(id="message_2", tool_call_id="xyz", content="Creating SQL query")),
("message", VisualizationMessage(query="Foobar", answer=query, plan="Plan")),
- ("message", {"tool_call_id": "xyz", "type": "tool"}), # Don't check content as it's implementation detail
+ ("message", AssistantToolCallMessage(tool_call_id="xyz", content="Result")),
("message", AssistantMessage(content="The results indicate a great future for you.")),
]
self.assertConversationEqual(actual_output, expected_output)
@@ -1096,7 +1091,7 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest):
{
"id": str(uuid4()),
"name": "create_and_query_insight",
- "args": {"query_description": "Foobar", "query_kind": "trends"},
+ "args": {"query_description": "Foobar"},
}
],
)
@@ -1138,7 +1133,7 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest):
graph = (
AssistantGraph(self.team, self.user)
.add_edge(AssistantNodeName.START, AssistantNodeName.ROOT)
- .add_root({"root": AssistantNodeName.ROOT, "end": AssistantNodeName.END})
+ .add_root()
.compile()
)
self.assertEqual(self.conversation.status, Conversation.Status.IDLE)
@@ -1158,7 +1153,7 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest):
graph = (
AssistantGraph(self.team, self.user)
.add_edge(AssistantNodeName.START, AssistantNodeName.ROOT)
- .add_root({"root": AssistantNodeName.ROOT, "end": AssistantNodeName.END})
+ .add_root()
.compile()
)
@@ -1177,7 +1172,7 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest):
{
"id": "1",
"name": "create_and_query_insight",
- "args": {"query_description": "Foobar", "query_kind": "trends"},
+ "args": {"query_description": "Foobar"},
}
],
)
@@ -1209,14 +1204,7 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest):
graph = (
AssistantGraph(self.team, self.user)
.add_edge(AssistantNodeName.START, AssistantNodeName.ROOT)
- .add_root(
- {
- "search_documentation": AssistantNodeName.INKEEP_DOCS,
- "root": AssistantNodeName.ROOT,
- "end": AssistantNodeName.END,
- }
- )
- .add_inkeep_docs()
+ .add_root()
.compile()
)
@@ -1235,45 +1223,25 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest):
self.assertConversationEqual(
output,
[
- (
- "message",
- HumanMessage(
- content="How do I use feature flags?",
- id="d93b3d57-2ff3-4c63-8635-8349955b16f0",
- parent_tool_call_id=None,
- type="human",
- ui_context=None,
- ),
- ),
+ ("message", HumanMessage(content="How do I use feature flags?")),
(
"message",
AssistantMessage(
content="",
- id="ea1f4faf-85b4-4d98-b044-842b7092ba5d",
- meta=None,
- parent_tool_call_id=None,
tool_calls=[
AssistantToolCall(
args={"kind": "docs", "query": "test"}, id="1", name="search", type="tool_call"
)
],
- type="ai",
),
),
(
"update",
- AssistantUpdateEvent(id="message_2", tool_call_id="1", content="Checking PostHog documentation..."),
+ AssistantUpdateEvent(content="Checking PostHog documentation...", id="1", tool_call_id="1"),
),
(
"message",
- AssistantToolCallMessage(
- content="Checking PostHog documentation...",
- id="00149847-0bcd-4b7c-a514-bab176312ae8",
- parent_tool_call_id=None,
- tool_call_id="1",
- type="tool",
- ui_payload=None,
- ),
+ AssistantToolCallMessage(content="Checking PostHog documentation...", tool_call_id="1"),
),
(
"message",
@@ -1283,12 +1251,10 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest):
)
@title_generator_mock
+ @query_executor_mock
@patch("ee.hogai.graph.schema_generator.nodes.SchemaGeneratorNode._model")
@patch("ee.hogai.graph.query_planner.nodes.QueryPlannerNode._get_model")
- @patch("ee.hogai.graph.graph.QueryExecutorNode")
- async def test_insights_tool_mode_flow(
- self, query_executor_mock, planner_mock, generator_mock, title_generator_mock
- ):
+ async def test_insights_tool_mode_flow(self, planner_mock, generator_mock, title_generator_mock):
"""Test that the insights tool mode works correctly."""
query = AssistantTrendsQuery(series=[])
tool_call_id = str(uuid4())
@@ -1315,25 +1281,6 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest):
)
generator_mock.return_value = RunnableLambda(lambda _: TrendsSchemaGeneratorOutput(query=query))
- class QueryExecutorNodeMock(BaseAssistantNode):
- def __init__(self, team, user):
- super().__init__(team, user)
-
- @property
- def node_name(self):
- return AssistantNodeName.QUERY_EXECUTOR
-
- async def arun(self, state, config):
- return PartialAssistantState(
- messages=[
- AssistantToolCallMessage(
- content="The results indicate a great future for you.", tool_call_id=tool_call_id
- )
- ]
- )
-
- query_executor_mock.return_value = QueryExecutorNodeMock(self.team, self.user)
-
# Run in insights tool mode
output, _ = await self._run_assistant_graph(
conversation=self.conversation,
@@ -1347,9 +1294,7 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest):
("message", VisualizationMessage(query="Foobar", answer=query, plan="Plan")),
(
"message",
- AssistantToolCallMessage(
- content="The results indicate a great future for you.", tool_call_id=tool_call_id
- ),
+ AssistantToolCallMessage(content="Result", tool_call_id=tool_call_id),
),
]
self.assertConversationEqual(output, expected_output)
@@ -1387,7 +1332,6 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest):
async def test_merges_messages_with_same_id(self):
"""Test that messages with the same ID are merged into one."""
- from ee.hogai.graph.base import BaseAssistantNode
message_ids = [str(uuid4()), str(uuid4())]
@@ -1395,7 +1339,7 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest):
first_content = "First version of message"
updated_content = "Updated version of message"
- class MessageUpdatingNode(BaseAssistantNode):
+ class MessageUpdatingNode(AssistantNode):
def __init__(self, team, user):
super().__init__(team, user)
self.call_count = 0
@@ -1447,7 +1391,6 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest):
async def test_assistant_filters_messages_correctly(self):
"""Test that the Assistant class correctly filters messages based on should_output_assistant_message."""
- from ee.hogai.graph.base import BaseAssistantNode
output_messages = [
# Should be output (has content)
@@ -1467,7 +1410,7 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest):
for test_message, expected_in_output in output_messages:
# Create a simple graph that produces different message types to test filtering
- class MessageFilteringNode(BaseAssistantNode):
+ class MessageFilteringNode(AssistantNode):
def __init__(self, team, user, message_to_return):
super().__init__(team, user)
self.message_to_return = message_to_return
@@ -1477,8 +1420,7 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest):
return AssistantNodeName.ROOT
async def arun(self, state, config):
- self.dispatcher.message(self.message_to_return)
- return PartialAssistantState()
+ return PartialAssistantState(messages=[self.message_to_return])
# Create a graph with our test node
node = MessageFilteringNode(self.team, self.user, test_message)
@@ -1505,12 +1447,13 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest):
"""Test that ui_context persists when retrieving conversation state across multiple runs."""
# Create a simple graph that just returns the initial state
- def return_initial_state(state):
- return {"messages": [AssistantMessage(content="Response from assistant")]}
+ class ReturnInitialStateNode(AssistantNode):
+ async def arun(self, state, config):
+ return PartialAssistantState(messages=[AssistantMessage(content="Response from assistant")])
graph = (
AssistantGraph(self.team, self.user)
- .add_node(AssistantNodeName.ROOT, return_initial_state)
+ .add_node(AssistantNodeName.ROOT, ReturnInitialStateNode(self.team, self.user))
.add_edge(AssistantNodeName.START, AssistantNodeName.ROOT)
.add_edge(AssistantNodeName.ROOT, AssistantNodeName.END)
.compile()
@@ -1585,8 +1528,8 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest):
tool_calls=[
{
"id": "xyz",
- "name": "edit_current_insight",
- "args": {"query_description": "Foobar", "query_kind": "trends"},
+ "name": "create_and_query_insight",
+ "args": {"query_description": "Foobar"},
}
],
)
@@ -1622,20 +1565,13 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest):
output, assistant = await self._run_assistant_graph(
test_graph=AssistantGraph(self.team, self.user)
.add_edge(AssistantNodeName.START, AssistantNodeName.ROOT)
- .add_root(
- {
- "root": AssistantNodeName.ROOT,
- "insights": AssistantNodeName.INSIGHTS_SUBGRAPH,
- "end": AssistantNodeName.END,
- }
- )
- .add_insights()
+ .add_root()
.compile(),
conversation=self.conversation,
is_new_conversation=True,
message="Hello",
mode=AssistantMode.ASSISTANT,
- contextual_tools={"edit_current_insight": {"current_query": "query"}},
+ contextual_tools={"create_and_query_insight": {"current_query": "query"}},
)
expected_output = [
@@ -1645,23 +1581,30 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest):
"message",
AssistantMessage(
content="",
+ id="56076433-5d90-4248-9a46-df3fda42bd0a",
tool_calls=[
AssistantToolCall(
+ args={"query_description": "Foobar"},
id="xyz",
- name="edit_current_insight",
- args={"query_description": "Foobar", "query_kind": "trends"},
+ name="create_and_query_insight",
+ type="tool_call",
)
],
),
),
+ (
+ "update",
+ AssistantUpdateEvent(content="Picking relevant events and properties", tool_call_id="xyz", id=""),
+ ),
+ ("update", AssistantUpdateEvent(content="Creating trends query", tool_call_id="xyz", id="")),
("message", VisualizationMessage(query="Foobar", answer=query, plan="Plan")),
(
"message",
- {
- "tool_call_id": "xyz",
- "type": "tool",
- "ui_payload": {"edit_current_insight": query.model_dump(exclude_none=True)},
- }, # Don't check content as it's implementation detail
+ AssistantToolCallMessage(
+ content="The results indicate a great future for you.",
+ tool_call_id="xyz",
+ ui_payload={"create_and_query_insight": query.model_dump(exclude_none=True)},
+ ),
),
("message", AssistantMessage(content="Everything is fine")),
]
@@ -1671,7 +1614,7 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest):
state = AssistantState.model_validate(snapshot.values)
expected_state_messages = [
ContextMessage(
- content="\nContextual tools that are available to you on this page are:\n\nThe user is currently editing an insight (aka query). Here is that insight's current definition, which can be edited using the `edit_current_insight` tool:\n\n```json\nquery\n```\n\nIMPORTANT: DO NOT REMOVE ANY FIELDS FROM THE CURRENT INSIGHT DEFINITION. DO NOT CHANGE ANY OTHER FIELDS THAN THE ONES THE USER ASKED FOR. KEEP THE REST AS IS.\n\nIMPORTANT: this context may or may not be relevant to your tasks. You should not respond to this context unless it is highly relevant to your task.\n"
+ content="\nContextual tools that are available to you on this page are:\n\nThe user is currently editing an insight (aka query). Here is that insight's current definition, which can be edited using the `create_and_query_insight` tool:\n\n```json\nquery\n```\n\n\nDo not remove any fields from the current insight definition. Do not change any other fields than the ones the user asked for. Keep the rest as is.\n\n\nIMPORTANT: this context may or may not be relevant to your tasks. You should not respond to this context unless it is highly relevant to your task.\n"
),
HumanMessage(content="Hello"),
AssistantMessage(
@@ -1679,8 +1622,8 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest):
tool_calls=[
AssistantToolCall(
id="xyz",
- name="edit_current_insight",
- args={"query_description": "Foobar", "query_kind": "trends"},
+ name="create_and_query_insight",
+ args={"query_description": "Foobar"},
)
],
),
@@ -1688,7 +1631,7 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest):
AssistantToolCallMessage(
content="The results indicate a great future for you.",
tool_call_id="xyz",
- ui_payload={"edit_current_insight": query.model_dump(exclude_none=True)},
+ ui_payload={"create_and_query_insight": query.model_dump(exclude_none=True)},
),
AssistantMessage(content="Everything is fine"),
]
@@ -1705,7 +1648,7 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest):
graph = (
AssistantGraph(self.team, self.user)
.add_edge(AssistantNodeName.START, AssistantNodeName.ROOT)
- .add_root({"root": AssistantNodeName.ROOT, "end": AssistantNodeName.END})
+ .add_root()
.compile()
)
@@ -1757,16 +1700,14 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest):
# Tests for ainvoke method
async def test_ainvoke_basic_functionality(self):
"""Test ainvoke returns all messages at once without streaming."""
- from ee.hogai.graph.base import BaseAssistantNode
- class TestNode(BaseAssistantNode):
+ class TestNode(AssistantNode):
@property
def node_name(self):
return AssistantNodeName.ROOT
async def arun(self, state, config):
- self.dispatcher.message(AssistantMessage(content="Response"))
- return PartialAssistantState()
+ return PartialAssistantState(messages=[AssistantMessage(content="Response", id=str(uuid4()))])
test_node = TestNode(self.team, self.user)
graph = (
@@ -1798,28 +1739,6 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest):
self.assertIsInstance(item[1], AssistantMessage)
self.assertEqual(cast(AssistantMessage, item[1]).content, "Response")
- async def test_process_value_update_returns_none(self):
- """Test that _aprocess_value_update returns None for basic state updates (ACKs are now handled by reducer)."""
-
- assistant = Assistant.create(
- self.team, self.conversation, new_message=HumanMessage(content="Hello"), user=self.user
- )
-
- # Create a value update tuple that doesn't match special nodes
- update = cast(
- GraphValueUpdateTuple,
- (
- AssistantNodeName.ROOT,
- {"root": {"messages": []}}, # Empty update that doesn't match visualization or verbose nodes
- ),
- )
-
- # Process the update
- result = await assistant._aprocess_value_update(update)
-
- # Should return None (ACK events are now generated by the reducer)
- self.assertIsNone(result)
-
def test_billing_context_in_config(self):
billing_context = MaxBillingContext(
has_active_subscription=True,
@@ -1855,7 +1774,225 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest):
)
config = assistant._get_config()
- self.assertEqual(config.get("configurable", {}).get("billing_context"), billing_context)
+ self.assertEqual(config["configurable"]["billing_context"], billing_context)
+
+ @patch("ee.hogai.context.context.AssistantContextManager.check_user_has_billing_access", return_value=True)
+ @patch("ee.hogai.graph.root.nodes.RootNode._get_model")
+ async def test_billing_tool_execution(self, root_mock, access_mock):
+ """Test that the billing tool can be called and returns formatted billing information."""
+ billing_context = MaxBillingContext(
+ subscription_level=MaxBillingContextSubscriptionLevel.PAID,
+ billing_plan="startup",
+ has_active_subscription=True,
+ is_deactivated=False,
+ products=[
+ MaxProductInfo(
+ name="Product Analytics",
+ type="analytics",
+ description="Track user behavior",
+ current_usage=50000,
+ usage_limit=100000,
+ has_exceeded_limit=False,
+ is_used=True,
+ percentage_usage=0.5,
+ addons=[],
+ )
+ ],
+ settings=MaxBillingContextSettings(autocapture_on=True, active_destinations=2),
+ )
+
+ # Mock the root node to call the read_data tool with billing_info kind
+ tool_call_id = str(uuid4())
+
+ def root_side_effect(msgs: list[BaseMessage]):
+ # Check if we've already received a tool result
+ last_message = msgs[-1]
+ if (
+ isinstance(last_message.content, list)
+ and isinstance(last_message.content[-1], dict)
+ and last_message.content[-1]["type"] == "tool_result"
+ ):
+ # After tool execution, respond with final message
+ return messages.AIMessage(content="Your billing information shows you're on a startup plan.")
+
+ # First call - request the billing tool
+ return messages.AIMessage(
+ content="",
+ tool_calls=[
+ {
+ "id": tool_call_id,
+ "name": "read_data",
+ "args": {"kind": "billing_info"},
+ }
+ ],
+ )
+
+ root_mock.return_value = FakeAnthropicRunnableLambdaWithTokenCounter(root_side_effect)
+
+ # Create a minimal test graph
+ test_graph = (
+ AssistantGraph(self.team, self.user)
+ .add_edge(AssistantNodeName.START, AssistantNodeName.ROOT)
+ .add_root()
+ .compile()
+ )
+
+ # Run the assistant with billing context
+ assistant = Assistant.create(
+ team=self.team,
+ conversation=self.conversation,
+ user=self.user,
+ new_message=HumanMessage(content="What's my current billing status?"),
+ billing_context=billing_context,
+ )
+ assistant._graph = test_graph
+
+ output: list[AssistantOutput] = []
+ async for event in assistant.astream():
+ output.append(event)
+
+ # Verify we received messages
+ self.assertGreater(len(output), 0)
+
+ # Find the assistant's final response
+ assistant_messages = [msg for event_type, msg in output if isinstance(msg, AssistantMessage)]
+ self.assertGreater(len(assistant_messages), 0)
+
+ # Verify the assistant received and used the billing information
+ # The mock returns "Your billing information shows you're on a startup plan."
+ final_message = cast(AssistantMessage, assistant_messages[-1])
+ self.assertIn("billing", final_message.content.lower())
+ self.assertIn("startup", final_message.content.lower())
+
+ async def test_messages_without_id_are_yielded(self):
+ """Test that messages without ID are always yielded."""
+
+ class MessageWithoutIdNode(AssistantNode):
+ call_count = 0
+
+ async def arun(self, state, config):
+ self.call_count += 1
+ # Return message without ID - should always be yielded
+ return PartialAssistantState(
+ messages=[
+ AssistantMessage(content=f"Message {self.call_count} without ID"),
+ AssistantMessage(content=f"Message {self.call_count} without ID"),
+ ]
+ )
+
+ node = MessageWithoutIdNode(self.team, self.user)
+ graph = (
+ AssistantGraph(self.team, self.user)
+ .add_node(AssistantNodeName.ROOT, node)
+ .add_edge(AssistantNodeName.START, AssistantNodeName.ROOT)
+ .add_edge(AssistantNodeName.ROOT, AssistantNodeName.END)
+ .compile()
+ )
+
+ # Run the assistant multiple times
+ output1, _ = await self._run_assistant_graph(graph, message="First run", conversation=self.conversation)
+ output2, _ = await self._run_assistant_graph(graph, message="Second run", conversation=self.conversation)
+
+ # Both runs should yield their messages (human + assistant message each)
+ self.assertEqual(len(output1), 3) # Human message + AI message + AI message
+ self.assertEqual(len(output2), 3) # Human message + AI message + AI message
+
+ async def test_messages_with_id_are_deduplicated(self):
+ """Test that messages with ID are deduplicated during streaming."""
+ message_id = str(uuid4())
+
+ class DuplicateMessageNode(AssistantNode):
+ call_count = 0
+
+ async def arun(self, state, config):
+ self.call_count += 1
+ # Always return the same message with same ID
+ return PartialAssistantState(
+ messages=[
+ AssistantMessage(id=message_id, content=f"Call {self.call_count}"),
+ AssistantMessage(id=message_id, content=f"Call {self.call_count}"),
+ ]
+ )
+
+ node = DuplicateMessageNode(self.team, self.user)
+ graph = (
+ AssistantGraph(self.team, self.user)
+ .add_node(AssistantNodeName.ROOT, node)
+ .add_edge(AssistantNodeName.START, AssistantNodeName.ROOT)
+ .add_edge(AssistantNodeName.ROOT, AssistantNodeName.END)
+ .compile()
+ )
+
+ # Create assistant and manually test the streaming behavior
+ assistant = Assistant.create(
+ self.team,
+ self.conversation,
+ new_message=HumanMessage(content="Test message"),
+ user=self.user,
+ is_new_conversation=False,
+ )
+ assistant._graph = graph
+
+ # Collect all streamed messages
+ streamed_messages = []
+ async for event_type, message in assistant.astream(stream_first_message=False):
+ if event_type == AssistantEventType.MESSAGE:
+ streamed_messages.append(message)
+
+ # Should only get one message despite the node being called multiple times
+ assistant_messages = [
+ msg for msg in streamed_messages if isinstance(msg, AssistantMessage) and msg.id == message_id
+ ]
+ self.assertEqual(len(assistant_messages), 1, "Message with same ID should only be yielded once")
+
+ async def test_replaced_messaged_are_not_double_streamed(self):
+ """Test that existing messages are not streamed again"""
+ # Create messages with IDs that should be tracked
+ message_id_1 = str(uuid4())
+ message_id_2 = str(uuid4())
+ call_count = [0]
+
+ # Create a simple graph that returns messages with IDs
+ class TestNode(AssistantNode):
+ async def arun(self, state, config):
+ result = None
+ if call_count[0] == 0:
+ result = PartialAssistantState(
+ messages=[
+ AssistantMessage(id=message_id_1, content="Message 1"),
+ ]
+ )
+ else:
+ result = PartialAssistantState(
+ messages=ReplaceMessages(
+ [
+ AssistantMessage(id=message_id_1, content="Message 1"),
+ AssistantMessage(id=message_id_2, content="Message 2"),
+ ]
+ )
+ )
+ call_count[0] += 1
+ return result
+
+ graph = (
+ AssistantGraph(self.team, self.user)
+ .add_node(AssistantNodeName.ROOT, TestNode(self.team, self.user))
+ .add_edge(AssistantNodeName.START, AssistantNodeName.ROOT)
+ .add_edge(AssistantNodeName.ROOT, AssistantNodeName.END)
+ .compile()
+ )
+
+ output, _ = await self._run_assistant_graph(graph, message="First run", conversation=self.conversation)
+ # Filter for assistant messages only, as the test is about tracking assistant message IDs
+ assistant_output = [(event_type, msg) for event_type, msg in output if isinstance(msg, AssistantMessage)]
+ self.assertEqual(len(assistant_output), 1)
+ self.assertEqual(cast(AssistantMessage, assistant_output[0][1]).id, message_id_1)
+
+ output, _ = await self._run_assistant_graph(graph, message="Second run", conversation=self.conversation)
+ # Filter for assistant messages only, as the test is about tracking assistant message IDs
+ assistant_output = [(event_type, msg) for event_type, msg in output if isinstance(msg, AssistantMessage)]
+ self.assertEqual(len(assistant_output), 1)
+ self.assertEqual(cast(AssistantMessage, assistant_output[0][1]).id, message_id_2)
@patch(
"ee.hogai.graph.conversation_summarizer.nodes.AnthropicConversationSummarizer.summarize",
@@ -1887,22 +2024,7 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest):
mock_tool.return_value = ("Event list" * 128000, None)
mock_should_compact.side_effect = cycle([False, True]) # Also changed this
- graph = (
- AssistantGraph(self.team, self.user)
- .add_root(
- path_map={
- "insights": AssistantNodeName.END,
- "search_documentation": AssistantNodeName.END,
- "root": AssistantNodeName.ROOT,
- "end": AssistantNodeName.END,
- "insights_search": AssistantNodeName.END,
- "session_summarization": AssistantNodeName.END,
- "create_dashboard": AssistantNodeName.END,
- }
- )
- .add_memory_onboarding()
- .compile()
- )
+ graph = AssistantGraph(self.team, self.user).add_root().add_memory_onboarding().compile()
expected_output = [
("message", HumanMessage(content="First")),
@@ -1928,3 +2050,86 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest):
self.assertEqual(state.start_id, new_human_message.id)
# should be equal to the summary message, minus reasoning message
self.assertEqual(state.root_conversation_start_id, state.messages[3].id)
+
+ @patch("ee.hogai.graph.root.tools.search.SearchTool._arun_impl", return_value=("Docs doubt it", None))
+ @patch(
+ "ee.hogai.graph.root.tools.read_taxonomy.ReadTaxonomyTool._run_impl",
+ return_value=("Hedgehogs have not talked yet", None),
+ )
+ @patch("ee.hogai.graph.root.nodes.RootNode._get_model")
+ async def test_root_node_can_execute_multiple_tool_calls(self, root_mock, search_mock, read_taxonomy_mock):
+ """Test that the root node can execute multiple tool calls in parallel."""
+ tool_call_id1, tool_call_id2 = [str(uuid4()), str(uuid4())]
+
+ def root_side_effect(msgs: list[BaseMessage]):
+ # Check if we've already received a tool result
+ last_message = msgs[-1]
+ if (
+ isinstance(last_message.content, list)
+ and isinstance(last_message.content[-1], dict)
+ and last_message.content[-1]["type"] == "tool_result"
+ ):
+ # After tool execution, respond with final message
+ return messages.AIMessage(content="No")
+
+ return messages.AIMessage(
+ content="Not sure. Let me check.",
+ tool_calls=[
+ {
+ "id": tool_call_id1,
+ "name": "search",
+ "args": {"kind": "docs", "query": "Do hedgehogs speak?"},
+ },
+ {
+ "id": tool_call_id2,
+ "name": "read_taxonomy",
+ "args": {"query": {"kind": "events"}},
+ },
+ ],
+ )
+
+ root_mock.return_value = FakeAnthropicRunnableLambdaWithTokenCounter(root_side_effect)
+
+ # Create a minimal test graph
+ graph = (
+ AssistantGraph(self.team, self.user)
+ .add_edge(AssistantNodeName.START, AssistantNodeName.ROOT)
+ .add_root()
+ .compile()
+ )
+
+ expected_output = [
+ (AssistantEventType.MESSAGE, HumanMessage(content="Do hedgehogs speak?")),
+ (
+ AssistantEventType.MESSAGE,
+ AssistantMessage(
+ content="Not sure. Let me check.",
+ tool_calls=[
+ {
+ "id": tool_call_id1,
+ "name": "search",
+ "args": {"kind": "docs", "query": "Do hedgehogs speak?"},
+ },
+ {
+ "id": tool_call_id2,
+ "name": "read_taxonomy",
+ "args": {"query": {"kind": "events"}},
+ },
+ ],
+ ),
+ ),
+ (
+ AssistantEventType.MESSAGE,
+ AssistantToolCallMessage(content="Docs doubt it", tool_call_id=tool_call_id1),
+ ),
+ (
+ AssistantEventType.MESSAGE,
+ AssistantToolCallMessage(content="Hedgehogs have not talked yet", tool_call_id=tool_call_id2),
+ ),
+ (AssistantEventType.MESSAGE, AssistantMessage(content="No")),
+ ]
+ output, _ = await self._run_assistant_graph(
+ graph, message="Do hedgehogs speak?", conversation=self.conversation
+ )
+
+ self.assertConversationEqual(output, expected_output)
diff --git a/ee/hogai/tool.py b/ee/hogai/tool.py
index 704c569ff2..3213639291 100644
--- a/ee/hogai/tool.py
+++ b/ee/hogai/tool.py
@@ -1,6 +1,7 @@
import json
import pkgutil
import importlib
+from abc import ABC, abstractmethod
from collections.abc import Sequence
from typing import Any, Literal, Self
@@ -16,9 +17,9 @@ from posthog.models import Team, User
import products
from ee.hogai.context.context import AssistantContextManager
-from ee.hogai.graph.mixins import AssistantContextMixin
-from ee.hogai.utils.types import AssistantState
-from ee.hogai.utils.types.base import AssistantMessageUnion
+from ee.hogai.graph.base.context import get_node_path, set_node_path
+from ee.hogai.graph.mixins import AssistantContextMixin, AssistantDispatcherMixin
+from ee.hogai.utils.types.base import AssistantMessageUnion, AssistantState, NodePath
CONTEXTUAL_TOOL_NAME_TO_TOOL: dict[AssistantTool, type["MaxTool"]] = {}
@@ -51,7 +52,7 @@ class ToolMessagesArtifact(BaseModel):
messages: Sequence[AssistantMessageUnion]
-class MaxTool(AssistantContextMixin, BaseTool):
+class MaxTool(AssistantContextMixin, AssistantDispatcherMixin, BaseTool):
# LangChain's default is just "content", but we always want to return the tool call artifact too
# - it becomes the `ui_payload`
response_format: Literal["content_and_artifact"] = "content_and_artifact"
@@ -66,7 +67,7 @@ class MaxTool(AssistantContextMixin, BaseTool):
_config: RunnableConfig
_state: AssistantState
_context_manager: AssistantContextManager
- _tool_call_id: str
+ _node_path: tuple[NodePath, ...]
# DEPRECATED: Use `_arun_impl` instead
def _run_impl(self, *args, **kwargs) -> tuple[str, Any]:
@@ -82,7 +83,7 @@ class MaxTool(AssistantContextMixin, BaseTool):
*,
team: Team,
user: User,
- tool_call_id: str,
+ node_path: tuple[NodePath, ...] | None = None,
state: AssistantState | None = None,
config: RunnableConfig | None = None,
name: str | None = None,
@@ -102,7 +103,10 @@ class MaxTool(AssistantContextMixin, BaseTool):
super().__init__(**tool_kwargs, **kwargs)
self._team = team
self._user = user
- self._tool_call_id = tool_call_id
+ if node_path is None:
+ self._node_path = (*(get_node_path() or ()), NodePath(name=self.node_name))
+ else:
+ self._node_path = node_path
self._state = state if state else AssistantState(messages=[])
self._config = config if config else RunnableConfig(configurable={})
self._context_manager = context_manager or AssistantContextManager(team, user, self._config)
@@ -120,19 +124,35 @@ class MaxTool(AssistantContextMixin, BaseTool):
CONTEXTUAL_TOOL_NAME_TO_TOOL[accepted_name] = cls
def _run(self, *args, config: RunnableConfig, **kwargs):
+ """LangChain default runner."""
try:
- return self._run_impl(*args, **kwargs)
+ return self._run_with_context(*args, **kwargs)
except NotImplementedError:
pass
- return async_to_sync(self._arun_impl)(*args, **kwargs)
+ return async_to_sync(self._arun_with_context)(*args, **kwargs)
async def _arun(self, *args, config: RunnableConfig, **kwargs):
+ """LangChain default runner."""
try:
- return await self._arun_impl(*args, **kwargs)
+ return await self._arun_with_context(*args, **kwargs)
except NotImplementedError:
pass
return await super()._arun(*args, config=config, **kwargs)
+ def _run_with_context(self, *args, **kwargs):
+ """Sets the context for the tool."""
+ with set_node_path(self.node_path):
+ return self._run_impl(*args, **kwargs)
+
+ async def _arun_with_context(self, *args, **kwargs):
+ """Sets the context for the tool."""
+ with set_node_path(self.node_path):
+ return await self._arun_impl(*args, **kwargs)
+
+ @property
+ def node_name(self) -> str:
+ return f"max_tool.{self.get_name()}"
+
@property
def context(self) -> dict:
return self._context_manager.get_contextual_tools().get(self.get_name(), {})
@@ -149,7 +169,7 @@ class MaxTool(AssistantContextMixin, BaseTool):
*,
team: Team,
user: User,
- tool_call_id: str,
+ node_path: tuple[NodePath, ...] | None = None,
state: AssistantState | None = None,
config: RunnableConfig | None = None,
context_manager: AssistantContextManager | None = None,
@@ -160,5 +180,31 @@ class MaxTool(AssistantContextMixin, BaseTool):
Override this factory to dynamically modify the tool name, description, args schema, etc.
"""
return cls(
- team=team, user=user, tool_call_id=tool_call_id, state=state, config=config, context_manager=context_manager
+ team=team, user=user, node_path=node_path, state=state, config=config, context_manager=context_manager
)
+
+
+class MaxSubtool(AssistantDispatcherMixin, ABC):
+ _config: RunnableConfig
+
+ def __init__(
+ self,
+ *,
+ team: Team,
+ user: User,
+ state: AssistantState,
+ config: RunnableConfig,
+ context_manager: AssistantContextManager,
+ ):
+ self._team = team
+ self._user = user
+ self._state = state
+ self._context_manager = context_manager
+
+ @abstractmethod
+ async def execute(self, *args, **kwargs) -> Any:
+ pass
+
+ @property
+ def node_name(self) -> str:
+ return f"max_subtool.{self.__class__.__name__}"
diff --git a/ee/hogai/utils/dispatcher.py b/ee/hogai/utils/dispatcher.py
index 0b7a3b0e75..7cfa1db995 100644
--- a/ee/hogai/utils/dispatcher.py
+++ b/ee/hogai/utils/dispatcher.py
@@ -1,21 +1,19 @@
from collections.abc import Callable
-from typing import TYPE_CHECKING, Any
+from typing import Any
+from langchain_core.runnables import RunnableConfig
+from langgraph.config import get_stream_writer
from langgraph.types import StreamWriter
-from posthog.schema import AssistantMessage, AssistantToolCallMessage
-
from ee.hogai.utils.types.base import (
AssistantActionUnion,
AssistantDispatcherEvent,
AssistantMessageUnion,
MessageAction,
- NodeStartAction,
+ NodePath,
+ UpdateAction,
)
-if TYPE_CHECKING:
- from ee.hogai.utils.types.composed import MaxNodeName
-
class AssistantDispatcher:
"""
@@ -26,13 +24,14 @@ class AssistantDispatcher:
The dispatcher does NOT update state - it just emits actions to the stream.
"""
- _parent_tool_call_id: str | None = None
+ _node_path: tuple[NodePath, ...]
def __init__(
self,
writer: StreamWriter | Callable[[Any], None],
- node_name: "MaxNodeName",
- parent_tool_call_id: str | None = None,
+ node_path: tuple[NodePath, ...],
+ node_name: str,
+ node_run_id: str,
):
"""
Create a dispatcher for a specific node.
@@ -40,9 +39,10 @@ class AssistantDispatcher:
Args:
node_name: The name of the node dispatching actions (for attribution)
"""
- self._node_name = node_name
self._writer = writer
- self._parent_tool_call_id = parent_tool_call_id
+ self._node_path = node_path
+ self._node_name = node_name
+ self._node_run_id = node_run_id
def dispatch(self, action: AssistantActionUnion) -> None:
"""
@@ -56,7 +56,11 @@ class AssistantDispatcher:
action: Action dict with "type" and "payload" keys
"""
try:
- self._writer(AssistantDispatcherEvent(action=action, node_name=self._node_name))
+ self._writer(
+ AssistantDispatcherEvent(
+ action=action, node_path=self._node_path, node_name=self._node_name, node_run_id=self._node_run_id
+ )
+ )
except Exception as e:
# Log error but don't crash node execution
# The dispatcher should be resilient to writer failures
@@ -69,34 +73,29 @@ class AssistantDispatcher:
"""
Dispatch a message to the stream.
"""
- if self._parent_tool_call_id:
- # If the dispatcher is initialized with a parent tool call id, set the parent tool call id on the message
- # This is to ensure that the message is associated with the correct tool call
- # Don't set parent_tool_call_id on:
- # 1. AssistantToolCallMessage with the same tool_call_id (to avoid self-reference)
- # 2. AssistantMessage with tool_calls containing the same ID (to avoid cycles)
- should_skip = False
- if isinstance(message, AssistantToolCallMessage) and self._parent_tool_call_id == message.tool_call_id:
- should_skip = True
- elif isinstance(message, AssistantMessage) and message.tool_calls:
- # Check if any tool call has the same ID as the parent
- for tool_call in message.tool_calls:
- if tool_call.id == self._parent_tool_call_id:
- should_skip = True
- break
-
- if not should_skip:
- message.parent_tool_call_id = self._parent_tool_call_id
self.dispatch(MessageAction(message=message))
- def node_start(self) -> None:
- """
- Dispatch a node start action to the stream.
- """
- self.dispatch(NodeStartAction())
+ def update(self, content: str):
+ """Dispatch a transient update message to the stream that will be associated with a tool call in the UI."""
+ self.dispatch(UpdateAction(content=content))
- def set_as_root(self) -> None:
- """
- Set the dispatcher as the root.
- """
- self._parent_tool_call_id = None
+
+def create_dispatcher_from_config(config: RunnableConfig, node_path: tuple[NodePath, ...]) -> AssistantDispatcher:
+ """Create a dispatcher from a RunnableConfig and node path"""
+ # Set writer from LangGraph context
+ try:
+ writer = get_stream_writer()
+ except RuntimeError:
+ # Not in streaming context (e.g., testing)
+ # Use noop writer
+ def noop(*_args, **_kwargs):
+ pass
+
+ writer = noop
+
+ metadata = config.get("metadata") or {}
+ node_name: str = metadata.get("langgraph_node") or ""
+ # `langgraph_checkpoint_ns` contains the nested path to the node, so it's more accurate for streaming.
+ node_run_id: str = metadata.get("langgraph_checkpoint_ns") or ""
+
+ return AssistantDispatcher(writer, node_path=node_path, node_run_id=node_run_id, node_name=node_name)
diff --git a/ee/hogai/utils/helpers.py b/ee/hogai/utils/helpers.py
index 98c417b7c9..609217f800 100644
--- a/ee/hogai/utils/helpers.py
+++ b/ee/hogai/utils/helpers.py
@@ -38,8 +38,7 @@ from posthog.hogql_queries.query_runner import ExecutionMode
from posthog.models import Team
from posthog.taxonomy.taxonomy import CORE_FILTER_DEFINITIONS_BY_GROUP
-from ee.hogai.utils.types import AssistantMessageUnion
-from ee.hogai.utils.types.base import AssistantDispatcherEvent
+from ee.hogai.utils.types.base import AssistantDispatcherEvent, AssistantMessageUnion
def remove_line_breaks(line: str) -> str:
diff --git a/ee/hogai/utils/state.py b/ee/hogai/utils/state.py
index 2d94ae6cc0..4af17c1b94 100644
--- a/ee/hogai/utils/state.py
+++ b/ee/hogai/utils/state.py
@@ -5,7 +5,7 @@ from structlog import get_logger
from ee.hogai.graph.deep_research.types import DeepResearchNodeName, PartialDeepResearchState
from ee.hogai.graph.taxonomy.types import TaxonomyAgentState, TaxonomyNodeName
-from ee.hogai.utils.types import PartialAssistantState
+from ee.hogai.utils.types.base import PartialAssistantState
from ee.hogai.utils.types.composed import AssistantMaxGraphState, AssistantMaxPartialGraphState, MaxNodeName
# A state update can have a partial state or a LangGraph's reserved dataclasses like Interrupt.
@@ -45,6 +45,7 @@ def validate_value_update(
class LangGraphState(TypedDict):
langgraph_node: MaxNodeName
+ langgraph_checkpoint_ns: str
GraphMessageUpdateTuple = tuple[Literal["messages"], tuple[Union[AIMessageChunk, Any], LangGraphState]]
diff --git a/ee/hogai/utils/stream_processor.py b/ee/hogai/utils/stream_processor.py
index f995092dbe..dba0062bc9 100644
--- a/ee/hogai/utils/stream_processor.py
+++ b/ee/hogai/utils/stream_processor.py
@@ -1,4 +1,4 @@
-from typing import cast
+from typing import Generic, Protocol, TypeVar, cast, get_args
import structlog
from langchain_core.messages import AIMessageChunk
@@ -6,7 +6,6 @@ from langchain_core.messages import AIMessageChunk
from posthog.schema import (
AssistantGenerationStatusEvent,
AssistantGenerationStatusType,
- AssistantMessage,
AssistantToolCallMessage,
AssistantUpdateEvent,
FailureMessage,
@@ -16,21 +15,51 @@ from posthog.schema import (
)
from ee.hogai.utils.helpers import normalize_ai_message, should_output_assistant_message
-from ee.hogai.utils.state import merge_message_chunk
+from ee.hogai.utils.state import is_message_update, is_state_update, merge_message_chunk
from ee.hogai.utils.types.base import (
AssistantDispatcherEvent,
+ AssistantGraphName,
AssistantMessageUnion,
AssistantResultUnion,
+ BaseStateWithMessages,
+ LangGraphUpdateEvent,
MessageAction,
MessageChunkAction,
+ NodeEndAction,
+ NodePath,
NodeStartAction,
+ UpdateAction,
)
from ee.hogai.utils.types.composed import MaxNodeName
logger = structlog.get_logger(__name__)
-class AssistantStreamProcessor:
+class AssistantStreamProcessorProtocol(Protocol):
+ """Protocol defining the interface for assistant stream processors."""
+
+ _streamed_update_ids: set[str]
+ """Tracks the IDs of messages that have been streamed."""
+
+ def process(self, event: AssistantDispatcherEvent) -> list[AssistantResultUnion] | None:
+ """Process a dispatcher event and return a result or None."""
+ ...
+
+ def process_langgraph_update(self, event: LangGraphUpdateEvent) -> list[AssistantResultUnion] | None:
+ """Process a LangGraph update event and return a list of results or None."""
+ ...
+
+ def mark_id_as_streamed(self, message_id: str) -> None:
+ """Mark a message ID as streamed."""
+ self._streamed_update_ids.add(message_id)
+
+
+StateType = TypeVar("StateType", bound=BaseStateWithMessages)
+
+MESSAGE_TYPE_TUPLE = get_args(AssistantMessageUnion)
+
+
+class AssistantStreamProcessor(AssistantStreamProcessorProtocol, Generic[StateType]):
"""
Reduces streamed actions to client-facing messages.
@@ -38,36 +67,34 @@ class AssistantStreamProcessor:
handlers based on action type and message characteristics.
"""
+ _verbose_nodes: set[MaxNodeName]
+ """Nodes that emit messages."""
_streaming_nodes: set[MaxNodeName]
"""Nodes that produce streaming messages."""
- _visualization_nodes: dict[MaxNodeName, type]
- """Nodes that produce visualization messages."""
- _tool_call_id_to_message: dict[str, AssistantMessage]
- """Maps tool call IDs to their parent messages for message chain tracking."""
- _streamed_update_ids: set[str]
- """Tracks the IDs of messages that have been streamed."""
- _chunks: AIMessageChunk
+ _chunks: dict[str, AIMessageChunk]
"""Tracks the current message chunk."""
+ _state: StateType | None
+ """Tracks the current state."""
+ _state_type: type[StateType]
+ """The type of the state."""
- def __init__(
- self,
- streaming_nodes: set[MaxNodeName],
- visualization_nodes: dict[MaxNodeName, type],
- ):
+ def __init__(self, verbose_nodes: set[MaxNodeName], streaming_nodes: set[MaxNodeName], state_type: type[StateType]):
"""
Initialize the stream processor with node configuration.
Args:
+ verbose_nodes: Nodes that produce messages
streaming_nodes: Nodes that produce streaming messages
- visualization_nodes: Nodes that produce visualization messages
"""
+ # If a node is streaming node, it should also be verbose.
+ self._verbose_nodes = verbose_nodes | streaming_nodes
self._streaming_nodes = streaming_nodes
- self._visualization_nodes = visualization_nodes
- self._tool_call_id_to_message = {}
self._streamed_update_ids = set()
- self._chunks = AIMessageChunk(content="")
+ self._chunks = {}
+ self._state_type = state_type
+ self._state = None
- def process(self, event: AssistantDispatcherEvent) -> AssistantResultUnion | None:
+ def process(self, event: AssistantDispatcherEvent) -> list[AssistantResultUnion] | None:
"""
Reduce streamed actions to client messages.
@@ -75,86 +102,113 @@ class AssistantStreamProcessor:
to specialized handlers based on action type and message characteristics.
"""
action = event.action
- node_name = event.node_name
-
- if isinstance(action, MessageChunkAction):
- return self._handle_message_stream(action.message, cast(MaxNodeName, node_name))
if isinstance(action, NodeStartAction):
- return AssistantGenerationStatusEvent(type=AssistantGenerationStatusType.ACK)
+ self._chunks[event.node_run_id] = AIMessageChunk(content="")
+ return [AssistantGenerationStatusEvent(type=AssistantGenerationStatusType.ACK)]
+
+ if isinstance(action, NodeEndAction):
+ if event.node_run_id in self._chunks:
+ del self._chunks[event.node_run_id]
+ return self._handle_node_end(event, action)
+
+ if isinstance(action, MessageChunkAction) and (result := self._handle_message_stream(event, action.message)):
+ return [result]
if isinstance(action, MessageAction):
message = action.message
+ if result := self._handle_message(event, message):
+ return [result]
- # Register any tool calls for later parent chain lookups
- self._register_tool_calls(message)
- result = self._handle_message(message, cast(MaxNodeName, node_name))
- return (
- result if result is not None else AssistantGenerationStatusEvent(type=AssistantGenerationStatusType.ACK)
+ if isinstance(action, UpdateAction) and (update_event := self._handle_update_message(event, action)):
+ return [update_event]
+
+ return None
+
+ def process_langgraph_update(self, event: LangGraphUpdateEvent) -> list[AssistantResultUnion] | None:
+ """
+ Process a LangGraph update event.
+ """
+ if is_message_update(event.update):
+ # Convert the message chunk update to a dispatcher event to prepare for a bright future without LangGraph
+ maybe_message_chunk, state = event.update[1]
+ if not isinstance(maybe_message_chunk, AIMessageChunk):
+ return None
+ action = AssistantDispatcherEvent(
+ action=MessageChunkAction(message=maybe_message_chunk),
+ node_name=state["langgraph_node"],
+ node_run_id=state["langgraph_checkpoint_ns"],
)
+ return self.process(action)
- def _find_parent_ids(self, message: AssistantMessage) -> tuple[str | None, str | None]:
- """
- Walk up the message chain to find the root parent's message_id and tool_call_id.
+ if is_state_update(event.update):
+ new_state = self._state_type.model_validate(event.update[1])
+ self._state = new_state
- Returns (root_message_id, root_tool_call_id) for the root message in the chain.
- Includes cycle detection and max depth protection.
- """
- root_tool_call_id = message.parent_tool_call_id
- if root_tool_call_id is None:
- return message.id, None
+ return None
- root_message_id = None
- visited: set[str] = set()
+ def _handle_message(
+ self, action: AssistantDispatcherEvent, message: AssistantMessageUnion
+ ) -> AssistantResultUnion | None:
+ """Handle a message from a node."""
+ node_name = cast(MaxNodeName, action.node_name)
+ produced_message: AssistantResultUnion | None = None
- while root_tool_call_id is not None:
- if root_tool_call_id in visited:
- # Cycle detected, we skip this message
- return None, None
+ # Output all messages from the top-level graph.
+ if not self._is_message_from_nested_node_or_graph(action.node_path or ()):
+ produced_message = self._handle_root_message(message, node_name)
+ # Other message types with parents (viz, notebook, failure, tool call)
+ else:
+ produced_message = self._handle_special_child_message(message, node_name)
- visited.add(root_tool_call_id)
- parent_message = self._tool_call_id_to_message.get(root_tool_call_id)
- if parent_message is None:
- # The parent message is not registered, we skip this message as it could come
- # from a sub-nested graph invoked directly by a contextual tool.
- return None, None
+ # Messages with existing IDs must be deduplicated.
+ # Messages WITHOUT IDs must be streamed because they're progressive.
+ if isinstance(produced_message, MESSAGE_TYPE_TUPLE) and produced_message.id is not None:
+ if produced_message.id in self._streamed_update_ids:
+ return None
+ self._streamed_update_ids.add(produced_message.id)
- next_parent_tool_call_id = parent_message.parent_tool_call_id
- root_message_id = parent_message.id
- if next_parent_tool_call_id is None:
- return root_message_id, root_tool_call_id
- root_tool_call_id = next_parent_tool_call_id
- raise ValueError("Should not reach here")
+ return produced_message
- def _register_tool_calls(self, message: AssistantMessageUnion) -> None:
- """Register any tool calls in the message for later lookup."""
- if isinstance(message, AssistantMessage) and message.tool_calls is not None:
- for tool_call in message.tool_calls:
- self._tool_call_id_to_message[tool_call.id] = message
+ def _is_message_from_nested_node_or_graph(self, node_path: tuple[NodePath, ...]) -> bool:
+ """Check if the message is from a nested node or graph."""
+ if not node_path:
+ return False
+ # The first path is always the top-level graph.
+ # The second path is always the top-level node.
+ # If the path is longer than 2, it's a MaxTool or nested graphs.
+ if len(node_path) > 2 and next((path for path in node_path[1:] if path.name in AssistantGraphName), None):
+ return True
+ return False
def _handle_root_message(
self, message: AssistantMessageUnion, node_name: MaxNodeName
) -> AssistantMessageUnion | None:
"""Handle messages with no parent (root messages)."""
- if not should_output_assistant_message(message):
+ if node_name not in self._verbose_nodes or not should_output_assistant_message(message):
return None
return message
- def _handle_assistant_message_with_parent(self, message: AssistantMessage) -> AssistantUpdateEvent | None:
+ def _handle_update_message(
+ self, event: AssistantDispatcherEvent, action: UpdateAction
+ ) -> AssistantUpdateEvent | None:
"""Handle AssistantMessage that has a parent, creating an AssistantUpdateEvent."""
- parent_id, parent_tool_call_id = self._find_parent_ids(message)
-
- if parent_tool_call_id is None or parent_id is None:
+ if not event.node_path or not action.content:
return None
- if message.content == "":
+ # Find the closest tool call id to the update.
+ parent_path = next((path for path in reversed(event.node_path) if path.tool_call_id), None)
+ # Updates from the top-level graph nodes are not supported.
+ if not parent_path:
return None
- return AssistantUpdateEvent(
- id=parent_id,
- tool_call_id=parent_tool_call_id,
- content=message.content,
- )
+ tool_call_id = parent_path.tool_call_id
+ message_id = parent_path.message_id
+
+ if not message_id or not tool_call_id:
+ return None
+
+ return AssistantUpdateEvent(id=message_id, tool_call_id=tool_call_id, content=action.content)
def _handle_special_child_message(
self, message: AssistantMessageUnion, node_name: MaxNodeName
@@ -164,14 +218,10 @@ class AssistantStreamProcessor:
These messages are returned as-is regardless of where in the nesting hierarchy they are.
"""
- # Return visualization messages only if from visualization nodes
- if isinstance(message, VisualizationMessage | MultiVisualizationMessage):
- if node_name in self._visualization_nodes:
- return message
- return None
-
# These message types are always returned as-is
- if isinstance(message, NotebookUpdateMessage | FailureMessage):
+ if isinstance(message, VisualizationMessage | MultiVisualizationMessage) or isinstance(
+ message, NotebookUpdateMessage | FailureMessage
+ ):
return message
if isinstance(message, AssistantToolCallMessage):
@@ -181,37 +231,44 @@ class AssistantStreamProcessor:
# Should not reach here
raise ValueError(f"Unhandled special message type: {type(message).__name__}")
- def _handle_message(self, message: AssistantMessageUnion, node_name: MaxNodeName) -> AssistantResultUnion | None:
- # Messages with existing IDs must be deduplicated.
- # Messages WITHOUT IDs must be streamed because they're progressive.
- if hasattr(message, "id") and message.id is not None:
- if message.id in self._streamed_update_ids:
- return None
- self._streamed_update_ids.add(message.id)
-
- # Root messages (no parent) are filtered by VERBOSE_NODES
- if message.parent_tool_call_id is None:
- return self._handle_root_message(message, node_name)
-
- # AssistantMessage with parent creates AssistantUpdateEvent
- if isinstance(message, AssistantMessage):
- return self._handle_assistant_message_with_parent(message)
- else:
- # Other message types with parents (viz, notebook, failure, tool call)
- return self._handle_special_child_message(message, node_name)
-
- def _handle_message_stream(self, message: AIMessageChunk, node_name: MaxNodeName) -> AssistantResultUnion | None:
+ def _handle_message_stream(
+ self, event: AssistantDispatcherEvent, message: AIMessageChunk
+ ) -> AssistantResultUnion | None:
"""
Process LLM chunks from "messages" stream mode.
With dispatch pattern, complete messages are dispatched by nodes.
This handles AIMessageChunk for ephemeral streaming (responsiveness).
"""
+ node_name = cast(MaxNodeName, event.node_name)
+ run_id = event.node_run_id
+
if node_name not in self._streaming_nodes:
return None
+ if run_id not in self._chunks:
+ self._chunks[run_id] = AIMessageChunk(content="")
# Merge message chunks
- self._chunks = merge_message_chunk(self._chunks, message)
+ self._chunks[run_id] = merge_message_chunk(self._chunks[run_id], message)
# Stream ephemeral message (no ID = not persisted)
- return normalize_ai_message(self._chunks)
+ return normalize_ai_message(self._chunks[run_id])
+
+ def _handle_node_end(
+ self, event: AssistantDispatcherEvent, action: NodeEndAction
+ ) -> list[AssistantResultUnion] | None:
+ """Handle the end of a node. Reset the streaming chunks."""
+ if not isinstance(action.state, BaseStateWithMessages):
+ return None
+ results: list[AssistantResultUnion] = []
+ for message in action.state.messages:
+ if new_event := self.process(
+ AssistantDispatcherEvent(
+ action=MessageAction(message=message),
+ node_name=event.node_name,
+ node_run_id=event.node_run_id,
+ node_path=event.node_path,
+ )
+ ):
+ results.extend(new_event)
+ return results
diff --git a/ee/hogai/utils/test/test_dispatcher.py b/ee/hogai/utils/test/test_dispatcher.py
index 4152edbe08..b70088df60 100644
--- a/ee/hogai/utils/test/test_dispatcher.py
+++ b/ee/hogai/utils/test/test_dispatcher.py
@@ -1,165 +1,426 @@
+"""
+Comprehensive tests for AssistantDispatcher.
+
+Tests the dispatcher logic that emits actions to LangGraph custom stream.
+"""
+
from posthog.test.base import BaseTest
-from unittest.mock import AsyncMock, MagicMock
+from unittest.mock import MagicMock, patch
from langchain_core.runnables import RunnableConfig
from posthog.schema import AssistantMessage, AssistantToolCall
-from ee.hogai.graph.base import BaseAssistantNode
-from ee.hogai.utils.dispatcher import AssistantDispatcher
-from ee.hogai.utils.types import AssistantState, PartialAssistantState
-from ee.hogai.utils.types.base import AssistantDispatcherEvent, AssistantNodeName
+from ee.hogai.utils.dispatcher import AssistantDispatcher, create_dispatcher_from_config
+from ee.hogai.utils.types.base import (
+ AssistantDispatcherEvent,
+ AssistantGraphName,
+ AssistantNodeName,
+ MessageAction,
+ NodePath,
+ UpdateAction,
+)
-class TestMessageDispatcher(BaseTest):
+class TestAssistantDispatcher(BaseTest):
+ """Test the AssistantDispatcher in isolation."""
+
def setUp(self):
- self._writer = MagicMock()
- self._writer.__aiter__ = AsyncMock(return_value=iter([]))
- self._writer.write = AsyncMock()
+ super().setUp()
+ self.dispatched_events: list[AssistantDispatcherEvent] = []
- def test_create_dispatcher(self):
- """Test creating a MessageDispatcher"""
- dispatcher = AssistantDispatcher(self._writer, node_name=AssistantNodeName.ROOT)
+ def mock_writer(event):
+ self.dispatched_events.append(event)
+
+ self.writer = mock_writer
+ self.node_path = (
+ NodePath(name=AssistantGraphName.ASSISTANT),
+ NodePath(name=AssistantNodeName.ROOT),
+ )
+
+ def test_create_dispatcher_with_basic_params(self):
+ """Test creating a dispatcher with required parameters."""
+ dispatcher = AssistantDispatcher(
+ writer=self.writer,
+ node_path=self.node_path,
+ node_name=AssistantNodeName.ROOT,
+ node_run_id="test_run_123",
+ )
self.assertEqual(dispatcher._node_name, AssistantNodeName.ROOT)
- self.assertEqual(dispatcher._writer, self._writer)
+ self.assertEqual(dispatcher._node_run_id, "test_run_123")
+ self.assertEqual(dispatcher._node_path, self.node_path)
+ self.assertEqual(dispatcher._writer, self.writer)
- async def test_dispatch_with_writer(self):
- """Test dispatching with a writer"""
- dispatched_actions = []
+ def test_dispatch_message_action(self):
+ """Test dispatching a message via the message() method."""
+ dispatcher = AssistantDispatcher(
+ writer=self.writer,
+ node_path=self.node_path,
+ node_name=AssistantNodeName.ROOT,
+ node_run_id="test_run_456",
+ )
- def mock_writer(data):
- dispatched_actions.append(data)
-
- dispatcher = AssistantDispatcher(mock_writer, node_name=AssistantNodeName.ROOT)
-
- message = AssistantMessage(content="test message", parent_tool_call_id="tc1")
+ message = AssistantMessage(content="Test message")
dispatcher.message(message)
- self.assertEqual(len(dispatched_actions), 1)
+ self.assertEqual(len(self.dispatched_events), 1)
+ event = self.dispatched_events[0]
- # Verify action structure
- event = dispatched_actions[0]
self.assertIsInstance(event, AssistantDispatcherEvent)
- self.assertEqual(event.action.type, "MESSAGE")
+ self.assertIsInstance(event.action, MessageAction)
+ assert isinstance(event.action, MessageAction)
self.assertEqual(event.action.message, message)
self.assertEqual(event.node_name, AssistantNodeName.ROOT)
+ self.assertEqual(event.node_run_id, "test_run_456")
+ self.assertEqual(event.node_path, self.node_path)
+
+ def test_dispatch_update_action(self):
+ """Test dispatching an update via the update() method."""
+ dispatcher = AssistantDispatcher(
+ writer=self.writer,
+ node_path=self.node_path,
+ node_name=AssistantNodeName.TRENDS_GENERATOR,
+ node_run_id="test_run_789",
+ )
+
+ dispatcher.update("Processing query...")
+
+ self.assertEqual(len(self.dispatched_events), 1)
+ event = self.dispatched_events[0]
+
+ self.assertIsInstance(event, AssistantDispatcherEvent)
+ self.assertIsInstance(event.action, UpdateAction)
+ assert isinstance(event.action, UpdateAction)
+ self.assertEqual(event.action.content, "Processing query...")
+ self.assertEqual(event.node_name, AssistantNodeName.TRENDS_GENERATOR)
+ self.assertEqual(event.node_run_id, "test_run_789")
+
+ def test_dispatch_multiple_messages(self):
+ """Test dispatching multiple messages in sequence."""
+ dispatcher = AssistantDispatcher(
+ writer=self.writer,
+ node_path=self.node_path,
+ node_name=AssistantNodeName.ROOT,
+ node_run_id="test_run_multi",
+ )
+
+ message1 = AssistantMessage(content="First message")
+ message2 = AssistantMessage(content="Second message")
+ message3 = AssistantMessage(content="Third message")
+
+ dispatcher.message(message1)
+ dispatcher.message(message2)
+ dispatcher.message(message3)
+
+ self.assertEqual(len(self.dispatched_events), 3)
+
+ contents = []
+ for event in self.dispatched_events:
+ if isinstance(event.action, MessageAction):
+ msg = event.action.message
+ if isinstance(msg, AssistantMessage):
+ contents.append(msg.content)
+ self.assertEqual(contents, ["First message", "Second message", "Third message"])
+
+ def test_dispatch_message_with_tool_calls(self):
+ """Test dispatching a message with tool calls preserves all data."""
+ dispatcher = AssistantDispatcher(
+ writer=self.writer,
+ node_path=self.node_path,
+ node_name=AssistantNodeName.ROOT,
+ node_run_id="test_run_tools",
+ )
+
+ tool_call = AssistantToolCall(id="tool_123", name="search", args={"query": "test query"})
+ message = AssistantMessage(content="Running search...", tool_calls=[tool_call])
+
+ dispatcher.message(message)
+
+ self.assertEqual(len(self.dispatched_events), 1)
+ event = self.dispatched_events[0]
+
+ self.assertIsInstance(event.action, MessageAction)
+ assert isinstance(event.action, MessageAction)
+ dispatched_message = event.action.message
+ assert isinstance(dispatched_message, AssistantMessage)
+ self.assertEqual(dispatched_message.content, "Running search...")
+ self.assertIsNotNone(dispatched_message.tool_calls)
+ assert dispatched_message.tool_calls is not None
+ self.assertEqual(len(dispatched_message.tool_calls), 1)
+ self.assertEqual(dispatched_message.tool_calls[0].id, "tool_123")
+ self.assertEqual(dispatched_message.tool_calls[0].name, "search")
+ self.assertEqual(dispatched_message.tool_calls[0].args, {"query": "test query"})
+
+ def test_dispatch_with_nested_node_path(self):
+ """Test that nested node paths are preserved in dispatched events."""
+ nested_path = (
+ NodePath(name=AssistantGraphName.ASSISTANT),
+ NodePath(name=AssistantNodeName.ROOT, message_id="msg_123", tool_call_id="tc_123"),
+ NodePath(name=AssistantGraphName.INSIGHTS),
+ NodePath(name=AssistantNodeName.TRENDS_GENERATOR),
+ )
+
+ dispatcher = AssistantDispatcher(
+ writer=self.writer, node_path=nested_path, node_name=AssistantNodeName.TRENDS_GENERATOR, node_run_id="run_1"
+ )
+
+ message = AssistantMessage(content="Nested message")
+ dispatcher.message(message)
+
+ self.assertEqual(len(self.dispatched_events), 1)
+ event = self.dispatched_events[0]
+
+ self.assertEqual(event.node_path, nested_path)
+ assert event.node_path is not None
+ self.assertEqual(len(event.node_path), 4)
+ self.assertEqual(event.node_path[1].message_id, "msg_123")
+ self.assertEqual(event.node_path[1].tool_call_id, "tc_123")
+
+ def test_dispatch_error_handling_continues_execution(self):
+ """Test that dispatch errors are caught and logged but don't crash."""
+
+ def failing_writer(event):
+ raise RuntimeError("Writer failed!")
+
+ dispatcher = AssistantDispatcher(
+ writer=failing_writer,
+ node_path=self.node_path,
+ node_name=AssistantNodeName.ROOT,
+ node_run_id="test_run_error",
+ )
+
+ message = AssistantMessage(content="This should not crash")
+
+ # Should not raise exception
+ with patch("logging.getLogger") as mock_get_logger:
+ mock_logger = MagicMock()
+ mock_get_logger.return_value = mock_logger
+
+ dispatcher.message(message)
+
+ # Verify error was logged
+ mock_logger.error.assert_called_once()
+ args, kwargs = mock_logger.error.call_args
+ self.assertIn("Failed to dispatch action", args[0])
+
+ def test_dispatch_mixed_actions(self):
+ """Test dispatching both messages and updates in sequence."""
+ dispatcher = AssistantDispatcher(
+ writer=self.writer,
+ node_path=self.node_path,
+ node_name=AssistantNodeName.TRENDS_GENERATOR,
+ node_run_id="test_run_mixed",
+ )
+
+ dispatcher.update("Starting analysis...")
+ dispatcher.message(AssistantMessage(content="Found 3 insights"))
+ dispatcher.update("Finalizing results...")
+
+ self.assertEqual(len(self.dispatched_events), 3)
+
+ self.assertIsInstance(self.dispatched_events[0].action, UpdateAction)
+ assert isinstance(self.dispatched_events[0].action, UpdateAction)
+ self.assertEqual(self.dispatched_events[0].action.content, "Starting analysis...")
+
+ self.assertIsInstance(self.dispatched_events[1].action, MessageAction)
+ assert isinstance(self.dispatched_events[1].action, MessageAction)
+ assert isinstance(self.dispatched_events[1].action.message, AssistantMessage)
+ self.assertEqual(self.dispatched_events[1].action.message.content, "Found 3 insights")
+
+ self.assertIsInstance(self.dispatched_events[2].action, UpdateAction)
+ assert isinstance(self.dispatched_events[2].action, UpdateAction)
+ self.assertEqual(self.dispatched_events[2].action.content, "Finalizing results...")
+
+ def test_dispatch_preserves_message_id(self):
+ """Test that message IDs are preserved through dispatch."""
+ dispatcher = AssistantDispatcher(
+ writer=self.writer,
+ node_path=self.node_path,
+ node_name=AssistantNodeName.ROOT,
+ node_run_id="test_run_id_preservation",
+ )
+
+ message = AssistantMessage(id="msg_xyz_789", content="Message with ID")
+ dispatcher.message(message)
+
+ event = self.dispatched_events[0]
+ assert isinstance(event.action, MessageAction)
+ self.assertEqual(event.action.message.id, "msg_xyz_789")
+
+ def test_dispatch_with_empty_node_path(self):
+ """Test dispatcher with an empty node path."""
+ dispatcher = AssistantDispatcher(
+ writer=self.writer, node_path=(), node_name=AssistantNodeName.ROOT, node_run_id="test_run_empty_path"
+ )
+
+ message = AssistantMessage(content="Root level message")
+ dispatcher.message(message)
+
+ self.assertEqual(len(self.dispatched_events), 1)
+ event = self.dispatched_events[0]
+ self.assertEqual(event.node_path, ())
-class MockNode(BaseAssistantNode[AssistantState, PartialAssistantState]):
- """Mock node for testing"""
+class TestCreateDispatcherFromConfig(BaseTest):
+ """Test the create_dispatcher_from_config helper function."""
- @property
- def node_name(self):
- return AssistantNodeName.ROOT
+ def test_create_dispatcher_from_config_with_stream_writer(self):
+ """Test creating dispatcher from config with LangGraph stream writer."""
+ node_path = (NodePath(name=AssistantGraphName.ASSISTANT), NodePath(name=AssistantNodeName.ROOT))
- async def arun(self, state, config):
- # Use dispatch to add a message
- self.dispatcher.message(AssistantMessage(content="Test message from node"))
- return PartialAssistantState()
+ config = RunnableConfig(
+ metadata={"langgraph_node": AssistantNodeName.TRENDS_GENERATOR, "langgraph_checkpoint_ns": "checkpoint_abc"}
+ )
+
+ mock_writer = MagicMock()
+
+ with patch("ee.hogai.utils.dispatcher.get_stream_writer", return_value=mock_writer):
+ dispatcher = create_dispatcher_from_config(config, node_path)
+
+ self.assertEqual(dispatcher._node_name, AssistantNodeName.TRENDS_GENERATOR)
+ self.assertEqual(dispatcher._node_run_id, "checkpoint_abc")
+ self.assertEqual(dispatcher._node_path, node_path)
+ self.assertEqual(dispatcher._writer, mock_writer)
+
+ def test_create_dispatcher_from_config_without_stream_writer(self):
+ """Test creating dispatcher when not in streaming context (e.g., testing)."""
+ node_path = (NodePath(name=AssistantGraphName.ASSISTANT), NodePath(name=AssistantNodeName.ROOT))
+
+ config = RunnableConfig(
+ metadata={"langgraph_node": AssistantNodeName.ROOT, "langgraph_checkpoint_ns": "checkpoint_xyz"}
+ )
+
+ # Simulate RuntimeError when get_stream_writer is called outside streaming context
+ with patch("ee.hogai.utils.dispatcher.get_stream_writer", side_effect=RuntimeError("Not in streaming context")):
+ dispatcher = create_dispatcher_from_config(config, node_path)
+
+ # Should create a noop writer
+ self.assertEqual(dispatcher._node_name, AssistantNodeName.ROOT)
+ self.assertEqual(dispatcher._node_run_id, "checkpoint_xyz")
+ self.assertEqual(dispatcher._node_path, node_path)
+
+ # Verify the noop writer doesn't raise exceptions
+ message = AssistantMessage(content="Test")
+ dispatcher.message(message) # Should not crash
+
+ def test_create_dispatcher_with_missing_metadata(self):
+ """Test creating dispatcher when metadata fields are missing."""
+ node_path = (NodePath(name=AssistantGraphName.ASSISTANT),)
+
+ config = RunnableConfig(metadata={})
+
+ with patch("ee.hogai.utils.dispatcher.get_stream_writer", side_effect=RuntimeError("Not in streaming context")):
+ dispatcher = create_dispatcher_from_config(config, node_path)
+
+ # Should use empty strings as defaults
+ self.assertEqual(dispatcher._node_name, "")
+ self.assertEqual(dispatcher._node_run_id, "")
+ self.assertEqual(dispatcher._node_path, node_path)
+
+ def test_create_dispatcher_preserves_node_path(self):
+ """Test that node path is correctly passed through to the dispatcher."""
+ nested_path = (
+ NodePath(name=AssistantGraphName.ASSISTANT),
+ NodePath(name=AssistantNodeName.ROOT, message_id="msg_1", tool_call_id="tc_1"),
+ NodePath(name=AssistantGraphName.INSIGHTS),
+ )
+
+ config = RunnableConfig(
+ metadata={"langgraph_node": AssistantNodeName.TRENDS_GENERATOR, "langgraph_checkpoint_ns": "cp_1"}
+ )
+
+ with patch("ee.hogai.utils.dispatcher.get_stream_writer", side_effect=RuntimeError("Not in streaming context")):
+ dispatcher = create_dispatcher_from_config(config, nested_path)
+
+ self.assertEqual(dispatcher._node_path, nested_path)
+ self.assertEqual(len(dispatcher._node_path), 3)
class TestDispatcherIntegration(BaseTest):
- async def test_node_dispatch_flow(self):
- """Test that a node can dispatch messages"""
- # Use sync test helpers since BaseTest provides them
- team, user = self.team, self.user
+ """Integration tests for dispatcher usage patterns."""
- node = MockNode(team=team, user=user)
+ def test_dispatcher_in_node_context(self):
+ """Test typical usage pattern within a node."""
+ dispatched_events = []
- # Track dispatched actions
- dispatched_actions = []
+ def mock_writer(event):
+ dispatched_events.append(event)
- def mock_writer(data):
- dispatched_actions.append(data)
-
- # Set up node's dispatcher
- node._dispatcher = AssistantDispatcher(mock_writer, node_name=AssistantNodeName.ROOT)
-
- # Run the node
- state = AssistantState(messages=[])
- config = RunnableConfig(configurable={})
-
- await node.arun(state, config)
-
- # Verify action was dispatched
- self.assertEqual(len(dispatched_actions), 1)
-
- event = dispatched_actions[0]
- self.assertIsInstance(event, AssistantDispatcherEvent)
- self.assertEqual(event.action.type, "MESSAGE")
- self.assertEqual(event.action.message.content, "Test message from node")
- self.assertEqual(event.action.message.parent_tool_call_id, None)
- self.assertEqual(event.node_name, AssistantNodeName.ROOT)
-
- async def test_action_preservation_through_stream(self):
- """Test that action data is preserved through the stream"""
-
- captured_updates = []
-
- def mock_writer(data):
- captured_updates.append(data)
-
- dispatcher = AssistantDispatcher(mock_writer, node_name=AssistantNodeName.ROOT)
-
- # Create complex message with metadata
- message = AssistantMessage(
- content="Complex message",
- parent_tool_call_id="tc123",
- tool_calls=[AssistantToolCall(id="tool1", name="search", args={"query": "test"})],
+ node_path = (
+ NodePath(name=AssistantGraphName.ASSISTANT),
+ NodePath(name=AssistantNodeName.ROOT),
)
- dispatcher.message(message)
+ dispatcher = AssistantDispatcher(
+ writer=mock_writer, node_path=node_path, node_name=AssistantNodeName.ROOT, node_run_id="integration_run_1"
+ )
- # Extract action
- event = captured_updates[0]
- self.assertIsInstance(event, AssistantDispatcherEvent)
+ # Simulate node execution pattern
+ dispatcher.update("Starting node execution...")
- # Verify all message fields preserved
- payload = event.action.message
- self.assertEqual(payload.content, "Complex message")
- self.assertEqual(payload.parent_tool_call_id, "tc123")
- self.assertIsNotNone(payload.tool_calls)
- self.assertEqual(len(payload.tool_calls), 1)
- self.assertEqual(payload.tool_calls[0].name, "search")
+ tool_call = AssistantToolCall(id="tc_int_1", name="generate_insight", args={"type": "trends"})
+ dispatcher.message(AssistantMessage(content="Generating insight", tool_calls=[tool_call]))
- async def test_multiple_dispatches_from_node(self):
- """Test that a node can dispatch multiple messages"""
- # Use sync test helpers since BaseTest provides them
- team, user = self.team, self.user
+ dispatcher.update("Processing data...")
- class MultiDispatchNode(BaseAssistantNode[AssistantState, PartialAssistantState]):
- @property
- def node_name(self):
- return AssistantNodeName.ROOT
+ dispatcher.message(AssistantMessage(content="Insight generated successfully"))
- async def arun(self, state, config):
- # Dispatch multiple messages
- self.dispatcher.message(AssistantMessage(content="First message"))
- self.dispatcher.message(AssistantMessage(content="Second message"))
- self.dispatcher.message(AssistantMessage(content="Third message"))
- return PartialAssistantState()
+ # Verify all events were dispatched
+ self.assertEqual(len(dispatched_events), 4)
- node = MultiDispatchNode(team=team, user=user)
+ # Verify event types and order
+ self.assertIsInstance(dispatched_events[0].action, UpdateAction)
+ self.assertIsInstance(dispatched_events[1].action, MessageAction)
+ self.assertIsInstance(dispatched_events[2].action, UpdateAction)
+ self.assertIsInstance(dispatched_events[3].action, MessageAction)
- dispatched_actions = []
+ # Verify all events have consistent metadata
+ for event in dispatched_events:
+ self.assertEqual(event.node_name, AssistantNodeName.ROOT)
+ self.assertEqual(event.node_run_id, "integration_run_1")
+ self.assertEqual(event.node_path, node_path)
- def mock_writer(data):
- dispatched_actions.append(data)
+ def test_concurrent_dispatchers(self):
+ """Test multiple dispatchers can coexist without interference."""
+ dispatched_events_1 = []
+ dispatched_events_2 = []
- node._dispatcher = AssistantDispatcher(mock_writer, node_name=AssistantNodeName.ROOT)
+ def writer_1(event):
+ dispatched_events_1.append(event)
- state = AssistantState(messages=[])
- config = RunnableConfig(configurable={})
+ def writer_2(event):
+ dispatched_events_2.append(event)
- await node.arun(state, config)
+ path_1 = (NodePath(name=AssistantGraphName.ASSISTANT), NodePath(name=AssistantNodeName.ROOT))
+ path_2 = (NodePath(name=AssistantGraphName.ASSISTANT), NodePath(name=AssistantNodeName.TRENDS_GENERATOR))
- # Verify all three dispatches
- self.assertEqual(len(dispatched_actions), 3)
+ dispatcher_1 = AssistantDispatcher(
+ writer=writer_1, node_path=path_1, node_name=AssistantNodeName.ROOT, node_run_id="run_1"
+ )
- contents = []
- for event in dispatched_actions:
- self.assertIsInstance(event, AssistantDispatcherEvent)
- contents.append(event.action.message.content)
+ dispatcher_2 = AssistantDispatcher(
+ writer=writer_2, node_path=path_2, node_name=AssistantNodeName.TRENDS_GENERATOR, node_run_id="run_2"
+ )
- self.assertEqual(contents, ["First message", "Second message", "Third message"])
+ # Dispatch from both
+ dispatcher_1.message(AssistantMessage(content="From dispatcher 1"))
+ dispatcher_2.message(AssistantMessage(content="From dispatcher 2"))
+ dispatcher_1.update("Update from dispatcher 1")
+ dispatcher_2.update("Update from dispatcher 2")
+
+ # Verify each dispatcher wrote to its own writer
+ self.assertEqual(len(dispatched_events_1), 2)
+ self.assertEqual(len(dispatched_events_2), 2)
+
+ # Verify events went to correct writers
+ assert isinstance(dispatched_events_1[0].action, MessageAction)
+ assert isinstance(dispatched_events_1[0].action.message, AssistantMessage)
+ self.assertEqual(dispatched_events_1[0].action.message.content, "From dispatcher 1")
+ assert isinstance(dispatched_events_2[0].action, MessageAction)
+ assert isinstance(dispatched_events_2[0].action.message, AssistantMessage)
+ self.assertEqual(dispatched_events_2[0].action.message.content, "From dispatcher 2")
+
+ # Verify node names are correct
+ self.assertEqual(dispatched_events_1[0].node_name, AssistantNodeName.ROOT)
+ self.assertEqual(dispatched_events_2[0].node_name, AssistantNodeName.TRENDS_GENERATOR)
diff --git a/ee/hogai/utils/test/test_stream_processor.py b/ee/hogai/utils/test/test_stream_processor.py
index 25a434e63e..f6dca2e50a 100644
--- a/ee/hogai/utils/test/test_stream_processor.py
+++ b/ee/hogai/utils/test/test_stream_processor.py
@@ -1,15 +1,14 @@
"""
-Comprehensive tests for AssistantMessageReducer.
+Comprehensive tests for AssistantStreamProcessor.
-Tests the reducer logic that processes dispatcher actions
-and routes messages appropriately.
+Tests the stream processor logic that handles dispatcher actions,
+routes messages based on node paths, and manages streaming state.
"""
from typing import cast
from uuid import uuid4
from posthog.test.base import BaseTest
-from unittest.mock import MagicMock
from langchain_core.messages import AIMessageChunk
@@ -17,7 +16,6 @@ from posthog.schema import (
AssistantGenerationStatusEvent,
AssistantGenerationStatusType,
AssistantMessage,
- AssistantToolCall,
AssistantToolCallMessage,
AssistantUpdateEvent,
FailureMessage,
@@ -29,13 +27,20 @@ from posthog.schema import (
VisualizationMessage,
)
+from ee.hogai.utils.state import GraphValueUpdateTuple
from ee.hogai.utils.stream_processor import AssistantStreamProcessor
from ee.hogai.utils.types.base import (
AssistantDispatcherEvent,
+ AssistantGraphName,
AssistantNodeName,
+ AssistantState,
+ LangGraphUpdateEvent,
MessageAction,
MessageChunkAction,
+ NodeEndAction,
+ NodePath,
NodeStartAction,
+ UpdateAction,
)
@@ -44,406 +49,581 @@ class TestStreamProcessor(BaseTest):
def setUp(self):
super().setUp()
- # Create a reducer with test node configuration
self.stream_processor = AssistantStreamProcessor(
+ verbose_nodes={AssistantNodeName.ROOT, AssistantNodeName.TRENDS_GENERATOR},
streaming_nodes={AssistantNodeName.TRENDS_GENERATOR},
- visualization_nodes={AssistantNodeName.TRENDS_GENERATOR: MagicMock()},
+ state_type=AssistantState,
)
def _create_dispatcher_event(
self,
- action: MessageAction | NodeStartAction | MessageChunkAction,
+ action: MessageAction | NodeStartAction | MessageChunkAction | NodeEndAction | UpdateAction,
node_name: AssistantNodeName = AssistantNodeName.ROOT,
+ node_run_id: str = "test_run_id",
+ node_path: tuple[NodePath, ...] | None = None,
) -> AssistantDispatcherEvent:
"""Helper to create a dispatcher event for testing."""
- return AssistantDispatcherEvent(action=action, node_name=node_name)
+ return AssistantDispatcherEvent(
+ action=action, node_name=node_name, node_run_id=node_run_id, node_path=node_path
+ )
- def test_node_start_action_returns_ack(self):
- """Test NODE_START action returns ACK status event."""
- event = self._create_dispatcher_event(NodeStartAction())
+ # Node lifecycle tests
+
+ def test_node_start_initializes_chunk_and_returns_ack(self):
+ """Test NodeStartAction initializes a chunk for the run_id and returns ACK."""
+ run_id = "test_run_123"
+ event = self._create_dispatcher_event(NodeStartAction(), node_run_id=run_id)
result = self.stream_processor.process(event)
self.assertIsNotNone(result)
- result = cast(AssistantGenerationStatusEvent, result)
- self.assertEqual(result.type, AssistantGenerationStatusType.ACK)
+ assert result is not None
+ self.assertEqual(len(result), 1)
+ first_result = result[0]
+ self.assertIsInstance(first_result, AssistantGenerationStatusEvent)
+ assert isinstance(first_result, AssistantGenerationStatusEvent)
+ self.assertEqual(first_result.type, AssistantGenerationStatusType.ACK)
+ self.assertIn(run_id, self.stream_processor._chunks)
+ self.assertEqual(self.stream_processor._chunks[run_id].content, "")
- def test_message_with_tool_calls_stores_in_registry(self):
- """Test AssistantMessage with tool_calls is stored in _tool_call_id_to_message."""
- tool_call_id = str(uuid4())
- message = AssistantMessage(
- content="Test",
- tool_calls=[AssistantToolCall(id=tool_call_id, name="test_tool", args={})],
- )
+ def test_node_end_cleans_up_chunk(self):
+ """Test NodeEndAction removes the chunk for the run_id."""
+ run_id = "test_run_456"
+ self.stream_processor._chunks[run_id] = AIMessageChunk(content="test")
- event = self._create_dispatcher_event(MessageAction(message=message))
+ state = AssistantState(messages=[])
+ event = self._create_dispatcher_event(NodeEndAction(state=state), node_run_id=run_id)
self.stream_processor.process(event)
- # Should be stored in registry
- self.assertIn(tool_call_id, self.stream_processor._tool_call_id_to_message)
- self.assertEqual(self.stream_processor._tool_call_id_to_message[tool_call_id], message)
+ self.assertNotIn(run_id, self.stream_processor._chunks)
- def test_assistant_message_with_parent_creates_assistant_update_event(self):
- """Test AssistantMessage with parent_tool_call_id creates AssistantUpdateEvent."""
- # First, register a parent message
- parent_tool_call_id = str(uuid4())
- parent_message = AssistantMessage(
- id=str(uuid4()),
- content="Parent",
- tool_calls=[AssistantToolCall(id=parent_tool_call_id, name="test", args={})],
- parent_tool_call_id=None,
+ def test_node_end_processes_messages_from_state(self):
+ """Test NodeEndAction processes all messages from the final state."""
+ run_id = "test_run_789"
+ message1 = AssistantMessage(id=str(uuid4()), content="Message 1")
+ message2 = AssistantMessage(id=str(uuid4()), content="Message 2")
+ state = AssistantState(messages=[message1, message2])
+
+ event = self._create_dispatcher_event(
+ NodeEndAction(state=state), node_name=AssistantNodeName.ROOT, node_run_id=run_id
)
- self.stream_processor._tool_call_id_to_message[parent_tool_call_id] = parent_message
+ results = self.stream_processor.process(event)
- # Now send a child message
- child_message = AssistantMessage(content="Child content", parent_tool_call_id=parent_tool_call_id)
- event = self._create_dispatcher_event(MessageAction(message=child_message))
+ self.assertIsNotNone(results)
+ assert results is not None
+ self.assertEqual(len(results), 2)
+ self.assertEqual(results[0], message1)
+ self.assertEqual(results[1], message2)
+
+ # Message streaming tests
+
+ def test_message_chunk_streaming_for_streaming_nodes(self):
+ """Test MessageChunkAction streams chunks for nodes in streaming_nodes."""
+ run_id = "stream_run_1"
+ chunk = AIMessageChunk(content="Hello ")
+
+ event = self._create_dispatcher_event(
+ MessageChunkAction(message=chunk), node_name=AssistantNodeName.TRENDS_GENERATOR, node_run_id=run_id
+ )
result = self.stream_processor.process(event)
- self.assertIsInstance(result, AssistantUpdateEvent)
- result = cast(AssistantUpdateEvent, result)
- self.assertEqual(result.id, parent_message.id)
- self.assertEqual(result.tool_call_id, parent_tool_call_id)
- self.assertEqual(result.content, "Child content")
+ self.assertIsNotNone(result)
+ assert result is not None
+ self.assertEqual(len(result), 1)
+ self.assertIsInstance(result[0], AssistantMessage)
+ assert isinstance(result[0], AssistantMessage)
+ self.assertEqual(result[0].content, "Hello ")
+ self.assertIsNone(result[0].id)
- def test_nested_parent_chain_resolution(self):
- """Test finding parent IDs through nested chain of parents."""
- # Create chain: root -> intermediate -> leaf
- root_tool_call_id = str(uuid4())
- root_message = AssistantMessage(
- id=str(uuid4()),
- content="Root",
- tool_calls=[AssistantToolCall(id=root_tool_call_id, name="root_tool", args={})],
- parent_tool_call_id=None,
+ def test_message_chunk_ignored_for_non_streaming_nodes(self):
+ """Test MessageChunkAction returns None for nodes not in streaming_nodes."""
+ run_id = "stream_run_2"
+ chunk = AIMessageChunk(content="Hello ")
+
+ event = self._create_dispatcher_event(
+ MessageChunkAction(message=chunk), node_name=AssistantNodeName.ROOT, node_run_id=run_id
)
-
- intermediate_tool_call_id = str(uuid4())
- intermediate_message = AssistantMessage(
- id=str(uuid4()),
- content="Intermediate",
- tool_calls=[AssistantToolCall(id=intermediate_tool_call_id, name="intermediate_tool", args={})],
- parent_tool_call_id=root_tool_call_id,
- )
-
- # Register both in the registry
- self.stream_processor._tool_call_id_to_message[root_tool_call_id] = root_message
- self.stream_processor._tool_call_id_to_message[intermediate_tool_call_id] = intermediate_message
-
- # Send leaf message that references intermediate
- leaf_message = AssistantMessage(content="Leaf content", parent_tool_call_id=intermediate_tool_call_id)
- event = self._create_dispatcher_event(MessageAction(message=leaf_message))
result = self.stream_processor.process(event)
- # Should resolve to root
+ self.assertIsNone(result)
- self.assertIsInstance(result, AssistantUpdateEvent)
- result = cast(AssistantUpdateEvent, result)
- # Note: The unpacking swaps the values, so id is tool_call_id and parent_tool_call_id is message_id
- self.assertEqual(result.id, root_message.id)
- self.assertEqual(result.tool_call_id, root_tool_call_id)
+ def test_multiple_chunks_merged_correctly(self):
+ """Test that multiple MessageChunkActions are merged correctly."""
+ run_id = "stream_run_3"
- def test_missing_parent_message_returns_ack(self):
- """Test that missing parent message returns ACK."""
- missing_parent_id = str(uuid4())
- child_message = AssistantMessage(content="Orphan", parent_tool_call_id=missing_parent_id)
-
- event = self._create_dispatcher_event(MessageAction(message=child_message))
-
- result = self.stream_processor.process(event)
- self.assertIsInstance(result, AssistantGenerationStatusEvent)
- result = cast(AssistantGenerationStatusEvent, result)
- self.assertEqual(result.type, AssistantGenerationStatusType.ACK)
-
- def test_parent_without_id_returns_ack(self):
- """Test that parent message without ID logs warning and returns None."""
- parent_tool_call_id = str(uuid4())
- # Parent message WITHOUT an id
- parent_message = AssistantMessage(
- id=None, # No ID
- content="Parent",
- tool_calls=[AssistantToolCall(id=parent_tool_call_id, name="test", args={})],
- parent_tool_call_id=None,
+ chunk1 = AIMessageChunk(content="Hello ")
+ event1 = self._create_dispatcher_event(
+ MessageChunkAction(message=chunk1), node_name=AssistantNodeName.TRENDS_GENERATOR, node_run_id=run_id
)
- self.stream_processor._tool_call_id_to_message[parent_tool_call_id] = parent_message
+ result1 = self.stream_processor.process(event1)
- child_message = AssistantMessage(content="Child", parent_tool_call_id=parent_tool_call_id)
- event = self._create_dispatcher_event(MessageAction(message=child_message))
+ chunk2 = AIMessageChunk(content="world!")
+ event2 = self._create_dispatcher_event(
+ MessageChunkAction(message=chunk2), node_name=AssistantNodeName.TRENDS_GENERATOR, node_run_id=run_id
+ )
+ result2 = self.stream_processor.process(event2)
+ self.assertIsNotNone(result1)
+ assert result1 is not None
+ assert isinstance(result1[0], AssistantMessage)
+ self.assertEqual(result1[0].content, "Hello ")
+ self.assertIsNotNone(result2)
+ assert result2 is not None
+ assert isinstance(result2[0], AssistantMessage)
+ self.assertEqual(result2[0].content, "Hello world!")
+
+ def test_concurrent_chunks_from_different_runs(self):
+ """Test that chunks from different node runs are kept separate."""
+ run_id_1 = "stream_run_4a"
+ run_id_2 = "stream_run_4b"
+
+ chunk1 = AIMessageChunk(content="Run 1")
+ event1 = self._create_dispatcher_event(
+ MessageChunkAction(message=chunk1), node_name=AssistantNodeName.TRENDS_GENERATOR, node_run_id=run_id_1
+ )
+ self.stream_processor.process(event1)
+
+ chunk2 = AIMessageChunk(content="Run 2")
+ event2 = self._create_dispatcher_event(
+ MessageChunkAction(message=chunk2), node_name=AssistantNodeName.TRENDS_GENERATOR, node_run_id=run_id_2
+ )
+ self.stream_processor.process(event2)
+
+ self.assertEqual(self.stream_processor._chunks[run_id_1].content, "Run 1")
+ self.assertEqual(self.stream_processor._chunks[run_id_2].content, "Run 2")
+
+ def test_handles_mixed_content_types_in_chunks(self):
+ """Test that stream processor handles switching between string and list content formats."""
+ run_id = "stream_run_5"
+
+ # Start with string content
+ chunk1 = AIMessageChunk(content="initial string")
+ event1 = self._create_dispatcher_event(
+ MessageChunkAction(message=chunk1), node_name=AssistantNodeName.TRENDS_GENERATOR, node_run_id=run_id
+ )
+ self.stream_processor.process(event1)
+
+ # Switch to list format (OpenAI Responses API)
+ chunk2 = AIMessageChunk(content=[{"type": "text", "text": "list content"}])
+ event2 = self._create_dispatcher_event(
+ MessageChunkAction(message=chunk2), node_name=AssistantNodeName.TRENDS_GENERATOR, node_run_id=run_id
+ )
+ result = self.stream_processor.process(event2)
+
+ # The result should normalize to string content
+ self.assertIsNotNone(result)
+ assert result is not None
+ assert isinstance(result[0], AssistantMessage)
+ self.assertEqual(result[0].content, "list content")
+
+ # Root vs nested message handling tests
+
+ def test_root_message_from_verbose_node_returned(self):
+ """Test messages from root level (node_path <= 2) in verbose nodes are returned."""
+ message = AssistantMessage(id=str(uuid4()), content="Root message")
+ node_path = (NodePath(name=AssistantGraphName.ASSISTANT), NodePath(name=AssistantNodeName.ROOT))
+
+ event = self._create_dispatcher_event(
+ MessageAction(message=message), node_name=AssistantNodeName.ROOT, node_path=node_path
+ )
result = self.stream_processor.process(event)
- self.assertIsInstance(result, AssistantGenerationStatusEvent)
- result = cast(AssistantGenerationStatusEvent, result)
- self.assertEqual(result.type, AssistantGenerationStatusType.ACK)
- def test_visualization_message_in_visualization_nodes(self):
- """Test VisualizationMessage is returned when node is in VISUALIZATION_NODES."""
+ self.assertIsNotNone(result)
+ assert result is not None
+ self.assertEqual(len(result), 1)
+ self.assertEqual(result[0], message)
+
+ def test_root_message_from_non_verbose_node_filtered(self):
+ """Test messages from root level in non-verbose nodes are filtered out."""
+ message = AssistantMessage(id=str(uuid4()), content="Non-verbose message")
+ node_path = (NodePath(name=AssistantGraphName.ASSISTANT), NodePath(name=AssistantNodeName.BILLING))
+
+ event = self._create_dispatcher_event(
+ MessageAction(message=message), node_name=AssistantNodeName.BILLING, node_path=node_path
+ )
+ result = self.stream_processor.process(event)
+
+ self.assertIsNone(result)
+
+ def test_nested_visualization_message_returned(self):
+ """Test VisualizationMessage from nested node/graph is returned."""
query = TrendsQuery(series=[])
viz_message = VisualizationMessage(query="test query", answer=query, plan="test plan")
- viz_message.parent_tool_call_id = str(uuid4())
- # Register parent
- parent_message = AssistantMessage(
- id=str(uuid4()),
- content="Parent",
- tool_calls=[AssistantToolCall(id=viz_message.parent_tool_call_id, name="test", args={})],
+ # Create a deep node path indicating this is from a nested graph
+ node_path = (
+ NodePath(name=AssistantGraphName.ASSISTANT),
+ NodePath(name=AssistantNodeName.ROOT, message_id=str(uuid4()), tool_call_id=str(uuid4())),
+ NodePath(name=AssistantGraphName.INSIGHTS),
+ NodePath(name=AssistantNodeName.TRENDS_GENERATOR),
)
- self.stream_processor._tool_call_id_to_message[viz_message.parent_tool_call_id] = parent_message
- node_name = AssistantNodeName.TRENDS_GENERATOR
- event = self._create_dispatcher_event(MessageAction(message=viz_message), node_name=node_name)
+ event = self._create_dispatcher_event(
+ MessageAction(message=viz_message), node_name=AssistantNodeName.TRENDS_GENERATOR, node_path=node_path
+ )
result = self.stream_processor.process(event)
- self.assertEqual(result, viz_message)
+ self.assertIsNotNone(result)
+ assert result is not None
+ self.assertEqual(len(result), 1)
+ self.assertEqual(result[0], viz_message)
- def test_visualization_message_not_in_visualization_nodes(self):
- """Test VisualizationMessage raises error when from non-visualization node."""
- query = TrendsQuery(series=[])
- viz_message = VisualizationMessage(query="test query", answer=query, plan="test plan")
- viz_message.parent_tool_call_id = str(uuid4())
-
- # Register parent
- parent_message = AssistantMessage(
- id=str(uuid4()),
- content="Parent",
- tool_calls=[AssistantToolCall(id=viz_message.parent_tool_call_id, name="test", args={})],
- )
- self.stream_processor._tool_call_id_to_message[viz_message.parent_tool_call_id] = parent_message
-
- node_name = AssistantNodeName.ROOT # Not a visualization node
- event = self._create_dispatcher_event(MessageAction(message=viz_message), node_name=node_name)
-
- result = self.stream_processor.process(event)
- self.assertEqual(result, AssistantGenerationStatusEvent(type=AssistantGenerationStatusType.ACK))
-
- def test_multi_visualization_message_in_visualization_nodes(self):
- """Test MultiVisualizationMessage is returned when node is in VISUALIZATION_NODES."""
+ def test_nested_multi_visualization_message_returned(self):
+ """Test MultiVisualizationMessage from nested node/graph is returned."""
query = TrendsQuery(series=[])
viz_item = VisualizationItem(query="test query", answer=query, plan="test plan")
multi_viz_message = MultiVisualizationMessage(visualizations=[viz_item])
- multi_viz_message.parent_tool_call_id = str(uuid4())
- # Register parent
- parent_message = AssistantMessage(
- id=str(uuid4()),
- content="Parent",
- tool_calls=[AssistantToolCall(id=multi_viz_message.parent_tool_call_id, name="test", args={})],
+ node_path = (
+ NodePath(name=AssistantGraphName.ASSISTANT),
+ NodePath(name=AssistantNodeName.ROOT, message_id=str(uuid4()), tool_call_id=str(uuid4())),
+ NodePath(name=AssistantGraphName.INSIGHTS),
+ NodePath(name=AssistantNodeName.TRENDS_GENERATOR),
)
- self.stream_processor._tool_call_id_to_message[multi_viz_message.parent_tool_call_id] = parent_message
- node_name = AssistantNodeName.TRENDS_GENERATOR
- event = self._create_dispatcher_event(MessageAction(message=multi_viz_message), node_name=node_name)
+ event = self._create_dispatcher_event(
+ MessageAction(message=multi_viz_message), node_name=AssistantNodeName.TRENDS_GENERATOR, node_path=node_path
+ )
result = self.stream_processor.process(event)
- self.assertEqual(result, multi_viz_message)
+ self.assertIsNotNone(result)
+ assert result is not None
+ self.assertEqual(len(result), 1)
+ self.assertEqual(result[0], multi_viz_message)
- def test_notebook_update_message_returns_as_is(self):
- """Test NotebookUpdateMessage is returned directly."""
+ def test_nested_notebook_message_returned(self):
+ """Test NotebookUpdateMessage from nested node/graph is returned."""
content = ProsemirrorJSONContent(type="doc", content=[])
notebook_message = NotebookUpdateMessage(notebook_id="nb123", content=content)
- notebook_message.parent_tool_call_id = str(uuid4())
- # Register parent
- parent_message = AssistantMessage(
- id=str(uuid4()),
- content="Parent",
- tool_calls=[AssistantToolCall(id=notebook_message.parent_tool_call_id, name="test", args={})],
+ node_path = (
+ NodePath(name=AssistantGraphName.ASSISTANT),
+ NodePath(name=AssistantNodeName.ROOT, message_id=str(uuid4()), tool_call_id=str(uuid4())),
+ NodePath(name=AssistantGraphName.INSIGHTS),
+ NodePath(name=AssistantNodeName.TRENDS_GENERATOR),
)
- self.stream_processor._tool_call_id_to_message[notebook_message.parent_tool_call_id] = parent_message
- event = self._create_dispatcher_event(MessageAction(message=notebook_message))
+ event = self._create_dispatcher_event(
+ MessageAction(message=notebook_message), node_name=AssistantNodeName.TRENDS_GENERATOR, node_path=node_path
+ )
result = self.stream_processor.process(event)
- self.assertEqual(result, notebook_message)
+ self.assertIsNotNone(result)
+ assert result is not None
+ self.assertEqual(len(result), 1)
+ self.assertEqual(result[0], notebook_message)
- def test_failure_message_returns_as_is(self):
- """Test FailureMessage is returned directly."""
+ def test_nested_failure_message_returned(self):
+ """Test FailureMessage from nested node/graph is returned."""
failure_message = FailureMessage(content="Something went wrong")
- failure_message.parent_tool_call_id = str(uuid4())
- # Register parent
- parent_message = AssistantMessage(
- id=str(uuid4()),
- content="Parent",
- tool_calls=[AssistantToolCall(id=failure_message.parent_tool_call_id, name="test", args={})],
+ node_path = (
+ NodePath(name=AssistantGraphName.ASSISTANT),
+ NodePath(name=AssistantNodeName.ROOT, message_id=str(uuid4()), tool_call_id=str(uuid4())),
+ NodePath(name=AssistantGraphName.INSIGHTS),
+ NodePath(name=AssistantNodeName.TRENDS_GENERATOR),
)
- self.stream_processor._tool_call_id_to_message[failure_message.parent_tool_call_id] = parent_message
- event = self._create_dispatcher_event(MessageAction(message=failure_message))
+ event = self._create_dispatcher_event(
+ MessageAction(message=failure_message), node_name=AssistantNodeName.TRENDS_GENERATOR, node_path=node_path
+ )
result = self.stream_processor.process(event)
- self.assertEqual(result, failure_message)
+ self.assertIsNotNone(result)
+ assert result is not None
+ self.assertEqual(len(result), 1)
+ self.assertEqual(result[0], failure_message)
- def test_assistant_tool_call_message_returns_as_is(self):
- """Test AssistantToolCallMessage with parent is filtered out (returns ACK)."""
+ def test_nested_tool_call_message_filtered(self):
+ """Test AssistantToolCallMessage from nested node/graph is filtered out."""
tool_call_message = AssistantToolCallMessage(content="Tool result", tool_call_id=str(uuid4()))
- tool_call_message.parent_tool_call_id = str(uuid4())
- # Register parent
- parent_message = AssistantMessage(
- id=str(uuid4()),
- content="Parent",
- tool_calls=[AssistantToolCall(id=tool_call_message.parent_tool_call_id, name="test", args={})],
- )
- self.stream_processor._tool_call_id_to_message[tool_call_message.parent_tool_call_id] = parent_message
-
- event = self._create_dispatcher_event(MessageAction(message=tool_call_message))
- result = self.stream_processor.process(event)
-
- # New behavior: AssistantToolCallMessages with parents are filtered out
- self.assertIsInstance(result, AssistantGenerationStatusEvent)
- result = cast(AssistantGenerationStatusEvent, result)
- self.assertEqual(result.type, AssistantGenerationStatusType.ACK)
-
- def test_cycle_detection_in_parent_chain(self):
- """Test that circular parent chains are detected and returns ACK."""
- # Create circular chain: A -> B -> A
- tool_call_a = str(uuid4())
- tool_call_b = str(uuid4())
-
- message_a = AssistantMessage(
- id=str(uuid4()),
- content="A",
- tool_calls=[AssistantToolCall(id=tool_call_a, name="tool_a", args={})],
- parent_tool_call_id=tool_call_b, # Points to B
+ node_path = (
+ NodePath(name=AssistantGraphName.ASSISTANT),
+ NodePath(name=AssistantNodeName.ROOT, message_id=str(uuid4()), tool_call_id=str(uuid4())),
+ NodePath(name=AssistantGraphName.INSIGHTS),
+ NodePath(name=AssistantNodeName.TRENDS_GENERATOR),
)
- message_b = AssistantMessage(
- id=str(uuid4()),
- content="B",
- tool_calls=[AssistantToolCall(id=tool_call_b, name="tool_b", args={})],
- parent_tool_call_id=tool_call_a, # Points to A
- )
-
- self.stream_processor._tool_call_id_to_message[tool_call_a] = message_a
- self.stream_processor._tool_call_id_to_message[tool_call_b] = message_b
-
- # Try to process a child of B
- child_message = AssistantMessage(content="Child", parent_tool_call_id=tool_call_b)
- event = self._create_dispatcher_event(MessageAction(message=child_message))
-
- result = self.stream_processor.process(event)
-
- # Cycle detection returns ACK instead of raising error
- self.assertIsInstance(result, AssistantGenerationStatusEvent)
- result = cast(AssistantGenerationStatusEvent, result)
- self.assertEqual(result.type, AssistantGenerationStatusType.ACK)
-
- def test_handles_mixed_content_types_in_chunks(self):
- """Test that stream processor correctly handles switching between string and list content formats."""
- # Test string to list transition
- self.stream_processor._chunks = AIMessageChunk(content="initial string content")
-
- # Simulate a chunk from OpenAI Responses API (list format)
- list_chunk = AIMessageChunk(content=[{"type": "text", "text": "new content from o3"}])
event = self._create_dispatcher_event(
- MessageChunkAction(message=list_chunk), node_name=AssistantNodeName.TRENDS_GENERATOR
- )
- self.stream_processor.process(event)
-
- # Verify the chunks were reset to list format
- self.assertIsInstance(self.stream_processor._chunks.content, list)
- self.assertEqual(len(self.stream_processor._chunks.content), 1)
- self.assertEqual(cast(dict, self.stream_processor._chunks.content[0])["text"], "new content from o3")
-
- # Test list to string transition
- string_chunk = AIMessageChunk(content="back to string format")
- event = self._create_dispatcher_event(
- MessageChunkAction(message=string_chunk), node_name=AssistantNodeName.TRENDS_GENERATOR
- )
- self.stream_processor.process(event)
-
- # Verify the chunks were reset to string format
- self.assertIsInstance(self.stream_processor._chunks.content, str)
- self.assertEqual(self.stream_processor._chunks.content, "back to string format")
-
- def test_handles_multiple_list_chunks(self):
- """Test that multiple list-format chunks are properly concatenated."""
- # Start with empty chunks
- self.stream_processor._chunks = AIMessageChunk(content="")
-
- # Add first list chunk
- chunk1 = AIMessageChunk(content=[{"type": "text", "text": "First part"}])
- event = self._create_dispatcher_event(
- MessageChunkAction(message=chunk1), node_name=AssistantNodeName.TRENDS_GENERATOR
- )
- self.stream_processor.process(event)
-
- # Add second list chunk
- chunk2 = AIMessageChunk(content=[{"type": "text", "text": " second part"}])
- event = self._create_dispatcher_event(
- MessageChunkAction(message=chunk2), node_name=AssistantNodeName.TRENDS_GENERATOR
+ MessageAction(message=tool_call_message), node_name=AssistantNodeName.TRENDS_GENERATOR, node_path=node_path
)
result = self.stream_processor.process(event)
- # Verify the result is an AssistantMessage with combined content
- self.assertIsInstance(result, AssistantMessage)
- result = cast(AssistantMessage, result)
- self.assertEqual(result.content, "First part second part")
+ self.assertIsNone(result)
- def test_messages_without_id_are_yielded(self):
- """Test that messages without ID are always yielded."""
- # Create messages without IDs
+ def test_short_node_path_treated_as_root(self):
+ """Test that node_path with length <= 2 is treated as root level."""
+ message = AssistantMessage(id=str(uuid4()), content="Short path message")
+
+ # Path with just 2 elements (graph + node)
+ node_path = (NodePath(name=AssistantGraphName.ASSISTANT), NodePath(name=AssistantNodeName.ROOT))
+
+ event = self._create_dispatcher_event(
+ MessageAction(message=message), node_name=AssistantNodeName.ROOT, node_path=node_path
+ )
+ result = self.stream_processor.process(event)
+
+ self.assertIsNotNone(result)
+ assert result is not None
+ self.assertEqual(len(result), 1)
+ self.assertEqual(result[0], message)
+
+ # UpdateAction tests
+
+ def test_update_action_creates_update_event_with_parent_from_path(self):
+ """Test UpdateAction creates AssistantUpdateEvent using closest tool_call_id from node_path."""
+ message_id = str(uuid4())
+ tool_call_id = str(uuid4())
+
+ node_path = (
+ NodePath(name=AssistantGraphName.ASSISTANT),
+ NodePath(name=AssistantNodeName.ROOT, message_id=message_id, tool_call_id=tool_call_id),
+ NodePath(name=AssistantGraphName.INSIGHTS),
+ )
+
+ event = self._create_dispatcher_event(
+ UpdateAction(content="Update content"), node_name=AssistantNodeName.TRENDS_GENERATOR, node_path=node_path
+ )
+ result = self.stream_processor.process(event)
+
+ self.assertIsNotNone(result)
+ assert result is not None
+ self.assertEqual(len(result), 1)
+ self.assertIsInstance(result[0], AssistantUpdateEvent)
+ update_event = cast(AssistantUpdateEvent, result[0])
+ self.assertEqual(update_event.id, message_id)
+ self.assertEqual(update_event.tool_call_id, tool_call_id)
+ self.assertEqual(update_event.content, "Update content")
+
+ def test_update_action_without_parent_returns_none(self):
+ """Test UpdateAction without parent tool_call_id in node_path returns None."""
+ # No tool_call_id in any path element
+ node_path = (NodePath(name=AssistantGraphName.ASSISTANT), NodePath(name=AssistantNodeName.ROOT))
+
+ event = self._create_dispatcher_event(
+ UpdateAction(content="Update content"), node_name=AssistantNodeName.ROOT, node_path=node_path
+ )
+ result = self.stream_processor.process(event)
+
+ self.assertIsNone(result)
+
+ def test_update_action_without_node_path_returns_none(self):
+ """Test UpdateAction without node_path returns None."""
+ event = self._create_dispatcher_event(UpdateAction(content="Update content"), node_path=None)
+ result = self.stream_processor.process(event)
+
+ self.assertIsNone(result)
+
+ def test_update_action_finds_closest_tool_call_in_reversed_path(self):
+ """Test UpdateAction finds the closest (most recent) tool_call_id by reversing the path."""
+ # Multiple tool calls in the path - should find the closest one (last in reversed iteration)
+ message_id_1 = str(uuid4())
+ tool_call_id_1 = str(uuid4())
+ message_id_2 = str(uuid4())
+ tool_call_id_2 = str(uuid4())
+
+ node_path = (
+ NodePath(name=AssistantGraphName.ASSISTANT),
+ NodePath(name=AssistantNodeName.ROOT, message_id=message_id_1, tool_call_id=tool_call_id_1),
+ NodePath(name=AssistantGraphName.INSIGHTS, message_id=message_id_2, tool_call_id=tool_call_id_2),
+ NodePath(name=AssistantNodeName.TRENDS_GENERATOR),
+ )
+
+ event = self._create_dispatcher_event(
+ UpdateAction(content="Update content"), node_name=AssistantNodeName.TRENDS_GENERATOR, node_path=node_path
+ )
+ result = self.stream_processor.process(event)
+
+ self.assertIsNotNone(result)
+ assert result is not None
+ self.assertEqual(len(result), 1)
+ update_event = cast(AssistantUpdateEvent, result[0])
+ # Should use the closest parent (last one in reversed path)
+ self.assertEqual(update_event.id, message_id_2)
+ self.assertEqual(update_event.tool_call_id, tool_call_id_2)
+
+ # Message deduplication tests
+
+ def test_messages_with_id_deduplicated(self):
+ """Test that messages with the same ID are deduplicated."""
+ message_id = str(uuid4())
+ message1 = AssistantMessage(id=message_id, content="First occurrence")
+ message2 = AssistantMessage(id=message_id, content="Second occurrence")
+
+ node_path = (NodePath(name=AssistantGraphName.ASSISTANT), NodePath(name=AssistantNodeName.ROOT))
+
+ # Process first message - should be returned
+ event1 = self._create_dispatcher_event(
+ MessageAction(message=message1), node_name=AssistantNodeName.ROOT, node_path=node_path
+ )
+ result1 = self.stream_processor.process(event1)
+ self.assertIsNotNone(result1)
+ assert result1 is not None
+ self.assertEqual(result1[0], message1)
+
+ # Process second message with same ID - should be filtered
+ event2 = self._create_dispatcher_event(
+ MessageAction(message=message2), node_name=AssistantNodeName.ROOT, node_path=node_path
+ )
+ result2 = self.stream_processor.process(event2)
+ self.assertIsNone(result2)
+
+ def test_messages_without_id_not_deduplicated(self):
+ """Test that messages without ID are always yielded (not deduplicated)."""
message1 = AssistantMessage(content="Message without ID")
message2 = AssistantMessage(content="Another message without ID")
- # Process first message
- event1 = self._create_dispatcher_event(MessageAction(message=message1))
+ node_path = (NodePath(name=AssistantGraphName.ASSISTANT), NodePath(name=AssistantNodeName.ROOT))
+
+ event1 = self._create_dispatcher_event(
+ MessageAction(message=message1), node_name=AssistantNodeName.ROOT, node_path=node_path
+ )
result1 = self.stream_processor.process(event1)
- self.assertEqual(result1, message1)
+ self.assertIsNotNone(result1)
+ assert result1 is not None
+ self.assertEqual(result1[0], message1)
- # Process second message with same content
- event2 = self._create_dispatcher_event(MessageAction(message=message2))
+ event2 = self._create_dispatcher_event(
+ MessageAction(message=message2), node_name=AssistantNodeName.ROOT, node_path=node_path
+ )
result2 = self.stream_processor.process(event2)
- self.assertEqual(result2, message2)
+ self.assertIsNotNone(result2)
+ assert result2 is not None
+ self.assertEqual(result2[0], message2)
- # Both should be yielded since they have no IDs
-
- def test_messages_with_id_are_deduplicated(self):
- """Test that messages with ID are deduplicated during streaming."""
+ def test_preexisting_message_ids_filtered(self):
+ """Test that stream processor filters messages with IDs already in _streamed_update_ids."""
message_id = str(uuid4())
- # Create multiple messages with the same ID
- message1 = AssistantMessage(id=message_id, content="First occurrence")
- message2 = AssistantMessage(id=message_id, content="Second occurrence")
- message3 = AssistantMessage(id=message_id, content="Third occurrence")
+ # Pre-populate the streamed IDs
+ self.stream_processor._streamed_update_ids.add(message_id)
- # Process first message - should be yielded
- event1 = self._create_dispatcher_event(MessageAction(message=message1))
- result1 = self.stream_processor.process(event1)
- self.assertEqual(result1, message1)
- self.assertIn(message_id, self.stream_processor._streamed_update_ids)
+ message = AssistantMessage(id=message_id, content="Already seen")
+ node_path = (NodePath(name=AssistantGraphName.ASSISTANT), NodePath(name=AssistantNodeName.ROOT))
- # Process second message with same ID - should return ACK
- event2 = self._create_dispatcher_event(MessageAction(message=message2))
- result2 = self.stream_processor.process(event2)
- self.assertIsInstance(result2, AssistantGenerationStatusEvent)
- result2 = cast(AssistantGenerationStatusEvent, result2)
- self.assertEqual(result2.type, AssistantGenerationStatusType.ACK)
+ event = self._create_dispatcher_event(
+ MessageAction(message=message), node_name=AssistantNodeName.ROOT, node_path=node_path
+ )
+ result = self.stream_processor.process(event)
- # Process third message with same ID - should also return ACK
- event3 = self._create_dispatcher_event(MessageAction(message=message3))
- result3 = self.stream_processor.process(event3)
- self.assertIsInstance(result3, AssistantGenerationStatusEvent)
- result3 = cast(AssistantGenerationStatusEvent, result3)
- self.assertEqual(result3.type, AssistantGenerationStatusType.ACK)
+ self.assertIsNone(result)
- def test_stream_processor_with_preexisting_message_ids(self):
- """Test that stream processor correctly filters messages when initialized with existing IDs."""
- message_id_1 = str(uuid4())
- message_id_2 = str(uuid4())
+ # LangGraph update processing tests
- # Simulate existing messages by pre-populating the streamed IDs set
- self.stream_processor._streamed_update_ids.add(message_id_1)
+ def test_langgraph_message_chunk_processed(self):
+ """Test that LangGraph message chunk updates are converted and processed."""
+ chunk = AIMessageChunk(content="LangGraph chunk")
+ state = {"langgraph_node": AssistantNodeName.TRENDS_GENERATOR, "langgraph_checkpoint_ns": "checkpoint_123"}
- # Try to process message with existing ID - should be filtered out
- message1 = AssistantMessage(id=message_id_1, content="Already seen")
- event1 = self._create_dispatcher_event(MessageAction(message=message1))
- result1 = self.stream_processor.process(event1)
- self.assertIsInstance(result1, AssistantGenerationStatusEvent)
- result1 = cast(AssistantGenerationStatusEvent, result1)
- self.assertEqual(result1.type, AssistantGenerationStatusType.ACK)
+ update = ["messages", (chunk, state)]
+ event = LangGraphUpdateEvent(update=update)
- # Process message with new ID - should be yielded
- message2 = AssistantMessage(id=message_id_2, content="New message")
- event2 = self._create_dispatcher_event(MessageAction(message=message2))
- result2 = self.stream_processor.process(event2)
- self.assertEqual(result2, message2)
- self.assertIn(message_id_2, self.stream_processor._streamed_update_ids)
+ result = self.stream_processor.process_langgraph_update(event)
+
+ self.assertIsNotNone(result)
+ assert result is not None
+ self.assertEqual(len(result), 1)
+ self.assertIsInstance(result[0], AssistantMessage)
+ assert isinstance(result[0], AssistantMessage)
+ self.assertEqual(result[0].content, "LangGraph chunk")
+
+ def test_langgraph_state_update_stored(self):
+ """Test that LangGraph state updates are stored in _state."""
+ new_state_dict = {"messages": [], "plan": "Test plan"}
+ update = cast(GraphValueUpdateTuple, ["values", new_state_dict])
+
+ event = LangGraphUpdateEvent(update=update)
+ result = self.stream_processor.process_langgraph_update(event)
+
+ self.assertIsNone(result)
+ self.assertIsNotNone(self.stream_processor._state)
+ assert self.stream_processor._state is not None
+ self.assertEqual(self.stream_processor._state.plan, "Test plan")
+
+ def test_langgraph_non_message_chunk_ignored(self):
+ """Test that LangGraph updates that are not AIMessageChunk are ignored."""
+ regular_message = AssistantMessage(content="Not a chunk")
+ state = {"langgraph_node": AssistantNodeName.ROOT, "langgraph_checkpoint_ns": "checkpoint_456"}
+
+ update = ["messages", (regular_message, state)]
+ event = LangGraphUpdateEvent(update=update)
+
+ result = self.stream_processor.process_langgraph_update(event)
+
+ self.assertIsNone(result)
+
+ def test_langgraph_invalid_update_format_ignored(self):
+ """Test that invalid LangGraph update formats are ignored."""
+ update = "invalid_format"
+ event = LangGraphUpdateEvent(update=update)
+
+ result = self.stream_processor.process_langgraph_update(event)
+
+ self.assertIsNone(result)
+
+ # Edge cases and error conditions
+
+ def test_empty_node_path_treated_as_root(self):
+ """Test that empty node_path is treated as root level."""
+ message = AssistantMessage(id=str(uuid4()), content="Empty path message")
+
+ event = self._create_dispatcher_event(
+ MessageAction(message=message), node_name=AssistantNodeName.ROOT, node_path=()
+ )
+ result = self.stream_processor.process(event)
+
+ self.assertIsNotNone(result)
+ assert result is not None
+ self.assertEqual(len(result), 1)
+ self.assertEqual(result[0], message)
+
+ def test_none_node_path_treated_as_root(self):
+ """Test that None node_path is treated as root level."""
+ message = AssistantMessage(id=str(uuid4()), content="None path message")
+
+ event = self._create_dispatcher_event(
+ MessageAction(message=message), node_name=AssistantNodeName.ROOT, node_path=None
+ )
+ result = self.stream_processor.process(event)
+
+ self.assertIsNotNone(result)
+ assert result is not None
+ self.assertEqual(len(result), 1)
+ self.assertEqual(result[0], message)
+
+ def test_node_end_with_none_state_returns_none(self):
+ """Test NodeEndAction with None state returns None."""
+ event = self._create_dispatcher_event(NodeEndAction(state=None))
+ result = self.stream_processor.process(event)
+
+ self.assertIsNone(result)
+
+ def test_update_action_with_empty_content_returns_none(self):
+ """Test UpdateAction with empty content returns None."""
+ node_path = (
+ NodePath(name=AssistantGraphName.ASSISTANT),
+ NodePath(name=AssistantNodeName.ROOT, message_id=str(uuid4()), tool_call_id=str(uuid4())),
+ )
+
+ event = self._create_dispatcher_event(UpdateAction(content=""), node_path=node_path)
+ result = self.stream_processor.process(event)
+
+ self.assertIsNone(result)
+
+ def test_special_messages_from_root_level_returned(self):
+ """Test that special message types from root level are handled by root message logic."""
+ # VisualizationMessage from root should be returned if from verbose node
+ query = TrendsQuery(series=[])
+ viz_message = VisualizationMessage(query="test", answer=query, plan="plan")
+
+ node_path = (NodePath(name=AssistantGraphName.ASSISTANT), NodePath(name=AssistantNodeName.TRENDS_GENERATOR))
+
+ event = self._create_dispatcher_event(
+ MessageAction(message=viz_message), node_name=AssistantNodeName.TRENDS_GENERATOR, node_path=node_path
+ )
+ result = self.stream_processor.process(event)
+
+ self.assertIsNotNone(result)
+ assert result is not None
+ self.assertEqual(len(result), 1)
+ self.assertEqual(result[0], viz_message)
diff --git a/ee/hogai/utils/types/base.py b/ee/hogai/utils/types/base.py
index ce2eef08f1..1b9e5afb46 100644
--- a/ee/hogai/utils/types/base.py
+++ b/ee/hogai/utils/types/base.py
@@ -299,7 +299,7 @@ class _SharedAssistantState(BaseStateWithMessages, BaseStateWithIntermediateStep
"""
The ID of the message to start from to keep the message window short enough.
"""
- root_tool_call_id: Optional[str] = Field(default=None)
+ root_tool_call_id: Annotated[Optional[str], replace] = Field(default=None)
"""
The ID of the tool call from the root node.
"""
@@ -311,7 +311,7 @@ class _SharedAssistantState(BaseStateWithMessages, BaseStateWithIntermediateStep
"""
The type of insight to generate.
"""
- root_tool_calls_count: Optional[int] = Field(default=None)
+ root_tool_calls_count: Annotated[Optional[int], replace] = Field(default=None)
"""
Tracks the number of tool calls made by the root node to terminate the loop.
"""
@@ -399,7 +399,6 @@ class AssistantNodeName(StrEnum):
MEMORY_COLLECTOR_TOOLS = "memory_collector_tools"
INKEEP_DOCS = "inkeep_docs"
INSIGHT_RAG_CONTEXT = "insight_rag_context"
- INSIGHTS_SUBGRAPH = "insights_subgraph"
TITLE_GENERATOR = "title_generator"
INSIGHTS_SEARCH = "insights_search"
SESSION_SUMMARIZATION = "session_summarization"
@@ -413,6 +412,13 @@ class AssistantNodeName(StrEnum):
REVENUE_ANALYTICS_FILTER_OPTIONS_TOOLS = "revenue_analytics_filter_options_tools"
+class AssistantGraphName(StrEnum):
+ ASSISTANT = "assistant_graph"
+ INSIGHTS = "insights_graph"
+ TAXONOMY = "taxonomy_graph"
+ DEEP_RESEARCH = "deep_research_graph"
+
+
class AssistantMode(StrEnum):
ASSISTANT = "assistant"
INSIGHTS_TOOL = "insights_tool"
@@ -443,9 +449,33 @@ class NodeStartAction(BaseModel):
type: Literal["NODE_START"] = "NODE_START"
-AssistantActionUnion = MessageAction | MessageChunkAction | NodeStartAction
+class NodeEndAction(BaseModel, Generic[PartialStateType]):
+ type: Literal["NODE_END"] = "NODE_END"
+ state: PartialStateType | None = None
+
+
+class UpdateAction(BaseModel):
+ type: Literal["UPDATE"] = "UPDATE"
+ content: str
+
+
+AssistantActionUnion = MessageAction | MessageChunkAction | NodeStartAction | NodeEndAction | UpdateAction
+
+
+class NodePath(BaseModel):
+ """Defines a vertice of the assistant graph path."""
+
+ name: str
+ message_id: str | None = None
+ tool_call_id: str | None = None
class AssistantDispatcherEvent(BaseModel):
action: AssistantActionUnion = Field(discriminator="type")
+ node_path: tuple[NodePath, ...] | None = None
node_name: str
+ node_run_id: str
+
+
+class LangGraphUpdateEvent(BaseModel):
+ update: Any
diff --git a/frontend/__snapshots__/scenes-app-data-pipelines--pipeline-destination-page--light.png b/frontend/__snapshots__/scenes-app-data-pipelines--pipeline-destination-page--light.png
index 77339b9eec..53f2163297 100644
Binary files a/frontend/__snapshots__/scenes-app-data-pipelines--pipeline-destination-page--light.png and b/frontend/__snapshots__/scenes-app-data-pipelines--pipeline-destination-page--light.png differ
diff --git a/frontend/__snapshots__/scenes-app-insights-trendsvalue--trends-area-edit--dark.png b/frontend/__snapshots__/scenes-app-insights-trendsvalue--trends-area-edit--dark.png
index 9dd9468aff..c7d0ddff18 100644
Binary files a/frontend/__snapshots__/scenes-app-insights-trendsvalue--trends-area-edit--dark.png and b/frontend/__snapshots__/scenes-app-insights-trendsvalue--trends-area-edit--dark.png differ
diff --git a/frontend/__snapshots__/scenes-app-insights-trendsvalue--trends-area-edit--light.png b/frontend/__snapshots__/scenes-app-insights-trendsvalue--trends-area-edit--light.png
index 176e4d0aea..fa59bb523f 100644
Binary files a/frontend/__snapshots__/scenes-app-insights-trendsvalue--trends-area-edit--light.png and b/frontend/__snapshots__/scenes-app-insights-trendsvalue--trends-area-edit--light.png differ
diff --git a/frontend/__snapshots__/scenes-app-posthog-ai--max-instance-with-contextual-tools--dark.png b/frontend/__snapshots__/scenes-app-posthog-ai--max-instance-with-contextual-tools--dark.png
index 6fc02e7724..be639a9d44 100644
Binary files a/frontend/__snapshots__/scenes-app-posthog-ai--max-instance-with-contextual-tools--dark.png and b/frontend/__snapshots__/scenes-app-posthog-ai--max-instance-with-contextual-tools--dark.png differ
diff --git a/frontend/__snapshots__/scenes-app-posthog-ai--max-instance-with-contextual-tools--light.png b/frontend/__snapshots__/scenes-app-posthog-ai--max-instance-with-contextual-tools--light.png
index 2ae147a2da..379b6d0032 100644
Binary files a/frontend/__snapshots__/scenes-app-posthog-ai--max-instance-with-contextual-tools--light.png and b/frontend/__snapshots__/scenes-app-posthog-ai--max-instance-with-contextual-tools--light.png differ
diff --git a/frontend/__snapshots__/scenes-app-posthog-ai--reasoning-component--light.png b/frontend/__snapshots__/scenes-app-posthog-ai--reasoning-component--light.png
index 1396f3d4fe..ce426311b2 100644
Binary files a/frontend/__snapshots__/scenes-app-posthog-ai--reasoning-component--light.png and b/frontend/__snapshots__/scenes-app-posthog-ai--reasoning-component--light.png differ
diff --git a/frontend/__snapshots__/scenes-app-sidepanels--side-panel-docs--dark.png b/frontend/__snapshots__/scenes-app-sidepanels--side-panel-docs--dark.png
index 397d231080..b897157a52 100644
Binary files a/frontend/__snapshots__/scenes-app-sidepanels--side-panel-docs--dark.png and b/frontend/__snapshots__/scenes-app-sidepanels--side-panel-docs--dark.png differ
diff --git a/frontend/src/queries/nodes/InsightViz/EditorFilters.tsx b/frontend/src/queries/nodes/InsightViz/EditorFilters.tsx
index d76a168aea..67ed5c97ca 100644
--- a/frontend/src/queries/nodes/InsightViz/EditorFilters.tsx
+++ b/frontend/src/queries/nodes/InsightViz/EditorFilters.tsx
@@ -409,7 +409,7 @@ export function EditorFilters({ query, showing, embedded }: EditorFiltersProps):
) : null
@@ -834,9 +835,10 @@ function ReasoningAnswer({ content, completed, id, showCompletionIcon = true }:
interface ToolCallsAnswerProps {
toolCalls: EnhancedToolCall[]
+ registeredToolMap: Record
}
-function ToolCallsAnswer({ toolCalls }: ToolCallsAnswerProps): JSX.Element {
+function ToolCallsAnswer({ toolCalls, registeredToolMap }: ToolCallsAnswerProps): JSX.Element {
// Separate todo_write tool calls from regular tool calls
const todoWriteToolCalls = toolCalls.filter((tc) => tc.name === 'todo_write')
const regularToolCalls = toolCalls.filter((tc) => tc.name !== 'todo_write')
@@ -867,7 +869,7 @@ function ToolCallsAnswer({ toolCalls }: ToolCallsAnswerProps): JSX.Element {
let description = `Executing ${toolCall.name}`
if (definition) {
if (definition.displayFormatter) {
- description = definition.displayFormatter(toolCall)
+ description = definition.displayFormatter(toolCall, { registeredToolMap })
}
if (commentary) {
description = commentary
diff --git a/frontend/src/scenes/max/components/ToolsDisplay.tsx b/frontend/src/scenes/max/components/ToolsDisplay.tsx
index 5b5f8008db..da43f87adf 100644
--- a/frontend/src/scenes/max/components/ToolsDisplay.tsx
+++ b/frontend/src/scenes/max/components/ToolsDisplay.tsx
@@ -68,7 +68,8 @@ export const ToolsDisplay: React.FC = ({ isFloating, tools, b
// or border-secondary (--color-posthog-3000-400) because the former is almost invisible here, and the latter too distinct
{toolDef?.icon || }
- {toolDef?.name}
+ {/* Controls how the create_and_query_insight tool displays its name */}
+ {tool.name}
)
})}
@@ -173,7 +174,9 @@ function ToolsExplanation({ toolsInReverse }: { toolsInReverse: ToolRegistration
if (toolDef?.subtools) {
tools.push(...Object.values(toolDef.subtools))
} else {
- tools.push({ name: toolDef?.name, description: toolDef?.description, icon: toolDef?.icon })
+ // Taking `tool` name and description from the registered tool, not the tool definition.
+ // This makes tool substitution correctly work (create_and_query_insight).
+ tools.push({ name: tool.name, description: tool.description, icon: toolDef?.icon })
}
return tools
},
diff --git a/frontend/src/scenes/max/max-constants.tsx b/frontend/src/scenes/max/max-constants.tsx
index d60e737480..2400d7c9ab 100644
--- a/frontend/src/scenes/max/max-constants.tsx
+++ b/frontend/src/scenes/max/max-constants.tsx
@@ -25,7 +25,10 @@ export interface ToolDefinition {
ToolDefinition
>
icon: JSX.Element
- displayFormatter?: (toolCall: EnhancedToolCall) => string
+ displayFormatter?: (
+ toolCall: EnhancedToolCall,
+ { registeredToolMap }: { registeredToolMap: Record }
+ ) => string
/**
* If only available in a specific product, specify it here.
* We're using Scene instead of ProductKey, because that's more flexible (specifically for SQL editor there
@@ -160,14 +163,18 @@ export const TOOL_DEFINITIONS: Record, Tool
},
},
create_and_query_insight: {
- name: 'Query data',
- description: 'Query data by creating insights and SQL queries',
+ name: 'Edit the insight',
+ description: "Edit the insight you're viewing",
icon: iconForType('product_analytics'),
- displayFormatter: (toolCall) => {
- if (toolCall.status === 'completed') {
- return 'Created an insight'
+ product: Scene.Insight,
+ displayFormatter: (toolCall, { registeredToolMap }) => {
+ const isEditing = registeredToolMap.create_and_query_insight
+ if (isEditing) {
+ return toolCall.status === 'completed'
+ ? 'Edited the insight you are viewing'
+ : 'Editing the insight you are viewing...'
}
- return 'Creating an insight...'
+ return toolCall.status === 'completed' ? 'Created an insight' : 'Creating an insight...'
},
},
search_session_recordings: {
@@ -327,18 +334,6 @@ export const TOOL_DEFINITIONS: Record, Tool
return 'Fixing SQL...'
},
},
- edit_current_insight: {
- name: 'Edit the insight',
- description: "Edit the insight you're viewing",
- icon: iconForType('product_analytics'),
- product: Scene.Insight,
- displayFormatter: (toolCall) => {
- if (toolCall.status === 'completed') {
- return 'Edited the insight you are viewing'
- }
- return 'Editing the insight you are viewing...'
- },
- },
filter_revenue_analytics: {
name: 'Filter revenue analytics',
description: 'Filter revenue analytics to find the most impactful revenue insights',
diff --git a/frontend/src/scenes/max/maxGlobalLogic.tsx b/frontend/src/scenes/max/maxGlobalLogic.tsx
index 9e238fd1aa..cac6232a74 100644
--- a/frontend/src/scenes/max/maxGlobalLogic.tsx
+++ b/frontend/src/scenes/max/maxGlobalLogic.tsx
@@ -73,8 +73,8 @@ export const STATIC_TOOLS: ToolRegistration[] = [
},
{
identifier: 'create_and_query_insight' as const,
- name: TOOL_DEFINITIONS['create_and_query_insight'].name,
- description: TOOL_DEFINITIONS['create_and_query_insight'].description,
+ name: 'Query data',
+ description: 'Query data by creating insights and SQL queries',
},
]
@@ -210,17 +210,10 @@ export const maxGlobalLogic = kea([
],
toolMap: [
(s) => [s.registeredToolMap, s.availableStaticTools],
- (registeredToolMap, availableStaticTools) => {
- if (registeredToolMap.edit_current_insight) {
- availableStaticTools = availableStaticTools.filter(
- (tool) => tool.identifier !== 'create_and_query_insight'
- )
- }
- return {
- ...Object.fromEntries(availableStaticTools.map((tool) => [tool.identifier, tool])),
- ...registeredToolMap,
- }
- },
+ (registeredToolMap, availableStaticTools) => ({
+ ...Object.fromEntries(availableStaticTools.map((tool) => [tool.identifier, tool])),
+ ...registeredToolMap,
+ }),
],
tools: [(s) => [s.toolMap], (toolMap): ToolRegistration[] => Object.values(toolMap)],
editInsightToolRegistered: [
diff --git a/posthog/schema.py b/posthog/schema.py
index 24e071dfd4..ba02d59095 100644
--- a/posthog/schema.py
+++ b/posthog/schema.py
@@ -267,7 +267,6 @@ class AssistantTool(StrEnum):
CREATE_HOG_FUNCTION_INPUTS = "create_hog_function_inputs"
CREATE_MESSAGE_TEMPLATE = "create_message_template"
NAVIGATE = "navigate"
- EDIT_CURRENT_INSIGHT = "edit_current_insight"
FILTER_ERROR_TRACKING_ISSUES = "filter_error_tracking_issues"
FIND_ERROR_TRACKING_IMPACTFUL_ISSUE_EVENT_LIST = "find_error_tracking_impactful_issue_event_list"
EXPERIMENT_RESULTS_SUMMARY = "experiment_results_summary"
diff --git a/products/dashboards/backend/test/test_max_tool_integration.py b/products/dashboards/backend/test/test_max_tool_integration.py
index 80011d75dd..37d4fa86a9 100644
--- a/products/dashboards/backend/test/test_max_tool_integration.py
+++ b/products/dashboards/backend/test/test_max_tool_integration.py
@@ -96,7 +96,6 @@ async def test_dashboard_metadata_update(dashboard_setup):
tool = EditCurrentDashboardTool(
team=dashboard.team,
user=conversation.user,
- tool_call_id="test-tool-call-id",
config=RunnableConfig(
configurable={
"thread_id": conversation.id,
@@ -133,7 +132,6 @@ async def test_dashboard_metadata_update_no_permissions(dashboard_setup_no_perms
tool = EditCurrentDashboardTool(
team=dashboard.team,
user=conversation.user,
- tool_call_id="test-tool-call-id",
config=RunnableConfig(
configurable={
"thread_id": conversation.id,
@@ -155,7 +153,6 @@ async def test_dashboard_add_insights(dashboard_setup):
tool = EditCurrentDashboardTool(
team=dashboard.team,
user=conversation.user,
- tool_call_id="test-tool-call-id",
config=RunnableConfig(
configurable={
"thread_id": conversation.id,
diff --git a/products/dashboards/backend/test/test_max_tools.py b/products/dashboards/backend/test/test_max_tools.py
index 711608f997..206417d67b 100644
--- a/products/dashboards/backend/test/test_max_tools.py
+++ b/products/dashboards/backend/test/test_max_tools.py
@@ -25,9 +25,7 @@ class TestEditCurrentDashboardTool:
configurable = {"team": mock_team, "user": mock_user}
if context:
configurable["contextual_tools"] = {"edit_current_dashboard": context}
- tool = EditCurrentDashboardTool(
- team=mock_team, user=mock_user, config={"configurable": configurable}, tool_call_id="test-tool-call-id"
- )
+ tool = EditCurrentDashboardTool(team=mock_team, user=mock_user, config={"configurable": configurable})
return tool
@pytest.mark.asyncio
diff --git a/products/data_warehouse/backend/max_tools.py b/products/data_warehouse/backend/max_tools.py
index 92f76f5438..b7d9384d3f 100644
--- a/products/data_warehouse/backend/max_tools.py
+++ b/products/data_warehouse/backend/max_tools.py
@@ -113,11 +113,10 @@ class HogQLGeneratorToolsNode(TaxonomyAgentToolsNode[TaxonomyAgentState, Taxonom
class HogQLGeneratorGraph(TaxonomyAgent[TaxonomyAgentState, TaxonomyAgentState[FinalAnswerArgs]]):
- def __init__(self, team: Team, user: User, tool_call_id: str):
+ def __init__(self, team: Team, user: User):
super().__init__(
team,
user,
- tool_call_id,
loop_node_class=HogQLGeneratorNode,
tools_node_class=HogQLGeneratorToolsNode,
toolkit_class=HogQLGeneratorToolkit,
@@ -134,9 +133,7 @@ class HogQLGeneratorTool(HogQLGeneratorMixin, MaxTool):
current_query: str | None = self.context.get("current_query", "")
user_prompt = HOGQL_GENERATOR_USER_PROMPT.format(instructions=instructions, current_query=current_query)
- graph = HogQLGeneratorGraph(
- team=self._team, user=self._user, tool_call_id=self._tool_call_id
- ).compile_full_graph()
+ graph = HogQLGeneratorGraph(team=self._team, user=self._user).compile_full_graph()
graph_context = {
"change": user_prompt,
diff --git a/products/data_warehouse/backend/test/test_max_tools.py b/products/data_warehouse/backend/test/test_max_tools.py
index 30ba13ea3c..c780b868bf 100644
--- a/products/data_warehouse/backend/test/test_max_tools.py
+++ b/products/data_warehouse/backend/test/test_max_tools.py
@@ -43,9 +43,7 @@ class TestDataWarehouseMaxTools(NonAtomicBaseTest):
mock_graph.ainvoke.return_value = mock_result
mock_compile.return_value = mock_graph
- tool = HogQLGeneratorTool(
- team=self.team, user=self.user, state=AssistantState(messages=[]), tool_call_id="test-tool-call-id"
- )
+ tool = HogQLGeneratorTool(team=self.team, user=self.user, state=AssistantState(messages=[]))
tool_call = AssistantToolCall(
id="1",
name="generate_hogql_query",
@@ -83,9 +81,7 @@ class TestDataWarehouseMaxTools(NonAtomicBaseTest):
mock_graph.ainvoke.return_value = mock_result
mock_compile.return_value = mock_graph
- tool = HogQLGeneratorTool(
- team=self.team, user=self.user, state=AssistantState(messages=[]), tool_call_id="test-tool-call-id"
- )
+ tool = HogQLGeneratorTool(team=self.team, user=self.user, state=AssistantState(messages=[]))
tool_call = AssistantToolCall(
id="1",
name="generate_hogql_query",
@@ -119,9 +115,7 @@ class TestDataWarehouseMaxTools(NonAtomicBaseTest):
mock_compile.return_value = mock_graph
- tool = HogQLGeneratorTool(
- team=self.team, user=self.user, state=AssistantState(messages=[]), tool_call_id="test-tool-call-id"
- )
+ tool = HogQLGeneratorTool(team=self.team, user=self.user, state=AssistantState(messages=[]))
tool_call = AssistantToolCall(
id="1",
name="generate_hogql_query",
@@ -171,9 +165,7 @@ class TestDataWarehouseMaxTools(NonAtomicBaseTest):
"SELECT suspicious_query FROM events", "Suspicious query detected"
)
- tool = HogQLGeneratorTool(
- team=self.team, user=self.user, state=AssistantState(messages=[]), tool_call_id="test-tool-call-id"
- )
+ tool = HogQLGeneratorTool(team=self.team, user=self.user, state=AssistantState(messages=[]))
tool_call = AssistantToolCall(
id="1",
name="generate_hogql_query",
@@ -223,9 +215,7 @@ class TestDataWarehouseMaxTools(NonAtomicBaseTest):
mock_graph.ainvoke.return_value = mock_result
mock_compile.return_value = mock_graph
- tool = HogQLGeneratorTool(
- team=self.team, user=self.user, state=AssistantState(messages=[]), tool_call_id="test-tool-call-id"
- )
+ tool = HogQLGeneratorTool(team=self.team, user=self.user, state=AssistantState(messages=[]))
tool_call = AssistantToolCall(
id="1",
name="generate_hogql_query",
@@ -263,9 +253,7 @@ class TestDataWarehouseMaxTools(NonAtomicBaseTest):
mock_graph.ainvoke.return_value = graph_result
mock_compile.return_value = mock_graph
- tool = HogQLGeneratorTool(
- team=self.team, user=self.user, state=AssistantState(messages=[]), tool_call_id="test-tool-call-id"
- )
+ tool = HogQLGeneratorTool(team=self.team, user=self.user, state=AssistantState(messages=[]))
tool_call = AssistantToolCall(
id="1",
name="generate_hogql_query",
@@ -301,9 +289,7 @@ class TestDataWarehouseMaxTools(NonAtomicBaseTest):
mock_graph.ainvoke.return_value = graph_result
mock_compile.return_value = mock_graph
- tool = HogQLGeneratorTool(
- team=self.team, user=self.user, state=AssistantState(messages=[]), tool_call_id="test-tool-call-id"
- )
+ tool = HogQLGeneratorTool(team=self.team, user=self.user, state=AssistantState(messages=[]))
tool_call = AssistantToolCall(
id="1",
name="generate_hogql_query",
@@ -339,9 +325,7 @@ class TestDataWarehouseMaxTools(NonAtomicBaseTest):
mock_graph.ainvoke.return_value = graph_result
mock_compile.return_value = mock_graph
- tool = HogQLGeneratorTool(
- team=self.team, user=self.user, state=AssistantState(messages=[]), tool_call_id="test-tool-call-id"
- )
+ tool = HogQLGeneratorTool(team=self.team, user=self.user, state=AssistantState(messages=[]))
tool_call = AssistantToolCall(
id="1",
name="generate_hogql_query",
@@ -353,9 +337,7 @@ class TestDataWarehouseMaxTools(NonAtomicBaseTest):
def test_current_query_included_in_system_prompt_template(self):
"""Test that the system prompt template includes the current query section."""
- tool = HogQLGeneratorTool(
- team=self.team, user=self.user, state=AssistantState(messages=[]), tool_call_id="test-tool-call-id"
- )
+ tool = HogQLGeneratorTool(team=self.team, user=self.user, state=AssistantState(messages=[]))
# Verify the system prompt template contains the expected current query section
self.assertIn("The current HogQL query", tool.context_prompt_template)
diff --git a/products/error_tracking/backend/max_tools.py b/products/error_tracking/backend/max_tools.py
index 95fc69075a..d994186403 100644
--- a/products/error_tracking/backend/max_tools.py
+++ b/products/error_tracking/backend/max_tools.py
@@ -155,11 +155,10 @@ class ErrorTrackingIssueImpactToolsNode(
class ErrorTrackingIssueImpactGraph(
TaxonomyAgent[TaxonomyAgentState, TaxonomyAgentState[ErrorTrackingIssueImpactToolOutput]]
):
- def __init__(self, team: Team, user: User, tool_call_id: str):
+ def __init__(self, team: Team, user: User):
super().__init__(
team,
user,
- tool_call_id,
loop_node_class=ErrorTrackingIssueImpactLoopNode,
tools_node_class=ErrorTrackingIssueImpactToolsNode,
toolkit_class=ErrorTrackingIssueImpactToolkit,
@@ -177,7 +176,7 @@ class ErrorTrackingIssueImpactTool(MaxTool):
args_schema: type[BaseModel] = IssueImpactQueryArgs
async def _arun_impl(self, instructions: str) -> tuple[str, ErrorTrackingIssueImpactToolOutput]:
- graph = ErrorTrackingIssueImpactGraph(team=self._team, user=self._user, tool_call_id=self._tool_call_id)
+ graph = ErrorTrackingIssueImpactGraph(team=self._team, user=self._user)
graph_context = {
"change": f"Goal: {instructions}",
diff --git a/products/feature_flags/backend/max_tools.py b/products/feature_flags/backend/max_tools.py
index 14d8a08b21..7cfaa5bd88 100644
--- a/products/feature_flags/backend/max_tools.py
+++ b/products/feature_flags/backend/max_tools.py
@@ -439,11 +439,10 @@ class FeatureFlagGeneratorGraph(TaxonomyAgent[TaxonomyAgentState, TaxonomyAgentS
4. Generate structured feature flag configuration
"""
- def __init__(self, team: Team, user: User, tool_call_id: str):
+ def __init__(self, team: Team, user: User):
super().__init__(
team,
user,
- tool_call_id,
loop_node_class=FeatureFlagCreationNode,
tools_node_class=FeatureFlagCreationToolsNode,
toolkit_class=FeatureFlagToolkit,
@@ -501,7 +500,7 @@ The tool will automatically:
async def _create_flag_from_instructions(self, instructions: str) -> FeatureFlagCreationSchema:
"""Use TaxonomyAgent graph to generate structured flag configuration."""
- graph = FeatureFlagGeneratorGraph(team=self._team, user=self._user, tool_call_id=self._tool_call_id)
+ graph = FeatureFlagGeneratorGraph(team=self._team, user=self._user)
graph_context = {
"change": f"Create a feature flag based on these instructions: {instructions}",
diff --git a/products/product_analytics/backend/max_tools.py b/products/product_analytics/backend/max_tools.py
deleted file mode 100644
index 6af7ccdb0c..0000000000
--- a/products/product_analytics/backend/max_tools.py
+++ /dev/null
@@ -1,137 +0,0 @@
-from uuid import uuid4
-
-from pydantic import BaseModel, Field
-
-from posthog.schema import AssistantMessage, AssistantToolCallMessage, VisualizationMessage
-
-from ee.hogai.tool import MaxTool, ToolMessagesArtifact
-from ee.hogai.utils.types import AssistantState
-
-QUERY_KIND_DESCRIPTION_PROMPT = """
-## Trends
-A trends insight visualizes events over time using time series. They're useful for finding patterns in historical data.
-
-The trends insights have the following features:
-- The insight can show multiple trends in one request.
-- Custom formulas can calculate derived metrics, like `A/B*100` to calculate a ratio.
-- Filter and break down data using multiple properties.
-- Compare with the previous period and sample data.
-- Apply various aggregation types, like sum, average, etc., and chart types.
-- And more.
-
-Examples of use cases include:
-- How the product's most important metrics change over time.
-- Long-term patterns, or cycles in product's usage.
-- The usage of different features side-by-side.
-- How the properties of events vary using aggregation (sum, average, etc).
-- Users can also visualize the same data points in a variety of ways.
-
-## Funnel
-A funnel insight visualizes a sequence of events that users go through in a product. They use percentages as the primary aggregation type. Funnels use two or more series, so the conversation history should mention at least two events.
-
-The funnel insights have the following features:
-- Various visualization types (steps, time-to-convert, historical trends).
-- Filter data and apply exclusion steps.
-- Break down data using a single property.
-- Specify conversion windows, details of conversion calculation, attribution settings.
-- Sample data.
-- And more.
-
-Examples of use cases include:
-- Conversion rates.
-- Drop off steps.
-- Steps with the highest friction and time to convert.
-- If product changes are improving their funnel over time.
-- Average/median time to convert.
-- Conversion trends over time.
-
-## Retention
-A retention insight visualizes how many users return to the product after performing some action. They're useful for understanding user engagement and retention.
-
-The retention insights have the following features: filter data, sample data, and more.
-
-Examples of use cases include:
-- How many users come back and perform an action after their first visit.
-- How many users come back to perform action X after performing action Y.
-- How often users return to use a specific feature.
-
-## SQL
-The 'sql' insight type allows you to write arbitrary SQL queries to retrieve data.
-
-The SQL insights have the following features:
-- Filter data using arbitrary SQL.
-- All ClickHouse SQL features.
-- You can nest subqueries as needed.
-""".strip()
-
-
-class EditCurrentInsightArgs(BaseModel):
- """
- Edits the insight visualization the user is currently working on, by creating a query or iterating on a previous query.
- """
-
- query_description: str = Field(
- description="The new query to edit the current insight. Must include all details from the current insight plus any change on top of them. Include any relevant information from the current conversation, as the tool does not have access to the conversation."
- )
- query_kind: str = Field(description=QUERY_KIND_DESCRIPTION_PROMPT)
-
-
-class EditCurrentInsightTool(MaxTool):
- name: str = "edit_current_insight"
- description: str = (
- "Update the insight the user is currently working on, based on the current insight's JSON schema."
- )
- context_prompt_template: str = """The user is currently editing an insight (aka query). Here is that insight's current definition, which can be edited using the `edit_current_insight` tool:
-
-```json
-{current_query}
-```
-
-IMPORTANT: DO NOT REMOVE ANY FIELDS FROM THE CURRENT INSIGHT DEFINITION. DO NOT CHANGE ANY OTHER FIELDS THAN THE ONES THE USER ASKED FOR. KEEP THE REST AS IS.
-""".strip()
-
- args_schema: type[BaseModel] = EditCurrentInsightArgs
-
- async def _arun_impl(self, query_kind: str, query_description: str) -> tuple[str, ToolMessagesArtifact]:
- from ee.hogai.graph.graph import InsightsAssistantGraph # avoid circular import
-
- if "current_query" not in self.context:
- raise ValueError("Context `current_query` is required for the `create_and_query_insight` tool")
-
- graph = InsightsAssistantGraph(self._team, self._user, tool_call_id=self._tool_call_id).compile_full_graph()
- state = self._state
- last_message = state.messages[-1]
- if not isinstance(last_message, AssistantMessage):
- raise ValueError("Last message is not an AssistantMessage")
- if last_message.tool_calls is None or len(last_message.tool_calls) == 0:
- raise ValueError("Last message has no tool calls")
-
- state.root_tool_insight_plan = query_description
- root_tool_call_id = last_message.tool_calls[0].id
-
- # We need to set a new root tool call id to sub-nest the graph within the contextual tool call
- # and avoid duplicating messages in the stream
- state.root_tool_call_id = str(uuid4())
-
- state_dict = await graph.ainvoke(state, config=self._config)
- state = AssistantState.model_validate(state_dict)
-
- result = state.messages[-1]
- viz_messages = [message for message in state.messages if isinstance(message, VisualizationMessage)]
- viz_message = viz_messages[-1] if viz_messages else None
- if not viz_message:
- raise ValueError("Visualization was not generated")
- if not isinstance(result, AssistantToolCallMessage):
- raise ValueError("Last message is not an AssistantToolCallMessage")
-
- return "", ToolMessagesArtifact(
- messages=[
- viz_message,
- AssistantToolCallMessage(
- content=result.content,
- ui_payload={self.get_name(): viz_message.answer.model_dump(exclude_none=True)},
- id=str(uuid4()),
- tool_call_id=root_tool_call_id,
- ),
- ]
- )
diff --git a/products/replay/backend/max_tools.py b/products/replay/backend/max_tools.py
index ea88a5004e..656c22e0ab 100644
--- a/products/replay/backend/max_tools.py
+++ b/products/replay/backend/max_tools.py
@@ -90,11 +90,10 @@ class SessionReplayFilterOptionsGraph(
):
"""Graph for generating filtering options for session replay."""
- def __init__(self, team: Team, user: User, tool_call_id: str):
+ def __init__(self, team: Team, user: User):
super().__init__(
team,
user,
- tool_call_id=tool_call_id,
loop_node_class=SessionReplayFilterNode,
tools_node_class=SessionReplayFilterOptionsToolsNode,
toolkit_class=SessionReplayFilterOptionsToolkit,
@@ -131,7 +130,7 @@ class SearchSessionRecordingsTool(MaxTool):
Reusable method to call graph to avoid code/prompt duplication and enable
different processing of the results, based on the place the tool is used.
"""
- graph = SessionReplayFilterOptionsGraph(team=self._team, user=self._user, tool_call_id=self._tool_call_id)
+ graph = SessionReplayFilterOptionsGraph(team=self._team, user=self._user)
pretty_filters = json.dumps(self.context.get("current_filters", {}), indent=2)
user_prompt = USER_FILTER_OPTIONS_PROMPT.format(change=change, current_filters=pretty_filters)
graph_context = {
diff --git a/products/revenue_analytics/backend/max_tools.py b/products/revenue_analytics/backend/max_tools.py
index d34d2323a7..afc5d922b7 100644
--- a/products/revenue_analytics/backend/max_tools.py
+++ b/products/revenue_analytics/backend/max_tools.py
@@ -153,11 +153,10 @@ class RevenueAnalyticsFilterOptionsGraph(
):
"""Graph for generating filtering options for revenue analytics."""
- def __init__(self, team: Team, user: User, tool_call_id: str):
+ def __init__(self, team: Team, user: User):
super().__init__(
team,
user,
- tool_call_id,
loop_node_class=RevenueAnalyticsFilterNode,
tools_node_class=RevenueAnalyticsFilterOptionsToolsNode,
toolkit_class=RevenueAnalyticsFilterOptionsToolkit,
@@ -192,7 +191,7 @@ class FilterRevenueAnalyticsTool(MaxTool):
Reusable method to call graph to avoid code/prompt duplication and enable
different processing of the results, based on the place the tool is used.
"""
- graph = RevenueAnalyticsFilterOptionsGraph(team=self._team, user=self._user, tool_call_id=self._tool_call_id)
+ graph = RevenueAnalyticsFilterOptionsGraph(team=self._team, user=self._user)
pretty_filters = json.dumps(self.context.get("current_filters", {}), indent=2)
user_prompt = USER_FILTER_OPTIONS_PROMPT.format(change=change, current_filters=pretty_filters)
graph_context = {
diff --git a/products/surveys/backend/max_tools.py b/products/surveys/backend/max_tools.py
index 547074fc43..2ddbb20c76 100644
--- a/products/surveys/backend/max_tools.py
+++ b/products/surveys/backend/max_tools.py
@@ -52,7 +52,7 @@ class CreateSurveyTool(MaxTool):
Create a survey from natural language instructions.
"""
- graph = FeatureFlagLookupGraph(team=self._team, user=self._user, tool_call_id=self._tool_call_id)
+ graph = FeatureFlagLookupGraph(team=self._team, user=self._user)
graph_context = {
"change": f"Create a survey based on these instructions: {instructions}",
@@ -286,11 +286,10 @@ class SurveyLookupToolsNode(TaxonomyAgentToolsNode[TaxonomyAgentState, TaxonomyA
class FeatureFlagLookupGraph(TaxonomyAgent[TaxonomyAgentState, TaxonomyAgentState[SurveyCreationSchema]]):
"""Graph for feature flag lookup operations."""
- def __init__(self, team: Team, user: User, tool_call_id: str):
+ def __init__(self, team: Team, user: User):
super().__init__(
team,
user,
- tool_call_id,
loop_node_class=SurveyLoopNode,
tools_node_class=SurveyLookupToolsNode,
toolkit_class=SurveyToolkit,
diff --git a/products/surveys/backend/test_max_tools.py b/products/surveys/backend/test_max_tools.py
index 7380676fa1..e097918e44 100644
--- a/products/surveys/backend/test_max_tools.py
+++ b/products/surveys/backend/test_max_tools.py
@@ -45,7 +45,7 @@ class TestSurveyCreatorTool(BaseTest):
def _setup_tool(self):
"""Helper to create a SurveyCreatorTool instance with mocked dependencies"""
- tool = CreateSurveyTool(team=self.team, user=self.user, config=self._config, tool_call_id="test-tool-call-id")
+ tool = CreateSurveyTool(team=self.team, user=self.user, config=self._config)
return tool
def test_get_team_survey_config(self):
@@ -468,7 +468,6 @@ class TestSurveyAnalysisTool(BaseTest):
tool = SurveyAnalysisTool(
team=self.team,
user=self.user,
- tool_call_id="test-tool-call-id",
config={
**self._config,
"configurable": {