mirror of
https://github.com/run-llama/llama_deploy.git
synced 2026-07-01 21:04:00 -04:00
feat: Introduce tracing support for monitoring LlamaDeploy (#530)
* feat: add tracing support * fix * fix deps * simplify import structure * remove noise, simplify code * update lockfile * update docs * fix e2e tests
This commit is contained in:
committed by
GitHub
parent
e3b0a85145
commit
878e5923b4
+5
-2
@@ -1,5 +1,8 @@
|
||||
[formatting]
|
||||
align_comments = false
|
||||
reorder_keys = false
|
||||
indent_string = " " # adapt to toml-sort
|
||||
array_trailing_comma = false # adapt to toml-sort
|
||||
# Following are to be consistent with toml-sort
|
||||
indent_string = " "
|
||||
array_trailing_comma = false
|
||||
compact_arrays = true
|
||||
compact_inline_tables = true
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import uvicorn
|
||||
from prometheus_client import start_http_server
|
||||
|
||||
from llama_deploy.apiserver import settings
|
||||
from llama_deploy.apiserver.settings import settings
|
||||
|
||||
if __name__ == "__main__":
|
||||
if settings.prometheus_enabled:
|
||||
|
||||
@@ -6,7 +6,7 @@ from pathlib import Path
|
||||
import uvicorn
|
||||
from prometheus_client import start_http_server
|
||||
|
||||
from llama_deploy.apiserver import settings
|
||||
from llama_deploy.apiserver.settings import settings
|
||||
|
||||
CLONED_REPO_FOLDER = Path("cloned_repo")
|
||||
RC_PATH = Path("/data")
|
||||
|
||||
@@ -0,0 +1,250 @@
|
||||
# Observability
|
||||
|
||||
LlamaDeployment provides comprehensive observability capabilities through distributed tracing and metrics collection.
|
||||
This allows you to monitor workflow execution, track performance, and debug issues across your distributed deployment.
|
||||
|
||||
## Overview
|
||||
|
||||
LlamaDeployment supports two main observability features:
|
||||
|
||||
1. **Distributed Tracing** - Track request flows across services using OpenTelemetry
|
||||
2. **Metrics Collection** - Monitor system performance with Prometheus metrics
|
||||
|
||||
## Distributed Tracing
|
||||
|
||||
Distributed tracing provides end-to-end visibility into workflow execution across all components in your deployment.
|
||||
Traces show the complete journey of a request from the API server through message queues to workflow completion.
|
||||
|
||||
### What Gets Traced
|
||||
|
||||
- **Control Plane**: Service registration, task orchestration, and session management
|
||||
- **Workflow Services**: Complete workflow execution lifecycle including state loading, workflow running, event
|
||||
streaming, and result publishing
|
||||
- **Message Queues**: Message publishing and consumption with trace context propagation
|
||||
|
||||
### Configuration
|
||||
|
||||
Tracing is disabled by default and can be enabled through environment variables:
|
||||
|
||||
```bash
|
||||
# Enable tracing
|
||||
export LLAMA_DEPLOY_APISERVER_TRACING_ENABLED=true
|
||||
|
||||
# Set service name (default: llama-deploy-apiserver)
|
||||
export LLAMA_DEPLOY_APISERVER_TRACING_SERVICE_NAME=my-api-server
|
||||
|
||||
# Configure exporter (console, jaeger, otlp)
|
||||
export LLAMA_DEPLOY_APISERVER_TRACING_EXPORTER=jaeger
|
||||
export LLAMA_DEPLOY_APISERVER_TRACING_ENDPOINT=localhost:14268
|
||||
|
||||
# Configure sampling rate (0.0 to 1.0)
|
||||
export LLAMA_DEPLOY_APISERVER_TRACING_SAMPLE_RATE=0.1
|
||||
```
|
||||
|
||||
### Supported Exporters
|
||||
|
||||
#### Console Exporter
|
||||
Prints traces to the console - useful for development:
|
||||
|
||||
```bash
|
||||
export LLAMA_DEPLOY_APISERVER_TRACING_EXPORTER=console
|
||||
```
|
||||
|
||||
#### OTLP Exporter
|
||||
Exports traces using OpenTelemetry Protocol (works with many backends like Jaeger):
|
||||
|
||||
```bash
|
||||
export LLAMA_DEPLOY_APISERVER_TRACING_EXPORTER=otlp
|
||||
export LLAMA_DEPLOY_APISERVER_TRACING_ENDPOINT=http://localhost:4317
|
||||
```
|
||||
|
||||
### Setting Up Jaeger
|
||||
|
||||
To set up Jaeger for trace collection and visualization:
|
||||
|
||||
```bash
|
||||
# Run Jaeger all-in-one container
|
||||
docker run --rm -d \
|
||||
-e COLLECTOR_ZIPKIN_HOST_PORT=:9411 \
|
||||
-p 16686:16686 \
|
||||
-p 4317:4317 \
|
||||
-p 4318:4318 \
|
||||
-p 9411:9411 \
|
||||
jaegertracing/all-in-one:latest
|
||||
|
||||
# Configure LlamaDeployment to use Jaeger
|
||||
export LLAMA_DEPLOY_APISERVER_TRACING_ENABLED=true
|
||||
export LLAMA_DEPLOY_APISERVER_TRACING_EXPORTER=otlp
|
||||
export LLAMA_DEPLOY_APISERVER_TRACING_ENDPOINT=http://localhost:4317
|
||||
```
|
||||
|
||||
Access the Jaeger UI at http://localhost:16686 to view traces.
|
||||
|
||||
### Trace Context Propagation
|
||||
|
||||
Traces automatically propagate across service boundaries through message queues. Each message includes trace context
|
||||
(`trace_id` and `span_id`) in the `QueueMessageStats`, ensuring complete end-to-end tracing.
|
||||
|
||||
## Metrics Collection
|
||||
|
||||
LlamaDeployment includes Prometheus metrics for monitoring system performance and health.
|
||||
|
||||
### API Server Metrics
|
||||
|
||||
The API server automatically exposes Prometheus metrics when enabled:
|
||||
|
||||
```bash
|
||||
# Enable Prometheus metrics (default: true)
|
||||
export LLAMA_DEPLOY_APISERVER_PROMETHEUS_ENABLED=true
|
||||
|
||||
# Set metrics port (default: 9000)
|
||||
export LLAMA_DEPLOY_APISERVER_PROMETHEUS_PORT=9000
|
||||
```
|
||||
|
||||
Metrics are available at `http://localhost:9000/metrics`.
|
||||
|
||||
### Available Metrics
|
||||
|
||||
The API server tracks several key metrics:
|
||||
|
||||
- **Deployment State**: Current state of deployments (running, stopped, etc.)
|
||||
- **Service State**: Health and status of registered services
|
||||
- **API Request Metrics**: HTTP request counts, durations, and error rates (via tracing integration)
|
||||
|
||||
### Setting Up Prometheus
|
||||
|
||||
Create a `prometheus.yml` configuration file:
|
||||
|
||||
```yaml
|
||||
global:
|
||||
scrape_interval: 15s
|
||||
|
||||
scrape_configs:
|
||||
- job_name: 'llama-deploy-apiserver'
|
||||
static_configs:
|
||||
- targets: ['localhost:9000']
|
||||
scrape_interval: 5s
|
||||
```
|
||||
|
||||
Run Prometheus:
|
||||
|
||||
```bash
|
||||
# Using Docker
|
||||
docker run -d --name prometheus \
|
||||
-p 9090:9090 \
|
||||
-v $(pwd)/prometheus.yml:/etc/prometheus/prometheus.yml \
|
||||
prom/prometheus
|
||||
|
||||
# Using binary
|
||||
prometheus --config.file=prometheus.yml
|
||||
```
|
||||
|
||||
Access Prometheus at http://localhost:9090.
|
||||
|
||||
### Setting Up Grafana
|
||||
|
||||
For advanced visualization, set up Grafana with Prometheus as a data source:
|
||||
|
||||
```bash
|
||||
# Run Grafana
|
||||
docker run -d --name grafana \
|
||||
-p 3000:3000 \
|
||||
grafana/grafana
|
||||
```
|
||||
|
||||
1. Open http://localhost:3000 (admin/admin)
|
||||
2. Add Prometheus as data source: http://localhost:9090
|
||||
3. Create dashboards to visualize LlamaDeployment metrics
|
||||
|
||||
## Best Practices
|
||||
|
||||
### Sampling
|
||||
|
||||
For production deployments, configure appropriate sampling rates to balance observability with performance:
|
||||
|
||||
```bash
|
||||
# Sample 10% of traces
|
||||
export LLAMA_DEPLOY_APISERVER_TRACING_SAMPLE_RATE=0.1
|
||||
```
|
||||
|
||||
### Service Names
|
||||
|
||||
Use descriptive service names to distinguish between different deployments:
|
||||
|
||||
```bash
|
||||
export LLAMA_DEPLOY_APISERVER_TRACING_SERVICE_NAME=prod-rag-workflow
|
||||
export LLAMA_DEPLOY_APISERVER_TRACING_SERVICE_NAME=prod-api-server
|
||||
```
|
||||
|
||||
### Resource Attributes
|
||||
|
||||
Add custom resource attributes for better filtering:
|
||||
|
||||
```python
|
||||
from llama_deploy.apiserver.tracing.utils import add_span_attribute
|
||||
|
||||
# Add custom attributes in your workflow
|
||||
add_span_attribute("workflow.type", "rag")
|
||||
add_span_attribute("environment", "production")
|
||||
```
|
||||
|
||||
### Monitoring Alerts
|
||||
|
||||
Set up alerts based on metrics:
|
||||
|
||||
```yaml
|
||||
# Prometheus alerting rule example
|
||||
- alert: DeploymentDown
|
||||
expr: llama_deploy_deployment_state{state="stopped"} > 0
|
||||
for: 5m
|
||||
labels:
|
||||
severity: critical
|
||||
annotations:
|
||||
summary: "LlamaDeployment deployment is down"
|
||||
```
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Traces Not Appearing
|
||||
|
||||
1. Verify tracing is enabled: `LLAMA_DEPLOY_APISERVER_TRACING_ENABLED=true`
|
||||
2. Check exporter configuration and endpoint connectivity
|
||||
3. Verify sampling rate is not too low
|
||||
4. Check application logs for tracing errors
|
||||
|
||||
### Missing Dependencies
|
||||
|
||||
Install tracing dependencies:
|
||||
|
||||
```bash
|
||||
pip install llama-deploy[observability]
|
||||
```
|
||||
|
||||
### Performance Impact
|
||||
|
||||
- Tracing adds minimal overhead when properly configured
|
||||
- Use sampling to reduce overhead in high-traffic scenarios
|
||||
- Console exporter has higher overhead than Jaeger/OTLP
|
||||
|
||||
### Metrics Not Available
|
||||
|
||||
1. Verify Prometheus is enabled: `LLAMA_DEPLOY_APISERVER_PROMETHEUS_ENABLED=true`
|
||||
2. Check metrics port is accessible: `curl http://localhost:9000/metrics`
|
||||
3. Verify Prometheus configuration and targets
|
||||
|
||||
## Security Considerations
|
||||
|
||||
### Sensitive Data
|
||||
|
||||
Tracing doesn't automatically exclude sensitive parameters from span attributes. Review custom span attributes to
|
||||
ensure no sensitive data is included.
|
||||
|
||||
### Network Security
|
||||
|
||||
When using OTLP in production:
|
||||
|
||||
```bash
|
||||
# Use secure connections
|
||||
export LLAMA_DEPLOY_APISERVER_TRACING_INSECURE=false
|
||||
export LLAMA_DEPLOY_APISERVER_TRACING_ENDPOINT=https://your-secure-endpoint.com
|
||||
```
|
||||
@@ -10,7 +10,7 @@ from llama_deploy.client import Client
|
||||
|
||||
|
||||
def run_apiserver():
|
||||
uvicorn.run("llama_deploy.apiserver:app", host="127.0.0.1", port=4501)
|
||||
uvicorn.run("llama_deploy.apiserver.app:app", host="127.0.0.1", port=4501)
|
||||
|
||||
|
||||
@retry(wait=wait_exponential(min=1, max=10))
|
||||
|
||||
@@ -1,9 +0,0 @@
|
||||
from .app import app
|
||||
from .deployment_config_parser import DeploymentConfig
|
||||
from .settings import settings
|
||||
|
||||
__all__ = [
|
||||
"app",
|
||||
"settings",
|
||||
"DeploymentConfig",
|
||||
]
|
||||
|
||||
@@ -8,7 +8,7 @@ if __name__ == "__main__":
|
||||
start_http_server(settings.prometheus_port)
|
||||
|
||||
uvicorn.run(
|
||||
"llama_deploy.apiserver:app",
|
||||
"llama_deploy.apiserver.app:app",
|
||||
host=settings.host,
|
||||
port=settings.port,
|
||||
)
|
||||
|
||||
@@ -8,12 +8,17 @@ from fastapi.responses import JSONResponse
|
||||
|
||||
from .routers import deployments_router, status_router
|
||||
from .server import lifespan
|
||||
from .settings import settings
|
||||
from .tracing import configure_tracing
|
||||
|
||||
logger = logging.getLogger("uvicorn.info")
|
||||
|
||||
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
|
||||
# Setup tracing
|
||||
configure_tracing(settings)
|
||||
|
||||
# Configure CORS middleware if the environment variable is set
|
||||
if not os.environ.get("DISABLE_CORS", False):
|
||||
app.add_middleware(
|
||||
|
||||
@@ -22,10 +22,10 @@ from llama_deploy.message_queues import (
|
||||
|
||||
MessageQueueConfig = Annotated[
|
||||
Union[
|
||||
KafkaMessageQueueConfig,
|
||||
RabbitMQMessageQueueConfig,
|
||||
RedisMessageQueueConfig,
|
||||
SimpleMessageQueueConfig,
|
||||
"KafkaMessageQueueConfig",
|
||||
"RabbitMQMessageQueueConfig",
|
||||
"RedisMessageQueueConfig",
|
||||
"SimpleMessageQueueConfig",
|
||||
],
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
@@ -105,7 +105,7 @@ class DeploymentConfig(BaseModel):
|
||||
model_config = ConfigDict(populate_by_name=True, extra="forbid")
|
||||
|
||||
name: str
|
||||
control_plane: ControlPlaneConfig
|
||||
control_plane: "ControlPlaneConfig"
|
||||
message_queue: MessageQueueConfig | None = Field(None)
|
||||
default_service: str | None = Field(None)
|
||||
services: dict[str, Service]
|
||||
|
||||
@@ -27,6 +27,8 @@ class ApiserverSettings(BaseSettings):
|
||||
default=False,
|
||||
description="Use TLS (HTTPS) to communicate with the API Server",
|
||||
)
|
||||
|
||||
# Metrics collection settings
|
||||
prometheus_enabled: bool = Field(
|
||||
default=True,
|
||||
description="Whether to enable the Prometheus metrics exporter along with the API Server",
|
||||
@@ -36,6 +38,36 @@ class ApiserverSettings(BaseSettings):
|
||||
description="The port where to serve Prometheus metrics",
|
||||
)
|
||||
|
||||
# Tracing settings
|
||||
tracing_enabled: bool = Field(
|
||||
default=False,
|
||||
description="Enable OpenTelemetry tracing. Defaults to False.",
|
||||
)
|
||||
tracing_service_name: str = Field(
|
||||
default="llama-deploy-apiserver",
|
||||
description="Service name for tracing. Defaults to 'llama-deploy-apiserver'.",
|
||||
)
|
||||
tracing_exporter: str = Field(
|
||||
default="console",
|
||||
description="Trace exporter type: 'console', 'jaeger', 'otlp'. Defaults to 'console'.",
|
||||
)
|
||||
tracing_endpoint: str | None = Field(
|
||||
default=None,
|
||||
description="Trace exporter endpoint. Required for 'jaeger' and 'otlp' exporters.",
|
||||
)
|
||||
tracing_sample_rate: float = Field(
|
||||
default=1.0,
|
||||
description="Trace sampling rate (0.0 to 1.0). Defaults to 1.0 (100% sampling).",
|
||||
)
|
||||
tracing_insecure: bool = Field(
|
||||
default=True,
|
||||
description="Use insecure connection for OTLP exporter. Defaults to True.",
|
||||
)
|
||||
tracing_timeout: int = Field(
|
||||
default=30,
|
||||
description="Timeout in seconds for trace export. Defaults to 30.",
|
||||
)
|
||||
|
||||
@property
|
||||
def url(self) -> str:
|
||||
protocol = "https://" if self.use_tls else "http://"
|
||||
|
||||
@@ -0,0 +1,237 @@
|
||||
"""Tracing utilities for llama_deploy."""
|
||||
|
||||
import logging
|
||||
from contextlib import contextmanager, nullcontext
|
||||
from functools import wraps
|
||||
from typing import TYPE_CHECKING, Any, Callable, Generator, TypeVar
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from llama_deploy.apiserver.settings import ApiserverSettings
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Since opentelemetry is optional, we have to use Any to type the tracer
|
||||
_tracer: Any | None = None
|
||||
_tracing_enabled = False
|
||||
_null_context = nullcontext()
|
||||
|
||||
F = TypeVar("F", bound=Callable[..., Any])
|
||||
|
||||
|
||||
def configure_tracing(settings: "ApiserverSettings") -> None:
|
||||
"""Configure OpenTelemetry tracing based on the provided configuration."""
|
||||
global _tracer, _tracing_enabled
|
||||
|
||||
if not settings.tracing_enabled:
|
||||
logger.debug("Tracing is disabled")
|
||||
_tracing_enabled = False
|
||||
return
|
||||
|
||||
try:
|
||||
from opentelemetry import trace
|
||||
from opentelemetry.instrumentation.asyncio import AsyncioInstrumentor
|
||||
from opentelemetry.sdk.resources import SERVICE_NAME, Resource
|
||||
from opentelemetry.sdk.trace import TracerProvider
|
||||
from opentelemetry.sdk.trace.export import BatchSpanProcessor
|
||||
from opentelemetry.sdk.trace.sampling import TraceIdRatioBased
|
||||
|
||||
# Create resource with service name
|
||||
resource = Resource.create({SERVICE_NAME: settings.tracing_service_name})
|
||||
|
||||
# Create tracer provider with sampling
|
||||
tracer_provider = TracerProvider(
|
||||
resource=resource, sampler=TraceIdRatioBased(settings.tracing_sample_rate)
|
||||
)
|
||||
|
||||
# Configure exporter based on config
|
||||
if settings.tracing_exporter == "console":
|
||||
from opentelemetry.sdk.trace.export import ConsoleSpanExporter
|
||||
|
||||
exporter = ConsoleSpanExporter()
|
||||
|
||||
elif settings.tracing_exporter == "otlp":
|
||||
from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import (
|
||||
OTLPSpanExporter,
|
||||
)
|
||||
|
||||
if not settings.tracing_endpoint:
|
||||
raise ValueError("OTLP exporter requires an endpoint")
|
||||
exporter = OTLPSpanExporter(
|
||||
endpoint=f"{settings.tracing_endpoint}/v1/traces",
|
||||
insecure=settings.tracing_insecure,
|
||||
timeout=settings.tracing_timeout,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported exporter: {settings.tracing_exporter}")
|
||||
|
||||
# Add span processor
|
||||
span_processor = BatchSpanProcessor(exporter)
|
||||
tracer_provider.add_span_processor(span_processor)
|
||||
|
||||
# Set the global tracer provider
|
||||
trace.set_tracer_provider(tracer_provider)
|
||||
|
||||
# Initialize global tracer
|
||||
_tracer = trace.get_tracer(__name__)
|
||||
_tracing_enabled = True
|
||||
|
||||
# Setup auto-instrumentation
|
||||
AsyncioInstrumentor().instrument()
|
||||
|
||||
logger.info(
|
||||
f"Tracing configured with {settings.tracing_exporter} exporter, service: {settings.tracing_service_name}"
|
||||
)
|
||||
|
||||
except ImportError as e:
|
||||
msg = (
|
||||
f"Tracing is enabled but OpenTelemetry instrumentation packages are missing: {e}. "
|
||||
"Run `pip install llama_deploy[observability]`"
|
||||
)
|
||||
logger.warning(msg)
|
||||
_tracing_enabled = False
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to configure tracing: {e}")
|
||||
_tracing_enabled = False
|
||||
|
||||
|
||||
def get_tracer() -> Any | None:
|
||||
"""Get the configured tracer instance."""
|
||||
return _tracer if _tracing_enabled else None
|
||||
|
||||
|
||||
def is_tracing_enabled() -> bool:
|
||||
"""Check if tracing is enabled."""
|
||||
return _tracing_enabled
|
||||
|
||||
|
||||
def trace_method(
|
||||
span_name: str | None = None, attributes: dict | None = None
|
||||
) -> Callable[[F], F]:
|
||||
"""Decorator to add tracing to synchronous methods."""
|
||||
|
||||
def decorator(func: F) -> F:
|
||||
if not _tracing_enabled:
|
||||
return func
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs): # type: ignore
|
||||
tracer = get_tracer()
|
||||
if not tracer:
|
||||
return func(*args, **kwargs)
|
||||
|
||||
name = span_name or f"{func.__module__}.{func.__qualname__}"
|
||||
with tracer.start_as_current_span(name) as span:
|
||||
if attributes:
|
||||
span.set_attributes(attributes)
|
||||
|
||||
if hasattr(func, "__annotations__"):
|
||||
for i, (param_name, _) in enumerate(func.__annotations__.items()):
|
||||
if i < len(args) and param_name not in {"self", "cls"}:
|
||||
span.set_attribute(
|
||||
f"arg.{param_name}", str(args[i])[:100]
|
||||
) # Truncate long values
|
||||
|
||||
try:
|
||||
result = func(*args, **kwargs)
|
||||
span.set_attribute("success", True)
|
||||
return result
|
||||
except Exception as e:
|
||||
span.set_attribute("success", False)
|
||||
span.set_attribute("error.type", type(e).__name__)
|
||||
span.set_attribute("error.message", str(e))
|
||||
raise
|
||||
|
||||
return wrapper # type: ignore
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def trace_async_method(
|
||||
span_name: str | None = None, attributes: dict | None = None
|
||||
) -> Callable[[F], F]:
|
||||
"""Decorator to add tracing to asynchronous methods."""
|
||||
|
||||
def decorator(func: F) -> F:
|
||||
if not _tracing_enabled:
|
||||
return func
|
||||
|
||||
@wraps(func)
|
||||
async def wrapper(*args, **kwargs): # type: ignore
|
||||
tracer = get_tracer()
|
||||
if not tracer:
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
name = span_name or f"{func.__module__}.{func.__qualname__}"
|
||||
with tracer.start_as_current_span(name) as span:
|
||||
if attributes:
|
||||
span.set_attributes(attributes)
|
||||
|
||||
if hasattr(func, "__annotations__"):
|
||||
for i, (param_name, _) in enumerate(func.__annotations__.items()):
|
||||
if i < len(args) and param_name not in {"self", "cls"}:
|
||||
span.set_attribute(
|
||||
f"arg.{param_name}", str(args[i])[:100]
|
||||
) # Truncate long values
|
||||
|
||||
try:
|
||||
result = await func(*args, **kwargs)
|
||||
span.set_attribute("success", True)
|
||||
return result
|
||||
except Exception as e:
|
||||
span.set_attribute("success", False)
|
||||
span.set_attribute("error.type", type(e).__name__)
|
||||
span.set_attribute("error.message", str(e))
|
||||
raise
|
||||
|
||||
return wrapper # type: ignore
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
@contextmanager
|
||||
def create_span(
|
||||
name: str, attributes: dict | None = None
|
||||
) -> Generator[Any, None, None]:
|
||||
tracer = get_tracer()
|
||||
if tracer is None:
|
||||
yield
|
||||
return
|
||||
|
||||
with tracer.start_as_current_span(name) as span:
|
||||
if attributes:
|
||||
for k, v in attributes.items():
|
||||
span.set_attribute(k, v)
|
||||
yield span
|
||||
|
||||
|
||||
def add_span_attribute(key: str, value: Any) -> None:
|
||||
"""Add an attribute to the current span if tracing is enabled."""
|
||||
if not _tracing_enabled:
|
||||
return
|
||||
|
||||
try:
|
||||
from opentelemetry import trace
|
||||
|
||||
current_span = trace.get_current_span()
|
||||
if current_span:
|
||||
current_span.set_attribute(key, str(value))
|
||||
except Exception:
|
||||
# Silently ignore tracing errors
|
||||
pass
|
||||
|
||||
|
||||
def add_span_event(name: str, attributes: dict | None = None) -> None:
|
||||
"""Add an event to the current span if tracing is enabled."""
|
||||
if not _tracing_enabled:
|
||||
return
|
||||
|
||||
try:
|
||||
from opentelemetry import trace
|
||||
|
||||
current_span = trace.get_current_span()
|
||||
if current_span:
|
||||
current_span.add_event(name, attributes or {})
|
||||
except Exception:
|
||||
# Silently ignore tracing errors
|
||||
pass
|
||||
@@ -6,7 +6,7 @@ import click
|
||||
from prometheus_client import start_http_server
|
||||
from tenacity import RetryError, Retrying, stop_after_attempt, wait_fixed
|
||||
|
||||
from llama_deploy.apiserver import settings
|
||||
from llama_deploy.apiserver.settings import settings
|
||||
from llama_deploy.client import Client
|
||||
|
||||
RETRY_WAIT_SECONDS = 1
|
||||
|
||||
@@ -2,7 +2,7 @@ import asyncio
|
||||
import json
|
||||
import uuid
|
||||
from logging import getLogger
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple
|
||||
from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, List, Optional, Tuple
|
||||
|
||||
import uvicorn
|
||||
from fastapi import FastAPI, HTTPException
|
||||
@@ -11,7 +11,14 @@ from fastapi.responses import StreamingResponse
|
||||
from llama_index.core.storage.kvstore import SimpleKVStore
|
||||
from llama_index.core.storage.kvstore.types import BaseKVStore
|
||||
|
||||
from llama_deploy.message_queues.base import AbstractMessageQueue, PublishCallback
|
||||
from llama_deploy.apiserver.tracing import (
|
||||
add_span_attribute,
|
||||
trace_async_method,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from llama_deploy.message_queues.base import AbstractMessageQueue, PublishCallback
|
||||
|
||||
from llama_deploy.types import (
|
||||
ActionTypes,
|
||||
EventDefinition,
|
||||
@@ -62,8 +69,8 @@ class ControlPlaneServer:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message_queue: AbstractMessageQueue,
|
||||
publish_callback: PublishCallback | None = None,
|
||||
message_queue: "AbstractMessageQueue",
|
||||
publish_callback: "PublishCallback | None" = None,
|
||||
state_store: BaseKVStore | None = None,
|
||||
config: ControlPlaneConfig | None = None,
|
||||
) -> None:
|
||||
@@ -198,7 +205,7 @@ class ControlPlaneServer:
|
||||
)
|
||||
|
||||
@property
|
||||
def message_queue(self) -> AbstractMessageQueue:
|
||||
def message_queue(self) -> "AbstractMessageQueue":
|
||||
return self._message_queue
|
||||
|
||||
@property
|
||||
@@ -206,7 +213,7 @@ class ControlPlaneServer:
|
||||
return self._publisher_id
|
||||
|
||||
@property
|
||||
def publish_callback(self) -> Optional[PublishCallback]:
|
||||
def publish_callback(self) -> Optional["PublishCallback"]:
|
||||
return self._publish_callback
|
||||
|
||||
async def _process_messages(self, topic: str) -> None:
|
||||
@@ -345,9 +352,14 @@ class ControlPlaneServer:
|
||||
return None
|
||||
return await self.get_task(session.task_ids[-1])
|
||||
|
||||
@trace_async_method("control_plane.add_task_to_session")
|
||||
async def add_task_to_session(
|
||||
self, session_id: str, task_def: TaskDefinition
|
||||
) -> str:
|
||||
add_span_attribute("session.id", session_id)
|
||||
add_span_attribute("task.id", task_def.task_id)
|
||||
add_span_attribute("task.service_id", task_def.service_id or "auto")
|
||||
|
||||
session_dict = await self._state_store.aget(
|
||||
session_id, collection=self._config.session_store_key
|
||||
)
|
||||
@@ -402,10 +414,14 @@ class ControlPlaneServer:
|
||||
|
||||
return task_def
|
||||
|
||||
@trace_async_method("control_plane.handle_service_completion")
|
||||
async def handle_service_completion(
|
||||
self,
|
||||
task_result: TaskResult,
|
||||
) -> None:
|
||||
add_span_attribute("task.id", task_result.task_id)
|
||||
add_span_attribute("task.result_length", len(task_result.result))
|
||||
|
||||
# add result to task state
|
||||
task_def = await self.get_task(task_result.task_id)
|
||||
if task_def.session_id is None:
|
||||
|
||||
@@ -1,14 +1,14 @@
|
||||
from llama_deploy.message_queues.apache_kafka import (
|
||||
from .apache_kafka import (
|
||||
KafkaMessageQueue,
|
||||
KafkaMessageQueueConfig,
|
||||
)
|
||||
from llama_deploy.message_queues.base import AbstractMessageQueue
|
||||
from llama_deploy.message_queues.rabbitmq import (
|
||||
from .base import AbstractMessageQueue
|
||||
from .rabbitmq import (
|
||||
RabbitMQMessageQueue,
|
||||
RabbitMQMessageQueueConfig,
|
||||
)
|
||||
from llama_deploy.message_queues.redis import RedisMessageQueue, RedisMessageQueueConfig
|
||||
from llama_deploy.message_queues.simple import (
|
||||
from .redis import RedisMessageQueue, RedisMessageQueueConfig
|
||||
from .simple import (
|
||||
SimpleMessageQueue,
|
||||
SimpleMessageQueueConfig,
|
||||
SimpleMessageQueueServer,
|
||||
|
||||
@@ -7,6 +7,7 @@ from typing import Any, AsyncIterator, Awaitable, Callable
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_deploy.apiserver.tracing import add_span_attribute, create_span
|
||||
from llama_deploy.types import QueueMessage
|
||||
|
||||
logger = getLogger(__name__)
|
||||
@@ -35,19 +36,28 @@ class AbstractMessageQueue(ABC):
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Send message to a consumer."""
|
||||
logger.info(
|
||||
f"Publishing message of type '{message.type}' with action '{message.action}' to topic '{topic}'"
|
||||
)
|
||||
logger.debug(f"Message: {message.model_dump()}")
|
||||
with create_span("message_queue.publish"):
|
||||
add_span_attribute("message.type", message.type)
|
||||
add_span_attribute("message.action", str(message.action))
|
||||
add_span_attribute("message.topic", topic)
|
||||
add_span_attribute("message.id", message.id_)
|
||||
|
||||
message.stats.publish_time = message.stats.timestamp_str()
|
||||
await self._publish(message, topic, create_topic)
|
||||
logger.info(
|
||||
f"Publishing message of type '{message.type}' with action '{message.action}' to topic '{topic}'"
|
||||
)
|
||||
logger.debug(f"Message: {message.model_dump()}")
|
||||
|
||||
if callback:
|
||||
if inspect.iscoroutinefunction(callback):
|
||||
await callback(message, **kwargs)
|
||||
else:
|
||||
callback(message, **kwargs)
|
||||
message.stats.publish_time = message.stats.timestamp_str()
|
||||
message.stats.set_trace_context()
|
||||
|
||||
await self._publish(message, topic, create_topic)
|
||||
|
||||
if callback:
|
||||
with create_span("message_queue.callback"):
|
||||
if inspect.iscoroutinefunction(callback):
|
||||
await callback(message, **kwargs)
|
||||
else:
|
||||
callback(message, **kwargs)
|
||||
|
||||
@abstractmethod
|
||||
async def cleanup(self, *args: Any, **kwargs: dict[str, Any]) -> None:
|
||||
|
||||
@@ -18,6 +18,12 @@ from llama_index.core.workflow.handler import WorkflowHandler
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
from llama_deploy.apiserver.tracing import (
|
||||
add_span_attribute,
|
||||
add_span_event,
|
||||
create_span,
|
||||
trace_async_method,
|
||||
)
|
||||
from llama_deploy.control_plane.server import (
|
||||
CONTROL_PLANE_MESSAGE_TYPE,
|
||||
ControlPlaneConfig,
|
||||
@@ -212,6 +218,7 @@ class WorkflowService:
|
||||
# Store the state in the control plane
|
||||
await self.update_session_state(current_state.session_id, session_state)
|
||||
|
||||
@trace_async_method("workflow.process_call")
|
||||
async def process_call(self, current_call: WorkflowState) -> None:
|
||||
"""Processes a given task, and writes a response to the message queue.
|
||||
|
||||
@@ -222,78 +229,98 @@ class WorkflowService:
|
||||
current_call (WorkflowState):
|
||||
The state of the current task, including run_kwargs and other session state.
|
||||
"""
|
||||
add_span_attribute("task.id", current_call.task_id)
|
||||
add_span_attribute("task.session_id", current_call.session_id or "none")
|
||||
add_span_attribute("workflow.service_name", self.service_name)
|
||||
|
||||
# create send_event background task
|
||||
close_send_events = asyncio.Event()
|
||||
handler = None
|
||||
|
||||
try:
|
||||
add_span_event("workflow.state.loading")
|
||||
# load the state
|
||||
ctx = await self.get_workflow_state(current_call)
|
||||
with create_span("workflow.get_state", {"task.id": current_call.task_id}):
|
||||
ctx = await self.get_workflow_state(current_call)
|
||||
|
||||
add_span_event("workflow.execution.starting")
|
||||
# run the workflow
|
||||
handler = self.workflow.run(ctx=ctx, **current_call.run_kwargs)
|
||||
if handler.ctx is None:
|
||||
# This should never happen, workflow.run actually sets the Context
|
||||
# even if handler.ctx is typed as Optional[Context]
|
||||
raise ValueError("Context cannot be None.")
|
||||
|
||||
async def send_events(
|
||||
handler: WorkflowHandler, close_event: asyncio.Event
|
||||
) -> None:
|
||||
with create_span("workflow.run", {"task.id": current_call.task_id}):
|
||||
handler = self.workflow.run(ctx=ctx, **current_call.run_kwargs)
|
||||
if handler.ctx is None:
|
||||
raise ValueError("handler does not have a valid Context.")
|
||||
# This should never happen, workflow.run actually sets the Context
|
||||
# even if handler.ctx is typed as Optional[Context]
|
||||
raise ValueError("Context cannot be None.")
|
||||
|
||||
while not close_event.is_set():
|
||||
try:
|
||||
event = self._events_buffer[current_call.task_id].get_nowait()
|
||||
handler.ctx.send_event(event)
|
||||
except asyncio.QueueEmpty:
|
||||
pass
|
||||
await asyncio.sleep(self.config.step_interval)
|
||||
async def send_events(
|
||||
handler: WorkflowHandler, close_event: asyncio.Event
|
||||
) -> None:
|
||||
if handler.ctx is None:
|
||||
raise ValueError("handler does not have a valid Context.")
|
||||
|
||||
_ = asyncio.create_task(send_events(handler, close_send_events))
|
||||
while not close_event.is_set():
|
||||
try:
|
||||
event = self._events_buffer[
|
||||
current_call.task_id
|
||||
].get_nowait()
|
||||
handler.ctx.send_event(event)
|
||||
except asyncio.QueueEmpty:
|
||||
pass
|
||||
await asyncio.sleep(self.config.step_interval)
|
||||
|
||||
index = 0
|
||||
async for ev in handler.stream_events():
|
||||
# send the event to control plane for client / api server streaming
|
||||
logger.debug(f"Publishing event: {ev}")
|
||||
_ = asyncio.create_task(send_events(handler, close_send_events))
|
||||
|
||||
index = 0
|
||||
async for ev in handler.stream_events():
|
||||
# send the event to control plane for client / api server streaming
|
||||
logger.debug(f"Publishing event: {ev}")
|
||||
with create_span("workflow.event.publish", {"event.index": index}):
|
||||
await self.message_queue.publish(
|
||||
QueueMessage(
|
||||
type=CONTROL_PLANE_MESSAGE_TYPE,
|
||||
action=ActionTypes.TASK_STREAM,
|
||||
data=TaskStream(
|
||||
task_id=current_call.task_id,
|
||||
session_id=current_call.session_id,
|
||||
data=ev.model_dump(),
|
||||
index=index,
|
||||
).model_dump(),
|
||||
),
|
||||
self.get_topic(CONTROL_PLANE_MESSAGE_TYPE),
|
||||
)
|
||||
index += 1
|
||||
|
||||
final_result = await handler
|
||||
add_span_attribute("workflow.result.type", type(final_result).__name__)
|
||||
|
||||
add_span_event("workflow.state.saving")
|
||||
# dump the state
|
||||
with create_span("workflow.set_state", {"task.id": current_call.task_id}):
|
||||
await self.set_workflow_state(handler.ctx, current_call)
|
||||
|
||||
add_span_event("workflow.result.publishing")
|
||||
logger.info(
|
||||
f"Publishing final result: {final_result} to '{self.get_topic(CONTROL_PLANE_MESSAGE_TYPE)}'"
|
||||
)
|
||||
with create_span("workflow.result.publish"):
|
||||
await self.message_queue.publish(
|
||||
QueueMessage(
|
||||
type=CONTROL_PLANE_MESSAGE_TYPE,
|
||||
action=ActionTypes.TASK_STREAM,
|
||||
data=TaskStream(
|
||||
action=ActionTypes.COMPLETED_TASK,
|
||||
data=TaskResult(
|
||||
task_id=current_call.task_id,
|
||||
session_id=current_call.session_id,
|
||||
data=ev.model_dump(),
|
||||
index=index,
|
||||
history=[],
|
||||
result=str(final_result),
|
||||
data={},
|
||||
).model_dump(),
|
||||
),
|
||||
self.get_topic(CONTROL_PLANE_MESSAGE_TYPE),
|
||||
)
|
||||
index += 1
|
||||
|
||||
final_result = await handler
|
||||
|
||||
# dump the state
|
||||
await self.set_workflow_state(handler.ctx, current_call)
|
||||
|
||||
logger.info(
|
||||
f"Publishing final result: {final_result} to '{self.get_topic(CONTROL_PLANE_MESSAGE_TYPE)}'"
|
||||
)
|
||||
await self.message_queue.publish(
|
||||
QueueMessage(
|
||||
type=CONTROL_PLANE_MESSAGE_TYPE,
|
||||
action=ActionTypes.COMPLETED_TASK,
|
||||
data=TaskResult(
|
||||
task_id=current_call.task_id,
|
||||
history=[],
|
||||
result=str(final_result),
|
||||
data={},
|
||||
).model_dump(),
|
||||
),
|
||||
self.get_topic(CONTROL_PLANE_MESSAGE_TYPE),
|
||||
)
|
||||
except Exception as e:
|
||||
add_span_attribute("error.occurred", True)
|
||||
add_span_attribute("error.type", type(e).__name__)
|
||||
add_span_attribute("error.message", str(e))
|
||||
|
||||
if self.config.raise_exceptions:
|
||||
raise e
|
||||
|
||||
@@ -306,19 +333,20 @@ class WorkflowService:
|
||||
await self.set_workflow_state(handler.ctx, current_call)
|
||||
|
||||
# return failure
|
||||
await self.message_queue.publish(
|
||||
QueueMessage(
|
||||
type=CONTROL_PLANE_MESSAGE_TYPE,
|
||||
action=ActionTypes.COMPLETED_TASK,
|
||||
data=TaskResult(
|
||||
task_id=current_call.task_id,
|
||||
history=[],
|
||||
result=str(e),
|
||||
data={},
|
||||
).model_dump(),
|
||||
),
|
||||
self.get_topic(CONTROL_PLANE_MESSAGE_TYPE),
|
||||
)
|
||||
with create_span("workflow.error.publish"):
|
||||
await self.message_queue.publish(
|
||||
QueueMessage(
|
||||
type=CONTROL_PLANE_MESSAGE_TYPE,
|
||||
action=ActionTypes.COMPLETED_TASK,
|
||||
data=TaskResult(
|
||||
task_id=current_call.task_id,
|
||||
history=[],
|
||||
result=str(e),
|
||||
data={},
|
||||
).model_dump(),
|
||||
),
|
||||
self.get_topic(CONTROL_PLANE_MESSAGE_TYPE),
|
||||
)
|
||||
finally:
|
||||
# clean up
|
||||
close_send_events.set()
|
||||
|
||||
@@ -39,16 +39,39 @@ class QueueMessageStats(BaseModel):
|
||||
The time the message processing started.
|
||||
process_end_time (Optional[str]):
|
||||
The time the message processing ended.
|
||||
trace_id (Optional[str]):
|
||||
The trace ID for distributed tracing.
|
||||
span_id (Optional[str]):
|
||||
The span ID for distributed tracing.
|
||||
"""
|
||||
|
||||
publish_time: str | None = Field(default=None)
|
||||
process_start_time: str | None = Field(default=None)
|
||||
process_end_time: str | None = Field(default=None)
|
||||
trace_id: str | None = Field(default=None)
|
||||
span_id: str | None = Field(default=None)
|
||||
|
||||
@staticmethod
|
||||
def timestamp_str(format: str = "%Y-%m-%d %H:%M:%S") -> str:
|
||||
return datetime.now().strftime(format)
|
||||
|
||||
def set_trace_context(self) -> None:
|
||||
"""Set trace context from current span if tracing is enabled."""
|
||||
try:
|
||||
from opentelemetry import trace
|
||||
|
||||
current_span = trace.get_current_span()
|
||||
if current_span and current_span.is_recording():
|
||||
ctx = current_span.get_span_context()
|
||||
self.trace_id = format(ctx.trace_id, "032x")
|
||||
self.span_id = format(ctx.span_id, "016x")
|
||||
except ImportError:
|
||||
# OpenTelemetry not available
|
||||
pass
|
||||
except Exception:
|
||||
# Silently ignore tracing errors
|
||||
pass
|
||||
|
||||
|
||||
class QueueMessage(BaseModel):
|
||||
"""A message for the message queue.
|
||||
|
||||
@@ -58,6 +58,13 @@ dependencies = [
|
||||
kafka = ["aiokafka>=0.11.0,<0.12", "kafka-python-ng>=2.2.2,<3"]
|
||||
rabbitmq = ["aio-pika>=9.4.2,<10"]
|
||||
redis = ["redis>=5.0.7,<6"]
|
||||
observability = [
|
||||
"opentelemetry-api>=1.20.0,<2.0",
|
||||
"opentelemetry-sdk>=1.20.0,<2.0",
|
||||
"opentelemetry-instrumentation-asyncio>=0.41b0,<1.0",
|
||||
"opentelemetry-exporter-jaeger>=1.20.0,<2.0",
|
||||
"opentelemetry-exporter-otlp>=1.20.0,<2.0"
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
llamactl = "llama_deploy.cli.__main__:main"
|
||||
|
||||
@@ -14,7 +14,7 @@ from fastapi.testclient import TestClient
|
||||
from llama_index.core.workflow.context_serializers import JsonSerializer
|
||||
from llama_index.core.workflow.events import Event
|
||||
|
||||
from llama_deploy.apiserver import DeploymentConfig
|
||||
from llama_deploy.apiserver.deployment_config_parser import DeploymentConfig
|
||||
from llama_deploy.types import TaskResult
|
||||
from llama_deploy.types.core import EventDefinition, TaskDefinition
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ from unittest import mock
|
||||
import httpx
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from llama_deploy.apiserver import settings
|
||||
from llama_deploy.apiserver.settings import settings
|
||||
|
||||
|
||||
def test_read_main(http_client: TestClient) -> None:
|
||||
|
||||
@@ -3,7 +3,7 @@ from unittest import mock
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_deploy.apiserver import DeploymentConfig
|
||||
from llama_deploy.apiserver.deployment_config_parser import DeploymentConfig
|
||||
from llama_deploy.apiserver.source_managers.base import SyncPolicy
|
||||
from llama_deploy.apiserver.source_managers.local import LocalSourceManager
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@ import pytest
|
||||
from click.testing import CliRunner
|
||||
from tenacity import RetryError
|
||||
|
||||
from llama_deploy.apiserver import settings
|
||||
from llama_deploy.apiserver.settings import settings
|
||||
from llama_deploy.cli.serve import serve
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user