Revert "Revert "Merge pull request #14720 from uc4w6c/feat/remove-servername-prefix-mcp_tools""

This reverts commit a88d774f94.
This commit is contained in:
Yuta Saito
2025-09-28 08:23:42 +09:00
parent 208cd5fcb5
commit dae7d08ff2
5 changed files with 240 additions and 9 deletions
@@ -443,6 +443,7 @@ class MCPServerManager:
server: MCPServer,
mcp_auth_header: Optional[Union[str, Dict[str, str]]] = None,
extra_headers: Optional[Dict[str, str]] = None,
add_prefix: bool = True,
) -> List[MCPTool]:
"""
Helper method to get tools from a single MCP server with prefixed names.
@@ -468,9 +469,11 @@ class MCPServerManager:
tools = await self._fetch_tools_with_timeout(client, server.name)
prefixed_tools = self._create_prefixed_tools(tools, server)
prefixed_or_original_tools = self._create_prefixed_tools(
tools, server, add_prefix=add_prefix
)
return prefixed_tools
return prefixed_or_original_tools
except Exception as e:
verbose_logger.warning(
@@ -539,7 +542,7 @@ class MCPServerManager:
return []
def _create_prefixed_tools(
self, tools: List[MCPTool], server: MCPServer
self, tools: List[MCPTool], server: MCPServer, add_prefix: bool = True
) -> List[MCPTool]:
"""
Create prefixed tools and update tool mapping.
@@ -557,14 +560,16 @@ class MCPServerManager:
for tool in tools:
prefixed_name = add_server_prefix_to_tool_name(tool.name, prefix)
prefixed_tool = MCPTool(
name=prefixed_name,
name_to_use = prefixed_name if add_prefix else tool.name
tool_obj = MCPTool(
name=name_to_use,
description=tool.description,
inputSchema=tool.inputSchema,
)
prefixed_tools.append(prefixed_tool)
prefixed_tools.append(tool_obj)
# Update tool to server mapping with both original and prefixed names
# Update tool to server mapping for resolution (support both forms)
self.tool_name_to_mcp_server_name_mapping[tool.name] = prefix
self.tool_name_to_mcp_server_name_mapping[prefixed_name] = prefix
@@ -73,6 +73,7 @@ if MCP_AVAILABLE:
tools = await global_mcp_server_manager._get_tools_from_server(
server=server,
mcp_auth_header=server_auth_header,
add_prefix=False,
)
return _create_tool_response_objects(tools, server.mcp_info)
@@ -414,6 +414,9 @@ if MCP_AVAILABLE:
allowed_mcp_servers=allowed_mcp_servers,
)
# Decide whether to add prefix based on number of allowed servers
add_prefix = not (len(allowed_mcp_servers) == 1)
# Get tools from each allowed server
all_tools = []
for server_id in allowed_mcp_servers:
@@ -448,6 +451,7 @@ if MCP_AVAILABLE:
server=server,
mcp_auth_header=server_auth_header,
extra_headers=extra_headers,
add_prefix=add_prefix,
)
all_tools.extend(filter_tools_by_allowed_tools(tools, server))
verbose_logger.debug(
@@ -103,7 +103,7 @@ async def test_get_tools_from_mcp_servers_continues_when_one_server_fails():
)
async def mock_get_tools_from_server(
server, mcp_auth_header=None, extra_headers=None
server, mcp_auth_header=None, extra_headers=None, add_prefix=True
):
if server.name == "working_server":
# Working server returns tools
@@ -187,7 +187,7 @@ async def test_get_tools_from_mcp_servers_handles_all_servers_failing():
)
async def mock_get_tools_from_server(
server, mcp_auth_header=None, extra_headers=None
server, mcp_auth_header=None, extra_headers=None, add_prefix=True
):
# All servers fail
raise Exception(f"Server {server.name} connection failed")
@@ -564,3 +564,120 @@ async def test_oauth2_headers_passed_to_mcp_client():
captured_client_args["extra_headers"]["Authorization"]
== "Bearer github_oauth_token_12345"
)
@pytest.mark.asyncio
async def test_list_tools_single_server_unprefixed_names():
"""When only one MCP server is allowed, list tools should return unprefixed names."""
try:
from litellm.proxy._experimental.mcp_server.server import (
_get_tools_from_mcp_servers,
set_auth_context,
)
except ImportError:
pytest.skip("MCP server not available")
# Mock user auth
user_api_key_auth = UserAPIKeyAuth(api_key="test_key", user_id="test_user")
set_auth_context(user_api_key_auth)
# One allowed server
server = MagicMock()
server.server_id = "server1"
server.name = "Zapier MCP"
server.alias = "zapier"
# Mock manager: allow just one server and return a tool based on add_prefix flag
mock_manager = MagicMock()
mock_manager.get_allowed_mcp_servers = AsyncMock(return_value=["server1"])
mock_manager.get_mcp_server_by_id = (
lambda server_id: server if server_id == "server1" else None
)
async def mock_get_tools_from_server(
server, mcp_auth_header=None, extra_headers=None, add_prefix=False
):
tool = MagicMock()
tool.name = f"{server.alias}-toolA" if add_prefix else "toolA"
tool.description = "desc"
tool.inputSchema = {}
return [tool]
mock_manager._get_tools_from_server = mock_get_tools_from_server
with patch(
"litellm.proxy._experimental.mcp_server.server.global_mcp_server_manager",
mock_manager,
):
tools = await _get_tools_from_mcp_servers(
user_api_key_auth=user_api_key_auth,
mcp_auth_header=None,
mcp_servers=None,
mcp_server_auth_headers=None,
)
# Should be unprefixed since only one server is allowed
assert len(tools) == 1
assert tools[0].name == "toolA"
@pytest.mark.asyncio
async def test_list_tools_multiple_servers_prefixed_names():
"""When multiple MCP servers are allowed, list tools should return prefixed names."""
try:
from litellm.proxy._experimental.mcp_server.server import (
_get_tools_from_mcp_servers,
set_auth_context,
)
except ImportError:
pytest.skip("MCP server not available")
# Mock user auth
user_api_key_auth = UserAPIKeyAuth(api_key="test_key", user_id="test_user")
set_auth_context(user_api_key_auth)
# Two allowed servers
server1 = MagicMock()
server1.server_id = "server1"
server1.name = "Zapier MCP"
server1.alias = "zapier"
server2 = MagicMock()
server2.server_id = "server2"
server2.name = "Jira MCP"
server2.alias = "jira"
# Mock manager
mock_manager = MagicMock()
mock_manager.get_allowed_mcp_servers = AsyncMock(
return_value=["server1", "server2"]
)
mock_manager.get_mcp_server_by_id = (
lambda server_id: server1 if server_id == "server1" else server2
)
async def mock_get_tools_from_server(
server, mcp_auth_header=None, extra_headers=None, add_prefix=True
):
tool = MagicMock()
# When multiple servers, add_prefix should be True -> prefixed names
tool.name = f"{server.alias}-toolA" if add_prefix else "toolA"
tool.description = "desc"
tool.inputSchema = {}
return [tool]
mock_manager._get_tools_from_server = mock_get_tools_from_server
with patch(
"litellm.proxy._experimental.mcp_server.server.global_mcp_server_manager",
mock_manager,
):
tools = await _get_tools_from_mcp_servers(
user_api_key_auth=user_api_key_auth,
mcp_auth_header=None,
mcp_servers=None,
mcp_server_auth_headers=None,
)
# Should be prefixed since multiple servers are allowed
names = sorted([t.name for t in tools])
assert names == ["jira-toolA", "zapier-toolA"]
@@ -654,6 +654,110 @@ class TestMCPServerManager:
"Tool tool3 is not allowed for server test-server"
in exc_info.value.detail["error"]
)
async def test_get_tools_from_server_add_prefix(self):
"""Verify _get_tools_from_server respects add_prefix True/False."""
manager = MCPServerManager()
# Create a minimal server with alias used as prefix
server = MCPServer(
server_id="zapier",
name="zapier",
transport=MCPTransport.http,
)
# Mock client creation and fetching tools
manager._create_mcp_client = MagicMock(return_value=object())
# Tools returned upstream (unprefixed from provider)
upstream_tool = MagicMock()
upstream_tool.name = "send_email"
upstream_tool.description = "Send an email"
upstream_tool.inputSchema = {}
manager._fetch_tools_with_timeout = AsyncMock(return_value=[upstream_tool])
# Case 1: add_prefix=True (default for multi-server) -> expect prefixed
tools_prefixed = await manager._get_tools_from_server(server, add_prefix=True)
assert len(tools_prefixed) == 1
assert tools_prefixed[0].name == "zapier-send_email"
# Case 2: add_prefix=False (single-server) -> expect unprefixed
tools_unprefixed = await manager._get_tools_from_server(
server, add_prefix=False
)
assert len(tools_unprefixed) == 1
assert tools_unprefixed[0].name == "send_email"
def test_create_prefixed_tools_updates_mapping_for_both_forms(self):
"""_create_prefixed_tools should populate mapping for prefixed and original names even when not adding prefix in output."""
manager = MCPServerManager()
server = MCPServer(
server_id="jira",
name="jira",
transport=MCPTransport.http,
)
# Input tools as would come from upstream
t1 = MagicMock()
t1.name = "create_issue"
t1.description = ""
t1.inputSchema = {}
t2 = MagicMock()
t2.name = "close_issue"
t2.description = ""
t2.inputSchema = {}
# Do not add prefix in returned objects
out_tools = manager._create_prefixed_tools([t1, t2], server, add_prefix=False)
# Returned names should be unprefixed
names = sorted([t.name for t in out_tools])
assert names == ["close_issue", "create_issue"]
# Mapping should include both original and prefixed names -> resolves calls either way
assert manager.tool_name_to_mcp_server_name_mapping["create_issue"] == "jira"
assert (
manager.tool_name_to_mcp_server_name_mapping["jira-create_issue"] == "jira"
)
assert manager.tool_name_to_mcp_server_name_mapping["close_issue"] == "jira"
assert (
manager.tool_name_to_mcp_server_name_mapping["jira-close_issue"] == "jira"
)
def test_get_mcp_server_from_tool_name_with_prefixed_and_unprefixed(self):
"""After mapping is populated, manager resolves both prefixed and unprefixed tool names to the same server."""
manager = MCPServerManager()
server = MCPServer(
server_id="zapier",
name="zapier",
server_name="zapier",
transport=MCPTransport.http,
)
# Register server so resolution can find it
manager.registry = {server.server_id: server}
# Populate mapping (add_prefix value doesn't matter for mapping population)
base_tool = MagicMock()
base_tool.name = "create_zap"
base_tool.description = ""
base_tool.inputSchema = {}
_ = manager._create_prefixed_tools([base_tool], server, add_prefix=False)
# Unprefixed resolution
resolved_server_unpref = manager._get_mcp_server_from_tool_name("create_zap")
print(resolved_server_unpref)
assert resolved_server_unpref is not None
assert resolved_server_unpref.server_id == server.server_id
# Prefixed resolution
resolved_server_pref = manager._get_mcp_server_from_tool_name(
"zapier-create_zap"
)
assert resolved_server_pref is not None
assert resolved_server_pref.server_id == server.server_id
if __name__ == "__main__":