Compare commits

...

2 Commits

Author SHA1 Message Date
Logan Markewich ff94e1a277 add changeset 2025-11-17 16:00:55 -06:00
Logan Markewich c756dae145 propagate retrieval metadata to retrieved nodes 2025-11-17 15:58:51 -06:00
4 changed files with 62 additions and 11 deletions
+6
View File
@@ -0,0 +1,6 @@
---
"llama-cloud-services": patch
"llama-cloud-services-py": patch
---
Propagate retrieval metadata to retriever nodes
+18 -2
View File
@@ -258,6 +258,7 @@ def page_screenshot_nodes_to_node_with_score(
client: LlamaCloud,
raw_image_nodes: Optional[List[PageScreenshotNodeWithScore]],
project_id: str,
metadata: Optional[dict] = None,
) -> List[NodeWithScore]:
if not raw_image_nodes:
return []
@@ -273,6 +274,7 @@ def page_screenshot_nodes_to_node_with_score(
image_base64 = base64.b64encode(image_bytes).decode("utf-8")
image_node_metadata: Dict[str, Any] = {
**(raw_image_node.node.metadata or {}),
**(metadata or {}),
"file_id": raw_image_node.node.file_id,
"page_index": raw_image_node.node.page_index,
}
@@ -289,6 +291,7 @@ def image_nodes_to_node_with_score(
client: LlamaCloud,
raw_image_nodes: Optional[List[PageScreenshotNodeWithScore]],
project_id: str,
metadata: Optional[dict] = None,
) -> List[NodeWithScore]:
"""
Legacy method to alias page_screenshot_nodes_to_node_with_score.
@@ -297,7 +300,10 @@ def image_nodes_to_node_with_score(
return []
return page_screenshot_nodes_to_node_with_score(
client=client, raw_image_nodes=raw_image_nodes, project_id=project_id
client=client,
raw_image_nodes=raw_image_nodes,
project_id=project_id,
metadata=metadata,
)
@@ -305,6 +311,7 @@ def page_figure_nodes_to_node_with_score(
client: LlamaCloud,
raw_figure_nodes: Optional[List[PageFigureNodeWithScore]],
project_id: str,
metadata: Optional[dict] = None,
) -> List[NodeWithScore]:
if not raw_figure_nodes:
return []
@@ -321,6 +328,7 @@ def page_figure_nodes_to_node_with_score(
figure_base64 = base64.b64encode(figure_bytes).decode("utf-8")
figure_node_metadata: Dict[str, Any] = {
**(raw_figure_node.node.metadata or {}),
**(metadata or {}),
"file_id": raw_figure_node.node.file_id,
"page_index": raw_figure_node.node.page_index,
"figure_name": raw_figure_node.node.figure_name,
@@ -337,6 +345,7 @@ async def apage_screenshot_nodes_to_node_with_score(
client: AsyncLlamaCloud,
raw_image_nodes: Optional[List[PageScreenshotNodeWithScore]],
project_id: str,
metadata: Optional[dict] = None,
) -> List[NodeWithScore]:
if not raw_image_nodes:
return []
@@ -357,6 +366,7 @@ async def apage_screenshot_nodes_to_node_with_score(
image_base64 = base64.b64encode(image_bytes).decode("utf-8")
image_node_metadata: Dict[str, Any] = {
**(raw_image_node.node.metadata or {}),
**(metadata or {}),
"file_id": raw_image_node.node.file_id,
"page_index": raw_image_node.node.page_index,
}
@@ -372,6 +382,7 @@ async def aimage_nodes_to_node_with_score(
client: AsyncLlamaCloud,
raw_image_nodes: Optional[List[PageScreenshotNodeWithScore]],
project_id: str,
metadata: Optional[dict] = None,
) -> List[NodeWithScore]:
"""
Legacy method to alias apage_screenshot_nodes_to_node_with_score.
@@ -380,7 +391,10 @@ async def aimage_nodes_to_node_with_score(
return []
return await apage_screenshot_nodes_to_node_with_score(
client=client, raw_image_nodes=raw_image_nodes, project_id=project_id
client=client,
raw_image_nodes=raw_image_nodes,
project_id=project_id,
metadata=metadata,
)
@@ -388,6 +402,7 @@ async def apage_figure_nodes_to_node_with_score(
client: AsyncLlamaCloud,
raw_figure_nodes: Optional[List[PageFigureNodeWithScore]],
project_id: str,
metadata: Optional[dict] = None,
) -> List[NodeWithScore]:
if not raw_figure_nodes:
return []
@@ -409,6 +424,7 @@ async def apage_figure_nodes_to_node_with_score(
figure_base64 = base64.b64encode(figure_bytes).decode("utf-8")
figure_node_metadata: Dict[str, Any] = {
**(raw_figure_node.node.metadata or {}),
**(metadata or {}),
"file_id": raw_figure_node.node.file_id,
"page_index": raw_figure_node.node.page_index,
"figure_name": raw_figure_node.node.figure_name,
+25 -8
View File
@@ -129,11 +129,12 @@ class LlamaCloudRetriever(BaseRetriever):
)
def _result_nodes_to_node_with_score(
self, result_nodes: List[TextNodeWithScore]
self, result_nodes: List[TextNodeWithScore], metadata: Optional[dict] = None
) -> List[NodeWithScore]:
nodes = []
for res in result_nodes:
text_node = TextNode.parse_obj(res.node.dict())
text_node = TextNode.model_validate(res.node.dict())
text_node.metadata.update(metadata or {})
nodes.append(NodeWithScore(node=text_node, score=res.score))
return nodes
@@ -161,17 +162,25 @@ class LlamaCloudRetriever(BaseRetriever):
search_filters_inference_schema=search_filters_inference_schema,
)
result_nodes = self._result_nodes_to_node_with_score(results.retrieval_nodes)
result_nodes = self._result_nodes_to_node_with_score(
results.retrieval_nodes, metadata=results.metadata
)
if self._retrieve_page_screenshot_nodes:
result_nodes.extend(
page_screenshot_nodes_to_node_with_score(
self._client, results.image_nodes, self.project.id
self._client,
results.image_nodes,
self.project.id,
metadata=results.metadata,
)
)
if self._retrieve_page_figure_nodes:
result_nodes.extend(
page_figure_nodes_to_node_with_score(
self._client, results.page_figure_nodes, self.project.id
self._client,
results.page_figure_nodes,
self.project.id,
metadata=results.metadata,
)
)
@@ -200,17 +209,25 @@ class LlamaCloudRetriever(BaseRetriever):
search_filters_inference_schema=search_filters_inference_schema,
)
result_nodes = self._result_nodes_to_node_with_score(results.retrieval_nodes)
result_nodes = self._result_nodes_to_node_with_score(
results.retrieval_nodes, metadata=results.metadata
)
if self._retrieve_page_screenshot_nodes:
result_nodes.extend(
await apage_screenshot_nodes_to_node_with_score(
self._aclient, results.image_nodes, self.project.id
self._aclient,
results.image_nodes,
self.project.id,
metadata=results.metadata,
)
)
if self._retrieve_page_figure_nodes:
result_nodes.extend(
await apage_figure_nodes_to_node_with_score(
self._aclient, results.page_figure_nodes, self.project.id
self._aclient,
results.page_figure_nodes,
self.project.id,
metadata=results.metadata,
)
)
@@ -34,12 +34,15 @@ export class LlamaCloudRetriever extends BaseRetriever {
private resultNodesToNodeWithScore(
nodes: TextNodeWithScore[],
metadata: Record<string, string> | undefined,
): NodeWithScore[] {
return nodes.map((node: TextNodeWithScore) => {
const textNode = jsonToNode(node.node, ObjectType.TEXT);
const extra_metadata = metadata || {};
textNode.metadata = {
...textNode.metadata,
...node.node.extra_info, // append LlamaCloud extra_info to node metadata (file_name, pipeline_id, etc.)
...extra_metadata, // append retrieval-level metadata
};
return {
// Currently LlamaCloud only supports text nodes
@@ -63,6 +66,7 @@ export class LlamaCloudRetriever extends BaseRetriever {
private async pageScreenshotNodesToNodeWithScore(
nodes: PageScreenshotNodeWithScore[] | undefined,
projectId: string,
metadata: Record<string, string> | undefined,
): Promise<NodeWithScore[]> {
if (!nodes || nodes.length === 0) return [];
@@ -87,6 +91,7 @@ export class LlamaCloudRetriever extends BaseRetriever {
image: base64,
metadata: {
...(n.node.metadata ?? {}),
...(metadata || {}),
file_id: n.node.file_id,
page_index: n.node.page_index,
},
@@ -101,6 +106,7 @@ export class LlamaCloudRetriever extends BaseRetriever {
private async pageFigureNodesToNodeWithScore(
nodes: PageFigureNodeWithScore[] | undefined,
projectId: string,
metadata: Record<string, string> | undefined,
): Promise<NodeWithScore[]> {
if (!nodes || nodes.length === 0) return [];
@@ -126,6 +132,7 @@ export class LlamaCloudRetriever extends BaseRetriever {
image: base64,
metadata: {
...(n.node.metadata ?? {}),
...(metadata || {}),
file_id: n.node.file_id,
page_index: n.node.page_index,
figure_name: n.node.figure_name,
@@ -222,7 +229,10 @@ export class LlamaCloudRetriever extends BaseRetriever {
},
});
const textNodes = this.resultNodesToNodeWithScore(results.retrieval_nodes);
const textNodes = this.resultNodesToNodeWithScore(
results.retrieval_nodes,
results.metadata,
);
const needScreenshots = (this.retrieveParams as RetrievalParams)
.retrieve_page_screenshot_nodes;
@@ -240,12 +250,14 @@ export class LlamaCloudRetriever extends BaseRetriever {
? this.pageScreenshotNodesToNodeWithScore(
results.image_nodes,
projectId,
results.metadata,
)
: Promise.resolve([] as NodeWithScore[]),
needFigures
? this.pageFigureNodesToNodeWithScore(
results.page_figure_nodes,
projectId,
results.metadata,
)
: Promise.resolve([] as NodeWithScore[]),
]);