fix: refactor reranker

This commit is contained in:
Clelia (Astra) Bertelli
2026-01-07 14:43:38 +01:00
parent 5abc97790d
commit 98943aaf26
2 changed files with 49 additions and 27 deletions
@@ -15,6 +15,7 @@ class Pipeline:
self,
qdrant_client: AsyncQdrantClient,
qdrant_collection_name: str,
rrf_constant: int = 60,
parsing_kwargs: dict[str, Any] | None = None,
cache_directory: str | None = None,
openai_api_key: str | None = None,
@@ -50,6 +51,7 @@ class Pipeline:
qdrant_client=qdrant_client,
collection_name=qdrant_collection_name,
embedder=self.embedder,
rrf_constant=rrf_constant,
)
self.filter_llm = LLMFilter(api_key=openai_api_key, model=openai_llm_model)
self.file_paths: list[str] = []
@@ -9,7 +9,6 @@ from qdrant_client.models import (
FieldCondition,
MatchValue,
)
from statistics import mean
from typing import TypedDict, Literal, cast
@@ -26,29 +25,40 @@ class SearchResult(TypedDict):
class SimpleReranker:
def __init__(self) -> None:
pass
def __init__(self, k: int = 60) -> None:
"""
Args:
k: Constant for RRF formula. Higher values reduce the impact of top-ranked items. Default of 60 is commonly used in literature.
"""
self.k = k
def _dedupe(
def _reciprocal_rank_fusion(
self, dense_results: list[SearchResult], sparse_results: list[SearchResult]
) -> list[SearchResult]:
dense_results_ranked = dense_results.copy()
sparse_results_ranked = sparse_results.copy()
dense_results_ranked.sort(key=lambda x: x["content"], reverse=False)
sparse_results_ranked.sort(key=lambda x: x["content"], reverse=False)
for i, r in enumerate(dense_results_ranked):
r["score"] = i + 1
for i, r in enumerate(sparse_results_ranked):
r["score"] = i + 1
for result in sparse_results_ranked:
for i, r in enumerate(dense_results_ranked):
if r["content"] == result["content"]:
r["score"] = mean([r["score"], result["score"]])
dense_results_ranked[i] = r
break
else:
dense_results_ranked.append(result)
return dense_results_ranked
) -> dict[str, float]:
rrf_scores: dict[str, float] = {}
for rank, result in enumerate(dense_results, start=1):
content = result["content"]
rrf_scores[content] = rrf_scores.get(content, 0.0) + 1 / (self.k + rank)
for rank, result in enumerate(sparse_results, start=1):
content = result["content"]
rrf_scores[content] = rrf_scores.get(content, 0.0) + 1 / (self.k + rank)
return rrf_scores
def _dedupe_and_merge(
self, dense_results: list[SearchResult], sparse_results: list[SearchResult]
) -> dict[str, SearchResult]:
results_map: dict[str, SearchResult] = {}
for result in dense_results:
if result["content"] not in results_map:
results_map[result["content"]] = result
for result in sparse_results:
if result["content"] not in results_map:
results_map[result["content"]] = result
return results_map
def rerank(
self,
@@ -56,19 +66,29 @@ class SimpleReranker:
sparse_results: list[SearchResult],
limit: int = 1,
) -> list[SearchResult]:
results = self._dedupe(dense_results, sparse_results)
results.sort(key=lambda x: x["score"], reverse=True)
return results[:limit]
rrf_scores = self._reciprocal_rank_fusion(dense_results, sparse_results)
results_map = self._dedupe_and_merge(dense_results, sparse_results)
reranked_results: list[SearchResult] = []
for content, result in results_map.items():
result_copy = result.copy()
result_copy["score"] = rrf_scores[content]
reranked_results.append(result_copy)
reranked_results.sort(key=lambda x: x["score"], reverse=True)
return reranked_results[:limit]
class VectorDB:
def __init__(
self, qdrant_client: AsyncQdrantClient, collection_name: str, embedder: Embedder
self,
qdrant_client: AsyncQdrantClient,
collection_name: str,
embedder: Embedder,
rrf_constant: int = 60,
) -> None:
self._client = qdrant_client
self.collection_name = collection_name
self.embedder = embedder
self._reranker = SimpleReranker()
self._reranker = SimpleReranker(k=rrf_constant)
async def configure_collection(self) -> None:
if await self._client.collection_exists(self.collection_name):