mirror of
https://github.com/run-llama/llama_extract.git
synced 2026-07-01 01:37:54 -04:00
make lint
This commit is contained in:
@@ -26,6 +26,7 @@ repos:
|
||||
rev: v1.0.1
|
||||
hooks:
|
||||
- id: mypy
|
||||
files: ^llama_extract/
|
||||
additional_dependencies:
|
||||
[
|
||||
"types-requests",
|
||||
|
||||
+26
-32
@@ -3,7 +3,7 @@ import os
|
||||
import time
|
||||
from io import BufferedIOBase, BufferedReader, BytesIO
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Type, Union, Tuple
|
||||
from typing import List, Optional, Type, Union, Tuple, Coroutine, Any, TypeVar
|
||||
import warnings
|
||||
import httpx
|
||||
from pydantic import BaseModel
|
||||
@@ -20,13 +20,15 @@ from llama_cloud import (
|
||||
Project,
|
||||
)
|
||||
from llama_cloud.client import AsyncLlamaCloud
|
||||
from llama_extract.utils import JSONObjectType, run_sync, handle_async_errors
|
||||
from llama_extract.utils import JSONObjectType, augment_async_errors
|
||||
from llama_index.core.schema import BaseComponent
|
||||
from llama_index.core.async_utils import run_jobs
|
||||
from llama_index.core.bridge.pydantic import Field, PrivateAttr
|
||||
from llama_index.core.constants import DEFAULT_BASE_URL
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
FileInput = Union[str, Path, bytes, BufferedIOBase]
|
||||
SchemaInput = Union[JSONObjectType, Type[BaseModel]]
|
||||
|
||||
@@ -61,17 +63,17 @@ class ExtractionAgent:
|
||||
self.num_workers = num_workers
|
||||
self.show_progress = show_progress
|
||||
self._verbose = verbose
|
||||
self._data_schema = None
|
||||
self._config = None
|
||||
self._data_schema: Union[JSONObjectType, None] = None
|
||||
self._config: Union[ExtractConfig, None] = None
|
||||
self._thread_pool = ThreadPoolExecutor(
|
||||
max_workers=min(10, (os.cpu_count() or 1) + 4)
|
||||
)
|
||||
|
||||
def _run_in_thread(self, coro):
|
||||
def _run_in_thread(self, coro: Coroutine[Any, Any, T]) -> T:
|
||||
"""Run coroutine in a separate thread to avoid event loop issues"""
|
||||
|
||||
def run_coro():
|
||||
async def wrapped_coro():
|
||||
def run_coro() -> T:
|
||||
async def wrapped_coro() -> T:
|
||||
async with httpx.AsyncClient(
|
||||
timeout=self._client._client_wrapper.httpx_client.timeout,
|
||||
) as client:
|
||||
@@ -99,30 +101,27 @@ class ExtractionAgent:
|
||||
return self._agent.data_schema if not self._data_schema else self._data_schema
|
||||
|
||||
@data_schema.setter
|
||||
def data_schema(self, data_schema: SchemaInput):
|
||||
def data_schema(self, data_schema: SchemaInput) -> None:
|
||||
processed_schema: JSONObjectType
|
||||
if isinstance(data_schema, dict):
|
||||
data_schema = data_schema
|
||||
elif issubclass(data_schema, BaseModel):
|
||||
data_schema = data_schema.model_json_schema()
|
||||
# TODO: if we expose a get_validated JSON schema method, we can use it here
|
||||
processed_schema = data_schema # type: ignore
|
||||
elif isinstance(data_schema, type) and issubclass(data_schema, BaseModel):
|
||||
processed_schema = data_schema.model_json_schema()
|
||||
else:
|
||||
raise ValueError(
|
||||
"data_schema must be either a dictionary or a Pydantic model"
|
||||
)
|
||||
self._data_schema = data_schema
|
||||
self._data_schema = processed_schema
|
||||
|
||||
@property
|
||||
def config(self) -> ExtractConfig:
|
||||
return self._agent.config if not self._config else self._config
|
||||
|
||||
@config.setter
|
||||
def config(self, config: ExtractConfig):
|
||||
def config(self, config: ExtractConfig) -> None:
|
||||
self._config = config
|
||||
|
||||
def _run_sync(self, coro):
|
||||
"""Helper method to run async code in sync context."""
|
||||
with run_sync() as runner:
|
||||
return runner(coro)
|
||||
|
||||
async def _upload_file(self, file_input: FileInput) -> File:
|
||||
"""Upload a file for extraction."""
|
||||
if isinstance(file_input, BufferedIOBase):
|
||||
@@ -208,7 +207,7 @@ class ExtractionAgent:
|
||||
single_file = False
|
||||
|
||||
upload_tasks = [self._upload_file(file) for file in files]
|
||||
with handle_async_errors():
|
||||
with augment_async_errors():
|
||||
uploaded_files = await run_jobs(
|
||||
upload_tasks,
|
||||
workers=self.num_workers,
|
||||
@@ -227,7 +226,7 @@ class ExtractionAgent:
|
||||
)
|
||||
for file in uploaded_files
|
||||
]
|
||||
with handle_async_errors():
|
||||
with augment_async_errors():
|
||||
results = await run_jobs(
|
||||
job_tasks,
|
||||
workers=self.num_workers,
|
||||
@@ -266,8 +265,8 @@ class ExtractionAgent:
|
||||
# Queue all files for extraction
|
||||
jobs = await self.queue_extraction(files)
|
||||
# Wait for all results concurrently
|
||||
result_tasks = [self._wait_for_job_result(job.id) for job in jobs]
|
||||
with handle_async_errors():
|
||||
result_tasks = [self._wait_for_job_result(job.id) for job, _ in jobs]
|
||||
with augment_async_errors():
|
||||
results = await run_jobs(
|
||||
result_tasks,
|
||||
workers=self.num_workers,
|
||||
@@ -310,7 +309,7 @@ class ExtractionAgent:
|
||||
)
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
return f"ExtractionAgent(id={self.id}, name={self.name})"
|
||||
|
||||
|
||||
@@ -393,12 +392,12 @@ class LlamaExtract(BaseComponent):
|
||||
max_workers=min(10, (os.cpu_count() or 1) + 4)
|
||||
)
|
||||
|
||||
def _run_in_thread(self, coro):
|
||||
def _run_in_thread(self, coro: Coroutine[Any, Any, T]) -> T:
|
||||
"""Run coroutine in a separate thread to avoid event loop issues"""
|
||||
|
||||
def run_coro():
|
||||
def run_coro() -> T:
|
||||
# Create a new client for this thread
|
||||
async def wrapped_coro():
|
||||
async def wrapped_coro() -> T:
|
||||
async with httpx.AsyncClient(
|
||||
timeout=self._async_client._client_wrapper.httpx_client.timeout,
|
||||
) as client:
|
||||
@@ -416,11 +415,6 @@ class LlamaExtract(BaseComponent):
|
||||
|
||||
return self._thread_pool.submit(run_coro).result()
|
||||
|
||||
def _run_sync(self, coro):
|
||||
"""Helper method to run async code in sync context."""
|
||||
with run_sync() as runner:
|
||||
return runner(coro)
|
||||
|
||||
def create_agent(
|
||||
self,
|
||||
name: str,
|
||||
@@ -540,7 +534,7 @@ class LlamaExtract(BaseComponent):
|
||||
for agent in agents
|
||||
]
|
||||
|
||||
def delete_agent(self, agent_id: str):
|
||||
def delete_agent(self, agent_id: str) -> None:
|
||||
"""Delete an extraction agent by ID.
|
||||
|
||||
Args:
|
||||
|
||||
+2
-20
@@ -1,4 +1,4 @@
|
||||
from typing import Any, Dict, List, Union
|
||||
from typing import Any, Dict, List, Union, Generator
|
||||
import asyncio
|
||||
from llama_index.core.async_utils import asyncio_run
|
||||
from contextlib import contextmanager
|
||||
@@ -22,25 +22,7 @@ def is_jupyter() -> bool:
|
||||
|
||||
|
||||
@contextmanager
|
||||
def run_sync():
|
||||
"""Context manager to handle async runtime errors."""
|
||||
|
||||
def run_with_error_handling(coro):
|
||||
# Only apply special handling in Jupyter
|
||||
if is_jupyter():
|
||||
import nest_asyncio
|
||||
|
||||
nest_asyncio.apply()
|
||||
return asyncio.get_event_loop().run_until_complete(coro)
|
||||
|
||||
with handle_async_errors():
|
||||
return asyncio_run(coro)
|
||||
|
||||
yield run_with_error_handling
|
||||
|
||||
|
||||
@contextmanager
|
||||
def handle_async_errors():
|
||||
def augment_async_errors() -> Generator[None, None, None]:
|
||||
"""Context manager to add helpful information for errors due to nested event loops."""
|
||||
try:
|
||||
yield
|
||||
|
||||
@@ -2,6 +2,10 @@
|
||||
requires = ["poetry-core"]
|
||||
build-backend = "poetry.core.masonry.api"
|
||||
|
||||
[tool.mypy]
|
||||
files = ["llama_extract"]
|
||||
python_version = "3.9"
|
||||
|
||||
[tool.poetry]
|
||||
name = "llama-extract"
|
||||
version = "0.1.0"
|
||||
|
||||
+3
-1
@@ -1,8 +1,10 @@
|
||||
from typing import Any
|
||||
|
||||
from autoevals.string import Levenshtein
|
||||
from autoevals.number import NumericDiff
|
||||
|
||||
|
||||
def json_subset_match_score(expected, actual):
|
||||
def json_subset_match_score(expected: Any, actual: Any) -> float:
|
||||
"""
|
||||
Adapted from autoevals.JsonDiff to only test on the subset of keys within the expected json.
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user