diff --git a/litellm-proxy-extras/litellm_proxy_extras/migrations/20251104220043_add_credentials_to_mcp_servers/migration.sql b/litellm-proxy-extras/litellm_proxy_extras/migrations/20251104220043_add_credentials_to_mcp_servers/migration.sql new file mode 100644 index 0000000000..800c96f18b --- /dev/null +++ b/litellm-proxy-extras/litellm_proxy_extras/migrations/20251104220043_add_credentials_to_mcp_servers/migration.sql @@ -0,0 +1,2 @@ +-- AlterTable +ALTER TABLE "LiteLLM_MCPServerTable" ADD COLUMN "credentials" JSONB DEFAULT '{}'; diff --git a/litellm-proxy-extras/litellm_proxy_extras/schema.prisma b/litellm-proxy-extras/litellm_proxy_extras/schema.prisma index 9702c39028..1ab193bba7 100644 --- a/litellm-proxy-extras/litellm_proxy_extras/schema.prisma +++ b/litellm-proxy-extras/litellm_proxy_extras/schema.prisma @@ -174,6 +174,7 @@ model LiteLLM_MCPServerTable { url String? transport String @default("sse") auth_type String? + credentials Json? @default("{}") created_at DateTime? @default(now()) @map("created_at") created_by String? updated_at DateTime? @default(now()) @updatedAt @map("updated_at") diff --git a/litellm/proxy/_experimental/mcp_server/db.py b/litellm/proxy/_experimental/mcp_server/db.py index dca9af62c7..a9734233a6 100644 --- a/litellm/proxy/_experimental/mcp_server/db.py +++ b/litellm/proxy/_experimental/mcp_server/db.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Iterable, List, Optional, Set, Union +from typing import Any, Dict, Iterable, List, Optional, Set, Union, cast from litellm._logging import verbose_proxy_logger from litellm._uuid import uuid @@ -11,7 +11,12 @@ from litellm.proxy._types import ( UpdateMCPServerRequest, UserAPIKeyAuth, ) +from litellm.proxy.common_utils.encrypt_decrypt_utils import ( + _get_salt_key, + encrypt_value_helper, +) from litellm.proxy.utils import PrismaClient +from litellm.types.mcp import MCPCredentials def _prepare_mcp_server_data( @@ -35,6 +40,14 @@ def _prepare_mcp_server_data( if "alias" not in data_dict: data_dict["alias"] = getattr(data, "alias", None) + # Handle credentials serialization + credentials = data_dict.get("credentials") + if credentials is not None: + data_dict["credentials"] = encrypt_credentials( + credentials=credentials, encryption_key=_get_salt_key() + ) + data_dict["credentials"] = safe_dumps(data_dict["credentials"]) + # Handle static_headers serialization if data.static_headers is not None: data_dict["static_headers"] = safe_dumps(data.static_headers) @@ -52,6 +65,30 @@ def _prepare_mcp_server_data( return data_dict +def encrypt_credentials( + credentials: MCPCredentials, encryption_key: Optional[str] +) -> MCPCredentials: + auth_value = credentials.get("auth_value") + if auth_value is not None: + credentials["auth_value"] = encrypt_value_helper( + value=auth_value, + new_encryption_key=encryption_key, + ) + client_id = credentials.get("client_id") + if client_id is not None: + credentials["client_id"] = encrypt_value_helper( + value=client_id, + new_encryption_key=encryption_key, + ) + client_secret = credentials.get("client_secret") + if client_secret is not None: + credentials["client_secret"] = encrypt_value_helper( + value=client_secret, + new_encryption_key=encryption_key, + ) + return credentials + + async def get_all_mcp_servers( prisma_client: PrismaClient, ) -> List[LiteLLM_MCPServerTable]: @@ -303,3 +340,32 @@ async def update_mcp_server( ) return updated_mcp_server + + +async def rotate_mcp_server_credentials_master_key( + prisma_client: PrismaClient, touched_by: str, new_master_key: str +): + mcp_servers = await prisma_client.db.litellm_mcpservertable.find_many() + + for mcp_server in mcp_servers: + credentials = mcp_server.credentials + if not credentials: + continue + + credentials_copy = dict(credentials) + encrypted_credentials = encrypt_credentials( + credentials=cast(MCPCredentials, credentials_copy), + encryption_key=new_master_key, + ) + + from litellm.litellm_core_utils.safe_json_dumps import safe_dumps + + serialized_credentials = safe_dumps(encrypted_credentials) + + await prisma_client.db.litellm_mcpservertable.update( + where={"server_id": mcp_server.server_id}, + data={ + "credentials": serialized_credentials, + "updated_by": touched_by, + }, + ) diff --git a/litellm/proxy/_experimental/mcp_server/mcp_server_manager.py b/litellm/proxy/_experimental/mcp_server/mcp_server_manager.py index 49ff1effd1..5b1dc5933c 100644 --- a/litellm/proxy/_experimental/mcp_server/mcp_server_manager.py +++ b/litellm/proxy/_experimental/mcp_server/mcp_server_manager.py @@ -38,6 +38,9 @@ from litellm.proxy._types import ( MCPTransportType, UserAPIKeyAuth, ) +from litellm.proxy.common_utils.encrypt_decrypt_utils import ( + decrypt_value_helper, +) from litellm.proxy.utils import ProxyLogging from litellm.types.mcp import MCPAuth, MCPStdioConfig from litellm.types.mcp_server.mcp_server_manager import MCPInfo, MCPServer @@ -400,6 +403,20 @@ class MCPServerManager: static_headers_dict = _deserialize_json_dict( getattr(mcp_server, "static_headers", None) ) + credentials_dict = _deserialize_json_dict( + getattr(mcp_server, "credentials", None) + ) + + encrypted_auth_value: Optional[str] = None + if credentials_dict: + encrypted_auth_value = credentials_dict.get("auth_value") + + auth_value: Optional[str] = None + if encrypted_auth_value: + auth_value = decrypt_value_helper( + value=encrypted_auth_value, + key="auth_value", + ) # Use alias for name if present, else server_name name_for_prefix = ( mcp_server.alias or mcp_server.server_name or mcp_server.server_id @@ -422,6 +439,7 @@ class MCPServerManager: url=mcp_server.url, transport=cast(MCPTransportType, mcp_server.transport), auth_type=cast(MCPAuthType, mcp_server.auth_type), + authentication_token=auth_value, mcp_info=mcp_info, extra_headers=getattr(mcp_server, "extra_headers", None), static_headers=static_headers_dict, diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index d739727cca..e8f89baeff 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -17,7 +17,13 @@ from typing_extensions import Required, TypedDict from litellm._uuid import uuid from litellm.types.integrations.slack_alerting import AlertType from litellm.types.llms.openai import AllMessageValues, OpenAIFileObject -from litellm.types.mcp import MCPAuthType, MCPTransport, MCPTransportType +from litellm.types.mcp import ( + MCPAuth, + MCPAuthType, + MCPCredentials, + MCPTransport, + MCPTransportType, +) from litellm.types.mcp_server.mcp_server_manager import MCPInfo from litellm.types.router import RouterErrors, UpdateRouterConfig from litellm.types.secret_managers.main import KeyManagementSystem @@ -959,6 +965,7 @@ class NewMCPServerRequest(LiteLLMPydanticObjectBase): description: Optional[str] = None transport: MCPTransportType = MCPTransport.sse auth_type: Optional[MCPAuthType] = None + credentials: Optional[MCPCredentials] = None url: Optional[str] = None mcp_info: Optional[MCPInfo] = None mcp_access_groups: List[str] = Field(default_factory=list) @@ -985,6 +992,28 @@ class NewMCPServerRequest(LiteLLMPydanticObjectBase): raise ValueError("url is required for HTTP/SSE transport") return values + @model_validator(mode="before") + @classmethod + def validate_credentials_requirements(cls, values): + if not isinstance(values, dict): + return values + + auth_type = values.get("auth_type") + if auth_type in {MCPAuth.api_key, MCPAuth.bearer_token, MCPAuth.basic}: + credentials = values.get("credentials") + auth_value = None + if isinstance(credentials, dict): + auth_value = credentials.get("auth_value") + elif hasattr(credentials, "get"): + auth_value = credentials.get("auth_value") # type: ignore[attr-defined] + + if not auth_value: + raise ValueError( + "auth_value is required when auth_type is api_key, bearer_token, or basic" + ) + + return values + class UpdateMCPServerRequest(LiteLLMPydanticObjectBase): server_id: str @@ -993,6 +1022,7 @@ class UpdateMCPServerRequest(LiteLLMPydanticObjectBase): description: Optional[str] = None transport: MCPTransportType = MCPTransport.sse auth_type: Optional[MCPAuthType] = None + credentials: Optional[MCPCredentials] = None url: Optional[str] = None mcp_info: Optional[MCPInfo] = None mcp_access_groups: List[str] = Field(default_factory=list) @@ -1028,6 +1058,7 @@ class LiteLLM_MCPServerTable(LiteLLMPydanticObjectBase): url: Optional[str] = None transport: MCPTransportType auth_type: Optional[MCPAuthType] = None + credentials: Optional[MCPCredentials] = None created_at: Optional[datetime] = None created_by: Optional[str] = None updated_at: Optional[datetime] = None diff --git a/litellm/proxy/management_endpoints/key_management_endpoints.py b/litellm/proxy/management_endpoints/key_management_endpoints.py index 472cb81f93..a8def13869 100644 --- a/litellm/proxy/management_endpoints/key_management_endpoints.py +++ b/litellm/proxy/management_endpoints/key_management_endpoints.py @@ -24,8 +24,15 @@ import litellm from litellm._logging import verbose_proxy_logger from litellm._uuid import uuid from litellm.caching import DualCache -from litellm.constants import LENGTH_OF_LITELLM_GENERATED_KEY, UI_SESSION_TOKEN_TEAM_ID +from litellm.constants import ( + LENGTH_OF_LITELLM_GENERATED_KEY, + LITELLM_PROXY_ADMIN_NAME, + UI_SESSION_TOKEN_TEAM_ID, +) from litellm.litellm_core_utils.duration_parser import duration_in_seconds +from litellm.proxy._experimental.mcp_server.db import ( + rotate_mcp_server_credentials_master_key, +) from litellm.proxy._types import * from litellm.proxy._types import LiteLLM_VerificationToken from litellm.proxy.auth.auth_checks import ( @@ -640,9 +647,9 @@ async def _common_key_generation_helper( # noqa: PLR0915 request_type="key", **data_json, table_name="key" ) - response["soft_budget"] = ( - data.soft_budget - ) # include the user-input soft budget in the response + response[ + "soft_budget" + ] = data.soft_budget # include the user-input soft budget in the response response = GenerateKeyResponse(**response) @@ -1299,7 +1306,6 @@ async def prepare_key_update_data( data: Union[UpdateKeyRequest, RegenerateKeyRequest], existing_key_row: LiteLLM_VerificationToken, ): - data_json: dict = data.model_dump(exclude_unset=True) data_json.pop("key", None) data_json.pop("new_key", None) @@ -2357,10 +2363,10 @@ async def delete_verification_tokens( try: if prisma_client: tokens = [_hash_token_if_needed(token=key) for key in tokens] - _keys_being_deleted: List[LiteLLM_VerificationToken] = ( - await prisma_client.db.litellm_verificationtoken.find_many( - where={"token": {"in": tokens}} - ) + _keys_being_deleted: List[ + LiteLLM_VerificationToken + ] = await prisma_client.db.litellm_verificationtoken.find_many( + where={"token": {"in": tokens}} ) if len(_keys_being_deleted) == 0: @@ -2468,9 +2474,9 @@ async def _rotate_master_key( from litellm.proxy.proxy_server import proxy_config try: - models: Optional[List] = ( - await prisma_client.db.litellm_proxymodeltable.find_many() - ) + models: Optional[ + List + ] = await prisma_client.db.litellm_proxymodeltable.find_many() except Exception: models = None # 2. process model table @@ -2524,6 +2530,13 @@ async def _rotate_master_key( data={"param_value": jsonify_object(encrypted_env_vars)}, ) + # 4. process MCP server table + await rotate_mcp_server_credentials_master_key( + prisma_client=prisma_client, + touched_by=user_api_key_dict.user_id or LITELLM_PROXY_ADMIN_NAME, + new_master_key=new_master_key, + ) + def get_new_token(data: Optional[RegenerateKeyRequest]) -> str: if data and data.new_key is not None: @@ -2781,11 +2794,11 @@ async def validate_key_list_check( param="user_id", code=status.HTTP_403_FORBIDDEN, ) - complete_user_info_db_obj: Optional[BaseModel] = ( - await prisma_client.db.litellm_usertable.find_unique( - where={"user_id": user_api_key_dict.user_id}, - include={"organization_memberships": True}, - ) + complete_user_info_db_obj: Optional[ + BaseModel + ] = await prisma_client.db.litellm_usertable.find_unique( + where={"user_id": user_api_key_dict.user_id}, + include={"organization_memberships": True}, ) if complete_user_info_db_obj is None: @@ -2871,10 +2884,10 @@ async def get_admin_team_ids( if complete_user_info is None: return [] # Get all teams that user is an admin of - teams: Optional[List[BaseModel]] = ( - await prisma_client.db.litellm_teamtable.find_many( - where={"team_id": {"in": complete_user_info.teams}} - ) + teams: Optional[ + List[BaseModel] + ] = await prisma_client.db.litellm_teamtable.find_many( + where={"team_id": {"in": complete_user_info.teams}} ) if teams is None: return [] diff --git a/litellm/proxy/management_endpoints/mcp_management_endpoints.py b/litellm/proxy/management_endpoints/mcp_management_endpoints.py index b0c09b1e2e..22874ca8f1 100644 --- a/litellm/proxy/management_endpoints/mcp_management_endpoints.py +++ b/litellm/proxy/management_endpoints/mcp_management_endpoints.py @@ -58,6 +58,24 @@ if MCP_AVAILABLE: from litellm.proxy.management_endpoints.common_utils import _user_has_admin_view from litellm.proxy.management_helpers.utils import management_endpoint_wrapper + def _redact_mcp_credentials( + mcp_server: LiteLLM_MCPServerTable, + ) -> LiteLLM_MCPServerTable: + """Return a copy of the MCP server object with credentials removed.""" + + try: + redacted_server = mcp_server.model_copy(deep=True) + except AttributeError: + redacted_server = mcp_server.copy(deep=True) # type: ignore[attr-defined] + + redacted_server.credentials = None + return redacted_server + + def _redact_mcp_credentials_list( + mcp_servers: Iterable[LiteLLM_MCPServerTable], + ) -> List[LiteLLM_MCPServerTable]: + return [_redact_mcp_credentials(server) for server in mcp_servers] + def get_prisma_client_or_throw(message: str): from litellm.proxy.proxy_server import prisma_client @@ -273,11 +291,12 @@ if MCP_AVAILABLE: ``` """ # Use server manager to get all servers with health and team data - return ( + mcp_servers = ( await global_mcp_server_manager.get_all_mcp_servers_with_health_and_teams( user_api_key_auth=user_api_key_dict ) ) + return _redact_mcp_credentials_list(mcp_servers) @router.get( "/server/{server_id}", @@ -335,7 +354,7 @@ if MCP_AVAILABLE: # Implement authz restriction from requested user if _user_has_admin_view(user_api_key_dict): - return mcp_server + return _redact_mcp_credentials(mcp_server) # Perform authz check to filter the mcp servers user has access to mcp_server_records = await get_all_mcp_servers_for_user( @@ -345,7 +364,7 @@ if MCP_AVAILABLE: if exists: global_mcp_server_manager.add_update_server(mcp_server) - return mcp_server + return _redact_mcp_credentials(mcp_server) else: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, @@ -428,7 +447,7 @@ if MCP_AVAILABLE: status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail={"error": f"Error creating mcp server: {str(e)}"}, ) - return new_mcp_server + return _redact_mcp_credentials(new_mcp_server) @router.delete( "/server/{server_id}", @@ -563,4 +582,4 @@ if MCP_AVAILABLE: if litellm.store_audit_logs: pass - return mcp_server_record_updated + return _redact_mcp_credentials(mcp_server_record_updated) diff --git a/litellm/proxy/schema.prisma b/litellm/proxy/schema.prisma index 760954a1d6..025a1a0e3c 100644 --- a/litellm/proxy/schema.prisma +++ b/litellm/proxy/schema.prisma @@ -174,6 +174,7 @@ model LiteLLM_MCPServerTable { url String? transport String @default("sse") auth_type String? + credentials Json? @default("{}") created_at DateTime? @default(now()) @map("created_at") created_by String? updated_at DateTime? @default(now()) @updatedAt @map("updated_at") diff --git a/litellm/types/mcp.py b/litellm/types/mcp.py index bbb9e7d7b4..ce771c40d3 100644 --- a/litellm/types/mcp.py +++ b/litellm/types/mcp.py @@ -54,6 +54,28 @@ MCPAuthType = Optional[ ] +class MCPCredentials(TypedDict, total=False): + auth_value: Optional[str] + """ + Authentication value + """ + + client_id: Optional[str] + """ + OAuth 2.0 client identifier used when auth_type is oauth2 + """ + + client_secret: Optional[str] + """ + OAuth 2.0 client secret used when auth_type is oauth2 + """ + + scopes: Optional[List[str]] + """ + OAuth 2.0 scopes to request when exchanging the client credentials + """ + + class MCPServerCostInfo(TypedDict, total=False): default_cost_per_query: Optional[float] """ diff --git a/schema.prisma b/schema.prisma index 760954a1d6..025a1a0e3c 100644 --- a/schema.prisma +++ b/schema.prisma @@ -174,6 +174,7 @@ model LiteLLM_MCPServerTable { url String? transport String @default("sse") auth_type String? + credentials Json? @default("{}") created_at DateTime? @default(now()) @map("created_at") created_by String? updated_at DateTime? @default(now()) @updatedAt @map("updated_at") diff --git a/tests/mcp_tests/test_mcp_server.py b/tests/mcp_tests/test_mcp_server.py index 6f6127ed26..9ee5f6a9a6 100644 --- a/tests/mcp_tests/test_mcp_server.py +++ b/tests/mcp_tests/test_mcp_server.py @@ -1342,6 +1342,7 @@ def test_add_update_server_with_alias(): mock_mcp_server.url = "https://test-server.com/mcp" mock_mcp_server.transport = MCPTransport.http mock_mcp_server.auth_type = None + mock_mcp_server.credentials = {} mock_mcp_server.description = "Test server description" mock_mcp_server.mcp_info = {} mock_mcp_server.static_headers = {} @@ -1380,6 +1381,7 @@ def test_add_update_server_without_alias(): mock_mcp_server.url = "https://test-server.com/mcp" mock_mcp_server.transport = MCPTransport.http mock_mcp_server.auth_type = None + mock_mcp_server.credentials = {} mock_mcp_server.description = "Test server description" mock_mcp_server.mcp_info = {} mock_mcp_server.static_headers = {} @@ -1418,6 +1420,7 @@ def test_add_update_server_fallback_to_server_id(): mock_mcp_server.url = "https://test-server.com/mcp" mock_mcp_server.transport = MCPTransport.http mock_mcp_server.auth_type = None + mock_mcp_server.credentials = {} mock_mcp_server.description = "Test server description" mock_mcp_server.mcp_info = {} mock_mcp_server.static_headers = {} diff --git a/tests/store_model_in_db_tests/test_mcp_servers.py b/tests/store_model_in_db_tests/test_mcp_servers.py index a396c37832..ca9b92afb9 100644 --- a/tests/store_model_in_db_tests/test_mcp_servers.py +++ b/tests/store_model_in_db_tests/test_mcp_servers.py @@ -15,6 +15,7 @@ from litellm.proxy._types import ( MCPTransportType, MCPTransport, NewMCPServerRequest, + UpdateMCPServerRequest, LiteLLM_MCPServerTable, LitellmUserRoles, UserAPIKeyAuth, @@ -165,6 +166,7 @@ async def test_create_mcp_server_direct(): updated_by=LITELLM_PROXY_ADMIN_NAME, teams=[], ) + expected_response.credentials = {"auth_value": "secret"} # Mock the database calls mock_get_server.return_value = None # Server doesn't exist yet @@ -188,6 +190,7 @@ async def test_create_mcp_server_direct(): assert result.alias == expected_alias # Check against normalized alias assert result.url == mcp_server_request.url assert result.transport == mcp_server_request.transport + assert result.credentials is None # Verify mocks were called mock_get_server.assert_called_once_with(mock_prisma, server_id) @@ -353,6 +356,69 @@ async def test_create_mcp_server_invalid_alias(): ) +@pytest.mark.asyncio +async def test_edit_mcp_server_redacts_credentials(): + with mock.patch( + "litellm.proxy.management_endpoints.mcp_management_endpoints.MCP_AVAILABLE", + True, + ), mock.patch( + "litellm.proxy.management_endpoints.mcp_management_endpoints.get_prisma_client_or_throw" + ) as mock_get_prisma, mock.patch( + "litellm.proxy.management_endpoints.mcp_management_endpoints.update_mcp_server", + new_callable=mock.AsyncMock, + ) as mock_update, mock.patch( + "litellm.proxy.management_endpoints.mcp_management_endpoints.validate_and_normalize_mcp_server_payload", + autospec=True, + ) as mock_validate, mock.patch( + "litellm.proxy.management_endpoints.mcp_management_endpoints.global_mcp_server_manager" + ) as mock_manager: + from litellm.proxy.management_endpoints.mcp_management_endpoints import ( + edit_mcp_server, + ) + + mock_prisma = mock.Mock() + mock_get_prisma.return_value = mock_prisma + + mock_manager.add_update_server = mock.Mock() + mock_manager.reload_servers_from_database = mock.AsyncMock() + + server_id = str(uuid.uuid4()) + updated_server = LiteLLM_MCPServerTable( + server_id=server_id, + alias="Updated Server", + url="https://updated.example.com/mcp", + transport=MCPTransport.http, + created_at=datetime.now(), + updated_at=datetime.now(), + teams=[], + ) + updated_server.credentials = {"auth_value": "secret"} + + mock_update.return_value = updated_server + + payload = UpdateMCPServerRequest( + server_id=server_id, + alias="Updated Server", + url="https://updated.example.com/mcp", + transport=MCPTransport.http, + ) + + user_auth = UserAPIKeyAuth( + api_key=TEST_MASTER_KEY, + user_id="test-user", + user_role=LitellmUserRoles.PROXY_ADMIN, + ) + + result = await edit_mcp_server(payload=payload, user_api_key_dict=user_auth) + + assert result.server_id == server_id + assert result.credentials is None + assert updated_server.credentials == {"auth_value": "secret"} + + mock_validate.assert_called_once() + mock_update.assert_awaited_once() + mock_manager.add_update_server.assert_called_once_with(updated_server) + mock_manager.reload_servers_from_database.assert_awaited_once() def test_validate_mcp_server_name_direct(): """ Test the validation function directly to ensure it works. diff --git a/tests/test_litellm/proxy/management_endpoints/test_mcp_management_endpoints.py b/tests/test_litellm/proxy/management_endpoints/test_mcp_management_endpoints.py index b62fd6f177..2102fe71b1 100644 --- a/tests/test_litellm/proxy/management_endpoints/test_mcp_management_endpoints.py +++ b/tests/test_litellm/proxy/management_endpoints/test_mcp_management_endpoints.py @@ -186,6 +186,9 @@ class TestListMCPServers: return_value=mock_servers_with_health ) + for idx, server in enumerate(mock_servers_with_health): + server.credentials = {"auth_value": f"secret_{idx}"} + with patch( "litellm.proxy.management_endpoints.mcp_management_endpoints.global_mcp_server_manager", mock_manager, @@ -205,6 +208,7 @@ class TestListMCPServers: # Verify results assert len(result) == 2 + assert all(server.credentials is None for server in result) # Check that both config servers are returned server_ids = [server.server_id for server in result] @@ -315,6 +319,9 @@ class TestListMCPServers: return_value=mock_servers_with_health ) + for idx, server in enumerate(mock_servers_with_health): + server.credentials = {"auth_value": f"secret_{idx}"} + with patch( "litellm.proxy.management_endpoints.mcp_management_endpoints.global_mcp_server_manager", mock_manager, @@ -334,6 +341,7 @@ class TestListMCPServers: # Verify results assert len(result) == 4 + assert all(server.credentials is None for server in result) # Check that both DB and config servers are returned server_ids = [server.server_id for server in result] @@ -428,6 +436,9 @@ class TestListMCPServers: return_value=mock_servers_with_health ) + for idx, server in enumerate(mock_servers_with_health): + server.credentials = {"auth_value": f"secret_{idx}"} + with patch( "litellm.proxy.management_endpoints.mcp_management_endpoints.global_mcp_server_manager", mock_manager, @@ -447,6 +458,7 @@ class TestListMCPServers: # Verify results - should only return servers user has access to assert len(result) == 2 + assert all(server.credentials is None for server in result) # Check that only allowed servers are returned server_ids = [server.server_id for server in result] @@ -464,6 +476,51 @@ class TestListMCPServers: assert server.url == "https://actions.zapier.com/mcp/sse" + @pytest.mark.asyncio + async def test_fetch_single_mcp_server_redacts_credentials(self): + mock_server = generate_mock_mcp_server_db_record( + server_id="server-1", alias="Server 1" + ) + mock_server.credentials = {"auth_value": "top-secret"} + + mock_prisma_client = MagicMock() + mock_health_result = { + "status": "healthy", + "last_health_check": datetime.now().isoformat(), + "error": None, + } + + mock_user_auth = generate_mock_user_api_key_auth( + user_role=LitellmUserRoles.PROXY_ADMIN + ) + + with patch( + "litellm.proxy.management_endpoints.mcp_management_endpoints.get_prisma_client_or_throw", + return_value=mock_prisma_client, + ), patch( + "litellm.proxy.management_endpoints.mcp_management_endpoints.get_mcp_server", + AsyncMock(return_value=mock_server), + ), patch( + "litellm.proxy.management_endpoints.mcp_management_endpoints.global_mcp_server_manager.health_check_server", + AsyncMock(return_value=mock_health_result), + ), patch( + "litellm.proxy.management_endpoints.mcp_management_endpoints._user_has_admin_view", + return_value=True, + ): + from litellm.proxy.management_endpoints.mcp_management_endpoints import ( + fetch_mcp_server, + ) + + result = await fetch_mcp_server( + server_id="server-1", user_api_key_dict=mock_user_auth + ) + + assert result.server_id == "server-1" + assert result.credentials is None + assert mock_server.credentials == {"auth_value": "top-secret"} + assert result.status == "healthy" + + class TestMCPHealthCheckEndpoints: """Test MCP health check endpoints""" @@ -721,6 +778,8 @@ class TestMCPHealthCheckEndpoints: return_value=[mock_server] ) + mock_server.credentials = {"auth_value": "secret"} + mock_user_auth = generate_mock_user_api_key_auth( user_role=LitellmUserRoles.PROXY_ADMIN ) @@ -749,3 +808,4 @@ class TestMCPHealthCheckEndpoints: assert server.status == "healthy" assert server.last_health_check is not None assert server.health_check_error is None + assert server.credentials is None diff --git a/ui/litellm-dashboard/src/components/mcp_tools/ToolTestPanel.tsx b/ui/litellm-dashboard/src/components/mcp_tools/ToolTestPanel.tsx index 2cabfff581..e0edb506fc 100644 --- a/ui/litellm-dashboard/src/components/mcp_tools/ToolTestPanel.tsx +++ b/ui/litellm-dashboard/src/components/mcp_tools/ToolTestPanel.tsx @@ -7,8 +7,6 @@ import NotificationsManager from "../molecules/notifications_manager"; export function ToolTestPanel({ tool, - needsAuth, - authValue, onSubmit, isLoading, result, @@ -16,8 +14,6 @@ export function ToolTestPanel({ onClose, }: { tool: MCPTool; - needsAuth: boolean; - authValue?: string | null; onSubmit: (args: Record) => void; isLoading: boolean; result: any | null; diff --git a/ui/litellm-dashboard/src/components/mcp_tools/create_mcp_server.tsx b/ui/litellm-dashboard/src/components/mcp_tools/create_mcp_server.tsx index 81842b5d74..231f29d409 100644 --- a/ui/litellm-dashboard/src/components/mcp_tools/create_mcp_server.tsx +++ b/ui/litellm-dashboard/src/components/mcp_tools/create_mcp_server.tsx @@ -3,7 +3,7 @@ import { Modal, Tooltip, Form, Select } from "antd"; import { InfoCircleOutlined } from "@ant-design/icons"; import { Button, TextInput } from "@tremor/react"; import { createMCPServer } from "../networking"; -import { MCPServer, MCPServerCostInfo } from "./types"; +import { AUTH_TYPE, MCPServer, MCPServerCostInfo } from "./types"; import MCPServerCostConfig from "./mcp_server_cost_config"; import MCPConnectionStatus from "./mcp_connection_status"; import MCPToolConfiguration from "./mcp_tool_configuration"; @@ -25,6 +25,12 @@ interface CreateMCPServerProps { availableAccessGroups: string[]; } +const AUTH_TYPES_REQUIRING_AUTH_VALUE = [ + AUTH_TYPE.API_KEY, + AUTH_TYPE.BEARER_TOKEN, + AUTH_TYPE.BASIC, +]; + const CreateMCPServer: React.FC = ({ userRole, accessToken, @@ -43,6 +49,10 @@ const CreateMCPServer: React.FC = ({ const [transportType, setTransportType] = useState(""); const [searchValue, setSearchValue] = useState(""); const [urlWarning, setUrlWarning] = useState(""); + const authType = formValues.auth_type as string | undefined; + const shouldShowAuthValueField = authType + ? AUTH_TYPES_REQUIRING_AUTH_VALUE.includes(authType) + : false; // Function to check URL format based on transport type const checkUrlFormat = (url: string, transport: string) => { @@ -63,7 +73,12 @@ const CreateMCPServer: React.FC = ({ const handleCreate = async (values: Record) => { setIsLoading(true); try { - const { static_headers: staticHeadersList, stdio_config: rawStdioConfig, ...restValues } = values; + const { + static_headers: staticHeadersList, + stdio_config: rawStdioConfig, + credentials: credentialValues, + ...restValues + } = values; // Transform access groups into objects with name property const accessGroups = restValues.mcp_access_groups; @@ -79,6 +94,26 @@ const CreateMCPServer: React.FC = ({ }, {}) : {} as Record; + const credentialsPayload = + credentialValues && typeof credentialValues === "object" + ? Object.entries(credentialValues).reduce((acc: Record, [key, value]) => { + if (value === undefined || value === null || value === "") { + return acc; + } + if (key === "scopes") { + if (Array.isArray(value)) { + const filteredScopes = value.filter((scope) => scope != null && scope !== ""); + if (filteredScopes.length > 0) { + acc[key] = filteredScopes; + } + } + } else { + acc[key] = value; + } + return acc; + }, {}) + : undefined; + // Process stdio configuration if present let stdioFields = {}; if (rawStdioConfig && transportType === "stdio") { @@ -135,6 +170,18 @@ const CreateMCPServer: React.FC = ({ static_headers: staticHeaders, }; + payload.static_headers = staticHeaders; + const includeCredentials = + restValues.auth_type && AUTH_TYPES_REQUIRING_AUTH_VALUE.includes(restValues.auth_type); + + if ( + includeCredentials && + credentialsPayload && + Object.keys(credentialsPayload).length > 0 + ) { + payload.credentials = credentialsPayload; + } + console.log(`Payload: ${JSON.stringify(payload)}`); if (accessToken != null) { @@ -393,6 +440,27 @@ const CreateMCPServer: React.FC = ({ )} + {transportType !== "stdio" && shouldShowAuthValueField && ( + + Authentication Value + + + + + } + name={["credentials", "auth_value"]} + rules={[{ required: true, message: "Please enter the authentication value" }]} + > + + + )} + {/* Stdio Configuration - only show for stdio transport */} diff --git a/ui/litellm-dashboard/src/components/mcp_tools/mcp_auth_storage.ts b/ui/litellm-dashboard/src/components/mcp_tools/mcp_auth_storage.ts deleted file mode 100644 index 9f81e58fe6..0000000000 --- a/ui/litellm-dashboard/src/components/mcp_tools/mcp_auth_storage.ts +++ /dev/null @@ -1,106 +0,0 @@ -// Utility functions for managing MCP server authentication tokens in localStorage - -const MCP_AUTH_STORAGE_KEY = "litellm_mcp_auth_tokens"; - -export interface MCPAuthToken { - serverId: string; - serverAlias?: string; - authValue: string; - authType: string; - timestamp: number; -} - -export interface MCPAuthStorage { - [serverId: string]: MCPAuthToken; -} - -/** - * Get all stored MCP authentication tokens - */ -export const getMCPAuthTokens = (): MCPAuthStorage => { - try { - const stored = localStorage.getItem(MCP_AUTH_STORAGE_KEY); - return stored ? JSON.parse(stored) : {}; - } catch (error) { - console.error("Error reading MCP auth tokens from localStorage:", error); - return {}; - } -}; - -/** - * Get authentication token for a specific MCP server - */ -export const getMCPAuthToken = (serverId: string, serverAlias?: string): string | null => { - try { - const tokens = getMCPAuthTokens(); - const token = tokens[serverId]; - - // If token exists, check if serverAlias matches (both can be undefined) - if (token && token.serverAlias === serverAlias) { - return token.authValue; - } - - // If no serverAlias was provided and token exists without serverAlias, return it - if (token && !serverAlias && !token.serverAlias) { - return token.authValue; - } - - return null; - } catch (error) { - console.error("Error getting MCP auth token:", error); - return null; - } -}; - -/** - * Store authentication token for an MCP server - */ -export const setMCPAuthToken = (serverId: string, authValue: string, authType: string, serverAlias?: string): void => { - try { - const tokens = getMCPAuthTokens(); - - tokens[serverId] = { - serverId, - serverAlias, - authValue, - authType, - timestamp: Date.now(), - }; - - localStorage.setItem(MCP_AUTH_STORAGE_KEY, JSON.stringify(tokens)); - } catch (error) { - console.error("Error storing MCP auth token:", error); - } -}; - -/** - * Remove authentication token for an MCP server - */ -export const removeMCPAuthToken = (serverId: string): void => { - try { - const tokens = getMCPAuthTokens(); - delete tokens[serverId]; - localStorage.setItem(MCP_AUTH_STORAGE_KEY, JSON.stringify(tokens)); - } catch (error) { - console.error("Error removing MCP auth token:", error); - } -}; - -/** - * Clear all MCP authentication tokens (useful for logout) - */ -export const clearMCPAuthTokens = (): void => { - try { - localStorage.removeItem(MCP_AUTH_STORAGE_KEY); - } catch (error) { - console.error("Error clearing MCP auth tokens:", error); - } -}; - -/** - * Check if a token exists for a server - */ -export const hasMCPAuthToken = (serverId: string, serverAlias?: string): boolean => { - const token = getMCPAuthToken(serverId, serverAlias); - return token !== null; -}; diff --git a/ui/litellm-dashboard/src/components/mcp_tools/mcp_server_edit.tsx b/ui/litellm-dashboard/src/components/mcp_tools/mcp_server_edit.tsx index 54e1f9f957..bbfe0d1332 100644 --- a/ui/litellm-dashboard/src/components/mcp_tools/mcp_server_edit.tsx +++ b/ui/litellm-dashboard/src/components/mcp_tools/mcp_server_edit.tsx @@ -1,7 +1,8 @@ import React, { useState, useEffect } from "react"; -import { Form, Select, Button as AntdButton } from "antd"; +import { Form, Select, Button as AntdButton, Tooltip } from "antd"; +import { InfoCircleOutlined } from "@ant-design/icons"; import { Button, TextInput, TabGroup, TabList, Tab, TabPanels, TabPanel } from "@tremor/react"; -import { MCPServer, MCPServerCostInfo } from "./types"; +import { AUTH_TYPE, MCPServer, MCPServerCostInfo } from "./types"; import { updateMCPServer, testMCPToolsListRequest } from "../networking"; import MCPServerCostConfig from "./mcp_server_cost_config"; import MCPPermissionManagement from "./MCPPermissionManagement"; @@ -17,6 +18,12 @@ interface MCPServerEditProps { availableAccessGroups: string[]; } +const AUTH_TYPES_REQUIRING_AUTH_VALUE = [ + AUTH_TYPE.API_KEY, + AUTH_TYPE.BEARER_TOKEN, + AUTH_TYPE.BASIC, +]; + const MCPServerEdit: React.FC = ({ mcpServer, accessToken, @@ -31,6 +38,10 @@ const MCPServerEdit: React.FC = ({ const [searchValue, setSearchValue] = useState(""); const [aliasManuallyEdited, setAliasManuallyEdited] = useState(false); const [allowedTools, setAllowedTools] = useState([]); + const authType = Form.useWatch("auth_type", form) as string | undefined; + const shouldShowAuthValueField = authType + ? AUTH_TYPES_REQUIRING_AUTH_VALUE.includes(authType) + : false; const initialStaticHeaders = React.useMemo(() => { if (!mcpServer.static_headers) { @@ -148,7 +159,11 @@ const MCPServerEdit: React.FC = ({ if (!accessToken) return; try { // Ensure access groups is always a string array - const { static_headers: staticHeadersList, ...restValues } = values; + const { + static_headers: staticHeadersList, + credentials: credentialValues, + ...restValues + } = values; const accessGroups = (restValues.mcp_access_groups || []).map((g: any) => typeof g === "string" ? g : g.name || String(g), @@ -165,6 +180,26 @@ const MCPServerEdit: React.FC = ({ }, {}) : {} as Record; + const credentialsPayload = + credentialValues && typeof credentialValues === "object" + ? Object.entries(credentialValues).reduce((acc: Record, [key, value]) => { + if (value === undefined || value === null || value === "") { + return acc; + } + if (key === "scopes") { + if (Array.isArray(value)) { + const filteredScopes = value.filter((scope) => scope != null && scope !== ""); + if (filteredScopes.length > 0) { + acc[key] = filteredScopes; + } + } + } else { + acc[key] = value; + } + return acc; + }, {}) + : undefined; + // Prepare the payload with cost configuration and permission fields const payload = { ...restValues, @@ -183,6 +218,17 @@ const MCPServerEdit: React.FC = ({ static_headers: staticHeaders, }; + const includeCredentials = + restValues.auth_type && AUTH_TYPES_REQUIRING_AUTH_VALUE.includes(restValues.auth_type); + + if ( + includeCredentials && + credentialsPayload && + Object.keys(credentialsPayload).length > 0 + ) { + payload.credentials = credentialsPayload; + } + const updated = await updateMCPServer(accessToken, payload); NotificationsManager.success("MCP Server updated successfully"); onSuccess(updated); @@ -250,6 +296,34 @@ const MCPServerEdit: React.FC = ({ + {shouldShowAuthValueField && ( + + Authentication Value + + + + + } + name={["credentials", "auth_value"]} + rules={[ + { + validator: (_, value) => + value && typeof value === "string" && value.trim() === "" + ? Promise.reject(new Error("Authentication value cannot be empty")) + : Promise.resolve(), + }, + ]} + > + + + )} + {/* Permission Management / Access Control Section */}
void; - onCancel: () => void; - authType?: string | null; -}; - -export const AuthModal = ({ visible, onOk, onCancel, authType }: AuthModalProps) => { - const [form] = Form.useForm(); - - // Handler for modal OK - const handleOk = () => { - form.validateFields().then((values) => { - if (authType === AUTH_TYPE.BASIC) { - onOk(`${values.username.trim()}:${values.password.trim()}`); - } else { - onOk(values.authValue.trim()); - } - }); - }; - - let content; - if (authType === AUTH_TYPE.API_KEY || authType === AUTH_TYPE.BEARER_TOKEN) { - const label = authType === AUTH_TYPE.API_KEY ? "API Key" : "Bearer Token"; - content = ( - - - - ); - } else if (authType === AUTH_TYPE.BASIC) { - content = ( - <> - - - - - - - - ); - } - - return ( - -
- {content} -
-
- ); -}; - -const AuthSection = ({ - authType, - onAuthSubmit, - onClearAuth, - hasAuth, -}: { - authType: string | null | undefined; - onAuthSubmit: (value: string) => void; - onClearAuth: () => void; - hasAuth: boolean; -}) => { - const [modalVisible, setModalVisible] = useState(false); - - const handleAddAuth = () => setModalVisible(true); - - const handleModalOk = (authValue: string) => { - onAuthSubmit(authValue); - setModalVisible(false); - }; - - const handleModalCancel = () => setModalVisible(false); - - const handleClearAuth = () => { - onClearAuth(); - }; - - return ( -
-
- Authentication {hasAuth ? "✓" : ""} -
- {hasAuth && ( - - )} - -
-
- - {hasAuth ? "Authentication configured and saved locally" : "Some tools may require authentication"} - - -
- ); -}; - const MCPToolsViewer = ({ serverId, accessToken, @@ -125,47 +18,20 @@ const MCPToolsViewer = ({ userID, serverAlias, // Add serverAlias prop }: MCPToolsViewerProps) => { - const [mcpAuthValue, setMcpAuthValue] = useState(""); const [selectedTool, setSelectedTool] = useState(null); const [toolResult, setToolResult] = useState(null); const [toolError, setToolError] = useState(null); - // Load stored auth token on component mount - useEffect(() => { - if (mcpServerHasAuth(auth_type)) { - const storedAuthValue = getMCPAuthToken(serverId, serverAlias || undefined); - if (storedAuthValue) { - setMcpAuthValue(storedAuthValue); - } - } - }, [serverId, serverAlias, auth_type]); - - // Function to handle auth submission with localStorage persistence - const handleAuthSubmit = (authValue: string) => { - setMcpAuthValue(authValue); - if (authValue && mcpServerHasAuth(auth_type)) { - setMCPAuthToken(serverId, authValue, auth_type || "none", serverAlias || undefined); - NotificationsManager.success("Authentication token saved locally"); - } - }; - - // Function to clear auth token - const handleClearAuth = () => { - setMcpAuthValue(""); - removeMCPAuthToken(serverId); - NotificationsManager.info("Authentication token cleared"); - }; - // Query to fetch MCP tools const { data: mcpToolsResponse, isLoading: isLoadingTools, error: mcpToolsError, } = useQuery({ - queryKey: ["mcpTools", serverId, mcpAuthValue, serverAlias], + queryKey: ["mcpTools", serverId], queryFn: () => { if (!accessToken) throw new Error("Access Token required"); - return listMCPTools(accessToken, serverId, mcpAuthValue, serverAlias || undefined); + return listMCPTools(accessToken, serverId); }, enabled: !!accessToken, staleTime: 30000, // Consider data fresh for 30 seconds @@ -173,7 +39,7 @@ const MCPToolsViewer = ({ // Mutation for calling a tool const { mutate: executeTool, isPending: isCallingTool } = useMutation({ - mutationFn: async (args: { tool: MCPTool; arguments: Record; authValue: string }) => { + mutationFn: async (args: { tool: MCPTool; arguments: Record }) => { if (!accessToken) throw new Error("Access Token required"); try { @@ -181,8 +47,6 @@ const MCPToolsViewer = ({ accessToken, args.tool.name, args.arguments, - args.authValue, - serverAlias || undefined, ); return result; } catch (error) { @@ -200,7 +64,6 @@ const MCPToolsViewer = ({ }); const toolsData = mcpToolsResponse?.tools || []; - const hasAuth = mcpAuthValue !== ""; return (
@@ -317,44 +180,6 @@ const MCPToolsViewer = ({
)}
- - {/* Authentication Section - Below tools list */} - {mcpServerHasAuth(auth_type) && ( -
- {!hasAuth ? ( - /* Prominent display when auth required but not provided */ -
-
- - Authentication Required -
- - This MCP server requires authentication. You must add your credentials below to access the - tools. - - -
- ) : ( - /* Subtle display when already authenticated */ - <> - - Authentication - - - - )} -
- )} @@ -379,10 +204,8 @@ const MCPToolsViewer = ({
{ - executeTool({ tool: selectedTool, arguments: args, authValue: mcpAuthValue }); + executeTool({ tool: selectedTool, arguments: args }); }} result={toolResult} error={toolError} diff --git a/ui/litellm-dashboard/src/components/mcp_tools/types.tsx b/ui/litellm-dashboard/src/components/mcp_tools/types.tsx index 29bb19d294..3d5360207a 100644 --- a/ui/litellm-dashboard/src/components/mcp_tools/types.tsx +++ b/ui/litellm-dashboard/src/components/mcp_tools/types.tsx @@ -34,10 +34,6 @@ export const handleAuth = (authType?: string | null): string => { return authType; }; -export const mcpServerHasAuth = (authType?: string | null): boolean => { - return handleAuth(authType) !== AUTH_TYPE.NONE; -}; - // Define the structure for tool input schema properties export interface InputSchemaProperty { type: string; diff --git a/ui/litellm-dashboard/src/components/navbar.tsx b/ui/litellm-dashboard/src/components/navbar.tsx index 0316f3bf6a..667691168c 100644 --- a/ui/litellm-dashboard/src/components/navbar.tsx +++ b/ui/litellm-dashboard/src/components/navbar.tsx @@ -15,7 +15,6 @@ import { import { clearTokenCookies } from "@/utils/cookieUtils"; import { fetchProxySettings } from "@/utils/proxyUtils"; import { useTheme } from "@/contexts/ThemeContext"; -import { clearMCPAuthTokens } from "./mcp_tools/mcp_auth_storage"; import useFeatureFlags from "@/hooks/useFeatureFlags"; interface NavbarProps { @@ -71,7 +70,6 @@ const Navbar: React.FC = ({ const handleLogout = () => { clearTokenCookies(); - clearMCPAuthTokens(); // Clear MCP auth tokens on logout window.location.href = logoutUrl; }; diff --git a/ui/litellm-dashboard/src/components/networking.tsx b/ui/litellm-dashboard/src/components/networking.tsx index 5e7bb698cb..a8cb05310c 100644 --- a/ui/litellm-dashboard/src/components/networking.tsx +++ b/ui/litellm-dashboard/src/components/networking.tsx @@ -5750,7 +5750,7 @@ export const testSearchToolConnection = async (accessToken: string, litellmParam } }; -export const listMCPTools = async (accessToken: string, serverId: string, authValue?: string, serverAlias?: string) => { +export const listMCPTools = async (accessToken: string, serverId: string) => { try { // Construct base URL let url = proxyBaseUrl @@ -5764,14 +5764,6 @@ export const listMCPTools = async (accessToken: string, serverId: string, authVa "Content-Type": "application/json", }; - // Use new server-specific auth header format if serverAlias is provided - if (serverAlias && authValue) { - headers[`x-mcp-${serverAlias}-authorization`] = authValue; - } else if (authValue) { - // Fall back to deprecated x-mcp-auth header for backward compatibility - headers[MCP_AUTH_HEADER] = authValue; - } - const response = await fetch(url, { method: "GET", headers, @@ -5805,9 +5797,7 @@ export const listMCPTools = async (accessToken: string, serverId: string, authVa export const callMCPTool = async ( accessToken: string, toolName: string, - toolArguments: Record, - authValue: string, - serverAlias?: string, + toolArguments: Record ) => { try { // Construct base URL @@ -5820,14 +5810,6 @@ export const callMCPTool = async ( "Content-Type": "application/json", }; - // Use new server-specific auth header format if serverAlias is provided - if (serverAlias) { - headers[`x-mcp-${serverAlias}-authorization`] = authValue; - } else { - // Fall back to deprecated x-mcp-auth header for backward compatibility - headers[MCP_AUTH_HEADER] = authValue; - } - const response = await fetch(url, { method: "POST", headers, diff --git a/ui/litellm-dashboard/src/hooks/useTestMCPConnection.tsx b/ui/litellm-dashboard/src/hooks/useTestMCPConnection.tsx index 4c1f88ea2b..0cad2fe81a 100644 --- a/ui/litellm-dashboard/src/hooks/useTestMCPConnection.tsx +++ b/ui/litellm-dashboard/src/hooks/useTestMCPConnection.tsx @@ -9,6 +9,12 @@ interface MCPServerConfig { auth_type?: string; mcp_info?: any; static_headers?: Record; + credentials?: { + auth_value?: string; + client_id?: string; + client_secret?: string; + scopes?: string[]; + }; } interface UseTestMCPConnectionProps { @@ -41,6 +47,7 @@ export const useTestMCPConnection = ({ const canFetchTools = !!(formValues.url && formValues.transport && formValues.auth_type && accessToken); const staticHeadersKey = JSON.stringify(formValues.static_headers ?? {}); + const credentialsKey = JSON.stringify(formValues.credentials ?? {}); const fetchTools = async () => { if (!accessToken || !formValues.url) { @@ -74,6 +81,29 @@ export const useTestMCPConnection = ({ ) : {} as Record; + const credentials = + formValues.credentials && typeof formValues.credentials === "object" + ? Object.entries(formValues.credentials).reduce( + (acc: Record, [key, value]) => { + if (value === undefined || value === null || value === "") { + return acc; + } + if (key === "scopes") { + if (Array.isArray(value)) { + const normalizedScopes = value.filter((scope) => scope != null && scope !== ""); + if (normalizedScopes.length > 0) { + acc[key] = normalizedScopes; + } + } + } else { + acc[key] = value; + } + return acc; + }, + {}, + ) + : undefined; + const mcpServerConfig: MCPServerConfig = { server_id: formValues.server_id || "", server_name: formValues.server_name || "", @@ -84,6 +114,10 @@ export const useTestMCPConnection = ({ static_headers: staticHeaders, }; + if (credentials && Object.keys(credentials).length > 0) { + mcpServerConfig.credentials = credentials; + } + const toolsResponse = await testMCPToolsListRequest(accessToken, mcpServerConfig); if (toolsResponse.tools && !toolsResponse.error) { @@ -126,7 +160,16 @@ export const useTestMCPConnection = ({ clearTools(); } // eslint-disable-next-line react-hooks/exhaustive-deps - }, [formValues.url, formValues.transport, formValues.auth_type, accessToken, enabled, canFetchTools, staticHeadersKey]); + }, [ + formValues.url, + formValues.transport, + formValues.auth_type, + accessToken, + enabled, + canFetchTools, + staticHeadersKey, + credentialsKey, + ]); return { tools,