mirror of
https://github.com/run-llama/llama_deploy.git
synced 2026-07-01 21:04:00 -04:00
refact: Remove Control Plane and Message Queues (#544)
This commit is contained in:
committed by
GitHub
parent
671295d518
commit
adc5dfbab5
@@ -16,15 +16,7 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python-version: ["3.10", "3.11", "3.12"]
|
||||
test-package:
|
||||
[
|
||||
"basic_hitl",
|
||||
"basic_streaming",
|
||||
"apiserver",
|
||||
"basic_session",
|
||||
"basic_workflow",
|
||||
"core",
|
||||
]
|
||||
test-package: ["apiserver"]
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
|
||||
@@ -35,21 +27,3 @@ jobs:
|
||||
|
||||
- name: Run All E2E Tests
|
||||
run: uv run -- pytest e2e_tests/${{ matrix.test-package }} -s
|
||||
|
||||
e2e-message-queues:
|
||||
runs-on: ubuntu-latest
|
||||
# E2E tests might get stuck, timeout aggressively for faster feedback
|
||||
timeout-minutes: 10
|
||||
strategy:
|
||||
# Let the matrix finish to see if the failure was transient
|
||||
fail-fast: false
|
||||
matrix:
|
||||
test-package: ["kafka", "rabbitmq", "redis", "simple"]
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
|
||||
- name: Install uv and set the python version
|
||||
uses: astral-sh/setup-uv@v5
|
||||
|
||||
- name: Run E2E Tests for message queues
|
||||
run: uv run -- pytest e2e_tests/message_queues/${{ matrix.test-package }} -s
|
||||
|
||||
@@ -36,12 +36,10 @@ pip install -U llama-deploy
|
||||
|
||||
1. **Seamless Deployment**: It bridges the gap between development and production, allowing you to deploy `llama_index`
|
||||
workflows with minimal changes to your code.
|
||||
2. **Scalability**: The microservices architecture enables easy scaling of individual components as your system grows.
|
||||
3. **Flexibility**: By using a hub-and-spoke architecture, you can easily swap out components (like message queues) or
|
||||
add new services without disrupting the entire system.
|
||||
4. **Fault Tolerance**: With built-in retry mechanisms and failure handling, LlamaDeploy adds robustness in
|
||||
production environments.
|
||||
5. **State Management**: The control plane manages state across services, simplifying complex multi-step processes.
|
||||
6. **Async-First**: Designed for high-concurrency scenarios, making it suitable for real-time and high-throughput
|
||||
applications.
|
||||
|
||||
|
||||
@@ -1,5 +0,0 @@
|
||||
# `control_plane`
|
||||
|
||||
::: llama_deploy.control_plane
|
||||
options:
|
||||
show_docstring_parameters: true
|
||||
@@ -1,6 +0,0 @@
|
||||
# message_queues
|
||||
|
||||
::: llama_deploy.message_queues.base
|
||||
options:
|
||||
members:
|
||||
- AbstractMessageQueue
|
||||
@@ -1,3 +0,0 @@
|
||||
# apache_kafka
|
||||
|
||||
::: llama_deploy.message_queues.apache_kafka
|
||||
@@ -1,3 +0,0 @@
|
||||
# rabbitmq
|
||||
|
||||
::: llama_deploy.message_queues.rabbitmq
|
||||
@@ -1,3 +0,0 @@
|
||||
# redis
|
||||
|
||||
::: llama_deploy.message_queues.redis
|
||||
@@ -1,3 +0,0 @@
|
||||
# simple
|
||||
|
||||
::: llama_deploy.message_queues.simple
|
||||
@@ -10,7 +10,3 @@
|
||||
## API Server functionalities
|
||||
|
||||
::: llama_deploy.client.models.apiserver
|
||||
|
||||
## Control Plane functionalities
|
||||
|
||||
::: llama_deploy.client.models.core
|
||||
|
||||
@@ -1,3 +0,0 @@
|
||||
# `services`
|
||||
|
||||
::: llama_deploy.services
|
||||
@@ -22,9 +22,6 @@ Run the interactive wizard:
|
||||
$ llamactl init
|
||||
Project name [llama-deploy-app]: hello-deploy
|
||||
Destination directory [.]:
|
||||
Control plane port [8000]:
|
||||
Select message queue type
|
||||
simple redis rabbitmq kafka [simple]:
|
||||
Workflow template:
|
||||
basic - Basic workflow with OpenAI integration (recommended)
|
||||
none - Do not create any sample workflow code
|
||||
@@ -114,5 +111,4 @@ hello-deploy/
|
||||
| Change the LLM, add tools or multiple steps | `src/workflow.py` – build with any [LlamaIndex Workflow](https://docs.llamaindex.ai/en/stable/understanding/workflows/). |
|
||||
| Add more workflows/services | Duplicate the `example_workflow` block in `deployment.yml` and point to your new workflow. |
|
||||
| Set secrets & environment variables | Use `env`/`env_files` inside each service. |
|
||||
| Switch message queue (e.g., Redis, Kafka) | Re-run `llamactl init` with a different `--message-queue-type` **or** edit the `message_queue` section manually. |
|
||||
| Deploy to production | Containerize your deployment (for example, see the [Google Cloud Run Example](https://github.com/run-llama/llama_deploy/tree/main/examples/google_cloud_run)), then run and scale the deployment anywhere you can run Docker/K8s. |
|
||||
|
||||
@@ -45,101 +45,6 @@ API is documented below.
|
||||
|
||||
!!swagger apiserver.json!!
|
||||
|
||||
## Control Plane
|
||||
|
||||
The control plane is responsible for managing the state of the system, including:
|
||||
|
||||
- Registering services.
|
||||
- Managing sessions and tasks.
|
||||
- Handling service completion.
|
||||
- Launching the control plane server.
|
||||
|
||||
The state of the system is persisted in a key-value store that by default consists of a simple mapping in memory.
|
||||
In particular, the state store contains:
|
||||
|
||||
- The name and definition of the registered services.
|
||||
- The active sessions and their relative tasks and event streams.
|
||||
- The Context, in case the service is of type Workflow,
|
||||
|
||||
In case you need a more scalable storage for the system state, you can set the `state_store_uri` field in the Control
|
||||
Plane configuration to point to one of the databases we support (see
|
||||
[the Python API reference](../../api_reference/llama_deploy/control_plane.md)) for more details.
|
||||
Using a scalable storage for the global state is mostly needed when:
|
||||
- You want to scale the control plane horizontally, and you want every instance to share the same global state.
|
||||
- The control plane has to deal with high traffic (many services, sessions and tasks).
|
||||
- The global state needs to be persisted across restarts (for example, workflow contexts are stored in the global state).
|
||||
|
||||
## Service
|
||||
|
||||
The general structure of a service is as follows:
|
||||
|
||||
- A service has a name.
|
||||
- A service has a service definition.
|
||||
- A service uses a message queue to send/receive messages.
|
||||
- A service has a processing loop, for continuous processing of messages.
|
||||
- A service can process a message.
|
||||
- A service can publish a message to another service.
|
||||
- A service can be launched in-process.
|
||||
- A service can be launched as a server.
|
||||
- A service can be registered to the control plane.
|
||||
- A service can be registered to the message queue.
|
||||
|
||||
## Message Queue
|
||||
|
||||
In addition to `SimpleMessageQueue`, we provide integrations for various
|
||||
message queue providers, such as RabbitMQ, Redis, etc. The general usage pattern
|
||||
for any of these message queues is the same as that for `SimpleMessageQueue`,
|
||||
however the appropriate extra would need to be installed along with `llama-deploy`.
|
||||
|
||||
For example, for `RabbitMQMessageQueue`, we need to install the "rabbitmq" extra:
|
||||
|
||||
```sh
|
||||
# using pip install
|
||||
pip install llama-agents[rabbitmq]
|
||||
|
||||
# using poetry
|
||||
poetry add llama-agents -E "rabbitmq"
|
||||
|
||||
# using uv
|
||||
uv add llama-agents -extra "rabbitmq"
|
||||
```
|
||||
|
||||
Using the `RabbitMQMessageQueue` is then done as follows:
|
||||
|
||||
```python
|
||||
from llama_agents.message_queue.rabbitmq import (
|
||||
RabbitMQMessageQueueConfig,
|
||||
RabbitMQMessageQueue,
|
||||
)
|
||||
|
||||
message_queue_config = (
|
||||
RabbitMQMessageQueueConfig()
|
||||
) # loads params from environment vars
|
||||
message_queue = RabbitMQMessageQueue(**message_queue_config)
|
||||
```
|
||||
|
||||
|
||||
> [!NOTE]
|
||||
> `RabbitMQMessageQueueConfig` can load its params from environment variables.
|
||||
|
||||
### Delivery policy
|
||||
|
||||
Currently the way service replicas receive the message to run a task depends on
|
||||
the message queue implementation:
|
||||
|
||||
- `SimpleMessageQueue`: consumers are competing but the order is non
|
||||
deterministic, the first subscriber (in this case, the first service) that
|
||||
manages to get the message in the topic wins, all the others will keep trying
|
||||
and never know a message was published.
|
||||
- `RedisMessageQueue`: by default, all the services get the message and run the
|
||||
task. If you set the `exclusive_mode` configuration parameter of the
|
||||
`RedisMessageQueueConfig` class to `True`, services will compete for messages
|
||||
and only the first coming will be able to read it.
|
||||
- `RabbitMQMessageQueue`: consumers are competing, a round robin policy is used
|
||||
to pick the recipient
|
||||
- `KafkaMessageQueue`: same as RabbitMQ because the `group_id` of the consumer
|
||||
is hardcoded
|
||||
|
||||
## Task
|
||||
|
||||
A Task is an object representing a request for an operation sent to a Service and the response that will be sent back.
|
||||
|
||||
@@ -45,7 +45,6 @@ async def check_status():
|
||||
The client provides access to two main components:
|
||||
|
||||
- `apiserver`: Interact with the [API server](./20_core_components.md#api-server)
|
||||
- `core`: Access [Control Plane](./20_core_components.md#control-plane) functionalities.
|
||||
|
||||
Each component exposes specific methods for managing and interacting with the deployed system.
|
||||
|
||||
@@ -53,10 +52,6 @@ Each component exposes specific methods for managing and interacting with the de
|
||||
> To use the `apiserver` functionalities, the API Server must be up and its URL
|
||||
> (by default `http://localhost:4501`) reachable by the host executing the client code.
|
||||
|
||||
> [!IMPORTANT]
|
||||
> To use the `core` functionalities, the Control Plane must be up and its URL
|
||||
> (by default `http://localhost:8000`) reachable by the host executing the client code.
|
||||
|
||||
|
||||
For a complete list of available methods and detailed API reference, see the
|
||||
[API Reference section](../../api_reference/llama_deploy/python_sdk.md).
|
||||
@@ -88,66 +83,3 @@ print(status)
|
||||
> [!IMPORTANT]
|
||||
> The synchronous API (`client.sync`) cannot be used within an async event loop.
|
||||
> Use the async methods directly in that case.
|
||||
|
||||
### A more complex example
|
||||
|
||||
This is an example of how you would use the recommended async version of the client to
|
||||
run a deployed workflow and collect the events it streams.
|
||||
|
||||
```python
|
||||
import llama_deploy
|
||||
|
||||
|
||||
async def stream_events(services):
|
||||
client = llama_deploy.Client(timeout=10)
|
||||
|
||||
# Create a new session
|
||||
session = await client.core.sessions.create()
|
||||
|
||||
# Assuming there's a workflow called `streaming_workflow`, run it in the background
|
||||
task_id = await session.run_nowait(
|
||||
"streaming_workflow", arg="Hello, world!"
|
||||
)
|
||||
|
||||
# The workflow is supposed to stream events signalling its progress
|
||||
async for event in session.get_task_result_stream(task_id):
|
||||
if "progress" in event:
|
||||
print(f'Workflow Progress: {event["progress"]}')
|
||||
|
||||
# When done, collect the workflow output
|
||||
final_result = await session.get_task_result(task_id)
|
||||
print(final_result)
|
||||
|
||||
# Clean up the session
|
||||
await client.core.sessions.delete(session.id)
|
||||
```
|
||||
|
||||
The equivalent synchronous version would be the following:
|
||||
|
||||
```python
|
||||
import llama_deploy
|
||||
|
||||
|
||||
def stream_events(services):
|
||||
client = llama_deploy.Client(timeout=10)
|
||||
|
||||
# Create a new session
|
||||
session = client.sync.core.sessions.create()
|
||||
|
||||
# Assuming there's a workflow called `streaming_workflow`, run it
|
||||
task_id = session.run_nowait("streaming_workflow", arg1="hello_world")
|
||||
|
||||
# The workflow is supposed to stream events signalling its progress.
|
||||
# Since this is a synchronous call, by this time all the events were
|
||||
# streamed and collected in a list.
|
||||
for event in session.get_task_result_stream(task_id):
|
||||
if "progress" in event:
|
||||
print(f'Workflow Progress: {event["progress"]}')
|
||||
|
||||
# Collect the workflow output
|
||||
final_result = session.get_task_result(task_id)
|
||||
print(final_result)
|
||||
|
||||
# Clean up the session
|
||||
client.sync.core.sessions.delete(session.id)
|
||||
```
|
||||
|
||||
File diff suppressed because one or more lines are too long
@@ -25,12 +25,10 @@ pip install -U llama-deploy
|
||||
|
||||
1. **Seamless Deployment**: It bridges the gap between development and production, allowing you to deploy `llama_index`
|
||||
workflows with minimal changes to your code.
|
||||
2. **Scalability**: The microservices architecture enables easy scaling of individual components as your system grows.
|
||||
3. **Flexibility**: By using a hub-and-spoke architecture, you can easily swap out components (like message queues) or
|
||||
add new services without disrupting the entire system.
|
||||
4. **Fault Tolerance**: With built-in retry mechanisms and failure handling, LlamaDeploy adds robustness in
|
||||
production environments.
|
||||
5. **State Management**: The control plane manages state across services, simplifying complex multi-step processes.
|
||||
6. **Async-First**: Designed for high-concurrency scenarios, making it suitable for real-time and high-throughput
|
||||
applications.
|
||||
|
||||
|
||||
+2
-2
@@ -19,11 +19,11 @@ $ uv run -- pytest ./e2e_tests
|
||||
To run a specific scenario:
|
||||
|
||||
```sh
|
||||
$ uv run -- pytest e2e_tests/basic_streaming
|
||||
$ uv run -- pytest e2e_tests/apiserver
|
||||
```
|
||||
|
||||
If you want to see the output of the different services running, pass the `-s` flag to pytest:
|
||||
|
||||
```sh
|
||||
$ uv run -- pytest e2e_tests/basic_streaming -s
|
||||
$ uv run -- pytest e2e_tests/apiserver/test_deploy.py -s
|
||||
```
|
||||
|
||||
@@ -8,7 +8,7 @@ services:
|
||||
name: Git Workflow
|
||||
source:
|
||||
type: git
|
||||
name: https://github.com/run-llama/llama_deploy.git@massi/refact
|
||||
name: https://github.com/run-llama/llama_deploy.git
|
||||
env:
|
||||
VAR_1: x # this gets overwritten because VAR_1 also exists in the provided .env
|
||||
VAR_2: y
|
||||
|
||||
@@ -1,36 +0,0 @@
|
||||
import asyncio
|
||||
import multiprocessing
|
||||
import time
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_deploy.control_plane import ControlPlaneConfig
|
||||
from llama_deploy.services import WorkflowServiceConfig
|
||||
|
||||
from ..utils import deploy_workflow
|
||||
from .workflow import HumanInTheLoopWorkflow
|
||||
|
||||
|
||||
def run_async_workflow():
|
||||
asyncio.run(
|
||||
deploy_workflow(
|
||||
HumanInTheLoopWorkflow(timeout=60),
|
||||
WorkflowServiceConfig(
|
||||
host="127.0.0.1",
|
||||
port=8002,
|
||||
service_name="hitl_workflow",
|
||||
),
|
||||
ControlPlaneConfig(),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def services(core):
|
||||
p = multiprocessing.Process(target=run_async_workflow)
|
||||
p.start()
|
||||
time.sleep(5)
|
||||
|
||||
yield
|
||||
|
||||
p.kill()
|
||||
@@ -1,74 +0,0 @@
|
||||
import asyncio
|
||||
import time
|
||||
|
||||
import pytest
|
||||
from workflows.events import HumanResponseEvent
|
||||
|
||||
from llama_deploy.client import Client
|
||||
|
||||
|
||||
def test_run_client(services):
|
||||
client = Client(timeout=10)
|
||||
|
||||
# sanity check
|
||||
sessions = client.sync.core.sessions.list()
|
||||
assert len(sessions) == 0, "Sessions list is not empty"
|
||||
|
||||
# create a session
|
||||
session = client.sync.core.sessions.create()
|
||||
|
||||
# kick off run
|
||||
task_id = session.run_nowait("hitl_workflow")
|
||||
|
||||
# send event
|
||||
session.send_event(
|
||||
ev=HumanResponseEvent(response="42"),
|
||||
service_name="hitl_workflow",
|
||||
task_id=task_id,
|
||||
)
|
||||
|
||||
# get final result, polling to wait for workflow to finish after send event
|
||||
final_result = None
|
||||
while final_result is None:
|
||||
final_result = session.get_task_result(task_id)
|
||||
time.sleep(0.1)
|
||||
assert final_result.result == "42", "The human's response is not consistent."
|
||||
|
||||
# delete the session
|
||||
client.sync.core.sessions.delete(session.id)
|
||||
sessions = client.sync.core.sessions.list()
|
||||
assert len(sessions) == 0, "Sessions list is not empty"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_client_async(services):
|
||||
client = Client(timeout=10)
|
||||
|
||||
# sanity check
|
||||
sessions = await client.core.sessions.list()
|
||||
assert len(sessions) == 0, "Sessions list is not empty"
|
||||
|
||||
# create a session
|
||||
session = await client.core.sessions.create()
|
||||
|
||||
# kick off run
|
||||
task_id = await session.run_nowait("hitl_workflow")
|
||||
|
||||
# send event
|
||||
await session.send_event(
|
||||
ev=HumanResponseEvent(response="42"),
|
||||
service_name="hitl_workflow",
|
||||
task_id=task_id,
|
||||
)
|
||||
|
||||
# get final result, polling to wait for workflow to finish after send event
|
||||
final_result = None
|
||||
while final_result is None:
|
||||
final_result = await session.get_task_result(task_id)
|
||||
await asyncio.sleep(0.1)
|
||||
assert final_result.result == "42", "The human's response is not consistent."
|
||||
|
||||
# delete the session
|
||||
await client.core.sessions.delete(session.id)
|
||||
sessions = await client.core.sessions.list()
|
||||
assert len(sessions) == 0, "Sessions list is not empty"
|
||||
@@ -1,17 +0,0 @@
|
||||
from workflows import Workflow, step
|
||||
from workflows.events import (
|
||||
HumanResponseEvent,
|
||||
InputRequiredEvent,
|
||||
StartEvent,
|
||||
StopEvent,
|
||||
)
|
||||
|
||||
|
||||
class HumanInTheLoopWorkflow(Workflow):
|
||||
@step
|
||||
async def step1(self, ev: StartEvent) -> InputRequiredEvent:
|
||||
return InputRequiredEvent(prefix="Enter a number: ")
|
||||
|
||||
@step
|
||||
async def step2(self, ev: HumanResponseEvent) -> StopEvent:
|
||||
return StopEvent(result=ev.response)
|
||||
@@ -1,36 +0,0 @@
|
||||
import asyncio
|
||||
import multiprocessing
|
||||
import time
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_deploy.control_plane import ControlPlaneConfig
|
||||
from llama_deploy.services import WorkflowServiceConfig
|
||||
|
||||
from ..utils import deploy_workflow
|
||||
from .workflow import SessionWorkflow
|
||||
|
||||
|
||||
def run_async_workflow():
|
||||
asyncio.run(
|
||||
deploy_workflow(
|
||||
SessionWorkflow(timeout=10),
|
||||
WorkflowServiceConfig(
|
||||
host="127.0.0.1",
|
||||
port=8002,
|
||||
service_name="session_workflow",
|
||||
),
|
||||
ControlPlaneConfig(),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def workflow(core):
|
||||
p = multiprocessing.Process(target=run_async_workflow)
|
||||
p.start()
|
||||
time.sleep(5)
|
||||
|
||||
yield
|
||||
|
||||
p.kill()
|
||||
@@ -1,44 +0,0 @@
|
||||
import pytest
|
||||
|
||||
from llama_deploy.client import Client
|
||||
|
||||
|
||||
def test_run_client(workflow):
|
||||
client = Client(timeout=10)
|
||||
|
||||
# create session
|
||||
session = client.sync.core.sessions.create()
|
||||
|
||||
# test run with session
|
||||
result = session.run("session_workflow")
|
||||
assert result == "1"
|
||||
|
||||
# run again
|
||||
result = session.run("session_workflow")
|
||||
assert result == "2"
|
||||
|
||||
# create new session and run
|
||||
session = client.sync.core.sessions.create()
|
||||
result = session.run("session_workflow")
|
||||
assert result == "1"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_client_async(workflow):
|
||||
client = Client(timeout=10)
|
||||
|
||||
# create session
|
||||
session = await client.core.sessions.create()
|
||||
|
||||
# run
|
||||
result = await session.run("session_workflow")
|
||||
assert result == "1"
|
||||
|
||||
# run again
|
||||
result = await session.run("session_workflow")
|
||||
assert result == "2"
|
||||
|
||||
# create new session and run
|
||||
session = await client.core.sessions.create()
|
||||
result = await session.run("session_workflow")
|
||||
assert result == "1"
|
||||
@@ -1,16 +0,0 @@
|
||||
from llama_index.core.workflow import (
|
||||
Context,
|
||||
Workflow,
|
||||
StartEvent,
|
||||
StopEvent,
|
||||
step,
|
||||
)
|
||||
|
||||
|
||||
class SessionWorkflow(Workflow):
|
||||
@step()
|
||||
async def step_1(self, ctx: Context, ev: StartEvent) -> StopEvent:
|
||||
cur_val = await ctx.get("count", default=0)
|
||||
await ctx.set("count", cur_val + 1)
|
||||
|
||||
return StopEvent(result=cur_val + 1)
|
||||
@@ -1,36 +0,0 @@
|
||||
import asyncio
|
||||
import multiprocessing
|
||||
import time
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_deploy.control_plane import ControlPlaneConfig
|
||||
from llama_deploy.services import WorkflowServiceConfig
|
||||
|
||||
from ..utils import deploy_workflow
|
||||
from .workflow import StreamingWorkflow
|
||||
|
||||
|
||||
def run_async_workflow():
|
||||
asyncio.run(
|
||||
deploy_workflow(
|
||||
StreamingWorkflow(timeout=10),
|
||||
WorkflowServiceConfig(
|
||||
host="127.0.0.1",
|
||||
port=8002,
|
||||
service_name="streaming_workflow",
|
||||
),
|
||||
ControlPlaneConfig(),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def services(core):
|
||||
p = multiprocessing.Process(target=run_async_workflow)
|
||||
p.start()
|
||||
time.sleep(5)
|
||||
|
||||
yield
|
||||
|
||||
p.kill()
|
||||
@@ -1,55 +0,0 @@
|
||||
import pytest
|
||||
|
||||
from llama_deploy.client import Client
|
||||
|
||||
|
||||
def test_run_client(services):
|
||||
client = Client(timeout=20)
|
||||
|
||||
# sanity check
|
||||
sessions = client.sync.core.sessions.list()
|
||||
assert len(sessions) == 0, "Sessions list is not empty"
|
||||
|
||||
# test streaming
|
||||
session = client.sync.core.sessions.create()
|
||||
|
||||
# kick off run
|
||||
task_id = session.run_nowait("streaming_workflow", arg1="hello_world")
|
||||
|
||||
progress_received = []
|
||||
for event in session.get_task_result_stream(task_id):
|
||||
event_data = event.get("value", {})
|
||||
if "progress" in event_data:
|
||||
progress_received.append(event_data.get("progress"))
|
||||
assert progress_received == [0.3, 0.6, 0.9]
|
||||
|
||||
# get final result
|
||||
final_result = session.get_task_result(task_id)
|
||||
assert final_result.result == "hello_world_result_result_result" # type: ignore
|
||||
|
||||
# delete everything
|
||||
client.sync.core.sessions.delete(session.id)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_client_async(services):
|
||||
client = Client(timeout=20)
|
||||
|
||||
# test streaming
|
||||
session = await client.core.sessions.create()
|
||||
|
||||
# kick off run
|
||||
task_id = await session.run_nowait("streaming_workflow", arg1="hello_world")
|
||||
|
||||
progress_received = []
|
||||
async for event in session.get_task_result_stream(task_id):
|
||||
event_data = event.get("value", {})
|
||||
if "progress" in event_data:
|
||||
progress_received.append(event_data.get("progress"))
|
||||
assert progress_received == [0.3, 0.6, 0.9]
|
||||
|
||||
final_result = await session.get_task_result(task_id)
|
||||
assert final_result.result == "hello_world_result_result_result" # type: ignore
|
||||
|
||||
# delete everything
|
||||
await client.core.sessions.delete(session.id)
|
||||
@@ -1,52 +0,0 @@
|
||||
from llama_index.core.workflow import (
|
||||
Context,
|
||||
Event,
|
||||
StartEvent,
|
||||
StopEvent,
|
||||
Workflow,
|
||||
step,
|
||||
)
|
||||
|
||||
|
||||
class ProgressEvent(Event):
|
||||
progress: float
|
||||
|
||||
|
||||
class Step1(Event):
|
||||
arg1: str
|
||||
|
||||
|
||||
class Step2(Event):
|
||||
arg1: str
|
||||
|
||||
|
||||
class StreamingWorkflow(Workflow):
|
||||
@step()
|
||||
async def run_step_1(self, ctx: Context, ev: StartEvent) -> Step1:
|
||||
arg1 = ev.get("arg1")
|
||||
if not arg1:
|
||||
raise ValueError("arg1 is required.")
|
||||
|
||||
ctx.write_event_to_stream(ProgressEvent(progress=0.3))
|
||||
|
||||
return Step1(arg1=str(arg1) + "_result")
|
||||
|
||||
@step()
|
||||
async def run_step_2(self, ctx: Context, ev: Step1) -> Step2:
|
||||
arg1 = ev.arg1
|
||||
if not arg1:
|
||||
raise ValueError("arg1 is required.")
|
||||
|
||||
ctx.write_event_to_stream(ProgressEvent(progress=0.6))
|
||||
|
||||
return Step2(arg1=str(arg1) + "_result")
|
||||
|
||||
@step()
|
||||
async def run_step_3(self, ctx: Context, ev: Step2) -> StopEvent:
|
||||
arg1 = ev.arg1
|
||||
if not arg1:
|
||||
raise ValueError("arg1 is required.")
|
||||
|
||||
ctx.write_event_to_stream(ProgressEvent(progress=0.9))
|
||||
|
||||
return StopEvent(result=str(arg1) + "_result")
|
||||
@@ -1,36 +0,0 @@
|
||||
import asyncio
|
||||
import multiprocessing
|
||||
import time
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_deploy.control_plane import ControlPlaneConfig
|
||||
from llama_deploy.services import WorkflowServiceConfig
|
||||
|
||||
from ..utils import deploy_workflow
|
||||
from .workflow import OuterWorkflow
|
||||
|
||||
|
||||
def run_async_workflow():
|
||||
asyncio.run(
|
||||
deploy_workflow(
|
||||
OuterWorkflow(timeout=10),
|
||||
WorkflowServiceConfig(
|
||||
host="127.0.0.1",
|
||||
port=8002,
|
||||
service_name="outer",
|
||||
),
|
||||
ControlPlaneConfig(),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def workflow(core):
|
||||
p = multiprocessing.Process(target=run_async_workflow)
|
||||
p.start()
|
||||
time.sleep(5)
|
||||
|
||||
yield
|
||||
|
||||
p.kill()
|
||||
@@ -1,60 +0,0 @@
|
||||
import pytest
|
||||
|
||||
from llama_deploy.client import Client
|
||||
|
||||
|
||||
def test_run_client(workflow):
|
||||
client = Client(timeout=10)
|
||||
|
||||
# test connections
|
||||
assert len(client.sync.core.services.list()) == 1
|
||||
assert len(client.sync.core.sessions.list()) == 0
|
||||
|
||||
# test create session
|
||||
session = client.sync.core.sessions.get_or_create("fake_session_id")
|
||||
sessions = client.sync.core.sessions.list()
|
||||
assert len(sessions) == 1
|
||||
assert sessions[0].id == session.id
|
||||
|
||||
# test run with session
|
||||
result = session.run("outer", arg1="hello_world")
|
||||
assert result == "hello_world_result"
|
||||
|
||||
# test number of tasks
|
||||
tasks = session.get_tasks()
|
||||
assert len(tasks) == 1
|
||||
assert tasks[0].service_id == "outer"
|
||||
|
||||
# delete everything
|
||||
client.sync.core.sessions.delete(session.id)
|
||||
assert len(client.sync.core.sessions.list()) == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_client_async(workflow):
|
||||
client = Client(timeout=10)
|
||||
|
||||
# test connections
|
||||
assert len(await client.core.services.list()) == 1
|
||||
assert len(await client.core.sessions.list()) == 0
|
||||
|
||||
# test create session
|
||||
session = await client.core.sessions.get_or_create("fake_session_id")
|
||||
sessions = await client.core.sessions.list()
|
||||
assert len(sessions) == 1, f"Expected 1 session, got {sessions}"
|
||||
assert sessions[0].id == session.id
|
||||
|
||||
# test run with session
|
||||
result = await session.run("outer", arg1="hello_world")
|
||||
assert result == "hello_world_result"
|
||||
|
||||
# test number of tasks
|
||||
tasks = await session.get_tasks()
|
||||
assert len(tasks) == 1, f"Expected 1 task, got {len(tasks)} tasks"
|
||||
assert (
|
||||
tasks[0].service_id == "outer"
|
||||
), f"Expected id to be 'outer', got {tasks[0].service_id}"
|
||||
|
||||
# delete everything
|
||||
await client.core.sessions.delete(session.id)
|
||||
assert len(await client.core.sessions.list()) == 0
|
||||
@@ -1,26 +0,0 @@
|
||||
from llama_index.core.workflow import (
|
||||
Context,
|
||||
Event,
|
||||
StartEvent,
|
||||
StopEvent,
|
||||
Workflow,
|
||||
step,
|
||||
)
|
||||
|
||||
|
||||
class CustomEvent(Event):
|
||||
pass
|
||||
|
||||
|
||||
class OuterWorkflow(Workflow):
|
||||
@step()
|
||||
async def run_step(self, ctx: Context, ev: StartEvent) -> CustomEvent:
|
||||
await ctx.set("arg1", ev.get("arg1"))
|
||||
# ensure the collect_events system serializes correctly
|
||||
ctx.collect_events(ev, [CustomEvent])
|
||||
return CustomEvent()
|
||||
|
||||
@step
|
||||
async def run_final_step(self, ctx: Context, ev: CustomEvent) -> StopEvent:
|
||||
arg1 = await ctx.get("arg1")
|
||||
return StopEvent(result=str(arg1) + "_result")
|
||||
@@ -1,25 +0,0 @@
|
||||
import asyncio
|
||||
import multiprocessing
|
||||
import time
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_deploy.control_plane import ControlPlaneConfig
|
||||
from llama_deploy.message_queues.simple import SimpleMessageQueueConfig
|
||||
|
||||
from .utils import deploy_core
|
||||
|
||||
|
||||
def run_async_core():
|
||||
asyncio.run(deploy_core(ControlPlaneConfig(), SimpleMessageQueueConfig()))
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def core():
|
||||
p = multiprocessing.Process(target=run_async_core)
|
||||
p.start()
|
||||
time.sleep(3)
|
||||
|
||||
yield
|
||||
|
||||
p.kill()
|
||||
@@ -1,36 +0,0 @@
|
||||
import asyncio
|
||||
import multiprocessing
|
||||
import time
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_deploy.control_plane import ControlPlaneConfig
|
||||
from llama_deploy.services import WorkflowServiceConfig
|
||||
|
||||
from ..utils import deploy_workflow
|
||||
from .workflow import BasicWorkflow
|
||||
|
||||
|
||||
def run_async_workflow():
|
||||
asyncio.run(
|
||||
deploy_workflow(
|
||||
BasicWorkflow(timeout=10),
|
||||
WorkflowServiceConfig(
|
||||
host="127.0.0.1",
|
||||
port=8002,
|
||||
service_name="basic",
|
||||
),
|
||||
ControlPlaneConfig(),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def workflow(core):
|
||||
p = multiprocessing.Process(target=run_async_workflow)
|
||||
p.start()
|
||||
time.sleep(5)
|
||||
|
||||
yield
|
||||
|
||||
p.kill()
|
||||
@@ -1,73 +0,0 @@
|
||||
import asyncio
|
||||
import multiprocessing
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_deploy.client import Client
|
||||
from llama_deploy.types.core import ServiceDefinition
|
||||
|
||||
from .conftest import run_async_workflow
|
||||
|
||||
|
||||
def test_services(workflow):
|
||||
client = Client()
|
||||
|
||||
services = client.sync.core.services
|
||||
assert len(services.list()) == 1
|
||||
|
||||
services.deregister("basic")
|
||||
assert len(services.items) == 0
|
||||
|
||||
new_s = services.register(
|
||||
ServiceDefinition(service_name="another_basic", description="none")
|
||||
)
|
||||
assert new_s.id == "another_basic"
|
||||
assert len(services.items) == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_services_async(workflow):
|
||||
client = Client()
|
||||
|
||||
assert len(await client.core.services.list()) == 1
|
||||
await client.core.services.deregister("basic")
|
||||
assert len(await client.core.services.list()) == 0
|
||||
|
||||
new_s = await client.core.services.register(
|
||||
ServiceDefinition(service_name="another_basic", description="none")
|
||||
)
|
||||
assert new_s.id == "another_basic"
|
||||
assert len(await client.core.services.list()) == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_service_restart(core):
|
||||
client = Client()
|
||||
|
||||
# create workflow service in a separate process
|
||||
p = multiprocessing.Process(target=run_async_workflow)
|
||||
p.start()
|
||||
await asyncio.sleep(5)
|
||||
|
||||
# create session
|
||||
session = await client.core.sessions.create()
|
||||
|
||||
# run
|
||||
result = await session.run("basic")
|
||||
assert result == "n/a_result"
|
||||
|
||||
# kill the service
|
||||
p.kill()
|
||||
p.join()
|
||||
|
||||
# restart the service
|
||||
p = multiprocessing.Process(target=run_async_workflow)
|
||||
p.start()
|
||||
await asyncio.sleep(5)
|
||||
|
||||
# run again, same session
|
||||
result = await session.run("basic")
|
||||
assert result == "n/a_result"
|
||||
|
||||
p.kill()
|
||||
p.join()
|
||||
@@ -1,8 +0,0 @@
|
||||
from llama_index.core.workflow import Context, StartEvent, StopEvent, Workflow, step
|
||||
|
||||
|
||||
class BasicWorkflow(Workflow):
|
||||
@step()
|
||||
async def run_step(self, ctx: Context, ev: StartEvent) -> StopEvent:
|
||||
arg1 = ev.get("arg1", "n/a")
|
||||
return StopEvent(result=str(arg1) + "_result")
|
||||
@@ -1,103 +0,0 @@
|
||||
import asyncio
|
||||
import multiprocessing
|
||||
import subprocess
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_deploy.control_plane import ControlPlaneConfig
|
||||
from llama_deploy.message_queues import KafkaMessageQueue, KafkaMessageQueueConfig
|
||||
from llama_deploy.services import WorkflowServiceConfig
|
||||
|
||||
from ...utils import deploy_core, deploy_workflow
|
||||
from .workflow import BasicWorkflow
|
||||
|
||||
|
||||
@pytest.fixture(scope="package")
|
||||
def kafka_service():
|
||||
compose_file = Path(__file__).resolve().parent / "docker-compose.yml"
|
||||
proc = subprocess.Popen(
|
||||
["docker", "compose", "-f", f"{compose_file}", "up", "-d", "--wait"]
|
||||
)
|
||||
proc.communicate()
|
||||
yield
|
||||
subprocess.Popen(["docker", "compose", "-f", f"{compose_file}", "down"])
|
||||
proc.communicate()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mq(kafka_service):
|
||||
return KafkaMessageQueue(KafkaMessageQueueConfig())
|
||||
|
||||
|
||||
def run_workflow_one():
|
||||
asyncio.run(
|
||||
deploy_workflow(
|
||||
BasicWorkflow(timeout=10, name="Workflow one"),
|
||||
WorkflowServiceConfig(
|
||||
host="127.0.0.1",
|
||||
port=8003,
|
||||
service_name="basic",
|
||||
),
|
||||
ControlPlaneConfig(topic_namespace="core_one", port=8001),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def run_workflow_two():
|
||||
asyncio.run(
|
||||
deploy_workflow(
|
||||
BasicWorkflow(timeout=10, name="Workflow two"),
|
||||
WorkflowServiceConfig(
|
||||
host="127.0.0.1",
|
||||
port=8004,
|
||||
service_name="basic",
|
||||
),
|
||||
ControlPlaneConfig(topic_namespace="core_two", port=8002),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def run_core_one():
|
||||
asyncio.run(
|
||||
deploy_core(
|
||||
ControlPlaneConfig(topic_namespace="core_one", port=8001),
|
||||
KafkaMessageQueueConfig(),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def run_core_two():
|
||||
asyncio.run(
|
||||
deploy_core(
|
||||
ControlPlaneConfig(topic_namespace="core_two", port=8002),
|
||||
KafkaMessageQueueConfig(),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def control_planes(kafka_service):
|
||||
p1 = multiprocessing.Process(target=run_core_one)
|
||||
p1.start()
|
||||
time.sleep(2)
|
||||
|
||||
p2 = multiprocessing.Process(target=run_core_two)
|
||||
p2.start()
|
||||
time.sleep(2)
|
||||
|
||||
p3 = multiprocessing.Process(target=run_workflow_one)
|
||||
p3.start()
|
||||
time.sleep(2)
|
||||
|
||||
p4 = multiprocessing.Process(target=run_workflow_two)
|
||||
p4.start()
|
||||
time.sleep(2)
|
||||
|
||||
yield
|
||||
|
||||
p1.terminate()
|
||||
p2.terminate()
|
||||
p3.terminate()
|
||||
p4.terminate()
|
||||
@@ -1,29 +0,0 @@
|
||||
services:
|
||||
kafka:
|
||||
image: apache/kafka:3.7.1
|
||||
ports:
|
||||
- "9092:9092"
|
||||
environment:
|
||||
KAFKA_NODE_ID: 1
|
||||
KAFKA_LISTENER_SECURITY_PROTOCOL_MAP: "CONTROLLER:PLAINTEXT,PLAINTEXT:PLAINTEXT,PLAINTEXT_HOST:PLAINTEXT"
|
||||
KAFKA_ADVERTISED_LISTENERS: "PLAINTEXT_HOST://localhost:9092,PLAINTEXT://kafka:19092"
|
||||
KAFKA_PROCESS_ROLES: "broker,controller"
|
||||
KAFKA_CONTROLLER_QUORUM_VOTERS: "1@kafka:29093"
|
||||
KAFKA_LISTENERS: "CONTROLLER://:29093,PLAINTEXT_HOST://:9092,PLAINTEXT://:19092"
|
||||
KAFKA_INTER_BROKER_LISTENER_NAME: "PLAINTEXT"
|
||||
KAFKA_CONTROLLER_LISTENER_NAMES: "CONTROLLER"
|
||||
KAFKA_OFFSETS_TOPIC_REPLICATION_FACTOR: 1
|
||||
KAFKA_GROUP_INITIAL_REBALANCE_DELAY_MS: 0
|
||||
KAFKA_TRANSACTION_STATE_LOG_MIN_ISR: 1
|
||||
KAFKA_TRANSACTION_STATE_LOG_REPLICATION_FACTOR: 1
|
||||
KAFKA_LOG_DIRS: "/tmp/kraft-combined-logs"
|
||||
healthcheck:
|
||||
test:
|
||||
- CMD-SHELL
|
||||
- -c
|
||||
- |
|
||||
nc -z localhost 9092 || exit -1
|
||||
start_period: 15s
|
||||
interval: 30s
|
||||
timeout: 10s
|
||||
retries: 5
|
||||
@@ -1,38 +0,0 @@
|
||||
import asyncio
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_deploy.client import Client
|
||||
from llama_deploy.types import QueueMessage
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_roundtrip(mq):
|
||||
# produce a message
|
||||
test_message = QueueMessage(type="test_message", data={"message": "this is a test"})
|
||||
await mq.publish(test_message, topic="test")
|
||||
|
||||
await asyncio.sleep(0)
|
||||
|
||||
async for m in mq.get_messages("test"):
|
||||
assert m == test_message
|
||||
break
|
||||
|
||||
# Give time for shutting down kafka consumer
|
||||
await asyncio.sleep(1)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_control_planes(control_planes):
|
||||
c1 = Client(control_plane_url="http://localhost:8001")
|
||||
c2 = Client(control_plane_url="http://localhost:8002")
|
||||
|
||||
session = await c1.core.sessions.create()
|
||||
r1 = await session.run("basic", arg="Hello One!")
|
||||
await c1.core.sessions.delete(session.id)
|
||||
assert r1 == "Workflow one received Hello One!"
|
||||
|
||||
session = await c2.core.sessions.create()
|
||||
r2 = await session.run("basic", arg="Hello Two!")
|
||||
await c2.core.sessions.delete(session.id)
|
||||
assert r2 == "Workflow two received Hello Two!"
|
||||
@@ -1,12 +0,0 @@
|
||||
from llama_index.core.workflow import Context, StartEvent, StopEvent, Workflow, step
|
||||
|
||||
|
||||
class BasicWorkflow(Workflow):
|
||||
def __init__(self, *args, **kwargs):
|
||||
self._name = kwargs.pop("name")
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
@step()
|
||||
async def run_step(self, ctx: Context, ev: StartEvent) -> StopEvent:
|
||||
received = ev.get("arg")
|
||||
return StopEvent(result=f"{self._name} received {received}")
|
||||
@@ -1,103 +0,0 @@
|
||||
import asyncio
|
||||
import multiprocessing
|
||||
import subprocess
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_deploy.control_plane import ControlPlaneConfig
|
||||
from llama_deploy.message_queues import RabbitMQMessageQueue, RabbitMQMessageQueueConfig
|
||||
from llama_deploy.services import WorkflowServiceConfig
|
||||
|
||||
from ...utils import deploy_core, deploy_workflow
|
||||
from .workflow import BasicWorkflow
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def rabbitmq_service():
|
||||
compose_file = Path(__file__).resolve().parent / "docker-compose.yml"
|
||||
proc = subprocess.Popen(
|
||||
["docker", "compose", "-f", f"{compose_file}", "up", "-d", "--wait"]
|
||||
)
|
||||
proc.communicate()
|
||||
yield
|
||||
subprocess.Popen(["docker", "compose", "-f", f"{compose_file}", "down"])
|
||||
proc.communicate()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mq(rabbitmq_service):
|
||||
return RabbitMQMessageQueue(RabbitMQMessageQueueConfig())
|
||||
|
||||
|
||||
def run_workflow_one():
|
||||
asyncio.run(
|
||||
deploy_workflow(
|
||||
BasicWorkflow(timeout=10, name="Workflow one"),
|
||||
WorkflowServiceConfig(
|
||||
host="127.0.0.1",
|
||||
port=8003,
|
||||
service_name="basic",
|
||||
),
|
||||
ControlPlaneConfig(topic_namespace="core_one", port=8001),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def run_workflow_two():
|
||||
asyncio.run(
|
||||
deploy_workflow(
|
||||
BasicWorkflow(timeout=10, name="Workflow two"),
|
||||
WorkflowServiceConfig(
|
||||
host="127.0.0.1",
|
||||
port=8004,
|
||||
service_name="basic",
|
||||
),
|
||||
ControlPlaneConfig(topic_namespace="core_two", port=8002),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def run_core_one():
|
||||
asyncio.run(
|
||||
deploy_core(
|
||||
ControlPlaneConfig(topic_namespace="core_one", port=8001),
|
||||
RabbitMQMessageQueueConfig(),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def run_core_two():
|
||||
asyncio.run(
|
||||
deploy_core(
|
||||
ControlPlaneConfig(topic_namespace="core_two", port=8002),
|
||||
RabbitMQMessageQueueConfig(),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def control_planes(rabbitmq_service):
|
||||
p1 = multiprocessing.Process(target=run_core_one)
|
||||
p1.start()
|
||||
|
||||
p2 = multiprocessing.Process(target=run_core_two)
|
||||
p2.start()
|
||||
|
||||
time.sleep(3)
|
||||
|
||||
p3 = multiprocessing.Process(target=run_workflow_one)
|
||||
p3.start()
|
||||
|
||||
p4 = multiprocessing.Process(target=run_workflow_two)
|
||||
p4.start()
|
||||
|
||||
time.sleep(3)
|
||||
|
||||
yield
|
||||
|
||||
p1.terminate()
|
||||
p2.terminate()
|
||||
p3.terminate()
|
||||
p4.terminate()
|
||||
@@ -1,17 +0,0 @@
|
||||
services:
|
||||
rabbitmq:
|
||||
image: rabbitmq:3-management-alpine
|
||||
hostname: "rabbitmq"
|
||||
ports:
|
||||
- "5672:5672"
|
||||
- "15672:15672"
|
||||
healthcheck:
|
||||
test:
|
||||
- CMD-SHELL
|
||||
- -c
|
||||
- |
|
||||
rabbitmq-diagnostics -q check_running
|
||||
rabbitmq-diagnostics -q check_port_connectivity
|
||||
interval: 5s
|
||||
timeout: 3s
|
||||
retries: 5
|
||||
@@ -1,39 +0,0 @@
|
||||
import asyncio
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_deploy.client import Client
|
||||
from llama_deploy.types import QueueMessage
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_roundtrip(mq):
|
||||
async def consume():
|
||||
async for m in mq.get_messages("test"):
|
||||
return m
|
||||
|
||||
t = asyncio.create_task(consume())
|
||||
await asyncio.sleep(1)
|
||||
|
||||
test_message = QueueMessage(type="test_message", data={"message": "this is a test"})
|
||||
await mq.publish(test_message, topic="test")
|
||||
|
||||
result = await t
|
||||
assert result == test_message
|
||||
await asyncio.sleep(1)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_control_planes(control_planes):
|
||||
c1 = Client(control_plane_url="http://localhost:8001")
|
||||
c2 = Client(control_plane_url="http://localhost:8002")
|
||||
|
||||
session = await c1.core.sessions.create()
|
||||
r1 = await session.run("basic", arg="Hello One!")
|
||||
await c1.core.sessions.delete(session.id)
|
||||
assert r1 == "Workflow one received Hello One!"
|
||||
|
||||
session = await c2.core.sessions.create()
|
||||
r2 = await session.run("basic", arg="Hello Two!")
|
||||
await c2.core.sessions.delete(session.id)
|
||||
assert r2 == "Workflow two received Hello Two!"
|
||||
@@ -1,12 +0,0 @@
|
||||
from llama_index.core.workflow import Context, StartEvent, StopEvent, Workflow, step
|
||||
|
||||
|
||||
class BasicWorkflow(Workflow):
|
||||
def __init__(self, *args, **kwargs):
|
||||
self._name = kwargs.pop("name")
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
@step()
|
||||
async def run_step(self, ctx: Context, ev: StartEvent) -> StopEvent:
|
||||
received = ev.get("arg")
|
||||
return StopEvent(result=f"{self._name} received {received}")
|
||||
@@ -1,106 +0,0 @@
|
||||
import asyncio
|
||||
import multiprocessing
|
||||
import subprocess
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from llama_deploy.control_plane import ControlPlaneConfig
|
||||
from llama_deploy.message_queues import RedisMessageQueue, RedisMessageQueueConfig
|
||||
from llama_deploy.services import WorkflowServiceConfig
|
||||
|
||||
from ...utils import deploy_core, deploy_workflow
|
||||
from .workflow import BasicWorkflow
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def redis_service():
|
||||
compose_file = Path(__file__).resolve().parent / "docker-compose.yml"
|
||||
proc = subprocess.Popen(
|
||||
["docker", "compose", "-f", f"{compose_file}", "up", "-d", "--wait"]
|
||||
)
|
||||
proc.communicate()
|
||||
yield
|
||||
subprocess.Popen(["docker", "compose", "-f", f"{compose_file}", "down"])
|
||||
proc.communicate()
|
||||
|
||||
|
||||
@pytest_asyncio.fixture()
|
||||
async def mq(redis_service):
|
||||
mq = RedisMessageQueue()
|
||||
yield mq
|
||||
await mq.cleanup()
|
||||
|
||||
|
||||
def run_workflow_one():
|
||||
asyncio.run(
|
||||
deploy_workflow(
|
||||
BasicWorkflow(timeout=10, name="Workflow one"),
|
||||
WorkflowServiceConfig(
|
||||
host="127.0.0.1",
|
||||
port=8003,
|
||||
service_name="basic",
|
||||
),
|
||||
ControlPlaneConfig(topic_namespace="core_one", port=8001),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def run_workflow_two():
|
||||
asyncio.run(
|
||||
deploy_workflow(
|
||||
BasicWorkflow(timeout=10, name="Workflow two"),
|
||||
WorkflowServiceConfig(
|
||||
host="127.0.0.1",
|
||||
port=8004,
|
||||
service_name="basic",
|
||||
),
|
||||
ControlPlaneConfig(topic_namespace="core_two", port=8002),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def run_core_one():
|
||||
asyncio.run(
|
||||
deploy_core(
|
||||
ControlPlaneConfig(topic_namespace="core_one", port=8001),
|
||||
RedisMessageQueueConfig(),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def run_core_two():
|
||||
asyncio.run(
|
||||
deploy_core(
|
||||
ControlPlaneConfig(topic_namespace="core_two", port=8002),
|
||||
RedisMessageQueueConfig(),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def control_planes(redis_service):
|
||||
p1 = multiprocessing.Process(target=run_core_one)
|
||||
p1.start()
|
||||
|
||||
p2 = multiprocessing.Process(target=run_core_two)
|
||||
p2.start()
|
||||
|
||||
time.sleep(3)
|
||||
|
||||
p3 = multiprocessing.Process(target=run_workflow_one)
|
||||
p3.start()
|
||||
|
||||
p4 = multiprocessing.Process(target=run_workflow_two)
|
||||
p4.start()
|
||||
|
||||
time.sleep(3)
|
||||
|
||||
yield
|
||||
|
||||
p1.terminate()
|
||||
p2.terminate()
|
||||
p3.terminate()
|
||||
p4.terminate()
|
||||
@@ -1,15 +0,0 @@
|
||||
services:
|
||||
redis:
|
||||
image: redis:latest
|
||||
hostname: redis
|
||||
ports:
|
||||
- "6379:6379"
|
||||
healthcheck:
|
||||
test:
|
||||
- CMD-SHELL
|
||||
- -c
|
||||
- |
|
||||
redis-cli --raw incr ping
|
||||
interval: 5s
|
||||
timeout: 3s
|
||||
retries: 5
|
||||
@@ -1,41 +0,0 @@
|
||||
import asyncio
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_deploy.client import Client
|
||||
from llama_deploy.message_queues.redis import RedisMessageQueue
|
||||
from llama_deploy.types import QueueMessage
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_roundtrip(mq: RedisMessageQueue):
|
||||
# Redis pubsub has no persistence, we need to start the subscriber
|
||||
# before publishing
|
||||
async def consume():
|
||||
async for m in mq.get_messages("test"):
|
||||
return m
|
||||
|
||||
t = asyncio.create_task(consume())
|
||||
await asyncio.sleep(1)
|
||||
|
||||
test_message = QueueMessage(type="test_message", data={"message": "this is a test"})
|
||||
await mq.publish(test_message, topic="test")
|
||||
|
||||
result = await t
|
||||
assert result == test_message
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_control_planes(control_planes):
|
||||
c1 = Client(control_plane_url="http://localhost:8001")
|
||||
c2 = Client(control_plane_url="http://localhost:8002")
|
||||
|
||||
session = await c1.core.sessions.create()
|
||||
r1 = await session.run("basic", arg="Hello One!")
|
||||
await c1.core.sessions.delete(session.id)
|
||||
assert r1 == "Workflow one received Hello One!"
|
||||
|
||||
session = await c2.core.sessions.create()
|
||||
r2 = await session.run("basic", arg="Hello Two!")
|
||||
await c2.core.sessions.delete(session.id)
|
||||
assert r2 == "Workflow two received Hello Two!"
|
||||
@@ -1,12 +0,0 @@
|
||||
from llama_index.core.workflow import Context, StartEvent, StopEvent, Workflow, step
|
||||
|
||||
|
||||
class BasicWorkflow(Workflow):
|
||||
def __init__(self, *args, **kwargs):
|
||||
self._name = kwargs.pop("name")
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
@step()
|
||||
async def run_step(self, ctx: Context, ev: StartEvent) -> StopEvent:
|
||||
received = ev.get("arg")
|
||||
return StopEvent(result=f"{self._name} received {received}")
|
||||
@@ -1,28 +0,0 @@
|
||||
import asyncio
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from llama_deploy.message_queues.simple import (
|
||||
SimpleMessageQueue,
|
||||
SimpleMessageQueueConfig,
|
||||
SimpleMessageQueueServer,
|
||||
)
|
||||
|
||||
|
||||
@pytest_asyncio.fixture()
|
||||
async def simple_server():
|
||||
queue = SimpleMessageQueueServer(SimpleMessageQueueConfig(port=8009))
|
||||
t = asyncio.create_task(queue.launch_server())
|
||||
# let message queue boot up
|
||||
await asyncio.sleep(1)
|
||||
|
||||
yield
|
||||
|
||||
t.cancel()
|
||||
await t
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mq(simple_server):
|
||||
return SimpleMessageQueue(SimpleMessageQueueConfig(port=8009))
|
||||
@@ -1,19 +0,0 @@
|
||||
import asyncio
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_deploy.message_queues import SimpleMessageQueue
|
||||
from llama_deploy.types import QueueMessage
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_roundtrip(mq: SimpleMessageQueue):
|
||||
# produce a message
|
||||
test_message = QueueMessage(type="test_message", data={"message": "this is a test"})
|
||||
await mq.publish(test_message, topic="test")
|
||||
|
||||
await asyncio.sleep(0)
|
||||
|
||||
async for m in mq.get_messages("test"):
|
||||
assert m == test_message
|
||||
break
|
||||
@@ -1,21 +0,0 @@
|
||||
import asyncio
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_deploy.message_queues import (
|
||||
SimpleMessageQueueConfig,
|
||||
SimpleMessageQueueServer,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cancel_launch_server():
|
||||
mq = SimpleMessageQueueServer(SimpleMessageQueueConfig(port=8009))
|
||||
t = asyncio.create_task(mq.launch_server())
|
||||
|
||||
# Make sure the queue starts
|
||||
await asyncio.sleep(1)
|
||||
|
||||
# Cancel
|
||||
t.cancel()
|
||||
await t
|
||||
@@ -1,157 +0,0 @@
|
||||
import asyncio
|
||||
|
||||
import httpx
|
||||
from llama_index.core.workflow import Workflow
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
from llama_deploy.control_plane.server import ControlPlaneConfig, ControlPlaneServer
|
||||
from llama_deploy.message_queues import (
|
||||
AbstractMessageQueue,
|
||||
KafkaMessageQueue,
|
||||
KafkaMessageQueueConfig,
|
||||
RabbitMQMessageQueue,
|
||||
RabbitMQMessageQueueConfig,
|
||||
RedisMessageQueue,
|
||||
RedisMessageQueueConfig,
|
||||
SimpleMessageQueueConfig,
|
||||
SimpleMessageQueueServer,
|
||||
)
|
||||
from llama_deploy.message_queues.simple import SimpleMessageQueue
|
||||
from llama_deploy.services.network_service_manager import NetworkServiceManager
|
||||
from llama_deploy.services.workflow import WorkflowService, WorkflowServiceConfig
|
||||
|
||||
DEFAULT_TIMEOUT = 120.0
|
||||
|
||||
|
||||
def _get_message_queue_config(config_dict: dict) -> BaseSettings:
|
||||
key = next(iter(config_dict.keys()))
|
||||
if key == SimpleMessageQueueConfig.__name__:
|
||||
return SimpleMessageQueueConfig(**config_dict[key])
|
||||
elif key == KafkaMessageQueueConfig.__name__:
|
||||
return KafkaMessageQueueConfig(**config_dict[key])
|
||||
elif key == RabbitMQMessageQueueConfig.__name__:
|
||||
return RabbitMQMessageQueueConfig(**config_dict[key])
|
||||
elif key == RedisMessageQueueConfig.__name__:
|
||||
return RedisMessageQueueConfig(**config_dict[key])
|
||||
else:
|
||||
raise ValueError(f"Unknown message queue: {key}")
|
||||
|
||||
|
||||
def _get_message_queue_client(config: BaseSettings) -> AbstractMessageQueue:
|
||||
if isinstance(config, SimpleMessageQueueConfig):
|
||||
return SimpleMessageQueue(config)
|
||||
elif isinstance(config, KafkaMessageQueueConfig):
|
||||
return KafkaMessageQueue(config)
|
||||
elif isinstance(config, RabbitMQMessageQueueConfig):
|
||||
return RabbitMQMessageQueue(config)
|
||||
elif isinstance(config, RedisMessageQueueConfig):
|
||||
return RedisMessageQueue(config)
|
||||
else:
|
||||
raise ValueError(f"Invalid message queue config: {config}")
|
||||
|
||||
|
||||
async def deploy_core(
|
||||
control_plane_config: ControlPlaneConfig | None = None,
|
||||
message_queue_config: BaseSettings | None = None,
|
||||
disable_message_queue: bool = False,
|
||||
disable_control_plane: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Deploy the core components of the llama_deploy system.
|
||||
|
||||
This function sets up and launches the message queue, control plane, and orchestrator.
|
||||
It handles the initialization and connection of these core components.
|
||||
|
||||
Args:
|
||||
control_plane_config (Optional[ControlPlaneConfig]): Configuration for the control plane.
|
||||
message_queue_config (Optional[BaseSettings]): Configuration for the message queue. Defaults to a local SimpleMessageQueue.
|
||||
disable_message_queue (bool): Whether to disable deploying the message queue. Defaults to False.
|
||||
disable_control_plane (bool): Whether to disable deploying the control plane. Defaults to False.
|
||||
|
||||
Raises:
|
||||
ValueError: If an unknown message queue type is specified in the config.
|
||||
Exception: If any of the launched tasks encounter an error.
|
||||
"""
|
||||
control_plane_config = control_plane_config or ControlPlaneConfig()
|
||||
message_queue_config = message_queue_config or SimpleMessageQueueConfig()
|
||||
|
||||
tasks = []
|
||||
|
||||
message_queue_client = _get_message_queue_client(message_queue_config)
|
||||
# If needed, start the SimpleMessageQueueServer
|
||||
if (
|
||||
isinstance(message_queue_config, SimpleMessageQueueConfig)
|
||||
and not disable_message_queue
|
||||
):
|
||||
queue = SimpleMessageQueueServer(message_queue_config)
|
||||
tasks.append(asyncio.create_task(queue.launch_server()))
|
||||
# let message queue boot up
|
||||
await asyncio.sleep(2)
|
||||
|
||||
if not disable_control_plane:
|
||||
control_plane = ControlPlaneServer(
|
||||
message_queue_client, config=control_plane_config
|
||||
)
|
||||
tasks.append(asyncio.create_task(control_plane.launch_server()))
|
||||
# let service spin up
|
||||
await asyncio.sleep(4)
|
||||
|
||||
# let things run
|
||||
try:
|
||||
await asyncio.gather(*tasks)
|
||||
except (Exception, asyncio.CancelledError):
|
||||
await message_queue_client.cleanup()
|
||||
for task in tasks:
|
||||
if not task.done():
|
||||
task.cancel()
|
||||
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
|
||||
async def deploy_workflow(
|
||||
workflow: Workflow,
|
||||
workflow_config: WorkflowServiceConfig,
|
||||
control_plane_config: ControlPlaneConfig | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Deploy a workflow as a service within the llama_deploy system.
|
||||
|
||||
This function sets up a workflow as a service, connects it to the message queue,
|
||||
and registers it with the control plane.
|
||||
|
||||
Args:
|
||||
workflow (Workflow): The workflow to be deployed as a service.
|
||||
workflow_config (WorkflowServiceConfig): Configuration for the workflow service.
|
||||
control_plane_config (Optional[ControlPlaneConfig]): Configuration for the control plane.
|
||||
|
||||
Raises:
|
||||
httpx.HTTPError: If there's an error communicating with the control plane.
|
||||
ValueError: If an invalid message queue config is encountered.
|
||||
Exception: If any of the launched tasks encounter an error.
|
||||
"""
|
||||
control_plane_config = control_plane_config or ControlPlaneConfig()
|
||||
control_plane_url = control_plane_config.url
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(f"{control_plane_url}/queue_config")
|
||||
queue_config_dict = response.json()
|
||||
|
||||
message_queue_config = _get_message_queue_config(queue_config_dict)
|
||||
message_queue_client = _get_message_queue_client(message_queue_config)
|
||||
|
||||
# override the service manager, while maintaining dict of existing services
|
||||
workflow._service_manager = NetworkServiceManager(
|
||||
workflow._service_manager._services
|
||||
)
|
||||
|
||||
service = WorkflowService(
|
||||
workflow=workflow,
|
||||
message_queue=message_queue_client,
|
||||
config=workflow_config,
|
||||
)
|
||||
|
||||
# register to control plane
|
||||
await service.register_to_control_plane(control_plane_url)
|
||||
await asyncio.sleep(1)
|
||||
|
||||
await service.launch_server()
|
||||
@@ -1,6 +1,7 @@
|
||||
import asyncio
|
||||
|
||||
from llama_index.core.workflow import Workflow, StartEvent, StopEvent, step
|
||||
from workflows import Workflow, step
|
||||
from workflows.events import StartEvent, StopEvent
|
||||
|
||||
|
||||
# create a dummy workflow
|
||||
|
||||
@@ -115,7 +115,7 @@ services:
|
||||
```
|
||||
|
||||
The YAML code above defines the deployment that LlamaDeploy will create and run as a service. As you can
|
||||
see, this deployment has a name, some configuration for the control plane and one service to wrap our workflow. The
|
||||
see, this deployment has a name and one service to wrap our workflow. The
|
||||
service will look for a Python variable named `llamacloud_workflow` in a Python module named `workflow` and run the workflow.
|
||||
|
||||
At this point we have all we need to run this deployment. Ideally, we would have the API server already running
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
import json
|
||||
import asyncio
|
||||
import argparse
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, Any
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict
|
||||
|
||||
from llama_cloud.client import LlamaCloud
|
||||
from llama_cloud.types import (
|
||||
@@ -11,15 +12,8 @@ from llama_cloud.types import (
|
||||
ConfigurableDataSourceNames,
|
||||
DataSourceCreate,
|
||||
)
|
||||
from llama_index.core.workflow import (
|
||||
Event,
|
||||
StartEvent,
|
||||
StopEvent,
|
||||
Workflow,
|
||||
step,
|
||||
)
|
||||
|
||||
import os
|
||||
from workflows import Workflow, step
|
||||
from workflows.events import Event, StartEvent, StopEvent
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
import yaml
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
import nest_asyncio
|
||||
from llama_index.core.workflow import Workflow, StartEvent, StopEvent, step
|
||||
import yaml
|
||||
from llama_index.indices.managed.llama_cloud import LlamaCloudIndex
|
||||
from workflows import Workflow, step
|
||||
from workflows.events import StartEvent, StopEvent
|
||||
|
||||
# Apply nest_asyncio at the start
|
||||
nest_asyncio.apply()
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
import asyncio
|
||||
|
||||
from llama_index.core.workflow import Workflow, StartEvent, StopEvent, step
|
||||
from pyfiglet import Figlet
|
||||
from cowpy import cow
|
||||
from fortune import fortune
|
||||
from pyfiglet import Figlet
|
||||
from workflows import Workflow, step
|
||||
from workflows.events import StartEvent, StopEvent
|
||||
|
||||
|
||||
# create a dummy workflow
|
||||
|
||||
@@ -34,7 +34,6 @@ Let's walk through the important files and folders:
|
||||
|
||||
The application relies on different components:
|
||||
|
||||
- A Redis instance used by the LlamaDeploy message queue
|
||||
- A Qdrant instance used by the RAG workflow
|
||||
- A LlamaDeploy API server instance managing the deployment
|
||||
- The Reflex application serving the UI at http://localhost:3000
|
||||
|
||||
@@ -8,18 +8,6 @@ services:
|
||||
volumes:
|
||||
- qdrant_data:/qdrant/storage
|
||||
|
||||
redis:
|
||||
# LlamaDeploy message queue
|
||||
image: redis:latest
|
||||
hostname: redis
|
||||
ports:
|
||||
- "6379:6379"
|
||||
healthcheck:
|
||||
test: redis-cli --raw incr ping
|
||||
interval: 5s
|
||||
timeout: 3s
|
||||
retries: 5
|
||||
|
||||
apiserver:
|
||||
# LlamaDeploy API server, will run the workflows
|
||||
image: llamaindex/llama-deploy:main
|
||||
|
||||
@@ -4,16 +4,11 @@ from typing import List
|
||||
from llama_index.core.llms import ChatMessage
|
||||
from llama_index.core.memory import ChatMemoryBuffer
|
||||
from llama_index.core.tools import FunctionTool
|
||||
from llama_index.core.workflow import (
|
||||
Context,
|
||||
Event,
|
||||
StartEvent,
|
||||
StopEvent,
|
||||
Workflow,
|
||||
step,
|
||||
)
|
||||
from llama_index.llms.openai import OpenAI
|
||||
|
||||
from workflows import Context, Workflow, step
|
||||
from workflows.events import Event, StartEvent, StopEvent
|
||||
|
||||
from .rag_workflow import RAGWorkflow
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
@@ -6,20 +6,15 @@ from llama_index.core import SimpleDirectoryReader, VectorStoreIndex
|
||||
from llama_index.core.node_parser import SemanticSplitterNodeParser, SentenceSplitter
|
||||
from llama_index.core.response_synthesizers import CompactAndRefine
|
||||
from llama_index.core.schema import NodeWithScore
|
||||
from llama_index.core.workflow import (
|
||||
Context,
|
||||
Event,
|
||||
StartEvent,
|
||||
StopEvent,
|
||||
Workflow,
|
||||
step,
|
||||
)
|
||||
from llama_index.embeddings.openai import OpenAIEmbedding
|
||||
from llama_index.llms.openai import OpenAI
|
||||
from llama_index.postprocessor.rankgpt_rerank import RankGPTRerank
|
||||
from llama_index.vector_stores.qdrant import QdrantVectorStore
|
||||
from qdrant_client import AsyncQdrantClient, QdrantClient
|
||||
|
||||
from workflows import Context, Workflow, step
|
||||
from workflows.events import Event, StartEvent, StopEvent
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
|
||||
@@ -7,7 +7,8 @@ The `src` folder contains a `workflow.py` file defining a trivial workflow:
|
||||
|
||||
```python
|
||||
import asyncio
|
||||
from llama_index.core.workflow import Workflow, StartEvent, StopEvent, step
|
||||
from workflows import Context, Workflow, step
|
||||
from workflows.events import Event, StartEvent, StopEvent
|
||||
|
||||
|
||||
class EchoWorkflow(Workflow):
|
||||
@@ -61,7 +62,7 @@ ui:
|
||||
```
|
||||
|
||||
The YAML code above defines the deployment that LlamaDeploy will create and run as a service. As you can
|
||||
see, this deployment has a name, some configuration for the control plane and one service to wrap our workflow. The
|
||||
see, this deployment has a name and one service to wrap our workflow. The
|
||||
service will look for a Python variable named `echo_workflow` in a Python module named `workflow` and run the workflow.
|
||||
|
||||
This example includes a Next.js-based UI interface that allows you to interact with your deployment through a web browser.
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import asyncio
|
||||
|
||||
from llama_index.core.workflow import Workflow, StartEvent, StopEvent, step
|
||||
from workflows import Workflow, step
|
||||
from workflows.events import StartEvent, StopEvent
|
||||
|
||||
|
||||
# create a dummy workflow
|
||||
|
||||
@@ -1,80 +0,0 @@
|
||||
# Using Redis as Message Queue provider
|
||||
|
||||
> [!NOTE]
|
||||
> This example is mostly based on the [Quick Start](../quick_start/README.md), see there for more details.
|
||||
|
||||
We'll be deploying a simple workflow on a container running LlamaDeploy using Redis as the message queue
|
||||
provider. The Redis container will be started in a different container using Docker Compose.
|
||||
|
||||
This is the code defining our deployment, with comments to the relevant bits:
|
||||
|
||||
```yaml
|
||||
name: RedisMessageQueue
|
||||
|
||||
control-plane:
|
||||
port: 8000
|
||||
|
||||
message-queue:
|
||||
type: redis
|
||||
# what follows depends on what's in the docker compose file
|
||||
host: redis
|
||||
port: 6379
|
||||
|
||||
default-service: counter_workflow_service
|
||||
|
||||
services:
|
||||
counter_workflow_service:
|
||||
name: Counter Workflow
|
||||
source:
|
||||
type: local
|
||||
name: .
|
||||
path: workflow:counter_workflow
|
||||
```
|
||||
|
||||
Note how we the deployment file contains the `message-queue` key to instruct LlamaDeploy to use
|
||||
Redis as the message queue provider.
|
||||
|
||||
Before starting the containers, two things to note about how LlamaDeploy is configured:
|
||||
|
||||
- We mount our application, consisting of the `deployment.yml` file and a Python
|
||||
module `workflow.py` containing the LlamaIndex code implementing the workflow, under the path
|
||||
`/opt/app` inside the container
|
||||
- We set the `LLAMA_DEPLOY_APISERVER_RC_PATH` environment variable so that when LlamaDeploy
|
||||
starts, it will look under the `/opt/app` folder for deployments to create automatically.
|
||||
|
||||
We can now start the Docker containers using Compose:
|
||||
|
||||
```
|
||||
$ docker compose up -d
|
||||
```
|
||||
|
||||
When the containers are up and running, we can use `llamactl` from our local host to
|
||||
interact with the deployment:
|
||||
|
||||
```
|
||||
$ llamactl status
|
||||
LlamaDeploy is up and running.
|
||||
|
||||
Active deployments:
|
||||
- RedisMessageQueue
|
||||
```
|
||||
|
||||
Our workflow is now part of the `RedisMessageQueue` deployment and ready to serve requests! Since we want to persist
|
||||
a counter across workflow runs, first we manually create a session:
|
||||
|
||||
```
|
||||
$ llamactl sessions create -d RedisMessageQueue
|
||||
session_id='<YOUR_SESSION_ID>' task_ids=[] state={}
|
||||
```
|
||||
|
||||
Then we run the workflow multiple times, always using the same session we created in the previous step:
|
||||
|
||||
```
|
||||
$ lamactl run --deployment RedisMessageQueue --arg amount 3 -i <YOUR_SESSION_ID>
|
||||
Current balance: 3.0
|
||||
$ lamactl run --deployment RedisMessageQueue --arg amount 3 -i <YOUR_SESSION_ID>
|
||||
Current balance: 3.5
|
||||
```
|
||||
|
||||
_Note_: If you have multiple replicas of the workflow and control plane and only want one replica to process messages,
|
||||
set `REDIS_EXCLUSIVE_MODE` to true.
|
||||
@@ -1,38 +0,0 @@
|
||||
services:
|
||||
redis:
|
||||
# Use as KV store
|
||||
image: redis:latest
|
||||
hostname: redis
|
||||
ports:
|
||||
- "6379:6379"
|
||||
healthcheck:
|
||||
test:
|
||||
- CMD-SHELL
|
||||
- -c
|
||||
- |
|
||||
redis-cli --raw incr ping
|
||||
interval: 5s
|
||||
timeout: 3s
|
||||
retries: 5
|
||||
|
||||
apiserver:
|
||||
image: llamaindex/llama-deploy:main
|
||||
hostname: apiserver
|
||||
ports:
|
||||
- "4501:4501"
|
||||
environment:
|
||||
LLAMA_DEPLOY_APISERVER_RC_PATH: /opt/app/
|
||||
depends_on:
|
||||
redis:
|
||||
condition: service_healthy
|
||||
healthcheck:
|
||||
test:
|
||||
- CMD-SHELL
|
||||
- -c
|
||||
- llamactl status
|
||||
interval: 5s
|
||||
timeout: 3s
|
||||
retries: 5
|
||||
volumes:
|
||||
- ./src:/opt/app
|
||||
working_dir: /opt/app
|
||||
@@ -1,20 +0,0 @@
|
||||
name: RedisMessageQueue
|
||||
|
||||
control-plane:
|
||||
port: 8000
|
||||
|
||||
message-queue:
|
||||
type: redis
|
||||
# what follows depends on what's in the docker compose file
|
||||
host: redis
|
||||
port: 6379
|
||||
|
||||
default-service: counter_workflow_service
|
||||
|
||||
services:
|
||||
counter_workflow_service:
|
||||
name: Counter Workflow
|
||||
source:
|
||||
type: local
|
||||
name: .
|
||||
path: workflow:counter_workflow
|
||||
@@ -1,26 +0,0 @@
|
||||
import asyncio
|
||||
|
||||
from llama_index.core.workflow import Context, StartEvent, StopEvent, Workflow, step
|
||||
|
||||
|
||||
# create a dummy workflow
|
||||
class CounterWorkflow(Workflow):
|
||||
"""A dummy workflow with only one step sending back the input given."""
|
||||
|
||||
@step()
|
||||
async def run_step(self, ctx: Context, ev: StartEvent) -> StopEvent:
|
||||
amount = float(ev.get("amount", 0.0))
|
||||
total = await ctx.get("total", 0.0) + amount
|
||||
await ctx.set("total", total)
|
||||
return StopEvent(result=f"Current balance: {total}")
|
||||
|
||||
|
||||
counter_workflow = CounterWorkflow()
|
||||
|
||||
|
||||
async def main():
|
||||
print(await counter_workflow.run(message=10.0))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@@ -1,79 +0,0 @@
|
||||
# Using Redis as State Store
|
||||
|
||||
> [!NOTE]
|
||||
> This example is mostly based on the [Quick Start](../quick_start/README.md), see there for more details.
|
||||
|
||||
We'll be deploying a simple workflow on a local instance of LlamaDeploy using Redis as a scalable storage for the
|
||||
global state. See [the Control Plane documentation](https://docs.llamaindex.ai/en/stable/module_guides/llama_deploy/20_core_components/#control-plane)
|
||||
for an overview of what the global state consists of and when the default storage might not be enough.
|
||||
|
||||
Before starting LlamaDeploy, use Docker compose to start the Redis container and run it in the background:
|
||||
|
||||
```
|
||||
$ docker compose up -d
|
||||
```
|
||||
|
||||
Make sure to install the package to support the Redis KV store in the virtual environment where we'll run LlamaDeploy:
|
||||
|
||||
```
|
||||
$ pip install -r requirements.txt
|
||||
```
|
||||
|
||||
This is the code defining our deployment, with comments to the relevant bits:
|
||||
|
||||
```yaml
|
||||
name: RedisStateStore
|
||||
|
||||
control-plane:
|
||||
port: 8000
|
||||
# Here we tell the Control Plane to use Redis
|
||||
state_store_uri: redis://localhost:6379
|
||||
|
||||
default-service: counter_workflow_service
|
||||
|
||||
services:
|
||||
counter_workflow_service:
|
||||
name: Counter Workflow
|
||||
source:
|
||||
type: local
|
||||
name: src
|
||||
path: workflow:counter_workflow
|
||||
```
|
||||
|
||||
Note how we provide a connection URI for Redis in the `state_store_uri` field of the control plane configuration.
|
||||
|
||||
At this point we have all we need to run this deployment. Ideally, we would have the API server already running
|
||||
somewhere in the cloud, but to get started let's start an instance locally. Run the following python script
|
||||
from a shell:
|
||||
|
||||
```
|
||||
$ python -m llama_deploy.apiserver
|
||||
INFO: Started server process [10842]
|
||||
INFO: Waiting for application startup.
|
||||
INFO: Application startup complete.
|
||||
INFO: Uvicorn running on http://0.0.0.0:4501 (Press CTRL+C to quit)
|
||||
```
|
||||
|
||||
From another shell, use the CLI, `llamactl`, to create the deployment:
|
||||
|
||||
```
|
||||
$ llamactl deploy redis_store.yml
|
||||
Deployment successful: RedisStateStore
|
||||
```
|
||||
|
||||
Our workflow is now part of the `RedisStateStore` deployment and ready to serve requests! Since we want to persist
|
||||
a counter across workflow runs, first we manually create a session:
|
||||
|
||||
```
|
||||
$ llamactl sessions create -d RedisStateStore
|
||||
session_id='<YOUR_SESSION_ID>' task_ids=[] state={}
|
||||
```
|
||||
|
||||
Then we run the workflow multiple times, always using the same session we created in the previous step:
|
||||
|
||||
```
|
||||
$ lamactl run --deployment RedisStateStore --arg amount 3 -i <YOUR_SESSION_ID>
|
||||
Current balance: 3.0
|
||||
$ lamactl run --deployment RedisStateStore --arg amount 3 -i <YOUR_SESSION_ID>
|
||||
Current balance: 3.5
|
||||
```
|
||||
@@ -1,12 +0,0 @@
|
||||
services:
|
||||
redis:
|
||||
# Use as KV store
|
||||
image: redis:latest
|
||||
hostname: redis
|
||||
ports:
|
||||
- "6379:6379"
|
||||
healthcheck:
|
||||
test: redis-cli --raw incr ping
|
||||
interval: 5s
|
||||
timeout: 3s
|
||||
retries: 5
|
||||
@@ -1,15 +0,0 @@
|
||||
name: RedisStateStore
|
||||
|
||||
control-plane:
|
||||
port: 8000
|
||||
state_store_uri: redis://localhost:6379
|
||||
|
||||
default-service: counter_workflow_service
|
||||
|
||||
services:
|
||||
counter_workflow_service:
|
||||
name: Counter Workflow
|
||||
source:
|
||||
type: local
|
||||
name: src
|
||||
path: workflow:counter_workflow
|
||||
@@ -1 +0,0 @@
|
||||
llama-index-storage-kvstore-redis
|
||||
@@ -1,26 +0,0 @@
|
||||
import asyncio
|
||||
|
||||
from llama_index.core.workflow import Context, StartEvent, StopEvent, Workflow, step
|
||||
|
||||
|
||||
# create a dummy workflow
|
||||
class CounterWorkflow(Workflow):
|
||||
"""A dummy workflow with only one step sending back the input given."""
|
||||
|
||||
@step()
|
||||
async def run_step(self, ctx: Context, ev: StartEvent) -> StopEvent:
|
||||
amount = float(ev.get("amount", 0.0))
|
||||
total = await ctx.get("total", 0.0) + amount
|
||||
await ctx.set("total", total)
|
||||
return StopEvent(result=f"Current balance: {total}")
|
||||
|
||||
|
||||
counter_workflow = CounterWorkflow()
|
||||
|
||||
|
||||
async def main():
|
||||
print(await counter_workflow.run(message=10.0))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@@ -37,12 +37,6 @@ class DeploymentError(Exception): ...
|
||||
|
||||
|
||||
class Deployment:
|
||||
"""A Deployment consists of running services and core component instances.
|
||||
|
||||
Every Deployment is self contained, running a dedicated instance of the control plane
|
||||
and the message queue along with any service defined in the configuration object.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
@@ -65,7 +59,7 @@ class Deployment:
|
||||
self._deployment_path = (
|
||||
deployment_path if local else deployment_path / config.name
|
||||
)
|
||||
self._client = Client(control_plane_url=config.control_plane.url)
|
||||
self._client = Client()
|
||||
self._default_service: str | None = None
|
||||
self._running = False
|
||||
self._service_tasks: list[asyncio.Task] = []
|
||||
|
||||
@@ -1,9 +1,8 @@
|
||||
import sys
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from typing import Annotated, Any, Optional, Union
|
||||
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional
|
||||
|
||||
if sys.version_info >= (3, 11):
|
||||
from typing import Self
|
||||
@@ -13,24 +12,6 @@ else: # pragma: no cover
|
||||
import yaml
|
||||
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
||||
|
||||
from llama_deploy.control_plane.server import ControlPlaneConfig
|
||||
from llama_deploy.message_queues import (
|
||||
KafkaMessageQueueConfig,
|
||||
RabbitMQMessageQueueConfig,
|
||||
RedisMessageQueueConfig,
|
||||
SimpleMessageQueueConfig,
|
||||
)
|
||||
|
||||
MessageQueueConfig = Annotated[
|
||||
Union[
|
||||
"KafkaMessageQueueConfig",
|
||||
"RabbitMQMessageQueueConfig",
|
||||
"RedisMessageQueueConfig",
|
||||
"SimpleMessageQueueConfig",
|
||||
],
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
|
||||
|
||||
class SourceType(str, Enum):
|
||||
"""Supported types for the `Service.source` parameter."""
|
||||
@@ -116,11 +97,9 @@ class UIService(Service):
|
||||
class DeploymentConfig(BaseModel):
|
||||
"""Model definition mapping a deployment config file."""
|
||||
|
||||
model_config = ConfigDict(populate_by_name=True, extra="forbid")
|
||||
model_config = ConfigDict(populate_by_name=True, extra="ignore")
|
||||
|
||||
name: str
|
||||
control_plane: "ControlPlaneConfig"
|
||||
message_queue: MessageQueueConfig | None = Field(None)
|
||||
default_service: str | None = Field(None)
|
||||
services: dict[str, Service]
|
||||
ui: UIService | None = None
|
||||
|
||||
@@ -11,30 +11,11 @@ from pydantic import BaseModel
|
||||
# Import pydantic models
|
||||
from llama_deploy.apiserver.deployment_config_parser import (
|
||||
DeploymentConfig,
|
||||
MessageQueueConfig,
|
||||
Service,
|
||||
ServiceSource,
|
||||
SourceType,
|
||||
UIService,
|
||||
)
|
||||
from llama_deploy.control_plane.server import ControlPlaneConfig
|
||||
from llama_deploy.message_queues import (
|
||||
KafkaMessageQueueConfig,
|
||||
RabbitMQMessageQueueConfig,
|
||||
RedisMessageQueueConfig,
|
||||
SimpleMessageQueueConfig,
|
||||
)
|
||||
|
||||
SUPPORTED_MESSAGE_QUEUES: Dict[str, Type[MessageQueueConfig]] = {
|
||||
x.model_json_schema()["properties"]["type"]["default"]: x # type: ignore
|
||||
for x in [
|
||||
KafkaMessageQueueConfig,
|
||||
RabbitMQMessageQueueConfig,
|
||||
RedisMessageQueueConfig,
|
||||
SimpleMessageQueueConfig,
|
||||
]
|
||||
if hasattr(x, "model_json_schema")
|
||||
}
|
||||
|
||||
|
||||
@click.command()
|
||||
@@ -50,18 +31,6 @@ SUPPORTED_MESSAGE_QUEUES: Dict[str, Type[MessageQueueConfig]] = {
|
||||
default=None,
|
||||
help="Directory where the project will be created",
|
||||
)
|
||||
@click.option(
|
||||
"--port",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Port for the control plane server",
|
||||
)
|
||||
@click.option(
|
||||
"--message-queue-type",
|
||||
type=click.Choice(["simple", "redis", "rabbitmq", "kafka"]),
|
||||
default=None,
|
||||
help="Type of message queue to use",
|
||||
)
|
||||
@click.option(
|
||||
"--template",
|
||||
type=click.Choice(["basic", "none"]), # For future: add more templates
|
||||
@@ -71,8 +40,6 @@ SUPPORTED_MESSAGE_QUEUES: Dict[str, Type[MessageQueueConfig]] = {
|
||||
def init(
|
||||
name: Optional[str] = None,
|
||||
destination: Optional[str] = None,
|
||||
port: Optional[int] = None,
|
||||
message_queue_type: Optional[str] = None,
|
||||
template: Optional[str] = None,
|
||||
) -> None:
|
||||
"""Bootstrap a new llama-deploy project with a basic workflow and configuration."""
|
||||
@@ -85,6 +52,7 @@ def init(
|
||||
show_default=True,
|
||||
type=str,
|
||||
)
|
||||
assert name
|
||||
|
||||
if destination is None:
|
||||
destination = click.prompt(
|
||||
@@ -93,22 +61,7 @@ def init(
|
||||
type=str,
|
||||
show_default=True,
|
||||
)
|
||||
|
||||
if port is None:
|
||||
port = click.prompt(
|
||||
"Control plane port",
|
||||
default="8000",
|
||||
type=int,
|
||||
show_default=True,
|
||||
)
|
||||
|
||||
if message_queue_type is None:
|
||||
message_queue_type = click.prompt(
|
||||
"Select message queue type\n",
|
||||
default="simple",
|
||||
type=click.Choice(SUPPORTED_MESSAGE_QUEUES.keys()),
|
||||
show_default=True,
|
||||
)
|
||||
assert destination
|
||||
|
||||
if template is None:
|
||||
click.echo("\nWorkflow template:")
|
||||
@@ -143,7 +96,7 @@ def init(
|
||||
click.echo(f"Created project directory: {project_dir}")
|
||||
|
||||
# Create deployment.yml using pydantic models
|
||||
deployment_config = create_deployment_config(name, port, message_queue_type, use_ui)
|
||||
deployment_config = create_deployment_config(name, use_ui)
|
||||
deployment_path = project_dir / "deployment.yml"
|
||||
|
||||
# Exclude several fields that would only confuse users
|
||||
@@ -391,14 +344,6 @@ def write_yaml_with_comments(
|
||||
|
||||
# Add section comments
|
||||
section_comments = {
|
||||
"control_plane:": [
|
||||
"# Control plane configuration",
|
||||
"# The control plane manages the state of the system and coordinates services",
|
||||
],
|
||||
"message_queue:": [
|
||||
"# Message queue configuration",
|
||||
"# The message queue handles communication between services",
|
||||
],
|
||||
"default_service:": [
|
||||
"# The default service to use when no service is specified",
|
||||
],
|
||||
@@ -431,20 +376,8 @@ def write_yaml_with_comments(
|
||||
f.write(commented_yaml)
|
||||
|
||||
|
||||
def create_deployment_config(
|
||||
name: str, port: int, message_queue_type: str, use_ui: bool = False
|
||||
) -> DeploymentConfig:
|
||||
def create_deployment_config(name: str, use_ui: bool = False) -> DeploymentConfig:
|
||||
"""Create a deployment configuration using pydantic models."""
|
||||
# Create control plane config
|
||||
control_plane = ControlPlaneConfig(port=port)
|
||||
|
||||
# Create message queue config
|
||||
message_queue_cls = SUPPORTED_MESSAGE_QUEUES.get(message_queue_type)
|
||||
if message_queue_cls is None:
|
||||
raise ValueError(f"Message queue type {message_queue_type} not supported")
|
||||
|
||||
message_queue = message_queue_cls()
|
||||
|
||||
# Create the example service
|
||||
service = Service(
|
||||
name="Example Workflow",
|
||||
@@ -478,8 +411,6 @@ def create_deployment_config(
|
||||
# Create the deployment config
|
||||
deployment_config = DeploymentConfig(
|
||||
name=name,
|
||||
control_plane=control_plane,
|
||||
message_queue=message_queue,
|
||||
default_service="example_workflow",
|
||||
services={"example_workflow": service},
|
||||
ui=ui_service,
|
||||
|
||||
@@ -14,7 +14,6 @@ class _BaseClient(BaseSettings):
|
||||
model_config = SettingsConfigDict(env_prefix="LLAMA_DEPLOY_")
|
||||
|
||||
api_server_url: str = "http://localhost:4501"
|
||||
control_plane_url: str = "http://localhost:8000"
|
||||
disable_ssl: bool = False
|
||||
timeout: float | None = 120.0
|
||||
poll_interval: float = 0.5
|
||||
|
||||
@@ -2,7 +2,7 @@ import asyncio
|
||||
from typing import Any
|
||||
|
||||
from .base import _BaseClient
|
||||
from .models import ApiServer, Core, make_sync
|
||||
from .models import ApiServer, make_sync
|
||||
|
||||
|
||||
class Client(_BaseClient):
|
||||
@@ -42,17 +42,8 @@ class Client(_BaseClient):
|
||||
"""Access the API Server functionalities."""
|
||||
return ApiServer(client=self, id="apiserver")
|
||||
|
||||
@property
|
||||
def core(self) -> Core:
|
||||
"""Access the Control Plane functionalities."""
|
||||
return Core(client=self, id="core")
|
||||
|
||||
|
||||
class _SyncClient(_BaseClient):
|
||||
@property
|
||||
def apiserver(self) -> Any:
|
||||
return make_sync(ApiServer)(client=self, id="apiserver")
|
||||
|
||||
@property
|
||||
def core(self) -> Any:
|
||||
return make_sync(Core)(client=self, id="core")
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
from .apiserver import ApiServer
|
||||
from .core import Core
|
||||
from .model import Collection, Model, make_sync
|
||||
|
||||
__all__ = ["ApiServer", "Collection", "Core", "Model", "make_sync"]
|
||||
__all__ = ["ApiServer", "Collection", "Model", "make_sync"]
|
||||
|
||||
@@ -11,9 +11,9 @@ import json
|
||||
from typing import Any, AsyncGenerator, TextIO
|
||||
|
||||
import httpx
|
||||
from llama_index.core.workflow.context_serializers import JsonSerializer
|
||||
from llama_index.core.workflow.events import Event
|
||||
from pydantic import Field
|
||||
from workflows.context import JsonSerializer
|
||||
from workflows.events import Event
|
||||
|
||||
from llama_deploy.types.apiserver import Status, StatusEnum
|
||||
from llama_deploy.types.core import (
|
||||
|
||||
@@ -1,316 +0,0 @@
|
||||
"""Client functionalities to operate on the Control Plane.
|
||||
|
||||
This module allows the client to use all the functionalities
|
||||
from the Control Plane. For this to work, the Control Plane
|
||||
must be up and its URL (by default `http://localhost:8000`)
|
||||
reachable by the host executing the client code.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
from typing import Any, AsyncGenerator
|
||||
|
||||
import httpx
|
||||
from llama_index.core.workflow import Event
|
||||
from llama_index.core.workflow.context_serializers import JsonSerializer
|
||||
|
||||
from llama_deploy.types.core import (
|
||||
EventDefinition,
|
||||
ServiceDefinition,
|
||||
TaskDefinition,
|
||||
TaskResult,
|
||||
)
|
||||
|
||||
from .model import Collection, Model
|
||||
|
||||
|
||||
class Session(Model):
|
||||
"""A model representing a Session."""
|
||||
|
||||
async def run(self, service_name: str, **run_kwargs: Any) -> str:
|
||||
"""Implements the workflow-based run API for a session."""
|
||||
task_input = json.dumps(run_kwargs)
|
||||
task_def = TaskDefinition(input=task_input, service_id=service_name)
|
||||
task_id = await self._do_create_task(task_def)
|
||||
|
||||
# wait for task to complete, up to timeout seconds
|
||||
async def _get_result() -> str:
|
||||
while True:
|
||||
task_result = await self._do_get_task_result(task_id)
|
||||
|
||||
if isinstance(task_result, TaskResult):
|
||||
return task_result.result or ""
|
||||
await asyncio.sleep(self.client.poll_interval)
|
||||
|
||||
return await asyncio.wait_for(_get_result(), timeout=self.client.timeout)
|
||||
|
||||
async def run_nowait(self, service_name: str, **run_kwargs: Any) -> str:
|
||||
"""Implements the workflow-based run API for a session, but does not wait for the task to complete."""
|
||||
|
||||
task_input = json.dumps(run_kwargs)
|
||||
task_def = TaskDefinition(input=task_input, service_id=service_name)
|
||||
task_id = await self._do_create_task(task_def)
|
||||
|
||||
return task_id
|
||||
|
||||
async def create_task(self, task_def: TaskDefinition) -> str:
|
||||
"""Create a new task in this session.
|
||||
|
||||
Args:
|
||||
task_def (TaskDefinition): The task definition.
|
||||
|
||||
Returns:
|
||||
str: The ID of the created task.
|
||||
"""
|
||||
return await self._do_create_task(task_def)
|
||||
|
||||
async def _do_create_task(self, task_def: TaskDefinition) -> str:
|
||||
"""Async-only version of create_task, to be used internally from other methods."""
|
||||
task_def.session_id = self.id
|
||||
url = f"{self.client.control_plane_url}/sessions/{self.id}/tasks"
|
||||
response = await self.client.request("POST", url, json=task_def.model_dump())
|
||||
return response.json()
|
||||
|
||||
async def get_task_result(self, task_id: str) -> TaskResult | None:
|
||||
"""Get the result of a task in this session if it has one.
|
||||
|
||||
Args:
|
||||
task_id (str): The ID of the task to get the result for.
|
||||
|
||||
Returns:
|
||||
Optional[TaskResult]: The result of the task if it has one, otherwise None.
|
||||
"""
|
||||
return await self._do_get_task_result(task_id)
|
||||
|
||||
async def _do_get_task_result(self, task_id: str) -> TaskResult | None:
|
||||
"""Async-only version of get_task_result, to be used internally from other methods."""
|
||||
url = (
|
||||
f"{self.client.control_plane_url}/sessions/{self.id}/tasks/{task_id}/result"
|
||||
)
|
||||
response = await self.client.request("GET", url)
|
||||
data = response.json()
|
||||
return TaskResult(**data) if data else None
|
||||
|
||||
async def get_tasks(self) -> list[TaskDefinition]:
|
||||
"""Get all tasks in this session.
|
||||
|
||||
Returns:
|
||||
list[TaskDefinition]: A list of task definitions in the session.
|
||||
"""
|
||||
url = f"{self.client.control_plane_url}/sessions/{self.id}/tasks"
|
||||
response = await self.client.request("GET", url)
|
||||
return [TaskDefinition(**task) for task in response.json()]
|
||||
|
||||
async def send_event(self, service_name: str, task_id: str, ev: Event) -> None:
|
||||
"""Send event to a Workflow service.
|
||||
|
||||
Args:
|
||||
service_name (str): The name of the service running the target Task.
|
||||
task_id (str): The ID of the task running the workflow receiving the event.
|
||||
ev (Event): The event to be sent to the workflow task.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
serializer = JsonSerializer()
|
||||
event_def = EventDefinition(
|
||||
event_obj_str=serializer.serialize(ev), service_id=service_name
|
||||
)
|
||||
|
||||
url = f"{self.client.control_plane_url}/sessions/{self.id}/tasks/{task_id}/send_event"
|
||||
await self.client.request("POST", url, json=event_def.model_dump())
|
||||
|
||||
async def send_event_def(self, task_id: str, ev_def: EventDefinition) -> None:
|
||||
"""Send event to a Workflow service.
|
||||
|
||||
Args:
|
||||
task_id (str): The ID of the task running the workflow receiving the event.
|
||||
ev_def (EventDefinition): The event definition describing the Event to send.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
url = f"{self.client.control_plane_url}/sessions/{self.id}/tasks/{task_id}/send_event"
|
||||
await self.client.request("POST", url, json=ev_def.model_dump())
|
||||
|
||||
async def get_task_result_stream(
|
||||
self, task_id: str
|
||||
) -> AsyncGenerator[dict[str, Any], None]:
|
||||
"""Get the result of a task in this session if it has one.
|
||||
|
||||
Args:
|
||||
task_id (str): The ID of the task to get the result for.
|
||||
|
||||
Returns:
|
||||
AsyncGenerator[str, None, None]: A generator that yields the result of the task.
|
||||
"""
|
||||
url = f"{self.client.control_plane_url}/sessions/{self.id}/tasks/{task_id}/result_stream"
|
||||
start_time = time.time()
|
||||
while True:
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=self.client.timeout) as client:
|
||||
async with client.stream("GET", url) as response:
|
||||
response.raise_for_status()
|
||||
async for line in response.aiter_lines():
|
||||
json_line = json.loads(line)
|
||||
yield json_line
|
||||
break # Exit the function if successful
|
||||
except httpx.HTTPStatusError as e:
|
||||
if e.response.status_code != 404:
|
||||
raise # Re-raise if it's not a 404 error
|
||||
if (
|
||||
self.client.timeout is None # means no timeout, always poll
|
||||
or time.time() - start_time < self.client.timeout
|
||||
):
|
||||
await asyncio.sleep(self.client.poll_interval)
|
||||
else:
|
||||
raise TimeoutError(
|
||||
f"Task result not available after waiting for {self.client.timeout} seconds"
|
||||
)
|
||||
|
||||
|
||||
class SessionCollection(Collection):
|
||||
"""A model representing a collection of sessions."""
|
||||
|
||||
async def list(self) -> list[Session]: # type: ignore
|
||||
"""Returns a list of all the sessions in the collection."""
|
||||
sessions_url = f"{self.client.control_plane_url}/sessions"
|
||||
response = await self.client.request("GET", sessions_url)
|
||||
sessions = []
|
||||
model_class = self._prepare(Session)
|
||||
for id, session_def in response.json().items():
|
||||
sessions.append(model_class(client=self.client, id=id))
|
||||
return sessions
|
||||
|
||||
async def create(self) -> Session:
|
||||
"""Creates a new session and returns a Session object.
|
||||
|
||||
Returns:
|
||||
Session: A Session object representing the newly created session.
|
||||
"""
|
||||
return await self._create()
|
||||
|
||||
async def _create(self) -> Session:
|
||||
"""Async-only version of create, to be used internally from other methods."""
|
||||
create_url = f"{self.client.control_plane_url}/sessions/create"
|
||||
response = await self.client.request("POST", create_url)
|
||||
session_id = response.json()
|
||||
model_class = self._prepare(Session)
|
||||
return model_class(client=self.client, id=session_id)
|
||||
|
||||
async def get(self, id: str) -> Session:
|
||||
"""Gets a session by ID.
|
||||
|
||||
Args:
|
||||
id (str): The ID of the session to get.
|
||||
|
||||
Returns:
|
||||
Session: A Session object representing the specified session.
|
||||
|
||||
Raises:
|
||||
ValueError: If the session does not exist.
|
||||
"""
|
||||
return await self._get(id)
|
||||
|
||||
async def _get(self, id: str) -> Session:
|
||||
"""Async-only version of get, to be used internally from other methods."""
|
||||
|
||||
get_url = f"{self.client.control_plane_url}/sessions/{id}"
|
||||
await self.client.request("GET", get_url)
|
||||
model_class = self._prepare(Session)
|
||||
return model_class(client=self.client, id=id)
|
||||
|
||||
async def get_or_create(self, id: str) -> Session:
|
||||
"""Gets a session by ID, or creates a new one if it doesn't exist.
|
||||
|
||||
Returns:
|
||||
Session: A Session object representing the specified session.
|
||||
"""
|
||||
try:
|
||||
return await self._get(id)
|
||||
except httpx.HTTPStatusError as e:
|
||||
if e.response.status_code == 404:
|
||||
return await self._create()
|
||||
raise e
|
||||
|
||||
async def delete(self, session_id: str) -> None:
|
||||
"""Deletes a session by ID.
|
||||
|
||||
Args:
|
||||
session_id: The ID of the session to delete.
|
||||
"""
|
||||
delete_url = f"{self.client.control_plane_url}/sessions/{session_id}/delete"
|
||||
await self.client.request("POST", delete_url)
|
||||
|
||||
|
||||
class Service(Model):
|
||||
"""A model representing a Service."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class ServiceCollection(Collection):
|
||||
async def list(self) -> list[Service]: # type: ignore
|
||||
"""Returns a list containing all the services registered with the control plane.
|
||||
|
||||
Returns:
|
||||
list[Service]: List of services registered with the control plane.
|
||||
"""
|
||||
services_url = f"{self.client.control_plane_url}/services"
|
||||
response = await self.client.request("GET", services_url)
|
||||
services = []
|
||||
model_class = self._prepare(Service)
|
||||
|
||||
for name, service in response.json().items():
|
||||
services.append(model_class(client=self.client, id=name))
|
||||
|
||||
return services
|
||||
|
||||
async def register(self, service: ServiceDefinition) -> Service:
|
||||
"""Registers a service with the control plane.
|
||||
|
||||
Args:
|
||||
service: Definition of the Service to register.
|
||||
"""
|
||||
register_url = f"{self.client.control_plane_url}/services/register"
|
||||
await self.client.request("POST", register_url, json=service.model_dump())
|
||||
model_class = self._prepare(Service)
|
||||
s = model_class(id=service.service_name, client=self.client)
|
||||
self.items[service.service_name] = s
|
||||
return s
|
||||
|
||||
async def deregister(self, service_name: str) -> None:
|
||||
"""Deregisters a service from the control plane.
|
||||
|
||||
Args:
|
||||
service_name: The name of the Service to deregister.
|
||||
"""
|
||||
deregister_url = f"{self.client.control_plane_url}/services/deregister"
|
||||
await self.client.request(
|
||||
"POST",
|
||||
deregister_url,
|
||||
params={"service_name": service_name},
|
||||
)
|
||||
|
||||
|
||||
class Core(Model):
|
||||
@property
|
||||
def services(self) -> ServiceCollection:
|
||||
"""Returns a collection containing all the services registered with the control plane.
|
||||
|
||||
Returns:
|
||||
ServiceCollection: Collection of services registered with the control plane.
|
||||
"""
|
||||
model_class = self._prepare(ServiceCollection)
|
||||
return model_class(client=self.client, items={})
|
||||
|
||||
@property
|
||||
def sessions(self) -> SessionCollection:
|
||||
"""Returns a collection to access all the sessions registered with the control plane.
|
||||
|
||||
Returns:
|
||||
SessionCollection: Collection of sessions registered with the control plane.
|
||||
"""
|
||||
model_class = self._prepare(SessionCollection)
|
||||
return model_class(client=self.client, items={})
|
||||
@@ -1,4 +0,0 @@
|
||||
from .config import ControlPlaneConfig
|
||||
from .server import ControlPlaneServer
|
||||
|
||||
__all__ = ["ControlPlaneServer", "ControlPlaneConfig"]
|
||||
@@ -1,97 +0,0 @@
|
||||
from typing import List
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from llama_index.core.storage.kvstore.types import BaseKVStore
|
||||
from pydantic import Field
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
|
||||
class ControlPlaneConfig(BaseSettings):
|
||||
"""Control plane configuration."""
|
||||
|
||||
model_config = SettingsConfigDict(
|
||||
env_prefix="CONTROL_PLANE_", arbitrary_types_allowed=True
|
||||
)
|
||||
|
||||
services_store_key: str = Field(
|
||||
default="services",
|
||||
description="Key for the services store. Defaults to 'services'.",
|
||||
)
|
||||
tasks_store_key: str = Field(
|
||||
default="tasks",
|
||||
description="Key for the tasks store. Defaults to 'tasks'.",
|
||||
)
|
||||
session_store_key: str = Field(
|
||||
default="sessions",
|
||||
description="Key for the session store. Defaults to 'sessions'.",
|
||||
)
|
||||
step_interval: float = Field(
|
||||
default=0.1,
|
||||
description="The interval in seconds to poll for tool call results. Defaults to 0.1s.",
|
||||
)
|
||||
host: str = Field(
|
||||
default="127.0.0.1",
|
||||
description="The host where to run the control plane server",
|
||||
)
|
||||
port: int = Field(
|
||||
default=8000, description="The TCP port where to bind the control plane server"
|
||||
)
|
||||
internal_host: str | None = None
|
||||
internal_port: int | None = None
|
||||
running: bool = True
|
||||
cors_origins: List[str] | None = Field(
|
||||
default=None,
|
||||
description="List of hosts from which the service will accept CORS requests. Use ['*'] for all hosts.",
|
||||
)
|
||||
topic_namespace: str = Field(
|
||||
default="llama_deploy",
|
||||
description="The prefix used in the message queue topic to namespace messages from this control plane",
|
||||
)
|
||||
state_store_uri: str | None = Field(
|
||||
default=None,
|
||||
description="The connection URI of the database where to store state. If None, SimpleKVStore will be used",
|
||||
)
|
||||
use_tls: bool = Field(
|
||||
default=False,
|
||||
description="Use TLS (HTTPS) to communicate with the control plane",
|
||||
)
|
||||
|
||||
@property
|
||||
def url(self) -> str:
|
||||
if self.use_tls:
|
||||
return f"https://{self.host}:{self.port}"
|
||||
return f"http://{self.host}:{self.port}"
|
||||
|
||||
|
||||
def parse_state_store_uri(uri: str) -> BaseKVStore:
|
||||
bits = urlparse(uri)
|
||||
|
||||
# Redis supports multiple schemes:
|
||||
# https://redis-py.readthedocs.io/en/stable/connections.html#redis.Redis.from_url
|
||||
if bits.scheme in {"redis", "rediss", "unix"}:
|
||||
try:
|
||||
from llama_index.storage.kvstore.redis import RedisKVStore # type: ignore
|
||||
|
||||
return RedisKVStore(redis_uri=uri)
|
||||
except ImportError:
|
||||
msg = (
|
||||
"key-value store redis is not available, please install the required "
|
||||
"llama_index integration with 'pip install llama-index-storage-kvstore-redis'."
|
||||
)
|
||||
raise ValueError(msg)
|
||||
elif bits.scheme == "mongodb+srv":
|
||||
try:
|
||||
from llama_index.storage.kvstore.mongodb import ( # type:ignore
|
||||
MongoDBKVStore,
|
||||
)
|
||||
|
||||
return MongoDBKVStore(uri=uri)
|
||||
except ImportError:
|
||||
msg = (
|
||||
f"key-value store {bits.scheme} is not available, please install the required "
|
||||
"llama_index integration with 'pip install llama-index-storage-kvstore-mongodb'."
|
||||
)
|
||||
raise ValueError(msg)
|
||||
else:
|
||||
msg = f"key-value store '{bits.scheme}' is not supported."
|
||||
raise ValueError(msg)
|
||||
@@ -1,666 +0,0 @@
|
||||
import asyncio
|
||||
import json
|
||||
import uuid
|
||||
from logging import getLogger
|
||||
from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, List, Optional, Tuple
|
||||
|
||||
import uvicorn
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import StreamingResponse
|
||||
from llama_index.core.storage.kvstore import SimpleKVStore
|
||||
from llama_index.core.storage.kvstore.types import BaseKVStore
|
||||
|
||||
from llama_deploy.apiserver.tracing import (
|
||||
add_span_attribute,
|
||||
trace_async_method,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from llama_deploy.message_queues.base import AbstractMessageQueue, PublishCallback
|
||||
|
||||
from llama_deploy.types import (
|
||||
ActionTypes,
|
||||
EventDefinition,
|
||||
QueueMessage,
|
||||
ServiceDefinition,
|
||||
SessionDefinition,
|
||||
TaskDefinition,
|
||||
TaskResult,
|
||||
TaskStream,
|
||||
)
|
||||
|
||||
from .config import ControlPlaneConfig, parse_state_store_uri
|
||||
from .utils import get_result_key, get_stream_key
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
CONTROL_PLANE_MESSAGE_TYPE = "control_plane"
|
||||
|
||||
|
||||
class ControlPlaneServer:
|
||||
"""Control plane server.
|
||||
|
||||
The control plane is responsible for managing the state of the system, including:
|
||||
- Registering services.
|
||||
- Submitting tasks.
|
||||
- Managing task state.
|
||||
- Handling service completion.
|
||||
- Launching the control plane server.
|
||||
|
||||
Args:
|
||||
message_queue (AbstractMessageQueue): Message queue for the system.
|
||||
publish_callback (Optional[PublishCallback], optional): Callback for publishing messages. Defaults to None.
|
||||
state_store (Optional[BaseKVStore], optional): State store for the system. Defaults to None.
|
||||
config (ControlPlaneConfig, optional): Configuration for the control plane. Defaults to None.
|
||||
|
||||
Examples:
|
||||
```python
|
||||
from llama_deploy.control_plane import ControlPlaneServer
|
||||
from llama_deploy.message_queue import SimpleMessageQueue
|
||||
from llama_index.llms.openai import OpenAI
|
||||
|
||||
control_plane = ControlPlaneServer(
|
||||
SimpleMessageQueue(),
|
||||
SimpleOrchestrator(),
|
||||
)
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message_queue: "AbstractMessageQueue",
|
||||
publish_callback: "PublishCallback | None" = None,
|
||||
state_store: BaseKVStore | None = None,
|
||||
config: ControlPlaneConfig | None = None,
|
||||
) -> None:
|
||||
self._config = config or ControlPlaneConfig()
|
||||
|
||||
if state_store is not None and self._config.state_store_uri is not None:
|
||||
raise ValueError("Please use either 'state_store' or 'state_store_uri'.")
|
||||
|
||||
if state_store:
|
||||
self._state_store = state_store
|
||||
elif self._config.state_store_uri:
|
||||
self._state_store = parse_state_store_uri(self._config.state_store_uri)
|
||||
else:
|
||||
self._state_store = state_store or SimpleKVStore()
|
||||
|
||||
self._message_queue = message_queue
|
||||
self._publisher_id = f"{self.__class__.__qualname__}-{uuid.uuid4()}"
|
||||
self._publish_callback = publish_callback
|
||||
|
||||
self.app = FastAPI()
|
||||
if self._config.cors_origins:
|
||||
self.app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=self._config.cors_origins,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
self.app.add_api_route("/", self.home, methods=["GET"], tags=["Control Plane"])
|
||||
self.app.add_api_route(
|
||||
"/queue_config",
|
||||
self.get_message_queue_config,
|
||||
methods=["GET"],
|
||||
tags=["Message Queue"],
|
||||
)
|
||||
|
||||
self.app.add_api_route(
|
||||
"/services/register",
|
||||
self.register_service,
|
||||
methods=["POST"],
|
||||
tags=["Services"],
|
||||
)
|
||||
self.app.add_api_route(
|
||||
"/services/deregister",
|
||||
self.deregister_service,
|
||||
methods=["POST"],
|
||||
tags=["Services"],
|
||||
)
|
||||
self.app.add_api_route(
|
||||
"/services/{service_name}",
|
||||
self.get_service,
|
||||
methods=["GET"],
|
||||
tags=["Services"],
|
||||
)
|
||||
self.app.add_api_route(
|
||||
"/services",
|
||||
self.get_all_services,
|
||||
methods=["GET"],
|
||||
tags=["Services"],
|
||||
)
|
||||
|
||||
self.app.add_api_route(
|
||||
"/sessions/{session_id}",
|
||||
self.get_session,
|
||||
methods=["GET"],
|
||||
tags=["Sessions"],
|
||||
)
|
||||
self.app.add_api_route(
|
||||
"/sessions/create",
|
||||
self.create_session,
|
||||
methods=["POST"],
|
||||
tags=["Sessions"],
|
||||
)
|
||||
self.app.add_api_route(
|
||||
"/sessions/{session_id}/delete",
|
||||
self.delete_session,
|
||||
methods=["POST"],
|
||||
tags=["Sessions"],
|
||||
)
|
||||
self.app.add_api_route(
|
||||
"/sessions/{session_id}/tasks",
|
||||
self.add_task_to_session,
|
||||
methods=["POST"],
|
||||
tags=["Sessions"],
|
||||
)
|
||||
self.app.add_api_route(
|
||||
"/sessions",
|
||||
self.get_all_sessions,
|
||||
methods=["GET"],
|
||||
tags=["Sessions"],
|
||||
)
|
||||
self.app.add_api_route(
|
||||
"/sessions/{session_id}/tasks",
|
||||
self.get_session_tasks,
|
||||
methods=["GET"],
|
||||
tags=["Sessions"],
|
||||
)
|
||||
self.app.add_api_route(
|
||||
"/sessions/{session_id}/current_task",
|
||||
self.get_current_task,
|
||||
methods=["GET"],
|
||||
tags=["Sessions"],
|
||||
)
|
||||
self.app.add_api_route(
|
||||
"/sessions/{session_id}/tasks/{task_id}/result",
|
||||
self.get_task_result,
|
||||
methods=["GET"],
|
||||
tags=["Sessions"],
|
||||
)
|
||||
self.app.add_api_route(
|
||||
"/sessions/{session_id}/tasks/{task_id}/result_stream",
|
||||
self.get_task_result_stream,
|
||||
methods=["GET"],
|
||||
tags=["Sessions"],
|
||||
)
|
||||
self.app.add_api_route(
|
||||
"/sessions/{session_id}/tasks/{task_id}/send_event",
|
||||
self.send_event,
|
||||
methods=["POST"],
|
||||
tags=["Sessions"],
|
||||
)
|
||||
self.app.add_api_route(
|
||||
"/sessions/{session_id}/state",
|
||||
self.get_session_state,
|
||||
methods=["GET"],
|
||||
tags=["Sessions"],
|
||||
)
|
||||
self.app.add_api_route(
|
||||
"/sessions/{session_id}/state",
|
||||
self.update_session_state,
|
||||
methods=["POST"],
|
||||
tags=["Sessions"],
|
||||
)
|
||||
|
||||
@property
|
||||
def message_queue(self) -> "AbstractMessageQueue":
|
||||
return self._message_queue
|
||||
|
||||
@property
|
||||
def publisher_id(self) -> str:
|
||||
return self._publisher_id
|
||||
|
||||
@property
|
||||
def publish_callback(self) -> Optional["PublishCallback"]:
|
||||
return self._publish_callback
|
||||
|
||||
async def _process_messages(self, topic: str) -> None:
|
||||
async for message in self._message_queue.get_messages(topic):
|
||||
if not message.data:
|
||||
raise ValueError(
|
||||
f"Invalid field 'data' in QueueMessage: {message.data}"
|
||||
)
|
||||
|
||||
action = message.action
|
||||
if action == ActionTypes.NEW_TASK:
|
||||
task_def = TaskDefinition(**message.data)
|
||||
if task_def.session_id is None:
|
||||
task_def.session_id = await self.create_session()
|
||||
|
||||
await self.add_task_to_session(task_def.session_id, task_def)
|
||||
elif action == ActionTypes.COMPLETED_TASK:
|
||||
await self.handle_service_completion(TaskResult(**message.data))
|
||||
elif action == ActionTypes.TASK_STREAM:
|
||||
await self.add_stream_to_session(TaskStream(**message.data))
|
||||
else:
|
||||
raise ValueError(f"Action {action} not supported by control plane")
|
||||
|
||||
async def launch_server(self) -> None:
|
||||
# give precedence to external settings
|
||||
host = self._config.internal_host or self._config.host
|
||||
port = self._config.internal_port or self._config.port
|
||||
logger.info(f"Launching control plane server at {host}:{port}")
|
||||
|
||||
message_queue_consumer = asyncio.create_task(
|
||||
self._process_messages(self.get_topic(CONTROL_PLANE_MESSAGE_TYPE))
|
||||
)
|
||||
|
||||
class CustomServer(uvicorn.Server):
|
||||
def install_signal_handlers(self) -> None:
|
||||
pass
|
||||
|
||||
cfg = uvicorn.Config(self.app, host=host, port=port)
|
||||
server = CustomServer(cfg)
|
||||
try:
|
||||
await server.serve()
|
||||
except asyncio.CancelledError:
|
||||
self._running = False
|
||||
message_queue_consumer.cancel()
|
||||
await asyncio.gather(
|
||||
server.shutdown(), message_queue_consumer, return_exceptions=True
|
||||
)
|
||||
|
||||
async def home(self) -> Dict[str, str]:
|
||||
return {
|
||||
"running": str(self._config.running),
|
||||
"step_interval": str(self._config.step_interval),
|
||||
"services_store_key": self._config.services_store_key,
|
||||
"tasks_store_key": self._config.tasks_store_key,
|
||||
"session_store_key": self._config.session_store_key,
|
||||
}
|
||||
|
||||
async def register_service(
|
||||
self, service_def: ServiceDefinition
|
||||
) -> ControlPlaneConfig:
|
||||
await self._state_store.aput(
|
||||
service_def.service_name,
|
||||
service_def.model_dump(),
|
||||
collection=self._config.services_store_key,
|
||||
)
|
||||
return self._config
|
||||
|
||||
async def deregister_service(self, service_name: str) -> None:
|
||||
await self._state_store.adelete(
|
||||
service_name, collection=self._config.services_store_key
|
||||
)
|
||||
|
||||
async def get_service(self, service_name: str) -> ServiceDefinition:
|
||||
service_dict = await self._state_store.aget(
|
||||
service_name, collection=self._config.services_store_key
|
||||
)
|
||||
if service_dict is None:
|
||||
raise HTTPException(status_code=404, detail="Service not found")
|
||||
|
||||
return ServiceDefinition.model_validate(service_dict)
|
||||
|
||||
async def get_all_services(self) -> Dict[str, ServiceDefinition]:
|
||||
service_dicts = await self._state_store.aget_all(
|
||||
collection=self._config.services_store_key
|
||||
)
|
||||
|
||||
return {
|
||||
service_name: ServiceDefinition.model_validate(service_dict)
|
||||
for service_name, service_dict in service_dicts.items()
|
||||
}
|
||||
|
||||
async def create_session(self) -> str:
|
||||
session = SessionDefinition()
|
||||
await self._state_store.aput(
|
||||
session.session_id,
|
||||
session.model_dump(),
|
||||
collection=self._config.session_store_key,
|
||||
)
|
||||
|
||||
return session.session_id
|
||||
|
||||
async def get_session(self, session_id: str) -> SessionDefinition:
|
||||
session_dict = await self._state_store.aget(
|
||||
session_id, collection=self._config.session_store_key
|
||||
)
|
||||
if session_dict is None:
|
||||
raise HTTPException(status_code=404, detail="Session not found")
|
||||
|
||||
return SessionDefinition.model_validate(session_dict)
|
||||
|
||||
async def delete_session(self, session_id: str) -> None:
|
||||
await self._state_store.adelete(
|
||||
session_id, collection=self._config.session_store_key
|
||||
)
|
||||
|
||||
async def get_all_sessions(self) -> Dict[str, SessionDefinition]:
|
||||
session_dicts = await self._state_store.aget_all(
|
||||
collection=self._config.session_store_key
|
||||
)
|
||||
|
||||
return {
|
||||
session_id: SessionDefinition.model_validate(session_dict)
|
||||
for session_id, session_dict in session_dicts.items()
|
||||
}
|
||||
|
||||
async def get_session_tasks(self, session_id: str) -> List[TaskDefinition]:
|
||||
session = await self.get_session(session_id)
|
||||
task_defs = []
|
||||
for task_id in session.task_ids:
|
||||
task_defs.append(await self.get_task(task_id))
|
||||
return task_defs
|
||||
|
||||
async def get_current_task(self, session_id: str) -> Optional[TaskDefinition]:
|
||||
session = await self.get_session(session_id)
|
||||
if len(session.task_ids) == 0:
|
||||
return None
|
||||
return await self.get_task(session.task_ids[-1])
|
||||
|
||||
@trace_async_method("control_plane.add_task_to_session")
|
||||
async def add_task_to_session(
|
||||
self, session_id: str, task_def: TaskDefinition
|
||||
) -> str:
|
||||
add_span_attribute("session.id", session_id)
|
||||
add_span_attribute("task.id", task_def.task_id)
|
||||
add_span_attribute("task.service_id", task_def.service_id or "auto")
|
||||
|
||||
session_dict = await self._state_store.aget(
|
||||
session_id, collection=self._config.session_store_key
|
||||
)
|
||||
if session_dict is None:
|
||||
raise HTTPException(status_code=404, detail="Session not found")
|
||||
|
||||
if not task_def.session_id:
|
||||
task_def.session_id = session_id
|
||||
|
||||
if task_def.session_id != session_id:
|
||||
msg = f"Wrong task definition: task.session_id is {task_def.session_id} but should be {session_id}"
|
||||
raise HTTPException(status_code=400, detail=msg)
|
||||
|
||||
session = SessionDefinition(**session_dict)
|
||||
session.task_ids.append(task_def.task_id)
|
||||
await self._state_store.aput(
|
||||
session_id, session.model_dump(), collection=self._config.session_store_key
|
||||
)
|
||||
|
||||
await self._state_store.aput(
|
||||
task_def.task_id,
|
||||
task_def.model_dump(),
|
||||
collection=self._config.tasks_store_key,
|
||||
)
|
||||
|
||||
task_def = await self.send_task_to_service(task_def)
|
||||
|
||||
return task_def.task_id
|
||||
|
||||
async def send_task_to_service(self, task_def: TaskDefinition) -> TaskDefinition:
|
||||
if task_def.session_id is None:
|
||||
raise ValueError(f"Task with id {task_def.task_id} has no session")
|
||||
|
||||
session = await self.get_session(task_def.session_id)
|
||||
|
||||
next_messages, session_state = await self.get_next_messages(
|
||||
task_def, session.state
|
||||
)
|
||||
|
||||
logger.debug(f"Sending task {task_def.task_id} to services: {next_messages}")
|
||||
|
||||
for message in next_messages:
|
||||
await self.publish(message)
|
||||
|
||||
session.state.update(session_state)
|
||||
|
||||
await self._state_store.aput(
|
||||
task_def.session_id,
|
||||
session.model_dump(),
|
||||
collection=self._config.session_store_key,
|
||||
)
|
||||
|
||||
return task_def
|
||||
|
||||
@trace_async_method("control_plane.handle_service_completion")
|
||||
async def handle_service_completion(
|
||||
self,
|
||||
task_result: TaskResult,
|
||||
) -> None:
|
||||
add_span_attribute("task.id", task_result.task_id)
|
||||
add_span_attribute("task.result_length", len(task_result.result))
|
||||
|
||||
# add result to task state
|
||||
task_def = await self.get_task(task_result.task_id)
|
||||
if task_def.session_id is None:
|
||||
raise ValueError(f"Task with id {task_result.task_id} has no session")
|
||||
|
||||
session = await self.get_session(task_def.session_id)
|
||||
state = await self.add_result_to_state(task_result, session.state)
|
||||
|
||||
# update session state
|
||||
session.state.update(state)
|
||||
await self._state_store.aput(
|
||||
session.session_id,
|
||||
session.model_dump(),
|
||||
collection=self._config.session_store_key,
|
||||
)
|
||||
|
||||
# generate and send new tasks when needed
|
||||
task_def = await self.send_task_to_service(task_def)
|
||||
|
||||
await self._state_store.aput(
|
||||
task_def.task_id,
|
||||
task_def.model_dump(),
|
||||
collection=self._config.tasks_store_key,
|
||||
)
|
||||
|
||||
async def get_task(self, task_id: str) -> TaskDefinition:
|
||||
state_dict = await self._state_store.aget(
|
||||
task_id, collection=self._config.tasks_store_key
|
||||
)
|
||||
if state_dict is None:
|
||||
raise HTTPException(status_code=404, detail="Task not found")
|
||||
|
||||
return TaskDefinition(**state_dict)
|
||||
|
||||
async def get_task_result(
|
||||
self, task_id: str, session_id: str
|
||||
) -> Optional[TaskResult]:
|
||||
"""Get the result of a task if it has one.
|
||||
|
||||
Args:
|
||||
task_id (str): The ID of the task to get the result for.
|
||||
session_id (str): The ID of the session the task belongs to.
|
||||
|
||||
Returns:
|
||||
Optional[TaskResult]: The result of the task if it has one, otherwise None.
|
||||
"""
|
||||
session = await self.get_session(session_id)
|
||||
|
||||
result_key = get_result_key(task_id)
|
||||
if result_key not in session.state:
|
||||
return None
|
||||
|
||||
result = session.state[result_key]
|
||||
if not isinstance(result, TaskResult):
|
||||
if isinstance(result, dict):
|
||||
result = TaskResult(**result)
|
||||
elif isinstance(result, str):
|
||||
result = TaskResult(**json.loads(result))
|
||||
else:
|
||||
raise HTTPException(status_code=500, detail="Unexpected result type")
|
||||
|
||||
# sanity check
|
||||
if result.task_id != task_id:
|
||||
logger.debug(
|
||||
f"Retrieved result did not match requested task_id: {str(result)}"
|
||||
)
|
||||
return None
|
||||
|
||||
return result
|
||||
|
||||
async def add_stream_to_session(self, task_stream: TaskStream) -> None:
|
||||
# get session
|
||||
if task_stream.session_id is None:
|
||||
raise ValueError(
|
||||
f"Task stream with id {task_stream.task_id} has no session"
|
||||
)
|
||||
|
||||
session = await self.get_session(task_stream.session_id)
|
||||
|
||||
# add new stream data to session state
|
||||
existing_stream = session.state.get(get_stream_key(task_stream.task_id), [])
|
||||
existing_stream.append(task_stream.model_dump())
|
||||
session.state[get_stream_key(task_stream.task_id)] = existing_stream
|
||||
|
||||
# update session state in store
|
||||
await self._state_store.aput(
|
||||
task_stream.session_id,
|
||||
session.model_dump(),
|
||||
collection=self._config.session_store_key,
|
||||
)
|
||||
|
||||
async def get_task_result_stream(
|
||||
self, session_id: str, task_id: str
|
||||
) -> StreamingResponse:
|
||||
session = await self.get_session(session_id)
|
||||
|
||||
stream_key = get_stream_key(task_id)
|
||||
if stream_key not in session.state:
|
||||
raise HTTPException(status_code=404, detail="Task stream not found")
|
||||
|
||||
async def event_generator(
|
||||
session: SessionDefinition, stream_key: str
|
||||
) -> AsyncGenerator[str, None]:
|
||||
try:
|
||||
last_index = 0
|
||||
while True:
|
||||
session = await self.get_session(session_id)
|
||||
stream_results = session.state[stream_key][last_index:]
|
||||
stream_results = sorted(stream_results, key=lambda x: x["index"])
|
||||
for result in stream_results:
|
||||
if not isinstance(result, TaskStream):
|
||||
if isinstance(result, dict):
|
||||
result = TaskStream(**result)
|
||||
elif isinstance(result, str):
|
||||
result = TaskStream(**json.loads(result))
|
||||
else:
|
||||
raise ValueError("Unexpected result type in stream")
|
||||
|
||||
yield json.dumps(result.data) + "\n"
|
||||
|
||||
# check if there is a final result
|
||||
final_result = await self.get_task_result(task_id, session_id)
|
||||
if final_result is not None:
|
||||
return
|
||||
|
||||
last_index += len(stream_results)
|
||||
# Small delay to prevent tight loop
|
||||
await asyncio.sleep(self._config.step_interval)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error in event stream for session {session_id}, task {task_id}: {str(e)}"
|
||||
)
|
||||
yield json.dumps({"error": str(e)}) + "\n"
|
||||
|
||||
return StreamingResponse(
|
||||
event_generator(session, stream_key),
|
||||
media_type="application/x-ndjson",
|
||||
)
|
||||
|
||||
async def send_event(
|
||||
self,
|
||||
session_id: str,
|
||||
task_id: str,
|
||||
event_def: EventDefinition,
|
||||
) -> None:
|
||||
task_def = TaskDefinition(
|
||||
task_id=task_id,
|
||||
session_id=session_id,
|
||||
input=event_def.event_obj_str,
|
||||
service_id=event_def.service_id,
|
||||
)
|
||||
message = QueueMessage(
|
||||
type=event_def.service_id,
|
||||
action=ActionTypes.SEND_EVENT,
|
||||
data=task_def.model_dump(),
|
||||
)
|
||||
await self.publish(message)
|
||||
|
||||
async def get_session_state(self, session_id: str) -> Dict[str, Any]:
|
||||
session = await self.get_session(session_id)
|
||||
if session.task_ids is None:
|
||||
raise HTTPException(status_code=404, detail="Session not found")
|
||||
|
||||
return session.state
|
||||
|
||||
async def update_session_state(
|
||||
self, session_id: str, state: Dict[str, Any]
|
||||
) -> None:
|
||||
session = await self.get_session(session_id)
|
||||
|
||||
session.state.update(state)
|
||||
await self._state_store.aput(
|
||||
session_id, session.model_dump(), collection=self._config.session_store_key
|
||||
)
|
||||
|
||||
async def get_message_queue_config(self) -> Dict[str, dict]:
|
||||
"""
|
||||
Gets the config dict for the message queue being used.
|
||||
|
||||
Returns:
|
||||
Dict[str, dict]: A dict of message queue name -> config dict
|
||||
"""
|
||||
queue_config = self._message_queue.as_config()
|
||||
return {queue_config.__class__.__name__: queue_config.model_dump()}
|
||||
|
||||
def get_topic(self, msg_type: str) -> str:
|
||||
return f"{self._config.topic_namespace}.{msg_type}"
|
||||
|
||||
async def get_next_messages(
|
||||
self, task_def: TaskDefinition, state: Dict[str, Any]
|
||||
) -> Tuple[List[QueueMessage], Dict[str, Any]]:
|
||||
"""Get the next message to process. Returns the message and the new state.
|
||||
|
||||
Assumes the service_id is the destination for the next message.
|
||||
|
||||
Runs the required service, then sends the result to the final message type.
|
||||
"""
|
||||
if task_def.service_id is None:
|
||||
raise ValueError(
|
||||
"Task definition must have an service_id specified to identify a service"
|
||||
)
|
||||
|
||||
if task_def.task_id not in state:
|
||||
state[task_def.task_id] = {}
|
||||
|
||||
if state.get(get_result_key(task_def.task_id)) is not None:
|
||||
return [], state
|
||||
|
||||
destination = task_def.service_id
|
||||
destination_messages = [
|
||||
QueueMessage(
|
||||
type=destination,
|
||||
action=ActionTypes.NEW_TASK,
|
||||
data=task_def.model_dump(),
|
||||
)
|
||||
]
|
||||
|
||||
return destination_messages, state
|
||||
|
||||
async def add_result_to_state(
|
||||
self, result: TaskResult, state: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
"""Add the result of processing a message to the state. Returns the new state."""
|
||||
|
||||
# TODO: detect failures + retries
|
||||
cur_retries = state.get("retries", -1) + 1
|
||||
state["retries"] = cur_retries
|
||||
|
||||
# add result to state
|
||||
state[get_result_key(result.task_id)] = result
|
||||
|
||||
return state
|
||||
|
||||
async def publish(self, message: QueueMessage, **kwargs: Any) -> Any:
|
||||
"""Publish message."""
|
||||
message.publisher_id = self.publisher_id
|
||||
return await self.message_queue.publish(
|
||||
message,
|
||||
callback=self.publish_callback,
|
||||
topic=self.get_topic(message.type),
|
||||
**kwargs,
|
||||
)
|
||||
@@ -1,6 +0,0 @@
|
||||
def get_result_key(task_id: str) -> str:
|
||||
return f"result_{task_id}"
|
||||
|
||||
|
||||
def get_stream_key(task_id: str) -> str:
|
||||
return f"stream_{task_id}"
|
||||
@@ -1,28 +0,0 @@
|
||||
from .apache_kafka import (
|
||||
KafkaMessageQueue,
|
||||
KafkaMessageQueueConfig,
|
||||
)
|
||||
from .base import AbstractMessageQueue
|
||||
from .rabbitmq import (
|
||||
RabbitMQMessageQueue,
|
||||
RabbitMQMessageQueueConfig,
|
||||
)
|
||||
from .redis import RedisMessageQueue, RedisMessageQueueConfig
|
||||
from .simple import (
|
||||
SimpleMessageQueue,
|
||||
SimpleMessageQueueConfig,
|
||||
SimpleMessageQueueServer,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"AbstractMessageQueue",
|
||||
"KafkaMessageQueue",
|
||||
"KafkaMessageQueueConfig",
|
||||
"RabbitMQMessageQueue",
|
||||
"RabbitMQMessageQueueConfig",
|
||||
"RedisMessageQueue",
|
||||
"RedisMessageQueueConfig",
|
||||
"SimpleMessageQueueServer",
|
||||
"SimpleMessageQueueConfig",
|
||||
"SimpleMessageQueue",
|
||||
]
|
||||
@@ -1,202 +0,0 @@
|
||||
"""Apache Kafka Message Queue."""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from logging import getLogger
|
||||
from typing import Any, AsyncIterator, Dict, Literal
|
||||
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
from llama_deploy.message_queues.base import AbstractMessageQueue
|
||||
from llama_deploy.types import QueueMessage
|
||||
|
||||
logger = getLogger(__name__)
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
|
||||
DEFAULT_URL = "localhost:9092"
|
||||
DEFAULT_TOPIC_PARTITIONS = 10
|
||||
DEFAULT_TOPIC_REPLICATION_FACTOR = 1
|
||||
DEFAULT_TOPIC_NAME = "control_plane"
|
||||
DEFAULT_GROUP_ID = "default_group" # single group for competing consumers
|
||||
|
||||
|
||||
class KafkaMessageQueueConfig(BaseSettings):
|
||||
"""Kafka message queue configuration."""
|
||||
|
||||
model_config = SettingsConfigDict(env_prefix="KAFKA_")
|
||||
|
||||
type: Literal["kafka"] = Field(default="kafka")
|
||||
url: str = DEFAULT_URL
|
||||
host: str | None = None
|
||||
port: int | None = None
|
||||
|
||||
@model_validator(mode="after")
|
||||
def update_url(self) -> "KafkaMessageQueueConfig":
|
||||
if self.host and self.port:
|
||||
self.url = f"{self.host}:{self.port}"
|
||||
return self
|
||||
|
||||
|
||||
class KafkaMessageQueue(AbstractMessageQueue):
|
||||
"""Apache Kafka integration with aiokafka.
|
||||
|
||||
This class implements a traditional message broker using Apache Kafka.
|
||||
- Topics are created with N partitions
|
||||
- Consumers are registered to a single group to implement a competing
|
||||
consumer scheme where only one consumer subscribed to a topic gets the
|
||||
message
|
||||
- Default round-robin assignment is used
|
||||
|
||||
Attributes:
|
||||
url (str): The broker url string to connect to the Kafka server
|
||||
|
||||
Examples:
|
||||
```python
|
||||
from llama_deploy.message_queues.apache_kafka import KafkaMessageQueue
|
||||
|
||||
message_queue = KafkaMessageQueue() # uses the default url
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, config: KafkaMessageQueueConfig | None = None) -> None:
|
||||
self._config = config or KafkaMessageQueueConfig()
|
||||
self._kafka_consumers: dict[str, Any] = {}
|
||||
self._registered_topics: set[str] = set()
|
||||
|
||||
@classmethod
|
||||
def from_url_params(
|
||||
cls,
|
||||
host: str,
|
||||
port: int | None = None,
|
||||
) -> "KafkaMessageQueue":
|
||||
"""Convenience constructor from url params.
|
||||
|
||||
Args:
|
||||
host (str): host for rabbitmq server
|
||||
port (Optional[int], optional): port for rabbitmq server. Defaults to None.
|
||||
|
||||
Returns:
|
||||
KafkaMessageQueue: An Apache Kafka MessageQueue integration.
|
||||
"""
|
||||
url = f"{host}:{port}" if port else f"{host}"
|
||||
return cls(KafkaMessageQueueConfig(url=url))
|
||||
|
||||
def _create_new_topic(
|
||||
self,
|
||||
topic_name: str,
|
||||
num_partitions: int | None = None,
|
||||
replication_factor: int | None = None,
|
||||
**kwargs: Dict[str, Any],
|
||||
) -> None:
|
||||
"""Create a new topic.
|
||||
|
||||
Use kafka-python-ng instead of aio-kafka as latter has issues with
|
||||
resolving api_version with broker.
|
||||
|
||||
TODO: convert to aiokafka once this it is resolved there.
|
||||
"""
|
||||
try:
|
||||
from kafka.admin import KafkaAdminClient, NewTopic
|
||||
from kafka.errors import TopicAlreadyExistsError
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"kafka-python-ng is not installed. "
|
||||
"Please install it using `pip install kafka-python-ng`."
|
||||
)
|
||||
|
||||
admin_client = KafkaAdminClient(bootstrap_servers=self._config.url)
|
||||
try:
|
||||
topic = NewTopic(
|
||||
name=topic_name,
|
||||
num_partitions=num_partitions or DEFAULT_TOPIC_PARTITIONS,
|
||||
replication_factor=replication_factor
|
||||
or DEFAULT_TOPIC_REPLICATION_FACTOR,
|
||||
**kwargs,
|
||||
)
|
||||
admin_client.create_topics(new_topics=[topic])
|
||||
self._registered_topics.add(topic_name)
|
||||
logger.info(f"New topic {topic_name} created.")
|
||||
except TopicAlreadyExistsError:
|
||||
logger.info(f"Topic {topic_name} already exists.")
|
||||
pass
|
||||
|
||||
async def _publish(
|
||||
self, message: QueueMessage, topic: str, create_topic: bool
|
||||
) -> Any:
|
||||
"""Publish message to the queue."""
|
||||
try:
|
||||
from aiokafka import AIOKafkaProducer
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"aiokafka is not installed. "
|
||||
"Please install it using `pip install aiokafka`."
|
||||
)
|
||||
|
||||
if create_topic:
|
||||
self._create_new_topic(topic)
|
||||
|
||||
producer = AIOKafkaProducer(bootstrap_servers=self._config.url)
|
||||
await producer.start()
|
||||
try:
|
||||
message_body = json.dumps(message.model_dump()).encode("utf-8")
|
||||
await producer.send_and_wait(topic, message_body)
|
||||
logger.info(f"published message {message.id_}")
|
||||
finally:
|
||||
await producer.stop()
|
||||
|
||||
async def cleanup(self, *args: Any, **kwargs: Dict[str, Any]) -> None:
|
||||
"""Cleanup for local runs.
|
||||
|
||||
Use kafka-python-ng instead of aio-kafka as latter has issues with
|
||||
resolving api_version with broker when using admin client.
|
||||
|
||||
TODO: convert to aiokafka once this it is resolved there.
|
||||
"""
|
||||
try:
|
||||
from kafka.admin import KafkaAdminClient
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"aiokafka is not installed. "
|
||||
"Please install it using `pip install aiokafka`."
|
||||
)
|
||||
|
||||
admin_client = KafkaAdminClient(bootstrap_servers=self._config.url)
|
||||
active_topics = admin_client.list_topics()
|
||||
topics_to_delete = [el for el in self._registered_topics if el in active_topics]
|
||||
admin_client.delete_consumer_groups(DEFAULT_GROUP_ID)
|
||||
if topics_to_delete:
|
||||
admin_client.delete_topics(topics_to_delete)
|
||||
|
||||
async def get_messages(self, topic: str) -> AsyncIterator[QueueMessage]:
|
||||
try:
|
||||
from aiokafka import AIOKafkaConsumer
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"aiokafka is not installed. "
|
||||
"Please install it using `pip install aiokafka`."
|
||||
)
|
||||
|
||||
kafka_consumer = AIOKafkaConsumer(
|
||||
topic,
|
||||
bootstrap_servers=self._config.url,
|
||||
group_id=DEFAULT_GROUP_ID,
|
||||
auto_offset_reset="earliest",
|
||||
)
|
||||
|
||||
await kafka_consumer.start()
|
||||
|
||||
try:
|
||||
async for msg in kafka_consumer:
|
||||
if msg.value is None:
|
||||
raise RuntimeError("msg.value is None")
|
||||
decoded_message = json.loads(msg.value.decode("utf-8"))
|
||||
yield QueueMessage.model_validate(decoded_message)
|
||||
finally:
|
||||
stop_task = asyncio.create_task(kafka_consumer.stop())
|
||||
await asyncio.shield(stop_task)
|
||||
|
||||
def as_config(self) -> BaseModel:
|
||||
return self._config
|
||||
@@ -1,73 +0,0 @@
|
||||
"""Message queue module."""
|
||||
|
||||
import inspect
|
||||
from abc import ABC, abstractmethod
|
||||
from logging import getLogger
|
||||
from typing import Any, AsyncIterator, Awaitable, Callable
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_deploy.apiserver.tracing import add_span_attribute, create_span
|
||||
from llama_deploy.types import QueueMessage
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
PublishCallback = (
|
||||
Callable[[QueueMessage], Any] | Callable[[QueueMessage], Awaitable[Any]]
|
||||
)
|
||||
|
||||
|
||||
class AbstractMessageQueue(ABC):
|
||||
"""Message broker interface between publisher and consumer."""
|
||||
|
||||
@abstractmethod
|
||||
async def _publish(
|
||||
self, message: QueueMessage, topic: str, create_topic: bool
|
||||
) -> Any:
|
||||
"""Subclasses implement publish logic here."""
|
||||
|
||||
async def publish(
|
||||
self,
|
||||
message: QueueMessage,
|
||||
topic: str,
|
||||
callback: PublishCallback | None = None,
|
||||
create_topic: bool = True,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Send message to a consumer."""
|
||||
with create_span("message_queue.publish"):
|
||||
add_span_attribute("message.type", message.type)
|
||||
add_span_attribute("message.action", str(message.action))
|
||||
add_span_attribute("message.topic", topic)
|
||||
add_span_attribute("message.id", message.id_)
|
||||
|
||||
logger.info(
|
||||
f"Publishing message of type '{message.type}' with action '{message.action}' to topic '{topic}'"
|
||||
)
|
||||
logger.debug(f"Message: {message.model_dump()}")
|
||||
|
||||
message.stats.publish_time = message.stats.timestamp_str()
|
||||
message.stats.set_trace_context()
|
||||
|
||||
await self._publish(message, topic, create_topic)
|
||||
|
||||
if callback:
|
||||
with create_span("message_queue.callback"):
|
||||
if inspect.iscoroutinefunction(callback):
|
||||
await callback(message, **kwargs)
|
||||
else:
|
||||
callback(message, **kwargs)
|
||||
|
||||
@abstractmethod
|
||||
async def cleanup(self, *args: Any, **kwargs: dict[str, Any]) -> None:
|
||||
"""Perform any cleanup before shutting down."""
|
||||
|
||||
@abstractmethod
|
||||
def as_config(self) -> BaseModel:
|
||||
"""Returns the config dict to reconstruct the message queue."""
|
||||
|
||||
async def get_messages(self, topic: str) -> AsyncIterator[QueueMessage]:
|
||||
if False:
|
||||
# This is to help type checkers
|
||||
yield
|
||||
@@ -1,240 +0,0 @@
|
||||
"""RabbitMQ Message Queue."""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from logging import getLogger
|
||||
from typing import TYPE_CHECKING, Any, AsyncIterator, Literal, cast
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
from llama_deploy.message_queues.base import AbstractMessageQueue
|
||||
from llama_deploy.types import QueueMessage
|
||||
|
||||
if TYPE_CHECKING: # pragma: no cover
|
||||
from aio_pika import Connection
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
DEFAULT_URL = "amqp://guest:guest@localhost/"
|
||||
DEFAULT_EXCHANGE_NAME = "llama-deploy"
|
||||
|
||||
|
||||
class RabbitMQMessageQueueConfig(BaseSettings):
|
||||
"""RabbitMQ message queue configuration."""
|
||||
|
||||
model_config = SettingsConfigDict(env_prefix="RABBITMQ_")
|
||||
|
||||
type: Literal["rabbitmq"] = Field(default="rabbitmq")
|
||||
url: str = DEFAULT_URL
|
||||
exchange_name: str = DEFAULT_EXCHANGE_NAME
|
||||
username: str | None = None
|
||||
password: str | None = None
|
||||
host: str | None = None
|
||||
port: int | None = None
|
||||
vhost: str | None = None
|
||||
secure: bool | None = None
|
||||
|
||||
def model_post_init(self, __context: Any) -> None:
|
||||
if self.username and self.password and self.host:
|
||||
scheme = "amqps" if self.secure else "amqp"
|
||||
self.url = f"{scheme}://{self.username}:{self.password}@{self.host}"
|
||||
if self.port:
|
||||
self.url += f":{self.port}"
|
||||
elif self.vhost:
|
||||
self.url += f"/{self.vhost}"
|
||||
|
||||
|
||||
async def _establish_connection(url: str) -> "Connection":
|
||||
try:
|
||||
import aio_pika
|
||||
from aio_pika import Connection
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"Missing pika optional dep. Please install by running `pip install llama-deploy[rabbimq]`."
|
||||
)
|
||||
return cast(Connection, await aio_pika.connect(url))
|
||||
|
||||
|
||||
class RabbitMQMessageQueue(AbstractMessageQueue):
|
||||
"""RabbitMQ integration with aio-pika client.
|
||||
|
||||
This class creates a Work (or Task) Queue. For more information on Work Queues
|
||||
with RabbitMQ see the pages linked below:
|
||||
1. https://aio-pika.readthedocs.io/en/latest/rabbitmq-tutorial/2-work-queues.html
|
||||
2. https://aio-pika.readthedocs.io/en/latest/rabbitmq-tutorial/3-publish-subscribe.html
|
||||
|
||||
Connections are established by url that use [amqp uri scheme](https://www.rabbitmq.com/docs/uri-spec#the-amqp-uri-scheme):
|
||||
|
||||
```
|
||||
amqp_URI = "amqp://" amqp_authority [ "/" vhost ] [ "?" query ]
|
||||
amqp_authority = [ amqp_userinfo "@" ] host [ ":" port ]
|
||||
amqp_userinfo = username [ ":" password ]
|
||||
username = *( unreserved / pct-encoded / sub-delims )
|
||||
password = *( unreserved / pct-encoded / sub-delims )
|
||||
vhost = segment
|
||||
```
|
||||
|
||||
The Work Queue created has the following properties:
|
||||
- Exchange with name self.exchange
|
||||
- Messages are published to this queue through the exchange
|
||||
- Consumers are bound to the exchange and have queues based on their
|
||||
message type
|
||||
- Round-robin dispatching: with multiple consumers listening to the same
|
||||
queue, only one consumer will be chosen dictated by sequence.
|
||||
|
||||
Attributes:
|
||||
url (str): The amqp url string to connect to the RabbitMQ server
|
||||
exchange_name (str): The name to give to the so-called exchange within
|
||||
RabbitMQ AMQP 0-9 protocol.
|
||||
|
||||
Examples:
|
||||
```python
|
||||
from llama_deploy.message_queues.rabbitmq import RabbitMQMessageQueue
|
||||
|
||||
message_queue = RabbitMQMessageQueue() # uses the default url
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: RabbitMQMessageQueueConfig | None = None,
|
||||
url: str = DEFAULT_URL,
|
||||
exchange_name: str = DEFAULT_EXCHANGE_NAME,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
self._config = config or RabbitMQMessageQueueConfig()
|
||||
self._registered_topics: set[str] = set()
|
||||
|
||||
@classmethod
|
||||
def from_url_params(
|
||||
cls,
|
||||
username: str,
|
||||
password: str,
|
||||
host: str,
|
||||
vhost: str = "",
|
||||
port: int | None = None,
|
||||
secure: bool = False,
|
||||
exchange_name: str = DEFAULT_EXCHANGE_NAME,
|
||||
) -> "RabbitMQMessageQueue":
|
||||
"""Convenience constructor from url params.
|
||||
|
||||
Args:
|
||||
username (str): username for the amqp authority
|
||||
password (str): password for the amqp authority
|
||||
host (str): host for rabbitmq server
|
||||
port (int | None, optional): port for rabbitmq server. Defaults to None.
|
||||
secure (bool, optional): Whether or not to use SSL. Defaults to False.
|
||||
exchange_name (str, optional): The exchange name. Defaults to DEFAULT_EXCHANGE_NAME.
|
||||
|
||||
Returns:
|
||||
RabbitMQMessageQueue: A RabbitMQ MessageQueue integration.
|
||||
"""
|
||||
if not secure:
|
||||
if port:
|
||||
url = f"amqp://{username}:{password}@{host}:{port}/{vhost}"
|
||||
else:
|
||||
url = f"amqp://{username}:{password}@{host}/{vhost}"
|
||||
else:
|
||||
if port:
|
||||
url = f"amqps://{username}:{password}@{host}:{port}/{vhost}"
|
||||
else:
|
||||
url = f"amqps://{username}:{password}@{host}/{vhost}"
|
||||
return cls(RabbitMQMessageQueueConfig(url=url, exchange_name=exchange_name))
|
||||
|
||||
async def new_connection(self) -> "Connection":
|
||||
"""Returns a new connection to the RabbitMQ server."""
|
||||
return await _establish_connection(self._config.url)
|
||||
|
||||
async def _publish(
|
||||
self, message: QueueMessage, topic: str, create_topic: bool
|
||||
) -> Any:
|
||||
"""Publish message to the queue."""
|
||||
from aio_pika import DeliveryMode, ExchangeType
|
||||
from aio_pika import Message as AioPikaMessage
|
||||
|
||||
connection = await _establish_connection(self._config.url)
|
||||
|
||||
async with connection:
|
||||
channel = await connection.channel()
|
||||
exchange = await channel.declare_exchange(
|
||||
self._config.exchange_name,
|
||||
ExchangeType.DIRECT,
|
||||
)
|
||||
message_body = json.dumps(message.model_dump()).encode("utf-8")
|
||||
|
||||
aio_pika_message = AioPikaMessage(
|
||||
message_body,
|
||||
delivery_mode=DeliveryMode.PERSISTENT,
|
||||
)
|
||||
# Sending the message
|
||||
await exchange.publish(aio_pika_message, routing_key=topic)
|
||||
self._registered_topics.add(topic)
|
||||
logger.info(f"published message {message.id_} to {topic}")
|
||||
|
||||
async def get_messages(self, topic: str) -> AsyncIterator[QueueMessage]:
|
||||
from aio_pika import Channel, ExchangeType, IncomingMessage, Queue
|
||||
from aio_pika.abc import AbstractIncomingMessage
|
||||
|
||||
# Use a Queue to get the messages out of the callback
|
||||
message_queue: asyncio.Queue = asyncio.Queue()
|
||||
|
||||
async def on_message(message: AbstractIncomingMessage) -> None:
|
||||
message = cast(IncomingMessage, message)
|
||||
async with message.process():
|
||||
try:
|
||||
decoded_message = json.loads(message.body.decode("utf-8"))
|
||||
queue_message = QueueMessage.model_validate(decoded_message)
|
||||
await message_queue.put(queue_message)
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing message: {e}", exc_info=True)
|
||||
|
||||
while True:
|
||||
connection = None
|
||||
try:
|
||||
connection = await _establish_connection(self._config.url)
|
||||
async with connection:
|
||||
channel = cast(Channel, await connection.channel())
|
||||
exchange = await channel.declare_exchange(
|
||||
self._config.exchange_name,
|
||||
ExchangeType.DIRECT,
|
||||
)
|
||||
queue = cast(Queue, await channel.declare_queue(name=topic))
|
||||
await queue.bind(exchange)
|
||||
await queue.consume(on_message)
|
||||
# Yield messages as they arrive
|
||||
while True:
|
||||
try:
|
||||
# Wait for a message with a timeout to allow for cancellation checks
|
||||
message = await asyncio.wait_for(
|
||||
message_queue.get(), timeout=1.0
|
||||
)
|
||||
yield message
|
||||
except asyncio.TimeoutError:
|
||||
# Check if connection is still alive, continue if so
|
||||
if connection.is_closed:
|
||||
break
|
||||
continue
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error: {e}", exc_info=True)
|
||||
# Wait before reconnecting. Ideally we'd want exponential backoff here.
|
||||
await asyncio.sleep(10)
|
||||
finally:
|
||||
if connection and not connection.is_closed:
|
||||
await connection.close()
|
||||
|
||||
async def cleanup(self, *args: Any, **kwargs: dict[str, Any]) -> None:
|
||||
"""Perform any clean up of queues and exchanges."""
|
||||
connection = await self.new_connection()
|
||||
async with connection:
|
||||
channel = await connection.channel()
|
||||
for queue_name in self._registered_topics:
|
||||
await channel.queue_delete(queue_name=queue_name)
|
||||
await channel.exchange_delete(exchange_name=self._config.exchange_name)
|
||||
|
||||
def as_config(self) -> BaseModel:
|
||||
return RabbitMQMessageQueueConfig(
|
||||
url=self._config.url, exchange_name=self._config.exchange_name
|
||||
)
|
||||
@@ -1,114 +0,0 @@
|
||||
"""Redis Message Queue."""
|
||||
|
||||
import json
|
||||
from logging import getLogger
|
||||
from typing import Any, AsyncIterator, Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
from llama_deploy.message_queues.base import AbstractMessageQueue
|
||||
from llama_deploy.types import QueueMessage
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
class RedisMessageQueueConfig(BaseSettings):
|
||||
"""Redis message queue configuration."""
|
||||
|
||||
model_config = SettingsConfigDict(env_prefix="REDIS_")
|
||||
|
||||
type: Literal["redis"] = Field(default="redis")
|
||||
url: str = "redis://localhost:6379"
|
||||
host: str | None = None
|
||||
port: int | None = None
|
||||
db: int | None = None
|
||||
username: str | None = None
|
||||
password: str | None = None
|
||||
ssl: bool | None = None
|
||||
exclusive_mode: bool = False
|
||||
|
||||
def model_post_init(self, __context: Any) -> None:
|
||||
if self.host and self.port:
|
||||
scheme = "rediss" if self.ssl else "redis"
|
||||
auth = (
|
||||
f"{self.username}:{self.password}@"
|
||||
if self.username and self.password
|
||||
else ""
|
||||
)
|
||||
self.url = f"{scheme}://{auth}{self.host}:{self.port}/{self.db or ''}"
|
||||
|
||||
|
||||
class RedisMessageQueue(AbstractMessageQueue):
|
||||
"""Redis integration for message queue.
|
||||
|
||||
This class uses Redis Pub/Sub functionality for message distribution.
|
||||
|
||||
Examples:
|
||||
```python
|
||||
from llama_deploy.message_queues.redis import RedisMessageQueue
|
||||
|
||||
message_queue = RedisMessageQueue() # uses the default url
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, config: RedisMessageQueueConfig | None = None) -> None:
|
||||
self._config = config or RedisMessageQueueConfig()
|
||||
|
||||
try:
|
||||
from redis.asyncio import Redis
|
||||
|
||||
self._redis: Redis = Redis.from_url(self._config.url)
|
||||
except ImportError:
|
||||
msg = "Missing redis optional dependency. Please install by running `pip install llama-deploy[redis]`."
|
||||
raise ValueError(msg)
|
||||
|
||||
async def _publish(
|
||||
self, message: QueueMessage, topic: str, create_topic: bool
|
||||
) -> Any:
|
||||
"""Publish message to the Redis channel."""
|
||||
message_json = json.dumps(message.model_dump())
|
||||
result = await self._redis.publish(topic, message_json)
|
||||
logger.info(
|
||||
f"Published message {message.id_} to topic {topic} with {result} subscribers"
|
||||
)
|
||||
return result
|
||||
|
||||
async def get_messages(self, topic: str) -> AsyncIterator[QueueMessage]:
|
||||
pubsub = self._redis.pubsub()
|
||||
await pubsub.subscribe(topic)
|
||||
|
||||
processed_message_key = f"{topic}.processed_messages"
|
||||
try:
|
||||
while True:
|
||||
message = await pubsub.get_message(ignore_subscribe_messages=True)
|
||||
if message:
|
||||
decoded_message = json.loads(message["data"])
|
||||
queue_message = QueueMessage.model_validate(decoded_message)
|
||||
|
||||
# Deduplication check
|
||||
if self._config.exclusive_mode:
|
||||
new_message = await self._redis.sadd( # type: ignore
|
||||
processed_message_key, queue_message.id_
|
||||
)
|
||||
if not new_message:
|
||||
logger.debug(
|
||||
f"Skipping message {queue_message.id_} as it has "
|
||||
"already been consumed."
|
||||
)
|
||||
continue
|
||||
|
||||
# Set expiration for deduplication key. Expire processed messages
|
||||
# in 5 minutes.
|
||||
await self._redis.expire(processed_message_key, 300, nx=True)
|
||||
|
||||
yield queue_message
|
||||
finally:
|
||||
return
|
||||
|
||||
async def cleanup(self, *args: Any, **kwargs: dict[str, Any]) -> None:
|
||||
# Close main Redis connection
|
||||
await self._redis.aclose() # type: ignore
|
||||
|
||||
def as_config(self) -> BaseModel:
|
||||
return self._config
|
||||
@@ -1,11 +0,0 @@
|
||||
from .client import SimpleMessageQueue
|
||||
from .server import (
|
||||
SimpleMessageQueueConfig,
|
||||
SimpleMessageQueueServer,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"SimpleMessageQueueServer",
|
||||
"SimpleMessageQueueConfig",
|
||||
"SimpleMessageQueue",
|
||||
]
|
||||
@@ -1,67 +0,0 @@
|
||||
import asyncio
|
||||
from logging import getLogger
|
||||
from typing import Any, AsyncIterator, Dict
|
||||
|
||||
import httpx
|
||||
|
||||
from llama_deploy.message_queues.base import AbstractMessageQueue
|
||||
from llama_deploy.types import QueueMessage
|
||||
|
||||
from .config import SimpleMessageQueueConfig
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
class SimpleMessageQueue(AbstractMessageQueue):
|
||||
"""Remote client to be used with a SimpleMessageQueue server."""
|
||||
|
||||
def __init__(
|
||||
self, config: SimpleMessageQueueConfig = SimpleMessageQueueConfig()
|
||||
) -> None:
|
||||
self._config = config
|
||||
self._topics: set[str] = set()
|
||||
|
||||
async def _publish(
|
||||
self, message: QueueMessage, topic: str, create_topic: bool
|
||||
) -> Any:
|
||||
"""Sends a message to the SimpleMessageQueueServer."""
|
||||
if topic not in self._topics:
|
||||
# call the server to create it
|
||||
url = f"{self._config.base_url}topics/{topic}"
|
||||
async with httpx.AsyncClient(**self._config.client_kwargs) as client:
|
||||
result = await client.post(url)
|
||||
result.raise_for_status()
|
||||
self._topics.add(topic)
|
||||
|
||||
url = f"{self._config.base_url}messages/{topic}"
|
||||
async with httpx.AsyncClient(**self._config.client_kwargs) as client:
|
||||
result = await client.post(url, json=message.model_dump())
|
||||
return result
|
||||
|
||||
async def get_messages(self, topic: str) -> AsyncIterator[QueueMessage]:
|
||||
url = f"{self._config.base_url}messages/{topic}"
|
||||
client = httpx.AsyncClient(**self._config.client_kwargs)
|
||||
while True:
|
||||
try:
|
||||
result = await client.get(url)
|
||||
result.raise_for_status()
|
||||
if result.json():
|
||||
yield QueueMessage.model_validate(result.json())
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
except httpx.HTTPError as e:
|
||||
logger.debug(f"HTTP error occurred while fetching messages: {e}")
|
||||
await asyncio.sleep(1) # Back off on errors
|
||||
continue
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error while fetching messages: {e}")
|
||||
await asyncio.sleep(1) # Back off on errors
|
||||
continue
|
||||
|
||||
async def cleanup(self, *args: Any, **kwargs: Dict[str, Any]) -> None:
|
||||
# Nothing to clean up
|
||||
pass
|
||||
|
||||
def as_config(self) -> SimpleMessageQueueConfig:
|
||||
return self._config
|
||||
@@ -1,31 +0,0 @@
|
||||
from logging import getLogger
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import Field
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
class SimpleMessageQueueConfig(BaseSettings):
|
||||
"""Simple message queue configuration."""
|
||||
|
||||
model_config = SettingsConfigDict(env_prefix="SIMPLE_MESSAGE_QUEUE_")
|
||||
|
||||
type: Literal["simple"] = Field(default="simple")
|
||||
host: str = "127.0.0.1"
|
||||
port: int = 8001
|
||||
client_kwargs: dict[str, Any] = Field(
|
||||
default_factory=dict, description="The kwargs to pass to the httpx client."
|
||||
)
|
||||
raise_exceptions: bool = Field(
|
||||
default=False, description="Whether to raise exceptions when an error occurs."
|
||||
)
|
||||
use_ssl: bool = Field(default=False)
|
||||
|
||||
@property
|
||||
def base_url(self) -> str:
|
||||
protocol = "https" if self.use_ssl else "http"
|
||||
if self.port != 80:
|
||||
return f"{protocol}://{self.host}:{self.port}/"
|
||||
return f"{protocol}://{self.host}/"
|
||||
@@ -1,118 +0,0 @@
|
||||
"""Simple Message Queue."""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from collections import deque
|
||||
from typing import Any, Dict
|
||||
|
||||
import uvicorn
|
||||
from fastapi import FastAPI, HTTPException, status
|
||||
|
||||
from llama_deploy.types import QueueMessage
|
||||
|
||||
from .config import SimpleMessageQueueConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MessagesPollFilter(logging.Filter):
|
||||
"""Filters out access logs for /messages/.
|
||||
|
||||
The message queue client works with plain HTTP and as a form of pubsub
|
||||
subscription indefintely polls the /messages/ endpoint on the server. To
|
||||
avoid cluttering the logs, we filter out GET requests on that specific
|
||||
endpoint.
|
||||
"""
|
||||
|
||||
def filter(self, record: logging.LogRecord) -> bool:
|
||||
return "GET /messages/" not in record.getMessage()
|
||||
|
||||
|
||||
uvicorn_logger = logging.getLogger("uvicorn.access")
|
||||
uvicorn_logger.addFilter(MessagesPollFilter())
|
||||
|
||||
|
||||
class SimpleMessageQueueServer:
|
||||
"""An in-memory message queue that implements a push model for consumers.
|
||||
|
||||
When registering, a specific queue for a consumer is created.
|
||||
When a message is published, it is added to the queue for the given message type.
|
||||
"""
|
||||
|
||||
def __init__(self, config: SimpleMessageQueueConfig = SimpleMessageQueueConfig()):
|
||||
self._config = config
|
||||
self._queues: dict[str, deque] = {}
|
||||
self._running = False
|
||||
self._app = FastAPI()
|
||||
|
||||
self._app.add_api_route(
|
||||
"/",
|
||||
self._home,
|
||||
methods=["GET"],
|
||||
)
|
||||
self._app.add_api_route(
|
||||
"/topics/{topic}",
|
||||
self._create_topic,
|
||||
methods=["POST"],
|
||||
)
|
||||
self._app.add_api_route(
|
||||
"/messages/{topic}",
|
||||
self._publish,
|
||||
methods=["POST"],
|
||||
)
|
||||
self._app.add_api_route(
|
||||
"/messages/{topic}",
|
||||
self._get_messages,
|
||||
methods=["GET"],
|
||||
)
|
||||
|
||||
async def launch_server(self) -> None:
|
||||
"""Launch the message queue as a FastAPI server."""
|
||||
logger.info(f"Launching message queue server at {self._config.base_url}")
|
||||
self._running = True
|
||||
|
||||
cfg = uvicorn.Config(
|
||||
self._app, host=self._config.host, port=self._config.port or 80
|
||||
)
|
||||
server = uvicorn.Server(cfg)
|
||||
|
||||
try:
|
||||
await server.serve()
|
||||
except asyncio.CancelledError:
|
||||
self._running = False
|
||||
await asyncio.gather(server.shutdown(), return_exceptions=True)
|
||||
|
||||
#
|
||||
# HTTP API endpoints
|
||||
#
|
||||
|
||||
async def _home(self) -> Dict[str, str]:
|
||||
return {
|
||||
"service_name": "message_queue",
|
||||
"description": "Message queue for multi-agent system",
|
||||
}
|
||||
|
||||
async def _create_topic(self, topic: str) -> Any:
|
||||
"""If topic already exists, this is a no-op."""
|
||||
if topic not in self._queues:
|
||||
self._queues[topic] = deque()
|
||||
|
||||
async def _publish(self, message: QueueMessage, topic: str) -> Any:
|
||||
"""Publish message to a queue."""
|
||||
if topic not in self._queues:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail=f"topic {topic} not found"
|
||||
)
|
||||
|
||||
self._queues[topic].append(message)
|
||||
|
||||
async def _get_messages(self, topic: str) -> QueueMessage | None:
|
||||
if topic not in self._queues:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail=f"topic {topic} not found"
|
||||
)
|
||||
if queue := self._queues[topic]:
|
||||
message: QueueMessage = queue.popleft()
|
||||
return message
|
||||
|
||||
return None
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user