mirror of
https://github.com/BillyOutlast/posthog.git
synced 2026-02-04 03:01:23 +01:00
feat(max): parallel tool calls (#39954)
Co-authored-by: kappa90 <e.capparelli@gmail.com> Co-authored-by: github-actions <41898282+github-actions[bot]@users.noreply.github.com>
This commit is contained in:
@@ -320,7 +320,7 @@ class YourTaxonomyGraph(TaxonomyAgent[TaxonomyAgentState, TaxonomyAgentState[Max
|
||||
4. Invoke it (typically from a `MaxTool`), mirroring `products/replay/backend/max_tools.py`:
|
||||
|
||||
```python
|
||||
graph = YourTaxonomyGraph(team=self._team, user=self._user, tool_call_id=self._tool_call_id)
|
||||
graph = YourTaxonomyGraph(team=self._team, user=self._user)
|
||||
|
||||
graph_context = {
|
||||
"change": "Show me recordings of users in Germany that used a mobile device while performing a payment",
|
||||
|
||||
@@ -9,7 +9,6 @@ import structlog
|
||||
import posthoganalytics
|
||||
from asgiref.sync import async_to_sync
|
||||
from langchain_core.callbacks.base import BaseCallbackHandler
|
||||
from langchain_core.messages import AIMessageChunk
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
from langgraph.errors import GraphRecursionError
|
||||
from langgraph.graph.state import CompiledStateGraph
|
||||
@@ -31,20 +30,19 @@ from posthog.event_usage import report_user_action
|
||||
from posthog.models import Team, User
|
||||
from posthog.sync import database_sync_to_async
|
||||
|
||||
from ee.hogai.graph.base import BaseAssistantNode
|
||||
from ee.hogai.utils.exceptions import GenerationCanceled
|
||||
from ee.hogai.utils.helpers import extract_stream_update
|
||||
from ee.hogai.utils.state import (
|
||||
GraphValueUpdateTuple,
|
||||
is_message_update,
|
||||
is_state_update,
|
||||
is_value_update,
|
||||
validate_state_update,
|
||||
from ee.hogai.utils.state import validate_state_update
|
||||
from ee.hogai.utils.stream_processor import AssistantStreamProcessorProtocol
|
||||
from ee.hogai.utils.types.base import (
|
||||
AssistantDispatcherEvent,
|
||||
AssistantMessageUnion,
|
||||
AssistantMode,
|
||||
AssistantOutput,
|
||||
AssistantResultUnion,
|
||||
LangGraphUpdateEvent,
|
||||
)
|
||||
from ee.hogai.utils.stream_processor import AssistantStreamProcessor
|
||||
from ee.hogai.utils.types import AssistantMessageUnion, AssistantOutput
|
||||
from ee.hogai.utils.types.base import AssistantDispatcherEvent, AssistantMode, AssistantResultUnion, MessageChunkAction
|
||||
from ee.hogai.utils.types.composed import AssistantMaxGraphState, AssistantMaxPartialGraphState, MaxNodeName
|
||||
from ee.hogai.utils.types.composed import AssistantMaxGraphState, AssistantMaxPartialGraphState
|
||||
from ee.models import Conversation
|
||||
|
||||
logger = structlog.get_logger(__name__)
|
||||
@@ -66,7 +64,7 @@ class BaseAssistant(ABC):
|
||||
_trace_id: Optional[str | UUID]
|
||||
_billing_context: Optional[MaxBillingContext]
|
||||
_initial_state: Optional[AssistantMaxGraphState | AssistantMaxPartialGraphState]
|
||||
_stream_processor: AssistantStreamProcessor
|
||||
_stream_processor: AssistantStreamProcessorProtocol
|
||||
"""The stream processor that processes dispatcher actions and message chunks."""
|
||||
|
||||
def __init__(
|
||||
@@ -87,6 +85,7 @@ class BaseAssistant(ABC):
|
||||
billing_context: Optional[MaxBillingContext] = None,
|
||||
initial_state: Optional[AssistantMaxGraphState | AssistantMaxPartialGraphState] = None,
|
||||
callback_handler: Optional[BaseCallbackHandler] = None,
|
||||
stream_processor: AssistantStreamProcessorProtocol,
|
||||
):
|
||||
self._team = team
|
||||
self._contextual_tools = contextual_tools or {}
|
||||
@@ -121,22 +120,7 @@ class BaseAssistant(ABC):
|
||||
self._mode = mode
|
||||
self._initial_state = initial_state
|
||||
# Initialize the stream processor with node configuration
|
||||
self._stream_processor = AssistantStreamProcessor(
|
||||
streaming_nodes=self.STREAMING_NODES,
|
||||
visualization_nodes=self.VISUALIZATION_NODES,
|
||||
)
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def VISUALIZATION_NODES(self) -> dict[MaxNodeName, type[BaseAssistantNode]]:
|
||||
"""Nodes that can generate visualizations."""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def STREAMING_NODES(self) -> set[MaxNodeName]:
|
||||
"""Nodes that can stream messages to the client."""
|
||||
pass
|
||||
self._stream_processor = stream_processor
|
||||
|
||||
@abstractmethod
|
||||
def get_initial_state(self) -> AssistantMaxGraphState:
|
||||
@@ -175,7 +159,7 @@ class BaseAssistant(ABC):
|
||||
state = await self._init_or_update_state()
|
||||
config = self._get_config()
|
||||
|
||||
stream_mode: list[StreamMode] = ["values", "updates", "custom"]
|
||||
stream_mode: list[StreamMode] = ["values", "custom"]
|
||||
if stream_message_chunks:
|
||||
stream_mode.append("messages")
|
||||
|
||||
@@ -286,11 +270,11 @@ class BaseAssistant(ABC):
|
||||
# Add existing ids to streamed messages, so we don't send the messages again.
|
||||
for message in saved_state.messages:
|
||||
if message.id is not None:
|
||||
self._stream_processor._streamed_update_ids.add(message.id)
|
||||
self._stream_processor.mark_id_as_streamed(message.id)
|
||||
|
||||
# Add the latest message id to streamed messages, so we don't send it multiple times.
|
||||
if self._latest_message and self._latest_message.id is not None:
|
||||
self._stream_processor._streamed_update_ids.add(self._latest_message.id)
|
||||
self._stream_processor.mark_id_as_streamed(self._latest_message.id)
|
||||
|
||||
# If the graph previously hasn't reset the state, it is an interrupt. We resume from the point of interruption.
|
||||
if snapshot.next and self._latest_message and saved_state.graph_status == "interrupted":
|
||||
@@ -322,29 +306,12 @@ class BaseAssistant(ABC):
|
||||
async def _process_update(self, update: Any) -> list[AssistantResultUnion] | None:
|
||||
update = extract_stream_update(update)
|
||||
|
||||
new_message: AssistantResultUnion | None = None
|
||||
if not isinstance(update, AssistantDispatcherEvent):
|
||||
if is_state_update(update):
|
||||
_, new_state = update
|
||||
self._state = validate_state_update(new_state, self._state_type)
|
||||
elif is_value_update(update) and (new_message := await self._aprocess_value_update(update)):
|
||||
return [new_message]
|
||||
if updates := self._stream_processor.process_langgraph_update(LangGraphUpdateEvent(update=update)):
|
||||
return updates
|
||||
elif new_message := self._stream_processor.process(update):
|
||||
return new_message
|
||||
|
||||
if is_message_update(update):
|
||||
# Convert the message chunk update to a dispatcher event to prepare for a bright future without LangGraph
|
||||
message, state = update[1]
|
||||
if not isinstance(message, AIMessageChunk):
|
||||
return None
|
||||
update = AssistantDispatcherEvent(
|
||||
action=MessageChunkAction(message=message), node_name=state["langgraph_node"]
|
||||
)
|
||||
|
||||
if isinstance(update, AssistantDispatcherEvent) and (new_message := self._stream_processor.process(update)):
|
||||
return [new_message] if new_message else None
|
||||
|
||||
return None
|
||||
|
||||
async def _aprocess_value_update(self, update: GraphValueUpdateTuple) -> AssistantResultUnion | None:
|
||||
return None
|
||||
|
||||
def _build_root_config_for_persistence(self) -> RunnableConfig:
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Any, Optional
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
from uuid import UUID
|
||||
|
||||
from posthog.schema import AssistantMessage, HumanMessage, MaxBillingContext, VisualizationMessage
|
||||
@@ -7,13 +7,26 @@ from posthog.schema import AssistantMessage, HumanMessage, MaxBillingContext, Vi
|
||||
from posthog.models import Team, User
|
||||
|
||||
from ee.hogai.assistant.base import BaseAssistant
|
||||
from ee.hogai.graph import DeepResearchAssistantGraph
|
||||
from ee.hogai.graph.base import BaseAssistantNode
|
||||
from ee.hogai.graph.deep_research.graph import DeepResearchAssistantGraph
|
||||
from ee.hogai.graph.deep_research.types import DeepResearchNodeName, DeepResearchState, PartialDeepResearchState
|
||||
from ee.hogai.utils.stream_processor import AssistantStreamProcessor
|
||||
from ee.hogai.utils.types import AssistantMode, AssistantOutput
|
||||
from ee.hogai.utils.types.composed import MaxNodeName
|
||||
from ee.models import Conversation
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ee.hogai.utils.types.composed import MaxNodeName
|
||||
|
||||
|
||||
STREAMING_NODES: set["MaxNodeName"] = {
|
||||
DeepResearchNodeName.ONBOARDING,
|
||||
DeepResearchNodeName.PLANNER,
|
||||
DeepResearchNodeName.TASK_EXECUTOR,
|
||||
}
|
||||
|
||||
VERBOSE_NODES: set["MaxNodeName"] = STREAMING_NODES | {
|
||||
DeepResearchNodeName.PLANNER_TOOLS,
|
||||
}
|
||||
|
||||
|
||||
class DeepResearchAssistant(BaseAssistant):
|
||||
_state: Optional[DeepResearchState]
|
||||
@@ -48,20 +61,13 @@ class DeepResearchAssistant(BaseAssistant):
|
||||
trace_id=trace_id,
|
||||
billing_context=billing_context,
|
||||
initial_state=initial_state,
|
||||
stream_processor=AssistantStreamProcessor(
|
||||
verbose_nodes=VERBOSE_NODES,
|
||||
streaming_nodes=STREAMING_NODES,
|
||||
state_type=DeepResearchState,
|
||||
),
|
||||
)
|
||||
|
||||
@property
|
||||
def VISUALIZATION_NODES(self) -> dict[MaxNodeName, type[BaseAssistantNode]]:
|
||||
return {}
|
||||
|
||||
@property
|
||||
def STREAMING_NODES(self) -> set[MaxNodeName]:
|
||||
return {
|
||||
DeepResearchNodeName.ONBOARDING,
|
||||
DeepResearchNodeName.PLANNER,
|
||||
DeepResearchNodeName.TASK_EXECUTOR,
|
||||
}
|
||||
|
||||
def get_initial_state(self) -> DeepResearchState:
|
||||
if self._latest_message:
|
||||
return DeepResearchState(
|
||||
|
||||
@@ -1,36 +1,32 @@
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Any, Optional
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
from uuid import UUID
|
||||
|
||||
from posthog.schema import (
|
||||
AssistantGenerationStatusEvent,
|
||||
AssistantGenerationStatusType,
|
||||
AssistantMessage,
|
||||
HumanMessage,
|
||||
MaxBillingContext,
|
||||
VisualizationMessage,
|
||||
)
|
||||
from posthog.schema import AssistantMessage, HumanMessage, MaxBillingContext, VisualizationMessage
|
||||
|
||||
from posthog.models import Team, User
|
||||
|
||||
from ee.hogai.assistant.base import BaseAssistant
|
||||
from ee.hogai.graph import FunnelGeneratorNode, RetentionGeneratorNode, SQLGeneratorNode, TrendsGeneratorNode
|
||||
from ee.hogai.graph.base import BaseAssistantNode
|
||||
from ee.hogai.graph.graph import InsightsAssistantGraph
|
||||
from ee.hogai.graph.query_executor.nodes import QueryExecutorNode
|
||||
from ee.hogai.graph.taxonomy.types import TaxonomyNodeName
|
||||
from ee.hogai.utils.state import GraphValueUpdateTuple, validate_value_update
|
||||
from ee.hogai.utils.types import (
|
||||
AssistantMode,
|
||||
AssistantNodeName,
|
||||
AssistantOutput,
|
||||
AssistantState,
|
||||
PartialAssistantState,
|
||||
)
|
||||
from ee.hogai.utils.types.base import AssistantResultUnion
|
||||
from ee.hogai.utils.types.composed import MaxNodeName
|
||||
from ee.hogai.graph.insights_graph.graph import InsightsGraph
|
||||
from ee.hogai.utils.stream_processor import AssistantStreamProcessor
|
||||
from ee.hogai.utils.types import AssistantMode, AssistantOutput, AssistantState, PartialAssistantState
|
||||
from ee.hogai.utils.types.base import AssistantNodeName
|
||||
from ee.models import Conversation
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ee.hogai.utils.types.composed import MaxNodeName
|
||||
|
||||
|
||||
VERBOSE_NODES: set["MaxNodeName"] = {
|
||||
AssistantNodeName.QUERY_EXECUTOR,
|
||||
AssistantNodeName.FUNNEL_GENERATOR,
|
||||
AssistantNodeName.RETENTION_GENERATOR,
|
||||
AssistantNodeName.SQL_GENERATOR,
|
||||
AssistantNodeName.TRENDS_GENERATOR,
|
||||
AssistantNodeName.ROOT,
|
||||
AssistantNodeName.ROOT_TOOLS,
|
||||
}
|
||||
|
||||
|
||||
class InsightsAssistant(BaseAssistant):
|
||||
_state: Optional[AssistantState]
|
||||
@@ -55,7 +51,7 @@ class InsightsAssistant(BaseAssistant):
|
||||
conversation,
|
||||
new_message=new_message,
|
||||
user=user,
|
||||
graph=InsightsAssistantGraph(team, user).compile_full_graph(),
|
||||
graph=InsightsGraph(team, user).compile_full_graph(),
|
||||
state_type=AssistantState,
|
||||
partial_state_type=PartialAssistantState,
|
||||
mode=AssistantMode.INSIGHTS_TOOL,
|
||||
@@ -65,24 +61,11 @@ class InsightsAssistant(BaseAssistant):
|
||||
trace_id=trace_id,
|
||||
billing_context=billing_context,
|
||||
initial_state=initial_state,
|
||||
stream_processor=AssistantStreamProcessor(
|
||||
verbose_nodes=VERBOSE_NODES, streaming_nodes=set(), state_type=AssistantState
|
||||
),
|
||||
)
|
||||
|
||||
@property
|
||||
def VISUALIZATION_NODES(self) -> dict[MaxNodeName, type[BaseAssistantNode]]:
|
||||
return {
|
||||
AssistantNodeName.TRENDS_GENERATOR: TrendsGeneratorNode,
|
||||
AssistantNodeName.FUNNEL_GENERATOR: FunnelGeneratorNode,
|
||||
AssistantNodeName.RETENTION_GENERATOR: RetentionGeneratorNode,
|
||||
AssistantNodeName.SQL_GENERATOR: SQLGeneratorNode,
|
||||
AssistantNodeName.QUERY_EXECUTOR: QueryExecutorNode,
|
||||
}
|
||||
|
||||
@property
|
||||
def STREAMING_NODES(self) -> set[MaxNodeName]:
|
||||
return {
|
||||
TaxonomyNodeName.LOOP_NODE,
|
||||
}
|
||||
|
||||
def get_initial_state(self) -> AssistantState:
|
||||
return AssistantState(messages=[])
|
||||
|
||||
@@ -127,13 +110,3 @@ class InsightsAssistant(BaseAssistant):
|
||||
"is_new_conversation": False,
|
||||
},
|
||||
)
|
||||
|
||||
async def _aprocess_value_update(self, update: GraphValueUpdateTuple) -> AssistantResultUnion | None:
|
||||
_, maybe_state_update = update
|
||||
state_update = validate_value_update(maybe_state_update)
|
||||
if intersected_nodes := state_update.keys() & self.VISUALIZATION_NODES.keys():
|
||||
node_name: MaxNodeName = intersected_nodes.pop()
|
||||
node_val = state_update[node_name]
|
||||
if isinstance(node_val, PartialAssistantState) and node_val.intermediate_steps:
|
||||
return AssistantGenerationStatusEvent(type=AssistantGenerationStatusType.GENERATION_ERROR)
|
||||
return await super()._aprocess_value_update(update)
|
||||
|
||||
@@ -1,29 +1,15 @@
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Any, Optional
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
from uuid import UUID
|
||||
|
||||
from posthog.schema import (
|
||||
AssistantGenerationStatusEvent,
|
||||
AssistantGenerationStatusType,
|
||||
AssistantMessage,
|
||||
HumanMessage,
|
||||
MaxBillingContext,
|
||||
VisualizationMessage,
|
||||
)
|
||||
from posthog.schema import AssistantMessage, HumanMessage, MaxBillingContext, VisualizationMessage
|
||||
|
||||
from posthog.models import Team, User
|
||||
|
||||
from ee.hogai.assistant.base import BaseAssistant
|
||||
from ee.hogai.graph import (
|
||||
AssistantGraph,
|
||||
FunnelGeneratorNode,
|
||||
RetentionGeneratorNode,
|
||||
SQLGeneratorNode,
|
||||
TrendsGeneratorNode,
|
||||
)
|
||||
from ee.hogai.graph.base import BaseAssistantNode
|
||||
from ee.hogai.graph.insights.nodes import InsightSearchNode
|
||||
from ee.hogai.utils.state import GraphValueUpdateTuple, validate_value_update
|
||||
from ee.hogai.graph.graph import AssistantGraph
|
||||
from ee.hogai.graph.taxonomy.types import TaxonomyNodeName
|
||||
from ee.hogai.utils.stream_processor import AssistantStreamProcessor
|
||||
from ee.hogai.utils.types import (
|
||||
AssistantMode,
|
||||
AssistantNodeName,
|
||||
@@ -31,10 +17,36 @@ from ee.hogai.utils.types import (
|
||||
AssistantState,
|
||||
PartialAssistantState,
|
||||
)
|
||||
from ee.hogai.utils.types.base import AssistantResultUnion
|
||||
from ee.hogai.utils.types.composed import MaxNodeName
|
||||
from ee.models import Conversation
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ee.hogai.utils.types.composed import MaxNodeName
|
||||
|
||||
|
||||
STREAMING_NODES: set["MaxNodeName"] = {
|
||||
AssistantNodeName.ROOT,
|
||||
AssistantNodeName.INKEEP_DOCS,
|
||||
AssistantNodeName.MEMORY_ONBOARDING,
|
||||
AssistantNodeName.MEMORY_INITIALIZER,
|
||||
AssistantNodeName.MEMORY_ONBOARDING_ENQUIRY,
|
||||
AssistantNodeName.MEMORY_ONBOARDING_FINALIZE,
|
||||
AssistantNodeName.DASHBOARD_CREATION,
|
||||
}
|
||||
|
||||
|
||||
VERBOSE_NODES: set["MaxNodeName"] = {
|
||||
AssistantNodeName.TRENDS_GENERATOR,
|
||||
AssistantNodeName.FUNNEL_GENERATOR,
|
||||
AssistantNodeName.RETENTION_GENERATOR,
|
||||
AssistantNodeName.SQL_GENERATOR,
|
||||
AssistantNodeName.INSIGHTS_SEARCH,
|
||||
AssistantNodeName.QUERY_EXECUTOR,
|
||||
AssistantNodeName.MEMORY_INITIALIZER_INTERRUPT,
|
||||
AssistantNodeName.ROOT_TOOLS,
|
||||
TaxonomyNodeName.TOOLS_NODE,
|
||||
TaxonomyNodeName.TASK_EXECUTOR,
|
||||
}
|
||||
|
||||
|
||||
class MainAssistant(BaseAssistant):
|
||||
_state: Optional[AssistantState]
|
||||
@@ -69,30 +81,11 @@ class MainAssistant(BaseAssistant):
|
||||
trace_id=trace_id,
|
||||
billing_context=billing_context,
|
||||
initial_state=initial_state,
|
||||
stream_processor=AssistantStreamProcessor(
|
||||
verbose_nodes=VERBOSE_NODES, streaming_nodes=STREAMING_NODES, state_type=AssistantState
|
||||
),
|
||||
)
|
||||
|
||||
@property
|
||||
def VISUALIZATION_NODES(self) -> dict[MaxNodeName, type[BaseAssistantNode]]:
|
||||
return {
|
||||
AssistantNodeName.TRENDS_GENERATOR: TrendsGeneratorNode,
|
||||
AssistantNodeName.FUNNEL_GENERATOR: FunnelGeneratorNode,
|
||||
AssistantNodeName.RETENTION_GENERATOR: RetentionGeneratorNode,
|
||||
AssistantNodeName.SQL_GENERATOR: SQLGeneratorNode,
|
||||
AssistantNodeName.INSIGHTS_SEARCH: InsightSearchNode,
|
||||
}
|
||||
|
||||
@property
|
||||
def STREAMING_NODES(self) -> set[MaxNodeName]:
|
||||
return {
|
||||
AssistantNodeName.ROOT,
|
||||
AssistantNodeName.INKEEP_DOCS,
|
||||
AssistantNodeName.MEMORY_ONBOARDING,
|
||||
AssistantNodeName.MEMORY_INITIALIZER,
|
||||
AssistantNodeName.MEMORY_ONBOARDING_ENQUIRY,
|
||||
AssistantNodeName.MEMORY_ONBOARDING_FINALIZE,
|
||||
AssistantNodeName.DASHBOARD_CREATION,
|
||||
}
|
||||
|
||||
def get_initial_state(self) -> AssistantState:
|
||||
if self._latest_message:
|
||||
return AssistantState(
|
||||
@@ -142,13 +135,3 @@ class MainAssistant(BaseAssistant):
|
||||
"is_new_conversation": self._is_new_conversation,
|
||||
},
|
||||
)
|
||||
|
||||
async def _aprocess_value_update(self, update: GraphValueUpdateTuple) -> AssistantResultUnion | None:
|
||||
_, maybe_state_update = update
|
||||
state_update = validate_value_update(maybe_state_update)
|
||||
if intersected_nodes := state_update.keys() & self.VISUALIZATION_NODES.keys():
|
||||
node_name: MaxNodeName = intersected_nodes.pop()
|
||||
node_val = state_update[node_name]
|
||||
if isinstance(node_val, PartialAssistantState) and node_val.intermediate_steps:
|
||||
return AssistantGenerationStatusEvent(type=AssistantGenerationStatusType.GENERATION_ERROR)
|
||||
return await super()._aprocess_value_update(update)
|
||||
|
||||
@@ -396,7 +396,7 @@ class AssistantContextManager(AssistantContextMixin):
|
||||
|
||||
contextual_tools_prompt = [
|
||||
f"<{tool_name}>\n"
|
||||
f"{get_contextual_tool_class(tool_name)(team=self._team, user=self._user, tool_call_id="").format_context_prompt_injection(tool_context)}\n" # type: ignore
|
||||
f"{get_contextual_tool_class(tool_name)(team=self._team, user=self._user).format_context_prompt_injection(tool_context)}\n" # type: ignore
|
||||
f"</{tool_name}>"
|
||||
for tool_name, tool_context in self.get_contextual_tools().items()
|
||||
if get_contextual_tool_class(tool_name) is not None
|
||||
|
||||
@@ -507,7 +507,7 @@ Query results: 42 events
|
||||
# Mock the tool class
|
||||
mock_tool = MagicMock()
|
||||
mock_tool.format_context_prompt_injection.return_value = "Tool system prompt"
|
||||
mock_get_contextual_tool_class.return_value = lambda team, user, tool_call_id: mock_tool
|
||||
mock_get_contextual_tool_class.return_value = lambda team, user: mock_tool
|
||||
|
||||
config = RunnableConfig(
|
||||
configurable={"contextual_tools": {"search_session_recordings": {"current_filters": {}}}}
|
||||
|
||||
@@ -52,7 +52,7 @@ from ee.hogai.eval.base import MaxPrivateEval
|
||||
from ee.hogai.eval.offline.conftest import EvaluationContext, capture_score, get_eval_context
|
||||
from ee.hogai.eval.schema import DatasetInput
|
||||
from ee.hogai.eval.scorers.sql import SQLSemanticsCorrectness, SQLSyntaxCorrectness
|
||||
from ee.hogai.graph import AssistantGraph
|
||||
from ee.hogai.graph.graph import AssistantGraph
|
||||
from ee.hogai.utils.types import AssistantState
|
||||
from ee.models import Conversation
|
||||
|
||||
|
||||
@@ -19,7 +19,7 @@ from ee.hogai.django_checkpoint.checkpointer import DjangoCheckpointer
|
||||
# We want the PostHog set_up_evals fixture here
|
||||
from ee.hogai.eval.conftest import set_up_evals # noqa: F401
|
||||
from ee.hogai.eval.scorers import PlanAndQueryOutput
|
||||
from ee.hogai.graph.graph import AssistantGraph, InsightsAssistantGraph
|
||||
from ee.hogai.graph.graph import AssistantGraph
|
||||
from ee.hogai.utils.types import AssistantNodeName, AssistantState
|
||||
from ee.models.assistant import Conversation, CoreMemory
|
||||
|
||||
@@ -34,28 +34,10 @@ EVAL_USER_FULL_NAME = "Karen Smith"
|
||||
@pytest.fixture
|
||||
def call_root_for_insight_generation(demo_org_team_user):
|
||||
# This graph structure will first get a plan, then generate the SQL query.
|
||||
|
||||
insights_subgraph = (
|
||||
# Insights subgraph without query execution, so we only create the queries
|
||||
InsightsAssistantGraph(demo_org_team_user[1], demo_org_team_user[2])
|
||||
.add_query_creation_flow(next_node=AssistantNodeName.END)
|
||||
.compile()
|
||||
)
|
||||
graph = (
|
||||
AssistantGraph(demo_org_team_user[1], demo_org_team_user[2])
|
||||
.add_edge(AssistantNodeName.START, AssistantNodeName.ROOT)
|
||||
.add_root(
|
||||
path_map={
|
||||
"insights": AssistantNodeName.INSIGHTS_SUBGRAPH,
|
||||
"insights_search": AssistantNodeName.INSIGHTS_SEARCH,
|
||||
"root": AssistantNodeName.ROOT,
|
||||
"search_documentation": AssistantNodeName.END,
|
||||
"end": AssistantNodeName.END,
|
||||
}
|
||||
)
|
||||
.add_node(AssistantNodeName.INSIGHTS_SUBGRAPH, insights_subgraph)
|
||||
.add_edge(AssistantNodeName.INSIGHTS_SUBGRAPH, AssistantNodeName.END)
|
||||
.add_insights_search()
|
||||
.add_root()
|
||||
# TRICKY: We need to set a checkpointer here because async tests create a new event loop.
|
||||
.compile(checkpointer=DjangoCheckpointer())
|
||||
)
|
||||
@@ -85,15 +67,21 @@ def call_root_for_insight_generation(demo_org_team_user):
|
||||
final_state_raw = await graph.ainvoke(final_state, {"configurable": {"thread_id": conversation.id}})
|
||||
final_state = AssistantState.model_validate(final_state_raw)
|
||||
|
||||
if not final_state.messages or not isinstance(final_state.messages[-1], VisualizationMessage):
|
||||
# The order is a viz message, tool call message, and assistant message.
|
||||
if (
|
||||
not final_state.messages
|
||||
or not len(final_state.messages) >= 3
|
||||
or not isinstance(final_state.messages[-3], VisualizationMessage)
|
||||
):
|
||||
return {
|
||||
"plan": None,
|
||||
"query": None,
|
||||
"query_generation_retry_count": final_state.query_generation_retry_count,
|
||||
}
|
||||
|
||||
return {
|
||||
"plan": final_state.messages[-1].plan,
|
||||
"query": final_state.messages[-1].answer,
|
||||
"plan": final_state.messages[-3].plan,
|
||||
"query": final_state.messages[-3].answer,
|
||||
"query_generation_retry_count": final_state.query_generation_retry_count,
|
||||
}
|
||||
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from braintrust import EvalCase
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
@@ -7,8 +6,8 @@ from langchain_core.runnables import RunnableConfig
|
||||
from posthog.schema import AssistantMessage, AssistantToolCall, HumanMessage
|
||||
|
||||
from ee.hogai.django_checkpoint.checkpointer import DjangoCheckpointer
|
||||
from ee.hogai.graph import AssistantGraph
|
||||
from ee.hogai.graph.dashboards.nodes import DashboardCreationNode
|
||||
from ee.hogai.graph.graph import AssistantGraph
|
||||
from ee.hogai.utils.types import AssistantMessageUnion, AssistantNodeName, AssistantState, PartialAssistantState
|
||||
from ee.models.assistant import Conversation
|
||||
|
||||
@@ -21,18 +20,7 @@ def call_root_for_dashboard_creation(demo_org_team_user):
|
||||
graph = (
|
||||
AssistantGraph(demo_org_team_user[1], demo_org_team_user[2])
|
||||
.add_edge(AssistantNodeName.START, AssistantNodeName.ROOT)
|
||||
.add_root(
|
||||
{
|
||||
"create_dashboard": AssistantNodeName.END,
|
||||
"insights": AssistantNodeName.END,
|
||||
"search_documentation": AssistantNodeName.END,
|
||||
"session_summarization": AssistantNodeName.END,
|
||||
"insights_search": AssistantNodeName.END,
|
||||
"create_and_query_insight": AssistantNodeName.END,
|
||||
"root": AssistantNodeName.END,
|
||||
"end": AssistantNodeName.END,
|
||||
}
|
||||
)
|
||||
.add_root(lambda state: AssistantNodeName.END)
|
||||
.compile(checkpointer=DjangoCheckpointer())
|
||||
)
|
||||
|
||||
@@ -172,9 +160,7 @@ async def eval_tool_routing_dashboard_creation(call_root_for_dashboard_creation,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
@patch("ee.hogai.graph.base.get_stream_writer", return_value=MagicMock())
|
||||
async def eval_tool_call_dashboard_creation(patch_get_stream_writer, pytestconfig, demo_org_team_user):
|
||||
async def eval_tool_call_dashboard_creation(pytestconfig, demo_org_team_user):
|
||||
conversation = await Conversation.objects.acreate(team=demo_org_team_user[1], user=demo_org_team_user[2])
|
||||
dashboard_creation_node = DashboardCreationNode(demo_org_team_user[1], demo_org_team_user[2])
|
||||
|
||||
|
||||
@@ -8,7 +8,7 @@ from braintrust import EvalCase
|
||||
from posthog.schema import HumanMessage, VisualizationMessage
|
||||
|
||||
from ee.hogai.django_checkpoint.checkpointer import DjangoCheckpointer
|
||||
from ee.hogai.graph import AssistantGraph
|
||||
from ee.hogai.graph.graph import AssistantGraph
|
||||
from ee.hogai.utils.types import AssistantNodeName, AssistantState
|
||||
from ee.models.assistant import Conversation
|
||||
|
||||
@@ -84,16 +84,7 @@ def call_insight_search(demo_org_team_user):
|
||||
graph = (
|
||||
AssistantGraph(demo_org_team_user[1], demo_org_team_user[2])
|
||||
.add_edge(AssistantNodeName.START, AssistantNodeName.ROOT)
|
||||
.add_root(
|
||||
{
|
||||
"insights": AssistantNodeName.END,
|
||||
"search_documentation": AssistantNodeName.END,
|
||||
"root": AssistantNodeName.END,
|
||||
"end": AssistantNodeName.END,
|
||||
"insights_search": AssistantNodeName.INSIGHTS_SEARCH,
|
||||
}
|
||||
)
|
||||
.add_insights_search()
|
||||
.add_root()
|
||||
.compile(checkpointer=DjangoCheckpointer())
|
||||
)
|
||||
|
||||
|
||||
@@ -9,7 +9,7 @@ from langchain_core.messages import AIMessage as LangchainAIMessage
|
||||
from posthog.schema import AssistantMessage, AssistantToolCall, HumanMessage
|
||||
|
||||
from ee.hogai.django_checkpoint.checkpointer import DjangoCheckpointer
|
||||
from ee.hogai.graph import AssistantGraph
|
||||
from ee.hogai.graph.graph import AssistantGraph
|
||||
from ee.hogai.utils.types import AssistantNodeName, AssistantState
|
||||
from ee.models.assistant import Conversation
|
||||
|
||||
|
||||
@@ -13,7 +13,7 @@ from posthog.models.user import User
|
||||
from posthog.sync import database_sync_to_async
|
||||
|
||||
from ee.hogai.django_checkpoint.checkpointer import DjangoCheckpointer
|
||||
from ee.hogai.graph import AssistantGraph
|
||||
from ee.hogai.graph.graph import AssistantGraph
|
||||
from ee.hogai.graph.memory.prompts import (
|
||||
ENQUIRY_INITIAL_MESSAGE,
|
||||
SCRAPING_SUCCESS_KEY_PHRASE,
|
||||
|
||||
@@ -7,7 +7,7 @@ from braintrust import EvalCase
|
||||
from posthog.schema import AssistantMessage, AssistantToolCall, AssistantToolCallMessage, HumanMessage
|
||||
|
||||
from ee.hogai.django_checkpoint.checkpointer import DjangoCheckpointer
|
||||
from ee.hogai.graph import AssistantGraph
|
||||
from ee.hogai.graph.graph import AssistantGraph
|
||||
from ee.hogai.utils.types import AssistantMessageUnion, AssistantNodeName, AssistantState
|
||||
from ee.models.assistant import Conversation
|
||||
|
||||
@@ -20,16 +20,7 @@ def call_root(demo_org_team_user):
|
||||
graph = (
|
||||
AssistantGraph(demo_org_team_user[1], demo_org_team_user[2])
|
||||
.add_edge(AssistantNodeName.START, AssistantNodeName.ROOT)
|
||||
.add_root(
|
||||
{
|
||||
"insights": AssistantNodeName.END,
|
||||
"billing": AssistantNodeName.END,
|
||||
"insights_search": AssistantNodeName.END,
|
||||
"search_documentation": AssistantNodeName.END,
|
||||
"root": AssistantNodeName.ROOT,
|
||||
"end": AssistantNodeName.END,
|
||||
}
|
||||
)
|
||||
.add_root(lambda state: AssistantNodeName.END)
|
||||
# TRICKY: We need to set a checkpointer here because async tests create a new event loop.
|
||||
.compile(checkpointer=DjangoCheckpointer())
|
||||
)
|
||||
|
||||
@@ -5,7 +5,7 @@ from braintrust import EvalCase
|
||||
from posthog.schema import AssistantMessage, AssistantToolCall, HumanMessage
|
||||
|
||||
from ee.hogai.django_checkpoint.checkpointer import DjangoCheckpointer
|
||||
from ee.hogai.graph import AssistantGraph
|
||||
from ee.hogai.graph.graph import AssistantGraph
|
||||
from ee.hogai.utils.types import AssistantNodeName, AssistantState
|
||||
from ee.models.assistant import Conversation
|
||||
|
||||
@@ -18,17 +18,7 @@ def call_root(demo_org_team_user):
|
||||
graph = (
|
||||
AssistantGraph(demo_org_team_user[1], demo_org_team_user[2])
|
||||
.add_edge(AssistantNodeName.START, AssistantNodeName.ROOT)
|
||||
.add_root(
|
||||
path_map={
|
||||
"insights": AssistantNodeName.END,
|
||||
"billing": AssistantNodeName.END,
|
||||
"insights_search": AssistantNodeName.END,
|
||||
"search_documentation": AssistantNodeName.END,
|
||||
"root": AssistantNodeName.ROOT,
|
||||
"end": AssistantNodeName.END,
|
||||
},
|
||||
tools_node=AssistantNodeName.END,
|
||||
)
|
||||
.add_root(router=lambda state: AssistantNodeName.END)
|
||||
# TRICKY: We need to set a checkpointer here because async tests create a new event loop.
|
||||
.compile(checkpointer=DjangoCheckpointer())
|
||||
)
|
||||
|
||||
@@ -5,7 +5,7 @@ from braintrust import EvalCase
|
||||
from posthog.schema import AssistantMessage, AssistantToolCall, HumanMessage
|
||||
|
||||
from ee.hogai.django_checkpoint.checkpointer import DjangoCheckpointer
|
||||
from ee.hogai.graph import AssistantGraph
|
||||
from ee.hogai.graph.graph import AssistantGraph
|
||||
from ee.hogai.utils.types import AssistantNodeName, AssistantState
|
||||
from ee.models.assistant import Conversation
|
||||
|
||||
@@ -18,17 +18,7 @@ def call_root(demo_org_team_user):
|
||||
graph = (
|
||||
AssistantGraph(demo_org_team_user[1], demo_org_team_user[2])
|
||||
.add_edge(AssistantNodeName.START, AssistantNodeName.ROOT)
|
||||
.add_root(
|
||||
path_map={
|
||||
"insights": AssistantNodeName.END,
|
||||
"billing": AssistantNodeName.END,
|
||||
"insights_search": AssistantNodeName.END,
|
||||
"search_documentation": AssistantNodeName.END,
|
||||
"root": AssistantNodeName.ROOT,
|
||||
"end": AssistantNodeName.END,
|
||||
},
|
||||
tools_node=AssistantNodeName.END,
|
||||
)
|
||||
.add_root(router=lambda state: AssistantNodeName.END)
|
||||
# TRICKY: We need to set a checkpointer here because async tests create a new event loop.
|
||||
.compile(checkpointer=DjangoCheckpointer())
|
||||
)
|
||||
|
||||
@@ -5,7 +5,7 @@ from braintrust import EvalCase
|
||||
|
||||
from posthog.schema import AssistantMessage, HumanMessage
|
||||
|
||||
from ee.hogai.graph import AssistantGraph
|
||||
from ee.hogai.graph.graph import AssistantGraph
|
||||
from ee.hogai.utils.types import AssistantMessageUnion, AssistantNodeName, AssistantState
|
||||
from ee.models.assistant import Conversation
|
||||
|
||||
@@ -60,19 +60,7 @@ def call_root(demo_org_team_user):
|
||||
graph = (
|
||||
AssistantGraph(demo_org_team_user[1], demo_org_team_user[2])
|
||||
.add_edge(AssistantNodeName.START, AssistantNodeName.ROOT)
|
||||
.add_root(
|
||||
{
|
||||
# Some requests will go via Inkeep, and this is realistic! Inkeep needs to adhere to our intended style too
|
||||
"search_documentation": AssistantNodeName.INKEEP_DOCS,
|
||||
"root": AssistantNodeName.ROOT,
|
||||
"billing": AssistantNodeName.END,
|
||||
"insights": AssistantNodeName.END,
|
||||
"insights_search": AssistantNodeName.END,
|
||||
"session_summarization": AssistantNodeName.END,
|
||||
"end": AssistantNodeName.END,
|
||||
}
|
||||
)
|
||||
.add_inkeep_docs()
|
||||
.add_root(lambda state: AssistantNodeName.END)
|
||||
.compile()
|
||||
)
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ from langchain_core.runnables import RunnableConfig
|
||||
from posthog.schema import AssistantMessage, AssistantToolCall, HumanMessage
|
||||
|
||||
from ee.hogai.django_checkpoint.checkpointer import DjangoCheckpointer
|
||||
from ee.hogai.graph import AssistantGraph
|
||||
from ee.hogai.graph.graph import AssistantGraph
|
||||
from ee.hogai.graph.session_summaries.nodes import _SessionSearch
|
||||
from ee.hogai.utils.types import AssistantMessageUnion, AssistantNodeName, AssistantState
|
||||
from ee.hogai.utils.yaml import load_yaml_from_raw_llm_content
|
||||
@@ -22,16 +22,7 @@ def call_root_for_replay_sessions(demo_org_team_user):
|
||||
graph = (
|
||||
AssistantGraph(demo_org_team_user[1], demo_org_team_user[2])
|
||||
.add_edge(AssistantNodeName.START, AssistantNodeName.ROOT)
|
||||
.add_root(
|
||||
{
|
||||
"insights": AssistantNodeName.END,
|
||||
"search_documentation": AssistantNodeName.END,
|
||||
"session_summarization": AssistantNodeName.END,
|
||||
"insights_search": AssistantNodeName.END,
|
||||
"root": AssistantNodeName.END,
|
||||
"end": AssistantNodeName.END,
|
||||
}
|
||||
)
|
||||
.add_root(lambda state: AssistantNodeName.END)
|
||||
.compile(checkpointer=DjangoCheckpointer())
|
||||
)
|
||||
|
||||
|
||||
@@ -126,9 +126,7 @@ def call_surveys_max_tool(demo_org_team_user, create_feature_flags):
|
||||
"change": f"Create a survey based on these instructions: {instructions}",
|
||||
"output": None,
|
||||
}
|
||||
graph = FeatureFlagLookupGraph(team=team, user=user, tool_call_id="test-tool-call-id").compile_full_graph(
|
||||
checkpointer=DjangoCheckpointer()
|
||||
)
|
||||
graph = FeatureFlagLookupGraph(team=team, user=user).compile_full_graph(checkpointer=DjangoCheckpointer())
|
||||
result = await graph.ainvoke(
|
||||
graph_context,
|
||||
config={
|
||||
|
||||
@@ -15,7 +15,7 @@ from posthog.models.action.action import Action
|
||||
from posthog.models.team.team import Team
|
||||
|
||||
from ee.hogai.django_checkpoint.checkpointer import DjangoCheckpointer
|
||||
from ee.hogai.graph import AssistantGraph
|
||||
from ee.hogai.graph.graph import AssistantGraph
|
||||
from ee.hogai.utils.types import AssistantNodeName, AssistantState
|
||||
from ee.models.assistant import Conversation
|
||||
|
||||
@@ -29,14 +29,7 @@ def call_root_with_ui_context(demo_org_team_user):
|
||||
graph = (
|
||||
AssistantGraph(demo_org_team_user[1], demo_org_team_user[2])
|
||||
.add_edge(AssistantNodeName.START, AssistantNodeName.ROOT)
|
||||
.add_root(
|
||||
{
|
||||
"insights": AssistantNodeName.END,
|
||||
"docs": AssistantNodeName.END,
|
||||
"root": AssistantNodeName.END,
|
||||
"end": AssistantNodeName.END,
|
||||
}
|
||||
)
|
||||
.add_root(lambda state: AssistantNodeName.END)
|
||||
# TRICKY: We need to set a checkpointer here because async tests create a new event loop.
|
||||
.compile(checkpointer=DjangoCheckpointer())
|
||||
)
|
||||
|
||||
@@ -81,7 +81,6 @@ async def eval_create_feature_flag_basic(pytestconfig, demo_org_team_user):
|
||||
tool = await CreateFeatureFlagTool.create_tool_class(
|
||||
team=team,
|
||||
user=user,
|
||||
tool_call_id="test-eval-call",
|
||||
state=AssistantState(messages=[]),
|
||||
config={
|
||||
"configurable": {
|
||||
@@ -148,7 +147,6 @@ async def eval_create_feature_flag_with_rollout(pytestconfig, demo_org_team_user
|
||||
tool = await CreateFeatureFlagTool.create_tool_class(
|
||||
team=team,
|
||||
user=user,
|
||||
tool_call_id="test-eval-call",
|
||||
state=AssistantState(messages=[]),
|
||||
config={
|
||||
"configurable": {
|
||||
@@ -241,7 +239,6 @@ async def eval_create_feature_flag_with_property_filters(pytestconfig, demo_org_
|
||||
tool = await CreateFeatureFlagTool.create_tool_class(
|
||||
team=team,
|
||||
user=user,
|
||||
tool_call_id="test-eval-call",
|
||||
state=AssistantState(messages=[]),
|
||||
config={
|
||||
"configurable": {
|
||||
@@ -336,7 +333,6 @@ async def eval_create_feature_flag_duplicate_handling(pytestconfig, demo_org_tea
|
||||
tool = await CreateFeatureFlagTool.create_tool_class(
|
||||
team=team,
|
||||
user=user,
|
||||
tool_call_id="test-eval-call",
|
||||
state=AssistantState(messages=[]),
|
||||
config={
|
||||
"configurable": {
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from braintrust import EvalCase
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
@@ -13,18 +12,7 @@ from ee.hogai.eval.scorers import SemanticSimilarity
|
||||
from ee.models.assistant import Conversation
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_kafka_producer():
|
||||
"""Mock Kafka producer to prevent Kafka errors in tests."""
|
||||
with patch("posthog.kafka_client.client._KafkaProducer.produce") as mock_produce:
|
||||
mock_future = MagicMock()
|
||||
mock_produce.return_value = mock_future
|
||||
yield
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
@patch("ee.hogai.graph.base.get_stream_writer", return_value=MagicMock())
|
||||
async def eval_insights_addition(patch_get_stream_writer, pytestconfig, demo_org_team_user):
|
||||
async def eval_insights_addition(pytestconfig, demo_org_team_user):
|
||||
"""Test that adding insights to dashboard executes correctly."""
|
||||
|
||||
dashboard = await Dashboard.objects.acreate(
|
||||
|
||||
@@ -6,7 +6,7 @@ from pydantic import BaseModel, Field
|
||||
from posthog.schema import AssistantMessage, AssistantNavigateUrl, AssistantToolCall, FailureMessage, HumanMessage
|
||||
|
||||
from ee.hogai.django_checkpoint.checkpointer import DjangoCheckpointer
|
||||
from ee.hogai.graph import AssistantGraph
|
||||
from ee.hogai.graph.graph import AssistantGraph
|
||||
from ee.hogai.utils.types import AssistantMessageUnion, AssistantNodeName, AssistantState
|
||||
from ee.models.assistant import Conversation
|
||||
|
||||
@@ -48,16 +48,7 @@ def call_root(demo_org_team_user):
|
||||
graph = (
|
||||
AssistantGraph(demo_org_team_user[1], demo_org_team_user[2])
|
||||
.add_edge(AssistantNodeName.START, AssistantNodeName.ROOT)
|
||||
.add_root(
|
||||
{
|
||||
"insights": AssistantNodeName.END,
|
||||
"billing": AssistantNodeName.END,
|
||||
"insights_search": AssistantNodeName.END,
|
||||
"search_documentation": AssistantNodeName.END,
|
||||
"root": AssistantNodeName.ROOT,
|
||||
"end": AssistantNodeName.END,
|
||||
}
|
||||
)
|
||||
.add_root(lambda state: AssistantNodeName.END)
|
||||
# TRICKY: We need to set a checkpointer here because async tests create a new event loop.
|
||||
.compile(checkpointer=DjangoCheckpointer())
|
||||
)
|
||||
|
||||
@@ -40,9 +40,9 @@ DUMMY_CURRENT_FILTERS = RevenueAnalyticsAssistantFilters(
|
||||
|
||||
@pytest.fixture
|
||||
def call_filter_revenue_analytics(demo_org_team_user):
|
||||
graph = RevenueAnalyticsFilterOptionsGraph(
|
||||
demo_org_team_user[1], demo_org_team_user[2], tool_call_id="test-tool-call-id"
|
||||
).compile_full_graph(checkpointer=DjangoCheckpointer())
|
||||
graph = RevenueAnalyticsFilterOptionsGraph(demo_org_team_user[1], demo_org_team_user[2]).compile_full_graph(
|
||||
checkpointer=DjangoCheckpointer()
|
||||
)
|
||||
|
||||
async def callable(change: str) -> dict:
|
||||
conversation = await Conversation.objects.acreate(team=demo_org_team_user[1], user=demo_org_team_user[2])
|
||||
|
||||
@@ -48,9 +48,9 @@ DUMMY_CURRENT_FILTERS = MaxRecordingUniversalFilters(
|
||||
|
||||
@pytest.fixture
|
||||
def call_search_session_recordings(demo_org_team_user):
|
||||
graph = SessionReplayFilterOptionsGraph(
|
||||
demo_org_team_user[1], demo_org_team_user[2], tool_call_id="test-tool-call-id"
|
||||
).compile_full_graph(checkpointer=DjangoCheckpointer())
|
||||
graph = SessionReplayFilterOptionsGraph(demo_org_team_user[1], demo_org_team_user[2]).compile_full_graph(
|
||||
checkpointer=DjangoCheckpointer()
|
||||
)
|
||||
|
||||
async def callable(change: str) -> dict:
|
||||
conversation = await Conversation.objects.acreate(team=demo_org_team_user[1], user=demo_org_team_user[2])
|
||||
|
||||
@@ -48,9 +48,9 @@ DUMMY_CURRENT_FILTERS = MaxRecordingUniversalFilters(
|
||||
|
||||
@pytest.fixture
|
||||
def call_search_session_recordings(demo_org_team_user):
|
||||
graph = SessionReplayFilterOptionsGraph(
|
||||
demo_org_team_user[1], demo_org_team_user[2], tool_call_id="test-tool-call-id"
|
||||
).compile_full_graph(checkpointer=DjangoCheckpointer())
|
||||
graph = SessionReplayFilterOptionsGraph(demo_org_team_user[1], demo_org_team_user[2]).compile_full_graph(
|
||||
checkpointer=DjangoCheckpointer()
|
||||
)
|
||||
|
||||
async def callable(change: str) -> dict:
|
||||
conversation = await Conversation.objects.acreate(team=demo_org_team_user[1], user=demo_org_team_user[2])
|
||||
|
||||
@@ -18,7 +18,7 @@ from ee.hogai.eval.base import MaxPrivateEval
|
||||
from ee.hogai.eval.offline.conftest import EvaluationContext, capture_score, get_eval_context
|
||||
from ee.hogai.eval.schema import DatasetInput
|
||||
from ee.hogai.eval.scorers.sql import SQLSemanticsCorrectness, SQLSyntaxCorrectness
|
||||
from ee.hogai.graph import AssistantGraph
|
||||
from ee.hogai.graph.graph import AssistantGraph
|
||||
from ee.hogai.utils.helpers import find_last_message_of_type
|
||||
from ee.hogai.utils.types import AssistantState
|
||||
from ee.hogai.utils.warehouse import serialize_database_schema
|
||||
|
||||
@@ -41,27 +41,29 @@ class ToolRelevance(ScorerWithPartial):
|
||||
raise TypeError(f"Eval case expected must be an AssistantToolCall, not {type(expected)}")
|
||||
if not isinstance(output, AssistantMessage):
|
||||
raise TypeError(f"Eval case output must be an AssistantMessage, not {type(output)}")
|
||||
if output.tool_calls and len(output.tool_calls) > 1:
|
||||
raise ValueError("Parallel tool calls not supported by this scorer yet")
|
||||
score = 0.0 # 0.0 to 1.0
|
||||
if output.tool_calls and len(output.tool_calls) == 1:
|
||||
tool_call = output.tool_calls[0]
|
||||
# 0.5 point for getting the tool right
|
||||
if tool_call.name == expected.name:
|
||||
score += 0.5
|
||||
if not expected.args:
|
||||
score += 0.5 if not tool_call.args else 0 # If no args expected, only score for lack of args
|
||||
else:
|
||||
score_per_arg = 0.5 / len(expected.args)
|
||||
for arg_name, expected_arg_value in expected.args.items():
|
||||
if arg_name in self.semantic_similarity_args:
|
||||
arg_similarity = AnswerSimilarity(model="text-embedding-3-small").eval(
|
||||
output=tool_call.args.get(arg_name), expected=expected_arg_value
|
||||
)
|
||||
score += arg_similarity.score * score_per_arg
|
||||
elif tool_call.args.get(arg_name) == expected_arg_value:
|
||||
score += score_per_arg
|
||||
return Score(name=self._name(), score=score)
|
||||
|
||||
best_score = 0.0 # 0.0 to 1.0
|
||||
if output.tool_calls:
|
||||
# Check all tool calls and return the best match
|
||||
for tool_call in output.tool_calls:
|
||||
score = 0.0
|
||||
# 0.5 point for getting the tool right
|
||||
if tool_call.name == expected.name:
|
||||
score += 0.5
|
||||
if not expected.args:
|
||||
score += 0.5 if not tool_call.args else 0 # If no args expected, only score for lack of args
|
||||
else:
|
||||
score_per_arg = 0.5 / len(expected.args)
|
||||
for arg_name, expected_arg_value in expected.args.items():
|
||||
if arg_name in self.semantic_similarity_args:
|
||||
arg_similarity = AnswerSimilarity(model="text-embedding-3-small").eval(
|
||||
output=tool_call.args.get(arg_name), expected=expected_arg_value
|
||||
)
|
||||
score += arg_similarity.score * score_per_arg
|
||||
elif tool_call.args.get(arg_name) == expected_arg_value:
|
||||
score += score_per_arg
|
||||
best_score = max(best_score, score)
|
||||
return Score(name=self._name(), score=best_score)
|
||||
|
||||
|
||||
class PlanAndQueryOutput(TypedDict, total=False):
|
||||
|
||||
@@ -1,33 +0,0 @@
|
||||
from .deep_research.graph import DeepResearchAssistantGraph
|
||||
from .funnels.nodes import FunnelGeneratorNode
|
||||
from .graph import AssistantGraph, InsightsAssistantGraph
|
||||
from .inkeep_docs.nodes import InkeepDocsNode
|
||||
from .insights.nodes import InsightSearchNode
|
||||
from .memory.nodes import MemoryInitializerNode
|
||||
from .query_executor.nodes import QueryExecutorNode
|
||||
from .query_planner.nodes import QueryPlannerNode
|
||||
from .rag.nodes import InsightRagContextNode
|
||||
from .retention.nodes import RetentionGeneratorNode
|
||||
from .root.nodes import RootNode, RootNodeTools
|
||||
from .schema_generator.nodes import SchemaGeneratorNode
|
||||
from .sql.nodes import SQLGeneratorNode
|
||||
from .trends.nodes import TrendsGeneratorNode
|
||||
|
||||
__all__ = [
|
||||
"FunnelGeneratorNode",
|
||||
"InkeepDocsNode",
|
||||
"MemoryInitializerNode",
|
||||
"QueryExecutorNode",
|
||||
"InsightRagContextNode",
|
||||
"RetentionGeneratorNode",
|
||||
"RootNode",
|
||||
"RootNodeTools",
|
||||
"SchemaGeneratorNode",
|
||||
"SQLGeneratorNode",
|
||||
"QueryPlannerNode",
|
||||
"TrendsGeneratorNode",
|
||||
"AssistantGraph",
|
||||
"InsightsAssistantGraph",
|
||||
"InsightSearchNode",
|
||||
"DeepResearchAssistantGraph",
|
||||
]
|
||||
|
||||
9
ee/hogai/graph/base/__init__.py
Normal file
9
ee/hogai/graph/base/__init__.py
Normal file
@@ -0,0 +1,9 @@
|
||||
from .graph import BaseAssistantGraph, global_checkpointer
|
||||
from .node import AssistantNode, BaseAssistantNode
|
||||
|
||||
__all__ = [
|
||||
"BaseAssistantNode",
|
||||
"AssistantNode",
|
||||
"BaseAssistantGraph",
|
||||
"global_checkpointer",
|
||||
]
|
||||
22
ee/hogai/graph/base/context.py
Normal file
22
ee/hogai/graph/base/context.py
Normal file
@@ -0,0 +1,22 @@
|
||||
import contextvars
|
||||
from contextlib import contextmanager
|
||||
|
||||
from ee.hogai.utils.types.base import NodePath
|
||||
|
||||
node_path_context = contextvars.ContextVar[tuple[NodePath, ...]]("node_path_context")
|
||||
|
||||
|
||||
@contextmanager
|
||||
def set_node_path(node_path: tuple[NodePath, ...]):
|
||||
token = node_path_context.set(node_path)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
node_path_context.reset(token)
|
||||
|
||||
|
||||
def get_node_path() -> tuple[NodePath, ...] | None:
|
||||
try:
|
||||
return node_path_context.get()
|
||||
except LookupError:
|
||||
return None
|
||||
88
ee/hogai/graph/base/graph.py
Normal file
88
ee/hogai/graph/base/graph.py
Normal file
@@ -0,0 +1,88 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import TYPE_CHECKING, Any, Generic, Literal, TypeVar
|
||||
|
||||
from langgraph.graph.state import StateGraph
|
||||
|
||||
from posthog.models import Team, User
|
||||
|
||||
from ee.hogai.django_checkpoint.checkpointer import DjangoCheckpointer
|
||||
from ee.hogai.utils.types.base import AssistantGraphName, AssistantNodeName, NodePath, PartialStateType, StateType
|
||||
|
||||
from .context import get_node_path, set_node_path
|
||||
from .node import BaseAssistantNode
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ee.hogai.utils.types.composed import MaxNodeName
|
||||
|
||||
|
||||
# Base checkpointer for all graphs
|
||||
global_checkpointer = DjangoCheckpointer()
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def with_node_path(func: Callable[..., T]) -> Callable[..., T]:
|
||||
@wraps(func)
|
||||
def wrapper(self, *args: Any, **kwargs: Any) -> T:
|
||||
with set_node_path(self.node_path):
|
||||
return func(self, *args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class BaseAssistantGraph(Generic[StateType, PartialStateType], ABC):
|
||||
_team: Team
|
||||
_user: User
|
||||
_graph: StateGraph
|
||||
_node_path: tuple[NodePath, ...]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
team: Team,
|
||||
user: User,
|
||||
):
|
||||
self._team = team
|
||||
self._user = user
|
||||
self._has_start_node = False
|
||||
self._graph = StateGraph(self.state_type)
|
||||
self._node_path = (*(get_node_path() or ()), NodePath(name=self.graph_name.value))
|
||||
|
||||
def __init_subclass__(cls, **kwargs):
|
||||
super().__init_subclass__(**kwargs)
|
||||
# Wrap all public methods with the node path context
|
||||
for name, method in cls.__dict__.items():
|
||||
if callable(method) and not name.startswith("_") and name not in ("graph_name", "state_type", "node_path"):
|
||||
setattr(cls, name, with_node_path(method))
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def state_type(self) -> type[StateType]: ...
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def graph_name(self) -> AssistantGraphName: ...
|
||||
|
||||
@property
|
||||
def node_path(self) -> tuple[NodePath, ...]:
|
||||
return self._node_path
|
||||
|
||||
def add_edge(self, from_node: "MaxNodeName", to_node: "MaxNodeName"):
|
||||
if from_node == AssistantNodeName.START:
|
||||
self._has_start_node = True
|
||||
self._graph.add_edge(from_node, to_node)
|
||||
return self
|
||||
|
||||
def add_node(self, node: "MaxNodeName", action: BaseAssistantNode[StateType, PartialStateType]):
|
||||
self._graph.add_node(node, action)
|
||||
return self
|
||||
|
||||
def compile(self, checkpointer: DjangoCheckpointer | None | Literal[False] = None):
|
||||
if not self._has_start_node:
|
||||
raise ValueError("Start node not added to the graph")
|
||||
# TRICKY: We check `is not None` because False has a special meaning of "no checkpointer", which we want to pass on
|
||||
compiled_graph = self._graph.compile(
|
||||
checkpointer=checkpointer if checkpointer is not None else global_checkpointer
|
||||
)
|
||||
return compiled_graph
|
||||
@@ -1,49 +1,45 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Sequence
|
||||
from abc import ABC
|
||||
from typing import Generic
|
||||
from uuid import UUID
|
||||
|
||||
from django.conf import settings
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.config import get_stream_writer
|
||||
|
||||
from posthog.schema import AssistantMessage, AssistantToolCall, HumanMessage
|
||||
from posthog.schema import HumanMessage
|
||||
|
||||
from posthog.models import Team
|
||||
from posthog.models.user import User
|
||||
from posthog.models import Team, User
|
||||
from posthog.sync import database_sync_to_async
|
||||
|
||||
from ee.hogai.context import AssistantContextManager
|
||||
from ee.hogai.graph.mixins import AssistantContextMixin
|
||||
from ee.hogai.utils.dispatcher import AssistantDispatcher
|
||||
from ee.hogai.graph.base.context import get_node_path, set_node_path
|
||||
from ee.hogai.graph.mixins import AssistantContextMixin, AssistantDispatcherMixin
|
||||
from ee.hogai.utils.exceptions import GenerationCanceled
|
||||
from ee.hogai.utils.helpers import find_start_message
|
||||
from ee.hogai.utils.types import (
|
||||
AssistantMessageUnion,
|
||||
from ee.hogai.utils.types.base import (
|
||||
AssistantState,
|
||||
NodeEndAction,
|
||||
NodePath,
|
||||
NodeStartAction,
|
||||
PartialAssistantState,
|
||||
PartialStateType,
|
||||
StateType,
|
||||
)
|
||||
from ee.hogai.utils.types.composed import MaxNodeName
|
||||
from ee.models import Conversation
|
||||
|
||||
|
||||
class BaseAssistantNode(Generic[StateType, PartialStateType], AssistantContextMixin, ABC):
|
||||
class BaseAssistantNode(Generic[StateType, PartialStateType], AssistantContextMixin, AssistantDispatcherMixin, ABC):
|
||||
_config: RunnableConfig | None = None
|
||||
_context_manager: AssistantContextManager | None = None
|
||||
_dispatcher: AssistantDispatcher | None = None
|
||||
_parent_tool_call_id: str | None = None
|
||||
_node_path: tuple[NodePath, ...]
|
||||
|
||||
def __init__(self, team: Team, user: User):
|
||||
def __init__(self, team: Team, user: User, node_path: tuple[NodePath, ...] | None = None):
|
||||
self._team = team
|
||||
self._user = user
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def node_name(self) -> MaxNodeName:
|
||||
raise NotImplementedError
|
||||
if node_path is None:
|
||||
self._node_path = (*(get_node_path() or ()), NodePath(name=self.node_name))
|
||||
else:
|
||||
self._node_path = node_path
|
||||
|
||||
async def __call__(self, state: StateType, config: RunnableConfig) -> PartialStateType | None:
|
||||
"""
|
||||
@@ -54,25 +50,19 @@ class BaseAssistantNode(Generic[StateType, PartialStateType], AssistantContextMi
|
||||
self._dispatcher = None
|
||||
self._config = config
|
||||
|
||||
if isinstance(state, AssistantState) and state.root_tool_call_id:
|
||||
# NOTE: we set the parent tool call id as the root tool call id
|
||||
# This will be deprecated once all tools become MaxTools and are removed from the graph
|
||||
self._parent_tool_call_id = state.root_tool_call_id
|
||||
|
||||
self.dispatcher.node_start()
|
||||
self.dispatcher.dispatch(NodeStartAction())
|
||||
|
||||
thread_id = (config.get("configurable") or {}).get("thread_id")
|
||||
if thread_id and await self._is_conversation_cancelled(thread_id):
|
||||
raise GenerationCanceled
|
||||
|
||||
try:
|
||||
new_state = await self.arun(state, config)
|
||||
new_state = await self._arun_with_context(state, config)
|
||||
except NotImplementedError:
|
||||
new_state = await database_sync_to_async(self.run, thread_sensitive=False)(state, config)
|
||||
new_state = await database_sync_to_async(self._run_with_context, thread_sensitive=False)(state, config)
|
||||
|
||||
self.dispatcher.dispatch(NodeEndAction(state=new_state))
|
||||
|
||||
if new_state is not None and (messages := getattr(new_state, "messages", [])):
|
||||
for message in messages:
|
||||
self.dispatcher.message(message)
|
||||
return new_state
|
||||
|
||||
def run(self, state: StateType, config: RunnableConfig) -> PartialStateType | None:
|
||||
@@ -82,6 +72,14 @@ class BaseAssistantNode(Generic[StateType, PartialStateType], AssistantContextMi
|
||||
async def arun(self, state: StateType, config: RunnableConfig) -> PartialStateType | None:
|
||||
raise NotImplementedError
|
||||
|
||||
def _run_with_context(self, state: StateType, config: RunnableConfig) -> PartialStateType | None:
|
||||
with set_node_path(self.node_path):
|
||||
return self.run(state, config)
|
||||
|
||||
async def _arun_with_context(self, state: StateType, config: RunnableConfig) -> PartialStateType | None:
|
||||
with set_node_path(self.node_path):
|
||||
return await self.arun(state, config)
|
||||
|
||||
@property
|
||||
def context_manager(self) -> AssistantContextManager:
|
||||
if self._context_manager is None:
|
||||
@@ -97,26 +95,13 @@ class BaseAssistantNode(Generic[StateType, PartialStateType], AssistantContextMi
|
||||
return self._context_manager
|
||||
|
||||
@property
|
||||
def dispatcher(self) -> AssistantDispatcher:
|
||||
"""Create a dispatcher for this node"""
|
||||
if self._dispatcher:
|
||||
return self._dispatcher
|
||||
|
||||
# Set writer from LangGraph context
|
||||
try:
|
||||
writer = get_stream_writer()
|
||||
except RuntimeError:
|
||||
# Not in streaming context (e.g., testing)
|
||||
# Use noop writer
|
||||
def noop(*_args, **_kwargs):
|
||||
pass
|
||||
|
||||
writer = noop
|
||||
|
||||
self._dispatcher = AssistantDispatcher(
|
||||
writer, node_name=self.node_name, parent_tool_call_id=self._parent_tool_call_id
|
||||
)
|
||||
return self._dispatcher
|
||||
def node_name(self) -> str:
|
||||
config_name: str | None = None
|
||||
if self._config:
|
||||
config_name = self._config["metadata"].get("langgraph_node")
|
||||
if config_name is not None:
|
||||
config_name = str(config_name)
|
||||
return config_name or self.__class__.__name__
|
||||
|
||||
async def _is_conversation_cancelled(self, conversation_id: UUID) -> bool:
|
||||
conversation = await self._aget_conversation(conversation_id)
|
||||
@@ -124,15 +109,6 @@ class BaseAssistantNode(Generic[StateType, PartialStateType], AssistantContextMi
|
||||
raise ValueError(f"Conversation {conversation_id} not found")
|
||||
return conversation.status == Conversation.Status.CANCELING
|
||||
|
||||
def _get_tool_call(self, messages: Sequence[AssistantMessageUnion], tool_call_id: str) -> AssistantToolCall:
|
||||
for message in reversed(messages):
|
||||
if not isinstance(message, AssistantMessage) or not message.tool_calls:
|
||||
continue
|
||||
for tool_call in message.tool_calls:
|
||||
if tool_call.id == tool_call_id:
|
||||
return tool_call
|
||||
raise ValueError(f"Tool call {tool_call_id} not found in state")
|
||||
|
||||
def _is_first_turn(self, state: AssistantState) -> bool:
|
||||
last_message = state.messages[-1]
|
||||
if isinstance(last_message, HumanMessage):
|
||||
47
ee/hogai/graph/base/test/test_assistant_graph.py
Normal file
47
ee/hogai/graph/base/test/test_assistant_graph.py
Normal file
@@ -0,0 +1,47 @@
|
||||
from posthog.test.base import BaseTest
|
||||
|
||||
from langgraph.checkpoint.memory import InMemorySaver
|
||||
|
||||
from ee.hogai.graph.base import AssistantNode
|
||||
from ee.hogai.graph.base.graph import BaseAssistantGraph
|
||||
from ee.hogai.utils.types import AssistantNodeName, AssistantState, PartialAssistantState
|
||||
from ee.hogai.utils.types.base import AssistantGraphName
|
||||
from ee.models import Conversation
|
||||
|
||||
|
||||
class TestAssistantGraph(BaseTest):
|
||||
async def test_pydantic_state_resets_with_none(self):
|
||||
"""When a None field is set, it should be reset to None."""
|
||||
|
||||
class TestAssistantGraph(BaseAssistantGraph[AssistantState, PartialAssistantState]):
|
||||
@property
|
||||
def state_type(self) -> type[AssistantState]:
|
||||
return AssistantState
|
||||
|
||||
@property
|
||||
def graph_name(self) -> AssistantGraphName:
|
||||
return AssistantGraphName.ASSISTANT
|
||||
|
||||
graph = TestAssistantGraph(self.team, self.user)
|
||||
|
||||
class TestNode(AssistantNode):
|
||||
@property
|
||||
def node_name(self):
|
||||
return AssistantNodeName.ROOT
|
||||
|
||||
async def arun(self, state, config):
|
||||
return PartialAssistantState(start_id=None)
|
||||
|
||||
compiled_graph = (
|
||||
graph.add_node(AssistantNodeName.ROOT, TestNode(self.team, self.user))
|
||||
.add_edge(AssistantNodeName.START, AssistantNodeName.ROOT)
|
||||
.add_edge(AssistantNodeName.ROOT, AssistantNodeName.END)
|
||||
.compile(checkpointer=InMemorySaver())
|
||||
)
|
||||
conversation = await Conversation.objects.acreate(team=self.team, user=self.user)
|
||||
state = await compiled_graph.ainvoke(
|
||||
AssistantState(messages=[], graph_status="resumed", start_id=None),
|
||||
{"configurable": {"thread_id": conversation.id}},
|
||||
)
|
||||
self.assertEqual(state["start_id"], None)
|
||||
self.assertEqual(state["graph_status"], "resumed")
|
||||
502
ee/hogai/graph/base/test/test_node_path.py
Normal file
502
ee/hogai/graph/base/test/test_node_path.py
Normal file
@@ -0,0 +1,502 @@
|
||||
from posthog.test.base import BaseTest
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
|
||||
from ee.hogai.graph.base import AssistantNode
|
||||
from ee.hogai.graph.base.context import get_node_path
|
||||
from ee.hogai.graph.base.graph import BaseAssistantGraph
|
||||
from ee.hogai.utils.types import AssistantNodeName, AssistantState, PartialAssistantState
|
||||
from ee.hogai.utils.types.base import AssistantGraphName, NodePath
|
||||
from ee.models import Conversation
|
||||
|
||||
|
||||
class TestNodePath(BaseTest):
|
||||
"""
|
||||
Tests for node_path functionality across sync/async methods and graph compositions.
|
||||
"""
|
||||
|
||||
async def test_graph_to_async_node_has_two_elements(self):
|
||||
"""Graph -> Node (async) should have path: [graph, node]"""
|
||||
captured_path = None
|
||||
|
||||
class TestNode(AssistantNode):
|
||||
async def arun(self, state, config):
|
||||
nonlocal captured_path
|
||||
captured_path = get_node_path()
|
||||
return None
|
||||
|
||||
class TestGraph(BaseAssistantGraph[AssistantState, PartialAssistantState]):
|
||||
@property
|
||||
def state_type(self) -> type[AssistantState]:
|
||||
return AssistantState
|
||||
|
||||
@property
|
||||
def graph_name(self) -> AssistantGraphName:
|
||||
return AssistantGraphName.ASSISTANT
|
||||
|
||||
def setup(self):
|
||||
node = TestNode(self._team, self._user)
|
||||
self.add_node(AssistantNodeName.ROOT, node)
|
||||
self.add_edge(AssistantNodeName.START, AssistantNodeName.ROOT)
|
||||
self.add_edge(AssistantNodeName.ROOT, AssistantNodeName.END)
|
||||
return self
|
||||
|
||||
graph = TestGraph(self.team, self.user)
|
||||
compiled = graph.setup().compile(checkpointer=False)
|
||||
|
||||
conversation = await Conversation.objects.acreate(team=self.team, user=self.user)
|
||||
await compiled.ainvoke(AssistantState(messages=[]), {"configurable": {"thread_id": conversation.id}})
|
||||
|
||||
assert captured_path is not None
|
||||
self.assertEqual(len(captured_path), 2)
|
||||
self.assertEqual(captured_path[0].name, AssistantGraphName.ASSISTANT.value)
|
||||
# Node name is determined at init time, so it's the class name
|
||||
self.assertEqual(captured_path[1].name, "TestNode")
|
||||
|
||||
async def test_graph_to_sync_node_has_two_elements(self):
|
||||
"""Graph -> Node (sync) should have path: [graph, node]"""
|
||||
captured_path = None
|
||||
|
||||
class TestNode(AssistantNode):
|
||||
def run(self, state, config):
|
||||
nonlocal captured_path
|
||||
captured_path = get_node_path()
|
||||
return None
|
||||
|
||||
class TestGraph(BaseAssistantGraph[AssistantState, PartialAssistantState]):
|
||||
@property
|
||||
def state_type(self) -> type[AssistantState]:
|
||||
return AssistantState
|
||||
|
||||
@property
|
||||
def graph_name(self) -> AssistantGraphName:
|
||||
return AssistantGraphName.ASSISTANT
|
||||
|
||||
def setup(self):
|
||||
node = TestNode(self._team, self._user)
|
||||
self.add_node(AssistantNodeName.ROOT, node)
|
||||
self.add_edge(AssistantNodeName.START, AssistantNodeName.ROOT)
|
||||
self.add_edge(AssistantNodeName.ROOT, AssistantNodeName.END)
|
||||
return self
|
||||
|
||||
graph = TestGraph(self.team, self.user)
|
||||
compiled = graph.setup().compile(checkpointer=False)
|
||||
|
||||
conversation = await Conversation.objects.acreate(team=self.team, user=self.user)
|
||||
await compiled.ainvoke(AssistantState(messages=[]), {"configurable": {"thread_id": conversation.id}})
|
||||
|
||||
assert captured_path is not None
|
||||
self.assertEqual(len(captured_path), 2)
|
||||
self.assertEqual(captured_path[0].name, AssistantGraphName.ASSISTANT.value)
|
||||
self.assertEqual(captured_path[1].name, "TestNode")
|
||||
|
||||
async def test_graph_to_node_to_async_node_has_three_elements(self):
|
||||
"""Graph -> Node -> Node (async) should have path: [graph, node, node]"""
|
||||
captured_paths = []
|
||||
|
||||
class SecondNode(AssistantNode):
|
||||
async def arun(self, state, config):
|
||||
nonlocal captured_paths
|
||||
captured_paths.append(get_node_path())
|
||||
return None
|
||||
|
||||
class FirstNode(AssistantNode):
|
||||
async def arun(self, state, config):
|
||||
nonlocal captured_paths
|
||||
captured_paths.append(get_node_path())
|
||||
# Call second node
|
||||
second_node = SecondNode(self._team, self._user)
|
||||
await second_node(state, config)
|
||||
return None
|
||||
|
||||
class TestGraph(BaseAssistantGraph[AssistantState, PartialAssistantState]):
|
||||
@property
|
||||
def state_type(self) -> type[AssistantState]:
|
||||
return AssistantState
|
||||
|
||||
@property
|
||||
def graph_name(self) -> AssistantGraphName:
|
||||
return AssistantGraphName.ASSISTANT
|
||||
|
||||
def setup(self):
|
||||
node = FirstNode(self._team, self._user)
|
||||
self.add_node(AssistantNodeName.ROOT, node)
|
||||
self.add_edge(AssistantNodeName.START, AssistantNodeName.ROOT)
|
||||
self.add_edge(AssistantNodeName.ROOT, AssistantNodeName.END)
|
||||
return self
|
||||
|
||||
graph = TestGraph(self.team, self.user)
|
||||
compiled = graph.setup().compile(checkpointer=False)
|
||||
|
||||
conversation = await Conversation.objects.acreate(team=self.team, user=self.user)
|
||||
await compiled.ainvoke(AssistantState(messages=[]), {"configurable": {"thread_id": conversation.id}})
|
||||
|
||||
# First node: [graph, node]
|
||||
assert captured_paths[0] is not None
|
||||
self.assertEqual(len(captured_paths[0]), 2)
|
||||
self.assertEqual(captured_paths[0][0].name, AssistantGraphName.ASSISTANT.value)
|
||||
self.assertEqual(captured_paths[0][1].name, "FirstNode")
|
||||
|
||||
# Second node: [graph, node, node]
|
||||
assert captured_paths[1] is not None
|
||||
self.assertEqual(len(captured_paths[1]), 3)
|
||||
self.assertEqual(captured_paths[1][0].name, AssistantGraphName.ASSISTANT.value)
|
||||
self.assertEqual(captured_paths[1][1].name, "FirstNode")
|
||||
self.assertEqual(captured_paths[1][2].name, "SecondNode")
|
||||
|
||||
async def test_graph_to_node_to_sync_node_has_three_elements(self):
|
||||
"""Graph -> Node -> Node (sync) - calling .run() directly doesn't extend path"""
|
||||
captured_paths = []
|
||||
|
||||
class SecondNode(AssistantNode):
|
||||
def run(self, state, config):
|
||||
nonlocal captured_paths
|
||||
captured_paths.append(get_node_path())
|
||||
return None
|
||||
|
||||
class FirstNode(AssistantNode):
|
||||
def run(self, state, config):
|
||||
nonlocal captured_paths
|
||||
captured_paths.append(get_node_path())
|
||||
# Call second node - note: calling run() directly bypasses context setting,
|
||||
# so the second node won't have the proper path. This tests that direct .run()
|
||||
# calls don't propagate context properly.
|
||||
second_node = SecondNode(self._team, self._user)
|
||||
second_node.run(state, config)
|
||||
return None
|
||||
|
||||
class TestGraph(BaseAssistantGraph[AssistantState, PartialAssistantState]):
|
||||
@property
|
||||
def state_type(self) -> type[AssistantState]:
|
||||
return AssistantState
|
||||
|
||||
@property
|
||||
def graph_name(self) -> AssistantGraphName:
|
||||
return AssistantGraphName.ASSISTANT
|
||||
|
||||
def setup(self):
|
||||
node = FirstNode(self._team, self._user)
|
||||
self.add_node(AssistantNodeName.ROOT, node)
|
||||
self.add_edge(AssistantNodeName.START, AssistantNodeName.ROOT)
|
||||
self.add_edge(AssistantNodeName.ROOT, AssistantNodeName.END)
|
||||
return self
|
||||
|
||||
graph = TestGraph(self.team, self.user)
|
||||
compiled = graph.setup().compile(checkpointer=False)
|
||||
|
||||
conversation = await Conversation.objects.acreate(team=self.team, user=self.user)
|
||||
await compiled.ainvoke(AssistantState(messages=[]), {"configurable": {"thread_id": conversation.id}})
|
||||
|
||||
# First node: [graph, node]
|
||||
assert captured_paths[0] is not None
|
||||
self.assertEqual(len(captured_paths[0]), 2)
|
||||
self.assertEqual(captured_paths[0][0].name, AssistantGraphName.ASSISTANT.value)
|
||||
self.assertEqual(captured_paths[0][1].name, "FirstNode")
|
||||
|
||||
# Second node: [graph, node] - initialized within FirstNode's context, so gets same path
|
||||
assert captured_paths[1] is not None
|
||||
self.assertEqual(len(captured_paths[1]), 2)
|
||||
self.assertEqual(captured_paths[1][0].name, AssistantGraphName.ASSISTANT.value)
|
||||
self.assertEqual(captured_paths[1][1].name, "FirstNode") # Same as first because initialized in same context
|
||||
|
||||
async def test_graph_to_node_to_graph_to_node_has_four_elements(self):
|
||||
"""Graph -> Node -> Graph -> Node should have path: [graph, node, graph, node]"""
|
||||
captured_paths = []
|
||||
|
||||
class InnerNode(AssistantNode):
|
||||
async def arun(self, state, config):
|
||||
nonlocal captured_paths
|
||||
captured_paths.append(get_node_path())
|
||||
return None
|
||||
|
||||
class InnerGraph(BaseAssistantGraph[AssistantState, PartialAssistantState]):
|
||||
@property
|
||||
def state_type(self) -> type[AssistantState]:
|
||||
return AssistantState
|
||||
|
||||
@property
|
||||
def graph_name(self) -> AssistantGraphName:
|
||||
return AssistantGraphName.INSIGHTS
|
||||
|
||||
def setup(self):
|
||||
node = InnerNode(self._team, self._user)
|
||||
self.add_node(AssistantNodeName.TRENDS_GENERATOR, node)
|
||||
self.add_edge(AssistantNodeName.START, AssistantNodeName.TRENDS_GENERATOR)
|
||||
self.add_edge(AssistantNodeName.TRENDS_GENERATOR, AssistantNodeName.END)
|
||||
return self
|
||||
|
||||
class OuterNode(AssistantNode):
|
||||
async def arun(self, state, config):
|
||||
nonlocal captured_paths
|
||||
captured_paths.append(get_node_path())
|
||||
# Call inner graph
|
||||
inner_graph = InnerGraph(self._team, self._user)
|
||||
compiled_inner = inner_graph.setup().compile(checkpointer=False)
|
||||
await compiled_inner.ainvoke(state, config)
|
||||
return None
|
||||
|
||||
class OuterGraph(BaseAssistantGraph[AssistantState, PartialAssistantState]):
|
||||
@property
|
||||
def state_type(self) -> type[AssistantState]:
|
||||
return AssistantState
|
||||
|
||||
@property
|
||||
def graph_name(self) -> AssistantGraphName:
|
||||
return AssistantGraphName.ASSISTANT
|
||||
|
||||
def setup(self):
|
||||
node = OuterNode(self._team, self._user)
|
||||
self.add_node(AssistantNodeName.ROOT, node)
|
||||
self.add_edge(AssistantNodeName.START, AssistantNodeName.ROOT)
|
||||
self.add_edge(AssistantNodeName.ROOT, AssistantNodeName.END)
|
||||
return self
|
||||
|
||||
outer_graph = OuterGraph(self.team, self.user)
|
||||
compiled = outer_graph.setup().compile(checkpointer=False)
|
||||
|
||||
conversation = await Conversation.objects.acreate(team=self.team, user=self.user)
|
||||
await compiled.ainvoke(AssistantState(messages=[]), {"configurable": {"thread_id": conversation.id}})
|
||||
|
||||
# Outer node: [graph, node]
|
||||
assert captured_paths[0] is not None
|
||||
self.assertEqual(len(captured_paths[0]), 2)
|
||||
self.assertEqual(captured_paths[0][0].name, AssistantGraphName.ASSISTANT.value)
|
||||
self.assertEqual(captured_paths[0][1].name, "OuterNode")
|
||||
|
||||
# Inner node: [graph, node, graph, node]
|
||||
assert captured_paths[1] is not None
|
||||
self.assertEqual(len(captured_paths[1]), 4)
|
||||
self.assertEqual(captured_paths[1][0].name, AssistantGraphName.ASSISTANT.value)
|
||||
self.assertEqual(captured_paths[1][1].name, "OuterNode")
|
||||
self.assertEqual(captured_paths[1][2].name, AssistantGraphName.INSIGHTS.value)
|
||||
self.assertEqual(captured_paths[1][3].name, "InnerNode")
|
||||
|
||||
async def test_graph_to_graph_to_node_has_three_elements(self):
|
||||
"""Graph -> Graph -> Node should have path: [graph, graph, node]"""
|
||||
captured_path = None
|
||||
|
||||
class InnerNode(AssistantNode):
|
||||
async def arun(self, state, config):
|
||||
nonlocal captured_path
|
||||
captured_path = get_node_path()
|
||||
return None
|
||||
|
||||
class InnerGraph(BaseAssistantGraph[AssistantState, PartialAssistantState]):
|
||||
@property
|
||||
def state_type(self) -> type[AssistantState]:
|
||||
return AssistantState
|
||||
|
||||
@property
|
||||
def graph_name(self) -> AssistantGraphName:
|
||||
return AssistantGraphName.INSIGHTS
|
||||
|
||||
def setup(self):
|
||||
node = InnerNode(self._team, self._user)
|
||||
self.add_node(AssistantNodeName.TRENDS_GENERATOR, node)
|
||||
self.add_edge(AssistantNodeName.START, AssistantNodeName.TRENDS_GENERATOR)
|
||||
self.add_edge(AssistantNodeName.TRENDS_GENERATOR, AssistantNodeName.END)
|
||||
return self
|
||||
|
||||
class OuterGraph(BaseAssistantGraph[AssistantState, PartialAssistantState]):
|
||||
@property
|
||||
def state_type(self) -> type[AssistantState]:
|
||||
return AssistantState
|
||||
|
||||
@property
|
||||
def graph_name(self) -> AssistantGraphName:
|
||||
return AssistantGraphName.ASSISTANT
|
||||
|
||||
async def invoke_inner_graph(self, state, config):
|
||||
inner_graph = InnerGraph(self._team, self._user)
|
||||
compiled_inner = inner_graph.setup().compile(checkpointer=False)
|
||||
await compiled_inner.ainvoke(state, config)
|
||||
|
||||
outer_graph = OuterGraph(self.team, self.user)
|
||||
|
||||
conversation = await Conversation.objects.acreate(team=self.team, user=self.user)
|
||||
await outer_graph.invoke_inner_graph(
|
||||
AssistantState(messages=[]), RunnableConfig(configurable={"thread_id": conversation.id})
|
||||
)
|
||||
|
||||
# Inner node: [graph, node] - outer graph context is not propagated when calling graph methods directly
|
||||
assert captured_path is not None
|
||||
self.assertEqual(len(captured_path), 2)
|
||||
self.assertEqual(captured_path[0].name, AssistantGraphName.INSIGHTS.value)
|
||||
self.assertEqual(captured_path[1].name, "InnerNode")
|
||||
|
||||
async def test_node_path_preserved_across_async_and_sync_methods(self):
|
||||
"""Test that calling .run() directly doesn't extend path"""
|
||||
captured_paths = []
|
||||
|
||||
class SyncNode(AssistantNode):
|
||||
def run(self, state, config):
|
||||
nonlocal captured_paths
|
||||
captured_paths.append(("sync", get_node_path()))
|
||||
return None
|
||||
|
||||
class AsyncNode(AssistantNode):
|
||||
async def arun(self, state, config):
|
||||
nonlocal captured_paths
|
||||
captured_paths.append(("async", get_node_path()))
|
||||
# Call sync node - calling .run() directly bypasses context setting
|
||||
sync_node = SyncNode(self._team, self._user)
|
||||
sync_node.run(state, config)
|
||||
return None
|
||||
|
||||
class TestGraph(BaseAssistantGraph[AssistantState, PartialAssistantState]):
|
||||
@property
|
||||
def state_type(self) -> type[AssistantState]:
|
||||
return AssistantState
|
||||
|
||||
@property
|
||||
def graph_name(self) -> AssistantGraphName:
|
||||
return AssistantGraphName.ASSISTANT
|
||||
|
||||
def setup(self):
|
||||
node = AsyncNode(self._team, self._user)
|
||||
self.add_node(AssistantNodeName.ROOT, node)
|
||||
self.add_edge(AssistantNodeName.START, AssistantNodeName.ROOT)
|
||||
self.add_edge(AssistantNodeName.ROOT, AssistantNodeName.END)
|
||||
return self
|
||||
|
||||
graph = TestGraph(self.team, self.user)
|
||||
compiled = graph.setup().compile(checkpointer=False)
|
||||
|
||||
conversation = await Conversation.objects.acreate(team=self.team, user=self.user)
|
||||
await compiled.ainvoke(AssistantState(messages=[]), {"configurable": {"thread_id": conversation.id}})
|
||||
|
||||
# Async node: [graph, node]
|
||||
self.assertEqual(captured_paths[0][0], "async")
|
||||
assert captured_paths[0][1] is not None
|
||||
self.assertEqual(len(captured_paths[0][1]), 2)
|
||||
self.assertEqual(captured_paths[0][1][0].name, AssistantGraphName.ASSISTANT.value)
|
||||
self.assertEqual(captured_paths[0][1][1].name, "AsyncNode")
|
||||
|
||||
# Sync node: [graph, node] - initialized within AsyncNode's context, so gets same path
|
||||
self.assertEqual(captured_paths[1][0], "sync")
|
||||
assert captured_paths[1][1] is not None
|
||||
self.assertEqual(len(captured_paths[1][1]), 2)
|
||||
self.assertEqual(captured_paths[1][1][0].name, AssistantGraphName.ASSISTANT.value)
|
||||
self.assertEqual(captured_paths[1][1][1].name, "AsyncNode") # Same as async because initialized in same context
|
||||
|
||||
def test_node_path_with_explicit_node_path_parameter(self):
|
||||
"""Test that explicitly passing node_path overrides default behavior"""
|
||||
custom_path = (NodePath(name="custom_graph"), NodePath(name="custom_node"))
|
||||
|
||||
class TestNode(AssistantNode):
|
||||
def run(self, state, config):
|
||||
return None
|
||||
|
||||
node = TestNode(self.team, self.user, node_path=custom_path)
|
||||
|
||||
self.assertEqual(len(node._node_path), 2)
|
||||
self.assertEqual(node._node_path[0].name, "custom_graph")
|
||||
self.assertEqual(node._node_path[1].name, "custom_node")
|
||||
|
||||
async def test_multiple_nested_graphs(self):
|
||||
"""Test deeply nested graph composition: Graph -> Node -> Graph -> Node -> Graph -> Node"""
|
||||
captured_paths = []
|
||||
|
||||
class Level3Node(AssistantNode):
|
||||
async def arun(self, state, config):
|
||||
nonlocal captured_paths
|
||||
captured_paths.append(("level3", get_node_path()))
|
||||
return None
|
||||
|
||||
class Level3Graph(BaseAssistantGraph[AssistantState, PartialAssistantState]):
|
||||
@property
|
||||
def state_type(self) -> type[AssistantState]:
|
||||
return AssistantState
|
||||
|
||||
@property
|
||||
def graph_name(self) -> AssistantGraphName:
|
||||
return AssistantGraphName.TAXONOMY
|
||||
|
||||
def setup(self):
|
||||
node = Level3Node(self._team, self._user)
|
||||
self.add_node(AssistantNodeName.FUNNEL_GENERATOR, node)
|
||||
self.add_edge(AssistantNodeName.START, AssistantNodeName.FUNNEL_GENERATOR)
|
||||
self.add_edge(AssistantNodeName.FUNNEL_GENERATOR, AssistantNodeName.END)
|
||||
return self
|
||||
|
||||
class Level2Node(AssistantNode):
|
||||
async def arun(self, state, config):
|
||||
nonlocal captured_paths
|
||||
captured_paths.append(("level2", get_node_path()))
|
||||
# Call level 3 graph
|
||||
level3_graph = Level3Graph(self._team, self._user)
|
||||
compiled = level3_graph.setup().compile(checkpointer=False)
|
||||
await compiled.ainvoke(state, config)
|
||||
return None
|
||||
|
||||
class Level2Graph(BaseAssistantGraph[AssistantState, PartialAssistantState]):
|
||||
@property
|
||||
def state_type(self) -> type[AssistantState]:
|
||||
return AssistantState
|
||||
|
||||
@property
|
||||
def graph_name(self) -> AssistantGraphName:
|
||||
return AssistantGraphName.INSIGHTS
|
||||
|
||||
def setup(self):
|
||||
node = Level2Node(self._team, self._user)
|
||||
self.add_node(AssistantNodeName.TRENDS_GENERATOR, node)
|
||||
self.add_edge(AssistantNodeName.START, AssistantNodeName.TRENDS_GENERATOR)
|
||||
self.add_edge(AssistantNodeName.TRENDS_GENERATOR, AssistantNodeName.END)
|
||||
return self
|
||||
|
||||
class Level1Node(AssistantNode):
|
||||
async def arun(self, state, config):
|
||||
nonlocal captured_paths
|
||||
captured_paths.append(("level1", get_node_path()))
|
||||
# Call level 2 graph
|
||||
level2_graph = Level2Graph(self._team, self._user)
|
||||
compiled = level2_graph.setup().compile(checkpointer=False)
|
||||
await compiled.ainvoke(state, config)
|
||||
return None
|
||||
|
||||
class Level1Graph(BaseAssistantGraph[AssistantState, PartialAssistantState]):
|
||||
@property
|
||||
def state_type(self) -> type[AssistantState]:
|
||||
return AssistantState
|
||||
|
||||
@property
|
||||
def graph_name(self) -> AssistantGraphName:
|
||||
return AssistantGraphName.ASSISTANT
|
||||
|
||||
def setup(self):
|
||||
node = Level1Node(self._team, self._user)
|
||||
self.add_node(AssistantNodeName.ROOT, node)
|
||||
self.add_edge(AssistantNodeName.START, AssistantNodeName.ROOT)
|
||||
self.add_edge(AssistantNodeName.ROOT, AssistantNodeName.END)
|
||||
return self
|
||||
|
||||
level1_graph = Level1Graph(self.team, self.user)
|
||||
compiled = level1_graph.setup().compile(checkpointer=False)
|
||||
|
||||
conversation = await Conversation.objects.acreate(team=self.team, user=self.user)
|
||||
await compiled.ainvoke(AssistantState(messages=[]), {"configurable": {"thread_id": conversation.id}})
|
||||
|
||||
# Level 1: [graph, node]
|
||||
assert captured_paths[0][1] is not None
|
||||
self.assertEqual(len(captured_paths[0][1]), 2)
|
||||
self.assertEqual(captured_paths[0][1][0].name, AssistantGraphName.ASSISTANT.value)
|
||||
self.assertEqual(captured_paths[0][1][1].name, "Level1Node")
|
||||
|
||||
# Level 2: [graph, node, graph, node]
|
||||
assert captured_paths[1][1] is not None
|
||||
self.assertEqual(len(captured_paths[1][1]), 4)
|
||||
self.assertEqual(captured_paths[1][1][0].name, AssistantGraphName.ASSISTANT.value)
|
||||
self.assertEqual(captured_paths[1][1][1].name, "Level1Node")
|
||||
self.assertEqual(captured_paths[1][1][2].name, AssistantGraphName.INSIGHTS.value)
|
||||
self.assertEqual(captured_paths[1][1][3].name, "Level2Node")
|
||||
|
||||
# Level 3: [graph, node, graph, node, graph, node]
|
||||
assert captured_paths[2][1] is not None
|
||||
self.assertEqual(len(captured_paths[2][1]), 6)
|
||||
self.assertEqual(captured_paths[2][1][0].name, AssistantGraphName.ASSISTANT.value)
|
||||
self.assertEqual(captured_paths[2][1][1].name, "Level1Node")
|
||||
self.assertEqual(captured_paths[2][1][2].name, AssistantGraphName.INSIGHTS.value)
|
||||
self.assertEqual(captured_paths[2][1][3].name, "Level2Node")
|
||||
self.assertEqual(captured_paths[2][1][4].name, AssistantGraphName.TAXONOMY.value)
|
||||
self.assertEqual(captured_paths[2][1][5].name, "Level3Node")
|
||||
@@ -89,13 +89,6 @@ class DashboardCreationNode(AssistantNode):
|
||||
def _get_found_insight_count(self, queries_metadata: dict[str, QueryMetadata]) -> int:
|
||||
return sum(len(query.found_insight_ids) for query in queries_metadata.values())
|
||||
|
||||
def _dispatch_update_message(self, content: str) -> None:
|
||||
self.dispatcher.message(
|
||||
AssistantMessage(
|
||||
content=content,
|
||||
)
|
||||
)
|
||||
|
||||
async def arun(self, state: AssistantState, config: RunnableConfig) -> PartialAssistantState:
|
||||
dashboard_name = (
|
||||
state.dashboard_name[:50] if state.dashboard_name else "Analytics Dashboard"
|
||||
@@ -117,18 +110,18 @@ class DashboardCreationNode(AssistantNode):
|
||||
for i, query in enumerate(state.search_insights_queries)
|
||||
}
|
||||
|
||||
self._dispatch_update_message(f"Searching for {pluralize(len(state.search_insights_queries), 'insight')}")
|
||||
self.dispatcher.update(f"Searching for {pluralize(len(state.search_insights_queries), 'insight')}")
|
||||
|
||||
result = await self._search_insights(result, config)
|
||||
|
||||
self._dispatch_update_message(f"Found {pluralize(self._get_found_insight_count(result), 'insight')}")
|
||||
self.dispatcher.update(f"Found {pluralize(self._get_found_insight_count(result), 'insight')}")
|
||||
|
||||
left_to_create = {
|
||||
query_id: result[query_id].query for query_id in result.keys() if not result[query_id].found_insight_ids
|
||||
}
|
||||
|
||||
if left_to_create:
|
||||
self._dispatch_update_message(f"Will create {pluralize(len(left_to_create), 'insight')}")
|
||||
self.dispatcher.update(f"Will create {pluralize(len(left_to_create), 'insight')}")
|
||||
|
||||
result = await self._create_insights(left_to_create, result, config)
|
||||
|
||||
@@ -187,9 +180,7 @@ class DashboardCreationNode(AssistantNode):
|
||||
message = AssistantMessage(content="", id=str(uuid4()), tool_calls=tool_calls)
|
||||
|
||||
executor = DashboardCreationExecutorNode(self._team, self._user)
|
||||
result = await executor.arun(
|
||||
AssistantState(messages=[message], root_tool_call_id=self._parent_tool_call_id), config
|
||||
)
|
||||
result = await executor.arun(AssistantState(messages=[message], root_tool_call_id=self.tool_call_id), config)
|
||||
|
||||
query_metadata = await self._process_insight_creation_results(tool_calls, result.task_results, query_metadata)
|
||||
|
||||
@@ -211,9 +202,7 @@ class DashboardCreationNode(AssistantNode):
|
||||
message = AssistantMessage(content="", id=str(uuid4()), tool_calls=tool_calls)
|
||||
|
||||
executor = DashboardCreationExecutorNode(self._team, self._user)
|
||||
result = await executor.arun(
|
||||
AssistantState(messages=[message], root_tool_call_id=self._parent_tool_call_id), config
|
||||
)
|
||||
result = await executor.arun(AssistantState(messages=[message], root_tool_call_id=self.tool_call_id), config)
|
||||
final_task_executor_state = BaseStateWithTasks.model_validate(result)
|
||||
|
||||
for task_result in final_task_executor_state.task_results:
|
||||
@@ -307,7 +296,7 @@ class DashboardCreationNode(AssistantNode):
|
||||
self, dashboard_name: str, insights: set[int], dashboard_id: int | None = None
|
||||
) -> tuple[Dashboard, list[Insight]]:
|
||||
"""Create a dashboard and add the insights to it."""
|
||||
self._dispatch_update_message("Saving your dashboard")
|
||||
self.dispatcher.update("Saving your dashboard")
|
||||
|
||||
@database_sync_to_async
|
||||
@transaction.atomic
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import pytest
|
||||
from unittest import TestCase
|
||||
from posthog.test.base import BaseTest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
@@ -11,10 +11,17 @@ from posthog.models import Dashboard, Insight, Team, User
|
||||
from ee.hogai.graph.dashboards.nodes import DashboardCreationExecutorNode, DashboardCreationNode, QueryMetadata
|
||||
from ee.hogai.utils.helpers import build_dashboard_url, build_insight_url
|
||||
from ee.hogai.utils.types import AssistantState, PartialAssistantState
|
||||
from ee.hogai.utils.types.base import BaseStateWithTasks, InsightArtifact, InsightQuery, TaskResult
|
||||
from ee.hogai.utils.types.base import (
|
||||
AssistantNodeName,
|
||||
BaseStateWithTasks,
|
||||
InsightArtifact,
|
||||
InsightQuery,
|
||||
NodePath,
|
||||
TaskResult,
|
||||
)
|
||||
|
||||
|
||||
class TestQueryMetadata(TestCase):
|
||||
class TestQueryMetadata(BaseTest):
|
||||
def test_query_metadata_initialization(self):
|
||||
"""Test QueryMetadata initialization with all fields."""
|
||||
query = InsightQuery(name="Test Query", description="Test Description")
|
||||
@@ -33,21 +40,30 @@ class TestQueryMetadata(TestCase):
|
||||
self.assertEqual(metadata.query, query)
|
||||
|
||||
|
||||
class TestDashboardCreationExecutorNode:
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_method(self):
|
||||
class TestDashboardCreationExecutorNode(BaseTest):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.mock_team = MagicMock(spec=Team)
|
||||
self.mock_team.id = 1
|
||||
self.mock_user = MagicMock(spec=User)
|
||||
self.mock_user.id = 1
|
||||
self.node = DashboardCreationExecutorNode(self.mock_team, self.mock_user)
|
||||
self.node = DashboardCreationExecutorNode(
|
||||
self.mock_team,
|
||||
self.mock_user,
|
||||
(
|
||||
NodePath(
|
||||
name=AssistantNodeName.DASHBOARD_CREATION_EXECUTOR.value,
|
||||
message_id="test_message_id",
|
||||
tool_call_id="test_tool_call_id",
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
def test_initialization(self):
|
||||
"""Test node initialization."""
|
||||
assert self.node._team == self.mock_team
|
||||
assert self.node._user == self.mock_user
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aget_input_tuples_search_insights(self):
|
||||
"""Test _aget_input_tuples for search_insights tasks."""
|
||||
tool_calls = [
|
||||
@@ -62,7 +78,6 @@ class TestDashboardCreationExecutorNode:
|
||||
assert task.name == "search_insights"
|
||||
assert callable_func == self.node._execute_search_insights
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aget_input_tuples_create_insight(self):
|
||||
"""Test _aget_input_tuples for create_insight tasks."""
|
||||
tool_calls = [AssistantToolCall(id="task_1", name="create_insight", args={"query_description": "Test prompt"})]
|
||||
@@ -75,7 +90,6 @@ class TestDashboardCreationExecutorNode:
|
||||
assert task.name == "create_insight"
|
||||
assert callable_func == self.node._execute_create_insight
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aget_input_tuples_unsupported_task(self):
|
||||
"""Test _aget_input_tuples raises error for unsupported task type."""
|
||||
tool_calls = [AssistantToolCall(id="task_1", name="unsupported_type", args={"query": "Test prompt"})]
|
||||
@@ -85,7 +99,6 @@ class TestDashboardCreationExecutorNode:
|
||||
|
||||
assert "Unsupported task type: unsupported_type" in str(exc_info.value)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aget_input_tuples_no_tasks(self):
|
||||
"""Test _aget_input_tuples returns empty list when no tasks."""
|
||||
tool_calls: list[AssistantToolCall] = []
|
||||
@@ -95,14 +108,24 @@ class TestDashboardCreationExecutorNode:
|
||||
assert len(input_tuples) == 0
|
||||
|
||||
|
||||
class TestDashboardCreationNode:
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_method(self):
|
||||
class TestDashboardCreationNode(BaseTest):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.mock_team = MagicMock(spec=Team)
|
||||
self.mock_team.id = 1
|
||||
self.mock_user = MagicMock(spec=User)
|
||||
self.mock_user.id = 1
|
||||
self.node = DashboardCreationNode(self.mock_team, self.mock_user)
|
||||
self.node = DashboardCreationNode(
|
||||
self.mock_team,
|
||||
self.mock_user,
|
||||
(
|
||||
NodePath(
|
||||
name=AssistantNodeName.DASHBOARD_CREATION.value,
|
||||
message_id="test_message_id",
|
||||
tool_call_id="test_tool_call_id",
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
def test_get_found_insight_count(self):
|
||||
"""Test _get_found_insight_count calculates correct count."""
|
||||
@@ -140,7 +163,6 @@ class TestDashboardCreationNode:
|
||||
expected_url = f"/project/{self.mock_team.id}/dashboard/{dashboard_id}"
|
||||
assert url == expected_url
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("ee.hogai.graph.dashboards.nodes.DashboardCreationExecutorNode")
|
||||
async def test_arun_missing_search_insights_queries(self, mock_executor_node_class):
|
||||
"""Test arun returns error when search_insights_queries is missing."""
|
||||
@@ -150,7 +172,7 @@ class TestDashboardCreationNode:
|
||||
state = AssistantState(
|
||||
dashboard_name="Create dashboard",
|
||||
search_insights_queries=None,
|
||||
root_tool_call_id="test_call",
|
||||
root_tool_call_id="test_tool_call_id",
|
||||
)
|
||||
config = RunnableConfig()
|
||||
|
||||
@@ -161,7 +183,6 @@ class TestDashboardCreationNode:
|
||||
assert isinstance(result.messages[0], AssistantToolCallMessage)
|
||||
assert "Search insights queries are required" in result.messages[0].content
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("ee.hogai.graph.dashboards.nodes.DashboardCreationExecutorNode")
|
||||
@patch.object(DashboardCreationNode, "_search_insights")
|
||||
@patch.object(DashboardCreationNode, "_create_insights")
|
||||
@@ -205,7 +226,7 @@ class TestDashboardCreationNode:
|
||||
state = AssistantState(
|
||||
dashboard_name="Create dashboard",
|
||||
search_insights_queries=[InsightQuery(name="Query 1", description="Description 1")],
|
||||
root_tool_call_id="test_call",
|
||||
root_tool_call_id="test_tool_call_id",
|
||||
)
|
||||
config = RunnableConfig()
|
||||
|
||||
@@ -217,7 +238,6 @@ class TestDashboardCreationNode:
|
||||
assert "Dashboard Created" in result.messages[0].content
|
||||
assert "Test Dashboard" in result.messages[0].content
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("ee.hogai.graph.dashboards.nodes.DashboardCreationExecutorNode")
|
||||
@patch.object(DashboardCreationNode, "_search_insights")
|
||||
@patch.object(DashboardCreationNode, "_create_insights")
|
||||
@@ -244,7 +264,7 @@ class TestDashboardCreationNode:
|
||||
state = AssistantState(
|
||||
dashboard_name="Create dashboard",
|
||||
search_insights_queries=[InsightQuery(name="Query 1", description="Description 1")],
|
||||
root_tool_call_id="test_call",
|
||||
root_tool_call_id="test_tool_call_id",
|
||||
)
|
||||
config = RunnableConfig()
|
||||
|
||||
@@ -255,7 +275,6 @@ class TestDashboardCreationNode:
|
||||
assert isinstance(result.messages[0], AssistantToolCallMessage)
|
||||
assert "No existing insights matched" in result.messages[0].content
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("ee.hogai.graph.dashboards.nodes.DashboardCreationExecutorNode")
|
||||
@patch.object(DashboardCreationNode, "_search_insights")
|
||||
@patch("ee.hogai.graph.dashboards.nodes.logger")
|
||||
@@ -268,7 +287,7 @@ class TestDashboardCreationNode:
|
||||
state = AssistantState(
|
||||
dashboard_name="Create dashboard",
|
||||
search_insights_queries=[InsightQuery(name="Query 1", description="Description 1")],
|
||||
root_tool_call_id="test_call",
|
||||
root_tool_call_id="test_tool_call_id",
|
||||
)
|
||||
config = RunnableConfig()
|
||||
|
||||
@@ -295,7 +314,7 @@ class TestDashboardCreationNode:
|
||||
result = self.node._create_success_response(
|
||||
mock_dashboard,
|
||||
mock_insights, # type: ignore[arg-type]
|
||||
"test_call",
|
||||
"test_tool_call_id",
|
||||
["Query without insights"],
|
||||
)
|
||||
|
||||
@@ -309,7 +328,7 @@ class TestDashboardCreationNode:
|
||||
|
||||
def test_create_no_insights_response(self):
|
||||
"""Test _create_no_insights_response creates correct no insights message."""
|
||||
result = self.node._create_no_insights_response("test_call", "No insights found")
|
||||
result = self.node._create_no_insights_response("test_tool_call_id", "No insights found")
|
||||
|
||||
assert isinstance(result, PartialAssistantState)
|
||||
assert len(result.messages) == 1
|
||||
@@ -320,27 +339,30 @@ class TestDashboardCreationNode:
|
||||
def test_create_error_response(self):
|
||||
"""Test _create_error_response creates correct error message."""
|
||||
with patch("ee.hogai.graph.dashboards.nodes.capture_exception") as mock_capture:
|
||||
result = self.node._create_error_response("Test error", "test_call")
|
||||
result = self.node._create_error_response("Test error", "test_tool_call_id")
|
||||
|
||||
assert isinstance(result, PartialAssistantState)
|
||||
assert len(result.messages) == 1
|
||||
assert isinstance(result.messages[0], AssistantToolCallMessage)
|
||||
assert result.messages[0].content == "Test error"
|
||||
assert isinstance(result.messages[0], AssistantToolCallMessage)
|
||||
assert result.messages[0].tool_call_id == "test_call"
|
||||
assert result.messages[0].tool_call_id == "test_tool_call_id"
|
||||
mock_capture.assert_called_once()
|
||||
|
||||
|
||||
class TestDashboardCreationNodeAsyncMethods:
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_method(self):
|
||||
class TestDashboardCreationNodeAsyncMethods(BaseTest):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.mock_team = MagicMock(spec=Team)
|
||||
self.mock_team.id = 1
|
||||
self.mock_user = MagicMock(spec=User)
|
||||
self.mock_user.id = 1
|
||||
self.node = DashboardCreationNode(self.mock_team, self.mock_user)
|
||||
self.node = DashboardCreationNode(
|
||||
self.mock_team,
|
||||
self.mock_user,
|
||||
node_path=(NodePath(name="test_node", message_id="test-id", tool_call_id="test_tool_call_id"),),
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("ee.hogai.graph.dashboards.nodes.DashboardCreationExecutorNode")
|
||||
async def test_create_insights(self, mock_executor_node_class):
|
||||
"""Test _create_insights method."""
|
||||
@@ -389,7 +411,6 @@ class TestDashboardCreationNodeAsyncMethods:
|
||||
assert len(result["task_1"].created_insight_ids) == 2
|
||||
assert len(result["task_1"].created_insight_messages) == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("ee.hogai.graph.dashboards.nodes.DashboardCreationExecutorNode")
|
||||
async def test_search_insights(self, mock_executor_node_class):
|
||||
"""Test _search_insights method."""
|
||||
@@ -430,7 +451,6 @@ class TestDashboardCreationNodeAsyncMethods:
|
||||
assert len(result["task_1"].found_insight_ids) == 2
|
||||
assert len(result["task_1"].found_insight_messages) == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("ee.hogai.graph.dashboards.nodes.database_sync_to_async")
|
||||
async def test_create_dashboard_with_insights(self, mock_db_sync):
|
||||
"""Test _create_dashboard_with_insights method."""
|
||||
@@ -452,7 +472,6 @@ class TestDashboardCreationNodeAsyncMethods:
|
||||
assert result[1] == mock_insights
|
||||
mock_sync_func.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_insight_creation_results(self):
|
||||
"""Test _process_insight_creation_results method."""
|
||||
# Create simple mocked Team and User instances like other tests
|
||||
|
||||
@@ -1,21 +1,25 @@
|
||||
from typing import Literal, Optional
|
||||
|
||||
from posthog.models.team.team import Team
|
||||
from posthog.models.user import User
|
||||
|
||||
from ee.hogai.django_checkpoint.checkpointer import DjangoCheckpointer
|
||||
from ee.hogai.graph.base import BaseAssistantGraph
|
||||
from ee.hogai.graph.deep_research.notebook.nodes import DeepResearchNotebookPlanningNode
|
||||
from ee.hogai.graph.deep_research.onboarding.nodes import DeepResearchOnboardingNode
|
||||
from ee.hogai.graph.deep_research.planner.nodes import DeepResearchPlannerNode, DeepResearchPlannerToolsNode
|
||||
from ee.hogai.graph.deep_research.report.nodes import DeepResearchReportNode
|
||||
from ee.hogai.graph.deep_research.task_executor.nodes import DeepResearchTaskExecutorNode
|
||||
from ee.hogai.graph.deep_research.types import DeepResearchNodeName, DeepResearchState
|
||||
from ee.hogai.graph.graph import BaseAssistantGraph
|
||||
from ee.hogai.graph.deep_research.types import DeepResearchNodeName, DeepResearchState, PartialDeepResearchState
|
||||
from ee.hogai.graph.title_generator.nodes import TitleGeneratorNode
|
||||
from ee.hogai.utils.types.base import AssistantGraphName
|
||||
|
||||
|
||||
class DeepResearchAssistantGraph(BaseAssistantGraph[DeepResearchState]):
|
||||
def __init__(self, team: Team, user: User):
|
||||
super().__init__(team, user, DeepResearchState)
|
||||
class DeepResearchAssistantGraph(BaseAssistantGraph[DeepResearchState, PartialDeepResearchState]):
|
||||
@property
|
||||
def graph_name(self) -> AssistantGraphName:
|
||||
return AssistantGraphName.DEEP_RESEARCH
|
||||
|
||||
@property
|
||||
def state_type(self) -> type[DeepResearchState]:
|
||||
return DeepResearchState
|
||||
|
||||
def add_onboarding_node(
|
||||
self, node_map: Optional[dict[Literal["onboarding", "planning", "continue"], DeepResearchNodeName]] = None
|
||||
@@ -80,13 +84,21 @@ class DeepResearchAssistantGraph(BaseAssistantGraph[DeepResearchState]):
|
||||
builder.add_edge(DeepResearchNodeName.REPORT, next_node)
|
||||
return self
|
||||
|
||||
def add_title_generator(self, end_node: DeepResearchNodeName = DeepResearchNodeName.END):
|
||||
self._has_start_node = True
|
||||
|
||||
title_generator = TitleGeneratorNode(self._team, self._user)
|
||||
self._graph.add_node(DeepResearchNodeName.TITLE_GENERATOR, title_generator)
|
||||
self._graph.add_edge(DeepResearchNodeName.START, DeepResearchNodeName.TITLE_GENERATOR)
|
||||
self._graph.add_edge(DeepResearchNodeName.TITLE_GENERATOR, end_node)
|
||||
return self
|
||||
|
||||
def compile_full_graph(self, checkpointer: DjangoCheckpointer | None = None):
|
||||
return (
|
||||
self.add_onboarding_node()
|
||||
.add_notebook_nodes()
|
||||
.add_planner_nodes()
|
||||
.add_report_node()
|
||||
.add_title_generator()
|
||||
.add_task_executor()
|
||||
.compile(checkpointer=checkpointer)
|
||||
)
|
||||
|
||||
@@ -40,8 +40,8 @@ from ee.hogai.graph.deep_research.types import (
|
||||
DeepResearchNodeName,
|
||||
DeepResearchState,
|
||||
PartialDeepResearchState,
|
||||
TodoItem,
|
||||
)
|
||||
from ee.hogai.graph.root.tools.todo_write import TodoItem
|
||||
from ee.hogai.notebook.notebook_serializer import NotebookSerializer
|
||||
from ee.hogai.utils.helpers import normalize_ai_message
|
||||
from ee.hogai.utils.types import WithCommentary
|
||||
|
||||
@@ -34,8 +34,7 @@ from ee.hogai.graph.deep_research.planner.prompts import (
|
||||
WRITE_RESULT_FAILED_TOOL_RESULT,
|
||||
WRITE_RESULT_TOOL_RESULT,
|
||||
)
|
||||
from ee.hogai.graph.deep_research.types import DeepResearchState, PartialDeepResearchState
|
||||
from ee.hogai.graph.root.tools.todo_write import TodoItem
|
||||
from ee.hogai.graph.deep_research.types import DeepResearchState, PartialDeepResearchState, TodoItem
|
||||
from ee.hogai.utils.types import InsightArtifact
|
||||
from ee.hogai.utils.types.base import TaskResult
|
||||
|
||||
|
||||
@@ -128,7 +128,7 @@ class TestTaskExecutorNodeArun(TestTaskExecutorNode):
|
||||
|
||||
class TestTaskExecutorInsightsExecution(TestTaskExecutorNode):
|
||||
@patch("ee.hogai.graph.deep_research.task_executor.nodes.DeepResearchTaskExecutorNode.dispatcher")
|
||||
@patch("ee.hogai.graph.deep_research.task_executor.nodes.InsightsAssistantGraph")
|
||||
@patch("ee.hogai.graph.deep_research.task_executor.nodes.InsightsGraph")
|
||||
async def test_execute_task_with_insights_successful(self, mock_insights_graph_class, mock_dispatcher):
|
||||
"""Test successful task execution through insights pipeline."""
|
||||
|
||||
@@ -171,7 +171,7 @@ class TestTaskExecutorInsightsExecution(TestTaskExecutorNode):
|
||||
self.assertEqual(len(result.artifacts), 1)
|
||||
|
||||
@patch("ee.hogai.graph.deep_research.task_executor.nodes.DeepResearchTaskExecutorNode.dispatcher")
|
||||
@patch("ee.hogai.graph.deep_research.task_executor.nodes.InsightsAssistantGraph")
|
||||
@patch("ee.hogai.graph.deep_research.task_executor.nodes.InsightsGraph")
|
||||
async def test_execute_task_with_insights_no_artifacts(self, mock_insights_graph_class, mock_dispatcher):
|
||||
"""Test task execution that produces no artifacts."""
|
||||
task = self._create_assistant_tool_call()
|
||||
@@ -209,7 +209,7 @@ class TestTaskExecutorInsightsExecution(TestTaskExecutorNode):
|
||||
|
||||
@patch("ee.hogai.graph.deep_research.task_executor.nodes.capture_exception")
|
||||
@patch("ee.hogai.graph.deep_research.task_executor.nodes.DeepResearchTaskExecutorNode.dispatcher")
|
||||
@patch("ee.hogai.graph.deep_research.task_executor.nodes.InsightsAssistantGraph")
|
||||
@patch("ee.hogai.graph.deep_research.task_executor.nodes.InsightsGraph")
|
||||
async def test_execute_task_with_exception(self, mock_insights_graph_class, mock_dispatcher, mock_capture):
|
||||
"""Test task execution that encounters an exception."""
|
||||
task = self._create_assistant_tool_call()
|
||||
|
||||
@@ -27,8 +27,7 @@ from ee.hogai.graph.deep_research.onboarding.nodes import DeepResearchOnboarding
|
||||
from ee.hogai.graph.deep_research.planner.nodes import DeepResearchPlannerNode, DeepResearchPlannerToolsNode
|
||||
from ee.hogai.graph.deep_research.report.nodes import DeepResearchReportNode
|
||||
from ee.hogai.graph.deep_research.task_executor.nodes import DeepResearchTaskExecutorNode
|
||||
from ee.hogai.graph.deep_research.types import DeepResearchIntermediateResult, DeepResearchState
|
||||
from ee.hogai.graph.root.tools.todo_write import TodoItem
|
||||
from ee.hogai.graph.deep_research.types import DeepResearchIntermediateResult, DeepResearchState, TodoItem
|
||||
from ee.hogai.utils.types.base import TaskResult
|
||||
from ee.models.assistant import Conversation
|
||||
|
||||
|
||||
@@ -12,9 +12,9 @@ from ee.hogai.graph.deep_research.types import (
|
||||
DeepResearchIntermediateResult,
|
||||
DeepResearchState,
|
||||
PartialDeepResearchState,
|
||||
TodoItem,
|
||||
_SharedDeepResearchState,
|
||||
)
|
||||
from ee.hogai.graph.root.tools.todo_write import TodoItem
|
||||
from ee.hogai.utils.types.base import InsightArtifact, TaskResult
|
||||
|
||||
"""
|
||||
|
||||
@@ -1,19 +1,31 @@
|
||||
from collections.abc import Sequence
|
||||
from enum import StrEnum
|
||||
from typing import Annotated, Optional
|
||||
from typing import Annotated, Literal, Optional
|
||||
|
||||
from langgraph.graph import END, START
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from posthog.schema import DeepResearchNotebook
|
||||
|
||||
from ee.hogai.graph.root.tools.todo_write import TodoItem
|
||||
from ee.hogai.utils.types import AssistantMessageUnion, add_and_merge_messages
|
||||
from ee.hogai.utils.types.base import BaseStateWithMessages, BaseStateWithTasks, append, replace
|
||||
from ee.hogai.utils.types.base import (
|
||||
AssistantMessageUnion,
|
||||
BaseStateWithMessages,
|
||||
BaseStateWithTasks,
|
||||
add_and_merge_messages,
|
||||
append,
|
||||
replace,
|
||||
)
|
||||
|
||||
NotebookInfo = DeepResearchNotebook
|
||||
|
||||
|
||||
class TodoItem(BaseModel):
|
||||
content: str = Field(..., min_length=1)
|
||||
status: Literal["pending", "in_progress", "completed"]
|
||||
id: str
|
||||
priority: Literal["low", "medium", "high"]
|
||||
|
||||
|
||||
class DeepResearchIntermediateResult(BaseModel):
|
||||
"""
|
||||
An intermediate result of a batch of work, that will be used to write the final report.
|
||||
@@ -69,3 +81,4 @@ class DeepResearchNodeName(StrEnum):
|
||||
PLANNER_TOOLS = "planner_tools"
|
||||
TASK_EXECUTOR = "task_executor"
|
||||
REPORT = "report"
|
||||
TITLE_GENERATOR = "title_generator"
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from langchain_core.prompts import ChatPromptTemplate
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
|
||||
from posthog.schema import AssistantFunnelsQuery, AssistantMessage
|
||||
from posthog.schema import AssistantFunnelsQuery
|
||||
|
||||
from ee.hogai.utils.types import AssistantState, PartialAssistantState
|
||||
from ee.hogai.utils.types.base import AssistantNodeName
|
||||
@@ -25,7 +25,7 @@ class FunnelGeneratorNode(SchemaGeneratorNode[AssistantFunnelsQuery]):
|
||||
return AssistantNodeName.FUNNEL_GENERATOR
|
||||
|
||||
async def arun(self, state: AssistantState, config: RunnableConfig) -> PartialAssistantState:
|
||||
self.dispatcher.message(AssistantMessage(content="Creating funnel query"))
|
||||
self.dispatcher.update("Creating funnel query")
|
||||
prompt = ChatPromptTemplate.from_messages(
|
||||
[
|
||||
("system", FUNNEL_SYSTEM_PROMPT),
|
||||
|
||||
@@ -1,23 +1,11 @@
|
||||
from collections.abc import Hashable
|
||||
from typing import Any, Generic, Literal, Optional, cast
|
||||
|
||||
from langgraph.graph.state import StateGraph
|
||||
|
||||
from posthog.models.team.team import Team
|
||||
from posthog.models.user import User
|
||||
from collections.abc import Callable
|
||||
from typing import cast
|
||||
|
||||
from ee.hogai.django_checkpoint.checkpointer import DjangoCheckpointer
|
||||
from ee.hogai.graph.billing.nodes import BillingNode
|
||||
from ee.hogai.graph.query_planner.nodes import QueryPlannerNode, QueryPlannerToolsNode
|
||||
from ee.hogai.graph.session_summaries.nodes import SessionSummarizationNode
|
||||
from ee.hogai.graph.base import BaseAssistantGraph
|
||||
from ee.hogai.graph.title_generator.nodes import TitleGeneratorNode
|
||||
from ee.hogai.utils.types import AssistantNodeName, AssistantState, StateType
|
||||
from ee.hogai.utils.types.composed import MaxNodeName
|
||||
from ee.hogai.utils.types.base import AssistantGraphName, AssistantNodeName, AssistantState, PartialAssistantState
|
||||
|
||||
from .dashboards.nodes import DashboardCreationNode
|
||||
from .funnels.nodes import FunnelGeneratorNode, FunnelGeneratorToolsNode
|
||||
from .inkeep_docs.nodes import InkeepDocsNode
|
||||
from .insights.nodes import InsightSearchNode
|
||||
from .memory.nodes import (
|
||||
MemoryCollectorNode,
|
||||
MemoryCollectorToolsNode,
|
||||
@@ -28,51 +16,19 @@ from .memory.nodes import (
|
||||
MemoryOnboardingFinalizeNode,
|
||||
MemoryOnboardingNode,
|
||||
)
|
||||
from .query_executor.nodes import QueryExecutorNode
|
||||
from .rag.nodes import InsightRagContextNode
|
||||
from .retention.nodes import RetentionGeneratorNode, RetentionGeneratorToolsNode
|
||||
from .root.nodes import RootNode, RootNodeTools
|
||||
from .sql.nodes import SQLGeneratorNode, SQLGeneratorToolsNode
|
||||
from .trends.nodes import TrendsGeneratorNode, TrendsGeneratorToolsNode
|
||||
|
||||
global_checkpointer = DjangoCheckpointer()
|
||||
|
||||
|
||||
class BaseAssistantGraph(Generic[StateType]):
|
||||
_team: Team
|
||||
_user: User
|
||||
_graph: StateGraph
|
||||
_parent_tool_call_id: str | None
|
||||
class AssistantGraph(BaseAssistantGraph[AssistantState, PartialAssistantState]):
|
||||
@property
|
||||
def graph_name(self) -> AssistantGraphName:
|
||||
return AssistantGraphName.ASSISTANT
|
||||
|
||||
def __init__(self, team: Team, user: User, state_type: type[StateType], parent_tool_call_id: str | None = None):
|
||||
self._team = team
|
||||
self._user = user
|
||||
self._graph = StateGraph(state_type)
|
||||
self._has_start_node = False
|
||||
self._parent_tool_call_id = parent_tool_call_id
|
||||
@property
|
||||
def state_type(self) -> type[AssistantState]:
|
||||
return AssistantState
|
||||
|
||||
def add_edge(self, from_node: MaxNodeName, to_node: MaxNodeName):
|
||||
if from_node == AssistantNodeName.START:
|
||||
self._has_start_node = True
|
||||
self._graph.add_edge(from_node, to_node)
|
||||
return self
|
||||
|
||||
def add_node(self, node: MaxNodeName, action: Any):
|
||||
if self._parent_tool_call_id:
|
||||
action._parent_tool_call_id = self._parent_tool_call_id
|
||||
self._graph.add_node(node, action)
|
||||
return self
|
||||
|
||||
def compile(self, checkpointer: DjangoCheckpointer | None | Literal[False] = None):
|
||||
if not self._has_start_node:
|
||||
raise ValueError("Start node not added to the graph")
|
||||
# TRICKY: We check `is not None` because False has a special meaning of "no checkpointer", which we want to pass on
|
||||
compiled_graph = self._graph.compile(
|
||||
checkpointer=checkpointer if checkpointer is not None else global_checkpointer
|
||||
)
|
||||
return compiled_graph
|
||||
|
||||
def add_title_generator(self, end_node: MaxNodeName = AssistantNodeName.END):
|
||||
def add_title_generator(self, end_node: AssistantNodeName = AssistantNodeName.END):
|
||||
self._has_start_node = True
|
||||
|
||||
title_generator = TitleGeneratorNode(self._team, self._user)
|
||||
@@ -81,178 +37,19 @@ class BaseAssistantGraph(Generic[StateType]):
|
||||
self._graph.add_edge(AssistantNodeName.TITLE_GENERATOR, end_node)
|
||||
return self
|
||||
|
||||
|
||||
class InsightsAssistantGraph(BaseAssistantGraph[AssistantState]):
|
||||
def __init__(self, team: Team, user: User, tool_call_id: str | None = None):
|
||||
super().__init__(team, user, AssistantState, tool_call_id)
|
||||
|
||||
def add_rag_context(self):
|
||||
self._has_start_node = True
|
||||
retriever = InsightRagContextNode(self._team, self._user)
|
||||
self.add_node(AssistantNodeName.INSIGHT_RAG_CONTEXT, retriever)
|
||||
self._graph.add_edge(AssistantNodeName.START, AssistantNodeName.INSIGHT_RAG_CONTEXT)
|
||||
self._graph.add_edge(AssistantNodeName.INSIGHT_RAG_CONTEXT, AssistantNodeName.QUERY_PLANNER)
|
||||
return self
|
||||
|
||||
def add_trends_generator(self, next_node: AssistantNodeName = AssistantNodeName.QUERY_EXECUTOR):
|
||||
trends_generator = TrendsGeneratorNode(self._team, self._user)
|
||||
self.add_node(AssistantNodeName.TRENDS_GENERATOR, trends_generator)
|
||||
|
||||
trends_generator_tools = TrendsGeneratorToolsNode(self._team, self._user)
|
||||
self.add_node(AssistantNodeName.TRENDS_GENERATOR_TOOLS, trends_generator_tools)
|
||||
|
||||
self._graph.add_edge(AssistantNodeName.TRENDS_GENERATOR_TOOLS, AssistantNodeName.TRENDS_GENERATOR)
|
||||
self._graph.add_conditional_edges(
|
||||
AssistantNodeName.TRENDS_GENERATOR,
|
||||
trends_generator.router,
|
||||
path_map={
|
||||
"tools": AssistantNodeName.TRENDS_GENERATOR_TOOLS,
|
||||
"next": next_node,
|
||||
},
|
||||
)
|
||||
|
||||
return self
|
||||
|
||||
def add_funnel_generator(self, next_node: AssistantNodeName = AssistantNodeName.QUERY_EXECUTOR):
|
||||
funnel_generator = FunnelGeneratorNode(self._team, self._user)
|
||||
self.add_node(AssistantNodeName.FUNNEL_GENERATOR, funnel_generator)
|
||||
|
||||
funnel_generator_tools = FunnelGeneratorToolsNode(self._team, self._user)
|
||||
self.add_node(AssistantNodeName.FUNNEL_GENERATOR_TOOLS, funnel_generator_tools)
|
||||
|
||||
self._graph.add_edge(AssistantNodeName.FUNNEL_GENERATOR_TOOLS, AssistantNodeName.FUNNEL_GENERATOR)
|
||||
self._graph.add_conditional_edges(
|
||||
AssistantNodeName.FUNNEL_GENERATOR,
|
||||
funnel_generator.router,
|
||||
path_map={
|
||||
"tools": AssistantNodeName.FUNNEL_GENERATOR_TOOLS,
|
||||
"next": next_node,
|
||||
},
|
||||
)
|
||||
|
||||
return self
|
||||
|
||||
def add_retention_generator(self, next_node: AssistantNodeName = AssistantNodeName.QUERY_EXECUTOR):
|
||||
retention_generator = RetentionGeneratorNode(self._team, self._user)
|
||||
self.add_node(AssistantNodeName.RETENTION_GENERATOR, retention_generator)
|
||||
|
||||
retention_generator_tools = RetentionGeneratorToolsNode(self._team, self._user)
|
||||
self.add_node(AssistantNodeName.RETENTION_GENERATOR_TOOLS, retention_generator_tools)
|
||||
|
||||
self._graph.add_edge(AssistantNodeName.RETENTION_GENERATOR_TOOLS, AssistantNodeName.RETENTION_GENERATOR)
|
||||
self._graph.add_conditional_edges(
|
||||
AssistantNodeName.RETENTION_GENERATOR,
|
||||
retention_generator.router,
|
||||
path_map={
|
||||
"tools": AssistantNodeName.RETENTION_GENERATOR_TOOLS,
|
||||
"next": next_node,
|
||||
},
|
||||
)
|
||||
|
||||
return self
|
||||
|
||||
def add_query_planner(
|
||||
self,
|
||||
path_map: Optional[
|
||||
dict[Literal["trends", "funnel", "retention", "sql", "continue", "end"], AssistantNodeName]
|
||||
] = None,
|
||||
):
|
||||
query_planner = QueryPlannerNode(self._team, self._user)
|
||||
self.add_node(AssistantNodeName.QUERY_PLANNER, query_planner)
|
||||
self._graph.add_edge(AssistantNodeName.QUERY_PLANNER, AssistantNodeName.QUERY_PLANNER_TOOLS)
|
||||
|
||||
query_planner_tools = QueryPlannerToolsNode(self._team, self._user)
|
||||
self.add_node(AssistantNodeName.QUERY_PLANNER_TOOLS, query_planner_tools)
|
||||
self._graph.add_conditional_edges(
|
||||
AssistantNodeName.QUERY_PLANNER_TOOLS,
|
||||
query_planner_tools.router,
|
||||
path_map=path_map # type: ignore
|
||||
or {
|
||||
"continue": AssistantNodeName.QUERY_PLANNER,
|
||||
"trends": AssistantNodeName.TRENDS_GENERATOR,
|
||||
"funnel": AssistantNodeName.FUNNEL_GENERATOR,
|
||||
"retention": AssistantNodeName.RETENTION_GENERATOR,
|
||||
"sql": AssistantNodeName.SQL_GENERATOR,
|
||||
"end": AssistantNodeName.END,
|
||||
},
|
||||
)
|
||||
|
||||
return self
|
||||
|
||||
def add_sql_generator(self, next_node: AssistantNodeName = AssistantNodeName.QUERY_EXECUTOR):
|
||||
sql_generator = SQLGeneratorNode(self._team, self._user)
|
||||
self.add_node(AssistantNodeName.SQL_GENERATOR, sql_generator)
|
||||
|
||||
sql_generator_tools = SQLGeneratorToolsNode(self._team, self._user)
|
||||
self.add_node(AssistantNodeName.SQL_GENERATOR_TOOLS, sql_generator_tools)
|
||||
|
||||
self._graph.add_edge(AssistantNodeName.SQL_GENERATOR_TOOLS, AssistantNodeName.SQL_GENERATOR)
|
||||
self._graph.add_conditional_edges(
|
||||
AssistantNodeName.SQL_GENERATOR,
|
||||
sql_generator.router,
|
||||
path_map={
|
||||
"tools": AssistantNodeName.SQL_GENERATOR_TOOLS,
|
||||
"next": next_node,
|
||||
},
|
||||
)
|
||||
|
||||
return self
|
||||
|
||||
def add_query_executor(self, next_node: AssistantNodeName = AssistantNodeName.END):
|
||||
query_executor_node = QueryExecutorNode(self._team, self._user)
|
||||
self.add_node(AssistantNodeName.QUERY_EXECUTOR, query_executor_node)
|
||||
self._graph.add_edge(AssistantNodeName.QUERY_EXECUTOR, next_node)
|
||||
return self
|
||||
|
||||
def add_query_creation_flow(self, next_node: AssistantNodeName = AssistantNodeName.QUERY_EXECUTOR):
|
||||
"""Add all nodes and edges EXCEPT query execution."""
|
||||
return (
|
||||
self.add_rag_context()
|
||||
.add_query_planner()
|
||||
.add_trends_generator(next_node=next_node)
|
||||
.add_funnel_generator(next_node=next_node)
|
||||
.add_retention_generator(next_node=next_node)
|
||||
.add_sql_generator(next_node=next_node)
|
||||
)
|
||||
|
||||
def compile_full_graph(self, checkpointer: DjangoCheckpointer | None = None):
|
||||
return self.add_query_creation_flow().add_query_executor().compile(checkpointer=checkpointer)
|
||||
|
||||
|
||||
class AssistantGraph(BaseAssistantGraph[AssistantState]):
|
||||
def __init__(self, team: Team, user: User):
|
||||
super().__init__(team, user, AssistantState)
|
||||
|
||||
def add_root(
|
||||
self,
|
||||
path_map: Optional[dict[Hashable, AssistantNodeName]] = None,
|
||||
tools_node: AssistantNodeName = AssistantNodeName.ROOT_TOOLS,
|
||||
):
|
||||
path_map = path_map or {
|
||||
"insights": AssistantNodeName.INSIGHTS_SUBGRAPH,
|
||||
"search_documentation": AssistantNodeName.INKEEP_DOCS,
|
||||
"root": AssistantNodeName.ROOT,
|
||||
"billing": AssistantNodeName.BILLING,
|
||||
"end": AssistantNodeName.END,
|
||||
"insights_search": AssistantNodeName.INSIGHTS_SEARCH,
|
||||
"session_summarization": AssistantNodeName.SESSION_SUMMARIZATION,
|
||||
"create_dashboard": AssistantNodeName.DASHBOARD_CREATION,
|
||||
}
|
||||
def add_root(self, router: Callable[[AssistantState], AssistantNodeName] | None = None):
|
||||
root_node = RootNode(self._team, self._user)
|
||||
self.add_node(AssistantNodeName.ROOT, root_node)
|
||||
root_node_tools = RootNodeTools(self._team, self._user)
|
||||
self.add_node(AssistantNodeName.ROOT_TOOLS, root_node_tools)
|
||||
self._graph.add_edge(AssistantNodeName.ROOT, tools_node)
|
||||
self._graph.add_conditional_edges(
|
||||
AssistantNodeName.ROOT_TOOLS, root_node_tools.router, path_map=cast(dict[Hashable, str], path_map)
|
||||
AssistantNodeName.ROOT, router or cast(Callable[[AssistantState], AssistantNodeName], root_node.router)
|
||||
)
|
||||
self._graph.add_conditional_edges(
|
||||
AssistantNodeName.ROOT_TOOLS,
|
||||
root_node_tools.router,
|
||||
path_map={"root": AssistantNodeName.ROOT, "end": AssistantNodeName.END},
|
||||
)
|
||||
return self
|
||||
|
||||
def add_insights(self, next_node: AssistantNodeName = AssistantNodeName.ROOT):
|
||||
insights_assistant_graph = InsightsAssistantGraph(self._team, self._user)
|
||||
compiled_graph = insights_assistant_graph.compile_full_graph()
|
||||
self.add_node(AssistantNodeName.INSIGHTS_SUBGRAPH, compiled_graph)
|
||||
self._graph.add_edge(AssistantNodeName.INSIGHTS_SUBGRAPH, next_node)
|
||||
return self
|
||||
|
||||
def add_memory_onboarding(
|
||||
@@ -332,55 +129,6 @@ class AssistantGraph(BaseAssistantGraph[AssistantState]):
|
||||
self._graph.add_edge(AssistantNodeName.MEMORY_COLLECTOR_TOOLS, AssistantNodeName.MEMORY_COLLECTOR)
|
||||
return self
|
||||
|
||||
def add_inkeep_docs(self, path_map: Optional[dict[Hashable, AssistantNodeName]] = None):
|
||||
"""Add the Inkeep docs search node to the graph."""
|
||||
path_map = path_map or {
|
||||
"end": AssistantNodeName.END,
|
||||
"root": AssistantNodeName.ROOT,
|
||||
}
|
||||
inkeep_docs_node = InkeepDocsNode(self._team, self._user)
|
||||
self.add_node(AssistantNodeName.INKEEP_DOCS, inkeep_docs_node)
|
||||
self._graph.add_conditional_edges(
|
||||
AssistantNodeName.INKEEP_DOCS,
|
||||
inkeep_docs_node.router,
|
||||
path_map=cast(dict[Hashable, str], path_map),
|
||||
)
|
||||
return self
|
||||
|
||||
def add_billing(self):
|
||||
billing_node = BillingNode(self._team, self._user)
|
||||
self.add_node(AssistantNodeName.BILLING, billing_node)
|
||||
self._graph.add_edge(AssistantNodeName.BILLING, AssistantNodeName.ROOT)
|
||||
return self
|
||||
|
||||
def add_insights_search(self, end_node: AssistantNodeName = AssistantNodeName.END):
|
||||
path_map = {
|
||||
"end": end_node,
|
||||
"root": AssistantNodeName.ROOT,
|
||||
}
|
||||
|
||||
insights_search_node = InsightSearchNode(self._team, self._user)
|
||||
self.add_node(AssistantNodeName.INSIGHTS_SEARCH, insights_search_node)
|
||||
self._graph.add_conditional_edges(
|
||||
AssistantNodeName.INSIGHTS_SEARCH,
|
||||
insights_search_node.router,
|
||||
path_map=cast(dict[Hashable, str], path_map),
|
||||
)
|
||||
return self
|
||||
|
||||
def add_session_summarization(self, end_node: AssistantNodeName = AssistantNodeName.END):
|
||||
session_summarization_node = SessionSummarizationNode(self._team, self._user)
|
||||
self.add_node(AssistantNodeName.SESSION_SUMMARIZATION, session_summarization_node)
|
||||
self._graph.add_edge(AssistantNodeName.SESSION_SUMMARIZATION, AssistantNodeName.ROOT)
|
||||
return self
|
||||
|
||||
def add_dashboard_creation(self, end_node: AssistantNodeName = AssistantNodeName.END):
|
||||
builder = self._graph
|
||||
dashboard_creation_node = DashboardCreationNode(self._team, self._user)
|
||||
builder.add_node(AssistantNodeName.DASHBOARD_CREATION, dashboard_creation_node)
|
||||
builder.add_edge(AssistantNodeName.DASHBOARD_CREATION, AssistantNodeName.ROOT)
|
||||
return self
|
||||
|
||||
def compile_full_graph(self, checkpointer: DjangoCheckpointer | None = None):
|
||||
return (
|
||||
self.add_title_generator()
|
||||
@@ -388,11 +136,5 @@ class AssistantGraph(BaseAssistantGraph[AssistantState]):
|
||||
.add_memory_collector()
|
||||
.add_memory_collector_tools()
|
||||
.add_root()
|
||||
.add_insights()
|
||||
.add_inkeep_docs()
|
||||
.add_billing()
|
||||
.add_insights_search()
|
||||
.add_session_summarization()
|
||||
.add_dashboard_creation()
|
||||
.compile(checkpointer=checkpointer)
|
||||
)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, Literal
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
from django.conf import settings
|
||||
@@ -34,19 +34,22 @@ class InkeepDocsNode(RootNode): # Inheriting from RootNode to use the same mess
|
||||
|
||||
async def arun(self, state: AssistantState, config: RunnableConfig) -> PartialAssistantState:
|
||||
"""Process the state and return documentation search results."""
|
||||
self.dispatcher.message(AssistantMessage(content="Checking PostHog documentation...", id=str(uuid4())))
|
||||
self.dispatcher.update("Checking PostHog documentation...")
|
||||
|
||||
messages = self._construct_messages(
|
||||
state.messages, state.root_conversation_start_id, state.root_tool_calls_count
|
||||
)
|
||||
|
||||
message: LangchainAIMessage = await self._get_model().ainvoke(messages, config)
|
||||
# NOTE: This is a hacky way to send these messages as part of the root tool call
|
||||
# Can't think of a better interface for this at the moment.
|
||||
self.dispatcher.set_as_root()
|
||||
should_continue = INKEEP_DATA_CONTINUATION_PHRASE in message.content
|
||||
|
||||
tool_prompt = "Checking PostHog documentation..."
|
||||
if should_continue:
|
||||
tool_prompt = "The documentation search results are provided in the next Assistant message.\n<system_reminder>Continue with the user's data request.</system_reminder>"
|
||||
|
||||
return PartialAssistantState(
|
||||
messages=[
|
||||
AssistantToolCallMessage(
|
||||
content="Checking PostHog documentation...", tool_call_id=state.root_tool_call_id, id=str(uuid4())
|
||||
),
|
||||
AssistantToolCallMessage(content=tool_prompt, tool_call_id=state.root_tool_call_id, id=str(uuid4())),
|
||||
AssistantMessage(content=message.content, id=str(uuid4())),
|
||||
],
|
||||
root_tool_call_id=None,
|
||||
@@ -116,16 +119,3 @@ class InkeepDocsNode(RootNode): # Inheriting from RootNode to use the same mess
|
||||
) -> list[BaseMessage]:
|
||||
# Original node has Anthropic messages, but Inkeep expects OpenAI messages
|
||||
return convert_to_openai_messages(conversation_window, tool_result_messages)
|
||||
|
||||
def router(self, state: AssistantState) -> Literal["end", "root"]:
|
||||
last_message = state.messages[-1]
|
||||
if isinstance(last_message, AssistantMessage) and INKEEP_DATA_CONTINUATION_PHRASE in last_message.content:
|
||||
# The continuation phrase solution is a little weird, but seems it's the best one for agentic capabilities
|
||||
# I've found here. The alternatives that definitively don't work are:
|
||||
# 1. Using tool calls in this node - the Inkeep API only supports providing their own pre-defined tools
|
||||
# (for including extra search metadata), nothing else
|
||||
# 2. Always going back to root, for root to judge whether to continue or not - GPT-4o is terrible at this,
|
||||
# and I was unable to stop it from repeating the context from the last assistant message, i.e. the Inkeep
|
||||
# output message (doesn't quite work to tell it to output an empty message, or to call an "end" tool)
|
||||
return "root"
|
||||
return "end"
|
||||
|
||||
@@ -67,6 +67,9 @@ class TestInkeepDocsNode(ClickhouseTestMixin, BaseTest):
|
||||
assert next_state is not None
|
||||
messages = cast(list, next_state.messages)
|
||||
self.assertEqual(len(messages), 2)
|
||||
# Tool call message should have the continuation prompt
|
||||
first_message = cast(AssistantToolCallMessage, messages[0])
|
||||
self.assertIn("Continue with the user's data request", first_message.content)
|
||||
second_message = cast(AssistantMessage, messages[1])
|
||||
self.assertEqual(second_message.content, response_with_continuation)
|
||||
|
||||
@@ -91,26 +94,6 @@ class TestInkeepDocsNode(ClickhouseTestMixin, BaseTest):
|
||||
self.assertIsInstance(messages[2], LangchainAIMessage)
|
||||
self.assertIsInstance(messages[3], LangchainHumanMessage)
|
||||
|
||||
def test_router_with_data_continuation(self):
|
||||
node = InkeepDocsNode(self.team, self.user)
|
||||
state = AssistantState(
|
||||
messages=[
|
||||
HumanMessage(content="Explain PostHog trends, and show me an example trends insight"),
|
||||
AssistantMessage(content=f"Here's the documentation: XYZ.\n{INKEEP_DATA_CONTINUATION_PHRASE}"),
|
||||
]
|
||||
)
|
||||
self.assertEqual(node.router(state), "root") # Going back to root, so that the agent can continue with the task
|
||||
|
||||
def test_router_without_data_continuation(self):
|
||||
node = InkeepDocsNode(self.team, self.user)
|
||||
state = AssistantState(
|
||||
messages=[
|
||||
HumanMessage(content="How do I use feature flags?"),
|
||||
AssistantMessage(content="Here's how to use feature flags..."),
|
||||
]
|
||||
)
|
||||
self.assertEqual(node.router(state), "end") # Ending
|
||||
|
||||
async def test_tool_call_id_handling(self):
|
||||
"""Test that tool_call_id is properly handled in both input and output states."""
|
||||
test_tool_call_id = str(uuid4())
|
||||
|
||||
@@ -4,7 +4,7 @@ import inspect
|
||||
import warnings
|
||||
from datetime import timedelta
|
||||
from functools import wraps
|
||||
from typing import Literal, Optional, TypedDict
|
||||
from typing import TYPE_CHECKING, Optional, TypedDict
|
||||
from uuid import uuid4
|
||||
|
||||
from django.db.models import Max
|
||||
@@ -16,7 +16,7 @@ from langchain_core.runnables import RunnableConfig
|
||||
from langchain_core.tools import tool
|
||||
from langchain_openai import ChatOpenAI
|
||||
|
||||
from posthog.schema import AssistantMessage, AssistantToolCallMessage, VisualizationMessage
|
||||
from posthog.schema import AssistantToolCallMessage, VisualizationMessage
|
||||
|
||||
from posthog.exceptions_capture import capture_exception
|
||||
from posthog.models import Insight
|
||||
@@ -28,18 +28,19 @@ from ee.hogai.graph.shared_prompts import HYPERLINK_USAGE_INSTRUCTIONS
|
||||
from ee.hogai.utils.helpers import build_insight_url
|
||||
from ee.hogai.utils.types import AssistantState, PartialAssistantState
|
||||
from ee.hogai.utils.types.base import AssistantNodeName
|
||||
from ee.hogai.utils.types.composed import MaxNodeName
|
||||
|
||||
from .prompts import (
|
||||
EMPTY_DATABASE_ERROR_MESSAGE,
|
||||
ITERATIVE_SEARCH_SYSTEM_PROMPT,
|
||||
ITERATIVE_SEARCH_USER_PROMPT,
|
||||
NO_INSIGHTS_FOUND_MESSAGE,
|
||||
PAGINATION_INSTRUCTIONS_TEMPLATE,
|
||||
SEARCH_ERROR_INSTRUCTIONS,
|
||||
TOOL_BASED_EVALUATION_SYSTEM_PROMPT,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ee.hogai.utils.types.composed import MaxNodeName
|
||||
|
||||
|
||||
logger = structlog.get_logger(__name__)
|
||||
# Silence Pydantic serializer warnings for creation of VisualizationMessage/Query execution
|
||||
warnings.filterwarnings("ignore", category=UserWarning, message=".*Pydantic serializer.*")
|
||||
@@ -105,6 +106,10 @@ class InsightDict(TypedDict):
|
||||
short_id: str
|
||||
|
||||
|
||||
class NoInsightsException(Exception):
|
||||
"""Exception indicating that the insight search cannot be done because the user does not have any insights."""
|
||||
|
||||
|
||||
class InsightSearchNode(AssistantNode):
|
||||
PAGE_SIZE = 500
|
||||
MAX_SEARCH_ITERATIONS = 6
|
||||
@@ -114,7 +119,7 @@ class InsightSearchNode(AssistantNode):
|
||||
MAX_SERIES_TO_PROCESS = 3
|
||||
|
||||
@property
|
||||
def node_name(self) -> MaxNodeName:
|
||||
def node_name(self) -> "MaxNodeName":
|
||||
return AssistantNodeName.INSIGHTS_SEARCH
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
@@ -133,9 +138,31 @@ class InsightSearchNode(AssistantNode):
|
||||
self._query_cache = {}
|
||||
self._insight_id_cache = {}
|
||||
|
||||
def _dispatch_update_message(self, content: str) -> None:
|
||||
"""Dispatch an update message to the assistant."""
|
||||
self.dispatcher.message(AssistantMessage(content=content))
|
||||
@timing_logger("InsightSearchNode.arun")
|
||||
async def arun(self, state: AssistantState, config: RunnableConfig) -> PartialAssistantState | None:
|
||||
self.dispatcher.update("Searching for insights")
|
||||
search_query = state.search_insights_query
|
||||
self._current_iteration = 0
|
||||
|
||||
total_count = await self._get_total_insights_count()
|
||||
if total_count == 0:
|
||||
raise NoInsightsException
|
||||
|
||||
selected_insights = await self._search_insights_iteratively(search_query or "")
|
||||
logger.warning(
|
||||
f"{TIMING_LOG_PREFIX} search_insights_iteratively returned {len(selected_insights)} insights: {selected_insights}"
|
||||
)
|
||||
|
||||
if selected_insights:
|
||||
self.dispatcher.update(f"Evaluating {len(selected_insights)} insights to find the best match")
|
||||
else:
|
||||
self.dispatcher.update("No existing insights found, creating a new one")
|
||||
|
||||
evaluation_result = await self._evaluate_insights_with_tools(
|
||||
selected_insights, search_query or "", max_selections=1
|
||||
)
|
||||
|
||||
return self._handle_evaluation_result(evaluation_result, state)
|
||||
|
||||
def _create_page_reader_tool(self):
|
||||
"""Create tool for reading insights pages during agentic RAG loop."""
|
||||
@@ -185,36 +212,6 @@ class InsightSearchNode(AssistantNode):
|
||||
|
||||
return [select_insight, reject_all_insights]
|
||||
|
||||
@timing_logger("InsightSearchNode.arun")
|
||||
async def arun(self, state: AssistantState, config: RunnableConfig) -> PartialAssistantState | None:
|
||||
self._dispatch_update_message("Searching for insights")
|
||||
search_query = state.search_insights_query
|
||||
try:
|
||||
self._current_iteration = 0
|
||||
|
||||
total_count = await self._get_total_insights_count()
|
||||
if total_count == 0:
|
||||
return self._handle_empty_database(state)
|
||||
|
||||
selected_insights = await self._search_insights_iteratively(search_query or "")
|
||||
logger.warning(
|
||||
f"{TIMING_LOG_PREFIX} search_insights_iteratively returned {len(selected_insights)} insights: {selected_insights}"
|
||||
)
|
||||
|
||||
if selected_insights:
|
||||
self._dispatch_update_message(f"Evaluating {len(selected_insights)} insights to find the best match")
|
||||
else:
|
||||
self._dispatch_update_message("No existing insights found, creating a new one")
|
||||
|
||||
evaluation_result = await self._evaluate_insights_with_tools(
|
||||
selected_insights, search_query or "", max_selections=1
|
||||
)
|
||||
|
||||
return self._handle_evaluation_result(evaluation_result, state)
|
||||
|
||||
except Exception as e:
|
||||
return self._handle_search_error(e, state)
|
||||
|
||||
@timing_logger("InsightSearchNode._get_insights_queryset")
|
||||
def _get_insights_queryset(self):
|
||||
"""Get Insight objects with latest view time annotated and cutoff date."""
|
||||
@@ -235,10 +232,6 @@ class InsightSearchNode(AssistantNode):
|
||||
self._total_insights_count = await self._get_insights_queryset().acount()
|
||||
return self._total_insights_count
|
||||
|
||||
def _handle_empty_database(self, state: AssistantState) -> PartialAssistantState:
|
||||
"""Handle the case when no insights exist in the database. (Rare edge-case but still possible)"""
|
||||
return self._create_error_response(EMPTY_DATABASE_ERROR_MESSAGE, state.root_tool_call_id)
|
||||
|
||||
def _handle_evaluation_result(self, evaluation_result: dict, state: AssistantState) -> PartialAssistantState:
|
||||
"""Process the evaluation result and return appropriate response."""
|
||||
if evaluation_result["should_use_existing"]:
|
||||
@@ -256,12 +249,12 @@ class InsightSearchNode(AssistantNode):
|
||||
|
||||
return PartialAssistantState(
|
||||
messages=[
|
||||
*evaluation_result["visualization_messages"],
|
||||
AssistantToolCallMessage(
|
||||
content=formatted_content,
|
||||
tool_call_id=state.root_tool_call_id or "unknown",
|
||||
id=str(uuid4()),
|
||||
),
|
||||
*evaluation_result["visualization_messages"],
|
||||
],
|
||||
selected_insight_ids=evaluation_result["selected_insights"],
|
||||
search_insights_query=None,
|
||||
@@ -284,16 +277,6 @@ class InsightSearchNode(AssistantNode):
|
||||
selected_insight_ids=None,
|
||||
)
|
||||
|
||||
@timing_logger("InsightSearchNode._handle_search_error")
|
||||
def _handle_search_error(self, e: Exception, state: AssistantState) -> PartialAssistantState:
|
||||
"""Handle exceptions during search process."""
|
||||
capture_exception(e)
|
||||
logger.error(f"{TIMING_LOG_PREFIX} Error in InsightSearchNode: {e}", exc_info=True)
|
||||
return self._create_error_response(
|
||||
SEARCH_ERROR_INSTRUCTIONS,
|
||||
state.root_tool_call_id,
|
||||
)
|
||||
|
||||
def _format_insight_for_display(self, insight: InsightDict) -> str:
|
||||
"""Format a single insight for display."""
|
||||
name = insight["name"] or insight["derived_name"] or "Unnamed"
|
||||
@@ -424,7 +407,7 @@ class InsightSearchNode(AssistantNode):
|
||||
selected_insights = []
|
||||
|
||||
for step in ["Searching through existing insights", "Analyzing available insights"]:
|
||||
self._dispatch_update_message(step)
|
||||
self.dispatcher.update(step)
|
||||
logger.warning(f"{TIMING_LOG_PREFIX} Starting iterative search, max_iterations={self._max_iterations}")
|
||||
|
||||
while self._current_iteration < self._max_iterations:
|
||||
@@ -446,7 +429,7 @@ class InsightSearchNode(AssistantNode):
|
||||
logger.warning(
|
||||
f"{TIMING_LOG_PREFIX} STALL POINT(?): Streamed 'Finding the most relevant insights' - about to fetch page content"
|
||||
)
|
||||
self._dispatch_update_message("Finding the most relevant insights")
|
||||
self.dispatcher.update("Finding the most relevant insights")
|
||||
|
||||
logger.warning(f"{TIMING_LOG_PREFIX} Fetching page content for page {page_num}")
|
||||
tool_response = await self._get_page_content_for_tool(page_num)
|
||||
@@ -465,15 +448,15 @@ class InsightSearchNode(AssistantNode):
|
||||
content = response.content if isinstance(response.content, str) else str(response.content)
|
||||
selected_insights = self._parse_insight_ids(content)
|
||||
if selected_insights:
|
||||
self._dispatch_update_message(f"Found {len(selected_insights)} relevant insights")
|
||||
self.dispatcher.update(f"Found {len(selected_insights)} relevant insights")
|
||||
else:
|
||||
self._dispatch_update_message("No matching insights found")
|
||||
self.dispatcher.update("No matching insights found")
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
capture_exception(e)
|
||||
error_message = f"Error during search"
|
||||
self._dispatch_update_message(error_message)
|
||||
self.dispatcher.update(error_message)
|
||||
break
|
||||
|
||||
return selected_insights
|
||||
@@ -692,7 +675,7 @@ class InsightSearchNode(AssistantNode):
|
||||
"""Create a VisualizationMessage to render the insight UI."""
|
||||
try:
|
||||
for step in ["Executing insight query...", "Processing query parameters", "Running data analysis"]:
|
||||
self._dispatch_update_message(step)
|
||||
self.dispatcher.update(step)
|
||||
|
||||
query_obj, _ = await self._process_insight_query(insight)
|
||||
|
||||
@@ -786,7 +769,7 @@ class InsightSearchNode(AssistantNode):
|
||||
async def _run_evaluation_loop(self, user_query: str, insights_summary: list[str], max_selections: int) -> None:
|
||||
"""Run the evaluation loop with LLM."""
|
||||
for step in ["Analyzing insights to match your request", "Comparing insights for best fit"]:
|
||||
self._dispatch_update_message(step)
|
||||
self.dispatcher.update(step)
|
||||
|
||||
tools = self._create_insight_evaluation_tools()
|
||||
llm_with_tools = self._model.bind_tools(tools)
|
||||
@@ -800,7 +783,7 @@ class InsightSearchNode(AssistantNode):
|
||||
if getattr(response, "tool_calls", None):
|
||||
# Only stream on first iteration to avoid noise
|
||||
if iteration == 0:
|
||||
self._dispatch_update_message("Making evaluation decisions")
|
||||
self.dispatcher.update("Making evaluation decisions")
|
||||
self._process_evaluation_tool_calls(response, messages, tools)
|
||||
else:
|
||||
break
|
||||
@@ -841,7 +824,7 @@ class InsightSearchNode(AssistantNode):
|
||||
num_insights = len(self._evaluation_selections)
|
||||
insight_word = "insight" if num_insights == 1 else "insights"
|
||||
|
||||
self._dispatch_update_message(f"Perfect! Found {num_insights} suitable {insight_word}")
|
||||
self.dispatcher.update(f"Perfect! Found {num_insights} suitable {insight_word}")
|
||||
|
||||
for _, selection in self._evaluation_selections.items():
|
||||
insight = selection["insight"]
|
||||
@@ -871,7 +854,7 @@ class InsightSearchNode(AssistantNode):
|
||||
|
||||
async def _create_rejection_result(self) -> dict:
|
||||
"""Create result for when all insights are rejected."""
|
||||
self._dispatch_update_message("Will create a custom insight tailored to your request")
|
||||
self.dispatcher.update("Will create a custom insight tailored to your request")
|
||||
|
||||
return {
|
||||
"should_use_existing": False,
|
||||
@@ -880,9 +863,6 @@ class InsightSearchNode(AssistantNode):
|
||||
"visualization_messages": [],
|
||||
}
|
||||
|
||||
def router(self, state: AssistantState) -> Literal["root"]:
|
||||
return "root"
|
||||
|
||||
@property
|
||||
def _model(self):
|
||||
return ChatOpenAI(
|
||||
|
||||
@@ -36,7 +36,3 @@ Instructions:
|
||||
NO_INSIGHTS_FOUND_MESSAGE = (
|
||||
"No existing insights found matching your query. Creating a new insight based on your request."
|
||||
)
|
||||
|
||||
SEARCH_ERROR_INSTRUCTIONS = "INSTRUCTIONS: Tell the user that you encountered an issue while searching for insights and suggest they try again with a different search term."
|
||||
|
||||
EMPTY_DATABASE_ERROR_MESSAGE = "No insights found in the database."
|
||||
|
||||
@@ -23,7 +23,7 @@ from posthog.schema import (
|
||||
|
||||
from posthog.models import Insight, InsightViewed
|
||||
|
||||
from ee.hogai.graph.insights.nodes import InsightDict, InsightSearchNode
|
||||
from ee.hogai.graph.insights.nodes import InsightDict, InsightSearchNode, NoInsightsException
|
||||
from ee.hogai.utils.types import AssistantState, PartialAssistantState
|
||||
from ee.models.assistant import Conversation
|
||||
|
||||
@@ -115,11 +115,6 @@ class TestInsightSearchNode(BaseTest):
|
||||
short_id=insight.short_id,
|
||||
)
|
||||
|
||||
def test_router_returns_root(self):
|
||||
"""Test that router returns 'root' as expected."""
|
||||
result = self.node.router(AssistantState(messages=[]))
|
||||
self.assertEqual(result, "root")
|
||||
|
||||
async def test_load_insights_page(self):
|
||||
"""Test loading paginated insights from database."""
|
||||
# Load first page
|
||||
@@ -350,12 +345,6 @@ class TestInsightSearchNode(BaseTest):
|
||||
# Should return empty list when LLM fails to select anything
|
||||
self.assertEqual(len(result), 0)
|
||||
|
||||
def test_router_always_returns_root(self):
|
||||
"""Test that router always returns 'root'."""
|
||||
state = AssistantState(messages=[], root_tool_insight_plan="some plan", search_insights_query=None)
|
||||
result = self.node.router(state)
|
||||
self.assertEqual(result, "root")
|
||||
|
||||
async def test_evaluation_flow_returns_creation_when_no_suitable_insights(self):
|
||||
"""Test that when evaluation returns NO, the system transitions to creation flow."""
|
||||
selected_insights = [self.insight1.id, self.insight2.id]
|
||||
@@ -411,21 +400,6 @@ class TestInsightSearchNode(BaseTest):
|
||||
"root_tool_insight_plan should be set to search_query",
|
||||
)
|
||||
|
||||
# Test router behavior with the returned state
|
||||
# Create a new state that simulates what happens after this node runs
|
||||
post_evaluation_state = AssistantState(
|
||||
messages=state.messages,
|
||||
root_tool_insight_plan=search_query, # This gets set to search_query
|
||||
search_insights_query=None, # This gets cleared
|
||||
)
|
||||
|
||||
router_result = self.node.router(post_evaluation_state)
|
||||
self.assertEqual(
|
||||
router_result,
|
||||
"root",
|
||||
"Router should always return root",
|
||||
)
|
||||
|
||||
# Verify that _evaluate_insights_with_tools was called with the search_query
|
||||
mock_evaluate.assert_called_once_with(selected_insights, search_query, max_selections=1)
|
||||
|
||||
@@ -481,7 +455,7 @@ class TestInsightSearchNode(BaseTest):
|
||||
self.assertIsNone(result.root_tool_call_id)
|
||||
|
||||
def test_run_with_no_insights(self):
|
||||
"""Test arun method when no insights exist."""
|
||||
"""Test arun method when no insights exist - should raise NoInsightsException."""
|
||||
# Clear all insights (done outside async context)
|
||||
InsightViewed.objects.all().delete()
|
||||
Insight.objects.all().delete()
|
||||
@@ -497,13 +471,10 @@ class TestInsightSearchNode(BaseTest):
|
||||
async def async_test():
|
||||
# Mock the database calls that happen in async context
|
||||
with patch.object(self.node, "_get_total_insights_count", return_value=0):
|
||||
result = await self.node.arun(state, {"configurable": {"thread_id": str(conversation.id)}})
|
||||
return result
|
||||
await self.node.arun(state, {"configurable": {"thread_id": str(conversation.id)}})
|
||||
|
||||
result = asyncio.run(async_test())
|
||||
self.assertIsInstance(result, PartialAssistantState)
|
||||
self.assertEqual(len(result.messages), 1)
|
||||
self.assertIn("No insights found in the database", result.messages[0].content)
|
||||
with self.assertRaises(NoInsightsException):
|
||||
asyncio.run(async_test())
|
||||
|
||||
async def test_team_filtering(self):
|
||||
"""Test that insights are filtered by team."""
|
||||
|
||||
155
ee/hogai/graph/insights_graph/graph.py
Normal file
155
ee/hogai/graph/insights_graph/graph.py
Normal file
@@ -0,0 +1,155 @@
|
||||
from typing import Literal, Optional
|
||||
|
||||
from ee.hogai.django_checkpoint.checkpointer import DjangoCheckpointer
|
||||
from ee.hogai.graph.base import BaseAssistantGraph
|
||||
from ee.hogai.utils.types.base import AssistantGraphName, AssistantNodeName, AssistantState, PartialAssistantState
|
||||
|
||||
from ..funnels.nodes import FunnelGeneratorNode, FunnelGeneratorToolsNode
|
||||
from ..query_executor.nodes import QueryExecutorNode
|
||||
from ..query_planner.nodes import QueryPlannerNode, QueryPlannerToolsNode
|
||||
from ..rag.nodes import InsightRagContextNode
|
||||
from ..retention.nodes import RetentionGeneratorNode, RetentionGeneratorToolsNode
|
||||
from ..sql.nodes import SQLGeneratorNode, SQLGeneratorToolsNode
|
||||
from ..trends.nodes import TrendsGeneratorNode, TrendsGeneratorToolsNode
|
||||
|
||||
|
||||
class InsightsGraph(BaseAssistantGraph[AssistantState, PartialAssistantState]):
|
||||
@property
|
||||
def graph_name(self) -> AssistantGraphName:
|
||||
return AssistantGraphName.INSIGHTS
|
||||
|
||||
@property
|
||||
def state_type(self) -> type[AssistantState]:
|
||||
return AssistantState
|
||||
|
||||
def add_rag_context(self):
|
||||
self._has_start_node = True
|
||||
retriever = InsightRagContextNode(self._team, self._user)
|
||||
self.add_node(AssistantNodeName.INSIGHT_RAG_CONTEXT, retriever)
|
||||
self._graph.add_edge(AssistantNodeName.START, AssistantNodeName.INSIGHT_RAG_CONTEXT)
|
||||
self._graph.add_edge(AssistantNodeName.INSIGHT_RAG_CONTEXT, AssistantNodeName.QUERY_PLANNER)
|
||||
return self
|
||||
|
||||
def add_trends_generator(self, next_node: AssistantNodeName = AssistantNodeName.QUERY_EXECUTOR):
|
||||
trends_generator = TrendsGeneratorNode(self._team, self._user)
|
||||
self.add_node(AssistantNodeName.TRENDS_GENERATOR, trends_generator)
|
||||
|
||||
trends_generator_tools = TrendsGeneratorToolsNode(self._team, self._user)
|
||||
self.add_node(AssistantNodeName.TRENDS_GENERATOR_TOOLS, trends_generator_tools)
|
||||
|
||||
self._graph.add_edge(AssistantNodeName.TRENDS_GENERATOR_TOOLS, AssistantNodeName.TRENDS_GENERATOR)
|
||||
self._graph.add_conditional_edges(
|
||||
AssistantNodeName.TRENDS_GENERATOR,
|
||||
trends_generator.router,
|
||||
path_map={
|
||||
"tools": AssistantNodeName.TRENDS_GENERATOR_TOOLS,
|
||||
"next": next_node,
|
||||
},
|
||||
)
|
||||
|
||||
return self
|
||||
|
||||
def add_funnel_generator(self, next_node: AssistantNodeName = AssistantNodeName.QUERY_EXECUTOR):
|
||||
funnel_generator = FunnelGeneratorNode(self._team, self._user)
|
||||
self.add_node(AssistantNodeName.FUNNEL_GENERATOR, funnel_generator)
|
||||
|
||||
funnel_generator_tools = FunnelGeneratorToolsNode(self._team, self._user)
|
||||
self.add_node(AssistantNodeName.FUNNEL_GENERATOR_TOOLS, funnel_generator_tools)
|
||||
|
||||
self._graph.add_edge(AssistantNodeName.FUNNEL_GENERATOR_TOOLS, AssistantNodeName.FUNNEL_GENERATOR)
|
||||
self._graph.add_conditional_edges(
|
||||
AssistantNodeName.FUNNEL_GENERATOR,
|
||||
funnel_generator.router,
|
||||
path_map={
|
||||
"tools": AssistantNodeName.FUNNEL_GENERATOR_TOOLS,
|
||||
"next": next_node,
|
||||
},
|
||||
)
|
||||
|
||||
return self
|
||||
|
||||
def add_retention_generator(self, next_node: AssistantNodeName = AssistantNodeName.QUERY_EXECUTOR):
|
||||
retention_generator = RetentionGeneratorNode(self._team, self._user)
|
||||
self.add_node(AssistantNodeName.RETENTION_GENERATOR, retention_generator)
|
||||
|
||||
retention_generator_tools = RetentionGeneratorToolsNode(self._team, self._user)
|
||||
self.add_node(AssistantNodeName.RETENTION_GENERATOR_TOOLS, retention_generator_tools)
|
||||
|
||||
self._graph.add_edge(AssistantNodeName.RETENTION_GENERATOR_TOOLS, AssistantNodeName.RETENTION_GENERATOR)
|
||||
self._graph.add_conditional_edges(
|
||||
AssistantNodeName.RETENTION_GENERATOR,
|
||||
retention_generator.router,
|
||||
path_map={
|
||||
"tools": AssistantNodeName.RETENTION_GENERATOR_TOOLS,
|
||||
"next": next_node,
|
||||
},
|
||||
)
|
||||
|
||||
return self
|
||||
|
||||
def add_query_planner(
|
||||
self,
|
||||
path_map: Optional[
|
||||
dict[Literal["trends", "funnel", "retention", "sql", "continue", "end"], AssistantNodeName]
|
||||
] = None,
|
||||
):
|
||||
query_planner = QueryPlannerNode(self._team, self._user)
|
||||
self.add_node(AssistantNodeName.QUERY_PLANNER, query_planner)
|
||||
self._graph.add_edge(AssistantNodeName.QUERY_PLANNER, AssistantNodeName.QUERY_PLANNER_TOOLS)
|
||||
|
||||
query_planner_tools = QueryPlannerToolsNode(self._team, self._user)
|
||||
self.add_node(AssistantNodeName.QUERY_PLANNER_TOOLS, query_planner_tools)
|
||||
self._graph.add_conditional_edges(
|
||||
AssistantNodeName.QUERY_PLANNER_TOOLS,
|
||||
query_planner_tools.router,
|
||||
path_map=path_map # type: ignore
|
||||
or {
|
||||
"continue": AssistantNodeName.QUERY_PLANNER,
|
||||
"trends": AssistantNodeName.TRENDS_GENERATOR,
|
||||
"funnel": AssistantNodeName.FUNNEL_GENERATOR,
|
||||
"retention": AssistantNodeName.RETENTION_GENERATOR,
|
||||
"sql": AssistantNodeName.SQL_GENERATOR,
|
||||
"end": AssistantNodeName.END,
|
||||
},
|
||||
)
|
||||
|
||||
return self
|
||||
|
||||
def add_sql_generator(self, next_node: AssistantNodeName = AssistantNodeName.QUERY_EXECUTOR):
|
||||
sql_generator = SQLGeneratorNode(self._team, self._user)
|
||||
self.add_node(AssistantNodeName.SQL_GENERATOR, sql_generator)
|
||||
|
||||
sql_generator_tools = SQLGeneratorToolsNode(self._team, self._user)
|
||||
self.add_node(AssistantNodeName.SQL_GENERATOR_TOOLS, sql_generator_tools)
|
||||
|
||||
self._graph.add_edge(AssistantNodeName.SQL_GENERATOR_TOOLS, AssistantNodeName.SQL_GENERATOR)
|
||||
self._graph.add_conditional_edges(
|
||||
AssistantNodeName.SQL_GENERATOR,
|
||||
sql_generator.router,
|
||||
path_map={
|
||||
"tools": AssistantNodeName.SQL_GENERATOR_TOOLS,
|
||||
"next": next_node,
|
||||
},
|
||||
)
|
||||
|
||||
return self
|
||||
|
||||
def add_query_executor(self, next_node: AssistantNodeName = AssistantNodeName.END):
|
||||
query_executor_node = QueryExecutorNode(self._team, self._user)
|
||||
self.add_node(AssistantNodeName.QUERY_EXECUTOR, query_executor_node)
|
||||
self._graph.add_edge(AssistantNodeName.QUERY_EXECUTOR, next_node)
|
||||
return self
|
||||
|
||||
def add_query_creation_flow(self, next_node: AssistantNodeName = AssistantNodeName.QUERY_EXECUTOR):
|
||||
"""Add all nodes and edges EXCEPT query execution."""
|
||||
return (
|
||||
self.add_rag_context()
|
||||
.add_query_planner()
|
||||
.add_trends_generator(next_node=next_node)
|
||||
.add_funnel_generator(next_node=next_node)
|
||||
.add_retention_generator(next_node=next_node)
|
||||
.add_sql_generator(next_node=next_node)
|
||||
)
|
||||
|
||||
def compile_full_graph(self, checkpointer: DjangoCheckpointer | None = None):
|
||||
return self.add_query_creation_flow().add_query_executor().compile(checkpointer=checkpointer)
|
||||
@@ -1,5 +1,5 @@
|
||||
import datetime
|
||||
from abc import ABC
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, get_args, get_origin
|
||||
from uuid import UUID
|
||||
|
||||
@@ -7,15 +7,15 @@ from django.utils import timezone
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
|
||||
from posthog.schema import AssistantMessage, CurrencyCode
|
||||
from posthog.schema import CurrencyCode
|
||||
|
||||
from posthog.event_usage import groups
|
||||
from posthog.models import Team
|
||||
from posthog.models.action.action import Action
|
||||
from posthog.models.user import User
|
||||
|
||||
from ee.hogai.utils.dispatcher import AssistantDispatcher
|
||||
from ee.hogai.utils.types.base import BaseStateWithIntermediateSteps
|
||||
from ee.hogai.utils.dispatcher import AssistantDispatcher, create_dispatcher_from_config
|
||||
from ee.hogai.utils.types.base import BaseStateWithIntermediateSteps, NodePath
|
||||
from ee.models import Conversation, CoreMemory
|
||||
|
||||
|
||||
@@ -194,4 +194,33 @@ class TaxonomyUpdateDispatcherNodeMixin:
|
||||
content = "Picking relevant events and properties"
|
||||
if substeps:
|
||||
content = substeps[-1]
|
||||
self.dispatcher.message(AssistantMessage(content=content))
|
||||
self.dispatcher.update(content)
|
||||
|
||||
|
||||
class AssistantDispatcherMixin(ABC):
|
||||
_node_path: tuple[NodePath, ...]
|
||||
_config: RunnableConfig | None
|
||||
_dispatcher: AssistantDispatcher | None = None
|
||||
|
||||
@property
|
||||
def node_path(self) -> tuple[NodePath, ...]:
|
||||
return self._node_path
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def node_name(self) -> str: ...
|
||||
|
||||
@property
|
||||
def tool_call_id(self) -> str:
|
||||
parent_tool_call_id = next((path.tool_call_id for path in reversed(self._node_path) if path.tool_call_id), None)
|
||||
if not parent_tool_call_id:
|
||||
raise ValueError("No tool call ID found")
|
||||
return parent_tool_call_id
|
||||
|
||||
@property
|
||||
def dispatcher(self) -> AssistantDispatcher:
|
||||
"""Create a dispatcher for this node"""
|
||||
if self._dispatcher:
|
||||
return self._dispatcher
|
||||
self._dispatcher = create_dispatcher_from_config(self._config or {}, self.node_path)
|
||||
return self._dispatcher
|
||||
|
||||
@@ -52,7 +52,7 @@ class WithInsightCreationTaskExecution:
|
||||
The type allows None for compatibility with the base class.
|
||||
"""
|
||||
# Import here to avoid circular dependency
|
||||
from ee.hogai.graph.graph import InsightsAssistantGraph
|
||||
from ee.hogai.graph.insights_graph.graph import InsightsGraph
|
||||
|
||||
task = cast(AssistantToolCall, input_dict["task"])
|
||||
artifacts = input_dict["artifacts"]
|
||||
@@ -60,7 +60,7 @@ class WithInsightCreationTaskExecution:
|
||||
|
||||
self._current_task_id = task.id
|
||||
|
||||
# This is needed by the InsightsAssistantGraph to return an AssistantToolCallMessage
|
||||
# This is needed by the InsightsGraph to return an AssistantToolCallMessage
|
||||
task_tool_call_id = f"task_{uuid.uuid4().hex[:8]}"
|
||||
query = task.args["query_description"]
|
||||
|
||||
@@ -77,9 +77,7 @@ class WithInsightCreationTaskExecution:
|
||||
)
|
||||
|
||||
subgraph_result_messages: list[AssistantMessageUnion] = []
|
||||
assistant_graph = InsightsAssistantGraph(
|
||||
self._team, self._user, tool_call_id=self._parent_tool_call_id
|
||||
).compile_full_graph()
|
||||
assistant_graph = InsightsGraph(self._team, self._user).compile_full_graph()
|
||||
try:
|
||||
async for chunk in assistant_graph.astream(
|
||||
input_state,
|
||||
@@ -156,7 +154,7 @@ class WithInsightCreationTaskExecution:
|
||||
if isinstance(message, VisualizationMessage) and message.id:
|
||||
artifact = InsightArtifact(
|
||||
task_id=tool_call.id,
|
||||
id=None, # The InsightsAssistantGraph does not create the insight objects
|
||||
id=None, # The InsightsGraph does not create the insight objects
|
||||
content="",
|
||||
query=cast(AnyAssistantGeneratedQuery, message.answer),
|
||||
)
|
||||
@@ -194,9 +192,7 @@ class WithInsightSearchTaskExecution:
|
||||
)
|
||||
|
||||
try:
|
||||
result = await InsightSearchNode(self._team, self._user, tool_call_id=self._parent_tool_call_id).arun(
|
||||
input_state, config
|
||||
)
|
||||
result = await InsightSearchNode(self._team, self._user).arun(input_state, config)
|
||||
|
||||
if not result or not result.messages:
|
||||
logger.warning("Task failed: no messages received from node executor", task_id=task.id)
|
||||
|
||||
@@ -76,6 +76,21 @@ class QueryExecutorNode(AssistantNode):
|
||||
)
|
||||
return PartialAssistantState(messages=[FailureMessage(content=str(err), id=str(uuid4()))])
|
||||
|
||||
return PartialAssistantState(
|
||||
messages=[
|
||||
AssistantToolCallMessage(
|
||||
content=self._format_query_result(viz_message, results, example_prompt),
|
||||
id=str(uuid4()),
|
||||
tool_call_id=tool_call_id,
|
||||
)
|
||||
],
|
||||
root_tool_call_id=None,
|
||||
root_tool_insight_plan=None,
|
||||
root_tool_insight_type=None,
|
||||
rag_context=None,
|
||||
)
|
||||
|
||||
def _format_query_result(self, viz_message: VisualizationMessage, results: str, example_prompt: str) -> str:
|
||||
query_result = QUERY_RESULTS_PROMPT.format(
|
||||
query_kind=viz_message.answer.kind,
|
||||
results=results,
|
||||
@@ -84,20 +99,10 @@ class QueryExecutorNode(AssistantNode):
|
||||
project_timezone=self.project_timezone,
|
||||
currency=self.project_currency,
|
||||
)
|
||||
|
||||
formatted_query_result = f"{example_prompt}\n\n{query_result}"
|
||||
if isinstance(viz_message.answer, AssistantHogQLQuery):
|
||||
formatted_query_result = f"{example_prompt}\n\n{SQL_QUERY_PROMPT.format(query=viz_message.answer.query)}\n\n{formatted_query_result}"
|
||||
|
||||
return PartialAssistantState(
|
||||
messages=[
|
||||
AssistantToolCallMessage(content=formatted_query_result, id=str(uuid4()), tool_call_id=tool_call_id)
|
||||
],
|
||||
root_tool_call_id=None,
|
||||
root_tool_insight_plan=None,
|
||||
root_tool_insight_type=None,
|
||||
rag_context=None,
|
||||
)
|
||||
return formatted_query_result
|
||||
|
||||
def _get_example_prompt(self, viz_message: VisualizationMessage) -> str:
|
||||
if isinstance(viz_message.answer, AssistantTrendsQuery | TrendsQuery):
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from langchain_core.prompts import ChatPromptTemplate
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
|
||||
from posthog.schema import AssistantMessage, AssistantRetentionQuery
|
||||
from posthog.schema import AssistantRetentionQuery
|
||||
|
||||
from ee.hogai.utils.types import AssistantState, PartialAssistantState
|
||||
from ee.hogai.utils.types.base import AssistantNodeName
|
||||
@@ -25,7 +25,7 @@ class RetentionGeneratorNode(SchemaGeneratorNode[AssistantRetentionQuery]):
|
||||
return AssistantNodeName.RETENTION_GENERATOR
|
||||
|
||||
async def arun(self, state: AssistantState, config: RunnableConfig) -> PartialAssistantState:
|
||||
self.dispatcher.message(AssistantMessage(content="Creating retention query"))
|
||||
self.dispatcher.update("Creating retention query")
|
||||
prompt = ChatPromptTemplate.from_messages(
|
||||
[
|
||||
("system", RETENTION_SYSTEM_PROMPT),
|
||||
|
||||
@@ -3,27 +3,23 @@ from collections.abc import Awaitable, Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Literal, TypeVar, Union
|
||||
from uuid import uuid4
|
||||
|
||||
import structlog
|
||||
import posthoganalytics
|
||||
from langchain_core.messages import (
|
||||
AIMessage as LangchainAIMessage,
|
||||
BaseMessage,
|
||||
HumanMessage as LangchainHumanMessage,
|
||||
ToolCall,
|
||||
ToolMessage as LangchainToolMessage,
|
||||
)
|
||||
from langchain_core.prompts import ChatPromptTemplate
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.errors import NodeInterrupt
|
||||
from langgraph.types import Send
|
||||
from posthoganalytics import capture_exception
|
||||
from pydantic import BaseModel
|
||||
|
||||
from posthog.schema import (
|
||||
AssistantMessage,
|
||||
AssistantTool,
|
||||
AssistantToolCallMessage,
|
||||
ContextMessage,
|
||||
FailureMessage,
|
||||
HumanMessage,
|
||||
)
|
||||
from posthog.schema import AssistantMessage, AssistantToolCallMessage, ContextMessage, FailureMessage, HumanMessage
|
||||
|
||||
from posthog.models import Team, User
|
||||
|
||||
@@ -36,8 +32,14 @@ from ee.hogai.tool import ToolMessagesArtifact
|
||||
from ee.hogai.utils.anthropic import add_cache_control, convert_to_anthropic_messages
|
||||
from ee.hogai.utils.helpers import convert_tool_messages_to_dict, normalize_ai_message
|
||||
from ee.hogai.utils.prompt import format_prompt_string
|
||||
from ee.hogai.utils.types import AssistantMessageUnion, AssistantNodeName, AssistantState, InsightQuery
|
||||
from ee.hogai.utils.types.base import PartialAssistantState, ReplaceMessages
|
||||
from ee.hogai.utils.types.base import (
|
||||
AssistantMessageUnion,
|
||||
AssistantNodeName,
|
||||
AssistantState,
|
||||
NodePath,
|
||||
PartialAssistantState,
|
||||
ReplaceMessages,
|
||||
)
|
||||
from ee.hogai.utils.types.composed import MaxNodeName
|
||||
|
||||
from .prompts import (
|
||||
@@ -51,13 +53,13 @@ from .prompts import (
|
||||
ROOT_TOOL_DOES_NOT_EXIST,
|
||||
)
|
||||
from .tools import (
|
||||
CreateAndQueryInsightTool,
|
||||
CreateDashboardTool,
|
||||
ReadDataTool,
|
||||
ReadTaxonomyTool,
|
||||
SearchTool,
|
||||
SessionSummarizationTool,
|
||||
TodoWriteTool,
|
||||
create_and_query_insight,
|
||||
create_dashboard,
|
||||
session_summarization,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -66,24 +68,14 @@ if TYPE_CHECKING:
|
||||
SLASH_COMMAND_INIT = "/init"
|
||||
SLASH_COMMAND_REMEMBER = "/remember"
|
||||
|
||||
RouteName = Literal[
|
||||
"insights",
|
||||
"root",
|
||||
"end",
|
||||
"search_documentation",
|
||||
"memory_onboarding",
|
||||
"insights_search",
|
||||
"billing",
|
||||
"session_summarization",
|
||||
"create_dashboard",
|
||||
]
|
||||
|
||||
|
||||
RootMessageUnion = HumanMessage | AssistantMessage | FailureMessage | AssistantToolCallMessage | ContextMessage
|
||||
T = TypeVar("T", RootMessageUnion, BaseMessage)
|
||||
|
||||
RootTool = Union[type[BaseModel], "MaxTool"]
|
||||
|
||||
logger = structlog.get_logger(__name__)
|
||||
|
||||
|
||||
class RootNode(AssistantNode):
|
||||
MAX_TOOL_CALLS = 24
|
||||
@@ -95,8 +87,8 @@ class RootNode(AssistantNode):
|
||||
Determines the thinking configuration for the model.
|
||||
"""
|
||||
|
||||
def __init__(self, team: Team, user: User):
|
||||
super().__init__(team, user)
|
||||
def __init__(self, team: Team, user: User, node_path: tuple[NodePath, ...] | None = None):
|
||||
super().__init__(team, user, node_path)
|
||||
self._window_manager = AnthropicConversationCompactionManager()
|
||||
|
||||
async def arun(self, state: AssistantState, config: RunnableConfig) -> PartialAssistantState:
|
||||
@@ -168,12 +160,25 @@ class RootNode(AssistantNode):
|
||||
if messages_to_replace:
|
||||
new_messages = ReplaceMessages([*messages_to_replace, assistant_message])
|
||||
|
||||
# Set new tool call count
|
||||
tool_call_count = (state.root_tool_calls_count or 0) + 1 if assistant_message.tool_calls else None
|
||||
|
||||
return PartialAssistantState(
|
||||
messages=new_messages,
|
||||
root_tool_calls_count=tool_call_count,
|
||||
root_conversation_start_id=window_id,
|
||||
start_id=start_id,
|
||||
)
|
||||
|
||||
def router(self, state: AssistantState):
|
||||
last_message = state.messages[-1]
|
||||
if not isinstance(last_message, AssistantMessage) or not last_message.tool_calls:
|
||||
return AssistantNodeName.END
|
||||
return [
|
||||
Send(AssistantNodeName.ROOT_TOOLS, state.model_copy(update={"root_tool_call_id": tool_call.id}))
|
||||
for tool_call in last_message.tool_calls
|
||||
]
|
||||
|
||||
@property
|
||||
def node_name(self) -> MaxNodeName:
|
||||
return AssistantNodeName.ROOT
|
||||
@@ -226,11 +231,35 @@ class RootNode(AssistantNode):
|
||||
if self._is_hard_limit_reached(state.root_tool_calls_count):
|
||||
return base_model
|
||||
|
||||
return base_model.bind_tools(tools, parallel_tool_calls=False)
|
||||
return base_model.bind_tools(tools, parallel_tool_calls=True)
|
||||
|
||||
async def _get_tools(self, state: AssistantState, config: RunnableConfig) -> list[RootTool]:
|
||||
from ee.hogai.tool import get_contextual_tool_class
|
||||
|
||||
# Static toolkit
|
||||
default_tools: list[type[MaxTool]] = [
|
||||
ReadTaxonomyTool,
|
||||
ReadDataTool,
|
||||
SearchTool,
|
||||
TodoWriteTool,
|
||||
]
|
||||
|
||||
# The contextual insights tool overrides the static tool. Only inject if it's injected.
|
||||
if not CreateAndQueryInsightTool.is_editing_mode(self.context_manager):
|
||||
default_tools.append(CreateAndQueryInsightTool)
|
||||
|
||||
# Add session summarization tool if enabled
|
||||
if self._has_session_summarization_feature_flag():
|
||||
default_tools.append(SessionSummarizationTool)
|
||||
|
||||
# Add other lower-priority tools
|
||||
default_tools.extend(
|
||||
[
|
||||
CreateDashboardTool,
|
||||
]
|
||||
)
|
||||
|
||||
# Processed tools
|
||||
available_tools: list[RootTool] = []
|
||||
|
||||
# Initialize the static toolkit
|
||||
@@ -238,37 +267,14 @@ class RootNode(AssistantNode):
|
||||
# This is just to bound the tools to the model
|
||||
dynamic_tools = (
|
||||
tool_class.create_tool_class(
|
||||
team=self._team,
|
||||
user=self._user,
|
||||
tool_call_id="",
|
||||
state=state,
|
||||
config=config,
|
||||
context_manager=self.context_manager,
|
||||
)
|
||||
for tool_class in (
|
||||
ReadTaxonomyTool,
|
||||
ReadDataTool,
|
||||
SearchTool,
|
||||
TodoWriteTool,
|
||||
team=self._team, user=self._user, state=state, config=config, context_manager=self.context_manager
|
||||
)
|
||||
for tool_class in default_tools
|
||||
)
|
||||
available_tools.extend(await asyncio.gather(*dynamic_tools))
|
||||
|
||||
# Insights tool
|
||||
tool_names = self.context_manager.get_contextual_tools().keys()
|
||||
is_editing_insight = AssistantTool.EDIT_CURRENT_INSIGHT in tool_names
|
||||
if not is_editing_insight:
|
||||
# This is the default tool, which can be overriden by the MaxTool based tool with the same name
|
||||
available_tools.append(create_and_query_insight)
|
||||
|
||||
# Check if session summarization is enabled for the user
|
||||
if self._has_session_summarization_feature_flag():
|
||||
available_tools.append(session_summarization)
|
||||
|
||||
# Dashboard creation tool
|
||||
available_tools.append(create_dashboard)
|
||||
|
||||
# Inject contextual tools
|
||||
tool_names = self.context_manager.get_contextual_tools().keys()
|
||||
awaited_contextual_tools: list[Awaitable[RootTool]] = []
|
||||
for tool_name in tool_names:
|
||||
ContextualMaxToolClass = get_contextual_tool_class(tool_name)
|
||||
@@ -278,7 +284,6 @@ class RootNode(AssistantNode):
|
||||
ContextualMaxToolClass.create_tool_class(
|
||||
team=self._team,
|
||||
user=self._user,
|
||||
tool_call_id="",
|
||||
state=state,
|
||||
config=config,
|
||||
context_manager=self.context_manager,
|
||||
@@ -366,51 +371,23 @@ class RootNodeTools(AssistantNode):
|
||||
def node_name(self) -> MaxNodeName:
|
||||
return AssistantNodeName.ROOT_TOOLS
|
||||
|
||||
async def arun(self, state: AssistantState, config: RunnableConfig) -> PartialAssistantState:
|
||||
async def arun(self, state: AssistantState, config: RunnableConfig) -> PartialAssistantState | None:
|
||||
last_message = state.messages[-1]
|
||||
if not isinstance(last_message, AssistantMessage) or not last_message.tool_calls:
|
||||
# Reset tools.
|
||||
return PartialAssistantState(root_tool_calls_count=0)
|
||||
|
||||
tool_call_count = state.root_tool_calls_count or 0
|
||||
reset_state = PartialAssistantState(root_tool_call_id=None)
|
||||
# Should never happen, but just in case.
|
||||
if not isinstance(last_message, AssistantMessage) or not last_message.id or not state.root_tool_call_id:
|
||||
return reset_state
|
||||
|
||||
tools_calls = last_message.tool_calls
|
||||
if len(tools_calls) != 1:
|
||||
raise ValueError("Expected exactly one tool call.")
|
||||
|
||||
tool_names = self.context_manager.get_contextual_tools().keys()
|
||||
is_editing_insight = AssistantTool.EDIT_CURRENT_INSIGHT in tool_names
|
||||
tool_call = tools_calls[0]
|
||||
# Find the current tool call in the last message.
|
||||
tool_call = next(
|
||||
(tool_call for tool_call in last_message.tool_calls or [] if tool_call.id == state.root_tool_call_id), None
|
||||
)
|
||||
if not tool_call:
|
||||
return reset_state
|
||||
|
||||
from ee.hogai.tool import get_contextual_tool_class
|
||||
|
||||
if tool_call.name == "create_and_query_insight" and not is_editing_insight:
|
||||
return PartialAssistantState(
|
||||
root_tool_call_id=tool_call.id,
|
||||
root_tool_insight_plan=tool_call.args["query_description"],
|
||||
root_tool_calls_count=tool_call_count + 1,
|
||||
)
|
||||
if tool_call.name == "session_summarization":
|
||||
return PartialAssistantState(
|
||||
root_tool_call_id=tool_call.id,
|
||||
session_summarization_query=tool_call.args["session_summarization_query"],
|
||||
# Safety net in case the argument is missing to avoid raising exceptions internally
|
||||
should_use_current_filters=tool_call.args.get("should_use_current_filters", False),
|
||||
summary_title=tool_call.args.get("summary_title"),
|
||||
root_tool_calls_count=tool_call_count + 1,
|
||||
)
|
||||
if tool_call.name == "create_dashboard":
|
||||
raw_queries = tool_call.args["search_insights_queries"]
|
||||
search_insights_queries = [InsightQuery.model_validate(query) for query in raw_queries]
|
||||
|
||||
return PartialAssistantState(
|
||||
root_tool_call_id=tool_call.id,
|
||||
dashboard_name=tool_call.args.get("dashboard_name"),
|
||||
search_insights_queries=search_insights_queries,
|
||||
root_tool_calls_count=tool_call_count + 1,
|
||||
)
|
||||
|
||||
# MaxTool flow
|
||||
ToolClass = get_contextual_tool_class(tool_call.name)
|
||||
|
||||
# If the tool doesn't exist, return the message to the agent
|
||||
@@ -423,25 +400,32 @@ class RootNodeTools(AssistantNode):
|
||||
tool_call_id=tool_call.id,
|
||||
)
|
||||
],
|
||||
root_tool_calls_count=tool_call_count + 1,
|
||||
)
|
||||
|
||||
# Initialize the tool and process it
|
||||
tool_class = await ToolClass.create_tool_class(
|
||||
team=self._team,
|
||||
user=self._user,
|
||||
tool_call_id=tool_call.id,
|
||||
# Tricky: set the node path to associated with the tool call
|
||||
node_path=(
|
||||
*self.node_path[:-1],
|
||||
NodePath(name=AssistantNodeName.ROOT_TOOLS, message_id=last_message.id, tool_call_id=tool_call.id),
|
||||
),
|
||||
state=state,
|
||||
config=config,
|
||||
context_manager=self.context_manager,
|
||||
)
|
||||
|
||||
try:
|
||||
result = await tool_class.ainvoke(tool_call.model_dump(), config)
|
||||
result = await tool_class.ainvoke(
|
||||
ToolCall(type="tool_call", name=tool_call.name, args=tool_call.args, id=tool_call.id), config=config
|
||||
)
|
||||
if not isinstance(result, LangchainToolMessage):
|
||||
raise ValueError(
|
||||
f"Tool '{tool_call.name}' returned {type(result).__name__}, expected LangchainToolMessage"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception("Error calling tool", extra={"tool_name": tool_call.name, "error": str(e)})
|
||||
capture_exception(
|
||||
e, distinct_id=self._get_user_distinct_id(config), properties=self._get_debug_props(config)
|
||||
)
|
||||
@@ -453,38 +437,11 @@ class RootNodeTools(AssistantNode):
|
||||
tool_call_id=tool_call.id,
|
||||
)
|
||||
],
|
||||
root_tool_calls_count=tool_call_count + 1,
|
||||
)
|
||||
|
||||
if isinstance(result.artifact, ToolMessagesArtifact):
|
||||
return PartialAssistantState(
|
||||
messages=result.artifact.messages,
|
||||
root_tool_calls_count=tool_call_count + 1,
|
||||
)
|
||||
|
||||
# Handle the basic toolkit
|
||||
if result.name == "search" and isinstance(result.artifact, dict):
|
||||
match result.artifact.get("kind"):
|
||||
case "insights":
|
||||
return PartialAssistantState(
|
||||
root_tool_call_id=tool_call.id,
|
||||
search_insights_query=result.artifact.get("query"),
|
||||
root_tool_calls_count=tool_call_count + 1,
|
||||
)
|
||||
case "docs":
|
||||
return PartialAssistantState(
|
||||
root_tool_call_id=tool_call.id,
|
||||
root_tool_calls_count=tool_call_count + 1,
|
||||
)
|
||||
|
||||
if (
|
||||
result.name == "read_data"
|
||||
and isinstance(result.artifact, dict)
|
||||
and result.artifact.get("kind") == "billing_info"
|
||||
):
|
||||
return PartialAssistantState(
|
||||
root_tool_call_id=tool_call.id,
|
||||
root_tool_calls_count=tool_call_count + 1,
|
||||
)
|
||||
|
||||
# If this is a navigation tool call, pause the graph execution
|
||||
@@ -496,7 +453,6 @@ class RootNodeTools(AssistantNode):
|
||||
id=str(uuid4()),
|
||||
tool_call_id=tool_call.id,
|
||||
)
|
||||
self.dispatcher.message(navigate_message)
|
||||
# Raising a `NodeInterrupt` ensures the assistant graph stops here and
|
||||
# surfaces the navigation confirmation to the client. The next user
|
||||
# interaction will resume the graph with potentially different
|
||||
@@ -512,29 +468,11 @@ class RootNodeTools(AssistantNode):
|
||||
|
||||
return PartialAssistantState(
|
||||
messages=[tool_message],
|
||||
root_tool_calls_count=tool_call_count + 1,
|
||||
)
|
||||
|
||||
def router(self, state: AssistantState) -> RouteName:
|
||||
# This is only for the Inkeep node. Remove when inkeep_docs is removed.
|
||||
def router(self, state: AssistantState) -> Literal["root", "end"]:
|
||||
last_message = state.messages[-1]
|
||||
|
||||
if isinstance(last_message, AssistantToolCallMessage):
|
||||
return "root" # Let the root either proceed or finish, since it now can see the tool call result
|
||||
if isinstance(last_message, AssistantMessage) and state.root_tool_call_id:
|
||||
tool_calls = getattr(last_message, "tool_calls", None)
|
||||
if tool_calls and len(tool_calls) > 0:
|
||||
tool_call = tool_calls[0]
|
||||
tool_call_name = tool_call.name
|
||||
if tool_call_name == "read_data" and tool_call.args.get("kind") == "billing_info":
|
||||
return "billing"
|
||||
if tool_call_name == "create_dashboard":
|
||||
return "create_dashboard"
|
||||
if state.root_tool_insight_plan:
|
||||
return "insights"
|
||||
elif state.search_insights_query:
|
||||
return "insights_search"
|
||||
elif state.session_summarization_query:
|
||||
return "session_summarization"
|
||||
else:
|
||||
return "search_documentation"
|
||||
return "end"
|
||||
|
||||
@@ -112,6 +112,10 @@ The user is a product engineer and will primarily request you perform product ma
|
||||
- Tool results and user messages may include <system_reminder> tags. <system_reminder> tags contain useful information and reminders. They are NOT part of the user's provided input or the tool result.
|
||||
</doing_tasks>
|
||||
|
||||
<tool_usage_policy>
|
||||
- You can invoke multiple tools within a single response. When a request involves several independent pieces of information, batch your tool calls together for optimal performance
|
||||
</tool_usage_policy>
|
||||
|
||||
{{{billing_context}}}
|
||||
|
||||
{{{core_memory_prompt}}}
|
||||
|
||||
@@ -38,6 +38,7 @@ from ee.hogai.graph.root.prompts import (
|
||||
ROOT_BILLING_CONTEXT_WITH_ACCESS_PROMPT,
|
||||
ROOT_BILLING_CONTEXT_WITH_NO_ACCESS_PROMPT,
|
||||
)
|
||||
from ee.hogai.tool import ToolMessagesArtifact
|
||||
from ee.hogai.utils.tests import FakeChatAnthropic, FakeChatOpenAI
|
||||
from ee.hogai.utils.types import AssistantState, PartialAssistantState
|
||||
from ee.hogai.utils.types.base import AssistantMessageUnion
|
||||
@@ -644,6 +645,147 @@ class TestRootNode(ClickhouseTestMixin, BaseTest):
|
||||
result = node._is_hard_limit_reached(tool_calls_count)
|
||||
self.assertEqual(result, expected)
|
||||
|
||||
async def test_node_increments_tool_count_on_tool_call(self):
|
||||
"""Test that RootNode increments tool count when assistant makes a tool call"""
|
||||
with patch(
|
||||
"ee.hogai.graph.root.nodes.RootNode._get_model",
|
||||
return_value=FakeChatOpenAI(
|
||||
responses=[
|
||||
LangchainAIMessage(
|
||||
content="Let me help",
|
||||
tool_calls=[
|
||||
{
|
||||
"id": "tool-1",
|
||||
"name": "create_and_query_insight",
|
||||
"args": {"query_description": "test"},
|
||||
}
|
||||
],
|
||||
)
|
||||
]
|
||||
),
|
||||
):
|
||||
node = RootNode(self.team, self.user)
|
||||
|
||||
# Test starting from no tool calls
|
||||
state_1 = AssistantState(messages=[HumanMessage(content="Hello")])
|
||||
result_1 = await node.arun(state_1, {})
|
||||
self.assertEqual(result_1.root_tool_calls_count, 1)
|
||||
|
||||
# Test incrementing from existing count
|
||||
state_2 = AssistantState(
|
||||
messages=[HumanMessage(content="Hello")],
|
||||
root_tool_calls_count=5,
|
||||
)
|
||||
result_2 = await node.arun(state_2, {})
|
||||
self.assertEqual(result_2.root_tool_calls_count, 6)
|
||||
|
||||
async def test_node_resets_tool_count_on_plain_response(self):
|
||||
"""Test that RootNode resets tool count when assistant responds without tool calls"""
|
||||
with patch(
|
||||
"ee.hogai.graph.root.nodes.RootNode._get_model",
|
||||
return_value=FakeChatOpenAI(responses=[LangchainAIMessage(content="Here's your answer")]),
|
||||
):
|
||||
node = RootNode(self.team, self.user)
|
||||
|
||||
state = AssistantState(
|
||||
messages=[HumanMessage(content="Hello")],
|
||||
root_tool_calls_count=5,
|
||||
)
|
||||
result = await node.arun(state, {})
|
||||
self.assertIsNone(result.root_tool_calls_count)
|
||||
|
||||
def test_router_returns_end_for_plain_response(self):
|
||||
"""Test that router returns END when message has no tool calls"""
|
||||
from ee.hogai.utils.types import AssistantNodeName
|
||||
|
||||
node = RootNode(self.team, self.user)
|
||||
|
||||
state = AssistantState(
|
||||
messages=[
|
||||
HumanMessage(content="Hello"),
|
||||
AssistantMessage(content="Hi there!"),
|
||||
]
|
||||
)
|
||||
result = node.router(state)
|
||||
self.assertEqual(result, AssistantNodeName.END)
|
||||
|
||||
def test_router_returns_send_for_single_tool_call(self):
|
||||
"""Test that router returns Send for single tool call"""
|
||||
from langgraph.types import Send
|
||||
|
||||
from ee.hogai.utils.types import AssistantNodeName
|
||||
|
||||
node = RootNode(self.team, self.user)
|
||||
|
||||
state = AssistantState(
|
||||
messages=[
|
||||
HumanMessage(content="Generate insights"),
|
||||
AssistantMessage(
|
||||
content="Let me help",
|
||||
tool_calls=[
|
||||
AssistantToolCall(
|
||||
id="tool-1",
|
||||
name="create_and_query_insight",
|
||||
args={"query_description": "test"},
|
||||
)
|
||||
],
|
||||
),
|
||||
]
|
||||
)
|
||||
result = node.router(state)
|
||||
|
||||
# Verify it's a list of Send objects
|
||||
self.assertIsInstance(result, list)
|
||||
self.assertEqual(len(result), 1)
|
||||
self.assertIsInstance(result[0], Send)
|
||||
self.assertEqual(result[0].node, AssistantNodeName.ROOT_TOOLS)
|
||||
self.assertEqual(result[0].arg.root_tool_call_id, "tool-1")
|
||||
|
||||
def test_router_returns_multiple_sends_for_parallel_tool_calls(self):
|
||||
"""Test that router returns multiple Send objects for parallel tool calls"""
|
||||
from langgraph.types import Send
|
||||
|
||||
from ee.hogai.utils.types import AssistantNodeName
|
||||
|
||||
node = RootNode(self.team, self.user)
|
||||
|
||||
state = AssistantState(
|
||||
messages=[
|
||||
HumanMessage(content="Generate multiple insights"),
|
||||
AssistantMessage(
|
||||
content="Let me create several insights",
|
||||
tool_calls=[
|
||||
AssistantToolCall(
|
||||
id="tool-1",
|
||||
name="create_and_query_insight",
|
||||
args={"query_description": "trends"},
|
||||
),
|
||||
AssistantToolCall(
|
||||
id="tool-2",
|
||||
name="create_and_query_insight",
|
||||
args={"query_description": "funnel"},
|
||||
),
|
||||
AssistantToolCall(
|
||||
id="tool-3",
|
||||
name="create_and_query_insight",
|
||||
args={"query_description": "retention"},
|
||||
),
|
||||
],
|
||||
),
|
||||
]
|
||||
)
|
||||
result = node.router(state)
|
||||
|
||||
# Verify it's a list of Send objects
|
||||
self.assertIsInstance(result, list)
|
||||
self.assertEqual(len(result), 3)
|
||||
|
||||
# Verify all are Send objects to ROOT_TOOLS
|
||||
for i, send in enumerate(result):
|
||||
self.assertIsInstance(send, Send)
|
||||
self.assertEqual(send.node, AssistantNodeName.ROOT_TOOLS)
|
||||
self.assertEqual(send.arg.root_tool_call_id, f"tool-{i+1}")
|
||||
|
||||
|
||||
class TestRootNodeTools(BaseTest):
|
||||
def test_node_tools_router(self):
|
||||
@@ -675,9 +817,13 @@ class TestRootNodeTools(BaseTest):
|
||||
node = RootNodeTools(self.team, self.user)
|
||||
state = AssistantState(messages=[HumanMessage(content="Hello")])
|
||||
result = await node.arun(state, {})
|
||||
self.assertEqual(result, PartialAssistantState(root_tool_calls_count=0))
|
||||
self.assertEqual(result, PartialAssistantState(root_tool_call_id=None))
|
||||
|
||||
@patch("ee.hogai.graph.root.tools.create_and_query_insight.CreateAndQueryInsightTool._arun_impl")
|
||||
async def test_run_valid_tool_call(self, create_and_query_insight_mock):
|
||||
test_message = AssistantToolCallMessage(content="Tool result", tool_call_id="xyz", id="msg-1")
|
||||
create_and_query_insight_mock.return_value = ("", ToolMessagesArtifact(messages=[test_message]))
|
||||
|
||||
async def test_run_valid_tool_call(self):
|
||||
node = RootNodeTools(self.team, self.user)
|
||||
state = AssistantState(
|
||||
messages=[
|
||||
@@ -688,17 +834,20 @@ class TestRootNodeTools(BaseTest):
|
||||
AssistantToolCall(
|
||||
id="xyz",
|
||||
name="create_and_query_insight",
|
||||
args={"query_kind": "trends", "query_description": "test query"},
|
||||
args={"query_description": "test query"},
|
||||
)
|
||||
],
|
||||
)
|
||||
]
|
||||
],
|
||||
root_tool_call_id="xyz",
|
||||
)
|
||||
result = await node.arun(state, {})
|
||||
self.assertIsInstance(result, PartialAssistantState)
|
||||
self.assertEqual(result.root_tool_call_id, "xyz")
|
||||
self.assertEqual(result.root_tool_insight_plan, "test query")
|
||||
self.assertEqual(result.root_tool_insight_type, None) # Insight type is determined by query planner node
|
||||
assert result is not None
|
||||
self.assertEqual(len(result.messages), 1)
|
||||
assert isinstance(result.messages[0], AssistantToolCallMessage)
|
||||
self.assertEqual(result.messages[0].tool_call_id, "xyz")
|
||||
create_and_query_insight_mock.assert_called_once_with(query_description="test query")
|
||||
|
||||
async def test_run_valid_contextual_tool_call(self):
|
||||
node = RootNodeTools(self.team, self.user)
|
||||
@@ -715,7 +864,8 @@ class TestRootNodeTools(BaseTest):
|
||||
)
|
||||
],
|
||||
)
|
||||
]
|
||||
],
|
||||
root_tool_call_id="xyz",
|
||||
)
|
||||
|
||||
result = await node.arun(
|
||||
@@ -730,69 +880,9 @@ class TestRootNodeTools(BaseTest):
|
||||
)
|
||||
|
||||
self.assertIsInstance(result, PartialAssistantState)
|
||||
self.assertEqual(result.root_tool_call_id, None) # Tool was fully handled by the node
|
||||
self.assertIsNone(result.root_tool_insight_plan) # No insight plan for contextual tools
|
||||
self.assertIsNone(result.root_tool_insight_type) # No insight type for contextual tools
|
||||
|
||||
async def test_run_multiple_tool_calls_raises(self):
|
||||
node = RootNodeTools(self.team, self.user)
|
||||
state = AssistantState(
|
||||
messages=[
|
||||
AssistantMessage(
|
||||
content="Hello",
|
||||
id="test-id",
|
||||
tool_calls=[
|
||||
AssistantToolCall(
|
||||
id="xyz1",
|
||||
name="create_and_query_insight",
|
||||
args={"query_kind": "trends", "query_description": "test query 1"},
|
||||
),
|
||||
AssistantToolCall(
|
||||
id="xyz2",
|
||||
name="create_and_query_insight",
|
||||
args={"query_kind": "funnel", "query_description": "test query 2"},
|
||||
),
|
||||
],
|
||||
)
|
||||
]
|
||||
)
|
||||
with self.assertRaises(ValueError) as cm:
|
||||
await node.arun(state, {})
|
||||
self.assertEqual(str(cm.exception), "Expected exactly one tool call.")
|
||||
|
||||
async def test_run_increments_tool_count(self):
|
||||
node = RootNodeTools(self.team, self.user)
|
||||
state = AssistantState(
|
||||
messages=[
|
||||
AssistantMessage(
|
||||
content="Hello",
|
||||
id="test-id",
|
||||
tool_calls=[
|
||||
AssistantToolCall(
|
||||
id="xyz",
|
||||
name="create_and_query_insight",
|
||||
args={"query_kind": "trends", "query_description": "test query"},
|
||||
)
|
||||
],
|
||||
)
|
||||
],
|
||||
root_tool_calls_count=2, # Starting count
|
||||
)
|
||||
result = await node.arun(state, {})
|
||||
self.assertEqual(result.root_tool_calls_count, 3) # Should increment by 1
|
||||
|
||||
async def test_run_resets_tool_count(self):
|
||||
node = RootNodeTools(self.team, self.user)
|
||||
|
||||
# Test reset when no tool calls in AssistantMessage
|
||||
state_1 = AssistantState(messages=[AssistantMessage(content="Hello", tool_calls=[])], root_tool_calls_count=3)
|
||||
result = await node.arun(state_1, {})
|
||||
self.assertEqual(result.root_tool_calls_count, 0)
|
||||
|
||||
# Test reset when last message is HumanMessage
|
||||
state_2 = AssistantState(messages=[HumanMessage(content="Hello")], root_tool_calls_count=3)
|
||||
result = await node.arun(state_2, {})
|
||||
self.assertEqual(result.root_tool_calls_count, 0)
|
||||
assert result is not None
|
||||
self.assertEqual(len(result.messages), 1)
|
||||
self.assertIsInstance(result.messages[0], AssistantToolCallMessage)
|
||||
|
||||
async def test_navigate_tool_call_raises_node_interrupt(self):
|
||||
"""Test that navigate tool calls raise NodeInterrupt to pause graph execution"""
|
||||
@@ -805,7 +895,8 @@ class TestRootNodeTools(BaseTest):
|
||||
id="test-id",
|
||||
tool_calls=[AssistantToolCall(id="nav-123", name="navigate", args={"page_key": "insights"})],
|
||||
)
|
||||
]
|
||||
],
|
||||
root_tool_call_id="nav-123",
|
||||
)
|
||||
|
||||
mock_navigate_tool = AsyncMock()
|
||||
@@ -828,319 +919,6 @@ class TestRootNodeTools(BaseTest):
|
||||
self.assertEqual(interrupt_data.tool_call_id, "nav-123")
|
||||
self.assertEqual(interrupt_data.ui_payload, {"navigate": {"page_key": "insights"}})
|
||||
|
||||
def test_billing_tool_routing(self):
|
||||
"""Test that billing tool calls are routed correctly"""
|
||||
node = RootNodeTools(self.team, self.user)
|
||||
|
||||
# Create state with billing tool call (read_data with kind=billing_info)
|
||||
state = AssistantState(
|
||||
messages=[
|
||||
AssistantMessage(
|
||||
content="Let me check your billing information",
|
||||
tool_calls=[AssistantToolCall(id="billing-123", name="read_data", args={"kind": "billing_info"})],
|
||||
)
|
||||
],
|
||||
root_tool_call_id="billing-123",
|
||||
)
|
||||
|
||||
# Should route to billing
|
||||
self.assertEqual(node.router(state), "billing")
|
||||
|
||||
def test_router_insights_path(self):
|
||||
"""Test router routes to insights when root_tool_insight_plan is set"""
|
||||
node = RootNodeTools(self.team, self.user)
|
||||
|
||||
state = AssistantState(
|
||||
messages=[
|
||||
AssistantMessage(
|
||||
content="Creating insight",
|
||||
tool_calls=[
|
||||
AssistantToolCall(
|
||||
id="insight-123",
|
||||
name="create_and_query_insight",
|
||||
args={"query_kind": "trends", "query_description": "test"},
|
||||
)
|
||||
],
|
||||
)
|
||||
],
|
||||
root_tool_call_id="insight-123",
|
||||
root_tool_insight_plan="test query plan",
|
||||
)
|
||||
|
||||
self.assertEqual(node.router(state), "insights")
|
||||
|
||||
def test_router_insights_search_path(self):
|
||||
"""Test router routes to insights_search when search_insights_query is set"""
|
||||
node = RootNodeTools(self.team, self.user)
|
||||
|
||||
state = AssistantState(
|
||||
messages=[
|
||||
AssistantMessage(
|
||||
content="Searching insights",
|
||||
tool_calls=[AssistantToolCall(id="search-123", name="search", args={"kind": "insights"})],
|
||||
)
|
||||
],
|
||||
root_tool_call_id="search-123",
|
||||
search_insights_query="test search query",
|
||||
)
|
||||
|
||||
self.assertEqual(node.router(state), "insights_search")
|
||||
|
||||
def test_router_session_summarization_path(self):
|
||||
"""Test router routes to session_summarization when session_summarization_query is set"""
|
||||
node = RootNodeTools(self.team, self.user)
|
||||
|
||||
state = AssistantState(
|
||||
messages=[
|
||||
AssistantMessage(
|
||||
content="Summarizing sessions",
|
||||
tool_calls=[
|
||||
AssistantToolCall(
|
||||
id="session-123", name="session_summarization", args={"session_summarization_query": "test"}
|
||||
)
|
||||
],
|
||||
)
|
||||
],
|
||||
root_tool_call_id="session-123",
|
||||
session_summarization_query="test session query",
|
||||
)
|
||||
|
||||
self.assertEqual(node.router(state), "session_summarization")
|
||||
|
||||
def test_router_create_dashboard_path(self):
|
||||
"""Test router routes to create_dashboard when create_dashboard tool is called"""
|
||||
node = RootNodeTools(self.team, self.user)
|
||||
|
||||
state = AssistantState(
|
||||
messages=[
|
||||
AssistantMessage(
|
||||
content="Creating dashboard",
|
||||
tool_calls=[
|
||||
AssistantToolCall(
|
||||
id="dashboard-123", name="create_dashboard", args={"search_insights_queries": []}
|
||||
)
|
||||
],
|
||||
)
|
||||
],
|
||||
root_tool_call_id="dashboard-123",
|
||||
)
|
||||
|
||||
self.assertEqual(node.router(state), "create_dashboard")
|
||||
|
||||
def test_router_search_documentation_fallback(self):
|
||||
"""Test router routes to search_documentation when root_tool_call_id is set but no specific route"""
|
||||
node = RootNodeTools(self.team, self.user)
|
||||
|
||||
state = AssistantState(
|
||||
messages=[
|
||||
AssistantMessage(
|
||||
content="Searching docs",
|
||||
tool_calls=[AssistantToolCall(id="search-123", name="search", args={"kind": "docs"})],
|
||||
)
|
||||
],
|
||||
root_tool_call_id="search-123",
|
||||
)
|
||||
|
||||
self.assertEqual(node.router(state), "search_documentation")
|
||||
|
||||
async def test_arun_session_summarization_with_all_args(self):
|
||||
"""Test session_summarization tool call with all arguments"""
|
||||
node = RootNodeTools(self.team, self.user)
|
||||
state = AssistantState(
|
||||
messages=[
|
||||
AssistantMessage(
|
||||
content="Summarizing sessions",
|
||||
id="test-id",
|
||||
tool_calls=[
|
||||
AssistantToolCall(
|
||||
id="session-123",
|
||||
name="session_summarization",
|
||||
args={
|
||||
"session_summarization_query": "test query",
|
||||
"should_use_current_filters": True,
|
||||
"summary_title": "Test Summary",
|
||||
},
|
||||
)
|
||||
],
|
||||
)
|
||||
]
|
||||
)
|
||||
result = await node.arun(state, {})
|
||||
self.assertIsInstance(result, PartialAssistantState)
|
||||
self.assertEqual(result.root_tool_call_id, "session-123")
|
||||
self.assertEqual(result.session_summarization_query, "test query")
|
||||
self.assertEqual(result.should_use_current_filters, True)
|
||||
self.assertEqual(result.summary_title, "Test Summary")
|
||||
|
||||
async def test_arun_session_summarization_missing_optional_args(self):
|
||||
"""Test session_summarization tool call with missing optional arguments"""
|
||||
node = RootNodeTools(self.team, self.user)
|
||||
state = AssistantState(
|
||||
messages=[
|
||||
AssistantMessage(
|
||||
content="Summarizing sessions",
|
||||
id="test-id",
|
||||
tool_calls=[
|
||||
AssistantToolCall(
|
||||
id="session-123",
|
||||
name="session_summarization",
|
||||
args={"session_summarization_query": "test query"},
|
||||
)
|
||||
],
|
||||
)
|
||||
]
|
||||
)
|
||||
result = await node.arun(state, {})
|
||||
self.assertIsInstance(result, PartialAssistantState)
|
||||
self.assertEqual(result.should_use_current_filters, False) # Default value
|
||||
self.assertIsNone(result.summary_title)
|
||||
|
||||
async def test_arun_create_dashboard_with_queries(self):
|
||||
"""Test create_dashboard tool call with search_insights_queries"""
|
||||
node = RootNodeTools(self.team, self.user)
|
||||
state = AssistantState(
|
||||
messages=[
|
||||
AssistantMessage(
|
||||
content="Creating dashboard",
|
||||
id="test-id",
|
||||
tool_calls=[
|
||||
AssistantToolCall(
|
||||
id="dashboard-123",
|
||||
name="create_dashboard",
|
||||
args={
|
||||
"dashboard_name": "Test Dashboard",
|
||||
"search_insights_queries": [
|
||||
{"name": "Query 1", "description": "Trends insight description"},
|
||||
{"name": "Query 2", "description": "Funnel insight description"},
|
||||
],
|
||||
},
|
||||
)
|
||||
],
|
||||
)
|
||||
]
|
||||
)
|
||||
result = await node.arun(state, {})
|
||||
self.assertIsInstance(result, PartialAssistantState)
|
||||
self.assertEqual(result.root_tool_call_id, "dashboard-123")
|
||||
self.assertEqual(result.dashboard_name, "Test Dashboard")
|
||||
self.assertIsNotNone(result.search_insights_queries)
|
||||
assert result.search_insights_queries is not None
|
||||
self.assertEqual(len(result.search_insights_queries), 2)
|
||||
self.assertEqual(result.search_insights_queries[0].name, "Query 1")
|
||||
self.assertEqual(result.search_insights_queries[1].name, "Query 2")
|
||||
|
||||
async def test_arun_search_tool_insights_kind(self):
|
||||
"""Test search tool with kind=insights"""
|
||||
node = RootNodeTools(self.team, self.user)
|
||||
state = AssistantState(
|
||||
messages=[
|
||||
AssistantMessage(
|
||||
content="Searching insights",
|
||||
id="test-id",
|
||||
tool_calls=[
|
||||
AssistantToolCall(id="search-123", name="search", args={"query": "test", "kind": "insights"})
|
||||
],
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
mock_tool = AsyncMock()
|
||||
mock_tool.ainvoke.return_value = LangchainToolMessage(
|
||||
content="Search results",
|
||||
tool_call_id="search-123",
|
||||
name="search",
|
||||
artifact={"kind": "insights", "query": "test"},
|
||||
)
|
||||
|
||||
with mock_contextual_tool(mock_tool):
|
||||
result = await node.arun(state, {"configurable": {"contextual_tools": {"search": {}}}})
|
||||
|
||||
self.assertIsInstance(result, PartialAssistantState)
|
||||
self.assertEqual(result.root_tool_call_id, "search-123")
|
||||
self.assertEqual(result.search_insights_query, "test")
|
||||
|
||||
async def test_arun_search_tool_docs_kind(self):
|
||||
"""Test search tool with kind=docs"""
|
||||
node = RootNodeTools(self.team, self.user)
|
||||
state = AssistantState(
|
||||
messages=[
|
||||
AssistantMessage(
|
||||
content="Searching docs",
|
||||
id="test-id",
|
||||
tool_calls=[
|
||||
AssistantToolCall(id="search-123", name="search", args={"query": "test", "kind": "docs"})
|
||||
],
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
mock_tool = AsyncMock()
|
||||
mock_tool.ainvoke.return_value = LangchainToolMessage(
|
||||
content="Docs results", tool_call_id="search-123", name="search", artifact={"kind": "docs"}
|
||||
)
|
||||
|
||||
with mock_contextual_tool(mock_tool):
|
||||
result = await node.arun(state, {"configurable": {"contextual_tools": {"search": {}}}})
|
||||
|
||||
self.assertIsInstance(result, PartialAssistantState)
|
||||
self.assertEqual(result.root_tool_call_id, "search-123")
|
||||
|
||||
async def test_arun_read_data_billing_info(self):
|
||||
"""Test read_data tool with kind=billing_info"""
|
||||
node = RootNodeTools(self.team, self.user)
|
||||
state = AssistantState(
|
||||
messages=[
|
||||
AssistantMessage(
|
||||
content="Reading billing info",
|
||||
id="test-id",
|
||||
tool_calls=[AssistantToolCall(id="read-123", name="read_data", args={"kind": "billing_info"})],
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
mock_tool = AsyncMock()
|
||||
mock_tool.ainvoke.return_value = LangchainToolMessage(
|
||||
content="Billing data", tool_call_id="read-123", name="read_data", artifact={"kind": "billing_info"}
|
||||
)
|
||||
|
||||
with mock_contextual_tool(mock_tool):
|
||||
result = await node.arun(state, {"configurable": {"contextual_tools": {"read_data": {}}}})
|
||||
|
||||
self.assertIsInstance(result, PartialAssistantState)
|
||||
self.assertEqual(result.root_tool_call_id, "read-123")
|
||||
|
||||
async def test_arun_tool_updates_state(self):
|
||||
"""Test that when a tool updates its _state, the new messages are included"""
|
||||
node = RootNodeTools(self.team, self.user)
|
||||
state = AssistantState(
|
||||
messages=[
|
||||
AssistantMessage(
|
||||
content="Using tool",
|
||||
id="test-id",
|
||||
tool_calls=[AssistantToolCall(id="tool-123", name="test_tool", args={})],
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
mock_tool = AsyncMock()
|
||||
# Simulate tool appending a message to state
|
||||
updated_state = AssistantState(
|
||||
messages=[
|
||||
*state.messages,
|
||||
AssistantToolCallMessage(content="Tool result", tool_call_id="tool-123", id="msg-1"),
|
||||
]
|
||||
)
|
||||
mock_tool._state = updated_state
|
||||
mock_tool.ainvoke.return_value = LangchainToolMessage(content="Tool result", tool_call_id="tool-123")
|
||||
|
||||
with mock_contextual_tool(mock_tool):
|
||||
result = await node.arun(state, {"configurable": {"contextual_tools": {"test_tool": {}}}})
|
||||
|
||||
# Should include the new message from the updated state
|
||||
self.assertIsInstance(result, PartialAssistantState)
|
||||
self.assertEqual(len(result.messages), 1)
|
||||
self.assertIsInstance(result.messages[0], AssistantToolCallMessage)
|
||||
|
||||
async def test_arun_tool_returns_wrong_type_returns_error_message(self):
|
||||
"""Test that tool returning wrong type returns an error message"""
|
||||
node = RootNodeTools(self.team, self.user)
|
||||
@@ -1151,7 +929,8 @@ class TestRootNodeTools(BaseTest):
|
||||
id="test-id",
|
||||
tool_calls=[AssistantToolCall(id="tool-123", name="test_tool", args={})],
|
||||
)
|
||||
]
|
||||
],
|
||||
root_tool_call_id="tool-123",
|
||||
)
|
||||
|
||||
mock_tool = AsyncMock()
|
||||
@@ -1161,10 +940,11 @@ class TestRootNodeTools(BaseTest):
|
||||
result = await node.arun(state, {"configurable": {"contextual_tools": {"test_tool": {}}}})
|
||||
|
||||
self.assertIsInstance(result, PartialAssistantState)
|
||||
assert result is not None
|
||||
self.assertEqual(len(result.messages), 1)
|
||||
assert isinstance(result.messages[0], AssistantToolCallMessage)
|
||||
self.assertEqual(result.messages[0].tool_call_id, "tool-123")
|
||||
self.assertEqual(result.root_tool_calls_count, 1)
|
||||
self.assertIn("internal error", result.messages[0].content)
|
||||
|
||||
async def test_arun_unknown_tool_returns_error_message(self):
|
||||
"""Test that unknown tool name returns an error message"""
|
||||
@@ -1176,14 +956,16 @@ class TestRootNodeTools(BaseTest):
|
||||
id="test-id",
|
||||
tool_calls=[AssistantToolCall(id="tool-123", name="unknown_tool", args={})],
|
||||
)
|
||||
]
|
||||
],
|
||||
root_tool_call_id="tool-123",
|
||||
)
|
||||
|
||||
with patch("ee.hogai.tool.get_contextual_tool_class", return_value=None):
|
||||
result = await node.arun(state, {})
|
||||
|
||||
self.assertIsInstance(result, PartialAssistantState)
|
||||
assert result is not None
|
||||
self.assertEqual(len(result.messages), 1)
|
||||
assert isinstance(result.messages[0], AssistantToolCallMessage)
|
||||
self.assertEqual(result.messages[0].tool_call_id, "tool-123")
|
||||
self.assertEqual(result.root_tool_calls_count, 1)
|
||||
self.assertIn("does not exist", result.messages[0].content)
|
||||
|
||||
@@ -1,21 +1,26 @@
|
||||
from .legacy import create_and_query_insight, create_dashboard, session_summarization
|
||||
from .create_and_query_insight import CreateAndQueryInsightTool, CreateAndQueryInsightToolArgs
|
||||
from .create_dashboard import CreateDashboardTool, CreateDashboardToolArgs
|
||||
from .navigate import NavigateTool, NavigateToolArgs
|
||||
from .read_data import ReadDataTool, ReadDataToolArgs
|
||||
from .read_taxonomy import ReadTaxonomyTool
|
||||
from .search import SearchTool, SearchToolArgs
|
||||
from .session_summarization import SessionSummarizationTool, SessionSummarizationToolArgs
|
||||
from .todo_write import TodoWriteTool, TodoWriteToolArgs
|
||||
|
||||
__all__ = [
|
||||
"CreateAndQueryInsightTool",
|
||||
"CreateAndQueryInsightToolArgs",
|
||||
"CreateDashboardTool",
|
||||
"CreateDashboardToolArgs",
|
||||
"NavigateTool",
|
||||
"NavigateToolArgs",
|
||||
"ReadDataTool",
|
||||
"ReadDataToolArgs",
|
||||
"ReadTaxonomyTool",
|
||||
"SearchTool",
|
||||
"SearchToolArgs",
|
||||
"ReadDataTool",
|
||||
"ReadDataToolArgs",
|
||||
"SessionSummarizationTool",
|
||||
"SessionSummarizationToolArgs",
|
||||
"TodoWriteTool",
|
||||
"TodoWriteToolArgs",
|
||||
"NavigateTool",
|
||||
"NavigateToolArgs",
|
||||
"create_and_query_insight",
|
||||
"session_summarization",
|
||||
"create_dashboard",
|
||||
]
|
||||
|
||||
206
ee/hogai/graph/root/tools/create_and_query_insight.py
Normal file
206
ee/hogai/graph/root/tools/create_and_query_insight.py
Normal file
@@ -0,0 +1,206 @@
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from posthog.schema import AssistantTool, AssistantToolCallMessage, VisualizationMessage
|
||||
|
||||
from ee.hogai.context.context import AssistantContextManager
|
||||
from ee.hogai.graph.insights_graph.graph import InsightsGraph
|
||||
from ee.hogai.graph.schema_generator.nodes import SchemaGenerationException
|
||||
from ee.hogai.tool import MaxTool, ToolMessagesArtifact
|
||||
from ee.hogai.utils.prompt import format_prompt_string
|
||||
from ee.hogai.utils.types.base import AssistantState
|
||||
|
||||
INSIGHT_TOOL_PROMPT = """
|
||||
Use this tool to create a product analytics insight for a given natural language description by spawning a subagent.
|
||||
The tool generates a query and returns formatted text results for a specific data question or iterates on a previous query. It only retrieves a single query per call. If the user asks for multiple insights, you need to decompose a query into multiple subqueries and call the tool for each subquery.
|
||||
|
||||
Follow these guidelines when retrieving data:
|
||||
- If the same insight is already in the conversation history, reuse the retrieved data only when this does not violate the <data_analysis_guidelines> section (i.e. only when a presence-check, count, or sort on existing columns is enough).
|
||||
- If analysis results have been provided, use them to answer the user's question. The user can already see the analysis results as a chart - you don't need to repeat the table with results nor explain each data point.
|
||||
- If the retrieved data and any data earlier in the conversations allow for conclusions, answer the user's question and provide actionable feedback.
|
||||
- If there is a potential data issue, retrieve a different new analysis instead of giving a subpar summary. Note: empty data is NOT a potential data issue.
|
||||
- If the query cannot be answered with a UI-built insight type - trends, funnels, retention - choose the SQL type to answer the question (e.g. for listing events or aggregating in ways that aren't supported in trends/funnels/retention).
|
||||
|
||||
IMPORTANT: Avoid generic advice. Take into account what you know about the product. Your answer needs to be super high-impact and no more than a few sentences.
|
||||
Remember: do NOT retrieve data for the same query more than 3 times in a row.
|
||||
|
||||
# Data schema
|
||||
|
||||
You can pass events, actions, properties, and property values to this tool by specifying the "Data schema" section.
|
||||
|
||||
<example>
|
||||
User: Calculate onboarding completion rate for the last week.
|
||||
Assistant: I'm going to retrieve the existing data schema first.
|
||||
*Retrieves matching events, properties, and property values*
|
||||
Assistant: I'm going to create a new trends insight.
|
||||
*Calls this tool with the query description: "Trends insight of the onboarding completion rate. Data schema: Relevant matching data schema"*
|
||||
</example>
|
||||
|
||||
# Supported insight types
|
||||
## Trends
|
||||
A trends insight visualizes events over time using time series. They're useful for finding patterns in historical data.
|
||||
|
||||
The trends insights have the following features:
|
||||
- The insight can show multiple trends in one request.
|
||||
- Custom formulas can calculate derived metrics, like `A/B*100` to calculate a ratio.
|
||||
- Filter and break down data using multiple properties.
|
||||
- Compare with the previous period and sample data.
|
||||
- Apply various aggregation types, like sum, average, etc., and chart types.
|
||||
- And more.
|
||||
|
||||
Examples of use cases include:
|
||||
- How the product's most important metrics change over time.
|
||||
- Long-term patterns, or cycles in product's usage.
|
||||
- The usage of different features side-by-side.
|
||||
- How the properties of events vary using aggregation (sum, average, etc).
|
||||
- Users can also visualize the same data points in a variety of ways.
|
||||
|
||||
## Funnel
|
||||
A funnel insight visualizes a sequence of events that users go through in a product. They use percentages as the primary aggregation type. Funnels use two or more series, so the conversation history should mention at least two events.
|
||||
|
||||
The funnel insights have the following features:
|
||||
- Various visualization types (steps, time-to-convert, historical trends).
|
||||
- Filter data and apply exclusion steps.
|
||||
- Break down data using a single property.
|
||||
- Specify conversion windows, details of conversion calculation, attribution settings.
|
||||
- Sample data.
|
||||
- And more.
|
||||
|
||||
Examples of use cases include:
|
||||
- Conversion rates.
|
||||
- Drop off steps.
|
||||
- Steps with the highest friction and time to convert.
|
||||
- If product changes are improving their funnel over time.
|
||||
- Average/median time to convert.
|
||||
- Conversion trends over time.
|
||||
|
||||
## Retention
|
||||
A retention insight visualizes how many users return to the product after performing some action. They're useful for understanding user engagement and retention.
|
||||
|
||||
The retention insights have the following features: filter data, sample data, and more.
|
||||
|
||||
Examples of use cases include:
|
||||
- How many users come back and perform an action after their first visit.
|
||||
- How many users come back to perform action X after performing action Y.
|
||||
- How often users return to use a specific feature.
|
||||
|
||||
## SQL
|
||||
The 'sql' insight type allows you to write arbitrary SQL queries to retrieve data.
|
||||
|
||||
The SQL insights have the following features:
|
||||
- Filter data using arbitrary SQL.
|
||||
- All ClickHouse SQL features.
|
||||
- You can nest subqueries as needed.
|
||||
""".strip()
|
||||
|
||||
INSIGHT_TOOL_CONTEXT_PROMPT_TEMPLATE = """
|
||||
The user is currently editing an insight (aka query). Here is that insight's current definition, which can be edited using the `create_and_query_insight` tool:
|
||||
|
||||
```json
|
||||
{current_query}
|
||||
```
|
||||
|
||||
<system_reminder>
|
||||
Do not remove any fields from the current insight definition. Do not change any other fields than the ones the user asked for. Keep the rest as is.
|
||||
</system_reminder>
|
||||
""".strip()
|
||||
|
||||
INSIGHT_TOOL_FAILURE_SYSTEM_REMINDER_PROMPT = """
|
||||
<system_reminder>
|
||||
Inform the user that you've encountered an error during the creation of the insight. Afterwards, try to generate a new insight with a different query.
|
||||
Terminate if the error persists.
|
||||
</system_reminder>
|
||||
""".strip()
|
||||
|
||||
INSIGHT_TOOL_HANDLED_FAILURE_PROMPT = """
|
||||
The agent has encountered the error while creating an insight.
|
||||
|
||||
Generated output:
|
||||
```
|
||||
{{{output}}}
|
||||
```
|
||||
|
||||
Error message:
|
||||
```
|
||||
{{{error_message}}}
|
||||
```
|
||||
|
||||
{{{system_reminder}}}
|
||||
""".strip()
|
||||
|
||||
|
||||
INSIGHT_TOOL_UNHANDLED_FAILURE_PROMPT = """
|
||||
The agent has encountered an unknown error while creating an insight.
|
||||
{{{system_reminder}}}
|
||||
""".strip()
|
||||
|
||||
|
||||
class CreateAndQueryInsightToolArgs(BaseModel):
|
||||
query_description: str = Field(
|
||||
description=(
|
||||
"A description of the query to generate, encapsulating the details of the user's request. "
|
||||
"Include all relevant context from earlier messages too, as the tool won't see that conversation history. "
|
||||
"If an existing insight has been used as a starting point, include that insight's filters and query in the description. "
|
||||
"Don't be overly prescriptive with event or property names, unless the user indicated they mean this specific name (e.g. with quotes). "
|
||||
"If the users seems to ask for a list of entities, rather than a count, state this explicitly."
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class CreateAndQueryInsightTool(MaxTool):
|
||||
name: Literal["create_and_query_insight"] = "create_and_query_insight"
|
||||
args_schema: type[BaseModel] = CreateAndQueryInsightToolArgs
|
||||
description: str = INSIGHT_TOOL_PROMPT
|
||||
context_prompt_template: str = INSIGHT_TOOL_CONTEXT_PROMPT_TEMPLATE
|
||||
thinking_message: str = "Coming up with an insight"
|
||||
|
||||
async def _arun_impl(self, query_description: str) -> tuple[str, ToolMessagesArtifact | None]:
|
||||
graph = InsightsGraph(self._team, self._user).compile_full_graph()
|
||||
new_state = self._state.model_copy(
|
||||
update={
|
||||
"root_tool_call_id": self.tool_call_id,
|
||||
"root_tool_insight_plan": query_description,
|
||||
},
|
||||
deep=True,
|
||||
)
|
||||
try:
|
||||
dict_state = await graph.ainvoke(new_state)
|
||||
except SchemaGenerationException as e:
|
||||
return format_prompt_string(
|
||||
INSIGHT_TOOL_HANDLED_FAILURE_PROMPT,
|
||||
output=e.llm_output,
|
||||
error_message=e.validation_message,
|
||||
system_reminder=INSIGHT_TOOL_FAILURE_SYSTEM_REMINDER_PROMPT,
|
||||
), None
|
||||
|
||||
updated_state = AssistantState.model_validate(dict_state)
|
||||
maybe_viz_message, tool_call_message = updated_state.messages[-2:]
|
||||
|
||||
if not isinstance(tool_call_message, AssistantToolCallMessage):
|
||||
return format_prompt_string(
|
||||
INSIGHT_TOOL_UNHANDLED_FAILURE_PROMPT, system_reminder=INSIGHT_TOOL_FAILURE_SYSTEM_REMINDER_PROMPT
|
||||
), None
|
||||
|
||||
# If the previous message is not a visualization message, the agent has requested human feedback.
|
||||
if not isinstance(maybe_viz_message, VisualizationMessage):
|
||||
return "", ToolMessagesArtifact(messages=[tool_call_message])
|
||||
|
||||
# If the contextual tool is available, we're editing an insight.
|
||||
# Add the UI payload to the tool call message.
|
||||
if self.is_editing_mode(self._context_manager):
|
||||
tool_call_message = AssistantToolCallMessage(
|
||||
content=tool_call_message.content,
|
||||
ui_payload={self.get_name(): maybe_viz_message.answer.model_dump(exclude_none=True)},
|
||||
id=tool_call_message.id,
|
||||
tool_call_id=tool_call_message.tool_call_id,
|
||||
)
|
||||
|
||||
return "", ToolMessagesArtifact(messages=[maybe_viz_message, tool_call_message])
|
||||
|
||||
@classmethod
|
||||
def is_editing_mode(cls, context_manager: AssistantContextManager) -> bool:
|
||||
"""
|
||||
Determines if the tool is in editing mode.
|
||||
"""
|
||||
return AssistantTool.CREATE_AND_QUERY_INSIGHT.value in context_manager.get_contextual_tools()
|
||||
53
ee/hogai/graph/root/tools/create_dashboard.py
Normal file
53
ee/hogai/graph/root/tools/create_dashboard.py
Normal file
@@ -0,0 +1,53 @@
|
||||
from typing import Literal
|
||||
|
||||
from langchain_core.runnables import RunnableLambda
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from ee.hogai.graph.dashboards.nodes import DashboardCreationNode
|
||||
from ee.hogai.tool import MaxTool, ToolMessagesArtifact
|
||||
from ee.hogai.utils.types.base import AssistantState, InsightQuery, PartialAssistantState
|
||||
|
||||
CREATE_DASHBOARD_TOOL_PROMPT = """
|
||||
Use this tool when users ask to create, build, or make a new dashboard with insights.
|
||||
This tool will search for existing insights that match the user's requirements so no need to call `search` tool, or create new insights if none are found, then combine them into a dashboard.
|
||||
Do not call this tool if the user only asks to find, search for, or look up existing insights and does not ask to create a dashboard.
|
||||
If you decided to use this tool, there is no need to call `search_insights` tool beforehand. The tool will search for existing insights that match the user's requirements and create new insights if none are found.
|
||||
""".strip()
|
||||
|
||||
|
||||
class CreateDashboardToolArgs(BaseModel):
|
||||
search_insights_queries: list[InsightQuery] = Field(
|
||||
description="A list of insights to be included in the dashboard. Include all the insights that the user mentioned."
|
||||
)
|
||||
dashboard_name: str = Field(
|
||||
description=(
|
||||
"The name of the dashboard to be created based on the user request. It should be short and concise as it will be displayed as a header in the dashboard tile."
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class CreateDashboardTool(MaxTool):
|
||||
name: Literal["create_dashboard"] = "create_dashboard"
|
||||
description: str = CREATE_DASHBOARD_TOOL_PROMPT
|
||||
thinking_message: str = "Creating a dashboard"
|
||||
context_prompt_template: str = "Creates a dashboard based on the user's request"
|
||||
args_schema: type[BaseModel] = CreateDashboardToolArgs
|
||||
show_tool_call_message: bool = False
|
||||
|
||||
async def _arun_impl(
|
||||
self, search_insights_queries: list[InsightQuery], dashboard_name: str
|
||||
) -> tuple[str, ToolMessagesArtifact | None]:
|
||||
node = DashboardCreationNode(self._team, self._user)
|
||||
chain: RunnableLambda[AssistantState, PartialAssistantState | None] = RunnableLambda(node)
|
||||
copied_state = self._state.model_copy(
|
||||
deep=True,
|
||||
update={
|
||||
"root_tool_call_id": self.tool_call_id,
|
||||
"search_insights_queries": search_insights_queries,
|
||||
"dashboard_name": dashboard_name,
|
||||
},
|
||||
)
|
||||
result = await chain.ainvoke(copied_state)
|
||||
if not result or not result.messages:
|
||||
return "Dashboard creation failed", None
|
||||
return "", ToolMessagesArtifact(messages=result.messages)
|
||||
@@ -3,10 +3,13 @@ from unittest.mock import Mock, patch
|
||||
|
||||
from django.conf import settings
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from parameterized import parameterized
|
||||
|
||||
from ee.hogai.graph.root.tools.full_text_search.tool import ENTITY_MAP, EntitySearchToolkit, FTSKind
|
||||
from ee.hogai.context import AssistantContextManager
|
||||
from ee.hogai.graph.root.tools.full_text_search.tool import ENTITY_MAP, EntitySearchTool, FTSKind
|
||||
from ee.hogai.graph.shared_prompts import HYPERLINK_USAGE_INSTRUCTIONS
|
||||
from ee.hogai.utils.types.base import AssistantState
|
||||
|
||||
|
||||
class TestEntitySearchToolkit(NonAtomicBaseTest):
|
||||
@@ -18,7 +21,13 @@ class TestEntitySearchToolkit(NonAtomicBaseTest):
|
||||
self.team.organization = Mock()
|
||||
self.team.organization.id = 789
|
||||
self.user = Mock()
|
||||
self.toolkit = EntitySearchToolkit(self.team, self.user)
|
||||
self.toolkit = EntitySearchTool(
|
||||
team=self.team,
|
||||
user=self.user,
|
||||
state=AssistantState(messages=[]),
|
||||
config=RunnableConfig(configurable={}),
|
||||
context_manager=AssistantContextManager(self.team, self.user, {}),
|
||||
)
|
||||
|
||||
@parameterized.expand(
|
||||
[
|
||||
|
||||
@@ -6,13 +6,14 @@ import yaml
|
||||
from posthoganalytics import capture_exception
|
||||
|
||||
from posthog.api.search import EntityConfig, search_entities
|
||||
from posthog.models import Action, Cohort, Dashboard, Experiment, FeatureFlag, Insight, Survey, Team, User
|
||||
from posthog.models import Action, Cohort, Dashboard, Experiment, FeatureFlag, Insight, Survey
|
||||
from posthog.rbac.user_access_control import UserAccessControl
|
||||
from posthog.sync import database_sync_to_async
|
||||
|
||||
from products.error_tracking.backend.models import ErrorTrackingIssue
|
||||
|
||||
from ee.hogai.graph.shared_prompts import HYPERLINK_USAGE_INSTRUCTIONS
|
||||
from ee.hogai.tool import MaxSubtool
|
||||
|
||||
from .prompts import ENTITY_TYPE_SUMMARY_TEMPLATE, FOUND_ENTITIES_MESSAGE_TEMPLATE
|
||||
|
||||
@@ -91,14 +92,10 @@ SEARCH_KIND_TO_DATABASE_ENTITY_TYPE: dict[FTSKind, str] = {
|
||||
}
|
||||
|
||||
|
||||
class EntitySearchToolkit:
|
||||
class EntitySearchTool(MaxSubtool):
|
||||
MAX_ENTITY_RESULTS = 10
|
||||
MAX_CONCURRENT_SEARCHES = 10
|
||||
|
||||
def __init__(self, team: Team, user: User):
|
||||
self._team = team
|
||||
self._user = user
|
||||
|
||||
async def execute(self, query: str, search_kind: FTSKind) -> str:
|
||||
"""Search for entities by query and entity."""
|
||||
try:
|
||||
|
||||
@@ -1,204 +0,0 @@
|
||||
# The module contains tools that are deprecated and will be replaced in the future with MaxTool implementations.
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from ee.hogai.utils.types.base import InsightQuery
|
||||
|
||||
|
||||
# Lower casing matters here. Do not change it.
|
||||
class create_and_query_insight(BaseModel):
|
||||
"""
|
||||
Use this tool to spawn a subagent that will create a product analytics insight for a given description.
|
||||
The tool generates a query and returns formatted text results for a specific data question or iterates on a previous query. It only retrieves a single query per call. If the user asks for multiple insights, you need to decompose a query into multiple subqueries and call the tool for each subquery.
|
||||
|
||||
Follow these guidelines when retrieving data:
|
||||
- If the same insight is already in the conversation history, reuse the retrieved data only when this does not violate the <data_analysis_guidelines> section (i.e. only when a presence-check, count, or sort on existing columns is enough).
|
||||
- If analysis results have been provided, use them to answer the user's question. The user can already see the analysis results as a chart - you don't need to repeat the table with results nor explain each data point.
|
||||
- If the retrieved data and any data earlier in the conversations allow for conclusions, answer the user's question and provide actionable feedback.
|
||||
- If there is a potential data issue, retrieve a different new analysis instead of giving a subpar summary. Note: empty data is NOT a potential data issue.
|
||||
- If the query cannot be answered with a UI-built insight type - trends, funnels, retention - choose the SQL type to answer the question (e.g. for listing events or aggregating in ways that aren't supported in trends/funnels/retention).
|
||||
|
||||
IMPORTANT: Avoid generic advice. Take into account what you know about the product. Your answer needs to be super high-impact and no more than a few sentences.
|
||||
Remember: do NOT retrieve data for the same query more than 3 times in a row.
|
||||
|
||||
# Data schema
|
||||
|
||||
You can pass events, actions, properties, and property values to this tool by specifying the "Data schema" section.
|
||||
|
||||
<example>
|
||||
User: Calculate onboarding completion rate for the last week.
|
||||
Assistant: I'm going to retrieve the existing data schema first.
|
||||
*Retrieves matching events, properties, and property values*
|
||||
Assistant: I'm going to create a new trends insight.
|
||||
*Calls this tool with the query description: "Trends insight of the onboarding completion rate. Data schema: Relevant matching data schema"*
|
||||
</example>
|
||||
|
||||
# Supported insight types
|
||||
## Trends
|
||||
A trends insight visualizes events over time using time series. They're useful for finding patterns in historical data.
|
||||
|
||||
The trends insights have the following features:
|
||||
- The insight can show multiple trends in one request.
|
||||
- Custom formulas can calculate derived metrics, like `A/B*100` to calculate a ratio.
|
||||
- Filter and break down data using multiple properties.
|
||||
- Compare with the previous period and sample data.
|
||||
- Apply various aggregation types, like sum, average, etc., and chart types.
|
||||
- And more.
|
||||
|
||||
Examples of use cases include:
|
||||
- How the product's most important metrics change over time.
|
||||
- Long-term patterns, or cycles in product's usage.
|
||||
- The usage of different features side-by-side.
|
||||
- How the properties of events vary using aggregation (sum, average, etc).
|
||||
- Users can also visualize the same data points in a variety of ways.
|
||||
|
||||
## Funnel
|
||||
A funnel insight visualizes a sequence of events that users go through in a product. They use percentages as the primary aggregation type. Funnels use two or more series, so the conversation history should mention at least two events.
|
||||
|
||||
The funnel insights have the following features:
|
||||
- Various visualization types (steps, time-to-convert, historical trends).
|
||||
- Filter data and apply exclusion steps.
|
||||
- Break down data using a single property.
|
||||
- Specify conversion windows, details of conversion calculation, attribution settings.
|
||||
- Sample data.
|
||||
- And more.
|
||||
|
||||
Examples of use cases include:
|
||||
- Conversion rates.
|
||||
- Drop off steps.
|
||||
- Steps with the highest friction and time to convert.
|
||||
- If product changes are improving their funnel over time.
|
||||
- Average/median time to convert.
|
||||
- Conversion trends over time.
|
||||
|
||||
## Retention
|
||||
A retention insight visualizes how many users return to the product after performing some action. They're useful for understanding user engagement and retention.
|
||||
|
||||
The retention insights have the following features: filter data, sample data, and more.
|
||||
|
||||
Examples of use cases include:
|
||||
- How many users come back and perform an action after their first visit.
|
||||
- How many users come back to perform action X after performing action Y.
|
||||
- How often users return to use a specific feature.
|
||||
|
||||
## SQL
|
||||
The 'sql' insight type allows you to write arbitrary SQL queries to retrieve data.
|
||||
|
||||
The SQL insights have the following features:
|
||||
- Filter data using arbitrary SQL.
|
||||
- All ClickHouse SQL features.
|
||||
- You can nest subqueries as needed.
|
||||
"""
|
||||
|
||||
query_description: str = Field(
|
||||
description=(
|
||||
"A description of the query to generate, encapsulating the details of the user's request. "
|
||||
"Include all relevant context from earlier messages too, as the tool won't see that conversation history. "
|
||||
"If an existing insight has been used as a starting point, include that insight's filters and query in the description. "
|
||||
"Don't be overly prescriptive with event or property names, unless the user indicated they mean this specific name (e.g. with quotes). "
|
||||
"If the users seems to ask for a list of entities, rather than a count, state this explicitly."
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class session_summarization(BaseModel):
|
||||
"""
|
||||
Use this tool to summarize session recordings by analysing the events within those sessions to find patterns and issues.
|
||||
It will return a textual summary of the captured session recordings.
|
||||
|
||||
# When to use the tool:
|
||||
When the user asks to summarize session recordings:
|
||||
- "summarize" synonyms: "watch", "analyze", "review", and similar
|
||||
- "session recordings" synonyms: "sessions", "recordings", "replays", "user sessions", and similar
|
||||
|
||||
# When NOT to use the tool:
|
||||
- When the user asks to find, search for, or look up session recordings, but doesn't ask to summarize them
|
||||
- When users asks to update, change, or adjust session recordings filters
|
||||
|
||||
# Synonyms
|
||||
- "summarize": "watch", "analyze", "review", and similar
|
||||
- "session recordings": "sessions", "recordings", "replays", "user sessions", and similar
|
||||
|
||||
# Managing context
|
||||
If the conversation history contains context about the current filters or session recordings, follow these steps:
|
||||
- Convert the user query into a `session_summarization_query`
|
||||
- The query should be used to understand the user's intent
|
||||
- Decide if the query is relevant to the current filters and set `should_use_current_filters` accordingly
|
||||
- Generate the `summary_title` based on the user's query and the current filters
|
||||
|
||||
Otherwise:
|
||||
- Convert the user query into a `session_summarization_query`
|
||||
- The query should be used to search for relevant sessions and then summarize them
|
||||
- Assume the `should_use_current_filters` should be always `false`
|
||||
- Generate the `summary_title` based on the user's query
|
||||
|
||||
# Additional guidelines
|
||||
- CRITICAL: Always pass the user's complete, unmodified query to the `session_summarization_query` parameter
|
||||
- DO NOT truncate, summarize, or extract keywords from the user's query
|
||||
- The query is used to find relevant sessions - context helps find better matches
|
||||
- Use explicit tool definition to make a decision
|
||||
"""
|
||||
|
||||
session_summarization_query: str = Field(
|
||||
description="""
|
||||
- The user's complete query for session recordings summarization.
|
||||
- This will be used to find relevant session recordings.
|
||||
- Always pass the user's complete, unmodified query.
|
||||
- Examples:
|
||||
* 'summarize all session recordings from yesterday'
|
||||
* 'analyze mobile user session recordings from last week, even if 1 second'
|
||||
* 'watch last 300 session recordings of MacOS users from US'
|
||||
* and similar
|
||||
"""
|
||||
)
|
||||
should_use_current_filters: bool = Field(
|
||||
description="""
|
||||
- Whether to use current filters from user's UI to find relevant session recordings.
|
||||
- IMPORTANT: Should be always `false` if the current filters or `search_session_recordings` tool are not present in the conversation history.
|
||||
- Examples:
|
||||
* Set to `true` if one of the conditions is met:
|
||||
- the user wants to summarize "current/selected/opened/my/all/these" session recordings
|
||||
- the user wants to use "current/these" filters
|
||||
- the user's query specifies filters identical to the current filters
|
||||
- if the user's query doesn't specify any filters/conditions
|
||||
- the user refers to what they're "looking at" or "viewing"
|
||||
* Set to `false` if one of the conditions is met:
|
||||
- no current filters or `search_session_recordings` tool are present in the conversation
|
||||
- the user specifies date/time period different from the current filters
|
||||
- the user specifies conditions (user, device, id, URL, etc.) not present in the current filters
|
||||
""",
|
||||
)
|
||||
summary_title: str = Field(
|
||||
description="""
|
||||
- The name of the summary that is expected to be generated from the user's `session_summarization_query` and/or `current_filters` (if present).
|
||||
- The name should cover in 3-7 words what sessions would be to be summarized in the summary
|
||||
- This won't be used for any search of filtering, only to properly label the generated summary.
|
||||
- Examples:
|
||||
* If `should_use_current_filters` is `false`, then the `summary_title` should be generated based on the `session_summarization_query`:
|
||||
- query: "I want to watch all the sessions of user `user@example.com` in the last 30 days no matter how long" -> name: "Sessions of the user user@example.com (last 30 days)"
|
||||
- query: "summarize my last 100 session recordings" -> name: "Last 100 sessions"
|
||||
- and similar
|
||||
* If `should_use_current_filters` is `true`, then the `summary_title` should be generated based on the current filters in the context (if present):
|
||||
- filters: "{"key":"$os","value":["Mac OS X"],"operator":"exact","type":"event"}" -> name: "MacOS users"
|
||||
- filters: "{"date_from": "-7d", "filter_test_accounts": True}" -> name: "All sessions (last 7 days)"
|
||||
- and similar
|
||||
* If there's not enough context to generated the summary name - keep it an empty string ("")
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
class create_dashboard(BaseModel):
|
||||
"""
|
||||
Use this tool when users ask to create, build, or make a new dashboard with insights.
|
||||
This tool will search for existing insights that match the user's requirements so no need to call `search_insights` tool, or create new insights if none are found, then combine them into a dashboard.
|
||||
Do not call this tool if the user only asks to find, search for, or look up existing insights and does not ask to create a dashboard.
|
||||
If you decided to use this tool, there is no need to call `search_insights` tool beforehand. The tool will search for existing insights that match the user's requirements and create new insights if none are found.
|
||||
"""
|
||||
|
||||
search_insights_queries: list[InsightQuery] = Field(
|
||||
description="A list of insights to be included in the dashboard. Include all the insights that the user mentioned."
|
||||
)
|
||||
dashboard_name: str = Field(
|
||||
description=(
|
||||
"The name of the dashboard to be created based on the user request. It should be short and concise as it will be displayed as a header in the dashboard tile."
|
||||
)
|
||||
)
|
||||
@@ -10,7 +10,7 @@ from posthog.models import Team, User
|
||||
from ee.hogai.context.context import AssistantContextManager
|
||||
from ee.hogai.tool import MaxTool
|
||||
from ee.hogai.utils.prompt import format_prompt_string
|
||||
from ee.hogai.utils.types.base import AssistantState
|
||||
from ee.hogai.utils.types.base import AssistantState, NodePath
|
||||
|
||||
NAVIGATION_TOOL_PROMPT = """
|
||||
Use the `navigate` tool to move between different pages in the PostHog application.
|
||||
@@ -57,7 +57,7 @@ class NavigateTool(MaxTool):
|
||||
*,
|
||||
team: Team,
|
||||
user: User,
|
||||
tool_call_id: str,
|
||||
node_path: tuple[NodePath, ...] | None = None,
|
||||
state: AssistantState | None = None,
|
||||
config: RunnableConfig | None = None,
|
||||
context_manager: AssistantContextManager | None = None,
|
||||
@@ -69,7 +69,7 @@ class NavigateTool(MaxTool):
|
||||
return cls(
|
||||
team=team,
|
||||
user=user,
|
||||
tool_call_id=tool_call_id,
|
||||
node_path=node_path,
|
||||
state=state,
|
||||
config=config,
|
||||
context_manager=context_manager,
|
||||
|
||||
@@ -2,11 +2,12 @@ import datetime
|
||||
from typing import cast
|
||||
from uuid import uuid4
|
||||
|
||||
from posthog.test.base import BaseTest, ClickhouseTestMixin
|
||||
from posthog.test.base import ClickhouseTestMixin, NonAtomicBaseTest
|
||||
from unittest.mock import patch
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
|
||||
from posthog.schema import (
|
||||
AssistantToolCallMessage,
|
||||
BillingSpendResponseBreakdownType,
|
||||
BillingUsageResponseBreakdownType,
|
||||
MaxAddonInfo,
|
||||
@@ -21,26 +22,30 @@ from posthog.schema import (
|
||||
UsageHistoryItem,
|
||||
)
|
||||
|
||||
from ee.hogai.graph.billing.nodes import BillingNode
|
||||
from ee.hogai.context.context import AssistantContextManager
|
||||
from ee.hogai.graph.root.tools.read_billing_tool.tool import ReadBillingTool
|
||||
from ee.hogai.utils.types import AssistantState
|
||||
|
||||
|
||||
class TestBillingNode(ClickhouseTestMixin, BaseTest):
|
||||
class TestBillingNode(ClickhouseTestMixin, NonAtomicBaseTest):
|
||||
CLASS_DATA_LEVEL_SETUP = False
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.node = BillingNode(self.team, self.user)
|
||||
self.tool_call_id = str(uuid4())
|
||||
self.state = AssistantState(messages=[], root_tool_call_id=self.tool_call_id)
|
||||
self.tool = ReadBillingTool(
|
||||
team=self.team,
|
||||
user=self.user,
|
||||
state=AssistantState(messages=[], root_tool_call_id=str(uuid4())),
|
||||
config=RunnableConfig(configurable={}),
|
||||
context_manager=AssistantContextManager(self.team, self.user, {}),
|
||||
)
|
||||
|
||||
def test_run_with_no_billing_context(self):
|
||||
with patch.object(self.node.context_manager, "get_billing_context", return_value=None):
|
||||
result = self.node.run(self.state, {})
|
||||
self.assertEqual(len(result.messages), 1)
|
||||
message = result.messages[0]
|
||||
self.assertIsInstance(message, AssistantToolCallMessage)
|
||||
self.assertEqual(cast(AssistantToolCallMessage, message).content, "No billing information available")
|
||||
async def test_run_with_no_billing_context(self):
|
||||
with patch.object(self.tool._context_manager, "get_billing_context", return_value=None):
|
||||
result = await self.tool.execute()
|
||||
self.assertEqual(result, "No billing information available")
|
||||
|
||||
def test_run_with_billing_context(self):
|
||||
async def test_run_with_billing_context(self):
|
||||
billing_context = MaxBillingContext(
|
||||
subscription_level=MaxBillingContextSubscriptionLevel.PAID,
|
||||
billing_plan="paid",
|
||||
@@ -50,17 +55,13 @@ class TestBillingNode(ClickhouseTestMixin, BaseTest):
|
||||
products=[],
|
||||
)
|
||||
with (
|
||||
patch.object(self.node.context_manager, "get_billing_context", return_value=billing_context),
|
||||
patch.object(self.node, "_format_billing_context", return_value="Formatted Context"),
|
||||
patch.object(self.tool._context_manager, "get_billing_context", return_value=billing_context),
|
||||
patch.object(self.tool, "_format_billing_context", return_value="Formatted Context"),
|
||||
):
|
||||
result = self.node.run(self.state, {})
|
||||
self.assertEqual(len(result.messages), 1)
|
||||
message = result.messages[0]
|
||||
self.assertIsInstance(message, AssistantToolCallMessage)
|
||||
self.assertEqual(cast(AssistantToolCallMessage, message).content, "Formatted Context")
|
||||
self.assertEqual(cast(AssistantToolCallMessage, message).tool_call_id, self.tool_call_id)
|
||||
result = await self.tool.execute()
|
||||
self.assertEqual(result, "Formatted Context")
|
||||
|
||||
def test_format_billing_context(self):
|
||||
async def test_format_billing_context(self):
|
||||
billing_context = MaxBillingContext(
|
||||
subscription_level=MaxBillingContextSubscriptionLevel.PAID,
|
||||
billing_plan="paid",
|
||||
@@ -89,8 +90,8 @@ class TestBillingNode(ClickhouseTestMixin, BaseTest):
|
||||
settings=MaxBillingContextSettings(autocapture_on=True, active_destinations=2),
|
||||
)
|
||||
|
||||
with patch.object(self.node, "_get_top_events_by_usage", return_value=[]):
|
||||
formatted_string = self.node._format_billing_context(billing_context)
|
||||
with patch.object(self.tool, "_get_top_events_by_usage", return_value=[]):
|
||||
formatted_string = await self.tool._format_billing_context(billing_context)
|
||||
self.assertIn("(paid)", formatted_string)
|
||||
self.assertIn("Period: 2023-01-01 to 2023-01-31", formatted_string)
|
||||
|
||||
@@ -123,19 +124,21 @@ class TestBillingNode(ClickhouseTestMixin, BaseTest):
|
||||
),
|
||||
]
|
||||
|
||||
usage_table = self.node._format_history_table(usage_history)
|
||||
usage_table = self.tool._format_history_table(usage_history)
|
||||
self.assertIn("### Overall (all projects)", usage_table)
|
||||
self.assertIn("| Data Type | 2023-01-01 | 2023-01-02 |", usage_table)
|
||||
self.assertIn("| Recordings | 100.00 | 200.00 |", usage_table)
|
||||
self.assertIn("| Events | 1.50 | 2.50 |", usage_table)
|
||||
|
||||
spend_table = self.node._format_history_table(spend_history)
|
||||
spend_table = self.tool._format_history_table(spend_history)
|
||||
self.assertIn("| Mobile Recordings | 10.50 | 20.00 |", spend_table)
|
||||
|
||||
def test_get_top_events_by_usage(self):
|
||||
mock_results = [("pageview", 1000), ("$autocapture", 500)]
|
||||
with patch("ee.hogai.graph.billing.nodes.sync_execute", return_value=mock_results) as mock_sync_execute:
|
||||
top_events = self.node._get_top_events_by_usage()
|
||||
with patch(
|
||||
"ee.hogai.graph.root.tools.read_billing_tool.tool.sync_execute", return_value=mock_results
|
||||
) as mock_sync_execute:
|
||||
top_events = self.tool._get_top_events_by_usage()
|
||||
self.assertEqual(len(top_events), 2)
|
||||
self.assertEqual(top_events[0]["event"], "pageview")
|
||||
self.assertEqual(top_events[0]["count"], 1000)
|
||||
@@ -150,13 +153,14 @@ class TestBillingNode(ClickhouseTestMixin, BaseTest):
|
||||
|
||||
def test_get_top_events_by_usage_query_fails(self):
|
||||
with patch(
|
||||
"ee.hogai.graph.billing.nodes.sync_execute", side_effect=Exception("DB connection failed")
|
||||
"ee.hogai.graph.root.tools.read_billing_tool.tool.sync_execute",
|
||||
side_effect=Exception("DB connection failed"),
|
||||
) as mock_sync_execute:
|
||||
top_events = self.node._get_top_events_by_usage()
|
||||
top_events = self.tool._get_top_events_by_usage()
|
||||
self.assertEqual(top_events, [])
|
||||
mock_sync_execute.assert_called_once()
|
||||
|
||||
def test_format_billing_context_with_addons(self):
|
||||
async def test_format_billing_context_with_addons(self):
|
||||
"""Test that addons are properly nested within products in the formatted output"""
|
||||
billing_context = MaxBillingContext(
|
||||
subscription_level=MaxBillingContextSubscriptionLevel.PAID,
|
||||
@@ -230,14 +234,14 @@ class TestBillingNode(ClickhouseTestMixin, BaseTest):
|
||||
)
|
||||
|
||||
with patch.object(
|
||||
self.node,
|
||||
self.tool,
|
||||
"_get_top_events_by_usage",
|
||||
return_value=[
|
||||
{"event": "$pageview", "count": 50000, "formatted_count": "50,000"},
|
||||
{"event": "$autocapture", "count": 30000, "formatted_count": "30,000"},
|
||||
],
|
||||
):
|
||||
formatted_string = self.node._format_billing_context(billing_context)
|
||||
formatted_string = await self.tool._format_billing_context(billing_context)
|
||||
|
||||
# Check basic info
|
||||
self.assertIn("paid subscription (startup)", formatted_string)
|
||||
@@ -274,7 +278,7 @@ class TestBillingNode(ClickhouseTestMixin, BaseTest):
|
||||
self.assertIn("$pageview", formatted_string)
|
||||
self.assertIn("50,000 events", formatted_string)
|
||||
|
||||
def test_format_billing_context_no_subscription(self):
|
||||
async def test_format_billing_context_no_subscription(self):
|
||||
"""Test formatting when user has no active subscription (free plan)"""
|
||||
billing_context = MaxBillingContext(
|
||||
subscription_level=MaxBillingContextSubscriptionLevel.FREE,
|
||||
@@ -286,8 +290,8 @@ class TestBillingNode(ClickhouseTestMixin, BaseTest):
|
||||
trial=MaxBillingContextTrial(is_active=True, expires_at="2023-02-01", target="teams"),
|
||||
)
|
||||
|
||||
with patch.object(self.node, "_get_top_events_by_usage", return_value=[]):
|
||||
formatted_string = self.node._format_billing_context(billing_context)
|
||||
with patch.object(self.tool, "_get_top_events_by_usage", return_value=[]):
|
||||
formatted_string = await self.tool._format_billing_context(billing_context)
|
||||
|
||||
self.assertIn("free subscription", formatted_string)
|
||||
self.assertIn("Active subscription: No (Free plan)", formatted_string)
|
||||
@@ -297,7 +301,7 @@ class TestBillingNode(ClickhouseTestMixin, BaseTest):
|
||||
def test_format_history_table_with_team_breakdown(self):
|
||||
"""Test that history tables properly group by team when breakdown includes team IDs"""
|
||||
# Mock the teams map
|
||||
self.node._teams_map = {
|
||||
self.tool._teams_map = {
|
||||
1: "Team Alpha (ID: 1)",
|
||||
2: "Team Beta (ID: 2)",
|
||||
}
|
||||
@@ -337,7 +341,7 @@ class TestBillingNode(ClickhouseTestMixin, BaseTest):
|
||||
),
|
||||
]
|
||||
|
||||
table = self.node._format_history_table(usage_history)
|
||||
table = self.tool._format_history_table(usage_history)
|
||||
|
||||
# Check team-specific tables
|
||||
self.assertIn("### Team Alpha (ID: 1)", table)
|
||||
@@ -349,7 +353,7 @@ class TestBillingNode(ClickhouseTestMixin, BaseTest):
|
||||
self.assertIn("| Recordings | 100.00 | 200.00 |", table)
|
||||
self.assertIn("| Feature Flag Requests | 50.00 | 100.00 |", table)
|
||||
|
||||
def test_format_billing_context_edge_cases(self):
|
||||
async def test_format_billing_context_edge_cases(self):
|
||||
"""Test edge cases and potential security issues"""
|
||||
billing_context = MaxBillingContext(
|
||||
subscription_level=MaxBillingContextSubscriptionLevel.CUSTOM,
|
||||
@@ -374,8 +378,8 @@ class TestBillingNode(ClickhouseTestMixin, BaseTest):
|
||||
settings=MaxBillingContextSettings(autocapture_on=True, active_destinations=0),
|
||||
)
|
||||
|
||||
with patch.object(self.node, "_get_top_events_by_usage", return_value=[]):
|
||||
formatted_string = self.node._format_billing_context(billing_context)
|
||||
with patch.object(self.tool, "_get_top_events_by_usage", return_value=[]):
|
||||
formatted_string = await self.tool._format_billing_context(billing_context)
|
||||
|
||||
# Check deactivated status
|
||||
self.assertIn("Status: Account is deactivated", formatted_string)
|
||||
@@ -393,7 +397,7 @@ class TestBillingNode(ClickhouseTestMixin, BaseTest):
|
||||
# Check exceeded limit warning
|
||||
self.assertIn("⚠️ Usage limit exceeded", formatted_string)
|
||||
|
||||
def test_format_billing_context_complete_template_coverage(self):
|
||||
async def test_format_billing_context_complete_template_coverage(self):
|
||||
"""Test all possible template variables are covered"""
|
||||
billing_context = MaxBillingContext(
|
||||
subscription_level=MaxBillingContextSubscriptionLevel.PAID,
|
||||
@@ -472,14 +476,14 @@ class TestBillingNode(ClickhouseTestMixin, BaseTest):
|
||||
)
|
||||
|
||||
# Mock teams map for history table
|
||||
self.node._teams_map = {1: "Main Team (ID: 1)"}
|
||||
self.tool._teams_map = {1: "Main Team (ID: 1)"}
|
||||
|
||||
with patch.object(
|
||||
self.node,
|
||||
self.tool,
|
||||
"_get_top_events_by_usage",
|
||||
return_value=[{"event": "$identify", "count": 10000, "formatted_count": "10,000"}],
|
||||
):
|
||||
formatted_string = self.node._format_billing_context(billing_context)
|
||||
formatted_string = await self.tool._format_billing_context(billing_context)
|
||||
|
||||
# Verify all template sections are present
|
||||
self.assertIn("<billing_context>", formatted_string)
|
||||
@@ -545,12 +549,12 @@ class TestBillingNode(ClickhouseTestMixin, BaseTest):
|
||||
),
|
||||
]
|
||||
|
||||
self.node._teams_map = {
|
||||
self.tool._teams_map = {
|
||||
84444: "Project 84444",
|
||||
12345: "Project 12345",
|
||||
}
|
||||
|
||||
table = self.node._format_history_table(usage_history)
|
||||
table = self.tool._format_history_table(usage_history)
|
||||
|
||||
# Should always include aggregated table first
|
||||
self.assertIn("### Overall (all projects)", table)
|
||||
@@ -597,7 +601,7 @@ class TestBillingNode(ClickhouseTestMixin, BaseTest):
|
||||
),
|
||||
]
|
||||
|
||||
table = self.node._format_history_table(usage_history)
|
||||
table = self.tool._format_history_table(usage_history)
|
||||
|
||||
# Should include aggregated table
|
||||
self.assertIn("### Overall (all projects)", table)
|
||||
@@ -645,7 +649,7 @@ class TestBillingNode(ClickhouseTestMixin, BaseTest):
|
||||
)
|
||||
]
|
||||
|
||||
aggregated = self.node._create_aggregated_items(team_items, other_items)
|
||||
aggregated = self.tool._create_aggregated_items(team_items, other_items)
|
||||
|
||||
# Should have 2 aggregated items: events and feature flags
|
||||
self.assertEqual(len(aggregated), 2)
|
||||
@@ -687,7 +691,7 @@ class TestBillingNode(ClickhouseTestMixin, BaseTest):
|
||||
),
|
||||
]
|
||||
|
||||
table = self.node._format_history_table(usage_history)
|
||||
table = self.tool._format_history_table(usage_history)
|
||||
|
||||
# Should only have aggregated table, no team-specific tables
|
||||
self.assertIn("### Overall (all projects)", table)
|
||||
@@ -721,18 +725,18 @@ class TestBillingNode(ClickhouseTestMixin, BaseTest):
|
||||
)
|
||||
]
|
||||
|
||||
self.node._teams_map = {
|
||||
self.tool._teams_map = {
|
||||
123: "Project 123",
|
||||
}
|
||||
|
||||
# Test usage history formatting
|
||||
usage_table = self.node._format_history_table(usage_history)
|
||||
usage_table = self.tool._format_history_table(usage_history)
|
||||
self.assertIn("### Overall (all projects)", usage_table)
|
||||
self.assertIn("### Project 123", usage_table)
|
||||
self.assertIn("| Events | 1,000.00 |", usage_table)
|
||||
|
||||
# Test spend history formatting
|
||||
spend_table = self.node._format_history_table(spend_history)
|
||||
spend_table = self.tool._format_history_table(spend_history)
|
||||
self.assertIn("### Overall (all projects)", spend_table)
|
||||
self.assertIn("### Project 123", spend_table)
|
||||
self.assertIn("| Events | 50.00 |", spend_table)
|
||||
@@ -758,7 +762,7 @@ class TestBillingNode(ClickhouseTestMixin, BaseTest):
|
||||
),
|
||||
]
|
||||
|
||||
table = self.node._format_history_table(usage_history)
|
||||
table = self.tool._format_history_table(usage_history)
|
||||
|
||||
# Should handle empty dates gracefully and still show valid data
|
||||
self.assertIn("### Overall (all projects)", table)
|
||||
@@ -1,18 +1,19 @@
|
||||
from typing import Any, cast
|
||||
from uuid import uuid4
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.prompts import PromptTemplate
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
|
||||
from posthog.schema import AssistantToolCallMessage, MaxBillingContext, SpendHistoryItem, UsageHistoryItem
|
||||
from posthog.schema import MaxBillingContext, SpendHistoryItem, UsageHistoryItem
|
||||
|
||||
from posthog.clickhouse.client import sync_execute
|
||||
from posthog.models import Team, User
|
||||
from posthog.sync import database_sync_to_async
|
||||
|
||||
from ee.hogai.graph.base import AssistantNode
|
||||
from ee.hogai.graph.billing.prompts import BILLING_CONTEXT_PROMPT
|
||||
from ee.hogai.context.context import AssistantContextManager
|
||||
from ee.hogai.tool import MaxSubtool
|
||||
from ee.hogai.utils.types import AssistantState
|
||||
from ee.hogai.utils.types.base import AssistantNodeName, PartialAssistantState
|
||||
from ee.hogai.utils.types.composed import MaxNodeName
|
||||
|
||||
from .prompts import BILLING_CONTEXT_PROMPT
|
||||
|
||||
# sync with frontend/src/scenes/billing/constants.ts
|
||||
USAGE_TYPES = [
|
||||
@@ -33,33 +34,27 @@ USAGE_TYPES = [
|
||||
]
|
||||
|
||||
|
||||
class BillingNode(AssistantNode):
|
||||
_teams_map: dict[int, str] = {}
|
||||
class ReadBillingTool(MaxSubtool):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
team: Team,
|
||||
user: User,
|
||||
state: AssistantState,
|
||||
config: RunnableConfig,
|
||||
context_manager: AssistantContextManager,
|
||||
):
|
||||
super().__init__(team=team, user=user, state=state, config=config, context_manager=context_manager)
|
||||
self._teams_map: dict[int, str] = {}
|
||||
|
||||
@property
|
||||
def node_name(self) -> MaxNodeName:
|
||||
return AssistantNodeName.BILLING
|
||||
|
||||
def run(self, state: AssistantState, config: RunnableConfig) -> PartialAssistantState:
|
||||
tool_call_id = cast(str, state.root_tool_call_id)
|
||||
billing_context = self.context_manager.get_billing_context()
|
||||
async def execute(self) -> str:
|
||||
billing_context = self._context_manager.get_billing_context()
|
||||
if not billing_context:
|
||||
return PartialAssistantState(
|
||||
messages=[
|
||||
AssistantToolCallMessage(
|
||||
content="No billing information available", id=str(uuid4()), tool_call_id=tool_call_id
|
||||
)
|
||||
],
|
||||
root_tool_call_id=None,
|
||||
)
|
||||
formatted_billing_context = self._format_billing_context(billing_context)
|
||||
return PartialAssistantState(
|
||||
messages=[
|
||||
AssistantToolCallMessage(content=formatted_billing_context, tool_call_id=tool_call_id, id=str(uuid4())),
|
||||
],
|
||||
root_tool_call_id=None,
|
||||
)
|
||||
return "No billing information available"
|
||||
formatted_billing_context = await self._format_billing_context(billing_context)
|
||||
return formatted_billing_context
|
||||
|
||||
@database_sync_to_async(thread_sensitive=False)
|
||||
def _format_billing_context(self, billing_context: MaxBillingContext) -> str:
|
||||
"""Format billing context into a readable prompt section."""
|
||||
# Convert billing context to a format suitable for the mustache template
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Any, Literal, Self
|
||||
from typing import Literal, Self
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from pydantic import BaseModel
|
||||
@@ -9,12 +9,14 @@ from ee.hogai.context.context import AssistantContextManager
|
||||
from ee.hogai.graph.sql.mixins import HogQLDatabaseMixin
|
||||
from ee.hogai.tool import MaxTool
|
||||
from ee.hogai.utils.prompt import format_prompt_string
|
||||
from ee.hogai.utils.types.base import AssistantState
|
||||
from ee.hogai.utils.types.base import AssistantState, NodePath
|
||||
|
||||
from .read_billing_tool.tool import ReadBillingTool
|
||||
|
||||
READ_DATA_BILLING_PROMPT = """
|
||||
# Billing information
|
||||
|
||||
Use this tool with the "billing_info" kind to retrieve the billing information if the user asks about billing, their subscription, their usage, or their spending.
|
||||
Use this tool with the "billing_info" kind to retrieve the billing information if the user asks about their billing, subscription, product usage, spending, or cost reduction strategies.
|
||||
You can use the information retrieved to check which PostHog products and add-ons the user has activated, how much they are spending, their usage history across all products in the last 30 days, as well as trials, spending limits, billing period, and more.
|
||||
If the user wants to reduce their spending, always call this tool to get suggestions on how to do so.
|
||||
If an insight shows zero data, it could mean either the query is looking at the wrong data or there was a temporary data collection issue. You can investigate potential dips in usage/captured data using the billing tool.
|
||||
@@ -62,7 +64,7 @@ class ReadDataTool(HogQLDatabaseMixin, MaxTool):
|
||||
*,
|
||||
team: Team,
|
||||
user: User,
|
||||
tool_call_id: str,
|
||||
node_path: tuple[NodePath, ...] | None = None,
|
||||
state: AssistantState | None = None,
|
||||
config: RunnableConfig | None = None,
|
||||
context_manager: AssistantContextManager | None = None,
|
||||
@@ -83,21 +85,29 @@ class ReadDataTool(HogQLDatabaseMixin, MaxTool):
|
||||
return cls(
|
||||
team=team,
|
||||
user=user,
|
||||
tool_call_id=tool_call_id,
|
||||
state=state,
|
||||
node_path=node_path,
|
||||
config=config,
|
||||
args_schema=args,
|
||||
description=description,
|
||||
context_manager=context_manager,
|
||||
)
|
||||
|
||||
async def _arun_impl(self, kind: ReadDataAdminAccessKind | ReadDataKind) -> tuple[str, dict[str, Any] | None]:
|
||||
async def _arun_impl(self, kind: ReadDataAdminAccessKind | ReadDataKind) -> tuple[str, None]:
|
||||
match kind:
|
||||
case "billing_info":
|
||||
has_access = await self._context_manager.check_user_has_billing_access()
|
||||
if not has_access:
|
||||
return BILLING_INSUFFICIENT_ACCESS_PROMPT, None
|
||||
# used for routing
|
||||
return "", self.args_schema(kind=kind).model_dump()
|
||||
billing_tool = ReadBillingTool(
|
||||
team=self._team,
|
||||
user=self._user,
|
||||
state=self._state,
|
||||
config=self._config,
|
||||
context_manager=self._context_manager,
|
||||
)
|
||||
result = await billing_tool.execute()
|
||||
return result, None
|
||||
case "datawarehouse_schema":
|
||||
return await self._serialize_database_schema(), None
|
||||
|
||||
@@ -9,7 +9,7 @@ from ee.hogai.context.context import AssistantContextManager
|
||||
from ee.hogai.graph.query_planner.toolkit import TaxonomyAgentToolkit
|
||||
from ee.hogai.tool import MaxTool
|
||||
from ee.hogai.utils.helpers import format_events_yaml
|
||||
from ee.hogai.utils.types.base import AssistantState
|
||||
from ee.hogai.utils.types.base import AssistantState, NodePath
|
||||
|
||||
READ_TAXONOMY_TOOL_DESCRIPTION = """
|
||||
Use this tool to explore the user's taxonomy (i.e. data schema).
|
||||
@@ -155,7 +155,7 @@ class ReadTaxonomyTool(MaxTool):
|
||||
*,
|
||||
team: Team,
|
||||
user: User,
|
||||
tool_call_id: str,
|
||||
node_path: tuple[NodePath, ...] | None = None,
|
||||
state: AssistantState | None = None,
|
||||
config: RunnableConfig | None = None,
|
||||
context_manager: AssistantContextManager | None = None,
|
||||
@@ -200,9 +200,9 @@ class ReadTaxonomyTool(MaxTool):
|
||||
return cls(
|
||||
team=team,
|
||||
user=user,
|
||||
tool_call_id=tool_call_id,
|
||||
state=state,
|
||||
config=config,
|
||||
node_path=node_path,
|
||||
args_schema=ReadTaxonomyToolArgsWithGroups,
|
||||
context_manager=context_manager,
|
||||
)
|
||||
|
||||
@@ -1,17 +1,19 @@
|
||||
from typing import Any, Literal
|
||||
from typing import Literal
|
||||
|
||||
from django.conf import settings
|
||||
|
||||
import posthoganalytics
|
||||
from langchain_core.output_parsers import SimpleJsonOutputParser
|
||||
from langchain_core.prompts import ChatPromptTemplate
|
||||
from langchain_core.runnables import RunnableLambda
|
||||
from langchain_openai import ChatOpenAI
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from posthog.models import Team, User
|
||||
|
||||
from ee.hogai.graph.root.tools.full_text_search.tool import EntitySearchToolkit, FTSKind
|
||||
from ee.hogai.tool import MaxTool
|
||||
from ee.hogai.graph.insights.nodes import InsightSearchNode, NoInsightsException
|
||||
from ee.hogai.graph.root.tools.full_text_search.tool import EntitySearchTool, FTSKind
|
||||
from ee.hogai.tool import MaxSubtool, MaxTool, ToolMessagesArtifact
|
||||
from ee.hogai.utils.prompt import format_prompt_string
|
||||
from ee.hogai.utils.types.base import AssistantState, PartialAssistantState
|
||||
|
||||
SEARCH_TOOL_PROMPT = """
|
||||
Use this tool to search docs, insights, dashboards, cohorts, actions, experiments, feature flags, notebooks, error tracking issues, and surveys in PostHog.
|
||||
@@ -57,35 +59,10 @@ If you want to search for all entities, you should use `all`.
|
||||
|
||||
""".strip()
|
||||
|
||||
DOCS_SEARCH_RESULTS_TEMPLATE = """Found {count} relevant documentation page(s):
|
||||
|
||||
{docs}
|
||||
<system_reminder>
|
||||
Use retrieved documentation to answer the user's question if it is relevant to the user's query.
|
||||
Format the response using Markdown and reference the documentation using hyperlinks.
|
||||
Every link to docs clearly explicitly be labeled, for example as "(see docs)".
|
||||
</system_reminder>
|
||||
INVALID_ENTITY_KIND_PROMPT = """
|
||||
Invalid entity kind: {{{kind}}}. Please provide a valid entity kind for the tool.
|
||||
""".strip()
|
||||
|
||||
DOCS_SEARCH_NO_RESULTS_TEMPLATE = """
|
||||
No documentation found.
|
||||
|
||||
<system_reminder>
|
||||
Do not answer the user's question if you did not find any documentation. Try rewriting the query.
|
||||
If after a couple of attempts you still do not find any documentation, suggest the user navigate to the documentation page, which is available at `https://posthog.com/docs`.
|
||||
</system_reminder>
|
||||
""".strip()
|
||||
|
||||
DOC_ITEM_TEMPLATE = """
|
||||
# {title}
|
||||
URL: {url}
|
||||
|
||||
{text}
|
||||
""".strip()
|
||||
|
||||
|
||||
FTS_SEARCH_FEATURE_FLAG = "hogai-insights-fts-search"
|
||||
|
||||
ENTITIES = [f"{entity}" for entity in FTSKind if entity != FTSKind.INSIGHTS]
|
||||
|
||||
SearchKind = Literal["insights", "docs", *ENTITIES] # type: ignore
|
||||
@@ -126,50 +103,106 @@ class SearchTool(MaxTool):
|
||||
context_prompt_template: str = "Searches documentation, insights, dashboards, cohorts, actions, experiments, feature flags, notebooks, error tracking issues, and surveys in PostHog"
|
||||
args_schema: type[BaseModel] = SearchToolArgs
|
||||
|
||||
@staticmethod
|
||||
def _get_fts_entities(include_insight_fts: bool) -> list[str]:
|
||||
if not include_insight_fts:
|
||||
entities = [e for e in FTSKind if e != FTSKind.INSIGHTS]
|
||||
else:
|
||||
entities = list(FTSKind)
|
||||
return [*entities, FTSKind.ALL]
|
||||
|
||||
async def _arun_impl(self, kind: SearchKind, query: str) -> tuple[str, dict[str, Any] | None]:
|
||||
async def _arun_impl(self, kind: str, query: str) -> tuple[str, ToolMessagesArtifact | None]:
|
||||
if kind == "docs":
|
||||
if not settings.INKEEP_API_KEY:
|
||||
return "This tool is not available in this environment.", None
|
||||
if self._has_docs_search_feature_flag():
|
||||
return await self._search_docs(query), None
|
||||
docs_tool = InkeepDocsSearchTool(
|
||||
team=self._team,
|
||||
user=self._user,
|
||||
state=self._state,
|
||||
config=self._config,
|
||||
context_manager=self._context_manager,
|
||||
)
|
||||
return await docs_tool.execute(query, self.tool_call_id)
|
||||
|
||||
fts_entities = SearchTool._get_fts_entities(SearchTool._has_fts_search_feature_flag(self._user, self._team))
|
||||
if kind == "insights" and not self._has_insights_fts_search_feature_flag():
|
||||
insights_tool = InsightSearchTool(
|
||||
team=self._team,
|
||||
user=self._user,
|
||||
state=self._state,
|
||||
config=self._config,
|
||||
context_manager=self._context_manager,
|
||||
)
|
||||
return await insights_tool.execute(query, self.tool_call_id)
|
||||
|
||||
if kind in fts_entities:
|
||||
entity_search_toolkit = EntitySearchToolkit(self._team, self._user)
|
||||
response = await entity_search_toolkit.execute(query, FTSKind(kind))
|
||||
return response, None
|
||||
# Used for routing
|
||||
return "Search tool executed", SearchToolArgs(kind=kind, query=query).model_dump()
|
||||
if kind not in self._fts_entities:
|
||||
return format_prompt_string(INVALID_ENTITY_KIND_PROMPT, kind=kind), None
|
||||
|
||||
def _has_docs_search_feature_flag(self) -> bool:
|
||||
entity_search_toolkit = EntitySearchTool(
|
||||
team=self._team,
|
||||
user=self._user,
|
||||
state=self._state,
|
||||
config=self._config,
|
||||
context_manager=self._context_manager,
|
||||
)
|
||||
response = await entity_search_toolkit.execute(query, FTSKind(kind))
|
||||
return response, None
|
||||
|
||||
@property
|
||||
def _fts_entities(self) -> list[str]:
|
||||
entities = list(FTSKind)
|
||||
return [*entities, FTSKind.ALL]
|
||||
|
||||
def _has_insights_fts_search_feature_flag(self) -> bool:
|
||||
return posthoganalytics.feature_enabled(
|
||||
"max-inkeep-rag-docs-search",
|
||||
"hogai-insights-fts-search",
|
||||
str(self._user.distinct_id),
|
||||
groups={"organization": str(self._team.organization_id)},
|
||||
group_properties={"organization": {"id": str(self._team.organization_id)}},
|
||||
send_feature_flag_events=False,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _has_fts_search_feature_flag(user: User, team: Team) -> bool:
|
||||
return posthoganalytics.feature_enabled(
|
||||
FTS_SEARCH_FEATURE_FLAG,
|
||||
str(user.distinct_id),
|
||||
groups={"organization": str(team.organization_id)},
|
||||
group_properties={"organization": {"id": str(team.organization_id)}},
|
||||
send_feature_flag_events=False,
|
||||
)
|
||||
|
||||
async def _search_docs(self, query: str) -> str:
|
||||
DOCS_SEARCH_RESULTS_TEMPLATE = """Found {count} relevant documentation page(s):
|
||||
|
||||
{docs}
|
||||
<system_reminder>
|
||||
Use retrieved documentation to answer the user's question if it is relevant to the user's query.
|
||||
Format the response using Markdown and reference the documentation using hyperlinks.
|
||||
Every link to docs clearly explicitly be labeled, for example as "(see docs)".
|
||||
</system_reminder>
|
||||
""".strip()
|
||||
|
||||
DOCS_SEARCH_NO_RESULTS_TEMPLATE = """
|
||||
No documentation found.
|
||||
|
||||
<system_reminder>
|
||||
Do not answer the user's question if you did not find any documentation. Try rewriting the query.
|
||||
If after a couple of attempts you still do not find any documentation, suggest the user navigate to the documentation page, which is available at `https://posthog.com/docs`.
|
||||
</system_reminder>
|
||||
""".strip()
|
||||
|
||||
DOC_ITEM_TEMPLATE = """
|
||||
# {title}
|
||||
URL: {url}
|
||||
|
||||
{text}
|
||||
""".strip()
|
||||
|
||||
|
||||
class InkeepDocsSearchTool(MaxSubtool):
|
||||
async def execute(self, query: str, tool_call_id: str) -> tuple[str, ToolMessagesArtifact | None]:
|
||||
if self._has_rag_docs_search_feature_flag():
|
||||
return await self._search_using_rag_endpoint(query, tool_call_id)
|
||||
else:
|
||||
return await self._search_using_node(query, tool_call_id)
|
||||
|
||||
async def _search_using_node(self, query: str, tool_call_id: str) -> tuple[str, ToolMessagesArtifact | None]:
|
||||
# Avoid circular import
|
||||
from ee.hogai.graph.inkeep_docs.nodes import InkeepDocsNode
|
||||
|
||||
# Init the graph
|
||||
node = InkeepDocsNode(self._team, self._user)
|
||||
chain: RunnableLambda[AssistantState, PartialAssistantState | None] = RunnableLambda(node)
|
||||
copied_state = self._state.model_copy(deep=True, update={"root_tool_call_id": tool_call_id})
|
||||
result = await chain.ainvoke(copied_state)
|
||||
assert result is not None
|
||||
return "", ToolMessagesArtifact(messages=result.messages)
|
||||
|
||||
async def _search_using_rag_endpoint(
|
||||
self, query: str, tool_call_id: str
|
||||
) -> tuple[str, ToolMessagesArtifact | None]:
|
||||
model = ChatOpenAI(
|
||||
model="inkeep-rag",
|
||||
base_url="https://api.inkeep.com/v1/",
|
||||
@@ -184,7 +217,7 @@ class SearchTool(MaxTool):
|
||||
rag_context_raw = await chain.ainvoke({"query": query})
|
||||
|
||||
if not rag_context_raw or not rag_context_raw.get("content"):
|
||||
return DOCS_SEARCH_NO_RESULTS_TEMPLATE
|
||||
return DOCS_SEARCH_NO_RESULTS_TEMPLATE, None
|
||||
|
||||
rag_context = InkeepResponse.model_validate(rag_context_raw)
|
||||
|
||||
@@ -197,7 +230,35 @@ class SearchTool(MaxTool):
|
||||
docs.append(DOC_ITEM_TEMPLATE.format(title=doc.title, url=doc.url, text=text))
|
||||
|
||||
if not docs:
|
||||
return DOCS_SEARCH_NO_RESULTS_TEMPLATE
|
||||
return DOCS_SEARCH_NO_RESULTS_TEMPLATE, None
|
||||
|
||||
formatted_docs = "\n\n---\n\n".join(docs)
|
||||
return DOCS_SEARCH_RESULTS_TEMPLATE.format(count=len(docs), docs=formatted_docs)
|
||||
return DOCS_SEARCH_RESULTS_TEMPLATE.format(count=len(docs), docs=formatted_docs), None
|
||||
|
||||
def _has_rag_docs_search_feature_flag(self) -> bool:
|
||||
return posthoganalytics.feature_enabled(
|
||||
"max-inkeep-rag-docs-search",
|
||||
str(self._user.distinct_id),
|
||||
groups={"organization": str(self._team.organization_id)},
|
||||
group_properties={"organization": {"id": str(self._team.organization_id)}},
|
||||
send_feature_flag_events=False,
|
||||
)
|
||||
|
||||
|
||||
EMPTY_DATABASE_ERROR_MESSAGE = """
|
||||
The user doesn't have any insights created yet.
|
||||
""".strip()
|
||||
|
||||
|
||||
class InsightSearchTool(MaxSubtool):
|
||||
async def execute(self, query: str, tool_call_id: str) -> tuple[str, ToolMessagesArtifact | None]:
|
||||
try:
|
||||
node = InsightSearchNode(self._team, self._user)
|
||||
copied_state = self._state.model_copy(
|
||||
deep=True, update={"search_insights_query": query, "root_tool_call_id": tool_call_id}
|
||||
)
|
||||
chain: RunnableLambda[AssistantState, PartialAssistantState | None] = RunnableLambda(node)
|
||||
result = await chain.ainvoke(copied_state)
|
||||
return "", ToolMessagesArtifact(messages=result.messages) if result else None
|
||||
except NoInsightsException:
|
||||
return EMPTY_DATABASE_ERROR_MESSAGE, None
|
||||
|
||||
122
ee/hogai/graph/root/tools/session_summarization.py
Normal file
122
ee/hogai/graph/root/tools/session_summarization.py
Normal file
@@ -0,0 +1,122 @@
|
||||
from typing import Literal
|
||||
|
||||
from langchain_core.runnables import RunnableLambda
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from ee.hogai.graph.session_summaries.nodes import SessionSummarizationNode
|
||||
from ee.hogai.tool import MaxTool, ToolMessagesArtifact
|
||||
from ee.hogai.utils.types.base import AssistantState, PartialAssistantState
|
||||
|
||||
SESSION_SUMMARIZATION_TOOL_PROMPT = """
|
||||
Use this tool to summarize session recordings by analysing the events within those sessions to find patterns and issues.
|
||||
It will return a textual summary of the captured session recordings.
|
||||
|
||||
# When to use the tool:
|
||||
When the user asks to summarize session recordings:
|
||||
- "summarize" synonyms: "watch", "analyze", "review", and similar
|
||||
- "session recordings" synonyms: "sessions", "recordings", "replays", "user sessions", and similar
|
||||
|
||||
# When NOT to use the tool:
|
||||
- When the user asks to find, search for, or look up session recordings, but doesn't ask to summarize them
|
||||
- When users asks to update, change, or adjust session recordings filters
|
||||
|
||||
# Synonyms
|
||||
- "summarize": "watch", "analyze", "review", and similar
|
||||
- "session recordings": "sessions", "recordings", "replays", "user sessions", and similar
|
||||
|
||||
# Managing context
|
||||
If the conversation history contains context about the current filters or session recordings, follow these steps:
|
||||
- Convert the user query into a `session_summarization_query`
|
||||
- The query should be used to understand the user's intent
|
||||
- Decide if the query is relevant to the current filters and set `should_use_current_filters` accordingly
|
||||
- Generate the `summary_title` based on the user's query and the current filters
|
||||
|
||||
Otherwise:
|
||||
- Convert the user query into a `session_summarization_query`
|
||||
- The query should be used to search for relevant sessions and then summarize them
|
||||
- Assume the `should_use_current_filters` should be always `false`
|
||||
- Generate the `summary_title` based on the user's query
|
||||
|
||||
# Additional guidelines
|
||||
- CRITICAL: Always pass the user's complete, unmodified query to the `session_summarization_query` parameter
|
||||
- DO NOT truncate, summarize, or extract keywords from the user's query
|
||||
- The query is used to find relevant sessions - context helps find better matches
|
||||
- Use explicit tool definition to make a decision
|
||||
""".strip()
|
||||
|
||||
|
||||
class SessionSummarizationToolArgs(BaseModel):
|
||||
session_summarization_query: str = Field(
|
||||
description="""
|
||||
- The user's complete query for session recordings summarization.
|
||||
- This will be used to find relevant session recordings.
|
||||
- Always pass the user's complete, unmodified query.
|
||||
- Examples:
|
||||
* 'summarize all session recordings from yesterday'
|
||||
* 'analyze mobile user session recordings from last week, even if 1 second'
|
||||
* 'watch last 300 session recordings of MacOS users from US'
|
||||
* and similar
|
||||
"""
|
||||
)
|
||||
should_use_current_filters: bool = Field(
|
||||
description="""
|
||||
- Whether to use current filters from user's UI to find relevant session recordings.
|
||||
- IMPORTANT: Should be always `false` if the current filters or `search_session_recordings` tool are not present in the conversation history.
|
||||
- Examples:
|
||||
* Set to `true` if one of the conditions is met:
|
||||
- the user wants to summarize "current/selected/opened/my/all/these" session recordings
|
||||
- the user wants to use "current/these" filters
|
||||
- the user's query specifies filters identical to the current filters
|
||||
- if the user's query doesn't specify any filters/conditions
|
||||
- the user refers to what they're "looking at" or "viewing"
|
||||
* Set to `false` if one of the conditions is met:
|
||||
- no current filters or `search_session_recordings` tool are present in the conversation
|
||||
- the user specifies date/time period different from the current filters
|
||||
- the user specifies conditions (user, device, id, URL, etc.) not present in the current filters
|
||||
""",
|
||||
)
|
||||
summary_title: str = Field(
|
||||
description="""
|
||||
- The name of the summary that is expected to be generated from the user's `session_summarization_query` and/or `current_filters` (if present).
|
||||
- The name should cover in 3-7 words what sessions would be to be summarized in the summary
|
||||
- This won't be used for any search of filtering, only to properly label the generated summary.
|
||||
- Examples:
|
||||
* If `should_use_current_filters` is `false`, then the `summary_title` should be generated based on the `session_summarization_query`:
|
||||
- query: "I want to watch all the sessions of user `user@example.com` in the last 30 days no matter how long" -> name: "Sessions of the user user@example.com (last 30 days)"
|
||||
- query: "summarize my last 100 session recordings" -> name: "Last 100 sessions"
|
||||
- and similar
|
||||
* If `should_use_current_filters` is `true`, then the `summary_title` should be generated based on the current filters in the context (if present):
|
||||
- filters: "{"key":"$os","value":["Mac OS X"],"operator":"exact","type":"event"}" -> name: "MacOS users"
|
||||
- filters: "{"date_from": "-7d", "filter_test_accounts": True}" -> name: "All sessions (last 7 days)"
|
||||
- and similar
|
||||
* If there's not enough context to generated the summary name - keep it an empty string ("")
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
class SessionSummarizationTool(MaxTool):
|
||||
name: Literal["session_summarization"] = "session_summarization"
|
||||
description: str = SESSION_SUMMARIZATION_TOOL_PROMPT
|
||||
thinking_message: str = "Summarizing session recordings"
|
||||
context_prompt_template: str = "Summarizes session recordings based on the user's query and current filters"
|
||||
args_schema: type[BaseModel] = SessionSummarizationToolArgs
|
||||
show_tool_call_message: bool = False
|
||||
|
||||
async def _arun_impl(
|
||||
self, session_summarization_query: str, should_use_current_filters: bool, summary_title: str
|
||||
) -> tuple[str, ToolMessagesArtifact | None]:
|
||||
node = SessionSummarizationNode(self._team, self._user)
|
||||
chain: RunnableLambda[AssistantState, PartialAssistantState | None] = RunnableLambda(node)
|
||||
copied_state = self._state.model_copy(
|
||||
deep=True,
|
||||
update={
|
||||
"root_tool_call_id": self.tool_call_id,
|
||||
"session_summarization_query": session_summarization_query,
|
||||
"should_use_current_filters": should_use_current_filters,
|
||||
"summary_title": summary_title,
|
||||
},
|
||||
)
|
||||
result = await chain.ainvoke(copied_state)
|
||||
if not result or not result.messages:
|
||||
return "Session summarization failed", None
|
||||
return "", ToolMessagesArtifact(messages=result.messages)
|
||||
253
ee/hogai/graph/root/tools/test/test_create_and_query_insight.py
Normal file
253
ee/hogai/graph/root/tools/test/test_create_and_query_insight.py
Normal file
@@ -0,0 +1,253 @@
|
||||
from typing import Any
|
||||
|
||||
from posthog.test.base import ClickhouseTestMixin, NonAtomicBaseTest
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
|
||||
from posthog.schema import (
|
||||
AssistantMessage,
|
||||
AssistantTool,
|
||||
AssistantToolCallMessage,
|
||||
AssistantTrendsQuery,
|
||||
VisualizationMessage,
|
||||
)
|
||||
|
||||
from ee.hogai.context.context import AssistantContextManager
|
||||
from ee.hogai.graph.root.tools.create_and_query_insight import (
|
||||
INSIGHT_TOOL_FAILURE_SYSTEM_REMINDER_PROMPT,
|
||||
CreateAndQueryInsightTool,
|
||||
)
|
||||
from ee.hogai.graph.schema_generator.nodes import SchemaGenerationException
|
||||
from ee.hogai.utils.types import AssistantState
|
||||
from ee.hogai.utils.types.base import NodePath
|
||||
from ee.models.assistant import Conversation
|
||||
|
||||
|
||||
class TestCreateAndQueryInsightTool(ClickhouseTestMixin, NonAtomicBaseTest):
|
||||
CLASS_DATA_LEVEL_SETUP = False
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.conversation = Conversation.objects.create(team=self.team, user=self.user)
|
||||
self.tool_call_id = "test_tool_call_id"
|
||||
|
||||
def _create_tool(
|
||||
self, state: AssistantState | None = None, contextual_tools: dict[str, dict[str, Any]] | None = None
|
||||
):
|
||||
"""Helper to create tool instance with optional state and contextual tools."""
|
||||
if state is None:
|
||||
state = AssistantState(messages=[])
|
||||
|
||||
config: RunnableConfig = RunnableConfig()
|
||||
if contextual_tools:
|
||||
config = RunnableConfig(configurable={"contextual_tools": contextual_tools})
|
||||
|
||||
context_manager = AssistantContextManager(team=self.team, user=self.user, config=config)
|
||||
|
||||
return CreateAndQueryInsightTool(
|
||||
team=self.team,
|
||||
user=self.user,
|
||||
state=state,
|
||||
context_manager=context_manager,
|
||||
node_path=(NodePath(name="test_node", tool_call_id=self.tool_call_id, message_id="test"),),
|
||||
)
|
||||
|
||||
async def test_successful_insight_creation_returns_messages(self):
|
||||
"""Test successful insight creation returns visualization and tool call messages."""
|
||||
tool = self._create_tool()
|
||||
|
||||
query = AssistantTrendsQuery(series=[])
|
||||
viz_message = VisualizationMessage(query="test query", answer=query, plan="test plan")
|
||||
tool_call_message = AssistantToolCallMessage(content="Results are here", tool_call_id=self.tool_call_id)
|
||||
|
||||
mock_state = AssistantState(messages=[viz_message, tool_call_message])
|
||||
|
||||
with patch("ee.hogai.graph.insights_graph.graph.InsightsGraph.compile_full_graph") as mock_compile:
|
||||
mock_graph = AsyncMock()
|
||||
mock_graph.ainvoke = AsyncMock(return_value=mock_state.model_dump())
|
||||
mock_compile.return_value = mock_graph
|
||||
|
||||
result_text, artifact = await tool._arun_impl(query_description="test description")
|
||||
|
||||
self.assertEqual(result_text, "")
|
||||
self.assertIsNotNone(artifact)
|
||||
self.assertEqual(len(artifact.messages), 2)
|
||||
self.assertIsInstance(artifact.messages[0], VisualizationMessage)
|
||||
self.assertIsInstance(artifact.messages[1], AssistantToolCallMessage)
|
||||
|
||||
async def test_schema_generation_exception_returns_formatted_error(self):
|
||||
"""Test SchemaGenerationException is caught and returns formatted error message."""
|
||||
tool = self._create_tool()
|
||||
|
||||
exception = SchemaGenerationException(
|
||||
llm_output="Invalid query structure", validation_message="Missing required field: series"
|
||||
)
|
||||
|
||||
with patch("ee.hogai.graph.insights_graph.graph.InsightsGraph.compile_full_graph") as mock_compile:
|
||||
mock_graph = AsyncMock()
|
||||
mock_graph.ainvoke = AsyncMock(side_effect=exception)
|
||||
mock_compile.return_value = mock_graph
|
||||
|
||||
result_text, artifact = await tool._arun_impl(query_description="test description")
|
||||
|
||||
self.assertIsNone(artifact)
|
||||
self.assertIn("Invalid query structure", result_text)
|
||||
self.assertIn("Missing required field: series", result_text)
|
||||
self.assertIn(INSIGHT_TOOL_FAILURE_SYSTEM_REMINDER_PROMPT, result_text)
|
||||
|
||||
async def test_invalid_tool_call_message_type_returns_error(self):
|
||||
"""Test when the last message is not AssistantToolCallMessage, returns error."""
|
||||
tool = self._create_tool()
|
||||
|
||||
viz_message = VisualizationMessage(query="test query", answer=AssistantTrendsQuery(series=[]), plan="test plan")
|
||||
# Last message is AssistantMessage instead of AssistantToolCallMessage
|
||||
invalid_message = AssistantMessage(content="Not a tool call message")
|
||||
|
||||
mock_state = AssistantState(messages=[viz_message, invalid_message])
|
||||
|
||||
with patch("ee.hogai.graph.insights_graph.graph.InsightsGraph.compile_full_graph") as mock_compile:
|
||||
mock_graph = AsyncMock()
|
||||
mock_graph.ainvoke = AsyncMock(return_value=mock_state.model_dump())
|
||||
mock_compile.return_value = mock_graph
|
||||
|
||||
result_text, artifact = await tool._arun_impl(query_description="test description")
|
||||
|
||||
self.assertIsNone(artifact)
|
||||
self.assertIn("unknown error", result_text)
|
||||
self.assertIn(INSIGHT_TOOL_FAILURE_SYSTEM_REMINDER_PROMPT, result_text)
|
||||
|
||||
async def test_human_feedback_requested_returns_only_tool_call_message(self):
|
||||
"""Test when visualization message is not present, returns only tool call message."""
|
||||
tool = self._create_tool()
|
||||
|
||||
# When agent requests human feedback, there's no VisualizationMessage
|
||||
some_message = AssistantMessage(content="I need help with this query")
|
||||
tool_call_message = AssistantToolCallMessage(content="Need clarification", tool_call_id=self.tool_call_id)
|
||||
|
||||
mock_state = AssistantState(messages=[some_message, tool_call_message])
|
||||
|
||||
with patch("ee.hogai.graph.insights_graph.graph.InsightsGraph.compile_full_graph") as mock_compile:
|
||||
mock_graph = AsyncMock()
|
||||
mock_graph.ainvoke = AsyncMock(return_value=mock_state.model_dump())
|
||||
mock_compile.return_value = mock_graph
|
||||
|
||||
result_text, artifact = await tool._arun_impl(query_description="test description")
|
||||
|
||||
self.assertEqual(result_text, "")
|
||||
self.assertIsNotNone(artifact)
|
||||
self.assertEqual(len(artifact.messages), 1)
|
||||
self.assertIsInstance(artifact.messages[0], AssistantToolCallMessage)
|
||||
self.assertEqual(artifact.messages[0].content, "Need clarification")
|
||||
|
||||
async def test_editing_mode_adds_ui_payload(self):
|
||||
"""Test that in editing mode, UI payload is added to tool call message."""
|
||||
# Create tool with contextual tool available
|
||||
tool = self._create_tool(contextual_tools={AssistantTool.CREATE_AND_QUERY_INSIGHT.value: {}})
|
||||
|
||||
query = AssistantTrendsQuery(series=[])
|
||||
viz_message = VisualizationMessage(query="test query", answer=query, plan="test plan")
|
||||
tool_call_message = AssistantToolCallMessage(content="Results are here", tool_call_id=self.tool_call_id)
|
||||
|
||||
mock_state = AssistantState(messages=[viz_message, tool_call_message])
|
||||
|
||||
with patch("ee.hogai.graph.insights_graph.graph.InsightsGraph.compile_full_graph") as mock_compile:
|
||||
mock_graph = AsyncMock()
|
||||
mock_graph.ainvoke = AsyncMock(return_value=mock_state.model_dump())
|
||||
mock_compile.return_value = mock_graph
|
||||
|
||||
result_text, artifact = await tool._arun_impl(query_description="test description")
|
||||
|
||||
self.assertEqual(result_text, "")
|
||||
self.assertIsNotNone(artifact)
|
||||
self.assertEqual(len(artifact.messages), 2)
|
||||
|
||||
# Check that UI payload was added to tool call message
|
||||
returned_tool_call_message = artifact.messages[1]
|
||||
self.assertIsInstance(returned_tool_call_message, AssistantToolCallMessage)
|
||||
self.assertIsNotNone(returned_tool_call_message.ui_payload)
|
||||
self.assertIn("create_and_query_insight", returned_tool_call_message.ui_payload)
|
||||
self.assertEqual(
|
||||
returned_tool_call_message.ui_payload["create_and_query_insight"], query.model_dump(exclude_none=True)
|
||||
)
|
||||
|
||||
async def test_non_editing_mode_no_ui_payload(self):
|
||||
"""Test that in non-editing mode, no UI payload is added to tool call message."""
|
||||
# Create tool without contextual tools (non-editing mode)
|
||||
tool = self._create_tool(contextual_tools={})
|
||||
|
||||
query = AssistantTrendsQuery(series=[])
|
||||
viz_message = VisualizationMessage(query="test query", answer=query, plan="test plan")
|
||||
tool_call_message = AssistantToolCallMessage(content="Results are here", tool_call_id=self.tool_call_id)
|
||||
|
||||
mock_state = AssistantState(messages=[viz_message, tool_call_message])
|
||||
|
||||
with patch("ee.hogai.graph.insights_graph.graph.InsightsGraph.compile_full_graph") as mock_compile:
|
||||
mock_graph = AsyncMock()
|
||||
mock_graph.ainvoke = AsyncMock(return_value=mock_state.model_dump())
|
||||
mock_compile.return_value = mock_graph
|
||||
|
||||
result_text, artifact = await tool._arun_impl(query_description="test description")
|
||||
|
||||
self.assertEqual(result_text, "")
|
||||
self.assertIsNotNone(artifact)
|
||||
self.assertEqual(len(artifact.messages), 2)
|
||||
|
||||
# Check that the original tool call message is returned without modification
|
||||
returned_tool_call_message = artifact.messages[1]
|
||||
self.assertIsInstance(returned_tool_call_message, AssistantToolCallMessage)
|
||||
# In non-editing mode, the original message is returned as-is
|
||||
self.assertEqual(returned_tool_call_message, tool_call_message)
|
||||
|
||||
async def test_state_updates_include_tool_call_metadata(self):
|
||||
"""Test that the state passed to graph includes root_tool_call_id and root_tool_insight_plan."""
|
||||
initial_state = AssistantState(messages=[AssistantMessage(content="initial")])
|
||||
tool = self._create_tool(state=initial_state)
|
||||
|
||||
query = AssistantTrendsQuery(series=[])
|
||||
viz_message = VisualizationMessage(query="test query", answer=query, plan="test plan")
|
||||
tool_call_message = AssistantToolCallMessage(content="Results", tool_call_id=self.tool_call_id)
|
||||
mock_state = AssistantState(messages=[viz_message, tool_call_message])
|
||||
|
||||
invoked_state = None
|
||||
|
||||
async def capture_invoked_state(state):
|
||||
nonlocal invoked_state
|
||||
invoked_state = state
|
||||
return mock_state.model_dump()
|
||||
|
||||
with patch("ee.hogai.graph.insights_graph.graph.InsightsGraph.compile_full_graph") as mock_compile:
|
||||
mock_graph = AsyncMock()
|
||||
mock_graph.ainvoke = AsyncMock(side_effect=capture_invoked_state)
|
||||
mock_compile.return_value = mock_graph
|
||||
|
||||
await tool._arun_impl(query_description="my test query")
|
||||
|
||||
# Verify the state passed to ainvoke has the correct metadata
|
||||
self.assertIsNotNone(invoked_state)
|
||||
validated_state = AssistantState.model_validate(invoked_state)
|
||||
self.assertEqual(validated_state.root_tool_call_id, self.tool_call_id)
|
||||
self.assertEqual(validated_state.root_tool_insight_plan, "my test query")
|
||||
# Original message should still be there
|
||||
self.assertEqual(len(validated_state.messages), 1)
|
||||
assert isinstance(validated_state.messages[0], AssistantMessage)
|
||||
self.assertEqual(validated_state.messages[0].content, "initial")
|
||||
|
||||
async def test_is_editing_mode_classmethod(self):
|
||||
"""Test the is_editing_mode class method correctly detects editing mode."""
|
||||
# Test with editing mode enabled
|
||||
config_editing: RunnableConfig = RunnableConfig(
|
||||
configurable={"contextual_tools": {AssistantTool.CREATE_AND_QUERY_INSIGHT.value: {}}}
|
||||
)
|
||||
context_manager_editing = AssistantContextManager(team=self.team, user=self.user, config=config_editing)
|
||||
self.assertTrue(CreateAndQueryInsightTool.is_editing_mode(context_manager_editing))
|
||||
|
||||
# Test with editing mode disabled
|
||||
config_not_editing = RunnableConfig(configurable={"contextual_tools": {}})
|
||||
context_manager_not_editing = AssistantContextManager(team=self.team, user=self.user, config=config_not_editing)
|
||||
self.assertFalse(CreateAndQueryInsightTool.is_editing_mode(context_manager_not_editing))
|
||||
|
||||
# Test with other contextual tools but not create_and_query_insight
|
||||
config_other = RunnableConfig(configurable={"contextual_tools": {"some_other_tool": {}}})
|
||||
context_manager_other = AssistantContextManager(team=self.team, user=self.user, config=config_other)
|
||||
self.assertFalse(CreateAndQueryInsightTool.is_editing_mode(context_manager_other))
|
||||
248
ee/hogai/graph/root/tools/test/test_create_dashboard.py
Normal file
248
ee/hogai/graph/root/tools/test/test_create_dashboard.py
Normal file
@@ -0,0 +1,248 @@
|
||||
from typing import cast
|
||||
from uuid import uuid4
|
||||
|
||||
from posthog.test.base import ClickhouseTestMixin, NonAtomicBaseTest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
from posthog.schema import AssistantMessage
|
||||
|
||||
from ee.hogai.context.context import AssistantContextManager
|
||||
from ee.hogai.graph.root.tools.create_dashboard import CreateDashboardTool
|
||||
from ee.hogai.utils.types import AssistantState, InsightQuery, PartialAssistantState
|
||||
from ee.hogai.utils.types.base import NodePath
|
||||
|
||||
|
||||
class TestCreateDashboardTool(ClickhouseTestMixin, NonAtomicBaseTest):
|
||||
CLASS_DATA_LEVEL_SETUP = False
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.tool_call_id = str(uuid4())
|
||||
self.state = AssistantState(messages=[], root_tool_call_id=self.tool_call_id)
|
||||
self.context_manager = AssistantContextManager(self.team, self.user, {})
|
||||
self.tool = CreateDashboardTool(
|
||||
team=self.team,
|
||||
user=self.user,
|
||||
state=self.state,
|
||||
context_manager=self.context_manager,
|
||||
node_path=(NodePath(name="test_node", tool_call_id=self.tool_call_id, message_id="test"),),
|
||||
)
|
||||
|
||||
async def test_execute_calls_dashboard_creation_node(self):
|
||||
mock_node_instance = MagicMock()
|
||||
mock_result = PartialAssistantState(
|
||||
messages=[AssistantMessage(content="Dashboard created successfully with 3 insights")]
|
||||
)
|
||||
|
||||
with patch("ee.hogai.graph.dashboards.nodes.DashboardCreationNode", return_value=mock_node_instance):
|
||||
with patch("ee.hogai.graph.root.tools.create_dashboard.RunnableLambda") as mock_runnable:
|
||||
mock_chain = MagicMock()
|
||||
mock_chain.ainvoke = AsyncMock(return_value=mock_result)
|
||||
mock_runnable.return_value = mock_chain
|
||||
|
||||
insight_queries = [
|
||||
InsightQuery(name="Pageviews", description="Show pageviews for last 7 days"),
|
||||
InsightQuery(name="User signups", description="Show user signups funnel"),
|
||||
]
|
||||
|
||||
result, artifact = await self.tool._arun_impl(
|
||||
search_insights_queries=insight_queries,
|
||||
dashboard_name="Marketing Dashboard",
|
||||
)
|
||||
|
||||
self.assertEqual(result, "")
|
||||
self.assertIsNotNone(artifact)
|
||||
assert artifact is not None
|
||||
self.assertEqual(len(artifact.messages), 1)
|
||||
message = cast(AssistantMessage, artifact.messages[0])
|
||||
self.assertEqual(message.content, "Dashboard created successfully with 3 insights")
|
||||
|
||||
async def test_execute_updates_state_with_all_parameters(self):
|
||||
mock_result = PartialAssistantState(messages=[AssistantMessage(content="Test response")])
|
||||
|
||||
insight_queries = [
|
||||
InsightQuery(name="Revenue Trends", description="Monthly revenue trends for Q4"),
|
||||
InsightQuery(name="Churn Rate", description="Customer churn rate by cohort"),
|
||||
InsightQuery(name="NPS Score", description="Net Promoter Score over time"),
|
||||
]
|
||||
|
||||
async def mock_ainvoke(state):
|
||||
self.assertEqual(state.search_insights_queries, insight_queries)
|
||||
self.assertEqual(state.dashboard_name, "Executive Summary Q4")
|
||||
self.assertEqual(state.root_tool_call_id, self.tool_call_id)
|
||||
return mock_result
|
||||
|
||||
mock_chain = MagicMock()
|
||||
mock_chain.ainvoke = mock_ainvoke
|
||||
|
||||
with patch("ee.hogai.graph.dashboards.nodes.DashboardCreationNode"):
|
||||
with patch("ee.hogai.graph.root.tools.create_dashboard.RunnableLambda", return_value=mock_chain):
|
||||
await self.tool._arun_impl(
|
||||
search_insights_queries=insight_queries,
|
||||
dashboard_name="Executive Summary Q4",
|
||||
)
|
||||
|
||||
async def test_execute_with_single_insight_query(self):
|
||||
mock_result = PartialAssistantState(messages=[AssistantMessage(content="Dashboard with one insight created")])
|
||||
|
||||
insight_queries = [InsightQuery(name="Daily Active Users", description="Count of daily active users")]
|
||||
|
||||
async def mock_ainvoke(state):
|
||||
self.assertEqual(len(state.search_insights_queries), 1)
|
||||
self.assertEqual(state.search_insights_queries[0].name, "Daily Active Users")
|
||||
self.assertEqual(state.search_insights_queries[0].description, "Count of daily active users")
|
||||
self.assertEqual(state.dashboard_name, "User Activity Dashboard")
|
||||
return mock_result
|
||||
|
||||
mock_chain = MagicMock()
|
||||
mock_chain.ainvoke = mock_ainvoke
|
||||
|
||||
with patch("ee.hogai.graph.dashboards.nodes.DashboardCreationNode"):
|
||||
with patch("ee.hogai.graph.root.tools.create_dashboard.RunnableLambda", return_value=mock_chain):
|
||||
result, artifact = await self.tool._arun_impl(
|
||||
search_insights_queries=insight_queries,
|
||||
dashboard_name="User Activity Dashboard",
|
||||
)
|
||||
|
||||
self.assertEqual(result, "")
|
||||
self.assertIsNotNone(artifact)
|
||||
|
||||
async def test_execute_with_many_insight_queries(self):
|
||||
mock_result = PartialAssistantState(messages=[AssistantMessage(content="Large dashboard created")])
|
||||
|
||||
insight_queries = [
|
||||
InsightQuery(name=f"Insight {i}", description=f"Description for insight {i}") for i in range(10)
|
||||
]
|
||||
|
||||
async def mock_ainvoke(state):
|
||||
self.assertEqual(len(state.search_insights_queries), 10)
|
||||
self.assertEqual(state.search_insights_queries[0].name, "Insight 0")
|
||||
self.assertEqual(state.search_insights_queries[9].name, "Insight 9")
|
||||
self.assertEqual(state.dashboard_name, "Comprehensive Dashboard")
|
||||
return mock_result
|
||||
|
||||
mock_chain = MagicMock()
|
||||
mock_chain.ainvoke = mock_ainvoke
|
||||
|
||||
with patch("ee.hogai.graph.dashboards.nodes.DashboardCreationNode"):
|
||||
with patch("ee.hogai.graph.root.tools.create_dashboard.RunnableLambda", return_value=mock_chain):
|
||||
result, artifact = await self.tool._arun_impl(
|
||||
search_insights_queries=insight_queries,
|
||||
dashboard_name="Comprehensive Dashboard",
|
||||
)
|
||||
|
||||
self.assertEqual(result, "")
|
||||
self.assertIsNotNone(artifact)
|
||||
|
||||
async def test_execute_returns_failure_message_when_result_is_none(self):
|
||||
async def mock_ainvoke(state):
|
||||
return None
|
||||
|
||||
mock_chain = MagicMock()
|
||||
mock_chain.ainvoke = mock_ainvoke
|
||||
|
||||
with patch("ee.hogai.graph.dashboards.nodes.DashboardCreationNode"):
|
||||
with patch("ee.hogai.graph.root.tools.create_dashboard.RunnableLambda", return_value=mock_chain):
|
||||
result, artifact = await self.tool._arun_impl(
|
||||
search_insights_queries=[InsightQuery(name="Test", description="Test insight")],
|
||||
dashboard_name="Test Dashboard",
|
||||
)
|
||||
|
||||
self.assertEqual(result, "Dashboard creation failed")
|
||||
self.assertIsNone(artifact)
|
||||
|
||||
async def test_execute_returns_failure_message_when_result_has_no_messages(self):
|
||||
mock_result = PartialAssistantState(messages=[])
|
||||
|
||||
async def mock_ainvoke(state):
|
||||
return mock_result
|
||||
|
||||
mock_chain = MagicMock()
|
||||
mock_chain.ainvoke = mock_ainvoke
|
||||
|
||||
with patch("ee.hogai.graph.dashboards.nodes.DashboardCreationNode"):
|
||||
with patch("ee.hogai.graph.root.tools.create_dashboard.RunnableLambda", return_value=mock_chain):
|
||||
result, artifact = await self.tool._arun_impl(
|
||||
search_insights_queries=[InsightQuery(name="Test", description="Test insight")],
|
||||
dashboard_name="Test Dashboard",
|
||||
)
|
||||
|
||||
self.assertEqual(result, "Dashboard creation failed")
|
||||
self.assertIsNone(artifact)
|
||||
|
||||
async def test_execute_preserves_original_state(self):
|
||||
"""Test that the original state is not modified when creating the copied state"""
|
||||
original_queries = [InsightQuery(name="Original", description="Original insight")]
|
||||
original_state = AssistantState(
|
||||
messages=[],
|
||||
root_tool_call_id=self.tool_call_id,
|
||||
search_insights_queries=original_queries,
|
||||
dashboard_name="Original Dashboard",
|
||||
)
|
||||
|
||||
tool = CreateDashboardTool(
|
||||
team=self.team,
|
||||
user=self.user,
|
||||
state=original_state,
|
||||
context_manager=self.context_manager,
|
||||
node_path=(NodePath(name="test_node", tool_call_id=self.tool_call_id, message_id="test"),),
|
||||
)
|
||||
|
||||
mock_result = PartialAssistantState(messages=[AssistantMessage(content="Test")])
|
||||
|
||||
new_queries = [InsightQuery(name="New", description="New insight")]
|
||||
|
||||
async def mock_ainvoke(state):
|
||||
# Verify the new state has updated values
|
||||
self.assertEqual(state.search_insights_queries, new_queries)
|
||||
self.assertEqual(state.dashboard_name, "New Dashboard")
|
||||
return mock_result
|
||||
|
||||
mock_chain = MagicMock()
|
||||
mock_chain.ainvoke = mock_ainvoke
|
||||
|
||||
with patch("ee.hogai.graph.dashboards.nodes.DashboardCreationNode"):
|
||||
with patch("ee.hogai.graph.root.tools.create_dashboard.RunnableLambda", return_value=mock_chain):
|
||||
await tool._arun_impl(
|
||||
search_insights_queries=new_queries,
|
||||
dashboard_name="New Dashboard",
|
||||
)
|
||||
|
||||
# Verify original state was not modified
|
||||
self.assertEqual(original_state.search_insights_queries, original_queries)
|
||||
self.assertEqual(original_state.dashboard_name, "Original Dashboard")
|
||||
self.assertEqual(original_state.root_tool_call_id, self.tool_call_id)
|
||||
|
||||
async def test_execute_with_complex_insight_descriptions(self):
|
||||
"""Test that complex insight descriptions with special characters are handled correctly"""
|
||||
mock_result = PartialAssistantState(messages=[AssistantMessage(content="Dashboard created")])
|
||||
|
||||
insight_queries = [
|
||||
InsightQuery(
|
||||
name="User Journey",
|
||||
description='Show funnel from "Sign up" → "Create project" → "Invite team" for users with property email containing "@company.com"',
|
||||
),
|
||||
InsightQuery(
|
||||
name="Revenue (USD)",
|
||||
description="Track total revenue in USD from 2024-01-01 to 2024-12-31, filtered by plan_type = 'premium' OR plan_type = 'enterprise'",
|
||||
),
|
||||
]
|
||||
|
||||
async def mock_ainvoke(state):
|
||||
self.assertEqual(len(state.search_insights_queries), 2)
|
||||
self.assertIn("@company.com", state.search_insights_queries[0].description)
|
||||
self.assertIn("'premium'", state.search_insights_queries[1].description)
|
||||
return mock_result
|
||||
|
||||
mock_chain = MagicMock()
|
||||
mock_chain.ainvoke = mock_ainvoke
|
||||
|
||||
with patch("ee.hogai.graph.dashboards.nodes.DashboardCreationNode"):
|
||||
with patch("ee.hogai.graph.root.tools.create_dashboard.RunnableLambda", return_value=mock_chain):
|
||||
result, artifact = await self.tool._arun_impl(
|
||||
search_insights_queries=insight_queries,
|
||||
dashboard_name="Complex Dashboard",
|
||||
)
|
||||
|
||||
self.assertEqual(result, "")
|
||||
self.assertIsNotNone(artifact)
|
||||
@@ -1,27 +1,178 @@
|
||||
from posthog.test.base import BaseTest
|
||||
from unittest.mock import patch
|
||||
from typing import cast
|
||||
from uuid import uuid4
|
||||
|
||||
from posthog.test.base import ClickhouseTestMixin, NonAtomicBaseTest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
from django.test import override_settings
|
||||
|
||||
from langchain_core import messages
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
|
||||
from posthog.schema import AssistantMessage
|
||||
|
||||
from ee.hogai.context.context import AssistantContextManager
|
||||
from ee.hogai.graph.root.tools.search import (
|
||||
DOC_ITEM_TEMPLATE,
|
||||
DOCS_SEARCH_NO_RESULTS_TEMPLATE,
|
||||
DOCS_SEARCH_RESULTS_TEMPLATE,
|
||||
EMPTY_DATABASE_ERROR_MESSAGE,
|
||||
InkeepDocsSearchTool,
|
||||
InsightSearchTool,
|
||||
SearchTool,
|
||||
)
|
||||
from ee.hogai.utils.tests import FakeChatOpenAI
|
||||
from ee.hogai.utils.types import AssistantState, PartialAssistantState
|
||||
from ee.hogai.utils.types.base import NodePath
|
||||
|
||||
|
||||
class TestSearchToolDocumentation(BaseTest):
|
||||
class TestSearchTool(ClickhouseTestMixin, NonAtomicBaseTest):
|
||||
CLASS_DATA_LEVEL_SETUP = False
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.tool = SearchTool(team=self.team, user=self.user, tool_call_id="test-tool-call-id")
|
||||
self.tool_call_id = "test_tool_call_id"
|
||||
self.state = AssistantState(messages=[], root_tool_call_id=str(uuid4()))
|
||||
self.context_manager = AssistantContextManager(self.team, self.user, {})
|
||||
self.tool = SearchTool(
|
||||
team=self.team,
|
||||
user=self.user,
|
||||
state=self.state,
|
||||
context_manager=self.context_manager,
|
||||
node_path=(NodePath(name="test_node", tool_call_id=self.tool_call_id, message_id="test"),),
|
||||
)
|
||||
|
||||
async def test_run_docs_search_without_api_key(self):
|
||||
with patch("ee.hogai.graph.root.tools.search.settings") as mock_settings:
|
||||
mock_settings.INKEEP_API_KEY = None
|
||||
result, artifact = await self.tool._arun_impl(kind="docs", query="How to use feature flags?")
|
||||
self.assertEqual(result, "This tool is not available in this environment.")
|
||||
self.assertIsNone(artifact)
|
||||
|
||||
async def test_run_docs_search_with_api_key(self):
|
||||
mock_docs_tool = MagicMock()
|
||||
mock_docs_tool.execute = AsyncMock(return_value=("", MagicMock()))
|
||||
|
||||
with (
|
||||
patch("ee.hogai.graph.root.tools.search.settings") as mock_settings,
|
||||
patch("ee.hogai.graph.root.tools.search.InkeepDocsSearchTool", return_value=mock_docs_tool),
|
||||
):
|
||||
mock_settings.INKEEP_API_KEY = "test-key"
|
||||
result, artifact = await self.tool._arun_impl(kind="docs", query="How to use feature flags?")
|
||||
|
||||
mock_docs_tool.execute.assert_called_once_with("How to use feature flags?", self.tool_call_id)
|
||||
self.assertEqual(result, "")
|
||||
self.assertIsNotNone(artifact)
|
||||
|
||||
async def test_run_insights_search(self):
|
||||
mock_insights_tool = MagicMock()
|
||||
mock_insights_tool.execute = AsyncMock(return_value=("", MagicMock()))
|
||||
|
||||
with patch("ee.hogai.graph.root.tools.search.InsightSearchTool", return_value=mock_insights_tool):
|
||||
result, artifact = await self.tool._arun_impl(kind="insights", query="user signups")
|
||||
|
||||
mock_insights_tool.execute.assert_called_once_with("user signups", self.tool_call_id)
|
||||
self.assertEqual(result, "")
|
||||
self.assertIsNotNone(artifact)
|
||||
|
||||
async def test_run_unknown_kind(self):
|
||||
result, artifact = await self.tool._arun_impl(kind="unknown", query="test")
|
||||
self.assertEqual(result, "Invalid entity kind: unknown. Please provide a valid entity kind for the tool.")
|
||||
self.assertIsNone(artifact)
|
||||
|
||||
@patch("ee.hogai.graph.root.tools.search.EntitySearchTool.execute")
|
||||
async def test_arun_impl_error_tracking_issues_returns_routing_data(self, mock_execute):
|
||||
mock_execute.return_value = "Search results for error tracking issues"
|
||||
|
||||
result, artifact = await self.tool._arun_impl(
|
||||
kind="error_tracking_issues", query="test error tracking issue query"
|
||||
)
|
||||
|
||||
self.assertEqual(result, "Search results for error tracking issues")
|
||||
self.assertIsNone(artifact)
|
||||
mock_execute.assert_called_once_with("test error tracking issue query", "error_tracking_issues")
|
||||
|
||||
@patch("ee.hogai.graph.root.tools.search.EntitySearchTool.execute")
|
||||
@patch("ee.hogai.graph.root.tools.search.SearchTool._has_insights_fts_search_feature_flag")
|
||||
async def test_arun_impl_insight_with_feature_flag_disabled(
|
||||
self, mock_has_insights_fts_search_feature_flag, mock_execute
|
||||
):
|
||||
mock_has_insights_fts_search_feature_flag.return_value = False
|
||||
mock_execute.return_value = "Search results for insights"
|
||||
|
||||
result, artifact = await self.tool._arun_impl(kind="insights", query="test insight query")
|
||||
|
||||
self.assertEqual(result, "The user doesn't have any insights created yet.")
|
||||
self.assertIsNone(artifact)
|
||||
mock_execute.assert_not_called()
|
||||
|
||||
@patch("ee.hogai.graph.root.tools.search.EntitySearchTool.execute")
|
||||
@patch("ee.hogai.graph.root.tools.search.SearchTool._has_insights_fts_search_feature_flag")
|
||||
async def test_arun_impl_insight_with_feature_flag_enabled(
|
||||
self, mock_has_insights_fts_search_feature_flag, mock_execute
|
||||
):
|
||||
mock_has_insights_fts_search_feature_flag.return_value = True
|
||||
mock_execute.return_value = "Search results for insights"
|
||||
|
||||
result, artifact = await self.tool._arun_impl(kind="insights", query="test insight query")
|
||||
|
||||
self.assertEqual(result, "Search results for insights")
|
||||
self.assertIsNone(artifact)
|
||||
mock_execute.assert_called_once_with("test insight query", "insights")
|
||||
|
||||
|
||||
class TestInkeepDocsSearchTool(ClickhouseTestMixin, NonAtomicBaseTest):
|
||||
CLASS_DATA_LEVEL_SETUP = False
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.tool_call_id = str(uuid4())
|
||||
self.state = AssistantState(messages=[], root_tool_call_id=self.tool_call_id)
|
||||
self.context_manager = AssistantContextManager(self.team, self.user, {})
|
||||
self.tool = InkeepDocsSearchTool(
|
||||
team=self.team,
|
||||
user=self.user,
|
||||
state=self.state,
|
||||
config=RunnableConfig(configurable={}),
|
||||
context_manager=self.context_manager,
|
||||
)
|
||||
|
||||
async def test_execute_calls_inkeep_docs_node(self):
|
||||
mock_node_instance = MagicMock()
|
||||
mock_result = PartialAssistantState(messages=[AssistantMessage(content="Here is the answer from docs")])
|
||||
|
||||
with patch("ee.hogai.graph.inkeep_docs.nodes.InkeepDocsNode", return_value=mock_node_instance):
|
||||
with patch("ee.hogai.graph.root.tools.search.RunnableLambda") as mock_runnable:
|
||||
mock_chain = MagicMock()
|
||||
mock_chain.ainvoke = AsyncMock(return_value=mock_result)
|
||||
mock_runnable.return_value = mock_chain
|
||||
|
||||
result, artifact = await self.tool.execute("How to track events?", "test-tool-call-id")
|
||||
|
||||
self.assertEqual(result, "")
|
||||
self.assertIsNotNone(artifact)
|
||||
assert artifact is not None
|
||||
self.assertEqual(len(artifact.messages), 1)
|
||||
message = cast(AssistantMessage, artifact.messages[0])
|
||||
self.assertEqual(message.content, "Here is the answer from docs")
|
||||
|
||||
async def test_execute_updates_state_with_tool_call_id(self):
|
||||
mock_result = PartialAssistantState(messages=[AssistantMessage(content="Test response")])
|
||||
|
||||
async def mock_ainvoke(state):
|
||||
self.assertEqual(state.root_tool_call_id, "custom-tool-call-id")
|
||||
return mock_result
|
||||
|
||||
mock_chain = MagicMock()
|
||||
mock_chain.ainvoke = mock_ainvoke
|
||||
|
||||
with patch("ee.hogai.graph.inkeep_docs.nodes.InkeepDocsNode"):
|
||||
with patch("ee.hogai.graph.root.tools.search.RunnableLambda", return_value=mock_chain):
|
||||
await self.tool.execute("test query", "custom-tool-call-id")
|
||||
|
||||
@override_settings(INKEEP_API_KEY="test-inkeep-key")
|
||||
@patch("ee.hogai.graph.root.tools.search.ChatOpenAI")
|
||||
async def test_search_docs_with_successful_results(self, mock_llm_class):
|
||||
@patch("ee.hogai.graph.root.tools.search.InkeepDocsSearchTool._has_rag_docs_search_feature_flag", return_value=True)
|
||||
async def test_search_docs_with_successful_results(self, mock_has_rag_docs_search_feature_flag, mock_llm_class):
|
||||
response_json = """{
|
||||
"content": [
|
||||
{
|
||||
@@ -47,7 +198,7 @@ class TestSearchToolDocumentation(BaseTest):
|
||||
fake_llm = FakeChatOpenAI(responses=[messages.AIMessage(content=response_json)])
|
||||
mock_llm_class.return_value = fake_llm
|
||||
|
||||
result = await self.tool._search_docs("how to use feature")
|
||||
result, _ = await self.tool.execute("how to use feature", "test-tool-call-id")
|
||||
|
||||
expected_doc_1 = DOC_ITEM_TEMPLATE.format(
|
||||
title="Feature Documentation",
|
||||
@@ -68,196 +219,76 @@ class TestSearchToolDocumentation(BaseTest):
|
||||
self.assertEqual(mock_llm_class.call_args.kwargs["api_key"], "test-inkeep-key")
|
||||
self.assertEqual(mock_llm_class.call_args.kwargs["streaming"], False)
|
||||
|
||||
@override_settings(INKEEP_API_KEY="test-inkeep-key")
|
||||
@patch("ee.hogai.graph.root.tools.search.ChatOpenAI")
|
||||
async def test_search_docs_with_no_results(self, mock_llm_class):
|
||||
fake_llm = FakeChatOpenAI(responses=[messages.AIMessage(content="{}")])
|
||||
mock_llm_class.return_value = fake_llm
|
||||
|
||||
result = await self.tool._search_docs("nonexistent feature")
|
||||
class TestInsightSearchTool(ClickhouseTestMixin, NonAtomicBaseTest):
|
||||
CLASS_DATA_LEVEL_SETUP = False
|
||||
|
||||
self.assertEqual(result, DOCS_SEARCH_NO_RESULTS_TEMPLATE)
|
||||
|
||||
@override_settings(INKEEP_API_KEY="test-inkeep-key")
|
||||
@patch("ee.hogai.graph.root.tools.search.ChatOpenAI")
|
||||
async def test_search_docs_with_empty_content(self, mock_llm_class):
|
||||
fake_llm = FakeChatOpenAI(responses=[messages.AIMessage(content='{"content": []}')])
|
||||
mock_llm_class.return_value = fake_llm
|
||||
|
||||
result = await self.tool._search_docs("query")
|
||||
|
||||
self.assertEqual(result, DOCS_SEARCH_NO_RESULTS_TEMPLATE)
|
||||
|
||||
@override_settings(INKEEP_API_KEY="test-inkeep-key")
|
||||
@patch("ee.hogai.graph.root.tools.search.ChatOpenAI")
|
||||
async def test_search_docs_filters_non_document_types(self, mock_llm_class):
|
||||
response_json = """{
|
||||
"content": [
|
||||
{
|
||||
"type": "snippet",
|
||||
"record_type": "code",
|
||||
"url": "https://posthog.com/code",
|
||||
"title": "Code Snippet",
|
||||
"source": {"type": "text", "content": [{"type": "text", "text": "Code example"}]}
|
||||
},
|
||||
{
|
||||
"type": "answer",
|
||||
"record_type": "answer",
|
||||
"url": "https://posthog.com/answer",
|
||||
"title": "Answer",
|
||||
"source": {"type": "text", "content": [{"type": "text", "text": "Answer text"}]}
|
||||
}
|
||||
]
|
||||
}"""
|
||||
|
||||
fake_llm = FakeChatOpenAI(responses=[messages.AIMessage(content=response_json)])
|
||||
mock_llm_class.return_value = fake_llm
|
||||
|
||||
result = await self.tool._search_docs("query")
|
||||
|
||||
self.assertEqual(result, DOCS_SEARCH_NO_RESULTS_TEMPLATE)
|
||||
|
||||
@override_settings(INKEEP_API_KEY="test-inkeep-key")
|
||||
@patch("ee.hogai.graph.root.tools.search.ChatOpenAI")
|
||||
async def test_search_docs_handles_empty_source_content(self, mock_llm_class):
|
||||
response_json = """{
|
||||
"content": [
|
||||
{
|
||||
"type": "document",
|
||||
"record_type": "page",
|
||||
"url": "https://posthog.com/docs/feature",
|
||||
"title": "Feature Documentation",
|
||||
"source": {"type": "text", "content": []}
|
||||
}
|
||||
]
|
||||
}"""
|
||||
|
||||
fake_llm = FakeChatOpenAI(responses=[messages.AIMessage(content=response_json)])
|
||||
mock_llm_class.return_value = fake_llm
|
||||
|
||||
result = await self.tool._search_docs("query")
|
||||
|
||||
expected_doc = DOC_ITEM_TEMPLATE.format(
|
||||
title="Feature Documentation", url="https://posthog.com/docs/feature", text=""
|
||||
)
|
||||
expected_result = DOCS_SEARCH_RESULTS_TEMPLATE.format(count=1, docs=expected_doc)
|
||||
|
||||
self.assertEqual(result, expected_result)
|
||||
|
||||
@override_settings(INKEEP_API_KEY="test-inkeep-key")
|
||||
@patch("ee.hogai.graph.root.tools.search.ChatOpenAI")
|
||||
async def test_search_docs_handles_mixed_document_types(self, mock_llm_class):
|
||||
response_json = """{
|
||||
"content": [
|
||||
{
|
||||
"type": "document",
|
||||
"record_type": "page",
|
||||
"url": "https://posthog.com/docs/valid",
|
||||
"title": "Valid Doc",
|
||||
"source": {"type": "text", "content": [{"type": "text", "text": "Valid content"}]}
|
||||
},
|
||||
{
|
||||
"type": "snippet",
|
||||
"record_type": "code",
|
||||
"url": "https://posthog.com/code",
|
||||
"title": "Code Snippet",
|
||||
"source": {"type": "text", "content": [{"type": "text", "text": "Code"}]}
|
||||
},
|
||||
{
|
||||
"type": "document",
|
||||
"record_type": "guide",
|
||||
"url": "https://posthog.com/docs/another",
|
||||
"title": "Another Valid Doc",
|
||||
"source": {"type": "text", "content": [{"type": "text", "text": "More content"}]}
|
||||
}
|
||||
]
|
||||
}"""
|
||||
|
||||
fake_llm = FakeChatOpenAI(responses=[messages.AIMessage(content=response_json)])
|
||||
mock_llm_class.return_value = fake_llm
|
||||
|
||||
result = await self.tool._search_docs("query")
|
||||
|
||||
expected_doc_1 = DOC_ITEM_TEMPLATE.format(
|
||||
title="Valid Doc", url="https://posthog.com/docs/valid", text="Valid content"
|
||||
)
|
||||
expected_doc_2 = DOC_ITEM_TEMPLATE.format(
|
||||
title="Another Valid Doc", url="https://posthog.com/docs/another", text="More content"
|
||||
)
|
||||
expected_result = DOCS_SEARCH_RESULTS_TEMPLATE.format(
|
||||
count=2, docs=f"{expected_doc_1}\n\n---\n\n{expected_doc_2}"
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.tool_call_id = str(uuid4())
|
||||
self.state = AssistantState(messages=[], root_tool_call_id=self.tool_call_id)
|
||||
self.context_manager = AssistantContextManager(self.team, self.user, {})
|
||||
self.tool = InsightSearchTool(
|
||||
team=self.team,
|
||||
user=self.user,
|
||||
state=self.state,
|
||||
config=RunnableConfig(configurable={}),
|
||||
context_manager=self.context_manager,
|
||||
)
|
||||
|
||||
self.assertEqual(result, expected_result)
|
||||
async def test_execute_calls_insight_search_node(self):
|
||||
mock_node_instance = MagicMock()
|
||||
mock_result = PartialAssistantState(messages=[AssistantMessage(content="Found 3 insights matching your query")])
|
||||
|
||||
@override_settings(INKEEP_API_KEY=None)
|
||||
async def test_arun_impl_docs_without_api_key(self):
|
||||
result, artifact = await self.tool._arun_impl(kind="docs", query="test query")
|
||||
with patch("ee.hogai.graph.insights.nodes.InsightSearchNode", return_value=mock_node_instance):
|
||||
with patch("ee.hogai.graph.root.tools.search.RunnableLambda") as mock_runnable:
|
||||
mock_chain = MagicMock()
|
||||
mock_chain.ainvoke = AsyncMock(return_value=mock_result)
|
||||
mock_runnable.return_value = mock_chain
|
||||
|
||||
self.assertEqual(result, "This tool is not available in this environment.")
|
||||
self.assertIsNone(artifact)
|
||||
result, artifact = await self.tool.execute("user signups by week", self.tool_call_id)
|
||||
|
||||
@override_settings(INKEEP_API_KEY="test-key")
|
||||
@patch("ee.hogai.graph.root.tools.search.posthoganalytics.feature_enabled")
|
||||
async def test_arun_impl_docs_with_feature_flag_disabled(self, mock_feature_enabled):
|
||||
mock_feature_enabled.return_value = False
|
||||
self.assertEqual(result, "")
|
||||
self.assertIsNotNone(artifact)
|
||||
assert artifact is not None
|
||||
self.assertEqual(len(artifact.messages), 1)
|
||||
message = cast(AssistantMessage, artifact.messages[0])
|
||||
self.assertEqual(message.content, "Found 3 insights matching your query")
|
||||
|
||||
result, artifact = await self.tool._arun_impl(kind="docs", query="test query")
|
||||
async def test_execute_updates_state_with_search_query(self):
|
||||
mock_result = PartialAssistantState(messages=[AssistantMessage(content="Test response")])
|
||||
|
||||
self.assertEqual(result, "Search tool executed")
|
||||
self.assertEqual(artifact, {"kind": "docs", "query": "test query"})
|
||||
async def mock_ainvoke(state):
|
||||
self.assertEqual(state.search_insights_query, "custom search query")
|
||||
self.assertEqual(state.root_tool_call_id, self.tool_call_id)
|
||||
return mock_result
|
||||
|
||||
@override_settings(INKEEP_API_KEY="test-key")
|
||||
@patch("ee.hogai.graph.root.tools.search.posthoganalytics.feature_enabled")
|
||||
@patch.object(SearchTool, "_search_docs")
|
||||
async def test_arun_impl_docs_with_feature_flag_enabled(self, mock_search_docs, mock_feature_enabled):
|
||||
mock_feature_enabled.return_value = True
|
||||
mock_search_docs.return_value = "Search results"
|
||||
mock_chain = MagicMock()
|
||||
mock_chain.ainvoke = mock_ainvoke
|
||||
|
||||
result, artifact = await self.tool._arun_impl(kind="docs", query="test query")
|
||||
with patch("ee.hogai.graph.insights.nodes.InsightSearchNode"):
|
||||
with patch("ee.hogai.graph.root.tools.search.RunnableLambda", return_value=mock_chain):
|
||||
await self.tool.execute("custom search query", self.tool_call_id)
|
||||
|
||||
self.assertEqual(result, "Search results")
|
||||
self.assertIsNone(artifact)
|
||||
mock_search_docs.assert_called_once_with("test query")
|
||||
async def test_execute_handles_no_insights_exception(self):
|
||||
from ee.hogai.graph.insights.nodes import NoInsightsException
|
||||
|
||||
async def test_arun_impl_insights_returns_routing_data(self):
|
||||
result, artifact = await self.tool._arun_impl(kind="insights", query="test insight query")
|
||||
with patch("ee.hogai.graph.insights.nodes.InsightSearchNode", side_effect=NoInsightsException()):
|
||||
result, artifact = await self.tool.execute("user signups", self.tool_call_id)
|
||||
|
||||
self.assertEqual(result, "Search tool executed")
|
||||
self.assertEqual(artifact, {"kind": "insights", "query": "test insight query"})
|
||||
self.assertEqual(result, EMPTY_DATABASE_ERROR_MESSAGE)
|
||||
self.assertIsNone(artifact)
|
||||
|
||||
@patch("ee.hogai.graph.root.tools.search.EntitySearchToolkit.execute")
|
||||
async def test_arun_impl_error_tracking_issues_returns_routing_data(self, mock_execute):
|
||||
mock_execute.return_value = "Search results for error tracking issues"
|
||||
async def test_execute_returns_none_artifact_when_result_is_none(self):
|
||||
async def mock_ainvoke(state):
|
||||
return None
|
||||
|
||||
result, artifact = await self.tool._arun_impl(
|
||||
kind="error_tracking_issues", query="test error tracking issue query"
|
||||
)
|
||||
mock_chain = MagicMock()
|
||||
mock_chain.ainvoke = mock_ainvoke
|
||||
|
||||
self.assertEqual(result, "Search results for error tracking issues")
|
||||
self.assertIsNone(artifact)
|
||||
mock_execute.assert_called_once_with("test error tracking issue query", "error_tracking_issues")
|
||||
with patch("ee.hogai.graph.insights.nodes.InsightSearchNode"):
|
||||
with patch("ee.hogai.graph.root.tools.search.RunnableLambda", return_value=mock_chain):
|
||||
result, artifact = await self.tool.execute("test query", self.tool_call_id)
|
||||
|
||||
@patch("ee.hogai.graph.root.tools.search.EntitySearchToolkit.execute")
|
||||
@patch("ee.hogai.graph.root.tools.search.SearchTool._has_fts_search_feature_flag")
|
||||
async def test_arun_impl_insight_with_feature_flag_disabled(self, mock_has_fts_search_feature_flag, mock_execute):
|
||||
mock_has_fts_search_feature_flag.return_value = False
|
||||
mock_execute.return_value = "Search results for insights"
|
||||
|
||||
result, artifact = await self.tool._arun_impl(kind="insights", query="test insight query")
|
||||
|
||||
self.assertEqual(result, "Search tool executed")
|
||||
self.assertEqual(artifact, {"kind": "insights", "query": "test insight query"})
|
||||
mock_execute.assert_not_called()
|
||||
|
||||
@patch("ee.hogai.graph.root.tools.search.EntitySearchToolkit.execute")
|
||||
@patch("ee.hogai.graph.root.tools.search.SearchTool._has_fts_search_feature_flag")
|
||||
async def test_arun_impl_insight_with_feature_flag_enabled(self, mock_has_fts_search_feature_flag, mock_execute):
|
||||
mock_has_fts_search_feature_flag.return_value = True
|
||||
mock_execute.return_value = "Search results for insights"
|
||||
|
||||
result, artifact = await self.tool._arun_impl(kind="insights", query="test insight query")
|
||||
|
||||
self.assertEqual(result, "Search results for insights")
|
||||
self.assertIsNone(artifact)
|
||||
mock_execute.assert_called_once_with("test insight query", "insights")
|
||||
self.assertEqual(result, "")
|
||||
self.assertIsNone(artifact)
|
||||
|
||||
193
ee/hogai/graph/root/tools/test/test_session_summarization.py
Normal file
193
ee/hogai/graph/root/tools/test/test_session_summarization.py
Normal file
@@ -0,0 +1,193 @@
|
||||
from typing import cast
|
||||
from uuid import uuid4
|
||||
|
||||
from posthog.test.base import ClickhouseTestMixin, NonAtomicBaseTest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
from posthog.schema import AssistantMessage
|
||||
|
||||
from ee.hogai.context.context import AssistantContextManager
|
||||
from ee.hogai.graph.root.tools.session_summarization import SessionSummarizationTool
|
||||
from ee.hogai.utils.types import AssistantState, PartialAssistantState
|
||||
from ee.hogai.utils.types.base import NodePath
|
||||
|
||||
|
||||
class TestSessionSummarizationTool(ClickhouseTestMixin, NonAtomicBaseTest):
|
||||
CLASS_DATA_LEVEL_SETUP = False
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.tool_call_id = str(uuid4())
|
||||
self.state = AssistantState(messages=[], root_tool_call_id=self.tool_call_id)
|
||||
self.context_manager = AssistantContextManager(self.team, self.user, {})
|
||||
self.tool = SessionSummarizationTool(
|
||||
team=self.team,
|
||||
user=self.user,
|
||||
state=self.state,
|
||||
context_manager=self.context_manager,
|
||||
node_path=(NodePath(name="test_node", tool_call_id=self.tool_call_id, message_id="test"),),
|
||||
)
|
||||
|
||||
async def test_execute_calls_session_summarization_node(self):
|
||||
mock_node_instance = MagicMock()
|
||||
mock_result = PartialAssistantState(
|
||||
messages=[AssistantMessage(content="Session summary: 10 sessions analyzed with 5 key patterns found")]
|
||||
)
|
||||
|
||||
with patch("ee.hogai.graph.session_summaries.nodes.SessionSummarizationNode", return_value=mock_node_instance):
|
||||
with patch("ee.hogai.graph.root.tools.session_summarization.RunnableLambda") as mock_runnable:
|
||||
mock_chain = MagicMock()
|
||||
mock_chain.ainvoke = AsyncMock(return_value=mock_result)
|
||||
mock_runnable.return_value = mock_chain
|
||||
|
||||
result, artifact = await self.tool._arun_impl(
|
||||
session_summarization_query="summarize all sessions from yesterday",
|
||||
should_use_current_filters=False,
|
||||
summary_title="All sessions from yesterday",
|
||||
)
|
||||
|
||||
self.assertEqual(result, "")
|
||||
self.assertIsNotNone(artifact)
|
||||
assert artifact is not None
|
||||
self.assertEqual(len(artifact.messages), 1)
|
||||
message = cast(AssistantMessage, artifact.messages[0])
|
||||
self.assertEqual(message.content, "Session summary: 10 sessions analyzed with 5 key patterns found")
|
||||
|
||||
async def test_execute_updates_state_with_all_parameters(self):
|
||||
mock_result = PartialAssistantState(messages=[AssistantMessage(content="Test response")])
|
||||
|
||||
async def mock_ainvoke(state):
|
||||
self.assertEqual(state.session_summarization_query, "analyze mobile user sessions")
|
||||
self.assertEqual(state.should_use_current_filters, True)
|
||||
self.assertEqual(state.summary_title, "Mobile user sessions")
|
||||
self.assertEqual(state.root_tool_call_id, self.tool_call_id)
|
||||
return mock_result
|
||||
|
||||
mock_chain = MagicMock()
|
||||
mock_chain.ainvoke = mock_ainvoke
|
||||
|
||||
with patch("ee.hogai.graph.session_summaries.nodes.SessionSummarizationNode"):
|
||||
with patch("ee.hogai.graph.root.tools.session_summarization.RunnableLambda", return_value=mock_chain):
|
||||
await self.tool._arun_impl(
|
||||
session_summarization_query="analyze mobile user sessions",
|
||||
should_use_current_filters=True,
|
||||
summary_title="Mobile user sessions",
|
||||
)
|
||||
|
||||
async def test_execute_with_should_use_current_filters_false(self):
|
||||
mock_result = PartialAssistantState(messages=[AssistantMessage(content="Test response")])
|
||||
|
||||
async def mock_ainvoke(state):
|
||||
self.assertEqual(state.should_use_current_filters, False)
|
||||
self.assertEqual(state.session_summarization_query, "watch last 300 session recordings")
|
||||
self.assertEqual(state.summary_title, "Last 300 sessions")
|
||||
return mock_result
|
||||
|
||||
mock_chain = MagicMock()
|
||||
mock_chain.ainvoke = mock_ainvoke
|
||||
|
||||
with patch("ee.hogai.graph.session_summaries.nodes.SessionSummarizationNode"):
|
||||
with patch("ee.hogai.graph.root.tools.session_summarization.RunnableLambda", return_value=mock_chain):
|
||||
await self.tool._arun_impl(
|
||||
session_summarization_query="watch last 300 session recordings",
|
||||
should_use_current_filters=False,
|
||||
summary_title="Last 300 sessions",
|
||||
)
|
||||
|
||||
async def test_execute_returns_failure_message_when_result_is_none(self):
|
||||
async def mock_ainvoke(state):
|
||||
return None
|
||||
|
||||
mock_chain = MagicMock()
|
||||
mock_chain.ainvoke = mock_ainvoke
|
||||
|
||||
with patch("ee.hogai.graph.session_summaries.nodes.SessionSummarizationNode"):
|
||||
with patch("ee.hogai.graph.root.tools.session_summarization.RunnableLambda", return_value=mock_chain):
|
||||
result, artifact = await self.tool._arun_impl(
|
||||
session_summarization_query="test query",
|
||||
should_use_current_filters=False,
|
||||
summary_title="Test",
|
||||
)
|
||||
|
||||
self.assertEqual(result, "Session summarization failed")
|
||||
self.assertIsNone(artifact)
|
||||
|
||||
async def test_execute_returns_failure_message_when_result_has_no_messages(self):
|
||||
mock_result = PartialAssistantState(messages=[])
|
||||
|
||||
async def mock_ainvoke(state):
|
||||
return mock_result
|
||||
|
||||
mock_chain = MagicMock()
|
||||
mock_chain.ainvoke = mock_ainvoke
|
||||
|
||||
with patch("ee.hogai.graph.session_summaries.nodes.SessionSummarizationNode"):
|
||||
with patch("ee.hogai.graph.root.tools.session_summarization.RunnableLambda", return_value=mock_chain):
|
||||
result, artifact = await self.tool._arun_impl(
|
||||
session_summarization_query="test query",
|
||||
should_use_current_filters=False,
|
||||
summary_title="Test",
|
||||
)
|
||||
|
||||
self.assertEqual(result, "Session summarization failed")
|
||||
self.assertIsNone(artifact)
|
||||
|
||||
async def test_execute_with_empty_summary_title(self):
|
||||
mock_result = PartialAssistantState(messages=[AssistantMessage(content="Summary completed")])
|
||||
|
||||
async def mock_ainvoke(state):
|
||||
self.assertEqual(state.summary_title, "")
|
||||
return mock_result
|
||||
|
||||
mock_chain = MagicMock()
|
||||
mock_chain.ainvoke = mock_ainvoke
|
||||
|
||||
with patch("ee.hogai.graph.session_summaries.nodes.SessionSummarizationNode"):
|
||||
with patch("ee.hogai.graph.root.tools.session_summarization.RunnableLambda", return_value=mock_chain):
|
||||
result, artifact = await self.tool._arun_impl(
|
||||
session_summarization_query="summarize sessions",
|
||||
should_use_current_filters=False,
|
||||
summary_title="",
|
||||
)
|
||||
|
||||
self.assertEqual(result, "")
|
||||
self.assertIsNotNone(artifact)
|
||||
|
||||
async def test_execute_preserves_original_state(self):
|
||||
"""Test that the original state is not modified when creating the copied state"""
|
||||
original_query = "original query"
|
||||
original_state = AssistantState(
|
||||
messages=[],
|
||||
root_tool_call_id=self.tool_call_id,
|
||||
session_summarization_query=original_query,
|
||||
)
|
||||
|
||||
tool = SessionSummarizationTool(
|
||||
team=self.team,
|
||||
user=self.user,
|
||||
state=original_state,
|
||||
context_manager=self.context_manager,
|
||||
node_path=(NodePath(name="test_node", tool_call_id=self.tool_call_id, message_id="test"),),
|
||||
)
|
||||
|
||||
mock_result = PartialAssistantState(messages=[AssistantMessage(content="Test")])
|
||||
|
||||
async def mock_ainvoke(state):
|
||||
# Verify the new state has updated values
|
||||
self.assertEqual(state.session_summarization_query, "new query")
|
||||
return mock_result
|
||||
|
||||
mock_chain = MagicMock()
|
||||
mock_chain.ainvoke = mock_ainvoke
|
||||
|
||||
with patch("ee.hogai.graph.session_summaries.nodes.SessionSummarizationNode"):
|
||||
with patch("ee.hogai.graph.root.tools.session_summarization.RunnableLambda", return_value=mock_chain):
|
||||
await tool._arun_impl(
|
||||
session_summarization_query="new query",
|
||||
should_use_current_filters=True,
|
||||
summary_title="New Summary",
|
||||
)
|
||||
|
||||
# Verify original state was not modified
|
||||
self.assertEqual(original_state.session_summarization_query, original_query)
|
||||
self.assertEqual(original_state.root_tool_call_id, self.tool_call_id)
|
||||
@@ -138,11 +138,11 @@ When unsure, use this tool. Proactive task management shows attentiveness and he
|
||||
""".strip()
|
||||
|
||||
|
||||
# Has its unique schema that doesn't match the Deep Research schema
|
||||
class TodoItem(BaseModel):
|
||||
content: str = Field(..., min_length=1)
|
||||
status: Literal["pending", "in_progress", "completed"]
|
||||
id: str
|
||||
priority: Literal["low", "medium", "high"]
|
||||
|
||||
|
||||
class TodoWriteToolArgs(BaseModel):
|
||||
|
||||
@@ -13,7 +13,7 @@ from langchain_core.messages import (
|
||||
from langchain_core.prompts import ChatPromptTemplate, HumanMessagePromptTemplate
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
|
||||
from posthog.schema import FailureMessage, VisualizationMessage
|
||||
from posthog.schema import VisualizationMessage
|
||||
|
||||
from posthog.models.group_type_mapping import GroupTypeMapping
|
||||
|
||||
@@ -36,6 +36,15 @@ from .utils import Q, SchemaGeneratorOutput
|
||||
RETRIES_ALLOWED = 2
|
||||
|
||||
|
||||
class SchemaGenerationException(Exception):
|
||||
"""An error occurred while generating a schema in the `SchemaGeneratorNode` node."""
|
||||
|
||||
def __init__(self, llm_output: str, validation_message: str):
|
||||
super().__init__("Failed to generate schema")
|
||||
self.llm_output = llm_output
|
||||
self.validation_message = validation_message
|
||||
|
||||
|
||||
class SchemaGeneratorNode(AssistantNode, Generic[Q]):
|
||||
INSIGHT_NAME: str
|
||||
"""
|
||||
@@ -87,9 +96,8 @@ class SchemaGeneratorNode(AssistantNode, Generic[Q]):
|
||||
|
||||
chain = generation_prompt | merger | self._model | self._parse_output
|
||||
|
||||
result: SchemaGeneratorOutput[Q] | None = None
|
||||
try:
|
||||
result = await chain.ainvoke(
|
||||
result: SchemaGeneratorOutput[Q] = await chain.ainvoke(
|
||||
{
|
||||
"project_datetime": self.project_now,
|
||||
"project_timezone": self.project_timezone,
|
||||
@@ -120,18 +128,9 @@ class SchemaGeneratorNode(AssistantNode, Generic[Q]):
|
||||
query_generation_retry_count=len(intermediate_steps) + 1,
|
||||
)
|
||||
|
||||
if not result:
|
||||
# We've got no usable result after exhausting all iteration attempts - it's failure message time
|
||||
return PartialAssistantState(
|
||||
messages=[
|
||||
FailureMessage(
|
||||
content=f"It looks like I'm having trouble generating this {self.INSIGHT_NAME} insight."
|
||||
)
|
||||
],
|
||||
intermediate_steps=None,
|
||||
plan=None,
|
||||
query_generation_retry_count=len(intermediate_steps) + 1,
|
||||
)
|
||||
if isinstance(e, PydanticOutputParserException):
|
||||
raise SchemaGenerationException(e.llm_output, e.validation_message)
|
||||
raise SchemaGenerationException(e.llm_output or "No input was provided.", str(e))
|
||||
|
||||
# We've got a result that either passed the quality check or we've exhausted all attempts at iterating - return
|
||||
return PartialAssistantState(
|
||||
|
||||
@@ -13,7 +13,12 @@ from langchain_core.runnables import RunnableConfig, RunnableLambda
|
||||
|
||||
from posthog.schema import AssistantMessage, AssistantTrendsQuery, FailureMessage, HumanMessage, VisualizationMessage
|
||||
|
||||
from ee.hogai.graph.schema_generator.nodes import RETRIES_ALLOWED, SchemaGeneratorNode, SchemaGeneratorToolsNode
|
||||
from ee.hogai.graph.schema_generator.nodes import (
|
||||
RETRIES_ALLOWED,
|
||||
SchemaGenerationException,
|
||||
SchemaGeneratorNode,
|
||||
SchemaGeneratorToolsNode,
|
||||
)
|
||||
from ee.hogai.graph.schema_generator.parsers import PydanticOutputParserException
|
||||
from ee.hogai.graph.schema_generator.utils import SchemaGeneratorOutput
|
||||
from ee.hogai.utils.types import AssistantState, PartialAssistantState
|
||||
@@ -258,7 +263,7 @@ class TestSchemaGeneratorNode(BaseTest):
|
||||
self.assertEqual(action.log, "Field validation failed")
|
||||
|
||||
async def test_quality_check_failure_with_retries_exhausted(self):
|
||||
"""Test quality check failure with retries exhausted still returns VisualizationMessage."""
|
||||
"""Test quality check failure with retries exhausted raises SchemaGenerationException."""
|
||||
node = DummyGeneratorNode(self.team, self.user)
|
||||
with (
|
||||
patch.object(DummyGeneratorNode, "_model") as generator_model_mock,
|
||||
@@ -273,26 +278,25 @@ class TestSchemaGeneratorNode(BaseTest):
|
||||
)
|
||||
|
||||
# Start with RETRIES_ALLOWED intermediate steps (so no more allowed)
|
||||
new_state = await node.arun(
|
||||
AssistantState(
|
||||
messages=[HumanMessage(content="Text", id="0")],
|
||||
start_id="0",
|
||||
intermediate_steps=cast(
|
||||
list[IntermediateStep],
|
||||
[
|
||||
(AgentAction(tool="handle_incorrect_response", tool_input="", log=""), "retry"),
|
||||
],
|
||||
)
|
||||
* RETRIES_ALLOWED,
|
||||
),
|
||||
{},
|
||||
)
|
||||
with self.assertRaises(SchemaGenerationException) as cm:
|
||||
await node.arun(
|
||||
AssistantState(
|
||||
messages=[HumanMessage(content="Text", id="0")],
|
||||
start_id="0",
|
||||
intermediate_steps=cast(
|
||||
list[IntermediateStep],
|
||||
[
|
||||
(AgentAction(tool="handle_incorrect_response", tool_input="", log=""), "retry"),
|
||||
],
|
||||
)
|
||||
* RETRIES_ALLOWED,
|
||||
),
|
||||
{},
|
||||
)
|
||||
|
||||
# Should return VisualizationMessage despite quality check failure
|
||||
self.assertEqual(new_state.intermediate_steps, None)
|
||||
self.assertEqual(len(new_state.messages), 1)
|
||||
self.assertEqual(new_state.messages[0].type, "ai/viz")
|
||||
self.assertEqual(cast(VisualizationMessage, new_state.messages[0]).answer, self.basic_trends)
|
||||
# Verify the exception contains the expected information
|
||||
self.assertEqual(cm.exception.llm_output, '{"query": "test"}')
|
||||
self.assertEqual(cm.exception.validation_message, "Quality check failed")
|
||||
|
||||
async def test_node_leaves_failover(self):
|
||||
node = DummyGeneratorNode(self.team, self.user)
|
||||
@@ -329,20 +333,17 @@ class TestSchemaGeneratorNode(BaseTest):
|
||||
schema = DummySchema.model_construct(query=[]).model_dump() # type: ignore
|
||||
generator_model_mock.return_value = RunnableLambda(lambda _: json.dumps(schema))
|
||||
|
||||
new_state = await node.arun(
|
||||
AssistantState(
|
||||
messages=[HumanMessage(content="Text")],
|
||||
intermediate_steps=[
|
||||
(AgentAction(tool="", tool_input="", log="exception"), "exception"),
|
||||
(AgentAction(tool="", tool_input="", log="exception"), "exception"),
|
||||
],
|
||||
),
|
||||
{},
|
||||
)
|
||||
self.assertEqual(new_state.intermediate_steps, None)
|
||||
self.assertEqual(len(new_state.messages), 1)
|
||||
self.assertIsInstance(new_state.messages[0], FailureMessage)
|
||||
self.assertEqual(new_state.plan, None)
|
||||
with self.assertRaises(SchemaGenerationException):
|
||||
await node.arun(
|
||||
AssistantState(
|
||||
messages=[HumanMessage(content="Text")],
|
||||
intermediate_steps=[
|
||||
(AgentAction(tool="", tool_input="", log="exception"), "exception"),
|
||||
(AgentAction(tool="", tool_input="", log="exception"), "exception"),
|
||||
],
|
||||
),
|
||||
{},
|
||||
)
|
||||
|
||||
async def test_agent_reconstructs_conversation_with_failover(self):
|
||||
action = AgentAction(tool="fix", tool_input="validation error", log="exception")
|
||||
|
||||
@@ -11,7 +11,6 @@ from langchain_core.prompts import ChatPromptTemplate
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
|
||||
from posthog.schema import (
|
||||
AssistantMessage,
|
||||
AssistantToolCallMessage,
|
||||
MaxRecordingUniversalFilters,
|
||||
NotebookUpdateMessage,
|
||||
@@ -70,11 +69,7 @@ class SessionSummarizationNode(AssistantNode):
|
||||
"""Push summarization progress as reasoning messages"""
|
||||
content = prepare_reasoning_progress_message(progress_message)
|
||||
if content:
|
||||
self.dispatcher.message(
|
||||
AssistantMessage(
|
||||
content=content,
|
||||
)
|
||||
)
|
||||
self.dispatcher.update(content)
|
||||
|
||||
async def _stream_notebook_content(self, content: dict, state: AssistantState, partial: bool = True) -> None:
|
||||
"""Stream TipTap content directly to a notebook if notebook_id is present in state."""
|
||||
@@ -202,7 +197,7 @@ class _SessionSearch:
|
||||
tool = await SearchSessionRecordingsTool.create_tool_class(
|
||||
team=self._node._team,
|
||||
user=self._node._user,
|
||||
tool_call_id=self._node._parent_tool_call_id or "",
|
||||
node_path=self._node.node_path,
|
||||
state=state,
|
||||
config=config,
|
||||
context_manager=self._node.context_manager,
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
|
||||
from posthog.schema import AssistantHogQLQuery, AssistantMessage
|
||||
from posthog.schema import AssistantHogQLQuery
|
||||
|
||||
from posthog.hogql.context import HogQLContext
|
||||
|
||||
@@ -25,7 +25,7 @@ class SQLGeneratorNode(HogQLGeneratorMixin, SchemaGeneratorNode[AssistantHogQLQu
|
||||
return AssistantNodeName.SQL_GENERATOR
|
||||
|
||||
async def arun(self, state: AssistantState, config: RunnableConfig) -> PartialAssistantState:
|
||||
self.dispatcher.message(AssistantMessage(content="Creating SQL query"))
|
||||
self.dispatcher.update("Creating SQL query")
|
||||
prompt = await self._construct_system_prompt()
|
||||
return await super()._run_with_prompt(state, prompt, config=config)
|
||||
|
||||
|
||||
@@ -3,34 +3,43 @@ from typing import Generic
|
||||
from posthog.models import Team, User
|
||||
|
||||
from ee.hogai.django_checkpoint.checkpointer import DjangoCheckpointer
|
||||
from ee.hogai.graph.graph import BaseAssistantGraph
|
||||
from ee.hogai.graph.base import BaseAssistantGraph
|
||||
from ee.hogai.utils.types import PartialStateType, StateType
|
||||
from ee.hogai.utils.types.base import AssistantGraphName
|
||||
|
||||
from .nodes import StateClassMixin, TaxonomyAgentNode, TaxonomyAgentToolsNode
|
||||
from .toolkit import TaxonomyAgentToolkit
|
||||
from .types import TaxonomyNodeName
|
||||
|
||||
|
||||
class TaxonomyAgent(BaseAssistantGraph[StateType], Generic[StateType, PartialStateType], StateClassMixin):
|
||||
class TaxonomyAgent(
|
||||
BaseAssistantGraph[StateType, PartialStateType], Generic[StateType, PartialStateType], StateClassMixin
|
||||
):
|
||||
"""Taxonomy agent that can be configured with different node classes."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
team: Team,
|
||||
user: User,
|
||||
tool_call_id: str,
|
||||
loop_node_class: type["TaxonomyAgentNode"],
|
||||
tools_node_class: type["TaxonomyAgentToolsNode"],
|
||||
toolkit_class: type["TaxonomyAgentToolkit"],
|
||||
):
|
||||
# Extract the State type from the generic parameter
|
||||
state_class, _ = self._get_state_class(TaxonomyAgent)
|
||||
super().__init__(team, user, state_class, parent_tool_call_id=tool_call_id)
|
||||
|
||||
super().__init__(team, user)
|
||||
self._loop_node_class = loop_node_class
|
||||
self._tools_node_class = tools_node_class
|
||||
self._toolkit_class = toolkit_class
|
||||
|
||||
@property
|
||||
def state_type(self) -> type[StateType]:
|
||||
# Extract the State type from the generic parameter
|
||||
state_type, _ = self._get_state_class(TaxonomyAgent)
|
||||
return state_type
|
||||
|
||||
@property
|
||||
def graph_name(self) -> AssistantGraphName:
|
||||
return AssistantGraphName.TAXONOMY
|
||||
|
||||
def compile_full_graph(self, checkpointer: DjangoCheckpointer | None = None):
|
||||
"""Compile a complete taxonomy graph."""
|
||||
return self.add_taxonomy_generator().compile(checkpointer=checkpointer)
|
||||
|
||||
@@ -40,25 +40,10 @@ TaxonomyPartialStateType = TypeVar("TaxonomyPartialStateType", bound=TaxonomyAge
|
||||
TaxonomyNodeBound = BaseAssistantNode[TaxonomyStateType, TaxonomyPartialStateType]
|
||||
|
||||
|
||||
class ParentToolCallIdMixin(BaseAssistantNode[TaxonomyStateType, TaxonomyPartialStateType]):
|
||||
_parent_tool_call_id: str | None = None
|
||||
_toolkit: TaxonomyAgentToolkit
|
||||
|
||||
async def __call__(self, state: TaxonomyStateType, config: RunnableConfig) -> TaxonomyPartialStateType | None:
|
||||
"""
|
||||
Run the assistant node and handle cancelled conversation before the node is run.
|
||||
"""
|
||||
if self._parent_tool_call_id:
|
||||
self._toolkit._parent_tool_call_id = self._parent_tool_call_id
|
||||
|
||||
return await super().__call__(state, config)
|
||||
|
||||
|
||||
class TaxonomyAgentNode(
|
||||
Generic[TaxonomyStateType, TaxonomyPartialStateType],
|
||||
StateClassMixin,
|
||||
TaxonomyUpdateDispatcherNodeMixin,
|
||||
ParentToolCallIdMixin,
|
||||
TaxonomyNodeBound,
|
||||
ABC,
|
||||
):
|
||||
@@ -173,7 +158,6 @@ class TaxonomyAgentToolsNode(
|
||||
Generic[TaxonomyStateType, TaxonomyPartialStateType],
|
||||
StateClassMixin,
|
||||
TaxonomyUpdateDispatcherNodeMixin,
|
||||
ParentToolCallIdMixin,
|
||||
TaxonomyNodeBound,
|
||||
):
|
||||
"""Base tools node for taxonomy agents."""
|
||||
|
||||
@@ -36,14 +36,13 @@ class TestTaxonomyAgent(BaseTest):
|
||||
self.mock_graph = Mock()
|
||||
|
||||
# Patch StateGraph in the parent class where it's actually used
|
||||
self.patcher = patch("ee.hogai.graph.graph.StateGraph")
|
||||
self.patcher = patch("ee.hogai.graph.base.graph.StateGraph")
|
||||
mock_state_graph_class = self.patcher.start()
|
||||
mock_state_graph_class.return_value = self.mock_graph
|
||||
|
||||
self.agent = ConcreteTaxonomyAgent(
|
||||
team=self.team,
|
||||
user=self.user,
|
||||
tool_call_id="test_tool_call_id",
|
||||
loop_node_class=MockTaxonomyAgentNode,
|
||||
tools_node_class=MockTaxonomyAgentToolsNode,
|
||||
toolkit_class=MockTaxonomyAgentToolkit,
|
||||
@@ -75,7 +74,6 @@ class TestTaxonomyAgent(BaseTest):
|
||||
NonGenericAgent(
|
||||
team=self.team,
|
||||
user=self.user,
|
||||
tool_call_id="test_tool_call_id",
|
||||
loop_node_class=MockTaxonomyAgentNode,
|
||||
tools_node_class=MockTaxonomyAgentToolsNode,
|
||||
toolkit_class=MockTaxonomyAgentToolkit,
|
||||
|
||||
@@ -122,10 +122,6 @@ class TaxonomyTaskExecutorNode(
|
||||
Task executor node specifically for taxonomy operations.
|
||||
"""
|
||||
|
||||
def __init__(self, team: Team, user: User, parent_tool_call_id: str | None):
|
||||
super().__init__(team, user)
|
||||
self._parent_tool_call_id = parent_tool_call_id
|
||||
|
||||
@property
|
||||
def node_name(self) -> MaxNodeName:
|
||||
return TaxonomyNodeName.TASK_EXECUTOR
|
||||
@@ -148,8 +144,6 @@ class TaxonomyTaskExecutorNode(
|
||||
class TaxonomyAgentToolkit:
|
||||
"""Base toolkit for taxonomy agents that handle tool execution."""
|
||||
|
||||
_parent_tool_call_id: str | None
|
||||
|
||||
def __init__(self, team: Team, user: User):
|
||||
self._team = team
|
||||
self._user = user
|
||||
@@ -391,7 +385,7 @@ class TaxonomyAgentToolkit:
|
||||
)
|
||||
)
|
||||
message = AssistantMessage(content="", id=str(uuid4()), tool_calls=tool_calls)
|
||||
executor = TaxonomyTaskExecutorNode(self._team, self._user, parent_tool_call_id=self._parent_tool_call_id)
|
||||
executor = TaxonomyTaskExecutorNode(self._team, self._user)
|
||||
result = await executor.arun(AssistantState(messages=[message]), RunnableConfig())
|
||||
|
||||
final_result = {}
|
||||
@@ -651,7 +645,7 @@ class TaxonomyAgentToolkit:
|
||||
for event_name_or_action_id in event_name_or_action_ids
|
||||
]
|
||||
message = AssistantMessage(content="", id=str(uuid4()), tool_calls=tool_calls)
|
||||
executor = TaxonomyTaskExecutorNode(self._team, self._user, parent_tool_call_id=self._parent_tool_call_id)
|
||||
executor = TaxonomyTaskExecutorNode(self._team, self._user)
|
||||
result = await executor.arun(AssistantState(messages=[message]), RunnableConfig())
|
||||
return {task_result.id: task_result.result for task_result in result.task_results}
|
||||
|
||||
|
||||
@@ -1,29 +0,0 @@
|
||||
from posthog.test.base import BaseTest
|
||||
|
||||
from langchain_core.runnables import RunnableLambda
|
||||
from langgraph.checkpoint.memory import InMemorySaver
|
||||
|
||||
from ee.hogai.graph.graph import BaseAssistantGraph
|
||||
from ee.hogai.utils.types import AssistantNodeName, AssistantState, PartialAssistantState
|
||||
|
||||
|
||||
class TestAssistantGraph(BaseTest):
|
||||
async def test_pydantic_state_resets_with_none(self):
|
||||
"""When a None field is set, it should be reset to None."""
|
||||
|
||||
async def runnable(state: AssistantState) -> PartialAssistantState:
|
||||
return PartialAssistantState(start_id=None)
|
||||
|
||||
graph = BaseAssistantGraph(self.team, self.user, state_type=AssistantState)
|
||||
compiled_graph = (
|
||||
graph.add_node(AssistantNodeName.ROOT, RunnableLambda(runnable))
|
||||
.add_edge(AssistantNodeName.START, AssistantNodeName.ROOT)
|
||||
.add_edge(AssistantNodeName.ROOT, AssistantNodeName.END)
|
||||
.compile(checkpointer=InMemorySaver())
|
||||
)
|
||||
state = await compiled_graph.ainvoke(
|
||||
AssistantState(messages=[], graph_status="resumed", start_id=None),
|
||||
{"configurable": {"thread_id": "test"}},
|
||||
)
|
||||
self.assertEqual(state["start_id"], None)
|
||||
self.assertEqual(state["graph_status"], "resumed")
|
||||
@@ -4,8 +4,7 @@ Integration tests for dispatcher usage in BaseAssistantNode and graph execution.
|
||||
These tests ensure that the dispatcher pattern works correctly end-to-end in real graph execution.
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
from posthog.test.base import BaseTest
|
||||
from unittest.mock import MagicMock, patch
|
||||
@@ -14,30 +13,40 @@ from langchain_core.runnables import RunnableConfig
|
||||
|
||||
from posthog.schema import AssistantMessage
|
||||
|
||||
from ee.hogai.graph.base import BaseAssistantNode
|
||||
from ee.hogai.utils.dispatcher import MessageAction, NodeStartAction
|
||||
from ee.hogai.utils.types import AssistantNodeName
|
||||
from ee.hogai.utils.types.base import AssistantDispatcherEvent, AssistantState, PartialAssistantState
|
||||
from ee.hogai.graph.base.node import BaseAssistantNode
|
||||
from ee.hogai.utils.types.base import (
|
||||
AssistantDispatcherEvent,
|
||||
AssistantGraphName,
|
||||
AssistantNodeName,
|
||||
AssistantState,
|
||||
MessageAction,
|
||||
NodeEndAction,
|
||||
NodePath,
|
||||
NodeStartAction,
|
||||
PartialAssistantState,
|
||||
UpdateAction,
|
||||
)
|
||||
|
||||
|
||||
class MockAssistantNode(BaseAssistantNode):
|
||||
class MockAssistantNode(BaseAssistantNode[AssistantState, PartialAssistantState]):
|
||||
"""Mock node for testing dispatcher integration."""
|
||||
|
||||
def __init__(self, team, user):
|
||||
super().__init__(team, user)
|
||||
def __init__(self, team, user, node_path=None):
|
||||
super().__init__(team, user, node_path)
|
||||
self.arun_called = False
|
||||
self.messages_dispatched = []
|
||||
|
||||
@property
|
||||
def node_name(self) -> AssistantNodeName:
|
||||
def node_name(self) -> str:
|
||||
return AssistantNodeName.ROOT
|
||||
|
||||
async def arun(self, state, config: RunnableConfig) -> PartialAssistantState:
|
||||
async def arun(self, state: AssistantState, config: RunnableConfig) -> PartialAssistantState:
|
||||
self.arun_called = True
|
||||
|
||||
# Simulate dispatching messages during execution
|
||||
self.dispatcher.message(AssistantMessage(content="Processing..."))
|
||||
self.dispatcher.message(AssistantMessage(content="Done!"))
|
||||
self.dispatcher.update("Processing...")
|
||||
self.dispatcher.message(AssistantMessage(content="Intermediate result"))
|
||||
self.dispatcher.update("Done!")
|
||||
|
||||
return PartialAssistantState(messages=[AssistantMessage(content="Final result")])
|
||||
|
||||
@@ -57,147 +66,169 @@ class TestDispatcherIntegration(BaseTest):
|
||||
node = MockAssistantNode(self.mock_team, self.mock_user)
|
||||
|
||||
state = AssistantState(messages=[])
|
||||
config = RunnableConfig()
|
||||
config = RunnableConfig(metadata={"langgraph_node": AssistantNodeName.ROOT, "langgraph_checkpoint_ns": "cp_1"})
|
||||
|
||||
await node.arun(state, config)
|
||||
with patch("ee.hogai.utils.dispatcher.get_stream_writer", side_effect=RuntimeError("Not streaming")):
|
||||
await node(state, config)
|
||||
|
||||
self.assertTrue(node.arun_called)
|
||||
self.assertIsNotNone(node.dispatcher)
|
||||
|
||||
async def test_node_path_propagation(self):
|
||||
"""Test that node_path is correctly set and propagated."""
|
||||
node_path = (
|
||||
NodePath(name=AssistantGraphName.ASSISTANT),
|
||||
NodePath(name=AssistantNodeName.ROOT),
|
||||
)
|
||||
|
||||
node = MockAssistantNode(self.mock_team, self.mock_user, node_path=node_path)
|
||||
|
||||
self.assertEqual(node.node_path, node_path)
|
||||
|
||||
async def test_dispatcher_dispatches_node_start_and_end(self):
|
||||
"""Test that NODE_START and NODE_END actions are dispatched."""
|
||||
node = MockAssistantNode(self.mock_team, self.mock_user)
|
||||
|
||||
state = AssistantState(messages=[])
|
||||
config = RunnableConfig(metadata={"langgraph_node": AssistantNodeName.ROOT, "langgraph_checkpoint_ns": "cp_2"})
|
||||
|
||||
dispatched_actions = []
|
||||
|
||||
def capture_write(event):
|
||||
if isinstance(event, AssistantDispatcherEvent):
|
||||
dispatched_actions.append(event)
|
||||
|
||||
with patch("ee.hogai.utils.dispatcher.get_stream_writer", return_value=capture_write):
|
||||
await node(state, config)
|
||||
|
||||
# Should have dispatched actions in this order:
|
||||
# 1. NODE_START
|
||||
# 2. UpdateAction ("Processing...")
|
||||
# 3. MessageAction (intermediate)
|
||||
# 4. UpdateAction ("Done!")
|
||||
# 5. NODE_END with final state
|
||||
self.assertGreater(len(dispatched_actions), 0)
|
||||
|
||||
# Verify NODE_START is first
|
||||
self.assertIsInstance(dispatched_actions[0].action, NodeStartAction)
|
||||
|
||||
# Verify NODE_END is last
|
||||
last_action = dispatched_actions[-1].action
|
||||
self.assertIsInstance(last_action, NodeEndAction)
|
||||
self.assertIsNotNone(cast(NodeEndAction, last_action).state)
|
||||
|
||||
async def test_messages_dispatched_during_node_execution(self):
|
||||
"""Test that messages dispatched during node execution are sent to writer."""
|
||||
node = MockAssistantNode(self.mock_team, self.mock_user)
|
||||
|
||||
state = AssistantState(messages=[])
|
||||
config = RunnableConfig()
|
||||
config = RunnableConfig(metadata={"langgraph_node": AssistantNodeName.ROOT, "langgraph_checkpoint_ns": "cp_3"})
|
||||
|
||||
dispatched_actions = []
|
||||
|
||||
def capture_write(update: Any):
|
||||
dispatched_actions.append(update)
|
||||
def capture_write(event: AssistantDispatcherEvent):
|
||||
dispatched_actions.append(event)
|
||||
|
||||
# Mock get_stream_writer to return our test writer
|
||||
with patch("ee.hogai.graph.base.get_stream_writer", return_value=capture_write):
|
||||
# Call the node (not arun) to trigger __call__ which handles dispatching
|
||||
with patch("ee.hogai.utils.dispatcher.get_stream_writer", return_value=capture_write):
|
||||
await node(state, config)
|
||||
|
||||
# Should have:
|
||||
# 1. NODE_START from __call__
|
||||
# 2. Two MESSAGE actions from arun (Processing..., Done!)
|
||||
# 3. One MESSAGE action from returned state (Final result)
|
||||
# Find the update and message actions (excluding NODE_START and NODE_END)
|
||||
update_actions = [e for e in dispatched_actions if isinstance(e.action, UpdateAction)]
|
||||
message_actions = [e for e in dispatched_actions if isinstance(e.action, MessageAction)]
|
||||
|
||||
self.assertEqual(len(dispatched_actions), 4)
|
||||
self.assertIsInstance(dispatched_actions[0], AssistantDispatcherEvent)
|
||||
self.assertIsInstance(dispatched_actions[0].action, NodeStartAction)
|
||||
self.assertIsInstance(dispatched_actions[1], AssistantDispatcherEvent)
|
||||
self.assertIsInstance(dispatched_actions[1].action, MessageAction)
|
||||
self.assertEqual(dispatched_actions[1].action.message.content, "Processing...")
|
||||
self.assertIsInstance(dispatched_actions[2], AssistantDispatcherEvent)
|
||||
self.assertIsInstance(dispatched_actions[2].action, MessageAction)
|
||||
self.assertEqual(dispatched_actions[2].action.message.content, "Done!")
|
||||
self.assertIsInstance(dispatched_actions[3], AssistantDispatcherEvent)
|
||||
self.assertIsInstance(dispatched_actions[3].action, MessageAction)
|
||||
self.assertEqual(dispatched_actions[3].action.message.content, "Final result")
|
||||
# Should have 2 updates: "Processing..." and "Done!"
|
||||
self.assertEqual(len(update_actions), 2)
|
||||
self.assertEqual(cast(UpdateAction, update_actions[0].action).content, "Processing...")
|
||||
self.assertEqual(cast(UpdateAction, update_actions[1].action).content, "Done!")
|
||||
|
||||
async def test_node_start_action_dispatched(self):
|
||||
"""Test that NODE_START action is dispatched at node entry."""
|
||||
# Should have 1 message: intermediate (final message is in NODE_END state, not dispatched separately)
|
||||
self.assertEqual(len(message_actions), 1)
|
||||
msg = cast(MessageAction, message_actions[0].action).message
|
||||
self.assertEqual(cast(AssistantMessage, msg).content, "Intermediate result")
|
||||
|
||||
async def test_node_path_included_in_dispatched_events(self):
|
||||
"""Test that node_path is included in all dispatched events."""
|
||||
node_path = (
|
||||
NodePath(name=AssistantGraphName.ASSISTANT),
|
||||
NodePath(name=AssistantNodeName.ROOT),
|
||||
)
|
||||
|
||||
node = MockAssistantNode(self.mock_team, self.mock_user, node_path=node_path)
|
||||
|
||||
state = AssistantState(messages=[])
|
||||
config = RunnableConfig(metadata={"langgraph_node": AssistantNodeName.ROOT, "langgraph_checkpoint_ns": "cp_4"})
|
||||
|
||||
dispatched_events = []
|
||||
|
||||
def capture_write(event: AssistantDispatcherEvent):
|
||||
dispatched_events.append(event)
|
||||
|
||||
with patch("ee.hogai.utils.dispatcher.get_stream_writer", return_value=capture_write):
|
||||
await node(state, config)
|
||||
|
||||
# Verify all events have the correct node_path
|
||||
for event in dispatched_events:
|
||||
self.assertEqual(event.node_path, node_path)
|
||||
|
||||
async def test_node_run_id_included_in_dispatched_events(self):
|
||||
"""Test that node_run_id is included in all dispatched events."""
|
||||
node = MockAssistantNode(self.mock_team, self.mock_user)
|
||||
|
||||
state = AssistantState(messages=[])
|
||||
config = RunnableConfig()
|
||||
checkpoint_ns = "checkpoint_xyz_789"
|
||||
config = RunnableConfig(
|
||||
metadata={"langgraph_node": AssistantNodeName.ROOT, "langgraph_checkpoint_ns": checkpoint_ns}
|
||||
)
|
||||
|
||||
dispatched_actions = []
|
||||
dispatched_events = []
|
||||
|
||||
def capture_write(update):
|
||||
if isinstance(update, AssistantDispatcherEvent):
|
||||
dispatched_actions.append(update.action)
|
||||
def capture_write(event: AssistantDispatcherEvent):
|
||||
dispatched_events.append(event)
|
||||
|
||||
with patch("ee.hogai.graph.base.get_stream_writer", return_value=capture_write):
|
||||
with patch("ee.hogai.utils.dispatcher.get_stream_writer", return_value=capture_write):
|
||||
await node(state, config)
|
||||
|
||||
# Should have at least one NODE_START action
|
||||
from ee.hogai.utils.dispatcher import NodeStartAction
|
||||
# Verify all events have the correct node_run_id
|
||||
for event in dispatched_events:
|
||||
self.assertEqual(event.node_run_id, checkpoint_ns)
|
||||
|
||||
node_start_actions = [action for action in dispatched_actions if isinstance(action, NodeStartAction)]
|
||||
self.assertGreater(len(node_start_actions), 0)
|
||||
|
||||
@patch("ee.hogai.graph.base.get_stream_writer")
|
||||
async def test_parent_tool_call_id_propagation(self, mock_get_stream_writer):
|
||||
"""Test that parent_tool_call_id is propagated to dispatched messages."""
|
||||
parent_tool_call_id = str(uuid.uuid4())
|
||||
|
||||
class NodeWithParent(BaseAssistantNode):
|
||||
@property
|
||||
def node_name(self):
|
||||
return AssistantNodeName.ROOT
|
||||
|
||||
async def arun(self, state, config: RunnableConfig) -> PartialAssistantState:
|
||||
# Dispatch message - should inherit parent_tool_call_id from state
|
||||
self.dispatcher.message(AssistantMessage(content="Child message"))
|
||||
return PartialAssistantState(messages=[])
|
||||
|
||||
node = NodeWithParent(self.mock_team, self.mock_user)
|
||||
|
||||
state = {"messages": [], "parent_tool_call_id": parent_tool_call_id}
|
||||
config = RunnableConfig()
|
||||
|
||||
dispatched_messages = []
|
||||
|
||||
def capture_write(update):
|
||||
if isinstance(update, AssistantDispatcherEvent) and isinstance(update.action, MessageAction):
|
||||
dispatched_messages.append(update.action.message)
|
||||
|
||||
mock_get_stream_writer.return_value = capture_write
|
||||
|
||||
await node.arun(state, config)
|
||||
|
||||
# Verify dispatched messages have parent_tool_call_id
|
||||
assistant_messages = [msg for msg in dispatched_messages if isinstance(msg, AssistantMessage)]
|
||||
for msg in assistant_messages:
|
||||
# If the implementation propagates it, this should be true
|
||||
# Otherwise this test will help catch that as a potential issue
|
||||
if msg.parent_tool_call_id:
|
||||
self.assertEqual(msg.parent_tool_call_id, parent_tool_call_id)
|
||||
|
||||
@patch("ee.hogai.graph.base.get_stream_writer")
|
||||
async def test_dispatcher_error_handling(self, mock_get_stream_writer):
|
||||
async def test_dispatcher_error_handling_does_not_crash_node(self):
|
||||
"""Test that errors in dispatcher don't crash node execution."""
|
||||
|
||||
class FailingDispatcherNode(BaseAssistantNode):
|
||||
class FailingDispatcherNode(BaseAssistantNode[AssistantState, PartialAssistantState]):
|
||||
@property
|
||||
def node_name(self):
|
||||
return AssistantNodeName.ROOT
|
||||
|
||||
async def arun(self, state, config: RunnableConfig) -> PartialAssistantState:
|
||||
async def arun(self, state: AssistantState, config: RunnableConfig) -> PartialAssistantState:
|
||||
# Try to dispatch - if writer fails, should handle gracefully
|
||||
self.dispatcher.message(AssistantMessage(content="Test"))
|
||||
self.dispatcher.update("Test")
|
||||
return PartialAssistantState(messages=[AssistantMessage(content="Result")])
|
||||
|
||||
node = FailingDispatcherNode(self.mock_team, self.mock_user)
|
||||
|
||||
state = AssistantState(messages=[])
|
||||
config = RunnableConfig()
|
||||
config = RunnableConfig(metadata={"langgraph_node": AssistantNodeName.ROOT, "langgraph_checkpoint_ns": "cp_5"})
|
||||
|
||||
# Make writer raise an error
|
||||
def failing_writer(data):
|
||||
raise RuntimeError("Writer failed")
|
||||
|
||||
mock_get_stream_writer.return_value = failing_writer
|
||||
with patch("ee.hogai.utils.dispatcher.get_stream_writer", return_value=failing_writer):
|
||||
# Should not crash - node should complete
|
||||
result = await node(state, config)
|
||||
self.assertIsNotNone(result)
|
||||
|
||||
# Should not crash - node should complete
|
||||
result = await node.arun(state, config)
|
||||
self.assertIsNotNone(result)
|
||||
async def test_messages_in_partial_state_dispatched_via_node_end(self):
|
||||
"""Test that messages in PartialState are dispatched via NODE_END."""
|
||||
|
||||
async def test_messages_in_partial_state_are_auto_dispatched(self):
|
||||
"""Test that messages in PartialState are automatically dispatched."""
|
||||
|
||||
class NodeReturningMessages(BaseAssistantNode):
|
||||
class NodeReturningMessages(BaseAssistantNode[AssistantState, PartialAssistantState]):
|
||||
@property
|
||||
def node_name(self):
|
||||
return AssistantNodeName.ROOT
|
||||
|
||||
async def arun(self, state, config: RunnableConfig) -> PartialAssistantState:
|
||||
# Return messages in state - should be auto-dispatched
|
||||
async def arun(self, state: AssistantState, config: RunnableConfig) -> PartialAssistantState:
|
||||
# Return messages in state
|
||||
return PartialAssistantState(
|
||||
messages=[
|
||||
AssistantMessage(content="Message 1"),
|
||||
@@ -208,39 +239,139 @@ class TestDispatcherIntegration(BaseTest):
|
||||
node = NodeReturningMessages(self.mock_team, self.mock_user)
|
||||
|
||||
state = AssistantState(messages=[])
|
||||
config = RunnableConfig()
|
||||
config = RunnableConfig(metadata={"langgraph_node": AssistantNodeName.ROOT, "langgraph_checkpoint_ns": "cp_6"})
|
||||
|
||||
dispatched_messages = []
|
||||
dispatched_events = []
|
||||
|
||||
def capture_write(update):
|
||||
if isinstance(update, AssistantDispatcherEvent) and isinstance(update.action, MessageAction):
|
||||
dispatched_messages.append(update.action.message)
|
||||
def capture_write(event: AssistantDispatcherEvent):
|
||||
dispatched_events.append(event)
|
||||
|
||||
with patch("ee.hogai.graph.base.get_stream_writer", return_value=capture_write):
|
||||
with patch("ee.hogai.utils.dispatcher.get_stream_writer", return_value=capture_write):
|
||||
await node(state, config)
|
||||
|
||||
# Should have dispatched the messages from PartialState (and NODE_START)
|
||||
# We expect at least 2 message actions (Message 1 and Message 2)
|
||||
self.assertGreaterEqual(len(dispatched_messages), 2)
|
||||
# Should have NODE_END action with state containing messages
|
||||
node_end_actions = [e for e in dispatched_events if isinstance(e.action, NodeEndAction)]
|
||||
self.assertEqual(len(node_end_actions), 1)
|
||||
|
||||
node_end_state = cast(NodeEndAction, node_end_actions[0].action).state
|
||||
self.assertIsNotNone(node_end_state)
|
||||
assert node_end_state is not None
|
||||
self.assertEqual(len(node_end_state.messages), 2)
|
||||
|
||||
async def test_node_returns_none_state_handling(self):
|
||||
"""Test that node can return None state without errors."""
|
||||
|
||||
class NoneStateNode(BaseAssistantNode):
|
||||
class NoneStateNode(BaseAssistantNode[AssistantState, PartialAssistantState]):
|
||||
@property
|
||||
def node_name(self):
|
||||
return AssistantNodeName.ROOT
|
||||
|
||||
async def arun(self, state, config: RunnableConfig) -> None:
|
||||
async def arun(self, state: AssistantState, config: RunnableConfig) -> PartialAssistantState | None:
|
||||
# Dispatch a message but return None state
|
||||
self.dispatcher.message(AssistantMessage(content="Test"))
|
||||
self.dispatcher.update("Test")
|
||||
return None
|
||||
|
||||
node = NoneStateNode(self.mock_team, self.mock_user)
|
||||
|
||||
state = AssistantState(messages=[])
|
||||
config = RunnableConfig()
|
||||
config = RunnableConfig(metadata={"langgraph_node": AssistantNodeName.ROOT, "langgraph_checkpoint_ns": "cp_7"})
|
||||
|
||||
result = await node(state, config)
|
||||
# Should handle None gracefully
|
||||
self.assertIsNone(result)
|
||||
with patch("ee.hogai.utils.dispatcher.get_stream_writer", side_effect=RuntimeError("Not streaming")):
|
||||
result = await node(state, config)
|
||||
# Should handle None gracefully
|
||||
self.assertIsNone(result)
|
||||
|
||||
async def test_nested_node_path_in_dispatched_events(self):
|
||||
"""Test that nested nodes have correct node_path."""
|
||||
# Simulate a nested node path (e.g., from a tool call)
|
||||
parent_path = (
|
||||
NodePath(name=AssistantGraphName.ASSISTANT),
|
||||
NodePath(name=AssistantNodeName.ROOT, message_id="msg_123", tool_call_id="tc_123"),
|
||||
NodePath(name=AssistantGraphName.INSIGHTS),
|
||||
)
|
||||
|
||||
node = MockAssistantNode(self.mock_team, self.mock_user, node_path=parent_path)
|
||||
|
||||
state = AssistantState(messages=[])
|
||||
config = RunnableConfig(
|
||||
metadata={"langgraph_node": AssistantNodeName.TRENDS_GENERATOR, "langgraph_checkpoint_ns": "cp_8"}
|
||||
)
|
||||
|
||||
dispatched_events = []
|
||||
|
||||
def capture_write(event: AssistantDispatcherEvent):
|
||||
dispatched_events.append(event)
|
||||
|
||||
with patch("ee.hogai.utils.dispatcher.get_stream_writer", return_value=capture_write):
|
||||
await node(state, config)
|
||||
|
||||
# Verify all events have the nested path
|
||||
for event in dispatched_events:
|
||||
self.assertIsNotNone(event.node_path)
|
||||
assert event.node_path is not None
|
||||
self.assertEqual(event.node_path, parent_path)
|
||||
self.assertEqual(len(event.node_path), 3)
|
||||
self.assertEqual(event.node_path[1].message_id, "msg_123")
|
||||
self.assertEqual(event.node_path[1].tool_call_id, "tc_123")
|
||||
|
||||
async def test_update_actions_include_node_metadata(self):
|
||||
"""Test that update actions include correct node metadata."""
|
||||
node = MockAssistantNode(self.mock_team, self.mock_user)
|
||||
|
||||
state = AssistantState(messages=[])
|
||||
config = RunnableConfig(metadata={"langgraph_node": AssistantNodeName.ROOT, "langgraph_checkpoint_ns": "cp_9"})
|
||||
|
||||
dispatched_events = []
|
||||
|
||||
def capture_write(event: AssistantDispatcherEvent):
|
||||
dispatched_events.append(event)
|
||||
|
||||
with patch("ee.hogai.utils.dispatcher.get_stream_writer", return_value=capture_write):
|
||||
await node(state, config)
|
||||
|
||||
# Find update actions
|
||||
update_events = [e for e in dispatched_events if isinstance(e.action, UpdateAction)]
|
||||
|
||||
for event in update_events:
|
||||
self.assertEqual(event.node_name, AssistantNodeName.ROOT)
|
||||
self.assertEqual(event.node_run_id, "cp_9")
|
||||
self.assertIsNotNone(event.node_path)
|
||||
|
||||
async def test_concurrent_node_executions_independent_dispatchers(self):
|
||||
"""Test that concurrent node executions use independent dispatchers."""
|
||||
node1 = MockAssistantNode(self.mock_team, self.mock_user)
|
||||
node2 = MockAssistantNode(self.mock_team, self.mock_user)
|
||||
|
||||
state = AssistantState(messages=[])
|
||||
config1 = RunnableConfig(
|
||||
metadata={"langgraph_node": AssistantNodeName.ROOT, "langgraph_checkpoint_ns": "cp_10a"}
|
||||
)
|
||||
config2 = RunnableConfig(
|
||||
metadata={"langgraph_node": AssistantNodeName.ROOT, "langgraph_checkpoint_ns": "cp_10b"}
|
||||
)
|
||||
|
||||
events1 = []
|
||||
events2 = []
|
||||
|
||||
def capture_write1(event: AssistantDispatcherEvent):
|
||||
events1.append(event)
|
||||
|
||||
def capture_write2(event: AssistantDispatcherEvent):
|
||||
events2.append(event)
|
||||
|
||||
with patch("ee.hogai.utils.dispatcher.get_stream_writer", return_value=capture_write1):
|
||||
await node1(state, config1)
|
||||
|
||||
with patch("ee.hogai.utils.dispatcher.get_stream_writer", return_value=capture_write2):
|
||||
await node2(state, config2)
|
||||
|
||||
# Verify events went to separate lists
|
||||
self.assertGreater(len(events1), 0)
|
||||
self.assertGreater(len(events2), 0)
|
||||
|
||||
# Verify node_run_ids are different
|
||||
for event in events1:
|
||||
self.assertEqual(event.node_run_id, "cp_10a")
|
||||
|
||||
for event in events2:
|
||||
self.assertEqual(event.node_run_id, "cp_10b")
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from langchain_core.prompts import ChatPromptTemplate
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
|
||||
from posthog.schema import AssistantMessage, AssistantTrendsQuery
|
||||
from posthog.schema import AssistantTrendsQuery
|
||||
|
||||
from ee.hogai.utils.types import AssistantState
|
||||
from ee.hogai.utils.types.base import AssistantNodeName, PartialAssistantState
|
||||
@@ -25,7 +25,7 @@ class TrendsGeneratorNode(SchemaGeneratorNode[AssistantTrendsQuery]):
|
||||
return AssistantNodeName.TRENDS_GENERATOR
|
||||
|
||||
async def arun(self, state: AssistantState, config: RunnableConfig) -> PartialAssistantState:
|
||||
self.dispatcher.message(AssistantMessage(content="Creating trends query"))
|
||||
self.dispatcher.update("Creating trends query")
|
||||
prompt = ChatPromptTemplate.from_messages(
|
||||
[
|
||||
("system", TRENDS_SYSTEM_PROMPT),
|
||||
|
||||
@@ -9,7 +9,7 @@ from posthog.test.base import (
|
||||
_create_person,
|
||||
flush_persons_and_events,
|
||||
)
|
||||
from unittest.mock import AsyncMock, patch
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
from django.test import override_settings
|
||||
|
||||
@@ -61,13 +61,13 @@ from posthog.models import Action
|
||||
|
||||
from ee.hogai.assistant.base import BaseAssistant
|
||||
from ee.hogai.django_checkpoint.checkpointer import DjangoCheckpointer
|
||||
from ee.hogai.graph.base import BaseAssistantNode
|
||||
from ee.hogai.graph.base import AssistantNode
|
||||
from ee.hogai.graph.funnels.nodes import FunnelsSchemaGeneratorOutput
|
||||
from ee.hogai.graph.insights_graph.graph import InsightsGraph
|
||||
from ee.hogai.graph.memory import prompts as memory_prompts
|
||||
from ee.hogai.graph.retention.nodes import RetentionSchemaGeneratorOutput
|
||||
from ee.hogai.graph.root.nodes import SLASH_COMMAND_INIT
|
||||
from ee.hogai.graph.trends.nodes import TrendsSchemaGeneratorOutput
|
||||
from ee.hogai.utils.state import GraphValueUpdateTuple
|
||||
from ee.hogai.utils.tests import FakeAnthropicRunnableLambdaWithTokenCounter, FakeChatAnthropic, FakeChatOpenAI
|
||||
from ee.hogai.utils.types import (
|
||||
AssistantMode,
|
||||
@@ -76,16 +76,21 @@ from ee.hogai.utils.types import (
|
||||
AssistantState,
|
||||
PartialAssistantState,
|
||||
)
|
||||
from ee.hogai.utils.types.base import ReplaceMessages
|
||||
from ee.models.assistant import Conversation, CoreMemory
|
||||
|
||||
from ..assistant import Assistant
|
||||
from ..graph import AssistantGraph, InsightsAssistantGraph
|
||||
from ..graph.graph import AssistantGraph
|
||||
|
||||
title_generator_mock = patch(
|
||||
"ee.hogai.graph.title_generator.nodes.TitleGeneratorNode._model",
|
||||
return_value=FakeChatOpenAI(responses=[messages.AIMessage(content="Title")]),
|
||||
)
|
||||
|
||||
query_executor_mock = patch(
|
||||
"ee.hogai.graph.query_executor.nodes.QueryExecutorNode._format_query_result", new=MagicMock(return_value="Result")
|
||||
)
|
||||
|
||||
|
||||
class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest):
|
||||
CLASS_DATA_LEVEL_SETUP = False
|
||||
@@ -118,7 +123,7 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest):
|
||||
),
|
||||
).start()
|
||||
|
||||
self.checkpointer_patch = patch("ee.hogai.graph.graph.global_checkpointer", new=DjangoCheckpointer())
|
||||
self.checkpointer_patch = patch("ee.hogai.graph.base.graph.global_checkpointer", new=DjangoCheckpointer())
|
||||
self.checkpointer_patch.start()
|
||||
|
||||
def tearDown(self):
|
||||
@@ -232,14 +237,7 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest):
|
||||
graph = (
|
||||
AssistantGraph(self.team, self.user)
|
||||
.add_edge(AssistantNodeName.START, AssistantNodeName.ROOT)
|
||||
.add_root(
|
||||
{
|
||||
"insights": AssistantNodeName.INSIGHTS_SUBGRAPH,
|
||||
"root": AssistantNodeName.ROOT,
|
||||
"end": AssistantNodeName.END,
|
||||
}
|
||||
)
|
||||
.add_insights(AssistantNodeName.ROOT)
|
||||
.add_root()
|
||||
.compile()
|
||||
)
|
||||
|
||||
@@ -269,7 +267,7 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest):
|
||||
{
|
||||
"id": "1",
|
||||
"name": "create_and_query_insight",
|
||||
"args": {"query_description": "Foobar", "query_kind": insight_type},
|
||||
"args": {"query_description": "Foobar"},
|
||||
}
|
||||
],
|
||||
)
|
||||
@@ -302,7 +300,7 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest):
|
||||
AssistantToolCall(
|
||||
id="1",
|
||||
name="create_and_query_insight",
|
||||
args={"query_description": "Foobar", "query_kind": insight_type},
|
||||
args={"query_description": "Foobar"},
|
||||
)
|
||||
],
|
||||
),
|
||||
@@ -327,12 +325,12 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest):
|
||||
snapshot: StateSnapshot = await graph.aget_state(config)
|
||||
self.assertFalse(snapshot.next)
|
||||
self.assertFalse(snapshot.values.get("intermediate_steps"))
|
||||
self.assertFalse(snapshot.values["plan"])
|
||||
self.assertFalse(snapshot.values["graph_status"])
|
||||
self.assertFalse(snapshot.values["root_tool_call_id"])
|
||||
self.assertFalse(snapshot.values["root_tool_insight_plan"])
|
||||
self.assertFalse(snapshot.values["root_tool_insight_type"])
|
||||
self.assertFalse(snapshot.values["root_tool_calls_count"])
|
||||
self.assertFalse(snapshot.values.get("plan"))
|
||||
self.assertFalse(snapshot.values.get("graph_status"))
|
||||
self.assertFalse(snapshot.values.get("root_tool_call_id"))
|
||||
self.assertFalse(snapshot.values.get("root_tool_insight_plan"))
|
||||
self.assertFalse(snapshot.values.get("root_tool_insight_type"))
|
||||
self.assertFalse(snapshot.values.get("root_tool_calls_count"))
|
||||
|
||||
async def test_trends_interrupt_when_asking_for_help(self):
|
||||
await self._test_human_in_the_loop("trends")
|
||||
@@ -345,7 +343,7 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest):
|
||||
|
||||
async def test_ai_messages_appended_after_interrupt(self):
|
||||
with patch("ee.hogai.graph.query_planner.nodes.QueryPlannerNode._get_model") as mock:
|
||||
graph = InsightsAssistantGraph(self.team, self.user).compile_full_graph()
|
||||
graph = InsightsGraph(self.team, self.user).compile_full_graph()
|
||||
config: RunnableConfig = {
|
||||
"configurable": {
|
||||
"thread_id": self.conversation.id,
|
||||
@@ -441,16 +439,13 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest):
|
||||
self.assertIsInstance(output[1][1], FailureMessage)
|
||||
|
||||
async def test_new_conversation_handles_serialized_conversation(self):
|
||||
from ee.hogai.graph.base import BaseAssistantNode
|
||||
|
||||
class TestNode(BaseAssistantNode):
|
||||
class TestNode(AssistantNode):
|
||||
@property
|
||||
def node_name(self):
|
||||
return AssistantNodeName.ROOT
|
||||
|
||||
async def arun(self, state, config):
|
||||
self.dispatcher.message(AssistantMessage(content="Hello"))
|
||||
return PartialAssistantState()
|
||||
return PartialAssistantState(messages=[AssistantMessage(content="Hello", id=str(uuid4()))])
|
||||
|
||||
test_node = TestNode(self.team, self.user)
|
||||
graph = (
|
||||
@@ -478,16 +473,13 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest):
|
||||
self.assertNotEqual(output[0][0], "conversation")
|
||||
|
||||
async def test_async_stream(self):
|
||||
from ee.hogai.graph.base import BaseAssistantNode
|
||||
|
||||
class TestNode(BaseAssistantNode):
|
||||
class TestNode(AssistantNode):
|
||||
@property
|
||||
def node_name(self):
|
||||
return AssistantNodeName.ROOT
|
||||
|
||||
async def arun(self, state, config):
|
||||
self.dispatcher.message(AssistantMessage(content="bar"))
|
||||
return PartialAssistantState()
|
||||
return PartialAssistantState(messages=[AssistantMessage(content="bar", id=str(uuid4()))])
|
||||
|
||||
test_node = TestNode(self.team, self.user)
|
||||
graph = (
|
||||
@@ -517,12 +509,13 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest):
|
||||
self.assertConversationEqual(actual_output, expected_output)
|
||||
|
||||
async def test_async_stream_handles_exceptions(self):
|
||||
def node_handler(state):
|
||||
raise ValueError()
|
||||
class NodeHandler(AssistantNode):
|
||||
async def arun(self, state, config):
|
||||
raise ValueError
|
||||
|
||||
graph = (
|
||||
AssistantGraph(self.team, self.user)
|
||||
.add_node(AssistantNodeName.ROOT, node_handler)
|
||||
.add_node(AssistantNodeName.ROOT, NodeHandler(self.team, self.user))
|
||||
.add_edge(AssistantNodeName.START, AssistantNodeName.ROOT)
|
||||
.add_edge(AssistantNodeName.ROOT, AssistantNodeName.END)
|
||||
.compile()
|
||||
@@ -536,12 +529,11 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest):
|
||||
("message", HumanMessage(content="foo")),
|
||||
("message", FailureMessage()),
|
||||
]
|
||||
actual_output = []
|
||||
async for event in assistant.astream():
|
||||
actual_output.append(event)
|
||||
actual_output, _ = await self._run_assistant_graph(graph, message="foo")
|
||||
self.assertConversationEqual(actual_output, expected_output)
|
||||
|
||||
@title_generator_mock
|
||||
@query_executor_mock
|
||||
@patch("ee.hogai.graph.schema_generator.nodes.SchemaGeneratorNode._model")
|
||||
@patch("ee.hogai.graph.query_planner.nodes.QueryPlannerNode._get_model")
|
||||
@patch("ee.hogai.graph.root.nodes.RootNode._get_model")
|
||||
@@ -556,7 +548,7 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest):
|
||||
{
|
||||
"id": "xyz",
|
||||
"name": "create_and_query_insight",
|
||||
"args": {"query_description": "Foobar", "query_kind": "trends"},
|
||||
"args": {"query_description": "Foobar"},
|
||||
}
|
||||
],
|
||||
)
|
||||
@@ -596,7 +588,7 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest):
|
||||
AssistantToolCall(
|
||||
id="xyz",
|
||||
name="create_and_query_insight",
|
||||
args={"query_description": "Foobar", "query_kind": "trends"},
|
||||
args={"query_description": "Foobar"},
|
||||
)
|
||||
],
|
||||
),
|
||||
@@ -611,7 +603,7 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest):
|
||||
),
|
||||
("update", AssistantUpdateEvent(id="message_2", tool_call_id="xyz", content="Creating trends query")),
|
||||
("message", VisualizationMessage(query="Foobar", answer=query, plan="Plan")),
|
||||
("message", {"tool_call_id": "xyz", "type": "tool"}), # Don't check content as it's implementation detail
|
||||
("message", AssistantToolCallMessage(tool_call_id="xyz", content="Result")),
|
||||
("message", AssistantMessage(content="The results indicate a great future for you.")),
|
||||
]
|
||||
self.assertConversationEqual(actual_output, expected_output)
|
||||
@@ -634,6 +626,7 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest):
|
||||
)
|
||||
|
||||
@title_generator_mock
|
||||
@query_executor_mock
|
||||
@patch("ee.hogai.graph.schema_generator.nodes.SchemaGeneratorNode._model")
|
||||
@patch("ee.hogai.graph.query_planner.nodes.QueryPlannerNode._get_model")
|
||||
@patch("ee.hogai.graph.root.nodes.RootNode._get_model")
|
||||
@@ -648,7 +641,7 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest):
|
||||
{
|
||||
"id": "xyz",
|
||||
"name": "create_and_query_insight",
|
||||
"args": {"query_description": "Foobar", "query_kind": "funnel"},
|
||||
"args": {"query_description": "Foobar"},
|
||||
}
|
||||
],
|
||||
)
|
||||
@@ -693,7 +686,7 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest):
|
||||
AssistantToolCall(
|
||||
id="xyz",
|
||||
name="create_and_query_insight",
|
||||
args={"query_description": "Foobar", "query_kind": "funnel"},
|
||||
args={"query_description": "Foobar"},
|
||||
)
|
||||
],
|
||||
),
|
||||
@@ -708,7 +701,7 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest):
|
||||
),
|
||||
("update", AssistantUpdateEvent(id="message_2", tool_call_id="xyz", content="Creating funnel query")),
|
||||
("message", VisualizationMessage(query="Foobar", answer=query, plan="Plan")),
|
||||
("message", {"tool_call_id": "xyz", "type": "tool"}), # Don't check content as it's implementation detail
|
||||
("message", AssistantToolCallMessage(tool_call_id="xyz", content="Result")),
|
||||
("message", AssistantMessage(content="The results indicate a great future for you.")),
|
||||
]
|
||||
self.assertConversationEqual(actual_output, expected_output)
|
||||
@@ -731,6 +724,7 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest):
|
||||
)
|
||||
|
||||
@title_generator_mock
|
||||
@query_executor_mock
|
||||
@patch("ee.hogai.graph.schema_generator.nodes.SchemaGeneratorNode._model")
|
||||
@patch("ee.hogai.graph.query_planner.nodes.QueryPlannerNode._get_model")
|
||||
@patch("ee.hogai.graph.root.nodes.RootNode._get_model")
|
||||
@@ -747,7 +741,7 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest):
|
||||
{
|
||||
"id": "xyz",
|
||||
"name": "create_and_query_insight",
|
||||
"args": {"query_description": "Foobar", "query_kind": "retention"},
|
||||
"args": {"query_description": "Foobar"},
|
||||
}
|
||||
],
|
||||
)
|
||||
@@ -792,7 +786,7 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest):
|
||||
AssistantToolCall(
|
||||
id="xyz",
|
||||
name="create_and_query_insight",
|
||||
args={"query_description": "Foobar", "query_kind": "retention"},
|
||||
args={"query_description": "Foobar"},
|
||||
)
|
||||
],
|
||||
),
|
||||
@@ -807,7 +801,7 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest):
|
||||
),
|
||||
("update", AssistantUpdateEvent(id="message_2", tool_call_id="xyz", content="Creating retention query")),
|
||||
("message", VisualizationMessage(query="Foobar", answer=query, plan="Plan")),
|
||||
("message", {"tool_call_id": "xyz", "type": "tool"}), # Don't check content as it's implementation detail
|
||||
("message", AssistantToolCallMessage(tool_call_id="xyz", content="Result")),
|
||||
("message", AssistantMessage(content="The results indicate a great future for you.")),
|
||||
]
|
||||
self.assertConversationEqual(actual_output, expected_output)
|
||||
@@ -830,6 +824,7 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest):
|
||||
)
|
||||
|
||||
@title_generator_mock
|
||||
@query_executor_mock
|
||||
@patch("ee.hogai.graph.schema_generator.nodes.SchemaGeneratorNode._model")
|
||||
@patch("ee.hogai.graph.query_planner.nodes.QueryPlannerNode._get_model")
|
||||
@patch("ee.hogai.graph.root.nodes.RootNode._get_model")
|
||||
@@ -844,7 +839,7 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest):
|
||||
{
|
||||
"id": "xyz",
|
||||
"name": "create_and_query_insight",
|
||||
"args": {"query_description": "Foobar", "query_kind": "sql"},
|
||||
"args": {"query_description": "Foobar"},
|
||||
}
|
||||
],
|
||||
)
|
||||
@@ -883,7 +878,7 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest):
|
||||
AssistantToolCall(
|
||||
id="xyz",
|
||||
name="create_and_query_insight",
|
||||
args={"query_description": "Foobar", "query_kind": "sql"},
|
||||
args={"query_description": "Foobar"},
|
||||
)
|
||||
],
|
||||
),
|
||||
@@ -898,7 +893,7 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest):
|
||||
),
|
||||
("update", AssistantUpdateEvent(id="message_2", tool_call_id="xyz", content="Creating SQL query")),
|
||||
("message", VisualizationMessage(query="Foobar", answer=query, plan="Plan")),
|
||||
("message", {"tool_call_id": "xyz", "type": "tool"}), # Don't check content as it's implementation detail
|
||||
("message", AssistantToolCallMessage(tool_call_id="xyz", content="Result")),
|
||||
("message", AssistantMessage(content="The results indicate a great future for you.")),
|
||||
]
|
||||
self.assertConversationEqual(actual_output, expected_output)
|
||||
@@ -1096,7 +1091,7 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest):
|
||||
{
|
||||
"id": str(uuid4()),
|
||||
"name": "create_and_query_insight",
|
||||
"args": {"query_description": "Foobar", "query_kind": "trends"},
|
||||
"args": {"query_description": "Foobar"},
|
||||
}
|
||||
],
|
||||
)
|
||||
@@ -1138,7 +1133,7 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest):
|
||||
graph = (
|
||||
AssistantGraph(self.team, self.user)
|
||||
.add_edge(AssistantNodeName.START, AssistantNodeName.ROOT)
|
||||
.add_root({"root": AssistantNodeName.ROOT, "end": AssistantNodeName.END})
|
||||
.add_root()
|
||||
.compile()
|
||||
)
|
||||
self.assertEqual(self.conversation.status, Conversation.Status.IDLE)
|
||||
@@ -1158,7 +1153,7 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest):
|
||||
graph = (
|
||||
AssistantGraph(self.team, self.user)
|
||||
.add_edge(AssistantNodeName.START, AssistantNodeName.ROOT)
|
||||
.add_root({"root": AssistantNodeName.ROOT, "end": AssistantNodeName.END})
|
||||
.add_root()
|
||||
.compile()
|
||||
)
|
||||
|
||||
@@ -1177,7 +1172,7 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest):
|
||||
{
|
||||
"id": "1",
|
||||
"name": "create_and_query_insight",
|
||||
"args": {"query_description": "Foobar", "query_kind": "trends"},
|
||||
"args": {"query_description": "Foobar"},
|
||||
}
|
||||
],
|
||||
)
|
||||
@@ -1209,14 +1204,7 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest):
|
||||
graph = (
|
||||
AssistantGraph(self.team, self.user)
|
||||
.add_edge(AssistantNodeName.START, AssistantNodeName.ROOT)
|
||||
.add_root(
|
||||
{
|
||||
"search_documentation": AssistantNodeName.INKEEP_DOCS,
|
||||
"root": AssistantNodeName.ROOT,
|
||||
"end": AssistantNodeName.END,
|
||||
}
|
||||
)
|
||||
.add_inkeep_docs()
|
||||
.add_root()
|
||||
.compile()
|
||||
)
|
||||
|
||||
@@ -1235,45 +1223,25 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest):
|
||||
self.assertConversationEqual(
|
||||
output,
|
||||
[
|
||||
(
|
||||
"message",
|
||||
HumanMessage(
|
||||
content="How do I use feature flags?",
|
||||
id="d93b3d57-2ff3-4c63-8635-8349955b16f0",
|
||||
parent_tool_call_id=None,
|
||||
type="human",
|
||||
ui_context=None,
|
||||
),
|
||||
),
|
||||
("message", HumanMessage(content="How do I use feature flags?")),
|
||||
(
|
||||
"message",
|
||||
AssistantMessage(
|
||||
content="",
|
||||
id="ea1f4faf-85b4-4d98-b044-842b7092ba5d",
|
||||
meta=None,
|
||||
parent_tool_call_id=None,
|
||||
tool_calls=[
|
||||
AssistantToolCall(
|
||||
args={"kind": "docs", "query": "test"}, id="1", name="search", type="tool_call"
|
||||
)
|
||||
],
|
||||
type="ai",
|
||||
),
|
||||
),
|
||||
(
|
||||
"update",
|
||||
AssistantUpdateEvent(id="message_2", tool_call_id="1", content="Checking PostHog documentation..."),
|
||||
AssistantUpdateEvent(content="Checking PostHog documentation...", id="1", tool_call_id="1"),
|
||||
),
|
||||
(
|
||||
"message",
|
||||
AssistantToolCallMessage(
|
||||
content="Checking PostHog documentation...",
|
||||
id="00149847-0bcd-4b7c-a514-bab176312ae8",
|
||||
parent_tool_call_id=None,
|
||||
tool_call_id="1",
|
||||
type="tool",
|
||||
ui_payload=None,
|
||||
),
|
||||
AssistantToolCallMessage(content="Checking PostHog documentation...", tool_call_id="1"),
|
||||
),
|
||||
(
|
||||
"message",
|
||||
@@ -1283,12 +1251,10 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest):
|
||||
)
|
||||
|
||||
@title_generator_mock
|
||||
@query_executor_mock
|
||||
@patch("ee.hogai.graph.schema_generator.nodes.SchemaGeneratorNode._model")
|
||||
@patch("ee.hogai.graph.query_planner.nodes.QueryPlannerNode._get_model")
|
||||
@patch("ee.hogai.graph.graph.QueryExecutorNode")
|
||||
async def test_insights_tool_mode_flow(
|
||||
self, query_executor_mock, planner_mock, generator_mock, title_generator_mock
|
||||
):
|
||||
async def test_insights_tool_mode_flow(self, planner_mock, generator_mock, title_generator_mock):
|
||||
"""Test that the insights tool mode works correctly."""
|
||||
query = AssistantTrendsQuery(series=[])
|
||||
tool_call_id = str(uuid4())
|
||||
@@ -1315,25 +1281,6 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest):
|
||||
)
|
||||
generator_mock.return_value = RunnableLambda(lambda _: TrendsSchemaGeneratorOutput(query=query))
|
||||
|
||||
class QueryExecutorNodeMock(BaseAssistantNode):
|
||||
def __init__(self, team, user):
|
||||
super().__init__(team, user)
|
||||
|
||||
@property
|
||||
def node_name(self):
|
||||
return AssistantNodeName.QUERY_EXECUTOR
|
||||
|
||||
async def arun(self, state, config):
|
||||
return PartialAssistantState(
|
||||
messages=[
|
||||
AssistantToolCallMessage(
|
||||
content="The results indicate a great future for you.", tool_call_id=tool_call_id
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
query_executor_mock.return_value = QueryExecutorNodeMock(self.team, self.user)
|
||||
|
||||
# Run in insights tool mode
|
||||
output, _ = await self._run_assistant_graph(
|
||||
conversation=self.conversation,
|
||||
@@ -1347,9 +1294,7 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest):
|
||||
("message", VisualizationMessage(query="Foobar", answer=query, plan="Plan")),
|
||||
(
|
||||
"message",
|
||||
AssistantToolCallMessage(
|
||||
content="The results indicate a great future for you.", tool_call_id=tool_call_id
|
||||
),
|
||||
AssistantToolCallMessage(content="Result", tool_call_id=tool_call_id),
|
||||
),
|
||||
]
|
||||
self.assertConversationEqual(output, expected_output)
|
||||
@@ -1387,7 +1332,6 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest):
|
||||
|
||||
async def test_merges_messages_with_same_id(self):
|
||||
"""Test that messages with the same ID are merged into one."""
|
||||
from ee.hogai.graph.base import BaseAssistantNode
|
||||
|
||||
message_ids = [str(uuid4()), str(uuid4())]
|
||||
|
||||
@@ -1395,7 +1339,7 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest):
|
||||
first_content = "First version of message"
|
||||
updated_content = "Updated version of message"
|
||||
|
||||
class MessageUpdatingNode(BaseAssistantNode):
|
||||
class MessageUpdatingNode(AssistantNode):
|
||||
def __init__(self, team, user):
|
||||
super().__init__(team, user)
|
||||
self.call_count = 0
|
||||
@@ -1447,7 +1391,6 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest):
|
||||
|
||||
async def test_assistant_filters_messages_correctly(self):
|
||||
"""Test that the Assistant class correctly filters messages based on should_output_assistant_message."""
|
||||
from ee.hogai.graph.base import BaseAssistantNode
|
||||
|
||||
output_messages = [
|
||||
# Should be output (has content)
|
||||
@@ -1467,7 +1410,7 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest):
|
||||
|
||||
for test_message, expected_in_output in output_messages:
|
||||
# Create a simple graph that produces different message types to test filtering
|
||||
class MessageFilteringNode(BaseAssistantNode):
|
||||
class MessageFilteringNode(AssistantNode):
|
||||
def __init__(self, team, user, message_to_return):
|
||||
super().__init__(team, user)
|
||||
self.message_to_return = message_to_return
|
||||
@@ -1477,8 +1420,7 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest):
|
||||
return AssistantNodeName.ROOT
|
||||
|
||||
async def arun(self, state, config):
|
||||
self.dispatcher.message(self.message_to_return)
|
||||
return PartialAssistantState()
|
||||
return PartialAssistantState(messages=[self.message_to_return])
|
||||
|
||||
# Create a graph with our test node
|
||||
node = MessageFilteringNode(self.team, self.user, test_message)
|
||||
@@ -1505,12 +1447,13 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest):
|
||||
"""Test that ui_context persists when retrieving conversation state across multiple runs."""
|
||||
|
||||
# Create a simple graph that just returns the initial state
|
||||
def return_initial_state(state):
|
||||
return {"messages": [AssistantMessage(content="Response from assistant")]}
|
||||
class ReturnInitialStateNode(AssistantNode):
|
||||
async def arun(self, state, config):
|
||||
return PartialAssistantState(messages=[AssistantMessage(content="Response from assistant")])
|
||||
|
||||
graph = (
|
||||
AssistantGraph(self.team, self.user)
|
||||
.add_node(AssistantNodeName.ROOT, return_initial_state)
|
||||
.add_node(AssistantNodeName.ROOT, ReturnInitialStateNode(self.team, self.user))
|
||||
.add_edge(AssistantNodeName.START, AssistantNodeName.ROOT)
|
||||
.add_edge(AssistantNodeName.ROOT, AssistantNodeName.END)
|
||||
.compile()
|
||||
@@ -1585,8 +1528,8 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest):
|
||||
tool_calls=[
|
||||
{
|
||||
"id": "xyz",
|
||||
"name": "edit_current_insight",
|
||||
"args": {"query_description": "Foobar", "query_kind": "trends"},
|
||||
"name": "create_and_query_insight",
|
||||
"args": {"query_description": "Foobar"},
|
||||
}
|
||||
],
|
||||
)
|
||||
@@ -1622,20 +1565,13 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest):
|
||||
output, assistant = await self._run_assistant_graph(
|
||||
test_graph=AssistantGraph(self.team, self.user)
|
||||
.add_edge(AssistantNodeName.START, AssistantNodeName.ROOT)
|
||||
.add_root(
|
||||
{
|
||||
"root": AssistantNodeName.ROOT,
|
||||
"insights": AssistantNodeName.INSIGHTS_SUBGRAPH,
|
||||
"end": AssistantNodeName.END,
|
||||
}
|
||||
)
|
||||
.add_insights()
|
||||
.add_root()
|
||||
.compile(),
|
||||
conversation=self.conversation,
|
||||
is_new_conversation=True,
|
||||
message="Hello",
|
||||
mode=AssistantMode.ASSISTANT,
|
||||
contextual_tools={"edit_current_insight": {"current_query": "query"}},
|
||||
contextual_tools={"create_and_query_insight": {"current_query": "query"}},
|
||||
)
|
||||
|
||||
expected_output = [
|
||||
@@ -1645,23 +1581,30 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest):
|
||||
"message",
|
||||
AssistantMessage(
|
||||
content="",
|
||||
id="56076433-5d90-4248-9a46-df3fda42bd0a",
|
||||
tool_calls=[
|
||||
AssistantToolCall(
|
||||
args={"query_description": "Foobar"},
|
||||
id="xyz",
|
||||
name="edit_current_insight",
|
||||
args={"query_description": "Foobar", "query_kind": "trends"},
|
||||
name="create_and_query_insight",
|
||||
type="tool_call",
|
||||
)
|
||||
],
|
||||
),
|
||||
),
|
||||
(
|
||||
"update",
|
||||
AssistantUpdateEvent(content="Picking relevant events and properties", tool_call_id="xyz", id=""),
|
||||
),
|
||||
("update", AssistantUpdateEvent(content="Creating trends query", tool_call_id="xyz", id="")),
|
||||
("message", VisualizationMessage(query="Foobar", answer=query, plan="Plan")),
|
||||
(
|
||||
"message",
|
||||
{
|
||||
"tool_call_id": "xyz",
|
||||
"type": "tool",
|
||||
"ui_payload": {"edit_current_insight": query.model_dump(exclude_none=True)},
|
||||
}, # Don't check content as it's implementation detail
|
||||
AssistantToolCallMessage(
|
||||
content="The results indicate a great future for you.",
|
||||
tool_call_id="xyz",
|
||||
ui_payload={"create_and_query_insight": query.model_dump(exclude_none=True)},
|
||||
),
|
||||
),
|
||||
("message", AssistantMessage(content="Everything is fine")),
|
||||
]
|
||||
@@ -1671,7 +1614,7 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest):
|
||||
state = AssistantState.model_validate(snapshot.values)
|
||||
expected_state_messages = [
|
||||
ContextMessage(
|
||||
content="<system_reminder>\nContextual tools that are available to you on this page are:\n<edit_current_insight>\nThe user is currently editing an insight (aka query). Here is that insight's current definition, which can be edited using the `edit_current_insight` tool:\n\n```json\nquery\n```\n\nIMPORTANT: DO NOT REMOVE ANY FIELDS FROM THE CURRENT INSIGHT DEFINITION. DO NOT CHANGE ANY OTHER FIELDS THAN THE ONES THE USER ASKED FOR. KEEP THE REST AS IS.\n</edit_current_insight>\nIMPORTANT: this context may or may not be relevant to your tasks. You should not respond to this context unless it is highly relevant to your task.\n</system_reminder>"
|
||||
content="<system_reminder>\nContextual tools that are available to you on this page are:\n<create_and_query_insight>\nThe user is currently editing an insight (aka query). Here is that insight's current definition, which can be edited using the `create_and_query_insight` tool:\n\n```json\nquery\n```\n\n<system_reminder>\nDo not remove any fields from the current insight definition. Do not change any other fields than the ones the user asked for. Keep the rest as is.\n</system_reminder>\n</create_and_query_insight>\nIMPORTANT: this context may or may not be relevant to your tasks. You should not respond to this context unless it is highly relevant to your task.\n</system_reminder>"
|
||||
),
|
||||
HumanMessage(content="Hello"),
|
||||
AssistantMessage(
|
||||
@@ -1679,8 +1622,8 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest):
|
||||
tool_calls=[
|
||||
AssistantToolCall(
|
||||
id="xyz",
|
||||
name="edit_current_insight",
|
||||
args={"query_description": "Foobar", "query_kind": "trends"},
|
||||
name="create_and_query_insight",
|
||||
args={"query_description": "Foobar"},
|
||||
)
|
||||
],
|
||||
),
|
||||
@@ -1688,7 +1631,7 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest):
|
||||
AssistantToolCallMessage(
|
||||
content="The results indicate a great future for you.",
|
||||
tool_call_id="xyz",
|
||||
ui_payload={"edit_current_insight": query.model_dump(exclude_none=True)},
|
||||
ui_payload={"create_and_query_insight": query.model_dump(exclude_none=True)},
|
||||
),
|
||||
AssistantMessage(content="Everything is fine"),
|
||||
]
|
||||
@@ -1705,7 +1648,7 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest):
|
||||
graph = (
|
||||
AssistantGraph(self.team, self.user)
|
||||
.add_edge(AssistantNodeName.START, AssistantNodeName.ROOT)
|
||||
.add_root({"root": AssistantNodeName.ROOT, "end": AssistantNodeName.END})
|
||||
.add_root()
|
||||
.compile()
|
||||
)
|
||||
|
||||
@@ -1757,16 +1700,14 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest):
|
||||
# Tests for ainvoke method
|
||||
async def test_ainvoke_basic_functionality(self):
|
||||
"""Test ainvoke returns all messages at once without streaming."""
|
||||
from ee.hogai.graph.base import BaseAssistantNode
|
||||
|
||||
class TestNode(BaseAssistantNode):
|
||||
class TestNode(AssistantNode):
|
||||
@property
|
||||
def node_name(self):
|
||||
return AssistantNodeName.ROOT
|
||||
|
||||
async def arun(self, state, config):
|
||||
self.dispatcher.message(AssistantMessage(content="Response"))
|
||||
return PartialAssistantState()
|
||||
return PartialAssistantState(messages=[AssistantMessage(content="Response", id=str(uuid4()))])
|
||||
|
||||
test_node = TestNode(self.team, self.user)
|
||||
graph = (
|
||||
@@ -1798,28 +1739,6 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest):
|
||||
self.assertIsInstance(item[1], AssistantMessage)
|
||||
self.assertEqual(cast(AssistantMessage, item[1]).content, "Response")
|
||||
|
||||
async def test_process_value_update_returns_none(self):
|
||||
"""Test that _aprocess_value_update returns None for basic state updates (ACKs are now handled by reducer)."""
|
||||
|
||||
assistant = Assistant.create(
|
||||
self.team, self.conversation, new_message=HumanMessage(content="Hello"), user=self.user
|
||||
)
|
||||
|
||||
# Create a value update tuple that doesn't match special nodes
|
||||
update = cast(
|
||||
GraphValueUpdateTuple,
|
||||
(
|
||||
AssistantNodeName.ROOT,
|
||||
{"root": {"messages": []}}, # Empty update that doesn't match visualization or verbose nodes
|
||||
),
|
||||
)
|
||||
|
||||
# Process the update
|
||||
result = await assistant._aprocess_value_update(update)
|
||||
|
||||
# Should return None (ACK events are now generated by the reducer)
|
||||
self.assertIsNone(result)
|
||||
|
||||
def test_billing_context_in_config(self):
|
||||
billing_context = MaxBillingContext(
|
||||
has_active_subscription=True,
|
||||
@@ -1855,7 +1774,225 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest):
|
||||
)
|
||||
|
||||
config = assistant._get_config()
|
||||
self.assertEqual(config.get("configurable", {}).get("billing_context"), billing_context)
|
||||
self.assertEqual(config["configurable"]["billing_context"], billing_context)
|
||||
|
||||
@patch("ee.hogai.context.context.AssistantContextManager.check_user_has_billing_access", return_value=True)
|
||||
@patch("ee.hogai.graph.root.nodes.RootNode._get_model")
|
||||
async def test_billing_tool_execution(self, root_mock, access_mock):
|
||||
"""Test that the billing tool can be called and returns formatted billing information."""
|
||||
billing_context = MaxBillingContext(
|
||||
subscription_level=MaxBillingContextSubscriptionLevel.PAID,
|
||||
billing_plan="startup",
|
||||
has_active_subscription=True,
|
||||
is_deactivated=False,
|
||||
products=[
|
||||
MaxProductInfo(
|
||||
name="Product Analytics",
|
||||
type="analytics",
|
||||
description="Track user behavior",
|
||||
current_usage=50000,
|
||||
usage_limit=100000,
|
||||
has_exceeded_limit=False,
|
||||
is_used=True,
|
||||
percentage_usage=0.5,
|
||||
addons=[],
|
||||
)
|
||||
],
|
||||
settings=MaxBillingContextSettings(autocapture_on=True, active_destinations=2),
|
||||
)
|
||||
|
||||
# Mock the root node to call the read_data tool with billing_info kind
|
||||
tool_call_id = str(uuid4())
|
||||
|
||||
def root_side_effect(msgs: list[BaseMessage]):
|
||||
# Check if we've already received a tool result
|
||||
last_message = msgs[-1]
|
||||
if (
|
||||
isinstance(last_message.content, list)
|
||||
and isinstance(last_message.content[-1], dict)
|
||||
and last_message.content[-1]["type"] == "tool_result"
|
||||
):
|
||||
# After tool execution, respond with final message
|
||||
return messages.AIMessage(content="Your billing information shows you're on a startup plan.")
|
||||
|
||||
# First call - request the billing tool
|
||||
return messages.AIMessage(
|
||||
content="",
|
||||
tool_calls=[
|
||||
{
|
||||
"id": tool_call_id,
|
||||
"name": "read_data",
|
||||
"args": {"kind": "billing_info"},
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
root_mock.return_value = FakeAnthropicRunnableLambdaWithTokenCounter(root_side_effect)
|
||||
|
||||
# Create a minimal test graph
|
||||
test_graph = (
|
||||
AssistantGraph(self.team, self.user)
|
||||
.add_edge(AssistantNodeName.START, AssistantNodeName.ROOT)
|
||||
.add_root()
|
||||
.compile()
|
||||
)
|
||||
|
||||
# Run the assistant with billing context
|
||||
assistant = Assistant.create(
|
||||
team=self.team,
|
||||
conversation=self.conversation,
|
||||
user=self.user,
|
||||
new_message=HumanMessage(content="What's my current billing status?"),
|
||||
billing_context=billing_context,
|
||||
)
|
||||
assistant._graph = test_graph
|
||||
|
||||
output: list[AssistantOutput] = []
|
||||
async for event in assistant.astream():
|
||||
output.append(event)
|
||||
|
||||
# Verify we received messages
|
||||
self.assertGreater(len(output), 0)
|
||||
|
||||
# Find the assistant's final response
|
||||
assistant_messages = [msg for event_type, msg in output if isinstance(msg, AssistantMessage)]
|
||||
self.assertGreater(len(assistant_messages), 0)
|
||||
|
||||
# Verify the assistant received and used the billing information
|
||||
# The mock returns "Your billing information shows you're on a startup plan."
|
||||
final_message = cast(AssistantMessage, assistant_messages[-1])
|
||||
self.assertIn("billing", final_message.content.lower())
|
||||
self.assertIn("startup", final_message.content.lower())
|
||||
|
||||
async def test_messages_without_id_are_yielded(self):
|
||||
"""Test that messages without ID are always yielded."""
|
||||
|
||||
class MessageWithoutIdNode(AssistantNode):
|
||||
call_count = 0
|
||||
|
||||
async def arun(self, state, config):
|
||||
self.call_count += 1
|
||||
# Return message without ID - should always be yielded
|
||||
return PartialAssistantState(
|
||||
messages=[
|
||||
AssistantMessage(content=f"Message {self.call_count} without ID"),
|
||||
AssistantMessage(content=f"Message {self.call_count} without ID"),
|
||||
]
|
||||
)
|
||||
|
||||
node = MessageWithoutIdNode(self.team, self.user)
|
||||
graph = (
|
||||
AssistantGraph(self.team, self.user)
|
||||
.add_node(AssistantNodeName.ROOT, node)
|
||||
.add_edge(AssistantNodeName.START, AssistantNodeName.ROOT)
|
||||
.add_edge(AssistantNodeName.ROOT, AssistantNodeName.END)
|
||||
.compile()
|
||||
)
|
||||
|
||||
# Run the assistant multiple times
|
||||
output1, _ = await self._run_assistant_graph(graph, message="First run", conversation=self.conversation)
|
||||
output2, _ = await self._run_assistant_graph(graph, message="Second run", conversation=self.conversation)
|
||||
|
||||
# Both runs should yield their messages (human + assistant message each)
|
||||
self.assertEqual(len(output1), 3) # Human message + AI message + AI message
|
||||
self.assertEqual(len(output2), 3) # Human message + AI message + AI message
|
||||
|
||||
async def test_messages_with_id_are_deduplicated(self):
|
||||
"""Test that messages with ID are deduplicated during streaming."""
|
||||
message_id = str(uuid4())
|
||||
|
||||
class DuplicateMessageNode(AssistantNode):
|
||||
call_count = 0
|
||||
|
||||
async def arun(self, state, config):
|
||||
self.call_count += 1
|
||||
# Always return the same message with same ID
|
||||
return PartialAssistantState(
|
||||
messages=[
|
||||
AssistantMessage(id=message_id, content=f"Call {self.call_count}"),
|
||||
AssistantMessage(id=message_id, content=f"Call {self.call_count}"),
|
||||
]
|
||||
)
|
||||
|
||||
node = DuplicateMessageNode(self.team, self.user)
|
||||
graph = (
|
||||
AssistantGraph(self.team, self.user)
|
||||
.add_node(AssistantNodeName.ROOT, node)
|
||||
.add_edge(AssistantNodeName.START, AssistantNodeName.ROOT)
|
||||
.add_edge(AssistantNodeName.ROOT, AssistantNodeName.END)
|
||||
.compile()
|
||||
)
|
||||
|
||||
# Create assistant and manually test the streaming behavior
|
||||
assistant = Assistant.create(
|
||||
self.team,
|
||||
self.conversation,
|
||||
new_message=HumanMessage(content="Test message"),
|
||||
user=self.user,
|
||||
is_new_conversation=False,
|
||||
)
|
||||
assistant._graph = graph
|
||||
|
||||
# Collect all streamed messages
|
||||
streamed_messages = []
|
||||
async for event_type, message in assistant.astream(stream_first_message=False):
|
||||
if event_type == AssistantEventType.MESSAGE:
|
||||
streamed_messages.append(message)
|
||||
|
||||
# Should only get one message despite the node being called multiple times
|
||||
assistant_messages = [
|
||||
msg for msg in streamed_messages if isinstance(msg, AssistantMessage) and msg.id == message_id
|
||||
]
|
||||
self.assertEqual(len(assistant_messages), 1, "Message with same ID should only be yielded once")
|
||||
|
||||
async def test_replaced_messaged_are_not_double_streamed(self):
|
||||
"""Test that existing messages are not streamed again"""
|
||||
# Create messages with IDs that should be tracked
|
||||
message_id_1 = str(uuid4())
|
||||
message_id_2 = str(uuid4())
|
||||
call_count = [0]
|
||||
|
||||
# Create a simple graph that returns messages with IDs
|
||||
class TestNode(AssistantNode):
|
||||
async def arun(self, state, config):
|
||||
result = None
|
||||
if call_count[0] == 0:
|
||||
result = PartialAssistantState(
|
||||
messages=[
|
||||
AssistantMessage(id=message_id_1, content="Message 1"),
|
||||
]
|
||||
)
|
||||
else:
|
||||
result = PartialAssistantState(
|
||||
messages=ReplaceMessages(
|
||||
[
|
||||
AssistantMessage(id=message_id_1, content="Message 1"),
|
||||
AssistantMessage(id=message_id_2, content="Message 2"),
|
||||
]
|
||||
)
|
||||
)
|
||||
call_count[0] += 1
|
||||
return result
|
||||
|
||||
graph = (
|
||||
AssistantGraph(self.team, self.user)
|
||||
.add_node(AssistantNodeName.ROOT, TestNode(self.team, self.user))
|
||||
.add_edge(AssistantNodeName.START, AssistantNodeName.ROOT)
|
||||
.add_edge(AssistantNodeName.ROOT, AssistantNodeName.END)
|
||||
.compile()
|
||||
)
|
||||
|
||||
output, _ = await self._run_assistant_graph(graph, message="First run", conversation=self.conversation)
|
||||
# Filter for assistant messages only, as the test is about tracking assistant message IDs
|
||||
assistant_output = [(event_type, msg) for event_type, msg in output if isinstance(msg, AssistantMessage)]
|
||||
self.assertEqual(len(assistant_output), 1)
|
||||
self.assertEqual(cast(AssistantMessage, assistant_output[0][1]).id, message_id_1)
|
||||
|
||||
output, _ = await self._run_assistant_graph(graph, message="Second run", conversation=self.conversation)
|
||||
# Filter for assistant messages only, as the test is about tracking assistant message IDs
|
||||
assistant_output = [(event_type, msg) for event_type, msg in output if isinstance(msg, AssistantMessage)]
|
||||
self.assertEqual(len(assistant_output), 1)
|
||||
self.assertEqual(cast(AssistantMessage, assistant_output[0][1]).id, message_id_2)
|
||||
|
||||
@patch(
|
||||
"ee.hogai.graph.conversation_summarizer.nodes.AnthropicConversationSummarizer.summarize",
|
||||
@@ -1887,22 +2024,7 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest):
|
||||
mock_tool.return_value = ("Event list" * 128000, None)
|
||||
mock_should_compact.side_effect = cycle([False, True]) # Also changed this
|
||||
|
||||
graph = (
|
||||
AssistantGraph(self.team, self.user)
|
||||
.add_root(
|
||||
path_map={
|
||||
"insights": AssistantNodeName.END,
|
||||
"search_documentation": AssistantNodeName.END,
|
||||
"root": AssistantNodeName.ROOT,
|
||||
"end": AssistantNodeName.END,
|
||||
"insights_search": AssistantNodeName.END,
|
||||
"session_summarization": AssistantNodeName.END,
|
||||
"create_dashboard": AssistantNodeName.END,
|
||||
}
|
||||
)
|
||||
.add_memory_onboarding()
|
||||
.compile()
|
||||
)
|
||||
graph = AssistantGraph(self.team, self.user).add_root().add_memory_onboarding().compile()
|
||||
|
||||
expected_output = [
|
||||
("message", HumanMessage(content="First")),
|
||||
@@ -1928,3 +2050,86 @@ class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest):
|
||||
self.assertEqual(state.start_id, new_human_message.id)
|
||||
# should be equal to the summary message, minus reasoning message
|
||||
self.assertEqual(state.root_conversation_start_id, state.messages[3].id)
|
||||
|
||||
@patch("ee.hogai.graph.root.tools.search.SearchTool._arun_impl", return_value=("Docs doubt it", None))
|
||||
@patch(
|
||||
"ee.hogai.graph.root.tools.read_taxonomy.ReadTaxonomyTool._run_impl",
|
||||
return_value=("Hedgehogs have not talked yet", None),
|
||||
)
|
||||
@patch("ee.hogai.graph.root.nodes.RootNode._get_model")
|
||||
async def test_root_node_can_execute_multiple_tool_calls(self, root_mock, search_mock, read_taxonomy_mock):
|
||||
"""Test that the root node can execute multiple tool calls in parallel."""
|
||||
tool_call_id1, tool_call_id2 = [str(uuid4()), str(uuid4())]
|
||||
|
||||
def root_side_effect(msgs: list[BaseMessage]):
|
||||
# Check if we've already received a tool result
|
||||
last_message = msgs[-1]
|
||||
if (
|
||||
isinstance(last_message.content, list)
|
||||
and isinstance(last_message.content[-1], dict)
|
||||
and last_message.content[-1]["type"] == "tool_result"
|
||||
):
|
||||
# After tool execution, respond with final message
|
||||
return messages.AIMessage(content="No")
|
||||
|
||||
return messages.AIMessage(
|
||||
content="Not sure. Let me check.",
|
||||
tool_calls=[
|
||||
{
|
||||
"id": tool_call_id1,
|
||||
"name": "search",
|
||||
"args": {"kind": "docs", "query": "Do hedgehogs speak?"},
|
||||
},
|
||||
{
|
||||
"id": tool_call_id2,
|
||||
"name": "read_taxonomy",
|
||||
"args": {"query": {"kind": "events"}},
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
root_mock.return_value = FakeAnthropicRunnableLambdaWithTokenCounter(root_side_effect)
|
||||
|
||||
# Create a minimal test graph
|
||||
graph = (
|
||||
AssistantGraph(self.team, self.user)
|
||||
.add_edge(AssistantNodeName.START, AssistantNodeName.ROOT)
|
||||
.add_root()
|
||||
.compile()
|
||||
)
|
||||
|
||||
expected_output = [
|
||||
(AssistantEventType.MESSAGE, HumanMessage(content="Do hedgehogs speak?")),
|
||||
(
|
||||
AssistantEventType.MESSAGE,
|
||||
AssistantMessage(
|
||||
content="Not sure. Let me check.",
|
||||
tool_calls=[
|
||||
{
|
||||
"id": tool_call_id1,
|
||||
"name": "search",
|
||||
"args": {"kind": "docs", "query": "Do hedgehogs speak?"},
|
||||
},
|
||||
{
|
||||
"id": tool_call_id2,
|
||||
"name": "read_taxonomy",
|
||||
"args": {"query": {"kind": "events"}},
|
||||
},
|
||||
],
|
||||
),
|
||||
),
|
||||
(
|
||||
AssistantEventType.MESSAGE,
|
||||
AssistantToolCallMessage(content="Docs doubt it", tool_call_id=tool_call_id1),
|
||||
),
|
||||
(
|
||||
AssistantEventType.MESSAGE,
|
||||
AssistantToolCallMessage(content="Hedgehogs have not talked yet", tool_call_id=tool_call_id2),
|
||||
),
|
||||
(AssistantEventType.MESSAGE, AssistantMessage(content="No")),
|
||||
]
|
||||
output, _ = await self._run_assistant_graph(
|
||||
graph, message="Do hedgehogs speak?", conversation=self.conversation
|
||||
)
|
||||
|
||||
self.assertConversationEqual(output, expected_output)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import json
|
||||
import pkgutil
|
||||
import importlib
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Literal, Self
|
||||
|
||||
@@ -16,9 +17,9 @@ from posthog.models import Team, User
|
||||
import products
|
||||
|
||||
from ee.hogai.context.context import AssistantContextManager
|
||||
from ee.hogai.graph.mixins import AssistantContextMixin
|
||||
from ee.hogai.utils.types import AssistantState
|
||||
from ee.hogai.utils.types.base import AssistantMessageUnion
|
||||
from ee.hogai.graph.base.context import get_node_path, set_node_path
|
||||
from ee.hogai.graph.mixins import AssistantContextMixin, AssistantDispatcherMixin
|
||||
from ee.hogai.utils.types.base import AssistantMessageUnion, AssistantState, NodePath
|
||||
|
||||
CONTEXTUAL_TOOL_NAME_TO_TOOL: dict[AssistantTool, type["MaxTool"]] = {}
|
||||
|
||||
@@ -51,7 +52,7 @@ class ToolMessagesArtifact(BaseModel):
|
||||
messages: Sequence[AssistantMessageUnion]
|
||||
|
||||
|
||||
class MaxTool(AssistantContextMixin, BaseTool):
|
||||
class MaxTool(AssistantContextMixin, AssistantDispatcherMixin, BaseTool):
|
||||
# LangChain's default is just "content", but we always want to return the tool call artifact too
|
||||
# - it becomes the `ui_payload`
|
||||
response_format: Literal["content_and_artifact"] = "content_and_artifact"
|
||||
@@ -66,7 +67,7 @@ class MaxTool(AssistantContextMixin, BaseTool):
|
||||
_config: RunnableConfig
|
||||
_state: AssistantState
|
||||
_context_manager: AssistantContextManager
|
||||
_tool_call_id: str
|
||||
_node_path: tuple[NodePath, ...]
|
||||
|
||||
# DEPRECATED: Use `_arun_impl` instead
|
||||
def _run_impl(self, *args, **kwargs) -> tuple[str, Any]:
|
||||
@@ -82,7 +83,7 @@ class MaxTool(AssistantContextMixin, BaseTool):
|
||||
*,
|
||||
team: Team,
|
||||
user: User,
|
||||
tool_call_id: str,
|
||||
node_path: tuple[NodePath, ...] | None = None,
|
||||
state: AssistantState | None = None,
|
||||
config: RunnableConfig | None = None,
|
||||
name: str | None = None,
|
||||
@@ -102,7 +103,10 @@ class MaxTool(AssistantContextMixin, BaseTool):
|
||||
super().__init__(**tool_kwargs, **kwargs)
|
||||
self._team = team
|
||||
self._user = user
|
||||
self._tool_call_id = tool_call_id
|
||||
if node_path is None:
|
||||
self._node_path = (*(get_node_path() or ()), NodePath(name=self.node_name))
|
||||
else:
|
||||
self._node_path = node_path
|
||||
self._state = state if state else AssistantState(messages=[])
|
||||
self._config = config if config else RunnableConfig(configurable={})
|
||||
self._context_manager = context_manager or AssistantContextManager(team, user, self._config)
|
||||
@@ -120,19 +124,35 @@ class MaxTool(AssistantContextMixin, BaseTool):
|
||||
CONTEXTUAL_TOOL_NAME_TO_TOOL[accepted_name] = cls
|
||||
|
||||
def _run(self, *args, config: RunnableConfig, **kwargs):
|
||||
"""LangChain default runner."""
|
||||
try:
|
||||
return self._run_impl(*args, **kwargs)
|
||||
return self._run_with_context(*args, **kwargs)
|
||||
except NotImplementedError:
|
||||
pass
|
||||
return async_to_sync(self._arun_impl)(*args, **kwargs)
|
||||
return async_to_sync(self._arun_with_context)(*args, **kwargs)
|
||||
|
||||
async def _arun(self, *args, config: RunnableConfig, **kwargs):
|
||||
"""LangChain default runner."""
|
||||
try:
|
||||
return await self._arun_impl(*args, **kwargs)
|
||||
return await self._arun_with_context(*args, **kwargs)
|
||||
except NotImplementedError:
|
||||
pass
|
||||
return await super()._arun(*args, config=config, **kwargs)
|
||||
|
||||
def _run_with_context(self, *args, **kwargs):
|
||||
"""Sets the context for the tool."""
|
||||
with set_node_path(self.node_path):
|
||||
return self._run_impl(*args, **kwargs)
|
||||
|
||||
async def _arun_with_context(self, *args, **kwargs):
|
||||
"""Sets the context for the tool."""
|
||||
with set_node_path(self.node_path):
|
||||
return await self._arun_impl(*args, **kwargs)
|
||||
|
||||
@property
|
||||
def node_name(self) -> str:
|
||||
return f"max_tool.{self.get_name()}"
|
||||
|
||||
@property
|
||||
def context(self) -> dict:
|
||||
return self._context_manager.get_contextual_tools().get(self.get_name(), {})
|
||||
@@ -149,7 +169,7 @@ class MaxTool(AssistantContextMixin, BaseTool):
|
||||
*,
|
||||
team: Team,
|
||||
user: User,
|
||||
tool_call_id: str,
|
||||
node_path: tuple[NodePath, ...] | None = None,
|
||||
state: AssistantState | None = None,
|
||||
config: RunnableConfig | None = None,
|
||||
context_manager: AssistantContextManager | None = None,
|
||||
@@ -160,5 +180,31 @@ class MaxTool(AssistantContextMixin, BaseTool):
|
||||
Override this factory to dynamically modify the tool name, description, args schema, etc.
|
||||
"""
|
||||
return cls(
|
||||
team=team, user=user, tool_call_id=tool_call_id, state=state, config=config, context_manager=context_manager
|
||||
team=team, user=user, node_path=node_path, state=state, config=config, context_manager=context_manager
|
||||
)
|
||||
|
||||
|
||||
class MaxSubtool(AssistantDispatcherMixin, ABC):
|
||||
_config: RunnableConfig
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
team: Team,
|
||||
user: User,
|
||||
state: AssistantState,
|
||||
config: RunnableConfig,
|
||||
context_manager: AssistantContextManager,
|
||||
):
|
||||
self._team = team
|
||||
self._user = user
|
||||
self._state = state
|
||||
self._context_manager = context_manager
|
||||
|
||||
@abstractmethod
|
||||
async def execute(self, *args, **kwargs) -> Any:
|
||||
pass
|
||||
|
||||
@property
|
||||
def node_name(self) -> str:
|
||||
return f"max_subtool.{self.__class__.__name__}"
|
||||
|
||||
@@ -1,21 +1,19 @@
|
||||
from collections.abc import Callable
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.config import get_stream_writer
|
||||
from langgraph.types import StreamWriter
|
||||
|
||||
from posthog.schema import AssistantMessage, AssistantToolCallMessage
|
||||
|
||||
from ee.hogai.utils.types.base import (
|
||||
AssistantActionUnion,
|
||||
AssistantDispatcherEvent,
|
||||
AssistantMessageUnion,
|
||||
MessageAction,
|
||||
NodeStartAction,
|
||||
NodePath,
|
||||
UpdateAction,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ee.hogai.utils.types.composed import MaxNodeName
|
||||
|
||||
|
||||
class AssistantDispatcher:
|
||||
"""
|
||||
@@ -26,13 +24,14 @@ class AssistantDispatcher:
|
||||
The dispatcher does NOT update state - it just emits actions to the stream.
|
||||
"""
|
||||
|
||||
_parent_tool_call_id: str | None = None
|
||||
_node_path: tuple[NodePath, ...]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
writer: StreamWriter | Callable[[Any], None],
|
||||
node_name: "MaxNodeName",
|
||||
parent_tool_call_id: str | None = None,
|
||||
node_path: tuple[NodePath, ...],
|
||||
node_name: str,
|
||||
node_run_id: str,
|
||||
):
|
||||
"""
|
||||
Create a dispatcher for a specific node.
|
||||
@@ -40,9 +39,10 @@ class AssistantDispatcher:
|
||||
Args:
|
||||
node_name: The name of the node dispatching actions (for attribution)
|
||||
"""
|
||||
self._node_name = node_name
|
||||
self._writer = writer
|
||||
self._parent_tool_call_id = parent_tool_call_id
|
||||
self._node_path = node_path
|
||||
self._node_name = node_name
|
||||
self._node_run_id = node_run_id
|
||||
|
||||
def dispatch(self, action: AssistantActionUnion) -> None:
|
||||
"""
|
||||
@@ -56,7 +56,11 @@ class AssistantDispatcher:
|
||||
action: Action dict with "type" and "payload" keys
|
||||
"""
|
||||
try:
|
||||
self._writer(AssistantDispatcherEvent(action=action, node_name=self._node_name))
|
||||
self._writer(
|
||||
AssistantDispatcherEvent(
|
||||
action=action, node_path=self._node_path, node_name=self._node_name, node_run_id=self._node_run_id
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
# Log error but don't crash node execution
|
||||
# The dispatcher should be resilient to writer failures
|
||||
@@ -69,34 +73,29 @@ class AssistantDispatcher:
|
||||
"""
|
||||
Dispatch a message to the stream.
|
||||
"""
|
||||
if self._parent_tool_call_id:
|
||||
# If the dispatcher is initialized with a parent tool call id, set the parent tool call id on the message
|
||||
# This is to ensure that the message is associated with the correct tool call
|
||||
# Don't set parent_tool_call_id on:
|
||||
# 1. AssistantToolCallMessage with the same tool_call_id (to avoid self-reference)
|
||||
# 2. AssistantMessage with tool_calls containing the same ID (to avoid cycles)
|
||||
should_skip = False
|
||||
if isinstance(message, AssistantToolCallMessage) and self._parent_tool_call_id == message.tool_call_id:
|
||||
should_skip = True
|
||||
elif isinstance(message, AssistantMessage) and message.tool_calls:
|
||||
# Check if any tool call has the same ID as the parent
|
||||
for tool_call in message.tool_calls:
|
||||
if tool_call.id == self._parent_tool_call_id:
|
||||
should_skip = True
|
||||
break
|
||||
|
||||
if not should_skip:
|
||||
message.parent_tool_call_id = self._parent_tool_call_id
|
||||
self.dispatch(MessageAction(message=message))
|
||||
|
||||
def node_start(self) -> None:
|
||||
"""
|
||||
Dispatch a node start action to the stream.
|
||||
"""
|
||||
self.dispatch(NodeStartAction())
|
||||
def update(self, content: str):
|
||||
"""Dispatch a transient update message to the stream that will be associated with a tool call in the UI."""
|
||||
self.dispatch(UpdateAction(content=content))
|
||||
|
||||
def set_as_root(self) -> None:
|
||||
"""
|
||||
Set the dispatcher as the root.
|
||||
"""
|
||||
self._parent_tool_call_id = None
|
||||
|
||||
def create_dispatcher_from_config(config: RunnableConfig, node_path: tuple[NodePath, ...]) -> AssistantDispatcher:
|
||||
"""Create a dispatcher from a RunnableConfig and node path"""
|
||||
# Set writer from LangGraph context
|
||||
try:
|
||||
writer = get_stream_writer()
|
||||
except RuntimeError:
|
||||
# Not in streaming context (e.g., testing)
|
||||
# Use noop writer
|
||||
def noop(*_args, **_kwargs):
|
||||
pass
|
||||
|
||||
writer = noop
|
||||
|
||||
metadata = config.get("metadata") or {}
|
||||
node_name: str = metadata.get("langgraph_node") or ""
|
||||
# `langgraph_checkpoint_ns` contains the nested path to the node, so it's more accurate for streaming.
|
||||
node_run_id: str = metadata.get("langgraph_checkpoint_ns") or ""
|
||||
|
||||
return AssistantDispatcher(writer, node_path=node_path, node_run_id=node_run_id, node_name=node_name)
|
||||
|
||||
@@ -38,8 +38,7 @@ from posthog.hogql_queries.query_runner import ExecutionMode
|
||||
from posthog.models import Team
|
||||
from posthog.taxonomy.taxonomy import CORE_FILTER_DEFINITIONS_BY_GROUP
|
||||
|
||||
from ee.hogai.utils.types import AssistantMessageUnion
|
||||
from ee.hogai.utils.types.base import AssistantDispatcherEvent
|
||||
from ee.hogai.utils.types.base import AssistantDispatcherEvent, AssistantMessageUnion
|
||||
|
||||
|
||||
def remove_line_breaks(line: str) -> str:
|
||||
|
||||
@@ -5,7 +5,7 @@ from structlog import get_logger
|
||||
|
||||
from ee.hogai.graph.deep_research.types import DeepResearchNodeName, PartialDeepResearchState
|
||||
from ee.hogai.graph.taxonomy.types import TaxonomyAgentState, TaxonomyNodeName
|
||||
from ee.hogai.utils.types import PartialAssistantState
|
||||
from ee.hogai.utils.types.base import PartialAssistantState
|
||||
from ee.hogai.utils.types.composed import AssistantMaxGraphState, AssistantMaxPartialGraphState, MaxNodeName
|
||||
|
||||
# A state update can have a partial state or a LangGraph's reserved dataclasses like Interrupt.
|
||||
@@ -45,6 +45,7 @@ def validate_value_update(
|
||||
|
||||
class LangGraphState(TypedDict):
|
||||
langgraph_node: MaxNodeName
|
||||
langgraph_checkpoint_ns: str
|
||||
|
||||
|
||||
GraphMessageUpdateTuple = tuple[Literal["messages"], tuple[Union[AIMessageChunk, Any], LangGraphState]]
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import cast
|
||||
from typing import Generic, Protocol, TypeVar, cast, get_args
|
||||
|
||||
import structlog
|
||||
from langchain_core.messages import AIMessageChunk
|
||||
@@ -6,7 +6,6 @@ from langchain_core.messages import AIMessageChunk
|
||||
from posthog.schema import (
|
||||
AssistantGenerationStatusEvent,
|
||||
AssistantGenerationStatusType,
|
||||
AssistantMessage,
|
||||
AssistantToolCallMessage,
|
||||
AssistantUpdateEvent,
|
||||
FailureMessage,
|
||||
@@ -16,21 +15,51 @@ from posthog.schema import (
|
||||
)
|
||||
|
||||
from ee.hogai.utils.helpers import normalize_ai_message, should_output_assistant_message
|
||||
from ee.hogai.utils.state import merge_message_chunk
|
||||
from ee.hogai.utils.state import is_message_update, is_state_update, merge_message_chunk
|
||||
from ee.hogai.utils.types.base import (
|
||||
AssistantDispatcherEvent,
|
||||
AssistantGraphName,
|
||||
AssistantMessageUnion,
|
||||
AssistantResultUnion,
|
||||
BaseStateWithMessages,
|
||||
LangGraphUpdateEvent,
|
||||
MessageAction,
|
||||
MessageChunkAction,
|
||||
NodeEndAction,
|
||||
NodePath,
|
||||
NodeStartAction,
|
||||
UpdateAction,
|
||||
)
|
||||
from ee.hogai.utils.types.composed import MaxNodeName
|
||||
|
||||
logger = structlog.get_logger(__name__)
|
||||
|
||||
|
||||
class AssistantStreamProcessor:
|
||||
class AssistantStreamProcessorProtocol(Protocol):
|
||||
"""Protocol defining the interface for assistant stream processors."""
|
||||
|
||||
_streamed_update_ids: set[str]
|
||||
"""Tracks the IDs of messages that have been streamed."""
|
||||
|
||||
def process(self, event: AssistantDispatcherEvent) -> list[AssistantResultUnion] | None:
|
||||
"""Process a dispatcher event and return a result or None."""
|
||||
...
|
||||
|
||||
def process_langgraph_update(self, event: LangGraphUpdateEvent) -> list[AssistantResultUnion] | None:
|
||||
"""Process a LangGraph update event and return a list of results or None."""
|
||||
...
|
||||
|
||||
def mark_id_as_streamed(self, message_id: str) -> None:
|
||||
"""Mark a message ID as streamed."""
|
||||
self._streamed_update_ids.add(message_id)
|
||||
|
||||
|
||||
StateType = TypeVar("StateType", bound=BaseStateWithMessages)
|
||||
|
||||
MESSAGE_TYPE_TUPLE = get_args(AssistantMessageUnion)
|
||||
|
||||
|
||||
class AssistantStreamProcessor(AssistantStreamProcessorProtocol, Generic[StateType]):
|
||||
"""
|
||||
Reduces streamed actions to client-facing messages.
|
||||
|
||||
@@ -38,36 +67,34 @@ class AssistantStreamProcessor:
|
||||
handlers based on action type and message characteristics.
|
||||
"""
|
||||
|
||||
_verbose_nodes: set[MaxNodeName]
|
||||
"""Nodes that emit messages."""
|
||||
_streaming_nodes: set[MaxNodeName]
|
||||
"""Nodes that produce streaming messages."""
|
||||
_visualization_nodes: dict[MaxNodeName, type]
|
||||
"""Nodes that produce visualization messages."""
|
||||
_tool_call_id_to_message: dict[str, AssistantMessage]
|
||||
"""Maps tool call IDs to their parent messages for message chain tracking."""
|
||||
_streamed_update_ids: set[str]
|
||||
"""Tracks the IDs of messages that have been streamed."""
|
||||
_chunks: AIMessageChunk
|
||||
_chunks: dict[str, AIMessageChunk]
|
||||
"""Tracks the current message chunk."""
|
||||
_state: StateType | None
|
||||
"""Tracks the current state."""
|
||||
_state_type: type[StateType]
|
||||
"""The type of the state."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
streaming_nodes: set[MaxNodeName],
|
||||
visualization_nodes: dict[MaxNodeName, type],
|
||||
):
|
||||
def __init__(self, verbose_nodes: set[MaxNodeName], streaming_nodes: set[MaxNodeName], state_type: type[StateType]):
|
||||
"""
|
||||
Initialize the stream processor with node configuration.
|
||||
|
||||
Args:
|
||||
verbose_nodes: Nodes that produce messages
|
||||
streaming_nodes: Nodes that produce streaming messages
|
||||
visualization_nodes: Nodes that produce visualization messages
|
||||
"""
|
||||
# If a node is streaming node, it should also be verbose.
|
||||
self._verbose_nodes = verbose_nodes | streaming_nodes
|
||||
self._streaming_nodes = streaming_nodes
|
||||
self._visualization_nodes = visualization_nodes
|
||||
self._tool_call_id_to_message = {}
|
||||
self._streamed_update_ids = set()
|
||||
self._chunks = AIMessageChunk(content="")
|
||||
self._chunks = {}
|
||||
self._state_type = state_type
|
||||
self._state = None
|
||||
|
||||
def process(self, event: AssistantDispatcherEvent) -> AssistantResultUnion | None:
|
||||
def process(self, event: AssistantDispatcherEvent) -> list[AssistantResultUnion] | None:
|
||||
"""
|
||||
Reduce streamed actions to client messages.
|
||||
|
||||
@@ -75,86 +102,113 @@ class AssistantStreamProcessor:
|
||||
to specialized handlers based on action type and message characteristics.
|
||||
"""
|
||||
action = event.action
|
||||
node_name = event.node_name
|
||||
|
||||
if isinstance(action, MessageChunkAction):
|
||||
return self._handle_message_stream(action.message, cast(MaxNodeName, node_name))
|
||||
|
||||
if isinstance(action, NodeStartAction):
|
||||
return AssistantGenerationStatusEvent(type=AssistantGenerationStatusType.ACK)
|
||||
self._chunks[event.node_run_id] = AIMessageChunk(content="")
|
||||
return [AssistantGenerationStatusEvent(type=AssistantGenerationStatusType.ACK)]
|
||||
|
||||
if isinstance(action, NodeEndAction):
|
||||
if event.node_run_id in self._chunks:
|
||||
del self._chunks[event.node_run_id]
|
||||
return self._handle_node_end(event, action)
|
||||
|
||||
if isinstance(action, MessageChunkAction) and (result := self._handle_message_stream(event, action.message)):
|
||||
return [result]
|
||||
|
||||
if isinstance(action, MessageAction):
|
||||
message = action.message
|
||||
if result := self._handle_message(event, message):
|
||||
return [result]
|
||||
|
||||
# Register any tool calls for later parent chain lookups
|
||||
self._register_tool_calls(message)
|
||||
result = self._handle_message(message, cast(MaxNodeName, node_name))
|
||||
return (
|
||||
result if result is not None else AssistantGenerationStatusEvent(type=AssistantGenerationStatusType.ACK)
|
||||
if isinstance(action, UpdateAction) and (update_event := self._handle_update_message(event, action)):
|
||||
return [update_event]
|
||||
|
||||
return None
|
||||
|
||||
def process_langgraph_update(self, event: LangGraphUpdateEvent) -> list[AssistantResultUnion] | None:
|
||||
"""
|
||||
Process a LangGraph update event.
|
||||
"""
|
||||
if is_message_update(event.update):
|
||||
# Convert the message chunk update to a dispatcher event to prepare for a bright future without LangGraph
|
||||
maybe_message_chunk, state = event.update[1]
|
||||
if not isinstance(maybe_message_chunk, AIMessageChunk):
|
||||
return None
|
||||
action = AssistantDispatcherEvent(
|
||||
action=MessageChunkAction(message=maybe_message_chunk),
|
||||
node_name=state["langgraph_node"],
|
||||
node_run_id=state["langgraph_checkpoint_ns"],
|
||||
)
|
||||
return self.process(action)
|
||||
|
||||
def _find_parent_ids(self, message: AssistantMessage) -> tuple[str | None, str | None]:
|
||||
"""
|
||||
Walk up the message chain to find the root parent's message_id and tool_call_id.
|
||||
if is_state_update(event.update):
|
||||
new_state = self._state_type.model_validate(event.update[1])
|
||||
self._state = new_state
|
||||
|
||||
Returns (root_message_id, root_tool_call_id) for the root message in the chain.
|
||||
Includes cycle detection and max depth protection.
|
||||
"""
|
||||
root_tool_call_id = message.parent_tool_call_id
|
||||
if root_tool_call_id is None:
|
||||
return message.id, None
|
||||
return None
|
||||
|
||||
root_message_id = None
|
||||
visited: set[str] = set()
|
||||
def _handle_message(
|
||||
self, action: AssistantDispatcherEvent, message: AssistantMessageUnion
|
||||
) -> AssistantResultUnion | None:
|
||||
"""Handle a message from a node."""
|
||||
node_name = cast(MaxNodeName, action.node_name)
|
||||
produced_message: AssistantResultUnion | None = None
|
||||
|
||||
while root_tool_call_id is not None:
|
||||
if root_tool_call_id in visited:
|
||||
# Cycle detected, we skip this message
|
||||
return None, None
|
||||
# Output all messages from the top-level graph.
|
||||
if not self._is_message_from_nested_node_or_graph(action.node_path or ()):
|
||||
produced_message = self._handle_root_message(message, node_name)
|
||||
# Other message types with parents (viz, notebook, failure, tool call)
|
||||
else:
|
||||
produced_message = self._handle_special_child_message(message, node_name)
|
||||
|
||||
visited.add(root_tool_call_id)
|
||||
parent_message = self._tool_call_id_to_message.get(root_tool_call_id)
|
||||
if parent_message is None:
|
||||
# The parent message is not registered, we skip this message as it could come
|
||||
# from a sub-nested graph invoked directly by a contextual tool.
|
||||
return None, None
|
||||
# Messages with existing IDs must be deduplicated.
|
||||
# Messages WITHOUT IDs must be streamed because they're progressive.
|
||||
if isinstance(produced_message, MESSAGE_TYPE_TUPLE) and produced_message.id is not None:
|
||||
if produced_message.id in self._streamed_update_ids:
|
||||
return None
|
||||
self._streamed_update_ids.add(produced_message.id)
|
||||
|
||||
next_parent_tool_call_id = parent_message.parent_tool_call_id
|
||||
root_message_id = parent_message.id
|
||||
if next_parent_tool_call_id is None:
|
||||
return root_message_id, root_tool_call_id
|
||||
root_tool_call_id = next_parent_tool_call_id
|
||||
raise ValueError("Should not reach here")
|
||||
return produced_message
|
||||
|
||||
def _register_tool_calls(self, message: AssistantMessageUnion) -> None:
|
||||
"""Register any tool calls in the message for later lookup."""
|
||||
if isinstance(message, AssistantMessage) and message.tool_calls is not None:
|
||||
for tool_call in message.tool_calls:
|
||||
self._tool_call_id_to_message[tool_call.id] = message
|
||||
def _is_message_from_nested_node_or_graph(self, node_path: tuple[NodePath, ...]) -> bool:
|
||||
"""Check if the message is from a nested node or graph."""
|
||||
if not node_path:
|
||||
return False
|
||||
# The first path is always the top-level graph.
|
||||
# The second path is always the top-level node.
|
||||
# If the path is longer than 2, it's a MaxTool or nested graphs.
|
||||
if len(node_path) > 2 and next((path for path in node_path[1:] if path.name in AssistantGraphName), None):
|
||||
return True
|
||||
return False
|
||||
|
||||
def _handle_root_message(
|
||||
self, message: AssistantMessageUnion, node_name: MaxNodeName
|
||||
) -> AssistantMessageUnion | None:
|
||||
"""Handle messages with no parent (root messages)."""
|
||||
if not should_output_assistant_message(message):
|
||||
if node_name not in self._verbose_nodes or not should_output_assistant_message(message):
|
||||
return None
|
||||
return message
|
||||
|
||||
def _handle_assistant_message_with_parent(self, message: AssistantMessage) -> AssistantUpdateEvent | None:
|
||||
def _handle_update_message(
|
||||
self, event: AssistantDispatcherEvent, action: UpdateAction
|
||||
) -> AssistantUpdateEvent | None:
|
||||
"""Handle AssistantMessage that has a parent, creating an AssistantUpdateEvent."""
|
||||
parent_id, parent_tool_call_id = self._find_parent_ids(message)
|
||||
|
||||
if parent_tool_call_id is None or parent_id is None:
|
||||
if not event.node_path or not action.content:
|
||||
return None
|
||||
|
||||
if message.content == "":
|
||||
# Find the closest tool call id to the update.
|
||||
parent_path = next((path for path in reversed(event.node_path) if path.tool_call_id), None)
|
||||
# Updates from the top-level graph nodes are not supported.
|
||||
if not parent_path:
|
||||
return None
|
||||
|
||||
return AssistantUpdateEvent(
|
||||
id=parent_id,
|
||||
tool_call_id=parent_tool_call_id,
|
||||
content=message.content,
|
||||
)
|
||||
tool_call_id = parent_path.tool_call_id
|
||||
message_id = parent_path.message_id
|
||||
|
||||
if not message_id or not tool_call_id:
|
||||
return None
|
||||
|
||||
return AssistantUpdateEvent(id=message_id, tool_call_id=tool_call_id, content=action.content)
|
||||
|
||||
def _handle_special_child_message(
|
||||
self, message: AssistantMessageUnion, node_name: MaxNodeName
|
||||
@@ -164,14 +218,10 @@ class AssistantStreamProcessor:
|
||||
|
||||
These messages are returned as-is regardless of where in the nesting hierarchy they are.
|
||||
"""
|
||||
# Return visualization messages only if from visualization nodes
|
||||
if isinstance(message, VisualizationMessage | MultiVisualizationMessage):
|
||||
if node_name in self._visualization_nodes:
|
||||
return message
|
||||
return None
|
||||
|
||||
# These message types are always returned as-is
|
||||
if isinstance(message, NotebookUpdateMessage | FailureMessage):
|
||||
if isinstance(message, VisualizationMessage | MultiVisualizationMessage) or isinstance(
|
||||
message, NotebookUpdateMessage | FailureMessage
|
||||
):
|
||||
return message
|
||||
|
||||
if isinstance(message, AssistantToolCallMessage):
|
||||
@@ -181,37 +231,44 @@ class AssistantStreamProcessor:
|
||||
# Should not reach here
|
||||
raise ValueError(f"Unhandled special message type: {type(message).__name__}")
|
||||
|
||||
def _handle_message(self, message: AssistantMessageUnion, node_name: MaxNodeName) -> AssistantResultUnion | None:
|
||||
# Messages with existing IDs must be deduplicated.
|
||||
# Messages WITHOUT IDs must be streamed because they're progressive.
|
||||
if hasattr(message, "id") and message.id is not None:
|
||||
if message.id in self._streamed_update_ids:
|
||||
return None
|
||||
self._streamed_update_ids.add(message.id)
|
||||
|
||||
# Root messages (no parent) are filtered by VERBOSE_NODES
|
||||
if message.parent_tool_call_id is None:
|
||||
return self._handle_root_message(message, node_name)
|
||||
|
||||
# AssistantMessage with parent creates AssistantUpdateEvent
|
||||
if isinstance(message, AssistantMessage):
|
||||
return self._handle_assistant_message_with_parent(message)
|
||||
else:
|
||||
# Other message types with parents (viz, notebook, failure, tool call)
|
||||
return self._handle_special_child_message(message, node_name)
|
||||
|
||||
def _handle_message_stream(self, message: AIMessageChunk, node_name: MaxNodeName) -> AssistantResultUnion | None:
|
||||
def _handle_message_stream(
|
||||
self, event: AssistantDispatcherEvent, message: AIMessageChunk
|
||||
) -> AssistantResultUnion | None:
|
||||
"""
|
||||
Process LLM chunks from "messages" stream mode.
|
||||
|
||||
With dispatch pattern, complete messages are dispatched by nodes.
|
||||
This handles AIMessageChunk for ephemeral streaming (responsiveness).
|
||||
"""
|
||||
node_name = cast(MaxNodeName, event.node_name)
|
||||
run_id = event.node_run_id
|
||||
|
||||
if node_name not in self._streaming_nodes:
|
||||
return None
|
||||
if run_id not in self._chunks:
|
||||
self._chunks[run_id] = AIMessageChunk(content="")
|
||||
|
||||
# Merge message chunks
|
||||
self._chunks = merge_message_chunk(self._chunks, message)
|
||||
self._chunks[run_id] = merge_message_chunk(self._chunks[run_id], message)
|
||||
|
||||
# Stream ephemeral message (no ID = not persisted)
|
||||
return normalize_ai_message(self._chunks)
|
||||
return normalize_ai_message(self._chunks[run_id])
|
||||
|
||||
def _handle_node_end(
|
||||
self, event: AssistantDispatcherEvent, action: NodeEndAction
|
||||
) -> list[AssistantResultUnion] | None:
|
||||
"""Handle the end of a node. Reset the streaming chunks."""
|
||||
if not isinstance(action.state, BaseStateWithMessages):
|
||||
return None
|
||||
results: list[AssistantResultUnion] = []
|
||||
for message in action.state.messages:
|
||||
if new_event := self.process(
|
||||
AssistantDispatcherEvent(
|
||||
action=MessageAction(message=message),
|
||||
node_name=event.node_name,
|
||||
node_run_id=event.node_run_id,
|
||||
node_path=event.node_path,
|
||||
)
|
||||
):
|
||||
results.extend(new_event)
|
||||
return results
|
||||
|
||||
@@ -1,165 +1,426 @@
|
||||
"""
|
||||
Comprehensive tests for AssistantDispatcher.
|
||||
|
||||
Tests the dispatcher logic that emits actions to LangGraph custom stream.
|
||||
"""
|
||||
|
||||
from posthog.test.base import BaseTest
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
|
||||
from posthog.schema import AssistantMessage, AssistantToolCall
|
||||
|
||||
from ee.hogai.graph.base import BaseAssistantNode
|
||||
from ee.hogai.utils.dispatcher import AssistantDispatcher
|
||||
from ee.hogai.utils.types import AssistantState, PartialAssistantState
|
||||
from ee.hogai.utils.types.base import AssistantDispatcherEvent, AssistantNodeName
|
||||
from ee.hogai.utils.dispatcher import AssistantDispatcher, create_dispatcher_from_config
|
||||
from ee.hogai.utils.types.base import (
|
||||
AssistantDispatcherEvent,
|
||||
AssistantGraphName,
|
||||
AssistantNodeName,
|
||||
MessageAction,
|
||||
NodePath,
|
||||
UpdateAction,
|
||||
)
|
||||
|
||||
|
||||
class TestMessageDispatcher(BaseTest):
|
||||
class TestAssistantDispatcher(BaseTest):
|
||||
"""Test the AssistantDispatcher in isolation."""
|
||||
|
||||
def setUp(self):
|
||||
self._writer = MagicMock()
|
||||
self._writer.__aiter__ = AsyncMock(return_value=iter([]))
|
||||
self._writer.write = AsyncMock()
|
||||
super().setUp()
|
||||
self.dispatched_events: list[AssistantDispatcherEvent] = []
|
||||
|
||||
def test_create_dispatcher(self):
|
||||
"""Test creating a MessageDispatcher"""
|
||||
dispatcher = AssistantDispatcher(self._writer, node_name=AssistantNodeName.ROOT)
|
||||
def mock_writer(event):
|
||||
self.dispatched_events.append(event)
|
||||
|
||||
self.writer = mock_writer
|
||||
self.node_path = (
|
||||
NodePath(name=AssistantGraphName.ASSISTANT),
|
||||
NodePath(name=AssistantNodeName.ROOT),
|
||||
)
|
||||
|
||||
def test_create_dispatcher_with_basic_params(self):
|
||||
"""Test creating a dispatcher with required parameters."""
|
||||
dispatcher = AssistantDispatcher(
|
||||
writer=self.writer,
|
||||
node_path=self.node_path,
|
||||
node_name=AssistantNodeName.ROOT,
|
||||
node_run_id="test_run_123",
|
||||
)
|
||||
|
||||
self.assertEqual(dispatcher._node_name, AssistantNodeName.ROOT)
|
||||
self.assertEqual(dispatcher._writer, self._writer)
|
||||
self.assertEqual(dispatcher._node_run_id, "test_run_123")
|
||||
self.assertEqual(dispatcher._node_path, self.node_path)
|
||||
self.assertEqual(dispatcher._writer, self.writer)
|
||||
|
||||
async def test_dispatch_with_writer(self):
|
||||
"""Test dispatching with a writer"""
|
||||
dispatched_actions = []
|
||||
def test_dispatch_message_action(self):
|
||||
"""Test dispatching a message via the message() method."""
|
||||
dispatcher = AssistantDispatcher(
|
||||
writer=self.writer,
|
||||
node_path=self.node_path,
|
||||
node_name=AssistantNodeName.ROOT,
|
||||
node_run_id="test_run_456",
|
||||
)
|
||||
|
||||
def mock_writer(data):
|
||||
dispatched_actions.append(data)
|
||||
|
||||
dispatcher = AssistantDispatcher(mock_writer, node_name=AssistantNodeName.ROOT)
|
||||
|
||||
message = AssistantMessage(content="test message", parent_tool_call_id="tc1")
|
||||
message = AssistantMessage(content="Test message")
|
||||
dispatcher.message(message)
|
||||
|
||||
self.assertEqual(len(dispatched_actions), 1)
|
||||
self.assertEqual(len(self.dispatched_events), 1)
|
||||
event = self.dispatched_events[0]
|
||||
|
||||
# Verify action structure
|
||||
event = dispatched_actions[0]
|
||||
self.assertIsInstance(event, AssistantDispatcherEvent)
|
||||
self.assertEqual(event.action.type, "MESSAGE")
|
||||
self.assertIsInstance(event.action, MessageAction)
|
||||
assert isinstance(event.action, MessageAction)
|
||||
self.assertEqual(event.action.message, message)
|
||||
self.assertEqual(event.node_name, AssistantNodeName.ROOT)
|
||||
self.assertEqual(event.node_run_id, "test_run_456")
|
||||
self.assertEqual(event.node_path, self.node_path)
|
||||
|
||||
def test_dispatch_update_action(self):
|
||||
"""Test dispatching an update via the update() method."""
|
||||
dispatcher = AssistantDispatcher(
|
||||
writer=self.writer,
|
||||
node_path=self.node_path,
|
||||
node_name=AssistantNodeName.TRENDS_GENERATOR,
|
||||
node_run_id="test_run_789",
|
||||
)
|
||||
|
||||
dispatcher.update("Processing query...")
|
||||
|
||||
self.assertEqual(len(self.dispatched_events), 1)
|
||||
event = self.dispatched_events[0]
|
||||
|
||||
self.assertIsInstance(event, AssistantDispatcherEvent)
|
||||
self.assertIsInstance(event.action, UpdateAction)
|
||||
assert isinstance(event.action, UpdateAction)
|
||||
self.assertEqual(event.action.content, "Processing query...")
|
||||
self.assertEqual(event.node_name, AssistantNodeName.TRENDS_GENERATOR)
|
||||
self.assertEqual(event.node_run_id, "test_run_789")
|
||||
|
||||
def test_dispatch_multiple_messages(self):
|
||||
"""Test dispatching multiple messages in sequence."""
|
||||
dispatcher = AssistantDispatcher(
|
||||
writer=self.writer,
|
||||
node_path=self.node_path,
|
||||
node_name=AssistantNodeName.ROOT,
|
||||
node_run_id="test_run_multi",
|
||||
)
|
||||
|
||||
message1 = AssistantMessage(content="First message")
|
||||
message2 = AssistantMessage(content="Second message")
|
||||
message3 = AssistantMessage(content="Third message")
|
||||
|
||||
dispatcher.message(message1)
|
||||
dispatcher.message(message2)
|
||||
dispatcher.message(message3)
|
||||
|
||||
self.assertEqual(len(self.dispatched_events), 3)
|
||||
|
||||
contents = []
|
||||
for event in self.dispatched_events:
|
||||
if isinstance(event.action, MessageAction):
|
||||
msg = event.action.message
|
||||
if isinstance(msg, AssistantMessage):
|
||||
contents.append(msg.content)
|
||||
self.assertEqual(contents, ["First message", "Second message", "Third message"])
|
||||
|
||||
def test_dispatch_message_with_tool_calls(self):
|
||||
"""Test dispatching a message with tool calls preserves all data."""
|
||||
dispatcher = AssistantDispatcher(
|
||||
writer=self.writer,
|
||||
node_path=self.node_path,
|
||||
node_name=AssistantNodeName.ROOT,
|
||||
node_run_id="test_run_tools",
|
||||
)
|
||||
|
||||
tool_call = AssistantToolCall(id="tool_123", name="search", args={"query": "test query"})
|
||||
message = AssistantMessage(content="Running search...", tool_calls=[tool_call])
|
||||
|
||||
dispatcher.message(message)
|
||||
|
||||
self.assertEqual(len(self.dispatched_events), 1)
|
||||
event = self.dispatched_events[0]
|
||||
|
||||
self.assertIsInstance(event.action, MessageAction)
|
||||
assert isinstance(event.action, MessageAction)
|
||||
dispatched_message = event.action.message
|
||||
assert isinstance(dispatched_message, AssistantMessage)
|
||||
self.assertEqual(dispatched_message.content, "Running search...")
|
||||
self.assertIsNotNone(dispatched_message.tool_calls)
|
||||
assert dispatched_message.tool_calls is not None
|
||||
self.assertEqual(len(dispatched_message.tool_calls), 1)
|
||||
self.assertEqual(dispatched_message.tool_calls[0].id, "tool_123")
|
||||
self.assertEqual(dispatched_message.tool_calls[0].name, "search")
|
||||
self.assertEqual(dispatched_message.tool_calls[0].args, {"query": "test query"})
|
||||
|
||||
def test_dispatch_with_nested_node_path(self):
|
||||
"""Test that nested node paths are preserved in dispatched events."""
|
||||
nested_path = (
|
||||
NodePath(name=AssistantGraphName.ASSISTANT),
|
||||
NodePath(name=AssistantNodeName.ROOT, message_id="msg_123", tool_call_id="tc_123"),
|
||||
NodePath(name=AssistantGraphName.INSIGHTS),
|
||||
NodePath(name=AssistantNodeName.TRENDS_GENERATOR),
|
||||
)
|
||||
|
||||
dispatcher = AssistantDispatcher(
|
||||
writer=self.writer, node_path=nested_path, node_name=AssistantNodeName.TRENDS_GENERATOR, node_run_id="run_1"
|
||||
)
|
||||
|
||||
message = AssistantMessage(content="Nested message")
|
||||
dispatcher.message(message)
|
||||
|
||||
self.assertEqual(len(self.dispatched_events), 1)
|
||||
event = self.dispatched_events[0]
|
||||
|
||||
self.assertEqual(event.node_path, nested_path)
|
||||
assert event.node_path is not None
|
||||
self.assertEqual(len(event.node_path), 4)
|
||||
self.assertEqual(event.node_path[1].message_id, "msg_123")
|
||||
self.assertEqual(event.node_path[1].tool_call_id, "tc_123")
|
||||
|
||||
def test_dispatch_error_handling_continues_execution(self):
|
||||
"""Test that dispatch errors are caught and logged but don't crash."""
|
||||
|
||||
def failing_writer(event):
|
||||
raise RuntimeError("Writer failed!")
|
||||
|
||||
dispatcher = AssistantDispatcher(
|
||||
writer=failing_writer,
|
||||
node_path=self.node_path,
|
||||
node_name=AssistantNodeName.ROOT,
|
||||
node_run_id="test_run_error",
|
||||
)
|
||||
|
||||
message = AssistantMessage(content="This should not crash")
|
||||
|
||||
# Should not raise exception
|
||||
with patch("logging.getLogger") as mock_get_logger:
|
||||
mock_logger = MagicMock()
|
||||
mock_get_logger.return_value = mock_logger
|
||||
|
||||
dispatcher.message(message)
|
||||
|
||||
# Verify error was logged
|
||||
mock_logger.error.assert_called_once()
|
||||
args, kwargs = mock_logger.error.call_args
|
||||
self.assertIn("Failed to dispatch action", args[0])
|
||||
|
||||
def test_dispatch_mixed_actions(self):
|
||||
"""Test dispatching both messages and updates in sequence."""
|
||||
dispatcher = AssistantDispatcher(
|
||||
writer=self.writer,
|
||||
node_path=self.node_path,
|
||||
node_name=AssistantNodeName.TRENDS_GENERATOR,
|
||||
node_run_id="test_run_mixed",
|
||||
)
|
||||
|
||||
dispatcher.update("Starting analysis...")
|
||||
dispatcher.message(AssistantMessage(content="Found 3 insights"))
|
||||
dispatcher.update("Finalizing results...")
|
||||
|
||||
self.assertEqual(len(self.dispatched_events), 3)
|
||||
|
||||
self.assertIsInstance(self.dispatched_events[0].action, UpdateAction)
|
||||
assert isinstance(self.dispatched_events[0].action, UpdateAction)
|
||||
self.assertEqual(self.dispatched_events[0].action.content, "Starting analysis...")
|
||||
|
||||
self.assertIsInstance(self.dispatched_events[1].action, MessageAction)
|
||||
assert isinstance(self.dispatched_events[1].action, MessageAction)
|
||||
assert isinstance(self.dispatched_events[1].action.message, AssistantMessage)
|
||||
self.assertEqual(self.dispatched_events[1].action.message.content, "Found 3 insights")
|
||||
|
||||
self.assertIsInstance(self.dispatched_events[2].action, UpdateAction)
|
||||
assert isinstance(self.dispatched_events[2].action, UpdateAction)
|
||||
self.assertEqual(self.dispatched_events[2].action.content, "Finalizing results...")
|
||||
|
||||
def test_dispatch_preserves_message_id(self):
|
||||
"""Test that message IDs are preserved through dispatch."""
|
||||
dispatcher = AssistantDispatcher(
|
||||
writer=self.writer,
|
||||
node_path=self.node_path,
|
||||
node_name=AssistantNodeName.ROOT,
|
||||
node_run_id="test_run_id_preservation",
|
||||
)
|
||||
|
||||
message = AssistantMessage(id="msg_xyz_789", content="Message with ID")
|
||||
dispatcher.message(message)
|
||||
|
||||
event = self.dispatched_events[0]
|
||||
assert isinstance(event.action, MessageAction)
|
||||
self.assertEqual(event.action.message.id, "msg_xyz_789")
|
||||
|
||||
def test_dispatch_with_empty_node_path(self):
|
||||
"""Test dispatcher with an empty node path."""
|
||||
dispatcher = AssistantDispatcher(
|
||||
writer=self.writer, node_path=(), node_name=AssistantNodeName.ROOT, node_run_id="test_run_empty_path"
|
||||
)
|
||||
|
||||
message = AssistantMessage(content="Root level message")
|
||||
dispatcher.message(message)
|
||||
|
||||
self.assertEqual(len(self.dispatched_events), 1)
|
||||
event = self.dispatched_events[0]
|
||||
self.assertEqual(event.node_path, ())
|
||||
|
||||
|
||||
class MockNode(BaseAssistantNode[AssistantState, PartialAssistantState]):
|
||||
"""Mock node for testing"""
|
||||
class TestCreateDispatcherFromConfig(BaseTest):
|
||||
"""Test the create_dispatcher_from_config helper function."""
|
||||
|
||||
@property
|
||||
def node_name(self):
|
||||
return AssistantNodeName.ROOT
|
||||
def test_create_dispatcher_from_config_with_stream_writer(self):
|
||||
"""Test creating dispatcher from config with LangGraph stream writer."""
|
||||
node_path = (NodePath(name=AssistantGraphName.ASSISTANT), NodePath(name=AssistantNodeName.ROOT))
|
||||
|
||||
async def arun(self, state, config):
|
||||
# Use dispatch to add a message
|
||||
self.dispatcher.message(AssistantMessage(content="Test message from node"))
|
||||
return PartialAssistantState()
|
||||
config = RunnableConfig(
|
||||
metadata={"langgraph_node": AssistantNodeName.TRENDS_GENERATOR, "langgraph_checkpoint_ns": "checkpoint_abc"}
|
||||
)
|
||||
|
||||
mock_writer = MagicMock()
|
||||
|
||||
with patch("ee.hogai.utils.dispatcher.get_stream_writer", return_value=mock_writer):
|
||||
dispatcher = create_dispatcher_from_config(config, node_path)
|
||||
|
||||
self.assertEqual(dispatcher._node_name, AssistantNodeName.TRENDS_GENERATOR)
|
||||
self.assertEqual(dispatcher._node_run_id, "checkpoint_abc")
|
||||
self.assertEqual(dispatcher._node_path, node_path)
|
||||
self.assertEqual(dispatcher._writer, mock_writer)
|
||||
|
||||
def test_create_dispatcher_from_config_without_stream_writer(self):
|
||||
"""Test creating dispatcher when not in streaming context (e.g., testing)."""
|
||||
node_path = (NodePath(name=AssistantGraphName.ASSISTANT), NodePath(name=AssistantNodeName.ROOT))
|
||||
|
||||
config = RunnableConfig(
|
||||
metadata={"langgraph_node": AssistantNodeName.ROOT, "langgraph_checkpoint_ns": "checkpoint_xyz"}
|
||||
)
|
||||
|
||||
# Simulate RuntimeError when get_stream_writer is called outside streaming context
|
||||
with patch("ee.hogai.utils.dispatcher.get_stream_writer", side_effect=RuntimeError("Not in streaming context")):
|
||||
dispatcher = create_dispatcher_from_config(config, node_path)
|
||||
|
||||
# Should create a noop writer
|
||||
self.assertEqual(dispatcher._node_name, AssistantNodeName.ROOT)
|
||||
self.assertEqual(dispatcher._node_run_id, "checkpoint_xyz")
|
||||
self.assertEqual(dispatcher._node_path, node_path)
|
||||
|
||||
# Verify the noop writer doesn't raise exceptions
|
||||
message = AssistantMessage(content="Test")
|
||||
dispatcher.message(message) # Should not crash
|
||||
|
||||
def test_create_dispatcher_with_missing_metadata(self):
|
||||
"""Test creating dispatcher when metadata fields are missing."""
|
||||
node_path = (NodePath(name=AssistantGraphName.ASSISTANT),)
|
||||
|
||||
config = RunnableConfig(metadata={})
|
||||
|
||||
with patch("ee.hogai.utils.dispatcher.get_stream_writer", side_effect=RuntimeError("Not in streaming context")):
|
||||
dispatcher = create_dispatcher_from_config(config, node_path)
|
||||
|
||||
# Should use empty strings as defaults
|
||||
self.assertEqual(dispatcher._node_name, "")
|
||||
self.assertEqual(dispatcher._node_run_id, "")
|
||||
self.assertEqual(dispatcher._node_path, node_path)
|
||||
|
||||
def test_create_dispatcher_preserves_node_path(self):
|
||||
"""Test that node path is correctly passed through to the dispatcher."""
|
||||
nested_path = (
|
||||
NodePath(name=AssistantGraphName.ASSISTANT),
|
||||
NodePath(name=AssistantNodeName.ROOT, message_id="msg_1", tool_call_id="tc_1"),
|
||||
NodePath(name=AssistantGraphName.INSIGHTS),
|
||||
)
|
||||
|
||||
config = RunnableConfig(
|
||||
metadata={"langgraph_node": AssistantNodeName.TRENDS_GENERATOR, "langgraph_checkpoint_ns": "cp_1"}
|
||||
)
|
||||
|
||||
with patch("ee.hogai.utils.dispatcher.get_stream_writer", side_effect=RuntimeError("Not in streaming context")):
|
||||
dispatcher = create_dispatcher_from_config(config, nested_path)
|
||||
|
||||
self.assertEqual(dispatcher._node_path, nested_path)
|
||||
self.assertEqual(len(dispatcher._node_path), 3)
|
||||
|
||||
|
||||
class TestDispatcherIntegration(BaseTest):
|
||||
async def test_node_dispatch_flow(self):
|
||||
"""Test that a node can dispatch messages"""
|
||||
# Use sync test helpers since BaseTest provides them
|
||||
team, user = self.team, self.user
|
||||
"""Integration tests for dispatcher usage patterns."""
|
||||
|
||||
node = MockNode(team=team, user=user)
|
||||
def test_dispatcher_in_node_context(self):
|
||||
"""Test typical usage pattern within a node."""
|
||||
dispatched_events = []
|
||||
|
||||
# Track dispatched actions
|
||||
dispatched_actions = []
|
||||
def mock_writer(event):
|
||||
dispatched_events.append(event)
|
||||
|
||||
def mock_writer(data):
|
||||
dispatched_actions.append(data)
|
||||
|
||||
# Set up node's dispatcher
|
||||
node._dispatcher = AssistantDispatcher(mock_writer, node_name=AssistantNodeName.ROOT)
|
||||
|
||||
# Run the node
|
||||
state = AssistantState(messages=[])
|
||||
config = RunnableConfig(configurable={})
|
||||
|
||||
await node.arun(state, config)
|
||||
|
||||
# Verify action was dispatched
|
||||
self.assertEqual(len(dispatched_actions), 1)
|
||||
|
||||
event = dispatched_actions[0]
|
||||
self.assertIsInstance(event, AssistantDispatcherEvent)
|
||||
self.assertEqual(event.action.type, "MESSAGE")
|
||||
self.assertEqual(event.action.message.content, "Test message from node")
|
||||
self.assertEqual(event.action.message.parent_tool_call_id, None)
|
||||
self.assertEqual(event.node_name, AssistantNodeName.ROOT)
|
||||
|
||||
async def test_action_preservation_through_stream(self):
|
||||
"""Test that action data is preserved through the stream"""
|
||||
|
||||
captured_updates = []
|
||||
|
||||
def mock_writer(data):
|
||||
captured_updates.append(data)
|
||||
|
||||
dispatcher = AssistantDispatcher(mock_writer, node_name=AssistantNodeName.ROOT)
|
||||
|
||||
# Create complex message with metadata
|
||||
message = AssistantMessage(
|
||||
content="Complex message",
|
||||
parent_tool_call_id="tc123",
|
||||
tool_calls=[AssistantToolCall(id="tool1", name="search", args={"query": "test"})],
|
||||
node_path = (
|
||||
NodePath(name=AssistantGraphName.ASSISTANT),
|
||||
NodePath(name=AssistantNodeName.ROOT),
|
||||
)
|
||||
|
||||
dispatcher.message(message)
|
||||
dispatcher = AssistantDispatcher(
|
||||
writer=mock_writer, node_path=node_path, node_name=AssistantNodeName.ROOT, node_run_id="integration_run_1"
|
||||
)
|
||||
|
||||
# Extract action
|
||||
event = captured_updates[0]
|
||||
self.assertIsInstance(event, AssistantDispatcherEvent)
|
||||
# Simulate node execution pattern
|
||||
dispatcher.update("Starting node execution...")
|
||||
|
||||
# Verify all message fields preserved
|
||||
payload = event.action.message
|
||||
self.assertEqual(payload.content, "Complex message")
|
||||
self.assertEqual(payload.parent_tool_call_id, "tc123")
|
||||
self.assertIsNotNone(payload.tool_calls)
|
||||
self.assertEqual(len(payload.tool_calls), 1)
|
||||
self.assertEqual(payload.tool_calls[0].name, "search")
|
||||
tool_call = AssistantToolCall(id="tc_int_1", name="generate_insight", args={"type": "trends"})
|
||||
dispatcher.message(AssistantMessage(content="Generating insight", tool_calls=[tool_call]))
|
||||
|
||||
async def test_multiple_dispatches_from_node(self):
|
||||
"""Test that a node can dispatch multiple messages"""
|
||||
# Use sync test helpers since BaseTest provides them
|
||||
team, user = self.team, self.user
|
||||
dispatcher.update("Processing data...")
|
||||
|
||||
class MultiDispatchNode(BaseAssistantNode[AssistantState, PartialAssistantState]):
|
||||
@property
|
||||
def node_name(self):
|
||||
return AssistantNodeName.ROOT
|
||||
dispatcher.message(AssistantMessage(content="Insight generated successfully"))
|
||||
|
||||
async def arun(self, state, config):
|
||||
# Dispatch multiple messages
|
||||
self.dispatcher.message(AssistantMessage(content="First message"))
|
||||
self.dispatcher.message(AssistantMessage(content="Second message"))
|
||||
self.dispatcher.message(AssistantMessage(content="Third message"))
|
||||
return PartialAssistantState()
|
||||
# Verify all events were dispatched
|
||||
self.assertEqual(len(dispatched_events), 4)
|
||||
|
||||
node = MultiDispatchNode(team=team, user=user)
|
||||
# Verify event types and order
|
||||
self.assertIsInstance(dispatched_events[0].action, UpdateAction)
|
||||
self.assertIsInstance(dispatched_events[1].action, MessageAction)
|
||||
self.assertIsInstance(dispatched_events[2].action, UpdateAction)
|
||||
self.assertIsInstance(dispatched_events[3].action, MessageAction)
|
||||
|
||||
dispatched_actions = []
|
||||
# Verify all events have consistent metadata
|
||||
for event in dispatched_events:
|
||||
self.assertEqual(event.node_name, AssistantNodeName.ROOT)
|
||||
self.assertEqual(event.node_run_id, "integration_run_1")
|
||||
self.assertEqual(event.node_path, node_path)
|
||||
|
||||
def mock_writer(data):
|
||||
dispatched_actions.append(data)
|
||||
def test_concurrent_dispatchers(self):
|
||||
"""Test multiple dispatchers can coexist without interference."""
|
||||
dispatched_events_1 = []
|
||||
dispatched_events_2 = []
|
||||
|
||||
node._dispatcher = AssistantDispatcher(mock_writer, node_name=AssistantNodeName.ROOT)
|
||||
def writer_1(event):
|
||||
dispatched_events_1.append(event)
|
||||
|
||||
state = AssistantState(messages=[])
|
||||
config = RunnableConfig(configurable={})
|
||||
def writer_2(event):
|
||||
dispatched_events_2.append(event)
|
||||
|
||||
await node.arun(state, config)
|
||||
path_1 = (NodePath(name=AssistantGraphName.ASSISTANT), NodePath(name=AssistantNodeName.ROOT))
|
||||
path_2 = (NodePath(name=AssistantGraphName.ASSISTANT), NodePath(name=AssistantNodeName.TRENDS_GENERATOR))
|
||||
|
||||
# Verify all three dispatches
|
||||
self.assertEqual(len(dispatched_actions), 3)
|
||||
dispatcher_1 = AssistantDispatcher(
|
||||
writer=writer_1, node_path=path_1, node_name=AssistantNodeName.ROOT, node_run_id="run_1"
|
||||
)
|
||||
|
||||
contents = []
|
||||
for event in dispatched_actions:
|
||||
self.assertIsInstance(event, AssistantDispatcherEvent)
|
||||
contents.append(event.action.message.content)
|
||||
dispatcher_2 = AssistantDispatcher(
|
||||
writer=writer_2, node_path=path_2, node_name=AssistantNodeName.TRENDS_GENERATOR, node_run_id="run_2"
|
||||
)
|
||||
|
||||
self.assertEqual(contents, ["First message", "Second message", "Third message"])
|
||||
# Dispatch from both
|
||||
dispatcher_1.message(AssistantMessage(content="From dispatcher 1"))
|
||||
dispatcher_2.message(AssistantMessage(content="From dispatcher 2"))
|
||||
dispatcher_1.update("Update from dispatcher 1")
|
||||
dispatcher_2.update("Update from dispatcher 2")
|
||||
|
||||
# Verify each dispatcher wrote to its own writer
|
||||
self.assertEqual(len(dispatched_events_1), 2)
|
||||
self.assertEqual(len(dispatched_events_2), 2)
|
||||
|
||||
# Verify events went to correct writers
|
||||
assert isinstance(dispatched_events_1[0].action, MessageAction)
|
||||
assert isinstance(dispatched_events_1[0].action.message, AssistantMessage)
|
||||
self.assertEqual(dispatched_events_1[0].action.message.content, "From dispatcher 1")
|
||||
assert isinstance(dispatched_events_2[0].action, MessageAction)
|
||||
assert isinstance(dispatched_events_2[0].action.message, AssistantMessage)
|
||||
self.assertEqual(dispatched_events_2[0].action.message.content, "From dispatcher 2")
|
||||
|
||||
# Verify node names are correct
|
||||
self.assertEqual(dispatched_events_1[0].node_name, AssistantNodeName.ROOT)
|
||||
self.assertEqual(dispatched_events_2[0].node_name, AssistantNodeName.TRENDS_GENERATOR)
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user