diff --git a/ee/hogai/graph/insights/nodes.py b/ee/hogai/graph/insights/nodes.py index 535dfcb617..9e280be3f6 100644 --- a/ee/hogai/graph/insights/nodes.py +++ b/ee/hogai/graph/insights/nodes.py @@ -1,5 +1,4 @@ import re -import logging import warnings from datetime import timedelta from typing import Literal @@ -8,6 +7,7 @@ from uuid import uuid4 from django.db.models import Max from django.utils import timezone +import structlog from langchain_core.messages import AIMessageChunk, BaseMessage, HumanMessage, SystemMessage, ToolMessage from langchain_core.runnables import RunnableConfig from langchain_core.tools import tool @@ -36,7 +36,7 @@ from .prompts import ( TOOL_BASED_EVALUATION_SYSTEM_PROMPT, ) -logger = logging.getLogger(__name__) +logger = structlog.get_logger(__name__) # Silence Pydantic serializer warnings for creation of VisualizationMessage/Query execution warnings.filterwarnings("ignore", category=UserWarning, message=".*Pydantic serializer.*") @@ -64,12 +64,15 @@ class InsightSearchNode(AssistantNode): self._cutoff_date_for_insights_in_days = self.INSIGHTS_CUTOFF_DAYS self._query_cache = {} self._insight_id_cache = {} + self._stream_writer = None def _get_stream_writer(self) -> StreamWriter | None: - try: - return get_stream_writer() - except Exception: - return None + if self._stream_writer is None: + try: + self._stream_writer = get_stream_writer() + except Exception: + self._stream_writer = None + return self._stream_writer def _stream_reasoning( self, content: str, substeps: list[str] | None = None, writer: StreamWriter | None = None @@ -92,7 +95,7 @@ class InsightSearchNode(AssistantNode): writer(("insights_search_node", "messages", message)) except Exception as e: - logger.exception("Failed to stream reasoning message", extra={"error": str(e), "content": content}) + logger.exception("Failed to stream reasoning message", error=str(e), content=content) def _create_page_reader_tool(self): """Create tool for reading insights pages during agentic RAG loop.""" @@ -144,7 +147,6 @@ class InsightSearchNode(AssistantNode): async def arun(self, state: AssistantState, config: RunnableConfig) -> PartialAssistantState | None: search_query = state.search_insights_query - try: self._current_iteration = 0 @@ -153,6 +155,9 @@ class InsightSearchNode(AssistantNode): return self._handle_empty_database(state) selected_insights = await self._search_insights_iteratively(search_query or "") + logger.warning( + f"search_insights_iteratively returned {len(selected_insights)} insights: {selected_insights}" + ) writer = self._get_stream_writer() if selected_insights: @@ -261,14 +266,31 @@ class InsightSearchNode(AssistantNode): async def _load_insights_page(self, page_number: int) -> list[Insight]: """Load a specific page of insights from database.""" + logger.warning(f"_load_insights_page called with page_number={page_number}") + if page_number in self._loaded_pages: + logger.info(f"Page {page_number} found in cache with {len(self._loaded_pages[page_number])} insights") return self._loaded_pages[page_number] start_idx = page_number * self._page_size end_idx = start_idx + self._page_size insights_qs = self._get_insights_queryset()[start_idx:end_idx] - page_insights = [i async for i in insights_qs] + logger.warning(f"Executing async query for page {page_number}") + + import time + + db_start = time.time() + page_insights = [] + async for i in insights_qs: + page_insights.append(i) + logger.debug(f"Loaded insight {i.id}: {i.name or i.derived_name}") + db_elapsed = time.time() - db_start + + logger.warning( + f"Database query completed in {db_elapsed:.2f}s, loaded {len(page_insights)} insights for page {page_number}" + ) + logger.warning(f"DB QUERY: took {db_elapsed:.2f}s to load page {page_number}") self._loaded_pages[page_number] = page_insights @@ -330,9 +352,11 @@ class InsightSearchNode(AssistantNode): substeps=[f"Analyzing {await self._get_total_insights_count()} available insights"], writer=writer, ) + logger.warning(f"Starting iterative search, max_iterations={self._max_iterations}") while self._current_iteration < self._max_iterations: self._current_iteration += 1 + logger.warning(f"Iteration {self._current_iteration}/{self._max_iterations} starting") try: response = await llm_with_tools.ainvoke(messages) @@ -344,13 +368,23 @@ class InsightSearchNode(AssistantNode): for tool_call in response.tool_calls: if tool_call.get("name") == "read_insights_page": page_num = tool_call.get("args", {}).get("page_number", 0) + logger.warning(f"Reading insights page {page_num}") + page_message = "Finding the most relevant insights" + logger.warning( + f"STALL POINT(?): Streamed 'Finding the most relevant insights' - about to fetch page content" + ) self._stream_reasoning(content=page_message, writer=writer) + logger.warning(f"Fetching page content for page {page_num}") tool_response = await self._get_page_content_for_tool(page_num) + logger.warning(f"Page content fetched successfully, length={len(tool_response)}") + messages.append( ToolMessage(content=tool_response, tool_call_id=tool_call.get("id", "unknown")) ) + + logger.warning("Continuing to next iteration after tool calls") continue # No tool calls, extract insight IDs from the response. Done with the search diff --git a/ee/hogai/graph/insights/test/test_nodes.py b/ee/hogai/graph/insights/test/test_nodes.py index 9303101879..ea1839986a 100644 --- a/ee/hogai/graph/insights/test/test_nodes.py +++ b/ee/hogai/graph/insights/test/test_nodes.py @@ -1,7 +1,7 @@ import asyncio from datetime import timedelta -from posthog.test.base import NonAtomicBaseTest +from posthog.test.base import BaseTest from unittest.mock import AsyncMock, MagicMock, patch from django.utils import timezone @@ -28,12 +28,29 @@ from ee.hogai.utils.types import AssistantState, PartialAssistantState from ee.models.assistant import Conversation -# TRICKY: Preserve the non-atomic setup for the test. Otherwise, threads would stall because -# `query_executor.arun_and_format_query` spawns threads with a new connection. -class TestInsightSearchNode(NonAtomicBaseTest): - # TRICKY: See above. - CLASS_DATA_LEVEL_SETUP = False +def create_mock_query_executor(): + """Mock query executor instead of querying ClickHouse (since we are using NonAtomicBaseTest)""" + mock_executor = MagicMock() + async def mock_arun_and_format_query(query_obj): + """Return mocked query results based on query type.""" + if isinstance(query_obj, TrendsQuery): + return "Mocked trends query results: Daily pageviews = 1000", {} + elif isinstance(query_obj, FunnelsQuery): + return "Mocked funnel query results: Conversion rate = 25%", {} + elif isinstance(query_obj, RetentionQuery): + return "Mocked retention query results: Day 1 retention = 40%", {} + elif isinstance(query_obj, HogQLQuery): + return "Mocked HogQL query results: Result count = 42", {} + else: + return "Mocked query results", {} + + mock_executor.arun_and_format_query = mock_arun_and_format_query + return mock_executor + + +@patch("ee.hogai.graph.insights.nodes.AssistantQueryExecutor", create_mock_query_executor) +class TestInsightSearchNode(BaseTest): def setUp(self): super().setUp() self.node = InsightSearchNode(self.team, self.user)