mirror of
https://github.com/run-llama/template-workflow-extract-reconcile-invoice.git
synced 2026-06-30 22:17:53 -04:00
extract-reconcile-invoice: fix review UI and contract indexing (#276)
This commit is contained in:
@@ -10,7 +10,9 @@ can either carry an inline snapshot OR point at a saved platform config.
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
import jsonref
|
||||
from llama_cloud.types.beta.split_category import SplitCategory
|
||||
from llama_cloud.types.classify_v2_parameters import ClassifyV2Parameters, Rule
|
||||
from llama_cloud.types.extract_v2_parameters import ExtractV2Parameters
|
||||
@@ -144,8 +146,14 @@ class Discrepancy(BaseModel):
|
||||
)
|
||||
|
||||
|
||||
class InvoiceWithReconciliation(InvoiceExtractionSchema):
|
||||
"""Invoice data with reconciliation information"""
|
||||
class Reconciliation(BaseModel):
|
||||
"""Contract-linkage fields appended to an invoice by the reconcile step.
|
||||
|
||||
Separate from `InvoiceExtractionSchema` because these are not part of the
|
||||
LlamaCloud extract output; they're produced downstream by
|
||||
`reconcile_with_contract`. The review UI's display schema is built by
|
||||
overlaying these properties onto the extract schema.
|
||||
"""
|
||||
|
||||
matched_contract_id: str | None = Field(
|
||||
default=None, description="ID of the matched contract file in LlamaCloud"
|
||||
@@ -164,3 +172,31 @@ class InvoiceWithReconciliation(InvoiceExtractionSchema):
|
||||
default=None,
|
||||
description="List of discrepancies found between invoice and contract",
|
||||
)
|
||||
|
||||
|
||||
class InvoiceWithReconciliation(Reconciliation, InvoiceExtractionSchema):
|
||||
"""Invoice data plus reconciliation overlay (the full agent_data shape)."""
|
||||
|
||||
|
||||
# Reconciliation overlay properties resolve once at import: the model is fixed,
|
||||
# so its JSON schema is constant. `replace_refs` flattens the `$defs/Discrepancy`
|
||||
# reference embedded in `discrepancies.items`.
|
||||
_RECON_OVERLAY_PROPS: dict[str, Any] = jsonref.replace_refs(
|
||||
Reconciliation.model_json_schema(), proxies=False
|
||||
)["properties"]
|
||||
|
||||
|
||||
def build_review_schema(extract_data_schema: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Compose the review-UI display schema from the extract data schema.
|
||||
|
||||
The `reconcile_with_contract` step appends fields beyond what extract
|
||||
returns, so the displayed schema is the extract schema with reconciliation
|
||||
properties overlaid. Reconciliation keys go first so the contract-match
|
||||
verdict renders above invoice fields. The merge order here is the source
|
||||
of truth for ordering, NOT the Pydantic class graph.
|
||||
"""
|
||||
schema = jsonref.replace_refs(extract_data_schema, proxies=False)
|
||||
return {
|
||||
**schema,
|
||||
"properties": {**_RECON_OVERLAY_PROPS, **schema["properties"]},
|
||||
}
|
||||
|
||||
@@ -1,187 +1,191 @@
|
||||
"""
|
||||
Workflow for indexing contract documents into LlamaCloud Index for retrieval.
|
||||
|
||||
Each contract is parsed via LlamaParse and the resulting markdown is upserted
|
||||
into the contracts pipeline. The reconcile step retrieves and feeds that
|
||||
markdown to the matching LLM, so it has to be human-readable text, not raw
|
||||
PDF bytes.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import Literal
|
||||
from typing import Annotated, Literal
|
||||
|
||||
import httpx
|
||||
from llama_cloud import AsyncLlamaCloud
|
||||
from llama_cloud.types.pipelines import CloudDocumentCreateParam
|
||||
from pydantic import BaseModel
|
||||
from workflows import Context, Workflow, step
|
||||
from workflows.events import Event, StartEvent, StopEvent
|
||||
from workflows.resource import Resource, ResourceConfig
|
||||
|
||||
from .clients import get_contracts_pipeline_id, get_llama_cloud_client, project_id
|
||||
from .config import ParseConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ContractFileEvent(StartEvent):
|
||||
"""Event to start contract indexing with a file ID"""
|
||||
"""Event to start contract indexing with file IDs."""
|
||||
|
||||
file_ids: list[str]
|
||||
|
||||
|
||||
class DownloadContractEvent(Event):
|
||||
"""Event to trigger contract download"""
|
||||
class IndexContractFileEvent(Event):
|
||||
"""Per-file fan-out event."""
|
||||
|
||||
file_id: str
|
||||
|
||||
|
||||
class ContractDownloadedEvent(Event):
|
||||
"""Event indicating contract has been downloaded"""
|
||||
class ContractParseStartedEvent(Event):
|
||||
"""Event indicating a contract parse job has started."""
|
||||
|
||||
file_id: str
|
||||
file_path: str
|
||||
filename: str
|
||||
parse_job_id: str
|
||||
|
||||
|
||||
class ContractIndexedEvent(Event):
|
||||
"""Event indicating a single contract has been indexed"""
|
||||
"""Event indicating a single contract has been parsed and indexed."""
|
||||
|
||||
file_id: str
|
||||
filename: str
|
||||
|
||||
|
||||
class Status(Event):
|
||||
"""Event to show toast notifications in the UI"""
|
||||
"""Toast notification for the UI."""
|
||||
|
||||
level: Literal["info", "warning", "error"]
|
||||
message: str
|
||||
|
||||
|
||||
class ContractIndexState(BaseModel):
|
||||
"""State for contract indexing workflow"""
|
||||
|
||||
total_files: int = 0
|
||||
# Store file info keyed by file_id
|
||||
file_paths: dict[str, str] = {}
|
||||
filenames: dict[str, str] = {}
|
||||
|
||||
|
||||
class IndexContractWorkflow(Workflow):
|
||||
"""
|
||||
Workflow to download and index a contract document into LlamaCloud Index.
|
||||
"""
|
||||
"""Parse contracts via LlamaParse and index their markdown for retrieval."""
|
||||
|
||||
@step()
|
||||
async def start_indexing(
|
||||
self, event: ContractFileEvent, ctx: Context[ContractIndexState]
|
||||
) -> DownloadContractEvent | None:
|
||||
"""Initialize the workflow with multiple file IDs and fan out to parallel downloads"""
|
||||
) -> IndexContractFileEvent | None:
|
||||
"""Fan out to one parse-and-index task per file."""
|
||||
logger.info(f"Starting contract indexing for {len(event.file_ids)} files")
|
||||
async with ctx.store.edit_state() as state:
|
||||
state.total_files = len(event.file_ids)
|
||||
|
||||
# Fan out: emit one download event per file
|
||||
for file_id in event.file_ids:
|
||||
ctx.send_event(DownloadContractEvent(file_id=file_id))
|
||||
|
||||
ctx.send_event(IndexContractFileEvent(file_id=file_id))
|
||||
return None
|
||||
|
||||
@step(num_workers=4)
|
||||
async def download_contract(
|
||||
self, event: DownloadContractEvent, ctx: Context[ContractIndexState]
|
||||
) -> ContractDownloadedEvent:
|
||||
"""Download the contract file from LlamaCloud storage (runs in parallel)"""
|
||||
async def start_contract_parse(
|
||||
self,
|
||||
event: IndexContractFileEvent,
|
||||
ctx: Context[ContractIndexState],
|
||||
llama_cloud_client: Annotated[
|
||||
AsyncLlamaCloud, Resource(get_llama_cloud_client)
|
||||
],
|
||||
parse_config: Annotated[
|
||||
ParseConfig,
|
||||
ResourceConfig(
|
||||
config_file="configs/config.json",
|
||||
path_selector="parse",
|
||||
label="Contract Parsing Settings",
|
||||
description="Parse settings used when indexing contract documents",
|
||||
),
|
||||
],
|
||||
) -> ContractParseStartedEvent:
|
||||
"""Start a LlamaParse job for the contract."""
|
||||
file_id = event.file_id
|
||||
|
||||
client = get_llama_cloud_client()
|
||||
file_metadata = None
|
||||
async for f in client.files.list(file_ids=[file_id], project_id=project_id):
|
||||
async for f in llama_cloud_client.files.list(
|
||||
file_ids=[file_id], project_id=project_id
|
||||
):
|
||||
file_metadata = f
|
||||
break
|
||||
if file_metadata is None:
|
||||
raise ValueError(f"File {file_id} not found")
|
||||
file_url = await client.files.get(file_id)
|
||||
|
||||
temp_dir = tempfile.gettempdir()
|
||||
filename = file_metadata.name
|
||||
file_path = os.path.join(temp_dir, filename)
|
||||
|
||||
logger.info(f"Downloading contract {filename} from {file_url.url}")
|
||||
logger.info(f"Parsing contract {filename}")
|
||||
ctx.write_event_to_stream(
|
||||
Status(level="info", message=f"Downloading contract: {filename}")
|
||||
Status(level="info", message=f"Parsing contract: {filename}")
|
||||
)
|
||||
|
||||
client = httpx.AsyncClient()
|
||||
async with client.stream("GET", file_url.url) as response:
|
||||
with open(file_path, "wb") as f:
|
||||
async for chunk in response.aiter_bytes():
|
||||
f.write(chunk)
|
||||
parse_kwargs = parse_config.model_dump(
|
||||
exclude={"configuration_id", "product_type"},
|
||||
exclude_none=True,
|
||||
)
|
||||
parse_job = await llama_cloud_client.parsing.create(
|
||||
file_id=file_id,
|
||||
project_id=project_id,
|
||||
**parse_kwargs,
|
||||
)
|
||||
|
||||
logger.info(f"Downloaded contract to {file_path}")
|
||||
async with ctx.store.edit_state() as state:
|
||||
state.file_paths[file_id] = file_path
|
||||
state.filenames[file_id] = filename
|
||||
|
||||
return ContractDownloadedEvent(
|
||||
file_id=file_id, file_path=file_path, filename=filename
|
||||
return ContractParseStartedEvent(
|
||||
file_id=file_id,
|
||||
filename=filename,
|
||||
parse_job_id=parse_job.id,
|
||||
)
|
||||
|
||||
@step(num_workers=4)
|
||||
async def index_contract(
|
||||
self, event: ContractDownloadedEvent, ctx: Context[ContractIndexState]
|
||||
async def index_parsed_contract(
|
||||
self,
|
||||
event: ContractParseStartedEvent,
|
||||
ctx: Context[ContractIndexState],
|
||||
llama_cloud_client: Annotated[
|
||||
AsyncLlamaCloud, Resource(get_llama_cloud_client)
|
||||
],
|
||||
) -> ContractIndexedEvent:
|
||||
"""Index the contract document into LlamaCloud Index (runs in parallel)"""
|
||||
"""Wait for LlamaParse completion and upsert the contract markdown."""
|
||||
file_id = event.file_id
|
||||
file_path = event.file_path
|
||||
filename = event.filename
|
||||
|
||||
await llama_cloud_client.parsing.wait_for_completion(
|
||||
event.parse_job_id,
|
||||
project_id=project_id,
|
||||
)
|
||||
parse_result = await llama_cloud_client.parsing.get(
|
||||
event.parse_job_id,
|
||||
expand=["markdown"],
|
||||
project_id=project_id,
|
||||
)
|
||||
pages = parse_result.markdown.pages if parse_result.markdown else []
|
||||
markdown = "\n\n".join(page.markdown for page in pages if page.success)
|
||||
if not markdown:
|
||||
raise ValueError(f"Parse produced no markdown for contract {filename}")
|
||||
|
||||
logger.info(f"Indexing contract {filename}")
|
||||
ctx.write_event_to_stream(
|
||||
Status(level="info", message=f"Indexing contract: {filename}")
|
||||
)
|
||||
|
||||
# Create a document with metadata
|
||||
file_content = Path(file_path).read_text(errors="ignore")
|
||||
document = CloudDocumentCreateParam(
|
||||
text=file_content,
|
||||
text=markdown,
|
||||
metadata={
|
||||
"filename": filename,
|
||||
"file_id": file_id,
|
||||
"document_type": "contract",
|
||||
},
|
||||
)
|
||||
|
||||
# Get the contracts pipeline and upsert the document
|
||||
client = get_llama_cloud_client()
|
||||
pipeline_id = await get_contracts_pipeline_id()
|
||||
await client.pipelines.documents.upsert(
|
||||
await llama_cloud_client.pipelines.documents.upsert(
|
||||
pipeline_id=pipeline_id,
|
||||
body=[document],
|
||||
)
|
||||
|
||||
logger.info(f"Successfully indexed contract {filename}")
|
||||
ctx.write_event_to_stream(
|
||||
Status(
|
||||
level="info",
|
||||
message=f"Successfully indexed contract: {filename}",
|
||||
)
|
||||
)
|
||||
|
||||
return ContractIndexedEvent(file_id=file_id, filename=filename)
|
||||
|
||||
@step()
|
||||
async def collect_results(
|
||||
self, event: ContractIndexedEvent, ctx: Context[ContractIndexState]
|
||||
) -> StopEvent | None:
|
||||
"""Collect all indexed contracts and return final results (fan-in)"""
|
||||
"""Wait for every contract to finish, then return aggregated results."""
|
||||
state = await ctx.store.get_state()
|
||||
|
||||
# Collect all ContractIndexedEvent events - one for each file
|
||||
events = ctx.collect_events(event, [ContractIndexedEvent] * state.total_files)
|
||||
|
||||
if events is None:
|
||||
# Not all files have been indexed yet
|
||||
return None
|
||||
|
||||
# All files have been indexed, return aggregated results
|
||||
results = [{"file_id": ev.file_id, "filename": ev.filename} for ev in events]
|
||||
|
||||
logger.info(f"Successfully indexed all {len(results)} contracts")
|
||||
ctx.write_event_to_stream(
|
||||
Status(
|
||||
@@ -189,21 +193,22 @@ class IndexContractWorkflow(Workflow):
|
||||
message=f"Successfully indexed all {len(results)} contracts",
|
||||
)
|
||||
)
|
||||
|
||||
return StopEvent(result={"contracts": results, "total": len(results)})
|
||||
|
||||
|
||||
workflow = IndexContractWorkflow(timeout=None)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
async def main():
|
||||
# Example usage - upload a contract and index it
|
||||
file = await get_llama_cloud_client().files.create(
|
||||
file=Path("sample_contract.pdf").open("rb"),
|
||||
purpose="extract",
|
||||
|
||||
@@ -1,13 +1,16 @@
|
||||
from typing import Annotated, Any
|
||||
|
||||
import jsonref
|
||||
from llama_cloud.types.configuration_response import ExtractV2Parameters
|
||||
from workflows import Workflow, step
|
||||
from workflows.events import StartEvent, StopEvent
|
||||
from workflows.resource import ResourceConfig
|
||||
|
||||
from .clients import get_contracts_pipeline_id, get_llama_cloud_client, project_id
|
||||
from .config import EXTRACTED_DATA_COLLECTION, ExtractConfig
|
||||
from .config import (
|
||||
EXTRACTED_DATA_COLLECTION,
|
||||
ExtractConfig,
|
||||
build_review_schema,
|
||||
)
|
||||
|
||||
|
||||
class MetadataResponse(StopEvent):
|
||||
@@ -54,10 +57,9 @@ class MetadataWorkflow(Workflow):
|
||||
else:
|
||||
schema_dict = dict(extract_config.data_schema)
|
||||
|
||||
json_schema = jsonref.replace_refs(schema_dict, proxies=False)
|
||||
contracts_pipeline_id = await get_contracts_pipeline_id()
|
||||
return MetadataResponse(
|
||||
json_schema=json_schema,
|
||||
json_schema=build_review_schema(schema_dict),
|
||||
extracted_data_collection=EXTRACTED_DATA_COLLECTION,
|
||||
contracts_pipeline_id=contracts_pipeline_id,
|
||||
)
|
||||
|
||||
+31
-3
@@ -1,5 +1,5 @@
|
||||
import pytest
|
||||
from extraction_review.config import EXTRACTED_DATA_COLLECTION
|
||||
from extraction_review.config import EXTRACTED_DATA_COLLECTION, Reconciliation
|
||||
from extraction_review.index_contract import ContractFileEvent
|
||||
from extraction_review.index_contract import workflow as index_contract_workflow
|
||||
from extraction_review.metadata_workflow import MetadataResponse
|
||||
@@ -31,8 +31,12 @@ async def test_index_contract_workflow(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
fake: FakeLlamaCloudServer,
|
||||
) -> None:
|
||||
"""Regression: index_contract previously called stale SDK methods
|
||||
(files.get_file / files.read_file_content) and died on any contract upload.
|
||||
"""The contract indexer must run LlamaParse and upsert real markdown text.
|
||||
|
||||
Earlier versions read PDFs via `Path(...).read_text(errors="ignore")` and
|
||||
upserted the resulting binary garbage into the contracts pipeline, so the
|
||||
matching LLM saw stream-decoded bytes instead of contract content. Lock
|
||||
the fix: the upserted document must not be raw PDF bytes.
|
||||
"""
|
||||
monkeypatch.setenv("LLAMA_CLOUD_API_KEY", "fake-api-key")
|
||||
file_id = fake.files.preload(path="tests/files/test.pdf")
|
||||
@@ -43,6 +47,22 @@ async def test_index_contract_workflow(
|
||||
assert result["total"] == 1
|
||||
assert result["contracts"][0]["file_id"] == file_id
|
||||
|
||||
# The fake stores upserted documents under pipelines._documents; pull the
|
||||
# contract-tagged ones and verify the indexed text is parsed markdown,
|
||||
# not the raw PDF stream.
|
||||
indexed_texts = [
|
||||
doc.text
|
||||
for store in fake.pipelines._documents.values()
|
||||
for doc in store.values()
|
||||
if doc.metadata.get("file_id") == file_id
|
||||
]
|
||||
assert indexed_texts, "expected at least one upserted contract document"
|
||||
for text in indexed_texts:
|
||||
assert text, "indexed contract text is empty"
|
||||
assert not text.startswith("%PDF"), (
|
||||
f"contract was indexed as raw PDF bytes: {text[:40]!r}"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_metadata_workflow(
|
||||
@@ -56,3 +76,11 @@ async def test_metadata_workflow(
|
||||
assert isinstance(result.json_schema, dict)
|
||||
assert "properties" in result.json_schema
|
||||
assert result.contracts_pipeline_id
|
||||
|
||||
# Reconciliation overlay: presence guards the regression that dropped the
|
||||
# linkage fields entirely; ordering guards a "cleanup" that silently moves
|
||||
# the verdict to the bottom of the form.
|
||||
properties = result.json_schema["properties"]
|
||||
reconciliation_fields = set(Reconciliation.model_fields)
|
||||
assert reconciliation_fields.issubset(properties.keys())
|
||||
assert next(iter(properties)) in reconciliation_fields
|
||||
|
||||
Reference in New Issue
Block a user