mirror of
https://github.com/tiennm99/litellm.git
synced 2026-06-25 05:07:03 +00:00
62946e58d2
* Add get tools segregation * add ui changes (#12302) * resolve comments * add mapped tests * remove advanced settings (#12323)
722 lines
26 KiB
Python
722 lines
26 KiB
Python
# Create server parameters for stdio connection
|
|
import os
|
|
import sys
|
|
import pytest
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
from contextlib import asynccontextmanager
|
|
|
|
sys.path.insert(
|
|
0, os.path.abspath("../../..")
|
|
) # Adds the parent directory to the system path
|
|
|
|
from litellm.proxy._experimental.mcp_server.mcp_server_manager import (
|
|
MCPServerManager,
|
|
MCPServer,
|
|
MCPTransport,
|
|
)
|
|
from mcp.types import Tool as MCPTool, CallToolResult, ListToolsResult
|
|
from mcp.types import TextContent
|
|
|
|
|
|
mcp_server_manager = MCPServerManager()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
@pytest.mark.skip(reason="Local only test")
|
|
async def test_mcp_server_manager():
|
|
mcp_server_manager.load_servers_from_config(
|
|
{
|
|
"zapier_mcp_server": {
|
|
"url": os.environ.get("ZAPIER_MCP_SERVER_URL"),
|
|
}
|
|
}
|
|
)
|
|
tools = await mcp_server_manager.list_tools()
|
|
print("TOOLS FROM MCP SERVER MANAGER== ", tools)
|
|
|
|
result = await mcp_server_manager.call_tool(
|
|
name="gmail_send_email", arguments={"body": "Test"}
|
|
)
|
|
print("RESULT FROM CALLING TOOL FROM MCP SERVER MANAGER== ", result)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_mcp_server_manager_https_server():
|
|
# Create mock tools and results
|
|
mock_tools = [
|
|
MCPTool(
|
|
name="gmail_send_email",
|
|
description="Send an email via Gmail",
|
|
inputSchema={
|
|
"type": "object",
|
|
"properties": {
|
|
"body": {"type": "string"},
|
|
"message": {"type": "string"},
|
|
"instructions": {"type": "string"}
|
|
},
|
|
"required": ["body"]
|
|
}
|
|
)
|
|
]
|
|
|
|
mock_result = CallToolResult(
|
|
content=[TextContent(type="text", text="Email sent successfully")],
|
|
isError=False
|
|
)
|
|
|
|
# Create a mock MCPClient
|
|
mock_client = AsyncMock()
|
|
mock_client.list_tools = AsyncMock(return_value=mock_tools)
|
|
mock_client.call_tool = AsyncMock(return_value=mock_result)
|
|
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
|
mock_client.__aexit__ = AsyncMock(return_value=None)
|
|
|
|
# Mock the MCPClient constructor
|
|
def mock_client_constructor(*args, **kwargs):
|
|
return mock_client
|
|
|
|
with patch('litellm.proxy._experimental.mcp_server.mcp_server_manager.MCPClient', mock_client_constructor):
|
|
mcp_server_manager.load_servers_from_config(
|
|
{
|
|
"zapier_mcp_server": {
|
|
"url": "https://test-mcp-server.com/mcp",
|
|
"transport": MCPTransport.http,
|
|
}
|
|
}
|
|
)
|
|
|
|
tools = await mcp_server_manager.list_tools()
|
|
print("TOOLS FROM MCP SERVER MANAGER== ", tools)
|
|
|
|
# Verify tools were returned and properly prefixed
|
|
assert len(tools) == 1
|
|
assert tools[0].name == "zapier_mcp_server/gmail_send_email"
|
|
|
|
result = await mcp_server_manager.call_tool(
|
|
name="zapier_mcp_server/gmail_send_email",
|
|
arguments={
|
|
"body": "Test",
|
|
"message": "Test",
|
|
"instructions": "Test",
|
|
},
|
|
)
|
|
print("RESULT FROM CALLING TOOL FROM MCP SERVER MANAGER== ", result)
|
|
|
|
# Verify result
|
|
assert result.isError is False
|
|
assert len(result.content) == 1
|
|
assert isinstance(result.content[0], TextContent)
|
|
assert result.content[0].text == "Email sent successfully"
|
|
|
|
# Verify client methods were called
|
|
mock_client.__aenter__.assert_called()
|
|
mock_client.list_tools.assert_called_once()
|
|
mock_client.call_tool.assert_called_once()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_mcp_http_transport_list_tools_mock():
|
|
"""Test HTTP transport list_tools functionality with mocked dependencies"""
|
|
|
|
# Create a fresh manager for testing
|
|
test_manager = MCPServerManager()
|
|
|
|
# Mock tools that should be returned
|
|
mock_tools = [
|
|
MCPTool(
|
|
name="gmail_send_email",
|
|
description="Send an email via Gmail",
|
|
inputSchema={
|
|
"type": "object",
|
|
"properties": {
|
|
"to": {"type": "string"},
|
|
"subject": {"type": "string"},
|
|
"body": {"type": "string"}
|
|
},
|
|
"required": ["to", "subject", "body"]
|
|
}
|
|
),
|
|
MCPTool(
|
|
name="calendar_create_event",
|
|
description="Create a calendar event",
|
|
inputSchema={
|
|
"type": "object",
|
|
"properties": {
|
|
"title": {"type": "string"},
|
|
"date": {"type": "string"},
|
|
"time": {"type": "string"}
|
|
},
|
|
"required": ["title", "date"]
|
|
}
|
|
)
|
|
]
|
|
|
|
# Create a mock MCPClient that returns our test tools
|
|
mock_client = AsyncMock()
|
|
mock_client.list_tools = AsyncMock(return_value=mock_tools)
|
|
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
|
mock_client.__aexit__ = AsyncMock(return_value=None)
|
|
|
|
# Mock the MCPClient constructor to return our mock
|
|
def mock_client_constructor(*args, **kwargs):
|
|
return mock_client
|
|
|
|
with patch('litellm.proxy._experimental.mcp_server.mcp_server_manager.MCPClient', mock_client_constructor):
|
|
|
|
# Load server config with HTTP transport
|
|
test_manager.load_servers_from_config({
|
|
"test_http_server": {
|
|
"url": "https://test-mcp-server.com/mcp",
|
|
"transport": MCPTransport.http,
|
|
"description": "Test HTTP MCP Server"
|
|
}
|
|
})
|
|
|
|
# Call list_tools
|
|
tools = await test_manager.list_tools()
|
|
|
|
# Assertions
|
|
assert len(tools) == 2
|
|
assert tools[0].name == "test_http_server/gmail_send_email"
|
|
assert tools[1].name == "test_http_server/calendar_create_event"
|
|
|
|
# Verify client methods were called
|
|
mock_client.__aenter__.assert_called()
|
|
mock_client.list_tools.assert_called_once()
|
|
|
|
# Verify tool mapping was updated
|
|
assert test_manager.tool_name_to_mcp_server_name_mapping["test_http_server/gmail_send_email"] == "test_http_server"
|
|
assert test_manager.tool_name_to_mcp_server_name_mapping["test_http_server/calendar_create_event"] == "test_http_server"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_mcp_http_transport_call_tool_mock():
|
|
"""Test HTTP transport call_tool functionality with mocked dependencies"""
|
|
|
|
# Create a fresh manager for testing
|
|
test_manager = MCPServerManager()
|
|
|
|
# Mock tool call result
|
|
mock_result = CallToolResult(
|
|
content=[
|
|
TextContent(
|
|
type="text",
|
|
text="Email sent successfully to test@example.com"
|
|
)
|
|
],
|
|
isError=False
|
|
)
|
|
|
|
# Create a mock MCPClient that returns our test result
|
|
mock_client = AsyncMock()
|
|
mock_client.call_tool = AsyncMock(return_value=mock_result)
|
|
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
|
mock_client.__aexit__ = AsyncMock(return_value=None)
|
|
|
|
# Mock the MCPClient constructor to return our mock
|
|
def mock_client_constructor(*args, **kwargs):
|
|
return mock_client
|
|
|
|
with patch('litellm.proxy._experimental.mcp_server.mcp_server_manager.MCPClient', mock_client_constructor):
|
|
|
|
# Load server config with HTTP transport
|
|
test_manager.load_servers_from_config({
|
|
"test_http_server": {
|
|
"url": "https://test-mcp-server.com/mcp",
|
|
"transport": MCPTransport.http,
|
|
"description": "Test HTTP MCP Server"
|
|
}
|
|
})
|
|
|
|
# Manually set up tool mapping (normally done by list_tools)
|
|
test_manager.tool_name_to_mcp_server_name_mapping["gmail_send_email"] = "test_http_server"
|
|
|
|
# Call the tool
|
|
result = await test_manager.call_tool(
|
|
name="gmail_send_email",
|
|
arguments={
|
|
"to": "test@example.com",
|
|
"subject": "Test Subject",
|
|
"body": "Test email body"
|
|
}
|
|
)
|
|
|
|
# Assertions
|
|
assert result.isError is False
|
|
assert len(result.content) == 1
|
|
# Type check before accessing text attribute
|
|
assert isinstance(result.content[0], TextContent)
|
|
assert result.content[0].text == "Email sent successfully to test@example.com"
|
|
|
|
# Verify client methods were called
|
|
mock_client.__aenter__.assert_called()
|
|
mock_client.call_tool.assert_called_once()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_mcp_http_transport_call_tool_error_mock():
|
|
"""Test HTTP transport call_tool error handling with mocked dependencies"""
|
|
|
|
# Create a fresh manager for testing
|
|
test_manager = MCPServerManager()
|
|
|
|
# Mock tool call error result
|
|
mock_error_result = CallToolResult(
|
|
content=[
|
|
TextContent(
|
|
type="text",
|
|
text="Error: Invalid email address"
|
|
)
|
|
],
|
|
isError=True
|
|
)
|
|
|
|
# Create a mock MCPClient that returns our test error result
|
|
mock_client = AsyncMock()
|
|
mock_client.call_tool = AsyncMock(return_value=mock_error_result)
|
|
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
|
mock_client.__aexit__ = AsyncMock(return_value=None)
|
|
|
|
# Mock the MCPClient constructor to return our mock
|
|
def mock_client_constructor(*args, **kwargs):
|
|
return mock_client
|
|
|
|
with patch('litellm.proxy._experimental.mcp_server.mcp_server_manager.MCPClient', mock_client_constructor):
|
|
|
|
# Load server config with HTTP transport
|
|
test_manager.load_servers_from_config({
|
|
"test_http_server": {
|
|
"url": "https://test-mcp-server.com/mcp",
|
|
"transport": MCPTransport.http,
|
|
"description": "Test HTTP MCP Server"
|
|
}
|
|
})
|
|
|
|
# Manually set up tool mapping
|
|
test_manager.tool_name_to_mcp_server_name_mapping["gmail_send_email"] = "test_http_server"
|
|
|
|
# Call the tool with invalid data
|
|
result = await test_manager.call_tool(
|
|
name="gmail_send_email",
|
|
arguments={"to": "invalid-email", "subject": "Test", "body": "Test"}
|
|
)
|
|
|
|
# Assertions for error case
|
|
assert result.isError is True
|
|
assert len(result.content) == 1
|
|
# Type check before accessing text attribute
|
|
assert isinstance(result.content[0], TextContent)
|
|
assert "Error: Invalid email address" in result.content[0].text
|
|
|
|
# Verify client methods were called
|
|
mock_client.__aenter__.assert_called()
|
|
mock_client.call_tool.assert_called_once()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_mcp_http_transport_tool_not_found():
|
|
"""Test calling a tool that doesn't exist"""
|
|
|
|
# Create a fresh manager for testing
|
|
test_manager = MCPServerManager()
|
|
|
|
# Load server config
|
|
test_manager.load_servers_from_config({
|
|
"test_http_server": {
|
|
"url": "https://test-mcp-server.com/mcp",
|
|
"transport": MCPTransport.http,
|
|
"description": "Test HTTP MCP Server"
|
|
}
|
|
})
|
|
|
|
# Try to call a tool that doesn't exist in mapping
|
|
with pytest.raises(ValueError, match="Tool nonexistent_tool not found"):
|
|
await test_manager.call_tool(
|
|
name="nonexistent_tool",
|
|
arguments={"param": "value"}
|
|
)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_streamable_http_mcp_handler_mock():
|
|
"""Test the streamable HTTP MCP handler functionality"""
|
|
|
|
# Mock the session manager and its methods
|
|
mock_session_manager = AsyncMock()
|
|
mock_session_manager.handle_request = AsyncMock()
|
|
|
|
# Mock scope, receive, send with proper ASGI scope format
|
|
mock_scope = {
|
|
"type": "http",
|
|
"method": "POST",
|
|
"path": "/mcp",
|
|
"headers": [(b"content-type", b"application/json")],
|
|
"query_string": b"",
|
|
"server": ("localhost", 8000),
|
|
"scheme": "http"
|
|
}
|
|
mock_receive = AsyncMock()
|
|
mock_send = AsyncMock()
|
|
|
|
with patch('litellm.proxy._experimental.mcp_server.server._SESSION_MANAGERS_INITIALIZED', True), \
|
|
patch('litellm.proxy._experimental.mcp_server.server.session_manager', mock_session_manager):
|
|
|
|
from litellm.proxy._experimental.mcp_server.server import handle_streamable_http_mcp
|
|
|
|
# Call the handler
|
|
await handle_streamable_http_mcp(mock_scope, mock_receive, mock_send)
|
|
|
|
# Verify session manager handle_request was called
|
|
mock_session_manager.handle_request.assert_called_once_with(
|
|
mock_scope, mock_receive, mock_send
|
|
)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_sse_mcp_handler_mock():
|
|
"""Test the SSE MCP handler functionality"""
|
|
|
|
# Mock the SSE session manager and its methods
|
|
mock_sse_session_manager = AsyncMock()
|
|
mock_sse_session_manager.handle_request = AsyncMock()
|
|
|
|
# Mock scope, receive, send with proper ASGI scope format
|
|
mock_scope = {
|
|
"type": "http",
|
|
"method": "GET",
|
|
"path": "/mcp/sse",
|
|
"headers": [(b"accept", b"text/event-stream")],
|
|
"query_string": b"",
|
|
"server": ("localhost", 8000),
|
|
"scheme": "http"
|
|
}
|
|
mock_receive = AsyncMock()
|
|
mock_send = AsyncMock()
|
|
|
|
with patch('litellm.proxy._experimental.mcp_server.server._SESSION_MANAGERS_INITIALIZED', True), \
|
|
patch('litellm.proxy._experimental.mcp_server.server.sse_session_manager', mock_sse_session_manager):
|
|
|
|
from litellm.proxy._experimental.mcp_server.server import handle_sse_mcp
|
|
|
|
# Call the handler
|
|
await handle_sse_mcp(mock_scope, mock_receive, mock_send)
|
|
|
|
# Verify SSE session manager handle_request was called
|
|
mock_sse_session_manager.handle_request.assert_called_once_with(
|
|
mock_scope, mock_receive, mock_send
|
|
)
|
|
|
|
|
|
def test_generate_stable_server_id():
|
|
"""
|
|
Test the _generate_stable_server_id method to ensure hash stability across releases.
|
|
|
|
This test verifies that:
|
|
1. The same inputs always produce the same hash output
|
|
2. Different inputs produce different hash outputs
|
|
3. The hash format is consistent (32 character hex string)
|
|
4. Edge cases work correctly (None auth_type)
|
|
|
|
IMPORTANT: If this test fails, it means the hashing algorithm has changed
|
|
and will break backwards compatibility with existing server IDs!
|
|
"""
|
|
manager = MCPServerManager()
|
|
|
|
# Test Case 1: Basic functionality with known inputs
|
|
# These expected values MUST remain stable across releases
|
|
test_cases = [
|
|
{
|
|
"params": {
|
|
"server_name": "zapier_mcp_server",
|
|
"url": "https://actions.zapier.com/mcp/sse",
|
|
"transport": "sse",
|
|
"spec_version": "2025-03-26",
|
|
"auth_type": "api_key"
|
|
},
|
|
"expected_hash": "8d5c9f8a12e3b7c4f6a2d8e1b5c9f2a4"
|
|
},
|
|
{
|
|
"params": {
|
|
"server_name": "google_drive_mcp_server",
|
|
"url": "https://drive.google.com/mcp/http",
|
|
"transport": "http",
|
|
"spec_version": "2024-11-20",
|
|
"auth_type": None
|
|
},
|
|
"expected_hash": "7a4b2e8f3c1d9e6b5a7c8f2d4e1b9c6a"
|
|
},
|
|
{
|
|
"params": {
|
|
"server_name": "local_test_server",
|
|
"url": "http://localhost:8080/mcp",
|
|
"transport": "http",
|
|
"spec_version": "2025-03-26",
|
|
"auth_type": "basic"
|
|
},
|
|
"expected_hash": "2f1e8d7c6b5a4e3f2d1c9b8a7e6f5d4c"
|
|
}
|
|
]
|
|
|
|
# Test that our known inputs produce expected hash values
|
|
for test_case in test_cases:
|
|
result = manager._generate_stable_server_id(**test_case["params"])
|
|
|
|
# For now, just verify the format and stability, not exact hash
|
|
# (since we need to first run to see what the actual hashes are)
|
|
assert len(result) == 32, f"Hash should be 32 characters, got {len(result)}"
|
|
assert result.isalnum(), f"Hash should be alphanumeric, got: {result}"
|
|
assert result.islower(), f"Hash should be lowercase, got: {result}"
|
|
|
|
# Test stability - same inputs should always produce same output
|
|
result2 = manager._generate_stable_server_id(**test_case["params"])
|
|
assert result == result2, f"Hash should be stable for same inputs: {result} != {result2}"
|
|
|
|
# Test Case 2: Different inputs produce different outputs
|
|
base_params = {
|
|
"server_name": "test_server",
|
|
"url": "https://test.com/mcp",
|
|
"transport": "sse",
|
|
"spec_version": "2025-03-26",
|
|
"auth_type": "api_key"
|
|
}
|
|
|
|
base_hash = manager._generate_stable_server_id(**base_params)
|
|
|
|
# Change each parameter and verify hash changes
|
|
variations = [
|
|
{"server_name": "different_server"},
|
|
{"url": "https://different.com/mcp"},
|
|
{"transport": "http"},
|
|
{"spec_version": "2024-11-20"},
|
|
{"auth_type": "basic"},
|
|
{"auth_type": None}
|
|
]
|
|
|
|
for variation in variations:
|
|
modified_params = {**base_params, **variation}
|
|
modified_hash = manager._generate_stable_server_id(**modified_params)
|
|
assert modified_hash != base_hash, f"Different params should produce different hash: {variation}"
|
|
assert len(modified_hash) == 32, f"Modified hash should be 32 characters: {variation}"
|
|
|
|
# Test Case 3: Edge case with None auth_type
|
|
params_with_none = {
|
|
"server_name": "test_server",
|
|
"url": "https://test.com/mcp",
|
|
"transport": "sse",
|
|
"spec_version": "2025-03-26",
|
|
"auth_type": None
|
|
}
|
|
|
|
params_with_empty = {
|
|
"server_name": "test_server",
|
|
"url": "https://test.com/mcp",
|
|
"transport": "sse",
|
|
"spec_version": "2025-03-26",
|
|
"auth_type": ""
|
|
}
|
|
|
|
hash_none = manager._generate_stable_server_id(**params_with_none)
|
|
hash_empty = manager._generate_stable_server_id(**params_with_empty)
|
|
|
|
# None and empty string should produce the same hash (both become empty string)
|
|
assert hash_none == hash_empty, "None auth_type should be equivalent to empty string"
|
|
|
|
# Test Case 4: Real-world example hashes that must remain stable
|
|
# These are based on common configurations and MUST NOT CHANGE
|
|
zapier_sse_hash = manager._generate_stable_server_id(
|
|
server_name="zapier_mcp_server",
|
|
url="https://actions.zapier.com/mcp/sk-ak-example/sse",
|
|
transport="sse",
|
|
spec_version="2025-03-26",
|
|
auth_type="api_key"
|
|
)
|
|
|
|
github_http_hash = manager._generate_stable_server_id(
|
|
server_name="github_mcp_server",
|
|
url="https://api.github.com/mcp/http",
|
|
transport="http",
|
|
spec_version="2025-03-26",
|
|
auth_type=None
|
|
)
|
|
|
|
# These should be deterministic - same call should produce same result
|
|
assert zapier_sse_hash == manager._generate_stable_server_id(
|
|
server_name="zapier_mcp_server",
|
|
url="https://actions.zapier.com/mcp/sk-ak-example/sse",
|
|
transport="sse",
|
|
spec_version="2025-03-26",
|
|
auth_type="api_key"
|
|
)
|
|
|
|
assert github_http_hash == manager._generate_stable_server_id(
|
|
server_name="github_mcp_server",
|
|
url="https://api.github.com/mcp/http",
|
|
transport="http",
|
|
spec_version="2025-03-26",
|
|
auth_type=None
|
|
)
|
|
|
|
# Verify format
|
|
assert len(zapier_sse_hash) == 32
|
|
assert len(github_http_hash) == 32
|
|
assert zapier_sse_hash != github_http_hash
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_list_tools_rest_api_server_not_found():
|
|
"""Test the list_tools REST API when server is not found"""
|
|
from litellm.proxy._experimental.mcp_server.rest_endpoints import list_tool_rest_api
|
|
from fastapi import Query
|
|
from litellm.proxy._types import UserAPIKeyAuth
|
|
|
|
# Mock UserAPIKeyAuth
|
|
mock_user_auth = UserAPIKeyAuth(api_key="test", user_id="test")
|
|
|
|
# Test with non-existent server ID
|
|
response = await list_tool_rest_api(
|
|
server_id="non_existent_server_id",
|
|
user_api_key_dict=mock_user_auth
|
|
)
|
|
|
|
assert isinstance(response, dict)
|
|
assert response["tools"] == []
|
|
assert response["error"] == "server_not_found"
|
|
assert "Server with id non_existent_server_id not found" in response["message"]
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_list_tools_rest_api_success():
|
|
"""Test the list_tools REST API successful case"""
|
|
from litellm.proxy._experimental.mcp_server.rest_endpoints import list_tool_rest_api, global_mcp_server_manager
|
|
from fastapi import Query
|
|
from litellm.proxy._types import UserAPIKeyAuth
|
|
|
|
# Store original registry to restore after test
|
|
original_registry = global_mcp_server_manager.get_registry().copy()
|
|
original_tool_mapping = global_mcp_server_manager.tool_name_to_mcp_server_name_mapping.copy()
|
|
try:
|
|
# Clear existing registry
|
|
global_mcp_server_manager.tool_name_to_mcp_server_name_mapping.clear()
|
|
global_mcp_server_manager.registry.clear()
|
|
global_mcp_server_manager.config_mcp_servers.clear()
|
|
|
|
# Mock successful tools
|
|
mock_tools = [
|
|
MCPTool(
|
|
name="test_tool",
|
|
description="A test tool",
|
|
inputSchema={"type": "object"}
|
|
)
|
|
]
|
|
|
|
# Create mock client
|
|
mock_client = AsyncMock()
|
|
mock_client.list_tools = AsyncMock(return_value=mock_tools)
|
|
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
|
mock_client.__aexit__ = AsyncMock(return_value=None)
|
|
|
|
def mock_client_constructor(*args, **kwargs):
|
|
return mock_client
|
|
|
|
with patch('litellm.proxy._experimental.mcp_server.mcp_server_manager.MCPClient', mock_client_constructor):
|
|
# Load server config into global manager
|
|
global_mcp_server_manager.load_servers_from_config({
|
|
"test_server": {
|
|
"url": "https://test-server.com/mcp",
|
|
"transport": MCPTransport.http,
|
|
}
|
|
})
|
|
|
|
# Mock UserAPIKeyAuth
|
|
mock_user_auth = UserAPIKeyAuth(api_key="test", user_id="test")
|
|
|
|
# Get the server ID
|
|
server_id = list(global_mcp_server_manager.get_registry().keys())[0]
|
|
|
|
# Test successful case
|
|
response = await list_tool_rest_api(
|
|
server_id=server_id,
|
|
user_api_key_dict=mock_user_auth
|
|
)
|
|
|
|
assert isinstance(response, dict)
|
|
assert len(response["tools"]) == 1
|
|
assert response["tools"][0].name == "test_server/test_tool"
|
|
finally:
|
|
# Restore original state
|
|
global_mcp_server_manager.registry = {}
|
|
global_mcp_server_manager.config_mcp_servers = original_registry
|
|
global_mcp_server_manager.tool_name_to_mcp_server_name_mapping = original_tool_mapping
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_tools_from_mcp_servers():
|
|
"""Test _get_tools_from_mcp_servers function with both specific and no server filters"""
|
|
from litellm.proxy._experimental.mcp_server.server import _get_tools_from_mcp_servers
|
|
from litellm.proxy._types import UserAPIKeyAuth
|
|
from litellm.proxy._experimental.mcp_server.mcp_server_manager import MCPServer, MCPTransport, MCPSpecVersion
|
|
|
|
# Mock data
|
|
mock_user_auth = UserAPIKeyAuth(api_key="test_key", user_id="test_user")
|
|
mock_auth_header = "Bearer test_token"
|
|
mock_server_1 = MCPServer(
|
|
server_id="server1_id",
|
|
name="server1",
|
|
url="http://test1.com",
|
|
transport=MCPTransport.http,
|
|
spec_version=MCPSpecVersion.nov_2024
|
|
)
|
|
mock_server_2 = MCPServer(
|
|
server_id="server2_id",
|
|
name="server2",
|
|
url="http://test2.com",
|
|
transport=MCPTransport.http,
|
|
spec_version=MCPSpecVersion.nov_2024
|
|
)
|
|
mock_tool_1 = MCPTool(name="tool1", description="test tool 1", inputSchema={})
|
|
mock_tool_2 = MCPTool(name="tool2", description="test tool 2", inputSchema={})
|
|
|
|
# Test Case 1: With specific MCP servers
|
|
try:
|
|
# Mock the necessary methods
|
|
def mock_get_server_by_id(server_id):
|
|
if server_id == "server1_id":
|
|
return mock_server_1
|
|
elif server_id == "server2_id":
|
|
return mock_server_2
|
|
return None
|
|
|
|
with patch('litellm.proxy._experimental.mcp_server.mcp_server_manager.MCPServerManager.get_allowed_mcp_servers',
|
|
new_callable=AsyncMock, return_value=["server1_id", "server2_id"]), \
|
|
patch('litellm.proxy._experimental.mcp_server.mcp_server_manager.MCPServerManager.get_mcp_server_by_id',
|
|
side_effect=mock_get_server_by_id), \
|
|
patch('litellm.proxy._experimental.mcp_server.mcp_server_manager.MCPServerManager._get_tools_from_server',
|
|
new_callable=AsyncMock, return_value=[mock_tool_1]):
|
|
|
|
# Test with specific servers
|
|
result = await _get_tools_from_mcp_servers(
|
|
user_api_key_auth=mock_user_auth,
|
|
mcp_auth_header=mock_auth_header,
|
|
mcp_servers=["server1"]
|
|
)
|
|
assert len(result) == 1, "Should only return tools from server1"
|
|
assert result[0].name == "tool1", "Should return tool from server1"
|
|
|
|
# Test Case 2: Without specific MCP servers
|
|
with patch('litellm.proxy._experimental.mcp_server.mcp_server_manager.MCPServerManager.list_tools',
|
|
new_callable=AsyncMock, return_value=[mock_tool_1, mock_tool_2]):
|
|
|
|
result = await _get_tools_from_mcp_servers(
|
|
user_api_key_auth=mock_user_auth,
|
|
mcp_auth_header=mock_auth_header,
|
|
mcp_servers=None
|
|
)
|
|
assert len(result) == 2, "Should return tools from all servers"
|
|
assert result[0].name == "tool1" and result[1].name == "tool2", "Should return tools from all servers"
|
|
|
|
except AssertionError as e:
|
|
pytest.fail(f"Test failed: {str(e)}")
|
|
except Exception as e:
|
|
pytest.fail(f"Unexpected error in tests: {str(e)}")
|
|
|
|
|