diff --git a/litellm/proxy/_experimental/mcp_server/mcp_server_manager.py b/litellm/proxy/_experimental/mcp_server/mcp_server_manager.py index 6423a9ae15..3cf1f764d1 100644 --- a/litellm/proxy/_experimental/mcp_server/mcp_server_manager.py +++ b/litellm/proxy/_experimental/mcp_server/mcp_server_manager.py @@ -443,6 +443,7 @@ class MCPServerManager: server: MCPServer, mcp_auth_header: Optional[Union[str, Dict[str, str]]] = None, extra_headers: Optional[Dict[str, str]] = None, + add_prefix: bool = True, ) -> List[MCPTool]: """ Helper method to get tools from a single MCP server with prefixed names. @@ -468,9 +469,11 @@ class MCPServerManager: tools = await self._fetch_tools_with_timeout(client, server.name) - prefixed_tools = self._create_prefixed_tools(tools, server) + prefixed_or_original_tools = self._create_prefixed_tools( + tools, server, add_prefix=add_prefix + ) - return prefixed_tools + return prefixed_or_original_tools except Exception as e: verbose_logger.warning( @@ -539,7 +542,7 @@ class MCPServerManager: return [] def _create_prefixed_tools( - self, tools: List[MCPTool], server: MCPServer + self, tools: List[MCPTool], server: MCPServer, add_prefix: bool = True ) -> List[MCPTool]: """ Create prefixed tools and update tool mapping. @@ -557,14 +560,16 @@ class MCPServerManager: for tool in tools: prefixed_name = add_server_prefix_to_tool_name(tool.name, prefix) - prefixed_tool = MCPTool( - name=prefixed_name, + name_to_use = prefixed_name if add_prefix else tool.name + + tool_obj = MCPTool( + name=name_to_use, description=tool.description, inputSchema=tool.inputSchema, ) - prefixed_tools.append(prefixed_tool) + prefixed_tools.append(tool_obj) - # Update tool to server mapping with both original and prefixed names + # Update tool to server mapping for resolution (support both forms) self.tool_name_to_mcp_server_name_mapping[tool.name] = prefix self.tool_name_to_mcp_server_name_mapping[prefixed_name] = prefix diff --git a/litellm/proxy/_experimental/mcp_server/rest_endpoints.py b/litellm/proxy/_experimental/mcp_server/rest_endpoints.py index ef770a9d43..236dfdb6cf 100644 --- a/litellm/proxy/_experimental/mcp_server/rest_endpoints.py +++ b/litellm/proxy/_experimental/mcp_server/rest_endpoints.py @@ -73,6 +73,7 @@ if MCP_AVAILABLE: tools = await global_mcp_server_manager._get_tools_from_server( server=server, mcp_auth_header=server_auth_header, + add_prefix=False, ) return _create_tool_response_objects(tools, server.mcp_info) diff --git a/litellm/proxy/_experimental/mcp_server/server.py b/litellm/proxy/_experimental/mcp_server/server.py index 94f81ccef8..9c2551fec2 100644 --- a/litellm/proxy/_experimental/mcp_server/server.py +++ b/litellm/proxy/_experimental/mcp_server/server.py @@ -414,6 +414,9 @@ if MCP_AVAILABLE: allowed_mcp_servers=allowed_mcp_servers, ) + # Decide whether to add prefix based on number of allowed servers + add_prefix = not (len(allowed_mcp_servers) == 1) + # Get tools from each allowed server all_tools = [] for server_id in allowed_mcp_servers: @@ -448,6 +451,7 @@ if MCP_AVAILABLE: server=server, mcp_auth_header=server_auth_header, extra_headers=extra_headers, + add_prefix=add_prefix, ) all_tools.extend(filter_tools_by_allowed_tools(tools, server)) verbose_logger.debug( diff --git a/tests/test_litellm/proxy/_experimental/mcp_server/test_mcp_server.py b/tests/test_litellm/proxy/_experimental/mcp_server/test_mcp_server.py index 10a6ab8cb0..6a1d43b33e 100644 --- a/tests/test_litellm/proxy/_experimental/mcp_server/test_mcp_server.py +++ b/tests/test_litellm/proxy/_experimental/mcp_server/test_mcp_server.py @@ -103,7 +103,7 @@ async def test_get_tools_from_mcp_servers_continues_when_one_server_fails(): ) async def mock_get_tools_from_server( - server, mcp_auth_header=None, extra_headers=None + server, mcp_auth_header=None, extra_headers=None, add_prefix=True ): if server.name == "working_server": # Working server returns tools @@ -187,7 +187,7 @@ async def test_get_tools_from_mcp_servers_handles_all_servers_failing(): ) async def mock_get_tools_from_server( - server, mcp_auth_header=None, extra_headers=None + server, mcp_auth_header=None, extra_headers=None, add_prefix=True ): # All servers fail raise Exception(f"Server {server.name} connection failed") @@ -564,3 +564,120 @@ async def test_oauth2_headers_passed_to_mcp_client(): captured_client_args["extra_headers"]["Authorization"] == "Bearer github_oauth_token_12345" ) + +@pytest.mark.asyncio +async def test_list_tools_single_server_unprefixed_names(): + """When only one MCP server is allowed, list tools should return unprefixed names.""" + try: + from litellm.proxy._experimental.mcp_server.server import ( + _get_tools_from_mcp_servers, + set_auth_context, + ) + except ImportError: + pytest.skip("MCP server not available") + + # Mock user auth + user_api_key_auth = UserAPIKeyAuth(api_key="test_key", user_id="test_user") + set_auth_context(user_api_key_auth) + + # One allowed server + server = MagicMock() + server.server_id = "server1" + server.name = "Zapier MCP" + server.alias = "zapier" + + # Mock manager: allow just one server and return a tool based on add_prefix flag + mock_manager = MagicMock() + mock_manager.get_allowed_mcp_servers = AsyncMock(return_value=["server1"]) + mock_manager.get_mcp_server_by_id = ( + lambda server_id: server if server_id == "server1" else None + ) + + async def mock_get_tools_from_server( + server, mcp_auth_header=None, extra_headers=None, add_prefix=False + ): + tool = MagicMock() + tool.name = f"{server.alias}-toolA" if add_prefix else "toolA" + tool.description = "desc" + tool.inputSchema = {} + return [tool] + + mock_manager._get_tools_from_server = mock_get_tools_from_server + + with patch( + "litellm.proxy._experimental.mcp_server.server.global_mcp_server_manager", + mock_manager, + ): + tools = await _get_tools_from_mcp_servers( + user_api_key_auth=user_api_key_auth, + mcp_auth_header=None, + mcp_servers=None, + mcp_server_auth_headers=None, + ) + + # Should be unprefixed since only one server is allowed + assert len(tools) == 1 + assert tools[0].name == "toolA" + + +@pytest.mark.asyncio +async def test_list_tools_multiple_servers_prefixed_names(): + """When multiple MCP servers are allowed, list tools should return prefixed names.""" + try: + from litellm.proxy._experimental.mcp_server.server import ( + _get_tools_from_mcp_servers, + set_auth_context, + ) + except ImportError: + pytest.skip("MCP server not available") + + # Mock user auth + user_api_key_auth = UserAPIKeyAuth(api_key="test_key", user_id="test_user") + set_auth_context(user_api_key_auth) + + # Two allowed servers + server1 = MagicMock() + server1.server_id = "server1" + server1.name = "Zapier MCP" + server1.alias = "zapier" + + server2 = MagicMock() + server2.server_id = "server2" + server2.name = "Jira MCP" + server2.alias = "jira" + + # Mock manager + mock_manager = MagicMock() + mock_manager.get_allowed_mcp_servers = AsyncMock( + return_value=["server1", "server2"] + ) + mock_manager.get_mcp_server_by_id = ( + lambda server_id: server1 if server_id == "server1" else server2 + ) + + async def mock_get_tools_from_server( + server, mcp_auth_header=None, extra_headers=None, add_prefix=True + ): + tool = MagicMock() + # When multiple servers, add_prefix should be True -> prefixed names + tool.name = f"{server.alias}-toolA" if add_prefix else "toolA" + tool.description = "desc" + tool.inputSchema = {} + return [tool] + + mock_manager._get_tools_from_server = mock_get_tools_from_server + + with patch( + "litellm.proxy._experimental.mcp_server.server.global_mcp_server_manager", + mock_manager, + ): + tools = await _get_tools_from_mcp_servers( + user_api_key_auth=user_api_key_auth, + mcp_auth_header=None, + mcp_servers=None, + mcp_server_auth_headers=None, + ) + + # Should be prefixed since multiple servers are allowed + names = sorted([t.name for t in tools]) + assert names == ["jira-toolA", "zapier-toolA"] diff --git a/tests/test_litellm/proxy/_experimental/mcp_server/test_mcp_server_manager.py b/tests/test_litellm/proxy/_experimental/mcp_server/test_mcp_server_manager.py index 126a4ebb89..80ea95a221 100644 --- a/tests/test_litellm/proxy/_experimental/mcp_server/test_mcp_server_manager.py +++ b/tests/test_litellm/proxy/_experimental/mcp_server/test_mcp_server_manager.py @@ -654,6 +654,110 @@ class TestMCPServerManager: "Tool tool3 is not allowed for server test-server" in exc_info.value.detail["error"] ) + async def test_get_tools_from_server_add_prefix(self): + """Verify _get_tools_from_server respects add_prefix True/False.""" + manager = MCPServerManager() + + # Create a minimal server with alias used as prefix + server = MCPServer( + server_id="zapier", + name="zapier", + transport=MCPTransport.http, + ) + + # Mock client creation and fetching tools + manager._create_mcp_client = MagicMock(return_value=object()) + + # Tools returned upstream (unprefixed from provider) + upstream_tool = MagicMock() + upstream_tool.name = "send_email" + upstream_tool.description = "Send an email" + upstream_tool.inputSchema = {} + + manager._fetch_tools_with_timeout = AsyncMock(return_value=[upstream_tool]) + + # Case 1: add_prefix=True (default for multi-server) -> expect prefixed + tools_prefixed = await manager._get_tools_from_server(server, add_prefix=True) + assert len(tools_prefixed) == 1 + assert tools_prefixed[0].name == "zapier-send_email" + + # Case 2: add_prefix=False (single-server) -> expect unprefixed + tools_unprefixed = await manager._get_tools_from_server( + server, add_prefix=False + ) + assert len(tools_unprefixed) == 1 + assert tools_unprefixed[0].name == "send_email" + + def test_create_prefixed_tools_updates_mapping_for_both_forms(self): + """_create_prefixed_tools should populate mapping for prefixed and original names even when not adding prefix in output.""" + manager = MCPServerManager() + + server = MCPServer( + server_id="jira", + name="jira", + transport=MCPTransport.http, + ) + + # Input tools as would come from upstream + t1 = MagicMock() + t1.name = "create_issue" + t1.description = "" + t1.inputSchema = {} + t2 = MagicMock() + t2.name = "close_issue" + t2.description = "" + t2.inputSchema = {} + + # Do not add prefix in returned objects + out_tools = manager._create_prefixed_tools([t1, t2], server, add_prefix=False) + + # Returned names should be unprefixed + names = sorted([t.name for t in out_tools]) + assert names == ["close_issue", "create_issue"] + + # Mapping should include both original and prefixed names -> resolves calls either way + assert manager.tool_name_to_mcp_server_name_mapping["create_issue"] == "jira" + assert ( + manager.tool_name_to_mcp_server_name_mapping["jira-create_issue"] == "jira" + ) + assert manager.tool_name_to_mcp_server_name_mapping["close_issue"] == "jira" + assert ( + manager.tool_name_to_mcp_server_name_mapping["jira-close_issue"] == "jira" + ) + + def test_get_mcp_server_from_tool_name_with_prefixed_and_unprefixed(self): + """After mapping is populated, manager resolves both prefixed and unprefixed tool names to the same server.""" + manager = MCPServerManager() + + server = MCPServer( + server_id="zapier", + name="zapier", + server_name="zapier", + transport=MCPTransport.http, + ) + + # Register server so resolution can find it + manager.registry = {server.server_id: server} + + # Populate mapping (add_prefix value doesn't matter for mapping population) + base_tool = MagicMock() + base_tool.name = "create_zap" + base_tool.description = "" + base_tool.inputSchema = {} + _ = manager._create_prefixed_tools([base_tool], server, add_prefix=False) + + # Unprefixed resolution + resolved_server_unpref = manager._get_mcp_server_from_tool_name("create_zap") + print(resolved_server_unpref) + assert resolved_server_unpref is not None + assert resolved_server_unpref.server_id == server.server_id + + # Prefixed resolution + resolved_server_pref = manager._get_mcp_server_from_tool_name( + "zapier-create_zap" + ) + assert resolved_server_pref is not None + assert resolved_server_pref.server_id == server.server_id if __name__ == "__main__":