mirror of
https://github.com/tiennm99/litellm.git
synced 2026-06-18 05:28:02 +00:00
test(test_mcp_server_manager.py): add unit testing
This commit is contained in:
@@ -97,3 +97,5 @@ litellm_config.yaml
|
||||
.vscode/launch.json
|
||||
litellm/proxy/to_delete_loadtest_work/*
|
||||
update_model_cost_map.py
|
||||
tests/test_litellm/proxy/_experimental/mcp_server/test_mcp_server_manager.py
|
||||
litellm/proxy/_experimental/out/guardrails/index.html
|
||||
|
||||
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
@@ -1222,6 +1222,104 @@ class TestMCPServerManager:
|
||||
"Contact proxy admin to allow this tool" in exc_info.value.detail["error"]
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_call_tool_without_broken_pipe_error(self):
|
||||
"""
|
||||
Test that call_tool properly uses async context manager to avoid broken pipe errors.
|
||||
This test ensures that tasks are awaited INSIDE the context manager, keeping the connection alive.
|
||||
"""
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
from mcp.types import CallToolResult
|
||||
|
||||
manager = MCPServerManager()
|
||||
|
||||
# Create a test server
|
||||
server = MCPServer(
|
||||
server_id="test-server",
|
||||
name="test-server",
|
||||
transport=MCPTransport.http,
|
||||
url="http://test-server.com",
|
||||
)
|
||||
|
||||
# Register the server and map a tool to it
|
||||
manager.registry = {"test-server": server}
|
||||
manager.tool_name_to_mcp_server_name_mapping["test_tool"] = "test-server"
|
||||
|
||||
# Create mock client that tracks context manager usage
|
||||
mock_client = MagicMock()
|
||||
context_entered = False
|
||||
context_exited = False
|
||||
call_tool_called_inside_context = False
|
||||
|
||||
async def mock_aenter(self):
|
||||
nonlocal context_entered
|
||||
context_entered = True
|
||||
return self
|
||||
|
||||
async def mock_aexit(self, exc_type, exc_val, exc_tb):
|
||||
nonlocal context_exited
|
||||
context_exited = True
|
||||
# Verify that call_tool was called before context exit
|
||||
assert (
|
||||
call_tool_called_inside_context
|
||||
), "call_tool must be awaited inside context manager"
|
||||
return False
|
||||
|
||||
async def mock_call_tool(params):
|
||||
nonlocal call_tool_called_inside_context
|
||||
# Verify we're inside the context when this is called
|
||||
assert context_entered, "call_tool called outside context manager"
|
||||
assert not context_exited, "call_tool called after context exit"
|
||||
call_tool_called_inside_context = True
|
||||
|
||||
# Return a mock CallToolResult
|
||||
result = MagicMock(spec=CallToolResult)
|
||||
result.content = [{"type": "text", "text": "Tool executed successfully"}]
|
||||
result.isError = False
|
||||
return result
|
||||
|
||||
mock_client.__aenter__ = mock_aenter
|
||||
mock_client.__aexit__ = mock_aexit
|
||||
mock_client.call_tool = mock_call_tool
|
||||
|
||||
# Mock _create_mcp_client to return our mock client
|
||||
manager._create_mcp_client = MagicMock(return_value=mock_client)
|
||||
|
||||
# Mock user auth with no restrictions
|
||||
user_api_key_auth = MagicMock()
|
||||
user_api_key_auth.object_permission = None
|
||||
user_api_key_auth.object_permission_id = None
|
||||
|
||||
# Mock proxy logging
|
||||
proxy_logging_obj = MagicMock()
|
||||
proxy_logging_obj._create_mcp_request_object_from_kwargs = MagicMock(
|
||||
return_value={}
|
||||
)
|
||||
proxy_logging_obj._convert_mcp_to_llm_format = MagicMock(return_value={})
|
||||
proxy_logging_obj.pre_call_hook = AsyncMock(return_value={})
|
||||
proxy_logging_obj.during_call_hook = AsyncMock(return_value=None)
|
||||
|
||||
# Call the tool
|
||||
result = await manager.call_tool(
|
||||
name="test_tool",
|
||||
arguments={"param": "value"},
|
||||
user_api_key_auth=user_api_key_auth,
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
)
|
||||
|
||||
# Verify the result
|
||||
assert result is not None
|
||||
assert result.isError is False
|
||||
assert len(result.content) > 0
|
||||
|
||||
# Verify context manager was used properly
|
||||
assert context_entered, "Context manager __aenter__ was not called"
|
||||
assert context_exited, "Context manager __aexit__ was not called"
|
||||
assert (
|
||||
call_tool_called_inside_context
|
||||
), "call_tool was not awaited inside context"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__])
|
||||
|
||||
Reference in New Issue
Block a user