added error handling for MCP tools not found

This commit is contained in:
Jugal Bhatt
2025-07-01 17:12:01 -07:00
parent 514224d190
commit 27e6ef5f39
6 changed files with 354 additions and 197 deletions
+72 -36
View File
@@ -4,6 +4,7 @@ LiteLLM Proxy uses this MCP Client to connnect to other MCP servers.
import base64
from datetime import timedelta
from typing import List, Optional
import asyncio
from mcp import ClientSession
from mcp.client.sse import sse_client
@@ -46,6 +47,7 @@ class MCPClient:
self._transport_ctx = None
self._transport = None
self._session_ctx = None
self._task: Optional[asyncio.Task] = None
# handle the basic auth value if provided
if auth_value:
@@ -56,8 +58,12 @@ class MCPClient:
Enable async context manager support.
Initializes the transport and session.
"""
await self.connect()
return self
try:
await self.connect()
return self
except Exception as e:
await self.disconnect()
raise
async def connect(self):
"""Initialize the transport and session."""
@@ -66,47 +72,63 @@ class MCPClient:
headers = self._get_auth_headers()
if self.transport_type == MCPTransport.sse:
self._transport_ctx = sse_client(
url=self.server_url,
timeout=self.timeout,
headers=headers,
)
self._transport = await self._transport_ctx.__aenter__()
self._session_ctx = ClientSession(self._transport[0], self._transport[1])
self._session = await self._session_ctx.__aenter__()
await self._session.initialize()
else:
self._transport_ctx = streamablehttp_client(
url=self.server_url,
timeout=timedelta(seconds=self.timeout),
headers=headers,
)
self._transport = await self._transport_ctx.__aenter__()
self._session_ctx = ClientSession(self._transport[0], self._transport[1])
self._session = await self._session_ctx.__aenter__()
await self._session.initialize()
try:
if self.transport_type == MCPTransport.sse:
self._transport_ctx = sse_client(
url=self.server_url,
timeout=self.timeout,
headers=headers,
)
self._transport = await self._transport_ctx.__aenter__()
self._session_ctx = ClientSession(self._transport[0], self._transport[1])
self._session = await self._session_ctx.__aenter__()
await self._session.initialize()
else:
self._transport_ctx = streamablehttp_client(
url=self.server_url,
timeout=timedelta(seconds=self.timeout),
headers=headers,
)
self._transport = await self._transport_ctx.__aenter__()
self._session_ctx = ClientSession(self._transport[0], self._transport[1])
self._session = await self._session_ctx.__aenter__()
await self._session.initialize()
except Exception as e:
await self.disconnect()
raise
async def __aexit__(self, exc_type, exc_val, exc_tb):
"""Cleanup when exiting context manager."""
if self._session:
await self._session_ctx.__aexit__(exc_type, exc_val, exc_tb) # type: ignore
if self._transport_ctx:
await self._transport_ctx.__aexit__(exc_type, exc_val, exc_tb)
await self.disconnect()
async def disconnect(self):
"""Clean up session and connections."""
if self._task and not self._task.done():
self._task.cancel()
try:
await self._task
except asyncio.CancelledError:
pass
if self._session:
try:
# Ensure session is properly closed
await self._session.close() # type: ignore
await self._session_ctx.__aexit__(None, None, None) # type: ignore
except Exception:
pass
self._session = None
self._session_ctx = None
if self._transport_ctx:
try:
await self._transport_ctx.__aexit__(None, None, None)
except Exception:
pass
self._transport_ctx = None
self._transport = None
if self._context:
try:
await self._context.__aexit__(None, None, None) # type: ignore
await self._context.__aexit__(None, None, None) # type: ignore
except Exception:
pass
self._context = None
@@ -140,8 +162,15 @@ class MCPClient:
if self._session is None:
raise ValueError("Session is not initialized")
result = await self._session.list_tools()
return result.tools
try:
result = await self._session.list_tools()
return result.tools
except asyncio.CancelledError:
await self.disconnect()
raise
except Exception as e:
await self.disconnect()
raise
async def call_tool(
self, call_tool_request_params: MCPCallToolRequestParams
@@ -155,10 +184,17 @@ class MCPClient:
if self._session is None:
raise ValueError("Session is not initialized")
tool_result = await self._session.call_tool(
name=call_tool_request_params.name,
arguments=call_tool_request_params.arguments,
)
return tool_result
try:
tool_result = await self._session.call_tool(
name=call_tool_request_params.name,
arguments=call_tool_request_params.arguments,
)
return tool_result
except asyncio.CancelledError:
await self.disconnect()
raise
except Exception as e:
await self.disconnect()
raise
@@ -227,19 +227,43 @@ class MCPServerManager:
verbose_logger.debug(f"Connecting to url: {server.url}")
verbose_logger.info("_get_tools_from_server...")
client = self._create_mcp_client(
server=server,
mcp_auth_header=mcp_auth_header,
)
async with client:
tools = await client.list_tools()
verbose_logger.debug(f"Tools from {server.name}: {tools}")
client = None
try:
client = self._create_mcp_client(
server=server,
mcp_auth_header=mcp_auth_header,
)
# Create a task for the client operations to ensure proper cancellation handling
async def _list_tools_task():
async with client:
tools = await client.list_tools()
verbose_logger.debug(f"Tools from {server.name}: {tools}")
return tools
# Update tool to server mapping
for tool in tools:
self.tool_name_to_mcp_server_name_mapping[tool.name] = server.name
try:
tools = await _list_tools_task()
# Update tool to server mapping
for tool in tools:
self.tool_name_to_mcp_server_name_mapping[tool.name] = server.name
return tools
return tools
except asyncio.CancelledError:
verbose_logger.warning(f"Task cancelled while listing tools from {server.name}")
raise # Re-raise the cancellation
except Exception as e:
verbose_logger.exception(f"Error listing tools from {server.name}: {str(e)}")
raise
except Exception as e:
verbose_logger.exception(f"Failed to get tools from server {server.name}: {str(e)}")
return [] # Return empty list on failure
finally:
if client:
try:
await client.disconnect()
except Exception:
pass
async def call_tool(
self,
@@ -38,54 +38,98 @@ if MCP_AVAILABLE:
None, description="The server id to list tools for"
),
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
) -> List[ListMCPToolsRestAPIResponseObject]:
) -> dict:
"""
List all available tools with information about the server they belong to.
Example response:
Tools:
[
{
"name": "create_zap",
"description": "Create a new zap",
"inputSchema": "tool_input_schema",
"mcp_info": {
"server_name": "zapier",
"logo_url": "https://www.zapier.com/logo.png",
{
"tools": [
{
"name": "create_zap",
"description": "Create a new zap",
"inputSchema": "tool_input_schema",
"mcp_info": {
"server_name": "zapier",
"logo_url": "https://www.zapier.com/logo.png",
}
}
},
{
"name": "fetch_data",
"description": "Fetch data from a URL",
"inputSchema": "tool_input_schema",
"mcp_info": {
"server_name": "fetch",
"logo_url": "https://www.fetch.com/logo.png",
}
}
]
],
"error": null,
"message": "Successfully retrieved tools"
}
"""
list_tools_result: List[ListMCPToolsRestAPIResponseObject] = []
for server in global_mcp_server_manager.get_registry().values():
if server_id and server.server_id != server_id:
continue
try:
tools = await global_mcp_server_manager._get_tools_from_server(
server=server,
)
for tool in tools:
list_tools_result.append(
ListMCPToolsRestAPIResponseObject(
name=tool.name,
description=tool.description,
inputSchema=tool.inputSchema,
mcp_info=server.mcp_info,
)
try:
list_tools_result = []
error_message = None
# If server_id is specified, only query that specific server
if server_id:
server = global_mcp_server_manager.get_mcp_server_by_id(server_id)
if server is None:
return {
"tools": [],
"error": "server_not_found",
"message": f"Server with id {server_id} not found"
}
try:
tools = await global_mcp_server_manager._get_tools_from_server(
server=server,
)
except Exception as e:
verbose_logger.exception(f"Error getting tools from {server.name}: {e}")
continue
return list_tools_result
for tool in tools:
list_tools_result.append(
ListMCPToolsRestAPIResponseObject(
name=tool.name,
description=tool.description,
inputSchema=tool.inputSchema,
mcp_info=server.mcp_info,
)
)
except Exception as e:
verbose_logger.exception(f"Error getting tools from {server.name}: {e}")
return {
"tools": [],
"error": "server_error",
"message": f"Failed to get tools from server {server.name}: {str(e)}"
}
else:
# Query all servers
errors = []
for server in global_mcp_server_manager.get_registry().values():
try:
tools = await global_mcp_server_manager._get_tools_from_server(
server=server,
)
for tool in tools:
list_tools_result.append(
ListMCPToolsRestAPIResponseObject(
name=tool.name,
description=tool.description,
inputSchema=tool.inputSchema,
mcp_info=server.mcp_info,
)
)
except Exception as e:
verbose_logger.exception(f"Error getting tools from {server.name}: {e}")
errors.append(f"{server.name}: {str(e)}")
continue
if errors and not list_tools_result:
error_message = "Failed to get tools from servers: " + "; ".join(errors)
return {
"tools": list_tools_result,
"error": "partial_failure" if error_message else None,
"message": error_message if error_message else "Successfully retrieved tools"
}
except Exception as e:
verbose_logger.exception("Unexpected error in list_tool_rest_api: %s", str(e))
return {
"tools": [],
"error": "unexpected_error",
"message": f"An unexpected error occurred: {str(e)}"
}
@router.post("/tools/call", dependencies=[Depends(user_api_key_auth)])
async def call_tool_rest_api(
+88
View File
@@ -512,3 +512,91 @@ def test_generate_stable_server_id():
assert zapier_sse_hash != github_http_hash
@pytest.mark.asyncio
async def test_list_tools_rest_api_server_not_found():
"""Test the list_tools REST API when server is not found"""
from litellm.proxy._experimental.mcp_server.rest_endpoints import list_tool_rest_api
from fastapi import Query
from litellm.proxy._types import UserAPIKeyAuth
# Mock UserAPIKeyAuth
mock_user_auth = UserAPIKeyAuth(api_key="test", user_id="test")
# Test with non-existent server ID
response = await list_tool_rest_api(
server_id="non_existent_server_id",
user_api_key_dict=mock_user_auth
)
assert isinstance(response, dict)
assert response["tools"] == []
assert response["error"] == "server_not_found"
assert "Server with id non_existent_server_id not found" in response["message"]
@pytest.mark.asyncio
async def test_list_tools_rest_api_success():
"""Test the list_tools REST API successful case"""
from litellm.proxy._experimental.mcp_server.rest_endpoints import list_tool_rest_api, global_mcp_server_manager
from fastapi import Query
from litellm.proxy._types import UserAPIKeyAuth
# Store original registry to restore after test
original_registry = global_mcp_server_manager.get_registry().copy()
original_tool_mapping = global_mcp_server_manager.tool_name_to_mcp_server_name_mapping.copy()
try:
# Clear existing registry
global_mcp_server_manager.tool_name_to_mcp_server_name_mapping.clear()
global_mcp_server_manager.registry.clear()
global_mcp_server_manager.config_mcp_servers.clear()
# Mock successful tools
mock_tools = [
MCPTool(
name="test_tool",
description="A test tool",
inputSchema={"type": "object"}
)
]
# Create mock client
mock_client = AsyncMock()
mock_client.list_tools = AsyncMock(return_value=mock_tools)
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
mock_client.__aexit__ = AsyncMock(return_value=None)
def mock_client_constructor(*args, **kwargs):
return mock_client
with patch('litellm.proxy._experimental.mcp_server.mcp_server_manager.MCPClient', mock_client_constructor):
# Load server config into global manager
global_mcp_server_manager.load_servers_from_config({
"test_server": {
"url": "https://test-server.com/mcp",
"transport": MCPTransport.http,
}
})
# Mock UserAPIKeyAuth
mock_user_auth = UserAPIKeyAuth(api_key="test", user_id="test")
# Get the server ID
server_id = list(global_mcp_server_manager.get_registry().keys())[0]
# Test successful case
response = await list_tool_rest_api(
server_id=server_id,
user_api_key_dict=mock_user_auth
)
assert isinstance(response, dict)
assert len(response["tools"]) == 1
assert response["tools"][0].name == "test_tool"
assert response["error"] is None
assert response["message"] == "Successfully retrieved tools"
finally:
# Restore original state
global_mcp_server_manager.registry = {}
global_mcp_server_manager.config_mcp_servers = original_registry
global_mcp_server_manager.tool_name_to_mcp_server_name_mapping = original_tool_mapping
@@ -159,13 +159,11 @@ const MCPToolsViewer = ({
const [searchTerm, setSearchTerm] = useState("");
const [mcpAuthValue, setMcpAuthValue] = useState("");
const [selectedTool, setSelectedTool] = useState<MCPTool | null>(null);
const [toolResult, setToolResult] = useState<CallMCPToolResponse | null>(
null
);
const [toolResult, setToolResult] = useState<CallMCPToolResponse | null>(null);
const [toolError, setToolError] = useState<Error | null>(null);
// Query to fetch MCP tools
const { data: mcpTools, isLoading: isLoadingTools } = useQuery({
const { data: mcpToolsResponse, isLoading: isLoadingTools, error: mcpToolsError } = useQuery({
queryKey: ["mcpTools"],
queryFn: () => {
if (!accessToken) throw new Error("Access Token required");
@@ -192,9 +190,9 @@ const MCPToolsViewer = ({
// Add onToolSelect handler to each tool
const toolsData = React.useMemo(() => {
if (!mcpTools) return [];
if (!mcpToolsResponse) return [];
return mcpTools.map((tool: MCPTool) => ({
return (mcpToolsResponse.tools || []).map((tool: MCPTool) => ({
...tool,
onToolSelect: (tool: MCPTool) => {
setSelectedTool(tool);
@@ -202,108 +200,65 @@ const MCPToolsViewer = ({
setToolError(null);
},
}));
}, [mcpTools]);
}, [mcpToolsResponse]);
// Filter tools based on search term
const filteredTools = React.useMemo(() => {
return toolsData.filter((tool: MCPTool) => {
const searchLower = searchTerm.toLowerCase();
return (
tool.name.toLowerCase().includes(searchLower) ||
(tool.description != null &&
tool.description.toLowerCase().includes(searchLower)) ||
tool.mcp_info.server_name.toLowerCase().includes(searchLower)
);
});
}, [toolsData, searchTerm]);
// Error message display
const errorMessage = mcpToolsResponse?.error ? (
<div className="p-4 mb-4 text-sm text-red-800 rounded-lg bg-red-50">
<p className="font-medium">Error: {mcpToolsResponse.message}</p>
{mcpToolsResponse.error === "server_not_found" && (
<p className="mt-2">The specified server could not be found. Please check the server ID and try again.</p>
)}
{mcpToolsResponse.error === "server_error" && (
<p className="mt-2">There was an error connecting to the server. Please try again later or contact support if the issue persists.</p>
)}
{mcpToolsResponse.error === "unexpected_error" && (
<p className="mt-2">An unexpected error occurred. Please try again later or contact support if the issue persists.</p>
)}
</div>
) : null;
// Handle tool call submission
const handleToolSubmit = (args: Record<string, any>) => {
if (!selectedTool) return;
executeTool({
tool: selectedTool,
arguments: args,
authValue: mcpAuthValue
});
};
if (!accessToken || !userRole || !userID) {
return (
<div className="p-6 text-center text-gray-500">
Missing required authentication parameters.
</div>
);
}
// No tools message
const noToolsMessage = !isLoadingTools && !mcpToolsResponse?.error && (!toolsData || toolsData.length === 0) ? (
<div className="p-4 mb-4 text-sm text-gray-800 rounded-lg bg-gray-50">
<p className="font-medium">No tools available</p>
<p className="mt-2">No tools were found for this server. This could be because:</p>
<ul className="list-disc list-inside mt-2">
<li>The server has not registered any tools</li>
<li>There was an error connecting to the server</li>
<li>The server is still initializing</li>
</ul>
</div>
) : null;
return (
<div className="w-full p-6">
<div className="flex items-center justify-between mb-4">
<h1 className="text-xl font-semibold">MCP Tools</h1>
</div>
{mcpServerHasAuth(auth_type) && (
<AuthSection
authType={auth_type}
onAuthSubmit={(value: string) => {
setMcpAuthValue(value);
}}
/>
)}
<div className="bg-white rounded-lg shadow">
<div className="border-b px-6 py-4">
<div className="flex items-center justify-between">
<div className="relative w-64">
<input
type="text"
placeholder="Search tools..."
className="w-full px-3 py-2 pl-8 border rounded-md text-sm focus:outline-none focus:ring-2 focus:ring-blue-500 focus:border-blue-500"
value={searchTerm}
onChange={(e) => setSearchTerm(e.target.value)}
/>
<svg
className="absolute left-2.5 top-2.5 h-4 w-4 text-gray-500"
fill="none"
stroke="currentColor"
viewBox="0 0 24 24"
>
<path
strokeLinecap="round"
strokeLinejoin="round"
strokeWidth={2}
d="M21 21l-6-6m2-5a7 7 0 11-14 0 7 7 0 0114 0z"
/>
</svg>
</div>
<div className="text-sm text-gray-500">
{filteredTools.length} tool{filteredTools.length !== 1 ? "s" : ""}{" "}
available
</div>
</div>
</div>
<DataTableWrapper
columns={columns}
data={filteredTools}
isLoading={isLoadingTools}
/>
</div>
{/* Tool Test Panel - Show when a tool is selected */}
<div className="space-y-4">
{errorMessage}
{noToolsMessage}
<DataTableWrapper
columns={columns}
data={toolsData}
isLoading={isLoadingTools}
/>
{selectedTool && (
<div className="fixed inset-0 bg-gray-800 bg-opacity-75 flex items-center justify-center z-50 p-4">
<ToolTestPanel
tool={selectedTool}
needsAuth={mcpServerHasAuth(auth_type)}
authValue={mcpAuthValue}
onSubmit={handleToolSubmit}
isLoading={isCallingTool}
result={toolResult}
error={toolError}
onClose={() => setSelectedTool(null)}
/>
</div>
<ToolTestPanel
tool={selectedTool}
needsAuth={mcpServerHasAuth(auth_type)}
authValue={mcpAuthValue}
onSubmit={(args) => {
executeTool({ tool: selectedTool, arguments: args, authValue: mcpAuthValue });
}}
result={toolResult}
error={toolError}
isLoading={isCallingTool}
onClose={() => setSelectedTool(null)}
/>
)}
{mcpServerHasAuth(auth_type) && (
<AuthSection
authType={auth_type}
onAuthSubmit={(value) => setMcpAuthValue(value)}
/>
)}
</div>
);
@@ -4889,18 +4889,28 @@ export const listMCPTools = async (accessToken: string, serverId: string) => {
},
});
const data = await response.json();
console.log("Fetched MCP tools response:", data);
if (!response.ok) {
const errorData = await response.text();
handleError(errorData);
throw new Error("Network response was not ok");
// If the server returned an error response, use it
if (data.error && data.message) {
throw new Error(data.message);
}
// Otherwise use a generic error
throw new Error("Failed to fetch MCP tools");
}
const data = await response.json();
console.log("Fetched MCP tools:", data);
// Return the full response object which includes tools, error, and message
return data;
} catch (error) {
console.error("Failed to fetch MCP tools:", error);
throw error;
// Return an error response in the same format as the API
return {
tools: [],
error: "network_error",
message: error instanceof Error ? error.message : "Failed to fetch MCP tools"
};
}
};