mirror of
https://github.com/run-llama/fs-explorer.git
synced 2026-07-01 21:45:00 -04:00
fix: refactor reranker
This commit is contained in:
@@ -15,6 +15,7 @@ class Pipeline:
|
||||
self,
|
||||
qdrant_client: AsyncQdrantClient,
|
||||
qdrant_collection_name: str,
|
||||
rrf_constant: int = 60,
|
||||
parsing_kwargs: dict[str, Any] | None = None,
|
||||
cache_directory: str | None = None,
|
||||
openai_api_key: str | None = None,
|
||||
@@ -50,6 +51,7 @@ class Pipeline:
|
||||
qdrant_client=qdrant_client,
|
||||
collection_name=qdrant_collection_name,
|
||||
embedder=self.embedder,
|
||||
rrf_constant=rrf_constant,
|
||||
)
|
||||
self.filter_llm = LLMFilter(api_key=openai_api_key, model=openai_llm_model)
|
||||
self.file_paths: list[str] = []
|
||||
|
||||
@@ -9,7 +9,6 @@ from qdrant_client.models import (
|
||||
FieldCondition,
|
||||
MatchValue,
|
||||
)
|
||||
from statistics import mean
|
||||
from typing import TypedDict, Literal, cast
|
||||
|
||||
|
||||
@@ -26,29 +25,40 @@ class SearchResult(TypedDict):
|
||||
|
||||
|
||||
class SimpleReranker:
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
def __init__(self, k: int = 60) -> None:
|
||||
"""
|
||||
Args:
|
||||
k: Constant for RRF formula. Higher values reduce the impact of top-ranked items. Default of 60 is commonly used in literature.
|
||||
"""
|
||||
self.k = k
|
||||
|
||||
def _dedupe(
|
||||
def _reciprocal_rank_fusion(
|
||||
self, dense_results: list[SearchResult], sparse_results: list[SearchResult]
|
||||
) -> list[SearchResult]:
|
||||
dense_results_ranked = dense_results.copy()
|
||||
sparse_results_ranked = sparse_results.copy()
|
||||
dense_results_ranked.sort(key=lambda x: x["content"], reverse=False)
|
||||
sparse_results_ranked.sort(key=lambda x: x["content"], reverse=False)
|
||||
for i, r in enumerate(dense_results_ranked):
|
||||
r["score"] = i + 1
|
||||
for i, r in enumerate(sparse_results_ranked):
|
||||
r["score"] = i + 1
|
||||
for result in sparse_results_ranked:
|
||||
for i, r in enumerate(dense_results_ranked):
|
||||
if r["content"] == result["content"]:
|
||||
r["score"] = mean([r["score"], result["score"]])
|
||||
dense_results_ranked[i] = r
|
||||
break
|
||||
else:
|
||||
dense_results_ranked.append(result)
|
||||
return dense_results_ranked
|
||||
) -> dict[str, float]:
|
||||
rrf_scores: dict[str, float] = {}
|
||||
for rank, result in enumerate(dense_results, start=1):
|
||||
content = result["content"]
|
||||
rrf_scores[content] = rrf_scores.get(content, 0.0) + 1 / (self.k + rank)
|
||||
for rank, result in enumerate(sparse_results, start=1):
|
||||
content = result["content"]
|
||||
rrf_scores[content] = rrf_scores.get(content, 0.0) + 1 / (self.k + rank)
|
||||
|
||||
return rrf_scores
|
||||
|
||||
def _dedupe_and_merge(
|
||||
self, dense_results: list[SearchResult], sparse_results: list[SearchResult]
|
||||
) -> dict[str, SearchResult]:
|
||||
results_map: dict[str, SearchResult] = {}
|
||||
|
||||
for result in dense_results:
|
||||
if result["content"] not in results_map:
|
||||
results_map[result["content"]] = result
|
||||
|
||||
for result in sparse_results:
|
||||
if result["content"] not in results_map:
|
||||
results_map[result["content"]] = result
|
||||
|
||||
return results_map
|
||||
|
||||
def rerank(
|
||||
self,
|
||||
@@ -56,19 +66,29 @@ class SimpleReranker:
|
||||
sparse_results: list[SearchResult],
|
||||
limit: int = 1,
|
||||
) -> list[SearchResult]:
|
||||
results = self._dedupe(dense_results, sparse_results)
|
||||
results.sort(key=lambda x: x["score"], reverse=True)
|
||||
return results[:limit]
|
||||
rrf_scores = self._reciprocal_rank_fusion(dense_results, sparse_results)
|
||||
results_map = self._dedupe_and_merge(dense_results, sparse_results)
|
||||
reranked_results: list[SearchResult] = []
|
||||
for content, result in results_map.items():
|
||||
result_copy = result.copy()
|
||||
result_copy["score"] = rrf_scores[content]
|
||||
reranked_results.append(result_copy)
|
||||
reranked_results.sort(key=lambda x: x["score"], reverse=True)
|
||||
return reranked_results[:limit]
|
||||
|
||||
|
||||
class VectorDB:
|
||||
def __init__(
|
||||
self, qdrant_client: AsyncQdrantClient, collection_name: str, embedder: Embedder
|
||||
self,
|
||||
qdrant_client: AsyncQdrantClient,
|
||||
collection_name: str,
|
||||
embedder: Embedder,
|
||||
rrf_constant: int = 60,
|
||||
) -> None:
|
||||
self._client = qdrant_client
|
||||
self.collection_name = collection_name
|
||||
self.embedder = embedder
|
||||
self._reranker = SimpleReranker()
|
||||
self._reranker = SimpleReranker(k=rrf_constant)
|
||||
|
||||
async def configure_collection(self) -> None:
|
||||
if await self._client.collection_exists(self.collection_name):
|
||||
|
||||
Reference in New Issue
Block a user