extract-reconcile-invoice: fix review UI and contract indexing (#276)

This commit is contained in:
Adrian Lyjak
2026-04-29 15:32:36 -04:00
committed by GitHub
parent 01ddebd78f
commit 0a2139028c
4 changed files with 160 additions and 89 deletions
+38 -2
View File
@@ -10,7 +10,9 @@ can either carry an inline snapshot OR point at a saved platform config.
import logging
import os
from typing import Any
import jsonref
from llama_cloud.types.beta.split_category import SplitCategory
from llama_cloud.types.classify_v2_parameters import ClassifyV2Parameters, Rule
from llama_cloud.types.extract_v2_parameters import ExtractV2Parameters
@@ -144,8 +146,14 @@ class Discrepancy(BaseModel):
)
class InvoiceWithReconciliation(InvoiceExtractionSchema):
"""Invoice data with reconciliation information"""
class Reconciliation(BaseModel):
"""Contract-linkage fields appended to an invoice by the reconcile step.
Separate from `InvoiceExtractionSchema` because these are not part of the
LlamaCloud extract output; they're produced downstream by
`reconcile_with_contract`. The review UI's display schema is built by
overlaying these properties onto the extract schema.
"""
matched_contract_id: str | None = Field(
default=None, description="ID of the matched contract file in LlamaCloud"
@@ -164,3 +172,31 @@ class InvoiceWithReconciliation(InvoiceExtractionSchema):
default=None,
description="List of discrepancies found between invoice and contract",
)
class InvoiceWithReconciliation(Reconciliation, InvoiceExtractionSchema):
"""Invoice data plus reconciliation overlay (the full agent_data shape)."""
# Reconciliation overlay properties resolve once at import: the model is fixed,
# so its JSON schema is constant. `replace_refs` flattens the `$defs/Discrepancy`
# reference embedded in `discrepancies.items`.
_RECON_OVERLAY_PROPS: dict[str, Any] = jsonref.replace_refs(
Reconciliation.model_json_schema(), proxies=False
)["properties"]
def build_review_schema(extract_data_schema: dict[str, Any]) -> dict[str, Any]:
"""Compose the review-UI display schema from the extract data schema.
The `reconcile_with_contract` step appends fields beyond what extract
returns, so the displayed schema is the extract schema with reconciliation
properties overlaid. Reconciliation keys go first so the contract-match
verdict renders above invoice fields. The merge order here is the source
of truth for ordering, NOT the Pydantic class graph.
"""
schema = jsonref.replace_refs(extract_data_schema, proxies=False)
return {
**schema,
"properties": {**_RECON_OVERLAY_PROPS, **schema["properties"]},
}
+85 -80
View File
@@ -1,187 +1,191 @@
"""
Workflow for indexing contract documents into LlamaCloud Index for retrieval.
Each contract is parsed via LlamaParse and the resulting markdown is upserted
into the contracts pipeline. The reconcile step retrieves and feeds that
markdown to the matching LLM, so it has to be human-readable text, not raw
PDF bytes.
"""
import logging
import os
import tempfile
from pathlib import Path
from typing import Literal
from typing import Annotated, Literal
import httpx
from llama_cloud import AsyncLlamaCloud
from llama_cloud.types.pipelines import CloudDocumentCreateParam
from pydantic import BaseModel
from workflows import Context, Workflow, step
from workflows.events import Event, StartEvent, StopEvent
from workflows.resource import Resource, ResourceConfig
from .clients import get_contracts_pipeline_id, get_llama_cloud_client, project_id
from .config import ParseConfig
logger = logging.getLogger(__name__)
class ContractFileEvent(StartEvent):
"""Event to start contract indexing with a file ID"""
"""Event to start contract indexing with file IDs."""
file_ids: list[str]
class DownloadContractEvent(Event):
"""Event to trigger contract download"""
class IndexContractFileEvent(Event):
"""Per-file fan-out event."""
file_id: str
class ContractDownloadedEvent(Event):
"""Event indicating contract has been downloaded"""
class ContractParseStartedEvent(Event):
"""Event indicating a contract parse job has started."""
file_id: str
file_path: str
filename: str
parse_job_id: str
class ContractIndexedEvent(Event):
"""Event indicating a single contract has been indexed"""
"""Event indicating a single contract has been parsed and indexed."""
file_id: str
filename: str
class Status(Event):
"""Event to show toast notifications in the UI"""
"""Toast notification for the UI."""
level: Literal["info", "warning", "error"]
message: str
class ContractIndexState(BaseModel):
"""State for contract indexing workflow"""
total_files: int = 0
# Store file info keyed by file_id
file_paths: dict[str, str] = {}
filenames: dict[str, str] = {}
class IndexContractWorkflow(Workflow):
"""
Workflow to download and index a contract document into LlamaCloud Index.
"""
"""Parse contracts via LlamaParse and index their markdown for retrieval."""
@step()
async def start_indexing(
self, event: ContractFileEvent, ctx: Context[ContractIndexState]
) -> DownloadContractEvent | None:
"""Initialize the workflow with multiple file IDs and fan out to parallel downloads"""
) -> IndexContractFileEvent | None:
"""Fan out to one parse-and-index task per file."""
logger.info(f"Starting contract indexing for {len(event.file_ids)} files")
async with ctx.store.edit_state() as state:
state.total_files = len(event.file_ids)
# Fan out: emit one download event per file
for file_id in event.file_ids:
ctx.send_event(DownloadContractEvent(file_id=file_id))
ctx.send_event(IndexContractFileEvent(file_id=file_id))
return None
@step(num_workers=4)
async def download_contract(
self, event: DownloadContractEvent, ctx: Context[ContractIndexState]
) -> ContractDownloadedEvent:
"""Download the contract file from LlamaCloud storage (runs in parallel)"""
async def start_contract_parse(
self,
event: IndexContractFileEvent,
ctx: Context[ContractIndexState],
llama_cloud_client: Annotated[
AsyncLlamaCloud, Resource(get_llama_cloud_client)
],
parse_config: Annotated[
ParseConfig,
ResourceConfig(
config_file="configs/config.json",
path_selector="parse",
label="Contract Parsing Settings",
description="Parse settings used when indexing contract documents",
),
],
) -> ContractParseStartedEvent:
"""Start a LlamaParse job for the contract."""
file_id = event.file_id
client = get_llama_cloud_client()
file_metadata = None
async for f in client.files.list(file_ids=[file_id], project_id=project_id):
async for f in llama_cloud_client.files.list(
file_ids=[file_id], project_id=project_id
):
file_metadata = f
break
if file_metadata is None:
raise ValueError(f"File {file_id} not found")
file_url = await client.files.get(file_id)
temp_dir = tempfile.gettempdir()
filename = file_metadata.name
file_path = os.path.join(temp_dir, filename)
logger.info(f"Downloading contract {filename} from {file_url.url}")
logger.info(f"Parsing contract {filename}")
ctx.write_event_to_stream(
Status(level="info", message=f"Downloading contract: {filename}")
Status(level="info", message=f"Parsing contract: {filename}")
)
client = httpx.AsyncClient()
async with client.stream("GET", file_url.url) as response:
with open(file_path, "wb") as f:
async for chunk in response.aiter_bytes():
f.write(chunk)
parse_kwargs = parse_config.model_dump(
exclude={"configuration_id", "product_type"},
exclude_none=True,
)
parse_job = await llama_cloud_client.parsing.create(
file_id=file_id,
project_id=project_id,
**parse_kwargs,
)
logger.info(f"Downloaded contract to {file_path}")
async with ctx.store.edit_state() as state:
state.file_paths[file_id] = file_path
state.filenames[file_id] = filename
return ContractDownloadedEvent(
file_id=file_id, file_path=file_path, filename=filename
return ContractParseStartedEvent(
file_id=file_id,
filename=filename,
parse_job_id=parse_job.id,
)
@step(num_workers=4)
async def index_contract(
self, event: ContractDownloadedEvent, ctx: Context[ContractIndexState]
async def index_parsed_contract(
self,
event: ContractParseStartedEvent,
ctx: Context[ContractIndexState],
llama_cloud_client: Annotated[
AsyncLlamaCloud, Resource(get_llama_cloud_client)
],
) -> ContractIndexedEvent:
"""Index the contract document into LlamaCloud Index (runs in parallel)"""
"""Wait for LlamaParse completion and upsert the contract markdown."""
file_id = event.file_id
file_path = event.file_path
filename = event.filename
await llama_cloud_client.parsing.wait_for_completion(
event.parse_job_id,
project_id=project_id,
)
parse_result = await llama_cloud_client.parsing.get(
event.parse_job_id,
expand=["markdown"],
project_id=project_id,
)
pages = parse_result.markdown.pages if parse_result.markdown else []
markdown = "\n\n".join(page.markdown for page in pages if page.success)
if not markdown:
raise ValueError(f"Parse produced no markdown for contract {filename}")
logger.info(f"Indexing contract {filename}")
ctx.write_event_to_stream(
Status(level="info", message=f"Indexing contract: {filename}")
)
# Create a document with metadata
file_content = Path(file_path).read_text(errors="ignore")
document = CloudDocumentCreateParam(
text=file_content,
text=markdown,
metadata={
"filename": filename,
"file_id": file_id,
"document_type": "contract",
},
)
# Get the contracts pipeline and upsert the document
client = get_llama_cloud_client()
pipeline_id = await get_contracts_pipeline_id()
await client.pipelines.documents.upsert(
await llama_cloud_client.pipelines.documents.upsert(
pipeline_id=pipeline_id,
body=[document],
)
logger.info(f"Successfully indexed contract {filename}")
ctx.write_event_to_stream(
Status(
level="info",
message=f"Successfully indexed contract: {filename}",
)
)
return ContractIndexedEvent(file_id=file_id, filename=filename)
@step()
async def collect_results(
self, event: ContractIndexedEvent, ctx: Context[ContractIndexState]
) -> StopEvent | None:
"""Collect all indexed contracts and return final results (fan-in)"""
"""Wait for every contract to finish, then return aggregated results."""
state = await ctx.store.get_state()
# Collect all ContractIndexedEvent events - one for each file
events = ctx.collect_events(event, [ContractIndexedEvent] * state.total_files)
if events is None:
# Not all files have been indexed yet
return None
# All files have been indexed, return aggregated results
results = [{"file_id": ev.file_id, "filename": ev.filename} for ev in events]
logger.info(f"Successfully indexed all {len(results)} contracts")
ctx.write_event_to_stream(
Status(
@@ -189,21 +193,22 @@ class IndexContractWorkflow(Workflow):
message=f"Successfully indexed all {len(results)} contracts",
)
)
return StopEvent(result={"contracts": results, "total": len(results)})
workflow = IndexContractWorkflow(timeout=None)
if __name__ == "__main__":
import asyncio
from pathlib import Path
from dotenv import load_dotenv
load_dotenv()
logging.basicConfig(level=logging.INFO)
async def main():
# Example usage - upload a contract and index it
file = await get_llama_cloud_client().files.create(
file=Path("sample_contract.pdf").open("rb"),
purpose="extract",
+6 -4
View File
@@ -1,13 +1,16 @@
from typing import Annotated, Any
import jsonref
from llama_cloud.types.configuration_response import ExtractV2Parameters
from workflows import Workflow, step
from workflows.events import StartEvent, StopEvent
from workflows.resource import ResourceConfig
from .clients import get_contracts_pipeline_id, get_llama_cloud_client, project_id
from .config import EXTRACTED_DATA_COLLECTION, ExtractConfig
from .config import (
EXTRACTED_DATA_COLLECTION,
ExtractConfig,
build_review_schema,
)
class MetadataResponse(StopEvent):
@@ -54,10 +57,9 @@ class MetadataWorkflow(Workflow):
else:
schema_dict = dict(extract_config.data_schema)
json_schema = jsonref.replace_refs(schema_dict, proxies=False)
contracts_pipeline_id = await get_contracts_pipeline_id()
return MetadataResponse(
json_schema=json_schema,
json_schema=build_review_schema(schema_dict),
extracted_data_collection=EXTRACTED_DATA_COLLECTION,
contracts_pipeline_id=contracts_pipeline_id,
)
+31 -3
View File
@@ -1,5 +1,5 @@
import pytest
from extraction_review.config import EXTRACTED_DATA_COLLECTION
from extraction_review.config import EXTRACTED_DATA_COLLECTION, Reconciliation
from extraction_review.index_contract import ContractFileEvent
from extraction_review.index_contract import workflow as index_contract_workflow
from extraction_review.metadata_workflow import MetadataResponse
@@ -31,8 +31,12 @@ async def test_index_contract_workflow(
monkeypatch: pytest.MonkeyPatch,
fake: FakeLlamaCloudServer,
) -> None:
"""Regression: index_contract previously called stale SDK methods
(files.get_file / files.read_file_content) and died on any contract upload.
"""The contract indexer must run LlamaParse and upsert real markdown text.
Earlier versions read PDFs via `Path(...).read_text(errors="ignore")` and
upserted the resulting binary garbage into the contracts pipeline, so the
matching LLM saw stream-decoded bytes instead of contract content. Lock
the fix: the upserted document must not be raw PDF bytes.
"""
monkeypatch.setenv("LLAMA_CLOUD_API_KEY", "fake-api-key")
file_id = fake.files.preload(path="tests/files/test.pdf")
@@ -43,6 +47,22 @@ async def test_index_contract_workflow(
assert result["total"] == 1
assert result["contracts"][0]["file_id"] == file_id
# The fake stores upserted documents under pipelines._documents; pull the
# contract-tagged ones and verify the indexed text is parsed markdown,
# not the raw PDF stream.
indexed_texts = [
doc.text
for store in fake.pipelines._documents.values()
for doc in store.values()
if doc.metadata.get("file_id") == file_id
]
assert indexed_texts, "expected at least one upserted contract document"
for text in indexed_texts:
assert text, "indexed contract text is empty"
assert not text.startswith("%PDF"), (
f"contract was indexed as raw PDF bytes: {text[:40]!r}"
)
@pytest.mark.asyncio
async def test_metadata_workflow(
@@ -56,3 +76,11 @@ async def test_metadata_workflow(
assert isinstance(result.json_schema, dict)
assert "properties" in result.json_schema
assert result.contracts_pipeline_id
# Reconciliation overlay: presence guards the regression that dropped the
# linkage fields entirely; ordering guards a "cleanup" that silently moves
# the verdict to the bottom of the form.
properties = result.json_schema["properties"]
reconciliation_fields = set(Reconciliation.model_fields)
assert reconciliation_fields.issubset(properties.keys())
assert next(iter(properties)) in reconciliation_fields