From 27e6ef5f39457fa342686944fc9ae8a67e95adca Mon Sep 17 00:00:00 2001 From: Jugal Bhatt Date: Tue, 1 Jul 2025 17:12:01 -0700 Subject: [PATCH] added error handling for MCP tools not found --- litellm/experimental_mcp_client/client.py | 108 ++++++++---- .../mcp_server/mcp_server_manager.py | 46 +++-- .../mcp_server/rest_endpoints.py | 126 +++++++++----- tests/mcp_tests/test_mcp_server.py | 88 ++++++++++ .../src/components/mcp_tools/mcp_tools.tsx | 161 +++++++----------- .../src/components/networking.tsx | 22 ++- 6 files changed, 354 insertions(+), 197 deletions(-) diff --git a/litellm/experimental_mcp_client/client.py b/litellm/experimental_mcp_client/client.py index af2cb171da..f2d32a5a9f 100644 --- a/litellm/experimental_mcp_client/client.py +++ b/litellm/experimental_mcp_client/client.py @@ -4,6 +4,7 @@ LiteLLM Proxy uses this MCP Client to connnect to other MCP servers. import base64 from datetime import timedelta from typing import List, Optional +import asyncio from mcp import ClientSession from mcp.client.sse import sse_client @@ -46,6 +47,7 @@ class MCPClient: self._transport_ctx = None self._transport = None self._session_ctx = None + self._task: Optional[asyncio.Task] = None # handle the basic auth value if provided if auth_value: @@ -56,8 +58,12 @@ class MCPClient: Enable async context manager support. Initializes the transport and session. """ - await self.connect() - return self + try: + await self.connect() + return self + except Exception as e: + await self.disconnect() + raise async def connect(self): """Initialize the transport and session.""" @@ -66,47 +72,63 @@ class MCPClient: headers = self._get_auth_headers() - if self.transport_type == MCPTransport.sse: - self._transport_ctx = sse_client( - url=self.server_url, - timeout=self.timeout, - headers=headers, - ) - self._transport = await self._transport_ctx.__aenter__() - self._session_ctx = ClientSession(self._transport[0], self._transport[1]) - self._session = await self._session_ctx.__aenter__() - await self._session.initialize() - else: - self._transport_ctx = streamablehttp_client( - url=self.server_url, - timeout=timedelta(seconds=self.timeout), - headers=headers, - ) - self._transport = await self._transport_ctx.__aenter__() - self._session_ctx = ClientSession(self._transport[0], self._transport[1]) - self._session = await self._session_ctx.__aenter__() - await self._session.initialize() + try: + if self.transport_type == MCPTransport.sse: + self._transport_ctx = sse_client( + url=self.server_url, + timeout=self.timeout, + headers=headers, + ) + self._transport = await self._transport_ctx.__aenter__() + self._session_ctx = ClientSession(self._transport[0], self._transport[1]) + self._session = await self._session_ctx.__aenter__() + await self._session.initialize() + else: + self._transport_ctx = streamablehttp_client( + url=self.server_url, + timeout=timedelta(seconds=self.timeout), + headers=headers, + ) + self._transport = await self._transport_ctx.__aenter__() + self._session_ctx = ClientSession(self._transport[0], self._transport[1]) + self._session = await self._session_ctx.__aenter__() + await self._session.initialize() + except Exception as e: + await self.disconnect() + raise async def __aexit__(self, exc_type, exc_val, exc_tb): """Cleanup when exiting context manager.""" - if self._session: - await self._session_ctx.__aexit__(exc_type, exc_val, exc_tb) # type: ignore - if self._transport_ctx: - await self._transport_ctx.__aexit__(exc_type, exc_val, exc_tb) + await self.disconnect() async def disconnect(self): """Clean up session and connections.""" + if self._task and not self._task.done(): + self._task.cancel() + try: + await self._task + except asyncio.CancelledError: + pass + if self._session: try: - # Ensure session is properly closed - await self._session.close() # type: ignore + await self._session_ctx.__aexit__(None, None, None) # type: ignore except Exception: pass self._session = None + self._session_ctx = None + + if self._transport_ctx: + try: + await self._transport_ctx.__aexit__(None, None, None) + except Exception: + pass + self._transport_ctx = None + self._transport = None if self._context: try: - await self._context.__aexit__(None, None, None) # type: ignore + await self._context.__aexit__(None, None, None) # type: ignore except Exception: pass self._context = None @@ -140,8 +162,15 @@ class MCPClient: if self._session is None: raise ValueError("Session is not initialized") - result = await self._session.list_tools() - return result.tools + try: + result = await self._session.list_tools() + return result.tools + except asyncio.CancelledError: + await self.disconnect() + raise + except Exception as e: + await self.disconnect() + raise async def call_tool( self, call_tool_request_params: MCPCallToolRequestParams @@ -155,10 +184,17 @@ class MCPClient: if self._session is None: raise ValueError("Session is not initialized") - tool_result = await self._session.call_tool( - name=call_tool_request_params.name, - arguments=call_tool_request_params.arguments, - ) - return tool_result + try: + tool_result = await self._session.call_tool( + name=call_tool_request_params.name, + arguments=call_tool_request_params.arguments, + ) + return tool_result + except asyncio.CancelledError: + await self.disconnect() + raise + except Exception as e: + await self.disconnect() + raise diff --git a/litellm/proxy/_experimental/mcp_server/mcp_server_manager.py b/litellm/proxy/_experimental/mcp_server/mcp_server_manager.py index d32a177914..c993e781e6 100644 --- a/litellm/proxy/_experimental/mcp_server/mcp_server_manager.py +++ b/litellm/proxy/_experimental/mcp_server/mcp_server_manager.py @@ -227,19 +227,43 @@ class MCPServerManager: verbose_logger.debug(f"Connecting to url: {server.url}") verbose_logger.info("_get_tools_from_server...") - client = self._create_mcp_client( - server=server, - mcp_auth_header=mcp_auth_header, - ) - async with client: - tools = await client.list_tools() - verbose_logger.debug(f"Tools from {server.name}: {tools}") + client = None + try: + client = self._create_mcp_client( + server=server, + mcp_auth_header=mcp_auth_header, + ) + + # Create a task for the client operations to ensure proper cancellation handling + async def _list_tools_task(): + async with client: + tools = await client.list_tools() + verbose_logger.debug(f"Tools from {server.name}: {tools}") + return tools - # Update tool to server mapping - for tool in tools: - self.tool_name_to_mcp_server_name_mapping[tool.name] = server.name + try: + tools = await _list_tools_task() + + # Update tool to server mapping + for tool in tools: + self.tool_name_to_mcp_server_name_mapping[tool.name] = server.name - return tools + return tools + except asyncio.CancelledError: + verbose_logger.warning(f"Task cancelled while listing tools from {server.name}") + raise # Re-raise the cancellation + except Exception as e: + verbose_logger.exception(f"Error listing tools from {server.name}: {str(e)}") + raise + except Exception as e: + verbose_logger.exception(f"Failed to get tools from server {server.name}: {str(e)}") + return [] # Return empty list on failure + finally: + if client: + try: + await client.disconnect() + except Exception: + pass async def call_tool( self, diff --git a/litellm/proxy/_experimental/mcp_server/rest_endpoints.py b/litellm/proxy/_experimental/mcp_server/rest_endpoints.py index 9094be6f42..07ff4712ed 100644 --- a/litellm/proxy/_experimental/mcp_server/rest_endpoints.py +++ b/litellm/proxy/_experimental/mcp_server/rest_endpoints.py @@ -38,54 +38,98 @@ if MCP_AVAILABLE: None, description="The server id to list tools for" ), user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), - ) -> List[ListMCPToolsRestAPIResponseObject]: + ) -> dict: """ List all available tools with information about the server they belong to. Example response: - Tools: - [ - { - "name": "create_zap", - "description": "Create a new zap", - "inputSchema": "tool_input_schema", - "mcp_info": { - "server_name": "zapier", - "logo_url": "https://www.zapier.com/logo.png", + { + "tools": [ + { + "name": "create_zap", + "description": "Create a new zap", + "inputSchema": "tool_input_schema", + "mcp_info": { + "server_name": "zapier", + "logo_url": "https://www.zapier.com/logo.png", + } } - }, - { - "name": "fetch_data", - "description": "Fetch data from a URL", - "inputSchema": "tool_input_schema", - "mcp_info": { - "server_name": "fetch", - "logo_url": "https://www.fetch.com/logo.png", - } - } - ] + ], + "error": null, + "message": "Successfully retrieved tools" + } """ - list_tools_result: List[ListMCPToolsRestAPIResponseObject] = [] - for server in global_mcp_server_manager.get_registry().values(): - if server_id and server.server_id != server_id: - continue - try: - tools = await global_mcp_server_manager._get_tools_from_server( - server=server, - ) - for tool in tools: - list_tools_result.append( - ListMCPToolsRestAPIResponseObject( - name=tool.name, - description=tool.description, - inputSchema=tool.inputSchema, - mcp_info=server.mcp_info, - ) + try: + list_tools_result = [] + error_message = None + + # If server_id is specified, only query that specific server + if server_id: + server = global_mcp_server_manager.get_mcp_server_by_id(server_id) + if server is None: + return { + "tools": [], + "error": "server_not_found", + "message": f"Server with id {server_id} not found" + } + try: + tools = await global_mcp_server_manager._get_tools_from_server( + server=server, ) - except Exception as e: - verbose_logger.exception(f"Error getting tools from {server.name}: {e}") - continue - return list_tools_result + for tool in tools: + list_tools_result.append( + ListMCPToolsRestAPIResponseObject( + name=tool.name, + description=tool.description, + inputSchema=tool.inputSchema, + mcp_info=server.mcp_info, + ) + ) + except Exception as e: + verbose_logger.exception(f"Error getting tools from {server.name}: {e}") + return { + "tools": [], + "error": "server_error", + "message": f"Failed to get tools from server {server.name}: {str(e)}" + } + else: + # Query all servers + errors = [] + for server in global_mcp_server_manager.get_registry().values(): + try: + tools = await global_mcp_server_manager._get_tools_from_server( + server=server, + ) + for tool in tools: + list_tools_result.append( + ListMCPToolsRestAPIResponseObject( + name=tool.name, + description=tool.description, + inputSchema=tool.inputSchema, + mcp_info=server.mcp_info, + ) + ) + except Exception as e: + verbose_logger.exception(f"Error getting tools from {server.name}: {e}") + errors.append(f"{server.name}: {str(e)}") + continue + + if errors and not list_tools_result: + error_message = "Failed to get tools from servers: " + "; ".join(errors) + + return { + "tools": list_tools_result, + "error": "partial_failure" if error_message else None, + "message": error_message if error_message else "Successfully retrieved tools" + } + + except Exception as e: + verbose_logger.exception("Unexpected error in list_tool_rest_api: %s", str(e)) + return { + "tools": [], + "error": "unexpected_error", + "message": f"An unexpected error occurred: {str(e)}" + } @router.post("/tools/call", dependencies=[Depends(user_api_key_auth)]) async def call_tool_rest_api( diff --git a/tests/mcp_tests/test_mcp_server.py b/tests/mcp_tests/test_mcp_server.py index 3115ab80ef..031e26cf16 100644 --- a/tests/mcp_tests/test_mcp_server.py +++ b/tests/mcp_tests/test_mcp_server.py @@ -512,3 +512,91 @@ def test_generate_stable_server_id(): assert zapier_sse_hash != github_http_hash +@pytest.mark.asyncio +async def test_list_tools_rest_api_server_not_found(): + """Test the list_tools REST API when server is not found""" + from litellm.proxy._experimental.mcp_server.rest_endpoints import list_tool_rest_api + from fastapi import Query + from litellm.proxy._types import UserAPIKeyAuth + + # Mock UserAPIKeyAuth + mock_user_auth = UserAPIKeyAuth(api_key="test", user_id="test") + + # Test with non-existent server ID + response = await list_tool_rest_api( + server_id="non_existent_server_id", + user_api_key_dict=mock_user_auth + ) + + assert isinstance(response, dict) + assert response["tools"] == [] + assert response["error"] == "server_not_found" + assert "Server with id non_existent_server_id not found" in response["message"] + +@pytest.mark.asyncio +async def test_list_tools_rest_api_success(): + """Test the list_tools REST API successful case""" + from litellm.proxy._experimental.mcp_server.rest_endpoints import list_tool_rest_api, global_mcp_server_manager + from fastapi import Query + from litellm.proxy._types import UserAPIKeyAuth + + # Store original registry to restore after test + original_registry = global_mcp_server_manager.get_registry().copy() + original_tool_mapping = global_mcp_server_manager.tool_name_to_mcp_server_name_mapping.copy() + try: + # Clear existing registry + global_mcp_server_manager.tool_name_to_mcp_server_name_mapping.clear() + global_mcp_server_manager.registry.clear() + global_mcp_server_manager.config_mcp_servers.clear() + + # Mock successful tools + mock_tools = [ + MCPTool( + name="test_tool", + description="A test tool", + inputSchema={"type": "object"} + ) + ] + + # Create mock client + mock_client = AsyncMock() + mock_client.list_tools = AsyncMock(return_value=mock_tools) + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=None) + + def mock_client_constructor(*args, **kwargs): + return mock_client + + with patch('litellm.proxy._experimental.mcp_server.mcp_server_manager.MCPClient', mock_client_constructor): + # Load server config into global manager + global_mcp_server_manager.load_servers_from_config({ + "test_server": { + "url": "https://test-server.com/mcp", + "transport": MCPTransport.http, + } + }) + + # Mock UserAPIKeyAuth + mock_user_auth = UserAPIKeyAuth(api_key="test", user_id="test") + + # Get the server ID + server_id = list(global_mcp_server_manager.get_registry().keys())[0] + + # Test successful case + response = await list_tool_rest_api( + server_id=server_id, + user_api_key_dict=mock_user_auth + ) + + assert isinstance(response, dict) + assert len(response["tools"]) == 1 + assert response["tools"][0].name == "test_tool" + assert response["error"] is None + assert response["message"] == "Successfully retrieved tools" + finally: + # Restore original state + global_mcp_server_manager.registry = {} + global_mcp_server_manager.config_mcp_servers = original_registry + global_mcp_server_manager.tool_name_to_mcp_server_name_mapping = original_tool_mapping + + diff --git a/ui/litellm-dashboard/src/components/mcp_tools/mcp_tools.tsx b/ui/litellm-dashboard/src/components/mcp_tools/mcp_tools.tsx index 26cd758678..4a6b78cb14 100644 --- a/ui/litellm-dashboard/src/components/mcp_tools/mcp_tools.tsx +++ b/ui/litellm-dashboard/src/components/mcp_tools/mcp_tools.tsx @@ -159,13 +159,11 @@ const MCPToolsViewer = ({ const [searchTerm, setSearchTerm] = useState(""); const [mcpAuthValue, setMcpAuthValue] = useState(""); const [selectedTool, setSelectedTool] = useState(null); - const [toolResult, setToolResult] = useState( - null - ); + const [toolResult, setToolResult] = useState(null); const [toolError, setToolError] = useState(null); // Query to fetch MCP tools - const { data: mcpTools, isLoading: isLoadingTools } = useQuery({ + const { data: mcpToolsResponse, isLoading: isLoadingTools, error: mcpToolsError } = useQuery({ queryKey: ["mcpTools"], queryFn: () => { if (!accessToken) throw new Error("Access Token required"); @@ -192,9 +190,9 @@ const MCPToolsViewer = ({ // Add onToolSelect handler to each tool const toolsData = React.useMemo(() => { - if (!mcpTools) return []; + if (!mcpToolsResponse) return []; - return mcpTools.map((tool: MCPTool) => ({ + return (mcpToolsResponse.tools || []).map((tool: MCPTool) => ({ ...tool, onToolSelect: (tool: MCPTool) => { setSelectedTool(tool); @@ -202,108 +200,65 @@ const MCPToolsViewer = ({ setToolError(null); }, })); - }, [mcpTools]); + }, [mcpToolsResponse]); - // Filter tools based on search term - const filteredTools = React.useMemo(() => { - return toolsData.filter((tool: MCPTool) => { - const searchLower = searchTerm.toLowerCase(); - return ( - tool.name.toLowerCase().includes(searchLower) || - (tool.description != null && - tool.description.toLowerCase().includes(searchLower)) || - tool.mcp_info.server_name.toLowerCase().includes(searchLower) - ); - }); - }, [toolsData, searchTerm]); + // Error message display + const errorMessage = mcpToolsResponse?.error ? ( +
+

Error: {mcpToolsResponse.message}

+ {mcpToolsResponse.error === "server_not_found" && ( +

The specified server could not be found. Please check the server ID and try again.

+ )} + {mcpToolsResponse.error === "server_error" && ( +

There was an error connecting to the server. Please try again later or contact support if the issue persists.

+ )} + {mcpToolsResponse.error === "unexpected_error" && ( +

An unexpected error occurred. Please try again later or contact support if the issue persists.

+ )} +
+ ) : null; - // Handle tool call submission - const handleToolSubmit = (args: Record) => { - if (!selectedTool) return; - - executeTool({ - tool: selectedTool, - arguments: args, - authValue: mcpAuthValue - }); - }; - - if (!accessToken || !userRole || !userID) { - return ( -
- Missing required authentication parameters. -
- ); - } + // No tools message + const noToolsMessage = !isLoadingTools && !mcpToolsResponse?.error && (!toolsData || toolsData.length === 0) ? ( +
+

No tools available

+

No tools were found for this server. This could be because:

+
    +
  • The server has not registered any tools
  • +
  • There was an error connecting to the server
  • +
  • The server is still initializing
  • +
+
+ ) : null; return ( -
-
-

MCP Tools

-
- - {mcpServerHasAuth(auth_type) && ( - { - setMcpAuthValue(value); - }} - /> - )} - -
-
-
-
- setSearchTerm(e.target.value)} - /> - - - -
-
- {filteredTools.length} tool{filteredTools.length !== 1 ? "s" : ""}{" "} - available -
-
-
- - -
- - {/* Tool Test Panel - Show when a tool is selected */} +
+ {errorMessage} + {noToolsMessage} + {selectedTool && ( -
- setSelectedTool(null)} - /> -
+ { + executeTool({ tool: selectedTool, arguments: args, authValue: mcpAuthValue }); + }} + result={toolResult} + error={toolError} + isLoading={isCallingTool} + onClose={() => setSelectedTool(null)} + /> + )} + {mcpServerHasAuth(auth_type) && ( + setMcpAuthValue(value)} + /> )}
); diff --git a/ui/litellm-dashboard/src/components/networking.tsx b/ui/litellm-dashboard/src/components/networking.tsx index 21c2ea02ef..e42a433156 100644 --- a/ui/litellm-dashboard/src/components/networking.tsx +++ b/ui/litellm-dashboard/src/components/networking.tsx @@ -4889,18 +4889,28 @@ export const listMCPTools = async (accessToken: string, serverId: string) => { }, }); + const data = await response.json(); + console.log("Fetched MCP tools response:", data); + if (!response.ok) { - const errorData = await response.text(); - handleError(errorData); - throw new Error("Network response was not ok"); + // If the server returned an error response, use it + if (data.error && data.message) { + throw new Error(data.message); + } + // Otherwise use a generic error + throw new Error("Failed to fetch MCP tools"); } - const data = await response.json(); - console.log("Fetched MCP tools:", data); + // Return the full response object which includes tools, error, and message return data; } catch (error) { console.error("Failed to fetch MCP tools:", error); - throw error; + // Return an error response in the same format as the API + return { + tools: [], + error: "network_error", + message: error instanceof Error ? error.message : "Failed to fetch MCP tools" + }; } };