Merge pull request #21992 from BerriAI/litellm_fix_oauth_mcp

fix: Missing OAuth session state
This commit is contained in:
Sameer Kankute
2026-02-24 19:37:09 +05:30
committed by GitHub
6 changed files with 283 additions and 30 deletions
+2 -2
View File
@@ -5832,8 +5832,8 @@ def get_model_info(
if value is not None:
_model_info[key] = value # type: ignore
if verbose_logger.isEnabledFor(logging.DEBUG):
verbose_logger.debug(f"model_info: {_model_info}")
# if verbose_logger.isEnabledFor(logging.DEBUG):
# verbose_logger.debug(f"model_info: {_model_info}")
returned_model_info = ModelInfo(
**_model_info, supported_openai_params=supported_openai_params
@@ -1487,3 +1487,182 @@ async def test_discovery_root_includes_server_name_prefix():
assert response["scopes_supported"] == ["read", "write"]
finally:
global_mcp_server_manager.registry.clear()
@pytest.mark.asyncio
async def test_oauth_callback_redirects_with_state():
"""Test OAuth callback endpoint properly decodes state and redirects to client callback URL."""
try:
from litellm.proxy._experimental.mcp_server.discoverable_endpoints import (
callback,
)
except ImportError:
pytest.skip("MCP discoverable endpoints not available")
# Mock the state decoding
mock_state_data = {
"base_url": "http://localhost:3000/ui/mcp/oauth/callback",
"original_state": "test-uuid-state-123",
"code_challenge": "test_challenge",
"code_challenge_method": "S256",
"client_redirect_uri": "http://localhost:3000/ui/mcp/oauth/callback",
}
with patch(
"litellm.proxy._experimental.mcp_server.discoverable_endpoints.decode_state_hash"
) as mock_decode:
mock_decode.return_value = mock_state_data
# Call callback endpoint with code and state
response = await callback(
code="test_authorization_code_12345",
state="encrypted_state_value",
)
# Should redirect to the client callback URL with code and original state
assert response.status_code == 302
assert "http://localhost:3000/ui/mcp/oauth/callback" in response.headers["location"]
assert "code=test_authorization_code_12345" in response.headers["location"]
assert "state=test-uuid-state-123" in response.headers["location"]
# Verify state was decoded
mock_decode.assert_called_once_with("encrypted_state_value")
@pytest.mark.asyncio
async def test_oauth_callback_handles_invalid_state():
"""Test OAuth callback returns error page when state decryption fails."""
try:
from litellm.proxy._experimental.mcp_server.discoverable_endpoints import (
callback,
)
except ImportError:
pytest.skip("MCP discoverable endpoints not available")
# Mock state decoding to raise an exception
with patch(
"litellm.proxy._experimental.mcp_server.discoverable_endpoints.decode_state_hash"
) as mock_decode:
mock_decode.side_effect = Exception("Failed to decrypt state")
# Call callback endpoint with invalid state
response = await callback(
code="test_code",
state="invalid_encrypted_state",
)
# Should return HTML error page
assert response.status_code == 200
assert "Authentication incomplete" in response.body.decode()
@pytest.mark.asyncio
async def test_oauth_authorize_includes_scopes_from_server_config():
"""Test that authorize endpoint includes scopes from server configuration."""
try:
from fastapi import Request
from litellm.proxy._experimental.mcp_server.discoverable_endpoints import (
authorize_with_server,
)
from litellm.proxy._types import MCPTransport
from litellm.types.mcp import MCPAuth
from litellm.types.mcp_server.mcp_server_manager import MCPServer
except ImportError:
pytest.skip("MCP discoverable endpoints not available")
# Create server with specific scopes (e.g., GitLab requires 'ai_workflows')
oauth_server = MCPServer(
server_id="gitlab_server",
name="gitlab",
server_name="gitlab",
transport=MCPTransport.http,
auth_type=MCPAuth.oauth2,
authorization_url="https://gitlab.com/oauth/authorize",
token_url="https://gitlab.com/oauth/token",
client_id="test_client",
scopes=["api", "read_user", "ai_workflows"], # GitLab-specific scopes
)
mock_request = MagicMock(spec=Request)
mock_request.base_url = "https://litellm.example.com/"
mock_request.headers = {}
with patch(
"litellm.proxy._experimental.mcp_server.discoverable_endpoints.encrypt_value_helper"
) as mock_encrypt:
mock_encrypt.return_value = "encrypted_state"
# Call authorize without explicit scope parameter
response = await authorize_with_server(
request=mock_request,
mcp_server=oauth_server,
client_id="test_client",
redirect_uri="http://localhost:3000/callback",
state="test_state",
code_challenge="test_challenge",
code_challenge_method="S256",
response_type="code",
scope=None, # No scope in request, should use server's scopes
)
# Should redirect with scopes from server config
assert response.status_code in (307, 302)
redirect_url = response.headers["location"]
assert "scope=api+read_user+ai_workflows" in redirect_url or "scope=api%20read_user%20ai_workflows" in redirect_url
@pytest.mark.asyncio
async def test_oauth_authorize_prefers_request_scope_over_server_config():
"""Test that explicit scope parameter takes precedence over server configuration."""
try:
from fastapi import Request
from litellm.proxy._experimental.mcp_server.discoverable_endpoints import (
authorize_with_server,
)
from litellm.proxy._types import MCPTransport
from litellm.types.mcp import MCPAuth
from litellm.types.mcp_server.mcp_server_manager import MCPServer
except ImportError:
pytest.skip("MCP discoverable endpoints not available")
oauth_server = MCPServer(
server_id="test_server",
name="test",
server_name="test",
transport=MCPTransport.http,
auth_type=MCPAuth.oauth2,
authorization_url="https://provider.com/oauth/authorize",
token_url="https://provider.com/oauth/token",
client_id="test_client",
scopes=["default_scope1", "default_scope2"],
)
mock_request = MagicMock(spec=Request)
mock_request.base_url = "https://litellm.example.com/"
mock_request.headers = {}
with patch(
"litellm.proxy._experimental.mcp_server.discoverable_endpoints.encrypt_value_helper"
) as mock_encrypt:
mock_encrypt.return_value = "encrypted_state"
# Call authorize WITH explicit scope parameter
response = await authorize_with_server(
request=mock_request,
mcp_server=oauth_server,
client_id="test_client",
redirect_uri="http://localhost:3000/callback",
state="test_state",
code_challenge="test_challenge",
code_challenge_method="S256",
response_type="code",
scope="custom_scope1 custom_scope2", # Explicit scope should take precedence
)
# Should use the explicit scope, not server config
assert response.status_code in (307, 302)
redirect_url = response.headers["location"]
assert "scope=custom_scope1+custom_scope2" in redirect_url or "scope=custom_scope1%20custom_scope2" in redirect_url
assert "default_scope" not in redirect_url
@@ -41,13 +41,16 @@ const McpOAuthCallbackContent = () => {
}
try {
// Store in both sessionStorage and localStorage for redundancy
window.sessionStorage.setItem(RESULT_STORAGE_KEY, JSON.stringify(payload));
window.localStorage.setItem(RESULT_STORAGE_KEY, JSON.stringify(payload));
} catch (err) {
console.error("Failed to persist OAuth callback payload", err);
// Silently ignore storage errors
}
const returnUrl = window.sessionStorage.getItem(RETURN_URL_STORAGE_KEY);
console.info("[MCP OAuth callback] returnUrl", returnUrl);
// Check both sessionStorage and localStorage for return URL
const returnUrl = window.sessionStorage.getItem(RETURN_URL_STORAGE_KEY) ||
window.localStorage.getItem(RETURN_URL_STORAGE_KEY);
const destination = returnUrl || resolveDefaultRedirect();
window.location.replace(destination);
}, [payload]);
@@ -129,6 +129,21 @@ const CreateMCPServer: React.FC<CreateMCPServerProps> = ({
},
onTokenReceived: (token) => {
setOauthAccessToken(token?.access_token ?? null);
if (token?.access_token) {
const credentials = {
access_token: token.access_token,
...(token.refresh_token && { refresh_token: token.refresh_token }),
...(token.expires_in && { expires_in: token.expires_in }),
...(token.scope && { scope: token.scope }),
};
form.setFieldsValue({ credentials });
NotificationsManager.success(
"OAuth authorization successful! Please click 'Create MCP Server' to save the configuration."
);
}
},
onBeforeRedirect: persistCreateUiState,
});
@@ -117,6 +117,21 @@ const MCPServerEdit: React.FC<MCPServerEditProps> = ({
},
onTokenReceived: (token) => {
setOauthAccessToken(token?.access_token ?? null);
if (token?.access_token) {
const credentials = {
access_token: token.access_token,
...(token.refresh_token && { refresh_token: token.refresh_token }),
...(token.expires_in && { expires_in: token.expires_in }),
...(token.scope && { scope: token.scope }),
};
form.setFieldsValue({ credentials });
NotificationsManager.success(
"OAuth authorization successful! Please click 'Update MCP Server' to save the credentials."
);
}
},
onBeforeRedirect: persistEditUiState,
});
@@ -234,8 +249,13 @@ const MCPServerEdit: React.FC<MCPServerEditProps> = ({
}
}, [mcpServer]);
// Fetch tools when component mounts
// Fetch tools when component mounts or when OAuth token is received
// But only if the server has been properly saved (has a permanent server_id)
useEffect(() => {
// Don't fetch if server hasn't been saved yet (no permanent server_id)
if (!mcpServer.server_id || mcpServer.server_id.trim() === "") {
return;
}
fetchTools();
}, [mcpServer, accessToken, oauthAccessToken]);
@@ -1,6 +1,6 @@
"use client";
import { useCallback, useEffect, useState } from "react";
import { useCallback, useEffect, useRef, useState } from "react";
import NotificationsManager from "@/components/molecules/notifications_manager";
import {
buildMcpOAuthAuthorizeUrl,
@@ -61,6 +61,7 @@ export const useMcpOAuthFlow = ({
const [status, setStatus] = useState<McpOAuthStatus>("idle");
const [error, setError] = useState<string | null>(null);
const [tokenResponse, setTokenResponse] = useState<Record<string, any> | null>(null);
const processingRef = useRef(false);
const FLOW_STATE_KEY = "litellm-mcp-oauth-flow-state";
const RESULT_KEY = "litellm-mcp-oauth-result";
@@ -75,6 +76,28 @@ export const useMcpOAuthFlow = ({
redirectUri: string;
};
const setStorageItem = (key: string, value: string) => {
if (typeof window === "undefined") return;
try {
// Store in both sessionStorage and localStorage for redundancy
window.sessionStorage.setItem(key, value);
window.localStorage.setItem(key, value);
} catch (err) {
console.warn(`Failed to set storage item ${key}`, err);
}
};
const getStorageItem = (key: string): string | null => {
if (typeof window === "undefined") return null;
try {
// Try sessionStorage first, fall back to localStorage
return window.sessionStorage.getItem(key) || window.localStorage.getItem(key);
} catch (err) {
console.warn(`Failed to get storage item ${key}`, err);
return null;
}
};
const clearStoredFlow = () => {
if (typeof window === "undefined") {
return;
@@ -83,6 +106,9 @@ export const useMcpOAuthFlow = ({
window.sessionStorage.removeItem(FLOW_STATE_KEY);
window.sessionStorage.removeItem(RESULT_KEY);
window.sessionStorage.removeItem(RETURN_URL_KEY);
window.localStorage.removeItem(FLOW_STATE_KEY);
window.localStorage.removeItem(RESULT_KEY);
window.localStorage.removeItem(RETURN_URL_KEY);
} catch (err) {
console.warn("Failed to clear OAuth storage", err);
}
@@ -187,10 +213,9 @@ export const useMcpOAuthFlow = ({
}
try {
window.sessionStorage.setItem(FLOW_STATE_KEY, JSON.stringify(flowState));
window.sessionStorage.setItem(RETURN_URL_KEY, window.location.href);
setStorageItem(FLOW_STATE_KEY, JSON.stringify(flowState));
setStorageItem(RETURN_URL_KEY, window.location.href);
} catch (storageErr) {
console.error("Unable to persist OAuth state", storageErr);
throw new Error("Unable to access browser storage for OAuth. Please enable storage and retry.");
}
@@ -209,19 +234,28 @@ export const useMcpOAuthFlow = ({
return;
}
// Prevent duplicate processing
if (processingRef.current) {
return;
}
let payload: Record<string, any> | null = null;
let flowState: StoredFlowState | null = null;
try {
const storedPayload = window.sessionStorage.getItem(RESULT_KEY);
const storedPayload = getStorageItem(RESULT_KEY);
if (!storedPayload) {
return;
}
// Mark as processing
processingRef.current = true;
payload = JSON.parse(storedPayload);
flowState = JSON.parse(window.sessionStorage.getItem(FLOW_STATE_KEY) || "null");
const storedFlowState = getStorageItem(FLOW_STATE_KEY);
flowState = storedFlowState ? JSON.parse(storedFlowState) : null;
} catch (err) {
console.error("Failed to read OAuth session state", err);
clearStoredFlow();
processingRef.current = false;
setError("Failed to resume OAuth flow. Please retry.");
setStatus("error");
NotificationsManager.error("Failed to resume OAuth flow. Please retry.");
@@ -229,14 +263,26 @@ export const useMcpOAuthFlow = ({
}
if (!payload) {
processingRef.current = false;
return;
}
window.sessionStorage.removeItem(RESULT_KEY);
// Clear the result key after reading it
if (typeof window !== "undefined") {
try {
window.sessionStorage.removeItem(RESULT_KEY);
window.localStorage.removeItem(RESULT_KEY);
} catch (err) {
// Silently ignore storage errors
}
}
try {
if (!flowState || !flowState.state || !flowState.codeVerifier || !flowState.serverId) {
throw new Error("Missing OAuth session state. Please retry.");
throw new Error(
"OAuth session state was lost. This can happen if you have strict browser privacy settings. " +
"Please try again and ensure cookies/storage is enabled."
);
}
if (!payload.state || payload.state !== flowState.state) {
throw new Error("OAuth state mismatch. Please retry.");
@@ -264,31 +310,21 @@ export const useMcpOAuthFlow = ({
setError(null);
NotificationsManager.success("OAuth token retrieved successfully");
} catch (err) {
console.error("OAuth flow failed", err);
const message = err instanceof Error ? err.message : String(err);
setError(message);
setStatus("error");
NotificationsManager.error(message);
} finally {
clearStoredFlow();
// Reset processing flag after a delay to allow UI updates
setTimeout(() => {
processingRef.current = false;
}, 1000);
}
}, [onTokenReceived]);
useEffect(() => {
let cancelled = false;
const maybeResume = async () => {
if (cancelled) {
return;
}
await resumeOAuthFlow();
};
maybeResume();
return () => {
cancelled = true;
};
resumeOAuthFlow();
}, [resumeOAuthFlow]);
return {