mirror of
https://github.com/BillyOutlast/posthog.git
synced 2026-02-04 03:01:23 +01:00
fix(ph-ai): conversation compaction edge cases (#41443)
This commit is contained in:
@@ -7,10 +7,12 @@ from uuid import uuid4
|
||||
from langchain_anthropic import ChatAnthropic
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_core.messages import (
|
||||
AIMessage as LangchainAIMessage,
|
||||
BaseMessage,
|
||||
HumanMessage as LangchainHumanMessage,
|
||||
)
|
||||
from langchain_core.tools import BaseTool
|
||||
from langchain_core.utils.function_calling import convert_to_openai_tool
|
||||
from pydantic import BaseModel
|
||||
|
||||
from posthog.schema import AssistantMessage, AssistantToolCallMessage, ContextMessage, HumanMessage
|
||||
@@ -36,7 +38,7 @@ class ConversationCompactionManager(ABC):
|
||||
Manages conversation window boundaries, message filtering, and summarization decisions.
|
||||
"""
|
||||
|
||||
CONVERSATION_WINDOW_SIZE = 64000
|
||||
CONVERSATION_WINDOW_SIZE = 100_000
|
||||
"""
|
||||
Determines the maximum number of tokens allowed in the conversation window.
|
||||
"""
|
||||
@@ -54,7 +56,7 @@ class ConversationCompactionManager(ABC):
|
||||
new_window_id: str | None = None
|
||||
for message in reversed(messages):
|
||||
# Handle limits before assigning the window ID.
|
||||
max_tokens -= self._get_estimated_tokens(message)
|
||||
max_tokens -= self._get_estimated_assistant_message_tokens(message)
|
||||
max_messages -= 1
|
||||
if max_tokens < 0 or max_messages < 0:
|
||||
break
|
||||
@@ -83,12 +85,20 @@ class ConversationCompactionManager(ABC):
|
||||
Determine if the conversation should be summarized based on token count.
|
||||
Avoids summarizing if there are only two human messages or fewer.
|
||||
"""
|
||||
return await self.calculate_token_count(model, messages, tools, **kwargs) > self.CONVERSATION_WINDOW_SIZE
|
||||
|
||||
async def calculate_token_count(
|
||||
self, model: BaseChatModel, messages: list[BaseMessage], tools: LangchainTools | None = None, **kwargs
|
||||
) -> int:
|
||||
"""
|
||||
Calculate the token count for a conversation.
|
||||
"""
|
||||
# Avoid summarizing the conversation if there is only two human messages.
|
||||
human_messages = [message for message in messages if isinstance(message, LangchainHumanMessage)]
|
||||
if len(human_messages) <= 2:
|
||||
return False
|
||||
token_count = await self._get_token_count(model, messages, tools, **kwargs)
|
||||
return token_count > self.CONVERSATION_WINDOW_SIZE
|
||||
tool_tokens = self._get_estimated_tools_tokens(tools) if tools else 0
|
||||
return sum(self._get_estimated_langchain_message_tokens(message) for message in messages) + tool_tokens
|
||||
return await self._get_token_count(model, messages, tools, **kwargs)
|
||||
|
||||
def update_window(
|
||||
self, messages: Sequence[T], summary_message: ContextMessage, start_id: str | None = None
|
||||
@@ -134,7 +144,7 @@ class ConversationCompactionManager(ABC):
|
||||
updated_window_start_id=window_start_id_candidate,
|
||||
)
|
||||
|
||||
def _get_estimated_tokens(self, message: AssistantMessageUnion) -> int:
|
||||
def _get_estimated_assistant_message_tokens(self, message: AssistantMessageUnion) -> int:
|
||||
"""
|
||||
Estimate token count for a message using character/4 heuristic.
|
||||
"""
|
||||
@@ -149,6 +159,24 @@ class ConversationCompactionManager(ABC):
|
||||
char_count = len(message.content)
|
||||
return round(char_count / self.APPROXIMATE_TOKEN_LENGTH)
|
||||
|
||||
def _get_estimated_langchain_message_tokens(self, message: BaseMessage) -> int:
|
||||
"""
|
||||
Estimate token count for a message using character/4 heuristic.
|
||||
"""
|
||||
char_count = 0
|
||||
if isinstance(message.content, str):
|
||||
char_count = len(message.content)
|
||||
else:
|
||||
for content in message.content:
|
||||
if isinstance(content, str):
|
||||
char_count += len(content)
|
||||
elif isinstance(content, dict):
|
||||
char_count += self._count_json_tokens(content)
|
||||
if isinstance(message, LangchainAIMessage) and message.tool_calls:
|
||||
for tool_call in message.tool_calls:
|
||||
char_count += len(json.dumps(tool_call, separators=(",", ":")))
|
||||
return round(char_count / self.APPROXIMATE_TOKEN_LENGTH)
|
||||
|
||||
def _get_conversation_window(self, messages: Sequence[T], start_id: str) -> Sequence[T]:
|
||||
"""
|
||||
Get messages from the start_id onwards.
|
||||
@@ -158,6 +186,22 @@ class ConversationCompactionManager(ABC):
|
||||
return messages[idx:]
|
||||
return messages
|
||||
|
||||
def _get_estimated_tools_tokens(self, tools: LangchainTools) -> int:
|
||||
"""
|
||||
Estimate token count for tools by converting them to JSON schemas.
|
||||
"""
|
||||
if not tools:
|
||||
return 0
|
||||
|
||||
total_chars = 0
|
||||
for tool in tools:
|
||||
tool_schema = convert_to_openai_tool(tool)
|
||||
total_chars += self._count_json_tokens(tool_schema)
|
||||
return round(total_chars / self.APPROXIMATE_TOKEN_LENGTH)
|
||||
|
||||
def _count_json_tokens(self, json_data: dict) -> int:
|
||||
return len(json.dumps(json_data, separators=(",", ":")))
|
||||
|
||||
@abstractmethod
|
||||
async def _get_token_count(
|
||||
self,
|
||||
|
||||
@@ -187,12 +187,18 @@ class AgentExecutable(BaseAgentExecutable):
|
||||
start_id = state.start_id
|
||||
|
||||
# Summarize the conversation if it's too long.
|
||||
if await self._window_manager.should_compact_conversation(
|
||||
current_token_count = await self._window_manager.calculate_token_count(
|
||||
model, langchain_messages, tools=tools, thinking_config=self.THINKING_CONFIG
|
||||
):
|
||||
)
|
||||
if current_token_count > self._window_manager.CONVERSATION_WINDOW_SIZE:
|
||||
# Exclude the last message if it's the first turn.
|
||||
messages_to_summarize = langchain_messages[:-1] if self._is_first_turn(state) else langchain_messages
|
||||
summary = await AnthropicConversationSummarizer(self._team, self._user).summarize(messages_to_summarize)
|
||||
summary = await AnthropicConversationSummarizer(
|
||||
self._team,
|
||||
self._user,
|
||||
extend_context_window=current_token_count > 195_000,
|
||||
).summarize(messages_to_summarize)
|
||||
|
||||
summary_message = ContextMessage(
|
||||
content=ROOT_CONVERSATION_SUMMARY_PROMPT.format(summary=summary),
|
||||
id=str(uuid4()),
|
||||
|
||||
@@ -114,11 +114,11 @@ class TestAnthropicConversationCompactionManager(BaseTest):
|
||||
@parameterized.expand(
|
||||
[
|
||||
# (num_human_messages, token_count, should_compact)
|
||||
[1, 70000, False], # Only 1 human message
|
||||
[2, 70000, False], # Only 2 human messages
|
||||
[3, 50000, False], # 3 human messages but under token limit
|
||||
[3, 70000, True], # 3 human messages and over token limit
|
||||
[5, 70000, True], # Many messages over limit
|
||||
[1, 90000, False], # Only 1 human message, under limit
|
||||
[2, 90000, False], # Only 2 human messages, under limit
|
||||
[3, 80000, False], # 3 human messages but under token limit
|
||||
[3, 110000, True], # 3 human messages and over token limit
|
||||
[5, 110000, True], # Many messages over limit
|
||||
]
|
||||
)
|
||||
async def test_should_compact_conversation(self, num_human_messages, token_count, should_compact):
|
||||
@@ -136,19 +136,57 @@ class TestAnthropicConversationCompactionManager(BaseTest):
|
||||
result = await self.window_manager.should_compact_conversation(mock_model, messages)
|
||||
self.assertEqual(result, should_compact)
|
||||
|
||||
def test_get_estimated_tokens_human_message(self):
|
||||
async def test_should_compact_conversation_with_tools_under_limit(self):
|
||||
"""Test that tools are accounted for when estimating tokens with 2 or fewer human messages"""
|
||||
from langchain_core.tools import tool
|
||||
|
||||
@tool
|
||||
def test_tool(query: str) -> str:
|
||||
"""A test tool"""
|
||||
return f"Result for {query}"
|
||||
|
||||
messages: list[BaseMessage] = [
|
||||
LangchainHumanMessage(content="A" * 1000), # ~250 tokens
|
||||
LangchainAIMessage(content="B" * 1000), # ~250 tokens
|
||||
]
|
||||
tools = [test_tool]
|
||||
|
||||
mock_model = MagicMock()
|
||||
# With 2 human messages, should use estimation and not call _get_token_count
|
||||
result = await self.window_manager.should_compact_conversation(mock_model, messages, tools=tools)
|
||||
|
||||
# Total should be well under 100k limit
|
||||
self.assertFalse(result)
|
||||
|
||||
async def test_should_compact_conversation_with_tools_over_limit(self):
|
||||
"""Test that tools push estimation over limit with 2 or fewer human messages"""
|
||||
messages: list[BaseMessage] = [
|
||||
LangchainHumanMessage(content="A" * 200000), # ~50k tokens
|
||||
LangchainAIMessage(content="B" * 200000), # ~50k tokens
|
||||
]
|
||||
|
||||
# Create large tool schemas to push over 100k limit
|
||||
tools = [{"type": "function", "function": {"name": f"tool_{i}", "description": "X" * 1000}} for i in range(100)]
|
||||
|
||||
mock_model = MagicMock()
|
||||
result = await self.window_manager.should_compact_conversation(mock_model, messages, tools=tools)
|
||||
|
||||
# Should be over the 100k limit
|
||||
self.assertTrue(result)
|
||||
|
||||
def test_get_estimated_assistant_message_tokens_human_message(self):
|
||||
"""Test token estimation for human messages"""
|
||||
message = HumanMessage(content="A" * 100, id="1") # 100 chars = ~25 tokens
|
||||
tokens = self.window_manager._get_estimated_tokens(message)
|
||||
tokens = self.window_manager._get_estimated_assistant_message_tokens(message)
|
||||
self.assertEqual(tokens, 25)
|
||||
|
||||
def test_get_estimated_tokens_assistant_message(self):
|
||||
def test_get_estimated_assistant_message_tokens_assistant_message(self):
|
||||
"""Test token estimation for assistant messages without tool calls"""
|
||||
message = AssistantMessage(content="A" * 100, id="1") # 100 chars = ~25 tokens
|
||||
tokens = self.window_manager._get_estimated_tokens(message)
|
||||
tokens = self.window_manager._get_estimated_assistant_message_tokens(message)
|
||||
self.assertEqual(tokens, 25)
|
||||
|
||||
def test_get_estimated_tokens_assistant_message_with_tool_calls(self):
|
||||
def test_get_estimated_assistant_message_tokens_assistant_message_with_tool_calls(self):
|
||||
"""Test token estimation for assistant messages with tool calls"""
|
||||
message = AssistantMessage(
|
||||
content="A" * 100, # 100 chars
|
||||
@@ -162,17 +200,117 @@ class TestAnthropicConversationCompactionManager(BaseTest):
|
||||
],
|
||||
)
|
||||
# Should count content + JSON serialized args
|
||||
tokens = self.window_manager._get_estimated_tokens(message)
|
||||
tokens = self.window_manager._get_estimated_assistant_message_tokens(message)
|
||||
# 100 chars content + ~15 chars for args = ~29 tokens
|
||||
self.assertGreater(tokens, 25)
|
||||
self.assertLess(tokens, 35)
|
||||
|
||||
def test_get_estimated_tokens_tool_call_message(self):
|
||||
def test_get_estimated_assistant_message_tokens_tool_call_message(self):
|
||||
"""Test token estimation for tool call messages"""
|
||||
message = AssistantToolCallMessage(content="A" * 200, id="1", tool_call_id="t1")
|
||||
tokens = self.window_manager._get_estimated_tokens(message)
|
||||
tokens = self.window_manager._get_estimated_assistant_message_tokens(message)
|
||||
self.assertEqual(tokens, 50)
|
||||
|
||||
def test_get_estimated_langchain_message_tokens_string_content(self):
|
||||
"""Test token estimation for langchain messages with string content"""
|
||||
message = LangchainHumanMessage(content="A" * 100)
|
||||
tokens = self.window_manager._get_estimated_langchain_message_tokens(message)
|
||||
self.assertEqual(tokens, 25)
|
||||
|
||||
def test_get_estimated_langchain_message_tokens_list_content_with_strings(self):
|
||||
"""Test token estimation for langchain messages with list of string content"""
|
||||
message = LangchainHumanMessage(content=["A" * 100, "B" * 100])
|
||||
tokens = self.window_manager._get_estimated_langchain_message_tokens(message)
|
||||
self.assertEqual(tokens, 50)
|
||||
|
||||
def test_get_estimated_langchain_message_tokens_list_content_with_dicts(self):
|
||||
"""Test token estimation for langchain messages with dict content"""
|
||||
message = LangchainHumanMessage(content=[{"type": "text", "text": "A" * 100}])
|
||||
tokens = self.window_manager._get_estimated_langchain_message_tokens(message)
|
||||
# 100 chars for text + overhead for JSON structure
|
||||
self.assertGreater(tokens, 25)
|
||||
self.assertLess(tokens, 40)
|
||||
|
||||
def test_get_estimated_langchain_message_tokens_ai_message_with_tool_calls(self):
|
||||
"""Test token estimation for AI messages with tool calls"""
|
||||
message = LangchainAIMessage(
|
||||
content="A" * 100,
|
||||
tool_calls=[
|
||||
{"id": "t1", "name": "test_tool", "args": {"key": "value"}},
|
||||
{"id": "t2", "name": "another_tool", "args": {"foo": "bar"}},
|
||||
],
|
||||
)
|
||||
tokens = self.window_manager._get_estimated_langchain_message_tokens(message)
|
||||
# Content + tool calls JSON
|
||||
self.assertGreater(tokens, 25)
|
||||
self.assertLess(tokens, 70)
|
||||
|
||||
def test_get_estimated_langchain_message_tokens_ai_message_without_tool_calls(self):
|
||||
"""Test token estimation for AI messages without tool calls"""
|
||||
message = LangchainAIMessage(content="A" * 100)
|
||||
tokens = self.window_manager._get_estimated_langchain_message_tokens(message)
|
||||
self.assertEqual(tokens, 25)
|
||||
|
||||
def test_count_json_tokens(self):
|
||||
"""Test JSON token counting helper"""
|
||||
json_data = {"key": "value", "nested": {"foo": "bar"}}
|
||||
char_count = self.window_manager._count_json_tokens(json_data)
|
||||
# Should match length of compact JSON
|
||||
import json
|
||||
|
||||
expected = len(json.dumps(json_data, separators=(",", ":")))
|
||||
self.assertEqual(char_count, expected)
|
||||
|
||||
def test_get_estimated_tools_tokens_empty(self):
|
||||
"""Test tool token estimation with no tools"""
|
||||
tokens = self.window_manager._get_estimated_tools_tokens([])
|
||||
self.assertEqual(tokens, 0)
|
||||
|
||||
def test_get_estimated_tools_tokens_with_dict_tools(self):
|
||||
"""Test tool token estimation with dict tools"""
|
||||
tools = [
|
||||
{"type": "function", "function": {"name": "test_tool", "description": "A test tool"}},
|
||||
]
|
||||
tokens = self.window_manager._get_estimated_tools_tokens(tools)
|
||||
# Should be positive and reasonable
|
||||
self.assertGreater(tokens, 0)
|
||||
self.assertLess(tokens, 100)
|
||||
|
||||
def test_get_estimated_tools_tokens_with_base_tool(self):
|
||||
"""Test tool token estimation with BaseTool"""
|
||||
from langchain_core.tools import tool
|
||||
|
||||
@tool
|
||||
def sample_tool(query: str) -> str:
|
||||
"""A sample tool for testing"""
|
||||
return f"Result for {query}"
|
||||
|
||||
tools = [sample_tool]
|
||||
tokens = self.window_manager._get_estimated_tools_tokens(tools)
|
||||
# Should count the tool schema
|
||||
self.assertGreater(tokens, 0)
|
||||
self.assertLess(tokens, 200)
|
||||
|
||||
def test_get_estimated_tools_tokens_multiple_tools(self):
|
||||
"""Test tool token estimation with multiple tools"""
|
||||
from langchain_core.tools import tool
|
||||
|
||||
@tool
|
||||
def tool1(x: int) -> int:
|
||||
"""First tool"""
|
||||
return x * 2
|
||||
|
||||
@tool
|
||||
def tool2(y: str) -> str:
|
||||
"""Second tool"""
|
||||
return y.upper()
|
||||
|
||||
tools = [tool1, tool2]
|
||||
tokens = self.window_manager._get_estimated_tools_tokens(tools)
|
||||
# Should count both tool schemas
|
||||
self.assertGreater(tokens, 0)
|
||||
self.assertLess(tokens, 400)
|
||||
|
||||
async def test_get_token_count_calls_model(self):
|
||||
"""Test that _get_token_count properly calls the model's token counting"""
|
||||
mock_model = MagicMock()
|
||||
|
||||
@@ -520,13 +520,12 @@ class TestAgentNode(ClickhouseTestMixin, BaseTest):
|
||||
self.assertEqual(await node._get_billing_prompt(), expected_prompt)
|
||||
|
||||
@patch("ee.hogai.graph.agent_modes.nodes.AgentExecutable._get_model", return_value=FakeChatOpenAI(responses=[]))
|
||||
@patch(
|
||||
"ee.hogai.graph.agent_modes.compaction_manager.AnthropicConversationCompactionManager.should_compact_conversation"
|
||||
)
|
||||
@patch("ee.hogai.graph.agent_modes.compaction_manager.AnthropicConversationCompactionManager.calculate_token_count")
|
||||
@patch("ee.hogai.graph.conversation_summarizer.nodes.AnthropicConversationSummarizer.summarize")
|
||||
async def test_conversation_summarization_flow(self, mock_summarize, mock_should_compact, mock_model):
|
||||
async def test_conversation_summarization_flow(self, mock_summarize, mock_calculate_tokens, mock_model):
|
||||
"""Test that conversation is summarized when it gets too long"""
|
||||
mock_should_compact.return_value = True
|
||||
# Return a token count higher than CONVERSATION_WINDOW_SIZE (100,000)
|
||||
mock_calculate_tokens.return_value = 150_000
|
||||
mock_summarize.return_value = "This is a summary of the conversation so far."
|
||||
|
||||
mock_model_instance = FakeChatOpenAI(responses=[LangchainAIMessage(content="Response after summary")])
|
||||
@@ -554,13 +553,12 @@ class TestAgentNode(ClickhouseTestMixin, BaseTest):
|
||||
self.assertIn("This is a summary of the conversation so far.", context_messages[0].content)
|
||||
|
||||
@patch("ee.hogai.graph.agent_modes.nodes.AgentExecutable._get_model", return_value=FakeChatOpenAI(responses=[]))
|
||||
@patch(
|
||||
"ee.hogai.graph.agent_modes.compaction_manager.AnthropicConversationCompactionManager.should_compact_conversation"
|
||||
)
|
||||
@patch("ee.hogai.graph.agent_modes.compaction_manager.AnthropicConversationCompactionManager.calculate_token_count")
|
||||
@patch("ee.hogai.graph.conversation_summarizer.nodes.AnthropicConversationSummarizer.summarize")
|
||||
async def test_conversation_summarization_on_first_turn(self, mock_summarize, mock_should_compact, mock_model):
|
||||
async def test_conversation_summarization_on_first_turn(self, mock_summarize, mock_calculate_tokens, mock_model):
|
||||
"""Test that on first turn, the last message is excluded from summarization"""
|
||||
mock_should_compact.return_value = True
|
||||
# Return a token count higher than CONVERSATION_WINDOW_SIZE (100,000)
|
||||
mock_calculate_tokens.return_value = 150_000
|
||||
mock_summarize.return_value = "Summary without last message"
|
||||
|
||||
mock_model_instance = FakeChatOpenAI(responses=[LangchainAIMessage(content="Response")])
|
||||
|
||||
@@ -55,9 +55,14 @@ class ConversationSummarizer:
|
||||
|
||||
|
||||
class AnthropicConversationSummarizer(ConversationSummarizer):
|
||||
def __init__(self, team: Team, user: User, extend_context_window: bool | None = False):
|
||||
super().__init__(team, user)
|
||||
self._extend_context_window = extend_context_window
|
||||
|
||||
def _get_model(self):
|
||||
# Haiku has 200k token limit. Sonnet has 1M token limit.
|
||||
return MaxChatAnthropic(
|
||||
model="claude-haiku-4-5",
|
||||
model="claude-sonnet-4-5" if self._extend_context_window else "claude-haiku-4-5",
|
||||
streaming=False,
|
||||
stream_usage=False,
|
||||
max_tokens=8192,
|
||||
@@ -65,6 +70,7 @@ class AnthropicConversationSummarizer(ConversationSummarizer):
|
||||
user=self._user,
|
||||
team=self._team,
|
||||
billable=True,
|
||||
betas=["context-1m-2025-08-07"] if self._extend_context_window else None,
|
||||
)
|
||||
|
||||
def _construct_messages(self, messages: Sequence[BaseMessage]):
|
||||
|
||||
Reference in New Issue
Block a user