diff --git a/ee/hogai/README.md b/ee/hogai/README.md index eed6ff5005..3e57b5ea85 100644 --- a/ee/hogai/README.md +++ b/ee/hogai/README.md @@ -320,7 +320,7 @@ class YourTaxonomyGraph(TaxonomyAgent[TaxonomyAgentState, TaxonomyAgentState[Max 4. Invoke it (typically from a `MaxTool`), mirroring `products/replay/backend/max_tools.py`: ```python -graph = YourTaxonomyGraph(team=self._team, user=self._user, tool_call_id=self._tool_call_id) +graph = YourTaxonomyGraph(team=self._team, user=self._user) graph_context = { "change": "Show me recordings of users in Germany that used a mobile device while performing a payment", diff --git a/ee/hogai/assistant/base.py b/ee/hogai/assistant/base.py index a7342ca49f..7fe0e2d505 100644 --- a/ee/hogai/assistant/base.py +++ b/ee/hogai/assistant/base.py @@ -9,7 +9,6 @@ import structlog import posthoganalytics from asgiref.sync import async_to_sync from langchain_core.callbacks.base import BaseCallbackHandler -from langchain_core.messages import AIMessageChunk from langchain_core.runnables.config import RunnableConfig from langgraph.errors import GraphRecursionError from langgraph.graph.state import CompiledStateGraph @@ -31,20 +30,19 @@ from posthog.event_usage import report_user_action from posthog.models import Team, User from posthog.sync import database_sync_to_async -from ee.hogai.graph.base import BaseAssistantNode from ee.hogai.utils.exceptions import GenerationCanceled from ee.hogai.utils.helpers import extract_stream_update -from ee.hogai.utils.state import ( - GraphValueUpdateTuple, - is_message_update, - is_state_update, - is_value_update, - validate_state_update, +from ee.hogai.utils.state import validate_state_update +from ee.hogai.utils.stream_processor import AssistantStreamProcessorProtocol +from ee.hogai.utils.types.base import ( + AssistantDispatcherEvent, + AssistantMessageUnion, + AssistantMode, + AssistantOutput, + AssistantResultUnion, + LangGraphUpdateEvent, ) -from ee.hogai.utils.stream_processor import AssistantStreamProcessor -from ee.hogai.utils.types import AssistantMessageUnion, AssistantOutput -from ee.hogai.utils.types.base import AssistantDispatcherEvent, AssistantMode, AssistantResultUnion, MessageChunkAction -from ee.hogai.utils.types.composed import AssistantMaxGraphState, AssistantMaxPartialGraphState, MaxNodeName +from ee.hogai.utils.types.composed import AssistantMaxGraphState, AssistantMaxPartialGraphState from ee.models import Conversation logger = structlog.get_logger(__name__) @@ -66,7 +64,7 @@ class BaseAssistant(ABC): _trace_id: Optional[str | UUID] _billing_context: Optional[MaxBillingContext] _initial_state: Optional[AssistantMaxGraphState | AssistantMaxPartialGraphState] - _stream_processor: AssistantStreamProcessor + _stream_processor: AssistantStreamProcessorProtocol """The stream processor that processes dispatcher actions and message chunks.""" def __init__( @@ -87,6 +85,7 @@ class BaseAssistant(ABC): billing_context: Optional[MaxBillingContext] = None, initial_state: Optional[AssistantMaxGraphState | AssistantMaxPartialGraphState] = None, callback_handler: Optional[BaseCallbackHandler] = None, + stream_processor: AssistantStreamProcessorProtocol, ): self._team = team self._contextual_tools = contextual_tools or {} @@ -121,22 +120,7 @@ class BaseAssistant(ABC): self._mode = mode self._initial_state = initial_state # Initialize the stream processor with node configuration - self._stream_processor = AssistantStreamProcessor( - streaming_nodes=self.STREAMING_NODES, - visualization_nodes=self.VISUALIZATION_NODES, - ) - - @property - @abstractmethod - def VISUALIZATION_NODES(self) -> dict[MaxNodeName, type[BaseAssistantNode]]: - """Nodes that can generate visualizations.""" - pass - - @property - @abstractmethod - def STREAMING_NODES(self) -> set[MaxNodeName]: - """Nodes that can stream messages to the client.""" - pass + self._stream_processor = stream_processor @abstractmethod def get_initial_state(self) -> AssistantMaxGraphState: @@ -175,7 +159,7 @@ class BaseAssistant(ABC): state = await self._init_or_update_state() config = self._get_config() - stream_mode: list[StreamMode] = ["values", "updates", "custom"] + stream_mode: list[StreamMode] = ["values", "custom"] if stream_message_chunks: stream_mode.append("messages") @@ -286,11 +270,11 @@ class BaseAssistant(ABC): # Add existing ids to streamed messages, so we don't send the messages again. for message in saved_state.messages: if message.id is not None: - self._stream_processor._streamed_update_ids.add(message.id) + self._stream_processor.mark_id_as_streamed(message.id) # Add the latest message id to streamed messages, so we don't send it multiple times. if self._latest_message and self._latest_message.id is not None: - self._stream_processor._streamed_update_ids.add(self._latest_message.id) + self._stream_processor.mark_id_as_streamed(self._latest_message.id) # If the graph previously hasn't reset the state, it is an interrupt. We resume from the point of interruption. if snapshot.next and self._latest_message and saved_state.graph_status == "interrupted": @@ -322,29 +306,12 @@ class BaseAssistant(ABC): async def _process_update(self, update: Any) -> list[AssistantResultUnion] | None: update = extract_stream_update(update) - new_message: AssistantResultUnion | None = None if not isinstance(update, AssistantDispatcherEvent): - if is_state_update(update): - _, new_state = update - self._state = validate_state_update(new_state, self._state_type) - elif is_value_update(update) and (new_message := await self._aprocess_value_update(update)): - return [new_message] + if updates := self._stream_processor.process_langgraph_update(LangGraphUpdateEvent(update=update)): + return updates + elif new_message := self._stream_processor.process(update): + return new_message - if is_message_update(update): - # Convert the message chunk update to a dispatcher event to prepare for a bright future without LangGraph - message, state = update[1] - if not isinstance(message, AIMessageChunk): - return None - update = AssistantDispatcherEvent( - action=MessageChunkAction(message=message), node_name=state["langgraph_node"] - ) - - if isinstance(update, AssistantDispatcherEvent) and (new_message := self._stream_processor.process(update)): - return [new_message] if new_message else None - - return None - - async def _aprocess_value_update(self, update: GraphValueUpdateTuple) -> AssistantResultUnion | None: return None def _build_root_config_for_persistence(self) -> RunnableConfig: diff --git a/ee/hogai/assistant/deep_research_assistant.py b/ee/hogai/assistant/deep_research_assistant.py index 56cdc4ad2d..4036af6430 100644 --- a/ee/hogai/assistant/deep_research_assistant.py +++ b/ee/hogai/assistant/deep_research_assistant.py @@ -1,5 +1,5 @@ from collections.abc import AsyncGenerator -from typing import Any, Optional +from typing import TYPE_CHECKING, Any, Optional from uuid import UUID from posthog.schema import AssistantMessage, HumanMessage, MaxBillingContext, VisualizationMessage @@ -7,13 +7,26 @@ from posthog.schema import AssistantMessage, HumanMessage, MaxBillingContext, Vi from posthog.models import Team, User from ee.hogai.assistant.base import BaseAssistant -from ee.hogai.graph import DeepResearchAssistantGraph -from ee.hogai.graph.base import BaseAssistantNode +from ee.hogai.graph.deep_research.graph import DeepResearchAssistantGraph from ee.hogai.graph.deep_research.types import DeepResearchNodeName, DeepResearchState, PartialDeepResearchState +from ee.hogai.utils.stream_processor import AssistantStreamProcessor from ee.hogai.utils.types import AssistantMode, AssistantOutput -from ee.hogai.utils.types.composed import MaxNodeName from ee.models import Conversation +if TYPE_CHECKING: + from ee.hogai.utils.types.composed import MaxNodeName + + +STREAMING_NODES: set["MaxNodeName"] = { + DeepResearchNodeName.ONBOARDING, + DeepResearchNodeName.PLANNER, + DeepResearchNodeName.TASK_EXECUTOR, +} + +VERBOSE_NODES: set["MaxNodeName"] = STREAMING_NODES | { + DeepResearchNodeName.PLANNER_TOOLS, +} + class DeepResearchAssistant(BaseAssistant): _state: Optional[DeepResearchState] @@ -48,20 +61,13 @@ class DeepResearchAssistant(BaseAssistant): trace_id=trace_id, billing_context=billing_context, initial_state=initial_state, + stream_processor=AssistantStreamProcessor( + verbose_nodes=VERBOSE_NODES, + streaming_nodes=STREAMING_NODES, + state_type=DeepResearchState, + ), ) - @property - def VISUALIZATION_NODES(self) -> dict[MaxNodeName, type[BaseAssistantNode]]: - return {} - - @property - def STREAMING_NODES(self) -> set[MaxNodeName]: - return { - DeepResearchNodeName.ONBOARDING, - DeepResearchNodeName.PLANNER, - DeepResearchNodeName.TASK_EXECUTOR, - } - def get_initial_state(self) -> DeepResearchState: if self._latest_message: return DeepResearchState( diff --git a/ee/hogai/assistant/insights_assistant.py b/ee/hogai/assistant/insights_assistant.py index 39848397f4..806116d3de 100644 --- a/ee/hogai/assistant/insights_assistant.py +++ b/ee/hogai/assistant/insights_assistant.py @@ -1,36 +1,32 @@ from collections.abc import AsyncGenerator -from typing import Any, Optional +from typing import TYPE_CHECKING, Any, Optional from uuid import UUID -from posthog.schema import ( - AssistantGenerationStatusEvent, - AssistantGenerationStatusType, - AssistantMessage, - HumanMessage, - MaxBillingContext, - VisualizationMessage, -) +from posthog.schema import AssistantMessage, HumanMessage, MaxBillingContext, VisualizationMessage from posthog.models import Team, User from ee.hogai.assistant.base import BaseAssistant -from ee.hogai.graph import FunnelGeneratorNode, RetentionGeneratorNode, SQLGeneratorNode, TrendsGeneratorNode -from ee.hogai.graph.base import BaseAssistantNode -from ee.hogai.graph.graph import InsightsAssistantGraph -from ee.hogai.graph.query_executor.nodes import QueryExecutorNode -from ee.hogai.graph.taxonomy.types import TaxonomyNodeName -from ee.hogai.utils.state import GraphValueUpdateTuple, validate_value_update -from ee.hogai.utils.types import ( - AssistantMode, - AssistantNodeName, - AssistantOutput, - AssistantState, - PartialAssistantState, -) -from ee.hogai.utils.types.base import AssistantResultUnion -from ee.hogai.utils.types.composed import MaxNodeName +from ee.hogai.graph.insights_graph.graph import InsightsGraph +from ee.hogai.utils.stream_processor import AssistantStreamProcessor +from ee.hogai.utils.types import AssistantMode, AssistantOutput, AssistantState, PartialAssistantState +from ee.hogai.utils.types.base import AssistantNodeName from ee.models import Conversation +if TYPE_CHECKING: + from ee.hogai.utils.types.composed import MaxNodeName + + +VERBOSE_NODES: set["MaxNodeName"] = { + AssistantNodeName.QUERY_EXECUTOR, + AssistantNodeName.FUNNEL_GENERATOR, + AssistantNodeName.RETENTION_GENERATOR, + AssistantNodeName.SQL_GENERATOR, + AssistantNodeName.TRENDS_GENERATOR, + AssistantNodeName.ROOT, + AssistantNodeName.ROOT_TOOLS, +} + class InsightsAssistant(BaseAssistant): _state: Optional[AssistantState] @@ -55,7 +51,7 @@ class InsightsAssistant(BaseAssistant): conversation, new_message=new_message, user=user, - graph=InsightsAssistantGraph(team, user).compile_full_graph(), + graph=InsightsGraph(team, user).compile_full_graph(), state_type=AssistantState, partial_state_type=PartialAssistantState, mode=AssistantMode.INSIGHTS_TOOL, @@ -65,24 +61,11 @@ class InsightsAssistant(BaseAssistant): trace_id=trace_id, billing_context=billing_context, initial_state=initial_state, + stream_processor=AssistantStreamProcessor( + verbose_nodes=VERBOSE_NODES, streaming_nodes=set(), state_type=AssistantState + ), ) - @property - def VISUALIZATION_NODES(self) -> dict[MaxNodeName, type[BaseAssistantNode]]: - return { - AssistantNodeName.TRENDS_GENERATOR: TrendsGeneratorNode, - AssistantNodeName.FUNNEL_GENERATOR: FunnelGeneratorNode, - AssistantNodeName.RETENTION_GENERATOR: RetentionGeneratorNode, - AssistantNodeName.SQL_GENERATOR: SQLGeneratorNode, - AssistantNodeName.QUERY_EXECUTOR: QueryExecutorNode, - } - - @property - def STREAMING_NODES(self) -> set[MaxNodeName]: - return { - TaxonomyNodeName.LOOP_NODE, - } - def get_initial_state(self) -> AssistantState: return AssistantState(messages=[]) @@ -127,13 +110,3 @@ class InsightsAssistant(BaseAssistant): "is_new_conversation": False, }, ) - - async def _aprocess_value_update(self, update: GraphValueUpdateTuple) -> AssistantResultUnion | None: - _, maybe_state_update = update - state_update = validate_value_update(maybe_state_update) - if intersected_nodes := state_update.keys() & self.VISUALIZATION_NODES.keys(): - node_name: MaxNodeName = intersected_nodes.pop() - node_val = state_update[node_name] - if isinstance(node_val, PartialAssistantState) and node_val.intermediate_steps: - return AssistantGenerationStatusEvent(type=AssistantGenerationStatusType.GENERATION_ERROR) - return await super()._aprocess_value_update(update) diff --git a/ee/hogai/assistant/main_assistant.py b/ee/hogai/assistant/main_assistant.py index 4da7ceee09..da74e2ac28 100644 --- a/ee/hogai/assistant/main_assistant.py +++ b/ee/hogai/assistant/main_assistant.py @@ -1,29 +1,15 @@ from collections.abc import AsyncGenerator -from typing import Any, Optional +from typing import TYPE_CHECKING, Any, Optional from uuid import UUID -from posthog.schema import ( - AssistantGenerationStatusEvent, - AssistantGenerationStatusType, - AssistantMessage, - HumanMessage, - MaxBillingContext, - VisualizationMessage, -) +from posthog.schema import AssistantMessage, HumanMessage, MaxBillingContext, VisualizationMessage from posthog.models import Team, User from ee.hogai.assistant.base import BaseAssistant -from ee.hogai.graph import ( - AssistantGraph, - FunnelGeneratorNode, - RetentionGeneratorNode, - SQLGeneratorNode, - TrendsGeneratorNode, -) -from ee.hogai.graph.base import BaseAssistantNode -from ee.hogai.graph.insights.nodes import InsightSearchNode -from ee.hogai.utils.state import GraphValueUpdateTuple, validate_value_update +from ee.hogai.graph.graph import AssistantGraph +from ee.hogai.graph.taxonomy.types import TaxonomyNodeName +from ee.hogai.utils.stream_processor import AssistantStreamProcessor from ee.hogai.utils.types import ( AssistantMode, AssistantNodeName, @@ -31,10 +17,36 @@ from ee.hogai.utils.types import ( AssistantState, PartialAssistantState, ) -from ee.hogai.utils.types.base import AssistantResultUnion -from ee.hogai.utils.types.composed import MaxNodeName from ee.models import Conversation +if TYPE_CHECKING: + from ee.hogai.utils.types.composed import MaxNodeName + + +STREAMING_NODES: set["MaxNodeName"] = { + AssistantNodeName.ROOT, + AssistantNodeName.INKEEP_DOCS, + AssistantNodeName.MEMORY_ONBOARDING, + AssistantNodeName.MEMORY_INITIALIZER, + AssistantNodeName.MEMORY_ONBOARDING_ENQUIRY, + AssistantNodeName.MEMORY_ONBOARDING_FINALIZE, + AssistantNodeName.DASHBOARD_CREATION, +} + + +VERBOSE_NODES: set["MaxNodeName"] = { + AssistantNodeName.TRENDS_GENERATOR, + AssistantNodeName.FUNNEL_GENERATOR, + AssistantNodeName.RETENTION_GENERATOR, + AssistantNodeName.SQL_GENERATOR, + AssistantNodeName.INSIGHTS_SEARCH, + AssistantNodeName.QUERY_EXECUTOR, + AssistantNodeName.MEMORY_INITIALIZER_INTERRUPT, + AssistantNodeName.ROOT_TOOLS, + TaxonomyNodeName.TOOLS_NODE, + TaxonomyNodeName.TASK_EXECUTOR, +} + class MainAssistant(BaseAssistant): _state: Optional[AssistantState] @@ -69,30 +81,11 @@ class MainAssistant(BaseAssistant): trace_id=trace_id, billing_context=billing_context, initial_state=initial_state, + stream_processor=AssistantStreamProcessor( + verbose_nodes=VERBOSE_NODES, streaming_nodes=STREAMING_NODES, state_type=AssistantState + ), ) - @property - def VISUALIZATION_NODES(self) -> dict[MaxNodeName, type[BaseAssistantNode]]: - return { - AssistantNodeName.TRENDS_GENERATOR: TrendsGeneratorNode, - AssistantNodeName.FUNNEL_GENERATOR: FunnelGeneratorNode, - AssistantNodeName.RETENTION_GENERATOR: RetentionGeneratorNode, - AssistantNodeName.SQL_GENERATOR: SQLGeneratorNode, - AssistantNodeName.INSIGHTS_SEARCH: InsightSearchNode, - } - - @property - def STREAMING_NODES(self) -> set[MaxNodeName]: - return { - AssistantNodeName.ROOT, - AssistantNodeName.INKEEP_DOCS, - AssistantNodeName.MEMORY_ONBOARDING, - AssistantNodeName.MEMORY_INITIALIZER, - AssistantNodeName.MEMORY_ONBOARDING_ENQUIRY, - AssistantNodeName.MEMORY_ONBOARDING_FINALIZE, - AssistantNodeName.DASHBOARD_CREATION, - } - def get_initial_state(self) -> AssistantState: if self._latest_message: return AssistantState( @@ -142,13 +135,3 @@ class MainAssistant(BaseAssistant): "is_new_conversation": self._is_new_conversation, }, ) - - async def _aprocess_value_update(self, update: GraphValueUpdateTuple) -> AssistantResultUnion | None: - _, maybe_state_update = update - state_update = validate_value_update(maybe_state_update) - if intersected_nodes := state_update.keys() & self.VISUALIZATION_NODES.keys(): - node_name: MaxNodeName = intersected_nodes.pop() - node_val = state_update[node_name] - if isinstance(node_val, PartialAssistantState) and node_val.intermediate_steps: - return AssistantGenerationStatusEvent(type=AssistantGenerationStatusType.GENERATION_ERROR) - return await super()._aprocess_value_update(update) diff --git a/ee/hogai/context/context.py b/ee/hogai/context/context.py index dae30936ca..ed5f1d6222 100644 --- a/ee/hogai/context/context.py +++ b/ee/hogai/context/context.py @@ -396,7 +396,7 @@ class AssistantContextManager(AssistantContextMixin): contextual_tools_prompt = [ f"<{tool_name}>\n" - f"{get_contextual_tool_class(tool_name)(team=self._team, user=self._user, tool_call_id="").format_context_prompt_injection(tool_context)}\n" # type: ignore + f"{get_contextual_tool_class(tool_name)(team=self._team, user=self._user).format_context_prompt_injection(tool_context)}\n" # type: ignore f"" for tool_name, tool_context in self.get_contextual_tools().items() if get_contextual_tool_class(tool_name) is not None diff --git a/ee/hogai/context/test/test_context.py b/ee/hogai/context/test/test_context.py index 3bba6785df..65a59982a0 100644 --- a/ee/hogai/context/test/test_context.py +++ b/ee/hogai/context/test/test_context.py @@ -507,7 +507,7 @@ Query results: 42 events # Mock the tool class mock_tool = MagicMock() mock_tool.format_context_prompt_injection.return_value = "Tool system prompt" - mock_get_contextual_tool_class.return_value = lambda team, user, tool_call_id: mock_tool + mock_get_contextual_tool_class.return_value = lambda team, user: mock_tool config = RunnableConfig( configurable={"contextual_tools": {"search_session_recordings": {"current_filters": {}}}} diff --git a/ee/hogai/eval/README.md b/ee/hogai/eval/README.md index a5521c5f56..ba1b28218d 100644 --- a/ee/hogai/eval/README.md +++ b/ee/hogai/eval/README.md @@ -52,7 +52,7 @@ from ee.hogai.eval.base import MaxPrivateEval from ee.hogai.eval.offline.conftest import EvaluationContext, capture_score, get_eval_context from ee.hogai.eval.schema import DatasetInput from ee.hogai.eval.scorers.sql import SQLSemanticsCorrectness, SQLSyntaxCorrectness -from ee.hogai.graph import AssistantGraph +from ee.hogai.graph.graph import AssistantGraph from ee.hogai.utils.types import AssistantState from ee.models import Conversation diff --git a/ee/hogai/eval/ci/conftest.py b/ee/hogai/eval/ci/conftest.py index 528f8317e3..ecffff38f3 100644 --- a/ee/hogai/eval/ci/conftest.py +++ b/ee/hogai/eval/ci/conftest.py @@ -19,7 +19,7 @@ from ee.hogai.django_checkpoint.checkpointer import DjangoCheckpointer # We want the PostHog set_up_evals fixture here from ee.hogai.eval.conftest import set_up_evals # noqa: F401 from ee.hogai.eval.scorers import PlanAndQueryOutput -from ee.hogai.graph.graph import AssistantGraph, InsightsAssistantGraph +from ee.hogai.graph.graph import AssistantGraph from ee.hogai.utils.types import AssistantNodeName, AssistantState from ee.models.assistant import Conversation, CoreMemory @@ -34,28 +34,10 @@ EVAL_USER_FULL_NAME = "Karen Smith" @pytest.fixture def call_root_for_insight_generation(demo_org_team_user): # This graph structure will first get a plan, then generate the SQL query. - - insights_subgraph = ( - # Insights subgraph without query execution, so we only create the queries - InsightsAssistantGraph(demo_org_team_user[1], demo_org_team_user[2]) - .add_query_creation_flow(next_node=AssistantNodeName.END) - .compile() - ) graph = ( AssistantGraph(demo_org_team_user[1], demo_org_team_user[2]) .add_edge(AssistantNodeName.START, AssistantNodeName.ROOT) - .add_root( - path_map={ - "insights": AssistantNodeName.INSIGHTS_SUBGRAPH, - "insights_search": AssistantNodeName.INSIGHTS_SEARCH, - "root": AssistantNodeName.ROOT, - "search_documentation": AssistantNodeName.END, - "end": AssistantNodeName.END, - } - ) - .add_node(AssistantNodeName.INSIGHTS_SUBGRAPH, insights_subgraph) - .add_edge(AssistantNodeName.INSIGHTS_SUBGRAPH, AssistantNodeName.END) - .add_insights_search() + .add_root() # TRICKY: We need to set a checkpointer here because async tests create a new event loop. .compile(checkpointer=DjangoCheckpointer()) ) @@ -85,15 +67,21 @@ def call_root_for_insight_generation(demo_org_team_user): final_state_raw = await graph.ainvoke(final_state, {"configurable": {"thread_id": conversation.id}}) final_state = AssistantState.model_validate(final_state_raw) - if not final_state.messages or not isinstance(final_state.messages[-1], VisualizationMessage): + # The order is a viz message, tool call message, and assistant message. + if ( + not final_state.messages + or not len(final_state.messages) >= 3 + or not isinstance(final_state.messages[-3], VisualizationMessage) + ): return { "plan": None, "query": None, "query_generation_retry_count": final_state.query_generation_retry_count, } + return { - "plan": final_state.messages[-1].plan, - "query": final_state.messages[-1].answer, + "plan": final_state.messages[-3].plan, + "query": final_state.messages[-3].answer, "query_generation_retry_count": final_state.query_generation_retry_count, } diff --git a/ee/hogai/eval/ci/eval_dashboard_creation.py b/ee/hogai/eval/ci/eval_dashboard_creation.py index 1c489b8adb..04e7599a98 100644 --- a/ee/hogai/eval/ci/eval_dashboard_creation.py +++ b/ee/hogai/eval/ci/eval_dashboard_creation.py @@ -1,5 +1,4 @@ import pytest -from unittest.mock import MagicMock, patch from braintrust import EvalCase from langchain_core.runnables import RunnableConfig @@ -7,8 +6,8 @@ from langchain_core.runnables import RunnableConfig from posthog.schema import AssistantMessage, AssistantToolCall, HumanMessage from ee.hogai.django_checkpoint.checkpointer import DjangoCheckpointer -from ee.hogai.graph import AssistantGraph from ee.hogai.graph.dashboards.nodes import DashboardCreationNode +from ee.hogai.graph.graph import AssistantGraph from ee.hogai.utils.types import AssistantMessageUnion, AssistantNodeName, AssistantState, PartialAssistantState from ee.models.assistant import Conversation @@ -21,18 +20,7 @@ def call_root_for_dashboard_creation(demo_org_team_user): graph = ( AssistantGraph(demo_org_team_user[1], demo_org_team_user[2]) .add_edge(AssistantNodeName.START, AssistantNodeName.ROOT) - .add_root( - { - "create_dashboard": AssistantNodeName.END, - "insights": AssistantNodeName.END, - "search_documentation": AssistantNodeName.END, - "session_summarization": AssistantNodeName.END, - "insights_search": AssistantNodeName.END, - "create_and_query_insight": AssistantNodeName.END, - "root": AssistantNodeName.END, - "end": AssistantNodeName.END, - } - ) + .add_root(lambda state: AssistantNodeName.END) .compile(checkpointer=DjangoCheckpointer()) ) @@ -172,9 +160,7 @@ async def eval_tool_routing_dashboard_creation(call_root_for_dashboard_creation, ) -@pytest.mark.django_db -@patch("ee.hogai.graph.base.get_stream_writer", return_value=MagicMock()) -async def eval_tool_call_dashboard_creation(patch_get_stream_writer, pytestconfig, demo_org_team_user): +async def eval_tool_call_dashboard_creation(pytestconfig, demo_org_team_user): conversation = await Conversation.objects.acreate(team=demo_org_team_user[1], user=demo_org_team_user[2]) dashboard_creation_node = DashboardCreationNode(demo_org_team_user[1], demo_org_team_user[2]) diff --git a/ee/hogai/eval/ci/eval_insight_search.py b/ee/hogai/eval/ci/eval_insight_search.py index bfcd9d26b6..edd98dbde4 100644 --- a/ee/hogai/eval/ci/eval_insight_search.py +++ b/ee/hogai/eval/ci/eval_insight_search.py @@ -8,7 +8,7 @@ from braintrust import EvalCase from posthog.schema import HumanMessage, VisualizationMessage from ee.hogai.django_checkpoint.checkpointer import DjangoCheckpointer -from ee.hogai.graph import AssistantGraph +from ee.hogai.graph.graph import AssistantGraph from ee.hogai.utils.types import AssistantNodeName, AssistantState from ee.models.assistant import Conversation @@ -84,16 +84,7 @@ def call_insight_search(demo_org_team_user): graph = ( AssistantGraph(demo_org_team_user[1], demo_org_team_user[2]) .add_edge(AssistantNodeName.START, AssistantNodeName.ROOT) - .add_root( - { - "insights": AssistantNodeName.END, - "search_documentation": AssistantNodeName.END, - "root": AssistantNodeName.END, - "end": AssistantNodeName.END, - "insights_search": AssistantNodeName.INSIGHTS_SEARCH, - } - ) - .add_insights_search() + .add_root() .compile(checkpointer=DjangoCheckpointer()) ) diff --git a/ee/hogai/eval/ci/eval_memory.py b/ee/hogai/eval/ci/eval_memory.py index a75f4520ce..abfcdcd08a 100644 --- a/ee/hogai/eval/ci/eval_memory.py +++ b/ee/hogai/eval/ci/eval_memory.py @@ -9,7 +9,7 @@ from langchain_core.messages import AIMessage as LangchainAIMessage from posthog.schema import AssistantMessage, AssistantToolCall, HumanMessage from ee.hogai.django_checkpoint.checkpointer import DjangoCheckpointer -from ee.hogai.graph import AssistantGraph +from ee.hogai.graph.graph import AssistantGraph from ee.hogai.utils.types import AssistantNodeName, AssistantState from ee.models.assistant import Conversation diff --git a/ee/hogai/eval/ci/eval_memory_onboarding.py b/ee/hogai/eval/ci/eval_memory_onboarding.py index 124238495a..492ba7e20b 100644 --- a/ee/hogai/eval/ci/eval_memory_onboarding.py +++ b/ee/hogai/eval/ci/eval_memory_onboarding.py @@ -13,7 +13,7 @@ from posthog.models.user import User from posthog.sync import database_sync_to_async from ee.hogai.django_checkpoint.checkpointer import DjangoCheckpointer -from ee.hogai.graph import AssistantGraph +from ee.hogai.graph.graph import AssistantGraph from ee.hogai.graph.memory.prompts import ( ENQUIRY_INITIAL_MESSAGE, SCRAPING_SUCCESS_KEY_PHRASE, diff --git a/ee/hogai/eval/ci/eval_root.py b/ee/hogai/eval/ci/eval_root.py index 7922b61944..cbb6f6540a 100644 --- a/ee/hogai/eval/ci/eval_root.py +++ b/ee/hogai/eval/ci/eval_root.py @@ -7,7 +7,7 @@ from braintrust import EvalCase from posthog.schema import AssistantMessage, AssistantToolCall, AssistantToolCallMessage, HumanMessage from ee.hogai.django_checkpoint.checkpointer import DjangoCheckpointer -from ee.hogai.graph import AssistantGraph +from ee.hogai.graph.graph import AssistantGraph from ee.hogai.utils.types import AssistantMessageUnion, AssistantNodeName, AssistantState from ee.models.assistant import Conversation @@ -20,16 +20,7 @@ def call_root(demo_org_team_user): graph = ( AssistantGraph(demo_org_team_user[1], demo_org_team_user[2]) .add_edge(AssistantNodeName.START, AssistantNodeName.ROOT) - .add_root( - { - "insights": AssistantNodeName.END, - "billing": AssistantNodeName.END, - "insights_search": AssistantNodeName.END, - "search_documentation": AssistantNodeName.END, - "root": AssistantNodeName.ROOT, - "end": AssistantNodeName.END, - } - ) + .add_root(lambda state: AssistantNodeName.END) # TRICKY: We need to set a checkpointer here because async tests create a new event loop. .compile(checkpointer=DjangoCheckpointer()) ) diff --git a/ee/hogai/eval/ci/eval_root_documentation.py b/ee/hogai/eval/ci/eval_root_documentation.py index 60ef3ebbe9..cb522dd44a 100644 --- a/ee/hogai/eval/ci/eval_root_documentation.py +++ b/ee/hogai/eval/ci/eval_root_documentation.py @@ -5,7 +5,7 @@ from braintrust import EvalCase from posthog.schema import AssistantMessage, AssistantToolCall, HumanMessage from ee.hogai.django_checkpoint.checkpointer import DjangoCheckpointer -from ee.hogai.graph import AssistantGraph +from ee.hogai.graph.graph import AssistantGraph from ee.hogai.utils.types import AssistantNodeName, AssistantState from ee.models.assistant import Conversation @@ -18,17 +18,7 @@ def call_root(demo_org_team_user): graph = ( AssistantGraph(demo_org_team_user[1], demo_org_team_user[2]) .add_edge(AssistantNodeName.START, AssistantNodeName.ROOT) - .add_root( - path_map={ - "insights": AssistantNodeName.END, - "billing": AssistantNodeName.END, - "insights_search": AssistantNodeName.END, - "search_documentation": AssistantNodeName.END, - "root": AssistantNodeName.ROOT, - "end": AssistantNodeName.END, - }, - tools_node=AssistantNodeName.END, - ) + .add_root(router=lambda state: AssistantNodeName.END) # TRICKY: We need to set a checkpointer here because async tests create a new event loop. .compile(checkpointer=DjangoCheckpointer()) ) diff --git a/ee/hogai/eval/ci/eval_root_entity_search.py b/ee/hogai/eval/ci/eval_root_entity_search.py index d20ff16bb4..158137f24d 100644 --- a/ee/hogai/eval/ci/eval_root_entity_search.py +++ b/ee/hogai/eval/ci/eval_root_entity_search.py @@ -5,7 +5,7 @@ from braintrust import EvalCase from posthog.schema import AssistantMessage, AssistantToolCall, HumanMessage from ee.hogai.django_checkpoint.checkpointer import DjangoCheckpointer -from ee.hogai.graph import AssistantGraph +from ee.hogai.graph.graph import AssistantGraph from ee.hogai.utils.types import AssistantNodeName, AssistantState from ee.models.assistant import Conversation @@ -18,17 +18,7 @@ def call_root(demo_org_team_user): graph = ( AssistantGraph(demo_org_team_user[1], demo_org_team_user[2]) .add_edge(AssistantNodeName.START, AssistantNodeName.ROOT) - .add_root( - path_map={ - "insights": AssistantNodeName.END, - "billing": AssistantNodeName.END, - "insights_search": AssistantNodeName.END, - "search_documentation": AssistantNodeName.END, - "root": AssistantNodeName.ROOT, - "end": AssistantNodeName.END, - }, - tools_node=AssistantNodeName.END, - ) + .add_root(router=lambda state: AssistantNodeName.END) # TRICKY: We need to set a checkpointer here because async tests create a new event loop. .compile(checkpointer=DjangoCheckpointer()) ) diff --git a/ee/hogai/eval/ci/eval_root_style.py b/ee/hogai/eval/ci/eval_root_style.py index 9e071c1b24..8b1ef101f0 100644 --- a/ee/hogai/eval/ci/eval_root_style.py +++ b/ee/hogai/eval/ci/eval_root_style.py @@ -5,7 +5,7 @@ from braintrust import EvalCase from posthog.schema import AssistantMessage, HumanMessage -from ee.hogai.graph import AssistantGraph +from ee.hogai.graph.graph import AssistantGraph from ee.hogai.utils.types import AssistantMessageUnion, AssistantNodeName, AssistantState from ee.models.assistant import Conversation @@ -60,19 +60,7 @@ def call_root(demo_org_team_user): graph = ( AssistantGraph(demo_org_team_user[1], demo_org_team_user[2]) .add_edge(AssistantNodeName.START, AssistantNodeName.ROOT) - .add_root( - { - # Some requests will go via Inkeep, and this is realistic! Inkeep needs to adhere to our intended style too - "search_documentation": AssistantNodeName.INKEEP_DOCS, - "root": AssistantNodeName.ROOT, - "billing": AssistantNodeName.END, - "insights": AssistantNodeName.END, - "insights_search": AssistantNodeName.END, - "session_summarization": AssistantNodeName.END, - "end": AssistantNodeName.END, - } - ) - .add_inkeep_docs() + .add_root(lambda state: AssistantNodeName.END) .compile() ) diff --git a/ee/hogai/eval/ci/eval_session_summarization.py b/ee/hogai/eval/ci/eval_session_summarization.py index 16d3ebcea6..13818dead5 100644 --- a/ee/hogai/eval/ci/eval_session_summarization.py +++ b/ee/hogai/eval/ci/eval_session_summarization.py @@ -7,7 +7,7 @@ from langchain_core.runnables import RunnableConfig from posthog.schema import AssistantMessage, AssistantToolCall, HumanMessage from ee.hogai.django_checkpoint.checkpointer import DjangoCheckpointer -from ee.hogai.graph import AssistantGraph +from ee.hogai.graph.graph import AssistantGraph from ee.hogai.graph.session_summaries.nodes import _SessionSearch from ee.hogai.utils.types import AssistantMessageUnion, AssistantNodeName, AssistantState from ee.hogai.utils.yaml import load_yaml_from_raw_llm_content @@ -22,16 +22,7 @@ def call_root_for_replay_sessions(demo_org_team_user): graph = ( AssistantGraph(demo_org_team_user[1], demo_org_team_user[2]) .add_edge(AssistantNodeName.START, AssistantNodeName.ROOT) - .add_root( - { - "insights": AssistantNodeName.END, - "search_documentation": AssistantNodeName.END, - "session_summarization": AssistantNodeName.END, - "insights_search": AssistantNodeName.END, - "root": AssistantNodeName.END, - "end": AssistantNodeName.END, - } - ) + .add_root(lambda state: AssistantNodeName.END) .compile(checkpointer=DjangoCheckpointer()) ) diff --git a/ee/hogai/eval/ci/eval_surveys.py b/ee/hogai/eval/ci/eval_surveys.py index b7d8d48478..3772d8eae0 100644 --- a/ee/hogai/eval/ci/eval_surveys.py +++ b/ee/hogai/eval/ci/eval_surveys.py @@ -126,9 +126,7 @@ def call_surveys_max_tool(demo_org_team_user, create_feature_flags): "change": f"Create a survey based on these instructions: {instructions}", "output": None, } - graph = FeatureFlagLookupGraph(team=team, user=user, tool_call_id="test-tool-call-id").compile_full_graph( - checkpointer=DjangoCheckpointer() - ) + graph = FeatureFlagLookupGraph(team=team, user=user).compile_full_graph(checkpointer=DjangoCheckpointer()) result = await graph.ainvoke( graph_context, config={ diff --git a/ee/hogai/eval/ci/eval_ui_context.py b/ee/hogai/eval/ci/eval_ui_context.py index 40a01776a7..3dc03e74d4 100644 --- a/ee/hogai/eval/ci/eval_ui_context.py +++ b/ee/hogai/eval/ci/eval_ui_context.py @@ -15,7 +15,7 @@ from posthog.models.action.action import Action from posthog.models.team.team import Team from ee.hogai.django_checkpoint.checkpointer import DjangoCheckpointer -from ee.hogai.graph import AssistantGraph +from ee.hogai.graph.graph import AssistantGraph from ee.hogai.utils.types import AssistantNodeName, AssistantState from ee.models.assistant import Conversation @@ -29,14 +29,7 @@ def call_root_with_ui_context(demo_org_team_user): graph = ( AssistantGraph(demo_org_team_user[1], demo_org_team_user[2]) .add_edge(AssistantNodeName.START, AssistantNodeName.ROOT) - .add_root( - { - "insights": AssistantNodeName.END, - "docs": AssistantNodeName.END, - "root": AssistantNodeName.END, - "end": AssistantNodeName.END, - } - ) + .add_root(lambda state: AssistantNodeName.END) # TRICKY: We need to set a checkpointer here because async tests create a new event loop. .compile(checkpointer=DjangoCheckpointer()) ) diff --git a/ee/hogai/eval/ci/max_tools/eval_create_feature_flag_tool.py b/ee/hogai/eval/ci/max_tools/eval_create_feature_flag_tool.py index 05414ccb70..7b94c52384 100644 --- a/ee/hogai/eval/ci/max_tools/eval_create_feature_flag_tool.py +++ b/ee/hogai/eval/ci/max_tools/eval_create_feature_flag_tool.py @@ -81,7 +81,6 @@ async def eval_create_feature_flag_basic(pytestconfig, demo_org_team_user): tool = await CreateFeatureFlagTool.create_tool_class( team=team, user=user, - tool_call_id="test-eval-call", state=AssistantState(messages=[]), config={ "configurable": { @@ -148,7 +147,6 @@ async def eval_create_feature_flag_with_rollout(pytestconfig, demo_org_team_user tool = await CreateFeatureFlagTool.create_tool_class( team=team, user=user, - tool_call_id="test-eval-call", state=AssistantState(messages=[]), config={ "configurable": { @@ -241,7 +239,6 @@ async def eval_create_feature_flag_with_property_filters(pytestconfig, demo_org_ tool = await CreateFeatureFlagTool.create_tool_class( team=team, user=user, - tool_call_id="test-eval-call", state=AssistantState(messages=[]), config={ "configurable": { @@ -336,7 +333,6 @@ async def eval_create_feature_flag_duplicate_handling(pytestconfig, demo_org_tea tool = await CreateFeatureFlagTool.create_tool_class( team=team, user=user, - tool_call_id="test-eval-call", state=AssistantState(messages=[]), config={ "configurable": { diff --git a/ee/hogai/eval/ci/max_tools/eval_edit_dashboard_tool.py b/ee/hogai/eval/ci/max_tools/eval_edit_dashboard_tool.py index 404c6d753f..b89dc7b4bf 100644 --- a/ee/hogai/eval/ci/max_tools/eval_edit_dashboard_tool.py +++ b/ee/hogai/eval/ci/max_tools/eval_edit_dashboard_tool.py @@ -1,5 +1,4 @@ import pytest -from unittest.mock import MagicMock, patch from braintrust import EvalCase from langchain_core.runnables import RunnableConfig @@ -13,18 +12,7 @@ from ee.hogai.eval.scorers import SemanticSimilarity from ee.models.assistant import Conversation -@pytest.fixture(autouse=True) -def mock_kafka_producer(): - """Mock Kafka producer to prevent Kafka errors in tests.""" - with patch("posthog.kafka_client.client._KafkaProducer.produce") as mock_produce: - mock_future = MagicMock() - mock_produce.return_value = mock_future - yield - - -@pytest.mark.django_db -@patch("ee.hogai.graph.base.get_stream_writer", return_value=MagicMock()) -async def eval_insights_addition(patch_get_stream_writer, pytestconfig, demo_org_team_user): +async def eval_insights_addition(pytestconfig, demo_org_team_user): """Test that adding insights to dashboard executes correctly.""" dashboard = await Dashboard.objects.acreate( diff --git a/ee/hogai/eval/ci/max_tools/eval_navigate_tool.py b/ee/hogai/eval/ci/max_tools/eval_navigate_tool.py index 634ea7953f..dfa9e50c37 100644 --- a/ee/hogai/eval/ci/max_tools/eval_navigate_tool.py +++ b/ee/hogai/eval/ci/max_tools/eval_navigate_tool.py @@ -6,7 +6,7 @@ from pydantic import BaseModel, Field from posthog.schema import AssistantMessage, AssistantNavigateUrl, AssistantToolCall, FailureMessage, HumanMessage from ee.hogai.django_checkpoint.checkpointer import DjangoCheckpointer -from ee.hogai.graph import AssistantGraph +from ee.hogai.graph.graph import AssistantGraph from ee.hogai.utils.types import AssistantMessageUnion, AssistantNodeName, AssistantState from ee.models.assistant import Conversation @@ -48,16 +48,7 @@ def call_root(demo_org_team_user): graph = ( AssistantGraph(demo_org_team_user[1], demo_org_team_user[2]) .add_edge(AssistantNodeName.START, AssistantNodeName.ROOT) - .add_root( - { - "insights": AssistantNodeName.END, - "billing": AssistantNodeName.END, - "insights_search": AssistantNodeName.END, - "search_documentation": AssistantNodeName.END, - "root": AssistantNodeName.ROOT, - "end": AssistantNodeName.END, - } - ) + .add_root(lambda state: AssistantNodeName.END) # TRICKY: We need to set a checkpointer here because async tests create a new event loop. .compile(checkpointer=DjangoCheckpointer()) ) diff --git a/ee/hogai/eval/ci/max_tools/eval_revenue_analytics_filter_generation.py b/ee/hogai/eval/ci/max_tools/eval_revenue_analytics_filter_generation.py index 63562d8aa4..18839ebf12 100644 --- a/ee/hogai/eval/ci/max_tools/eval_revenue_analytics_filter_generation.py +++ b/ee/hogai/eval/ci/max_tools/eval_revenue_analytics_filter_generation.py @@ -40,9 +40,9 @@ DUMMY_CURRENT_FILTERS = RevenueAnalyticsAssistantFilters( @pytest.fixture def call_filter_revenue_analytics(demo_org_team_user): - graph = RevenueAnalyticsFilterOptionsGraph( - demo_org_team_user[1], demo_org_team_user[2], tool_call_id="test-tool-call-id" - ).compile_full_graph(checkpointer=DjangoCheckpointer()) + graph = RevenueAnalyticsFilterOptionsGraph(demo_org_team_user[1], demo_org_team_user[2]).compile_full_graph( + checkpointer=DjangoCheckpointer() + ) async def callable(change: str) -> dict: conversation = await Conversation.objects.acreate(team=demo_org_team_user[1], user=demo_org_team_user[2]) diff --git a/ee/hogai/eval/ci/max_tools/eval_session_replay_filter_generation.py b/ee/hogai/eval/ci/max_tools/eval_session_replay_filter_generation.py index 4633426cff..e7cb0f5088 100644 --- a/ee/hogai/eval/ci/max_tools/eval_session_replay_filter_generation.py +++ b/ee/hogai/eval/ci/max_tools/eval_session_replay_filter_generation.py @@ -48,9 +48,9 @@ DUMMY_CURRENT_FILTERS = MaxRecordingUniversalFilters( @pytest.fixture def call_search_session_recordings(demo_org_team_user): - graph = SessionReplayFilterOptionsGraph( - demo_org_team_user[1], demo_org_team_user[2], tool_call_id="test-tool-call-id" - ).compile_full_graph(checkpointer=DjangoCheckpointer()) + graph = SessionReplayFilterOptionsGraph(demo_org_team_user[1], demo_org_team_user[2]).compile_full_graph( + checkpointer=DjangoCheckpointer() + ) async def callable(change: str) -> dict: conversation = await Conversation.objects.acreate(team=demo_org_team_user[1], user=demo_org_team_user[2]) diff --git a/ee/hogai/eval/ci/max_tools/eval_tool_filter_generation.py b/ee/hogai/eval/ci/max_tools/eval_tool_filter_generation.py index 295439cbe0..97e8e579ef 100644 --- a/ee/hogai/eval/ci/max_tools/eval_tool_filter_generation.py +++ b/ee/hogai/eval/ci/max_tools/eval_tool_filter_generation.py @@ -48,9 +48,9 @@ DUMMY_CURRENT_FILTERS = MaxRecordingUniversalFilters( @pytest.fixture def call_search_session_recordings(demo_org_team_user): - graph = SessionReplayFilterOptionsGraph( - demo_org_team_user[1], demo_org_team_user[2], tool_call_id="test-tool-call-id" - ).compile_full_graph(checkpointer=DjangoCheckpointer()) + graph = SessionReplayFilterOptionsGraph(demo_org_team_user[1], demo_org_team_user[2]).compile_full_graph( + checkpointer=DjangoCheckpointer() + ) async def callable(change: str) -> dict: conversation = await Conversation.objects.acreate(team=demo_org_team_user[1], user=demo_org_team_user[2]) diff --git a/ee/hogai/eval/offline/eval_sql.py b/ee/hogai/eval/offline/eval_sql.py index 42de1d6636..6d8ec46f91 100644 --- a/ee/hogai/eval/offline/eval_sql.py +++ b/ee/hogai/eval/offline/eval_sql.py @@ -18,7 +18,7 @@ from ee.hogai.eval.base import MaxPrivateEval from ee.hogai.eval.offline.conftest import EvaluationContext, capture_score, get_eval_context from ee.hogai.eval.schema import DatasetInput from ee.hogai.eval.scorers.sql import SQLSemanticsCorrectness, SQLSyntaxCorrectness -from ee.hogai.graph import AssistantGraph +from ee.hogai.graph.graph import AssistantGraph from ee.hogai.utils.helpers import find_last_message_of_type from ee.hogai.utils.types import AssistantState from ee.hogai.utils.warehouse import serialize_database_schema diff --git a/ee/hogai/eval/scorers/__init__.py b/ee/hogai/eval/scorers/__init__.py index fe059ad7da..c9e0155fbe 100644 --- a/ee/hogai/eval/scorers/__init__.py +++ b/ee/hogai/eval/scorers/__init__.py @@ -41,27 +41,29 @@ class ToolRelevance(ScorerWithPartial): raise TypeError(f"Eval case expected must be an AssistantToolCall, not {type(expected)}") if not isinstance(output, AssistantMessage): raise TypeError(f"Eval case output must be an AssistantMessage, not {type(output)}") - if output.tool_calls and len(output.tool_calls) > 1: - raise ValueError("Parallel tool calls not supported by this scorer yet") - score = 0.0 # 0.0 to 1.0 - if output.tool_calls and len(output.tool_calls) == 1: - tool_call = output.tool_calls[0] - # 0.5 point for getting the tool right - if tool_call.name == expected.name: - score += 0.5 - if not expected.args: - score += 0.5 if not tool_call.args else 0 # If no args expected, only score for lack of args - else: - score_per_arg = 0.5 / len(expected.args) - for arg_name, expected_arg_value in expected.args.items(): - if arg_name in self.semantic_similarity_args: - arg_similarity = AnswerSimilarity(model="text-embedding-3-small").eval( - output=tool_call.args.get(arg_name), expected=expected_arg_value - ) - score += arg_similarity.score * score_per_arg - elif tool_call.args.get(arg_name) == expected_arg_value: - score += score_per_arg - return Score(name=self._name(), score=score) + + best_score = 0.0 # 0.0 to 1.0 + if output.tool_calls: + # Check all tool calls and return the best match + for tool_call in output.tool_calls: + score = 0.0 + # 0.5 point for getting the tool right + if tool_call.name == expected.name: + score += 0.5 + if not expected.args: + score += 0.5 if not tool_call.args else 0 # If no args expected, only score for lack of args + else: + score_per_arg = 0.5 / len(expected.args) + for arg_name, expected_arg_value in expected.args.items(): + if arg_name in self.semantic_similarity_args: + arg_similarity = AnswerSimilarity(model="text-embedding-3-small").eval( + output=tool_call.args.get(arg_name), expected=expected_arg_value + ) + score += arg_similarity.score * score_per_arg + elif tool_call.args.get(arg_name) == expected_arg_value: + score += score_per_arg + best_score = max(best_score, score) + return Score(name=self._name(), score=best_score) class PlanAndQueryOutput(TypedDict, total=False): diff --git a/ee/hogai/graph/__init__.py b/ee/hogai/graph/__init__.py index 4f88326575..e69de29bb2 100644 --- a/ee/hogai/graph/__init__.py +++ b/ee/hogai/graph/__init__.py @@ -1,33 +0,0 @@ -from .deep_research.graph import DeepResearchAssistantGraph -from .funnels.nodes import FunnelGeneratorNode -from .graph import AssistantGraph, InsightsAssistantGraph -from .inkeep_docs.nodes import InkeepDocsNode -from .insights.nodes import InsightSearchNode -from .memory.nodes import MemoryInitializerNode -from .query_executor.nodes import QueryExecutorNode -from .query_planner.nodes import QueryPlannerNode -from .rag.nodes import InsightRagContextNode -from .retention.nodes import RetentionGeneratorNode -from .root.nodes import RootNode, RootNodeTools -from .schema_generator.nodes import SchemaGeneratorNode -from .sql.nodes import SQLGeneratorNode -from .trends.nodes import TrendsGeneratorNode - -__all__ = [ - "FunnelGeneratorNode", - "InkeepDocsNode", - "MemoryInitializerNode", - "QueryExecutorNode", - "InsightRagContextNode", - "RetentionGeneratorNode", - "RootNode", - "RootNodeTools", - "SchemaGeneratorNode", - "SQLGeneratorNode", - "QueryPlannerNode", - "TrendsGeneratorNode", - "AssistantGraph", - "InsightsAssistantGraph", - "InsightSearchNode", - "DeepResearchAssistantGraph", -] diff --git a/ee/hogai/graph/base/__init__.py b/ee/hogai/graph/base/__init__.py new file mode 100644 index 0000000000..7a56583a32 --- /dev/null +++ b/ee/hogai/graph/base/__init__.py @@ -0,0 +1,9 @@ +from .graph import BaseAssistantGraph, global_checkpointer +from .node import AssistantNode, BaseAssistantNode + +__all__ = [ + "BaseAssistantNode", + "AssistantNode", + "BaseAssistantGraph", + "global_checkpointer", +] diff --git a/ee/hogai/graph/base/context.py b/ee/hogai/graph/base/context.py new file mode 100644 index 0000000000..c644e086c7 --- /dev/null +++ b/ee/hogai/graph/base/context.py @@ -0,0 +1,22 @@ +import contextvars +from contextlib import contextmanager + +from ee.hogai.utils.types.base import NodePath + +node_path_context = contextvars.ContextVar[tuple[NodePath, ...]]("node_path_context") + + +@contextmanager +def set_node_path(node_path: tuple[NodePath, ...]): + token = node_path_context.set(node_path) + try: + yield + finally: + node_path_context.reset(token) + + +def get_node_path() -> tuple[NodePath, ...] | None: + try: + return node_path_context.get() + except LookupError: + return None diff --git a/ee/hogai/graph/base/graph.py b/ee/hogai/graph/base/graph.py new file mode 100644 index 0000000000..0f6a28cba0 --- /dev/null +++ b/ee/hogai/graph/base/graph.py @@ -0,0 +1,88 @@ +from abc import ABC, abstractmethod +from collections.abc import Callable +from functools import wraps +from typing import TYPE_CHECKING, Any, Generic, Literal, TypeVar + +from langgraph.graph.state import StateGraph + +from posthog.models import Team, User + +from ee.hogai.django_checkpoint.checkpointer import DjangoCheckpointer +from ee.hogai.utils.types.base import AssistantGraphName, AssistantNodeName, NodePath, PartialStateType, StateType + +from .context import get_node_path, set_node_path +from .node import BaseAssistantNode + +if TYPE_CHECKING: + from ee.hogai.utils.types.composed import MaxNodeName + + +# Base checkpointer for all graphs +global_checkpointer = DjangoCheckpointer() + +T = TypeVar("T") + + +def with_node_path(func: Callable[..., T]) -> Callable[..., T]: + @wraps(func) + def wrapper(self, *args: Any, **kwargs: Any) -> T: + with set_node_path(self.node_path): + return func(self, *args, **kwargs) + + return wrapper + + +class BaseAssistantGraph(Generic[StateType, PartialStateType], ABC): + _team: Team + _user: User + _graph: StateGraph + _node_path: tuple[NodePath, ...] + + def __init__( + self, + team: Team, + user: User, + ): + self._team = team + self._user = user + self._has_start_node = False + self._graph = StateGraph(self.state_type) + self._node_path = (*(get_node_path() or ()), NodePath(name=self.graph_name.value)) + + def __init_subclass__(cls, **kwargs): + super().__init_subclass__(**kwargs) + # Wrap all public methods with the node path context + for name, method in cls.__dict__.items(): + if callable(method) and not name.startswith("_") and name not in ("graph_name", "state_type", "node_path"): + setattr(cls, name, with_node_path(method)) + + @property + @abstractmethod + def state_type(self) -> type[StateType]: ... + + @property + @abstractmethod + def graph_name(self) -> AssistantGraphName: ... + + @property + def node_path(self) -> tuple[NodePath, ...]: + return self._node_path + + def add_edge(self, from_node: "MaxNodeName", to_node: "MaxNodeName"): + if from_node == AssistantNodeName.START: + self._has_start_node = True + self._graph.add_edge(from_node, to_node) + return self + + def add_node(self, node: "MaxNodeName", action: BaseAssistantNode[StateType, PartialStateType]): + self._graph.add_node(node, action) + return self + + def compile(self, checkpointer: DjangoCheckpointer | None | Literal[False] = None): + if not self._has_start_node: + raise ValueError("Start node not added to the graph") + # TRICKY: We check `is not None` because False has a special meaning of "no checkpointer", which we want to pass on + compiled_graph = self._graph.compile( + checkpointer=checkpointer if checkpointer is not None else global_checkpointer + ) + return compiled_graph diff --git a/ee/hogai/graph/base.py b/ee/hogai/graph/base/node.py similarity index 53% rename from ee/hogai/graph/base.py rename to ee/hogai/graph/base/node.py index 73fc911e80..4e3f0f5636 100644 --- a/ee/hogai/graph/base.py +++ b/ee/hogai/graph/base/node.py @@ -1,49 +1,45 @@ -from abc import ABC, abstractmethod -from collections.abc import Sequence +from abc import ABC from typing import Generic from uuid import UUID from django.conf import settings from langchain_core.runnables import RunnableConfig -from langgraph.config import get_stream_writer -from posthog.schema import AssistantMessage, AssistantToolCall, HumanMessage +from posthog.schema import HumanMessage -from posthog.models import Team -from posthog.models.user import User +from posthog.models import Team, User from posthog.sync import database_sync_to_async from ee.hogai.context import AssistantContextManager -from ee.hogai.graph.mixins import AssistantContextMixin -from ee.hogai.utils.dispatcher import AssistantDispatcher +from ee.hogai.graph.base.context import get_node_path, set_node_path +from ee.hogai.graph.mixins import AssistantContextMixin, AssistantDispatcherMixin from ee.hogai.utils.exceptions import GenerationCanceled from ee.hogai.utils.helpers import find_start_message -from ee.hogai.utils.types import ( - AssistantMessageUnion, +from ee.hogai.utils.types.base import ( AssistantState, + NodeEndAction, + NodePath, + NodeStartAction, PartialAssistantState, PartialStateType, StateType, ) -from ee.hogai.utils.types.composed import MaxNodeName from ee.models import Conversation -class BaseAssistantNode(Generic[StateType, PartialStateType], AssistantContextMixin, ABC): +class BaseAssistantNode(Generic[StateType, PartialStateType], AssistantContextMixin, AssistantDispatcherMixin, ABC): _config: RunnableConfig | None = None _context_manager: AssistantContextManager | None = None - _dispatcher: AssistantDispatcher | None = None - _parent_tool_call_id: str | None = None + _node_path: tuple[NodePath, ...] - def __init__(self, team: Team, user: User): + def __init__(self, team: Team, user: User, node_path: tuple[NodePath, ...] | None = None): self._team = team self._user = user - - @property - @abstractmethod - def node_name(self) -> MaxNodeName: - raise NotImplementedError + if node_path is None: + self._node_path = (*(get_node_path() or ()), NodePath(name=self.node_name)) + else: + self._node_path = node_path async def __call__(self, state: StateType, config: RunnableConfig) -> PartialStateType | None: """ @@ -54,25 +50,19 @@ class BaseAssistantNode(Generic[StateType, PartialStateType], AssistantContextMi self._dispatcher = None self._config = config - if isinstance(state, AssistantState) and state.root_tool_call_id: - # NOTE: we set the parent tool call id as the root tool call id - # This will be deprecated once all tools become MaxTools and are removed from the graph - self._parent_tool_call_id = state.root_tool_call_id - - self.dispatcher.node_start() + self.dispatcher.dispatch(NodeStartAction()) thread_id = (config.get("configurable") or {}).get("thread_id") if thread_id and await self._is_conversation_cancelled(thread_id): raise GenerationCanceled try: - new_state = await self.arun(state, config) + new_state = await self._arun_with_context(state, config) except NotImplementedError: - new_state = await database_sync_to_async(self.run, thread_sensitive=False)(state, config) + new_state = await database_sync_to_async(self._run_with_context, thread_sensitive=False)(state, config) + + self.dispatcher.dispatch(NodeEndAction(state=new_state)) - if new_state is not None and (messages := getattr(new_state, "messages", [])): - for message in messages: - self.dispatcher.message(message) return new_state def run(self, state: StateType, config: RunnableConfig) -> PartialStateType | None: @@ -82,6 +72,14 @@ class BaseAssistantNode(Generic[StateType, PartialStateType], AssistantContextMi async def arun(self, state: StateType, config: RunnableConfig) -> PartialStateType | None: raise NotImplementedError + def _run_with_context(self, state: StateType, config: RunnableConfig) -> PartialStateType | None: + with set_node_path(self.node_path): + return self.run(state, config) + + async def _arun_with_context(self, state: StateType, config: RunnableConfig) -> PartialStateType | None: + with set_node_path(self.node_path): + return await self.arun(state, config) + @property def context_manager(self) -> AssistantContextManager: if self._context_manager is None: @@ -97,26 +95,13 @@ class BaseAssistantNode(Generic[StateType, PartialStateType], AssistantContextMi return self._context_manager @property - def dispatcher(self) -> AssistantDispatcher: - """Create a dispatcher for this node""" - if self._dispatcher: - return self._dispatcher - - # Set writer from LangGraph context - try: - writer = get_stream_writer() - except RuntimeError: - # Not in streaming context (e.g., testing) - # Use noop writer - def noop(*_args, **_kwargs): - pass - - writer = noop - - self._dispatcher = AssistantDispatcher( - writer, node_name=self.node_name, parent_tool_call_id=self._parent_tool_call_id - ) - return self._dispatcher + def node_name(self) -> str: + config_name: str | None = None + if self._config: + config_name = self._config["metadata"].get("langgraph_node") + if config_name is not None: + config_name = str(config_name) + return config_name or self.__class__.__name__ async def _is_conversation_cancelled(self, conversation_id: UUID) -> bool: conversation = await self._aget_conversation(conversation_id) @@ -124,15 +109,6 @@ class BaseAssistantNode(Generic[StateType, PartialStateType], AssistantContextMi raise ValueError(f"Conversation {conversation_id} not found") return conversation.status == Conversation.Status.CANCELING - def _get_tool_call(self, messages: Sequence[AssistantMessageUnion], tool_call_id: str) -> AssistantToolCall: - for message in reversed(messages): - if not isinstance(message, AssistantMessage) or not message.tool_calls: - continue - for tool_call in message.tool_calls: - if tool_call.id == tool_call_id: - return tool_call - raise ValueError(f"Tool call {tool_call_id} not found in state") - def _is_first_turn(self, state: AssistantState) -> bool: last_message = state.messages[-1] if isinstance(last_message, HumanMessage): diff --git a/ee/hogai/graph/billing/__init__.py b/ee/hogai/graph/base/test/__init__.py similarity index 100% rename from ee/hogai/graph/billing/__init__.py rename to ee/hogai/graph/base/test/__init__.py diff --git a/ee/hogai/graph/base/test/test_assistant_graph.py b/ee/hogai/graph/base/test/test_assistant_graph.py new file mode 100644 index 0000000000..6329293848 --- /dev/null +++ b/ee/hogai/graph/base/test/test_assistant_graph.py @@ -0,0 +1,47 @@ +from posthog.test.base import BaseTest + +from langgraph.checkpoint.memory import InMemorySaver + +from ee.hogai.graph.base import AssistantNode +from ee.hogai.graph.base.graph import BaseAssistantGraph +from ee.hogai.utils.types import AssistantNodeName, AssistantState, PartialAssistantState +from ee.hogai.utils.types.base import AssistantGraphName +from ee.models import Conversation + + +class TestAssistantGraph(BaseTest): + async def test_pydantic_state_resets_with_none(self): + """When a None field is set, it should be reset to None.""" + + class TestAssistantGraph(BaseAssistantGraph[AssistantState, PartialAssistantState]): + @property + def state_type(self) -> type[AssistantState]: + return AssistantState + + @property + def graph_name(self) -> AssistantGraphName: + return AssistantGraphName.ASSISTANT + + graph = TestAssistantGraph(self.team, self.user) + + class TestNode(AssistantNode): + @property + def node_name(self): + return AssistantNodeName.ROOT + + async def arun(self, state, config): + return PartialAssistantState(start_id=None) + + compiled_graph = ( + graph.add_node(AssistantNodeName.ROOT, TestNode(self.team, self.user)) + .add_edge(AssistantNodeName.START, AssistantNodeName.ROOT) + .add_edge(AssistantNodeName.ROOT, AssistantNodeName.END) + .compile(checkpointer=InMemorySaver()) + ) + conversation = await Conversation.objects.acreate(team=self.team, user=self.user) + state = await compiled_graph.ainvoke( + AssistantState(messages=[], graph_status="resumed", start_id=None), + {"configurable": {"thread_id": conversation.id}}, + ) + self.assertEqual(state["start_id"], None) + self.assertEqual(state["graph_status"], "resumed") diff --git a/ee/hogai/graph/base/test/test_node_path.py b/ee/hogai/graph/base/test/test_node_path.py new file mode 100644 index 0000000000..d1c8961831 --- /dev/null +++ b/ee/hogai/graph/base/test/test_node_path.py @@ -0,0 +1,502 @@ +from posthog.test.base import BaseTest + +from langchain_core.runnables import RunnableConfig + +from ee.hogai.graph.base import AssistantNode +from ee.hogai.graph.base.context import get_node_path +from ee.hogai.graph.base.graph import BaseAssistantGraph +from ee.hogai.utils.types import AssistantNodeName, AssistantState, PartialAssistantState +from ee.hogai.utils.types.base import AssistantGraphName, NodePath +from ee.models import Conversation + + +class TestNodePath(BaseTest): + """ + Tests for node_path functionality across sync/async methods and graph compositions. + """ + + async def test_graph_to_async_node_has_two_elements(self): + """Graph -> Node (async) should have path: [graph, node]""" + captured_path = None + + class TestNode(AssistantNode): + async def arun(self, state, config): + nonlocal captured_path + captured_path = get_node_path() + return None + + class TestGraph(BaseAssistantGraph[AssistantState, PartialAssistantState]): + @property + def state_type(self) -> type[AssistantState]: + return AssistantState + + @property + def graph_name(self) -> AssistantGraphName: + return AssistantGraphName.ASSISTANT + + def setup(self): + node = TestNode(self._team, self._user) + self.add_node(AssistantNodeName.ROOT, node) + self.add_edge(AssistantNodeName.START, AssistantNodeName.ROOT) + self.add_edge(AssistantNodeName.ROOT, AssistantNodeName.END) + return self + + graph = TestGraph(self.team, self.user) + compiled = graph.setup().compile(checkpointer=False) + + conversation = await Conversation.objects.acreate(team=self.team, user=self.user) + await compiled.ainvoke(AssistantState(messages=[]), {"configurable": {"thread_id": conversation.id}}) + + assert captured_path is not None + self.assertEqual(len(captured_path), 2) + self.assertEqual(captured_path[0].name, AssistantGraphName.ASSISTANT.value) + # Node name is determined at init time, so it's the class name + self.assertEqual(captured_path[1].name, "TestNode") + + async def test_graph_to_sync_node_has_two_elements(self): + """Graph -> Node (sync) should have path: [graph, node]""" + captured_path = None + + class TestNode(AssistantNode): + def run(self, state, config): + nonlocal captured_path + captured_path = get_node_path() + return None + + class TestGraph(BaseAssistantGraph[AssistantState, PartialAssistantState]): + @property + def state_type(self) -> type[AssistantState]: + return AssistantState + + @property + def graph_name(self) -> AssistantGraphName: + return AssistantGraphName.ASSISTANT + + def setup(self): + node = TestNode(self._team, self._user) + self.add_node(AssistantNodeName.ROOT, node) + self.add_edge(AssistantNodeName.START, AssistantNodeName.ROOT) + self.add_edge(AssistantNodeName.ROOT, AssistantNodeName.END) + return self + + graph = TestGraph(self.team, self.user) + compiled = graph.setup().compile(checkpointer=False) + + conversation = await Conversation.objects.acreate(team=self.team, user=self.user) + await compiled.ainvoke(AssistantState(messages=[]), {"configurable": {"thread_id": conversation.id}}) + + assert captured_path is not None + self.assertEqual(len(captured_path), 2) + self.assertEqual(captured_path[0].name, AssistantGraphName.ASSISTANT.value) + self.assertEqual(captured_path[1].name, "TestNode") + + async def test_graph_to_node_to_async_node_has_three_elements(self): + """Graph -> Node -> Node (async) should have path: [graph, node, node]""" + captured_paths = [] + + class SecondNode(AssistantNode): + async def arun(self, state, config): + nonlocal captured_paths + captured_paths.append(get_node_path()) + return None + + class FirstNode(AssistantNode): + async def arun(self, state, config): + nonlocal captured_paths + captured_paths.append(get_node_path()) + # Call second node + second_node = SecondNode(self._team, self._user) + await second_node(state, config) + return None + + class TestGraph(BaseAssistantGraph[AssistantState, PartialAssistantState]): + @property + def state_type(self) -> type[AssistantState]: + return AssistantState + + @property + def graph_name(self) -> AssistantGraphName: + return AssistantGraphName.ASSISTANT + + def setup(self): + node = FirstNode(self._team, self._user) + self.add_node(AssistantNodeName.ROOT, node) + self.add_edge(AssistantNodeName.START, AssistantNodeName.ROOT) + self.add_edge(AssistantNodeName.ROOT, AssistantNodeName.END) + return self + + graph = TestGraph(self.team, self.user) + compiled = graph.setup().compile(checkpointer=False) + + conversation = await Conversation.objects.acreate(team=self.team, user=self.user) + await compiled.ainvoke(AssistantState(messages=[]), {"configurable": {"thread_id": conversation.id}}) + + # First node: [graph, node] + assert captured_paths[0] is not None + self.assertEqual(len(captured_paths[0]), 2) + self.assertEqual(captured_paths[0][0].name, AssistantGraphName.ASSISTANT.value) + self.assertEqual(captured_paths[0][1].name, "FirstNode") + + # Second node: [graph, node, node] + assert captured_paths[1] is not None + self.assertEqual(len(captured_paths[1]), 3) + self.assertEqual(captured_paths[1][0].name, AssistantGraphName.ASSISTANT.value) + self.assertEqual(captured_paths[1][1].name, "FirstNode") + self.assertEqual(captured_paths[1][2].name, "SecondNode") + + async def test_graph_to_node_to_sync_node_has_three_elements(self): + """Graph -> Node -> Node (sync) - calling .run() directly doesn't extend path""" + captured_paths = [] + + class SecondNode(AssistantNode): + def run(self, state, config): + nonlocal captured_paths + captured_paths.append(get_node_path()) + return None + + class FirstNode(AssistantNode): + def run(self, state, config): + nonlocal captured_paths + captured_paths.append(get_node_path()) + # Call second node - note: calling run() directly bypasses context setting, + # so the second node won't have the proper path. This tests that direct .run() + # calls don't propagate context properly. + second_node = SecondNode(self._team, self._user) + second_node.run(state, config) + return None + + class TestGraph(BaseAssistantGraph[AssistantState, PartialAssistantState]): + @property + def state_type(self) -> type[AssistantState]: + return AssistantState + + @property + def graph_name(self) -> AssistantGraphName: + return AssistantGraphName.ASSISTANT + + def setup(self): + node = FirstNode(self._team, self._user) + self.add_node(AssistantNodeName.ROOT, node) + self.add_edge(AssistantNodeName.START, AssistantNodeName.ROOT) + self.add_edge(AssistantNodeName.ROOT, AssistantNodeName.END) + return self + + graph = TestGraph(self.team, self.user) + compiled = graph.setup().compile(checkpointer=False) + + conversation = await Conversation.objects.acreate(team=self.team, user=self.user) + await compiled.ainvoke(AssistantState(messages=[]), {"configurable": {"thread_id": conversation.id}}) + + # First node: [graph, node] + assert captured_paths[0] is not None + self.assertEqual(len(captured_paths[0]), 2) + self.assertEqual(captured_paths[0][0].name, AssistantGraphName.ASSISTANT.value) + self.assertEqual(captured_paths[0][1].name, "FirstNode") + + # Second node: [graph, node] - initialized within FirstNode's context, so gets same path + assert captured_paths[1] is not None + self.assertEqual(len(captured_paths[1]), 2) + self.assertEqual(captured_paths[1][0].name, AssistantGraphName.ASSISTANT.value) + self.assertEqual(captured_paths[1][1].name, "FirstNode") # Same as first because initialized in same context + + async def test_graph_to_node_to_graph_to_node_has_four_elements(self): + """Graph -> Node -> Graph -> Node should have path: [graph, node, graph, node]""" + captured_paths = [] + + class InnerNode(AssistantNode): + async def arun(self, state, config): + nonlocal captured_paths + captured_paths.append(get_node_path()) + return None + + class InnerGraph(BaseAssistantGraph[AssistantState, PartialAssistantState]): + @property + def state_type(self) -> type[AssistantState]: + return AssistantState + + @property + def graph_name(self) -> AssistantGraphName: + return AssistantGraphName.INSIGHTS + + def setup(self): + node = InnerNode(self._team, self._user) + self.add_node(AssistantNodeName.TRENDS_GENERATOR, node) + self.add_edge(AssistantNodeName.START, AssistantNodeName.TRENDS_GENERATOR) + self.add_edge(AssistantNodeName.TRENDS_GENERATOR, AssistantNodeName.END) + return self + + class OuterNode(AssistantNode): + async def arun(self, state, config): + nonlocal captured_paths + captured_paths.append(get_node_path()) + # Call inner graph + inner_graph = InnerGraph(self._team, self._user) + compiled_inner = inner_graph.setup().compile(checkpointer=False) + await compiled_inner.ainvoke(state, config) + return None + + class OuterGraph(BaseAssistantGraph[AssistantState, PartialAssistantState]): + @property + def state_type(self) -> type[AssistantState]: + return AssistantState + + @property + def graph_name(self) -> AssistantGraphName: + return AssistantGraphName.ASSISTANT + + def setup(self): + node = OuterNode(self._team, self._user) + self.add_node(AssistantNodeName.ROOT, node) + self.add_edge(AssistantNodeName.START, AssistantNodeName.ROOT) + self.add_edge(AssistantNodeName.ROOT, AssistantNodeName.END) + return self + + outer_graph = OuterGraph(self.team, self.user) + compiled = outer_graph.setup().compile(checkpointer=False) + + conversation = await Conversation.objects.acreate(team=self.team, user=self.user) + await compiled.ainvoke(AssistantState(messages=[]), {"configurable": {"thread_id": conversation.id}}) + + # Outer node: [graph, node] + assert captured_paths[0] is not None + self.assertEqual(len(captured_paths[0]), 2) + self.assertEqual(captured_paths[0][0].name, AssistantGraphName.ASSISTANT.value) + self.assertEqual(captured_paths[0][1].name, "OuterNode") + + # Inner node: [graph, node, graph, node] + assert captured_paths[1] is not None + self.assertEqual(len(captured_paths[1]), 4) + self.assertEqual(captured_paths[1][0].name, AssistantGraphName.ASSISTANT.value) + self.assertEqual(captured_paths[1][1].name, "OuterNode") + self.assertEqual(captured_paths[1][2].name, AssistantGraphName.INSIGHTS.value) + self.assertEqual(captured_paths[1][3].name, "InnerNode") + + async def test_graph_to_graph_to_node_has_three_elements(self): + """Graph -> Graph -> Node should have path: [graph, graph, node]""" + captured_path = None + + class InnerNode(AssistantNode): + async def arun(self, state, config): + nonlocal captured_path + captured_path = get_node_path() + return None + + class InnerGraph(BaseAssistantGraph[AssistantState, PartialAssistantState]): + @property + def state_type(self) -> type[AssistantState]: + return AssistantState + + @property + def graph_name(self) -> AssistantGraphName: + return AssistantGraphName.INSIGHTS + + def setup(self): + node = InnerNode(self._team, self._user) + self.add_node(AssistantNodeName.TRENDS_GENERATOR, node) + self.add_edge(AssistantNodeName.START, AssistantNodeName.TRENDS_GENERATOR) + self.add_edge(AssistantNodeName.TRENDS_GENERATOR, AssistantNodeName.END) + return self + + class OuterGraph(BaseAssistantGraph[AssistantState, PartialAssistantState]): + @property + def state_type(self) -> type[AssistantState]: + return AssistantState + + @property + def graph_name(self) -> AssistantGraphName: + return AssistantGraphName.ASSISTANT + + async def invoke_inner_graph(self, state, config): + inner_graph = InnerGraph(self._team, self._user) + compiled_inner = inner_graph.setup().compile(checkpointer=False) + await compiled_inner.ainvoke(state, config) + + outer_graph = OuterGraph(self.team, self.user) + + conversation = await Conversation.objects.acreate(team=self.team, user=self.user) + await outer_graph.invoke_inner_graph( + AssistantState(messages=[]), RunnableConfig(configurable={"thread_id": conversation.id}) + ) + + # Inner node: [graph, node] - outer graph context is not propagated when calling graph methods directly + assert captured_path is not None + self.assertEqual(len(captured_path), 2) + self.assertEqual(captured_path[0].name, AssistantGraphName.INSIGHTS.value) + self.assertEqual(captured_path[1].name, "InnerNode") + + async def test_node_path_preserved_across_async_and_sync_methods(self): + """Test that calling .run() directly doesn't extend path""" + captured_paths = [] + + class SyncNode(AssistantNode): + def run(self, state, config): + nonlocal captured_paths + captured_paths.append(("sync", get_node_path())) + return None + + class AsyncNode(AssistantNode): + async def arun(self, state, config): + nonlocal captured_paths + captured_paths.append(("async", get_node_path())) + # Call sync node - calling .run() directly bypasses context setting + sync_node = SyncNode(self._team, self._user) + sync_node.run(state, config) + return None + + class TestGraph(BaseAssistantGraph[AssistantState, PartialAssistantState]): + @property + def state_type(self) -> type[AssistantState]: + return AssistantState + + @property + def graph_name(self) -> AssistantGraphName: + return AssistantGraphName.ASSISTANT + + def setup(self): + node = AsyncNode(self._team, self._user) + self.add_node(AssistantNodeName.ROOT, node) + self.add_edge(AssistantNodeName.START, AssistantNodeName.ROOT) + self.add_edge(AssistantNodeName.ROOT, AssistantNodeName.END) + return self + + graph = TestGraph(self.team, self.user) + compiled = graph.setup().compile(checkpointer=False) + + conversation = await Conversation.objects.acreate(team=self.team, user=self.user) + await compiled.ainvoke(AssistantState(messages=[]), {"configurable": {"thread_id": conversation.id}}) + + # Async node: [graph, node] + self.assertEqual(captured_paths[0][0], "async") + assert captured_paths[0][1] is not None + self.assertEqual(len(captured_paths[0][1]), 2) + self.assertEqual(captured_paths[0][1][0].name, AssistantGraphName.ASSISTANT.value) + self.assertEqual(captured_paths[0][1][1].name, "AsyncNode") + + # Sync node: [graph, node] - initialized within AsyncNode's context, so gets same path + self.assertEqual(captured_paths[1][0], "sync") + assert captured_paths[1][1] is not None + self.assertEqual(len(captured_paths[1][1]), 2) + self.assertEqual(captured_paths[1][1][0].name, AssistantGraphName.ASSISTANT.value) + self.assertEqual(captured_paths[1][1][1].name, "AsyncNode") # Same as async because initialized in same context + + def test_node_path_with_explicit_node_path_parameter(self): + """Test that explicitly passing node_path overrides default behavior""" + custom_path = (NodePath(name="custom_graph"), NodePath(name="custom_node")) + + class TestNode(AssistantNode): + def run(self, state, config): + return None + + node = TestNode(self.team, self.user, node_path=custom_path) + + self.assertEqual(len(node._node_path), 2) + self.assertEqual(node._node_path[0].name, "custom_graph") + self.assertEqual(node._node_path[1].name, "custom_node") + + async def test_multiple_nested_graphs(self): + """Test deeply nested graph composition: Graph -> Node -> Graph -> Node -> Graph -> Node""" + captured_paths = [] + + class Level3Node(AssistantNode): + async def arun(self, state, config): + nonlocal captured_paths + captured_paths.append(("level3", get_node_path())) + return None + + class Level3Graph(BaseAssistantGraph[AssistantState, PartialAssistantState]): + @property + def state_type(self) -> type[AssistantState]: + return AssistantState + + @property + def graph_name(self) -> AssistantGraphName: + return AssistantGraphName.TAXONOMY + + def setup(self): + node = Level3Node(self._team, self._user) + self.add_node(AssistantNodeName.FUNNEL_GENERATOR, node) + self.add_edge(AssistantNodeName.START, AssistantNodeName.FUNNEL_GENERATOR) + self.add_edge(AssistantNodeName.FUNNEL_GENERATOR, AssistantNodeName.END) + return self + + class Level2Node(AssistantNode): + async def arun(self, state, config): + nonlocal captured_paths + captured_paths.append(("level2", get_node_path())) + # Call level 3 graph + level3_graph = Level3Graph(self._team, self._user) + compiled = level3_graph.setup().compile(checkpointer=False) + await compiled.ainvoke(state, config) + return None + + class Level2Graph(BaseAssistantGraph[AssistantState, PartialAssistantState]): + @property + def state_type(self) -> type[AssistantState]: + return AssistantState + + @property + def graph_name(self) -> AssistantGraphName: + return AssistantGraphName.INSIGHTS + + def setup(self): + node = Level2Node(self._team, self._user) + self.add_node(AssistantNodeName.TRENDS_GENERATOR, node) + self.add_edge(AssistantNodeName.START, AssistantNodeName.TRENDS_GENERATOR) + self.add_edge(AssistantNodeName.TRENDS_GENERATOR, AssistantNodeName.END) + return self + + class Level1Node(AssistantNode): + async def arun(self, state, config): + nonlocal captured_paths + captured_paths.append(("level1", get_node_path())) + # Call level 2 graph + level2_graph = Level2Graph(self._team, self._user) + compiled = level2_graph.setup().compile(checkpointer=False) + await compiled.ainvoke(state, config) + return None + + class Level1Graph(BaseAssistantGraph[AssistantState, PartialAssistantState]): + @property + def state_type(self) -> type[AssistantState]: + return AssistantState + + @property + def graph_name(self) -> AssistantGraphName: + return AssistantGraphName.ASSISTANT + + def setup(self): + node = Level1Node(self._team, self._user) + self.add_node(AssistantNodeName.ROOT, node) + self.add_edge(AssistantNodeName.START, AssistantNodeName.ROOT) + self.add_edge(AssistantNodeName.ROOT, AssistantNodeName.END) + return self + + level1_graph = Level1Graph(self.team, self.user) + compiled = level1_graph.setup().compile(checkpointer=False) + + conversation = await Conversation.objects.acreate(team=self.team, user=self.user) + await compiled.ainvoke(AssistantState(messages=[]), {"configurable": {"thread_id": conversation.id}}) + + # Level 1: [graph, node] + assert captured_paths[0][1] is not None + self.assertEqual(len(captured_paths[0][1]), 2) + self.assertEqual(captured_paths[0][1][0].name, AssistantGraphName.ASSISTANT.value) + self.assertEqual(captured_paths[0][1][1].name, "Level1Node") + + # Level 2: [graph, node, graph, node] + assert captured_paths[1][1] is not None + self.assertEqual(len(captured_paths[1][1]), 4) + self.assertEqual(captured_paths[1][1][0].name, AssistantGraphName.ASSISTANT.value) + self.assertEqual(captured_paths[1][1][1].name, "Level1Node") + self.assertEqual(captured_paths[1][1][2].name, AssistantGraphName.INSIGHTS.value) + self.assertEqual(captured_paths[1][1][3].name, "Level2Node") + + # Level 3: [graph, node, graph, node, graph, node] + assert captured_paths[2][1] is not None + self.assertEqual(len(captured_paths[2][1]), 6) + self.assertEqual(captured_paths[2][1][0].name, AssistantGraphName.ASSISTANT.value) + self.assertEqual(captured_paths[2][1][1].name, "Level1Node") + self.assertEqual(captured_paths[2][1][2].name, AssistantGraphName.INSIGHTS.value) + self.assertEqual(captured_paths[2][1][3].name, "Level2Node") + self.assertEqual(captured_paths[2][1][4].name, AssistantGraphName.TAXONOMY.value) + self.assertEqual(captured_paths[2][1][5].name, "Level3Node") diff --git a/ee/hogai/graph/dashboards/nodes.py b/ee/hogai/graph/dashboards/nodes.py index d5af4b494c..4d820acd2e 100644 --- a/ee/hogai/graph/dashboards/nodes.py +++ b/ee/hogai/graph/dashboards/nodes.py @@ -89,13 +89,6 @@ class DashboardCreationNode(AssistantNode): def _get_found_insight_count(self, queries_metadata: dict[str, QueryMetadata]) -> int: return sum(len(query.found_insight_ids) for query in queries_metadata.values()) - def _dispatch_update_message(self, content: str) -> None: - self.dispatcher.message( - AssistantMessage( - content=content, - ) - ) - async def arun(self, state: AssistantState, config: RunnableConfig) -> PartialAssistantState: dashboard_name = ( state.dashboard_name[:50] if state.dashboard_name else "Analytics Dashboard" @@ -117,18 +110,18 @@ class DashboardCreationNode(AssistantNode): for i, query in enumerate(state.search_insights_queries) } - self._dispatch_update_message(f"Searching for {pluralize(len(state.search_insights_queries), 'insight')}") + self.dispatcher.update(f"Searching for {pluralize(len(state.search_insights_queries), 'insight')}") result = await self._search_insights(result, config) - self._dispatch_update_message(f"Found {pluralize(self._get_found_insight_count(result), 'insight')}") + self.dispatcher.update(f"Found {pluralize(self._get_found_insight_count(result), 'insight')}") left_to_create = { query_id: result[query_id].query for query_id in result.keys() if not result[query_id].found_insight_ids } if left_to_create: - self._dispatch_update_message(f"Will create {pluralize(len(left_to_create), 'insight')}") + self.dispatcher.update(f"Will create {pluralize(len(left_to_create), 'insight')}") result = await self._create_insights(left_to_create, result, config) @@ -187,9 +180,7 @@ class DashboardCreationNode(AssistantNode): message = AssistantMessage(content="", id=str(uuid4()), tool_calls=tool_calls) executor = DashboardCreationExecutorNode(self._team, self._user) - result = await executor.arun( - AssistantState(messages=[message], root_tool_call_id=self._parent_tool_call_id), config - ) + result = await executor.arun(AssistantState(messages=[message], root_tool_call_id=self.tool_call_id), config) query_metadata = await self._process_insight_creation_results(tool_calls, result.task_results, query_metadata) @@ -211,9 +202,7 @@ class DashboardCreationNode(AssistantNode): message = AssistantMessage(content="", id=str(uuid4()), tool_calls=tool_calls) executor = DashboardCreationExecutorNode(self._team, self._user) - result = await executor.arun( - AssistantState(messages=[message], root_tool_call_id=self._parent_tool_call_id), config - ) + result = await executor.arun(AssistantState(messages=[message], root_tool_call_id=self.tool_call_id), config) final_task_executor_state = BaseStateWithTasks.model_validate(result) for task_result in final_task_executor_state.task_results: @@ -307,7 +296,7 @@ class DashboardCreationNode(AssistantNode): self, dashboard_name: str, insights: set[int], dashboard_id: int | None = None ) -> tuple[Dashboard, list[Insight]]: """Create a dashboard and add the insights to it.""" - self._dispatch_update_message("Saving your dashboard") + self.dispatcher.update("Saving your dashboard") @database_sync_to_async @transaction.atomic diff --git a/ee/hogai/graph/dashboards/test/test_nodes.py b/ee/hogai/graph/dashboards/test/test_nodes.py index a53bf0b9ac..cb0fd2e3c9 100644 --- a/ee/hogai/graph/dashboards/test/test_nodes.py +++ b/ee/hogai/graph/dashboards/test/test_nodes.py @@ -1,5 +1,5 @@ import pytest -from unittest import TestCase +from posthog.test.base import BaseTest from unittest.mock import AsyncMock, MagicMock, patch from langchain_core.runnables import RunnableConfig @@ -11,10 +11,17 @@ from posthog.models import Dashboard, Insight, Team, User from ee.hogai.graph.dashboards.nodes import DashboardCreationExecutorNode, DashboardCreationNode, QueryMetadata from ee.hogai.utils.helpers import build_dashboard_url, build_insight_url from ee.hogai.utils.types import AssistantState, PartialAssistantState -from ee.hogai.utils.types.base import BaseStateWithTasks, InsightArtifact, InsightQuery, TaskResult +from ee.hogai.utils.types.base import ( + AssistantNodeName, + BaseStateWithTasks, + InsightArtifact, + InsightQuery, + NodePath, + TaskResult, +) -class TestQueryMetadata(TestCase): +class TestQueryMetadata(BaseTest): def test_query_metadata_initialization(self): """Test QueryMetadata initialization with all fields.""" query = InsightQuery(name="Test Query", description="Test Description") @@ -33,21 +40,30 @@ class TestQueryMetadata(TestCase): self.assertEqual(metadata.query, query) -class TestDashboardCreationExecutorNode: - @pytest.fixture(autouse=True) - def setup_method(self): +class TestDashboardCreationExecutorNode(BaseTest): + def setUp(self): + super().setUp() self.mock_team = MagicMock(spec=Team) self.mock_team.id = 1 self.mock_user = MagicMock(spec=User) self.mock_user.id = 1 - self.node = DashboardCreationExecutorNode(self.mock_team, self.mock_user) + self.node = DashboardCreationExecutorNode( + self.mock_team, + self.mock_user, + ( + NodePath( + name=AssistantNodeName.DASHBOARD_CREATION_EXECUTOR.value, + message_id="test_message_id", + tool_call_id="test_tool_call_id", + ), + ), + ) def test_initialization(self): """Test node initialization.""" assert self.node._team == self.mock_team assert self.node._user == self.mock_user - @pytest.mark.asyncio async def test_aget_input_tuples_search_insights(self): """Test _aget_input_tuples for search_insights tasks.""" tool_calls = [ @@ -62,7 +78,6 @@ class TestDashboardCreationExecutorNode: assert task.name == "search_insights" assert callable_func == self.node._execute_search_insights - @pytest.mark.asyncio async def test_aget_input_tuples_create_insight(self): """Test _aget_input_tuples for create_insight tasks.""" tool_calls = [AssistantToolCall(id="task_1", name="create_insight", args={"query_description": "Test prompt"})] @@ -75,7 +90,6 @@ class TestDashboardCreationExecutorNode: assert task.name == "create_insight" assert callable_func == self.node._execute_create_insight - @pytest.mark.asyncio async def test_aget_input_tuples_unsupported_task(self): """Test _aget_input_tuples raises error for unsupported task type.""" tool_calls = [AssistantToolCall(id="task_1", name="unsupported_type", args={"query": "Test prompt"})] @@ -85,7 +99,6 @@ class TestDashboardCreationExecutorNode: assert "Unsupported task type: unsupported_type" in str(exc_info.value) - @pytest.mark.asyncio async def test_aget_input_tuples_no_tasks(self): """Test _aget_input_tuples returns empty list when no tasks.""" tool_calls: list[AssistantToolCall] = [] @@ -95,14 +108,24 @@ class TestDashboardCreationExecutorNode: assert len(input_tuples) == 0 -class TestDashboardCreationNode: - @pytest.fixture(autouse=True) - def setup_method(self): +class TestDashboardCreationNode(BaseTest): + def setUp(self): + super().setUp() self.mock_team = MagicMock(spec=Team) self.mock_team.id = 1 self.mock_user = MagicMock(spec=User) self.mock_user.id = 1 - self.node = DashboardCreationNode(self.mock_team, self.mock_user) + self.node = DashboardCreationNode( + self.mock_team, + self.mock_user, + ( + NodePath( + name=AssistantNodeName.DASHBOARD_CREATION.value, + message_id="test_message_id", + tool_call_id="test_tool_call_id", + ), + ), + ) def test_get_found_insight_count(self): """Test _get_found_insight_count calculates correct count.""" @@ -140,7 +163,6 @@ class TestDashboardCreationNode: expected_url = f"/project/{self.mock_team.id}/dashboard/{dashboard_id}" assert url == expected_url - @pytest.mark.asyncio @patch("ee.hogai.graph.dashboards.nodes.DashboardCreationExecutorNode") async def test_arun_missing_search_insights_queries(self, mock_executor_node_class): """Test arun returns error when search_insights_queries is missing.""" @@ -150,7 +172,7 @@ class TestDashboardCreationNode: state = AssistantState( dashboard_name="Create dashboard", search_insights_queries=None, - root_tool_call_id="test_call", + root_tool_call_id="test_tool_call_id", ) config = RunnableConfig() @@ -161,7 +183,6 @@ class TestDashboardCreationNode: assert isinstance(result.messages[0], AssistantToolCallMessage) assert "Search insights queries are required" in result.messages[0].content - @pytest.mark.asyncio @patch("ee.hogai.graph.dashboards.nodes.DashboardCreationExecutorNode") @patch.object(DashboardCreationNode, "_search_insights") @patch.object(DashboardCreationNode, "_create_insights") @@ -205,7 +226,7 @@ class TestDashboardCreationNode: state = AssistantState( dashboard_name="Create dashboard", search_insights_queries=[InsightQuery(name="Query 1", description="Description 1")], - root_tool_call_id="test_call", + root_tool_call_id="test_tool_call_id", ) config = RunnableConfig() @@ -217,7 +238,6 @@ class TestDashboardCreationNode: assert "Dashboard Created" in result.messages[0].content assert "Test Dashboard" in result.messages[0].content - @pytest.mark.asyncio @patch("ee.hogai.graph.dashboards.nodes.DashboardCreationExecutorNode") @patch.object(DashboardCreationNode, "_search_insights") @patch.object(DashboardCreationNode, "_create_insights") @@ -244,7 +264,7 @@ class TestDashboardCreationNode: state = AssistantState( dashboard_name="Create dashboard", search_insights_queries=[InsightQuery(name="Query 1", description="Description 1")], - root_tool_call_id="test_call", + root_tool_call_id="test_tool_call_id", ) config = RunnableConfig() @@ -255,7 +275,6 @@ class TestDashboardCreationNode: assert isinstance(result.messages[0], AssistantToolCallMessage) assert "No existing insights matched" in result.messages[0].content - @pytest.mark.asyncio @patch("ee.hogai.graph.dashboards.nodes.DashboardCreationExecutorNode") @patch.object(DashboardCreationNode, "_search_insights") @patch("ee.hogai.graph.dashboards.nodes.logger") @@ -268,7 +287,7 @@ class TestDashboardCreationNode: state = AssistantState( dashboard_name="Create dashboard", search_insights_queries=[InsightQuery(name="Query 1", description="Description 1")], - root_tool_call_id="test_call", + root_tool_call_id="test_tool_call_id", ) config = RunnableConfig() @@ -295,7 +314,7 @@ class TestDashboardCreationNode: result = self.node._create_success_response( mock_dashboard, mock_insights, # type: ignore[arg-type] - "test_call", + "test_tool_call_id", ["Query without insights"], ) @@ -309,7 +328,7 @@ class TestDashboardCreationNode: def test_create_no_insights_response(self): """Test _create_no_insights_response creates correct no insights message.""" - result = self.node._create_no_insights_response("test_call", "No insights found") + result = self.node._create_no_insights_response("test_tool_call_id", "No insights found") assert isinstance(result, PartialAssistantState) assert len(result.messages) == 1 @@ -320,27 +339,30 @@ class TestDashboardCreationNode: def test_create_error_response(self): """Test _create_error_response creates correct error message.""" with patch("ee.hogai.graph.dashboards.nodes.capture_exception") as mock_capture: - result = self.node._create_error_response("Test error", "test_call") + result = self.node._create_error_response("Test error", "test_tool_call_id") assert isinstance(result, PartialAssistantState) assert len(result.messages) == 1 assert isinstance(result.messages[0], AssistantToolCallMessage) assert result.messages[0].content == "Test error" assert isinstance(result.messages[0], AssistantToolCallMessage) - assert result.messages[0].tool_call_id == "test_call" + assert result.messages[0].tool_call_id == "test_tool_call_id" mock_capture.assert_called_once() -class TestDashboardCreationNodeAsyncMethods: - @pytest.fixture(autouse=True) - def setup_method(self): +class TestDashboardCreationNodeAsyncMethods(BaseTest): + def setUp(self): + super().setUp() self.mock_team = MagicMock(spec=Team) self.mock_team.id = 1 self.mock_user = MagicMock(spec=User) self.mock_user.id = 1 - self.node = DashboardCreationNode(self.mock_team, self.mock_user) + self.node = DashboardCreationNode( + self.mock_team, + self.mock_user, + node_path=(NodePath(name="test_node", message_id="test-id", tool_call_id="test_tool_call_id"),), + ) - @pytest.mark.asyncio @patch("ee.hogai.graph.dashboards.nodes.DashboardCreationExecutorNode") async def test_create_insights(self, mock_executor_node_class): """Test _create_insights method.""" @@ -389,7 +411,6 @@ class TestDashboardCreationNodeAsyncMethods: assert len(result["task_1"].created_insight_ids) == 2 assert len(result["task_1"].created_insight_messages) == 1 - @pytest.mark.asyncio @patch("ee.hogai.graph.dashboards.nodes.DashboardCreationExecutorNode") async def test_search_insights(self, mock_executor_node_class): """Test _search_insights method.""" @@ -430,7 +451,6 @@ class TestDashboardCreationNodeAsyncMethods: assert len(result["task_1"].found_insight_ids) == 2 assert len(result["task_1"].found_insight_messages) == 2 - @pytest.mark.asyncio @patch("ee.hogai.graph.dashboards.nodes.database_sync_to_async") async def test_create_dashboard_with_insights(self, mock_db_sync): """Test _create_dashboard_with_insights method.""" @@ -452,7 +472,6 @@ class TestDashboardCreationNodeAsyncMethods: assert result[1] == mock_insights mock_sync_func.assert_called_once() - @pytest.mark.asyncio async def test_process_insight_creation_results(self): """Test _process_insight_creation_results method.""" # Create simple mocked Team and User instances like other tests diff --git a/ee/hogai/graph/deep_research/graph.py b/ee/hogai/graph/deep_research/graph.py index 1cb3e68d73..d442f5366f 100644 --- a/ee/hogai/graph/deep_research/graph.py +++ b/ee/hogai/graph/deep_research/graph.py @@ -1,21 +1,25 @@ from typing import Literal, Optional -from posthog.models.team.team import Team -from posthog.models.user import User - from ee.hogai.django_checkpoint.checkpointer import DjangoCheckpointer +from ee.hogai.graph.base import BaseAssistantGraph from ee.hogai.graph.deep_research.notebook.nodes import DeepResearchNotebookPlanningNode from ee.hogai.graph.deep_research.onboarding.nodes import DeepResearchOnboardingNode from ee.hogai.graph.deep_research.planner.nodes import DeepResearchPlannerNode, DeepResearchPlannerToolsNode from ee.hogai.graph.deep_research.report.nodes import DeepResearchReportNode from ee.hogai.graph.deep_research.task_executor.nodes import DeepResearchTaskExecutorNode -from ee.hogai.graph.deep_research.types import DeepResearchNodeName, DeepResearchState -from ee.hogai.graph.graph import BaseAssistantGraph +from ee.hogai.graph.deep_research.types import DeepResearchNodeName, DeepResearchState, PartialDeepResearchState +from ee.hogai.graph.title_generator.nodes import TitleGeneratorNode +from ee.hogai.utils.types.base import AssistantGraphName -class DeepResearchAssistantGraph(BaseAssistantGraph[DeepResearchState]): - def __init__(self, team: Team, user: User): - super().__init__(team, user, DeepResearchState) +class DeepResearchAssistantGraph(BaseAssistantGraph[DeepResearchState, PartialDeepResearchState]): + @property + def graph_name(self) -> AssistantGraphName: + return AssistantGraphName.DEEP_RESEARCH + + @property + def state_type(self) -> type[DeepResearchState]: + return DeepResearchState def add_onboarding_node( self, node_map: Optional[dict[Literal["onboarding", "planning", "continue"], DeepResearchNodeName]] = None @@ -80,13 +84,21 @@ class DeepResearchAssistantGraph(BaseAssistantGraph[DeepResearchState]): builder.add_edge(DeepResearchNodeName.REPORT, next_node) return self + def add_title_generator(self, end_node: DeepResearchNodeName = DeepResearchNodeName.END): + self._has_start_node = True + + title_generator = TitleGeneratorNode(self._team, self._user) + self._graph.add_node(DeepResearchNodeName.TITLE_GENERATOR, title_generator) + self._graph.add_edge(DeepResearchNodeName.START, DeepResearchNodeName.TITLE_GENERATOR) + self._graph.add_edge(DeepResearchNodeName.TITLE_GENERATOR, end_node) + return self + def compile_full_graph(self, checkpointer: DjangoCheckpointer | None = None): return ( self.add_onboarding_node() .add_notebook_nodes() .add_planner_nodes() .add_report_node() - .add_title_generator() .add_task_executor() .compile(checkpointer=checkpointer) ) diff --git a/ee/hogai/graph/deep_research/planner/nodes.py b/ee/hogai/graph/deep_research/planner/nodes.py index 74a04d947a..6c8c1e0258 100644 --- a/ee/hogai/graph/deep_research/planner/nodes.py +++ b/ee/hogai/graph/deep_research/planner/nodes.py @@ -40,8 +40,8 @@ from ee.hogai.graph.deep_research.types import ( DeepResearchNodeName, DeepResearchState, PartialDeepResearchState, + TodoItem, ) -from ee.hogai.graph.root.tools.todo_write import TodoItem from ee.hogai.notebook.notebook_serializer import NotebookSerializer from ee.hogai.utils.helpers import normalize_ai_message from ee.hogai.utils.types import WithCommentary diff --git a/ee/hogai/graph/deep_research/planner/test/test_nodes.py b/ee/hogai/graph/deep_research/planner/test/test_nodes.py index 74f772db10..2c78a6e490 100644 --- a/ee/hogai/graph/deep_research/planner/test/test_nodes.py +++ b/ee/hogai/graph/deep_research/planner/test/test_nodes.py @@ -34,8 +34,7 @@ from ee.hogai.graph.deep_research.planner.prompts import ( WRITE_RESULT_FAILED_TOOL_RESULT, WRITE_RESULT_TOOL_RESULT, ) -from ee.hogai.graph.deep_research.types import DeepResearchState, PartialDeepResearchState -from ee.hogai.graph.root.tools.todo_write import TodoItem +from ee.hogai.graph.deep_research.types import DeepResearchState, PartialDeepResearchState, TodoItem from ee.hogai.utils.types import InsightArtifact from ee.hogai.utils.types.base import TaskResult diff --git a/ee/hogai/graph/deep_research/task_executor/test/test_nodes.py b/ee/hogai/graph/deep_research/task_executor/test/test_nodes.py index ab0f6b06c6..79c8545eac 100644 --- a/ee/hogai/graph/deep_research/task_executor/test/test_nodes.py +++ b/ee/hogai/graph/deep_research/task_executor/test/test_nodes.py @@ -128,7 +128,7 @@ class TestTaskExecutorNodeArun(TestTaskExecutorNode): class TestTaskExecutorInsightsExecution(TestTaskExecutorNode): @patch("ee.hogai.graph.deep_research.task_executor.nodes.DeepResearchTaskExecutorNode.dispatcher") - @patch("ee.hogai.graph.deep_research.task_executor.nodes.InsightsAssistantGraph") + @patch("ee.hogai.graph.deep_research.task_executor.nodes.InsightsGraph") async def test_execute_task_with_insights_successful(self, mock_insights_graph_class, mock_dispatcher): """Test successful task execution through insights pipeline.""" @@ -171,7 +171,7 @@ class TestTaskExecutorInsightsExecution(TestTaskExecutorNode): self.assertEqual(len(result.artifacts), 1) @patch("ee.hogai.graph.deep_research.task_executor.nodes.DeepResearchTaskExecutorNode.dispatcher") - @patch("ee.hogai.graph.deep_research.task_executor.nodes.InsightsAssistantGraph") + @patch("ee.hogai.graph.deep_research.task_executor.nodes.InsightsGraph") async def test_execute_task_with_insights_no_artifacts(self, mock_insights_graph_class, mock_dispatcher): """Test task execution that produces no artifacts.""" task = self._create_assistant_tool_call() @@ -209,7 +209,7 @@ class TestTaskExecutorInsightsExecution(TestTaskExecutorNode): @patch("ee.hogai.graph.deep_research.task_executor.nodes.capture_exception") @patch("ee.hogai.graph.deep_research.task_executor.nodes.DeepResearchTaskExecutorNode.dispatcher") - @patch("ee.hogai.graph.deep_research.task_executor.nodes.InsightsAssistantGraph") + @patch("ee.hogai.graph.deep_research.task_executor.nodes.InsightsGraph") async def test_execute_task_with_exception(self, mock_insights_graph_class, mock_dispatcher, mock_capture): """Test task execution that encounters an exception.""" task = self._create_assistant_tool_call() diff --git a/ee/hogai/graph/deep_research/test/test_integration.py b/ee/hogai/graph/deep_research/test/test_integration.py index 811a9ad651..621c51364f 100644 --- a/ee/hogai/graph/deep_research/test/test_integration.py +++ b/ee/hogai/graph/deep_research/test/test_integration.py @@ -27,8 +27,7 @@ from ee.hogai.graph.deep_research.onboarding.nodes import DeepResearchOnboarding from ee.hogai.graph.deep_research.planner.nodes import DeepResearchPlannerNode, DeepResearchPlannerToolsNode from ee.hogai.graph.deep_research.report.nodes import DeepResearchReportNode from ee.hogai.graph.deep_research.task_executor.nodes import DeepResearchTaskExecutorNode -from ee.hogai.graph.deep_research.types import DeepResearchIntermediateResult, DeepResearchState -from ee.hogai.graph.root.tools.todo_write import TodoItem +from ee.hogai.graph.deep_research.types import DeepResearchIntermediateResult, DeepResearchState, TodoItem from ee.hogai.utils.types.base import TaskResult from ee.models.assistant import Conversation diff --git a/ee/hogai/graph/deep_research/test/test_types.py b/ee/hogai/graph/deep_research/test/test_types.py index 06a9c835c7..4927596bc5 100644 --- a/ee/hogai/graph/deep_research/test/test_types.py +++ b/ee/hogai/graph/deep_research/test/test_types.py @@ -12,9 +12,9 @@ from ee.hogai.graph.deep_research.types import ( DeepResearchIntermediateResult, DeepResearchState, PartialDeepResearchState, + TodoItem, _SharedDeepResearchState, ) -from ee.hogai.graph.root.tools.todo_write import TodoItem from ee.hogai.utils.types.base import InsightArtifact, TaskResult """ diff --git a/ee/hogai/graph/deep_research/types.py b/ee/hogai/graph/deep_research/types.py index 49c69c2e8a..c6099f3052 100644 --- a/ee/hogai/graph/deep_research/types.py +++ b/ee/hogai/graph/deep_research/types.py @@ -1,19 +1,31 @@ from collections.abc import Sequence from enum import StrEnum -from typing import Annotated, Optional +from typing import Annotated, Literal, Optional from langgraph.graph import END, START from pydantic import BaseModel, Field from posthog.schema import DeepResearchNotebook -from ee.hogai.graph.root.tools.todo_write import TodoItem -from ee.hogai.utils.types import AssistantMessageUnion, add_and_merge_messages -from ee.hogai.utils.types.base import BaseStateWithMessages, BaseStateWithTasks, append, replace +from ee.hogai.utils.types.base import ( + AssistantMessageUnion, + BaseStateWithMessages, + BaseStateWithTasks, + add_and_merge_messages, + append, + replace, +) NotebookInfo = DeepResearchNotebook +class TodoItem(BaseModel): + content: str = Field(..., min_length=1) + status: Literal["pending", "in_progress", "completed"] + id: str + priority: Literal["low", "medium", "high"] + + class DeepResearchIntermediateResult(BaseModel): """ An intermediate result of a batch of work, that will be used to write the final report. @@ -69,3 +81,4 @@ class DeepResearchNodeName(StrEnum): PLANNER_TOOLS = "planner_tools" TASK_EXECUTOR = "task_executor" REPORT = "report" + TITLE_GENERATOR = "title_generator" diff --git a/ee/hogai/graph/funnels/nodes.py b/ee/hogai/graph/funnels/nodes.py index 766c8fe09d..e062c15b2b 100644 --- a/ee/hogai/graph/funnels/nodes.py +++ b/ee/hogai/graph/funnels/nodes.py @@ -1,7 +1,7 @@ from langchain_core.prompts import ChatPromptTemplate from langchain_core.runnables import RunnableConfig -from posthog.schema import AssistantFunnelsQuery, AssistantMessage +from posthog.schema import AssistantFunnelsQuery from ee.hogai.utils.types import AssistantState, PartialAssistantState from ee.hogai.utils.types.base import AssistantNodeName @@ -25,7 +25,7 @@ class FunnelGeneratorNode(SchemaGeneratorNode[AssistantFunnelsQuery]): return AssistantNodeName.FUNNEL_GENERATOR async def arun(self, state: AssistantState, config: RunnableConfig) -> PartialAssistantState: - self.dispatcher.message(AssistantMessage(content="Creating funnel query")) + self.dispatcher.update("Creating funnel query") prompt = ChatPromptTemplate.from_messages( [ ("system", FUNNEL_SYSTEM_PROMPT), diff --git a/ee/hogai/graph/graph.py b/ee/hogai/graph/graph.py index 38da274d11..92598694c0 100644 --- a/ee/hogai/graph/graph.py +++ b/ee/hogai/graph/graph.py @@ -1,23 +1,11 @@ -from collections.abc import Hashable -from typing import Any, Generic, Literal, Optional, cast - -from langgraph.graph.state import StateGraph - -from posthog.models.team.team import Team -from posthog.models.user import User +from collections.abc import Callable +from typing import cast from ee.hogai.django_checkpoint.checkpointer import DjangoCheckpointer -from ee.hogai.graph.billing.nodes import BillingNode -from ee.hogai.graph.query_planner.nodes import QueryPlannerNode, QueryPlannerToolsNode -from ee.hogai.graph.session_summaries.nodes import SessionSummarizationNode +from ee.hogai.graph.base import BaseAssistantGraph from ee.hogai.graph.title_generator.nodes import TitleGeneratorNode -from ee.hogai.utils.types import AssistantNodeName, AssistantState, StateType -from ee.hogai.utils.types.composed import MaxNodeName +from ee.hogai.utils.types.base import AssistantGraphName, AssistantNodeName, AssistantState, PartialAssistantState -from .dashboards.nodes import DashboardCreationNode -from .funnels.nodes import FunnelGeneratorNode, FunnelGeneratorToolsNode -from .inkeep_docs.nodes import InkeepDocsNode -from .insights.nodes import InsightSearchNode from .memory.nodes import ( MemoryCollectorNode, MemoryCollectorToolsNode, @@ -28,51 +16,19 @@ from .memory.nodes import ( MemoryOnboardingFinalizeNode, MemoryOnboardingNode, ) -from .query_executor.nodes import QueryExecutorNode -from .rag.nodes import InsightRagContextNode -from .retention.nodes import RetentionGeneratorNode, RetentionGeneratorToolsNode from .root.nodes import RootNode, RootNodeTools -from .sql.nodes import SQLGeneratorNode, SQLGeneratorToolsNode -from .trends.nodes import TrendsGeneratorNode, TrendsGeneratorToolsNode - -global_checkpointer = DjangoCheckpointer() -class BaseAssistantGraph(Generic[StateType]): - _team: Team - _user: User - _graph: StateGraph - _parent_tool_call_id: str | None +class AssistantGraph(BaseAssistantGraph[AssistantState, PartialAssistantState]): + @property + def graph_name(self) -> AssistantGraphName: + return AssistantGraphName.ASSISTANT - def __init__(self, team: Team, user: User, state_type: type[StateType], parent_tool_call_id: str | None = None): - self._team = team - self._user = user - self._graph = StateGraph(state_type) - self._has_start_node = False - self._parent_tool_call_id = parent_tool_call_id + @property + def state_type(self) -> type[AssistantState]: + return AssistantState - def add_edge(self, from_node: MaxNodeName, to_node: MaxNodeName): - if from_node == AssistantNodeName.START: - self._has_start_node = True - self._graph.add_edge(from_node, to_node) - return self - - def add_node(self, node: MaxNodeName, action: Any): - if self._parent_tool_call_id: - action._parent_tool_call_id = self._parent_tool_call_id - self._graph.add_node(node, action) - return self - - def compile(self, checkpointer: DjangoCheckpointer | None | Literal[False] = None): - if not self._has_start_node: - raise ValueError("Start node not added to the graph") - # TRICKY: We check `is not None` because False has a special meaning of "no checkpointer", which we want to pass on - compiled_graph = self._graph.compile( - checkpointer=checkpointer if checkpointer is not None else global_checkpointer - ) - return compiled_graph - - def add_title_generator(self, end_node: MaxNodeName = AssistantNodeName.END): + def add_title_generator(self, end_node: AssistantNodeName = AssistantNodeName.END): self._has_start_node = True title_generator = TitleGeneratorNode(self._team, self._user) @@ -81,178 +37,19 @@ class BaseAssistantGraph(Generic[StateType]): self._graph.add_edge(AssistantNodeName.TITLE_GENERATOR, end_node) return self - -class InsightsAssistantGraph(BaseAssistantGraph[AssistantState]): - def __init__(self, team: Team, user: User, tool_call_id: str | None = None): - super().__init__(team, user, AssistantState, tool_call_id) - - def add_rag_context(self): - self._has_start_node = True - retriever = InsightRagContextNode(self._team, self._user) - self.add_node(AssistantNodeName.INSIGHT_RAG_CONTEXT, retriever) - self._graph.add_edge(AssistantNodeName.START, AssistantNodeName.INSIGHT_RAG_CONTEXT) - self._graph.add_edge(AssistantNodeName.INSIGHT_RAG_CONTEXT, AssistantNodeName.QUERY_PLANNER) - return self - - def add_trends_generator(self, next_node: AssistantNodeName = AssistantNodeName.QUERY_EXECUTOR): - trends_generator = TrendsGeneratorNode(self._team, self._user) - self.add_node(AssistantNodeName.TRENDS_GENERATOR, trends_generator) - - trends_generator_tools = TrendsGeneratorToolsNode(self._team, self._user) - self.add_node(AssistantNodeName.TRENDS_GENERATOR_TOOLS, trends_generator_tools) - - self._graph.add_edge(AssistantNodeName.TRENDS_GENERATOR_TOOLS, AssistantNodeName.TRENDS_GENERATOR) - self._graph.add_conditional_edges( - AssistantNodeName.TRENDS_GENERATOR, - trends_generator.router, - path_map={ - "tools": AssistantNodeName.TRENDS_GENERATOR_TOOLS, - "next": next_node, - }, - ) - - return self - - def add_funnel_generator(self, next_node: AssistantNodeName = AssistantNodeName.QUERY_EXECUTOR): - funnel_generator = FunnelGeneratorNode(self._team, self._user) - self.add_node(AssistantNodeName.FUNNEL_GENERATOR, funnel_generator) - - funnel_generator_tools = FunnelGeneratorToolsNode(self._team, self._user) - self.add_node(AssistantNodeName.FUNNEL_GENERATOR_TOOLS, funnel_generator_tools) - - self._graph.add_edge(AssistantNodeName.FUNNEL_GENERATOR_TOOLS, AssistantNodeName.FUNNEL_GENERATOR) - self._graph.add_conditional_edges( - AssistantNodeName.FUNNEL_GENERATOR, - funnel_generator.router, - path_map={ - "tools": AssistantNodeName.FUNNEL_GENERATOR_TOOLS, - "next": next_node, - }, - ) - - return self - - def add_retention_generator(self, next_node: AssistantNodeName = AssistantNodeName.QUERY_EXECUTOR): - retention_generator = RetentionGeneratorNode(self._team, self._user) - self.add_node(AssistantNodeName.RETENTION_GENERATOR, retention_generator) - - retention_generator_tools = RetentionGeneratorToolsNode(self._team, self._user) - self.add_node(AssistantNodeName.RETENTION_GENERATOR_TOOLS, retention_generator_tools) - - self._graph.add_edge(AssistantNodeName.RETENTION_GENERATOR_TOOLS, AssistantNodeName.RETENTION_GENERATOR) - self._graph.add_conditional_edges( - AssistantNodeName.RETENTION_GENERATOR, - retention_generator.router, - path_map={ - "tools": AssistantNodeName.RETENTION_GENERATOR_TOOLS, - "next": next_node, - }, - ) - - return self - - def add_query_planner( - self, - path_map: Optional[ - dict[Literal["trends", "funnel", "retention", "sql", "continue", "end"], AssistantNodeName] - ] = None, - ): - query_planner = QueryPlannerNode(self._team, self._user) - self.add_node(AssistantNodeName.QUERY_PLANNER, query_planner) - self._graph.add_edge(AssistantNodeName.QUERY_PLANNER, AssistantNodeName.QUERY_PLANNER_TOOLS) - - query_planner_tools = QueryPlannerToolsNode(self._team, self._user) - self.add_node(AssistantNodeName.QUERY_PLANNER_TOOLS, query_planner_tools) - self._graph.add_conditional_edges( - AssistantNodeName.QUERY_PLANNER_TOOLS, - query_planner_tools.router, - path_map=path_map # type: ignore - or { - "continue": AssistantNodeName.QUERY_PLANNER, - "trends": AssistantNodeName.TRENDS_GENERATOR, - "funnel": AssistantNodeName.FUNNEL_GENERATOR, - "retention": AssistantNodeName.RETENTION_GENERATOR, - "sql": AssistantNodeName.SQL_GENERATOR, - "end": AssistantNodeName.END, - }, - ) - - return self - - def add_sql_generator(self, next_node: AssistantNodeName = AssistantNodeName.QUERY_EXECUTOR): - sql_generator = SQLGeneratorNode(self._team, self._user) - self.add_node(AssistantNodeName.SQL_GENERATOR, sql_generator) - - sql_generator_tools = SQLGeneratorToolsNode(self._team, self._user) - self.add_node(AssistantNodeName.SQL_GENERATOR_TOOLS, sql_generator_tools) - - self._graph.add_edge(AssistantNodeName.SQL_GENERATOR_TOOLS, AssistantNodeName.SQL_GENERATOR) - self._graph.add_conditional_edges( - AssistantNodeName.SQL_GENERATOR, - sql_generator.router, - path_map={ - "tools": AssistantNodeName.SQL_GENERATOR_TOOLS, - "next": next_node, - }, - ) - - return self - - def add_query_executor(self, next_node: AssistantNodeName = AssistantNodeName.END): - query_executor_node = QueryExecutorNode(self._team, self._user) - self.add_node(AssistantNodeName.QUERY_EXECUTOR, query_executor_node) - self._graph.add_edge(AssistantNodeName.QUERY_EXECUTOR, next_node) - return self - - def add_query_creation_flow(self, next_node: AssistantNodeName = AssistantNodeName.QUERY_EXECUTOR): - """Add all nodes and edges EXCEPT query execution.""" - return ( - self.add_rag_context() - .add_query_planner() - .add_trends_generator(next_node=next_node) - .add_funnel_generator(next_node=next_node) - .add_retention_generator(next_node=next_node) - .add_sql_generator(next_node=next_node) - ) - - def compile_full_graph(self, checkpointer: DjangoCheckpointer | None = None): - return self.add_query_creation_flow().add_query_executor().compile(checkpointer=checkpointer) - - -class AssistantGraph(BaseAssistantGraph[AssistantState]): - def __init__(self, team: Team, user: User): - super().__init__(team, user, AssistantState) - - def add_root( - self, - path_map: Optional[dict[Hashable, AssistantNodeName]] = None, - tools_node: AssistantNodeName = AssistantNodeName.ROOT_TOOLS, - ): - path_map = path_map or { - "insights": AssistantNodeName.INSIGHTS_SUBGRAPH, - "search_documentation": AssistantNodeName.INKEEP_DOCS, - "root": AssistantNodeName.ROOT, - "billing": AssistantNodeName.BILLING, - "end": AssistantNodeName.END, - "insights_search": AssistantNodeName.INSIGHTS_SEARCH, - "session_summarization": AssistantNodeName.SESSION_SUMMARIZATION, - "create_dashboard": AssistantNodeName.DASHBOARD_CREATION, - } + def add_root(self, router: Callable[[AssistantState], AssistantNodeName] | None = None): root_node = RootNode(self._team, self._user) self.add_node(AssistantNodeName.ROOT, root_node) root_node_tools = RootNodeTools(self._team, self._user) self.add_node(AssistantNodeName.ROOT_TOOLS, root_node_tools) - self._graph.add_edge(AssistantNodeName.ROOT, tools_node) self._graph.add_conditional_edges( - AssistantNodeName.ROOT_TOOLS, root_node_tools.router, path_map=cast(dict[Hashable, str], path_map) + AssistantNodeName.ROOT, router or cast(Callable[[AssistantState], AssistantNodeName], root_node.router) + ) + self._graph.add_conditional_edges( + AssistantNodeName.ROOT_TOOLS, + root_node_tools.router, + path_map={"root": AssistantNodeName.ROOT, "end": AssistantNodeName.END}, ) - return self - - def add_insights(self, next_node: AssistantNodeName = AssistantNodeName.ROOT): - insights_assistant_graph = InsightsAssistantGraph(self._team, self._user) - compiled_graph = insights_assistant_graph.compile_full_graph() - self.add_node(AssistantNodeName.INSIGHTS_SUBGRAPH, compiled_graph) - self._graph.add_edge(AssistantNodeName.INSIGHTS_SUBGRAPH, next_node) return self def add_memory_onboarding( @@ -332,55 +129,6 @@ class AssistantGraph(BaseAssistantGraph[AssistantState]): self._graph.add_edge(AssistantNodeName.MEMORY_COLLECTOR_TOOLS, AssistantNodeName.MEMORY_COLLECTOR) return self - def add_inkeep_docs(self, path_map: Optional[dict[Hashable, AssistantNodeName]] = None): - """Add the Inkeep docs search node to the graph.""" - path_map = path_map or { - "end": AssistantNodeName.END, - "root": AssistantNodeName.ROOT, - } - inkeep_docs_node = InkeepDocsNode(self._team, self._user) - self.add_node(AssistantNodeName.INKEEP_DOCS, inkeep_docs_node) - self._graph.add_conditional_edges( - AssistantNodeName.INKEEP_DOCS, - inkeep_docs_node.router, - path_map=cast(dict[Hashable, str], path_map), - ) - return self - - def add_billing(self): - billing_node = BillingNode(self._team, self._user) - self.add_node(AssistantNodeName.BILLING, billing_node) - self._graph.add_edge(AssistantNodeName.BILLING, AssistantNodeName.ROOT) - return self - - def add_insights_search(self, end_node: AssistantNodeName = AssistantNodeName.END): - path_map = { - "end": end_node, - "root": AssistantNodeName.ROOT, - } - - insights_search_node = InsightSearchNode(self._team, self._user) - self.add_node(AssistantNodeName.INSIGHTS_SEARCH, insights_search_node) - self._graph.add_conditional_edges( - AssistantNodeName.INSIGHTS_SEARCH, - insights_search_node.router, - path_map=cast(dict[Hashable, str], path_map), - ) - return self - - def add_session_summarization(self, end_node: AssistantNodeName = AssistantNodeName.END): - session_summarization_node = SessionSummarizationNode(self._team, self._user) - self.add_node(AssistantNodeName.SESSION_SUMMARIZATION, session_summarization_node) - self._graph.add_edge(AssistantNodeName.SESSION_SUMMARIZATION, AssistantNodeName.ROOT) - return self - - def add_dashboard_creation(self, end_node: AssistantNodeName = AssistantNodeName.END): - builder = self._graph - dashboard_creation_node = DashboardCreationNode(self._team, self._user) - builder.add_node(AssistantNodeName.DASHBOARD_CREATION, dashboard_creation_node) - builder.add_edge(AssistantNodeName.DASHBOARD_CREATION, AssistantNodeName.ROOT) - return self - def compile_full_graph(self, checkpointer: DjangoCheckpointer | None = None): return ( self.add_title_generator() @@ -388,11 +136,5 @@ class AssistantGraph(BaseAssistantGraph[AssistantState]): .add_memory_collector() .add_memory_collector_tools() .add_root() - .add_insights() - .add_inkeep_docs() - .add_billing() - .add_insights_search() - .add_session_summarization() - .add_dashboard_creation() .compile(checkpointer=checkpointer) ) diff --git a/ee/hogai/graph/inkeep_docs/nodes.py b/ee/hogai/graph/inkeep_docs/nodes.py index f4d8979789..7ab0dc6837 100644 --- a/ee/hogai/graph/inkeep_docs/nodes.py +++ b/ee/hogai/graph/inkeep_docs/nodes.py @@ -1,5 +1,5 @@ from collections.abc import Mapping, Sequence -from typing import Any, Literal +from typing import Any from uuid import uuid4 from django.conf import settings @@ -34,19 +34,22 @@ class InkeepDocsNode(RootNode): # Inheriting from RootNode to use the same mess async def arun(self, state: AssistantState, config: RunnableConfig) -> PartialAssistantState: """Process the state and return documentation search results.""" - self.dispatcher.message(AssistantMessage(content="Checking PostHog documentation...", id=str(uuid4()))) + self.dispatcher.update("Checking PostHog documentation...") + messages = self._construct_messages( state.messages, state.root_conversation_start_id, state.root_tool_calls_count ) + message: LangchainAIMessage = await self._get_model().ainvoke(messages, config) - # NOTE: This is a hacky way to send these messages as part of the root tool call - # Can't think of a better interface for this at the moment. - self.dispatcher.set_as_root() + should_continue = INKEEP_DATA_CONTINUATION_PHRASE in message.content + + tool_prompt = "Checking PostHog documentation..." + if should_continue: + tool_prompt = "The documentation search results are provided in the next Assistant message.\nContinue with the user's data request." + return PartialAssistantState( messages=[ - AssistantToolCallMessage( - content="Checking PostHog documentation...", tool_call_id=state.root_tool_call_id, id=str(uuid4()) - ), + AssistantToolCallMessage(content=tool_prompt, tool_call_id=state.root_tool_call_id, id=str(uuid4())), AssistantMessage(content=message.content, id=str(uuid4())), ], root_tool_call_id=None, @@ -116,16 +119,3 @@ class InkeepDocsNode(RootNode): # Inheriting from RootNode to use the same mess ) -> list[BaseMessage]: # Original node has Anthropic messages, but Inkeep expects OpenAI messages return convert_to_openai_messages(conversation_window, tool_result_messages) - - def router(self, state: AssistantState) -> Literal["end", "root"]: - last_message = state.messages[-1] - if isinstance(last_message, AssistantMessage) and INKEEP_DATA_CONTINUATION_PHRASE in last_message.content: - # The continuation phrase solution is a little weird, but seems it's the best one for agentic capabilities - # I've found here. The alternatives that definitively don't work are: - # 1. Using tool calls in this node - the Inkeep API only supports providing their own pre-defined tools - # (for including extra search metadata), nothing else - # 2. Always going back to root, for root to judge whether to continue or not - GPT-4o is terrible at this, - # and I was unable to stop it from repeating the context from the last assistant message, i.e. the Inkeep - # output message (doesn't quite work to tell it to output an empty message, or to call an "end" tool) - return "root" - return "end" diff --git a/ee/hogai/graph/inkeep_docs/test/test_nodes.py b/ee/hogai/graph/inkeep_docs/test/test_nodes.py index 0afdda3d4a..5b7bed7c1d 100644 --- a/ee/hogai/graph/inkeep_docs/test/test_nodes.py +++ b/ee/hogai/graph/inkeep_docs/test/test_nodes.py @@ -67,6 +67,9 @@ class TestInkeepDocsNode(ClickhouseTestMixin, BaseTest): assert next_state is not None messages = cast(list, next_state.messages) self.assertEqual(len(messages), 2) + # Tool call message should have the continuation prompt + first_message = cast(AssistantToolCallMessage, messages[0]) + self.assertIn("Continue with the user's data request", first_message.content) second_message = cast(AssistantMessage, messages[1]) self.assertEqual(second_message.content, response_with_continuation) @@ -91,26 +94,6 @@ class TestInkeepDocsNode(ClickhouseTestMixin, BaseTest): self.assertIsInstance(messages[2], LangchainAIMessage) self.assertIsInstance(messages[3], LangchainHumanMessage) - def test_router_with_data_continuation(self): - node = InkeepDocsNode(self.team, self.user) - state = AssistantState( - messages=[ - HumanMessage(content="Explain PostHog trends, and show me an example trends insight"), - AssistantMessage(content=f"Here's the documentation: XYZ.\n{INKEEP_DATA_CONTINUATION_PHRASE}"), - ] - ) - self.assertEqual(node.router(state), "root") # Going back to root, so that the agent can continue with the task - - def test_router_without_data_continuation(self): - node = InkeepDocsNode(self.team, self.user) - state = AssistantState( - messages=[ - HumanMessage(content="How do I use feature flags?"), - AssistantMessage(content="Here's how to use feature flags..."), - ] - ) - self.assertEqual(node.router(state), "end") # Ending - async def test_tool_call_id_handling(self): """Test that tool_call_id is properly handled in both input and output states.""" test_tool_call_id = str(uuid4()) diff --git a/ee/hogai/graph/insights/nodes.py b/ee/hogai/graph/insights/nodes.py index 9939ea7831..4283db8693 100644 --- a/ee/hogai/graph/insights/nodes.py +++ b/ee/hogai/graph/insights/nodes.py @@ -4,7 +4,7 @@ import inspect import warnings from datetime import timedelta from functools import wraps -from typing import Literal, Optional, TypedDict +from typing import TYPE_CHECKING, Optional, TypedDict from uuid import uuid4 from django.db.models import Max @@ -16,7 +16,7 @@ from langchain_core.runnables import RunnableConfig from langchain_core.tools import tool from langchain_openai import ChatOpenAI -from posthog.schema import AssistantMessage, AssistantToolCallMessage, VisualizationMessage +from posthog.schema import AssistantToolCallMessage, VisualizationMessage from posthog.exceptions_capture import capture_exception from posthog.models import Insight @@ -28,18 +28,19 @@ from ee.hogai.graph.shared_prompts import HYPERLINK_USAGE_INSTRUCTIONS from ee.hogai.utils.helpers import build_insight_url from ee.hogai.utils.types import AssistantState, PartialAssistantState from ee.hogai.utils.types.base import AssistantNodeName -from ee.hogai.utils.types.composed import MaxNodeName from .prompts import ( - EMPTY_DATABASE_ERROR_MESSAGE, ITERATIVE_SEARCH_SYSTEM_PROMPT, ITERATIVE_SEARCH_USER_PROMPT, NO_INSIGHTS_FOUND_MESSAGE, PAGINATION_INSTRUCTIONS_TEMPLATE, - SEARCH_ERROR_INSTRUCTIONS, TOOL_BASED_EVALUATION_SYSTEM_PROMPT, ) +if TYPE_CHECKING: + from ee.hogai.utils.types.composed import MaxNodeName + + logger = structlog.get_logger(__name__) # Silence Pydantic serializer warnings for creation of VisualizationMessage/Query execution warnings.filterwarnings("ignore", category=UserWarning, message=".*Pydantic serializer.*") @@ -105,6 +106,10 @@ class InsightDict(TypedDict): short_id: str +class NoInsightsException(Exception): + """Exception indicating that the insight search cannot be done because the user does not have any insights.""" + + class InsightSearchNode(AssistantNode): PAGE_SIZE = 500 MAX_SEARCH_ITERATIONS = 6 @@ -114,7 +119,7 @@ class InsightSearchNode(AssistantNode): MAX_SERIES_TO_PROCESS = 3 @property - def node_name(self) -> MaxNodeName: + def node_name(self) -> "MaxNodeName": return AssistantNodeName.INSIGHTS_SEARCH def __init__(self, *args, **kwargs): @@ -133,9 +138,31 @@ class InsightSearchNode(AssistantNode): self._query_cache = {} self._insight_id_cache = {} - def _dispatch_update_message(self, content: str) -> None: - """Dispatch an update message to the assistant.""" - self.dispatcher.message(AssistantMessage(content=content)) + @timing_logger("InsightSearchNode.arun") + async def arun(self, state: AssistantState, config: RunnableConfig) -> PartialAssistantState | None: + self.dispatcher.update("Searching for insights") + search_query = state.search_insights_query + self._current_iteration = 0 + + total_count = await self._get_total_insights_count() + if total_count == 0: + raise NoInsightsException + + selected_insights = await self._search_insights_iteratively(search_query or "") + logger.warning( + f"{TIMING_LOG_PREFIX} search_insights_iteratively returned {len(selected_insights)} insights: {selected_insights}" + ) + + if selected_insights: + self.dispatcher.update(f"Evaluating {len(selected_insights)} insights to find the best match") + else: + self.dispatcher.update("No existing insights found, creating a new one") + + evaluation_result = await self._evaluate_insights_with_tools( + selected_insights, search_query or "", max_selections=1 + ) + + return self._handle_evaluation_result(evaluation_result, state) def _create_page_reader_tool(self): """Create tool for reading insights pages during agentic RAG loop.""" @@ -185,36 +212,6 @@ class InsightSearchNode(AssistantNode): return [select_insight, reject_all_insights] - @timing_logger("InsightSearchNode.arun") - async def arun(self, state: AssistantState, config: RunnableConfig) -> PartialAssistantState | None: - self._dispatch_update_message("Searching for insights") - search_query = state.search_insights_query - try: - self._current_iteration = 0 - - total_count = await self._get_total_insights_count() - if total_count == 0: - return self._handle_empty_database(state) - - selected_insights = await self._search_insights_iteratively(search_query or "") - logger.warning( - f"{TIMING_LOG_PREFIX} search_insights_iteratively returned {len(selected_insights)} insights: {selected_insights}" - ) - - if selected_insights: - self._dispatch_update_message(f"Evaluating {len(selected_insights)} insights to find the best match") - else: - self._dispatch_update_message("No existing insights found, creating a new one") - - evaluation_result = await self._evaluate_insights_with_tools( - selected_insights, search_query or "", max_selections=1 - ) - - return self._handle_evaluation_result(evaluation_result, state) - - except Exception as e: - return self._handle_search_error(e, state) - @timing_logger("InsightSearchNode._get_insights_queryset") def _get_insights_queryset(self): """Get Insight objects with latest view time annotated and cutoff date.""" @@ -235,10 +232,6 @@ class InsightSearchNode(AssistantNode): self._total_insights_count = await self._get_insights_queryset().acount() return self._total_insights_count - def _handle_empty_database(self, state: AssistantState) -> PartialAssistantState: - """Handle the case when no insights exist in the database. (Rare edge-case but still possible)""" - return self._create_error_response(EMPTY_DATABASE_ERROR_MESSAGE, state.root_tool_call_id) - def _handle_evaluation_result(self, evaluation_result: dict, state: AssistantState) -> PartialAssistantState: """Process the evaluation result and return appropriate response.""" if evaluation_result["should_use_existing"]: @@ -256,12 +249,12 @@ class InsightSearchNode(AssistantNode): return PartialAssistantState( messages=[ + *evaluation_result["visualization_messages"], AssistantToolCallMessage( content=formatted_content, tool_call_id=state.root_tool_call_id or "unknown", id=str(uuid4()), ), - *evaluation_result["visualization_messages"], ], selected_insight_ids=evaluation_result["selected_insights"], search_insights_query=None, @@ -284,16 +277,6 @@ class InsightSearchNode(AssistantNode): selected_insight_ids=None, ) - @timing_logger("InsightSearchNode._handle_search_error") - def _handle_search_error(self, e: Exception, state: AssistantState) -> PartialAssistantState: - """Handle exceptions during search process.""" - capture_exception(e) - logger.error(f"{TIMING_LOG_PREFIX} Error in InsightSearchNode: {e}", exc_info=True) - return self._create_error_response( - SEARCH_ERROR_INSTRUCTIONS, - state.root_tool_call_id, - ) - def _format_insight_for_display(self, insight: InsightDict) -> str: """Format a single insight for display.""" name = insight["name"] or insight["derived_name"] or "Unnamed" @@ -424,7 +407,7 @@ class InsightSearchNode(AssistantNode): selected_insights = [] for step in ["Searching through existing insights", "Analyzing available insights"]: - self._dispatch_update_message(step) + self.dispatcher.update(step) logger.warning(f"{TIMING_LOG_PREFIX} Starting iterative search, max_iterations={self._max_iterations}") while self._current_iteration < self._max_iterations: @@ -446,7 +429,7 @@ class InsightSearchNode(AssistantNode): logger.warning( f"{TIMING_LOG_PREFIX} STALL POINT(?): Streamed 'Finding the most relevant insights' - about to fetch page content" ) - self._dispatch_update_message("Finding the most relevant insights") + self.dispatcher.update("Finding the most relevant insights") logger.warning(f"{TIMING_LOG_PREFIX} Fetching page content for page {page_num}") tool_response = await self._get_page_content_for_tool(page_num) @@ -465,15 +448,15 @@ class InsightSearchNode(AssistantNode): content = response.content if isinstance(response.content, str) else str(response.content) selected_insights = self._parse_insight_ids(content) if selected_insights: - self._dispatch_update_message(f"Found {len(selected_insights)} relevant insights") + self.dispatcher.update(f"Found {len(selected_insights)} relevant insights") else: - self._dispatch_update_message("No matching insights found") + self.dispatcher.update("No matching insights found") break except Exception as e: capture_exception(e) error_message = f"Error during search" - self._dispatch_update_message(error_message) + self.dispatcher.update(error_message) break return selected_insights @@ -692,7 +675,7 @@ class InsightSearchNode(AssistantNode): """Create a VisualizationMessage to render the insight UI.""" try: for step in ["Executing insight query...", "Processing query parameters", "Running data analysis"]: - self._dispatch_update_message(step) + self.dispatcher.update(step) query_obj, _ = await self._process_insight_query(insight) @@ -786,7 +769,7 @@ class InsightSearchNode(AssistantNode): async def _run_evaluation_loop(self, user_query: str, insights_summary: list[str], max_selections: int) -> None: """Run the evaluation loop with LLM.""" for step in ["Analyzing insights to match your request", "Comparing insights for best fit"]: - self._dispatch_update_message(step) + self.dispatcher.update(step) tools = self._create_insight_evaluation_tools() llm_with_tools = self._model.bind_tools(tools) @@ -800,7 +783,7 @@ class InsightSearchNode(AssistantNode): if getattr(response, "tool_calls", None): # Only stream on first iteration to avoid noise if iteration == 0: - self._dispatch_update_message("Making evaluation decisions") + self.dispatcher.update("Making evaluation decisions") self._process_evaluation_tool_calls(response, messages, tools) else: break @@ -841,7 +824,7 @@ class InsightSearchNode(AssistantNode): num_insights = len(self._evaluation_selections) insight_word = "insight" if num_insights == 1 else "insights" - self._dispatch_update_message(f"Perfect! Found {num_insights} suitable {insight_word}") + self.dispatcher.update(f"Perfect! Found {num_insights} suitable {insight_word}") for _, selection in self._evaluation_selections.items(): insight = selection["insight"] @@ -871,7 +854,7 @@ class InsightSearchNode(AssistantNode): async def _create_rejection_result(self) -> dict: """Create result for when all insights are rejected.""" - self._dispatch_update_message("Will create a custom insight tailored to your request") + self.dispatcher.update("Will create a custom insight tailored to your request") return { "should_use_existing": False, @@ -880,9 +863,6 @@ class InsightSearchNode(AssistantNode): "visualization_messages": [], } - def router(self, state: AssistantState) -> Literal["root"]: - return "root" - @property def _model(self): return ChatOpenAI( diff --git a/ee/hogai/graph/insights/prompts.py b/ee/hogai/graph/insights/prompts.py index 10e2472df5..706aa3e0a1 100644 --- a/ee/hogai/graph/insights/prompts.py +++ b/ee/hogai/graph/insights/prompts.py @@ -36,7 +36,3 @@ Instructions: NO_INSIGHTS_FOUND_MESSAGE = ( "No existing insights found matching your query. Creating a new insight based on your request." ) - -SEARCH_ERROR_INSTRUCTIONS = "INSTRUCTIONS: Tell the user that you encountered an issue while searching for insights and suggest they try again with a different search term." - -EMPTY_DATABASE_ERROR_MESSAGE = "No insights found in the database." diff --git a/ee/hogai/graph/insights/test/test_nodes.py b/ee/hogai/graph/insights/test/test_nodes.py index 410f84a141..a1a18e0c20 100644 --- a/ee/hogai/graph/insights/test/test_nodes.py +++ b/ee/hogai/graph/insights/test/test_nodes.py @@ -23,7 +23,7 @@ from posthog.schema import ( from posthog.models import Insight, InsightViewed -from ee.hogai.graph.insights.nodes import InsightDict, InsightSearchNode +from ee.hogai.graph.insights.nodes import InsightDict, InsightSearchNode, NoInsightsException from ee.hogai.utils.types import AssistantState, PartialAssistantState from ee.models.assistant import Conversation @@ -115,11 +115,6 @@ class TestInsightSearchNode(BaseTest): short_id=insight.short_id, ) - def test_router_returns_root(self): - """Test that router returns 'root' as expected.""" - result = self.node.router(AssistantState(messages=[])) - self.assertEqual(result, "root") - async def test_load_insights_page(self): """Test loading paginated insights from database.""" # Load first page @@ -350,12 +345,6 @@ class TestInsightSearchNode(BaseTest): # Should return empty list when LLM fails to select anything self.assertEqual(len(result), 0) - def test_router_always_returns_root(self): - """Test that router always returns 'root'.""" - state = AssistantState(messages=[], root_tool_insight_plan="some plan", search_insights_query=None) - result = self.node.router(state) - self.assertEqual(result, "root") - async def test_evaluation_flow_returns_creation_when_no_suitable_insights(self): """Test that when evaluation returns NO, the system transitions to creation flow.""" selected_insights = [self.insight1.id, self.insight2.id] @@ -411,21 +400,6 @@ class TestInsightSearchNode(BaseTest): "root_tool_insight_plan should be set to search_query", ) - # Test router behavior with the returned state - # Create a new state that simulates what happens after this node runs - post_evaluation_state = AssistantState( - messages=state.messages, - root_tool_insight_plan=search_query, # This gets set to search_query - search_insights_query=None, # This gets cleared - ) - - router_result = self.node.router(post_evaluation_state) - self.assertEqual( - router_result, - "root", - "Router should always return root", - ) - # Verify that _evaluate_insights_with_tools was called with the search_query mock_evaluate.assert_called_once_with(selected_insights, search_query, max_selections=1) @@ -481,7 +455,7 @@ class TestInsightSearchNode(BaseTest): self.assertIsNone(result.root_tool_call_id) def test_run_with_no_insights(self): - """Test arun method when no insights exist.""" + """Test arun method when no insights exist - should raise NoInsightsException.""" # Clear all insights (done outside async context) InsightViewed.objects.all().delete() Insight.objects.all().delete() @@ -497,13 +471,10 @@ class TestInsightSearchNode(BaseTest): async def async_test(): # Mock the database calls that happen in async context with patch.object(self.node, "_get_total_insights_count", return_value=0): - result = await self.node.arun(state, {"configurable": {"thread_id": str(conversation.id)}}) - return result + await self.node.arun(state, {"configurable": {"thread_id": str(conversation.id)}}) - result = asyncio.run(async_test()) - self.assertIsInstance(result, PartialAssistantState) - self.assertEqual(len(result.messages), 1) - self.assertIn("No insights found in the database", result.messages[0].content) + with self.assertRaises(NoInsightsException): + asyncio.run(async_test()) async def test_team_filtering(self): """Test that insights are filtered by team.""" diff --git a/ee/hogai/graph/billing/test/__init__.py b/ee/hogai/graph/insights_graph/__init__.py similarity index 100% rename from ee/hogai/graph/billing/test/__init__.py rename to ee/hogai/graph/insights_graph/__init__.py diff --git a/ee/hogai/graph/insights_graph/graph.py b/ee/hogai/graph/insights_graph/graph.py new file mode 100644 index 0000000000..6af85172c9 --- /dev/null +++ b/ee/hogai/graph/insights_graph/graph.py @@ -0,0 +1,155 @@ +from typing import Literal, Optional + +from ee.hogai.django_checkpoint.checkpointer import DjangoCheckpointer +from ee.hogai.graph.base import BaseAssistantGraph +from ee.hogai.utils.types.base import AssistantGraphName, AssistantNodeName, AssistantState, PartialAssistantState + +from ..funnels.nodes import FunnelGeneratorNode, FunnelGeneratorToolsNode +from ..query_executor.nodes import QueryExecutorNode +from ..query_planner.nodes import QueryPlannerNode, QueryPlannerToolsNode +from ..rag.nodes import InsightRagContextNode +from ..retention.nodes import RetentionGeneratorNode, RetentionGeneratorToolsNode +from ..sql.nodes import SQLGeneratorNode, SQLGeneratorToolsNode +from ..trends.nodes import TrendsGeneratorNode, TrendsGeneratorToolsNode + + +class InsightsGraph(BaseAssistantGraph[AssistantState, PartialAssistantState]): + @property + def graph_name(self) -> AssistantGraphName: + return AssistantGraphName.INSIGHTS + + @property + def state_type(self) -> type[AssistantState]: + return AssistantState + + def add_rag_context(self): + self._has_start_node = True + retriever = InsightRagContextNode(self._team, self._user) + self.add_node(AssistantNodeName.INSIGHT_RAG_CONTEXT, retriever) + self._graph.add_edge(AssistantNodeName.START, AssistantNodeName.INSIGHT_RAG_CONTEXT) + self._graph.add_edge(AssistantNodeName.INSIGHT_RAG_CONTEXT, AssistantNodeName.QUERY_PLANNER) + return self + + def add_trends_generator(self, next_node: AssistantNodeName = AssistantNodeName.QUERY_EXECUTOR): + trends_generator = TrendsGeneratorNode(self._team, self._user) + self.add_node(AssistantNodeName.TRENDS_GENERATOR, trends_generator) + + trends_generator_tools = TrendsGeneratorToolsNode(self._team, self._user) + self.add_node(AssistantNodeName.TRENDS_GENERATOR_TOOLS, trends_generator_tools) + + self._graph.add_edge(AssistantNodeName.TRENDS_GENERATOR_TOOLS, AssistantNodeName.TRENDS_GENERATOR) + self._graph.add_conditional_edges( + AssistantNodeName.TRENDS_GENERATOR, + trends_generator.router, + path_map={ + "tools": AssistantNodeName.TRENDS_GENERATOR_TOOLS, + "next": next_node, + }, + ) + + return self + + def add_funnel_generator(self, next_node: AssistantNodeName = AssistantNodeName.QUERY_EXECUTOR): + funnel_generator = FunnelGeneratorNode(self._team, self._user) + self.add_node(AssistantNodeName.FUNNEL_GENERATOR, funnel_generator) + + funnel_generator_tools = FunnelGeneratorToolsNode(self._team, self._user) + self.add_node(AssistantNodeName.FUNNEL_GENERATOR_TOOLS, funnel_generator_tools) + + self._graph.add_edge(AssistantNodeName.FUNNEL_GENERATOR_TOOLS, AssistantNodeName.FUNNEL_GENERATOR) + self._graph.add_conditional_edges( + AssistantNodeName.FUNNEL_GENERATOR, + funnel_generator.router, + path_map={ + "tools": AssistantNodeName.FUNNEL_GENERATOR_TOOLS, + "next": next_node, + }, + ) + + return self + + def add_retention_generator(self, next_node: AssistantNodeName = AssistantNodeName.QUERY_EXECUTOR): + retention_generator = RetentionGeneratorNode(self._team, self._user) + self.add_node(AssistantNodeName.RETENTION_GENERATOR, retention_generator) + + retention_generator_tools = RetentionGeneratorToolsNode(self._team, self._user) + self.add_node(AssistantNodeName.RETENTION_GENERATOR_TOOLS, retention_generator_tools) + + self._graph.add_edge(AssistantNodeName.RETENTION_GENERATOR_TOOLS, AssistantNodeName.RETENTION_GENERATOR) + self._graph.add_conditional_edges( + AssistantNodeName.RETENTION_GENERATOR, + retention_generator.router, + path_map={ + "tools": AssistantNodeName.RETENTION_GENERATOR_TOOLS, + "next": next_node, + }, + ) + + return self + + def add_query_planner( + self, + path_map: Optional[ + dict[Literal["trends", "funnel", "retention", "sql", "continue", "end"], AssistantNodeName] + ] = None, + ): + query_planner = QueryPlannerNode(self._team, self._user) + self.add_node(AssistantNodeName.QUERY_PLANNER, query_planner) + self._graph.add_edge(AssistantNodeName.QUERY_PLANNER, AssistantNodeName.QUERY_PLANNER_TOOLS) + + query_planner_tools = QueryPlannerToolsNode(self._team, self._user) + self.add_node(AssistantNodeName.QUERY_PLANNER_TOOLS, query_planner_tools) + self._graph.add_conditional_edges( + AssistantNodeName.QUERY_PLANNER_TOOLS, + query_planner_tools.router, + path_map=path_map # type: ignore + or { + "continue": AssistantNodeName.QUERY_PLANNER, + "trends": AssistantNodeName.TRENDS_GENERATOR, + "funnel": AssistantNodeName.FUNNEL_GENERATOR, + "retention": AssistantNodeName.RETENTION_GENERATOR, + "sql": AssistantNodeName.SQL_GENERATOR, + "end": AssistantNodeName.END, + }, + ) + + return self + + def add_sql_generator(self, next_node: AssistantNodeName = AssistantNodeName.QUERY_EXECUTOR): + sql_generator = SQLGeneratorNode(self._team, self._user) + self.add_node(AssistantNodeName.SQL_GENERATOR, sql_generator) + + sql_generator_tools = SQLGeneratorToolsNode(self._team, self._user) + self.add_node(AssistantNodeName.SQL_GENERATOR_TOOLS, sql_generator_tools) + + self._graph.add_edge(AssistantNodeName.SQL_GENERATOR_TOOLS, AssistantNodeName.SQL_GENERATOR) + self._graph.add_conditional_edges( + AssistantNodeName.SQL_GENERATOR, + sql_generator.router, + path_map={ + "tools": AssistantNodeName.SQL_GENERATOR_TOOLS, + "next": next_node, + }, + ) + + return self + + def add_query_executor(self, next_node: AssistantNodeName = AssistantNodeName.END): + query_executor_node = QueryExecutorNode(self._team, self._user) + self.add_node(AssistantNodeName.QUERY_EXECUTOR, query_executor_node) + self._graph.add_edge(AssistantNodeName.QUERY_EXECUTOR, next_node) + return self + + def add_query_creation_flow(self, next_node: AssistantNodeName = AssistantNodeName.QUERY_EXECUTOR): + """Add all nodes and edges EXCEPT query execution.""" + return ( + self.add_rag_context() + .add_query_planner() + .add_trends_generator(next_node=next_node) + .add_funnel_generator(next_node=next_node) + .add_retention_generator(next_node=next_node) + .add_sql_generator(next_node=next_node) + ) + + def compile_full_graph(self, checkpointer: DjangoCheckpointer | None = None): + return self.add_query_creation_flow().add_query_executor().compile(checkpointer=checkpointer) diff --git a/ee/hogai/graph/mixins.py b/ee/hogai/graph/mixins.py index f6596833f6..bbc309bb9d 100644 --- a/ee/hogai/graph/mixins.py +++ b/ee/hogai/graph/mixins.py @@ -1,5 +1,5 @@ import datetime -from abc import ABC +from abc import ABC, abstractmethod from typing import Any, get_args, get_origin from uuid import UUID @@ -7,15 +7,15 @@ from django.utils import timezone from langchain_core.runnables import RunnableConfig -from posthog.schema import AssistantMessage, CurrencyCode +from posthog.schema import CurrencyCode from posthog.event_usage import groups from posthog.models import Team from posthog.models.action.action import Action from posthog.models.user import User -from ee.hogai.utils.dispatcher import AssistantDispatcher -from ee.hogai.utils.types.base import BaseStateWithIntermediateSteps +from ee.hogai.utils.dispatcher import AssistantDispatcher, create_dispatcher_from_config +from ee.hogai.utils.types.base import BaseStateWithIntermediateSteps, NodePath from ee.models import Conversation, CoreMemory @@ -194,4 +194,33 @@ class TaxonomyUpdateDispatcherNodeMixin: content = "Picking relevant events and properties" if substeps: content = substeps[-1] - self.dispatcher.message(AssistantMessage(content=content)) + self.dispatcher.update(content) + + +class AssistantDispatcherMixin(ABC): + _node_path: tuple[NodePath, ...] + _config: RunnableConfig | None + _dispatcher: AssistantDispatcher | None = None + + @property + def node_path(self) -> tuple[NodePath, ...]: + return self._node_path + + @property + @abstractmethod + def node_name(self) -> str: ... + + @property + def tool_call_id(self) -> str: + parent_tool_call_id = next((path.tool_call_id for path in reversed(self._node_path) if path.tool_call_id), None) + if not parent_tool_call_id: + raise ValueError("No tool call ID found") + return parent_tool_call_id + + @property + def dispatcher(self) -> AssistantDispatcher: + """Create a dispatcher for this node""" + if self._dispatcher: + return self._dispatcher + self._dispatcher = create_dispatcher_from_config(self._config or {}, self.node_path) + return self._dispatcher diff --git a/ee/hogai/graph/parallel_task_execution/mixins.py b/ee/hogai/graph/parallel_task_execution/mixins.py index 4191df430e..a36c80c1fc 100644 --- a/ee/hogai/graph/parallel_task_execution/mixins.py +++ b/ee/hogai/graph/parallel_task_execution/mixins.py @@ -52,7 +52,7 @@ class WithInsightCreationTaskExecution: The type allows None for compatibility with the base class. """ # Import here to avoid circular dependency - from ee.hogai.graph.graph import InsightsAssistantGraph + from ee.hogai.graph.insights_graph.graph import InsightsGraph task = cast(AssistantToolCall, input_dict["task"]) artifacts = input_dict["artifacts"] @@ -60,7 +60,7 @@ class WithInsightCreationTaskExecution: self._current_task_id = task.id - # This is needed by the InsightsAssistantGraph to return an AssistantToolCallMessage + # This is needed by the InsightsGraph to return an AssistantToolCallMessage task_tool_call_id = f"task_{uuid.uuid4().hex[:8]}" query = task.args["query_description"] @@ -77,9 +77,7 @@ class WithInsightCreationTaskExecution: ) subgraph_result_messages: list[AssistantMessageUnion] = [] - assistant_graph = InsightsAssistantGraph( - self._team, self._user, tool_call_id=self._parent_tool_call_id - ).compile_full_graph() + assistant_graph = InsightsGraph(self._team, self._user).compile_full_graph() try: async for chunk in assistant_graph.astream( input_state, @@ -156,7 +154,7 @@ class WithInsightCreationTaskExecution: if isinstance(message, VisualizationMessage) and message.id: artifact = InsightArtifact( task_id=tool_call.id, - id=None, # The InsightsAssistantGraph does not create the insight objects + id=None, # The InsightsGraph does not create the insight objects content="", query=cast(AnyAssistantGeneratedQuery, message.answer), ) @@ -194,9 +192,7 @@ class WithInsightSearchTaskExecution: ) try: - result = await InsightSearchNode(self._team, self._user, tool_call_id=self._parent_tool_call_id).arun( - input_state, config - ) + result = await InsightSearchNode(self._team, self._user).arun(input_state, config) if not result or not result.messages: logger.warning("Task failed: no messages received from node executor", task_id=task.id) diff --git a/ee/hogai/graph/query_executor/nodes.py b/ee/hogai/graph/query_executor/nodes.py index af8df9d8a1..7acb0ec5c6 100644 --- a/ee/hogai/graph/query_executor/nodes.py +++ b/ee/hogai/graph/query_executor/nodes.py @@ -76,6 +76,21 @@ class QueryExecutorNode(AssistantNode): ) return PartialAssistantState(messages=[FailureMessage(content=str(err), id=str(uuid4()))]) + return PartialAssistantState( + messages=[ + AssistantToolCallMessage( + content=self._format_query_result(viz_message, results, example_prompt), + id=str(uuid4()), + tool_call_id=tool_call_id, + ) + ], + root_tool_call_id=None, + root_tool_insight_plan=None, + root_tool_insight_type=None, + rag_context=None, + ) + + def _format_query_result(self, viz_message: VisualizationMessage, results: str, example_prompt: str) -> str: query_result = QUERY_RESULTS_PROMPT.format( query_kind=viz_message.answer.kind, results=results, @@ -84,20 +99,10 @@ class QueryExecutorNode(AssistantNode): project_timezone=self.project_timezone, currency=self.project_currency, ) - formatted_query_result = f"{example_prompt}\n\n{query_result}" if isinstance(viz_message.answer, AssistantHogQLQuery): formatted_query_result = f"{example_prompt}\n\n{SQL_QUERY_PROMPT.format(query=viz_message.answer.query)}\n\n{formatted_query_result}" - - return PartialAssistantState( - messages=[ - AssistantToolCallMessage(content=formatted_query_result, id=str(uuid4()), tool_call_id=tool_call_id) - ], - root_tool_call_id=None, - root_tool_insight_plan=None, - root_tool_insight_type=None, - rag_context=None, - ) + return formatted_query_result def _get_example_prompt(self, viz_message: VisualizationMessage) -> str: if isinstance(viz_message.answer, AssistantTrendsQuery | TrendsQuery): diff --git a/ee/hogai/graph/retention/nodes.py b/ee/hogai/graph/retention/nodes.py index 689bf5c59a..6080c90b21 100644 --- a/ee/hogai/graph/retention/nodes.py +++ b/ee/hogai/graph/retention/nodes.py @@ -1,7 +1,7 @@ from langchain_core.prompts import ChatPromptTemplate from langchain_core.runnables import RunnableConfig -from posthog.schema import AssistantMessage, AssistantRetentionQuery +from posthog.schema import AssistantRetentionQuery from ee.hogai.utils.types import AssistantState, PartialAssistantState from ee.hogai.utils.types.base import AssistantNodeName @@ -25,7 +25,7 @@ class RetentionGeneratorNode(SchemaGeneratorNode[AssistantRetentionQuery]): return AssistantNodeName.RETENTION_GENERATOR async def arun(self, state: AssistantState, config: RunnableConfig) -> PartialAssistantState: - self.dispatcher.message(AssistantMessage(content="Creating retention query")) + self.dispatcher.update("Creating retention query") prompt = ChatPromptTemplate.from_messages( [ ("system", RETENTION_SYSTEM_PROMPT), diff --git a/ee/hogai/graph/root/nodes.py b/ee/hogai/graph/root/nodes.py index 7cac956db9..b7dabd099c 100644 --- a/ee/hogai/graph/root/nodes.py +++ b/ee/hogai/graph/root/nodes.py @@ -3,27 +3,23 @@ from collections.abc import Awaitable, Mapping, Sequence from typing import TYPE_CHECKING, Literal, TypeVar, Union from uuid import uuid4 +import structlog import posthoganalytics from langchain_core.messages import ( AIMessage as LangchainAIMessage, BaseMessage, HumanMessage as LangchainHumanMessage, + ToolCall, ToolMessage as LangchainToolMessage, ) from langchain_core.prompts import ChatPromptTemplate from langchain_core.runnables import RunnableConfig from langgraph.errors import NodeInterrupt +from langgraph.types import Send from posthoganalytics import capture_exception from pydantic import BaseModel -from posthog.schema import ( - AssistantMessage, - AssistantTool, - AssistantToolCallMessage, - ContextMessage, - FailureMessage, - HumanMessage, -) +from posthog.schema import AssistantMessage, AssistantToolCallMessage, ContextMessage, FailureMessage, HumanMessage from posthog.models import Team, User @@ -36,8 +32,14 @@ from ee.hogai.tool import ToolMessagesArtifact from ee.hogai.utils.anthropic import add_cache_control, convert_to_anthropic_messages from ee.hogai.utils.helpers import convert_tool_messages_to_dict, normalize_ai_message from ee.hogai.utils.prompt import format_prompt_string -from ee.hogai.utils.types import AssistantMessageUnion, AssistantNodeName, AssistantState, InsightQuery -from ee.hogai.utils.types.base import PartialAssistantState, ReplaceMessages +from ee.hogai.utils.types.base import ( + AssistantMessageUnion, + AssistantNodeName, + AssistantState, + NodePath, + PartialAssistantState, + ReplaceMessages, +) from ee.hogai.utils.types.composed import MaxNodeName from .prompts import ( @@ -51,13 +53,13 @@ from .prompts import ( ROOT_TOOL_DOES_NOT_EXIST, ) from .tools import ( + CreateAndQueryInsightTool, + CreateDashboardTool, ReadDataTool, ReadTaxonomyTool, SearchTool, + SessionSummarizationTool, TodoWriteTool, - create_and_query_insight, - create_dashboard, - session_summarization, ) if TYPE_CHECKING: @@ -66,24 +68,14 @@ if TYPE_CHECKING: SLASH_COMMAND_INIT = "/init" SLASH_COMMAND_REMEMBER = "/remember" -RouteName = Literal[ - "insights", - "root", - "end", - "search_documentation", - "memory_onboarding", - "insights_search", - "billing", - "session_summarization", - "create_dashboard", -] - RootMessageUnion = HumanMessage | AssistantMessage | FailureMessage | AssistantToolCallMessage | ContextMessage T = TypeVar("T", RootMessageUnion, BaseMessage) RootTool = Union[type[BaseModel], "MaxTool"] +logger = structlog.get_logger(__name__) + class RootNode(AssistantNode): MAX_TOOL_CALLS = 24 @@ -95,8 +87,8 @@ class RootNode(AssistantNode): Determines the thinking configuration for the model. """ - def __init__(self, team: Team, user: User): - super().__init__(team, user) + def __init__(self, team: Team, user: User, node_path: tuple[NodePath, ...] | None = None): + super().__init__(team, user, node_path) self._window_manager = AnthropicConversationCompactionManager() async def arun(self, state: AssistantState, config: RunnableConfig) -> PartialAssistantState: @@ -168,12 +160,25 @@ class RootNode(AssistantNode): if messages_to_replace: new_messages = ReplaceMessages([*messages_to_replace, assistant_message]) + # Set new tool call count + tool_call_count = (state.root_tool_calls_count or 0) + 1 if assistant_message.tool_calls else None + return PartialAssistantState( messages=new_messages, + root_tool_calls_count=tool_call_count, root_conversation_start_id=window_id, start_id=start_id, ) + def router(self, state: AssistantState): + last_message = state.messages[-1] + if not isinstance(last_message, AssistantMessage) or not last_message.tool_calls: + return AssistantNodeName.END + return [ + Send(AssistantNodeName.ROOT_TOOLS, state.model_copy(update={"root_tool_call_id": tool_call.id})) + for tool_call in last_message.tool_calls + ] + @property def node_name(self) -> MaxNodeName: return AssistantNodeName.ROOT @@ -226,11 +231,35 @@ class RootNode(AssistantNode): if self._is_hard_limit_reached(state.root_tool_calls_count): return base_model - return base_model.bind_tools(tools, parallel_tool_calls=False) + return base_model.bind_tools(tools, parallel_tool_calls=True) async def _get_tools(self, state: AssistantState, config: RunnableConfig) -> list[RootTool]: from ee.hogai.tool import get_contextual_tool_class + # Static toolkit + default_tools: list[type[MaxTool]] = [ + ReadTaxonomyTool, + ReadDataTool, + SearchTool, + TodoWriteTool, + ] + + # The contextual insights tool overrides the static tool. Only inject if it's injected. + if not CreateAndQueryInsightTool.is_editing_mode(self.context_manager): + default_tools.append(CreateAndQueryInsightTool) + + # Add session summarization tool if enabled + if self._has_session_summarization_feature_flag(): + default_tools.append(SessionSummarizationTool) + + # Add other lower-priority tools + default_tools.extend( + [ + CreateDashboardTool, + ] + ) + + # Processed tools available_tools: list[RootTool] = [] # Initialize the static toolkit @@ -238,37 +267,14 @@ class RootNode(AssistantNode): # This is just to bound the tools to the model dynamic_tools = ( tool_class.create_tool_class( - team=self._team, - user=self._user, - tool_call_id="", - state=state, - config=config, - context_manager=self.context_manager, - ) - for tool_class in ( - ReadTaxonomyTool, - ReadDataTool, - SearchTool, - TodoWriteTool, + team=self._team, user=self._user, state=state, config=config, context_manager=self.context_manager ) + for tool_class in default_tools ) available_tools.extend(await asyncio.gather(*dynamic_tools)) - # Insights tool - tool_names = self.context_manager.get_contextual_tools().keys() - is_editing_insight = AssistantTool.EDIT_CURRENT_INSIGHT in tool_names - if not is_editing_insight: - # This is the default tool, which can be overriden by the MaxTool based tool with the same name - available_tools.append(create_and_query_insight) - - # Check if session summarization is enabled for the user - if self._has_session_summarization_feature_flag(): - available_tools.append(session_summarization) - - # Dashboard creation tool - available_tools.append(create_dashboard) - # Inject contextual tools + tool_names = self.context_manager.get_contextual_tools().keys() awaited_contextual_tools: list[Awaitable[RootTool]] = [] for tool_name in tool_names: ContextualMaxToolClass = get_contextual_tool_class(tool_name) @@ -278,7 +284,6 @@ class RootNode(AssistantNode): ContextualMaxToolClass.create_tool_class( team=self._team, user=self._user, - tool_call_id="", state=state, config=config, context_manager=self.context_manager, @@ -366,51 +371,23 @@ class RootNodeTools(AssistantNode): def node_name(self) -> MaxNodeName: return AssistantNodeName.ROOT_TOOLS - async def arun(self, state: AssistantState, config: RunnableConfig) -> PartialAssistantState: + async def arun(self, state: AssistantState, config: RunnableConfig) -> PartialAssistantState | None: last_message = state.messages[-1] - if not isinstance(last_message, AssistantMessage) or not last_message.tool_calls: - # Reset tools. - return PartialAssistantState(root_tool_calls_count=0) - tool_call_count = state.root_tool_calls_count or 0 + reset_state = PartialAssistantState(root_tool_call_id=None) + # Should never happen, but just in case. + if not isinstance(last_message, AssistantMessage) or not last_message.id or not state.root_tool_call_id: + return reset_state - tools_calls = last_message.tool_calls - if len(tools_calls) != 1: - raise ValueError("Expected exactly one tool call.") - - tool_names = self.context_manager.get_contextual_tools().keys() - is_editing_insight = AssistantTool.EDIT_CURRENT_INSIGHT in tool_names - tool_call = tools_calls[0] + # Find the current tool call in the last message. + tool_call = next( + (tool_call for tool_call in last_message.tool_calls or [] if tool_call.id == state.root_tool_call_id), None + ) + if not tool_call: + return reset_state from ee.hogai.tool import get_contextual_tool_class - if tool_call.name == "create_and_query_insight" and not is_editing_insight: - return PartialAssistantState( - root_tool_call_id=tool_call.id, - root_tool_insight_plan=tool_call.args["query_description"], - root_tool_calls_count=tool_call_count + 1, - ) - if tool_call.name == "session_summarization": - return PartialAssistantState( - root_tool_call_id=tool_call.id, - session_summarization_query=tool_call.args["session_summarization_query"], - # Safety net in case the argument is missing to avoid raising exceptions internally - should_use_current_filters=tool_call.args.get("should_use_current_filters", False), - summary_title=tool_call.args.get("summary_title"), - root_tool_calls_count=tool_call_count + 1, - ) - if tool_call.name == "create_dashboard": - raw_queries = tool_call.args["search_insights_queries"] - search_insights_queries = [InsightQuery.model_validate(query) for query in raw_queries] - - return PartialAssistantState( - root_tool_call_id=tool_call.id, - dashboard_name=tool_call.args.get("dashboard_name"), - search_insights_queries=search_insights_queries, - root_tool_calls_count=tool_call_count + 1, - ) - - # MaxTool flow ToolClass = get_contextual_tool_class(tool_call.name) # If the tool doesn't exist, return the message to the agent @@ -423,25 +400,32 @@ class RootNodeTools(AssistantNode): tool_call_id=tool_call.id, ) ], - root_tool_calls_count=tool_call_count + 1, ) # Initialize the tool and process it tool_class = await ToolClass.create_tool_class( team=self._team, user=self._user, - tool_call_id=tool_call.id, + # Tricky: set the node path to associated with the tool call + node_path=( + *self.node_path[:-1], + NodePath(name=AssistantNodeName.ROOT_TOOLS, message_id=last_message.id, tool_call_id=tool_call.id), + ), state=state, config=config, context_manager=self.context_manager, ) + try: - result = await tool_class.ainvoke(tool_call.model_dump(), config) + result = await tool_class.ainvoke( + ToolCall(type="tool_call", name=tool_call.name, args=tool_call.args, id=tool_call.id), config=config + ) if not isinstance(result, LangchainToolMessage): raise ValueError( f"Tool '{tool_call.name}' returned {type(result).__name__}, expected LangchainToolMessage" ) except Exception as e: + logger.exception("Error calling tool", extra={"tool_name": tool_call.name, "error": str(e)}) capture_exception( e, distinct_id=self._get_user_distinct_id(config), properties=self._get_debug_props(config) ) @@ -453,38 +437,11 @@ class RootNodeTools(AssistantNode): tool_call_id=tool_call.id, ) ], - root_tool_calls_count=tool_call_count + 1, ) if isinstance(result.artifact, ToolMessagesArtifact): return PartialAssistantState( messages=result.artifact.messages, - root_tool_calls_count=tool_call_count + 1, - ) - - # Handle the basic toolkit - if result.name == "search" and isinstance(result.artifact, dict): - match result.artifact.get("kind"): - case "insights": - return PartialAssistantState( - root_tool_call_id=tool_call.id, - search_insights_query=result.artifact.get("query"), - root_tool_calls_count=tool_call_count + 1, - ) - case "docs": - return PartialAssistantState( - root_tool_call_id=tool_call.id, - root_tool_calls_count=tool_call_count + 1, - ) - - if ( - result.name == "read_data" - and isinstance(result.artifact, dict) - and result.artifact.get("kind") == "billing_info" - ): - return PartialAssistantState( - root_tool_call_id=tool_call.id, - root_tool_calls_count=tool_call_count + 1, ) # If this is a navigation tool call, pause the graph execution @@ -496,7 +453,6 @@ class RootNodeTools(AssistantNode): id=str(uuid4()), tool_call_id=tool_call.id, ) - self.dispatcher.message(navigate_message) # Raising a `NodeInterrupt` ensures the assistant graph stops here and # surfaces the navigation confirmation to the client. The next user # interaction will resume the graph with potentially different @@ -512,29 +468,11 @@ class RootNodeTools(AssistantNode): return PartialAssistantState( messages=[tool_message], - root_tool_calls_count=tool_call_count + 1, ) - def router(self, state: AssistantState) -> RouteName: + # This is only for the Inkeep node. Remove when inkeep_docs is removed. + def router(self, state: AssistantState) -> Literal["root", "end"]: last_message = state.messages[-1] - if isinstance(last_message, AssistantToolCallMessage): return "root" # Let the root either proceed or finish, since it now can see the tool call result - if isinstance(last_message, AssistantMessage) and state.root_tool_call_id: - tool_calls = getattr(last_message, "tool_calls", None) - if tool_calls and len(tool_calls) > 0: - tool_call = tool_calls[0] - tool_call_name = tool_call.name - if tool_call_name == "read_data" and tool_call.args.get("kind") == "billing_info": - return "billing" - if tool_call_name == "create_dashboard": - return "create_dashboard" - if state.root_tool_insight_plan: - return "insights" - elif state.search_insights_query: - return "insights_search" - elif state.session_summarization_query: - return "session_summarization" - else: - return "search_documentation" return "end" diff --git a/ee/hogai/graph/root/prompts.py b/ee/hogai/graph/root/prompts.py index 411ed94ca2..de6d2671a1 100644 --- a/ee/hogai/graph/root/prompts.py +++ b/ee/hogai/graph/root/prompts.py @@ -112,6 +112,10 @@ The user is a product engineer and will primarily request you perform product ma - Tool results and user messages may include tags. tags contain useful information and reminders. They are NOT part of the user's provided input or the tool result. + +- You can invoke multiple tools within a single response. When a request involves several independent pieces of information, batch your tool calls together for optimal performance + + {{{billing_context}}} {{{core_memory_prompt}}} diff --git a/ee/hogai/graph/root/test/test_nodes.py b/ee/hogai/graph/root/test/test_nodes.py index d878cee623..8c060a4354 100644 --- a/ee/hogai/graph/root/test/test_nodes.py +++ b/ee/hogai/graph/root/test/test_nodes.py @@ -38,6 +38,7 @@ from ee.hogai.graph.root.prompts import ( ROOT_BILLING_CONTEXT_WITH_ACCESS_PROMPT, ROOT_BILLING_CONTEXT_WITH_NO_ACCESS_PROMPT, ) +from ee.hogai.tool import ToolMessagesArtifact from ee.hogai.utils.tests import FakeChatAnthropic, FakeChatOpenAI from ee.hogai.utils.types import AssistantState, PartialAssistantState from ee.hogai.utils.types.base import AssistantMessageUnion @@ -644,6 +645,147 @@ class TestRootNode(ClickhouseTestMixin, BaseTest): result = node._is_hard_limit_reached(tool_calls_count) self.assertEqual(result, expected) + async def test_node_increments_tool_count_on_tool_call(self): + """Test that RootNode increments tool count when assistant makes a tool call""" + with patch( + "ee.hogai.graph.root.nodes.RootNode._get_model", + return_value=FakeChatOpenAI( + responses=[ + LangchainAIMessage( + content="Let me help", + tool_calls=[ + { + "id": "tool-1", + "name": "create_and_query_insight", + "args": {"query_description": "test"}, + } + ], + ) + ] + ), + ): + node = RootNode(self.team, self.user) + + # Test starting from no tool calls + state_1 = AssistantState(messages=[HumanMessage(content="Hello")]) + result_1 = await node.arun(state_1, {}) + self.assertEqual(result_1.root_tool_calls_count, 1) + + # Test incrementing from existing count + state_2 = AssistantState( + messages=[HumanMessage(content="Hello")], + root_tool_calls_count=5, + ) + result_2 = await node.arun(state_2, {}) + self.assertEqual(result_2.root_tool_calls_count, 6) + + async def test_node_resets_tool_count_on_plain_response(self): + """Test that RootNode resets tool count when assistant responds without tool calls""" + with patch( + "ee.hogai.graph.root.nodes.RootNode._get_model", + return_value=FakeChatOpenAI(responses=[LangchainAIMessage(content="Here's your answer")]), + ): + node = RootNode(self.team, self.user) + + state = AssistantState( + messages=[HumanMessage(content="Hello")], + root_tool_calls_count=5, + ) + result = await node.arun(state, {}) + self.assertIsNone(result.root_tool_calls_count) + + def test_router_returns_end_for_plain_response(self): + """Test that router returns END when message has no tool calls""" + from ee.hogai.utils.types import AssistantNodeName + + node = RootNode(self.team, self.user) + + state = AssistantState( + messages=[ + HumanMessage(content="Hello"), + AssistantMessage(content="Hi there!"), + ] + ) + result = node.router(state) + self.assertEqual(result, AssistantNodeName.END) + + def test_router_returns_send_for_single_tool_call(self): + """Test that router returns Send for single tool call""" + from langgraph.types import Send + + from ee.hogai.utils.types import AssistantNodeName + + node = RootNode(self.team, self.user) + + state = AssistantState( + messages=[ + HumanMessage(content="Generate insights"), + AssistantMessage( + content="Let me help", + tool_calls=[ + AssistantToolCall( + id="tool-1", + name="create_and_query_insight", + args={"query_description": "test"}, + ) + ], + ), + ] + ) + result = node.router(state) + + # Verify it's a list of Send objects + self.assertIsInstance(result, list) + self.assertEqual(len(result), 1) + self.assertIsInstance(result[0], Send) + self.assertEqual(result[0].node, AssistantNodeName.ROOT_TOOLS) + self.assertEqual(result[0].arg.root_tool_call_id, "tool-1") + + def test_router_returns_multiple_sends_for_parallel_tool_calls(self): + """Test that router returns multiple Send objects for parallel tool calls""" + from langgraph.types import Send + + from ee.hogai.utils.types import AssistantNodeName + + node = RootNode(self.team, self.user) + + state = AssistantState( + messages=[ + HumanMessage(content="Generate multiple insights"), + AssistantMessage( + content="Let me create several insights", + tool_calls=[ + AssistantToolCall( + id="tool-1", + name="create_and_query_insight", + args={"query_description": "trends"}, + ), + AssistantToolCall( + id="tool-2", + name="create_and_query_insight", + args={"query_description": "funnel"}, + ), + AssistantToolCall( + id="tool-3", + name="create_and_query_insight", + args={"query_description": "retention"}, + ), + ], + ), + ] + ) + result = node.router(state) + + # Verify it's a list of Send objects + self.assertIsInstance(result, list) + self.assertEqual(len(result), 3) + + # Verify all are Send objects to ROOT_TOOLS + for i, send in enumerate(result): + self.assertIsInstance(send, Send) + self.assertEqual(send.node, AssistantNodeName.ROOT_TOOLS) + self.assertEqual(send.arg.root_tool_call_id, f"tool-{i+1}") + class TestRootNodeTools(BaseTest): def test_node_tools_router(self): @@ -675,9 +817,13 @@ class TestRootNodeTools(BaseTest): node = RootNodeTools(self.team, self.user) state = AssistantState(messages=[HumanMessage(content="Hello")]) result = await node.arun(state, {}) - self.assertEqual(result, PartialAssistantState(root_tool_calls_count=0)) + self.assertEqual(result, PartialAssistantState(root_tool_call_id=None)) + + @patch("ee.hogai.graph.root.tools.create_and_query_insight.CreateAndQueryInsightTool._arun_impl") + async def test_run_valid_tool_call(self, create_and_query_insight_mock): + test_message = AssistantToolCallMessage(content="Tool result", tool_call_id="xyz", id="msg-1") + create_and_query_insight_mock.return_value = ("", ToolMessagesArtifact(messages=[test_message])) - async def test_run_valid_tool_call(self): node = RootNodeTools(self.team, self.user) state = AssistantState( messages=[ @@ -688,17 +834,20 @@ class TestRootNodeTools(BaseTest): AssistantToolCall( id="xyz", name="create_and_query_insight", - args={"query_kind": "trends", "query_description": "test query"}, + args={"query_description": "test query"}, ) ], ) - ] + ], + root_tool_call_id="xyz", ) result = await node.arun(state, {}) self.assertIsInstance(result, PartialAssistantState) - self.assertEqual(result.root_tool_call_id, "xyz") - self.assertEqual(result.root_tool_insight_plan, "test query") - self.assertEqual(result.root_tool_insight_type, None) # Insight type is determined by query planner node + assert result is not None + self.assertEqual(len(result.messages), 1) + assert isinstance(result.messages[0], AssistantToolCallMessage) + self.assertEqual(result.messages[0].tool_call_id, "xyz") + create_and_query_insight_mock.assert_called_once_with(query_description="test query") async def test_run_valid_contextual_tool_call(self): node = RootNodeTools(self.team, self.user) @@ -715,7 +864,8 @@ class TestRootNodeTools(BaseTest): ) ], ) - ] + ], + root_tool_call_id="xyz", ) result = await node.arun( @@ -730,69 +880,9 @@ class TestRootNodeTools(BaseTest): ) self.assertIsInstance(result, PartialAssistantState) - self.assertEqual(result.root_tool_call_id, None) # Tool was fully handled by the node - self.assertIsNone(result.root_tool_insight_plan) # No insight plan for contextual tools - self.assertIsNone(result.root_tool_insight_type) # No insight type for contextual tools - - async def test_run_multiple_tool_calls_raises(self): - node = RootNodeTools(self.team, self.user) - state = AssistantState( - messages=[ - AssistantMessage( - content="Hello", - id="test-id", - tool_calls=[ - AssistantToolCall( - id="xyz1", - name="create_and_query_insight", - args={"query_kind": "trends", "query_description": "test query 1"}, - ), - AssistantToolCall( - id="xyz2", - name="create_and_query_insight", - args={"query_kind": "funnel", "query_description": "test query 2"}, - ), - ], - ) - ] - ) - with self.assertRaises(ValueError) as cm: - await node.arun(state, {}) - self.assertEqual(str(cm.exception), "Expected exactly one tool call.") - - async def test_run_increments_tool_count(self): - node = RootNodeTools(self.team, self.user) - state = AssistantState( - messages=[ - AssistantMessage( - content="Hello", - id="test-id", - tool_calls=[ - AssistantToolCall( - id="xyz", - name="create_and_query_insight", - args={"query_kind": "trends", "query_description": "test query"}, - ) - ], - ) - ], - root_tool_calls_count=2, # Starting count - ) - result = await node.arun(state, {}) - self.assertEqual(result.root_tool_calls_count, 3) # Should increment by 1 - - async def test_run_resets_tool_count(self): - node = RootNodeTools(self.team, self.user) - - # Test reset when no tool calls in AssistantMessage - state_1 = AssistantState(messages=[AssistantMessage(content="Hello", tool_calls=[])], root_tool_calls_count=3) - result = await node.arun(state_1, {}) - self.assertEqual(result.root_tool_calls_count, 0) - - # Test reset when last message is HumanMessage - state_2 = AssistantState(messages=[HumanMessage(content="Hello")], root_tool_calls_count=3) - result = await node.arun(state_2, {}) - self.assertEqual(result.root_tool_calls_count, 0) + assert result is not None + self.assertEqual(len(result.messages), 1) + self.assertIsInstance(result.messages[0], AssistantToolCallMessage) async def test_navigate_tool_call_raises_node_interrupt(self): """Test that navigate tool calls raise NodeInterrupt to pause graph execution""" @@ -805,7 +895,8 @@ class TestRootNodeTools(BaseTest): id="test-id", tool_calls=[AssistantToolCall(id="nav-123", name="navigate", args={"page_key": "insights"})], ) - ] + ], + root_tool_call_id="nav-123", ) mock_navigate_tool = AsyncMock() @@ -828,319 +919,6 @@ class TestRootNodeTools(BaseTest): self.assertEqual(interrupt_data.tool_call_id, "nav-123") self.assertEqual(interrupt_data.ui_payload, {"navigate": {"page_key": "insights"}}) - def test_billing_tool_routing(self): - """Test that billing tool calls are routed correctly""" - node = RootNodeTools(self.team, self.user) - - # Create state with billing tool call (read_data with kind=billing_info) - state = AssistantState( - messages=[ - AssistantMessage( - content="Let me check your billing information", - tool_calls=[AssistantToolCall(id="billing-123", name="read_data", args={"kind": "billing_info"})], - ) - ], - root_tool_call_id="billing-123", - ) - - # Should route to billing - self.assertEqual(node.router(state), "billing") - - def test_router_insights_path(self): - """Test router routes to insights when root_tool_insight_plan is set""" - node = RootNodeTools(self.team, self.user) - - state = AssistantState( - messages=[ - AssistantMessage( - content="Creating insight", - tool_calls=[ - AssistantToolCall( - id="insight-123", - name="create_and_query_insight", - args={"query_kind": "trends", "query_description": "test"}, - ) - ], - ) - ], - root_tool_call_id="insight-123", - root_tool_insight_plan="test query plan", - ) - - self.assertEqual(node.router(state), "insights") - - def test_router_insights_search_path(self): - """Test router routes to insights_search when search_insights_query is set""" - node = RootNodeTools(self.team, self.user) - - state = AssistantState( - messages=[ - AssistantMessage( - content="Searching insights", - tool_calls=[AssistantToolCall(id="search-123", name="search", args={"kind": "insights"})], - ) - ], - root_tool_call_id="search-123", - search_insights_query="test search query", - ) - - self.assertEqual(node.router(state), "insights_search") - - def test_router_session_summarization_path(self): - """Test router routes to session_summarization when session_summarization_query is set""" - node = RootNodeTools(self.team, self.user) - - state = AssistantState( - messages=[ - AssistantMessage( - content="Summarizing sessions", - tool_calls=[ - AssistantToolCall( - id="session-123", name="session_summarization", args={"session_summarization_query": "test"} - ) - ], - ) - ], - root_tool_call_id="session-123", - session_summarization_query="test session query", - ) - - self.assertEqual(node.router(state), "session_summarization") - - def test_router_create_dashboard_path(self): - """Test router routes to create_dashboard when create_dashboard tool is called""" - node = RootNodeTools(self.team, self.user) - - state = AssistantState( - messages=[ - AssistantMessage( - content="Creating dashboard", - tool_calls=[ - AssistantToolCall( - id="dashboard-123", name="create_dashboard", args={"search_insights_queries": []} - ) - ], - ) - ], - root_tool_call_id="dashboard-123", - ) - - self.assertEqual(node.router(state), "create_dashboard") - - def test_router_search_documentation_fallback(self): - """Test router routes to search_documentation when root_tool_call_id is set but no specific route""" - node = RootNodeTools(self.team, self.user) - - state = AssistantState( - messages=[ - AssistantMessage( - content="Searching docs", - tool_calls=[AssistantToolCall(id="search-123", name="search", args={"kind": "docs"})], - ) - ], - root_tool_call_id="search-123", - ) - - self.assertEqual(node.router(state), "search_documentation") - - async def test_arun_session_summarization_with_all_args(self): - """Test session_summarization tool call with all arguments""" - node = RootNodeTools(self.team, self.user) - state = AssistantState( - messages=[ - AssistantMessage( - content="Summarizing sessions", - id="test-id", - tool_calls=[ - AssistantToolCall( - id="session-123", - name="session_summarization", - args={ - "session_summarization_query": "test query", - "should_use_current_filters": True, - "summary_title": "Test Summary", - }, - ) - ], - ) - ] - ) - result = await node.arun(state, {}) - self.assertIsInstance(result, PartialAssistantState) - self.assertEqual(result.root_tool_call_id, "session-123") - self.assertEqual(result.session_summarization_query, "test query") - self.assertEqual(result.should_use_current_filters, True) - self.assertEqual(result.summary_title, "Test Summary") - - async def test_arun_session_summarization_missing_optional_args(self): - """Test session_summarization tool call with missing optional arguments""" - node = RootNodeTools(self.team, self.user) - state = AssistantState( - messages=[ - AssistantMessage( - content="Summarizing sessions", - id="test-id", - tool_calls=[ - AssistantToolCall( - id="session-123", - name="session_summarization", - args={"session_summarization_query": "test query"}, - ) - ], - ) - ] - ) - result = await node.arun(state, {}) - self.assertIsInstance(result, PartialAssistantState) - self.assertEqual(result.should_use_current_filters, False) # Default value - self.assertIsNone(result.summary_title) - - async def test_arun_create_dashboard_with_queries(self): - """Test create_dashboard tool call with search_insights_queries""" - node = RootNodeTools(self.team, self.user) - state = AssistantState( - messages=[ - AssistantMessage( - content="Creating dashboard", - id="test-id", - tool_calls=[ - AssistantToolCall( - id="dashboard-123", - name="create_dashboard", - args={ - "dashboard_name": "Test Dashboard", - "search_insights_queries": [ - {"name": "Query 1", "description": "Trends insight description"}, - {"name": "Query 2", "description": "Funnel insight description"}, - ], - }, - ) - ], - ) - ] - ) - result = await node.arun(state, {}) - self.assertIsInstance(result, PartialAssistantState) - self.assertEqual(result.root_tool_call_id, "dashboard-123") - self.assertEqual(result.dashboard_name, "Test Dashboard") - self.assertIsNotNone(result.search_insights_queries) - assert result.search_insights_queries is not None - self.assertEqual(len(result.search_insights_queries), 2) - self.assertEqual(result.search_insights_queries[0].name, "Query 1") - self.assertEqual(result.search_insights_queries[1].name, "Query 2") - - async def test_arun_search_tool_insights_kind(self): - """Test search tool with kind=insights""" - node = RootNodeTools(self.team, self.user) - state = AssistantState( - messages=[ - AssistantMessage( - content="Searching insights", - id="test-id", - tool_calls=[ - AssistantToolCall(id="search-123", name="search", args={"query": "test", "kind": "insights"}) - ], - ) - ] - ) - - mock_tool = AsyncMock() - mock_tool.ainvoke.return_value = LangchainToolMessage( - content="Search results", - tool_call_id="search-123", - name="search", - artifact={"kind": "insights", "query": "test"}, - ) - - with mock_contextual_tool(mock_tool): - result = await node.arun(state, {"configurable": {"contextual_tools": {"search": {}}}}) - - self.assertIsInstance(result, PartialAssistantState) - self.assertEqual(result.root_tool_call_id, "search-123") - self.assertEqual(result.search_insights_query, "test") - - async def test_arun_search_tool_docs_kind(self): - """Test search tool with kind=docs""" - node = RootNodeTools(self.team, self.user) - state = AssistantState( - messages=[ - AssistantMessage( - content="Searching docs", - id="test-id", - tool_calls=[ - AssistantToolCall(id="search-123", name="search", args={"query": "test", "kind": "docs"}) - ], - ) - ] - ) - - mock_tool = AsyncMock() - mock_tool.ainvoke.return_value = LangchainToolMessage( - content="Docs results", tool_call_id="search-123", name="search", artifact={"kind": "docs"} - ) - - with mock_contextual_tool(mock_tool): - result = await node.arun(state, {"configurable": {"contextual_tools": {"search": {}}}}) - - self.assertIsInstance(result, PartialAssistantState) - self.assertEqual(result.root_tool_call_id, "search-123") - - async def test_arun_read_data_billing_info(self): - """Test read_data tool with kind=billing_info""" - node = RootNodeTools(self.team, self.user) - state = AssistantState( - messages=[ - AssistantMessage( - content="Reading billing info", - id="test-id", - tool_calls=[AssistantToolCall(id="read-123", name="read_data", args={"kind": "billing_info"})], - ) - ] - ) - - mock_tool = AsyncMock() - mock_tool.ainvoke.return_value = LangchainToolMessage( - content="Billing data", tool_call_id="read-123", name="read_data", artifact={"kind": "billing_info"} - ) - - with mock_contextual_tool(mock_tool): - result = await node.arun(state, {"configurable": {"contextual_tools": {"read_data": {}}}}) - - self.assertIsInstance(result, PartialAssistantState) - self.assertEqual(result.root_tool_call_id, "read-123") - - async def test_arun_tool_updates_state(self): - """Test that when a tool updates its _state, the new messages are included""" - node = RootNodeTools(self.team, self.user) - state = AssistantState( - messages=[ - AssistantMessage( - content="Using tool", - id="test-id", - tool_calls=[AssistantToolCall(id="tool-123", name="test_tool", args={})], - ) - ] - ) - - mock_tool = AsyncMock() - # Simulate tool appending a message to state - updated_state = AssistantState( - messages=[ - *state.messages, - AssistantToolCallMessage(content="Tool result", tool_call_id="tool-123", id="msg-1"), - ] - ) - mock_tool._state = updated_state - mock_tool.ainvoke.return_value = LangchainToolMessage(content="Tool result", tool_call_id="tool-123") - - with mock_contextual_tool(mock_tool): - result = await node.arun(state, {"configurable": {"contextual_tools": {"test_tool": {}}}}) - - # Should include the new message from the updated state - self.assertIsInstance(result, PartialAssistantState) - self.assertEqual(len(result.messages), 1) - self.assertIsInstance(result.messages[0], AssistantToolCallMessage) - async def test_arun_tool_returns_wrong_type_returns_error_message(self): """Test that tool returning wrong type returns an error message""" node = RootNodeTools(self.team, self.user) @@ -1151,7 +929,8 @@ class TestRootNodeTools(BaseTest): id="test-id", tool_calls=[AssistantToolCall(id="tool-123", name="test_tool", args={})], ) - ] + ], + root_tool_call_id="tool-123", ) mock_tool = AsyncMock() @@ -1161,10 +940,11 @@ class TestRootNodeTools(BaseTest): result = await node.arun(state, {"configurable": {"contextual_tools": {"test_tool": {}}}}) self.assertIsInstance(result, PartialAssistantState) + assert result is not None self.assertEqual(len(result.messages), 1) assert isinstance(result.messages[0], AssistantToolCallMessage) self.assertEqual(result.messages[0].tool_call_id, "tool-123") - self.assertEqual(result.root_tool_calls_count, 1) + self.assertIn("internal error", result.messages[0].content) async def test_arun_unknown_tool_returns_error_message(self): """Test that unknown tool name returns an error message""" @@ -1176,14 +956,16 @@ class TestRootNodeTools(BaseTest): id="test-id", tool_calls=[AssistantToolCall(id="tool-123", name="unknown_tool", args={})], ) - ] + ], + root_tool_call_id="tool-123", ) with patch("ee.hogai.tool.get_contextual_tool_class", return_value=None): result = await node.arun(state, {}) self.assertIsInstance(result, PartialAssistantState) + assert result is not None self.assertEqual(len(result.messages), 1) assert isinstance(result.messages[0], AssistantToolCallMessage) self.assertEqual(result.messages[0].tool_call_id, "tool-123") - self.assertEqual(result.root_tool_calls_count, 1) + self.assertIn("does not exist", result.messages[0].content) diff --git a/ee/hogai/graph/root/tools/__init__.py b/ee/hogai/graph/root/tools/__init__.py index 793c94b097..7bf9c6298f 100644 --- a/ee/hogai/graph/root/tools/__init__.py +++ b/ee/hogai/graph/root/tools/__init__.py @@ -1,21 +1,26 @@ -from .legacy import create_and_query_insight, create_dashboard, session_summarization +from .create_and_query_insight import CreateAndQueryInsightTool, CreateAndQueryInsightToolArgs +from .create_dashboard import CreateDashboardTool, CreateDashboardToolArgs from .navigate import NavigateTool, NavigateToolArgs from .read_data import ReadDataTool, ReadDataToolArgs from .read_taxonomy import ReadTaxonomyTool from .search import SearchTool, SearchToolArgs +from .session_summarization import SessionSummarizationTool, SessionSummarizationToolArgs from .todo_write import TodoWriteTool, TodoWriteToolArgs __all__ = [ + "CreateAndQueryInsightTool", + "CreateAndQueryInsightToolArgs", + "CreateDashboardTool", + "CreateDashboardToolArgs", + "NavigateTool", + "NavigateToolArgs", + "ReadDataTool", + "ReadDataToolArgs", "ReadTaxonomyTool", "SearchTool", "SearchToolArgs", - "ReadDataTool", - "ReadDataToolArgs", + "SessionSummarizationTool", + "SessionSummarizationToolArgs", "TodoWriteTool", "TodoWriteToolArgs", - "NavigateTool", - "NavigateToolArgs", - "create_and_query_insight", - "session_summarization", - "create_dashboard", ] diff --git a/ee/hogai/graph/root/tools/create_and_query_insight.py b/ee/hogai/graph/root/tools/create_and_query_insight.py new file mode 100644 index 0000000000..a5f6985472 --- /dev/null +++ b/ee/hogai/graph/root/tools/create_and_query_insight.py @@ -0,0 +1,206 @@ +from typing import Literal + +from pydantic import BaseModel, Field + +from posthog.schema import AssistantTool, AssistantToolCallMessage, VisualizationMessage + +from ee.hogai.context.context import AssistantContextManager +from ee.hogai.graph.insights_graph.graph import InsightsGraph +from ee.hogai.graph.schema_generator.nodes import SchemaGenerationException +from ee.hogai.tool import MaxTool, ToolMessagesArtifact +from ee.hogai.utils.prompt import format_prompt_string +from ee.hogai.utils.types.base import AssistantState + +INSIGHT_TOOL_PROMPT = """ +Use this tool to create a product analytics insight for a given natural language description by spawning a subagent. +The tool generates a query and returns formatted text results for a specific data question or iterates on a previous query. It only retrieves a single query per call. If the user asks for multiple insights, you need to decompose a query into multiple subqueries and call the tool for each subquery. + +Follow these guidelines when retrieving data: +- If the same insight is already in the conversation history, reuse the retrieved data only when this does not violate the section (i.e. only when a presence-check, count, or sort on existing columns is enough). +- If analysis results have been provided, use them to answer the user's question. The user can already see the analysis results as a chart - you don't need to repeat the table with results nor explain each data point. +- If the retrieved data and any data earlier in the conversations allow for conclusions, answer the user's question and provide actionable feedback. +- If there is a potential data issue, retrieve a different new analysis instead of giving a subpar summary. Note: empty data is NOT a potential data issue. +- If the query cannot be answered with a UI-built insight type - trends, funnels, retention - choose the SQL type to answer the question (e.g. for listing events or aggregating in ways that aren't supported in trends/funnels/retention). + +IMPORTANT: Avoid generic advice. Take into account what you know about the product. Your answer needs to be super high-impact and no more than a few sentences. +Remember: do NOT retrieve data for the same query more than 3 times in a row. + +# Data schema + +You can pass events, actions, properties, and property values to this tool by specifying the "Data schema" section. + + +User: Calculate onboarding completion rate for the last week. +Assistant: I'm going to retrieve the existing data schema first. +*Retrieves matching events, properties, and property values* +Assistant: I'm going to create a new trends insight. +*Calls this tool with the query description: "Trends insight of the onboarding completion rate. Data schema: Relevant matching data schema"* + + +# Supported insight types +## Trends +A trends insight visualizes events over time using time series. They're useful for finding patterns in historical data. + +The trends insights have the following features: +- The insight can show multiple trends in one request. +- Custom formulas can calculate derived metrics, like `A/B*100` to calculate a ratio. +- Filter and break down data using multiple properties. +- Compare with the previous period and sample data. +- Apply various aggregation types, like sum, average, etc., and chart types. +- And more. + +Examples of use cases include: +- How the product's most important metrics change over time. +- Long-term patterns, or cycles in product's usage. +- The usage of different features side-by-side. +- How the properties of events vary using aggregation (sum, average, etc). +- Users can also visualize the same data points in a variety of ways. + +## Funnel +A funnel insight visualizes a sequence of events that users go through in a product. They use percentages as the primary aggregation type. Funnels use two or more series, so the conversation history should mention at least two events. + +The funnel insights have the following features: +- Various visualization types (steps, time-to-convert, historical trends). +- Filter data and apply exclusion steps. +- Break down data using a single property. +- Specify conversion windows, details of conversion calculation, attribution settings. +- Sample data. +- And more. + +Examples of use cases include: +- Conversion rates. +- Drop off steps. +- Steps with the highest friction and time to convert. +- If product changes are improving their funnel over time. +- Average/median time to convert. +- Conversion trends over time. + +## Retention +A retention insight visualizes how many users return to the product after performing some action. They're useful for understanding user engagement and retention. + +The retention insights have the following features: filter data, sample data, and more. + +Examples of use cases include: +- How many users come back and perform an action after their first visit. +- How many users come back to perform action X after performing action Y. +- How often users return to use a specific feature. + +## SQL +The 'sql' insight type allows you to write arbitrary SQL queries to retrieve data. + +The SQL insights have the following features: +- Filter data using arbitrary SQL. +- All ClickHouse SQL features. +- You can nest subqueries as needed. +""".strip() + +INSIGHT_TOOL_CONTEXT_PROMPT_TEMPLATE = """ +The user is currently editing an insight (aka query). Here is that insight's current definition, which can be edited using the `create_and_query_insight` tool: + +```json +{current_query} +``` + + +Do not remove any fields from the current insight definition. Do not change any other fields than the ones the user asked for. Keep the rest as is. + +""".strip() + +INSIGHT_TOOL_FAILURE_SYSTEM_REMINDER_PROMPT = """ + +Inform the user that you've encountered an error during the creation of the insight. Afterwards, try to generate a new insight with a different query. +Terminate if the error persists. + +""".strip() + +INSIGHT_TOOL_HANDLED_FAILURE_PROMPT = """ +The agent has encountered the error while creating an insight. + +Generated output: +``` +{{{output}}} +``` + +Error message: +``` +{{{error_message}}} +``` + +{{{system_reminder}}} +""".strip() + + +INSIGHT_TOOL_UNHANDLED_FAILURE_PROMPT = """ +The agent has encountered an unknown error while creating an insight. +{{{system_reminder}}} +""".strip() + + +class CreateAndQueryInsightToolArgs(BaseModel): + query_description: str = Field( + description=( + "A description of the query to generate, encapsulating the details of the user's request. " + "Include all relevant context from earlier messages too, as the tool won't see that conversation history. " + "If an existing insight has been used as a starting point, include that insight's filters and query in the description. " + "Don't be overly prescriptive with event or property names, unless the user indicated they mean this specific name (e.g. with quotes). " + "If the users seems to ask for a list of entities, rather than a count, state this explicitly." + ) + ) + + +class CreateAndQueryInsightTool(MaxTool): + name: Literal["create_and_query_insight"] = "create_and_query_insight" + args_schema: type[BaseModel] = CreateAndQueryInsightToolArgs + description: str = INSIGHT_TOOL_PROMPT + context_prompt_template: str = INSIGHT_TOOL_CONTEXT_PROMPT_TEMPLATE + thinking_message: str = "Coming up with an insight" + + async def _arun_impl(self, query_description: str) -> tuple[str, ToolMessagesArtifact | None]: + graph = InsightsGraph(self._team, self._user).compile_full_graph() + new_state = self._state.model_copy( + update={ + "root_tool_call_id": self.tool_call_id, + "root_tool_insight_plan": query_description, + }, + deep=True, + ) + try: + dict_state = await graph.ainvoke(new_state) + except SchemaGenerationException as e: + return format_prompt_string( + INSIGHT_TOOL_HANDLED_FAILURE_PROMPT, + output=e.llm_output, + error_message=e.validation_message, + system_reminder=INSIGHT_TOOL_FAILURE_SYSTEM_REMINDER_PROMPT, + ), None + + updated_state = AssistantState.model_validate(dict_state) + maybe_viz_message, tool_call_message = updated_state.messages[-2:] + + if not isinstance(tool_call_message, AssistantToolCallMessage): + return format_prompt_string( + INSIGHT_TOOL_UNHANDLED_FAILURE_PROMPT, system_reminder=INSIGHT_TOOL_FAILURE_SYSTEM_REMINDER_PROMPT + ), None + + # If the previous message is not a visualization message, the agent has requested human feedback. + if not isinstance(maybe_viz_message, VisualizationMessage): + return "", ToolMessagesArtifact(messages=[tool_call_message]) + + # If the contextual tool is available, we're editing an insight. + # Add the UI payload to the tool call message. + if self.is_editing_mode(self._context_manager): + tool_call_message = AssistantToolCallMessage( + content=tool_call_message.content, + ui_payload={self.get_name(): maybe_viz_message.answer.model_dump(exclude_none=True)}, + id=tool_call_message.id, + tool_call_id=tool_call_message.tool_call_id, + ) + + return "", ToolMessagesArtifact(messages=[maybe_viz_message, tool_call_message]) + + @classmethod + def is_editing_mode(cls, context_manager: AssistantContextManager) -> bool: + """ + Determines if the tool is in editing mode. + """ + return AssistantTool.CREATE_AND_QUERY_INSIGHT.value in context_manager.get_contextual_tools() diff --git a/ee/hogai/graph/root/tools/create_dashboard.py b/ee/hogai/graph/root/tools/create_dashboard.py new file mode 100644 index 0000000000..23d6b6d906 --- /dev/null +++ b/ee/hogai/graph/root/tools/create_dashboard.py @@ -0,0 +1,53 @@ +from typing import Literal + +from langchain_core.runnables import RunnableLambda +from pydantic import BaseModel, Field + +from ee.hogai.graph.dashboards.nodes import DashboardCreationNode +from ee.hogai.tool import MaxTool, ToolMessagesArtifact +from ee.hogai.utils.types.base import AssistantState, InsightQuery, PartialAssistantState + +CREATE_DASHBOARD_TOOL_PROMPT = """ +Use this tool when users ask to create, build, or make a new dashboard with insights. +This tool will search for existing insights that match the user's requirements so no need to call `search` tool, or create new insights if none are found, then combine them into a dashboard. +Do not call this tool if the user only asks to find, search for, or look up existing insights and does not ask to create a dashboard. +If you decided to use this tool, there is no need to call `search_insights` tool beforehand. The tool will search for existing insights that match the user's requirements and create new insights if none are found. +""".strip() + + +class CreateDashboardToolArgs(BaseModel): + search_insights_queries: list[InsightQuery] = Field( + description="A list of insights to be included in the dashboard. Include all the insights that the user mentioned." + ) + dashboard_name: str = Field( + description=( + "The name of the dashboard to be created based on the user request. It should be short and concise as it will be displayed as a header in the dashboard tile." + ) + ) + + +class CreateDashboardTool(MaxTool): + name: Literal["create_dashboard"] = "create_dashboard" + description: str = CREATE_DASHBOARD_TOOL_PROMPT + thinking_message: str = "Creating a dashboard" + context_prompt_template: str = "Creates a dashboard based on the user's request" + args_schema: type[BaseModel] = CreateDashboardToolArgs + show_tool_call_message: bool = False + + async def _arun_impl( + self, search_insights_queries: list[InsightQuery], dashboard_name: str + ) -> tuple[str, ToolMessagesArtifact | None]: + node = DashboardCreationNode(self._team, self._user) + chain: RunnableLambda[AssistantState, PartialAssistantState | None] = RunnableLambda(node) + copied_state = self._state.model_copy( + deep=True, + update={ + "root_tool_call_id": self.tool_call_id, + "search_insights_queries": search_insights_queries, + "dashboard_name": dashboard_name, + }, + ) + result = await chain.ainvoke(copied_state) + if not result or not result.messages: + return "Dashboard creation failed", None + return "", ToolMessagesArtifact(messages=result.messages) diff --git a/ee/hogai/graph/root/tools/full_text_search/test/test_toolkit.py b/ee/hogai/graph/root/tools/full_text_search/test/test_toolkit.py index 1112ade793..6139e99166 100644 --- a/ee/hogai/graph/root/tools/full_text_search/test/test_toolkit.py +++ b/ee/hogai/graph/root/tools/full_text_search/test/test_toolkit.py @@ -3,10 +3,13 @@ from unittest.mock import Mock, patch from django.conf import settings +from langchain_core.runnables import RunnableConfig from parameterized import parameterized -from ee.hogai.graph.root.tools.full_text_search.tool import ENTITY_MAP, EntitySearchToolkit, FTSKind +from ee.hogai.context import AssistantContextManager +from ee.hogai.graph.root.tools.full_text_search.tool import ENTITY_MAP, EntitySearchTool, FTSKind from ee.hogai.graph.shared_prompts import HYPERLINK_USAGE_INSTRUCTIONS +from ee.hogai.utils.types.base import AssistantState class TestEntitySearchToolkit(NonAtomicBaseTest): @@ -18,7 +21,13 @@ class TestEntitySearchToolkit(NonAtomicBaseTest): self.team.organization = Mock() self.team.organization.id = 789 self.user = Mock() - self.toolkit = EntitySearchToolkit(self.team, self.user) + self.toolkit = EntitySearchTool( + team=self.team, + user=self.user, + state=AssistantState(messages=[]), + config=RunnableConfig(configurable={}), + context_manager=AssistantContextManager(self.team, self.user, {}), + ) @parameterized.expand( [ diff --git a/ee/hogai/graph/root/tools/full_text_search/tool.py b/ee/hogai/graph/root/tools/full_text_search/tool.py index 8ecd54d9e5..e2d453a70c 100644 --- a/ee/hogai/graph/root/tools/full_text_search/tool.py +++ b/ee/hogai/graph/root/tools/full_text_search/tool.py @@ -6,13 +6,14 @@ import yaml from posthoganalytics import capture_exception from posthog.api.search import EntityConfig, search_entities -from posthog.models import Action, Cohort, Dashboard, Experiment, FeatureFlag, Insight, Survey, Team, User +from posthog.models import Action, Cohort, Dashboard, Experiment, FeatureFlag, Insight, Survey from posthog.rbac.user_access_control import UserAccessControl from posthog.sync import database_sync_to_async from products.error_tracking.backend.models import ErrorTrackingIssue from ee.hogai.graph.shared_prompts import HYPERLINK_USAGE_INSTRUCTIONS +from ee.hogai.tool import MaxSubtool from .prompts import ENTITY_TYPE_SUMMARY_TEMPLATE, FOUND_ENTITIES_MESSAGE_TEMPLATE @@ -91,14 +92,10 @@ SEARCH_KIND_TO_DATABASE_ENTITY_TYPE: dict[FTSKind, str] = { } -class EntitySearchToolkit: +class EntitySearchTool(MaxSubtool): MAX_ENTITY_RESULTS = 10 MAX_CONCURRENT_SEARCHES = 10 - def __init__(self, team: Team, user: User): - self._team = team - self._user = user - async def execute(self, query: str, search_kind: FTSKind) -> str: """Search for entities by query and entity.""" try: diff --git a/ee/hogai/graph/root/tools/legacy.py b/ee/hogai/graph/root/tools/legacy.py deleted file mode 100644 index e028434d10..0000000000 --- a/ee/hogai/graph/root/tools/legacy.py +++ /dev/null @@ -1,204 +0,0 @@ -# The module contains tools that are deprecated and will be replaced in the future with MaxTool implementations. -from pydantic import BaseModel, Field - -from ee.hogai.utils.types.base import InsightQuery - - -# Lower casing matters here. Do not change it. -class create_and_query_insight(BaseModel): - """ - Use this tool to spawn a subagent that will create a product analytics insight for a given description. - The tool generates a query and returns formatted text results for a specific data question or iterates on a previous query. It only retrieves a single query per call. If the user asks for multiple insights, you need to decompose a query into multiple subqueries and call the tool for each subquery. - - Follow these guidelines when retrieving data: - - If the same insight is already in the conversation history, reuse the retrieved data only when this does not violate the section (i.e. only when a presence-check, count, or sort on existing columns is enough). - - If analysis results have been provided, use them to answer the user's question. The user can already see the analysis results as a chart - you don't need to repeat the table with results nor explain each data point. - - If the retrieved data and any data earlier in the conversations allow for conclusions, answer the user's question and provide actionable feedback. - - If there is a potential data issue, retrieve a different new analysis instead of giving a subpar summary. Note: empty data is NOT a potential data issue. - - If the query cannot be answered with a UI-built insight type - trends, funnels, retention - choose the SQL type to answer the question (e.g. for listing events or aggregating in ways that aren't supported in trends/funnels/retention). - - IMPORTANT: Avoid generic advice. Take into account what you know about the product. Your answer needs to be super high-impact and no more than a few sentences. - Remember: do NOT retrieve data for the same query more than 3 times in a row. - - # Data schema - - You can pass events, actions, properties, and property values to this tool by specifying the "Data schema" section. - - - User: Calculate onboarding completion rate for the last week. - Assistant: I'm going to retrieve the existing data schema first. - *Retrieves matching events, properties, and property values* - Assistant: I'm going to create a new trends insight. - *Calls this tool with the query description: "Trends insight of the onboarding completion rate. Data schema: Relevant matching data schema"* - - - # Supported insight types - ## Trends - A trends insight visualizes events over time using time series. They're useful for finding patterns in historical data. - - The trends insights have the following features: - - The insight can show multiple trends in one request. - - Custom formulas can calculate derived metrics, like `A/B*100` to calculate a ratio. - - Filter and break down data using multiple properties. - - Compare with the previous period and sample data. - - Apply various aggregation types, like sum, average, etc., and chart types. - - And more. - - Examples of use cases include: - - How the product's most important metrics change over time. - - Long-term patterns, or cycles in product's usage. - - The usage of different features side-by-side. - - How the properties of events vary using aggregation (sum, average, etc). - - Users can also visualize the same data points in a variety of ways. - - ## Funnel - A funnel insight visualizes a sequence of events that users go through in a product. They use percentages as the primary aggregation type. Funnels use two or more series, so the conversation history should mention at least two events. - - The funnel insights have the following features: - - Various visualization types (steps, time-to-convert, historical trends). - - Filter data and apply exclusion steps. - - Break down data using a single property. - - Specify conversion windows, details of conversion calculation, attribution settings. - - Sample data. - - And more. - - Examples of use cases include: - - Conversion rates. - - Drop off steps. - - Steps with the highest friction and time to convert. - - If product changes are improving their funnel over time. - - Average/median time to convert. - - Conversion trends over time. - - ## Retention - A retention insight visualizes how many users return to the product after performing some action. They're useful for understanding user engagement and retention. - - The retention insights have the following features: filter data, sample data, and more. - - Examples of use cases include: - - How many users come back and perform an action after their first visit. - - How many users come back to perform action X after performing action Y. - - How often users return to use a specific feature. - - ## SQL - The 'sql' insight type allows you to write arbitrary SQL queries to retrieve data. - - The SQL insights have the following features: - - Filter data using arbitrary SQL. - - All ClickHouse SQL features. - - You can nest subqueries as needed. - """ - - query_description: str = Field( - description=( - "A description of the query to generate, encapsulating the details of the user's request. " - "Include all relevant context from earlier messages too, as the tool won't see that conversation history. " - "If an existing insight has been used as a starting point, include that insight's filters and query in the description. " - "Don't be overly prescriptive with event or property names, unless the user indicated they mean this specific name (e.g. with quotes). " - "If the users seems to ask for a list of entities, rather than a count, state this explicitly." - ) - ) - - -class session_summarization(BaseModel): - """ - Use this tool to summarize session recordings by analysing the events within those sessions to find patterns and issues. - It will return a textual summary of the captured session recordings. - - # When to use the tool: - When the user asks to summarize session recordings: - - "summarize" synonyms: "watch", "analyze", "review", and similar - - "session recordings" synonyms: "sessions", "recordings", "replays", "user sessions", and similar - - # When NOT to use the tool: - - When the user asks to find, search for, or look up session recordings, but doesn't ask to summarize them - - When users asks to update, change, or adjust session recordings filters - - # Synonyms - - "summarize": "watch", "analyze", "review", and similar - - "session recordings": "sessions", "recordings", "replays", "user sessions", and similar - - # Managing context - If the conversation history contains context about the current filters or session recordings, follow these steps: - - Convert the user query into a `session_summarization_query` - - The query should be used to understand the user's intent - - Decide if the query is relevant to the current filters and set `should_use_current_filters` accordingly - - Generate the `summary_title` based on the user's query and the current filters - - Otherwise: - - Convert the user query into a `session_summarization_query` - - The query should be used to search for relevant sessions and then summarize them - - Assume the `should_use_current_filters` should be always `false` - - Generate the `summary_title` based on the user's query - - # Additional guidelines - - CRITICAL: Always pass the user's complete, unmodified query to the `session_summarization_query` parameter - - DO NOT truncate, summarize, or extract keywords from the user's query - - The query is used to find relevant sessions - context helps find better matches - - Use explicit tool definition to make a decision - """ - - session_summarization_query: str = Field( - description=""" - - The user's complete query for session recordings summarization. - - This will be used to find relevant session recordings. - - Always pass the user's complete, unmodified query. - - Examples: - * 'summarize all session recordings from yesterday' - * 'analyze mobile user session recordings from last week, even if 1 second' - * 'watch last 300 session recordings of MacOS users from US' - * and similar - """ - ) - should_use_current_filters: bool = Field( - description=""" - - Whether to use current filters from user's UI to find relevant session recordings. - - IMPORTANT: Should be always `false` if the current filters or `search_session_recordings` tool are not present in the conversation history. - - Examples: - * Set to `true` if one of the conditions is met: - - the user wants to summarize "current/selected/opened/my/all/these" session recordings - - the user wants to use "current/these" filters - - the user's query specifies filters identical to the current filters - - if the user's query doesn't specify any filters/conditions - - the user refers to what they're "looking at" or "viewing" - * Set to `false` if one of the conditions is met: - - no current filters or `search_session_recordings` tool are present in the conversation - - the user specifies date/time period different from the current filters - - the user specifies conditions (user, device, id, URL, etc.) not present in the current filters - """, - ) - summary_title: str = Field( - description=""" - - The name of the summary that is expected to be generated from the user's `session_summarization_query` and/or `current_filters` (if present). - - The name should cover in 3-7 words what sessions would be to be summarized in the summary - - This won't be used for any search of filtering, only to properly label the generated summary. - - Examples: - * If `should_use_current_filters` is `false`, then the `summary_title` should be generated based on the `session_summarization_query`: - - query: "I want to watch all the sessions of user `user@example.com` in the last 30 days no matter how long" -> name: "Sessions of the user user@example.com (last 30 days)" - - query: "summarize my last 100 session recordings" -> name: "Last 100 sessions" - - and similar - * If `should_use_current_filters` is `true`, then the `summary_title` should be generated based on the current filters in the context (if present): - - filters: "{"key":"$os","value":["Mac OS X"],"operator":"exact","type":"event"}" -> name: "MacOS users" - - filters: "{"date_from": "-7d", "filter_test_accounts": True}" -> name: "All sessions (last 7 days)" - - and similar - * If there's not enough context to generated the summary name - keep it an empty string ("") - """ - ) - - -class create_dashboard(BaseModel): - """ - Use this tool when users ask to create, build, or make a new dashboard with insights. - This tool will search for existing insights that match the user's requirements so no need to call `search_insights` tool, or create new insights if none are found, then combine them into a dashboard. - Do not call this tool if the user only asks to find, search for, or look up existing insights and does not ask to create a dashboard. - If you decided to use this tool, there is no need to call `search_insights` tool beforehand. The tool will search for existing insights that match the user's requirements and create new insights if none are found. - """ - - search_insights_queries: list[InsightQuery] = Field( - description="A list of insights to be included in the dashboard. Include all the insights that the user mentioned." - ) - dashboard_name: str = Field( - description=( - "The name of the dashboard to be created based on the user request. It should be short and concise as it will be displayed as a header in the dashboard tile." - ) - ) diff --git a/ee/hogai/graph/root/tools/navigate.py b/ee/hogai/graph/root/tools/navigate.py index 84a741c4e2..50607d3a37 100644 --- a/ee/hogai/graph/root/tools/navigate.py +++ b/ee/hogai/graph/root/tools/navigate.py @@ -10,7 +10,7 @@ from posthog.models import Team, User from ee.hogai.context.context import AssistantContextManager from ee.hogai.tool import MaxTool from ee.hogai.utils.prompt import format_prompt_string -from ee.hogai.utils.types.base import AssistantState +from ee.hogai.utils.types.base import AssistantState, NodePath NAVIGATION_TOOL_PROMPT = """ Use the `navigate` tool to move between different pages in the PostHog application. @@ -57,7 +57,7 @@ class NavigateTool(MaxTool): *, team: Team, user: User, - tool_call_id: str, + node_path: tuple[NodePath, ...] | None = None, state: AssistantState | None = None, config: RunnableConfig | None = None, context_manager: AssistantContextManager | None = None, @@ -69,7 +69,7 @@ class NavigateTool(MaxTool): return cls( team=team, user=user, - tool_call_id=tool_call_id, + node_path=node_path, state=state, config=config, context_manager=context_manager, diff --git a/ee/hogai/graph/root/tools/read_billing_tool/__init__.py b/ee/hogai/graph/root/tools/read_billing_tool/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/ee/hogai/graph/billing/prompts.py b/ee/hogai/graph/root/tools/read_billing_tool/prompts.py similarity index 100% rename from ee/hogai/graph/billing/prompts.py rename to ee/hogai/graph/root/tools/read_billing_tool/prompts.py diff --git a/ee/hogai/graph/root/tools/read_billing_tool/test/__init__.py b/ee/hogai/graph/root/tools/read_billing_tool/test/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/ee/hogai/graph/billing/test/test_nodes.py b/ee/hogai/graph/root/tools/read_billing_tool/test/test_nodes.py similarity index 89% rename from ee/hogai/graph/billing/test/test_nodes.py rename to ee/hogai/graph/root/tools/read_billing_tool/test/test_nodes.py index 81e87f16d6..d07ad48ebf 100644 --- a/ee/hogai/graph/billing/test/test_nodes.py +++ b/ee/hogai/graph/root/tools/read_billing_tool/test/test_nodes.py @@ -2,11 +2,12 @@ import datetime from typing import cast from uuid import uuid4 -from posthog.test.base import BaseTest, ClickhouseTestMixin +from posthog.test.base import ClickhouseTestMixin, NonAtomicBaseTest from unittest.mock import patch +from langchain_core.runnables import RunnableConfig + from posthog.schema import ( - AssistantToolCallMessage, BillingSpendResponseBreakdownType, BillingUsageResponseBreakdownType, MaxAddonInfo, @@ -21,26 +22,30 @@ from posthog.schema import ( UsageHistoryItem, ) -from ee.hogai.graph.billing.nodes import BillingNode +from ee.hogai.context.context import AssistantContextManager +from ee.hogai.graph.root.tools.read_billing_tool.tool import ReadBillingTool from ee.hogai.utils.types import AssistantState -class TestBillingNode(ClickhouseTestMixin, BaseTest): +class TestBillingNode(ClickhouseTestMixin, NonAtomicBaseTest): + CLASS_DATA_LEVEL_SETUP = False + def setUp(self): super().setUp() - self.node = BillingNode(self.team, self.user) - self.tool_call_id = str(uuid4()) - self.state = AssistantState(messages=[], root_tool_call_id=self.tool_call_id) + self.tool = ReadBillingTool( + team=self.team, + user=self.user, + state=AssistantState(messages=[], root_tool_call_id=str(uuid4())), + config=RunnableConfig(configurable={}), + context_manager=AssistantContextManager(self.team, self.user, {}), + ) - def test_run_with_no_billing_context(self): - with patch.object(self.node.context_manager, "get_billing_context", return_value=None): - result = self.node.run(self.state, {}) - self.assertEqual(len(result.messages), 1) - message = result.messages[0] - self.assertIsInstance(message, AssistantToolCallMessage) - self.assertEqual(cast(AssistantToolCallMessage, message).content, "No billing information available") + async def test_run_with_no_billing_context(self): + with patch.object(self.tool._context_manager, "get_billing_context", return_value=None): + result = await self.tool.execute() + self.assertEqual(result, "No billing information available") - def test_run_with_billing_context(self): + async def test_run_with_billing_context(self): billing_context = MaxBillingContext( subscription_level=MaxBillingContextSubscriptionLevel.PAID, billing_plan="paid", @@ -50,17 +55,13 @@ class TestBillingNode(ClickhouseTestMixin, BaseTest): products=[], ) with ( - patch.object(self.node.context_manager, "get_billing_context", return_value=billing_context), - patch.object(self.node, "_format_billing_context", return_value="Formatted Context"), + patch.object(self.tool._context_manager, "get_billing_context", return_value=billing_context), + patch.object(self.tool, "_format_billing_context", return_value="Formatted Context"), ): - result = self.node.run(self.state, {}) - self.assertEqual(len(result.messages), 1) - message = result.messages[0] - self.assertIsInstance(message, AssistantToolCallMessage) - self.assertEqual(cast(AssistantToolCallMessage, message).content, "Formatted Context") - self.assertEqual(cast(AssistantToolCallMessage, message).tool_call_id, self.tool_call_id) + result = await self.tool.execute() + self.assertEqual(result, "Formatted Context") - def test_format_billing_context(self): + async def test_format_billing_context(self): billing_context = MaxBillingContext( subscription_level=MaxBillingContextSubscriptionLevel.PAID, billing_plan="paid", @@ -89,8 +90,8 @@ class TestBillingNode(ClickhouseTestMixin, BaseTest): settings=MaxBillingContextSettings(autocapture_on=True, active_destinations=2), ) - with patch.object(self.node, "_get_top_events_by_usage", return_value=[]): - formatted_string = self.node._format_billing_context(billing_context) + with patch.object(self.tool, "_get_top_events_by_usage", return_value=[]): + formatted_string = await self.tool._format_billing_context(billing_context) self.assertIn("(paid)", formatted_string) self.assertIn("Period: 2023-01-01 to 2023-01-31", formatted_string) @@ -123,19 +124,21 @@ class TestBillingNode(ClickhouseTestMixin, BaseTest): ), ] - usage_table = self.node._format_history_table(usage_history) + usage_table = self.tool._format_history_table(usage_history) self.assertIn("### Overall (all projects)", usage_table) self.assertIn("| Data Type | 2023-01-01 | 2023-01-02 |", usage_table) self.assertIn("| Recordings | 100.00 | 200.00 |", usage_table) self.assertIn("| Events | 1.50 | 2.50 |", usage_table) - spend_table = self.node._format_history_table(spend_history) + spend_table = self.tool._format_history_table(spend_history) self.assertIn("| Mobile Recordings | 10.50 | 20.00 |", spend_table) def test_get_top_events_by_usage(self): mock_results = [("pageview", 1000), ("$autocapture", 500)] - with patch("ee.hogai.graph.billing.nodes.sync_execute", return_value=mock_results) as mock_sync_execute: - top_events = self.node._get_top_events_by_usage() + with patch( + "ee.hogai.graph.root.tools.read_billing_tool.tool.sync_execute", return_value=mock_results + ) as mock_sync_execute: + top_events = self.tool._get_top_events_by_usage() self.assertEqual(len(top_events), 2) self.assertEqual(top_events[0]["event"], "pageview") self.assertEqual(top_events[0]["count"], 1000) @@ -150,13 +153,14 @@ class TestBillingNode(ClickhouseTestMixin, BaseTest): def test_get_top_events_by_usage_query_fails(self): with patch( - "ee.hogai.graph.billing.nodes.sync_execute", side_effect=Exception("DB connection failed") + "ee.hogai.graph.root.tools.read_billing_tool.tool.sync_execute", + side_effect=Exception("DB connection failed"), ) as mock_sync_execute: - top_events = self.node._get_top_events_by_usage() + top_events = self.tool._get_top_events_by_usage() self.assertEqual(top_events, []) mock_sync_execute.assert_called_once() - def test_format_billing_context_with_addons(self): + async def test_format_billing_context_with_addons(self): """Test that addons are properly nested within products in the formatted output""" billing_context = MaxBillingContext( subscription_level=MaxBillingContextSubscriptionLevel.PAID, @@ -230,14 +234,14 @@ class TestBillingNode(ClickhouseTestMixin, BaseTest): ) with patch.object( - self.node, + self.tool, "_get_top_events_by_usage", return_value=[ {"event": "$pageview", "count": 50000, "formatted_count": "50,000"}, {"event": "$autocapture", "count": 30000, "formatted_count": "30,000"}, ], ): - formatted_string = self.node._format_billing_context(billing_context) + formatted_string = await self.tool._format_billing_context(billing_context) # Check basic info self.assertIn("paid subscription (startup)", formatted_string) @@ -274,7 +278,7 @@ class TestBillingNode(ClickhouseTestMixin, BaseTest): self.assertIn("$pageview", formatted_string) self.assertIn("50,000 events", formatted_string) - def test_format_billing_context_no_subscription(self): + async def test_format_billing_context_no_subscription(self): """Test formatting when user has no active subscription (free plan)""" billing_context = MaxBillingContext( subscription_level=MaxBillingContextSubscriptionLevel.FREE, @@ -286,8 +290,8 @@ class TestBillingNode(ClickhouseTestMixin, BaseTest): trial=MaxBillingContextTrial(is_active=True, expires_at="2023-02-01", target="teams"), ) - with patch.object(self.node, "_get_top_events_by_usage", return_value=[]): - formatted_string = self.node._format_billing_context(billing_context) + with patch.object(self.tool, "_get_top_events_by_usage", return_value=[]): + formatted_string = await self.tool._format_billing_context(billing_context) self.assertIn("free subscription", formatted_string) self.assertIn("Active subscription: No (Free plan)", formatted_string) @@ -297,7 +301,7 @@ class TestBillingNode(ClickhouseTestMixin, BaseTest): def test_format_history_table_with_team_breakdown(self): """Test that history tables properly group by team when breakdown includes team IDs""" # Mock the teams map - self.node._teams_map = { + self.tool._teams_map = { 1: "Team Alpha (ID: 1)", 2: "Team Beta (ID: 2)", } @@ -337,7 +341,7 @@ class TestBillingNode(ClickhouseTestMixin, BaseTest): ), ] - table = self.node._format_history_table(usage_history) + table = self.tool._format_history_table(usage_history) # Check team-specific tables self.assertIn("### Team Alpha (ID: 1)", table) @@ -349,7 +353,7 @@ class TestBillingNode(ClickhouseTestMixin, BaseTest): self.assertIn("| Recordings | 100.00 | 200.00 |", table) self.assertIn("| Feature Flag Requests | 50.00 | 100.00 |", table) - def test_format_billing_context_edge_cases(self): + async def test_format_billing_context_edge_cases(self): """Test edge cases and potential security issues""" billing_context = MaxBillingContext( subscription_level=MaxBillingContextSubscriptionLevel.CUSTOM, @@ -374,8 +378,8 @@ class TestBillingNode(ClickhouseTestMixin, BaseTest): settings=MaxBillingContextSettings(autocapture_on=True, active_destinations=0), ) - with patch.object(self.node, "_get_top_events_by_usage", return_value=[]): - formatted_string = self.node._format_billing_context(billing_context) + with patch.object(self.tool, "_get_top_events_by_usage", return_value=[]): + formatted_string = await self.tool._format_billing_context(billing_context) # Check deactivated status self.assertIn("Status: Account is deactivated", formatted_string) @@ -393,7 +397,7 @@ class TestBillingNode(ClickhouseTestMixin, BaseTest): # Check exceeded limit warning self.assertIn("⚠️ Usage limit exceeded", formatted_string) - def test_format_billing_context_complete_template_coverage(self): + async def test_format_billing_context_complete_template_coverage(self): """Test all possible template variables are covered""" billing_context = MaxBillingContext( subscription_level=MaxBillingContextSubscriptionLevel.PAID, @@ -472,14 +476,14 @@ class TestBillingNode(ClickhouseTestMixin, BaseTest): ) # Mock teams map for history table - self.node._teams_map = {1: "Main Team (ID: 1)"} + self.tool._teams_map = {1: "Main Team (ID: 1)"} with patch.object( - self.node, + self.tool, "_get_top_events_by_usage", return_value=[{"event": "$identify", "count": 10000, "formatted_count": "10,000"}], ): - formatted_string = self.node._format_billing_context(billing_context) + formatted_string = await self.tool._format_billing_context(billing_context) # Verify all template sections are present self.assertIn("", formatted_string) @@ -545,12 +549,12 @@ class TestBillingNode(ClickhouseTestMixin, BaseTest): ), ] - self.node._teams_map = { + self.tool._teams_map = { 84444: "Project 84444", 12345: "Project 12345", } - table = self.node._format_history_table(usage_history) + table = self.tool._format_history_table(usage_history) # Should always include aggregated table first self.assertIn("### Overall (all projects)", table) @@ -597,7 +601,7 @@ class TestBillingNode(ClickhouseTestMixin, BaseTest): ), ] - table = self.node._format_history_table(usage_history) + table = self.tool._format_history_table(usage_history) # Should include aggregated table self.assertIn("### Overall (all projects)", table) @@ -645,7 +649,7 @@ class TestBillingNode(ClickhouseTestMixin, BaseTest): ) ] - aggregated = self.node._create_aggregated_items(team_items, other_items) + aggregated = self.tool._create_aggregated_items(team_items, other_items) # Should have 2 aggregated items: events and feature flags self.assertEqual(len(aggregated), 2) @@ -687,7 +691,7 @@ class TestBillingNode(ClickhouseTestMixin, BaseTest): ), ] - table = self.node._format_history_table(usage_history) + table = self.tool._format_history_table(usage_history) # Should only have aggregated table, no team-specific tables self.assertIn("### Overall (all projects)", table) @@ -721,18 +725,18 @@ class TestBillingNode(ClickhouseTestMixin, BaseTest): ) ] - self.node._teams_map = { + self.tool._teams_map = { 123: "Project 123", } # Test usage history formatting - usage_table = self.node._format_history_table(usage_history) + usage_table = self.tool._format_history_table(usage_history) self.assertIn("### Overall (all projects)", usage_table) self.assertIn("### Project 123", usage_table) self.assertIn("| Events | 1,000.00 |", usage_table) # Test spend history formatting - spend_table = self.node._format_history_table(spend_history) + spend_table = self.tool._format_history_table(spend_history) self.assertIn("### Overall (all projects)", spend_table) self.assertIn("### Project 123", spend_table) self.assertIn("| Events | 50.00 |", spend_table) @@ -758,7 +762,7 @@ class TestBillingNode(ClickhouseTestMixin, BaseTest): ), ] - table = self.node._format_history_table(usage_history) + table = self.tool._format_history_table(usage_history) # Should handle empty dates gracefully and still show valid data self.assertIn("### Overall (all projects)", table) diff --git a/ee/hogai/graph/billing/nodes.py b/ee/hogai/graph/root/tools/read_billing_tool/tool.py similarity index 91% rename from ee/hogai/graph/billing/nodes.py rename to ee/hogai/graph/root/tools/read_billing_tool/tool.py index 985337862b..6d002e380c 100644 --- a/ee/hogai/graph/billing/nodes.py +++ b/ee/hogai/graph/root/tools/read_billing_tool/tool.py @@ -1,18 +1,19 @@ -from typing import Any, cast -from uuid import uuid4 +from typing import Any from langchain_core.prompts import PromptTemplate from langchain_core.runnables import RunnableConfig -from posthog.schema import AssistantToolCallMessage, MaxBillingContext, SpendHistoryItem, UsageHistoryItem +from posthog.schema import MaxBillingContext, SpendHistoryItem, UsageHistoryItem from posthog.clickhouse.client import sync_execute +from posthog.models import Team, User +from posthog.sync import database_sync_to_async -from ee.hogai.graph.base import AssistantNode -from ee.hogai.graph.billing.prompts import BILLING_CONTEXT_PROMPT +from ee.hogai.context.context import AssistantContextManager +from ee.hogai.tool import MaxSubtool from ee.hogai.utils.types import AssistantState -from ee.hogai.utils.types.base import AssistantNodeName, PartialAssistantState -from ee.hogai.utils.types.composed import MaxNodeName + +from .prompts import BILLING_CONTEXT_PROMPT # sync with frontend/src/scenes/billing/constants.ts USAGE_TYPES = [ @@ -33,33 +34,27 @@ USAGE_TYPES = [ ] -class BillingNode(AssistantNode): - _teams_map: dict[int, str] = {} +class ReadBillingTool(MaxSubtool): + def __init__( + self, + *, + team: Team, + user: User, + state: AssistantState, + config: RunnableConfig, + context_manager: AssistantContextManager, + ): + super().__init__(team=team, user=user, state=state, config=config, context_manager=context_manager) + self._teams_map: dict[int, str] = {} - @property - def node_name(self) -> MaxNodeName: - return AssistantNodeName.BILLING - - def run(self, state: AssistantState, config: RunnableConfig) -> PartialAssistantState: - tool_call_id = cast(str, state.root_tool_call_id) - billing_context = self.context_manager.get_billing_context() + async def execute(self) -> str: + billing_context = self._context_manager.get_billing_context() if not billing_context: - return PartialAssistantState( - messages=[ - AssistantToolCallMessage( - content="No billing information available", id=str(uuid4()), tool_call_id=tool_call_id - ) - ], - root_tool_call_id=None, - ) - formatted_billing_context = self._format_billing_context(billing_context) - return PartialAssistantState( - messages=[ - AssistantToolCallMessage(content=formatted_billing_context, tool_call_id=tool_call_id, id=str(uuid4())), - ], - root_tool_call_id=None, - ) + return "No billing information available" + formatted_billing_context = await self._format_billing_context(billing_context) + return formatted_billing_context + @database_sync_to_async(thread_sensitive=False) def _format_billing_context(self, billing_context: MaxBillingContext) -> str: """Format billing context into a readable prompt section.""" # Convert billing context to a format suitable for the mustache template diff --git a/ee/hogai/graph/root/tools/read_data.py b/ee/hogai/graph/root/tools/read_data.py index 712d8c22d0..577743f5c3 100644 --- a/ee/hogai/graph/root/tools/read_data.py +++ b/ee/hogai/graph/root/tools/read_data.py @@ -1,4 +1,4 @@ -from typing import Any, Literal, Self +from typing import Literal, Self from langchain_core.runnables import RunnableConfig from pydantic import BaseModel @@ -9,12 +9,14 @@ from ee.hogai.context.context import AssistantContextManager from ee.hogai.graph.sql.mixins import HogQLDatabaseMixin from ee.hogai.tool import MaxTool from ee.hogai.utils.prompt import format_prompt_string -from ee.hogai.utils.types.base import AssistantState +from ee.hogai.utils.types.base import AssistantState, NodePath + +from .read_billing_tool.tool import ReadBillingTool READ_DATA_BILLING_PROMPT = """ # Billing information -Use this tool with the "billing_info" kind to retrieve the billing information if the user asks about billing, their subscription, their usage, or their spending. +Use this tool with the "billing_info" kind to retrieve the billing information if the user asks about their billing, subscription, product usage, spending, or cost reduction strategies. You can use the information retrieved to check which PostHog products and add-ons the user has activated, how much they are spending, their usage history across all products in the last 30 days, as well as trials, spending limits, billing period, and more. If the user wants to reduce their spending, always call this tool to get suggestions on how to do so. If an insight shows zero data, it could mean either the query is looking at the wrong data or there was a temporary data collection issue. You can investigate potential dips in usage/captured data using the billing tool. @@ -62,7 +64,7 @@ class ReadDataTool(HogQLDatabaseMixin, MaxTool): *, team: Team, user: User, - tool_call_id: str, + node_path: tuple[NodePath, ...] | None = None, state: AssistantState | None = None, config: RunnableConfig | None = None, context_manager: AssistantContextManager | None = None, @@ -83,21 +85,29 @@ class ReadDataTool(HogQLDatabaseMixin, MaxTool): return cls( team=team, user=user, - tool_call_id=tool_call_id, state=state, + node_path=node_path, config=config, args_schema=args, description=description, context_manager=context_manager, ) - async def _arun_impl(self, kind: ReadDataAdminAccessKind | ReadDataKind) -> tuple[str, dict[str, Any] | None]: + async def _arun_impl(self, kind: ReadDataAdminAccessKind | ReadDataKind) -> tuple[str, None]: match kind: case "billing_info": has_access = await self._context_manager.check_user_has_billing_access() if not has_access: return BILLING_INSUFFICIENT_ACCESS_PROMPT, None # used for routing - return "", self.args_schema(kind=kind).model_dump() + billing_tool = ReadBillingTool( + team=self._team, + user=self._user, + state=self._state, + config=self._config, + context_manager=self._context_manager, + ) + result = await billing_tool.execute() + return result, None case "datawarehouse_schema": return await self._serialize_database_schema(), None diff --git a/ee/hogai/graph/root/tools/read_taxonomy.py b/ee/hogai/graph/root/tools/read_taxonomy.py index 0fffab82e0..d088d7b5ca 100644 --- a/ee/hogai/graph/root/tools/read_taxonomy.py +++ b/ee/hogai/graph/root/tools/read_taxonomy.py @@ -9,7 +9,7 @@ from ee.hogai.context.context import AssistantContextManager from ee.hogai.graph.query_planner.toolkit import TaxonomyAgentToolkit from ee.hogai.tool import MaxTool from ee.hogai.utils.helpers import format_events_yaml -from ee.hogai.utils.types.base import AssistantState +from ee.hogai.utils.types.base import AssistantState, NodePath READ_TAXONOMY_TOOL_DESCRIPTION = """ Use this tool to explore the user's taxonomy (i.e. data schema). @@ -155,7 +155,7 @@ class ReadTaxonomyTool(MaxTool): *, team: Team, user: User, - tool_call_id: str, + node_path: tuple[NodePath, ...] | None = None, state: AssistantState | None = None, config: RunnableConfig | None = None, context_manager: AssistantContextManager | None = None, @@ -200,9 +200,9 @@ class ReadTaxonomyTool(MaxTool): return cls( team=team, user=user, - tool_call_id=tool_call_id, state=state, config=config, + node_path=node_path, args_schema=ReadTaxonomyToolArgsWithGroups, context_manager=context_manager, ) diff --git a/ee/hogai/graph/root/tools/search.py b/ee/hogai/graph/root/tools/search.py index 11c32dfb99..7f7910a282 100644 --- a/ee/hogai/graph/root/tools/search.py +++ b/ee/hogai/graph/root/tools/search.py @@ -1,17 +1,19 @@ -from typing import Any, Literal +from typing import Literal from django.conf import settings import posthoganalytics from langchain_core.output_parsers import SimpleJsonOutputParser from langchain_core.prompts import ChatPromptTemplate +from langchain_core.runnables import RunnableLambda from langchain_openai import ChatOpenAI from pydantic import BaseModel, Field -from posthog.models import Team, User - -from ee.hogai.graph.root.tools.full_text_search.tool import EntitySearchToolkit, FTSKind -from ee.hogai.tool import MaxTool +from ee.hogai.graph.insights.nodes import InsightSearchNode, NoInsightsException +from ee.hogai.graph.root.tools.full_text_search.tool import EntitySearchTool, FTSKind +from ee.hogai.tool import MaxSubtool, MaxTool, ToolMessagesArtifact +from ee.hogai.utils.prompt import format_prompt_string +from ee.hogai.utils.types.base import AssistantState, PartialAssistantState SEARCH_TOOL_PROMPT = """ Use this tool to search docs, insights, dashboards, cohorts, actions, experiments, feature flags, notebooks, error tracking issues, and surveys in PostHog. @@ -57,35 +59,10 @@ If you want to search for all entities, you should use `all`. """.strip() -DOCS_SEARCH_RESULTS_TEMPLATE = """Found {count} relevant documentation page(s): - -{docs} - -Use retrieved documentation to answer the user's question if it is relevant to the user's query. -Format the response using Markdown and reference the documentation using hyperlinks. -Every link to docs clearly explicitly be labeled, for example as "(see docs)". - +INVALID_ENTITY_KIND_PROMPT = """ +Invalid entity kind: {{{kind}}}. Please provide a valid entity kind for the tool. """.strip() -DOCS_SEARCH_NO_RESULTS_TEMPLATE = """ -No documentation found. - - -Do not answer the user's question if you did not find any documentation. Try rewriting the query. -If after a couple of attempts you still do not find any documentation, suggest the user navigate to the documentation page, which is available at `https://posthog.com/docs`. - -""".strip() - -DOC_ITEM_TEMPLATE = """ -# {title} -URL: {url} - -{text} -""".strip() - - -FTS_SEARCH_FEATURE_FLAG = "hogai-insights-fts-search" - ENTITIES = [f"{entity}" for entity in FTSKind if entity != FTSKind.INSIGHTS] SearchKind = Literal["insights", "docs", *ENTITIES] # type: ignore @@ -126,50 +103,106 @@ class SearchTool(MaxTool): context_prompt_template: str = "Searches documentation, insights, dashboards, cohorts, actions, experiments, feature flags, notebooks, error tracking issues, and surveys in PostHog" args_schema: type[BaseModel] = SearchToolArgs - @staticmethod - def _get_fts_entities(include_insight_fts: bool) -> list[str]: - if not include_insight_fts: - entities = [e for e in FTSKind if e != FTSKind.INSIGHTS] - else: - entities = list(FTSKind) - return [*entities, FTSKind.ALL] - - async def _arun_impl(self, kind: SearchKind, query: str) -> tuple[str, dict[str, Any] | None]: + async def _arun_impl(self, kind: str, query: str) -> tuple[str, ToolMessagesArtifact | None]: if kind == "docs": if not settings.INKEEP_API_KEY: return "This tool is not available in this environment.", None - if self._has_docs_search_feature_flag(): - return await self._search_docs(query), None + docs_tool = InkeepDocsSearchTool( + team=self._team, + user=self._user, + state=self._state, + config=self._config, + context_manager=self._context_manager, + ) + return await docs_tool.execute(query, self.tool_call_id) - fts_entities = SearchTool._get_fts_entities(SearchTool._has_fts_search_feature_flag(self._user, self._team)) + if kind == "insights" and not self._has_insights_fts_search_feature_flag(): + insights_tool = InsightSearchTool( + team=self._team, + user=self._user, + state=self._state, + config=self._config, + context_manager=self._context_manager, + ) + return await insights_tool.execute(query, self.tool_call_id) - if kind in fts_entities: - entity_search_toolkit = EntitySearchToolkit(self._team, self._user) - response = await entity_search_toolkit.execute(query, FTSKind(kind)) - return response, None - # Used for routing - return "Search tool executed", SearchToolArgs(kind=kind, query=query).model_dump() + if kind not in self._fts_entities: + return format_prompt_string(INVALID_ENTITY_KIND_PROMPT, kind=kind), None - def _has_docs_search_feature_flag(self) -> bool: + entity_search_toolkit = EntitySearchTool( + team=self._team, + user=self._user, + state=self._state, + config=self._config, + context_manager=self._context_manager, + ) + response = await entity_search_toolkit.execute(query, FTSKind(kind)) + return response, None + + @property + def _fts_entities(self) -> list[str]: + entities = list(FTSKind) + return [*entities, FTSKind.ALL] + + def _has_insights_fts_search_feature_flag(self) -> bool: return posthoganalytics.feature_enabled( - "max-inkeep-rag-docs-search", + "hogai-insights-fts-search", str(self._user.distinct_id), groups={"organization": str(self._team.organization_id)}, group_properties={"organization": {"id": str(self._team.organization_id)}}, send_feature_flag_events=False, ) - @staticmethod - def _has_fts_search_feature_flag(user: User, team: Team) -> bool: - return posthoganalytics.feature_enabled( - FTS_SEARCH_FEATURE_FLAG, - str(user.distinct_id), - groups={"organization": str(team.organization_id)}, - group_properties={"organization": {"id": str(team.organization_id)}}, - send_feature_flag_events=False, - ) - async def _search_docs(self, query: str) -> str: +DOCS_SEARCH_RESULTS_TEMPLATE = """Found {count} relevant documentation page(s): + +{docs} + +Use retrieved documentation to answer the user's question if it is relevant to the user's query. +Format the response using Markdown and reference the documentation using hyperlinks. +Every link to docs clearly explicitly be labeled, for example as "(see docs)". + +""".strip() + +DOCS_SEARCH_NO_RESULTS_TEMPLATE = """ +No documentation found. + + +Do not answer the user's question if you did not find any documentation. Try rewriting the query. +If after a couple of attempts you still do not find any documentation, suggest the user navigate to the documentation page, which is available at `https://posthog.com/docs`. + +""".strip() + +DOC_ITEM_TEMPLATE = """ +# {title} +URL: {url} + +{text} +""".strip() + + +class InkeepDocsSearchTool(MaxSubtool): + async def execute(self, query: str, tool_call_id: str) -> tuple[str, ToolMessagesArtifact | None]: + if self._has_rag_docs_search_feature_flag(): + return await self._search_using_rag_endpoint(query, tool_call_id) + else: + return await self._search_using_node(query, tool_call_id) + + async def _search_using_node(self, query: str, tool_call_id: str) -> tuple[str, ToolMessagesArtifact | None]: + # Avoid circular import + from ee.hogai.graph.inkeep_docs.nodes import InkeepDocsNode + + # Init the graph + node = InkeepDocsNode(self._team, self._user) + chain: RunnableLambda[AssistantState, PartialAssistantState | None] = RunnableLambda(node) + copied_state = self._state.model_copy(deep=True, update={"root_tool_call_id": tool_call_id}) + result = await chain.ainvoke(copied_state) + assert result is not None + return "", ToolMessagesArtifact(messages=result.messages) + + async def _search_using_rag_endpoint( + self, query: str, tool_call_id: str + ) -> tuple[str, ToolMessagesArtifact | None]: model = ChatOpenAI( model="inkeep-rag", base_url="https://api.inkeep.com/v1/", @@ -184,7 +217,7 @@ class SearchTool(MaxTool): rag_context_raw = await chain.ainvoke({"query": query}) if not rag_context_raw or not rag_context_raw.get("content"): - return DOCS_SEARCH_NO_RESULTS_TEMPLATE + return DOCS_SEARCH_NO_RESULTS_TEMPLATE, None rag_context = InkeepResponse.model_validate(rag_context_raw) @@ -197,7 +230,35 @@ class SearchTool(MaxTool): docs.append(DOC_ITEM_TEMPLATE.format(title=doc.title, url=doc.url, text=text)) if not docs: - return DOCS_SEARCH_NO_RESULTS_TEMPLATE + return DOCS_SEARCH_NO_RESULTS_TEMPLATE, None formatted_docs = "\n\n---\n\n".join(docs) - return DOCS_SEARCH_RESULTS_TEMPLATE.format(count=len(docs), docs=formatted_docs) + return DOCS_SEARCH_RESULTS_TEMPLATE.format(count=len(docs), docs=formatted_docs), None + + def _has_rag_docs_search_feature_flag(self) -> bool: + return posthoganalytics.feature_enabled( + "max-inkeep-rag-docs-search", + str(self._user.distinct_id), + groups={"organization": str(self._team.organization_id)}, + group_properties={"organization": {"id": str(self._team.organization_id)}}, + send_feature_flag_events=False, + ) + + +EMPTY_DATABASE_ERROR_MESSAGE = """ +The user doesn't have any insights created yet. +""".strip() + + +class InsightSearchTool(MaxSubtool): + async def execute(self, query: str, tool_call_id: str) -> tuple[str, ToolMessagesArtifact | None]: + try: + node = InsightSearchNode(self._team, self._user) + copied_state = self._state.model_copy( + deep=True, update={"search_insights_query": query, "root_tool_call_id": tool_call_id} + ) + chain: RunnableLambda[AssistantState, PartialAssistantState | None] = RunnableLambda(node) + result = await chain.ainvoke(copied_state) + return "", ToolMessagesArtifact(messages=result.messages) if result else None + except NoInsightsException: + return EMPTY_DATABASE_ERROR_MESSAGE, None diff --git a/ee/hogai/graph/root/tools/session_summarization.py b/ee/hogai/graph/root/tools/session_summarization.py new file mode 100644 index 0000000000..594358ec6b --- /dev/null +++ b/ee/hogai/graph/root/tools/session_summarization.py @@ -0,0 +1,122 @@ +from typing import Literal + +from langchain_core.runnables import RunnableLambda +from pydantic import BaseModel, Field + +from ee.hogai.graph.session_summaries.nodes import SessionSummarizationNode +from ee.hogai.tool import MaxTool, ToolMessagesArtifact +from ee.hogai.utils.types.base import AssistantState, PartialAssistantState + +SESSION_SUMMARIZATION_TOOL_PROMPT = """ +Use this tool to summarize session recordings by analysing the events within those sessions to find patterns and issues. +It will return a textual summary of the captured session recordings. + +# When to use the tool: +When the user asks to summarize session recordings: +- "summarize" synonyms: "watch", "analyze", "review", and similar +- "session recordings" synonyms: "sessions", "recordings", "replays", "user sessions", and similar + +# When NOT to use the tool: +- When the user asks to find, search for, or look up session recordings, but doesn't ask to summarize them +- When users asks to update, change, or adjust session recordings filters + +# Synonyms +- "summarize": "watch", "analyze", "review", and similar +- "session recordings": "sessions", "recordings", "replays", "user sessions", and similar + +# Managing context +If the conversation history contains context about the current filters or session recordings, follow these steps: +- Convert the user query into a `session_summarization_query` +- The query should be used to understand the user's intent +- Decide if the query is relevant to the current filters and set `should_use_current_filters` accordingly +- Generate the `summary_title` based on the user's query and the current filters + +Otherwise: +- Convert the user query into a `session_summarization_query` +- The query should be used to search for relevant sessions and then summarize them +- Assume the `should_use_current_filters` should be always `false` +- Generate the `summary_title` based on the user's query + +# Additional guidelines +- CRITICAL: Always pass the user's complete, unmodified query to the `session_summarization_query` parameter +- DO NOT truncate, summarize, or extract keywords from the user's query +- The query is used to find relevant sessions - context helps find better matches +- Use explicit tool definition to make a decision +""".strip() + + +class SessionSummarizationToolArgs(BaseModel): + session_summarization_query: str = Field( + description=""" + - The user's complete query for session recordings summarization. + - This will be used to find relevant session recordings. + - Always pass the user's complete, unmodified query. + - Examples: + * 'summarize all session recordings from yesterday' + * 'analyze mobile user session recordings from last week, even if 1 second' + * 'watch last 300 session recordings of MacOS users from US' + * and similar + """ + ) + should_use_current_filters: bool = Field( + description=""" + - Whether to use current filters from user's UI to find relevant session recordings. + - IMPORTANT: Should be always `false` if the current filters or `search_session_recordings` tool are not present in the conversation history. + - Examples: + * Set to `true` if one of the conditions is met: + - the user wants to summarize "current/selected/opened/my/all/these" session recordings + - the user wants to use "current/these" filters + - the user's query specifies filters identical to the current filters + - if the user's query doesn't specify any filters/conditions + - the user refers to what they're "looking at" or "viewing" + * Set to `false` if one of the conditions is met: + - no current filters or `search_session_recordings` tool are present in the conversation + - the user specifies date/time period different from the current filters + - the user specifies conditions (user, device, id, URL, etc.) not present in the current filters + """, + ) + summary_title: str = Field( + description=""" + - The name of the summary that is expected to be generated from the user's `session_summarization_query` and/or `current_filters` (if present). + - The name should cover in 3-7 words what sessions would be to be summarized in the summary + - This won't be used for any search of filtering, only to properly label the generated summary. + - Examples: + * If `should_use_current_filters` is `false`, then the `summary_title` should be generated based on the `session_summarization_query`: + - query: "I want to watch all the sessions of user `user@example.com` in the last 30 days no matter how long" -> name: "Sessions of the user user@example.com (last 30 days)" + - query: "summarize my last 100 session recordings" -> name: "Last 100 sessions" + - and similar + * If `should_use_current_filters` is `true`, then the `summary_title` should be generated based on the current filters in the context (if present): + - filters: "{"key":"$os","value":["Mac OS X"],"operator":"exact","type":"event"}" -> name: "MacOS users" + - filters: "{"date_from": "-7d", "filter_test_accounts": True}" -> name: "All sessions (last 7 days)" + - and similar + * If there's not enough context to generated the summary name - keep it an empty string ("") + """ + ) + + +class SessionSummarizationTool(MaxTool): + name: Literal["session_summarization"] = "session_summarization" + description: str = SESSION_SUMMARIZATION_TOOL_PROMPT + thinking_message: str = "Summarizing session recordings" + context_prompt_template: str = "Summarizes session recordings based on the user's query and current filters" + args_schema: type[BaseModel] = SessionSummarizationToolArgs + show_tool_call_message: bool = False + + async def _arun_impl( + self, session_summarization_query: str, should_use_current_filters: bool, summary_title: str + ) -> tuple[str, ToolMessagesArtifact | None]: + node = SessionSummarizationNode(self._team, self._user) + chain: RunnableLambda[AssistantState, PartialAssistantState | None] = RunnableLambda(node) + copied_state = self._state.model_copy( + deep=True, + update={ + "root_tool_call_id": self.tool_call_id, + "session_summarization_query": session_summarization_query, + "should_use_current_filters": should_use_current_filters, + "summary_title": summary_title, + }, + ) + result = await chain.ainvoke(copied_state) + if not result or not result.messages: + return "Session summarization failed", None + return "", ToolMessagesArtifact(messages=result.messages) diff --git a/ee/hogai/graph/root/tools/test/test_create_and_query_insight.py b/ee/hogai/graph/root/tools/test/test_create_and_query_insight.py new file mode 100644 index 0000000000..c2926d53db --- /dev/null +++ b/ee/hogai/graph/root/tools/test/test_create_and_query_insight.py @@ -0,0 +1,253 @@ +from typing import Any + +from posthog.test.base import ClickhouseTestMixin, NonAtomicBaseTest +from unittest.mock import AsyncMock, patch + +from langchain_core.runnables import RunnableConfig + +from posthog.schema import ( + AssistantMessage, + AssistantTool, + AssistantToolCallMessage, + AssistantTrendsQuery, + VisualizationMessage, +) + +from ee.hogai.context.context import AssistantContextManager +from ee.hogai.graph.root.tools.create_and_query_insight import ( + INSIGHT_TOOL_FAILURE_SYSTEM_REMINDER_PROMPT, + CreateAndQueryInsightTool, +) +from ee.hogai.graph.schema_generator.nodes import SchemaGenerationException +from ee.hogai.utils.types import AssistantState +from ee.hogai.utils.types.base import NodePath +from ee.models.assistant import Conversation + + +class TestCreateAndQueryInsightTool(ClickhouseTestMixin, NonAtomicBaseTest): + CLASS_DATA_LEVEL_SETUP = False + + def setUp(self): + super().setUp() + self.conversation = Conversation.objects.create(team=self.team, user=self.user) + self.tool_call_id = "test_tool_call_id" + + def _create_tool( + self, state: AssistantState | None = None, contextual_tools: dict[str, dict[str, Any]] | None = None + ): + """Helper to create tool instance with optional state and contextual tools.""" + if state is None: + state = AssistantState(messages=[]) + + config: RunnableConfig = RunnableConfig() + if contextual_tools: + config = RunnableConfig(configurable={"contextual_tools": contextual_tools}) + + context_manager = AssistantContextManager(team=self.team, user=self.user, config=config) + + return CreateAndQueryInsightTool( + team=self.team, + user=self.user, + state=state, + context_manager=context_manager, + node_path=(NodePath(name="test_node", tool_call_id=self.tool_call_id, message_id="test"),), + ) + + async def test_successful_insight_creation_returns_messages(self): + """Test successful insight creation returns visualization and tool call messages.""" + tool = self._create_tool() + + query = AssistantTrendsQuery(series=[]) + viz_message = VisualizationMessage(query="test query", answer=query, plan="test plan") + tool_call_message = AssistantToolCallMessage(content="Results are here", tool_call_id=self.tool_call_id) + + mock_state = AssistantState(messages=[viz_message, tool_call_message]) + + with patch("ee.hogai.graph.insights_graph.graph.InsightsGraph.compile_full_graph") as mock_compile: + mock_graph = AsyncMock() + mock_graph.ainvoke = AsyncMock(return_value=mock_state.model_dump()) + mock_compile.return_value = mock_graph + + result_text, artifact = await tool._arun_impl(query_description="test description") + + self.assertEqual(result_text, "") + self.assertIsNotNone(artifact) + self.assertEqual(len(artifact.messages), 2) + self.assertIsInstance(artifact.messages[0], VisualizationMessage) + self.assertIsInstance(artifact.messages[1], AssistantToolCallMessage) + + async def test_schema_generation_exception_returns_formatted_error(self): + """Test SchemaGenerationException is caught and returns formatted error message.""" + tool = self._create_tool() + + exception = SchemaGenerationException( + llm_output="Invalid query structure", validation_message="Missing required field: series" + ) + + with patch("ee.hogai.graph.insights_graph.graph.InsightsGraph.compile_full_graph") as mock_compile: + mock_graph = AsyncMock() + mock_graph.ainvoke = AsyncMock(side_effect=exception) + mock_compile.return_value = mock_graph + + result_text, artifact = await tool._arun_impl(query_description="test description") + + self.assertIsNone(artifact) + self.assertIn("Invalid query structure", result_text) + self.assertIn("Missing required field: series", result_text) + self.assertIn(INSIGHT_TOOL_FAILURE_SYSTEM_REMINDER_PROMPT, result_text) + + async def test_invalid_tool_call_message_type_returns_error(self): + """Test when the last message is not AssistantToolCallMessage, returns error.""" + tool = self._create_tool() + + viz_message = VisualizationMessage(query="test query", answer=AssistantTrendsQuery(series=[]), plan="test plan") + # Last message is AssistantMessage instead of AssistantToolCallMessage + invalid_message = AssistantMessage(content="Not a tool call message") + + mock_state = AssistantState(messages=[viz_message, invalid_message]) + + with patch("ee.hogai.graph.insights_graph.graph.InsightsGraph.compile_full_graph") as mock_compile: + mock_graph = AsyncMock() + mock_graph.ainvoke = AsyncMock(return_value=mock_state.model_dump()) + mock_compile.return_value = mock_graph + + result_text, artifact = await tool._arun_impl(query_description="test description") + + self.assertIsNone(artifact) + self.assertIn("unknown error", result_text) + self.assertIn(INSIGHT_TOOL_FAILURE_SYSTEM_REMINDER_PROMPT, result_text) + + async def test_human_feedback_requested_returns_only_tool_call_message(self): + """Test when visualization message is not present, returns only tool call message.""" + tool = self._create_tool() + + # When agent requests human feedback, there's no VisualizationMessage + some_message = AssistantMessage(content="I need help with this query") + tool_call_message = AssistantToolCallMessage(content="Need clarification", tool_call_id=self.tool_call_id) + + mock_state = AssistantState(messages=[some_message, tool_call_message]) + + with patch("ee.hogai.graph.insights_graph.graph.InsightsGraph.compile_full_graph") as mock_compile: + mock_graph = AsyncMock() + mock_graph.ainvoke = AsyncMock(return_value=mock_state.model_dump()) + mock_compile.return_value = mock_graph + + result_text, artifact = await tool._arun_impl(query_description="test description") + + self.assertEqual(result_text, "") + self.assertIsNotNone(artifact) + self.assertEqual(len(artifact.messages), 1) + self.assertIsInstance(artifact.messages[0], AssistantToolCallMessage) + self.assertEqual(artifact.messages[0].content, "Need clarification") + + async def test_editing_mode_adds_ui_payload(self): + """Test that in editing mode, UI payload is added to tool call message.""" + # Create tool with contextual tool available + tool = self._create_tool(contextual_tools={AssistantTool.CREATE_AND_QUERY_INSIGHT.value: {}}) + + query = AssistantTrendsQuery(series=[]) + viz_message = VisualizationMessage(query="test query", answer=query, plan="test plan") + tool_call_message = AssistantToolCallMessage(content="Results are here", tool_call_id=self.tool_call_id) + + mock_state = AssistantState(messages=[viz_message, tool_call_message]) + + with patch("ee.hogai.graph.insights_graph.graph.InsightsGraph.compile_full_graph") as mock_compile: + mock_graph = AsyncMock() + mock_graph.ainvoke = AsyncMock(return_value=mock_state.model_dump()) + mock_compile.return_value = mock_graph + + result_text, artifact = await tool._arun_impl(query_description="test description") + + self.assertEqual(result_text, "") + self.assertIsNotNone(artifact) + self.assertEqual(len(artifact.messages), 2) + + # Check that UI payload was added to tool call message + returned_tool_call_message = artifact.messages[1] + self.assertIsInstance(returned_tool_call_message, AssistantToolCallMessage) + self.assertIsNotNone(returned_tool_call_message.ui_payload) + self.assertIn("create_and_query_insight", returned_tool_call_message.ui_payload) + self.assertEqual( + returned_tool_call_message.ui_payload["create_and_query_insight"], query.model_dump(exclude_none=True) + ) + + async def test_non_editing_mode_no_ui_payload(self): + """Test that in non-editing mode, no UI payload is added to tool call message.""" + # Create tool without contextual tools (non-editing mode) + tool = self._create_tool(contextual_tools={}) + + query = AssistantTrendsQuery(series=[]) + viz_message = VisualizationMessage(query="test query", answer=query, plan="test plan") + tool_call_message = AssistantToolCallMessage(content="Results are here", tool_call_id=self.tool_call_id) + + mock_state = AssistantState(messages=[viz_message, tool_call_message]) + + with patch("ee.hogai.graph.insights_graph.graph.InsightsGraph.compile_full_graph") as mock_compile: + mock_graph = AsyncMock() + mock_graph.ainvoke = AsyncMock(return_value=mock_state.model_dump()) + mock_compile.return_value = mock_graph + + result_text, artifact = await tool._arun_impl(query_description="test description") + + self.assertEqual(result_text, "") + self.assertIsNotNone(artifact) + self.assertEqual(len(artifact.messages), 2) + + # Check that the original tool call message is returned without modification + returned_tool_call_message = artifact.messages[1] + self.assertIsInstance(returned_tool_call_message, AssistantToolCallMessage) + # In non-editing mode, the original message is returned as-is + self.assertEqual(returned_tool_call_message, tool_call_message) + + async def test_state_updates_include_tool_call_metadata(self): + """Test that the state passed to graph includes root_tool_call_id and root_tool_insight_plan.""" + initial_state = AssistantState(messages=[AssistantMessage(content="initial")]) + tool = self._create_tool(state=initial_state) + + query = AssistantTrendsQuery(series=[]) + viz_message = VisualizationMessage(query="test query", answer=query, plan="test plan") + tool_call_message = AssistantToolCallMessage(content="Results", tool_call_id=self.tool_call_id) + mock_state = AssistantState(messages=[viz_message, tool_call_message]) + + invoked_state = None + + async def capture_invoked_state(state): + nonlocal invoked_state + invoked_state = state + return mock_state.model_dump() + + with patch("ee.hogai.graph.insights_graph.graph.InsightsGraph.compile_full_graph") as mock_compile: + mock_graph = AsyncMock() + mock_graph.ainvoke = AsyncMock(side_effect=capture_invoked_state) + mock_compile.return_value = mock_graph + + await tool._arun_impl(query_description="my test query") + + # Verify the state passed to ainvoke has the correct metadata + self.assertIsNotNone(invoked_state) + validated_state = AssistantState.model_validate(invoked_state) + self.assertEqual(validated_state.root_tool_call_id, self.tool_call_id) + self.assertEqual(validated_state.root_tool_insight_plan, "my test query") + # Original message should still be there + self.assertEqual(len(validated_state.messages), 1) + assert isinstance(validated_state.messages[0], AssistantMessage) + self.assertEqual(validated_state.messages[0].content, "initial") + + async def test_is_editing_mode_classmethod(self): + """Test the is_editing_mode class method correctly detects editing mode.""" + # Test with editing mode enabled + config_editing: RunnableConfig = RunnableConfig( + configurable={"contextual_tools": {AssistantTool.CREATE_AND_QUERY_INSIGHT.value: {}}} + ) + context_manager_editing = AssistantContextManager(team=self.team, user=self.user, config=config_editing) + self.assertTrue(CreateAndQueryInsightTool.is_editing_mode(context_manager_editing)) + + # Test with editing mode disabled + config_not_editing = RunnableConfig(configurable={"contextual_tools": {}}) + context_manager_not_editing = AssistantContextManager(team=self.team, user=self.user, config=config_not_editing) + self.assertFalse(CreateAndQueryInsightTool.is_editing_mode(context_manager_not_editing)) + + # Test with other contextual tools but not create_and_query_insight + config_other = RunnableConfig(configurable={"contextual_tools": {"some_other_tool": {}}}) + context_manager_other = AssistantContextManager(team=self.team, user=self.user, config=config_other) + self.assertFalse(CreateAndQueryInsightTool.is_editing_mode(context_manager_other)) diff --git a/ee/hogai/graph/root/tools/test/test_create_dashboard.py b/ee/hogai/graph/root/tools/test/test_create_dashboard.py new file mode 100644 index 0000000000..39eb8e3150 --- /dev/null +++ b/ee/hogai/graph/root/tools/test/test_create_dashboard.py @@ -0,0 +1,248 @@ +from typing import cast +from uuid import uuid4 + +from posthog.test.base import ClickhouseTestMixin, NonAtomicBaseTest +from unittest.mock import AsyncMock, MagicMock, patch + +from posthog.schema import AssistantMessage + +from ee.hogai.context.context import AssistantContextManager +from ee.hogai.graph.root.tools.create_dashboard import CreateDashboardTool +from ee.hogai.utils.types import AssistantState, InsightQuery, PartialAssistantState +from ee.hogai.utils.types.base import NodePath + + +class TestCreateDashboardTool(ClickhouseTestMixin, NonAtomicBaseTest): + CLASS_DATA_LEVEL_SETUP = False + + def setUp(self): + super().setUp() + self.tool_call_id = str(uuid4()) + self.state = AssistantState(messages=[], root_tool_call_id=self.tool_call_id) + self.context_manager = AssistantContextManager(self.team, self.user, {}) + self.tool = CreateDashboardTool( + team=self.team, + user=self.user, + state=self.state, + context_manager=self.context_manager, + node_path=(NodePath(name="test_node", tool_call_id=self.tool_call_id, message_id="test"),), + ) + + async def test_execute_calls_dashboard_creation_node(self): + mock_node_instance = MagicMock() + mock_result = PartialAssistantState( + messages=[AssistantMessage(content="Dashboard created successfully with 3 insights")] + ) + + with patch("ee.hogai.graph.dashboards.nodes.DashboardCreationNode", return_value=mock_node_instance): + with patch("ee.hogai.graph.root.tools.create_dashboard.RunnableLambda") as mock_runnable: + mock_chain = MagicMock() + mock_chain.ainvoke = AsyncMock(return_value=mock_result) + mock_runnable.return_value = mock_chain + + insight_queries = [ + InsightQuery(name="Pageviews", description="Show pageviews for last 7 days"), + InsightQuery(name="User signups", description="Show user signups funnel"), + ] + + result, artifact = await self.tool._arun_impl( + search_insights_queries=insight_queries, + dashboard_name="Marketing Dashboard", + ) + + self.assertEqual(result, "") + self.assertIsNotNone(artifact) + assert artifact is not None + self.assertEqual(len(artifact.messages), 1) + message = cast(AssistantMessage, artifact.messages[0]) + self.assertEqual(message.content, "Dashboard created successfully with 3 insights") + + async def test_execute_updates_state_with_all_parameters(self): + mock_result = PartialAssistantState(messages=[AssistantMessage(content="Test response")]) + + insight_queries = [ + InsightQuery(name="Revenue Trends", description="Monthly revenue trends for Q4"), + InsightQuery(name="Churn Rate", description="Customer churn rate by cohort"), + InsightQuery(name="NPS Score", description="Net Promoter Score over time"), + ] + + async def mock_ainvoke(state): + self.assertEqual(state.search_insights_queries, insight_queries) + self.assertEqual(state.dashboard_name, "Executive Summary Q4") + self.assertEqual(state.root_tool_call_id, self.tool_call_id) + return mock_result + + mock_chain = MagicMock() + mock_chain.ainvoke = mock_ainvoke + + with patch("ee.hogai.graph.dashboards.nodes.DashboardCreationNode"): + with patch("ee.hogai.graph.root.tools.create_dashboard.RunnableLambda", return_value=mock_chain): + await self.tool._arun_impl( + search_insights_queries=insight_queries, + dashboard_name="Executive Summary Q4", + ) + + async def test_execute_with_single_insight_query(self): + mock_result = PartialAssistantState(messages=[AssistantMessage(content="Dashboard with one insight created")]) + + insight_queries = [InsightQuery(name="Daily Active Users", description="Count of daily active users")] + + async def mock_ainvoke(state): + self.assertEqual(len(state.search_insights_queries), 1) + self.assertEqual(state.search_insights_queries[0].name, "Daily Active Users") + self.assertEqual(state.search_insights_queries[0].description, "Count of daily active users") + self.assertEqual(state.dashboard_name, "User Activity Dashboard") + return mock_result + + mock_chain = MagicMock() + mock_chain.ainvoke = mock_ainvoke + + with patch("ee.hogai.graph.dashboards.nodes.DashboardCreationNode"): + with patch("ee.hogai.graph.root.tools.create_dashboard.RunnableLambda", return_value=mock_chain): + result, artifact = await self.tool._arun_impl( + search_insights_queries=insight_queries, + dashboard_name="User Activity Dashboard", + ) + + self.assertEqual(result, "") + self.assertIsNotNone(artifact) + + async def test_execute_with_many_insight_queries(self): + mock_result = PartialAssistantState(messages=[AssistantMessage(content="Large dashboard created")]) + + insight_queries = [ + InsightQuery(name=f"Insight {i}", description=f"Description for insight {i}") for i in range(10) + ] + + async def mock_ainvoke(state): + self.assertEqual(len(state.search_insights_queries), 10) + self.assertEqual(state.search_insights_queries[0].name, "Insight 0") + self.assertEqual(state.search_insights_queries[9].name, "Insight 9") + self.assertEqual(state.dashboard_name, "Comprehensive Dashboard") + return mock_result + + mock_chain = MagicMock() + mock_chain.ainvoke = mock_ainvoke + + with patch("ee.hogai.graph.dashboards.nodes.DashboardCreationNode"): + with patch("ee.hogai.graph.root.tools.create_dashboard.RunnableLambda", return_value=mock_chain): + result, artifact = await self.tool._arun_impl( + search_insights_queries=insight_queries, + dashboard_name="Comprehensive Dashboard", + ) + + self.assertEqual(result, "") + self.assertIsNotNone(artifact) + + async def test_execute_returns_failure_message_when_result_is_none(self): + async def mock_ainvoke(state): + return None + + mock_chain = MagicMock() + mock_chain.ainvoke = mock_ainvoke + + with patch("ee.hogai.graph.dashboards.nodes.DashboardCreationNode"): + with patch("ee.hogai.graph.root.tools.create_dashboard.RunnableLambda", return_value=mock_chain): + result, artifact = await self.tool._arun_impl( + search_insights_queries=[InsightQuery(name="Test", description="Test insight")], + dashboard_name="Test Dashboard", + ) + + self.assertEqual(result, "Dashboard creation failed") + self.assertIsNone(artifact) + + async def test_execute_returns_failure_message_when_result_has_no_messages(self): + mock_result = PartialAssistantState(messages=[]) + + async def mock_ainvoke(state): + return mock_result + + mock_chain = MagicMock() + mock_chain.ainvoke = mock_ainvoke + + with patch("ee.hogai.graph.dashboards.nodes.DashboardCreationNode"): + with patch("ee.hogai.graph.root.tools.create_dashboard.RunnableLambda", return_value=mock_chain): + result, artifact = await self.tool._arun_impl( + search_insights_queries=[InsightQuery(name="Test", description="Test insight")], + dashboard_name="Test Dashboard", + ) + + self.assertEqual(result, "Dashboard creation failed") + self.assertIsNone(artifact) + + async def test_execute_preserves_original_state(self): + """Test that the original state is not modified when creating the copied state""" + original_queries = [InsightQuery(name="Original", description="Original insight")] + original_state = AssistantState( + messages=[], + root_tool_call_id=self.tool_call_id, + search_insights_queries=original_queries, + dashboard_name="Original Dashboard", + ) + + tool = CreateDashboardTool( + team=self.team, + user=self.user, + state=original_state, + context_manager=self.context_manager, + node_path=(NodePath(name="test_node", tool_call_id=self.tool_call_id, message_id="test"),), + ) + + mock_result = PartialAssistantState(messages=[AssistantMessage(content="Test")]) + + new_queries = [InsightQuery(name="New", description="New insight")] + + async def mock_ainvoke(state): + # Verify the new state has updated values + self.assertEqual(state.search_insights_queries, new_queries) + self.assertEqual(state.dashboard_name, "New Dashboard") + return mock_result + + mock_chain = MagicMock() + mock_chain.ainvoke = mock_ainvoke + + with patch("ee.hogai.graph.dashboards.nodes.DashboardCreationNode"): + with patch("ee.hogai.graph.root.tools.create_dashboard.RunnableLambda", return_value=mock_chain): + await tool._arun_impl( + search_insights_queries=new_queries, + dashboard_name="New Dashboard", + ) + + # Verify original state was not modified + self.assertEqual(original_state.search_insights_queries, original_queries) + self.assertEqual(original_state.dashboard_name, "Original Dashboard") + self.assertEqual(original_state.root_tool_call_id, self.tool_call_id) + + async def test_execute_with_complex_insight_descriptions(self): + """Test that complex insight descriptions with special characters are handled correctly""" + mock_result = PartialAssistantState(messages=[AssistantMessage(content="Dashboard created")]) + + insight_queries = [ + InsightQuery( + name="User Journey", + description='Show funnel from "Sign up" → "Create project" → "Invite team" for users with property email containing "@company.com"', + ), + InsightQuery( + name="Revenue (USD)", + description="Track total revenue in USD from 2024-01-01 to 2024-12-31, filtered by plan_type = 'premium' OR plan_type = 'enterprise'", + ), + ] + + async def mock_ainvoke(state): + self.assertEqual(len(state.search_insights_queries), 2) + self.assertIn("@company.com", state.search_insights_queries[0].description) + self.assertIn("'premium'", state.search_insights_queries[1].description) + return mock_result + + mock_chain = MagicMock() + mock_chain.ainvoke = mock_ainvoke + + with patch("ee.hogai.graph.dashboards.nodes.DashboardCreationNode"): + with patch("ee.hogai.graph.root.tools.create_dashboard.RunnableLambda", return_value=mock_chain): + result, artifact = await self.tool._arun_impl( + search_insights_queries=insight_queries, + dashboard_name="Complex Dashboard", + ) + + self.assertEqual(result, "") + self.assertIsNotNone(artifact) diff --git a/ee/hogai/graph/root/tools/test/test_search.py b/ee/hogai/graph/root/tools/test/test_search.py index 4bbdb14b9f..5077b79eaa 100644 --- a/ee/hogai/graph/root/tools/test/test_search.py +++ b/ee/hogai/graph/root/tools/test/test_search.py @@ -1,27 +1,178 @@ -from posthog.test.base import BaseTest -from unittest.mock import patch +from typing import cast +from uuid import uuid4 + +from posthog.test.base import ClickhouseTestMixin, NonAtomicBaseTest +from unittest.mock import AsyncMock, MagicMock, patch from django.test import override_settings from langchain_core import messages +from langchain_core.runnables import RunnableConfig +from posthog.schema import AssistantMessage + +from ee.hogai.context.context import AssistantContextManager from ee.hogai.graph.root.tools.search import ( DOC_ITEM_TEMPLATE, - DOCS_SEARCH_NO_RESULTS_TEMPLATE, DOCS_SEARCH_RESULTS_TEMPLATE, + EMPTY_DATABASE_ERROR_MESSAGE, + InkeepDocsSearchTool, + InsightSearchTool, SearchTool, ) from ee.hogai.utils.tests import FakeChatOpenAI +from ee.hogai.utils.types import AssistantState, PartialAssistantState +from ee.hogai.utils.types.base import NodePath -class TestSearchToolDocumentation(BaseTest): +class TestSearchTool(ClickhouseTestMixin, NonAtomicBaseTest): + CLASS_DATA_LEVEL_SETUP = False + def setUp(self): super().setUp() - self.tool = SearchTool(team=self.team, user=self.user, tool_call_id="test-tool-call-id") + self.tool_call_id = "test_tool_call_id" + self.state = AssistantState(messages=[], root_tool_call_id=str(uuid4())) + self.context_manager = AssistantContextManager(self.team, self.user, {}) + self.tool = SearchTool( + team=self.team, + user=self.user, + state=self.state, + context_manager=self.context_manager, + node_path=(NodePath(name="test_node", tool_call_id=self.tool_call_id, message_id="test"),), + ) + + async def test_run_docs_search_without_api_key(self): + with patch("ee.hogai.graph.root.tools.search.settings") as mock_settings: + mock_settings.INKEEP_API_KEY = None + result, artifact = await self.tool._arun_impl(kind="docs", query="How to use feature flags?") + self.assertEqual(result, "This tool is not available in this environment.") + self.assertIsNone(artifact) + + async def test_run_docs_search_with_api_key(self): + mock_docs_tool = MagicMock() + mock_docs_tool.execute = AsyncMock(return_value=("", MagicMock())) + + with ( + patch("ee.hogai.graph.root.tools.search.settings") as mock_settings, + patch("ee.hogai.graph.root.tools.search.InkeepDocsSearchTool", return_value=mock_docs_tool), + ): + mock_settings.INKEEP_API_KEY = "test-key" + result, artifact = await self.tool._arun_impl(kind="docs", query="How to use feature flags?") + + mock_docs_tool.execute.assert_called_once_with("How to use feature flags?", self.tool_call_id) + self.assertEqual(result, "") + self.assertIsNotNone(artifact) + + async def test_run_insights_search(self): + mock_insights_tool = MagicMock() + mock_insights_tool.execute = AsyncMock(return_value=("", MagicMock())) + + with patch("ee.hogai.graph.root.tools.search.InsightSearchTool", return_value=mock_insights_tool): + result, artifact = await self.tool._arun_impl(kind="insights", query="user signups") + + mock_insights_tool.execute.assert_called_once_with("user signups", self.tool_call_id) + self.assertEqual(result, "") + self.assertIsNotNone(artifact) + + async def test_run_unknown_kind(self): + result, artifact = await self.tool._arun_impl(kind="unknown", query="test") + self.assertEqual(result, "Invalid entity kind: unknown. Please provide a valid entity kind for the tool.") + self.assertIsNone(artifact) + + @patch("ee.hogai.graph.root.tools.search.EntitySearchTool.execute") + async def test_arun_impl_error_tracking_issues_returns_routing_data(self, mock_execute): + mock_execute.return_value = "Search results for error tracking issues" + + result, artifact = await self.tool._arun_impl( + kind="error_tracking_issues", query="test error tracking issue query" + ) + + self.assertEqual(result, "Search results for error tracking issues") + self.assertIsNone(artifact) + mock_execute.assert_called_once_with("test error tracking issue query", "error_tracking_issues") + + @patch("ee.hogai.graph.root.tools.search.EntitySearchTool.execute") + @patch("ee.hogai.graph.root.tools.search.SearchTool._has_insights_fts_search_feature_flag") + async def test_arun_impl_insight_with_feature_flag_disabled( + self, mock_has_insights_fts_search_feature_flag, mock_execute + ): + mock_has_insights_fts_search_feature_flag.return_value = False + mock_execute.return_value = "Search results for insights" + + result, artifact = await self.tool._arun_impl(kind="insights", query="test insight query") + + self.assertEqual(result, "The user doesn't have any insights created yet.") + self.assertIsNone(artifact) + mock_execute.assert_not_called() + + @patch("ee.hogai.graph.root.tools.search.EntitySearchTool.execute") + @patch("ee.hogai.graph.root.tools.search.SearchTool._has_insights_fts_search_feature_flag") + async def test_arun_impl_insight_with_feature_flag_enabled( + self, mock_has_insights_fts_search_feature_flag, mock_execute + ): + mock_has_insights_fts_search_feature_flag.return_value = True + mock_execute.return_value = "Search results for insights" + + result, artifact = await self.tool._arun_impl(kind="insights", query="test insight query") + + self.assertEqual(result, "Search results for insights") + self.assertIsNone(artifact) + mock_execute.assert_called_once_with("test insight query", "insights") + + +class TestInkeepDocsSearchTool(ClickhouseTestMixin, NonAtomicBaseTest): + CLASS_DATA_LEVEL_SETUP = False + + def setUp(self): + super().setUp() + self.tool_call_id = str(uuid4()) + self.state = AssistantState(messages=[], root_tool_call_id=self.tool_call_id) + self.context_manager = AssistantContextManager(self.team, self.user, {}) + self.tool = InkeepDocsSearchTool( + team=self.team, + user=self.user, + state=self.state, + config=RunnableConfig(configurable={}), + context_manager=self.context_manager, + ) + + async def test_execute_calls_inkeep_docs_node(self): + mock_node_instance = MagicMock() + mock_result = PartialAssistantState(messages=[AssistantMessage(content="Here is the answer from docs")]) + + with patch("ee.hogai.graph.inkeep_docs.nodes.InkeepDocsNode", return_value=mock_node_instance): + with patch("ee.hogai.graph.root.tools.search.RunnableLambda") as mock_runnable: + mock_chain = MagicMock() + mock_chain.ainvoke = AsyncMock(return_value=mock_result) + mock_runnable.return_value = mock_chain + + result, artifact = await self.tool.execute("How to track events?", "test-tool-call-id") + + self.assertEqual(result, "") + self.assertIsNotNone(artifact) + assert artifact is not None + self.assertEqual(len(artifact.messages), 1) + message = cast(AssistantMessage, artifact.messages[0]) + self.assertEqual(message.content, "Here is the answer from docs") + + async def test_execute_updates_state_with_tool_call_id(self): + mock_result = PartialAssistantState(messages=[AssistantMessage(content="Test response")]) + + async def mock_ainvoke(state): + self.assertEqual(state.root_tool_call_id, "custom-tool-call-id") + return mock_result + + mock_chain = MagicMock() + mock_chain.ainvoke = mock_ainvoke + + with patch("ee.hogai.graph.inkeep_docs.nodes.InkeepDocsNode"): + with patch("ee.hogai.graph.root.tools.search.RunnableLambda", return_value=mock_chain): + await self.tool.execute("test query", "custom-tool-call-id") @override_settings(INKEEP_API_KEY="test-inkeep-key") @patch("ee.hogai.graph.root.tools.search.ChatOpenAI") - async def test_search_docs_with_successful_results(self, mock_llm_class): + @patch("ee.hogai.graph.root.tools.search.InkeepDocsSearchTool._has_rag_docs_search_feature_flag", return_value=True) + async def test_search_docs_with_successful_results(self, mock_has_rag_docs_search_feature_flag, mock_llm_class): response_json = """{ "content": [ { @@ -47,7 +198,7 @@ class TestSearchToolDocumentation(BaseTest): fake_llm = FakeChatOpenAI(responses=[messages.AIMessage(content=response_json)]) mock_llm_class.return_value = fake_llm - result = await self.tool._search_docs("how to use feature") + result, _ = await self.tool.execute("how to use feature", "test-tool-call-id") expected_doc_1 = DOC_ITEM_TEMPLATE.format( title="Feature Documentation", @@ -68,196 +219,76 @@ class TestSearchToolDocumentation(BaseTest): self.assertEqual(mock_llm_class.call_args.kwargs["api_key"], "test-inkeep-key") self.assertEqual(mock_llm_class.call_args.kwargs["streaming"], False) - @override_settings(INKEEP_API_KEY="test-inkeep-key") - @patch("ee.hogai.graph.root.tools.search.ChatOpenAI") - async def test_search_docs_with_no_results(self, mock_llm_class): - fake_llm = FakeChatOpenAI(responses=[messages.AIMessage(content="{}")]) - mock_llm_class.return_value = fake_llm - result = await self.tool._search_docs("nonexistent feature") +class TestInsightSearchTool(ClickhouseTestMixin, NonAtomicBaseTest): + CLASS_DATA_LEVEL_SETUP = False - self.assertEqual(result, DOCS_SEARCH_NO_RESULTS_TEMPLATE) - - @override_settings(INKEEP_API_KEY="test-inkeep-key") - @patch("ee.hogai.graph.root.tools.search.ChatOpenAI") - async def test_search_docs_with_empty_content(self, mock_llm_class): - fake_llm = FakeChatOpenAI(responses=[messages.AIMessage(content='{"content": []}')]) - mock_llm_class.return_value = fake_llm - - result = await self.tool._search_docs("query") - - self.assertEqual(result, DOCS_SEARCH_NO_RESULTS_TEMPLATE) - - @override_settings(INKEEP_API_KEY="test-inkeep-key") - @patch("ee.hogai.graph.root.tools.search.ChatOpenAI") - async def test_search_docs_filters_non_document_types(self, mock_llm_class): - response_json = """{ - "content": [ - { - "type": "snippet", - "record_type": "code", - "url": "https://posthog.com/code", - "title": "Code Snippet", - "source": {"type": "text", "content": [{"type": "text", "text": "Code example"}]} - }, - { - "type": "answer", - "record_type": "answer", - "url": "https://posthog.com/answer", - "title": "Answer", - "source": {"type": "text", "content": [{"type": "text", "text": "Answer text"}]} - } - ] - }""" - - fake_llm = FakeChatOpenAI(responses=[messages.AIMessage(content=response_json)]) - mock_llm_class.return_value = fake_llm - - result = await self.tool._search_docs("query") - - self.assertEqual(result, DOCS_SEARCH_NO_RESULTS_TEMPLATE) - - @override_settings(INKEEP_API_KEY="test-inkeep-key") - @patch("ee.hogai.graph.root.tools.search.ChatOpenAI") - async def test_search_docs_handles_empty_source_content(self, mock_llm_class): - response_json = """{ - "content": [ - { - "type": "document", - "record_type": "page", - "url": "https://posthog.com/docs/feature", - "title": "Feature Documentation", - "source": {"type": "text", "content": []} - } - ] - }""" - - fake_llm = FakeChatOpenAI(responses=[messages.AIMessage(content=response_json)]) - mock_llm_class.return_value = fake_llm - - result = await self.tool._search_docs("query") - - expected_doc = DOC_ITEM_TEMPLATE.format( - title="Feature Documentation", url="https://posthog.com/docs/feature", text="" - ) - expected_result = DOCS_SEARCH_RESULTS_TEMPLATE.format(count=1, docs=expected_doc) - - self.assertEqual(result, expected_result) - - @override_settings(INKEEP_API_KEY="test-inkeep-key") - @patch("ee.hogai.graph.root.tools.search.ChatOpenAI") - async def test_search_docs_handles_mixed_document_types(self, mock_llm_class): - response_json = """{ - "content": [ - { - "type": "document", - "record_type": "page", - "url": "https://posthog.com/docs/valid", - "title": "Valid Doc", - "source": {"type": "text", "content": [{"type": "text", "text": "Valid content"}]} - }, - { - "type": "snippet", - "record_type": "code", - "url": "https://posthog.com/code", - "title": "Code Snippet", - "source": {"type": "text", "content": [{"type": "text", "text": "Code"}]} - }, - { - "type": "document", - "record_type": "guide", - "url": "https://posthog.com/docs/another", - "title": "Another Valid Doc", - "source": {"type": "text", "content": [{"type": "text", "text": "More content"}]} - } - ] - }""" - - fake_llm = FakeChatOpenAI(responses=[messages.AIMessage(content=response_json)]) - mock_llm_class.return_value = fake_llm - - result = await self.tool._search_docs("query") - - expected_doc_1 = DOC_ITEM_TEMPLATE.format( - title="Valid Doc", url="https://posthog.com/docs/valid", text="Valid content" - ) - expected_doc_2 = DOC_ITEM_TEMPLATE.format( - title="Another Valid Doc", url="https://posthog.com/docs/another", text="More content" - ) - expected_result = DOCS_SEARCH_RESULTS_TEMPLATE.format( - count=2, docs=f"{expected_doc_1}\n\n---\n\n{expected_doc_2}" + def setUp(self): + super().setUp() + self.tool_call_id = str(uuid4()) + self.state = AssistantState(messages=[], root_tool_call_id=self.tool_call_id) + self.context_manager = AssistantContextManager(self.team, self.user, {}) + self.tool = InsightSearchTool( + team=self.team, + user=self.user, + state=self.state, + config=RunnableConfig(configurable={}), + context_manager=self.context_manager, ) - self.assertEqual(result, expected_result) + async def test_execute_calls_insight_search_node(self): + mock_node_instance = MagicMock() + mock_result = PartialAssistantState(messages=[AssistantMessage(content="Found 3 insights matching your query")]) - @override_settings(INKEEP_API_KEY=None) - async def test_arun_impl_docs_without_api_key(self): - result, artifact = await self.tool._arun_impl(kind="docs", query="test query") + with patch("ee.hogai.graph.insights.nodes.InsightSearchNode", return_value=mock_node_instance): + with patch("ee.hogai.graph.root.tools.search.RunnableLambda") as mock_runnable: + mock_chain = MagicMock() + mock_chain.ainvoke = AsyncMock(return_value=mock_result) + mock_runnable.return_value = mock_chain - self.assertEqual(result, "This tool is not available in this environment.") - self.assertIsNone(artifact) + result, artifact = await self.tool.execute("user signups by week", self.tool_call_id) - @override_settings(INKEEP_API_KEY="test-key") - @patch("ee.hogai.graph.root.tools.search.posthoganalytics.feature_enabled") - async def test_arun_impl_docs_with_feature_flag_disabled(self, mock_feature_enabled): - mock_feature_enabled.return_value = False + self.assertEqual(result, "") + self.assertIsNotNone(artifact) + assert artifact is not None + self.assertEqual(len(artifact.messages), 1) + message = cast(AssistantMessage, artifact.messages[0]) + self.assertEqual(message.content, "Found 3 insights matching your query") - result, artifact = await self.tool._arun_impl(kind="docs", query="test query") + async def test_execute_updates_state_with_search_query(self): + mock_result = PartialAssistantState(messages=[AssistantMessage(content="Test response")]) - self.assertEqual(result, "Search tool executed") - self.assertEqual(artifact, {"kind": "docs", "query": "test query"}) + async def mock_ainvoke(state): + self.assertEqual(state.search_insights_query, "custom search query") + self.assertEqual(state.root_tool_call_id, self.tool_call_id) + return mock_result - @override_settings(INKEEP_API_KEY="test-key") - @patch("ee.hogai.graph.root.tools.search.posthoganalytics.feature_enabled") - @patch.object(SearchTool, "_search_docs") - async def test_arun_impl_docs_with_feature_flag_enabled(self, mock_search_docs, mock_feature_enabled): - mock_feature_enabled.return_value = True - mock_search_docs.return_value = "Search results" + mock_chain = MagicMock() + mock_chain.ainvoke = mock_ainvoke - result, artifact = await self.tool._arun_impl(kind="docs", query="test query") + with patch("ee.hogai.graph.insights.nodes.InsightSearchNode"): + with patch("ee.hogai.graph.root.tools.search.RunnableLambda", return_value=mock_chain): + await self.tool.execute("custom search query", self.tool_call_id) - self.assertEqual(result, "Search results") - self.assertIsNone(artifact) - mock_search_docs.assert_called_once_with("test query") + async def test_execute_handles_no_insights_exception(self): + from ee.hogai.graph.insights.nodes import NoInsightsException - async def test_arun_impl_insights_returns_routing_data(self): - result, artifact = await self.tool._arun_impl(kind="insights", query="test insight query") + with patch("ee.hogai.graph.insights.nodes.InsightSearchNode", side_effect=NoInsightsException()): + result, artifact = await self.tool.execute("user signups", self.tool_call_id) - self.assertEqual(result, "Search tool executed") - self.assertEqual(artifact, {"kind": "insights", "query": "test insight query"}) + self.assertEqual(result, EMPTY_DATABASE_ERROR_MESSAGE) + self.assertIsNone(artifact) - @patch("ee.hogai.graph.root.tools.search.EntitySearchToolkit.execute") - async def test_arun_impl_error_tracking_issues_returns_routing_data(self, mock_execute): - mock_execute.return_value = "Search results for error tracking issues" + async def test_execute_returns_none_artifact_when_result_is_none(self): + async def mock_ainvoke(state): + return None - result, artifact = await self.tool._arun_impl( - kind="error_tracking_issues", query="test error tracking issue query" - ) + mock_chain = MagicMock() + mock_chain.ainvoke = mock_ainvoke - self.assertEqual(result, "Search results for error tracking issues") - self.assertIsNone(artifact) - mock_execute.assert_called_once_with("test error tracking issue query", "error_tracking_issues") + with patch("ee.hogai.graph.insights.nodes.InsightSearchNode"): + with patch("ee.hogai.graph.root.tools.search.RunnableLambda", return_value=mock_chain): + result, artifact = await self.tool.execute("test query", self.tool_call_id) - @patch("ee.hogai.graph.root.tools.search.EntitySearchToolkit.execute") - @patch("ee.hogai.graph.root.tools.search.SearchTool._has_fts_search_feature_flag") - async def test_arun_impl_insight_with_feature_flag_disabled(self, mock_has_fts_search_feature_flag, mock_execute): - mock_has_fts_search_feature_flag.return_value = False - mock_execute.return_value = "Search results for insights" - - result, artifact = await self.tool._arun_impl(kind="insights", query="test insight query") - - self.assertEqual(result, "Search tool executed") - self.assertEqual(artifact, {"kind": "insights", "query": "test insight query"}) - mock_execute.assert_not_called() - - @patch("ee.hogai.graph.root.tools.search.EntitySearchToolkit.execute") - @patch("ee.hogai.graph.root.tools.search.SearchTool._has_fts_search_feature_flag") - async def test_arun_impl_insight_with_feature_flag_enabled(self, mock_has_fts_search_feature_flag, mock_execute): - mock_has_fts_search_feature_flag.return_value = True - mock_execute.return_value = "Search results for insights" - - result, artifact = await self.tool._arun_impl(kind="insights", query="test insight query") - - self.assertEqual(result, "Search results for insights") - self.assertIsNone(artifact) - mock_execute.assert_called_once_with("test insight query", "insights") + self.assertEqual(result, "") + self.assertIsNone(artifact) diff --git a/ee/hogai/graph/root/tools/test/test_session_summarization.py b/ee/hogai/graph/root/tools/test/test_session_summarization.py new file mode 100644 index 0000000000..ff6d4965c6 --- /dev/null +++ b/ee/hogai/graph/root/tools/test/test_session_summarization.py @@ -0,0 +1,193 @@ +from typing import cast +from uuid import uuid4 + +from posthog.test.base import ClickhouseTestMixin, NonAtomicBaseTest +from unittest.mock import AsyncMock, MagicMock, patch + +from posthog.schema import AssistantMessage + +from ee.hogai.context.context import AssistantContextManager +from ee.hogai.graph.root.tools.session_summarization import SessionSummarizationTool +from ee.hogai.utils.types import AssistantState, PartialAssistantState +from ee.hogai.utils.types.base import NodePath + + +class TestSessionSummarizationTool(ClickhouseTestMixin, NonAtomicBaseTest): + CLASS_DATA_LEVEL_SETUP = False + + def setUp(self): + super().setUp() + self.tool_call_id = str(uuid4()) + self.state = AssistantState(messages=[], root_tool_call_id=self.tool_call_id) + self.context_manager = AssistantContextManager(self.team, self.user, {}) + self.tool = SessionSummarizationTool( + team=self.team, + user=self.user, + state=self.state, + context_manager=self.context_manager, + node_path=(NodePath(name="test_node", tool_call_id=self.tool_call_id, message_id="test"),), + ) + + async def test_execute_calls_session_summarization_node(self): + mock_node_instance = MagicMock() + mock_result = PartialAssistantState( + messages=[AssistantMessage(content="Session summary: 10 sessions analyzed with 5 key patterns found")] + ) + + with patch("ee.hogai.graph.session_summaries.nodes.SessionSummarizationNode", return_value=mock_node_instance): + with patch("ee.hogai.graph.root.tools.session_summarization.RunnableLambda") as mock_runnable: + mock_chain = MagicMock() + mock_chain.ainvoke = AsyncMock(return_value=mock_result) + mock_runnable.return_value = mock_chain + + result, artifact = await self.tool._arun_impl( + session_summarization_query="summarize all sessions from yesterday", + should_use_current_filters=False, + summary_title="All sessions from yesterday", + ) + + self.assertEqual(result, "") + self.assertIsNotNone(artifact) + assert artifact is not None + self.assertEqual(len(artifact.messages), 1) + message = cast(AssistantMessage, artifact.messages[0]) + self.assertEqual(message.content, "Session summary: 10 sessions analyzed with 5 key patterns found") + + async def test_execute_updates_state_with_all_parameters(self): + mock_result = PartialAssistantState(messages=[AssistantMessage(content="Test response")]) + + async def mock_ainvoke(state): + self.assertEqual(state.session_summarization_query, "analyze mobile user sessions") + self.assertEqual(state.should_use_current_filters, True) + self.assertEqual(state.summary_title, "Mobile user sessions") + self.assertEqual(state.root_tool_call_id, self.tool_call_id) + return mock_result + + mock_chain = MagicMock() + mock_chain.ainvoke = mock_ainvoke + + with patch("ee.hogai.graph.session_summaries.nodes.SessionSummarizationNode"): + with patch("ee.hogai.graph.root.tools.session_summarization.RunnableLambda", return_value=mock_chain): + await self.tool._arun_impl( + session_summarization_query="analyze mobile user sessions", + should_use_current_filters=True, + summary_title="Mobile user sessions", + ) + + async def test_execute_with_should_use_current_filters_false(self): + mock_result = PartialAssistantState(messages=[AssistantMessage(content="Test response")]) + + async def mock_ainvoke(state): + self.assertEqual(state.should_use_current_filters, False) + self.assertEqual(state.session_summarization_query, "watch last 300 session recordings") + self.assertEqual(state.summary_title, "Last 300 sessions") + return mock_result + + mock_chain = MagicMock() + mock_chain.ainvoke = mock_ainvoke + + with patch("ee.hogai.graph.session_summaries.nodes.SessionSummarizationNode"): + with patch("ee.hogai.graph.root.tools.session_summarization.RunnableLambda", return_value=mock_chain): + await self.tool._arun_impl( + session_summarization_query="watch last 300 session recordings", + should_use_current_filters=False, + summary_title="Last 300 sessions", + ) + + async def test_execute_returns_failure_message_when_result_is_none(self): + async def mock_ainvoke(state): + return None + + mock_chain = MagicMock() + mock_chain.ainvoke = mock_ainvoke + + with patch("ee.hogai.graph.session_summaries.nodes.SessionSummarizationNode"): + with patch("ee.hogai.graph.root.tools.session_summarization.RunnableLambda", return_value=mock_chain): + result, artifact = await self.tool._arun_impl( + session_summarization_query="test query", + should_use_current_filters=False, + summary_title="Test", + ) + + self.assertEqual(result, "Session summarization failed") + self.assertIsNone(artifact) + + async def test_execute_returns_failure_message_when_result_has_no_messages(self): + mock_result = PartialAssistantState(messages=[]) + + async def mock_ainvoke(state): + return mock_result + + mock_chain = MagicMock() + mock_chain.ainvoke = mock_ainvoke + + with patch("ee.hogai.graph.session_summaries.nodes.SessionSummarizationNode"): + with patch("ee.hogai.graph.root.tools.session_summarization.RunnableLambda", return_value=mock_chain): + result, artifact = await self.tool._arun_impl( + session_summarization_query="test query", + should_use_current_filters=False, + summary_title="Test", + ) + + self.assertEqual(result, "Session summarization failed") + self.assertIsNone(artifact) + + async def test_execute_with_empty_summary_title(self): + mock_result = PartialAssistantState(messages=[AssistantMessage(content="Summary completed")]) + + async def mock_ainvoke(state): + self.assertEqual(state.summary_title, "") + return mock_result + + mock_chain = MagicMock() + mock_chain.ainvoke = mock_ainvoke + + with patch("ee.hogai.graph.session_summaries.nodes.SessionSummarizationNode"): + with patch("ee.hogai.graph.root.tools.session_summarization.RunnableLambda", return_value=mock_chain): + result, artifact = await self.tool._arun_impl( + session_summarization_query="summarize sessions", + should_use_current_filters=False, + summary_title="", + ) + + self.assertEqual(result, "") + self.assertIsNotNone(artifact) + + async def test_execute_preserves_original_state(self): + """Test that the original state is not modified when creating the copied state""" + original_query = "original query" + original_state = AssistantState( + messages=[], + root_tool_call_id=self.tool_call_id, + session_summarization_query=original_query, + ) + + tool = SessionSummarizationTool( + team=self.team, + user=self.user, + state=original_state, + context_manager=self.context_manager, + node_path=(NodePath(name="test_node", tool_call_id=self.tool_call_id, message_id="test"),), + ) + + mock_result = PartialAssistantState(messages=[AssistantMessage(content="Test")]) + + async def mock_ainvoke(state): + # Verify the new state has updated values + self.assertEqual(state.session_summarization_query, "new query") + return mock_result + + mock_chain = MagicMock() + mock_chain.ainvoke = mock_ainvoke + + with patch("ee.hogai.graph.session_summaries.nodes.SessionSummarizationNode"): + with patch("ee.hogai.graph.root.tools.session_summarization.RunnableLambda", return_value=mock_chain): + await tool._arun_impl( + session_summarization_query="new query", + should_use_current_filters=True, + summary_title="New Summary", + ) + + # Verify original state was not modified + self.assertEqual(original_state.session_summarization_query, original_query) + self.assertEqual(original_state.root_tool_call_id, self.tool_call_id) diff --git a/ee/hogai/graph/root/tools/todo_write.py b/ee/hogai/graph/root/tools/todo_write.py index 816c19b6ce..f8670eef8b 100644 --- a/ee/hogai/graph/root/tools/todo_write.py +++ b/ee/hogai/graph/root/tools/todo_write.py @@ -138,11 +138,11 @@ When unsure, use this tool. Proactive task management shows attentiveness and he """.strip() +# Has its unique schema that doesn't match the Deep Research schema class TodoItem(BaseModel): content: str = Field(..., min_length=1) status: Literal["pending", "in_progress", "completed"] id: str - priority: Literal["low", "medium", "high"] class TodoWriteToolArgs(BaseModel): diff --git a/ee/hogai/graph/schema_generator/nodes.py b/ee/hogai/graph/schema_generator/nodes.py index f2e957ac58..2149c58071 100644 --- a/ee/hogai/graph/schema_generator/nodes.py +++ b/ee/hogai/graph/schema_generator/nodes.py @@ -13,7 +13,7 @@ from langchain_core.messages import ( from langchain_core.prompts import ChatPromptTemplate, HumanMessagePromptTemplate from langchain_core.runnables import RunnableConfig -from posthog.schema import FailureMessage, VisualizationMessage +from posthog.schema import VisualizationMessage from posthog.models.group_type_mapping import GroupTypeMapping @@ -36,6 +36,15 @@ from .utils import Q, SchemaGeneratorOutput RETRIES_ALLOWED = 2 +class SchemaGenerationException(Exception): + """An error occurred while generating a schema in the `SchemaGeneratorNode` node.""" + + def __init__(self, llm_output: str, validation_message: str): + super().__init__("Failed to generate schema") + self.llm_output = llm_output + self.validation_message = validation_message + + class SchemaGeneratorNode(AssistantNode, Generic[Q]): INSIGHT_NAME: str """ @@ -87,9 +96,8 @@ class SchemaGeneratorNode(AssistantNode, Generic[Q]): chain = generation_prompt | merger | self._model | self._parse_output - result: SchemaGeneratorOutput[Q] | None = None try: - result = await chain.ainvoke( + result: SchemaGeneratorOutput[Q] = await chain.ainvoke( { "project_datetime": self.project_now, "project_timezone": self.project_timezone, @@ -120,18 +128,9 @@ class SchemaGeneratorNode(AssistantNode, Generic[Q]): query_generation_retry_count=len(intermediate_steps) + 1, ) - if not result: - # We've got no usable result after exhausting all iteration attempts - it's failure message time - return PartialAssistantState( - messages=[ - FailureMessage( - content=f"It looks like I'm having trouble generating this {self.INSIGHT_NAME} insight." - ) - ], - intermediate_steps=None, - plan=None, - query_generation_retry_count=len(intermediate_steps) + 1, - ) + if isinstance(e, PydanticOutputParserException): + raise SchemaGenerationException(e.llm_output, e.validation_message) + raise SchemaGenerationException(e.llm_output or "No input was provided.", str(e)) # We've got a result that either passed the quality check or we've exhausted all attempts at iterating - return return PartialAssistantState( diff --git a/ee/hogai/graph/schema_generator/test/test_nodes.py b/ee/hogai/graph/schema_generator/test/test_nodes.py index 1e1ea5e1c1..2dc0c4507b 100644 --- a/ee/hogai/graph/schema_generator/test/test_nodes.py +++ b/ee/hogai/graph/schema_generator/test/test_nodes.py @@ -13,7 +13,12 @@ from langchain_core.runnables import RunnableConfig, RunnableLambda from posthog.schema import AssistantMessage, AssistantTrendsQuery, FailureMessage, HumanMessage, VisualizationMessage -from ee.hogai.graph.schema_generator.nodes import RETRIES_ALLOWED, SchemaGeneratorNode, SchemaGeneratorToolsNode +from ee.hogai.graph.schema_generator.nodes import ( + RETRIES_ALLOWED, + SchemaGenerationException, + SchemaGeneratorNode, + SchemaGeneratorToolsNode, +) from ee.hogai.graph.schema_generator.parsers import PydanticOutputParserException from ee.hogai.graph.schema_generator.utils import SchemaGeneratorOutput from ee.hogai.utils.types import AssistantState, PartialAssistantState @@ -258,7 +263,7 @@ class TestSchemaGeneratorNode(BaseTest): self.assertEqual(action.log, "Field validation failed") async def test_quality_check_failure_with_retries_exhausted(self): - """Test quality check failure with retries exhausted still returns VisualizationMessage.""" + """Test quality check failure with retries exhausted raises SchemaGenerationException.""" node = DummyGeneratorNode(self.team, self.user) with ( patch.object(DummyGeneratorNode, "_model") as generator_model_mock, @@ -273,26 +278,25 @@ class TestSchemaGeneratorNode(BaseTest): ) # Start with RETRIES_ALLOWED intermediate steps (so no more allowed) - new_state = await node.arun( - AssistantState( - messages=[HumanMessage(content="Text", id="0")], - start_id="0", - intermediate_steps=cast( - list[IntermediateStep], - [ - (AgentAction(tool="handle_incorrect_response", tool_input="", log=""), "retry"), - ], - ) - * RETRIES_ALLOWED, - ), - {}, - ) + with self.assertRaises(SchemaGenerationException) as cm: + await node.arun( + AssistantState( + messages=[HumanMessage(content="Text", id="0")], + start_id="0", + intermediate_steps=cast( + list[IntermediateStep], + [ + (AgentAction(tool="handle_incorrect_response", tool_input="", log=""), "retry"), + ], + ) + * RETRIES_ALLOWED, + ), + {}, + ) - # Should return VisualizationMessage despite quality check failure - self.assertEqual(new_state.intermediate_steps, None) - self.assertEqual(len(new_state.messages), 1) - self.assertEqual(new_state.messages[0].type, "ai/viz") - self.assertEqual(cast(VisualizationMessage, new_state.messages[0]).answer, self.basic_trends) + # Verify the exception contains the expected information + self.assertEqual(cm.exception.llm_output, '{"query": "test"}') + self.assertEqual(cm.exception.validation_message, "Quality check failed") async def test_node_leaves_failover(self): node = DummyGeneratorNode(self.team, self.user) @@ -329,20 +333,17 @@ class TestSchemaGeneratorNode(BaseTest): schema = DummySchema.model_construct(query=[]).model_dump() # type: ignore generator_model_mock.return_value = RunnableLambda(lambda _: json.dumps(schema)) - new_state = await node.arun( - AssistantState( - messages=[HumanMessage(content="Text")], - intermediate_steps=[ - (AgentAction(tool="", tool_input="", log="exception"), "exception"), - (AgentAction(tool="", tool_input="", log="exception"), "exception"), - ], - ), - {}, - ) - self.assertEqual(new_state.intermediate_steps, None) - self.assertEqual(len(new_state.messages), 1) - self.assertIsInstance(new_state.messages[0], FailureMessage) - self.assertEqual(new_state.plan, None) + with self.assertRaises(SchemaGenerationException): + await node.arun( + AssistantState( + messages=[HumanMessage(content="Text")], + intermediate_steps=[ + (AgentAction(tool="", tool_input="", log="exception"), "exception"), + (AgentAction(tool="", tool_input="", log="exception"), "exception"), + ], + ), + {}, + ) async def test_agent_reconstructs_conversation_with_failover(self): action = AgentAction(tool="fix", tool_input="validation error", log="exception") diff --git a/ee/hogai/graph/session_summaries/nodes.py b/ee/hogai/graph/session_summaries/nodes.py index 4d33fee929..8e5692256a 100644 --- a/ee/hogai/graph/session_summaries/nodes.py +++ b/ee/hogai/graph/session_summaries/nodes.py @@ -11,7 +11,6 @@ from langchain_core.prompts import ChatPromptTemplate from langchain_core.runnables import RunnableConfig from posthog.schema import ( - AssistantMessage, AssistantToolCallMessage, MaxRecordingUniversalFilters, NotebookUpdateMessage, @@ -70,11 +69,7 @@ class SessionSummarizationNode(AssistantNode): """Push summarization progress as reasoning messages""" content = prepare_reasoning_progress_message(progress_message) if content: - self.dispatcher.message( - AssistantMessage( - content=content, - ) - ) + self.dispatcher.update(content) async def _stream_notebook_content(self, content: dict, state: AssistantState, partial: bool = True) -> None: """Stream TipTap content directly to a notebook if notebook_id is present in state.""" @@ -202,7 +197,7 @@ class _SessionSearch: tool = await SearchSessionRecordingsTool.create_tool_class( team=self._node._team, user=self._node._user, - tool_call_id=self._node._parent_tool_call_id or "", + node_path=self._node.node_path, state=state, config=config, context_manager=self._node.context_manager, diff --git a/ee/hogai/graph/sql/nodes.py b/ee/hogai/graph/sql/nodes.py index eb9da64c9c..36d3840bf9 100644 --- a/ee/hogai/graph/sql/nodes.py +++ b/ee/hogai/graph/sql/nodes.py @@ -1,6 +1,6 @@ from langchain_core.runnables import RunnableConfig -from posthog.schema import AssistantHogQLQuery, AssistantMessage +from posthog.schema import AssistantHogQLQuery from posthog.hogql.context import HogQLContext @@ -25,7 +25,7 @@ class SQLGeneratorNode(HogQLGeneratorMixin, SchemaGeneratorNode[AssistantHogQLQu return AssistantNodeName.SQL_GENERATOR async def arun(self, state: AssistantState, config: RunnableConfig) -> PartialAssistantState: - self.dispatcher.message(AssistantMessage(content="Creating SQL query")) + self.dispatcher.update("Creating SQL query") prompt = await self._construct_system_prompt() return await super()._run_with_prompt(state, prompt, config=config) diff --git a/ee/hogai/graph/taxonomy/agent.py b/ee/hogai/graph/taxonomy/agent.py index 3ae8a99dab..bfb7b4b01e 100644 --- a/ee/hogai/graph/taxonomy/agent.py +++ b/ee/hogai/graph/taxonomy/agent.py @@ -3,34 +3,43 @@ from typing import Generic from posthog.models import Team, User from ee.hogai.django_checkpoint.checkpointer import DjangoCheckpointer -from ee.hogai.graph.graph import BaseAssistantGraph +from ee.hogai.graph.base import BaseAssistantGraph from ee.hogai.utils.types import PartialStateType, StateType +from ee.hogai.utils.types.base import AssistantGraphName from .nodes import StateClassMixin, TaxonomyAgentNode, TaxonomyAgentToolsNode from .toolkit import TaxonomyAgentToolkit from .types import TaxonomyNodeName -class TaxonomyAgent(BaseAssistantGraph[StateType], Generic[StateType, PartialStateType], StateClassMixin): +class TaxonomyAgent( + BaseAssistantGraph[StateType, PartialStateType], Generic[StateType, PartialStateType], StateClassMixin +): """Taxonomy agent that can be configured with different node classes.""" def __init__( self, team: Team, user: User, - tool_call_id: str, loop_node_class: type["TaxonomyAgentNode"], tools_node_class: type["TaxonomyAgentToolsNode"], toolkit_class: type["TaxonomyAgentToolkit"], ): - # Extract the State type from the generic parameter - state_class, _ = self._get_state_class(TaxonomyAgent) - super().__init__(team, user, state_class, parent_tool_call_id=tool_call_id) - + super().__init__(team, user) self._loop_node_class = loop_node_class self._tools_node_class = tools_node_class self._toolkit_class = toolkit_class + @property + def state_type(self) -> type[StateType]: + # Extract the State type from the generic parameter + state_type, _ = self._get_state_class(TaxonomyAgent) + return state_type + + @property + def graph_name(self) -> AssistantGraphName: + return AssistantGraphName.TAXONOMY + def compile_full_graph(self, checkpointer: DjangoCheckpointer | None = None): """Compile a complete taxonomy graph.""" return self.add_taxonomy_generator().compile(checkpointer=checkpointer) diff --git a/ee/hogai/graph/taxonomy/nodes.py b/ee/hogai/graph/taxonomy/nodes.py index 327b1cbb48..1db0ef30d8 100644 --- a/ee/hogai/graph/taxonomy/nodes.py +++ b/ee/hogai/graph/taxonomy/nodes.py @@ -40,25 +40,10 @@ TaxonomyPartialStateType = TypeVar("TaxonomyPartialStateType", bound=TaxonomyAge TaxonomyNodeBound = BaseAssistantNode[TaxonomyStateType, TaxonomyPartialStateType] -class ParentToolCallIdMixin(BaseAssistantNode[TaxonomyStateType, TaxonomyPartialStateType]): - _parent_tool_call_id: str | None = None - _toolkit: TaxonomyAgentToolkit - - async def __call__(self, state: TaxonomyStateType, config: RunnableConfig) -> TaxonomyPartialStateType | None: - """ - Run the assistant node and handle cancelled conversation before the node is run. - """ - if self._parent_tool_call_id: - self._toolkit._parent_tool_call_id = self._parent_tool_call_id - - return await super().__call__(state, config) - - class TaxonomyAgentNode( Generic[TaxonomyStateType, TaxonomyPartialStateType], StateClassMixin, TaxonomyUpdateDispatcherNodeMixin, - ParentToolCallIdMixin, TaxonomyNodeBound, ABC, ): @@ -173,7 +158,6 @@ class TaxonomyAgentToolsNode( Generic[TaxonomyStateType, TaxonomyPartialStateType], StateClassMixin, TaxonomyUpdateDispatcherNodeMixin, - ParentToolCallIdMixin, TaxonomyNodeBound, ): """Base tools node for taxonomy agents.""" diff --git a/ee/hogai/graph/taxonomy/test/test_agent.py b/ee/hogai/graph/taxonomy/test/test_agent.py index cb0551a864..23eaded165 100644 --- a/ee/hogai/graph/taxonomy/test/test_agent.py +++ b/ee/hogai/graph/taxonomy/test/test_agent.py @@ -36,14 +36,13 @@ class TestTaxonomyAgent(BaseTest): self.mock_graph = Mock() # Patch StateGraph in the parent class where it's actually used - self.patcher = patch("ee.hogai.graph.graph.StateGraph") + self.patcher = patch("ee.hogai.graph.base.graph.StateGraph") mock_state_graph_class = self.patcher.start() mock_state_graph_class.return_value = self.mock_graph self.agent = ConcreteTaxonomyAgent( team=self.team, user=self.user, - tool_call_id="test_tool_call_id", loop_node_class=MockTaxonomyAgentNode, tools_node_class=MockTaxonomyAgentToolsNode, toolkit_class=MockTaxonomyAgentToolkit, @@ -75,7 +74,6 @@ class TestTaxonomyAgent(BaseTest): NonGenericAgent( team=self.team, user=self.user, - tool_call_id="test_tool_call_id", loop_node_class=MockTaxonomyAgentNode, tools_node_class=MockTaxonomyAgentToolsNode, toolkit_class=MockTaxonomyAgentToolkit, diff --git a/ee/hogai/graph/taxonomy/toolkit.py b/ee/hogai/graph/taxonomy/toolkit.py index 5ccde1c15c..676a3bb12b 100644 --- a/ee/hogai/graph/taxonomy/toolkit.py +++ b/ee/hogai/graph/taxonomy/toolkit.py @@ -122,10 +122,6 @@ class TaxonomyTaskExecutorNode( Task executor node specifically for taxonomy operations. """ - def __init__(self, team: Team, user: User, parent_tool_call_id: str | None): - super().__init__(team, user) - self._parent_tool_call_id = parent_tool_call_id - @property def node_name(self) -> MaxNodeName: return TaxonomyNodeName.TASK_EXECUTOR @@ -148,8 +144,6 @@ class TaxonomyTaskExecutorNode( class TaxonomyAgentToolkit: """Base toolkit for taxonomy agents that handle tool execution.""" - _parent_tool_call_id: str | None - def __init__(self, team: Team, user: User): self._team = team self._user = user @@ -391,7 +385,7 @@ class TaxonomyAgentToolkit: ) ) message = AssistantMessage(content="", id=str(uuid4()), tool_calls=tool_calls) - executor = TaxonomyTaskExecutorNode(self._team, self._user, parent_tool_call_id=self._parent_tool_call_id) + executor = TaxonomyTaskExecutorNode(self._team, self._user) result = await executor.arun(AssistantState(messages=[message]), RunnableConfig()) final_result = {} @@ -651,7 +645,7 @@ class TaxonomyAgentToolkit: for event_name_or_action_id in event_name_or_action_ids ] message = AssistantMessage(content="", id=str(uuid4()), tool_calls=tool_calls) - executor = TaxonomyTaskExecutorNode(self._team, self._user, parent_tool_call_id=self._parent_tool_call_id) + executor = TaxonomyTaskExecutorNode(self._team, self._user) result = await executor.arun(AssistantState(messages=[message]), RunnableConfig()) return {task_result.id: task_result.result for task_result in result.task_results} diff --git a/ee/hogai/graph/test/test_assistant_graph.py b/ee/hogai/graph/test/test_assistant_graph.py deleted file mode 100644 index bd2a144b02..0000000000 --- a/ee/hogai/graph/test/test_assistant_graph.py +++ /dev/null @@ -1,29 +0,0 @@ -from posthog.test.base import BaseTest - -from langchain_core.runnables import RunnableLambda -from langgraph.checkpoint.memory import InMemorySaver - -from ee.hogai.graph.graph import BaseAssistantGraph -from ee.hogai.utils.types import AssistantNodeName, AssistantState, PartialAssistantState - - -class TestAssistantGraph(BaseTest): - async def test_pydantic_state_resets_with_none(self): - """When a None field is set, it should be reset to None.""" - - async def runnable(state: AssistantState) -> PartialAssistantState: - return PartialAssistantState(start_id=None) - - graph = BaseAssistantGraph(self.team, self.user, state_type=AssistantState) - compiled_graph = ( - graph.add_node(AssistantNodeName.ROOT, RunnableLambda(runnable)) - .add_edge(AssistantNodeName.START, AssistantNodeName.ROOT) - .add_edge(AssistantNodeName.ROOT, AssistantNodeName.END) - .compile(checkpointer=InMemorySaver()) - ) - state = await compiled_graph.ainvoke( - AssistantState(messages=[], graph_status="resumed", start_id=None), - {"configurable": {"thread_id": "test"}}, - ) - self.assertEqual(state["start_id"], None) - self.assertEqual(state["graph_status"], "resumed") diff --git a/ee/hogai/graph/test/test_dispatcher_integration.py b/ee/hogai/graph/test/test_dispatcher_integration.py index db8c36240f..2687579a1b 100644 --- a/ee/hogai/graph/test/test_dispatcher_integration.py +++ b/ee/hogai/graph/test/test_dispatcher_integration.py @@ -4,8 +4,7 @@ Integration tests for dispatcher usage in BaseAssistantNode and graph execution. These tests ensure that the dispatcher pattern works correctly end-to-end in real graph execution. """ -import uuid -from typing import Any +from typing import cast from posthog.test.base import BaseTest from unittest.mock import MagicMock, patch @@ -14,30 +13,40 @@ from langchain_core.runnables import RunnableConfig from posthog.schema import AssistantMessage -from ee.hogai.graph.base import BaseAssistantNode -from ee.hogai.utils.dispatcher import MessageAction, NodeStartAction -from ee.hogai.utils.types import AssistantNodeName -from ee.hogai.utils.types.base import AssistantDispatcherEvent, AssistantState, PartialAssistantState +from ee.hogai.graph.base.node import BaseAssistantNode +from ee.hogai.utils.types.base import ( + AssistantDispatcherEvent, + AssistantGraphName, + AssistantNodeName, + AssistantState, + MessageAction, + NodeEndAction, + NodePath, + NodeStartAction, + PartialAssistantState, + UpdateAction, +) -class MockAssistantNode(BaseAssistantNode): +class MockAssistantNode(BaseAssistantNode[AssistantState, PartialAssistantState]): """Mock node for testing dispatcher integration.""" - def __init__(self, team, user): - super().__init__(team, user) + def __init__(self, team, user, node_path=None): + super().__init__(team, user, node_path) self.arun_called = False self.messages_dispatched = [] @property - def node_name(self) -> AssistantNodeName: + def node_name(self) -> str: return AssistantNodeName.ROOT - async def arun(self, state, config: RunnableConfig) -> PartialAssistantState: + async def arun(self, state: AssistantState, config: RunnableConfig) -> PartialAssistantState: self.arun_called = True # Simulate dispatching messages during execution - self.dispatcher.message(AssistantMessage(content="Processing...")) - self.dispatcher.message(AssistantMessage(content="Done!")) + self.dispatcher.update("Processing...") + self.dispatcher.message(AssistantMessage(content="Intermediate result")) + self.dispatcher.update("Done!") return PartialAssistantState(messages=[AssistantMessage(content="Final result")]) @@ -57,147 +66,169 @@ class TestDispatcherIntegration(BaseTest): node = MockAssistantNode(self.mock_team, self.mock_user) state = AssistantState(messages=[]) - config = RunnableConfig() + config = RunnableConfig(metadata={"langgraph_node": AssistantNodeName.ROOT, "langgraph_checkpoint_ns": "cp_1"}) - await node.arun(state, config) + with patch("ee.hogai.utils.dispatcher.get_stream_writer", side_effect=RuntimeError("Not streaming")): + await node(state, config) self.assertTrue(node.arun_called) self.assertIsNotNone(node.dispatcher) + async def test_node_path_propagation(self): + """Test that node_path is correctly set and propagated.""" + node_path = ( + NodePath(name=AssistantGraphName.ASSISTANT), + NodePath(name=AssistantNodeName.ROOT), + ) + + node = MockAssistantNode(self.mock_team, self.mock_user, node_path=node_path) + + self.assertEqual(node.node_path, node_path) + + async def test_dispatcher_dispatches_node_start_and_end(self): + """Test that NODE_START and NODE_END actions are dispatched.""" + node = MockAssistantNode(self.mock_team, self.mock_user) + + state = AssistantState(messages=[]) + config = RunnableConfig(metadata={"langgraph_node": AssistantNodeName.ROOT, "langgraph_checkpoint_ns": "cp_2"}) + + dispatched_actions = [] + + def capture_write(event): + if isinstance(event, AssistantDispatcherEvent): + dispatched_actions.append(event) + + with patch("ee.hogai.utils.dispatcher.get_stream_writer", return_value=capture_write): + await node(state, config) + + # Should have dispatched actions in this order: + # 1. NODE_START + # 2. UpdateAction ("Processing...") + # 3. MessageAction (intermediate) + # 4. UpdateAction ("Done!") + # 5. NODE_END with final state + self.assertGreater(len(dispatched_actions), 0) + + # Verify NODE_START is first + self.assertIsInstance(dispatched_actions[0].action, NodeStartAction) + + # Verify NODE_END is last + last_action = dispatched_actions[-1].action + self.assertIsInstance(last_action, NodeEndAction) + self.assertIsNotNone(cast(NodeEndAction, last_action).state) + async def test_messages_dispatched_during_node_execution(self): """Test that messages dispatched during node execution are sent to writer.""" node = MockAssistantNode(self.mock_team, self.mock_user) state = AssistantState(messages=[]) - config = RunnableConfig() + config = RunnableConfig(metadata={"langgraph_node": AssistantNodeName.ROOT, "langgraph_checkpoint_ns": "cp_3"}) dispatched_actions = [] - def capture_write(update: Any): - dispatched_actions.append(update) + def capture_write(event: AssistantDispatcherEvent): + dispatched_actions.append(event) - # Mock get_stream_writer to return our test writer - with patch("ee.hogai.graph.base.get_stream_writer", return_value=capture_write): - # Call the node (not arun) to trigger __call__ which handles dispatching + with patch("ee.hogai.utils.dispatcher.get_stream_writer", return_value=capture_write): await node(state, config) - # Should have: - # 1. NODE_START from __call__ - # 2. Two MESSAGE actions from arun (Processing..., Done!) - # 3. One MESSAGE action from returned state (Final result) + # Find the update and message actions (excluding NODE_START and NODE_END) + update_actions = [e for e in dispatched_actions if isinstance(e.action, UpdateAction)] + message_actions = [e for e in dispatched_actions if isinstance(e.action, MessageAction)] - self.assertEqual(len(dispatched_actions), 4) - self.assertIsInstance(dispatched_actions[0], AssistantDispatcherEvent) - self.assertIsInstance(dispatched_actions[0].action, NodeStartAction) - self.assertIsInstance(dispatched_actions[1], AssistantDispatcherEvent) - self.assertIsInstance(dispatched_actions[1].action, MessageAction) - self.assertEqual(dispatched_actions[1].action.message.content, "Processing...") - self.assertIsInstance(dispatched_actions[2], AssistantDispatcherEvent) - self.assertIsInstance(dispatched_actions[2].action, MessageAction) - self.assertEqual(dispatched_actions[2].action.message.content, "Done!") - self.assertIsInstance(dispatched_actions[3], AssistantDispatcherEvent) - self.assertIsInstance(dispatched_actions[3].action, MessageAction) - self.assertEqual(dispatched_actions[3].action.message.content, "Final result") + # Should have 2 updates: "Processing..." and "Done!" + self.assertEqual(len(update_actions), 2) + self.assertEqual(cast(UpdateAction, update_actions[0].action).content, "Processing...") + self.assertEqual(cast(UpdateAction, update_actions[1].action).content, "Done!") - async def test_node_start_action_dispatched(self): - """Test that NODE_START action is dispatched at node entry.""" + # Should have 1 message: intermediate (final message is in NODE_END state, not dispatched separately) + self.assertEqual(len(message_actions), 1) + msg = cast(MessageAction, message_actions[0].action).message + self.assertEqual(cast(AssistantMessage, msg).content, "Intermediate result") + + async def test_node_path_included_in_dispatched_events(self): + """Test that node_path is included in all dispatched events.""" + node_path = ( + NodePath(name=AssistantGraphName.ASSISTANT), + NodePath(name=AssistantNodeName.ROOT), + ) + + node = MockAssistantNode(self.mock_team, self.mock_user, node_path=node_path) + + state = AssistantState(messages=[]) + config = RunnableConfig(metadata={"langgraph_node": AssistantNodeName.ROOT, "langgraph_checkpoint_ns": "cp_4"}) + + dispatched_events = [] + + def capture_write(event: AssistantDispatcherEvent): + dispatched_events.append(event) + + with patch("ee.hogai.utils.dispatcher.get_stream_writer", return_value=capture_write): + await node(state, config) + + # Verify all events have the correct node_path + for event in dispatched_events: + self.assertEqual(event.node_path, node_path) + + async def test_node_run_id_included_in_dispatched_events(self): + """Test that node_run_id is included in all dispatched events.""" node = MockAssistantNode(self.mock_team, self.mock_user) state = AssistantState(messages=[]) - config = RunnableConfig() + checkpoint_ns = "checkpoint_xyz_789" + config = RunnableConfig( + metadata={"langgraph_node": AssistantNodeName.ROOT, "langgraph_checkpoint_ns": checkpoint_ns} + ) - dispatched_actions = [] + dispatched_events = [] - def capture_write(update): - if isinstance(update, AssistantDispatcherEvent): - dispatched_actions.append(update.action) + def capture_write(event: AssistantDispatcherEvent): + dispatched_events.append(event) - with patch("ee.hogai.graph.base.get_stream_writer", return_value=capture_write): + with patch("ee.hogai.utils.dispatcher.get_stream_writer", return_value=capture_write): await node(state, config) - # Should have at least one NODE_START action - from ee.hogai.utils.dispatcher import NodeStartAction + # Verify all events have the correct node_run_id + for event in dispatched_events: + self.assertEqual(event.node_run_id, checkpoint_ns) - node_start_actions = [action for action in dispatched_actions if isinstance(action, NodeStartAction)] - self.assertGreater(len(node_start_actions), 0) - - @patch("ee.hogai.graph.base.get_stream_writer") - async def test_parent_tool_call_id_propagation(self, mock_get_stream_writer): - """Test that parent_tool_call_id is propagated to dispatched messages.""" - parent_tool_call_id = str(uuid.uuid4()) - - class NodeWithParent(BaseAssistantNode): - @property - def node_name(self): - return AssistantNodeName.ROOT - - async def arun(self, state, config: RunnableConfig) -> PartialAssistantState: - # Dispatch message - should inherit parent_tool_call_id from state - self.dispatcher.message(AssistantMessage(content="Child message")) - return PartialAssistantState(messages=[]) - - node = NodeWithParent(self.mock_team, self.mock_user) - - state = {"messages": [], "parent_tool_call_id": parent_tool_call_id} - config = RunnableConfig() - - dispatched_messages = [] - - def capture_write(update): - if isinstance(update, AssistantDispatcherEvent) and isinstance(update.action, MessageAction): - dispatched_messages.append(update.action.message) - - mock_get_stream_writer.return_value = capture_write - - await node.arun(state, config) - - # Verify dispatched messages have parent_tool_call_id - assistant_messages = [msg for msg in dispatched_messages if isinstance(msg, AssistantMessage)] - for msg in assistant_messages: - # If the implementation propagates it, this should be true - # Otherwise this test will help catch that as a potential issue - if msg.parent_tool_call_id: - self.assertEqual(msg.parent_tool_call_id, parent_tool_call_id) - - @patch("ee.hogai.graph.base.get_stream_writer") - async def test_dispatcher_error_handling(self, mock_get_stream_writer): + async def test_dispatcher_error_handling_does_not_crash_node(self): """Test that errors in dispatcher don't crash node execution.""" - class FailingDispatcherNode(BaseAssistantNode): + class FailingDispatcherNode(BaseAssistantNode[AssistantState, PartialAssistantState]): @property def node_name(self): return AssistantNodeName.ROOT - async def arun(self, state, config: RunnableConfig) -> PartialAssistantState: + async def arun(self, state: AssistantState, config: RunnableConfig) -> PartialAssistantState: # Try to dispatch - if writer fails, should handle gracefully - self.dispatcher.message(AssistantMessage(content="Test")) + self.dispatcher.update("Test") return PartialAssistantState(messages=[AssistantMessage(content="Result")]) node = FailingDispatcherNode(self.mock_team, self.mock_user) state = AssistantState(messages=[]) - config = RunnableConfig() + config = RunnableConfig(metadata={"langgraph_node": AssistantNodeName.ROOT, "langgraph_checkpoint_ns": "cp_5"}) # Make writer raise an error def failing_writer(data): raise RuntimeError("Writer failed") - mock_get_stream_writer.return_value = failing_writer + with patch("ee.hogai.utils.dispatcher.get_stream_writer", return_value=failing_writer): + # Should not crash - node should complete + result = await node(state, config) + self.assertIsNotNone(result) - # Should not crash - node should complete - result = await node.arun(state, config) - self.assertIsNotNone(result) + async def test_messages_in_partial_state_dispatched_via_node_end(self): + """Test that messages in PartialState are dispatched via NODE_END.""" - async def test_messages_in_partial_state_are_auto_dispatched(self): - """Test that messages in PartialState are automatically dispatched.""" - - class NodeReturningMessages(BaseAssistantNode): + class NodeReturningMessages(BaseAssistantNode[AssistantState, PartialAssistantState]): @property def node_name(self): return AssistantNodeName.ROOT - async def arun(self, state, config: RunnableConfig) -> PartialAssistantState: - # Return messages in state - should be auto-dispatched + async def arun(self, state: AssistantState, config: RunnableConfig) -> PartialAssistantState: + # Return messages in state return PartialAssistantState( messages=[ AssistantMessage(content="Message 1"), @@ -208,39 +239,139 @@ class TestDispatcherIntegration(BaseTest): node = NodeReturningMessages(self.mock_team, self.mock_user) state = AssistantState(messages=[]) - config = RunnableConfig() + config = RunnableConfig(metadata={"langgraph_node": AssistantNodeName.ROOT, "langgraph_checkpoint_ns": "cp_6"}) - dispatched_messages = [] + dispatched_events = [] - def capture_write(update): - if isinstance(update, AssistantDispatcherEvent) and isinstance(update.action, MessageAction): - dispatched_messages.append(update.action.message) + def capture_write(event: AssistantDispatcherEvent): + dispatched_events.append(event) - with patch("ee.hogai.graph.base.get_stream_writer", return_value=capture_write): + with patch("ee.hogai.utils.dispatcher.get_stream_writer", return_value=capture_write): await node(state, config) - # Should have dispatched the messages from PartialState (and NODE_START) - # We expect at least 2 message actions (Message 1 and Message 2) - self.assertGreaterEqual(len(dispatched_messages), 2) + # Should have NODE_END action with state containing messages + node_end_actions = [e for e in dispatched_events if isinstance(e.action, NodeEndAction)] + self.assertEqual(len(node_end_actions), 1) + + node_end_state = cast(NodeEndAction, node_end_actions[0].action).state + self.assertIsNotNone(node_end_state) + assert node_end_state is not None + self.assertEqual(len(node_end_state.messages), 2) async def test_node_returns_none_state_handling(self): """Test that node can return None state without errors.""" - class NoneStateNode(BaseAssistantNode): + class NoneStateNode(BaseAssistantNode[AssistantState, PartialAssistantState]): @property def node_name(self): return AssistantNodeName.ROOT - async def arun(self, state, config: RunnableConfig) -> None: + async def arun(self, state: AssistantState, config: RunnableConfig) -> PartialAssistantState | None: # Dispatch a message but return None state - self.dispatcher.message(AssistantMessage(content="Test")) + self.dispatcher.update("Test") return None node = NoneStateNode(self.mock_team, self.mock_user) state = AssistantState(messages=[]) - config = RunnableConfig() + config = RunnableConfig(metadata={"langgraph_node": AssistantNodeName.ROOT, "langgraph_checkpoint_ns": "cp_7"}) - result = await node(state, config) - # Should handle None gracefully - self.assertIsNone(result) + with patch("ee.hogai.utils.dispatcher.get_stream_writer", side_effect=RuntimeError("Not streaming")): + result = await node(state, config) + # Should handle None gracefully + self.assertIsNone(result) + + async def test_nested_node_path_in_dispatched_events(self): + """Test that nested nodes have correct node_path.""" + # Simulate a nested node path (e.g., from a tool call) + parent_path = ( + NodePath(name=AssistantGraphName.ASSISTANT), + NodePath(name=AssistantNodeName.ROOT, message_id="msg_123", tool_call_id="tc_123"), + NodePath(name=AssistantGraphName.INSIGHTS), + ) + + node = MockAssistantNode(self.mock_team, self.mock_user, node_path=parent_path) + + state = AssistantState(messages=[]) + config = RunnableConfig( + metadata={"langgraph_node": AssistantNodeName.TRENDS_GENERATOR, "langgraph_checkpoint_ns": "cp_8"} + ) + + dispatched_events = [] + + def capture_write(event: AssistantDispatcherEvent): + dispatched_events.append(event) + + with patch("ee.hogai.utils.dispatcher.get_stream_writer", return_value=capture_write): + await node(state, config) + + # Verify all events have the nested path + for event in dispatched_events: + self.assertIsNotNone(event.node_path) + assert event.node_path is not None + self.assertEqual(event.node_path, parent_path) + self.assertEqual(len(event.node_path), 3) + self.assertEqual(event.node_path[1].message_id, "msg_123") + self.assertEqual(event.node_path[1].tool_call_id, "tc_123") + + async def test_update_actions_include_node_metadata(self): + """Test that update actions include correct node metadata.""" + node = MockAssistantNode(self.mock_team, self.mock_user) + + state = AssistantState(messages=[]) + config = RunnableConfig(metadata={"langgraph_node": AssistantNodeName.ROOT, "langgraph_checkpoint_ns": "cp_9"}) + + dispatched_events = [] + + def capture_write(event: AssistantDispatcherEvent): + dispatched_events.append(event) + + with patch("ee.hogai.utils.dispatcher.get_stream_writer", return_value=capture_write): + await node(state, config) + + # Find update actions + update_events = [e for e in dispatched_events if isinstance(e.action, UpdateAction)] + + for event in update_events: + self.assertEqual(event.node_name, AssistantNodeName.ROOT) + self.assertEqual(event.node_run_id, "cp_9") + self.assertIsNotNone(event.node_path) + + async def test_concurrent_node_executions_independent_dispatchers(self): + """Test that concurrent node executions use independent dispatchers.""" + node1 = MockAssistantNode(self.mock_team, self.mock_user) + node2 = MockAssistantNode(self.mock_team, self.mock_user) + + state = AssistantState(messages=[]) + config1 = RunnableConfig( + metadata={"langgraph_node": AssistantNodeName.ROOT, "langgraph_checkpoint_ns": "cp_10a"} + ) + config2 = RunnableConfig( + metadata={"langgraph_node": AssistantNodeName.ROOT, "langgraph_checkpoint_ns": "cp_10b"} + ) + + events1 = [] + events2 = [] + + def capture_write1(event: AssistantDispatcherEvent): + events1.append(event) + + def capture_write2(event: AssistantDispatcherEvent): + events2.append(event) + + with patch("ee.hogai.utils.dispatcher.get_stream_writer", return_value=capture_write1): + await node1(state, config1) + + with patch("ee.hogai.utils.dispatcher.get_stream_writer", return_value=capture_write2): + await node2(state, config2) + + # Verify events went to separate lists + self.assertGreater(len(events1), 0) + self.assertGreater(len(events2), 0) + + # Verify node_run_ids are different + for event in events1: + self.assertEqual(event.node_run_id, "cp_10a") + + for event in events2: + self.assertEqual(event.node_run_id, "cp_10b") diff --git a/ee/hogai/graph/trends/nodes.py b/ee/hogai/graph/trends/nodes.py index cb11eb2ba7..6034a7ef58 100644 --- a/ee/hogai/graph/trends/nodes.py +++ b/ee/hogai/graph/trends/nodes.py @@ -1,7 +1,7 @@ from langchain_core.prompts import ChatPromptTemplate from langchain_core.runnables import RunnableConfig -from posthog.schema import AssistantMessage, AssistantTrendsQuery +from posthog.schema import AssistantTrendsQuery from ee.hogai.utils.types import AssistantState from ee.hogai.utils.types.base import AssistantNodeName, PartialAssistantState @@ -25,7 +25,7 @@ class TrendsGeneratorNode(SchemaGeneratorNode[AssistantTrendsQuery]): return AssistantNodeName.TRENDS_GENERATOR async def arun(self, state: AssistantState, config: RunnableConfig) -> PartialAssistantState: - self.dispatcher.message(AssistantMessage(content="Creating trends query")) + self.dispatcher.update("Creating trends query") prompt = ChatPromptTemplate.from_messages( [ ("system", TRENDS_SYSTEM_PROMPT), diff --git a/ee/hogai/test/test_assistant.py b/ee/hogai/test/test_assistant.py index 588ba4d2ae..ccdf8e00df 100644 --- a/ee/hogai/test/test_assistant.py +++ b/ee/hogai/test/test_assistant.py @@ -9,7 +9,7 @@ from posthog.test.base import ( _create_person, flush_persons_and_events, ) -from unittest.mock import AsyncMock, patch +from unittest.mock import AsyncMock, MagicMock, patch from django.test import override_settings @@ -61,13 +61,13 @@ from posthog.models import Action from ee.hogai.assistant.base import BaseAssistant from ee.hogai.django_checkpoint.checkpointer import DjangoCheckpointer -from ee.hogai.graph.base import BaseAssistantNode +from ee.hogai.graph.base import AssistantNode from ee.hogai.graph.funnels.nodes import FunnelsSchemaGeneratorOutput +from ee.hogai.graph.insights_graph.graph import InsightsGraph from ee.hogai.graph.memory import prompts as memory_prompts from ee.hogai.graph.retention.nodes import RetentionSchemaGeneratorOutput from ee.hogai.graph.root.nodes import SLASH_COMMAND_INIT from ee.hogai.graph.trends.nodes import TrendsSchemaGeneratorOutput -from ee.hogai.utils.state import GraphValueUpdateTuple from ee.hogai.utils.tests import FakeAnthropicRunnableLambdaWithTokenCounter, FakeChatAnthropic, FakeChatOpenAI from ee.hogai.utils.types import ( AssistantMode, @@ -76,16 +76,21 @@ from ee.hogai.utils.types import ( AssistantState, PartialAssistantState, ) +from ee.hogai.utils.types.base import ReplaceMessages from ee.models.assistant import Conversation, CoreMemory from ..assistant import Assistant -from ..graph import AssistantGraph, InsightsAssistantGraph +from ..graph.graph import AssistantGraph title_generator_mock = patch( "ee.hogai.graph.title_generator.nodes.TitleGeneratorNode._model", return_value=FakeChatOpenAI(responses=[messages.AIMessage(content="Title")]), ) +query_executor_mock = patch( + "ee.hogai.graph.query_executor.nodes.QueryExecutorNode._format_query_result", new=MagicMock(return_value="Result") +) + class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest): CLASS_DATA_LEVEL_SETUP = False @@ -118,7 +123,7 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest): ), ).start() - self.checkpointer_patch = patch("ee.hogai.graph.graph.global_checkpointer", new=DjangoCheckpointer()) + self.checkpointer_patch = patch("ee.hogai.graph.base.graph.global_checkpointer", new=DjangoCheckpointer()) self.checkpointer_patch.start() def tearDown(self): @@ -232,14 +237,7 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest): graph = ( AssistantGraph(self.team, self.user) .add_edge(AssistantNodeName.START, AssistantNodeName.ROOT) - .add_root( - { - "insights": AssistantNodeName.INSIGHTS_SUBGRAPH, - "root": AssistantNodeName.ROOT, - "end": AssistantNodeName.END, - } - ) - .add_insights(AssistantNodeName.ROOT) + .add_root() .compile() ) @@ -269,7 +267,7 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest): { "id": "1", "name": "create_and_query_insight", - "args": {"query_description": "Foobar", "query_kind": insight_type}, + "args": {"query_description": "Foobar"}, } ], ) @@ -302,7 +300,7 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest): AssistantToolCall( id="1", name="create_and_query_insight", - args={"query_description": "Foobar", "query_kind": insight_type}, + args={"query_description": "Foobar"}, ) ], ), @@ -327,12 +325,12 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest): snapshot: StateSnapshot = await graph.aget_state(config) self.assertFalse(snapshot.next) self.assertFalse(snapshot.values.get("intermediate_steps")) - self.assertFalse(snapshot.values["plan"]) - self.assertFalse(snapshot.values["graph_status"]) - self.assertFalse(snapshot.values["root_tool_call_id"]) - self.assertFalse(snapshot.values["root_tool_insight_plan"]) - self.assertFalse(snapshot.values["root_tool_insight_type"]) - self.assertFalse(snapshot.values["root_tool_calls_count"]) + self.assertFalse(snapshot.values.get("plan")) + self.assertFalse(snapshot.values.get("graph_status")) + self.assertFalse(snapshot.values.get("root_tool_call_id")) + self.assertFalse(snapshot.values.get("root_tool_insight_plan")) + self.assertFalse(snapshot.values.get("root_tool_insight_type")) + self.assertFalse(snapshot.values.get("root_tool_calls_count")) async def test_trends_interrupt_when_asking_for_help(self): await self._test_human_in_the_loop("trends") @@ -345,7 +343,7 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest): async def test_ai_messages_appended_after_interrupt(self): with patch("ee.hogai.graph.query_planner.nodes.QueryPlannerNode._get_model") as mock: - graph = InsightsAssistantGraph(self.team, self.user).compile_full_graph() + graph = InsightsGraph(self.team, self.user).compile_full_graph() config: RunnableConfig = { "configurable": { "thread_id": self.conversation.id, @@ -441,16 +439,13 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest): self.assertIsInstance(output[1][1], FailureMessage) async def test_new_conversation_handles_serialized_conversation(self): - from ee.hogai.graph.base import BaseAssistantNode - - class TestNode(BaseAssistantNode): + class TestNode(AssistantNode): @property def node_name(self): return AssistantNodeName.ROOT async def arun(self, state, config): - self.dispatcher.message(AssistantMessage(content="Hello")) - return PartialAssistantState() + return PartialAssistantState(messages=[AssistantMessage(content="Hello", id=str(uuid4()))]) test_node = TestNode(self.team, self.user) graph = ( @@ -478,16 +473,13 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest): self.assertNotEqual(output[0][0], "conversation") async def test_async_stream(self): - from ee.hogai.graph.base import BaseAssistantNode - - class TestNode(BaseAssistantNode): + class TestNode(AssistantNode): @property def node_name(self): return AssistantNodeName.ROOT async def arun(self, state, config): - self.dispatcher.message(AssistantMessage(content="bar")) - return PartialAssistantState() + return PartialAssistantState(messages=[AssistantMessage(content="bar", id=str(uuid4()))]) test_node = TestNode(self.team, self.user) graph = ( @@ -517,12 +509,13 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest): self.assertConversationEqual(actual_output, expected_output) async def test_async_stream_handles_exceptions(self): - def node_handler(state): - raise ValueError() + class NodeHandler(AssistantNode): + async def arun(self, state, config): + raise ValueError graph = ( AssistantGraph(self.team, self.user) - .add_node(AssistantNodeName.ROOT, node_handler) + .add_node(AssistantNodeName.ROOT, NodeHandler(self.team, self.user)) .add_edge(AssistantNodeName.START, AssistantNodeName.ROOT) .add_edge(AssistantNodeName.ROOT, AssistantNodeName.END) .compile() @@ -536,12 +529,11 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest): ("message", HumanMessage(content="foo")), ("message", FailureMessage()), ] - actual_output = [] - async for event in assistant.astream(): - actual_output.append(event) + actual_output, _ = await self._run_assistant_graph(graph, message="foo") self.assertConversationEqual(actual_output, expected_output) @title_generator_mock + @query_executor_mock @patch("ee.hogai.graph.schema_generator.nodes.SchemaGeneratorNode._model") @patch("ee.hogai.graph.query_planner.nodes.QueryPlannerNode._get_model") @patch("ee.hogai.graph.root.nodes.RootNode._get_model") @@ -556,7 +548,7 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest): { "id": "xyz", "name": "create_and_query_insight", - "args": {"query_description": "Foobar", "query_kind": "trends"}, + "args": {"query_description": "Foobar"}, } ], ) @@ -596,7 +588,7 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest): AssistantToolCall( id="xyz", name="create_and_query_insight", - args={"query_description": "Foobar", "query_kind": "trends"}, + args={"query_description": "Foobar"}, ) ], ), @@ -611,7 +603,7 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest): ), ("update", AssistantUpdateEvent(id="message_2", tool_call_id="xyz", content="Creating trends query")), ("message", VisualizationMessage(query="Foobar", answer=query, plan="Plan")), - ("message", {"tool_call_id": "xyz", "type": "tool"}), # Don't check content as it's implementation detail + ("message", AssistantToolCallMessage(tool_call_id="xyz", content="Result")), ("message", AssistantMessage(content="The results indicate a great future for you.")), ] self.assertConversationEqual(actual_output, expected_output) @@ -634,6 +626,7 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest): ) @title_generator_mock + @query_executor_mock @patch("ee.hogai.graph.schema_generator.nodes.SchemaGeneratorNode._model") @patch("ee.hogai.graph.query_planner.nodes.QueryPlannerNode._get_model") @patch("ee.hogai.graph.root.nodes.RootNode._get_model") @@ -648,7 +641,7 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest): { "id": "xyz", "name": "create_and_query_insight", - "args": {"query_description": "Foobar", "query_kind": "funnel"}, + "args": {"query_description": "Foobar"}, } ], ) @@ -693,7 +686,7 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest): AssistantToolCall( id="xyz", name="create_and_query_insight", - args={"query_description": "Foobar", "query_kind": "funnel"}, + args={"query_description": "Foobar"}, ) ], ), @@ -708,7 +701,7 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest): ), ("update", AssistantUpdateEvent(id="message_2", tool_call_id="xyz", content="Creating funnel query")), ("message", VisualizationMessage(query="Foobar", answer=query, plan="Plan")), - ("message", {"tool_call_id": "xyz", "type": "tool"}), # Don't check content as it's implementation detail + ("message", AssistantToolCallMessage(tool_call_id="xyz", content="Result")), ("message", AssistantMessage(content="The results indicate a great future for you.")), ] self.assertConversationEqual(actual_output, expected_output) @@ -731,6 +724,7 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest): ) @title_generator_mock + @query_executor_mock @patch("ee.hogai.graph.schema_generator.nodes.SchemaGeneratorNode._model") @patch("ee.hogai.graph.query_planner.nodes.QueryPlannerNode._get_model") @patch("ee.hogai.graph.root.nodes.RootNode._get_model") @@ -747,7 +741,7 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest): { "id": "xyz", "name": "create_and_query_insight", - "args": {"query_description": "Foobar", "query_kind": "retention"}, + "args": {"query_description": "Foobar"}, } ], ) @@ -792,7 +786,7 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest): AssistantToolCall( id="xyz", name="create_and_query_insight", - args={"query_description": "Foobar", "query_kind": "retention"}, + args={"query_description": "Foobar"}, ) ], ), @@ -807,7 +801,7 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest): ), ("update", AssistantUpdateEvent(id="message_2", tool_call_id="xyz", content="Creating retention query")), ("message", VisualizationMessage(query="Foobar", answer=query, plan="Plan")), - ("message", {"tool_call_id": "xyz", "type": "tool"}), # Don't check content as it's implementation detail + ("message", AssistantToolCallMessage(tool_call_id="xyz", content="Result")), ("message", AssistantMessage(content="The results indicate a great future for you.")), ] self.assertConversationEqual(actual_output, expected_output) @@ -830,6 +824,7 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest): ) @title_generator_mock + @query_executor_mock @patch("ee.hogai.graph.schema_generator.nodes.SchemaGeneratorNode._model") @patch("ee.hogai.graph.query_planner.nodes.QueryPlannerNode._get_model") @patch("ee.hogai.graph.root.nodes.RootNode._get_model") @@ -844,7 +839,7 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest): { "id": "xyz", "name": "create_and_query_insight", - "args": {"query_description": "Foobar", "query_kind": "sql"}, + "args": {"query_description": "Foobar"}, } ], ) @@ -883,7 +878,7 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest): AssistantToolCall( id="xyz", name="create_and_query_insight", - args={"query_description": "Foobar", "query_kind": "sql"}, + args={"query_description": "Foobar"}, ) ], ), @@ -898,7 +893,7 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest): ), ("update", AssistantUpdateEvent(id="message_2", tool_call_id="xyz", content="Creating SQL query")), ("message", VisualizationMessage(query="Foobar", answer=query, plan="Plan")), - ("message", {"tool_call_id": "xyz", "type": "tool"}), # Don't check content as it's implementation detail + ("message", AssistantToolCallMessage(tool_call_id="xyz", content="Result")), ("message", AssistantMessage(content="The results indicate a great future for you.")), ] self.assertConversationEqual(actual_output, expected_output) @@ -1096,7 +1091,7 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest): { "id": str(uuid4()), "name": "create_and_query_insight", - "args": {"query_description": "Foobar", "query_kind": "trends"}, + "args": {"query_description": "Foobar"}, } ], ) @@ -1138,7 +1133,7 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest): graph = ( AssistantGraph(self.team, self.user) .add_edge(AssistantNodeName.START, AssistantNodeName.ROOT) - .add_root({"root": AssistantNodeName.ROOT, "end": AssistantNodeName.END}) + .add_root() .compile() ) self.assertEqual(self.conversation.status, Conversation.Status.IDLE) @@ -1158,7 +1153,7 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest): graph = ( AssistantGraph(self.team, self.user) .add_edge(AssistantNodeName.START, AssistantNodeName.ROOT) - .add_root({"root": AssistantNodeName.ROOT, "end": AssistantNodeName.END}) + .add_root() .compile() ) @@ -1177,7 +1172,7 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest): { "id": "1", "name": "create_and_query_insight", - "args": {"query_description": "Foobar", "query_kind": "trends"}, + "args": {"query_description": "Foobar"}, } ], ) @@ -1209,14 +1204,7 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest): graph = ( AssistantGraph(self.team, self.user) .add_edge(AssistantNodeName.START, AssistantNodeName.ROOT) - .add_root( - { - "search_documentation": AssistantNodeName.INKEEP_DOCS, - "root": AssistantNodeName.ROOT, - "end": AssistantNodeName.END, - } - ) - .add_inkeep_docs() + .add_root() .compile() ) @@ -1235,45 +1223,25 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest): self.assertConversationEqual( output, [ - ( - "message", - HumanMessage( - content="How do I use feature flags?", - id="d93b3d57-2ff3-4c63-8635-8349955b16f0", - parent_tool_call_id=None, - type="human", - ui_context=None, - ), - ), + ("message", HumanMessage(content="How do I use feature flags?")), ( "message", AssistantMessage( content="", - id="ea1f4faf-85b4-4d98-b044-842b7092ba5d", - meta=None, - parent_tool_call_id=None, tool_calls=[ AssistantToolCall( args={"kind": "docs", "query": "test"}, id="1", name="search", type="tool_call" ) ], - type="ai", ), ), ( "update", - AssistantUpdateEvent(id="message_2", tool_call_id="1", content="Checking PostHog documentation..."), + AssistantUpdateEvent(content="Checking PostHog documentation...", id="1", tool_call_id="1"), ), ( "message", - AssistantToolCallMessage( - content="Checking PostHog documentation...", - id="00149847-0bcd-4b7c-a514-bab176312ae8", - parent_tool_call_id=None, - tool_call_id="1", - type="tool", - ui_payload=None, - ), + AssistantToolCallMessage(content="Checking PostHog documentation...", tool_call_id="1"), ), ( "message", @@ -1283,12 +1251,10 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest): ) @title_generator_mock + @query_executor_mock @patch("ee.hogai.graph.schema_generator.nodes.SchemaGeneratorNode._model") @patch("ee.hogai.graph.query_planner.nodes.QueryPlannerNode._get_model") - @patch("ee.hogai.graph.graph.QueryExecutorNode") - async def test_insights_tool_mode_flow( - self, query_executor_mock, planner_mock, generator_mock, title_generator_mock - ): + async def test_insights_tool_mode_flow(self, planner_mock, generator_mock, title_generator_mock): """Test that the insights tool mode works correctly.""" query = AssistantTrendsQuery(series=[]) tool_call_id = str(uuid4()) @@ -1315,25 +1281,6 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest): ) generator_mock.return_value = RunnableLambda(lambda _: TrendsSchemaGeneratorOutput(query=query)) - class QueryExecutorNodeMock(BaseAssistantNode): - def __init__(self, team, user): - super().__init__(team, user) - - @property - def node_name(self): - return AssistantNodeName.QUERY_EXECUTOR - - async def arun(self, state, config): - return PartialAssistantState( - messages=[ - AssistantToolCallMessage( - content="The results indicate a great future for you.", tool_call_id=tool_call_id - ) - ] - ) - - query_executor_mock.return_value = QueryExecutorNodeMock(self.team, self.user) - # Run in insights tool mode output, _ = await self._run_assistant_graph( conversation=self.conversation, @@ -1347,9 +1294,7 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest): ("message", VisualizationMessage(query="Foobar", answer=query, plan="Plan")), ( "message", - AssistantToolCallMessage( - content="The results indicate a great future for you.", tool_call_id=tool_call_id - ), + AssistantToolCallMessage(content="Result", tool_call_id=tool_call_id), ), ] self.assertConversationEqual(output, expected_output) @@ -1387,7 +1332,6 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest): async def test_merges_messages_with_same_id(self): """Test that messages with the same ID are merged into one.""" - from ee.hogai.graph.base import BaseAssistantNode message_ids = [str(uuid4()), str(uuid4())] @@ -1395,7 +1339,7 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest): first_content = "First version of message" updated_content = "Updated version of message" - class MessageUpdatingNode(BaseAssistantNode): + class MessageUpdatingNode(AssistantNode): def __init__(self, team, user): super().__init__(team, user) self.call_count = 0 @@ -1447,7 +1391,6 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest): async def test_assistant_filters_messages_correctly(self): """Test that the Assistant class correctly filters messages based on should_output_assistant_message.""" - from ee.hogai.graph.base import BaseAssistantNode output_messages = [ # Should be output (has content) @@ -1467,7 +1410,7 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest): for test_message, expected_in_output in output_messages: # Create a simple graph that produces different message types to test filtering - class MessageFilteringNode(BaseAssistantNode): + class MessageFilteringNode(AssistantNode): def __init__(self, team, user, message_to_return): super().__init__(team, user) self.message_to_return = message_to_return @@ -1477,8 +1420,7 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest): return AssistantNodeName.ROOT async def arun(self, state, config): - self.dispatcher.message(self.message_to_return) - return PartialAssistantState() + return PartialAssistantState(messages=[self.message_to_return]) # Create a graph with our test node node = MessageFilteringNode(self.team, self.user, test_message) @@ -1505,12 +1447,13 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest): """Test that ui_context persists when retrieving conversation state across multiple runs.""" # Create a simple graph that just returns the initial state - def return_initial_state(state): - return {"messages": [AssistantMessage(content="Response from assistant")]} + class ReturnInitialStateNode(AssistantNode): + async def arun(self, state, config): + return PartialAssistantState(messages=[AssistantMessage(content="Response from assistant")]) graph = ( AssistantGraph(self.team, self.user) - .add_node(AssistantNodeName.ROOT, return_initial_state) + .add_node(AssistantNodeName.ROOT, ReturnInitialStateNode(self.team, self.user)) .add_edge(AssistantNodeName.START, AssistantNodeName.ROOT) .add_edge(AssistantNodeName.ROOT, AssistantNodeName.END) .compile() @@ -1585,8 +1528,8 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest): tool_calls=[ { "id": "xyz", - "name": "edit_current_insight", - "args": {"query_description": "Foobar", "query_kind": "trends"}, + "name": "create_and_query_insight", + "args": {"query_description": "Foobar"}, } ], ) @@ -1622,20 +1565,13 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest): output, assistant = await self._run_assistant_graph( test_graph=AssistantGraph(self.team, self.user) .add_edge(AssistantNodeName.START, AssistantNodeName.ROOT) - .add_root( - { - "root": AssistantNodeName.ROOT, - "insights": AssistantNodeName.INSIGHTS_SUBGRAPH, - "end": AssistantNodeName.END, - } - ) - .add_insights() + .add_root() .compile(), conversation=self.conversation, is_new_conversation=True, message="Hello", mode=AssistantMode.ASSISTANT, - contextual_tools={"edit_current_insight": {"current_query": "query"}}, + contextual_tools={"create_and_query_insight": {"current_query": "query"}}, ) expected_output = [ @@ -1645,23 +1581,30 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest): "message", AssistantMessage( content="", + id="56076433-5d90-4248-9a46-df3fda42bd0a", tool_calls=[ AssistantToolCall( + args={"query_description": "Foobar"}, id="xyz", - name="edit_current_insight", - args={"query_description": "Foobar", "query_kind": "trends"}, + name="create_and_query_insight", + type="tool_call", ) ], ), ), + ( + "update", + AssistantUpdateEvent(content="Picking relevant events and properties", tool_call_id="xyz", id=""), + ), + ("update", AssistantUpdateEvent(content="Creating trends query", tool_call_id="xyz", id="")), ("message", VisualizationMessage(query="Foobar", answer=query, plan="Plan")), ( "message", - { - "tool_call_id": "xyz", - "type": "tool", - "ui_payload": {"edit_current_insight": query.model_dump(exclude_none=True)}, - }, # Don't check content as it's implementation detail + AssistantToolCallMessage( + content="The results indicate a great future for you.", + tool_call_id="xyz", + ui_payload={"create_and_query_insight": query.model_dump(exclude_none=True)}, + ), ), ("message", AssistantMessage(content="Everything is fine")), ] @@ -1671,7 +1614,7 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest): state = AssistantState.model_validate(snapshot.values) expected_state_messages = [ ContextMessage( - content="\nContextual tools that are available to you on this page are:\n\nThe user is currently editing an insight (aka query). Here is that insight's current definition, which can be edited using the `edit_current_insight` tool:\n\n```json\nquery\n```\n\nIMPORTANT: DO NOT REMOVE ANY FIELDS FROM THE CURRENT INSIGHT DEFINITION. DO NOT CHANGE ANY OTHER FIELDS THAN THE ONES THE USER ASKED FOR. KEEP THE REST AS IS.\n\nIMPORTANT: this context may or may not be relevant to your tasks. You should not respond to this context unless it is highly relevant to your task.\n" + content="\nContextual tools that are available to you on this page are:\n\nThe user is currently editing an insight (aka query). Here is that insight's current definition, which can be edited using the `create_and_query_insight` tool:\n\n```json\nquery\n```\n\n\nDo not remove any fields from the current insight definition. Do not change any other fields than the ones the user asked for. Keep the rest as is.\n\n\nIMPORTANT: this context may or may not be relevant to your tasks. You should not respond to this context unless it is highly relevant to your task.\n" ), HumanMessage(content="Hello"), AssistantMessage( @@ -1679,8 +1622,8 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest): tool_calls=[ AssistantToolCall( id="xyz", - name="edit_current_insight", - args={"query_description": "Foobar", "query_kind": "trends"}, + name="create_and_query_insight", + args={"query_description": "Foobar"}, ) ], ), @@ -1688,7 +1631,7 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest): AssistantToolCallMessage( content="The results indicate a great future for you.", tool_call_id="xyz", - ui_payload={"edit_current_insight": query.model_dump(exclude_none=True)}, + ui_payload={"create_and_query_insight": query.model_dump(exclude_none=True)}, ), AssistantMessage(content="Everything is fine"), ] @@ -1705,7 +1648,7 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest): graph = ( AssistantGraph(self.team, self.user) .add_edge(AssistantNodeName.START, AssistantNodeName.ROOT) - .add_root({"root": AssistantNodeName.ROOT, "end": AssistantNodeName.END}) + .add_root() .compile() ) @@ -1757,16 +1700,14 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest): # Tests for ainvoke method async def test_ainvoke_basic_functionality(self): """Test ainvoke returns all messages at once without streaming.""" - from ee.hogai.graph.base import BaseAssistantNode - class TestNode(BaseAssistantNode): + class TestNode(AssistantNode): @property def node_name(self): return AssistantNodeName.ROOT async def arun(self, state, config): - self.dispatcher.message(AssistantMessage(content="Response")) - return PartialAssistantState() + return PartialAssistantState(messages=[AssistantMessage(content="Response", id=str(uuid4()))]) test_node = TestNode(self.team, self.user) graph = ( @@ -1798,28 +1739,6 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest): self.assertIsInstance(item[1], AssistantMessage) self.assertEqual(cast(AssistantMessage, item[1]).content, "Response") - async def test_process_value_update_returns_none(self): - """Test that _aprocess_value_update returns None for basic state updates (ACKs are now handled by reducer).""" - - assistant = Assistant.create( - self.team, self.conversation, new_message=HumanMessage(content="Hello"), user=self.user - ) - - # Create a value update tuple that doesn't match special nodes - update = cast( - GraphValueUpdateTuple, - ( - AssistantNodeName.ROOT, - {"root": {"messages": []}}, # Empty update that doesn't match visualization or verbose nodes - ), - ) - - # Process the update - result = await assistant._aprocess_value_update(update) - - # Should return None (ACK events are now generated by the reducer) - self.assertIsNone(result) - def test_billing_context_in_config(self): billing_context = MaxBillingContext( has_active_subscription=True, @@ -1855,7 +1774,225 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest): ) config = assistant._get_config() - self.assertEqual(config.get("configurable", {}).get("billing_context"), billing_context) + self.assertEqual(config["configurable"]["billing_context"], billing_context) + + @patch("ee.hogai.context.context.AssistantContextManager.check_user_has_billing_access", return_value=True) + @patch("ee.hogai.graph.root.nodes.RootNode._get_model") + async def test_billing_tool_execution(self, root_mock, access_mock): + """Test that the billing tool can be called and returns formatted billing information.""" + billing_context = MaxBillingContext( + subscription_level=MaxBillingContextSubscriptionLevel.PAID, + billing_plan="startup", + has_active_subscription=True, + is_deactivated=False, + products=[ + MaxProductInfo( + name="Product Analytics", + type="analytics", + description="Track user behavior", + current_usage=50000, + usage_limit=100000, + has_exceeded_limit=False, + is_used=True, + percentage_usage=0.5, + addons=[], + ) + ], + settings=MaxBillingContextSettings(autocapture_on=True, active_destinations=2), + ) + + # Mock the root node to call the read_data tool with billing_info kind + tool_call_id = str(uuid4()) + + def root_side_effect(msgs: list[BaseMessage]): + # Check if we've already received a tool result + last_message = msgs[-1] + if ( + isinstance(last_message.content, list) + and isinstance(last_message.content[-1], dict) + and last_message.content[-1]["type"] == "tool_result" + ): + # After tool execution, respond with final message + return messages.AIMessage(content="Your billing information shows you're on a startup plan.") + + # First call - request the billing tool + return messages.AIMessage( + content="", + tool_calls=[ + { + "id": tool_call_id, + "name": "read_data", + "args": {"kind": "billing_info"}, + } + ], + ) + + root_mock.return_value = FakeAnthropicRunnableLambdaWithTokenCounter(root_side_effect) + + # Create a minimal test graph + test_graph = ( + AssistantGraph(self.team, self.user) + .add_edge(AssistantNodeName.START, AssistantNodeName.ROOT) + .add_root() + .compile() + ) + + # Run the assistant with billing context + assistant = Assistant.create( + team=self.team, + conversation=self.conversation, + user=self.user, + new_message=HumanMessage(content="What's my current billing status?"), + billing_context=billing_context, + ) + assistant._graph = test_graph + + output: list[AssistantOutput] = [] + async for event in assistant.astream(): + output.append(event) + + # Verify we received messages + self.assertGreater(len(output), 0) + + # Find the assistant's final response + assistant_messages = [msg for event_type, msg in output if isinstance(msg, AssistantMessage)] + self.assertGreater(len(assistant_messages), 0) + + # Verify the assistant received and used the billing information + # The mock returns "Your billing information shows you're on a startup plan." + final_message = cast(AssistantMessage, assistant_messages[-1]) + self.assertIn("billing", final_message.content.lower()) + self.assertIn("startup", final_message.content.lower()) + + async def test_messages_without_id_are_yielded(self): + """Test that messages without ID are always yielded.""" + + class MessageWithoutIdNode(AssistantNode): + call_count = 0 + + async def arun(self, state, config): + self.call_count += 1 + # Return message without ID - should always be yielded + return PartialAssistantState( + messages=[ + AssistantMessage(content=f"Message {self.call_count} without ID"), + AssistantMessage(content=f"Message {self.call_count} without ID"), + ] + ) + + node = MessageWithoutIdNode(self.team, self.user) + graph = ( + AssistantGraph(self.team, self.user) + .add_node(AssistantNodeName.ROOT, node) + .add_edge(AssistantNodeName.START, AssistantNodeName.ROOT) + .add_edge(AssistantNodeName.ROOT, AssistantNodeName.END) + .compile() + ) + + # Run the assistant multiple times + output1, _ = await self._run_assistant_graph(graph, message="First run", conversation=self.conversation) + output2, _ = await self._run_assistant_graph(graph, message="Second run", conversation=self.conversation) + + # Both runs should yield their messages (human + assistant message each) + self.assertEqual(len(output1), 3) # Human message + AI message + AI message + self.assertEqual(len(output2), 3) # Human message + AI message + AI message + + async def test_messages_with_id_are_deduplicated(self): + """Test that messages with ID are deduplicated during streaming.""" + message_id = str(uuid4()) + + class DuplicateMessageNode(AssistantNode): + call_count = 0 + + async def arun(self, state, config): + self.call_count += 1 + # Always return the same message with same ID + return PartialAssistantState( + messages=[ + AssistantMessage(id=message_id, content=f"Call {self.call_count}"), + AssistantMessage(id=message_id, content=f"Call {self.call_count}"), + ] + ) + + node = DuplicateMessageNode(self.team, self.user) + graph = ( + AssistantGraph(self.team, self.user) + .add_node(AssistantNodeName.ROOT, node) + .add_edge(AssistantNodeName.START, AssistantNodeName.ROOT) + .add_edge(AssistantNodeName.ROOT, AssistantNodeName.END) + .compile() + ) + + # Create assistant and manually test the streaming behavior + assistant = Assistant.create( + self.team, + self.conversation, + new_message=HumanMessage(content="Test message"), + user=self.user, + is_new_conversation=False, + ) + assistant._graph = graph + + # Collect all streamed messages + streamed_messages = [] + async for event_type, message in assistant.astream(stream_first_message=False): + if event_type == AssistantEventType.MESSAGE: + streamed_messages.append(message) + + # Should only get one message despite the node being called multiple times + assistant_messages = [ + msg for msg in streamed_messages if isinstance(msg, AssistantMessage) and msg.id == message_id + ] + self.assertEqual(len(assistant_messages), 1, "Message with same ID should only be yielded once") + + async def test_replaced_messaged_are_not_double_streamed(self): + """Test that existing messages are not streamed again""" + # Create messages with IDs that should be tracked + message_id_1 = str(uuid4()) + message_id_2 = str(uuid4()) + call_count = [0] + + # Create a simple graph that returns messages with IDs + class TestNode(AssistantNode): + async def arun(self, state, config): + result = None + if call_count[0] == 0: + result = PartialAssistantState( + messages=[ + AssistantMessage(id=message_id_1, content="Message 1"), + ] + ) + else: + result = PartialAssistantState( + messages=ReplaceMessages( + [ + AssistantMessage(id=message_id_1, content="Message 1"), + AssistantMessage(id=message_id_2, content="Message 2"), + ] + ) + ) + call_count[0] += 1 + return result + + graph = ( + AssistantGraph(self.team, self.user) + .add_node(AssistantNodeName.ROOT, TestNode(self.team, self.user)) + .add_edge(AssistantNodeName.START, AssistantNodeName.ROOT) + .add_edge(AssistantNodeName.ROOT, AssistantNodeName.END) + .compile() + ) + + output, _ = await self._run_assistant_graph(graph, message="First run", conversation=self.conversation) + # Filter for assistant messages only, as the test is about tracking assistant message IDs + assistant_output = [(event_type, msg) for event_type, msg in output if isinstance(msg, AssistantMessage)] + self.assertEqual(len(assistant_output), 1) + self.assertEqual(cast(AssistantMessage, assistant_output[0][1]).id, message_id_1) + + output, _ = await self._run_assistant_graph(graph, message="Second run", conversation=self.conversation) + # Filter for assistant messages only, as the test is about tracking assistant message IDs + assistant_output = [(event_type, msg) for event_type, msg in output if isinstance(msg, AssistantMessage)] + self.assertEqual(len(assistant_output), 1) + self.assertEqual(cast(AssistantMessage, assistant_output[0][1]).id, message_id_2) @patch( "ee.hogai.graph.conversation_summarizer.nodes.AnthropicConversationSummarizer.summarize", @@ -1887,22 +2024,7 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest): mock_tool.return_value = ("Event list" * 128000, None) mock_should_compact.side_effect = cycle([False, True]) # Also changed this - graph = ( - AssistantGraph(self.team, self.user) - .add_root( - path_map={ - "insights": AssistantNodeName.END, - "search_documentation": AssistantNodeName.END, - "root": AssistantNodeName.ROOT, - "end": AssistantNodeName.END, - "insights_search": AssistantNodeName.END, - "session_summarization": AssistantNodeName.END, - "create_dashboard": AssistantNodeName.END, - } - ) - .add_memory_onboarding() - .compile() - ) + graph = AssistantGraph(self.team, self.user).add_root().add_memory_onboarding().compile() expected_output = [ ("message", HumanMessage(content="First")), @@ -1928,3 +2050,86 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest): self.assertEqual(state.start_id, new_human_message.id) # should be equal to the summary message, minus reasoning message self.assertEqual(state.root_conversation_start_id, state.messages[3].id) + + @patch("ee.hogai.graph.root.tools.search.SearchTool._arun_impl", return_value=("Docs doubt it", None)) + @patch( + "ee.hogai.graph.root.tools.read_taxonomy.ReadTaxonomyTool._run_impl", + return_value=("Hedgehogs have not talked yet", None), + ) + @patch("ee.hogai.graph.root.nodes.RootNode._get_model") + async def test_root_node_can_execute_multiple_tool_calls(self, root_mock, search_mock, read_taxonomy_mock): + """Test that the root node can execute multiple tool calls in parallel.""" + tool_call_id1, tool_call_id2 = [str(uuid4()), str(uuid4())] + + def root_side_effect(msgs: list[BaseMessage]): + # Check if we've already received a tool result + last_message = msgs[-1] + if ( + isinstance(last_message.content, list) + and isinstance(last_message.content[-1], dict) + and last_message.content[-1]["type"] == "tool_result" + ): + # After tool execution, respond with final message + return messages.AIMessage(content="No") + + return messages.AIMessage( + content="Not sure. Let me check.", + tool_calls=[ + { + "id": tool_call_id1, + "name": "search", + "args": {"kind": "docs", "query": "Do hedgehogs speak?"}, + }, + { + "id": tool_call_id2, + "name": "read_taxonomy", + "args": {"query": {"kind": "events"}}, + }, + ], + ) + + root_mock.return_value = FakeAnthropicRunnableLambdaWithTokenCounter(root_side_effect) + + # Create a minimal test graph + graph = ( + AssistantGraph(self.team, self.user) + .add_edge(AssistantNodeName.START, AssistantNodeName.ROOT) + .add_root() + .compile() + ) + + expected_output = [ + (AssistantEventType.MESSAGE, HumanMessage(content="Do hedgehogs speak?")), + ( + AssistantEventType.MESSAGE, + AssistantMessage( + content="Not sure. Let me check.", + tool_calls=[ + { + "id": tool_call_id1, + "name": "search", + "args": {"kind": "docs", "query": "Do hedgehogs speak?"}, + }, + { + "id": tool_call_id2, + "name": "read_taxonomy", + "args": {"query": {"kind": "events"}}, + }, + ], + ), + ), + ( + AssistantEventType.MESSAGE, + AssistantToolCallMessage(content="Docs doubt it", tool_call_id=tool_call_id1), + ), + ( + AssistantEventType.MESSAGE, + AssistantToolCallMessage(content="Hedgehogs have not talked yet", tool_call_id=tool_call_id2), + ), + (AssistantEventType.MESSAGE, AssistantMessage(content="No")), + ] + output, _ = await self._run_assistant_graph( + graph, message="Do hedgehogs speak?", conversation=self.conversation + ) + + self.assertConversationEqual(output, expected_output) diff --git a/ee/hogai/tool.py b/ee/hogai/tool.py index 704c569ff2..3213639291 100644 --- a/ee/hogai/tool.py +++ b/ee/hogai/tool.py @@ -1,6 +1,7 @@ import json import pkgutil import importlib +from abc import ABC, abstractmethod from collections.abc import Sequence from typing import Any, Literal, Self @@ -16,9 +17,9 @@ from posthog.models import Team, User import products from ee.hogai.context.context import AssistantContextManager -from ee.hogai.graph.mixins import AssistantContextMixin -from ee.hogai.utils.types import AssistantState -from ee.hogai.utils.types.base import AssistantMessageUnion +from ee.hogai.graph.base.context import get_node_path, set_node_path +from ee.hogai.graph.mixins import AssistantContextMixin, AssistantDispatcherMixin +from ee.hogai.utils.types.base import AssistantMessageUnion, AssistantState, NodePath CONTEXTUAL_TOOL_NAME_TO_TOOL: dict[AssistantTool, type["MaxTool"]] = {} @@ -51,7 +52,7 @@ class ToolMessagesArtifact(BaseModel): messages: Sequence[AssistantMessageUnion] -class MaxTool(AssistantContextMixin, BaseTool): +class MaxTool(AssistantContextMixin, AssistantDispatcherMixin, BaseTool): # LangChain's default is just "content", but we always want to return the tool call artifact too # - it becomes the `ui_payload` response_format: Literal["content_and_artifact"] = "content_and_artifact" @@ -66,7 +67,7 @@ class MaxTool(AssistantContextMixin, BaseTool): _config: RunnableConfig _state: AssistantState _context_manager: AssistantContextManager - _tool_call_id: str + _node_path: tuple[NodePath, ...] # DEPRECATED: Use `_arun_impl` instead def _run_impl(self, *args, **kwargs) -> tuple[str, Any]: @@ -82,7 +83,7 @@ class MaxTool(AssistantContextMixin, BaseTool): *, team: Team, user: User, - tool_call_id: str, + node_path: tuple[NodePath, ...] | None = None, state: AssistantState | None = None, config: RunnableConfig | None = None, name: str | None = None, @@ -102,7 +103,10 @@ class MaxTool(AssistantContextMixin, BaseTool): super().__init__(**tool_kwargs, **kwargs) self._team = team self._user = user - self._tool_call_id = tool_call_id + if node_path is None: + self._node_path = (*(get_node_path() or ()), NodePath(name=self.node_name)) + else: + self._node_path = node_path self._state = state if state else AssistantState(messages=[]) self._config = config if config else RunnableConfig(configurable={}) self._context_manager = context_manager or AssistantContextManager(team, user, self._config) @@ -120,19 +124,35 @@ class MaxTool(AssistantContextMixin, BaseTool): CONTEXTUAL_TOOL_NAME_TO_TOOL[accepted_name] = cls def _run(self, *args, config: RunnableConfig, **kwargs): + """LangChain default runner.""" try: - return self._run_impl(*args, **kwargs) + return self._run_with_context(*args, **kwargs) except NotImplementedError: pass - return async_to_sync(self._arun_impl)(*args, **kwargs) + return async_to_sync(self._arun_with_context)(*args, **kwargs) async def _arun(self, *args, config: RunnableConfig, **kwargs): + """LangChain default runner.""" try: - return await self._arun_impl(*args, **kwargs) + return await self._arun_with_context(*args, **kwargs) except NotImplementedError: pass return await super()._arun(*args, config=config, **kwargs) + def _run_with_context(self, *args, **kwargs): + """Sets the context for the tool.""" + with set_node_path(self.node_path): + return self._run_impl(*args, **kwargs) + + async def _arun_with_context(self, *args, **kwargs): + """Sets the context for the tool.""" + with set_node_path(self.node_path): + return await self._arun_impl(*args, **kwargs) + + @property + def node_name(self) -> str: + return f"max_tool.{self.get_name()}" + @property def context(self) -> dict: return self._context_manager.get_contextual_tools().get(self.get_name(), {}) @@ -149,7 +169,7 @@ class MaxTool(AssistantContextMixin, BaseTool): *, team: Team, user: User, - tool_call_id: str, + node_path: tuple[NodePath, ...] | None = None, state: AssistantState | None = None, config: RunnableConfig | None = None, context_manager: AssistantContextManager | None = None, @@ -160,5 +180,31 @@ class MaxTool(AssistantContextMixin, BaseTool): Override this factory to dynamically modify the tool name, description, args schema, etc. """ return cls( - team=team, user=user, tool_call_id=tool_call_id, state=state, config=config, context_manager=context_manager + team=team, user=user, node_path=node_path, state=state, config=config, context_manager=context_manager ) + + +class MaxSubtool(AssistantDispatcherMixin, ABC): + _config: RunnableConfig + + def __init__( + self, + *, + team: Team, + user: User, + state: AssistantState, + config: RunnableConfig, + context_manager: AssistantContextManager, + ): + self._team = team + self._user = user + self._state = state + self._context_manager = context_manager + + @abstractmethod + async def execute(self, *args, **kwargs) -> Any: + pass + + @property + def node_name(self) -> str: + return f"max_subtool.{self.__class__.__name__}" diff --git a/ee/hogai/utils/dispatcher.py b/ee/hogai/utils/dispatcher.py index 0b7a3b0e75..7cfa1db995 100644 --- a/ee/hogai/utils/dispatcher.py +++ b/ee/hogai/utils/dispatcher.py @@ -1,21 +1,19 @@ from collections.abc import Callable -from typing import TYPE_CHECKING, Any +from typing import Any +from langchain_core.runnables import RunnableConfig +from langgraph.config import get_stream_writer from langgraph.types import StreamWriter -from posthog.schema import AssistantMessage, AssistantToolCallMessage - from ee.hogai.utils.types.base import ( AssistantActionUnion, AssistantDispatcherEvent, AssistantMessageUnion, MessageAction, - NodeStartAction, + NodePath, + UpdateAction, ) -if TYPE_CHECKING: - from ee.hogai.utils.types.composed import MaxNodeName - class AssistantDispatcher: """ @@ -26,13 +24,14 @@ class AssistantDispatcher: The dispatcher does NOT update state - it just emits actions to the stream. """ - _parent_tool_call_id: str | None = None + _node_path: tuple[NodePath, ...] def __init__( self, writer: StreamWriter | Callable[[Any], None], - node_name: "MaxNodeName", - parent_tool_call_id: str | None = None, + node_path: tuple[NodePath, ...], + node_name: str, + node_run_id: str, ): """ Create a dispatcher for a specific node. @@ -40,9 +39,10 @@ class AssistantDispatcher: Args: node_name: The name of the node dispatching actions (for attribution) """ - self._node_name = node_name self._writer = writer - self._parent_tool_call_id = parent_tool_call_id + self._node_path = node_path + self._node_name = node_name + self._node_run_id = node_run_id def dispatch(self, action: AssistantActionUnion) -> None: """ @@ -56,7 +56,11 @@ class AssistantDispatcher: action: Action dict with "type" and "payload" keys """ try: - self._writer(AssistantDispatcherEvent(action=action, node_name=self._node_name)) + self._writer( + AssistantDispatcherEvent( + action=action, node_path=self._node_path, node_name=self._node_name, node_run_id=self._node_run_id + ) + ) except Exception as e: # Log error but don't crash node execution # The dispatcher should be resilient to writer failures @@ -69,34 +73,29 @@ class AssistantDispatcher: """ Dispatch a message to the stream. """ - if self._parent_tool_call_id: - # If the dispatcher is initialized with a parent tool call id, set the parent tool call id on the message - # This is to ensure that the message is associated with the correct tool call - # Don't set parent_tool_call_id on: - # 1. AssistantToolCallMessage with the same tool_call_id (to avoid self-reference) - # 2. AssistantMessage with tool_calls containing the same ID (to avoid cycles) - should_skip = False - if isinstance(message, AssistantToolCallMessage) and self._parent_tool_call_id == message.tool_call_id: - should_skip = True - elif isinstance(message, AssistantMessage) and message.tool_calls: - # Check if any tool call has the same ID as the parent - for tool_call in message.tool_calls: - if tool_call.id == self._parent_tool_call_id: - should_skip = True - break - - if not should_skip: - message.parent_tool_call_id = self._parent_tool_call_id self.dispatch(MessageAction(message=message)) - def node_start(self) -> None: - """ - Dispatch a node start action to the stream. - """ - self.dispatch(NodeStartAction()) + def update(self, content: str): + """Dispatch a transient update message to the stream that will be associated with a tool call in the UI.""" + self.dispatch(UpdateAction(content=content)) - def set_as_root(self) -> None: - """ - Set the dispatcher as the root. - """ - self._parent_tool_call_id = None + +def create_dispatcher_from_config(config: RunnableConfig, node_path: tuple[NodePath, ...]) -> AssistantDispatcher: + """Create a dispatcher from a RunnableConfig and node path""" + # Set writer from LangGraph context + try: + writer = get_stream_writer() + except RuntimeError: + # Not in streaming context (e.g., testing) + # Use noop writer + def noop(*_args, **_kwargs): + pass + + writer = noop + + metadata = config.get("metadata") or {} + node_name: str = metadata.get("langgraph_node") or "" + # `langgraph_checkpoint_ns` contains the nested path to the node, so it's more accurate for streaming. + node_run_id: str = metadata.get("langgraph_checkpoint_ns") or "" + + return AssistantDispatcher(writer, node_path=node_path, node_run_id=node_run_id, node_name=node_name) diff --git a/ee/hogai/utils/helpers.py b/ee/hogai/utils/helpers.py index 98c417b7c9..609217f800 100644 --- a/ee/hogai/utils/helpers.py +++ b/ee/hogai/utils/helpers.py @@ -38,8 +38,7 @@ from posthog.hogql_queries.query_runner import ExecutionMode from posthog.models import Team from posthog.taxonomy.taxonomy import CORE_FILTER_DEFINITIONS_BY_GROUP -from ee.hogai.utils.types import AssistantMessageUnion -from ee.hogai.utils.types.base import AssistantDispatcherEvent +from ee.hogai.utils.types.base import AssistantDispatcherEvent, AssistantMessageUnion def remove_line_breaks(line: str) -> str: diff --git a/ee/hogai/utils/state.py b/ee/hogai/utils/state.py index 2d94ae6cc0..4af17c1b94 100644 --- a/ee/hogai/utils/state.py +++ b/ee/hogai/utils/state.py @@ -5,7 +5,7 @@ from structlog import get_logger from ee.hogai.graph.deep_research.types import DeepResearchNodeName, PartialDeepResearchState from ee.hogai.graph.taxonomy.types import TaxonomyAgentState, TaxonomyNodeName -from ee.hogai.utils.types import PartialAssistantState +from ee.hogai.utils.types.base import PartialAssistantState from ee.hogai.utils.types.composed import AssistantMaxGraphState, AssistantMaxPartialGraphState, MaxNodeName # A state update can have a partial state or a LangGraph's reserved dataclasses like Interrupt. @@ -45,6 +45,7 @@ def validate_value_update( class LangGraphState(TypedDict): langgraph_node: MaxNodeName + langgraph_checkpoint_ns: str GraphMessageUpdateTuple = tuple[Literal["messages"], tuple[Union[AIMessageChunk, Any], LangGraphState]] diff --git a/ee/hogai/utils/stream_processor.py b/ee/hogai/utils/stream_processor.py index f995092dbe..dba0062bc9 100644 --- a/ee/hogai/utils/stream_processor.py +++ b/ee/hogai/utils/stream_processor.py @@ -1,4 +1,4 @@ -from typing import cast +from typing import Generic, Protocol, TypeVar, cast, get_args import structlog from langchain_core.messages import AIMessageChunk @@ -6,7 +6,6 @@ from langchain_core.messages import AIMessageChunk from posthog.schema import ( AssistantGenerationStatusEvent, AssistantGenerationStatusType, - AssistantMessage, AssistantToolCallMessage, AssistantUpdateEvent, FailureMessage, @@ -16,21 +15,51 @@ from posthog.schema import ( ) from ee.hogai.utils.helpers import normalize_ai_message, should_output_assistant_message -from ee.hogai.utils.state import merge_message_chunk +from ee.hogai.utils.state import is_message_update, is_state_update, merge_message_chunk from ee.hogai.utils.types.base import ( AssistantDispatcherEvent, + AssistantGraphName, AssistantMessageUnion, AssistantResultUnion, + BaseStateWithMessages, + LangGraphUpdateEvent, MessageAction, MessageChunkAction, + NodeEndAction, + NodePath, NodeStartAction, + UpdateAction, ) from ee.hogai.utils.types.composed import MaxNodeName logger = structlog.get_logger(__name__) -class AssistantStreamProcessor: +class AssistantStreamProcessorProtocol(Protocol): + """Protocol defining the interface for assistant stream processors.""" + + _streamed_update_ids: set[str] + """Tracks the IDs of messages that have been streamed.""" + + def process(self, event: AssistantDispatcherEvent) -> list[AssistantResultUnion] | None: + """Process a dispatcher event and return a result or None.""" + ... + + def process_langgraph_update(self, event: LangGraphUpdateEvent) -> list[AssistantResultUnion] | None: + """Process a LangGraph update event and return a list of results or None.""" + ... + + def mark_id_as_streamed(self, message_id: str) -> None: + """Mark a message ID as streamed.""" + self._streamed_update_ids.add(message_id) + + +StateType = TypeVar("StateType", bound=BaseStateWithMessages) + +MESSAGE_TYPE_TUPLE = get_args(AssistantMessageUnion) + + +class AssistantStreamProcessor(AssistantStreamProcessorProtocol, Generic[StateType]): """ Reduces streamed actions to client-facing messages. @@ -38,36 +67,34 @@ class AssistantStreamProcessor: handlers based on action type and message characteristics. """ + _verbose_nodes: set[MaxNodeName] + """Nodes that emit messages.""" _streaming_nodes: set[MaxNodeName] """Nodes that produce streaming messages.""" - _visualization_nodes: dict[MaxNodeName, type] - """Nodes that produce visualization messages.""" - _tool_call_id_to_message: dict[str, AssistantMessage] - """Maps tool call IDs to their parent messages for message chain tracking.""" - _streamed_update_ids: set[str] - """Tracks the IDs of messages that have been streamed.""" - _chunks: AIMessageChunk + _chunks: dict[str, AIMessageChunk] """Tracks the current message chunk.""" + _state: StateType | None + """Tracks the current state.""" + _state_type: type[StateType] + """The type of the state.""" - def __init__( - self, - streaming_nodes: set[MaxNodeName], - visualization_nodes: dict[MaxNodeName, type], - ): + def __init__(self, verbose_nodes: set[MaxNodeName], streaming_nodes: set[MaxNodeName], state_type: type[StateType]): """ Initialize the stream processor with node configuration. Args: + verbose_nodes: Nodes that produce messages streaming_nodes: Nodes that produce streaming messages - visualization_nodes: Nodes that produce visualization messages """ + # If a node is streaming node, it should also be verbose. + self._verbose_nodes = verbose_nodes | streaming_nodes self._streaming_nodes = streaming_nodes - self._visualization_nodes = visualization_nodes - self._tool_call_id_to_message = {} self._streamed_update_ids = set() - self._chunks = AIMessageChunk(content="") + self._chunks = {} + self._state_type = state_type + self._state = None - def process(self, event: AssistantDispatcherEvent) -> AssistantResultUnion | None: + def process(self, event: AssistantDispatcherEvent) -> list[AssistantResultUnion] | None: """ Reduce streamed actions to client messages. @@ -75,86 +102,113 @@ class AssistantStreamProcessor: to specialized handlers based on action type and message characteristics. """ action = event.action - node_name = event.node_name - - if isinstance(action, MessageChunkAction): - return self._handle_message_stream(action.message, cast(MaxNodeName, node_name)) if isinstance(action, NodeStartAction): - return AssistantGenerationStatusEvent(type=AssistantGenerationStatusType.ACK) + self._chunks[event.node_run_id] = AIMessageChunk(content="") + return [AssistantGenerationStatusEvent(type=AssistantGenerationStatusType.ACK)] + + if isinstance(action, NodeEndAction): + if event.node_run_id in self._chunks: + del self._chunks[event.node_run_id] + return self._handle_node_end(event, action) + + if isinstance(action, MessageChunkAction) and (result := self._handle_message_stream(event, action.message)): + return [result] if isinstance(action, MessageAction): message = action.message + if result := self._handle_message(event, message): + return [result] - # Register any tool calls for later parent chain lookups - self._register_tool_calls(message) - result = self._handle_message(message, cast(MaxNodeName, node_name)) - return ( - result if result is not None else AssistantGenerationStatusEvent(type=AssistantGenerationStatusType.ACK) + if isinstance(action, UpdateAction) and (update_event := self._handle_update_message(event, action)): + return [update_event] + + return None + + def process_langgraph_update(self, event: LangGraphUpdateEvent) -> list[AssistantResultUnion] | None: + """ + Process a LangGraph update event. + """ + if is_message_update(event.update): + # Convert the message chunk update to a dispatcher event to prepare for a bright future without LangGraph + maybe_message_chunk, state = event.update[1] + if not isinstance(maybe_message_chunk, AIMessageChunk): + return None + action = AssistantDispatcherEvent( + action=MessageChunkAction(message=maybe_message_chunk), + node_name=state["langgraph_node"], + node_run_id=state["langgraph_checkpoint_ns"], ) + return self.process(action) - def _find_parent_ids(self, message: AssistantMessage) -> tuple[str | None, str | None]: - """ - Walk up the message chain to find the root parent's message_id and tool_call_id. + if is_state_update(event.update): + new_state = self._state_type.model_validate(event.update[1]) + self._state = new_state - Returns (root_message_id, root_tool_call_id) for the root message in the chain. - Includes cycle detection and max depth protection. - """ - root_tool_call_id = message.parent_tool_call_id - if root_tool_call_id is None: - return message.id, None + return None - root_message_id = None - visited: set[str] = set() + def _handle_message( + self, action: AssistantDispatcherEvent, message: AssistantMessageUnion + ) -> AssistantResultUnion | None: + """Handle a message from a node.""" + node_name = cast(MaxNodeName, action.node_name) + produced_message: AssistantResultUnion | None = None - while root_tool_call_id is not None: - if root_tool_call_id in visited: - # Cycle detected, we skip this message - return None, None + # Output all messages from the top-level graph. + if not self._is_message_from_nested_node_or_graph(action.node_path or ()): + produced_message = self._handle_root_message(message, node_name) + # Other message types with parents (viz, notebook, failure, tool call) + else: + produced_message = self._handle_special_child_message(message, node_name) - visited.add(root_tool_call_id) - parent_message = self._tool_call_id_to_message.get(root_tool_call_id) - if parent_message is None: - # The parent message is not registered, we skip this message as it could come - # from a sub-nested graph invoked directly by a contextual tool. - return None, None + # Messages with existing IDs must be deduplicated. + # Messages WITHOUT IDs must be streamed because they're progressive. + if isinstance(produced_message, MESSAGE_TYPE_TUPLE) and produced_message.id is not None: + if produced_message.id in self._streamed_update_ids: + return None + self._streamed_update_ids.add(produced_message.id) - next_parent_tool_call_id = parent_message.parent_tool_call_id - root_message_id = parent_message.id - if next_parent_tool_call_id is None: - return root_message_id, root_tool_call_id - root_tool_call_id = next_parent_tool_call_id - raise ValueError("Should not reach here") + return produced_message - def _register_tool_calls(self, message: AssistantMessageUnion) -> None: - """Register any tool calls in the message for later lookup.""" - if isinstance(message, AssistantMessage) and message.tool_calls is not None: - for tool_call in message.tool_calls: - self._tool_call_id_to_message[tool_call.id] = message + def _is_message_from_nested_node_or_graph(self, node_path: tuple[NodePath, ...]) -> bool: + """Check if the message is from a nested node or graph.""" + if not node_path: + return False + # The first path is always the top-level graph. + # The second path is always the top-level node. + # If the path is longer than 2, it's a MaxTool or nested graphs. + if len(node_path) > 2 and next((path for path in node_path[1:] if path.name in AssistantGraphName), None): + return True + return False def _handle_root_message( self, message: AssistantMessageUnion, node_name: MaxNodeName ) -> AssistantMessageUnion | None: """Handle messages with no parent (root messages).""" - if not should_output_assistant_message(message): + if node_name not in self._verbose_nodes or not should_output_assistant_message(message): return None return message - def _handle_assistant_message_with_parent(self, message: AssistantMessage) -> AssistantUpdateEvent | None: + def _handle_update_message( + self, event: AssistantDispatcherEvent, action: UpdateAction + ) -> AssistantUpdateEvent | None: """Handle AssistantMessage that has a parent, creating an AssistantUpdateEvent.""" - parent_id, parent_tool_call_id = self._find_parent_ids(message) - - if parent_tool_call_id is None or parent_id is None: + if not event.node_path or not action.content: return None - if message.content == "": + # Find the closest tool call id to the update. + parent_path = next((path for path in reversed(event.node_path) if path.tool_call_id), None) + # Updates from the top-level graph nodes are not supported. + if not parent_path: return None - return AssistantUpdateEvent( - id=parent_id, - tool_call_id=parent_tool_call_id, - content=message.content, - ) + tool_call_id = parent_path.tool_call_id + message_id = parent_path.message_id + + if not message_id or not tool_call_id: + return None + + return AssistantUpdateEvent(id=message_id, tool_call_id=tool_call_id, content=action.content) def _handle_special_child_message( self, message: AssistantMessageUnion, node_name: MaxNodeName @@ -164,14 +218,10 @@ class AssistantStreamProcessor: These messages are returned as-is regardless of where in the nesting hierarchy they are. """ - # Return visualization messages only if from visualization nodes - if isinstance(message, VisualizationMessage | MultiVisualizationMessage): - if node_name in self._visualization_nodes: - return message - return None - # These message types are always returned as-is - if isinstance(message, NotebookUpdateMessage | FailureMessage): + if isinstance(message, VisualizationMessage | MultiVisualizationMessage) or isinstance( + message, NotebookUpdateMessage | FailureMessage + ): return message if isinstance(message, AssistantToolCallMessage): @@ -181,37 +231,44 @@ class AssistantStreamProcessor: # Should not reach here raise ValueError(f"Unhandled special message type: {type(message).__name__}") - def _handle_message(self, message: AssistantMessageUnion, node_name: MaxNodeName) -> AssistantResultUnion | None: - # Messages with existing IDs must be deduplicated. - # Messages WITHOUT IDs must be streamed because they're progressive. - if hasattr(message, "id") and message.id is not None: - if message.id in self._streamed_update_ids: - return None - self._streamed_update_ids.add(message.id) - - # Root messages (no parent) are filtered by VERBOSE_NODES - if message.parent_tool_call_id is None: - return self._handle_root_message(message, node_name) - - # AssistantMessage with parent creates AssistantUpdateEvent - if isinstance(message, AssistantMessage): - return self._handle_assistant_message_with_parent(message) - else: - # Other message types with parents (viz, notebook, failure, tool call) - return self._handle_special_child_message(message, node_name) - - def _handle_message_stream(self, message: AIMessageChunk, node_name: MaxNodeName) -> AssistantResultUnion | None: + def _handle_message_stream( + self, event: AssistantDispatcherEvent, message: AIMessageChunk + ) -> AssistantResultUnion | None: """ Process LLM chunks from "messages" stream mode. With dispatch pattern, complete messages are dispatched by nodes. This handles AIMessageChunk for ephemeral streaming (responsiveness). """ + node_name = cast(MaxNodeName, event.node_name) + run_id = event.node_run_id + if node_name not in self._streaming_nodes: return None + if run_id not in self._chunks: + self._chunks[run_id] = AIMessageChunk(content="") # Merge message chunks - self._chunks = merge_message_chunk(self._chunks, message) + self._chunks[run_id] = merge_message_chunk(self._chunks[run_id], message) # Stream ephemeral message (no ID = not persisted) - return normalize_ai_message(self._chunks) + return normalize_ai_message(self._chunks[run_id]) + + def _handle_node_end( + self, event: AssistantDispatcherEvent, action: NodeEndAction + ) -> list[AssistantResultUnion] | None: + """Handle the end of a node. Reset the streaming chunks.""" + if not isinstance(action.state, BaseStateWithMessages): + return None + results: list[AssistantResultUnion] = [] + for message in action.state.messages: + if new_event := self.process( + AssistantDispatcherEvent( + action=MessageAction(message=message), + node_name=event.node_name, + node_run_id=event.node_run_id, + node_path=event.node_path, + ) + ): + results.extend(new_event) + return results diff --git a/ee/hogai/utils/test/test_dispatcher.py b/ee/hogai/utils/test/test_dispatcher.py index 4152edbe08..b70088df60 100644 --- a/ee/hogai/utils/test/test_dispatcher.py +++ b/ee/hogai/utils/test/test_dispatcher.py @@ -1,165 +1,426 @@ +""" +Comprehensive tests for AssistantDispatcher. + +Tests the dispatcher logic that emits actions to LangGraph custom stream. +""" + from posthog.test.base import BaseTest -from unittest.mock import AsyncMock, MagicMock +from unittest.mock import MagicMock, patch from langchain_core.runnables import RunnableConfig from posthog.schema import AssistantMessage, AssistantToolCall -from ee.hogai.graph.base import BaseAssistantNode -from ee.hogai.utils.dispatcher import AssistantDispatcher -from ee.hogai.utils.types import AssistantState, PartialAssistantState -from ee.hogai.utils.types.base import AssistantDispatcherEvent, AssistantNodeName +from ee.hogai.utils.dispatcher import AssistantDispatcher, create_dispatcher_from_config +from ee.hogai.utils.types.base import ( + AssistantDispatcherEvent, + AssistantGraphName, + AssistantNodeName, + MessageAction, + NodePath, + UpdateAction, +) -class TestMessageDispatcher(BaseTest): +class TestAssistantDispatcher(BaseTest): + """Test the AssistantDispatcher in isolation.""" + def setUp(self): - self._writer = MagicMock() - self._writer.__aiter__ = AsyncMock(return_value=iter([])) - self._writer.write = AsyncMock() + super().setUp() + self.dispatched_events: list[AssistantDispatcherEvent] = [] - def test_create_dispatcher(self): - """Test creating a MessageDispatcher""" - dispatcher = AssistantDispatcher(self._writer, node_name=AssistantNodeName.ROOT) + def mock_writer(event): + self.dispatched_events.append(event) + + self.writer = mock_writer + self.node_path = ( + NodePath(name=AssistantGraphName.ASSISTANT), + NodePath(name=AssistantNodeName.ROOT), + ) + + def test_create_dispatcher_with_basic_params(self): + """Test creating a dispatcher with required parameters.""" + dispatcher = AssistantDispatcher( + writer=self.writer, + node_path=self.node_path, + node_name=AssistantNodeName.ROOT, + node_run_id="test_run_123", + ) self.assertEqual(dispatcher._node_name, AssistantNodeName.ROOT) - self.assertEqual(dispatcher._writer, self._writer) + self.assertEqual(dispatcher._node_run_id, "test_run_123") + self.assertEqual(dispatcher._node_path, self.node_path) + self.assertEqual(dispatcher._writer, self.writer) - async def test_dispatch_with_writer(self): - """Test dispatching with a writer""" - dispatched_actions = [] + def test_dispatch_message_action(self): + """Test dispatching a message via the message() method.""" + dispatcher = AssistantDispatcher( + writer=self.writer, + node_path=self.node_path, + node_name=AssistantNodeName.ROOT, + node_run_id="test_run_456", + ) - def mock_writer(data): - dispatched_actions.append(data) - - dispatcher = AssistantDispatcher(mock_writer, node_name=AssistantNodeName.ROOT) - - message = AssistantMessage(content="test message", parent_tool_call_id="tc1") + message = AssistantMessage(content="Test message") dispatcher.message(message) - self.assertEqual(len(dispatched_actions), 1) + self.assertEqual(len(self.dispatched_events), 1) + event = self.dispatched_events[0] - # Verify action structure - event = dispatched_actions[0] self.assertIsInstance(event, AssistantDispatcherEvent) - self.assertEqual(event.action.type, "MESSAGE") + self.assertIsInstance(event.action, MessageAction) + assert isinstance(event.action, MessageAction) self.assertEqual(event.action.message, message) self.assertEqual(event.node_name, AssistantNodeName.ROOT) + self.assertEqual(event.node_run_id, "test_run_456") + self.assertEqual(event.node_path, self.node_path) + + def test_dispatch_update_action(self): + """Test dispatching an update via the update() method.""" + dispatcher = AssistantDispatcher( + writer=self.writer, + node_path=self.node_path, + node_name=AssistantNodeName.TRENDS_GENERATOR, + node_run_id="test_run_789", + ) + + dispatcher.update("Processing query...") + + self.assertEqual(len(self.dispatched_events), 1) + event = self.dispatched_events[0] + + self.assertIsInstance(event, AssistantDispatcherEvent) + self.assertIsInstance(event.action, UpdateAction) + assert isinstance(event.action, UpdateAction) + self.assertEqual(event.action.content, "Processing query...") + self.assertEqual(event.node_name, AssistantNodeName.TRENDS_GENERATOR) + self.assertEqual(event.node_run_id, "test_run_789") + + def test_dispatch_multiple_messages(self): + """Test dispatching multiple messages in sequence.""" + dispatcher = AssistantDispatcher( + writer=self.writer, + node_path=self.node_path, + node_name=AssistantNodeName.ROOT, + node_run_id="test_run_multi", + ) + + message1 = AssistantMessage(content="First message") + message2 = AssistantMessage(content="Second message") + message3 = AssistantMessage(content="Third message") + + dispatcher.message(message1) + dispatcher.message(message2) + dispatcher.message(message3) + + self.assertEqual(len(self.dispatched_events), 3) + + contents = [] + for event in self.dispatched_events: + if isinstance(event.action, MessageAction): + msg = event.action.message + if isinstance(msg, AssistantMessage): + contents.append(msg.content) + self.assertEqual(contents, ["First message", "Second message", "Third message"]) + + def test_dispatch_message_with_tool_calls(self): + """Test dispatching a message with tool calls preserves all data.""" + dispatcher = AssistantDispatcher( + writer=self.writer, + node_path=self.node_path, + node_name=AssistantNodeName.ROOT, + node_run_id="test_run_tools", + ) + + tool_call = AssistantToolCall(id="tool_123", name="search", args={"query": "test query"}) + message = AssistantMessage(content="Running search...", tool_calls=[tool_call]) + + dispatcher.message(message) + + self.assertEqual(len(self.dispatched_events), 1) + event = self.dispatched_events[0] + + self.assertIsInstance(event.action, MessageAction) + assert isinstance(event.action, MessageAction) + dispatched_message = event.action.message + assert isinstance(dispatched_message, AssistantMessage) + self.assertEqual(dispatched_message.content, "Running search...") + self.assertIsNotNone(dispatched_message.tool_calls) + assert dispatched_message.tool_calls is not None + self.assertEqual(len(dispatched_message.tool_calls), 1) + self.assertEqual(dispatched_message.tool_calls[0].id, "tool_123") + self.assertEqual(dispatched_message.tool_calls[0].name, "search") + self.assertEqual(dispatched_message.tool_calls[0].args, {"query": "test query"}) + + def test_dispatch_with_nested_node_path(self): + """Test that nested node paths are preserved in dispatched events.""" + nested_path = ( + NodePath(name=AssistantGraphName.ASSISTANT), + NodePath(name=AssistantNodeName.ROOT, message_id="msg_123", tool_call_id="tc_123"), + NodePath(name=AssistantGraphName.INSIGHTS), + NodePath(name=AssistantNodeName.TRENDS_GENERATOR), + ) + + dispatcher = AssistantDispatcher( + writer=self.writer, node_path=nested_path, node_name=AssistantNodeName.TRENDS_GENERATOR, node_run_id="run_1" + ) + + message = AssistantMessage(content="Nested message") + dispatcher.message(message) + + self.assertEqual(len(self.dispatched_events), 1) + event = self.dispatched_events[0] + + self.assertEqual(event.node_path, nested_path) + assert event.node_path is not None + self.assertEqual(len(event.node_path), 4) + self.assertEqual(event.node_path[1].message_id, "msg_123") + self.assertEqual(event.node_path[1].tool_call_id, "tc_123") + + def test_dispatch_error_handling_continues_execution(self): + """Test that dispatch errors are caught and logged but don't crash.""" + + def failing_writer(event): + raise RuntimeError("Writer failed!") + + dispatcher = AssistantDispatcher( + writer=failing_writer, + node_path=self.node_path, + node_name=AssistantNodeName.ROOT, + node_run_id="test_run_error", + ) + + message = AssistantMessage(content="This should not crash") + + # Should not raise exception + with patch("logging.getLogger") as mock_get_logger: + mock_logger = MagicMock() + mock_get_logger.return_value = mock_logger + + dispatcher.message(message) + + # Verify error was logged + mock_logger.error.assert_called_once() + args, kwargs = mock_logger.error.call_args + self.assertIn("Failed to dispatch action", args[0]) + + def test_dispatch_mixed_actions(self): + """Test dispatching both messages and updates in sequence.""" + dispatcher = AssistantDispatcher( + writer=self.writer, + node_path=self.node_path, + node_name=AssistantNodeName.TRENDS_GENERATOR, + node_run_id="test_run_mixed", + ) + + dispatcher.update("Starting analysis...") + dispatcher.message(AssistantMessage(content="Found 3 insights")) + dispatcher.update("Finalizing results...") + + self.assertEqual(len(self.dispatched_events), 3) + + self.assertIsInstance(self.dispatched_events[0].action, UpdateAction) + assert isinstance(self.dispatched_events[0].action, UpdateAction) + self.assertEqual(self.dispatched_events[0].action.content, "Starting analysis...") + + self.assertIsInstance(self.dispatched_events[1].action, MessageAction) + assert isinstance(self.dispatched_events[1].action, MessageAction) + assert isinstance(self.dispatched_events[1].action.message, AssistantMessage) + self.assertEqual(self.dispatched_events[1].action.message.content, "Found 3 insights") + + self.assertIsInstance(self.dispatched_events[2].action, UpdateAction) + assert isinstance(self.dispatched_events[2].action, UpdateAction) + self.assertEqual(self.dispatched_events[2].action.content, "Finalizing results...") + + def test_dispatch_preserves_message_id(self): + """Test that message IDs are preserved through dispatch.""" + dispatcher = AssistantDispatcher( + writer=self.writer, + node_path=self.node_path, + node_name=AssistantNodeName.ROOT, + node_run_id="test_run_id_preservation", + ) + + message = AssistantMessage(id="msg_xyz_789", content="Message with ID") + dispatcher.message(message) + + event = self.dispatched_events[0] + assert isinstance(event.action, MessageAction) + self.assertEqual(event.action.message.id, "msg_xyz_789") + + def test_dispatch_with_empty_node_path(self): + """Test dispatcher with an empty node path.""" + dispatcher = AssistantDispatcher( + writer=self.writer, node_path=(), node_name=AssistantNodeName.ROOT, node_run_id="test_run_empty_path" + ) + + message = AssistantMessage(content="Root level message") + dispatcher.message(message) + + self.assertEqual(len(self.dispatched_events), 1) + event = self.dispatched_events[0] + self.assertEqual(event.node_path, ()) -class MockNode(BaseAssistantNode[AssistantState, PartialAssistantState]): - """Mock node for testing""" +class TestCreateDispatcherFromConfig(BaseTest): + """Test the create_dispatcher_from_config helper function.""" - @property - def node_name(self): - return AssistantNodeName.ROOT + def test_create_dispatcher_from_config_with_stream_writer(self): + """Test creating dispatcher from config with LangGraph stream writer.""" + node_path = (NodePath(name=AssistantGraphName.ASSISTANT), NodePath(name=AssistantNodeName.ROOT)) - async def arun(self, state, config): - # Use dispatch to add a message - self.dispatcher.message(AssistantMessage(content="Test message from node")) - return PartialAssistantState() + config = RunnableConfig( + metadata={"langgraph_node": AssistantNodeName.TRENDS_GENERATOR, "langgraph_checkpoint_ns": "checkpoint_abc"} + ) + + mock_writer = MagicMock() + + with patch("ee.hogai.utils.dispatcher.get_stream_writer", return_value=mock_writer): + dispatcher = create_dispatcher_from_config(config, node_path) + + self.assertEqual(dispatcher._node_name, AssistantNodeName.TRENDS_GENERATOR) + self.assertEqual(dispatcher._node_run_id, "checkpoint_abc") + self.assertEqual(dispatcher._node_path, node_path) + self.assertEqual(dispatcher._writer, mock_writer) + + def test_create_dispatcher_from_config_without_stream_writer(self): + """Test creating dispatcher when not in streaming context (e.g., testing).""" + node_path = (NodePath(name=AssistantGraphName.ASSISTANT), NodePath(name=AssistantNodeName.ROOT)) + + config = RunnableConfig( + metadata={"langgraph_node": AssistantNodeName.ROOT, "langgraph_checkpoint_ns": "checkpoint_xyz"} + ) + + # Simulate RuntimeError when get_stream_writer is called outside streaming context + with patch("ee.hogai.utils.dispatcher.get_stream_writer", side_effect=RuntimeError("Not in streaming context")): + dispatcher = create_dispatcher_from_config(config, node_path) + + # Should create a noop writer + self.assertEqual(dispatcher._node_name, AssistantNodeName.ROOT) + self.assertEqual(dispatcher._node_run_id, "checkpoint_xyz") + self.assertEqual(dispatcher._node_path, node_path) + + # Verify the noop writer doesn't raise exceptions + message = AssistantMessage(content="Test") + dispatcher.message(message) # Should not crash + + def test_create_dispatcher_with_missing_metadata(self): + """Test creating dispatcher when metadata fields are missing.""" + node_path = (NodePath(name=AssistantGraphName.ASSISTANT),) + + config = RunnableConfig(metadata={}) + + with patch("ee.hogai.utils.dispatcher.get_stream_writer", side_effect=RuntimeError("Not in streaming context")): + dispatcher = create_dispatcher_from_config(config, node_path) + + # Should use empty strings as defaults + self.assertEqual(dispatcher._node_name, "") + self.assertEqual(dispatcher._node_run_id, "") + self.assertEqual(dispatcher._node_path, node_path) + + def test_create_dispatcher_preserves_node_path(self): + """Test that node path is correctly passed through to the dispatcher.""" + nested_path = ( + NodePath(name=AssistantGraphName.ASSISTANT), + NodePath(name=AssistantNodeName.ROOT, message_id="msg_1", tool_call_id="tc_1"), + NodePath(name=AssistantGraphName.INSIGHTS), + ) + + config = RunnableConfig( + metadata={"langgraph_node": AssistantNodeName.TRENDS_GENERATOR, "langgraph_checkpoint_ns": "cp_1"} + ) + + with patch("ee.hogai.utils.dispatcher.get_stream_writer", side_effect=RuntimeError("Not in streaming context")): + dispatcher = create_dispatcher_from_config(config, nested_path) + + self.assertEqual(dispatcher._node_path, nested_path) + self.assertEqual(len(dispatcher._node_path), 3) class TestDispatcherIntegration(BaseTest): - async def test_node_dispatch_flow(self): - """Test that a node can dispatch messages""" - # Use sync test helpers since BaseTest provides them - team, user = self.team, self.user + """Integration tests for dispatcher usage patterns.""" - node = MockNode(team=team, user=user) + def test_dispatcher_in_node_context(self): + """Test typical usage pattern within a node.""" + dispatched_events = [] - # Track dispatched actions - dispatched_actions = [] + def mock_writer(event): + dispatched_events.append(event) - def mock_writer(data): - dispatched_actions.append(data) - - # Set up node's dispatcher - node._dispatcher = AssistantDispatcher(mock_writer, node_name=AssistantNodeName.ROOT) - - # Run the node - state = AssistantState(messages=[]) - config = RunnableConfig(configurable={}) - - await node.arun(state, config) - - # Verify action was dispatched - self.assertEqual(len(dispatched_actions), 1) - - event = dispatched_actions[0] - self.assertIsInstance(event, AssistantDispatcherEvent) - self.assertEqual(event.action.type, "MESSAGE") - self.assertEqual(event.action.message.content, "Test message from node") - self.assertEqual(event.action.message.parent_tool_call_id, None) - self.assertEqual(event.node_name, AssistantNodeName.ROOT) - - async def test_action_preservation_through_stream(self): - """Test that action data is preserved through the stream""" - - captured_updates = [] - - def mock_writer(data): - captured_updates.append(data) - - dispatcher = AssistantDispatcher(mock_writer, node_name=AssistantNodeName.ROOT) - - # Create complex message with metadata - message = AssistantMessage( - content="Complex message", - parent_tool_call_id="tc123", - tool_calls=[AssistantToolCall(id="tool1", name="search", args={"query": "test"})], + node_path = ( + NodePath(name=AssistantGraphName.ASSISTANT), + NodePath(name=AssistantNodeName.ROOT), ) - dispatcher.message(message) + dispatcher = AssistantDispatcher( + writer=mock_writer, node_path=node_path, node_name=AssistantNodeName.ROOT, node_run_id="integration_run_1" + ) - # Extract action - event = captured_updates[0] - self.assertIsInstance(event, AssistantDispatcherEvent) + # Simulate node execution pattern + dispatcher.update("Starting node execution...") - # Verify all message fields preserved - payload = event.action.message - self.assertEqual(payload.content, "Complex message") - self.assertEqual(payload.parent_tool_call_id, "tc123") - self.assertIsNotNone(payload.tool_calls) - self.assertEqual(len(payload.tool_calls), 1) - self.assertEqual(payload.tool_calls[0].name, "search") + tool_call = AssistantToolCall(id="tc_int_1", name="generate_insight", args={"type": "trends"}) + dispatcher.message(AssistantMessage(content="Generating insight", tool_calls=[tool_call])) - async def test_multiple_dispatches_from_node(self): - """Test that a node can dispatch multiple messages""" - # Use sync test helpers since BaseTest provides them - team, user = self.team, self.user + dispatcher.update("Processing data...") - class MultiDispatchNode(BaseAssistantNode[AssistantState, PartialAssistantState]): - @property - def node_name(self): - return AssistantNodeName.ROOT + dispatcher.message(AssistantMessage(content="Insight generated successfully")) - async def arun(self, state, config): - # Dispatch multiple messages - self.dispatcher.message(AssistantMessage(content="First message")) - self.dispatcher.message(AssistantMessage(content="Second message")) - self.dispatcher.message(AssistantMessage(content="Third message")) - return PartialAssistantState() + # Verify all events were dispatched + self.assertEqual(len(dispatched_events), 4) - node = MultiDispatchNode(team=team, user=user) + # Verify event types and order + self.assertIsInstance(dispatched_events[0].action, UpdateAction) + self.assertIsInstance(dispatched_events[1].action, MessageAction) + self.assertIsInstance(dispatched_events[2].action, UpdateAction) + self.assertIsInstance(dispatched_events[3].action, MessageAction) - dispatched_actions = [] + # Verify all events have consistent metadata + for event in dispatched_events: + self.assertEqual(event.node_name, AssistantNodeName.ROOT) + self.assertEqual(event.node_run_id, "integration_run_1") + self.assertEqual(event.node_path, node_path) - def mock_writer(data): - dispatched_actions.append(data) + def test_concurrent_dispatchers(self): + """Test multiple dispatchers can coexist without interference.""" + dispatched_events_1 = [] + dispatched_events_2 = [] - node._dispatcher = AssistantDispatcher(mock_writer, node_name=AssistantNodeName.ROOT) + def writer_1(event): + dispatched_events_1.append(event) - state = AssistantState(messages=[]) - config = RunnableConfig(configurable={}) + def writer_2(event): + dispatched_events_2.append(event) - await node.arun(state, config) + path_1 = (NodePath(name=AssistantGraphName.ASSISTANT), NodePath(name=AssistantNodeName.ROOT)) + path_2 = (NodePath(name=AssistantGraphName.ASSISTANT), NodePath(name=AssistantNodeName.TRENDS_GENERATOR)) - # Verify all three dispatches - self.assertEqual(len(dispatched_actions), 3) + dispatcher_1 = AssistantDispatcher( + writer=writer_1, node_path=path_1, node_name=AssistantNodeName.ROOT, node_run_id="run_1" + ) - contents = [] - for event in dispatched_actions: - self.assertIsInstance(event, AssistantDispatcherEvent) - contents.append(event.action.message.content) + dispatcher_2 = AssistantDispatcher( + writer=writer_2, node_path=path_2, node_name=AssistantNodeName.TRENDS_GENERATOR, node_run_id="run_2" + ) - self.assertEqual(contents, ["First message", "Second message", "Third message"]) + # Dispatch from both + dispatcher_1.message(AssistantMessage(content="From dispatcher 1")) + dispatcher_2.message(AssistantMessage(content="From dispatcher 2")) + dispatcher_1.update("Update from dispatcher 1") + dispatcher_2.update("Update from dispatcher 2") + + # Verify each dispatcher wrote to its own writer + self.assertEqual(len(dispatched_events_1), 2) + self.assertEqual(len(dispatched_events_2), 2) + + # Verify events went to correct writers + assert isinstance(dispatched_events_1[0].action, MessageAction) + assert isinstance(dispatched_events_1[0].action.message, AssistantMessage) + self.assertEqual(dispatched_events_1[0].action.message.content, "From dispatcher 1") + assert isinstance(dispatched_events_2[0].action, MessageAction) + assert isinstance(dispatched_events_2[0].action.message, AssistantMessage) + self.assertEqual(dispatched_events_2[0].action.message.content, "From dispatcher 2") + + # Verify node names are correct + self.assertEqual(dispatched_events_1[0].node_name, AssistantNodeName.ROOT) + self.assertEqual(dispatched_events_2[0].node_name, AssistantNodeName.TRENDS_GENERATOR) diff --git a/ee/hogai/utils/test/test_stream_processor.py b/ee/hogai/utils/test/test_stream_processor.py index 25a434e63e..f6dca2e50a 100644 --- a/ee/hogai/utils/test/test_stream_processor.py +++ b/ee/hogai/utils/test/test_stream_processor.py @@ -1,15 +1,14 @@ """ -Comprehensive tests for AssistantMessageReducer. +Comprehensive tests for AssistantStreamProcessor. -Tests the reducer logic that processes dispatcher actions -and routes messages appropriately. +Tests the stream processor logic that handles dispatcher actions, +routes messages based on node paths, and manages streaming state. """ from typing import cast from uuid import uuid4 from posthog.test.base import BaseTest -from unittest.mock import MagicMock from langchain_core.messages import AIMessageChunk @@ -17,7 +16,6 @@ from posthog.schema import ( AssistantGenerationStatusEvent, AssistantGenerationStatusType, AssistantMessage, - AssistantToolCall, AssistantToolCallMessage, AssistantUpdateEvent, FailureMessage, @@ -29,13 +27,20 @@ from posthog.schema import ( VisualizationMessage, ) +from ee.hogai.utils.state import GraphValueUpdateTuple from ee.hogai.utils.stream_processor import AssistantStreamProcessor from ee.hogai.utils.types.base import ( AssistantDispatcherEvent, + AssistantGraphName, AssistantNodeName, + AssistantState, + LangGraphUpdateEvent, MessageAction, MessageChunkAction, + NodeEndAction, + NodePath, NodeStartAction, + UpdateAction, ) @@ -44,406 +49,581 @@ class TestStreamProcessor(BaseTest): def setUp(self): super().setUp() - # Create a reducer with test node configuration self.stream_processor = AssistantStreamProcessor( + verbose_nodes={AssistantNodeName.ROOT, AssistantNodeName.TRENDS_GENERATOR}, streaming_nodes={AssistantNodeName.TRENDS_GENERATOR}, - visualization_nodes={AssistantNodeName.TRENDS_GENERATOR: MagicMock()}, + state_type=AssistantState, ) def _create_dispatcher_event( self, - action: MessageAction | NodeStartAction | MessageChunkAction, + action: MessageAction | NodeStartAction | MessageChunkAction | NodeEndAction | UpdateAction, node_name: AssistantNodeName = AssistantNodeName.ROOT, + node_run_id: str = "test_run_id", + node_path: tuple[NodePath, ...] | None = None, ) -> AssistantDispatcherEvent: """Helper to create a dispatcher event for testing.""" - return AssistantDispatcherEvent(action=action, node_name=node_name) + return AssistantDispatcherEvent( + action=action, node_name=node_name, node_run_id=node_run_id, node_path=node_path + ) - def test_node_start_action_returns_ack(self): - """Test NODE_START action returns ACK status event.""" - event = self._create_dispatcher_event(NodeStartAction()) + # Node lifecycle tests + + def test_node_start_initializes_chunk_and_returns_ack(self): + """Test NodeStartAction initializes a chunk for the run_id and returns ACK.""" + run_id = "test_run_123" + event = self._create_dispatcher_event(NodeStartAction(), node_run_id=run_id) result = self.stream_processor.process(event) self.assertIsNotNone(result) - result = cast(AssistantGenerationStatusEvent, result) - self.assertEqual(result.type, AssistantGenerationStatusType.ACK) + assert result is not None + self.assertEqual(len(result), 1) + first_result = result[0] + self.assertIsInstance(first_result, AssistantGenerationStatusEvent) + assert isinstance(first_result, AssistantGenerationStatusEvent) + self.assertEqual(first_result.type, AssistantGenerationStatusType.ACK) + self.assertIn(run_id, self.stream_processor._chunks) + self.assertEqual(self.stream_processor._chunks[run_id].content, "") - def test_message_with_tool_calls_stores_in_registry(self): - """Test AssistantMessage with tool_calls is stored in _tool_call_id_to_message.""" - tool_call_id = str(uuid4()) - message = AssistantMessage( - content="Test", - tool_calls=[AssistantToolCall(id=tool_call_id, name="test_tool", args={})], - ) + def test_node_end_cleans_up_chunk(self): + """Test NodeEndAction removes the chunk for the run_id.""" + run_id = "test_run_456" + self.stream_processor._chunks[run_id] = AIMessageChunk(content="test") - event = self._create_dispatcher_event(MessageAction(message=message)) + state = AssistantState(messages=[]) + event = self._create_dispatcher_event(NodeEndAction(state=state), node_run_id=run_id) self.stream_processor.process(event) - # Should be stored in registry - self.assertIn(tool_call_id, self.stream_processor._tool_call_id_to_message) - self.assertEqual(self.stream_processor._tool_call_id_to_message[tool_call_id], message) + self.assertNotIn(run_id, self.stream_processor._chunks) - def test_assistant_message_with_parent_creates_assistant_update_event(self): - """Test AssistantMessage with parent_tool_call_id creates AssistantUpdateEvent.""" - # First, register a parent message - parent_tool_call_id = str(uuid4()) - parent_message = AssistantMessage( - id=str(uuid4()), - content="Parent", - tool_calls=[AssistantToolCall(id=parent_tool_call_id, name="test", args={})], - parent_tool_call_id=None, + def test_node_end_processes_messages_from_state(self): + """Test NodeEndAction processes all messages from the final state.""" + run_id = "test_run_789" + message1 = AssistantMessage(id=str(uuid4()), content="Message 1") + message2 = AssistantMessage(id=str(uuid4()), content="Message 2") + state = AssistantState(messages=[message1, message2]) + + event = self._create_dispatcher_event( + NodeEndAction(state=state), node_name=AssistantNodeName.ROOT, node_run_id=run_id ) - self.stream_processor._tool_call_id_to_message[parent_tool_call_id] = parent_message + results = self.stream_processor.process(event) - # Now send a child message - child_message = AssistantMessage(content="Child content", parent_tool_call_id=parent_tool_call_id) - event = self._create_dispatcher_event(MessageAction(message=child_message)) + self.assertIsNotNone(results) + assert results is not None + self.assertEqual(len(results), 2) + self.assertEqual(results[0], message1) + self.assertEqual(results[1], message2) + + # Message streaming tests + + def test_message_chunk_streaming_for_streaming_nodes(self): + """Test MessageChunkAction streams chunks for nodes in streaming_nodes.""" + run_id = "stream_run_1" + chunk = AIMessageChunk(content="Hello ") + + event = self._create_dispatcher_event( + MessageChunkAction(message=chunk), node_name=AssistantNodeName.TRENDS_GENERATOR, node_run_id=run_id + ) result = self.stream_processor.process(event) - self.assertIsInstance(result, AssistantUpdateEvent) - result = cast(AssistantUpdateEvent, result) - self.assertEqual(result.id, parent_message.id) - self.assertEqual(result.tool_call_id, parent_tool_call_id) - self.assertEqual(result.content, "Child content") + self.assertIsNotNone(result) + assert result is not None + self.assertEqual(len(result), 1) + self.assertIsInstance(result[0], AssistantMessage) + assert isinstance(result[0], AssistantMessage) + self.assertEqual(result[0].content, "Hello ") + self.assertIsNone(result[0].id) - def test_nested_parent_chain_resolution(self): - """Test finding parent IDs through nested chain of parents.""" - # Create chain: root -> intermediate -> leaf - root_tool_call_id = str(uuid4()) - root_message = AssistantMessage( - id=str(uuid4()), - content="Root", - tool_calls=[AssistantToolCall(id=root_tool_call_id, name="root_tool", args={})], - parent_tool_call_id=None, + def test_message_chunk_ignored_for_non_streaming_nodes(self): + """Test MessageChunkAction returns None for nodes not in streaming_nodes.""" + run_id = "stream_run_2" + chunk = AIMessageChunk(content="Hello ") + + event = self._create_dispatcher_event( + MessageChunkAction(message=chunk), node_name=AssistantNodeName.ROOT, node_run_id=run_id ) - - intermediate_tool_call_id = str(uuid4()) - intermediate_message = AssistantMessage( - id=str(uuid4()), - content="Intermediate", - tool_calls=[AssistantToolCall(id=intermediate_tool_call_id, name="intermediate_tool", args={})], - parent_tool_call_id=root_tool_call_id, - ) - - # Register both in the registry - self.stream_processor._tool_call_id_to_message[root_tool_call_id] = root_message - self.stream_processor._tool_call_id_to_message[intermediate_tool_call_id] = intermediate_message - - # Send leaf message that references intermediate - leaf_message = AssistantMessage(content="Leaf content", parent_tool_call_id=intermediate_tool_call_id) - event = self._create_dispatcher_event(MessageAction(message=leaf_message)) result = self.stream_processor.process(event) - # Should resolve to root + self.assertIsNone(result) - self.assertIsInstance(result, AssistantUpdateEvent) - result = cast(AssistantUpdateEvent, result) - # Note: The unpacking swaps the values, so id is tool_call_id and parent_tool_call_id is message_id - self.assertEqual(result.id, root_message.id) - self.assertEqual(result.tool_call_id, root_tool_call_id) + def test_multiple_chunks_merged_correctly(self): + """Test that multiple MessageChunkActions are merged correctly.""" + run_id = "stream_run_3" - def test_missing_parent_message_returns_ack(self): - """Test that missing parent message returns ACK.""" - missing_parent_id = str(uuid4()) - child_message = AssistantMessage(content="Orphan", parent_tool_call_id=missing_parent_id) - - event = self._create_dispatcher_event(MessageAction(message=child_message)) - - result = self.stream_processor.process(event) - self.assertIsInstance(result, AssistantGenerationStatusEvent) - result = cast(AssistantGenerationStatusEvent, result) - self.assertEqual(result.type, AssistantGenerationStatusType.ACK) - - def test_parent_without_id_returns_ack(self): - """Test that parent message without ID logs warning and returns None.""" - parent_tool_call_id = str(uuid4()) - # Parent message WITHOUT an id - parent_message = AssistantMessage( - id=None, # No ID - content="Parent", - tool_calls=[AssistantToolCall(id=parent_tool_call_id, name="test", args={})], - parent_tool_call_id=None, + chunk1 = AIMessageChunk(content="Hello ") + event1 = self._create_dispatcher_event( + MessageChunkAction(message=chunk1), node_name=AssistantNodeName.TRENDS_GENERATOR, node_run_id=run_id ) - self.stream_processor._tool_call_id_to_message[parent_tool_call_id] = parent_message + result1 = self.stream_processor.process(event1) - child_message = AssistantMessage(content="Child", parent_tool_call_id=parent_tool_call_id) - event = self._create_dispatcher_event(MessageAction(message=child_message)) + chunk2 = AIMessageChunk(content="world!") + event2 = self._create_dispatcher_event( + MessageChunkAction(message=chunk2), node_name=AssistantNodeName.TRENDS_GENERATOR, node_run_id=run_id + ) + result2 = self.stream_processor.process(event2) + self.assertIsNotNone(result1) + assert result1 is not None + assert isinstance(result1[0], AssistantMessage) + self.assertEqual(result1[0].content, "Hello ") + self.assertIsNotNone(result2) + assert result2 is not None + assert isinstance(result2[0], AssistantMessage) + self.assertEqual(result2[0].content, "Hello world!") + + def test_concurrent_chunks_from_different_runs(self): + """Test that chunks from different node runs are kept separate.""" + run_id_1 = "stream_run_4a" + run_id_2 = "stream_run_4b" + + chunk1 = AIMessageChunk(content="Run 1") + event1 = self._create_dispatcher_event( + MessageChunkAction(message=chunk1), node_name=AssistantNodeName.TRENDS_GENERATOR, node_run_id=run_id_1 + ) + self.stream_processor.process(event1) + + chunk2 = AIMessageChunk(content="Run 2") + event2 = self._create_dispatcher_event( + MessageChunkAction(message=chunk2), node_name=AssistantNodeName.TRENDS_GENERATOR, node_run_id=run_id_2 + ) + self.stream_processor.process(event2) + + self.assertEqual(self.stream_processor._chunks[run_id_1].content, "Run 1") + self.assertEqual(self.stream_processor._chunks[run_id_2].content, "Run 2") + + def test_handles_mixed_content_types_in_chunks(self): + """Test that stream processor handles switching between string and list content formats.""" + run_id = "stream_run_5" + + # Start with string content + chunk1 = AIMessageChunk(content="initial string") + event1 = self._create_dispatcher_event( + MessageChunkAction(message=chunk1), node_name=AssistantNodeName.TRENDS_GENERATOR, node_run_id=run_id + ) + self.stream_processor.process(event1) + + # Switch to list format (OpenAI Responses API) + chunk2 = AIMessageChunk(content=[{"type": "text", "text": "list content"}]) + event2 = self._create_dispatcher_event( + MessageChunkAction(message=chunk2), node_name=AssistantNodeName.TRENDS_GENERATOR, node_run_id=run_id + ) + result = self.stream_processor.process(event2) + + # The result should normalize to string content + self.assertIsNotNone(result) + assert result is not None + assert isinstance(result[0], AssistantMessage) + self.assertEqual(result[0].content, "list content") + + # Root vs nested message handling tests + + def test_root_message_from_verbose_node_returned(self): + """Test messages from root level (node_path <= 2) in verbose nodes are returned.""" + message = AssistantMessage(id=str(uuid4()), content="Root message") + node_path = (NodePath(name=AssistantGraphName.ASSISTANT), NodePath(name=AssistantNodeName.ROOT)) + + event = self._create_dispatcher_event( + MessageAction(message=message), node_name=AssistantNodeName.ROOT, node_path=node_path + ) result = self.stream_processor.process(event) - self.assertIsInstance(result, AssistantGenerationStatusEvent) - result = cast(AssistantGenerationStatusEvent, result) - self.assertEqual(result.type, AssistantGenerationStatusType.ACK) - def test_visualization_message_in_visualization_nodes(self): - """Test VisualizationMessage is returned when node is in VISUALIZATION_NODES.""" + self.assertIsNotNone(result) + assert result is not None + self.assertEqual(len(result), 1) + self.assertEqual(result[0], message) + + def test_root_message_from_non_verbose_node_filtered(self): + """Test messages from root level in non-verbose nodes are filtered out.""" + message = AssistantMessage(id=str(uuid4()), content="Non-verbose message") + node_path = (NodePath(name=AssistantGraphName.ASSISTANT), NodePath(name=AssistantNodeName.BILLING)) + + event = self._create_dispatcher_event( + MessageAction(message=message), node_name=AssistantNodeName.BILLING, node_path=node_path + ) + result = self.stream_processor.process(event) + + self.assertIsNone(result) + + def test_nested_visualization_message_returned(self): + """Test VisualizationMessage from nested node/graph is returned.""" query = TrendsQuery(series=[]) viz_message = VisualizationMessage(query="test query", answer=query, plan="test plan") - viz_message.parent_tool_call_id = str(uuid4()) - # Register parent - parent_message = AssistantMessage( - id=str(uuid4()), - content="Parent", - tool_calls=[AssistantToolCall(id=viz_message.parent_tool_call_id, name="test", args={})], + # Create a deep node path indicating this is from a nested graph + node_path = ( + NodePath(name=AssistantGraphName.ASSISTANT), + NodePath(name=AssistantNodeName.ROOT, message_id=str(uuid4()), tool_call_id=str(uuid4())), + NodePath(name=AssistantGraphName.INSIGHTS), + NodePath(name=AssistantNodeName.TRENDS_GENERATOR), ) - self.stream_processor._tool_call_id_to_message[viz_message.parent_tool_call_id] = parent_message - node_name = AssistantNodeName.TRENDS_GENERATOR - event = self._create_dispatcher_event(MessageAction(message=viz_message), node_name=node_name) + event = self._create_dispatcher_event( + MessageAction(message=viz_message), node_name=AssistantNodeName.TRENDS_GENERATOR, node_path=node_path + ) result = self.stream_processor.process(event) - self.assertEqual(result, viz_message) + self.assertIsNotNone(result) + assert result is not None + self.assertEqual(len(result), 1) + self.assertEqual(result[0], viz_message) - def test_visualization_message_not_in_visualization_nodes(self): - """Test VisualizationMessage raises error when from non-visualization node.""" - query = TrendsQuery(series=[]) - viz_message = VisualizationMessage(query="test query", answer=query, plan="test plan") - viz_message.parent_tool_call_id = str(uuid4()) - - # Register parent - parent_message = AssistantMessage( - id=str(uuid4()), - content="Parent", - tool_calls=[AssistantToolCall(id=viz_message.parent_tool_call_id, name="test", args={})], - ) - self.stream_processor._tool_call_id_to_message[viz_message.parent_tool_call_id] = parent_message - - node_name = AssistantNodeName.ROOT # Not a visualization node - event = self._create_dispatcher_event(MessageAction(message=viz_message), node_name=node_name) - - result = self.stream_processor.process(event) - self.assertEqual(result, AssistantGenerationStatusEvent(type=AssistantGenerationStatusType.ACK)) - - def test_multi_visualization_message_in_visualization_nodes(self): - """Test MultiVisualizationMessage is returned when node is in VISUALIZATION_NODES.""" + def test_nested_multi_visualization_message_returned(self): + """Test MultiVisualizationMessage from nested node/graph is returned.""" query = TrendsQuery(series=[]) viz_item = VisualizationItem(query="test query", answer=query, plan="test plan") multi_viz_message = MultiVisualizationMessage(visualizations=[viz_item]) - multi_viz_message.parent_tool_call_id = str(uuid4()) - # Register parent - parent_message = AssistantMessage( - id=str(uuid4()), - content="Parent", - tool_calls=[AssistantToolCall(id=multi_viz_message.parent_tool_call_id, name="test", args={})], + node_path = ( + NodePath(name=AssistantGraphName.ASSISTANT), + NodePath(name=AssistantNodeName.ROOT, message_id=str(uuid4()), tool_call_id=str(uuid4())), + NodePath(name=AssistantGraphName.INSIGHTS), + NodePath(name=AssistantNodeName.TRENDS_GENERATOR), ) - self.stream_processor._tool_call_id_to_message[multi_viz_message.parent_tool_call_id] = parent_message - node_name = AssistantNodeName.TRENDS_GENERATOR - event = self._create_dispatcher_event(MessageAction(message=multi_viz_message), node_name=node_name) + event = self._create_dispatcher_event( + MessageAction(message=multi_viz_message), node_name=AssistantNodeName.TRENDS_GENERATOR, node_path=node_path + ) result = self.stream_processor.process(event) - self.assertEqual(result, multi_viz_message) + self.assertIsNotNone(result) + assert result is not None + self.assertEqual(len(result), 1) + self.assertEqual(result[0], multi_viz_message) - def test_notebook_update_message_returns_as_is(self): - """Test NotebookUpdateMessage is returned directly.""" + def test_nested_notebook_message_returned(self): + """Test NotebookUpdateMessage from nested node/graph is returned.""" content = ProsemirrorJSONContent(type="doc", content=[]) notebook_message = NotebookUpdateMessage(notebook_id="nb123", content=content) - notebook_message.parent_tool_call_id = str(uuid4()) - # Register parent - parent_message = AssistantMessage( - id=str(uuid4()), - content="Parent", - tool_calls=[AssistantToolCall(id=notebook_message.parent_tool_call_id, name="test", args={})], + node_path = ( + NodePath(name=AssistantGraphName.ASSISTANT), + NodePath(name=AssistantNodeName.ROOT, message_id=str(uuid4()), tool_call_id=str(uuid4())), + NodePath(name=AssistantGraphName.INSIGHTS), + NodePath(name=AssistantNodeName.TRENDS_GENERATOR), ) - self.stream_processor._tool_call_id_to_message[notebook_message.parent_tool_call_id] = parent_message - event = self._create_dispatcher_event(MessageAction(message=notebook_message)) + event = self._create_dispatcher_event( + MessageAction(message=notebook_message), node_name=AssistantNodeName.TRENDS_GENERATOR, node_path=node_path + ) result = self.stream_processor.process(event) - self.assertEqual(result, notebook_message) + self.assertIsNotNone(result) + assert result is not None + self.assertEqual(len(result), 1) + self.assertEqual(result[0], notebook_message) - def test_failure_message_returns_as_is(self): - """Test FailureMessage is returned directly.""" + def test_nested_failure_message_returned(self): + """Test FailureMessage from nested node/graph is returned.""" failure_message = FailureMessage(content="Something went wrong") - failure_message.parent_tool_call_id = str(uuid4()) - # Register parent - parent_message = AssistantMessage( - id=str(uuid4()), - content="Parent", - tool_calls=[AssistantToolCall(id=failure_message.parent_tool_call_id, name="test", args={})], + node_path = ( + NodePath(name=AssistantGraphName.ASSISTANT), + NodePath(name=AssistantNodeName.ROOT, message_id=str(uuid4()), tool_call_id=str(uuid4())), + NodePath(name=AssistantGraphName.INSIGHTS), + NodePath(name=AssistantNodeName.TRENDS_GENERATOR), ) - self.stream_processor._tool_call_id_to_message[failure_message.parent_tool_call_id] = parent_message - event = self._create_dispatcher_event(MessageAction(message=failure_message)) + event = self._create_dispatcher_event( + MessageAction(message=failure_message), node_name=AssistantNodeName.TRENDS_GENERATOR, node_path=node_path + ) result = self.stream_processor.process(event) - self.assertEqual(result, failure_message) + self.assertIsNotNone(result) + assert result is not None + self.assertEqual(len(result), 1) + self.assertEqual(result[0], failure_message) - def test_assistant_tool_call_message_returns_as_is(self): - """Test AssistantToolCallMessage with parent is filtered out (returns ACK).""" + def test_nested_tool_call_message_filtered(self): + """Test AssistantToolCallMessage from nested node/graph is filtered out.""" tool_call_message = AssistantToolCallMessage(content="Tool result", tool_call_id=str(uuid4())) - tool_call_message.parent_tool_call_id = str(uuid4()) - # Register parent - parent_message = AssistantMessage( - id=str(uuid4()), - content="Parent", - tool_calls=[AssistantToolCall(id=tool_call_message.parent_tool_call_id, name="test", args={})], - ) - self.stream_processor._tool_call_id_to_message[tool_call_message.parent_tool_call_id] = parent_message - - event = self._create_dispatcher_event(MessageAction(message=tool_call_message)) - result = self.stream_processor.process(event) - - # New behavior: AssistantToolCallMessages with parents are filtered out - self.assertIsInstance(result, AssistantGenerationStatusEvent) - result = cast(AssistantGenerationStatusEvent, result) - self.assertEqual(result.type, AssistantGenerationStatusType.ACK) - - def test_cycle_detection_in_parent_chain(self): - """Test that circular parent chains are detected and returns ACK.""" - # Create circular chain: A -> B -> A - tool_call_a = str(uuid4()) - tool_call_b = str(uuid4()) - - message_a = AssistantMessage( - id=str(uuid4()), - content="A", - tool_calls=[AssistantToolCall(id=tool_call_a, name="tool_a", args={})], - parent_tool_call_id=tool_call_b, # Points to B + node_path = ( + NodePath(name=AssistantGraphName.ASSISTANT), + NodePath(name=AssistantNodeName.ROOT, message_id=str(uuid4()), tool_call_id=str(uuid4())), + NodePath(name=AssistantGraphName.INSIGHTS), + NodePath(name=AssistantNodeName.TRENDS_GENERATOR), ) - message_b = AssistantMessage( - id=str(uuid4()), - content="B", - tool_calls=[AssistantToolCall(id=tool_call_b, name="tool_b", args={})], - parent_tool_call_id=tool_call_a, # Points to A - ) - - self.stream_processor._tool_call_id_to_message[tool_call_a] = message_a - self.stream_processor._tool_call_id_to_message[tool_call_b] = message_b - - # Try to process a child of B - child_message = AssistantMessage(content="Child", parent_tool_call_id=tool_call_b) - event = self._create_dispatcher_event(MessageAction(message=child_message)) - - result = self.stream_processor.process(event) - - # Cycle detection returns ACK instead of raising error - self.assertIsInstance(result, AssistantGenerationStatusEvent) - result = cast(AssistantGenerationStatusEvent, result) - self.assertEqual(result.type, AssistantGenerationStatusType.ACK) - - def test_handles_mixed_content_types_in_chunks(self): - """Test that stream processor correctly handles switching between string and list content formats.""" - # Test string to list transition - self.stream_processor._chunks = AIMessageChunk(content="initial string content") - - # Simulate a chunk from OpenAI Responses API (list format) - list_chunk = AIMessageChunk(content=[{"type": "text", "text": "new content from o3"}]) event = self._create_dispatcher_event( - MessageChunkAction(message=list_chunk), node_name=AssistantNodeName.TRENDS_GENERATOR - ) - self.stream_processor.process(event) - - # Verify the chunks were reset to list format - self.assertIsInstance(self.stream_processor._chunks.content, list) - self.assertEqual(len(self.stream_processor._chunks.content), 1) - self.assertEqual(cast(dict, self.stream_processor._chunks.content[0])["text"], "new content from o3") - - # Test list to string transition - string_chunk = AIMessageChunk(content="back to string format") - event = self._create_dispatcher_event( - MessageChunkAction(message=string_chunk), node_name=AssistantNodeName.TRENDS_GENERATOR - ) - self.stream_processor.process(event) - - # Verify the chunks were reset to string format - self.assertIsInstance(self.stream_processor._chunks.content, str) - self.assertEqual(self.stream_processor._chunks.content, "back to string format") - - def test_handles_multiple_list_chunks(self): - """Test that multiple list-format chunks are properly concatenated.""" - # Start with empty chunks - self.stream_processor._chunks = AIMessageChunk(content="") - - # Add first list chunk - chunk1 = AIMessageChunk(content=[{"type": "text", "text": "First part"}]) - event = self._create_dispatcher_event( - MessageChunkAction(message=chunk1), node_name=AssistantNodeName.TRENDS_GENERATOR - ) - self.stream_processor.process(event) - - # Add second list chunk - chunk2 = AIMessageChunk(content=[{"type": "text", "text": " second part"}]) - event = self._create_dispatcher_event( - MessageChunkAction(message=chunk2), node_name=AssistantNodeName.TRENDS_GENERATOR + MessageAction(message=tool_call_message), node_name=AssistantNodeName.TRENDS_GENERATOR, node_path=node_path ) result = self.stream_processor.process(event) - # Verify the result is an AssistantMessage with combined content - self.assertIsInstance(result, AssistantMessage) - result = cast(AssistantMessage, result) - self.assertEqual(result.content, "First part second part") + self.assertIsNone(result) - def test_messages_without_id_are_yielded(self): - """Test that messages without ID are always yielded.""" - # Create messages without IDs + def test_short_node_path_treated_as_root(self): + """Test that node_path with length <= 2 is treated as root level.""" + message = AssistantMessage(id=str(uuid4()), content="Short path message") + + # Path with just 2 elements (graph + node) + node_path = (NodePath(name=AssistantGraphName.ASSISTANT), NodePath(name=AssistantNodeName.ROOT)) + + event = self._create_dispatcher_event( + MessageAction(message=message), node_name=AssistantNodeName.ROOT, node_path=node_path + ) + result = self.stream_processor.process(event) + + self.assertIsNotNone(result) + assert result is not None + self.assertEqual(len(result), 1) + self.assertEqual(result[0], message) + + # UpdateAction tests + + def test_update_action_creates_update_event_with_parent_from_path(self): + """Test UpdateAction creates AssistantUpdateEvent using closest tool_call_id from node_path.""" + message_id = str(uuid4()) + tool_call_id = str(uuid4()) + + node_path = ( + NodePath(name=AssistantGraphName.ASSISTANT), + NodePath(name=AssistantNodeName.ROOT, message_id=message_id, tool_call_id=tool_call_id), + NodePath(name=AssistantGraphName.INSIGHTS), + ) + + event = self._create_dispatcher_event( + UpdateAction(content="Update content"), node_name=AssistantNodeName.TRENDS_GENERATOR, node_path=node_path + ) + result = self.stream_processor.process(event) + + self.assertIsNotNone(result) + assert result is not None + self.assertEqual(len(result), 1) + self.assertIsInstance(result[0], AssistantUpdateEvent) + update_event = cast(AssistantUpdateEvent, result[0]) + self.assertEqual(update_event.id, message_id) + self.assertEqual(update_event.tool_call_id, tool_call_id) + self.assertEqual(update_event.content, "Update content") + + def test_update_action_without_parent_returns_none(self): + """Test UpdateAction without parent tool_call_id in node_path returns None.""" + # No tool_call_id in any path element + node_path = (NodePath(name=AssistantGraphName.ASSISTANT), NodePath(name=AssistantNodeName.ROOT)) + + event = self._create_dispatcher_event( + UpdateAction(content="Update content"), node_name=AssistantNodeName.ROOT, node_path=node_path + ) + result = self.stream_processor.process(event) + + self.assertIsNone(result) + + def test_update_action_without_node_path_returns_none(self): + """Test UpdateAction without node_path returns None.""" + event = self._create_dispatcher_event(UpdateAction(content="Update content"), node_path=None) + result = self.stream_processor.process(event) + + self.assertIsNone(result) + + def test_update_action_finds_closest_tool_call_in_reversed_path(self): + """Test UpdateAction finds the closest (most recent) tool_call_id by reversing the path.""" + # Multiple tool calls in the path - should find the closest one (last in reversed iteration) + message_id_1 = str(uuid4()) + tool_call_id_1 = str(uuid4()) + message_id_2 = str(uuid4()) + tool_call_id_2 = str(uuid4()) + + node_path = ( + NodePath(name=AssistantGraphName.ASSISTANT), + NodePath(name=AssistantNodeName.ROOT, message_id=message_id_1, tool_call_id=tool_call_id_1), + NodePath(name=AssistantGraphName.INSIGHTS, message_id=message_id_2, tool_call_id=tool_call_id_2), + NodePath(name=AssistantNodeName.TRENDS_GENERATOR), + ) + + event = self._create_dispatcher_event( + UpdateAction(content="Update content"), node_name=AssistantNodeName.TRENDS_GENERATOR, node_path=node_path + ) + result = self.stream_processor.process(event) + + self.assertIsNotNone(result) + assert result is not None + self.assertEqual(len(result), 1) + update_event = cast(AssistantUpdateEvent, result[0]) + # Should use the closest parent (last one in reversed path) + self.assertEqual(update_event.id, message_id_2) + self.assertEqual(update_event.tool_call_id, tool_call_id_2) + + # Message deduplication tests + + def test_messages_with_id_deduplicated(self): + """Test that messages with the same ID are deduplicated.""" + message_id = str(uuid4()) + message1 = AssistantMessage(id=message_id, content="First occurrence") + message2 = AssistantMessage(id=message_id, content="Second occurrence") + + node_path = (NodePath(name=AssistantGraphName.ASSISTANT), NodePath(name=AssistantNodeName.ROOT)) + + # Process first message - should be returned + event1 = self._create_dispatcher_event( + MessageAction(message=message1), node_name=AssistantNodeName.ROOT, node_path=node_path + ) + result1 = self.stream_processor.process(event1) + self.assertIsNotNone(result1) + assert result1 is not None + self.assertEqual(result1[0], message1) + + # Process second message with same ID - should be filtered + event2 = self._create_dispatcher_event( + MessageAction(message=message2), node_name=AssistantNodeName.ROOT, node_path=node_path + ) + result2 = self.stream_processor.process(event2) + self.assertIsNone(result2) + + def test_messages_without_id_not_deduplicated(self): + """Test that messages without ID are always yielded (not deduplicated).""" message1 = AssistantMessage(content="Message without ID") message2 = AssistantMessage(content="Another message without ID") - # Process first message - event1 = self._create_dispatcher_event(MessageAction(message=message1)) + node_path = (NodePath(name=AssistantGraphName.ASSISTANT), NodePath(name=AssistantNodeName.ROOT)) + + event1 = self._create_dispatcher_event( + MessageAction(message=message1), node_name=AssistantNodeName.ROOT, node_path=node_path + ) result1 = self.stream_processor.process(event1) - self.assertEqual(result1, message1) + self.assertIsNotNone(result1) + assert result1 is not None + self.assertEqual(result1[0], message1) - # Process second message with same content - event2 = self._create_dispatcher_event(MessageAction(message=message2)) + event2 = self._create_dispatcher_event( + MessageAction(message=message2), node_name=AssistantNodeName.ROOT, node_path=node_path + ) result2 = self.stream_processor.process(event2) - self.assertEqual(result2, message2) + self.assertIsNotNone(result2) + assert result2 is not None + self.assertEqual(result2[0], message2) - # Both should be yielded since they have no IDs - - def test_messages_with_id_are_deduplicated(self): - """Test that messages with ID are deduplicated during streaming.""" + def test_preexisting_message_ids_filtered(self): + """Test that stream processor filters messages with IDs already in _streamed_update_ids.""" message_id = str(uuid4()) - # Create multiple messages with the same ID - message1 = AssistantMessage(id=message_id, content="First occurrence") - message2 = AssistantMessage(id=message_id, content="Second occurrence") - message3 = AssistantMessage(id=message_id, content="Third occurrence") + # Pre-populate the streamed IDs + self.stream_processor._streamed_update_ids.add(message_id) - # Process first message - should be yielded - event1 = self._create_dispatcher_event(MessageAction(message=message1)) - result1 = self.stream_processor.process(event1) - self.assertEqual(result1, message1) - self.assertIn(message_id, self.stream_processor._streamed_update_ids) + message = AssistantMessage(id=message_id, content="Already seen") + node_path = (NodePath(name=AssistantGraphName.ASSISTANT), NodePath(name=AssistantNodeName.ROOT)) - # Process second message with same ID - should return ACK - event2 = self._create_dispatcher_event(MessageAction(message=message2)) - result2 = self.stream_processor.process(event2) - self.assertIsInstance(result2, AssistantGenerationStatusEvent) - result2 = cast(AssistantGenerationStatusEvent, result2) - self.assertEqual(result2.type, AssistantGenerationStatusType.ACK) + event = self._create_dispatcher_event( + MessageAction(message=message), node_name=AssistantNodeName.ROOT, node_path=node_path + ) + result = self.stream_processor.process(event) - # Process third message with same ID - should also return ACK - event3 = self._create_dispatcher_event(MessageAction(message=message3)) - result3 = self.stream_processor.process(event3) - self.assertIsInstance(result3, AssistantGenerationStatusEvent) - result3 = cast(AssistantGenerationStatusEvent, result3) - self.assertEqual(result3.type, AssistantGenerationStatusType.ACK) + self.assertIsNone(result) - def test_stream_processor_with_preexisting_message_ids(self): - """Test that stream processor correctly filters messages when initialized with existing IDs.""" - message_id_1 = str(uuid4()) - message_id_2 = str(uuid4()) + # LangGraph update processing tests - # Simulate existing messages by pre-populating the streamed IDs set - self.stream_processor._streamed_update_ids.add(message_id_1) + def test_langgraph_message_chunk_processed(self): + """Test that LangGraph message chunk updates are converted and processed.""" + chunk = AIMessageChunk(content="LangGraph chunk") + state = {"langgraph_node": AssistantNodeName.TRENDS_GENERATOR, "langgraph_checkpoint_ns": "checkpoint_123"} - # Try to process message with existing ID - should be filtered out - message1 = AssistantMessage(id=message_id_1, content="Already seen") - event1 = self._create_dispatcher_event(MessageAction(message=message1)) - result1 = self.stream_processor.process(event1) - self.assertIsInstance(result1, AssistantGenerationStatusEvent) - result1 = cast(AssistantGenerationStatusEvent, result1) - self.assertEqual(result1.type, AssistantGenerationStatusType.ACK) + update = ["messages", (chunk, state)] + event = LangGraphUpdateEvent(update=update) - # Process message with new ID - should be yielded - message2 = AssistantMessage(id=message_id_2, content="New message") - event2 = self._create_dispatcher_event(MessageAction(message=message2)) - result2 = self.stream_processor.process(event2) - self.assertEqual(result2, message2) - self.assertIn(message_id_2, self.stream_processor._streamed_update_ids) + result = self.stream_processor.process_langgraph_update(event) + + self.assertIsNotNone(result) + assert result is not None + self.assertEqual(len(result), 1) + self.assertIsInstance(result[0], AssistantMessage) + assert isinstance(result[0], AssistantMessage) + self.assertEqual(result[0].content, "LangGraph chunk") + + def test_langgraph_state_update_stored(self): + """Test that LangGraph state updates are stored in _state.""" + new_state_dict = {"messages": [], "plan": "Test plan"} + update = cast(GraphValueUpdateTuple, ["values", new_state_dict]) + + event = LangGraphUpdateEvent(update=update) + result = self.stream_processor.process_langgraph_update(event) + + self.assertIsNone(result) + self.assertIsNotNone(self.stream_processor._state) + assert self.stream_processor._state is not None + self.assertEqual(self.stream_processor._state.plan, "Test plan") + + def test_langgraph_non_message_chunk_ignored(self): + """Test that LangGraph updates that are not AIMessageChunk are ignored.""" + regular_message = AssistantMessage(content="Not a chunk") + state = {"langgraph_node": AssistantNodeName.ROOT, "langgraph_checkpoint_ns": "checkpoint_456"} + + update = ["messages", (regular_message, state)] + event = LangGraphUpdateEvent(update=update) + + result = self.stream_processor.process_langgraph_update(event) + + self.assertIsNone(result) + + def test_langgraph_invalid_update_format_ignored(self): + """Test that invalid LangGraph update formats are ignored.""" + update = "invalid_format" + event = LangGraphUpdateEvent(update=update) + + result = self.stream_processor.process_langgraph_update(event) + + self.assertIsNone(result) + + # Edge cases and error conditions + + def test_empty_node_path_treated_as_root(self): + """Test that empty node_path is treated as root level.""" + message = AssistantMessage(id=str(uuid4()), content="Empty path message") + + event = self._create_dispatcher_event( + MessageAction(message=message), node_name=AssistantNodeName.ROOT, node_path=() + ) + result = self.stream_processor.process(event) + + self.assertIsNotNone(result) + assert result is not None + self.assertEqual(len(result), 1) + self.assertEqual(result[0], message) + + def test_none_node_path_treated_as_root(self): + """Test that None node_path is treated as root level.""" + message = AssistantMessage(id=str(uuid4()), content="None path message") + + event = self._create_dispatcher_event( + MessageAction(message=message), node_name=AssistantNodeName.ROOT, node_path=None + ) + result = self.stream_processor.process(event) + + self.assertIsNotNone(result) + assert result is not None + self.assertEqual(len(result), 1) + self.assertEqual(result[0], message) + + def test_node_end_with_none_state_returns_none(self): + """Test NodeEndAction with None state returns None.""" + event = self._create_dispatcher_event(NodeEndAction(state=None)) + result = self.stream_processor.process(event) + + self.assertIsNone(result) + + def test_update_action_with_empty_content_returns_none(self): + """Test UpdateAction with empty content returns None.""" + node_path = ( + NodePath(name=AssistantGraphName.ASSISTANT), + NodePath(name=AssistantNodeName.ROOT, message_id=str(uuid4()), tool_call_id=str(uuid4())), + ) + + event = self._create_dispatcher_event(UpdateAction(content=""), node_path=node_path) + result = self.stream_processor.process(event) + + self.assertIsNone(result) + + def test_special_messages_from_root_level_returned(self): + """Test that special message types from root level are handled by root message logic.""" + # VisualizationMessage from root should be returned if from verbose node + query = TrendsQuery(series=[]) + viz_message = VisualizationMessage(query="test", answer=query, plan="plan") + + node_path = (NodePath(name=AssistantGraphName.ASSISTANT), NodePath(name=AssistantNodeName.TRENDS_GENERATOR)) + + event = self._create_dispatcher_event( + MessageAction(message=viz_message), node_name=AssistantNodeName.TRENDS_GENERATOR, node_path=node_path + ) + result = self.stream_processor.process(event) + + self.assertIsNotNone(result) + assert result is not None + self.assertEqual(len(result), 1) + self.assertEqual(result[0], viz_message) diff --git a/ee/hogai/utils/types/base.py b/ee/hogai/utils/types/base.py index ce2eef08f1..1b9e5afb46 100644 --- a/ee/hogai/utils/types/base.py +++ b/ee/hogai/utils/types/base.py @@ -299,7 +299,7 @@ class _SharedAssistantState(BaseStateWithMessages, BaseStateWithIntermediateStep """ The ID of the message to start from to keep the message window short enough. """ - root_tool_call_id: Optional[str] = Field(default=None) + root_tool_call_id: Annotated[Optional[str], replace] = Field(default=None) """ The ID of the tool call from the root node. """ @@ -311,7 +311,7 @@ class _SharedAssistantState(BaseStateWithMessages, BaseStateWithIntermediateStep """ The type of insight to generate. """ - root_tool_calls_count: Optional[int] = Field(default=None) + root_tool_calls_count: Annotated[Optional[int], replace] = Field(default=None) """ Tracks the number of tool calls made by the root node to terminate the loop. """ @@ -399,7 +399,6 @@ class AssistantNodeName(StrEnum): MEMORY_COLLECTOR_TOOLS = "memory_collector_tools" INKEEP_DOCS = "inkeep_docs" INSIGHT_RAG_CONTEXT = "insight_rag_context" - INSIGHTS_SUBGRAPH = "insights_subgraph" TITLE_GENERATOR = "title_generator" INSIGHTS_SEARCH = "insights_search" SESSION_SUMMARIZATION = "session_summarization" @@ -413,6 +412,13 @@ class AssistantNodeName(StrEnum): REVENUE_ANALYTICS_FILTER_OPTIONS_TOOLS = "revenue_analytics_filter_options_tools" +class AssistantGraphName(StrEnum): + ASSISTANT = "assistant_graph" + INSIGHTS = "insights_graph" + TAXONOMY = "taxonomy_graph" + DEEP_RESEARCH = "deep_research_graph" + + class AssistantMode(StrEnum): ASSISTANT = "assistant" INSIGHTS_TOOL = "insights_tool" @@ -443,9 +449,33 @@ class NodeStartAction(BaseModel): type: Literal["NODE_START"] = "NODE_START" -AssistantActionUnion = MessageAction | MessageChunkAction | NodeStartAction +class NodeEndAction(BaseModel, Generic[PartialStateType]): + type: Literal["NODE_END"] = "NODE_END" + state: PartialStateType | None = None + + +class UpdateAction(BaseModel): + type: Literal["UPDATE"] = "UPDATE" + content: str + + +AssistantActionUnion = MessageAction | MessageChunkAction | NodeStartAction | NodeEndAction | UpdateAction + + +class NodePath(BaseModel): + """Defines a vertice of the assistant graph path.""" + + name: str + message_id: str | None = None + tool_call_id: str | None = None class AssistantDispatcherEvent(BaseModel): action: AssistantActionUnion = Field(discriminator="type") + node_path: tuple[NodePath, ...] | None = None node_name: str + node_run_id: str + + +class LangGraphUpdateEvent(BaseModel): + update: Any diff --git a/frontend/__snapshots__/scenes-app-data-pipelines--pipeline-destination-page--light.png b/frontend/__snapshots__/scenes-app-data-pipelines--pipeline-destination-page--light.png index 77339b9eec..53f2163297 100644 Binary files a/frontend/__snapshots__/scenes-app-data-pipelines--pipeline-destination-page--light.png and b/frontend/__snapshots__/scenes-app-data-pipelines--pipeline-destination-page--light.png differ diff --git a/frontend/__snapshots__/scenes-app-insights-trendsvalue--trends-area-edit--dark.png b/frontend/__snapshots__/scenes-app-insights-trendsvalue--trends-area-edit--dark.png index 9dd9468aff..c7d0ddff18 100644 Binary files a/frontend/__snapshots__/scenes-app-insights-trendsvalue--trends-area-edit--dark.png and b/frontend/__snapshots__/scenes-app-insights-trendsvalue--trends-area-edit--dark.png differ diff --git a/frontend/__snapshots__/scenes-app-insights-trendsvalue--trends-area-edit--light.png b/frontend/__snapshots__/scenes-app-insights-trendsvalue--trends-area-edit--light.png index 176e4d0aea..fa59bb523f 100644 Binary files a/frontend/__snapshots__/scenes-app-insights-trendsvalue--trends-area-edit--light.png and b/frontend/__snapshots__/scenes-app-insights-trendsvalue--trends-area-edit--light.png differ diff --git a/frontend/__snapshots__/scenes-app-posthog-ai--max-instance-with-contextual-tools--dark.png b/frontend/__snapshots__/scenes-app-posthog-ai--max-instance-with-contextual-tools--dark.png index 6fc02e7724..be639a9d44 100644 Binary files a/frontend/__snapshots__/scenes-app-posthog-ai--max-instance-with-contextual-tools--dark.png and b/frontend/__snapshots__/scenes-app-posthog-ai--max-instance-with-contextual-tools--dark.png differ diff --git a/frontend/__snapshots__/scenes-app-posthog-ai--max-instance-with-contextual-tools--light.png b/frontend/__snapshots__/scenes-app-posthog-ai--max-instance-with-contextual-tools--light.png index 2ae147a2da..379b6d0032 100644 Binary files a/frontend/__snapshots__/scenes-app-posthog-ai--max-instance-with-contextual-tools--light.png and b/frontend/__snapshots__/scenes-app-posthog-ai--max-instance-with-contextual-tools--light.png differ diff --git a/frontend/__snapshots__/scenes-app-posthog-ai--reasoning-component--light.png b/frontend/__snapshots__/scenes-app-posthog-ai--reasoning-component--light.png index 1396f3d4fe..ce426311b2 100644 Binary files a/frontend/__snapshots__/scenes-app-posthog-ai--reasoning-component--light.png and b/frontend/__snapshots__/scenes-app-posthog-ai--reasoning-component--light.png differ diff --git a/frontend/__snapshots__/scenes-app-sidepanels--side-panel-docs--dark.png b/frontend/__snapshots__/scenes-app-sidepanels--side-panel-docs--dark.png index 397d231080..b897157a52 100644 Binary files a/frontend/__snapshots__/scenes-app-sidepanels--side-panel-docs--dark.png and b/frontend/__snapshots__/scenes-app-sidepanels--side-panel-docs--dark.png differ diff --git a/frontend/src/queries/nodes/InsightViz/EditorFilters.tsx b/frontend/src/queries/nodes/InsightViz/EditorFilters.tsx index d76a168aea..67ed5c97ca 100644 --- a/frontend/src/queries/nodes/InsightViz/EditorFilters.tsx +++ b/frontend/src/queries/nodes/InsightViz/EditorFilters.tsx @@ -409,7 +409,7 @@ export function EditorFilters({ query, showing, embedded }: EditorFiltersProps):
) : null @@ -834,9 +835,10 @@ function ReasoningAnswer({ content, completed, id, showCompletionIcon = true }: interface ToolCallsAnswerProps { toolCalls: EnhancedToolCall[] + registeredToolMap: Record } -function ToolCallsAnswer({ toolCalls }: ToolCallsAnswerProps): JSX.Element { +function ToolCallsAnswer({ toolCalls, registeredToolMap }: ToolCallsAnswerProps): JSX.Element { // Separate todo_write tool calls from regular tool calls const todoWriteToolCalls = toolCalls.filter((tc) => tc.name === 'todo_write') const regularToolCalls = toolCalls.filter((tc) => tc.name !== 'todo_write') @@ -867,7 +869,7 @@ function ToolCallsAnswer({ toolCalls }: ToolCallsAnswerProps): JSX.Element { let description = `Executing ${toolCall.name}` if (definition) { if (definition.displayFormatter) { - description = definition.displayFormatter(toolCall) + description = definition.displayFormatter(toolCall, { registeredToolMap }) } if (commentary) { description = commentary diff --git a/frontend/src/scenes/max/components/ToolsDisplay.tsx b/frontend/src/scenes/max/components/ToolsDisplay.tsx index 5b5f8008db..da43f87adf 100644 --- a/frontend/src/scenes/max/components/ToolsDisplay.tsx +++ b/frontend/src/scenes/max/components/ToolsDisplay.tsx @@ -68,7 +68,8 @@ export const ToolsDisplay: React.FC = ({ isFloating, tools, b // or border-secondary (--color-posthog-3000-400) because the former is almost invisible here, and the latter too distinct {toolDef?.icon || } - {toolDef?.name} + {/* Controls how the create_and_query_insight tool displays its name */} + {tool.name} ) })} @@ -173,7 +174,9 @@ function ToolsExplanation({ toolsInReverse }: { toolsInReverse: ToolRegistration if (toolDef?.subtools) { tools.push(...Object.values(toolDef.subtools)) } else { - tools.push({ name: toolDef?.name, description: toolDef?.description, icon: toolDef?.icon }) + // Taking `tool` name and description from the registered tool, not the tool definition. + // This makes tool substitution correctly work (create_and_query_insight). + tools.push({ name: tool.name, description: tool.description, icon: toolDef?.icon }) } return tools }, diff --git a/frontend/src/scenes/max/max-constants.tsx b/frontend/src/scenes/max/max-constants.tsx index d60e737480..2400d7c9ab 100644 --- a/frontend/src/scenes/max/max-constants.tsx +++ b/frontend/src/scenes/max/max-constants.tsx @@ -25,7 +25,10 @@ export interface ToolDefinition { ToolDefinition > icon: JSX.Element - displayFormatter?: (toolCall: EnhancedToolCall) => string + displayFormatter?: ( + toolCall: EnhancedToolCall, + { registeredToolMap }: { registeredToolMap: Record } + ) => string /** * If only available in a specific product, specify it here. * We're using Scene instead of ProductKey, because that's more flexible (specifically for SQL editor there @@ -160,14 +163,18 @@ export const TOOL_DEFINITIONS: Record, Tool }, }, create_and_query_insight: { - name: 'Query data', - description: 'Query data by creating insights and SQL queries', + name: 'Edit the insight', + description: "Edit the insight you're viewing", icon: iconForType('product_analytics'), - displayFormatter: (toolCall) => { - if (toolCall.status === 'completed') { - return 'Created an insight' + product: Scene.Insight, + displayFormatter: (toolCall, { registeredToolMap }) => { + const isEditing = registeredToolMap.create_and_query_insight + if (isEditing) { + return toolCall.status === 'completed' + ? 'Edited the insight you are viewing' + : 'Editing the insight you are viewing...' } - return 'Creating an insight...' + return toolCall.status === 'completed' ? 'Created an insight' : 'Creating an insight...' }, }, search_session_recordings: { @@ -327,18 +334,6 @@ export const TOOL_DEFINITIONS: Record, Tool return 'Fixing SQL...' }, }, - edit_current_insight: { - name: 'Edit the insight', - description: "Edit the insight you're viewing", - icon: iconForType('product_analytics'), - product: Scene.Insight, - displayFormatter: (toolCall) => { - if (toolCall.status === 'completed') { - return 'Edited the insight you are viewing' - } - return 'Editing the insight you are viewing...' - }, - }, filter_revenue_analytics: { name: 'Filter revenue analytics', description: 'Filter revenue analytics to find the most impactful revenue insights', diff --git a/frontend/src/scenes/max/maxGlobalLogic.tsx b/frontend/src/scenes/max/maxGlobalLogic.tsx index 9e238fd1aa..cac6232a74 100644 --- a/frontend/src/scenes/max/maxGlobalLogic.tsx +++ b/frontend/src/scenes/max/maxGlobalLogic.tsx @@ -73,8 +73,8 @@ export const STATIC_TOOLS: ToolRegistration[] = [ }, { identifier: 'create_and_query_insight' as const, - name: TOOL_DEFINITIONS['create_and_query_insight'].name, - description: TOOL_DEFINITIONS['create_and_query_insight'].description, + name: 'Query data', + description: 'Query data by creating insights and SQL queries', }, ] @@ -210,17 +210,10 @@ export const maxGlobalLogic = kea([ ], toolMap: [ (s) => [s.registeredToolMap, s.availableStaticTools], - (registeredToolMap, availableStaticTools) => { - if (registeredToolMap.edit_current_insight) { - availableStaticTools = availableStaticTools.filter( - (tool) => tool.identifier !== 'create_and_query_insight' - ) - } - return { - ...Object.fromEntries(availableStaticTools.map((tool) => [tool.identifier, tool])), - ...registeredToolMap, - } - }, + (registeredToolMap, availableStaticTools) => ({ + ...Object.fromEntries(availableStaticTools.map((tool) => [tool.identifier, tool])), + ...registeredToolMap, + }), ], tools: [(s) => [s.toolMap], (toolMap): ToolRegistration[] => Object.values(toolMap)], editInsightToolRegistered: [ diff --git a/posthog/schema.py b/posthog/schema.py index 24e071dfd4..ba02d59095 100644 --- a/posthog/schema.py +++ b/posthog/schema.py @@ -267,7 +267,6 @@ class AssistantTool(StrEnum): CREATE_HOG_FUNCTION_INPUTS = "create_hog_function_inputs" CREATE_MESSAGE_TEMPLATE = "create_message_template" NAVIGATE = "navigate" - EDIT_CURRENT_INSIGHT = "edit_current_insight" FILTER_ERROR_TRACKING_ISSUES = "filter_error_tracking_issues" FIND_ERROR_TRACKING_IMPACTFUL_ISSUE_EVENT_LIST = "find_error_tracking_impactful_issue_event_list" EXPERIMENT_RESULTS_SUMMARY = "experiment_results_summary" diff --git a/products/dashboards/backend/test/test_max_tool_integration.py b/products/dashboards/backend/test/test_max_tool_integration.py index 80011d75dd..37d4fa86a9 100644 --- a/products/dashboards/backend/test/test_max_tool_integration.py +++ b/products/dashboards/backend/test/test_max_tool_integration.py @@ -96,7 +96,6 @@ async def test_dashboard_metadata_update(dashboard_setup): tool = EditCurrentDashboardTool( team=dashboard.team, user=conversation.user, - tool_call_id="test-tool-call-id", config=RunnableConfig( configurable={ "thread_id": conversation.id, @@ -133,7 +132,6 @@ async def test_dashboard_metadata_update_no_permissions(dashboard_setup_no_perms tool = EditCurrentDashboardTool( team=dashboard.team, user=conversation.user, - tool_call_id="test-tool-call-id", config=RunnableConfig( configurable={ "thread_id": conversation.id, @@ -155,7 +153,6 @@ async def test_dashboard_add_insights(dashboard_setup): tool = EditCurrentDashboardTool( team=dashboard.team, user=conversation.user, - tool_call_id="test-tool-call-id", config=RunnableConfig( configurable={ "thread_id": conversation.id, diff --git a/products/dashboards/backend/test/test_max_tools.py b/products/dashboards/backend/test/test_max_tools.py index 711608f997..206417d67b 100644 --- a/products/dashboards/backend/test/test_max_tools.py +++ b/products/dashboards/backend/test/test_max_tools.py @@ -25,9 +25,7 @@ class TestEditCurrentDashboardTool: configurable = {"team": mock_team, "user": mock_user} if context: configurable["contextual_tools"] = {"edit_current_dashboard": context} - tool = EditCurrentDashboardTool( - team=mock_team, user=mock_user, config={"configurable": configurable}, tool_call_id="test-tool-call-id" - ) + tool = EditCurrentDashboardTool(team=mock_team, user=mock_user, config={"configurable": configurable}) return tool @pytest.mark.asyncio diff --git a/products/data_warehouse/backend/max_tools.py b/products/data_warehouse/backend/max_tools.py index 92f76f5438..b7d9384d3f 100644 --- a/products/data_warehouse/backend/max_tools.py +++ b/products/data_warehouse/backend/max_tools.py @@ -113,11 +113,10 @@ class HogQLGeneratorToolsNode(TaxonomyAgentToolsNode[TaxonomyAgentState, Taxonom class HogQLGeneratorGraph(TaxonomyAgent[TaxonomyAgentState, TaxonomyAgentState[FinalAnswerArgs]]): - def __init__(self, team: Team, user: User, tool_call_id: str): + def __init__(self, team: Team, user: User): super().__init__( team, user, - tool_call_id, loop_node_class=HogQLGeneratorNode, tools_node_class=HogQLGeneratorToolsNode, toolkit_class=HogQLGeneratorToolkit, @@ -134,9 +133,7 @@ class HogQLGeneratorTool(HogQLGeneratorMixin, MaxTool): current_query: str | None = self.context.get("current_query", "") user_prompt = HOGQL_GENERATOR_USER_PROMPT.format(instructions=instructions, current_query=current_query) - graph = HogQLGeneratorGraph( - team=self._team, user=self._user, tool_call_id=self._tool_call_id - ).compile_full_graph() + graph = HogQLGeneratorGraph(team=self._team, user=self._user).compile_full_graph() graph_context = { "change": user_prompt, diff --git a/products/data_warehouse/backend/test/test_max_tools.py b/products/data_warehouse/backend/test/test_max_tools.py index 30ba13ea3c..c780b868bf 100644 --- a/products/data_warehouse/backend/test/test_max_tools.py +++ b/products/data_warehouse/backend/test/test_max_tools.py @@ -43,9 +43,7 @@ class TestDataWarehouseMaxTools(NonAtomicBaseTest): mock_graph.ainvoke.return_value = mock_result mock_compile.return_value = mock_graph - tool = HogQLGeneratorTool( - team=self.team, user=self.user, state=AssistantState(messages=[]), tool_call_id="test-tool-call-id" - ) + tool = HogQLGeneratorTool(team=self.team, user=self.user, state=AssistantState(messages=[])) tool_call = AssistantToolCall( id="1", name="generate_hogql_query", @@ -83,9 +81,7 @@ class TestDataWarehouseMaxTools(NonAtomicBaseTest): mock_graph.ainvoke.return_value = mock_result mock_compile.return_value = mock_graph - tool = HogQLGeneratorTool( - team=self.team, user=self.user, state=AssistantState(messages=[]), tool_call_id="test-tool-call-id" - ) + tool = HogQLGeneratorTool(team=self.team, user=self.user, state=AssistantState(messages=[])) tool_call = AssistantToolCall( id="1", name="generate_hogql_query", @@ -119,9 +115,7 @@ class TestDataWarehouseMaxTools(NonAtomicBaseTest): mock_compile.return_value = mock_graph - tool = HogQLGeneratorTool( - team=self.team, user=self.user, state=AssistantState(messages=[]), tool_call_id="test-tool-call-id" - ) + tool = HogQLGeneratorTool(team=self.team, user=self.user, state=AssistantState(messages=[])) tool_call = AssistantToolCall( id="1", name="generate_hogql_query", @@ -171,9 +165,7 @@ class TestDataWarehouseMaxTools(NonAtomicBaseTest): "SELECT suspicious_query FROM events", "Suspicious query detected" ) - tool = HogQLGeneratorTool( - team=self.team, user=self.user, state=AssistantState(messages=[]), tool_call_id="test-tool-call-id" - ) + tool = HogQLGeneratorTool(team=self.team, user=self.user, state=AssistantState(messages=[])) tool_call = AssistantToolCall( id="1", name="generate_hogql_query", @@ -223,9 +215,7 @@ class TestDataWarehouseMaxTools(NonAtomicBaseTest): mock_graph.ainvoke.return_value = mock_result mock_compile.return_value = mock_graph - tool = HogQLGeneratorTool( - team=self.team, user=self.user, state=AssistantState(messages=[]), tool_call_id="test-tool-call-id" - ) + tool = HogQLGeneratorTool(team=self.team, user=self.user, state=AssistantState(messages=[])) tool_call = AssistantToolCall( id="1", name="generate_hogql_query", @@ -263,9 +253,7 @@ class TestDataWarehouseMaxTools(NonAtomicBaseTest): mock_graph.ainvoke.return_value = graph_result mock_compile.return_value = mock_graph - tool = HogQLGeneratorTool( - team=self.team, user=self.user, state=AssistantState(messages=[]), tool_call_id="test-tool-call-id" - ) + tool = HogQLGeneratorTool(team=self.team, user=self.user, state=AssistantState(messages=[])) tool_call = AssistantToolCall( id="1", name="generate_hogql_query", @@ -301,9 +289,7 @@ class TestDataWarehouseMaxTools(NonAtomicBaseTest): mock_graph.ainvoke.return_value = graph_result mock_compile.return_value = mock_graph - tool = HogQLGeneratorTool( - team=self.team, user=self.user, state=AssistantState(messages=[]), tool_call_id="test-tool-call-id" - ) + tool = HogQLGeneratorTool(team=self.team, user=self.user, state=AssistantState(messages=[])) tool_call = AssistantToolCall( id="1", name="generate_hogql_query", @@ -339,9 +325,7 @@ class TestDataWarehouseMaxTools(NonAtomicBaseTest): mock_graph.ainvoke.return_value = graph_result mock_compile.return_value = mock_graph - tool = HogQLGeneratorTool( - team=self.team, user=self.user, state=AssistantState(messages=[]), tool_call_id="test-tool-call-id" - ) + tool = HogQLGeneratorTool(team=self.team, user=self.user, state=AssistantState(messages=[])) tool_call = AssistantToolCall( id="1", name="generate_hogql_query", @@ -353,9 +337,7 @@ class TestDataWarehouseMaxTools(NonAtomicBaseTest): def test_current_query_included_in_system_prompt_template(self): """Test that the system prompt template includes the current query section.""" - tool = HogQLGeneratorTool( - team=self.team, user=self.user, state=AssistantState(messages=[]), tool_call_id="test-tool-call-id" - ) + tool = HogQLGeneratorTool(team=self.team, user=self.user, state=AssistantState(messages=[])) # Verify the system prompt template contains the expected current query section self.assertIn("The current HogQL query", tool.context_prompt_template) diff --git a/products/error_tracking/backend/max_tools.py b/products/error_tracking/backend/max_tools.py index 95fc69075a..d994186403 100644 --- a/products/error_tracking/backend/max_tools.py +++ b/products/error_tracking/backend/max_tools.py @@ -155,11 +155,10 @@ class ErrorTrackingIssueImpactToolsNode( class ErrorTrackingIssueImpactGraph( TaxonomyAgent[TaxonomyAgentState, TaxonomyAgentState[ErrorTrackingIssueImpactToolOutput]] ): - def __init__(self, team: Team, user: User, tool_call_id: str): + def __init__(self, team: Team, user: User): super().__init__( team, user, - tool_call_id, loop_node_class=ErrorTrackingIssueImpactLoopNode, tools_node_class=ErrorTrackingIssueImpactToolsNode, toolkit_class=ErrorTrackingIssueImpactToolkit, @@ -177,7 +176,7 @@ class ErrorTrackingIssueImpactTool(MaxTool): args_schema: type[BaseModel] = IssueImpactQueryArgs async def _arun_impl(self, instructions: str) -> tuple[str, ErrorTrackingIssueImpactToolOutput]: - graph = ErrorTrackingIssueImpactGraph(team=self._team, user=self._user, tool_call_id=self._tool_call_id) + graph = ErrorTrackingIssueImpactGraph(team=self._team, user=self._user) graph_context = { "change": f"Goal: {instructions}", diff --git a/products/feature_flags/backend/max_tools.py b/products/feature_flags/backend/max_tools.py index 14d8a08b21..7cfaa5bd88 100644 --- a/products/feature_flags/backend/max_tools.py +++ b/products/feature_flags/backend/max_tools.py @@ -439,11 +439,10 @@ class FeatureFlagGeneratorGraph(TaxonomyAgent[TaxonomyAgentState, TaxonomyAgentS 4. Generate structured feature flag configuration """ - def __init__(self, team: Team, user: User, tool_call_id: str): + def __init__(self, team: Team, user: User): super().__init__( team, user, - tool_call_id, loop_node_class=FeatureFlagCreationNode, tools_node_class=FeatureFlagCreationToolsNode, toolkit_class=FeatureFlagToolkit, @@ -501,7 +500,7 @@ The tool will automatically: async def _create_flag_from_instructions(self, instructions: str) -> FeatureFlagCreationSchema: """Use TaxonomyAgent graph to generate structured flag configuration.""" - graph = FeatureFlagGeneratorGraph(team=self._team, user=self._user, tool_call_id=self._tool_call_id) + graph = FeatureFlagGeneratorGraph(team=self._team, user=self._user) graph_context = { "change": f"Create a feature flag based on these instructions: {instructions}", diff --git a/products/product_analytics/backend/max_tools.py b/products/product_analytics/backend/max_tools.py deleted file mode 100644 index 6af7ccdb0c..0000000000 --- a/products/product_analytics/backend/max_tools.py +++ /dev/null @@ -1,137 +0,0 @@ -from uuid import uuid4 - -from pydantic import BaseModel, Field - -from posthog.schema import AssistantMessage, AssistantToolCallMessage, VisualizationMessage - -from ee.hogai.tool import MaxTool, ToolMessagesArtifact -from ee.hogai.utils.types import AssistantState - -QUERY_KIND_DESCRIPTION_PROMPT = """ -## Trends -A trends insight visualizes events over time using time series. They're useful for finding patterns in historical data. - -The trends insights have the following features: -- The insight can show multiple trends in one request. -- Custom formulas can calculate derived metrics, like `A/B*100` to calculate a ratio. -- Filter and break down data using multiple properties. -- Compare with the previous period and sample data. -- Apply various aggregation types, like sum, average, etc., and chart types. -- And more. - -Examples of use cases include: -- How the product's most important metrics change over time. -- Long-term patterns, or cycles in product's usage. -- The usage of different features side-by-side. -- How the properties of events vary using aggregation (sum, average, etc). -- Users can also visualize the same data points in a variety of ways. - -## Funnel -A funnel insight visualizes a sequence of events that users go through in a product. They use percentages as the primary aggregation type. Funnels use two or more series, so the conversation history should mention at least two events. - -The funnel insights have the following features: -- Various visualization types (steps, time-to-convert, historical trends). -- Filter data and apply exclusion steps. -- Break down data using a single property. -- Specify conversion windows, details of conversion calculation, attribution settings. -- Sample data. -- And more. - -Examples of use cases include: -- Conversion rates. -- Drop off steps. -- Steps with the highest friction and time to convert. -- If product changes are improving their funnel over time. -- Average/median time to convert. -- Conversion trends over time. - -## Retention -A retention insight visualizes how many users return to the product after performing some action. They're useful for understanding user engagement and retention. - -The retention insights have the following features: filter data, sample data, and more. - -Examples of use cases include: -- How many users come back and perform an action after their first visit. -- How many users come back to perform action X after performing action Y. -- How often users return to use a specific feature. - -## SQL -The 'sql' insight type allows you to write arbitrary SQL queries to retrieve data. - -The SQL insights have the following features: -- Filter data using arbitrary SQL. -- All ClickHouse SQL features. -- You can nest subqueries as needed. -""".strip() - - -class EditCurrentInsightArgs(BaseModel): - """ - Edits the insight visualization the user is currently working on, by creating a query or iterating on a previous query. - """ - - query_description: str = Field( - description="The new query to edit the current insight. Must include all details from the current insight plus any change on top of them. Include any relevant information from the current conversation, as the tool does not have access to the conversation." - ) - query_kind: str = Field(description=QUERY_KIND_DESCRIPTION_PROMPT) - - -class EditCurrentInsightTool(MaxTool): - name: str = "edit_current_insight" - description: str = ( - "Update the insight the user is currently working on, based on the current insight's JSON schema." - ) - context_prompt_template: str = """The user is currently editing an insight (aka query). Here is that insight's current definition, which can be edited using the `edit_current_insight` tool: - -```json -{current_query} -``` - -IMPORTANT: DO NOT REMOVE ANY FIELDS FROM THE CURRENT INSIGHT DEFINITION. DO NOT CHANGE ANY OTHER FIELDS THAN THE ONES THE USER ASKED FOR. KEEP THE REST AS IS. -""".strip() - - args_schema: type[BaseModel] = EditCurrentInsightArgs - - async def _arun_impl(self, query_kind: str, query_description: str) -> tuple[str, ToolMessagesArtifact]: - from ee.hogai.graph.graph import InsightsAssistantGraph # avoid circular import - - if "current_query" not in self.context: - raise ValueError("Context `current_query` is required for the `create_and_query_insight` tool") - - graph = InsightsAssistantGraph(self._team, self._user, tool_call_id=self._tool_call_id).compile_full_graph() - state = self._state - last_message = state.messages[-1] - if not isinstance(last_message, AssistantMessage): - raise ValueError("Last message is not an AssistantMessage") - if last_message.tool_calls is None or len(last_message.tool_calls) == 0: - raise ValueError("Last message has no tool calls") - - state.root_tool_insight_plan = query_description - root_tool_call_id = last_message.tool_calls[0].id - - # We need to set a new root tool call id to sub-nest the graph within the contextual tool call - # and avoid duplicating messages in the stream - state.root_tool_call_id = str(uuid4()) - - state_dict = await graph.ainvoke(state, config=self._config) - state = AssistantState.model_validate(state_dict) - - result = state.messages[-1] - viz_messages = [message for message in state.messages if isinstance(message, VisualizationMessage)] - viz_message = viz_messages[-1] if viz_messages else None - if not viz_message: - raise ValueError("Visualization was not generated") - if not isinstance(result, AssistantToolCallMessage): - raise ValueError("Last message is not an AssistantToolCallMessage") - - return "", ToolMessagesArtifact( - messages=[ - viz_message, - AssistantToolCallMessage( - content=result.content, - ui_payload={self.get_name(): viz_message.answer.model_dump(exclude_none=True)}, - id=str(uuid4()), - tool_call_id=root_tool_call_id, - ), - ] - ) diff --git a/products/replay/backend/max_tools.py b/products/replay/backend/max_tools.py index ea88a5004e..656c22e0ab 100644 --- a/products/replay/backend/max_tools.py +++ b/products/replay/backend/max_tools.py @@ -90,11 +90,10 @@ class SessionReplayFilterOptionsGraph( ): """Graph for generating filtering options for session replay.""" - def __init__(self, team: Team, user: User, tool_call_id: str): + def __init__(self, team: Team, user: User): super().__init__( team, user, - tool_call_id=tool_call_id, loop_node_class=SessionReplayFilterNode, tools_node_class=SessionReplayFilterOptionsToolsNode, toolkit_class=SessionReplayFilterOptionsToolkit, @@ -131,7 +130,7 @@ class SearchSessionRecordingsTool(MaxTool): Reusable method to call graph to avoid code/prompt duplication and enable different processing of the results, based on the place the tool is used. """ - graph = SessionReplayFilterOptionsGraph(team=self._team, user=self._user, tool_call_id=self._tool_call_id) + graph = SessionReplayFilterOptionsGraph(team=self._team, user=self._user) pretty_filters = json.dumps(self.context.get("current_filters", {}), indent=2) user_prompt = USER_FILTER_OPTIONS_PROMPT.format(change=change, current_filters=pretty_filters) graph_context = { diff --git a/products/revenue_analytics/backend/max_tools.py b/products/revenue_analytics/backend/max_tools.py index d34d2323a7..afc5d922b7 100644 --- a/products/revenue_analytics/backend/max_tools.py +++ b/products/revenue_analytics/backend/max_tools.py @@ -153,11 +153,10 @@ class RevenueAnalyticsFilterOptionsGraph( ): """Graph for generating filtering options for revenue analytics.""" - def __init__(self, team: Team, user: User, tool_call_id: str): + def __init__(self, team: Team, user: User): super().__init__( team, user, - tool_call_id, loop_node_class=RevenueAnalyticsFilterNode, tools_node_class=RevenueAnalyticsFilterOptionsToolsNode, toolkit_class=RevenueAnalyticsFilterOptionsToolkit, @@ -192,7 +191,7 @@ class FilterRevenueAnalyticsTool(MaxTool): Reusable method to call graph to avoid code/prompt duplication and enable different processing of the results, based on the place the tool is used. """ - graph = RevenueAnalyticsFilterOptionsGraph(team=self._team, user=self._user, tool_call_id=self._tool_call_id) + graph = RevenueAnalyticsFilterOptionsGraph(team=self._team, user=self._user) pretty_filters = json.dumps(self.context.get("current_filters", {}), indent=2) user_prompt = USER_FILTER_OPTIONS_PROMPT.format(change=change, current_filters=pretty_filters) graph_context = { diff --git a/products/surveys/backend/max_tools.py b/products/surveys/backend/max_tools.py index 547074fc43..2ddbb20c76 100644 --- a/products/surveys/backend/max_tools.py +++ b/products/surveys/backend/max_tools.py @@ -52,7 +52,7 @@ class CreateSurveyTool(MaxTool): Create a survey from natural language instructions. """ - graph = FeatureFlagLookupGraph(team=self._team, user=self._user, tool_call_id=self._tool_call_id) + graph = FeatureFlagLookupGraph(team=self._team, user=self._user) graph_context = { "change": f"Create a survey based on these instructions: {instructions}", @@ -286,11 +286,10 @@ class SurveyLookupToolsNode(TaxonomyAgentToolsNode[TaxonomyAgentState, TaxonomyA class FeatureFlagLookupGraph(TaxonomyAgent[TaxonomyAgentState, TaxonomyAgentState[SurveyCreationSchema]]): """Graph for feature flag lookup operations.""" - def __init__(self, team: Team, user: User, tool_call_id: str): + def __init__(self, team: Team, user: User): super().__init__( team, user, - tool_call_id, loop_node_class=SurveyLoopNode, tools_node_class=SurveyLookupToolsNode, toolkit_class=SurveyToolkit, diff --git a/products/surveys/backend/test_max_tools.py b/products/surveys/backend/test_max_tools.py index 7380676fa1..e097918e44 100644 --- a/products/surveys/backend/test_max_tools.py +++ b/products/surveys/backend/test_max_tools.py @@ -45,7 +45,7 @@ class TestSurveyCreatorTool(BaseTest): def _setup_tool(self): """Helper to create a SurveyCreatorTool instance with mocked dependencies""" - tool = CreateSurveyTool(team=self.team, user=self.user, config=self._config, tool_call_id="test-tool-call-id") + tool = CreateSurveyTool(team=self.team, user=self.user, config=self._config) return tool def test_get_team_survey_config(self): @@ -468,7 +468,6 @@ class TestSurveyAnalysisTool(BaseTest): tool = SurveyAnalysisTool( team=self.team, user=self.user, - tool_call_id="test-tool-call-id", config={ **self._config, "configurable": {