fix: Add response_type + PKCE parameters to OAuth authorization endpoint (#15720)

* fix: Add response_type parameter to OAuth authorization endpoint

Fixes #15684

OAuth providers like Google require the response_type parameter during
the authorization flow. This commit adds response_type=code to the
authorization redirect parameters, which is required by the OAuth 2.0
specification (RFC 6749 Section 4.1.1).

Changes:
- Added response_type=code to authorization params in discoverable_endpoints.py
- Added test coverage for the response_type parameter

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>

* fix oauth flow by forwarding code_challenge and forwarding code_verifier

---------

Co-authored-by: Claude <noreply@anthropic.com>
This commit is contained in:
Talal
2025-10-21 09:43:19 -07:00
committed by GitHub
parent 98f1d63508
commit 46d55bd92a
2 changed files with 318 additions and 36 deletions
@@ -1,5 +1,5 @@
import json
from typing import Optional, Tuple
from typing import Optional
from urllib.parse import urlencode, urlparse, urlunparse
from fastapi import APIRouter, Form, HTTPException, Request
@@ -19,32 +19,47 @@ router = APIRouter(
)
def encode_state_with_base_url(base_url: str, original_state: str) -> str:
def encode_state_with_base_url(
base_url: str,
original_state: str,
code_challenge: Optional[str] = None,
code_challenge_method: Optional[str] = None,
client_redirect_uri: Optional[str] = None,
) -> str:
"""
Encode the base_url and original state using encryption.
Encode the base_url, original state, and PKCE parameters using encryption.
Args:
base_url: The base URL to encode
original_state: The original state parameter
code_challenge: PKCE code challenge from client
code_challenge_method: PKCE code challenge method from client
client_redirect_uri: Original redirect_uri from client
Returns:
An encrypted string that encodes both values
An encrypted string that encodes all values
"""
state_data = {"base_url": base_url, "original_state": original_state}
state_data = {
"base_url": base_url,
"original_state": original_state,
"code_challenge": code_challenge,
"code_challenge_method": code_challenge_method,
"client_redirect_uri": client_redirect_uri,
}
state_json = json.dumps(state_data, sort_keys=True)
encrypted_state = encrypt_value_helper(state_json)
return encrypted_state
def decode_state_hash(encrypted_state: str) -> Tuple[str, str]:
def decode_state_hash(encrypted_state: str) -> dict:
"""
Decode an encrypted state to retrieve the base_url and original state.
Decode an encrypted state to retrieve all OAuth session data.
Args:
encrypted_state: The encrypted string to decode
Returns:
A tuple of (base_url, original_state)
A dict containing base_url, original_state, and optional PKCE parameters
Raises:
Exception: If decryption fails or data is malformed
@@ -54,7 +69,7 @@ def decode_state_hash(encrypted_state: str) -> Tuple[str, str]:
raise ValueError("Failed to decrypt state parameter")
state_data = json.loads(decrypted_json)
return state_data["base_url"], state_data["original_state"]
return state_data
@router.get("/{mcp_server_name}/authorize")
@@ -65,8 +80,12 @@ async def authorize(
redirect_uri: str,
state: str = "",
mcp_server_name: Optional[str] = None,
code_challenge: Optional[str] = None,
code_challenge_method: Optional[str] = None,
response_type: Optional[str] = None,
scope: Optional[str] = None,
):
# Redirect to real GitHub OAuth
# Redirect to real OAuth provider with PKCE support
from litellm.proxy._experimental.mcp_server.mcp_server_manager import (
global_mcp_server_manager,
)
@@ -90,15 +109,30 @@ async def authorize(
base_url = urlunparse(parsed._replace(query=""))
request_base_url = str(request.base_url).rstrip("/")
# Encode the base_url and original state in a unique hash
encoded_state = encode_state_with_base_url(base_url, state)
# Encode the base_url, original state, PKCE params, and client redirect_uri in encrypted state
encoded_state = encode_state_with_base_url(
base_url=base_url,
original_state=state,
code_challenge=code_challenge,
code_challenge_method=code_challenge_method,
client_redirect_uri=redirect_uri,
)
# Build params for upstream OAuth provider
params = {
"client_id": mcp_server.client_id,
"redirect_uri": f"{request_base_url}/callback",
"scope": " ".join(mcp_server.scopes),
"scope": scope or " ".join(mcp_server.scopes),
"state": encoded_state,
"response_type": response_type or "code",
}
# Forward PKCE parameters if present
if code_challenge:
params["code_challenge"] = code_challenge
if code_challenge_method:
params["code_challenge_method"] = code_challenge_method
return RedirectResponse(f"{mcp_server.authorization_url}?{urlencode(params)}")
@@ -110,15 +144,16 @@ async def token_endpoint(
redirect_uri: str = Form(None),
client_id: str = Form(...),
client_secret: str = Form(...),
code_verifier: str = Form(None),
):
"""
Accept the authorization code from Claude and exchange it for GitHub token.
Forward the GitHub token back to Claude in standard OAuth format.
Accept the authorization code from client and exchange it for OAuth token.
Supports PKCE flow by forwarding code_verifier to upstream provider.
1. Call the token endpoint
2. Store the user's PAT in the db - and generate a LiteLLM virtual key
2. Return the token
3. Return a virtual key in this response
1. Call the token endpoint with PKCE parameters
2. Store the user's token in the db - and generate a LiteLLM virtual key
3. Return the token
4. Return a virtual key in this response
"""
from litellm.proxy._experimental.mcp_server.mcp_server_manager import (
global_mcp_server_manager,
@@ -136,41 +171,60 @@ async def token_endpoint(
proxy_base_url = str(request.base_url).rstrip("/")
# Exchange code for real GitHub token
# Build token request data
token_data = {
"grant_type": "authorization_code",
"client_id": mcp_server.client_id,
"client_secret": mcp_server.client_secret,
"code": code,
"redirect_uri": f"{proxy_base_url}/callback",
}
# Forward PKCE code_verifier if present
if code_verifier:
token_data["code_verifier"] = code_verifier
# Exchange code for real OAuth token
async_client = get_async_httpx_client(llm_provider=httpxSpecialProvider.Oauth2Check)
response = await async_client.post(
mcp_server.token_url,
headers={"Accept": "application/json"},
data={
"client_id": mcp_server.client_id,
"client_secret": mcp_server.client_secret,
"code": code,
"redirect_uri": f"{proxy_base_url}/callback",
},
data=token_data,
)
response.raise_for_status()
github_token = response.json()["access_token"]
token_response = response.json()
access_token = token_response["access_token"]
# Return to Claude in expected OAuth 2 format
# Return to client in expected OAuth 2 format
# Only include fields that have values
result = {
"access_token": access_token,
"token_type": token_response.get("token_type", "Bearer"),
"expires_in": token_response.get("expires_in", 3600),
}
### return a virtual key in this response
# Add optional fields only if they exist
if "refresh_token" in token_response and token_response["refresh_token"]:
result["refresh_token"] = token_response["refresh_token"]
if "scope" in token_response and token_response["scope"]:
result["scope"] = token_response["scope"]
return JSONResponse(
{"access_token": github_token, "token_type": "Bearer", "expires_in": 3600}
)
return JSONResponse(result)
@router.get("/callback")
async def callback(code: str, state: str):
try:
# Decode the state hash to get base_url and original state
base_url, original_state = decode_state_hash(state)
# Decode the state hash to get base_url, original state, and PKCE params
state_data = decode_state_hash(state)
base_url = state_data["base_url"]
original_state = state_data["original_state"]
# Exchange code for token with GitHub
# Forward code and original state back to client
params = {"code": code, "state": original_state}
# Forward token to Claude ephemeral endpoint
# Forward to client's callback endpoint
complete_returned_url = f"{base_url}?{urlencode(params)}"
return RedirectResponse(url=complete_returned_url, status_code=302)
@@ -0,0 +1,228 @@
"""Tests for MCP OAuth discoverable endpoints"""
import pytest
from unittest.mock import MagicMock, patch
@pytest.mark.asyncio
async def test_authorize_endpoint_includes_response_type():
"""Test that authorize endpoint includes response_type=code parameter (fixes #15684)"""
try:
from litellm.proxy._experimental.mcp_server.discoverable_endpoints import (
authorize,
)
from litellm.proxy._experimental.mcp_server.mcp_server_manager import (
global_mcp_server_manager,
)
from litellm.types.mcp import MCPAuth
from litellm.types.mcp_server.mcp_server_manager import MCPServer
from litellm.proxy._types import MCPTransport
from fastapi import Request
except ImportError:
pytest.skip("MCP discoverable endpoints not available")
# Clear registry
global_mcp_server_manager.registry.clear()
# Create mock OAuth2 server
oauth2_server = MCPServer(
server_id="test_oauth_server",
name="test_oauth",
server_name="test_oauth",
alias="test_oauth",
transport=MCPTransport.http,
auth_type=MCPAuth.oauth2,
client_id="test_client_id",
client_secret="test_client_secret",
authorization_url="https://provider.com/oauth/authorize",
token_url="https://provider.com/oauth/token",
scopes=["read", "write"],
)
global_mcp_server_manager.registry[oauth2_server.server_id] = oauth2_server
# Mock request
mock_request = MagicMock(spec=Request)
mock_request.base_url = "https://litellm.example.com/"
mock_request.headers = {}
# Mock the encryption functions to avoid needing a signing key
with patch(
"litellm.proxy._experimental.mcp_server.discoverable_endpoints.encrypt_value_helper"
) as mock_encrypt:
mock_encrypt.return_value = "mocked_encrypted_state"
# Call authorize endpoint
response = await authorize(
request=mock_request,
client_id="test_oauth",
redirect_uri="https://client.example.com/callback",
state="test_state",
)
# Verify response is a redirect
assert response.status_code == 307 # FastAPI RedirectResponse default
# Verify response_type is in the redirect URL
assert "response_type=code" in response.headers["location"]
assert "https://provider.com/oauth/authorize" in response.headers["location"]
assert "client_id=test_client_id" in response.headers["location"]
assert "scope=read+write" in response.headers["location"]
@pytest.mark.asyncio
async def test_authorize_endpoint_forwards_pkce_parameters():
"""Test that authorize endpoint forwards PKCE parameters (code_challenge and code_challenge_method)"""
try:
from litellm.proxy._experimental.mcp_server.discoverable_endpoints import (
authorize,
)
from litellm.proxy._experimental.mcp_server.mcp_server_manager import (
global_mcp_server_manager,
)
from litellm.types.mcp import MCPAuth
from litellm.types.mcp_server.mcp_server_manager import MCPServer
from litellm.proxy._types import MCPTransport
from fastapi import Request
except ImportError:
pytest.skip("MCP discoverable endpoints not available")
# Clear registry
global_mcp_server_manager.registry.clear()
# Create mock OAuth2 server (simulating Google OAuth)
oauth2_server = MCPServer(
server_id="google_mcp",
name="google_mcp",
server_name="google_mcp",
alias="google_mcp",
transport=MCPTransport.http,
auth_type=MCPAuth.oauth2,
client_id="669428968603-test.apps.googleusercontent.com",
client_secret="GOCSPX-test_secret",
authorization_url="https://accounts.google.com/o/oauth2/v2/auth",
token_url="https://oauth2.googleapis.com/token",
scopes=["https://www.googleapis.com/auth/drive", "openid", "email"],
)
global_mcp_server_manager.registry[oauth2_server.server_id] = oauth2_server
# Mock request
mock_request = MagicMock(spec=Request)
mock_request.base_url = "https://litellm-proxy.example.com/"
mock_request.headers = {}
# Mock the encryption function
with patch(
"litellm.proxy._experimental.mcp_server.discoverable_endpoints.encrypt_value_helper"
) as mock_encrypt:
mock_encrypt.return_value = "mocked_encrypted_state_with_pkce"
# Call authorize endpoint with PKCE parameters
response = await authorize(
request=mock_request,
client_id="google_mcp",
redirect_uri="http://localhost:60108/callback",
state="test_client_state",
code_challenge="x6YH_qgwbvOzbsHDuL1sW9gYkR9-gObUiIB5RkPwxDk",
code_challenge_method="S256",
)
# Verify response is a redirect
assert response.status_code == 307
# Verify PKCE parameters are included in the redirect URL
location = response.headers["location"]
assert "https://accounts.google.com/o/oauth2/v2/auth" in location
assert "code_challenge=x6YH_qgwbvOzbsHDuL1sW9gYkR9-gObUiIB5RkPwxDk" in location
assert "code_challenge_method=S256" in location
assert "client_id=669428968603-test.apps.googleusercontent.com" in location
assert "response_type=code" in location
@pytest.mark.asyncio
async def test_token_endpoint_forwards_code_verifier():
"""Test that token endpoint forwards code_verifier for PKCE flow"""
try:
from litellm.proxy._experimental.mcp_server.discoverable_endpoints import (
token_endpoint,
)
from litellm.proxy._experimental.mcp_server.mcp_server_manager import (
global_mcp_server_manager,
)
from litellm.types.mcp import MCPAuth
from litellm.types.mcp_server.mcp_server_manager import MCPServer
from litellm.proxy._types import MCPTransport
from fastapi import Request
import httpx
except ImportError:
pytest.skip("MCP discoverable endpoints not available")
# Clear registry
global_mcp_server_manager.registry.clear()
# Create mock OAuth2 server
oauth2_server = MCPServer(
server_id="google_mcp",
name="google_mcp",
server_name="google_mcp",
alias="google_mcp",
transport=MCPTransport.http,
auth_type=MCPAuth.oauth2,
client_id="669428968603-test.apps.googleusercontent.com",
client_secret="GOCSPX-test_secret",
authorization_url="https://accounts.google.com/o/oauth2/v2/auth",
token_url="https://oauth2.googleapis.com/token",
scopes=["https://www.googleapis.com/auth/drive", "openid", "email"],
)
global_mcp_server_manager.registry[oauth2_server.server_id] = oauth2_server
# Mock request
mock_request = MagicMock(spec=Request)
mock_request.base_url = "https://litellm-proxy.example.com/"
# Mock httpx client response
mock_response = MagicMock()
mock_response.json.return_value = {
"access_token": "ya29.test_access_token",
"token_type": "Bearer",
"expires_in": 3599,
"scope": "openid email https://www.googleapis.com/auth/drive",
}
mock_response.raise_for_status = MagicMock()
# Mock the async httpx client with AsyncMock for async methods
from unittest.mock import AsyncMock
with patch(
"litellm.proxy._experimental.mcp_server.discoverable_endpoints.get_async_httpx_client"
) as mock_get_client:
mock_async_client = MagicMock()
# Use AsyncMock for the async post method
mock_async_client.post = AsyncMock(return_value=mock_response)
mock_get_client.return_value = mock_async_client
# Call token endpoint with code_verifier
response = await token_endpoint(
request=mock_request,
grant_type="authorization_code",
code="4/test_authorization_code",
redirect_uri="http://localhost:60108/callback",
client_id="google_mcp",
client_secret="dummy",
code_verifier="test_code_verifier_from_client",
)
# Verify that the token endpoint was called with code_verifier
mock_async_client.post.assert_called_once()
call_args = mock_async_client.post.call_args
# Check the data parameter includes code_verifier
assert call_args[1]["data"]["code_verifier"] == "test_code_verifier_from_client"
assert call_args[1]["data"]["code"] == "4/test_authorization_code"
assert call_args[1]["data"]["client_id"] == "669428968603-test.apps.googleusercontent.com"
assert call_args[1]["data"]["client_secret"] == "GOCSPX-test_secret"
assert call_args[1]["data"]["grant_type"] == "authorization_code"
# Verify response
response_data = response.body
import json
token_data = json.loads(response_data)
assert token_data["access_token"] == "ya29.test_access_token"
assert token_data["token_type"] == "Bearer"