mirror of
https://github.com/run-llama/template-workflow-classify-extract-sec.git
synced 2026-07-01 21:54:02 -04:00
Classify v2 fake, agent_data.create migration, downstream copier updates (#268)
This commit is contained in:
+1
-1
@@ -1,3 +1,3 @@
|
||||
# Changes here will be overwritten by Copier; NEVER EDIT MANUALLY
|
||||
_commit: v0.7.2
|
||||
_commit: v0.7.3
|
||||
_src_path: https://github.com/run-llama/template-workflow-data-extraction
|
||||
|
||||
+12
-3
@@ -1,11 +1,11 @@
|
||||
[project]
|
||||
name = "extraction-review"
|
||||
name = "classify-extract-sec"
|
||||
version = "0.1.0"
|
||||
description = "Extracts data"
|
||||
description = "Classify SEC filings and extract per-type data"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.12"
|
||||
dependencies = [
|
||||
"llama-cloud>=2.3.0,<3",
|
||||
"llama-cloud>=2.4.1,<3",
|
||||
"json-schema-to-pydantic>=0.4.8",
|
||||
"llama-index-workflows>=2.16.0,<3.0.0",
|
||||
"python-dotenv>=1.1.0",
|
||||
@@ -23,12 +23,17 @@ dev = [
|
||||
"pytest>=8.4.1",
|
||||
"hatch>=1.14.1",
|
||||
"pytest-asyncio>=1.3.0",
|
||||
"pytest-timeout>=2.3.1",
|
||||
"llama-cloud-fake>=0.1,<0.2",
|
||||
]
|
||||
|
||||
[build-system]
|
||||
requires = ["hatchling"]
|
||||
build-backend = "hatchling.build"
|
||||
|
||||
[tool.hatch.build.targets.wheel]
|
||||
packages = ["src/extraction_review"]
|
||||
|
||||
[tool.hatch.envs.default.scripts]
|
||||
"format" = "ruff format ."
|
||||
"format-check" = "ruff format --check ."
|
||||
@@ -39,6 +44,10 @@ test = "pytest"
|
||||
"all-check" = ["format-check", "lint-check", "test"]
|
||||
"all-fix" = ["format", "lint", "test"]
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
timeout = 120
|
||||
timeout_method = "thread"
|
||||
|
||||
[tool.llamadeploy]
|
||||
env_files = [".env"]
|
||||
llama_cloud = true
|
||||
|
||||
@@ -392,7 +392,7 @@ class ProcessFileWorkflow(Workflow):
|
||||
f"Removed {delete_result.deleted_count} existing record(s) "
|
||||
f"for file {extracted_data.file_name}"
|
||||
)
|
||||
item = await llama_cloud_client.beta.agent_data.agent_data(
|
||||
item = await llama_cloud_client.beta.agent_data.create(
|
||||
data=data_dict,
|
||||
deployment_name=agent_name or "_public",
|
||||
collection=EXTRACTED_DATA_COLLECTION,
|
||||
|
||||
@@ -0,0 +1,20 @@
|
||||
"""Pytest configuration: install the LlamaCloud fake server for all tests."""
|
||||
|
||||
import logging
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
from llama_cloud_fake import FakeLlamaCloudServer
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
||||
handlers=[logging.StreamHandler(sys.stdout)],
|
||||
)
|
||||
|
||||
_fake = FakeLlamaCloudServer().install()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fake() -> FakeLlamaCloudServer:
|
||||
return _fake
|
||||
Binary file not shown.
@@ -1,2 +0,0 @@
|
||||
def test_placeholder():
|
||||
pass
|
||||
@@ -0,0 +1,72 @@
|
||||
from importlib.metadata import version
|
||||
|
||||
import pytest
|
||||
from extraction_review.config import EXTRACTED_DATA_COLLECTION
|
||||
from extraction_review.metadata_workflow import DISCRIMINATOR_FIELD, MetadataResponse
|
||||
from extraction_review.metadata_workflow import workflow as metadata_workflow
|
||||
from extraction_review.process_file import FileEvent, Status
|
||||
from extraction_review.process_file import workflow as process_file_workflow
|
||||
from llama_cloud_fake import FakeLlamaCloudServer
|
||||
from workflows.events import StartEvent
|
||||
|
||||
FILING_TYPES = {"10-K", "10-Q", "8-K", "other"}
|
||||
FAKE_HAS_CLASSIFY_V2 = version("llama-cloud-fake") >= "0.1.1"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_file_workflow(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
fake: FakeLlamaCloudServer,
|
||||
) -> None:
|
||||
monkeypatch.setenv("LLAMA_CLOUD_API_KEY", "fake-api-key")
|
||||
file_id = fake.files.preload(path="tests/files/test.pdf")
|
||||
try:
|
||||
result = await process_file_workflow.run(start_event=FileEvent(file_id=file_id))
|
||||
except Exception:
|
||||
result = None
|
||||
assert result is not None
|
||||
assert isinstance(result, str)
|
||||
assert len(result) == 7
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.skipif(
|
||||
not FAKE_HAS_CLASSIFY_V2,
|
||||
reason="llama-cloud-fake < 0.1.1 does not mock classify v2",
|
||||
)
|
||||
async def test_classify_v2_assigns_filing_type(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
fake: FakeLlamaCloudServer,
|
||||
) -> None:
|
||||
"""process_file reports a concrete SEC filing type from classify v2."""
|
||||
monkeypatch.setenv("LLAMA_CLOUD_API_KEY", "fake-api-key")
|
||||
file_id = fake.files.preload(path="tests/files/test.pdf")
|
||||
|
||||
handler = process_file_workflow.run(start_event=FileEvent(file_id=file_id))
|
||||
classified_statuses: list[Status] = []
|
||||
async for event in handler.stream_events():
|
||||
if isinstance(event, Status):
|
||||
if event.level == "error":
|
||||
raise AssertionError(f"workflow errored: {event.message}")
|
||||
if event.message.startswith("Classified as "):
|
||||
classified_statuses.append(event)
|
||||
await handler
|
||||
|
||||
# A real classify v2 result produces a "Classified as <type>" info status.
|
||||
# The fallback path (classification error -> "other") does *not* emit this.
|
||||
assert classified_statuses, (
|
||||
"expected a 'Classified as ...' status from a completed classify v2 job"
|
||||
)
|
||||
message = classified_statuses[-1].message
|
||||
matched = next((t for t in FILING_TYPES if f"Classified as {t} " in message), None)
|
||||
assert matched is not None, f"unexpected classification status: {message}"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_metadata_workflow() -> None:
|
||||
result = await metadata_workflow.run(start_event=StartEvent())
|
||||
assert isinstance(result, MetadataResponse)
|
||||
assert result.extracted_data_collection == EXTRACTED_DATA_COLLECTION
|
||||
assert result.discriminator_field == DISCRIMINATOR_FIELD
|
||||
assert set(result.schemas.keys()) == FILING_TYPES
|
||||
assert DISCRIMINATOR_FIELD in result.json_schema.get("properties", {})
|
||||
Reference in New Issue
Block a user