Compare commits

...

2 Commits

Author SHA1 Message Date
Adrian Lyjak 47cb887fbc Expose raw api result 2025-09-21 22:34:09 -04:00
Cursor Agent 8cf2930320 feat: Add untyped agent data retrieval and handling
Introduces methods to retrieve agent data as untyped dictionaries,
handling validation errors gracefully. This allows for more flexible
data access when strict typing is not required or when data may be
malformed.

Co-authored-by: adrian <adrian@runllama.ai>
2025-09-21 14:00:43 +00:00
5 changed files with 223 additions and 18 deletions
@@ -1,6 +1,11 @@
import os
from typing import Any, Dict, Generic, List, Optional, Type
from llama_cloud import (
AgentData,
PaginatedResponseAgentData,
PaginatedResponseAggregateGroup,
)
from llama_cloud.client import AsyncLlamaCloud
from tenacity import (
WrappedFn,
@@ -157,10 +162,14 @@ class AsyncAgentDataClient(Generic[AgentDataT]):
@agent_data_retry
async def get_item(self, item_id: str) -> TypedAgentData[AgentDataT]:
raw_data = await self.client.beta.get_agent_data(
raw_data = await self.untyped_get_item(item_id)
return TypedAgentData.from_raw(raw_data, self.type)
@agent_data_retry
async def untyped_get_item(self, item_id: str) -> AgentData:
return await self.client.beta.get_agent_data(
item_id=item_id,
)
return TypedAgentData.from_raw(raw_data, validator=self.type)
@agent_data_retry
async def create_item(self, data: AgentDataT) -> TypedAgentData[AgentDataT]:
@@ -211,9 +220,7 @@ class AsyncAgentDataClient(Generic[AgentDataT]):
offset: Number of items to skip from the beginning. Defaults to 0.
include_total: Whether to include the total count in the response. Defaults to False to improve performance. It's recommended to only request on the first page.
"""
raw = await self.client.beta.search_agent_data_api_v_1_beta_agent_data_search_post(
deployment_name=self.deployment_name,
collection=self.collection,
raw = await self.untyped_search(
filter=filter,
order_by=order_by,
offset=offset,
@@ -228,6 +235,25 @@ class AsyncAgentDataClient(Generic[AgentDataT]):
total=raw.total_size,
)
@agent_data_retry
async def untyped_search(
self,
filter: Optional[Dict[str, Dict[ComparisonOperator, Any]]] = None,
order_by: Optional[str] = None,
offset: Optional[int] = None,
page_size: Optional[int] = None,
include_total: bool = False,
) -> PaginatedResponseAgentData:
return await self.client.beta.search_agent_data_api_v_1_beta_agent_data_search_post(
deployment_name=self.deployment_name,
collection=self.collection,
filter=filter,
order_by=order_by,
offset=offset,
page_size=page_size,
include_total=include_total,
)
@agent_data_retry
async def aggregate(
self,
@@ -254,7 +280,37 @@ class AsyncAgentDataClient(Generic[AgentDataT]):
offset: Number of groups to skip from the beginning. Defaults to 0.
page_size: Maximum number of groups to return per page.
"""
raw = await self.client.beta.aggregate_agent_data_api_v_1_beta_agent_data_aggregate_post(
raw = await self.untyped_aggregate(
filter=filter,
group_by=group_by,
count=count,
first=first,
order_by=order_by,
offset=offset,
page_size=page_size,
)
return TypedAggregateGroupItems(
items=[
TypedAggregateGroup.from_raw(grp, validator=self.type)
for grp in raw.items
],
has_more=raw.next_page_token is not None,
total=raw.total_size,
)
@agent_data_retry
async def untyped_aggregate(
self,
filter: Optional[Dict[str, Dict[ComparisonOperator, Any]]] = None,
group_by: Optional[List[str]] = None,
count: Optional[bool] = None,
first: Optional[bool] = None,
order_by: Optional[str] = None,
offset: Optional[int] = None,
page_size: Optional[int] = None,
) -> PaginatedResponseAggregateGroup:
return await self.client.beta.aggregate_agent_data_api_v_1_beta_agent_data_aggregate_post(
deployment_name=self.deployment_name,
collection=self.collection,
page_size=page_size,
@@ -265,11 +321,3 @@ class AsyncAgentDataClient(Generic[AgentDataT]):
first=first,
offset=offset,
)
return TypedAggregateGroupItems(
items=[
TypedAggregateGroup.from_raw(item, validator=self.type)
for item in raw.items
],
has_more=raw.next_page_token is not None,
total=raw.total_size,
)
@@ -56,7 +56,6 @@ from typing import (
# Type variable for user-defined data models
AgentDataT = TypeVar("AgentDataT", bound=BaseModel)
# Type variable for extracted data (can be dict or Pydantic model)
ExtractedT = TypeVar("ExtractedT", bound=Union[BaseModel, dict])
@@ -116,10 +115,10 @@ class TypedAgentData(BaseModel, Generic[AgentDataT]):
Args:
raw_data: Raw agent data from the API
validator: Pydantic model class to validate the data field
Returns:
TypedAgentData instance with validated data
"""
data: AgentDataT = validator.model_validate(raw_data.data)
return cls(
@@ -0,0 +1,158 @@
import pytest
from typing import Any, Dict, List, Optional
from pydantic import BaseModel
from datetime import datetime
from llama_cloud.types.agent_data import AgentData
from llama_cloud.types.aggregate_group import AggregateGroup
from llama_cloud_services.beta.agent_data.client import AsyncAgentDataClient
class Person(BaseModel):
name: str
age: int
class FakeBeta:
def __init__(self) -> None:
self._get_item_response: Optional[AgentData] = None
self._search_items: List[AgentData] = []
self._aggregate_items: List[AggregateGroup] = []
self._total_size: Optional[int] = None
self._next_page_token: Optional[str] = None
# Single get
async def get_agent_data(self, item_id: str) -> AgentData:
assert self._get_item_response is not None, "_get_item_response not set"
return self._get_item_response
# Search
async def search_agent_data_api_v_1_beta_agent_data_search_post(
self,
*,
deployment_name: str,
collection: str,
filter: Optional[Dict[str, Any]] = None,
order_by: Optional[str] = None,
offset: Optional[int] = None,
page_size: Optional[int] = None,
include_total: bool = False,
) -> Any:
class Resp:
def __init__(
self,
items: List[AgentData],
total_size: Optional[int],
next_page_token: Optional[str],
) -> None:
self.items = items
self.total_size = total_size
self.next_page_token = next_page_token
return Resp(self._search_items, self._total_size, self._next_page_token)
# Aggregate
async def aggregate_agent_data_api_v_1_beta_agent_data_aggregate_post(
self,
*,
deployment_name: str,
collection: str,
page_size: Optional[int] = None,
filter: Optional[Dict[str, Any]] = None,
order_by: Optional[str] = None,
group_by: Optional[List[str]] = None,
count: Optional[bool] = None,
first: Optional[bool] = None,
offset: Optional[int] = None,
) -> Any:
class Resp:
def __init__(
self,
items: List[AggregateGroup],
total_size: Optional[int],
next_page_token: Optional[str],
) -> None:
self.items = items
self.total_size = total_size
self.next_page_token = next_page_token
return Resp(self._aggregate_items, self._total_size, self._next_page_token)
class FakeClient:
def __init__(self) -> None:
self.beta = FakeBeta()
def make_agent_data(data: Dict[str, Any]) -> AgentData:
return AgentData(
id="id-1",
deployment_name="dep",
collection="col",
data=data,
created_at=datetime.now(),
updated_at=datetime.now(),
)
def make_group(
group_key: Dict[str, Any],
first_item: Optional[Dict[str, Any]],
count: Optional[int] = None,
) -> AggregateGroup:
return AggregateGroup(group_key=group_key, count=count, first_item=first_item)
@pytest.mark.asyncio
async def test_untyped_get_item_valid_to_dict() -> None:
client = FakeClient()
client.beta._get_item_response = make_agent_data({"name": "Alice", "age": 30})
adc = AsyncAgentDataClient(type=Person, client=client, deployment_name="dep")
item = await adc.untyped_get_item("id-1")
assert item.data == {"name": "Alice", "age": 30}
@pytest.mark.asyncio
async def test_untyped_get_item_invalid_retains_dict() -> None:
client = FakeClient()
# age wrong type; will fail validation and should be returned as dict
client.beta._get_item_response = make_agent_data({"name": "Bob", "age": "x"})
adc = AsyncAgentDataClient(type=Person, client=client, deployment_name="dep")
item = await adc.untyped_get_item("id-1")
assert item.data == {"name": "Bob", "age": "x"}
@pytest.mark.asyncio
async def test_untyped_search_mixed_items() -> None:
client = FakeClient()
client.beta._search_items = [
make_agent_data({"name": "Carol", "age": 22}),
make_agent_data({"name": "Dave", "age": "bad"}),
]
client.beta._total_size = 2
adc = AsyncAgentDataClient(type=Person, client=client, deployment_name="dep")
results = await adc.untyped_search(include_total=True)
assert len(results.items) == 2
assert results.items[0].data == {"name": "Carol", "age": 22}
assert results.items[1].data == {"name": "Dave", "age": "bad"}
assert results.total_size == 2
@pytest.mark.asyncio
async def test_untyped_aggregate_first_item_dict() -> None:
client = FakeClient()
client.beta._aggregate_items = [
make_group({"k": 1}, {"name": "Eve", "age": 40}),
make_group({"k": 2}, {"name": "Frank", "age": "bad"}),
]
client.beta._total_size = 2
adc = AsyncAgentDataClient(type=Person, client=client, deployment_name="dep")
results = await adc.untyped_aggregate(group_by=["k"], first=True)
assert len(results.items) == 2
assert results.items[0].first_item == {"name": "Eve", "age": 40}
assert results.items[1].first_item == {"name": "Frank", "age": "bad"}
@@ -56,7 +56,7 @@ def test_typed_agent_data_from_raw():
def test_typed_agent_data_from_raw_validation_error():
"""Test TypedAgentData.from_raw with invalid data."""
"""Test TypedAgentData.from_raw with invalid data now raises InvalidTypedAgentData."""
raw_data = AgentData(
id="789",
deployment_name="test-agent",
Generated
+1 -1
View File
@@ -1596,7 +1596,7 @@ wheels = [
[[package]]
name = "llama-cloud-services"
version = "0.6.67"
version = "0.6.68"
source = { editable = "." }
dependencies = [
{ name = "click", version = "8.1.8", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.10'" },