diff --git a/docs/my-website/docs/mcp.md b/docs/my-website/docs/mcp.md index c735b8ecdd..9fd803f434 100644 --- a/docs/my-website/docs/mcp.md +++ b/docs/my-website/docs/mcp.md @@ -211,11 +211,12 @@ mcp_servers: oauth2_example: url: "https://my-mcp-server.com/mcp" auth_type: "oauth2" # 👈 KEY CHANGE - authorization_url: "https://my-mcp-server.com/oauth/authorize" # optional for client-credentials - token_url: "https://my-mcp-server.com/oauth/token" # required + authorization_url: "https://my-mcp-server.com/oauth/authorize" # optional override + token_url: "https://my-mcp-server.com/oauth/token" # optional override + registration_url: "https://my-mcp-server.com/oauth/register" # optional override client_id: os.environ/OAUTH_CLIENT_ID client_secret: os.environ/OAUTH_CLIENT_SECRET - scopes: ["tool.read", "tool.write"] # optional + scopes: ["tool.read", "tool.write"] # optional override bearer_example: url: "https://my-mcp-server.com/mcp" @@ -325,6 +326,10 @@ mcp_servers: | `spec_path` | Yes | Path or URL to your OpenAPI specification file (JSON or YAML) | | `auth_type` | No | Authentication type: `none`, `api_key`, `bearer_token`, `basic`, `authorization` | | `auth_value` | No | Authentication value (required if `auth_type` is set) | +| `authorization_url` | No | For `auth_type: oauth2`. Optional override; if omitted LiteLLM auto-discovers it. | +| `token_url` | No | For `auth_type: oauth2`. Optional override; if omitted LiteLLM auto-discovers it. | +| `registration_url` | No | For `auth_type: oauth2`. Optional override; if omitted LiteLLM auto-discovers it. | +| `scopes` | No | For `auth_type: oauth2`. Optional override; if omitted LiteLLM uses the scopes advertised by the server. | | `description` | No | Optional description for the MCP server | | `allowed_tools` | No | List of specific tools to allow (see [MCP Tool Filtering](#mcp-tool-filtering)) | | `disallowed_tools` | No | List of specific tools to block (see [MCP Tool Filtering](#mcp-tool-filtering)) | @@ -1224,17 +1229,10 @@ mcp_servers: github_mcp: url: "https://api.githubcopilot.com/mcp" auth_type: oauth2 - authorization_url: https://github.com/login/oauth/authorize - token_url: https://github.com/login/oauth/access_token client_id: os.environ/GITHUB_OAUTH_CLIENT_ID client_secret: os.environ/GITHUB_OAUTH_CLIENT_SECRET - scopes: ["public_repo", "user:email"] ``` -**Note** -In the future, users will only need to specify the `url` of the MCP server. -LiteLLM will automatically resolve the corresponding `authorization_url`, `token_url`, and `registration_url` based on the MCP server metadata (e.g., `.well-known/oauth-authorization-server` or `oauth-protected-resource`). - [**See Claude Code Tutorial**](./tutorials/claude_responses_api#connecting-mcp-servers) ## Using your MCP with client side credentials @@ -1887,4 +1885,4 @@ async with stdio_client(server_params) as (read, write): ``` - \ No newline at end of file + diff --git a/docs/my-website/docs/tutorials/claude_responses_api.md b/docs/my-website/docs/tutorials/claude_responses_api.md index 343f938b67..0dbb4a2f1e 100644 --- a/docs/my-website/docs/tutorials/claude_responses_api.md +++ b/docs/my-website/docs/tutorials/claude_responses_api.md @@ -237,11 +237,8 @@ mcp_servers: github_mcp: url: "https://api.githubcopilot.com/mcp" auth_type: oauth2 - authorization_url: https://github.com/login/oauth/authorize - token_url: https://github.com/login/oauth/access_token client_id: os.environ/GITHUB_OAUTH_CLIENT_ID client_secret: os.environ/GITHUB_OAUTH_CLIENT_SECRET - scopes: ["public_repo", "user:email"] ``` @@ -255,9 +252,6 @@ atlassian_mcp: url: "https://mcp.atlassian.com/v1/sse" transport: "sse" auth_type: oauth2 - authorization_url: https://mcp.atlassian.com/v1/authorize - token_url: https://cf.mcp.atlassian.com/v1/token - registration_url: https://cf.mcp.atlassian.com/v1/register ``` diff --git a/litellm/proxy/_experimental/mcp_server/mcp_server_manager.py b/litellm/proxy/_experimental/mcp_server/mcp_server_manager.py index aefbbc8d4a..b42364b6db 100644 --- a/litellm/proxy/_experimental/mcp_server/mcp_server_manager.py +++ b/litellm/proxy/_experimental/mcp_server/mcp_server_manager.py @@ -10,9 +10,13 @@ import asyncio import datetime import hashlib import json -from typing import Any, Dict, List, Optional, Set, Union, cast +import re +from typing import Any, Dict, List, Optional, Set, Tuple, Union, cast +from urllib.parse import urlparse from fastapi import HTTPException +import httpx +from httpx import HTTPStatusError from mcp.types import CallToolRequestParams as MCPCallToolRequestParams from mcp.types import CallToolResult from mcp.types import Tool as MCPTool @@ -43,7 +47,11 @@ from litellm.proxy.common_utils.encrypt_decrypt_utils import ( ) from litellm.proxy.utils import ProxyLogging from litellm.types.mcp import MCPAuth, MCPStdioConfig -from litellm.types.mcp_server.mcp_server_manager import MCPInfo, MCPServer +from litellm.types.mcp_server.mcp_server_manager import ( + MCPInfo, + MCPOAuthMetadata, + MCPServer, +) def _deserialize_json_dict(data: Any) -> Optional[Dict[str, str]]: @@ -100,7 +108,7 @@ class MCPServerManager: """ return self.config_mcp_servers | self.registry - def load_servers_from_config( + async def load_servers_from_config( self, mcp_servers_config: Dict[str, Any], mcp_aliases: Optional[Dict[str, str]] = None, @@ -180,35 +188,57 @@ class MCPServerManager: )() name_for_prefix = get_server_prefix(temp_server) + server_url = server_config.get("url", None) or "" # Generate stable server ID based on parameters server_id = self._generate_stable_server_id( server_name=server_name, - url=server_config.get("url", None) or "", + url=server_url, transport=server_config.get("transport", MCPTransport.http), auth_type=server_config.get("auth_type", None), alias=alias, ) + auth_type = server_config.get("auth_type", None) + if server_url and auth_type is not None and auth_type == MCPAuth.oauth2: + mcp_oauth_metadata = await self._descovery_metadata( + server_url=server_url, + ) + else: + mcp_oauth_metadata = None + + resolved_scopes = server_config.get("scopes") or ( + mcp_oauth_metadata.scopes if mcp_oauth_metadata else None + ) + resolved_authorization_url = server_config.get("authorization_url") or ( + mcp_oauth_metadata.authorization_url if mcp_oauth_metadata else None + ) + resolved_token_url = server_config.get("token_url") or ( + mcp_oauth_metadata.token_url if mcp_oauth_metadata else None + ) + resolved_registration_url = server_config.get("registration_url") or ( + mcp_oauth_metadata.registration_url if mcp_oauth_metadata else None + ) + new_server = MCPServer( server_id=server_id, name=name_for_prefix, alias=alias, server_name=server_name, spec_path=server_config.get("spec_path", None), - url=server_config.get("url", None) or "", + url=server_url, command=server_config.get("command", None) or "", args=server_config.get("args", None) or [], env=server_config.get("env", None) or {}, # oauth specific fields client_id=server_config.get("client_id", None), client_secret=server_config.get("client_secret", None), - scopes=server_config.get("scopes", None), - authorization_url=server_config.get("authorization_url", None), - token_url=server_config.get("token_url", None), - registration_url=server_config.get("registration_url", None), + scopes=resolved_scopes, + authorization_url=resolved_authorization_url, + token_url=resolved_token_url, + registration_url=resolved_registration_url, # TODO: utility fn the default values transport=server_config.get("transport", MCPTransport.http), - auth_type=server_config.get("auth_type", None), + auth_type=auth_type, authentication_token=server_config.get( "authentication_token", server_config.get("auth_value", None) ), @@ -692,6 +722,250 @@ class MCPServerManager: except Exception: pass + async def _descovery_metadata( + self, + server_url: str, + ) -> Optional[MCPOAuthMetadata]: + """Discover OAuth metadata by following RFC 9728 (protected resource metadata discovery).""" + + try: + async with httpx.AsyncClient( + timeout=10.0, follow_redirects=False + ) as client: + response = await client.get(server_url) + response.raise_for_status() + verbose_logger.warning( + "MCP OAuth discovery unexpectedly succeeded for %s; server did not challenge", + server_url, + ) + raise RuntimeError("OAuth discovery must not succeed without a challenge") + except HTTPStatusError as exc: + verbose_logger.debug( + "MCP OAuth discovery for %s received status error: %s", + server_url, + exc, + ) + + header_value: Optional[str] = None + if exc.response is not None: + header_value = exc.response.headers.get( + "WWW-Authenticate" + ) or exc.response.headers.get("www-authenticate") + + resource_metadata_url, scopes = self._parse_www_authenticate_header( + header_value + ) + + authorization_servers: List[str] = [] + resource_scopes: Optional[List[str]] = None + if resource_metadata_url: + ( + authorization_servers, + resource_scopes, + ) = await self._fetch_oauth_metadata_from_resource( + resource_metadata_url + ) + else: + ( + authorization_servers, + resource_scopes, + ) = await self._attempt_well_known_discovery(server_url) + + metadata = None + if not authorization_servers: + try: + parsed_url = urlparse(server_url) + if parsed_url.scheme and parsed_url.netloc: + authorization_servers = [ + f"{parsed_url.scheme}://{parsed_url.netloc}" + ] + except Exception: + authorization_servers = [] + + if authorization_servers: + metadata = await self._fetch_authorization_server_metadata( + authorization_servers + ) + + preferred_scopes = scopes or resource_scopes + if metadata is None and preferred_scopes: + metadata = MCPOAuthMetadata(scopes=preferred_scopes) + elif metadata is not None and preferred_scopes: + metadata.scopes = preferred_scopes + + return metadata + except Exception as exc: # pragma: no cover - network/transient issues + verbose_logger.debug( + "MCP OAuth discovery failed for %s: %s", server_url, exc + ) + return None + + def _parse_www_authenticate_header( + self, header_value: Optional[str] + ) -> Tuple[Optional[str], Optional[List[str]]]: + if not header_value: + return None, None + + _, _, params_section = header_value.partition(" ") + params_section = params_section or header_value + + param_pattern = re.compile(r"([a-zA-Z0-9_]+)\s*=\s*\"?([^\",]+)\"?") + params: Dict[str, str] = { + match.group(1).lower(): match.group(2).strip() + for match in param_pattern.finditer(params_section) + } + + resource_metadata_url = params.get("resource_metadata") + + scope_value = params.get("scope") + scopes_list = [s for s in (scope_value.split() if scope_value else []) if s] + scopes = scopes_list or None + + return resource_metadata_url, scopes + + async def _fetch_oauth_metadata_from_resource( + self, resource_metadata_url: str + ) -> Tuple[List[str], Optional[List[str]]]: + if not resource_metadata_url: + return [], None + + try: + async with httpx.AsyncClient(timeout=10.0, follow_redirects=True) as client: + response = await client.get(resource_metadata_url) + response.raise_for_status() + data = response.json() + except Exception as exc: # pragma: no cover - network issues + verbose_logger.debug( + "Failed to fetch MCP OAuth metadata from %s: %s", + resource_metadata_url, + exc, + ) + return [], None + + raw_servers = data.get("authorization_servers") + if isinstance(raw_servers, list): + authorization_servers = [ + entry + for entry in raw_servers + if isinstance(entry, str) and entry.strip() != "" + ] + else: + authorization_servers = [] + + scopes = self._extract_scopes( + data.get("scopes_supported") or data.get("scopes") + ) + + return authorization_servers, scopes + + async def _attempt_well_known_discovery( + self, server_url: str + ) -> Tuple[List[str], Optional[List[str]]]: + try: + parsed = urlparse(server_url) + except Exception: + return [], None + + if not parsed.scheme or not parsed.netloc: + return [], None + + base = f"{parsed.scheme}://{parsed.netloc}" + path = parsed.path or "" + path = path.strip("/") + + candidate_urls: List[str] = [] + if path: + candidate_urls.append(f"{base}/.well-known/oauth-protected-resource/{path}") + candidate_urls.append(f"{base}/.well-known/oauth-protected-resource") + + for url in candidate_urls: + ( + authorization_servers, + scopes, + ) = await self._fetch_oauth_metadata_from_resource(url) + if authorization_servers: + return authorization_servers, scopes + + return [], None + + async def _fetch_authorization_server_metadata( + self, authorization_servers: List[str] + ) -> Optional[MCPOAuthMetadata]: + for issuer in authorization_servers: + metadata = await self._fetch_single_authorization_server_metadata(issuer) + if metadata is not None: + return metadata + return None + + async def _fetch_single_authorization_server_metadata( + self, issuer_url: str + ) -> Optional[MCPOAuthMetadata]: + try: + parsed = urlparse(issuer_url) + except Exception: + return None + + if not parsed.scheme or not parsed.netloc: + return None + + base = f"{parsed.scheme}://{parsed.netloc}" + path = (parsed.path or "").strip("/") + + candidate_urls: List[str] = [] + if path: + candidate_urls.append( + f"{base}/.well-known/oauth-authorization-server/{path}" + ) + candidate_urls.append(f"{base}/.well-known/openid-configuration/{path}") + candidate_urls.append(f"{base}/.well-known/oauth-authorization-server") + candidate_urls.append(f"{base}/.well-known/openid-configuration") + candidate_urls.append(issuer_url.rstrip("/")) + + for url in candidate_urls: + try: + async with httpx.AsyncClient( + timeout=10.0, follow_redirects=True + ) as client: + response = await client.get(url) + response.raise_for_status() + data = response.json() + except Exception as exc: # pragma: no cover - network issues + verbose_logger.debug( + "Failed to fetch authorization metadata from %s: %s", + url, + exc, + ) + continue + + scopes = self._extract_scopes(data.get("scopes_supported")) + metadata = MCPOAuthMetadata( + scopes=scopes, + authorization_url=data.get("authorization_endpoint"), + token_url=data.get("token_endpoint"), + registration_url=data.get("registration_endpoint"), + ) + + if any( + [ + metadata.scopes, + metadata.authorization_url, + metadata.token_url, + metadata.registration_url, + ] + ): + return metadata + + return None + + def _extract_scopes(self, scopes_value: Any) -> Optional[List[str]]: + if isinstance(scopes_value, str): + scopes = [s.strip() for s in scopes_value.split() if s.strip()] + return scopes or None + if isinstance(scopes_value, list): + scopes = [s for s in scopes_value if isinstance(s, str) and s.strip()] + return scopes or None + return None + async def _fetch_tools_with_timeout( self, client: MCPClient, server_name: str ) -> List[MCPTool]: @@ -721,11 +995,6 @@ class MCPServerManager: f"Client operation failed for {server_name}: {str(e)}" ) return [] - finally: - try: - await client.disconnect() - except Exception: - pass try: return await asyncio.wait_for(_list_tools_task(), timeout=30.0) diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index f81bcd14d5..06964d1cf8 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -2572,11 +2572,11 @@ class ProxyConfig: litellm.credential_list = credential_list_dict ## NON-LLM CONFIGS eg. MCP tools, vector stores, etc. - self._init_non_llm_configs(config=config) + await self._init_non_llm_configs(config=config) return router, router.get_model_list(), general_settings - def _init_non_llm_configs(self, config: dict): + async def _init_non_llm_configs(self, config: dict): """ Initialize non-LLM configs eg. MCP tools, vector stores, etc. """ @@ -2595,7 +2595,7 @@ class ProxyConfig: litellm_settings = config.get("litellm_settings", {}) mcp_aliases = litellm_settings.get("mcp_aliases", None) - global_mcp_server_manager.load_servers_from_config( + await global_mcp_server_manager.load_servers_from_config( mcp_servers_config, mcp_aliases ) diff --git a/litellm/types/mcp_server/mcp_server_manager.py b/litellm/types/mcp_server/mcp_server_manager.py index bba28b26ae..869037546c 100644 --- a/litellm/types/mcp_server/mcp_server_manager.py +++ b/litellm/types/mcp_server/mcp_server_manager.py @@ -7,6 +7,11 @@ from litellm.proxy._types import MCPAuthType, MCPTransportType # MCPInfo now allows arbitrary additional fields for custom metadata MCPInfo = Dict[str, Any] +class MCPOAuthMetadata(BaseModel): + scopes: Optional[List[str]] = None + authorization_url: Optional[str] = None + token_url: Optional[str] = None + registration_url: Optional[str] = None class MCPServer(BaseModel): server_id: str diff --git a/tests/mcp_tests/test_mcp_auth_priority.py b/tests/mcp_tests/test_mcp_auth_priority.py index 1f120b4669..bc217a7903 100644 --- a/tests/mcp_tests/test_mcp_auth_priority.py +++ b/tests/mcp_tests/test_mcp_auth_priority.py @@ -43,7 +43,7 @@ def test_mcp_server_works_without_config_auth_value(): @pytest.mark.parametrize("token_key", ["authentication_token", "auth_value"]) -def test_mcp_server_config_auth_value_header_used(token_key): +async def test_mcp_server_config_auth_value_header_used(token_key): """Ensure auth header is sent when auth token configured in config""" config = { "test_server": { @@ -55,7 +55,7 @@ def test_mcp_server_config_auth_value_header_used(token_key): } manager = MCPServerManager() - manager.load_servers_from_config(config) + await manager.load_servers_from_config(config) server = next(iter(manager.config_mcp_servers.values())) client = manager._create_mcp_client(server) diff --git a/tests/mcp_tests/test_mcp_logging.py b/tests/mcp_tests/test_mcp_logging.py index 7049a77345..6ee52364cc 100644 --- a/tests/mcp_tests/test_mcp_logging.py +++ b/tests/mcp_tests/test_mcp_logging.py @@ -66,7 +66,7 @@ async def test_mcp_cost_tracking(): with patch('litellm.proxy._experimental.mcp_server.mcp_server_manager.MCPClient', mock_client_constructor): # Load the server config - local_mcp_server_manager.load_servers_from_config( + await local_mcp_server_manager.load_servers_from_config( mcp_servers_config={ "zapier_gmail_server": { "url": os.getenv("ZAPIER_MCP_HTTPS_SERVER_URL"), @@ -166,7 +166,7 @@ async def test_mcp_cost_tracking_per_tool(): with patch('litellm.proxy._experimental.mcp_server.mcp_server_manager.MCPClient', mock_client_constructor): # Load the server config with per-tool costs - local_mcp_server_manager.load_servers_from_config( + await local_mcp_server_manager.load_servers_from_config( mcp_servers_config={ "test_server": { "url": os.getenv("ZAPIER_MCP_HTTPS_SERVER_URL"), @@ -310,7 +310,7 @@ async def test_mcp_tool_call_hook(): with patch('litellm.proxy._experimental.mcp_server.mcp_server_manager.MCPClient', mock_client_constructor): # Load the server config - local_mcp_server_manager.load_servers_from_config( + await local_mcp_server_manager.load_servers_from_config( mcp_servers_config={ "zapier_gmail_server": { "url": os.getenv("ZAPIER_MCP_HTTPS_SERVER_URL"), diff --git a/tests/mcp_tests/test_mcp_server.py b/tests/mcp_tests/test_mcp_server.py index 8391e8f77e..953dc373ab 100644 --- a/tests/mcp_tests/test_mcp_server.py +++ b/tests/mcp_tests/test_mcp_server.py @@ -24,7 +24,7 @@ mcp_server_manager = MCPServerManager() @pytest.mark.asyncio @pytest.mark.skip(reason="Local only test") async def test_mcp_server_manager(): - mcp_server_manager.load_servers_from_config( + await mcp_server_manager.load_servers_from_config( { "zapier_mcp_server": { "url": os.environ.get("ZAPIER_MCP_SERVER_URL"), @@ -79,7 +79,7 @@ async def test_mcp_server_manager_https_server(): "litellm.proxy._experimental.mcp_server.mcp_server_manager.MCPClient", mock_client_constructor, ): - mcp_server_manager.load_servers_from_config( + await mcp_server_manager.load_servers_from_config( { "zapier_mcp_server": { "url": "https://test-mcp-server.com/mcp", @@ -189,7 +189,7 @@ async def test_mcp_http_transport_list_tools_mock(): mock_client_constructor, ): # Load server config with HTTP transport - test_manager.load_servers_from_config( + await test_manager.load_servers_from_config( { "test_http_server": { "url": "https://test-mcp-server.com/mcp", @@ -266,7 +266,7 @@ async def test_mcp_http_transport_call_tool_mock(): mock_client_constructor, ): # Load server config with HTTP transport - test_manager.load_servers_from_config( + await test_manager.load_servers_from_config( { "test_http_server": { "url": "https://test-mcp-server.com/mcp", @@ -333,7 +333,7 @@ async def test_mcp_http_transport_call_tool_error_mock(): mock_client_constructor, ): # Load server config with HTTP transport - test_manager.load_servers_from_config( + await test_manager.load_servers_from_config( { "test_http_server": { "url": "https://test-mcp-server.com/mcp", @@ -376,7 +376,7 @@ async def test_mcp_http_transport_tool_not_found(): test_manager = MCPServerManager() # Load server config - test_manager.load_servers_from_config( + await test_manager.load_servers_from_config( { "test_http_server": { "url": "https://test-mcp-server.com/mcp", @@ -699,7 +699,7 @@ async def test_list_tools_rest_api_success(): mock_client_constructor, ): # Load server config into global manager - global_mcp_server_manager.load_servers_from_config( + await global_mcp_server_manager.load_servers_from_config( { "test_server": { "url": "https://test-server.com/mcp", @@ -878,7 +878,7 @@ async def test_list_tools_only_returns_allowed_servers(monkeypatch): test_manager = MCPServerManager() # Setup two servers in the config - test_manager.load_servers_from_config( + await test_manager.load_servers_from_config( { "server_a": { "url": "https://server-a.com/mcp", @@ -944,12 +944,13 @@ async def test_list_tools_only_returns_allowed_servers(monkeypatch): assert tools[0].name.startswith(f"{expected_prefix}-") -def test_mcp_server_manager_access_groups_from_config(): +@pytest.mark.asyncio +async def test_mcp_server_manager_access_groups_from_config(): """ Test that access_groups are loaded from config and can be resolved. """ test_manager = MCPServerManager() - test_manager.load_servers_from_config( + await test_manager.load_servers_from_config( { "config_server": { "url": "https://config-mcp-server.com/mcp", @@ -986,15 +987,15 @@ def test_mcp_server_manager_access_groups_from_config(): # Should find config_server for group-a, both for group-b, other_server for group-c import asyncio - server_ids_a = asyncio.run( - MCPRequestHandler._get_mcp_servers_from_access_groups(["group-a"]) - ) - server_ids_b = asyncio.run( - MCPRequestHandler._get_mcp_servers_from_access_groups(["group-b"]) - ) - server_ids_c = asyncio.run( - MCPRequestHandler._get_mcp_servers_from_access_groups(["group-c"]) - ) + server_ids_a = await MCPRequestHandler._get_mcp_servers_from_access_groups([ + "group-a" + ]) + server_ids_b = await MCPRequestHandler._get_mcp_servers_from_access_groups([ + "group-b" + ]) + server_ids_c = await MCPRequestHandler._get_mcp_servers_from_access_groups([ + "group-c" + ]) assert any(config_server.server_id == sid for sid in server_ids_a) assert set(server_ids_b) == set( [ @@ -1009,7 +1010,7 @@ def test_mcp_server_manager_access_groups_from_config(): ) -def test_mcp_server_manager_config_integration_with_database(): +async def test_mcp_server_manager_config_integration_with_database(): """ Test that config-based servers properly integrate with database servers, specifically testing access_groups and description fields. @@ -1020,7 +1021,7 @@ def test_mcp_server_manager_config_integration_with_database(): test_manager = MCPServerManager() # Test 1: Load config with access_groups and description - test_manager.load_servers_from_config( + await test_manager.load_servers_from_config( { "config_server_with_groups": { "url": "https://config-server.com/mcp", @@ -1081,10 +1082,8 @@ def test_mcp_server_manager_config_integration_with_database(): # Test the method (this tests our second fix) import asyncio - servers_list = asyncio.run( - test_manager.get_all_mcp_servers_with_health_and_teams( - user_api_key_auth=mock_user_auth - ) + servers_list = await test_manager.get_all_mcp_servers_with_health_and_teams( + user_api_key_auth=mock_user_auth ) # Verify we have the config server properly converted @@ -1536,7 +1535,7 @@ async def test_mcp_protocol_version_passed_to_client(): mock_client_constructor, ): # Load a test server - test_manager.load_servers_from_config( + await test_manager.load_servers_from_config( { "test_server": { "url": "https://test-server.com/mcp", @@ -2423,7 +2422,7 @@ async def test_mcp_server_manager_with_access_groups_integration(): test_manager = MCPServerManager() # Load servers with access groups - test_manager.load_servers_from_config( + await test_manager.load_servers_from_config( { "staff_server": { "url": "https://staff-server.com/mcp", @@ -2470,7 +2469,7 @@ async def test_get_allowed_mcp_servers_returns_registry_for_admin(): ) test_manager = MCPServerManager() - test_manager.load_servers_from_config( + await test_manager.load_servers_from_config( { "alpha_server": { "url": "https://alpha.server/mcp", @@ -2505,7 +2504,7 @@ async def test_get_allowed_mcp_servers_returns_empty_for_non_admin_without_permi ) test_manager = MCPServerManager() - test_manager.load_servers_from_config( + await test_manager.load_servers_from_config( { "alpha_server": { "url": "https://alpha.server/mcp", diff --git a/tests/test_litellm/proxy/_experimental/mcp_server/test_mcp_custom_fields.py b/tests/test_litellm/proxy/_experimental/mcp_server/test_mcp_custom_fields.py index d2fa7853e7..51f2861b2d 100644 --- a/tests/test_litellm/proxy/_experimental/mcp_server/test_mcp_custom_fields.py +++ b/tests/test_litellm/proxy/_experimental/mcp_server/test_mcp_custom_fields.py @@ -22,7 +22,7 @@ from litellm.proxy._types import LiteLLM_MCPServerTable class TestMCPCustomFields: """Test custom fields functionality in MCP server configuration.""" - def test_custom_fields_preserved_from_config(self): + async def test_custom_fields_preserved_from_config(self): """Test that custom fields in mcp_info are preserved when loading from config.""" manager = MCPServerManager() @@ -46,7 +46,7 @@ class TestMCPCustomFields: } # Load servers from config - manager.load_servers_from_config(mock_config) + await manager.load_servers_from_config(mock_config) # Get the loaded server servers = list(manager.config_mcp_servers.values()) @@ -109,7 +109,7 @@ class TestMCPCustomFields: assert mcp_info["metadata"] == {"source": "database"} assert mcp_info["version"] == "1.0.0" - def test_empty_mcp_info_handled_gracefully(self): + async def test_empty_mcp_info_handled_gracefully(self): """Test that empty or missing mcp_info is handled gracefully.""" manager = MCPServerManager() @@ -122,7 +122,7 @@ class TestMCPCustomFields: } } - manager.load_servers_from_config(mock_config) + await manager.load_servers_from_config(mock_config) servers = list(manager.config_mcp_servers.values()) assert len(servers) == 1 @@ -133,7 +133,7 @@ class TestMCPCustomFields: # Should have default server_name assert mcp_info["server_name"] == "test_server" - def test_missing_mcp_info_creates_defaults(self): + async def test_missing_mcp_info_creates_defaults(self): """Test that missing mcp_info creates appropriate defaults.""" manager = MCPServerManager() @@ -146,7 +146,7 @@ class TestMCPCustomFields: } } - manager.load_servers_from_config(mock_config) + await manager.load_servers_from_config(mock_config) servers = list(manager.config_mcp_servers.values()) assert len(servers) == 1 @@ -158,7 +158,7 @@ class TestMCPCustomFields: assert mcp_info["server_name"] == "test_server" assert mcp_info["description"] == "Server description" - def test_config_description_fallback(self): + async def test_config_description_fallback(self): """Test that description from config level is used as fallback.""" manager = MCPServerManager() @@ -174,7 +174,7 @@ class TestMCPCustomFields: } } - manager.load_servers_from_config(mock_config) + await manager.load_servers_from_config(mock_config) servers = list(manager.config_mcp_servers.values()) server = servers[0] @@ -184,7 +184,7 @@ class TestMCPCustomFields: assert mcp_info["description"] == "Config level description" assert mcp_info["custom_field"] == "custom_value" - def test_mcp_info_description_takes_precedence(self): + async def test_mcp_info_description_takes_precedence(self): """Test that description in mcp_info takes precedence over config level.""" manager = MCPServerManager() @@ -201,7 +201,7 @@ class TestMCPCustomFields: } } - manager.load_servers_from_config(mock_config) + await manager.load_servers_from_config(mock_config) servers = list(manager.config_mcp_servers.values()) server = servers[0] diff --git a/tests/test_litellm/proxy/_experimental/mcp_server/test_mcp_server_manager.py b/tests/test_litellm/proxy/_experimental/mcp_server/test_mcp_server_manager.py index f36ae1aec0..d5b3832971 100644 --- a/tests/test_litellm/proxy/_experimental/mcp_server/test_mcp_server_manager.py +++ b/tests/test_litellm/proxy/_experimental/mcp_server/test_mcp_server_manager.py @@ -8,12 +8,14 @@ from fastapi import HTTPException # Add the parent directory to the path so we can import litellm sys.path.insert(0, "../../../../../") +import httpx from litellm.proxy._experimental.mcp_server.mcp_server_manager import ( MCPServerManager, _deserialize_json_dict, ) from litellm.proxy._types import LiteLLM_MCPServerTable, MCPTransport -from litellm.types.mcp_server.mcp_server_manager import MCPServer +from litellm.types.mcp import MCPAuth +from litellm.types.mcp_server.mcp_server_manager import MCPServer, MCPOAuthMetadata class TestMCPServerManager: @@ -218,6 +220,133 @@ class TestMCPServerManager: assert len(result) == 1 assert result[0].name == "github_tool_1" + @pytest.mark.asyncio + async def test_fetch_oauth_metadata_from_resource_returns_servers_and_scopes(self): + manager = MCPServerManager() + + mock_response = MagicMock() + mock_response.json.return_value = { + "authorization_servers": [ + "https://auth1.example.com", + "https://auth2.example.com", + ], + "scopes_supported": ["read", "write"], + } + mock_response.raise_for_status = MagicMock() + + mock_client = MagicMock() + mock_client.get = AsyncMock(return_value=mock_response) + + async_client_context = MagicMock() + async_client_context.__aenter__ = AsyncMock(return_value=mock_client) + async_client_context.__aexit__ = AsyncMock(return_value=None) + + with patch( + "litellm.proxy._experimental.mcp_server.mcp_server_manager.httpx.AsyncClient", + return_value=async_client_context, + ): + servers, scopes = await manager._fetch_oauth_metadata_from_resource( + "https://protected.example.com/.well-known/oauth" + ) + + assert servers == [ + "https://auth1.example.com", + "https://auth2.example.com", + ] + assert scopes == ["read", "write"] + + @pytest.mark.asyncio + async def test_descovery_metadata_falls_back_to_origin_when_no_auth_servers(self): + manager = MCPServerManager() + server_url = "https://example.com/public/mcp" + + request = httpx.Request("GET", server_url) + response_obj = httpx.Response( + status_code=401, + request=request, + headers={"WWW-Authenticate": 'Bearer scope="read"'}, + ) + + def raise_http_error(): + raise httpx.HTTPStatusError( + "unauthorized", request=request, response=response_obj + ) + + response_obj.raise_for_status = MagicMock(side_effect=raise_http_error) + + mock_client = MagicMock() + mock_client.get = AsyncMock(return_value=response_obj) + + async_client_context = MagicMock() + async_client_context.__aenter__ = AsyncMock(return_value=mock_client) + async_client_context.__aexit__ = AsyncMock(return_value=None) + + mock_metadata = MCPOAuthMetadata( + scopes=None, + authorization_url="https://example.com/auth", + token_url="https://example.com/token", + registration_url=None, + ) + + with patch( + "litellm.proxy._experimental.mcp_server.mcp_server_manager.httpx.AsyncClient", + return_value=async_client_context, + ), patch.object( + manager, + "_fetch_oauth_metadata_from_resource", + AsyncMock(return_value=([], None)), + ), patch.object( + manager, + "_attempt_well_known_discovery", + AsyncMock(return_value=([], None)), + ), patch.object( + manager, + "_fetch_authorization_server_metadata", + AsyncMock(return_value=mock_metadata), + ) as mock_fetch_auth: + result = await manager._descovery_metadata(server_url) + + mock_fetch_auth.assert_awaited_once_with(["https://example.com"]) + assert result is mock_metadata + assert result.scopes == ["read"] + + @pytest.mark.asyncio + async def test_load_servers_from_config_overrides_discovery_metadata(self): + manager = MCPServerManager() + + discovered_metadata = MCPOAuthMetadata( + scopes=["discovered"], + authorization_url="https://discovered.example.com/auth", + token_url="https://discovered.example.com/token", + registration_url="https://discovered.example.com/register", + ) + + async def fake_discovery(server_url: str): + assert server_url == "https://example.com/mcp" + return discovered_metadata + + manager._descovery_metadata = fake_discovery # type: ignore[attr-defined] + + config = { + "example": { + "url": "https://example.com/mcp", + "transport": MCPTransport.http, + "auth_type": MCPAuth.oauth2, + "scopes": ["config"], + "authorization_url": "https://config.example.com/auth", + } + } + + await manager.load_servers_from_config(config) + + server = next(iter(manager.config_mcp_servers.values())) + assert server.scopes == ["config"] # config overrides discovery + assert server.authorization_url == "https://config.example.com/auth" + assert server.token_url == "https://discovered.example.com/token" + assert ( + server.registration_url == "https://discovered.example.com/register" + ) + @pytest.mark.asyncio async def test_list_tools_handles_missing_server_alias(self): """Test that list_tools handles servers without alias gracefully"""