feat: copier update child templates from data-extraction v0.5.0 (#241)

This commit is contained in:
Adrian Lyjak
2026-03-18 19:11:41 -04:00
committed by GitHub
parent e161961a0a
commit cd3411d0a6
15 changed files with 1829 additions and 465 deletions
+1 -1
View File
@@ -1,3 +1,3 @@
# Changes here will be overwritten by Copier; NEVER EDIT MANUALLY
_commit: v0.4.0
_commit: v0.5.0
_src_path: https://github.com/run-llama/template-workflow-data-extraction
+1206
View File
File diff suppressed because it is too large Load Diff
+3 -2
View File
@@ -5,8 +5,9 @@ description = "Extracts data"
readme = "README.md"
requires-python = ">=3.12"
dependencies = [
"llama-cloud-services>=0.6.69",
"llama-index-workflows>=2.2.0,<3.0.0",
"llama-cloud>=1.3.0",
"json-schema-to-pydantic>=0.4.8",
"llama-index-workflows>=2.16.0,<3.0.0",
"python-dotenv>=1.1.0",
"jsonref>=1.1.0",
"click>=8.2.1,<8.3.0",
+6 -71
View File
@@ -1,29 +1,11 @@
import functools
import os
from typing import Any
import httpx
from llama_cloud_services import ExtractionAgent, LlamaExtract
from llama_cloud.core.api_error import ApiError
from llama_cloud_services.beta.agent_data import AsyncAgentDataClient, ExtractedData
from llama_cloud_services.beta.classifier.client import ClassifyClient
from llama_cloud.client import AsyncLlamaCloud
import logging
import os
from extraction_review.config import (
EXTRACT_CONFIG,
EXTRACTED_DATA_COLLECTION,
EXTRACTION_AGENT_NAME,
USE_REMOTE_EXTRACTION_SCHEMA,
ExtractionSchema,
)
from llama_cloud import AsyncLlamaCloud
logger = logging.getLogger(__name__)
# deployed agents may infer their name from the deployment name
# Note: Make sure that an agent deployment with this name actually exists
# otherwise calls to get or set data will fail. You may need to adjust the `or `
# name for development
agent_name = os.getenv("LLAMA_DEPLOY_DEPLOYMENT_NAME")
# required for all llama cloud calls
api_key = os.getenv("LLAMA_CLOUD_API_KEY")
@@ -32,57 +14,10 @@ base_url = os.getenv("LLAMA_CLOUD_BASE_URL")
project_id = os.getenv("LLAMA_DEPLOY_PROJECT_ID")
@functools.lru_cache(maxsize=None)
def get_extract_agent() -> ExtractionAgent:
extract_api = LlamaExtract(
api_key=api_key, base_url=base_url, project_id=project_id
)
try:
existing = extract_api.get_agent(EXTRACTION_AGENT_NAME)
if not USE_REMOTE_EXTRACTION_SCHEMA:
existing.data_schema = ExtractionSchema
existing.config = EXTRACT_CONFIG
return existing
except ApiError as e:
if e.status_code == 404:
if USE_REMOTE_EXTRACTION_SCHEMA:
logger.warning(
"Extraction agent does not exist, creating a new one from the local schema"
)
return extract_api.create_agent(
name=EXTRACTION_AGENT_NAME,
data_schema=ExtractionSchema,
config=EXTRACT_CONFIG,
)
else:
raise
@functools.lru_cache(maxsize=None)
def get_data_client() -> AsyncAgentDataClient:
return AsyncAgentDataClient(
deployment_name=agent_name,
collection=EXTRACTED_DATA_COLLECTION,
type=ExtractedData[Any],
client=get_llama_cloud_client(),
)
@functools.lru_cache(maxsize=None)
def get_llama_cloud_client():
def get_llama_cloud_client() -> AsyncLlamaCloud:
"""Cloud services connection for file storage and processing."""
return AsyncLlamaCloud(
api_key=api_key,
base_url=base_url,
token=api_key,
httpx_client=httpx.AsyncClient(
timeout=60, headers={"Project-Id": project_id} if project_id else None
),
)
@functools.lru_cache(maxsize=None)
def get_classifier_client():
return ClassifyClient(
client=get_llama_cloud_client(),
project_id=project_id,
default_headers={"Project-Id": project_id} if project_id else {},
)
+77 -32
View File
@@ -1,30 +1,35 @@
"""
For simple configuration of the extraction review application, just customize this file.
Configuration for the extraction review application.
If you need more control, feel free to edit the rest of the application
Configuration is loaded from configs/config.json via ResourceConfig.
The unified config contains both extraction settings and the JSON schema.
Extraction can run in two modes, controlled by the "extraction_agent_id" field
in configs/config.json:
- Local (default): extraction_agent_id is null. Uses the json_schema and
settings defined in config.json directly via extraction.run().
- Remote agent: extraction_agent_id is set to a LlamaCloud extraction agent
ID. Uses extraction.jobs.extract(extraction_agent_id=...) which delegates
schema and settings to the remote agent. The local json_schema and settings
in config.json are ignored — both extraction and the metadata workflow fetch
the schema directly from the remote agent.
"""
from __future__ import annotations
import os
from typing import Type
import logging
from typing import Any, Literal
from llama_cloud import ExtractConfig
from llama_cloud_services.extract import ExtractMode
from pydantic import BaseModel, Field
# If you change this to true, the schema and extraction configuration will be fetched from the remote extraction agent
# rather than using the ExtractionSchema and configuration defined below.
USE_REMOTE_EXTRACTION_SCHEMA: bool = False
# The name of the extraction agent to use. Prefers the name of this deployment when deployed to isolate environments.
# Note that the application will create a new agent from the below ExtractionSchema if the extraction agent does not yet exist.
EXTRACTION_AGENT_NAME: str = (
os.getenv("LLAMA_DEPLOY_DEPLOYMENT_NAME") or "sec-filing-extraction"
)
# The name of the collection to use for storing extracted data. This will be qualified by the agent name.
# When developing locally, this will use the _public collection (shared within the project), otherwise agent
# data is isolated to each agent
EXTRACTED_DATA_COLLECTION: str = "sec-filing-extraction"
from .json_util import create_union_schema as create_union_schema
from .json_util import get_extraction_schema as get_extraction_schema
logger = logging.getLogger(__name__)
# The name of the collection to use for storing extracted data.
EXTRACTED_DATA_COLLECTION: str = "sec-filing-extraction"
# SEC Filing Classification Types
SEC_FILING_TYPES = ["10-K", "10-Q", "8-K", "other"]
@@ -346,18 +351,58 @@ FILING_SCHEMAS = {
}
# This is only used if USE_REMOTE_EXTRACTION_SCHEMA is False.
EXTRACT_CONFIG = ExtractConfig(
extraction_mode=ExtractMode.PREMIUM,
system_prompt=None,
# advanced. Only compatible with Premium mode.
citation_bbox=True,
use_reasoning=False,
cite_sources=True,
confidence_scores=False,
)
class ExtractSettings(BaseModel):
extraction_mode: Literal["FAST", "PREMIUM", "MULTIMODAL"]
system_prompt: str | None = None
citation_bbox: bool = False
use_reasoning: bool = False
cite_sources: bool = False
confidence_scores: bool = False
SCHEMA: Type[BaseModel] | None = (
None if USE_REMOTE_EXTRACTION_SCHEMA else ExtractionSchema
)
class ExtractConfig(BaseModel):
json_schema: dict[str, Any]
settings: ExtractSettings
# Set this to a LlamaCloud extraction agent ID to use a remote agent's
# schema and settings instead of the local json_schema/settings above.
# When set, extraction uses extraction.jobs.extract(extraction_agent_id=...)
# and the local settings are ignored for extraction.
extraction_agent_id: str | None = None
class JsonSchema(BaseModel):
type: str = "object"
properties: dict[str, Any] = {}
required: list[str] = []
def to_dict(self) -> dict[str, Any]:
return self.model_dump(exclude_none=True)
class ClassifyRule(BaseModel):
"""Classify rule, with type (rule target) and description (rule description)"""
type: str
description: str
class ClassifyParsingConfig(BaseModel):
"""Parsing config for Classify"""
lang: str = Field(description="two-letter ISO 639 language code", default="en")
max_pages: int | None = None
target_pages: list[int] | None = None
class ClassifySettings(BaseModel):
"""Extra settings for Classify"""
mode: Literal["FAST", "MULTIMODAL"] = "FAST"
parsing_config: ClassifyParsingConfig = ClassifyParsingConfig()
class ClassifyConfig(BaseModel):
"""Classify configuration, with rules and settings"""
rules: list[ClassifyRule] = []
settings: ClassifySettings = ClassifySettings()
+163
View File
@@ -0,0 +1,163 @@
"""Utilities for working with JSON schemas."""
import hashlib
import json
import logging
from functools import lru_cache
from typing import Any
from json_schema_to_pydantic import create_model
from pydantic import BaseModel
logger = logging.getLogger(__name__)
def _hash_schema(json_schema: dict[str, Any]) -> str:
"""Create a stable hash of a JSON schema for caching."""
schema_str = json.dumps(json_schema, sort_keys=True)
return hashlib.sha256(schema_str.encode()).hexdigest()
@lru_cache(maxsize=16)
def _get_cached_model(schema_hash: str, schema_json: str) -> type[BaseModel]:
"""Get or create a Pydantic model from a JSON schema, cached by hash."""
schema = json.loads(schema_json)
return create_model(schema)
def get_extraction_schema(
json_schema: dict[str, Any],
discriminator_field: str | None = None,
discriminator_value: str | None = None,
) -> type[BaseModel]:
"""Convert a JSON schema to a Pydantic model for validating extraction results.
Args:
json_schema: A JSON Schema object from config.
discriminator_field: Field name to identify document type (e.g., "document_type").
discriminator_value: The value for the discriminator field (e.g., "10-K").
Returns:
A Pydantic model class for validation.
"""
schema = json_schema
if discriminator_field is not None:
if discriminator_value is None:
raise ValueError(
"discriminator_value is required when discriminator_field is set"
)
schema = _add_discriminator_to_schema(
schema, discriminator_field, discriminator_value
)
schema_hash = _hash_schema(schema)
schema_json = json.dumps(schema, sort_keys=True)
return _get_cached_model(schema_hash, schema_json)
def _add_discriminator_to_schema(
schema: dict[str, Any],
discriminator_field: str,
discriminator_value: str,
) -> dict[str, Any]:
"""Add a discriminator field with a default value to a schema."""
properties = schema.get("properties", {})
required = schema.get("required", [])
new_properties = {
discriminator_field: {
"type": "string",
"default": discriminator_value,
"description": "Type of document that was extracted",
},
**properties,
}
new_required = list(required)
if discriminator_field not in new_required:
new_required = [discriminator_field] + new_required
return {
**schema,
"properties": new_properties,
"required": new_required,
}
def _schemas_are_equal(schema1: dict[str, Any], schema2: dict[str, Any]) -> bool:
"""Check if two JSON schemas are structurally equal."""
return json.dumps(schema1, sort_keys=True) == json.dumps(schema2, sort_keys=True)
def _merge_property_schemas(
existing: dict[str, Any], new: dict[str, Any]
) -> dict[str, Any]:
"""Merge two property schemas, creating anyOf for conflicting types."""
if _schemas_are_equal(existing, new):
return existing
if "anyOf" in existing:
for variant in existing["anyOf"]:
if _schemas_are_equal(variant, new):
return existing
return {"anyOf": existing["anyOf"] + [new]}
return {"anyOf": [existing, new]}
def create_union_schema(
schemas: dict[str, dict[str, Any]],
discriminator_field: str = "document_type",
) -> dict[str, Any]:
"""Create a union JSON schema from multiple extraction schemas.
Merges all properties from input schemas into a single flat schema,
adding a discriminator field.
"""
schemas_with_discriminator = [
name
for name, schema in schemas.items()
if discriminator_field in schema.get("properties", {})
]
if schemas_with_discriminator:
logger.warning(
f"Discriminator field '{discriminator_field}' found in schemas: "
f"{', '.join(schemas_with_discriminator)}. "
f"It will be replaced with the union discriminator."
)
all_properties: dict[str, Any] = {}
for schema in schemas.values():
for prop_name, prop_def in schema.get("properties", {}).items():
if prop_name == discriminator_field:
continue
if prop_name not in all_properties:
all_properties[prop_name] = prop_def
else:
all_properties[prop_name] = _merge_property_schemas(
all_properties[prop_name], prop_def
)
schema_list = list(schemas.values())
common_required: set[str] = set()
if schema_list:
common_required = set(schema_list[0].get("required", []))
for schema in schema_list[1:]:
common_required &= set(schema.get("required", []))
common_required.discard(discriminator_field)
required_fields = [discriminator_field] + sorted(common_required)
return {
"type": "object",
"properties": {
discriminator_field: {
"type": "string",
"enum": list(schemas.keys()),
"description": "Type of document that was extracted",
},
**all_properties,
},
"required": required_fields,
}
+65 -16
View File
@@ -1,34 +1,83 @@
from typing import Any
from typing import Annotated, Any
from workflows import Workflow, step
from workflows.events import StartEvent, StopEvent
from workflows.resource import Resource, ResourceConfig
import jsonref
from .config import EXTRACTED_DATA_COLLECTION, ExtractConfig, create_union_schema
from .config import EXTRACTED_DATA_COLLECTION, FILING_SCHEMAS
DISCRIMINATOR_FIELD = "document_type"
class MetadataResponse(StopEvent):
json_schema: dict[str, Any]
schemas: dict[str, dict[str, Any]]
discriminator_field: str
extracted_data_collection: str
async def get_presentation_schema(
extract_10k: Annotated[
ExtractConfig,
ResourceConfig(
config_file="configs/config.json",
path_selector="extract-10k",
label="10-K Extraction",
),
],
extract_10q: Annotated[
ExtractConfig,
ResourceConfig(
config_file="configs/config.json",
path_selector="extract-10q",
label="10-Q Extraction",
),
],
extract_8k: Annotated[
ExtractConfig,
ResourceConfig(
config_file="configs/config.json",
path_selector="extract-8k",
label="8-K Extraction",
),
],
extract_other: Annotated[
ExtractConfig,
ResourceConfig(
config_file="configs/config.json",
path_selector="extract-other",
label="Other Extraction",
),
],
) -> dict[str, Any]:
schemas = {
"10-K": extract_10k.json_schema,
"10-Q": extract_10q.json_schema,
"8-K": extract_8k.json_schema,
"other": extract_other.json_schema,
}
union = create_union_schema(schemas, discriminator_field=DISCRIMINATOR_FIELD)
return {
"json_schema": union,
"schemas": schemas,
"discriminator_field": DISCRIMINATOR_FIELD,
}
class MetadataWorkflow(Workflow):
"""
Simple single step workflow to expose configuration to the UI, such as all JSON schemas and collection name.
"""
"""Provide extraction schema and configuration to the workflow editor."""
@step
async def get_metadata(self, _: StartEvent) -> MetadataResponse:
# Convert all filing schemas to JSON schemas
schemas = {}
for filing_type, schema_class in FILING_SCHEMAS.items():
json_schema = schema_class.model_json_schema()
# Resolve any $ref references
json_schema = jsonref.replace_refs(json_schema, proxies=False)
schemas[filing_type] = json_schema
async def get_metadata(
self,
_: StartEvent,
presentation: Annotated[dict[str, Any], Resource(get_presentation_schema)],
) -> MetadataResponse:
"""Return the data schemas and storage settings for the review interface."""
return MetadataResponse(
schemas=schemas,
json_schema=presentation["json_schema"],
schemas=presentation["schemas"],
discriminator_field=presentation["discriminator_field"],
extracted_data_collection=EXTRACTED_DATA_COLLECTION,
)
+256 -224
View File
@@ -1,41 +1,39 @@
import asyncio
import hashlib
import json
import logging
import os
from pathlib import Path
import tempfile
from typing import Any, Literal
from typing import Annotated, Any, Literal
import httpx
from llama_cloud import ClassificationResult, ExtractRun
from llama_cloud.types import ClassifierRule, ClassifyParsingConfiguration
from llama_cloud_services.extract import SourceText
from llama_cloud_services.beta.agent_data import ExtractedData, InvalidExtractionData
from llama_cloud import AsyncLlamaCloud
from llama_cloud.types.beta.extracted_data import ExtractedData, InvalidExtractionData
from pydantic import BaseModel
from workflows import Context, Workflow, step
from workflows.events import Event, StartEvent, StopEvent
from workflows.resource import Resource, ResourceConfig
from .clients import (
get_classifier_client,
get_llama_cloud_client,
get_data_client,
get_extract_agent,
from .clients import agent_name, get_llama_cloud_client, project_id
from .config import (
EXTRACTED_DATA_COLLECTION,
ClassifyConfig,
ExtractConfig,
get_extraction_schema,
)
from .config import FILING_SCHEMAS
logger = logging.getLogger(__name__)
DISCRIMINATOR_FIELD = "document_type"
# Mapping from classify rule types to config.json extract section keys
EXTRACT_CONFIG_KEYS = {
"10-K": "extract-10k",
"10-Q": "extract-10q",
"8-K": "extract-8k",
"other": "extract-other",
}
class FileEvent(StartEvent):
file_id: str
class DownloadFileEvent(Event):
pass
class FileDownloadedEvent(Event):
pass
file_hash: str | None = None
class ClassifyFileEvent(Event):
@@ -53,6 +51,10 @@ class Status(Event):
message: str
class ExtractJobStartedEvent(Event):
pass
class ExtractedEvent(Event):
data: ExtractedData
@@ -63,76 +65,103 @@ class ExtractedInvalidEvent(Event):
class ExtractionState(BaseModel):
file_id: str | None = None
file_path: str | None = None
filename: str | None = None
file_hash: str | None = None
extract_job_id: str | None = None
filing_type: str | None = None
classification_confidence: float | None = None
classification_reasoning: str | None = None
class ProcessFileWorkflow(Workflow):
"""
Given a file path, this workflow will process a single file through the custom extraction logic.
"""
"""Extract structured data from a document and save it for review."""
@step()
async def run_file(self, event: FileEvent, ctx: Context) -> DownloadFileEvent:
logger.info(f"Running file {event.file_id}")
async with ctx.store.edit_state() as state:
state.file_id = event.file_id
return DownloadFileEvent()
async def start_extraction(
self,
event: FileEvent,
ctx: Context[ExtractionState],
llama_cloud_client: Annotated[
AsyncLlamaCloud, Resource(get_llama_cloud_client)
],
extract_config: Annotated[
ExtractConfig,
ResourceConfig(
config_file="configs/config.json",
path_selector="extract-10k",
label="Default Extraction Settings",
description="Default extraction config (10-K); actual schema selected after classification",
),
],
) -> ExtractJobStartedEvent:
"""Start extraction job for the document."""
file_id = event.file_id
logger.info(f"Running file {file_id}")
@step()
async def download_file(
self, event: DownloadFileEvent, ctx: Context[ExtractionState]
) -> ClassifyFileEvent:
"""Download the file reference from the cloud storage"""
state = await ctx.store.get_state()
if state.file_id is None:
raise ValueError("File ID is not set")
try:
file_metadata = await get_llama_cloud_client().files.get_file(
id=state.file_id
)
file_url = await get_llama_cloud_client().files.read_file_content(
state.file_id
)
temp_dir = tempfile.gettempdir()
files_page = await llama_cloud_client.files.list(file_ids=[file_id])
file_metadata = files_page.items[0]
filename = file_metadata.name
file_path = os.path.join(temp_dir, filename)
client = httpx.AsyncClient()
# Report progress to the UI
logger.info(f"Downloading file {file_url.url} to {file_path}")
async with client.stream("GET", file_url.url) as response:
with open(file_path, "wb") as f:
async for chunk in response.aiter_bytes():
f.write(chunk)
logger.info(f"Downloaded file {file_url.url} to {file_path}")
async with ctx.store.edit_state() as state:
state.file_path = file_path
state.filename = filename
return ClassifyFileEvent()
except Exception as e:
logger.error(f"Error downloading file {state.file_id}: {e}", exc_info=True)
logger.error(f"Error fetching file metadata {file_id}: {e}", exc_info=True)
ctx.write_event_to_stream(
Status(
level="error",
message=f"Error downloading file {state.file_id}: {e}",
message=f"Error fetching file metadata {file_id}: {e}",
)
)
raise e
logger.info(f"Extracting data from file {filename}")
ctx.write_event_to_stream(
Status(level="info", message=f"Extracting data from file {filename}")
)
if extract_config.extraction_agent_id:
extract_job = await llama_cloud_client.extraction.jobs.extract(
extraction_agent_id=extract_config.extraction_agent_id,
file_id=file_id,
)
else:
extract_job = await llama_cloud_client.extraction.run(
config=extract_config.settings.model_dump(),
data_schema=extract_config.json_schema,
file_id=file_id,
project_id=project_id,
)
file_hash = event.file_hash or file_metadata.external_file_id
async with ctx.store.edit_state() as state:
state.file_id = file_id
state.filename = filename
state.file_hash = file_hash
state.extract_job_id = extract_job.id
return ExtractJobStartedEvent()
@step()
async def classify_file(
self, event: ClassifyFileEvent, ctx: Context[ExtractionState]
self,
event: ExtractJobStartedEvent,
ctx: Context[ExtractionState],
llama_cloud_client: Annotated[
AsyncLlamaCloud, Resource(get_llama_cloud_client)
],
classify_config: Annotated[
ClassifyConfig,
ResourceConfig(
config_file="configs/config.json",
path_selector="classify",
label="Classification Rules",
description="Rules for classifying SEC filing types",
),
],
) -> FileClassifiedEvent:
"""Classify the SEC filing document type"""
"""Classify the SEC filing document type while extraction runs."""
state = await ctx.store.get_state()
if state.file_path is None or state.filename is None:
raise ValueError("File path or filename is not set")
if state.file_id is None or state.filename is None:
raise ValueError("File ID or filename is not set")
try:
logger.info(f"Classifying file {state.filename}")
@@ -140,71 +169,42 @@ class ProcessFileWorkflow(Workflow):
Status(level="info", message=f"Classifying file {state.filename}")
)
# Initialize the classifier
classifier = get_classifier_client()
# Define classification rules for SEC filing types
# Build rules from config
rules = [
ClassifierRule(
type="10-K",
description=(
"Form 10-K is an annual report filed by public companies with the SEC. "
"It provides a comprehensive summary of a company's financial performance for the year, "
"including audited financial statements, management's discussion and analysis (MD&A), "
"risk factors, business description, and executive compensation. "
"Look for: 'Form 10-K', 'Annual Report', fiscal year references, audited financials."
),
),
ClassifierRule(
type="10-Q",
description=(
"Form 10-Q is a quarterly report filed by public companies with the SEC. "
"It provides unaudited financial statements and management discussion for a specific quarter. "
"Contains quarterly financial data, updates on business operations, and material changes. "
"Look for: 'Form 10-Q', 'Quarterly Report', quarter references (Q1, Q2, Q3), unaudited statements."
),
),
ClassifierRule(
type="8-K",
description=(
"Form 8-K is a current report filed to announce material events or corporate changes. "
"Used to notify investors of significant events like mergers, acquisitions, leadership changes, "
"earnings releases, or other material corporate events that shareholders should know about. "
"Look for: 'Form 8-K', 'Current Report', Item numbers (e.g., Item 1.01, Item 5.02), event dates, "
"specific triggering events."
),
),
ClassifierRule(
type="other",
description=(
"Any other SEC filing type not covered by 10-K, 10-Q, or 8-K. "
"This includes forms such as S-1 (IPO registration), DEF 14A (proxy statement), "
"13F (institutional holdings), SC 13D (beneficial ownership), and other SEC forms."
),
),
{"type": rule.type, "description": rule.description}
for rule in classify_config.rules
]
# Configure parsing - only parse first few pages for classification
parsing_config = ClassifyParsingConfiguration(
max_pages=5, # Only parse first 5 pages for faster classification
)
# Build parsing config from settings
parsing_config: dict[str, Any] = {}
if classify_config.settings.parsing_config.max_pages is not None:
parsing_config["max_pages"] = (
classify_config.settings.parsing_config.max_pages
)
if classify_config.settings.parsing_config.target_pages is not None:
parsing_config["target_pages"] = (
classify_config.settings.parsing_config.target_pages
)
# Classify the file
results = await classifier.aclassify_file_paths(
# 3-step classify: create job, wait, get results
classify_job = await llama_cloud_client.classifier.jobs.create(
file_ids=[state.file_id],
rules=rules,
file_input_paths=[state.file_path],
parsing_configuration=parsing_config,
mode=classify_config.settings.mode,
**({"parsing_configuration": parsing_config} if parsing_config else {}),
)
await llama_cloud_client.classifier.wait_for_completion(classify_job.id)
results = await llama_cloud_client.classifier.jobs.get_results(
classify_job.id
)
# Extract classification result
if results.items and len(results.items) > 0:
item = results.items[0]
result: ClassificationResult | None = item.result
if result:
filing_type = result.type
confidence = result.confidence
reasoning = result.reasoning
if item.result:
filing_type = item.result.type
confidence = item.result.confidence
reasoning = item.result.reasoning
logger.info(
f"Classified {state.filename} as {filing_type} "
@@ -228,7 +228,6 @@ class ProcessFileWorkflow(Workflow):
reasoning=reasoning,
)
else:
# Classification failed, default to "other"
logger.warning(
f"Classification failed for {state.filename}, defaulting to 'other'"
)
@@ -242,7 +241,6 @@ class ProcessFileWorkflow(Workflow):
state.filing_type = "other"
return FileClassifiedEvent(filing_type="other")
else:
# No results, default to "other"
logger.warning(f"No classification results for {state.filename}")
async with ctx.store.edit_state() as state:
state.filing_type = "other"
@@ -256,74 +254,113 @@ class ProcessFileWorkflow(Workflow):
message=f"Classification failed, using default schema: {e}",
)
)
# On error, default to "other" and continue
async with ctx.store.edit_state() as state:
state.filing_type = "other"
return FileClassifiedEvent(filing_type="other")
@step()
async def process_file(
self, event: FileClassifiedEvent, ctx: Context[ExtractionState]
) -> ExtractedEvent | ExtractedInvalidEvent:
"""Runs the extraction against the file"""
async def complete_extraction(
self,
event: FileClassifiedEvent,
ctx: Context[ExtractionState],
llama_cloud_client: Annotated[
AsyncLlamaCloud, Resource(get_llama_cloud_client)
],
extract_10k: Annotated[
ExtractConfig,
ResourceConfig(
config_file="configs/config.json",
path_selector="extract-10k",
label="10-K Extraction",
),
],
extract_10q: Annotated[
ExtractConfig,
ResourceConfig(
config_file="configs/config.json",
path_selector="extract-10q",
label="10-Q Extraction",
),
],
extract_8k: Annotated[
ExtractConfig,
ResourceConfig(
config_file="configs/config.json",
path_selector="extract-8k",
label="8-K Extraction",
),
],
extract_other: Annotated[
ExtractConfig,
ResourceConfig(
config_file="configs/config.json",
path_selector="extract-other",
label="Other Extraction",
),
],
) -> StopEvent:
"""Wait for extraction to complete, validate results, and save for review."""
state = await ctx.store.get_state()
if state.file_path is None or state.filename is None:
raise ValueError("File path or filename is not set")
if state.extract_job_id is None:
raise ValueError("Job ID cannot be null when waiting for its completion")
# Select the extract config for the classified filing type
extract_configs = {
"10-K": extract_10k,
"10-Q": extract_10q,
"8-K": extract_8k,
"other": extract_other,
}
filing_type = state.filing_type or "other"
extract_config = extract_configs.get(filing_type, extract_other)
await llama_cloud_client.extraction.jobs.wait_for_completion(
state.extract_job_id
)
extracted_result = await llama_cloud_client.extraction.jobs.get_result(
state.extract_job_id
)
extract_run = await llama_cloud_client.extraction.runs.get(
run_id=extracted_result.run_id
)
extracted_event: ExtractedEvent | ExtractedInvalidEvent
try:
# Get the appropriate schema based on classification
filing_type = (state.filing_type or "other").upper()
schema = FILING_SCHEMAS.get(filing_type, FILING_SCHEMAS["other"])
logger.info(f"Using schema for filing type: {filing_type}")
ctx.write_event_to_stream(
Status(
level="info",
message=f"Extracting data using {filing_type} schema",
)
logger.info(
f"Extracted data: {json.dumps(extracted_result.model_dump(), indent=2)}"
)
agent = get_extract_agent()
# Update the agent's data schema for this specific filing type
agent.data_schema = schema
# track the content of the file, so as to be able to de-duplicate
file_content = Path(state.file_path).read_bytes()
file_hash = hashlib.sha256(file_content).hexdigest()
source_text = SourceText(
file=state.file_path,
filename=state.filename,
if extract_config.extraction_agent_id:
agent = await llama_cloud_client.extraction.extraction_agents.get(
extract_config.extraction_agent_id
)
schema_class = get_extraction_schema(agent.data_schema)
else:
schema_class = get_extraction_schema(
extract_config.json_schema,
discriminator_field=DISCRIMINATOR_FIELD,
discriminator_value=filing_type,
)
data = ExtractedData.from_extraction_result(
result=extract_run,
schema=schema_class,
file_name=state.filename,
file_id=state.file_id,
file_hash=state.file_hash,
)
logger.info(f"Extracting data from file {state.filename}")
ctx.write_event_to_stream(
Status(
level="info", message=f"Extracting data from file {state.filename}"
)
)
extracted_result: ExtractRun = await agent.aextract(source_text)
try:
logger.info(f"Extracted data: {extracted_result}")
data = ExtractedData.from_extraction_result(
result=extracted_result,
schema=schema,
file_hash=file_hash,
)
# Add classification information to the extracted data
if data.metadata is None:
data.metadata = {}
data.metadata["classification"] = filing_type
data.metadata["classification_confidence"] = (
state.classification_confidence
)
data.metadata["classification_reasoning"] = (
state.classification_reasoning
)
return ExtractedEvent(data=data)
except InvalidExtractionData as e:
logger.error(f"Error validating extracted data: {e}", exc_info=True)
return ExtractedInvalidEvent(data=e.invalid_item)
# Add classification information to the extracted data
if data.metadata is None:
data.metadata = {}
data.metadata["classification"] = filing_type
data.metadata["classification_confidence"] = state.classification_confidence
data.metadata["classification_reasoning"] = state.classification_reasoning
extracted_event = ExtractedEvent(data=data)
except InvalidExtractionData as e:
logger.error(f"Error validating extracted data: {e}", exc_info=True)
extracted_event = ExtractedInvalidEvent(data=e.invalid_item)
except Exception as e:
logger.error(
f"Error extracting data from file {state.filename}: {e}",
exc_info=True,
f"Error extracting data from file {state.filename}: {e}", exc_info=True
)
ctx.write_event_to_stream(
Status(
@@ -333,61 +370,56 @@ class ProcessFileWorkflow(Workflow):
)
raise e
@step()
async def record_extracted_data(
self, event: ExtractedEvent | ExtractedInvalidEvent, ctx: Context
) -> StopEvent:
"""Records the extracted data to the agent data API"""
try:
logger.info(f"Recorded extracted data for file {event.data.file_name}")
ctx.write_event_to_stream(
Status(
level="info",
message=f"Recorded extracted data for file {event.data.file_name}",
)
)
# remove past data when reprocessing the same file
if event.data.file_hash:
await get_data_client().delete(
filter={
"file_hash": {
"eq": event.data.file_hash,
},
ctx.write_event_to_stream(extracted_event)
extracted_data = extracted_event.data
data_dict = extracted_data.model_dump()
if extracted_data.file_hash is not None:
delete_result = await llama_cloud_client.beta.agent_data.delete_by_query(
deployment_name=agent_name or "_public",
collection=EXTRACTED_DATA_COLLECTION,
filter={
"file_hash": {
"eq": extracted_data.file_hash,
},
)
},
)
if delete_result.deleted_count > 0:
logger.info(
f"Removing past data for file {event.data.file_name} with hash {event.data.file_hash}"
f"Removed {delete_result.deleted_count} existing record(s) "
f"for file {extracted_data.file_name}"
)
# finally, save the new data
item_id = await get_data_client().create_item(event.data)
return StopEvent(
result=item_id.id,
item = await llama_cloud_client.beta.agent_data.agent_data(
data=data_dict,
deployment_name=agent_name or "_public",
collection=EXTRACTED_DATA_COLLECTION,
)
logger.info(
f"Recorded extracted data for file {extracted_data.file_name or ''}"
)
ctx.write_event_to_stream(
Status(
level="info",
message=f"Recorded extracted data for file {extracted_data.file_name or ''}",
)
except Exception as e:
logger.error(
f"Error recording extracted data for file {event.data.file_name}: {e}",
exc_info=True,
)
ctx.write_event_to_stream(
Status(
level="error",
message=f"Error recording extracted data for file {event.data.file_name}: {e}",
)
)
raise e
)
return StopEvent(result=item.id)
workflow = ProcessFileWorkflow(timeout=None)
if __name__ == "__main__":
from pathlib import Path
from dotenv import load_dotenv
load_dotenv()
logging.basicConfig(level=logging.INFO)
async def main():
file = await get_llama_cloud_client().files.upload_file(
upload_file=Path("test.pdf").open("rb")
file = await get_llama_cloud_client().files.create(
file=Path("test.pdf").open("rb"),
purpose="extract",
)
await workflow.run(start_event=FileEvent(file_id=file.id))
-57
View File
@@ -1,57 +0,0 @@
"""
Selects a locally defined shema, or queries the remote extraction agent for the schema.
"""
import asyncio
import jsonref
from .clients import get_extract_agent
from .config import USE_REMOTE_EXTRACTION_SCHEMA, ExtractionSchema
from typing import Any, Type
from pydantic import BaseModel
from pydantic import create_model, Field
SCHEMA: Type[BaseModel] | None = (
None if USE_REMOTE_EXTRACTION_SCHEMA else ExtractionSchema
)
_schema_lock = asyncio.Lock()
async def get_extraction_schema() -> Type[BaseModel]:
global SCHEMA
if SCHEMA is not None:
return SCHEMA
async with _schema_lock:
if SCHEMA is not None:
return SCHEMA
agent = get_extract_agent()
SCHEMA = model_from_schema(agent.data_schema)
return SCHEMA
async def get_extraction_schema_json() -> dict[str, Any]:
json_schema = (await get_extraction_schema()).model_json_schema()
json_schema = jsonref.replace_refs(json_schema, proxies=False)
return json_schema
def model_from_schema(schema: dict[str, Any]) -> Type[BaseModel]:
"""
Converts a JSON schema back to a Pydantic model.
"""
typemap = {
"string": str,
"integer": int,
"number": float,
"boolean": bool,
"array": list,
"object": dict,
}
fields = {}
for prop, meta in schema.get("properties", {}).items():
py_type = typemap.get(meta.get("type"), Any)
default = ... if prop in schema.get("required", []) else None
fields[prop] = (py_type, Field(default, description=meta.get("description")))
return create_model(schema.get("title", "DynamicModel"), **fields)
+7 -7
View File
@@ -16,18 +16,18 @@
"dependencies": {
"@babel/runtime": "^7.27.6",
"@lezer/highlight": "^1.2.1",
"@llamaindex/ui": "~3.6.1",
"@llamaindex/workflows-client": "^1.7.0",
"@radix-ui/themes": "^3.2.1",
"@llamaindex/llama-cloud": "^1.3.0",
"@llamaindex/ui": "^4.1.3",
"@llamaindex/workflows-client": "^1.8.3",
"@radix-ui/themes": "^3.3.0",
"class-variance-authority": "^0.7.1",
"clsx": "^2.1.1",
"llama-cloud-services": "^0.5.2",
"lucide-react": "^0.514.0",
"react": "^19.0.0",
"react-dom": "^19.0.0",
"react-router-dom": "^6.30.0",
"sonner": "^2.0.5",
"tw-animate-css": "^1.3.5"
"react-router-dom": "^7.8.0",
"sonner": "^1.7.2",
"tw-animate-css": "^1.4.0"
},
"devDependencies": {
"@tailwindcss/postcss": "^4.1.10",
+12 -18
View File
@@ -1,10 +1,10 @@
import { ExtractedData } from "llama-cloud-services/beta/agent";
import {
ApiClients,
createWorkflowsClient,
createWorkflowsConfig,
createCloudAgentClient,
cloudApiClient,
configureCloudClient,
getCloudClient,
createAgentDataConfig,
} from "@llamaindex/ui";
import { AGENT_NAME } from "./config";
import type { Metadata } from "./useMetadata";
@@ -13,16 +13,11 @@ const platformToken = import.meta.env.VITE_LLAMA_CLOUD_API_KEY;
const apiBaseUrl = import.meta.env.VITE_LLAMA_CLOUD_BASE_URL;
const projectId = import.meta.env.VITE_LLAMA_DEPLOY_PROJECT_ID;
// Configure the platform client
cloudApiClient.setConfig({
...(apiBaseUrl && { baseUrl: apiBaseUrl }),
headers: {
// optionally use a backend API token scoped to a project. For local development,
...(platformToken && { authorization: `Bearer ${platformToken}` }),
// This header is required for requests to correctly scope to the agent's project
// when authenticating with a user cookie
...(projectId && { "Project-Id": projectId }),
},
// Configure the cloud client
configureCloudClient({
...(apiBaseUrl && { baseURL: apiBaseUrl }),
...(platformToken && { apiKey: platformToken }),
...(projectId && { projectId }),
});
export function createBaseWorkflowClient(): ReturnType<
@@ -37,15 +32,14 @@ export function createBaseWorkflowClient(): ReturnType<
export function createClients(metadata: Metadata): ApiClients {
const workflowsClient = createBaseWorkflowClient();
const agentClient = createCloudAgentClient<ExtractedData<any>>({
client: cloudApiClient,
const agentDataConfig = createAgentDataConfig({
windowUrl: typeof window !== "undefined" ? window.location.href : undefined,
collection: metadata.extracted_data_collection,
});
return {
workflowsClient,
cloudApiClient,
agentDataClient: agentClient,
} as ApiClients;
cloudApiClient: getCloudClient(),
agentDataConfig,
};
}
+7 -16
View File
@@ -1,11 +1,5 @@
import type {
ExtractedData,
TypedAgentData,
} from "llama-cloud-services/beta/agent";
import type { AgentDataItem, ExtractedData } from "@llamaindex/ui";
/**
* Downloads data as a JSON file
*/
export function downloadJSON<T>(
data: T,
filename: string = "extraction-results.json",
@@ -20,19 +14,16 @@ export function downloadJSON<T>(
document.body.appendChild(link);
link.click();
// Cleanup
document.body.removeChild(link);
URL.revokeObjectURL(url);
}
/**
* Downloads extracted data item as JSON
*/
export function downloadExtractedDataItem<T>(
item: TypedAgentData<ExtractedData<T>>,
) {
const fileName = item.data.file_name || "item";
const timestamp = item.createdAt.toISOString().split("T")[0];
export function downloadExtractedDataItem(item: AgentDataItem) {
const extractedData = item.data as ExtractedData<unknown>;
const fileName = extractedData.file_name || "item";
const timestamp = item.created_at
? new Date(item.created_at).toISOString().split("T")[0]
: new Date().toISOString().split("T")[0];
const filename = `${fileName}-${timestamp}.json`;
downloadJSON(item, filename);
+1 -2
View File
@@ -1,7 +1,6 @@
import { clsx, type ClassValue } from "clsx";
import { twMerge } from "tailwind-merge";
import type { Highlight } from "@llamaindex/ui";
import type { FieldCitation } from "llama-cloud-services/beta/agent";
import type { Highlight, FieldCitation } from "@llamaindex/ui";
export function cn(...inputs: ClassValue[]) {
return twMerge(clsx(inputs));
+4 -2
View File
@@ -3,8 +3,8 @@ import {
WorkflowTrigger,
ExtractedDataItemGrid,
HandlerState,
AgentDataItem,
} from "@llamaindex/ui";
import type { TypedAgentData } from "llama-cloud-services/beta/agent";
import styles from "./HomePage.module.css";
import { useNavigate } from "react-router-dom";
import { useState } from "react";
@@ -16,7 +16,7 @@ export default function HomePage() {
function TaskList() {
const navigate = useNavigate();
const goToItem = (item: TypedAgentData) => {
const goToItem = (item: AgentDataItem) => {
navigate(`/item/${item.id}`);
};
const [reloadSignal, setReloadSignal] = useState(0);
@@ -52,9 +52,11 @@ function TaskList() {
/>
<WorkflowTrigger
workflowName="process-file"
contentHash={{ enabled: true }}
customWorkflowInput={(files) => {
return {
file_id: files[0].fileId,
file_hash: files[0].contentHash ?? null,
};
}}
onSuccess={(handler) => {
+21 -17
View File
@@ -5,6 +5,7 @@ import {
FilePreview,
useItemData,
type Highlight,
type ExtractedData,
Button,
} from "@llamaindex/ui";
import { Clock, XCircle, Download } from "lucide-react";
@@ -32,8 +33,11 @@ export default function ItemPage() {
});
// Determine the correct schema based on classification
const classificationData = itemHookData.item?.data as
| ExtractedData<any>
| undefined;
const classification = (
(itemHookData.item?.data?.metadata?.classification as string | undefined) ||
(classificationData?.metadata?.classification as string | undefined) ||
"10-K"
).toUpperCase();
const correctSchema =
@@ -52,9 +56,11 @@ export default function ItemPage() {
const navigate = useNavigate();
// Update breadcrumb when item data loads
useEffect(() => {
const fileName = itemHookData.item?.data?.file_name;
const extractedData = itemHookData.item?.data as
| ExtractedData<unknown>
| undefined;
const fileName = extractedData?.file_name;
if (fileName) {
setBreadcrumbs([
{ label: APP_TITLE, href: "/" },
@@ -66,10 +72,9 @@ export default function ItemPage() {
}
return () => {
// Reset to default breadcrumb when leaving the page
setBreadcrumbs([{ label: APP_TITLE, href: "/" }]);
};
}, [itemHookData.item?.data?.file_name, setBreadcrumbs]);
}, [itemHookData.item?.data, setBreadcrumbs]);
useEffect(() => {
setButtons(() => [
@@ -83,10 +88,9 @@ export default function ItemPage() {
}
}}
disabled={!itemData}
>
<Download className="h-4 w-4 mr-2" />
Export JSON
</Button>
startIcon={<Download className="h-4 w-4" />}
label="Export JSON"
/>
<AcceptReject<any>
itemData={itemHookData}
onComplete={() => navigate("/")}
@@ -105,8 +109,8 @@ export default function ItemPage() {
error,
} = itemHookData;
const classificationReasoning = itemData?.data?.metadata
?.classification_reasoning as string | undefined;
const classificationReasoning = (itemData?.data as ExtractedData<any>)
?.metadata?.classification_reasoning as string | undefined;
if (isLoading) {
return (
@@ -132,13 +136,15 @@ export default function ItemPage() {
);
}
const extractedData = itemData.data as ExtractedData<any>;
const fileId = extractedData.file_id;
return (
<div className="flex h-full bg-gray-50">
{/* Left Side - File Preview */}
<div className="w-1/2 border-r h-full border-gray-200 bg-white">
{itemData.data.file_id && (
{fileId && (
<FilePreview
fileId={itemData.data.file_id}
fileId={fileId}
onBoundingBoxClick={(box, pageNumber) => {
console.log("Bounding box clicked:", box, "on page:", pageNumber);
}}
@@ -147,7 +153,6 @@ export default function ItemPage() {
)}
</div>
{/* Right Side - Review Panel */}
<div className="flex-1 bg-white h-full overflow-y-auto">
<div className="p-4 space-y-4">
{/* Classification Info */}
@@ -163,10 +168,9 @@ export default function ItemPage() {
)}
</div>
)}
{/* Extracted Data */}
<ExtractedDataDisplay<any>
key={schemaKey}
extractedData={itemData.data}
extractedData={extractedData}
title="Extracted Data"
onChange={(updatedData) => {
updateData(updatedData);