mirror of
https://github.com/run-llama/llama_deploy.git
synced 2026-07-01 21:04:00 -04:00
refact: bypass message queue and control plane from apiserver (#542)
* depend on workflows explicitly * temp * backup * revert * backup * fix test * e2e * remove cp and mq * remove sleeps * fix unit tests * fix ui server
This commit is contained in:
committed by
GitHub
parent
4137ddb317
commit
671295d518
@@ -8,7 +8,7 @@ services:
|
||||
name: Git Workflow
|
||||
source:
|
||||
type: git
|
||||
name: https://github.com/run-llama/llama_deploy.git
|
||||
name: https://github.com/run-llama/llama_deploy.git@massi/refact
|
||||
env:
|
||||
VAR_1: x # this gets overwritten because VAR_1 also exists in the provided .env
|
||||
VAR_2: y
|
||||
|
||||
@@ -1,13 +1,7 @@
|
||||
import asyncio
|
||||
|
||||
from llama_index.core.workflow import (
|
||||
Context,
|
||||
Event,
|
||||
StartEvent,
|
||||
StopEvent,
|
||||
Workflow,
|
||||
step,
|
||||
)
|
||||
from workflows import Context, Workflow, step
|
||||
from workflows.events import Event, StartEvent, StopEvent
|
||||
|
||||
|
||||
class Message(Event):
|
||||
@@ -20,7 +14,7 @@ class EchoWorkflow(Workflow):
|
||||
@step()
|
||||
async def run_step(self, ctx: Context, ev: StartEvent) -> StopEvent:
|
||||
for i in range(3):
|
||||
ctx.write_event_to_stream(Message(text=f"message number {i+1}"))
|
||||
ctx.write_event_to_stream(Message(text=f"message number {i + 1}"))
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
return StopEvent(result="Done.")
|
||||
|
||||
@@ -1,13 +1,8 @@
|
||||
import asyncio
|
||||
import os
|
||||
|
||||
from llama_index.core.workflow import (
|
||||
Context,
|
||||
StartEvent,
|
||||
StopEvent,
|
||||
Workflow,
|
||||
step,
|
||||
)
|
||||
from workflows import Context, Workflow, step
|
||||
from workflows.events import StartEvent, StopEvent
|
||||
|
||||
|
||||
class MyWorkflow(Workflow):
|
||||
@@ -18,7 +13,7 @@ class MyWorkflow(Workflow):
|
||||
api_key = os.environ.get("API_KEY")
|
||||
return StopEvent(
|
||||
# result depends on variables read from environment
|
||||
result=(f"var_1: {var_1}, " f"var_2: {var_2}, " f"api_key: {api_key}")
|
||||
result=(f"var_1: {var_1}, var_2: {var_2}, api_key: {api_key}")
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -1,12 +1,9 @@
|
||||
from llama_index.core.workflow import (
|
||||
StartEvent,
|
||||
StopEvent,
|
||||
Workflow,
|
||||
step,
|
||||
)
|
||||
from llama_index.core.workflow.events import (
|
||||
from workflows import Workflow, step
|
||||
from workflows.events import (
|
||||
HumanResponseEvent,
|
||||
InputRequiredEvent,
|
||||
StartEvent,
|
||||
StopEvent,
|
||||
)
|
||||
|
||||
|
||||
@@ -20,4 +17,4 @@ class HumanInTheLoopWorkflow(Workflow):
|
||||
return StopEvent(result=ev.response)
|
||||
|
||||
|
||||
workflow = HumanInTheLoopWorkflow()
|
||||
workflow = HumanInTheLoopWorkflow(timeout=3)
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from llama_index.core.workflow import Context, StartEvent, StopEvent, Workflow, step
|
||||
from workflows import Context, Workflow, step
|
||||
from workflows.events import StartEvent, StopEvent
|
||||
|
||||
|
||||
class EchoWithPrompt(Workflow):
|
||||
|
||||
@@ -1,13 +1,7 @@
|
||||
import asyncio
|
||||
|
||||
from llama_index.core.workflow import (
|
||||
Context,
|
||||
Event,
|
||||
StartEvent,
|
||||
StopEvent,
|
||||
Workflow,
|
||||
step,
|
||||
)
|
||||
from workflows import Context, Workflow, step
|
||||
from workflows.events import Event, StartEvent, StopEvent
|
||||
|
||||
|
||||
class Message(Event):
|
||||
|
||||
@@ -1,22 +1,23 @@
|
||||
import asyncio
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_deploy.types.core import TaskDefinition
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_read_env_vars_git(apiserver, client):
|
||||
here = Path(__file__).parent
|
||||
deployment_fp = here / "deployments" / "deployment_env_git.yml"
|
||||
with open(deployment_fp) as f:
|
||||
await client.apiserver.deployments.create(f, base_path=deployment_fp.parent)
|
||||
await asyncio.sleep(5)
|
||||
deployment = await client.apiserver.deployments.create(
|
||||
f, base_path=deployment_fp.parent
|
||||
)
|
||||
|
||||
session = await client.core.sessions.create()
|
||||
|
||||
# run workflow
|
||||
result = await session.run(
|
||||
"workflow_git", env_vars_to_read=["VAR_1", "VAR_2", "API_KEY"]
|
||||
input_str = json.dumps({"env_vars_to_read": ["VAR_1", "VAR_2", "API_KEY"]})
|
||||
result = await deployment.tasks.run(
|
||||
TaskDefinition(service_id="workflow_git", input=input_str)
|
||||
)
|
||||
|
||||
assert result == "VAR_1: x, VAR_2: y, API_KEY: 123"
|
||||
|
||||
@@ -1,20 +1,21 @@
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_deploy.types.core import TaskDefinition
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_read_env_vars_local(apiserver, client):
|
||||
here = Path(__file__).parent
|
||||
deployment_fp = here / "deployments" / "deployment_env_local.yml"
|
||||
with open(deployment_fp) as f:
|
||||
await client.apiserver.deployments.create(f, base_path=deployment_fp.parent)
|
||||
await asyncio.sleep(5)
|
||||
deployment = await client.apiserver.deployments.create(
|
||||
f, base_path=deployment_fp.parent
|
||||
)
|
||||
|
||||
session = await client.core.sessions.create()
|
||||
|
||||
# run workflow
|
||||
result = await session.run("test_env_workflow")
|
||||
result = await deployment.tasks.run(
|
||||
TaskDefinition(service_id="test_env_workflow", input="")
|
||||
)
|
||||
|
||||
assert result == "var_1: z, var_2: y, api_key: 123"
|
||||
|
||||
@@ -2,7 +2,7 @@ import asyncio
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from llama_index.core.workflow.events import HumanResponseEvent
|
||||
from workflows.events import HumanResponseEvent
|
||||
|
||||
from llama_deploy.types import TaskDefinition
|
||||
|
||||
@@ -15,16 +15,14 @@ async def test_hitl(apiserver, client):
|
||||
deployment = await client.apiserver.deployments.create(
|
||||
f, base_path=deployment_fp.parent
|
||||
)
|
||||
await asyncio.sleep(5)
|
||||
|
||||
tasks = deployment.tasks
|
||||
task_handler = await tasks.create(TaskDefinition(input="{}"))
|
||||
task_handler = await deployment.tasks.create(TaskDefinition(input="{}"))
|
||||
ev_def = await task_handler.send_event(
|
||||
ev=HumanResponseEvent(response="42"), service_name="hitl_workflow"
|
||||
)
|
||||
|
||||
# wait for workflow to finish
|
||||
await asyncio.sleep(2)
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
result = await task_handler.results()
|
||||
assert ev_def.service_id == "hitl_workflow"
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
@@ -14,7 +13,6 @@ async def test_reload(apiserver, client):
|
||||
deployment = await client.apiserver.deployments.create(
|
||||
f, base_path=deployment_fp.parent
|
||||
)
|
||||
await asyncio.sleep(3)
|
||||
|
||||
tasks = deployment.tasks
|
||||
res = await tasks.run(TaskDefinition(input='{"data": "bar"}'))
|
||||
@@ -25,7 +23,6 @@ async def test_reload(apiserver, client):
|
||||
deployment = await client.apiserver.deployments.create(
|
||||
f, base_path=deployment_fp.parent, reload=True
|
||||
)
|
||||
await asyncio.sleep(3)
|
||||
|
||||
tasks = deployment.tasks
|
||||
res = await tasks.run(TaskDefinition(input='{"data": "bar"}'))
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
@@ -14,13 +13,12 @@ async def test_stream(apiserver, client):
|
||||
deployment = await client.apiserver.deployments.create(
|
||||
f, base_path=deployment_fp.parent
|
||||
)
|
||||
await asyncio.sleep(5)
|
||||
|
||||
tasks = deployment.tasks
|
||||
task = await tasks.create(TaskDefinition(input='{"a": "b"}'))
|
||||
task = await deployment.tasks.create(TaskDefinition(input='{"a": "b"}'))
|
||||
|
||||
read_events = []
|
||||
async for ev in task.events():
|
||||
if "text" in ev:
|
||||
if ev and "text" in ev:
|
||||
read_events.append(ev)
|
||||
assert len(read_events) == 3
|
||||
# the workflow produces events sequentially, so here we can assume events arrived in order
|
||||
|
||||
@@ -2,7 +2,7 @@ import asyncio
|
||||
import time
|
||||
|
||||
import pytest
|
||||
from llama_index.core.workflow.events import HumanResponseEvent
|
||||
from workflows.events import HumanResponseEvent
|
||||
|
||||
from llama_deploy.client import Client
|
||||
|
||||
|
||||
@@ -1,12 +1,9 @@
|
||||
from llama_index.core.workflow import (
|
||||
StartEvent,
|
||||
StopEvent,
|
||||
Workflow,
|
||||
step,
|
||||
)
|
||||
from llama_index.core.workflow.events import (
|
||||
from workflows import Workflow, step
|
||||
from workflows.events import (
|
||||
HumanResponseEvent,
|
||||
InputRequiredEvent,
|
||||
StartEvent,
|
||||
StopEvent,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -6,31 +6,20 @@ import site
|
||||
import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
from asyncio.subprocess import Process
|
||||
from multiprocessing.pool import ThreadPool
|
||||
from pathlib import Path
|
||||
from typing import Type
|
||||
|
||||
import httpx
|
||||
from dotenv import dotenv_values
|
||||
from tenacity import AsyncRetrying, RetryError, wait_exponential
|
||||
from workflows import Context, Workflow
|
||||
from workflows.handler import WorkflowHandler
|
||||
|
||||
from llama_deploy.apiserver.source_managers.base import SyncPolicy
|
||||
from llama_deploy.client import Client
|
||||
from llama_deploy.control_plane import ControlPlaneServer
|
||||
from llama_deploy.message_queues import (
|
||||
AbstractMessageQueue,
|
||||
KafkaMessageQueue,
|
||||
RabbitMQMessageQueue,
|
||||
RedisMessageQueue,
|
||||
SimpleMessageQueue,
|
||||
SimpleMessageQueueConfig,
|
||||
)
|
||||
from llama_deploy.message_queues.simple import SimpleMessageQueueServer
|
||||
from llama_deploy.services import WorkflowService, WorkflowServiceConfig
|
||||
|
||||
from .deployment_config_parser import (
|
||||
DeploymentConfig,
|
||||
MessageQueueConfig,
|
||||
Service,
|
||||
SourceType,
|
||||
)
|
||||
@@ -76,25 +65,23 @@ class Deployment:
|
||||
self._deployment_path = (
|
||||
deployment_path if local else deployment_path / config.name
|
||||
)
|
||||
self._queue_client = self._load_message_queue_client(config.message_queue)
|
||||
self._control_plane_config = config.control_plane
|
||||
self._control_plane = ControlPlaneServer(
|
||||
self._queue_client, config=config.control_plane
|
||||
)
|
||||
self._client = Client(control_plane_url=config.control_plane.url)
|
||||
self._default_service: str | None = None
|
||||
self._running = False
|
||||
self._service_tasks: list[asyncio.Task] = []
|
||||
self._service_startup_complete = asyncio.Event()
|
||||
self._ui_server_process: Process | None = None
|
||||
# Ready to load services
|
||||
self._workflow_services: dict[str, WorkflowService] = self._load_services(
|
||||
config
|
||||
)
|
||||
self._workflow_services: dict[str, Workflow] = self._load_services(config)
|
||||
self._contexts: dict[str, Context] = {}
|
||||
self._handlers: dict[str, WorkflowHandler] = {}
|
||||
self._handler_inputs: dict[str, str] = {}
|
||||
self._config = config
|
||||
deployment_state.labels(self._name).state("ready")
|
||||
|
||||
@property
|
||||
def default_service(self) -> str | None:
|
||||
def default_service(self) -> str:
|
||||
if not self._default_service:
|
||||
self._default_service = list(self._workflow_services.keys())[0]
|
||||
return self._default_service
|
||||
|
||||
@property
|
||||
@@ -120,80 +107,27 @@ class Deployment:
|
||||
"""
|
||||
self._running = True
|
||||
|
||||
# Control Plane
|
||||
tasks = await self._start_control_plane()
|
||||
# UI
|
||||
if self._config.ui:
|
||||
await self._start_ui_server()
|
||||
|
||||
# Start the services. It makes no sense for a deployment to have no services but
|
||||
# the configuration allows it, so let's be defensive here.
|
||||
deployment_state.labels(self._name).state("starting_services")
|
||||
if self._workflow_services:
|
||||
tasks.append(asyncio.create_task(self._run_services()))
|
||||
async def reload(self, config: DeploymentConfig) -> None:
|
||||
# Reset default service, it might change across reloads
|
||||
self._default_service = None
|
||||
# Tear down the UI server
|
||||
self._stop_ui_server()
|
||||
# Reload the services
|
||||
self._workflow_services = self._load_services(config)
|
||||
|
||||
# UI
|
||||
if self._config.ui:
|
||||
await self._start_ui_server()
|
||||
|
||||
# Run allthethings
|
||||
deployment_state.labels(self._name).state("running")
|
||||
await asyncio.gather(*tasks)
|
||||
deployment_state.labels(self._name).state("stopped")
|
||||
self._running = False
|
||||
def _stop_ui_server(self) -> None:
|
||||
if self._ui_server_process is None:
|
||||
return
|
||||
|
||||
async def reload(self, config: DeploymentConfig) -> None:
|
||||
"""Reload this deployment by restarting its services.
|
||||
|
||||
The reload process consists in cancelling the services tasks
|
||||
and rely on the fact that _run_services() will restart them
|
||||
with the new configuration. This function won't return until
|
||||
_run_services will trigger the _service_startup_complete signal.
|
||||
"""
|
||||
self._workflow_services = self._load_services(config)
|
||||
self._default_service = config.default_service
|
||||
|
||||
for t in self._service_tasks:
|
||||
# t is awaited in _run_services(), we don't need to await here
|
||||
t.cancel()
|
||||
|
||||
# Hold until _run_services() has restarted all the tasks
|
||||
await self._service_startup_complete.wait()
|
||||
|
||||
async def _start_control_plane(self) -> list[asyncio.Task]:
|
||||
tasks = []
|
||||
tasks.append(asyncio.create_task(self._control_plane.launch_server()))
|
||||
# Wait for the Control Plane to boot
|
||||
try:
|
||||
async for attempt in AsyncRetrying(
|
||||
wait=wait_exponential(min=1, max=10),
|
||||
):
|
||||
with attempt:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(self._control_plane_config.url)
|
||||
response.raise_for_status()
|
||||
except RetryError:
|
||||
msg = f"Unable to reach Control Plane at {self._control_plane_config.url}"
|
||||
raise DeploymentError(msg)
|
||||
|
||||
return tasks
|
||||
|
||||
async def _run_services(self) -> None:
|
||||
"""Start an asyncio task for each service and gather them.
|
||||
|
||||
For the time self._running holds true, the tasks will be restarted
|
||||
if they are all cancelled. This is to support the reload process
|
||||
(see reload() for more details).
|
||||
"""
|
||||
while self._running:
|
||||
self._service_tasks = []
|
||||
# If this is a reload, self._workflow_services contains the updated configurations
|
||||
for name, wfs in self._workflow_services.items():
|
||||
logger.debug(f"Starting service {name}")
|
||||
service_task = asyncio.create_task(wfs.launch_server())
|
||||
self._service_tasks.append(service_task)
|
||||
await wfs.register_to_control_plane(self._control_plane_config.url)
|
||||
|
||||
# If this is a reload, unblock the reload() function signalling that tasks are up and running
|
||||
self._service_startup_complete.set()
|
||||
await asyncio.gather(*self._service_tasks)
|
||||
self._ui_server_process.terminate()
|
||||
|
||||
async def _start_ui_server(self) -> None:
|
||||
"""Creates WorkflowService instances according to the configuration object."""
|
||||
@@ -224,7 +158,7 @@ class Deployment:
|
||||
# Override PORT and force using the one from the deployment.yaml file
|
||||
env["PORT"] = str(self._config.ui.port)
|
||||
|
||||
process = await asyncio.create_subprocess_exec(
|
||||
self._ui_server_process = await asyncio.create_subprocess_exec(
|
||||
"pnpm",
|
||||
"run",
|
||||
"dev",
|
||||
@@ -232,9 +166,9 @@ class Deployment:
|
||||
env=env,
|
||||
)
|
||||
|
||||
print(f"Started Next.js app with PID {process.pid}")
|
||||
print(f"Started Next.js app with PID {self._ui_server_process.pid}")
|
||||
|
||||
def _load_services(self, config: DeploymentConfig) -> dict[str, WorkflowService]:
|
||||
def _load_services(self, config: DeploymentConfig) -> dict[str, Workflow]:
|
||||
"""Creates WorkflowService instances according to the configuration object."""
|
||||
deployment_state.labels(self._name).state("loading_services")
|
||||
workflow_services = {}
|
||||
@@ -247,21 +181,10 @@ class Deployment:
|
||||
# TODO: possibly start the default service if not running already
|
||||
continue
|
||||
|
||||
# FIXME: Momentarily assuming everything is a workflow
|
||||
if service_config.import_path is None:
|
||||
msg = "path field in service definition must be set"
|
||||
raise ValueError(msg)
|
||||
|
||||
if service_config.port is None:
|
||||
# This won't happen if we arrive here from Manager.deploy(), the manager will assign a port
|
||||
msg = "port field in service definition must be set"
|
||||
raise ValueError(msg)
|
||||
|
||||
if service_config.host is None:
|
||||
# This won't happen if we arrive here from Manager.deploy(), the manager will assign a host
|
||||
msg = "host field in service definition must be set"
|
||||
raise ValueError(msg)
|
||||
|
||||
# Sync the service source
|
||||
service_state.labels(self._name, service_id).state("syncing")
|
||||
destination = self._deployment_path.resolve()
|
||||
@@ -285,27 +208,17 @@ class Deployment:
|
||||
sys.path.append(str(pythonpath))
|
||||
|
||||
module = importlib.import_module(module_name)
|
||||
workflow_services[service_id] = getattr(module, workflow_name)
|
||||
|
||||
workflow = getattr(module, workflow_name)
|
||||
workflow_config = WorkflowServiceConfig(
|
||||
host=service_config.host,
|
||||
port=service_config.port,
|
||||
internal_host="0.0.0.0",
|
||||
internal_port=service_config.port,
|
||||
service_name=service_id,
|
||||
)
|
||||
workflow_services[service_id] = WorkflowService(
|
||||
workflow=workflow,
|
||||
message_queue=self._queue_client,
|
||||
config=workflow_config,
|
||||
)
|
||||
service_state.labels(self._name, service_id).state("ready")
|
||||
|
||||
if config.default_service in workflow_services:
|
||||
self._default_service = config.default_service
|
||||
else:
|
||||
msg = f"There is no service with id '{config.default_service}' in this deployment, cannot set default."
|
||||
logger.warning(msg)
|
||||
if config.default_service:
|
||||
if config.default_service in workflow_services:
|
||||
self._default_service = config.default_service
|
||||
else:
|
||||
msg = f"Service with id '{config.default_service}' does not exist, cannot set it as default."
|
||||
logger.warning(msg)
|
||||
self._default_service = None
|
||||
|
||||
return workflow_services
|
||||
|
||||
@@ -428,26 +341,6 @@ class Deployment:
|
||||
msg = f"Unable to install service dependencies using command '{e.cmd}': {e.stderr}"
|
||||
raise DeploymentError(msg) from None
|
||||
|
||||
def _load_message_queue_client(
|
||||
self, cfg: MessageQueueConfig | None
|
||||
) -> AbstractMessageQueue:
|
||||
# Use the SimpleMessageQueue as the default
|
||||
if cfg is None:
|
||||
# we use model_validate instead of __init__ to avoid static checkers complaining over field aliases
|
||||
cfg = SimpleMessageQueueConfig()
|
||||
|
||||
if cfg.type == "kafka":
|
||||
return KafkaMessageQueue(cfg)
|
||||
elif cfg.type == "rabbitmq":
|
||||
return RabbitMQMessageQueue(cfg)
|
||||
elif cfg.type == "redis":
|
||||
return RedisMessageQueue(cfg)
|
||||
elif cfg.type == "simple":
|
||||
return SimpleMessageQueue(cfg)
|
||||
else:
|
||||
msg = f"Unsupported message queue: {cfg.type}"
|
||||
raise ValueError(msg)
|
||||
|
||||
|
||||
class Manager:
|
||||
"""The Manager orchestrates deployments and their runtime.
|
||||
@@ -544,33 +437,6 @@ class Manager:
|
||||
msg = "Reached the maximum number of deployments, cannot schedule more"
|
||||
raise ValueError(msg)
|
||||
|
||||
# Set the control plane TCP port in the config where not specified
|
||||
self._assign_control_plane_address(config)
|
||||
|
||||
# Get the message queue configuration
|
||||
msg_queue = config.message_queue or SimpleMessageQueueConfig()
|
||||
|
||||
# Spawn SimpleMessageQueue server if needed
|
||||
if (
|
||||
isinstance(msg_queue, SimpleMessageQueueConfig)
|
||||
and self._simple_message_queue_server is None
|
||||
):
|
||||
self._simple_message_queue_server = asyncio.create_task(
|
||||
SimpleMessageQueueServer(msg_queue).launch_server()
|
||||
)
|
||||
|
||||
# the other components need the queue to run in order to start, give the queue some time to start
|
||||
try:
|
||||
async for attempt in AsyncRetrying(
|
||||
wait=wait_exponential(min=1, max=10),
|
||||
):
|
||||
with attempt:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(msg_queue.base_url)
|
||||
response.raise_for_status()
|
||||
except RetryError:
|
||||
msg = f"Unable to reach SimpleMessageQueueServer at {msg_queue.base_url}"
|
||||
raise DeploymentError(msg)
|
||||
deployment = Deployment(
|
||||
config=config,
|
||||
base_path=Path(base_path),
|
||||
@@ -578,7 +444,7 @@ class Manager:
|
||||
local=local,
|
||||
)
|
||||
self._deployments[config.name] = deployment
|
||||
self._pool.apply_async(func=asyncio.run, args=(deployment.start(),))
|
||||
await deployment.start()
|
||||
else:
|
||||
if config.name not in self._deployments:
|
||||
msg = f"Cannot find deployment to reload: {config.name}"
|
||||
@@ -586,11 +452,3 @@ class Manager:
|
||||
|
||||
deployment = self._deployments[config.name]
|
||||
await deployment.reload(config)
|
||||
|
||||
def _assign_control_plane_address(self, config: DeploymentConfig) -> None:
|
||||
for service in config.services.values():
|
||||
if not service.port:
|
||||
service.port = self._last_control_plane_port
|
||||
self._last_control_plane_port += 1
|
||||
if not service.host:
|
||||
service.host = "localhost"
|
||||
|
||||
@@ -8,6 +8,9 @@ import websockets
|
||||
from fastapi import APIRouter, File, HTTPException, Request, UploadFile, WebSocket
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
from starlette.background import BackgroundTask
|
||||
from workflows import Context
|
||||
from workflows.context import JsonSerializer
|
||||
from workflows.handler import WorkflowHandler
|
||||
|
||||
from llama_deploy.apiserver.deployment_config_parser import DeploymentConfig
|
||||
from llama_deploy.apiserver.server import manager
|
||||
@@ -17,7 +20,7 @@ from llama_deploy.types import (
|
||||
SessionDefinition,
|
||||
TaskDefinition,
|
||||
)
|
||||
from llama_deploy.types.core import TaskResult
|
||||
from llama_deploy.types.core import TaskResult, generate_id
|
||||
|
||||
deployments_router = APIRouter(
|
||||
prefix="/deployments",
|
||||
@@ -76,18 +79,17 @@ async def create_deployment_task(
|
||||
detail=f"Service '{task_definition.service_id}' not found in deployment 'deployment_name'",
|
||||
)
|
||||
|
||||
workflow = deployment._workflow_services[task_definition.service_id]
|
||||
if session_id:
|
||||
session = await deployment.client.core.sessions.get(session_id)
|
||||
context = deployment._contexts[session_id]
|
||||
result = await workflow.run(
|
||||
context=context, **json.loads(task_definition.input)
|
||||
)
|
||||
else:
|
||||
session = await deployment.client.core.sessions.create()
|
||||
|
||||
result = await session.run(
|
||||
task_definition.service_id or "", **json.loads(task_definition.input)
|
||||
)
|
||||
|
||||
# Assume the request does not care about the session if no session_id is provided
|
||||
if session_id is None:
|
||||
await deployment.client.core.sessions.delete(session.id)
|
||||
if task_definition.input:
|
||||
result = await workflow.run(**json.loads(task_definition.input))
|
||||
else:
|
||||
result = await workflow.run()
|
||||
|
||||
return JSONResponse(result)
|
||||
|
||||
@@ -109,16 +111,20 @@ async def create_deployment_task_nowait(
|
||||
)
|
||||
task_definition.service_id = deployment.default_service
|
||||
|
||||
workflow = deployment._workflow_services[task_definition.service_id]
|
||||
if session_id:
|
||||
session = await deployment.client.core.sessions.get(session_id)
|
||||
context = deployment._contexts[session_id]
|
||||
handler = workflow.run(context=context, **json.loads(task_definition.input))
|
||||
else:
|
||||
session = await deployment.client.core.sessions.create()
|
||||
session_id = session.id
|
||||
handler = workflow.run(**json.loads(task_definition.input))
|
||||
session_id = generate_id()
|
||||
deployment._contexts[session_id] = handler.ctx or Context(workflow)
|
||||
|
||||
handler_id = generate_id()
|
||||
deployment._handlers[handler_id] = handler
|
||||
deployment._handler_inputs[handler_id] = task_definition.input
|
||||
task_definition.session_id = session_id
|
||||
task_definition.task_id = await session.run_nowait(
|
||||
task_definition.service_id or "", **json.loads(task_definition.input)
|
||||
)
|
||||
task_definition.task_id = handler_id
|
||||
|
||||
return task_definition
|
||||
|
||||
@@ -135,9 +141,10 @@ async def send_event(
|
||||
if deployment is None:
|
||||
raise HTTPException(status_code=404, detail="Deployment not found")
|
||||
|
||||
session = await deployment.client.core.sessions.get(session_id)
|
||||
|
||||
await session.send_event_def(task_id=task_id, ev_def=event_def)
|
||||
ctx = deployment._contexts[session_id]
|
||||
serializer = JsonSerializer()
|
||||
event = serializer.deserialize(event_def.event_obj_str)
|
||||
ctx.send_event(event)
|
||||
|
||||
return event_def
|
||||
|
||||
@@ -160,18 +167,20 @@ async def get_events(
|
||||
if deployment is None:
|
||||
raise HTTPException(status_code=404, detail="Deployment not found")
|
||||
|
||||
session = await deployment.client.core.sessions.get(session_id)
|
||||
|
||||
async def event_stream() -> AsyncGenerator[str, None]:
|
||||
async def event_stream(handler: WorkflowHandler) -> AsyncGenerator[str, None]:
|
||||
serializer = JsonSerializer()
|
||||
# need to convert back to str to use SSE
|
||||
async for event in session.get_task_result_stream(task_id):
|
||||
async for event in handler.stream_events():
|
||||
data = json.loads(serializer.serialize(event))
|
||||
if raw_event:
|
||||
yield json.dumps(event) + "\n"
|
||||
yield json.dumps(data) + "\n"
|
||||
else:
|
||||
yield json.dumps(event.get("value")) + "\n"
|
||||
yield json.dumps(data.get("value")) + "\n"
|
||||
await asyncio.sleep(0.01)
|
||||
await handler
|
||||
|
||||
return StreamingResponse(
|
||||
event_stream(),
|
||||
event_stream(deployment._handlers[task_id]),
|
||||
media_type="application/x-ndjson",
|
||||
)
|
||||
|
||||
@@ -185,8 +194,8 @@ async def get_task_result(
|
||||
if deployment is None:
|
||||
raise HTTPException(status_code=404, detail="Deployment not found")
|
||||
|
||||
session = await deployment.client.core.sessions.get(session_id)
|
||||
return await session.get_task_result(task_id)
|
||||
handler = deployment._handlers[task_id]
|
||||
return TaskResult(task_id=task_id, history=[], result=await handler)
|
||||
|
||||
|
||||
@deployments_router.get("/{deployment_name}/tasks")
|
||||
@@ -199,9 +208,11 @@ async def get_tasks(
|
||||
raise HTTPException(status_code=404, detail="Deployment not found")
|
||||
|
||||
tasks: list[TaskDefinition] = []
|
||||
for session in await deployment.client.core.sessions.list():
|
||||
for task_def in await session.get_tasks():
|
||||
tasks.append(task_def)
|
||||
for task_id in deployment._handlers.keys():
|
||||
tasks.append(
|
||||
TaskDefinition(task_id=task_id, input=deployment._handler_inputs[task_id])
|
||||
)
|
||||
|
||||
return tasks
|
||||
|
||||
|
||||
@@ -214,8 +225,7 @@ async def get_sessions(
|
||||
if deployment is None:
|
||||
raise HTTPException(status_code=404, detail="Deployment not found")
|
||||
|
||||
sessions = await deployment.client.core.sessions.list()
|
||||
return [SessionDefinition(session_id=s.id) for s in sessions]
|
||||
return [SessionDefinition(session_id=k) for k in deployment._contexts.keys()]
|
||||
|
||||
|
||||
@deployments_router.get("/{deployment_name}/sessions/{session_id}")
|
||||
@@ -225,8 +235,7 @@ async def get_session(deployment_name: str, session_id: str) -> SessionDefinitio
|
||||
if deployment is None:
|
||||
raise HTTPException(status_code=404, detail="Deployment not found")
|
||||
|
||||
session = await deployment.client.core.sessions.get(session_id)
|
||||
return SessionDefinition(session_id=session.id)
|
||||
return SessionDefinition(session_id=session_id)
|
||||
|
||||
|
||||
@deployments_router.post("/{deployment_name}/sessions/create")
|
||||
@@ -236,8 +245,11 @@ async def create_session(deployment_name: str) -> SessionDefinition:
|
||||
if deployment is None:
|
||||
raise HTTPException(status_code=404, detail="Deployment not found")
|
||||
|
||||
session = await deployment.client.core.sessions.create()
|
||||
return SessionDefinition(session_id=session.id)
|
||||
workflow = deployment._workflow_services[deployment.default_service]
|
||||
session_id = generate_id()
|
||||
deployment._contexts[session_id] = Context(workflow)
|
||||
|
||||
return SessionDefinition(session_id=session_id)
|
||||
|
||||
|
||||
@deployments_router.post("/{deployment_name}/sessions/delete")
|
||||
@@ -247,7 +259,7 @@ async def delete_session(deployment_name: str, session_id: str) -> None:
|
||||
if deployment is None:
|
||||
raise HTTPException(status_code=404, detail="Deployment not found")
|
||||
|
||||
await deployment.client.core.sessions.delete(session_id)
|
||||
deployment._contexts.pop(session_id)
|
||||
|
||||
|
||||
async def _ws_proxy(ws: WebSocket, upstream_url: str) -> None:
|
||||
|
||||
@@ -79,6 +79,18 @@ class SessionCollection(Collection):
|
||||
|
||||
return r.json()
|
||||
|
||||
async def get(self, id: str) -> SessionDefinition:
|
||||
"""Gets a deployment by id."""
|
||||
get_url = f"{self.client.api_server_url}/deployments/{self.deployment_id}/sessions/{id}"
|
||||
await self.client.request(
|
||||
"GET",
|
||||
get_url,
|
||||
verify=not self.client.disable_ssl,
|
||||
timeout=self.client.timeout,
|
||||
)
|
||||
model_class = self._prepare(SessionDefinition)
|
||||
return model_class(client=self.client, id=id)
|
||||
|
||||
|
||||
class Task(Model):
|
||||
"""A model representing a task belonging to a given session in the given deployment."""
|
||||
|
||||
+2
-1
@@ -51,7 +51,8 @@ dependencies = [
|
||||
"platformdirs>=4.3.6,<5",
|
||||
"rich>=13.9.4,<14",
|
||||
"brotli>=1.1.0",
|
||||
"websockets>=15.0.1"
|
||||
"websockets>=15.0.1",
|
||||
"llama-index-workflows>=0.2.1"
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
import os
|
||||
from llama_index.core.workflow import Context, StartEvent, StopEvent, Workflow, step
|
||||
|
||||
from workflows import Context, Workflow, step
|
||||
from workflows.events import StartEvent, StopEvent
|
||||
|
||||
|
||||
class MyWorkflow(Workflow):
|
||||
|
||||
@@ -130,6 +130,9 @@ def test_run_deployment_task(
|
||||
) -> None:
|
||||
deployment = mock.AsyncMock()
|
||||
deployment.default_service = "TestService"
|
||||
mocked_workflow = mock.AsyncMock()
|
||||
mocked_workflow.run.return_value = "foo"
|
||||
deployment._workflow_services = {"TestService": mocked_workflow}
|
||||
|
||||
session = mock.AsyncMock(id="42")
|
||||
deployment.client.core.sessions.create.return_value = session
|
||||
@@ -145,7 +148,6 @@ def test_run_deployment_task(
|
||||
json={"input": "{}"},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
deployment.client.core.sessions.delete.assert_called_with("42")
|
||||
|
||||
deployment.reset_mock()
|
||||
response = http_client.post(
|
||||
@@ -163,13 +165,15 @@ def test_create_deployment_task(
|
||||
deployment = mock.AsyncMock()
|
||||
deployment.default_service = "TestService"
|
||||
|
||||
session = mock.AsyncMock(id="42")
|
||||
deployment.client.core.sessions.create.return_value = session
|
||||
session.run_nowait.return_value = "test_task_id"
|
||||
|
||||
session_from_get = mock.AsyncMock(id="84")
|
||||
deployment.client.core.sessions.get.return_value = session_from_get
|
||||
session_from_get.run_nowait.return_value = "another_test_task_id"
|
||||
# Mock workflow that returns a handler
|
||||
mock_workflow = mock.MagicMock()
|
||||
mock_handler = mock.MagicMock()
|
||||
mock_handler.ctx = mock.MagicMock()
|
||||
mock_workflow.run.return_value = mock_handler
|
||||
deployment._workflow_services = {"TestService": mock_workflow}
|
||||
deployment._handlers = {}
|
||||
deployment._handler_inputs = {}
|
||||
deployment._contexts = {"84": mock.MagicMock()} # For session_id test
|
||||
|
||||
mock_manager.get_deployment.return_value = deployment
|
||||
response = http_client.post(
|
||||
@@ -178,7 +182,7 @@ def test_create_deployment_task(
|
||||
)
|
||||
assert response.status_code == 200
|
||||
td = TaskDefinition(**response.json())
|
||||
assert td.task_id == "test_task_id"
|
||||
assert td.task_id is not None
|
||||
|
||||
deployment.reset_mock()
|
||||
response = http_client.post(
|
||||
@@ -187,7 +191,6 @@ def test_create_deployment_task(
|
||||
params={"session_id": 84},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
deployment.client.core.sessions.delete.assert_not_called()
|
||||
|
||||
|
||||
def test_send_event_not_found(
|
||||
@@ -211,9 +214,8 @@ def test_send_event(
|
||||
) -> None:
|
||||
deployment = mock.AsyncMock()
|
||||
deployment.default_service = "TestService"
|
||||
session = mock.AsyncMock()
|
||||
deployment.client.core.sessions.create.return_value = session
|
||||
session.id = "42"
|
||||
mock_context = mock.MagicMock()
|
||||
deployment._contexts = {"42": mock_context}
|
||||
mock_manager.get_deployment.return_value = deployment
|
||||
|
||||
serializer = JsonSerializer()
|
||||
@@ -256,12 +258,23 @@ async def test_get_event_stream(
|
||||
|
||||
deployment = mock.AsyncMock()
|
||||
deployment.default_service = "TestService"
|
||||
session = mock.MagicMock()
|
||||
deployment.client.core.sessions.get.return_value = session
|
||||
|
||||
# Mock handler that streams events
|
||||
class MockHandler:
|
||||
async def stream_events(self): # type:ignore
|
||||
for event in mock_events:
|
||||
yield Event(msg=event["value"]["_data"]["msg"])
|
||||
|
||||
def __await__(self): # type:ignore
|
||||
# Make it awaitable
|
||||
async def await_impl(): # type:ignore
|
||||
return "completed"
|
||||
|
||||
return await_impl().__await__()
|
||||
|
||||
mock_handler = MockHandler()
|
||||
deployment._handlers = {"test_task_id": mock_handler}
|
||||
mock_manager.get_deployment.return_value = deployment
|
||||
mocked_get_task_result_stream = mock.MagicMock()
|
||||
mocked_get_task_result_stream.__aiter__.return_value = mock_events
|
||||
session.get_task_result_stream.return_value = mocked_get_task_result_stream
|
||||
|
||||
response = http_client.get(
|
||||
"/deployments/test-deployment/tasks/test_task_id/events/?session_id=42",
|
||||
@@ -272,8 +285,6 @@ async def test_get_event_stream(
|
||||
data = json.loads(line)
|
||||
assert data == mock_events[ix].get("value")
|
||||
ix += 1
|
||||
deployment.client.core.sessions.get.assert_called_with("42")
|
||||
session.get_task_result_stream.assert_called_with("test_task_id")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -289,12 +300,23 @@ async def test_get_event_stream_raw(
|
||||
|
||||
deployment = mock.AsyncMock()
|
||||
deployment.default_service = "TestService"
|
||||
session = mock.MagicMock()
|
||||
deployment.client.core.sessions.get.return_value = session
|
||||
|
||||
# Mock handler that streams events
|
||||
class MockHandler:
|
||||
async def stream_events(self): # type:ignore
|
||||
for event in mock_events:
|
||||
yield Event(msg=event["value"]["_data"]["msg"])
|
||||
|
||||
def __await__(self): # type:ignore
|
||||
# Make it awaitable
|
||||
async def await_impl(): # type:ignore
|
||||
return "completed"
|
||||
|
||||
return await_impl().__await__()
|
||||
|
||||
mock_handler = MockHandler()
|
||||
deployment._handlers = {"test_task_id": mock_handler}
|
||||
mock_manager.get_deployment.return_value = deployment
|
||||
mocked_get_task_result_stream = mock.MagicMock()
|
||||
mocked_get_task_result_stream.__aiter__.return_value = mock_events
|
||||
session.get_task_result_stream.return_value = mocked_get_task_result_stream
|
||||
|
||||
response = http_client.get(
|
||||
"/deployments/test-deployment/tasks/test_task_id/events/?session_id=42&raw_event=true",
|
||||
@@ -309,8 +331,6 @@ async def test_get_event_stream_raw(
|
||||
assert "qualified_name" in data
|
||||
assert data["qualified_name"] == "llama_index.core.workflow.events.Event"
|
||||
ix += 1
|
||||
deployment.client.core.sessions.get.assert_called_with("42")
|
||||
session.get_task_result_stream.assert_called_with("test_task_id")
|
||||
|
||||
|
||||
def test_get_task_result_not_found(
|
||||
@@ -337,10 +357,9 @@ def test_get_tasks(
|
||||
http_client: TestClient, data_path: Path, mock_manager: MagicMock
|
||||
) -> None:
|
||||
deployment = mock.AsyncMock()
|
||||
deployment._handlers = {"task1": mock.MagicMock()}
|
||||
deployment._handler_inputs = {"task1": "foo"}
|
||||
mock_manager.get_deployment.return_value = deployment
|
||||
session = mock.AsyncMock()
|
||||
session.get_tasks.return_value = [TaskDefinition(input="foo")]
|
||||
deployment.client.core.sessions.list.return_value = [session]
|
||||
|
||||
response = http_client.get(
|
||||
"/deployments/test-deployment/tasks",
|
||||
@@ -355,11 +374,18 @@ def test_get_task_result(
|
||||
) -> None:
|
||||
deployment = mock.AsyncMock()
|
||||
deployment.default_service = "TestService"
|
||||
session = mock.AsyncMock()
|
||||
deployment.client.core.sessions.get.return_value = session
|
||||
session.get_task_result.return_value = TaskResult(
|
||||
result="test_result", history=[], task_id="test_task_id"
|
||||
)
|
||||
|
||||
# Mock the handler to return the expected result - needs to be awaitable
|
||||
class MockHandler:
|
||||
def __await__(self): # type:ignore
|
||||
async def await_impl(): # type:ignore
|
||||
return "test_result"
|
||||
|
||||
return await_impl().__await__()
|
||||
|
||||
mock_handler = MockHandler()
|
||||
deployment._handlers = {"test_task_id": mock_handler}
|
||||
|
||||
mock_manager.get_deployment.return_value = deployment
|
||||
|
||||
response = http_client.get(
|
||||
@@ -367,8 +393,6 @@ def test_get_task_result(
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert TaskResult(**response.json()).result == "test_result"
|
||||
session.get_task_result.assert_called_with("test_task_id")
|
||||
deployment.client.core.sessions.get.assert_called_with("42")
|
||||
|
||||
|
||||
def test_get_sessions_not_found(
|
||||
@@ -386,7 +410,7 @@ def test_get_sessions(
|
||||
) -> None:
|
||||
deployment = mock.AsyncMock()
|
||||
deployment.default_service = "TestService"
|
||||
deployment.client.list_sessions.return_value = []
|
||||
deployment._contexts = {} # Empty contexts
|
||||
mock_manager.get_deployment.return_value = deployment
|
||||
|
||||
response = http_client.get(
|
||||
@@ -412,13 +436,13 @@ def test_delete_session(
|
||||
) -> None:
|
||||
deployment = mock.AsyncMock()
|
||||
deployment.default_service = "TestService"
|
||||
deployment._contexts = {"42": mock.MagicMock()} # Mock context to be deleted
|
||||
mock_manager.get_deployment.return_value = deployment
|
||||
|
||||
response = http_client.post(
|
||||
"/deployments/test-deployment/sessions/delete/?session_id=42",
|
||||
)
|
||||
assert response.status_code == 200
|
||||
deployment.client.core.sessions.delete.assert_called_with("42")
|
||||
|
||||
|
||||
def test_get_session_not_found(
|
||||
@@ -458,8 +482,9 @@ def test_create_session(
|
||||
http_client: TestClient, data_path: Path, mock_manager: MagicMock
|
||||
) -> None:
|
||||
deployment = mock.AsyncMock()
|
||||
session = mock.AsyncMock(id="test-session-id")
|
||||
deployment.client.core.sessions.create.return_value = session
|
||||
deployment.default_service = "TestService"
|
||||
deployment._workflow_services = {"TestService": mock.MagicMock()}
|
||||
deployment._contexts = {}
|
||||
mock_manager.get_deployment.return_value = deployment
|
||||
|
||||
response = http_client.post(
|
||||
@@ -467,15 +492,13 @@ def test_create_session(
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {
|
||||
"session_id": "test-session-id",
|
||||
"state": {},
|
||||
"task_ids": [],
|
||||
}
|
||||
# The response should contain a generated session_id
|
||||
assert "session_id" in response.json()
|
||||
assert response.json()["state"] == {}
|
||||
assert response.json()["task_ids"] == []
|
||||
|
||||
# Verify the mocked calls
|
||||
mock_manager.get_deployment.assert_called_once_with("test-deployment")
|
||||
deployment.client.core.sessions.create.assert_called_once()
|
||||
|
||||
|
||||
@respx.mock
|
||||
|
||||
@@ -1,14 +1,13 @@
|
||||
import asyncio
|
||||
from collections.abc import Generator
|
||||
import subprocess
|
||||
import sys
|
||||
from collections.abc import Generator
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
from tenacity import RetryError
|
||||
|
||||
from llama_deploy.apiserver.deployment import (
|
||||
SOURCE_MANAGERS,
|
||||
@@ -23,9 +22,7 @@ from llama_deploy.apiserver.deployment_config_parser import (
|
||||
SyncPolicy,
|
||||
UIService,
|
||||
)
|
||||
from llama_deploy.control_plane import ControlPlaneConfig, ControlPlaneServer
|
||||
from llama_deploy.message_queues import SimpleMessageQueue
|
||||
from llama_deploy.message_queues.redis import RedisMessageQueueConfig
|
||||
from llama_deploy.control_plane import ControlPlaneConfig
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -60,11 +57,10 @@ def test_deployment_ctor(data_path: Path, mock_importlib: Any, tmp_path: Path) -
|
||||
sm_dict["git"].return_value.sync.assert_called_once()
|
||||
assert d.name == "TestDeployment"
|
||||
assert d._deployment_path.name == "TestDeployment"
|
||||
assert type(d._control_plane) is ControlPlaneServer
|
||||
assert len(d._workflow_services) == 1
|
||||
assert d.service_names == ["test-workflow"]
|
||||
assert d.client is not None
|
||||
assert d.default_service is None
|
||||
assert d.default_service == "test-workflow"
|
||||
|
||||
|
||||
def test_deployment_ctor_missing_service_path(data_path: Path, tmp_path: Path) -> None:
|
||||
@@ -76,24 +72,6 @@ def test_deployment_ctor_missing_service_path(data_path: Path, tmp_path: Path) -
|
||||
Deployment(config=config, base_path=data_path, deployment_path=tmp_path)
|
||||
|
||||
|
||||
def test_deployment_ctor_missing_service_port(data_path: Path, tmp_path: Path) -> None:
|
||||
config = DeploymentConfig.from_yaml(data_path / "git_service.yaml")
|
||||
config.services["test-workflow"].port = None
|
||||
with pytest.raises(
|
||||
ValueError, match="port field in service definition must be set"
|
||||
):
|
||||
Deployment(config=config, base_path=data_path, deployment_path=tmp_path)
|
||||
|
||||
|
||||
def test_deployment_ctor_missing_service_host(data_path: Path, tmp_path: Path) -> None:
|
||||
config = DeploymentConfig.from_yaml(data_path / "git_service.yaml")
|
||||
config.services["test-workflow"].host = None
|
||||
with pytest.raises(
|
||||
ValueError, match="host field in service definition must be set"
|
||||
):
|
||||
Deployment(config=config, base_path=data_path, deployment_path=tmp_path)
|
||||
|
||||
|
||||
def test_deployment_ctor_skip_default_service(
|
||||
data_path: Path, mock_importlib: Any, tmp_path: Path
|
||||
) -> None:
|
||||
@@ -112,10 +90,9 @@ def test_deployment_ctor_invalid_default_service(
|
||||
config = DeploymentConfig.from_yaml(data_path / "local.yaml")
|
||||
config.default_service = "does-not-exist"
|
||||
|
||||
d = Deployment(config=config, base_path=data_path, deployment_path=tmp_path)
|
||||
assert d.default_service is None
|
||||
Deployment(config=config, base_path=data_path, deployment_path=tmp_path)
|
||||
assert (
|
||||
"There is no service with id 'does-not-exist' in this deployment, cannot set default."
|
||||
"Service with id 'does-not-exist' does not exist, cannot set it as default."
|
||||
in caplog.text
|
||||
)
|
||||
|
||||
@@ -130,41 +107,6 @@ def test_deployment_ctor_default_service(
|
||||
assert d.default_service == "test-workflow"
|
||||
|
||||
|
||||
def test_deployment___load_message_queue_default(mocked_deployment: Deployment) -> None:
|
||||
q = mocked_deployment._load_message_queue_client(None)
|
||||
assert type(q) is SimpleMessageQueue
|
||||
assert q._config.port == 8001
|
||||
assert q._config.host == "127.0.0.1"
|
||||
|
||||
|
||||
def test_deployment___load_message_queue_not_supported(
|
||||
mocked_deployment: Deployment,
|
||||
) -> None:
|
||||
mocked_config = mock.MagicMock(queue_type="does_not_exist")
|
||||
with pytest.raises(ValueError, match="Unsupported message queue:"):
|
||||
mocked_deployment._load_message_queue_client(mocked_config)
|
||||
|
||||
|
||||
def test_deployment__load_message_queues(mocked_deployment: Deployment) -> None:
|
||||
with mock.patch("llama_deploy.apiserver.deployment.KafkaMessageQueue") as m:
|
||||
mocked_config = mock.MagicMock(type="kafka")
|
||||
mocked_config.model_dump.return_value = {"foo": "kafka"}
|
||||
mocked_deployment._load_message_queue_client(mocked_config)
|
||||
m.assert_called_with(mocked_config)
|
||||
|
||||
with mock.patch("llama_deploy.apiserver.deployment.RabbitMQMessageQueue") as m:
|
||||
mocked_config = mock.MagicMock(type="rabbitmq")
|
||||
mocked_config.model_dump.return_value = {"foo": "rabbitmq"}
|
||||
mocked_deployment._load_message_queue_client(mocked_config)
|
||||
m.assert_called_with(mocked_config)
|
||||
|
||||
with mock.patch("llama_deploy.apiserver.deployment.RedisMessageQueue") as m:
|
||||
mocked_config = mock.MagicMock(type="redis")
|
||||
mocked_config.model_dump.return_value = {"foo": "redis"}
|
||||
mocked_deployment._load_message_queue_client(mocked_config)
|
||||
m.assert_called_with(mocked_config)
|
||||
|
||||
|
||||
def test__install_dependencies(data_path: Path) -> None:
|
||||
config = DeploymentConfig.from_yaml(data_path / "python_dependencies.yaml")
|
||||
service_config = config.services["myworkflow"]
|
||||
@@ -530,17 +472,19 @@ async def test_manager_deploy_maximum_reached(data_path: Path) -> None:
|
||||
@pytest.mark.asyncio
|
||||
async def test_manager_deploy(data_path: Path) -> None:
|
||||
config = DeploymentConfig.from_yaml(data_path / "git_service.yaml")
|
||||
# Do not use SimpleMessageQueue here, to avoid starting the server
|
||||
config.message_queue = RedisMessageQueueConfig()
|
||||
|
||||
with mock.patch(
|
||||
"llama_deploy.apiserver.deployment.Deployment"
|
||||
) as mocked_deployment:
|
||||
# Mock the start method as an async method
|
||||
mocked_deployment.return_value.start = mock.AsyncMock()
|
||||
|
||||
m = Manager()
|
||||
m._serving = True
|
||||
m._deployments_path = Path()
|
||||
await m.deploy(config, base_path=str(data_path))
|
||||
mocked_deployment.assert_called_once()
|
||||
mocked_deployment.return_value.start.assert_awaited_once()
|
||||
assert m.deployment_names == ["TestDeployment"]
|
||||
assert m.get_deployment("TestDeployment") is not None
|
||||
|
||||
@@ -563,82 +507,6 @@ async def test_manager_serve_loop(tmp_path: Path) -> None:
|
||||
assert serve_task.exception() is None
|
||||
|
||||
|
||||
def test_manager_assign_control_plane_port(data_path: Path) -> None:
|
||||
m = Manager()
|
||||
config = DeploymentConfig.from_yaml(data_path / "service_ports.yaml")
|
||||
m._assign_control_plane_address(config)
|
||||
assert config.services["no-port"].port == 8002
|
||||
assert config.services["has-port"].port == 9999
|
||||
assert config.services["no-port-again"].port == 8003
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_control_plane_success(
|
||||
deployment_config: DeploymentConfig, tmp_path: Path
|
||||
) -> None:
|
||||
# Create deployment instance
|
||||
deployment = Deployment(
|
||||
config=deployment_config, base_path=Path(), deployment_path=tmp_path
|
||||
)
|
||||
|
||||
# Mock control plane methods
|
||||
deployment._control_plane.launch_server = mock.AsyncMock() # type: ignore
|
||||
|
||||
# Mock httpx client
|
||||
mock_response = mock.MagicMock()
|
||||
mock_response.raise_for_status = mock.MagicMock()
|
||||
|
||||
mock_client = mock.AsyncMock()
|
||||
mock_client.__aenter__.return_value.get.return_value = mock_response
|
||||
|
||||
with mock.patch("httpx.AsyncClient", return_value=mock_client):
|
||||
# Run the method
|
||||
tasks = await deployment._start_control_plane()
|
||||
|
||||
# Verify tasks were created
|
||||
assert len(tasks)
|
||||
assert all(isinstance(task, asyncio.Task) for task in tasks)
|
||||
|
||||
# Verify control plane methods were called
|
||||
deployment._control_plane.launch_server.assert_called_once()
|
||||
|
||||
# Verify health check was performed
|
||||
mock_client.__aenter__.return_value.get.assert_called_with(
|
||||
deployment_config.control_plane.url
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_control_plane_failure(
|
||||
deployment_config: DeploymentConfig, tmp_path: Path
|
||||
) -> None:
|
||||
# Create deployment instance
|
||||
deployment = Deployment(
|
||||
config=deployment_config, base_path=Path(), deployment_path=tmp_path
|
||||
)
|
||||
|
||||
# Mock control plane methods
|
||||
deployment._control_plane.launch_server = mock.AsyncMock() # type: ignore
|
||||
|
||||
# Create a mock attempt
|
||||
mock_attempt: asyncio.Future = asyncio.Future()
|
||||
mock_attempt.set_exception(Exception("Connection failed"))
|
||||
|
||||
# Mock AsyncRetrying to raise an exception
|
||||
with mock.patch(
|
||||
"llama_deploy.apiserver.deployment.AsyncRetrying",
|
||||
side_effect=RetryError(last_attempt=mock_attempt), # type: ignore
|
||||
):
|
||||
# Verify DeploymentError is raised
|
||||
with pytest.raises(DeploymentError) as exc_info:
|
||||
await deployment._start_control_plane()
|
||||
|
||||
assert "Unable to reach Control Plane" in str(exc_info.value)
|
||||
|
||||
# Verify control plane methods were still called
|
||||
deployment._control_plane.launch_server.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_sequence(
|
||||
deployment_config: DeploymentConfig, tmp_path: Path
|
||||
@@ -646,13 +514,8 @@ async def test_start_sequence(
|
||||
deployment = Deployment(
|
||||
config=deployment_config, base_path=Path(), deployment_path=tmp_path
|
||||
)
|
||||
deployment._start_control_plane = mock.AsyncMock() # type: ignore
|
||||
deployment._run_services = mock.AsyncMock() # type: ignore
|
||||
deployment._start_ui_server = mock.AsyncMock() # type: ignore
|
||||
await deployment.start()
|
||||
deployment._start_control_plane.assert_awaited_once()
|
||||
# no services should start
|
||||
deployment._run_services.assert_not_awaited()
|
||||
# no ui server
|
||||
deployment._start_ui_server.assert_not_awaited()
|
||||
|
||||
@@ -664,20 +527,12 @@ async def test_start_with_services(data_path: Path, tmp_path: Path) -> None:
|
||||
deployment = Deployment(
|
||||
config=config, base_path=data_path, deployment_path=tmp_path
|
||||
)
|
||||
deployment._start_control_plane = mock.AsyncMock(return_value=[]) # type: ignore
|
||||
deployment._run_services = mock.AsyncMock(return_value=[]) # type: ignore
|
||||
deployment._start_ui_server = mock.AsyncMock(return_value=[]) # type: ignore
|
||||
|
||||
sm_dict["git"] = mock.MagicMock()
|
||||
await deployment.start()
|
||||
sm_dict["git"].return_value.sync.assert_called_once()
|
||||
|
||||
# Verify control plane was started
|
||||
deployment._start_control_plane.assert_awaited_once()
|
||||
|
||||
# Verify services were started
|
||||
deployment._run_services.assert_awaited_once()
|
||||
|
||||
# Verify UI server was not started
|
||||
deployment._start_ui_server.assert_not_awaited()
|
||||
|
||||
@@ -690,15 +545,11 @@ async def test_start_with_services_ui(data_path: Path, tmp_path: Path) -> None:
|
||||
deployment = Deployment(
|
||||
config=config, base_path=data_path, deployment_path=tmp_path
|
||||
)
|
||||
deployment._start_control_plane = mock.AsyncMock(return_value=[]) # type: ignore
|
||||
deployment._run_services = mock.AsyncMock(return_value=[]) # type: ignore
|
||||
deployment._start_ui_server = mock.AsyncMock(return_value=[]) # type: ignore
|
||||
|
||||
sm_dict["git"] = mock.MagicMock()
|
||||
await deployment.start()
|
||||
|
||||
deployment._start_control_plane.assert_awaited_once()
|
||||
deployment._run_services.assert_awaited_once()
|
||||
deployment._start_ui_server.assert_awaited_once()
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user