Files
template-workflow-data-extr…/src/extraction_review/process_file.py
T

259 lines
9.0 KiB
Python

import asyncio
import json
import logging
from typing import Annotated, Any, Literal
from llama_cloud import AsyncLlamaCloud
from llama_cloud.types.beta.extracted_data import ExtractedData, InvalidExtractionData
from llama_cloud.types.configuration_response import ExtractV2Parameters
from pydantic import BaseModel
from workflows import Context, Workflow, step
from workflows.events import Event, StartEvent, StopEvent
from workflows.resource import Resource, ResourceConfig
from .clients import agent_name, get_llama_cloud_client, project_id
from .config import EXTRACTED_DATA_COLLECTION, ExtractConfig, get_extraction_schema
logger = logging.getLogger(__name__)
class FileEvent(StartEvent):
file_id: str
file_hash: str | None = None
class Status(Event):
level: Literal["info", "warning", "error"]
message: str
class ExtractJobStartedEvent(Event):
pass
class ExtractedEvent(Event):
data: ExtractedData
class ExtractedInvalidEvent(Event):
data: ExtractedData[dict[str, Any]]
class ExtractionState(BaseModel):
file_id: str | None = None
filename: str | None = None
file_hash: str | None = None
extract_job_id: str | None = None
class ProcessFileWorkflow(Workflow):
"""Extract structured data from a document and save it for review."""
@step()
async def start_extraction(
self,
event: FileEvent,
ctx: Context[ExtractionState],
llama_cloud_client: Annotated[
AsyncLlamaCloud, Resource(get_llama_cloud_client)
],
extract_config: Annotated[
ExtractConfig,
ResourceConfig(
config_file="configs/config.json",
path_selector="extract",
label="Extraction Settings",
description="Configuration for document extraction quality and features",
),
],
) -> ExtractJobStartedEvent:
"""Start extraction job for the document."""
file_id = event.file_id
logger.info(f"Running file {file_id}")
# Get file metadata (v2: files.list returns AsyncPaginator)
try:
file_metadata = None
async for f in llama_cloud_client.files.list(file_ids=[file_id]):
file_metadata = f
break
if file_metadata is None:
raise ValueError(f"File {file_id} not found")
filename = file_metadata.name
except Exception as e:
logger.error(f"Error fetching file metadata {file_id}: {e}", exc_info=True)
ctx.write_event_to_stream(
Status(
level="error",
message=f"Error fetching file metadata {file_id}: {e}",
)
)
raise e
logger.info(f"Extracting data from file {filename}")
ctx.write_event_to_stream(
Status(level="info", message=f"Extracting data from file {filename}")
)
if extract_config.configuration_id:
extract_job = await llama_cloud_client.extract.create(
file_input=file_id,
configuration_id=extract_config.configuration_id,
project_id=project_id,
)
else:
extract_job = await llama_cloud_client.extract.create(
file_input=file_id,
configuration=extract_config.model_dump(
exclude={"configuration_id", "product_type"},
exclude_none=True,
),
project_id=project_id,
)
# Use file_hash from the event (computed by UI from file content)
# or fall back to external_file_id from file metadata for deduplication
file_hash = event.file_hash or file_metadata.external_file_id
async with ctx.store.edit_state() as state:
state.file_id = file_id
state.filename = filename
state.file_hash = file_hash
state.extract_job_id = extract_job.id
return ExtractJobStartedEvent()
@step()
async def complete_extraction(
self,
event: ExtractJobStartedEvent,
ctx: Context[ExtractionState],
llama_cloud_client: Annotated[
AsyncLlamaCloud, Resource(get_llama_cloud_client)
],
extract_config: Annotated[
ExtractConfig,
ResourceConfig(
config_file="configs/config.json",
path_selector="extract",
label="Extraction Settings",
description="Configuration for document extraction quality and features",
),
],
) -> StopEvent:
"""Wait for extraction to complete, validate results, and save for review."""
state = await ctx.store.get_state()
if state.extract_job_id is None:
raise ValueError("Job ID cannot be null when waiting for its completion")
# Wait for extraction job to complete; v2 returns the completed job
# (with extract_result embedded) directly.
job = await llama_cloud_client.extract.wait_for_completion(
state.extract_job_id,
project_id=project_id,
)
# Re-fetch with extract_metadata expansion for field-level metadata
job = await llama_cloud_client.extract.get(
state.extract_job_id,
expand=["extract_metadata"],
project_id=project_id,
)
extracted_event: ExtractedEvent | ExtractedInvalidEvent
try:
logger.info(
f"Extracted data: {json.dumps(job.model_dump(mode='json'), indent=2, default=str)}"
)
# Resolve the schema that governs the extracted data.
if extract_config.configuration_id:
config_resp = await llama_cloud_client.configurations.retrieve(
extract_config.configuration_id,
project_id=project_id,
)
params = config_resp.parameters
if not isinstance(params, ExtractV2Parameters):
raise ValueError(
f"Configuration {extract_config.configuration_id} is not extract_v2"
)
schema_class = get_extraction_schema(dict(params.data_schema))
else:
schema_class = get_extraction_schema(dict(extract_config.data_schema))
data = ExtractedData.from_extract_job(
job=job,
schema=schema_class,
file_name=state.filename,
file_id=state.file_id,
file_hash=state.file_hash,
)
extracted_event = ExtractedEvent(data=data)
except InvalidExtractionData as e:
logger.error(f"Error validating extracted data: {e}", exc_info=True)
extracted_event = ExtractedInvalidEvent(data=e.invalid_item)
except Exception as e:
logger.error(
f"Error extracting data from file {state.filename}: {e}", exc_info=True
)
ctx.write_event_to_stream(
Status(
level="error",
message=f"Error extracting data from file {state.filename}: {e}",
)
)
raise e
ctx.write_event_to_stream(extracted_event)
extracted_data = extracted_event.data
data_dict = extracted_data.model_dump()
if extracted_data.file_hash is not None:
delete_result = await llama_cloud_client.beta.agent_data.delete_by_query(
deployment_name=agent_name or "_public",
collection=EXTRACTED_DATA_COLLECTION,
filter={
"file_hash": {
"eq": extracted_data.file_hash,
},
},
)
if delete_result.deleted_count > 0:
logger.info(
f"Removed {delete_result.deleted_count} existing record(s) "
f"for file {extracted_data.file_name}"
)
item = await llama_cloud_client.beta.agent_data.create(
data=data_dict,
deployment_name=agent_name or "_public",
collection=EXTRACTED_DATA_COLLECTION,
)
logger.info(
f"Recorded extracted data for file {extracted_data.file_name or ''}"
)
ctx.write_event_to_stream(
Status(
level="info",
message=f"Recorded extracted data for file {extracted_data.file_name or ''}",
)
)
return StopEvent(result=item.id)
workflow = ProcessFileWorkflow(timeout=None)
if __name__ == "__main__":
from pathlib import Path
from dotenv import load_dotenv
load_dotenv()
logging.basicConfig(level=logging.INFO)
async def main():
file = await get_llama_cloud_client().files.create(
file=Path("test.pdf").open("rb"),
purpose="extract",
)
await workflow.run(start_event=FileEvent(file_id=file.id))
asyncio.run(main())