chore: lazy loading stream writer (#37175)

This commit is contained in:
Alessandro Pogliaghi
2025-08-26 16:49:19 +01:00
committed by GitHub
parent a5394c47f2
commit 28b472717b
2 changed files with 66 additions and 15 deletions

View File

@@ -1,5 +1,4 @@
import re
import logging
import warnings
from datetime import timedelta
from typing import Literal
@@ -8,6 +7,7 @@ from uuid import uuid4
from django.db.models import Max
from django.utils import timezone
import structlog
from langchain_core.messages import AIMessageChunk, BaseMessage, HumanMessage, SystemMessage, ToolMessage
from langchain_core.runnables import RunnableConfig
from langchain_core.tools import tool
@@ -36,7 +36,7 @@ from .prompts import (
TOOL_BASED_EVALUATION_SYSTEM_PROMPT,
)
logger = logging.getLogger(__name__)
logger = structlog.get_logger(__name__)
# Silence Pydantic serializer warnings for creation of VisualizationMessage/Query execution
warnings.filterwarnings("ignore", category=UserWarning, message=".*Pydantic serializer.*")
@@ -64,12 +64,15 @@ class InsightSearchNode(AssistantNode):
self._cutoff_date_for_insights_in_days = self.INSIGHTS_CUTOFF_DAYS
self._query_cache = {}
self._insight_id_cache = {}
self._stream_writer = None
def _get_stream_writer(self) -> StreamWriter | None:
try:
return get_stream_writer()
except Exception:
return None
if self._stream_writer is None:
try:
self._stream_writer = get_stream_writer()
except Exception:
self._stream_writer = None
return self._stream_writer
def _stream_reasoning(
self, content: str, substeps: list[str] | None = None, writer: StreamWriter | None = None
@@ -92,7 +95,7 @@ class InsightSearchNode(AssistantNode):
writer(("insights_search_node", "messages", message))
except Exception as e:
logger.exception("Failed to stream reasoning message", extra={"error": str(e), "content": content})
logger.exception("Failed to stream reasoning message", error=str(e), content=content)
def _create_page_reader_tool(self):
"""Create tool for reading insights pages during agentic RAG loop."""
@@ -144,7 +147,6 @@ class InsightSearchNode(AssistantNode):
async def arun(self, state: AssistantState, config: RunnableConfig) -> PartialAssistantState | None:
search_query = state.search_insights_query
try:
self._current_iteration = 0
@@ -153,6 +155,9 @@ class InsightSearchNode(AssistantNode):
return self._handle_empty_database(state)
selected_insights = await self._search_insights_iteratively(search_query or "")
logger.warning(
f"search_insights_iteratively returned {len(selected_insights)} insights: {selected_insights}"
)
writer = self._get_stream_writer()
if selected_insights:
@@ -261,14 +266,31 @@ class InsightSearchNode(AssistantNode):
async def _load_insights_page(self, page_number: int) -> list[Insight]:
"""Load a specific page of insights from database."""
logger.warning(f"_load_insights_page called with page_number={page_number}")
if page_number in self._loaded_pages:
logger.info(f"Page {page_number} found in cache with {len(self._loaded_pages[page_number])} insights")
return self._loaded_pages[page_number]
start_idx = page_number * self._page_size
end_idx = start_idx + self._page_size
insights_qs = self._get_insights_queryset()[start_idx:end_idx]
page_insights = [i async for i in insights_qs]
logger.warning(f"Executing async query for page {page_number}")
import time
db_start = time.time()
page_insights = []
async for i in insights_qs:
page_insights.append(i)
logger.debug(f"Loaded insight {i.id}: {i.name or i.derived_name}")
db_elapsed = time.time() - db_start
logger.warning(
f"Database query completed in {db_elapsed:.2f}s, loaded {len(page_insights)} insights for page {page_number}"
)
logger.warning(f"DB QUERY: took {db_elapsed:.2f}s to load page {page_number}")
self._loaded_pages[page_number] = page_insights
@@ -330,9 +352,11 @@ class InsightSearchNode(AssistantNode):
substeps=[f"Analyzing {await self._get_total_insights_count()} available insights"],
writer=writer,
)
logger.warning(f"Starting iterative search, max_iterations={self._max_iterations}")
while self._current_iteration < self._max_iterations:
self._current_iteration += 1
logger.warning(f"Iteration {self._current_iteration}/{self._max_iterations} starting")
try:
response = await llm_with_tools.ainvoke(messages)
@@ -344,13 +368,23 @@ class InsightSearchNode(AssistantNode):
for tool_call in response.tool_calls:
if tool_call.get("name") == "read_insights_page":
page_num = tool_call.get("args", {}).get("page_number", 0)
logger.warning(f"Reading insights page {page_num}")
page_message = "Finding the most relevant insights"
logger.warning(
f"STALL POINT(?): Streamed 'Finding the most relevant insights' - about to fetch page content"
)
self._stream_reasoning(content=page_message, writer=writer)
logger.warning(f"Fetching page content for page {page_num}")
tool_response = await self._get_page_content_for_tool(page_num)
logger.warning(f"Page content fetched successfully, length={len(tool_response)}")
messages.append(
ToolMessage(content=tool_response, tool_call_id=tool_call.get("id", "unknown"))
)
logger.warning("Continuing to next iteration after tool calls")
continue
# No tool calls, extract insight IDs from the response. Done with the search

View File

@@ -1,7 +1,7 @@
import asyncio
from datetime import timedelta
from posthog.test.base import NonAtomicBaseTest
from posthog.test.base import BaseTest
from unittest.mock import AsyncMock, MagicMock, patch
from django.utils import timezone
@@ -28,12 +28,29 @@ from ee.hogai.utils.types import AssistantState, PartialAssistantState
from ee.models.assistant import Conversation
# TRICKY: Preserve the non-atomic setup for the test. Otherwise, threads would stall because
# `query_executor.arun_and_format_query` spawns threads with a new connection.
class TestInsightSearchNode(NonAtomicBaseTest):
# TRICKY: See above.
CLASS_DATA_LEVEL_SETUP = False
def create_mock_query_executor():
"""Mock query executor instead of querying ClickHouse (since we are using NonAtomicBaseTest)"""
mock_executor = MagicMock()
async def mock_arun_and_format_query(query_obj):
"""Return mocked query results based on query type."""
if isinstance(query_obj, TrendsQuery):
return "Mocked trends query results: Daily pageviews = 1000", {}
elif isinstance(query_obj, FunnelsQuery):
return "Mocked funnel query results: Conversion rate = 25%", {}
elif isinstance(query_obj, RetentionQuery):
return "Mocked retention query results: Day 1 retention = 40%", {}
elif isinstance(query_obj, HogQLQuery):
return "Mocked HogQL query results: Result count = 42", {}
else:
return "Mocked query results", {}
mock_executor.arun_and_format_query = mock_arun_and_format_query
return mock_executor
@patch("ee.hogai.graph.insights.nodes.AssistantQueryExecutor", create_mock_query_executor)
class TestInsightSearchNode(BaseTest):
def setUp(self):
super().setUp()
self.node = InsightSearchNode(self.team, self.user)