mirror of
https://github.com/run-llama/llama_cloud_services.git
synced 2026-07-01 21:44:37 -04:00
Fix sheets API client (#1032)
This commit is contained in:
@@ -2,7 +2,7 @@ import asyncio
|
||||
import io
|
||||
import os
|
||||
import time
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import Any, Dict, TYPE_CHECKING
|
||||
|
||||
import httpx
|
||||
from llama_cloud.client import AsyncLlamaCloud
|
||||
@@ -68,6 +68,8 @@ class LlamaSheets:
|
||||
max_timeout: int = 300,
|
||||
poll_interval: int = 5,
|
||||
max_retries: int = 3,
|
||||
project_id: str | None = None,
|
||||
organization_id: str | None = None,
|
||||
async_httpx_client: httpx.AsyncClient | None = None,
|
||||
) -> None:
|
||||
"""Initialize the LlamaSheets client.
|
||||
@@ -78,6 +80,8 @@ class LlamaSheets:
|
||||
max_timeout: Maximum time to wait for job completion in seconds
|
||||
poll_interval: Interval between status checks in seconds
|
||||
max_retries: Maximum number of retries for failed requests
|
||||
project_id: Project ID for file operations. If not provided, will use LLAMA_CLOUD_PROJECT_ID env var
|
||||
organization_id: Organization ID for file operations. If not provided, will use LLAMA_CLOUD_ORGANIZATION_ID env var
|
||||
async_httpx_client: Optional custom async httpx client
|
||||
"""
|
||||
self.api_key = api_key or os.environ.get("LLAMA_CLOUD_API_KEY")
|
||||
@@ -93,15 +97,32 @@ class LlamaSheets:
|
||||
self.poll_interval = poll_interval
|
||||
self.max_retries = max_retries
|
||||
|
||||
self.project_id = project_id or os.environ.get("LLAMA_CLOUD_PROJECT_ID")
|
||||
self.organization_id = organization_id or os.environ.get(
|
||||
"LLAMA_CLOUD_ORGANIZATION_ID"
|
||||
)
|
||||
|
||||
self._async_client: httpx.AsyncClient | None = async_httpx_client
|
||||
self._files_client = FileClient(
|
||||
AsyncLlamaCloud(
|
||||
token=self.api_key,
|
||||
base_url=self.base_url,
|
||||
httpx_client=async_httpx_client,
|
||||
)
|
||||
),
|
||||
project_id=self.project_id,
|
||||
organization_id=self.organization_id,
|
||||
)
|
||||
|
||||
def _get_default_params(self) -> dict[str, str]:
|
||||
"""Get default query parameters for API requests"""
|
||||
params = {}
|
||||
if self.project_id is not None:
|
||||
params["project_id"] = self.project_id
|
||||
if self.organization_id is not None:
|
||||
params["organization_id"] = self.organization_id
|
||||
|
||||
return params
|
||||
|
||||
def _get_async_client(self) -> httpx.AsyncClient:
|
||||
"""Get or create the async httpx client"""
|
||||
if self._async_client is None:
|
||||
@@ -306,6 +327,8 @@ class LlamaSheets:
|
||||
"config": config.model_dump(mode="json", exclude_none=True),
|
||||
}
|
||||
|
||||
params = self._get_default_params()
|
||||
|
||||
try:
|
||||
async for attempt in AsyncRetrying(
|
||||
stop=stop_after_attempt(self.max_retries),
|
||||
@@ -318,6 +341,7 @@ class LlamaSheets:
|
||||
response = await client.post(
|
||||
f"{self.base_url}/api/v1/beta/sheets/jobs",
|
||||
headers=self._get_headers(),
|
||||
params=params,
|
||||
json=payload,
|
||||
)
|
||||
response.raise_for_status()
|
||||
@@ -347,12 +371,17 @@ class LlamaSheets:
|
||||
):
|
||||
with attempt:
|
||||
client = self._get_async_client()
|
||||
params: Dict[str, Any] = {
|
||||
"include_results": include_results_metadata,
|
||||
**self._get_default_params(),
|
||||
}
|
||||
response = await client.get(
|
||||
f"{self.base_url}/api/v1/beta/sheets/jobs/{job_id}",
|
||||
headers=self._get_headers(),
|
||||
params={"include_results": include_results_metadata},
|
||||
params=params,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
return SpreadsheetJobResult.model_validate(response.json())
|
||||
except Exception as e:
|
||||
raise SpreadsheetAPIError(f"Failed to get job status: {e}") from e
|
||||
@@ -415,6 +444,8 @@ class LlamaSheets:
|
||||
# Get presigned URL
|
||||
presigned_response = None
|
||||
result_type_str = str(result_type)
|
||||
params = self._get_default_params()
|
||||
|
||||
try:
|
||||
async for attempt in AsyncRetrying(
|
||||
stop=stop_after_attempt(self.max_retries),
|
||||
@@ -427,6 +458,7 @@ class LlamaSheets:
|
||||
response = await client.get(
|
||||
f"{self.base_url}/api/v1/beta/sheets/jobs/{job_id}/regions/{region_id}/result/{result_type_str}",
|
||||
headers=self._get_headers(),
|
||||
params=params,
|
||||
)
|
||||
response.raise_for_status()
|
||||
presigned_response = PresignedUrlResponse.model_validate(
|
||||
|
||||
@@ -654,6 +654,9 @@ class LlamaCloudIndex(BaseManagedIndex):
|
||||
],
|
||||
)
|
||||
|
||||
# Trigger a sync
|
||||
client.pipelines.sync_pipeline(pipeline_id=index.pipeline.id)
|
||||
|
||||
doc_ids = [doc.id for doc in upserted_documents]
|
||||
index.wait_for_completion(
|
||||
doc_ids=doc_ids, verbose=verbose, raise_on_error=raise_on_error
|
||||
@@ -738,6 +741,10 @@ class LlamaCloudIndex(BaseManagedIndex):
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
# Trigger a sync
|
||||
self._client.pipelines.sync_pipeline(pipeline_id=self.pipeline.id)
|
||||
|
||||
upserted_document = upserted_documents[0]
|
||||
self.wait_for_completion(
|
||||
doc_ids=[upserted_document.id], verbose=verbose, raise_on_error=True
|
||||
@@ -760,6 +767,9 @@ class LlamaCloudIndex(BaseManagedIndex):
|
||||
)
|
||||
],
|
||||
)
|
||||
# Trigger a sync
|
||||
await self._aclient.pipelines.sync_pipeline(pipeline_id=self.pipeline.id)
|
||||
|
||||
upserted_document = upserted_documents[0]
|
||||
await self.await_for_completion(
|
||||
doc_ids=[upserted_document.id], verbose=verbose, raise_on_error=True
|
||||
@@ -782,6 +792,9 @@ class LlamaCloudIndex(BaseManagedIndex):
|
||||
)
|
||||
],
|
||||
)
|
||||
# Trigger a sync
|
||||
self._client.pipelines.sync_pipeline(pipeline_id=self.pipeline.id)
|
||||
|
||||
upserted_document = upserted_documents[0]
|
||||
self.wait_for_completion(
|
||||
doc_ids=[upserted_document.id], verbose=verbose, raise_on_error=True
|
||||
@@ -804,6 +817,9 @@ class LlamaCloudIndex(BaseManagedIndex):
|
||||
)
|
||||
],
|
||||
)
|
||||
# Trigger a sync
|
||||
await self._aclient.pipelines.sync_pipeline(pipeline_id=self.pipeline.id)
|
||||
|
||||
upserted_document = upserted_documents[0]
|
||||
await self.await_for_completion(
|
||||
doc_ids=[upserted_document.id], verbose=verbose, raise_on_error=True
|
||||
@@ -827,6 +843,9 @@ class LlamaCloudIndex(BaseManagedIndex):
|
||||
for doc in documents
|
||||
],
|
||||
)
|
||||
# Trigger a sync
|
||||
self._client.pipelines.sync_pipeline(pipeline_id=self.pipeline.id)
|
||||
|
||||
doc_ids = [doc.id for doc in upserted_documents]
|
||||
self.wait_for_completion(doc_ids=doc_ids, verbose=True, raise_on_error=True)
|
||||
return [True] * len(doc_ids)
|
||||
@@ -849,6 +868,9 @@ class LlamaCloudIndex(BaseManagedIndex):
|
||||
for doc in documents
|
||||
],
|
||||
)
|
||||
# Trigger a sync
|
||||
await self._aclient.pipelines.sync_pipeline(pipeline_id=self.pipeline.id)
|
||||
|
||||
doc_ids = [doc.id for doc in upserted_documents]
|
||||
await self.await_for_completion(
|
||||
doc_ids=doc_ids, verbose=True, raise_on_error=True
|
||||
|
||||
@@ -44,16 +44,16 @@ class TestSpreadsheetParsingConfig:
|
||||
@pytest.fixture
|
||||
def sheets_client():
|
||||
"""Create a LlamaSheets client for testing."""
|
||||
api_key = os.getenv(
|
||||
"LLAMA_CLOUD_API_KEY", "llx-3AEorIw5v0lnJPzEOI9xSl0N8yFx3fguw0Zn8QJHzGWmwg5r"
|
||||
)
|
||||
base_url = os.getenv("LLAMA_CLOUD_BASE_URL", "https://api.staging.llamaindex.ai")
|
||||
api_key = os.getenv("LLAMA_CLOUD_API_KEY")
|
||||
base_url = os.getenv("LLAMA_CLOUD_BASE_URL", "https://api.cloud.llamaindex.ai")
|
||||
project_id = os.getenv("LLAMA_CLOUD_PROJECT_ID")
|
||||
|
||||
client = LlamaSheets(
|
||||
api_key=api_key,
|
||||
base_url=base_url,
|
||||
max_timeout=300,
|
||||
poll_interval=2,
|
||||
project_id=project_id,
|
||||
)
|
||||
return client
|
||||
|
||||
@@ -85,10 +85,7 @@ def sample_excel_file():
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get(
|
||||
"LLAMA_CLOUD_API_KEY", "llx-3AEorIw5v0lnJPzEOI9xSl0N8yFx3fguw0Zn8QJHzGWmwg5r"
|
||||
)
|
||||
== "",
|
||||
os.environ.get("LLAMA_CLOUD_API_KEY", "") == "",
|
||||
reason="LLAMA_CLOUD_API_KEY not set",
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
@@ -168,10 +165,7 @@ async def test_spreadsheet_extraction_e2e(
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get(
|
||||
"LLAMA_CLOUD_API_KEY", "llx-3AEorIw5v0lnJPzEOI9xSl0N8yFx3fguw0Zn8QJHzGWmwg5r"
|
||||
)
|
||||
== "",
|
||||
os.environ.get("LLAMA_CLOUD_API_KEY", "") == "",
|
||||
reason="LLAMA_CLOUD_API_KEY not set",
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
|
||||
@@ -78,7 +78,7 @@ async def test_upload_bytes(
|
||||
uploaded_file = await file_client.upload_bytes(file_bytes, external_file_id)
|
||||
|
||||
assert isinstance(uploaded_file, File)
|
||||
expected_name = external_file_id if use_presigned_url else "upload"
|
||||
expected_name = external_file_id
|
||||
assert uploaded_file.name == expected_name
|
||||
assert uploaded_file.external_file_id == external_file_id
|
||||
|
||||
@@ -100,7 +100,7 @@ async def test_upload_buffer(
|
||||
uploaded_file = await file_client.upload_buffer(buffer, external_file_id, file_size)
|
||||
|
||||
assert isinstance(uploaded_file, File)
|
||||
expected_name = external_file_id if use_presigned_url else "upload"
|
||||
expected_name = external_file_id
|
||||
assert uploaded_file.name == expected_name
|
||||
assert uploaded_file.external_file_id == external_file_id
|
||||
|
||||
|
||||
Reference in New Issue
Block a user