mirror of
https://github.com/tiennm99/litellm.git
synced 2026-06-27 05:07:36 +00:00
Merge pull request #21992 from BerriAI/litellm_fix_oauth_mcp
fix: Missing OAuth session state
This commit is contained in:
+2
-2
@@ -5832,8 +5832,8 @@ def get_model_info(
|
||||
if value is not None:
|
||||
_model_info[key] = value # type: ignore
|
||||
|
||||
if verbose_logger.isEnabledFor(logging.DEBUG):
|
||||
verbose_logger.debug(f"model_info: {_model_info}")
|
||||
# if verbose_logger.isEnabledFor(logging.DEBUG):
|
||||
# verbose_logger.debug(f"model_info: {_model_info}")
|
||||
|
||||
returned_model_info = ModelInfo(
|
||||
**_model_info, supported_openai_params=supported_openai_params
|
||||
|
||||
@@ -1487,3 +1487,182 @@ async def test_discovery_root_includes_server_name_prefix():
|
||||
assert response["scopes_supported"] == ["read", "write"]
|
||||
finally:
|
||||
global_mcp_server_manager.registry.clear()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_oauth_callback_redirects_with_state():
|
||||
"""Test OAuth callback endpoint properly decodes state and redirects to client callback URL."""
|
||||
try:
|
||||
from litellm.proxy._experimental.mcp_server.discoverable_endpoints import (
|
||||
callback,
|
||||
)
|
||||
except ImportError:
|
||||
pytest.skip("MCP discoverable endpoints not available")
|
||||
|
||||
# Mock the state decoding
|
||||
mock_state_data = {
|
||||
"base_url": "http://localhost:3000/ui/mcp/oauth/callback",
|
||||
"original_state": "test-uuid-state-123",
|
||||
"code_challenge": "test_challenge",
|
||||
"code_challenge_method": "S256",
|
||||
"client_redirect_uri": "http://localhost:3000/ui/mcp/oauth/callback",
|
||||
}
|
||||
|
||||
with patch(
|
||||
"litellm.proxy._experimental.mcp_server.discoverable_endpoints.decode_state_hash"
|
||||
) as mock_decode:
|
||||
mock_decode.return_value = mock_state_data
|
||||
|
||||
# Call callback endpoint with code and state
|
||||
response = await callback(
|
||||
code="test_authorization_code_12345",
|
||||
state="encrypted_state_value",
|
||||
)
|
||||
|
||||
# Should redirect to the client callback URL with code and original state
|
||||
assert response.status_code == 302
|
||||
assert "http://localhost:3000/ui/mcp/oauth/callback" in response.headers["location"]
|
||||
assert "code=test_authorization_code_12345" in response.headers["location"]
|
||||
assert "state=test-uuid-state-123" in response.headers["location"]
|
||||
|
||||
# Verify state was decoded
|
||||
mock_decode.assert_called_once_with("encrypted_state_value")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_oauth_callback_handles_invalid_state():
|
||||
"""Test OAuth callback returns error page when state decryption fails."""
|
||||
try:
|
||||
from litellm.proxy._experimental.mcp_server.discoverable_endpoints import (
|
||||
callback,
|
||||
)
|
||||
except ImportError:
|
||||
pytest.skip("MCP discoverable endpoints not available")
|
||||
|
||||
# Mock state decoding to raise an exception
|
||||
with patch(
|
||||
"litellm.proxy._experimental.mcp_server.discoverable_endpoints.decode_state_hash"
|
||||
) as mock_decode:
|
||||
mock_decode.side_effect = Exception("Failed to decrypt state")
|
||||
|
||||
# Call callback endpoint with invalid state
|
||||
response = await callback(
|
||||
code="test_code",
|
||||
state="invalid_encrypted_state",
|
||||
)
|
||||
|
||||
# Should return HTML error page
|
||||
assert response.status_code == 200
|
||||
assert "Authentication incomplete" in response.body.decode()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_oauth_authorize_includes_scopes_from_server_config():
|
||||
"""Test that authorize endpoint includes scopes from server configuration."""
|
||||
try:
|
||||
from fastapi import Request
|
||||
|
||||
from litellm.proxy._experimental.mcp_server.discoverable_endpoints import (
|
||||
authorize_with_server,
|
||||
)
|
||||
from litellm.proxy._types import MCPTransport
|
||||
from litellm.types.mcp import MCPAuth
|
||||
from litellm.types.mcp_server.mcp_server_manager import MCPServer
|
||||
except ImportError:
|
||||
pytest.skip("MCP discoverable endpoints not available")
|
||||
|
||||
# Create server with specific scopes (e.g., GitLab requires 'ai_workflows')
|
||||
oauth_server = MCPServer(
|
||||
server_id="gitlab_server",
|
||||
name="gitlab",
|
||||
server_name="gitlab",
|
||||
transport=MCPTransport.http,
|
||||
auth_type=MCPAuth.oauth2,
|
||||
authorization_url="https://gitlab.com/oauth/authorize",
|
||||
token_url="https://gitlab.com/oauth/token",
|
||||
client_id="test_client",
|
||||
scopes=["api", "read_user", "ai_workflows"], # GitLab-specific scopes
|
||||
)
|
||||
|
||||
mock_request = MagicMock(spec=Request)
|
||||
mock_request.base_url = "https://litellm.example.com/"
|
||||
mock_request.headers = {}
|
||||
|
||||
with patch(
|
||||
"litellm.proxy._experimental.mcp_server.discoverable_endpoints.encrypt_value_helper"
|
||||
) as mock_encrypt:
|
||||
mock_encrypt.return_value = "encrypted_state"
|
||||
|
||||
# Call authorize without explicit scope parameter
|
||||
response = await authorize_with_server(
|
||||
request=mock_request,
|
||||
mcp_server=oauth_server,
|
||||
client_id="test_client",
|
||||
redirect_uri="http://localhost:3000/callback",
|
||||
state="test_state",
|
||||
code_challenge="test_challenge",
|
||||
code_challenge_method="S256",
|
||||
response_type="code",
|
||||
scope=None, # No scope in request, should use server's scopes
|
||||
)
|
||||
|
||||
# Should redirect with scopes from server config
|
||||
assert response.status_code in (307, 302)
|
||||
redirect_url = response.headers["location"]
|
||||
assert "scope=api+read_user+ai_workflows" in redirect_url or "scope=api%20read_user%20ai_workflows" in redirect_url
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_oauth_authorize_prefers_request_scope_over_server_config():
|
||||
"""Test that explicit scope parameter takes precedence over server configuration."""
|
||||
try:
|
||||
from fastapi import Request
|
||||
|
||||
from litellm.proxy._experimental.mcp_server.discoverable_endpoints import (
|
||||
authorize_with_server,
|
||||
)
|
||||
from litellm.proxy._types import MCPTransport
|
||||
from litellm.types.mcp import MCPAuth
|
||||
from litellm.types.mcp_server.mcp_server_manager import MCPServer
|
||||
except ImportError:
|
||||
pytest.skip("MCP discoverable endpoints not available")
|
||||
|
||||
oauth_server = MCPServer(
|
||||
server_id="test_server",
|
||||
name="test",
|
||||
server_name="test",
|
||||
transport=MCPTransport.http,
|
||||
auth_type=MCPAuth.oauth2,
|
||||
authorization_url="https://provider.com/oauth/authorize",
|
||||
token_url="https://provider.com/oauth/token",
|
||||
client_id="test_client",
|
||||
scopes=["default_scope1", "default_scope2"],
|
||||
)
|
||||
|
||||
mock_request = MagicMock(spec=Request)
|
||||
mock_request.base_url = "https://litellm.example.com/"
|
||||
mock_request.headers = {}
|
||||
|
||||
with patch(
|
||||
"litellm.proxy._experimental.mcp_server.discoverable_endpoints.encrypt_value_helper"
|
||||
) as mock_encrypt:
|
||||
mock_encrypt.return_value = "encrypted_state"
|
||||
|
||||
# Call authorize WITH explicit scope parameter
|
||||
response = await authorize_with_server(
|
||||
request=mock_request,
|
||||
mcp_server=oauth_server,
|
||||
client_id="test_client",
|
||||
redirect_uri="http://localhost:3000/callback",
|
||||
state="test_state",
|
||||
code_challenge="test_challenge",
|
||||
code_challenge_method="S256",
|
||||
response_type="code",
|
||||
scope="custom_scope1 custom_scope2", # Explicit scope should take precedence
|
||||
)
|
||||
|
||||
# Should use the explicit scope, not server config
|
||||
assert response.status_code in (307, 302)
|
||||
redirect_url = response.headers["location"]
|
||||
assert "scope=custom_scope1+custom_scope2" in redirect_url or "scope=custom_scope1%20custom_scope2" in redirect_url
|
||||
assert "default_scope" not in redirect_url
|
||||
|
||||
@@ -41,13 +41,16 @@ const McpOAuthCallbackContent = () => {
|
||||
}
|
||||
|
||||
try {
|
||||
// Store in both sessionStorage and localStorage for redundancy
|
||||
window.sessionStorage.setItem(RESULT_STORAGE_KEY, JSON.stringify(payload));
|
||||
window.localStorage.setItem(RESULT_STORAGE_KEY, JSON.stringify(payload));
|
||||
} catch (err) {
|
||||
console.error("Failed to persist OAuth callback payload", err);
|
||||
// Silently ignore storage errors
|
||||
}
|
||||
|
||||
const returnUrl = window.sessionStorage.getItem(RETURN_URL_STORAGE_KEY);
|
||||
console.info("[MCP OAuth callback] returnUrl", returnUrl);
|
||||
// Check both sessionStorage and localStorage for return URL
|
||||
const returnUrl = window.sessionStorage.getItem(RETURN_URL_STORAGE_KEY) ||
|
||||
window.localStorage.getItem(RETURN_URL_STORAGE_KEY);
|
||||
const destination = returnUrl || resolveDefaultRedirect();
|
||||
window.location.replace(destination);
|
||||
}, [payload]);
|
||||
|
||||
@@ -129,6 +129,21 @@ const CreateMCPServer: React.FC<CreateMCPServerProps> = ({
|
||||
},
|
||||
onTokenReceived: (token) => {
|
||||
setOauthAccessToken(token?.access_token ?? null);
|
||||
|
||||
if (token?.access_token) {
|
||||
const credentials = {
|
||||
access_token: token.access_token,
|
||||
...(token.refresh_token && { refresh_token: token.refresh_token }),
|
||||
...(token.expires_in && { expires_in: token.expires_in }),
|
||||
...(token.scope && { scope: token.scope }),
|
||||
};
|
||||
|
||||
form.setFieldsValue({ credentials });
|
||||
|
||||
NotificationsManager.success(
|
||||
"OAuth authorization successful! Please click 'Create MCP Server' to save the configuration."
|
||||
);
|
||||
}
|
||||
},
|
||||
onBeforeRedirect: persistCreateUiState,
|
||||
});
|
||||
|
||||
@@ -117,6 +117,21 @@ const MCPServerEdit: React.FC<MCPServerEditProps> = ({
|
||||
},
|
||||
onTokenReceived: (token) => {
|
||||
setOauthAccessToken(token?.access_token ?? null);
|
||||
|
||||
if (token?.access_token) {
|
||||
const credentials = {
|
||||
access_token: token.access_token,
|
||||
...(token.refresh_token && { refresh_token: token.refresh_token }),
|
||||
...(token.expires_in && { expires_in: token.expires_in }),
|
||||
...(token.scope && { scope: token.scope }),
|
||||
};
|
||||
|
||||
form.setFieldsValue({ credentials });
|
||||
|
||||
NotificationsManager.success(
|
||||
"OAuth authorization successful! Please click 'Update MCP Server' to save the credentials."
|
||||
);
|
||||
}
|
||||
},
|
||||
onBeforeRedirect: persistEditUiState,
|
||||
});
|
||||
@@ -234,8 +249,13 @@ const MCPServerEdit: React.FC<MCPServerEditProps> = ({
|
||||
}
|
||||
}, [mcpServer]);
|
||||
|
||||
// Fetch tools when component mounts
|
||||
// Fetch tools when component mounts or when OAuth token is received
|
||||
// But only if the server has been properly saved (has a permanent server_id)
|
||||
useEffect(() => {
|
||||
// Don't fetch if server hasn't been saved yet (no permanent server_id)
|
||||
if (!mcpServer.server_id || mcpServer.server_id.trim() === "") {
|
||||
return;
|
||||
}
|
||||
fetchTools();
|
||||
}, [mcpServer, accessToken, oauthAccessToken]);
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
"use client";
|
||||
|
||||
import { useCallback, useEffect, useState } from "react";
|
||||
import { useCallback, useEffect, useRef, useState } from "react";
|
||||
import NotificationsManager from "@/components/molecules/notifications_manager";
|
||||
import {
|
||||
buildMcpOAuthAuthorizeUrl,
|
||||
@@ -61,6 +61,7 @@ export const useMcpOAuthFlow = ({
|
||||
const [status, setStatus] = useState<McpOAuthStatus>("idle");
|
||||
const [error, setError] = useState<string | null>(null);
|
||||
const [tokenResponse, setTokenResponse] = useState<Record<string, any> | null>(null);
|
||||
const processingRef = useRef(false);
|
||||
|
||||
const FLOW_STATE_KEY = "litellm-mcp-oauth-flow-state";
|
||||
const RESULT_KEY = "litellm-mcp-oauth-result";
|
||||
@@ -75,6 +76,28 @@ export const useMcpOAuthFlow = ({
|
||||
redirectUri: string;
|
||||
};
|
||||
|
||||
const setStorageItem = (key: string, value: string) => {
|
||||
if (typeof window === "undefined") return;
|
||||
try {
|
||||
// Store in both sessionStorage and localStorage for redundancy
|
||||
window.sessionStorage.setItem(key, value);
|
||||
window.localStorage.setItem(key, value);
|
||||
} catch (err) {
|
||||
console.warn(`Failed to set storage item ${key}`, err);
|
||||
}
|
||||
};
|
||||
|
||||
const getStorageItem = (key: string): string | null => {
|
||||
if (typeof window === "undefined") return null;
|
||||
try {
|
||||
// Try sessionStorage first, fall back to localStorage
|
||||
return window.sessionStorage.getItem(key) || window.localStorage.getItem(key);
|
||||
} catch (err) {
|
||||
console.warn(`Failed to get storage item ${key}`, err);
|
||||
return null;
|
||||
}
|
||||
};
|
||||
|
||||
const clearStoredFlow = () => {
|
||||
if (typeof window === "undefined") {
|
||||
return;
|
||||
@@ -83,6 +106,9 @@ export const useMcpOAuthFlow = ({
|
||||
window.sessionStorage.removeItem(FLOW_STATE_KEY);
|
||||
window.sessionStorage.removeItem(RESULT_KEY);
|
||||
window.sessionStorage.removeItem(RETURN_URL_KEY);
|
||||
window.localStorage.removeItem(FLOW_STATE_KEY);
|
||||
window.localStorage.removeItem(RESULT_KEY);
|
||||
window.localStorage.removeItem(RETURN_URL_KEY);
|
||||
} catch (err) {
|
||||
console.warn("Failed to clear OAuth storage", err);
|
||||
}
|
||||
@@ -187,10 +213,9 @@ export const useMcpOAuthFlow = ({
|
||||
}
|
||||
|
||||
try {
|
||||
window.sessionStorage.setItem(FLOW_STATE_KEY, JSON.stringify(flowState));
|
||||
window.sessionStorage.setItem(RETURN_URL_KEY, window.location.href);
|
||||
setStorageItem(FLOW_STATE_KEY, JSON.stringify(flowState));
|
||||
setStorageItem(RETURN_URL_KEY, window.location.href);
|
||||
} catch (storageErr) {
|
||||
console.error("Unable to persist OAuth state", storageErr);
|
||||
throw new Error("Unable to access browser storage for OAuth. Please enable storage and retry.");
|
||||
}
|
||||
|
||||
@@ -209,19 +234,28 @@ export const useMcpOAuthFlow = ({
|
||||
return;
|
||||
}
|
||||
|
||||
// Prevent duplicate processing
|
||||
if (processingRef.current) {
|
||||
return;
|
||||
}
|
||||
|
||||
let payload: Record<string, any> | null = null;
|
||||
let flowState: StoredFlowState | null = null;
|
||||
|
||||
try {
|
||||
const storedPayload = window.sessionStorage.getItem(RESULT_KEY);
|
||||
const storedPayload = getStorageItem(RESULT_KEY);
|
||||
if (!storedPayload) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Mark as processing
|
||||
processingRef.current = true;
|
||||
payload = JSON.parse(storedPayload);
|
||||
flowState = JSON.parse(window.sessionStorage.getItem(FLOW_STATE_KEY) || "null");
|
||||
const storedFlowState = getStorageItem(FLOW_STATE_KEY);
|
||||
flowState = storedFlowState ? JSON.parse(storedFlowState) : null;
|
||||
} catch (err) {
|
||||
console.error("Failed to read OAuth session state", err);
|
||||
clearStoredFlow();
|
||||
processingRef.current = false;
|
||||
setError("Failed to resume OAuth flow. Please retry.");
|
||||
setStatus("error");
|
||||
NotificationsManager.error("Failed to resume OAuth flow. Please retry.");
|
||||
@@ -229,14 +263,26 @@ export const useMcpOAuthFlow = ({
|
||||
}
|
||||
|
||||
if (!payload) {
|
||||
processingRef.current = false;
|
||||
return;
|
||||
}
|
||||
|
||||
window.sessionStorage.removeItem(RESULT_KEY);
|
||||
// Clear the result key after reading it
|
||||
if (typeof window !== "undefined") {
|
||||
try {
|
||||
window.sessionStorage.removeItem(RESULT_KEY);
|
||||
window.localStorage.removeItem(RESULT_KEY);
|
||||
} catch (err) {
|
||||
// Silently ignore storage errors
|
||||
}
|
||||
}
|
||||
|
||||
try {
|
||||
if (!flowState || !flowState.state || !flowState.codeVerifier || !flowState.serverId) {
|
||||
throw new Error("Missing OAuth session state. Please retry.");
|
||||
throw new Error(
|
||||
"OAuth session state was lost. This can happen if you have strict browser privacy settings. " +
|
||||
"Please try again and ensure cookies/storage is enabled."
|
||||
);
|
||||
}
|
||||
if (!payload.state || payload.state !== flowState.state) {
|
||||
throw new Error("OAuth state mismatch. Please retry.");
|
||||
@@ -264,31 +310,21 @@ export const useMcpOAuthFlow = ({
|
||||
setError(null);
|
||||
NotificationsManager.success("OAuth token retrieved successfully");
|
||||
} catch (err) {
|
||||
console.error("OAuth flow failed", err);
|
||||
const message = err instanceof Error ? err.message : String(err);
|
||||
setError(message);
|
||||
setStatus("error");
|
||||
NotificationsManager.error(message);
|
||||
} finally {
|
||||
clearStoredFlow();
|
||||
// Reset processing flag after a delay to allow UI updates
|
||||
setTimeout(() => {
|
||||
processingRef.current = false;
|
||||
}, 1000);
|
||||
}
|
||||
}, [onTokenReceived]);
|
||||
|
||||
useEffect(() => {
|
||||
let cancelled = false;
|
||||
|
||||
const maybeResume = async () => {
|
||||
if (cancelled) {
|
||||
return;
|
||||
}
|
||||
await resumeOAuthFlow();
|
||||
};
|
||||
|
||||
maybeResume();
|
||||
|
||||
return () => {
|
||||
cancelled = true;
|
||||
};
|
||||
resumeOAuthFlow();
|
||||
}, [resumeOAuthFlow]);
|
||||
|
||||
return {
|
||||
|
||||
Reference in New Issue
Block a user