Files
posthog/ee/hogai/graph/conversation_summarizer/nodes.py
2025-11-13 16:44:28 +00:00

88 lines
3.1 KiB
Python

import re
from abc import abstractmethod
from collections.abc import Sequence
from langchain_core.messages import BaseMessage
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from posthog.models import Team, User
from ee.hogai.graph.conversation_summarizer.prompts import SYSTEM_PROMPT, USER_PROMPT
from ee.hogai.llm import MaxChatAnthropic
class ConversationSummarizer:
def __init__(self, team: Team, user: User):
self._user = user
self._team = team
async def summarize(self, messages: Sequence[BaseMessage]) -> str:
prompt = self._construct_messages(messages)
model = self._get_model()
chain = prompt | model | StrOutputParser() | self._parse_xml_tags
response: str = await chain.ainvoke({}) # Do not pass config here, so the node doesn't stream
return response
@abstractmethod
def _get_model(self): ...
def _construct_messages(self, messages: Sequence[BaseMessage]):
return (
ChatPromptTemplate.from_messages([("system", SYSTEM_PROMPT)])
+ messages
+ ChatPromptTemplate.from_messages([("user", USER_PROMPT)])
)
def _parse_xml_tags(self, message: str) -> str:
"""
Extract analysis and summary tags from a message.
Args:
message: The message content to parse
Returns:
Summary (falls back to original message if not present)
"""
summary = message # fallback to original message
# Extract summary tag content
summary_match = re.search(r"<summary>(.*?)</summary>", message, re.DOTALL | re.IGNORECASE)
if summary_match:
summary = summary_match.group(1).strip()
return summary
class AnthropicConversationSummarizer(ConversationSummarizer):
def __init__(self, team: Team, user: User, extend_context_window: bool | None = False):
super().__init__(team, user)
self._extend_context_window = extend_context_window
def _get_model(self):
# Haiku has 200k token limit. Sonnet has 1M token limit.
return MaxChatAnthropic(
model="claude-sonnet-4-5" if self._extend_context_window else "claude-haiku-4-5",
streaming=False,
stream_usage=False,
max_tokens=8192,
disable_streaming=True,
user=self._user,
team=self._team,
billable=True,
betas=["context-1m-2025-08-07"] if self._extend_context_window else None,
)
def _construct_messages(self, messages: Sequence[BaseMessage]):
"""Removes cache_control headers."""
messages_without_cache: list[BaseMessage] = []
for message in messages:
if isinstance(message.content, list):
message = message.model_copy(deep=True)
for content in message.content:
if isinstance(content, dict) and "cache_control" in content:
content.pop("cache_control")
messages_without_cache.append(message)
return super()._construct_messages(messages_without_cache)