mirror of
https://github.com/tiennm99/litellm.git
synced 2026-06-18 17:28:19 +00:00
5db4862cbf
* fix headers * fix test * fix ruff * added try except for catching errors which lead to client failures * fix mypy * fix ruff * fix tests * fix python error * fix test * fix test * fixed the MCP Call Tool result
1240 lines
45 KiB
Python
1240 lines
45 KiB
Python
# Create server parameters for stdio connection
|
|
import os
|
|
import sys
|
|
import pytest
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
from contextlib import asynccontextmanager
|
|
|
|
sys.path.insert(
|
|
0, os.path.abspath("../../..")
|
|
) # Adds the parent directory to the system path
|
|
|
|
from litellm.proxy._experimental.mcp_server.mcp_server_manager import (
|
|
MCPServerManager,
|
|
MCPServer,
|
|
MCPTransport,
|
|
)
|
|
from mcp.types import Tool as MCPTool, CallToolResult, ListToolsResult
|
|
from mcp.types import TextContent
|
|
|
|
|
|
mcp_server_manager = MCPServerManager()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
@pytest.mark.skip(reason="Local only test")
|
|
async def test_mcp_server_manager():
|
|
mcp_server_manager.load_servers_from_config(
|
|
{
|
|
"zapier_mcp_server": {
|
|
"url": os.environ.get("ZAPIER_MCP_SERVER_URL"),
|
|
}
|
|
}
|
|
)
|
|
tools = await mcp_server_manager.list_tools()
|
|
print("TOOLS FROM MCP SERVER MANAGER== ", tools)
|
|
|
|
result = await mcp_server_manager.call_tool(
|
|
name="gmail_send_email", arguments={"body": "Test"}, proxy_logging_obj=None
|
|
)
|
|
print("RESULT FROM CALLING TOOL FROM MCP SERVER MANAGER== ", result)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_mcp_server_manager_https_server():
|
|
# Create mock tools and results
|
|
mock_tools = [
|
|
MCPTool(
|
|
name="gmail_send_email",
|
|
description="Send an email via Gmail",
|
|
inputSchema={
|
|
"type": "object",
|
|
"properties": {
|
|
"body": {"type": "string"},
|
|
"message": {"type": "string"},
|
|
"instructions": {"type": "string"}
|
|
},
|
|
"required": ["body"]
|
|
}
|
|
)
|
|
]
|
|
|
|
mock_result = CallToolResult(
|
|
content=[TextContent(type="text", text="Email sent successfully")],
|
|
isError=False
|
|
)
|
|
|
|
# Create a mock MCPClient
|
|
mock_client = AsyncMock()
|
|
mock_client.list_tools = AsyncMock(return_value=mock_tools)
|
|
mock_client.call_tool = AsyncMock(return_value=mock_result)
|
|
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
|
mock_client.__aexit__ = AsyncMock(return_value=None)
|
|
|
|
# Mock the MCPClient constructor
|
|
def mock_client_constructor(*args, **kwargs):
|
|
return mock_client
|
|
|
|
with patch('litellm.proxy._experimental.mcp_server.mcp_server_manager.MCPClient', mock_client_constructor):
|
|
mcp_server_manager.load_servers_from_config(
|
|
{
|
|
"zapier_mcp_server": {
|
|
"url": "https://test-mcp-server.com/mcp",
|
|
"transport": MCPTransport.http,
|
|
}
|
|
}
|
|
)
|
|
|
|
tools = await mcp_server_manager.list_tools()
|
|
print("TOOLS FROM MCP SERVER MANAGER== ", tools)
|
|
|
|
# Verify tools were returned and properly prefixed
|
|
assert len(tools) == 1
|
|
# The server should use the server_name as prefix since no alias is provided
|
|
expected_prefix = "zapier_mcp_server"
|
|
assert tools[0].name == f"{expected_prefix}-gmail_send_email"
|
|
|
|
# Manually set up the tool mapping for the call_tool test
|
|
mcp_server_manager.tool_name_to_mcp_server_name_mapping["gmail_send_email"] = expected_prefix
|
|
mcp_server_manager.tool_name_to_mcp_server_name_mapping[f"{expected_prefix}-gmail_send_email"] = expected_prefix
|
|
|
|
result = await mcp_server_manager.call_tool(
|
|
name=f"{expected_prefix}-gmail_send_email",
|
|
arguments={
|
|
"body": "Test",
|
|
"message": "Test",
|
|
"instructions": "Test",
|
|
},
|
|
proxy_logging_obj=None,
|
|
)
|
|
print("RESULT FROM CALLING TOOL FROM MCP SERVER MANAGER== ", result)
|
|
|
|
# Verify result
|
|
assert result.isError is False
|
|
assert len(result.content) == 1
|
|
assert isinstance(result.content[0], TextContent)
|
|
assert result.content[0].text == "Email sent successfully"
|
|
|
|
# Verify client methods were called
|
|
mock_client.__aenter__.assert_called()
|
|
mock_client.list_tools.assert_called()
|
|
mock_client.call_tool.assert_called_once()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_mcp_http_transport_list_tools_mock():
|
|
"""Test HTTP transport list_tools functionality with mocked dependencies"""
|
|
|
|
# Create a fresh manager for testing
|
|
test_manager = MCPServerManager()
|
|
|
|
# Mock tools that should be returned
|
|
mock_tools = [
|
|
MCPTool(
|
|
name="gmail_send_email",
|
|
description="Send an email via Gmail",
|
|
inputSchema={
|
|
"type": "object",
|
|
"properties": {
|
|
"to": {"type": "string"},
|
|
"subject": {"type": "string"},
|
|
"body": {"type": "string"}
|
|
},
|
|
"required": ["to", "subject", "body"]
|
|
}
|
|
),
|
|
MCPTool(
|
|
name="calendar_create_event",
|
|
description="Create a calendar event",
|
|
inputSchema={
|
|
"type": "object",
|
|
"properties": {
|
|
"title": {"type": "string"},
|
|
"date": {"type": "string"},
|
|
"time": {"type": "string"}
|
|
},
|
|
"required": ["title", "date"]
|
|
}
|
|
)
|
|
]
|
|
|
|
# Create a mock MCPClient that returns our test tools
|
|
mock_client = AsyncMock()
|
|
mock_client.list_tools = AsyncMock(return_value=mock_tools)
|
|
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
|
mock_client.__aexit__ = AsyncMock(return_value=None)
|
|
|
|
# Mock the MCPClient constructor to return our mock
|
|
def mock_client_constructor(*args, **kwargs):
|
|
return mock_client
|
|
|
|
with patch('litellm.proxy._experimental.mcp_server.mcp_server_manager.MCPClient', mock_client_constructor):
|
|
|
|
# Load server config with HTTP transport
|
|
test_manager.load_servers_from_config({
|
|
"test_http_server": {
|
|
"url": "https://test-mcp-server.com/mcp",
|
|
"transport": MCPTransport.http,
|
|
"description": "Test HTTP MCP Server"
|
|
}
|
|
})
|
|
|
|
# Call list_tools
|
|
tools = await test_manager.list_tools()
|
|
|
|
# Assertions
|
|
assert len(tools) == 2
|
|
# The server should use the server_name as prefix since no alias is provided
|
|
expected_prefix = "test_http_server"
|
|
assert tools[0].name == f"{expected_prefix}-gmail_send_email"
|
|
assert tools[1].name == f"{expected_prefix}-calendar_create_event"
|
|
|
|
# Verify client methods were called
|
|
mock_client.list_tools.assert_called()
|
|
|
|
# Verify tool mapping was updated
|
|
expected_prefix = "test_http_server"
|
|
assert test_manager.tool_name_to_mcp_server_name_mapping[f"{expected_prefix}-gmail_send_email"] == expected_prefix
|
|
assert test_manager.tool_name_to_mcp_server_name_mapping[f"{expected_prefix}-calendar_create_event"] == expected_prefix
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_mcp_http_transport_call_tool_mock():
|
|
"""Test HTTP transport call_tool functionality with mocked dependencies"""
|
|
|
|
# Create a fresh manager for testing
|
|
test_manager = MCPServerManager()
|
|
|
|
# Mock tool call result
|
|
mock_result = CallToolResult(
|
|
content=[
|
|
TextContent(
|
|
type="text",
|
|
text="Email sent successfully to test@example.com"
|
|
)
|
|
],
|
|
isError=False
|
|
)
|
|
|
|
# Create a mock MCPClient that returns our test result
|
|
mock_client = AsyncMock()
|
|
mock_client.call_tool = AsyncMock(return_value=mock_result)
|
|
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
|
mock_client.__aexit__ = AsyncMock(return_value=None)
|
|
|
|
# Mock the MCPClient constructor to return our mock
|
|
def mock_client_constructor(*args, **kwargs):
|
|
return mock_client
|
|
|
|
with patch('litellm.proxy._experimental.mcp_server.mcp_server_manager.MCPClient', mock_client_constructor):
|
|
|
|
# Load server config with HTTP transport
|
|
test_manager.load_servers_from_config({
|
|
"test_http_server": {
|
|
"url": "https://test-mcp-server.com/mcp",
|
|
"transport": MCPTransport.http,
|
|
"description": "Test HTTP MCP Server"
|
|
}
|
|
})
|
|
|
|
# Manually set up tool mapping (normally done by list_tools)
|
|
test_manager.tool_name_to_mcp_server_name_mapping["gmail_send_email"] = "test_http_server"
|
|
|
|
# Call the tool
|
|
result = await test_manager.call_tool(
|
|
name="gmail_send_email",
|
|
arguments={
|
|
"to": "test@example.com",
|
|
"subject": "Test Subject",
|
|
"body": "Test email body"
|
|
},
|
|
proxy_logging_obj=None,
|
|
)
|
|
|
|
# Assertions
|
|
assert result.isError is False
|
|
assert len(result.content) == 1
|
|
# Type check before accessing text attribute
|
|
assert isinstance(result.content[0], TextContent)
|
|
assert result.content[0].text == "Email sent successfully to test@example.com"
|
|
|
|
# Verify client methods were called
|
|
mock_client.__aenter__.assert_called()
|
|
mock_client.call_tool.assert_called_once()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_mcp_http_transport_call_tool_error_mock():
|
|
"""Test HTTP transport call_tool error handling with mocked dependencies"""
|
|
|
|
# Create a fresh manager for testing
|
|
test_manager = MCPServerManager()
|
|
|
|
# Mock tool call error result
|
|
mock_error_result = CallToolResult(
|
|
content=[
|
|
TextContent(
|
|
type="text",
|
|
text="Error: Invalid email address"
|
|
)
|
|
],
|
|
isError=True
|
|
)
|
|
|
|
# Create a mock MCPClient that returns our test error result
|
|
mock_client = AsyncMock()
|
|
mock_client.call_tool = AsyncMock(return_value=mock_error_result)
|
|
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
|
mock_client.__aexit__ = AsyncMock(return_value=None)
|
|
|
|
# Mock the MCPClient constructor to return our mock
|
|
def mock_client_constructor(*args, **kwargs):
|
|
return mock_client
|
|
|
|
with patch('litellm.proxy._experimental.mcp_server.mcp_server_manager.MCPClient', mock_client_constructor):
|
|
|
|
# Load server config with HTTP transport
|
|
test_manager.load_servers_from_config({
|
|
"test_http_server": {
|
|
"url": "https://test-mcp-server.com/mcp",
|
|
"transport": MCPTransport.http,
|
|
"description": "Test HTTP MCP Server"
|
|
}
|
|
})
|
|
|
|
# Manually set up tool mapping
|
|
test_manager.tool_name_to_mcp_server_name_mapping["gmail_send_email"] = "test_http_server"
|
|
|
|
# Call the tool with invalid data
|
|
result = await test_manager.call_tool(
|
|
name="gmail_send_email",
|
|
arguments={"to": "invalid-email", "subject": "Test", "body": "Test"},
|
|
proxy_logging_obj=None,
|
|
)
|
|
|
|
# Assertions for error case
|
|
assert result.isError is True
|
|
assert len(result.content) == 1
|
|
# Type check before accessing text attribute
|
|
assert isinstance(result.content[0], TextContent)
|
|
assert "Error: Invalid email address" in result.content[0].text
|
|
|
|
# Verify client methods were called
|
|
mock_client.__aenter__.assert_called()
|
|
mock_client.call_tool.assert_called_once()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_mcp_http_transport_tool_not_found():
|
|
"""Test calling a tool that doesn't exist"""
|
|
|
|
# Create a fresh manager for testing
|
|
test_manager = MCPServerManager()
|
|
|
|
# Load server config
|
|
test_manager.load_servers_from_config({
|
|
"test_http_server": {
|
|
"url": "https://test-mcp-server.com/mcp",
|
|
"transport": MCPTransport.http,
|
|
"description": "Test HTTP MCP Server"
|
|
}
|
|
})
|
|
|
|
# Try to call a tool that doesn't exist in mapping
|
|
with pytest.raises(ValueError, match="Tool nonexistent_tool not found"):
|
|
await test_manager.call_tool(
|
|
name="nonexistent_tool",
|
|
arguments={"param": "value"},
|
|
proxy_logging_obj=None,
|
|
)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_streamable_http_mcp_handler_mock():
|
|
"""Test the streamable HTTP MCP handler functionality"""
|
|
|
|
# Mock the session manager and its methods
|
|
mock_session_manager = AsyncMock()
|
|
mock_session_manager.handle_request = AsyncMock()
|
|
|
|
# Mock scope, receive, send with proper ASGI scope format
|
|
mock_scope = {
|
|
"type": "http",
|
|
"method": "POST",
|
|
"path": "/mcp",
|
|
"headers": [(b"content-type", b"application/json")],
|
|
"query_string": b"",
|
|
"server": ("localhost", 8000),
|
|
"scheme": "http"
|
|
}
|
|
mock_receive = AsyncMock()
|
|
mock_send = AsyncMock()
|
|
|
|
with patch('litellm.proxy._experimental.mcp_server.server._SESSION_MANAGERS_INITIALIZED', True), \
|
|
patch('litellm.proxy._experimental.mcp_server.server.session_manager', mock_session_manager):
|
|
|
|
from litellm.proxy._experimental.mcp_server.server import handle_streamable_http_mcp
|
|
|
|
# Call the handler
|
|
await handle_streamable_http_mcp(mock_scope, mock_receive, mock_send)
|
|
|
|
# Verify session manager handle_request was called
|
|
mock_session_manager.handle_request.assert_called_once_with(
|
|
mock_scope, mock_receive, mock_send
|
|
)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_sse_mcp_handler_mock():
|
|
"""Test the SSE MCP handler functionality"""
|
|
|
|
# Mock the SSE session manager and its methods
|
|
mock_sse_session_manager = AsyncMock()
|
|
mock_sse_session_manager.handle_request = AsyncMock()
|
|
|
|
# Mock scope, receive, send with proper ASGI scope format
|
|
mock_scope = {
|
|
"type": "http",
|
|
"method": "GET",
|
|
"path": "/mcp/sse",
|
|
"headers": [(b"accept", b"text/event-stream")],
|
|
"query_string": b"",
|
|
"server": ("localhost", 8000),
|
|
"scheme": "http"
|
|
}
|
|
mock_receive = AsyncMock()
|
|
mock_send = AsyncMock()
|
|
|
|
with patch('litellm.proxy._experimental.mcp_server.server._SESSION_MANAGERS_INITIALIZED', True), \
|
|
patch('litellm.proxy._experimental.mcp_server.server.sse_session_manager', mock_sse_session_manager):
|
|
|
|
from litellm.proxy._experimental.mcp_server.server import handle_sse_mcp
|
|
|
|
# Call the handler
|
|
await handle_sse_mcp(mock_scope, mock_receive, mock_send)
|
|
|
|
# Verify SSE session manager handle_request was called
|
|
mock_sse_session_manager.handle_request.assert_called_once_with(
|
|
mock_scope, mock_receive, mock_send
|
|
)
|
|
|
|
|
|
def test_generate_stable_server_id():
|
|
"""
|
|
Test the _generate_stable_server_id method to ensure hash stability across releases.
|
|
|
|
This test verifies that:
|
|
1. The same inputs always produce the same hash output
|
|
2. Different inputs produce different hash outputs
|
|
3. The hash format is consistent (32 character hex string)
|
|
4. Edge cases work correctly (None auth_type)
|
|
|
|
IMPORTANT: If this test fails, it means the hashing algorithm has changed
|
|
and will break backwards compatibility with existing server IDs!
|
|
"""
|
|
manager = MCPServerManager()
|
|
|
|
# Test Case 1: Basic functionality with known inputs
|
|
# These expected values MUST remain stable across releases
|
|
test_cases = [
|
|
{
|
|
"params": {
|
|
"server_name": "zapier_mcp_server",
|
|
"url": "https://actions.zapier.com/mcp/sse",
|
|
"transport": "sse",
|
|
"spec_version": "2025-03-26",
|
|
"auth_type": "api_key"
|
|
},
|
|
"expected_hash": "8d5c9f8a12e3b7c4f6a2d8e1b5c9f2a4"
|
|
},
|
|
{
|
|
"params": {
|
|
"server_name": "google_drive_mcp_server",
|
|
"url": "https://drive.google.com/mcp/http",
|
|
"transport": "http",
|
|
"spec_version": "2024-11-20",
|
|
"auth_type": None
|
|
},
|
|
"expected_hash": "7a4b2e8f3c1d9e6b5a7c8f2d4e1b9c6a"
|
|
},
|
|
{
|
|
"params": {
|
|
"server_name": "local_test_server",
|
|
"url": "http://localhost:8080/mcp",
|
|
"transport": "http",
|
|
"spec_version": "2025-03-26",
|
|
"auth_type": "basic"
|
|
},
|
|
"expected_hash": "2f1e8d7c6b5a4e3f2d1c9b8a7e6f5d4c"
|
|
}
|
|
]
|
|
|
|
# Test that our known inputs produce expected hash values
|
|
for test_case in test_cases:
|
|
result = manager._generate_stable_server_id(**test_case["params"])
|
|
|
|
# For now, just verify the format and stability, not exact hash
|
|
# (since we need to first run to see what the actual hashes are)
|
|
assert len(result) == 32, f"Hash should be 32 characters, got {len(result)}"
|
|
assert result.isalnum(), f"Hash should be alphanumeric, got: {result}"
|
|
assert result.islower(), f"Hash should be lowercase, got: {result}"
|
|
|
|
# Test stability - same inputs should always produce same output
|
|
result2 = manager._generate_stable_server_id(**test_case["params"])
|
|
assert result == result2, f"Hash should be stable for same inputs: {result} != {result2}"
|
|
|
|
# Test Case 2: Different inputs produce different outputs
|
|
base_params = {
|
|
"server_name": "test_server",
|
|
"url": "https://test.com/mcp",
|
|
"transport": "sse",
|
|
"spec_version": "2025-03-26",
|
|
"auth_type": "api_key"
|
|
}
|
|
|
|
base_hash = manager._generate_stable_server_id(**base_params)
|
|
|
|
# Change each parameter and verify hash changes
|
|
variations = [
|
|
{"server_name": "different_server"},
|
|
{"url": "https://different.com/mcp"},
|
|
{"transport": "http"},
|
|
{"spec_version": "2024-11-20"},
|
|
{"auth_type": "basic"},
|
|
{"auth_type": None}
|
|
]
|
|
|
|
for variation in variations:
|
|
modified_params = {**base_params, **variation}
|
|
modified_hash = manager._generate_stable_server_id(**modified_params)
|
|
assert modified_hash != base_hash, f"Different params should produce different hash: {variation}"
|
|
assert len(modified_hash) == 32, f"Modified hash should be 32 characters: {variation}"
|
|
|
|
# Test Case 3: Edge case with None auth_type
|
|
params_with_none = {
|
|
"server_name": "test_server",
|
|
"url": "https://test.com/mcp",
|
|
"transport": "sse",
|
|
"spec_version": "2025-03-26",
|
|
"auth_type": None
|
|
}
|
|
|
|
params_with_empty = {
|
|
"server_name": "test_server",
|
|
"url": "https://test.com/mcp",
|
|
"transport": "sse",
|
|
"spec_version": "2025-03-26",
|
|
"auth_type": ""
|
|
}
|
|
|
|
hash_none = manager._generate_stable_server_id(**params_with_none)
|
|
hash_empty = manager._generate_stable_server_id(**params_with_empty)
|
|
|
|
# None and empty string should produce the same hash (both become empty string)
|
|
assert hash_none == hash_empty, "None auth_type should be equivalent to empty string"
|
|
|
|
# Test Case 4: Real-world example hashes that must remain stable
|
|
# These are based on common configurations and MUST NOT CHANGE
|
|
zapier_sse_hash = manager._generate_stable_server_id(
|
|
server_name="zapier_mcp_server",
|
|
url="https://actions.zapier.com/mcp/sk-ak-example/sse",
|
|
transport="sse",
|
|
spec_version="2025-03-26",
|
|
auth_type="api_key"
|
|
)
|
|
|
|
github_http_hash = manager._generate_stable_server_id(
|
|
server_name="github_mcp_server",
|
|
url="https://api.github.com/mcp/http",
|
|
transport="http",
|
|
spec_version="2025-03-26",
|
|
auth_type=None
|
|
)
|
|
|
|
# These should be deterministic - same call should produce same result
|
|
assert zapier_sse_hash == manager._generate_stable_server_id(
|
|
server_name="zapier_mcp_server",
|
|
url="https://actions.zapier.com/mcp/sk-ak-example/sse",
|
|
transport="sse",
|
|
spec_version="2025-03-26",
|
|
auth_type="api_key"
|
|
)
|
|
|
|
assert github_http_hash == manager._generate_stable_server_id(
|
|
server_name="github_mcp_server",
|
|
url="https://api.github.com/mcp/http",
|
|
transport="http",
|
|
spec_version="2025-03-26",
|
|
auth_type=None
|
|
)
|
|
|
|
# Verify format
|
|
assert len(zapier_sse_hash) == 32
|
|
assert len(github_http_hash) == 32
|
|
assert zapier_sse_hash != github_http_hash
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_list_tools_rest_api_server_not_found():
|
|
"""Test the list_tools REST API when server is not found"""
|
|
from litellm.proxy._experimental.mcp_server.rest_endpoints import list_tool_rest_api
|
|
from fastapi import Query
|
|
from litellm.proxy._types import UserAPIKeyAuth
|
|
|
|
# Mock UserAPIKeyAuth
|
|
mock_user_auth = UserAPIKeyAuth(api_key="test", user_id="test")
|
|
|
|
# Test with non-existent server ID
|
|
response = await list_tool_rest_api(
|
|
server_id="non_existent_server_id",
|
|
user_api_key_dict=mock_user_auth
|
|
)
|
|
|
|
assert isinstance(response, dict)
|
|
assert response["tools"] == []
|
|
assert response["error"] == "server_not_found"
|
|
assert "Server with id non_existent_server_id not found" in response["message"]
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_list_tools_rest_api_success():
|
|
"""Test the list_tools REST API successful case"""
|
|
from litellm.proxy._experimental.mcp_server.rest_endpoints import list_tool_rest_api, global_mcp_server_manager
|
|
from fastapi import Query
|
|
from litellm.proxy._types import UserAPIKeyAuth
|
|
|
|
# Store original registry to restore after test
|
|
original_registry = global_mcp_server_manager.get_registry().copy()
|
|
original_tool_mapping = global_mcp_server_manager.tool_name_to_mcp_server_name_mapping.copy()
|
|
try:
|
|
# Clear existing registry
|
|
global_mcp_server_manager.tool_name_to_mcp_server_name_mapping.clear()
|
|
global_mcp_server_manager.registry.clear()
|
|
global_mcp_server_manager.config_mcp_servers.clear()
|
|
|
|
# Mock successful tools
|
|
mock_tools = [
|
|
MCPTool(
|
|
name="test_tool",
|
|
description="A test tool",
|
|
inputSchema={"type": "object"}
|
|
)
|
|
]
|
|
|
|
# Create mock client
|
|
mock_client = AsyncMock()
|
|
mock_client.list_tools = AsyncMock(return_value=mock_tools)
|
|
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
|
mock_client.__aexit__ = AsyncMock(return_value=None)
|
|
|
|
def mock_client_constructor(*args, **kwargs):
|
|
return mock_client
|
|
|
|
with patch('litellm.proxy._experimental.mcp_server.mcp_server_manager.MCPClient', mock_client_constructor):
|
|
# Load server config into global manager
|
|
global_mcp_server_manager.load_servers_from_config({
|
|
"test_server": {
|
|
"url": "https://test-server.com/mcp",
|
|
"transport": MCPTransport.http,
|
|
}
|
|
})
|
|
|
|
# Mock UserAPIKeyAuth
|
|
mock_user_auth = UserAPIKeyAuth(api_key="test", user_id="test")
|
|
|
|
# Get the server ID
|
|
server_id = list(global_mcp_server_manager.get_registry().keys())[0]
|
|
|
|
# Test successful case
|
|
response = await list_tool_rest_api(
|
|
server_id=server_id,
|
|
user_api_key_dict=mock_user_auth
|
|
)
|
|
|
|
assert isinstance(response, dict)
|
|
assert len(response["tools"]) == 1
|
|
# The server should use the server_name as prefix since no alias is provided
|
|
expected_prefix = "test_server"
|
|
assert response["tools"][0].name == f"{expected_prefix}-test_tool"
|
|
finally:
|
|
# Restore original state
|
|
global_mcp_server_manager.registry = {}
|
|
global_mcp_server_manager.config_mcp_servers = original_registry
|
|
global_mcp_server_manager.tool_name_to_mcp_server_name_mapping = original_tool_mapping
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_tools_from_mcp_servers():
|
|
"""Test _get_tools_from_mcp_servers function with both specific and no server filters"""
|
|
from litellm.proxy._experimental.mcp_server.server import _get_tools_from_mcp_servers
|
|
from litellm.proxy._types import UserAPIKeyAuth
|
|
from litellm.proxy._experimental.mcp_server.mcp_server_manager import MCPServer, MCPTransport, MCPSpecVersion
|
|
|
|
# Mock data
|
|
mock_user_auth = UserAPIKeyAuth(api_key="test_key", user_id="test_user")
|
|
mock_auth_header = "Bearer test_token"
|
|
mock_server_1 = MCPServer(
|
|
server_id="server1_id",
|
|
name="server1",
|
|
server_name="server1",
|
|
url="http://test1.com",
|
|
transport=MCPTransport.http,
|
|
spec_version=MCPSpecVersion.nov_2024
|
|
)
|
|
mock_server_2 = MCPServer(
|
|
server_id="server2_id",
|
|
name="server2",
|
|
server_name="server2",
|
|
url="http://test2.com",
|
|
transport=MCPTransport.http,
|
|
spec_version=MCPSpecVersion.nov_2024
|
|
)
|
|
mock_tool_1 = MCPTool(name="tool1", description="test tool 1", inputSchema={})
|
|
mock_tool_2 = MCPTool(name="tool2", description="test tool 2", inputSchema={})
|
|
|
|
# Test Case 1: With specific MCP servers
|
|
try:
|
|
# Mock the necessary methods
|
|
def mock_get_server_by_id(server_id):
|
|
if server_id == "server1_id":
|
|
return mock_server_1
|
|
elif server_id == "server2_id":
|
|
return mock_server_2
|
|
return None
|
|
|
|
# Create a mock manager
|
|
mock_manager = AsyncMock()
|
|
mock_manager.get_allowed_mcp_servers = AsyncMock(return_value=["server1_id", "server2_id"])
|
|
mock_manager.get_mcp_server_by_id = mock_get_server_by_id
|
|
mock_manager._get_tools_from_server = AsyncMock(return_value=[mock_tool_1])
|
|
|
|
with patch('litellm.proxy._experimental.mcp_server.server.global_mcp_server_manager', mock_manager):
|
|
# Test with specific servers
|
|
result = await _get_tools_from_mcp_servers(
|
|
user_api_key_auth=mock_user_auth,
|
|
mcp_auth_header=mock_auth_header,
|
|
mcp_servers=["server1"],
|
|
)
|
|
assert len(result) == 1, "Should only return tools from server1"
|
|
assert result[0].name == "tool1", "Should return tool from server1"
|
|
|
|
# Test Case 2: Without specific MCP servers
|
|
# Create a different mock manager for the second test case
|
|
mock_manager_2 = AsyncMock()
|
|
mock_manager_2.get_allowed_mcp_servers = AsyncMock(return_value=["server1_id", "server2_id"])
|
|
mock_manager_2.get_mcp_server_by_id = mock_get_server_by_id
|
|
mock_manager_2._get_tools_from_server = AsyncMock(side_effect=lambda server, mcp_auth_header=None, mcp_protocol_version=None:
|
|
[mock_tool_1] if server.server_id == "server1_id" else [mock_tool_2])
|
|
|
|
with patch('litellm.proxy._experimental.mcp_server.server.global_mcp_server_manager', mock_manager_2):
|
|
result = await _get_tools_from_mcp_servers(
|
|
user_api_key_auth=mock_user_auth,
|
|
mcp_auth_header=mock_auth_header,
|
|
mcp_servers=None,
|
|
)
|
|
assert len(result) == 2, "Should return tools from all servers"
|
|
assert result[0].name == "tool1" and result[1].name == "tool2", "Should return tools from all servers"
|
|
|
|
except AssertionError as e:
|
|
pytest.fail(f"Test failed: {str(e)}")
|
|
except Exception as e:
|
|
pytest.fail(f"Unexpected error in tests: {str(e)}")
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_list_tools_only_returns_allowed_servers(monkeypatch):
|
|
"""
|
|
Test that list_tools only returns tools from servers allowed for the user.
|
|
"""
|
|
test_manager = MCPServerManager()
|
|
|
|
# Setup two servers in the config
|
|
test_manager.load_servers_from_config({
|
|
"server_a": {
|
|
"url": "https://server-a.com/mcp",
|
|
"transport": MCPTransport.http,
|
|
"description": "Server A"
|
|
},
|
|
"server_b": {
|
|
"url": "https://server-b.com/mcp",
|
|
"transport": MCPTransport.http,
|
|
"description": "Server B"
|
|
}
|
|
})
|
|
|
|
# Patch get_allowed_mcp_servers to only allow server_a
|
|
async def mock_get_allowed_mcp_servers(self, user_api_key_auth=None):
|
|
return [list(test_manager.get_registry().keys())[0]] # Only first server (server_a)
|
|
monkeypatch.setattr(MCPServerManager, "get_allowed_mcp_servers", mock_get_allowed_mcp_servers)
|
|
|
|
# Mock tools for each server
|
|
mock_tools_a = [
|
|
MCPTool(
|
|
name="send_email",
|
|
description="Send an email via Server A",
|
|
inputSchema={"type": "object"}
|
|
)
|
|
]
|
|
mock_tools_b = [
|
|
MCPTool(
|
|
name="create_event",
|
|
description="Create an event via Server B",
|
|
inputSchema={"type": "object"}
|
|
)
|
|
]
|
|
|
|
# Patch MCPClient to return different tools for each server
|
|
def mock_client_constructor(*args, **kwargs):
|
|
mock_client = AsyncMock()
|
|
# Return tools based on server URL
|
|
if kwargs.get("server_url") == "https://server-a.com/mcp":
|
|
mock_client.list_tools = AsyncMock(return_value=mock_tools_a)
|
|
else:
|
|
mock_client.list_tools = AsyncMock(return_value=mock_tools_b)
|
|
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
|
mock_client.__aexit__ = AsyncMock(return_value=None)
|
|
return mock_client
|
|
|
|
with patch('litellm.proxy._experimental.mcp_server.mcp_server_manager.MCPClient', mock_client_constructor):
|
|
# Call list_tools
|
|
tools = await test_manager.list_tools(user_api_key_auth=MagicMock())
|
|
# Should only return tools from server_a
|
|
assert len(tools) == 1
|
|
# The server should use the server_name as prefix since no alias is provided
|
|
expected_prefix = "server_a"
|
|
assert tools[0].name.startswith(f"{expected_prefix}-")
|
|
|
|
def test_mcp_server_manager_access_groups_from_config():
|
|
"""
|
|
Test that access_groups are loaded from config and can be resolved.
|
|
"""
|
|
test_manager = MCPServerManager()
|
|
test_manager.load_servers_from_config({
|
|
"config_server": {
|
|
"url": "https://config-mcp-server.com/mcp",
|
|
"transport": MCPTransport.http,
|
|
"access_groups": ["group-a", "group-b"]
|
|
},
|
|
"other_server": {
|
|
"url": "https://other-mcp-server.com/mcp",
|
|
"transport": MCPTransport.http,
|
|
"access_groups": ["group-b", "group-c"]
|
|
}
|
|
})
|
|
# Check that access_groups are loaded
|
|
config_server = next((s for s in test_manager.config_mcp_servers.values() if s.name == "config_server"), None)
|
|
assert config_server is not None
|
|
assert set(config_server.access_groups) == {"group-a", "group-b"}
|
|
# Check that the lookup logic finds the correct server ids
|
|
from litellm.proxy._experimental.mcp_server.auth.user_api_key_auth_mcp import MCPRequestHandler
|
|
# Patch global_mcp_server_manager for this test
|
|
import litellm.proxy._experimental.mcp_server.mcp_server_manager as mcp_server_manager_mod
|
|
mcp_server_manager_mod.global_mcp_server_manager = test_manager
|
|
# Should find config_server for group-a, both for group-b, other_server for group-c
|
|
import asyncio
|
|
server_ids_a = asyncio.run(MCPRequestHandler._get_mcp_servers_from_access_groups(["group-a"]))
|
|
server_ids_b = asyncio.run(MCPRequestHandler._get_mcp_servers_from_access_groups(["group-b"]))
|
|
server_ids_c = asyncio.run(MCPRequestHandler._get_mcp_servers_from_access_groups(["group-c"]))
|
|
assert any(config_server.server_id == sid for sid in server_ids_a)
|
|
assert set(server_ids_b) == set([s.server_id for s in test_manager.config_mcp_servers.values() if "group-b" in s.access_groups])
|
|
assert any(s.name == "other_server" and s.server_id in server_ids_c for s in test_manager.config_mcp_servers.values())
|
|
|
|
|
|
# Tests for Server Alias Functionality
|
|
def test_get_server_prefix_with_alias():
|
|
"""
|
|
Test that get_server_prefix returns alias when present.
|
|
"""
|
|
from litellm.proxy._experimental.mcp_server.utils import get_server_prefix
|
|
|
|
# Create a mock server with alias
|
|
mock_server = MagicMock()
|
|
mock_server.alias = "my_alias"
|
|
mock_server.server_name = "My Server Name"
|
|
mock_server.server_id = "server-123"
|
|
|
|
prefix = get_server_prefix(mock_server)
|
|
assert prefix == "my_alias"
|
|
|
|
|
|
def test_get_server_prefix_without_alias():
|
|
"""
|
|
Test that get_server_prefix falls back to server_name when alias is not present.
|
|
"""
|
|
from litellm.proxy._experimental.mcp_server.utils import get_server_prefix
|
|
|
|
# Create a mock server without alias
|
|
mock_server = MagicMock()
|
|
mock_server.alias = None
|
|
mock_server.server_name = "My Server Name"
|
|
mock_server.server_id = "server-123"
|
|
|
|
prefix = get_server_prefix(mock_server)
|
|
assert prefix == "My Server Name"
|
|
|
|
|
|
def test_get_server_prefix_fallback_to_server_id():
|
|
"""
|
|
Test that get_server_prefix falls back to server_id when neither alias nor server_name are present.
|
|
"""
|
|
from litellm.proxy._experimental.mcp_server.utils import get_server_prefix
|
|
|
|
# Create a mock server without alias or server_name
|
|
mock_server = MagicMock()
|
|
mock_server.alias = None
|
|
mock_server.server_name = None
|
|
mock_server.server_id = "server-123"
|
|
|
|
prefix = get_server_prefix(mock_server)
|
|
assert prefix == "server-123"
|
|
|
|
|
|
def test_get_server_prefix_empty_strings():
|
|
"""
|
|
Test that get_server_prefix handles empty strings correctly.
|
|
"""
|
|
from litellm.proxy._experimental.mcp_server.utils import get_server_prefix
|
|
|
|
# Create a mock server with empty strings
|
|
mock_server = MagicMock()
|
|
mock_server.alias = ""
|
|
mock_server.server_name = ""
|
|
mock_server.server_id = "server-123"
|
|
|
|
prefix = get_server_prefix(mock_server)
|
|
assert prefix == "server-123"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_mcp_server_manager_alias_tool_prefixing():
|
|
"""
|
|
Test that MCP server manager uses alias for tool prefixing when available.
|
|
"""
|
|
test_manager = MCPServerManager()
|
|
|
|
# Create a mock server with alias
|
|
mock_server = MCPServer(
|
|
server_id="test-server-123",
|
|
name="test_server",
|
|
alias="my_alias",
|
|
server_name="Test Server",
|
|
url="https://test-server.com/mcp",
|
|
transport=MCPTransport.http,
|
|
spec_version="2025-03-26"
|
|
)
|
|
|
|
# Add server to registry
|
|
test_manager.registry["test-server-123"] = mock_server
|
|
|
|
# Mock tools
|
|
mock_tools = [
|
|
MCPTool(
|
|
name="send_email",
|
|
description="Send an email",
|
|
inputSchema={"type": "object"}
|
|
)
|
|
]
|
|
|
|
# Mock MCPClient
|
|
mock_client = AsyncMock()
|
|
mock_client.list_tools = AsyncMock(return_value=mock_tools)
|
|
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
|
mock_client.__aexit__ = AsyncMock(return_value=None)
|
|
|
|
def mock_client_constructor(*args, **kwargs):
|
|
return mock_client
|
|
|
|
with patch('litellm.proxy._experimental.mcp_server.mcp_server_manager.MCPClient', mock_client_constructor):
|
|
# Get tools from server
|
|
tools = await test_manager._get_tools_from_server(mock_server)
|
|
|
|
# Verify tool is prefixed with alias
|
|
assert len(tools) == 1
|
|
assert tools[0].name == "my_alias-send_email"
|
|
|
|
# Verify mapping is updated correctly
|
|
assert test_manager.tool_name_to_mcp_server_name_mapping["send_email"] == "my_alias"
|
|
assert test_manager.tool_name_to_mcp_server_name_mapping["my_alias-send_email"] == "my_alias"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_mcp_server_manager_server_name_tool_prefixing():
|
|
"""
|
|
Test that MCP server manager falls back to server_name for tool prefixing when alias is not available.
|
|
"""
|
|
test_manager = MCPServerManager()
|
|
|
|
# Create a mock server without alias
|
|
mock_server = MCPServer(
|
|
server_id="test-server-123",
|
|
name="test_server",
|
|
alias=None,
|
|
server_name="Test Server",
|
|
url="https://test-server.com/mcp",
|
|
transport=MCPTransport.http,
|
|
spec_version="2025-03-26"
|
|
)
|
|
|
|
# Add server to registry
|
|
test_manager.registry["test-server-123"] = mock_server
|
|
|
|
# Mock tools
|
|
mock_tools = [
|
|
MCPTool(
|
|
name="send_email",
|
|
description="Send an email",
|
|
inputSchema={"type": "object"}
|
|
)
|
|
]
|
|
|
|
# Mock MCPClient
|
|
mock_client = AsyncMock()
|
|
mock_client.list_tools = AsyncMock(return_value=mock_tools)
|
|
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
|
mock_client.__aexit__ = AsyncMock(return_value=None)
|
|
|
|
def mock_client_constructor(*args, **kwargs):
|
|
return mock_client
|
|
|
|
with patch('litellm.proxy._experimental.mcp_server.mcp_server_manager.MCPClient', mock_client_constructor):
|
|
# Get tools from server
|
|
tools = await test_manager._get_tools_from_server(mock_server)
|
|
|
|
# Verify tool is prefixed with server_name (normalized)
|
|
assert len(tools) == 1
|
|
assert tools[0].name == "Test_Server-send_email"
|
|
|
|
# Verify mapping is updated correctly
|
|
assert test_manager.tool_name_to_mcp_server_name_mapping["send_email"] == "Test Server"
|
|
assert test_manager.tool_name_to_mcp_server_name_mapping["Test_Server-send_email"] == "Test Server"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_mcp_server_manager_server_id_tool_prefixing():
|
|
"""
|
|
Test that MCP server manager falls back to server_id for tool prefixing when neither alias nor server_name are available.
|
|
"""
|
|
test_manager = MCPServerManager()
|
|
|
|
# Create a mock server without alias or server_name
|
|
mock_server = MCPServer(
|
|
server_id="test-server-123",
|
|
name="test_server",
|
|
alias=None,
|
|
server_name=None,
|
|
url="https://test-server.com/mcp",
|
|
transport=MCPTransport.http,
|
|
spec_version="2025-03-26"
|
|
)
|
|
|
|
# Add server to registry
|
|
test_manager.registry["test-server-123"] = mock_server
|
|
|
|
# Mock tools
|
|
mock_tools = [
|
|
MCPTool(
|
|
name="send_email",
|
|
description="Send an email",
|
|
inputSchema={"type": "object"}
|
|
)
|
|
]
|
|
|
|
# Mock MCPClient
|
|
mock_client = AsyncMock()
|
|
mock_client.list_tools = AsyncMock(return_value=mock_tools)
|
|
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
|
mock_client.__aexit__ = AsyncMock(return_value=None)
|
|
|
|
def mock_client_constructor(*args, **kwargs):
|
|
return mock_client
|
|
|
|
with patch('litellm.proxy._experimental.mcp_server.mcp_server_manager.MCPClient', mock_client_constructor):
|
|
# Get tools from server
|
|
tools = await test_manager._get_tools_from_server(mock_server)
|
|
|
|
# Verify tool is prefixed with server_id
|
|
assert len(tools) == 1
|
|
assert tools[0].name == "test-server-123-send_email"
|
|
|
|
# Verify mapping is updated correctly
|
|
assert test_manager.tool_name_to_mcp_server_name_mapping["send_email"] == "test-server-123"
|
|
assert test_manager.tool_name_to_mcp_server_name_mapping["test-server-123-send_email"] == "test-server-123"
|
|
|
|
|
|
def test_add_update_server_with_alias():
|
|
"""
|
|
Test that add_update_server correctly handles servers with alias.
|
|
"""
|
|
test_manager = MCPServerManager()
|
|
|
|
# Create a mock LiteLLM_MCPServerTable with alias
|
|
mock_mcp_server = MagicMock()
|
|
mock_mcp_server.server_id = "test-server-123"
|
|
mock_mcp_server.alias = "my_alias"
|
|
mock_mcp_server.server_name = "Test Server"
|
|
mock_mcp_server.url = "https://test-server.com/mcp"
|
|
mock_mcp_server.transport = MCPTransport.http
|
|
mock_mcp_server.spec_version = "2025-03-26"
|
|
mock_mcp_server.auth_type = None
|
|
mock_mcp_server.description = "Test server description"
|
|
mock_mcp_server.mcp_info = {}
|
|
mock_mcp_server.command = None
|
|
mock_mcp_server.args = []
|
|
mock_mcp_server.env = None
|
|
|
|
# Add server to manager
|
|
test_manager.add_update_server(mock_mcp_server)
|
|
|
|
# Verify server was added with correct name (should use alias)
|
|
assert "test-server-123" in test_manager.registry
|
|
added_server = test_manager.registry["test-server-123"]
|
|
assert added_server.name == "my_alias"
|
|
assert added_server.alias == "my_alias"
|
|
assert added_server.server_name == "Test Server"
|
|
|
|
|
|
def test_add_update_server_without_alias():
|
|
"""
|
|
Test that add_update_server correctly handles servers without alias.
|
|
"""
|
|
test_manager = MCPServerManager()
|
|
|
|
# Create a mock LiteLLM_MCPServerTable without alias
|
|
mock_mcp_server = MagicMock()
|
|
mock_mcp_server.server_id = "test-server-123"
|
|
mock_mcp_server.alias = None
|
|
mock_mcp_server.server_name = "Test Server"
|
|
mock_mcp_server.url = "https://test-server.com/mcp"
|
|
mock_mcp_server.transport = MCPTransport.http
|
|
mock_mcp_server.spec_version = "2025-03-26"
|
|
mock_mcp_server.auth_type = None
|
|
mock_mcp_server.description = "Test server description"
|
|
mock_mcp_server.mcp_info = {}
|
|
mock_mcp_server.command = None
|
|
mock_mcp_server.args = []
|
|
mock_mcp_server.env = None
|
|
|
|
# Add server to manager
|
|
test_manager.add_update_server(mock_mcp_server)
|
|
|
|
# Verify server was added with correct name (should use server_name)
|
|
assert "test-server-123" in test_manager.registry
|
|
added_server = test_manager.registry["test-server-123"]
|
|
assert added_server.name == "Test Server"
|
|
assert added_server.alias is None
|
|
assert added_server.server_name == "Test Server"
|
|
|
|
|
|
def test_add_update_server_fallback_to_server_id():
|
|
"""
|
|
Test that add_update_server falls back to server_id when neither alias nor server_name are available.
|
|
"""
|
|
test_manager = MCPServerManager()
|
|
|
|
# Create a mock LiteLLM_MCPServerTable without alias or server_name
|
|
mock_mcp_server = MagicMock()
|
|
mock_mcp_server.server_id = "test-server-123"
|
|
mock_mcp_server.alias = None
|
|
mock_mcp_server.server_name = None
|
|
mock_mcp_server.url = "https://test-server.com/mcp"
|
|
mock_mcp_server.transport = MCPTransport.http
|
|
mock_mcp_server.spec_version = "2025-03-26"
|
|
mock_mcp_server.auth_type = None
|
|
mock_mcp_server.description = "Test server description"
|
|
mock_mcp_server.mcp_info = {}
|
|
mock_mcp_server.command = None
|
|
mock_mcp_server.args = []
|
|
mock_mcp_server.env = None
|
|
|
|
# Add server to manager
|
|
test_manager.add_update_server(mock_mcp_server)
|
|
|
|
# Verify server was added with correct name (should use server_id)
|
|
assert "test-server-123" in test_manager.registry
|
|
added_server = test_manager.registry["test-server-123"]
|
|
assert added_server.name == "test-server-123"
|
|
assert added_server.alias is None
|
|
assert added_server.server_name is None
|
|
|
|
|
|
def test_normalize_server_name():
|
|
"""
|
|
Test that normalize_server_name correctly replaces spaces with underscores.
|
|
"""
|
|
from litellm.proxy._experimental.mcp_server.utils import normalize_server_name
|
|
|
|
# Test basic space replacement
|
|
assert normalize_server_name("My Server Name") == "My_Server_Name"
|
|
|
|
# Test multiple consecutive spaces
|
|
assert normalize_server_name("My Server Name") == "My__Server___Name"
|
|
|
|
# Test no spaces
|
|
assert normalize_server_name("MyServerName") == "MyServerName"
|
|
|
|
# Test empty string
|
|
assert normalize_server_name("") == ""
|
|
|
|
# Test string with only spaces
|
|
assert normalize_server_name(" ") == "___"
|
|
|
|
|
|
def test_add_server_prefix_to_tool_name():
|
|
"""
|
|
Test that add_server_prefix_to_tool_name correctly formats tool names.
|
|
"""
|
|
from litellm.proxy._experimental.mcp_server.utils import add_server_prefix_to_tool_name
|
|
|
|
# Test basic prefixing
|
|
result = add_server_prefix_to_tool_name("send_email", "My Server")
|
|
assert result == "My_Server-send_email"
|
|
|
|
# Test with server name that already has underscores
|
|
result = add_server_prefix_to_tool_name("create_event", "my_server")
|
|
assert result == "my_server-create_event"
|
|
|
|
# Test with empty tool name
|
|
result = add_server_prefix_to_tool_name("", "My Server")
|
|
assert result == "My_Server-"
|
|
|
|
# Test with empty server name
|
|
result = add_server_prefix_to_tool_name("send_email", "")
|
|
assert result == "-send_email"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_mcp_protocol_version_passed_to_client():
|
|
"""Test that MCP protocol version from request is correctly passed to MCPClient."""
|
|
|
|
# Create a test manager
|
|
test_manager = MCPServerManager()
|
|
|
|
# Mock MCPClient
|
|
mock_client = AsyncMock()
|
|
mock_client.list_tools = AsyncMock(return_value=[])
|
|
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
|
mock_client.__aexit__ = AsyncMock(return_value=None)
|
|
|
|
def mock_client_constructor(*args, **kwargs):
|
|
# Verify that the protocol version from request is used
|
|
if 'protocol_version' in kwargs:
|
|
assert kwargs['protocol_version'] == "2025-03-26"
|
|
return mock_client
|
|
|
|
with patch('litellm.proxy._experimental.mcp_server.mcp_server_manager.MCPClient', mock_client_constructor):
|
|
# Load a test server
|
|
test_manager.load_servers_from_config({
|
|
"test_server": {
|
|
"url": "https://test-server.com/mcp",
|
|
"transport": "http",
|
|
"description": "Test Server"
|
|
}
|
|
})
|
|
|
|
# Call list_tools with a specific protocol version from request
|
|
await test_manager.list_tools(mcp_protocol_version="2025-03-26")
|
|
|
|
# Verify the client was created with the correct protocol version
|
|
mock_client.list_tools.assert_called()
|
|
|
|
|