mirror of
https://github.com/BillyOutlast/posthog.git
synced 2026-02-04 03:01:23 +01:00
feat(llma): Generate and embed LLM traces issue-search summaries (#40364)
Co-authored-by: Andrew Maguire <andrewm4894@gmail.com> Co-authored-by: Claude <noreply@anthropic.com>
This commit is contained in:
0
ee/hogai/llm_traces_summaries/__init__.py
Normal file
0
ee/hogai/llm_traces_summaries/__init__.py
Normal file
10
ee/hogai/llm_traces_summaries/constants.py
Normal file
10
ee/hogai/llm_traces_summaries/constants.py
Normal file
@@ -0,0 +1,10 @@
|
||||
# Summaries
|
||||
LLM_TRACES_SUMMARIES_MODEL_TO_SUMMARIZE_STRINGIFIED_TRACES = "gemini-2.5-flash-lite-preview-09-2025"
|
||||
|
||||
# Embeddings
|
||||
DOCUMENT_EMBEDDINGS_TOPIC = "document_embeddings_input"
|
||||
LLM_TRACES_SUMMARIES_PRODUCT = "llm-analytics"
|
||||
LLM_TRACES_SUMMARIES_DOCUMENT_TYPE = "llm-trace-summary"
|
||||
LLM_TRACES_SUMMARIES_SEARCH_QUERY_DOCUMENT_TYPE = "trace-summary-search-query"
|
||||
LLM_TRACES_SUMMARIES_SEARCH_QUERY_POLL_INTERVAL_SECONDS = 3
|
||||
LLM_TRACES_SUMMARIES_SEARCH_QUERY_MAX_ATTEMPTS = 10
|
||||
100
ee/hogai/llm_traces_summaries/summarize_traces.py
Normal file
100
ee/hogai/llm_traces_summaries/summarize_traces.py
Normal file
@@ -0,0 +1,100 @@
|
||||
import structlog
|
||||
|
||||
from posthog.schema import DateRange
|
||||
|
||||
from posthog.models.team.team import Team
|
||||
from posthog.sync import database_sync_to_async
|
||||
|
||||
from ee.hogai.llm_traces_summaries.tools.embed_summaries import LLMTracesSummarizerEmbedder
|
||||
from ee.hogai.llm_traces_summaries.tools.find_similar_traces import LLMTracesSummarizerFinder
|
||||
from ee.hogai.llm_traces_summaries.tools.generate_stringified_summaries import LLMTraceSummarizerGenerator
|
||||
from ee.hogai.llm_traces_summaries.tools.get_traces import LLMTracesSummarizerCollector
|
||||
from ee.hogai.llm_traces_summaries.tools.stringify_trace import LLMTracesSummarizerStringifier
|
||||
from ee.hogai.llm_traces_summaries.utils.load_from_csv import load_traces_from_csv_files
|
||||
from ee.models.llm_traces_summaries import LLMTraceSummary
|
||||
|
||||
logger = structlog.get_logger(__name__)
|
||||
|
||||
|
||||
class LLMTracesSummarizer:
|
||||
def __init__(self, team: Team):
|
||||
self._team = team
|
||||
|
||||
async def summarize_traces_for_date_range(self, date_range: DateRange) -> None:
|
||||
"""Get, stringify, summarize, embed and store summaries for all traces in the date range."""
|
||||
stringified_traces = await self._collect_and_stringify_traces_for_date_range(date_range=date_range)
|
||||
# Summarize stringified traces
|
||||
await self._summarize_stringified_traces(stringified_traces=stringified_traces)
|
||||
# Returns nothing if everything succeeded
|
||||
return None
|
||||
|
||||
async def _collect_and_stringify_traces_for_date_range(self, date_range: DateRange) -> dict[str, str]:
|
||||
collector = LLMTracesSummarizerCollector(team=self._team)
|
||||
# Collect and stringify traces in-memory
|
||||
stringifier = LLMTracesSummarizerStringifier(team=self._team)
|
||||
stringified_traces: dict[str, str] = {} # trace_id -> stringified trace
|
||||
offset = 0
|
||||
|
||||
# Iterate to collect and stringify all traces in the date range
|
||||
while True:
|
||||
# Processing in chunks to avoid storing all heavy traces in memory at once (stringified ones are way lighter)
|
||||
response = await database_sync_to_async(collector.get_db_traces_per_page)(
|
||||
offset=offset, date_range=date_range
|
||||
)
|
||||
results = response.results
|
||||
offset += len(results)
|
||||
if len(results) == 0:
|
||||
break
|
||||
stringified_traces_chunk = stringifier.stringify_traces(traces_chunk=results)
|
||||
stringified_traces.update(stringified_traces_chunk)
|
||||
if response.hasMore is not True:
|
||||
break
|
||||
return stringified_traces
|
||||
|
||||
async def summarize_traces_from_csv_files(self, csv_paths: list[str]) -> None:
|
||||
"""Collect and stringify traces from CSV files and summarize them, useful for local development"""
|
||||
stringified_traces = self._collect_and_stringify_traces_from_csv_files(csv_paths=csv_paths)
|
||||
# Summarize stringified traces
|
||||
await self._summarize_stringified_traces(stringified_traces=stringified_traces)
|
||||
return None
|
||||
|
||||
def _collect_and_stringify_traces_from_csv_files(self, csv_paths: list[str]) -> dict[str, str]:
|
||||
# Collect and stringify traces in-memory
|
||||
stringifier = LLMTracesSummarizerStringifier(team=self._team)
|
||||
stringified_traces: dict[str, str] = {} # trace_id -> stringified trace
|
||||
for trace in load_traces_from_csv_files(csv_paths=csv_paths):
|
||||
stringified_trace = stringifier.stringify_traces(traces_chunk=[trace])
|
||||
stringified_traces.update(stringified_trace)
|
||||
return stringified_traces
|
||||
|
||||
async def _summarize_stringified_traces(self, stringified_traces: dict[str, str]) -> None:
|
||||
# Summarize stringified traces
|
||||
summary_generator = LLMTraceSummarizerGenerator(team=self._team)
|
||||
summarized_traces = await summary_generator.summarize_stringified_traces(stringified_traces=stringified_traces)
|
||||
# Store summaries in the database
|
||||
await database_sync_to_async(summary_generator.store_summaries_in_db)(summarized_traces=summarized_traces)
|
||||
# Embed summaries
|
||||
embedder = LLMTracesSummarizerEmbedder(team=self._team)
|
||||
embedder.embed_summaries(
|
||||
summarized_traces=summarized_traces, summary_type=LLMTraceSummary.LLMTraceSummaryType.ISSUES_SEARCH
|
||||
)
|
||||
# Returns nothing if everything succeeded
|
||||
return None
|
||||
|
||||
def find_top_similar_traces_for_query(
|
||||
self,
|
||||
query: str,
|
||||
request_id: str,
|
||||
top: int,
|
||||
date_range: DateRange,
|
||||
summary_type: LLMTraceSummary.LLMTraceSummaryType,
|
||||
):
|
||||
"""Search all summarized traces withi the date range for the query and return the top similar traces."""
|
||||
finder = LLMTracesSummarizerFinder(team=self._team)
|
||||
return finder.find_top_similar_traces_for_query(
|
||||
query=query,
|
||||
request_id=request_id,
|
||||
top=top,
|
||||
date_range=date_range,
|
||||
summary_type=summary_type,
|
||||
)
|
||||
@@ -0,0 +1,387 @@
|
||||
import math
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
|
||||
import numpy as np
|
||||
import structlog
|
||||
from sklearn.cluster import KMeans, MiniBatchKMeans
|
||||
from tqdm.asyncio import tqdm
|
||||
|
||||
logger = structlog.get_logger(__name__)
|
||||
|
||||
|
||||
def cosine_similarity_func(a: list[float], b: list[float]) -> float:
|
||||
return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RelevantTracesGroup:
|
||||
traces: list[str]
|
||||
avg_similarity: float | None
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TracesCluster:
|
||||
traces: list[str]
|
||||
embeddings: list[list[float]]
|
||||
|
||||
|
||||
# How many times max to re-group singles to increase group count
|
||||
EMBEDDINGS_CLUSTERING_MAX_RECURSION: int = 3
|
||||
# How many additional recursions allowed if the tail is too large (loose traces)
|
||||
EMBEDDINGS_CLUSTERING_MAX_TAIL_RECURSION: int = 3
|
||||
# If the tail is larger than that - try to cluster once more with more loose approach
|
||||
EMBEDDINGS_CLUSTERING_MAX_TAIL_PERCENTAGE: float = 0.50
|
||||
# Split embeddings into chunks to speed up clustering
|
||||
EMBEDDINGS_CLUSTERING_CHUNK_SIZE: int = 1000 # Increasing from default 25
|
||||
# Expected average similarity between embeddings to group them
|
||||
EMBEDDINGS_COSINE_SIMILARITY: float = 0.72 # Lowering from the default 0.95
|
||||
# How many times to try to group until to stop
|
||||
EMBEDDINGS_CLUSTERING_ITERATIONS: int = 5
|
||||
# How many times to try to group when trying to decrease the tail (too large, loose traces)
|
||||
EMBEDDINGS_CLUSTERING_MAX_TAIL_ITERATIONS: int = 1
|
||||
# Expected minimal number of traces per group when grouping embeddings
|
||||
EXPECTED_SUGGESTIONS_PER_EMBEDDINGS_GROUP: int = 25 # Increasing from default 5
|
||||
# Max traces per group to avoid large loosely-related groups
|
||||
MAX_SUGGESTIONS_PER_EMBEDDINGS_GROUP: int = 100
|
||||
# How to decrease the similarity between embeddings to group them with each iteration,
|
||||
# to increase the number of groups and improve the user experience
|
||||
EMBEDDINGS_COSINE_SIMILARITY_DECREASE: float = 0.01
|
||||
|
||||
|
||||
# Results (1000-items run)
|
||||
|
||||
# 15-sized groups:
|
||||
# 2025-10-19 15:20:25 [info ] INPUT
|
||||
# 2025-10-19 15:20:25 [info ] Clustering chunk size: 1000
|
||||
# 2025-10-19 15:20:25 [info ] Clustering cosine similarity: 0.72
|
||||
# 2025-10-19 15:20:25 [info ] Expected traces per embeddings group: 15
|
||||
# 2025-10-19 15:20:25 [info ] CLUSTERING RESULTS
|
||||
# 2025-10-19 15:20:25 [info ] Groups count: 70
|
||||
# 2025-10-19 15:20:25 [info ] Singles count: 506
|
||||
# 2025-10-19 15:20:25 [info ] Avg cosine similarity: 0.7435714285714285
|
||||
|
||||
# 25-sized groups:
|
||||
# 2025-10-19 15:17:40 [info ] INPUT
|
||||
# 2025-10-19 15:17:40 [info ] Clustering chunk size: 1000
|
||||
# 2025-10-19 15:17:40 [info ] Clustering cosine similarity: 0.72
|
||||
# 2025-10-19 15:17:40 [info ] Expected traces per embeddings group: 25
|
||||
# 2025-10-19 15:17:40 [info ] CLUSTERING RESULTS
|
||||
# 2025-10-19 15:17:40 [info ] Groups count: 32
|
||||
# 2025-10-19 15:17:40 [info ] Singles count: 598
|
||||
# 2025-10-19 15:17:40 [info ] Avg cosine similarity: 0.720625
|
||||
|
||||
# 50-sized groups:
|
||||
# 2025-10-19 18:03:19 [info ] INPUT
|
||||
# 2025-10-19 18:03:19 [info ] Clustering chunk size: 1000
|
||||
# 2025-10-19 18:03:19 [info ] Clustering cosine similarity: 0.72
|
||||
# 2025-10-19 18:03:19 [info ] Expected traces per embeddings group: 50
|
||||
# 2025-10-19 18:03:19 [info ] CLUSTERING RESULTS
|
||||
# 2025-10-19 18:03:19 [info ] Groups count: 7
|
||||
# 2025-10-19 18:03:19 [info ] Singles count: 907
|
||||
# 2025-10-19 18:03:19 [info ] Avg cosine similarity: 0.7099999999999999
|
||||
|
||||
# Results (6636-items run in 1000 chunks)
|
||||
|
||||
# 15-sized groups:
|
||||
# 2025-10-19 18:27:09 [info ] INPUT
|
||||
# 2025-10-19 18:27:09 [info ] Clustering chunk size: 1000
|
||||
# 2025-10-19 18:27:09 [info ] Clustering cosine similarity: 0.72
|
||||
# 2025-10-19 18:27:09 [info ] Expected traces per embeddings group: 15
|
||||
# 2025-10-19 18:27:09 [info ] CLUSTERING RESULTS
|
||||
# 2025-10-19 18:27:09 [info ] Groups count: 582
|
||||
# 2025-10-19 18:27:09 [info ] Singles count: 1088
|
||||
# 2025-10-19 18:27:09 [info ] Avg cosine similarity: 0.7607044673539523
|
||||
|
||||
# 25-sized groups:
|
||||
# 2025-10-19 20:04:56 [info ] INPUT
|
||||
# 2025-10-19 20:04:56 [info ] Clustering chunk size: 1000
|
||||
# 2025-10-19 20:04:56 [info ] Clustering cosine similarity: 0.72
|
||||
# 2025-10-19 20:04:56 [info ] Expected traces per embeddings group: 25
|
||||
# 2025-10-19 20:04:56 [info ] CLUSTERING RESULTS
|
||||
# 2025-10-19 20:04:56 [info ] Groups count: 329
|
||||
# 2025-10-19 20:04:56 [info ] Singles count: 1960
|
||||
# 2025-10-19 20:04:56 [info ] Avg cosine similarity: 0.7536778115501521
|
||||
# 2025-10-19 20:04:58 [info ]
|
||||
|
||||
|
||||
class KmeansClusterizer:
|
||||
@classmethod
|
||||
def clusterize_traces(
|
||||
cls,
|
||||
embedded_traces: list[str],
|
||||
embeddings: list[list[float]],
|
||||
max_tail_size: int,
|
||||
pre_combined_groups: dict[str, RelevantTracesGroup] | None = None,
|
||||
iteration: int = 0,
|
||||
) -> tuple[dict[str, RelevantTracesGroup], TracesCluster]:
|
||||
"""
|
||||
Wrapper for clusterizing traces, to allow tracking stats
|
||||
for clusterization iterations only once, on final iteration.
|
||||
"""
|
||||
# Assuming the input is sorted alphabetically in hope to improve grouping quality
|
||||
return cls._clusterize_traces(
|
||||
embedded_traces=embedded_traces,
|
||||
embeddings=embeddings,
|
||||
max_tail_size=max_tail_size,
|
||||
pre_combined_groups=pre_combined_groups,
|
||||
iteration=iteration,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _clusterize_traces(
|
||||
cls,
|
||||
embedded_traces: list[str],
|
||||
embeddings: list[list[float]],
|
||||
max_tail_size: int,
|
||||
pre_combined_groups: dict[str, RelevantTracesGroup] | None,
|
||||
iteration: int,
|
||||
cosine_similarity: float = EMBEDDINGS_COSINE_SIMILARITY,
|
||||
clustering_iterations: int = EMBEDDINGS_CLUSTERING_ITERATIONS,
|
||||
) -> tuple[dict[str, RelevantTracesGroup], TracesCluster]:
|
||||
groups, singles = cls._clusterize_traces_iteration(
|
||||
embedded_traces=embedded_traces,
|
||||
embeddings=embeddings,
|
||||
iteration=iteration,
|
||||
cosine_similarity=cosine_similarity,
|
||||
clustering_iterations=clustering_iterations,
|
||||
)
|
||||
combined_groups: dict[str, RelevantTracesGroup] = {}
|
||||
# If pre-combined groups are provided - add them to the combined groups
|
||||
if pre_combined_groups:
|
||||
groups = [pre_combined_groups, *groups]
|
||||
for group_set in groups:
|
||||
combined_groups = {**combined_groups, **group_set}
|
||||
# Combine the singles in the expected format
|
||||
combined_singles = TracesCluster(traces=[], embeddings=[])
|
||||
for single in singles:
|
||||
combined_singles.traces.extend(single.traces)
|
||||
combined_singles.embeddings.extend(single.embeddings)
|
||||
# If there are still iterations left - iterate again
|
||||
if iteration < EMBEDDINGS_CLUSTERING_MAX_RECURSION:
|
||||
return cls._clusterize_traces(
|
||||
embedded_traces=combined_singles.traces,
|
||||
embeddings=combined_singles.embeddings,
|
||||
max_tail_size=max_tail_size,
|
||||
pre_combined_groups=combined_groups,
|
||||
iteration=iteration + 1,
|
||||
)
|
||||
# If the iterations exhausted and the tail is acceptable - return the results
|
||||
if len(combined_singles.traces) <= max_tail_size:
|
||||
return combined_groups, combined_singles
|
||||
# If the tail is still too large, but no max tail recursions left - return the results anyway
|
||||
if iteration >= (EMBEDDINGS_CLUSTERING_MAX_RECURSION + EMBEDDINGS_CLUSTERING_MAX_TAIL_RECURSION):
|
||||
return combined_groups, combined_singles
|
||||
# If the tail is still too large and there are max tail recursions left -
|
||||
# iterate again with the lowest allowed average similarity
|
||||
max_tail_cosine_similarity = round(
|
||||
(
|
||||
# Calculate the lowest allowed similarity
|
||||
EMBEDDINGS_COSINE_SIMILARITY
|
||||
- (
|
||||
# First iteration doesn't count (i-0), so decrease the similarity
|
||||
(EMBEDDINGS_CLUSTERING_ITERATIONS - 1) * EMBEDDINGS_COSINE_SIMILARITY_DECREASE
|
||||
)
|
||||
),
|
||||
2,
|
||||
)
|
||||
return cls._clusterize_traces(
|
||||
embedded_traces=combined_singles.traces,
|
||||
embeddings=combined_singles.embeddings,
|
||||
max_tail_size=max_tail_size,
|
||||
pre_combined_groups=combined_groups,
|
||||
iteration=iteration + 1,
|
||||
cosine_similarity=max_tail_cosine_similarity,
|
||||
clustering_iterations=EMBEDDINGS_CLUSTERING_MAX_TAIL_ITERATIONS,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _clusterize_traces_iteration(
|
||||
cls,
|
||||
embedded_traces: list[str],
|
||||
embeddings: list[list[float]],
|
||||
iteration: int,
|
||||
cosine_similarity: float,
|
||||
clustering_iterations: int,
|
||||
) -> tuple[list[dict[str, RelevantTracesGroup]], list[TracesCluster]]:
|
||||
# Split traces into large chunks, and then search for groups within each chunk
|
||||
n_clusters = math.ceil(len(embeddings) / EMBEDDINGS_CLUSTERING_CHUNK_SIZE)
|
||||
if n_clusters == 1:
|
||||
# If it's a single cluster - create it manually
|
||||
init_embeddings_clusters = {"single_cluster": TracesCluster(traces=embedded_traces, embeddings=embeddings)}
|
||||
else:
|
||||
init_embeddings_clusters = cls._calculate_embeddings_clusters(
|
||||
embedded_traces=embedded_traces,
|
||||
embeddings=embeddings,
|
||||
n_clusters=n_clusters,
|
||||
minibatch=True,
|
||||
)
|
||||
return cls._group_multiple_embeddings_clusters(
|
||||
init_embeddings_clusters=init_embeddings_clusters,
|
||||
iteration=iteration,
|
||||
cosine_similarity=cosine_similarity,
|
||||
clustering_iterations=clustering_iterations,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _group_multiple_embeddings_clusters(
|
||||
cls,
|
||||
init_embeddings_clusters: dict[str, TracesCluster],
|
||||
iteration: int,
|
||||
cosine_similarity: float,
|
||||
clustering_iterations: int,
|
||||
) -> tuple[list[dict[str, RelevantTracesGroup]], list[TracesCluster]]:
|
||||
groups = []
|
||||
singles = []
|
||||
# Find groups of traces in each cluster, one by one
|
||||
for _i, (_, cluster) in enumerate(
|
||||
tqdm(
|
||||
init_embeddings_clusters.items(),
|
||||
desc=f"Grouping embeddings clusters (iteration: {iteration})",
|
||||
)
|
||||
):
|
||||
cluster_groups, cluster_singles = cls._group_embeddings_cluster(
|
||||
embedded_traces=cluster.traces,
|
||||
embeddings=cluster.embeddings,
|
||||
cosine_similarity=cosine_similarity,
|
||||
clustering_iterations=clustering_iterations,
|
||||
)
|
||||
groups.append(cluster_groups)
|
||||
singles.append(cluster_singles)
|
||||
return groups, singles
|
||||
|
||||
@classmethod
|
||||
def _group_embeddings_cluster(
|
||||
cls,
|
||||
embedded_traces: list[str],
|
||||
embeddings: list[list[float]],
|
||||
cosine_similarity: float,
|
||||
clustering_iterations: int,
|
||||
) -> tuple[dict[str, RelevantTracesGroup], TracesCluster]:
|
||||
# Define result variables to update with each iteration
|
||||
result_relevant_groups: dict[str, RelevantTracesGroup] = {}
|
||||
result_singles: TracesCluster = TracesCluster(traces=[], embeddings=[])
|
||||
traces_input, embeddings_input = embedded_traces, embeddings
|
||||
# An expected average of traces per group.
|
||||
embeddings_per_group = EXPECTED_SUGGESTIONS_PER_EMBEDDINGS_GROUP
|
||||
# How many times to clusterize until to stop (to disallow while loop to run forever)
|
||||
# Decrease the required similarity (- quality) and decrease the cluster size (+ quality) with each iteration
|
||||
for similarity_iteration in range(clustering_iterations):
|
||||
n_clusters = math.ceil(len(traces_input) / embeddings_per_group)
|
||||
# Decrease required similarity to group embeddings with each iteration,
|
||||
# to allow more ideas to be grouped and improve the user experience
|
||||
avg_similarity_threshold = round(
|
||||
cosine_similarity - (EMBEDDINGS_COSINE_SIMILARITY_DECREASE * similarity_iteration),
|
||||
2,
|
||||
)
|
||||
(
|
||||
relevant_groups,
|
||||
result_singles,
|
||||
) = cls._group_embeddings_cluster_iteration(
|
||||
embedded_traces=traces_input,
|
||||
embeddings=embeddings_input,
|
||||
n_clusters=n_clusters,
|
||||
avg_similarity_threshold=avg_similarity_threshold,
|
||||
)
|
||||
# Save successfully groupped traces
|
||||
result_relevant_groups = {**result_relevant_groups, **relevant_groups}
|
||||
# If no singles left - nothing to group again, return results
|
||||
if not result_singles.traces:
|
||||
return result_relevant_groups, result_singles
|
||||
# If singles left, but less than a single group - don't group them again
|
||||
if len(result_singles.traces) < embeddings_per_group:
|
||||
return result_relevant_groups, result_singles
|
||||
# If enough singles left - try to clusterize them again
|
||||
traces_input, embeddings_input = (
|
||||
result_singles.traces,
|
||||
result_singles.embeddings,
|
||||
)
|
||||
# Return the final results
|
||||
return result_relevant_groups, result_singles
|
||||
|
||||
@staticmethod
|
||||
def _calculate_embeddings_clusters(
|
||||
embedded_traces: list[str],
|
||||
embeddings: list[list[float]],
|
||||
n_clusters: int,
|
||||
minibatch: bool,
|
||||
) -> dict[str, TracesCluster]:
|
||||
matrix = np.vstack(embeddings)
|
||||
if not minibatch:
|
||||
kmeans = KMeans(n_clusters=n_clusters, init="k-means++", n_init=10, random_state=42)
|
||||
else:
|
||||
kmeans = MiniBatchKMeans(n_clusters=n_clusters, init="k-means++", n_init=10, random_state=42)
|
||||
kmeans.fit_predict(matrix)
|
||||
labels: list[int] = kmeans.labels_
|
||||
# Organize clustered traces
|
||||
grouped_traces: dict[str, TracesCluster] = {}
|
||||
# Generate unique label for each clustering calculation
|
||||
unique_label = str(uuid.uuid4())
|
||||
for trace, label, emb in zip(embedded_traces, labels, embeddings):
|
||||
formatted_label = f"{label}_{unique_label}"
|
||||
if formatted_label not in grouped_traces:
|
||||
grouped_traces[formatted_label] = TracesCluster(traces=[], embeddings=[])
|
||||
grouped_traces[formatted_label].traces.append(trace)
|
||||
grouped_traces[formatted_label].embeddings.append(emb)
|
||||
return grouped_traces
|
||||
|
||||
@classmethod
|
||||
def _group_embeddings_cluster_iteration(
|
||||
cls,
|
||||
embedded_traces: list[str],
|
||||
embeddings: list[list[float]],
|
||||
n_clusters: int,
|
||||
avg_similarity_threshold: float | None,
|
||||
) -> tuple[dict[str, RelevantTracesGroup], TracesCluster]:
|
||||
embeddings_clusters = cls._calculate_embeddings_clusters(
|
||||
embedded_traces=embedded_traces,
|
||||
embeddings=embeddings,
|
||||
n_clusters=n_clusters,
|
||||
minibatch=False,
|
||||
)
|
||||
# Split into relevant groups and singles
|
||||
relevant_groups: dict[str, RelevantTracesGroup] = {}
|
||||
singles = TracesCluster(traces=[], embeddings=[])
|
||||
for group_label, cluster in embeddings_clusters.items():
|
||||
if len(cluster.traces) <= 1:
|
||||
# Groups with a single idea move to singles automatically
|
||||
singles.traces.extend(cluster.traces)
|
||||
singles.embeddings.extend(cluster.embeddings)
|
||||
continue
|
||||
# If avg similarity threshold not provided (init chunking) - don't filter groups
|
||||
if not avg_similarity_threshold:
|
||||
# Keep the proper-sized groups that are close to each other
|
||||
relevant_groups[group_label] = RelevantTracesGroup(
|
||||
traces=cluster.traces,
|
||||
avg_similarity=None,
|
||||
)
|
||||
continue
|
||||
# Calculate average similarity between all traces in the group
|
||||
similarities = []
|
||||
for trace, emb in zip(cluster.traces, cluster.embeddings):
|
||||
for l_trace, l_emb in zip(cluster.traces, cluster.embeddings):
|
||||
if trace != l_trace:
|
||||
similarities.append(cosine_similarity_func(emb, l_emb))
|
||||
# Round to 2 symbols after the dot
|
||||
if not similarities:
|
||||
# TODO: Add some proper logging
|
||||
continue
|
||||
avg_similarity = round(sum(similarities) / len(similarities), 2)
|
||||
# Groups that aren't close enough move to singles
|
||||
if avg_similarity < avg_similarity_threshold:
|
||||
singles.traces.extend(cluster.traces)
|
||||
singles.embeddings.extend(cluster.embeddings)
|
||||
continue
|
||||
# Avoid having large loosely-connected groups
|
||||
if len(cluster.traces) > MAX_SUGGESTIONS_PER_EMBEDDINGS_GROUP:
|
||||
singles.traces.extend(cluster.traces)
|
||||
singles.embeddings.extend(cluster.embeddings)
|
||||
continue
|
||||
# Keep the proper-sized groups that are close to each other
|
||||
relevant_groups[group_label] = RelevantTracesGroup(
|
||||
# Don't save embeddings, as the group is already relevant,
|
||||
# so won't go through the clusterization again
|
||||
traces=cluster.traces,
|
||||
avg_similarity=avg_similarity,
|
||||
)
|
||||
return relevant_groups, singles
|
||||
@@ -0,0 +1,129 @@
|
||||
import os
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
import structlog
|
||||
from google import genai
|
||||
from google.genai.types import GenerateContentConfig
|
||||
|
||||
logger = structlog.get_logger(__name__)
|
||||
|
||||
CLUSTER_NAME_PROMPT = """
|
||||
- Analyze this list of chat conversation summaries groups, explaining users' issues
|
||||
- Summaries were combined into this groups by similarity, so they share a specific topic
|
||||
- Generate a name for this group that would describe what topic summaries of this group share
|
||||
- The group name should be concise, up to 10 words
|
||||
- IMPORTANT: focus on a specific issue, product, or feature the user had problems with
|
||||
- Start every meaningful word in the title with a capital letter
|
||||
- Return the name of the group as plain text, without any comments or explanations
|
||||
|
||||
```
|
||||
{summaries}
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ClusterizedSuggestion:
|
||||
summary: str
|
||||
trace_id: str
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ClusterizedSuggestionsGroup:
|
||||
suggestions: list[ClusterizedSuggestion]
|
||||
avg_similarity: float
|
||||
cluster_label: str
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ExplainedClusterizedSuggestionsGroup(ClusterizedSuggestionsGroup):
|
||||
name: str
|
||||
|
||||
|
||||
class ClusterExplainer:
|
||||
def __init__(self, model_id: str, groups_raw: dict[str, Any], summaries_to_trace_ids_mapping: dict[str, str]):
|
||||
self._groups_raw = groups_raw
|
||||
self._summaries_to_trace_ids_mapping = summaries_to_trace_ids_mapping
|
||||
self.model_id = model_id
|
||||
self.client = self._prepare_client()
|
||||
|
||||
def explain_clusters(self) -> dict[str, ExplainedClusterizedSuggestionsGroup]:
|
||||
enriched_clusters: dict[str, ClusterizedSuggestionsGroup] = {}
|
||||
for cluster_label, cluster_raw in self._groups_raw.items():
|
||||
enriched_clusters[cluster_label] = self._enrich_cluster_with_trace_ids(
|
||||
cluster_raw=cluster_raw, cluster_label=cluster_label
|
||||
)
|
||||
named_clusters = self._name_clusters(enriched_clusters)
|
||||
# Sort clusters to show the best ones first
|
||||
sorted_named_clusters = self.sort_named_clusters(named_clusters)
|
||||
return sorted_named_clusters
|
||||
|
||||
@staticmethod
|
||||
def _prepare_client() -> genai.Client:
|
||||
api_key = os.getenv("GEMINI_API_KEY")
|
||||
return genai.Client(api_key=api_key)
|
||||
|
||||
def _name_clusters(
|
||||
self, enriched_clusters: dict[str, ClusterizedSuggestionsGroup]
|
||||
) -> dict[str, ExplainedClusterizedSuggestionsGroup]:
|
||||
named_clusters: dict[str, ExplainedClusterizedSuggestionsGroup] = {}
|
||||
tasks = {}
|
||||
for label, cluster in enriched_clusters.items():
|
||||
tasks[label] = self._generate_cluster_name(
|
||||
# Provide first 5 summaries, should be enough to get the context for the name generation
|
||||
summaries=[x.summary for x in cluster.suggestions][:5]
|
||||
)
|
||||
for label, result in tasks.items():
|
||||
current_cluster = enriched_clusters[label]
|
||||
named_clusters[label] = ExplainedClusterizedSuggestionsGroup(
|
||||
suggestions=current_cluster.suggestions,
|
||||
avg_similarity=current_cluster.avg_similarity,
|
||||
cluster_label=current_cluster.cluster_label,
|
||||
name=result,
|
||||
)
|
||||
return named_clusters
|
||||
|
||||
def _generate_cluster_name(self, summaries: list[str]) -> str:
|
||||
message = CLUSTER_NAME_PROMPT.format(summaries=json.dumps(summaries))
|
||||
config_kwargs = {"temperature": 0} # Not using any system prompt for saving tokens, as should be good enough
|
||||
response = self.client.models.generate_content(
|
||||
model=self.model_id, contents=message, config=GenerateContentConfig(**config_kwargs)
|
||||
)
|
||||
if not response.text:
|
||||
raise ValueError("No cluster name was generated")
|
||||
sentences = response.text.split(".")
|
||||
if len(sentences) > 1:
|
||||
# If LLM generated the explanation (should not) - use the last sentence
|
||||
return sentences[-1]
|
||||
return response.text
|
||||
|
||||
def _enrich_cluster_with_trace_ids(
|
||||
self, cluster_raw: dict[str, Any], cluster_label: str
|
||||
) -> ClusterizedSuggestionsGroup:
|
||||
try:
|
||||
avg_similarity: float = cluster_raw["avg_similarity"]
|
||||
suggestions: list[str] = cluster_raw["suggestions"]
|
||||
suggestions_with_trace_ids: list[ClusterizedSuggestion] = []
|
||||
for suggestion in suggestions:
|
||||
trace_id = self._summaries_to_trace_ids_mapping[suggestion]
|
||||
suggestions_with_trace_ids.append(ClusterizedSuggestion(summary=suggestion, trace_id=trace_id))
|
||||
return ClusterizedSuggestionsGroup(
|
||||
suggestions=suggestions_with_trace_ids, avg_similarity=avg_similarity, cluster_label=cluster_label
|
||||
)
|
||||
except Exception as err:
|
||||
raise ValueError(f"Error enriching cluster {cluster_label} with trace IDs: {err}") from err
|
||||
|
||||
@staticmethod
|
||||
def sort_named_clusters(
|
||||
named_clusters: dict[str, ExplainedClusterizedSuggestionsGroup],
|
||||
) -> dict[str, ExplainedClusterizedSuggestionsGroup]:
|
||||
# Sort named clusters by the average similarity
|
||||
return dict(
|
||||
sorted(
|
||||
named_clusters.items(),
|
||||
key=lambda item: item[1].avg_similarity,
|
||||
reverse=True,
|
||||
)
|
||||
)
|
||||
114
ee/hogai/llm_traces_summaries/tools/embed_summaries.py
Normal file
114
ee/hogai/llm_traces_summaries/tools/embed_summaries.py
Normal file
@@ -0,0 +1,114 @@
|
||||
import time
|
||||
from datetime import datetime
|
||||
|
||||
from django.utils import timezone
|
||||
|
||||
from posthog.schema import EmbeddingModelName
|
||||
|
||||
from posthog.clickhouse.client import sync_execute
|
||||
from posthog.kafka_client.client import KafkaProducer
|
||||
from posthog.models.team.team import Team
|
||||
|
||||
from ee.hogai.llm_traces_summaries.constants import (
|
||||
DOCUMENT_EMBEDDINGS_TOPIC,
|
||||
LLM_TRACES_SUMMARIES_DOCUMENT_TYPE,
|
||||
LLM_TRACES_SUMMARIES_PRODUCT,
|
||||
LLM_TRACES_SUMMARIES_SEARCH_QUERY_DOCUMENT_TYPE,
|
||||
LLM_TRACES_SUMMARIES_SEARCH_QUERY_MAX_ATTEMPTS,
|
||||
LLM_TRACES_SUMMARIES_SEARCH_QUERY_POLL_INTERVAL_SECONDS,
|
||||
)
|
||||
from ee.models.llm_traces_summaries import LLMTraceSummary
|
||||
|
||||
|
||||
class LLMTracesSummarizerEmbedder:
|
||||
def __init__(
|
||||
self, team: Team, embedding_model_name: EmbeddingModelName = EmbeddingModelName.TEXT_EMBEDDING_3_LARGE_3072
|
||||
):
|
||||
self._team = team
|
||||
self._producer = KafkaProducer()
|
||||
self._embedding_model_name = embedding_model_name
|
||||
|
||||
def embed_summaries(self, summarized_traces: dict[str, str], summary_type: LLMTraceSummary.LLMTraceSummaryType):
|
||||
"""Generated embeddings for all summaries of stringified traces."""
|
||||
# Add all the summaries to the Kafka producer to be stored in ClickHouse
|
||||
for trace_id, summary in summarized_traces.items():
|
||||
self._embed_document(
|
||||
content=summary,
|
||||
document_id=trace_id,
|
||||
document_type=LLM_TRACES_SUMMARIES_DOCUMENT_TYPE,
|
||||
product=LLM_TRACES_SUMMARIES_PRODUCT,
|
||||
rendering=summary_type.value,
|
||||
)
|
||||
# No immediate results needed, so return nothing
|
||||
return None
|
||||
|
||||
def embed_summaries_search_query_with_timestamp(
|
||||
self, query: str, request_id: str, summary_type: LLMTraceSummary.LLMTraceSummaryType
|
||||
) -> datetime:
|
||||
"""
|
||||
Generate and return embeddings for the search query to get the most similar summaries.
|
||||
We expect query to come either from conversation or from a search request.
|
||||
"""
|
||||
# Embed the search query and store the timestamp it was generated at
|
||||
timestamp = self._embed_document(
|
||||
content=query,
|
||||
document_id=request_id,
|
||||
document_type=LLM_TRACES_SUMMARIES_SEARCH_QUERY_DOCUMENT_TYPE,
|
||||
product=LLM_TRACES_SUMMARIES_PRODUCT,
|
||||
rendering=summary_type.value,
|
||||
)
|
||||
# Check if the embeddings are ready
|
||||
# TODO: Understand a better, more predictable way to check if the embeddings are ready
|
||||
embeddings_ready = False
|
||||
attempts = 0
|
||||
while attempts < LLM_TRACES_SUMMARIES_SEARCH_QUERY_MAX_ATTEMPTS:
|
||||
embeddings_ready = self._check_embedding_exists(
|
||||
document_id=request_id, document_type=LLM_TRACES_SUMMARIES_SEARCH_QUERY_DOCUMENT_TYPE
|
||||
)
|
||||
if embeddings_ready:
|
||||
break
|
||||
attempts += 1
|
||||
time.sleep(LLM_TRACES_SUMMARIES_SEARCH_QUERY_POLL_INTERVAL_SECONDS)
|
||||
if not embeddings_ready:
|
||||
raise ValueError(
|
||||
f"Embeddings not ready after {LLM_TRACES_SUMMARIES_SEARCH_QUERY_MAX_ATTEMPTS} attempts when embedding search query for traces summaries"
|
||||
)
|
||||
return timestamp
|
||||
|
||||
def _check_embedding_exists(self, document_id: str, document_type: str) -> bool:
|
||||
"""Check if embedding exists in ClickHouse for given document_id"""
|
||||
query = """
|
||||
SELECT count()
|
||||
FROM posthog_document_embeddings
|
||||
WHERE team_id = %(team_id)s
|
||||
AND product = %(product)s
|
||||
AND document_type = %(document_type)s
|
||||
AND document_id = %(document_id)s
|
||||
"""
|
||||
result = sync_execute(
|
||||
query,
|
||||
{
|
||||
"team_id": self._team.id,
|
||||
"product": LLM_TRACES_SUMMARIES_PRODUCT,
|
||||
"document_type": document_type,
|
||||
"document_id": document_id,
|
||||
},
|
||||
)
|
||||
return result[0][0] > 0
|
||||
|
||||
def _embed_document(
|
||||
self, content: str, document_id: str, document_type: str, rendering: str, product: str
|
||||
) -> datetime:
|
||||
timestamp = timezone.now()
|
||||
payload = {
|
||||
"team_id": self._team.id,
|
||||
"product": product,
|
||||
"document_type": document_type,
|
||||
"rendering": rendering,
|
||||
"document_id": document_id,
|
||||
"timestamp": timestamp.isoformat(),
|
||||
"content": content,
|
||||
"models": [self._embedding_model_name.value],
|
||||
}
|
||||
self._producer.produce(topic=DOCUMENT_EMBEDDINGS_TOPIC, data=payload)
|
||||
return timestamp
|
||||
98
ee/hogai/llm_traces_summaries/tools/find_similar_traces.py
Normal file
98
ee/hogai/llm_traces_summaries/tools/find_similar_traces.py
Normal file
@@ -0,0 +1,98 @@
|
||||
import structlog
|
||||
|
||||
from posthog.schema import (
|
||||
CachedDocumentSimilarityQueryResponse,
|
||||
DateRange,
|
||||
DistanceFunc,
|
||||
DocumentSimilarityQuery,
|
||||
EmbeddedDocument,
|
||||
EmbeddingDistance,
|
||||
EmbeddingModelName,
|
||||
OrderBy,
|
||||
OrderDirection,
|
||||
)
|
||||
|
||||
from posthog.hogql_queries.document_embeddings_query_runner import DocumentEmbeddingsQueryRunner
|
||||
from posthog.models.team.team import Team
|
||||
|
||||
from ee.hogai.llm_traces_summaries.constants import (
|
||||
LLM_TRACES_SUMMARIES_DOCUMENT_TYPE,
|
||||
LLM_TRACES_SUMMARIES_PRODUCT,
|
||||
LLM_TRACES_SUMMARIES_SEARCH_QUERY_DOCUMENT_TYPE,
|
||||
)
|
||||
from ee.hogai.llm_traces_summaries.tools.embed_summaries import LLMTracesSummarizerEmbedder
|
||||
from ee.models.llm_traces_summaries import LLMTraceSummary
|
||||
|
||||
logger = structlog.get_logger(__name__)
|
||||
|
||||
|
||||
class LLMTracesSummarizerFinder:
|
||||
def __init__(
|
||||
self, team: Team, embedding_model_name: EmbeddingModelName = EmbeddingModelName.TEXT_EMBEDDING_3_LARGE_3072
|
||||
):
|
||||
self._team = team
|
||||
self._embedding_model_name = embedding_model_name
|
||||
|
||||
def find_top_similar_traces_for_query(
|
||||
self,
|
||||
query: str,
|
||||
request_id: str,
|
||||
top: int,
|
||||
date_range: DateRange,
|
||||
summary_type: LLMTraceSummary.LLMTraceSummaryType,
|
||||
) -> dict[str, tuple[EmbeddingDistance, LLMTraceSummary]]:
|
||||
"""Search all summarized traces for the query and return the top similar traces."""
|
||||
embedder = LLMTracesSummarizerEmbedder(team=self._team, embedding_model_name=self._embedding_model_name)
|
||||
# Embed the search query and add to the document embeddings table to be able to search for similar summaries
|
||||
embedding_timestamp = embedder.embed_summaries_search_query_with_timestamp(
|
||||
query=query, request_id=request_id, summary_type=summary_type
|
||||
)
|
||||
similarity_query = DocumentSimilarityQuery(
|
||||
dateRange=date_range,
|
||||
distance_func=DistanceFunc.COSINE_DISTANCE,
|
||||
document_types=[LLM_TRACES_SUMMARIES_DOCUMENT_TYPE], # Searching for summaries
|
||||
products=[LLM_TRACES_SUMMARIES_PRODUCT],
|
||||
# Searching for summaries with the explicit type of summary (like issues search)
|
||||
renderings=[summary_type.value],
|
||||
limit=top,
|
||||
model=self._embedding_model_name.value,
|
||||
order_by=OrderBy.DISTANCE,
|
||||
order_direction=OrderDirection.ASC, # Best matches first
|
||||
origin=EmbeddedDocument(
|
||||
document_id=request_id,
|
||||
document_type=LLM_TRACES_SUMMARIES_SEARCH_QUERY_DOCUMENT_TYPE, # Searching with a query
|
||||
product=LLM_TRACES_SUMMARIES_PRODUCT,
|
||||
timestamp=embedding_timestamp,
|
||||
),
|
||||
)
|
||||
runner = DocumentEmbeddingsQueryRunner(query=similarity_query, team=self._team)
|
||||
response = runner.run()
|
||||
if not isinstance(response, CachedDocumentSimilarityQueryResponse):
|
||||
raise ValueError(
|
||||
f'Failed to get similarity results for query "{query}" ({request_id}) '
|
||||
"from team {self._team.id} when searching for summarized LLM traces"
|
||||
)
|
||||
distances: list[EmbeddingDistance] = response.results
|
||||
# Get relevant summaries for the document_id + team + summary type, newest first
|
||||
summaries = LLMTraceSummary.objects.filter(
|
||||
team=self._team,
|
||||
trace_id__in=[distance.result.document_id for distance in distances],
|
||||
trace_summary_type=summary_type,
|
||||
).order_by("-created_at")
|
||||
if len(summaries) != len(distances):
|
||||
# Raise warning, but don't fail, as some results still be returned
|
||||
logger.warning(
|
||||
f"Number of summaries ({len(summaries)}) does not match number of distances ({len(distances)}) for que"
|
||||
f"query {query} ({request_id}) for team {self._team.id} when searching for summarized LLM traces"
|
||||
)
|
||||
# Combine distances with summaries
|
||||
results: dict[str, tuple[EmbeddingDistance, LLMTraceSummary]] = {}
|
||||
for distance in distances:
|
||||
summaries_for_trace = [x for x in summaries if x.trace_id == distance.result.document_id]
|
||||
if not summaries_for_trace:
|
||||
logger.warning(
|
||||
f"No summary found for trace {distance.result.document_id} for query {query} ({request_id}) for team {self._team.id} when searching for summarized LLM traces"
|
||||
)
|
||||
continue
|
||||
results[distance.result.document_id] = (distance, summaries_for_trace[0])
|
||||
return results
|
||||
@@ -0,0 +1,217 @@
|
||||
import re
|
||||
import asyncio
|
||||
import difflib
|
||||
from copy import copy
|
||||
|
||||
import structlog
|
||||
from google import genai
|
||||
from google.genai.types import GenerateContentConfig
|
||||
from rich.console import Console
|
||||
|
||||
from posthog.models.team.team import Team
|
||||
from posthog.sync import database_sync_to_async
|
||||
|
||||
from products.llm_analytics.backend.providers.gemini import GeminiProvider
|
||||
|
||||
from ee.hogai.llm_traces_summaries.constants import LLM_TRACES_SUMMARIES_MODEL_TO_SUMMARIZE_STRINGIFIED_TRACES
|
||||
from ee.models.llm_traces_summaries import LLMTraceSummary
|
||||
|
||||
logger = structlog.get_logger(__name__)
|
||||
|
||||
GENERATE_STRINGIFIED_TRACE_SUMMARY_PROMPT = """
|
||||
- Analyze this conversation between the user and the PostHog AI assistant
|
||||
- List all pain points, frustrations, and feature limitations the user experienced.
|
||||
- IMPORTANT: Count only specific issues the user experienced when interacting with the assistant, don't guess or suggest.
|
||||
- If no issues - return only "No issues found" text, without any additional comments.
|
||||
- If issues found - provide output as plain English text in a maximum of 10 sentences, while highlighting all the crucial parts.
|
||||
|
||||
```
|
||||
{stringified_trace}
|
||||
```
|
||||
"""
|
||||
|
||||
console = Console()
|
||||
|
||||
|
||||
class LLMTraceSummarizerGenerator:
|
||||
def __init__(
|
||||
self,
|
||||
team: Team,
|
||||
model_id: str = LLM_TRACES_SUMMARIES_MODEL_TO_SUMMARIZE_STRINGIFIED_TRACES,
|
||||
summary_type: LLMTraceSummary.LLMTraceSummaryType = LLMTraceSummary.LLMTraceSummaryType.ISSUES_SEARCH,
|
||||
):
|
||||
self._team = team
|
||||
self._summary_type = summary_type
|
||||
self._model_id = model_id
|
||||
self._provider = GeminiProvider(model_id=model_id)
|
||||
# # Using default Google client as posthog wrapper doesn't support `aio` yet for async calls
|
||||
self._client = genai.Client(api_key=self._provider.get_api_key())
|
||||
# Remove excessive summary parts that add no value to concentrate the summary meaning programmatically
|
||||
# Parts that add no value and can't be safely removed
|
||||
self._no_value_parts = [
|
||||
"several",
|
||||
"during the interaction",
|
||||
"explicitly",
|
||||
]
|
||||
# Repeating prefixes (model-specific)
|
||||
self._excessive_prefixes = ["The user experienced"]
|
||||
# Excessive markdown formatting (model-specific)
|
||||
self._excessive_formatting = ["**"]
|
||||
# Narrative words (one word before comma at the start of the sentence), like "Next, " or "Finally, "
|
||||
self._narrative_words_regex = r"(^|\n|\.\"\s|\.\s)([A-Z]\w+, )"
|
||||
# Ensure the summary is readable after the clean-up
|
||||
self._proper_capitalization_regex = (
|
||||
r"(\.\s|\n\s|\.\"\s|^)([a-z])" # Replace lowercase letters with uppercase at the start of the sentence
|
||||
)
|
||||
|
||||
async def summarize_stringified_traces(self, stringified_traces: dict[str, str]) -> dict[str, str]:
|
||||
"""Summarize a dictionary of stringified traces."""
|
||||
tasks = {}
|
||||
# Check which traces already have summaries to avoid re-generating them
|
||||
existing_trace_ids = await database_sync_to_async(self._check_existing_summaries)(
|
||||
trace_ids=list(stringified_traces.keys())
|
||||
)
|
||||
# Limit to 10 concurrent API calls
|
||||
semaphore = asyncio.Semaphore(10)
|
||||
async with asyncio.TaskGroup() as tg:
|
||||
for trace_id, stringified_trace in list(stringified_traces.items()):
|
||||
if trace_id in existing_trace_ids:
|
||||
# Avoid re-generating summaries that already exist for this team + trace + type
|
||||
continue
|
||||
tasks[trace_id] = tg.create_task(
|
||||
self._generate_trace_summary_with_semaphore(
|
||||
semaphore=semaphore, trace_id=trace_id, stringified_trace=stringified_trace
|
||||
)
|
||||
)
|
||||
summarized_traces: dict[str, str] = {}
|
||||
for trace_id, task in tasks.items():
|
||||
res: str | Exception = task.result()
|
||||
if isinstance(res, Exception):
|
||||
logger.exception(
|
||||
f"Failed to generate summary for trace {trace_id} from team {self._team.id} when summarizing traces: {res}",
|
||||
error=str(res),
|
||||
)
|
||||
continue
|
||||
# If the summary generated is too large to store - skip it
|
||||
if len(res) > 1000:
|
||||
logger.warning(
|
||||
f"Summary for trace {trace_id} from team {self._team.id} is too large to store (over 1000 characters), skipping",
|
||||
)
|
||||
continue
|
||||
# Return only successful summaries
|
||||
summarized_traces[trace_id] = res
|
||||
return summarized_traces
|
||||
|
||||
async def _generate_trace_summary_with_semaphore(
|
||||
self, semaphore: asyncio.Semaphore, trace_id: str, stringified_trace: str
|
||||
) -> str | Exception:
|
||||
"""Wrapper to limit concurrent API calls using a semaphore."""
|
||||
async with semaphore:
|
||||
return await self._generate_trace_summary(trace_id=trace_id, stringified_trace=stringified_trace)
|
||||
|
||||
async def _generate_trace_summary(self, trace_id: str, stringified_trace: str) -> str | Exception:
|
||||
prompt = GENERATE_STRINGIFIED_TRACE_SUMMARY_PROMPT.format(stringified_trace=stringified_trace)
|
||||
try:
|
||||
self._provider.validate_model(self._model_id)
|
||||
config_kwargs = self._provider.prepare_config_kwargs(system="")
|
||||
response = await self._client.aio.models.generate_content(
|
||||
model=self._model_id,
|
||||
contents=prompt,
|
||||
config=GenerateContentConfig(**config_kwargs),
|
||||
)
|
||||
if not response.text:
|
||||
raise ValueError(f"No trace summary was generated for trace {trace_id} from team {self._team.id}")
|
||||
# Avoid LLM returning excessive comments when no issues found
|
||||
if "no issues found" in response.text.lower() and response.text.lower() != "no issues found":
|
||||
logger.info(
|
||||
f"Original 'no issues' text for trace {trace_id} from team {self._team.id} (replaced with 'No issues found'): {response.text}"
|
||||
)
|
||||
return "No issues found"
|
||||
if response.text.lower() == "no issues found":
|
||||
return "No issues found"
|
||||
cleaned_up_summary = self._clean_up_summary_before_embedding(trace_id=trace_id, summary=response.text)
|
||||
return cleaned_up_summary
|
||||
except Exception as err:
|
||||
return err # Let caller handle the error
|
||||
|
||||
def _clean_up_summary_before_embedding(self, trace_id: str, summary: str, log_diff: bool = False) -> str:
|
||||
"""Remove repetitive phrases and excessive formatting to make embeddings more accurate."""
|
||||
original_summary = copy(summary)
|
||||
# Remove parts that don't add value
|
||||
for part in self._no_value_parts:
|
||||
summary = summary.replace(f" {part} ", " ")
|
||||
# Remove excessive prefixes
|
||||
for prefix in self._excessive_prefixes:
|
||||
while True:
|
||||
# Remove all occurrences
|
||||
prefix_index = summary.find(prefix)
|
||||
if prefix_index == -1:
|
||||
# Not found
|
||||
break
|
||||
# Remove prefix
|
||||
summary = summary[:prefix_index] + summary[prefix_index + len(prefix) :]
|
||||
# Remove narrative words (one word before comma at the start of the sentence)
|
||||
summary = re.sub(self._narrative_words_regex, lambda m: m.group(1), summary)
|
||||
# Remove excessive formatting
|
||||
for formatting in self._excessive_formatting:
|
||||
while True:
|
||||
# Remove all occurrences
|
||||
formatting_index = summary.find(formatting)
|
||||
if formatting_index == -1:
|
||||
break
|
||||
# Remove formatting
|
||||
summary = summary[:formatting_index] + summary[formatting_index + len(formatting) :]
|
||||
# Strip, just in case
|
||||
summary = summary.strip()
|
||||
# Replace the symbols after dot + space, newline + space, or start + space with uppercase if they are lowercase
|
||||
summary = re.sub(self._proper_capitalization_regex, lambda m: m.group(1) + m.group(2).upper(), summary)
|
||||
if len(summary) / len(original_summary) <= 0.8:
|
||||
logger.warning(
|
||||
f"Summary for trace {trace_id} from team {self._team.id} is too different from the original summary "
|
||||
"(smaller 20%+ after cleanup) when summarizing traces",
|
||||
)
|
||||
# Force log diff if drastic difference
|
||||
log_diff = True
|
||||
if summary == original_summary or not log_diff:
|
||||
return summary
|
||||
# Log differences, if any, when asked explicitly
|
||||
self._log_diff(trace_id=trace_id, original_summary=original_summary, summary=summary)
|
||||
return summary
|
||||
|
||||
def _log_diff(self, trace_id: str, original_summary: str, summary: str) -> None:
|
||||
"""Optional helper function to log the differences between the original and cleaned up summaries."""
|
||||
logger.info(f"Summary cleaned up for trace {trace_id} from team {self._team.id} when summarizing traces")
|
||||
logger.info(f"Original summary:\n{original_summary}")
|
||||
logger.info(f"Cleaned summary:\n{summary}")
|
||||
# Character-level diff for precise changes
|
||||
console.print("[bold]Changes:[/bold]")
|
||||
matcher = difflib.SequenceMatcher(None, original_summary, summary)
|
||||
for tag, i1, i2, j1, j2 in matcher.get_opcodes():
|
||||
if tag == "delete":
|
||||
console.print(f"[red]Removed: '{original_summary[i1:i2]}'[/red]")
|
||||
elif tag == "insert":
|
||||
console.print(f"[green]Added: '{summary[j1:j2]}'[/green]")
|
||||
elif tag == "replace":
|
||||
console.print(f"[red]Removed: '{original_summary[i1:i2]}'[/red]")
|
||||
console.print(f"[green]Added: '{summary[j1:j2]}'[/green]")
|
||||
console.print("=" * 50 + "\n")
|
||||
|
||||
def store_summaries_in_db(self, summarized_traces: dict[str, str]):
|
||||
# Store summaries in the database should be part of the embedding process
|
||||
# Temporary PSQL solution to test end-to-end summarization pipeline
|
||||
# TODO: Should be replaced (or migrated to) later with the Clickhouse-powered solution to allow FTS
|
||||
summaries_batch_size = 500
|
||||
summaries_for_db = [
|
||||
LLMTraceSummary(team=self._team, trace_id=trace_id, summary=summary, trace_summary_type=self._summary_type)
|
||||
for trace_id, summary in summarized_traces.items()
|
||||
]
|
||||
# Ignore already processed traces summaries, if they get to this stage
|
||||
LLMTraceSummary.objects.bulk_create(summaries_for_db, batch_size=summaries_batch_size, ignore_conflicts=True)
|
||||
|
||||
def _check_existing_summaries(self, trace_ids: list[str]) -> set[str]:
|
||||
existing_trace_ids = set(
|
||||
LLMTraceSummary.objects.filter(
|
||||
team=self._team, trace_summary_type=self._summary_type, trace_id__in=trace_ids
|
||||
).values_list("trace_id", flat=True)
|
||||
)
|
||||
return existing_trace_ids
|
||||
45
ee/hogai/llm_traces_summaries/tools/get_traces.py
Normal file
45
ee/hogai/llm_traces_summaries/tools/get_traces.py
Normal file
@@ -0,0 +1,45 @@
|
||||
from posthog.schema import CachedTracesQueryResponse, DateRange, HogQLPropertyFilter, QueryLogTags, TracesQuery
|
||||
|
||||
from posthog.hogql_queries.ai.traces_query_runner import TracesQueryRunner
|
||||
from posthog.models.team.team import Team
|
||||
|
||||
|
||||
class LLMTracesSummarizerCollector:
|
||||
def __init__(self, team: Team):
|
||||
self._team = team
|
||||
# Should be large enough to go fast, and small enough to avoid any memory issues
|
||||
self._traces_per_page = 100
|
||||
|
||||
def get_db_traces_per_page(self, offset: int, date_range: DateRange) -> CachedTracesQueryResponse:
|
||||
query = self._get_traces_query(offset=offset, date_range=date_range, limit=self._traces_per_page)
|
||||
runner = TracesQueryRunner(query=query, team=self._team)
|
||||
response = runner.run()
|
||||
if not isinstance(response, CachedTracesQueryResponse):
|
||||
raise ValueError(f"Failed to get result for the previous day when summarizing LLM traces: {response}")
|
||||
return response
|
||||
|
||||
def get_db_trace_ids(self, date_range: DateRange, limit: int) -> list[str]:
|
||||
"""Get all the trace ids (but ids only) within the date range."""
|
||||
query = self._get_traces_query(offset=0, date_range=date_range, limit=limit)
|
||||
runner = TracesQueryRunner(query=query, team=self._team)
|
||||
# Expecting to get all trace ids in a single query, as it should be lightweight-ish
|
||||
trace_ids, _, _ = runner._get_trace_ids()
|
||||
return trace_ids
|
||||
|
||||
@staticmethod
|
||||
def _get_traces_query(offset: int, date_range: DateRange, limit: int) -> TracesQuery:
|
||||
return TracesQuery(
|
||||
dateRange=date_range,
|
||||
filterTestAccounts=False, # Internal users are active, so don't see the reason to filter them out, for now
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
properties=[
|
||||
HogQLPropertyFilter(
|
||||
# Analyze only LangGraph traces initially
|
||||
type="hogql",
|
||||
key="properties.$ai_span_name = 'LangGraph'",
|
||||
value=None,
|
||||
)
|
||||
],
|
||||
tags=QueryLogTags(productKey="LLMAnalytics"),
|
||||
)
|
||||
134
ee/hogai/llm_traces_summaries/tools/stringify_trace.py
Normal file
134
ee/hogai/llm_traces_summaries/tools/stringify_trace.py
Normal file
@@ -0,0 +1,134 @@
|
||||
from typing import Any
|
||||
|
||||
import tiktoken
|
||||
import structlog
|
||||
|
||||
from posthog.schema import LLMTrace
|
||||
|
||||
from posthog.models.team.team import Team
|
||||
|
||||
logger = structlog.get_logger(__name__)
|
||||
|
||||
|
||||
class LLMTracesSummarizerStringifier:
|
||||
def __init__(self, team: Team):
|
||||
self._team = team
|
||||
self._token_encoder = tiktoken.encoding_for_model("gpt-4o")
|
||||
self._stringified_trace_max_tokens = 5000
|
||||
|
||||
def stringify_traces(self, traces_chunk: list[LLMTrace]) -> dict[str, str]:
|
||||
stringified_traces: dict[str, str] = {}
|
||||
for trace in traces_chunk:
|
||||
stringified_trace = self._stringify_trace_messages(trace)
|
||||
if not stringified_trace:
|
||||
continue
|
||||
stringified_traces[trace.id] = stringified_trace
|
||||
return stringified_traces
|
||||
|
||||
def _stringify_trace_messages(self, trace: LLMTrace) -> str | None:
|
||||
stringified_messages: list[str] = []
|
||||
# TODO: Iterate full conversations (traces combined) instead of just traces, as it leads to duplicates
|
||||
messages = trace.outputState.get("messages") if trace.outputState else []
|
||||
for message in messages:
|
||||
stringified_message = self.stringify_message(message=message, trace_id=trace.id)
|
||||
# Skip empty messages
|
||||
if not stringified_message:
|
||||
continue
|
||||
# Check that the previous message isn't identical
|
||||
if stringified_messages and stringified_messages[-1] == stringified_message:
|
||||
continue
|
||||
stringified_messages.append(stringified_message)
|
||||
# If no messages, skip the trace
|
||||
if not stringified_messages:
|
||||
return None
|
||||
# If human didn't respond to any AI messages (no interaction), skip the trace
|
||||
no_interaction_found = True
|
||||
for i, message in enumerate(stringified_messages):
|
||||
if message.startswith("human") and i > 0 and stringified_messages[i - 1].startswith("ai"):
|
||||
no_interaction_found = False
|
||||
break
|
||||
if no_interaction_found:
|
||||
return None
|
||||
# Combine into string
|
||||
stringified_messages_str = "\n\n".join(stringified_messages)
|
||||
# Check if the trace is too long for summarization
|
||||
num_tokens = len(self._token_encoder.encode(stringified_messages_str))
|
||||
if num_tokens > self._stringified_trace_max_tokens:
|
||||
logger.warning(
|
||||
f"Trace {trace.id} from team {self._team.id} stringified version is too long ({num_tokens} tokens > {self._stringified_trace_max_tokens})"
|
||||
"for summarization when summarizing LLM traces, skipping"
|
||||
)
|
||||
return None
|
||||
return stringified_messages_str
|
||||
|
||||
@staticmethod
|
||||
def _stringify_answer(message: dict[str, Any]) -> str | None:
|
||||
answer_kind = message["answer"]["kind"]
|
||||
message_content = f"*AI displayed a {answer_kind}*"
|
||||
if not message_content:
|
||||
return None
|
||||
return f"ai/answer: {message_content}"
|
||||
|
||||
@staticmethod
|
||||
def _stringify_ai_message(message: dict[str, Any]) -> str | None:
|
||||
message_content = message["content"]
|
||||
tools_called = []
|
||||
for tc in message.get("tool_calls") or []:
|
||||
if tc.get("type") != "tool_call":
|
||||
continue
|
||||
tools_called.append(tc.get("name"))
|
||||
if tools_called:
|
||||
tool_content = f"*AI called tools: {', '.join(tools_called)}*"
|
||||
message_content += f" {tool_content}" if message_content else tool_content
|
||||
if not message_content:
|
||||
return None
|
||||
return f"ai: {message_content}"
|
||||
|
||||
@staticmethod
|
||||
def _stringify_tool_message(message: dict[str, Any]) -> str | None:
|
||||
# Keep navigation messages
|
||||
if (
|
||||
message.get("ui_payload")
|
||||
and isinstance(message["ui_payload"], dict)
|
||||
and message["ui_payload"].get("navigate")
|
||||
):
|
||||
return f"ai/navigation: *{message['content']}*"
|
||||
# TODO: Decide how to catch errors as they aren't marked as errors in the trace
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _stringify_human_message(message: dict[str, Any]) -> str | None:
|
||||
message_content = message["content"]
|
||||
if not message_content:
|
||||
return None
|
||||
return f"human: {message_content}"
|
||||
|
||||
def stringify_message(self, message: dict[str, Any], trace_id: str) -> str | None:
|
||||
message_type = message.get("type")
|
||||
if not message_type:
|
||||
logger.warning(
|
||||
f"Message {message.get('id')} for trace {trace_id} from team {self._team.id} has no type, skipping"
|
||||
)
|
||||
return None
|
||||
try:
|
||||
# Answers
|
||||
if message.get("answer"):
|
||||
return self._stringify_answer(message)
|
||||
# Messages
|
||||
if message_type == "ai":
|
||||
return self._stringify_ai_message(message)
|
||||
if message_type == "human":
|
||||
return self._stringify_human_message(message)
|
||||
if message_type == "context": # Skip context messages
|
||||
return None
|
||||
if message_type == "tool": # Decide if to keep tool messages
|
||||
return self._stringify_tool_message(message)
|
||||
# Ignore other message types
|
||||
# TODO: Decide if there's a need for other message types
|
||||
return None
|
||||
except Exception as err:
|
||||
logger.exception(
|
||||
f"Error stringifying message {message_type} ({err}) for trace {trace_id} from team {self._team.id}:\n{message}",
|
||||
error=str(err),
|
||||
)
|
||||
return None
|
||||
92
ee/hogai/llm_traces_summaries/utils/load_from_csv.py
Normal file
92
ee/hogai/llm_traces_summaries/utils/load_from_csv.py
Normal file
@@ -0,0 +1,92 @@
|
||||
import re
|
||||
import csv
|
||||
import sys
|
||||
import json
|
||||
import uuid
|
||||
from collections.abc import Generator
|
||||
|
||||
import structlog
|
||||
|
||||
from posthog.schema import LLMTrace, LLMTracePerson
|
||||
|
||||
logger = structlog.get_logger(__name__)
|
||||
csv.field_size_limit(sys.maxsize)
|
||||
|
||||
|
||||
def load_traces_from_csv_files(csv_paths: list[str]) -> Generator[LLMTrace, None, None]:
|
||||
"""Load traces from CSV files, useful for local development."""
|
||||
# The assumption is that the CSV was exported through default Traces query runner query
|
||||
fields_to_column_mapping = {
|
||||
"id": 0,
|
||||
"createdAt": 1,
|
||||
"person": 2,
|
||||
"totalLatency": 3,
|
||||
"inputTokens": 4,
|
||||
"outputTokens": 5,
|
||||
"inputCost": 6,
|
||||
"outputCost": 7,
|
||||
"totalCost": 8,
|
||||
"inputState": 9,
|
||||
"outputState": 10,
|
||||
"traceName": 11,
|
||||
}
|
||||
for csv_path in csv_paths:
|
||||
with open(csv_path) as file:
|
||||
reader = csv.reader(file)
|
||||
# Skip header
|
||||
next(reader)
|
||||
for row in reader:
|
||||
output_state_raw = row[fields_to_column_mapping["outputState"]]
|
||||
if not output_state_raw:
|
||||
continue
|
||||
output_state = json.loads(output_state_raw)
|
||||
input_state_raw = row[fields_to_column_mapping["inputState"]]
|
||||
if not input_state_raw:
|
||||
input_state = {} # Allowing empty, as it's not used in calculations
|
||||
else:
|
||||
input_state = json.loads(input_state_raw)
|
||||
# Create a person with minimal data (uuid and email)
|
||||
raw_person_data = row[fields_to_column_mapping["person"]]
|
||||
person_uuids_search = re.findall(r"UUID\('(.*?)'\)", raw_person_data)
|
||||
if not person_uuids_search:
|
||||
logger.warning(f"No person UUID found for person: {raw_person_data.splitlines()[0]}")
|
||||
continue
|
||||
person_uuid = person_uuids_search[0]
|
||||
properties = {}
|
||||
person_emails_search = re.findall(r'"email":"(.*?)"', raw_person_data)
|
||||
if person_emails_search:
|
||||
person_email = person_emails_search[0]
|
||||
properties["email"] = person_email
|
||||
person = LLMTracePerson(
|
||||
created_at=row[fields_to_column_mapping["createdAt"]],
|
||||
distinct_id=str(uuid.uuid4()), # Not used in calculations
|
||||
properties=properties,
|
||||
uuid=person_uuid,
|
||||
)
|
||||
# Process other properties, could be skipped, for now, as not used for calculations
|
||||
input_cost = row[fields_to_column_mapping["inputCost"]] or 0
|
||||
input_tokens = row[fields_to_column_mapping["inputTokens"]] or 0
|
||||
output_cost = row[fields_to_column_mapping["outputCost"]] or 0
|
||||
output_tokens = row[fields_to_column_mapping["outputTokens"]] or 0
|
||||
total_cost = row[fields_to_column_mapping["totalCost"]] or 0
|
||||
total_latency = row[fields_to_column_mapping["totalLatency"]]
|
||||
trace_name = row[fields_to_column_mapping["traceName"]]
|
||||
# Create LLM trace object
|
||||
trace = LLMTrace(
|
||||
aiSessionId=str(uuid.uuid4()), # Not used in calculations
|
||||
createdAt=row[fields_to_column_mapping["createdAt"]],
|
||||
errorCount=0, # Not used in calculations
|
||||
events=[], # Not used in calculations
|
||||
id=row[fields_to_column_mapping["id"]],
|
||||
inputCost=input_cost,
|
||||
inputState=input_state,
|
||||
inputTokens=input_tokens,
|
||||
outputCost=output_cost,
|
||||
outputState=output_state,
|
||||
outputTokens=output_tokens,
|
||||
person=person,
|
||||
totalCost=total_cost,
|
||||
totalLatency=total_latency,
|
||||
traceName=trace_name,
|
||||
)
|
||||
yield trace
|
||||
@@ -0,0 +1,51 @@
|
||||
# Generated by Django 4.2.22 on 2025-10-25 16:38
|
||||
|
||||
import django.db.models.deletion
|
||||
from django.db import migrations, models
|
||||
|
||||
import posthog.models.utils
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
dependencies = [
|
||||
("posthog", "0888_datawarehousemanagedviewset_and_more"),
|
||||
("ee", "0028_alter_conversation_type"),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.CreateModel(
|
||||
name="LLMTraceSummary",
|
||||
fields=[
|
||||
(
|
||||
"id",
|
||||
models.UUIDField(
|
||||
default=posthog.models.utils.UUIDT, editable=False, primary_key=True, serialize=False
|
||||
),
|
||||
),
|
||||
("created_at", models.DateTimeField(auto_now_add=True, null=True)),
|
||||
("updated_at", models.DateTimeField(auto_now=True, null=True)),
|
||||
(
|
||||
"trace_summary_type",
|
||||
models.CharField(
|
||||
choices=[("issues_search", "Issues Search")], default="issues_search", max_length=100
|
||||
),
|
||||
),
|
||||
("trace_id", models.CharField(help_text="Trace ID", max_length=255)),
|
||||
("summary", models.CharField(help_text="Trace summary", max_length=1000)),
|
||||
("team", models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to="posthog.team")),
|
||||
],
|
||||
options={
|
||||
"indexes": [
|
||||
models.Index(
|
||||
fields=["team", "trace_id", "trace_summary_type"], name="ee_llmtrace_team_id_93a147_idx"
|
||||
)
|
||||
],
|
||||
},
|
||||
),
|
||||
migrations.AddConstraint(
|
||||
model_name="llmtracesummary",
|
||||
constraint=models.UniqueConstraint(
|
||||
fields=("team", "trace_id", "trace_summary_type"), name="unique_trace_summary"
|
||||
),
|
||||
),
|
||||
]
|
||||
@@ -1 +1 @@
|
||||
0028_alter_conversation_type
|
||||
0029_llmtracesummary_llmtracesummary_unique_trace_summary
|
||||
|
||||
30
ee/models/llm_traces_summaries.py
Normal file
30
ee/models/llm_traces_summaries.py
Normal file
@@ -0,0 +1,30 @@
|
||||
from django.db import models
|
||||
|
||||
from posthog.models.team.team import Team
|
||||
from posthog.models.utils import UUIDTModel
|
||||
|
||||
|
||||
class LLMTraceSummary(UUIDTModel):
|
||||
class Meta:
|
||||
indexes = [
|
||||
models.Index(fields=["team", "trace_id", "trace_summary_type"]),
|
||||
]
|
||||
constraints = [
|
||||
models.UniqueConstraint(fields=["team", "trace_id", "trace_summary_type"], name="unique_trace_summary"),
|
||||
]
|
||||
|
||||
class LLMTraceSummaryType(models.TextChoices):
|
||||
"""
|
||||
Traces could be summarized with different types of prompts for different use cases.
|
||||
"""
|
||||
|
||||
ISSUES_SEARCH = "issues_search"
|
||||
|
||||
team = models.ForeignKey(Team, on_delete=models.CASCADE)
|
||||
created_at = models.DateTimeField(auto_now_add=True, null=True)
|
||||
updated_at = models.DateTimeField(auto_now=True, null=True)
|
||||
trace_summary_type = models.CharField(
|
||||
max_length=100, choices=LLMTraceSummaryType.choices, default=LLMTraceSummaryType.ISSUES_SEARCH
|
||||
)
|
||||
trace_id = models.CharField(max_length=255, help_text="Trace ID")
|
||||
summary = models.CharField(max_length=1000, help_text="Trace summary")
|
||||
@@ -186,7 +186,8 @@ class DocumentEmbeddingsQueryRunner(AnalyticsQueryRunner[DocumentSimilarityQuery
|
||||
),
|
||||
]
|
||||
),
|
||||
ast.And(
|
||||
# If the document type or product is different - compare all renderings
|
||||
ast.Or(
|
||||
exprs=[
|
||||
ast.CompareOperation(
|
||||
op=ast.CompareOperationOp.NotEq,
|
||||
|
||||
@@ -10,6 +10,11 @@ from posthog.temporal.ai.session_summary.activities.video_validation import (
|
||||
)
|
||||
from posthog.temporal.ai.session_summary.types.single import SingleSessionSummaryInputs
|
||||
|
||||
from .llm_traces_summaries.summarize_traces import (
|
||||
SummarizeLLMTracesInputs,
|
||||
SummarizeLLMTracesWorkflow,
|
||||
summarize_llm_traces_activity,
|
||||
)
|
||||
from .session_summary.summarize_session import (
|
||||
SummarizeSingleSessionStreamWorkflow,
|
||||
SummarizeSingleSessionWorkflow,
|
||||
@@ -37,6 +42,7 @@ WORKFLOWS = [
|
||||
SummarizeSingleSessionWorkflow,
|
||||
SummarizeSessionGroupWorkflow,
|
||||
AssistantConversationRunnerWorkflow,
|
||||
SummarizeLLMTracesWorkflow,
|
||||
]
|
||||
|
||||
ACTIVITIES = [
|
||||
@@ -53,6 +59,7 @@ ACTIVITIES = [
|
||||
split_session_summaries_into_chunks_for_patterns_extraction_activity,
|
||||
process_conversation_activity,
|
||||
validate_llm_single_session_summary_with_videos_activity,
|
||||
summarize_llm_traces_activity,
|
||||
]
|
||||
|
||||
__all__ = [
|
||||
@@ -60,4 +67,5 @@ __all__ = [
|
||||
"SingleSessionSummaryInputs",
|
||||
"SessionGroupSummaryInputs",
|
||||
"SessionGroupSummaryOfSummariesInputs",
|
||||
"SummarizeLLMTracesInputs",
|
||||
]
|
||||
|
||||
82
posthog/temporal/ai/llm_traces_summaries/summarize_traces.py
Normal file
82
posthog/temporal/ai/llm_traces_summaries/summarize_traces.py
Normal file
@@ -0,0 +1,82 @@
|
||||
import json
|
||||
import uuid
|
||||
import dataclasses
|
||||
from datetime import timedelta
|
||||
|
||||
from django.conf import settings
|
||||
|
||||
import structlog
|
||||
import temporalio
|
||||
from temporalio.common import RetryPolicy, WorkflowIDReusePolicy
|
||||
|
||||
from posthog.schema import DateRange
|
||||
|
||||
from posthog.models.team.team import Team
|
||||
from posthog.settings.temporal import MAX_AI_TASK_QUEUE
|
||||
from posthog.sync import database_sync_to_async
|
||||
from posthog.temporal.common.base import PostHogWorkflow
|
||||
from posthog.temporal.common.client import async_connect
|
||||
|
||||
from ee.hogai.llm_traces_summaries.summarize_traces import LLMTracesSummarizer
|
||||
|
||||
logger = structlog.get_logger(__name__)
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True, kw_only=True)
|
||||
class SummarizeLLMTracesInputs:
|
||||
date_to: str | None = None
|
||||
date_from: str | None = None
|
||||
team_id: int
|
||||
|
||||
|
||||
@temporalio.activity.defn
|
||||
async def summarize_llm_traces_activity(
|
||||
inputs: SummarizeLLMTracesInputs,
|
||||
) -> None:
|
||||
"""Summmarize and store embeddings for all LLM traces in the date range."""
|
||||
team = await database_sync_to_async(Team.objects.get)(id=inputs.team_id)
|
||||
summarizer = LLMTracesSummarizer(team=team)
|
||||
date_range = DateRange(date_from=inputs.date_from, date_to=inputs.date_to)
|
||||
await summarizer.summarize_traces_for_date_range(date_range=date_range)
|
||||
return None
|
||||
|
||||
|
||||
@temporalio.workflow.defn(name="summarize-llm-traces")
|
||||
class SummarizeLLMTracesWorkflow(PostHogWorkflow):
|
||||
@staticmethod
|
||||
def parse_inputs(inputs: list[str]) -> SummarizeLLMTracesInputs:
|
||||
"""Parse inputs from the management command CLI."""
|
||||
loaded = json.loads(inputs[0])
|
||||
return SummarizeLLMTracesInputs(**loaded)
|
||||
|
||||
@temporalio.workflow.run
|
||||
async def run(self, inputs: SummarizeLLMTracesInputs) -> None:
|
||||
await temporalio.workflow.execute_activity(
|
||||
summarize_llm_traces_activity,
|
||||
inputs,
|
||||
start_to_close_timeout=timedelta(minutes=30),
|
||||
retry_policy=RetryPolicy(maximum_attempts=3),
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
async def execute_summarize_llm_traces(
|
||||
date_range: DateRange,
|
||||
team: Team,
|
||||
) -> None:
|
||||
"""
|
||||
Start the direct summarization workflow (no streaming) to stringify traces > generate summaries > generate embeddings and store everything.
|
||||
"""
|
||||
workflow_id = f"llm-traces:summarize-traces:{date_range.date_from}:{date_range.date_to}:{team.id}:{uuid.uuid4()}"
|
||||
if not date_range.date_from and not date_range.date_to:
|
||||
raise ValueError("At least one of date_from or date_to must be provided when summarizing traces")
|
||||
client = await async_connect()
|
||||
retry_policy = RetryPolicy(maximum_attempts=int(settings.TEMPORAL_WORKFLOW_MAX_ATTEMPTS))
|
||||
await client.execute_workflow(
|
||||
"summarize-llm-traces",
|
||||
SummarizeLLMTracesInputs(date_to=date_range.date_to, date_from=date_range.date_from, team_id=team.id),
|
||||
id=workflow_id,
|
||||
id_reuse_policy=WorkflowIDReusePolicy.ALLOW_DUPLICATE_FAILED_ONLY,
|
||||
task_queue=MAX_AI_TASK_QUEUE,
|
||||
retry_policy=retry_policy,
|
||||
)
|
||||
@@ -10,6 +10,7 @@ class TestAITemporalModuleIntegrity:
|
||||
"SummarizeSingleSessionWorkflow",
|
||||
"SummarizeSessionGroupWorkflow",
|
||||
"AssistantConversationRunnerWorkflow",
|
||||
"SummarizeLLMTracesWorkflow",
|
||||
]
|
||||
actual_workflow_names = [workflow.__name__ for workflow in ai.WORKFLOWS]
|
||||
assert len(actual_workflow_names) == len(expected_workflows), (
|
||||
@@ -42,6 +43,7 @@ class TestAITemporalModuleIntegrity:
|
||||
"split_session_summaries_into_chunks_for_patterns_extraction_activity",
|
||||
"process_conversation_activity",
|
||||
"validate_llm_single_session_summary_with_videos_activity",
|
||||
"summarize_llm_traces_activity",
|
||||
]
|
||||
actual_activity_names = [activity.__name__ for activity in ai.ACTIVITIES]
|
||||
assert len(actual_activity_names) == len(expected_activities), (
|
||||
@@ -65,6 +67,7 @@ class TestAITemporalModuleIntegrity:
|
||||
"SingleSessionSummaryInputs",
|
||||
"SessionGroupSummaryInputs",
|
||||
"SessionGroupSummaryOfSummariesInputs",
|
||||
"SummarizeLLMTracesInputs",
|
||||
]
|
||||
actual_exports = ai.__all__
|
||||
assert len(actual_exports) == len(expected_exports), (
|
||||
|
||||
@@ -2,6 +2,7 @@ import json
|
||||
import uuid
|
||||
import logging
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
|
||||
from django.conf import settings
|
||||
|
||||
@@ -72,9 +73,11 @@ class GeminiProvider:
|
||||
"id": f"gemini_tool_{hash(str(part.function_call))}",
|
||||
"function": {
|
||||
"name": part.function_call.name,
|
||||
"arguments": json.dumps(dict(part.function_call.args))
|
||||
if part.function_call.args
|
||||
else "{}",
|
||||
"arguments": (
|
||||
json.dumps(dict(part.function_call.args))
|
||||
if part.function_call.args
|
||||
else "{}"
|
||||
),
|
||||
},
|
||||
}
|
||||
|
||||
@@ -84,6 +87,27 @@ class GeminiProvider:
|
||||
|
||||
return results
|
||||
|
||||
@staticmethod
|
||||
def prepare_config_kwargs(
|
||||
system: str,
|
||||
temperature: float | None = None,
|
||||
max_tokens: int | None = None,
|
||||
tools: list[dict] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
effective_temperature = temperature if temperature is not None else GeminiConfig.TEMPERATURE
|
||||
effective_max_tokens = max_tokens # May be None; Gemini API uses max_output_tokens
|
||||
# Build config with conditionals
|
||||
config_kwargs: dict[str, Any] = {
|
||||
"temperature": effective_temperature,
|
||||
}
|
||||
if system:
|
||||
config_kwargs["system_instruction"] = system
|
||||
if effective_max_tokens is not None:
|
||||
config_kwargs["max_output_tokens"] = effective_max_tokens
|
||||
if tools is not None:
|
||||
config_kwargs["tools"] = tools
|
||||
return config_kwargs
|
||||
|
||||
def stream_response(
|
||||
self,
|
||||
system: str,
|
||||
@@ -98,24 +122,13 @@ class GeminiProvider:
|
||||
groups: dict | None = None,
|
||||
) -> Generator[str, None]:
|
||||
"""
|
||||
Async generator function that yields SSE formatted data
|
||||
Async generator function that yields SSE formatted data.
|
||||
"""
|
||||
self.validate_model(self.model_id)
|
||||
|
||||
try:
|
||||
effective_temperature = temperature if temperature is not None else GeminiConfig.TEMPERATURE
|
||||
effective_max_tokens = max_tokens # May be None; Gemini API uses max_output_tokens
|
||||
|
||||
# Build config with conditionals
|
||||
config_kwargs = {
|
||||
"system_instruction": system,
|
||||
"temperature": effective_temperature,
|
||||
}
|
||||
if effective_max_tokens is not None:
|
||||
config_kwargs["max_output_tokens"] = effective_max_tokens
|
||||
if tools is not None:
|
||||
config_kwargs["tools"] = tools
|
||||
|
||||
config_kwargs = self.prepare_config_kwargs(
|
||||
system=system, temperature=temperature, max_tokens=max_tokens, tools=tools
|
||||
)
|
||||
response = self.client.models.generate_content_stream(
|
||||
model=self.model_id,
|
||||
contents=convert_anthropic_messages_to_gemini(messages),
|
||||
@@ -137,10 +150,47 @@ class GeminiProvider:
|
||||
yield f"data: {json.dumps({'type': 'usage', 'input_tokens': input_tokens, 'output_tokens': output_tokens})}\n\n"
|
||||
|
||||
except APIError as e:
|
||||
logger.exception(f"Gemini API error: {e}")
|
||||
logger.exception(f"Gemini API error when streaming response: {e}")
|
||||
yield f"data: {json.dumps({'type': 'error', 'error': f'Gemini API error'})}\n\n"
|
||||
return
|
||||
except Exception as e:
|
||||
logger.exception(f"Unexpected error: {e}")
|
||||
logger.exception(f"Unexpected error when streaming response: {e}")
|
||||
yield f"data: {json.dumps({'type': 'error', 'error': f'Unexpected error'})}\n\n"
|
||||
return
|
||||
|
||||
def get_response(
|
||||
self,
|
||||
system: str,
|
||||
prompt: str,
|
||||
temperature: float | None = None,
|
||||
max_tokens: int | None = None,
|
||||
tools: list[dict] | None = None,
|
||||
distinct_id: str = "",
|
||||
trace_id: str | None = None,
|
||||
properties: dict | None = None,
|
||||
groups: dict | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Get direct string response from Gemini API for a provided string prompt (no streaming).
|
||||
"""
|
||||
self.validate_model(self.model_id)
|
||||
try:
|
||||
config_kwargs = self.prepare_config_kwargs(
|
||||
system=system, temperature=temperature, max_tokens=max_tokens, tools=tools
|
||||
)
|
||||
response = self.client.models.generate_content(
|
||||
model=self.model_id,
|
||||
contents=prompt,
|
||||
config=GenerateContentConfig(**config_kwargs),
|
||||
posthog_distinct_id=distinct_id,
|
||||
posthog_trace_id=trace_id or str(uuid.uuid4()),
|
||||
posthog_properties={**(properties or {}), "ai_product": "playground"},
|
||||
posthog_groups=groups or {},
|
||||
)
|
||||
return response.text
|
||||
except APIError as err:
|
||||
logger.exception(f"Gemini API error when getting response: {err}")
|
||||
raise
|
||||
except Exception as err:
|
||||
logger.exception(f"Unexpected error when getting response: {err}")
|
||||
raise
|
||||
|
||||
Reference in New Issue
Block a user