mirror of
https://github.com/open-webui/mcpo.git
synced 2026-07-01 21:04:00 -04:00
@@ -11,3 +11,5 @@ wheels/
|
||||
config.json
|
||||
.python-version
|
||||
.vscode
|
||||
|
||||
.DS_Store
|
||||
|
||||
@@ -5,6 +5,24 @@ All notable changes to this project will be documented in this file.
|
||||
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
|
||||
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
|
||||
|
||||
## [0.0.20] - 2026-02-27
|
||||
|
||||
### Added
|
||||
|
||||
* 🔁 **MCP Connection Manager with Auto-Reconnect**: Introduced MCPConnectionManager to manage the full lifecycle of MCP client sessions and transports. Connections that drop due to a ClosedResourceError are now automatically re-established and retried transparently, eliminating manual restarts for intermittent server disconnects.
|
||||
* 🔄 **Automatic Retry on Closed Sessions**: Tool calls that encounter a ClosedResourceError will now automatically reconnect to the MCP server and retry the call, providing resilient end-to-end execution without user intervention.
|
||||
|
||||
### Changed
|
||||
|
||||
* 🧩 **Normalized disabledTools Configuration Key**: The disabled tools setting now accepts both camelCase (disabledTools) and snake_case (disabled_tools) in the config file for backwards compatibility and consistency with other configuration keys.
|
||||
* 🧹 **Simplified Tool Handler Architecture**: Refactored get_tool_handler to remove the session parameter and nested factory functions in favor of a cleaner design that resolves the session from the request app state at call time, enabling reconnect-aware tool execution.
|
||||
* 🛡️ **Guarded Lifespan Startup During Hot Reload**: The reload handler now checks for the presence of an async context manager before awaiting sub-app lifespan startup, preventing errors when lifespan context is unavailable.
|
||||
|
||||
### Fixed
|
||||
|
||||
* 🪛 **Fixed Erroneous Validation Raise in disabledTools**: Removed a misplaced raise statement in validate_server_config that would always throw an error after validating disabled_tools, even when the configuration was correct.
|
||||
* 🧯 **Graceful Handling of asyncio.CancelledError**: Added explicit CancelledError handling during server creation and lifespan startup to properly roll back routes and log errors instead of crashing silently.
|
||||
|
||||
## [0.0.19] - 2025-10-14
|
||||
|
||||
### Fixed
|
||||
|
||||
+1
-1
@@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "mcpo"
|
||||
version = "0.0.19"
|
||||
version = "0.0.20"
|
||||
description = "A simple, secure MCP-to-OpenAPI proxy server"
|
||||
authors = [
|
||||
{ name = "Timothy Jaeryang Baek", email = "tim@openwebui.com" }
|
||||
|
||||
+209
-54
@@ -1,11 +1,12 @@
|
||||
import asyncio
|
||||
import contextlib
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import signal
|
||||
import socket
|
||||
from contextlib import AsyncExitStack, asynccontextmanager
|
||||
from typing import Optional, Dict, Any
|
||||
from typing import Any, Dict, List, Optional
|
||||
from urllib.parse import urljoin
|
||||
|
||||
import uvicorn
|
||||
@@ -52,6 +53,139 @@ class GracefulShutdown:
|
||||
task.add_done_callback(self.tasks.discard)
|
||||
|
||||
|
||||
class MCPConnectionManager:
|
||||
"""
|
||||
Manages lifecycle of the MCP ClientSession and underlying transport so that
|
||||
we can reconnect transparently when the remote server drops the connection.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
server_type: str,
|
||||
command: Optional[str],
|
||||
args: List[str],
|
||||
env: Dict[str, str],
|
||||
headers: Optional[Dict[str, str]],
|
||||
connection_timeout: Optional[int],
|
||||
auth_provider: Optional[Any] = None,
|
||||
):
|
||||
self.server_type = normalize_server_type(server_type)
|
||||
self.command = command
|
||||
self.args = args
|
||||
self.env = env
|
||||
self.headers = headers
|
||||
self.connection_timeout = connection_timeout
|
||||
self.auth_provider = auth_provider
|
||||
|
||||
self._session: Optional[ClientSession] = None
|
||||
self._client_context = None
|
||||
self._lock = asyncio.Lock()
|
||||
self._initialize_lock = asyncio.Lock()
|
||||
self._initialize_result = None
|
||||
self._initialized = False
|
||||
|
||||
@property
|
||||
def current_session(self) -> Optional[ClientSession]:
|
||||
return self._session
|
||||
|
||||
async def get_session(self) -> ClientSession:
|
||||
async with self._lock:
|
||||
if self._session is None:
|
||||
await self._open_session_locked()
|
||||
return self._session
|
||||
|
||||
async def ensure_initialized(self):
|
||||
session = await self.get_session()
|
||||
if self._initialized and self._initialize_result is not None:
|
||||
return session, self._initialize_result
|
||||
|
||||
async with self._initialize_lock:
|
||||
if not self._initialized or self._initialize_result is None:
|
||||
initialize_result = await session.initialize()
|
||||
self._initialize_result = initialize_result
|
||||
self._initialized = True
|
||||
else:
|
||||
initialize_result = self._initialize_result
|
||||
|
||||
return session, initialize_result
|
||||
|
||||
async def reconnect(self):
|
||||
async with self._lock:
|
||||
await self._close_session_locked()
|
||||
await self._open_session_locked()
|
||||
# Run initialize outside of the lock to avoid deadlocks on nested calls
|
||||
return await self.ensure_initialized()
|
||||
|
||||
async def close(self):
|
||||
async with self._lock:
|
||||
await self._close_session_locked()
|
||||
|
||||
async def _open_session_locked(self):
|
||||
client_context = self._create_client_context()
|
||||
try:
|
||||
connection = await client_context.__aenter__()
|
||||
except Exception:
|
||||
# Ensure the context is closed if entering fails
|
||||
with contextlib.suppress(Exception):
|
||||
await client_context.__aexit__(None, None, None)
|
||||
raise
|
||||
|
||||
reader, writer, *_ = connection
|
||||
session = ClientSession(reader, writer)
|
||||
try:
|
||||
await session.__aenter__()
|
||||
except Exception:
|
||||
with contextlib.suppress(Exception):
|
||||
await session.__aexit__(None, None, None)
|
||||
with contextlib.suppress(Exception):
|
||||
await client_context.__aexit__(None, None, None)
|
||||
raise
|
||||
|
||||
self._client_context = client_context
|
||||
self._session = session
|
||||
self._initialized = False
|
||||
self._initialize_result = None
|
||||
|
||||
async def _close_session_locked(self):
|
||||
session, client_context = self._session, self._client_context
|
||||
self._session = None
|
||||
self._client_context = None
|
||||
self._initialized = False
|
||||
self._initialize_result = None
|
||||
|
||||
if session is not None:
|
||||
with contextlib.suppress(Exception):
|
||||
await session.__aexit__(None, None, None)
|
||||
|
||||
if client_context is not None:
|
||||
with contextlib.suppress(Exception):
|
||||
await client_context.__aexit__(None, None, None)
|
||||
|
||||
def _create_client_context(self):
|
||||
if self.server_type == "stdio":
|
||||
server_params = StdioServerParameters(
|
||||
command=self.command,
|
||||
args=self.args,
|
||||
env={**os.environ, **self.env},
|
||||
)
|
||||
return stdio_client(server_params)
|
||||
if self.server_type == "sse":
|
||||
timeout = self.connection_timeout or 900
|
||||
return sse_client(
|
||||
url=self.args[0],
|
||||
sse_read_timeout=timeout,
|
||||
headers=self.headers,
|
||||
)
|
||||
if self.server_type == "streamable-http":
|
||||
return streamablehttp_client(
|
||||
url=self.args[0],
|
||||
headers=self.headers,
|
||||
auth=self.auth_provider,
|
||||
)
|
||||
raise ValueError(f"Unsupported server type: {self.server_type}")
|
||||
|
||||
|
||||
def validate_server_config(server_name: str, server_cfg: Dict[str, Any]) -> None:
|
||||
"""Validate individual server configuration."""
|
||||
server_type = server_cfg.get("type")
|
||||
@@ -73,17 +207,16 @@ def validate_server_config(server_name: str, server_cfg: Dict[str, Any]) -> None
|
||||
else:
|
||||
raise ValueError(f"Server '{server_name}' must have either 'command' for stdio or 'type' and 'url' for remote servers")
|
||||
|
||||
# Validate disabledTools
|
||||
disabled_tools = server_cfg.get("disabled_tools")
|
||||
# Validate disabledTools (supports camelCase & snake_case for backwards compatibility)
|
||||
disabled_tools = server_cfg.get("disabledTools")
|
||||
if disabled_tools is None:
|
||||
disabled_tools = server_cfg.get("disabled_tools")
|
||||
if disabled_tools is not None:
|
||||
if not isinstance(disabled_tools, list):
|
||||
raise ValueError(f"Server '{server_name}' 'disabledTools' must be a list")
|
||||
for tool_name in disabled_tools:
|
||||
if not isinstance(tool_name, str):
|
||||
raise ValueError(f"Server '{server_name}' 'disabledTools' must contain only strings")
|
||||
raise ValueError(
|
||||
f"Server '{server_name}' must have either 'command' for stdio or 'type' and 'url' for remote servers"
|
||||
)
|
||||
|
||||
|
||||
def load_config(config_path: str) -> Dict[str, Any]:
|
||||
@@ -177,8 +310,11 @@ def create_sub_app(
|
||||
sub_app.state.api_dependency = api_dependency
|
||||
sub_app.state.connection_timeout = connection_timeout
|
||||
|
||||
# Store list of tools to be disabled, if present
|
||||
sub_app.state.disabled_tools = server_cfg.get("disabled_tools", [])
|
||||
# Store list of tools to be disabled, if present (accept both key styles)
|
||||
disabled_tools = server_cfg.get("disabledTools")
|
||||
if disabled_tools is None:
|
||||
disabled_tools = server_cfg.get("disabled_tools", [])
|
||||
sub_app.state.disabled_tools = disabled_tools
|
||||
|
||||
|
||||
# Store client header forwarding configuration
|
||||
@@ -306,12 +442,20 @@ async def reload_config_handler(main_app: FastAPI, new_config_data: Dict[str, An
|
||||
)
|
||||
main_app.mount(f"{path_prefix}{server_name}", sub_app)
|
||||
|
||||
# Start the lifespan for the new sub-app
|
||||
lifespan_context = sub_app.router.lifespan_context(sub_app)
|
||||
await lifespan_context.__aenter__()
|
||||
|
||||
# Store the context manager for cleanup later
|
||||
main_app.state.active_lifespans[server_name] = lifespan_context
|
||||
# Start the lifespan for the new sub-app when available
|
||||
lifespan_context = None
|
||||
lifespan_factory = getattr(sub_app.router, "lifespan_context", None)
|
||||
if lifespan_factory:
|
||||
lifespan_context = lifespan_factory(sub_app)
|
||||
if lifespan_context and hasattr(lifespan_context, "__aenter__"):
|
||||
await lifespan_context.__aenter__()
|
||||
# Store the context manager for cleanup later
|
||||
main_app.state.active_lifespans[server_name] = lifespan_context
|
||||
else:
|
||||
logger.debug(
|
||||
"Skipping lifespan startup for server '%s' (no async context manager)",
|
||||
server_name,
|
||||
)
|
||||
|
||||
# Check if connection was successful
|
||||
is_connected = getattr(sub_app.state, "is_connected", False)
|
||||
@@ -324,6 +468,11 @@ async def reload_config_handler(main_app: FastAPI, new_config_data: Dict[str, An
|
||||
f"Failed to connect to new server: '{server_name}'"
|
||||
)
|
||||
|
||||
except asyncio.CancelledError as e:
|
||||
logger.error(f"Failed to create server '{server_name}' (cancelled): {e}")
|
||||
# Rollback on failure
|
||||
main_app.router.routes = backup_routes
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create server '{server_name}': {e}")
|
||||
# Rollback on failure
|
||||
@@ -342,11 +491,18 @@ async def reload_config_handler(main_app: FastAPI, new_config_data: Dict[str, An
|
||||
|
||||
|
||||
async def create_dynamic_endpoints(app: FastAPI, api_dependency=None):
|
||||
session: ClientSession = app.state.session
|
||||
if not session:
|
||||
raise ValueError("Session is not initialized in the app state.")
|
||||
session_manager: Optional[MCPConnectionManager] = getattr(
|
||||
app.state, "session_manager", None
|
||||
)
|
||||
|
||||
result = await session.initialize()
|
||||
if session_manager:
|
||||
session, result = await session_manager.ensure_initialized()
|
||||
app.state.session = session
|
||||
else:
|
||||
session: Optional[ClientSession] = getattr(app.state, "session", None)
|
||||
if not session:
|
||||
raise ValueError("Session is not initialized in the app state.")
|
||||
result = await session.initialize()
|
||||
server_info = getattr(result, "serverInfo", None)
|
||||
if server_info:
|
||||
app.title = server_info.name or app.title
|
||||
@@ -400,7 +556,6 @@ async def create_dynamic_endpoints(app: FastAPI, api_dependency=None):
|
||||
)
|
||||
|
||||
tool_handler = get_tool_handler(
|
||||
session,
|
||||
endpoint_name,
|
||||
form_model_fields,
|
||||
response_model_fields,
|
||||
@@ -459,6 +614,14 @@ async def lifespan(app: FastAPI):
|
||||
f"Connection attempt for '{server_name}' finished, but status is not 'connected'."
|
||||
)
|
||||
failed_servers.append(server_name)
|
||||
except asyncio.CancelledError as e:
|
||||
if shutdown_handler and shutdown_handler.shutdown_event.is_set():
|
||||
raise
|
||||
logger.error(
|
||||
f"Failed to establish connection for server: '{server_name}' - CancelledError: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
failed_servers.append(server_name)
|
||||
except Exception as e:
|
||||
error_class_name = type(e).__name__
|
||||
if error_class_name == "ExceptionGroup" or (
|
||||
@@ -534,42 +697,28 @@ async def lifespan(app: FastAPI):
|
||||
)
|
||||
raise
|
||||
|
||||
if server_type == "stdio":
|
||||
# stdio doesn't support OAuth authentication
|
||||
if oauth_config:
|
||||
logger.warning(f"OAuth not supported for stdio server type")
|
||||
server_params = StdioServerParameters(
|
||||
command=command,
|
||||
args=args,
|
||||
env={**os.environ, **env},
|
||||
)
|
||||
client_context = stdio_client(server_params)
|
||||
elif server_type == "sse":
|
||||
# SSE doesn't support OAuth authentication currently
|
||||
if oauth_config:
|
||||
logger.warning(f"OAuth not supported for SSE server type")
|
||||
headers = getattr(app.state, "headers", None)
|
||||
client_context = sse_client(
|
||||
url=args[0],
|
||||
sse_read_timeout=connection_timeout or 900,
|
||||
headers=headers,
|
||||
)
|
||||
elif server_type == "streamable-http":
|
||||
headers = getattr(app.state, "headers", None)
|
||||
client_context = streamablehttp_client(
|
||||
url=args[0],
|
||||
headers=headers,
|
||||
auth=auth_provider, # Pass OAuth provider if configured
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported server type: {server_type}")
|
||||
if oauth_config and server_type == "stdio":
|
||||
logger.warning("OAuth not supported for stdio server type")
|
||||
if oauth_config and server_type == "sse":
|
||||
logger.warning("OAuth not supported for SSE server type")
|
||||
|
||||
async with client_context as (reader, writer, *_):
|
||||
async with ClientSession(reader, writer) as session:
|
||||
app.state.session = session
|
||||
await create_dynamic_endpoints(app, api_dependency=api_dependency)
|
||||
app.state.is_connected = True
|
||||
yield
|
||||
headers = getattr(app.state, "headers", None)
|
||||
session_manager = MCPConnectionManager(
|
||||
server_type=server_type,
|
||||
command=command,
|
||||
args=args,
|
||||
env=env,
|
||||
headers=headers,
|
||||
connection_timeout=connection_timeout,
|
||||
auth_provider=auth_provider,
|
||||
)
|
||||
app.state.session_manager = session_manager
|
||||
|
||||
session = await session_manager.get_session()
|
||||
app.state.session = session
|
||||
await create_dynamic_endpoints(app, api_dependency=api_dependency)
|
||||
app.state.is_connected = True
|
||||
yield
|
||||
except Exception as e:
|
||||
# Log the full exception with traceback for debugging
|
||||
logger.error(
|
||||
@@ -579,6 +728,12 @@ async def lifespan(app: FastAPI):
|
||||
app.state.is_connected = False
|
||||
# Re-raise the exception so it propagates to the main app's lifespan
|
||||
raise
|
||||
finally:
|
||||
session_manager = getattr(app.state, "session_manager", None)
|
||||
if session_manager:
|
||||
await session_manager.close()
|
||||
app.state.session_manager = None
|
||||
app.state.session = None
|
||||
|
||||
|
||||
async def run(
|
||||
|
||||
+147
-135
@@ -1,10 +1,12 @@
|
||||
import logging
|
||||
import json
|
||||
import traceback
|
||||
from typing import Any, Dict, ForwardRef, List, Optional, Type, Union
|
||||
import logging
|
||||
|
||||
from anyio import ClosedResourceError
|
||||
from fastapi import HTTPException, Request
|
||||
|
||||
from mcp import ClientSession, types
|
||||
from mcp import types
|
||||
from mcp.types import (
|
||||
CallToolResult,
|
||||
PARSE_ERROR,
|
||||
@@ -275,12 +277,50 @@ def get_model_fields(form_model_name, properties, required_fields, schema_defs=N
|
||||
|
||||
|
||||
def get_tool_handler(
|
||||
session,
|
||||
endpoint_name,
|
||||
form_model_fields,
|
||||
response_model_fields=None,
|
||||
client_header_forwarding_config=None,
|
||||
):
|
||||
async def call_tool_with_reconnect(
|
||||
request: Request, arguments: Dict[str, Any]
|
||||
) -> CallToolResult:
|
||||
session_manager = getattr(request.app.state, "session_manager", None)
|
||||
|
||||
async def _invoke(session):
|
||||
return await session.call_tool(endpoint_name, arguments=arguments)
|
||||
|
||||
if session_manager:
|
||||
try:
|
||||
session, _ = await session_manager.ensure_initialized()
|
||||
except ClosedResourceError:
|
||||
logger.warning(
|
||||
"Session closed while initializing '%s'; attempting reconnect",
|
||||
endpoint_name,
|
||||
)
|
||||
session, _ = await session_manager.reconnect()
|
||||
request.app.state.session = session
|
||||
|
||||
try:
|
||||
return await _invoke(session)
|
||||
except ClosedResourceError:
|
||||
logger.warning(
|
||||
"Session closed during call to '%s'; attempting reconnect",
|
||||
endpoint_name,
|
||||
)
|
||||
session, _ = await session_manager.reconnect()
|
||||
request.app.state.session = session
|
||||
return await _invoke(session)
|
||||
|
||||
session = getattr(request.app.state, "session", None)
|
||||
if not session:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={"message": "MCP session is not available"},
|
||||
)
|
||||
|
||||
return await _invoke(session)
|
||||
|
||||
if form_model_fields:
|
||||
FormModel = create_model(f"{endpoint_name}_form_model", **form_model_fields)
|
||||
ResponseModel = (
|
||||
@@ -289,149 +329,121 @@ def get_tool_handler(
|
||||
else Any
|
||||
)
|
||||
|
||||
def make_endpoint_func(
|
||||
endpoint_name: str, FormModel, session: ClientSession
|
||||
): # Parameterized endpoint
|
||||
async def tool(
|
||||
form_data: FormModel, request: Request
|
||||
) -> Union[ResponseModel, Any]:
|
||||
args = form_data.model_dump(exclude_none=True, by_alias=True)
|
||||
async def tool(
|
||||
form_data: FormModel, request: Request
|
||||
) -> Union[ResponseModel, Any]:
|
||||
args = form_data.model_dump(exclude_none=True, by_alias=True)
|
||||
|
||||
# Process headers for forwarding if configured
|
||||
forwarded_headers = {}
|
||||
if (
|
||||
client_header_forwarding_config
|
||||
and client_header_forwarding_config.get("enabled", False)
|
||||
):
|
||||
forwarded_headers = process_headers_for_server(
|
||||
request, client_header_forwarding_config
|
||||
)
|
||||
forwarded_headers = {}
|
||||
if (
|
||||
client_header_forwarding_config
|
||||
and client_header_forwarding_config.get("enabled", False)
|
||||
):
|
||||
forwarded_headers = process_headers_for_server(
|
||||
request, client_header_forwarding_config
|
||||
)
|
||||
|
||||
# Add headers to _meta if any headers are being forwarded
|
||||
meta = {}
|
||||
if forwarded_headers:
|
||||
meta["headers"] = forwarded_headers
|
||||
meta = {}
|
||||
if forwarded_headers:
|
||||
meta["headers"] = forwarded_headers
|
||||
|
||||
logger.info(f"Calling endpoint: {endpoint_name}, with args: {args}")
|
||||
try:
|
||||
result = await session.call_tool(endpoint_name, arguments=args)
|
||||
logger.info(f"Calling endpoint: {endpoint_name}, with args: {args}")
|
||||
try:
|
||||
result = await call_tool_with_reconnect(request, args)
|
||||
|
||||
if result.isError:
|
||||
error_message = "Unknown tool execution error"
|
||||
error_data = None # Initialize error_data
|
||||
if result.content:
|
||||
if isinstance(result.content[0], types.TextContent):
|
||||
error_message = result.content[0].text
|
||||
detail = {"message": error_message}
|
||||
if error_data is not None:
|
||||
detail["data"] = error_data
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=detail,
|
||||
)
|
||||
if result.isError:
|
||||
error_message = "Unknown tool execution error"
|
||||
error_data = None
|
||||
if result.content and isinstance(
|
||||
result.content[0], types.TextContent
|
||||
):
|
||||
error_message = result.content[0].text
|
||||
detail = {"message": error_message}
|
||||
if error_data is not None:
|
||||
detail["data"] = error_data
|
||||
raise HTTPException(status_code=500, detail=detail)
|
||||
|
||||
response_data = process_tool_response(result)
|
||||
final_response = (
|
||||
response_data[0] if len(response_data) == 1 else response_data
|
||||
)
|
||||
return final_response
|
||||
response_data = process_tool_response(result)
|
||||
final_response = (
|
||||
response_data[0] if len(response_data) == 1 else response_data
|
||||
)
|
||||
return final_response
|
||||
|
||||
except McpError as e:
|
||||
logger.info(
|
||||
f"MCP Error calling {endpoint_name}: {traceback.format_exc()}"
|
||||
)
|
||||
status_code = MCP_ERROR_TO_HTTP_STATUS.get(e.error.code, 500)
|
||||
raise HTTPException(
|
||||
status_code=status_code,
|
||||
detail=(
|
||||
{"message": e.error.message, "data": e.error.data}
|
||||
if e.error.data is not None
|
||||
else {"message": e.error.message}
|
||||
),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.info(
|
||||
f"Unexpected error calling {endpoint_name}: {traceback.format_exc()}"
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={"message": "Unexpected error", "error": str(e)},
|
||||
)
|
||||
except McpError as e:
|
||||
logger.info(
|
||||
f"MCP Error calling {endpoint_name}: {traceback.format_exc()}"
|
||||
)
|
||||
status_code = MCP_ERROR_TO_HTTP_STATUS.get(e.error.code, 500)
|
||||
raise HTTPException(
|
||||
status_code=status_code,
|
||||
detail=(
|
||||
{"message": e.error.message, "data": e.error.data}
|
||||
if e.error.data is not None
|
||||
else {"message": e.error.message}
|
||||
),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.info(
|
||||
f"Unexpected error calling {endpoint_name}: {traceback.format_exc()}"
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={"message": "Unexpected error", "error": str(e)},
|
||||
)
|
||||
|
||||
return tool
|
||||
return tool
|
||||
|
||||
tool_handler = make_endpoint_func(endpoint_name, FormModel, session)
|
||||
else:
|
||||
async def tool(request: Request):
|
||||
forwarded_headers = {}
|
||||
if (
|
||||
client_header_forwarding_config
|
||||
and client_header_forwarding_config.get("enabled", False)
|
||||
):
|
||||
forwarded_headers = process_headers_for_server(
|
||||
request, client_header_forwarding_config
|
||||
)
|
||||
|
||||
def make_endpoint_func_no_args(
|
||||
endpoint_name: str, session: ClientSession
|
||||
): # Parameterless endpoint
|
||||
async def tool(
|
||||
request: Request,
|
||||
): # No parameters but need request for headers
|
||||
# Process headers for forwarding if configured
|
||||
forwarded_headers = {}
|
||||
if (
|
||||
client_header_forwarding_config
|
||||
and client_header_forwarding_config.get("enabled", False)
|
||||
):
|
||||
forwarded_headers = process_headers_for_server(
|
||||
request, client_header_forwarding_config
|
||||
)
|
||||
meta = {}
|
||||
if forwarded_headers:
|
||||
meta["headers"] = forwarded_headers
|
||||
|
||||
# Add headers to _meta if any headers are being forwarded
|
||||
meta = {}
|
||||
if forwarded_headers:
|
||||
meta["headers"] = forwarded_headers
|
||||
logger.info(f"Calling endpoint: {endpoint_name}, with no args")
|
||||
try:
|
||||
result = await call_tool_with_reconnect(request, {})
|
||||
|
||||
logger.info(f"Calling endpoint: {endpoint_name}, with no args")
|
||||
try:
|
||||
result = await session.call_tool(
|
||||
endpoint_name, arguments={}
|
||||
) # Empty dict
|
||||
if result.isError:
|
||||
error_message = "Unknown tool execution error"
|
||||
if result.content and isinstance(result.content[0], types.TextContent):
|
||||
error_message = result.content[0].text
|
||||
detail = {"message": error_message}
|
||||
raise HTTPException(status_code=500, detail=detail)
|
||||
|
||||
if result.isError:
|
||||
error_message = "Unknown tool execution error"
|
||||
if result.content:
|
||||
if isinstance(result.content[0], types.TextContent):
|
||||
error_message = result.content[0].text
|
||||
detail = {"message": error_message}
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=detail,
|
||||
)
|
||||
response_data = process_tool_response(result)
|
||||
final_response = (
|
||||
response_data[0] if len(response_data) == 1 else response_data
|
||||
)
|
||||
return final_response
|
||||
|
||||
response_data = process_tool_response(result)
|
||||
final_response = (
|
||||
response_data[0] if len(response_data) == 1 else response_data
|
||||
)
|
||||
return final_response
|
||||
except McpError as e:
|
||||
logger.info(
|
||||
f"MCP Error calling {endpoint_name}: {traceback.format_exc()}"
|
||||
)
|
||||
status_code = MCP_ERROR_TO_HTTP_STATUS.get(e.error.code, 500)
|
||||
raise HTTPException(
|
||||
status_code=status_code,
|
||||
detail=(
|
||||
{"message": e.error.message, "data": e.error.data}
|
||||
if e.error.data is not None
|
||||
else {"message": e.error.message}
|
||||
),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.info(
|
||||
f"Unexpected error calling {endpoint_name}: {traceback.format_exc()}"
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={"message": "Unexpected error", "error": str(e)},
|
||||
)
|
||||
|
||||
except McpError as e:
|
||||
logger.info(
|
||||
f"MCP Error calling {endpoint_name}: {traceback.format_exc()}"
|
||||
)
|
||||
status_code = MCP_ERROR_TO_HTTP_STATUS.get(e.error.code, 500)
|
||||
# Propagate the error received from MCP as an HTTP exception
|
||||
raise HTTPException(
|
||||
status_code=status_code,
|
||||
detail=(
|
||||
{"message": e.error.message, "data": e.error.data}
|
||||
if e.error.data is not None
|
||||
else {"message": e.error.message}
|
||||
),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.info(
|
||||
f"Unexpected error calling {endpoint_name}: {traceback.format_exc()}"
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={"message": "Unexpected error", "error": str(e)},
|
||||
)
|
||||
|
||||
return tool
|
||||
|
||||
tool_handler = make_endpoint_func_no_args(endpoint_name, session)
|
||||
|
||||
return tool_handler
|
||||
return tool
|
||||
|
||||
Reference in New Issue
Block a user