feat: add dynamic OAuth2 metadata discovery for MCP servers (#16676)

* feat: add dynamic OAuth2 metadata discovery for MCP servers

* fix: lint error
This commit is contained in:
YutaSaito
2025-11-15 11:14:43 +09:00
committed by GitHub
parent 6063a75155
commit f487f4e3a9
10 changed files with 474 additions and 80 deletions
+9 -11
View File
@@ -211,11 +211,12 @@ mcp_servers:
oauth2_example:
url: "https://my-mcp-server.com/mcp"
auth_type: "oauth2" # 👈 KEY CHANGE
authorization_url: "https://my-mcp-server.com/oauth/authorize" # optional for client-credentials
token_url: "https://my-mcp-server.com/oauth/token" # required
authorization_url: "https://my-mcp-server.com/oauth/authorize" # optional override
token_url: "https://my-mcp-server.com/oauth/token" # optional override
registration_url: "https://my-mcp-server.com/oauth/register" # optional override
client_id: os.environ/OAUTH_CLIENT_ID
client_secret: os.environ/OAUTH_CLIENT_SECRET
scopes: ["tool.read", "tool.write"] # optional
scopes: ["tool.read", "tool.write"] # optional override
bearer_example:
url: "https://my-mcp-server.com/mcp"
@@ -325,6 +326,10 @@ mcp_servers:
| `spec_path` | Yes | Path or URL to your OpenAPI specification file (JSON or YAML) |
| `auth_type` | No | Authentication type: `none`, `api_key`, `bearer_token`, `basic`, `authorization` |
| `auth_value` | No | Authentication value (required if `auth_type` is set) |
| `authorization_url` | No | For `auth_type: oauth2`. Optional override; if omitted LiteLLM auto-discovers it. |
| `token_url` | No | For `auth_type: oauth2`. Optional override; if omitted LiteLLM auto-discovers it. |
| `registration_url` | No | For `auth_type: oauth2`. Optional override; if omitted LiteLLM auto-discovers it. |
| `scopes` | No | For `auth_type: oauth2`. Optional override; if omitted LiteLLM uses the scopes advertised by the server. |
| `description` | No | Optional description for the MCP server |
| `allowed_tools` | No | List of specific tools to allow (see [MCP Tool Filtering](#mcp-tool-filtering)) |
| `disallowed_tools` | No | List of specific tools to block (see [MCP Tool Filtering](#mcp-tool-filtering)) |
@@ -1224,17 +1229,10 @@ mcp_servers:
github_mcp:
url: "https://api.githubcopilot.com/mcp"
auth_type: oauth2
authorization_url: https://github.com/login/oauth/authorize
token_url: https://github.com/login/oauth/access_token
client_id: os.environ/GITHUB_OAUTH_CLIENT_ID
client_secret: os.environ/GITHUB_OAUTH_CLIENT_SECRET
scopes: ["public_repo", "user:email"]
```
**Note**
In the future, users will only need to specify the `url` of the MCP server.
LiteLLM will automatically resolve the corresponding `authorization_url`, `token_url`, and `registration_url` based on the MCP server metadata (e.g., `.well-known/oauth-authorization-server` or `oauth-protected-resource`).
[**See Claude Code Tutorial**](./tutorials/claude_responses_api#connecting-mcp-servers)
## Using your MCP with client side credentials
@@ -1887,4 +1885,4 @@ async with stdio_client(server_params) as (read, write):
```
</TabItem>
</Tabs>
</Tabs>
@@ -237,11 +237,8 @@ mcp_servers:
github_mcp:
url: "https://api.githubcopilot.com/mcp"
auth_type: oauth2
authorization_url: https://github.com/login/oauth/authorize
token_url: https://github.com/login/oauth/access_token
client_id: os.environ/GITHUB_OAUTH_CLIENT_ID
client_secret: os.environ/GITHUB_OAUTH_CLIENT_SECRET
scopes: ["public_repo", "user:email"]
```
</TabItem>
@@ -255,9 +252,6 @@ atlassian_mcp:
url: "https://mcp.atlassian.com/v1/sse"
transport: "sse"
auth_type: oauth2
authorization_url: https://mcp.atlassian.com/v1/authorize
token_url: https://cf.mcp.atlassian.com/v1/token
registration_url: https://cf.mcp.atlassian.com/v1/register
```
</TabItem>
@@ -10,9 +10,13 @@ import asyncio
import datetime
import hashlib
import json
from typing import Any, Dict, List, Optional, Set, Union, cast
import re
from typing import Any, Dict, List, Optional, Set, Tuple, Union, cast
from urllib.parse import urlparse
from fastapi import HTTPException
import httpx
from httpx import HTTPStatusError
from mcp.types import CallToolRequestParams as MCPCallToolRequestParams
from mcp.types import CallToolResult
from mcp.types import Tool as MCPTool
@@ -43,7 +47,11 @@ from litellm.proxy.common_utils.encrypt_decrypt_utils import (
)
from litellm.proxy.utils import ProxyLogging
from litellm.types.mcp import MCPAuth, MCPStdioConfig
from litellm.types.mcp_server.mcp_server_manager import MCPInfo, MCPServer
from litellm.types.mcp_server.mcp_server_manager import (
MCPInfo,
MCPOAuthMetadata,
MCPServer,
)
def _deserialize_json_dict(data: Any) -> Optional[Dict[str, str]]:
@@ -100,7 +108,7 @@ class MCPServerManager:
"""
return self.config_mcp_servers | self.registry
def load_servers_from_config(
async def load_servers_from_config(
self,
mcp_servers_config: Dict[str, Any],
mcp_aliases: Optional[Dict[str, str]] = None,
@@ -180,35 +188,57 @@ class MCPServerManager:
)()
name_for_prefix = get_server_prefix(temp_server)
server_url = server_config.get("url", None) or ""
# Generate stable server ID based on parameters
server_id = self._generate_stable_server_id(
server_name=server_name,
url=server_config.get("url", None) or "",
url=server_url,
transport=server_config.get("transport", MCPTransport.http),
auth_type=server_config.get("auth_type", None),
alias=alias,
)
auth_type = server_config.get("auth_type", None)
if server_url and auth_type is not None and auth_type == MCPAuth.oauth2:
mcp_oauth_metadata = await self._descovery_metadata(
server_url=server_url,
)
else:
mcp_oauth_metadata = None
resolved_scopes = server_config.get("scopes") or (
mcp_oauth_metadata.scopes if mcp_oauth_metadata else None
)
resolved_authorization_url = server_config.get("authorization_url") or (
mcp_oauth_metadata.authorization_url if mcp_oauth_metadata else None
)
resolved_token_url = server_config.get("token_url") or (
mcp_oauth_metadata.token_url if mcp_oauth_metadata else None
)
resolved_registration_url = server_config.get("registration_url") or (
mcp_oauth_metadata.registration_url if mcp_oauth_metadata else None
)
new_server = MCPServer(
server_id=server_id,
name=name_for_prefix,
alias=alias,
server_name=server_name,
spec_path=server_config.get("spec_path", None),
url=server_config.get("url", None) or "",
url=server_url,
command=server_config.get("command", None) or "",
args=server_config.get("args", None) or [],
env=server_config.get("env", None) or {},
# oauth specific fields
client_id=server_config.get("client_id", None),
client_secret=server_config.get("client_secret", None),
scopes=server_config.get("scopes", None),
authorization_url=server_config.get("authorization_url", None),
token_url=server_config.get("token_url", None),
registration_url=server_config.get("registration_url", None),
scopes=resolved_scopes,
authorization_url=resolved_authorization_url,
token_url=resolved_token_url,
registration_url=resolved_registration_url,
# TODO: utility fn the default values
transport=server_config.get("transport", MCPTransport.http),
auth_type=server_config.get("auth_type", None),
auth_type=auth_type,
authentication_token=server_config.get(
"authentication_token", server_config.get("auth_value", None)
),
@@ -692,6 +722,250 @@ class MCPServerManager:
except Exception:
pass
async def _descovery_metadata(
self,
server_url: str,
) -> Optional[MCPOAuthMetadata]:
"""Discover OAuth metadata by following RFC 9728 (protected resource metadata discovery)."""
try:
async with httpx.AsyncClient(
timeout=10.0, follow_redirects=False
) as client:
response = await client.get(server_url)
response.raise_for_status()
verbose_logger.warning(
"MCP OAuth discovery unexpectedly succeeded for %s; server did not challenge",
server_url,
)
raise RuntimeError("OAuth discovery must not succeed without a challenge")
except HTTPStatusError as exc:
verbose_logger.debug(
"MCP OAuth discovery for %s received status error: %s",
server_url,
exc,
)
header_value: Optional[str] = None
if exc.response is not None:
header_value = exc.response.headers.get(
"WWW-Authenticate"
) or exc.response.headers.get("www-authenticate")
resource_metadata_url, scopes = self._parse_www_authenticate_header(
header_value
)
authorization_servers: List[str] = []
resource_scopes: Optional[List[str]] = None
if resource_metadata_url:
(
authorization_servers,
resource_scopes,
) = await self._fetch_oauth_metadata_from_resource(
resource_metadata_url
)
else:
(
authorization_servers,
resource_scopes,
) = await self._attempt_well_known_discovery(server_url)
metadata = None
if not authorization_servers:
try:
parsed_url = urlparse(server_url)
if parsed_url.scheme and parsed_url.netloc:
authorization_servers = [
f"{parsed_url.scheme}://{parsed_url.netloc}"
]
except Exception:
authorization_servers = []
if authorization_servers:
metadata = await self._fetch_authorization_server_metadata(
authorization_servers
)
preferred_scopes = scopes or resource_scopes
if metadata is None and preferred_scopes:
metadata = MCPOAuthMetadata(scopes=preferred_scopes)
elif metadata is not None and preferred_scopes:
metadata.scopes = preferred_scopes
return metadata
except Exception as exc: # pragma: no cover - network/transient issues
verbose_logger.debug(
"MCP OAuth discovery failed for %s: %s", server_url, exc
)
return None
def _parse_www_authenticate_header(
self, header_value: Optional[str]
) -> Tuple[Optional[str], Optional[List[str]]]:
if not header_value:
return None, None
_, _, params_section = header_value.partition(" ")
params_section = params_section or header_value
param_pattern = re.compile(r"([a-zA-Z0-9_]+)\s*=\s*\"?([^\",]+)\"?")
params: Dict[str, str] = {
match.group(1).lower(): match.group(2).strip()
for match in param_pattern.finditer(params_section)
}
resource_metadata_url = params.get("resource_metadata")
scope_value = params.get("scope")
scopes_list = [s for s in (scope_value.split() if scope_value else []) if s]
scopes = scopes_list or None
return resource_metadata_url, scopes
async def _fetch_oauth_metadata_from_resource(
self, resource_metadata_url: str
) -> Tuple[List[str], Optional[List[str]]]:
if not resource_metadata_url:
return [], None
try:
async with httpx.AsyncClient(timeout=10.0, follow_redirects=True) as client:
response = await client.get(resource_metadata_url)
response.raise_for_status()
data = response.json()
except Exception as exc: # pragma: no cover - network issues
verbose_logger.debug(
"Failed to fetch MCP OAuth metadata from %s: %s",
resource_metadata_url,
exc,
)
return [], None
raw_servers = data.get("authorization_servers")
if isinstance(raw_servers, list):
authorization_servers = [
entry
for entry in raw_servers
if isinstance(entry, str) and entry.strip() != ""
]
else:
authorization_servers = []
scopes = self._extract_scopes(
data.get("scopes_supported") or data.get("scopes")
)
return authorization_servers, scopes
async def _attempt_well_known_discovery(
self, server_url: str
) -> Tuple[List[str], Optional[List[str]]]:
try:
parsed = urlparse(server_url)
except Exception:
return [], None
if not parsed.scheme or not parsed.netloc:
return [], None
base = f"{parsed.scheme}://{parsed.netloc}"
path = parsed.path or ""
path = path.strip("/")
candidate_urls: List[str] = []
if path:
candidate_urls.append(f"{base}/.well-known/oauth-protected-resource/{path}")
candidate_urls.append(f"{base}/.well-known/oauth-protected-resource")
for url in candidate_urls:
(
authorization_servers,
scopes,
) = await self._fetch_oauth_metadata_from_resource(url)
if authorization_servers:
return authorization_servers, scopes
return [], None
async def _fetch_authorization_server_metadata(
self, authorization_servers: List[str]
) -> Optional[MCPOAuthMetadata]:
for issuer in authorization_servers:
metadata = await self._fetch_single_authorization_server_metadata(issuer)
if metadata is not None:
return metadata
return None
async def _fetch_single_authorization_server_metadata(
self, issuer_url: str
) -> Optional[MCPOAuthMetadata]:
try:
parsed = urlparse(issuer_url)
except Exception:
return None
if not parsed.scheme or not parsed.netloc:
return None
base = f"{parsed.scheme}://{parsed.netloc}"
path = (parsed.path or "").strip("/")
candidate_urls: List[str] = []
if path:
candidate_urls.append(
f"{base}/.well-known/oauth-authorization-server/{path}"
)
candidate_urls.append(f"{base}/.well-known/openid-configuration/{path}")
candidate_urls.append(f"{base}/.well-known/oauth-authorization-server")
candidate_urls.append(f"{base}/.well-known/openid-configuration")
candidate_urls.append(issuer_url.rstrip("/"))
for url in candidate_urls:
try:
async with httpx.AsyncClient(
timeout=10.0, follow_redirects=True
) as client:
response = await client.get(url)
response.raise_for_status()
data = response.json()
except Exception as exc: # pragma: no cover - network issues
verbose_logger.debug(
"Failed to fetch authorization metadata from %s: %s",
url,
exc,
)
continue
scopes = self._extract_scopes(data.get("scopes_supported"))
metadata = MCPOAuthMetadata(
scopes=scopes,
authorization_url=data.get("authorization_endpoint"),
token_url=data.get("token_endpoint"),
registration_url=data.get("registration_endpoint"),
)
if any(
[
metadata.scopes,
metadata.authorization_url,
metadata.token_url,
metadata.registration_url,
]
):
return metadata
return None
def _extract_scopes(self, scopes_value: Any) -> Optional[List[str]]:
if isinstance(scopes_value, str):
scopes = [s.strip() for s in scopes_value.split() if s.strip()]
return scopes or None
if isinstance(scopes_value, list):
scopes = [s for s in scopes_value if isinstance(s, str) and s.strip()]
return scopes or None
return None
async def _fetch_tools_with_timeout(
self, client: MCPClient, server_name: str
) -> List[MCPTool]:
@@ -721,11 +995,6 @@ class MCPServerManager:
f"Client operation failed for {server_name}: {str(e)}"
)
return []
finally:
try:
await client.disconnect()
except Exception:
pass
try:
return await asyncio.wait_for(_list_tools_task(), timeout=30.0)
+3 -3
View File
@@ -2572,11 +2572,11 @@ class ProxyConfig:
litellm.credential_list = credential_list_dict
## NON-LLM CONFIGS eg. MCP tools, vector stores, etc.
self._init_non_llm_configs(config=config)
await self._init_non_llm_configs(config=config)
return router, router.get_model_list(), general_settings
def _init_non_llm_configs(self, config: dict):
async def _init_non_llm_configs(self, config: dict):
"""
Initialize non-LLM configs eg. MCP tools, vector stores, etc.
"""
@@ -2595,7 +2595,7 @@ class ProxyConfig:
litellm_settings = config.get("litellm_settings", {})
mcp_aliases = litellm_settings.get("mcp_aliases", None)
global_mcp_server_manager.load_servers_from_config(
await global_mcp_server_manager.load_servers_from_config(
mcp_servers_config, mcp_aliases
)
@@ -7,6 +7,11 @@ from litellm.proxy._types import MCPAuthType, MCPTransportType
# MCPInfo now allows arbitrary additional fields for custom metadata
MCPInfo = Dict[str, Any]
class MCPOAuthMetadata(BaseModel):
scopes: Optional[List[str]] = None
authorization_url: Optional[str] = None
token_url: Optional[str] = None
registration_url: Optional[str] = None
class MCPServer(BaseModel):
server_id: str
+2 -2
View File
@@ -43,7 +43,7 @@ def test_mcp_server_works_without_config_auth_value():
@pytest.mark.parametrize("token_key", ["authentication_token", "auth_value"])
def test_mcp_server_config_auth_value_header_used(token_key):
async def test_mcp_server_config_auth_value_header_used(token_key):
"""Ensure auth header is sent when auth token configured in config"""
config = {
"test_server": {
@@ -55,7 +55,7 @@ def test_mcp_server_config_auth_value_header_used(token_key):
}
manager = MCPServerManager()
manager.load_servers_from_config(config)
await manager.load_servers_from_config(config)
server = next(iter(manager.config_mcp_servers.values()))
client = manager._create_mcp_client(server)
+3 -3
View File
@@ -66,7 +66,7 @@ async def test_mcp_cost_tracking():
with patch('litellm.proxy._experimental.mcp_server.mcp_server_manager.MCPClient', mock_client_constructor):
# Load the server config
local_mcp_server_manager.load_servers_from_config(
await local_mcp_server_manager.load_servers_from_config(
mcp_servers_config={
"zapier_gmail_server": {
"url": os.getenv("ZAPIER_MCP_HTTPS_SERVER_URL"),
@@ -166,7 +166,7 @@ async def test_mcp_cost_tracking_per_tool():
with patch('litellm.proxy._experimental.mcp_server.mcp_server_manager.MCPClient', mock_client_constructor):
# Load the server config with per-tool costs
local_mcp_server_manager.load_servers_from_config(
await local_mcp_server_manager.load_servers_from_config(
mcp_servers_config={
"test_server": {
"url": os.getenv("ZAPIER_MCP_HTTPS_SERVER_URL"),
@@ -310,7 +310,7 @@ async def test_mcp_tool_call_hook():
with patch('litellm.proxy._experimental.mcp_server.mcp_server_manager.MCPClient', mock_client_constructor):
# Load the server config
local_mcp_server_manager.load_servers_from_config(
await local_mcp_server_manager.load_servers_from_config(
mcp_servers_config={
"zapier_gmail_server": {
"url": os.getenv("ZAPIER_MCP_HTTPS_SERVER_URL"),
+28 -29
View File
@@ -24,7 +24,7 @@ mcp_server_manager = MCPServerManager()
@pytest.mark.asyncio
@pytest.mark.skip(reason="Local only test")
async def test_mcp_server_manager():
mcp_server_manager.load_servers_from_config(
await mcp_server_manager.load_servers_from_config(
{
"zapier_mcp_server": {
"url": os.environ.get("ZAPIER_MCP_SERVER_URL"),
@@ -79,7 +79,7 @@ async def test_mcp_server_manager_https_server():
"litellm.proxy._experimental.mcp_server.mcp_server_manager.MCPClient",
mock_client_constructor,
):
mcp_server_manager.load_servers_from_config(
await mcp_server_manager.load_servers_from_config(
{
"zapier_mcp_server": {
"url": "https://test-mcp-server.com/mcp",
@@ -189,7 +189,7 @@ async def test_mcp_http_transport_list_tools_mock():
mock_client_constructor,
):
# Load server config with HTTP transport
test_manager.load_servers_from_config(
await test_manager.load_servers_from_config(
{
"test_http_server": {
"url": "https://test-mcp-server.com/mcp",
@@ -266,7 +266,7 @@ async def test_mcp_http_transport_call_tool_mock():
mock_client_constructor,
):
# Load server config with HTTP transport
test_manager.load_servers_from_config(
await test_manager.load_servers_from_config(
{
"test_http_server": {
"url": "https://test-mcp-server.com/mcp",
@@ -333,7 +333,7 @@ async def test_mcp_http_transport_call_tool_error_mock():
mock_client_constructor,
):
# Load server config with HTTP transport
test_manager.load_servers_from_config(
await test_manager.load_servers_from_config(
{
"test_http_server": {
"url": "https://test-mcp-server.com/mcp",
@@ -376,7 +376,7 @@ async def test_mcp_http_transport_tool_not_found():
test_manager = MCPServerManager()
# Load server config
test_manager.load_servers_from_config(
await test_manager.load_servers_from_config(
{
"test_http_server": {
"url": "https://test-mcp-server.com/mcp",
@@ -699,7 +699,7 @@ async def test_list_tools_rest_api_success():
mock_client_constructor,
):
# Load server config into global manager
global_mcp_server_manager.load_servers_from_config(
await global_mcp_server_manager.load_servers_from_config(
{
"test_server": {
"url": "https://test-server.com/mcp",
@@ -878,7 +878,7 @@ async def test_list_tools_only_returns_allowed_servers(monkeypatch):
test_manager = MCPServerManager()
# Setup two servers in the config
test_manager.load_servers_from_config(
await test_manager.load_servers_from_config(
{
"server_a": {
"url": "https://server-a.com/mcp",
@@ -944,12 +944,13 @@ async def test_list_tools_only_returns_allowed_servers(monkeypatch):
assert tools[0].name.startswith(f"{expected_prefix}-")
def test_mcp_server_manager_access_groups_from_config():
@pytest.mark.asyncio
async def test_mcp_server_manager_access_groups_from_config():
"""
Test that access_groups are loaded from config and can be resolved.
"""
test_manager = MCPServerManager()
test_manager.load_servers_from_config(
await test_manager.load_servers_from_config(
{
"config_server": {
"url": "https://config-mcp-server.com/mcp",
@@ -986,15 +987,15 @@ def test_mcp_server_manager_access_groups_from_config():
# Should find config_server for group-a, both for group-b, other_server for group-c
import asyncio
server_ids_a = asyncio.run(
MCPRequestHandler._get_mcp_servers_from_access_groups(["group-a"])
)
server_ids_b = asyncio.run(
MCPRequestHandler._get_mcp_servers_from_access_groups(["group-b"])
)
server_ids_c = asyncio.run(
MCPRequestHandler._get_mcp_servers_from_access_groups(["group-c"])
)
server_ids_a = await MCPRequestHandler._get_mcp_servers_from_access_groups([
"group-a"
])
server_ids_b = await MCPRequestHandler._get_mcp_servers_from_access_groups([
"group-b"
])
server_ids_c = await MCPRequestHandler._get_mcp_servers_from_access_groups([
"group-c"
])
assert any(config_server.server_id == sid for sid in server_ids_a)
assert set(server_ids_b) == set(
[
@@ -1009,7 +1010,7 @@ def test_mcp_server_manager_access_groups_from_config():
)
def test_mcp_server_manager_config_integration_with_database():
async def test_mcp_server_manager_config_integration_with_database():
"""
Test that config-based servers properly integrate with database servers,
specifically testing access_groups and description fields.
@@ -1020,7 +1021,7 @@ def test_mcp_server_manager_config_integration_with_database():
test_manager = MCPServerManager()
# Test 1: Load config with access_groups and description
test_manager.load_servers_from_config(
await test_manager.load_servers_from_config(
{
"config_server_with_groups": {
"url": "https://config-server.com/mcp",
@@ -1081,10 +1082,8 @@ def test_mcp_server_manager_config_integration_with_database():
# Test the method (this tests our second fix)
import asyncio
servers_list = asyncio.run(
test_manager.get_all_mcp_servers_with_health_and_teams(
user_api_key_auth=mock_user_auth
)
servers_list = await test_manager.get_all_mcp_servers_with_health_and_teams(
user_api_key_auth=mock_user_auth
)
# Verify we have the config server properly converted
@@ -1536,7 +1535,7 @@ async def test_mcp_protocol_version_passed_to_client():
mock_client_constructor,
):
# Load a test server
test_manager.load_servers_from_config(
await test_manager.load_servers_from_config(
{
"test_server": {
"url": "https://test-server.com/mcp",
@@ -2423,7 +2422,7 @@ async def test_mcp_server_manager_with_access_groups_integration():
test_manager = MCPServerManager()
# Load servers with access groups
test_manager.load_servers_from_config(
await test_manager.load_servers_from_config(
{
"staff_server": {
"url": "https://staff-server.com/mcp",
@@ -2470,7 +2469,7 @@ async def test_get_allowed_mcp_servers_returns_registry_for_admin():
)
test_manager = MCPServerManager()
test_manager.load_servers_from_config(
await test_manager.load_servers_from_config(
{
"alpha_server": {
"url": "https://alpha.server/mcp",
@@ -2505,7 +2504,7 @@ async def test_get_allowed_mcp_servers_returns_empty_for_non_admin_without_permi
)
test_manager = MCPServerManager()
test_manager.load_servers_from_config(
await test_manager.load_servers_from_config(
{
"alpha_server": {
"url": "https://alpha.server/mcp",
@@ -22,7 +22,7 @@ from litellm.proxy._types import LiteLLM_MCPServerTable
class TestMCPCustomFields:
"""Test custom fields functionality in MCP server configuration."""
def test_custom_fields_preserved_from_config(self):
async def test_custom_fields_preserved_from_config(self):
"""Test that custom fields in mcp_info are preserved when loading from config."""
manager = MCPServerManager()
@@ -46,7 +46,7 @@ class TestMCPCustomFields:
}
# Load servers from config
manager.load_servers_from_config(mock_config)
await manager.load_servers_from_config(mock_config)
# Get the loaded server
servers = list(manager.config_mcp_servers.values())
@@ -109,7 +109,7 @@ class TestMCPCustomFields:
assert mcp_info["metadata"] == {"source": "database"}
assert mcp_info["version"] == "1.0.0"
def test_empty_mcp_info_handled_gracefully(self):
async def test_empty_mcp_info_handled_gracefully(self):
"""Test that empty or missing mcp_info is handled gracefully."""
manager = MCPServerManager()
@@ -122,7 +122,7 @@ class TestMCPCustomFields:
}
}
manager.load_servers_from_config(mock_config)
await manager.load_servers_from_config(mock_config)
servers = list(manager.config_mcp_servers.values())
assert len(servers) == 1
@@ -133,7 +133,7 @@ class TestMCPCustomFields:
# Should have default server_name
assert mcp_info["server_name"] == "test_server"
def test_missing_mcp_info_creates_defaults(self):
async def test_missing_mcp_info_creates_defaults(self):
"""Test that missing mcp_info creates appropriate defaults."""
manager = MCPServerManager()
@@ -146,7 +146,7 @@ class TestMCPCustomFields:
}
}
manager.load_servers_from_config(mock_config)
await manager.load_servers_from_config(mock_config)
servers = list(manager.config_mcp_servers.values())
assert len(servers) == 1
@@ -158,7 +158,7 @@ class TestMCPCustomFields:
assert mcp_info["server_name"] == "test_server"
assert mcp_info["description"] == "Server description"
def test_config_description_fallback(self):
async def test_config_description_fallback(self):
"""Test that description from config level is used as fallback."""
manager = MCPServerManager()
@@ -174,7 +174,7 @@ class TestMCPCustomFields:
}
}
manager.load_servers_from_config(mock_config)
await manager.load_servers_from_config(mock_config)
servers = list(manager.config_mcp_servers.values())
server = servers[0]
@@ -184,7 +184,7 @@ class TestMCPCustomFields:
assert mcp_info["description"] == "Config level description"
assert mcp_info["custom_field"] == "custom_value"
def test_mcp_info_description_takes_precedence(self):
async def test_mcp_info_description_takes_precedence(self):
"""Test that description in mcp_info takes precedence over config level."""
manager = MCPServerManager()
@@ -201,7 +201,7 @@ class TestMCPCustomFields:
}
}
manager.load_servers_from_config(mock_config)
await manager.load_servers_from_config(mock_config)
servers = list(manager.config_mcp_servers.values())
server = servers[0]
@@ -8,12 +8,14 @@ from fastapi import HTTPException
# Add the parent directory to the path so we can import litellm
sys.path.insert(0, "../../../../../")
import httpx
from litellm.proxy._experimental.mcp_server.mcp_server_manager import (
MCPServerManager,
_deserialize_json_dict,
)
from litellm.proxy._types import LiteLLM_MCPServerTable, MCPTransport
from litellm.types.mcp_server.mcp_server_manager import MCPServer
from litellm.types.mcp import MCPAuth
from litellm.types.mcp_server.mcp_server_manager import MCPServer, MCPOAuthMetadata
class TestMCPServerManager:
@@ -218,6 +220,133 @@ class TestMCPServerManager:
assert len(result) == 1
assert result[0].name == "github_tool_1"
@pytest.mark.asyncio
async def test_fetch_oauth_metadata_from_resource_returns_servers_and_scopes(self):
manager = MCPServerManager()
mock_response = MagicMock()
mock_response.json.return_value = {
"authorization_servers": [
"https://auth1.example.com",
"https://auth2.example.com",
],
"scopes_supported": ["read", "write"],
}
mock_response.raise_for_status = MagicMock()
mock_client = MagicMock()
mock_client.get = AsyncMock(return_value=mock_response)
async_client_context = MagicMock()
async_client_context.__aenter__ = AsyncMock(return_value=mock_client)
async_client_context.__aexit__ = AsyncMock(return_value=None)
with patch(
"litellm.proxy._experimental.mcp_server.mcp_server_manager.httpx.AsyncClient",
return_value=async_client_context,
):
servers, scopes = await manager._fetch_oauth_metadata_from_resource(
"https://protected.example.com/.well-known/oauth"
)
assert servers == [
"https://auth1.example.com",
"https://auth2.example.com",
]
assert scopes == ["read", "write"]
@pytest.mark.asyncio
async def test_descovery_metadata_falls_back_to_origin_when_no_auth_servers(self):
manager = MCPServerManager()
server_url = "https://example.com/public/mcp"
request = httpx.Request("GET", server_url)
response_obj = httpx.Response(
status_code=401,
request=request,
headers={"WWW-Authenticate": 'Bearer scope="read"'},
)
def raise_http_error():
raise httpx.HTTPStatusError(
"unauthorized", request=request, response=response_obj
)
response_obj.raise_for_status = MagicMock(side_effect=raise_http_error)
mock_client = MagicMock()
mock_client.get = AsyncMock(return_value=response_obj)
async_client_context = MagicMock()
async_client_context.__aenter__ = AsyncMock(return_value=mock_client)
async_client_context.__aexit__ = AsyncMock(return_value=None)
mock_metadata = MCPOAuthMetadata(
scopes=None,
authorization_url="https://example.com/auth",
token_url="https://example.com/token",
registration_url=None,
)
with patch(
"litellm.proxy._experimental.mcp_server.mcp_server_manager.httpx.AsyncClient",
return_value=async_client_context,
), patch.object(
manager,
"_fetch_oauth_metadata_from_resource",
AsyncMock(return_value=([], None)),
), patch.object(
manager,
"_attempt_well_known_discovery",
AsyncMock(return_value=([], None)),
), patch.object(
manager,
"_fetch_authorization_server_metadata",
AsyncMock(return_value=mock_metadata),
) as mock_fetch_auth:
result = await manager._descovery_metadata(server_url)
mock_fetch_auth.assert_awaited_once_with(["https://example.com"])
assert result is mock_metadata
assert result.scopes == ["read"]
@pytest.mark.asyncio
async def test_load_servers_from_config_overrides_discovery_metadata(self):
manager = MCPServerManager()
discovered_metadata = MCPOAuthMetadata(
scopes=["discovered"],
authorization_url="https://discovered.example.com/auth",
token_url="https://discovered.example.com/token",
registration_url="https://discovered.example.com/register",
)
async def fake_discovery(server_url: str):
assert server_url == "https://example.com/mcp"
return discovered_metadata
manager._descovery_metadata = fake_discovery # type: ignore[attr-defined]
config = {
"example": {
"url": "https://example.com/mcp",
"transport": MCPTransport.http,
"auth_type": MCPAuth.oauth2,
"scopes": ["config"],
"authorization_url": "https://config.example.com/auth",
}
}
await manager.load_servers_from_config(config)
server = next(iter(manager.config_mcp_servers.values()))
assert server.scopes == ["config"] # config overrides discovery
assert server.authorization_url == "https://config.example.com/auth"
assert server.token_url == "https://discovered.example.com/token"
assert (
server.registration_url == "https://discovered.example.com/register"
)
@pytest.mark.asyncio
async def test_list_tools_handles_missing_server_alias(self):
"""Test that list_tools handles servers without alias gracefully"""