mirror of
https://github.com/run-llama/template-workflow-extract-basic.git
synced 2026-06-30 22:37:54 -04:00
Extract llama-cloud-fake package from extract-basic (#255)
This commit is contained in:
@@ -3,4 +3,3 @@ _exclude:
|
||||
- ".github"
|
||||
- "copier.yaml"
|
||||
- ".venv"
|
||||
- "tests/testing_utils"
|
||||
|
||||
@@ -28,8 +28,15 @@ dev = [
|
||||
"psutil>=7.1.3",
|
||||
"pytest-playwright>=0.7.2",
|
||||
"pytest-timeout>=2.3.1",
|
||||
"llama-cloud-fake>=0.1,<0.2",
|
||||
]
|
||||
|
||||
# Until llama-cloud-fake is published to PyPI, resolve from the monorepo
|
||||
# workspace. Remove this block once the package is on PyPI and the above
|
||||
# version pin can resolve normally.
|
||||
[tool.uv.sources]
|
||||
llama-cloud-fake = { path = "../../packages/llama-cloud-fake", editable = true }
|
||||
|
||||
[build-system]
|
||||
requires = ["hatchling"]
|
||||
build-backend = "hatchling.build"
|
||||
|
||||
@@ -3,8 +3,6 @@ import os
|
||||
|
||||
from llama_cloud import AsyncLlamaCloud
|
||||
|
||||
from .testing_utils import FakeLlamaCloudServer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# deployed agents may infer their name from the deployment name
|
||||
@@ -16,7 +14,9 @@ base_url = os.getenv("LLAMA_CLOUD_BASE_URL")
|
||||
project_id = os.getenv("LLAMA_DEPLOY_PROJECT_ID")
|
||||
|
||||
if os.getenv("FAKE_LLAMA_CLOUD"):
|
||||
fake: FakeLlamaCloudServer | None = FakeLlamaCloudServer().install()
|
||||
from llama_cloud_fake import FakeLlamaCloudServer
|
||||
|
||||
fake = FakeLlamaCloudServer().install()
|
||||
else:
|
||||
fake = None
|
||||
|
||||
|
||||
@@ -1,9 +0,0 @@
|
||||
from .matchers import FileMatcher, RequestMatcher, SchemaMatcher
|
||||
from .server import FakeLlamaCloudServer
|
||||
|
||||
__all__ = [
|
||||
"FakeLlamaCloudServer",
|
||||
"FileMatcher",
|
||||
"SchemaMatcher",
|
||||
"RequestMatcher",
|
||||
]
|
||||
@@ -1,223 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import random
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Iterable, Mapping, MutableMapping
|
||||
|
||||
from jsonref import JsonRef, replace_refs
|
||||
|
||||
|
||||
def hash_chunks(chunks: Iterable[bytes]) -> str:
|
||||
digest = hashlib.sha256()
|
||||
for chunk in chunks:
|
||||
digest.update(chunk)
|
||||
return digest.hexdigest()
|
||||
|
||||
|
||||
def fingerprint_file(content: bytes, filename: str | None = None) -> str:
|
||||
name_bytes = filename.encode("utf-8") if filename else b""
|
||||
return hash_chunks((content, name_bytes))
|
||||
|
||||
|
||||
def hash_schema(schema: Any) -> str:
|
||||
json_string = json.dumps(
|
||||
_to_serializable(schema),
|
||||
sort_keys=True,
|
||||
separators=(",", ":"),
|
||||
)
|
||||
return hashlib.sha256(json_string.encode("utf-8")).hexdigest()
|
||||
|
||||
|
||||
def combined_seed(*parts: str) -> int:
|
||||
digest = hash_chunks(tuple(part.encode("utf-8") for part in parts))
|
||||
return int(digest[:16], 16)
|
||||
|
||||
|
||||
def generate_data_from_schema(schema: Any, seed: int) -> Any:
|
||||
rng = random.Random(seed)
|
||||
schema = replace_refs(schema)
|
||||
schema = {k: v for k, v in schema.items() if k != "$defs"}
|
||||
return _generate_value(schema, rng, depth=0)
|
||||
|
||||
|
||||
def generate_text_blob(seed: int, sentences: int = 3) -> str:
|
||||
rng = random.Random(seed)
|
||||
words = [
|
||||
"aurora",
|
||||
"copper",
|
||||
"delta",
|
||||
"ember",
|
||||
"fable",
|
||||
"glyph",
|
||||
"harbor",
|
||||
"iris",
|
||||
"juniper",
|
||||
"kepler",
|
||||
"lumen",
|
||||
"monarch",
|
||||
"nylon",
|
||||
"onyx",
|
||||
"paragon",
|
||||
"quartz",
|
||||
"raptor",
|
||||
"solstice",
|
||||
"topaz",
|
||||
"umbra",
|
||||
"verdant",
|
||||
"willow",
|
||||
"xenon",
|
||||
"yonder",
|
||||
"zephyr",
|
||||
]
|
||||
sentence_pieces = []
|
||||
for _ in range(sentences):
|
||||
length = rng.randint(6, 12)
|
||||
chosen = rng.sample(words, k=length)
|
||||
sentence = " ".join(chosen).capitalize() + "."
|
||||
sentence_pieces.append(sentence)
|
||||
return " ".join(sentence_pieces)
|
||||
|
||||
|
||||
def utcnow() -> datetime:
|
||||
return datetime.now(timezone.utc)
|
||||
|
||||
|
||||
def _to_serializable(value: Any) -> Any:
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, (str, int, float, bool)):
|
||||
return value
|
||||
if isinstance(value, bytes):
|
||||
return value.decode("utf-8", errors="ignore")
|
||||
if isinstance(value, Mapping):
|
||||
return {key: _to_serializable(val) for key, val in value.items()}
|
||||
if isinstance(value, MutableMapping):
|
||||
return {key: _to_serializable(val) for key, val in value.items()}
|
||||
if isinstance(value, (list, tuple, set)):
|
||||
return [_to_serializable(item) for item in value]
|
||||
if hasattr(value, "model_dump_json"):
|
||||
return json.loads(value.model_dump_json())
|
||||
if hasattr(value, "model_dump"):
|
||||
return value.model_dump()
|
||||
if hasattr(value, "dict"):
|
||||
return value.dict() # type: ignore[call-arg]
|
||||
if hasattr(value, "model_json_schema"):
|
||||
return value.model_json_schema()
|
||||
return str(value)
|
||||
|
||||
|
||||
def _generate_value(schema: Any, rng: random.Random, depth: int) -> Any:
|
||||
if depth > 8:
|
||||
return rng.choice(
|
||||
(
|
||||
rng.randint(1, 999),
|
||||
rng.random(),
|
||||
generate_text_blob(rng.randint(0, 1_000_000), sentences=1),
|
||||
)
|
||||
)
|
||||
|
||||
if isinstance(schema, JsonRef):
|
||||
schema = dict(schema) # type: ignore
|
||||
|
||||
if schema is None:
|
||||
return generate_text_blob(rng.randint(0, 1_000_000), sentences=1)
|
||||
|
||||
if isinstance(schema, list):
|
||||
return [_generate_value(item, rng, depth + 1) for item in schema]
|
||||
|
||||
if isinstance(schema, str):
|
||||
return f"{schema}-{rng.randint(100, 999)}"
|
||||
|
||||
if isinstance(schema, Mapping):
|
||||
if "enum" in schema:
|
||||
options = schema["enum"]
|
||||
if options:
|
||||
index = rng.randint(0, len(options) - 1)
|
||||
return options[index]
|
||||
|
||||
schema_type = schema.get("type")
|
||||
|
||||
# Handle union types like ["number", "null"]
|
||||
if isinstance(schema_type, list):
|
||||
concrete_types = [t for t in schema_type if t != "null"]
|
||||
if concrete_types:
|
||||
schema_type = concrete_types[0]
|
||||
else:
|
||||
return None
|
||||
|
||||
if schema_type == "object":
|
||||
properties = schema.get("properties", {})
|
||||
result = {}
|
||||
for key, subschema in properties.items():
|
||||
result[key] = _generate_value(subschema, rng, depth + 1)
|
||||
return result
|
||||
|
||||
if schema_type == "array":
|
||||
items_schema = schema.get("items", {})
|
||||
min_items = schema.get("minItems", 1)
|
||||
max_items = schema.get("maxItems", max(3, min_items))
|
||||
length = rng.randint(min_items, min(min_items + 2, max_items))
|
||||
return [
|
||||
_generate_value(items_schema, rng, depth + 1) for _ in range(length)
|
||||
]
|
||||
|
||||
if schema_type == "integer":
|
||||
minimum = schema.get("minimum", 0)
|
||||
maximum = schema.get("maximum", minimum + 500)
|
||||
return rng.randint(int(minimum), int(maximum))
|
||||
|
||||
if schema_type == "number":
|
||||
minimum = schema.get("minimum", 0.0)
|
||||
maximum = schema.get("maximum", minimum + 500.0)
|
||||
value = rng.uniform(float(minimum), float(maximum))
|
||||
return round(value, 2)
|
||||
|
||||
if schema_type == "boolean":
|
||||
return rng.choice((True, False))
|
||||
|
||||
if schema_type == "string":
|
||||
fmt = schema.get("format")
|
||||
if fmt == "date-time":
|
||||
timestamp = utcnow().isoformat()
|
||||
return timestamp
|
||||
if fmt == "email":
|
||||
return f"user{rng.randint(1000, 9999)}@example.com"
|
||||
if fmt == "uri":
|
||||
return f"https://example.com/{rng.randint(1000, 9999)}"
|
||||
min_length = schema.get("minLength", 5)
|
||||
max_length = schema.get("maxLength", max(10, min_length))
|
||||
length = rng.randint(min_length, min(min_length + 5, max_length))
|
||||
return generate_text_blob(
|
||||
rng.randint(0, 1_000_000), sentences=max(1, length // 5)
|
||||
)
|
||||
|
||||
if schema_type == "null":
|
||||
return None
|
||||
|
||||
if "oneOf" in schema:
|
||||
option = rng.choice(schema["oneOf"])
|
||||
return _generate_value(option, rng, depth + 1)
|
||||
|
||||
if "anyOf" in schema:
|
||||
option = rng.choice(schema["anyOf"])
|
||||
return _generate_value(option, rng, depth + 1)
|
||||
|
||||
return generate_text_blob(rng.randint(0, 1_000_000), sentences=1)
|
||||
|
||||
|
||||
def categorize_pages(
|
||||
content: bytes, categories: list[str], seed: int
|
||||
) -> dict[str, list[int]]:
|
||||
rng = random.Random(seed)
|
||||
page_size = rng.randint(1, 50)
|
||||
categorized_pages: dict[str, list[int]] = {c: [] for c in categories}
|
||||
i = 0
|
||||
j = 0
|
||||
while j + page_size < len(content):
|
||||
i += 1
|
||||
category = rng.choice(categories)
|
||||
categorized_pages[category].append(i)
|
||||
j += page_size
|
||||
return categorized_pages
|
||||
@@ -1,335 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import httpx
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, Dict
|
||||
from ._deterministic import utcnow, hash_schema
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .server import FakeLlamaCloudServer
|
||||
|
||||
|
||||
@dataclass
|
||||
class StoredAgentData:
|
||||
data: dict[str, Any]
|
||||
id: str
|
||||
collection: str
|
||||
deployment_name: str
|
||||
|
||||
def __getattr__(self, name: str) -> Any:
|
||||
return self.data.get(name)
|
||||
|
||||
def __setattr__(self, name: str, value: Any) -> None:
|
||||
if name in ("data", "id", "collection", "deployment_name"):
|
||||
super().__setattr__(name, value)
|
||||
else:
|
||||
self.data[name] = value
|
||||
|
||||
@classmethod
|
||||
def from_request_data(cls, data: dict[str, Any]) -> "StoredAgentData":
|
||||
return cls(
|
||||
data=data.get("data", {}),
|
||||
collection=data.get("collection", "default"),
|
||||
deployment_name=data.get("deployment_name", ""),
|
||||
id=hash_schema(data.get("data", {}))[:7],
|
||||
)
|
||||
|
||||
|
||||
def apply_filter(data: dict, filters: dict) -> bool:
|
||||
"""Check if data matches all filters"""
|
||||
ops = {
|
||||
"gt": lambda a, b: a > b,
|
||||
"gte": lambda a, b: a >= b,
|
||||
"lt": lambda a, b: a < b,
|
||||
"lte": lambda a, b: a <= b,
|
||||
"eq": lambda a, b: a == b,
|
||||
"ne": lambda a, b: a != b,
|
||||
"in": lambda a, b: a in b,
|
||||
"nin": lambda a, b: a not in b,
|
||||
}
|
||||
|
||||
for key, condition in filters.items():
|
||||
if key not in data:
|
||||
return False
|
||||
|
||||
if isinstance(condition, dict):
|
||||
for op, value in condition.items():
|
||||
if op in ops:
|
||||
if not ops[op](data[key], value):
|
||||
return False
|
||||
else:
|
||||
return False
|
||||
else:
|
||||
if data[key] != condition:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
class FakeAgentDataNamespace:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
server: "FakeLlamaCloudServer",
|
||||
) -> None:
|
||||
self._server = server
|
||||
self.stored: list[StoredAgentData] = []
|
||||
self.routes: Dict[str, Any] = {}
|
||||
|
||||
def _create_data(self, request: httpx.Request) -> httpx.Response:
|
||||
payload = self._server.json(request=request)
|
||||
data = StoredAgentData.from_request_data(payload)
|
||||
self.stored.append(data)
|
||||
response = {
|
||||
"data": data.data,
|
||||
"collection": data.collection,
|
||||
"deployment_name": data.deployment_name,
|
||||
"created_at": utcnow().isoformat(),
|
||||
"updated_at": None,
|
||||
"id": data.id,
|
||||
"project_id": None,
|
||||
"organization_id": None,
|
||||
}
|
||||
return self._server.json_response(response, status_code=200)
|
||||
|
||||
def _delete_data_by_query(self, request: httpx.Request) -> httpx.Response:
|
||||
payload = self._server.json(request=request)
|
||||
delete_count = 0
|
||||
if (filters := payload.get("filter")) is not None:
|
||||
to_keep = []
|
||||
for data in self.stored:
|
||||
if data.collection == payload.get(
|
||||
"collection", "default"
|
||||
) and data.deployment_name == payload.get("deployment_name"):
|
||||
if not apply_filter(data.data, filters):
|
||||
to_keep.append(data)
|
||||
else:
|
||||
delete_count += 1
|
||||
self.stored = to_keep
|
||||
return self._server.json_response(
|
||||
{"deleted_count": delete_count}, status_code=200
|
||||
)
|
||||
|
||||
def _delete_data_by_id(self, request: httpx.Request) -> httpx.Response:
|
||||
item_id = self._find_item_id(request=request)
|
||||
if not item_id:
|
||||
return self._server.json_response(
|
||||
{
|
||||
"detail": "An item_id path parameter is required to perform this operation"
|
||||
},
|
||||
status_code=400,
|
||||
)
|
||||
if not any(data.id == item_id for data in self.stored):
|
||||
return self._server.json_response(
|
||||
{"detail": f"No data with ID: {item_id}"}, status_code=404
|
||||
)
|
||||
self.stored = [data for data in self.stored if data.id != item_id]
|
||||
return self._server.json_response({}, status_code=200)
|
||||
|
||||
def _get_data_by_id(self, request: httpx.Request) -> httpx.Response:
|
||||
item_id = self._find_item_id(request=request)
|
||||
if not item_id:
|
||||
return self._server.json_response(
|
||||
{
|
||||
"detail": "An item_id path parameter is required to perform this operation"
|
||||
},
|
||||
status_code=400,
|
||||
)
|
||||
data = [data for data in self.stored if data.id == item_id]
|
||||
if data:
|
||||
response = {
|
||||
"data": data[0].data,
|
||||
"collection": data[0].collection,
|
||||
"deployment_name": data[0].deployment_name,
|
||||
"created_at": utcnow().isoformat(),
|
||||
"updated_at": None,
|
||||
"id": data[0].id,
|
||||
"project_id": None,
|
||||
"organization_id": None,
|
||||
}
|
||||
return self._server.json_response(response, status_code=200)
|
||||
else:
|
||||
return self._server.json_response(
|
||||
{"detail": f"No data with ID: {item_id}"}, status_code=404
|
||||
)
|
||||
|
||||
def _search_data(self, request: httpx.Request) -> httpx.Response:
|
||||
payload = self._server.json(request=request)
|
||||
found = []
|
||||
if (filters := payload.get("filter")) is not None:
|
||||
for data in self.stored:
|
||||
if data.collection == payload.get(
|
||||
"collection", "default"
|
||||
) and data.deployment_name == payload.get("deployment_name"):
|
||||
if apply_filter(data.data, filters):
|
||||
found.append(
|
||||
{
|
||||
"data": data.data,
|
||||
"collection": data.collection,
|
||||
"deployment_name": data.deployment_name,
|
||||
"created_at": utcnow().isoformat(),
|
||||
"updated_at": None,
|
||||
"id": data.id,
|
||||
"project_id": None,
|
||||
"organization_id": None,
|
||||
}
|
||||
)
|
||||
else:
|
||||
for data in self.stored:
|
||||
if data.collection == payload.get(
|
||||
"collection", "default"
|
||||
) and data.deployment_name == payload.get("deployment_name"):
|
||||
found.append(
|
||||
{
|
||||
"data": data.data,
|
||||
"collection": data.collection,
|
||||
"deployment_name": data.deployment_name,
|
||||
"created_at": utcnow().isoformat(),
|
||||
"updated_at": None,
|
||||
"id": data.id,
|
||||
"project_id": None,
|
||||
"organization_id": None,
|
||||
}
|
||||
)
|
||||
return self._server.json_response(
|
||||
{"items": found, "next_page_token": None, "total_size": len(found)},
|
||||
status_code=200,
|
||||
)
|
||||
|
||||
def _update_data(self, request: httpx.Request) -> httpx.Response:
|
||||
item_id = self._find_item_id(request=request)
|
||||
payload = self._server.json(request=request)
|
||||
if not item_id:
|
||||
return self._server.json_response(
|
||||
{
|
||||
"detail": "An item_id path parameter is required to perform this operation"
|
||||
},
|
||||
status_code=400,
|
||||
)
|
||||
updated = None
|
||||
for i, data in enumerate(self.stored):
|
||||
if data.id == item_id:
|
||||
updated = data
|
||||
updated.data = payload.get("data", data.data)
|
||||
self.stored[i] = updated
|
||||
print(updated)
|
||||
if updated is not None:
|
||||
response = {
|
||||
"data": updated.data,
|
||||
"collection": updated.collection,
|
||||
"deployment_name": updated.deployment_name,
|
||||
"created_at": None,
|
||||
"updated_at": utcnow().isoformat(),
|
||||
"id": updated.id,
|
||||
"project_id": None,
|
||||
"organization_id": None,
|
||||
}
|
||||
status_code = 200
|
||||
else:
|
||||
response = {"detail": f"Record with id {item_id} not found"}
|
||||
status_code = 404
|
||||
return self._server.json_response(response, status_code=status_code)
|
||||
|
||||
def _aggregate_data(self, request: httpx.Request) -> httpx.Response:
|
||||
payload = self._server.json(request=request)
|
||||
add_count = payload.get("count", False)
|
||||
group_bys: list[str] = payload.get("group_by", [])
|
||||
groups: dict[str, dict[str, list[dict]]] = {key: {} for key in group_bys}
|
||||
if (filters := payload.get("filter")) is not None:
|
||||
for data in self.stored:
|
||||
if data.collection == payload.get(
|
||||
"collection", "default"
|
||||
) and data.deployment_name == payload.get("deployment_name"):
|
||||
if apply_filter(data.data, filters):
|
||||
for key in group_bys:
|
||||
if key in data.data and data.data[key] in groups[key]:
|
||||
groups[key][data.data[key]].append(data.data)
|
||||
elif key in data.data and data.data[key] not in groups[key]:
|
||||
groups[key][data.data[key]] = [data.data]
|
||||
else:
|
||||
for data in self.stored:
|
||||
if data.collection == payload.get(
|
||||
"collection", "default"
|
||||
) and data.deployment_name == payload.get("deployment_name"):
|
||||
for key in group_bys:
|
||||
if key in data.data and data.data[key] in groups[key]:
|
||||
groups[key][data.data[key]].append(data.data)
|
||||
elif key in data.data and data.data[key] not in groups[key]:
|
||||
groups[key][data.data[key]] = [data.data]
|
||||
|
||||
response: dict[str, Any] = {
|
||||
"items": [],
|
||||
"next_page_token": None,
|
||||
"total_size": 0,
|
||||
}
|
||||
for k in groups:
|
||||
if len(groups[k]) > 0:
|
||||
for v in groups[k]:
|
||||
if groups[k][v]:
|
||||
first_element = groups[k][v][0]
|
||||
else:
|
||||
first_element = None
|
||||
response["items"].append(
|
||||
{
|
||||
"first_item": first_element,
|
||||
"count": len(groups[k][v]) if add_count else None,
|
||||
"group_key": {k: v},
|
||||
}
|
||||
)
|
||||
response["total_size"] = len(response["items"])
|
||||
return self._server.json_response(response, status_code=200)
|
||||
|
||||
def _find_item_id(self, request: httpx.Request) -> str | None:
|
||||
matchgroups = re.search(r"/agent-data\/([^\/]+)$", request.url.path)
|
||||
return matchgroups.group(1) if matchgroups is not None else None
|
||||
|
||||
def register(self) -> None:
|
||||
server = self._server
|
||||
route = server.add_route(
|
||||
"POST",
|
||||
"/api/v1/beta/agent-data",
|
||||
self._create_data,
|
||||
namespace="create_item",
|
||||
)
|
||||
self.routes["stateless_run"] = route
|
||||
self.stateless_run = route
|
||||
server.add_route(
|
||||
"POST",
|
||||
"/api/v1/beta/agent-data/:aggregate",
|
||||
self._aggregate_data,
|
||||
namespace="untyped_aggregate",
|
||||
alias="aggregate",
|
||||
)
|
||||
server.add_route(
|
||||
"POST",
|
||||
"/api/v1/beta/agent-data/:delete",
|
||||
self._delete_data_by_query,
|
||||
namespace="delete",
|
||||
)
|
||||
server.add_route(
|
||||
"POST",
|
||||
"/api/v1/beta/agent-data/:search",
|
||||
self._search_data,
|
||||
namespace="untyped_search",
|
||||
alias="search",
|
||||
)
|
||||
server.add_route(
|
||||
"DELETE",
|
||||
"/api/v1/beta/agent-data/{item_id}",
|
||||
self._delete_data_by_id,
|
||||
namespace="delete_item",
|
||||
)
|
||||
server.add_route(
|
||||
"GET",
|
||||
"/api/v1/beta/agent-data/{item_id}",
|
||||
self._get_data_by_id,
|
||||
namespace="untyped_get_item",
|
||||
alias="get_item",
|
||||
)
|
||||
server.add_route(
|
||||
"PUT",
|
||||
"/api/v1/beta/agent-data/{item_id}",
|
||||
self._update_data,
|
||||
namespace="update_item",
|
||||
)
|
||||
@@ -1,154 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Dict, List
|
||||
|
||||
import httpx
|
||||
from llama_cloud.types.classifier import (
|
||||
ClassifierRule,
|
||||
ClassifyJob,
|
||||
)
|
||||
from llama_cloud.types.classifier.job_get_results_response import (
|
||||
Item,
|
||||
ItemResult,
|
||||
JobGetResultsResponse,
|
||||
)
|
||||
|
||||
from ._deterministic import combined_seed, utcnow
|
||||
from .files import FakeFilesNamespace, StoredFile
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .server import FakeLlamaCloudServer
|
||||
|
||||
|
||||
@dataclass
|
||||
class ClassificationJobRecord:
|
||||
job: ClassifyJob
|
||||
results: JobGetResultsResponse
|
||||
files: List[StoredFile]
|
||||
|
||||
|
||||
class FakeClassifyNamespace:
|
||||
def __init__(
|
||||
self, *, server: "FakeLlamaCloudServer", files: FakeFilesNamespace
|
||||
) -> None:
|
||||
self._server = server
|
||||
self._files = files
|
||||
self._jobs: Dict[str, ClassificationJobRecord] = {}
|
||||
|
||||
def register(self) -> None:
|
||||
server = self._server
|
||||
server.add_route(
|
||||
"POST",
|
||||
"/api/v1/classifier/jobs",
|
||||
self._handle_create_job,
|
||||
namespace="classify",
|
||||
)
|
||||
server.add_route(
|
||||
"GET",
|
||||
"/api/v1/classifier/jobs",
|
||||
self._handle_list_jobs,
|
||||
namespace="classify",
|
||||
)
|
||||
server.add_route(
|
||||
"GET",
|
||||
"/api/v1/classifier/jobs/{job_id}",
|
||||
self._handle_get_job,
|
||||
namespace="classify",
|
||||
)
|
||||
server.add_route(
|
||||
"GET",
|
||||
"/api/v1/classifier/jobs/{job_id}/results",
|
||||
self._handle_get_results,
|
||||
namespace="classify",
|
||||
)
|
||||
|
||||
def _handle_create_job(self, request: httpx.Request) -> httpx.Response:
|
||||
payload = self._server.json(request)
|
||||
file_ids = payload.get("file_ids", [])
|
||||
rules_payload = payload.get("rules", [])
|
||||
rules = [ClassifierRule.parse_obj(rule) for rule in rules_payload]
|
||||
stored_files = []
|
||||
for file_id in file_ids:
|
||||
stored = self._files.get(file_id)
|
||||
if not stored:
|
||||
return self._server.json_response(
|
||||
{"detail": f"File {file_id} not found"}, status_code=404
|
||||
)
|
||||
stored_files.append(stored)
|
||||
|
||||
job_id = self._server.new_id("classify-job")
|
||||
job = ClassifyJob(
|
||||
id=job_id,
|
||||
project_id=request.url.params.get(
|
||||
"project_id", self._server.default_project_id
|
||||
),
|
||||
user_id="fake-user",
|
||||
rules=rules,
|
||||
parsing_configuration=None,
|
||||
status="SUCCESS",
|
||||
created_at=utcnow(),
|
||||
updated_at=utcnow(),
|
||||
effective_at=utcnow(),
|
||||
error_message=None,
|
||||
job_record_id=None,
|
||||
)
|
||||
results = self._build_results(job_id, stored_files, rules)
|
||||
record = ClassificationJobRecord(job=job, results=results, files=stored_files)
|
||||
self._jobs[job_id] = record
|
||||
return self._server.json_response(job.dict())
|
||||
|
||||
def _handle_list_jobs(self, request: httpx.Request) -> httpx.Response:
|
||||
return self._server.json_response(
|
||||
[record.job.dict() for record in self._jobs.values()]
|
||||
)
|
||||
|
||||
def _handle_get_job(self, request: httpx.Request) -> httpx.Response:
|
||||
job_id = request.url.path.split("/")[-1]
|
||||
record = self._jobs.get(job_id)
|
||||
if not record:
|
||||
return self._server.json_response(
|
||||
{"detail": "Job not found"}, status_code=404
|
||||
)
|
||||
return self._server.json_response(record.job.dict())
|
||||
|
||||
def _handle_get_results(self, request: httpx.Request) -> httpx.Response:
|
||||
job_id = request.url.path.split("/")[-2]
|
||||
record = self._jobs.get(job_id)
|
||||
if not record:
|
||||
return self._server.json_response(
|
||||
{"detail": "Results not found"}, status_code=404
|
||||
)
|
||||
return self._server.json_response(record.results.dict())
|
||||
|
||||
def _build_results(
|
||||
self,
|
||||
job_id: str,
|
||||
stored_files: List[StoredFile],
|
||||
rules: List[ClassifierRule],
|
||||
) -> JobGetResultsResponse:
|
||||
items: List[Item] = []
|
||||
for stored in stored_files:
|
||||
seed = combined_seed(stored.sha256, job_id)
|
||||
rule_index = seed % len(rules) if rules else 0
|
||||
predicted_type = rules[rule_index].type if rules else "unlabeled"
|
||||
confidence = 0.55 + (seed % 40) / 100
|
||||
reasoning = (
|
||||
f"Selected rule '{predicted_type}' using deterministic seed {seed}."
|
||||
)
|
||||
classification = Item(
|
||||
id=self._server.new_id("classification"),
|
||||
file_id=stored.file.id,
|
||||
classify_job_id=job_id,
|
||||
created_at=utcnow(),
|
||||
updated_at=utcnow(),
|
||||
result=ItemResult(
|
||||
type=predicted_type,
|
||||
confidence=min(confidence, 0.95),
|
||||
reasoning=reasoning,
|
||||
),
|
||||
)
|
||||
items.append(classification)
|
||||
return JobGetResultsResponse(
|
||||
items=items, next_page_token=None, total_size=len(items)
|
||||
)
|
||||
@@ -1,153 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional
|
||||
|
||||
import httpx
|
||||
from llama_cloud.types.configuration_response import (
|
||||
ConfigurationResponse,
|
||||
ExtractV2Parameters,
|
||||
)
|
||||
|
||||
from ._deterministic import utcnow
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .server import FakeLlamaCloudServer
|
||||
|
||||
|
||||
@dataclass
|
||||
class StoredConfiguration:
|
||||
id: str
|
||||
name: str
|
||||
parameters: Dict[str, Any]
|
||||
created_at: datetime = field(default_factory=utcnow)
|
||||
updated_at: datetime = field(default_factory=utcnow)
|
||||
|
||||
|
||||
class FakeConfigurationsNamespace:
|
||||
"""Mocks the llama-cloud v2 configurations API.
|
||||
|
||||
Endpoints covered:
|
||||
GET /api/v1/beta/configurations/{config_id} retrieve
|
||||
GET /api/v1/beta/configurations list
|
||||
POST /api/v1/beta/configurations create
|
||||
PATCH /api/v1/beta/configurations/{config_id} update
|
||||
DELETE /api/v1/beta/configurations/{config_id} delete
|
||||
"""
|
||||
|
||||
def __init__(self, *, server: "FakeLlamaCloudServer") -> None:
|
||||
self._server = server
|
||||
self._configurations: Dict[str, StoredConfiguration] = {}
|
||||
|
||||
# Public API -----------------------------------------------------
|
||||
def create(
|
||||
self,
|
||||
*,
|
||||
name: str,
|
||||
parameters: Dict[str, Any],
|
||||
) -> StoredConfiguration:
|
||||
config_id = self._server.new_id("cfg")
|
||||
stored = StoredConfiguration(id=config_id, name=name, parameters=parameters)
|
||||
self._configurations[config_id] = stored
|
||||
return stored
|
||||
|
||||
def get(self, config_id: str) -> Optional[StoredConfiguration]:
|
||||
return self._configurations.get(config_id)
|
||||
|
||||
# Route registration ---------------------------------------------
|
||||
def register(self) -> None:
|
||||
server = self._server
|
||||
server.add_route(
|
||||
"GET",
|
||||
"/api/v1/beta/configurations",
|
||||
self._handle_list,
|
||||
namespace="configurations",
|
||||
)
|
||||
server.add_route(
|
||||
"POST",
|
||||
"/api/v1/beta/configurations",
|
||||
self._handle_create,
|
||||
namespace="configurations",
|
||||
)
|
||||
server.add_route(
|
||||
"GET",
|
||||
"/api/v1/beta/configurations/{config_id}",
|
||||
self._handle_get,
|
||||
namespace="configurations",
|
||||
)
|
||||
server.add_route(
|
||||
"PATCH",
|
||||
"/api/v1/beta/configurations/{config_id}",
|
||||
self._handle_update,
|
||||
namespace="configurations",
|
||||
)
|
||||
server.add_route(
|
||||
"DELETE",
|
||||
"/api/v1/beta/configurations/{config_id}",
|
||||
self._handle_delete,
|
||||
namespace="configurations",
|
||||
)
|
||||
|
||||
# Handlers -------------------------------------------------------
|
||||
def _handle_create(self, request: httpx.Request) -> httpx.Response:
|
||||
payload = self._server.json(request)
|
||||
stored = self.create(name=payload["name"], parameters=payload["parameters"])
|
||||
return self._server.json_response(self._to_dict(stored))
|
||||
|
||||
def _handle_get(self, request: httpx.Request) -> httpx.Response:
|
||||
config_id = request.url.path.rstrip("/").split("/")[-1]
|
||||
stored = self._configurations.get(config_id)
|
||||
if not stored:
|
||||
return self._server.json_response(
|
||||
{"detail": "Configuration not found"}, status_code=404
|
||||
)
|
||||
return self._server.json_response(self._to_dict(stored))
|
||||
|
||||
def _handle_update(self, request: httpx.Request) -> httpx.Response:
|
||||
config_id = request.url.path.rstrip("/").split("/")[-1]
|
||||
stored = self._configurations.get(config_id)
|
||||
if not stored:
|
||||
return self._server.json_response(
|
||||
{"detail": "Configuration not found"}, status_code=404
|
||||
)
|
||||
payload = self._server.json(request)
|
||||
if "name" in payload and payload["name"] is not None:
|
||||
stored.name = payload["name"]
|
||||
if "parameters" in payload and payload["parameters"] is not None:
|
||||
stored.parameters = payload["parameters"]
|
||||
stored.updated_at = utcnow()
|
||||
return self._server.json_response(self._to_dict(stored))
|
||||
|
||||
def _handle_delete(self, request: httpx.Request) -> httpx.Response:
|
||||
config_id = request.url.path.rstrip("/").split("/")[-1]
|
||||
self._configurations.pop(config_id, None)
|
||||
return self._server.json_response({}, status_code=200)
|
||||
|
||||
def _handle_list(self, request: httpx.Request) -> httpx.Response:
|
||||
product_type = request.url.params.get_list("product_type")
|
||||
items = []
|
||||
for stored in self._configurations.values():
|
||||
pt = stored.parameters.get("product_type")
|
||||
if product_type and pt not in product_type:
|
||||
continue
|
||||
items.append(self._to_dict(stored))
|
||||
return self._server.json_response(
|
||||
{"items": items, "next_page_token": None, "has_more": False}
|
||||
)
|
||||
|
||||
# Helpers --------------------------------------------------------
|
||||
def _to_dict(self, stored: StoredConfiguration) -> Dict[str, Any]:
|
||||
product_type = stored.parameters.get("product_type", "extract_v2")
|
||||
response = ConfigurationResponse(
|
||||
id=stored.id,
|
||||
name=stored.name,
|
||||
parameters=ExtractV2Parameters(**stored.parameters)
|
||||
if product_type == "extract_v2"
|
||||
else stored.parameters, # type: ignore[arg-type]
|
||||
product_type=product_type,
|
||||
version=stored.updated_at.isoformat(),
|
||||
created_at=stored.created_at,
|
||||
updated_at=stored.updated_at,
|
||||
)
|
||||
return response.model_dump(mode="json")
|
||||
@@ -1,275 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
||||
|
||||
import httpx
|
||||
from llama_cloud.types.extract_configuration import ExtractConfiguration
|
||||
from llama_cloud.types.extract_v2_job import ExtractV2Job
|
||||
|
||||
from ._deterministic import (
|
||||
combined_seed,
|
||||
generate_data_from_schema,
|
||||
hash_schema,
|
||||
utcnow,
|
||||
)
|
||||
from .configurations import FakeConfigurationsNamespace
|
||||
from .files import FakeFilesNamespace, StoredFile
|
||||
from .matchers import RequestContext, RequestMatcher
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .server import FakeLlamaCloudServer
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExtractRunStub:
|
||||
matcher: Optional[RequestMatcher]
|
||||
data: Optional[Any]
|
||||
status: Optional[str]
|
||||
metadata: Optional[Dict[str, Any]]
|
||||
error: Optional[str]
|
||||
once: bool
|
||||
|
||||
|
||||
@dataclass
|
||||
class StoredJob:
|
||||
job: ExtractV2Job
|
||||
data_schema: Dict[str, Any]
|
||||
file: StoredFile
|
||||
|
||||
|
||||
class FakeExtractNamespace:
|
||||
"""Mocks the llama-cloud v2 extract API.
|
||||
|
||||
Endpoints covered:
|
||||
POST /api/v2/extract create extract job
|
||||
GET /api/v2/extract/{job_id} get job
|
||||
GET /api/v2/extract list jobs
|
||||
DELETE /api/v2/extract/{job_id} delete job
|
||||
POST /api/v2/extract/schema/validation validate schema
|
||||
POST /api/v2/extract/schema/generate generate schema
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
server: "FakeLlamaCloudServer",
|
||||
files: FakeFilesNamespace,
|
||||
configurations: FakeConfigurationsNamespace,
|
||||
) -> None:
|
||||
self._server = server
|
||||
self._files = files
|
||||
self._configurations = configurations
|
||||
self._jobs: Dict[str, StoredJob] = {}
|
||||
self._run_stubs: List[ExtractRunStub] = []
|
||||
self.routes: Dict[str, Any] = {}
|
||||
|
||||
# Public stub APIs -----------------------------------------------
|
||||
def stub_run(
|
||||
self,
|
||||
matcher: Optional[RequestMatcher],
|
||||
*,
|
||||
data: Optional[Any] = None,
|
||||
status: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
error: Optional[str] = None,
|
||||
once: bool = True,
|
||||
) -> None:
|
||||
self._run_stubs.append(
|
||||
ExtractRunStub(
|
||||
matcher=matcher,
|
||||
data=data,
|
||||
status=status,
|
||||
metadata=metadata,
|
||||
error=error,
|
||||
once=once,
|
||||
)
|
||||
)
|
||||
|
||||
# Route registration ---------------------------------------------
|
||||
def register(self) -> None:
|
||||
server = self._server
|
||||
create_route = server.add_route(
|
||||
"POST",
|
||||
"/api/v2/extract",
|
||||
self._handle_create_job,
|
||||
namespace="extract",
|
||||
alias="extract_create",
|
||||
)
|
||||
self.routes["create"] = create_route
|
||||
server.add_route(
|
||||
"GET",
|
||||
"/api/v2/extract",
|
||||
self._handle_list_jobs,
|
||||
namespace="extract",
|
||||
)
|
||||
server.add_route(
|
||||
"GET",
|
||||
"/api/v2/extract/{job_id}",
|
||||
self._handle_get_job,
|
||||
namespace="extract",
|
||||
)
|
||||
server.add_route(
|
||||
"DELETE",
|
||||
"/api/v2/extract/{job_id}",
|
||||
self._handle_delete_job,
|
||||
namespace="extract",
|
||||
)
|
||||
server.add_route(
|
||||
"POST",
|
||||
"/api/v2/extract/schema/validation",
|
||||
self._handle_validate_schema,
|
||||
namespace="extract",
|
||||
)
|
||||
server.add_route(
|
||||
"POST",
|
||||
"/api/v2/extract/schema/generate",
|
||||
self._handle_generate_schema,
|
||||
namespace="extract",
|
||||
)
|
||||
|
||||
# Handlers -------------------------------------------------------
|
||||
def _handle_create_job(self, request: httpx.Request) -> httpx.Response:
|
||||
payload = self._server.json(request)
|
||||
file_input = payload["file_input"]
|
||||
stored_file = self._files.get(file_input)
|
||||
if not stored_file:
|
||||
return self._server.json_response(
|
||||
{"detail": f"File {file_input} not found"}, status_code=404
|
||||
)
|
||||
|
||||
configuration = payload.get("configuration")
|
||||
configuration_id = payload.get("configuration_id")
|
||||
if configuration and configuration_id:
|
||||
return self._server.json_response(
|
||||
{"detail": "Provide configuration OR configuration_id"},
|
||||
status_code=400,
|
||||
)
|
||||
|
||||
if configuration_id:
|
||||
cfg = self._configurations.get(configuration_id)
|
||||
if not cfg:
|
||||
return self._server.json_response(
|
||||
{"detail": "Configuration not found"}, status_code=404
|
||||
)
|
||||
params = cfg.parameters
|
||||
data_schema = params["data_schema"]
|
||||
extract_config = {
|
||||
k: v for k, v in params.items() if k not in ("product_type",)
|
||||
}
|
||||
else:
|
||||
if not configuration:
|
||||
return self._server.json_response(
|
||||
{"detail": "configuration or configuration_id required"},
|
||||
status_code=400,
|
||||
)
|
||||
data_schema = configuration["data_schema"]
|
||||
extract_config = dict(configuration)
|
||||
|
||||
context = RequestContext(
|
||||
request=request,
|
||||
json=payload,
|
||||
file_id=stored_file.file.id,
|
||||
filename=stored_file.file.name,
|
||||
file_sha256=stored_file.sha256,
|
||||
schema_hash=hash_schema(data_schema),
|
||||
project_id=stored_file.file.project_id,
|
||||
organization_id=self._server.default_organization_id,
|
||||
)
|
||||
|
||||
stub = self._pop_stub(self._run_stubs, context)
|
||||
status = "COMPLETED"
|
||||
run_data = self._generate_run_data(data_schema, stored_file.sha256)
|
||||
error_message: Optional[str] = None
|
||||
|
||||
if stub:
|
||||
if stub.status:
|
||||
status = stub.status
|
||||
if stub.error:
|
||||
error_message = stub.error
|
||||
status = "FAILED"
|
||||
if stub.data is not None:
|
||||
if callable(stub.data):
|
||||
run_data = stub.data(payload) # type: ignore[assignment]
|
||||
else:
|
||||
run_data = stub.data
|
||||
|
||||
job_id = self._server.new_id("exj")
|
||||
now = utcnow()
|
||||
job = ExtractV2Job(
|
||||
id=job_id,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
file_input=file_input,
|
||||
project_id=stored_file.file.project_id,
|
||||
status=status,
|
||||
configuration=ExtractConfiguration(**extract_config)
|
||||
if not configuration_id
|
||||
else None,
|
||||
configuration_id=configuration_id,
|
||||
error_message=error_message,
|
||||
extract_result=run_data if status == "COMPLETED" else None,
|
||||
extract_metadata=None,
|
||||
metadata=None,
|
||||
)
|
||||
stored = StoredJob(job=job, data_schema=data_schema, file=stored_file)
|
||||
self._jobs[job_id] = stored
|
||||
return self._server.json_response(job.model_dump(mode="json"))
|
||||
|
||||
def _handle_get_job(self, request: httpx.Request) -> httpx.Response:
|
||||
job_id = request.url.path.rstrip("/").split("/")[-1]
|
||||
stored = self._jobs.get(job_id)
|
||||
if not stored:
|
||||
return self._server.json_response(
|
||||
{"detail": "Job not found"}, status_code=404
|
||||
)
|
||||
return self._server.json_response(stored.job.model_dump(mode="json"))
|
||||
|
||||
def _handle_list_jobs(self, request: httpx.Request) -> httpx.Response:
|
||||
items = [stored.job.model_dump(mode="json") for stored in self._jobs.values()]
|
||||
return self._server.json_response(
|
||||
{"items": items, "next_page_token": None, "has_more": False}
|
||||
)
|
||||
|
||||
def _handle_delete_job(self, request: httpx.Request) -> httpx.Response:
|
||||
job_id = request.url.path.rstrip("/").split("/")[-1]
|
||||
self._jobs.pop(job_id, None)
|
||||
return self._server.json_response({}, status_code=200)
|
||||
|
||||
def _handle_validate_schema(self, request: httpx.Request) -> httpx.Response:
|
||||
payload = self._server.json(request)
|
||||
return self._server.json_response({"data_schema": payload["data_schema"]})
|
||||
|
||||
def _handle_generate_schema(self, request: httpx.Request) -> httpx.Response:
|
||||
payload = self._server.json(request)
|
||||
name = payload.get("name") or "generated"
|
||||
schema = payload.get("data_schema") or {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
}
|
||||
return self._server.json_response(
|
||||
{
|
||||
"name": name,
|
||||
"parameters": {
|
||||
"product_type": "extract_v2",
|
||||
"data_schema": schema,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
# Helpers --------------------------------------------------------
|
||||
def _generate_run_data(self, schema: Dict[str, Any], file_hash: str) -> Any:
|
||||
seed = combined_seed(file_hash, hash_schema(schema))
|
||||
return generate_data_from_schema(schema, seed)
|
||||
|
||||
def _pop_stub(
|
||||
self,
|
||||
stubs: List[ExtractRunStub],
|
||||
context: RequestContext,
|
||||
) -> Optional[ExtractRunStub]:
|
||||
for index, stub in enumerate(list(stubs)):
|
||||
if context.matches(stub.matcher):
|
||||
if stub.once:
|
||||
stubs.pop(index)
|
||||
return stub
|
||||
return None
|
||||
@@ -1,335 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
||||
from urllib.parse import urlencode
|
||||
|
||||
import httpx
|
||||
import respx
|
||||
from llama_cloud.types import File as CloudFile
|
||||
from llama_cloud.types.file_list_response import FileListResponse
|
||||
from llama_cloud.types.file_query_response import FileQueryResponse, Item
|
||||
from llama_cloud.types.presigned_url import PresignedURL
|
||||
|
||||
from ._deterministic import (
|
||||
fingerprint_file,
|
||||
utcnow,
|
||||
)
|
||||
from .matchers import RequestMatcher
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .server import FakeLlamaCloudServer
|
||||
|
||||
|
||||
@dataclass
|
||||
class StoredFile:
|
||||
file: CloudFile
|
||||
content: bytes
|
||||
sha256: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class PendingUpload:
|
||||
file_id: str
|
||||
filename: str
|
||||
project_id: str
|
||||
organization_id: str
|
||||
external_file_id: Optional[str]
|
||||
expected_size: Optional[int]
|
||||
|
||||
|
||||
class FakeFilesNamespace:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
server: "FakeLlamaCloudServer",
|
||||
upload_base_url: str,
|
||||
download_base_url: str,
|
||||
) -> None:
|
||||
self._server = server
|
||||
self._upload_base_url = upload_base_url.rstrip("/")
|
||||
self._download_base_url = download_base_url.rstrip("/")
|
||||
self._files: Dict[str, StoredFile] = {}
|
||||
self._pending: Dict[str, PendingUpload] = {}
|
||||
self._upload_stubs: List[
|
||||
tuple[RequestMatcher | None, int, Dict[str, Any], bool]
|
||||
] = []
|
||||
self.routes: Dict[str, respx.Route] = {}
|
||||
|
||||
# Public helpers -------------------------------------------------
|
||||
def preload(self, *, path: str | Path, filename: Optional[str] = None) -> str:
|
||||
path = Path(path)
|
||||
content = path.read_bytes()
|
||||
file_id = self._server.new_id("file")
|
||||
name = filename or path.name
|
||||
stored = self._build_file(
|
||||
file_id=file_id,
|
||||
name=name,
|
||||
project_id=self._server.default_project_id,
|
||||
organization_id=self._server.default_organization_id,
|
||||
content=content,
|
||||
external_file_id=None,
|
||||
)
|
||||
self._files[file_id] = stored
|
||||
return file_id
|
||||
|
||||
def read(self, file_id: str) -> bytes:
|
||||
return self._files[file_id].content
|
||||
|
||||
def get(self, file_id: str) -> Optional[StoredFile]:
|
||||
return self._files.get(file_id)
|
||||
|
||||
def preload_from_source(self, filename: str, content: bytes) -> str:
|
||||
file_id = self._server.new_id("file")
|
||||
name = filename
|
||||
stored = self._build_file(
|
||||
file_id=file_id,
|
||||
name=name,
|
||||
project_id=self._server.default_project_id,
|
||||
organization_id=self._server.default_organization_id,
|
||||
content=content,
|
||||
external_file_id=None,
|
||||
)
|
||||
self._files[file_id] = stored
|
||||
return file_id
|
||||
|
||||
def stub_upload(
|
||||
self,
|
||||
matcher: Optional[RequestMatcher],
|
||||
*,
|
||||
status_code: int = 413,
|
||||
json_body: Optional[Dict[str, Any]] = None,
|
||||
once: bool = True,
|
||||
) -> None:
|
||||
body = json_body or {"detail": "upload rejected by fake server"}
|
||||
self._upload_stubs.append((matcher, status_code, body, once))
|
||||
|
||||
def all_files(self) -> Dict[str, StoredFile]:
|
||||
return dict(self._files)
|
||||
|
||||
# Route registration ---------------------------------------------
|
||||
def register(self) -> None:
|
||||
server = self._server
|
||||
upload_route = server.add_route(
|
||||
"POST",
|
||||
"/api/v1/beta/files",
|
||||
self._handle_direct_upload,
|
||||
namespace="files",
|
||||
alias="upload",
|
||||
)
|
||||
self.routes["upload"] = upload_route
|
||||
list_route = server.add_route(
|
||||
"GET",
|
||||
"/api/v1/beta/files",
|
||||
self._handle_list,
|
||||
namespace="files",
|
||||
alias="list_files",
|
||||
)
|
||||
self.routes["list"] = list_route
|
||||
get_route = server.add_route(
|
||||
"GET",
|
||||
"/api/v1/beta/files/{file_id}/content",
|
||||
self._handle_read_content,
|
||||
namespace="files",
|
||||
alias="get",
|
||||
)
|
||||
self.routes["get"] = get_route
|
||||
server.add_route(
|
||||
"DELETE",
|
||||
"/api/v1/beta/files/{file_id}",
|
||||
self._handle_delete,
|
||||
namespace="files",
|
||||
)
|
||||
server.add_route(
|
||||
"POST",
|
||||
"/api/v1/beta/files/query",
|
||||
self._handle_query,
|
||||
namespace="files",
|
||||
alias="query",
|
||||
)
|
||||
server.add_route(
|
||||
"GET",
|
||||
"/files/{file_id}",
|
||||
self._handle_presigned_download,
|
||||
namespace="files",
|
||||
base_urls=[self._download_base_url],
|
||||
alias="download",
|
||||
)
|
||||
|
||||
# Handlers -------------------------------------------------------
|
||||
def _handle_direct_upload(self, request: httpx.Request) -> httpx.Response:
|
||||
file_bytes, filename = self._extract_multipart_file(request)
|
||||
file_id = self._server.new_id("file")
|
||||
stored = self._build_file(
|
||||
file_id=file_id,
|
||||
name=filename or f"upload-{file_id}.bin",
|
||||
project_id=request.url.params.get(
|
||||
"project_id", self._server.default_project_id
|
||||
),
|
||||
organization_id=request.url.params.get(
|
||||
"organization_id", self._server.default_organization_id
|
||||
),
|
||||
content=file_bytes,
|
||||
external_file_id=request.url.params.get("external_file_id"),
|
||||
)
|
||||
self._files[file_id] = stored
|
||||
return self._server.json_response(stored.file.model_dump())
|
||||
|
||||
def _handle_list(self, request: httpx.Request) -> httpx.Response:
|
||||
params = request.url.params
|
||||
file_ids_raw = params.multi_items()
|
||||
file_ids_filter = [v for k, v in file_ids_raw if k == "file_ids"]
|
||||
file_name = params.get("file_name")
|
||||
external_file_id = params.get("external_file_id")
|
||||
page_size = int(params.get("page_size", "50"))
|
||||
|
||||
files = list(self._files.values())
|
||||
if file_ids_filter:
|
||||
files = [f for f in files if f.file.id in file_ids_filter]
|
||||
if file_name:
|
||||
files = [f for f in files if f.file.name == file_name]
|
||||
if external_file_id:
|
||||
files = [f for f in files if f.file.external_file_id == external_file_id]
|
||||
|
||||
files = files[:page_size]
|
||||
items = [
|
||||
FileListResponse(
|
||||
id=f.file.id,
|
||||
name=f.file.name,
|
||||
project_id=f.file.project_id,
|
||||
expires_at=f.file.expires_at,
|
||||
external_file_id=f.file.external_file_id,
|
||||
file_type=f.file.file_type,
|
||||
last_modified_at=f.file.last_modified_at,
|
||||
purpose=f.file.purpose,
|
||||
)
|
||||
for f in files
|
||||
]
|
||||
return self._server.json_response(
|
||||
{
|
||||
"items": [item.model_dump() for item in items],
|
||||
"next_page_token": None,
|
||||
}
|
||||
)
|
||||
|
||||
def _handle_delete(self, request: httpx.Request) -> httpx.Response:
|
||||
file_id = request.url.path.split("/")[-1]
|
||||
self._files.pop(file_id, None)
|
||||
self._pending.pop(file_id, None)
|
||||
return self._server.json_response({}, status_code=200)
|
||||
|
||||
def _handle_read_content(self, request: httpx.Request) -> httpx.Response:
|
||||
file_id = request.url.path.split("/")[-2]
|
||||
if file_id not in self._files:
|
||||
return self._server.json_response(
|
||||
{"detail": "File not found"}, status_code=404
|
||||
)
|
||||
presigned = PresignedURL(
|
||||
url=f"{self._download_base_url}/files/{file_id}?{urlencode({'token': 'fake'})}",
|
||||
expires_at=utcnow(),
|
||||
form_fields=None,
|
||||
)
|
||||
return self._server.json_response(presigned.model_dump())
|
||||
|
||||
def _handle_presigned_download(self, request: httpx.Request) -> httpx.Response:
|
||||
file_id = request.url.path.split("/")[-1]
|
||||
stored = self._files.get(file_id)
|
||||
if not stored:
|
||||
return httpx.Response(404, json={"detail": "File not found"})
|
||||
return httpx.Response(200, content=stored.content)
|
||||
|
||||
def _handle_query(self, request: httpx.Request) -> httpx.Response:
|
||||
payload = self._server.json(request)
|
||||
files: list[StoredFile] = []
|
||||
items: list[Item] = []
|
||||
if payload.get("filter") is not None:
|
||||
file_ids = payload["filter"].get("file_ids", [])
|
||||
for file_id in self._files:
|
||||
if file_id in file_ids:
|
||||
files.append(self._files[file_id])
|
||||
else:
|
||||
files = list(self._files.values())
|
||||
for f in files:
|
||||
item = Item(
|
||||
id=f.file.id,
|
||||
name=f.file.name,
|
||||
project_id=self._server.default_project_id,
|
||||
expires_at=utcnow(),
|
||||
external_file_id=f.file.external_file_id,
|
||||
purpose=f.file.purpose,
|
||||
last_modified_at=utcnow(),
|
||||
file_type=f.file.file_type,
|
||||
)
|
||||
items.append(item)
|
||||
response = FileQueryResponse(
|
||||
items=items, next_page_token=None, total_size=len(items)
|
||||
)
|
||||
return self._server.json_response(response.model_dump())
|
||||
|
||||
# Internal helpers -----------------------------------------------
|
||||
def _build_file(
|
||||
self,
|
||||
*,
|
||||
file_id: str,
|
||||
name: str,
|
||||
project_id: str,
|
||||
organization_id: str,
|
||||
content: bytes,
|
||||
external_file_id: Optional[str],
|
||||
) -> StoredFile:
|
||||
sha256 = fingerprint_file(content, name)
|
||||
now = utcnow()
|
||||
cloud_file = CloudFile(
|
||||
id=file_id,
|
||||
name=name,
|
||||
project_id=project_id,
|
||||
external_file_id=external_file_id,
|
||||
file_size=len(content),
|
||||
file_type=Path(name).suffix or "application/octet-stream",
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
data_source_id=None,
|
||||
permission_info=None,
|
||||
resource_info=None,
|
||||
last_modified_at=now,
|
||||
)
|
||||
return StoredFile(file=cloud_file, content=content, sha256=sha256)
|
||||
|
||||
def _extract_multipart_file(
|
||||
self, request: httpx.Request
|
||||
) -> tuple[bytes, Optional[str]]:
|
||||
content_type = request.headers.get("content-type", "")
|
||||
if "multipart/form-data" not in content_type:
|
||||
raise ValueError("Expected multipart upload")
|
||||
|
||||
boundary = content_type.split("boundary=")[-1]
|
||||
boundary_bytes = boundary.encode("utf-8")
|
||||
body = request.content
|
||||
delimiter = b"--" + boundary_bytes
|
||||
parts = [
|
||||
part
|
||||
for part in body.split(delimiter)
|
||||
if part.strip(b"\r\n") and part.strip(b"\r\n") != b"--"
|
||||
]
|
||||
for part in parts:
|
||||
headers, _, payload = part.partition(b"\r\n\r\n")
|
||||
header_text = headers.decode("utf-8", errors="ignore")
|
||||
if 'name="upload_file"' in header_text or 'name="file"' in header_text:
|
||||
filename = None
|
||||
if "filename=" in header_text:
|
||||
filename = (
|
||||
header_text.split("filename=")[-1].strip().strip('"').strip("'")
|
||||
)
|
||||
return payload.rstrip(b"\r\n"), filename
|
||||
raise ValueError("upload file part not found")
|
||||
|
||||
def decode_file_data(self, data: Dict[str, Any]) -> tuple[bytes, Optional[str]]:
|
||||
if "file" not in data:
|
||||
raise ValueError("file payload missing")
|
||||
file_payload = data["file"]
|
||||
encoded = file_payload["data"]
|
||||
content = base64.b64decode(encoded)
|
||||
filename = file_payload.get("filename")
|
||||
return content, filename
|
||||
@@ -1,100 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
import httpx
|
||||
|
||||
|
||||
MatcherPredicate = Callable[[httpx.Request], bool]
|
||||
|
||||
|
||||
@dataclass
|
||||
class FileMatcher:
|
||||
filename: Optional[str] = None
|
||||
sha256: Optional[str] = None
|
||||
file_id: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class SchemaMatcher:
|
||||
model: Optional[type] = None
|
||||
schema_hash: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class RequestMatcher:
|
||||
file: Optional[FileMatcher | MatcherPredicate] = None
|
||||
schema: Optional[SchemaMatcher] = None
|
||||
agent_id: Optional[str] = None
|
||||
project_id: Optional[str] = None
|
||||
organization_id: Optional[str] = None
|
||||
predicate: Optional[MatcherPredicate] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class RequestContext:
|
||||
request: httpx.Request
|
||||
json: Optional[dict[str, Any]]
|
||||
file_id: Optional[str] = None
|
||||
filename: Optional[str] = None
|
||||
file_sha256: Optional[str] = None
|
||||
schema_hash: Optional[str] = None
|
||||
agent_id: Optional[str] = None
|
||||
project_id: Optional[str] = None
|
||||
organization_id: Optional[str] = None
|
||||
|
||||
def matches(self, matcher: Optional[RequestMatcher]) -> bool:
|
||||
if matcher is None:
|
||||
return True
|
||||
|
||||
if matcher.project_id and matcher.project_id != self.project_id:
|
||||
return False
|
||||
|
||||
if matcher.organization_id and matcher.organization_id != self.organization_id:
|
||||
return False
|
||||
|
||||
if matcher.agent_id and matcher.agent_id != self.agent_id:
|
||||
return False
|
||||
|
||||
if matcher.file:
|
||||
if isinstance(matcher.file, FileMatcher):
|
||||
if matcher.file.filename and matcher.file.filename != self.filename:
|
||||
return False
|
||||
if matcher.file.file_id and matcher.file.file_id != self.file_id:
|
||||
return False
|
||||
if matcher.file.sha256 and matcher.file.sha256 != self.file_sha256:
|
||||
return False
|
||||
else:
|
||||
if not matcher.file(self.request):
|
||||
return False
|
||||
|
||||
if matcher.schema:
|
||||
if (
|
||||
matcher.schema.schema_hash
|
||||
and matcher.schema.schema_hash != self.schema_hash
|
||||
):
|
||||
return False
|
||||
if matcher.schema.model and matcher.schema.schema_hash:
|
||||
return matcher.schema.schema_hash == self.schema_hash
|
||||
if matcher.schema.model and matcher.schema.schema_hash is None:
|
||||
expected = _schema_hash_from_model(matcher.schema.model)
|
||||
return expected == self.schema_hash
|
||||
|
||||
if matcher.predicate and not matcher.predicate(self.request):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def _schema_hash_from_model(model: type) -> Optional[str]:
|
||||
if hasattr(model, "model_json_schema"):
|
||||
schema = model.model_json_schema()
|
||||
elif hasattr(model, "schema"):
|
||||
schema = model.schema() # type: ignore[attr-defined]
|
||||
else:
|
||||
return None
|
||||
|
||||
from ._deterministic import hash_schema
|
||||
|
||||
return hash_schema(schema)
|
||||
@@ -1,269 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, Dict
|
||||
|
||||
import httpx
|
||||
from llama_cloud.types.parsing_create_response import ParsingCreateResponse
|
||||
from llama_cloud.types.parsing_get_response import (
|
||||
Items,
|
||||
ItemsPage,
|
||||
ItemsPageStructuredResultPage,
|
||||
TextItem,
|
||||
Job,
|
||||
Markdown,
|
||||
MarkdownPage,
|
||||
MarkdownPageMarkdownResultPage,
|
||||
ParsingGetResponse,
|
||||
Text,
|
||||
TextPage,
|
||||
)
|
||||
|
||||
from ._deterministic import generate_text_blob, hash_schema, utcnow
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .server import FakeLlamaCloudServer
|
||||
|
||||
|
||||
@dataclass
|
||||
class ParseJobRecord:
|
||||
job_id: str
|
||||
file_name: str
|
||||
status: str
|
||||
result: Dict[str, Any]
|
||||
content: bytes
|
||||
|
||||
|
||||
class FakeParseNamespace:
|
||||
def __init__(self, *, server: "FakeLlamaCloudServer") -> None:
|
||||
self._server = server
|
||||
self._jobs: Dict[str, ParsingGetResponse] = {}
|
||||
self.routes: Dict[str, Any] = {}
|
||||
self.allowed_expands = ("text", "markdown", "items")
|
||||
|
||||
def register(self) -> None:
|
||||
server = self._server
|
||||
server.add_route(
|
||||
"POST",
|
||||
"/api/v2/parse/upload",
|
||||
self._handle_upload,
|
||||
namespace="parse",
|
||||
)
|
||||
server.add_route(
|
||||
"GET",
|
||||
"/api/v2/parse/{job_id}",
|
||||
self._handle_job_result,
|
||||
namespace="parse",
|
||||
)
|
||||
server.add_route(
|
||||
"POST",
|
||||
"/api/v2/parse",
|
||||
self._handle_file_id_source_url,
|
||||
namespace="parse",
|
||||
)
|
||||
|
||||
def _handle_upload(self, request: httpx.Request) -> httpx.Response:
|
||||
_, filename, form_data = self._split_multipart(request)
|
||||
job_id = self._server.new_id("parse-job")
|
||||
seed_hash = hash_schema({"filename": filename, "form": form_data})
|
||||
seed = int(seed_hash[:16], 16)
|
||||
page_text = generate_text_blob(seed, sentences=3)
|
||||
item_pages: list[ItemsPage] = [
|
||||
ItemsPageStructuredResultPage(
|
||||
items=[TextItem(md=page_text, value=page_text, bBox=None, type="text")],
|
||||
page_height=1,
|
||||
page_number=1,
|
||||
page_width=1,
|
||||
success=True,
|
||||
)
|
||||
]
|
||||
md_pages: list[MarkdownPage] = [
|
||||
MarkdownPageMarkdownResultPage(
|
||||
markdown=page_text,
|
||||
page_number=1,
|
||||
success=True,
|
||||
)
|
||||
]
|
||||
txt_pages: list[TextPage] = [
|
||||
TextPage(
|
||||
text=page_text,
|
||||
page_number=1,
|
||||
)
|
||||
]
|
||||
record = ParsingGetResponse(
|
||||
job=Job(
|
||||
id=job_id,
|
||||
status="COMPLETED",
|
||||
project_id=self._server.default_project_id,
|
||||
created_at=utcnow(),
|
||||
updated_at=utcnow(),
|
||||
error_message=None,
|
||||
),
|
||||
items=Items(pages=item_pages),
|
||||
markdown=Markdown(pages=md_pages),
|
||||
text=Text(pages=txt_pages),
|
||||
)
|
||||
self._jobs[job_id] = record
|
||||
response = ParsingCreateResponse(
|
||||
id=job_id,
|
||||
project_id=self._server.default_project_id,
|
||||
status="COMPLETED",
|
||||
created_at=utcnow(),
|
||||
updated_at=utcnow(),
|
||||
error_message=None,
|
||||
)
|
||||
return self._server.json_response(response.model_dump())
|
||||
|
||||
def _handle_file_id_source_url(self, request: httpx.Request) -> httpx.Response:
|
||||
payload = self._server.json(request)
|
||||
file_id = payload.get("file_id")
|
||||
source_url = payload.get("source_url")
|
||||
if file_id is not None:
|
||||
file = self._server.files.get(file_id)
|
||||
if file is None:
|
||||
return self._server.json_response(
|
||||
{"details": f"File {file_id} not found"},
|
||||
status_code=404,
|
||||
)
|
||||
else:
|
||||
seed_hash = file.sha256
|
||||
elif source_url is not None:
|
||||
response = self._get_file_from_source_url(source_url)
|
||||
if isinstance(response, int):
|
||||
return self._server.json_response(
|
||||
{"details": f"Could not find file associated with {source_url}"},
|
||||
status_code=response,
|
||||
)
|
||||
file_content, filename = response
|
||||
file_id = self._server.files.preload_from_source(filename, file_content)
|
||||
seed_hash = self._server.files._files[file_id].sha256
|
||||
else:
|
||||
return self._server.json_response(
|
||||
{
|
||||
"details": "At least one between file_id and source_url should be not-null",
|
||||
},
|
||||
status_code=400,
|
||||
)
|
||||
job_id = self._server.new_id("parse-job")
|
||||
seed = int(seed_hash[:16], 16)
|
||||
page_text = generate_text_blob(seed, sentences=3)
|
||||
item_pages: list[ItemsPage] = [
|
||||
ItemsPageStructuredResultPage(
|
||||
items=[TextItem(md=page_text, value=page_text, bBox=None, type="text")],
|
||||
page_height=1,
|
||||
page_number=1,
|
||||
page_width=1,
|
||||
success=True,
|
||||
)
|
||||
]
|
||||
md_pages: list[MarkdownPage] = [
|
||||
MarkdownPageMarkdownResultPage(
|
||||
markdown=page_text,
|
||||
page_number=1,
|
||||
success=True,
|
||||
)
|
||||
]
|
||||
txt_pages: list[TextPage] = [
|
||||
TextPage(
|
||||
text=page_text,
|
||||
page_number=1,
|
||||
)
|
||||
]
|
||||
record = ParsingGetResponse(
|
||||
job=Job(
|
||||
id=job_id,
|
||||
status="COMPLETED",
|
||||
project_id=self._server.default_project_id,
|
||||
created_at=utcnow(),
|
||||
updated_at=utcnow(),
|
||||
error_message=None,
|
||||
),
|
||||
items=Items(pages=item_pages),
|
||||
markdown=Markdown(pages=md_pages),
|
||||
text=Text(pages=txt_pages),
|
||||
)
|
||||
self._jobs[job_id] = record
|
||||
response = ParsingCreateResponse(
|
||||
id=job_id,
|
||||
project_id=self._server.default_project_id,
|
||||
status="COMPLETED",
|
||||
created_at=utcnow(),
|
||||
updated_at=utcnow(),
|
||||
error_message=None,
|
||||
)
|
||||
return self._server.json_response(response.model_dump())
|
||||
|
||||
def _handle_job_result(self, request: httpx.Request) -> httpx.Response:
|
||||
job_id = request.url.path.split("/")[-1]
|
||||
expandees = request.url.params.get_list("expand")
|
||||
expandees = (
|
||||
[e for e in expandees if e in self.allowed_expands]
|
||||
if len(expandees) > 0
|
||||
else ["items"]
|
||||
)
|
||||
job_response = self._jobs.get(job_id)
|
||||
if not job_response:
|
||||
return self._server.json_response(
|
||||
{"detail": "Result not found"}, status_code=404
|
||||
)
|
||||
jb_resp_copy = deepcopy(job_response)
|
||||
if "markdown" not in expandees:
|
||||
jb_resp_copy.markdown = None
|
||||
if "text" not in expandees:
|
||||
jb_resp_copy.text = None
|
||||
if "items" not in expandees:
|
||||
jb_resp_copy.items = None
|
||||
return self._server.json_response(jb_resp_copy.model_dump())
|
||||
|
||||
def _split_multipart(
|
||||
self, request: httpx.Request
|
||||
) -> tuple[bytes, str, Dict[str, str]]:
|
||||
content_type = request.headers.get("content-type", "")
|
||||
if "multipart/form-data" not in content_type:
|
||||
raise ValueError("Expected multipart form data for parse upload")
|
||||
boundary = content_type.split("boundary=")[-1]
|
||||
delimiter = f"--{boundary}".encode()
|
||||
closing = f"--{boundary}--".encode()
|
||||
parts = []
|
||||
body = request.content
|
||||
for chunk in body.split(delimiter):
|
||||
chunk = chunk.strip()
|
||||
if not chunk or chunk == closing:
|
||||
continue
|
||||
parts.append(chunk)
|
||||
|
||||
file_bytes = b""
|
||||
filename = "upload.pdf"
|
||||
form_data: Dict[str, str] = {}
|
||||
for part in parts:
|
||||
header_blob, _, payload = part.partition(b"\r\n\r\n")
|
||||
payload = payload.rstrip(b"\r\n")
|
||||
header_text = header_blob.decode("utf-8", errors="ignore")
|
||||
if "filename=" in header_text:
|
||||
# Extract filename from Content-Disposition header, handling quotes
|
||||
# and avoiding capturing subsequent headers or parameters
|
||||
match = re.search(r'filename="([^"]+)"', header_text)
|
||||
if not match:
|
||||
match = re.search(r"filename='([^']+)'", header_text)
|
||||
if not match:
|
||||
match = re.search(r"filename=([^\s;\r\n]+)", header_text)
|
||||
if match:
|
||||
filename = match.group(1)
|
||||
file_bytes = payload
|
||||
else:
|
||||
name = header_text.split('name="')[-1].split('"')[0].strip()
|
||||
form_data[name] = payload.decode("utf-8", errors="ignore")
|
||||
if not file_bytes:
|
||||
raise ValueError("File part missing from multipart payload")
|
||||
return file_bytes, filename, form_data
|
||||
|
||||
def _get_file_from_source_url(self, source_url: str) -> tuple[bytes, str] | int:
|
||||
name = source_url.split("/")[-1]
|
||||
with httpx.Client() as client:
|
||||
response = client.get(source_url, follow_redirects=True)
|
||||
if response.status_code >= 400:
|
||||
return response.status_code
|
||||
content = response.content
|
||||
return content, name
|
||||
@@ -1,418 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
||||
|
||||
import httpx
|
||||
from llama_cloud.types.managed_ingestion_status_response import (
|
||||
ManagedIngestionStatusResponse,
|
||||
)
|
||||
from llama_cloud.types.pipeline import Pipeline
|
||||
from llama_cloud.types.pipeline_retrieve_response import (
|
||||
PipelineRetrieveResponse,
|
||||
RetrievalNode,
|
||||
)
|
||||
from llama_cloud.types.pipelines.cloud_document import CloudDocument
|
||||
from llama_cloud.types.pipelines.pipeline_file import PipelineFile
|
||||
from llama_cloud.types.pipelines.text_node import TextNode
|
||||
|
||||
from ._deterministic import combined_seed, generate_text_blob, utcnow
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .server import FakeLlamaCloudServer
|
||||
|
||||
|
||||
class FakePipelinesNamespace:
|
||||
def __init__(self, *, server: "FakeLlamaCloudServer") -> None:
|
||||
self._server = server
|
||||
self._pipelines: Dict[str, Pipeline] = {}
|
||||
# Per-pipeline storage for ingested documents and files
|
||||
self._documents: Dict[str, Dict[str, CloudDocument]] = {}
|
||||
self._files: Dict[str, Dict[str, PipelineFile]] = {}
|
||||
self.routes: Dict[str, Any] = {}
|
||||
|
||||
def register(self) -> None:
|
||||
server = self._server
|
||||
server.add_route(
|
||||
"POST",
|
||||
"/api/v1/pipelines",
|
||||
self._handle_create,
|
||||
namespace="pipelines",
|
||||
)
|
||||
server.add_route(
|
||||
"GET",
|
||||
"/api/v1/pipelines",
|
||||
self._handle_list,
|
||||
namespace="pipelines",
|
||||
)
|
||||
server.add_route(
|
||||
"GET",
|
||||
"/api/v1/pipelines/{pipeline_id}",
|
||||
self._handle_get,
|
||||
namespace="pipelines",
|
||||
)
|
||||
server.add_route(
|
||||
"PUT",
|
||||
"/api/v1/pipelines/{pipeline_id}",
|
||||
self._handle_update,
|
||||
namespace="pipelines",
|
||||
)
|
||||
server.add_route(
|
||||
"DELETE",
|
||||
"/api/v1/pipelines/{pipeline_id}",
|
||||
self._handle_delete,
|
||||
namespace="pipelines",
|
||||
)
|
||||
server.add_route(
|
||||
"GET",
|
||||
"/api/v1/pipelines/{pipeline_id}/status",
|
||||
self._handle_get_status,
|
||||
namespace="pipelines",
|
||||
)
|
||||
server.add_route(
|
||||
"POST",
|
||||
"/api/v1/pipelines/{pipeline_id}/retrieve",
|
||||
self._handle_retrieve,
|
||||
namespace="pipelines",
|
||||
)
|
||||
# Document ingestion
|
||||
server.add_route(
|
||||
"POST",
|
||||
"/api/v1/pipelines/{pipeline_id}/documents",
|
||||
self._handle_create_documents,
|
||||
namespace="pipelines",
|
||||
)
|
||||
server.add_route(
|
||||
"PUT",
|
||||
"/api/v1/pipelines/{pipeline_id}/documents",
|
||||
self._handle_upsert_documents,
|
||||
namespace="pipelines",
|
||||
)
|
||||
# File ingestion
|
||||
server.add_route(
|
||||
"PUT",
|
||||
"/api/v1/pipelines/{pipeline_id}/files",
|
||||
self._handle_upsert_files,
|
||||
namespace="pipelines",
|
||||
)
|
||||
|
||||
# Handlers -------------------------------------------------------
|
||||
|
||||
def _handle_create(self, request: httpx.Request) -> httpx.Response:
|
||||
payload = self._server.json(request)
|
||||
name = payload.get("name")
|
||||
if not name:
|
||||
return self._server.json_response(
|
||||
{"detail": "name is required"}, status_code=400
|
||||
)
|
||||
|
||||
pipeline_id = self._server.new_id("pipeline")
|
||||
project_id = request.url.params.get(
|
||||
"project_id", self._server.default_project_id
|
||||
)
|
||||
now = utcnow()
|
||||
|
||||
embedding_config = payload.get("embedding_config") or {
|
||||
"type": "MANAGED_OPENAI_EMBEDDING",
|
||||
"component": {},
|
||||
}
|
||||
|
||||
pipeline = Pipeline(
|
||||
id=pipeline_id,
|
||||
name=name,
|
||||
project_id=project_id,
|
||||
embedding_config=embedding_config,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
pipeline_type=payload.get("pipeline_type", "MANAGED"),
|
||||
status="CREATED",
|
||||
)
|
||||
self._pipelines[pipeline_id] = pipeline
|
||||
self._documents[pipeline_id] = {}
|
||||
self._files[pipeline_id] = {}
|
||||
return self._server.json_response(pipeline.model_dump(), status_code=200)
|
||||
|
||||
def _handle_list(self, request: httpx.Request) -> httpx.Response:
|
||||
params = request.url.params
|
||||
project_id = params.get("project_id")
|
||||
pipeline_name = params.get("pipeline_name")
|
||||
|
||||
pipelines = list(self._pipelines.values())
|
||||
if project_id:
|
||||
pipelines = [p for p in pipelines if p.project_id == project_id]
|
||||
if pipeline_name:
|
||||
pipelines = [p for p in pipelines if p.name == pipeline_name]
|
||||
|
||||
return self._server.json_response([p.model_dump() for p in pipelines])
|
||||
|
||||
def _handle_get(self, request: httpx.Request) -> httpx.Response:
|
||||
pipeline_id = request.url.path.split("/")[-1]
|
||||
pipeline = self._pipelines.get(pipeline_id)
|
||||
if not pipeline:
|
||||
return self._server.json_response(
|
||||
{"detail": f"Pipeline {pipeline_id} not found"}, status_code=404
|
||||
)
|
||||
return self._server.json_response(pipeline.model_dump())
|
||||
|
||||
def _handle_update(self, request: httpx.Request) -> httpx.Response:
|
||||
pipeline_id = request.url.path.split("/")[-1]
|
||||
pipeline = self._pipelines.get(pipeline_id)
|
||||
if not pipeline:
|
||||
return self._server.json_response(
|
||||
{"detail": f"Pipeline {pipeline_id} not found"}, status_code=404
|
||||
)
|
||||
|
||||
payload = self._server.json(request)
|
||||
data = pipeline.model_dump()
|
||||
data.update({k: v for k, v in payload.items() if v is not None})
|
||||
data["updated_at"] = utcnow()
|
||||
updated = Pipeline.model_validate(data)
|
||||
self._pipelines[pipeline_id] = updated
|
||||
return self._server.json_response(updated.model_dump())
|
||||
|
||||
def _handle_delete(self, request: httpx.Request) -> httpx.Response:
|
||||
pipeline_id = request.url.path.split("/")[-1]
|
||||
self._pipelines.pop(pipeline_id, None)
|
||||
self._documents.pop(pipeline_id, None)
|
||||
self._files.pop(pipeline_id, None)
|
||||
return self._server.json_response({}, status_code=200)
|
||||
|
||||
def _handle_get_status(self, request: httpx.Request) -> httpx.Response:
|
||||
parts = request.url.path.split("/")
|
||||
pipeline_id = parts[-2]
|
||||
pipeline = self._pipelines.get(pipeline_id)
|
||||
if not pipeline:
|
||||
return self._server.json_response(
|
||||
{"detail": f"Pipeline {pipeline_id} not found"}, status_code=404
|
||||
)
|
||||
status = ManagedIngestionStatusResponse(status="SUCCESS")
|
||||
return self._server.json_response(status.model_dump())
|
||||
|
||||
# Helpers --------------------------------------------------------
|
||||
|
||||
@staticmethod
|
||||
def _extract_list(payload: Any, key: str) -> List[Dict[str, Any]]:
|
||||
"""Extract a list from payload that may be a bare array or a dict."""
|
||||
if isinstance(payload, list):
|
||||
return payload
|
||||
if isinstance(payload, dict):
|
||||
return payload.get(key, [])
|
||||
return []
|
||||
|
||||
# Document ingestion ---------------------------------------------
|
||||
|
||||
def _ingest_documents(
|
||||
self, pipeline_id: str, documents: List[Dict[str, Any]]
|
||||
) -> List[CloudDocument]:
|
||||
store = self._documents.setdefault(pipeline_id, {})
|
||||
results: List[CloudDocument] = []
|
||||
for doc_payload in documents:
|
||||
doc_id = doc_payload.get("id") or self._server.new_id("doc")
|
||||
doc = CloudDocument(
|
||||
id=doc_id,
|
||||
text=doc_payload.get("text", ""),
|
||||
metadata=doc_payload.get("metadata", {}),
|
||||
excluded_embed_metadata_keys=doc_payload.get(
|
||||
"excluded_embed_metadata_keys"
|
||||
),
|
||||
excluded_llm_metadata_keys=doc_payload.get(
|
||||
"excluded_llm_metadata_keys"
|
||||
),
|
||||
)
|
||||
store[doc_id] = doc
|
||||
results.append(doc)
|
||||
return results
|
||||
|
||||
def _handle_create_documents(self, request: httpx.Request) -> httpx.Response:
|
||||
parts = request.url.path.split("/")
|
||||
pipeline_id = parts[-2]
|
||||
pipeline = self._pipelines.get(pipeline_id)
|
||||
if not pipeline:
|
||||
return self._server.json_response(
|
||||
{"detail": f"Pipeline {pipeline_id} not found"}, status_code=404
|
||||
)
|
||||
|
||||
payload = self._server.json(request)
|
||||
documents = self._extract_list(payload, "documents")
|
||||
results = self._ingest_documents(pipeline_id, documents)
|
||||
return self._server.json_response(
|
||||
[d.model_dump() for d in results], status_code=200
|
||||
)
|
||||
|
||||
def _handle_upsert_documents(self, request: httpx.Request) -> httpx.Response:
|
||||
parts = request.url.path.split("/")
|
||||
pipeline_id = parts[-2]
|
||||
pipeline = self._pipelines.get(pipeline_id)
|
||||
if not pipeline:
|
||||
return self._server.json_response(
|
||||
{"detail": f"Pipeline {pipeline_id} not found"}, status_code=404
|
||||
)
|
||||
|
||||
payload = self._server.json(request)
|
||||
documents = self._extract_list(payload, "documents")
|
||||
results = self._ingest_documents(pipeline_id, documents)
|
||||
return self._server.json_response(
|
||||
[d.model_dump() for d in results], status_code=200
|
||||
)
|
||||
|
||||
# File ingestion -------------------------------------------------
|
||||
|
||||
def _handle_upsert_files(self, request: httpx.Request) -> httpx.Response:
|
||||
parts = request.url.path.split("/")
|
||||
pipeline_id = parts[-2]
|
||||
pipeline = self._pipelines.get(pipeline_id)
|
||||
if not pipeline:
|
||||
return self._server.json_response(
|
||||
{"detail": f"Pipeline {pipeline_id} not found"}, status_code=404
|
||||
)
|
||||
|
||||
payload = self._server.json(request)
|
||||
file_ids = self._extract_list(payload, "files")
|
||||
now = utcnow()
|
||||
store = self._files.setdefault(pipeline_id, {})
|
||||
results: List[PipelineFile] = []
|
||||
for entry in file_ids:
|
||||
# Each entry can be a dict with file_id + optional metadata, or a plain string
|
||||
if isinstance(entry, dict):
|
||||
file_id = entry.get("file_id", self._server.new_id("file"))
|
||||
custom_metadata = entry.get("custom_metadata")
|
||||
name = entry.get("name")
|
||||
else:
|
||||
file_id = str(entry)
|
||||
custom_metadata = None
|
||||
name = None
|
||||
|
||||
pf_id = self._server.new_id("pf")
|
||||
pf = PipelineFile(
|
||||
id=pf_id,
|
||||
pipeline_id=pipeline_id,
|
||||
file_id=file_id,
|
||||
name=name or f"file-{file_id}",
|
||||
status="SUCCESS",
|
||||
project_id=pipeline.project_id,
|
||||
custom_metadata=custom_metadata,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
)
|
||||
store[pf_id] = pf
|
||||
results.append(pf)
|
||||
return self._server.json_response(
|
||||
[pf.model_dump() for pf in results], status_code=200
|
||||
)
|
||||
|
||||
# Retrieval ------------------------------------------------------
|
||||
|
||||
def _handle_retrieve(self, request: httpx.Request) -> httpx.Response:
|
||||
parts = request.url.path.split("/")
|
||||
pipeline_id = parts[-2]
|
||||
pipeline = self._pipelines.get(pipeline_id)
|
||||
if not pipeline:
|
||||
return self._server.json_response(
|
||||
{"detail": f"Pipeline {pipeline_id} not found"}, status_code=404
|
||||
)
|
||||
|
||||
payload = self._server.json(request)
|
||||
query = payload.get("query", "")
|
||||
top_k = payload.get("dense_similarity_top_k") or 3
|
||||
|
||||
nodes = self._build_retrieval_nodes(pipeline_id, query, top_k)
|
||||
response = PipelineRetrieveResponse(
|
||||
pipeline_id=pipeline_id,
|
||||
retrieval_nodes=nodes,
|
||||
)
|
||||
return self._server.json_response(response.model_dump())
|
||||
|
||||
def _build_retrieval_nodes(
|
||||
self, pipeline_id: str, query: str, top_k: int
|
||||
) -> List[RetrievalNode]:
|
||||
"""Build retrieval nodes from ingested documents and files."""
|
||||
chunks: List[_Chunk] = []
|
||||
|
||||
# Collect chunks from ingested documents
|
||||
for doc_id, doc in self._documents.get(pipeline_id, {}).items():
|
||||
text = doc.text or ""
|
||||
# Split the document text into paragraph-sized chunks
|
||||
paragraphs = [p.strip() for p in text.split("\n\n") if p.strip()]
|
||||
if not paragraphs:
|
||||
paragraphs = [text] if text else []
|
||||
for i, para in enumerate(paragraphs):
|
||||
chunks.append(
|
||||
_Chunk(
|
||||
text=para,
|
||||
source_id=doc_id,
|
||||
chunk_index=i,
|
||||
metadata=dict(doc.metadata) if doc.metadata else {},
|
||||
)
|
||||
)
|
||||
|
||||
# Collect chunks from ingested files (generate deterministic text)
|
||||
for pf_id, pf in self._files.get(pipeline_id, {}).items():
|
||||
seed = combined_seed(pipeline_id, pf.file_id or pf_id)
|
||||
file_text = generate_text_blob(seed, sentences=6)
|
||||
# Split generated text into sentence-pair chunks
|
||||
sentences = file_text.split(". ")
|
||||
for i in range(0, len(sentences), 2):
|
||||
chunk_text = ". ".join(sentences[i : i + 2])
|
||||
if not chunk_text.endswith("."):
|
||||
chunk_text += "."
|
||||
metadata: Dict[str, Any] = {"file_name": pf.name or ""}
|
||||
if pf.file_id:
|
||||
metadata["file_id"] = pf.file_id
|
||||
chunks.append(
|
||||
_Chunk(
|
||||
text=chunk_text,
|
||||
source_id=pf.file_id or pf_id,
|
||||
chunk_index=i // 2,
|
||||
metadata=metadata,
|
||||
)
|
||||
)
|
||||
|
||||
if not chunks:
|
||||
return []
|
||||
|
||||
# Score chunks deterministically based on query+chunk content
|
||||
scored: List[tuple[float, _Chunk]] = []
|
||||
for chunk in chunks:
|
||||
seed = combined_seed(query, chunk.text)
|
||||
# Generate a score between 0.5 and 1.0
|
||||
score = 0.5 + (seed % 5000) / 10000.0
|
||||
scored.append((score, chunk))
|
||||
|
||||
scored.sort(key=lambda x: x[0], reverse=True)
|
||||
selected = scored[:top_k]
|
||||
|
||||
nodes: List[RetrievalNode] = []
|
||||
for score, chunk in selected:
|
||||
node_id = self._server.new_id("node")
|
||||
text_node = TextNode(
|
||||
id=node_id,
|
||||
text=chunk.text,
|
||||
extra_info=chunk.metadata or None,
|
||||
start_char_idx=0,
|
||||
end_char_idx=len(chunk.text),
|
||||
)
|
||||
nodes.append(
|
||||
RetrievalNode(
|
||||
node=text_node,
|
||||
score=round(score, 4),
|
||||
)
|
||||
)
|
||||
return nodes
|
||||
|
||||
|
||||
class _Chunk:
|
||||
"""Internal helper to represent a text chunk for retrieval."""
|
||||
|
||||
__slots__ = ("text", "source_id", "chunk_index", "metadata")
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
text: str,
|
||||
source_id: str,
|
||||
chunk_index: int,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> None:
|
||||
self.text = text
|
||||
self.source_id = source_id
|
||||
self.chunk_index = chunk_index
|
||||
self.metadata = metadata or {}
|
||||
@@ -1,210 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import re
|
||||
import uuid
|
||||
from typing import Any, Callable, Dict, Optional, Sequence
|
||||
|
||||
import httpx
|
||||
import respx
|
||||
|
||||
from .agent_data import FakeAgentDataNamespace
|
||||
from .classify import FakeClassifyNamespace
|
||||
from .configurations import FakeConfigurationsNamespace
|
||||
from .extract import FakeExtractNamespace
|
||||
from .files import FakeFilesNamespace
|
||||
from .parse import FakeParseNamespace
|
||||
from .pipelines import FakePipelinesNamespace
|
||||
from .sheets import FakeSheetsNamespace
|
||||
from .split import FakeSplitNamespace
|
||||
|
||||
Handler = Callable[[httpx.Request], httpx.Response]
|
||||
|
||||
|
||||
class FakeLlamaCloudServer:
|
||||
DEFAULT_BASE_URL = "https://api.cloud.llamaindex.ai"
|
||||
DEFAULT_UPLOAD_BASE = "https://uploads.fake-llama.test"
|
||||
DEFAULT_DOWNLOAD_BASE = "https://downloads.fake-llama.test"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
base_urls: Optional[Sequence[str]] = None,
|
||||
namespaces: Optional[Sequence[str]] = None,
|
||||
upload_base_url: Optional[str] = None,
|
||||
download_base_url: Optional[str] = None,
|
||||
default_project_id: str = "proj-test",
|
||||
default_organization_id: str = "org-test",
|
||||
default_user_id: str = "user-test",
|
||||
) -> None:
|
||||
self.base_urls = tuple(base_urls or (self.DEFAULT_BASE_URL,))
|
||||
selected = namespaces or (
|
||||
"files",
|
||||
"configurations",
|
||||
"extract",
|
||||
"parse",
|
||||
"classify",
|
||||
"agent_data",
|
||||
"split",
|
||||
"sheets",
|
||||
"pipelines",
|
||||
)
|
||||
self._namespace_names = {name.lower() for name in selected}
|
||||
self._upload_base_url = upload_base_url or self.DEFAULT_UPLOAD_BASE
|
||||
self._download_base_url = download_base_url or self.DEFAULT_DOWNLOAD_BASE
|
||||
self.default_project_id = default_project_id
|
||||
self.default_organization_id = default_organization_id
|
||||
self.default_user_id = default_user_id
|
||||
self.router = respx.MockRouter(assert_all_called=False)
|
||||
self._installed = False
|
||||
self._registered = False
|
||||
|
||||
self.files = FakeFilesNamespace(
|
||||
server=self,
|
||||
upload_base_url=self._upload_base_url,
|
||||
download_base_url=self._download_base_url,
|
||||
)
|
||||
self.configurations = FakeConfigurationsNamespace(server=self)
|
||||
self.extract = FakeExtractNamespace(
|
||||
server=self,
|
||||
files=self.files,
|
||||
configurations=self.configurations,
|
||||
)
|
||||
self.parse = FakeParseNamespace(server=self)
|
||||
self.classify = FakeClassifyNamespace(server=self, files=self.files)
|
||||
self.agent_data = FakeAgentDataNamespace(server=self)
|
||||
self.split = FakeSplitNamespace(server=self)
|
||||
self.pipelines = FakePipelinesNamespace(server=self)
|
||||
self.sheets = FakeSheetsNamespace(
|
||||
server=self,
|
||||
files=self.files,
|
||||
download_base_url=self._download_base_url,
|
||||
)
|
||||
|
||||
# Context management ----------------------------------------------
|
||||
def install(self) -> "FakeLlamaCloudServer":
|
||||
self.router.route(url__regex=r"^http://localhost:.*").pass_through()
|
||||
if not self._registered:
|
||||
self._register_namespaces()
|
||||
if not self._installed:
|
||||
self.router.__enter__()
|
||||
self._installed = True
|
||||
return self
|
||||
|
||||
def uninstall(self) -> None:
|
||||
if self._installed:
|
||||
self.router.__exit__(None, None, None)
|
||||
self._installed = False
|
||||
|
||||
def __enter__(self) -> "FakeLlamaCloudServer":
|
||||
return self.install()
|
||||
|
||||
def __exit__(self, exc_type: Any, exc: Any, tb: Any) -> None:
|
||||
self.uninstall()
|
||||
|
||||
# Route utilities -------------------------------------------------
|
||||
def add_route(
|
||||
self,
|
||||
method: str,
|
||||
path: str,
|
||||
handler: Handler,
|
||||
*,
|
||||
namespace: str,
|
||||
alias: Optional[str] = None,
|
||||
base_urls: Optional[Sequence[str]] = None,
|
||||
) -> respx.Route:
|
||||
urls = base_urls or self.base_urls
|
||||
first_route: Optional[respx.Route] = None
|
||||
for base in urls:
|
||||
route = self._register_route(method, base, path, handler)
|
||||
if first_route is None:
|
||||
first_route = route
|
||||
if alias and first_route:
|
||||
setattr(self, alias, first_route)
|
||||
return first_route # type: ignore[return-value]
|
||||
|
||||
def _register_route(
|
||||
self,
|
||||
method: str,
|
||||
base: str,
|
||||
path: str,
|
||||
handler: Handler,
|
||||
) -> respx.Route:
|
||||
url = self._build_url(base, path)
|
||||
if "{" in path:
|
||||
regex = self._compile_regex(base, path)
|
||||
route = self.router.route(method=method, url__regex=regex)
|
||||
else:
|
||||
route = self.router.route(method=method, url=url)
|
||||
route.mock(side_effect=lambda request, func=handler: func(request))
|
||||
return route
|
||||
|
||||
def _build_url(self, base: str, path: str) -> str:
|
||||
base = base.rstrip("/")
|
||||
if not path.startswith("/"):
|
||||
path = "/" + path
|
||||
return f"{base}{path}"
|
||||
|
||||
def _compile_regex(self, base: str, path: str) -> re.Pattern[str]:
|
||||
escaped = re.escape(base.rstrip("/"))
|
||||
regex_path = re.sub(r"\{[^/]+\}", r"[^/]+", path)
|
||||
pattern = f"^{escaped}{regex_path}(\\?.*)?$"
|
||||
return re.compile(pattern)
|
||||
|
||||
# Helpers ---------------------------------------------------------
|
||||
def json(self, request: httpx.Request) -> Dict[str, Any]:
|
||||
if not request.content:
|
||||
return {}
|
||||
return json.loads(request.content.decode("utf-8"))
|
||||
|
||||
def encode_json(self, payload: Dict[str, Any]) -> bytes:
|
||||
return json.dumps(payload).encode("utf-8")
|
||||
|
||||
def json_response(self, payload: Any, *, status_code: int = 200) -> httpx.Response:
|
||||
body = json.dumps(payload, default=self._json_default).encode("utf-8")
|
||||
headers = {"content-type": "application/json"}
|
||||
return httpx.Response(status_code=status_code, headers=headers, content=body)
|
||||
|
||||
def new_id(self, prefix: str) -> str:
|
||||
return f"{prefix}_{uuid.uuid4().hex[:8]}"
|
||||
|
||||
# Internal --------------------------------------------------------
|
||||
def _json_default(self, value: Any) -> Any:
|
||||
if hasattr(value, "model_dump"):
|
||||
return value.model_dump()
|
||||
if hasattr(value, "dict"):
|
||||
return value.dict()
|
||||
if isinstance(value, (set, frozenset)):
|
||||
return list(value)
|
||||
if isinstance(value, (bytes, bytearray)):
|
||||
return value.decode("utf-8")
|
||||
if hasattr(value, "isoformat"):
|
||||
try:
|
||||
return value.isoformat() # datetime/date support
|
||||
except Exception:
|
||||
pass
|
||||
raise TypeError(f"{value!r} is not JSON serializable")
|
||||
|
||||
def _register_namespaces(self) -> None:
|
||||
if "files" in self._namespace_names:
|
||||
self.files.register()
|
||||
if "configurations" in self._namespace_names:
|
||||
self.configurations.register()
|
||||
if "extract" in self._namespace_names:
|
||||
self.extract.register()
|
||||
if "parse" in self._namespace_names:
|
||||
self.parse.register()
|
||||
if "classify" in self._namespace_names:
|
||||
self.classify.register()
|
||||
if "agent_data" in self._namespace_names:
|
||||
self.agent_data.register()
|
||||
if "split" in self._namespace_names:
|
||||
self.split.register()
|
||||
if "pipelines" in self._namespace_names:
|
||||
self.pipelines.register()
|
||||
if "sheets" in self._namespace_names:
|
||||
self.sheets.register()
|
||||
self._registered = True
|
||||
|
||||
|
||||
__all__ = ["FakeLlamaCloudServer"]
|
||||
@@ -1,274 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import random
|
||||
import struct
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional
|
||||
from urllib.parse import urlencode
|
||||
|
||||
import httpx
|
||||
from llama_cloud.types.beta.sheets_job import Region, SheetsJob, WorksheetMetadata
|
||||
from llama_cloud.types.beta.sheets_parsing_config import SheetsParsingConfig
|
||||
from llama_cloud.types.presigned_url import PresignedURL
|
||||
|
||||
from ._deterministic import combined_seed, generate_text_blob, utcnow
|
||||
from .files import FakeFilesNamespace, StoredFile
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .server import FakeLlamaCloudServer
|
||||
|
||||
|
||||
@dataclass
|
||||
class SheetsJobRecord:
|
||||
job: SheetsJob
|
||||
|
||||
|
||||
class FakeSheetsNamespace:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
server: "FakeLlamaCloudServer",
|
||||
files: FakeFilesNamespace,
|
||||
download_base_url: str,
|
||||
) -> None:
|
||||
self._server = server
|
||||
self._files = files
|
||||
self._download_base_url = download_base_url.rstrip("/")
|
||||
self._jobs: Dict[str, SheetsJobRecord] = {}
|
||||
self._region_content: Dict[str, bytes] = {}
|
||||
|
||||
def register(self) -> None:
|
||||
server = self._server
|
||||
server.add_route(
|
||||
"POST",
|
||||
"/api/v1/beta/sheets/jobs",
|
||||
self._handle_create,
|
||||
namespace="sheets",
|
||||
)
|
||||
server.add_route(
|
||||
"GET",
|
||||
"/api/v1/beta/sheets/jobs",
|
||||
self._handle_list,
|
||||
namespace="sheets",
|
||||
)
|
||||
server.add_route(
|
||||
"GET",
|
||||
"/api/v1/beta/sheets/jobs/{spreadsheet_job_id}",
|
||||
self._handle_get,
|
||||
namespace="sheets",
|
||||
)
|
||||
server.add_route(
|
||||
"DELETE",
|
||||
"/api/v1/beta/sheets/jobs/{spreadsheet_job_id}",
|
||||
self._handle_delete,
|
||||
namespace="sheets",
|
||||
)
|
||||
server.add_route(
|
||||
"GET",
|
||||
"/api/v1/beta/sheets/jobs/{spreadsheet_job_id}/regions/{region_id}/result/{region_type}",
|
||||
self._handle_get_result_table,
|
||||
namespace="sheets",
|
||||
)
|
||||
server.add_route(
|
||||
"GET",
|
||||
"/sheets/{job_id}/{region_id}/{region_type}",
|
||||
self._handle_presigned_download,
|
||||
namespace="sheets",
|
||||
base_urls=[self._download_base_url],
|
||||
)
|
||||
|
||||
def _handle_create(self, request: httpx.Request) -> httpx.Response:
|
||||
payload = self._server.json(request)
|
||||
file_id = payload.get("file_id", "")
|
||||
config_raw = payload.get("config") or {}
|
||||
|
||||
stored_file = self._files.get(file_id) if file_id else None
|
||||
if file_id and not stored_file:
|
||||
return self._server.json_response(
|
||||
{"detail": f"File {file_id} not found"}, status_code=404
|
||||
)
|
||||
|
||||
job_id = self._server.new_id("sheets-job")
|
||||
now = utcnow()
|
||||
|
||||
config = SheetsParsingConfig.model_validate(config_raw)
|
||||
regions, worksheet_metadata = self._build_results(job_id, config, stored_file)
|
||||
|
||||
job = SheetsJob(
|
||||
id=job_id,
|
||||
config=config,
|
||||
created_at=now.isoformat(),
|
||||
file_id=file_id,
|
||||
project_id=request.url.params.get(
|
||||
"project_id", self._server.default_project_id
|
||||
),
|
||||
status="SUCCESS",
|
||||
updated_at=now.isoformat(),
|
||||
user_id="fake-user",
|
||||
errors=None,
|
||||
file=None,
|
||||
regions=regions,
|
||||
success=True,
|
||||
worksheet_metadata=worksheet_metadata,
|
||||
)
|
||||
self._jobs[job_id] = SheetsJobRecord(job=job)
|
||||
return self._server.json_response(job.model_dump())
|
||||
|
||||
def _handle_list(self, request: httpx.Request) -> httpx.Response:
|
||||
items = [r.job.model_dump() for r in self._jobs.values()]
|
||||
return self._server.json_response({"items": items, "next_page_token": None})
|
||||
|
||||
def _handle_get(self, request: httpx.Request) -> httpx.Response:
|
||||
job_id = request.url.path.split("/")[-1]
|
||||
record = self._jobs.get(job_id)
|
||||
if not record:
|
||||
return self._server.json_response(
|
||||
{"detail": "Sheets job not found"}, status_code=404
|
||||
)
|
||||
return self._server.json_response(record.job.model_dump())
|
||||
|
||||
def _handle_delete(self, request: httpx.Request) -> httpx.Response:
|
||||
job_id = request.url.path.split("/")[-1]
|
||||
self._jobs.pop(job_id, None)
|
||||
return self._server.json_response({}, status_code=200)
|
||||
|
||||
def _handle_get_result_table(self, request: httpx.Request) -> httpx.Response:
|
||||
parts = request.url.path.split("/")
|
||||
# .../jobs/{job_id}/regions/{region_id}/result/{region_type}
|
||||
job_id = parts[-5]
|
||||
region_id = parts[-3]
|
||||
|
||||
record = self._jobs.get(job_id)
|
||||
if not record:
|
||||
return self._server.json_response(
|
||||
{"detail": "Sheets job not found"}, status_code=404
|
||||
)
|
||||
|
||||
if record.job.regions:
|
||||
found = any(r.region_id == region_id for r in record.job.regions)
|
||||
if not found:
|
||||
return self._server.json_response(
|
||||
{"detail": f"Region {region_id} not found"}, status_code=404
|
||||
)
|
||||
|
||||
region_type = parts[-1]
|
||||
presigned = PresignedURL(
|
||||
url=(
|
||||
f"{self._download_base_url}/sheets/{job_id}/{region_id}/{region_type}"
|
||||
f"?{urlencode({'token': 'fake'})}"
|
||||
),
|
||||
expires_at=utcnow(),
|
||||
form_fields=None,
|
||||
)
|
||||
return self._server.json_response(presigned.model_dump())
|
||||
|
||||
def _handle_presigned_download(self, request: httpx.Request) -> httpx.Response:
|
||||
parts = request.url.path.split("/")
|
||||
# /sheets/{job_id}/{region_id}/{region_type}
|
||||
region_id = parts[-2]
|
||||
content_key = region_id
|
||||
content = self._region_content.get(content_key)
|
||||
if content is None:
|
||||
return httpx.Response(404, json={"detail": "Region content not found"})
|
||||
return httpx.Response(
|
||||
200,
|
||||
content=content,
|
||||
headers={"content-type": "application/octet-stream"},
|
||||
)
|
||||
|
||||
def _build_results(
|
||||
self,
|
||||
job_id: str,
|
||||
config: SheetsParsingConfig,
|
||||
stored_file: Optional[StoredFile],
|
||||
) -> tuple[List[Region], List[WorksheetMetadata]]:
|
||||
file_hash = stored_file.sha256 if stored_file else "no-file"
|
||||
seed = combined_seed(file_hash, job_id)
|
||||
rng = random.Random(seed)
|
||||
|
||||
sheet_names_config = config.sheet_names if config else None
|
||||
if sheet_names_config:
|
||||
sheet_names = list(sheet_names_config)
|
||||
else:
|
||||
num_sheets = rng.randint(1, 3)
|
||||
sheet_names = [f"Sheet{i + 1}" for i in range(num_sheets)]
|
||||
|
||||
worksheet_metadata: List[WorksheetMetadata] = []
|
||||
for name in sheet_names:
|
||||
worksheet_metadata.append(
|
||||
WorksheetMetadata(
|
||||
sheet_name=name,
|
||||
title=f"Title for {name}",
|
||||
description=f"Description for {name}",
|
||||
)
|
||||
)
|
||||
|
||||
regions: List[Region] = []
|
||||
region_types = ["table", "extra"]
|
||||
for sheet_name in sheet_names:
|
||||
num_regions = rng.randint(1, 3)
|
||||
for j in range(num_regions):
|
||||
region_id = self._server.new_id("region")
|
||||
rtype = region_types[rng.randint(0, len(region_types) - 1)]
|
||||
row_start = rng.randint(1, 10)
|
||||
col_start = chr(ord("A") + rng.randint(0, 5))
|
||||
row_end = row_start + rng.randint(3, 20)
|
||||
col_end = chr(ord(col_start) + rng.randint(1, 5))
|
||||
location = f"{col_start}{row_start}:{col_end}{row_end}"
|
||||
regions.append(
|
||||
Region(
|
||||
region_id=region_id,
|
||||
region_type=rtype,
|
||||
sheet_name=sheet_name,
|
||||
location=location,
|
||||
title=f"Region {j + 1} in {sheet_name}",
|
||||
description=f"Deterministic region from {sheet_name}",
|
||||
)
|
||||
)
|
||||
content_seed = combined_seed(file_hash, job_id, region_id)
|
||||
self._region_content[region_id] = _build_fake_parquet(
|
||||
content_seed, sheet_name, location
|
||||
)
|
||||
|
||||
return regions, worksheet_metadata
|
||||
|
||||
|
||||
def _build_fake_parquet(seed: int, sheet_name: str, location: str) -> bytes:
|
||||
"""Build a minimal Apache Parquet file with deterministic tabular data.
|
||||
|
||||
The file uses the PAR1 magic bytes and contains a simplified but
|
||||
structurally valid Parquet layout so that downstream consumers can
|
||||
at minimum verify the magic header.
|
||||
"""
|
||||
rng = random.Random(seed)
|
||||
num_cols = rng.randint(2, 5)
|
||||
num_rows = rng.randint(3, 10)
|
||||
|
||||
headers = [f"col_{i}" for i in range(num_cols)]
|
||||
rows: List[List[str]] = []
|
||||
for _ in range(num_rows):
|
||||
row = [
|
||||
generate_text_blob(rng.randint(0, 1_000_000), sentences=1)[:30]
|
||||
for _ in headers
|
||||
]
|
||||
rows.append(row)
|
||||
|
||||
# Encode as minimal Parquet-like binary with correct magic.
|
||||
# Real Parquet parsers need Thrift metadata; for the mock server we
|
||||
# embed the data as JSON in the page payload between the PAR1 markers
|
||||
# so tests can verify content determinism if needed.
|
||||
import json as _json
|
||||
|
||||
payload = _json.dumps(
|
||||
{
|
||||
"sheet_name": sheet_name,
|
||||
"location": location,
|
||||
"headers": headers,
|
||||
"rows": rows,
|
||||
}
|
||||
).encode("utf-8")
|
||||
|
||||
magic = b"PAR1"
|
||||
# Parquet footer: 4-byte LE footer length + magic
|
||||
footer_len = struct.pack("<I", len(payload))
|
||||
return magic + payload + footer_len + magic
|
||||
@@ -1,141 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import httpx
|
||||
from llama_cloud.types.beta.split_category import SplitCategory
|
||||
from llama_cloud.types.beta.split_category_param import SplitCategoryParam
|
||||
from llama_cloud.types.beta.split_create_response import SplitCreateResponse
|
||||
from llama_cloud.types.beta.split_document_input import SplitDocumentInput
|
||||
from llama_cloud.types.beta.split_get_response import SplitGetResponse
|
||||
from llama_cloud.types.beta.split_result_response import SplitResultResponse
|
||||
from llama_cloud.types.beta.split_segment_response import SplitSegmentResponse
|
||||
from pydantic.dataclasses import dataclass
|
||||
|
||||
from ._deterministic import categorize_pages, utcnow
|
||||
from .files import StoredFile
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .server import FakeLlamaCloudServer
|
||||
|
||||
|
||||
@dataclass
|
||||
class SplitRequest:
|
||||
categories: list[SplitCategoryParam]
|
||||
file_id: str
|
||||
stored_file: StoredFile
|
||||
|
||||
|
||||
class FakeSplitNamespace:
|
||||
def __init__(self, *, server: "FakeLlamaCloudServer") -> None:
|
||||
self._server = server
|
||||
self._jobs: dict[str, SplitGetResponse] = {}
|
||||
self.routes: dict[str, Any] = {}
|
||||
self._allowed_input_types = ("file_id",)
|
||||
self._page_size = 50
|
||||
|
||||
def _validate_split_request(
|
||||
self, request: httpx.Request
|
||||
) -> httpx.Response | SplitRequest:
|
||||
payload = self._server.json(request)
|
||||
document_input = payload.get("document_input")
|
||||
if not document_input:
|
||||
response = {"detail": "the document_input field should be non-null"}
|
||||
return self._server.json_response(response, status_code=400)
|
||||
input_type = document_input.get("type", "file_id")
|
||||
if input_type not in self._allowed_input_types:
|
||||
response = {
|
||||
"detail": f"document_input.type {input_type} is invalid. Allowed input types: {', '.join(self._allowed_input_types)}"
|
||||
}
|
||||
return self._server.json_response(response, status_code=400)
|
||||
input_value = document_input.get("value")
|
||||
if input_value is None:
|
||||
response = {"detail": "Missing document_input.value field"}
|
||||
return self._server.json_response(response, status_code=400)
|
||||
configuration = payload.get("configuration") or {}
|
||||
categories = configuration.get("categories", [])
|
||||
if not categories:
|
||||
response = {"detail": "categories field should be non-null and non-empty"}
|
||||
return self._server.json_response(response, status_code=400)
|
||||
stored_file = self._server.files.get(input_value)
|
||||
if stored_file is None:
|
||||
response = {"detail": f"file with ID {input_value} not found"}
|
||||
return self._server.json_response(response, status_code=404)
|
||||
return SplitRequest(
|
||||
categories=categories, file_id=input_value, stored_file=stored_file
|
||||
)
|
||||
|
||||
def _create_split_job(self, request: httpx.Request) -> httpx.Response:
|
||||
validated = self._validate_split_request(request)
|
||||
if isinstance(validated, httpx.Response):
|
||||
return validated
|
||||
categorized = categorize_pages(
|
||||
validated.stored_file.content,
|
||||
[category["name"] for category in validated.categories],
|
||||
0,
|
||||
)
|
||||
result = SplitResultResponse(segments=[])
|
||||
for c in categorized:
|
||||
result.segments.append(
|
||||
SplitSegmentResponse(
|
||||
category=c, confidence_category="high", pages=categorized[c]
|
||||
)
|
||||
)
|
||||
job_id = self._server.new_id("split-")
|
||||
job = SplitGetResponse(
|
||||
id=job_id,
|
||||
categories=[
|
||||
SplitCategory(name=c["name"], description=c.get("description"))
|
||||
for c in validated.categories
|
||||
],
|
||||
document_input=SplitDocumentInput(type="file_id", value=validated.file_id),
|
||||
project_id=self._server.default_project_id,
|
||||
user_id=self._server.default_user_id,
|
||||
status="completed",
|
||||
result=result,
|
||||
created_at=utcnow(),
|
||||
updated_at=utcnow(),
|
||||
error_message=None,
|
||||
)
|
||||
self._jobs[job_id] = job
|
||||
response = SplitCreateResponse(
|
||||
id=job_id,
|
||||
categories=[
|
||||
SplitCategory(name=c["name"], description=c.get("description"))
|
||||
for c in validated.categories
|
||||
],
|
||||
document_input=SplitDocumentInput(type="file_id", value=validated.file_id),
|
||||
project_id=self._server.default_project_id,
|
||||
user_id=self._server.default_user_id,
|
||||
status="pending",
|
||||
error_message=None,
|
||||
)
|
||||
return self._server.json_response(response.model_dump(), status_code=200)
|
||||
|
||||
def _get_split_job_result(self, request: httpx.Request) -> httpx.Response:
|
||||
job_id = request.url.path.split("/")[-1]
|
||||
job = self._jobs.get(job_id)
|
||||
if job is not None:
|
||||
return self._server.json_response(job.model_dump())
|
||||
return self._server.json_response(
|
||||
{"detail": f"job with ID {job_id} does not exist"}, status_code=404
|
||||
)
|
||||
|
||||
def register(self) -> None:
|
||||
server = self._server
|
||||
create_route = server.add_route(
|
||||
"POST",
|
||||
"/api/v1/beta/split/jobs",
|
||||
self._create_split_job,
|
||||
namespace="split",
|
||||
alias="create",
|
||||
)
|
||||
self.routes["create"] = create_route
|
||||
get_route = server.add_route(
|
||||
"GET",
|
||||
"/api/v1/beta/split/jobs/{split_job_id}",
|
||||
self._get_split_job_result,
|
||||
namespace="split",
|
||||
alias="get",
|
||||
)
|
||||
self.routes["get"] = get_route
|
||||
@@ -0,0 +1,20 @@
|
||||
"""Regression tests for ``extraction_review.clients``."""
|
||||
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
|
||||
def test_clients_import_does_not_load_fake_when_disabled():
|
||||
"""Without FAKE_LLAMA_CLOUD set, importing clients must not pull in
|
||||
llama_cloud_fake. The fake server is a dev-only dependency and may not be
|
||||
installed in production environments.
|
||||
"""
|
||||
src = (
|
||||
"import sys\n"
|
||||
"import extraction_review.clients\n"
|
||||
"assert 'llama_cloud_fake' not in sys.modules, "
|
||||
"'llama_cloud_fake was imported without FAKE_LLAMA_CLOUD'\n"
|
||||
)
|
||||
env = {k: v for k, v in os.environ.items() if k != "FAKE_LLAMA_CLOUD"}
|
||||
subprocess.run([sys.executable, "-c", src], check=True, env=env)
|
||||
@@ -1,30 +0,0 @@
|
||||
"""
|
||||
Pytest configuration for testing_utils tests.
|
||||
|
||||
These tests create their own FakeLlamaCloudServer instances and need
|
||||
the global server from extraction_review.clients to be uninstalled
|
||||
to avoid route conflicts.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def isolate_from_global_fake():
|
||||
"""Uninstall the global fake server for testing_utils tests.
|
||||
|
||||
The global server from extraction_review.clients intercepts HTTP
|
||||
requests before test-specific servers can handle them. This fixture
|
||||
temporarily uninstalls it so tests can use their own isolated instances.
|
||||
"""
|
||||
from extraction_review.clients import fake
|
||||
|
||||
was_installed = fake is not None and fake._installed
|
||||
|
||||
if was_installed:
|
||||
fake.uninstall()
|
||||
|
||||
yield
|
||||
|
||||
if was_installed:
|
||||
fake.install()
|
||||
@@ -1,313 +0,0 @@
|
||||
"""Tests for the FakeAgentDataNamespace mock implementation."""
|
||||
|
||||
from typing import cast
|
||||
|
||||
import pytest
|
||||
from extraction_review.testing_utils import FakeLlamaCloudServer
|
||||
from extraction_review.testing_utils._deterministic import hash_schema
|
||||
from llama_cloud import AsyncLlamaCloud
|
||||
from llama_cloud._exceptions import APIStatusError
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class Receipt(BaseModel):
|
||||
merchant: str = Field(description="Vendor name")
|
||||
total: float = Field(description="Grand total")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def server():
|
||||
"""Provide an installed FakeLlamaCloudServer."""
|
||||
with FakeLlamaCloudServer() as srv:
|
||||
yield srv
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client(server) -> AsyncLlamaCloud:
|
||||
"""Provide an AsyncLlamaCloud client configured for the fake server."""
|
||||
return AsyncLlamaCloud(api_key="fake-api-key")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_item(server, client: AsyncLlamaCloud):
|
||||
"""Verify items can be created and have expected ID format."""
|
||||
data = Receipt(merchant="Test Inc", total=1000)
|
||||
item = await client.beta.agent_data.agent_data(
|
||||
data=data.model_dump(),
|
||||
deployment_name="extraction_agent",
|
||||
collection="extracted_data",
|
||||
)
|
||||
|
||||
assert item.id == hash_schema(data)[:7]
|
||||
assert item.data["merchant"] == data.merchant
|
||||
assert item.data["total"] == data.total
|
||||
assert item.collection == "extracted_data"
|
||||
assert item.deployment_name == "extraction_agent"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_item(server, client: AsyncLlamaCloud):
|
||||
"""Verify items can be updated while preserving metadata."""
|
||||
data = Receipt(merchant="Test Inc", total=1000)
|
||||
item = await client.beta.agent_data.agent_data(
|
||||
data=data.model_dump(),
|
||||
deployment_name="extraction_agent",
|
||||
collection="extracted_data",
|
||||
)
|
||||
assert item.id is not None
|
||||
|
||||
updated_data = Receipt(merchant="Testing Inc", total=1100)
|
||||
updated_item = await client.beta.agent_data.update(
|
||||
item_id=item.id, data=updated_data.model_dump()
|
||||
)
|
||||
|
||||
assert updated_item.data["merchant"] == updated_data.merchant
|
||||
assert updated_item.data["total"] == updated_data.total
|
||||
assert updated_item.id == item.id
|
||||
assert updated_item.collection == item.collection
|
||||
assert updated_item.deployment_name == item.deployment_name
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_with_eq_filter(server, client: AsyncLlamaCloud):
|
||||
"""Verify search with equality filter returns matching items."""
|
||||
data1 = Receipt(merchant="Test Inc", total=1000)
|
||||
data2 = Receipt(merchant="Test Inc", total=1300)
|
||||
data3 = Receipt(merchant="Testing Inc", total=1100)
|
||||
item1 = await client.beta.agent_data.agent_data(
|
||||
data=data1.model_dump(),
|
||||
deployment_name="extraction_agent",
|
||||
collection="extracted_data",
|
||||
)
|
||||
item2 = await client.beta.agent_data.agent_data(
|
||||
data=data2.model_dump(),
|
||||
deployment_name="extraction_agent",
|
||||
collection="extracted_data",
|
||||
)
|
||||
await client.beta.agent_data.agent_data(
|
||||
data=data3.model_dump(),
|
||||
deployment_name="extraction_agent",
|
||||
collection="extracted_data",
|
||||
)
|
||||
|
||||
result = await client.beta.agent_data.search(
|
||||
deployment_name="extraction_agent",
|
||||
collection="extracted_data",
|
||||
filter={"merchant": {"eq": "Test Inc"}},
|
||||
)
|
||||
|
||||
assert result.total_size == 2
|
||||
assert any(item.id == item1.id for item in result.items)
|
||||
assert any(item.id == item2.id for item in result.items)
|
||||
assert all(item.data["merchant"] == "Test Inc" for item in result.items)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_with_lt_filter(server, client: AsyncLlamaCloud):
|
||||
"""Verify search with less-than filter returns matching items."""
|
||||
data1 = Receipt(merchant="Test Inc", total=1000)
|
||||
data2 = Receipt(merchant="Test Inc", total=1300)
|
||||
data3 = Receipt(merchant="Testing Inc", total=1100)
|
||||
item1 = await client.beta.agent_data.agent_data(
|
||||
data=data1.model_dump(),
|
||||
deployment_name="extraction_agent",
|
||||
collection="extracted_data",
|
||||
)
|
||||
await client.beta.agent_data.agent_data(
|
||||
data=data2.model_dump(),
|
||||
deployment_name="extraction_agent",
|
||||
collection="extracted_data",
|
||||
)
|
||||
item3 = await client.beta.agent_data.agent_data(
|
||||
data=data3.model_dump(),
|
||||
deployment_name="extraction_agent",
|
||||
collection="extracted_data",
|
||||
)
|
||||
|
||||
result = await client.beta.agent_data.search(
|
||||
deployment_name="extraction_agent",
|
||||
collection="extracted_data",
|
||||
filter={"total": {"lt": 1200}},
|
||||
)
|
||||
|
||||
assert result.total_size == 2
|
||||
assert any(item.id == item1.id for item in result.items)
|
||||
assert any(item.id == item3.id for item in result.items)
|
||||
assert all(cast(int, item.data["total"]) < 1200 for item in result.items)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aggregate_with_filter(server, client: AsyncLlamaCloud):
|
||||
"""Verify aggregation with filter groups correctly."""
|
||||
data1 = Receipt(merchant="Test Inc", total=1000)
|
||||
data2 = Receipt(merchant="Test Inc", total=1300)
|
||||
data3 = Receipt(merchant="Testing Inc", total=1100)
|
||||
await client.beta.agent_data.agent_data(
|
||||
data=data1.model_dump(),
|
||||
deployment_name="extraction_agent",
|
||||
collection="extracted_data",
|
||||
)
|
||||
await client.beta.agent_data.agent_data(
|
||||
data=data2.model_dump(),
|
||||
deployment_name="extraction_agent",
|
||||
collection="extracted_data",
|
||||
)
|
||||
await client.beta.agent_data.agent_data(
|
||||
data=data3.model_dump(),
|
||||
deployment_name="extraction_agent",
|
||||
collection="extracted_data",
|
||||
)
|
||||
|
||||
result = await client.beta.agent_data.aggregate(
|
||||
deployment_name="extraction_agent",
|
||||
collection="extracted_data",
|
||||
filter={"merchant": {"eq": "Test Inc"}},
|
||||
group_by=["merchant"],
|
||||
count=True,
|
||||
)
|
||||
|
||||
# Filtering for 'Test Inc' means only one group
|
||||
assert len(result.items) == 1
|
||||
assert result.items[0].count == 2
|
||||
assert result.items[0].first_item is not None
|
||||
assert result.items[0].first_item["merchant"] == data1.merchant
|
||||
assert result.items[0].group_key == {"merchant": "Test Inc"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aggregate_without_filter(server, client: AsyncLlamaCloud):
|
||||
"""Verify aggregation without filter groups all items."""
|
||||
data1 = Receipt(merchant="Test Inc", total=1000)
|
||||
data2 = Receipt(merchant="Test Inc", total=1300)
|
||||
data3 = Receipt(merchant="Testing Inc", total=1100)
|
||||
await client.beta.agent_data.agent_data(
|
||||
data=data1.model_dump(),
|
||||
deployment_name="extraction_agent",
|
||||
collection="extracted_data",
|
||||
)
|
||||
await client.beta.agent_data.agent_data(
|
||||
data=data2.model_dump(),
|
||||
deployment_name="extraction_agent",
|
||||
collection="extracted_data",
|
||||
)
|
||||
await client.beta.agent_data.agent_data(
|
||||
data=data3.model_dump(),
|
||||
deployment_name="extraction_agent",
|
||||
collection="extracted_data",
|
||||
)
|
||||
result = await client.beta.agent_data.aggregate(
|
||||
deployment_name="extraction_agent",
|
||||
collection="extracted_data",
|
||||
group_by=["merchant"],
|
||||
count=True,
|
||||
)
|
||||
|
||||
assert len(result.items) == 2
|
||||
# First group: Test Inc (2 items)
|
||||
assert result.items[0].count == 2
|
||||
assert result.items[0].group_key == {"merchant": "Test Inc"}
|
||||
# Second group: Testing Inc (1 item)
|
||||
assert result.items[1].count == 1
|
||||
assert result.items[1].group_key == {"merchant": "Testing Inc"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_item(server, client: AsyncLlamaCloud):
|
||||
"""Verify items can be retrieved by ID."""
|
||||
data1 = Receipt(merchant="Test Inc", total=1000)
|
||||
data2 = Receipt(merchant="Test Inc", total=1300)
|
||||
item1 = await client.beta.agent_data.agent_data(
|
||||
data=data1.model_dump(),
|
||||
deployment_name="extraction_agent",
|
||||
collection="extracted_data",
|
||||
)
|
||||
item2 = await client.beta.agent_data.agent_data(
|
||||
data=data2.model_dump(),
|
||||
deployment_name="extraction_agent",
|
||||
collection="extracted_data",
|
||||
)
|
||||
|
||||
assert item1.id is not None
|
||||
retrieved = await client.beta.agent_data.get(item_id=item1.id)
|
||||
|
||||
assert retrieved.collection == item1.collection
|
||||
assert retrieved.deployment_name == item1.deployment_name
|
||||
assert retrieved.data["merchant"] == data1.merchant
|
||||
assert retrieved.data["total"] == data1.total
|
||||
|
||||
assert item2.id is not None
|
||||
# Non-existent ID should raise 404
|
||||
with pytest.raises(APIStatusError) as exc_info:
|
||||
await client.beta.agent_data.get(item_id=item2.id + "nonexistent")
|
||||
|
||||
assert exc_info.value.status_code == 404
|
||||
assert exc_info.value.body == {"detail": f"No data with ID: {item2.id}nonexistent"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_by_id(server, client: AsyncLlamaCloud):
|
||||
"""Verify items can be deleted by ID."""
|
||||
data = Receipt(merchant="Test Inc", total=1300)
|
||||
item = await client.beta.agent_data.agent_data(
|
||||
data=data.model_dump(),
|
||||
deployment_name="extraction_agent",
|
||||
collection="extracted_data",
|
||||
)
|
||||
assert item.id is not None
|
||||
|
||||
await client.beta.agent_data.delete(item.id)
|
||||
|
||||
# Item should no longer exist
|
||||
with pytest.raises(APIStatusError) as exc_info:
|
||||
await client.beta.agent_data.get(item.id)
|
||||
|
||||
assert exc_info.value.status_code == 404
|
||||
|
||||
# Deleting again should also raise 404
|
||||
with pytest.raises(APIStatusError) as exc_info:
|
||||
await client.beta.agent_data.delete(item.id)
|
||||
|
||||
assert exc_info.value.status_code == 404
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_by_query(server, client: AsyncLlamaCloud):
|
||||
"""Verify items can be deleted by filter query."""
|
||||
data1 = Receipt(merchant="Test Inc", total=1000)
|
||||
data2 = Receipt(merchant="Test Inc", total=1300)
|
||||
data3 = Receipt(merchant="Testing Inc", total=1100)
|
||||
item1 = await client.beta.agent_data.agent_data(
|
||||
data=data1.model_dump(),
|
||||
deployment_name="extraction_agent",
|
||||
collection="extracted_data",
|
||||
)
|
||||
item2 = await client.beta.agent_data.agent_data(
|
||||
data=data2.model_dump(),
|
||||
deployment_name="extraction_agent",
|
||||
collection="extracted_data",
|
||||
)
|
||||
item3 = await client.beta.agent_data.agent_data(
|
||||
data=data3.model_dump(),
|
||||
deployment_name="extraction_agent",
|
||||
collection="extracted_data",
|
||||
)
|
||||
|
||||
result = await client.beta.agent_data.delete_by_query(
|
||||
deployment_name="extraction_agent",
|
||||
collection="extracted_data",
|
||||
filter={"merchant": {"eq": "Test Inc"}},
|
||||
)
|
||||
|
||||
assert result.deleted_count == 2
|
||||
|
||||
# Deleted items should no longer exist
|
||||
for item in (item1, item2):
|
||||
assert item.id is not None
|
||||
with pytest.raises(APIStatusError) as exc_info:
|
||||
await client.beta.agent_data.get(item.id)
|
||||
assert exc_info.value.status_code == 404
|
||||
|
||||
# Non-matching item should still exist
|
||||
assert item3.id is not None
|
||||
found = await client.beta.agent_data.get(item3.id)
|
||||
assert found.id == item3.id
|
||||
@@ -1,215 +0,0 @@
|
||||
"""Tests for the deterministic data generation utilities."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import random
|
||||
|
||||
import pytest
|
||||
|
||||
from extraction_review.testing_utils._deterministic import (
|
||||
_generate_value,
|
||||
generate_data_from_schema,
|
||||
)
|
||||
|
||||
SEED = 42
|
||||
|
||||
|
||||
def _rng(seed: int = SEED) -> random.Random:
|
||||
return random.Random(seed)
|
||||
|
||||
|
||||
# -- type → expected python type mapping for parametrize -----------------------
|
||||
|
||||
_TYPE_CASES = [
|
||||
("integer", int),
|
||||
("number", (int, float)),
|
||||
("boolean", bool),
|
||||
("string", str),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"schema_type, expected_type",
|
||||
_TYPE_CASES,
|
||||
ids=[t for t, _ in _TYPE_CASES],
|
||||
)
|
||||
def test_basic_types(schema_type, expected_type):
|
||||
value = _generate_value({"type": schema_type}, _rng(), depth=0)
|
||||
assert isinstance(value, expected_type)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"schema_type, expected_type",
|
||||
_TYPE_CASES,
|
||||
ids=[f"nullable_{t}" for t, _ in _TYPE_CASES],
|
||||
)
|
||||
def test_nullable_types(schema_type, expected_type):
|
||||
"""``["<type>", "null"]`` must produce the concrete type, not a text blob."""
|
||||
value = _generate_value({"type": [schema_type, "null"]}, _rng(), depth=0)
|
||||
assert isinstance(value, expected_type)
|
||||
|
||||
|
||||
def test_null_type():
|
||||
assert _generate_value({"type": "null"}, _rng(), depth=0) is None
|
||||
|
||||
|
||||
def test_all_null_union_returns_none():
|
||||
assert _generate_value({"type": ["null"]}, _rng(), depth=0) is None
|
||||
|
||||
|
||||
def test_multi_type_union_picks_first_concrete():
|
||||
value = _generate_value({"type": ["string", "integer"]}, _rng(), depth=0)
|
||||
assert isinstance(value, str)
|
||||
|
||||
|
||||
# -- constraints survive nullable wrapping ------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"schema, lo, hi",
|
||||
[
|
||||
({"type": "integer", "minimum": 10, "maximum": 20}, 10, 20),
|
||||
({"type": "number", "minimum": 0.5, "maximum": 1.5}, 0.5, 1.5),
|
||||
({"type": ["integer", "null"], "minimum": 10, "maximum": 20}, 10, 20),
|
||||
({"type": ["number", "null"], "minimum": 0.5, "maximum": 1.5}, 0.5, 1.5),
|
||||
],
|
||||
ids=["int", "float", "nullable_int", "nullable_float"],
|
||||
)
|
||||
def test_numeric_bounds(schema, lo, hi):
|
||||
value = _generate_value(schema, _rng(), depth=0)
|
||||
assert lo <= value <= hi
|
||||
|
||||
|
||||
# -- string formats -----------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"fmt, substring",
|
||||
[
|
||||
("date-time", "T"),
|
||||
("email", "@example.com"),
|
||||
("uri", "https://example.com/"),
|
||||
],
|
||||
)
|
||||
def test_string_formats(fmt, substring):
|
||||
value = _generate_value({"type": "string", "format": fmt}, _rng(), depth=0)
|
||||
assert isinstance(value, str)
|
||||
assert substring in value
|
||||
|
||||
|
||||
# -- composite / container schemas --------------------------------------------
|
||||
|
||||
|
||||
def test_enum():
|
||||
value = _generate_value({"enum": ["a", "b", "c"]}, _rng(), depth=0)
|
||||
assert value in ("a", "b", "c")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("keyword", ["oneOf", "anyOf"])
|
||||
def test_composition_keywords(keyword):
|
||||
schema = {keyword: [{"type": "integer"}, {"type": "string"}]}
|
||||
value = _generate_value(schema, _rng(), depth=0)
|
||||
assert isinstance(value, (int, str))
|
||||
|
||||
|
||||
def test_object():
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {"name": {"type": "string"}, "age": {"type": "integer"}},
|
||||
}
|
||||
value = _generate_value(schema, _rng(), depth=0)
|
||||
assert isinstance(value["name"], str)
|
||||
assert isinstance(value["age"], int)
|
||||
|
||||
|
||||
def test_array():
|
||||
value = _generate_value(
|
||||
{"type": "array", "items": {"type": "integer"}, "minItems": 2, "maxItems": 4},
|
||||
_rng(),
|
||||
depth=0,
|
||||
)
|
||||
assert 2 <= len(value) <= 4
|
||||
assert all(isinstance(v, int) for v in value)
|
||||
|
||||
|
||||
def test_nullable_object():
|
||||
schema = {
|
||||
"type": ["object", "null"],
|
||||
"properties": {"name": {"type": "string"}},
|
||||
}
|
||||
value = _generate_value(schema, _rng(), depth=0)
|
||||
assert isinstance(value, dict) and "name" in value
|
||||
|
||||
|
||||
def test_nullable_array():
|
||||
schema = {"type": ["array", "null"], "items": {"type": "integer"}}
|
||||
value = _generate_value(schema, _rng(), depth=0)
|
||||
assert isinstance(value, list)
|
||||
assert all(isinstance(v, int) for v in value)
|
||||
|
||||
|
||||
# -- kitchen-sink integration test --------------------------------------------
|
||||
|
||||
|
||||
def test_mixed_nullable_object_end_to_end():
|
||||
"""Full schema with nullable numerics, enum, and nested nullable fields."""
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"total_revenue": {"type": ["number", "null"]},
|
||||
"employee_count": {"type": ["integer", "null"]},
|
||||
"filing_type": {"type": "string", "enum": ["10-K", "10-Q", "8-K"]},
|
||||
"is_audited": {"type": ["boolean", "null"]},
|
||||
"scores": {
|
||||
"type": ["array", "null"],
|
||||
"items": {"type": ["number", "null"]},
|
||||
},
|
||||
"metadata": {
|
||||
"type": ["object", "null"],
|
||||
"properties": {"source": {"type": ["string", "null"]}},
|
||||
},
|
||||
},
|
||||
}
|
||||
data = generate_data_from_schema(schema, seed=SEED)
|
||||
|
||||
assert isinstance(data["total_revenue"], (int, float))
|
||||
assert isinstance(data["employee_count"], int)
|
||||
assert data["filing_type"] in ("10-K", "10-Q", "8-K")
|
||||
assert isinstance(data["is_audited"], bool)
|
||||
assert all(isinstance(s, (int, float)) for s in data["scores"])
|
||||
assert isinstance(data["metadata"]["source"], str)
|
||||
|
||||
# deterministic
|
||||
assert generate_data_from_schema(schema, seed=SEED) == data
|
||||
|
||||
|
||||
# -- edge cases ---------------------------------------------------------------
|
||||
|
||||
|
||||
def test_depth_limit_returns_primitive():
|
||||
value = _generate_value({"type": "object", "properties": {}}, _rng(), depth=9)
|
||||
assert isinstance(value, (int, float, str))
|
||||
|
||||
|
||||
def test_none_schema():
|
||||
assert isinstance(_generate_value(None, _rng(), depth=0), str)
|
||||
|
||||
|
||||
def test_bare_string_schema():
|
||||
value = _generate_value("some_type", _rng(), depth=0)
|
||||
assert value.startswith("some_type-")
|
||||
|
||||
|
||||
def test_list_schema():
|
||||
value = _generate_value([{"type": "integer"}, {"type": "string"}], _rng(), depth=0)
|
||||
assert isinstance(value, list) and len(value) == 2
|
||||
|
||||
|
||||
def test_empty_enum_falls_through():
|
||||
value = _generate_value({"enum": [], "type": "string"}, _rng(), depth=0)
|
||||
assert isinstance(value, str)
|
||||
|
||||
|
||||
def test_unknown_mapping_falls_through():
|
||||
value = _generate_value({"description": "mystery"}, _rng(), depth=0)
|
||||
assert isinstance(value, str)
|
||||
@@ -1,100 +0,0 @@
|
||||
"""Tests for the FakeExtractNamespace mock implementation (llama-cloud v2)."""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from extraction_review.testing_utils import FakeLlamaCloudServer
|
||||
from llama_cloud import AsyncLlamaCloud
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class Receipt(BaseModel):
|
||||
merchant: str = Field(description="Vendor name")
|
||||
total: float = Field(description="Grand total")
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def fake_env(monkeypatch):
|
||||
monkeypatch.setenv("LLAMA_CLOUD_API_KEY", "unit-test-key")
|
||||
monkeypatch.setenv("LLAMA_CLOUD_BASE_URL", FakeLlamaCloudServer.DEFAULT_BASE_URL)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def server():
|
||||
with FakeLlamaCloudServer() as srv:
|
||||
yield srv
|
||||
|
||||
|
||||
def _write_sample_file(tmp_path: Path, name: str, content: str) -> Path:
|
||||
target = tmp_path / name
|
||||
target.write_text(content)
|
||||
return target
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stateless_extract_is_deterministic(server, tmp_path):
|
||||
"""Inline configuration + polling returns deterministic data."""
|
||||
client = AsyncLlamaCloud(api_key="unit-test-key")
|
||||
sample_path = _write_sample_file(
|
||||
tmp_path, "receipt.txt", "Merchant: Lunar Bistro\nTotal: 123.45"
|
||||
)
|
||||
|
||||
file_obj = await client.files.create(
|
||||
file=sample_path,
|
||||
purpose="extract",
|
||||
external_file_id=str(sample_path),
|
||||
)
|
||||
configuration = {
|
||||
"data_schema": Receipt.model_json_schema(),
|
||||
"tier": "cost_effective",
|
||||
}
|
||||
first = await client.extract.run(
|
||||
file_input=file_obj.id,
|
||||
configuration=configuration,
|
||||
)
|
||||
second = await client.extract.run(
|
||||
file_input=file_obj.id,
|
||||
configuration=configuration,
|
||||
)
|
||||
|
||||
assert first.status == "COMPLETED"
|
||||
assert isinstance(first.extract_result, dict)
|
||||
assert "merchant" in first.extract_result
|
||||
assert second.extract_result == first.extract_result
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_saved_configuration_flow(server, tmp_path):
|
||||
"""Using a saved configuration_id resolves schema + settings from the config."""
|
||||
client = AsyncLlamaCloud(api_key="unit-test-key")
|
||||
|
||||
cfg = await client.configurations.create(
|
||||
name="receipt-cfg",
|
||||
parameters={
|
||||
"product_type": "extract_v2",
|
||||
"data_schema": Receipt.model_json_schema(),
|
||||
"tier": "agentic",
|
||||
},
|
||||
)
|
||||
|
||||
sample_path = _write_sample_file(
|
||||
tmp_path, "contract.pdf", "Agreement between parties."
|
||||
)
|
||||
file_obj = await client.files.create(
|
||||
file=sample_path,
|
||||
purpose="extract",
|
||||
external_file_id=str(sample_path),
|
||||
)
|
||||
job = await client.extract.run(
|
||||
file_input=file_obj.id,
|
||||
configuration_id=cfg.id,
|
||||
)
|
||||
|
||||
assert job.status == "COMPLETED"
|
||||
assert isinstance(job.extract_result, dict)
|
||||
assert "merchant" in job.extract_result
|
||||
assert job.configuration_id == cfg.id
|
||||
|
||||
fetched = await client.configurations.retrieve(cfg.id)
|
||||
assert fetched.id == cfg.id
|
||||
assert fetched.parameters.product_type == "extract_v2"
|
||||
@@ -1,162 +0,0 @@
|
||||
"""Tests for the FakeFilesNamespace mock implementation."""
|
||||
|
||||
from urllib.parse import urlencode
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
from extraction_review.testing_utils import FakeLlamaCloudServer
|
||||
from llama_cloud import APIStatusError, AsyncLlamaCloud
|
||||
from llama_cloud.types.file_query_params import Filter
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def server():
|
||||
"""Provide a server with files namespace enabled."""
|
||||
with FakeLlamaCloudServer(namespaces=["files"]) as srv:
|
||||
yield srv
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_preload_and_download_as_presigned_url(server, tmp_path):
|
||||
"""Verify files can be preloaded and read back."""
|
||||
test_file = tmp_path / "test_file.txt"
|
||||
test_file.write_bytes(b"test content here")
|
||||
|
||||
file_id = server.files.preload(path=test_file)
|
||||
|
||||
content = server.files.read(file_id)
|
||||
assert content == b"test content here"
|
||||
|
||||
client = AsyncLlamaCloud(api_key="fake-api-key")
|
||||
|
||||
presigned_url = await client.files.get(
|
||||
file_id=file_id,
|
||||
)
|
||||
assert (
|
||||
presigned_url.url
|
||||
== f"{server._download_base_url}/files/{file_id}?{urlencode({'token': 'fake'})}"
|
||||
)
|
||||
response = httpx.get(presigned_url.url)
|
||||
assert response.content == b"test content here"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_not_found_returns_404(server):
|
||||
"""Verify non-existent file returns 404."""
|
||||
client = AsyncLlamaCloud(api_key="fake-api-key")
|
||||
with pytest.raises(APIStatusError) as exc_info:
|
||||
await client.files.get(
|
||||
"does-not-exist",
|
||||
)
|
||||
assert exc_info.value.status_code == 404
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_file(server, tmp_path):
|
||||
"""Verify files can be deleted."""
|
||||
test_file = tmp_path / "to_delete.txt"
|
||||
test_file.write_bytes(b"delete me")
|
||||
|
||||
file_id = server.files.preload(path=test_file)
|
||||
client = AsyncLlamaCloud(api_key="fake-api-key")
|
||||
|
||||
# File should exist
|
||||
response = await client.files.get(file_id)
|
||||
assert file_id in response.url
|
||||
|
||||
# Delete the file
|
||||
await client.files.delete(
|
||||
file_id,
|
||||
)
|
||||
|
||||
# File should no longer exist
|
||||
with pytest.raises(APIStatusError) as exc_info:
|
||||
await client.files.get(
|
||||
file_id,
|
||||
)
|
||||
assert exc_info.value.status_code == 404
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_files_native_upload(server, tmp_path):
|
||||
"""Verify that the client can natively upload the files without having to pass through server.preload"""
|
||||
test_file = tmp_path / "test_file.txt"
|
||||
test_file.write_bytes(b"test content here")
|
||||
client = AsyncLlamaCloud(api_key="fake-api-key")
|
||||
file_obj = await client.files.create(
|
||||
file=test_file,
|
||||
purpose="parse",
|
||||
external_file_id=str(test_file),
|
||||
)
|
||||
assert isinstance(file_obj.id, str)
|
||||
assert file_obj.file_type == "application/octet-stream"
|
||||
assert file_obj.id.startswith("file_")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_files_query_by_id(server, tmp_path):
|
||||
"""Test that you can upload and query files selecting them by file ID"""
|
||||
test_file_1 = tmp_path / "test_file1.txt"
|
||||
test_file_1.write_bytes(b"test content here 1")
|
||||
test_file_2 = tmp_path / "test_file2.txt"
|
||||
test_file_2.write_bytes(b"test content here 2")
|
||||
client = AsyncLlamaCloud(api_key="fake-api-key")
|
||||
file_obj_1 = await client.files.create(
|
||||
file=test_file_1,
|
||||
purpose="parse",
|
||||
external_file_id=str(test_file_1),
|
||||
)
|
||||
await client.files.create(
|
||||
file=test_file_2,
|
||||
purpose="parse",
|
||||
external_file_id=str(test_file_1),
|
||||
)
|
||||
response = await client.files.query(filter=Filter(file_ids=[file_obj_1.id]))
|
||||
assert len(response.items) == 1
|
||||
assert response.total_size == 1
|
||||
assert response.items[0].id == file_obj_1.id
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_files_list(server, tmp_path):
|
||||
"""Test that you can list files using the GET /files endpoint."""
|
||||
test_file_1 = tmp_path / "file_a.txt"
|
||||
test_file_1.write_bytes(b"content a")
|
||||
test_file_2 = tmp_path / "file_b.txt"
|
||||
test_file_2.write_bytes(b"content b")
|
||||
|
||||
client = AsyncLlamaCloud(api_key="fake-api-key")
|
||||
file_obj_1 = await client.files.create(
|
||||
file=test_file_1,
|
||||
purpose="extract",
|
||||
)
|
||||
file_obj_2 = await client.files.create(
|
||||
file=test_file_2,
|
||||
purpose="extract",
|
||||
)
|
||||
|
||||
# List all files
|
||||
all_files = await client.files.list()
|
||||
items = [f async for f in all_files]
|
||||
assert len(items) >= 2
|
||||
ids = {f.id for f in items}
|
||||
assert file_obj_1.id in ids
|
||||
assert file_obj_2.id in ids
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_files_list_by_name(server, tmp_path):
|
||||
"""Test filtering files by name via the list endpoint."""
|
||||
test_file = tmp_path / "unique_name.pdf"
|
||||
test_file.write_bytes(b"unique content")
|
||||
|
||||
# Use preload for reliable filename storage
|
||||
server.files.preload(path=test_file, filename="unique_name.pdf")
|
||||
server.files.preload_from_source("other_file.txt", b"other content")
|
||||
|
||||
client = AsyncLlamaCloud(api_key="fake-api-key")
|
||||
|
||||
results = await client.files.list(file_name="unique_name.pdf")
|
||||
items = [f async for f in results]
|
||||
assert len(items) == 1
|
||||
assert items[0].name == "unique_name.pdf"
|
||||
@@ -1,179 +0,0 @@
|
||||
"""Tests for the FakeParseNamespace mock implementation."""
|
||||
|
||||
import pytest
|
||||
import respx
|
||||
from extraction_review.testing_utils import FakeLlamaCloudServer
|
||||
from llama_cloud import APIStatusError, AsyncLlamaCloud
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def server():
|
||||
"""Provide a server with parse namespace enabled."""
|
||||
with FakeLlamaCloudServer() as srv:
|
||||
yield srv
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def client() -> AsyncLlamaCloud:
|
||||
return AsyncLlamaCloud(api_key="fake-api-key")
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def data() -> tuple[str, bytes, str]:
|
||||
with open("tests/files/test.pdf", "rb") as f:
|
||||
content = f.read()
|
||||
return ("tests/files/test.pdf", content, "application/pdf")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_parse_with_upload_file(
|
||||
server: FakeLlamaCloudServer, client: AsyncLlamaCloud, data: tuple[str, bytes, str]
|
||||
) -> None:
|
||||
job_create = await client.parsing.create(
|
||||
tier="fast",
|
||||
version="latest",
|
||||
upload_file=data,
|
||||
)
|
||||
assert job_create.error_message is None
|
||||
assert job_create.status == "COMPLETED"
|
||||
assert job_create.project_id == server.default_project_id
|
||||
job_response = await client.parsing.get(
|
||||
job_id=job_create.id, expand=["text", "markdown", "items"]
|
||||
)
|
||||
assert job_response.job.id == job_create.id
|
||||
assert job_response.job.status == job_create.status
|
||||
assert job_response.job.project_id == job_create.project_id
|
||||
assert job_response.items is not None
|
||||
assert job_response.markdown is not None
|
||||
assert job_response.text is not None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_parse_with_different_expand(
|
||||
server: FakeLlamaCloudServer, client: AsyncLlamaCloud, data: tuple[str, bytes, str]
|
||||
) -> None:
|
||||
job_create = await client.parsing.create(
|
||||
tier="fast",
|
||||
version="latest",
|
||||
upload_file=data,
|
||||
)
|
||||
job_response = await client.parsing.get(job_id=job_create.id, expand=["text"])
|
||||
assert job_response.items is None
|
||||
assert job_response.markdown is None
|
||||
assert job_response.text is not None
|
||||
job_response = await client.parsing.get(job_id=job_create.id, expand=["markdown"])
|
||||
assert job_response.items is None
|
||||
assert job_response.markdown is not None
|
||||
assert job_response.text is None
|
||||
job_response = await client.parsing.get(job_id=job_create.id, expand=["items"])
|
||||
assert job_response.items is not None
|
||||
assert job_response.markdown is None
|
||||
assert job_response.text is None
|
||||
# no expands -> defaul to items
|
||||
job_response = await client.parsing.get(job_id=job_create.id)
|
||||
assert job_response.items is not None
|
||||
assert job_response.markdown is None
|
||||
assert job_response.text is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_parse_with_file_id(
|
||||
server: FakeLlamaCloudServer, client: AsyncLlamaCloud, data: tuple[str, bytes, str]
|
||||
) -> None:
|
||||
file_name, _, _ = data
|
||||
file_obj = await client.files.create(
|
||||
file=file_name,
|
||||
purpose="parse",
|
||||
external_file_id=file_name,
|
||||
)
|
||||
job_create = await client.parsing.create(
|
||||
tier="fast",
|
||||
version="latest",
|
||||
file_id=file_obj.id,
|
||||
)
|
||||
assert job_create.error_message is None
|
||||
assert job_create.status == "COMPLETED"
|
||||
assert job_create.project_id == server.default_project_id
|
||||
job_response = await client.parsing.get(
|
||||
job_id=job_create.id, expand=["text", "markdown", "items"]
|
||||
)
|
||||
assert job_response.job.id == job_create.id
|
||||
assert job_response.job.status == job_create.status
|
||||
assert job_response.job.project_id == job_create.project_id
|
||||
assert job_response.items is not None
|
||||
assert job_response.markdown is not None
|
||||
assert job_response.text is not None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_parse_with_file_id_file_not_found(
|
||||
server: FakeLlamaCloudServer,
|
||||
client: AsyncLlamaCloud,
|
||||
) -> None:
|
||||
with pytest.raises(APIStatusError) as exc_info:
|
||||
await client.parsing.create(
|
||||
tier="fast",
|
||||
version="latest",
|
||||
file_id="does-not-exist",
|
||||
)
|
||||
assert exc_info.value.status_code == 404
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_parse_without_fileid_or_sourceurl(
|
||||
server: FakeLlamaCloudServer,
|
||||
client: AsyncLlamaCloud,
|
||||
) -> None:
|
||||
with pytest.raises(APIStatusError) as exc_info:
|
||||
await client.parsing.create(
|
||||
tier="fast",
|
||||
version="latest",
|
||||
)
|
||||
assert exc_info.value.status_code == 400
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@respx.mock(assert_all_mocked=False)
|
||||
async def test_parse_with_source_url(
|
||||
server: FakeLlamaCloudServer,
|
||||
client: AsyncLlamaCloud,
|
||||
) -> None:
|
||||
job_create = await client.parsing.create(
|
||||
tier="fast",
|
||||
version="latest",
|
||||
source_url="https://pdfobject.com/pdf/sample.pdf",
|
||||
)
|
||||
assert job_create.error_message is None
|
||||
assert job_create.status == "COMPLETED"
|
||||
assert job_create.project_id == server.default_project_id
|
||||
job_response = await client.parsing.get(
|
||||
job_id=job_create.id, expand=["text", "markdown", "items"]
|
||||
)
|
||||
assert job_response.job.id == job_create.id
|
||||
assert job_response.job.status == job_create.status
|
||||
assert job_response.job.project_id == job_create.project_id
|
||||
assert job_response.items is not None
|
||||
assert job_response.markdown is not None
|
||||
assert job_response.text is not None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_parse_e2e(
|
||||
server: FakeLlamaCloudServer, client: AsyncLlamaCloud, data: tuple[str, bytes, str]
|
||||
) -> None:
|
||||
file_name, _, _ = data
|
||||
file_obj = await client.files.create(
|
||||
file=file_name,
|
||||
purpose="parse",
|
||||
external_file_id=file_name,
|
||||
)
|
||||
result = await client.parsing.parse(
|
||||
file_id=file_obj.id,
|
||||
expand=["markdown"],
|
||||
tier="agentic",
|
||||
version="latest",
|
||||
)
|
||||
assert result.markdown is not None
|
||||
assert len(result.markdown.pages) == 1
|
||||
assert hasattr(result.markdown.pages[0], "markdown")
|
||||
assert isinstance(result.markdown.pages[0].markdown, str) # type: ignore
|
||||
@@ -1,381 +0,0 @@
|
||||
"""Tests for the FakePipelinesNamespace mock implementation."""
|
||||
|
||||
import pytest
|
||||
from extraction_review.testing_utils import FakeLlamaCloudServer
|
||||
from llama_cloud import APIStatusError, AsyncLlamaCloud
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def server():
|
||||
"""Provide a server with the pipelines namespace enabled."""
|
||||
with FakeLlamaCloudServer(namespaces=["pipelines"]) as srv:
|
||||
yield srv
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pipelines_create_and_get(server):
|
||||
"""Verify a pipeline can be created and retrieved."""
|
||||
client = AsyncLlamaCloud(api_key="fake-api-key")
|
||||
|
||||
pipeline = await client.pipelines.create(name="test-pipeline")
|
||||
assert pipeline.id.startswith("pipeline_")
|
||||
assert pipeline.name == "test-pipeline"
|
||||
assert pipeline.project_id == server.default_project_id
|
||||
assert pipeline.status == "CREATED"
|
||||
|
||||
retrieved = await client.pipelines.get(pipeline.id)
|
||||
assert retrieved.id == pipeline.id
|
||||
assert retrieved.name == "test-pipeline"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pipelines_create_with_embedding_config(server):
|
||||
"""Verify a pipeline can be created with explicit embedding config."""
|
||||
client = AsyncLlamaCloud(api_key="fake-api-key")
|
||||
|
||||
pipeline = await client.pipelines.create(
|
||||
name="custom-pipeline",
|
||||
embedding_config={
|
||||
"type": "MANAGED_OPENAI_EMBEDDING",
|
||||
"component": {},
|
||||
},
|
||||
pipeline_type="MANAGED",
|
||||
)
|
||||
assert pipeline.name == "custom-pipeline"
|
||||
assert pipeline.embedding_config is not None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pipelines_list(server):
|
||||
"""Verify listing pipelines returns created pipelines."""
|
||||
client = AsyncLlamaCloud(api_key="fake-api-key")
|
||||
|
||||
await client.pipelines.create(name="pipeline-1")
|
||||
await client.pipelines.create(name="pipeline-2")
|
||||
|
||||
pipelines = await client.pipelines.list()
|
||||
assert len(pipelines) == 2
|
||||
names = {p.name for p in pipelines}
|
||||
assert names == {"pipeline-1", "pipeline-2"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pipelines_get_not_found(server):
|
||||
"""Verify non-existent pipeline returns 404."""
|
||||
client = AsyncLlamaCloud(api_key="fake-api-key")
|
||||
with pytest.raises(APIStatusError) as exc_info:
|
||||
await client.pipelines.get("nonexistent-id")
|
||||
assert exc_info.value.status_code == 404
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pipelines_delete(server):
|
||||
"""Verify a pipeline can be deleted."""
|
||||
client = AsyncLlamaCloud(api_key="fake-api-key")
|
||||
|
||||
pipeline = await client.pipelines.create(name="to-delete")
|
||||
retrieved = await client.pipelines.get(pipeline.id)
|
||||
assert retrieved.id == pipeline.id
|
||||
|
||||
await client.pipelines.delete(pipeline.id)
|
||||
|
||||
with pytest.raises(APIStatusError) as exc_info:
|
||||
await client.pipelines.get(pipeline.id)
|
||||
assert exc_info.value.status_code == 404
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pipelines_update(server):
|
||||
"""Verify a pipeline can be updated."""
|
||||
client = AsyncLlamaCloud(api_key="fake-api-key")
|
||||
|
||||
pipeline = await client.pipelines.create(name="original-name")
|
||||
updated = await client.pipelines.update(pipeline.id, name="new-name")
|
||||
assert updated.name == "new-name"
|
||||
assert updated.id == pipeline.id
|
||||
|
||||
retrieved = await client.pipelines.get(pipeline.id)
|
||||
assert retrieved.name == "new-name"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pipelines_get_status(server):
|
||||
"""Verify pipeline status can be retrieved."""
|
||||
client = AsyncLlamaCloud(api_key="fake-api-key")
|
||||
|
||||
pipeline = await client.pipelines.create(name="status-test")
|
||||
status = await client.pipelines.get_status(pipeline.id)
|
||||
assert status.status == "SUCCESS"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pipelines_retrieve(server):
|
||||
"""Verify pipeline retrieve endpoint works."""
|
||||
client = AsyncLlamaCloud(api_key="fake-api-key")
|
||||
|
||||
pipeline = await client.pipelines.create(name="retrieve-test")
|
||||
result = await client.pipelines.retrieve(pipeline.id, query="test query")
|
||||
assert result.pipeline_id == pipeline.id
|
||||
assert result.retrieval_nodes is not None
|
||||
assert isinstance(result.retrieval_nodes, list)
|
||||
|
||||
|
||||
# --- Document ingestion tests ---
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pipelines_ingest_documents_and_retrieve(server):
|
||||
"""Ingest documents into a pipeline and verify retrieval returns nodes."""
|
||||
client = AsyncLlamaCloud(api_key="fake-api-key")
|
||||
|
||||
pipeline = await client.pipelines.create(name="doc-ingest-test")
|
||||
|
||||
docs = await client.pipelines.documents.create(
|
||||
pipeline.id,
|
||||
body=[
|
||||
{
|
||||
"text": "The quick brown fox jumps over the lazy dog.",
|
||||
"metadata": {"source": "test"},
|
||||
},
|
||||
{
|
||||
"text": "Machine learning is a subset of artificial intelligence.",
|
||||
"metadata": {"source": "ml"},
|
||||
},
|
||||
],
|
||||
)
|
||||
assert len(docs) == 2
|
||||
assert docs[0].text == "The quick brown fox jumps over the lazy dog."
|
||||
assert docs[0].metadata["source"] == "test"
|
||||
assert docs[1].text == "Machine learning is a subset of artificial intelligence."
|
||||
|
||||
result = await client.pipelines.retrieve(pipeline.id, query="fox")
|
||||
assert result.pipeline_id == pipeline.id
|
||||
assert len(result.retrieval_nodes) > 0
|
||||
|
||||
# Nodes should have text content from our documents
|
||||
texts = [n.node.text for n in result.retrieval_nodes]
|
||||
all_doc_texts = {
|
||||
"The quick brown fox jumps over the lazy dog.",
|
||||
"Machine learning is a subset of artificial intelligence.",
|
||||
}
|
||||
for text in texts:
|
||||
assert text in all_doc_texts
|
||||
|
||||
# Each node should have a score
|
||||
for node in result.retrieval_nodes:
|
||||
assert node.score is not None
|
||||
assert 0.0 <= node.score <= 1.0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pipelines_upsert_documents(server):
|
||||
"""Upsert documents (PUT) into a pipeline and verify they are stored."""
|
||||
client = AsyncLlamaCloud(api_key="fake-api-key")
|
||||
|
||||
pipeline = await client.pipelines.create(name="doc-upsert-test")
|
||||
|
||||
docs = await client.pipelines.documents.upsert(
|
||||
pipeline.id,
|
||||
body=[
|
||||
{"text": "First document content.", "metadata": {"order": "1"}},
|
||||
],
|
||||
)
|
||||
assert len(docs) == 1
|
||||
assert docs[0].text == "First document content."
|
||||
|
||||
# Upsert more documents
|
||||
docs2 = await client.pipelines.documents.upsert(
|
||||
pipeline.id,
|
||||
body=[
|
||||
{"text": "Second document content.", "metadata": {"order": "2"}},
|
||||
],
|
||||
)
|
||||
assert len(docs2) == 1
|
||||
|
||||
# Retrieve should pick up both documents
|
||||
result = await client.pipelines.retrieve(pipeline.id, query="document")
|
||||
assert len(result.retrieval_nodes) == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pipelines_ingest_documents_with_custom_id(server):
|
||||
"""Documents with explicit IDs preserve them."""
|
||||
client = AsyncLlamaCloud(api_key="fake-api-key")
|
||||
|
||||
pipeline = await client.pipelines.create(name="custom-id-test")
|
||||
|
||||
docs = await client.pipelines.documents.create(
|
||||
pipeline.id,
|
||||
body=[
|
||||
{"id": "my-doc-1", "text": "Custom ID document.", "metadata": {}},
|
||||
],
|
||||
)
|
||||
assert len(docs) == 1
|
||||
assert docs[0].id == "my-doc-1"
|
||||
|
||||
|
||||
# --- File ingestion tests ---
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pipelines_ingest_files_and_retrieve(server):
|
||||
"""Add files to a pipeline and verify retrieval returns generated nodes."""
|
||||
client = AsyncLlamaCloud(api_key="fake-api-key")
|
||||
|
||||
pipeline = await client.pipelines.create(name="file-ingest-test")
|
||||
|
||||
files = await client.pipelines.files.create(
|
||||
pipeline.id,
|
||||
body=[
|
||||
{"file_id": "file-abc123"},
|
||||
{"file_id": "file-def456"},
|
||||
],
|
||||
)
|
||||
assert len(files) == 2
|
||||
assert files[0].pipeline_id == pipeline.id
|
||||
assert files[0].status == "SUCCESS"
|
||||
assert files[1].file_id == "file-def456"
|
||||
|
||||
result = await client.pipelines.retrieve(pipeline.id, query="search query")
|
||||
assert result.pipeline_id == pipeline.id
|
||||
assert len(result.retrieval_nodes) > 0
|
||||
|
||||
# Nodes from files should have file metadata
|
||||
for node in result.retrieval_nodes:
|
||||
assert node.node.text # Should have generated text
|
||||
assert node.score is not None
|
||||
|
||||
|
||||
# --- Retrieval behavior tests ---
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pipelines_retrieve_empty_pipeline(server):
|
||||
"""Retrieve on an empty pipeline returns no nodes."""
|
||||
client = AsyncLlamaCloud(api_key="fake-api-key")
|
||||
|
||||
pipeline = await client.pipelines.create(name="empty-pipeline")
|
||||
result = await client.pipelines.retrieve(pipeline.id, query="anything")
|
||||
assert result.retrieval_nodes == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pipelines_retrieve_respects_top_k(server):
|
||||
"""Retrieve respects the dense_similarity_top_k parameter."""
|
||||
client = AsyncLlamaCloud(api_key="fake-api-key")
|
||||
|
||||
pipeline = await client.pipelines.create(name="topk-test")
|
||||
|
||||
# Ingest several documents to have more chunks than top_k
|
||||
await client.pipelines.documents.create(
|
||||
pipeline.id,
|
||||
body=[
|
||||
{
|
||||
"text": f"Document number {i} with unique content about topic {i}.",
|
||||
"metadata": {},
|
||||
}
|
||||
for i in range(10)
|
||||
],
|
||||
)
|
||||
|
||||
result = await client.pipelines.retrieve(
|
||||
pipeline.id,
|
||||
query="topic",
|
||||
dense_similarity_top_k=2,
|
||||
)
|
||||
assert len(result.retrieval_nodes) == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pipelines_retrieve_deterministic(server):
|
||||
"""Same query on same data produces the same results."""
|
||||
client = AsyncLlamaCloud(api_key="fake-api-key")
|
||||
|
||||
pipeline = await client.pipelines.create(name="deterministic-test")
|
||||
await client.pipelines.documents.create(
|
||||
pipeline.id,
|
||||
body=[
|
||||
{"text": "Alpha bravo charlie delta.", "metadata": {}},
|
||||
{"text": "Echo foxtrot golf hotel.", "metadata": {}},
|
||||
],
|
||||
)
|
||||
|
||||
result1 = await client.pipelines.retrieve(pipeline.id, query="bravo")
|
||||
result2 = await client.pipelines.retrieve(pipeline.id, query="bravo")
|
||||
|
||||
texts1 = [n.node.text for n in result1.retrieval_nodes]
|
||||
texts2 = [n.node.text for n in result2.retrieval_nodes]
|
||||
assert texts1 == texts2
|
||||
|
||||
scores1 = [n.score for n in result1.retrieval_nodes]
|
||||
scores2 = [n.score for n in result2.retrieval_nodes]
|
||||
assert scores1 == scores2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pipelines_delete_cleans_up_documents_and_files(server):
|
||||
"""Deleting a pipeline clears its ingested documents and files."""
|
||||
client = AsyncLlamaCloud(api_key="fake-api-key")
|
||||
|
||||
pipeline = await client.pipelines.create(name="cleanup-test")
|
||||
|
||||
await client.pipelines.documents.create(
|
||||
pipeline.id,
|
||||
body=[{"text": "Some content.", "metadata": {}}],
|
||||
)
|
||||
await client.pipelines.files.create(
|
||||
pipeline.id,
|
||||
body=[{"file_id": "file-xyz"}],
|
||||
)
|
||||
|
||||
# Verify data exists
|
||||
result = await client.pipelines.retrieve(pipeline.id, query="content")
|
||||
assert len(result.retrieval_nodes) > 0
|
||||
|
||||
# Delete pipeline
|
||||
await client.pipelines.delete(pipeline.id)
|
||||
|
||||
# Internal stores should be cleaned
|
||||
assert pipeline.id not in server.pipelines._documents
|
||||
assert pipeline.id not in server.pipelines._files
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pipelines_mixed_documents_and_files_retrieval(server):
|
||||
"""Retrieval combines results from both documents and files."""
|
||||
client = AsyncLlamaCloud(api_key="fake-api-key")
|
||||
|
||||
pipeline = await client.pipelines.create(name="mixed-test")
|
||||
|
||||
await client.pipelines.documents.create(
|
||||
pipeline.id,
|
||||
body=[
|
||||
{
|
||||
"text": "Document text about important concepts.",
|
||||
"metadata": {"type": "doc"},
|
||||
}
|
||||
],
|
||||
)
|
||||
await client.pipelines.files.create(
|
||||
pipeline.id,
|
||||
body=[{"file_id": "file-mixed-001"}],
|
||||
)
|
||||
|
||||
result = await client.pipelines.retrieve(
|
||||
pipeline.id,
|
||||
query="concepts",
|
||||
dense_similarity_top_k=10,
|
||||
)
|
||||
assert len(result.retrieval_nodes) > 1
|
||||
|
||||
# Should have nodes from both sources
|
||||
has_doc_node = any(
|
||||
n.node.text == "Document text about important concepts."
|
||||
for n in result.retrieval_nodes
|
||||
)
|
||||
has_file_node = any(
|
||||
n.node.extra_info and n.node.extra_info.get("file_id") == "file-mixed-001"
|
||||
for n in result.retrieval_nodes
|
||||
)
|
||||
assert has_doc_node, "Should have a node from the ingested document"
|
||||
assert has_file_node, "Should have a node from the ingested file"
|
||||
@@ -1,36 +0,0 @@
|
||||
"""Tests for the FakeLlamaCloudServer lifecycle and configuration."""
|
||||
|
||||
from extraction_review.testing_utils import FakeLlamaCloudServer
|
||||
|
||||
|
||||
def test_context_manager_installs_and_uninstalls():
|
||||
"""Verify context manager properly installs and uninstalls the mock."""
|
||||
server = FakeLlamaCloudServer()
|
||||
assert not server._installed
|
||||
|
||||
with server:
|
||||
assert server._installed
|
||||
|
||||
assert not server._installed
|
||||
|
||||
|
||||
def test_install_is_idempotent():
|
||||
"""Verify calling install multiple times is safe."""
|
||||
server = FakeLlamaCloudServer()
|
||||
server.install()
|
||||
server.install() # Should not raise
|
||||
assert server._installed
|
||||
server.uninstall()
|
||||
assert not server._installed
|
||||
|
||||
|
||||
def test_selective_namespace_registration():
|
||||
"""Verify only requested namespaces are registered."""
|
||||
with FakeLlamaCloudServer(namespaces=["files"]) as server:
|
||||
assert "files" in server._namespace_names
|
||||
assert "parse" not in server._namespace_names
|
||||
|
||||
with FakeLlamaCloudServer(namespaces=["parse", "extract"]) as server:
|
||||
assert "parse" in server._namespace_names
|
||||
assert "extract" in server._namespace_names
|
||||
assert "files" not in server._namespace_names
|
||||
@@ -1,209 +0,0 @@
|
||||
"""Tests for the FakeSheetsNamespace mock implementation."""
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
from extraction_review.testing_utils import FakeLlamaCloudServer
|
||||
from llama_cloud import APIStatusError, AsyncLlamaCloud
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def server():
|
||||
"""Provide a server with files and sheets namespaces enabled."""
|
||||
with FakeLlamaCloudServer(namespaces=["files", "sheets"]) as srv:
|
||||
yield srv
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sheets_create_and_get(server, tmp_path):
|
||||
"""Verify a sheets job can be created and retrieved."""
|
||||
test_file = tmp_path / "spreadsheet.xlsx"
|
||||
test_file.write_bytes(b"fake spreadsheet content")
|
||||
file_id = server.files.preload(path=test_file)
|
||||
|
||||
client = AsyncLlamaCloud(api_key="fake-api-key")
|
||||
|
||||
job = await client.beta.sheets.create(file_id=file_id)
|
||||
assert job.id.startswith("sheets-job_")
|
||||
assert job.status == "SUCCESS"
|
||||
assert job.success is True
|
||||
assert job.regions is not None
|
||||
assert len(job.regions) > 0
|
||||
|
||||
for region in job.regions:
|
||||
assert region.region_id is not None
|
||||
assert region.region_type in ("table", "extra")
|
||||
assert region.sheet_name is not None
|
||||
assert region.location is not None
|
||||
|
||||
assert job.worksheet_metadata is not None
|
||||
assert len(job.worksheet_metadata) > 0
|
||||
|
||||
# Retrieve the same job
|
||||
retrieved = await client.beta.sheets.get(job.id)
|
||||
assert retrieved.id == job.id
|
||||
assert retrieved.status == "SUCCESS"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sheets_file_not_found(server):
|
||||
"""Verify sheets with non-existent file returns 404."""
|
||||
client = AsyncLlamaCloud(api_key="fake-api-key")
|
||||
with pytest.raises(APIStatusError) as exc_info:
|
||||
await client.beta.sheets.create(file_id="nonexistent-file")
|
||||
assert exc_info.value.status_code == 404
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sheets_job_not_found(server):
|
||||
"""Verify non-existent sheets job returns 404."""
|
||||
client = AsyncLlamaCloud(api_key="fake-api-key")
|
||||
with pytest.raises(APIStatusError) as exc_info:
|
||||
await client.beta.sheets.get("nonexistent-id")
|
||||
assert exc_info.value.status_code == 404
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sheets_delete_job(server, tmp_path):
|
||||
"""Verify a sheets job can be deleted."""
|
||||
test_file = tmp_path / "spreadsheet.xlsx"
|
||||
test_file.write_bytes(b"delete me")
|
||||
file_id = server.files.preload(path=test_file)
|
||||
|
||||
client = AsyncLlamaCloud(api_key="fake-api-key")
|
||||
job = await client.beta.sheets.create(file_id=file_id)
|
||||
|
||||
# Job should exist
|
||||
retrieved = await client.beta.sheets.get(job.id)
|
||||
assert retrieved.id == job.id
|
||||
|
||||
# Delete
|
||||
await client.beta.sheets.delete_job(job.id)
|
||||
|
||||
# Should no longer exist
|
||||
with pytest.raises(APIStatusError) as exc_info:
|
||||
await client.beta.sheets.get(job.id)
|
||||
assert exc_info.value.status_code == 404
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sheets_get_result_table(server, tmp_path):
|
||||
"""Verify presigned URL generation for region results."""
|
||||
test_file = tmp_path / "spreadsheet.xlsx"
|
||||
test_file.write_bytes(b"spreadsheet with tables")
|
||||
file_id = server.files.preload(path=test_file)
|
||||
|
||||
client = AsyncLlamaCloud(api_key="fake-api-key")
|
||||
job = await client.beta.sheets.create(file_id=file_id)
|
||||
|
||||
assert job.regions is not None
|
||||
assert len(job.regions) > 0
|
||||
|
||||
region = job.regions[0]
|
||||
presigned = await client.beta.sheets.get_result_table(
|
||||
"table",
|
||||
spreadsheet_job_id=job.id,
|
||||
region_id=region.region_id,
|
||||
)
|
||||
assert presigned.url is not None
|
||||
assert "token=fake" in presigned.url
|
||||
assert job.id in presigned.url
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sheets_with_config(server, tmp_path):
|
||||
"""Verify sheets job with custom config."""
|
||||
test_file = tmp_path / "spreadsheet.xlsx"
|
||||
test_file.write_bytes(b"configured spreadsheet")
|
||||
file_id = server.files.preload(path=test_file)
|
||||
|
||||
client = AsyncLlamaCloud(api_key="fake-api-key")
|
||||
|
||||
job = await client.beta.sheets.create(
|
||||
file_id=file_id,
|
||||
config={
|
||||
"sheet_names": ["Revenue", "Expenses"],
|
||||
"flatten_hierarchical_tables": True,
|
||||
"generate_additional_metadata": True,
|
||||
},
|
||||
)
|
||||
assert job.status == "SUCCESS"
|
||||
assert job.worksheet_metadata is not None
|
||||
|
||||
sheet_names = {ws.sheet_name for ws in job.worksheet_metadata}
|
||||
assert sheet_names == {"Revenue", "Expenses"}
|
||||
|
||||
if job.regions:
|
||||
region_sheets = {r.sheet_name for r in job.regions}
|
||||
assert region_sheets.issubset({"Revenue", "Expenses"})
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sheets_list_jobs(server, tmp_path):
|
||||
"""Verify listing sheets jobs returns created jobs."""
|
||||
test_file = tmp_path / "spreadsheet.xlsx"
|
||||
test_file.write_bytes(b"list test")
|
||||
file_id = server.files.preload(path=test_file)
|
||||
|
||||
client = AsyncLlamaCloud(api_key="fake-api-key")
|
||||
|
||||
await client.beta.sheets.create(file_id=file_id)
|
||||
await client.beta.sheets.create(file_id=file_id)
|
||||
|
||||
jobs = await client.beta.sheets.list()
|
||||
items = [j async for j in jobs]
|
||||
assert len(items) == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sheets_presigned_download(server, tmp_path):
|
||||
"""Verify that the presigned URL for a region result serves parquet content."""
|
||||
test_file = tmp_path / "spreadsheet.xlsx"
|
||||
test_file.write_bytes(b"spreadsheet for download test")
|
||||
file_id = server.files.preload(path=test_file)
|
||||
|
||||
client = AsyncLlamaCloud(api_key="fake-api-key")
|
||||
job = await client.beta.sheets.create(file_id=file_id)
|
||||
|
||||
assert job.regions is not None
|
||||
assert len(job.regions) > 0
|
||||
|
||||
region = job.regions[0]
|
||||
presigned = await client.beta.sheets.get_result_table(
|
||||
region.region_type,
|
||||
spreadsheet_job_id=job.id,
|
||||
region_id=region.region_id,
|
||||
)
|
||||
|
||||
# Follow the presigned URL to download the parquet content
|
||||
async with httpx.AsyncClient() as http:
|
||||
resp = await http.get(presigned.url)
|
||||
assert resp.status_code == 200
|
||||
content = resp.content
|
||||
# Verify PAR1 magic bytes at start and end
|
||||
assert content[:4] == b"PAR1"
|
||||
assert content[-4:] == b"PAR1"
|
||||
assert len(content) > 8 # has actual payload between magic markers
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sheets_presigned_download_deterministic(server, tmp_path):
|
||||
"""Verify that repeated downloads for the same region return identical content."""
|
||||
test_file = tmp_path / "spreadsheet.xlsx"
|
||||
test_file.write_bytes(b"deterministic download test")
|
||||
file_id = server.files.preload(path=test_file)
|
||||
|
||||
client = AsyncLlamaCloud(api_key="fake-api-key")
|
||||
job = await client.beta.sheets.create(file_id=file_id)
|
||||
|
||||
assert job.regions is not None
|
||||
region = job.regions[0]
|
||||
presigned = await client.beta.sheets.get_result_table(
|
||||
region.region_type,
|
||||
spreadsheet_job_id=job.id,
|
||||
region_id=region.region_id,
|
||||
)
|
||||
|
||||
async with httpx.AsyncClient() as http:
|
||||
resp1 = await http.get(presigned.url)
|
||||
resp2 = await http.get(presigned.url)
|
||||
assert resp1.content == resp2.content
|
||||
@@ -1,115 +0,0 @@
|
||||
import pytest
|
||||
from extraction_review.testing_utils import FakeLlamaCloudServer
|
||||
from llama_cloud import APIStatusError, AsyncLlamaCloud
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def server():
|
||||
with FakeLlamaCloudServer() as srv:
|
||||
yield srv
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def client() -> AsyncLlamaCloud:
|
||||
return AsyncLlamaCloud(api_key="fake-api-key")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_split_end_to_end(
|
||||
server: FakeLlamaCloudServer, client: AsyncLlamaCloud
|
||||
) -> None:
|
||||
file_id = server.files.preload(path="tests/files/test.pdf")
|
||||
split_job = await client.beta.split.create(
|
||||
configuration={
|
||||
"categories": [
|
||||
{"name": "hello", "description": ""},
|
||||
{"name": "world", "description": ""},
|
||||
],
|
||||
},
|
||||
document_input={"type": "file_id", "value": file_id},
|
||||
)
|
||||
assert split_job.id.startswith("split-")
|
||||
assert split_job.status == "pending"
|
||||
cts = [c.name for c in split_job.categories]
|
||||
cts.sort()
|
||||
assert cts == ["hello", "world"]
|
||||
split_result = await client.beta.split.wait_for_completion(
|
||||
split_job_id=split_job.id,
|
||||
)
|
||||
assert split_result.result is not None
|
||||
cts = [s.category for s in split_result.result.segments]
|
||||
cts.sort()
|
||||
assert cts == [
|
||||
"hello",
|
||||
"world",
|
||||
]
|
||||
assert any(len(s.pages) > 0 for s in split_result.result.segments)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_split_no_categories_raises_bad_request(
|
||||
server: FakeLlamaCloudServer, client: AsyncLlamaCloud
|
||||
) -> None:
|
||||
file_id = server.files.preload(path="tests/files/test.pdf")
|
||||
with pytest.raises(APIStatusError) as exc_info:
|
||||
await client.beta.split.create(
|
||||
configuration={"categories": []},
|
||||
document_input={"type": "file_id", "value": file_id},
|
||||
)
|
||||
assert exc_info.value.status_code == 400
|
||||
assert (
|
||||
"categories field should be non-null and non-empty"
|
||||
in exc_info.value.message
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_split_invalid_document_input_type_raises_bad_request(
|
||||
server: FakeLlamaCloudServer, client: AsyncLlamaCloud
|
||||
) -> None:
|
||||
file_id = server.files.preload(path="tests/files/test.pdf")
|
||||
with pytest.raises(APIStatusError) as exc_info:
|
||||
await client.beta.split.create(
|
||||
configuration={
|
||||
"categories": [
|
||||
{"name": "hello", "description": ""},
|
||||
{"name": "world", "description": ""},
|
||||
],
|
||||
},
|
||||
document_input={"type": "file", "value": file_id},
|
||||
)
|
||||
assert exc_info.value.status_code == 400
|
||||
assert (
|
||||
"document_input.type file is invalid. Allowed input types: file_id"
|
||||
in exc_info.value.message
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_split_non_existing_file_id_raises_notfound(
|
||||
server: FakeLlamaCloudServer, client: AsyncLlamaCloud
|
||||
) -> None:
|
||||
file_id = "file-doesnotexist"
|
||||
with pytest.raises(APIStatusError) as exc_info:
|
||||
await client.beta.split.create(
|
||||
configuration={
|
||||
"categories": [
|
||||
{"name": "hello", "description": ""},
|
||||
{"name": "world", "description": ""},
|
||||
],
|
||||
},
|
||||
document_input={"type": "file_id", "value": file_id},
|
||||
)
|
||||
assert exc_info.value.status_code == 404
|
||||
assert f"file with ID {file_id} not found" in exc_info.value.message
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_split_non_existing_job_id_raises_notfound(
|
||||
server: FakeLlamaCloudServer, client: AsyncLlamaCloud
|
||||
) -> None:
|
||||
with pytest.raises(APIStatusError) as exc_info:
|
||||
job_id = "split-doesnotexist"
|
||||
await client.beta.split.get(job_id)
|
||||
assert exc_info.value.status_code == 404
|
||||
assert f"job with ID {job_id} does not exist" in exc_info.value.message
|
||||
Reference in New Issue
Block a user