fix tests

This commit is contained in:
Neeraj Pradhan
2025-01-16 18:43:06 -08:00
parent c8c38ccd54
commit c1271d7a1a
2 changed files with 7 additions and 6 deletions
+5 -5
View File
@@ -189,7 +189,7 @@ class ExtractionAgent:
async def queue_extraction(
self,
files: Union[FileInput, List[FileInput]],
) -> Union[ExtractionResult, List[ExtractionResult]]:
) -> Union[ExtractJob, List[ExtractJob]]:
"""
Queue multiple files for extraction.
@@ -197,7 +197,7 @@ class ExtractionAgent:
files (Union[FileInput, List[FileInput]]): The files to extract
Returns:
Union[ExtractionResult, List[ExtractionResult]]: The queued extraction jobs
Union[ExtractJob, List[ExtractJob]]: The queued extraction jobs
"""
"""Queue one or more files for extraction concurrently."""
if not isinstance(files, list):
@@ -227,7 +227,7 @@ class ExtractionAgent:
for file in uploaded_files
]
with augment_async_errors():
results = await run_jobs(
extract_jobs = await run_jobs(
job_tasks,
workers=self.num_workers,
desc="Creating extraction jobs",
@@ -243,7 +243,7 @@ class ExtractionAgent:
f"Queued file extraction for file {file_repr} under job_id {job.id}"
)
return results[0] if single_file else results
return extract_jobs[0] if single_file else extract_jobs
async def aextract(
self, files: Union[FileInput, List[FileInput]]
@@ -265,7 +265,7 @@ class ExtractionAgent:
# Queue all files for extraction
jobs = await self.queue_extraction(files)
# Wait for all results concurrently
result_tasks = [self._wait_for_job_result(job.id) for job, _ in jobs]
result_tasks = [self._wait_for_job_result(job.id) for job in jobs]
with augment_async_errors():
results = await run_jobs(
result_tasks,
+2 -1
View File
@@ -110,7 +110,8 @@ def extraction_agent(test_case: TestCase):
)
@pytest.mark.parametrize("test_case", get_test_cases(), ids=lambda x: x.name)
def test_extraction(test_case: TestCase, extraction_agent: ExtractionAgent) -> None:
result = extraction_agent.extract(test_case.input_file).data
_, result = extraction_agent.extract(test_case.input_file)
result = result.data
with open(test_case.expected_output, "r") as f:
expected = json.load(f)
assert json_subset_match_score(expected, result) > 0.5, DeepDiff(