mirror of
https://github.com/run-llama/llama_cloud_services.git
synced 2026-07-01 21:44:37 -04:00
Compare commits
2 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 47cb887fbc | |||
| 8cf2930320 |
@@ -1,6 +1,11 @@
|
||||
import os
|
||||
from typing import Any, Dict, Generic, List, Optional, Type
|
||||
|
||||
from llama_cloud import (
|
||||
AgentData,
|
||||
PaginatedResponseAgentData,
|
||||
PaginatedResponseAggregateGroup,
|
||||
)
|
||||
from llama_cloud.client import AsyncLlamaCloud
|
||||
from tenacity import (
|
||||
WrappedFn,
|
||||
@@ -157,10 +162,14 @@ class AsyncAgentDataClient(Generic[AgentDataT]):
|
||||
|
||||
@agent_data_retry
|
||||
async def get_item(self, item_id: str) -> TypedAgentData[AgentDataT]:
|
||||
raw_data = await self.client.beta.get_agent_data(
|
||||
raw_data = await self.untyped_get_item(item_id)
|
||||
return TypedAgentData.from_raw(raw_data, self.type)
|
||||
|
||||
@agent_data_retry
|
||||
async def untyped_get_item(self, item_id: str) -> AgentData:
|
||||
return await self.client.beta.get_agent_data(
|
||||
item_id=item_id,
|
||||
)
|
||||
return TypedAgentData.from_raw(raw_data, validator=self.type)
|
||||
|
||||
@agent_data_retry
|
||||
async def create_item(self, data: AgentDataT) -> TypedAgentData[AgentDataT]:
|
||||
@@ -211,9 +220,7 @@ class AsyncAgentDataClient(Generic[AgentDataT]):
|
||||
offset: Number of items to skip from the beginning. Defaults to 0.
|
||||
include_total: Whether to include the total count in the response. Defaults to False to improve performance. It's recommended to only request on the first page.
|
||||
"""
|
||||
raw = await self.client.beta.search_agent_data_api_v_1_beta_agent_data_search_post(
|
||||
deployment_name=self.deployment_name,
|
||||
collection=self.collection,
|
||||
raw = await self.untyped_search(
|
||||
filter=filter,
|
||||
order_by=order_by,
|
||||
offset=offset,
|
||||
@@ -228,6 +235,25 @@ class AsyncAgentDataClient(Generic[AgentDataT]):
|
||||
total=raw.total_size,
|
||||
)
|
||||
|
||||
@agent_data_retry
|
||||
async def untyped_search(
|
||||
self,
|
||||
filter: Optional[Dict[str, Dict[ComparisonOperator, Any]]] = None,
|
||||
order_by: Optional[str] = None,
|
||||
offset: Optional[int] = None,
|
||||
page_size: Optional[int] = None,
|
||||
include_total: bool = False,
|
||||
) -> PaginatedResponseAgentData:
|
||||
return await self.client.beta.search_agent_data_api_v_1_beta_agent_data_search_post(
|
||||
deployment_name=self.deployment_name,
|
||||
collection=self.collection,
|
||||
filter=filter,
|
||||
order_by=order_by,
|
||||
offset=offset,
|
||||
page_size=page_size,
|
||||
include_total=include_total,
|
||||
)
|
||||
|
||||
@agent_data_retry
|
||||
async def aggregate(
|
||||
self,
|
||||
@@ -254,7 +280,37 @@ class AsyncAgentDataClient(Generic[AgentDataT]):
|
||||
offset: Number of groups to skip from the beginning. Defaults to 0.
|
||||
page_size: Maximum number of groups to return per page.
|
||||
"""
|
||||
raw = await self.client.beta.aggregate_agent_data_api_v_1_beta_agent_data_aggregate_post(
|
||||
raw = await self.untyped_aggregate(
|
||||
filter=filter,
|
||||
group_by=group_by,
|
||||
count=count,
|
||||
first=first,
|
||||
order_by=order_by,
|
||||
offset=offset,
|
||||
page_size=page_size,
|
||||
)
|
||||
|
||||
return TypedAggregateGroupItems(
|
||||
items=[
|
||||
TypedAggregateGroup.from_raw(grp, validator=self.type)
|
||||
for grp in raw.items
|
||||
],
|
||||
has_more=raw.next_page_token is not None,
|
||||
total=raw.total_size,
|
||||
)
|
||||
|
||||
@agent_data_retry
|
||||
async def untyped_aggregate(
|
||||
self,
|
||||
filter: Optional[Dict[str, Dict[ComparisonOperator, Any]]] = None,
|
||||
group_by: Optional[List[str]] = None,
|
||||
count: Optional[bool] = None,
|
||||
first: Optional[bool] = None,
|
||||
order_by: Optional[str] = None,
|
||||
offset: Optional[int] = None,
|
||||
page_size: Optional[int] = None,
|
||||
) -> PaginatedResponseAggregateGroup:
|
||||
return await self.client.beta.aggregate_agent_data_api_v_1_beta_agent_data_aggregate_post(
|
||||
deployment_name=self.deployment_name,
|
||||
collection=self.collection,
|
||||
page_size=page_size,
|
||||
@@ -265,11 +321,3 @@ class AsyncAgentDataClient(Generic[AgentDataT]):
|
||||
first=first,
|
||||
offset=offset,
|
||||
)
|
||||
return TypedAggregateGroupItems(
|
||||
items=[
|
||||
TypedAggregateGroup.from_raw(item, validator=self.type)
|
||||
for item in raw.items
|
||||
],
|
||||
has_more=raw.next_page_token is not None,
|
||||
total=raw.total_size,
|
||||
)
|
||||
|
||||
@@ -56,7 +56,6 @@ from typing import (
|
||||
|
||||
# Type variable for user-defined data models
|
||||
AgentDataT = TypeVar("AgentDataT", bound=BaseModel)
|
||||
|
||||
# Type variable for extracted data (can be dict or Pydantic model)
|
||||
ExtractedT = TypeVar("ExtractedT", bound=Union[BaseModel, dict])
|
||||
|
||||
@@ -116,10 +115,10 @@ class TypedAgentData(BaseModel, Generic[AgentDataT]):
|
||||
Args:
|
||||
raw_data: Raw agent data from the API
|
||||
validator: Pydantic model class to validate the data field
|
||||
|
||||
Returns:
|
||||
TypedAgentData instance with validated data
|
||||
"""
|
||||
|
||||
data: AgentDataT = validator.model_validate(raw_data.data)
|
||||
|
||||
return cls(
|
||||
|
||||
@@ -0,0 +1,158 @@
|
||||
import pytest
|
||||
from typing import Any, Dict, List, Optional
|
||||
from pydantic import BaseModel
|
||||
from datetime import datetime
|
||||
|
||||
from llama_cloud.types.agent_data import AgentData
|
||||
from llama_cloud.types.aggregate_group import AggregateGroup
|
||||
|
||||
from llama_cloud_services.beta.agent_data.client import AsyncAgentDataClient
|
||||
|
||||
|
||||
class Person(BaseModel):
|
||||
name: str
|
||||
age: int
|
||||
|
||||
|
||||
class FakeBeta:
|
||||
def __init__(self) -> None:
|
||||
self._get_item_response: Optional[AgentData] = None
|
||||
self._search_items: List[AgentData] = []
|
||||
self._aggregate_items: List[AggregateGroup] = []
|
||||
self._total_size: Optional[int] = None
|
||||
self._next_page_token: Optional[str] = None
|
||||
|
||||
# Single get
|
||||
async def get_agent_data(self, item_id: str) -> AgentData:
|
||||
assert self._get_item_response is not None, "_get_item_response not set"
|
||||
return self._get_item_response
|
||||
|
||||
# Search
|
||||
async def search_agent_data_api_v_1_beta_agent_data_search_post(
|
||||
self,
|
||||
*,
|
||||
deployment_name: str,
|
||||
collection: str,
|
||||
filter: Optional[Dict[str, Any]] = None,
|
||||
order_by: Optional[str] = None,
|
||||
offset: Optional[int] = None,
|
||||
page_size: Optional[int] = None,
|
||||
include_total: bool = False,
|
||||
) -> Any:
|
||||
class Resp:
|
||||
def __init__(
|
||||
self,
|
||||
items: List[AgentData],
|
||||
total_size: Optional[int],
|
||||
next_page_token: Optional[str],
|
||||
) -> None:
|
||||
self.items = items
|
||||
self.total_size = total_size
|
||||
self.next_page_token = next_page_token
|
||||
|
||||
return Resp(self._search_items, self._total_size, self._next_page_token)
|
||||
|
||||
# Aggregate
|
||||
async def aggregate_agent_data_api_v_1_beta_agent_data_aggregate_post(
|
||||
self,
|
||||
*,
|
||||
deployment_name: str,
|
||||
collection: str,
|
||||
page_size: Optional[int] = None,
|
||||
filter: Optional[Dict[str, Any]] = None,
|
||||
order_by: Optional[str] = None,
|
||||
group_by: Optional[List[str]] = None,
|
||||
count: Optional[bool] = None,
|
||||
first: Optional[bool] = None,
|
||||
offset: Optional[int] = None,
|
||||
) -> Any:
|
||||
class Resp:
|
||||
def __init__(
|
||||
self,
|
||||
items: List[AggregateGroup],
|
||||
total_size: Optional[int],
|
||||
next_page_token: Optional[str],
|
||||
) -> None:
|
||||
self.items = items
|
||||
self.total_size = total_size
|
||||
self.next_page_token = next_page_token
|
||||
|
||||
return Resp(self._aggregate_items, self._total_size, self._next_page_token)
|
||||
|
||||
|
||||
class FakeClient:
|
||||
def __init__(self) -> None:
|
||||
self.beta = FakeBeta()
|
||||
|
||||
|
||||
def make_agent_data(data: Dict[str, Any]) -> AgentData:
|
||||
return AgentData(
|
||||
id="id-1",
|
||||
deployment_name="dep",
|
||||
collection="col",
|
||||
data=data,
|
||||
created_at=datetime.now(),
|
||||
updated_at=datetime.now(),
|
||||
)
|
||||
|
||||
|
||||
def make_group(
|
||||
group_key: Dict[str, Any],
|
||||
first_item: Optional[Dict[str, Any]],
|
||||
count: Optional[int] = None,
|
||||
) -> AggregateGroup:
|
||||
return AggregateGroup(group_key=group_key, count=count, first_item=first_item)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_untyped_get_item_valid_to_dict() -> None:
|
||||
client = FakeClient()
|
||||
client.beta._get_item_response = make_agent_data({"name": "Alice", "age": 30})
|
||||
|
||||
adc = AsyncAgentDataClient(type=Person, client=client, deployment_name="dep")
|
||||
item = await adc.untyped_get_item("id-1")
|
||||
assert item.data == {"name": "Alice", "age": 30}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_untyped_get_item_invalid_retains_dict() -> None:
|
||||
client = FakeClient()
|
||||
# age wrong type; will fail validation and should be returned as dict
|
||||
client.beta._get_item_response = make_agent_data({"name": "Bob", "age": "x"})
|
||||
|
||||
adc = AsyncAgentDataClient(type=Person, client=client, deployment_name="dep")
|
||||
item = await adc.untyped_get_item("id-1")
|
||||
assert item.data == {"name": "Bob", "age": "x"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_untyped_search_mixed_items() -> None:
|
||||
client = FakeClient()
|
||||
client.beta._search_items = [
|
||||
make_agent_data({"name": "Carol", "age": 22}),
|
||||
make_agent_data({"name": "Dave", "age": "bad"}),
|
||||
]
|
||||
client.beta._total_size = 2
|
||||
|
||||
adc = AsyncAgentDataClient(type=Person, client=client, deployment_name="dep")
|
||||
results = await adc.untyped_search(include_total=True)
|
||||
assert len(results.items) == 2
|
||||
assert results.items[0].data == {"name": "Carol", "age": 22}
|
||||
assert results.items[1].data == {"name": "Dave", "age": "bad"}
|
||||
assert results.total_size == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_untyped_aggregate_first_item_dict() -> None:
|
||||
client = FakeClient()
|
||||
client.beta._aggregate_items = [
|
||||
make_group({"k": 1}, {"name": "Eve", "age": 40}),
|
||||
make_group({"k": 2}, {"name": "Frank", "age": "bad"}),
|
||||
]
|
||||
client.beta._total_size = 2
|
||||
|
||||
adc = AsyncAgentDataClient(type=Person, client=client, deployment_name="dep")
|
||||
results = await adc.untyped_aggregate(group_by=["k"], first=True)
|
||||
assert len(results.items) == 2
|
||||
assert results.items[0].first_item == {"name": "Eve", "age": 40}
|
||||
assert results.items[1].first_item == {"name": "Frank", "age": "bad"}
|
||||
@@ -56,7 +56,7 @@ def test_typed_agent_data_from_raw():
|
||||
|
||||
|
||||
def test_typed_agent_data_from_raw_validation_error():
|
||||
"""Test TypedAgentData.from_raw with invalid data."""
|
||||
"""Test TypedAgentData.from_raw with invalid data now raises InvalidTypedAgentData."""
|
||||
raw_data = AgentData(
|
||||
id="789",
|
||||
deployment_name="test-agent",
|
||||
|
||||
Generated
+1
-1
@@ -1596,7 +1596,7 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "llama-cloud-services"
|
||||
version = "0.6.67"
|
||||
version = "0.6.68"
|
||||
source = { editable = "." }
|
||||
dependencies = [
|
||||
{ name = "click", version = "8.1.8", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.10'" },
|
||||
|
||||
Reference in New Issue
Block a user