feat(max): parallel tool calls (#39954)

Co-authored-by: kappa90 <e.capparelli@gmail.com>
Co-authored-by: github-actions <41898282+github-actions[bot]@users.noreply.github.com>
This commit is contained in:
Georgiy Tarasov
2025-11-05 16:08:31 +01:00
committed by GitHub
parent e49db4138f
commit beadde8e4e
128 changed files with 5035 additions and 3346 deletions

View File

@@ -320,7 +320,7 @@ class YourTaxonomyGraph(TaxonomyAgent[TaxonomyAgentState, TaxonomyAgentState[Max
4. Invoke it (typically from a `MaxTool`), mirroring `products/replay/backend/max_tools.py`:
```python
graph = YourTaxonomyGraph(team=self._team, user=self._user, tool_call_id=self._tool_call_id)
graph = YourTaxonomyGraph(team=self._team, user=self._user)
graph_context = {
"change": "Show me recordings of users in Germany that used a mobile device while performing a payment",

View File

@@ -9,7 +9,6 @@ import structlog
import posthoganalytics
from asgiref.sync import async_to_sync
from langchain_core.callbacks.base import BaseCallbackHandler
from langchain_core.messages import AIMessageChunk
from langchain_core.runnables.config import RunnableConfig
from langgraph.errors import GraphRecursionError
from langgraph.graph.state import CompiledStateGraph
@@ -31,20 +30,19 @@ from posthog.event_usage import report_user_action
from posthog.models import Team, User
from posthog.sync import database_sync_to_async
from ee.hogai.graph.base import BaseAssistantNode
from ee.hogai.utils.exceptions import GenerationCanceled
from ee.hogai.utils.helpers import extract_stream_update
from ee.hogai.utils.state import (
GraphValueUpdateTuple,
is_message_update,
is_state_update,
is_value_update,
validate_state_update,
from ee.hogai.utils.state import validate_state_update
from ee.hogai.utils.stream_processor import AssistantStreamProcessorProtocol
from ee.hogai.utils.types.base import (
AssistantDispatcherEvent,
AssistantMessageUnion,
AssistantMode,
AssistantOutput,
AssistantResultUnion,
LangGraphUpdateEvent,
)
from ee.hogai.utils.stream_processor import AssistantStreamProcessor
from ee.hogai.utils.types import AssistantMessageUnion, AssistantOutput
from ee.hogai.utils.types.base import AssistantDispatcherEvent, AssistantMode, AssistantResultUnion, MessageChunkAction
from ee.hogai.utils.types.composed import AssistantMaxGraphState, AssistantMaxPartialGraphState, MaxNodeName
from ee.hogai.utils.types.composed import AssistantMaxGraphState, AssistantMaxPartialGraphState
from ee.models import Conversation
logger = structlog.get_logger(__name__)
@@ -66,7 +64,7 @@ class BaseAssistant(ABC):
_trace_id: Optional[str | UUID]
_billing_context: Optional[MaxBillingContext]
_initial_state: Optional[AssistantMaxGraphState | AssistantMaxPartialGraphState]
_stream_processor: AssistantStreamProcessor
_stream_processor: AssistantStreamProcessorProtocol
"""The stream processor that processes dispatcher actions and message chunks."""
def __init__(
@@ -87,6 +85,7 @@ class BaseAssistant(ABC):
billing_context: Optional[MaxBillingContext] = None,
initial_state: Optional[AssistantMaxGraphState | AssistantMaxPartialGraphState] = None,
callback_handler: Optional[BaseCallbackHandler] = None,
stream_processor: AssistantStreamProcessorProtocol,
):
self._team = team
self._contextual_tools = contextual_tools or {}
@@ -121,22 +120,7 @@ class BaseAssistant(ABC):
self._mode = mode
self._initial_state = initial_state
# Initialize the stream processor with node configuration
self._stream_processor = AssistantStreamProcessor(
streaming_nodes=self.STREAMING_NODES,
visualization_nodes=self.VISUALIZATION_NODES,
)
@property
@abstractmethod
def VISUALIZATION_NODES(self) -> dict[MaxNodeName, type[BaseAssistantNode]]:
"""Nodes that can generate visualizations."""
pass
@property
@abstractmethod
def STREAMING_NODES(self) -> set[MaxNodeName]:
"""Nodes that can stream messages to the client."""
pass
self._stream_processor = stream_processor
@abstractmethod
def get_initial_state(self) -> AssistantMaxGraphState:
@@ -175,7 +159,7 @@ class BaseAssistant(ABC):
state = await self._init_or_update_state()
config = self._get_config()
stream_mode: list[StreamMode] = ["values", "updates", "custom"]
stream_mode: list[StreamMode] = ["values", "custom"]
if stream_message_chunks:
stream_mode.append("messages")
@@ -286,11 +270,11 @@ class BaseAssistant(ABC):
# Add existing ids to streamed messages, so we don't send the messages again.
for message in saved_state.messages:
if message.id is not None:
self._stream_processor._streamed_update_ids.add(message.id)
self._stream_processor.mark_id_as_streamed(message.id)
# Add the latest message id to streamed messages, so we don't send it multiple times.
if self._latest_message and self._latest_message.id is not None:
self._stream_processor._streamed_update_ids.add(self._latest_message.id)
self._stream_processor.mark_id_as_streamed(self._latest_message.id)
# If the graph previously hasn't reset the state, it is an interrupt. We resume from the point of interruption.
if snapshot.next and self._latest_message and saved_state.graph_status == "interrupted":
@@ -322,29 +306,12 @@ class BaseAssistant(ABC):
async def _process_update(self, update: Any) -> list[AssistantResultUnion] | None:
update = extract_stream_update(update)
new_message: AssistantResultUnion | None = None
if not isinstance(update, AssistantDispatcherEvent):
if is_state_update(update):
_, new_state = update
self._state = validate_state_update(new_state, self._state_type)
elif is_value_update(update) and (new_message := await self._aprocess_value_update(update)):
return [new_message]
if updates := self._stream_processor.process_langgraph_update(LangGraphUpdateEvent(update=update)):
return updates
elif new_message := self._stream_processor.process(update):
return new_message
if is_message_update(update):
# Convert the message chunk update to a dispatcher event to prepare for a bright future without LangGraph
message, state = update[1]
if not isinstance(message, AIMessageChunk):
return None
update = AssistantDispatcherEvent(
action=MessageChunkAction(message=message), node_name=state["langgraph_node"]
)
if isinstance(update, AssistantDispatcherEvent) and (new_message := self._stream_processor.process(update)):
return [new_message] if new_message else None
return None
async def _aprocess_value_update(self, update: GraphValueUpdateTuple) -> AssistantResultUnion | None:
return None
def _build_root_config_for_persistence(self) -> RunnableConfig:

View File

@@ -1,5 +1,5 @@
from collections.abc import AsyncGenerator
from typing import Any, Optional
from typing import TYPE_CHECKING, Any, Optional
from uuid import UUID
from posthog.schema import AssistantMessage, HumanMessage, MaxBillingContext, VisualizationMessage
@@ -7,13 +7,26 @@ from posthog.schema import AssistantMessage, HumanMessage, MaxBillingContext, Vi
from posthog.models import Team, User
from ee.hogai.assistant.base import BaseAssistant
from ee.hogai.graph import DeepResearchAssistantGraph
from ee.hogai.graph.base import BaseAssistantNode
from ee.hogai.graph.deep_research.graph import DeepResearchAssistantGraph
from ee.hogai.graph.deep_research.types import DeepResearchNodeName, DeepResearchState, PartialDeepResearchState
from ee.hogai.utils.stream_processor import AssistantStreamProcessor
from ee.hogai.utils.types import AssistantMode, AssistantOutput
from ee.hogai.utils.types.composed import MaxNodeName
from ee.models import Conversation
if TYPE_CHECKING:
from ee.hogai.utils.types.composed import MaxNodeName
STREAMING_NODES: set["MaxNodeName"] = {
DeepResearchNodeName.ONBOARDING,
DeepResearchNodeName.PLANNER,
DeepResearchNodeName.TASK_EXECUTOR,
}
VERBOSE_NODES: set["MaxNodeName"] = STREAMING_NODES | {
DeepResearchNodeName.PLANNER_TOOLS,
}
class DeepResearchAssistant(BaseAssistant):
_state: Optional[DeepResearchState]
@@ -48,20 +61,13 @@ class DeepResearchAssistant(BaseAssistant):
trace_id=trace_id,
billing_context=billing_context,
initial_state=initial_state,
stream_processor=AssistantStreamProcessor(
verbose_nodes=VERBOSE_NODES,
streaming_nodes=STREAMING_NODES,
state_type=DeepResearchState,
),
)
@property
def VISUALIZATION_NODES(self) -> dict[MaxNodeName, type[BaseAssistantNode]]:
return {}
@property
def STREAMING_NODES(self) -> set[MaxNodeName]:
return {
DeepResearchNodeName.ONBOARDING,
DeepResearchNodeName.PLANNER,
DeepResearchNodeName.TASK_EXECUTOR,
}
def get_initial_state(self) -> DeepResearchState:
if self._latest_message:
return DeepResearchState(

View File

@@ -1,36 +1,32 @@
from collections.abc import AsyncGenerator
from typing import Any, Optional
from typing import TYPE_CHECKING, Any, Optional
from uuid import UUID
from posthog.schema import (
AssistantGenerationStatusEvent,
AssistantGenerationStatusType,
AssistantMessage,
HumanMessage,
MaxBillingContext,
VisualizationMessage,
)
from posthog.schema import AssistantMessage, HumanMessage, MaxBillingContext, VisualizationMessage
from posthog.models import Team, User
from ee.hogai.assistant.base import BaseAssistant
from ee.hogai.graph import FunnelGeneratorNode, RetentionGeneratorNode, SQLGeneratorNode, TrendsGeneratorNode
from ee.hogai.graph.base import BaseAssistantNode
from ee.hogai.graph.graph import InsightsAssistantGraph
from ee.hogai.graph.query_executor.nodes import QueryExecutorNode
from ee.hogai.graph.taxonomy.types import TaxonomyNodeName
from ee.hogai.utils.state import GraphValueUpdateTuple, validate_value_update
from ee.hogai.utils.types import (
AssistantMode,
AssistantNodeName,
AssistantOutput,
AssistantState,
PartialAssistantState,
)
from ee.hogai.utils.types.base import AssistantResultUnion
from ee.hogai.utils.types.composed import MaxNodeName
from ee.hogai.graph.insights_graph.graph import InsightsGraph
from ee.hogai.utils.stream_processor import AssistantStreamProcessor
from ee.hogai.utils.types import AssistantMode, AssistantOutput, AssistantState, PartialAssistantState
from ee.hogai.utils.types.base import AssistantNodeName
from ee.models import Conversation
if TYPE_CHECKING:
from ee.hogai.utils.types.composed import MaxNodeName
VERBOSE_NODES: set["MaxNodeName"] = {
AssistantNodeName.QUERY_EXECUTOR,
AssistantNodeName.FUNNEL_GENERATOR,
AssistantNodeName.RETENTION_GENERATOR,
AssistantNodeName.SQL_GENERATOR,
AssistantNodeName.TRENDS_GENERATOR,
AssistantNodeName.ROOT,
AssistantNodeName.ROOT_TOOLS,
}
class InsightsAssistant(BaseAssistant):
_state: Optional[AssistantState]
@@ -55,7 +51,7 @@ class InsightsAssistant(BaseAssistant):
conversation,
new_message=new_message,
user=user,
graph=InsightsAssistantGraph(team, user).compile_full_graph(),
graph=InsightsGraph(team, user).compile_full_graph(),
state_type=AssistantState,
partial_state_type=PartialAssistantState,
mode=AssistantMode.INSIGHTS_TOOL,
@@ -65,24 +61,11 @@ class InsightsAssistant(BaseAssistant):
trace_id=trace_id,
billing_context=billing_context,
initial_state=initial_state,
stream_processor=AssistantStreamProcessor(
verbose_nodes=VERBOSE_NODES, streaming_nodes=set(), state_type=AssistantState
),
)
@property
def VISUALIZATION_NODES(self) -> dict[MaxNodeName, type[BaseAssistantNode]]:
return {
AssistantNodeName.TRENDS_GENERATOR: TrendsGeneratorNode,
AssistantNodeName.FUNNEL_GENERATOR: FunnelGeneratorNode,
AssistantNodeName.RETENTION_GENERATOR: RetentionGeneratorNode,
AssistantNodeName.SQL_GENERATOR: SQLGeneratorNode,
AssistantNodeName.QUERY_EXECUTOR: QueryExecutorNode,
}
@property
def STREAMING_NODES(self) -> set[MaxNodeName]:
return {
TaxonomyNodeName.LOOP_NODE,
}
def get_initial_state(self) -> AssistantState:
return AssistantState(messages=[])
@@ -127,13 +110,3 @@ class InsightsAssistant(BaseAssistant):
"is_new_conversation": False,
},
)
async def _aprocess_value_update(self, update: GraphValueUpdateTuple) -> AssistantResultUnion | None:
_, maybe_state_update = update
state_update = validate_value_update(maybe_state_update)
if intersected_nodes := state_update.keys() & self.VISUALIZATION_NODES.keys():
node_name: MaxNodeName = intersected_nodes.pop()
node_val = state_update[node_name]
if isinstance(node_val, PartialAssistantState) and node_val.intermediate_steps:
return AssistantGenerationStatusEvent(type=AssistantGenerationStatusType.GENERATION_ERROR)
return await super()._aprocess_value_update(update)

View File

@@ -1,29 +1,15 @@
from collections.abc import AsyncGenerator
from typing import Any, Optional
from typing import TYPE_CHECKING, Any, Optional
from uuid import UUID
from posthog.schema import (
AssistantGenerationStatusEvent,
AssistantGenerationStatusType,
AssistantMessage,
HumanMessage,
MaxBillingContext,
VisualizationMessage,
)
from posthog.schema import AssistantMessage, HumanMessage, MaxBillingContext, VisualizationMessage
from posthog.models import Team, User
from ee.hogai.assistant.base import BaseAssistant
from ee.hogai.graph import (
AssistantGraph,
FunnelGeneratorNode,
RetentionGeneratorNode,
SQLGeneratorNode,
TrendsGeneratorNode,
)
from ee.hogai.graph.base import BaseAssistantNode
from ee.hogai.graph.insights.nodes import InsightSearchNode
from ee.hogai.utils.state import GraphValueUpdateTuple, validate_value_update
from ee.hogai.graph.graph import AssistantGraph
from ee.hogai.graph.taxonomy.types import TaxonomyNodeName
from ee.hogai.utils.stream_processor import AssistantStreamProcessor
from ee.hogai.utils.types import (
AssistantMode,
AssistantNodeName,
@@ -31,10 +17,36 @@ from ee.hogai.utils.types import (
AssistantState,
PartialAssistantState,
)
from ee.hogai.utils.types.base import AssistantResultUnion
from ee.hogai.utils.types.composed import MaxNodeName
from ee.models import Conversation
if TYPE_CHECKING:
from ee.hogai.utils.types.composed import MaxNodeName
STREAMING_NODES: set["MaxNodeName"] = {
AssistantNodeName.ROOT,
AssistantNodeName.INKEEP_DOCS,
AssistantNodeName.MEMORY_ONBOARDING,
AssistantNodeName.MEMORY_INITIALIZER,
AssistantNodeName.MEMORY_ONBOARDING_ENQUIRY,
AssistantNodeName.MEMORY_ONBOARDING_FINALIZE,
AssistantNodeName.DASHBOARD_CREATION,
}
VERBOSE_NODES: set["MaxNodeName"] = {
AssistantNodeName.TRENDS_GENERATOR,
AssistantNodeName.FUNNEL_GENERATOR,
AssistantNodeName.RETENTION_GENERATOR,
AssistantNodeName.SQL_GENERATOR,
AssistantNodeName.INSIGHTS_SEARCH,
AssistantNodeName.QUERY_EXECUTOR,
AssistantNodeName.MEMORY_INITIALIZER_INTERRUPT,
AssistantNodeName.ROOT_TOOLS,
TaxonomyNodeName.TOOLS_NODE,
TaxonomyNodeName.TASK_EXECUTOR,
}
class MainAssistant(BaseAssistant):
_state: Optional[AssistantState]
@@ -69,30 +81,11 @@ class MainAssistant(BaseAssistant):
trace_id=trace_id,
billing_context=billing_context,
initial_state=initial_state,
stream_processor=AssistantStreamProcessor(
verbose_nodes=VERBOSE_NODES, streaming_nodes=STREAMING_NODES, state_type=AssistantState
),
)
@property
def VISUALIZATION_NODES(self) -> dict[MaxNodeName, type[BaseAssistantNode]]:
return {
AssistantNodeName.TRENDS_GENERATOR: TrendsGeneratorNode,
AssistantNodeName.FUNNEL_GENERATOR: FunnelGeneratorNode,
AssistantNodeName.RETENTION_GENERATOR: RetentionGeneratorNode,
AssistantNodeName.SQL_GENERATOR: SQLGeneratorNode,
AssistantNodeName.INSIGHTS_SEARCH: InsightSearchNode,
}
@property
def STREAMING_NODES(self) -> set[MaxNodeName]:
return {
AssistantNodeName.ROOT,
AssistantNodeName.INKEEP_DOCS,
AssistantNodeName.MEMORY_ONBOARDING,
AssistantNodeName.MEMORY_INITIALIZER,
AssistantNodeName.MEMORY_ONBOARDING_ENQUIRY,
AssistantNodeName.MEMORY_ONBOARDING_FINALIZE,
AssistantNodeName.DASHBOARD_CREATION,
}
def get_initial_state(self) -> AssistantState:
if self._latest_message:
return AssistantState(
@@ -142,13 +135,3 @@ class MainAssistant(BaseAssistant):
"is_new_conversation": self._is_new_conversation,
},
)
async def _aprocess_value_update(self, update: GraphValueUpdateTuple) -> AssistantResultUnion | None:
_, maybe_state_update = update
state_update = validate_value_update(maybe_state_update)
if intersected_nodes := state_update.keys() & self.VISUALIZATION_NODES.keys():
node_name: MaxNodeName = intersected_nodes.pop()
node_val = state_update[node_name]
if isinstance(node_val, PartialAssistantState) and node_val.intermediate_steps:
return AssistantGenerationStatusEvent(type=AssistantGenerationStatusType.GENERATION_ERROR)
return await super()._aprocess_value_update(update)

View File

@@ -396,7 +396,7 @@ class AssistantContextManager(AssistantContextMixin):
contextual_tools_prompt = [
f"<{tool_name}>\n"
f"{get_contextual_tool_class(tool_name)(team=self._team, user=self._user, tool_call_id="").format_context_prompt_injection(tool_context)}\n" # type: ignore
f"{get_contextual_tool_class(tool_name)(team=self._team, user=self._user).format_context_prompt_injection(tool_context)}\n" # type: ignore
f"</{tool_name}>"
for tool_name, tool_context in self.get_contextual_tools().items()
if get_contextual_tool_class(tool_name) is not None

View File

@@ -507,7 +507,7 @@ Query results: 42 events
# Mock the tool class
mock_tool = MagicMock()
mock_tool.format_context_prompt_injection.return_value = "Tool system prompt"
mock_get_contextual_tool_class.return_value = lambda team, user, tool_call_id: mock_tool
mock_get_contextual_tool_class.return_value = lambda team, user: mock_tool
config = RunnableConfig(
configurable={"contextual_tools": {"search_session_recordings": {"current_filters": {}}}}

View File

@@ -52,7 +52,7 @@ from ee.hogai.eval.base import MaxPrivateEval
from ee.hogai.eval.offline.conftest import EvaluationContext, capture_score, get_eval_context
from ee.hogai.eval.schema import DatasetInput
from ee.hogai.eval.scorers.sql import SQLSemanticsCorrectness, SQLSyntaxCorrectness
from ee.hogai.graph import AssistantGraph
from ee.hogai.graph.graph import AssistantGraph
from ee.hogai.utils.types import AssistantState
from ee.models import Conversation

View File

@@ -19,7 +19,7 @@ from ee.hogai.django_checkpoint.checkpointer import DjangoCheckpointer
# We want the PostHog set_up_evals fixture here
from ee.hogai.eval.conftest import set_up_evals # noqa: F401
from ee.hogai.eval.scorers import PlanAndQueryOutput
from ee.hogai.graph.graph import AssistantGraph, InsightsAssistantGraph
from ee.hogai.graph.graph import AssistantGraph
from ee.hogai.utils.types import AssistantNodeName, AssistantState
from ee.models.assistant import Conversation, CoreMemory
@@ -34,28 +34,10 @@ EVAL_USER_FULL_NAME = "Karen Smith"
@pytest.fixture
def call_root_for_insight_generation(demo_org_team_user):
# This graph structure will first get a plan, then generate the SQL query.
insights_subgraph = (
# Insights subgraph without query execution, so we only create the queries
InsightsAssistantGraph(demo_org_team_user[1], demo_org_team_user[2])
.add_query_creation_flow(next_node=AssistantNodeName.END)
.compile()
)
graph = (
AssistantGraph(demo_org_team_user[1], demo_org_team_user[2])
.add_edge(AssistantNodeName.START, AssistantNodeName.ROOT)
.add_root(
path_map={
"insights": AssistantNodeName.INSIGHTS_SUBGRAPH,
"insights_search": AssistantNodeName.INSIGHTS_SEARCH,
"root": AssistantNodeName.ROOT,
"search_documentation": AssistantNodeName.END,
"end": AssistantNodeName.END,
}
)
.add_node(AssistantNodeName.INSIGHTS_SUBGRAPH, insights_subgraph)
.add_edge(AssistantNodeName.INSIGHTS_SUBGRAPH, AssistantNodeName.END)
.add_insights_search()
.add_root()
# TRICKY: We need to set a checkpointer here because async tests create a new event loop.
.compile(checkpointer=DjangoCheckpointer())
)
@@ -85,15 +67,21 @@ def call_root_for_insight_generation(demo_org_team_user):
final_state_raw = await graph.ainvoke(final_state, {"configurable": {"thread_id": conversation.id}})
final_state = AssistantState.model_validate(final_state_raw)
if not final_state.messages or not isinstance(final_state.messages[-1], VisualizationMessage):
# The order is a viz message, tool call message, and assistant message.
if (
not final_state.messages
or not len(final_state.messages) >= 3
or not isinstance(final_state.messages[-3], VisualizationMessage)
):
return {
"plan": None,
"query": None,
"query_generation_retry_count": final_state.query_generation_retry_count,
}
return {
"plan": final_state.messages[-1].plan,
"query": final_state.messages[-1].answer,
"plan": final_state.messages[-3].plan,
"query": final_state.messages[-3].answer,
"query_generation_retry_count": final_state.query_generation_retry_count,
}

View File

@@ -1,5 +1,4 @@
import pytest
from unittest.mock import MagicMock, patch
from braintrust import EvalCase
from langchain_core.runnables import RunnableConfig
@@ -7,8 +6,8 @@ from langchain_core.runnables import RunnableConfig
from posthog.schema import AssistantMessage, AssistantToolCall, HumanMessage
from ee.hogai.django_checkpoint.checkpointer import DjangoCheckpointer
from ee.hogai.graph import AssistantGraph
from ee.hogai.graph.dashboards.nodes import DashboardCreationNode
from ee.hogai.graph.graph import AssistantGraph
from ee.hogai.utils.types import AssistantMessageUnion, AssistantNodeName, AssistantState, PartialAssistantState
from ee.models.assistant import Conversation
@@ -21,18 +20,7 @@ def call_root_for_dashboard_creation(demo_org_team_user):
graph = (
AssistantGraph(demo_org_team_user[1], demo_org_team_user[2])
.add_edge(AssistantNodeName.START, AssistantNodeName.ROOT)
.add_root(
{
"create_dashboard": AssistantNodeName.END,
"insights": AssistantNodeName.END,
"search_documentation": AssistantNodeName.END,
"session_summarization": AssistantNodeName.END,
"insights_search": AssistantNodeName.END,
"create_and_query_insight": AssistantNodeName.END,
"root": AssistantNodeName.END,
"end": AssistantNodeName.END,
}
)
.add_root(lambda state: AssistantNodeName.END)
.compile(checkpointer=DjangoCheckpointer())
)
@@ -172,9 +160,7 @@ async def eval_tool_routing_dashboard_creation(call_root_for_dashboard_creation,
)
@pytest.mark.django_db
@patch("ee.hogai.graph.base.get_stream_writer", return_value=MagicMock())
async def eval_tool_call_dashboard_creation(patch_get_stream_writer, pytestconfig, demo_org_team_user):
async def eval_tool_call_dashboard_creation(pytestconfig, demo_org_team_user):
conversation = await Conversation.objects.acreate(team=demo_org_team_user[1], user=demo_org_team_user[2])
dashboard_creation_node = DashboardCreationNode(demo_org_team_user[1], demo_org_team_user[2])

View File

@@ -8,7 +8,7 @@ from braintrust import EvalCase
from posthog.schema import HumanMessage, VisualizationMessage
from ee.hogai.django_checkpoint.checkpointer import DjangoCheckpointer
from ee.hogai.graph import AssistantGraph
from ee.hogai.graph.graph import AssistantGraph
from ee.hogai.utils.types import AssistantNodeName, AssistantState
from ee.models.assistant import Conversation
@@ -84,16 +84,7 @@ def call_insight_search(demo_org_team_user):
graph = (
AssistantGraph(demo_org_team_user[1], demo_org_team_user[2])
.add_edge(AssistantNodeName.START, AssistantNodeName.ROOT)
.add_root(
{
"insights": AssistantNodeName.END,
"search_documentation": AssistantNodeName.END,
"root": AssistantNodeName.END,
"end": AssistantNodeName.END,
"insights_search": AssistantNodeName.INSIGHTS_SEARCH,
}
)
.add_insights_search()
.add_root()
.compile(checkpointer=DjangoCheckpointer())
)

View File

@@ -9,7 +9,7 @@ from langchain_core.messages import AIMessage as LangchainAIMessage
from posthog.schema import AssistantMessage, AssistantToolCall, HumanMessage
from ee.hogai.django_checkpoint.checkpointer import DjangoCheckpointer
from ee.hogai.graph import AssistantGraph
from ee.hogai.graph.graph import AssistantGraph
from ee.hogai.utils.types import AssistantNodeName, AssistantState
from ee.models.assistant import Conversation

View File

@@ -13,7 +13,7 @@ from posthog.models.user import User
from posthog.sync import database_sync_to_async
from ee.hogai.django_checkpoint.checkpointer import DjangoCheckpointer
from ee.hogai.graph import AssistantGraph
from ee.hogai.graph.graph import AssistantGraph
from ee.hogai.graph.memory.prompts import (
ENQUIRY_INITIAL_MESSAGE,
SCRAPING_SUCCESS_KEY_PHRASE,

View File

@@ -7,7 +7,7 @@ from braintrust import EvalCase
from posthog.schema import AssistantMessage, AssistantToolCall, AssistantToolCallMessage, HumanMessage
from ee.hogai.django_checkpoint.checkpointer import DjangoCheckpointer
from ee.hogai.graph import AssistantGraph
from ee.hogai.graph.graph import AssistantGraph
from ee.hogai.utils.types import AssistantMessageUnion, AssistantNodeName, AssistantState
from ee.models.assistant import Conversation
@@ -20,16 +20,7 @@ def call_root(demo_org_team_user):
graph = (
AssistantGraph(demo_org_team_user[1], demo_org_team_user[2])
.add_edge(AssistantNodeName.START, AssistantNodeName.ROOT)
.add_root(
{
"insights": AssistantNodeName.END,
"billing": AssistantNodeName.END,
"insights_search": AssistantNodeName.END,
"search_documentation": AssistantNodeName.END,
"root": AssistantNodeName.ROOT,
"end": AssistantNodeName.END,
}
)
.add_root(lambda state: AssistantNodeName.END)
# TRICKY: We need to set a checkpointer here because async tests create a new event loop.
.compile(checkpointer=DjangoCheckpointer())
)

View File

@@ -5,7 +5,7 @@ from braintrust import EvalCase
from posthog.schema import AssistantMessage, AssistantToolCall, HumanMessage
from ee.hogai.django_checkpoint.checkpointer import DjangoCheckpointer
from ee.hogai.graph import AssistantGraph
from ee.hogai.graph.graph import AssistantGraph
from ee.hogai.utils.types import AssistantNodeName, AssistantState
from ee.models.assistant import Conversation
@@ -18,17 +18,7 @@ def call_root(demo_org_team_user):
graph = (
AssistantGraph(demo_org_team_user[1], demo_org_team_user[2])
.add_edge(AssistantNodeName.START, AssistantNodeName.ROOT)
.add_root(
path_map={
"insights": AssistantNodeName.END,
"billing": AssistantNodeName.END,
"insights_search": AssistantNodeName.END,
"search_documentation": AssistantNodeName.END,
"root": AssistantNodeName.ROOT,
"end": AssistantNodeName.END,
},
tools_node=AssistantNodeName.END,
)
.add_root(router=lambda state: AssistantNodeName.END)
# TRICKY: We need to set a checkpointer here because async tests create a new event loop.
.compile(checkpointer=DjangoCheckpointer())
)

View File

@@ -5,7 +5,7 @@ from braintrust import EvalCase
from posthog.schema import AssistantMessage, AssistantToolCall, HumanMessage
from ee.hogai.django_checkpoint.checkpointer import DjangoCheckpointer
from ee.hogai.graph import AssistantGraph
from ee.hogai.graph.graph import AssistantGraph
from ee.hogai.utils.types import AssistantNodeName, AssistantState
from ee.models.assistant import Conversation
@@ -18,17 +18,7 @@ def call_root(demo_org_team_user):
graph = (
AssistantGraph(demo_org_team_user[1], demo_org_team_user[2])
.add_edge(AssistantNodeName.START, AssistantNodeName.ROOT)
.add_root(
path_map={
"insights": AssistantNodeName.END,
"billing": AssistantNodeName.END,
"insights_search": AssistantNodeName.END,
"search_documentation": AssistantNodeName.END,
"root": AssistantNodeName.ROOT,
"end": AssistantNodeName.END,
},
tools_node=AssistantNodeName.END,
)
.add_root(router=lambda state: AssistantNodeName.END)
# TRICKY: We need to set a checkpointer here because async tests create a new event loop.
.compile(checkpointer=DjangoCheckpointer())
)

View File

@@ -5,7 +5,7 @@ from braintrust import EvalCase
from posthog.schema import AssistantMessage, HumanMessage
from ee.hogai.graph import AssistantGraph
from ee.hogai.graph.graph import AssistantGraph
from ee.hogai.utils.types import AssistantMessageUnion, AssistantNodeName, AssistantState
from ee.models.assistant import Conversation
@@ -60,19 +60,7 @@ def call_root(demo_org_team_user):
graph = (
AssistantGraph(demo_org_team_user[1], demo_org_team_user[2])
.add_edge(AssistantNodeName.START, AssistantNodeName.ROOT)
.add_root(
{
# Some requests will go via Inkeep, and this is realistic! Inkeep needs to adhere to our intended style too
"search_documentation": AssistantNodeName.INKEEP_DOCS,
"root": AssistantNodeName.ROOT,
"billing": AssistantNodeName.END,
"insights": AssistantNodeName.END,
"insights_search": AssistantNodeName.END,
"session_summarization": AssistantNodeName.END,
"end": AssistantNodeName.END,
}
)
.add_inkeep_docs()
.add_root(lambda state: AssistantNodeName.END)
.compile()
)

View File

@@ -7,7 +7,7 @@ from langchain_core.runnables import RunnableConfig
from posthog.schema import AssistantMessage, AssistantToolCall, HumanMessage
from ee.hogai.django_checkpoint.checkpointer import DjangoCheckpointer
from ee.hogai.graph import AssistantGraph
from ee.hogai.graph.graph import AssistantGraph
from ee.hogai.graph.session_summaries.nodes import _SessionSearch
from ee.hogai.utils.types import AssistantMessageUnion, AssistantNodeName, AssistantState
from ee.hogai.utils.yaml import load_yaml_from_raw_llm_content
@@ -22,16 +22,7 @@ def call_root_for_replay_sessions(demo_org_team_user):
graph = (
AssistantGraph(demo_org_team_user[1], demo_org_team_user[2])
.add_edge(AssistantNodeName.START, AssistantNodeName.ROOT)
.add_root(
{
"insights": AssistantNodeName.END,
"search_documentation": AssistantNodeName.END,
"session_summarization": AssistantNodeName.END,
"insights_search": AssistantNodeName.END,
"root": AssistantNodeName.END,
"end": AssistantNodeName.END,
}
)
.add_root(lambda state: AssistantNodeName.END)
.compile(checkpointer=DjangoCheckpointer())
)

View File

@@ -126,9 +126,7 @@ def call_surveys_max_tool(demo_org_team_user, create_feature_flags):
"change": f"Create a survey based on these instructions: {instructions}",
"output": None,
}
graph = FeatureFlagLookupGraph(team=team, user=user, tool_call_id="test-tool-call-id").compile_full_graph(
checkpointer=DjangoCheckpointer()
)
graph = FeatureFlagLookupGraph(team=team, user=user).compile_full_graph(checkpointer=DjangoCheckpointer())
result = await graph.ainvoke(
graph_context,
config={

View File

@@ -15,7 +15,7 @@ from posthog.models.action.action import Action
from posthog.models.team.team import Team
from ee.hogai.django_checkpoint.checkpointer import DjangoCheckpointer
from ee.hogai.graph import AssistantGraph
from ee.hogai.graph.graph import AssistantGraph
from ee.hogai.utils.types import AssistantNodeName, AssistantState
from ee.models.assistant import Conversation
@@ -29,14 +29,7 @@ def call_root_with_ui_context(demo_org_team_user):
graph = (
AssistantGraph(demo_org_team_user[1], demo_org_team_user[2])
.add_edge(AssistantNodeName.START, AssistantNodeName.ROOT)
.add_root(
{
"insights": AssistantNodeName.END,
"docs": AssistantNodeName.END,
"root": AssistantNodeName.END,
"end": AssistantNodeName.END,
}
)
.add_root(lambda state: AssistantNodeName.END)
# TRICKY: We need to set a checkpointer here because async tests create a new event loop.
.compile(checkpointer=DjangoCheckpointer())
)

View File

@@ -81,7 +81,6 @@ async def eval_create_feature_flag_basic(pytestconfig, demo_org_team_user):
tool = await CreateFeatureFlagTool.create_tool_class(
team=team,
user=user,
tool_call_id="test-eval-call",
state=AssistantState(messages=[]),
config={
"configurable": {
@@ -148,7 +147,6 @@ async def eval_create_feature_flag_with_rollout(pytestconfig, demo_org_team_user
tool = await CreateFeatureFlagTool.create_tool_class(
team=team,
user=user,
tool_call_id="test-eval-call",
state=AssistantState(messages=[]),
config={
"configurable": {
@@ -241,7 +239,6 @@ async def eval_create_feature_flag_with_property_filters(pytestconfig, demo_org_
tool = await CreateFeatureFlagTool.create_tool_class(
team=team,
user=user,
tool_call_id="test-eval-call",
state=AssistantState(messages=[]),
config={
"configurable": {
@@ -336,7 +333,6 @@ async def eval_create_feature_flag_duplicate_handling(pytestconfig, demo_org_tea
tool = await CreateFeatureFlagTool.create_tool_class(
team=team,
user=user,
tool_call_id="test-eval-call",
state=AssistantState(messages=[]),
config={
"configurable": {

View File

@@ -1,5 +1,4 @@
import pytest
from unittest.mock import MagicMock, patch
from braintrust import EvalCase
from langchain_core.runnables import RunnableConfig
@@ -13,18 +12,7 @@ from ee.hogai.eval.scorers import SemanticSimilarity
from ee.models.assistant import Conversation
@pytest.fixture(autouse=True)
def mock_kafka_producer():
"""Mock Kafka producer to prevent Kafka errors in tests."""
with patch("posthog.kafka_client.client._KafkaProducer.produce") as mock_produce:
mock_future = MagicMock()
mock_produce.return_value = mock_future
yield
@pytest.mark.django_db
@patch("ee.hogai.graph.base.get_stream_writer", return_value=MagicMock())
async def eval_insights_addition(patch_get_stream_writer, pytestconfig, demo_org_team_user):
async def eval_insights_addition(pytestconfig, demo_org_team_user):
"""Test that adding insights to dashboard executes correctly."""
dashboard = await Dashboard.objects.acreate(

View File

@@ -6,7 +6,7 @@ from pydantic import BaseModel, Field
from posthog.schema import AssistantMessage, AssistantNavigateUrl, AssistantToolCall, FailureMessage, HumanMessage
from ee.hogai.django_checkpoint.checkpointer import DjangoCheckpointer
from ee.hogai.graph import AssistantGraph
from ee.hogai.graph.graph import AssistantGraph
from ee.hogai.utils.types import AssistantMessageUnion, AssistantNodeName, AssistantState
from ee.models.assistant import Conversation
@@ -48,16 +48,7 @@ def call_root(demo_org_team_user):
graph = (
AssistantGraph(demo_org_team_user[1], demo_org_team_user[2])
.add_edge(AssistantNodeName.START, AssistantNodeName.ROOT)
.add_root(
{
"insights": AssistantNodeName.END,
"billing": AssistantNodeName.END,
"insights_search": AssistantNodeName.END,
"search_documentation": AssistantNodeName.END,
"root": AssistantNodeName.ROOT,
"end": AssistantNodeName.END,
}
)
.add_root(lambda state: AssistantNodeName.END)
# TRICKY: We need to set a checkpointer here because async tests create a new event loop.
.compile(checkpointer=DjangoCheckpointer())
)

View File

@@ -40,9 +40,9 @@ DUMMY_CURRENT_FILTERS = RevenueAnalyticsAssistantFilters(
@pytest.fixture
def call_filter_revenue_analytics(demo_org_team_user):
graph = RevenueAnalyticsFilterOptionsGraph(
demo_org_team_user[1], demo_org_team_user[2], tool_call_id="test-tool-call-id"
).compile_full_graph(checkpointer=DjangoCheckpointer())
graph = RevenueAnalyticsFilterOptionsGraph(demo_org_team_user[1], demo_org_team_user[2]).compile_full_graph(
checkpointer=DjangoCheckpointer()
)
async def callable(change: str) -> dict:
conversation = await Conversation.objects.acreate(team=demo_org_team_user[1], user=demo_org_team_user[2])

View File

@@ -48,9 +48,9 @@ DUMMY_CURRENT_FILTERS = MaxRecordingUniversalFilters(
@pytest.fixture
def call_search_session_recordings(demo_org_team_user):
graph = SessionReplayFilterOptionsGraph(
demo_org_team_user[1], demo_org_team_user[2], tool_call_id="test-tool-call-id"
).compile_full_graph(checkpointer=DjangoCheckpointer())
graph = SessionReplayFilterOptionsGraph(demo_org_team_user[1], demo_org_team_user[2]).compile_full_graph(
checkpointer=DjangoCheckpointer()
)
async def callable(change: str) -> dict:
conversation = await Conversation.objects.acreate(team=demo_org_team_user[1], user=demo_org_team_user[2])

View File

@@ -48,9 +48,9 @@ DUMMY_CURRENT_FILTERS = MaxRecordingUniversalFilters(
@pytest.fixture
def call_search_session_recordings(demo_org_team_user):
graph = SessionReplayFilterOptionsGraph(
demo_org_team_user[1], demo_org_team_user[2], tool_call_id="test-tool-call-id"
).compile_full_graph(checkpointer=DjangoCheckpointer())
graph = SessionReplayFilterOptionsGraph(demo_org_team_user[1], demo_org_team_user[2]).compile_full_graph(
checkpointer=DjangoCheckpointer()
)
async def callable(change: str) -> dict:
conversation = await Conversation.objects.acreate(team=demo_org_team_user[1], user=demo_org_team_user[2])

View File

@@ -18,7 +18,7 @@ from ee.hogai.eval.base import MaxPrivateEval
from ee.hogai.eval.offline.conftest import EvaluationContext, capture_score, get_eval_context
from ee.hogai.eval.schema import DatasetInput
from ee.hogai.eval.scorers.sql import SQLSemanticsCorrectness, SQLSyntaxCorrectness
from ee.hogai.graph import AssistantGraph
from ee.hogai.graph.graph import AssistantGraph
from ee.hogai.utils.helpers import find_last_message_of_type
from ee.hogai.utils.types import AssistantState
from ee.hogai.utils.warehouse import serialize_database_schema

View File

@@ -41,27 +41,29 @@ class ToolRelevance(ScorerWithPartial):
raise TypeError(f"Eval case expected must be an AssistantToolCall, not {type(expected)}")
if not isinstance(output, AssistantMessage):
raise TypeError(f"Eval case output must be an AssistantMessage, not {type(output)}")
if output.tool_calls and len(output.tool_calls) > 1:
raise ValueError("Parallel tool calls not supported by this scorer yet")
score = 0.0 # 0.0 to 1.0
if output.tool_calls and len(output.tool_calls) == 1:
tool_call = output.tool_calls[0]
# 0.5 point for getting the tool right
if tool_call.name == expected.name:
score += 0.5
if not expected.args:
score += 0.5 if not tool_call.args else 0 # If no args expected, only score for lack of args
else:
score_per_arg = 0.5 / len(expected.args)
for arg_name, expected_arg_value in expected.args.items():
if arg_name in self.semantic_similarity_args:
arg_similarity = AnswerSimilarity(model="text-embedding-3-small").eval(
output=tool_call.args.get(arg_name), expected=expected_arg_value
)
score += arg_similarity.score * score_per_arg
elif tool_call.args.get(arg_name) == expected_arg_value:
score += score_per_arg
return Score(name=self._name(), score=score)
best_score = 0.0 # 0.0 to 1.0
if output.tool_calls:
# Check all tool calls and return the best match
for tool_call in output.tool_calls:
score = 0.0
# 0.5 point for getting the tool right
if tool_call.name == expected.name:
score += 0.5
if not expected.args:
score += 0.5 if not tool_call.args else 0 # If no args expected, only score for lack of args
else:
score_per_arg = 0.5 / len(expected.args)
for arg_name, expected_arg_value in expected.args.items():
if arg_name in self.semantic_similarity_args:
arg_similarity = AnswerSimilarity(model="text-embedding-3-small").eval(
output=tool_call.args.get(arg_name), expected=expected_arg_value
)
score += arg_similarity.score * score_per_arg
elif tool_call.args.get(arg_name) == expected_arg_value:
score += score_per_arg
best_score = max(best_score, score)
return Score(name=self._name(), score=best_score)
class PlanAndQueryOutput(TypedDict, total=False):

View File

@@ -1,33 +0,0 @@
from .deep_research.graph import DeepResearchAssistantGraph
from .funnels.nodes import FunnelGeneratorNode
from .graph import AssistantGraph, InsightsAssistantGraph
from .inkeep_docs.nodes import InkeepDocsNode
from .insights.nodes import InsightSearchNode
from .memory.nodes import MemoryInitializerNode
from .query_executor.nodes import QueryExecutorNode
from .query_planner.nodes import QueryPlannerNode
from .rag.nodes import InsightRagContextNode
from .retention.nodes import RetentionGeneratorNode
from .root.nodes import RootNode, RootNodeTools
from .schema_generator.nodes import SchemaGeneratorNode
from .sql.nodes import SQLGeneratorNode
from .trends.nodes import TrendsGeneratorNode
__all__ = [
"FunnelGeneratorNode",
"InkeepDocsNode",
"MemoryInitializerNode",
"QueryExecutorNode",
"InsightRagContextNode",
"RetentionGeneratorNode",
"RootNode",
"RootNodeTools",
"SchemaGeneratorNode",
"SQLGeneratorNode",
"QueryPlannerNode",
"TrendsGeneratorNode",
"AssistantGraph",
"InsightsAssistantGraph",
"InsightSearchNode",
"DeepResearchAssistantGraph",
]

View File

@@ -0,0 +1,9 @@
from .graph import BaseAssistantGraph, global_checkpointer
from .node import AssistantNode, BaseAssistantNode
__all__ = [
"BaseAssistantNode",
"AssistantNode",
"BaseAssistantGraph",
"global_checkpointer",
]

View File

@@ -0,0 +1,22 @@
import contextvars
from contextlib import contextmanager
from ee.hogai.utils.types.base import NodePath
node_path_context = contextvars.ContextVar[tuple[NodePath, ...]]("node_path_context")
@contextmanager
def set_node_path(node_path: tuple[NodePath, ...]):
token = node_path_context.set(node_path)
try:
yield
finally:
node_path_context.reset(token)
def get_node_path() -> tuple[NodePath, ...] | None:
try:
return node_path_context.get()
except LookupError:
return None

View File

@@ -0,0 +1,88 @@
from abc import ABC, abstractmethod
from collections.abc import Callable
from functools import wraps
from typing import TYPE_CHECKING, Any, Generic, Literal, TypeVar
from langgraph.graph.state import StateGraph
from posthog.models import Team, User
from ee.hogai.django_checkpoint.checkpointer import DjangoCheckpointer
from ee.hogai.utils.types.base import AssistantGraphName, AssistantNodeName, NodePath, PartialStateType, StateType
from .context import get_node_path, set_node_path
from .node import BaseAssistantNode
if TYPE_CHECKING:
from ee.hogai.utils.types.composed import MaxNodeName
# Base checkpointer for all graphs
global_checkpointer = DjangoCheckpointer()
T = TypeVar("T")
def with_node_path(func: Callable[..., T]) -> Callable[..., T]:
@wraps(func)
def wrapper(self, *args: Any, **kwargs: Any) -> T:
with set_node_path(self.node_path):
return func(self, *args, **kwargs)
return wrapper
class BaseAssistantGraph(Generic[StateType, PartialStateType], ABC):
_team: Team
_user: User
_graph: StateGraph
_node_path: tuple[NodePath, ...]
def __init__(
self,
team: Team,
user: User,
):
self._team = team
self._user = user
self._has_start_node = False
self._graph = StateGraph(self.state_type)
self._node_path = (*(get_node_path() or ()), NodePath(name=self.graph_name.value))
def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
# Wrap all public methods with the node path context
for name, method in cls.__dict__.items():
if callable(method) and not name.startswith("_") and name not in ("graph_name", "state_type", "node_path"):
setattr(cls, name, with_node_path(method))
@property
@abstractmethod
def state_type(self) -> type[StateType]: ...
@property
@abstractmethod
def graph_name(self) -> AssistantGraphName: ...
@property
def node_path(self) -> tuple[NodePath, ...]:
return self._node_path
def add_edge(self, from_node: "MaxNodeName", to_node: "MaxNodeName"):
if from_node == AssistantNodeName.START:
self._has_start_node = True
self._graph.add_edge(from_node, to_node)
return self
def add_node(self, node: "MaxNodeName", action: BaseAssistantNode[StateType, PartialStateType]):
self._graph.add_node(node, action)
return self
def compile(self, checkpointer: DjangoCheckpointer | None | Literal[False] = None):
if not self._has_start_node:
raise ValueError("Start node not added to the graph")
# TRICKY: We check `is not None` because False has a special meaning of "no checkpointer", which we want to pass on
compiled_graph = self._graph.compile(
checkpointer=checkpointer if checkpointer is not None else global_checkpointer
)
return compiled_graph

View File

@@ -1,49 +1,45 @@
from abc import ABC, abstractmethod
from collections.abc import Sequence
from abc import ABC
from typing import Generic
from uuid import UUID
from django.conf import settings
from langchain_core.runnables import RunnableConfig
from langgraph.config import get_stream_writer
from posthog.schema import AssistantMessage, AssistantToolCall, HumanMessage
from posthog.schema import HumanMessage
from posthog.models import Team
from posthog.models.user import User
from posthog.models import Team, User
from posthog.sync import database_sync_to_async
from ee.hogai.context import AssistantContextManager
from ee.hogai.graph.mixins import AssistantContextMixin
from ee.hogai.utils.dispatcher import AssistantDispatcher
from ee.hogai.graph.base.context import get_node_path, set_node_path
from ee.hogai.graph.mixins import AssistantContextMixin, AssistantDispatcherMixin
from ee.hogai.utils.exceptions import GenerationCanceled
from ee.hogai.utils.helpers import find_start_message
from ee.hogai.utils.types import (
AssistantMessageUnion,
from ee.hogai.utils.types.base import (
AssistantState,
NodeEndAction,
NodePath,
NodeStartAction,
PartialAssistantState,
PartialStateType,
StateType,
)
from ee.hogai.utils.types.composed import MaxNodeName
from ee.models import Conversation
class BaseAssistantNode(Generic[StateType, PartialStateType], AssistantContextMixin, ABC):
class BaseAssistantNode(Generic[StateType, PartialStateType], AssistantContextMixin, AssistantDispatcherMixin, ABC):
_config: RunnableConfig | None = None
_context_manager: AssistantContextManager | None = None
_dispatcher: AssistantDispatcher | None = None
_parent_tool_call_id: str | None = None
_node_path: tuple[NodePath, ...]
def __init__(self, team: Team, user: User):
def __init__(self, team: Team, user: User, node_path: tuple[NodePath, ...] | None = None):
self._team = team
self._user = user
@property
@abstractmethod
def node_name(self) -> MaxNodeName:
raise NotImplementedError
if node_path is None:
self._node_path = (*(get_node_path() or ()), NodePath(name=self.node_name))
else:
self._node_path = node_path
async def __call__(self, state: StateType, config: RunnableConfig) -> PartialStateType | None:
"""
@@ -54,25 +50,19 @@ class BaseAssistantNode(Generic[StateType, PartialStateType], AssistantContextMi
self._dispatcher = None
self._config = config
if isinstance(state, AssistantState) and state.root_tool_call_id:
# NOTE: we set the parent tool call id as the root tool call id
# This will be deprecated once all tools become MaxTools and are removed from the graph
self._parent_tool_call_id = state.root_tool_call_id
self.dispatcher.node_start()
self.dispatcher.dispatch(NodeStartAction())
thread_id = (config.get("configurable") or {}).get("thread_id")
if thread_id and await self._is_conversation_cancelled(thread_id):
raise GenerationCanceled
try:
new_state = await self.arun(state, config)
new_state = await self._arun_with_context(state, config)
except NotImplementedError:
new_state = await database_sync_to_async(self.run, thread_sensitive=False)(state, config)
new_state = await database_sync_to_async(self._run_with_context, thread_sensitive=False)(state, config)
self.dispatcher.dispatch(NodeEndAction(state=new_state))
if new_state is not None and (messages := getattr(new_state, "messages", [])):
for message in messages:
self.dispatcher.message(message)
return new_state
def run(self, state: StateType, config: RunnableConfig) -> PartialStateType | None:
@@ -82,6 +72,14 @@ class BaseAssistantNode(Generic[StateType, PartialStateType], AssistantContextMi
async def arun(self, state: StateType, config: RunnableConfig) -> PartialStateType | None:
raise NotImplementedError
def _run_with_context(self, state: StateType, config: RunnableConfig) -> PartialStateType | None:
with set_node_path(self.node_path):
return self.run(state, config)
async def _arun_with_context(self, state: StateType, config: RunnableConfig) -> PartialStateType | None:
with set_node_path(self.node_path):
return await self.arun(state, config)
@property
def context_manager(self) -> AssistantContextManager:
if self._context_manager is None:
@@ -97,26 +95,13 @@ class BaseAssistantNode(Generic[StateType, PartialStateType], AssistantContextMi
return self._context_manager
@property
def dispatcher(self) -> AssistantDispatcher:
"""Create a dispatcher for this node"""
if self._dispatcher:
return self._dispatcher
# Set writer from LangGraph context
try:
writer = get_stream_writer()
except RuntimeError:
# Not in streaming context (e.g., testing)
# Use noop writer
def noop(*_args, **_kwargs):
pass
writer = noop
self._dispatcher = AssistantDispatcher(
writer, node_name=self.node_name, parent_tool_call_id=self._parent_tool_call_id
)
return self._dispatcher
def node_name(self) -> str:
config_name: str | None = None
if self._config:
config_name = self._config["metadata"].get("langgraph_node")
if config_name is not None:
config_name = str(config_name)
return config_name or self.__class__.__name__
async def _is_conversation_cancelled(self, conversation_id: UUID) -> bool:
conversation = await self._aget_conversation(conversation_id)
@@ -124,15 +109,6 @@ class BaseAssistantNode(Generic[StateType, PartialStateType], AssistantContextMi
raise ValueError(f"Conversation {conversation_id} not found")
return conversation.status == Conversation.Status.CANCELING
def _get_tool_call(self, messages: Sequence[AssistantMessageUnion], tool_call_id: str) -> AssistantToolCall:
for message in reversed(messages):
if not isinstance(message, AssistantMessage) or not message.tool_calls:
continue
for tool_call in message.tool_calls:
if tool_call.id == tool_call_id:
return tool_call
raise ValueError(f"Tool call {tool_call_id} not found in state")
def _is_first_turn(self, state: AssistantState) -> bool:
last_message = state.messages[-1]
if isinstance(last_message, HumanMessage):

View File

@@ -0,0 +1,47 @@
from posthog.test.base import BaseTest
from langgraph.checkpoint.memory import InMemorySaver
from ee.hogai.graph.base import AssistantNode
from ee.hogai.graph.base.graph import BaseAssistantGraph
from ee.hogai.utils.types import AssistantNodeName, AssistantState, PartialAssistantState
from ee.hogai.utils.types.base import AssistantGraphName
from ee.models import Conversation
class TestAssistantGraph(BaseTest):
async def test_pydantic_state_resets_with_none(self):
"""When a None field is set, it should be reset to None."""
class TestAssistantGraph(BaseAssistantGraph[AssistantState, PartialAssistantState]):
@property
def state_type(self) -> type[AssistantState]:
return AssistantState
@property
def graph_name(self) -> AssistantGraphName:
return AssistantGraphName.ASSISTANT
graph = TestAssistantGraph(self.team, self.user)
class TestNode(AssistantNode):
@property
def node_name(self):
return AssistantNodeName.ROOT
async def arun(self, state, config):
return PartialAssistantState(start_id=None)
compiled_graph = (
graph.add_node(AssistantNodeName.ROOT, TestNode(self.team, self.user))
.add_edge(AssistantNodeName.START, AssistantNodeName.ROOT)
.add_edge(AssistantNodeName.ROOT, AssistantNodeName.END)
.compile(checkpointer=InMemorySaver())
)
conversation = await Conversation.objects.acreate(team=self.team, user=self.user)
state = await compiled_graph.ainvoke(
AssistantState(messages=[], graph_status="resumed", start_id=None),
{"configurable": {"thread_id": conversation.id}},
)
self.assertEqual(state["start_id"], None)
self.assertEqual(state["graph_status"], "resumed")

View File

@@ -0,0 +1,502 @@
from posthog.test.base import BaseTest
from langchain_core.runnables import RunnableConfig
from ee.hogai.graph.base import AssistantNode
from ee.hogai.graph.base.context import get_node_path
from ee.hogai.graph.base.graph import BaseAssistantGraph
from ee.hogai.utils.types import AssistantNodeName, AssistantState, PartialAssistantState
from ee.hogai.utils.types.base import AssistantGraphName, NodePath
from ee.models import Conversation
class TestNodePath(BaseTest):
"""
Tests for node_path functionality across sync/async methods and graph compositions.
"""
async def test_graph_to_async_node_has_two_elements(self):
"""Graph -> Node (async) should have path: [graph, node]"""
captured_path = None
class TestNode(AssistantNode):
async def arun(self, state, config):
nonlocal captured_path
captured_path = get_node_path()
return None
class TestGraph(BaseAssistantGraph[AssistantState, PartialAssistantState]):
@property
def state_type(self) -> type[AssistantState]:
return AssistantState
@property
def graph_name(self) -> AssistantGraphName:
return AssistantGraphName.ASSISTANT
def setup(self):
node = TestNode(self._team, self._user)
self.add_node(AssistantNodeName.ROOT, node)
self.add_edge(AssistantNodeName.START, AssistantNodeName.ROOT)
self.add_edge(AssistantNodeName.ROOT, AssistantNodeName.END)
return self
graph = TestGraph(self.team, self.user)
compiled = graph.setup().compile(checkpointer=False)
conversation = await Conversation.objects.acreate(team=self.team, user=self.user)
await compiled.ainvoke(AssistantState(messages=[]), {"configurable": {"thread_id": conversation.id}})
assert captured_path is not None
self.assertEqual(len(captured_path), 2)
self.assertEqual(captured_path[0].name, AssistantGraphName.ASSISTANT.value)
# Node name is determined at init time, so it's the class name
self.assertEqual(captured_path[1].name, "TestNode")
async def test_graph_to_sync_node_has_two_elements(self):
"""Graph -> Node (sync) should have path: [graph, node]"""
captured_path = None
class TestNode(AssistantNode):
def run(self, state, config):
nonlocal captured_path
captured_path = get_node_path()
return None
class TestGraph(BaseAssistantGraph[AssistantState, PartialAssistantState]):
@property
def state_type(self) -> type[AssistantState]:
return AssistantState
@property
def graph_name(self) -> AssistantGraphName:
return AssistantGraphName.ASSISTANT
def setup(self):
node = TestNode(self._team, self._user)
self.add_node(AssistantNodeName.ROOT, node)
self.add_edge(AssistantNodeName.START, AssistantNodeName.ROOT)
self.add_edge(AssistantNodeName.ROOT, AssistantNodeName.END)
return self
graph = TestGraph(self.team, self.user)
compiled = graph.setup().compile(checkpointer=False)
conversation = await Conversation.objects.acreate(team=self.team, user=self.user)
await compiled.ainvoke(AssistantState(messages=[]), {"configurable": {"thread_id": conversation.id}})
assert captured_path is not None
self.assertEqual(len(captured_path), 2)
self.assertEqual(captured_path[0].name, AssistantGraphName.ASSISTANT.value)
self.assertEqual(captured_path[1].name, "TestNode")
async def test_graph_to_node_to_async_node_has_three_elements(self):
"""Graph -> Node -> Node (async) should have path: [graph, node, node]"""
captured_paths = []
class SecondNode(AssistantNode):
async def arun(self, state, config):
nonlocal captured_paths
captured_paths.append(get_node_path())
return None
class FirstNode(AssistantNode):
async def arun(self, state, config):
nonlocal captured_paths
captured_paths.append(get_node_path())
# Call second node
second_node = SecondNode(self._team, self._user)
await second_node(state, config)
return None
class TestGraph(BaseAssistantGraph[AssistantState, PartialAssistantState]):
@property
def state_type(self) -> type[AssistantState]:
return AssistantState
@property
def graph_name(self) -> AssistantGraphName:
return AssistantGraphName.ASSISTANT
def setup(self):
node = FirstNode(self._team, self._user)
self.add_node(AssistantNodeName.ROOT, node)
self.add_edge(AssistantNodeName.START, AssistantNodeName.ROOT)
self.add_edge(AssistantNodeName.ROOT, AssistantNodeName.END)
return self
graph = TestGraph(self.team, self.user)
compiled = graph.setup().compile(checkpointer=False)
conversation = await Conversation.objects.acreate(team=self.team, user=self.user)
await compiled.ainvoke(AssistantState(messages=[]), {"configurable": {"thread_id": conversation.id}})
# First node: [graph, node]
assert captured_paths[0] is not None
self.assertEqual(len(captured_paths[0]), 2)
self.assertEqual(captured_paths[0][0].name, AssistantGraphName.ASSISTANT.value)
self.assertEqual(captured_paths[0][1].name, "FirstNode")
# Second node: [graph, node, node]
assert captured_paths[1] is not None
self.assertEqual(len(captured_paths[1]), 3)
self.assertEqual(captured_paths[1][0].name, AssistantGraphName.ASSISTANT.value)
self.assertEqual(captured_paths[1][1].name, "FirstNode")
self.assertEqual(captured_paths[1][2].name, "SecondNode")
async def test_graph_to_node_to_sync_node_has_three_elements(self):
"""Graph -> Node -> Node (sync) - calling .run() directly doesn't extend path"""
captured_paths = []
class SecondNode(AssistantNode):
def run(self, state, config):
nonlocal captured_paths
captured_paths.append(get_node_path())
return None
class FirstNode(AssistantNode):
def run(self, state, config):
nonlocal captured_paths
captured_paths.append(get_node_path())
# Call second node - note: calling run() directly bypasses context setting,
# so the second node won't have the proper path. This tests that direct .run()
# calls don't propagate context properly.
second_node = SecondNode(self._team, self._user)
second_node.run(state, config)
return None
class TestGraph(BaseAssistantGraph[AssistantState, PartialAssistantState]):
@property
def state_type(self) -> type[AssistantState]:
return AssistantState
@property
def graph_name(self) -> AssistantGraphName:
return AssistantGraphName.ASSISTANT
def setup(self):
node = FirstNode(self._team, self._user)
self.add_node(AssistantNodeName.ROOT, node)
self.add_edge(AssistantNodeName.START, AssistantNodeName.ROOT)
self.add_edge(AssistantNodeName.ROOT, AssistantNodeName.END)
return self
graph = TestGraph(self.team, self.user)
compiled = graph.setup().compile(checkpointer=False)
conversation = await Conversation.objects.acreate(team=self.team, user=self.user)
await compiled.ainvoke(AssistantState(messages=[]), {"configurable": {"thread_id": conversation.id}})
# First node: [graph, node]
assert captured_paths[0] is not None
self.assertEqual(len(captured_paths[0]), 2)
self.assertEqual(captured_paths[0][0].name, AssistantGraphName.ASSISTANT.value)
self.assertEqual(captured_paths[0][1].name, "FirstNode")
# Second node: [graph, node] - initialized within FirstNode's context, so gets same path
assert captured_paths[1] is not None
self.assertEqual(len(captured_paths[1]), 2)
self.assertEqual(captured_paths[1][0].name, AssistantGraphName.ASSISTANT.value)
self.assertEqual(captured_paths[1][1].name, "FirstNode") # Same as first because initialized in same context
async def test_graph_to_node_to_graph_to_node_has_four_elements(self):
"""Graph -> Node -> Graph -> Node should have path: [graph, node, graph, node]"""
captured_paths = []
class InnerNode(AssistantNode):
async def arun(self, state, config):
nonlocal captured_paths
captured_paths.append(get_node_path())
return None
class InnerGraph(BaseAssistantGraph[AssistantState, PartialAssistantState]):
@property
def state_type(self) -> type[AssistantState]:
return AssistantState
@property
def graph_name(self) -> AssistantGraphName:
return AssistantGraphName.INSIGHTS
def setup(self):
node = InnerNode(self._team, self._user)
self.add_node(AssistantNodeName.TRENDS_GENERATOR, node)
self.add_edge(AssistantNodeName.START, AssistantNodeName.TRENDS_GENERATOR)
self.add_edge(AssistantNodeName.TRENDS_GENERATOR, AssistantNodeName.END)
return self
class OuterNode(AssistantNode):
async def arun(self, state, config):
nonlocal captured_paths
captured_paths.append(get_node_path())
# Call inner graph
inner_graph = InnerGraph(self._team, self._user)
compiled_inner = inner_graph.setup().compile(checkpointer=False)
await compiled_inner.ainvoke(state, config)
return None
class OuterGraph(BaseAssistantGraph[AssistantState, PartialAssistantState]):
@property
def state_type(self) -> type[AssistantState]:
return AssistantState
@property
def graph_name(self) -> AssistantGraphName:
return AssistantGraphName.ASSISTANT
def setup(self):
node = OuterNode(self._team, self._user)
self.add_node(AssistantNodeName.ROOT, node)
self.add_edge(AssistantNodeName.START, AssistantNodeName.ROOT)
self.add_edge(AssistantNodeName.ROOT, AssistantNodeName.END)
return self
outer_graph = OuterGraph(self.team, self.user)
compiled = outer_graph.setup().compile(checkpointer=False)
conversation = await Conversation.objects.acreate(team=self.team, user=self.user)
await compiled.ainvoke(AssistantState(messages=[]), {"configurable": {"thread_id": conversation.id}})
# Outer node: [graph, node]
assert captured_paths[0] is not None
self.assertEqual(len(captured_paths[0]), 2)
self.assertEqual(captured_paths[0][0].name, AssistantGraphName.ASSISTANT.value)
self.assertEqual(captured_paths[0][1].name, "OuterNode")
# Inner node: [graph, node, graph, node]
assert captured_paths[1] is not None
self.assertEqual(len(captured_paths[1]), 4)
self.assertEqual(captured_paths[1][0].name, AssistantGraphName.ASSISTANT.value)
self.assertEqual(captured_paths[1][1].name, "OuterNode")
self.assertEqual(captured_paths[1][2].name, AssistantGraphName.INSIGHTS.value)
self.assertEqual(captured_paths[1][3].name, "InnerNode")
async def test_graph_to_graph_to_node_has_three_elements(self):
"""Graph -> Graph -> Node should have path: [graph, graph, node]"""
captured_path = None
class InnerNode(AssistantNode):
async def arun(self, state, config):
nonlocal captured_path
captured_path = get_node_path()
return None
class InnerGraph(BaseAssistantGraph[AssistantState, PartialAssistantState]):
@property
def state_type(self) -> type[AssistantState]:
return AssistantState
@property
def graph_name(self) -> AssistantGraphName:
return AssistantGraphName.INSIGHTS
def setup(self):
node = InnerNode(self._team, self._user)
self.add_node(AssistantNodeName.TRENDS_GENERATOR, node)
self.add_edge(AssistantNodeName.START, AssistantNodeName.TRENDS_GENERATOR)
self.add_edge(AssistantNodeName.TRENDS_GENERATOR, AssistantNodeName.END)
return self
class OuterGraph(BaseAssistantGraph[AssistantState, PartialAssistantState]):
@property
def state_type(self) -> type[AssistantState]:
return AssistantState
@property
def graph_name(self) -> AssistantGraphName:
return AssistantGraphName.ASSISTANT
async def invoke_inner_graph(self, state, config):
inner_graph = InnerGraph(self._team, self._user)
compiled_inner = inner_graph.setup().compile(checkpointer=False)
await compiled_inner.ainvoke(state, config)
outer_graph = OuterGraph(self.team, self.user)
conversation = await Conversation.objects.acreate(team=self.team, user=self.user)
await outer_graph.invoke_inner_graph(
AssistantState(messages=[]), RunnableConfig(configurable={"thread_id": conversation.id})
)
# Inner node: [graph, node] - outer graph context is not propagated when calling graph methods directly
assert captured_path is not None
self.assertEqual(len(captured_path), 2)
self.assertEqual(captured_path[0].name, AssistantGraphName.INSIGHTS.value)
self.assertEqual(captured_path[1].name, "InnerNode")
async def test_node_path_preserved_across_async_and_sync_methods(self):
"""Test that calling .run() directly doesn't extend path"""
captured_paths = []
class SyncNode(AssistantNode):
def run(self, state, config):
nonlocal captured_paths
captured_paths.append(("sync", get_node_path()))
return None
class AsyncNode(AssistantNode):
async def arun(self, state, config):
nonlocal captured_paths
captured_paths.append(("async", get_node_path()))
# Call sync node - calling .run() directly bypasses context setting
sync_node = SyncNode(self._team, self._user)
sync_node.run(state, config)
return None
class TestGraph(BaseAssistantGraph[AssistantState, PartialAssistantState]):
@property
def state_type(self) -> type[AssistantState]:
return AssistantState
@property
def graph_name(self) -> AssistantGraphName:
return AssistantGraphName.ASSISTANT
def setup(self):
node = AsyncNode(self._team, self._user)
self.add_node(AssistantNodeName.ROOT, node)
self.add_edge(AssistantNodeName.START, AssistantNodeName.ROOT)
self.add_edge(AssistantNodeName.ROOT, AssistantNodeName.END)
return self
graph = TestGraph(self.team, self.user)
compiled = graph.setup().compile(checkpointer=False)
conversation = await Conversation.objects.acreate(team=self.team, user=self.user)
await compiled.ainvoke(AssistantState(messages=[]), {"configurable": {"thread_id": conversation.id}})
# Async node: [graph, node]
self.assertEqual(captured_paths[0][0], "async")
assert captured_paths[0][1] is not None
self.assertEqual(len(captured_paths[0][1]), 2)
self.assertEqual(captured_paths[0][1][0].name, AssistantGraphName.ASSISTANT.value)
self.assertEqual(captured_paths[0][1][1].name, "AsyncNode")
# Sync node: [graph, node] - initialized within AsyncNode's context, so gets same path
self.assertEqual(captured_paths[1][0], "sync")
assert captured_paths[1][1] is not None
self.assertEqual(len(captured_paths[1][1]), 2)
self.assertEqual(captured_paths[1][1][0].name, AssistantGraphName.ASSISTANT.value)
self.assertEqual(captured_paths[1][1][1].name, "AsyncNode") # Same as async because initialized in same context
def test_node_path_with_explicit_node_path_parameter(self):
"""Test that explicitly passing node_path overrides default behavior"""
custom_path = (NodePath(name="custom_graph"), NodePath(name="custom_node"))
class TestNode(AssistantNode):
def run(self, state, config):
return None
node = TestNode(self.team, self.user, node_path=custom_path)
self.assertEqual(len(node._node_path), 2)
self.assertEqual(node._node_path[0].name, "custom_graph")
self.assertEqual(node._node_path[1].name, "custom_node")
async def test_multiple_nested_graphs(self):
"""Test deeply nested graph composition: Graph -> Node -> Graph -> Node -> Graph -> Node"""
captured_paths = []
class Level3Node(AssistantNode):
async def arun(self, state, config):
nonlocal captured_paths
captured_paths.append(("level3", get_node_path()))
return None
class Level3Graph(BaseAssistantGraph[AssistantState, PartialAssistantState]):
@property
def state_type(self) -> type[AssistantState]:
return AssistantState
@property
def graph_name(self) -> AssistantGraphName:
return AssistantGraphName.TAXONOMY
def setup(self):
node = Level3Node(self._team, self._user)
self.add_node(AssistantNodeName.FUNNEL_GENERATOR, node)
self.add_edge(AssistantNodeName.START, AssistantNodeName.FUNNEL_GENERATOR)
self.add_edge(AssistantNodeName.FUNNEL_GENERATOR, AssistantNodeName.END)
return self
class Level2Node(AssistantNode):
async def arun(self, state, config):
nonlocal captured_paths
captured_paths.append(("level2", get_node_path()))
# Call level 3 graph
level3_graph = Level3Graph(self._team, self._user)
compiled = level3_graph.setup().compile(checkpointer=False)
await compiled.ainvoke(state, config)
return None
class Level2Graph(BaseAssistantGraph[AssistantState, PartialAssistantState]):
@property
def state_type(self) -> type[AssistantState]:
return AssistantState
@property
def graph_name(self) -> AssistantGraphName:
return AssistantGraphName.INSIGHTS
def setup(self):
node = Level2Node(self._team, self._user)
self.add_node(AssistantNodeName.TRENDS_GENERATOR, node)
self.add_edge(AssistantNodeName.START, AssistantNodeName.TRENDS_GENERATOR)
self.add_edge(AssistantNodeName.TRENDS_GENERATOR, AssistantNodeName.END)
return self
class Level1Node(AssistantNode):
async def arun(self, state, config):
nonlocal captured_paths
captured_paths.append(("level1", get_node_path()))
# Call level 2 graph
level2_graph = Level2Graph(self._team, self._user)
compiled = level2_graph.setup().compile(checkpointer=False)
await compiled.ainvoke(state, config)
return None
class Level1Graph(BaseAssistantGraph[AssistantState, PartialAssistantState]):
@property
def state_type(self) -> type[AssistantState]:
return AssistantState
@property
def graph_name(self) -> AssistantGraphName:
return AssistantGraphName.ASSISTANT
def setup(self):
node = Level1Node(self._team, self._user)
self.add_node(AssistantNodeName.ROOT, node)
self.add_edge(AssistantNodeName.START, AssistantNodeName.ROOT)
self.add_edge(AssistantNodeName.ROOT, AssistantNodeName.END)
return self
level1_graph = Level1Graph(self.team, self.user)
compiled = level1_graph.setup().compile(checkpointer=False)
conversation = await Conversation.objects.acreate(team=self.team, user=self.user)
await compiled.ainvoke(AssistantState(messages=[]), {"configurable": {"thread_id": conversation.id}})
# Level 1: [graph, node]
assert captured_paths[0][1] is not None
self.assertEqual(len(captured_paths[0][1]), 2)
self.assertEqual(captured_paths[0][1][0].name, AssistantGraphName.ASSISTANT.value)
self.assertEqual(captured_paths[0][1][1].name, "Level1Node")
# Level 2: [graph, node, graph, node]
assert captured_paths[1][1] is not None
self.assertEqual(len(captured_paths[1][1]), 4)
self.assertEqual(captured_paths[1][1][0].name, AssistantGraphName.ASSISTANT.value)
self.assertEqual(captured_paths[1][1][1].name, "Level1Node")
self.assertEqual(captured_paths[1][1][2].name, AssistantGraphName.INSIGHTS.value)
self.assertEqual(captured_paths[1][1][3].name, "Level2Node")
# Level 3: [graph, node, graph, node, graph, node]
assert captured_paths[2][1] is not None
self.assertEqual(len(captured_paths[2][1]), 6)
self.assertEqual(captured_paths[2][1][0].name, AssistantGraphName.ASSISTANT.value)
self.assertEqual(captured_paths[2][1][1].name, "Level1Node")
self.assertEqual(captured_paths[2][1][2].name, AssistantGraphName.INSIGHTS.value)
self.assertEqual(captured_paths[2][1][3].name, "Level2Node")
self.assertEqual(captured_paths[2][1][4].name, AssistantGraphName.TAXONOMY.value)
self.assertEqual(captured_paths[2][1][5].name, "Level3Node")

View File

@@ -89,13 +89,6 @@ class DashboardCreationNode(AssistantNode):
def _get_found_insight_count(self, queries_metadata: dict[str, QueryMetadata]) -> int:
return sum(len(query.found_insight_ids) for query in queries_metadata.values())
def _dispatch_update_message(self, content: str) -> None:
self.dispatcher.message(
AssistantMessage(
content=content,
)
)
async def arun(self, state: AssistantState, config: RunnableConfig) -> PartialAssistantState:
dashboard_name = (
state.dashboard_name[:50] if state.dashboard_name else "Analytics Dashboard"
@@ -117,18 +110,18 @@ class DashboardCreationNode(AssistantNode):
for i, query in enumerate(state.search_insights_queries)
}
self._dispatch_update_message(f"Searching for {pluralize(len(state.search_insights_queries), 'insight')}")
self.dispatcher.update(f"Searching for {pluralize(len(state.search_insights_queries), 'insight')}")
result = await self._search_insights(result, config)
self._dispatch_update_message(f"Found {pluralize(self._get_found_insight_count(result), 'insight')}")
self.dispatcher.update(f"Found {pluralize(self._get_found_insight_count(result), 'insight')}")
left_to_create = {
query_id: result[query_id].query for query_id in result.keys() if not result[query_id].found_insight_ids
}
if left_to_create:
self._dispatch_update_message(f"Will create {pluralize(len(left_to_create), 'insight')}")
self.dispatcher.update(f"Will create {pluralize(len(left_to_create), 'insight')}")
result = await self._create_insights(left_to_create, result, config)
@@ -187,9 +180,7 @@ class DashboardCreationNode(AssistantNode):
message = AssistantMessage(content="", id=str(uuid4()), tool_calls=tool_calls)
executor = DashboardCreationExecutorNode(self._team, self._user)
result = await executor.arun(
AssistantState(messages=[message], root_tool_call_id=self._parent_tool_call_id), config
)
result = await executor.arun(AssistantState(messages=[message], root_tool_call_id=self.tool_call_id), config)
query_metadata = await self._process_insight_creation_results(tool_calls, result.task_results, query_metadata)
@@ -211,9 +202,7 @@ class DashboardCreationNode(AssistantNode):
message = AssistantMessage(content="", id=str(uuid4()), tool_calls=tool_calls)
executor = DashboardCreationExecutorNode(self._team, self._user)
result = await executor.arun(
AssistantState(messages=[message], root_tool_call_id=self._parent_tool_call_id), config
)
result = await executor.arun(AssistantState(messages=[message], root_tool_call_id=self.tool_call_id), config)
final_task_executor_state = BaseStateWithTasks.model_validate(result)
for task_result in final_task_executor_state.task_results:
@@ -307,7 +296,7 @@ class DashboardCreationNode(AssistantNode):
self, dashboard_name: str, insights: set[int], dashboard_id: int | None = None
) -> tuple[Dashboard, list[Insight]]:
"""Create a dashboard and add the insights to it."""
self._dispatch_update_message("Saving your dashboard")
self.dispatcher.update("Saving your dashboard")
@database_sync_to_async
@transaction.atomic

View File

@@ -1,5 +1,5 @@
import pytest
from unittest import TestCase
from posthog.test.base import BaseTest
from unittest.mock import AsyncMock, MagicMock, patch
from langchain_core.runnables import RunnableConfig
@@ -11,10 +11,17 @@ from posthog.models import Dashboard, Insight, Team, User
from ee.hogai.graph.dashboards.nodes import DashboardCreationExecutorNode, DashboardCreationNode, QueryMetadata
from ee.hogai.utils.helpers import build_dashboard_url, build_insight_url
from ee.hogai.utils.types import AssistantState, PartialAssistantState
from ee.hogai.utils.types.base import BaseStateWithTasks, InsightArtifact, InsightQuery, TaskResult
from ee.hogai.utils.types.base import (
AssistantNodeName,
BaseStateWithTasks,
InsightArtifact,
InsightQuery,
NodePath,
TaskResult,
)
class TestQueryMetadata(TestCase):
class TestQueryMetadata(BaseTest):
def test_query_metadata_initialization(self):
"""Test QueryMetadata initialization with all fields."""
query = InsightQuery(name="Test Query", description="Test Description")
@@ -33,21 +40,30 @@ class TestQueryMetadata(TestCase):
self.assertEqual(metadata.query, query)
class TestDashboardCreationExecutorNode:
@pytest.fixture(autouse=True)
def setup_method(self):
class TestDashboardCreationExecutorNode(BaseTest):
def setUp(self):
super().setUp()
self.mock_team = MagicMock(spec=Team)
self.mock_team.id = 1
self.mock_user = MagicMock(spec=User)
self.mock_user.id = 1
self.node = DashboardCreationExecutorNode(self.mock_team, self.mock_user)
self.node = DashboardCreationExecutorNode(
self.mock_team,
self.mock_user,
(
NodePath(
name=AssistantNodeName.DASHBOARD_CREATION_EXECUTOR.value,
message_id="test_message_id",
tool_call_id="test_tool_call_id",
),
),
)
def test_initialization(self):
"""Test node initialization."""
assert self.node._team == self.mock_team
assert self.node._user == self.mock_user
@pytest.mark.asyncio
async def test_aget_input_tuples_search_insights(self):
"""Test _aget_input_tuples for search_insights tasks."""
tool_calls = [
@@ -62,7 +78,6 @@ class TestDashboardCreationExecutorNode:
assert task.name == "search_insights"
assert callable_func == self.node._execute_search_insights
@pytest.mark.asyncio
async def test_aget_input_tuples_create_insight(self):
"""Test _aget_input_tuples for create_insight tasks."""
tool_calls = [AssistantToolCall(id="task_1", name="create_insight", args={"query_description": "Test prompt"})]
@@ -75,7 +90,6 @@ class TestDashboardCreationExecutorNode:
assert task.name == "create_insight"
assert callable_func == self.node._execute_create_insight
@pytest.mark.asyncio
async def test_aget_input_tuples_unsupported_task(self):
"""Test _aget_input_tuples raises error for unsupported task type."""
tool_calls = [AssistantToolCall(id="task_1", name="unsupported_type", args={"query": "Test prompt"})]
@@ -85,7 +99,6 @@ class TestDashboardCreationExecutorNode:
assert "Unsupported task type: unsupported_type" in str(exc_info.value)
@pytest.mark.asyncio
async def test_aget_input_tuples_no_tasks(self):
"""Test _aget_input_tuples returns empty list when no tasks."""
tool_calls: list[AssistantToolCall] = []
@@ -95,14 +108,24 @@ class TestDashboardCreationExecutorNode:
assert len(input_tuples) == 0
class TestDashboardCreationNode:
@pytest.fixture(autouse=True)
def setup_method(self):
class TestDashboardCreationNode(BaseTest):
def setUp(self):
super().setUp()
self.mock_team = MagicMock(spec=Team)
self.mock_team.id = 1
self.mock_user = MagicMock(spec=User)
self.mock_user.id = 1
self.node = DashboardCreationNode(self.mock_team, self.mock_user)
self.node = DashboardCreationNode(
self.mock_team,
self.mock_user,
(
NodePath(
name=AssistantNodeName.DASHBOARD_CREATION.value,
message_id="test_message_id",
tool_call_id="test_tool_call_id",
),
),
)
def test_get_found_insight_count(self):
"""Test _get_found_insight_count calculates correct count."""
@@ -140,7 +163,6 @@ class TestDashboardCreationNode:
expected_url = f"/project/{self.mock_team.id}/dashboard/{dashboard_id}"
assert url == expected_url
@pytest.mark.asyncio
@patch("ee.hogai.graph.dashboards.nodes.DashboardCreationExecutorNode")
async def test_arun_missing_search_insights_queries(self, mock_executor_node_class):
"""Test arun returns error when search_insights_queries is missing."""
@@ -150,7 +172,7 @@ class TestDashboardCreationNode:
state = AssistantState(
dashboard_name="Create dashboard",
search_insights_queries=None,
root_tool_call_id="test_call",
root_tool_call_id="test_tool_call_id",
)
config = RunnableConfig()
@@ -161,7 +183,6 @@ class TestDashboardCreationNode:
assert isinstance(result.messages[0], AssistantToolCallMessage)
assert "Search insights queries are required" in result.messages[0].content
@pytest.mark.asyncio
@patch("ee.hogai.graph.dashboards.nodes.DashboardCreationExecutorNode")
@patch.object(DashboardCreationNode, "_search_insights")
@patch.object(DashboardCreationNode, "_create_insights")
@@ -205,7 +226,7 @@ class TestDashboardCreationNode:
state = AssistantState(
dashboard_name="Create dashboard",
search_insights_queries=[InsightQuery(name="Query 1", description="Description 1")],
root_tool_call_id="test_call",
root_tool_call_id="test_tool_call_id",
)
config = RunnableConfig()
@@ -217,7 +238,6 @@ class TestDashboardCreationNode:
assert "Dashboard Created" in result.messages[0].content
assert "Test Dashboard" in result.messages[0].content
@pytest.mark.asyncio
@patch("ee.hogai.graph.dashboards.nodes.DashboardCreationExecutorNode")
@patch.object(DashboardCreationNode, "_search_insights")
@patch.object(DashboardCreationNode, "_create_insights")
@@ -244,7 +264,7 @@ class TestDashboardCreationNode:
state = AssistantState(
dashboard_name="Create dashboard",
search_insights_queries=[InsightQuery(name="Query 1", description="Description 1")],
root_tool_call_id="test_call",
root_tool_call_id="test_tool_call_id",
)
config = RunnableConfig()
@@ -255,7 +275,6 @@ class TestDashboardCreationNode:
assert isinstance(result.messages[0], AssistantToolCallMessage)
assert "No existing insights matched" in result.messages[0].content
@pytest.mark.asyncio
@patch("ee.hogai.graph.dashboards.nodes.DashboardCreationExecutorNode")
@patch.object(DashboardCreationNode, "_search_insights")
@patch("ee.hogai.graph.dashboards.nodes.logger")
@@ -268,7 +287,7 @@ class TestDashboardCreationNode:
state = AssistantState(
dashboard_name="Create dashboard",
search_insights_queries=[InsightQuery(name="Query 1", description="Description 1")],
root_tool_call_id="test_call",
root_tool_call_id="test_tool_call_id",
)
config = RunnableConfig()
@@ -295,7 +314,7 @@ class TestDashboardCreationNode:
result = self.node._create_success_response(
mock_dashboard,
mock_insights, # type: ignore[arg-type]
"test_call",
"test_tool_call_id",
["Query without insights"],
)
@@ -309,7 +328,7 @@ class TestDashboardCreationNode:
def test_create_no_insights_response(self):
"""Test _create_no_insights_response creates correct no insights message."""
result = self.node._create_no_insights_response("test_call", "No insights found")
result = self.node._create_no_insights_response("test_tool_call_id", "No insights found")
assert isinstance(result, PartialAssistantState)
assert len(result.messages) == 1
@@ -320,27 +339,30 @@ class TestDashboardCreationNode:
def test_create_error_response(self):
"""Test _create_error_response creates correct error message."""
with patch("ee.hogai.graph.dashboards.nodes.capture_exception") as mock_capture:
result = self.node._create_error_response("Test error", "test_call")
result = self.node._create_error_response("Test error", "test_tool_call_id")
assert isinstance(result, PartialAssistantState)
assert len(result.messages) == 1
assert isinstance(result.messages[0], AssistantToolCallMessage)
assert result.messages[0].content == "Test error"
assert isinstance(result.messages[0], AssistantToolCallMessage)
assert result.messages[0].tool_call_id == "test_call"
assert result.messages[0].tool_call_id == "test_tool_call_id"
mock_capture.assert_called_once()
class TestDashboardCreationNodeAsyncMethods:
@pytest.fixture(autouse=True)
def setup_method(self):
class TestDashboardCreationNodeAsyncMethods(BaseTest):
def setUp(self):
super().setUp()
self.mock_team = MagicMock(spec=Team)
self.mock_team.id = 1
self.mock_user = MagicMock(spec=User)
self.mock_user.id = 1
self.node = DashboardCreationNode(self.mock_team, self.mock_user)
self.node = DashboardCreationNode(
self.mock_team,
self.mock_user,
node_path=(NodePath(name="test_node", message_id="test-id", tool_call_id="test_tool_call_id"),),
)
@pytest.mark.asyncio
@patch("ee.hogai.graph.dashboards.nodes.DashboardCreationExecutorNode")
async def test_create_insights(self, mock_executor_node_class):
"""Test _create_insights method."""
@@ -389,7 +411,6 @@ class TestDashboardCreationNodeAsyncMethods:
assert len(result["task_1"].created_insight_ids) == 2
assert len(result["task_1"].created_insight_messages) == 1
@pytest.mark.asyncio
@patch("ee.hogai.graph.dashboards.nodes.DashboardCreationExecutorNode")
async def test_search_insights(self, mock_executor_node_class):
"""Test _search_insights method."""
@@ -430,7 +451,6 @@ class TestDashboardCreationNodeAsyncMethods:
assert len(result["task_1"].found_insight_ids) == 2
assert len(result["task_1"].found_insight_messages) == 2
@pytest.mark.asyncio
@patch("ee.hogai.graph.dashboards.nodes.database_sync_to_async")
async def test_create_dashboard_with_insights(self, mock_db_sync):
"""Test _create_dashboard_with_insights method."""
@@ -452,7 +472,6 @@ class TestDashboardCreationNodeAsyncMethods:
assert result[1] == mock_insights
mock_sync_func.assert_called_once()
@pytest.mark.asyncio
async def test_process_insight_creation_results(self):
"""Test _process_insight_creation_results method."""
# Create simple mocked Team and User instances like other tests

View File

@@ -1,21 +1,25 @@
from typing import Literal, Optional
from posthog.models.team.team import Team
from posthog.models.user import User
from ee.hogai.django_checkpoint.checkpointer import DjangoCheckpointer
from ee.hogai.graph.base import BaseAssistantGraph
from ee.hogai.graph.deep_research.notebook.nodes import DeepResearchNotebookPlanningNode
from ee.hogai.graph.deep_research.onboarding.nodes import DeepResearchOnboardingNode
from ee.hogai.graph.deep_research.planner.nodes import DeepResearchPlannerNode, DeepResearchPlannerToolsNode
from ee.hogai.graph.deep_research.report.nodes import DeepResearchReportNode
from ee.hogai.graph.deep_research.task_executor.nodes import DeepResearchTaskExecutorNode
from ee.hogai.graph.deep_research.types import DeepResearchNodeName, DeepResearchState
from ee.hogai.graph.graph import BaseAssistantGraph
from ee.hogai.graph.deep_research.types import DeepResearchNodeName, DeepResearchState, PartialDeepResearchState
from ee.hogai.graph.title_generator.nodes import TitleGeneratorNode
from ee.hogai.utils.types.base import AssistantGraphName
class DeepResearchAssistantGraph(BaseAssistantGraph[DeepResearchState]):
def __init__(self, team: Team, user: User):
super().__init__(team, user, DeepResearchState)
class DeepResearchAssistantGraph(BaseAssistantGraph[DeepResearchState, PartialDeepResearchState]):
@property
def graph_name(self) -> AssistantGraphName:
return AssistantGraphName.DEEP_RESEARCH
@property
def state_type(self) -> type[DeepResearchState]:
return DeepResearchState
def add_onboarding_node(
self, node_map: Optional[dict[Literal["onboarding", "planning", "continue"], DeepResearchNodeName]] = None
@@ -80,13 +84,21 @@ class DeepResearchAssistantGraph(BaseAssistantGraph[DeepResearchState]):
builder.add_edge(DeepResearchNodeName.REPORT, next_node)
return self
def add_title_generator(self, end_node: DeepResearchNodeName = DeepResearchNodeName.END):
self._has_start_node = True
title_generator = TitleGeneratorNode(self._team, self._user)
self._graph.add_node(DeepResearchNodeName.TITLE_GENERATOR, title_generator)
self._graph.add_edge(DeepResearchNodeName.START, DeepResearchNodeName.TITLE_GENERATOR)
self._graph.add_edge(DeepResearchNodeName.TITLE_GENERATOR, end_node)
return self
def compile_full_graph(self, checkpointer: DjangoCheckpointer | None = None):
return (
self.add_onboarding_node()
.add_notebook_nodes()
.add_planner_nodes()
.add_report_node()
.add_title_generator()
.add_task_executor()
.compile(checkpointer=checkpointer)
)

View File

@@ -40,8 +40,8 @@ from ee.hogai.graph.deep_research.types import (
DeepResearchNodeName,
DeepResearchState,
PartialDeepResearchState,
TodoItem,
)
from ee.hogai.graph.root.tools.todo_write import TodoItem
from ee.hogai.notebook.notebook_serializer import NotebookSerializer
from ee.hogai.utils.helpers import normalize_ai_message
from ee.hogai.utils.types import WithCommentary

View File

@@ -34,8 +34,7 @@ from ee.hogai.graph.deep_research.planner.prompts import (
WRITE_RESULT_FAILED_TOOL_RESULT,
WRITE_RESULT_TOOL_RESULT,
)
from ee.hogai.graph.deep_research.types import DeepResearchState, PartialDeepResearchState
from ee.hogai.graph.root.tools.todo_write import TodoItem
from ee.hogai.graph.deep_research.types import DeepResearchState, PartialDeepResearchState, TodoItem
from ee.hogai.utils.types import InsightArtifact
from ee.hogai.utils.types.base import TaskResult

View File

@@ -128,7 +128,7 @@ class TestTaskExecutorNodeArun(TestTaskExecutorNode):
class TestTaskExecutorInsightsExecution(TestTaskExecutorNode):
@patch("ee.hogai.graph.deep_research.task_executor.nodes.DeepResearchTaskExecutorNode.dispatcher")
@patch("ee.hogai.graph.deep_research.task_executor.nodes.InsightsAssistantGraph")
@patch("ee.hogai.graph.deep_research.task_executor.nodes.InsightsGraph")
async def test_execute_task_with_insights_successful(self, mock_insights_graph_class, mock_dispatcher):
"""Test successful task execution through insights pipeline."""
@@ -171,7 +171,7 @@ class TestTaskExecutorInsightsExecution(TestTaskExecutorNode):
self.assertEqual(len(result.artifacts), 1)
@patch("ee.hogai.graph.deep_research.task_executor.nodes.DeepResearchTaskExecutorNode.dispatcher")
@patch("ee.hogai.graph.deep_research.task_executor.nodes.InsightsAssistantGraph")
@patch("ee.hogai.graph.deep_research.task_executor.nodes.InsightsGraph")
async def test_execute_task_with_insights_no_artifacts(self, mock_insights_graph_class, mock_dispatcher):
"""Test task execution that produces no artifacts."""
task = self._create_assistant_tool_call()
@@ -209,7 +209,7 @@ class TestTaskExecutorInsightsExecution(TestTaskExecutorNode):
@patch("ee.hogai.graph.deep_research.task_executor.nodes.capture_exception")
@patch("ee.hogai.graph.deep_research.task_executor.nodes.DeepResearchTaskExecutorNode.dispatcher")
@patch("ee.hogai.graph.deep_research.task_executor.nodes.InsightsAssistantGraph")
@patch("ee.hogai.graph.deep_research.task_executor.nodes.InsightsGraph")
async def test_execute_task_with_exception(self, mock_insights_graph_class, mock_dispatcher, mock_capture):
"""Test task execution that encounters an exception."""
task = self._create_assistant_tool_call()

View File

@@ -27,8 +27,7 @@ from ee.hogai.graph.deep_research.onboarding.nodes import DeepResearchOnboarding
from ee.hogai.graph.deep_research.planner.nodes import DeepResearchPlannerNode, DeepResearchPlannerToolsNode
from ee.hogai.graph.deep_research.report.nodes import DeepResearchReportNode
from ee.hogai.graph.deep_research.task_executor.nodes import DeepResearchTaskExecutorNode
from ee.hogai.graph.deep_research.types import DeepResearchIntermediateResult, DeepResearchState
from ee.hogai.graph.root.tools.todo_write import TodoItem
from ee.hogai.graph.deep_research.types import DeepResearchIntermediateResult, DeepResearchState, TodoItem
from ee.hogai.utils.types.base import TaskResult
from ee.models.assistant import Conversation

View File

@@ -12,9 +12,9 @@ from ee.hogai.graph.deep_research.types import (
DeepResearchIntermediateResult,
DeepResearchState,
PartialDeepResearchState,
TodoItem,
_SharedDeepResearchState,
)
from ee.hogai.graph.root.tools.todo_write import TodoItem
from ee.hogai.utils.types.base import InsightArtifact, TaskResult
"""

View File

@@ -1,19 +1,31 @@
from collections.abc import Sequence
from enum import StrEnum
from typing import Annotated, Optional
from typing import Annotated, Literal, Optional
from langgraph.graph import END, START
from pydantic import BaseModel, Field
from posthog.schema import DeepResearchNotebook
from ee.hogai.graph.root.tools.todo_write import TodoItem
from ee.hogai.utils.types import AssistantMessageUnion, add_and_merge_messages
from ee.hogai.utils.types.base import BaseStateWithMessages, BaseStateWithTasks, append, replace
from ee.hogai.utils.types.base import (
AssistantMessageUnion,
BaseStateWithMessages,
BaseStateWithTasks,
add_and_merge_messages,
append,
replace,
)
NotebookInfo = DeepResearchNotebook
class TodoItem(BaseModel):
content: str = Field(..., min_length=1)
status: Literal["pending", "in_progress", "completed"]
id: str
priority: Literal["low", "medium", "high"]
class DeepResearchIntermediateResult(BaseModel):
"""
An intermediate result of a batch of work, that will be used to write the final report.
@@ -69,3 +81,4 @@ class DeepResearchNodeName(StrEnum):
PLANNER_TOOLS = "planner_tools"
TASK_EXECUTOR = "task_executor"
REPORT = "report"
TITLE_GENERATOR = "title_generator"

View File

@@ -1,7 +1,7 @@
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnableConfig
from posthog.schema import AssistantFunnelsQuery, AssistantMessage
from posthog.schema import AssistantFunnelsQuery
from ee.hogai.utils.types import AssistantState, PartialAssistantState
from ee.hogai.utils.types.base import AssistantNodeName
@@ -25,7 +25,7 @@ class FunnelGeneratorNode(SchemaGeneratorNode[AssistantFunnelsQuery]):
return AssistantNodeName.FUNNEL_GENERATOR
async def arun(self, state: AssistantState, config: RunnableConfig) -> PartialAssistantState:
self.dispatcher.message(AssistantMessage(content="Creating funnel query"))
self.dispatcher.update("Creating funnel query")
prompt = ChatPromptTemplate.from_messages(
[
("system", FUNNEL_SYSTEM_PROMPT),

View File

@@ -1,23 +1,11 @@
from collections.abc import Hashable
from typing import Any, Generic, Literal, Optional, cast
from langgraph.graph.state import StateGraph
from posthog.models.team.team import Team
from posthog.models.user import User
from collections.abc import Callable
from typing import cast
from ee.hogai.django_checkpoint.checkpointer import DjangoCheckpointer
from ee.hogai.graph.billing.nodes import BillingNode
from ee.hogai.graph.query_planner.nodes import QueryPlannerNode, QueryPlannerToolsNode
from ee.hogai.graph.session_summaries.nodes import SessionSummarizationNode
from ee.hogai.graph.base import BaseAssistantGraph
from ee.hogai.graph.title_generator.nodes import TitleGeneratorNode
from ee.hogai.utils.types import AssistantNodeName, AssistantState, StateType
from ee.hogai.utils.types.composed import MaxNodeName
from ee.hogai.utils.types.base import AssistantGraphName, AssistantNodeName, AssistantState, PartialAssistantState
from .dashboards.nodes import DashboardCreationNode
from .funnels.nodes import FunnelGeneratorNode, FunnelGeneratorToolsNode
from .inkeep_docs.nodes import InkeepDocsNode
from .insights.nodes import InsightSearchNode
from .memory.nodes import (
MemoryCollectorNode,
MemoryCollectorToolsNode,
@@ -28,51 +16,19 @@ from .memory.nodes import (
MemoryOnboardingFinalizeNode,
MemoryOnboardingNode,
)
from .query_executor.nodes import QueryExecutorNode
from .rag.nodes import InsightRagContextNode
from .retention.nodes import RetentionGeneratorNode, RetentionGeneratorToolsNode
from .root.nodes import RootNode, RootNodeTools
from .sql.nodes import SQLGeneratorNode, SQLGeneratorToolsNode
from .trends.nodes import TrendsGeneratorNode, TrendsGeneratorToolsNode
global_checkpointer = DjangoCheckpointer()
class BaseAssistantGraph(Generic[StateType]):
_team: Team
_user: User
_graph: StateGraph
_parent_tool_call_id: str | None
class AssistantGraph(BaseAssistantGraph[AssistantState, PartialAssistantState]):
@property
def graph_name(self) -> AssistantGraphName:
return AssistantGraphName.ASSISTANT
def __init__(self, team: Team, user: User, state_type: type[StateType], parent_tool_call_id: str | None = None):
self._team = team
self._user = user
self._graph = StateGraph(state_type)
self._has_start_node = False
self._parent_tool_call_id = parent_tool_call_id
@property
def state_type(self) -> type[AssistantState]:
return AssistantState
def add_edge(self, from_node: MaxNodeName, to_node: MaxNodeName):
if from_node == AssistantNodeName.START:
self._has_start_node = True
self._graph.add_edge(from_node, to_node)
return self
def add_node(self, node: MaxNodeName, action: Any):
if self._parent_tool_call_id:
action._parent_tool_call_id = self._parent_tool_call_id
self._graph.add_node(node, action)
return self
def compile(self, checkpointer: DjangoCheckpointer | None | Literal[False] = None):
if not self._has_start_node:
raise ValueError("Start node not added to the graph")
# TRICKY: We check `is not None` because False has a special meaning of "no checkpointer", which we want to pass on
compiled_graph = self._graph.compile(
checkpointer=checkpointer if checkpointer is not None else global_checkpointer
)
return compiled_graph
def add_title_generator(self, end_node: MaxNodeName = AssistantNodeName.END):
def add_title_generator(self, end_node: AssistantNodeName = AssistantNodeName.END):
self._has_start_node = True
title_generator = TitleGeneratorNode(self._team, self._user)
@@ -81,178 +37,19 @@ class BaseAssistantGraph(Generic[StateType]):
self._graph.add_edge(AssistantNodeName.TITLE_GENERATOR, end_node)
return self
class InsightsAssistantGraph(BaseAssistantGraph[AssistantState]):
def __init__(self, team: Team, user: User, tool_call_id: str | None = None):
super().__init__(team, user, AssistantState, tool_call_id)
def add_rag_context(self):
self._has_start_node = True
retriever = InsightRagContextNode(self._team, self._user)
self.add_node(AssistantNodeName.INSIGHT_RAG_CONTEXT, retriever)
self._graph.add_edge(AssistantNodeName.START, AssistantNodeName.INSIGHT_RAG_CONTEXT)
self._graph.add_edge(AssistantNodeName.INSIGHT_RAG_CONTEXT, AssistantNodeName.QUERY_PLANNER)
return self
def add_trends_generator(self, next_node: AssistantNodeName = AssistantNodeName.QUERY_EXECUTOR):
trends_generator = TrendsGeneratorNode(self._team, self._user)
self.add_node(AssistantNodeName.TRENDS_GENERATOR, trends_generator)
trends_generator_tools = TrendsGeneratorToolsNode(self._team, self._user)
self.add_node(AssistantNodeName.TRENDS_GENERATOR_TOOLS, trends_generator_tools)
self._graph.add_edge(AssistantNodeName.TRENDS_GENERATOR_TOOLS, AssistantNodeName.TRENDS_GENERATOR)
self._graph.add_conditional_edges(
AssistantNodeName.TRENDS_GENERATOR,
trends_generator.router,
path_map={
"tools": AssistantNodeName.TRENDS_GENERATOR_TOOLS,
"next": next_node,
},
)
return self
def add_funnel_generator(self, next_node: AssistantNodeName = AssistantNodeName.QUERY_EXECUTOR):
funnel_generator = FunnelGeneratorNode(self._team, self._user)
self.add_node(AssistantNodeName.FUNNEL_GENERATOR, funnel_generator)
funnel_generator_tools = FunnelGeneratorToolsNode(self._team, self._user)
self.add_node(AssistantNodeName.FUNNEL_GENERATOR_TOOLS, funnel_generator_tools)
self._graph.add_edge(AssistantNodeName.FUNNEL_GENERATOR_TOOLS, AssistantNodeName.FUNNEL_GENERATOR)
self._graph.add_conditional_edges(
AssistantNodeName.FUNNEL_GENERATOR,
funnel_generator.router,
path_map={
"tools": AssistantNodeName.FUNNEL_GENERATOR_TOOLS,
"next": next_node,
},
)
return self
def add_retention_generator(self, next_node: AssistantNodeName = AssistantNodeName.QUERY_EXECUTOR):
retention_generator = RetentionGeneratorNode(self._team, self._user)
self.add_node(AssistantNodeName.RETENTION_GENERATOR, retention_generator)
retention_generator_tools = RetentionGeneratorToolsNode(self._team, self._user)
self.add_node(AssistantNodeName.RETENTION_GENERATOR_TOOLS, retention_generator_tools)
self._graph.add_edge(AssistantNodeName.RETENTION_GENERATOR_TOOLS, AssistantNodeName.RETENTION_GENERATOR)
self._graph.add_conditional_edges(
AssistantNodeName.RETENTION_GENERATOR,
retention_generator.router,
path_map={
"tools": AssistantNodeName.RETENTION_GENERATOR_TOOLS,
"next": next_node,
},
)
return self
def add_query_planner(
self,
path_map: Optional[
dict[Literal["trends", "funnel", "retention", "sql", "continue", "end"], AssistantNodeName]
] = None,
):
query_planner = QueryPlannerNode(self._team, self._user)
self.add_node(AssistantNodeName.QUERY_PLANNER, query_planner)
self._graph.add_edge(AssistantNodeName.QUERY_PLANNER, AssistantNodeName.QUERY_PLANNER_TOOLS)
query_planner_tools = QueryPlannerToolsNode(self._team, self._user)
self.add_node(AssistantNodeName.QUERY_PLANNER_TOOLS, query_planner_tools)
self._graph.add_conditional_edges(
AssistantNodeName.QUERY_PLANNER_TOOLS,
query_planner_tools.router,
path_map=path_map # type: ignore
or {
"continue": AssistantNodeName.QUERY_PLANNER,
"trends": AssistantNodeName.TRENDS_GENERATOR,
"funnel": AssistantNodeName.FUNNEL_GENERATOR,
"retention": AssistantNodeName.RETENTION_GENERATOR,
"sql": AssistantNodeName.SQL_GENERATOR,
"end": AssistantNodeName.END,
},
)
return self
def add_sql_generator(self, next_node: AssistantNodeName = AssistantNodeName.QUERY_EXECUTOR):
sql_generator = SQLGeneratorNode(self._team, self._user)
self.add_node(AssistantNodeName.SQL_GENERATOR, sql_generator)
sql_generator_tools = SQLGeneratorToolsNode(self._team, self._user)
self.add_node(AssistantNodeName.SQL_GENERATOR_TOOLS, sql_generator_tools)
self._graph.add_edge(AssistantNodeName.SQL_GENERATOR_TOOLS, AssistantNodeName.SQL_GENERATOR)
self._graph.add_conditional_edges(
AssistantNodeName.SQL_GENERATOR,
sql_generator.router,
path_map={
"tools": AssistantNodeName.SQL_GENERATOR_TOOLS,
"next": next_node,
},
)
return self
def add_query_executor(self, next_node: AssistantNodeName = AssistantNodeName.END):
query_executor_node = QueryExecutorNode(self._team, self._user)
self.add_node(AssistantNodeName.QUERY_EXECUTOR, query_executor_node)
self._graph.add_edge(AssistantNodeName.QUERY_EXECUTOR, next_node)
return self
def add_query_creation_flow(self, next_node: AssistantNodeName = AssistantNodeName.QUERY_EXECUTOR):
"""Add all nodes and edges EXCEPT query execution."""
return (
self.add_rag_context()
.add_query_planner()
.add_trends_generator(next_node=next_node)
.add_funnel_generator(next_node=next_node)
.add_retention_generator(next_node=next_node)
.add_sql_generator(next_node=next_node)
)
def compile_full_graph(self, checkpointer: DjangoCheckpointer | None = None):
return self.add_query_creation_flow().add_query_executor().compile(checkpointer=checkpointer)
class AssistantGraph(BaseAssistantGraph[AssistantState]):
def __init__(self, team: Team, user: User):
super().__init__(team, user, AssistantState)
def add_root(
self,
path_map: Optional[dict[Hashable, AssistantNodeName]] = None,
tools_node: AssistantNodeName = AssistantNodeName.ROOT_TOOLS,
):
path_map = path_map or {
"insights": AssistantNodeName.INSIGHTS_SUBGRAPH,
"search_documentation": AssistantNodeName.INKEEP_DOCS,
"root": AssistantNodeName.ROOT,
"billing": AssistantNodeName.BILLING,
"end": AssistantNodeName.END,
"insights_search": AssistantNodeName.INSIGHTS_SEARCH,
"session_summarization": AssistantNodeName.SESSION_SUMMARIZATION,
"create_dashboard": AssistantNodeName.DASHBOARD_CREATION,
}
def add_root(self, router: Callable[[AssistantState], AssistantNodeName] | None = None):
root_node = RootNode(self._team, self._user)
self.add_node(AssistantNodeName.ROOT, root_node)
root_node_tools = RootNodeTools(self._team, self._user)
self.add_node(AssistantNodeName.ROOT_TOOLS, root_node_tools)
self._graph.add_edge(AssistantNodeName.ROOT, tools_node)
self._graph.add_conditional_edges(
AssistantNodeName.ROOT_TOOLS, root_node_tools.router, path_map=cast(dict[Hashable, str], path_map)
AssistantNodeName.ROOT, router or cast(Callable[[AssistantState], AssistantNodeName], root_node.router)
)
self._graph.add_conditional_edges(
AssistantNodeName.ROOT_TOOLS,
root_node_tools.router,
path_map={"root": AssistantNodeName.ROOT, "end": AssistantNodeName.END},
)
return self
def add_insights(self, next_node: AssistantNodeName = AssistantNodeName.ROOT):
insights_assistant_graph = InsightsAssistantGraph(self._team, self._user)
compiled_graph = insights_assistant_graph.compile_full_graph()
self.add_node(AssistantNodeName.INSIGHTS_SUBGRAPH, compiled_graph)
self._graph.add_edge(AssistantNodeName.INSIGHTS_SUBGRAPH, next_node)
return self
def add_memory_onboarding(
@@ -332,55 +129,6 @@ class AssistantGraph(BaseAssistantGraph[AssistantState]):
self._graph.add_edge(AssistantNodeName.MEMORY_COLLECTOR_TOOLS, AssistantNodeName.MEMORY_COLLECTOR)
return self
def add_inkeep_docs(self, path_map: Optional[dict[Hashable, AssistantNodeName]] = None):
"""Add the Inkeep docs search node to the graph."""
path_map = path_map or {
"end": AssistantNodeName.END,
"root": AssistantNodeName.ROOT,
}
inkeep_docs_node = InkeepDocsNode(self._team, self._user)
self.add_node(AssistantNodeName.INKEEP_DOCS, inkeep_docs_node)
self._graph.add_conditional_edges(
AssistantNodeName.INKEEP_DOCS,
inkeep_docs_node.router,
path_map=cast(dict[Hashable, str], path_map),
)
return self
def add_billing(self):
billing_node = BillingNode(self._team, self._user)
self.add_node(AssistantNodeName.BILLING, billing_node)
self._graph.add_edge(AssistantNodeName.BILLING, AssistantNodeName.ROOT)
return self
def add_insights_search(self, end_node: AssistantNodeName = AssistantNodeName.END):
path_map = {
"end": end_node,
"root": AssistantNodeName.ROOT,
}
insights_search_node = InsightSearchNode(self._team, self._user)
self.add_node(AssistantNodeName.INSIGHTS_SEARCH, insights_search_node)
self._graph.add_conditional_edges(
AssistantNodeName.INSIGHTS_SEARCH,
insights_search_node.router,
path_map=cast(dict[Hashable, str], path_map),
)
return self
def add_session_summarization(self, end_node: AssistantNodeName = AssistantNodeName.END):
session_summarization_node = SessionSummarizationNode(self._team, self._user)
self.add_node(AssistantNodeName.SESSION_SUMMARIZATION, session_summarization_node)
self._graph.add_edge(AssistantNodeName.SESSION_SUMMARIZATION, AssistantNodeName.ROOT)
return self
def add_dashboard_creation(self, end_node: AssistantNodeName = AssistantNodeName.END):
builder = self._graph
dashboard_creation_node = DashboardCreationNode(self._team, self._user)
builder.add_node(AssistantNodeName.DASHBOARD_CREATION, dashboard_creation_node)
builder.add_edge(AssistantNodeName.DASHBOARD_CREATION, AssistantNodeName.ROOT)
return self
def compile_full_graph(self, checkpointer: DjangoCheckpointer | None = None):
return (
self.add_title_generator()
@@ -388,11 +136,5 @@ class AssistantGraph(BaseAssistantGraph[AssistantState]):
.add_memory_collector()
.add_memory_collector_tools()
.add_root()
.add_insights()
.add_inkeep_docs()
.add_billing()
.add_insights_search()
.add_session_summarization()
.add_dashboard_creation()
.compile(checkpointer=checkpointer)
)

View File

@@ -1,5 +1,5 @@
from collections.abc import Mapping, Sequence
from typing import Any, Literal
from typing import Any
from uuid import uuid4
from django.conf import settings
@@ -34,19 +34,22 @@ class InkeepDocsNode(RootNode): # Inheriting from RootNode to use the same mess
async def arun(self, state: AssistantState, config: RunnableConfig) -> PartialAssistantState:
"""Process the state and return documentation search results."""
self.dispatcher.message(AssistantMessage(content="Checking PostHog documentation...", id=str(uuid4())))
self.dispatcher.update("Checking PostHog documentation...")
messages = self._construct_messages(
state.messages, state.root_conversation_start_id, state.root_tool_calls_count
)
message: LangchainAIMessage = await self._get_model().ainvoke(messages, config)
# NOTE: This is a hacky way to send these messages as part of the root tool call
# Can't think of a better interface for this at the moment.
self.dispatcher.set_as_root()
should_continue = INKEEP_DATA_CONTINUATION_PHRASE in message.content
tool_prompt = "Checking PostHog documentation..."
if should_continue:
tool_prompt = "The documentation search results are provided in the next Assistant message.\n<system_reminder>Continue with the user's data request.</system_reminder>"
return PartialAssistantState(
messages=[
AssistantToolCallMessage(
content="Checking PostHog documentation...", tool_call_id=state.root_tool_call_id, id=str(uuid4())
),
AssistantToolCallMessage(content=tool_prompt, tool_call_id=state.root_tool_call_id, id=str(uuid4())),
AssistantMessage(content=message.content, id=str(uuid4())),
],
root_tool_call_id=None,
@@ -116,16 +119,3 @@ class InkeepDocsNode(RootNode): # Inheriting from RootNode to use the same mess
) -> list[BaseMessage]:
# Original node has Anthropic messages, but Inkeep expects OpenAI messages
return convert_to_openai_messages(conversation_window, tool_result_messages)
def router(self, state: AssistantState) -> Literal["end", "root"]:
last_message = state.messages[-1]
if isinstance(last_message, AssistantMessage) and INKEEP_DATA_CONTINUATION_PHRASE in last_message.content:
# The continuation phrase solution is a little weird, but seems it's the best one for agentic capabilities
# I've found here. The alternatives that definitively don't work are:
# 1. Using tool calls in this node - the Inkeep API only supports providing their own pre-defined tools
# (for including extra search metadata), nothing else
# 2. Always going back to root, for root to judge whether to continue or not - GPT-4o is terrible at this,
# and I was unable to stop it from repeating the context from the last assistant message, i.e. the Inkeep
# output message (doesn't quite work to tell it to output an empty message, or to call an "end" tool)
return "root"
return "end"

View File

@@ -67,6 +67,9 @@ class TestInkeepDocsNode(ClickhouseTestMixin, BaseTest):
assert next_state is not None
messages = cast(list, next_state.messages)
self.assertEqual(len(messages), 2)
# Tool call message should have the continuation prompt
first_message = cast(AssistantToolCallMessage, messages[0])
self.assertIn("Continue with the user's data request", first_message.content)
second_message = cast(AssistantMessage, messages[1])
self.assertEqual(second_message.content, response_with_continuation)
@@ -91,26 +94,6 @@ class TestInkeepDocsNode(ClickhouseTestMixin, BaseTest):
self.assertIsInstance(messages[2], LangchainAIMessage)
self.assertIsInstance(messages[3], LangchainHumanMessage)
def test_router_with_data_continuation(self):
node = InkeepDocsNode(self.team, self.user)
state = AssistantState(
messages=[
HumanMessage(content="Explain PostHog trends, and show me an example trends insight"),
AssistantMessage(content=f"Here's the documentation: XYZ.\n{INKEEP_DATA_CONTINUATION_PHRASE}"),
]
)
self.assertEqual(node.router(state), "root") # Going back to root, so that the agent can continue with the task
def test_router_without_data_continuation(self):
node = InkeepDocsNode(self.team, self.user)
state = AssistantState(
messages=[
HumanMessage(content="How do I use feature flags?"),
AssistantMessage(content="Here's how to use feature flags..."),
]
)
self.assertEqual(node.router(state), "end") # Ending
async def test_tool_call_id_handling(self):
"""Test that tool_call_id is properly handled in both input and output states."""
test_tool_call_id = str(uuid4())

View File

@@ -4,7 +4,7 @@ import inspect
import warnings
from datetime import timedelta
from functools import wraps
from typing import Literal, Optional, TypedDict
from typing import TYPE_CHECKING, Optional, TypedDict
from uuid import uuid4
from django.db.models import Max
@@ -16,7 +16,7 @@ from langchain_core.runnables import RunnableConfig
from langchain_core.tools import tool
from langchain_openai import ChatOpenAI
from posthog.schema import AssistantMessage, AssistantToolCallMessage, VisualizationMessage
from posthog.schema import AssistantToolCallMessage, VisualizationMessage
from posthog.exceptions_capture import capture_exception
from posthog.models import Insight
@@ -28,18 +28,19 @@ from ee.hogai.graph.shared_prompts import HYPERLINK_USAGE_INSTRUCTIONS
from ee.hogai.utils.helpers import build_insight_url
from ee.hogai.utils.types import AssistantState, PartialAssistantState
from ee.hogai.utils.types.base import AssistantNodeName
from ee.hogai.utils.types.composed import MaxNodeName
from .prompts import (
EMPTY_DATABASE_ERROR_MESSAGE,
ITERATIVE_SEARCH_SYSTEM_PROMPT,
ITERATIVE_SEARCH_USER_PROMPT,
NO_INSIGHTS_FOUND_MESSAGE,
PAGINATION_INSTRUCTIONS_TEMPLATE,
SEARCH_ERROR_INSTRUCTIONS,
TOOL_BASED_EVALUATION_SYSTEM_PROMPT,
)
if TYPE_CHECKING:
from ee.hogai.utils.types.composed import MaxNodeName
logger = structlog.get_logger(__name__)
# Silence Pydantic serializer warnings for creation of VisualizationMessage/Query execution
warnings.filterwarnings("ignore", category=UserWarning, message=".*Pydantic serializer.*")
@@ -105,6 +106,10 @@ class InsightDict(TypedDict):
short_id: str
class NoInsightsException(Exception):
"""Exception indicating that the insight search cannot be done because the user does not have any insights."""
class InsightSearchNode(AssistantNode):
PAGE_SIZE = 500
MAX_SEARCH_ITERATIONS = 6
@@ -114,7 +119,7 @@ class InsightSearchNode(AssistantNode):
MAX_SERIES_TO_PROCESS = 3
@property
def node_name(self) -> MaxNodeName:
def node_name(self) -> "MaxNodeName":
return AssistantNodeName.INSIGHTS_SEARCH
def __init__(self, *args, **kwargs):
@@ -133,9 +138,31 @@ class InsightSearchNode(AssistantNode):
self._query_cache = {}
self._insight_id_cache = {}
def _dispatch_update_message(self, content: str) -> None:
"""Dispatch an update message to the assistant."""
self.dispatcher.message(AssistantMessage(content=content))
@timing_logger("InsightSearchNode.arun")
async def arun(self, state: AssistantState, config: RunnableConfig) -> PartialAssistantState | None:
self.dispatcher.update("Searching for insights")
search_query = state.search_insights_query
self._current_iteration = 0
total_count = await self._get_total_insights_count()
if total_count == 0:
raise NoInsightsException
selected_insights = await self._search_insights_iteratively(search_query or "")
logger.warning(
f"{TIMING_LOG_PREFIX} search_insights_iteratively returned {len(selected_insights)} insights: {selected_insights}"
)
if selected_insights:
self.dispatcher.update(f"Evaluating {len(selected_insights)} insights to find the best match")
else:
self.dispatcher.update("No existing insights found, creating a new one")
evaluation_result = await self._evaluate_insights_with_tools(
selected_insights, search_query or "", max_selections=1
)
return self._handle_evaluation_result(evaluation_result, state)
def _create_page_reader_tool(self):
"""Create tool for reading insights pages during agentic RAG loop."""
@@ -185,36 +212,6 @@ class InsightSearchNode(AssistantNode):
return [select_insight, reject_all_insights]
@timing_logger("InsightSearchNode.arun")
async def arun(self, state: AssistantState, config: RunnableConfig) -> PartialAssistantState | None:
self._dispatch_update_message("Searching for insights")
search_query = state.search_insights_query
try:
self._current_iteration = 0
total_count = await self._get_total_insights_count()
if total_count == 0:
return self._handle_empty_database(state)
selected_insights = await self._search_insights_iteratively(search_query or "")
logger.warning(
f"{TIMING_LOG_PREFIX} search_insights_iteratively returned {len(selected_insights)} insights: {selected_insights}"
)
if selected_insights:
self._dispatch_update_message(f"Evaluating {len(selected_insights)} insights to find the best match")
else:
self._dispatch_update_message("No existing insights found, creating a new one")
evaluation_result = await self._evaluate_insights_with_tools(
selected_insights, search_query or "", max_selections=1
)
return self._handle_evaluation_result(evaluation_result, state)
except Exception as e:
return self._handle_search_error(e, state)
@timing_logger("InsightSearchNode._get_insights_queryset")
def _get_insights_queryset(self):
"""Get Insight objects with latest view time annotated and cutoff date."""
@@ -235,10 +232,6 @@ class InsightSearchNode(AssistantNode):
self._total_insights_count = await self._get_insights_queryset().acount()
return self._total_insights_count
def _handle_empty_database(self, state: AssistantState) -> PartialAssistantState:
"""Handle the case when no insights exist in the database. (Rare edge-case but still possible)"""
return self._create_error_response(EMPTY_DATABASE_ERROR_MESSAGE, state.root_tool_call_id)
def _handle_evaluation_result(self, evaluation_result: dict, state: AssistantState) -> PartialAssistantState:
"""Process the evaluation result and return appropriate response."""
if evaluation_result["should_use_existing"]:
@@ -256,12 +249,12 @@ class InsightSearchNode(AssistantNode):
return PartialAssistantState(
messages=[
*evaluation_result["visualization_messages"],
AssistantToolCallMessage(
content=formatted_content,
tool_call_id=state.root_tool_call_id or "unknown",
id=str(uuid4()),
),
*evaluation_result["visualization_messages"],
],
selected_insight_ids=evaluation_result["selected_insights"],
search_insights_query=None,
@@ -284,16 +277,6 @@ class InsightSearchNode(AssistantNode):
selected_insight_ids=None,
)
@timing_logger("InsightSearchNode._handle_search_error")
def _handle_search_error(self, e: Exception, state: AssistantState) -> PartialAssistantState:
"""Handle exceptions during search process."""
capture_exception(e)
logger.error(f"{TIMING_LOG_PREFIX} Error in InsightSearchNode: {e}", exc_info=True)
return self._create_error_response(
SEARCH_ERROR_INSTRUCTIONS,
state.root_tool_call_id,
)
def _format_insight_for_display(self, insight: InsightDict) -> str:
"""Format a single insight for display."""
name = insight["name"] or insight["derived_name"] or "Unnamed"
@@ -424,7 +407,7 @@ class InsightSearchNode(AssistantNode):
selected_insights = []
for step in ["Searching through existing insights", "Analyzing available insights"]:
self._dispatch_update_message(step)
self.dispatcher.update(step)
logger.warning(f"{TIMING_LOG_PREFIX} Starting iterative search, max_iterations={self._max_iterations}")
while self._current_iteration < self._max_iterations:
@@ -446,7 +429,7 @@ class InsightSearchNode(AssistantNode):
logger.warning(
f"{TIMING_LOG_PREFIX} STALL POINT(?): Streamed 'Finding the most relevant insights' - about to fetch page content"
)
self._dispatch_update_message("Finding the most relevant insights")
self.dispatcher.update("Finding the most relevant insights")
logger.warning(f"{TIMING_LOG_PREFIX} Fetching page content for page {page_num}")
tool_response = await self._get_page_content_for_tool(page_num)
@@ -465,15 +448,15 @@ class InsightSearchNode(AssistantNode):
content = response.content if isinstance(response.content, str) else str(response.content)
selected_insights = self._parse_insight_ids(content)
if selected_insights:
self._dispatch_update_message(f"Found {len(selected_insights)} relevant insights")
self.dispatcher.update(f"Found {len(selected_insights)} relevant insights")
else:
self._dispatch_update_message("No matching insights found")
self.dispatcher.update("No matching insights found")
break
except Exception as e:
capture_exception(e)
error_message = f"Error during search"
self._dispatch_update_message(error_message)
self.dispatcher.update(error_message)
break
return selected_insights
@@ -692,7 +675,7 @@ class InsightSearchNode(AssistantNode):
"""Create a VisualizationMessage to render the insight UI."""
try:
for step in ["Executing insight query...", "Processing query parameters", "Running data analysis"]:
self._dispatch_update_message(step)
self.dispatcher.update(step)
query_obj, _ = await self._process_insight_query(insight)
@@ -786,7 +769,7 @@ class InsightSearchNode(AssistantNode):
async def _run_evaluation_loop(self, user_query: str, insights_summary: list[str], max_selections: int) -> None:
"""Run the evaluation loop with LLM."""
for step in ["Analyzing insights to match your request", "Comparing insights for best fit"]:
self._dispatch_update_message(step)
self.dispatcher.update(step)
tools = self._create_insight_evaluation_tools()
llm_with_tools = self._model.bind_tools(tools)
@@ -800,7 +783,7 @@ class InsightSearchNode(AssistantNode):
if getattr(response, "tool_calls", None):
# Only stream on first iteration to avoid noise
if iteration == 0:
self._dispatch_update_message("Making evaluation decisions")
self.dispatcher.update("Making evaluation decisions")
self._process_evaluation_tool_calls(response, messages, tools)
else:
break
@@ -841,7 +824,7 @@ class InsightSearchNode(AssistantNode):
num_insights = len(self._evaluation_selections)
insight_word = "insight" if num_insights == 1 else "insights"
self._dispatch_update_message(f"Perfect! Found {num_insights} suitable {insight_word}")
self.dispatcher.update(f"Perfect! Found {num_insights} suitable {insight_word}")
for _, selection in self._evaluation_selections.items():
insight = selection["insight"]
@@ -871,7 +854,7 @@ class InsightSearchNode(AssistantNode):
async def _create_rejection_result(self) -> dict:
"""Create result for when all insights are rejected."""
self._dispatch_update_message("Will create a custom insight tailored to your request")
self.dispatcher.update("Will create a custom insight tailored to your request")
return {
"should_use_existing": False,
@@ -880,9 +863,6 @@ class InsightSearchNode(AssistantNode):
"visualization_messages": [],
}
def router(self, state: AssistantState) -> Literal["root"]:
return "root"
@property
def _model(self):
return ChatOpenAI(

View File

@@ -36,7 +36,3 @@ Instructions:
NO_INSIGHTS_FOUND_MESSAGE = (
"No existing insights found matching your query. Creating a new insight based on your request."
)
SEARCH_ERROR_INSTRUCTIONS = "INSTRUCTIONS: Tell the user that you encountered an issue while searching for insights and suggest they try again with a different search term."
EMPTY_DATABASE_ERROR_MESSAGE = "No insights found in the database."

View File

@@ -23,7 +23,7 @@ from posthog.schema import (
from posthog.models import Insight, InsightViewed
from ee.hogai.graph.insights.nodes import InsightDict, InsightSearchNode
from ee.hogai.graph.insights.nodes import InsightDict, InsightSearchNode, NoInsightsException
from ee.hogai.utils.types import AssistantState, PartialAssistantState
from ee.models.assistant import Conversation
@@ -115,11 +115,6 @@ class TestInsightSearchNode(BaseTest):
short_id=insight.short_id,
)
def test_router_returns_root(self):
"""Test that router returns 'root' as expected."""
result = self.node.router(AssistantState(messages=[]))
self.assertEqual(result, "root")
async def test_load_insights_page(self):
"""Test loading paginated insights from database."""
# Load first page
@@ -350,12 +345,6 @@ class TestInsightSearchNode(BaseTest):
# Should return empty list when LLM fails to select anything
self.assertEqual(len(result), 0)
def test_router_always_returns_root(self):
"""Test that router always returns 'root'."""
state = AssistantState(messages=[], root_tool_insight_plan="some plan", search_insights_query=None)
result = self.node.router(state)
self.assertEqual(result, "root")
async def test_evaluation_flow_returns_creation_when_no_suitable_insights(self):
"""Test that when evaluation returns NO, the system transitions to creation flow."""
selected_insights = [self.insight1.id, self.insight2.id]
@@ -411,21 +400,6 @@ class TestInsightSearchNode(BaseTest):
"root_tool_insight_plan should be set to search_query",
)
# Test router behavior with the returned state
# Create a new state that simulates what happens after this node runs
post_evaluation_state = AssistantState(
messages=state.messages,
root_tool_insight_plan=search_query, # This gets set to search_query
search_insights_query=None, # This gets cleared
)
router_result = self.node.router(post_evaluation_state)
self.assertEqual(
router_result,
"root",
"Router should always return root",
)
# Verify that _evaluate_insights_with_tools was called with the search_query
mock_evaluate.assert_called_once_with(selected_insights, search_query, max_selections=1)
@@ -481,7 +455,7 @@ class TestInsightSearchNode(BaseTest):
self.assertIsNone(result.root_tool_call_id)
def test_run_with_no_insights(self):
"""Test arun method when no insights exist."""
"""Test arun method when no insights exist - should raise NoInsightsException."""
# Clear all insights (done outside async context)
InsightViewed.objects.all().delete()
Insight.objects.all().delete()
@@ -497,13 +471,10 @@ class TestInsightSearchNode(BaseTest):
async def async_test():
# Mock the database calls that happen in async context
with patch.object(self.node, "_get_total_insights_count", return_value=0):
result = await self.node.arun(state, {"configurable": {"thread_id": str(conversation.id)}})
return result
await self.node.arun(state, {"configurable": {"thread_id": str(conversation.id)}})
result = asyncio.run(async_test())
self.assertIsInstance(result, PartialAssistantState)
self.assertEqual(len(result.messages), 1)
self.assertIn("No insights found in the database", result.messages[0].content)
with self.assertRaises(NoInsightsException):
asyncio.run(async_test())
async def test_team_filtering(self):
"""Test that insights are filtered by team."""

View File

@@ -0,0 +1,155 @@
from typing import Literal, Optional
from ee.hogai.django_checkpoint.checkpointer import DjangoCheckpointer
from ee.hogai.graph.base import BaseAssistantGraph
from ee.hogai.utils.types.base import AssistantGraphName, AssistantNodeName, AssistantState, PartialAssistantState
from ..funnels.nodes import FunnelGeneratorNode, FunnelGeneratorToolsNode
from ..query_executor.nodes import QueryExecutorNode
from ..query_planner.nodes import QueryPlannerNode, QueryPlannerToolsNode
from ..rag.nodes import InsightRagContextNode
from ..retention.nodes import RetentionGeneratorNode, RetentionGeneratorToolsNode
from ..sql.nodes import SQLGeneratorNode, SQLGeneratorToolsNode
from ..trends.nodes import TrendsGeneratorNode, TrendsGeneratorToolsNode
class InsightsGraph(BaseAssistantGraph[AssistantState, PartialAssistantState]):
@property
def graph_name(self) -> AssistantGraphName:
return AssistantGraphName.INSIGHTS
@property
def state_type(self) -> type[AssistantState]:
return AssistantState
def add_rag_context(self):
self._has_start_node = True
retriever = InsightRagContextNode(self._team, self._user)
self.add_node(AssistantNodeName.INSIGHT_RAG_CONTEXT, retriever)
self._graph.add_edge(AssistantNodeName.START, AssistantNodeName.INSIGHT_RAG_CONTEXT)
self._graph.add_edge(AssistantNodeName.INSIGHT_RAG_CONTEXT, AssistantNodeName.QUERY_PLANNER)
return self
def add_trends_generator(self, next_node: AssistantNodeName = AssistantNodeName.QUERY_EXECUTOR):
trends_generator = TrendsGeneratorNode(self._team, self._user)
self.add_node(AssistantNodeName.TRENDS_GENERATOR, trends_generator)
trends_generator_tools = TrendsGeneratorToolsNode(self._team, self._user)
self.add_node(AssistantNodeName.TRENDS_GENERATOR_TOOLS, trends_generator_tools)
self._graph.add_edge(AssistantNodeName.TRENDS_GENERATOR_TOOLS, AssistantNodeName.TRENDS_GENERATOR)
self._graph.add_conditional_edges(
AssistantNodeName.TRENDS_GENERATOR,
trends_generator.router,
path_map={
"tools": AssistantNodeName.TRENDS_GENERATOR_TOOLS,
"next": next_node,
},
)
return self
def add_funnel_generator(self, next_node: AssistantNodeName = AssistantNodeName.QUERY_EXECUTOR):
funnel_generator = FunnelGeneratorNode(self._team, self._user)
self.add_node(AssistantNodeName.FUNNEL_GENERATOR, funnel_generator)
funnel_generator_tools = FunnelGeneratorToolsNode(self._team, self._user)
self.add_node(AssistantNodeName.FUNNEL_GENERATOR_TOOLS, funnel_generator_tools)
self._graph.add_edge(AssistantNodeName.FUNNEL_GENERATOR_TOOLS, AssistantNodeName.FUNNEL_GENERATOR)
self._graph.add_conditional_edges(
AssistantNodeName.FUNNEL_GENERATOR,
funnel_generator.router,
path_map={
"tools": AssistantNodeName.FUNNEL_GENERATOR_TOOLS,
"next": next_node,
},
)
return self
def add_retention_generator(self, next_node: AssistantNodeName = AssistantNodeName.QUERY_EXECUTOR):
retention_generator = RetentionGeneratorNode(self._team, self._user)
self.add_node(AssistantNodeName.RETENTION_GENERATOR, retention_generator)
retention_generator_tools = RetentionGeneratorToolsNode(self._team, self._user)
self.add_node(AssistantNodeName.RETENTION_GENERATOR_TOOLS, retention_generator_tools)
self._graph.add_edge(AssistantNodeName.RETENTION_GENERATOR_TOOLS, AssistantNodeName.RETENTION_GENERATOR)
self._graph.add_conditional_edges(
AssistantNodeName.RETENTION_GENERATOR,
retention_generator.router,
path_map={
"tools": AssistantNodeName.RETENTION_GENERATOR_TOOLS,
"next": next_node,
},
)
return self
def add_query_planner(
self,
path_map: Optional[
dict[Literal["trends", "funnel", "retention", "sql", "continue", "end"], AssistantNodeName]
] = None,
):
query_planner = QueryPlannerNode(self._team, self._user)
self.add_node(AssistantNodeName.QUERY_PLANNER, query_planner)
self._graph.add_edge(AssistantNodeName.QUERY_PLANNER, AssistantNodeName.QUERY_PLANNER_TOOLS)
query_planner_tools = QueryPlannerToolsNode(self._team, self._user)
self.add_node(AssistantNodeName.QUERY_PLANNER_TOOLS, query_planner_tools)
self._graph.add_conditional_edges(
AssistantNodeName.QUERY_PLANNER_TOOLS,
query_planner_tools.router,
path_map=path_map # type: ignore
or {
"continue": AssistantNodeName.QUERY_PLANNER,
"trends": AssistantNodeName.TRENDS_GENERATOR,
"funnel": AssistantNodeName.FUNNEL_GENERATOR,
"retention": AssistantNodeName.RETENTION_GENERATOR,
"sql": AssistantNodeName.SQL_GENERATOR,
"end": AssistantNodeName.END,
},
)
return self
def add_sql_generator(self, next_node: AssistantNodeName = AssistantNodeName.QUERY_EXECUTOR):
sql_generator = SQLGeneratorNode(self._team, self._user)
self.add_node(AssistantNodeName.SQL_GENERATOR, sql_generator)
sql_generator_tools = SQLGeneratorToolsNode(self._team, self._user)
self.add_node(AssistantNodeName.SQL_GENERATOR_TOOLS, sql_generator_tools)
self._graph.add_edge(AssistantNodeName.SQL_GENERATOR_TOOLS, AssistantNodeName.SQL_GENERATOR)
self._graph.add_conditional_edges(
AssistantNodeName.SQL_GENERATOR,
sql_generator.router,
path_map={
"tools": AssistantNodeName.SQL_GENERATOR_TOOLS,
"next": next_node,
},
)
return self
def add_query_executor(self, next_node: AssistantNodeName = AssistantNodeName.END):
query_executor_node = QueryExecutorNode(self._team, self._user)
self.add_node(AssistantNodeName.QUERY_EXECUTOR, query_executor_node)
self._graph.add_edge(AssistantNodeName.QUERY_EXECUTOR, next_node)
return self
def add_query_creation_flow(self, next_node: AssistantNodeName = AssistantNodeName.QUERY_EXECUTOR):
"""Add all nodes and edges EXCEPT query execution."""
return (
self.add_rag_context()
.add_query_planner()
.add_trends_generator(next_node=next_node)
.add_funnel_generator(next_node=next_node)
.add_retention_generator(next_node=next_node)
.add_sql_generator(next_node=next_node)
)
def compile_full_graph(self, checkpointer: DjangoCheckpointer | None = None):
return self.add_query_creation_flow().add_query_executor().compile(checkpointer=checkpointer)

View File

@@ -1,5 +1,5 @@
import datetime
from abc import ABC
from abc import ABC, abstractmethod
from typing import Any, get_args, get_origin
from uuid import UUID
@@ -7,15 +7,15 @@ from django.utils import timezone
from langchain_core.runnables import RunnableConfig
from posthog.schema import AssistantMessage, CurrencyCode
from posthog.schema import CurrencyCode
from posthog.event_usage import groups
from posthog.models import Team
from posthog.models.action.action import Action
from posthog.models.user import User
from ee.hogai.utils.dispatcher import AssistantDispatcher
from ee.hogai.utils.types.base import BaseStateWithIntermediateSteps
from ee.hogai.utils.dispatcher import AssistantDispatcher, create_dispatcher_from_config
from ee.hogai.utils.types.base import BaseStateWithIntermediateSteps, NodePath
from ee.models import Conversation, CoreMemory
@@ -194,4 +194,33 @@ class TaxonomyUpdateDispatcherNodeMixin:
content = "Picking relevant events and properties"
if substeps:
content = substeps[-1]
self.dispatcher.message(AssistantMessage(content=content))
self.dispatcher.update(content)
class AssistantDispatcherMixin(ABC):
_node_path: tuple[NodePath, ...]
_config: RunnableConfig | None
_dispatcher: AssistantDispatcher | None = None
@property
def node_path(self) -> tuple[NodePath, ...]:
return self._node_path
@property
@abstractmethod
def node_name(self) -> str: ...
@property
def tool_call_id(self) -> str:
parent_tool_call_id = next((path.tool_call_id for path in reversed(self._node_path) if path.tool_call_id), None)
if not parent_tool_call_id:
raise ValueError("No tool call ID found")
return parent_tool_call_id
@property
def dispatcher(self) -> AssistantDispatcher:
"""Create a dispatcher for this node"""
if self._dispatcher:
return self._dispatcher
self._dispatcher = create_dispatcher_from_config(self._config or {}, self.node_path)
return self._dispatcher

View File

@@ -52,7 +52,7 @@ class WithInsightCreationTaskExecution:
The type allows None for compatibility with the base class.
"""
# Import here to avoid circular dependency
from ee.hogai.graph.graph import InsightsAssistantGraph
from ee.hogai.graph.insights_graph.graph import InsightsGraph
task = cast(AssistantToolCall, input_dict["task"])
artifacts = input_dict["artifacts"]
@@ -60,7 +60,7 @@ class WithInsightCreationTaskExecution:
self._current_task_id = task.id
# This is needed by the InsightsAssistantGraph to return an AssistantToolCallMessage
# This is needed by the InsightsGraph to return an AssistantToolCallMessage
task_tool_call_id = f"task_{uuid.uuid4().hex[:8]}"
query = task.args["query_description"]
@@ -77,9 +77,7 @@ class WithInsightCreationTaskExecution:
)
subgraph_result_messages: list[AssistantMessageUnion] = []
assistant_graph = InsightsAssistantGraph(
self._team, self._user, tool_call_id=self._parent_tool_call_id
).compile_full_graph()
assistant_graph = InsightsGraph(self._team, self._user).compile_full_graph()
try:
async for chunk in assistant_graph.astream(
input_state,
@@ -156,7 +154,7 @@ class WithInsightCreationTaskExecution:
if isinstance(message, VisualizationMessage) and message.id:
artifact = InsightArtifact(
task_id=tool_call.id,
id=None, # The InsightsAssistantGraph does not create the insight objects
id=None, # The InsightsGraph does not create the insight objects
content="",
query=cast(AnyAssistantGeneratedQuery, message.answer),
)
@@ -194,9 +192,7 @@ class WithInsightSearchTaskExecution:
)
try:
result = await InsightSearchNode(self._team, self._user, tool_call_id=self._parent_tool_call_id).arun(
input_state, config
)
result = await InsightSearchNode(self._team, self._user).arun(input_state, config)
if not result or not result.messages:
logger.warning("Task failed: no messages received from node executor", task_id=task.id)

View File

@@ -76,6 +76,21 @@ class QueryExecutorNode(AssistantNode):
)
return PartialAssistantState(messages=[FailureMessage(content=str(err), id=str(uuid4()))])
return PartialAssistantState(
messages=[
AssistantToolCallMessage(
content=self._format_query_result(viz_message, results, example_prompt),
id=str(uuid4()),
tool_call_id=tool_call_id,
)
],
root_tool_call_id=None,
root_tool_insight_plan=None,
root_tool_insight_type=None,
rag_context=None,
)
def _format_query_result(self, viz_message: VisualizationMessage, results: str, example_prompt: str) -> str:
query_result = QUERY_RESULTS_PROMPT.format(
query_kind=viz_message.answer.kind,
results=results,
@@ -84,20 +99,10 @@ class QueryExecutorNode(AssistantNode):
project_timezone=self.project_timezone,
currency=self.project_currency,
)
formatted_query_result = f"{example_prompt}\n\n{query_result}"
if isinstance(viz_message.answer, AssistantHogQLQuery):
formatted_query_result = f"{example_prompt}\n\n{SQL_QUERY_PROMPT.format(query=viz_message.answer.query)}\n\n{formatted_query_result}"
return PartialAssistantState(
messages=[
AssistantToolCallMessage(content=formatted_query_result, id=str(uuid4()), tool_call_id=tool_call_id)
],
root_tool_call_id=None,
root_tool_insight_plan=None,
root_tool_insight_type=None,
rag_context=None,
)
return formatted_query_result
def _get_example_prompt(self, viz_message: VisualizationMessage) -> str:
if isinstance(viz_message.answer, AssistantTrendsQuery | TrendsQuery):

View File

@@ -1,7 +1,7 @@
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnableConfig
from posthog.schema import AssistantMessage, AssistantRetentionQuery
from posthog.schema import AssistantRetentionQuery
from ee.hogai.utils.types import AssistantState, PartialAssistantState
from ee.hogai.utils.types.base import AssistantNodeName
@@ -25,7 +25,7 @@ class RetentionGeneratorNode(SchemaGeneratorNode[AssistantRetentionQuery]):
return AssistantNodeName.RETENTION_GENERATOR
async def arun(self, state: AssistantState, config: RunnableConfig) -> PartialAssistantState:
self.dispatcher.message(AssistantMessage(content="Creating retention query"))
self.dispatcher.update("Creating retention query")
prompt = ChatPromptTemplate.from_messages(
[
("system", RETENTION_SYSTEM_PROMPT),

View File

@@ -3,27 +3,23 @@ from collections.abc import Awaitable, Mapping, Sequence
from typing import TYPE_CHECKING, Literal, TypeVar, Union
from uuid import uuid4
import structlog
import posthoganalytics
from langchain_core.messages import (
AIMessage as LangchainAIMessage,
BaseMessage,
HumanMessage as LangchainHumanMessage,
ToolCall,
ToolMessage as LangchainToolMessage,
)
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnableConfig
from langgraph.errors import NodeInterrupt
from langgraph.types import Send
from posthoganalytics import capture_exception
from pydantic import BaseModel
from posthog.schema import (
AssistantMessage,
AssistantTool,
AssistantToolCallMessage,
ContextMessage,
FailureMessage,
HumanMessage,
)
from posthog.schema import AssistantMessage, AssistantToolCallMessage, ContextMessage, FailureMessage, HumanMessage
from posthog.models import Team, User
@@ -36,8 +32,14 @@ from ee.hogai.tool import ToolMessagesArtifact
from ee.hogai.utils.anthropic import add_cache_control, convert_to_anthropic_messages
from ee.hogai.utils.helpers import convert_tool_messages_to_dict, normalize_ai_message
from ee.hogai.utils.prompt import format_prompt_string
from ee.hogai.utils.types import AssistantMessageUnion, AssistantNodeName, AssistantState, InsightQuery
from ee.hogai.utils.types.base import PartialAssistantState, ReplaceMessages
from ee.hogai.utils.types.base import (
AssistantMessageUnion,
AssistantNodeName,
AssistantState,
NodePath,
PartialAssistantState,
ReplaceMessages,
)
from ee.hogai.utils.types.composed import MaxNodeName
from .prompts import (
@@ -51,13 +53,13 @@ from .prompts import (
ROOT_TOOL_DOES_NOT_EXIST,
)
from .tools import (
CreateAndQueryInsightTool,
CreateDashboardTool,
ReadDataTool,
ReadTaxonomyTool,
SearchTool,
SessionSummarizationTool,
TodoWriteTool,
create_and_query_insight,
create_dashboard,
session_summarization,
)
if TYPE_CHECKING:
@@ -66,24 +68,14 @@ if TYPE_CHECKING:
SLASH_COMMAND_INIT = "/init"
SLASH_COMMAND_REMEMBER = "/remember"
RouteName = Literal[
"insights",
"root",
"end",
"search_documentation",
"memory_onboarding",
"insights_search",
"billing",
"session_summarization",
"create_dashboard",
]
RootMessageUnion = HumanMessage | AssistantMessage | FailureMessage | AssistantToolCallMessage | ContextMessage
T = TypeVar("T", RootMessageUnion, BaseMessage)
RootTool = Union[type[BaseModel], "MaxTool"]
logger = structlog.get_logger(__name__)
class RootNode(AssistantNode):
MAX_TOOL_CALLS = 24
@@ -95,8 +87,8 @@ class RootNode(AssistantNode):
Determines the thinking configuration for the model.
"""
def __init__(self, team: Team, user: User):
super().__init__(team, user)
def __init__(self, team: Team, user: User, node_path: tuple[NodePath, ...] | None = None):
super().__init__(team, user, node_path)
self._window_manager = AnthropicConversationCompactionManager()
async def arun(self, state: AssistantState, config: RunnableConfig) -> PartialAssistantState:
@@ -168,12 +160,25 @@ class RootNode(AssistantNode):
if messages_to_replace:
new_messages = ReplaceMessages([*messages_to_replace, assistant_message])
# Set new tool call count
tool_call_count = (state.root_tool_calls_count or 0) + 1 if assistant_message.tool_calls else None
return PartialAssistantState(
messages=new_messages,
root_tool_calls_count=tool_call_count,
root_conversation_start_id=window_id,
start_id=start_id,
)
def router(self, state: AssistantState):
last_message = state.messages[-1]
if not isinstance(last_message, AssistantMessage) or not last_message.tool_calls:
return AssistantNodeName.END
return [
Send(AssistantNodeName.ROOT_TOOLS, state.model_copy(update={"root_tool_call_id": tool_call.id}))
for tool_call in last_message.tool_calls
]
@property
def node_name(self) -> MaxNodeName:
return AssistantNodeName.ROOT
@@ -226,11 +231,35 @@ class RootNode(AssistantNode):
if self._is_hard_limit_reached(state.root_tool_calls_count):
return base_model
return base_model.bind_tools(tools, parallel_tool_calls=False)
return base_model.bind_tools(tools, parallel_tool_calls=True)
async def _get_tools(self, state: AssistantState, config: RunnableConfig) -> list[RootTool]:
from ee.hogai.tool import get_contextual_tool_class
# Static toolkit
default_tools: list[type[MaxTool]] = [
ReadTaxonomyTool,
ReadDataTool,
SearchTool,
TodoWriteTool,
]
# The contextual insights tool overrides the static tool. Only inject if it's injected.
if not CreateAndQueryInsightTool.is_editing_mode(self.context_manager):
default_tools.append(CreateAndQueryInsightTool)
# Add session summarization tool if enabled
if self._has_session_summarization_feature_flag():
default_tools.append(SessionSummarizationTool)
# Add other lower-priority tools
default_tools.extend(
[
CreateDashboardTool,
]
)
# Processed tools
available_tools: list[RootTool] = []
# Initialize the static toolkit
@@ -238,37 +267,14 @@ class RootNode(AssistantNode):
# This is just to bound the tools to the model
dynamic_tools = (
tool_class.create_tool_class(
team=self._team,
user=self._user,
tool_call_id="",
state=state,
config=config,
context_manager=self.context_manager,
)
for tool_class in (
ReadTaxonomyTool,
ReadDataTool,
SearchTool,
TodoWriteTool,
team=self._team, user=self._user, state=state, config=config, context_manager=self.context_manager
)
for tool_class in default_tools
)
available_tools.extend(await asyncio.gather(*dynamic_tools))
# Insights tool
tool_names = self.context_manager.get_contextual_tools().keys()
is_editing_insight = AssistantTool.EDIT_CURRENT_INSIGHT in tool_names
if not is_editing_insight:
# This is the default tool, which can be overriden by the MaxTool based tool with the same name
available_tools.append(create_and_query_insight)
# Check if session summarization is enabled for the user
if self._has_session_summarization_feature_flag():
available_tools.append(session_summarization)
# Dashboard creation tool
available_tools.append(create_dashboard)
# Inject contextual tools
tool_names = self.context_manager.get_contextual_tools().keys()
awaited_contextual_tools: list[Awaitable[RootTool]] = []
for tool_name in tool_names:
ContextualMaxToolClass = get_contextual_tool_class(tool_name)
@@ -278,7 +284,6 @@ class RootNode(AssistantNode):
ContextualMaxToolClass.create_tool_class(
team=self._team,
user=self._user,
tool_call_id="",
state=state,
config=config,
context_manager=self.context_manager,
@@ -366,51 +371,23 @@ class RootNodeTools(AssistantNode):
def node_name(self) -> MaxNodeName:
return AssistantNodeName.ROOT_TOOLS
async def arun(self, state: AssistantState, config: RunnableConfig) -> PartialAssistantState:
async def arun(self, state: AssistantState, config: RunnableConfig) -> PartialAssistantState | None:
last_message = state.messages[-1]
if not isinstance(last_message, AssistantMessage) or not last_message.tool_calls:
# Reset tools.
return PartialAssistantState(root_tool_calls_count=0)
tool_call_count = state.root_tool_calls_count or 0
reset_state = PartialAssistantState(root_tool_call_id=None)
# Should never happen, but just in case.
if not isinstance(last_message, AssistantMessage) or not last_message.id or not state.root_tool_call_id:
return reset_state
tools_calls = last_message.tool_calls
if len(tools_calls) != 1:
raise ValueError("Expected exactly one tool call.")
tool_names = self.context_manager.get_contextual_tools().keys()
is_editing_insight = AssistantTool.EDIT_CURRENT_INSIGHT in tool_names
tool_call = tools_calls[0]
# Find the current tool call in the last message.
tool_call = next(
(tool_call for tool_call in last_message.tool_calls or [] if tool_call.id == state.root_tool_call_id), None
)
if not tool_call:
return reset_state
from ee.hogai.tool import get_contextual_tool_class
if tool_call.name == "create_and_query_insight" and not is_editing_insight:
return PartialAssistantState(
root_tool_call_id=tool_call.id,
root_tool_insight_plan=tool_call.args["query_description"],
root_tool_calls_count=tool_call_count + 1,
)
if tool_call.name == "session_summarization":
return PartialAssistantState(
root_tool_call_id=tool_call.id,
session_summarization_query=tool_call.args["session_summarization_query"],
# Safety net in case the argument is missing to avoid raising exceptions internally
should_use_current_filters=tool_call.args.get("should_use_current_filters", False),
summary_title=tool_call.args.get("summary_title"),
root_tool_calls_count=tool_call_count + 1,
)
if tool_call.name == "create_dashboard":
raw_queries = tool_call.args["search_insights_queries"]
search_insights_queries = [InsightQuery.model_validate(query) for query in raw_queries]
return PartialAssistantState(
root_tool_call_id=tool_call.id,
dashboard_name=tool_call.args.get("dashboard_name"),
search_insights_queries=search_insights_queries,
root_tool_calls_count=tool_call_count + 1,
)
# MaxTool flow
ToolClass = get_contextual_tool_class(tool_call.name)
# If the tool doesn't exist, return the message to the agent
@@ -423,25 +400,32 @@ class RootNodeTools(AssistantNode):
tool_call_id=tool_call.id,
)
],
root_tool_calls_count=tool_call_count + 1,
)
# Initialize the tool and process it
tool_class = await ToolClass.create_tool_class(
team=self._team,
user=self._user,
tool_call_id=tool_call.id,
# Tricky: set the node path to associated with the tool call
node_path=(
*self.node_path[:-1],
NodePath(name=AssistantNodeName.ROOT_TOOLS, message_id=last_message.id, tool_call_id=tool_call.id),
),
state=state,
config=config,
context_manager=self.context_manager,
)
try:
result = await tool_class.ainvoke(tool_call.model_dump(), config)
result = await tool_class.ainvoke(
ToolCall(type="tool_call", name=tool_call.name, args=tool_call.args, id=tool_call.id), config=config
)
if not isinstance(result, LangchainToolMessage):
raise ValueError(
f"Tool '{tool_call.name}' returned {type(result).__name__}, expected LangchainToolMessage"
)
except Exception as e:
logger.exception("Error calling tool", extra={"tool_name": tool_call.name, "error": str(e)})
capture_exception(
e, distinct_id=self._get_user_distinct_id(config), properties=self._get_debug_props(config)
)
@@ -453,38 +437,11 @@ class RootNodeTools(AssistantNode):
tool_call_id=tool_call.id,
)
],
root_tool_calls_count=tool_call_count + 1,
)
if isinstance(result.artifact, ToolMessagesArtifact):
return PartialAssistantState(
messages=result.artifact.messages,
root_tool_calls_count=tool_call_count + 1,
)
# Handle the basic toolkit
if result.name == "search" and isinstance(result.artifact, dict):
match result.artifact.get("kind"):
case "insights":
return PartialAssistantState(
root_tool_call_id=tool_call.id,
search_insights_query=result.artifact.get("query"),
root_tool_calls_count=tool_call_count + 1,
)
case "docs":
return PartialAssistantState(
root_tool_call_id=tool_call.id,
root_tool_calls_count=tool_call_count + 1,
)
if (
result.name == "read_data"
and isinstance(result.artifact, dict)
and result.artifact.get("kind") == "billing_info"
):
return PartialAssistantState(
root_tool_call_id=tool_call.id,
root_tool_calls_count=tool_call_count + 1,
)
# If this is a navigation tool call, pause the graph execution
@@ -496,7 +453,6 @@ class RootNodeTools(AssistantNode):
id=str(uuid4()),
tool_call_id=tool_call.id,
)
self.dispatcher.message(navigate_message)
# Raising a `NodeInterrupt` ensures the assistant graph stops here and
# surfaces the navigation confirmation to the client. The next user
# interaction will resume the graph with potentially different
@@ -512,29 +468,11 @@ class RootNodeTools(AssistantNode):
return PartialAssistantState(
messages=[tool_message],
root_tool_calls_count=tool_call_count + 1,
)
def router(self, state: AssistantState) -> RouteName:
# This is only for the Inkeep node. Remove when inkeep_docs is removed.
def router(self, state: AssistantState) -> Literal["root", "end"]:
last_message = state.messages[-1]
if isinstance(last_message, AssistantToolCallMessage):
return "root" # Let the root either proceed or finish, since it now can see the tool call result
if isinstance(last_message, AssistantMessage) and state.root_tool_call_id:
tool_calls = getattr(last_message, "tool_calls", None)
if tool_calls and len(tool_calls) > 0:
tool_call = tool_calls[0]
tool_call_name = tool_call.name
if tool_call_name == "read_data" and tool_call.args.get("kind") == "billing_info":
return "billing"
if tool_call_name == "create_dashboard":
return "create_dashboard"
if state.root_tool_insight_plan:
return "insights"
elif state.search_insights_query:
return "insights_search"
elif state.session_summarization_query:
return "session_summarization"
else:
return "search_documentation"
return "end"

View File

@@ -112,6 +112,10 @@ The user is a product engineer and will primarily request you perform product ma
- Tool results and user messages may include <system_reminder> tags. <system_reminder> tags contain useful information and reminders. They are NOT part of the user's provided input or the tool result.
</doing_tasks>
<tool_usage_policy>
- You can invoke multiple tools within a single response. When a request involves several independent pieces of information, batch your tool calls together for optimal performance
</tool_usage_policy>
{{{billing_context}}}
{{{core_memory_prompt}}}

View File

@@ -38,6 +38,7 @@ from ee.hogai.graph.root.prompts import (
ROOT_BILLING_CONTEXT_WITH_ACCESS_PROMPT,
ROOT_BILLING_CONTEXT_WITH_NO_ACCESS_PROMPT,
)
from ee.hogai.tool import ToolMessagesArtifact
from ee.hogai.utils.tests import FakeChatAnthropic, FakeChatOpenAI
from ee.hogai.utils.types import AssistantState, PartialAssistantState
from ee.hogai.utils.types.base import AssistantMessageUnion
@@ -644,6 +645,147 @@ class TestRootNode(ClickhouseTestMixin, BaseTest):
result = node._is_hard_limit_reached(tool_calls_count)
self.assertEqual(result, expected)
async def test_node_increments_tool_count_on_tool_call(self):
"""Test that RootNode increments tool count when assistant makes a tool call"""
with patch(
"ee.hogai.graph.root.nodes.RootNode._get_model",
return_value=FakeChatOpenAI(
responses=[
LangchainAIMessage(
content="Let me help",
tool_calls=[
{
"id": "tool-1",
"name": "create_and_query_insight",
"args": {"query_description": "test"},
}
],
)
]
),
):
node = RootNode(self.team, self.user)
# Test starting from no tool calls
state_1 = AssistantState(messages=[HumanMessage(content="Hello")])
result_1 = await node.arun(state_1, {})
self.assertEqual(result_1.root_tool_calls_count, 1)
# Test incrementing from existing count
state_2 = AssistantState(
messages=[HumanMessage(content="Hello")],
root_tool_calls_count=5,
)
result_2 = await node.arun(state_2, {})
self.assertEqual(result_2.root_tool_calls_count, 6)
async def test_node_resets_tool_count_on_plain_response(self):
"""Test that RootNode resets tool count when assistant responds without tool calls"""
with patch(
"ee.hogai.graph.root.nodes.RootNode._get_model",
return_value=FakeChatOpenAI(responses=[LangchainAIMessage(content="Here's your answer")]),
):
node = RootNode(self.team, self.user)
state = AssistantState(
messages=[HumanMessage(content="Hello")],
root_tool_calls_count=5,
)
result = await node.arun(state, {})
self.assertIsNone(result.root_tool_calls_count)
def test_router_returns_end_for_plain_response(self):
"""Test that router returns END when message has no tool calls"""
from ee.hogai.utils.types import AssistantNodeName
node = RootNode(self.team, self.user)
state = AssistantState(
messages=[
HumanMessage(content="Hello"),
AssistantMessage(content="Hi there!"),
]
)
result = node.router(state)
self.assertEqual(result, AssistantNodeName.END)
def test_router_returns_send_for_single_tool_call(self):
"""Test that router returns Send for single tool call"""
from langgraph.types import Send
from ee.hogai.utils.types import AssistantNodeName
node = RootNode(self.team, self.user)
state = AssistantState(
messages=[
HumanMessage(content="Generate insights"),
AssistantMessage(
content="Let me help",
tool_calls=[
AssistantToolCall(
id="tool-1",
name="create_and_query_insight",
args={"query_description": "test"},
)
],
),
]
)
result = node.router(state)
# Verify it's a list of Send objects
self.assertIsInstance(result, list)
self.assertEqual(len(result), 1)
self.assertIsInstance(result[0], Send)
self.assertEqual(result[0].node, AssistantNodeName.ROOT_TOOLS)
self.assertEqual(result[0].arg.root_tool_call_id, "tool-1")
def test_router_returns_multiple_sends_for_parallel_tool_calls(self):
"""Test that router returns multiple Send objects for parallel tool calls"""
from langgraph.types import Send
from ee.hogai.utils.types import AssistantNodeName
node = RootNode(self.team, self.user)
state = AssistantState(
messages=[
HumanMessage(content="Generate multiple insights"),
AssistantMessage(
content="Let me create several insights",
tool_calls=[
AssistantToolCall(
id="tool-1",
name="create_and_query_insight",
args={"query_description": "trends"},
),
AssistantToolCall(
id="tool-2",
name="create_and_query_insight",
args={"query_description": "funnel"},
),
AssistantToolCall(
id="tool-3",
name="create_and_query_insight",
args={"query_description": "retention"},
),
],
),
]
)
result = node.router(state)
# Verify it's a list of Send objects
self.assertIsInstance(result, list)
self.assertEqual(len(result), 3)
# Verify all are Send objects to ROOT_TOOLS
for i, send in enumerate(result):
self.assertIsInstance(send, Send)
self.assertEqual(send.node, AssistantNodeName.ROOT_TOOLS)
self.assertEqual(send.arg.root_tool_call_id, f"tool-{i+1}")
class TestRootNodeTools(BaseTest):
def test_node_tools_router(self):
@@ -675,9 +817,13 @@ class TestRootNodeTools(BaseTest):
node = RootNodeTools(self.team, self.user)
state = AssistantState(messages=[HumanMessage(content="Hello")])
result = await node.arun(state, {})
self.assertEqual(result, PartialAssistantState(root_tool_calls_count=0))
self.assertEqual(result, PartialAssistantState(root_tool_call_id=None))
@patch("ee.hogai.graph.root.tools.create_and_query_insight.CreateAndQueryInsightTool._arun_impl")
async def test_run_valid_tool_call(self, create_and_query_insight_mock):
test_message = AssistantToolCallMessage(content="Tool result", tool_call_id="xyz", id="msg-1")
create_and_query_insight_mock.return_value = ("", ToolMessagesArtifact(messages=[test_message]))
async def test_run_valid_tool_call(self):
node = RootNodeTools(self.team, self.user)
state = AssistantState(
messages=[
@@ -688,17 +834,20 @@ class TestRootNodeTools(BaseTest):
AssistantToolCall(
id="xyz",
name="create_and_query_insight",
args={"query_kind": "trends", "query_description": "test query"},
args={"query_description": "test query"},
)
],
)
]
],
root_tool_call_id="xyz",
)
result = await node.arun(state, {})
self.assertIsInstance(result, PartialAssistantState)
self.assertEqual(result.root_tool_call_id, "xyz")
self.assertEqual(result.root_tool_insight_plan, "test query")
self.assertEqual(result.root_tool_insight_type, None) # Insight type is determined by query planner node
assert result is not None
self.assertEqual(len(result.messages), 1)
assert isinstance(result.messages[0], AssistantToolCallMessage)
self.assertEqual(result.messages[0].tool_call_id, "xyz")
create_and_query_insight_mock.assert_called_once_with(query_description="test query")
async def test_run_valid_contextual_tool_call(self):
node = RootNodeTools(self.team, self.user)
@@ -715,7 +864,8 @@ class TestRootNodeTools(BaseTest):
)
],
)
]
],
root_tool_call_id="xyz",
)
result = await node.arun(
@@ -730,69 +880,9 @@ class TestRootNodeTools(BaseTest):
)
self.assertIsInstance(result, PartialAssistantState)
self.assertEqual(result.root_tool_call_id, None) # Tool was fully handled by the node
self.assertIsNone(result.root_tool_insight_plan) # No insight plan for contextual tools
self.assertIsNone(result.root_tool_insight_type) # No insight type for contextual tools
async def test_run_multiple_tool_calls_raises(self):
node = RootNodeTools(self.team, self.user)
state = AssistantState(
messages=[
AssistantMessage(
content="Hello",
id="test-id",
tool_calls=[
AssistantToolCall(
id="xyz1",
name="create_and_query_insight",
args={"query_kind": "trends", "query_description": "test query 1"},
),
AssistantToolCall(
id="xyz2",
name="create_and_query_insight",
args={"query_kind": "funnel", "query_description": "test query 2"},
),
],
)
]
)
with self.assertRaises(ValueError) as cm:
await node.arun(state, {})
self.assertEqual(str(cm.exception), "Expected exactly one tool call.")
async def test_run_increments_tool_count(self):
node = RootNodeTools(self.team, self.user)
state = AssistantState(
messages=[
AssistantMessage(
content="Hello",
id="test-id",
tool_calls=[
AssistantToolCall(
id="xyz",
name="create_and_query_insight",
args={"query_kind": "trends", "query_description": "test query"},
)
],
)
],
root_tool_calls_count=2, # Starting count
)
result = await node.arun(state, {})
self.assertEqual(result.root_tool_calls_count, 3) # Should increment by 1
async def test_run_resets_tool_count(self):
node = RootNodeTools(self.team, self.user)
# Test reset when no tool calls in AssistantMessage
state_1 = AssistantState(messages=[AssistantMessage(content="Hello", tool_calls=[])], root_tool_calls_count=3)
result = await node.arun(state_1, {})
self.assertEqual(result.root_tool_calls_count, 0)
# Test reset when last message is HumanMessage
state_2 = AssistantState(messages=[HumanMessage(content="Hello")], root_tool_calls_count=3)
result = await node.arun(state_2, {})
self.assertEqual(result.root_tool_calls_count, 0)
assert result is not None
self.assertEqual(len(result.messages), 1)
self.assertIsInstance(result.messages[0], AssistantToolCallMessage)
async def test_navigate_tool_call_raises_node_interrupt(self):
"""Test that navigate tool calls raise NodeInterrupt to pause graph execution"""
@@ -805,7 +895,8 @@ class TestRootNodeTools(BaseTest):
id="test-id",
tool_calls=[AssistantToolCall(id="nav-123", name="navigate", args={"page_key": "insights"})],
)
]
],
root_tool_call_id="nav-123",
)
mock_navigate_tool = AsyncMock()
@@ -828,319 +919,6 @@ class TestRootNodeTools(BaseTest):
self.assertEqual(interrupt_data.tool_call_id, "nav-123")
self.assertEqual(interrupt_data.ui_payload, {"navigate": {"page_key": "insights"}})
def test_billing_tool_routing(self):
"""Test that billing tool calls are routed correctly"""
node = RootNodeTools(self.team, self.user)
# Create state with billing tool call (read_data with kind=billing_info)
state = AssistantState(
messages=[
AssistantMessage(
content="Let me check your billing information",
tool_calls=[AssistantToolCall(id="billing-123", name="read_data", args={"kind": "billing_info"})],
)
],
root_tool_call_id="billing-123",
)
# Should route to billing
self.assertEqual(node.router(state), "billing")
def test_router_insights_path(self):
"""Test router routes to insights when root_tool_insight_plan is set"""
node = RootNodeTools(self.team, self.user)
state = AssistantState(
messages=[
AssistantMessage(
content="Creating insight",
tool_calls=[
AssistantToolCall(
id="insight-123",
name="create_and_query_insight",
args={"query_kind": "trends", "query_description": "test"},
)
],
)
],
root_tool_call_id="insight-123",
root_tool_insight_plan="test query plan",
)
self.assertEqual(node.router(state), "insights")
def test_router_insights_search_path(self):
"""Test router routes to insights_search when search_insights_query is set"""
node = RootNodeTools(self.team, self.user)
state = AssistantState(
messages=[
AssistantMessage(
content="Searching insights",
tool_calls=[AssistantToolCall(id="search-123", name="search", args={"kind": "insights"})],
)
],
root_tool_call_id="search-123",
search_insights_query="test search query",
)
self.assertEqual(node.router(state), "insights_search")
def test_router_session_summarization_path(self):
"""Test router routes to session_summarization when session_summarization_query is set"""
node = RootNodeTools(self.team, self.user)
state = AssistantState(
messages=[
AssistantMessage(
content="Summarizing sessions",
tool_calls=[
AssistantToolCall(
id="session-123", name="session_summarization", args={"session_summarization_query": "test"}
)
],
)
],
root_tool_call_id="session-123",
session_summarization_query="test session query",
)
self.assertEqual(node.router(state), "session_summarization")
def test_router_create_dashboard_path(self):
"""Test router routes to create_dashboard when create_dashboard tool is called"""
node = RootNodeTools(self.team, self.user)
state = AssistantState(
messages=[
AssistantMessage(
content="Creating dashboard",
tool_calls=[
AssistantToolCall(
id="dashboard-123", name="create_dashboard", args={"search_insights_queries": []}
)
],
)
],
root_tool_call_id="dashboard-123",
)
self.assertEqual(node.router(state), "create_dashboard")
def test_router_search_documentation_fallback(self):
"""Test router routes to search_documentation when root_tool_call_id is set but no specific route"""
node = RootNodeTools(self.team, self.user)
state = AssistantState(
messages=[
AssistantMessage(
content="Searching docs",
tool_calls=[AssistantToolCall(id="search-123", name="search", args={"kind": "docs"})],
)
],
root_tool_call_id="search-123",
)
self.assertEqual(node.router(state), "search_documentation")
async def test_arun_session_summarization_with_all_args(self):
"""Test session_summarization tool call with all arguments"""
node = RootNodeTools(self.team, self.user)
state = AssistantState(
messages=[
AssistantMessage(
content="Summarizing sessions",
id="test-id",
tool_calls=[
AssistantToolCall(
id="session-123",
name="session_summarization",
args={
"session_summarization_query": "test query",
"should_use_current_filters": True,
"summary_title": "Test Summary",
},
)
],
)
]
)
result = await node.arun(state, {})
self.assertIsInstance(result, PartialAssistantState)
self.assertEqual(result.root_tool_call_id, "session-123")
self.assertEqual(result.session_summarization_query, "test query")
self.assertEqual(result.should_use_current_filters, True)
self.assertEqual(result.summary_title, "Test Summary")
async def test_arun_session_summarization_missing_optional_args(self):
"""Test session_summarization tool call with missing optional arguments"""
node = RootNodeTools(self.team, self.user)
state = AssistantState(
messages=[
AssistantMessage(
content="Summarizing sessions",
id="test-id",
tool_calls=[
AssistantToolCall(
id="session-123",
name="session_summarization",
args={"session_summarization_query": "test query"},
)
],
)
]
)
result = await node.arun(state, {})
self.assertIsInstance(result, PartialAssistantState)
self.assertEqual(result.should_use_current_filters, False) # Default value
self.assertIsNone(result.summary_title)
async def test_arun_create_dashboard_with_queries(self):
"""Test create_dashboard tool call with search_insights_queries"""
node = RootNodeTools(self.team, self.user)
state = AssistantState(
messages=[
AssistantMessage(
content="Creating dashboard",
id="test-id",
tool_calls=[
AssistantToolCall(
id="dashboard-123",
name="create_dashboard",
args={
"dashboard_name": "Test Dashboard",
"search_insights_queries": [
{"name": "Query 1", "description": "Trends insight description"},
{"name": "Query 2", "description": "Funnel insight description"},
],
},
)
],
)
]
)
result = await node.arun(state, {})
self.assertIsInstance(result, PartialAssistantState)
self.assertEqual(result.root_tool_call_id, "dashboard-123")
self.assertEqual(result.dashboard_name, "Test Dashboard")
self.assertIsNotNone(result.search_insights_queries)
assert result.search_insights_queries is not None
self.assertEqual(len(result.search_insights_queries), 2)
self.assertEqual(result.search_insights_queries[0].name, "Query 1")
self.assertEqual(result.search_insights_queries[1].name, "Query 2")
async def test_arun_search_tool_insights_kind(self):
"""Test search tool with kind=insights"""
node = RootNodeTools(self.team, self.user)
state = AssistantState(
messages=[
AssistantMessage(
content="Searching insights",
id="test-id",
tool_calls=[
AssistantToolCall(id="search-123", name="search", args={"query": "test", "kind": "insights"})
],
)
]
)
mock_tool = AsyncMock()
mock_tool.ainvoke.return_value = LangchainToolMessage(
content="Search results",
tool_call_id="search-123",
name="search",
artifact={"kind": "insights", "query": "test"},
)
with mock_contextual_tool(mock_tool):
result = await node.arun(state, {"configurable": {"contextual_tools": {"search": {}}}})
self.assertIsInstance(result, PartialAssistantState)
self.assertEqual(result.root_tool_call_id, "search-123")
self.assertEqual(result.search_insights_query, "test")
async def test_arun_search_tool_docs_kind(self):
"""Test search tool with kind=docs"""
node = RootNodeTools(self.team, self.user)
state = AssistantState(
messages=[
AssistantMessage(
content="Searching docs",
id="test-id",
tool_calls=[
AssistantToolCall(id="search-123", name="search", args={"query": "test", "kind": "docs"})
],
)
]
)
mock_tool = AsyncMock()
mock_tool.ainvoke.return_value = LangchainToolMessage(
content="Docs results", tool_call_id="search-123", name="search", artifact={"kind": "docs"}
)
with mock_contextual_tool(mock_tool):
result = await node.arun(state, {"configurable": {"contextual_tools": {"search": {}}}})
self.assertIsInstance(result, PartialAssistantState)
self.assertEqual(result.root_tool_call_id, "search-123")
async def test_arun_read_data_billing_info(self):
"""Test read_data tool with kind=billing_info"""
node = RootNodeTools(self.team, self.user)
state = AssistantState(
messages=[
AssistantMessage(
content="Reading billing info",
id="test-id",
tool_calls=[AssistantToolCall(id="read-123", name="read_data", args={"kind": "billing_info"})],
)
]
)
mock_tool = AsyncMock()
mock_tool.ainvoke.return_value = LangchainToolMessage(
content="Billing data", tool_call_id="read-123", name="read_data", artifact={"kind": "billing_info"}
)
with mock_contextual_tool(mock_tool):
result = await node.arun(state, {"configurable": {"contextual_tools": {"read_data": {}}}})
self.assertIsInstance(result, PartialAssistantState)
self.assertEqual(result.root_tool_call_id, "read-123")
async def test_arun_tool_updates_state(self):
"""Test that when a tool updates its _state, the new messages are included"""
node = RootNodeTools(self.team, self.user)
state = AssistantState(
messages=[
AssistantMessage(
content="Using tool",
id="test-id",
tool_calls=[AssistantToolCall(id="tool-123", name="test_tool", args={})],
)
]
)
mock_tool = AsyncMock()
# Simulate tool appending a message to state
updated_state = AssistantState(
messages=[
*state.messages,
AssistantToolCallMessage(content="Tool result", tool_call_id="tool-123", id="msg-1"),
]
)
mock_tool._state = updated_state
mock_tool.ainvoke.return_value = LangchainToolMessage(content="Tool result", tool_call_id="tool-123")
with mock_contextual_tool(mock_tool):
result = await node.arun(state, {"configurable": {"contextual_tools": {"test_tool": {}}}})
# Should include the new message from the updated state
self.assertIsInstance(result, PartialAssistantState)
self.assertEqual(len(result.messages), 1)
self.assertIsInstance(result.messages[0], AssistantToolCallMessage)
async def test_arun_tool_returns_wrong_type_returns_error_message(self):
"""Test that tool returning wrong type returns an error message"""
node = RootNodeTools(self.team, self.user)
@@ -1151,7 +929,8 @@ class TestRootNodeTools(BaseTest):
id="test-id",
tool_calls=[AssistantToolCall(id="tool-123", name="test_tool", args={})],
)
]
],
root_tool_call_id="tool-123",
)
mock_tool = AsyncMock()
@@ -1161,10 +940,11 @@ class TestRootNodeTools(BaseTest):
result = await node.arun(state, {"configurable": {"contextual_tools": {"test_tool": {}}}})
self.assertIsInstance(result, PartialAssistantState)
assert result is not None
self.assertEqual(len(result.messages), 1)
assert isinstance(result.messages[0], AssistantToolCallMessage)
self.assertEqual(result.messages[0].tool_call_id, "tool-123")
self.assertEqual(result.root_tool_calls_count, 1)
self.assertIn("internal error", result.messages[0].content)
async def test_arun_unknown_tool_returns_error_message(self):
"""Test that unknown tool name returns an error message"""
@@ -1176,14 +956,16 @@ class TestRootNodeTools(BaseTest):
id="test-id",
tool_calls=[AssistantToolCall(id="tool-123", name="unknown_tool", args={})],
)
]
],
root_tool_call_id="tool-123",
)
with patch("ee.hogai.tool.get_contextual_tool_class", return_value=None):
result = await node.arun(state, {})
self.assertIsInstance(result, PartialAssistantState)
assert result is not None
self.assertEqual(len(result.messages), 1)
assert isinstance(result.messages[0], AssistantToolCallMessage)
self.assertEqual(result.messages[0].tool_call_id, "tool-123")
self.assertEqual(result.root_tool_calls_count, 1)
self.assertIn("does not exist", result.messages[0].content)

View File

@@ -1,21 +1,26 @@
from .legacy import create_and_query_insight, create_dashboard, session_summarization
from .create_and_query_insight import CreateAndQueryInsightTool, CreateAndQueryInsightToolArgs
from .create_dashboard import CreateDashboardTool, CreateDashboardToolArgs
from .navigate import NavigateTool, NavigateToolArgs
from .read_data import ReadDataTool, ReadDataToolArgs
from .read_taxonomy import ReadTaxonomyTool
from .search import SearchTool, SearchToolArgs
from .session_summarization import SessionSummarizationTool, SessionSummarizationToolArgs
from .todo_write import TodoWriteTool, TodoWriteToolArgs
__all__ = [
"CreateAndQueryInsightTool",
"CreateAndQueryInsightToolArgs",
"CreateDashboardTool",
"CreateDashboardToolArgs",
"NavigateTool",
"NavigateToolArgs",
"ReadDataTool",
"ReadDataToolArgs",
"ReadTaxonomyTool",
"SearchTool",
"SearchToolArgs",
"ReadDataTool",
"ReadDataToolArgs",
"SessionSummarizationTool",
"SessionSummarizationToolArgs",
"TodoWriteTool",
"TodoWriteToolArgs",
"NavigateTool",
"NavigateToolArgs",
"create_and_query_insight",
"session_summarization",
"create_dashboard",
]

View File

@@ -0,0 +1,206 @@
from typing import Literal
from pydantic import BaseModel, Field
from posthog.schema import AssistantTool, AssistantToolCallMessage, VisualizationMessage
from ee.hogai.context.context import AssistantContextManager
from ee.hogai.graph.insights_graph.graph import InsightsGraph
from ee.hogai.graph.schema_generator.nodes import SchemaGenerationException
from ee.hogai.tool import MaxTool, ToolMessagesArtifact
from ee.hogai.utils.prompt import format_prompt_string
from ee.hogai.utils.types.base import AssistantState
INSIGHT_TOOL_PROMPT = """
Use this tool to create a product analytics insight for a given natural language description by spawning a subagent.
The tool generates a query and returns formatted text results for a specific data question or iterates on a previous query. It only retrieves a single query per call. If the user asks for multiple insights, you need to decompose a query into multiple subqueries and call the tool for each subquery.
Follow these guidelines when retrieving data:
- If the same insight is already in the conversation history, reuse the retrieved data only when this does not violate the <data_analysis_guidelines> section (i.e. only when a presence-check, count, or sort on existing columns is enough).
- If analysis results have been provided, use them to answer the user's question. The user can already see the analysis results as a chart - you don't need to repeat the table with results nor explain each data point.
- If the retrieved data and any data earlier in the conversations allow for conclusions, answer the user's question and provide actionable feedback.
- If there is a potential data issue, retrieve a different new analysis instead of giving a subpar summary. Note: empty data is NOT a potential data issue.
- If the query cannot be answered with a UI-built insight type - trends, funnels, retention - choose the SQL type to answer the question (e.g. for listing events or aggregating in ways that aren't supported in trends/funnels/retention).
IMPORTANT: Avoid generic advice. Take into account what you know about the product. Your answer needs to be super high-impact and no more than a few sentences.
Remember: do NOT retrieve data for the same query more than 3 times in a row.
# Data schema
You can pass events, actions, properties, and property values to this tool by specifying the "Data schema" section.
<example>
User: Calculate onboarding completion rate for the last week.
Assistant: I'm going to retrieve the existing data schema first.
*Retrieves matching events, properties, and property values*
Assistant: I'm going to create a new trends insight.
*Calls this tool with the query description: "Trends insight of the onboarding completion rate. Data schema: Relevant matching data schema"*
</example>
# Supported insight types
## Trends
A trends insight visualizes events over time using time series. They're useful for finding patterns in historical data.
The trends insights have the following features:
- The insight can show multiple trends in one request.
- Custom formulas can calculate derived metrics, like `A/B*100` to calculate a ratio.
- Filter and break down data using multiple properties.
- Compare with the previous period and sample data.
- Apply various aggregation types, like sum, average, etc., and chart types.
- And more.
Examples of use cases include:
- How the product's most important metrics change over time.
- Long-term patterns, or cycles in product's usage.
- The usage of different features side-by-side.
- How the properties of events vary using aggregation (sum, average, etc).
- Users can also visualize the same data points in a variety of ways.
## Funnel
A funnel insight visualizes a sequence of events that users go through in a product. They use percentages as the primary aggregation type. Funnels use two or more series, so the conversation history should mention at least two events.
The funnel insights have the following features:
- Various visualization types (steps, time-to-convert, historical trends).
- Filter data and apply exclusion steps.
- Break down data using a single property.
- Specify conversion windows, details of conversion calculation, attribution settings.
- Sample data.
- And more.
Examples of use cases include:
- Conversion rates.
- Drop off steps.
- Steps with the highest friction and time to convert.
- If product changes are improving their funnel over time.
- Average/median time to convert.
- Conversion trends over time.
## Retention
A retention insight visualizes how many users return to the product after performing some action. They're useful for understanding user engagement and retention.
The retention insights have the following features: filter data, sample data, and more.
Examples of use cases include:
- How many users come back and perform an action after their first visit.
- How many users come back to perform action X after performing action Y.
- How often users return to use a specific feature.
## SQL
The 'sql' insight type allows you to write arbitrary SQL queries to retrieve data.
The SQL insights have the following features:
- Filter data using arbitrary SQL.
- All ClickHouse SQL features.
- You can nest subqueries as needed.
""".strip()
INSIGHT_TOOL_CONTEXT_PROMPT_TEMPLATE = """
The user is currently editing an insight (aka query). Here is that insight's current definition, which can be edited using the `create_and_query_insight` tool:
```json
{current_query}
```
<system_reminder>
Do not remove any fields from the current insight definition. Do not change any other fields than the ones the user asked for. Keep the rest as is.
</system_reminder>
""".strip()
INSIGHT_TOOL_FAILURE_SYSTEM_REMINDER_PROMPT = """
<system_reminder>
Inform the user that you've encountered an error during the creation of the insight. Afterwards, try to generate a new insight with a different query.
Terminate if the error persists.
</system_reminder>
""".strip()
INSIGHT_TOOL_HANDLED_FAILURE_PROMPT = """
The agent has encountered the error while creating an insight.
Generated output:
```
{{{output}}}
```
Error message:
```
{{{error_message}}}
```
{{{system_reminder}}}
""".strip()
INSIGHT_TOOL_UNHANDLED_FAILURE_PROMPT = """
The agent has encountered an unknown error while creating an insight.
{{{system_reminder}}}
""".strip()
class CreateAndQueryInsightToolArgs(BaseModel):
query_description: str = Field(
description=(
"A description of the query to generate, encapsulating the details of the user's request. "
"Include all relevant context from earlier messages too, as the tool won't see that conversation history. "
"If an existing insight has been used as a starting point, include that insight's filters and query in the description. "
"Don't be overly prescriptive with event or property names, unless the user indicated they mean this specific name (e.g. with quotes). "
"If the users seems to ask for a list of entities, rather than a count, state this explicitly."
)
)
class CreateAndQueryInsightTool(MaxTool):
name: Literal["create_and_query_insight"] = "create_and_query_insight"
args_schema: type[BaseModel] = CreateAndQueryInsightToolArgs
description: str = INSIGHT_TOOL_PROMPT
context_prompt_template: str = INSIGHT_TOOL_CONTEXT_PROMPT_TEMPLATE
thinking_message: str = "Coming up with an insight"
async def _arun_impl(self, query_description: str) -> tuple[str, ToolMessagesArtifact | None]:
graph = InsightsGraph(self._team, self._user).compile_full_graph()
new_state = self._state.model_copy(
update={
"root_tool_call_id": self.tool_call_id,
"root_tool_insight_plan": query_description,
},
deep=True,
)
try:
dict_state = await graph.ainvoke(new_state)
except SchemaGenerationException as e:
return format_prompt_string(
INSIGHT_TOOL_HANDLED_FAILURE_PROMPT,
output=e.llm_output,
error_message=e.validation_message,
system_reminder=INSIGHT_TOOL_FAILURE_SYSTEM_REMINDER_PROMPT,
), None
updated_state = AssistantState.model_validate(dict_state)
maybe_viz_message, tool_call_message = updated_state.messages[-2:]
if not isinstance(tool_call_message, AssistantToolCallMessage):
return format_prompt_string(
INSIGHT_TOOL_UNHANDLED_FAILURE_PROMPT, system_reminder=INSIGHT_TOOL_FAILURE_SYSTEM_REMINDER_PROMPT
), None
# If the previous message is not a visualization message, the agent has requested human feedback.
if not isinstance(maybe_viz_message, VisualizationMessage):
return "", ToolMessagesArtifact(messages=[tool_call_message])
# If the contextual tool is available, we're editing an insight.
# Add the UI payload to the tool call message.
if self.is_editing_mode(self._context_manager):
tool_call_message = AssistantToolCallMessage(
content=tool_call_message.content,
ui_payload={self.get_name(): maybe_viz_message.answer.model_dump(exclude_none=True)},
id=tool_call_message.id,
tool_call_id=tool_call_message.tool_call_id,
)
return "", ToolMessagesArtifact(messages=[maybe_viz_message, tool_call_message])
@classmethod
def is_editing_mode(cls, context_manager: AssistantContextManager) -> bool:
"""
Determines if the tool is in editing mode.
"""
return AssistantTool.CREATE_AND_QUERY_INSIGHT.value in context_manager.get_contextual_tools()

View File

@@ -0,0 +1,53 @@
from typing import Literal
from langchain_core.runnables import RunnableLambda
from pydantic import BaseModel, Field
from ee.hogai.graph.dashboards.nodes import DashboardCreationNode
from ee.hogai.tool import MaxTool, ToolMessagesArtifact
from ee.hogai.utils.types.base import AssistantState, InsightQuery, PartialAssistantState
CREATE_DASHBOARD_TOOL_PROMPT = """
Use this tool when users ask to create, build, or make a new dashboard with insights.
This tool will search for existing insights that match the user's requirements so no need to call `search` tool, or create new insights if none are found, then combine them into a dashboard.
Do not call this tool if the user only asks to find, search for, or look up existing insights and does not ask to create a dashboard.
If you decided to use this tool, there is no need to call `search_insights` tool beforehand. The tool will search for existing insights that match the user's requirements and create new insights if none are found.
""".strip()
class CreateDashboardToolArgs(BaseModel):
search_insights_queries: list[InsightQuery] = Field(
description="A list of insights to be included in the dashboard. Include all the insights that the user mentioned."
)
dashboard_name: str = Field(
description=(
"The name of the dashboard to be created based on the user request. It should be short and concise as it will be displayed as a header in the dashboard tile."
)
)
class CreateDashboardTool(MaxTool):
name: Literal["create_dashboard"] = "create_dashboard"
description: str = CREATE_DASHBOARD_TOOL_PROMPT
thinking_message: str = "Creating a dashboard"
context_prompt_template: str = "Creates a dashboard based on the user's request"
args_schema: type[BaseModel] = CreateDashboardToolArgs
show_tool_call_message: bool = False
async def _arun_impl(
self, search_insights_queries: list[InsightQuery], dashboard_name: str
) -> tuple[str, ToolMessagesArtifact | None]:
node = DashboardCreationNode(self._team, self._user)
chain: RunnableLambda[AssistantState, PartialAssistantState | None] = RunnableLambda(node)
copied_state = self._state.model_copy(
deep=True,
update={
"root_tool_call_id": self.tool_call_id,
"search_insights_queries": search_insights_queries,
"dashboard_name": dashboard_name,
},
)
result = await chain.ainvoke(copied_state)
if not result or not result.messages:
return "Dashboard creation failed", None
return "", ToolMessagesArtifact(messages=result.messages)

View File

@@ -3,10 +3,13 @@ from unittest.mock import Mock, patch
from django.conf import settings
from langchain_core.runnables import RunnableConfig
from parameterized import parameterized
from ee.hogai.graph.root.tools.full_text_search.tool import ENTITY_MAP, EntitySearchToolkit, FTSKind
from ee.hogai.context import AssistantContextManager
from ee.hogai.graph.root.tools.full_text_search.tool import ENTITY_MAP, EntitySearchTool, FTSKind
from ee.hogai.graph.shared_prompts import HYPERLINK_USAGE_INSTRUCTIONS
from ee.hogai.utils.types.base import AssistantState
class TestEntitySearchToolkit(NonAtomicBaseTest):
@@ -18,7 +21,13 @@ class TestEntitySearchToolkit(NonAtomicBaseTest):
self.team.organization = Mock()
self.team.organization.id = 789
self.user = Mock()
self.toolkit = EntitySearchToolkit(self.team, self.user)
self.toolkit = EntitySearchTool(
team=self.team,
user=self.user,
state=AssistantState(messages=[]),
config=RunnableConfig(configurable={}),
context_manager=AssistantContextManager(self.team, self.user, {}),
)
@parameterized.expand(
[

View File

@@ -6,13 +6,14 @@ import yaml
from posthoganalytics import capture_exception
from posthog.api.search import EntityConfig, search_entities
from posthog.models import Action, Cohort, Dashboard, Experiment, FeatureFlag, Insight, Survey, Team, User
from posthog.models import Action, Cohort, Dashboard, Experiment, FeatureFlag, Insight, Survey
from posthog.rbac.user_access_control import UserAccessControl
from posthog.sync import database_sync_to_async
from products.error_tracking.backend.models import ErrorTrackingIssue
from ee.hogai.graph.shared_prompts import HYPERLINK_USAGE_INSTRUCTIONS
from ee.hogai.tool import MaxSubtool
from .prompts import ENTITY_TYPE_SUMMARY_TEMPLATE, FOUND_ENTITIES_MESSAGE_TEMPLATE
@@ -91,14 +92,10 @@ SEARCH_KIND_TO_DATABASE_ENTITY_TYPE: dict[FTSKind, str] = {
}
class EntitySearchToolkit:
class EntitySearchTool(MaxSubtool):
MAX_ENTITY_RESULTS = 10
MAX_CONCURRENT_SEARCHES = 10
def __init__(self, team: Team, user: User):
self._team = team
self._user = user
async def execute(self, query: str, search_kind: FTSKind) -> str:
"""Search for entities by query and entity."""
try:

View File

@@ -1,204 +0,0 @@
# The module contains tools that are deprecated and will be replaced in the future with MaxTool implementations.
from pydantic import BaseModel, Field
from ee.hogai.utils.types.base import InsightQuery
# Lower casing matters here. Do not change it.
class create_and_query_insight(BaseModel):
"""
Use this tool to spawn a subagent that will create a product analytics insight for a given description.
The tool generates a query and returns formatted text results for a specific data question or iterates on a previous query. It only retrieves a single query per call. If the user asks for multiple insights, you need to decompose a query into multiple subqueries and call the tool for each subquery.
Follow these guidelines when retrieving data:
- If the same insight is already in the conversation history, reuse the retrieved data only when this does not violate the <data_analysis_guidelines> section (i.e. only when a presence-check, count, or sort on existing columns is enough).
- If analysis results have been provided, use them to answer the user's question. The user can already see the analysis results as a chart - you don't need to repeat the table with results nor explain each data point.
- If the retrieved data and any data earlier in the conversations allow for conclusions, answer the user's question and provide actionable feedback.
- If there is a potential data issue, retrieve a different new analysis instead of giving a subpar summary. Note: empty data is NOT a potential data issue.
- If the query cannot be answered with a UI-built insight type - trends, funnels, retention - choose the SQL type to answer the question (e.g. for listing events or aggregating in ways that aren't supported in trends/funnels/retention).
IMPORTANT: Avoid generic advice. Take into account what you know about the product. Your answer needs to be super high-impact and no more than a few sentences.
Remember: do NOT retrieve data for the same query more than 3 times in a row.
# Data schema
You can pass events, actions, properties, and property values to this tool by specifying the "Data schema" section.
<example>
User: Calculate onboarding completion rate for the last week.
Assistant: I'm going to retrieve the existing data schema first.
*Retrieves matching events, properties, and property values*
Assistant: I'm going to create a new trends insight.
*Calls this tool with the query description: "Trends insight of the onboarding completion rate. Data schema: Relevant matching data schema"*
</example>
# Supported insight types
## Trends
A trends insight visualizes events over time using time series. They're useful for finding patterns in historical data.
The trends insights have the following features:
- The insight can show multiple trends in one request.
- Custom formulas can calculate derived metrics, like `A/B*100` to calculate a ratio.
- Filter and break down data using multiple properties.
- Compare with the previous period and sample data.
- Apply various aggregation types, like sum, average, etc., and chart types.
- And more.
Examples of use cases include:
- How the product's most important metrics change over time.
- Long-term patterns, or cycles in product's usage.
- The usage of different features side-by-side.
- How the properties of events vary using aggregation (sum, average, etc).
- Users can also visualize the same data points in a variety of ways.
## Funnel
A funnel insight visualizes a sequence of events that users go through in a product. They use percentages as the primary aggregation type. Funnels use two or more series, so the conversation history should mention at least two events.
The funnel insights have the following features:
- Various visualization types (steps, time-to-convert, historical trends).
- Filter data and apply exclusion steps.
- Break down data using a single property.
- Specify conversion windows, details of conversion calculation, attribution settings.
- Sample data.
- And more.
Examples of use cases include:
- Conversion rates.
- Drop off steps.
- Steps with the highest friction and time to convert.
- If product changes are improving their funnel over time.
- Average/median time to convert.
- Conversion trends over time.
## Retention
A retention insight visualizes how many users return to the product after performing some action. They're useful for understanding user engagement and retention.
The retention insights have the following features: filter data, sample data, and more.
Examples of use cases include:
- How many users come back and perform an action after their first visit.
- How many users come back to perform action X after performing action Y.
- How often users return to use a specific feature.
## SQL
The 'sql' insight type allows you to write arbitrary SQL queries to retrieve data.
The SQL insights have the following features:
- Filter data using arbitrary SQL.
- All ClickHouse SQL features.
- You can nest subqueries as needed.
"""
query_description: str = Field(
description=(
"A description of the query to generate, encapsulating the details of the user's request. "
"Include all relevant context from earlier messages too, as the tool won't see that conversation history. "
"If an existing insight has been used as a starting point, include that insight's filters and query in the description. "
"Don't be overly prescriptive with event or property names, unless the user indicated they mean this specific name (e.g. with quotes). "
"If the users seems to ask for a list of entities, rather than a count, state this explicitly."
)
)
class session_summarization(BaseModel):
"""
Use this tool to summarize session recordings by analysing the events within those sessions to find patterns and issues.
It will return a textual summary of the captured session recordings.
# When to use the tool:
When the user asks to summarize session recordings:
- "summarize" synonyms: "watch", "analyze", "review", and similar
- "session recordings" synonyms: "sessions", "recordings", "replays", "user sessions", and similar
# When NOT to use the tool:
- When the user asks to find, search for, or look up session recordings, but doesn't ask to summarize them
- When users asks to update, change, or adjust session recordings filters
# Synonyms
- "summarize": "watch", "analyze", "review", and similar
- "session recordings": "sessions", "recordings", "replays", "user sessions", and similar
# Managing context
If the conversation history contains context about the current filters or session recordings, follow these steps:
- Convert the user query into a `session_summarization_query`
- The query should be used to understand the user's intent
- Decide if the query is relevant to the current filters and set `should_use_current_filters` accordingly
- Generate the `summary_title` based on the user's query and the current filters
Otherwise:
- Convert the user query into a `session_summarization_query`
- The query should be used to search for relevant sessions and then summarize them
- Assume the `should_use_current_filters` should be always `false`
- Generate the `summary_title` based on the user's query
# Additional guidelines
- CRITICAL: Always pass the user's complete, unmodified query to the `session_summarization_query` parameter
- DO NOT truncate, summarize, or extract keywords from the user's query
- The query is used to find relevant sessions - context helps find better matches
- Use explicit tool definition to make a decision
"""
session_summarization_query: str = Field(
description="""
- The user's complete query for session recordings summarization.
- This will be used to find relevant session recordings.
- Always pass the user's complete, unmodified query.
- Examples:
* 'summarize all session recordings from yesterday'
* 'analyze mobile user session recordings from last week, even if 1 second'
* 'watch last 300 session recordings of MacOS users from US'
* and similar
"""
)
should_use_current_filters: bool = Field(
description="""
- Whether to use current filters from user's UI to find relevant session recordings.
- IMPORTANT: Should be always `false` if the current filters or `search_session_recordings` tool are not present in the conversation history.
- Examples:
* Set to `true` if one of the conditions is met:
- the user wants to summarize "current/selected/opened/my/all/these" session recordings
- the user wants to use "current/these" filters
- the user's query specifies filters identical to the current filters
- if the user's query doesn't specify any filters/conditions
- the user refers to what they're "looking at" or "viewing"
* Set to `false` if one of the conditions is met:
- no current filters or `search_session_recordings` tool are present in the conversation
- the user specifies date/time period different from the current filters
- the user specifies conditions (user, device, id, URL, etc.) not present in the current filters
""",
)
summary_title: str = Field(
description="""
- The name of the summary that is expected to be generated from the user's `session_summarization_query` and/or `current_filters` (if present).
- The name should cover in 3-7 words what sessions would be to be summarized in the summary
- This won't be used for any search of filtering, only to properly label the generated summary.
- Examples:
* If `should_use_current_filters` is `false`, then the `summary_title` should be generated based on the `session_summarization_query`:
- query: "I want to watch all the sessions of user `user@example.com` in the last 30 days no matter how long" -> name: "Sessions of the user user@example.com (last 30 days)"
- query: "summarize my last 100 session recordings" -> name: "Last 100 sessions"
- and similar
* If `should_use_current_filters` is `true`, then the `summary_title` should be generated based on the current filters in the context (if present):
- filters: "{"key":"$os","value":["Mac OS X"],"operator":"exact","type":"event"}" -> name: "MacOS users"
- filters: "{"date_from": "-7d", "filter_test_accounts": True}" -> name: "All sessions (last 7 days)"
- and similar
* If there's not enough context to generated the summary name - keep it an empty string ("")
"""
)
class create_dashboard(BaseModel):
"""
Use this tool when users ask to create, build, or make a new dashboard with insights.
This tool will search for existing insights that match the user's requirements so no need to call `search_insights` tool, or create new insights if none are found, then combine them into a dashboard.
Do not call this tool if the user only asks to find, search for, or look up existing insights and does not ask to create a dashboard.
If you decided to use this tool, there is no need to call `search_insights` tool beforehand. The tool will search for existing insights that match the user's requirements and create new insights if none are found.
"""
search_insights_queries: list[InsightQuery] = Field(
description="A list of insights to be included in the dashboard. Include all the insights that the user mentioned."
)
dashboard_name: str = Field(
description=(
"The name of the dashboard to be created based on the user request. It should be short and concise as it will be displayed as a header in the dashboard tile."
)
)

View File

@@ -10,7 +10,7 @@ from posthog.models import Team, User
from ee.hogai.context.context import AssistantContextManager
from ee.hogai.tool import MaxTool
from ee.hogai.utils.prompt import format_prompt_string
from ee.hogai.utils.types.base import AssistantState
from ee.hogai.utils.types.base import AssistantState, NodePath
NAVIGATION_TOOL_PROMPT = """
Use the `navigate` tool to move between different pages in the PostHog application.
@@ -57,7 +57,7 @@ class NavigateTool(MaxTool):
*,
team: Team,
user: User,
tool_call_id: str,
node_path: tuple[NodePath, ...] | None = None,
state: AssistantState | None = None,
config: RunnableConfig | None = None,
context_manager: AssistantContextManager | None = None,
@@ -69,7 +69,7 @@ class NavigateTool(MaxTool):
return cls(
team=team,
user=user,
tool_call_id=tool_call_id,
node_path=node_path,
state=state,
config=config,
context_manager=context_manager,

View File

@@ -2,11 +2,12 @@ import datetime
from typing import cast
from uuid import uuid4
from posthog.test.base import BaseTest, ClickhouseTestMixin
from posthog.test.base import ClickhouseTestMixin, NonAtomicBaseTest
from unittest.mock import patch
from langchain_core.runnables import RunnableConfig
from posthog.schema import (
AssistantToolCallMessage,
BillingSpendResponseBreakdownType,
BillingUsageResponseBreakdownType,
MaxAddonInfo,
@@ -21,26 +22,30 @@ from posthog.schema import (
UsageHistoryItem,
)
from ee.hogai.graph.billing.nodes import BillingNode
from ee.hogai.context.context import AssistantContextManager
from ee.hogai.graph.root.tools.read_billing_tool.tool import ReadBillingTool
from ee.hogai.utils.types import AssistantState
class TestBillingNode(ClickhouseTestMixin, BaseTest):
class TestBillingNode(ClickhouseTestMixin, NonAtomicBaseTest):
CLASS_DATA_LEVEL_SETUP = False
def setUp(self):
super().setUp()
self.node = BillingNode(self.team, self.user)
self.tool_call_id = str(uuid4())
self.state = AssistantState(messages=[], root_tool_call_id=self.tool_call_id)
self.tool = ReadBillingTool(
team=self.team,
user=self.user,
state=AssistantState(messages=[], root_tool_call_id=str(uuid4())),
config=RunnableConfig(configurable={}),
context_manager=AssistantContextManager(self.team, self.user, {}),
)
def test_run_with_no_billing_context(self):
with patch.object(self.node.context_manager, "get_billing_context", return_value=None):
result = self.node.run(self.state, {})
self.assertEqual(len(result.messages), 1)
message = result.messages[0]
self.assertIsInstance(message, AssistantToolCallMessage)
self.assertEqual(cast(AssistantToolCallMessage, message).content, "No billing information available")
async def test_run_with_no_billing_context(self):
with patch.object(self.tool._context_manager, "get_billing_context", return_value=None):
result = await self.tool.execute()
self.assertEqual(result, "No billing information available")
def test_run_with_billing_context(self):
async def test_run_with_billing_context(self):
billing_context = MaxBillingContext(
subscription_level=MaxBillingContextSubscriptionLevel.PAID,
billing_plan="paid",
@@ -50,17 +55,13 @@ class TestBillingNode(ClickhouseTestMixin, BaseTest):
products=[],
)
with (
patch.object(self.node.context_manager, "get_billing_context", return_value=billing_context),
patch.object(self.node, "_format_billing_context", return_value="Formatted Context"),
patch.object(self.tool._context_manager, "get_billing_context", return_value=billing_context),
patch.object(self.tool, "_format_billing_context", return_value="Formatted Context"),
):
result = self.node.run(self.state, {})
self.assertEqual(len(result.messages), 1)
message = result.messages[0]
self.assertIsInstance(message, AssistantToolCallMessage)
self.assertEqual(cast(AssistantToolCallMessage, message).content, "Formatted Context")
self.assertEqual(cast(AssistantToolCallMessage, message).tool_call_id, self.tool_call_id)
result = await self.tool.execute()
self.assertEqual(result, "Formatted Context")
def test_format_billing_context(self):
async def test_format_billing_context(self):
billing_context = MaxBillingContext(
subscription_level=MaxBillingContextSubscriptionLevel.PAID,
billing_plan="paid",
@@ -89,8 +90,8 @@ class TestBillingNode(ClickhouseTestMixin, BaseTest):
settings=MaxBillingContextSettings(autocapture_on=True, active_destinations=2),
)
with patch.object(self.node, "_get_top_events_by_usage", return_value=[]):
formatted_string = self.node._format_billing_context(billing_context)
with patch.object(self.tool, "_get_top_events_by_usage", return_value=[]):
formatted_string = await self.tool._format_billing_context(billing_context)
self.assertIn("(paid)", formatted_string)
self.assertIn("Period: 2023-01-01 to 2023-01-31", formatted_string)
@@ -123,19 +124,21 @@ class TestBillingNode(ClickhouseTestMixin, BaseTest):
),
]
usage_table = self.node._format_history_table(usage_history)
usage_table = self.tool._format_history_table(usage_history)
self.assertIn("### Overall (all projects)", usage_table)
self.assertIn("| Data Type | 2023-01-01 | 2023-01-02 |", usage_table)
self.assertIn("| Recordings | 100.00 | 200.00 |", usage_table)
self.assertIn("| Events | 1.50 | 2.50 |", usage_table)
spend_table = self.node._format_history_table(spend_history)
spend_table = self.tool._format_history_table(spend_history)
self.assertIn("| Mobile Recordings | 10.50 | 20.00 |", spend_table)
def test_get_top_events_by_usage(self):
mock_results = [("pageview", 1000), ("$autocapture", 500)]
with patch("ee.hogai.graph.billing.nodes.sync_execute", return_value=mock_results) as mock_sync_execute:
top_events = self.node._get_top_events_by_usage()
with patch(
"ee.hogai.graph.root.tools.read_billing_tool.tool.sync_execute", return_value=mock_results
) as mock_sync_execute:
top_events = self.tool._get_top_events_by_usage()
self.assertEqual(len(top_events), 2)
self.assertEqual(top_events[0]["event"], "pageview")
self.assertEqual(top_events[0]["count"], 1000)
@@ -150,13 +153,14 @@ class TestBillingNode(ClickhouseTestMixin, BaseTest):
def test_get_top_events_by_usage_query_fails(self):
with patch(
"ee.hogai.graph.billing.nodes.sync_execute", side_effect=Exception("DB connection failed")
"ee.hogai.graph.root.tools.read_billing_tool.tool.sync_execute",
side_effect=Exception("DB connection failed"),
) as mock_sync_execute:
top_events = self.node._get_top_events_by_usage()
top_events = self.tool._get_top_events_by_usage()
self.assertEqual(top_events, [])
mock_sync_execute.assert_called_once()
def test_format_billing_context_with_addons(self):
async def test_format_billing_context_with_addons(self):
"""Test that addons are properly nested within products in the formatted output"""
billing_context = MaxBillingContext(
subscription_level=MaxBillingContextSubscriptionLevel.PAID,
@@ -230,14 +234,14 @@ class TestBillingNode(ClickhouseTestMixin, BaseTest):
)
with patch.object(
self.node,
self.tool,
"_get_top_events_by_usage",
return_value=[
{"event": "$pageview", "count": 50000, "formatted_count": "50,000"},
{"event": "$autocapture", "count": 30000, "formatted_count": "30,000"},
],
):
formatted_string = self.node._format_billing_context(billing_context)
formatted_string = await self.tool._format_billing_context(billing_context)
# Check basic info
self.assertIn("paid subscription (startup)", formatted_string)
@@ -274,7 +278,7 @@ class TestBillingNode(ClickhouseTestMixin, BaseTest):
self.assertIn("$pageview", formatted_string)
self.assertIn("50,000 events", formatted_string)
def test_format_billing_context_no_subscription(self):
async def test_format_billing_context_no_subscription(self):
"""Test formatting when user has no active subscription (free plan)"""
billing_context = MaxBillingContext(
subscription_level=MaxBillingContextSubscriptionLevel.FREE,
@@ -286,8 +290,8 @@ class TestBillingNode(ClickhouseTestMixin, BaseTest):
trial=MaxBillingContextTrial(is_active=True, expires_at="2023-02-01", target="teams"),
)
with patch.object(self.node, "_get_top_events_by_usage", return_value=[]):
formatted_string = self.node._format_billing_context(billing_context)
with patch.object(self.tool, "_get_top_events_by_usage", return_value=[]):
formatted_string = await self.tool._format_billing_context(billing_context)
self.assertIn("free subscription", formatted_string)
self.assertIn("Active subscription: No (Free plan)", formatted_string)
@@ -297,7 +301,7 @@ class TestBillingNode(ClickhouseTestMixin, BaseTest):
def test_format_history_table_with_team_breakdown(self):
"""Test that history tables properly group by team when breakdown includes team IDs"""
# Mock the teams map
self.node._teams_map = {
self.tool._teams_map = {
1: "Team Alpha (ID: 1)",
2: "Team Beta (ID: 2)",
}
@@ -337,7 +341,7 @@ class TestBillingNode(ClickhouseTestMixin, BaseTest):
),
]
table = self.node._format_history_table(usage_history)
table = self.tool._format_history_table(usage_history)
# Check team-specific tables
self.assertIn("### Team Alpha (ID: 1)", table)
@@ -349,7 +353,7 @@ class TestBillingNode(ClickhouseTestMixin, BaseTest):
self.assertIn("| Recordings | 100.00 | 200.00 |", table)
self.assertIn("| Feature Flag Requests | 50.00 | 100.00 |", table)
def test_format_billing_context_edge_cases(self):
async def test_format_billing_context_edge_cases(self):
"""Test edge cases and potential security issues"""
billing_context = MaxBillingContext(
subscription_level=MaxBillingContextSubscriptionLevel.CUSTOM,
@@ -374,8 +378,8 @@ class TestBillingNode(ClickhouseTestMixin, BaseTest):
settings=MaxBillingContextSettings(autocapture_on=True, active_destinations=0),
)
with patch.object(self.node, "_get_top_events_by_usage", return_value=[]):
formatted_string = self.node._format_billing_context(billing_context)
with patch.object(self.tool, "_get_top_events_by_usage", return_value=[]):
formatted_string = await self.tool._format_billing_context(billing_context)
# Check deactivated status
self.assertIn("Status: Account is deactivated", formatted_string)
@@ -393,7 +397,7 @@ class TestBillingNode(ClickhouseTestMixin, BaseTest):
# Check exceeded limit warning
self.assertIn("⚠️ Usage limit exceeded", formatted_string)
def test_format_billing_context_complete_template_coverage(self):
async def test_format_billing_context_complete_template_coverage(self):
"""Test all possible template variables are covered"""
billing_context = MaxBillingContext(
subscription_level=MaxBillingContextSubscriptionLevel.PAID,
@@ -472,14 +476,14 @@ class TestBillingNode(ClickhouseTestMixin, BaseTest):
)
# Mock teams map for history table
self.node._teams_map = {1: "Main Team (ID: 1)"}
self.tool._teams_map = {1: "Main Team (ID: 1)"}
with patch.object(
self.node,
self.tool,
"_get_top_events_by_usage",
return_value=[{"event": "$identify", "count": 10000, "formatted_count": "10,000"}],
):
formatted_string = self.node._format_billing_context(billing_context)
formatted_string = await self.tool._format_billing_context(billing_context)
# Verify all template sections are present
self.assertIn("<billing_context>", formatted_string)
@@ -545,12 +549,12 @@ class TestBillingNode(ClickhouseTestMixin, BaseTest):
),
]
self.node._teams_map = {
self.tool._teams_map = {
84444: "Project 84444",
12345: "Project 12345",
}
table = self.node._format_history_table(usage_history)
table = self.tool._format_history_table(usage_history)
# Should always include aggregated table first
self.assertIn("### Overall (all projects)", table)
@@ -597,7 +601,7 @@ class TestBillingNode(ClickhouseTestMixin, BaseTest):
),
]
table = self.node._format_history_table(usage_history)
table = self.tool._format_history_table(usage_history)
# Should include aggregated table
self.assertIn("### Overall (all projects)", table)
@@ -645,7 +649,7 @@ class TestBillingNode(ClickhouseTestMixin, BaseTest):
)
]
aggregated = self.node._create_aggregated_items(team_items, other_items)
aggregated = self.tool._create_aggregated_items(team_items, other_items)
# Should have 2 aggregated items: events and feature flags
self.assertEqual(len(aggregated), 2)
@@ -687,7 +691,7 @@ class TestBillingNode(ClickhouseTestMixin, BaseTest):
),
]
table = self.node._format_history_table(usage_history)
table = self.tool._format_history_table(usage_history)
# Should only have aggregated table, no team-specific tables
self.assertIn("### Overall (all projects)", table)
@@ -721,18 +725,18 @@ class TestBillingNode(ClickhouseTestMixin, BaseTest):
)
]
self.node._teams_map = {
self.tool._teams_map = {
123: "Project 123",
}
# Test usage history formatting
usage_table = self.node._format_history_table(usage_history)
usage_table = self.tool._format_history_table(usage_history)
self.assertIn("### Overall (all projects)", usage_table)
self.assertIn("### Project 123", usage_table)
self.assertIn("| Events | 1,000.00 |", usage_table)
# Test spend history formatting
spend_table = self.node._format_history_table(spend_history)
spend_table = self.tool._format_history_table(spend_history)
self.assertIn("### Overall (all projects)", spend_table)
self.assertIn("### Project 123", spend_table)
self.assertIn("| Events | 50.00 |", spend_table)
@@ -758,7 +762,7 @@ class TestBillingNode(ClickhouseTestMixin, BaseTest):
),
]
table = self.node._format_history_table(usage_history)
table = self.tool._format_history_table(usage_history)
# Should handle empty dates gracefully and still show valid data
self.assertIn("### Overall (all projects)", table)

View File

@@ -1,18 +1,19 @@
from typing import Any, cast
from uuid import uuid4
from typing import Any
from langchain_core.prompts import PromptTemplate
from langchain_core.runnables import RunnableConfig
from posthog.schema import AssistantToolCallMessage, MaxBillingContext, SpendHistoryItem, UsageHistoryItem
from posthog.schema import MaxBillingContext, SpendHistoryItem, UsageHistoryItem
from posthog.clickhouse.client import sync_execute
from posthog.models import Team, User
from posthog.sync import database_sync_to_async
from ee.hogai.graph.base import AssistantNode
from ee.hogai.graph.billing.prompts import BILLING_CONTEXT_PROMPT
from ee.hogai.context.context import AssistantContextManager
from ee.hogai.tool import MaxSubtool
from ee.hogai.utils.types import AssistantState
from ee.hogai.utils.types.base import AssistantNodeName, PartialAssistantState
from ee.hogai.utils.types.composed import MaxNodeName
from .prompts import BILLING_CONTEXT_PROMPT
# sync with frontend/src/scenes/billing/constants.ts
USAGE_TYPES = [
@@ -33,33 +34,27 @@ USAGE_TYPES = [
]
class BillingNode(AssistantNode):
_teams_map: dict[int, str] = {}
class ReadBillingTool(MaxSubtool):
def __init__(
self,
*,
team: Team,
user: User,
state: AssistantState,
config: RunnableConfig,
context_manager: AssistantContextManager,
):
super().__init__(team=team, user=user, state=state, config=config, context_manager=context_manager)
self._teams_map: dict[int, str] = {}
@property
def node_name(self) -> MaxNodeName:
return AssistantNodeName.BILLING
def run(self, state: AssistantState, config: RunnableConfig) -> PartialAssistantState:
tool_call_id = cast(str, state.root_tool_call_id)
billing_context = self.context_manager.get_billing_context()
async def execute(self) -> str:
billing_context = self._context_manager.get_billing_context()
if not billing_context:
return PartialAssistantState(
messages=[
AssistantToolCallMessage(
content="No billing information available", id=str(uuid4()), tool_call_id=tool_call_id
)
],
root_tool_call_id=None,
)
formatted_billing_context = self._format_billing_context(billing_context)
return PartialAssistantState(
messages=[
AssistantToolCallMessage(content=formatted_billing_context, tool_call_id=tool_call_id, id=str(uuid4())),
],
root_tool_call_id=None,
)
return "No billing information available"
formatted_billing_context = await self._format_billing_context(billing_context)
return formatted_billing_context
@database_sync_to_async(thread_sensitive=False)
def _format_billing_context(self, billing_context: MaxBillingContext) -> str:
"""Format billing context into a readable prompt section."""
# Convert billing context to a format suitable for the mustache template

View File

@@ -1,4 +1,4 @@
from typing import Any, Literal, Self
from typing import Literal, Self
from langchain_core.runnables import RunnableConfig
from pydantic import BaseModel
@@ -9,12 +9,14 @@ from ee.hogai.context.context import AssistantContextManager
from ee.hogai.graph.sql.mixins import HogQLDatabaseMixin
from ee.hogai.tool import MaxTool
from ee.hogai.utils.prompt import format_prompt_string
from ee.hogai.utils.types.base import AssistantState
from ee.hogai.utils.types.base import AssistantState, NodePath
from .read_billing_tool.tool import ReadBillingTool
READ_DATA_BILLING_PROMPT = """
# Billing information
Use this tool with the "billing_info" kind to retrieve the billing information if the user asks about billing, their subscription, their usage, or their spending.
Use this tool with the "billing_info" kind to retrieve the billing information if the user asks about their billing, subscription, product usage, spending, or cost reduction strategies.
You can use the information retrieved to check which PostHog products and add-ons the user has activated, how much they are spending, their usage history across all products in the last 30 days, as well as trials, spending limits, billing period, and more.
If the user wants to reduce their spending, always call this tool to get suggestions on how to do so.
If an insight shows zero data, it could mean either the query is looking at the wrong data or there was a temporary data collection issue. You can investigate potential dips in usage/captured data using the billing tool.
@@ -62,7 +64,7 @@ class ReadDataTool(HogQLDatabaseMixin, MaxTool):
*,
team: Team,
user: User,
tool_call_id: str,
node_path: tuple[NodePath, ...] | None = None,
state: AssistantState | None = None,
config: RunnableConfig | None = None,
context_manager: AssistantContextManager | None = None,
@@ -83,21 +85,29 @@ class ReadDataTool(HogQLDatabaseMixin, MaxTool):
return cls(
team=team,
user=user,
tool_call_id=tool_call_id,
state=state,
node_path=node_path,
config=config,
args_schema=args,
description=description,
context_manager=context_manager,
)
async def _arun_impl(self, kind: ReadDataAdminAccessKind | ReadDataKind) -> tuple[str, dict[str, Any] | None]:
async def _arun_impl(self, kind: ReadDataAdminAccessKind | ReadDataKind) -> tuple[str, None]:
match kind:
case "billing_info":
has_access = await self._context_manager.check_user_has_billing_access()
if not has_access:
return BILLING_INSUFFICIENT_ACCESS_PROMPT, None
# used for routing
return "", self.args_schema(kind=kind).model_dump()
billing_tool = ReadBillingTool(
team=self._team,
user=self._user,
state=self._state,
config=self._config,
context_manager=self._context_manager,
)
result = await billing_tool.execute()
return result, None
case "datawarehouse_schema":
return await self._serialize_database_schema(), None

View File

@@ -9,7 +9,7 @@ from ee.hogai.context.context import AssistantContextManager
from ee.hogai.graph.query_planner.toolkit import TaxonomyAgentToolkit
from ee.hogai.tool import MaxTool
from ee.hogai.utils.helpers import format_events_yaml
from ee.hogai.utils.types.base import AssistantState
from ee.hogai.utils.types.base import AssistantState, NodePath
READ_TAXONOMY_TOOL_DESCRIPTION = """
Use this tool to explore the user's taxonomy (i.e. data schema).
@@ -155,7 +155,7 @@ class ReadTaxonomyTool(MaxTool):
*,
team: Team,
user: User,
tool_call_id: str,
node_path: tuple[NodePath, ...] | None = None,
state: AssistantState | None = None,
config: RunnableConfig | None = None,
context_manager: AssistantContextManager | None = None,
@@ -200,9 +200,9 @@ class ReadTaxonomyTool(MaxTool):
return cls(
team=team,
user=user,
tool_call_id=tool_call_id,
state=state,
config=config,
node_path=node_path,
args_schema=ReadTaxonomyToolArgsWithGroups,
context_manager=context_manager,
)

View File

@@ -1,17 +1,19 @@
from typing import Any, Literal
from typing import Literal
from django.conf import settings
import posthoganalytics
from langchain_core.output_parsers import SimpleJsonOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnableLambda
from langchain_openai import ChatOpenAI
from pydantic import BaseModel, Field
from posthog.models import Team, User
from ee.hogai.graph.root.tools.full_text_search.tool import EntitySearchToolkit, FTSKind
from ee.hogai.tool import MaxTool
from ee.hogai.graph.insights.nodes import InsightSearchNode, NoInsightsException
from ee.hogai.graph.root.tools.full_text_search.tool import EntitySearchTool, FTSKind
from ee.hogai.tool import MaxSubtool, MaxTool, ToolMessagesArtifact
from ee.hogai.utils.prompt import format_prompt_string
from ee.hogai.utils.types.base import AssistantState, PartialAssistantState
SEARCH_TOOL_PROMPT = """
Use this tool to search docs, insights, dashboards, cohorts, actions, experiments, feature flags, notebooks, error tracking issues, and surveys in PostHog.
@@ -57,35 +59,10 @@ If you want to search for all entities, you should use `all`.
""".strip()
DOCS_SEARCH_RESULTS_TEMPLATE = """Found {count} relevant documentation page(s):
{docs}
<system_reminder>
Use retrieved documentation to answer the user's question if it is relevant to the user's query.
Format the response using Markdown and reference the documentation using hyperlinks.
Every link to docs clearly explicitly be labeled, for example as "(see docs)".
</system_reminder>
INVALID_ENTITY_KIND_PROMPT = """
Invalid entity kind: {{{kind}}}. Please provide a valid entity kind for the tool.
""".strip()
DOCS_SEARCH_NO_RESULTS_TEMPLATE = """
No documentation found.
<system_reminder>
Do not answer the user's question if you did not find any documentation. Try rewriting the query.
If after a couple of attempts you still do not find any documentation, suggest the user navigate to the documentation page, which is available at `https://posthog.com/docs`.
</system_reminder>
""".strip()
DOC_ITEM_TEMPLATE = """
# {title}
URL: {url}
{text}
""".strip()
FTS_SEARCH_FEATURE_FLAG = "hogai-insights-fts-search"
ENTITIES = [f"{entity}" for entity in FTSKind if entity != FTSKind.INSIGHTS]
SearchKind = Literal["insights", "docs", *ENTITIES] # type: ignore
@@ -126,50 +103,106 @@ class SearchTool(MaxTool):
context_prompt_template: str = "Searches documentation, insights, dashboards, cohorts, actions, experiments, feature flags, notebooks, error tracking issues, and surveys in PostHog"
args_schema: type[BaseModel] = SearchToolArgs
@staticmethod
def _get_fts_entities(include_insight_fts: bool) -> list[str]:
if not include_insight_fts:
entities = [e for e in FTSKind if e != FTSKind.INSIGHTS]
else:
entities = list(FTSKind)
return [*entities, FTSKind.ALL]
async def _arun_impl(self, kind: SearchKind, query: str) -> tuple[str, dict[str, Any] | None]:
async def _arun_impl(self, kind: str, query: str) -> tuple[str, ToolMessagesArtifact | None]:
if kind == "docs":
if not settings.INKEEP_API_KEY:
return "This tool is not available in this environment.", None
if self._has_docs_search_feature_flag():
return await self._search_docs(query), None
docs_tool = InkeepDocsSearchTool(
team=self._team,
user=self._user,
state=self._state,
config=self._config,
context_manager=self._context_manager,
)
return await docs_tool.execute(query, self.tool_call_id)
fts_entities = SearchTool._get_fts_entities(SearchTool._has_fts_search_feature_flag(self._user, self._team))
if kind == "insights" and not self._has_insights_fts_search_feature_flag():
insights_tool = InsightSearchTool(
team=self._team,
user=self._user,
state=self._state,
config=self._config,
context_manager=self._context_manager,
)
return await insights_tool.execute(query, self.tool_call_id)
if kind in fts_entities:
entity_search_toolkit = EntitySearchToolkit(self._team, self._user)
response = await entity_search_toolkit.execute(query, FTSKind(kind))
return response, None
# Used for routing
return "Search tool executed", SearchToolArgs(kind=kind, query=query).model_dump()
if kind not in self._fts_entities:
return format_prompt_string(INVALID_ENTITY_KIND_PROMPT, kind=kind), None
def _has_docs_search_feature_flag(self) -> bool:
entity_search_toolkit = EntitySearchTool(
team=self._team,
user=self._user,
state=self._state,
config=self._config,
context_manager=self._context_manager,
)
response = await entity_search_toolkit.execute(query, FTSKind(kind))
return response, None
@property
def _fts_entities(self) -> list[str]:
entities = list(FTSKind)
return [*entities, FTSKind.ALL]
def _has_insights_fts_search_feature_flag(self) -> bool:
return posthoganalytics.feature_enabled(
"max-inkeep-rag-docs-search",
"hogai-insights-fts-search",
str(self._user.distinct_id),
groups={"organization": str(self._team.organization_id)},
group_properties={"organization": {"id": str(self._team.organization_id)}},
send_feature_flag_events=False,
)
@staticmethod
def _has_fts_search_feature_flag(user: User, team: Team) -> bool:
return posthoganalytics.feature_enabled(
FTS_SEARCH_FEATURE_FLAG,
str(user.distinct_id),
groups={"organization": str(team.organization_id)},
group_properties={"organization": {"id": str(team.organization_id)}},
send_feature_flag_events=False,
)
async def _search_docs(self, query: str) -> str:
DOCS_SEARCH_RESULTS_TEMPLATE = """Found {count} relevant documentation page(s):
{docs}
<system_reminder>
Use retrieved documentation to answer the user's question if it is relevant to the user's query.
Format the response using Markdown and reference the documentation using hyperlinks.
Every link to docs clearly explicitly be labeled, for example as "(see docs)".
</system_reminder>
""".strip()
DOCS_SEARCH_NO_RESULTS_TEMPLATE = """
No documentation found.
<system_reminder>
Do not answer the user's question if you did not find any documentation. Try rewriting the query.
If after a couple of attempts you still do not find any documentation, suggest the user navigate to the documentation page, which is available at `https://posthog.com/docs`.
</system_reminder>
""".strip()
DOC_ITEM_TEMPLATE = """
# {title}
URL: {url}
{text}
""".strip()
class InkeepDocsSearchTool(MaxSubtool):
async def execute(self, query: str, tool_call_id: str) -> tuple[str, ToolMessagesArtifact | None]:
if self._has_rag_docs_search_feature_flag():
return await self._search_using_rag_endpoint(query, tool_call_id)
else:
return await self._search_using_node(query, tool_call_id)
async def _search_using_node(self, query: str, tool_call_id: str) -> tuple[str, ToolMessagesArtifact | None]:
# Avoid circular import
from ee.hogai.graph.inkeep_docs.nodes import InkeepDocsNode
# Init the graph
node = InkeepDocsNode(self._team, self._user)
chain: RunnableLambda[AssistantState, PartialAssistantState | None] = RunnableLambda(node)
copied_state = self._state.model_copy(deep=True, update={"root_tool_call_id": tool_call_id})
result = await chain.ainvoke(copied_state)
assert result is not None
return "", ToolMessagesArtifact(messages=result.messages)
async def _search_using_rag_endpoint(
self, query: str, tool_call_id: str
) -> tuple[str, ToolMessagesArtifact | None]:
model = ChatOpenAI(
model="inkeep-rag",
base_url="https://api.inkeep.com/v1/",
@@ -184,7 +217,7 @@ class SearchTool(MaxTool):
rag_context_raw = await chain.ainvoke({"query": query})
if not rag_context_raw or not rag_context_raw.get("content"):
return DOCS_SEARCH_NO_RESULTS_TEMPLATE
return DOCS_SEARCH_NO_RESULTS_TEMPLATE, None
rag_context = InkeepResponse.model_validate(rag_context_raw)
@@ -197,7 +230,35 @@ class SearchTool(MaxTool):
docs.append(DOC_ITEM_TEMPLATE.format(title=doc.title, url=doc.url, text=text))
if not docs:
return DOCS_SEARCH_NO_RESULTS_TEMPLATE
return DOCS_SEARCH_NO_RESULTS_TEMPLATE, None
formatted_docs = "\n\n---\n\n".join(docs)
return DOCS_SEARCH_RESULTS_TEMPLATE.format(count=len(docs), docs=formatted_docs)
return DOCS_SEARCH_RESULTS_TEMPLATE.format(count=len(docs), docs=formatted_docs), None
def _has_rag_docs_search_feature_flag(self) -> bool:
return posthoganalytics.feature_enabled(
"max-inkeep-rag-docs-search",
str(self._user.distinct_id),
groups={"organization": str(self._team.organization_id)},
group_properties={"organization": {"id": str(self._team.organization_id)}},
send_feature_flag_events=False,
)
EMPTY_DATABASE_ERROR_MESSAGE = """
The user doesn't have any insights created yet.
""".strip()
class InsightSearchTool(MaxSubtool):
async def execute(self, query: str, tool_call_id: str) -> tuple[str, ToolMessagesArtifact | None]:
try:
node = InsightSearchNode(self._team, self._user)
copied_state = self._state.model_copy(
deep=True, update={"search_insights_query": query, "root_tool_call_id": tool_call_id}
)
chain: RunnableLambda[AssistantState, PartialAssistantState | None] = RunnableLambda(node)
result = await chain.ainvoke(copied_state)
return "", ToolMessagesArtifact(messages=result.messages) if result else None
except NoInsightsException:
return EMPTY_DATABASE_ERROR_MESSAGE, None

View File

@@ -0,0 +1,122 @@
from typing import Literal
from langchain_core.runnables import RunnableLambda
from pydantic import BaseModel, Field
from ee.hogai.graph.session_summaries.nodes import SessionSummarizationNode
from ee.hogai.tool import MaxTool, ToolMessagesArtifact
from ee.hogai.utils.types.base import AssistantState, PartialAssistantState
SESSION_SUMMARIZATION_TOOL_PROMPT = """
Use this tool to summarize session recordings by analysing the events within those sessions to find patterns and issues.
It will return a textual summary of the captured session recordings.
# When to use the tool:
When the user asks to summarize session recordings:
- "summarize" synonyms: "watch", "analyze", "review", and similar
- "session recordings" synonyms: "sessions", "recordings", "replays", "user sessions", and similar
# When NOT to use the tool:
- When the user asks to find, search for, or look up session recordings, but doesn't ask to summarize them
- When users asks to update, change, or adjust session recordings filters
# Synonyms
- "summarize": "watch", "analyze", "review", and similar
- "session recordings": "sessions", "recordings", "replays", "user sessions", and similar
# Managing context
If the conversation history contains context about the current filters or session recordings, follow these steps:
- Convert the user query into a `session_summarization_query`
- The query should be used to understand the user's intent
- Decide if the query is relevant to the current filters and set `should_use_current_filters` accordingly
- Generate the `summary_title` based on the user's query and the current filters
Otherwise:
- Convert the user query into a `session_summarization_query`
- The query should be used to search for relevant sessions and then summarize them
- Assume the `should_use_current_filters` should be always `false`
- Generate the `summary_title` based on the user's query
# Additional guidelines
- CRITICAL: Always pass the user's complete, unmodified query to the `session_summarization_query` parameter
- DO NOT truncate, summarize, or extract keywords from the user's query
- The query is used to find relevant sessions - context helps find better matches
- Use explicit tool definition to make a decision
""".strip()
class SessionSummarizationToolArgs(BaseModel):
session_summarization_query: str = Field(
description="""
- The user's complete query for session recordings summarization.
- This will be used to find relevant session recordings.
- Always pass the user's complete, unmodified query.
- Examples:
* 'summarize all session recordings from yesterday'
* 'analyze mobile user session recordings from last week, even if 1 second'
* 'watch last 300 session recordings of MacOS users from US'
* and similar
"""
)
should_use_current_filters: bool = Field(
description="""
- Whether to use current filters from user's UI to find relevant session recordings.
- IMPORTANT: Should be always `false` if the current filters or `search_session_recordings` tool are not present in the conversation history.
- Examples:
* Set to `true` if one of the conditions is met:
- the user wants to summarize "current/selected/opened/my/all/these" session recordings
- the user wants to use "current/these" filters
- the user's query specifies filters identical to the current filters
- if the user's query doesn't specify any filters/conditions
- the user refers to what they're "looking at" or "viewing"
* Set to `false` if one of the conditions is met:
- no current filters or `search_session_recordings` tool are present in the conversation
- the user specifies date/time period different from the current filters
- the user specifies conditions (user, device, id, URL, etc.) not present in the current filters
""",
)
summary_title: str = Field(
description="""
- The name of the summary that is expected to be generated from the user's `session_summarization_query` and/or `current_filters` (if present).
- The name should cover in 3-7 words what sessions would be to be summarized in the summary
- This won't be used for any search of filtering, only to properly label the generated summary.
- Examples:
* If `should_use_current_filters` is `false`, then the `summary_title` should be generated based on the `session_summarization_query`:
- query: "I want to watch all the sessions of user `user@example.com` in the last 30 days no matter how long" -> name: "Sessions of the user user@example.com (last 30 days)"
- query: "summarize my last 100 session recordings" -> name: "Last 100 sessions"
- and similar
* If `should_use_current_filters` is `true`, then the `summary_title` should be generated based on the current filters in the context (if present):
- filters: "{"key":"$os","value":["Mac OS X"],"operator":"exact","type":"event"}" -> name: "MacOS users"
- filters: "{"date_from": "-7d", "filter_test_accounts": True}" -> name: "All sessions (last 7 days)"
- and similar
* If there's not enough context to generated the summary name - keep it an empty string ("")
"""
)
class SessionSummarizationTool(MaxTool):
name: Literal["session_summarization"] = "session_summarization"
description: str = SESSION_SUMMARIZATION_TOOL_PROMPT
thinking_message: str = "Summarizing session recordings"
context_prompt_template: str = "Summarizes session recordings based on the user's query and current filters"
args_schema: type[BaseModel] = SessionSummarizationToolArgs
show_tool_call_message: bool = False
async def _arun_impl(
self, session_summarization_query: str, should_use_current_filters: bool, summary_title: str
) -> tuple[str, ToolMessagesArtifact | None]:
node = SessionSummarizationNode(self._team, self._user)
chain: RunnableLambda[AssistantState, PartialAssistantState | None] = RunnableLambda(node)
copied_state = self._state.model_copy(
deep=True,
update={
"root_tool_call_id": self.tool_call_id,
"session_summarization_query": session_summarization_query,
"should_use_current_filters": should_use_current_filters,
"summary_title": summary_title,
},
)
result = await chain.ainvoke(copied_state)
if not result or not result.messages:
return "Session summarization failed", None
return "", ToolMessagesArtifact(messages=result.messages)

View File

@@ -0,0 +1,253 @@
from typing import Any
from posthog.test.base import ClickhouseTestMixin, NonAtomicBaseTest
from unittest.mock import AsyncMock, patch
from langchain_core.runnables import RunnableConfig
from posthog.schema import (
AssistantMessage,
AssistantTool,
AssistantToolCallMessage,
AssistantTrendsQuery,
VisualizationMessage,
)
from ee.hogai.context.context import AssistantContextManager
from ee.hogai.graph.root.tools.create_and_query_insight import (
INSIGHT_TOOL_FAILURE_SYSTEM_REMINDER_PROMPT,
CreateAndQueryInsightTool,
)
from ee.hogai.graph.schema_generator.nodes import SchemaGenerationException
from ee.hogai.utils.types import AssistantState
from ee.hogai.utils.types.base import NodePath
from ee.models.assistant import Conversation
class TestCreateAndQueryInsightTool(ClickhouseTestMixin, NonAtomicBaseTest):
CLASS_DATA_LEVEL_SETUP = False
def setUp(self):
super().setUp()
self.conversation = Conversation.objects.create(team=self.team, user=self.user)
self.tool_call_id = "test_tool_call_id"
def _create_tool(
self, state: AssistantState | None = None, contextual_tools: dict[str, dict[str, Any]] | None = None
):
"""Helper to create tool instance with optional state and contextual tools."""
if state is None:
state = AssistantState(messages=[])
config: RunnableConfig = RunnableConfig()
if contextual_tools:
config = RunnableConfig(configurable={"contextual_tools": contextual_tools})
context_manager = AssistantContextManager(team=self.team, user=self.user, config=config)
return CreateAndQueryInsightTool(
team=self.team,
user=self.user,
state=state,
context_manager=context_manager,
node_path=(NodePath(name="test_node", tool_call_id=self.tool_call_id, message_id="test"),),
)
async def test_successful_insight_creation_returns_messages(self):
"""Test successful insight creation returns visualization and tool call messages."""
tool = self._create_tool()
query = AssistantTrendsQuery(series=[])
viz_message = VisualizationMessage(query="test query", answer=query, plan="test plan")
tool_call_message = AssistantToolCallMessage(content="Results are here", tool_call_id=self.tool_call_id)
mock_state = AssistantState(messages=[viz_message, tool_call_message])
with patch("ee.hogai.graph.insights_graph.graph.InsightsGraph.compile_full_graph") as mock_compile:
mock_graph = AsyncMock()
mock_graph.ainvoke = AsyncMock(return_value=mock_state.model_dump())
mock_compile.return_value = mock_graph
result_text, artifact = await tool._arun_impl(query_description="test description")
self.assertEqual(result_text, "")
self.assertIsNotNone(artifact)
self.assertEqual(len(artifact.messages), 2)
self.assertIsInstance(artifact.messages[0], VisualizationMessage)
self.assertIsInstance(artifact.messages[1], AssistantToolCallMessage)
async def test_schema_generation_exception_returns_formatted_error(self):
"""Test SchemaGenerationException is caught and returns formatted error message."""
tool = self._create_tool()
exception = SchemaGenerationException(
llm_output="Invalid query structure", validation_message="Missing required field: series"
)
with patch("ee.hogai.graph.insights_graph.graph.InsightsGraph.compile_full_graph") as mock_compile:
mock_graph = AsyncMock()
mock_graph.ainvoke = AsyncMock(side_effect=exception)
mock_compile.return_value = mock_graph
result_text, artifact = await tool._arun_impl(query_description="test description")
self.assertIsNone(artifact)
self.assertIn("Invalid query structure", result_text)
self.assertIn("Missing required field: series", result_text)
self.assertIn(INSIGHT_TOOL_FAILURE_SYSTEM_REMINDER_PROMPT, result_text)
async def test_invalid_tool_call_message_type_returns_error(self):
"""Test when the last message is not AssistantToolCallMessage, returns error."""
tool = self._create_tool()
viz_message = VisualizationMessage(query="test query", answer=AssistantTrendsQuery(series=[]), plan="test plan")
# Last message is AssistantMessage instead of AssistantToolCallMessage
invalid_message = AssistantMessage(content="Not a tool call message")
mock_state = AssistantState(messages=[viz_message, invalid_message])
with patch("ee.hogai.graph.insights_graph.graph.InsightsGraph.compile_full_graph") as mock_compile:
mock_graph = AsyncMock()
mock_graph.ainvoke = AsyncMock(return_value=mock_state.model_dump())
mock_compile.return_value = mock_graph
result_text, artifact = await tool._arun_impl(query_description="test description")
self.assertIsNone(artifact)
self.assertIn("unknown error", result_text)
self.assertIn(INSIGHT_TOOL_FAILURE_SYSTEM_REMINDER_PROMPT, result_text)
async def test_human_feedback_requested_returns_only_tool_call_message(self):
"""Test when visualization message is not present, returns only tool call message."""
tool = self._create_tool()
# When agent requests human feedback, there's no VisualizationMessage
some_message = AssistantMessage(content="I need help with this query")
tool_call_message = AssistantToolCallMessage(content="Need clarification", tool_call_id=self.tool_call_id)
mock_state = AssistantState(messages=[some_message, tool_call_message])
with patch("ee.hogai.graph.insights_graph.graph.InsightsGraph.compile_full_graph") as mock_compile:
mock_graph = AsyncMock()
mock_graph.ainvoke = AsyncMock(return_value=mock_state.model_dump())
mock_compile.return_value = mock_graph
result_text, artifact = await tool._arun_impl(query_description="test description")
self.assertEqual(result_text, "")
self.assertIsNotNone(artifact)
self.assertEqual(len(artifact.messages), 1)
self.assertIsInstance(artifact.messages[0], AssistantToolCallMessage)
self.assertEqual(artifact.messages[0].content, "Need clarification")
async def test_editing_mode_adds_ui_payload(self):
"""Test that in editing mode, UI payload is added to tool call message."""
# Create tool with contextual tool available
tool = self._create_tool(contextual_tools={AssistantTool.CREATE_AND_QUERY_INSIGHT.value: {}})
query = AssistantTrendsQuery(series=[])
viz_message = VisualizationMessage(query="test query", answer=query, plan="test plan")
tool_call_message = AssistantToolCallMessage(content="Results are here", tool_call_id=self.tool_call_id)
mock_state = AssistantState(messages=[viz_message, tool_call_message])
with patch("ee.hogai.graph.insights_graph.graph.InsightsGraph.compile_full_graph") as mock_compile:
mock_graph = AsyncMock()
mock_graph.ainvoke = AsyncMock(return_value=mock_state.model_dump())
mock_compile.return_value = mock_graph
result_text, artifact = await tool._arun_impl(query_description="test description")
self.assertEqual(result_text, "")
self.assertIsNotNone(artifact)
self.assertEqual(len(artifact.messages), 2)
# Check that UI payload was added to tool call message
returned_tool_call_message = artifact.messages[1]
self.assertIsInstance(returned_tool_call_message, AssistantToolCallMessage)
self.assertIsNotNone(returned_tool_call_message.ui_payload)
self.assertIn("create_and_query_insight", returned_tool_call_message.ui_payload)
self.assertEqual(
returned_tool_call_message.ui_payload["create_and_query_insight"], query.model_dump(exclude_none=True)
)
async def test_non_editing_mode_no_ui_payload(self):
"""Test that in non-editing mode, no UI payload is added to tool call message."""
# Create tool without contextual tools (non-editing mode)
tool = self._create_tool(contextual_tools={})
query = AssistantTrendsQuery(series=[])
viz_message = VisualizationMessage(query="test query", answer=query, plan="test plan")
tool_call_message = AssistantToolCallMessage(content="Results are here", tool_call_id=self.tool_call_id)
mock_state = AssistantState(messages=[viz_message, tool_call_message])
with patch("ee.hogai.graph.insights_graph.graph.InsightsGraph.compile_full_graph") as mock_compile:
mock_graph = AsyncMock()
mock_graph.ainvoke = AsyncMock(return_value=mock_state.model_dump())
mock_compile.return_value = mock_graph
result_text, artifact = await tool._arun_impl(query_description="test description")
self.assertEqual(result_text, "")
self.assertIsNotNone(artifact)
self.assertEqual(len(artifact.messages), 2)
# Check that the original tool call message is returned without modification
returned_tool_call_message = artifact.messages[1]
self.assertIsInstance(returned_tool_call_message, AssistantToolCallMessage)
# In non-editing mode, the original message is returned as-is
self.assertEqual(returned_tool_call_message, tool_call_message)
async def test_state_updates_include_tool_call_metadata(self):
"""Test that the state passed to graph includes root_tool_call_id and root_tool_insight_plan."""
initial_state = AssistantState(messages=[AssistantMessage(content="initial")])
tool = self._create_tool(state=initial_state)
query = AssistantTrendsQuery(series=[])
viz_message = VisualizationMessage(query="test query", answer=query, plan="test plan")
tool_call_message = AssistantToolCallMessage(content="Results", tool_call_id=self.tool_call_id)
mock_state = AssistantState(messages=[viz_message, tool_call_message])
invoked_state = None
async def capture_invoked_state(state):
nonlocal invoked_state
invoked_state = state
return mock_state.model_dump()
with patch("ee.hogai.graph.insights_graph.graph.InsightsGraph.compile_full_graph") as mock_compile:
mock_graph = AsyncMock()
mock_graph.ainvoke = AsyncMock(side_effect=capture_invoked_state)
mock_compile.return_value = mock_graph
await tool._arun_impl(query_description="my test query")
# Verify the state passed to ainvoke has the correct metadata
self.assertIsNotNone(invoked_state)
validated_state = AssistantState.model_validate(invoked_state)
self.assertEqual(validated_state.root_tool_call_id, self.tool_call_id)
self.assertEqual(validated_state.root_tool_insight_plan, "my test query")
# Original message should still be there
self.assertEqual(len(validated_state.messages), 1)
assert isinstance(validated_state.messages[0], AssistantMessage)
self.assertEqual(validated_state.messages[0].content, "initial")
async def test_is_editing_mode_classmethod(self):
"""Test the is_editing_mode class method correctly detects editing mode."""
# Test with editing mode enabled
config_editing: RunnableConfig = RunnableConfig(
configurable={"contextual_tools": {AssistantTool.CREATE_AND_QUERY_INSIGHT.value: {}}}
)
context_manager_editing = AssistantContextManager(team=self.team, user=self.user, config=config_editing)
self.assertTrue(CreateAndQueryInsightTool.is_editing_mode(context_manager_editing))
# Test with editing mode disabled
config_not_editing = RunnableConfig(configurable={"contextual_tools": {}})
context_manager_not_editing = AssistantContextManager(team=self.team, user=self.user, config=config_not_editing)
self.assertFalse(CreateAndQueryInsightTool.is_editing_mode(context_manager_not_editing))
# Test with other contextual tools but not create_and_query_insight
config_other = RunnableConfig(configurable={"contextual_tools": {"some_other_tool": {}}})
context_manager_other = AssistantContextManager(team=self.team, user=self.user, config=config_other)
self.assertFalse(CreateAndQueryInsightTool.is_editing_mode(context_manager_other))

View File

@@ -0,0 +1,248 @@
from typing import cast
from uuid import uuid4
from posthog.test.base import ClickhouseTestMixin, NonAtomicBaseTest
from unittest.mock import AsyncMock, MagicMock, patch
from posthog.schema import AssistantMessage
from ee.hogai.context.context import AssistantContextManager
from ee.hogai.graph.root.tools.create_dashboard import CreateDashboardTool
from ee.hogai.utils.types import AssistantState, InsightQuery, PartialAssistantState
from ee.hogai.utils.types.base import NodePath
class TestCreateDashboardTool(ClickhouseTestMixin, NonAtomicBaseTest):
CLASS_DATA_LEVEL_SETUP = False
def setUp(self):
super().setUp()
self.tool_call_id = str(uuid4())
self.state = AssistantState(messages=[], root_tool_call_id=self.tool_call_id)
self.context_manager = AssistantContextManager(self.team, self.user, {})
self.tool = CreateDashboardTool(
team=self.team,
user=self.user,
state=self.state,
context_manager=self.context_manager,
node_path=(NodePath(name="test_node", tool_call_id=self.tool_call_id, message_id="test"),),
)
async def test_execute_calls_dashboard_creation_node(self):
mock_node_instance = MagicMock()
mock_result = PartialAssistantState(
messages=[AssistantMessage(content="Dashboard created successfully with 3 insights")]
)
with patch("ee.hogai.graph.dashboards.nodes.DashboardCreationNode", return_value=mock_node_instance):
with patch("ee.hogai.graph.root.tools.create_dashboard.RunnableLambda") as mock_runnable:
mock_chain = MagicMock()
mock_chain.ainvoke = AsyncMock(return_value=mock_result)
mock_runnable.return_value = mock_chain
insight_queries = [
InsightQuery(name="Pageviews", description="Show pageviews for last 7 days"),
InsightQuery(name="User signups", description="Show user signups funnel"),
]
result, artifact = await self.tool._arun_impl(
search_insights_queries=insight_queries,
dashboard_name="Marketing Dashboard",
)
self.assertEqual(result, "")
self.assertIsNotNone(artifact)
assert artifact is not None
self.assertEqual(len(artifact.messages), 1)
message = cast(AssistantMessage, artifact.messages[0])
self.assertEqual(message.content, "Dashboard created successfully with 3 insights")
async def test_execute_updates_state_with_all_parameters(self):
mock_result = PartialAssistantState(messages=[AssistantMessage(content="Test response")])
insight_queries = [
InsightQuery(name="Revenue Trends", description="Monthly revenue trends for Q4"),
InsightQuery(name="Churn Rate", description="Customer churn rate by cohort"),
InsightQuery(name="NPS Score", description="Net Promoter Score over time"),
]
async def mock_ainvoke(state):
self.assertEqual(state.search_insights_queries, insight_queries)
self.assertEqual(state.dashboard_name, "Executive Summary Q4")
self.assertEqual(state.root_tool_call_id, self.tool_call_id)
return mock_result
mock_chain = MagicMock()
mock_chain.ainvoke = mock_ainvoke
with patch("ee.hogai.graph.dashboards.nodes.DashboardCreationNode"):
with patch("ee.hogai.graph.root.tools.create_dashboard.RunnableLambda", return_value=mock_chain):
await self.tool._arun_impl(
search_insights_queries=insight_queries,
dashboard_name="Executive Summary Q4",
)
async def test_execute_with_single_insight_query(self):
mock_result = PartialAssistantState(messages=[AssistantMessage(content="Dashboard with one insight created")])
insight_queries = [InsightQuery(name="Daily Active Users", description="Count of daily active users")]
async def mock_ainvoke(state):
self.assertEqual(len(state.search_insights_queries), 1)
self.assertEqual(state.search_insights_queries[0].name, "Daily Active Users")
self.assertEqual(state.search_insights_queries[0].description, "Count of daily active users")
self.assertEqual(state.dashboard_name, "User Activity Dashboard")
return mock_result
mock_chain = MagicMock()
mock_chain.ainvoke = mock_ainvoke
with patch("ee.hogai.graph.dashboards.nodes.DashboardCreationNode"):
with patch("ee.hogai.graph.root.tools.create_dashboard.RunnableLambda", return_value=mock_chain):
result, artifact = await self.tool._arun_impl(
search_insights_queries=insight_queries,
dashboard_name="User Activity Dashboard",
)
self.assertEqual(result, "")
self.assertIsNotNone(artifact)
async def test_execute_with_many_insight_queries(self):
mock_result = PartialAssistantState(messages=[AssistantMessage(content="Large dashboard created")])
insight_queries = [
InsightQuery(name=f"Insight {i}", description=f"Description for insight {i}") for i in range(10)
]
async def mock_ainvoke(state):
self.assertEqual(len(state.search_insights_queries), 10)
self.assertEqual(state.search_insights_queries[0].name, "Insight 0")
self.assertEqual(state.search_insights_queries[9].name, "Insight 9")
self.assertEqual(state.dashboard_name, "Comprehensive Dashboard")
return mock_result
mock_chain = MagicMock()
mock_chain.ainvoke = mock_ainvoke
with patch("ee.hogai.graph.dashboards.nodes.DashboardCreationNode"):
with patch("ee.hogai.graph.root.tools.create_dashboard.RunnableLambda", return_value=mock_chain):
result, artifact = await self.tool._arun_impl(
search_insights_queries=insight_queries,
dashboard_name="Comprehensive Dashboard",
)
self.assertEqual(result, "")
self.assertIsNotNone(artifact)
async def test_execute_returns_failure_message_when_result_is_none(self):
async def mock_ainvoke(state):
return None
mock_chain = MagicMock()
mock_chain.ainvoke = mock_ainvoke
with patch("ee.hogai.graph.dashboards.nodes.DashboardCreationNode"):
with patch("ee.hogai.graph.root.tools.create_dashboard.RunnableLambda", return_value=mock_chain):
result, artifact = await self.tool._arun_impl(
search_insights_queries=[InsightQuery(name="Test", description="Test insight")],
dashboard_name="Test Dashboard",
)
self.assertEqual(result, "Dashboard creation failed")
self.assertIsNone(artifact)
async def test_execute_returns_failure_message_when_result_has_no_messages(self):
mock_result = PartialAssistantState(messages=[])
async def mock_ainvoke(state):
return mock_result
mock_chain = MagicMock()
mock_chain.ainvoke = mock_ainvoke
with patch("ee.hogai.graph.dashboards.nodes.DashboardCreationNode"):
with patch("ee.hogai.graph.root.tools.create_dashboard.RunnableLambda", return_value=mock_chain):
result, artifact = await self.tool._arun_impl(
search_insights_queries=[InsightQuery(name="Test", description="Test insight")],
dashboard_name="Test Dashboard",
)
self.assertEqual(result, "Dashboard creation failed")
self.assertIsNone(artifact)
async def test_execute_preserves_original_state(self):
"""Test that the original state is not modified when creating the copied state"""
original_queries = [InsightQuery(name="Original", description="Original insight")]
original_state = AssistantState(
messages=[],
root_tool_call_id=self.tool_call_id,
search_insights_queries=original_queries,
dashboard_name="Original Dashboard",
)
tool = CreateDashboardTool(
team=self.team,
user=self.user,
state=original_state,
context_manager=self.context_manager,
node_path=(NodePath(name="test_node", tool_call_id=self.tool_call_id, message_id="test"),),
)
mock_result = PartialAssistantState(messages=[AssistantMessage(content="Test")])
new_queries = [InsightQuery(name="New", description="New insight")]
async def mock_ainvoke(state):
# Verify the new state has updated values
self.assertEqual(state.search_insights_queries, new_queries)
self.assertEqual(state.dashboard_name, "New Dashboard")
return mock_result
mock_chain = MagicMock()
mock_chain.ainvoke = mock_ainvoke
with patch("ee.hogai.graph.dashboards.nodes.DashboardCreationNode"):
with patch("ee.hogai.graph.root.tools.create_dashboard.RunnableLambda", return_value=mock_chain):
await tool._arun_impl(
search_insights_queries=new_queries,
dashboard_name="New Dashboard",
)
# Verify original state was not modified
self.assertEqual(original_state.search_insights_queries, original_queries)
self.assertEqual(original_state.dashboard_name, "Original Dashboard")
self.assertEqual(original_state.root_tool_call_id, self.tool_call_id)
async def test_execute_with_complex_insight_descriptions(self):
"""Test that complex insight descriptions with special characters are handled correctly"""
mock_result = PartialAssistantState(messages=[AssistantMessage(content="Dashboard created")])
insight_queries = [
InsightQuery(
name="User Journey",
description='Show funnel from "Sign up""Create project""Invite team" for users with property email containing "@company.com"',
),
InsightQuery(
name="Revenue (USD)",
description="Track total revenue in USD from 2024-01-01 to 2024-12-31, filtered by plan_type = 'premium' OR plan_type = 'enterprise'",
),
]
async def mock_ainvoke(state):
self.assertEqual(len(state.search_insights_queries), 2)
self.assertIn("@company.com", state.search_insights_queries[0].description)
self.assertIn("'premium'", state.search_insights_queries[1].description)
return mock_result
mock_chain = MagicMock()
mock_chain.ainvoke = mock_ainvoke
with patch("ee.hogai.graph.dashboards.nodes.DashboardCreationNode"):
with patch("ee.hogai.graph.root.tools.create_dashboard.RunnableLambda", return_value=mock_chain):
result, artifact = await self.tool._arun_impl(
search_insights_queries=insight_queries,
dashboard_name="Complex Dashboard",
)
self.assertEqual(result, "")
self.assertIsNotNone(artifact)

View File

@@ -1,27 +1,178 @@
from posthog.test.base import BaseTest
from unittest.mock import patch
from typing import cast
from uuid import uuid4
from posthog.test.base import ClickhouseTestMixin, NonAtomicBaseTest
from unittest.mock import AsyncMock, MagicMock, patch
from django.test import override_settings
from langchain_core import messages
from langchain_core.runnables import RunnableConfig
from posthog.schema import AssistantMessage
from ee.hogai.context.context import AssistantContextManager
from ee.hogai.graph.root.tools.search import (
DOC_ITEM_TEMPLATE,
DOCS_SEARCH_NO_RESULTS_TEMPLATE,
DOCS_SEARCH_RESULTS_TEMPLATE,
EMPTY_DATABASE_ERROR_MESSAGE,
InkeepDocsSearchTool,
InsightSearchTool,
SearchTool,
)
from ee.hogai.utils.tests import FakeChatOpenAI
from ee.hogai.utils.types import AssistantState, PartialAssistantState
from ee.hogai.utils.types.base import NodePath
class TestSearchToolDocumentation(BaseTest):
class TestSearchTool(ClickhouseTestMixin, NonAtomicBaseTest):
CLASS_DATA_LEVEL_SETUP = False
def setUp(self):
super().setUp()
self.tool = SearchTool(team=self.team, user=self.user, tool_call_id="test-tool-call-id")
self.tool_call_id = "test_tool_call_id"
self.state = AssistantState(messages=[], root_tool_call_id=str(uuid4()))
self.context_manager = AssistantContextManager(self.team, self.user, {})
self.tool = SearchTool(
team=self.team,
user=self.user,
state=self.state,
context_manager=self.context_manager,
node_path=(NodePath(name="test_node", tool_call_id=self.tool_call_id, message_id="test"),),
)
async def test_run_docs_search_without_api_key(self):
with patch("ee.hogai.graph.root.tools.search.settings") as mock_settings:
mock_settings.INKEEP_API_KEY = None
result, artifact = await self.tool._arun_impl(kind="docs", query="How to use feature flags?")
self.assertEqual(result, "This tool is not available in this environment.")
self.assertIsNone(artifact)
async def test_run_docs_search_with_api_key(self):
mock_docs_tool = MagicMock()
mock_docs_tool.execute = AsyncMock(return_value=("", MagicMock()))
with (
patch("ee.hogai.graph.root.tools.search.settings") as mock_settings,
patch("ee.hogai.graph.root.tools.search.InkeepDocsSearchTool", return_value=mock_docs_tool),
):
mock_settings.INKEEP_API_KEY = "test-key"
result, artifact = await self.tool._arun_impl(kind="docs", query="How to use feature flags?")
mock_docs_tool.execute.assert_called_once_with("How to use feature flags?", self.tool_call_id)
self.assertEqual(result, "")
self.assertIsNotNone(artifact)
async def test_run_insights_search(self):
mock_insights_tool = MagicMock()
mock_insights_tool.execute = AsyncMock(return_value=("", MagicMock()))
with patch("ee.hogai.graph.root.tools.search.InsightSearchTool", return_value=mock_insights_tool):
result, artifact = await self.tool._arun_impl(kind="insights", query="user signups")
mock_insights_tool.execute.assert_called_once_with("user signups", self.tool_call_id)
self.assertEqual(result, "")
self.assertIsNotNone(artifact)
async def test_run_unknown_kind(self):
result, artifact = await self.tool._arun_impl(kind="unknown", query="test")
self.assertEqual(result, "Invalid entity kind: unknown. Please provide a valid entity kind for the tool.")
self.assertIsNone(artifact)
@patch("ee.hogai.graph.root.tools.search.EntitySearchTool.execute")
async def test_arun_impl_error_tracking_issues_returns_routing_data(self, mock_execute):
mock_execute.return_value = "Search results for error tracking issues"
result, artifact = await self.tool._arun_impl(
kind="error_tracking_issues", query="test error tracking issue query"
)
self.assertEqual(result, "Search results for error tracking issues")
self.assertIsNone(artifact)
mock_execute.assert_called_once_with("test error tracking issue query", "error_tracking_issues")
@patch("ee.hogai.graph.root.tools.search.EntitySearchTool.execute")
@patch("ee.hogai.graph.root.tools.search.SearchTool._has_insights_fts_search_feature_flag")
async def test_arun_impl_insight_with_feature_flag_disabled(
self, mock_has_insights_fts_search_feature_flag, mock_execute
):
mock_has_insights_fts_search_feature_flag.return_value = False
mock_execute.return_value = "Search results for insights"
result, artifact = await self.tool._arun_impl(kind="insights", query="test insight query")
self.assertEqual(result, "The user doesn't have any insights created yet.")
self.assertIsNone(artifact)
mock_execute.assert_not_called()
@patch("ee.hogai.graph.root.tools.search.EntitySearchTool.execute")
@patch("ee.hogai.graph.root.tools.search.SearchTool._has_insights_fts_search_feature_flag")
async def test_arun_impl_insight_with_feature_flag_enabled(
self, mock_has_insights_fts_search_feature_flag, mock_execute
):
mock_has_insights_fts_search_feature_flag.return_value = True
mock_execute.return_value = "Search results for insights"
result, artifact = await self.tool._arun_impl(kind="insights", query="test insight query")
self.assertEqual(result, "Search results for insights")
self.assertIsNone(artifact)
mock_execute.assert_called_once_with("test insight query", "insights")
class TestInkeepDocsSearchTool(ClickhouseTestMixin, NonAtomicBaseTest):
CLASS_DATA_LEVEL_SETUP = False
def setUp(self):
super().setUp()
self.tool_call_id = str(uuid4())
self.state = AssistantState(messages=[], root_tool_call_id=self.tool_call_id)
self.context_manager = AssistantContextManager(self.team, self.user, {})
self.tool = InkeepDocsSearchTool(
team=self.team,
user=self.user,
state=self.state,
config=RunnableConfig(configurable={}),
context_manager=self.context_manager,
)
async def test_execute_calls_inkeep_docs_node(self):
mock_node_instance = MagicMock()
mock_result = PartialAssistantState(messages=[AssistantMessage(content="Here is the answer from docs")])
with patch("ee.hogai.graph.inkeep_docs.nodes.InkeepDocsNode", return_value=mock_node_instance):
with patch("ee.hogai.graph.root.tools.search.RunnableLambda") as mock_runnable:
mock_chain = MagicMock()
mock_chain.ainvoke = AsyncMock(return_value=mock_result)
mock_runnable.return_value = mock_chain
result, artifact = await self.tool.execute("How to track events?", "test-tool-call-id")
self.assertEqual(result, "")
self.assertIsNotNone(artifact)
assert artifact is not None
self.assertEqual(len(artifact.messages), 1)
message = cast(AssistantMessage, artifact.messages[0])
self.assertEqual(message.content, "Here is the answer from docs")
async def test_execute_updates_state_with_tool_call_id(self):
mock_result = PartialAssistantState(messages=[AssistantMessage(content="Test response")])
async def mock_ainvoke(state):
self.assertEqual(state.root_tool_call_id, "custom-tool-call-id")
return mock_result
mock_chain = MagicMock()
mock_chain.ainvoke = mock_ainvoke
with patch("ee.hogai.graph.inkeep_docs.nodes.InkeepDocsNode"):
with patch("ee.hogai.graph.root.tools.search.RunnableLambda", return_value=mock_chain):
await self.tool.execute("test query", "custom-tool-call-id")
@override_settings(INKEEP_API_KEY="test-inkeep-key")
@patch("ee.hogai.graph.root.tools.search.ChatOpenAI")
async def test_search_docs_with_successful_results(self, mock_llm_class):
@patch("ee.hogai.graph.root.tools.search.InkeepDocsSearchTool._has_rag_docs_search_feature_flag", return_value=True)
async def test_search_docs_with_successful_results(self, mock_has_rag_docs_search_feature_flag, mock_llm_class):
response_json = """{
"content": [
{
@@ -47,7 +198,7 @@ class TestSearchToolDocumentation(BaseTest):
fake_llm = FakeChatOpenAI(responses=[messages.AIMessage(content=response_json)])
mock_llm_class.return_value = fake_llm
result = await self.tool._search_docs("how to use feature")
result, _ = await self.tool.execute("how to use feature", "test-tool-call-id")
expected_doc_1 = DOC_ITEM_TEMPLATE.format(
title="Feature Documentation",
@@ -68,196 +219,76 @@ class TestSearchToolDocumentation(BaseTest):
self.assertEqual(mock_llm_class.call_args.kwargs["api_key"], "test-inkeep-key")
self.assertEqual(mock_llm_class.call_args.kwargs["streaming"], False)
@override_settings(INKEEP_API_KEY="test-inkeep-key")
@patch("ee.hogai.graph.root.tools.search.ChatOpenAI")
async def test_search_docs_with_no_results(self, mock_llm_class):
fake_llm = FakeChatOpenAI(responses=[messages.AIMessage(content="{}")])
mock_llm_class.return_value = fake_llm
result = await self.tool._search_docs("nonexistent feature")
class TestInsightSearchTool(ClickhouseTestMixin, NonAtomicBaseTest):
CLASS_DATA_LEVEL_SETUP = False
self.assertEqual(result, DOCS_SEARCH_NO_RESULTS_TEMPLATE)
@override_settings(INKEEP_API_KEY="test-inkeep-key")
@patch("ee.hogai.graph.root.tools.search.ChatOpenAI")
async def test_search_docs_with_empty_content(self, mock_llm_class):
fake_llm = FakeChatOpenAI(responses=[messages.AIMessage(content='{"content": []}')])
mock_llm_class.return_value = fake_llm
result = await self.tool._search_docs("query")
self.assertEqual(result, DOCS_SEARCH_NO_RESULTS_TEMPLATE)
@override_settings(INKEEP_API_KEY="test-inkeep-key")
@patch("ee.hogai.graph.root.tools.search.ChatOpenAI")
async def test_search_docs_filters_non_document_types(self, mock_llm_class):
response_json = """{
"content": [
{
"type": "snippet",
"record_type": "code",
"url": "https://posthog.com/code",
"title": "Code Snippet",
"source": {"type": "text", "content": [{"type": "text", "text": "Code example"}]}
},
{
"type": "answer",
"record_type": "answer",
"url": "https://posthog.com/answer",
"title": "Answer",
"source": {"type": "text", "content": [{"type": "text", "text": "Answer text"}]}
}
]
}"""
fake_llm = FakeChatOpenAI(responses=[messages.AIMessage(content=response_json)])
mock_llm_class.return_value = fake_llm
result = await self.tool._search_docs("query")
self.assertEqual(result, DOCS_SEARCH_NO_RESULTS_TEMPLATE)
@override_settings(INKEEP_API_KEY="test-inkeep-key")
@patch("ee.hogai.graph.root.tools.search.ChatOpenAI")
async def test_search_docs_handles_empty_source_content(self, mock_llm_class):
response_json = """{
"content": [
{
"type": "document",
"record_type": "page",
"url": "https://posthog.com/docs/feature",
"title": "Feature Documentation",
"source": {"type": "text", "content": []}
}
]
}"""
fake_llm = FakeChatOpenAI(responses=[messages.AIMessage(content=response_json)])
mock_llm_class.return_value = fake_llm
result = await self.tool._search_docs("query")
expected_doc = DOC_ITEM_TEMPLATE.format(
title="Feature Documentation", url="https://posthog.com/docs/feature", text=""
)
expected_result = DOCS_SEARCH_RESULTS_TEMPLATE.format(count=1, docs=expected_doc)
self.assertEqual(result, expected_result)
@override_settings(INKEEP_API_KEY="test-inkeep-key")
@patch("ee.hogai.graph.root.tools.search.ChatOpenAI")
async def test_search_docs_handles_mixed_document_types(self, mock_llm_class):
response_json = """{
"content": [
{
"type": "document",
"record_type": "page",
"url": "https://posthog.com/docs/valid",
"title": "Valid Doc",
"source": {"type": "text", "content": [{"type": "text", "text": "Valid content"}]}
},
{
"type": "snippet",
"record_type": "code",
"url": "https://posthog.com/code",
"title": "Code Snippet",
"source": {"type": "text", "content": [{"type": "text", "text": "Code"}]}
},
{
"type": "document",
"record_type": "guide",
"url": "https://posthog.com/docs/another",
"title": "Another Valid Doc",
"source": {"type": "text", "content": [{"type": "text", "text": "More content"}]}
}
]
}"""
fake_llm = FakeChatOpenAI(responses=[messages.AIMessage(content=response_json)])
mock_llm_class.return_value = fake_llm
result = await self.tool._search_docs("query")
expected_doc_1 = DOC_ITEM_TEMPLATE.format(
title="Valid Doc", url="https://posthog.com/docs/valid", text="Valid content"
)
expected_doc_2 = DOC_ITEM_TEMPLATE.format(
title="Another Valid Doc", url="https://posthog.com/docs/another", text="More content"
)
expected_result = DOCS_SEARCH_RESULTS_TEMPLATE.format(
count=2, docs=f"{expected_doc_1}\n\n---\n\n{expected_doc_2}"
def setUp(self):
super().setUp()
self.tool_call_id = str(uuid4())
self.state = AssistantState(messages=[], root_tool_call_id=self.tool_call_id)
self.context_manager = AssistantContextManager(self.team, self.user, {})
self.tool = InsightSearchTool(
team=self.team,
user=self.user,
state=self.state,
config=RunnableConfig(configurable={}),
context_manager=self.context_manager,
)
self.assertEqual(result, expected_result)
async def test_execute_calls_insight_search_node(self):
mock_node_instance = MagicMock()
mock_result = PartialAssistantState(messages=[AssistantMessage(content="Found 3 insights matching your query")])
@override_settings(INKEEP_API_KEY=None)
async def test_arun_impl_docs_without_api_key(self):
result, artifact = await self.tool._arun_impl(kind="docs", query="test query")
with patch("ee.hogai.graph.insights.nodes.InsightSearchNode", return_value=mock_node_instance):
with patch("ee.hogai.graph.root.tools.search.RunnableLambda") as mock_runnable:
mock_chain = MagicMock()
mock_chain.ainvoke = AsyncMock(return_value=mock_result)
mock_runnable.return_value = mock_chain
self.assertEqual(result, "This tool is not available in this environment.")
self.assertIsNone(artifact)
result, artifact = await self.tool.execute("user signups by week", self.tool_call_id)
@override_settings(INKEEP_API_KEY="test-key")
@patch("ee.hogai.graph.root.tools.search.posthoganalytics.feature_enabled")
async def test_arun_impl_docs_with_feature_flag_disabled(self, mock_feature_enabled):
mock_feature_enabled.return_value = False
self.assertEqual(result, "")
self.assertIsNotNone(artifact)
assert artifact is not None
self.assertEqual(len(artifact.messages), 1)
message = cast(AssistantMessage, artifact.messages[0])
self.assertEqual(message.content, "Found 3 insights matching your query")
result, artifact = await self.tool._arun_impl(kind="docs", query="test query")
async def test_execute_updates_state_with_search_query(self):
mock_result = PartialAssistantState(messages=[AssistantMessage(content="Test response")])
self.assertEqual(result, "Search tool executed")
self.assertEqual(artifact, {"kind": "docs", "query": "test query"})
async def mock_ainvoke(state):
self.assertEqual(state.search_insights_query, "custom search query")
self.assertEqual(state.root_tool_call_id, self.tool_call_id)
return mock_result
@override_settings(INKEEP_API_KEY="test-key")
@patch("ee.hogai.graph.root.tools.search.posthoganalytics.feature_enabled")
@patch.object(SearchTool, "_search_docs")
async def test_arun_impl_docs_with_feature_flag_enabled(self, mock_search_docs, mock_feature_enabled):
mock_feature_enabled.return_value = True
mock_search_docs.return_value = "Search results"
mock_chain = MagicMock()
mock_chain.ainvoke = mock_ainvoke
result, artifact = await self.tool._arun_impl(kind="docs", query="test query")
with patch("ee.hogai.graph.insights.nodes.InsightSearchNode"):
with patch("ee.hogai.graph.root.tools.search.RunnableLambda", return_value=mock_chain):
await self.tool.execute("custom search query", self.tool_call_id)
self.assertEqual(result, "Search results")
self.assertIsNone(artifact)
mock_search_docs.assert_called_once_with("test query")
async def test_execute_handles_no_insights_exception(self):
from ee.hogai.graph.insights.nodes import NoInsightsException
async def test_arun_impl_insights_returns_routing_data(self):
result, artifact = await self.tool._arun_impl(kind="insights", query="test insight query")
with patch("ee.hogai.graph.insights.nodes.InsightSearchNode", side_effect=NoInsightsException()):
result, artifact = await self.tool.execute("user signups", self.tool_call_id)
self.assertEqual(result, "Search tool executed")
self.assertEqual(artifact, {"kind": "insights", "query": "test insight query"})
self.assertEqual(result, EMPTY_DATABASE_ERROR_MESSAGE)
self.assertIsNone(artifact)
@patch("ee.hogai.graph.root.tools.search.EntitySearchToolkit.execute")
async def test_arun_impl_error_tracking_issues_returns_routing_data(self, mock_execute):
mock_execute.return_value = "Search results for error tracking issues"
async def test_execute_returns_none_artifact_when_result_is_none(self):
async def mock_ainvoke(state):
return None
result, artifact = await self.tool._arun_impl(
kind="error_tracking_issues", query="test error tracking issue query"
)
mock_chain = MagicMock()
mock_chain.ainvoke = mock_ainvoke
self.assertEqual(result, "Search results for error tracking issues")
self.assertIsNone(artifact)
mock_execute.assert_called_once_with("test error tracking issue query", "error_tracking_issues")
with patch("ee.hogai.graph.insights.nodes.InsightSearchNode"):
with patch("ee.hogai.graph.root.tools.search.RunnableLambda", return_value=mock_chain):
result, artifact = await self.tool.execute("test query", self.tool_call_id)
@patch("ee.hogai.graph.root.tools.search.EntitySearchToolkit.execute")
@patch("ee.hogai.graph.root.tools.search.SearchTool._has_fts_search_feature_flag")
async def test_arun_impl_insight_with_feature_flag_disabled(self, mock_has_fts_search_feature_flag, mock_execute):
mock_has_fts_search_feature_flag.return_value = False
mock_execute.return_value = "Search results for insights"
result, artifact = await self.tool._arun_impl(kind="insights", query="test insight query")
self.assertEqual(result, "Search tool executed")
self.assertEqual(artifact, {"kind": "insights", "query": "test insight query"})
mock_execute.assert_not_called()
@patch("ee.hogai.graph.root.tools.search.EntitySearchToolkit.execute")
@patch("ee.hogai.graph.root.tools.search.SearchTool._has_fts_search_feature_flag")
async def test_arun_impl_insight_with_feature_flag_enabled(self, mock_has_fts_search_feature_flag, mock_execute):
mock_has_fts_search_feature_flag.return_value = True
mock_execute.return_value = "Search results for insights"
result, artifact = await self.tool._arun_impl(kind="insights", query="test insight query")
self.assertEqual(result, "Search results for insights")
self.assertIsNone(artifact)
mock_execute.assert_called_once_with("test insight query", "insights")
self.assertEqual(result, "")
self.assertIsNone(artifact)

View File

@@ -0,0 +1,193 @@
from typing import cast
from uuid import uuid4
from posthog.test.base import ClickhouseTestMixin, NonAtomicBaseTest
from unittest.mock import AsyncMock, MagicMock, patch
from posthog.schema import AssistantMessage
from ee.hogai.context.context import AssistantContextManager
from ee.hogai.graph.root.tools.session_summarization import SessionSummarizationTool
from ee.hogai.utils.types import AssistantState, PartialAssistantState
from ee.hogai.utils.types.base import NodePath
class TestSessionSummarizationTool(ClickhouseTestMixin, NonAtomicBaseTest):
CLASS_DATA_LEVEL_SETUP = False
def setUp(self):
super().setUp()
self.tool_call_id = str(uuid4())
self.state = AssistantState(messages=[], root_tool_call_id=self.tool_call_id)
self.context_manager = AssistantContextManager(self.team, self.user, {})
self.tool = SessionSummarizationTool(
team=self.team,
user=self.user,
state=self.state,
context_manager=self.context_manager,
node_path=(NodePath(name="test_node", tool_call_id=self.tool_call_id, message_id="test"),),
)
async def test_execute_calls_session_summarization_node(self):
mock_node_instance = MagicMock()
mock_result = PartialAssistantState(
messages=[AssistantMessage(content="Session summary: 10 sessions analyzed with 5 key patterns found")]
)
with patch("ee.hogai.graph.session_summaries.nodes.SessionSummarizationNode", return_value=mock_node_instance):
with patch("ee.hogai.graph.root.tools.session_summarization.RunnableLambda") as mock_runnable:
mock_chain = MagicMock()
mock_chain.ainvoke = AsyncMock(return_value=mock_result)
mock_runnable.return_value = mock_chain
result, artifact = await self.tool._arun_impl(
session_summarization_query="summarize all sessions from yesterday",
should_use_current_filters=False,
summary_title="All sessions from yesterday",
)
self.assertEqual(result, "")
self.assertIsNotNone(artifact)
assert artifact is not None
self.assertEqual(len(artifact.messages), 1)
message = cast(AssistantMessage, artifact.messages[0])
self.assertEqual(message.content, "Session summary: 10 sessions analyzed with 5 key patterns found")
async def test_execute_updates_state_with_all_parameters(self):
mock_result = PartialAssistantState(messages=[AssistantMessage(content="Test response")])
async def mock_ainvoke(state):
self.assertEqual(state.session_summarization_query, "analyze mobile user sessions")
self.assertEqual(state.should_use_current_filters, True)
self.assertEqual(state.summary_title, "Mobile user sessions")
self.assertEqual(state.root_tool_call_id, self.tool_call_id)
return mock_result
mock_chain = MagicMock()
mock_chain.ainvoke = mock_ainvoke
with patch("ee.hogai.graph.session_summaries.nodes.SessionSummarizationNode"):
with patch("ee.hogai.graph.root.tools.session_summarization.RunnableLambda", return_value=mock_chain):
await self.tool._arun_impl(
session_summarization_query="analyze mobile user sessions",
should_use_current_filters=True,
summary_title="Mobile user sessions",
)
async def test_execute_with_should_use_current_filters_false(self):
mock_result = PartialAssistantState(messages=[AssistantMessage(content="Test response")])
async def mock_ainvoke(state):
self.assertEqual(state.should_use_current_filters, False)
self.assertEqual(state.session_summarization_query, "watch last 300 session recordings")
self.assertEqual(state.summary_title, "Last 300 sessions")
return mock_result
mock_chain = MagicMock()
mock_chain.ainvoke = mock_ainvoke
with patch("ee.hogai.graph.session_summaries.nodes.SessionSummarizationNode"):
with patch("ee.hogai.graph.root.tools.session_summarization.RunnableLambda", return_value=mock_chain):
await self.tool._arun_impl(
session_summarization_query="watch last 300 session recordings",
should_use_current_filters=False,
summary_title="Last 300 sessions",
)
async def test_execute_returns_failure_message_when_result_is_none(self):
async def mock_ainvoke(state):
return None
mock_chain = MagicMock()
mock_chain.ainvoke = mock_ainvoke
with patch("ee.hogai.graph.session_summaries.nodes.SessionSummarizationNode"):
with patch("ee.hogai.graph.root.tools.session_summarization.RunnableLambda", return_value=mock_chain):
result, artifact = await self.tool._arun_impl(
session_summarization_query="test query",
should_use_current_filters=False,
summary_title="Test",
)
self.assertEqual(result, "Session summarization failed")
self.assertIsNone(artifact)
async def test_execute_returns_failure_message_when_result_has_no_messages(self):
mock_result = PartialAssistantState(messages=[])
async def mock_ainvoke(state):
return mock_result
mock_chain = MagicMock()
mock_chain.ainvoke = mock_ainvoke
with patch("ee.hogai.graph.session_summaries.nodes.SessionSummarizationNode"):
with patch("ee.hogai.graph.root.tools.session_summarization.RunnableLambda", return_value=mock_chain):
result, artifact = await self.tool._arun_impl(
session_summarization_query="test query",
should_use_current_filters=False,
summary_title="Test",
)
self.assertEqual(result, "Session summarization failed")
self.assertIsNone(artifact)
async def test_execute_with_empty_summary_title(self):
mock_result = PartialAssistantState(messages=[AssistantMessage(content="Summary completed")])
async def mock_ainvoke(state):
self.assertEqual(state.summary_title, "")
return mock_result
mock_chain = MagicMock()
mock_chain.ainvoke = mock_ainvoke
with patch("ee.hogai.graph.session_summaries.nodes.SessionSummarizationNode"):
with patch("ee.hogai.graph.root.tools.session_summarization.RunnableLambda", return_value=mock_chain):
result, artifact = await self.tool._arun_impl(
session_summarization_query="summarize sessions",
should_use_current_filters=False,
summary_title="",
)
self.assertEqual(result, "")
self.assertIsNotNone(artifact)
async def test_execute_preserves_original_state(self):
"""Test that the original state is not modified when creating the copied state"""
original_query = "original query"
original_state = AssistantState(
messages=[],
root_tool_call_id=self.tool_call_id,
session_summarization_query=original_query,
)
tool = SessionSummarizationTool(
team=self.team,
user=self.user,
state=original_state,
context_manager=self.context_manager,
node_path=(NodePath(name="test_node", tool_call_id=self.tool_call_id, message_id="test"),),
)
mock_result = PartialAssistantState(messages=[AssistantMessage(content="Test")])
async def mock_ainvoke(state):
# Verify the new state has updated values
self.assertEqual(state.session_summarization_query, "new query")
return mock_result
mock_chain = MagicMock()
mock_chain.ainvoke = mock_ainvoke
with patch("ee.hogai.graph.session_summaries.nodes.SessionSummarizationNode"):
with patch("ee.hogai.graph.root.tools.session_summarization.RunnableLambda", return_value=mock_chain):
await tool._arun_impl(
session_summarization_query="new query",
should_use_current_filters=True,
summary_title="New Summary",
)
# Verify original state was not modified
self.assertEqual(original_state.session_summarization_query, original_query)
self.assertEqual(original_state.root_tool_call_id, self.tool_call_id)

View File

@@ -138,11 +138,11 @@ When unsure, use this tool. Proactive task management shows attentiveness and he
""".strip()
# Has its unique schema that doesn't match the Deep Research schema
class TodoItem(BaseModel):
content: str = Field(..., min_length=1)
status: Literal["pending", "in_progress", "completed"]
id: str
priority: Literal["low", "medium", "high"]
class TodoWriteToolArgs(BaseModel):

View File

@@ -13,7 +13,7 @@ from langchain_core.messages import (
from langchain_core.prompts import ChatPromptTemplate, HumanMessagePromptTemplate
from langchain_core.runnables import RunnableConfig
from posthog.schema import FailureMessage, VisualizationMessage
from posthog.schema import VisualizationMessage
from posthog.models.group_type_mapping import GroupTypeMapping
@@ -36,6 +36,15 @@ from .utils import Q, SchemaGeneratorOutput
RETRIES_ALLOWED = 2
class SchemaGenerationException(Exception):
"""An error occurred while generating a schema in the `SchemaGeneratorNode` node."""
def __init__(self, llm_output: str, validation_message: str):
super().__init__("Failed to generate schema")
self.llm_output = llm_output
self.validation_message = validation_message
class SchemaGeneratorNode(AssistantNode, Generic[Q]):
INSIGHT_NAME: str
"""
@@ -87,9 +96,8 @@ class SchemaGeneratorNode(AssistantNode, Generic[Q]):
chain = generation_prompt | merger | self._model | self._parse_output
result: SchemaGeneratorOutput[Q] | None = None
try:
result = await chain.ainvoke(
result: SchemaGeneratorOutput[Q] = await chain.ainvoke(
{
"project_datetime": self.project_now,
"project_timezone": self.project_timezone,
@@ -120,18 +128,9 @@ class SchemaGeneratorNode(AssistantNode, Generic[Q]):
query_generation_retry_count=len(intermediate_steps) + 1,
)
if not result:
# We've got no usable result after exhausting all iteration attempts - it's failure message time
return PartialAssistantState(
messages=[
FailureMessage(
content=f"It looks like I'm having trouble generating this {self.INSIGHT_NAME} insight."
)
],
intermediate_steps=None,
plan=None,
query_generation_retry_count=len(intermediate_steps) + 1,
)
if isinstance(e, PydanticOutputParserException):
raise SchemaGenerationException(e.llm_output, e.validation_message)
raise SchemaGenerationException(e.llm_output or "No input was provided.", str(e))
# We've got a result that either passed the quality check or we've exhausted all attempts at iterating - return
return PartialAssistantState(

View File

@@ -13,7 +13,12 @@ from langchain_core.runnables import RunnableConfig, RunnableLambda
from posthog.schema import AssistantMessage, AssistantTrendsQuery, FailureMessage, HumanMessage, VisualizationMessage
from ee.hogai.graph.schema_generator.nodes import RETRIES_ALLOWED, SchemaGeneratorNode, SchemaGeneratorToolsNode
from ee.hogai.graph.schema_generator.nodes import (
RETRIES_ALLOWED,
SchemaGenerationException,
SchemaGeneratorNode,
SchemaGeneratorToolsNode,
)
from ee.hogai.graph.schema_generator.parsers import PydanticOutputParserException
from ee.hogai.graph.schema_generator.utils import SchemaGeneratorOutput
from ee.hogai.utils.types import AssistantState, PartialAssistantState
@@ -258,7 +263,7 @@ class TestSchemaGeneratorNode(BaseTest):
self.assertEqual(action.log, "Field validation failed")
async def test_quality_check_failure_with_retries_exhausted(self):
"""Test quality check failure with retries exhausted still returns VisualizationMessage."""
"""Test quality check failure with retries exhausted raises SchemaGenerationException."""
node = DummyGeneratorNode(self.team, self.user)
with (
patch.object(DummyGeneratorNode, "_model") as generator_model_mock,
@@ -273,26 +278,25 @@ class TestSchemaGeneratorNode(BaseTest):
)
# Start with RETRIES_ALLOWED intermediate steps (so no more allowed)
new_state = await node.arun(
AssistantState(
messages=[HumanMessage(content="Text", id="0")],
start_id="0",
intermediate_steps=cast(
list[IntermediateStep],
[
(AgentAction(tool="handle_incorrect_response", tool_input="", log=""), "retry"),
],
)
* RETRIES_ALLOWED,
),
{},
)
with self.assertRaises(SchemaGenerationException) as cm:
await node.arun(
AssistantState(
messages=[HumanMessage(content="Text", id="0")],
start_id="0",
intermediate_steps=cast(
list[IntermediateStep],
[
(AgentAction(tool="handle_incorrect_response", tool_input="", log=""), "retry"),
],
)
* RETRIES_ALLOWED,
),
{},
)
# Should return VisualizationMessage despite quality check failure
self.assertEqual(new_state.intermediate_steps, None)
self.assertEqual(len(new_state.messages), 1)
self.assertEqual(new_state.messages[0].type, "ai/viz")
self.assertEqual(cast(VisualizationMessage, new_state.messages[0]).answer, self.basic_trends)
# Verify the exception contains the expected information
self.assertEqual(cm.exception.llm_output, '{"query": "test"}')
self.assertEqual(cm.exception.validation_message, "Quality check failed")
async def test_node_leaves_failover(self):
node = DummyGeneratorNode(self.team, self.user)
@@ -329,20 +333,17 @@ class TestSchemaGeneratorNode(BaseTest):
schema = DummySchema.model_construct(query=[]).model_dump() # type: ignore
generator_model_mock.return_value = RunnableLambda(lambda _: json.dumps(schema))
new_state = await node.arun(
AssistantState(
messages=[HumanMessage(content="Text")],
intermediate_steps=[
(AgentAction(tool="", tool_input="", log="exception"), "exception"),
(AgentAction(tool="", tool_input="", log="exception"), "exception"),
],
),
{},
)
self.assertEqual(new_state.intermediate_steps, None)
self.assertEqual(len(new_state.messages), 1)
self.assertIsInstance(new_state.messages[0], FailureMessage)
self.assertEqual(new_state.plan, None)
with self.assertRaises(SchemaGenerationException):
await node.arun(
AssistantState(
messages=[HumanMessage(content="Text")],
intermediate_steps=[
(AgentAction(tool="", tool_input="", log="exception"), "exception"),
(AgentAction(tool="", tool_input="", log="exception"), "exception"),
],
),
{},
)
async def test_agent_reconstructs_conversation_with_failover(self):
action = AgentAction(tool="fix", tool_input="validation error", log="exception")

View File

@@ -11,7 +11,6 @@ from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnableConfig
from posthog.schema import (
AssistantMessage,
AssistantToolCallMessage,
MaxRecordingUniversalFilters,
NotebookUpdateMessage,
@@ -70,11 +69,7 @@ class SessionSummarizationNode(AssistantNode):
"""Push summarization progress as reasoning messages"""
content = prepare_reasoning_progress_message(progress_message)
if content:
self.dispatcher.message(
AssistantMessage(
content=content,
)
)
self.dispatcher.update(content)
async def _stream_notebook_content(self, content: dict, state: AssistantState, partial: bool = True) -> None:
"""Stream TipTap content directly to a notebook if notebook_id is present in state."""
@@ -202,7 +197,7 @@ class _SessionSearch:
tool = await SearchSessionRecordingsTool.create_tool_class(
team=self._node._team,
user=self._node._user,
tool_call_id=self._node._parent_tool_call_id or "",
node_path=self._node.node_path,
state=state,
config=config,
context_manager=self._node.context_manager,

View File

@@ -1,6 +1,6 @@
from langchain_core.runnables import RunnableConfig
from posthog.schema import AssistantHogQLQuery, AssistantMessage
from posthog.schema import AssistantHogQLQuery
from posthog.hogql.context import HogQLContext
@@ -25,7 +25,7 @@ class SQLGeneratorNode(HogQLGeneratorMixin, SchemaGeneratorNode[AssistantHogQLQu
return AssistantNodeName.SQL_GENERATOR
async def arun(self, state: AssistantState, config: RunnableConfig) -> PartialAssistantState:
self.dispatcher.message(AssistantMessage(content="Creating SQL query"))
self.dispatcher.update("Creating SQL query")
prompt = await self._construct_system_prompt()
return await super()._run_with_prompt(state, prompt, config=config)

View File

@@ -3,34 +3,43 @@ from typing import Generic
from posthog.models import Team, User
from ee.hogai.django_checkpoint.checkpointer import DjangoCheckpointer
from ee.hogai.graph.graph import BaseAssistantGraph
from ee.hogai.graph.base import BaseAssistantGraph
from ee.hogai.utils.types import PartialStateType, StateType
from ee.hogai.utils.types.base import AssistantGraphName
from .nodes import StateClassMixin, TaxonomyAgentNode, TaxonomyAgentToolsNode
from .toolkit import TaxonomyAgentToolkit
from .types import TaxonomyNodeName
class TaxonomyAgent(BaseAssistantGraph[StateType], Generic[StateType, PartialStateType], StateClassMixin):
class TaxonomyAgent(
BaseAssistantGraph[StateType, PartialStateType], Generic[StateType, PartialStateType], StateClassMixin
):
"""Taxonomy agent that can be configured with different node classes."""
def __init__(
self,
team: Team,
user: User,
tool_call_id: str,
loop_node_class: type["TaxonomyAgentNode"],
tools_node_class: type["TaxonomyAgentToolsNode"],
toolkit_class: type["TaxonomyAgentToolkit"],
):
# Extract the State type from the generic parameter
state_class, _ = self._get_state_class(TaxonomyAgent)
super().__init__(team, user, state_class, parent_tool_call_id=tool_call_id)
super().__init__(team, user)
self._loop_node_class = loop_node_class
self._tools_node_class = tools_node_class
self._toolkit_class = toolkit_class
@property
def state_type(self) -> type[StateType]:
# Extract the State type from the generic parameter
state_type, _ = self._get_state_class(TaxonomyAgent)
return state_type
@property
def graph_name(self) -> AssistantGraphName:
return AssistantGraphName.TAXONOMY
def compile_full_graph(self, checkpointer: DjangoCheckpointer | None = None):
"""Compile a complete taxonomy graph."""
return self.add_taxonomy_generator().compile(checkpointer=checkpointer)

View File

@@ -40,25 +40,10 @@ TaxonomyPartialStateType = TypeVar("TaxonomyPartialStateType", bound=TaxonomyAge
TaxonomyNodeBound = BaseAssistantNode[TaxonomyStateType, TaxonomyPartialStateType]
class ParentToolCallIdMixin(BaseAssistantNode[TaxonomyStateType, TaxonomyPartialStateType]):
_parent_tool_call_id: str | None = None
_toolkit: TaxonomyAgentToolkit
async def __call__(self, state: TaxonomyStateType, config: RunnableConfig) -> TaxonomyPartialStateType | None:
"""
Run the assistant node and handle cancelled conversation before the node is run.
"""
if self._parent_tool_call_id:
self._toolkit._parent_tool_call_id = self._parent_tool_call_id
return await super().__call__(state, config)
class TaxonomyAgentNode(
Generic[TaxonomyStateType, TaxonomyPartialStateType],
StateClassMixin,
TaxonomyUpdateDispatcherNodeMixin,
ParentToolCallIdMixin,
TaxonomyNodeBound,
ABC,
):
@@ -173,7 +158,6 @@ class TaxonomyAgentToolsNode(
Generic[TaxonomyStateType, TaxonomyPartialStateType],
StateClassMixin,
TaxonomyUpdateDispatcherNodeMixin,
ParentToolCallIdMixin,
TaxonomyNodeBound,
):
"""Base tools node for taxonomy agents."""

View File

@@ -36,14 +36,13 @@ class TestTaxonomyAgent(BaseTest):
self.mock_graph = Mock()
# Patch StateGraph in the parent class where it's actually used
self.patcher = patch("ee.hogai.graph.graph.StateGraph")
self.patcher = patch("ee.hogai.graph.base.graph.StateGraph")
mock_state_graph_class = self.patcher.start()
mock_state_graph_class.return_value = self.mock_graph
self.agent = ConcreteTaxonomyAgent(
team=self.team,
user=self.user,
tool_call_id="test_tool_call_id",
loop_node_class=MockTaxonomyAgentNode,
tools_node_class=MockTaxonomyAgentToolsNode,
toolkit_class=MockTaxonomyAgentToolkit,
@@ -75,7 +74,6 @@ class TestTaxonomyAgent(BaseTest):
NonGenericAgent(
team=self.team,
user=self.user,
tool_call_id="test_tool_call_id",
loop_node_class=MockTaxonomyAgentNode,
tools_node_class=MockTaxonomyAgentToolsNode,
toolkit_class=MockTaxonomyAgentToolkit,

View File

@@ -122,10 +122,6 @@ class TaxonomyTaskExecutorNode(
Task executor node specifically for taxonomy operations.
"""
def __init__(self, team: Team, user: User, parent_tool_call_id: str | None):
super().__init__(team, user)
self._parent_tool_call_id = parent_tool_call_id
@property
def node_name(self) -> MaxNodeName:
return TaxonomyNodeName.TASK_EXECUTOR
@@ -148,8 +144,6 @@ class TaxonomyTaskExecutorNode(
class TaxonomyAgentToolkit:
"""Base toolkit for taxonomy agents that handle tool execution."""
_parent_tool_call_id: str | None
def __init__(self, team: Team, user: User):
self._team = team
self._user = user
@@ -391,7 +385,7 @@ class TaxonomyAgentToolkit:
)
)
message = AssistantMessage(content="", id=str(uuid4()), tool_calls=tool_calls)
executor = TaxonomyTaskExecutorNode(self._team, self._user, parent_tool_call_id=self._parent_tool_call_id)
executor = TaxonomyTaskExecutorNode(self._team, self._user)
result = await executor.arun(AssistantState(messages=[message]), RunnableConfig())
final_result = {}
@@ -651,7 +645,7 @@ class TaxonomyAgentToolkit:
for event_name_or_action_id in event_name_or_action_ids
]
message = AssistantMessage(content="", id=str(uuid4()), tool_calls=tool_calls)
executor = TaxonomyTaskExecutorNode(self._team, self._user, parent_tool_call_id=self._parent_tool_call_id)
executor = TaxonomyTaskExecutorNode(self._team, self._user)
result = await executor.arun(AssistantState(messages=[message]), RunnableConfig())
return {task_result.id: task_result.result for task_result in result.task_results}

View File

@@ -1,29 +0,0 @@
from posthog.test.base import BaseTest
from langchain_core.runnables import RunnableLambda
from langgraph.checkpoint.memory import InMemorySaver
from ee.hogai.graph.graph import BaseAssistantGraph
from ee.hogai.utils.types import AssistantNodeName, AssistantState, PartialAssistantState
class TestAssistantGraph(BaseTest):
async def test_pydantic_state_resets_with_none(self):
"""When a None field is set, it should be reset to None."""
async def runnable(state: AssistantState) -> PartialAssistantState:
return PartialAssistantState(start_id=None)
graph = BaseAssistantGraph(self.team, self.user, state_type=AssistantState)
compiled_graph = (
graph.add_node(AssistantNodeName.ROOT, RunnableLambda(runnable))
.add_edge(AssistantNodeName.START, AssistantNodeName.ROOT)
.add_edge(AssistantNodeName.ROOT, AssistantNodeName.END)
.compile(checkpointer=InMemorySaver())
)
state = await compiled_graph.ainvoke(
AssistantState(messages=[], graph_status="resumed", start_id=None),
{"configurable": {"thread_id": "test"}},
)
self.assertEqual(state["start_id"], None)
self.assertEqual(state["graph_status"], "resumed")

View File

@@ -4,8 +4,7 @@ Integration tests for dispatcher usage in BaseAssistantNode and graph execution.
These tests ensure that the dispatcher pattern works correctly end-to-end in real graph execution.
"""
import uuid
from typing import Any
from typing import cast
from posthog.test.base import BaseTest
from unittest.mock import MagicMock, patch
@@ -14,30 +13,40 @@ from langchain_core.runnables import RunnableConfig
from posthog.schema import AssistantMessage
from ee.hogai.graph.base import BaseAssistantNode
from ee.hogai.utils.dispatcher import MessageAction, NodeStartAction
from ee.hogai.utils.types import AssistantNodeName
from ee.hogai.utils.types.base import AssistantDispatcherEvent, AssistantState, PartialAssistantState
from ee.hogai.graph.base.node import BaseAssistantNode
from ee.hogai.utils.types.base import (
AssistantDispatcherEvent,
AssistantGraphName,
AssistantNodeName,
AssistantState,
MessageAction,
NodeEndAction,
NodePath,
NodeStartAction,
PartialAssistantState,
UpdateAction,
)
class MockAssistantNode(BaseAssistantNode):
class MockAssistantNode(BaseAssistantNode[AssistantState, PartialAssistantState]):
"""Mock node for testing dispatcher integration."""
def __init__(self, team, user):
super().__init__(team, user)
def __init__(self, team, user, node_path=None):
super().__init__(team, user, node_path)
self.arun_called = False
self.messages_dispatched = []
@property
def node_name(self) -> AssistantNodeName:
def node_name(self) -> str:
return AssistantNodeName.ROOT
async def arun(self, state, config: RunnableConfig) -> PartialAssistantState:
async def arun(self, state: AssistantState, config: RunnableConfig) -> PartialAssistantState:
self.arun_called = True
# Simulate dispatching messages during execution
self.dispatcher.message(AssistantMessage(content="Processing..."))
self.dispatcher.message(AssistantMessage(content="Done!"))
self.dispatcher.update("Processing...")
self.dispatcher.message(AssistantMessage(content="Intermediate result"))
self.dispatcher.update("Done!")
return PartialAssistantState(messages=[AssistantMessage(content="Final result")])
@@ -57,147 +66,169 @@ class TestDispatcherIntegration(BaseTest):
node = MockAssistantNode(self.mock_team, self.mock_user)
state = AssistantState(messages=[])
config = RunnableConfig()
config = RunnableConfig(metadata={"langgraph_node": AssistantNodeName.ROOT, "langgraph_checkpoint_ns": "cp_1"})
await node.arun(state, config)
with patch("ee.hogai.utils.dispatcher.get_stream_writer", side_effect=RuntimeError("Not streaming")):
await node(state, config)
self.assertTrue(node.arun_called)
self.assertIsNotNone(node.dispatcher)
async def test_node_path_propagation(self):
"""Test that node_path is correctly set and propagated."""
node_path = (
NodePath(name=AssistantGraphName.ASSISTANT),
NodePath(name=AssistantNodeName.ROOT),
)
node = MockAssistantNode(self.mock_team, self.mock_user, node_path=node_path)
self.assertEqual(node.node_path, node_path)
async def test_dispatcher_dispatches_node_start_and_end(self):
"""Test that NODE_START and NODE_END actions are dispatched."""
node = MockAssistantNode(self.mock_team, self.mock_user)
state = AssistantState(messages=[])
config = RunnableConfig(metadata={"langgraph_node": AssistantNodeName.ROOT, "langgraph_checkpoint_ns": "cp_2"})
dispatched_actions = []
def capture_write(event):
if isinstance(event, AssistantDispatcherEvent):
dispatched_actions.append(event)
with patch("ee.hogai.utils.dispatcher.get_stream_writer", return_value=capture_write):
await node(state, config)
# Should have dispatched actions in this order:
# 1. NODE_START
# 2. UpdateAction ("Processing...")
# 3. MessageAction (intermediate)
# 4. UpdateAction ("Done!")
# 5. NODE_END with final state
self.assertGreater(len(dispatched_actions), 0)
# Verify NODE_START is first
self.assertIsInstance(dispatched_actions[0].action, NodeStartAction)
# Verify NODE_END is last
last_action = dispatched_actions[-1].action
self.assertIsInstance(last_action, NodeEndAction)
self.assertIsNotNone(cast(NodeEndAction, last_action).state)
async def test_messages_dispatched_during_node_execution(self):
"""Test that messages dispatched during node execution are sent to writer."""
node = MockAssistantNode(self.mock_team, self.mock_user)
state = AssistantState(messages=[])
config = RunnableConfig()
config = RunnableConfig(metadata={"langgraph_node": AssistantNodeName.ROOT, "langgraph_checkpoint_ns": "cp_3"})
dispatched_actions = []
def capture_write(update: Any):
dispatched_actions.append(update)
def capture_write(event: AssistantDispatcherEvent):
dispatched_actions.append(event)
# Mock get_stream_writer to return our test writer
with patch("ee.hogai.graph.base.get_stream_writer", return_value=capture_write):
# Call the node (not arun) to trigger __call__ which handles dispatching
with patch("ee.hogai.utils.dispatcher.get_stream_writer", return_value=capture_write):
await node(state, config)
# Should have:
# 1. NODE_START from __call__
# 2. Two MESSAGE actions from arun (Processing..., Done!)
# 3. One MESSAGE action from returned state (Final result)
# Find the update and message actions (excluding NODE_START and NODE_END)
update_actions = [e for e in dispatched_actions if isinstance(e.action, UpdateAction)]
message_actions = [e for e in dispatched_actions if isinstance(e.action, MessageAction)]
self.assertEqual(len(dispatched_actions), 4)
self.assertIsInstance(dispatched_actions[0], AssistantDispatcherEvent)
self.assertIsInstance(dispatched_actions[0].action, NodeStartAction)
self.assertIsInstance(dispatched_actions[1], AssistantDispatcherEvent)
self.assertIsInstance(dispatched_actions[1].action, MessageAction)
self.assertEqual(dispatched_actions[1].action.message.content, "Processing...")
self.assertIsInstance(dispatched_actions[2], AssistantDispatcherEvent)
self.assertIsInstance(dispatched_actions[2].action, MessageAction)
self.assertEqual(dispatched_actions[2].action.message.content, "Done!")
self.assertIsInstance(dispatched_actions[3], AssistantDispatcherEvent)
self.assertIsInstance(dispatched_actions[3].action, MessageAction)
self.assertEqual(dispatched_actions[3].action.message.content, "Final result")
# Should have 2 updates: "Processing..." and "Done!"
self.assertEqual(len(update_actions), 2)
self.assertEqual(cast(UpdateAction, update_actions[0].action).content, "Processing...")
self.assertEqual(cast(UpdateAction, update_actions[1].action).content, "Done!")
async def test_node_start_action_dispatched(self):
"""Test that NODE_START action is dispatched at node entry."""
# Should have 1 message: intermediate (final message is in NODE_END state, not dispatched separately)
self.assertEqual(len(message_actions), 1)
msg = cast(MessageAction, message_actions[0].action).message
self.assertEqual(cast(AssistantMessage, msg).content, "Intermediate result")
async def test_node_path_included_in_dispatched_events(self):
"""Test that node_path is included in all dispatched events."""
node_path = (
NodePath(name=AssistantGraphName.ASSISTANT),
NodePath(name=AssistantNodeName.ROOT),
)
node = MockAssistantNode(self.mock_team, self.mock_user, node_path=node_path)
state = AssistantState(messages=[])
config = RunnableConfig(metadata={"langgraph_node": AssistantNodeName.ROOT, "langgraph_checkpoint_ns": "cp_4"})
dispatched_events = []
def capture_write(event: AssistantDispatcherEvent):
dispatched_events.append(event)
with patch("ee.hogai.utils.dispatcher.get_stream_writer", return_value=capture_write):
await node(state, config)
# Verify all events have the correct node_path
for event in dispatched_events:
self.assertEqual(event.node_path, node_path)
async def test_node_run_id_included_in_dispatched_events(self):
"""Test that node_run_id is included in all dispatched events."""
node = MockAssistantNode(self.mock_team, self.mock_user)
state = AssistantState(messages=[])
config = RunnableConfig()
checkpoint_ns = "checkpoint_xyz_789"
config = RunnableConfig(
metadata={"langgraph_node": AssistantNodeName.ROOT, "langgraph_checkpoint_ns": checkpoint_ns}
)
dispatched_actions = []
dispatched_events = []
def capture_write(update):
if isinstance(update, AssistantDispatcherEvent):
dispatched_actions.append(update.action)
def capture_write(event: AssistantDispatcherEvent):
dispatched_events.append(event)
with patch("ee.hogai.graph.base.get_stream_writer", return_value=capture_write):
with patch("ee.hogai.utils.dispatcher.get_stream_writer", return_value=capture_write):
await node(state, config)
# Should have at least one NODE_START action
from ee.hogai.utils.dispatcher import NodeStartAction
# Verify all events have the correct node_run_id
for event in dispatched_events:
self.assertEqual(event.node_run_id, checkpoint_ns)
node_start_actions = [action for action in dispatched_actions if isinstance(action, NodeStartAction)]
self.assertGreater(len(node_start_actions), 0)
@patch("ee.hogai.graph.base.get_stream_writer")
async def test_parent_tool_call_id_propagation(self, mock_get_stream_writer):
"""Test that parent_tool_call_id is propagated to dispatched messages."""
parent_tool_call_id = str(uuid.uuid4())
class NodeWithParent(BaseAssistantNode):
@property
def node_name(self):
return AssistantNodeName.ROOT
async def arun(self, state, config: RunnableConfig) -> PartialAssistantState:
# Dispatch message - should inherit parent_tool_call_id from state
self.dispatcher.message(AssistantMessage(content="Child message"))
return PartialAssistantState(messages=[])
node = NodeWithParent(self.mock_team, self.mock_user)
state = {"messages": [], "parent_tool_call_id": parent_tool_call_id}
config = RunnableConfig()
dispatched_messages = []
def capture_write(update):
if isinstance(update, AssistantDispatcherEvent) and isinstance(update.action, MessageAction):
dispatched_messages.append(update.action.message)
mock_get_stream_writer.return_value = capture_write
await node.arun(state, config)
# Verify dispatched messages have parent_tool_call_id
assistant_messages = [msg for msg in dispatched_messages if isinstance(msg, AssistantMessage)]
for msg in assistant_messages:
# If the implementation propagates it, this should be true
# Otherwise this test will help catch that as a potential issue
if msg.parent_tool_call_id:
self.assertEqual(msg.parent_tool_call_id, parent_tool_call_id)
@patch("ee.hogai.graph.base.get_stream_writer")
async def test_dispatcher_error_handling(self, mock_get_stream_writer):
async def test_dispatcher_error_handling_does_not_crash_node(self):
"""Test that errors in dispatcher don't crash node execution."""
class FailingDispatcherNode(BaseAssistantNode):
class FailingDispatcherNode(BaseAssistantNode[AssistantState, PartialAssistantState]):
@property
def node_name(self):
return AssistantNodeName.ROOT
async def arun(self, state, config: RunnableConfig) -> PartialAssistantState:
async def arun(self, state: AssistantState, config: RunnableConfig) -> PartialAssistantState:
# Try to dispatch - if writer fails, should handle gracefully
self.dispatcher.message(AssistantMessage(content="Test"))
self.dispatcher.update("Test")
return PartialAssistantState(messages=[AssistantMessage(content="Result")])
node = FailingDispatcherNode(self.mock_team, self.mock_user)
state = AssistantState(messages=[])
config = RunnableConfig()
config = RunnableConfig(metadata={"langgraph_node": AssistantNodeName.ROOT, "langgraph_checkpoint_ns": "cp_5"})
# Make writer raise an error
def failing_writer(data):
raise RuntimeError("Writer failed")
mock_get_stream_writer.return_value = failing_writer
with patch("ee.hogai.utils.dispatcher.get_stream_writer", return_value=failing_writer):
# Should not crash - node should complete
result = await node(state, config)
self.assertIsNotNone(result)
# Should not crash - node should complete
result = await node.arun(state, config)
self.assertIsNotNone(result)
async def test_messages_in_partial_state_dispatched_via_node_end(self):
"""Test that messages in PartialState are dispatched via NODE_END."""
async def test_messages_in_partial_state_are_auto_dispatched(self):
"""Test that messages in PartialState are automatically dispatched."""
class NodeReturningMessages(BaseAssistantNode):
class NodeReturningMessages(BaseAssistantNode[AssistantState, PartialAssistantState]):
@property
def node_name(self):
return AssistantNodeName.ROOT
async def arun(self, state, config: RunnableConfig) -> PartialAssistantState:
# Return messages in state - should be auto-dispatched
async def arun(self, state: AssistantState, config: RunnableConfig) -> PartialAssistantState:
# Return messages in state
return PartialAssistantState(
messages=[
AssistantMessage(content="Message 1"),
@@ -208,39 +239,139 @@ class TestDispatcherIntegration(BaseTest):
node = NodeReturningMessages(self.mock_team, self.mock_user)
state = AssistantState(messages=[])
config = RunnableConfig()
config = RunnableConfig(metadata={"langgraph_node": AssistantNodeName.ROOT, "langgraph_checkpoint_ns": "cp_6"})
dispatched_messages = []
dispatched_events = []
def capture_write(update):
if isinstance(update, AssistantDispatcherEvent) and isinstance(update.action, MessageAction):
dispatched_messages.append(update.action.message)
def capture_write(event: AssistantDispatcherEvent):
dispatched_events.append(event)
with patch("ee.hogai.graph.base.get_stream_writer", return_value=capture_write):
with patch("ee.hogai.utils.dispatcher.get_stream_writer", return_value=capture_write):
await node(state, config)
# Should have dispatched the messages from PartialState (and NODE_START)
# We expect at least 2 message actions (Message 1 and Message 2)
self.assertGreaterEqual(len(dispatched_messages), 2)
# Should have NODE_END action with state containing messages
node_end_actions = [e for e in dispatched_events if isinstance(e.action, NodeEndAction)]
self.assertEqual(len(node_end_actions), 1)
node_end_state = cast(NodeEndAction, node_end_actions[0].action).state
self.assertIsNotNone(node_end_state)
assert node_end_state is not None
self.assertEqual(len(node_end_state.messages), 2)
async def test_node_returns_none_state_handling(self):
"""Test that node can return None state without errors."""
class NoneStateNode(BaseAssistantNode):
class NoneStateNode(BaseAssistantNode[AssistantState, PartialAssistantState]):
@property
def node_name(self):
return AssistantNodeName.ROOT
async def arun(self, state, config: RunnableConfig) -> None:
async def arun(self, state: AssistantState, config: RunnableConfig) -> PartialAssistantState | None:
# Dispatch a message but return None state
self.dispatcher.message(AssistantMessage(content="Test"))
self.dispatcher.update("Test")
return None
node = NoneStateNode(self.mock_team, self.mock_user)
state = AssistantState(messages=[])
config = RunnableConfig()
config = RunnableConfig(metadata={"langgraph_node": AssistantNodeName.ROOT, "langgraph_checkpoint_ns": "cp_7"})
result = await node(state, config)
# Should handle None gracefully
self.assertIsNone(result)
with patch("ee.hogai.utils.dispatcher.get_stream_writer", side_effect=RuntimeError("Not streaming")):
result = await node(state, config)
# Should handle None gracefully
self.assertIsNone(result)
async def test_nested_node_path_in_dispatched_events(self):
"""Test that nested nodes have correct node_path."""
# Simulate a nested node path (e.g., from a tool call)
parent_path = (
NodePath(name=AssistantGraphName.ASSISTANT),
NodePath(name=AssistantNodeName.ROOT, message_id="msg_123", tool_call_id="tc_123"),
NodePath(name=AssistantGraphName.INSIGHTS),
)
node = MockAssistantNode(self.mock_team, self.mock_user, node_path=parent_path)
state = AssistantState(messages=[])
config = RunnableConfig(
metadata={"langgraph_node": AssistantNodeName.TRENDS_GENERATOR, "langgraph_checkpoint_ns": "cp_8"}
)
dispatched_events = []
def capture_write(event: AssistantDispatcherEvent):
dispatched_events.append(event)
with patch("ee.hogai.utils.dispatcher.get_stream_writer", return_value=capture_write):
await node(state, config)
# Verify all events have the nested path
for event in dispatched_events:
self.assertIsNotNone(event.node_path)
assert event.node_path is not None
self.assertEqual(event.node_path, parent_path)
self.assertEqual(len(event.node_path), 3)
self.assertEqual(event.node_path[1].message_id, "msg_123")
self.assertEqual(event.node_path[1].tool_call_id, "tc_123")
async def test_update_actions_include_node_metadata(self):
"""Test that update actions include correct node metadata."""
node = MockAssistantNode(self.mock_team, self.mock_user)
state = AssistantState(messages=[])
config = RunnableConfig(metadata={"langgraph_node": AssistantNodeName.ROOT, "langgraph_checkpoint_ns": "cp_9"})
dispatched_events = []
def capture_write(event: AssistantDispatcherEvent):
dispatched_events.append(event)
with patch("ee.hogai.utils.dispatcher.get_stream_writer", return_value=capture_write):
await node(state, config)
# Find update actions
update_events = [e for e in dispatched_events if isinstance(e.action, UpdateAction)]
for event in update_events:
self.assertEqual(event.node_name, AssistantNodeName.ROOT)
self.assertEqual(event.node_run_id, "cp_9")
self.assertIsNotNone(event.node_path)
async def test_concurrent_node_executions_independent_dispatchers(self):
"""Test that concurrent node executions use independent dispatchers."""
node1 = MockAssistantNode(self.mock_team, self.mock_user)
node2 = MockAssistantNode(self.mock_team, self.mock_user)
state = AssistantState(messages=[])
config1 = RunnableConfig(
metadata={"langgraph_node": AssistantNodeName.ROOT, "langgraph_checkpoint_ns": "cp_10a"}
)
config2 = RunnableConfig(
metadata={"langgraph_node": AssistantNodeName.ROOT, "langgraph_checkpoint_ns": "cp_10b"}
)
events1 = []
events2 = []
def capture_write1(event: AssistantDispatcherEvent):
events1.append(event)
def capture_write2(event: AssistantDispatcherEvent):
events2.append(event)
with patch("ee.hogai.utils.dispatcher.get_stream_writer", return_value=capture_write1):
await node1(state, config1)
with patch("ee.hogai.utils.dispatcher.get_stream_writer", return_value=capture_write2):
await node2(state, config2)
# Verify events went to separate lists
self.assertGreater(len(events1), 0)
self.assertGreater(len(events2), 0)
# Verify node_run_ids are different
for event in events1:
self.assertEqual(event.node_run_id, "cp_10a")
for event in events2:
self.assertEqual(event.node_run_id, "cp_10b")

View File

@@ -1,7 +1,7 @@
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnableConfig
from posthog.schema import AssistantMessage, AssistantTrendsQuery
from posthog.schema import AssistantTrendsQuery
from ee.hogai.utils.types import AssistantState
from ee.hogai.utils.types.base import AssistantNodeName, PartialAssistantState
@@ -25,7 +25,7 @@ class TrendsGeneratorNode(SchemaGeneratorNode[AssistantTrendsQuery]):
return AssistantNodeName.TRENDS_GENERATOR
async def arun(self, state: AssistantState, config: RunnableConfig) -> PartialAssistantState:
self.dispatcher.message(AssistantMessage(content="Creating trends query"))
self.dispatcher.update("Creating trends query")
prompt = ChatPromptTemplate.from_messages(
[
("system", TRENDS_SYSTEM_PROMPT),

View File

@@ -9,7 +9,7 @@ from posthog.test.base import (
_create_person,
flush_persons_and_events,
)
from unittest.mock import AsyncMock, patch
from unittest.mock import AsyncMock, MagicMock, patch
from django.test import override_settings
@@ -61,13 +61,13 @@ from posthog.models import Action
from ee.hogai.assistant.base import BaseAssistant
from ee.hogai.django_checkpoint.checkpointer import DjangoCheckpointer
from ee.hogai.graph.base import BaseAssistantNode
from ee.hogai.graph.base import AssistantNode
from ee.hogai.graph.funnels.nodes import FunnelsSchemaGeneratorOutput
from ee.hogai.graph.insights_graph.graph import InsightsGraph
from ee.hogai.graph.memory import prompts as memory_prompts
from ee.hogai.graph.retention.nodes import RetentionSchemaGeneratorOutput
from ee.hogai.graph.root.nodes import SLASH_COMMAND_INIT
from ee.hogai.graph.trends.nodes import TrendsSchemaGeneratorOutput
from ee.hogai.utils.state import GraphValueUpdateTuple
from ee.hogai.utils.tests import FakeAnthropicRunnableLambdaWithTokenCounter, FakeChatAnthropic, FakeChatOpenAI
from ee.hogai.utils.types import (
AssistantMode,
@@ -76,16 +76,21 @@ from ee.hogai.utils.types import (
AssistantState,
PartialAssistantState,
)
from ee.hogai.utils.types.base import ReplaceMessages
from ee.models.assistant import Conversation, CoreMemory
from ..assistant import Assistant
from ..graph import AssistantGraph, InsightsAssistantGraph
from ..graph.graph import AssistantGraph
title_generator_mock = patch(
"ee.hogai.graph.title_generator.nodes.TitleGeneratorNode._model",
return_value=FakeChatOpenAI(responses=[messages.AIMessage(content="Title")]),
)
query_executor_mock = patch(
"ee.hogai.graph.query_executor.nodes.QueryExecutorNode._format_query_result", new=MagicMock(return_value="Result")
)
class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest):
CLASS_DATA_LEVEL_SETUP = False
@@ -118,7 +123,7 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest):
),
).start()
self.checkpointer_patch = patch("ee.hogai.graph.graph.global_checkpointer", new=DjangoCheckpointer())
self.checkpointer_patch = patch("ee.hogai.graph.base.graph.global_checkpointer", new=DjangoCheckpointer())
self.checkpointer_patch.start()
def tearDown(self):
@@ -232,14 +237,7 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest):
graph = (
AssistantGraph(self.team, self.user)
.add_edge(AssistantNodeName.START, AssistantNodeName.ROOT)
.add_root(
{
"insights": AssistantNodeName.INSIGHTS_SUBGRAPH,
"root": AssistantNodeName.ROOT,
"end": AssistantNodeName.END,
}
)
.add_insights(AssistantNodeName.ROOT)
.add_root()
.compile()
)
@@ -269,7 +267,7 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest):
{
"id": "1",
"name": "create_and_query_insight",
"args": {"query_description": "Foobar", "query_kind": insight_type},
"args": {"query_description": "Foobar"},
}
],
)
@@ -302,7 +300,7 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest):
AssistantToolCall(
id="1",
name="create_and_query_insight",
args={"query_description": "Foobar", "query_kind": insight_type},
args={"query_description": "Foobar"},
)
],
),
@@ -327,12 +325,12 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest):
snapshot: StateSnapshot = await graph.aget_state(config)
self.assertFalse(snapshot.next)
self.assertFalse(snapshot.values.get("intermediate_steps"))
self.assertFalse(snapshot.values["plan"])
self.assertFalse(snapshot.values["graph_status"])
self.assertFalse(snapshot.values["root_tool_call_id"])
self.assertFalse(snapshot.values["root_tool_insight_plan"])
self.assertFalse(snapshot.values["root_tool_insight_type"])
self.assertFalse(snapshot.values["root_tool_calls_count"])
self.assertFalse(snapshot.values.get("plan"))
self.assertFalse(snapshot.values.get("graph_status"))
self.assertFalse(snapshot.values.get("root_tool_call_id"))
self.assertFalse(snapshot.values.get("root_tool_insight_plan"))
self.assertFalse(snapshot.values.get("root_tool_insight_type"))
self.assertFalse(snapshot.values.get("root_tool_calls_count"))
async def test_trends_interrupt_when_asking_for_help(self):
await self._test_human_in_the_loop("trends")
@@ -345,7 +343,7 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest):
async def test_ai_messages_appended_after_interrupt(self):
with patch("ee.hogai.graph.query_planner.nodes.QueryPlannerNode._get_model") as mock:
graph = InsightsAssistantGraph(self.team, self.user).compile_full_graph()
graph = InsightsGraph(self.team, self.user).compile_full_graph()
config: RunnableConfig = {
"configurable": {
"thread_id": self.conversation.id,
@@ -441,16 +439,13 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest):
self.assertIsInstance(output[1][1], FailureMessage)
async def test_new_conversation_handles_serialized_conversation(self):
from ee.hogai.graph.base import BaseAssistantNode
class TestNode(BaseAssistantNode):
class TestNode(AssistantNode):
@property
def node_name(self):
return AssistantNodeName.ROOT
async def arun(self, state, config):
self.dispatcher.message(AssistantMessage(content="Hello"))
return PartialAssistantState()
return PartialAssistantState(messages=[AssistantMessage(content="Hello", id=str(uuid4()))])
test_node = TestNode(self.team, self.user)
graph = (
@@ -478,16 +473,13 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest):
self.assertNotEqual(output[0][0], "conversation")
async def test_async_stream(self):
from ee.hogai.graph.base import BaseAssistantNode
class TestNode(BaseAssistantNode):
class TestNode(AssistantNode):
@property
def node_name(self):
return AssistantNodeName.ROOT
async def arun(self, state, config):
self.dispatcher.message(AssistantMessage(content="bar"))
return PartialAssistantState()
return PartialAssistantState(messages=[AssistantMessage(content="bar", id=str(uuid4()))])
test_node = TestNode(self.team, self.user)
graph = (
@@ -517,12 +509,13 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest):
self.assertConversationEqual(actual_output, expected_output)
async def test_async_stream_handles_exceptions(self):
def node_handler(state):
raise ValueError()
class NodeHandler(AssistantNode):
async def arun(self, state, config):
raise ValueError
graph = (
AssistantGraph(self.team, self.user)
.add_node(AssistantNodeName.ROOT, node_handler)
.add_node(AssistantNodeName.ROOT, NodeHandler(self.team, self.user))
.add_edge(AssistantNodeName.START, AssistantNodeName.ROOT)
.add_edge(AssistantNodeName.ROOT, AssistantNodeName.END)
.compile()
@@ -536,12 +529,11 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest):
("message", HumanMessage(content="foo")),
("message", FailureMessage()),
]
actual_output = []
async for event in assistant.astream():
actual_output.append(event)
actual_output, _ = await self._run_assistant_graph(graph, message="foo")
self.assertConversationEqual(actual_output, expected_output)
@title_generator_mock
@query_executor_mock
@patch("ee.hogai.graph.schema_generator.nodes.SchemaGeneratorNode._model")
@patch("ee.hogai.graph.query_planner.nodes.QueryPlannerNode._get_model")
@patch("ee.hogai.graph.root.nodes.RootNode._get_model")
@@ -556,7 +548,7 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest):
{
"id": "xyz",
"name": "create_and_query_insight",
"args": {"query_description": "Foobar", "query_kind": "trends"},
"args": {"query_description": "Foobar"},
}
],
)
@@ -596,7 +588,7 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest):
AssistantToolCall(
id="xyz",
name="create_and_query_insight",
args={"query_description": "Foobar", "query_kind": "trends"},
args={"query_description": "Foobar"},
)
],
),
@@ -611,7 +603,7 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest):
),
("update", AssistantUpdateEvent(id="message_2", tool_call_id="xyz", content="Creating trends query")),
("message", VisualizationMessage(query="Foobar", answer=query, plan="Plan")),
("message", {"tool_call_id": "xyz", "type": "tool"}), # Don't check content as it's implementation detail
("message", AssistantToolCallMessage(tool_call_id="xyz", content="Result")),
("message", AssistantMessage(content="The results indicate a great future for you.")),
]
self.assertConversationEqual(actual_output, expected_output)
@@ -634,6 +626,7 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest):
)
@title_generator_mock
@query_executor_mock
@patch("ee.hogai.graph.schema_generator.nodes.SchemaGeneratorNode._model")
@patch("ee.hogai.graph.query_planner.nodes.QueryPlannerNode._get_model")
@patch("ee.hogai.graph.root.nodes.RootNode._get_model")
@@ -648,7 +641,7 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest):
{
"id": "xyz",
"name": "create_and_query_insight",
"args": {"query_description": "Foobar", "query_kind": "funnel"},
"args": {"query_description": "Foobar"},
}
],
)
@@ -693,7 +686,7 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest):
AssistantToolCall(
id="xyz",
name="create_and_query_insight",
args={"query_description": "Foobar", "query_kind": "funnel"},
args={"query_description": "Foobar"},
)
],
),
@@ -708,7 +701,7 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest):
),
("update", AssistantUpdateEvent(id="message_2", tool_call_id="xyz", content="Creating funnel query")),
("message", VisualizationMessage(query="Foobar", answer=query, plan="Plan")),
("message", {"tool_call_id": "xyz", "type": "tool"}), # Don't check content as it's implementation detail
("message", AssistantToolCallMessage(tool_call_id="xyz", content="Result")),
("message", AssistantMessage(content="The results indicate a great future for you.")),
]
self.assertConversationEqual(actual_output, expected_output)
@@ -731,6 +724,7 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest):
)
@title_generator_mock
@query_executor_mock
@patch("ee.hogai.graph.schema_generator.nodes.SchemaGeneratorNode._model")
@patch("ee.hogai.graph.query_planner.nodes.QueryPlannerNode._get_model")
@patch("ee.hogai.graph.root.nodes.RootNode._get_model")
@@ -747,7 +741,7 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest):
{
"id": "xyz",
"name": "create_and_query_insight",
"args": {"query_description": "Foobar", "query_kind": "retention"},
"args": {"query_description": "Foobar"},
}
],
)
@@ -792,7 +786,7 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest):
AssistantToolCall(
id="xyz",
name="create_and_query_insight",
args={"query_description": "Foobar", "query_kind": "retention"},
args={"query_description": "Foobar"},
)
],
),
@@ -807,7 +801,7 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest):
),
("update", AssistantUpdateEvent(id="message_2", tool_call_id="xyz", content="Creating retention query")),
("message", VisualizationMessage(query="Foobar", answer=query, plan="Plan")),
("message", {"tool_call_id": "xyz", "type": "tool"}), # Don't check content as it's implementation detail
("message", AssistantToolCallMessage(tool_call_id="xyz", content="Result")),
("message", AssistantMessage(content="The results indicate a great future for you.")),
]
self.assertConversationEqual(actual_output, expected_output)
@@ -830,6 +824,7 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest):
)
@title_generator_mock
@query_executor_mock
@patch("ee.hogai.graph.schema_generator.nodes.SchemaGeneratorNode._model")
@patch("ee.hogai.graph.query_planner.nodes.QueryPlannerNode._get_model")
@patch("ee.hogai.graph.root.nodes.RootNode._get_model")
@@ -844,7 +839,7 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest):
{
"id": "xyz",
"name": "create_and_query_insight",
"args": {"query_description": "Foobar", "query_kind": "sql"},
"args": {"query_description": "Foobar"},
}
],
)
@@ -883,7 +878,7 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest):
AssistantToolCall(
id="xyz",
name="create_and_query_insight",
args={"query_description": "Foobar", "query_kind": "sql"},
args={"query_description": "Foobar"},
)
],
),
@@ -898,7 +893,7 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest):
),
("update", AssistantUpdateEvent(id="message_2", tool_call_id="xyz", content="Creating SQL query")),
("message", VisualizationMessage(query="Foobar", answer=query, plan="Plan")),
("message", {"tool_call_id": "xyz", "type": "tool"}), # Don't check content as it's implementation detail
("message", AssistantToolCallMessage(tool_call_id="xyz", content="Result")),
("message", AssistantMessage(content="The results indicate a great future for you.")),
]
self.assertConversationEqual(actual_output, expected_output)
@@ -1096,7 +1091,7 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest):
{
"id": str(uuid4()),
"name": "create_and_query_insight",
"args": {"query_description": "Foobar", "query_kind": "trends"},
"args": {"query_description": "Foobar"},
}
],
)
@@ -1138,7 +1133,7 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest):
graph = (
AssistantGraph(self.team, self.user)
.add_edge(AssistantNodeName.START, AssistantNodeName.ROOT)
.add_root({"root": AssistantNodeName.ROOT, "end": AssistantNodeName.END})
.add_root()
.compile()
)
self.assertEqual(self.conversation.status, Conversation.Status.IDLE)
@@ -1158,7 +1153,7 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest):
graph = (
AssistantGraph(self.team, self.user)
.add_edge(AssistantNodeName.START, AssistantNodeName.ROOT)
.add_root({"root": AssistantNodeName.ROOT, "end": AssistantNodeName.END})
.add_root()
.compile()
)
@@ -1177,7 +1172,7 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest):
{
"id": "1",
"name": "create_and_query_insight",
"args": {"query_description": "Foobar", "query_kind": "trends"},
"args": {"query_description": "Foobar"},
}
],
)
@@ -1209,14 +1204,7 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest):
graph = (
AssistantGraph(self.team, self.user)
.add_edge(AssistantNodeName.START, AssistantNodeName.ROOT)
.add_root(
{
"search_documentation": AssistantNodeName.INKEEP_DOCS,
"root": AssistantNodeName.ROOT,
"end": AssistantNodeName.END,
}
)
.add_inkeep_docs()
.add_root()
.compile()
)
@@ -1235,45 +1223,25 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest):
self.assertConversationEqual(
output,
[
(
"message",
HumanMessage(
content="How do I use feature flags?",
id="d93b3d57-2ff3-4c63-8635-8349955b16f0",
parent_tool_call_id=None,
type="human",
ui_context=None,
),
),
("message", HumanMessage(content="How do I use feature flags?")),
(
"message",
AssistantMessage(
content="",
id="ea1f4faf-85b4-4d98-b044-842b7092ba5d",
meta=None,
parent_tool_call_id=None,
tool_calls=[
AssistantToolCall(
args={"kind": "docs", "query": "test"}, id="1", name="search", type="tool_call"
)
],
type="ai",
),
),
(
"update",
AssistantUpdateEvent(id="message_2", tool_call_id="1", content="Checking PostHog documentation..."),
AssistantUpdateEvent(content="Checking PostHog documentation...", id="1", tool_call_id="1"),
),
(
"message",
AssistantToolCallMessage(
content="Checking PostHog documentation...",
id="00149847-0bcd-4b7c-a514-bab176312ae8",
parent_tool_call_id=None,
tool_call_id="1",
type="tool",
ui_payload=None,
),
AssistantToolCallMessage(content="Checking PostHog documentation...", tool_call_id="1"),
),
(
"message",
@@ -1283,12 +1251,10 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest):
)
@title_generator_mock
@query_executor_mock
@patch("ee.hogai.graph.schema_generator.nodes.SchemaGeneratorNode._model")
@patch("ee.hogai.graph.query_planner.nodes.QueryPlannerNode._get_model")
@patch("ee.hogai.graph.graph.QueryExecutorNode")
async def test_insights_tool_mode_flow(
self, query_executor_mock, planner_mock, generator_mock, title_generator_mock
):
async def test_insights_tool_mode_flow(self, planner_mock, generator_mock, title_generator_mock):
"""Test that the insights tool mode works correctly."""
query = AssistantTrendsQuery(series=[])
tool_call_id = str(uuid4())
@@ -1315,25 +1281,6 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest):
)
generator_mock.return_value = RunnableLambda(lambda _: TrendsSchemaGeneratorOutput(query=query))
class QueryExecutorNodeMock(BaseAssistantNode):
def __init__(self, team, user):
super().__init__(team, user)
@property
def node_name(self):
return AssistantNodeName.QUERY_EXECUTOR
async def arun(self, state, config):
return PartialAssistantState(
messages=[
AssistantToolCallMessage(
content="The results indicate a great future for you.", tool_call_id=tool_call_id
)
]
)
query_executor_mock.return_value = QueryExecutorNodeMock(self.team, self.user)
# Run in insights tool mode
output, _ = await self._run_assistant_graph(
conversation=self.conversation,
@@ -1347,9 +1294,7 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest):
("message", VisualizationMessage(query="Foobar", answer=query, plan="Plan")),
(
"message",
AssistantToolCallMessage(
content="The results indicate a great future for you.", tool_call_id=tool_call_id
),
AssistantToolCallMessage(content="Result", tool_call_id=tool_call_id),
),
]
self.assertConversationEqual(output, expected_output)
@@ -1387,7 +1332,6 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest):
async def test_merges_messages_with_same_id(self):
"""Test that messages with the same ID are merged into one."""
from ee.hogai.graph.base import BaseAssistantNode
message_ids = [str(uuid4()), str(uuid4())]
@@ -1395,7 +1339,7 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest):
first_content = "First version of message"
updated_content = "Updated version of message"
class MessageUpdatingNode(BaseAssistantNode):
class MessageUpdatingNode(AssistantNode):
def __init__(self, team, user):
super().__init__(team, user)
self.call_count = 0
@@ -1447,7 +1391,6 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest):
async def test_assistant_filters_messages_correctly(self):
"""Test that the Assistant class correctly filters messages based on should_output_assistant_message."""
from ee.hogai.graph.base import BaseAssistantNode
output_messages = [
# Should be output (has content)
@@ -1467,7 +1410,7 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest):
for test_message, expected_in_output in output_messages:
# Create a simple graph that produces different message types to test filtering
class MessageFilteringNode(BaseAssistantNode):
class MessageFilteringNode(AssistantNode):
def __init__(self, team, user, message_to_return):
super().__init__(team, user)
self.message_to_return = message_to_return
@@ -1477,8 +1420,7 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest):
return AssistantNodeName.ROOT
async def arun(self, state, config):
self.dispatcher.message(self.message_to_return)
return PartialAssistantState()
return PartialAssistantState(messages=[self.message_to_return])
# Create a graph with our test node
node = MessageFilteringNode(self.team, self.user, test_message)
@@ -1505,12 +1447,13 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest):
"""Test that ui_context persists when retrieving conversation state across multiple runs."""
# Create a simple graph that just returns the initial state
def return_initial_state(state):
return {"messages": [AssistantMessage(content="Response from assistant")]}
class ReturnInitialStateNode(AssistantNode):
async def arun(self, state, config):
return PartialAssistantState(messages=[AssistantMessage(content="Response from assistant")])
graph = (
AssistantGraph(self.team, self.user)
.add_node(AssistantNodeName.ROOT, return_initial_state)
.add_node(AssistantNodeName.ROOT, ReturnInitialStateNode(self.team, self.user))
.add_edge(AssistantNodeName.START, AssistantNodeName.ROOT)
.add_edge(AssistantNodeName.ROOT, AssistantNodeName.END)
.compile()
@@ -1585,8 +1528,8 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest):
tool_calls=[
{
"id": "xyz",
"name": "edit_current_insight",
"args": {"query_description": "Foobar", "query_kind": "trends"},
"name": "create_and_query_insight",
"args": {"query_description": "Foobar"},
}
],
)
@@ -1622,20 +1565,13 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest):
output, assistant = await self._run_assistant_graph(
test_graph=AssistantGraph(self.team, self.user)
.add_edge(AssistantNodeName.START, AssistantNodeName.ROOT)
.add_root(
{
"root": AssistantNodeName.ROOT,
"insights": AssistantNodeName.INSIGHTS_SUBGRAPH,
"end": AssistantNodeName.END,
}
)
.add_insights()
.add_root()
.compile(),
conversation=self.conversation,
is_new_conversation=True,
message="Hello",
mode=AssistantMode.ASSISTANT,
contextual_tools={"edit_current_insight": {"current_query": "query"}},
contextual_tools={"create_and_query_insight": {"current_query": "query"}},
)
expected_output = [
@@ -1645,23 +1581,30 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest):
"message",
AssistantMessage(
content="",
id="56076433-5d90-4248-9a46-df3fda42bd0a",
tool_calls=[
AssistantToolCall(
args={"query_description": "Foobar"},
id="xyz",
name="edit_current_insight",
args={"query_description": "Foobar", "query_kind": "trends"},
name="create_and_query_insight",
type="tool_call",
)
],
),
),
(
"update",
AssistantUpdateEvent(content="Picking relevant events and properties", tool_call_id="xyz", id=""),
),
("update", AssistantUpdateEvent(content="Creating trends query", tool_call_id="xyz", id="")),
("message", VisualizationMessage(query="Foobar", answer=query, plan="Plan")),
(
"message",
{
"tool_call_id": "xyz",
"type": "tool",
"ui_payload": {"edit_current_insight": query.model_dump(exclude_none=True)},
}, # Don't check content as it's implementation detail
AssistantToolCallMessage(
content="The results indicate a great future for you.",
tool_call_id="xyz",
ui_payload={"create_and_query_insight": query.model_dump(exclude_none=True)},
),
),
("message", AssistantMessage(content="Everything is fine")),
]
@@ -1671,7 +1614,7 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest):
state = AssistantState.model_validate(snapshot.values)
expected_state_messages = [
ContextMessage(
content="<system_reminder>\nContextual tools that are available to you on this page are:\n<edit_current_insight>\nThe user is currently editing an insight (aka query). Here is that insight's current definition, which can be edited using the `edit_current_insight` tool:\n\n```json\nquery\n```\n\nIMPORTANT: DO NOT REMOVE ANY FIELDS FROM THE CURRENT INSIGHT DEFINITION. DO NOT CHANGE ANY OTHER FIELDS THAN THE ONES THE USER ASKED FOR. KEEP THE REST AS IS.\n</edit_current_insight>\nIMPORTANT: this context may or may not be relevant to your tasks. You should not respond to this context unless it is highly relevant to your task.\n</system_reminder>"
content="<system_reminder>\nContextual tools that are available to you on this page are:\n<create_and_query_insight>\nThe user is currently editing an insight (aka query). Here is that insight's current definition, which can be edited using the `create_and_query_insight` tool:\n\n```json\nquery\n```\n\n<system_reminder>\nDo not remove any fields from the current insight definition. Do not change any other fields than the ones the user asked for. Keep the rest as is.\n</system_reminder>\n</create_and_query_insight>\nIMPORTANT: this context may or may not be relevant to your tasks. You should not respond to this context unless it is highly relevant to your task.\n</system_reminder>"
),
HumanMessage(content="Hello"),
AssistantMessage(
@@ -1679,8 +1622,8 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest):
tool_calls=[
AssistantToolCall(
id="xyz",
name="edit_current_insight",
args={"query_description": "Foobar", "query_kind": "trends"},
name="create_and_query_insight",
args={"query_description": "Foobar"},
)
],
),
@@ -1688,7 +1631,7 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest):
AssistantToolCallMessage(
content="The results indicate a great future for you.",
tool_call_id="xyz",
ui_payload={"edit_current_insight": query.model_dump(exclude_none=True)},
ui_payload={"create_and_query_insight": query.model_dump(exclude_none=True)},
),
AssistantMessage(content="Everything is fine"),
]
@@ -1705,7 +1648,7 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest):
graph = (
AssistantGraph(self.team, self.user)
.add_edge(AssistantNodeName.START, AssistantNodeName.ROOT)
.add_root({"root": AssistantNodeName.ROOT, "end": AssistantNodeName.END})
.add_root()
.compile()
)
@@ -1757,16 +1700,14 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest):
# Tests for ainvoke method
async def test_ainvoke_basic_functionality(self):
"""Test ainvoke returns all messages at once without streaming."""
from ee.hogai.graph.base import BaseAssistantNode
class TestNode(BaseAssistantNode):
class TestNode(AssistantNode):
@property
def node_name(self):
return AssistantNodeName.ROOT
async def arun(self, state, config):
self.dispatcher.message(AssistantMessage(content="Response"))
return PartialAssistantState()
return PartialAssistantState(messages=[AssistantMessage(content="Response", id=str(uuid4()))])
test_node = TestNode(self.team, self.user)
graph = (
@@ -1798,28 +1739,6 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest):
self.assertIsInstance(item[1], AssistantMessage)
self.assertEqual(cast(AssistantMessage, item[1]).content, "Response")
async def test_process_value_update_returns_none(self):
"""Test that _aprocess_value_update returns None for basic state updates (ACKs are now handled by reducer)."""
assistant = Assistant.create(
self.team, self.conversation, new_message=HumanMessage(content="Hello"), user=self.user
)
# Create a value update tuple that doesn't match special nodes
update = cast(
GraphValueUpdateTuple,
(
AssistantNodeName.ROOT,
{"root": {"messages": []}}, # Empty update that doesn't match visualization or verbose nodes
),
)
# Process the update
result = await assistant._aprocess_value_update(update)
# Should return None (ACK events are now generated by the reducer)
self.assertIsNone(result)
def test_billing_context_in_config(self):
billing_context = MaxBillingContext(
has_active_subscription=True,
@@ -1855,7 +1774,225 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest):
)
config = assistant._get_config()
self.assertEqual(config.get("configurable", {}).get("billing_context"), billing_context)
self.assertEqual(config["configurable"]["billing_context"], billing_context)
@patch("ee.hogai.context.context.AssistantContextManager.check_user_has_billing_access", return_value=True)
@patch("ee.hogai.graph.root.nodes.RootNode._get_model")
async def test_billing_tool_execution(self, root_mock, access_mock):
"""Test that the billing tool can be called and returns formatted billing information."""
billing_context = MaxBillingContext(
subscription_level=MaxBillingContextSubscriptionLevel.PAID,
billing_plan="startup",
has_active_subscription=True,
is_deactivated=False,
products=[
MaxProductInfo(
name="Product Analytics",
type="analytics",
description="Track user behavior",
current_usage=50000,
usage_limit=100000,
has_exceeded_limit=False,
is_used=True,
percentage_usage=0.5,
addons=[],
)
],
settings=MaxBillingContextSettings(autocapture_on=True, active_destinations=2),
)
# Mock the root node to call the read_data tool with billing_info kind
tool_call_id = str(uuid4())
def root_side_effect(msgs: list[BaseMessage]):
# Check if we've already received a tool result
last_message = msgs[-1]
if (
isinstance(last_message.content, list)
and isinstance(last_message.content[-1], dict)
and last_message.content[-1]["type"] == "tool_result"
):
# After tool execution, respond with final message
return messages.AIMessage(content="Your billing information shows you're on a startup plan.")
# First call - request the billing tool
return messages.AIMessage(
content="",
tool_calls=[
{
"id": tool_call_id,
"name": "read_data",
"args": {"kind": "billing_info"},
}
],
)
root_mock.return_value = FakeAnthropicRunnableLambdaWithTokenCounter(root_side_effect)
# Create a minimal test graph
test_graph = (
AssistantGraph(self.team, self.user)
.add_edge(AssistantNodeName.START, AssistantNodeName.ROOT)
.add_root()
.compile()
)
# Run the assistant with billing context
assistant = Assistant.create(
team=self.team,
conversation=self.conversation,
user=self.user,
new_message=HumanMessage(content="What's my current billing status?"),
billing_context=billing_context,
)
assistant._graph = test_graph
output: list[AssistantOutput] = []
async for event in assistant.astream():
output.append(event)
# Verify we received messages
self.assertGreater(len(output), 0)
# Find the assistant's final response
assistant_messages = [msg for event_type, msg in output if isinstance(msg, AssistantMessage)]
self.assertGreater(len(assistant_messages), 0)
# Verify the assistant received and used the billing information
# The mock returns "Your billing information shows you're on a startup plan."
final_message = cast(AssistantMessage, assistant_messages[-1])
self.assertIn("billing", final_message.content.lower())
self.assertIn("startup", final_message.content.lower())
async def test_messages_without_id_are_yielded(self):
"""Test that messages without ID are always yielded."""
class MessageWithoutIdNode(AssistantNode):
call_count = 0
async def arun(self, state, config):
self.call_count += 1
# Return message without ID - should always be yielded
return PartialAssistantState(
messages=[
AssistantMessage(content=f"Message {self.call_count} without ID"),
AssistantMessage(content=f"Message {self.call_count} without ID"),
]
)
node = MessageWithoutIdNode(self.team, self.user)
graph = (
AssistantGraph(self.team, self.user)
.add_node(AssistantNodeName.ROOT, node)
.add_edge(AssistantNodeName.START, AssistantNodeName.ROOT)
.add_edge(AssistantNodeName.ROOT, AssistantNodeName.END)
.compile()
)
# Run the assistant multiple times
output1, _ = await self._run_assistant_graph(graph, message="First run", conversation=self.conversation)
output2, _ = await self._run_assistant_graph(graph, message="Second run", conversation=self.conversation)
# Both runs should yield their messages (human + assistant message each)
self.assertEqual(len(output1), 3) # Human message + AI message + AI message
self.assertEqual(len(output2), 3) # Human message + AI message + AI message
async def test_messages_with_id_are_deduplicated(self):
"""Test that messages with ID are deduplicated during streaming."""
message_id = str(uuid4())
class DuplicateMessageNode(AssistantNode):
call_count = 0
async def arun(self, state, config):
self.call_count += 1
# Always return the same message with same ID
return PartialAssistantState(
messages=[
AssistantMessage(id=message_id, content=f"Call {self.call_count}"),
AssistantMessage(id=message_id, content=f"Call {self.call_count}"),
]
)
node = DuplicateMessageNode(self.team, self.user)
graph = (
AssistantGraph(self.team, self.user)
.add_node(AssistantNodeName.ROOT, node)
.add_edge(AssistantNodeName.START, AssistantNodeName.ROOT)
.add_edge(AssistantNodeName.ROOT, AssistantNodeName.END)
.compile()
)
# Create assistant and manually test the streaming behavior
assistant = Assistant.create(
self.team,
self.conversation,
new_message=HumanMessage(content="Test message"),
user=self.user,
is_new_conversation=False,
)
assistant._graph = graph
# Collect all streamed messages
streamed_messages = []
async for event_type, message in assistant.astream(stream_first_message=False):
if event_type == AssistantEventType.MESSAGE:
streamed_messages.append(message)
# Should only get one message despite the node being called multiple times
assistant_messages = [
msg for msg in streamed_messages if isinstance(msg, AssistantMessage) and msg.id == message_id
]
self.assertEqual(len(assistant_messages), 1, "Message with same ID should only be yielded once")
async def test_replaced_messaged_are_not_double_streamed(self):
"""Test that existing messages are not streamed again"""
# Create messages with IDs that should be tracked
message_id_1 = str(uuid4())
message_id_2 = str(uuid4())
call_count = [0]
# Create a simple graph that returns messages with IDs
class TestNode(AssistantNode):
async def arun(self, state, config):
result = None
if call_count[0] == 0:
result = PartialAssistantState(
messages=[
AssistantMessage(id=message_id_1, content="Message 1"),
]
)
else:
result = PartialAssistantState(
messages=ReplaceMessages(
[
AssistantMessage(id=message_id_1, content="Message 1"),
AssistantMessage(id=message_id_2, content="Message 2"),
]
)
)
call_count[0] += 1
return result
graph = (
AssistantGraph(self.team, self.user)
.add_node(AssistantNodeName.ROOT, TestNode(self.team, self.user))
.add_edge(AssistantNodeName.START, AssistantNodeName.ROOT)
.add_edge(AssistantNodeName.ROOT, AssistantNodeName.END)
.compile()
)
output, _ = await self._run_assistant_graph(graph, message="First run", conversation=self.conversation)
# Filter for assistant messages only, as the test is about tracking assistant message IDs
assistant_output = [(event_type, msg) for event_type, msg in output if isinstance(msg, AssistantMessage)]
self.assertEqual(len(assistant_output), 1)
self.assertEqual(cast(AssistantMessage, assistant_output[0][1]).id, message_id_1)
output, _ = await self._run_assistant_graph(graph, message="Second run", conversation=self.conversation)
# Filter for assistant messages only, as the test is about tracking assistant message IDs
assistant_output = [(event_type, msg) for event_type, msg in output if isinstance(msg, AssistantMessage)]
self.assertEqual(len(assistant_output), 1)
self.assertEqual(cast(AssistantMessage, assistant_output[0][1]).id, message_id_2)
@patch(
"ee.hogai.graph.conversation_summarizer.nodes.AnthropicConversationSummarizer.summarize",
@@ -1887,22 +2024,7 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest):
mock_tool.return_value = ("Event list" * 128000, None)
mock_should_compact.side_effect = cycle([False, True]) # Also changed this
graph = (
AssistantGraph(self.team, self.user)
.add_root(
path_map={
"insights": AssistantNodeName.END,
"search_documentation": AssistantNodeName.END,
"root": AssistantNodeName.ROOT,
"end": AssistantNodeName.END,
"insights_search": AssistantNodeName.END,
"session_summarization": AssistantNodeName.END,
"create_dashboard": AssistantNodeName.END,
}
)
.add_memory_onboarding()
.compile()
)
graph = AssistantGraph(self.team, self.user).add_root().add_memory_onboarding().compile()
expected_output = [
("message", HumanMessage(content="First")),
@@ -1928,3 +2050,86 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest):
self.assertEqual(state.start_id, new_human_message.id)
# should be equal to the summary message, minus reasoning message
self.assertEqual(state.root_conversation_start_id, state.messages[3].id)
@patch("ee.hogai.graph.root.tools.search.SearchTool._arun_impl", return_value=("Docs doubt it", None))
@patch(
"ee.hogai.graph.root.tools.read_taxonomy.ReadTaxonomyTool._run_impl",
return_value=("Hedgehogs have not talked yet", None),
)
@patch("ee.hogai.graph.root.nodes.RootNode._get_model")
async def test_root_node_can_execute_multiple_tool_calls(self, root_mock, search_mock, read_taxonomy_mock):
"""Test that the root node can execute multiple tool calls in parallel."""
tool_call_id1, tool_call_id2 = [str(uuid4()), str(uuid4())]
def root_side_effect(msgs: list[BaseMessage]):
# Check if we've already received a tool result
last_message = msgs[-1]
if (
isinstance(last_message.content, list)
and isinstance(last_message.content[-1], dict)
and last_message.content[-1]["type"] == "tool_result"
):
# After tool execution, respond with final message
return messages.AIMessage(content="No")
return messages.AIMessage(
content="Not sure. Let me check.",
tool_calls=[
{
"id": tool_call_id1,
"name": "search",
"args": {"kind": "docs", "query": "Do hedgehogs speak?"},
},
{
"id": tool_call_id2,
"name": "read_taxonomy",
"args": {"query": {"kind": "events"}},
},
],
)
root_mock.return_value = FakeAnthropicRunnableLambdaWithTokenCounter(root_side_effect)
# Create a minimal test graph
graph = (
AssistantGraph(self.team, self.user)
.add_edge(AssistantNodeName.START, AssistantNodeName.ROOT)
.add_root()
.compile()
)
expected_output = [
(AssistantEventType.MESSAGE, HumanMessage(content="Do hedgehogs speak?")),
(
AssistantEventType.MESSAGE,
AssistantMessage(
content="Not sure. Let me check.",
tool_calls=[
{
"id": tool_call_id1,
"name": "search",
"args": {"kind": "docs", "query": "Do hedgehogs speak?"},
},
{
"id": tool_call_id2,
"name": "read_taxonomy",
"args": {"query": {"kind": "events"}},
},
],
),
),
(
AssistantEventType.MESSAGE,
AssistantToolCallMessage(content="Docs doubt it", tool_call_id=tool_call_id1),
),
(
AssistantEventType.MESSAGE,
AssistantToolCallMessage(content="Hedgehogs have not talked yet", tool_call_id=tool_call_id2),
),
(AssistantEventType.MESSAGE, AssistantMessage(content="No")),
]
output, _ = await self._run_assistant_graph(
graph, message="Do hedgehogs speak?", conversation=self.conversation
)
self.assertConversationEqual(output, expected_output)

View File

@@ -1,6 +1,7 @@
import json
import pkgutil
import importlib
from abc import ABC, abstractmethod
from collections.abc import Sequence
from typing import Any, Literal, Self
@@ -16,9 +17,9 @@ from posthog.models import Team, User
import products
from ee.hogai.context.context import AssistantContextManager
from ee.hogai.graph.mixins import AssistantContextMixin
from ee.hogai.utils.types import AssistantState
from ee.hogai.utils.types.base import AssistantMessageUnion
from ee.hogai.graph.base.context import get_node_path, set_node_path
from ee.hogai.graph.mixins import AssistantContextMixin, AssistantDispatcherMixin
from ee.hogai.utils.types.base import AssistantMessageUnion, AssistantState, NodePath
CONTEXTUAL_TOOL_NAME_TO_TOOL: dict[AssistantTool, type["MaxTool"]] = {}
@@ -51,7 +52,7 @@ class ToolMessagesArtifact(BaseModel):
messages: Sequence[AssistantMessageUnion]
class MaxTool(AssistantContextMixin, BaseTool):
class MaxTool(AssistantContextMixin, AssistantDispatcherMixin, BaseTool):
# LangChain's default is just "content", but we always want to return the tool call artifact too
# - it becomes the `ui_payload`
response_format: Literal["content_and_artifact"] = "content_and_artifact"
@@ -66,7 +67,7 @@ class MaxTool(AssistantContextMixin, BaseTool):
_config: RunnableConfig
_state: AssistantState
_context_manager: AssistantContextManager
_tool_call_id: str
_node_path: tuple[NodePath, ...]
# DEPRECATED: Use `_arun_impl` instead
def _run_impl(self, *args, **kwargs) -> tuple[str, Any]:
@@ -82,7 +83,7 @@ class MaxTool(AssistantContextMixin, BaseTool):
*,
team: Team,
user: User,
tool_call_id: str,
node_path: tuple[NodePath, ...] | None = None,
state: AssistantState | None = None,
config: RunnableConfig | None = None,
name: str | None = None,
@@ -102,7 +103,10 @@ class MaxTool(AssistantContextMixin, BaseTool):
super().__init__(**tool_kwargs, **kwargs)
self._team = team
self._user = user
self._tool_call_id = tool_call_id
if node_path is None:
self._node_path = (*(get_node_path() or ()), NodePath(name=self.node_name))
else:
self._node_path = node_path
self._state = state if state else AssistantState(messages=[])
self._config = config if config else RunnableConfig(configurable={})
self._context_manager = context_manager or AssistantContextManager(team, user, self._config)
@@ -120,19 +124,35 @@ class MaxTool(AssistantContextMixin, BaseTool):
CONTEXTUAL_TOOL_NAME_TO_TOOL[accepted_name] = cls
def _run(self, *args, config: RunnableConfig, **kwargs):
"""LangChain default runner."""
try:
return self._run_impl(*args, **kwargs)
return self._run_with_context(*args, **kwargs)
except NotImplementedError:
pass
return async_to_sync(self._arun_impl)(*args, **kwargs)
return async_to_sync(self._arun_with_context)(*args, **kwargs)
async def _arun(self, *args, config: RunnableConfig, **kwargs):
"""LangChain default runner."""
try:
return await self._arun_impl(*args, **kwargs)
return await self._arun_with_context(*args, **kwargs)
except NotImplementedError:
pass
return await super()._arun(*args, config=config, **kwargs)
def _run_with_context(self, *args, **kwargs):
"""Sets the context for the tool."""
with set_node_path(self.node_path):
return self._run_impl(*args, **kwargs)
async def _arun_with_context(self, *args, **kwargs):
"""Sets the context for the tool."""
with set_node_path(self.node_path):
return await self._arun_impl(*args, **kwargs)
@property
def node_name(self) -> str:
return f"max_tool.{self.get_name()}"
@property
def context(self) -> dict:
return self._context_manager.get_contextual_tools().get(self.get_name(), {})
@@ -149,7 +169,7 @@ class MaxTool(AssistantContextMixin, BaseTool):
*,
team: Team,
user: User,
tool_call_id: str,
node_path: tuple[NodePath, ...] | None = None,
state: AssistantState | None = None,
config: RunnableConfig | None = None,
context_manager: AssistantContextManager | None = None,
@@ -160,5 +180,31 @@ class MaxTool(AssistantContextMixin, BaseTool):
Override this factory to dynamically modify the tool name, description, args schema, etc.
"""
return cls(
team=team, user=user, tool_call_id=tool_call_id, state=state, config=config, context_manager=context_manager
team=team, user=user, node_path=node_path, state=state, config=config, context_manager=context_manager
)
class MaxSubtool(AssistantDispatcherMixin, ABC):
_config: RunnableConfig
def __init__(
self,
*,
team: Team,
user: User,
state: AssistantState,
config: RunnableConfig,
context_manager: AssistantContextManager,
):
self._team = team
self._user = user
self._state = state
self._context_manager = context_manager
@abstractmethod
async def execute(self, *args, **kwargs) -> Any:
pass
@property
def node_name(self) -> str:
return f"max_subtool.{self.__class__.__name__}"

View File

@@ -1,21 +1,19 @@
from collections.abc import Callable
from typing import TYPE_CHECKING, Any
from typing import Any
from langchain_core.runnables import RunnableConfig
from langgraph.config import get_stream_writer
from langgraph.types import StreamWriter
from posthog.schema import AssistantMessage, AssistantToolCallMessage
from ee.hogai.utils.types.base import (
AssistantActionUnion,
AssistantDispatcherEvent,
AssistantMessageUnion,
MessageAction,
NodeStartAction,
NodePath,
UpdateAction,
)
if TYPE_CHECKING:
from ee.hogai.utils.types.composed import MaxNodeName
class AssistantDispatcher:
"""
@@ -26,13 +24,14 @@ class AssistantDispatcher:
The dispatcher does NOT update state - it just emits actions to the stream.
"""
_parent_tool_call_id: str | None = None
_node_path: tuple[NodePath, ...]
def __init__(
self,
writer: StreamWriter | Callable[[Any], None],
node_name: "MaxNodeName",
parent_tool_call_id: str | None = None,
node_path: tuple[NodePath, ...],
node_name: str,
node_run_id: str,
):
"""
Create a dispatcher for a specific node.
@@ -40,9 +39,10 @@ class AssistantDispatcher:
Args:
node_name: The name of the node dispatching actions (for attribution)
"""
self._node_name = node_name
self._writer = writer
self._parent_tool_call_id = parent_tool_call_id
self._node_path = node_path
self._node_name = node_name
self._node_run_id = node_run_id
def dispatch(self, action: AssistantActionUnion) -> None:
"""
@@ -56,7 +56,11 @@ class AssistantDispatcher:
action: Action dict with "type" and "payload" keys
"""
try:
self._writer(AssistantDispatcherEvent(action=action, node_name=self._node_name))
self._writer(
AssistantDispatcherEvent(
action=action, node_path=self._node_path, node_name=self._node_name, node_run_id=self._node_run_id
)
)
except Exception as e:
# Log error but don't crash node execution
# The dispatcher should be resilient to writer failures
@@ -69,34 +73,29 @@ class AssistantDispatcher:
"""
Dispatch a message to the stream.
"""
if self._parent_tool_call_id:
# If the dispatcher is initialized with a parent tool call id, set the parent tool call id on the message
# This is to ensure that the message is associated with the correct tool call
# Don't set parent_tool_call_id on:
# 1. AssistantToolCallMessage with the same tool_call_id (to avoid self-reference)
# 2. AssistantMessage with tool_calls containing the same ID (to avoid cycles)
should_skip = False
if isinstance(message, AssistantToolCallMessage) and self._parent_tool_call_id == message.tool_call_id:
should_skip = True
elif isinstance(message, AssistantMessage) and message.tool_calls:
# Check if any tool call has the same ID as the parent
for tool_call in message.tool_calls:
if tool_call.id == self._parent_tool_call_id:
should_skip = True
break
if not should_skip:
message.parent_tool_call_id = self._parent_tool_call_id
self.dispatch(MessageAction(message=message))
def node_start(self) -> None:
"""
Dispatch a node start action to the stream.
"""
self.dispatch(NodeStartAction())
def update(self, content: str):
"""Dispatch a transient update message to the stream that will be associated with a tool call in the UI."""
self.dispatch(UpdateAction(content=content))
def set_as_root(self) -> None:
"""
Set the dispatcher as the root.
"""
self._parent_tool_call_id = None
def create_dispatcher_from_config(config: RunnableConfig, node_path: tuple[NodePath, ...]) -> AssistantDispatcher:
"""Create a dispatcher from a RunnableConfig and node path"""
# Set writer from LangGraph context
try:
writer = get_stream_writer()
except RuntimeError:
# Not in streaming context (e.g., testing)
# Use noop writer
def noop(*_args, **_kwargs):
pass
writer = noop
metadata = config.get("metadata") or {}
node_name: str = metadata.get("langgraph_node") or ""
# `langgraph_checkpoint_ns` contains the nested path to the node, so it's more accurate for streaming.
node_run_id: str = metadata.get("langgraph_checkpoint_ns") or ""
return AssistantDispatcher(writer, node_path=node_path, node_run_id=node_run_id, node_name=node_name)

View File

@@ -38,8 +38,7 @@ from posthog.hogql_queries.query_runner import ExecutionMode
from posthog.models import Team
from posthog.taxonomy.taxonomy import CORE_FILTER_DEFINITIONS_BY_GROUP
from ee.hogai.utils.types import AssistantMessageUnion
from ee.hogai.utils.types.base import AssistantDispatcherEvent
from ee.hogai.utils.types.base import AssistantDispatcherEvent, AssistantMessageUnion
def remove_line_breaks(line: str) -> str:

View File

@@ -5,7 +5,7 @@ from structlog import get_logger
from ee.hogai.graph.deep_research.types import DeepResearchNodeName, PartialDeepResearchState
from ee.hogai.graph.taxonomy.types import TaxonomyAgentState, TaxonomyNodeName
from ee.hogai.utils.types import PartialAssistantState
from ee.hogai.utils.types.base import PartialAssistantState
from ee.hogai.utils.types.composed import AssistantMaxGraphState, AssistantMaxPartialGraphState, MaxNodeName
# A state update can have a partial state or a LangGraph's reserved dataclasses like Interrupt.
@@ -45,6 +45,7 @@ def validate_value_update(
class LangGraphState(TypedDict):
langgraph_node: MaxNodeName
langgraph_checkpoint_ns: str
GraphMessageUpdateTuple = tuple[Literal["messages"], tuple[Union[AIMessageChunk, Any], LangGraphState]]

View File

@@ -1,4 +1,4 @@
from typing import cast
from typing import Generic, Protocol, TypeVar, cast, get_args
import structlog
from langchain_core.messages import AIMessageChunk
@@ -6,7 +6,6 @@ from langchain_core.messages import AIMessageChunk
from posthog.schema import (
AssistantGenerationStatusEvent,
AssistantGenerationStatusType,
AssistantMessage,
AssistantToolCallMessage,
AssistantUpdateEvent,
FailureMessage,
@@ -16,21 +15,51 @@ from posthog.schema import (
)
from ee.hogai.utils.helpers import normalize_ai_message, should_output_assistant_message
from ee.hogai.utils.state import merge_message_chunk
from ee.hogai.utils.state import is_message_update, is_state_update, merge_message_chunk
from ee.hogai.utils.types.base import (
AssistantDispatcherEvent,
AssistantGraphName,
AssistantMessageUnion,
AssistantResultUnion,
BaseStateWithMessages,
LangGraphUpdateEvent,
MessageAction,
MessageChunkAction,
NodeEndAction,
NodePath,
NodeStartAction,
UpdateAction,
)
from ee.hogai.utils.types.composed import MaxNodeName
logger = structlog.get_logger(__name__)
class AssistantStreamProcessor:
class AssistantStreamProcessorProtocol(Protocol):
"""Protocol defining the interface for assistant stream processors."""
_streamed_update_ids: set[str]
"""Tracks the IDs of messages that have been streamed."""
def process(self, event: AssistantDispatcherEvent) -> list[AssistantResultUnion] | None:
"""Process a dispatcher event and return a result or None."""
...
def process_langgraph_update(self, event: LangGraphUpdateEvent) -> list[AssistantResultUnion] | None:
"""Process a LangGraph update event and return a list of results or None."""
...
def mark_id_as_streamed(self, message_id: str) -> None:
"""Mark a message ID as streamed."""
self._streamed_update_ids.add(message_id)
StateType = TypeVar("StateType", bound=BaseStateWithMessages)
MESSAGE_TYPE_TUPLE = get_args(AssistantMessageUnion)
class AssistantStreamProcessor(AssistantStreamProcessorProtocol, Generic[StateType]):
"""
Reduces streamed actions to client-facing messages.
@@ -38,36 +67,34 @@ class AssistantStreamProcessor:
handlers based on action type and message characteristics.
"""
_verbose_nodes: set[MaxNodeName]
"""Nodes that emit messages."""
_streaming_nodes: set[MaxNodeName]
"""Nodes that produce streaming messages."""
_visualization_nodes: dict[MaxNodeName, type]
"""Nodes that produce visualization messages."""
_tool_call_id_to_message: dict[str, AssistantMessage]
"""Maps tool call IDs to their parent messages for message chain tracking."""
_streamed_update_ids: set[str]
"""Tracks the IDs of messages that have been streamed."""
_chunks: AIMessageChunk
_chunks: dict[str, AIMessageChunk]
"""Tracks the current message chunk."""
_state: StateType | None
"""Tracks the current state."""
_state_type: type[StateType]
"""The type of the state."""
def __init__(
self,
streaming_nodes: set[MaxNodeName],
visualization_nodes: dict[MaxNodeName, type],
):
def __init__(self, verbose_nodes: set[MaxNodeName], streaming_nodes: set[MaxNodeName], state_type: type[StateType]):
"""
Initialize the stream processor with node configuration.
Args:
verbose_nodes: Nodes that produce messages
streaming_nodes: Nodes that produce streaming messages
visualization_nodes: Nodes that produce visualization messages
"""
# If a node is streaming node, it should also be verbose.
self._verbose_nodes = verbose_nodes | streaming_nodes
self._streaming_nodes = streaming_nodes
self._visualization_nodes = visualization_nodes
self._tool_call_id_to_message = {}
self._streamed_update_ids = set()
self._chunks = AIMessageChunk(content="")
self._chunks = {}
self._state_type = state_type
self._state = None
def process(self, event: AssistantDispatcherEvent) -> AssistantResultUnion | None:
def process(self, event: AssistantDispatcherEvent) -> list[AssistantResultUnion] | None:
"""
Reduce streamed actions to client messages.
@@ -75,86 +102,113 @@ class AssistantStreamProcessor:
to specialized handlers based on action type and message characteristics.
"""
action = event.action
node_name = event.node_name
if isinstance(action, MessageChunkAction):
return self._handle_message_stream(action.message, cast(MaxNodeName, node_name))
if isinstance(action, NodeStartAction):
return AssistantGenerationStatusEvent(type=AssistantGenerationStatusType.ACK)
self._chunks[event.node_run_id] = AIMessageChunk(content="")
return [AssistantGenerationStatusEvent(type=AssistantGenerationStatusType.ACK)]
if isinstance(action, NodeEndAction):
if event.node_run_id in self._chunks:
del self._chunks[event.node_run_id]
return self._handle_node_end(event, action)
if isinstance(action, MessageChunkAction) and (result := self._handle_message_stream(event, action.message)):
return [result]
if isinstance(action, MessageAction):
message = action.message
if result := self._handle_message(event, message):
return [result]
# Register any tool calls for later parent chain lookups
self._register_tool_calls(message)
result = self._handle_message(message, cast(MaxNodeName, node_name))
return (
result if result is not None else AssistantGenerationStatusEvent(type=AssistantGenerationStatusType.ACK)
if isinstance(action, UpdateAction) and (update_event := self._handle_update_message(event, action)):
return [update_event]
return None
def process_langgraph_update(self, event: LangGraphUpdateEvent) -> list[AssistantResultUnion] | None:
"""
Process a LangGraph update event.
"""
if is_message_update(event.update):
# Convert the message chunk update to a dispatcher event to prepare for a bright future without LangGraph
maybe_message_chunk, state = event.update[1]
if not isinstance(maybe_message_chunk, AIMessageChunk):
return None
action = AssistantDispatcherEvent(
action=MessageChunkAction(message=maybe_message_chunk),
node_name=state["langgraph_node"],
node_run_id=state["langgraph_checkpoint_ns"],
)
return self.process(action)
def _find_parent_ids(self, message: AssistantMessage) -> tuple[str | None, str | None]:
"""
Walk up the message chain to find the root parent's message_id and tool_call_id.
if is_state_update(event.update):
new_state = self._state_type.model_validate(event.update[1])
self._state = new_state
Returns (root_message_id, root_tool_call_id) for the root message in the chain.
Includes cycle detection and max depth protection.
"""
root_tool_call_id = message.parent_tool_call_id
if root_tool_call_id is None:
return message.id, None
return None
root_message_id = None
visited: set[str] = set()
def _handle_message(
self, action: AssistantDispatcherEvent, message: AssistantMessageUnion
) -> AssistantResultUnion | None:
"""Handle a message from a node."""
node_name = cast(MaxNodeName, action.node_name)
produced_message: AssistantResultUnion | None = None
while root_tool_call_id is not None:
if root_tool_call_id in visited:
# Cycle detected, we skip this message
return None, None
# Output all messages from the top-level graph.
if not self._is_message_from_nested_node_or_graph(action.node_path or ()):
produced_message = self._handle_root_message(message, node_name)
# Other message types with parents (viz, notebook, failure, tool call)
else:
produced_message = self._handle_special_child_message(message, node_name)
visited.add(root_tool_call_id)
parent_message = self._tool_call_id_to_message.get(root_tool_call_id)
if parent_message is None:
# The parent message is not registered, we skip this message as it could come
# from a sub-nested graph invoked directly by a contextual tool.
return None, None
# Messages with existing IDs must be deduplicated.
# Messages WITHOUT IDs must be streamed because they're progressive.
if isinstance(produced_message, MESSAGE_TYPE_TUPLE) and produced_message.id is not None:
if produced_message.id in self._streamed_update_ids:
return None
self._streamed_update_ids.add(produced_message.id)
next_parent_tool_call_id = parent_message.parent_tool_call_id
root_message_id = parent_message.id
if next_parent_tool_call_id is None:
return root_message_id, root_tool_call_id
root_tool_call_id = next_parent_tool_call_id
raise ValueError("Should not reach here")
return produced_message
def _register_tool_calls(self, message: AssistantMessageUnion) -> None:
"""Register any tool calls in the message for later lookup."""
if isinstance(message, AssistantMessage) and message.tool_calls is not None:
for tool_call in message.tool_calls:
self._tool_call_id_to_message[tool_call.id] = message
def _is_message_from_nested_node_or_graph(self, node_path: tuple[NodePath, ...]) -> bool:
"""Check if the message is from a nested node or graph."""
if not node_path:
return False
# The first path is always the top-level graph.
# The second path is always the top-level node.
# If the path is longer than 2, it's a MaxTool or nested graphs.
if len(node_path) > 2 and next((path for path in node_path[1:] if path.name in AssistantGraphName), None):
return True
return False
def _handle_root_message(
self, message: AssistantMessageUnion, node_name: MaxNodeName
) -> AssistantMessageUnion | None:
"""Handle messages with no parent (root messages)."""
if not should_output_assistant_message(message):
if node_name not in self._verbose_nodes or not should_output_assistant_message(message):
return None
return message
def _handle_assistant_message_with_parent(self, message: AssistantMessage) -> AssistantUpdateEvent | None:
def _handle_update_message(
self, event: AssistantDispatcherEvent, action: UpdateAction
) -> AssistantUpdateEvent | None:
"""Handle AssistantMessage that has a parent, creating an AssistantUpdateEvent."""
parent_id, parent_tool_call_id = self._find_parent_ids(message)
if parent_tool_call_id is None or parent_id is None:
if not event.node_path or not action.content:
return None
if message.content == "":
# Find the closest tool call id to the update.
parent_path = next((path for path in reversed(event.node_path) if path.tool_call_id), None)
# Updates from the top-level graph nodes are not supported.
if not parent_path:
return None
return AssistantUpdateEvent(
id=parent_id,
tool_call_id=parent_tool_call_id,
content=message.content,
)
tool_call_id = parent_path.tool_call_id
message_id = parent_path.message_id
if not message_id or not tool_call_id:
return None
return AssistantUpdateEvent(id=message_id, tool_call_id=tool_call_id, content=action.content)
def _handle_special_child_message(
self, message: AssistantMessageUnion, node_name: MaxNodeName
@@ -164,14 +218,10 @@ class AssistantStreamProcessor:
These messages are returned as-is regardless of where in the nesting hierarchy they are.
"""
# Return visualization messages only if from visualization nodes
if isinstance(message, VisualizationMessage | MultiVisualizationMessage):
if node_name in self._visualization_nodes:
return message
return None
# These message types are always returned as-is
if isinstance(message, NotebookUpdateMessage | FailureMessage):
if isinstance(message, VisualizationMessage | MultiVisualizationMessage) or isinstance(
message, NotebookUpdateMessage | FailureMessage
):
return message
if isinstance(message, AssistantToolCallMessage):
@@ -181,37 +231,44 @@ class AssistantStreamProcessor:
# Should not reach here
raise ValueError(f"Unhandled special message type: {type(message).__name__}")
def _handle_message(self, message: AssistantMessageUnion, node_name: MaxNodeName) -> AssistantResultUnion | None:
# Messages with existing IDs must be deduplicated.
# Messages WITHOUT IDs must be streamed because they're progressive.
if hasattr(message, "id") and message.id is not None:
if message.id in self._streamed_update_ids:
return None
self._streamed_update_ids.add(message.id)
# Root messages (no parent) are filtered by VERBOSE_NODES
if message.parent_tool_call_id is None:
return self._handle_root_message(message, node_name)
# AssistantMessage with parent creates AssistantUpdateEvent
if isinstance(message, AssistantMessage):
return self._handle_assistant_message_with_parent(message)
else:
# Other message types with parents (viz, notebook, failure, tool call)
return self._handle_special_child_message(message, node_name)
def _handle_message_stream(self, message: AIMessageChunk, node_name: MaxNodeName) -> AssistantResultUnion | None:
def _handle_message_stream(
self, event: AssistantDispatcherEvent, message: AIMessageChunk
) -> AssistantResultUnion | None:
"""
Process LLM chunks from "messages" stream mode.
With dispatch pattern, complete messages are dispatched by nodes.
This handles AIMessageChunk for ephemeral streaming (responsiveness).
"""
node_name = cast(MaxNodeName, event.node_name)
run_id = event.node_run_id
if node_name not in self._streaming_nodes:
return None
if run_id not in self._chunks:
self._chunks[run_id] = AIMessageChunk(content="")
# Merge message chunks
self._chunks = merge_message_chunk(self._chunks, message)
self._chunks[run_id] = merge_message_chunk(self._chunks[run_id], message)
# Stream ephemeral message (no ID = not persisted)
return normalize_ai_message(self._chunks)
return normalize_ai_message(self._chunks[run_id])
def _handle_node_end(
self, event: AssistantDispatcherEvent, action: NodeEndAction
) -> list[AssistantResultUnion] | None:
"""Handle the end of a node. Reset the streaming chunks."""
if not isinstance(action.state, BaseStateWithMessages):
return None
results: list[AssistantResultUnion] = []
for message in action.state.messages:
if new_event := self.process(
AssistantDispatcherEvent(
action=MessageAction(message=message),
node_name=event.node_name,
node_run_id=event.node_run_id,
node_path=event.node_path,
)
):
results.extend(new_event)
return results

View File

@@ -1,165 +1,426 @@
"""
Comprehensive tests for AssistantDispatcher.
Tests the dispatcher logic that emits actions to LangGraph custom stream.
"""
from posthog.test.base import BaseTest
from unittest.mock import AsyncMock, MagicMock
from unittest.mock import MagicMock, patch
from langchain_core.runnables import RunnableConfig
from posthog.schema import AssistantMessage, AssistantToolCall
from ee.hogai.graph.base import BaseAssistantNode
from ee.hogai.utils.dispatcher import AssistantDispatcher
from ee.hogai.utils.types import AssistantState, PartialAssistantState
from ee.hogai.utils.types.base import AssistantDispatcherEvent, AssistantNodeName
from ee.hogai.utils.dispatcher import AssistantDispatcher, create_dispatcher_from_config
from ee.hogai.utils.types.base import (
AssistantDispatcherEvent,
AssistantGraphName,
AssistantNodeName,
MessageAction,
NodePath,
UpdateAction,
)
class TestMessageDispatcher(BaseTest):
class TestAssistantDispatcher(BaseTest):
"""Test the AssistantDispatcher in isolation."""
def setUp(self):
self._writer = MagicMock()
self._writer.__aiter__ = AsyncMock(return_value=iter([]))
self._writer.write = AsyncMock()
super().setUp()
self.dispatched_events: list[AssistantDispatcherEvent] = []
def test_create_dispatcher(self):
"""Test creating a MessageDispatcher"""
dispatcher = AssistantDispatcher(self._writer, node_name=AssistantNodeName.ROOT)
def mock_writer(event):
self.dispatched_events.append(event)
self.writer = mock_writer
self.node_path = (
NodePath(name=AssistantGraphName.ASSISTANT),
NodePath(name=AssistantNodeName.ROOT),
)
def test_create_dispatcher_with_basic_params(self):
"""Test creating a dispatcher with required parameters."""
dispatcher = AssistantDispatcher(
writer=self.writer,
node_path=self.node_path,
node_name=AssistantNodeName.ROOT,
node_run_id="test_run_123",
)
self.assertEqual(dispatcher._node_name, AssistantNodeName.ROOT)
self.assertEqual(dispatcher._writer, self._writer)
self.assertEqual(dispatcher._node_run_id, "test_run_123")
self.assertEqual(dispatcher._node_path, self.node_path)
self.assertEqual(dispatcher._writer, self.writer)
async def test_dispatch_with_writer(self):
"""Test dispatching with a writer"""
dispatched_actions = []
def test_dispatch_message_action(self):
"""Test dispatching a message via the message() method."""
dispatcher = AssistantDispatcher(
writer=self.writer,
node_path=self.node_path,
node_name=AssistantNodeName.ROOT,
node_run_id="test_run_456",
)
def mock_writer(data):
dispatched_actions.append(data)
dispatcher = AssistantDispatcher(mock_writer, node_name=AssistantNodeName.ROOT)
message = AssistantMessage(content="test message", parent_tool_call_id="tc1")
message = AssistantMessage(content="Test message")
dispatcher.message(message)
self.assertEqual(len(dispatched_actions), 1)
self.assertEqual(len(self.dispatched_events), 1)
event = self.dispatched_events[0]
# Verify action structure
event = dispatched_actions[0]
self.assertIsInstance(event, AssistantDispatcherEvent)
self.assertEqual(event.action.type, "MESSAGE")
self.assertIsInstance(event.action, MessageAction)
assert isinstance(event.action, MessageAction)
self.assertEqual(event.action.message, message)
self.assertEqual(event.node_name, AssistantNodeName.ROOT)
self.assertEqual(event.node_run_id, "test_run_456")
self.assertEqual(event.node_path, self.node_path)
def test_dispatch_update_action(self):
"""Test dispatching an update via the update() method."""
dispatcher = AssistantDispatcher(
writer=self.writer,
node_path=self.node_path,
node_name=AssistantNodeName.TRENDS_GENERATOR,
node_run_id="test_run_789",
)
dispatcher.update("Processing query...")
self.assertEqual(len(self.dispatched_events), 1)
event = self.dispatched_events[0]
self.assertIsInstance(event, AssistantDispatcherEvent)
self.assertIsInstance(event.action, UpdateAction)
assert isinstance(event.action, UpdateAction)
self.assertEqual(event.action.content, "Processing query...")
self.assertEqual(event.node_name, AssistantNodeName.TRENDS_GENERATOR)
self.assertEqual(event.node_run_id, "test_run_789")
def test_dispatch_multiple_messages(self):
"""Test dispatching multiple messages in sequence."""
dispatcher = AssistantDispatcher(
writer=self.writer,
node_path=self.node_path,
node_name=AssistantNodeName.ROOT,
node_run_id="test_run_multi",
)
message1 = AssistantMessage(content="First message")
message2 = AssistantMessage(content="Second message")
message3 = AssistantMessage(content="Third message")
dispatcher.message(message1)
dispatcher.message(message2)
dispatcher.message(message3)
self.assertEqual(len(self.dispatched_events), 3)
contents = []
for event in self.dispatched_events:
if isinstance(event.action, MessageAction):
msg = event.action.message
if isinstance(msg, AssistantMessage):
contents.append(msg.content)
self.assertEqual(contents, ["First message", "Second message", "Third message"])
def test_dispatch_message_with_tool_calls(self):
"""Test dispatching a message with tool calls preserves all data."""
dispatcher = AssistantDispatcher(
writer=self.writer,
node_path=self.node_path,
node_name=AssistantNodeName.ROOT,
node_run_id="test_run_tools",
)
tool_call = AssistantToolCall(id="tool_123", name="search", args={"query": "test query"})
message = AssistantMessage(content="Running search...", tool_calls=[tool_call])
dispatcher.message(message)
self.assertEqual(len(self.dispatched_events), 1)
event = self.dispatched_events[0]
self.assertIsInstance(event.action, MessageAction)
assert isinstance(event.action, MessageAction)
dispatched_message = event.action.message
assert isinstance(dispatched_message, AssistantMessage)
self.assertEqual(dispatched_message.content, "Running search...")
self.assertIsNotNone(dispatched_message.tool_calls)
assert dispatched_message.tool_calls is not None
self.assertEqual(len(dispatched_message.tool_calls), 1)
self.assertEqual(dispatched_message.tool_calls[0].id, "tool_123")
self.assertEqual(dispatched_message.tool_calls[0].name, "search")
self.assertEqual(dispatched_message.tool_calls[0].args, {"query": "test query"})
def test_dispatch_with_nested_node_path(self):
"""Test that nested node paths are preserved in dispatched events."""
nested_path = (
NodePath(name=AssistantGraphName.ASSISTANT),
NodePath(name=AssistantNodeName.ROOT, message_id="msg_123", tool_call_id="tc_123"),
NodePath(name=AssistantGraphName.INSIGHTS),
NodePath(name=AssistantNodeName.TRENDS_GENERATOR),
)
dispatcher = AssistantDispatcher(
writer=self.writer, node_path=nested_path, node_name=AssistantNodeName.TRENDS_GENERATOR, node_run_id="run_1"
)
message = AssistantMessage(content="Nested message")
dispatcher.message(message)
self.assertEqual(len(self.dispatched_events), 1)
event = self.dispatched_events[0]
self.assertEqual(event.node_path, nested_path)
assert event.node_path is not None
self.assertEqual(len(event.node_path), 4)
self.assertEqual(event.node_path[1].message_id, "msg_123")
self.assertEqual(event.node_path[1].tool_call_id, "tc_123")
def test_dispatch_error_handling_continues_execution(self):
"""Test that dispatch errors are caught and logged but don't crash."""
def failing_writer(event):
raise RuntimeError("Writer failed!")
dispatcher = AssistantDispatcher(
writer=failing_writer,
node_path=self.node_path,
node_name=AssistantNodeName.ROOT,
node_run_id="test_run_error",
)
message = AssistantMessage(content="This should not crash")
# Should not raise exception
with patch("logging.getLogger") as mock_get_logger:
mock_logger = MagicMock()
mock_get_logger.return_value = mock_logger
dispatcher.message(message)
# Verify error was logged
mock_logger.error.assert_called_once()
args, kwargs = mock_logger.error.call_args
self.assertIn("Failed to dispatch action", args[0])
def test_dispatch_mixed_actions(self):
"""Test dispatching both messages and updates in sequence."""
dispatcher = AssistantDispatcher(
writer=self.writer,
node_path=self.node_path,
node_name=AssistantNodeName.TRENDS_GENERATOR,
node_run_id="test_run_mixed",
)
dispatcher.update("Starting analysis...")
dispatcher.message(AssistantMessage(content="Found 3 insights"))
dispatcher.update("Finalizing results...")
self.assertEqual(len(self.dispatched_events), 3)
self.assertIsInstance(self.dispatched_events[0].action, UpdateAction)
assert isinstance(self.dispatched_events[0].action, UpdateAction)
self.assertEqual(self.dispatched_events[0].action.content, "Starting analysis...")
self.assertIsInstance(self.dispatched_events[1].action, MessageAction)
assert isinstance(self.dispatched_events[1].action, MessageAction)
assert isinstance(self.dispatched_events[1].action.message, AssistantMessage)
self.assertEqual(self.dispatched_events[1].action.message.content, "Found 3 insights")
self.assertIsInstance(self.dispatched_events[2].action, UpdateAction)
assert isinstance(self.dispatched_events[2].action, UpdateAction)
self.assertEqual(self.dispatched_events[2].action.content, "Finalizing results...")
def test_dispatch_preserves_message_id(self):
"""Test that message IDs are preserved through dispatch."""
dispatcher = AssistantDispatcher(
writer=self.writer,
node_path=self.node_path,
node_name=AssistantNodeName.ROOT,
node_run_id="test_run_id_preservation",
)
message = AssistantMessage(id="msg_xyz_789", content="Message with ID")
dispatcher.message(message)
event = self.dispatched_events[0]
assert isinstance(event.action, MessageAction)
self.assertEqual(event.action.message.id, "msg_xyz_789")
def test_dispatch_with_empty_node_path(self):
"""Test dispatcher with an empty node path."""
dispatcher = AssistantDispatcher(
writer=self.writer, node_path=(), node_name=AssistantNodeName.ROOT, node_run_id="test_run_empty_path"
)
message = AssistantMessage(content="Root level message")
dispatcher.message(message)
self.assertEqual(len(self.dispatched_events), 1)
event = self.dispatched_events[0]
self.assertEqual(event.node_path, ())
class MockNode(BaseAssistantNode[AssistantState, PartialAssistantState]):
"""Mock node for testing"""
class TestCreateDispatcherFromConfig(BaseTest):
"""Test the create_dispatcher_from_config helper function."""
@property
def node_name(self):
return AssistantNodeName.ROOT
def test_create_dispatcher_from_config_with_stream_writer(self):
"""Test creating dispatcher from config with LangGraph stream writer."""
node_path = (NodePath(name=AssistantGraphName.ASSISTANT), NodePath(name=AssistantNodeName.ROOT))
async def arun(self, state, config):
# Use dispatch to add a message
self.dispatcher.message(AssistantMessage(content="Test message from node"))
return PartialAssistantState()
config = RunnableConfig(
metadata={"langgraph_node": AssistantNodeName.TRENDS_GENERATOR, "langgraph_checkpoint_ns": "checkpoint_abc"}
)
mock_writer = MagicMock()
with patch("ee.hogai.utils.dispatcher.get_stream_writer", return_value=mock_writer):
dispatcher = create_dispatcher_from_config(config, node_path)
self.assertEqual(dispatcher._node_name, AssistantNodeName.TRENDS_GENERATOR)
self.assertEqual(dispatcher._node_run_id, "checkpoint_abc")
self.assertEqual(dispatcher._node_path, node_path)
self.assertEqual(dispatcher._writer, mock_writer)
def test_create_dispatcher_from_config_without_stream_writer(self):
"""Test creating dispatcher when not in streaming context (e.g., testing)."""
node_path = (NodePath(name=AssistantGraphName.ASSISTANT), NodePath(name=AssistantNodeName.ROOT))
config = RunnableConfig(
metadata={"langgraph_node": AssistantNodeName.ROOT, "langgraph_checkpoint_ns": "checkpoint_xyz"}
)
# Simulate RuntimeError when get_stream_writer is called outside streaming context
with patch("ee.hogai.utils.dispatcher.get_stream_writer", side_effect=RuntimeError("Not in streaming context")):
dispatcher = create_dispatcher_from_config(config, node_path)
# Should create a noop writer
self.assertEqual(dispatcher._node_name, AssistantNodeName.ROOT)
self.assertEqual(dispatcher._node_run_id, "checkpoint_xyz")
self.assertEqual(dispatcher._node_path, node_path)
# Verify the noop writer doesn't raise exceptions
message = AssistantMessage(content="Test")
dispatcher.message(message) # Should not crash
def test_create_dispatcher_with_missing_metadata(self):
"""Test creating dispatcher when metadata fields are missing."""
node_path = (NodePath(name=AssistantGraphName.ASSISTANT),)
config = RunnableConfig(metadata={})
with patch("ee.hogai.utils.dispatcher.get_stream_writer", side_effect=RuntimeError("Not in streaming context")):
dispatcher = create_dispatcher_from_config(config, node_path)
# Should use empty strings as defaults
self.assertEqual(dispatcher._node_name, "")
self.assertEqual(dispatcher._node_run_id, "")
self.assertEqual(dispatcher._node_path, node_path)
def test_create_dispatcher_preserves_node_path(self):
"""Test that node path is correctly passed through to the dispatcher."""
nested_path = (
NodePath(name=AssistantGraphName.ASSISTANT),
NodePath(name=AssistantNodeName.ROOT, message_id="msg_1", tool_call_id="tc_1"),
NodePath(name=AssistantGraphName.INSIGHTS),
)
config = RunnableConfig(
metadata={"langgraph_node": AssistantNodeName.TRENDS_GENERATOR, "langgraph_checkpoint_ns": "cp_1"}
)
with patch("ee.hogai.utils.dispatcher.get_stream_writer", side_effect=RuntimeError("Not in streaming context")):
dispatcher = create_dispatcher_from_config(config, nested_path)
self.assertEqual(dispatcher._node_path, nested_path)
self.assertEqual(len(dispatcher._node_path), 3)
class TestDispatcherIntegration(BaseTest):
async def test_node_dispatch_flow(self):
"""Test that a node can dispatch messages"""
# Use sync test helpers since BaseTest provides them
team, user = self.team, self.user
"""Integration tests for dispatcher usage patterns."""
node = MockNode(team=team, user=user)
def test_dispatcher_in_node_context(self):
"""Test typical usage pattern within a node."""
dispatched_events = []
# Track dispatched actions
dispatched_actions = []
def mock_writer(event):
dispatched_events.append(event)
def mock_writer(data):
dispatched_actions.append(data)
# Set up node's dispatcher
node._dispatcher = AssistantDispatcher(mock_writer, node_name=AssistantNodeName.ROOT)
# Run the node
state = AssistantState(messages=[])
config = RunnableConfig(configurable={})
await node.arun(state, config)
# Verify action was dispatched
self.assertEqual(len(dispatched_actions), 1)
event = dispatched_actions[0]
self.assertIsInstance(event, AssistantDispatcherEvent)
self.assertEqual(event.action.type, "MESSAGE")
self.assertEqual(event.action.message.content, "Test message from node")
self.assertEqual(event.action.message.parent_tool_call_id, None)
self.assertEqual(event.node_name, AssistantNodeName.ROOT)
async def test_action_preservation_through_stream(self):
"""Test that action data is preserved through the stream"""
captured_updates = []
def mock_writer(data):
captured_updates.append(data)
dispatcher = AssistantDispatcher(mock_writer, node_name=AssistantNodeName.ROOT)
# Create complex message with metadata
message = AssistantMessage(
content="Complex message",
parent_tool_call_id="tc123",
tool_calls=[AssistantToolCall(id="tool1", name="search", args={"query": "test"})],
node_path = (
NodePath(name=AssistantGraphName.ASSISTANT),
NodePath(name=AssistantNodeName.ROOT),
)
dispatcher.message(message)
dispatcher = AssistantDispatcher(
writer=mock_writer, node_path=node_path, node_name=AssistantNodeName.ROOT, node_run_id="integration_run_1"
)
# Extract action
event = captured_updates[0]
self.assertIsInstance(event, AssistantDispatcherEvent)
# Simulate node execution pattern
dispatcher.update("Starting node execution...")
# Verify all message fields preserved
payload = event.action.message
self.assertEqual(payload.content, "Complex message")
self.assertEqual(payload.parent_tool_call_id, "tc123")
self.assertIsNotNone(payload.tool_calls)
self.assertEqual(len(payload.tool_calls), 1)
self.assertEqual(payload.tool_calls[0].name, "search")
tool_call = AssistantToolCall(id="tc_int_1", name="generate_insight", args={"type": "trends"})
dispatcher.message(AssistantMessage(content="Generating insight", tool_calls=[tool_call]))
async def test_multiple_dispatches_from_node(self):
"""Test that a node can dispatch multiple messages"""
# Use sync test helpers since BaseTest provides them
team, user = self.team, self.user
dispatcher.update("Processing data...")
class MultiDispatchNode(BaseAssistantNode[AssistantState, PartialAssistantState]):
@property
def node_name(self):
return AssistantNodeName.ROOT
dispatcher.message(AssistantMessage(content="Insight generated successfully"))
async def arun(self, state, config):
# Dispatch multiple messages
self.dispatcher.message(AssistantMessage(content="First message"))
self.dispatcher.message(AssistantMessage(content="Second message"))
self.dispatcher.message(AssistantMessage(content="Third message"))
return PartialAssistantState()
# Verify all events were dispatched
self.assertEqual(len(dispatched_events), 4)
node = MultiDispatchNode(team=team, user=user)
# Verify event types and order
self.assertIsInstance(dispatched_events[0].action, UpdateAction)
self.assertIsInstance(dispatched_events[1].action, MessageAction)
self.assertIsInstance(dispatched_events[2].action, UpdateAction)
self.assertIsInstance(dispatched_events[3].action, MessageAction)
dispatched_actions = []
# Verify all events have consistent metadata
for event in dispatched_events:
self.assertEqual(event.node_name, AssistantNodeName.ROOT)
self.assertEqual(event.node_run_id, "integration_run_1")
self.assertEqual(event.node_path, node_path)
def mock_writer(data):
dispatched_actions.append(data)
def test_concurrent_dispatchers(self):
"""Test multiple dispatchers can coexist without interference."""
dispatched_events_1 = []
dispatched_events_2 = []
node._dispatcher = AssistantDispatcher(mock_writer, node_name=AssistantNodeName.ROOT)
def writer_1(event):
dispatched_events_1.append(event)
state = AssistantState(messages=[])
config = RunnableConfig(configurable={})
def writer_2(event):
dispatched_events_2.append(event)
await node.arun(state, config)
path_1 = (NodePath(name=AssistantGraphName.ASSISTANT), NodePath(name=AssistantNodeName.ROOT))
path_2 = (NodePath(name=AssistantGraphName.ASSISTANT), NodePath(name=AssistantNodeName.TRENDS_GENERATOR))
# Verify all three dispatches
self.assertEqual(len(dispatched_actions), 3)
dispatcher_1 = AssistantDispatcher(
writer=writer_1, node_path=path_1, node_name=AssistantNodeName.ROOT, node_run_id="run_1"
)
contents = []
for event in dispatched_actions:
self.assertIsInstance(event, AssistantDispatcherEvent)
contents.append(event.action.message.content)
dispatcher_2 = AssistantDispatcher(
writer=writer_2, node_path=path_2, node_name=AssistantNodeName.TRENDS_GENERATOR, node_run_id="run_2"
)
self.assertEqual(contents, ["First message", "Second message", "Third message"])
# Dispatch from both
dispatcher_1.message(AssistantMessage(content="From dispatcher 1"))
dispatcher_2.message(AssistantMessage(content="From dispatcher 2"))
dispatcher_1.update("Update from dispatcher 1")
dispatcher_2.update("Update from dispatcher 2")
# Verify each dispatcher wrote to its own writer
self.assertEqual(len(dispatched_events_1), 2)
self.assertEqual(len(dispatched_events_2), 2)
# Verify events went to correct writers
assert isinstance(dispatched_events_1[0].action, MessageAction)
assert isinstance(dispatched_events_1[0].action.message, AssistantMessage)
self.assertEqual(dispatched_events_1[0].action.message.content, "From dispatcher 1")
assert isinstance(dispatched_events_2[0].action, MessageAction)
assert isinstance(dispatched_events_2[0].action.message, AssistantMessage)
self.assertEqual(dispatched_events_2[0].action.message.content, "From dispatcher 2")
# Verify node names are correct
self.assertEqual(dispatched_events_1[0].node_name, AssistantNodeName.ROOT)
self.assertEqual(dispatched_events_2[0].node_name, AssistantNodeName.TRENDS_GENERATOR)

Some files were not shown because too many files have changed in this diff Show More