mirror of
https://github.com/onyx-dot-app/litellm.git
synced 2026-07-01 20:44:04 -04:00
fix: Add response_type + PKCE parameters to OAuth authorization endpoint (#15720)
* fix: Add response_type parameter to OAuth authorization endpoint Fixes #15684 OAuth providers like Google require the response_type parameter during the authorization flow. This commit adds response_type=code to the authorization redirect parameters, which is required by the OAuth 2.0 specification (RFC 6749 Section 4.1.1). Changes: - Added response_type=code to authorization params in discoverable_endpoints.py - Added test coverage for the response_type parameter 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com> * fix oauth flow by forwarding code_challenge and forwarding code_verifier --------- Co-authored-by: Claude <noreply@anthropic.com>
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
import json
|
||||
from typing import Optional, Tuple
|
||||
from typing import Optional
|
||||
from urllib.parse import urlencode, urlparse, urlunparse
|
||||
|
||||
from fastapi import APIRouter, Form, HTTPException, Request
|
||||
@@ -19,32 +19,47 @@ router = APIRouter(
|
||||
)
|
||||
|
||||
|
||||
def encode_state_with_base_url(base_url: str, original_state: str) -> str:
|
||||
def encode_state_with_base_url(
|
||||
base_url: str,
|
||||
original_state: str,
|
||||
code_challenge: Optional[str] = None,
|
||||
code_challenge_method: Optional[str] = None,
|
||||
client_redirect_uri: Optional[str] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Encode the base_url and original state using encryption.
|
||||
Encode the base_url, original state, and PKCE parameters using encryption.
|
||||
|
||||
Args:
|
||||
base_url: The base URL to encode
|
||||
original_state: The original state parameter
|
||||
code_challenge: PKCE code challenge from client
|
||||
code_challenge_method: PKCE code challenge method from client
|
||||
client_redirect_uri: Original redirect_uri from client
|
||||
|
||||
Returns:
|
||||
An encrypted string that encodes both values
|
||||
An encrypted string that encodes all values
|
||||
"""
|
||||
state_data = {"base_url": base_url, "original_state": original_state}
|
||||
state_data = {
|
||||
"base_url": base_url,
|
||||
"original_state": original_state,
|
||||
"code_challenge": code_challenge,
|
||||
"code_challenge_method": code_challenge_method,
|
||||
"client_redirect_uri": client_redirect_uri,
|
||||
}
|
||||
state_json = json.dumps(state_data, sort_keys=True)
|
||||
encrypted_state = encrypt_value_helper(state_json)
|
||||
return encrypted_state
|
||||
|
||||
|
||||
def decode_state_hash(encrypted_state: str) -> Tuple[str, str]:
|
||||
def decode_state_hash(encrypted_state: str) -> dict:
|
||||
"""
|
||||
Decode an encrypted state to retrieve the base_url and original state.
|
||||
Decode an encrypted state to retrieve all OAuth session data.
|
||||
|
||||
Args:
|
||||
encrypted_state: The encrypted string to decode
|
||||
|
||||
Returns:
|
||||
A tuple of (base_url, original_state)
|
||||
A dict containing base_url, original_state, and optional PKCE parameters
|
||||
|
||||
Raises:
|
||||
Exception: If decryption fails or data is malformed
|
||||
@@ -54,7 +69,7 @@ def decode_state_hash(encrypted_state: str) -> Tuple[str, str]:
|
||||
raise ValueError("Failed to decrypt state parameter")
|
||||
|
||||
state_data = json.loads(decrypted_json)
|
||||
return state_data["base_url"], state_data["original_state"]
|
||||
return state_data
|
||||
|
||||
|
||||
@router.get("/{mcp_server_name}/authorize")
|
||||
@@ -65,8 +80,12 @@ async def authorize(
|
||||
redirect_uri: str,
|
||||
state: str = "",
|
||||
mcp_server_name: Optional[str] = None,
|
||||
code_challenge: Optional[str] = None,
|
||||
code_challenge_method: Optional[str] = None,
|
||||
response_type: Optional[str] = None,
|
||||
scope: Optional[str] = None,
|
||||
):
|
||||
# Redirect to real GitHub OAuth
|
||||
# Redirect to real OAuth provider with PKCE support
|
||||
from litellm.proxy._experimental.mcp_server.mcp_server_manager import (
|
||||
global_mcp_server_manager,
|
||||
)
|
||||
@@ -90,15 +109,30 @@ async def authorize(
|
||||
base_url = urlunparse(parsed._replace(query=""))
|
||||
request_base_url = str(request.base_url).rstrip("/")
|
||||
|
||||
# Encode the base_url and original state in a unique hash
|
||||
encoded_state = encode_state_with_base_url(base_url, state)
|
||||
# Encode the base_url, original state, PKCE params, and client redirect_uri in encrypted state
|
||||
encoded_state = encode_state_with_base_url(
|
||||
base_url=base_url,
|
||||
original_state=state,
|
||||
code_challenge=code_challenge,
|
||||
code_challenge_method=code_challenge_method,
|
||||
client_redirect_uri=redirect_uri,
|
||||
)
|
||||
|
||||
# Build params for upstream OAuth provider
|
||||
params = {
|
||||
"client_id": mcp_server.client_id,
|
||||
"redirect_uri": f"{request_base_url}/callback",
|
||||
"scope": " ".join(mcp_server.scopes),
|
||||
"scope": scope or " ".join(mcp_server.scopes),
|
||||
"state": encoded_state,
|
||||
"response_type": response_type or "code",
|
||||
}
|
||||
|
||||
# Forward PKCE parameters if present
|
||||
if code_challenge:
|
||||
params["code_challenge"] = code_challenge
|
||||
if code_challenge_method:
|
||||
params["code_challenge_method"] = code_challenge_method
|
||||
|
||||
return RedirectResponse(f"{mcp_server.authorization_url}?{urlencode(params)}")
|
||||
|
||||
|
||||
@@ -110,15 +144,16 @@ async def token_endpoint(
|
||||
redirect_uri: str = Form(None),
|
||||
client_id: str = Form(...),
|
||||
client_secret: str = Form(...),
|
||||
code_verifier: str = Form(None),
|
||||
):
|
||||
"""
|
||||
Accept the authorization code from Claude and exchange it for GitHub token.
|
||||
Forward the GitHub token back to Claude in standard OAuth format.
|
||||
Accept the authorization code from client and exchange it for OAuth token.
|
||||
Supports PKCE flow by forwarding code_verifier to upstream provider.
|
||||
|
||||
1. Call the token endpoint
|
||||
2. Store the user's PAT in the db - and generate a LiteLLM virtual key
|
||||
2. Return the token
|
||||
3. Return a virtual key in this response
|
||||
1. Call the token endpoint with PKCE parameters
|
||||
2. Store the user's token in the db - and generate a LiteLLM virtual key
|
||||
3. Return the token
|
||||
4. Return a virtual key in this response
|
||||
"""
|
||||
from litellm.proxy._experimental.mcp_server.mcp_server_manager import (
|
||||
global_mcp_server_manager,
|
||||
@@ -136,41 +171,60 @@ async def token_endpoint(
|
||||
|
||||
proxy_base_url = str(request.base_url).rstrip("/")
|
||||
|
||||
# Exchange code for real GitHub token
|
||||
# Build token request data
|
||||
token_data = {
|
||||
"grant_type": "authorization_code",
|
||||
"client_id": mcp_server.client_id,
|
||||
"client_secret": mcp_server.client_secret,
|
||||
"code": code,
|
||||
"redirect_uri": f"{proxy_base_url}/callback",
|
||||
}
|
||||
|
||||
# Forward PKCE code_verifier if present
|
||||
if code_verifier:
|
||||
token_data["code_verifier"] = code_verifier
|
||||
|
||||
# Exchange code for real OAuth token
|
||||
async_client = get_async_httpx_client(llm_provider=httpxSpecialProvider.Oauth2Check)
|
||||
response = await async_client.post(
|
||||
mcp_server.token_url,
|
||||
headers={"Accept": "application/json"},
|
||||
data={
|
||||
"client_id": mcp_server.client_id,
|
||||
"client_secret": mcp_server.client_secret,
|
||||
"code": code,
|
||||
"redirect_uri": f"{proxy_base_url}/callback",
|
||||
},
|
||||
data=token_data,
|
||||
)
|
||||
|
||||
response.raise_for_status()
|
||||
github_token = response.json()["access_token"]
|
||||
token_response = response.json()
|
||||
access_token = token_response["access_token"]
|
||||
|
||||
# Return to Claude in expected OAuth 2 format
|
||||
# Return to client in expected OAuth 2 format
|
||||
# Only include fields that have values
|
||||
result = {
|
||||
"access_token": access_token,
|
||||
"token_type": token_response.get("token_type", "Bearer"),
|
||||
"expires_in": token_response.get("expires_in", 3600),
|
||||
}
|
||||
|
||||
### return a virtual key in this response
|
||||
# Add optional fields only if they exist
|
||||
if "refresh_token" in token_response and token_response["refresh_token"]:
|
||||
result["refresh_token"] = token_response["refresh_token"]
|
||||
if "scope" in token_response and token_response["scope"]:
|
||||
result["scope"] = token_response["scope"]
|
||||
|
||||
return JSONResponse(
|
||||
{"access_token": github_token, "token_type": "Bearer", "expires_in": 3600}
|
||||
)
|
||||
return JSONResponse(result)
|
||||
|
||||
|
||||
@router.get("/callback")
|
||||
async def callback(code: str, state: str):
|
||||
try:
|
||||
# Decode the state hash to get base_url and original state
|
||||
base_url, original_state = decode_state_hash(state)
|
||||
# Decode the state hash to get base_url, original state, and PKCE params
|
||||
state_data = decode_state_hash(state)
|
||||
base_url = state_data["base_url"]
|
||||
original_state = state_data["original_state"]
|
||||
|
||||
# Exchange code for token with GitHub
|
||||
# Forward code and original state back to client
|
||||
params = {"code": code, "state": original_state}
|
||||
|
||||
# Forward token to Claude ephemeral endpoint
|
||||
# Forward to client's callback endpoint
|
||||
complete_returned_url = f"{base_url}?{urlencode(params)}"
|
||||
return RedirectResponse(url=complete_returned_url, status_code=302)
|
||||
|
||||
|
||||
@@ -0,0 +1,228 @@
|
||||
"""Tests for MCP OAuth discoverable endpoints"""
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_authorize_endpoint_includes_response_type():
|
||||
"""Test that authorize endpoint includes response_type=code parameter (fixes #15684)"""
|
||||
try:
|
||||
from litellm.proxy._experimental.mcp_server.discoverable_endpoints import (
|
||||
authorize,
|
||||
)
|
||||
from litellm.proxy._experimental.mcp_server.mcp_server_manager import (
|
||||
global_mcp_server_manager,
|
||||
)
|
||||
from litellm.types.mcp import MCPAuth
|
||||
from litellm.types.mcp_server.mcp_server_manager import MCPServer
|
||||
from litellm.proxy._types import MCPTransport
|
||||
from fastapi import Request
|
||||
except ImportError:
|
||||
pytest.skip("MCP discoverable endpoints not available")
|
||||
|
||||
# Clear registry
|
||||
global_mcp_server_manager.registry.clear()
|
||||
|
||||
# Create mock OAuth2 server
|
||||
oauth2_server = MCPServer(
|
||||
server_id="test_oauth_server",
|
||||
name="test_oauth",
|
||||
server_name="test_oauth",
|
||||
alias="test_oauth",
|
||||
transport=MCPTransport.http,
|
||||
auth_type=MCPAuth.oauth2,
|
||||
client_id="test_client_id",
|
||||
client_secret="test_client_secret",
|
||||
authorization_url="https://provider.com/oauth/authorize",
|
||||
token_url="https://provider.com/oauth/token",
|
||||
scopes=["read", "write"],
|
||||
)
|
||||
global_mcp_server_manager.registry[oauth2_server.server_id] = oauth2_server
|
||||
|
||||
# Mock request
|
||||
mock_request = MagicMock(spec=Request)
|
||||
mock_request.base_url = "https://litellm.example.com/"
|
||||
mock_request.headers = {}
|
||||
|
||||
# Mock the encryption functions to avoid needing a signing key
|
||||
with patch(
|
||||
"litellm.proxy._experimental.mcp_server.discoverable_endpoints.encrypt_value_helper"
|
||||
) as mock_encrypt:
|
||||
mock_encrypt.return_value = "mocked_encrypted_state"
|
||||
|
||||
# Call authorize endpoint
|
||||
response = await authorize(
|
||||
request=mock_request,
|
||||
client_id="test_oauth",
|
||||
redirect_uri="https://client.example.com/callback",
|
||||
state="test_state",
|
||||
)
|
||||
|
||||
# Verify response is a redirect
|
||||
assert response.status_code == 307 # FastAPI RedirectResponse default
|
||||
|
||||
# Verify response_type is in the redirect URL
|
||||
assert "response_type=code" in response.headers["location"]
|
||||
assert "https://provider.com/oauth/authorize" in response.headers["location"]
|
||||
assert "client_id=test_client_id" in response.headers["location"]
|
||||
assert "scope=read+write" in response.headers["location"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_authorize_endpoint_forwards_pkce_parameters():
|
||||
"""Test that authorize endpoint forwards PKCE parameters (code_challenge and code_challenge_method)"""
|
||||
try:
|
||||
from litellm.proxy._experimental.mcp_server.discoverable_endpoints import (
|
||||
authorize,
|
||||
)
|
||||
from litellm.proxy._experimental.mcp_server.mcp_server_manager import (
|
||||
global_mcp_server_manager,
|
||||
)
|
||||
from litellm.types.mcp import MCPAuth
|
||||
from litellm.types.mcp_server.mcp_server_manager import MCPServer
|
||||
from litellm.proxy._types import MCPTransport
|
||||
from fastapi import Request
|
||||
except ImportError:
|
||||
pytest.skip("MCP discoverable endpoints not available")
|
||||
|
||||
# Clear registry
|
||||
global_mcp_server_manager.registry.clear()
|
||||
|
||||
# Create mock OAuth2 server (simulating Google OAuth)
|
||||
oauth2_server = MCPServer(
|
||||
server_id="google_mcp",
|
||||
name="google_mcp",
|
||||
server_name="google_mcp",
|
||||
alias="google_mcp",
|
||||
transport=MCPTransport.http,
|
||||
auth_type=MCPAuth.oauth2,
|
||||
client_id="669428968603-test.apps.googleusercontent.com",
|
||||
client_secret="GOCSPX-test_secret",
|
||||
authorization_url="https://accounts.google.com/o/oauth2/v2/auth",
|
||||
token_url="https://oauth2.googleapis.com/token",
|
||||
scopes=["https://www.googleapis.com/auth/drive", "openid", "email"],
|
||||
)
|
||||
global_mcp_server_manager.registry[oauth2_server.server_id] = oauth2_server
|
||||
|
||||
# Mock request
|
||||
mock_request = MagicMock(spec=Request)
|
||||
mock_request.base_url = "https://litellm-proxy.example.com/"
|
||||
mock_request.headers = {}
|
||||
|
||||
# Mock the encryption function
|
||||
with patch(
|
||||
"litellm.proxy._experimental.mcp_server.discoverable_endpoints.encrypt_value_helper"
|
||||
) as mock_encrypt:
|
||||
mock_encrypt.return_value = "mocked_encrypted_state_with_pkce"
|
||||
|
||||
# Call authorize endpoint with PKCE parameters
|
||||
response = await authorize(
|
||||
request=mock_request,
|
||||
client_id="google_mcp",
|
||||
redirect_uri="http://localhost:60108/callback",
|
||||
state="test_client_state",
|
||||
code_challenge="x6YH_qgwbvOzbsHDuL1sW9gYkR9-gObUiIB5RkPwxDk",
|
||||
code_challenge_method="S256",
|
||||
)
|
||||
|
||||
# Verify response is a redirect
|
||||
assert response.status_code == 307
|
||||
|
||||
# Verify PKCE parameters are included in the redirect URL
|
||||
location = response.headers["location"]
|
||||
assert "https://accounts.google.com/o/oauth2/v2/auth" in location
|
||||
assert "code_challenge=x6YH_qgwbvOzbsHDuL1sW9gYkR9-gObUiIB5RkPwxDk" in location
|
||||
assert "code_challenge_method=S256" in location
|
||||
assert "client_id=669428968603-test.apps.googleusercontent.com" in location
|
||||
assert "response_type=code" in location
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_token_endpoint_forwards_code_verifier():
|
||||
"""Test that token endpoint forwards code_verifier for PKCE flow"""
|
||||
try:
|
||||
from litellm.proxy._experimental.mcp_server.discoverable_endpoints import (
|
||||
token_endpoint,
|
||||
)
|
||||
from litellm.proxy._experimental.mcp_server.mcp_server_manager import (
|
||||
global_mcp_server_manager,
|
||||
)
|
||||
from litellm.types.mcp import MCPAuth
|
||||
from litellm.types.mcp_server.mcp_server_manager import MCPServer
|
||||
from litellm.proxy._types import MCPTransport
|
||||
from fastapi import Request
|
||||
import httpx
|
||||
except ImportError:
|
||||
pytest.skip("MCP discoverable endpoints not available")
|
||||
|
||||
# Clear registry
|
||||
global_mcp_server_manager.registry.clear()
|
||||
|
||||
# Create mock OAuth2 server
|
||||
oauth2_server = MCPServer(
|
||||
server_id="google_mcp",
|
||||
name="google_mcp",
|
||||
server_name="google_mcp",
|
||||
alias="google_mcp",
|
||||
transport=MCPTransport.http,
|
||||
auth_type=MCPAuth.oauth2,
|
||||
client_id="669428968603-test.apps.googleusercontent.com",
|
||||
client_secret="GOCSPX-test_secret",
|
||||
authorization_url="https://accounts.google.com/o/oauth2/v2/auth",
|
||||
token_url="https://oauth2.googleapis.com/token",
|
||||
scopes=["https://www.googleapis.com/auth/drive", "openid", "email"],
|
||||
)
|
||||
global_mcp_server_manager.registry[oauth2_server.server_id] = oauth2_server
|
||||
|
||||
# Mock request
|
||||
mock_request = MagicMock(spec=Request)
|
||||
mock_request.base_url = "https://litellm-proxy.example.com/"
|
||||
|
||||
# Mock httpx client response
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {
|
||||
"access_token": "ya29.test_access_token",
|
||||
"token_type": "Bearer",
|
||||
"expires_in": 3599,
|
||||
"scope": "openid email https://www.googleapis.com/auth/drive",
|
||||
}
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
|
||||
# Mock the async httpx client with AsyncMock for async methods
|
||||
from unittest.mock import AsyncMock
|
||||
with patch(
|
||||
"litellm.proxy._experimental.mcp_server.discoverable_endpoints.get_async_httpx_client"
|
||||
) as mock_get_client:
|
||||
mock_async_client = MagicMock()
|
||||
# Use AsyncMock for the async post method
|
||||
mock_async_client.post = AsyncMock(return_value=mock_response)
|
||||
mock_get_client.return_value = mock_async_client
|
||||
|
||||
# Call token endpoint with code_verifier
|
||||
response = await token_endpoint(
|
||||
request=mock_request,
|
||||
grant_type="authorization_code",
|
||||
code="4/test_authorization_code",
|
||||
redirect_uri="http://localhost:60108/callback",
|
||||
client_id="google_mcp",
|
||||
client_secret="dummy",
|
||||
code_verifier="test_code_verifier_from_client",
|
||||
)
|
||||
|
||||
# Verify that the token endpoint was called with code_verifier
|
||||
mock_async_client.post.assert_called_once()
|
||||
call_args = mock_async_client.post.call_args
|
||||
|
||||
# Check the data parameter includes code_verifier
|
||||
assert call_args[1]["data"]["code_verifier"] == "test_code_verifier_from_client"
|
||||
assert call_args[1]["data"]["code"] == "4/test_authorization_code"
|
||||
assert call_args[1]["data"]["client_id"] == "669428968603-test.apps.googleusercontent.com"
|
||||
assert call_args[1]["data"]["client_secret"] == "GOCSPX-test_secret"
|
||||
assert call_args[1]["data"]["grant_type"] == "authorization_code"
|
||||
|
||||
# Verify response
|
||||
response_data = response.body
|
||||
import json
|
||||
token_data = json.loads(response_data)
|
||||
assert token_data["access_token"] == "ya29.test_access_token"
|
||||
assert token_data["token_type"] == "Bearer"
|
||||
Reference in New Issue
Block a user