mirror of
https://github.com/run-llama/template-workflow-classify-extract-sec.git
synced 2026-07-01 21:54:02 -04:00
feat: copier update child templates from data-extraction v0.5.0 (#241)
This commit is contained in:
+1
-1
@@ -1,3 +1,3 @@
|
||||
# Changes here will be overwritten by Copier; NEVER EDIT MANUALLY
|
||||
_commit: v0.4.0
|
||||
_commit: v0.5.0
|
||||
_src_path: https://github.com/run-llama/template-workflow-data-extraction
|
||||
|
||||
+1206
File diff suppressed because it is too large
Load Diff
+3
-2
@@ -5,8 +5,9 @@ description = "Extracts data"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.12"
|
||||
dependencies = [
|
||||
"llama-cloud-services>=0.6.69",
|
||||
"llama-index-workflows>=2.2.0,<3.0.0",
|
||||
"llama-cloud>=1.3.0",
|
||||
"json-schema-to-pydantic>=0.4.8",
|
||||
"llama-index-workflows>=2.16.0,<3.0.0",
|
||||
"python-dotenv>=1.1.0",
|
||||
"jsonref>=1.1.0",
|
||||
"click>=8.2.1,<8.3.0",
|
||||
|
||||
@@ -1,29 +1,11 @@
|
||||
import functools
|
||||
import os
|
||||
from typing import Any
|
||||
import httpx
|
||||
|
||||
from llama_cloud_services import ExtractionAgent, LlamaExtract
|
||||
from llama_cloud.core.api_error import ApiError
|
||||
from llama_cloud_services.beta.agent_data import AsyncAgentDataClient, ExtractedData
|
||||
from llama_cloud_services.beta.classifier.client import ClassifyClient
|
||||
from llama_cloud.client import AsyncLlamaCloud
|
||||
import logging
|
||||
import os
|
||||
|
||||
from extraction_review.config import (
|
||||
EXTRACT_CONFIG,
|
||||
EXTRACTED_DATA_COLLECTION,
|
||||
EXTRACTION_AGENT_NAME,
|
||||
USE_REMOTE_EXTRACTION_SCHEMA,
|
||||
ExtractionSchema,
|
||||
)
|
||||
from llama_cloud import AsyncLlamaCloud
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# deployed agents may infer their name from the deployment name
|
||||
# Note: Make sure that an agent deployment with this name actually exists
|
||||
# otherwise calls to get or set data will fail. You may need to adjust the `or `
|
||||
# name for development
|
||||
agent_name = os.getenv("LLAMA_DEPLOY_DEPLOYMENT_NAME")
|
||||
# required for all llama cloud calls
|
||||
api_key = os.getenv("LLAMA_CLOUD_API_KEY")
|
||||
@@ -32,57 +14,10 @@ base_url = os.getenv("LLAMA_CLOUD_BASE_URL")
|
||||
project_id = os.getenv("LLAMA_DEPLOY_PROJECT_ID")
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def get_extract_agent() -> ExtractionAgent:
|
||||
extract_api = LlamaExtract(
|
||||
api_key=api_key, base_url=base_url, project_id=project_id
|
||||
)
|
||||
|
||||
try:
|
||||
existing = extract_api.get_agent(EXTRACTION_AGENT_NAME)
|
||||
if not USE_REMOTE_EXTRACTION_SCHEMA:
|
||||
existing.data_schema = ExtractionSchema
|
||||
existing.config = EXTRACT_CONFIG
|
||||
return existing
|
||||
except ApiError as e:
|
||||
if e.status_code == 404:
|
||||
if USE_REMOTE_EXTRACTION_SCHEMA:
|
||||
logger.warning(
|
||||
"Extraction agent does not exist, creating a new one from the local schema"
|
||||
)
|
||||
return extract_api.create_agent(
|
||||
name=EXTRACTION_AGENT_NAME,
|
||||
data_schema=ExtractionSchema,
|
||||
config=EXTRACT_CONFIG,
|
||||
)
|
||||
else:
|
||||
raise
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def get_data_client() -> AsyncAgentDataClient:
|
||||
return AsyncAgentDataClient(
|
||||
deployment_name=agent_name,
|
||||
collection=EXTRACTED_DATA_COLLECTION,
|
||||
type=ExtractedData[Any],
|
||||
client=get_llama_cloud_client(),
|
||||
)
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def get_llama_cloud_client():
|
||||
def get_llama_cloud_client() -> AsyncLlamaCloud:
|
||||
"""Cloud services connection for file storage and processing."""
|
||||
return AsyncLlamaCloud(
|
||||
api_key=api_key,
|
||||
base_url=base_url,
|
||||
token=api_key,
|
||||
httpx_client=httpx.AsyncClient(
|
||||
timeout=60, headers={"Project-Id": project_id} if project_id else None
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def get_classifier_client():
|
||||
return ClassifyClient(
|
||||
client=get_llama_cloud_client(),
|
||||
project_id=project_id,
|
||||
default_headers={"Project-Id": project_id} if project_id else {},
|
||||
)
|
||||
|
||||
@@ -1,30 +1,35 @@
|
||||
"""
|
||||
For simple configuration of the extraction review application, just customize this file.
|
||||
Configuration for the extraction review application.
|
||||
|
||||
If you need more control, feel free to edit the rest of the application
|
||||
Configuration is loaded from configs/config.json via ResourceConfig.
|
||||
The unified config contains both extraction settings and the JSON schema.
|
||||
|
||||
Extraction can run in two modes, controlled by the "extraction_agent_id" field
|
||||
in configs/config.json:
|
||||
|
||||
- Local (default): extraction_agent_id is null. Uses the json_schema and
|
||||
settings defined in config.json directly via extraction.run().
|
||||
|
||||
- Remote agent: extraction_agent_id is set to a LlamaCloud extraction agent
|
||||
ID. Uses extraction.jobs.extract(extraction_agent_id=...) which delegates
|
||||
schema and settings to the remote agent. The local json_schema and settings
|
||||
in config.json are ignored — both extraction and the metadata workflow fetch
|
||||
the schema directly from the remote agent.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
import os
|
||||
from typing import Type
|
||||
import logging
|
||||
from typing import Any, Literal
|
||||
|
||||
from llama_cloud import ExtractConfig
|
||||
from llama_cloud_services.extract import ExtractMode
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
# If you change this to true, the schema and extraction configuration will be fetched from the remote extraction agent
|
||||
# rather than using the ExtractionSchema and configuration defined below.
|
||||
USE_REMOTE_EXTRACTION_SCHEMA: bool = False
|
||||
# The name of the extraction agent to use. Prefers the name of this deployment when deployed to isolate environments.
|
||||
# Note that the application will create a new agent from the below ExtractionSchema if the extraction agent does not yet exist.
|
||||
EXTRACTION_AGENT_NAME: str = (
|
||||
os.getenv("LLAMA_DEPLOY_DEPLOYMENT_NAME") or "sec-filing-extraction"
|
||||
)
|
||||
# The name of the collection to use for storing extracted data. This will be qualified by the agent name.
|
||||
# When developing locally, this will use the _public collection (shared within the project), otherwise agent
|
||||
# data is isolated to each agent
|
||||
EXTRACTED_DATA_COLLECTION: str = "sec-filing-extraction"
|
||||
from .json_util import create_union_schema as create_union_schema
|
||||
from .json_util import get_extraction_schema as get_extraction_schema
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# The name of the collection to use for storing extracted data.
|
||||
EXTRACTED_DATA_COLLECTION: str = "sec-filing-extraction"
|
||||
|
||||
# SEC Filing Classification Types
|
||||
SEC_FILING_TYPES = ["10-K", "10-Q", "8-K", "other"]
|
||||
@@ -346,18 +351,58 @@ FILING_SCHEMAS = {
|
||||
}
|
||||
|
||||
|
||||
# This is only used if USE_REMOTE_EXTRACTION_SCHEMA is False.
|
||||
EXTRACT_CONFIG = ExtractConfig(
|
||||
extraction_mode=ExtractMode.PREMIUM,
|
||||
system_prompt=None,
|
||||
# advanced. Only compatible with Premium mode.
|
||||
citation_bbox=True,
|
||||
use_reasoning=False,
|
||||
cite_sources=True,
|
||||
confidence_scores=False,
|
||||
)
|
||||
class ExtractSettings(BaseModel):
|
||||
extraction_mode: Literal["FAST", "PREMIUM", "MULTIMODAL"]
|
||||
system_prompt: str | None = None
|
||||
citation_bbox: bool = False
|
||||
use_reasoning: bool = False
|
||||
cite_sources: bool = False
|
||||
confidence_scores: bool = False
|
||||
|
||||
|
||||
SCHEMA: Type[BaseModel] | None = (
|
||||
None if USE_REMOTE_EXTRACTION_SCHEMA else ExtractionSchema
|
||||
)
|
||||
class ExtractConfig(BaseModel):
|
||||
json_schema: dict[str, Any]
|
||||
settings: ExtractSettings
|
||||
# Set this to a LlamaCloud extraction agent ID to use a remote agent's
|
||||
# schema and settings instead of the local json_schema/settings above.
|
||||
# When set, extraction uses extraction.jobs.extract(extraction_agent_id=...)
|
||||
# and the local settings are ignored for extraction.
|
||||
extraction_agent_id: str | None = None
|
||||
|
||||
|
||||
class JsonSchema(BaseModel):
|
||||
type: str = "object"
|
||||
properties: dict[str, Any] = {}
|
||||
required: list[str] = []
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return self.model_dump(exclude_none=True)
|
||||
|
||||
|
||||
class ClassifyRule(BaseModel):
|
||||
"""Classify rule, with type (rule target) and description (rule description)"""
|
||||
|
||||
type: str
|
||||
description: str
|
||||
|
||||
|
||||
class ClassifyParsingConfig(BaseModel):
|
||||
"""Parsing config for Classify"""
|
||||
|
||||
lang: str = Field(description="two-letter ISO 639 language code", default="en")
|
||||
max_pages: int | None = None
|
||||
target_pages: list[int] | None = None
|
||||
|
||||
|
||||
class ClassifySettings(BaseModel):
|
||||
"""Extra settings for Classify"""
|
||||
|
||||
mode: Literal["FAST", "MULTIMODAL"] = "FAST"
|
||||
parsing_config: ClassifyParsingConfig = ClassifyParsingConfig()
|
||||
|
||||
|
||||
class ClassifyConfig(BaseModel):
|
||||
"""Classify configuration, with rules and settings"""
|
||||
|
||||
rules: list[ClassifyRule] = []
|
||||
settings: ClassifySettings = ClassifySettings()
|
||||
|
||||
@@ -0,0 +1,163 @@
|
||||
"""Utilities for working with JSON schemas."""
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
from functools import lru_cache
|
||||
from typing import Any
|
||||
|
||||
from json_schema_to_pydantic import create_model
|
||||
from pydantic import BaseModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _hash_schema(json_schema: dict[str, Any]) -> str:
|
||||
"""Create a stable hash of a JSON schema for caching."""
|
||||
schema_str = json.dumps(json_schema, sort_keys=True)
|
||||
return hashlib.sha256(schema_str.encode()).hexdigest()
|
||||
|
||||
|
||||
@lru_cache(maxsize=16)
|
||||
def _get_cached_model(schema_hash: str, schema_json: str) -> type[BaseModel]:
|
||||
"""Get or create a Pydantic model from a JSON schema, cached by hash."""
|
||||
schema = json.loads(schema_json)
|
||||
return create_model(schema)
|
||||
|
||||
|
||||
def get_extraction_schema(
|
||||
json_schema: dict[str, Any],
|
||||
discriminator_field: str | None = None,
|
||||
discriminator_value: str | None = None,
|
||||
) -> type[BaseModel]:
|
||||
"""Convert a JSON schema to a Pydantic model for validating extraction results.
|
||||
|
||||
Args:
|
||||
json_schema: A JSON Schema object from config.
|
||||
discriminator_field: Field name to identify document type (e.g., "document_type").
|
||||
discriminator_value: The value for the discriminator field (e.g., "10-K").
|
||||
|
||||
Returns:
|
||||
A Pydantic model class for validation.
|
||||
"""
|
||||
schema = json_schema
|
||||
|
||||
if discriminator_field is not None:
|
||||
if discriminator_value is None:
|
||||
raise ValueError(
|
||||
"discriminator_value is required when discriminator_field is set"
|
||||
)
|
||||
schema = _add_discriminator_to_schema(
|
||||
schema, discriminator_field, discriminator_value
|
||||
)
|
||||
|
||||
schema_hash = _hash_schema(schema)
|
||||
schema_json = json.dumps(schema, sort_keys=True)
|
||||
return _get_cached_model(schema_hash, schema_json)
|
||||
|
||||
|
||||
def _add_discriminator_to_schema(
|
||||
schema: dict[str, Any],
|
||||
discriminator_field: str,
|
||||
discriminator_value: str,
|
||||
) -> dict[str, Any]:
|
||||
"""Add a discriminator field with a default value to a schema."""
|
||||
properties = schema.get("properties", {})
|
||||
required = schema.get("required", [])
|
||||
|
||||
new_properties = {
|
||||
discriminator_field: {
|
||||
"type": "string",
|
||||
"default": discriminator_value,
|
||||
"description": "Type of document that was extracted",
|
||||
},
|
||||
**properties,
|
||||
}
|
||||
|
||||
new_required = list(required)
|
||||
if discriminator_field not in new_required:
|
||||
new_required = [discriminator_field] + new_required
|
||||
|
||||
return {
|
||||
**schema,
|
||||
"properties": new_properties,
|
||||
"required": new_required,
|
||||
}
|
||||
|
||||
|
||||
def _schemas_are_equal(schema1: dict[str, Any], schema2: dict[str, Any]) -> bool:
|
||||
"""Check if two JSON schemas are structurally equal."""
|
||||
return json.dumps(schema1, sort_keys=True) == json.dumps(schema2, sort_keys=True)
|
||||
|
||||
|
||||
def _merge_property_schemas(
|
||||
existing: dict[str, Any], new: dict[str, Any]
|
||||
) -> dict[str, Any]:
|
||||
"""Merge two property schemas, creating anyOf for conflicting types."""
|
||||
if _schemas_are_equal(existing, new):
|
||||
return existing
|
||||
|
||||
if "anyOf" in existing:
|
||||
for variant in existing["anyOf"]:
|
||||
if _schemas_are_equal(variant, new):
|
||||
return existing
|
||||
return {"anyOf": existing["anyOf"] + [new]}
|
||||
|
||||
return {"anyOf": [existing, new]}
|
||||
|
||||
|
||||
def create_union_schema(
|
||||
schemas: dict[str, dict[str, Any]],
|
||||
discriminator_field: str = "document_type",
|
||||
) -> dict[str, Any]:
|
||||
"""Create a union JSON schema from multiple extraction schemas.
|
||||
|
||||
Merges all properties from input schemas into a single flat schema,
|
||||
adding a discriminator field.
|
||||
"""
|
||||
schemas_with_discriminator = [
|
||||
name
|
||||
for name, schema in schemas.items()
|
||||
if discriminator_field in schema.get("properties", {})
|
||||
]
|
||||
if schemas_with_discriminator:
|
||||
logger.warning(
|
||||
f"Discriminator field '{discriminator_field}' found in schemas: "
|
||||
f"{', '.join(schemas_with_discriminator)}. "
|
||||
f"It will be replaced with the union discriminator."
|
||||
)
|
||||
|
||||
all_properties: dict[str, Any] = {}
|
||||
for schema in schemas.values():
|
||||
for prop_name, prop_def in schema.get("properties", {}).items():
|
||||
if prop_name == discriminator_field:
|
||||
continue
|
||||
if prop_name not in all_properties:
|
||||
all_properties[prop_name] = prop_def
|
||||
else:
|
||||
all_properties[prop_name] = _merge_property_schemas(
|
||||
all_properties[prop_name], prop_def
|
||||
)
|
||||
|
||||
schema_list = list(schemas.values())
|
||||
common_required: set[str] = set()
|
||||
if schema_list:
|
||||
common_required = set(schema_list[0].get("required", []))
|
||||
for schema in schema_list[1:]:
|
||||
common_required &= set(schema.get("required", []))
|
||||
|
||||
common_required.discard(discriminator_field)
|
||||
|
||||
required_fields = [discriminator_field] + sorted(common_required)
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
discriminator_field: {
|
||||
"type": "string",
|
||||
"enum": list(schemas.keys()),
|
||||
"description": "Type of document that was extracted",
|
||||
},
|
||||
**all_properties,
|
||||
},
|
||||
"required": required_fields,
|
||||
}
|
||||
@@ -1,34 +1,83 @@
|
||||
from typing import Any
|
||||
from typing import Annotated, Any
|
||||
|
||||
from workflows import Workflow, step
|
||||
from workflows.events import StartEvent, StopEvent
|
||||
from workflows.resource import Resource, ResourceConfig
|
||||
|
||||
import jsonref
|
||||
from .config import EXTRACTED_DATA_COLLECTION, ExtractConfig, create_union_schema
|
||||
|
||||
from .config import EXTRACTED_DATA_COLLECTION, FILING_SCHEMAS
|
||||
DISCRIMINATOR_FIELD = "document_type"
|
||||
|
||||
|
||||
class MetadataResponse(StopEvent):
|
||||
json_schema: dict[str, Any]
|
||||
schemas: dict[str, dict[str, Any]]
|
||||
discriminator_field: str
|
||||
extracted_data_collection: str
|
||||
|
||||
|
||||
async def get_presentation_schema(
|
||||
extract_10k: Annotated[
|
||||
ExtractConfig,
|
||||
ResourceConfig(
|
||||
config_file="configs/config.json",
|
||||
path_selector="extract-10k",
|
||||
label="10-K Extraction",
|
||||
),
|
||||
],
|
||||
extract_10q: Annotated[
|
||||
ExtractConfig,
|
||||
ResourceConfig(
|
||||
config_file="configs/config.json",
|
||||
path_selector="extract-10q",
|
||||
label="10-Q Extraction",
|
||||
),
|
||||
],
|
||||
extract_8k: Annotated[
|
||||
ExtractConfig,
|
||||
ResourceConfig(
|
||||
config_file="configs/config.json",
|
||||
path_selector="extract-8k",
|
||||
label="8-K Extraction",
|
||||
),
|
||||
],
|
||||
extract_other: Annotated[
|
||||
ExtractConfig,
|
||||
ResourceConfig(
|
||||
config_file="configs/config.json",
|
||||
path_selector="extract-other",
|
||||
label="Other Extraction",
|
||||
),
|
||||
],
|
||||
) -> dict[str, Any]:
|
||||
schemas = {
|
||||
"10-K": extract_10k.json_schema,
|
||||
"10-Q": extract_10q.json_schema,
|
||||
"8-K": extract_8k.json_schema,
|
||||
"other": extract_other.json_schema,
|
||||
}
|
||||
union = create_union_schema(schemas, discriminator_field=DISCRIMINATOR_FIELD)
|
||||
return {
|
||||
"json_schema": union,
|
||||
"schemas": schemas,
|
||||
"discriminator_field": DISCRIMINATOR_FIELD,
|
||||
}
|
||||
|
||||
|
||||
class MetadataWorkflow(Workflow):
|
||||
"""
|
||||
Simple single step workflow to expose configuration to the UI, such as all JSON schemas and collection name.
|
||||
"""
|
||||
"""Provide extraction schema and configuration to the workflow editor."""
|
||||
|
||||
@step
|
||||
async def get_metadata(self, _: StartEvent) -> MetadataResponse:
|
||||
# Convert all filing schemas to JSON schemas
|
||||
schemas = {}
|
||||
for filing_type, schema_class in FILING_SCHEMAS.items():
|
||||
json_schema = schema_class.model_json_schema()
|
||||
# Resolve any $ref references
|
||||
json_schema = jsonref.replace_refs(json_schema, proxies=False)
|
||||
schemas[filing_type] = json_schema
|
||||
|
||||
async def get_metadata(
|
||||
self,
|
||||
_: StartEvent,
|
||||
presentation: Annotated[dict[str, Any], Resource(get_presentation_schema)],
|
||||
) -> MetadataResponse:
|
||||
"""Return the data schemas and storage settings for the review interface."""
|
||||
return MetadataResponse(
|
||||
schemas=schemas,
|
||||
json_schema=presentation["json_schema"],
|
||||
schemas=presentation["schemas"],
|
||||
discriminator_field=presentation["discriminator_field"],
|
||||
extracted_data_collection=EXTRACTED_DATA_COLLECTION,
|
||||
)
|
||||
|
||||
|
||||
@@ -1,41 +1,39 @@
|
||||
import asyncio
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
import tempfile
|
||||
from typing import Any, Literal
|
||||
from typing import Annotated, Any, Literal
|
||||
|
||||
import httpx
|
||||
from llama_cloud import ClassificationResult, ExtractRun
|
||||
from llama_cloud.types import ClassifierRule, ClassifyParsingConfiguration
|
||||
from llama_cloud_services.extract import SourceText
|
||||
from llama_cloud_services.beta.agent_data import ExtractedData, InvalidExtractionData
|
||||
from llama_cloud import AsyncLlamaCloud
|
||||
from llama_cloud.types.beta.extracted_data import ExtractedData, InvalidExtractionData
|
||||
from pydantic import BaseModel
|
||||
from workflows import Context, Workflow, step
|
||||
from workflows.events import Event, StartEvent, StopEvent
|
||||
from workflows.resource import Resource, ResourceConfig
|
||||
|
||||
from .clients import (
|
||||
get_classifier_client,
|
||||
get_llama_cloud_client,
|
||||
get_data_client,
|
||||
get_extract_agent,
|
||||
from .clients import agent_name, get_llama_cloud_client, project_id
|
||||
from .config import (
|
||||
EXTRACTED_DATA_COLLECTION,
|
||||
ClassifyConfig,
|
||||
ExtractConfig,
|
||||
get_extraction_schema,
|
||||
)
|
||||
from .config import FILING_SCHEMAS
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DISCRIMINATOR_FIELD = "document_type"
|
||||
|
||||
# Mapping from classify rule types to config.json extract section keys
|
||||
EXTRACT_CONFIG_KEYS = {
|
||||
"10-K": "extract-10k",
|
||||
"10-Q": "extract-10q",
|
||||
"8-K": "extract-8k",
|
||||
"other": "extract-other",
|
||||
}
|
||||
|
||||
|
||||
class FileEvent(StartEvent):
|
||||
file_id: str
|
||||
|
||||
|
||||
class DownloadFileEvent(Event):
|
||||
pass
|
||||
|
||||
|
||||
class FileDownloadedEvent(Event):
|
||||
pass
|
||||
file_hash: str | None = None
|
||||
|
||||
|
||||
class ClassifyFileEvent(Event):
|
||||
@@ -53,6 +51,10 @@ class Status(Event):
|
||||
message: str
|
||||
|
||||
|
||||
class ExtractJobStartedEvent(Event):
|
||||
pass
|
||||
|
||||
|
||||
class ExtractedEvent(Event):
|
||||
data: ExtractedData
|
||||
|
||||
@@ -63,76 +65,103 @@ class ExtractedInvalidEvent(Event):
|
||||
|
||||
class ExtractionState(BaseModel):
|
||||
file_id: str | None = None
|
||||
file_path: str | None = None
|
||||
filename: str | None = None
|
||||
file_hash: str | None = None
|
||||
extract_job_id: str | None = None
|
||||
filing_type: str | None = None
|
||||
classification_confidence: float | None = None
|
||||
classification_reasoning: str | None = None
|
||||
|
||||
|
||||
class ProcessFileWorkflow(Workflow):
|
||||
"""
|
||||
Given a file path, this workflow will process a single file through the custom extraction logic.
|
||||
"""
|
||||
"""Extract structured data from a document and save it for review."""
|
||||
|
||||
@step()
|
||||
async def run_file(self, event: FileEvent, ctx: Context) -> DownloadFileEvent:
|
||||
logger.info(f"Running file {event.file_id}")
|
||||
async with ctx.store.edit_state() as state:
|
||||
state.file_id = event.file_id
|
||||
return DownloadFileEvent()
|
||||
async def start_extraction(
|
||||
self,
|
||||
event: FileEvent,
|
||||
ctx: Context[ExtractionState],
|
||||
llama_cloud_client: Annotated[
|
||||
AsyncLlamaCloud, Resource(get_llama_cloud_client)
|
||||
],
|
||||
extract_config: Annotated[
|
||||
ExtractConfig,
|
||||
ResourceConfig(
|
||||
config_file="configs/config.json",
|
||||
path_selector="extract-10k",
|
||||
label="Default Extraction Settings",
|
||||
description="Default extraction config (10-K); actual schema selected after classification",
|
||||
),
|
||||
],
|
||||
) -> ExtractJobStartedEvent:
|
||||
"""Start extraction job for the document."""
|
||||
file_id = event.file_id
|
||||
logger.info(f"Running file {file_id}")
|
||||
|
||||
@step()
|
||||
async def download_file(
|
||||
self, event: DownloadFileEvent, ctx: Context[ExtractionState]
|
||||
) -> ClassifyFileEvent:
|
||||
"""Download the file reference from the cloud storage"""
|
||||
state = await ctx.store.get_state()
|
||||
if state.file_id is None:
|
||||
raise ValueError("File ID is not set")
|
||||
try:
|
||||
file_metadata = await get_llama_cloud_client().files.get_file(
|
||||
id=state.file_id
|
||||
)
|
||||
file_url = await get_llama_cloud_client().files.read_file_content(
|
||||
state.file_id
|
||||
)
|
||||
|
||||
temp_dir = tempfile.gettempdir()
|
||||
files_page = await llama_cloud_client.files.list(file_ids=[file_id])
|
||||
file_metadata = files_page.items[0]
|
||||
filename = file_metadata.name
|
||||
file_path = os.path.join(temp_dir, filename)
|
||||
client = httpx.AsyncClient()
|
||||
# Report progress to the UI
|
||||
logger.info(f"Downloading file {file_url.url} to {file_path}")
|
||||
|
||||
async with client.stream("GET", file_url.url) as response:
|
||||
with open(file_path, "wb") as f:
|
||||
async for chunk in response.aiter_bytes():
|
||||
f.write(chunk)
|
||||
logger.info(f"Downloaded file {file_url.url} to {file_path}")
|
||||
async with ctx.store.edit_state() as state:
|
||||
state.file_path = file_path
|
||||
state.filename = filename
|
||||
return ClassifyFileEvent()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error downloading file {state.file_id}: {e}", exc_info=True)
|
||||
logger.error(f"Error fetching file metadata {file_id}: {e}", exc_info=True)
|
||||
ctx.write_event_to_stream(
|
||||
Status(
|
||||
level="error",
|
||||
message=f"Error downloading file {state.file_id}: {e}",
|
||||
message=f"Error fetching file metadata {file_id}: {e}",
|
||||
)
|
||||
)
|
||||
raise e
|
||||
|
||||
logger.info(f"Extracting data from file {filename}")
|
||||
ctx.write_event_to_stream(
|
||||
Status(level="info", message=f"Extracting data from file {filename}")
|
||||
)
|
||||
|
||||
if extract_config.extraction_agent_id:
|
||||
extract_job = await llama_cloud_client.extraction.jobs.extract(
|
||||
extraction_agent_id=extract_config.extraction_agent_id,
|
||||
file_id=file_id,
|
||||
)
|
||||
else:
|
||||
extract_job = await llama_cloud_client.extraction.run(
|
||||
config=extract_config.settings.model_dump(),
|
||||
data_schema=extract_config.json_schema,
|
||||
file_id=file_id,
|
||||
project_id=project_id,
|
||||
)
|
||||
|
||||
file_hash = event.file_hash or file_metadata.external_file_id
|
||||
|
||||
async with ctx.store.edit_state() as state:
|
||||
state.file_id = file_id
|
||||
state.filename = filename
|
||||
state.file_hash = file_hash
|
||||
state.extract_job_id = extract_job.id
|
||||
|
||||
return ExtractJobStartedEvent()
|
||||
|
||||
@step()
|
||||
async def classify_file(
|
||||
self, event: ClassifyFileEvent, ctx: Context[ExtractionState]
|
||||
self,
|
||||
event: ExtractJobStartedEvent,
|
||||
ctx: Context[ExtractionState],
|
||||
llama_cloud_client: Annotated[
|
||||
AsyncLlamaCloud, Resource(get_llama_cloud_client)
|
||||
],
|
||||
classify_config: Annotated[
|
||||
ClassifyConfig,
|
||||
ResourceConfig(
|
||||
config_file="configs/config.json",
|
||||
path_selector="classify",
|
||||
label="Classification Rules",
|
||||
description="Rules for classifying SEC filing types",
|
||||
),
|
||||
],
|
||||
) -> FileClassifiedEvent:
|
||||
"""Classify the SEC filing document type"""
|
||||
"""Classify the SEC filing document type while extraction runs."""
|
||||
state = await ctx.store.get_state()
|
||||
if state.file_path is None or state.filename is None:
|
||||
raise ValueError("File path or filename is not set")
|
||||
if state.file_id is None or state.filename is None:
|
||||
raise ValueError("File ID or filename is not set")
|
||||
|
||||
try:
|
||||
logger.info(f"Classifying file {state.filename}")
|
||||
@@ -140,71 +169,42 @@ class ProcessFileWorkflow(Workflow):
|
||||
Status(level="info", message=f"Classifying file {state.filename}")
|
||||
)
|
||||
|
||||
# Initialize the classifier
|
||||
|
||||
classifier = get_classifier_client()
|
||||
|
||||
# Define classification rules for SEC filing types
|
||||
# Build rules from config
|
||||
rules = [
|
||||
ClassifierRule(
|
||||
type="10-K",
|
||||
description=(
|
||||
"Form 10-K is an annual report filed by public companies with the SEC. "
|
||||
"It provides a comprehensive summary of a company's financial performance for the year, "
|
||||
"including audited financial statements, management's discussion and analysis (MD&A), "
|
||||
"risk factors, business description, and executive compensation. "
|
||||
"Look for: 'Form 10-K', 'Annual Report', fiscal year references, audited financials."
|
||||
),
|
||||
),
|
||||
ClassifierRule(
|
||||
type="10-Q",
|
||||
description=(
|
||||
"Form 10-Q is a quarterly report filed by public companies with the SEC. "
|
||||
"It provides unaudited financial statements and management discussion for a specific quarter. "
|
||||
"Contains quarterly financial data, updates on business operations, and material changes. "
|
||||
"Look for: 'Form 10-Q', 'Quarterly Report', quarter references (Q1, Q2, Q3), unaudited statements."
|
||||
),
|
||||
),
|
||||
ClassifierRule(
|
||||
type="8-K",
|
||||
description=(
|
||||
"Form 8-K is a current report filed to announce material events or corporate changes. "
|
||||
"Used to notify investors of significant events like mergers, acquisitions, leadership changes, "
|
||||
"earnings releases, or other material corporate events that shareholders should know about. "
|
||||
"Look for: 'Form 8-K', 'Current Report', Item numbers (e.g., Item 1.01, Item 5.02), event dates, "
|
||||
"specific triggering events."
|
||||
),
|
||||
),
|
||||
ClassifierRule(
|
||||
type="other",
|
||||
description=(
|
||||
"Any other SEC filing type not covered by 10-K, 10-Q, or 8-K. "
|
||||
"This includes forms such as S-1 (IPO registration), DEF 14A (proxy statement), "
|
||||
"13F (institutional holdings), SC 13D (beneficial ownership), and other SEC forms."
|
||||
),
|
||||
),
|
||||
{"type": rule.type, "description": rule.description}
|
||||
for rule in classify_config.rules
|
||||
]
|
||||
|
||||
# Configure parsing - only parse first few pages for classification
|
||||
parsing_config = ClassifyParsingConfiguration(
|
||||
max_pages=5, # Only parse first 5 pages for faster classification
|
||||
)
|
||||
# Build parsing config from settings
|
||||
parsing_config: dict[str, Any] = {}
|
||||
if classify_config.settings.parsing_config.max_pages is not None:
|
||||
parsing_config["max_pages"] = (
|
||||
classify_config.settings.parsing_config.max_pages
|
||||
)
|
||||
if classify_config.settings.parsing_config.target_pages is not None:
|
||||
parsing_config["target_pages"] = (
|
||||
classify_config.settings.parsing_config.target_pages
|
||||
)
|
||||
|
||||
# Classify the file
|
||||
results = await classifier.aclassify_file_paths(
|
||||
# 3-step classify: create job, wait, get results
|
||||
classify_job = await llama_cloud_client.classifier.jobs.create(
|
||||
file_ids=[state.file_id],
|
||||
rules=rules,
|
||||
file_input_paths=[state.file_path],
|
||||
parsing_configuration=parsing_config,
|
||||
mode=classify_config.settings.mode,
|
||||
**({"parsing_configuration": parsing_config} if parsing_config else {}),
|
||||
)
|
||||
await llama_cloud_client.classifier.wait_for_completion(classify_job.id)
|
||||
results = await llama_cloud_client.classifier.jobs.get_results(
|
||||
classify_job.id
|
||||
)
|
||||
|
||||
# Extract classification result
|
||||
if results.items and len(results.items) > 0:
|
||||
item = results.items[0]
|
||||
result: ClassificationResult | None = item.result
|
||||
if result:
|
||||
filing_type = result.type
|
||||
confidence = result.confidence
|
||||
reasoning = result.reasoning
|
||||
if item.result:
|
||||
filing_type = item.result.type
|
||||
confidence = item.result.confidence
|
||||
reasoning = item.result.reasoning
|
||||
|
||||
logger.info(
|
||||
f"Classified {state.filename} as {filing_type} "
|
||||
@@ -228,7 +228,6 @@ class ProcessFileWorkflow(Workflow):
|
||||
reasoning=reasoning,
|
||||
)
|
||||
else:
|
||||
# Classification failed, default to "other"
|
||||
logger.warning(
|
||||
f"Classification failed for {state.filename}, defaulting to 'other'"
|
||||
)
|
||||
@@ -242,7 +241,6 @@ class ProcessFileWorkflow(Workflow):
|
||||
state.filing_type = "other"
|
||||
return FileClassifiedEvent(filing_type="other")
|
||||
else:
|
||||
# No results, default to "other"
|
||||
logger.warning(f"No classification results for {state.filename}")
|
||||
async with ctx.store.edit_state() as state:
|
||||
state.filing_type = "other"
|
||||
@@ -256,74 +254,113 @@ class ProcessFileWorkflow(Workflow):
|
||||
message=f"Classification failed, using default schema: {e}",
|
||||
)
|
||||
)
|
||||
# On error, default to "other" and continue
|
||||
async with ctx.store.edit_state() as state:
|
||||
state.filing_type = "other"
|
||||
return FileClassifiedEvent(filing_type="other")
|
||||
|
||||
@step()
|
||||
async def process_file(
|
||||
self, event: FileClassifiedEvent, ctx: Context[ExtractionState]
|
||||
) -> ExtractedEvent | ExtractedInvalidEvent:
|
||||
"""Runs the extraction against the file"""
|
||||
async def complete_extraction(
|
||||
self,
|
||||
event: FileClassifiedEvent,
|
||||
ctx: Context[ExtractionState],
|
||||
llama_cloud_client: Annotated[
|
||||
AsyncLlamaCloud, Resource(get_llama_cloud_client)
|
||||
],
|
||||
extract_10k: Annotated[
|
||||
ExtractConfig,
|
||||
ResourceConfig(
|
||||
config_file="configs/config.json",
|
||||
path_selector="extract-10k",
|
||||
label="10-K Extraction",
|
||||
),
|
||||
],
|
||||
extract_10q: Annotated[
|
||||
ExtractConfig,
|
||||
ResourceConfig(
|
||||
config_file="configs/config.json",
|
||||
path_selector="extract-10q",
|
||||
label="10-Q Extraction",
|
||||
),
|
||||
],
|
||||
extract_8k: Annotated[
|
||||
ExtractConfig,
|
||||
ResourceConfig(
|
||||
config_file="configs/config.json",
|
||||
path_selector="extract-8k",
|
||||
label="8-K Extraction",
|
||||
),
|
||||
],
|
||||
extract_other: Annotated[
|
||||
ExtractConfig,
|
||||
ResourceConfig(
|
||||
config_file="configs/config.json",
|
||||
path_selector="extract-other",
|
||||
label="Other Extraction",
|
||||
),
|
||||
],
|
||||
) -> StopEvent:
|
||||
"""Wait for extraction to complete, validate results, and save for review."""
|
||||
state = await ctx.store.get_state()
|
||||
if state.file_path is None or state.filename is None:
|
||||
raise ValueError("File path or filename is not set")
|
||||
if state.extract_job_id is None:
|
||||
raise ValueError("Job ID cannot be null when waiting for its completion")
|
||||
|
||||
# Select the extract config for the classified filing type
|
||||
extract_configs = {
|
||||
"10-K": extract_10k,
|
||||
"10-Q": extract_10q,
|
||||
"8-K": extract_8k,
|
||||
"other": extract_other,
|
||||
}
|
||||
filing_type = state.filing_type or "other"
|
||||
extract_config = extract_configs.get(filing_type, extract_other)
|
||||
|
||||
await llama_cloud_client.extraction.jobs.wait_for_completion(
|
||||
state.extract_job_id
|
||||
)
|
||||
|
||||
extracted_result = await llama_cloud_client.extraction.jobs.get_result(
|
||||
state.extract_job_id
|
||||
)
|
||||
extract_run = await llama_cloud_client.extraction.runs.get(
|
||||
run_id=extracted_result.run_id
|
||||
)
|
||||
|
||||
extracted_event: ExtractedEvent | ExtractedInvalidEvent
|
||||
try:
|
||||
# Get the appropriate schema based on classification
|
||||
filing_type = (state.filing_type or "other").upper()
|
||||
schema = FILING_SCHEMAS.get(filing_type, FILING_SCHEMAS["other"])
|
||||
|
||||
logger.info(f"Using schema for filing type: {filing_type}")
|
||||
ctx.write_event_to_stream(
|
||||
Status(
|
||||
level="info",
|
||||
message=f"Extracting data using {filing_type} schema",
|
||||
)
|
||||
logger.info(
|
||||
f"Extracted data: {json.dumps(extracted_result.model_dump(), indent=2)}"
|
||||
)
|
||||
|
||||
agent = get_extract_agent()
|
||||
# Update the agent's data schema for this specific filing type
|
||||
agent.data_schema = schema
|
||||
# track the content of the file, so as to be able to de-duplicate
|
||||
file_content = Path(state.file_path).read_bytes()
|
||||
file_hash = hashlib.sha256(file_content).hexdigest()
|
||||
source_text = SourceText(
|
||||
file=state.file_path,
|
||||
filename=state.filename,
|
||||
if extract_config.extraction_agent_id:
|
||||
agent = await llama_cloud_client.extraction.extraction_agents.get(
|
||||
extract_config.extraction_agent_id
|
||||
)
|
||||
schema_class = get_extraction_schema(agent.data_schema)
|
||||
else:
|
||||
schema_class = get_extraction_schema(
|
||||
extract_config.json_schema,
|
||||
discriminator_field=DISCRIMINATOR_FIELD,
|
||||
discriminator_value=filing_type,
|
||||
)
|
||||
data = ExtractedData.from_extraction_result(
|
||||
result=extract_run,
|
||||
schema=schema_class,
|
||||
file_name=state.filename,
|
||||
file_id=state.file_id,
|
||||
file_hash=state.file_hash,
|
||||
)
|
||||
logger.info(f"Extracting data from file {state.filename}")
|
||||
ctx.write_event_to_stream(
|
||||
Status(
|
||||
level="info", message=f"Extracting data from file {state.filename}"
|
||||
)
|
||||
)
|
||||
extracted_result: ExtractRun = await agent.aextract(source_text)
|
||||
try:
|
||||
logger.info(f"Extracted data: {extracted_result}")
|
||||
data = ExtractedData.from_extraction_result(
|
||||
result=extracted_result,
|
||||
schema=schema,
|
||||
file_hash=file_hash,
|
||||
)
|
||||
# Add classification information to the extracted data
|
||||
if data.metadata is None:
|
||||
data.metadata = {}
|
||||
data.metadata["classification"] = filing_type
|
||||
data.metadata["classification_confidence"] = (
|
||||
state.classification_confidence
|
||||
)
|
||||
data.metadata["classification_reasoning"] = (
|
||||
state.classification_reasoning
|
||||
)
|
||||
return ExtractedEvent(data=data)
|
||||
except InvalidExtractionData as e:
|
||||
logger.error(f"Error validating extracted data: {e}", exc_info=True)
|
||||
return ExtractedInvalidEvent(data=e.invalid_item)
|
||||
# Add classification information to the extracted data
|
||||
if data.metadata is None:
|
||||
data.metadata = {}
|
||||
data.metadata["classification"] = filing_type
|
||||
data.metadata["classification_confidence"] = state.classification_confidence
|
||||
data.metadata["classification_reasoning"] = state.classification_reasoning
|
||||
extracted_event = ExtractedEvent(data=data)
|
||||
except InvalidExtractionData as e:
|
||||
logger.error(f"Error validating extracted data: {e}", exc_info=True)
|
||||
extracted_event = ExtractedInvalidEvent(data=e.invalid_item)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error extracting data from file {state.filename}: {e}",
|
||||
exc_info=True,
|
||||
f"Error extracting data from file {state.filename}: {e}", exc_info=True
|
||||
)
|
||||
ctx.write_event_to_stream(
|
||||
Status(
|
||||
@@ -333,61 +370,56 @@ class ProcessFileWorkflow(Workflow):
|
||||
)
|
||||
raise e
|
||||
|
||||
@step()
|
||||
async def record_extracted_data(
|
||||
self, event: ExtractedEvent | ExtractedInvalidEvent, ctx: Context
|
||||
) -> StopEvent:
|
||||
"""Records the extracted data to the agent data API"""
|
||||
try:
|
||||
logger.info(f"Recorded extracted data for file {event.data.file_name}")
|
||||
ctx.write_event_to_stream(
|
||||
Status(
|
||||
level="info",
|
||||
message=f"Recorded extracted data for file {event.data.file_name}",
|
||||
)
|
||||
)
|
||||
# remove past data when reprocessing the same file
|
||||
if event.data.file_hash:
|
||||
await get_data_client().delete(
|
||||
filter={
|
||||
"file_hash": {
|
||||
"eq": event.data.file_hash,
|
||||
},
|
||||
ctx.write_event_to_stream(extracted_event)
|
||||
|
||||
extracted_data = extracted_event.data
|
||||
data_dict = extracted_data.model_dump()
|
||||
if extracted_data.file_hash is not None:
|
||||
delete_result = await llama_cloud_client.beta.agent_data.delete_by_query(
|
||||
deployment_name=agent_name or "_public",
|
||||
collection=EXTRACTED_DATA_COLLECTION,
|
||||
filter={
|
||||
"file_hash": {
|
||||
"eq": extracted_data.file_hash,
|
||||
},
|
||||
)
|
||||
},
|
||||
)
|
||||
if delete_result.deleted_count > 0:
|
||||
logger.info(
|
||||
f"Removing past data for file {event.data.file_name} with hash {event.data.file_hash}"
|
||||
f"Removed {delete_result.deleted_count} existing record(s) "
|
||||
f"for file {extracted_data.file_name}"
|
||||
)
|
||||
# finally, save the new data
|
||||
item_id = await get_data_client().create_item(event.data)
|
||||
return StopEvent(
|
||||
result=item_id.id,
|
||||
item = await llama_cloud_client.beta.agent_data.agent_data(
|
||||
data=data_dict,
|
||||
deployment_name=agent_name or "_public",
|
||||
collection=EXTRACTED_DATA_COLLECTION,
|
||||
)
|
||||
logger.info(
|
||||
f"Recorded extracted data for file {extracted_data.file_name or ''}"
|
||||
)
|
||||
ctx.write_event_to_stream(
|
||||
Status(
|
||||
level="info",
|
||||
message=f"Recorded extracted data for file {extracted_data.file_name or ''}",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error recording extracted data for file {event.data.file_name}: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
ctx.write_event_to_stream(
|
||||
Status(
|
||||
level="error",
|
||||
message=f"Error recording extracted data for file {event.data.file_name}: {e}",
|
||||
)
|
||||
)
|
||||
raise e
|
||||
)
|
||||
return StopEvent(result=item.id)
|
||||
|
||||
|
||||
workflow = ProcessFileWorkflow(timeout=None)
|
||||
|
||||
if __name__ == "__main__":
|
||||
from pathlib import Path
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
async def main():
|
||||
file = await get_llama_cloud_client().files.upload_file(
|
||||
upload_file=Path("test.pdf").open("rb")
|
||||
file = await get_llama_cloud_client().files.create(
|
||||
file=Path("test.pdf").open("rb"),
|
||||
purpose="extract",
|
||||
)
|
||||
await workflow.run(start_event=FileEvent(file_id=file.id))
|
||||
|
||||
|
||||
@@ -1,57 +0,0 @@
|
||||
"""
|
||||
Selects a locally defined shema, or queries the remote extraction agent for the schema.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import jsonref
|
||||
from .clients import get_extract_agent
|
||||
from .config import USE_REMOTE_EXTRACTION_SCHEMA, ExtractionSchema
|
||||
from typing import Any, Type
|
||||
from pydantic import BaseModel
|
||||
from pydantic import create_model, Field
|
||||
|
||||
|
||||
SCHEMA: Type[BaseModel] | None = (
|
||||
None if USE_REMOTE_EXTRACTION_SCHEMA else ExtractionSchema
|
||||
)
|
||||
|
||||
|
||||
_schema_lock = asyncio.Lock()
|
||||
|
||||
|
||||
async def get_extraction_schema() -> Type[BaseModel]:
|
||||
global SCHEMA
|
||||
if SCHEMA is not None:
|
||||
return SCHEMA
|
||||
async with _schema_lock:
|
||||
if SCHEMA is not None:
|
||||
return SCHEMA
|
||||
agent = get_extract_agent()
|
||||
SCHEMA = model_from_schema(agent.data_schema)
|
||||
return SCHEMA
|
||||
|
||||
|
||||
async def get_extraction_schema_json() -> dict[str, Any]:
|
||||
json_schema = (await get_extraction_schema()).model_json_schema()
|
||||
json_schema = jsonref.replace_refs(json_schema, proxies=False)
|
||||
return json_schema
|
||||
|
||||
|
||||
def model_from_schema(schema: dict[str, Any]) -> Type[BaseModel]:
|
||||
"""
|
||||
Converts a JSON schema back to a Pydantic model.
|
||||
"""
|
||||
typemap = {
|
||||
"string": str,
|
||||
"integer": int,
|
||||
"number": float,
|
||||
"boolean": bool,
|
||||
"array": list,
|
||||
"object": dict,
|
||||
}
|
||||
fields = {}
|
||||
for prop, meta in schema.get("properties", {}).items():
|
||||
py_type = typemap.get(meta.get("type"), Any)
|
||||
default = ... if prop in schema.get("required", []) else None
|
||||
fields[prop] = (py_type, Field(default, description=meta.get("description")))
|
||||
return create_model(schema.get("title", "DynamicModel"), **fields)
|
||||
+7
-7
@@ -16,18 +16,18 @@
|
||||
"dependencies": {
|
||||
"@babel/runtime": "^7.27.6",
|
||||
"@lezer/highlight": "^1.2.1",
|
||||
"@llamaindex/ui": "~3.6.1",
|
||||
"@llamaindex/workflows-client": "^1.7.0",
|
||||
"@radix-ui/themes": "^3.2.1",
|
||||
"@llamaindex/llama-cloud": "^1.3.0",
|
||||
"@llamaindex/ui": "^4.1.3",
|
||||
"@llamaindex/workflows-client": "^1.8.3",
|
||||
"@radix-ui/themes": "^3.3.0",
|
||||
"class-variance-authority": "^0.7.1",
|
||||
"clsx": "^2.1.1",
|
||||
"llama-cloud-services": "^0.5.2",
|
||||
"lucide-react": "^0.514.0",
|
||||
"react": "^19.0.0",
|
||||
"react-dom": "^19.0.0",
|
||||
"react-router-dom": "^6.30.0",
|
||||
"sonner": "^2.0.5",
|
||||
"tw-animate-css": "^1.3.5"
|
||||
"react-router-dom": "^7.8.0",
|
||||
"sonner": "^1.7.2",
|
||||
"tw-animate-css": "^1.4.0"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@tailwindcss/postcss": "^4.1.10",
|
||||
|
||||
+12
-18
@@ -1,10 +1,10 @@
|
||||
import { ExtractedData } from "llama-cloud-services/beta/agent";
|
||||
import {
|
||||
ApiClients,
|
||||
createWorkflowsClient,
|
||||
createWorkflowsConfig,
|
||||
createCloudAgentClient,
|
||||
cloudApiClient,
|
||||
configureCloudClient,
|
||||
getCloudClient,
|
||||
createAgentDataConfig,
|
||||
} from "@llamaindex/ui";
|
||||
import { AGENT_NAME } from "./config";
|
||||
import type { Metadata } from "./useMetadata";
|
||||
@@ -13,16 +13,11 @@ const platformToken = import.meta.env.VITE_LLAMA_CLOUD_API_KEY;
|
||||
const apiBaseUrl = import.meta.env.VITE_LLAMA_CLOUD_BASE_URL;
|
||||
const projectId = import.meta.env.VITE_LLAMA_DEPLOY_PROJECT_ID;
|
||||
|
||||
// Configure the platform client
|
||||
cloudApiClient.setConfig({
|
||||
...(apiBaseUrl && { baseUrl: apiBaseUrl }),
|
||||
headers: {
|
||||
// optionally use a backend API token scoped to a project. For local development,
|
||||
...(platformToken && { authorization: `Bearer ${platformToken}` }),
|
||||
// This header is required for requests to correctly scope to the agent's project
|
||||
// when authenticating with a user cookie
|
||||
...(projectId && { "Project-Id": projectId }),
|
||||
},
|
||||
// Configure the cloud client
|
||||
configureCloudClient({
|
||||
...(apiBaseUrl && { baseURL: apiBaseUrl }),
|
||||
...(platformToken && { apiKey: platformToken }),
|
||||
...(projectId && { projectId }),
|
||||
});
|
||||
|
||||
export function createBaseWorkflowClient(): ReturnType<
|
||||
@@ -37,15 +32,14 @@ export function createBaseWorkflowClient(): ReturnType<
|
||||
|
||||
export function createClients(metadata: Metadata): ApiClients {
|
||||
const workflowsClient = createBaseWorkflowClient();
|
||||
const agentClient = createCloudAgentClient<ExtractedData<any>>({
|
||||
client: cloudApiClient,
|
||||
const agentDataConfig = createAgentDataConfig({
|
||||
windowUrl: typeof window !== "undefined" ? window.location.href : undefined,
|
||||
collection: metadata.extracted_data_collection,
|
||||
});
|
||||
|
||||
return {
|
||||
workflowsClient,
|
||||
cloudApiClient,
|
||||
agentDataClient: agentClient,
|
||||
} as ApiClients;
|
||||
cloudApiClient: getCloudClient(),
|
||||
agentDataConfig,
|
||||
};
|
||||
}
|
||||
|
||||
+7
-16
@@ -1,11 +1,5 @@
|
||||
import type {
|
||||
ExtractedData,
|
||||
TypedAgentData,
|
||||
} from "llama-cloud-services/beta/agent";
|
||||
import type { AgentDataItem, ExtractedData } from "@llamaindex/ui";
|
||||
|
||||
/**
|
||||
* Downloads data as a JSON file
|
||||
*/
|
||||
export function downloadJSON<T>(
|
||||
data: T,
|
||||
filename: string = "extraction-results.json",
|
||||
@@ -20,19 +14,16 @@ export function downloadJSON<T>(
|
||||
document.body.appendChild(link);
|
||||
link.click();
|
||||
|
||||
// Cleanup
|
||||
document.body.removeChild(link);
|
||||
URL.revokeObjectURL(url);
|
||||
}
|
||||
|
||||
/**
|
||||
* Downloads extracted data item as JSON
|
||||
*/
|
||||
export function downloadExtractedDataItem<T>(
|
||||
item: TypedAgentData<ExtractedData<T>>,
|
||||
) {
|
||||
const fileName = item.data.file_name || "item";
|
||||
const timestamp = item.createdAt.toISOString().split("T")[0];
|
||||
export function downloadExtractedDataItem(item: AgentDataItem) {
|
||||
const extractedData = item.data as ExtractedData<unknown>;
|
||||
const fileName = extractedData.file_name || "item";
|
||||
const timestamp = item.created_at
|
||||
? new Date(item.created_at).toISOString().split("T")[0]
|
||||
: new Date().toISOString().split("T")[0];
|
||||
const filename = `${fileName}-${timestamp}.json`;
|
||||
|
||||
downloadJSON(item, filename);
|
||||
|
||||
+1
-2
@@ -1,7 +1,6 @@
|
||||
import { clsx, type ClassValue } from "clsx";
|
||||
import { twMerge } from "tailwind-merge";
|
||||
import type { Highlight } from "@llamaindex/ui";
|
||||
import type { FieldCitation } from "llama-cloud-services/beta/agent";
|
||||
import type { Highlight, FieldCitation } from "@llamaindex/ui";
|
||||
|
||||
export function cn(...inputs: ClassValue[]) {
|
||||
return twMerge(clsx(inputs));
|
||||
|
||||
@@ -3,8 +3,8 @@ import {
|
||||
WorkflowTrigger,
|
||||
ExtractedDataItemGrid,
|
||||
HandlerState,
|
||||
AgentDataItem,
|
||||
} from "@llamaindex/ui";
|
||||
import type { TypedAgentData } from "llama-cloud-services/beta/agent";
|
||||
import styles from "./HomePage.module.css";
|
||||
import { useNavigate } from "react-router-dom";
|
||||
import { useState } from "react";
|
||||
@@ -16,7 +16,7 @@ export default function HomePage() {
|
||||
|
||||
function TaskList() {
|
||||
const navigate = useNavigate();
|
||||
const goToItem = (item: TypedAgentData) => {
|
||||
const goToItem = (item: AgentDataItem) => {
|
||||
navigate(`/item/${item.id}`);
|
||||
};
|
||||
const [reloadSignal, setReloadSignal] = useState(0);
|
||||
@@ -52,9 +52,11 @@ function TaskList() {
|
||||
/>
|
||||
<WorkflowTrigger
|
||||
workflowName="process-file"
|
||||
contentHash={{ enabled: true }}
|
||||
customWorkflowInput={(files) => {
|
||||
return {
|
||||
file_id: files[0].fileId,
|
||||
file_hash: files[0].contentHash ?? null,
|
||||
};
|
||||
}}
|
||||
onSuccess={(handler) => {
|
||||
|
||||
+21
-17
@@ -5,6 +5,7 @@ import {
|
||||
FilePreview,
|
||||
useItemData,
|
||||
type Highlight,
|
||||
type ExtractedData,
|
||||
Button,
|
||||
} from "@llamaindex/ui";
|
||||
import { Clock, XCircle, Download } from "lucide-react";
|
||||
@@ -32,8 +33,11 @@ export default function ItemPage() {
|
||||
});
|
||||
|
||||
// Determine the correct schema based on classification
|
||||
const classificationData = itemHookData.item?.data as
|
||||
| ExtractedData<any>
|
||||
| undefined;
|
||||
const classification = (
|
||||
(itemHookData.item?.data?.metadata?.classification as string | undefined) ||
|
||||
(classificationData?.metadata?.classification as string | undefined) ||
|
||||
"10-K"
|
||||
).toUpperCase();
|
||||
const correctSchema =
|
||||
@@ -52,9 +56,11 @@ export default function ItemPage() {
|
||||
|
||||
const navigate = useNavigate();
|
||||
|
||||
// Update breadcrumb when item data loads
|
||||
useEffect(() => {
|
||||
const fileName = itemHookData.item?.data?.file_name;
|
||||
const extractedData = itemHookData.item?.data as
|
||||
| ExtractedData<unknown>
|
||||
| undefined;
|
||||
const fileName = extractedData?.file_name;
|
||||
if (fileName) {
|
||||
setBreadcrumbs([
|
||||
{ label: APP_TITLE, href: "/" },
|
||||
@@ -66,10 +72,9 @@ export default function ItemPage() {
|
||||
}
|
||||
|
||||
return () => {
|
||||
// Reset to default breadcrumb when leaving the page
|
||||
setBreadcrumbs([{ label: APP_TITLE, href: "/" }]);
|
||||
};
|
||||
}, [itemHookData.item?.data?.file_name, setBreadcrumbs]);
|
||||
}, [itemHookData.item?.data, setBreadcrumbs]);
|
||||
|
||||
useEffect(() => {
|
||||
setButtons(() => [
|
||||
@@ -83,10 +88,9 @@ export default function ItemPage() {
|
||||
}
|
||||
}}
|
||||
disabled={!itemData}
|
||||
>
|
||||
<Download className="h-4 w-4 mr-2" />
|
||||
Export JSON
|
||||
</Button>
|
||||
startIcon={<Download className="h-4 w-4" />}
|
||||
label="Export JSON"
|
||||
/>
|
||||
<AcceptReject<any>
|
||||
itemData={itemHookData}
|
||||
onComplete={() => navigate("/")}
|
||||
@@ -105,8 +109,8 @@ export default function ItemPage() {
|
||||
error,
|
||||
} = itemHookData;
|
||||
|
||||
const classificationReasoning = itemData?.data?.metadata
|
||||
?.classification_reasoning as string | undefined;
|
||||
const classificationReasoning = (itemData?.data as ExtractedData<any>)
|
||||
?.metadata?.classification_reasoning as string | undefined;
|
||||
|
||||
if (isLoading) {
|
||||
return (
|
||||
@@ -132,13 +136,15 @@ export default function ItemPage() {
|
||||
);
|
||||
}
|
||||
|
||||
const extractedData = itemData.data as ExtractedData<any>;
|
||||
const fileId = extractedData.file_id;
|
||||
|
||||
return (
|
||||
<div className="flex h-full bg-gray-50">
|
||||
{/* Left Side - File Preview */}
|
||||
<div className="w-1/2 border-r h-full border-gray-200 bg-white">
|
||||
{itemData.data.file_id && (
|
||||
{fileId && (
|
||||
<FilePreview
|
||||
fileId={itemData.data.file_id}
|
||||
fileId={fileId}
|
||||
onBoundingBoxClick={(box, pageNumber) => {
|
||||
console.log("Bounding box clicked:", box, "on page:", pageNumber);
|
||||
}}
|
||||
@@ -147,7 +153,6 @@ export default function ItemPage() {
|
||||
)}
|
||||
</div>
|
||||
|
||||
{/* Right Side - Review Panel */}
|
||||
<div className="flex-1 bg-white h-full overflow-y-auto">
|
||||
<div className="p-4 space-y-4">
|
||||
{/* Classification Info */}
|
||||
@@ -163,10 +168,9 @@ export default function ItemPage() {
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
{/* Extracted Data */}
|
||||
<ExtractedDataDisplay<any>
|
||||
key={schemaKey}
|
||||
extractedData={itemData.data}
|
||||
extractedData={extractedData}
|
||||
title="Extracted Data"
|
||||
onChange={(updatedData) => {
|
||||
updateData(updatedData);
|
||||
|
||||
Reference in New Issue
Block a user