mirror of
https://github.com/open-webui/mcpo.git
synced 2026-07-01 21:04:00 -04:00
Merge pull request #251 from rkconsulting/issues/182/forward-client-http-headers
Feat: Add client header forwarding feature
This commit is contained in:
@@ -0,0 +1,151 @@
|
||||
# Client Header Forwarding in MCPO
|
||||
|
||||
MCPO supports forwarding HTTP headers from incoming client requests to MCP servers. This enables passing user context, authentication tokens, and other request-specific information to your MCP tools.
|
||||
|
||||
## Configuration
|
||||
|
||||
Add client header forwarding configuration to your MCP server config:
|
||||
|
||||
```json
|
||||
{
|
||||
"mcpServers": {
|
||||
"some-mcp": {
|
||||
"command": "uvx",
|
||||
"args": ["some-mcp"],
|
||||
"client_header_forwarding": {
|
||||
"enabled": true,
|
||||
"whitelist": ["Authorization", "X-User-*", "X-Request-ID"],
|
||||
"blacklist": ["Host", "Content-Length"],
|
||||
"debug_headers": false
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Configuration Options
|
||||
|
||||
- `enabled`: Enable/disable client header forwarding for this server (default: false)
|
||||
- `whitelist`: List of header patterns to forward (supports wildcards with `*`)
|
||||
- `blacklist`: List of header patterns to block (takes precedence over whitelist)
|
||||
- `debug_headers`: Enable debug logging for header processing (default: false)
|
||||
|
||||
## Header Pattern Matching
|
||||
|
||||
- **Exact match**: `"Authorization"` matches only the `Authorization` header
|
||||
- **Wildcard match**: `"X-User-*"` matches `X-User-ID`, `X-User-Email`, etc.
|
||||
- **Global wildcard**: `"*"` matches all headers (use with caution)
|
||||
|
||||
## How It Works
|
||||
|
||||
1. **Client Request**: A client makes an HTTP request to mcpo with headers like `Authorization: Bearer <token>`
|
||||
2. **Header Filtering**: Headers are filtered based on whitelist/blacklist rules
|
||||
3. **MCP Forwarding**: Filtered headers are passed to the MCP server via the `_meta.headers` field in tool calls
|
||||
|
||||
## Transport Support
|
||||
|
||||
Client header forwarding works with all MCP transport types:
|
||||
- **stdio**: Headers are passed via `_meta` field in JSON-RPC calls
|
||||
- **SSE**: Headers are passed via `_meta` field in JSON-RPC calls
|
||||
- **HTTP**: Headers are passed via `_meta` field in JSON-RPC calls
|
||||
|
||||
## Complementary Features
|
||||
|
||||
Client header forwarding works alongside mcpo's connection-level headers:
|
||||
|
||||
```json
|
||||
{
|
||||
"mcpServers": {
|
||||
"protected-server": {
|
||||
"type": "sse",
|
||||
"url": "https://api.example.com/mcp",
|
||||
"headers": {
|
||||
"Authorization": "Bearer server-token-123"
|
||||
},
|
||||
"client_header_forwarding": {
|
||||
"enabled": true,
|
||||
"whitelist": ["Authorization", "X-User-*"]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
- **`headers`**: Static headers for mcpo ↔ MCP server authentication
|
||||
- **`client_header_forwarding`**: Dynamic headers from client ↔ MCP server
|
||||
|
||||
## Security Considerations
|
||||
|
||||
- **Whitelist Headers**: Only forward necessary headers to minimize attack surface
|
||||
- **Blacklist Sensitive Headers**: Block headers like `Host`, `Content-Length`, etc.
|
||||
- **Debug Mode**: Only enable `debug_headers` in development environments
|
||||
|
||||
## MCP Server Integration
|
||||
|
||||
Your MCP server can access forwarded headers through the `_meta` field in tool calls:
|
||||
|
||||
```python
|
||||
from mcp.server.fastmcp import FastMCP, Context
|
||||
from mcp.server.session import ServerSession
|
||||
|
||||
mcp = FastMCP(name="Example Server")
|
||||
|
||||
@mcp.tool()
|
||||
async def protected_tool(data: str, ctx: Context[ServerSession, None]) -> str:
|
||||
# Access forwarded headers
|
||||
headers = getattr(ctx.request_meta, 'headers', {}) if hasattr(ctx, 'request_meta') else {}
|
||||
|
||||
# Check authorization
|
||||
auth_header = headers.get('Authorization', '')
|
||||
if not auth_header.startswith('Bearer '):
|
||||
raise ValueError("Missing or invalid authorization")
|
||||
|
||||
# Extract user context
|
||||
user_id = headers.get('X-User-ID', 'unknown')
|
||||
request_id = headers.get('X-Request-ID', 'unknown')
|
||||
|
||||
return f"Protected data for user {user_id} (request: {request_id}): {data}"
|
||||
```
|
||||
|
||||
## Example Use Cases
|
||||
|
||||
### 1. User Authentication
|
||||
```json
|
||||
{
|
||||
"client_header_forwarding": {
|
||||
"enabled": true,
|
||||
"whitelist": ["Authorization"]
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Forward JWT tokens or API keys for user authentication.
|
||||
|
||||
### 2. Request Tracing
|
||||
```json
|
||||
{
|
||||
"client_header_forwarding": {
|
||||
"enabled": true,
|
||||
"whitelist": ["X-Request-ID", "X-Trace-ID"]
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Forward tracing headers for request correlation across services.
|
||||
|
||||
### 3. User Context
|
||||
```json
|
||||
{
|
||||
"client_header_forwarding": {
|
||||
"enabled": true,
|
||||
"whitelist": ["X-User-*"],
|
||||
"blacklist": ["X-User-Secret"]
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Forward user information while blocking sensitive headers.
|
||||
|
||||
## Hot Reload Support
|
||||
|
||||
Client header forwarding configurations are automatically reloaded when using mcpo's `--hot-reload` feature. Changes to the configuration file will be applied without restarting the server.
|
||||
+13
-2
@@ -24,9 +24,8 @@ from mcpo.utils.main import (
|
||||
get_tool_handler,
|
||||
normalize_server_type,
|
||||
)
|
||||
from mcpo.utils.main import get_model_fields, get_tool_handler
|
||||
from mcpo.utils.auth import get_verify_api_key, APIKeyMiddleware
|
||||
from mcpo.utils.config_watcher import ConfigWatcher
|
||||
from mcpo.utils.headers import validate_client_header_forwarding_config
|
||||
from mcpo.utils.oauth import create_oauth_provider
|
||||
|
||||
|
||||
@@ -85,6 +84,11 @@ def load_config(config_path: str) -> Dict[str, Any]:
|
||||
# Validate each server configuration
|
||||
for server_name, server_cfg in mcp_servers.items():
|
||||
validate_server_config(server_name, server_cfg)
|
||||
|
||||
# Validate client header forwarding configuration if present
|
||||
header_config = server_cfg.get("client_header_forwarding", {})
|
||||
if header_config:
|
||||
validate_client_header_forwarding_config(server_name, header_config)
|
||||
|
||||
return config_data
|
||||
except json.JSONDecodeError as e:
|
||||
@@ -148,6 +152,9 @@ def create_sub_app(server_name: str, server_cfg: Dict[str, Any], cors_allow_orig
|
||||
sub_app.state.api_dependency = api_dependency
|
||||
sub_app.state.connection_timeout = connection_timeout
|
||||
|
||||
# Store client header forwarding configuration
|
||||
sub_app.state.client_header_forwarding = server_cfg.get("client_header_forwarding", {"enabled": False})
|
||||
|
||||
# Store OAuth configuration if present
|
||||
sub_app.state.oauth_config = server_cfg.get("oauth")
|
||||
|
||||
@@ -298,11 +305,15 @@ async def create_dynamic_endpoints(app: FastAPI, api_dependency=None):
|
||||
outputSchema.get("$defs", {}),
|
||||
)
|
||||
|
||||
# Get client header forwarding configuration from app state
|
||||
client_header_forwarding_config = getattr(app.state, "client_header_forwarding", {"enabled": False})
|
||||
|
||||
tool_handler = get_tool_handler(
|
||||
session,
|
||||
endpoint_name,
|
||||
form_model_fields,
|
||||
response_model_fields,
|
||||
client_header_forwarding_config,
|
||||
)
|
||||
|
||||
app.post(
|
||||
|
||||
@@ -0,0 +1,99 @@
|
||||
import logging
|
||||
import re
|
||||
from typing import Dict, List, Optional, Any
|
||||
from fastapi import Request
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def validate_client_header_forwarding_config(server_name: str, config: Dict[str, Any]) -> None:
|
||||
"""Validate client header forwarding configuration for a server."""
|
||||
if not isinstance(config, dict):
|
||||
raise ValueError(f"Server '{server_name}' client_header_forwarding must be a dictionary")
|
||||
|
||||
enabled = config.get("enabled", False)
|
||||
if not isinstance(enabled, bool):
|
||||
raise ValueError(f"Server '{server_name}' client_header_forwarding.enabled must be a boolean")
|
||||
|
||||
if not enabled:
|
||||
return # No further validation needed if disabled
|
||||
|
||||
whitelist = config.get("whitelist", [])
|
||||
blacklist = config.get("blacklist", [])
|
||||
|
||||
if whitelist and not isinstance(whitelist, list):
|
||||
raise ValueError(f"Server '{server_name}' client_header_forwarding.whitelist must be a list")
|
||||
|
||||
if blacklist and not isinstance(blacklist, list):
|
||||
raise ValueError(f"Server '{server_name}' client_header_forwarding.blacklist must be a list")
|
||||
|
||||
debug_headers = config.get("debug_headers", False)
|
||||
if not isinstance(debug_headers, bool):
|
||||
raise ValueError(f"Server '{server_name}' client_header_forwarding.debug_headers must be a boolean")
|
||||
|
||||
|
||||
def match_header_pattern(header_name: str, patterns: List[str]) -> bool:
|
||||
"""Check if header name matches any of the given patterns."""
|
||||
for pattern in patterns:
|
||||
if pattern == "*":
|
||||
return True
|
||||
if pattern.endswith("*"):
|
||||
# Wildcard pattern like "X-User-*"
|
||||
prefix = pattern[:-1]
|
||||
if header_name.startswith(prefix):
|
||||
return True
|
||||
elif pattern == header_name:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def filter_headers(
|
||||
request_headers: Dict[str, str],
|
||||
whitelist: List[str],
|
||||
blacklist: List[str],
|
||||
debug_headers: bool = False
|
||||
) -> Dict[str, str]:
|
||||
"""Filter request headers based on whitelist and blacklist."""
|
||||
filtered_headers = {}
|
||||
|
||||
for header_name, header_value in request_headers.items():
|
||||
# Skip if in blacklist
|
||||
if blacklist and match_header_pattern(header_name, blacklist):
|
||||
if debug_headers:
|
||||
logger.debug(f"Header '{header_name}' blocked by blacklist")
|
||||
continue
|
||||
|
||||
# Include if in whitelist (or no whitelist specified)
|
||||
if not whitelist or match_header_pattern(header_name, whitelist):
|
||||
filtered_headers[header_name] = header_value
|
||||
if debug_headers:
|
||||
logger.debug(f"Header '{header_name}' forwarded")
|
||||
elif debug_headers:
|
||||
logger.debug(f"Header '{header_name}' not in whitelist")
|
||||
|
||||
return filtered_headers
|
||||
|
||||
|
||||
def process_headers_for_server(
|
||||
request: Request,
|
||||
header_config: Dict[str, Any]
|
||||
) -> Dict[str, str]:
|
||||
"""Process and filter headers for a specific MCP server."""
|
||||
if not header_config.get("enabled", False):
|
||||
return {}
|
||||
|
||||
# Convert FastAPI headers to dict
|
||||
request_headers = dict(request.headers)
|
||||
|
||||
# Get configuration values
|
||||
whitelist = header_config.get("whitelist", [])
|
||||
blacklist = header_config.get("blacklist", [])
|
||||
debug_headers = header_config.get("debug_headers", False)
|
||||
|
||||
# Filter headers based on whitelist/blacklist
|
||||
filtered_headers = filter_headers(request_headers, whitelist, blacklist, debug_headers)
|
||||
|
||||
if debug_headers:
|
||||
logger.debug(f"Final forwarded headers: {list(filtered_headers.keys())}")
|
||||
|
||||
return filtered_headers
|
||||
+29
-5
@@ -2,7 +2,7 @@ import json
|
||||
import traceback
|
||||
from typing import Any, Dict, ForwardRef, List, Optional, Type, Union
|
||||
import logging
|
||||
from fastapi import HTTPException
|
||||
from fastapi import HTTPException, Request
|
||||
|
||||
from mcp import ClientSession, types
|
||||
from mcp.types import (
|
||||
@@ -19,6 +19,8 @@ from mcp.shared.exceptions import McpError
|
||||
from pydantic import Field, create_model
|
||||
from pydantic.fields import FieldInfo
|
||||
|
||||
from mcpo.utils.headers import process_headers_for_server
|
||||
|
||||
MCP_ERROR_TO_HTTP_STATUS = {
|
||||
PARSE_ERROR: 400,
|
||||
INVALID_REQUEST: 400,
|
||||
@@ -270,6 +272,7 @@ def get_tool_handler(
|
||||
endpoint_name,
|
||||
form_model_fields,
|
||||
response_model_fields=None,
|
||||
client_header_forwarding_config=None,
|
||||
):
|
||||
if form_model_fields:
|
||||
FormModel = create_model(f"{endpoint_name}_form_model", **form_model_fields)
|
||||
@@ -282,11 +285,22 @@ def get_tool_handler(
|
||||
def make_endpoint_func(
|
||||
endpoint_name: str, FormModel, session: ClientSession
|
||||
): # Parameterized endpoint
|
||||
async def tool(form_data: FormModel) -> Union[ResponseModel, Any]:
|
||||
async def tool(form_data: FormModel, request: Request) -> Union[ResponseModel, Any]:
|
||||
args = form_data.model_dump(exclude_none=True, by_alias=True)
|
||||
|
||||
# Process headers for forwarding if configured
|
||||
forwarded_headers = {}
|
||||
if client_header_forwarding_config and client_header_forwarding_config.get("enabled", False):
|
||||
forwarded_headers = process_headers_for_server(request, client_header_forwarding_config)
|
||||
|
||||
# Add headers to _meta if any headers are being forwarded
|
||||
meta = {}
|
||||
if forwarded_headers:
|
||||
meta["headers"] = forwarded_headers
|
||||
|
||||
logger.info(f"Calling endpoint: {endpoint_name}, with args: {args}")
|
||||
try:
|
||||
result = await session.call_tool(endpoint_name, arguments=args)
|
||||
result = await session.call_tool(endpoint_name, arguments=args, _meta=meta if meta else None)
|
||||
|
||||
if result.isError:
|
||||
error_message = "Unknown tool execution error"
|
||||
@@ -338,11 +352,21 @@ def get_tool_handler(
|
||||
def make_endpoint_func_no_args(
|
||||
endpoint_name: str, session: ClientSession
|
||||
): # Parameterless endpoint
|
||||
async def tool(): # No parameters
|
||||
async def tool(request: Request): # No parameters but need request for headers
|
||||
# Process headers for forwarding if configured
|
||||
forwarded_headers = {}
|
||||
if client_header_forwarding_config and client_header_forwarding_config.get("enabled", False):
|
||||
forwarded_headers = process_headers_for_server(request, client_header_forwarding_config)
|
||||
|
||||
# Add headers to _meta if any headers are being forwarded
|
||||
meta = {}
|
||||
if forwarded_headers:
|
||||
meta["headers"] = forwarded_headers
|
||||
|
||||
logger.info(f"Calling endpoint: {endpoint_name}, with no args")
|
||||
try:
|
||||
result = await session.call_tool(
|
||||
endpoint_name, arguments={}
|
||||
endpoint_name, arguments={}, _meta=meta if meta else None
|
||||
) # Empty dict
|
||||
|
||||
if result.isError:
|
||||
|
||||
Reference in New Issue
Block a user