mirror of
https://github.com/run-llama/llama_extract.git
synced 2026-07-01 01:37:54 -04:00
Update to llama cloud v 0.1.11 (#48)
This commit is contained in:
@@ -12,12 +12,13 @@ from llama_cloud import (
|
||||
ExtractConfig,
|
||||
ExtractJob,
|
||||
ExtractJobCreate,
|
||||
ExtractResultset,
|
||||
ExtractRun,
|
||||
File,
|
||||
ExtractMode,
|
||||
StatusEnum,
|
||||
Project,
|
||||
ExtractTarget,
|
||||
LlamaExtractSettings,
|
||||
)
|
||||
from llama_cloud.client import AsyncLlamaCloud
|
||||
from llama_extract.utils import JSONObjectType, augment_async_errors
|
||||
@@ -33,11 +34,10 @@ FileInput = Union[str, Path, bytes, BufferedIOBase]
|
||||
SchemaInput = Union[JSONObjectType, Type[BaseModel]]
|
||||
|
||||
DEFAULT_EXTRACT_CONFIG = ExtractConfig(
|
||||
extraction_mode=ExtractMode.PER_DOC,
|
||||
extraction_target=ExtractTarget.PER_DOC,
|
||||
extraction_mode=ExtractMode.ACCURATE,
|
||||
)
|
||||
|
||||
ExtractionResult = Tuple[ExtractJob, ExtractResultset]
|
||||
|
||||
|
||||
class ExtractionAgent:
|
||||
"""Class representing a single extraction agent with methods for extraction operations."""
|
||||
@@ -192,6 +192,58 @@ class ExtractionAgent:
|
||||
)
|
||||
)
|
||||
|
||||
async def _queue_extraction_test(
|
||||
self,
|
||||
files: Union[FileInput, List[FileInput]],
|
||||
extract_settings: LlamaExtractSettings,
|
||||
) -> Union[ExtractJob, List[ExtractJob]]:
|
||||
if not isinstance(files, list):
|
||||
files = [files]
|
||||
single_file = True
|
||||
else:
|
||||
single_file = False
|
||||
|
||||
upload_tasks = [self._upload_file(file) for file in files]
|
||||
with augment_async_errors():
|
||||
uploaded_files = await run_jobs(
|
||||
upload_tasks,
|
||||
workers=self.num_workers,
|
||||
desc="Uploading files",
|
||||
show_progress=self.show_progress,
|
||||
)
|
||||
|
||||
async def run_job(file: File) -> ExtractRun:
|
||||
job_queued = await self._client.llama_extract.run_job_test_user(
|
||||
job_create=ExtractJobCreate(
|
||||
extraction_agent_id=self.id,
|
||||
file_id=file.id,
|
||||
data_schema_override=self.data_schema,
|
||||
config_override=self.config,
|
||||
),
|
||||
extract_settings=extract_settings,
|
||||
)
|
||||
return await self._wait_for_job_result(job_queued.id)
|
||||
|
||||
job_tasks = [run_job(file) for file in uploaded_files]
|
||||
with augment_async_errors():
|
||||
extract_jobs = await run_jobs(
|
||||
job_tasks,
|
||||
workers=self.num_workers,
|
||||
desc="Creating extraction jobs",
|
||||
show_progress=self.show_progress,
|
||||
)
|
||||
|
||||
if self._verbose:
|
||||
for file, job in zip(files, extract_jobs):
|
||||
file_repr = (
|
||||
str(file) if isinstance(file, (str, Path)) else "<bytes/buffer>"
|
||||
)
|
||||
print(
|
||||
f"Queued file extraction for file {file_repr} under job_id {job.id}"
|
||||
)
|
||||
|
||||
return extract_jobs[0] if single_file else extract_jobs
|
||||
|
||||
async def queue_extraction(
|
||||
self,
|
||||
files: Union[FileInput, List[FileInput]],
|
||||
|
||||
Generated
+20
-20
@@ -285,13 +285,13 @@ files = [
|
||||
|
||||
[[package]]
|
||||
name = "attrs"
|
||||
version = "24.3.0"
|
||||
version = "25.1.0"
|
||||
description = "Classes Without Boilerplate"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "attrs-24.3.0-py3-none-any.whl", hash = "sha256:ac96cd038792094f438ad1f6ff80837353805ac950cd2aa0e0625ef19850c308"},
|
||||
{file = "attrs-24.3.0.tar.gz", hash = "sha256:8f5c07333d543103541ba7be0e2ce16eeee8130cb0b3f9238ab904ce1e85baff"},
|
||||
{file = "attrs-25.1.0-py3-none-any.whl", hash = "sha256:c75a69e28a550a7e93789579c22aa26b0f5b83b75dc4e08fe092980051e1090a"},
|
||||
{file = "attrs-25.1.0.tar.gz", hash = "sha256:1c97078a80c814273a76b2a298a932eb681c87415c11dee0a6921de7f1b02c3e"},
|
||||
]
|
||||
|
||||
[package.extras]
|
||||
@@ -737,20 +737,20 @@ files = [
|
||||
|
||||
[[package]]
|
||||
name = "deprecated"
|
||||
version = "1.2.15"
|
||||
version = "1.2.18"
|
||||
description = "Python @deprecated decorator to deprecate old python classes, functions or methods."
|
||||
optional = false
|
||||
python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,>=2.7"
|
||||
files = [
|
||||
{file = "Deprecated-1.2.15-py2.py3-none-any.whl", hash = "sha256:353bc4a8ac4bfc96800ddab349d89c25dec1079f65fd53acdcc1e0b975b21320"},
|
||||
{file = "deprecated-1.2.15.tar.gz", hash = "sha256:683e561a90de76239796e6b6feac66b99030d2dd3fcf61ef996330f14bbb9b0d"},
|
||||
{file = "Deprecated-1.2.18-py2.py3-none-any.whl", hash = "sha256:bd5011788200372a32418f888e326a09ff80d0214bd961147cfed01b5c018eec"},
|
||||
{file = "deprecated-1.2.18.tar.gz", hash = "sha256:422b6f6d859da6f2ef57857761bfb392480502a64c3028ca9bbe86085d72115d"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
wrapt = ">=1.10,<2"
|
||||
|
||||
[package.extras]
|
||||
dev = ["PyTest", "PyTest-Cov", "bump2version (<1)", "jinja2 (>=3.0.3,<3.1.0)", "setuptools", "sphinx (<2)", "tox"]
|
||||
dev = ["PyTest", "PyTest-Cov", "bump2version (<1)", "setuptools", "tox"]
|
||||
|
||||
[[package]]
|
||||
name = "dirtyjson"
|
||||
@@ -1784,13 +1784,13 @@ rapidfuzz = ">=3.9.0,<4.0.0"
|
||||
|
||||
[[package]]
|
||||
name = "llama-cloud"
|
||||
version = "0.1.10"
|
||||
version = "0.1.11"
|
||||
description = ""
|
||||
optional = false
|
||||
python-versions = "<4,>=3.8"
|
||||
files = [
|
||||
{file = "llama_cloud-0.1.10-py3-none-any.whl", hash = "sha256:d91198ad92ea6c3a25757e5d6cb565b4bd6db385dc4fa596a725c0fb81a68f4e"},
|
||||
{file = "llama_cloud-0.1.10.tar.gz", hash = "sha256:56ffe8f2910c2047dd4eb1b13da31ee5f67321a000794eee559e0b56954d2f76"},
|
||||
{file = "llama_cloud-0.1.11-py3-none-any.whl", hash = "sha256:b703765d03783a5a0fc57a52adc9892f8b91b0c19bbecb85a54ad4e813342951"},
|
||||
{file = "llama_cloud-0.1.11.tar.gz", hash = "sha256:d4be5b48659fd9fe1698727be257269a22d7f2733a2ed11bce7065768eb94cbe"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@@ -1905,13 +1905,13 @@ files = [
|
||||
|
||||
[[package]]
|
||||
name = "marshmallow"
|
||||
version = "3.25.1"
|
||||
version = "3.26.0"
|
||||
description = "A lightweight library for converting complex datatypes to and from native Python datatypes."
|
||||
optional = false
|
||||
python-versions = ">=3.9"
|
||||
files = [
|
||||
{file = "marshmallow-3.25.1-py3-none-any.whl", hash = "sha256:ec5d00d873ce473b7f2ffcb7104286a376c354cab0c2fa12f5573dab03e87210"},
|
||||
{file = "marshmallow-3.25.1.tar.gz", hash = "sha256:f4debda3bb11153d81ac34b0d582bf23053055ee11e791b54b4b35493468040a"},
|
||||
{file = "marshmallow-3.26.0-py3-none-any.whl", hash = "sha256:1287bca04e6a5f4094822ac153c03da5e214a0a60bcd557b140f3e66991b8ca1"},
|
||||
{file = "marshmallow-3.26.0.tar.gz", hash = "sha256:eb36762a1cc76d7abf831e18a3a1b26d3d481bbc74581b8e532a3d3a8115e1cb"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@@ -2751,13 +2751,13 @@ files = [
|
||||
|
||||
[[package]]
|
||||
name = "pydantic"
|
||||
version = "2.10.5"
|
||||
version = "2.10.6"
|
||||
description = "Data validation using Python type hints"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "pydantic-2.10.5-py3-none-any.whl", hash = "sha256:4dd4e322dbe55472cb7ca7e73f4b63574eecccf2835ffa2af9021ce113c83c53"},
|
||||
{file = "pydantic-2.10.5.tar.gz", hash = "sha256:278b38dbbaec562011d659ee05f63346951b3a248a6f3642e1bc68894ea2b4ff"},
|
||||
{file = "pydantic-2.10.6-py3-none-any.whl", hash = "sha256:427d664bf0b8a2b34ff5dd0f5a18df00591adcee7198fbd71981054cef37b584"},
|
||||
{file = "pydantic-2.10.6.tar.gz", hash = "sha256:ca5daa827cce33de7a42be142548b0096bf05a7e7b365aebfa5f8eeec7128236"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@@ -3307,13 +3307,13 @@ all = ["numpy"]
|
||||
|
||||
[[package]]
|
||||
name = "referencing"
|
||||
version = "0.36.1"
|
||||
version = "0.36.2"
|
||||
description = "JSON Referencing + Python"
|
||||
optional = false
|
||||
python-versions = ">=3.9"
|
||||
files = [
|
||||
{file = "referencing-0.36.1-py3-none-any.whl", hash = "sha256:363d9c65f080d0d70bc41c721dce3c7f3e77fc09f269cd5c8813da18069a6794"},
|
||||
{file = "referencing-0.36.1.tar.gz", hash = "sha256:ca2e6492769e3602957e9b831b94211599d2aade9477f5d44110d2530cf9aade"},
|
||||
{file = "referencing-0.36.2-py3-none-any.whl", hash = "sha256:e8699adbbf8b5c7de96d8ffa0eb5c158b3beafce084968e2ea8bb08c6794dcd0"},
|
||||
{file = "referencing-0.36.2.tar.gz", hash = "sha256:df2e89862cd09deabbdba16944cc3f10feb6b3e6f18e902f7cc25609a34775aa"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@@ -4317,4 +4317,4 @@ type = ["pytest-mypy"]
|
||||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = ">=3.9,<4.0"
|
||||
content-hash = "90b87c91c45412dae185dac7a8eb80b8e6f37d9677990ff394590acd751c7565"
|
||||
content-hash = "1ff53e863a18be137ee0eff10a8c4412e2db95ca7cdc7aedf22339325d3fc818"
|
||||
|
||||
+1
-1
@@ -18,7 +18,7 @@ packages = [{include = "llama_extract"}]
|
||||
[tool.poetry.dependencies]
|
||||
python = ">=3.9,<4.0"
|
||||
llama-index-core = "^0.11.0"
|
||||
llama-cloud = "0.1.10"
|
||||
llama-cloud = "0.1.11"
|
||||
python-dotenv = "^1.0.1"
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
|
||||
@@ -0,0 +1,149 @@
|
||||
import os
|
||||
import pytest
|
||||
from pathlib import Path
|
||||
|
||||
from llama_extract import LlamaExtract, ExtractionAgent
|
||||
from dotenv import load_dotenv
|
||||
from time import perf_counter
|
||||
from collections import namedtuple
|
||||
import json
|
||||
import uuid
|
||||
from llama_cloud.core.api_error import ApiError
|
||||
from llama_cloud.types import (
|
||||
ExtractConfig,
|
||||
ExtractMode,
|
||||
LlamaParseParameters,
|
||||
LlamaExtractSettings,
|
||||
)
|
||||
|
||||
load_dotenv(Path(__file__).parent.parent / ".env.dev", override=True)
|
||||
|
||||
|
||||
TEST_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "data")
|
||||
# Get configuration from environment
|
||||
LLAMA_CLOUD_API_KEY = os.getenv("LLAMA_CLOUD_API_KEY")
|
||||
LLAMA_CLOUD_BASE_URL = os.getenv("LLAMA_CLOUD_BASE_URL")
|
||||
LLAMA_CLOUD_PROJECT_ID = os.getenv("LLAMA_CLOUD_PROJECT_ID")
|
||||
|
||||
TestCase = namedtuple(
|
||||
"TestCase", ["name", "schema_path", "config", "input_file", "expected_output"]
|
||||
)
|
||||
|
||||
|
||||
def get_test_cases():
|
||||
"""Get all test cases from TEST_DIR.
|
||||
|
||||
Returns:
|
||||
List[TestCase]: List of test cases
|
||||
"""
|
||||
test_cases = []
|
||||
|
||||
for data_type in os.listdir(TEST_DIR):
|
||||
data_type_dir = os.path.join(TEST_DIR, data_type)
|
||||
if not os.path.isdir(data_type_dir):
|
||||
continue
|
||||
|
||||
schema_path = os.path.join(data_type_dir, "schema.json")
|
||||
if not os.path.exists(schema_path):
|
||||
continue
|
||||
|
||||
input_files = []
|
||||
|
||||
for file in os.listdir(data_type_dir):
|
||||
file_path = os.path.join(data_type_dir, file)
|
||||
if (
|
||||
not os.path.isfile(file_path)
|
||||
or file == "schema.json"
|
||||
or file.endswith(".test.json")
|
||||
):
|
||||
continue
|
||||
|
||||
input_files.append(file_path)
|
||||
|
||||
settings = [
|
||||
ExtractConfig(extraction_mode=ExtractMode.FAST),
|
||||
ExtractConfig(extraction_mode=ExtractMode.ACCURATE),
|
||||
]
|
||||
|
||||
for input_file in sorted(input_files):
|
||||
base_name = os.path.splitext(os.path.basename(input_file))[0]
|
||||
expected_output = os.path.join(data_type_dir, f"{base_name}.test.json")
|
||||
|
||||
if not os.path.exists(expected_output):
|
||||
continue
|
||||
|
||||
test_name = f"{data_type}/{os.path.basename(input_file)}"
|
||||
for setting in settings:
|
||||
test_cases.append(
|
||||
TestCase(
|
||||
name=test_name,
|
||||
schema_path=schema_path,
|
||||
input_file=input_file,
|
||||
config=setting,
|
||||
expected_output=expected_output,
|
||||
)
|
||||
)
|
||||
|
||||
return test_cases
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def extractor():
|
||||
"""Create a single LlamaExtract instance for all tests."""
|
||||
extract = LlamaExtract(
|
||||
api_key=LLAMA_CLOUD_API_KEY,
|
||||
base_url=LLAMA_CLOUD_BASE_URL,
|
||||
project_id=LLAMA_CLOUD_PROJECT_ID,
|
||||
verbose=True,
|
||||
)
|
||||
yield extract
|
||||
# Cleanup thread pool at end of session
|
||||
extract._thread_pool.shutdown()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def extraction_agent(test_case: TestCase, extractor: LlamaExtract):
|
||||
"""Fixture to create and cleanup extraction agent for each test."""
|
||||
# Create unique name with random UUID (important for CI to avoid conflicts)
|
||||
unique_id = uuid.uuid4().hex[:8]
|
||||
agent_name = f"{test_case.name}_{unique_id}"
|
||||
|
||||
with open(test_case.schema_path, "r") as f:
|
||||
schema = json.load(f)
|
||||
|
||||
# Clean up any existing agents with this name
|
||||
try:
|
||||
agents = extractor.list_agents()
|
||||
for agent in agents:
|
||||
if agent.name == agent_name:
|
||||
extractor.delete_agent(agent.id)
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to cleanup existing agent: {str(e)}")
|
||||
|
||||
# Create new agent
|
||||
agent = extractor.create_agent(agent_name, schema, config=test_case.config)
|
||||
yield agent
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
"CI" in os.environ,
|
||||
reason="CI environment is not suitable for benchmarking",
|
||||
)
|
||||
@pytest.mark.parametrize("test_case", get_test_cases(), ids=lambda x: x.name)
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_extraction(
|
||||
test_case: TestCase, extraction_agent: ExtractionAgent
|
||||
) -> None:
|
||||
start = perf_counter()
|
||||
result = await extraction_agent._queue_extraction_test(
|
||||
test_case.input_file,
|
||||
extract_settings=LlamaExtractSettings(
|
||||
llama_parse_params=LlamaParseParameters(
|
||||
invalidate_cache=True,
|
||||
do_not_cache=True,
|
||||
)
|
||||
),
|
||||
)
|
||||
end = perf_counter()
|
||||
print(f"Time taken: {end - start} seconds")
|
||||
print(result)
|
||||
@@ -8,6 +8,7 @@ from collections import namedtuple
|
||||
import json
|
||||
import uuid
|
||||
from llama_cloud.core.api_error import ApiError
|
||||
from llama_cloud.types import ExtractConfig, ExtractMode, ExtractConfig
|
||||
from deepdiff import DeepDiff
|
||||
from tests.util import json_subset_match_score
|
||||
|
||||
@@ -21,7 +22,7 @@ LLAMA_CLOUD_BASE_URL = os.getenv("LLAMA_CLOUD_BASE_URL")
|
||||
LLAMA_CLOUD_PROJECT_ID = os.getenv("LLAMA_CLOUD_PROJECT_ID")
|
||||
|
||||
TestCase = namedtuple(
|
||||
"TestCase", ["name", "schema_path", "input_file", "expected_output"]
|
||||
"TestCase", ["name", "schema_path", "config", "input_file", "expected_output"]
|
||||
)
|
||||
|
||||
|
||||
@@ -55,6 +56,11 @@ def get_test_cases():
|
||||
|
||||
input_files.append(file_path)
|
||||
|
||||
settings = [
|
||||
ExtractConfig(extraction_mode=ExtractMode.FAST),
|
||||
ExtractConfig(extraction_mode=ExtractMode.ACCURATE),
|
||||
]
|
||||
|
||||
for input_file in sorted(input_files):
|
||||
base_name = os.path.splitext(os.path.basename(input_file))[0]
|
||||
expected_output = os.path.join(data_type_dir, f"{base_name}.test.json")
|
||||
@@ -63,14 +69,16 @@ def get_test_cases():
|
||||
continue
|
||||
|
||||
test_name = f"{data_type}/{os.path.basename(input_file)}"
|
||||
test_cases.append(
|
||||
TestCase(
|
||||
name=test_name,
|
||||
schema_path=schema_path,
|
||||
input_file=input_file,
|
||||
expected_output=expected_output,
|
||||
for setting in settings:
|
||||
test_cases.append(
|
||||
TestCase(
|
||||
name=test_name,
|
||||
schema_path=schema_path,
|
||||
input_file=input_file,
|
||||
config=setting,
|
||||
expected_output=expected_output,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
return test_cases
|
||||
|
||||
@@ -109,7 +117,7 @@ def extraction_agent(test_case: TestCase, extractor: LlamaExtract):
|
||||
print(f"Warning: Failed to cleanup existing agent: {str(e)}")
|
||||
|
||||
# Create new agent
|
||||
agent = extractor.create_agent(agent_name, schema)
|
||||
agent = extractor.create_agent(agent_name, schema, config=test_case.config)
|
||||
yield agent
|
||||
|
||||
# Cleanup after test
|
||||
|
||||
Reference in New Issue
Block a user