mirror of
https://github.com/tiennm99/litellm.git
synced 2026-06-18 03:31:23 +00:00
Revert "Revert "Merge pull request #14720 from uc4w6c/feat/remove-servername-prefix-mcp_tools""
This reverts commit a88d774f94.
This commit is contained in:
@@ -443,6 +443,7 @@ class MCPServerManager:
|
||||
server: MCPServer,
|
||||
mcp_auth_header: Optional[Union[str, Dict[str, str]]] = None,
|
||||
extra_headers: Optional[Dict[str, str]] = None,
|
||||
add_prefix: bool = True,
|
||||
) -> List[MCPTool]:
|
||||
"""
|
||||
Helper method to get tools from a single MCP server with prefixed names.
|
||||
@@ -468,9 +469,11 @@ class MCPServerManager:
|
||||
|
||||
tools = await self._fetch_tools_with_timeout(client, server.name)
|
||||
|
||||
prefixed_tools = self._create_prefixed_tools(tools, server)
|
||||
prefixed_or_original_tools = self._create_prefixed_tools(
|
||||
tools, server, add_prefix=add_prefix
|
||||
)
|
||||
|
||||
return prefixed_tools
|
||||
return prefixed_or_original_tools
|
||||
|
||||
except Exception as e:
|
||||
verbose_logger.warning(
|
||||
@@ -539,7 +542,7 @@ class MCPServerManager:
|
||||
return []
|
||||
|
||||
def _create_prefixed_tools(
|
||||
self, tools: List[MCPTool], server: MCPServer
|
||||
self, tools: List[MCPTool], server: MCPServer, add_prefix: bool = True
|
||||
) -> List[MCPTool]:
|
||||
"""
|
||||
Create prefixed tools and update tool mapping.
|
||||
@@ -557,14 +560,16 @@ class MCPServerManager:
|
||||
for tool in tools:
|
||||
prefixed_name = add_server_prefix_to_tool_name(tool.name, prefix)
|
||||
|
||||
prefixed_tool = MCPTool(
|
||||
name=prefixed_name,
|
||||
name_to_use = prefixed_name if add_prefix else tool.name
|
||||
|
||||
tool_obj = MCPTool(
|
||||
name=name_to_use,
|
||||
description=tool.description,
|
||||
inputSchema=tool.inputSchema,
|
||||
)
|
||||
prefixed_tools.append(prefixed_tool)
|
||||
prefixed_tools.append(tool_obj)
|
||||
|
||||
# Update tool to server mapping with both original and prefixed names
|
||||
# Update tool to server mapping for resolution (support both forms)
|
||||
self.tool_name_to_mcp_server_name_mapping[tool.name] = prefix
|
||||
self.tool_name_to_mcp_server_name_mapping[prefixed_name] = prefix
|
||||
|
||||
|
||||
@@ -73,6 +73,7 @@ if MCP_AVAILABLE:
|
||||
tools = await global_mcp_server_manager._get_tools_from_server(
|
||||
server=server,
|
||||
mcp_auth_header=server_auth_header,
|
||||
add_prefix=False,
|
||||
)
|
||||
return _create_tool_response_objects(tools, server.mcp_info)
|
||||
|
||||
|
||||
@@ -414,6 +414,9 @@ if MCP_AVAILABLE:
|
||||
allowed_mcp_servers=allowed_mcp_servers,
|
||||
)
|
||||
|
||||
# Decide whether to add prefix based on number of allowed servers
|
||||
add_prefix = not (len(allowed_mcp_servers) == 1)
|
||||
|
||||
# Get tools from each allowed server
|
||||
all_tools = []
|
||||
for server_id in allowed_mcp_servers:
|
||||
@@ -448,6 +451,7 @@ if MCP_AVAILABLE:
|
||||
server=server,
|
||||
mcp_auth_header=server_auth_header,
|
||||
extra_headers=extra_headers,
|
||||
add_prefix=add_prefix,
|
||||
)
|
||||
all_tools.extend(filter_tools_by_allowed_tools(tools, server))
|
||||
verbose_logger.debug(
|
||||
|
||||
@@ -103,7 +103,7 @@ async def test_get_tools_from_mcp_servers_continues_when_one_server_fails():
|
||||
)
|
||||
|
||||
async def mock_get_tools_from_server(
|
||||
server, mcp_auth_header=None, extra_headers=None
|
||||
server, mcp_auth_header=None, extra_headers=None, add_prefix=True
|
||||
):
|
||||
if server.name == "working_server":
|
||||
# Working server returns tools
|
||||
@@ -187,7 +187,7 @@ async def test_get_tools_from_mcp_servers_handles_all_servers_failing():
|
||||
)
|
||||
|
||||
async def mock_get_tools_from_server(
|
||||
server, mcp_auth_header=None, extra_headers=None
|
||||
server, mcp_auth_header=None, extra_headers=None, add_prefix=True
|
||||
):
|
||||
# All servers fail
|
||||
raise Exception(f"Server {server.name} connection failed")
|
||||
@@ -564,3 +564,120 @@ async def test_oauth2_headers_passed_to_mcp_client():
|
||||
captured_client_args["extra_headers"]["Authorization"]
|
||||
== "Bearer github_oauth_token_12345"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_tools_single_server_unprefixed_names():
|
||||
"""When only one MCP server is allowed, list tools should return unprefixed names."""
|
||||
try:
|
||||
from litellm.proxy._experimental.mcp_server.server import (
|
||||
_get_tools_from_mcp_servers,
|
||||
set_auth_context,
|
||||
)
|
||||
except ImportError:
|
||||
pytest.skip("MCP server not available")
|
||||
|
||||
# Mock user auth
|
||||
user_api_key_auth = UserAPIKeyAuth(api_key="test_key", user_id="test_user")
|
||||
set_auth_context(user_api_key_auth)
|
||||
|
||||
# One allowed server
|
||||
server = MagicMock()
|
||||
server.server_id = "server1"
|
||||
server.name = "Zapier MCP"
|
||||
server.alias = "zapier"
|
||||
|
||||
# Mock manager: allow just one server and return a tool based on add_prefix flag
|
||||
mock_manager = MagicMock()
|
||||
mock_manager.get_allowed_mcp_servers = AsyncMock(return_value=["server1"])
|
||||
mock_manager.get_mcp_server_by_id = (
|
||||
lambda server_id: server if server_id == "server1" else None
|
||||
)
|
||||
|
||||
async def mock_get_tools_from_server(
|
||||
server, mcp_auth_header=None, extra_headers=None, add_prefix=False
|
||||
):
|
||||
tool = MagicMock()
|
||||
tool.name = f"{server.alias}-toolA" if add_prefix else "toolA"
|
||||
tool.description = "desc"
|
||||
tool.inputSchema = {}
|
||||
return [tool]
|
||||
|
||||
mock_manager._get_tools_from_server = mock_get_tools_from_server
|
||||
|
||||
with patch(
|
||||
"litellm.proxy._experimental.mcp_server.server.global_mcp_server_manager",
|
||||
mock_manager,
|
||||
):
|
||||
tools = await _get_tools_from_mcp_servers(
|
||||
user_api_key_auth=user_api_key_auth,
|
||||
mcp_auth_header=None,
|
||||
mcp_servers=None,
|
||||
mcp_server_auth_headers=None,
|
||||
)
|
||||
|
||||
# Should be unprefixed since only one server is allowed
|
||||
assert len(tools) == 1
|
||||
assert tools[0].name == "toolA"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_tools_multiple_servers_prefixed_names():
|
||||
"""When multiple MCP servers are allowed, list tools should return prefixed names."""
|
||||
try:
|
||||
from litellm.proxy._experimental.mcp_server.server import (
|
||||
_get_tools_from_mcp_servers,
|
||||
set_auth_context,
|
||||
)
|
||||
except ImportError:
|
||||
pytest.skip("MCP server not available")
|
||||
|
||||
# Mock user auth
|
||||
user_api_key_auth = UserAPIKeyAuth(api_key="test_key", user_id="test_user")
|
||||
set_auth_context(user_api_key_auth)
|
||||
|
||||
# Two allowed servers
|
||||
server1 = MagicMock()
|
||||
server1.server_id = "server1"
|
||||
server1.name = "Zapier MCP"
|
||||
server1.alias = "zapier"
|
||||
|
||||
server2 = MagicMock()
|
||||
server2.server_id = "server2"
|
||||
server2.name = "Jira MCP"
|
||||
server2.alias = "jira"
|
||||
|
||||
# Mock manager
|
||||
mock_manager = MagicMock()
|
||||
mock_manager.get_allowed_mcp_servers = AsyncMock(
|
||||
return_value=["server1", "server2"]
|
||||
)
|
||||
mock_manager.get_mcp_server_by_id = (
|
||||
lambda server_id: server1 if server_id == "server1" else server2
|
||||
)
|
||||
|
||||
async def mock_get_tools_from_server(
|
||||
server, mcp_auth_header=None, extra_headers=None, add_prefix=True
|
||||
):
|
||||
tool = MagicMock()
|
||||
# When multiple servers, add_prefix should be True -> prefixed names
|
||||
tool.name = f"{server.alias}-toolA" if add_prefix else "toolA"
|
||||
tool.description = "desc"
|
||||
tool.inputSchema = {}
|
||||
return [tool]
|
||||
|
||||
mock_manager._get_tools_from_server = mock_get_tools_from_server
|
||||
|
||||
with patch(
|
||||
"litellm.proxy._experimental.mcp_server.server.global_mcp_server_manager",
|
||||
mock_manager,
|
||||
):
|
||||
tools = await _get_tools_from_mcp_servers(
|
||||
user_api_key_auth=user_api_key_auth,
|
||||
mcp_auth_header=None,
|
||||
mcp_servers=None,
|
||||
mcp_server_auth_headers=None,
|
||||
)
|
||||
|
||||
# Should be prefixed since multiple servers are allowed
|
||||
names = sorted([t.name for t in tools])
|
||||
assert names == ["jira-toolA", "zapier-toolA"]
|
||||
|
||||
@@ -654,6 +654,110 @@ class TestMCPServerManager:
|
||||
"Tool tool3 is not allowed for server test-server"
|
||||
in exc_info.value.detail["error"]
|
||||
)
|
||||
async def test_get_tools_from_server_add_prefix(self):
|
||||
"""Verify _get_tools_from_server respects add_prefix True/False."""
|
||||
manager = MCPServerManager()
|
||||
|
||||
# Create a minimal server with alias used as prefix
|
||||
server = MCPServer(
|
||||
server_id="zapier",
|
||||
name="zapier",
|
||||
transport=MCPTransport.http,
|
||||
)
|
||||
|
||||
# Mock client creation and fetching tools
|
||||
manager._create_mcp_client = MagicMock(return_value=object())
|
||||
|
||||
# Tools returned upstream (unprefixed from provider)
|
||||
upstream_tool = MagicMock()
|
||||
upstream_tool.name = "send_email"
|
||||
upstream_tool.description = "Send an email"
|
||||
upstream_tool.inputSchema = {}
|
||||
|
||||
manager._fetch_tools_with_timeout = AsyncMock(return_value=[upstream_tool])
|
||||
|
||||
# Case 1: add_prefix=True (default for multi-server) -> expect prefixed
|
||||
tools_prefixed = await manager._get_tools_from_server(server, add_prefix=True)
|
||||
assert len(tools_prefixed) == 1
|
||||
assert tools_prefixed[0].name == "zapier-send_email"
|
||||
|
||||
# Case 2: add_prefix=False (single-server) -> expect unprefixed
|
||||
tools_unprefixed = await manager._get_tools_from_server(
|
||||
server, add_prefix=False
|
||||
)
|
||||
assert len(tools_unprefixed) == 1
|
||||
assert tools_unprefixed[0].name == "send_email"
|
||||
|
||||
def test_create_prefixed_tools_updates_mapping_for_both_forms(self):
|
||||
"""_create_prefixed_tools should populate mapping for prefixed and original names even when not adding prefix in output."""
|
||||
manager = MCPServerManager()
|
||||
|
||||
server = MCPServer(
|
||||
server_id="jira",
|
||||
name="jira",
|
||||
transport=MCPTransport.http,
|
||||
)
|
||||
|
||||
# Input tools as would come from upstream
|
||||
t1 = MagicMock()
|
||||
t1.name = "create_issue"
|
||||
t1.description = ""
|
||||
t1.inputSchema = {}
|
||||
t2 = MagicMock()
|
||||
t2.name = "close_issue"
|
||||
t2.description = ""
|
||||
t2.inputSchema = {}
|
||||
|
||||
# Do not add prefix in returned objects
|
||||
out_tools = manager._create_prefixed_tools([t1, t2], server, add_prefix=False)
|
||||
|
||||
# Returned names should be unprefixed
|
||||
names = sorted([t.name for t in out_tools])
|
||||
assert names == ["close_issue", "create_issue"]
|
||||
|
||||
# Mapping should include both original and prefixed names -> resolves calls either way
|
||||
assert manager.tool_name_to_mcp_server_name_mapping["create_issue"] == "jira"
|
||||
assert (
|
||||
manager.tool_name_to_mcp_server_name_mapping["jira-create_issue"] == "jira"
|
||||
)
|
||||
assert manager.tool_name_to_mcp_server_name_mapping["close_issue"] == "jira"
|
||||
assert (
|
||||
manager.tool_name_to_mcp_server_name_mapping["jira-close_issue"] == "jira"
|
||||
)
|
||||
|
||||
def test_get_mcp_server_from_tool_name_with_prefixed_and_unprefixed(self):
|
||||
"""After mapping is populated, manager resolves both prefixed and unprefixed tool names to the same server."""
|
||||
manager = MCPServerManager()
|
||||
|
||||
server = MCPServer(
|
||||
server_id="zapier",
|
||||
name="zapier",
|
||||
server_name="zapier",
|
||||
transport=MCPTransport.http,
|
||||
)
|
||||
|
||||
# Register server so resolution can find it
|
||||
manager.registry = {server.server_id: server}
|
||||
|
||||
# Populate mapping (add_prefix value doesn't matter for mapping population)
|
||||
base_tool = MagicMock()
|
||||
base_tool.name = "create_zap"
|
||||
base_tool.description = ""
|
||||
base_tool.inputSchema = {}
|
||||
_ = manager._create_prefixed_tools([base_tool], server, add_prefix=False)
|
||||
|
||||
# Unprefixed resolution
|
||||
resolved_server_unpref = manager._get_mcp_server_from_tool_name("create_zap")
|
||||
print(resolved_server_unpref)
|
||||
assert resolved_server_unpref is not None
|
||||
assert resolved_server_unpref.server_id == server.server_id
|
||||
|
||||
# Prefixed resolution
|
||||
resolved_server_pref = manager._get_mcp_server_from_tool_name(
|
||||
"zapier-create_zap"
|
||||
)
|
||||
assert resolved_server_pref is not None
|
||||
assert resolved_server_pref.server_id == server.server_id
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user