refact: bypass message queue and control plane from apiserver (#542)

* depend on workflows explicitly

* temp

* backup

* revert

* backup

* fix test

* e2e

* remove cp and mq

* remove sleeps

* fix unit tests

* fix ui server
This commit is contained in:
Massimiliano Pippi
2025-06-19 08:22:46 +02:00
committed by GitHub
parent 4137ddb317
commit 671295d518
21 changed files with 1743 additions and 1958 deletions
@@ -8,7 +8,7 @@ services:
name: Git Workflow
source:
type: git
name: https://github.com/run-llama/llama_deploy.git
name: https://github.com/run-llama/llama_deploy.git@massi/refact
env:
VAR_1: x # this gets overwritten because VAR_1 also exists in the provided .env
VAR_2: y
@@ -1,13 +1,7 @@
import asyncio
from llama_index.core.workflow import (
Context,
Event,
StartEvent,
StopEvent,
Workflow,
step,
)
from workflows import Context, Workflow, step
from workflows.events import Event, StartEvent, StopEvent
class Message(Event):
@@ -20,7 +14,7 @@ class EchoWorkflow(Workflow):
@step()
async def run_step(self, ctx: Context, ev: StartEvent) -> StopEvent:
for i in range(3):
ctx.write_event_to_stream(Message(text=f"message number {i+1}"))
ctx.write_event_to_stream(Message(text=f"message number {i + 1}"))
await asyncio.sleep(0.5)
return StopEvent(result="Done.")
@@ -1,13 +1,8 @@
import asyncio
import os
from llama_index.core.workflow import (
Context,
StartEvent,
StopEvent,
Workflow,
step,
)
from workflows import Context, Workflow, step
from workflows.events import StartEvent, StopEvent
class MyWorkflow(Workflow):
@@ -18,7 +13,7 @@ class MyWorkflow(Workflow):
api_key = os.environ.get("API_KEY")
return StopEvent(
# result depends on variables read from environment
result=(f"var_1: {var_1}, " f"var_2: {var_2}, " f"api_key: {api_key}")
result=(f"var_1: {var_1}, var_2: {var_2}, api_key: {api_key}")
)
@@ -1,12 +1,9 @@
from llama_index.core.workflow import (
StartEvent,
StopEvent,
Workflow,
step,
)
from llama_index.core.workflow.events import (
from workflows import Workflow, step
from workflows.events import (
HumanResponseEvent,
InputRequiredEvent,
StartEvent,
StopEvent,
)
@@ -20,4 +17,4 @@ class HumanInTheLoopWorkflow(Workflow):
return StopEvent(result=ev.response)
workflow = HumanInTheLoopWorkflow()
workflow = HumanInTheLoopWorkflow(timeout=3)
@@ -1,4 +1,5 @@
from llama_index.core.workflow import Context, StartEvent, StopEvent, Workflow, step
from workflows import Context, Workflow, step
from workflows.events import StartEvent, StopEvent
class EchoWithPrompt(Workflow):
+2 -8
View File
@@ -1,13 +1,7 @@
import asyncio
from llama_index.core.workflow import (
Context,
Event,
StartEvent,
StopEvent,
Workflow,
step,
)
from workflows import Context, Workflow, step
from workflows.events import Event, StartEvent, StopEvent
class Message(Event):
+9 -8
View File
@@ -1,22 +1,23 @@
import asyncio
import json
from pathlib import Path
import pytest
from llama_deploy.types.core import TaskDefinition
@pytest.mark.asyncio
async def test_read_env_vars_git(apiserver, client):
here = Path(__file__).parent
deployment_fp = here / "deployments" / "deployment_env_git.yml"
with open(deployment_fp) as f:
await client.apiserver.deployments.create(f, base_path=deployment_fp.parent)
await asyncio.sleep(5)
deployment = await client.apiserver.deployments.create(
f, base_path=deployment_fp.parent
)
session = await client.core.sessions.create()
# run workflow
result = await session.run(
"workflow_git", env_vars_to_read=["VAR_1", "VAR_2", "API_KEY"]
input_str = json.dumps({"env_vars_to_read": ["VAR_1", "VAR_2", "API_KEY"]})
result = await deployment.tasks.run(
TaskDefinition(service_id="workflow_git", input=input_str)
)
assert result == "VAR_1: x, VAR_2: y, API_KEY: 123"
+8 -7
View File
@@ -1,20 +1,21 @@
import asyncio
from pathlib import Path
import pytest
from llama_deploy.types.core import TaskDefinition
@pytest.mark.asyncio
async def test_read_env_vars_local(apiserver, client):
here = Path(__file__).parent
deployment_fp = here / "deployments" / "deployment_env_local.yml"
with open(deployment_fp) as f:
await client.apiserver.deployments.create(f, base_path=deployment_fp.parent)
await asyncio.sleep(5)
deployment = await client.apiserver.deployments.create(
f, base_path=deployment_fp.parent
)
session = await client.core.sessions.create()
# run workflow
result = await session.run("test_env_workflow")
result = await deployment.tasks.run(
TaskDefinition(service_id="test_env_workflow", input="")
)
assert result == "var_1: z, var_2: y, api_key: 123"
+3 -5
View File
@@ -2,7 +2,7 @@ import asyncio
from pathlib import Path
import pytest
from llama_index.core.workflow.events import HumanResponseEvent
from workflows.events import HumanResponseEvent
from llama_deploy.types import TaskDefinition
@@ -15,16 +15,14 @@ async def test_hitl(apiserver, client):
deployment = await client.apiserver.deployments.create(
f, base_path=deployment_fp.parent
)
await asyncio.sleep(5)
tasks = deployment.tasks
task_handler = await tasks.create(TaskDefinition(input="{}"))
task_handler = await deployment.tasks.create(TaskDefinition(input="{}"))
ev_def = await task_handler.send_event(
ev=HumanResponseEvent(response="42"), service_name="hitl_workflow"
)
# wait for workflow to finish
await asyncio.sleep(2)
await asyncio.sleep(0.1)
result = await task_handler.results()
assert ev_def.service_id == "hitl_workflow"
-3
View File
@@ -1,4 +1,3 @@
import asyncio
from pathlib import Path
import pytest
@@ -14,7 +13,6 @@ async def test_reload(apiserver, client):
deployment = await client.apiserver.deployments.create(
f, base_path=deployment_fp.parent
)
await asyncio.sleep(3)
tasks = deployment.tasks
res = await tasks.run(TaskDefinition(input='{"data": "bar"}'))
@@ -25,7 +23,6 @@ async def test_reload(apiserver, client):
deployment = await client.apiserver.deployments.create(
f, base_path=deployment_fp.parent, reload=True
)
await asyncio.sleep(3)
tasks = deployment.tasks
res = await tasks.run(TaskDefinition(input='{"data": "bar"}'))
+3 -5
View File
@@ -1,4 +1,3 @@
import asyncio
from pathlib import Path
import pytest
@@ -14,13 +13,12 @@ async def test_stream(apiserver, client):
deployment = await client.apiserver.deployments.create(
f, base_path=deployment_fp.parent
)
await asyncio.sleep(5)
tasks = deployment.tasks
task = await tasks.create(TaskDefinition(input='{"a": "b"}'))
task = await deployment.tasks.create(TaskDefinition(input='{"a": "b"}'))
read_events = []
async for ev in task.events():
if "text" in ev:
if ev and "text" in ev:
read_events.append(ev)
assert len(read_events) == 3
# the workflow produces events sequentially, so here we can assume events arrived in order
+1 -1
View File
@@ -2,7 +2,7 @@ import asyncio
import time
import pytest
from llama_index.core.workflow.events import HumanResponseEvent
from workflows.events import HumanResponseEvent
from llama_deploy.client import Client
+4 -7
View File
@@ -1,12 +1,9 @@
from llama_index.core.workflow import (
StartEvent,
StopEvent,
Workflow,
step,
)
from llama_index.core.workflow.events import (
from workflows import Workflow, step
from workflows.events import (
HumanResponseEvent,
InputRequiredEvent,
StartEvent,
StopEvent,
)
+37 -179
View File
@@ -6,31 +6,20 @@ import site
import subprocess
import sys
import tempfile
from asyncio.subprocess import Process
from multiprocessing.pool import ThreadPool
from pathlib import Path
from typing import Type
import httpx
from dotenv import dotenv_values
from tenacity import AsyncRetrying, RetryError, wait_exponential
from workflows import Context, Workflow
from workflows.handler import WorkflowHandler
from llama_deploy.apiserver.source_managers.base import SyncPolicy
from llama_deploy.client import Client
from llama_deploy.control_plane import ControlPlaneServer
from llama_deploy.message_queues import (
AbstractMessageQueue,
KafkaMessageQueue,
RabbitMQMessageQueue,
RedisMessageQueue,
SimpleMessageQueue,
SimpleMessageQueueConfig,
)
from llama_deploy.message_queues.simple import SimpleMessageQueueServer
from llama_deploy.services import WorkflowService, WorkflowServiceConfig
from .deployment_config_parser import (
DeploymentConfig,
MessageQueueConfig,
Service,
SourceType,
)
@@ -76,25 +65,23 @@ class Deployment:
self._deployment_path = (
deployment_path if local else deployment_path / config.name
)
self._queue_client = self._load_message_queue_client(config.message_queue)
self._control_plane_config = config.control_plane
self._control_plane = ControlPlaneServer(
self._queue_client, config=config.control_plane
)
self._client = Client(control_plane_url=config.control_plane.url)
self._default_service: str | None = None
self._running = False
self._service_tasks: list[asyncio.Task] = []
self._service_startup_complete = asyncio.Event()
self._ui_server_process: Process | None = None
# Ready to load services
self._workflow_services: dict[str, WorkflowService] = self._load_services(
config
)
self._workflow_services: dict[str, Workflow] = self._load_services(config)
self._contexts: dict[str, Context] = {}
self._handlers: dict[str, WorkflowHandler] = {}
self._handler_inputs: dict[str, str] = {}
self._config = config
deployment_state.labels(self._name).state("ready")
@property
def default_service(self) -> str | None:
def default_service(self) -> str:
if not self._default_service:
self._default_service = list(self._workflow_services.keys())[0]
return self._default_service
@property
@@ -120,80 +107,27 @@ class Deployment:
"""
self._running = True
# Control Plane
tasks = await self._start_control_plane()
# UI
if self._config.ui:
await self._start_ui_server()
# Start the services. It makes no sense for a deployment to have no services but
# the configuration allows it, so let's be defensive here.
deployment_state.labels(self._name).state("starting_services")
if self._workflow_services:
tasks.append(asyncio.create_task(self._run_services()))
async def reload(self, config: DeploymentConfig) -> None:
# Reset default service, it might change across reloads
self._default_service = None
# Tear down the UI server
self._stop_ui_server()
# Reload the services
self._workflow_services = self._load_services(config)
# UI
if self._config.ui:
await self._start_ui_server()
# Run allthethings
deployment_state.labels(self._name).state("running")
await asyncio.gather(*tasks)
deployment_state.labels(self._name).state("stopped")
self._running = False
def _stop_ui_server(self) -> None:
if self._ui_server_process is None:
return
async def reload(self, config: DeploymentConfig) -> None:
"""Reload this deployment by restarting its services.
The reload process consists in cancelling the services tasks
and rely on the fact that _run_services() will restart them
with the new configuration. This function won't return until
_run_services will trigger the _service_startup_complete signal.
"""
self._workflow_services = self._load_services(config)
self._default_service = config.default_service
for t in self._service_tasks:
# t is awaited in _run_services(), we don't need to await here
t.cancel()
# Hold until _run_services() has restarted all the tasks
await self._service_startup_complete.wait()
async def _start_control_plane(self) -> list[asyncio.Task]:
tasks = []
tasks.append(asyncio.create_task(self._control_plane.launch_server()))
# Wait for the Control Plane to boot
try:
async for attempt in AsyncRetrying(
wait=wait_exponential(min=1, max=10),
):
with attempt:
async with httpx.AsyncClient() as client:
response = await client.get(self._control_plane_config.url)
response.raise_for_status()
except RetryError:
msg = f"Unable to reach Control Plane at {self._control_plane_config.url}"
raise DeploymentError(msg)
return tasks
async def _run_services(self) -> None:
"""Start an asyncio task for each service and gather them.
For the time self._running holds true, the tasks will be restarted
if they are all cancelled. This is to support the reload process
(see reload() for more details).
"""
while self._running:
self._service_tasks = []
# If this is a reload, self._workflow_services contains the updated configurations
for name, wfs in self._workflow_services.items():
logger.debug(f"Starting service {name}")
service_task = asyncio.create_task(wfs.launch_server())
self._service_tasks.append(service_task)
await wfs.register_to_control_plane(self._control_plane_config.url)
# If this is a reload, unblock the reload() function signalling that tasks are up and running
self._service_startup_complete.set()
await asyncio.gather(*self._service_tasks)
self._ui_server_process.terminate()
async def _start_ui_server(self) -> None:
"""Creates WorkflowService instances according to the configuration object."""
@@ -224,7 +158,7 @@ class Deployment:
# Override PORT and force using the one from the deployment.yaml file
env["PORT"] = str(self._config.ui.port)
process = await asyncio.create_subprocess_exec(
self._ui_server_process = await asyncio.create_subprocess_exec(
"pnpm",
"run",
"dev",
@@ -232,9 +166,9 @@ class Deployment:
env=env,
)
print(f"Started Next.js app with PID {process.pid}")
print(f"Started Next.js app with PID {self._ui_server_process.pid}")
def _load_services(self, config: DeploymentConfig) -> dict[str, WorkflowService]:
def _load_services(self, config: DeploymentConfig) -> dict[str, Workflow]:
"""Creates WorkflowService instances according to the configuration object."""
deployment_state.labels(self._name).state("loading_services")
workflow_services = {}
@@ -247,21 +181,10 @@ class Deployment:
# TODO: possibly start the default service if not running already
continue
# FIXME: Momentarily assuming everything is a workflow
if service_config.import_path is None:
msg = "path field in service definition must be set"
raise ValueError(msg)
if service_config.port is None:
# This won't happen if we arrive here from Manager.deploy(), the manager will assign a port
msg = "port field in service definition must be set"
raise ValueError(msg)
if service_config.host is None:
# This won't happen if we arrive here from Manager.deploy(), the manager will assign a host
msg = "host field in service definition must be set"
raise ValueError(msg)
# Sync the service source
service_state.labels(self._name, service_id).state("syncing")
destination = self._deployment_path.resolve()
@@ -285,27 +208,17 @@ class Deployment:
sys.path.append(str(pythonpath))
module = importlib.import_module(module_name)
workflow_services[service_id] = getattr(module, workflow_name)
workflow = getattr(module, workflow_name)
workflow_config = WorkflowServiceConfig(
host=service_config.host,
port=service_config.port,
internal_host="0.0.0.0",
internal_port=service_config.port,
service_name=service_id,
)
workflow_services[service_id] = WorkflowService(
workflow=workflow,
message_queue=self._queue_client,
config=workflow_config,
)
service_state.labels(self._name, service_id).state("ready")
if config.default_service in workflow_services:
self._default_service = config.default_service
else:
msg = f"There is no service with id '{config.default_service}' in this deployment, cannot set default."
logger.warning(msg)
if config.default_service:
if config.default_service in workflow_services:
self._default_service = config.default_service
else:
msg = f"Service with id '{config.default_service}' does not exist, cannot set it as default."
logger.warning(msg)
self._default_service = None
return workflow_services
@@ -428,26 +341,6 @@ class Deployment:
msg = f"Unable to install service dependencies using command '{e.cmd}': {e.stderr}"
raise DeploymentError(msg) from None
def _load_message_queue_client(
self, cfg: MessageQueueConfig | None
) -> AbstractMessageQueue:
# Use the SimpleMessageQueue as the default
if cfg is None:
# we use model_validate instead of __init__ to avoid static checkers complaining over field aliases
cfg = SimpleMessageQueueConfig()
if cfg.type == "kafka":
return KafkaMessageQueue(cfg)
elif cfg.type == "rabbitmq":
return RabbitMQMessageQueue(cfg)
elif cfg.type == "redis":
return RedisMessageQueue(cfg)
elif cfg.type == "simple":
return SimpleMessageQueue(cfg)
else:
msg = f"Unsupported message queue: {cfg.type}"
raise ValueError(msg)
class Manager:
"""The Manager orchestrates deployments and their runtime.
@@ -544,33 +437,6 @@ class Manager:
msg = "Reached the maximum number of deployments, cannot schedule more"
raise ValueError(msg)
# Set the control plane TCP port in the config where not specified
self._assign_control_plane_address(config)
# Get the message queue configuration
msg_queue = config.message_queue or SimpleMessageQueueConfig()
# Spawn SimpleMessageQueue server if needed
if (
isinstance(msg_queue, SimpleMessageQueueConfig)
and self._simple_message_queue_server is None
):
self._simple_message_queue_server = asyncio.create_task(
SimpleMessageQueueServer(msg_queue).launch_server()
)
# the other components need the queue to run in order to start, give the queue some time to start
try:
async for attempt in AsyncRetrying(
wait=wait_exponential(min=1, max=10),
):
with attempt:
async with httpx.AsyncClient() as client:
response = await client.get(msg_queue.base_url)
response.raise_for_status()
except RetryError:
msg = f"Unable to reach SimpleMessageQueueServer at {msg_queue.base_url}"
raise DeploymentError(msg)
deployment = Deployment(
config=config,
base_path=Path(base_path),
@@ -578,7 +444,7 @@ class Manager:
local=local,
)
self._deployments[config.name] = deployment
self._pool.apply_async(func=asyncio.run, args=(deployment.start(),))
await deployment.start()
else:
if config.name not in self._deployments:
msg = f"Cannot find deployment to reload: {config.name}"
@@ -586,11 +452,3 @@ class Manager:
deployment = self._deployments[config.name]
await deployment.reload(config)
def _assign_control_plane_address(self, config: DeploymentConfig) -> None:
for service in config.services.values():
if not service.port:
service.port = self._last_control_plane_port
self._last_control_plane_port += 1
if not service.host:
service.host = "localhost"
+51 -39
View File
@@ -8,6 +8,9 @@ import websockets
from fastapi import APIRouter, File, HTTPException, Request, UploadFile, WebSocket
from fastapi.responses import JSONResponse, StreamingResponse
from starlette.background import BackgroundTask
from workflows import Context
from workflows.context import JsonSerializer
from workflows.handler import WorkflowHandler
from llama_deploy.apiserver.deployment_config_parser import DeploymentConfig
from llama_deploy.apiserver.server import manager
@@ -17,7 +20,7 @@ from llama_deploy.types import (
SessionDefinition,
TaskDefinition,
)
from llama_deploy.types.core import TaskResult
from llama_deploy.types.core import TaskResult, generate_id
deployments_router = APIRouter(
prefix="/deployments",
@@ -76,18 +79,17 @@ async def create_deployment_task(
detail=f"Service '{task_definition.service_id}' not found in deployment 'deployment_name'",
)
workflow = deployment._workflow_services[task_definition.service_id]
if session_id:
session = await deployment.client.core.sessions.get(session_id)
context = deployment._contexts[session_id]
result = await workflow.run(
context=context, **json.loads(task_definition.input)
)
else:
session = await deployment.client.core.sessions.create()
result = await session.run(
task_definition.service_id or "", **json.loads(task_definition.input)
)
# Assume the request does not care about the session if no session_id is provided
if session_id is None:
await deployment.client.core.sessions.delete(session.id)
if task_definition.input:
result = await workflow.run(**json.loads(task_definition.input))
else:
result = await workflow.run()
return JSONResponse(result)
@@ -109,16 +111,20 @@ async def create_deployment_task_nowait(
)
task_definition.service_id = deployment.default_service
workflow = deployment._workflow_services[task_definition.service_id]
if session_id:
session = await deployment.client.core.sessions.get(session_id)
context = deployment._contexts[session_id]
handler = workflow.run(context=context, **json.loads(task_definition.input))
else:
session = await deployment.client.core.sessions.create()
session_id = session.id
handler = workflow.run(**json.loads(task_definition.input))
session_id = generate_id()
deployment._contexts[session_id] = handler.ctx or Context(workflow)
handler_id = generate_id()
deployment._handlers[handler_id] = handler
deployment._handler_inputs[handler_id] = task_definition.input
task_definition.session_id = session_id
task_definition.task_id = await session.run_nowait(
task_definition.service_id or "", **json.loads(task_definition.input)
)
task_definition.task_id = handler_id
return task_definition
@@ -135,9 +141,10 @@ async def send_event(
if deployment is None:
raise HTTPException(status_code=404, detail="Deployment not found")
session = await deployment.client.core.sessions.get(session_id)
await session.send_event_def(task_id=task_id, ev_def=event_def)
ctx = deployment._contexts[session_id]
serializer = JsonSerializer()
event = serializer.deserialize(event_def.event_obj_str)
ctx.send_event(event)
return event_def
@@ -160,18 +167,20 @@ async def get_events(
if deployment is None:
raise HTTPException(status_code=404, detail="Deployment not found")
session = await deployment.client.core.sessions.get(session_id)
async def event_stream() -> AsyncGenerator[str, None]:
async def event_stream(handler: WorkflowHandler) -> AsyncGenerator[str, None]:
serializer = JsonSerializer()
# need to convert back to str to use SSE
async for event in session.get_task_result_stream(task_id):
async for event in handler.stream_events():
data = json.loads(serializer.serialize(event))
if raw_event:
yield json.dumps(event) + "\n"
yield json.dumps(data) + "\n"
else:
yield json.dumps(event.get("value")) + "\n"
yield json.dumps(data.get("value")) + "\n"
await asyncio.sleep(0.01)
await handler
return StreamingResponse(
event_stream(),
event_stream(deployment._handlers[task_id]),
media_type="application/x-ndjson",
)
@@ -185,8 +194,8 @@ async def get_task_result(
if deployment is None:
raise HTTPException(status_code=404, detail="Deployment not found")
session = await deployment.client.core.sessions.get(session_id)
return await session.get_task_result(task_id)
handler = deployment._handlers[task_id]
return TaskResult(task_id=task_id, history=[], result=await handler)
@deployments_router.get("/{deployment_name}/tasks")
@@ -199,9 +208,11 @@ async def get_tasks(
raise HTTPException(status_code=404, detail="Deployment not found")
tasks: list[TaskDefinition] = []
for session in await deployment.client.core.sessions.list():
for task_def in await session.get_tasks():
tasks.append(task_def)
for task_id in deployment._handlers.keys():
tasks.append(
TaskDefinition(task_id=task_id, input=deployment._handler_inputs[task_id])
)
return tasks
@@ -214,8 +225,7 @@ async def get_sessions(
if deployment is None:
raise HTTPException(status_code=404, detail="Deployment not found")
sessions = await deployment.client.core.sessions.list()
return [SessionDefinition(session_id=s.id) for s in sessions]
return [SessionDefinition(session_id=k) for k in deployment._contexts.keys()]
@deployments_router.get("/{deployment_name}/sessions/{session_id}")
@@ -225,8 +235,7 @@ async def get_session(deployment_name: str, session_id: str) -> SessionDefinitio
if deployment is None:
raise HTTPException(status_code=404, detail="Deployment not found")
session = await deployment.client.core.sessions.get(session_id)
return SessionDefinition(session_id=session.id)
return SessionDefinition(session_id=session_id)
@deployments_router.post("/{deployment_name}/sessions/create")
@@ -236,8 +245,11 @@ async def create_session(deployment_name: str) -> SessionDefinition:
if deployment is None:
raise HTTPException(status_code=404, detail="Deployment not found")
session = await deployment.client.core.sessions.create()
return SessionDefinition(session_id=session.id)
workflow = deployment._workflow_services[deployment.default_service]
session_id = generate_id()
deployment._contexts[session_id] = Context(workflow)
return SessionDefinition(session_id=session_id)
@deployments_router.post("/{deployment_name}/sessions/delete")
@@ -247,7 +259,7 @@ async def delete_session(deployment_name: str, session_id: str) -> None:
if deployment is None:
raise HTTPException(status_code=404, detail="Deployment not found")
await deployment.client.core.sessions.delete(session_id)
deployment._contexts.pop(session_id)
async def _ws_proxy(ws: WebSocket, upstream_url: str) -> None:
+12
View File
@@ -79,6 +79,18 @@ class SessionCollection(Collection):
return r.json()
async def get(self, id: str) -> SessionDefinition:
"""Gets a deployment by id."""
get_url = f"{self.client.api_server_url}/deployments/{self.deployment_id}/sessions/{id}"
await self.client.request(
"GET",
get_url,
verify=not self.client.disable_ssl,
timeout=self.client.timeout,
)
model_class = self._prepare(SessionDefinition)
return model_class(client=self.client, id=id)
class Task(Model):
"""A model representing a task belonging to a given session in the given deployment."""
+2 -1
View File
@@ -51,7 +51,8 @@ dependencies = [
"platformdirs>=4.3.6,<5",
"rich>=13.9.4,<14",
"brotli>=1.1.0",
"websockets>=15.0.1"
"websockets>=15.0.1",
"llama-index-workflows>=0.2.1"
]
[project.optional-dependencies]
@@ -1,5 +1,7 @@
import os
from llama_index.core.workflow import Context, StartEvent, StopEvent, Workflow, step
from workflows import Context, Workflow, step
from workflows.events import StartEvent, StopEvent
class MyWorkflow(Workflow):
+70 -47
View File
@@ -130,6 +130,9 @@ def test_run_deployment_task(
) -> None:
deployment = mock.AsyncMock()
deployment.default_service = "TestService"
mocked_workflow = mock.AsyncMock()
mocked_workflow.run.return_value = "foo"
deployment._workflow_services = {"TestService": mocked_workflow}
session = mock.AsyncMock(id="42")
deployment.client.core.sessions.create.return_value = session
@@ -145,7 +148,6 @@ def test_run_deployment_task(
json={"input": "{}"},
)
assert response.status_code == 200
deployment.client.core.sessions.delete.assert_called_with("42")
deployment.reset_mock()
response = http_client.post(
@@ -163,13 +165,15 @@ def test_create_deployment_task(
deployment = mock.AsyncMock()
deployment.default_service = "TestService"
session = mock.AsyncMock(id="42")
deployment.client.core.sessions.create.return_value = session
session.run_nowait.return_value = "test_task_id"
session_from_get = mock.AsyncMock(id="84")
deployment.client.core.sessions.get.return_value = session_from_get
session_from_get.run_nowait.return_value = "another_test_task_id"
# Mock workflow that returns a handler
mock_workflow = mock.MagicMock()
mock_handler = mock.MagicMock()
mock_handler.ctx = mock.MagicMock()
mock_workflow.run.return_value = mock_handler
deployment._workflow_services = {"TestService": mock_workflow}
deployment._handlers = {}
deployment._handler_inputs = {}
deployment._contexts = {"84": mock.MagicMock()} # For session_id test
mock_manager.get_deployment.return_value = deployment
response = http_client.post(
@@ -178,7 +182,7 @@ def test_create_deployment_task(
)
assert response.status_code == 200
td = TaskDefinition(**response.json())
assert td.task_id == "test_task_id"
assert td.task_id is not None
deployment.reset_mock()
response = http_client.post(
@@ -187,7 +191,6 @@ def test_create_deployment_task(
params={"session_id": 84},
)
assert response.status_code == 200
deployment.client.core.sessions.delete.assert_not_called()
def test_send_event_not_found(
@@ -211,9 +214,8 @@ def test_send_event(
) -> None:
deployment = mock.AsyncMock()
deployment.default_service = "TestService"
session = mock.AsyncMock()
deployment.client.core.sessions.create.return_value = session
session.id = "42"
mock_context = mock.MagicMock()
deployment._contexts = {"42": mock_context}
mock_manager.get_deployment.return_value = deployment
serializer = JsonSerializer()
@@ -256,12 +258,23 @@ async def test_get_event_stream(
deployment = mock.AsyncMock()
deployment.default_service = "TestService"
session = mock.MagicMock()
deployment.client.core.sessions.get.return_value = session
# Mock handler that streams events
class MockHandler:
async def stream_events(self): # type:ignore
for event in mock_events:
yield Event(msg=event["value"]["_data"]["msg"])
def __await__(self): # type:ignore
# Make it awaitable
async def await_impl(): # type:ignore
return "completed"
return await_impl().__await__()
mock_handler = MockHandler()
deployment._handlers = {"test_task_id": mock_handler}
mock_manager.get_deployment.return_value = deployment
mocked_get_task_result_stream = mock.MagicMock()
mocked_get_task_result_stream.__aiter__.return_value = mock_events
session.get_task_result_stream.return_value = mocked_get_task_result_stream
response = http_client.get(
"/deployments/test-deployment/tasks/test_task_id/events/?session_id=42",
@@ -272,8 +285,6 @@ async def test_get_event_stream(
data = json.loads(line)
assert data == mock_events[ix].get("value")
ix += 1
deployment.client.core.sessions.get.assert_called_with("42")
session.get_task_result_stream.assert_called_with("test_task_id")
@pytest.mark.asyncio
@@ -289,12 +300,23 @@ async def test_get_event_stream_raw(
deployment = mock.AsyncMock()
deployment.default_service = "TestService"
session = mock.MagicMock()
deployment.client.core.sessions.get.return_value = session
# Mock handler that streams events
class MockHandler:
async def stream_events(self): # type:ignore
for event in mock_events:
yield Event(msg=event["value"]["_data"]["msg"])
def __await__(self): # type:ignore
# Make it awaitable
async def await_impl(): # type:ignore
return "completed"
return await_impl().__await__()
mock_handler = MockHandler()
deployment._handlers = {"test_task_id": mock_handler}
mock_manager.get_deployment.return_value = deployment
mocked_get_task_result_stream = mock.MagicMock()
mocked_get_task_result_stream.__aiter__.return_value = mock_events
session.get_task_result_stream.return_value = mocked_get_task_result_stream
response = http_client.get(
"/deployments/test-deployment/tasks/test_task_id/events/?session_id=42&raw_event=true",
@@ -309,8 +331,6 @@ async def test_get_event_stream_raw(
assert "qualified_name" in data
assert data["qualified_name"] == "llama_index.core.workflow.events.Event"
ix += 1
deployment.client.core.sessions.get.assert_called_with("42")
session.get_task_result_stream.assert_called_with("test_task_id")
def test_get_task_result_not_found(
@@ -337,10 +357,9 @@ def test_get_tasks(
http_client: TestClient, data_path: Path, mock_manager: MagicMock
) -> None:
deployment = mock.AsyncMock()
deployment._handlers = {"task1": mock.MagicMock()}
deployment._handler_inputs = {"task1": "foo"}
mock_manager.get_deployment.return_value = deployment
session = mock.AsyncMock()
session.get_tasks.return_value = [TaskDefinition(input="foo")]
deployment.client.core.sessions.list.return_value = [session]
response = http_client.get(
"/deployments/test-deployment/tasks",
@@ -355,11 +374,18 @@ def test_get_task_result(
) -> None:
deployment = mock.AsyncMock()
deployment.default_service = "TestService"
session = mock.AsyncMock()
deployment.client.core.sessions.get.return_value = session
session.get_task_result.return_value = TaskResult(
result="test_result", history=[], task_id="test_task_id"
)
# Mock the handler to return the expected result - needs to be awaitable
class MockHandler:
def __await__(self): # type:ignore
async def await_impl(): # type:ignore
return "test_result"
return await_impl().__await__()
mock_handler = MockHandler()
deployment._handlers = {"test_task_id": mock_handler}
mock_manager.get_deployment.return_value = deployment
response = http_client.get(
@@ -367,8 +393,6 @@ def test_get_task_result(
)
assert response.status_code == 200
assert TaskResult(**response.json()).result == "test_result"
session.get_task_result.assert_called_with("test_task_id")
deployment.client.core.sessions.get.assert_called_with("42")
def test_get_sessions_not_found(
@@ -386,7 +410,7 @@ def test_get_sessions(
) -> None:
deployment = mock.AsyncMock()
deployment.default_service = "TestService"
deployment.client.list_sessions.return_value = []
deployment._contexts = {} # Empty contexts
mock_manager.get_deployment.return_value = deployment
response = http_client.get(
@@ -412,13 +436,13 @@ def test_delete_session(
) -> None:
deployment = mock.AsyncMock()
deployment.default_service = "TestService"
deployment._contexts = {"42": mock.MagicMock()} # Mock context to be deleted
mock_manager.get_deployment.return_value = deployment
response = http_client.post(
"/deployments/test-deployment/sessions/delete/?session_id=42",
)
assert response.status_code == 200
deployment.client.core.sessions.delete.assert_called_with("42")
def test_get_session_not_found(
@@ -458,8 +482,9 @@ def test_create_session(
http_client: TestClient, data_path: Path, mock_manager: MagicMock
) -> None:
deployment = mock.AsyncMock()
session = mock.AsyncMock(id="test-session-id")
deployment.client.core.sessions.create.return_value = session
deployment.default_service = "TestService"
deployment._workflow_services = {"TestService": mock.MagicMock()}
deployment._contexts = {}
mock_manager.get_deployment.return_value = deployment
response = http_client.post(
@@ -467,15 +492,13 @@ def test_create_session(
)
assert response.status_code == 200
assert response.json() == {
"session_id": "test-session-id",
"state": {},
"task_ids": [],
}
# The response should contain a generated session_id
assert "session_id" in response.json()
assert response.json()["state"] == {}
assert response.json()["task_ids"] == []
# Verify the mocked calls
mock_manager.get_deployment.assert_called_once_with("test-deployment")
deployment.client.core.sessions.create.assert_called_once()
@respx.mock
+9 -158
View File
@@ -1,14 +1,13 @@
import asyncio
from collections.abc import Generator
import subprocess
import sys
from collections.abc import Generator
from copy import deepcopy
from pathlib import Path
from typing import Any
from unittest import mock
import pytest
from tenacity import RetryError
from llama_deploy.apiserver.deployment import (
SOURCE_MANAGERS,
@@ -23,9 +22,7 @@ from llama_deploy.apiserver.deployment_config_parser import (
SyncPolicy,
UIService,
)
from llama_deploy.control_plane import ControlPlaneConfig, ControlPlaneServer
from llama_deploy.message_queues import SimpleMessageQueue
from llama_deploy.message_queues.redis import RedisMessageQueueConfig
from llama_deploy.control_plane import ControlPlaneConfig
@pytest.fixture
@@ -60,11 +57,10 @@ def test_deployment_ctor(data_path: Path, mock_importlib: Any, tmp_path: Path) -
sm_dict["git"].return_value.sync.assert_called_once()
assert d.name == "TestDeployment"
assert d._deployment_path.name == "TestDeployment"
assert type(d._control_plane) is ControlPlaneServer
assert len(d._workflow_services) == 1
assert d.service_names == ["test-workflow"]
assert d.client is not None
assert d.default_service is None
assert d.default_service == "test-workflow"
def test_deployment_ctor_missing_service_path(data_path: Path, tmp_path: Path) -> None:
@@ -76,24 +72,6 @@ def test_deployment_ctor_missing_service_path(data_path: Path, tmp_path: Path) -
Deployment(config=config, base_path=data_path, deployment_path=tmp_path)
def test_deployment_ctor_missing_service_port(data_path: Path, tmp_path: Path) -> None:
config = DeploymentConfig.from_yaml(data_path / "git_service.yaml")
config.services["test-workflow"].port = None
with pytest.raises(
ValueError, match="port field in service definition must be set"
):
Deployment(config=config, base_path=data_path, deployment_path=tmp_path)
def test_deployment_ctor_missing_service_host(data_path: Path, tmp_path: Path) -> None:
config = DeploymentConfig.from_yaml(data_path / "git_service.yaml")
config.services["test-workflow"].host = None
with pytest.raises(
ValueError, match="host field in service definition must be set"
):
Deployment(config=config, base_path=data_path, deployment_path=tmp_path)
def test_deployment_ctor_skip_default_service(
data_path: Path, mock_importlib: Any, tmp_path: Path
) -> None:
@@ -112,10 +90,9 @@ def test_deployment_ctor_invalid_default_service(
config = DeploymentConfig.from_yaml(data_path / "local.yaml")
config.default_service = "does-not-exist"
d = Deployment(config=config, base_path=data_path, deployment_path=tmp_path)
assert d.default_service is None
Deployment(config=config, base_path=data_path, deployment_path=tmp_path)
assert (
"There is no service with id 'does-not-exist' in this deployment, cannot set default."
"Service with id 'does-not-exist' does not exist, cannot set it as default."
in caplog.text
)
@@ -130,41 +107,6 @@ def test_deployment_ctor_default_service(
assert d.default_service == "test-workflow"
def test_deployment___load_message_queue_default(mocked_deployment: Deployment) -> None:
q = mocked_deployment._load_message_queue_client(None)
assert type(q) is SimpleMessageQueue
assert q._config.port == 8001
assert q._config.host == "127.0.0.1"
def test_deployment___load_message_queue_not_supported(
mocked_deployment: Deployment,
) -> None:
mocked_config = mock.MagicMock(queue_type="does_not_exist")
with pytest.raises(ValueError, match="Unsupported message queue:"):
mocked_deployment._load_message_queue_client(mocked_config)
def test_deployment__load_message_queues(mocked_deployment: Deployment) -> None:
with mock.patch("llama_deploy.apiserver.deployment.KafkaMessageQueue") as m:
mocked_config = mock.MagicMock(type="kafka")
mocked_config.model_dump.return_value = {"foo": "kafka"}
mocked_deployment._load_message_queue_client(mocked_config)
m.assert_called_with(mocked_config)
with mock.patch("llama_deploy.apiserver.deployment.RabbitMQMessageQueue") as m:
mocked_config = mock.MagicMock(type="rabbitmq")
mocked_config.model_dump.return_value = {"foo": "rabbitmq"}
mocked_deployment._load_message_queue_client(mocked_config)
m.assert_called_with(mocked_config)
with mock.patch("llama_deploy.apiserver.deployment.RedisMessageQueue") as m:
mocked_config = mock.MagicMock(type="redis")
mocked_config.model_dump.return_value = {"foo": "redis"}
mocked_deployment._load_message_queue_client(mocked_config)
m.assert_called_with(mocked_config)
def test__install_dependencies(data_path: Path) -> None:
config = DeploymentConfig.from_yaml(data_path / "python_dependencies.yaml")
service_config = config.services["myworkflow"]
@@ -530,17 +472,19 @@ async def test_manager_deploy_maximum_reached(data_path: Path) -> None:
@pytest.mark.asyncio
async def test_manager_deploy(data_path: Path) -> None:
config = DeploymentConfig.from_yaml(data_path / "git_service.yaml")
# Do not use SimpleMessageQueue here, to avoid starting the server
config.message_queue = RedisMessageQueueConfig()
with mock.patch(
"llama_deploy.apiserver.deployment.Deployment"
) as mocked_deployment:
# Mock the start method as an async method
mocked_deployment.return_value.start = mock.AsyncMock()
m = Manager()
m._serving = True
m._deployments_path = Path()
await m.deploy(config, base_path=str(data_path))
mocked_deployment.assert_called_once()
mocked_deployment.return_value.start.assert_awaited_once()
assert m.deployment_names == ["TestDeployment"]
assert m.get_deployment("TestDeployment") is not None
@@ -563,82 +507,6 @@ async def test_manager_serve_loop(tmp_path: Path) -> None:
assert serve_task.exception() is None
def test_manager_assign_control_plane_port(data_path: Path) -> None:
m = Manager()
config = DeploymentConfig.from_yaml(data_path / "service_ports.yaml")
m._assign_control_plane_address(config)
assert config.services["no-port"].port == 8002
assert config.services["has-port"].port == 9999
assert config.services["no-port-again"].port == 8003
@pytest.mark.asyncio
async def test_start_control_plane_success(
deployment_config: DeploymentConfig, tmp_path: Path
) -> None:
# Create deployment instance
deployment = Deployment(
config=deployment_config, base_path=Path(), deployment_path=tmp_path
)
# Mock control plane methods
deployment._control_plane.launch_server = mock.AsyncMock() # type: ignore
# Mock httpx client
mock_response = mock.MagicMock()
mock_response.raise_for_status = mock.MagicMock()
mock_client = mock.AsyncMock()
mock_client.__aenter__.return_value.get.return_value = mock_response
with mock.patch("httpx.AsyncClient", return_value=mock_client):
# Run the method
tasks = await deployment._start_control_plane()
# Verify tasks were created
assert len(tasks)
assert all(isinstance(task, asyncio.Task) for task in tasks)
# Verify control plane methods were called
deployment._control_plane.launch_server.assert_called_once()
# Verify health check was performed
mock_client.__aenter__.return_value.get.assert_called_with(
deployment_config.control_plane.url
)
@pytest.mark.asyncio
async def test_start_control_plane_failure(
deployment_config: DeploymentConfig, tmp_path: Path
) -> None:
# Create deployment instance
deployment = Deployment(
config=deployment_config, base_path=Path(), deployment_path=tmp_path
)
# Mock control plane methods
deployment._control_plane.launch_server = mock.AsyncMock() # type: ignore
# Create a mock attempt
mock_attempt: asyncio.Future = asyncio.Future()
mock_attempt.set_exception(Exception("Connection failed"))
# Mock AsyncRetrying to raise an exception
with mock.patch(
"llama_deploy.apiserver.deployment.AsyncRetrying",
side_effect=RetryError(last_attempt=mock_attempt), # type: ignore
):
# Verify DeploymentError is raised
with pytest.raises(DeploymentError) as exc_info:
await deployment._start_control_plane()
assert "Unable to reach Control Plane" in str(exc_info.value)
# Verify control plane methods were still called
deployment._control_plane.launch_server.assert_called_once()
@pytest.mark.asyncio
async def test_start_sequence(
deployment_config: DeploymentConfig, tmp_path: Path
@@ -646,13 +514,8 @@ async def test_start_sequence(
deployment = Deployment(
config=deployment_config, base_path=Path(), deployment_path=tmp_path
)
deployment._start_control_plane = mock.AsyncMock() # type: ignore
deployment._run_services = mock.AsyncMock() # type: ignore
deployment._start_ui_server = mock.AsyncMock() # type: ignore
await deployment.start()
deployment._start_control_plane.assert_awaited_once()
# no services should start
deployment._run_services.assert_not_awaited()
# no ui server
deployment._start_ui_server.assert_not_awaited()
@@ -664,20 +527,12 @@ async def test_start_with_services(data_path: Path, tmp_path: Path) -> None:
deployment = Deployment(
config=config, base_path=data_path, deployment_path=tmp_path
)
deployment._start_control_plane = mock.AsyncMock(return_value=[]) # type: ignore
deployment._run_services = mock.AsyncMock(return_value=[]) # type: ignore
deployment._start_ui_server = mock.AsyncMock(return_value=[]) # type: ignore
sm_dict["git"] = mock.MagicMock()
await deployment.start()
sm_dict["git"].return_value.sync.assert_called_once()
# Verify control plane was started
deployment._start_control_plane.assert_awaited_once()
# Verify services were started
deployment._run_services.assert_awaited_once()
# Verify UI server was not started
deployment._start_ui_server.assert_not_awaited()
@@ -690,15 +545,11 @@ async def test_start_with_services_ui(data_path: Path, tmp_path: Path) -> None:
deployment = Deployment(
config=config, base_path=data_path, deployment_path=tmp_path
)
deployment._start_control_plane = mock.AsyncMock(return_value=[]) # type: ignore
deployment._run_services = mock.AsyncMock(return_value=[]) # type: ignore
deployment._start_ui_server = mock.AsyncMock(return_value=[]) # type: ignore
sm_dict["git"] = mock.MagicMock()
await deployment.start()
deployment._start_control_plane.assert_awaited_once()
deployment._run_services.assert_awaited_once()
deployment._start_ui_server.assert_awaited_once()
Generated
+1515 -1462
View File
File diff suppressed because it is too large Load Diff