make lint

This commit is contained in:
Neeraj Pradhan
2025-01-16 18:19:58 -08:00
parent 214e519820
commit c8c38ccd54
5 changed files with 36 additions and 53 deletions
+1
View File
@@ -26,6 +26,7 @@ repos:
rev: v1.0.1
hooks:
- id: mypy
files: ^llama_extract/
additional_dependencies:
[
"types-requests",
+26 -32
View File
@@ -3,7 +3,7 @@ import os
import time
from io import BufferedIOBase, BufferedReader, BytesIO
from pathlib import Path
from typing import List, Optional, Type, Union, Tuple
from typing import List, Optional, Type, Union, Tuple, Coroutine, Any, TypeVar
import warnings
import httpx
from pydantic import BaseModel
@@ -20,13 +20,15 @@ from llama_cloud import (
Project,
)
from llama_cloud.client import AsyncLlamaCloud
from llama_extract.utils import JSONObjectType, run_sync, handle_async_errors
from llama_extract.utils import JSONObjectType, augment_async_errors
from llama_index.core.schema import BaseComponent
from llama_index.core.async_utils import run_jobs
from llama_index.core.bridge.pydantic import Field, PrivateAttr
from llama_index.core.constants import DEFAULT_BASE_URL
from concurrent.futures import ThreadPoolExecutor
T = TypeVar("T")
FileInput = Union[str, Path, bytes, BufferedIOBase]
SchemaInput = Union[JSONObjectType, Type[BaseModel]]
@@ -61,17 +63,17 @@ class ExtractionAgent:
self.num_workers = num_workers
self.show_progress = show_progress
self._verbose = verbose
self._data_schema = None
self._config = None
self._data_schema: Union[JSONObjectType, None] = None
self._config: Union[ExtractConfig, None] = None
self._thread_pool = ThreadPoolExecutor(
max_workers=min(10, (os.cpu_count() or 1) + 4)
)
def _run_in_thread(self, coro):
def _run_in_thread(self, coro: Coroutine[Any, Any, T]) -> T:
"""Run coroutine in a separate thread to avoid event loop issues"""
def run_coro():
async def wrapped_coro():
def run_coro() -> T:
async def wrapped_coro() -> T:
async with httpx.AsyncClient(
timeout=self._client._client_wrapper.httpx_client.timeout,
) as client:
@@ -99,30 +101,27 @@ class ExtractionAgent:
return self._agent.data_schema if not self._data_schema else self._data_schema
@data_schema.setter
def data_schema(self, data_schema: SchemaInput):
def data_schema(self, data_schema: SchemaInput) -> None:
processed_schema: JSONObjectType
if isinstance(data_schema, dict):
data_schema = data_schema
elif issubclass(data_schema, BaseModel):
data_schema = data_schema.model_json_schema()
# TODO: if we expose a get_validated JSON schema method, we can use it here
processed_schema = data_schema # type: ignore
elif isinstance(data_schema, type) and issubclass(data_schema, BaseModel):
processed_schema = data_schema.model_json_schema()
else:
raise ValueError(
"data_schema must be either a dictionary or a Pydantic model"
)
self._data_schema = data_schema
self._data_schema = processed_schema
@property
def config(self) -> ExtractConfig:
return self._agent.config if not self._config else self._config
@config.setter
def config(self, config: ExtractConfig):
def config(self, config: ExtractConfig) -> None:
self._config = config
def _run_sync(self, coro):
"""Helper method to run async code in sync context."""
with run_sync() as runner:
return runner(coro)
async def _upload_file(self, file_input: FileInput) -> File:
"""Upload a file for extraction."""
if isinstance(file_input, BufferedIOBase):
@@ -208,7 +207,7 @@ class ExtractionAgent:
single_file = False
upload_tasks = [self._upload_file(file) for file in files]
with handle_async_errors():
with augment_async_errors():
uploaded_files = await run_jobs(
upload_tasks,
workers=self.num_workers,
@@ -227,7 +226,7 @@ class ExtractionAgent:
)
for file in uploaded_files
]
with handle_async_errors():
with augment_async_errors():
results = await run_jobs(
job_tasks,
workers=self.num_workers,
@@ -266,8 +265,8 @@ class ExtractionAgent:
# Queue all files for extraction
jobs = await self.queue_extraction(files)
# Wait for all results concurrently
result_tasks = [self._wait_for_job_result(job.id) for job in jobs]
with handle_async_errors():
result_tasks = [self._wait_for_job_result(job.id) for job, _ in jobs]
with augment_async_errors():
results = await run_jobs(
result_tasks,
workers=self.num_workers,
@@ -310,7 +309,7 @@ class ExtractionAgent:
)
)
def __repr__(self):
def __repr__(self) -> str:
return f"ExtractionAgent(id={self.id}, name={self.name})"
@@ -393,12 +392,12 @@ class LlamaExtract(BaseComponent):
max_workers=min(10, (os.cpu_count() or 1) + 4)
)
def _run_in_thread(self, coro):
def _run_in_thread(self, coro: Coroutine[Any, Any, T]) -> T:
"""Run coroutine in a separate thread to avoid event loop issues"""
def run_coro():
def run_coro() -> T:
# Create a new client for this thread
async def wrapped_coro():
async def wrapped_coro() -> T:
async with httpx.AsyncClient(
timeout=self._async_client._client_wrapper.httpx_client.timeout,
) as client:
@@ -416,11 +415,6 @@ class LlamaExtract(BaseComponent):
return self._thread_pool.submit(run_coro).result()
def _run_sync(self, coro):
"""Helper method to run async code in sync context."""
with run_sync() as runner:
return runner(coro)
def create_agent(
self,
name: str,
@@ -540,7 +534,7 @@ class LlamaExtract(BaseComponent):
for agent in agents
]
def delete_agent(self, agent_id: str):
def delete_agent(self, agent_id: str) -> None:
"""Delete an extraction agent by ID.
Args:
+2 -20
View File
@@ -1,4 +1,4 @@
from typing import Any, Dict, List, Union
from typing import Any, Dict, List, Union, Generator
import asyncio
from llama_index.core.async_utils import asyncio_run
from contextlib import contextmanager
@@ -22,25 +22,7 @@ def is_jupyter() -> bool:
@contextmanager
def run_sync():
"""Context manager to handle async runtime errors."""
def run_with_error_handling(coro):
# Only apply special handling in Jupyter
if is_jupyter():
import nest_asyncio
nest_asyncio.apply()
return asyncio.get_event_loop().run_until_complete(coro)
with handle_async_errors():
return asyncio_run(coro)
yield run_with_error_handling
@contextmanager
def handle_async_errors():
def augment_async_errors() -> Generator[None, None, None]:
"""Context manager to add helpful information for errors due to nested event loops."""
try:
yield
+4
View File
@@ -2,6 +2,10 @@
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"
[tool.mypy]
files = ["llama_extract"]
python_version = "3.9"
[tool.poetry]
name = "llama-extract"
version = "0.1.0"
+3 -1
View File
@@ -1,8 +1,10 @@
from typing import Any
from autoevals.string import Levenshtein
from autoevals.number import NumericDiff
def json_subset_match_score(expected, actual):
def json_subset_match_score(expected: Any, actual: Any) -> float:
"""
Adapted from autoevals.JsonDiff to only test on the subset of keys within the expected json.
"""