Fix sheets API client (#1032)

This commit is contained in:
George He
2025-12-03 14:39:47 -08:00
committed by GitHub
parent 32487763d5
commit dac0f79e51
4 changed files with 65 additions and 17 deletions
+35 -3
View File
@@ -2,7 +2,7 @@ import asyncio
import io
import os
import time
from typing import TYPE_CHECKING
from typing import Any, Dict, TYPE_CHECKING
import httpx
from llama_cloud.client import AsyncLlamaCloud
@@ -68,6 +68,8 @@ class LlamaSheets:
max_timeout: int = 300,
poll_interval: int = 5,
max_retries: int = 3,
project_id: str | None = None,
organization_id: str | None = None,
async_httpx_client: httpx.AsyncClient | None = None,
) -> None:
"""Initialize the LlamaSheets client.
@@ -78,6 +80,8 @@ class LlamaSheets:
max_timeout: Maximum time to wait for job completion in seconds
poll_interval: Interval between status checks in seconds
max_retries: Maximum number of retries for failed requests
project_id: Project ID for file operations. If not provided, will use LLAMA_CLOUD_PROJECT_ID env var
organization_id: Organization ID for file operations. If not provided, will use LLAMA_CLOUD_ORGANIZATION_ID env var
async_httpx_client: Optional custom async httpx client
"""
self.api_key = api_key or os.environ.get("LLAMA_CLOUD_API_KEY")
@@ -93,15 +97,32 @@ class LlamaSheets:
self.poll_interval = poll_interval
self.max_retries = max_retries
self.project_id = project_id or os.environ.get("LLAMA_CLOUD_PROJECT_ID")
self.organization_id = organization_id or os.environ.get(
"LLAMA_CLOUD_ORGANIZATION_ID"
)
self._async_client: httpx.AsyncClient | None = async_httpx_client
self._files_client = FileClient(
AsyncLlamaCloud(
token=self.api_key,
base_url=self.base_url,
httpx_client=async_httpx_client,
)
),
project_id=self.project_id,
organization_id=self.organization_id,
)
def _get_default_params(self) -> dict[str, str]:
"""Get default query parameters for API requests"""
params = {}
if self.project_id is not None:
params["project_id"] = self.project_id
if self.organization_id is not None:
params["organization_id"] = self.organization_id
return params
def _get_async_client(self) -> httpx.AsyncClient:
"""Get or create the async httpx client"""
if self._async_client is None:
@@ -306,6 +327,8 @@ class LlamaSheets:
"config": config.model_dump(mode="json", exclude_none=True),
}
params = self._get_default_params()
try:
async for attempt in AsyncRetrying(
stop=stop_after_attempt(self.max_retries),
@@ -318,6 +341,7 @@ class LlamaSheets:
response = await client.post(
f"{self.base_url}/api/v1/beta/sheets/jobs",
headers=self._get_headers(),
params=params,
json=payload,
)
response.raise_for_status()
@@ -347,12 +371,17 @@ class LlamaSheets:
):
with attempt:
client = self._get_async_client()
params: Dict[str, Any] = {
"include_results": include_results_metadata,
**self._get_default_params(),
}
response = await client.get(
f"{self.base_url}/api/v1/beta/sheets/jobs/{job_id}",
headers=self._get_headers(),
params={"include_results": include_results_metadata},
params=params,
)
response.raise_for_status()
return SpreadsheetJobResult.model_validate(response.json())
except Exception as e:
raise SpreadsheetAPIError(f"Failed to get job status: {e}") from e
@@ -415,6 +444,8 @@ class LlamaSheets:
# Get presigned URL
presigned_response = None
result_type_str = str(result_type)
params = self._get_default_params()
try:
async for attempt in AsyncRetrying(
stop=stop_after_attempt(self.max_retries),
@@ -427,6 +458,7 @@ class LlamaSheets:
response = await client.get(
f"{self.base_url}/api/v1/beta/sheets/jobs/{job_id}/regions/{region_id}/result/{result_type_str}",
headers=self._get_headers(),
params=params,
)
response.raise_for_status()
presigned_response = PresignedUrlResponse.model_validate(
+22
View File
@@ -654,6 +654,9 @@ class LlamaCloudIndex(BaseManagedIndex):
],
)
# Trigger a sync
client.pipelines.sync_pipeline(pipeline_id=index.pipeline.id)
doc_ids = [doc.id for doc in upserted_documents]
index.wait_for_completion(
doc_ids=doc_ids, verbose=verbose, raise_on_error=raise_on_error
@@ -738,6 +741,10 @@ class LlamaCloudIndex(BaseManagedIndex):
)
],
)
# Trigger a sync
self._client.pipelines.sync_pipeline(pipeline_id=self.pipeline.id)
upserted_document = upserted_documents[0]
self.wait_for_completion(
doc_ids=[upserted_document.id], verbose=verbose, raise_on_error=True
@@ -760,6 +767,9 @@ class LlamaCloudIndex(BaseManagedIndex):
)
],
)
# Trigger a sync
await self._aclient.pipelines.sync_pipeline(pipeline_id=self.pipeline.id)
upserted_document = upserted_documents[0]
await self.await_for_completion(
doc_ids=[upserted_document.id], verbose=verbose, raise_on_error=True
@@ -782,6 +792,9 @@ class LlamaCloudIndex(BaseManagedIndex):
)
],
)
# Trigger a sync
self._client.pipelines.sync_pipeline(pipeline_id=self.pipeline.id)
upserted_document = upserted_documents[0]
self.wait_for_completion(
doc_ids=[upserted_document.id], verbose=verbose, raise_on_error=True
@@ -804,6 +817,9 @@ class LlamaCloudIndex(BaseManagedIndex):
)
],
)
# Trigger a sync
await self._aclient.pipelines.sync_pipeline(pipeline_id=self.pipeline.id)
upserted_document = upserted_documents[0]
await self.await_for_completion(
doc_ids=[upserted_document.id], verbose=verbose, raise_on_error=True
@@ -827,6 +843,9 @@ class LlamaCloudIndex(BaseManagedIndex):
for doc in documents
],
)
# Trigger a sync
self._client.pipelines.sync_pipeline(pipeline_id=self.pipeline.id)
doc_ids = [doc.id for doc in upserted_documents]
self.wait_for_completion(doc_ids=doc_ids, verbose=True, raise_on_error=True)
return [True] * len(doc_ids)
@@ -849,6 +868,9 @@ class LlamaCloudIndex(BaseManagedIndex):
for doc in documents
],
)
# Trigger a sync
await self._aclient.pipelines.sync_pipeline(pipeline_id=self.pipeline.id)
doc_ids = [doc.id for doc in upserted_documents]
await self.await_for_completion(
doc_ids=doc_ids, verbose=True, raise_on_error=True
+6 -12
View File
@@ -44,16 +44,16 @@ class TestSpreadsheetParsingConfig:
@pytest.fixture
def sheets_client():
"""Create a LlamaSheets client for testing."""
api_key = os.getenv(
"LLAMA_CLOUD_API_KEY", "llx-3AEorIw5v0lnJPzEOI9xSl0N8yFx3fguw0Zn8QJHzGWmwg5r"
)
base_url = os.getenv("LLAMA_CLOUD_BASE_URL", "https://api.staging.llamaindex.ai")
api_key = os.getenv("LLAMA_CLOUD_API_KEY")
base_url = os.getenv("LLAMA_CLOUD_BASE_URL", "https://api.cloud.llamaindex.ai")
project_id = os.getenv("LLAMA_CLOUD_PROJECT_ID")
client = LlamaSheets(
api_key=api_key,
base_url=base_url,
max_timeout=300,
poll_interval=2,
project_id=project_id,
)
return client
@@ -85,10 +85,7 @@ def sample_excel_file():
@pytest.mark.skipif(
os.environ.get(
"LLAMA_CLOUD_API_KEY", "llx-3AEorIw5v0lnJPzEOI9xSl0N8yFx3fguw0Zn8QJHzGWmwg5r"
)
== "",
os.environ.get("LLAMA_CLOUD_API_KEY", "") == "",
reason="LLAMA_CLOUD_API_KEY not set",
)
@pytest.mark.asyncio
@@ -168,10 +165,7 @@ async def test_spreadsheet_extraction_e2e(
@pytest.mark.skipif(
os.environ.get(
"LLAMA_CLOUD_API_KEY", "llx-3AEorIw5v0lnJPzEOI9xSl0N8yFx3fguw0Zn8QJHzGWmwg5r"
)
== "",
os.environ.get("LLAMA_CLOUD_API_KEY", "") == "",
reason="LLAMA_CLOUD_API_KEY not set",
)
@pytest.mark.asyncio
+2 -2
View File
@@ -78,7 +78,7 @@ async def test_upload_bytes(
uploaded_file = await file_client.upload_bytes(file_bytes, external_file_id)
assert isinstance(uploaded_file, File)
expected_name = external_file_id if use_presigned_url else "upload"
expected_name = external_file_id
assert uploaded_file.name == expected_name
assert uploaded_file.external_file_id == external_file_id
@@ -100,7 +100,7 @@ async def test_upload_buffer(
uploaded_file = await file_client.upload_buffer(buffer, external_file_id, file_size)
assert isinstance(uploaded_file, File)
expected_name = external_file_id if use_presigned_url else "upload"
expected_name = external_file_id
assert uploaded_file.name == expected_name
assert uploaded_file.external_file_id == external_file_id