mirror of
https://github.com/BillyOutlast/posthog.git
synced 2026-02-04 03:01:23 +01:00
chore: lazy loading stream writer (#37175)
This commit is contained in:
committed by
GitHub
parent
a5394c47f2
commit
28b472717b
@@ -1,5 +1,4 @@
|
||||
import re
|
||||
import logging
|
||||
import warnings
|
||||
from datetime import timedelta
|
||||
from typing import Literal
|
||||
@@ -8,6 +7,7 @@ from uuid import uuid4
|
||||
from django.db.models import Max
|
||||
from django.utils import timezone
|
||||
|
||||
import structlog
|
||||
from langchain_core.messages import AIMessageChunk, BaseMessage, HumanMessage, SystemMessage, ToolMessage
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langchain_core.tools import tool
|
||||
@@ -36,7 +36,7 @@ from .prompts import (
|
||||
TOOL_BASED_EVALUATION_SYSTEM_PROMPT,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger = structlog.get_logger(__name__)
|
||||
# Silence Pydantic serializer warnings for creation of VisualizationMessage/Query execution
|
||||
warnings.filterwarnings("ignore", category=UserWarning, message=".*Pydantic serializer.*")
|
||||
|
||||
@@ -64,12 +64,15 @@ class InsightSearchNode(AssistantNode):
|
||||
self._cutoff_date_for_insights_in_days = self.INSIGHTS_CUTOFF_DAYS
|
||||
self._query_cache = {}
|
||||
self._insight_id_cache = {}
|
||||
self._stream_writer = None
|
||||
|
||||
def _get_stream_writer(self) -> StreamWriter | None:
|
||||
try:
|
||||
return get_stream_writer()
|
||||
except Exception:
|
||||
return None
|
||||
if self._stream_writer is None:
|
||||
try:
|
||||
self._stream_writer = get_stream_writer()
|
||||
except Exception:
|
||||
self._stream_writer = None
|
||||
return self._stream_writer
|
||||
|
||||
def _stream_reasoning(
|
||||
self, content: str, substeps: list[str] | None = None, writer: StreamWriter | None = None
|
||||
@@ -92,7 +95,7 @@ class InsightSearchNode(AssistantNode):
|
||||
writer(("insights_search_node", "messages", message))
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Failed to stream reasoning message", extra={"error": str(e), "content": content})
|
||||
logger.exception("Failed to stream reasoning message", error=str(e), content=content)
|
||||
|
||||
def _create_page_reader_tool(self):
|
||||
"""Create tool for reading insights pages during agentic RAG loop."""
|
||||
@@ -144,7 +147,6 @@ class InsightSearchNode(AssistantNode):
|
||||
|
||||
async def arun(self, state: AssistantState, config: RunnableConfig) -> PartialAssistantState | None:
|
||||
search_query = state.search_insights_query
|
||||
|
||||
try:
|
||||
self._current_iteration = 0
|
||||
|
||||
@@ -153,6 +155,9 @@ class InsightSearchNode(AssistantNode):
|
||||
return self._handle_empty_database(state)
|
||||
|
||||
selected_insights = await self._search_insights_iteratively(search_query or "")
|
||||
logger.warning(
|
||||
f"search_insights_iteratively returned {len(selected_insights)} insights: {selected_insights}"
|
||||
)
|
||||
|
||||
writer = self._get_stream_writer()
|
||||
if selected_insights:
|
||||
@@ -261,14 +266,31 @@ class InsightSearchNode(AssistantNode):
|
||||
|
||||
async def _load_insights_page(self, page_number: int) -> list[Insight]:
|
||||
"""Load a specific page of insights from database."""
|
||||
logger.warning(f"_load_insights_page called with page_number={page_number}")
|
||||
|
||||
if page_number in self._loaded_pages:
|
||||
logger.info(f"Page {page_number} found in cache with {len(self._loaded_pages[page_number])} insights")
|
||||
return self._loaded_pages[page_number]
|
||||
|
||||
start_idx = page_number * self._page_size
|
||||
end_idx = start_idx + self._page_size
|
||||
|
||||
insights_qs = self._get_insights_queryset()[start_idx:end_idx]
|
||||
page_insights = [i async for i in insights_qs]
|
||||
logger.warning(f"Executing async query for page {page_number}")
|
||||
|
||||
import time
|
||||
|
||||
db_start = time.time()
|
||||
page_insights = []
|
||||
async for i in insights_qs:
|
||||
page_insights.append(i)
|
||||
logger.debug(f"Loaded insight {i.id}: {i.name or i.derived_name}")
|
||||
db_elapsed = time.time() - db_start
|
||||
|
||||
logger.warning(
|
||||
f"Database query completed in {db_elapsed:.2f}s, loaded {len(page_insights)} insights for page {page_number}"
|
||||
)
|
||||
logger.warning(f"DB QUERY: took {db_elapsed:.2f}s to load page {page_number}")
|
||||
|
||||
self._loaded_pages[page_number] = page_insights
|
||||
|
||||
@@ -330,9 +352,11 @@ class InsightSearchNode(AssistantNode):
|
||||
substeps=[f"Analyzing {await self._get_total_insights_count()} available insights"],
|
||||
writer=writer,
|
||||
)
|
||||
logger.warning(f"Starting iterative search, max_iterations={self._max_iterations}")
|
||||
|
||||
while self._current_iteration < self._max_iterations:
|
||||
self._current_iteration += 1
|
||||
logger.warning(f"Iteration {self._current_iteration}/{self._max_iterations} starting")
|
||||
|
||||
try:
|
||||
response = await llm_with_tools.ainvoke(messages)
|
||||
@@ -344,13 +368,23 @@ class InsightSearchNode(AssistantNode):
|
||||
for tool_call in response.tool_calls:
|
||||
if tool_call.get("name") == "read_insights_page":
|
||||
page_num = tool_call.get("args", {}).get("page_number", 0)
|
||||
logger.warning(f"Reading insights page {page_num}")
|
||||
|
||||
page_message = "Finding the most relevant insights"
|
||||
logger.warning(
|
||||
f"STALL POINT(?): Streamed 'Finding the most relevant insights' - about to fetch page content"
|
||||
)
|
||||
self._stream_reasoning(content=page_message, writer=writer)
|
||||
|
||||
logger.warning(f"Fetching page content for page {page_num}")
|
||||
tool_response = await self._get_page_content_for_tool(page_num)
|
||||
logger.warning(f"Page content fetched successfully, length={len(tool_response)}")
|
||||
|
||||
messages.append(
|
||||
ToolMessage(content=tool_response, tool_call_id=tool_call.get("id", "unknown"))
|
||||
)
|
||||
|
||||
logger.warning("Continuing to next iteration after tool calls")
|
||||
continue
|
||||
|
||||
# No tool calls, extract insight IDs from the response. Done with the search
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import asyncio
|
||||
from datetime import timedelta
|
||||
|
||||
from posthog.test.base import NonAtomicBaseTest
|
||||
from posthog.test.base import BaseTest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
from django.utils import timezone
|
||||
@@ -28,12 +28,29 @@ from ee.hogai.utils.types import AssistantState, PartialAssistantState
|
||||
from ee.models.assistant import Conversation
|
||||
|
||||
|
||||
# TRICKY: Preserve the non-atomic setup for the test. Otherwise, threads would stall because
|
||||
# `query_executor.arun_and_format_query` spawns threads with a new connection.
|
||||
class TestInsightSearchNode(NonAtomicBaseTest):
|
||||
# TRICKY: See above.
|
||||
CLASS_DATA_LEVEL_SETUP = False
|
||||
def create_mock_query_executor():
|
||||
"""Mock query executor instead of querying ClickHouse (since we are using NonAtomicBaseTest)"""
|
||||
mock_executor = MagicMock()
|
||||
|
||||
async def mock_arun_and_format_query(query_obj):
|
||||
"""Return mocked query results based on query type."""
|
||||
if isinstance(query_obj, TrendsQuery):
|
||||
return "Mocked trends query results: Daily pageviews = 1000", {}
|
||||
elif isinstance(query_obj, FunnelsQuery):
|
||||
return "Mocked funnel query results: Conversion rate = 25%", {}
|
||||
elif isinstance(query_obj, RetentionQuery):
|
||||
return "Mocked retention query results: Day 1 retention = 40%", {}
|
||||
elif isinstance(query_obj, HogQLQuery):
|
||||
return "Mocked HogQL query results: Result count = 42", {}
|
||||
else:
|
||||
return "Mocked query results", {}
|
||||
|
||||
mock_executor.arun_and_format_query = mock_arun_and_format_query
|
||||
return mock_executor
|
||||
|
||||
|
||||
@patch("ee.hogai.graph.insights.nodes.AssistantQueryExecutor", create_mock_query_executor)
|
||||
class TestInsightSearchNode(BaseTest):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.node = InsightSearchNode(self.team, self.user)
|
||||
|
||||
Reference in New Issue
Block a user