mirror of
https://github.com/run-llama/llama_extract.git
synced 2026-07-01 01:37:54 -04:00
fix tests
This commit is contained in:
@@ -189,7 +189,7 @@ class ExtractionAgent:
|
||||
async def queue_extraction(
|
||||
self,
|
||||
files: Union[FileInput, List[FileInput]],
|
||||
) -> Union[ExtractionResult, List[ExtractionResult]]:
|
||||
) -> Union[ExtractJob, List[ExtractJob]]:
|
||||
"""
|
||||
Queue multiple files for extraction.
|
||||
|
||||
@@ -197,7 +197,7 @@ class ExtractionAgent:
|
||||
files (Union[FileInput, List[FileInput]]): The files to extract
|
||||
|
||||
Returns:
|
||||
Union[ExtractionResult, List[ExtractionResult]]: The queued extraction jobs
|
||||
Union[ExtractJob, List[ExtractJob]]: The queued extraction jobs
|
||||
"""
|
||||
"""Queue one or more files for extraction concurrently."""
|
||||
if not isinstance(files, list):
|
||||
@@ -227,7 +227,7 @@ class ExtractionAgent:
|
||||
for file in uploaded_files
|
||||
]
|
||||
with augment_async_errors():
|
||||
results = await run_jobs(
|
||||
extract_jobs = await run_jobs(
|
||||
job_tasks,
|
||||
workers=self.num_workers,
|
||||
desc="Creating extraction jobs",
|
||||
@@ -243,7 +243,7 @@ class ExtractionAgent:
|
||||
f"Queued file extraction for file {file_repr} under job_id {job.id}"
|
||||
)
|
||||
|
||||
return results[0] if single_file else results
|
||||
return extract_jobs[0] if single_file else extract_jobs
|
||||
|
||||
async def aextract(
|
||||
self, files: Union[FileInput, List[FileInput]]
|
||||
@@ -265,7 +265,7 @@ class ExtractionAgent:
|
||||
# Queue all files for extraction
|
||||
jobs = await self.queue_extraction(files)
|
||||
# Wait for all results concurrently
|
||||
result_tasks = [self._wait_for_job_result(job.id) for job, _ in jobs]
|
||||
result_tasks = [self._wait_for_job_result(job.id) for job in jobs]
|
||||
with augment_async_errors():
|
||||
results = await run_jobs(
|
||||
result_tasks,
|
||||
|
||||
@@ -110,7 +110,8 @@ def extraction_agent(test_case: TestCase):
|
||||
)
|
||||
@pytest.mark.parametrize("test_case", get_test_cases(), ids=lambda x: x.name)
|
||||
def test_extraction(test_case: TestCase, extraction_agent: ExtractionAgent) -> None:
|
||||
result = extraction_agent.extract(test_case.input_file).data
|
||||
_, result = extraction_agent.extract(test_case.input_file)
|
||||
result = result.data
|
||||
with open(test_case.expected_output, "r") as f:
|
||||
expected = json.load(f)
|
||||
assert json_subset_match_score(expected, result) > 0.5, DeepDiff(
|
||||
|
||||
Reference in New Issue
Block a user