mirror of
https://github.com/tiennm99/litellm.git
synced 2026-06-18 03:31:23 +00:00
227 lines
8.0 KiB
Python
227 lines
8.0 KiB
Python
"""
|
|
Unit tests for the MCPClient class - critical functionality only.
|
|
"""
|
|
import base64
|
|
import os
|
|
import sys
|
|
import pytest
|
|
from unittest.mock import AsyncMock, MagicMock, patch, ANY
|
|
|
|
# Add the project root to the path
|
|
sys.path.insert(0, os.path.abspath("../../.."))
|
|
|
|
import litellm.experimental_mcp_client.client as mcp_client_module
|
|
from litellm.experimental_mcp_client.client import MCPClient
|
|
from litellm.types.mcp import MCPAuth, MCPTransport
|
|
from mcp.types import Tool as MCPTool, CallToolResult as MCPCallToolResult
|
|
|
|
|
|
def test_mcp_client_uses_configurable_default_timeout():
|
|
"""MCPClient should use MCP_CLIENT_TIMEOUT constant when no timeout is passed."""
|
|
with patch(
|
|
"litellm.experimental_mcp_client.client.MCP_CLIENT_TIMEOUT", 120.0
|
|
):
|
|
# Client reads constant at runtime when timeout is None
|
|
client = MCPClient(
|
|
server_url="http://example.com",
|
|
transport_type=MCPTransport.sse,
|
|
)
|
|
assert client.timeout == 120.0
|
|
|
|
|
|
class TestMCPClientUnitTests:
|
|
"""Unit tests for MCPClient functionality."""
|
|
|
|
def test_init_with_auth(self):
|
|
"""Test initialization with authentication."""
|
|
client = MCPClient(
|
|
server_url="http://example.com",
|
|
transport_type=MCPTransport.sse,
|
|
auth_type=MCPAuth.bearer_token,
|
|
auth_value="test_token",
|
|
timeout=30.0,
|
|
)
|
|
assert client.server_url == "http://example.com"
|
|
assert client.transport_type == MCPTransport.sse
|
|
assert client.auth_type == MCPAuth.bearer_token
|
|
assert client.timeout == 30.0
|
|
assert client._mcp_auth_value == "test_token"
|
|
|
|
def test_get_auth_headers(self):
|
|
"""Test authentication header generation for different auth types."""
|
|
# Bearer token
|
|
client = MCPClient(
|
|
"http://example.com",
|
|
auth_type=MCPAuth.bearer_token,
|
|
auth_value="test_token",
|
|
)
|
|
headers = client._get_auth_headers()
|
|
assert headers == {
|
|
"Authorization": "Bearer test_token",
|
|
}
|
|
|
|
# Basic auth
|
|
client = MCPClient(
|
|
"http://example.com", auth_type=MCPAuth.basic, auth_value="user:pass"
|
|
)
|
|
expected_encoded = base64.b64encode("user:pass".encode("utf-8")).decode()
|
|
headers = client._get_auth_headers()
|
|
assert headers == {
|
|
"Authorization": f"Basic {expected_encoded}",
|
|
}
|
|
|
|
# API key
|
|
client = MCPClient(
|
|
"http://example.com", auth_type=MCPAuth.api_key, auth_value="api_key_123"
|
|
)
|
|
headers = client._get_auth_headers()
|
|
assert headers == {
|
|
"X-API-Key": "api_key_123",
|
|
}
|
|
|
|
# Custom authorization header
|
|
client = MCPClient(
|
|
"http://example.com",
|
|
auth_type=MCPAuth.authorization,
|
|
auth_value="Token custom_token",
|
|
)
|
|
headers = client._get_auth_headers()
|
|
assert headers == {
|
|
"Authorization": "Token custom_token",
|
|
}
|
|
|
|
# OAuth2
|
|
client = MCPClient(
|
|
"http://example.com",
|
|
auth_type=MCPAuth.oauth2,
|
|
auth_value="oauth2-access-token-xyz",
|
|
)
|
|
headers = client._get_auth_headers()
|
|
assert headers == {
|
|
"Authorization": "Bearer oauth2-access-token-xyz",
|
|
}
|
|
|
|
# OAuth2 with extra_headers (per-user flow overrides auth_value)
|
|
client = MCPClient(
|
|
"http://example.com",
|
|
auth_type=MCPAuth.oauth2,
|
|
auth_value="static-server-token",
|
|
extra_headers={"Authorization": "Bearer per-user-token"},
|
|
)
|
|
headers = client._get_auth_headers()
|
|
assert headers["Authorization"] == "Bearer per-user-token"
|
|
|
|
# No auth
|
|
client = MCPClient("http://example.com")
|
|
headers = client._get_auth_headers()
|
|
assert headers == {}
|
|
|
|
@pytest.mark.asyncio
|
|
@patch.object(mcp_client_module, "streamable_http_client")
|
|
@patch.object(mcp_client_module, "ClientSession")
|
|
async def test_run_with_session(self, mock_session_class, mock_transport):
|
|
"""Test run_with_session establishes session with auth headers."""
|
|
# Setup mocks
|
|
mock_transport_ctx = AsyncMock()
|
|
mock_transport.return_value = mock_transport_ctx
|
|
mock_transport_instance = MagicMock()
|
|
mock_transport_ctx.__aenter__ = AsyncMock(return_value=mock_transport_instance)
|
|
|
|
mock_session_ctx = AsyncMock()
|
|
mock_session_class.return_value = mock_session_ctx
|
|
mock_session_instance = AsyncMock()
|
|
mock_session_ctx.__aenter__ = AsyncMock(return_value=mock_session_instance)
|
|
|
|
client = MCPClient(
|
|
"http://example.com",
|
|
auth_type=MCPAuth.bearer_token,
|
|
auth_value="test_token",
|
|
)
|
|
|
|
async def _operation(session):
|
|
return "ok"
|
|
|
|
await client.run_with_session(_operation)
|
|
|
|
# Verify transport was created with auth headers
|
|
call_args = mock_transport.call_args
|
|
http_client = call_args[1]["http_client"]
|
|
assert http_client.headers.get("Authorization") == "Bearer test_token"
|
|
|
|
# Verify session was initialized
|
|
mock_session_instance.initialize.assert_called_once()
|
|
|
|
@pytest.mark.asyncio
|
|
@patch.object(mcp_client_module, "streamable_http_client")
|
|
@patch.object(mcp_client_module, "ClientSession")
|
|
async def test_list_tools(self, mock_session_class, mock_transport):
|
|
"""Test listing tools from the server."""
|
|
# Setup mocks
|
|
mock_transport_ctx = AsyncMock()
|
|
mock_transport.return_value = mock_transport_ctx
|
|
mock_transport_instance = MagicMock()
|
|
mock_transport_ctx.__aenter__ = AsyncMock(return_value=mock_transport_instance)
|
|
|
|
mock_session_ctx = AsyncMock()
|
|
mock_session_class.return_value = mock_session_ctx
|
|
mock_session_instance = AsyncMock()
|
|
mock_session_ctx.__aenter__ = AsyncMock(return_value=mock_session_instance)
|
|
|
|
mock_tools = [
|
|
MCPTool(
|
|
name="test_tool",
|
|
description="Test tool",
|
|
inputSchema={
|
|
"type": "object",
|
|
"properties": {"arg1": {"type": "string"}},
|
|
"required": ["arg1"],
|
|
},
|
|
)
|
|
]
|
|
mock_result = MagicMock()
|
|
mock_result.tools = mock_tools
|
|
mock_session_instance.list_tools.return_value = mock_result
|
|
|
|
client = MCPClient("http://example.com")
|
|
result = await client.list_tools()
|
|
|
|
assert result == mock_tools
|
|
mock_session_instance.initialize.assert_called_once()
|
|
mock_session_instance.list_tools.assert_called_once()
|
|
|
|
@pytest.mark.asyncio
|
|
@patch.object(mcp_client_module, "streamable_http_client")
|
|
@patch.object(mcp_client_module, "ClientSession")
|
|
async def test_call_tool(self, mock_session_class, mock_transport):
|
|
"""Test calling a tool."""
|
|
from mcp.types import CallToolRequestParams
|
|
|
|
# Setup mocks
|
|
mock_transport_ctx = AsyncMock()
|
|
mock_transport.return_value = mock_transport_ctx
|
|
mock_transport_instance = MagicMock()
|
|
mock_transport_ctx.__aenter__ = AsyncMock(return_value=mock_transport_instance)
|
|
|
|
mock_session_ctx = AsyncMock()
|
|
mock_session_class.return_value = mock_session_ctx
|
|
mock_session_instance = AsyncMock()
|
|
mock_session_ctx.__aenter__ = AsyncMock(return_value=mock_session_instance)
|
|
|
|
mock_result = MCPCallToolResult(content=[])
|
|
mock_session_instance.call_tool.return_value = mock_result
|
|
|
|
client = MCPClient("http://example.com")
|
|
params = CallToolRequestParams(name="test_tool", arguments={"arg1": "value1"})
|
|
result = await client.call_tool(params)
|
|
|
|
assert result == mock_result
|
|
mock_session_instance.initialize.assert_called_once()
|
|
mock_session_instance.call_tool.assert_called_once_with(
|
|
name="test_tool", arguments={"arg1": "value1"},progress_callback=ANY
|
|
)
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
pytest.main([__file__, "-v"])
|