Update to llama cloud v 0.1.11 (#48)

This commit is contained in:
Neeraj Pradhan
2025-01-28 08:57:35 -08:00
committed by GitHub
parent e049c747a8
commit 90eb45ba37
5 changed files with 243 additions and 34 deletions
+56 -4
View File
@@ -12,12 +12,13 @@ from llama_cloud import (
ExtractConfig,
ExtractJob,
ExtractJobCreate,
ExtractResultset,
ExtractRun,
File,
ExtractMode,
StatusEnum,
Project,
ExtractTarget,
LlamaExtractSettings,
)
from llama_cloud.client import AsyncLlamaCloud
from llama_extract.utils import JSONObjectType, augment_async_errors
@@ -33,11 +34,10 @@ FileInput = Union[str, Path, bytes, BufferedIOBase]
SchemaInput = Union[JSONObjectType, Type[BaseModel]]
DEFAULT_EXTRACT_CONFIG = ExtractConfig(
extraction_mode=ExtractMode.PER_DOC,
extraction_target=ExtractTarget.PER_DOC,
extraction_mode=ExtractMode.ACCURATE,
)
ExtractionResult = Tuple[ExtractJob, ExtractResultset]
class ExtractionAgent:
"""Class representing a single extraction agent with methods for extraction operations."""
@@ -192,6 +192,58 @@ class ExtractionAgent:
)
)
async def _queue_extraction_test(
self,
files: Union[FileInput, List[FileInput]],
extract_settings: LlamaExtractSettings,
) -> Union[ExtractJob, List[ExtractJob]]:
if not isinstance(files, list):
files = [files]
single_file = True
else:
single_file = False
upload_tasks = [self._upload_file(file) for file in files]
with augment_async_errors():
uploaded_files = await run_jobs(
upload_tasks,
workers=self.num_workers,
desc="Uploading files",
show_progress=self.show_progress,
)
async def run_job(file: File) -> ExtractRun:
job_queued = await self._client.llama_extract.run_job_test_user(
job_create=ExtractJobCreate(
extraction_agent_id=self.id,
file_id=file.id,
data_schema_override=self.data_schema,
config_override=self.config,
),
extract_settings=extract_settings,
)
return await self._wait_for_job_result(job_queued.id)
job_tasks = [run_job(file) for file in uploaded_files]
with augment_async_errors():
extract_jobs = await run_jobs(
job_tasks,
workers=self.num_workers,
desc="Creating extraction jobs",
show_progress=self.show_progress,
)
if self._verbose:
for file, job in zip(files, extract_jobs):
file_repr = (
str(file) if isinstance(file, (str, Path)) else "<bytes/buffer>"
)
print(
f"Queued file extraction for file {file_repr} under job_id {job.id}"
)
return extract_jobs[0] if single_file else extract_jobs
async def queue_extraction(
self,
files: Union[FileInput, List[FileInput]],
Generated
+20 -20
View File
@@ -285,13 +285,13 @@ files = [
[[package]]
name = "attrs"
version = "24.3.0"
version = "25.1.0"
description = "Classes Without Boilerplate"
optional = false
python-versions = ">=3.8"
files = [
{file = "attrs-24.3.0-py3-none-any.whl", hash = "sha256:ac96cd038792094f438ad1f6ff80837353805ac950cd2aa0e0625ef19850c308"},
{file = "attrs-24.3.0.tar.gz", hash = "sha256:8f5c07333d543103541ba7be0e2ce16eeee8130cb0b3f9238ab904ce1e85baff"},
{file = "attrs-25.1.0-py3-none-any.whl", hash = "sha256:c75a69e28a550a7e93789579c22aa26b0f5b83b75dc4e08fe092980051e1090a"},
{file = "attrs-25.1.0.tar.gz", hash = "sha256:1c97078a80c814273a76b2a298a932eb681c87415c11dee0a6921de7f1b02c3e"},
]
[package.extras]
@@ -737,20 +737,20 @@ files = [
[[package]]
name = "deprecated"
version = "1.2.15"
version = "1.2.18"
description = "Python @deprecated decorator to deprecate old python classes, functions or methods."
optional = false
python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,>=2.7"
files = [
{file = "Deprecated-1.2.15-py2.py3-none-any.whl", hash = "sha256:353bc4a8ac4bfc96800ddab349d89c25dec1079f65fd53acdcc1e0b975b21320"},
{file = "deprecated-1.2.15.tar.gz", hash = "sha256:683e561a90de76239796e6b6feac66b99030d2dd3fcf61ef996330f14bbb9b0d"},
{file = "Deprecated-1.2.18-py2.py3-none-any.whl", hash = "sha256:bd5011788200372a32418f888e326a09ff80d0214bd961147cfed01b5c018eec"},
{file = "deprecated-1.2.18.tar.gz", hash = "sha256:422b6f6d859da6f2ef57857761bfb392480502a64c3028ca9bbe86085d72115d"},
]
[package.dependencies]
wrapt = ">=1.10,<2"
[package.extras]
dev = ["PyTest", "PyTest-Cov", "bump2version (<1)", "jinja2 (>=3.0.3,<3.1.0)", "setuptools", "sphinx (<2)", "tox"]
dev = ["PyTest", "PyTest-Cov", "bump2version (<1)", "setuptools", "tox"]
[[package]]
name = "dirtyjson"
@@ -1784,13 +1784,13 @@ rapidfuzz = ">=3.9.0,<4.0.0"
[[package]]
name = "llama-cloud"
version = "0.1.10"
version = "0.1.11"
description = ""
optional = false
python-versions = "<4,>=3.8"
files = [
{file = "llama_cloud-0.1.10-py3-none-any.whl", hash = "sha256:d91198ad92ea6c3a25757e5d6cb565b4bd6db385dc4fa596a725c0fb81a68f4e"},
{file = "llama_cloud-0.1.10.tar.gz", hash = "sha256:56ffe8f2910c2047dd4eb1b13da31ee5f67321a000794eee559e0b56954d2f76"},
{file = "llama_cloud-0.1.11-py3-none-any.whl", hash = "sha256:b703765d03783a5a0fc57a52adc9892f8b91b0c19bbecb85a54ad4e813342951"},
{file = "llama_cloud-0.1.11.tar.gz", hash = "sha256:d4be5b48659fd9fe1698727be257269a22d7f2733a2ed11bce7065768eb94cbe"},
]
[package.dependencies]
@@ -1905,13 +1905,13 @@ files = [
[[package]]
name = "marshmallow"
version = "3.25.1"
version = "3.26.0"
description = "A lightweight library for converting complex datatypes to and from native Python datatypes."
optional = false
python-versions = ">=3.9"
files = [
{file = "marshmallow-3.25.1-py3-none-any.whl", hash = "sha256:ec5d00d873ce473b7f2ffcb7104286a376c354cab0c2fa12f5573dab03e87210"},
{file = "marshmallow-3.25.1.tar.gz", hash = "sha256:f4debda3bb11153d81ac34b0d582bf23053055ee11e791b54b4b35493468040a"},
{file = "marshmallow-3.26.0-py3-none-any.whl", hash = "sha256:1287bca04e6a5f4094822ac153c03da5e214a0a60bcd557b140f3e66991b8ca1"},
{file = "marshmallow-3.26.0.tar.gz", hash = "sha256:eb36762a1cc76d7abf831e18a3a1b26d3d481bbc74581b8e532a3d3a8115e1cb"},
]
[package.dependencies]
@@ -2751,13 +2751,13 @@ files = [
[[package]]
name = "pydantic"
version = "2.10.5"
version = "2.10.6"
description = "Data validation using Python type hints"
optional = false
python-versions = ">=3.8"
files = [
{file = "pydantic-2.10.5-py3-none-any.whl", hash = "sha256:4dd4e322dbe55472cb7ca7e73f4b63574eecccf2835ffa2af9021ce113c83c53"},
{file = "pydantic-2.10.5.tar.gz", hash = "sha256:278b38dbbaec562011d659ee05f63346951b3a248a6f3642e1bc68894ea2b4ff"},
{file = "pydantic-2.10.6-py3-none-any.whl", hash = "sha256:427d664bf0b8a2b34ff5dd0f5a18df00591adcee7198fbd71981054cef37b584"},
{file = "pydantic-2.10.6.tar.gz", hash = "sha256:ca5daa827cce33de7a42be142548b0096bf05a7e7b365aebfa5f8eeec7128236"},
]
[package.dependencies]
@@ -3307,13 +3307,13 @@ all = ["numpy"]
[[package]]
name = "referencing"
version = "0.36.1"
version = "0.36.2"
description = "JSON Referencing + Python"
optional = false
python-versions = ">=3.9"
files = [
{file = "referencing-0.36.1-py3-none-any.whl", hash = "sha256:363d9c65f080d0d70bc41c721dce3c7f3e77fc09f269cd5c8813da18069a6794"},
{file = "referencing-0.36.1.tar.gz", hash = "sha256:ca2e6492769e3602957e9b831b94211599d2aade9477f5d44110d2530cf9aade"},
{file = "referencing-0.36.2-py3-none-any.whl", hash = "sha256:e8699adbbf8b5c7de96d8ffa0eb5c158b3beafce084968e2ea8bb08c6794dcd0"},
{file = "referencing-0.36.2.tar.gz", hash = "sha256:df2e89862cd09deabbdba16944cc3f10feb6b3e6f18e902f7cc25609a34775aa"},
]
[package.dependencies]
@@ -4317,4 +4317,4 @@ type = ["pytest-mypy"]
[metadata]
lock-version = "2.0"
python-versions = ">=3.9,<4.0"
content-hash = "90b87c91c45412dae185dac7a8eb80b8e6f37d9677990ff394590acd751c7565"
content-hash = "1ff53e863a18be137ee0eff10a8c4412e2db95ca7cdc7aedf22339325d3fc818"
+1 -1
View File
@@ -18,7 +18,7 @@ packages = [{include = "llama_extract"}]
[tool.poetry.dependencies]
python = ">=3.9,<4.0"
llama-index-core = "^0.11.0"
llama-cloud = "0.1.10"
llama-cloud = "0.1.11"
python-dotenv = "^1.0.1"
[tool.poetry.group.dev.dependencies]
+149
View File
@@ -0,0 +1,149 @@
import os
import pytest
from pathlib import Path
from llama_extract import LlamaExtract, ExtractionAgent
from dotenv import load_dotenv
from time import perf_counter
from collections import namedtuple
import json
import uuid
from llama_cloud.core.api_error import ApiError
from llama_cloud.types import (
ExtractConfig,
ExtractMode,
LlamaParseParameters,
LlamaExtractSettings,
)
load_dotenv(Path(__file__).parent.parent / ".env.dev", override=True)
TEST_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "data")
# Get configuration from environment
LLAMA_CLOUD_API_KEY = os.getenv("LLAMA_CLOUD_API_KEY")
LLAMA_CLOUD_BASE_URL = os.getenv("LLAMA_CLOUD_BASE_URL")
LLAMA_CLOUD_PROJECT_ID = os.getenv("LLAMA_CLOUD_PROJECT_ID")
TestCase = namedtuple(
"TestCase", ["name", "schema_path", "config", "input_file", "expected_output"]
)
def get_test_cases():
"""Get all test cases from TEST_DIR.
Returns:
List[TestCase]: List of test cases
"""
test_cases = []
for data_type in os.listdir(TEST_DIR):
data_type_dir = os.path.join(TEST_DIR, data_type)
if not os.path.isdir(data_type_dir):
continue
schema_path = os.path.join(data_type_dir, "schema.json")
if not os.path.exists(schema_path):
continue
input_files = []
for file in os.listdir(data_type_dir):
file_path = os.path.join(data_type_dir, file)
if (
not os.path.isfile(file_path)
or file == "schema.json"
or file.endswith(".test.json")
):
continue
input_files.append(file_path)
settings = [
ExtractConfig(extraction_mode=ExtractMode.FAST),
ExtractConfig(extraction_mode=ExtractMode.ACCURATE),
]
for input_file in sorted(input_files):
base_name = os.path.splitext(os.path.basename(input_file))[0]
expected_output = os.path.join(data_type_dir, f"{base_name}.test.json")
if not os.path.exists(expected_output):
continue
test_name = f"{data_type}/{os.path.basename(input_file)}"
for setting in settings:
test_cases.append(
TestCase(
name=test_name,
schema_path=schema_path,
input_file=input_file,
config=setting,
expected_output=expected_output,
)
)
return test_cases
@pytest.fixture(scope="session")
def extractor():
"""Create a single LlamaExtract instance for all tests."""
extract = LlamaExtract(
api_key=LLAMA_CLOUD_API_KEY,
base_url=LLAMA_CLOUD_BASE_URL,
project_id=LLAMA_CLOUD_PROJECT_ID,
verbose=True,
)
yield extract
# Cleanup thread pool at end of session
extract._thread_pool.shutdown()
@pytest.fixture
def extraction_agent(test_case: TestCase, extractor: LlamaExtract):
"""Fixture to create and cleanup extraction agent for each test."""
# Create unique name with random UUID (important for CI to avoid conflicts)
unique_id = uuid.uuid4().hex[:8]
agent_name = f"{test_case.name}_{unique_id}"
with open(test_case.schema_path, "r") as f:
schema = json.load(f)
# Clean up any existing agents with this name
try:
agents = extractor.list_agents()
for agent in agents:
if agent.name == agent_name:
extractor.delete_agent(agent.id)
except Exception as e:
print(f"Warning: Failed to cleanup existing agent: {str(e)}")
# Create new agent
agent = extractor.create_agent(agent_name, schema, config=test_case.config)
yield agent
@pytest.mark.skipif(
"CI" in os.environ,
reason="CI environment is not suitable for benchmarking",
)
@pytest.mark.parametrize("test_case", get_test_cases(), ids=lambda x: x.name)
@pytest.mark.asyncio(loop_scope="session")
async def test_extraction(
test_case: TestCase, extraction_agent: ExtractionAgent
) -> None:
start = perf_counter()
result = await extraction_agent._queue_extraction_test(
test_case.input_file,
extract_settings=LlamaExtractSettings(
llama_parse_params=LlamaParseParameters(
invalidate_cache=True,
do_not_cache=True,
)
),
)
end = perf_counter()
print(f"Time taken: {end - start} seconds")
print(result)
+17 -9
View File
@@ -8,6 +8,7 @@ from collections import namedtuple
import json
import uuid
from llama_cloud.core.api_error import ApiError
from llama_cloud.types import ExtractConfig, ExtractMode, ExtractConfig
from deepdiff import DeepDiff
from tests.util import json_subset_match_score
@@ -21,7 +22,7 @@ LLAMA_CLOUD_BASE_URL = os.getenv("LLAMA_CLOUD_BASE_URL")
LLAMA_CLOUD_PROJECT_ID = os.getenv("LLAMA_CLOUD_PROJECT_ID")
TestCase = namedtuple(
"TestCase", ["name", "schema_path", "input_file", "expected_output"]
"TestCase", ["name", "schema_path", "config", "input_file", "expected_output"]
)
@@ -55,6 +56,11 @@ def get_test_cases():
input_files.append(file_path)
settings = [
ExtractConfig(extraction_mode=ExtractMode.FAST),
ExtractConfig(extraction_mode=ExtractMode.ACCURATE),
]
for input_file in sorted(input_files):
base_name = os.path.splitext(os.path.basename(input_file))[0]
expected_output = os.path.join(data_type_dir, f"{base_name}.test.json")
@@ -63,14 +69,16 @@ def get_test_cases():
continue
test_name = f"{data_type}/{os.path.basename(input_file)}"
test_cases.append(
TestCase(
name=test_name,
schema_path=schema_path,
input_file=input_file,
expected_output=expected_output,
for setting in settings:
test_cases.append(
TestCase(
name=test_name,
schema_path=schema_path,
input_file=input_file,
config=setting,
expected_output=expected_output,
)
)
)
return test_cases
@@ -109,7 +117,7 @@ def extraction_agent(test_case: TestCase, extractor: LlamaExtract):
print(f"Warning: Failed to cleanup existing agent: {str(e)}")
# Create new agent
agent = extractor.create_agent(agent_name, schema)
agent = extractor.create_agent(agent_name, schema, config=test_case.config)
yield agent
# Cleanup after test