mirror of
https://github.com/onyx-dot-app/litellm.git
synced 2026-07-01 20:44:04 -04:00
fix(mcp/): add ssl certificate settings for mcp clients
respect ca bundle path for mcp calls
This commit is contained in:
@@ -5,8 +5,9 @@ LiteLLM Proxy uses this MCP Client to connnect to other MCP servers.
|
||||
import asyncio
|
||||
import base64
|
||||
from datetime import timedelta
|
||||
from typing import Dict, List, Optional, Union
|
||||
from typing import Callable, Dict, List, Optional, Union
|
||||
|
||||
import httpx
|
||||
from mcp import ClientSession, StdioServerParameters
|
||||
from mcp.client.sse import sse_client
|
||||
from mcp.client.stdio import stdio_client
|
||||
@@ -17,6 +18,8 @@ from mcp.types import TextContent
|
||||
from mcp.types import Tool as MCPTool
|
||||
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.llms.custom_httpx.http_handler import get_ssl_configuration
|
||||
from litellm.types.llms.custom_http import VerifyTypes
|
||||
from litellm.types.mcp import (
|
||||
MCPAuth,
|
||||
MCPAuthType,
|
||||
@@ -48,6 +51,7 @@ class MCPClient:
|
||||
timeout: float = 60.0,
|
||||
stdio_config: Optional[MCPStdioConfig] = None,
|
||||
extra_headers: Optional[Dict[str, str]] = None,
|
||||
ssl_verify: Optional[VerifyTypes] = None,
|
||||
):
|
||||
self.server_url: str = server_url
|
||||
self.transport_type: MCPTransport = transport_type
|
||||
@@ -62,6 +66,7 @@ class MCPClient:
|
||||
self._task: Optional[asyncio.Task] = None
|
||||
self.stdio_config: Optional[MCPStdioConfig] = stdio_config
|
||||
self.extra_headers: Optional[Dict[str, str]] = extra_headers
|
||||
self.ssl_verify: Optional[VerifyTypes] = ssl_verify
|
||||
# handle the basic auth value if provided
|
||||
if auth_value:
|
||||
self.update_auth_value(auth_value)
|
||||
@@ -104,10 +109,12 @@ class MCPClient:
|
||||
await self._session.initialize()
|
||||
elif self.transport_type == MCPTransport.sse:
|
||||
headers = self._get_auth_headers()
|
||||
httpx_client_factory = self._create_httpx_client_factory()
|
||||
self._transport_ctx = sse_client(
|
||||
url=self.server_url,
|
||||
timeout=self.timeout,
|
||||
headers=headers,
|
||||
httpx_client_factory=httpx_client_factory,
|
||||
)
|
||||
self._transport = await self._transport_ctx.__aenter__()
|
||||
self._session_ctx = ClientSession(
|
||||
@@ -117,6 +124,7 @@ class MCPClient:
|
||||
await self._session.initialize()
|
||||
else: # http
|
||||
headers = self._get_auth_headers()
|
||||
httpx_client_factory = self._create_httpx_client_factory()
|
||||
verbose_logger.debug(
|
||||
"litellm headers for streamablehttp_client: ", headers
|
||||
)
|
||||
@@ -124,6 +132,7 @@ class MCPClient:
|
||||
url=self.server_url,
|
||||
timeout=timedelta(seconds=self.timeout),
|
||||
headers=headers,
|
||||
httpx_client_factory=httpx_client_factory,
|
||||
)
|
||||
self._transport = await self._transport_ctx.__aenter__()
|
||||
self._session_ctx = ClientSession(
|
||||
@@ -215,6 +224,41 @@ class MCPClient:
|
||||
|
||||
return headers
|
||||
|
||||
def _create_httpx_client_factory(self) -> Callable[..., httpx.AsyncClient]:
|
||||
"""
|
||||
Create a custom httpx client factory that uses LiteLLM's SSL configuration.
|
||||
|
||||
This factory follows the same CA bundle path logic as http_handler.py:
|
||||
1. Check ssl_verify parameter (can be SSLContext, bool, or path to CA bundle)
|
||||
2. Check SSL_VERIFY environment variable
|
||||
3. Check SSL_CERT_FILE environment variable
|
||||
4. Fall back to certifi CA bundle
|
||||
"""
|
||||
|
||||
def factory(
|
||||
*,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
timeout: Optional[httpx.Timeout] = None,
|
||||
auth: Optional[httpx.Auth] = None,
|
||||
) -> httpx.AsyncClient:
|
||||
"""Create an httpx.AsyncClient with LiteLLM's SSL configuration."""
|
||||
# Get unified SSL configuration using the same logic as http_handler.py
|
||||
ssl_config = get_ssl_configuration(self.ssl_verify)
|
||||
|
||||
verbose_logger.debug(
|
||||
f"MCP client using SSL configuration: {type(ssl_config).__name__}"
|
||||
)
|
||||
|
||||
return httpx.AsyncClient(
|
||||
headers=headers,
|
||||
timeout=timeout,
|
||||
auth=auth,
|
||||
verify=ssl_config,
|
||||
follow_redirects=True,
|
||||
)
|
||||
|
||||
return factory
|
||||
|
||||
async def list_tools(self) -> List[MCPTool]:
|
||||
"""List available tools from the server."""
|
||||
if not self._session:
|
||||
|
||||
@@ -1,10 +1,13 @@
|
||||
import os
|
||||
import ssl
|
||||
import sys
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
# Add the parent directory to the path so we can import litellm
|
||||
sys.path.insert(0, '../../../')
|
||||
sys.path.insert(0, "../../../")
|
||||
|
||||
from litellm.experimental_mcp_client.client import MCPClient
|
||||
from litellm.types.mcp import MCPStdioConfig, MCPTransport
|
||||
@@ -16,57 +19,54 @@ class TestMCPClient:
|
||||
def test_mcp_client_stdio_init(self):
|
||||
"""Test MCPClient initialization with stdio config"""
|
||||
stdio_config = MCPStdioConfig(
|
||||
command="python",
|
||||
args=["-m", "my_mcp_server"],
|
||||
env={"DEBUG": "1"}
|
||||
command="python", args=["-m", "my_mcp_server"], env={"DEBUG": "1"}
|
||||
)
|
||||
|
||||
client = MCPClient(
|
||||
transport_type=MCPTransport.stdio,
|
||||
stdio_config=stdio_config
|
||||
)
|
||||
|
||||
|
||||
client = MCPClient(transport_type=MCPTransport.stdio, stdio_config=stdio_config)
|
||||
|
||||
assert client.transport_type == MCPTransport.stdio
|
||||
assert client.stdio_config == stdio_config
|
||||
assert client.stdio_config["command"] == "python"
|
||||
assert client.stdio_config["args"] == ["-m", "my_mcp_server"]
|
||||
assert client.stdio_config is not None
|
||||
assert client.stdio_config.get("command") == "python"
|
||||
assert client.stdio_config.get("args") == ["-m", "my_mcp_server"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mcp_client_stdio_connect_error(self):
|
||||
"""Test MCP client stdio connection error handling"""
|
||||
# Test missing stdio_config
|
||||
client = MCPClient(transport_type=MCPTransport.stdio)
|
||||
|
||||
with pytest.raises(ValueError, match="stdio_config is required for stdio transport"):
|
||||
|
||||
with pytest.raises(
|
||||
ValueError, match="stdio_config is required for stdio transport"
|
||||
):
|
||||
await client.connect()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('litellm.experimental_mcp_client.client.stdio_client')
|
||||
@patch('litellm.experimental_mcp_client.client.ClientSession')
|
||||
async def test_mcp_client_stdio_connect_success(self, mock_session, mock_stdio_client):
|
||||
@patch("litellm.experimental_mcp_client.client.stdio_client")
|
||||
@patch("litellm.experimental_mcp_client.client.ClientSession")
|
||||
async def test_mcp_client_stdio_connect_success(
|
||||
self, mock_session, mock_stdio_client
|
||||
):
|
||||
"""Test successful stdio connection"""
|
||||
# Setup mocks
|
||||
mock_transport = (MagicMock(), MagicMock())
|
||||
mock_stdio_client.return_value.__aenter__ = AsyncMock(return_value=mock_transport)
|
||||
|
||||
mock_stdio_client.return_value.__aenter__ = AsyncMock(
|
||||
return_value=mock_transport
|
||||
)
|
||||
|
||||
mock_session_instance = MagicMock()
|
||||
mock_session_instance.__aenter__ = AsyncMock(return_value=mock_session_instance)
|
||||
mock_session_instance.initialize = AsyncMock()
|
||||
mock_session.return_value = mock_session_instance
|
||||
|
||||
|
||||
stdio_config = MCPStdioConfig(
|
||||
command="python",
|
||||
args=["-m", "my_mcp_server"],
|
||||
env={"DEBUG": "1"}
|
||||
command="python", args=["-m", "my_mcp_server"], env={"DEBUG": "1"}
|
||||
)
|
||||
|
||||
client = MCPClient(
|
||||
transport_type=MCPTransport.stdio,
|
||||
stdio_config=stdio_config
|
||||
)
|
||||
|
||||
|
||||
client = MCPClient(transport_type=MCPTransport.stdio, stdio_config=stdio_config)
|
||||
|
||||
await client.connect()
|
||||
|
||||
|
||||
# Verify stdio_client was called with correct parameters
|
||||
mock_stdio_client.assert_called_once()
|
||||
call_args = mock_stdio_client.call_args[0][0]
|
||||
@@ -74,6 +74,162 @@ class TestMCPClient:
|
||||
assert call_args.args == ["-m", "my_mcp_server"]
|
||||
assert call_args.env == {"DEBUG": "1"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("litellm.experimental_mcp_client.client.streamablehttp_client")
|
||||
@patch.dict(
|
||||
os.environ,
|
||||
{
|
||||
"SSL_CERT_FILE": "/path/to/custom/ca-bundle.pem",
|
||||
"SSL_CERTIFICATE": "/path/to/client-cert.pem",
|
||||
},
|
||||
)
|
||||
async def test_mcp_client_ssl_configuration_from_env(
|
||||
self, mock_streamablehttp_client
|
||||
):
|
||||
"""Test that MCP client uses SSL configuration from environment variables"""
|
||||
# Setup mocks
|
||||
mock_transport = (MagicMock(), MagicMock())
|
||||
mock_streamablehttp_client.return_value.__aenter__ = AsyncMock(
|
||||
return_value=mock_transport
|
||||
)
|
||||
|
||||
# Mock the session
|
||||
with patch(
|
||||
"litellm.experimental_mcp_client.client.ClientSession"
|
||||
) as mock_session:
|
||||
mock_session_instance = MagicMock()
|
||||
mock_session_instance.__aenter__ = AsyncMock(
|
||||
return_value=mock_session_instance
|
||||
)
|
||||
mock_session_instance.initialize = AsyncMock()
|
||||
mock_session.return_value = mock_session_instance
|
||||
|
||||
client = MCPClient(
|
||||
server_url="https://mcp-server.example.com",
|
||||
transport_type=MCPTransport.http,
|
||||
)
|
||||
|
||||
await client.connect()
|
||||
|
||||
# Verify streamablehttp_client was called
|
||||
mock_streamablehttp_client.assert_called_once()
|
||||
call_kwargs = mock_streamablehttp_client.call_args[1]
|
||||
|
||||
# Verify httpx_client_factory was passed
|
||||
assert "httpx_client_factory" in call_kwargs
|
||||
httpx_factory = call_kwargs["httpx_client_factory"]
|
||||
|
||||
# Test the factory creates a client with proper SSL config
|
||||
# When SSL_CERT_FILE is set, the factory should use get_ssl_configuration
|
||||
test_client = httpx_factory(headers={"test": "header"})
|
||||
|
||||
# Verify the client was created successfully with SSL configuration
|
||||
assert test_client is not None
|
||||
assert isinstance(test_client, httpx.AsyncClient)
|
||||
# Verify it has the expected properties
|
||||
assert test_client.headers is not None
|
||||
# Clean up
|
||||
await test_client.aclose()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("litellm.experimental_mcp_client.client.sse_client")
|
||||
async def test_mcp_client_ssl_verify_parameter(self, mock_sse_client):
|
||||
"""Test that MCP client uses ssl_verify parameter when provided"""
|
||||
# Setup mocks
|
||||
mock_transport = (MagicMock(), MagicMock())
|
||||
mock_sse_client.return_value.__aenter__ = AsyncMock(return_value=mock_transport)
|
||||
|
||||
# Mock the session
|
||||
with patch(
|
||||
"litellm.experimental_mcp_client.client.ClientSession"
|
||||
) as mock_session:
|
||||
mock_session_instance = MagicMock()
|
||||
mock_session_instance.__aenter__ = AsyncMock(
|
||||
return_value=mock_session_instance
|
||||
)
|
||||
mock_session_instance.initialize = AsyncMock()
|
||||
mock_session.return_value = mock_session_instance
|
||||
|
||||
# Test with ssl_verify=False
|
||||
client = MCPClient(
|
||||
server_url="https://mcp-server.example.com",
|
||||
transport_type=MCPTransport.sse,
|
||||
ssl_verify=False,
|
||||
)
|
||||
|
||||
await client.connect()
|
||||
|
||||
# Verify sse_client was called
|
||||
mock_sse_client.assert_called_once()
|
||||
call_kwargs = mock_sse_client.call_args[1]
|
||||
|
||||
# Verify httpx_client_factory was passed
|
||||
assert "httpx_client_factory" in call_kwargs
|
||||
httpx_factory = call_kwargs["httpx_client_factory"]
|
||||
|
||||
# Test the factory creates a client with SSL verification disabled
|
||||
# When ssl_verify=False, the factory should disable SSL verification
|
||||
test_client = httpx_factory(headers={"test": "header"})
|
||||
|
||||
# Verify the client was created successfully
|
||||
assert test_client is not None
|
||||
assert isinstance(test_client, httpx.AsyncClient)
|
||||
# Verify it has the expected properties
|
||||
assert test_client.headers is not None
|
||||
# Clean up
|
||||
await test_client.aclose()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("litellm.experimental_mcp_client.client.streamablehttp_client")
|
||||
async def test_mcp_client_ssl_verify_custom_path(self, mock_streamablehttp_client):
|
||||
"""Test that MCP client uses custom CA bundle path from ssl_verify parameter"""
|
||||
# Setup mocks
|
||||
mock_transport = (MagicMock(), MagicMock())
|
||||
mock_streamablehttp_client.return_value.__aenter__ = AsyncMock(
|
||||
return_value=mock_transport
|
||||
)
|
||||
|
||||
# Mock the session
|
||||
with patch(
|
||||
"litellm.experimental_mcp_client.client.ClientSession"
|
||||
) as mock_session:
|
||||
mock_session_instance = MagicMock()
|
||||
mock_session_instance.__aenter__ = AsyncMock(
|
||||
return_value=mock_session_instance
|
||||
)
|
||||
mock_session_instance.initialize = AsyncMock()
|
||||
mock_session.return_value = mock_session_instance
|
||||
|
||||
# Test with custom CA bundle path
|
||||
custom_ca_path = "/custom/path/to/ca-bundle.pem"
|
||||
client = MCPClient(
|
||||
server_url="https://mcp-server.example.com",
|
||||
transport_type=MCPTransport.http,
|
||||
ssl_verify=custom_ca_path,
|
||||
)
|
||||
|
||||
await client.connect()
|
||||
|
||||
# Verify streamablehttp_client was called
|
||||
mock_streamablehttp_client.assert_called_once()
|
||||
call_kwargs = mock_streamablehttp_client.call_args[1]
|
||||
|
||||
# Verify httpx_client_factory was passed
|
||||
assert "httpx_client_factory" in call_kwargs
|
||||
httpx_factory = call_kwargs["httpx_client_factory"]
|
||||
|
||||
# Test the factory creates a client with custom CA bundle path
|
||||
# When ssl_verify is a path, the factory should use that path for SSL verification
|
||||
test_client = httpx_factory(headers={"test": "header"})
|
||||
|
||||
# Verify the client was created successfully
|
||||
assert test_client is not None
|
||||
assert isinstance(test_client, httpx.AsyncClient)
|
||||
# Verify it has the expected properties
|
||||
assert test_client.headers is not None
|
||||
# Clean up
|
||||
await test_client.aclose()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__])
|
||||
pytest.main([__file__])
|
||||
|
||||
Reference in New Issue
Block a user