Merge pull request #251 from rkconsulting/issues/182/forward-client-http-headers

Feat: Add client header forwarding feature
This commit is contained in:
Tim Jaeryang Baek
2025-09-24 18:04:50 -05:00
committed by GitHub
5 changed files with 664 additions and 379 deletions
+151
View File
@@ -0,0 +1,151 @@
# Client Header Forwarding in MCPO
MCPO supports forwarding HTTP headers from incoming client requests to MCP servers. This enables passing user context, authentication tokens, and other request-specific information to your MCP tools.
## Configuration
Add client header forwarding configuration to your MCP server config:
```json
{
"mcpServers": {
"some-mcp": {
"command": "uvx",
"args": ["some-mcp"],
"client_header_forwarding": {
"enabled": true,
"whitelist": ["Authorization", "X-User-*", "X-Request-ID"],
"blacklist": ["Host", "Content-Length"],
"debug_headers": false
}
}
}
}
```
## Configuration Options
- `enabled`: Enable/disable client header forwarding for this server (default: false)
- `whitelist`: List of header patterns to forward (supports wildcards with `*`)
- `blacklist`: List of header patterns to block (takes precedence over whitelist)
- `debug_headers`: Enable debug logging for header processing (default: false)
## Header Pattern Matching
- **Exact match**: `"Authorization"` matches only the `Authorization` header
- **Wildcard match**: `"X-User-*"` matches `X-User-ID`, `X-User-Email`, etc.
- **Global wildcard**: `"*"` matches all headers (use with caution)
## How It Works
1. **Client Request**: A client makes an HTTP request to mcpo with headers like `Authorization: Bearer <token>`
2. **Header Filtering**: Headers are filtered based on whitelist/blacklist rules
3. **MCP Forwarding**: Filtered headers are passed to the MCP server via the `_meta.headers` field in tool calls
## Transport Support
Client header forwarding works with all MCP transport types:
- **stdio**: Headers are passed via `_meta` field in JSON-RPC calls
- **SSE**: Headers are passed via `_meta` field in JSON-RPC calls
- **HTTP**: Headers are passed via `_meta` field in JSON-RPC calls
## Complementary Features
Client header forwarding works alongside mcpo's connection-level headers:
```json
{
"mcpServers": {
"protected-server": {
"type": "sse",
"url": "https://api.example.com/mcp",
"headers": {
"Authorization": "Bearer server-token-123"
},
"client_header_forwarding": {
"enabled": true,
"whitelist": ["Authorization", "X-User-*"]
}
}
}
}
```
- **`headers`**: Static headers for mcpo ↔ MCP server authentication
- **`client_header_forwarding`**: Dynamic headers from client ↔ MCP server
## Security Considerations
- **Whitelist Headers**: Only forward necessary headers to minimize attack surface
- **Blacklist Sensitive Headers**: Block headers like `Host`, `Content-Length`, etc.
- **Debug Mode**: Only enable `debug_headers` in development environments
## MCP Server Integration
Your MCP server can access forwarded headers through the `_meta` field in tool calls:
```python
from mcp.server.fastmcp import FastMCP, Context
from mcp.server.session import ServerSession
mcp = FastMCP(name="Example Server")
@mcp.tool()
async def protected_tool(data: str, ctx: Context[ServerSession, None]) -> str:
# Access forwarded headers
headers = getattr(ctx.request_meta, 'headers', {}) if hasattr(ctx, 'request_meta') else {}
# Check authorization
auth_header = headers.get('Authorization', '')
if not auth_header.startswith('Bearer '):
raise ValueError("Missing or invalid authorization")
# Extract user context
user_id = headers.get('X-User-ID', 'unknown')
request_id = headers.get('X-Request-ID', 'unknown')
return f"Protected data for user {user_id} (request: {request_id}): {data}"
```
## Example Use Cases
### 1. User Authentication
```json
{
"client_header_forwarding": {
"enabled": true,
"whitelist": ["Authorization"]
}
}
```
Forward JWT tokens or API keys for user authentication.
### 2. Request Tracing
```json
{
"client_header_forwarding": {
"enabled": true,
"whitelist": ["X-Request-ID", "X-Trace-ID"]
}
}
```
Forward tracing headers for request correlation across services.
### 3. User Context
```json
{
"client_header_forwarding": {
"enabled": true,
"whitelist": ["X-User-*"],
"blacklist": ["X-User-Secret"]
}
}
```
Forward user information while blocking sensitive headers.
## Hot Reload Support
Client header forwarding configurations are automatically reloaded when using mcpo's `--hot-reload` feature. Changes to the configuration file will be applied without restarting the server.
+13 -2
View File
@@ -24,9 +24,8 @@ from mcpo.utils.main import (
get_tool_handler,
normalize_server_type,
)
from mcpo.utils.main import get_model_fields, get_tool_handler
from mcpo.utils.auth import get_verify_api_key, APIKeyMiddleware
from mcpo.utils.config_watcher import ConfigWatcher
from mcpo.utils.headers import validate_client_header_forwarding_config
from mcpo.utils.oauth import create_oauth_provider
@@ -85,6 +84,11 @@ def load_config(config_path: str) -> Dict[str, Any]:
# Validate each server configuration
for server_name, server_cfg in mcp_servers.items():
validate_server_config(server_name, server_cfg)
# Validate client header forwarding configuration if present
header_config = server_cfg.get("client_header_forwarding", {})
if header_config:
validate_client_header_forwarding_config(server_name, header_config)
return config_data
except json.JSONDecodeError as e:
@@ -148,6 +152,9 @@ def create_sub_app(server_name: str, server_cfg: Dict[str, Any], cors_allow_orig
sub_app.state.api_dependency = api_dependency
sub_app.state.connection_timeout = connection_timeout
# Store client header forwarding configuration
sub_app.state.client_header_forwarding = server_cfg.get("client_header_forwarding", {"enabled": False})
# Store OAuth configuration if present
sub_app.state.oauth_config = server_cfg.get("oauth")
@@ -298,11 +305,15 @@ async def create_dynamic_endpoints(app: FastAPI, api_dependency=None):
outputSchema.get("$defs", {}),
)
# Get client header forwarding configuration from app state
client_header_forwarding_config = getattr(app.state, "client_header_forwarding", {"enabled": False})
tool_handler = get_tool_handler(
session,
endpoint_name,
form_model_fields,
response_model_fields,
client_header_forwarding_config,
)
app.post(
+99
View File
@@ -0,0 +1,99 @@
import logging
import re
from typing import Dict, List, Optional, Any
from fastapi import Request
logger = logging.getLogger(__name__)
def validate_client_header_forwarding_config(server_name: str, config: Dict[str, Any]) -> None:
"""Validate client header forwarding configuration for a server."""
if not isinstance(config, dict):
raise ValueError(f"Server '{server_name}' client_header_forwarding must be a dictionary")
enabled = config.get("enabled", False)
if not isinstance(enabled, bool):
raise ValueError(f"Server '{server_name}' client_header_forwarding.enabled must be a boolean")
if not enabled:
return # No further validation needed if disabled
whitelist = config.get("whitelist", [])
blacklist = config.get("blacklist", [])
if whitelist and not isinstance(whitelist, list):
raise ValueError(f"Server '{server_name}' client_header_forwarding.whitelist must be a list")
if blacklist and not isinstance(blacklist, list):
raise ValueError(f"Server '{server_name}' client_header_forwarding.blacklist must be a list")
debug_headers = config.get("debug_headers", False)
if not isinstance(debug_headers, bool):
raise ValueError(f"Server '{server_name}' client_header_forwarding.debug_headers must be a boolean")
def match_header_pattern(header_name: str, patterns: List[str]) -> bool:
"""Check if header name matches any of the given patterns."""
for pattern in patterns:
if pattern == "*":
return True
if pattern.endswith("*"):
# Wildcard pattern like "X-User-*"
prefix = pattern[:-1]
if header_name.startswith(prefix):
return True
elif pattern == header_name:
return True
return False
def filter_headers(
request_headers: Dict[str, str],
whitelist: List[str],
blacklist: List[str],
debug_headers: bool = False
) -> Dict[str, str]:
"""Filter request headers based on whitelist and blacklist."""
filtered_headers = {}
for header_name, header_value in request_headers.items():
# Skip if in blacklist
if blacklist and match_header_pattern(header_name, blacklist):
if debug_headers:
logger.debug(f"Header '{header_name}' blocked by blacklist")
continue
# Include if in whitelist (or no whitelist specified)
if not whitelist or match_header_pattern(header_name, whitelist):
filtered_headers[header_name] = header_value
if debug_headers:
logger.debug(f"Header '{header_name}' forwarded")
elif debug_headers:
logger.debug(f"Header '{header_name}' not in whitelist")
return filtered_headers
def process_headers_for_server(
request: Request,
header_config: Dict[str, Any]
) -> Dict[str, str]:
"""Process and filter headers for a specific MCP server."""
if not header_config.get("enabled", False):
return {}
# Convert FastAPI headers to dict
request_headers = dict(request.headers)
# Get configuration values
whitelist = header_config.get("whitelist", [])
blacklist = header_config.get("blacklist", [])
debug_headers = header_config.get("debug_headers", False)
# Filter headers based on whitelist/blacklist
filtered_headers = filter_headers(request_headers, whitelist, blacklist, debug_headers)
if debug_headers:
logger.debug(f"Final forwarded headers: {list(filtered_headers.keys())}")
return filtered_headers
+29 -5
View File
@@ -2,7 +2,7 @@ import json
import traceback
from typing import Any, Dict, ForwardRef, List, Optional, Type, Union
import logging
from fastapi import HTTPException
from fastapi import HTTPException, Request
from mcp import ClientSession, types
from mcp.types import (
@@ -19,6 +19,8 @@ from mcp.shared.exceptions import McpError
from pydantic import Field, create_model
from pydantic.fields import FieldInfo
from mcpo.utils.headers import process_headers_for_server
MCP_ERROR_TO_HTTP_STATUS = {
PARSE_ERROR: 400,
INVALID_REQUEST: 400,
@@ -270,6 +272,7 @@ def get_tool_handler(
endpoint_name,
form_model_fields,
response_model_fields=None,
client_header_forwarding_config=None,
):
if form_model_fields:
FormModel = create_model(f"{endpoint_name}_form_model", **form_model_fields)
@@ -282,11 +285,22 @@ def get_tool_handler(
def make_endpoint_func(
endpoint_name: str, FormModel, session: ClientSession
): # Parameterized endpoint
async def tool(form_data: FormModel) -> Union[ResponseModel, Any]:
async def tool(form_data: FormModel, request: Request) -> Union[ResponseModel, Any]:
args = form_data.model_dump(exclude_none=True, by_alias=True)
# Process headers for forwarding if configured
forwarded_headers = {}
if client_header_forwarding_config and client_header_forwarding_config.get("enabled", False):
forwarded_headers = process_headers_for_server(request, client_header_forwarding_config)
# Add headers to _meta if any headers are being forwarded
meta = {}
if forwarded_headers:
meta["headers"] = forwarded_headers
logger.info(f"Calling endpoint: {endpoint_name}, with args: {args}")
try:
result = await session.call_tool(endpoint_name, arguments=args)
result = await session.call_tool(endpoint_name, arguments=args, _meta=meta if meta else None)
if result.isError:
error_message = "Unknown tool execution error"
@@ -338,11 +352,21 @@ def get_tool_handler(
def make_endpoint_func_no_args(
endpoint_name: str, session: ClientSession
): # Parameterless endpoint
async def tool(): # No parameters
async def tool(request: Request): # No parameters but need request for headers
# Process headers for forwarding if configured
forwarded_headers = {}
if client_header_forwarding_config and client_header_forwarding_config.get("enabled", False):
forwarded_headers = process_headers_for_server(request, client_header_forwarding_config)
# Add headers to _meta if any headers are being forwarded
meta = {}
if forwarded_headers:
meta["headers"] = forwarded_headers
logger.info(f"Calling endpoint: {endpoint_name}, with no args")
try:
result = await session.call_tool(
endpoint_name, arguments={}
endpoint_name, arguments={}, _meta=meta if meta else None
) # Empty dict
if result.isError:
Generated
+372 -372
View File
File diff suppressed because it is too large Load Diff