Merge pull request #287 from open-webui/dev

0.0.20
This commit is contained in:
Tim Baek
2026-02-27 21:58:04 +04:00
committed by GitHub
5 changed files with 377 additions and 190 deletions
+2
View File
@@ -11,3 +11,5 @@ wheels/
config.json
.python-version
.vscode
.DS_Store
+18
View File
@@ -5,6 +5,24 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
## [0.0.20] - 2026-02-27
### Added
* 🔁 **MCP Connection Manager with Auto-Reconnect**: Introduced MCPConnectionManager to manage the full lifecycle of MCP client sessions and transports. Connections that drop due to a ClosedResourceError are now automatically re-established and retried transparently, eliminating manual restarts for intermittent server disconnects.
* 🔄 **Automatic Retry on Closed Sessions**: Tool calls that encounter a ClosedResourceError will now automatically reconnect to the MCP server and retry the call, providing resilient end-to-end execution without user intervention.
### Changed
* 🧩 **Normalized disabledTools Configuration Key**: The disabled tools setting now accepts both camelCase (disabledTools) and snake_case (disabled_tools) in the config file for backwards compatibility and consistency with other configuration keys.
* 🧹 **Simplified Tool Handler Architecture**: Refactored get_tool_handler to remove the session parameter and nested factory functions in favor of a cleaner design that resolves the session from the request app state at call time, enabling reconnect-aware tool execution.
* 🛡️ **Guarded Lifespan Startup During Hot Reload**: The reload handler now checks for the presence of an async context manager before awaiting sub-app lifespan startup, preventing errors when lifespan context is unavailable.
### Fixed
* 🪛 **Fixed Erroneous Validation Raise in disabledTools**: Removed a misplaced raise statement in validate_server_config that would always throw an error after validating disabled_tools, even when the configuration was correct.
* 🧯 **Graceful Handling of asyncio.CancelledError**: Added explicit CancelledError handling during server creation and lifespan startup to properly roll back routes and log errors instead of crashing silently.
## [0.0.19] - 2025-10-14
### Fixed
+1 -1
View File
@@ -1,6 +1,6 @@
[project]
name = "mcpo"
version = "0.0.19"
version = "0.0.20"
description = "A simple, secure MCP-to-OpenAPI proxy server"
authors = [
{ name = "Timothy Jaeryang Baek", email = "tim@openwebui.com" }
+209 -54
View File
@@ -1,11 +1,12 @@
import asyncio
import contextlib
import json
import logging
import os
import signal
import socket
from contextlib import AsyncExitStack, asynccontextmanager
from typing import Optional, Dict, Any
from typing import Any, Dict, List, Optional
from urllib.parse import urljoin
import uvicorn
@@ -52,6 +53,139 @@ class GracefulShutdown:
task.add_done_callback(self.tasks.discard)
class MCPConnectionManager:
"""
Manages lifecycle of the MCP ClientSession and underlying transport so that
we can reconnect transparently when the remote server drops the connection.
"""
def __init__(
self,
*,
server_type: str,
command: Optional[str],
args: List[str],
env: Dict[str, str],
headers: Optional[Dict[str, str]],
connection_timeout: Optional[int],
auth_provider: Optional[Any] = None,
):
self.server_type = normalize_server_type(server_type)
self.command = command
self.args = args
self.env = env
self.headers = headers
self.connection_timeout = connection_timeout
self.auth_provider = auth_provider
self._session: Optional[ClientSession] = None
self._client_context = None
self._lock = asyncio.Lock()
self._initialize_lock = asyncio.Lock()
self._initialize_result = None
self._initialized = False
@property
def current_session(self) -> Optional[ClientSession]:
return self._session
async def get_session(self) -> ClientSession:
async with self._lock:
if self._session is None:
await self._open_session_locked()
return self._session
async def ensure_initialized(self):
session = await self.get_session()
if self._initialized and self._initialize_result is not None:
return session, self._initialize_result
async with self._initialize_lock:
if not self._initialized or self._initialize_result is None:
initialize_result = await session.initialize()
self._initialize_result = initialize_result
self._initialized = True
else:
initialize_result = self._initialize_result
return session, initialize_result
async def reconnect(self):
async with self._lock:
await self._close_session_locked()
await self._open_session_locked()
# Run initialize outside of the lock to avoid deadlocks on nested calls
return await self.ensure_initialized()
async def close(self):
async with self._lock:
await self._close_session_locked()
async def _open_session_locked(self):
client_context = self._create_client_context()
try:
connection = await client_context.__aenter__()
except Exception:
# Ensure the context is closed if entering fails
with contextlib.suppress(Exception):
await client_context.__aexit__(None, None, None)
raise
reader, writer, *_ = connection
session = ClientSession(reader, writer)
try:
await session.__aenter__()
except Exception:
with contextlib.suppress(Exception):
await session.__aexit__(None, None, None)
with contextlib.suppress(Exception):
await client_context.__aexit__(None, None, None)
raise
self._client_context = client_context
self._session = session
self._initialized = False
self._initialize_result = None
async def _close_session_locked(self):
session, client_context = self._session, self._client_context
self._session = None
self._client_context = None
self._initialized = False
self._initialize_result = None
if session is not None:
with contextlib.suppress(Exception):
await session.__aexit__(None, None, None)
if client_context is not None:
with contextlib.suppress(Exception):
await client_context.__aexit__(None, None, None)
def _create_client_context(self):
if self.server_type == "stdio":
server_params = StdioServerParameters(
command=self.command,
args=self.args,
env={**os.environ, **self.env},
)
return stdio_client(server_params)
if self.server_type == "sse":
timeout = self.connection_timeout or 900
return sse_client(
url=self.args[0],
sse_read_timeout=timeout,
headers=self.headers,
)
if self.server_type == "streamable-http":
return streamablehttp_client(
url=self.args[0],
headers=self.headers,
auth=self.auth_provider,
)
raise ValueError(f"Unsupported server type: {self.server_type}")
def validate_server_config(server_name: str, server_cfg: Dict[str, Any]) -> None:
"""Validate individual server configuration."""
server_type = server_cfg.get("type")
@@ -73,17 +207,16 @@ def validate_server_config(server_name: str, server_cfg: Dict[str, Any]) -> None
else:
raise ValueError(f"Server '{server_name}' must have either 'command' for stdio or 'type' and 'url' for remote servers")
# Validate disabledTools
disabled_tools = server_cfg.get("disabled_tools")
# Validate disabledTools (supports camelCase & snake_case for backwards compatibility)
disabled_tools = server_cfg.get("disabledTools")
if disabled_tools is None:
disabled_tools = server_cfg.get("disabled_tools")
if disabled_tools is not None:
if not isinstance(disabled_tools, list):
raise ValueError(f"Server '{server_name}' 'disabledTools' must be a list")
for tool_name in disabled_tools:
if not isinstance(tool_name, str):
raise ValueError(f"Server '{server_name}' 'disabledTools' must contain only strings")
raise ValueError(
f"Server '{server_name}' must have either 'command' for stdio or 'type' and 'url' for remote servers"
)
def load_config(config_path: str) -> Dict[str, Any]:
@@ -177,8 +310,11 @@ def create_sub_app(
sub_app.state.api_dependency = api_dependency
sub_app.state.connection_timeout = connection_timeout
# Store list of tools to be disabled, if present
sub_app.state.disabled_tools = server_cfg.get("disabled_tools", [])
# Store list of tools to be disabled, if present (accept both key styles)
disabled_tools = server_cfg.get("disabledTools")
if disabled_tools is None:
disabled_tools = server_cfg.get("disabled_tools", [])
sub_app.state.disabled_tools = disabled_tools
# Store client header forwarding configuration
@@ -306,12 +442,20 @@ async def reload_config_handler(main_app: FastAPI, new_config_data: Dict[str, An
)
main_app.mount(f"{path_prefix}{server_name}", sub_app)
# Start the lifespan for the new sub-app
lifespan_context = sub_app.router.lifespan_context(sub_app)
await lifespan_context.__aenter__()
# Store the context manager for cleanup later
main_app.state.active_lifespans[server_name] = lifespan_context
# Start the lifespan for the new sub-app when available
lifespan_context = None
lifespan_factory = getattr(sub_app.router, "lifespan_context", None)
if lifespan_factory:
lifespan_context = lifespan_factory(sub_app)
if lifespan_context and hasattr(lifespan_context, "__aenter__"):
await lifespan_context.__aenter__()
# Store the context manager for cleanup later
main_app.state.active_lifespans[server_name] = lifespan_context
else:
logger.debug(
"Skipping lifespan startup for server '%s' (no async context manager)",
server_name,
)
# Check if connection was successful
is_connected = getattr(sub_app.state, "is_connected", False)
@@ -324,6 +468,11 @@ async def reload_config_handler(main_app: FastAPI, new_config_data: Dict[str, An
f"Failed to connect to new server: '{server_name}'"
)
except asyncio.CancelledError as e:
logger.error(f"Failed to create server '{server_name}' (cancelled): {e}")
# Rollback on failure
main_app.router.routes = backup_routes
raise
except Exception as e:
logger.error(f"Failed to create server '{server_name}': {e}")
# Rollback on failure
@@ -342,11 +491,18 @@ async def reload_config_handler(main_app: FastAPI, new_config_data: Dict[str, An
async def create_dynamic_endpoints(app: FastAPI, api_dependency=None):
session: ClientSession = app.state.session
if not session:
raise ValueError("Session is not initialized in the app state.")
session_manager: Optional[MCPConnectionManager] = getattr(
app.state, "session_manager", None
)
result = await session.initialize()
if session_manager:
session, result = await session_manager.ensure_initialized()
app.state.session = session
else:
session: Optional[ClientSession] = getattr(app.state, "session", None)
if not session:
raise ValueError("Session is not initialized in the app state.")
result = await session.initialize()
server_info = getattr(result, "serverInfo", None)
if server_info:
app.title = server_info.name or app.title
@@ -400,7 +556,6 @@ async def create_dynamic_endpoints(app: FastAPI, api_dependency=None):
)
tool_handler = get_tool_handler(
session,
endpoint_name,
form_model_fields,
response_model_fields,
@@ -459,6 +614,14 @@ async def lifespan(app: FastAPI):
f"Connection attempt for '{server_name}' finished, but status is not 'connected'."
)
failed_servers.append(server_name)
except asyncio.CancelledError as e:
if shutdown_handler and shutdown_handler.shutdown_event.is_set():
raise
logger.error(
f"Failed to establish connection for server: '{server_name}' - CancelledError: {e}",
exc_info=True,
)
failed_servers.append(server_name)
except Exception as e:
error_class_name = type(e).__name__
if error_class_name == "ExceptionGroup" or (
@@ -534,42 +697,28 @@ async def lifespan(app: FastAPI):
)
raise
if server_type == "stdio":
# stdio doesn't support OAuth authentication
if oauth_config:
logger.warning(f"OAuth not supported for stdio server type")
server_params = StdioServerParameters(
command=command,
args=args,
env={**os.environ, **env},
)
client_context = stdio_client(server_params)
elif server_type == "sse":
# SSE doesn't support OAuth authentication currently
if oauth_config:
logger.warning(f"OAuth not supported for SSE server type")
headers = getattr(app.state, "headers", None)
client_context = sse_client(
url=args[0],
sse_read_timeout=connection_timeout or 900,
headers=headers,
)
elif server_type == "streamable-http":
headers = getattr(app.state, "headers", None)
client_context = streamablehttp_client(
url=args[0],
headers=headers,
auth=auth_provider, # Pass OAuth provider if configured
)
else:
raise ValueError(f"Unsupported server type: {server_type}")
if oauth_config and server_type == "stdio":
logger.warning("OAuth not supported for stdio server type")
if oauth_config and server_type == "sse":
logger.warning("OAuth not supported for SSE server type")
async with client_context as (reader, writer, *_):
async with ClientSession(reader, writer) as session:
app.state.session = session
await create_dynamic_endpoints(app, api_dependency=api_dependency)
app.state.is_connected = True
yield
headers = getattr(app.state, "headers", None)
session_manager = MCPConnectionManager(
server_type=server_type,
command=command,
args=args,
env=env,
headers=headers,
connection_timeout=connection_timeout,
auth_provider=auth_provider,
)
app.state.session_manager = session_manager
session = await session_manager.get_session()
app.state.session = session
await create_dynamic_endpoints(app, api_dependency=api_dependency)
app.state.is_connected = True
yield
except Exception as e:
# Log the full exception with traceback for debugging
logger.error(
@@ -579,6 +728,12 @@ async def lifespan(app: FastAPI):
app.state.is_connected = False
# Re-raise the exception so it propagates to the main app's lifespan
raise
finally:
session_manager = getattr(app.state, "session_manager", None)
if session_manager:
await session_manager.close()
app.state.session_manager = None
app.state.session = None
async def run(
+147 -135
View File
@@ -1,10 +1,12 @@
import logging
import json
import traceback
from typing import Any, Dict, ForwardRef, List, Optional, Type, Union
import logging
from anyio import ClosedResourceError
from fastapi import HTTPException, Request
from mcp import ClientSession, types
from mcp import types
from mcp.types import (
CallToolResult,
PARSE_ERROR,
@@ -275,12 +277,50 @@ def get_model_fields(form_model_name, properties, required_fields, schema_defs=N
def get_tool_handler(
session,
endpoint_name,
form_model_fields,
response_model_fields=None,
client_header_forwarding_config=None,
):
async def call_tool_with_reconnect(
request: Request, arguments: Dict[str, Any]
) -> CallToolResult:
session_manager = getattr(request.app.state, "session_manager", None)
async def _invoke(session):
return await session.call_tool(endpoint_name, arguments=arguments)
if session_manager:
try:
session, _ = await session_manager.ensure_initialized()
except ClosedResourceError:
logger.warning(
"Session closed while initializing '%s'; attempting reconnect",
endpoint_name,
)
session, _ = await session_manager.reconnect()
request.app.state.session = session
try:
return await _invoke(session)
except ClosedResourceError:
logger.warning(
"Session closed during call to '%s'; attempting reconnect",
endpoint_name,
)
session, _ = await session_manager.reconnect()
request.app.state.session = session
return await _invoke(session)
session = getattr(request.app.state, "session", None)
if not session:
raise HTTPException(
status_code=500,
detail={"message": "MCP session is not available"},
)
return await _invoke(session)
if form_model_fields:
FormModel = create_model(f"{endpoint_name}_form_model", **form_model_fields)
ResponseModel = (
@@ -289,149 +329,121 @@ def get_tool_handler(
else Any
)
def make_endpoint_func(
endpoint_name: str, FormModel, session: ClientSession
): # Parameterized endpoint
async def tool(
form_data: FormModel, request: Request
) -> Union[ResponseModel, Any]:
args = form_data.model_dump(exclude_none=True, by_alias=True)
async def tool(
form_data: FormModel, request: Request
) -> Union[ResponseModel, Any]:
args = form_data.model_dump(exclude_none=True, by_alias=True)
# Process headers for forwarding if configured
forwarded_headers = {}
if (
client_header_forwarding_config
and client_header_forwarding_config.get("enabled", False)
):
forwarded_headers = process_headers_for_server(
request, client_header_forwarding_config
)
forwarded_headers = {}
if (
client_header_forwarding_config
and client_header_forwarding_config.get("enabled", False)
):
forwarded_headers = process_headers_for_server(
request, client_header_forwarding_config
)
# Add headers to _meta if any headers are being forwarded
meta = {}
if forwarded_headers:
meta["headers"] = forwarded_headers
meta = {}
if forwarded_headers:
meta["headers"] = forwarded_headers
logger.info(f"Calling endpoint: {endpoint_name}, with args: {args}")
try:
result = await session.call_tool(endpoint_name, arguments=args)
logger.info(f"Calling endpoint: {endpoint_name}, with args: {args}")
try:
result = await call_tool_with_reconnect(request, args)
if result.isError:
error_message = "Unknown tool execution error"
error_data = None # Initialize error_data
if result.content:
if isinstance(result.content[0], types.TextContent):
error_message = result.content[0].text
detail = {"message": error_message}
if error_data is not None:
detail["data"] = error_data
raise HTTPException(
status_code=500,
detail=detail,
)
if result.isError:
error_message = "Unknown tool execution error"
error_data = None
if result.content and isinstance(
result.content[0], types.TextContent
):
error_message = result.content[0].text
detail = {"message": error_message}
if error_data is not None:
detail["data"] = error_data
raise HTTPException(status_code=500, detail=detail)
response_data = process_tool_response(result)
final_response = (
response_data[0] if len(response_data) == 1 else response_data
)
return final_response
response_data = process_tool_response(result)
final_response = (
response_data[0] if len(response_data) == 1 else response_data
)
return final_response
except McpError as e:
logger.info(
f"MCP Error calling {endpoint_name}: {traceback.format_exc()}"
)
status_code = MCP_ERROR_TO_HTTP_STATUS.get(e.error.code, 500)
raise HTTPException(
status_code=status_code,
detail=(
{"message": e.error.message, "data": e.error.data}
if e.error.data is not None
else {"message": e.error.message}
),
)
except Exception as e:
logger.info(
f"Unexpected error calling {endpoint_name}: {traceback.format_exc()}"
)
raise HTTPException(
status_code=500,
detail={"message": "Unexpected error", "error": str(e)},
)
except McpError as e:
logger.info(
f"MCP Error calling {endpoint_name}: {traceback.format_exc()}"
)
status_code = MCP_ERROR_TO_HTTP_STATUS.get(e.error.code, 500)
raise HTTPException(
status_code=status_code,
detail=(
{"message": e.error.message, "data": e.error.data}
if e.error.data is not None
else {"message": e.error.message}
),
)
except Exception as e:
logger.info(
f"Unexpected error calling {endpoint_name}: {traceback.format_exc()}"
)
raise HTTPException(
status_code=500,
detail={"message": "Unexpected error", "error": str(e)},
)
return tool
return tool
tool_handler = make_endpoint_func(endpoint_name, FormModel, session)
else:
async def tool(request: Request):
forwarded_headers = {}
if (
client_header_forwarding_config
and client_header_forwarding_config.get("enabled", False)
):
forwarded_headers = process_headers_for_server(
request, client_header_forwarding_config
)
def make_endpoint_func_no_args(
endpoint_name: str, session: ClientSession
): # Parameterless endpoint
async def tool(
request: Request,
): # No parameters but need request for headers
# Process headers for forwarding if configured
forwarded_headers = {}
if (
client_header_forwarding_config
and client_header_forwarding_config.get("enabled", False)
):
forwarded_headers = process_headers_for_server(
request, client_header_forwarding_config
)
meta = {}
if forwarded_headers:
meta["headers"] = forwarded_headers
# Add headers to _meta if any headers are being forwarded
meta = {}
if forwarded_headers:
meta["headers"] = forwarded_headers
logger.info(f"Calling endpoint: {endpoint_name}, with no args")
try:
result = await call_tool_with_reconnect(request, {})
logger.info(f"Calling endpoint: {endpoint_name}, with no args")
try:
result = await session.call_tool(
endpoint_name, arguments={}
) # Empty dict
if result.isError:
error_message = "Unknown tool execution error"
if result.content and isinstance(result.content[0], types.TextContent):
error_message = result.content[0].text
detail = {"message": error_message}
raise HTTPException(status_code=500, detail=detail)
if result.isError:
error_message = "Unknown tool execution error"
if result.content:
if isinstance(result.content[0], types.TextContent):
error_message = result.content[0].text
detail = {"message": error_message}
raise HTTPException(
status_code=500,
detail=detail,
)
response_data = process_tool_response(result)
final_response = (
response_data[0] if len(response_data) == 1 else response_data
)
return final_response
response_data = process_tool_response(result)
final_response = (
response_data[0] if len(response_data) == 1 else response_data
)
return final_response
except McpError as e:
logger.info(
f"MCP Error calling {endpoint_name}: {traceback.format_exc()}"
)
status_code = MCP_ERROR_TO_HTTP_STATUS.get(e.error.code, 500)
raise HTTPException(
status_code=status_code,
detail=(
{"message": e.error.message, "data": e.error.data}
if e.error.data is not None
else {"message": e.error.message}
),
)
except Exception as e:
logger.info(
f"Unexpected error calling {endpoint_name}: {traceback.format_exc()}"
)
raise HTTPException(
status_code=500,
detail={"message": "Unexpected error", "error": str(e)},
)
except McpError as e:
logger.info(
f"MCP Error calling {endpoint_name}: {traceback.format_exc()}"
)
status_code = MCP_ERROR_TO_HTTP_STATUS.get(e.error.code, 500)
# Propagate the error received from MCP as an HTTP exception
raise HTTPException(
status_code=status_code,
detail=(
{"message": e.error.message, "data": e.error.data}
if e.error.data is not None
else {"message": e.error.message}
),
)
except Exception as e:
logger.info(
f"Unexpected error calling {endpoint_name}: {traceback.format_exc()}"
)
raise HTTPException(
status_code=500,
detail={"message": "Unexpected error", "error": str(e)},
)
return tool
tool_handler = make_endpoint_func_no_args(endpoint_name, session)
return tool_handler
return tool