Files
posthog/ee/hogai/graph/insights/nodes.py
2025-11-12 11:48:19 +01:00

872 lines
36 KiB
Python

import re
import time
import inspect
import warnings
from datetime import timedelta
from functools import wraps
from typing import Optional, TypedDict
from uuid import uuid4
from django.db.models import Max
from django.utils import timezone
import structlog
from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage, ToolMessage
from langchain_core.runnables import RunnableConfig
from langchain_core.tools import tool
from posthog.schema import AssistantToolCallMessage, VisualizationMessage
from posthog.exceptions_capture import capture_exception
from posthog.models import Insight
from ee.hogai.context import SUPPORTED_QUERY_MODEL_BY_KIND
from ee.hogai.graph.base import AssistantNode
from ee.hogai.graph.query_executor.query_executor import AssistantQueryExecutor, SupportedQueryTypes
from ee.hogai.graph.shared_prompts import HYPERLINK_USAGE_INSTRUCTIONS
from ee.hogai.llm import MaxChatOpenAI
from ee.hogai.utils.helpers import build_insight_url
from ee.hogai.utils.types import AssistantState, PartialAssistantState
from .prompts import (
ITERATIVE_SEARCH_SYSTEM_PROMPT,
ITERATIVE_SEARCH_USER_PROMPT,
NO_INSIGHTS_FOUND_MESSAGE,
PAGINATION_INSTRUCTIONS_TEMPLATE,
TOOL_BASED_EVALUATION_SYSTEM_PROMPT,
)
logger = structlog.get_logger(__name__)
# Silence Pydantic serializer warnings for creation of VisualizationMessage/Query execution
warnings.filterwarnings("ignore", category=UserWarning, message=".*Pydantic serializer.*")
TIMING_LOG_PREFIX = "[INSIGHT_SEARCH]"
def timing_logger(func_name: str | None = None):
"""Decorator to log execution time of functions.
Investigating production bottleneck
"""
def decorator(func):
name = func_name or f"{func.__module__}.{func.__qualname__}"
# Check async
if inspect.iscoroutinefunction(func):
@wraps(func)
async def async_wrapper(*args, **kwargs):
start_time = time.time()
logger.warning(f"{TIMING_LOG_PREFIX} Starting {name}")
try:
result = await func(*args, **kwargs)
elapsed = time.time() - start_time
logger.warning(f"{TIMING_LOG_PREFIX} {name} completed in {elapsed:.3f}s")
return result
except Exception:
elapsed = time.time() - start_time
logger.exception(f"{TIMING_LOG_PREFIX} {name} failed after {elapsed:.3f}s")
raise
return async_wrapper
else:
@wraps(func)
def sync_wrapper(*args, **kwargs):
start_time = time.time()
logger.warning(f"{TIMING_LOG_PREFIX} Starting {name}")
try:
result = func(*args, **kwargs)
elapsed = time.time() - start_time
logger.warning(f"{TIMING_LOG_PREFIX} {name} completed in {elapsed:.3f}s")
return result
except Exception:
elapsed = time.time() - start_time
logger.exception(f"{TIMING_LOG_PREFIX} {name} failed after {elapsed:.3f}s")
raise
return sync_wrapper
return decorator
class InsightDict(TypedDict):
"""TypedDict for insight data returned from queryset.values()."""
id: int
name: Optional[str]
description: Optional[str]
query: Optional[dict]
derived_name: Optional[str]
short_id: str
class NoInsightsException(Exception):
"""Exception indicating that the insight search cannot be done because the user does not have any insights."""
class InsightSearchNode(AssistantNode):
PAGE_SIZE = 500
MAX_SEARCH_ITERATIONS = 6
MAX_INSIGHTS_TO_RETURN = 3
MAX_EVALUATION_ITERATIONS = 3
INSIGHTS_CUTOFF_DAYS = 180
MAX_SERIES_TO_PROCESS = 3
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._current_page = 0
self._page_size = self.PAGE_SIZE
self._max_iterations = self.MAX_SEARCH_ITERATIONS
self._current_iteration = 0
self._loaded_pages = {}
self._total_insights_count = None
self._max_insights_to_select = self.MAX_INSIGHTS_TO_RETURN
self._max_insights_evaluation_iterations = self.MAX_EVALUATION_ITERATIONS
self._evaluation_selections = {}
self._rejection_reason = None
self._cutoff_date_for_insights_in_days = self.INSIGHTS_CUTOFF_DAYS
self._query_cache = {}
self._insight_id_cache = {}
@timing_logger("InsightSearchNode.arun")
async def arun(self, state: AssistantState, config: RunnableConfig) -> PartialAssistantState | None:
self.dispatcher.update("Searching for insights")
search_query = state.search_insights_query
self._current_iteration = 0
total_count = await self._get_total_insights_count()
if total_count == 0:
raise NoInsightsException
selected_insights = await self._search_insights_iteratively(search_query or "")
logger.warning(
f"{TIMING_LOG_PREFIX} search_insights_iteratively returned {len(selected_insights)} insights: {selected_insights}"
)
if selected_insights:
self.dispatcher.update(f"Evaluating {len(selected_insights)} insights to find the best match")
else:
self.dispatcher.update("No existing insights found, creating a new one")
evaluation_result = await self._evaluate_insights_with_tools(
selected_insights, search_query or "", max_selections=1
)
return self._handle_evaluation_result(evaluation_result, state)
def _create_page_reader_tool(self):
"""Create tool for reading insights pages during agentic RAG loop."""
@tool
async def read_insights_page(page_number: int) -> str:
"""Read a page of insights data.
Args:
page_number: The page number to read (0-based)
Returns:
Formatted insights data for the requested page
"""
page_insights = await self._load_insights_page(page_number)
if not page_insights:
return "No more insights available."
formatted_insights = [self._format_insight_for_display(insight) for insight in page_insights]
return f"Page {page_number + 1} insights:\n" + "\n".join(formatted_insights)
return read_insights_page
def _create_insight_evaluation_tools(self):
"""Create tools for insight evaluation."""
@tool
def select_insight(insight_id: int, explanation: str) -> str:
"""Select an insight as useful for the user's query."""
insight = self._find_insight_by_id(insight_id)
if not insight:
return f"Insight {insight_id} not found"
self._evaluation_selections[insight_id] = {"insight": insight, "explanation": explanation}
name = insight["name"] or insight["derived_name"] or "Unnamed"
insight_url = build_insight_url(self._team, insight["short_id"])
return f"Selected insight {insight_id}: {name} (url: {insight_url})"
@tool
def reject_all_insights(reason: str) -> str:
"""Indicate that none of the insights are suitable."""
self._evaluation_selections = {}
self._rejection_reason = reason
return "All insights rejected. Will create new insight."
return [select_insight, reject_all_insights]
@timing_logger("InsightSearchNode._get_insights_queryset")
def _get_insights_queryset(self):
"""Get Insight objects with latest view time annotated and cutoff date."""
cutoff_date = timezone.now() - timedelta(days=self._cutoff_date_for_insights_in_days)
return (
Insight.objects.filter(team=self._team, deleted=False)
# Annotate with latest view time from InsightViewed
.annotate(latest_view_time=Max("insightviewed__last_viewed_at"))
# Only include insights viewed within the last 6 months
.filter(latest_view_time__gte=cutoff_date)
.values("id", "name", "description", "query", "derived_name", "short_id")
.order_by("-latest_view_time")
)
@timing_logger("InsightSearchNode._get_total_insights_count")
async def _get_total_insights_count(self) -> int:
if self._total_insights_count is None:
self._total_insights_count = await self._get_insights_queryset().acount()
return self._total_insights_count
def _handle_evaluation_result(self, evaluation_result: dict, state: AssistantState) -> PartialAssistantState:
"""Process the evaluation result and return appropriate response."""
if evaluation_result["should_use_existing"]:
return self._create_existing_insights_response(evaluation_result, state)
else:
return self._create_new_insight_response(state.search_insights_query, state)
@timing_logger("InsightSearchNode._create_existing_insights_response")
def _create_existing_insights_response(
self, evaluation_result: dict, state: AssistantState
) -> PartialAssistantState:
"""Create response for when existing insights are found."""
formatted_content = f"**Evaluation Result**: {evaluation_result['explanation']}"
formatted_content += HYPERLINK_USAGE_INSTRUCTIONS
return PartialAssistantState(
messages=[
*evaluation_result["visualization_messages"],
AssistantToolCallMessage(
content=formatted_content,
tool_call_id=state.root_tool_call_id or "unknown",
id=str(uuid4()),
),
],
selected_insight_ids=evaluation_result["selected_insights"],
search_insights_query=None,
root_tool_call_id=None,
root_tool_insight_plan=None,
)
def _create_new_insight_response(self, search_query: str | None, state: AssistantState) -> PartialAssistantState:
"""Create response for when no suitable insights are found."""
return PartialAssistantState(
messages=[
AssistantToolCallMessage(
content=NO_INSIGHTS_FOUND_MESSAGE,
tool_call_id=state.root_tool_call_id or "unknown",
id=str(uuid4()),
)
],
root_tool_insight_plan=search_query,
search_insights_query=None,
selected_insight_ids=None,
)
def _format_insight_for_display(self, insight: InsightDict) -> str:
"""Format a single insight for display."""
name = insight["name"] or insight["derived_name"] or "Unnamed"
description = insight["description"] or ""
base = f"ID: {insight['id']} | {name}"
return f"{base} - {description}" if description else base
@timing_logger("InsightSearchNode._load_insights_page")
async def _load_insights_page(self, page_number: int) -> list[InsightDict]:
"""Load a specific page of insights from database."""
logger.warning(f"{TIMING_LOG_PREFIX} _load_insights_page called with page_number={page_number}")
if page_number in self._loaded_pages:
logger.info(
f"{TIMING_LOG_PREFIX} Page {page_number} found in cache with {len(self._loaded_pages[page_number])} insights"
)
return self._loaded_pages[page_number]
start_idx = page_number * self._page_size
end_idx = start_idx + self._page_size
insights_qs = self._get_insights_queryset()[start_idx:end_idx]
logger.warning(
f"{TIMING_LOG_PREFIX} Executing async query for page {page_number} (range: {start_idx}-{end_idx})"
)
db_start = time.time()
page_insights = []
insight_count = 0
try:
logger.warning(f"{TIMING_LOG_PREFIX} Starting async iteration for page {page_number}")
last_progress_time = time.time()
async for i in insights_qs:
insight_count += 1
page_insights.append(i)
current_time = time.time()
# Log progress every 100 insights or every 10 seconds in order to understand why we are getting stuck
if insight_count % 100 == 0 or (current_time - last_progress_time) > 10:
elapsed_so_far = current_time - db_start
logger.warning(
f"{TIMING_LOG_PREFIX} Progress: loaded {insight_count} insights for page {page_number} in {elapsed_so_far:.2f}s"
)
last_progress_time = current_time
# Certain insights may take too long
# Certain insights may take too long - check against db_start instead of last_progress_time
if (current_time - db_start) > 5 and insight_count == 1:
logger.warning(
f"{TIMING_LOG_PREFIX} Slow insight processing detected for page {page_number}, insight #{insight_count}"
)
logger.warning(
f"{TIMING_LOG_PREFIX} Async iteration completed for page {page_number}, total insights: {insight_count}"
)
except Exception as e:
elapsed_on_error = time.time() - db_start
logger.error(
f"{TIMING_LOG_PREFIX} Exception during async iteration for page {page_number} after {elapsed_on_error:.2f}s, loaded {insight_count} insights: {e}",
exc_info=True,
)
raise
db_elapsed = time.time() - db_start
logger.warning(
f"{TIMING_LOG_PREFIX} Database query completed in {db_elapsed:.2f}s, loaded {len(page_insights)} insights for page {page_number}"
)
logger.warning(f"{TIMING_LOG_PREFIX} DB QUERY: took {db_elapsed:.2f}s to load page {page_number}")
self._loaded_pages[page_number] = page_insights
for insight in page_insights:
self._insight_id_cache[insight["id"]] = insight
return page_insights
@timing_logger("InsightSearchNode._search_insights_iteratively")
async def _search_insights_iteratively(self, search_query: str) -> list[int]:
"""Execute iterative insight search with LLM and tool calling."""
messages = await self._build_search_messages(search_query)
llm_with_tools = await self._prepare_llm_with_tools()
selected_insights = await self._perform_iterative_search(messages, llm_with_tools)
if not selected_insights:
return []
return selected_insights[: self._max_insights_to_select]
async def _build_search_messages(self, search_query: str) -> list[BaseMessage]:
"""Build the initial messages for the search."""
first_page = await self._format_insights_page(0)
pagination_instructions = await self._get_pagination_instructions()
system_prompt = ITERATIVE_SEARCH_SYSTEM_PROMPT.format(
first_page_insights=first_page, pagination_instructions=pagination_instructions
)
user_prompt = ITERATIVE_SEARCH_USER_PROMPT.format(query=search_query)
return [SystemMessage(content=system_prompt), HumanMessage(content=user_prompt)]
async def _get_pagination_instructions(self) -> str:
"""Get pagination instructions based on available insights."""
total_insights = await self._get_total_insights_count()
total_pages = self._calculate_total_pages(total_insights)
if total_pages > 1:
return PAGINATION_INSTRUCTIONS_TEMPLATE.format(total_pages=total_pages)
return "This is the only page of insights available."
async def _prepare_llm_with_tools(self):
"""Prepare LLM with pagination tools if needed."""
total_insights = await self._get_total_insights_count()
total_pages = self._calculate_total_pages(total_insights)
if total_pages > 1:
read_tool = self._create_page_reader_tool()
return self._model.bind_tools([read_tool])
return self._model
@timing_logger("InsightSearchNode._perform_iterative_search")
async def _perform_iterative_search(self, messages: list[BaseMessage], llm_with_tools) -> list[int]:
"""Perform the iterative search with the LLM."""
selected_insights = []
for step in ["Searching through existing insights", "Analyzing available insights"]:
self.dispatcher.update(step)
logger.warning(f"{TIMING_LOG_PREFIX} Starting iterative search, max_iterations={self._max_iterations}")
while self._current_iteration < self._max_iterations:
self._current_iteration += 1
logger.warning(f"{TIMING_LOG_PREFIX} Iteration {self._current_iteration}/{self._max_iterations} starting")
try:
response = await llm_with_tools.ainvoke(messages)
if hasattr(response, "tool_calls") and response.tool_calls:
# Required for tool message protocol
messages.append(response)
for tool_call in response.tool_calls:
if tool_call.get("name") == "read_insights_page":
page_num = tool_call.get("args", {}).get("page_number", 0)
logger.warning(f"{TIMING_LOG_PREFIX} Reading insights page {page_num}")
logger.warning(
f"{TIMING_LOG_PREFIX} STALL POINT(?): Streamed 'Finding the most relevant insights' - about to fetch page content"
)
self.dispatcher.update("Finding the most relevant insights")
logger.warning(f"{TIMING_LOG_PREFIX} Fetching page content for page {page_num}")
tool_response = await self._get_page_content_for_tool(page_num)
logger.warning(
f"{TIMING_LOG_PREFIX} Page content fetched successfully, length={len(tool_response)}"
)
messages.append(
ToolMessage(content=tool_response, tool_call_id=tool_call.get("id", "unknown"))
)
logger.warning(f"{TIMING_LOG_PREFIX} Continuing to next iteration after tool calls")
continue
# No tool calls, extract insight IDs from the response. Done with the search
content = response.content if isinstance(response.content, str) else str(response.content)
selected_insights = self._parse_insight_ids(content)
if selected_insights:
self.dispatcher.update(f"Found {len(selected_insights)} relevant insights")
else:
self.dispatcher.update("No matching insights found")
break
except Exception as e:
capture_exception(e)
error_message = f"Error during search"
self.dispatcher.update(error_message)
break
return selected_insights
async def _get_page_content_for_tool(self, page_number: int) -> str:
"""Get page content for tool response."""
if page_number == 0:
return "Page 0 data is already provided in the initial context above."
else:
page_content = await self._format_insights_page(page_number)
return f"Page {page_number + 1} results:\n{page_content}"
def _calculate_total_pages(self, total_insights: int) -> int:
"""Calculate total number of pages for insights."""
return (total_insights + self._page_size - 1) // self._page_size
async def _format_insights_page(self, page_number: int) -> str:
"""Format a page of insights for display."""
page_insights = await self._load_insights_page(page_number)
if not page_insights:
return "No insights available on this page."
formatted_insights = [self._format_insight_for_display(insight) for insight in page_insights]
return "\n".join(formatted_insights)
@timing_logger("InsightSearchNode._get_all_loaded_insight_ids")
def _get_all_loaded_insight_ids(self) -> set[int]:
"""Get all insight IDs from loaded pages."""
all_ids = set()
for page_insights in self._loaded_pages.values():
for insight in page_insights:
all_ids.add(insight["id"])
return all_ids
@timing_logger("InsightSearchNode._find_insight_by_id")
def _find_insight_by_id(self, insight_id: int) -> InsightDict | None:
"""Find an insight by ID across all loaded pages (with cache)."""
return self._insight_id_cache.get(insight_id)
@timing_logger("InsightSearchNode._process_insight_query")
async def _process_insight_query(self, insight: InsightDict) -> tuple[SupportedQueryTypes | None, str | None]:
"""
Process an insight's query and cache object and formatted results for reference
"""
insight_id = insight["id"]
cached_result = self._get_cached_query(insight_id)
if cached_result is not None:
return cached_result
if not insight["query"]:
return self._cache_and_return(insight_id, None, None)
query_obj, formatted_results = await self._extract_and_execute_query(insight)
return self._cache_and_return(insight_id, query_obj, formatted_results)
def _get_cached_query(self, insight_id: int) -> tuple[SupportedQueryTypes | None, str | None] | None:
"""Get cached query result if available."""
if insight_id in self._query_cache:
return self._query_cache[insight_id]
return None
def _cache_and_return(
self,
insight_id: int,
query_obj: SupportedQueryTypes | None,
formatted_results: str | None,
) -> tuple[SupportedQueryTypes | None, str | None]:
"""Cache and return query result."""
result = (query_obj, formatted_results)
self._query_cache[insight_id] = result
return result
@timing_logger("InsightSearchNode._extract_and_execute_query")
async def _extract_and_execute_query(self, insight: InsightDict) -> tuple[SupportedQueryTypes | None, str | None]:
"""Extract query object and execute it."""
try:
query_dict = insight["query"]
if query_dict is None:
return None, "Query is missing"
query_source = query_dict.get("source", {})
insight_type = query_source.get("kind", "Unknown")
query_obj = self._validate_and_create_query_object(insight_type, query_source)
if query_obj is None:
return None, "Query type not supported for execution"
formatted_results = await self._execute_and_format_query(query_obj, insight["id"])
return query_obj, formatted_results
except Exception as e:
capture_exception(e)
return None, "Query processing failed"
@timing_logger("InsightSearchNode._validate_and_create_query_object")
def _validate_and_create_query_object(self, insight_type: str, query_source: dict) -> SupportedQueryTypes | None:
"""Validate query type and create query object."""
if insight_type not in SUPPORTED_QUERY_MODEL_BY_KIND:
return None
AssistantQueryModel = SUPPORTED_QUERY_MODEL_BY_KIND[insight_type]
return AssistantQueryModel.model_validate(query_source, strict=False)
@timing_logger("InsightSearchNode._execute_and_format_query")
async def _execute_and_format_query(self, query_obj: SupportedQueryTypes, insight_id: int) -> str:
"""Execute query and format results with timing instrumentation."""
try:
query_executor = AssistantQueryExecutor(team=self._team, utc_now_datetime=self._utc_now_datetime)
results, _ = await query_executor.arun_and_format_query(query_obj, debug_timing=True)
return results
except Exception as e:
capture_exception(e)
return "Query execution failed"
@timing_logger("InsightSearchNode._parse_insight_ids")
def _parse_insight_ids(self, response_content: str) -> list[int]:
"""Parse insight IDs from LLM response, removing duplicates and preserving order."""
numbers = re.findall(r"\b\d+\b", response_content)
# Convert to integers and validate against available insights
available_ids = self._get_all_loaded_insight_ids()
valid_ids = []
seen_ids = set()
for num_str in numbers:
try:
insight_id = int(num_str)
if insight_id in available_ids and insight_id not in seen_ids:
valid_ids.append(insight_id)
seen_ids.add(insight_id)
# Stop if we've found enough unique insights
if len(valid_ids) >= self._max_insights_to_select:
break
except ValueError:
continue
return valid_ids
@timing_logger("InsightSearchNode._create_enhanced_insight_summary")
async def _create_enhanced_insight_summary(self, insight: InsightDict) -> str:
"""Create enhanced summary with metadata and basic execution info."""
insight_id = insight["id"]
name = insight["name"] or insight["derived_name"] or "Unnamed"
description = insight["description"] or ""
insight_type = "Unknown"
query_info = None
_, query_result = await self._process_insight_query(insight)
if insight["query"]:
try:
query_dict = insight["query"]
query_source = query_dict.get("source", {})
insight_type = query_source.get("kind", "Unknown")
query_info = self._extract_query_metadata(query_source)
except Exception as e:
capture_exception(e)
insight_url = build_insight_url(self._team, insight["short_id"])
hyperlink_format = f"[{name}]({insight_url})"
summary_parts = [
f"ID: {insight_id} | {name} | {hyperlink_format}",
f"Type: {insight_type}",
f"Query result: {query_result}",
]
if description:
summary_parts.append(f"Description: {description}")
if query_info:
summary_parts.append(f"Query: {query_info}")
return " | ".join(summary_parts)
@timing_logger("InsightSearchNode._extract_query_metadata")
def _extract_query_metadata(self, query_source: dict) -> str | None:
"""Extract basic query information from Insight object without execution."""
try:
if not query_source:
return None
# Extract basic info from query
info_parts = []
# Get events/series info - only process first 3 for efficiency
series = query_source.get("series", [])
if series:
events = []
for series_item in series[: self.MAX_SERIES_TO_PROCESS]:
if isinstance(series_item, dict):
event_name = series_item.get("event", series_item.get("name", "Unknown"))
if event_name:
events.append(str(event_name))
if events:
info_parts.append(f"Events: {', '.join(events)}")
# Get date range info
date_range = query_source.get("dateRange", {})
if date_range:
date_from = date_range.get("date_from", "")
if date_from:
info_parts.append(f"Period: {date_from}")
return " | ".join(info_parts) if info_parts else None
except Exception as e:
capture_exception(e)
return None
@timing_logger("InsightSearchNode._create_visualization_message_for_insight")
async def _create_visualization_message_for_insight(self, insight: InsightDict) -> VisualizationMessage | None:
"""Create a VisualizationMessage to render the insight UI."""
try:
for step in ["Executing insight query...", "Processing query parameters", "Running data analysis"]:
self.dispatcher.update(step)
query_obj, _ = await self._process_insight_query(insight)
if not query_obj:
return None
insight_name = insight["name"] or insight["derived_name"] or "Unnamed Insight"
visualization_message = VisualizationMessage(
query=f"Existing insight: {insight_name}",
plan=f"Showing existing insight: {insight_name}",
answer=query_obj,
id=str(uuid4()),
short_id=insight["short_id"],
)
return visualization_message
except Exception as e:
capture_exception(e)
return None
def _create_error_response(self, content: str, tool_call_id: str | None) -> PartialAssistantState:
"""Create error response for the assistant."""
return PartialAssistantState(
messages=[
AssistantToolCallMessage(
content=content,
tool_call_id=tool_call_id or "unknown",
id=str(uuid4()),
),
],
search_insights_query=None,
root_tool_call_id=None,
)
@timing_logger("InsightSearchNode._evaluate_insights_with_tools")
async def _evaluate_insights_with_tools(
self, selected_insights: list[int], user_query: str, max_selections: int = 1
) -> dict:
"""Evaluate insights using tool calls for fine-grained selection.
Args:
selected_insights: List of insight IDs to evaluate
user_query: The user's search query
max_selections: Maximum number of insights to select (default: 1, best possible match)
"""
self._reset_evaluation_state()
insights_summary, final_selected_insights = await self._prepare_insights_for_evaluation(selected_insights)
if not final_selected_insights:
return self._no_insights_found_result()
await self._run_evaluation_loop(user_query, insights_summary, max_selections)
if self._evaluation_selections:
return await self._create_successful_evaluation_result()
else:
return await self._create_rejection_result()
def _reset_evaluation_state(self) -> None:
"""Reset evaluation state for new evaluation."""
self._evaluation_selections = {}
self._rejection_reason = None
async def _prepare_insights_for_evaluation(self, selected_insights: list[int]) -> tuple[list[str], list[int]]:
"""Prepare insights for evaluation."""
insights_summary = []
final_selected_insights = []
for insight_id in selected_insights:
insight = self._find_insight_by_id(insight_id)
if insight:
enhanced_summary = await self._create_enhanced_insight_summary(insight)
insights_summary.append(enhanced_summary)
final_selected_insights.append(insight_id)
return insights_summary, final_selected_insights
def _no_insights_found_result(self) -> dict:
"""Return result when no insights are found."""
return {
"should_use_existing": False,
"selected_insights": [],
"explanation": "No insights found matching the user's query.",
"visualization_messages": [],
}
@timing_logger("InsightSearchNode._run_evaluation_loop")
async def _run_evaluation_loop(self, user_query: str, insights_summary: list[str], max_selections: int) -> None:
"""Run the evaluation loop with LLM."""
for step in ["Analyzing insights to match your request", "Comparing insights for best fit"]:
self.dispatcher.update(step)
tools = self._create_insight_evaluation_tools()
llm_with_tools = self._model.bind_tools(tools)
selection_instruction = self._build_selection_instruction(max_selections)
messages = self._build_evaluation_messages(user_query, insights_summary, selection_instruction)
for iteration in range(self._max_insights_evaluation_iterations):
response = await llm_with_tools.ainvoke(messages)
if getattr(response, "tool_calls", None):
# Only stream on first iteration to avoid noise
if iteration == 0:
self.dispatcher.update("Making evaluation decisions")
self._process_evaluation_tool_calls(response, messages, tools)
else:
break
def _build_selection_instruction(self, max_selections: int) -> str:
"""Build instruction for insight selection."""
insight_word = "insight" if max_selections == 1 else "insights"
verb = "matches" if max_selections == 1 else "match"
return f"Select ONLY the {max_selections} BEST {insight_word} that {verb} the user's query."
def _build_evaluation_messages(
self, user_query: str, insights_summary: list[str], selection_instruction: str
) -> list[BaseMessage]:
"""Build messages for evaluation."""
system_prompt = TOOL_BASED_EVALUATION_SYSTEM_PROMPT.format(
user_query=user_query,
insights_summary=chr(10).join(insights_summary),
selection_instruction=selection_instruction,
)
return [SystemMessage(content=system_prompt)]
def _process_evaluation_tool_calls(self, response, messages: list[BaseMessage], tools: list) -> None:
"""Process tool calls during evaluation."""
messages.append(response)
for tool_call in response.tool_calls:
if tool_call["name"] in ["select_insight", "reject_all_insights"]:
tool_fn = next(t for t in tools if t.name == tool_call["name"])
result = tool_fn.invoke(tool_call["args"])
messages.append(ToolMessage(content=result, tool_call_id=tool_call["id"]))
@timing_logger("InsightSearchNode._create_successful_evaluation_result")
async def _create_successful_evaluation_result(self) -> dict:
"""Create result for successful evaluation."""
visualization_messages = []
explanations = []
num_insights = len(self._evaluation_selections)
insight_word = "insight" if num_insights == 1 else "insights"
self.dispatcher.update(f"Perfect! Found {num_insights} suitable {insight_word}")
for _, selection in self._evaluation_selections.items():
insight = selection["insight"]
visualization_message = await self._create_visualization_message_for_insight(insight)
if visualization_message:
visualization_messages.append(visualization_message)
insight_name = insight["name"] or insight["derived_name"] or "Unnamed"
insight_url = build_insight_url(self._team, insight["short_id"])
insight_hyperlink = f"[{insight_name}]({insight_url})"
explanations.append(f"- {insight_hyperlink}: {selection['explanation']}")
num_insights = len(self._evaluation_selections)
# If no insights were actually selected, this shouldn't be a successful result
if num_insights == 0:
return await self._create_rejection_result()
insight_word = "insight" if num_insights == 1 else "insights"
return {
"should_use_existing": True,
"selected_insights": list(self._evaluation_selections.keys()),
"explanation": f"Found {num_insights} relevant {insight_word}:\n" + "\n".join(explanations),
"visualization_messages": visualization_messages,
}
async def _create_rejection_result(self) -> dict:
"""Create result for when all insights are rejected."""
self.dispatcher.update("Will create a custom insight tailored to your request")
return {
"should_use_existing": False,
"selected_insights": [],
"explanation": self._rejection_reason or "No suitable insights found.",
"visualization_messages": [],
}
@property
def _model(self):
return MaxChatOpenAI(
model="gpt-4.1-mini",
temperature=0.7,
max_completion_tokens=1000,
streaming=False,
stream_usage=False,
max_retries=3,
disable_streaming=True,
user=self._user,
team=self._team,
billable=True,
inject_context=False,
)