mirror of
https://github.com/tiennm99/litellm.git
synced 2026-06-17 22:48:35 +00:00
feat: add dynamic OAuth2 metadata discovery for MCP servers (#16676)
* feat: add dynamic OAuth2 metadata discovery for MCP servers * fix: lint error
This commit is contained in:
@@ -211,11 +211,12 @@ mcp_servers:
|
||||
oauth2_example:
|
||||
url: "https://my-mcp-server.com/mcp"
|
||||
auth_type: "oauth2" # 👈 KEY CHANGE
|
||||
authorization_url: "https://my-mcp-server.com/oauth/authorize" # optional for client-credentials
|
||||
token_url: "https://my-mcp-server.com/oauth/token" # required
|
||||
authorization_url: "https://my-mcp-server.com/oauth/authorize" # optional override
|
||||
token_url: "https://my-mcp-server.com/oauth/token" # optional override
|
||||
registration_url: "https://my-mcp-server.com/oauth/register" # optional override
|
||||
client_id: os.environ/OAUTH_CLIENT_ID
|
||||
client_secret: os.environ/OAUTH_CLIENT_SECRET
|
||||
scopes: ["tool.read", "tool.write"] # optional
|
||||
scopes: ["tool.read", "tool.write"] # optional override
|
||||
|
||||
bearer_example:
|
||||
url: "https://my-mcp-server.com/mcp"
|
||||
@@ -325,6 +326,10 @@ mcp_servers:
|
||||
| `spec_path` | Yes | Path or URL to your OpenAPI specification file (JSON or YAML) |
|
||||
| `auth_type` | No | Authentication type: `none`, `api_key`, `bearer_token`, `basic`, `authorization` |
|
||||
| `auth_value` | No | Authentication value (required if `auth_type` is set) |
|
||||
| `authorization_url` | No | For `auth_type: oauth2`. Optional override; if omitted LiteLLM auto-discovers it. |
|
||||
| `token_url` | No | For `auth_type: oauth2`. Optional override; if omitted LiteLLM auto-discovers it. |
|
||||
| `registration_url` | No | For `auth_type: oauth2`. Optional override; if omitted LiteLLM auto-discovers it. |
|
||||
| `scopes` | No | For `auth_type: oauth2`. Optional override; if omitted LiteLLM uses the scopes advertised by the server. |
|
||||
| `description` | No | Optional description for the MCP server |
|
||||
| `allowed_tools` | No | List of specific tools to allow (see [MCP Tool Filtering](#mcp-tool-filtering)) |
|
||||
| `disallowed_tools` | No | List of specific tools to block (see [MCP Tool Filtering](#mcp-tool-filtering)) |
|
||||
@@ -1224,17 +1229,10 @@ mcp_servers:
|
||||
github_mcp:
|
||||
url: "https://api.githubcopilot.com/mcp"
|
||||
auth_type: oauth2
|
||||
authorization_url: https://github.com/login/oauth/authorize
|
||||
token_url: https://github.com/login/oauth/access_token
|
||||
client_id: os.environ/GITHUB_OAUTH_CLIENT_ID
|
||||
client_secret: os.environ/GITHUB_OAUTH_CLIENT_SECRET
|
||||
scopes: ["public_repo", "user:email"]
|
||||
```
|
||||
|
||||
**Note**
|
||||
In the future, users will only need to specify the `url` of the MCP server.
|
||||
LiteLLM will automatically resolve the corresponding `authorization_url`, `token_url`, and `registration_url` based on the MCP server metadata (e.g., `.well-known/oauth-authorization-server` or `oauth-protected-resource`).
|
||||
|
||||
[**See Claude Code Tutorial**](./tutorials/claude_responses_api#connecting-mcp-servers)
|
||||
|
||||
## Using your MCP with client side credentials
|
||||
@@ -1887,4 +1885,4 @@ async with stdio_client(server_params) as (read, write):
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
</Tabs>
|
||||
</Tabs>
|
||||
|
||||
@@ -237,11 +237,8 @@ mcp_servers:
|
||||
github_mcp:
|
||||
url: "https://api.githubcopilot.com/mcp"
|
||||
auth_type: oauth2
|
||||
authorization_url: https://github.com/login/oauth/authorize
|
||||
token_url: https://github.com/login/oauth/access_token
|
||||
client_id: os.environ/GITHUB_OAUTH_CLIENT_ID
|
||||
client_secret: os.environ/GITHUB_OAUTH_CLIENT_SECRET
|
||||
scopes: ["public_repo", "user:email"]
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
@@ -255,9 +252,6 @@ atlassian_mcp:
|
||||
url: "https://mcp.atlassian.com/v1/sse"
|
||||
transport: "sse"
|
||||
auth_type: oauth2
|
||||
authorization_url: https://mcp.atlassian.com/v1/authorize
|
||||
token_url: https://cf.mcp.atlassian.com/v1/token
|
||||
registration_url: https://cf.mcp.atlassian.com/v1/register
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
|
||||
@@ -10,9 +10,13 @@ import asyncio
|
||||
import datetime
|
||||
import hashlib
|
||||
import json
|
||||
from typing import Any, Dict, List, Optional, Set, Union, cast
|
||||
import re
|
||||
from typing import Any, Dict, List, Optional, Set, Tuple, Union, cast
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from fastapi import HTTPException
|
||||
import httpx
|
||||
from httpx import HTTPStatusError
|
||||
from mcp.types import CallToolRequestParams as MCPCallToolRequestParams
|
||||
from mcp.types import CallToolResult
|
||||
from mcp.types import Tool as MCPTool
|
||||
@@ -43,7 +47,11 @@ from litellm.proxy.common_utils.encrypt_decrypt_utils import (
|
||||
)
|
||||
from litellm.proxy.utils import ProxyLogging
|
||||
from litellm.types.mcp import MCPAuth, MCPStdioConfig
|
||||
from litellm.types.mcp_server.mcp_server_manager import MCPInfo, MCPServer
|
||||
from litellm.types.mcp_server.mcp_server_manager import (
|
||||
MCPInfo,
|
||||
MCPOAuthMetadata,
|
||||
MCPServer,
|
||||
)
|
||||
|
||||
|
||||
def _deserialize_json_dict(data: Any) -> Optional[Dict[str, str]]:
|
||||
@@ -100,7 +108,7 @@ class MCPServerManager:
|
||||
"""
|
||||
return self.config_mcp_servers | self.registry
|
||||
|
||||
def load_servers_from_config(
|
||||
async def load_servers_from_config(
|
||||
self,
|
||||
mcp_servers_config: Dict[str, Any],
|
||||
mcp_aliases: Optional[Dict[str, str]] = None,
|
||||
@@ -180,35 +188,57 @@ class MCPServerManager:
|
||||
)()
|
||||
name_for_prefix = get_server_prefix(temp_server)
|
||||
|
||||
server_url = server_config.get("url", None) or ""
|
||||
# Generate stable server ID based on parameters
|
||||
server_id = self._generate_stable_server_id(
|
||||
server_name=server_name,
|
||||
url=server_config.get("url", None) or "",
|
||||
url=server_url,
|
||||
transport=server_config.get("transport", MCPTransport.http),
|
||||
auth_type=server_config.get("auth_type", None),
|
||||
alias=alias,
|
||||
)
|
||||
|
||||
auth_type = server_config.get("auth_type", None)
|
||||
if server_url and auth_type is not None and auth_type == MCPAuth.oauth2:
|
||||
mcp_oauth_metadata = await self._descovery_metadata(
|
||||
server_url=server_url,
|
||||
)
|
||||
else:
|
||||
mcp_oauth_metadata = None
|
||||
|
||||
resolved_scopes = server_config.get("scopes") or (
|
||||
mcp_oauth_metadata.scopes if mcp_oauth_metadata else None
|
||||
)
|
||||
resolved_authorization_url = server_config.get("authorization_url") or (
|
||||
mcp_oauth_metadata.authorization_url if mcp_oauth_metadata else None
|
||||
)
|
||||
resolved_token_url = server_config.get("token_url") or (
|
||||
mcp_oauth_metadata.token_url if mcp_oauth_metadata else None
|
||||
)
|
||||
resolved_registration_url = server_config.get("registration_url") or (
|
||||
mcp_oauth_metadata.registration_url if mcp_oauth_metadata else None
|
||||
)
|
||||
|
||||
new_server = MCPServer(
|
||||
server_id=server_id,
|
||||
name=name_for_prefix,
|
||||
alias=alias,
|
||||
server_name=server_name,
|
||||
spec_path=server_config.get("spec_path", None),
|
||||
url=server_config.get("url", None) or "",
|
||||
url=server_url,
|
||||
command=server_config.get("command", None) or "",
|
||||
args=server_config.get("args", None) or [],
|
||||
env=server_config.get("env", None) or {},
|
||||
# oauth specific fields
|
||||
client_id=server_config.get("client_id", None),
|
||||
client_secret=server_config.get("client_secret", None),
|
||||
scopes=server_config.get("scopes", None),
|
||||
authorization_url=server_config.get("authorization_url", None),
|
||||
token_url=server_config.get("token_url", None),
|
||||
registration_url=server_config.get("registration_url", None),
|
||||
scopes=resolved_scopes,
|
||||
authorization_url=resolved_authorization_url,
|
||||
token_url=resolved_token_url,
|
||||
registration_url=resolved_registration_url,
|
||||
# TODO: utility fn the default values
|
||||
transport=server_config.get("transport", MCPTransport.http),
|
||||
auth_type=server_config.get("auth_type", None),
|
||||
auth_type=auth_type,
|
||||
authentication_token=server_config.get(
|
||||
"authentication_token", server_config.get("auth_value", None)
|
||||
),
|
||||
@@ -692,6 +722,250 @@ class MCPServerManager:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
async def _descovery_metadata(
|
||||
self,
|
||||
server_url: str,
|
||||
) -> Optional[MCPOAuthMetadata]:
|
||||
"""Discover OAuth metadata by following RFC 9728 (protected resource metadata discovery)."""
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(
|
||||
timeout=10.0, follow_redirects=False
|
||||
) as client:
|
||||
response = await client.get(server_url)
|
||||
response.raise_for_status()
|
||||
verbose_logger.warning(
|
||||
"MCP OAuth discovery unexpectedly succeeded for %s; server did not challenge",
|
||||
server_url,
|
||||
)
|
||||
raise RuntimeError("OAuth discovery must not succeed without a challenge")
|
||||
except HTTPStatusError as exc:
|
||||
verbose_logger.debug(
|
||||
"MCP OAuth discovery for %s received status error: %s",
|
||||
server_url,
|
||||
exc,
|
||||
)
|
||||
|
||||
header_value: Optional[str] = None
|
||||
if exc.response is not None:
|
||||
header_value = exc.response.headers.get(
|
||||
"WWW-Authenticate"
|
||||
) or exc.response.headers.get("www-authenticate")
|
||||
|
||||
resource_metadata_url, scopes = self._parse_www_authenticate_header(
|
||||
header_value
|
||||
)
|
||||
|
||||
authorization_servers: List[str] = []
|
||||
resource_scopes: Optional[List[str]] = None
|
||||
if resource_metadata_url:
|
||||
(
|
||||
authorization_servers,
|
||||
resource_scopes,
|
||||
) = await self._fetch_oauth_metadata_from_resource(
|
||||
resource_metadata_url
|
||||
)
|
||||
else:
|
||||
(
|
||||
authorization_servers,
|
||||
resource_scopes,
|
||||
) = await self._attempt_well_known_discovery(server_url)
|
||||
|
||||
metadata = None
|
||||
if not authorization_servers:
|
||||
try:
|
||||
parsed_url = urlparse(server_url)
|
||||
if parsed_url.scheme and parsed_url.netloc:
|
||||
authorization_servers = [
|
||||
f"{parsed_url.scheme}://{parsed_url.netloc}"
|
||||
]
|
||||
except Exception:
|
||||
authorization_servers = []
|
||||
|
||||
if authorization_servers:
|
||||
metadata = await self._fetch_authorization_server_metadata(
|
||||
authorization_servers
|
||||
)
|
||||
|
||||
preferred_scopes = scopes or resource_scopes
|
||||
if metadata is None and preferred_scopes:
|
||||
metadata = MCPOAuthMetadata(scopes=preferred_scopes)
|
||||
elif metadata is not None and preferred_scopes:
|
||||
metadata.scopes = preferred_scopes
|
||||
|
||||
return metadata
|
||||
except Exception as exc: # pragma: no cover - network/transient issues
|
||||
verbose_logger.debug(
|
||||
"MCP OAuth discovery failed for %s: %s", server_url, exc
|
||||
)
|
||||
return None
|
||||
|
||||
def _parse_www_authenticate_header(
|
||||
self, header_value: Optional[str]
|
||||
) -> Tuple[Optional[str], Optional[List[str]]]:
|
||||
if not header_value:
|
||||
return None, None
|
||||
|
||||
_, _, params_section = header_value.partition(" ")
|
||||
params_section = params_section or header_value
|
||||
|
||||
param_pattern = re.compile(r"([a-zA-Z0-9_]+)\s*=\s*\"?([^\",]+)\"?")
|
||||
params: Dict[str, str] = {
|
||||
match.group(1).lower(): match.group(2).strip()
|
||||
for match in param_pattern.finditer(params_section)
|
||||
}
|
||||
|
||||
resource_metadata_url = params.get("resource_metadata")
|
||||
|
||||
scope_value = params.get("scope")
|
||||
scopes_list = [s for s in (scope_value.split() if scope_value else []) if s]
|
||||
scopes = scopes_list or None
|
||||
|
||||
return resource_metadata_url, scopes
|
||||
|
||||
async def _fetch_oauth_metadata_from_resource(
|
||||
self, resource_metadata_url: str
|
||||
) -> Tuple[List[str], Optional[List[str]]]:
|
||||
if not resource_metadata_url:
|
||||
return [], None
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=10.0, follow_redirects=True) as client:
|
||||
response = await client.get(resource_metadata_url)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
except Exception as exc: # pragma: no cover - network issues
|
||||
verbose_logger.debug(
|
||||
"Failed to fetch MCP OAuth metadata from %s: %s",
|
||||
resource_metadata_url,
|
||||
exc,
|
||||
)
|
||||
return [], None
|
||||
|
||||
raw_servers = data.get("authorization_servers")
|
||||
if isinstance(raw_servers, list):
|
||||
authorization_servers = [
|
||||
entry
|
||||
for entry in raw_servers
|
||||
if isinstance(entry, str) and entry.strip() != ""
|
||||
]
|
||||
else:
|
||||
authorization_servers = []
|
||||
|
||||
scopes = self._extract_scopes(
|
||||
data.get("scopes_supported") or data.get("scopes")
|
||||
)
|
||||
|
||||
return authorization_servers, scopes
|
||||
|
||||
async def _attempt_well_known_discovery(
|
||||
self, server_url: str
|
||||
) -> Tuple[List[str], Optional[List[str]]]:
|
||||
try:
|
||||
parsed = urlparse(server_url)
|
||||
except Exception:
|
||||
return [], None
|
||||
|
||||
if not parsed.scheme or not parsed.netloc:
|
||||
return [], None
|
||||
|
||||
base = f"{parsed.scheme}://{parsed.netloc}"
|
||||
path = parsed.path or ""
|
||||
path = path.strip("/")
|
||||
|
||||
candidate_urls: List[str] = []
|
||||
if path:
|
||||
candidate_urls.append(f"{base}/.well-known/oauth-protected-resource/{path}")
|
||||
candidate_urls.append(f"{base}/.well-known/oauth-protected-resource")
|
||||
|
||||
for url in candidate_urls:
|
||||
(
|
||||
authorization_servers,
|
||||
scopes,
|
||||
) = await self._fetch_oauth_metadata_from_resource(url)
|
||||
if authorization_servers:
|
||||
return authorization_servers, scopes
|
||||
|
||||
return [], None
|
||||
|
||||
async def _fetch_authorization_server_metadata(
|
||||
self, authorization_servers: List[str]
|
||||
) -> Optional[MCPOAuthMetadata]:
|
||||
for issuer in authorization_servers:
|
||||
metadata = await self._fetch_single_authorization_server_metadata(issuer)
|
||||
if metadata is not None:
|
||||
return metadata
|
||||
return None
|
||||
|
||||
async def _fetch_single_authorization_server_metadata(
|
||||
self, issuer_url: str
|
||||
) -> Optional[MCPOAuthMetadata]:
|
||||
try:
|
||||
parsed = urlparse(issuer_url)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
if not parsed.scheme or not parsed.netloc:
|
||||
return None
|
||||
|
||||
base = f"{parsed.scheme}://{parsed.netloc}"
|
||||
path = (parsed.path or "").strip("/")
|
||||
|
||||
candidate_urls: List[str] = []
|
||||
if path:
|
||||
candidate_urls.append(
|
||||
f"{base}/.well-known/oauth-authorization-server/{path}"
|
||||
)
|
||||
candidate_urls.append(f"{base}/.well-known/openid-configuration/{path}")
|
||||
candidate_urls.append(f"{base}/.well-known/oauth-authorization-server")
|
||||
candidate_urls.append(f"{base}/.well-known/openid-configuration")
|
||||
candidate_urls.append(issuer_url.rstrip("/"))
|
||||
|
||||
for url in candidate_urls:
|
||||
try:
|
||||
async with httpx.AsyncClient(
|
||||
timeout=10.0, follow_redirects=True
|
||||
) as client:
|
||||
response = await client.get(url)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
except Exception as exc: # pragma: no cover - network issues
|
||||
verbose_logger.debug(
|
||||
"Failed to fetch authorization metadata from %s: %s",
|
||||
url,
|
||||
exc,
|
||||
)
|
||||
continue
|
||||
|
||||
scopes = self._extract_scopes(data.get("scopes_supported"))
|
||||
metadata = MCPOAuthMetadata(
|
||||
scopes=scopes,
|
||||
authorization_url=data.get("authorization_endpoint"),
|
||||
token_url=data.get("token_endpoint"),
|
||||
registration_url=data.get("registration_endpoint"),
|
||||
)
|
||||
|
||||
if any(
|
||||
[
|
||||
metadata.scopes,
|
||||
metadata.authorization_url,
|
||||
metadata.token_url,
|
||||
metadata.registration_url,
|
||||
]
|
||||
):
|
||||
return metadata
|
||||
|
||||
return None
|
||||
|
||||
def _extract_scopes(self, scopes_value: Any) -> Optional[List[str]]:
|
||||
if isinstance(scopes_value, str):
|
||||
scopes = [s.strip() for s in scopes_value.split() if s.strip()]
|
||||
return scopes or None
|
||||
if isinstance(scopes_value, list):
|
||||
scopes = [s for s in scopes_value if isinstance(s, str) and s.strip()]
|
||||
return scopes or None
|
||||
return None
|
||||
|
||||
async def _fetch_tools_with_timeout(
|
||||
self, client: MCPClient, server_name: str
|
||||
) -> List[MCPTool]:
|
||||
@@ -721,11 +995,6 @@ class MCPServerManager:
|
||||
f"Client operation failed for {server_name}: {str(e)}"
|
||||
)
|
||||
return []
|
||||
finally:
|
||||
try:
|
||||
await client.disconnect()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
return await asyncio.wait_for(_list_tools_task(), timeout=30.0)
|
||||
|
||||
@@ -2572,11 +2572,11 @@ class ProxyConfig:
|
||||
litellm.credential_list = credential_list_dict
|
||||
|
||||
## NON-LLM CONFIGS eg. MCP tools, vector stores, etc.
|
||||
self._init_non_llm_configs(config=config)
|
||||
await self._init_non_llm_configs(config=config)
|
||||
|
||||
return router, router.get_model_list(), general_settings
|
||||
|
||||
def _init_non_llm_configs(self, config: dict):
|
||||
async def _init_non_llm_configs(self, config: dict):
|
||||
"""
|
||||
Initialize non-LLM configs eg. MCP tools, vector stores, etc.
|
||||
"""
|
||||
@@ -2595,7 +2595,7 @@ class ProxyConfig:
|
||||
litellm_settings = config.get("litellm_settings", {})
|
||||
mcp_aliases = litellm_settings.get("mcp_aliases", None)
|
||||
|
||||
global_mcp_server_manager.load_servers_from_config(
|
||||
await global_mcp_server_manager.load_servers_from_config(
|
||||
mcp_servers_config, mcp_aliases
|
||||
)
|
||||
|
||||
|
||||
@@ -7,6 +7,11 @@ from litellm.proxy._types import MCPAuthType, MCPTransportType
|
||||
# MCPInfo now allows arbitrary additional fields for custom metadata
|
||||
MCPInfo = Dict[str, Any]
|
||||
|
||||
class MCPOAuthMetadata(BaseModel):
|
||||
scopes: Optional[List[str]] = None
|
||||
authorization_url: Optional[str] = None
|
||||
token_url: Optional[str] = None
|
||||
registration_url: Optional[str] = None
|
||||
|
||||
class MCPServer(BaseModel):
|
||||
server_id: str
|
||||
|
||||
@@ -43,7 +43,7 @@ def test_mcp_server_works_without_config_auth_value():
|
||||
|
||||
|
||||
@pytest.mark.parametrize("token_key", ["authentication_token", "auth_value"])
|
||||
def test_mcp_server_config_auth_value_header_used(token_key):
|
||||
async def test_mcp_server_config_auth_value_header_used(token_key):
|
||||
"""Ensure auth header is sent when auth token configured in config"""
|
||||
config = {
|
||||
"test_server": {
|
||||
@@ -55,7 +55,7 @@ def test_mcp_server_config_auth_value_header_used(token_key):
|
||||
}
|
||||
|
||||
manager = MCPServerManager()
|
||||
manager.load_servers_from_config(config)
|
||||
await manager.load_servers_from_config(config)
|
||||
|
||||
server = next(iter(manager.config_mcp_servers.values()))
|
||||
client = manager._create_mcp_client(server)
|
||||
|
||||
@@ -66,7 +66,7 @@ async def test_mcp_cost_tracking():
|
||||
|
||||
with patch('litellm.proxy._experimental.mcp_server.mcp_server_manager.MCPClient', mock_client_constructor):
|
||||
# Load the server config
|
||||
local_mcp_server_manager.load_servers_from_config(
|
||||
await local_mcp_server_manager.load_servers_from_config(
|
||||
mcp_servers_config={
|
||||
"zapier_gmail_server": {
|
||||
"url": os.getenv("ZAPIER_MCP_HTTPS_SERVER_URL"),
|
||||
@@ -166,7 +166,7 @@ async def test_mcp_cost_tracking_per_tool():
|
||||
|
||||
with patch('litellm.proxy._experimental.mcp_server.mcp_server_manager.MCPClient', mock_client_constructor):
|
||||
# Load the server config with per-tool costs
|
||||
local_mcp_server_manager.load_servers_from_config(
|
||||
await local_mcp_server_manager.load_servers_from_config(
|
||||
mcp_servers_config={
|
||||
"test_server": {
|
||||
"url": os.getenv("ZAPIER_MCP_HTTPS_SERVER_URL"),
|
||||
@@ -310,7 +310,7 @@ async def test_mcp_tool_call_hook():
|
||||
|
||||
with patch('litellm.proxy._experimental.mcp_server.mcp_server_manager.MCPClient', mock_client_constructor):
|
||||
# Load the server config
|
||||
local_mcp_server_manager.load_servers_from_config(
|
||||
await local_mcp_server_manager.load_servers_from_config(
|
||||
mcp_servers_config={
|
||||
"zapier_gmail_server": {
|
||||
"url": os.getenv("ZAPIER_MCP_HTTPS_SERVER_URL"),
|
||||
|
||||
@@ -24,7 +24,7 @@ mcp_server_manager = MCPServerManager()
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.skip(reason="Local only test")
|
||||
async def test_mcp_server_manager():
|
||||
mcp_server_manager.load_servers_from_config(
|
||||
await mcp_server_manager.load_servers_from_config(
|
||||
{
|
||||
"zapier_mcp_server": {
|
||||
"url": os.environ.get("ZAPIER_MCP_SERVER_URL"),
|
||||
@@ -79,7 +79,7 @@ async def test_mcp_server_manager_https_server():
|
||||
"litellm.proxy._experimental.mcp_server.mcp_server_manager.MCPClient",
|
||||
mock_client_constructor,
|
||||
):
|
||||
mcp_server_manager.load_servers_from_config(
|
||||
await mcp_server_manager.load_servers_from_config(
|
||||
{
|
||||
"zapier_mcp_server": {
|
||||
"url": "https://test-mcp-server.com/mcp",
|
||||
@@ -189,7 +189,7 @@ async def test_mcp_http_transport_list_tools_mock():
|
||||
mock_client_constructor,
|
||||
):
|
||||
# Load server config with HTTP transport
|
||||
test_manager.load_servers_from_config(
|
||||
await test_manager.load_servers_from_config(
|
||||
{
|
||||
"test_http_server": {
|
||||
"url": "https://test-mcp-server.com/mcp",
|
||||
@@ -266,7 +266,7 @@ async def test_mcp_http_transport_call_tool_mock():
|
||||
mock_client_constructor,
|
||||
):
|
||||
# Load server config with HTTP transport
|
||||
test_manager.load_servers_from_config(
|
||||
await test_manager.load_servers_from_config(
|
||||
{
|
||||
"test_http_server": {
|
||||
"url": "https://test-mcp-server.com/mcp",
|
||||
@@ -333,7 +333,7 @@ async def test_mcp_http_transport_call_tool_error_mock():
|
||||
mock_client_constructor,
|
||||
):
|
||||
# Load server config with HTTP transport
|
||||
test_manager.load_servers_from_config(
|
||||
await test_manager.load_servers_from_config(
|
||||
{
|
||||
"test_http_server": {
|
||||
"url": "https://test-mcp-server.com/mcp",
|
||||
@@ -376,7 +376,7 @@ async def test_mcp_http_transport_tool_not_found():
|
||||
test_manager = MCPServerManager()
|
||||
|
||||
# Load server config
|
||||
test_manager.load_servers_from_config(
|
||||
await test_manager.load_servers_from_config(
|
||||
{
|
||||
"test_http_server": {
|
||||
"url": "https://test-mcp-server.com/mcp",
|
||||
@@ -699,7 +699,7 @@ async def test_list_tools_rest_api_success():
|
||||
mock_client_constructor,
|
||||
):
|
||||
# Load server config into global manager
|
||||
global_mcp_server_manager.load_servers_from_config(
|
||||
await global_mcp_server_manager.load_servers_from_config(
|
||||
{
|
||||
"test_server": {
|
||||
"url": "https://test-server.com/mcp",
|
||||
@@ -878,7 +878,7 @@ async def test_list_tools_only_returns_allowed_servers(monkeypatch):
|
||||
test_manager = MCPServerManager()
|
||||
|
||||
# Setup two servers in the config
|
||||
test_manager.load_servers_from_config(
|
||||
await test_manager.load_servers_from_config(
|
||||
{
|
||||
"server_a": {
|
||||
"url": "https://server-a.com/mcp",
|
||||
@@ -944,12 +944,13 @@ async def test_list_tools_only_returns_allowed_servers(monkeypatch):
|
||||
assert tools[0].name.startswith(f"{expected_prefix}-")
|
||||
|
||||
|
||||
def test_mcp_server_manager_access_groups_from_config():
|
||||
@pytest.mark.asyncio
|
||||
async def test_mcp_server_manager_access_groups_from_config():
|
||||
"""
|
||||
Test that access_groups are loaded from config and can be resolved.
|
||||
"""
|
||||
test_manager = MCPServerManager()
|
||||
test_manager.load_servers_from_config(
|
||||
await test_manager.load_servers_from_config(
|
||||
{
|
||||
"config_server": {
|
||||
"url": "https://config-mcp-server.com/mcp",
|
||||
@@ -986,15 +987,15 @@ def test_mcp_server_manager_access_groups_from_config():
|
||||
# Should find config_server for group-a, both for group-b, other_server for group-c
|
||||
import asyncio
|
||||
|
||||
server_ids_a = asyncio.run(
|
||||
MCPRequestHandler._get_mcp_servers_from_access_groups(["group-a"])
|
||||
)
|
||||
server_ids_b = asyncio.run(
|
||||
MCPRequestHandler._get_mcp_servers_from_access_groups(["group-b"])
|
||||
)
|
||||
server_ids_c = asyncio.run(
|
||||
MCPRequestHandler._get_mcp_servers_from_access_groups(["group-c"])
|
||||
)
|
||||
server_ids_a = await MCPRequestHandler._get_mcp_servers_from_access_groups([
|
||||
"group-a"
|
||||
])
|
||||
server_ids_b = await MCPRequestHandler._get_mcp_servers_from_access_groups([
|
||||
"group-b"
|
||||
])
|
||||
server_ids_c = await MCPRequestHandler._get_mcp_servers_from_access_groups([
|
||||
"group-c"
|
||||
])
|
||||
assert any(config_server.server_id == sid for sid in server_ids_a)
|
||||
assert set(server_ids_b) == set(
|
||||
[
|
||||
@@ -1009,7 +1010,7 @@ def test_mcp_server_manager_access_groups_from_config():
|
||||
)
|
||||
|
||||
|
||||
def test_mcp_server_manager_config_integration_with_database():
|
||||
async def test_mcp_server_manager_config_integration_with_database():
|
||||
"""
|
||||
Test that config-based servers properly integrate with database servers,
|
||||
specifically testing access_groups and description fields.
|
||||
@@ -1020,7 +1021,7 @@ def test_mcp_server_manager_config_integration_with_database():
|
||||
test_manager = MCPServerManager()
|
||||
|
||||
# Test 1: Load config with access_groups and description
|
||||
test_manager.load_servers_from_config(
|
||||
await test_manager.load_servers_from_config(
|
||||
{
|
||||
"config_server_with_groups": {
|
||||
"url": "https://config-server.com/mcp",
|
||||
@@ -1081,10 +1082,8 @@ def test_mcp_server_manager_config_integration_with_database():
|
||||
# Test the method (this tests our second fix)
|
||||
import asyncio
|
||||
|
||||
servers_list = asyncio.run(
|
||||
test_manager.get_all_mcp_servers_with_health_and_teams(
|
||||
user_api_key_auth=mock_user_auth
|
||||
)
|
||||
servers_list = await test_manager.get_all_mcp_servers_with_health_and_teams(
|
||||
user_api_key_auth=mock_user_auth
|
||||
)
|
||||
|
||||
# Verify we have the config server properly converted
|
||||
@@ -1536,7 +1535,7 @@ async def test_mcp_protocol_version_passed_to_client():
|
||||
mock_client_constructor,
|
||||
):
|
||||
# Load a test server
|
||||
test_manager.load_servers_from_config(
|
||||
await test_manager.load_servers_from_config(
|
||||
{
|
||||
"test_server": {
|
||||
"url": "https://test-server.com/mcp",
|
||||
@@ -2423,7 +2422,7 @@ async def test_mcp_server_manager_with_access_groups_integration():
|
||||
test_manager = MCPServerManager()
|
||||
|
||||
# Load servers with access groups
|
||||
test_manager.load_servers_from_config(
|
||||
await test_manager.load_servers_from_config(
|
||||
{
|
||||
"staff_server": {
|
||||
"url": "https://staff-server.com/mcp",
|
||||
@@ -2470,7 +2469,7 @@ async def test_get_allowed_mcp_servers_returns_registry_for_admin():
|
||||
)
|
||||
|
||||
test_manager = MCPServerManager()
|
||||
test_manager.load_servers_from_config(
|
||||
await test_manager.load_servers_from_config(
|
||||
{
|
||||
"alpha_server": {
|
||||
"url": "https://alpha.server/mcp",
|
||||
@@ -2505,7 +2504,7 @@ async def test_get_allowed_mcp_servers_returns_empty_for_non_admin_without_permi
|
||||
)
|
||||
|
||||
test_manager = MCPServerManager()
|
||||
test_manager.load_servers_from_config(
|
||||
await test_manager.load_servers_from_config(
|
||||
{
|
||||
"alpha_server": {
|
||||
"url": "https://alpha.server/mcp",
|
||||
|
||||
@@ -22,7 +22,7 @@ from litellm.proxy._types import LiteLLM_MCPServerTable
|
||||
class TestMCPCustomFields:
|
||||
"""Test custom fields functionality in MCP server configuration."""
|
||||
|
||||
def test_custom_fields_preserved_from_config(self):
|
||||
async def test_custom_fields_preserved_from_config(self):
|
||||
"""Test that custom fields in mcp_info are preserved when loading from config."""
|
||||
manager = MCPServerManager()
|
||||
|
||||
@@ -46,7 +46,7 @@ class TestMCPCustomFields:
|
||||
}
|
||||
|
||||
# Load servers from config
|
||||
manager.load_servers_from_config(mock_config)
|
||||
await manager.load_servers_from_config(mock_config)
|
||||
|
||||
# Get the loaded server
|
||||
servers = list(manager.config_mcp_servers.values())
|
||||
@@ -109,7 +109,7 @@ class TestMCPCustomFields:
|
||||
assert mcp_info["metadata"] == {"source": "database"}
|
||||
assert mcp_info["version"] == "1.0.0"
|
||||
|
||||
def test_empty_mcp_info_handled_gracefully(self):
|
||||
async def test_empty_mcp_info_handled_gracefully(self):
|
||||
"""Test that empty or missing mcp_info is handled gracefully."""
|
||||
manager = MCPServerManager()
|
||||
|
||||
@@ -122,7 +122,7 @@ class TestMCPCustomFields:
|
||||
}
|
||||
}
|
||||
|
||||
manager.load_servers_from_config(mock_config)
|
||||
await manager.load_servers_from_config(mock_config)
|
||||
|
||||
servers = list(manager.config_mcp_servers.values())
|
||||
assert len(servers) == 1
|
||||
@@ -133,7 +133,7 @@ class TestMCPCustomFields:
|
||||
# Should have default server_name
|
||||
assert mcp_info["server_name"] == "test_server"
|
||||
|
||||
def test_missing_mcp_info_creates_defaults(self):
|
||||
async def test_missing_mcp_info_creates_defaults(self):
|
||||
"""Test that missing mcp_info creates appropriate defaults."""
|
||||
manager = MCPServerManager()
|
||||
|
||||
@@ -146,7 +146,7 @@ class TestMCPCustomFields:
|
||||
}
|
||||
}
|
||||
|
||||
manager.load_servers_from_config(mock_config)
|
||||
await manager.load_servers_from_config(mock_config)
|
||||
|
||||
servers = list(manager.config_mcp_servers.values())
|
||||
assert len(servers) == 1
|
||||
@@ -158,7 +158,7 @@ class TestMCPCustomFields:
|
||||
assert mcp_info["server_name"] == "test_server"
|
||||
assert mcp_info["description"] == "Server description"
|
||||
|
||||
def test_config_description_fallback(self):
|
||||
async def test_config_description_fallback(self):
|
||||
"""Test that description from config level is used as fallback."""
|
||||
manager = MCPServerManager()
|
||||
|
||||
@@ -174,7 +174,7 @@ class TestMCPCustomFields:
|
||||
}
|
||||
}
|
||||
|
||||
manager.load_servers_from_config(mock_config)
|
||||
await manager.load_servers_from_config(mock_config)
|
||||
|
||||
servers = list(manager.config_mcp_servers.values())
|
||||
server = servers[0]
|
||||
@@ -184,7 +184,7 @@ class TestMCPCustomFields:
|
||||
assert mcp_info["description"] == "Config level description"
|
||||
assert mcp_info["custom_field"] == "custom_value"
|
||||
|
||||
def test_mcp_info_description_takes_precedence(self):
|
||||
async def test_mcp_info_description_takes_precedence(self):
|
||||
"""Test that description in mcp_info takes precedence over config level."""
|
||||
manager = MCPServerManager()
|
||||
|
||||
@@ -201,7 +201,7 @@ class TestMCPCustomFields:
|
||||
}
|
||||
}
|
||||
|
||||
manager.load_servers_from_config(mock_config)
|
||||
await manager.load_servers_from_config(mock_config)
|
||||
|
||||
servers = list(manager.config_mcp_servers.values())
|
||||
server = servers[0]
|
||||
|
||||
@@ -8,12 +8,14 @@ from fastapi import HTTPException
|
||||
# Add the parent directory to the path so we can import litellm
|
||||
sys.path.insert(0, "../../../../../")
|
||||
|
||||
import httpx
|
||||
from litellm.proxy._experimental.mcp_server.mcp_server_manager import (
|
||||
MCPServerManager,
|
||||
_deserialize_json_dict,
|
||||
)
|
||||
from litellm.proxy._types import LiteLLM_MCPServerTable, MCPTransport
|
||||
from litellm.types.mcp_server.mcp_server_manager import MCPServer
|
||||
from litellm.types.mcp import MCPAuth
|
||||
from litellm.types.mcp_server.mcp_server_manager import MCPServer, MCPOAuthMetadata
|
||||
|
||||
|
||||
class TestMCPServerManager:
|
||||
@@ -218,6 +220,133 @@ class TestMCPServerManager:
|
||||
assert len(result) == 1
|
||||
assert result[0].name == "github_tool_1"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fetch_oauth_metadata_from_resource_returns_servers_and_scopes(self):
|
||||
manager = MCPServerManager()
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {
|
||||
"authorization_servers": [
|
||||
"https://auth1.example.com",
|
||||
"https://auth2.example.com",
|
||||
],
|
||||
"scopes_supported": ["read", "write"],
|
||||
}
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.get = AsyncMock(return_value=mock_response)
|
||||
|
||||
async_client_context = MagicMock()
|
||||
async_client_context.__aenter__ = AsyncMock(return_value=mock_client)
|
||||
async_client_context.__aexit__ = AsyncMock(return_value=None)
|
||||
|
||||
with patch(
|
||||
"litellm.proxy._experimental.mcp_server.mcp_server_manager.httpx.AsyncClient",
|
||||
return_value=async_client_context,
|
||||
):
|
||||
servers, scopes = await manager._fetch_oauth_metadata_from_resource(
|
||||
"https://protected.example.com/.well-known/oauth"
|
||||
)
|
||||
|
||||
assert servers == [
|
||||
"https://auth1.example.com",
|
||||
"https://auth2.example.com",
|
||||
]
|
||||
assert scopes == ["read", "write"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_descovery_metadata_falls_back_to_origin_when_no_auth_servers(self):
|
||||
manager = MCPServerManager()
|
||||
server_url = "https://example.com/public/mcp"
|
||||
|
||||
request = httpx.Request("GET", server_url)
|
||||
response_obj = httpx.Response(
|
||||
status_code=401,
|
||||
request=request,
|
||||
headers={"WWW-Authenticate": 'Bearer scope="read"'},
|
||||
)
|
||||
|
||||
def raise_http_error():
|
||||
raise httpx.HTTPStatusError(
|
||||
"unauthorized", request=request, response=response_obj
|
||||
)
|
||||
|
||||
response_obj.raise_for_status = MagicMock(side_effect=raise_http_error)
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.get = AsyncMock(return_value=response_obj)
|
||||
|
||||
async_client_context = MagicMock()
|
||||
async_client_context.__aenter__ = AsyncMock(return_value=mock_client)
|
||||
async_client_context.__aexit__ = AsyncMock(return_value=None)
|
||||
|
||||
mock_metadata = MCPOAuthMetadata(
|
||||
scopes=None,
|
||||
authorization_url="https://example.com/auth",
|
||||
token_url="https://example.com/token",
|
||||
registration_url=None,
|
||||
)
|
||||
|
||||
with patch(
|
||||
"litellm.proxy._experimental.mcp_server.mcp_server_manager.httpx.AsyncClient",
|
||||
return_value=async_client_context,
|
||||
), patch.object(
|
||||
manager,
|
||||
"_fetch_oauth_metadata_from_resource",
|
||||
AsyncMock(return_value=([], None)),
|
||||
), patch.object(
|
||||
manager,
|
||||
"_attempt_well_known_discovery",
|
||||
AsyncMock(return_value=([], None)),
|
||||
), patch.object(
|
||||
manager,
|
||||
"_fetch_authorization_server_metadata",
|
||||
AsyncMock(return_value=mock_metadata),
|
||||
) as mock_fetch_auth:
|
||||
result = await manager._descovery_metadata(server_url)
|
||||
|
||||
mock_fetch_auth.assert_awaited_once_with(["https://example.com"])
|
||||
assert result is mock_metadata
|
||||
assert result.scopes == ["read"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_servers_from_config_overrides_discovery_metadata(self):
|
||||
manager = MCPServerManager()
|
||||
|
||||
discovered_metadata = MCPOAuthMetadata(
|
||||
scopes=["discovered"],
|
||||
authorization_url="https://discovered.example.com/auth",
|
||||
token_url="https://discovered.example.com/token",
|
||||
registration_url="https://discovered.example.com/register",
|
||||
)
|
||||
|
||||
async def fake_discovery(server_url: str):
|
||||
assert server_url == "https://example.com/mcp"
|
||||
return discovered_metadata
|
||||
|
||||
manager._descovery_metadata = fake_discovery # type: ignore[attr-defined]
|
||||
|
||||
config = {
|
||||
"example": {
|
||||
"url": "https://example.com/mcp",
|
||||
"transport": MCPTransport.http,
|
||||
"auth_type": MCPAuth.oauth2,
|
||||
"scopes": ["config"],
|
||||
"authorization_url": "https://config.example.com/auth",
|
||||
}
|
||||
}
|
||||
|
||||
await manager.load_servers_from_config(config)
|
||||
|
||||
server = next(iter(manager.config_mcp_servers.values()))
|
||||
assert server.scopes == ["config"] # config overrides discovery
|
||||
assert server.authorization_url == "https://config.example.com/auth"
|
||||
assert server.token_url == "https://discovered.example.com/token"
|
||||
assert (
|
||||
server.registration_url == "https://discovered.example.com/register"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_tools_handles_missing_server_alias(self):
|
||||
"""Test that list_tools handles servers without alias gracefully"""
|
||||
|
||||
Reference in New Issue
Block a user