mirror of
https://github.com/BillyOutlast/posthog.git
synced 2026-02-04 03:01:23 +01:00
feat(max): Add logic on if to generate Replay filters or to pass through existing ones (#37424)
This commit is contained in:
@@ -1,3 +1,5 @@
|
||||
from functools import partial
|
||||
|
||||
import pytest
|
||||
from unittest.mock import patch
|
||||
|
||||
@@ -31,18 +33,23 @@ def call_root_for_replay_sessions(demo_org_team_user):
|
||||
.compile(checkpointer=DjangoCheckpointer())
|
||||
)
|
||||
|
||||
async def callable(messages: str | list[AssistantMessageUnion]) -> AssistantMessage:
|
||||
async def callable(
|
||||
messages: str | list[AssistantMessageUnion], include_search_session_recordings_context: bool
|
||||
) -> AssistantMessage:
|
||||
conversation = await Conversation.objects.acreate(team=demo_org_team_user[1], user=demo_org_team_user[2])
|
||||
initial_state = AssistantState(
|
||||
messages=[HumanMessage(content=messages)] if isinstance(messages, str) else messages
|
||||
)
|
||||
# Simulate session replay page context
|
||||
# Conditionally include session replay page context
|
||||
contextual_tools = (
|
||||
{"search_session_recordings": {"current_filters": {"date_from": "-7d", "filter_test_accounts": True}}}
|
||||
if include_search_session_recordings_context
|
||||
else {}
|
||||
)
|
||||
config = {
|
||||
"configurable": {
|
||||
"thread_id": conversation.id,
|
||||
"contextual_tools": {
|
||||
"search_session_recordings": {"current_filters": {"date_from": "-7d", "filter_test_accounts": True}}
|
||||
},
|
||||
"contextual_tools": contextual_tools,
|
||||
}
|
||||
}
|
||||
raw_state = await graph.ainvoke(initial_state, config)
|
||||
@@ -60,7 +67,7 @@ def call_root_for_replay_sessions(demo_org_team_user):
|
||||
@pytest.mark.django_db
|
||||
@patch("posthoganalytics.feature_enabled", return_value=True)
|
||||
async def eval_tool_routing_session_replay(patch_feature_enabled, call_root_for_replay_sessions, pytestconfig):
|
||||
"""Test routing between search_session_recordings (contextual) and session_summarization (root)."""
|
||||
"""Test routing between search_session_recordings (contextual) and session_summarization (root) with context."""
|
||||
|
||||
await MaxPublicEval(
|
||||
experiment_name="tool_routing_session_replay",
|
||||
@@ -73,6 +80,7 @@ async def eval_tool_routing_session_replay(patch_feature_enabled, call_root_for_
|
||||
expected=AssistantToolCall(
|
||||
id="1",
|
||||
name="search_session_recordings",
|
||||
# Expect the period to be guessed from current filters
|
||||
args={"change": "show me recordings from mobile users"},
|
||||
),
|
||||
),
|
||||
@@ -89,7 +97,7 @@ async def eval_tool_routing_session_replay(patch_feature_enabled, call_root_for_
|
||||
expected=AssistantToolCall(
|
||||
id="3",
|
||||
name="search_session_recordings",
|
||||
args={"change": "show recordings longer than 5 minutes"},
|
||||
args={"change": "show only recordings longer than 5 minutes"},
|
||||
),
|
||||
),
|
||||
# Cases where session_summarization should be used (analysis/summary)
|
||||
@@ -98,15 +106,21 @@ async def eval_tool_routing_session_replay(patch_feature_enabled, call_root_for_
|
||||
expected=AssistantToolCall(
|
||||
id="5",
|
||||
name="session_summarization",
|
||||
args={"session_summarization_query": "summarize sessions from yesterday"},
|
||||
args={
|
||||
"session_summarization_query": "summarize sessions from yesterday",
|
||||
"should_use_current_filters": False, # Specific time frame differs from current filters
|
||||
},
|
||||
),
|
||||
),
|
||||
EvalCase(
|
||||
input="watch sessions of the user 09081 in the last 7 days",
|
||||
input="watch sessions of the user 09081 in the last 30 days",
|
||||
expected=AssistantToolCall(
|
||||
id="6",
|
||||
name="session_summarization",
|
||||
args={"session_summarization_query": "watch sessions of the user 09081 in the last 7 days"},
|
||||
args={
|
||||
"session_summarization_query": "watch sessions of the user 09081 in the last 30 days",
|
||||
"should_use_current_filters": False, # Specific user and timeframe
|
||||
},
|
||||
),
|
||||
),
|
||||
EvalCase(
|
||||
@@ -114,24 +128,127 @@ async def eval_tool_routing_session_replay(patch_feature_enabled, call_root_for_
|
||||
expected=AssistantToolCall(
|
||||
id="7",
|
||||
name="session_summarization",
|
||||
args={"session_summarization_query": "analyze mobile user sessions from last week"},
|
||||
args={
|
||||
"session_summarization_query": "analyze mobile user sessions from last week",
|
||||
"should_use_current_filters": False, # Specific device type and timeframe
|
||||
},
|
||||
),
|
||||
),
|
||||
# Edge cases - ambiguous queries
|
||||
EvalCase(
|
||||
input="show me what users did on the checkout page",
|
||||
input="summarize sessions from the last 30 days, including test accounts",
|
||||
expected=AssistantToolCall(
|
||||
id="8",
|
||||
name="session_summarization",
|
||||
args={
|
||||
"session_summarization_query": "summarize sessions from the last 30 days with test accounts included",
|
||||
"should_use_current_filters": False, # Different time frame/conditions
|
||||
},
|
||||
),
|
||||
),
|
||||
# Cases where should_use_current_filters should be true (referring to current/selected filters)
|
||||
EvalCase(
|
||||
input="summarize these sessions",
|
||||
expected=AssistantToolCall(
|
||||
id="9",
|
||||
name="session_summarization",
|
||||
args={"session_summarization_query": "show me what users did on the checkout page"},
|
||||
args={
|
||||
"session_summarization_query": "summarize these sessions",
|
||||
"should_use_current_filters": True, # "these" refers to current filters
|
||||
},
|
||||
),
|
||||
),
|
||||
EvalCase(
|
||||
input="replay user sessions from this morning",
|
||||
input="summarize all sessions",
|
||||
expected=AssistantToolCall(
|
||||
id="10",
|
||||
name="session_summarization",
|
||||
args={"session_summarization_query": "replay user sessions from this morning"},
|
||||
args={
|
||||
"session_summarization_query": "summarize all sessions",
|
||||
"should_use_current_filters": True, # "all" in context of filtered view
|
||||
},
|
||||
),
|
||||
),
|
||||
EvalCase(
|
||||
input="summarize sessions from the last 7 days with test accounts filtered out",
|
||||
expected=AssistantToolCall(
|
||||
id="11",
|
||||
name="session_summarization",
|
||||
args={
|
||||
"session_summarization_query": "summarize sessions from the last 7 days with test accounts filtered out",
|
||||
"should_use_current_filters": True, # Matches current filters exactly
|
||||
},
|
||||
),
|
||||
),
|
||||
# Ambiguous cases
|
||||
EvalCase(
|
||||
input="show me the summary of what users did with our app in the last 7 days",
|
||||
expected=AssistantToolCall(
|
||||
id="12",
|
||||
name="session_summarization",
|
||||
args={
|
||||
"session_summarization_query": "show me what users did with our app",
|
||||
"should_use_current_filters": True, # Analyzing user behavior, use current context
|
||||
},
|
||||
),
|
||||
),
|
||||
EvalCase(
|
||||
input="show me sessions with users who visited checkout page",
|
||||
expected=AssistantToolCall(
|
||||
id="13",
|
||||
name="search_session_recordings",
|
||||
args={"change": "show me sessions with users who visited checkout page"},
|
||||
),
|
||||
),
|
||||
],
|
||||
pytestconfig=pytestconfig,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
@patch("posthoganalytics.feature_enabled", return_value=True)
|
||||
async def eval_session_summarization_no_context(patch_feature_enabled, call_root_for_replay_sessions, pytestconfig):
|
||||
"""Test session summarization without search_session_recordings context - should_use_current_filters should always be false."""
|
||||
|
||||
# Use partial to avoid adding session search context
|
||||
task_without_context = partial(call_root_for_replay_sessions, include_search_session_recordings_context=False)
|
||||
|
||||
await MaxPublicEval(
|
||||
experiment_name="session_summarization_no_context",
|
||||
task=task_without_context,
|
||||
scores=[ToolRelevance(semantic_similarity_args={"session_summarization_query"})],
|
||||
data=[
|
||||
# All cases should have should_use_current_filters=false when no context
|
||||
EvalCase(
|
||||
input="summarize sessions from yesterday",
|
||||
expected=AssistantToolCall(
|
||||
id="1",
|
||||
name="session_summarization",
|
||||
args={
|
||||
"session_summarization_query": "summarize sessions from yesterday",
|
||||
"should_use_current_filters": False, # No context, always false
|
||||
},
|
||||
),
|
||||
),
|
||||
EvalCase(
|
||||
input="analyze the current recordings from today",
|
||||
expected=AssistantToolCall(
|
||||
id="3",
|
||||
name="session_summarization",
|
||||
args={
|
||||
"session_summarization_query": "analyze the current recordings from today",
|
||||
"should_use_current_filters": False, # Even with "current", no context means false
|
||||
},
|
||||
),
|
||||
),
|
||||
EvalCase(
|
||||
input="watch all my session recordings",
|
||||
expected=AssistantToolCall(
|
||||
id="5",
|
||||
name="session_summarization",
|
||||
args={
|
||||
"session_summarization_query": "watch all session recordings",
|
||||
"should_use_current_filters": False, # Even with "all", no context means false
|
||||
},
|
||||
),
|
||||
),
|
||||
],
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import re
|
||||
import json
|
||||
import math
|
||||
from typing import Literal, Optional, TypeVar, cast
|
||||
from uuid import uuid4
|
||||
@@ -57,6 +58,9 @@ from .prompts import (
|
||||
ROOT_INSIGHTS_CONTEXT_PROMPT,
|
||||
ROOT_SYSTEM_PROMPT,
|
||||
ROOT_UI_CONTEXT_PROMPT,
|
||||
SESSION_SUMMARIZATION_PROMPT_BASE,
|
||||
SESSION_SUMMARIZATION_PROMPT_NO_REPLAY_CONTEXT,
|
||||
SESSION_SUMMARIZATION_PROMPT_WITH_REPLAY_CONTEXT,
|
||||
)
|
||||
|
||||
# Map query kinds to their respective full UI query classes
|
||||
@@ -337,14 +341,12 @@ class RootNode(RootNodeUIContextMixin):
|
||||
# Build system prompt with conditional session summarization and insight search sections
|
||||
system_prompt_template = ROOT_SYSTEM_PROMPT
|
||||
# Check if session summarization is enabled for the user
|
||||
if not self._has_session_summarization_feature_flag():
|
||||
# Remove session summarization section from prompt using regex
|
||||
if self._has_session_summarization_feature_flag():
|
||||
context = self._render_session_summarization_context(config)
|
||||
# Inject session summarization context
|
||||
system_prompt_template = re.sub(
|
||||
r"\n?<session_summarization>.*?</session_summarization>", "", system_prompt_template, flags=re.DOTALL
|
||||
r"\n?<session_summarization></session_summarization>", context, system_prompt_template, flags=re.DOTALL
|
||||
)
|
||||
# Also remove the reference to session_summarization in basic_functionality
|
||||
system_prompt_template = re.sub(r"\n?\d+\. `session_summarization`.*?[^\n]*", "", system_prompt_template)
|
||||
|
||||
# Check if insight search is enabled for the user
|
||||
if not self._has_insight_search_feature_flag():
|
||||
# Remove the reference to search_insights in basic_functionality
|
||||
@@ -611,6 +613,27 @@ class RootNode(RootNodeUIContextMixin):
|
||||
return messages[idx:]
|
||||
return messages
|
||||
|
||||
def _render_session_summarization_context(self, config: RunnableConfig) -> str:
|
||||
"""Render the user context template with the provided context strings."""
|
||||
search_session_recordings_context = self._get_contextual_tools(config).get("search_session_recordings")
|
||||
if (
|
||||
not search_session_recordings_context
|
||||
or not isinstance(search_session_recordings_context, dict)
|
||||
or not search_session_recordings_context.get("current_filters")
|
||||
or not isinstance(search_session_recordings_context["current_filters"], dict)
|
||||
):
|
||||
conditional_context = SESSION_SUMMARIZATION_PROMPT_NO_REPLAY_CONTEXT
|
||||
else:
|
||||
current_filters = search_session_recordings_context["current_filters"]
|
||||
conditional_template = PromptTemplate.from_template(
|
||||
SESSION_SUMMARIZATION_PROMPT_WITH_REPLAY_CONTEXT, template_format="mustache"
|
||||
)
|
||||
conditional_context = conditional_template.format_prompt(
|
||||
current_filters=json.dumps(current_filters)
|
||||
).to_string()
|
||||
template = PromptTemplate.from_template(SESSION_SUMMARIZATION_PROMPT_BASE, template_format="mustache")
|
||||
return template.format_prompt(conditional_context=conditional_context).to_string()
|
||||
|
||||
|
||||
class RootNodeTools(AssistantNode):
|
||||
async def arun(self, state: AssistantState, config: RunnableConfig) -> PartialAssistantState:
|
||||
@@ -652,6 +675,8 @@ class RootNodeTools(AssistantNode):
|
||||
return PartialAssistantState(
|
||||
root_tool_call_id=tool_call.id,
|
||||
session_summarization_query=tool_call.args["session_summarization_query"],
|
||||
# Safety net in case the argument is missing to avoid raising exceptions internally
|
||||
should_use_current_filters=tool_call.args.get("should_use_current_filters", False),
|
||||
root_tool_calls_count=tool_call_count + 1,
|
||||
)
|
||||
elif ToolClass := get_contextual_tool_class(tool_call.name):
|
||||
|
||||
@@ -48,7 +48,7 @@ You have access to these main tools:
|
||||
1. `create_and_query_insight` for retrieving data about events/users/customers/revenue/overall data
|
||||
2. `search_documentation` for answering questions related to PostHog features, concepts, usage, sdk integration, troubleshooting, and so on – use `search_documentation` liberally!
|
||||
3. `search_insights` for finding existing insights when you deem necessary to look for insights, when users ask to search, find, or look up insights or when creating dashboards
|
||||
4. `session_summarization` for summarizing sessions, when users ask to summarize (e.g. watch, analyze) specific sessions (e.g. replays, recordings)
|
||||
4. `session_summarization` for summarizing session recordings
|
||||
|
||||
Before using a tool, say what you're about to do, in one sentence. If calling the navigation tool, do not say anything.
|
||||
|
||||
@@ -119,22 +119,49 @@ Follow these guidelines when searching insights:
|
||||
- The search functionality works better with natural language queries that include context
|
||||
</insight_search>
|
||||
|
||||
<session_summarization>
|
||||
The tool `session_summarization` helps you to summarize sessions by converting user query into a search for relevant sessions and then summarizing the events within those sessions.
|
||||
|
||||
Follow these guidelines when summarizing sessions:
|
||||
- Sessions may also be called "recordings", "replays", "session recordings", or "user sessions"
|
||||
- Use this tool when users ask to watch, summarize, analyze, or review sessions
|
||||
- CRITICAL: Always pass the user's complete, unmodified query to the `session_summarization_query` parameter
|
||||
- DO NOT truncate, summarize, or extract keywords from the user's query
|
||||
- The query is used to find relevant sessions - context helps find better matches
|
||||
</session_summarization>
|
||||
<session_summarization></session_summarization>
|
||||
|
||||
{{{ui_context}}}
|
||||
{{{billing_context}}}
|
||||
""".strip()
|
||||
)
|
||||
|
||||
SESSION_SUMMARIZATION_PROMPT_BASE = """
|
||||
<session_summarization>
|
||||
The tool `session_summarization` helps you to summarize session recordings by analysing the events within those sessions.
|
||||
|
||||
{{{conditional_context}}}
|
||||
|
||||
Synonyms:
|
||||
- "summarize": "watch", "analyze", "review", and similar
|
||||
- "session recordings": "sessions", "recordings", "replays", "user sessions", and similar
|
||||
|
||||
Follow these guidelines when summarizing session recordings:
|
||||
- CRITICAL: Always pass the user's complete, unmodified query to the `session_summarization_query` parameter
|
||||
- DO NOT truncate, summarize, or extract keywords from the user's query
|
||||
- The query is used to find relevant sessions - context helps find better matches
|
||||
- Use explicit tool definition to make a decision
|
||||
</session_summarization>
|
||||
"""
|
||||
|
||||
SESSION_SUMMARIZATION_PROMPT_NO_REPLAY_CONTEXT = """
|
||||
There are no current filters in the user's UI context. It means that you need to:
|
||||
- Convert the user query into a `session_summarization_query`
|
||||
- The query should be used to search for relevant sessions and then summarize them
|
||||
- Assume the `should_use_current_filters` should be always `false`
|
||||
"""
|
||||
|
||||
SESSION_SUMMARIZATION_PROMPT_WITH_REPLAY_CONTEXT = """
|
||||
There are current filters in the user's UI context. It means that you need to:
|
||||
- Convert the user query into a `session_summarization_query`
|
||||
- The query should be used to understand the user's intent
|
||||
- Decide if the query is relevant to the current filters and set `should_use_current_filters` accordingly
|
||||
|
||||
```json
|
||||
{{{current_filters}}}
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
ROOT_INSIGHT_DESCRIPTION_PROMPT = """
|
||||
Pick the most suitable visualization type for the user's question.
|
||||
|
||||
@@ -24,7 +24,11 @@ from posthog.temporal.ai.session_summary.summarize_session_group import (
|
||||
)
|
||||
|
||||
from ee.hogai.graph.base import AssistantNode
|
||||
from ee.hogai.session_summaries.constants import GROUP_SUMMARIES_MIN_SESSIONS, SESSION_SUMMARIES_STREAMING_MODEL
|
||||
from ee.hogai.session_summaries.constants import (
|
||||
GROUP_SUMMARIES_MIN_SESSIONS,
|
||||
MAX_SESSIONS_TO_SUMMARIZE,
|
||||
SESSION_SUMMARIES_STREAMING_MODEL,
|
||||
)
|
||||
from ee.hogai.session_summaries.session_group.patterns import EnrichedSessionGroupSummaryPatternsList
|
||||
from ee.hogai.session_summaries.session_group.summarize_session_group import find_sessions_timestamps
|
||||
from ee.hogai.session_summaries.session_group.summary_notebooks import (
|
||||
@@ -120,7 +124,9 @@ class SessionSummarizationNode(AssistantNode):
|
||||
max_filters = cast(MaxRecordingUniversalFilters, filters_data)
|
||||
return max_filters
|
||||
|
||||
def _get_session_ids_with_filters(self, replay_filters: MaxRecordingUniversalFilters) -> list[str] | None:
|
||||
def _get_session_ids_with_filters(
|
||||
self, replay_filters: MaxRecordingUniversalFilters, limit: int = MAX_SESSIONS_TO_SUMMARIZE
|
||||
) -> list[str] | None:
|
||||
from posthog.session_recordings.queries.session_recording_list_from_query import SessionRecordingListFromQuery
|
||||
|
||||
# Convert Max filters into recordings query format
|
||||
@@ -144,6 +150,7 @@ class SessionSummarizationNode(AssistantNode):
|
||||
if replay_filters.duration
|
||||
else None
|
||||
),
|
||||
limit=limit,
|
||||
)
|
||||
# Execute the query to get session IDs
|
||||
query_runner = SessionRecordingListFromQuery(
|
||||
@@ -257,23 +264,46 @@ class SessionSummarizationNode(AssistantNode):
|
||||
conversation_id = config.get("configurable", {}).get("thread_id", "unknown")
|
||||
writer = self._get_stream_writer()
|
||||
# If query was not provided for some reason
|
||||
if not state.session_summarization_query:
|
||||
if state.session_summarization_query is None:
|
||||
self._log_failure(
|
||||
f"Session summarization query is not provided: {state.session_summarization_query}",
|
||||
f"Session summarization query is not provided when summarizing sessions: {state.session_summarization_query}",
|
||||
conversation_id,
|
||||
start_time,
|
||||
)
|
||||
return self._create_error_response(self._base_error_instructions, state)
|
||||
# If the decision on the current filters is not made
|
||||
if state.should_use_current_filters is None:
|
||||
self._log_failure(
|
||||
f"Use current filters decision is not made when summarizing sessions: {state.should_use_current_filters}",
|
||||
conversation_id,
|
||||
start_time,
|
||||
)
|
||||
return self._create_error_response(self._base_error_instructions, state)
|
||||
# If the current filters were marked as relevant, but not present in the context
|
||||
current_filters = self._get_contextual_tools(config).get("search_session_recordings", {}).get("current_filters")
|
||||
try:
|
||||
# Generate filters to get session ids from DB
|
||||
replay_filters = await self._generate_replay_filters(state.session_summarization_query)
|
||||
if not replay_filters:
|
||||
self._log_failure(
|
||||
f"No Replay filters were generated for session summarization: {state.session_summarization_query}",
|
||||
conversation_id,
|
||||
start_time,
|
||||
)
|
||||
return self._create_error_response(self._base_error_instructions, state)
|
||||
# Use current filters, if provided
|
||||
if state.should_use_current_filters:
|
||||
if not current_filters:
|
||||
self._log_failure(
|
||||
f"Use current filters decision was set to True, but current filters were not provided when summarizing sessions: {state.should_use_current_filters}",
|
||||
conversation_id,
|
||||
start_time,
|
||||
)
|
||||
return self._create_error_response(self._base_error_instructions, state)
|
||||
current_filters = cast(dict[str, Any], current_filters)
|
||||
replay_filters = MaxRecordingUniversalFilters.model_validate(current_filters)
|
||||
# If not - generate filters to get session ids from DB
|
||||
else:
|
||||
generated_filters = await self._generate_replay_filters(state.session_summarization_query)
|
||||
if not generated_filters:
|
||||
self._log_failure(
|
||||
f"No Replay filters were generated for session summarization: {state.session_summarization_query}",
|
||||
conversation_id,
|
||||
start_time,
|
||||
)
|
||||
return self._create_error_response(self._base_error_instructions, state)
|
||||
replay_filters = generated_filters
|
||||
# Query the filters to get session ids
|
||||
session_ids = await database_sync_to_async(self._get_session_ids_with_filters, thread_sensitive=False)(
|
||||
replay_filters
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from collections.abc import AsyncGenerator, Awaitable, Callable
|
||||
from typing import Any
|
||||
from typing import Any, cast
|
||||
|
||||
from posthog.test.base import BaseTest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
@@ -7,6 +7,7 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from django.utils import timezone
|
||||
|
||||
from asgiref.sync import async_to_sync
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
|
||||
from posthog.schema import (
|
||||
AssistantToolCallMessage,
|
||||
@@ -101,13 +102,17 @@ class TestSessionSummarizationNode(BaseTest):
|
||||
return mock_database_sync_to_async
|
||||
|
||||
def _create_test_state(
|
||||
self, query: str | None = None, root_tool_call_id: str | None = "test_tool_call_id"
|
||||
self,
|
||||
query: str | None = None,
|
||||
root_tool_call_id: str | None = "test_tool_call_id",
|
||||
should_use_current_filters: bool | None = None,
|
||||
) -> AssistantState:
|
||||
"""Helper to create a test AssistantState."""
|
||||
return AssistantState(
|
||||
messages=[HumanMessage(content="Test")],
|
||||
session_summarization_query=query,
|
||||
root_tool_call_id=root_tool_call_id,
|
||||
should_use_current_filters=should_use_current_filters,
|
||||
)
|
||||
|
||||
def test_create_error_response(self) -> None:
|
||||
@@ -256,7 +261,25 @@ class TestSessionSummarizationNode(BaseTest):
|
||||
mock_get_stream_writer.return_value = None
|
||||
conversation = Conversation.objects.create(team=self.team, user=self.user)
|
||||
|
||||
state = self._create_test_state(query=None)
|
||||
state = self._create_test_state(query=None, should_use_current_filters=False)
|
||||
|
||||
result = async_to_sync(self.node.arun)(state, {"configurable": {"thread_id": str(conversation.id)}})
|
||||
|
||||
self.assertIsInstance(result, PartialAssistantState)
|
||||
self.assertIsNotNone(result)
|
||||
assert result is not None
|
||||
message = result.messages[0]
|
||||
self.assertIsInstance(message, AssistantToolCallMessage)
|
||||
assert isinstance(message, AssistantToolCallMessage)
|
||||
self.assertIn("encountered an issue", message.content)
|
||||
|
||||
@patch("ee.hogai.graph.session_summaries.nodes.get_stream_writer")
|
||||
def test_arun_no_use_current_filters_decision(self, mock_get_stream_writer: MagicMock) -> None:
|
||||
"""Test arun returns error when should_use_current_filters decision is not made."""
|
||||
mock_get_stream_writer.return_value = None
|
||||
conversation = Conversation.objects.create(team=self.team, user=self.user)
|
||||
|
||||
state = self._create_test_state(query="test query", should_use_current_filters=None)
|
||||
|
||||
result = async_to_sync(self.node.arun)(state, {"configurable": {"thread_id": str(conversation.id)}})
|
||||
|
||||
@@ -281,7 +304,7 @@ class TestSessionSummarizationNode(BaseTest):
|
||||
mock_graph_instance, _ = self._create_mock_filter_graph(output_filters=None)
|
||||
mock_filter_graph_class.return_value = mock_graph_instance
|
||||
|
||||
state = self._create_test_state(query="test query")
|
||||
state = self._create_test_state(query="test query", should_use_current_filters=False)
|
||||
|
||||
result = async_to_sync(self.node.arun)(state, {"configurable": {"thread_id": str(conversation.id)}})
|
||||
|
||||
@@ -318,7 +341,7 @@ class TestSessionSummarizationNode(BaseTest):
|
||||
|
||||
mock_db_sync.side_effect = self._create_mock_db_sync_to_async()
|
||||
|
||||
state = self._create_test_state(query="test query")
|
||||
state = self._create_test_state(query="test query", should_use_current_filters=False)
|
||||
|
||||
result = async_to_sync(self.node.arun)(state, {"configurable": {"thread_id": str(conversation.id)}})
|
||||
|
||||
@@ -377,7 +400,7 @@ class TestSessionSummarizationNode(BaseTest):
|
||||
|
||||
mock_execute_summarize.side_effect = mock_summarize_side_effect
|
||||
|
||||
state = self._create_test_state(query="test query")
|
||||
state = self._create_test_state(query="test query", should_use_current_filters=False)
|
||||
|
||||
result = async_to_sync(self.node.arun)(state, {"configurable": {"thread_id": str(conversation.id)}})
|
||||
|
||||
@@ -404,7 +427,7 @@ class TestSessionSummarizationNode(BaseTest):
|
||||
# Mock filter generation to raise exception
|
||||
mock_filter_graph_class.side_effect = Exception("Test exception")
|
||||
|
||||
state = self._create_test_state(query="test query")
|
||||
state = self._create_test_state(query="test query", should_use_current_filters=False)
|
||||
|
||||
result = async_to_sync(self.node.arun)(state, {"configurable": {"thread_id": str(conversation.id)}})
|
||||
|
||||
@@ -417,3 +440,124 @@ class TestSessionSummarizationNode(BaseTest):
|
||||
assert isinstance(message, AssistantToolCallMessage)
|
||||
self.assertIn("encountered an issue", message.content)
|
||||
self.assertEqual(message.tool_call_id, "test_tool_call_id")
|
||||
|
||||
@patch("posthog.session_recordings.queries.session_recording_list_from_query.SessionRecordingListFromQuery")
|
||||
@patch("ee.hogai.graph.session_summaries.nodes.database_sync_to_async")
|
||||
@patch("ee.hogai.graph.session_summaries.nodes.get_stream_writer")
|
||||
def test_arun_use_current_filters_true_no_context(
|
||||
self,
|
||||
mock_get_stream_writer: MagicMock,
|
||||
mock_db_sync: MagicMock,
|
||||
mock_query_runner_class: MagicMock,
|
||||
) -> None:
|
||||
"""Test arun returns error when should_use_current_filters=True but no context provided."""
|
||||
mock_get_stream_writer.return_value = None
|
||||
conversation = Conversation.objects.create(team=self.team, user=self.user)
|
||||
|
||||
state = self._create_test_state(query="test query", should_use_current_filters=True)
|
||||
|
||||
# No contextual tools provided
|
||||
result = async_to_sync(self.node.arun)(state, {"configurable": {"thread_id": str(conversation.id)}})
|
||||
|
||||
self.assertIsInstance(result, PartialAssistantState)
|
||||
self.assertIsNotNone(result)
|
||||
assert result is not None
|
||||
message = result.messages[0]
|
||||
self.assertIsInstance(message, AssistantToolCallMessage)
|
||||
assert isinstance(message, AssistantToolCallMessage)
|
||||
self.assertIn("encountered an issue", message.content)
|
||||
|
||||
@patch("posthog.session_recordings.queries.session_recording_list_from_query.SessionRecordingListFromQuery")
|
||||
@patch("ee.hogai.graph.session_summaries.nodes.database_sync_to_async")
|
||||
@patch("ee.hogai.graph.session_summaries.nodes.get_stream_writer")
|
||||
def test_arun_use_current_filters_true_with_context(
|
||||
self,
|
||||
mock_get_stream_writer: MagicMock,
|
||||
mock_db_sync: MagicMock,
|
||||
mock_query_runner_class: MagicMock,
|
||||
) -> None:
|
||||
"""Test arun uses current filters when should_use_current_filters=True and context is provided."""
|
||||
mock_get_stream_writer.return_value = None
|
||||
conversation = Conversation.objects.create(team=self.team, user=self.user)
|
||||
|
||||
# Mock empty session results for simplicity
|
||||
mock_query_runner_class.return_value = self._create_mock_query_runner([])
|
||||
mock_db_sync.side_effect = self._create_mock_db_sync_to_async()
|
||||
|
||||
state = self._create_test_state(query="test query", should_use_current_filters=True)
|
||||
|
||||
# Provide contextual filters - need to match MaxRecordingUniversalFilters structure
|
||||
config = cast(
|
||||
RunnableConfig,
|
||||
{
|
||||
"configurable": {
|
||||
"thread_id": str(conversation.id),
|
||||
"contextual_tools": {
|
||||
"search_session_recordings": {
|
||||
"current_filters": {
|
||||
"date_from": "-30d",
|
||||
"date_to": "2024-01-31",
|
||||
"filter_test_accounts": True,
|
||||
"duration": [],
|
||||
"filter_group": {"type": "AND", "values": []},
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
result = async_to_sync(self.node.arun)(state, config)
|
||||
|
||||
# Should return "No sessions were found" message since we mocked empty results
|
||||
self.assertIsInstance(result, PartialAssistantState)
|
||||
self.assertIsNotNone(result)
|
||||
assert result is not None
|
||||
message = result.messages[0]
|
||||
self.assertIsInstance(message, AssistantToolCallMessage)
|
||||
assert isinstance(message, AssistantToolCallMessage)
|
||||
self.assertEqual(message.content, "No sessions were found.")
|
||||
|
||||
# Verify that the query runner was called (meaning it used the current filters)
|
||||
mock_query_runner_class.assert_called_once()
|
||||
|
||||
@patch("posthog.session_recordings.queries.session_recording_list_from_query.SessionRecordingListFromQuery")
|
||||
@patch("ee.hogai.graph.session_summaries.nodes.database_sync_to_async")
|
||||
@patch("products.replay.backend.max_tools.SessionReplayFilterOptionsGraph")
|
||||
@patch("ee.hogai.graph.session_summaries.nodes.get_stream_writer")
|
||||
def test_arun_use_current_filters_false_generates_filters(
|
||||
self,
|
||||
mock_get_stream_writer: MagicMock,
|
||||
mock_filter_graph_class: MagicMock,
|
||||
mock_db_sync: MagicMock,
|
||||
mock_query_runner_class: MagicMock,
|
||||
) -> None:
|
||||
"""Test arun generates new filters when should_use_current_filters=False."""
|
||||
mock_get_stream_writer.return_value = None
|
||||
conversation = Conversation.objects.create(team=self.team, user=self.user)
|
||||
|
||||
# Setup filter generation mock
|
||||
mock_filters = self._create_mock_filters()
|
||||
mock_graph_instance, _ = self._create_mock_filter_graph(mock_filters)
|
||||
mock_filter_graph_class.return_value = mock_graph_instance
|
||||
|
||||
# Mock empty session results
|
||||
mock_query_runner_class.return_value = self._create_mock_query_runner([])
|
||||
mock_db_sync.side_effect = self._create_mock_db_sync_to_async()
|
||||
|
||||
state = self._create_test_state(query="test query", should_use_current_filters=False)
|
||||
|
||||
result = async_to_sync(self.node.arun)(state, {"configurable": {"thread_id": str(conversation.id)}})
|
||||
|
||||
# Verify filter generation was called
|
||||
mock_filter_graph_class.assert_called_once()
|
||||
mock_graph_instance.compile_full_graph.assert_called_once()
|
||||
|
||||
# Should return "No sessions were found" message
|
||||
self.assertIsInstance(result, PartialAssistantState)
|
||||
self.assertIsNotNone(result)
|
||||
assert result is not None
|
||||
message = result.messages[0]
|
||||
self.assertIsInstance(message, AssistantToolCallMessage)
|
||||
assert isinstance(message, AssistantToolCallMessage)
|
||||
self.assertEqual(message.content, "No sessions were found.")
|
||||
|
||||
@@ -11,6 +11,7 @@ SESSION_SUMMARIES_TEMPERATURE = 0.1 # Reduce hallucinations, but >0 to allow fo
|
||||
BASE_LLM_CALL_TIMEOUT_S = 600.0
|
||||
|
||||
# Summarization
|
||||
MAX_SESSIONS_TO_SUMMARIZE = 200 # Maximum number of sessions to summarize at once
|
||||
HALLUCINATED_EVENTS_MIN_RATIO = 0.15 # If more than 15% of events in the summary hallucinated, fail the summarization
|
||||
# Minimum number of sessions to use group summary logic (find patterns) instead of summarizing them separately
|
||||
GROUP_SUMMARIES_MIN_SESSIONS = 5
|
||||
|
||||
@@ -50,12 +50,44 @@ class search_insights(BaseModel):
|
||||
|
||||
class session_summarization(BaseModel):
|
||||
"""
|
||||
Analyze sessions by finding relevant sessions based on user query and summarizing their events.
|
||||
Use this tool for summarizing sessions, when users ask to summarize (e.g. watch, analyze) specific sessions (e.g. replays, recordings)
|
||||
- Summarize session recordings to find patterns and issues by summarizing sessions' events.
|
||||
- When to use the tool:
|
||||
* When the user asks to summarize session recordings
|
||||
- "summarize" synonyms: "watch", "analyze", "review", and similar
|
||||
- "session recordings" synonyms: "sessions", "recordings", "replays", "user sessions", and similar
|
||||
- When NOT to use the tool:
|
||||
* When the user asks to find, search for, or look up session recordings, but doesn't ask to summarize them
|
||||
* When users asks to update, change, or adjust session recordings filters
|
||||
"""
|
||||
|
||||
session_summarization_query: str = Field(
|
||||
description="The user's complete query for session summarization. This will be used to find relevant sessions. Examples: 'summarize sessions from yesterday', 'watch what user X did on the checkout page', 'analyze mobile user sessions from last week'"
|
||||
description="""
|
||||
- The user's complete query for session recordings summarization.
|
||||
- This will be used to find relevant session recordings.
|
||||
- Always pass the user's complete, unmodified query.
|
||||
- Examples:
|
||||
* 'summarize all session recordings from yesterday'
|
||||
* 'analyze mobile user session recordings from last week, even if 1 second'
|
||||
* 'watch last 300 session recordings of MacOS users from US'
|
||||
* and similar
|
||||
"""
|
||||
)
|
||||
should_use_current_filters: bool = Field(
|
||||
description="""
|
||||
- Whether to use current filters from user's UI to find relevant session recordings.
|
||||
- IMPORTANT: Should be always `false` if the current filters or `search_session_recordings` tool are not present in the conversation history.
|
||||
- Examples:
|
||||
* Set to `true` if one of the conditions is met:
|
||||
- the user wants to summarize "current/selected/opened/my/all/these" session recordings
|
||||
- the user wants to use "current/these" filters
|
||||
- the user's query specifies filters identical to the current filters
|
||||
- if the user's query doesn't specify any filters/conditions
|
||||
- the user refers to what they're "looking at" or "viewing"
|
||||
* Set to `false` if one of the conditions is met:
|
||||
- no current filters or `search_session_recordings` tool are present in the conversation
|
||||
- the user specifies date/time period different from the current filters
|
||||
- the user specifies conditions (user, device, id, URL, etc.) not present in the current filters
|
||||
""",
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -196,7 +196,11 @@ class _SharedAssistantState(BaseState):
|
||||
"""
|
||||
session_summarization_query: Optional[str] = Field(default=None)
|
||||
"""
|
||||
The user's query for summarizing sessions.
|
||||
The user's query for summarizing sessions. Always pass the user's complete, unmodified query.
|
||||
"""
|
||||
should_use_current_filters: Optional[bool] = Field(default=None)
|
||||
"""
|
||||
Whether to use current filters from user's UI to find relevant sessions.
|
||||
"""
|
||||
notebook_id: Optional[str] = Field(default=None)
|
||||
"""
|
||||
|
||||
@@ -100,9 +100,17 @@ class SearchSessionRecordingsArgs(BaseModel):
|
||||
|
||||
class SearchSessionRecordingsTool(MaxTool):
|
||||
name: str = "search_session_recordings"
|
||||
description: str = (
|
||||
"Update session recordings filters on this page, in order to search for session recordings by any criteria."
|
||||
)
|
||||
description: str = """
|
||||
- Update session recordings filters on this page, in order to search for session recordings.
|
||||
- When to use the tool:
|
||||
* When the user asks to update session recordings filters
|
||||
- "update" synonyms: "change", "modify", "adjust", and similar
|
||||
- "session recordings" synonyms: "sessions", "recordings", "replays", "user sessions", and similar
|
||||
* When the user asks to search for session recordings
|
||||
- "search for" synonyms: "find", "look up", and similar
|
||||
- When NOT to use the tool:
|
||||
* When the user asks to summarize session recordings
|
||||
"""
|
||||
thinking_message: str = "Coming up with session recordings filters"
|
||||
root_system_prompt_template: str = "Current recordings filters are: {current_filters}"
|
||||
args_schema: type[BaseModel] = SearchSessionRecordingsArgs
|
||||
|
||||
Reference in New Issue
Block a user