mirror of
https://github.com/run-llama/llama_cloud_services.git
synced 2026-07-01 21:44:37 -04:00
Compare commits
2 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 7551e9fc0f | |||
| dfd83c6039 |
@@ -222,11 +222,18 @@ class LlamaCloudCompositeRetriever(BaseRetriever):
|
||||
rerank_config if rerank_config is not None else self._rerank_config
|
||||
)
|
||||
|
||||
# Inject rerank_top_n into rerank_config if specified
|
||||
if rerank_top_n is not None and rerank_top_n != OMIT:
|
||||
if rerank_config is None or rerank_config == OMIT:
|
||||
rerank_config = ReRankConfig(top_n=rerank_top_n)
|
||||
else:
|
||||
# Update existing rerank_config with top_n
|
||||
rerank_config = rerank_config.copy(update={"top_n": rerank_top_n})
|
||||
|
||||
if self._persisted:
|
||||
result = self._client.retrievers.retrieve(
|
||||
self.retriever.id, # type: ignore [union-attr]
|
||||
mode=mode,
|
||||
rerank_top_n=rerank_top_n,
|
||||
rerank_config=rerank_config,
|
||||
query=query_bundle.query_str,
|
||||
)
|
||||
@@ -234,7 +241,6 @@ class LlamaCloudCompositeRetriever(BaseRetriever):
|
||||
result = self._client.retrievers.direct_retrieve(
|
||||
project_id=self.project.id,
|
||||
mode=mode,
|
||||
rerank_top_n=rerank_top_n,
|
||||
rerank_config=rerank_config,
|
||||
query=query_bundle.query_str,
|
||||
pipelines=self.retriever.pipelines, # type: ignore [union-attr]
|
||||
@@ -263,19 +269,25 @@ class LlamaCloudCompositeRetriever(BaseRetriever):
|
||||
rerank_config if rerank_config is not None else self._rerank_config
|
||||
)
|
||||
|
||||
# Inject rerank_top_n into rerank_config if specified
|
||||
if rerank_top_n is not None and rerank_top_n != OMIT:
|
||||
if rerank_config is None or rerank_config == OMIT:
|
||||
rerank_config = ReRankConfig(top_n=rerank_top_n)
|
||||
else:
|
||||
# Update existing rerank_config with top_n
|
||||
rerank_config = rerank_config.copy(update={"top_n": rerank_top_n})
|
||||
|
||||
if self._persisted:
|
||||
result = await self._aclient.retrievers.retrieve(
|
||||
self.retriever.id, # type: ignore [union-attr]
|
||||
mode=mode,
|
||||
rerank_config=rerank_config,
|
||||
rerank_top_n=rerank_top_n,
|
||||
query=query_bundle.query_str,
|
||||
)
|
||||
else:
|
||||
result = await self._aclient.retrievers.direct_retrieve(
|
||||
project_id=self.project.id,
|
||||
mode=mode,
|
||||
rerank_top_n=rerank_top_n,
|
||||
rerank_config=rerank_config,
|
||||
query=query_bundle.query_str,
|
||||
pipelines=self.retriever.pipelines, # type: ignore [union-attr]
|
||||
|
||||
+1
-1
@@ -19,7 +19,7 @@ dev = [
|
||||
|
||||
[project]
|
||||
name = "llama-cloud-services"
|
||||
version = "0.6.60"
|
||||
version = "0.6.61"
|
||||
description = "Tailored SDK clients for LlamaCloud services."
|
||||
authors = [{name = "Logan Markewich", email = "logan@runllama.ai"}]
|
||||
requires-python = ">=3.9,<4.0"
|
||||
|
||||
@@ -412,12 +412,54 @@ async def test_composite_retriever(index_name: str):
|
||||
# Assertions to verify the retrieval
|
||||
assert len(nodes) >= 2
|
||||
|
||||
# Test additional rerank_top_n configurations to cover the injection logic
|
||||
|
||||
# Test retriever with only rerank_top_n=1 (no existing rerank_config)
|
||||
retriever_with_rerank_top_n = LlamaCloudCompositeRetriever(
|
||||
name="composite_retriever_test_2",
|
||||
project_name=project_name,
|
||||
api_key=api_key,
|
||||
base_url=base_url,
|
||||
create_if_not_exists=True,
|
||||
mode=CompositeRetrievalMode.FULL,
|
||||
rerank_top_n=1,
|
||||
)
|
||||
retriever_with_rerank_top_n.add_index(index1)
|
||||
retriever_with_rerank_top_n.add_index(index2)
|
||||
nodes = retriever_with_rerank_top_n.retrieve("Hello world.")
|
||||
assert len(nodes) <= 1 # Should be limited to 1 result by rerank_top_n
|
||||
|
||||
# Test retriever with both rerank_top_n and custom rerank_config
|
||||
custom_config = ReRankConfig(top_n=10, model="test-model")
|
||||
retriever_with_both = LlamaCloudCompositeRetriever(
|
||||
name="composite_retriever_test_3",
|
||||
project_name=project_name,
|
||||
api_key=api_key,
|
||||
base_url=base_url,
|
||||
create_if_not_exists=True,
|
||||
mode=CompositeRetrievalMode.FULL,
|
||||
rerank_top_n=2,
|
||||
rerank_config=custom_config,
|
||||
)
|
||||
retriever_with_both.add_index(index1)
|
||||
retriever_with_both.add_index(index2)
|
||||
nodes = retriever_with_both.retrieve("Hello world.")
|
||||
assert len(nodes) >= 2 # Should have results from both indices
|
||||
|
||||
# Retrieve nodes using the composite retriever
|
||||
nodes = await retriever.aretrieve("Hello world.")
|
||||
|
||||
# Assertions to verify the retrieval
|
||||
assert len(nodes) >= 2
|
||||
|
||||
# Test async retrieve with the rerank_top_n only retriever
|
||||
nodes = await retriever_with_rerank_top_n.aretrieve("Hello world.")
|
||||
assert len(nodes) >= 1
|
||||
|
||||
# Test async retrieve with the both rerank_top_n and rerank_config retriever
|
||||
nodes = await retriever_with_both.aretrieve("Hello world.")
|
||||
assert len(nodes) >= 2
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not base_url or not api_key, reason="No platform base url or api key set"
|
||||
|
||||
Generated
+1
-1
@@ -1596,7 +1596,7 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "llama-cloud-services"
|
||||
version = "0.6.59"
|
||||
version = "0.6.61"
|
||||
source = { editable = "." }
|
||||
dependencies = [
|
||||
{ name = "click", version = "8.1.8", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.10'" },
|
||||
|
||||
Reference in New Issue
Block a user