fix(ph-ai): conversation compaction edge cases (#41443)

This commit is contained in:
Georgiy Tarasov
2025-11-13 17:44:28 +01:00
committed by GitHub
parent da1a1fb6d9
commit d8fdf6bbfe
5 changed files with 225 additions and 33 deletions

View File

@@ -7,10 +7,12 @@ from uuid import uuid4
from langchain_anthropic import ChatAnthropic
from langchain_core.language_models import BaseChatModel
from langchain_core.messages import (
AIMessage as LangchainAIMessage,
BaseMessage,
HumanMessage as LangchainHumanMessage,
)
from langchain_core.tools import BaseTool
from langchain_core.utils.function_calling import convert_to_openai_tool
from pydantic import BaseModel
from posthog.schema import AssistantMessage, AssistantToolCallMessage, ContextMessage, HumanMessage
@@ -36,7 +38,7 @@ class ConversationCompactionManager(ABC):
Manages conversation window boundaries, message filtering, and summarization decisions.
"""
CONVERSATION_WINDOW_SIZE = 64000
CONVERSATION_WINDOW_SIZE = 100_000
"""
Determines the maximum number of tokens allowed in the conversation window.
"""
@@ -54,7 +56,7 @@ class ConversationCompactionManager(ABC):
new_window_id: str | None = None
for message in reversed(messages):
# Handle limits before assigning the window ID.
max_tokens -= self._get_estimated_tokens(message)
max_tokens -= self._get_estimated_assistant_message_tokens(message)
max_messages -= 1
if max_tokens < 0 or max_messages < 0:
break
@@ -83,12 +85,20 @@ class ConversationCompactionManager(ABC):
Determine if the conversation should be summarized based on token count.
Avoids summarizing if there are only two human messages or fewer.
"""
return await self.calculate_token_count(model, messages, tools, **kwargs) > self.CONVERSATION_WINDOW_SIZE
async def calculate_token_count(
self, model: BaseChatModel, messages: list[BaseMessage], tools: LangchainTools | None = None, **kwargs
) -> int:
"""
Calculate the token count for a conversation.
"""
# Avoid summarizing the conversation if there is only two human messages.
human_messages = [message for message in messages if isinstance(message, LangchainHumanMessage)]
if len(human_messages) <= 2:
return False
token_count = await self._get_token_count(model, messages, tools, **kwargs)
return token_count > self.CONVERSATION_WINDOW_SIZE
tool_tokens = self._get_estimated_tools_tokens(tools) if tools else 0
return sum(self._get_estimated_langchain_message_tokens(message) for message in messages) + tool_tokens
return await self._get_token_count(model, messages, tools, **kwargs)
def update_window(
self, messages: Sequence[T], summary_message: ContextMessage, start_id: str | None = None
@@ -134,7 +144,7 @@ class ConversationCompactionManager(ABC):
updated_window_start_id=window_start_id_candidate,
)
def _get_estimated_tokens(self, message: AssistantMessageUnion) -> int:
def _get_estimated_assistant_message_tokens(self, message: AssistantMessageUnion) -> int:
"""
Estimate token count for a message using character/4 heuristic.
"""
@@ -149,6 +159,24 @@ class ConversationCompactionManager(ABC):
char_count = len(message.content)
return round(char_count / self.APPROXIMATE_TOKEN_LENGTH)
def _get_estimated_langchain_message_tokens(self, message: BaseMessage) -> int:
"""
Estimate token count for a message using character/4 heuristic.
"""
char_count = 0
if isinstance(message.content, str):
char_count = len(message.content)
else:
for content in message.content:
if isinstance(content, str):
char_count += len(content)
elif isinstance(content, dict):
char_count += self._count_json_tokens(content)
if isinstance(message, LangchainAIMessage) and message.tool_calls:
for tool_call in message.tool_calls:
char_count += len(json.dumps(tool_call, separators=(",", ":")))
return round(char_count / self.APPROXIMATE_TOKEN_LENGTH)
def _get_conversation_window(self, messages: Sequence[T], start_id: str) -> Sequence[T]:
"""
Get messages from the start_id onwards.
@@ -158,6 +186,22 @@ class ConversationCompactionManager(ABC):
return messages[idx:]
return messages
def _get_estimated_tools_tokens(self, tools: LangchainTools) -> int:
"""
Estimate token count for tools by converting them to JSON schemas.
"""
if not tools:
return 0
total_chars = 0
for tool in tools:
tool_schema = convert_to_openai_tool(tool)
total_chars += self._count_json_tokens(tool_schema)
return round(total_chars / self.APPROXIMATE_TOKEN_LENGTH)
def _count_json_tokens(self, json_data: dict) -> int:
return len(json.dumps(json_data, separators=(",", ":")))
@abstractmethod
async def _get_token_count(
self,

View File

@@ -187,12 +187,18 @@ class AgentExecutable(BaseAgentExecutable):
start_id = state.start_id
# Summarize the conversation if it's too long.
if await self._window_manager.should_compact_conversation(
current_token_count = await self._window_manager.calculate_token_count(
model, langchain_messages, tools=tools, thinking_config=self.THINKING_CONFIG
):
)
if current_token_count > self._window_manager.CONVERSATION_WINDOW_SIZE:
# Exclude the last message if it's the first turn.
messages_to_summarize = langchain_messages[:-1] if self._is_first_turn(state) else langchain_messages
summary = await AnthropicConversationSummarizer(self._team, self._user).summarize(messages_to_summarize)
summary = await AnthropicConversationSummarizer(
self._team,
self._user,
extend_context_window=current_token_count > 195_000,
).summarize(messages_to_summarize)
summary_message = ContextMessage(
content=ROOT_CONVERSATION_SUMMARY_PROMPT.format(summary=summary),
id=str(uuid4()),

View File

@@ -114,11 +114,11 @@ class TestAnthropicConversationCompactionManager(BaseTest):
@parameterized.expand(
[
# (num_human_messages, token_count, should_compact)
[1, 70000, False], # Only 1 human message
[2, 70000, False], # Only 2 human messages
[3, 50000, False], # 3 human messages but under token limit
[3, 70000, True], # 3 human messages and over token limit
[5, 70000, True], # Many messages over limit
[1, 90000, False], # Only 1 human message, under limit
[2, 90000, False], # Only 2 human messages, under limit
[3, 80000, False], # 3 human messages but under token limit
[3, 110000, True], # 3 human messages and over token limit
[5, 110000, True], # Many messages over limit
]
)
async def test_should_compact_conversation(self, num_human_messages, token_count, should_compact):
@@ -136,19 +136,57 @@ class TestAnthropicConversationCompactionManager(BaseTest):
result = await self.window_manager.should_compact_conversation(mock_model, messages)
self.assertEqual(result, should_compact)
def test_get_estimated_tokens_human_message(self):
async def test_should_compact_conversation_with_tools_under_limit(self):
"""Test that tools are accounted for when estimating tokens with 2 or fewer human messages"""
from langchain_core.tools import tool
@tool
def test_tool(query: str) -> str:
"""A test tool"""
return f"Result for {query}"
messages: list[BaseMessage] = [
LangchainHumanMessage(content="A" * 1000), # ~250 tokens
LangchainAIMessage(content="B" * 1000), # ~250 tokens
]
tools = [test_tool]
mock_model = MagicMock()
# With 2 human messages, should use estimation and not call _get_token_count
result = await self.window_manager.should_compact_conversation(mock_model, messages, tools=tools)
# Total should be well under 100k limit
self.assertFalse(result)
async def test_should_compact_conversation_with_tools_over_limit(self):
"""Test that tools push estimation over limit with 2 or fewer human messages"""
messages: list[BaseMessage] = [
LangchainHumanMessage(content="A" * 200000), # ~50k tokens
LangchainAIMessage(content="B" * 200000), # ~50k tokens
]
# Create large tool schemas to push over 100k limit
tools = [{"type": "function", "function": {"name": f"tool_{i}", "description": "X" * 1000}} for i in range(100)]
mock_model = MagicMock()
result = await self.window_manager.should_compact_conversation(mock_model, messages, tools=tools)
# Should be over the 100k limit
self.assertTrue(result)
def test_get_estimated_assistant_message_tokens_human_message(self):
"""Test token estimation for human messages"""
message = HumanMessage(content="A" * 100, id="1") # 100 chars = ~25 tokens
tokens = self.window_manager._get_estimated_tokens(message)
tokens = self.window_manager._get_estimated_assistant_message_tokens(message)
self.assertEqual(tokens, 25)
def test_get_estimated_tokens_assistant_message(self):
def test_get_estimated_assistant_message_tokens_assistant_message(self):
"""Test token estimation for assistant messages without tool calls"""
message = AssistantMessage(content="A" * 100, id="1") # 100 chars = ~25 tokens
tokens = self.window_manager._get_estimated_tokens(message)
tokens = self.window_manager._get_estimated_assistant_message_tokens(message)
self.assertEqual(tokens, 25)
def test_get_estimated_tokens_assistant_message_with_tool_calls(self):
def test_get_estimated_assistant_message_tokens_assistant_message_with_tool_calls(self):
"""Test token estimation for assistant messages with tool calls"""
message = AssistantMessage(
content="A" * 100, # 100 chars
@@ -162,17 +200,117 @@ class TestAnthropicConversationCompactionManager(BaseTest):
],
)
# Should count content + JSON serialized args
tokens = self.window_manager._get_estimated_tokens(message)
tokens = self.window_manager._get_estimated_assistant_message_tokens(message)
# 100 chars content + ~15 chars for args = ~29 tokens
self.assertGreater(tokens, 25)
self.assertLess(tokens, 35)
def test_get_estimated_tokens_tool_call_message(self):
def test_get_estimated_assistant_message_tokens_tool_call_message(self):
"""Test token estimation for tool call messages"""
message = AssistantToolCallMessage(content="A" * 200, id="1", tool_call_id="t1")
tokens = self.window_manager._get_estimated_tokens(message)
tokens = self.window_manager._get_estimated_assistant_message_tokens(message)
self.assertEqual(tokens, 50)
def test_get_estimated_langchain_message_tokens_string_content(self):
"""Test token estimation for langchain messages with string content"""
message = LangchainHumanMessage(content="A" * 100)
tokens = self.window_manager._get_estimated_langchain_message_tokens(message)
self.assertEqual(tokens, 25)
def test_get_estimated_langchain_message_tokens_list_content_with_strings(self):
"""Test token estimation for langchain messages with list of string content"""
message = LangchainHumanMessage(content=["A" * 100, "B" * 100])
tokens = self.window_manager._get_estimated_langchain_message_tokens(message)
self.assertEqual(tokens, 50)
def test_get_estimated_langchain_message_tokens_list_content_with_dicts(self):
"""Test token estimation for langchain messages with dict content"""
message = LangchainHumanMessage(content=[{"type": "text", "text": "A" * 100}])
tokens = self.window_manager._get_estimated_langchain_message_tokens(message)
# 100 chars for text + overhead for JSON structure
self.assertGreater(tokens, 25)
self.assertLess(tokens, 40)
def test_get_estimated_langchain_message_tokens_ai_message_with_tool_calls(self):
"""Test token estimation for AI messages with tool calls"""
message = LangchainAIMessage(
content="A" * 100,
tool_calls=[
{"id": "t1", "name": "test_tool", "args": {"key": "value"}},
{"id": "t2", "name": "another_tool", "args": {"foo": "bar"}},
],
)
tokens = self.window_manager._get_estimated_langchain_message_tokens(message)
# Content + tool calls JSON
self.assertGreater(tokens, 25)
self.assertLess(tokens, 70)
def test_get_estimated_langchain_message_tokens_ai_message_without_tool_calls(self):
"""Test token estimation for AI messages without tool calls"""
message = LangchainAIMessage(content="A" * 100)
tokens = self.window_manager._get_estimated_langchain_message_tokens(message)
self.assertEqual(tokens, 25)
def test_count_json_tokens(self):
"""Test JSON token counting helper"""
json_data = {"key": "value", "nested": {"foo": "bar"}}
char_count = self.window_manager._count_json_tokens(json_data)
# Should match length of compact JSON
import json
expected = len(json.dumps(json_data, separators=(",", ":")))
self.assertEqual(char_count, expected)
def test_get_estimated_tools_tokens_empty(self):
"""Test tool token estimation with no tools"""
tokens = self.window_manager._get_estimated_tools_tokens([])
self.assertEqual(tokens, 0)
def test_get_estimated_tools_tokens_with_dict_tools(self):
"""Test tool token estimation with dict tools"""
tools = [
{"type": "function", "function": {"name": "test_tool", "description": "A test tool"}},
]
tokens = self.window_manager._get_estimated_tools_tokens(tools)
# Should be positive and reasonable
self.assertGreater(tokens, 0)
self.assertLess(tokens, 100)
def test_get_estimated_tools_tokens_with_base_tool(self):
"""Test tool token estimation with BaseTool"""
from langchain_core.tools import tool
@tool
def sample_tool(query: str) -> str:
"""A sample tool for testing"""
return f"Result for {query}"
tools = [sample_tool]
tokens = self.window_manager._get_estimated_tools_tokens(tools)
# Should count the tool schema
self.assertGreater(tokens, 0)
self.assertLess(tokens, 200)
def test_get_estimated_tools_tokens_multiple_tools(self):
"""Test tool token estimation with multiple tools"""
from langchain_core.tools import tool
@tool
def tool1(x: int) -> int:
"""First tool"""
return x * 2
@tool
def tool2(y: str) -> str:
"""Second tool"""
return y.upper()
tools = [tool1, tool2]
tokens = self.window_manager._get_estimated_tools_tokens(tools)
# Should count both tool schemas
self.assertGreater(tokens, 0)
self.assertLess(tokens, 400)
async def test_get_token_count_calls_model(self):
"""Test that _get_token_count properly calls the model's token counting"""
mock_model = MagicMock()

View File

@@ -520,13 +520,12 @@ class TestAgentNode(ClickhouseTestMixin, BaseTest):
self.assertEqual(await node._get_billing_prompt(), expected_prompt)
@patch("ee.hogai.graph.agent_modes.nodes.AgentExecutable._get_model", return_value=FakeChatOpenAI(responses=[]))
@patch(
"ee.hogai.graph.agent_modes.compaction_manager.AnthropicConversationCompactionManager.should_compact_conversation"
)
@patch("ee.hogai.graph.agent_modes.compaction_manager.AnthropicConversationCompactionManager.calculate_token_count")
@patch("ee.hogai.graph.conversation_summarizer.nodes.AnthropicConversationSummarizer.summarize")
async def test_conversation_summarization_flow(self, mock_summarize, mock_should_compact, mock_model):
async def test_conversation_summarization_flow(self, mock_summarize, mock_calculate_tokens, mock_model):
"""Test that conversation is summarized when it gets too long"""
mock_should_compact.return_value = True
# Return a token count higher than CONVERSATION_WINDOW_SIZE (100,000)
mock_calculate_tokens.return_value = 150_000
mock_summarize.return_value = "This is a summary of the conversation so far."
mock_model_instance = FakeChatOpenAI(responses=[LangchainAIMessage(content="Response after summary")])
@@ -554,13 +553,12 @@ class TestAgentNode(ClickhouseTestMixin, BaseTest):
self.assertIn("This is a summary of the conversation so far.", context_messages[0].content)
@patch("ee.hogai.graph.agent_modes.nodes.AgentExecutable._get_model", return_value=FakeChatOpenAI(responses=[]))
@patch(
"ee.hogai.graph.agent_modes.compaction_manager.AnthropicConversationCompactionManager.should_compact_conversation"
)
@patch("ee.hogai.graph.agent_modes.compaction_manager.AnthropicConversationCompactionManager.calculate_token_count")
@patch("ee.hogai.graph.conversation_summarizer.nodes.AnthropicConversationSummarizer.summarize")
async def test_conversation_summarization_on_first_turn(self, mock_summarize, mock_should_compact, mock_model):
async def test_conversation_summarization_on_first_turn(self, mock_summarize, mock_calculate_tokens, mock_model):
"""Test that on first turn, the last message is excluded from summarization"""
mock_should_compact.return_value = True
# Return a token count higher than CONVERSATION_WINDOW_SIZE (100,000)
mock_calculate_tokens.return_value = 150_000
mock_summarize.return_value = "Summary without last message"
mock_model_instance = FakeChatOpenAI(responses=[LangchainAIMessage(content="Response")])

View File

@@ -55,9 +55,14 @@ class ConversationSummarizer:
class AnthropicConversationSummarizer(ConversationSummarizer):
def __init__(self, team: Team, user: User, extend_context_window: bool | None = False):
super().__init__(team, user)
self._extend_context_window = extend_context_window
def _get_model(self):
# Haiku has 200k token limit. Sonnet has 1M token limit.
return MaxChatAnthropic(
model="claude-haiku-4-5",
model="claude-sonnet-4-5" if self._extend_context_window else "claude-haiku-4-5",
streaming=False,
stream_usage=False,
max_tokens=8192,
@@ -65,6 +70,7 @@ class AnthropicConversationSummarizer(ConversationSummarizer):
user=self._user,
team=self._team,
billable=True,
betas=["context-1m-2025-08-07"] if self._extend_context_window else None,
)
def _construct_messages(self, messages: Sequence[BaseMessage]):