Compare commits

...

2 Commits

Author SHA1 Message Date
Jerry Liu 7551e9fc0f cr 2025-08-17 15:31:12 -07:00
Jerry Liu dfd83c6039 cr 2025-08-17 15:30:55 -07:00
4 changed files with 60 additions and 6 deletions
@@ -222,11 +222,18 @@ class LlamaCloudCompositeRetriever(BaseRetriever):
rerank_config if rerank_config is not None else self._rerank_config
)
# Inject rerank_top_n into rerank_config if specified
if rerank_top_n is not None and rerank_top_n != OMIT:
if rerank_config is None or rerank_config == OMIT:
rerank_config = ReRankConfig(top_n=rerank_top_n)
else:
# Update existing rerank_config with top_n
rerank_config = rerank_config.copy(update={"top_n": rerank_top_n})
if self._persisted:
result = self._client.retrievers.retrieve(
self.retriever.id, # type: ignore [union-attr]
mode=mode,
rerank_top_n=rerank_top_n,
rerank_config=rerank_config,
query=query_bundle.query_str,
)
@@ -234,7 +241,6 @@ class LlamaCloudCompositeRetriever(BaseRetriever):
result = self._client.retrievers.direct_retrieve(
project_id=self.project.id,
mode=mode,
rerank_top_n=rerank_top_n,
rerank_config=rerank_config,
query=query_bundle.query_str,
pipelines=self.retriever.pipelines, # type: ignore [union-attr]
@@ -263,19 +269,25 @@ class LlamaCloudCompositeRetriever(BaseRetriever):
rerank_config if rerank_config is not None else self._rerank_config
)
# Inject rerank_top_n into rerank_config if specified
if rerank_top_n is not None and rerank_top_n != OMIT:
if rerank_config is None or rerank_config == OMIT:
rerank_config = ReRankConfig(top_n=rerank_top_n)
else:
# Update existing rerank_config with top_n
rerank_config = rerank_config.copy(update={"top_n": rerank_top_n})
if self._persisted:
result = await self._aclient.retrievers.retrieve(
self.retriever.id, # type: ignore [union-attr]
mode=mode,
rerank_config=rerank_config,
rerank_top_n=rerank_top_n,
query=query_bundle.query_str,
)
else:
result = await self._aclient.retrievers.direct_retrieve(
project_id=self.project.id,
mode=mode,
rerank_top_n=rerank_top_n,
rerank_config=rerank_config,
query=query_bundle.query_str,
pipelines=self.retriever.pipelines, # type: ignore [union-attr]
+1 -1
View File
@@ -19,7 +19,7 @@ dev = [
[project]
name = "llama-cloud-services"
version = "0.6.60"
version = "0.6.61"
description = "Tailored SDK clients for LlamaCloud services."
authors = [{name = "Logan Markewich", email = "logan@runllama.ai"}]
requires-python = ">=3.9,<4.0"
+42
View File
@@ -412,12 +412,54 @@ async def test_composite_retriever(index_name: str):
# Assertions to verify the retrieval
assert len(nodes) >= 2
# Test additional rerank_top_n configurations to cover the injection logic
# Test retriever with only rerank_top_n=1 (no existing rerank_config)
retriever_with_rerank_top_n = LlamaCloudCompositeRetriever(
name="composite_retriever_test_2",
project_name=project_name,
api_key=api_key,
base_url=base_url,
create_if_not_exists=True,
mode=CompositeRetrievalMode.FULL,
rerank_top_n=1,
)
retriever_with_rerank_top_n.add_index(index1)
retriever_with_rerank_top_n.add_index(index2)
nodes = retriever_with_rerank_top_n.retrieve("Hello world.")
assert len(nodes) <= 1 # Should be limited to 1 result by rerank_top_n
# Test retriever with both rerank_top_n and custom rerank_config
custom_config = ReRankConfig(top_n=10, model="test-model")
retriever_with_both = LlamaCloudCompositeRetriever(
name="composite_retriever_test_3",
project_name=project_name,
api_key=api_key,
base_url=base_url,
create_if_not_exists=True,
mode=CompositeRetrievalMode.FULL,
rerank_top_n=2,
rerank_config=custom_config,
)
retriever_with_both.add_index(index1)
retriever_with_both.add_index(index2)
nodes = retriever_with_both.retrieve("Hello world.")
assert len(nodes) >= 2 # Should have results from both indices
# Retrieve nodes using the composite retriever
nodes = await retriever.aretrieve("Hello world.")
# Assertions to verify the retrieval
assert len(nodes) >= 2
# Test async retrieve with the rerank_top_n only retriever
nodes = await retriever_with_rerank_top_n.aretrieve("Hello world.")
assert len(nodes) >= 1
# Test async retrieve with the both rerank_top_n and rerank_config retriever
nodes = await retriever_with_both.aretrieve("Hello world.")
assert len(nodes) >= 2
@pytest.mark.skipif(
not base_url or not api_key, reason="No platform base url or api key set"
Generated
+1 -1
View File
@@ -1596,7 +1596,7 @@ wheels = [
[[package]]
name = "llama-cloud-services"
version = "0.6.59"
version = "0.6.61"
source = { editable = "." }
dependencies = [
{ name = "click", version = "8.1.8", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.10'" },