fix(mcp/): add ssl certificate settings for mcp clients

respect ca bundle path for mcp calls
This commit is contained in:
Krrish Dholakia
2025-10-06 18:36:05 -07:00
parent 5336fcc000
commit 7f88a3f9c6
3 changed files with 232 additions and 32 deletions
View File
+45 -1
View File
@@ -5,8 +5,9 @@ LiteLLM Proxy uses this MCP Client to connnect to other MCP servers.
import asyncio
import base64
from datetime import timedelta
from typing import Dict, List, Optional, Union
from typing import Callable, Dict, List, Optional, Union
import httpx
from mcp import ClientSession, StdioServerParameters
from mcp.client.sse import sse_client
from mcp.client.stdio import stdio_client
@@ -17,6 +18,8 @@ from mcp.types import TextContent
from mcp.types import Tool as MCPTool
from litellm._logging import verbose_logger
from litellm.llms.custom_httpx.http_handler import get_ssl_configuration
from litellm.types.llms.custom_http import VerifyTypes
from litellm.types.mcp import (
MCPAuth,
MCPAuthType,
@@ -48,6 +51,7 @@ class MCPClient:
timeout: float = 60.0,
stdio_config: Optional[MCPStdioConfig] = None,
extra_headers: Optional[Dict[str, str]] = None,
ssl_verify: Optional[VerifyTypes] = None,
):
self.server_url: str = server_url
self.transport_type: MCPTransport = transport_type
@@ -62,6 +66,7 @@ class MCPClient:
self._task: Optional[asyncio.Task] = None
self.stdio_config: Optional[MCPStdioConfig] = stdio_config
self.extra_headers: Optional[Dict[str, str]] = extra_headers
self.ssl_verify: Optional[VerifyTypes] = ssl_verify
# handle the basic auth value if provided
if auth_value:
self.update_auth_value(auth_value)
@@ -104,10 +109,12 @@ class MCPClient:
await self._session.initialize()
elif self.transport_type == MCPTransport.sse:
headers = self._get_auth_headers()
httpx_client_factory = self._create_httpx_client_factory()
self._transport_ctx = sse_client(
url=self.server_url,
timeout=self.timeout,
headers=headers,
httpx_client_factory=httpx_client_factory,
)
self._transport = await self._transport_ctx.__aenter__()
self._session_ctx = ClientSession(
@@ -117,6 +124,7 @@ class MCPClient:
await self._session.initialize()
else: # http
headers = self._get_auth_headers()
httpx_client_factory = self._create_httpx_client_factory()
verbose_logger.debug(
"litellm headers for streamablehttp_client: ", headers
)
@@ -124,6 +132,7 @@ class MCPClient:
url=self.server_url,
timeout=timedelta(seconds=self.timeout),
headers=headers,
httpx_client_factory=httpx_client_factory,
)
self._transport = await self._transport_ctx.__aenter__()
self._session_ctx = ClientSession(
@@ -215,6 +224,41 @@ class MCPClient:
return headers
def _create_httpx_client_factory(self) -> Callable[..., httpx.AsyncClient]:
"""
Create a custom httpx client factory that uses LiteLLM's SSL configuration.
This factory follows the same CA bundle path logic as http_handler.py:
1. Check ssl_verify parameter (can be SSLContext, bool, or path to CA bundle)
2. Check SSL_VERIFY environment variable
3. Check SSL_CERT_FILE environment variable
4. Fall back to certifi CA bundle
"""
def factory(
*,
headers: Optional[Dict[str, str]] = None,
timeout: Optional[httpx.Timeout] = None,
auth: Optional[httpx.Auth] = None,
) -> httpx.AsyncClient:
"""Create an httpx.AsyncClient with LiteLLM's SSL configuration."""
# Get unified SSL configuration using the same logic as http_handler.py
ssl_config = get_ssl_configuration(self.ssl_verify)
verbose_logger.debug(
f"MCP client using SSL configuration: {type(ssl_config).__name__}"
)
return httpx.AsyncClient(
headers=headers,
timeout=timeout,
auth=auth,
verify=ssl_config,
follow_redirects=True,
)
return factory
async def list_tools(self) -> List[MCPTool]:
"""List available tools from the server."""
if not self._session:
@@ -1,10 +1,13 @@
import os
import ssl
import sys
from unittest.mock import AsyncMock, MagicMock, patch
import httpx
import pytest
# Add the parent directory to the path so we can import litellm
sys.path.insert(0, '../../../')
sys.path.insert(0, "../../../")
from litellm.experimental_mcp_client.client import MCPClient
from litellm.types.mcp import MCPStdioConfig, MCPTransport
@@ -16,57 +19,54 @@ class TestMCPClient:
def test_mcp_client_stdio_init(self):
"""Test MCPClient initialization with stdio config"""
stdio_config = MCPStdioConfig(
command="python",
args=["-m", "my_mcp_server"],
env={"DEBUG": "1"}
command="python", args=["-m", "my_mcp_server"], env={"DEBUG": "1"}
)
client = MCPClient(
transport_type=MCPTransport.stdio,
stdio_config=stdio_config
)
client = MCPClient(transport_type=MCPTransport.stdio, stdio_config=stdio_config)
assert client.transport_type == MCPTransport.stdio
assert client.stdio_config == stdio_config
assert client.stdio_config["command"] == "python"
assert client.stdio_config["args"] == ["-m", "my_mcp_server"]
assert client.stdio_config is not None
assert client.stdio_config.get("command") == "python"
assert client.stdio_config.get("args") == ["-m", "my_mcp_server"]
@pytest.mark.asyncio
async def test_mcp_client_stdio_connect_error(self):
"""Test MCP client stdio connection error handling"""
# Test missing stdio_config
client = MCPClient(transport_type=MCPTransport.stdio)
with pytest.raises(ValueError, match="stdio_config is required for stdio transport"):
with pytest.raises(
ValueError, match="stdio_config is required for stdio transport"
):
await client.connect()
@pytest.mark.asyncio
@patch('litellm.experimental_mcp_client.client.stdio_client')
@patch('litellm.experimental_mcp_client.client.ClientSession')
async def test_mcp_client_stdio_connect_success(self, mock_session, mock_stdio_client):
@patch("litellm.experimental_mcp_client.client.stdio_client")
@patch("litellm.experimental_mcp_client.client.ClientSession")
async def test_mcp_client_stdio_connect_success(
self, mock_session, mock_stdio_client
):
"""Test successful stdio connection"""
# Setup mocks
mock_transport = (MagicMock(), MagicMock())
mock_stdio_client.return_value.__aenter__ = AsyncMock(return_value=mock_transport)
mock_stdio_client.return_value.__aenter__ = AsyncMock(
return_value=mock_transport
)
mock_session_instance = MagicMock()
mock_session_instance.__aenter__ = AsyncMock(return_value=mock_session_instance)
mock_session_instance.initialize = AsyncMock()
mock_session.return_value = mock_session_instance
stdio_config = MCPStdioConfig(
command="python",
args=["-m", "my_mcp_server"],
env={"DEBUG": "1"}
command="python", args=["-m", "my_mcp_server"], env={"DEBUG": "1"}
)
client = MCPClient(
transport_type=MCPTransport.stdio,
stdio_config=stdio_config
)
client = MCPClient(transport_type=MCPTransport.stdio, stdio_config=stdio_config)
await client.connect()
# Verify stdio_client was called with correct parameters
mock_stdio_client.assert_called_once()
call_args = mock_stdio_client.call_args[0][0]
@@ -74,6 +74,162 @@ class TestMCPClient:
assert call_args.args == ["-m", "my_mcp_server"]
assert call_args.env == {"DEBUG": "1"}
@pytest.mark.asyncio
@patch("litellm.experimental_mcp_client.client.streamablehttp_client")
@patch.dict(
os.environ,
{
"SSL_CERT_FILE": "/path/to/custom/ca-bundle.pem",
"SSL_CERTIFICATE": "/path/to/client-cert.pem",
},
)
async def test_mcp_client_ssl_configuration_from_env(
self, mock_streamablehttp_client
):
"""Test that MCP client uses SSL configuration from environment variables"""
# Setup mocks
mock_transport = (MagicMock(), MagicMock())
mock_streamablehttp_client.return_value.__aenter__ = AsyncMock(
return_value=mock_transport
)
# Mock the session
with patch(
"litellm.experimental_mcp_client.client.ClientSession"
) as mock_session:
mock_session_instance = MagicMock()
mock_session_instance.__aenter__ = AsyncMock(
return_value=mock_session_instance
)
mock_session_instance.initialize = AsyncMock()
mock_session.return_value = mock_session_instance
client = MCPClient(
server_url="https://mcp-server.example.com",
transport_type=MCPTransport.http,
)
await client.connect()
# Verify streamablehttp_client was called
mock_streamablehttp_client.assert_called_once()
call_kwargs = mock_streamablehttp_client.call_args[1]
# Verify httpx_client_factory was passed
assert "httpx_client_factory" in call_kwargs
httpx_factory = call_kwargs["httpx_client_factory"]
# Test the factory creates a client with proper SSL config
# When SSL_CERT_FILE is set, the factory should use get_ssl_configuration
test_client = httpx_factory(headers={"test": "header"})
# Verify the client was created successfully with SSL configuration
assert test_client is not None
assert isinstance(test_client, httpx.AsyncClient)
# Verify it has the expected properties
assert test_client.headers is not None
# Clean up
await test_client.aclose()
@pytest.mark.asyncio
@patch("litellm.experimental_mcp_client.client.sse_client")
async def test_mcp_client_ssl_verify_parameter(self, mock_sse_client):
"""Test that MCP client uses ssl_verify parameter when provided"""
# Setup mocks
mock_transport = (MagicMock(), MagicMock())
mock_sse_client.return_value.__aenter__ = AsyncMock(return_value=mock_transport)
# Mock the session
with patch(
"litellm.experimental_mcp_client.client.ClientSession"
) as mock_session:
mock_session_instance = MagicMock()
mock_session_instance.__aenter__ = AsyncMock(
return_value=mock_session_instance
)
mock_session_instance.initialize = AsyncMock()
mock_session.return_value = mock_session_instance
# Test with ssl_verify=False
client = MCPClient(
server_url="https://mcp-server.example.com",
transport_type=MCPTransport.sse,
ssl_verify=False,
)
await client.connect()
# Verify sse_client was called
mock_sse_client.assert_called_once()
call_kwargs = mock_sse_client.call_args[1]
# Verify httpx_client_factory was passed
assert "httpx_client_factory" in call_kwargs
httpx_factory = call_kwargs["httpx_client_factory"]
# Test the factory creates a client with SSL verification disabled
# When ssl_verify=False, the factory should disable SSL verification
test_client = httpx_factory(headers={"test": "header"})
# Verify the client was created successfully
assert test_client is not None
assert isinstance(test_client, httpx.AsyncClient)
# Verify it has the expected properties
assert test_client.headers is not None
# Clean up
await test_client.aclose()
@pytest.mark.asyncio
@patch("litellm.experimental_mcp_client.client.streamablehttp_client")
async def test_mcp_client_ssl_verify_custom_path(self, mock_streamablehttp_client):
"""Test that MCP client uses custom CA bundle path from ssl_verify parameter"""
# Setup mocks
mock_transport = (MagicMock(), MagicMock())
mock_streamablehttp_client.return_value.__aenter__ = AsyncMock(
return_value=mock_transport
)
# Mock the session
with patch(
"litellm.experimental_mcp_client.client.ClientSession"
) as mock_session:
mock_session_instance = MagicMock()
mock_session_instance.__aenter__ = AsyncMock(
return_value=mock_session_instance
)
mock_session_instance.initialize = AsyncMock()
mock_session.return_value = mock_session_instance
# Test with custom CA bundle path
custom_ca_path = "/custom/path/to/ca-bundle.pem"
client = MCPClient(
server_url="https://mcp-server.example.com",
transport_type=MCPTransport.http,
ssl_verify=custom_ca_path,
)
await client.connect()
# Verify streamablehttp_client was called
mock_streamablehttp_client.assert_called_once()
call_kwargs = mock_streamablehttp_client.call_args[1]
# Verify httpx_client_factory was passed
assert "httpx_client_factory" in call_kwargs
httpx_factory = call_kwargs["httpx_client_factory"]
# Test the factory creates a client with custom CA bundle path
# When ssl_verify is a path, the factory should use that path for SSL verification
test_client = httpx_factory(headers={"test": "header"})
# Verify the client was created successfully
assert test_client is not None
assert isinstance(test_client, httpx.AsyncClient)
# Verify it has the expected properties
assert test_client.headers is not None
# Clean up
await test_client.aclose()
if __name__ == "__main__":
pytest.main([__file__])
pytest.main([__file__])