mirror of
https://github.com/run-llama/template-workflow-data-extraction.git
synced 2026-07-01 21:34:19 -04:00
259 lines
9.0 KiB
Python
259 lines
9.0 KiB
Python
import asyncio
|
|
import json
|
|
import logging
|
|
from typing import Annotated, Any, Literal
|
|
|
|
from llama_cloud import AsyncLlamaCloud
|
|
from llama_cloud.types.beta.extracted_data import ExtractedData, InvalidExtractionData
|
|
from llama_cloud.types.configuration_response import ExtractV2Parameters
|
|
from pydantic import BaseModel
|
|
from workflows import Context, Workflow, step
|
|
from workflows.events import Event, StartEvent, StopEvent
|
|
from workflows.resource import Resource, ResourceConfig
|
|
|
|
from .clients import agent_name, get_llama_cloud_client, project_id
|
|
from .config import EXTRACTED_DATA_COLLECTION, ExtractConfig, get_extraction_schema
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class FileEvent(StartEvent):
|
|
file_id: str
|
|
file_hash: str | None = None
|
|
|
|
|
|
class Status(Event):
|
|
level: Literal["info", "warning", "error"]
|
|
message: str
|
|
|
|
|
|
class ExtractJobStartedEvent(Event):
|
|
pass
|
|
|
|
|
|
class ExtractedEvent(Event):
|
|
data: ExtractedData
|
|
|
|
|
|
class ExtractedInvalidEvent(Event):
|
|
data: ExtractedData[dict[str, Any]]
|
|
|
|
|
|
class ExtractionState(BaseModel):
|
|
file_id: str | None = None
|
|
filename: str | None = None
|
|
file_hash: str | None = None
|
|
extract_job_id: str | None = None
|
|
|
|
|
|
class ProcessFileWorkflow(Workflow):
|
|
"""Extract structured data from a document and save it for review."""
|
|
|
|
@step()
|
|
async def start_extraction(
|
|
self,
|
|
event: FileEvent,
|
|
ctx: Context[ExtractionState],
|
|
llama_cloud_client: Annotated[
|
|
AsyncLlamaCloud, Resource(get_llama_cloud_client)
|
|
],
|
|
extract_config: Annotated[
|
|
ExtractConfig,
|
|
ResourceConfig(
|
|
config_file="configs/config.json",
|
|
path_selector="extract",
|
|
label="Extraction Settings",
|
|
description="Configuration for document extraction quality and features",
|
|
),
|
|
],
|
|
) -> ExtractJobStartedEvent:
|
|
"""Start extraction job for the document."""
|
|
file_id = event.file_id
|
|
logger.info(f"Running file {file_id}")
|
|
|
|
# Get file metadata (v2: files.list returns AsyncPaginator)
|
|
try:
|
|
file_metadata = None
|
|
async for f in llama_cloud_client.files.list(file_ids=[file_id]):
|
|
file_metadata = f
|
|
break
|
|
if file_metadata is None:
|
|
raise ValueError(f"File {file_id} not found")
|
|
filename = file_metadata.name
|
|
except Exception as e:
|
|
logger.error(f"Error fetching file metadata {file_id}: {e}", exc_info=True)
|
|
ctx.write_event_to_stream(
|
|
Status(
|
|
level="error",
|
|
message=f"Error fetching file metadata {file_id}: {e}",
|
|
)
|
|
)
|
|
raise e
|
|
|
|
logger.info(f"Extracting data from file {filename}")
|
|
ctx.write_event_to_stream(
|
|
Status(level="info", message=f"Extracting data from file {filename}")
|
|
)
|
|
|
|
if extract_config.configuration_id:
|
|
extract_job = await llama_cloud_client.extract.create(
|
|
file_input=file_id,
|
|
configuration_id=extract_config.configuration_id,
|
|
project_id=project_id,
|
|
)
|
|
else:
|
|
extract_job = await llama_cloud_client.extract.create(
|
|
file_input=file_id,
|
|
configuration=extract_config.model_dump(
|
|
exclude={"configuration_id", "product_type"},
|
|
exclude_none=True,
|
|
),
|
|
project_id=project_id,
|
|
)
|
|
|
|
# Use file_hash from the event (computed by UI from file content)
|
|
# or fall back to external_file_id from file metadata for deduplication
|
|
file_hash = event.file_hash or file_metadata.external_file_id
|
|
|
|
async with ctx.store.edit_state() as state:
|
|
state.file_id = file_id
|
|
state.filename = filename
|
|
state.file_hash = file_hash
|
|
state.extract_job_id = extract_job.id
|
|
|
|
return ExtractJobStartedEvent()
|
|
|
|
@step()
|
|
async def complete_extraction(
|
|
self,
|
|
event: ExtractJobStartedEvent,
|
|
ctx: Context[ExtractionState],
|
|
llama_cloud_client: Annotated[
|
|
AsyncLlamaCloud, Resource(get_llama_cloud_client)
|
|
],
|
|
extract_config: Annotated[
|
|
ExtractConfig,
|
|
ResourceConfig(
|
|
config_file="configs/config.json",
|
|
path_selector="extract",
|
|
label="Extraction Settings",
|
|
description="Configuration for document extraction quality and features",
|
|
),
|
|
],
|
|
) -> StopEvent:
|
|
"""Wait for extraction to complete, validate results, and save for review."""
|
|
state = await ctx.store.get_state()
|
|
if state.extract_job_id is None:
|
|
raise ValueError("Job ID cannot be null when waiting for its completion")
|
|
|
|
# Wait for extraction job to complete; v2 returns the completed job
|
|
# (with extract_result embedded) directly.
|
|
job = await llama_cloud_client.extract.wait_for_completion(
|
|
state.extract_job_id,
|
|
project_id=project_id,
|
|
)
|
|
# Re-fetch with extract_metadata expansion for field-level metadata
|
|
job = await llama_cloud_client.extract.get(
|
|
state.extract_job_id,
|
|
expand=["extract_metadata"],
|
|
project_id=project_id,
|
|
)
|
|
|
|
extracted_event: ExtractedEvent | ExtractedInvalidEvent
|
|
try:
|
|
logger.info(
|
|
f"Extracted data: {json.dumps(job.model_dump(mode='json'), indent=2, default=str)}"
|
|
)
|
|
# Resolve the schema that governs the extracted data.
|
|
if extract_config.configuration_id:
|
|
config_resp = await llama_cloud_client.configurations.retrieve(
|
|
extract_config.configuration_id,
|
|
project_id=project_id,
|
|
)
|
|
params = config_resp.parameters
|
|
if not isinstance(params, ExtractV2Parameters):
|
|
raise ValueError(
|
|
f"Configuration {extract_config.configuration_id} is not extract_v2"
|
|
)
|
|
schema_class = get_extraction_schema(dict(params.data_schema))
|
|
else:
|
|
schema_class = get_extraction_schema(dict(extract_config.data_schema))
|
|
|
|
data = ExtractedData.from_extract_job(
|
|
job=job,
|
|
schema=schema_class,
|
|
file_name=state.filename,
|
|
file_id=state.file_id,
|
|
file_hash=state.file_hash,
|
|
)
|
|
extracted_event = ExtractedEvent(data=data)
|
|
except InvalidExtractionData as e:
|
|
logger.error(f"Error validating extracted data: {e}", exc_info=True)
|
|
extracted_event = ExtractedInvalidEvent(data=e.invalid_item)
|
|
except Exception as e:
|
|
logger.error(
|
|
f"Error extracting data from file {state.filename}: {e}", exc_info=True
|
|
)
|
|
ctx.write_event_to_stream(
|
|
Status(
|
|
level="error",
|
|
message=f"Error extracting data from file {state.filename}: {e}",
|
|
)
|
|
)
|
|
raise e
|
|
|
|
ctx.write_event_to_stream(extracted_event)
|
|
|
|
extracted_data = extracted_event.data
|
|
data_dict = extracted_data.model_dump()
|
|
if extracted_data.file_hash is not None:
|
|
delete_result = await llama_cloud_client.beta.agent_data.delete_by_query(
|
|
deployment_name=agent_name or "_public",
|
|
collection=EXTRACTED_DATA_COLLECTION,
|
|
filter={
|
|
"file_hash": {
|
|
"eq": extracted_data.file_hash,
|
|
},
|
|
},
|
|
)
|
|
if delete_result.deleted_count > 0:
|
|
logger.info(
|
|
f"Removed {delete_result.deleted_count} existing record(s) "
|
|
f"for file {extracted_data.file_name}"
|
|
)
|
|
item = await llama_cloud_client.beta.agent_data.create(
|
|
data=data_dict,
|
|
deployment_name=agent_name or "_public",
|
|
collection=EXTRACTED_DATA_COLLECTION,
|
|
)
|
|
logger.info(
|
|
f"Recorded extracted data for file {extracted_data.file_name or ''}"
|
|
)
|
|
ctx.write_event_to_stream(
|
|
Status(
|
|
level="info",
|
|
message=f"Recorded extracted data for file {extracted_data.file_name or ''}",
|
|
)
|
|
)
|
|
return StopEvent(result=item.id)
|
|
|
|
|
|
workflow = ProcessFileWorkflow(timeout=None)
|
|
|
|
if __name__ == "__main__":
|
|
from pathlib import Path
|
|
|
|
from dotenv import load_dotenv
|
|
|
|
load_dotenv()
|
|
logging.basicConfig(level=logging.INFO)
|
|
|
|
async def main():
|
|
file = await get_llama_cloud_client().files.create(
|
|
file=Path("test.pdf").open("rb"),
|
|
purpose="extract",
|
|
)
|
|
await workflow.run(start_event=FileEvent(file_id=file.id))
|
|
|
|
asyncio.run(main())
|