mirror of
https://github.com/tiennm99/litellm.git
synced 2026-06-18 00:48:01 +00:00
Feat/persist mcp credentials in db (#16308)
* feat: persist mcp credentials in db * feat: remove Auth Value field from MCP Tool Testing Playground * fix: test
This commit is contained in:
+2
@@ -0,0 +1,2 @@
|
||||
-- AlterTable
|
||||
ALTER TABLE "LiteLLM_MCPServerTable" ADD COLUMN "credentials" JSONB DEFAULT '{}';
|
||||
@@ -174,6 +174,7 @@ model LiteLLM_MCPServerTable {
|
||||
url String?
|
||||
transport String @default("sse")
|
||||
auth_type String?
|
||||
credentials Json? @default("{}")
|
||||
created_at DateTime? @default(now()) @map("created_at")
|
||||
created_by String?
|
||||
updated_at DateTime? @default(now()) @updatedAt @map("updated_at")
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Any, Dict, Iterable, List, Optional, Set, Union
|
||||
from typing import Any, Dict, Iterable, List, Optional, Set, Union, cast
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm._uuid import uuid
|
||||
@@ -11,7 +11,12 @@ from litellm.proxy._types import (
|
||||
UpdateMCPServerRequest,
|
||||
UserAPIKeyAuth,
|
||||
)
|
||||
from litellm.proxy.common_utils.encrypt_decrypt_utils import (
|
||||
_get_salt_key,
|
||||
encrypt_value_helper,
|
||||
)
|
||||
from litellm.proxy.utils import PrismaClient
|
||||
from litellm.types.mcp import MCPCredentials
|
||||
|
||||
|
||||
def _prepare_mcp_server_data(
|
||||
@@ -35,6 +40,14 @@ def _prepare_mcp_server_data(
|
||||
if "alias" not in data_dict:
|
||||
data_dict["alias"] = getattr(data, "alias", None)
|
||||
|
||||
# Handle credentials serialization
|
||||
credentials = data_dict.get("credentials")
|
||||
if credentials is not None:
|
||||
data_dict["credentials"] = encrypt_credentials(
|
||||
credentials=credentials, encryption_key=_get_salt_key()
|
||||
)
|
||||
data_dict["credentials"] = safe_dumps(data_dict["credentials"])
|
||||
|
||||
# Handle static_headers serialization
|
||||
if data.static_headers is not None:
|
||||
data_dict["static_headers"] = safe_dumps(data.static_headers)
|
||||
@@ -52,6 +65,30 @@ def _prepare_mcp_server_data(
|
||||
return data_dict
|
||||
|
||||
|
||||
def encrypt_credentials(
|
||||
credentials: MCPCredentials, encryption_key: Optional[str]
|
||||
) -> MCPCredentials:
|
||||
auth_value = credentials.get("auth_value")
|
||||
if auth_value is not None:
|
||||
credentials["auth_value"] = encrypt_value_helper(
|
||||
value=auth_value,
|
||||
new_encryption_key=encryption_key,
|
||||
)
|
||||
client_id = credentials.get("client_id")
|
||||
if client_id is not None:
|
||||
credentials["client_id"] = encrypt_value_helper(
|
||||
value=client_id,
|
||||
new_encryption_key=encryption_key,
|
||||
)
|
||||
client_secret = credentials.get("client_secret")
|
||||
if client_secret is not None:
|
||||
credentials["client_secret"] = encrypt_value_helper(
|
||||
value=client_secret,
|
||||
new_encryption_key=encryption_key,
|
||||
)
|
||||
return credentials
|
||||
|
||||
|
||||
async def get_all_mcp_servers(
|
||||
prisma_client: PrismaClient,
|
||||
) -> List[LiteLLM_MCPServerTable]:
|
||||
@@ -303,3 +340,32 @@ async def update_mcp_server(
|
||||
)
|
||||
|
||||
return updated_mcp_server
|
||||
|
||||
|
||||
async def rotate_mcp_server_credentials_master_key(
|
||||
prisma_client: PrismaClient, touched_by: str, new_master_key: str
|
||||
):
|
||||
mcp_servers = await prisma_client.db.litellm_mcpservertable.find_many()
|
||||
|
||||
for mcp_server in mcp_servers:
|
||||
credentials = mcp_server.credentials
|
||||
if not credentials:
|
||||
continue
|
||||
|
||||
credentials_copy = dict(credentials)
|
||||
encrypted_credentials = encrypt_credentials(
|
||||
credentials=cast(MCPCredentials, credentials_copy),
|
||||
encryption_key=new_master_key,
|
||||
)
|
||||
|
||||
from litellm.litellm_core_utils.safe_json_dumps import safe_dumps
|
||||
|
||||
serialized_credentials = safe_dumps(encrypted_credentials)
|
||||
|
||||
await prisma_client.db.litellm_mcpservertable.update(
|
||||
where={"server_id": mcp_server.server_id},
|
||||
data={
|
||||
"credentials": serialized_credentials,
|
||||
"updated_by": touched_by,
|
||||
},
|
||||
)
|
||||
|
||||
@@ -38,6 +38,9 @@ from litellm.proxy._types import (
|
||||
MCPTransportType,
|
||||
UserAPIKeyAuth,
|
||||
)
|
||||
from litellm.proxy.common_utils.encrypt_decrypt_utils import (
|
||||
decrypt_value_helper,
|
||||
)
|
||||
from litellm.proxy.utils import ProxyLogging
|
||||
from litellm.types.mcp import MCPAuth, MCPStdioConfig
|
||||
from litellm.types.mcp_server.mcp_server_manager import MCPInfo, MCPServer
|
||||
@@ -400,6 +403,20 @@ class MCPServerManager:
|
||||
static_headers_dict = _deserialize_json_dict(
|
||||
getattr(mcp_server, "static_headers", None)
|
||||
)
|
||||
credentials_dict = _deserialize_json_dict(
|
||||
getattr(mcp_server, "credentials", None)
|
||||
)
|
||||
|
||||
encrypted_auth_value: Optional[str] = None
|
||||
if credentials_dict:
|
||||
encrypted_auth_value = credentials_dict.get("auth_value")
|
||||
|
||||
auth_value: Optional[str] = None
|
||||
if encrypted_auth_value:
|
||||
auth_value = decrypt_value_helper(
|
||||
value=encrypted_auth_value,
|
||||
key="auth_value",
|
||||
)
|
||||
# Use alias for name if present, else server_name
|
||||
name_for_prefix = (
|
||||
mcp_server.alias or mcp_server.server_name or mcp_server.server_id
|
||||
@@ -422,6 +439,7 @@ class MCPServerManager:
|
||||
url=mcp_server.url,
|
||||
transport=cast(MCPTransportType, mcp_server.transport),
|
||||
auth_type=cast(MCPAuthType, mcp_server.auth_type),
|
||||
authentication_token=auth_value,
|
||||
mcp_info=mcp_info,
|
||||
extra_headers=getattr(mcp_server, "extra_headers", None),
|
||||
static_headers=static_headers_dict,
|
||||
|
||||
+32
-1
@@ -17,7 +17,13 @@ from typing_extensions import Required, TypedDict
|
||||
from litellm._uuid import uuid
|
||||
from litellm.types.integrations.slack_alerting import AlertType
|
||||
from litellm.types.llms.openai import AllMessageValues, OpenAIFileObject
|
||||
from litellm.types.mcp import MCPAuthType, MCPTransport, MCPTransportType
|
||||
from litellm.types.mcp import (
|
||||
MCPAuth,
|
||||
MCPAuthType,
|
||||
MCPCredentials,
|
||||
MCPTransport,
|
||||
MCPTransportType,
|
||||
)
|
||||
from litellm.types.mcp_server.mcp_server_manager import MCPInfo
|
||||
from litellm.types.router import RouterErrors, UpdateRouterConfig
|
||||
from litellm.types.secret_managers.main import KeyManagementSystem
|
||||
@@ -959,6 +965,7 @@ class NewMCPServerRequest(LiteLLMPydanticObjectBase):
|
||||
description: Optional[str] = None
|
||||
transport: MCPTransportType = MCPTransport.sse
|
||||
auth_type: Optional[MCPAuthType] = None
|
||||
credentials: Optional[MCPCredentials] = None
|
||||
url: Optional[str] = None
|
||||
mcp_info: Optional[MCPInfo] = None
|
||||
mcp_access_groups: List[str] = Field(default_factory=list)
|
||||
@@ -985,6 +992,28 @@ class NewMCPServerRequest(LiteLLMPydanticObjectBase):
|
||||
raise ValueError("url is required for HTTP/SSE transport")
|
||||
return values
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_credentials_requirements(cls, values):
|
||||
if not isinstance(values, dict):
|
||||
return values
|
||||
|
||||
auth_type = values.get("auth_type")
|
||||
if auth_type in {MCPAuth.api_key, MCPAuth.bearer_token, MCPAuth.basic}:
|
||||
credentials = values.get("credentials")
|
||||
auth_value = None
|
||||
if isinstance(credentials, dict):
|
||||
auth_value = credentials.get("auth_value")
|
||||
elif hasattr(credentials, "get"):
|
||||
auth_value = credentials.get("auth_value") # type: ignore[attr-defined]
|
||||
|
||||
if not auth_value:
|
||||
raise ValueError(
|
||||
"auth_value is required when auth_type is api_key, bearer_token, or basic"
|
||||
)
|
||||
|
||||
return values
|
||||
|
||||
|
||||
class UpdateMCPServerRequest(LiteLLMPydanticObjectBase):
|
||||
server_id: str
|
||||
@@ -993,6 +1022,7 @@ class UpdateMCPServerRequest(LiteLLMPydanticObjectBase):
|
||||
description: Optional[str] = None
|
||||
transport: MCPTransportType = MCPTransport.sse
|
||||
auth_type: Optional[MCPAuthType] = None
|
||||
credentials: Optional[MCPCredentials] = None
|
||||
url: Optional[str] = None
|
||||
mcp_info: Optional[MCPInfo] = None
|
||||
mcp_access_groups: List[str] = Field(default_factory=list)
|
||||
@@ -1028,6 +1058,7 @@ class LiteLLM_MCPServerTable(LiteLLMPydanticObjectBase):
|
||||
url: Optional[str] = None
|
||||
transport: MCPTransportType
|
||||
auth_type: Optional[MCPAuthType] = None
|
||||
credentials: Optional[MCPCredentials] = None
|
||||
created_at: Optional[datetime] = None
|
||||
created_by: Optional[str] = None
|
||||
updated_at: Optional[datetime] = None
|
||||
|
||||
@@ -24,8 +24,15 @@ import litellm
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm._uuid import uuid
|
||||
from litellm.caching import DualCache
|
||||
from litellm.constants import LENGTH_OF_LITELLM_GENERATED_KEY, UI_SESSION_TOKEN_TEAM_ID
|
||||
from litellm.constants import (
|
||||
LENGTH_OF_LITELLM_GENERATED_KEY,
|
||||
LITELLM_PROXY_ADMIN_NAME,
|
||||
UI_SESSION_TOKEN_TEAM_ID,
|
||||
)
|
||||
from litellm.litellm_core_utils.duration_parser import duration_in_seconds
|
||||
from litellm.proxy._experimental.mcp_server.db import (
|
||||
rotate_mcp_server_credentials_master_key,
|
||||
)
|
||||
from litellm.proxy._types import *
|
||||
from litellm.proxy._types import LiteLLM_VerificationToken
|
||||
from litellm.proxy.auth.auth_checks import (
|
||||
@@ -640,9 +647,9 @@ async def _common_key_generation_helper( # noqa: PLR0915
|
||||
request_type="key", **data_json, table_name="key"
|
||||
)
|
||||
|
||||
response["soft_budget"] = (
|
||||
data.soft_budget
|
||||
) # include the user-input soft budget in the response
|
||||
response[
|
||||
"soft_budget"
|
||||
] = data.soft_budget # include the user-input soft budget in the response
|
||||
|
||||
response = GenerateKeyResponse(**response)
|
||||
|
||||
@@ -1299,7 +1306,6 @@ async def prepare_key_update_data(
|
||||
data: Union[UpdateKeyRequest, RegenerateKeyRequest],
|
||||
existing_key_row: LiteLLM_VerificationToken,
|
||||
):
|
||||
|
||||
data_json: dict = data.model_dump(exclude_unset=True)
|
||||
data_json.pop("key", None)
|
||||
data_json.pop("new_key", None)
|
||||
@@ -2357,10 +2363,10 @@ async def delete_verification_tokens(
|
||||
try:
|
||||
if prisma_client:
|
||||
tokens = [_hash_token_if_needed(token=key) for key in tokens]
|
||||
_keys_being_deleted: List[LiteLLM_VerificationToken] = (
|
||||
await prisma_client.db.litellm_verificationtoken.find_many(
|
||||
where={"token": {"in": tokens}}
|
||||
)
|
||||
_keys_being_deleted: List[
|
||||
LiteLLM_VerificationToken
|
||||
] = await prisma_client.db.litellm_verificationtoken.find_many(
|
||||
where={"token": {"in": tokens}}
|
||||
)
|
||||
|
||||
if len(_keys_being_deleted) == 0:
|
||||
@@ -2468,9 +2474,9 @@ async def _rotate_master_key(
|
||||
from litellm.proxy.proxy_server import proxy_config
|
||||
|
||||
try:
|
||||
models: Optional[List] = (
|
||||
await prisma_client.db.litellm_proxymodeltable.find_many()
|
||||
)
|
||||
models: Optional[
|
||||
List
|
||||
] = await prisma_client.db.litellm_proxymodeltable.find_many()
|
||||
except Exception:
|
||||
models = None
|
||||
# 2. process model table
|
||||
@@ -2524,6 +2530,13 @@ async def _rotate_master_key(
|
||||
data={"param_value": jsonify_object(encrypted_env_vars)},
|
||||
)
|
||||
|
||||
# 4. process MCP server table
|
||||
await rotate_mcp_server_credentials_master_key(
|
||||
prisma_client=prisma_client,
|
||||
touched_by=user_api_key_dict.user_id or LITELLM_PROXY_ADMIN_NAME,
|
||||
new_master_key=new_master_key,
|
||||
)
|
||||
|
||||
|
||||
def get_new_token(data: Optional[RegenerateKeyRequest]) -> str:
|
||||
if data and data.new_key is not None:
|
||||
@@ -2781,11 +2794,11 @@ async def validate_key_list_check(
|
||||
param="user_id",
|
||||
code=status.HTTP_403_FORBIDDEN,
|
||||
)
|
||||
complete_user_info_db_obj: Optional[BaseModel] = (
|
||||
await prisma_client.db.litellm_usertable.find_unique(
|
||||
where={"user_id": user_api_key_dict.user_id},
|
||||
include={"organization_memberships": True},
|
||||
)
|
||||
complete_user_info_db_obj: Optional[
|
||||
BaseModel
|
||||
] = await prisma_client.db.litellm_usertable.find_unique(
|
||||
where={"user_id": user_api_key_dict.user_id},
|
||||
include={"organization_memberships": True},
|
||||
)
|
||||
|
||||
if complete_user_info_db_obj is None:
|
||||
@@ -2871,10 +2884,10 @@ async def get_admin_team_ids(
|
||||
if complete_user_info is None:
|
||||
return []
|
||||
# Get all teams that user is an admin of
|
||||
teams: Optional[List[BaseModel]] = (
|
||||
await prisma_client.db.litellm_teamtable.find_many(
|
||||
where={"team_id": {"in": complete_user_info.teams}}
|
||||
)
|
||||
teams: Optional[
|
||||
List[BaseModel]
|
||||
] = await prisma_client.db.litellm_teamtable.find_many(
|
||||
where={"team_id": {"in": complete_user_info.teams}}
|
||||
)
|
||||
if teams is None:
|
||||
return []
|
||||
|
||||
@@ -58,6 +58,24 @@ if MCP_AVAILABLE:
|
||||
from litellm.proxy.management_endpoints.common_utils import _user_has_admin_view
|
||||
from litellm.proxy.management_helpers.utils import management_endpoint_wrapper
|
||||
|
||||
def _redact_mcp_credentials(
|
||||
mcp_server: LiteLLM_MCPServerTable,
|
||||
) -> LiteLLM_MCPServerTable:
|
||||
"""Return a copy of the MCP server object with credentials removed."""
|
||||
|
||||
try:
|
||||
redacted_server = mcp_server.model_copy(deep=True)
|
||||
except AttributeError:
|
||||
redacted_server = mcp_server.copy(deep=True) # type: ignore[attr-defined]
|
||||
|
||||
redacted_server.credentials = None
|
||||
return redacted_server
|
||||
|
||||
def _redact_mcp_credentials_list(
|
||||
mcp_servers: Iterable[LiteLLM_MCPServerTable],
|
||||
) -> List[LiteLLM_MCPServerTable]:
|
||||
return [_redact_mcp_credentials(server) for server in mcp_servers]
|
||||
|
||||
def get_prisma_client_or_throw(message: str):
|
||||
from litellm.proxy.proxy_server import prisma_client
|
||||
|
||||
@@ -273,11 +291,12 @@ if MCP_AVAILABLE:
|
||||
```
|
||||
"""
|
||||
# Use server manager to get all servers with health and team data
|
||||
return (
|
||||
mcp_servers = (
|
||||
await global_mcp_server_manager.get_all_mcp_servers_with_health_and_teams(
|
||||
user_api_key_auth=user_api_key_dict
|
||||
)
|
||||
)
|
||||
return _redact_mcp_credentials_list(mcp_servers)
|
||||
|
||||
@router.get(
|
||||
"/server/{server_id}",
|
||||
@@ -335,7 +354,7 @@ if MCP_AVAILABLE:
|
||||
|
||||
# Implement authz restriction from requested user
|
||||
if _user_has_admin_view(user_api_key_dict):
|
||||
return mcp_server
|
||||
return _redact_mcp_credentials(mcp_server)
|
||||
|
||||
# Perform authz check to filter the mcp servers user has access to
|
||||
mcp_server_records = await get_all_mcp_servers_for_user(
|
||||
@@ -345,7 +364,7 @@ if MCP_AVAILABLE:
|
||||
|
||||
if exists:
|
||||
global_mcp_server_manager.add_update_server(mcp_server)
|
||||
return mcp_server
|
||||
return _redact_mcp_credentials(mcp_server)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
@@ -428,7 +447,7 @@ if MCP_AVAILABLE:
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail={"error": f"Error creating mcp server: {str(e)}"},
|
||||
)
|
||||
return new_mcp_server
|
||||
return _redact_mcp_credentials(new_mcp_server)
|
||||
|
||||
@router.delete(
|
||||
"/server/{server_id}",
|
||||
@@ -563,4 +582,4 @@ if MCP_AVAILABLE:
|
||||
if litellm.store_audit_logs:
|
||||
pass
|
||||
|
||||
return mcp_server_record_updated
|
||||
return _redact_mcp_credentials(mcp_server_record_updated)
|
||||
|
||||
@@ -174,6 +174,7 @@ model LiteLLM_MCPServerTable {
|
||||
url String?
|
||||
transport String @default("sse")
|
||||
auth_type String?
|
||||
credentials Json? @default("{}")
|
||||
created_at DateTime? @default(now()) @map("created_at")
|
||||
created_by String?
|
||||
updated_at DateTime? @default(now()) @updatedAt @map("updated_at")
|
||||
|
||||
@@ -54,6 +54,28 @@ MCPAuthType = Optional[
|
||||
]
|
||||
|
||||
|
||||
class MCPCredentials(TypedDict, total=False):
|
||||
auth_value: Optional[str]
|
||||
"""
|
||||
Authentication value
|
||||
"""
|
||||
|
||||
client_id: Optional[str]
|
||||
"""
|
||||
OAuth 2.0 client identifier used when auth_type is oauth2
|
||||
"""
|
||||
|
||||
client_secret: Optional[str]
|
||||
"""
|
||||
OAuth 2.0 client secret used when auth_type is oauth2
|
||||
"""
|
||||
|
||||
scopes: Optional[List[str]]
|
||||
"""
|
||||
OAuth 2.0 scopes to request when exchanging the client credentials
|
||||
"""
|
||||
|
||||
|
||||
class MCPServerCostInfo(TypedDict, total=False):
|
||||
default_cost_per_query: Optional[float]
|
||||
"""
|
||||
|
||||
@@ -174,6 +174,7 @@ model LiteLLM_MCPServerTable {
|
||||
url String?
|
||||
transport String @default("sse")
|
||||
auth_type String?
|
||||
credentials Json? @default("{}")
|
||||
created_at DateTime? @default(now()) @map("created_at")
|
||||
created_by String?
|
||||
updated_at DateTime? @default(now()) @updatedAt @map("updated_at")
|
||||
|
||||
@@ -1342,6 +1342,7 @@ def test_add_update_server_with_alias():
|
||||
mock_mcp_server.url = "https://test-server.com/mcp"
|
||||
mock_mcp_server.transport = MCPTransport.http
|
||||
mock_mcp_server.auth_type = None
|
||||
mock_mcp_server.credentials = {}
|
||||
mock_mcp_server.description = "Test server description"
|
||||
mock_mcp_server.mcp_info = {}
|
||||
mock_mcp_server.static_headers = {}
|
||||
@@ -1380,6 +1381,7 @@ def test_add_update_server_without_alias():
|
||||
mock_mcp_server.url = "https://test-server.com/mcp"
|
||||
mock_mcp_server.transport = MCPTransport.http
|
||||
mock_mcp_server.auth_type = None
|
||||
mock_mcp_server.credentials = {}
|
||||
mock_mcp_server.description = "Test server description"
|
||||
mock_mcp_server.mcp_info = {}
|
||||
mock_mcp_server.static_headers = {}
|
||||
@@ -1418,6 +1420,7 @@ def test_add_update_server_fallback_to_server_id():
|
||||
mock_mcp_server.url = "https://test-server.com/mcp"
|
||||
mock_mcp_server.transport = MCPTransport.http
|
||||
mock_mcp_server.auth_type = None
|
||||
mock_mcp_server.credentials = {}
|
||||
mock_mcp_server.description = "Test server description"
|
||||
mock_mcp_server.mcp_info = {}
|
||||
mock_mcp_server.static_headers = {}
|
||||
|
||||
@@ -15,6 +15,7 @@ from litellm.proxy._types import (
|
||||
MCPTransportType,
|
||||
MCPTransport,
|
||||
NewMCPServerRequest,
|
||||
UpdateMCPServerRequest,
|
||||
LiteLLM_MCPServerTable,
|
||||
LitellmUserRoles,
|
||||
UserAPIKeyAuth,
|
||||
@@ -165,6 +166,7 @@ async def test_create_mcp_server_direct():
|
||||
updated_by=LITELLM_PROXY_ADMIN_NAME,
|
||||
teams=[],
|
||||
)
|
||||
expected_response.credentials = {"auth_value": "secret"}
|
||||
|
||||
# Mock the database calls
|
||||
mock_get_server.return_value = None # Server doesn't exist yet
|
||||
@@ -188,6 +190,7 @@ async def test_create_mcp_server_direct():
|
||||
assert result.alias == expected_alias # Check against normalized alias
|
||||
assert result.url == mcp_server_request.url
|
||||
assert result.transport == mcp_server_request.transport
|
||||
assert result.credentials is None
|
||||
|
||||
# Verify mocks were called
|
||||
mock_get_server.assert_called_once_with(mock_prisma, server_id)
|
||||
@@ -353,6 +356,69 @@ async def test_create_mcp_server_invalid_alias():
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_edit_mcp_server_redacts_credentials():
|
||||
with mock.patch(
|
||||
"litellm.proxy.management_endpoints.mcp_management_endpoints.MCP_AVAILABLE",
|
||||
True,
|
||||
), mock.patch(
|
||||
"litellm.proxy.management_endpoints.mcp_management_endpoints.get_prisma_client_or_throw"
|
||||
) as mock_get_prisma, mock.patch(
|
||||
"litellm.proxy.management_endpoints.mcp_management_endpoints.update_mcp_server",
|
||||
new_callable=mock.AsyncMock,
|
||||
) as mock_update, mock.patch(
|
||||
"litellm.proxy.management_endpoints.mcp_management_endpoints.validate_and_normalize_mcp_server_payload",
|
||||
autospec=True,
|
||||
) as mock_validate, mock.patch(
|
||||
"litellm.proxy.management_endpoints.mcp_management_endpoints.global_mcp_server_manager"
|
||||
) as mock_manager:
|
||||
from litellm.proxy.management_endpoints.mcp_management_endpoints import (
|
||||
edit_mcp_server,
|
||||
)
|
||||
|
||||
mock_prisma = mock.Mock()
|
||||
mock_get_prisma.return_value = mock_prisma
|
||||
|
||||
mock_manager.add_update_server = mock.Mock()
|
||||
mock_manager.reload_servers_from_database = mock.AsyncMock()
|
||||
|
||||
server_id = str(uuid.uuid4())
|
||||
updated_server = LiteLLM_MCPServerTable(
|
||||
server_id=server_id,
|
||||
alias="Updated Server",
|
||||
url="https://updated.example.com/mcp",
|
||||
transport=MCPTransport.http,
|
||||
created_at=datetime.now(),
|
||||
updated_at=datetime.now(),
|
||||
teams=[],
|
||||
)
|
||||
updated_server.credentials = {"auth_value": "secret"}
|
||||
|
||||
mock_update.return_value = updated_server
|
||||
|
||||
payload = UpdateMCPServerRequest(
|
||||
server_id=server_id,
|
||||
alias="Updated Server",
|
||||
url="https://updated.example.com/mcp",
|
||||
transport=MCPTransport.http,
|
||||
)
|
||||
|
||||
user_auth = UserAPIKeyAuth(
|
||||
api_key=TEST_MASTER_KEY,
|
||||
user_id="test-user",
|
||||
user_role=LitellmUserRoles.PROXY_ADMIN,
|
||||
)
|
||||
|
||||
result = await edit_mcp_server(payload=payload, user_api_key_dict=user_auth)
|
||||
|
||||
assert result.server_id == server_id
|
||||
assert result.credentials is None
|
||||
assert updated_server.credentials == {"auth_value": "secret"}
|
||||
|
||||
mock_validate.assert_called_once()
|
||||
mock_update.assert_awaited_once()
|
||||
mock_manager.add_update_server.assert_called_once_with(updated_server)
|
||||
mock_manager.reload_servers_from_database.assert_awaited_once()
|
||||
def test_validate_mcp_server_name_direct():
|
||||
"""
|
||||
Test the validation function directly to ensure it works.
|
||||
|
||||
@@ -186,6 +186,9 @@ class TestListMCPServers:
|
||||
return_value=mock_servers_with_health
|
||||
)
|
||||
|
||||
for idx, server in enumerate(mock_servers_with_health):
|
||||
server.credentials = {"auth_value": f"secret_{idx}"}
|
||||
|
||||
with patch(
|
||||
"litellm.proxy.management_endpoints.mcp_management_endpoints.global_mcp_server_manager",
|
||||
mock_manager,
|
||||
@@ -205,6 +208,7 @@ class TestListMCPServers:
|
||||
|
||||
# Verify results
|
||||
assert len(result) == 2
|
||||
assert all(server.credentials is None for server in result)
|
||||
|
||||
# Check that both config servers are returned
|
||||
server_ids = [server.server_id for server in result]
|
||||
@@ -315,6 +319,9 @@ class TestListMCPServers:
|
||||
return_value=mock_servers_with_health
|
||||
)
|
||||
|
||||
for idx, server in enumerate(mock_servers_with_health):
|
||||
server.credentials = {"auth_value": f"secret_{idx}"}
|
||||
|
||||
with patch(
|
||||
"litellm.proxy.management_endpoints.mcp_management_endpoints.global_mcp_server_manager",
|
||||
mock_manager,
|
||||
@@ -334,6 +341,7 @@ class TestListMCPServers:
|
||||
|
||||
# Verify results
|
||||
assert len(result) == 4
|
||||
assert all(server.credentials is None for server in result)
|
||||
|
||||
# Check that both DB and config servers are returned
|
||||
server_ids = [server.server_id for server in result]
|
||||
@@ -428,6 +436,9 @@ class TestListMCPServers:
|
||||
return_value=mock_servers_with_health
|
||||
)
|
||||
|
||||
for idx, server in enumerate(mock_servers_with_health):
|
||||
server.credentials = {"auth_value": f"secret_{idx}"}
|
||||
|
||||
with patch(
|
||||
"litellm.proxy.management_endpoints.mcp_management_endpoints.global_mcp_server_manager",
|
||||
mock_manager,
|
||||
@@ -447,6 +458,7 @@ class TestListMCPServers:
|
||||
|
||||
# Verify results - should only return servers user has access to
|
||||
assert len(result) == 2
|
||||
assert all(server.credentials is None for server in result)
|
||||
|
||||
# Check that only allowed servers are returned
|
||||
server_ids = [server.server_id for server in result]
|
||||
@@ -464,6 +476,51 @@ class TestListMCPServers:
|
||||
assert server.url == "https://actions.zapier.com/mcp/sse"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fetch_single_mcp_server_redacts_credentials(self):
|
||||
mock_server = generate_mock_mcp_server_db_record(
|
||||
server_id="server-1", alias="Server 1"
|
||||
)
|
||||
mock_server.credentials = {"auth_value": "top-secret"}
|
||||
|
||||
mock_prisma_client = MagicMock()
|
||||
mock_health_result = {
|
||||
"status": "healthy",
|
||||
"last_health_check": datetime.now().isoformat(),
|
||||
"error": None,
|
||||
}
|
||||
|
||||
mock_user_auth = generate_mock_user_api_key_auth(
|
||||
user_role=LitellmUserRoles.PROXY_ADMIN
|
||||
)
|
||||
|
||||
with patch(
|
||||
"litellm.proxy.management_endpoints.mcp_management_endpoints.get_prisma_client_or_throw",
|
||||
return_value=mock_prisma_client,
|
||||
), patch(
|
||||
"litellm.proxy.management_endpoints.mcp_management_endpoints.get_mcp_server",
|
||||
AsyncMock(return_value=mock_server),
|
||||
), patch(
|
||||
"litellm.proxy.management_endpoints.mcp_management_endpoints.global_mcp_server_manager.health_check_server",
|
||||
AsyncMock(return_value=mock_health_result),
|
||||
), patch(
|
||||
"litellm.proxy.management_endpoints.mcp_management_endpoints._user_has_admin_view",
|
||||
return_value=True,
|
||||
):
|
||||
from litellm.proxy.management_endpoints.mcp_management_endpoints import (
|
||||
fetch_mcp_server,
|
||||
)
|
||||
|
||||
result = await fetch_mcp_server(
|
||||
server_id="server-1", user_api_key_dict=mock_user_auth
|
||||
)
|
||||
|
||||
assert result.server_id == "server-1"
|
||||
assert result.credentials is None
|
||||
assert mock_server.credentials == {"auth_value": "top-secret"}
|
||||
assert result.status == "healthy"
|
||||
|
||||
|
||||
class TestMCPHealthCheckEndpoints:
|
||||
"""Test MCP health check endpoints"""
|
||||
|
||||
@@ -721,6 +778,8 @@ class TestMCPHealthCheckEndpoints:
|
||||
return_value=[mock_server]
|
||||
)
|
||||
|
||||
mock_server.credentials = {"auth_value": "secret"}
|
||||
|
||||
mock_user_auth = generate_mock_user_api_key_auth(
|
||||
user_role=LitellmUserRoles.PROXY_ADMIN
|
||||
)
|
||||
@@ -749,3 +808,4 @@ class TestMCPHealthCheckEndpoints:
|
||||
assert server.status == "healthy"
|
||||
assert server.last_health_check is not None
|
||||
assert server.health_check_error is None
|
||||
assert server.credentials is None
|
||||
|
||||
@@ -7,8 +7,6 @@ import NotificationsManager from "../molecules/notifications_manager";
|
||||
|
||||
export function ToolTestPanel({
|
||||
tool,
|
||||
needsAuth,
|
||||
authValue,
|
||||
onSubmit,
|
||||
isLoading,
|
||||
result,
|
||||
@@ -16,8 +14,6 @@ export function ToolTestPanel({
|
||||
onClose,
|
||||
}: {
|
||||
tool: MCPTool;
|
||||
needsAuth: boolean;
|
||||
authValue?: string | null;
|
||||
onSubmit: (args: Record<string, any>) => void;
|
||||
isLoading: boolean;
|
||||
result: any | null;
|
||||
|
||||
@@ -3,7 +3,7 @@ import { Modal, Tooltip, Form, Select } from "antd";
|
||||
import { InfoCircleOutlined } from "@ant-design/icons";
|
||||
import { Button, TextInput } from "@tremor/react";
|
||||
import { createMCPServer } from "../networking";
|
||||
import { MCPServer, MCPServerCostInfo } from "./types";
|
||||
import { AUTH_TYPE, MCPServer, MCPServerCostInfo } from "./types";
|
||||
import MCPServerCostConfig from "./mcp_server_cost_config";
|
||||
import MCPConnectionStatus from "./mcp_connection_status";
|
||||
import MCPToolConfiguration from "./mcp_tool_configuration";
|
||||
@@ -25,6 +25,12 @@ interface CreateMCPServerProps {
|
||||
availableAccessGroups: string[];
|
||||
}
|
||||
|
||||
const AUTH_TYPES_REQUIRING_AUTH_VALUE = [
|
||||
AUTH_TYPE.API_KEY,
|
||||
AUTH_TYPE.BEARER_TOKEN,
|
||||
AUTH_TYPE.BASIC,
|
||||
];
|
||||
|
||||
const CreateMCPServer: React.FC<CreateMCPServerProps> = ({
|
||||
userRole,
|
||||
accessToken,
|
||||
@@ -43,6 +49,10 @@ const CreateMCPServer: React.FC<CreateMCPServerProps> = ({
|
||||
const [transportType, setTransportType] = useState<string>("");
|
||||
const [searchValue, setSearchValue] = useState<string>("");
|
||||
const [urlWarning, setUrlWarning] = useState<string>("");
|
||||
const authType = formValues.auth_type as string | undefined;
|
||||
const shouldShowAuthValueField = authType
|
||||
? AUTH_TYPES_REQUIRING_AUTH_VALUE.includes(authType)
|
||||
: false;
|
||||
|
||||
// Function to check URL format based on transport type
|
||||
const checkUrlFormat = (url: string, transport: string) => {
|
||||
@@ -63,7 +73,12 @@ const CreateMCPServer: React.FC<CreateMCPServerProps> = ({
|
||||
const handleCreate = async (values: Record<string, any>) => {
|
||||
setIsLoading(true);
|
||||
try {
|
||||
const { static_headers: staticHeadersList, stdio_config: rawStdioConfig, ...restValues } = values;
|
||||
const {
|
||||
static_headers: staticHeadersList,
|
||||
stdio_config: rawStdioConfig,
|
||||
credentials: credentialValues,
|
||||
...restValues
|
||||
} = values;
|
||||
|
||||
// Transform access groups into objects with name property
|
||||
const accessGroups = restValues.mcp_access_groups;
|
||||
@@ -79,6 +94,26 @@ const CreateMCPServer: React.FC<CreateMCPServerProps> = ({
|
||||
}, {})
|
||||
: {} as Record<string, string>;
|
||||
|
||||
const credentialsPayload =
|
||||
credentialValues && typeof credentialValues === "object"
|
||||
? Object.entries(credentialValues).reduce((acc: Record<string, any>, [key, value]) => {
|
||||
if (value === undefined || value === null || value === "") {
|
||||
return acc;
|
||||
}
|
||||
if (key === "scopes") {
|
||||
if (Array.isArray(value)) {
|
||||
const filteredScopes = value.filter((scope) => scope != null && scope !== "");
|
||||
if (filteredScopes.length > 0) {
|
||||
acc[key] = filteredScopes;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
acc[key] = value;
|
||||
}
|
||||
return acc;
|
||||
}, {})
|
||||
: undefined;
|
||||
|
||||
// Process stdio configuration if present
|
||||
let stdioFields = {};
|
||||
if (rawStdioConfig && transportType === "stdio") {
|
||||
@@ -135,6 +170,18 @@ const CreateMCPServer: React.FC<CreateMCPServerProps> = ({
|
||||
static_headers: staticHeaders,
|
||||
};
|
||||
|
||||
payload.static_headers = staticHeaders;
|
||||
const includeCredentials =
|
||||
restValues.auth_type && AUTH_TYPES_REQUIRING_AUTH_VALUE.includes(restValues.auth_type);
|
||||
|
||||
if (
|
||||
includeCredentials &&
|
||||
credentialsPayload &&
|
||||
Object.keys(credentialsPayload).length > 0
|
||||
) {
|
||||
payload.credentials = credentialsPayload;
|
||||
}
|
||||
|
||||
console.log(`Payload: ${JSON.stringify(payload)}`);
|
||||
|
||||
if (accessToken != null) {
|
||||
@@ -393,6 +440,27 @@ const CreateMCPServer: React.FC<CreateMCPServerProps> = ({
|
||||
</Form.Item>
|
||||
)}
|
||||
|
||||
{transportType !== "stdio" && shouldShowAuthValueField && (
|
||||
<Form.Item
|
||||
label={
|
||||
<span className="text-sm font-medium text-gray-700 flex items-center">
|
||||
Authentication Value
|
||||
<Tooltip title="Token, password, or header value to send with each request for the selected auth type.">
|
||||
<InfoCircleOutlined className="ml-2 text-blue-400 hover:text-blue-600 cursor-help" />
|
||||
</Tooltip>
|
||||
</span>
|
||||
}
|
||||
name={["credentials", "auth_value"]}
|
||||
rules={[{ required: true, message: "Please enter the authentication value" }]}
|
||||
>
|
||||
<TextInput
|
||||
type="password"
|
||||
placeholder="Enter token or secret"
|
||||
className="rounded-lg border-gray-300 focus:border-blue-500 focus:ring-blue-500"
|
||||
/>
|
||||
</Form.Item>
|
||||
)}
|
||||
|
||||
{/* Stdio Configuration - only show for stdio transport */}
|
||||
<StdioConfiguration isVisible={transportType === "stdio"} />
|
||||
</div>
|
||||
|
||||
@@ -1,106 +0,0 @@
|
||||
// Utility functions for managing MCP server authentication tokens in localStorage
|
||||
|
||||
const MCP_AUTH_STORAGE_KEY = "litellm_mcp_auth_tokens";
|
||||
|
||||
export interface MCPAuthToken {
|
||||
serverId: string;
|
||||
serverAlias?: string;
|
||||
authValue: string;
|
||||
authType: string;
|
||||
timestamp: number;
|
||||
}
|
||||
|
||||
export interface MCPAuthStorage {
|
||||
[serverId: string]: MCPAuthToken;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get all stored MCP authentication tokens
|
||||
*/
|
||||
export const getMCPAuthTokens = (): MCPAuthStorage => {
|
||||
try {
|
||||
const stored = localStorage.getItem(MCP_AUTH_STORAGE_KEY);
|
||||
return stored ? JSON.parse(stored) : {};
|
||||
} catch (error) {
|
||||
console.error("Error reading MCP auth tokens from localStorage:", error);
|
||||
return {};
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Get authentication token for a specific MCP server
|
||||
*/
|
||||
export const getMCPAuthToken = (serverId: string, serverAlias?: string): string | null => {
|
||||
try {
|
||||
const tokens = getMCPAuthTokens();
|
||||
const token = tokens[serverId];
|
||||
|
||||
// If token exists, check if serverAlias matches (both can be undefined)
|
||||
if (token && token.serverAlias === serverAlias) {
|
||||
return token.authValue;
|
||||
}
|
||||
|
||||
// If no serverAlias was provided and token exists without serverAlias, return it
|
||||
if (token && !serverAlias && !token.serverAlias) {
|
||||
return token.authValue;
|
||||
}
|
||||
|
||||
return null;
|
||||
} catch (error) {
|
||||
console.error("Error getting MCP auth token:", error);
|
||||
return null;
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Store authentication token for an MCP server
|
||||
*/
|
||||
export const setMCPAuthToken = (serverId: string, authValue: string, authType: string, serverAlias?: string): void => {
|
||||
try {
|
||||
const tokens = getMCPAuthTokens();
|
||||
|
||||
tokens[serverId] = {
|
||||
serverId,
|
||||
serverAlias,
|
||||
authValue,
|
||||
authType,
|
||||
timestamp: Date.now(),
|
||||
};
|
||||
|
||||
localStorage.setItem(MCP_AUTH_STORAGE_KEY, JSON.stringify(tokens));
|
||||
} catch (error) {
|
||||
console.error("Error storing MCP auth token:", error);
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Remove authentication token for an MCP server
|
||||
*/
|
||||
export const removeMCPAuthToken = (serverId: string): void => {
|
||||
try {
|
||||
const tokens = getMCPAuthTokens();
|
||||
delete tokens[serverId];
|
||||
localStorage.setItem(MCP_AUTH_STORAGE_KEY, JSON.stringify(tokens));
|
||||
} catch (error) {
|
||||
console.error("Error removing MCP auth token:", error);
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Clear all MCP authentication tokens (useful for logout)
|
||||
*/
|
||||
export const clearMCPAuthTokens = (): void => {
|
||||
try {
|
||||
localStorage.removeItem(MCP_AUTH_STORAGE_KEY);
|
||||
} catch (error) {
|
||||
console.error("Error clearing MCP auth tokens:", error);
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Check if a token exists for a server
|
||||
*/
|
||||
export const hasMCPAuthToken = (serverId: string, serverAlias?: string): boolean => {
|
||||
const token = getMCPAuthToken(serverId, serverAlias);
|
||||
return token !== null;
|
||||
};
|
||||
@@ -1,7 +1,8 @@
|
||||
import React, { useState, useEffect } from "react";
|
||||
import { Form, Select, Button as AntdButton } from "antd";
|
||||
import { Form, Select, Button as AntdButton, Tooltip } from "antd";
|
||||
import { InfoCircleOutlined } from "@ant-design/icons";
|
||||
import { Button, TextInput, TabGroup, TabList, Tab, TabPanels, TabPanel } from "@tremor/react";
|
||||
import { MCPServer, MCPServerCostInfo } from "./types";
|
||||
import { AUTH_TYPE, MCPServer, MCPServerCostInfo } from "./types";
|
||||
import { updateMCPServer, testMCPToolsListRequest } from "../networking";
|
||||
import MCPServerCostConfig from "./mcp_server_cost_config";
|
||||
import MCPPermissionManagement from "./MCPPermissionManagement";
|
||||
@@ -17,6 +18,12 @@ interface MCPServerEditProps {
|
||||
availableAccessGroups: string[];
|
||||
}
|
||||
|
||||
const AUTH_TYPES_REQUIRING_AUTH_VALUE = [
|
||||
AUTH_TYPE.API_KEY,
|
||||
AUTH_TYPE.BEARER_TOKEN,
|
||||
AUTH_TYPE.BASIC,
|
||||
];
|
||||
|
||||
const MCPServerEdit: React.FC<MCPServerEditProps> = ({
|
||||
mcpServer,
|
||||
accessToken,
|
||||
@@ -31,6 +38,10 @@ const MCPServerEdit: React.FC<MCPServerEditProps> = ({
|
||||
const [searchValue, setSearchValue] = useState<string>("");
|
||||
const [aliasManuallyEdited, setAliasManuallyEdited] = useState(false);
|
||||
const [allowedTools, setAllowedTools] = useState<string[]>([]);
|
||||
const authType = Form.useWatch("auth_type", form) as string | undefined;
|
||||
const shouldShowAuthValueField = authType
|
||||
? AUTH_TYPES_REQUIRING_AUTH_VALUE.includes(authType)
|
||||
: false;
|
||||
|
||||
const initialStaticHeaders = React.useMemo(() => {
|
||||
if (!mcpServer.static_headers) {
|
||||
@@ -148,7 +159,11 @@ const MCPServerEdit: React.FC<MCPServerEditProps> = ({
|
||||
if (!accessToken) return;
|
||||
try {
|
||||
// Ensure access groups is always a string array
|
||||
const { static_headers: staticHeadersList, ...restValues } = values;
|
||||
const {
|
||||
static_headers: staticHeadersList,
|
||||
credentials: credentialValues,
|
||||
...restValues
|
||||
} = values;
|
||||
|
||||
const accessGroups = (restValues.mcp_access_groups || []).map((g: any) =>
|
||||
typeof g === "string" ? g : g.name || String(g),
|
||||
@@ -165,6 +180,26 @@ const MCPServerEdit: React.FC<MCPServerEditProps> = ({
|
||||
}, {})
|
||||
: {} as Record<string, string>;
|
||||
|
||||
const credentialsPayload =
|
||||
credentialValues && typeof credentialValues === "object"
|
||||
? Object.entries(credentialValues).reduce((acc: Record<string, any>, [key, value]) => {
|
||||
if (value === undefined || value === null || value === "") {
|
||||
return acc;
|
||||
}
|
||||
if (key === "scopes") {
|
||||
if (Array.isArray(value)) {
|
||||
const filteredScopes = value.filter((scope) => scope != null && scope !== "");
|
||||
if (filteredScopes.length > 0) {
|
||||
acc[key] = filteredScopes;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
acc[key] = value;
|
||||
}
|
||||
return acc;
|
||||
}, {})
|
||||
: undefined;
|
||||
|
||||
// Prepare the payload with cost configuration and permission fields
|
||||
const payload = {
|
||||
...restValues,
|
||||
@@ -183,6 +218,17 @@ const MCPServerEdit: React.FC<MCPServerEditProps> = ({
|
||||
static_headers: staticHeaders,
|
||||
};
|
||||
|
||||
const includeCredentials =
|
||||
restValues.auth_type && AUTH_TYPES_REQUIRING_AUTH_VALUE.includes(restValues.auth_type);
|
||||
|
||||
if (
|
||||
includeCredentials &&
|
||||
credentialsPayload &&
|
||||
Object.keys(credentialsPayload).length > 0
|
||||
) {
|
||||
payload.credentials = credentialsPayload;
|
||||
}
|
||||
|
||||
const updated = await updateMCPServer(accessToken, payload);
|
||||
NotificationsManager.success("MCP Server updated successfully");
|
||||
onSuccess(updated);
|
||||
@@ -250,6 +296,34 @@ const MCPServerEdit: React.FC<MCPServerEditProps> = ({
|
||||
</Select>
|
||||
</Form.Item>
|
||||
|
||||
{shouldShowAuthValueField && (
|
||||
<Form.Item
|
||||
label={
|
||||
<span className="text-sm font-medium text-gray-700 flex items-center">
|
||||
Authentication Value
|
||||
<Tooltip title="Token, password, or header value to send with each request for the selected auth type.">
|
||||
<InfoCircleOutlined className="ml-2 text-blue-400 hover:text-blue-600 cursor-help" />
|
||||
</Tooltip>
|
||||
</span>
|
||||
}
|
||||
name={["credentials", "auth_value"]}
|
||||
rules={[
|
||||
{
|
||||
validator: (_, value) =>
|
||||
value && typeof value === "string" && value.trim() === ""
|
||||
? Promise.reject(new Error("Authentication value cannot be empty"))
|
||||
: Promise.resolve(),
|
||||
},
|
||||
]}
|
||||
>
|
||||
<TextInput
|
||||
type="password"
|
||||
placeholder="Enter token or secret (leave blank to keep existing)"
|
||||
className="rounded-lg border-gray-300 focus:border-blue-500 focus:ring-blue-500"
|
||||
/>
|
||||
</Form.Item>
|
||||
)}
|
||||
|
||||
{/* Permission Management / Access Control Section */}
|
||||
<div className="mt-6">
|
||||
<MCPPermissionManagement
|
||||
|
||||
@@ -1,122 +1,15 @@
|
||||
import React, { useState, useEffect } from "react";
|
||||
import { useQuery, useMutation } from "@tanstack/react-query";
|
||||
import { ToolTestPanel } from "./ToolTestPanel";
|
||||
import { MCPTool, MCPToolsViewerProps, CallMCPToolResponse, mcpServerHasAuth } from "./types";
|
||||
import { MCPTool, MCPToolsViewerProps, CallMCPToolResponse } from "./types";
|
||||
import { listMCPTools, callMCPTool } from "../networking";
|
||||
import { getMCPAuthToken, setMCPAuthToken, removeMCPAuthToken } from "./mcp_auth_storage";
|
||||
|
||||
import { Modal, Input, Form } from "antd";
|
||||
import { Button, Card, Title, Text } from "@tremor/react";
|
||||
import { RobotOutlined, SafetyOutlined, ToolOutlined } from "@ant-design/icons";
|
||||
|
||||
import { AUTH_TYPE } from "./types";
|
||||
import NotificationsManager from "../molecules/notifications_manager";
|
||||
|
||||
type AuthModalProps = {
|
||||
visible: boolean;
|
||||
onOk: (values: any) => void;
|
||||
onCancel: () => void;
|
||||
authType?: string | null;
|
||||
};
|
||||
|
||||
export const AuthModal = ({ visible, onOk, onCancel, authType }: AuthModalProps) => {
|
||||
const [form] = Form.useForm();
|
||||
|
||||
// Handler for modal OK
|
||||
const handleOk = () => {
|
||||
form.validateFields().then((values) => {
|
||||
if (authType === AUTH_TYPE.BASIC) {
|
||||
onOk(`${values.username.trim()}:${values.password.trim()}`);
|
||||
} else {
|
||||
onOk(values.authValue.trim());
|
||||
}
|
||||
});
|
||||
};
|
||||
|
||||
let content;
|
||||
if (authType === AUTH_TYPE.API_KEY || authType === AUTH_TYPE.BEARER_TOKEN) {
|
||||
const label = authType === AUTH_TYPE.API_KEY ? "API Key" : "Bearer Token";
|
||||
content = (
|
||||
<Form.Item name="authValue" label={label} rules={[{ required: true, message: `Please input your ${label}` }]}>
|
||||
<Input.Password />
|
||||
</Form.Item>
|
||||
);
|
||||
} else if (authType === AUTH_TYPE.BASIC) {
|
||||
content = (
|
||||
<>
|
||||
<Form.Item name="username" label="Username" rules={[{ required: true, message: "Please input your username" }]}>
|
||||
<Input />
|
||||
</Form.Item>
|
||||
<Form.Item name="password" label="Password" rules={[{ required: true, message: "Please input your password" }]}>
|
||||
<Input.Password />
|
||||
</Form.Item>
|
||||
</>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<Modal open={visible} title="Authentication" onOk={handleOk} onCancel={onCancel} destroyOnClose>
|
||||
<Form form={form} layout="vertical">
|
||||
{content}
|
||||
</Form>
|
||||
</Modal>
|
||||
);
|
||||
};
|
||||
|
||||
const AuthSection = ({
|
||||
authType,
|
||||
onAuthSubmit,
|
||||
onClearAuth,
|
||||
hasAuth,
|
||||
}: {
|
||||
authType: string | null | undefined;
|
||||
onAuthSubmit: (value: string) => void;
|
||||
onClearAuth: () => void;
|
||||
hasAuth: boolean;
|
||||
}) => {
|
||||
const [modalVisible, setModalVisible] = useState(false);
|
||||
|
||||
const handleAddAuth = () => setModalVisible(true);
|
||||
|
||||
const handleModalOk = (authValue: string) => {
|
||||
onAuthSubmit(authValue);
|
||||
setModalVisible(false);
|
||||
};
|
||||
|
||||
const handleModalCancel = () => setModalVisible(false);
|
||||
|
||||
const handleClearAuth = () => {
|
||||
onClearAuth();
|
||||
};
|
||||
|
||||
return (
|
||||
<div className="space-y-3">
|
||||
<div className="flex items-center justify-between">
|
||||
<Text className="text-sm font-medium text-gray-700">Authentication {hasAuth ? "✓" : ""}</Text>
|
||||
<div className="flex gap-2">
|
||||
{hasAuth && (
|
||||
<Button
|
||||
onClick={handleClearAuth}
|
||||
size="sm"
|
||||
variant="secondary"
|
||||
className="text-xs text-red-600 hover:text-red-700"
|
||||
>
|
||||
Clear
|
||||
</Button>
|
||||
)}
|
||||
<Button onClick={handleAddAuth} size="sm" variant="secondary" className="text-xs">
|
||||
{hasAuth ? "Update" : "Add Auth"}
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
<Text className="text-xs text-gray-500">
|
||||
{hasAuth ? "Authentication configured and saved locally" : "Some tools may require authentication"}
|
||||
</Text>
|
||||
<AuthModal visible={modalVisible} onOk={handleModalOk} onCancel={handleModalCancel} authType={authType} />
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
const MCPToolsViewer = ({
|
||||
serverId,
|
||||
accessToken,
|
||||
@@ -125,47 +18,20 @@ const MCPToolsViewer = ({
|
||||
userID,
|
||||
serverAlias, // Add serverAlias prop
|
||||
}: MCPToolsViewerProps) => {
|
||||
const [mcpAuthValue, setMcpAuthValue] = useState("");
|
||||
const [selectedTool, setSelectedTool] = useState<MCPTool | null>(null);
|
||||
const [toolResult, setToolResult] = useState<CallMCPToolResponse | null>(null);
|
||||
const [toolError, setToolError] = useState<Error | null>(null);
|
||||
|
||||
// Load stored auth token on component mount
|
||||
useEffect(() => {
|
||||
if (mcpServerHasAuth(auth_type)) {
|
||||
const storedAuthValue = getMCPAuthToken(serverId, serverAlias || undefined);
|
||||
if (storedAuthValue) {
|
||||
setMcpAuthValue(storedAuthValue);
|
||||
}
|
||||
}
|
||||
}, [serverId, serverAlias, auth_type]);
|
||||
|
||||
// Function to handle auth submission with localStorage persistence
|
||||
const handleAuthSubmit = (authValue: string) => {
|
||||
setMcpAuthValue(authValue);
|
||||
if (authValue && mcpServerHasAuth(auth_type)) {
|
||||
setMCPAuthToken(serverId, authValue, auth_type || "none", serverAlias || undefined);
|
||||
NotificationsManager.success("Authentication token saved locally");
|
||||
}
|
||||
};
|
||||
|
||||
// Function to clear auth token
|
||||
const handleClearAuth = () => {
|
||||
setMcpAuthValue("");
|
||||
removeMCPAuthToken(serverId);
|
||||
NotificationsManager.info("Authentication token cleared");
|
||||
};
|
||||
|
||||
// Query to fetch MCP tools
|
||||
const {
|
||||
data: mcpToolsResponse,
|
||||
isLoading: isLoadingTools,
|
||||
error: mcpToolsError,
|
||||
} = useQuery({
|
||||
queryKey: ["mcpTools", serverId, mcpAuthValue, serverAlias],
|
||||
queryKey: ["mcpTools", serverId],
|
||||
queryFn: () => {
|
||||
if (!accessToken) throw new Error("Access Token required");
|
||||
return listMCPTools(accessToken, serverId, mcpAuthValue, serverAlias || undefined);
|
||||
return listMCPTools(accessToken, serverId);
|
||||
},
|
||||
enabled: !!accessToken,
|
||||
staleTime: 30000, // Consider data fresh for 30 seconds
|
||||
@@ -173,7 +39,7 @@ const MCPToolsViewer = ({
|
||||
|
||||
// Mutation for calling a tool
|
||||
const { mutate: executeTool, isPending: isCallingTool } = useMutation({
|
||||
mutationFn: async (args: { tool: MCPTool; arguments: Record<string, any>; authValue: string }) => {
|
||||
mutationFn: async (args: { tool: MCPTool; arguments: Record<string, any> }) => {
|
||||
if (!accessToken) throw new Error("Access Token required");
|
||||
|
||||
try {
|
||||
@@ -181,8 +47,6 @@ const MCPToolsViewer = ({
|
||||
accessToken,
|
||||
args.tool.name,
|
||||
args.arguments,
|
||||
args.authValue,
|
||||
serverAlias || undefined,
|
||||
);
|
||||
return result;
|
||||
} catch (error) {
|
||||
@@ -200,7 +64,6 @@ const MCPToolsViewer = ({
|
||||
});
|
||||
|
||||
const toolsData = mcpToolsResponse?.tools || [];
|
||||
const hasAuth = mcpAuthValue !== "";
|
||||
|
||||
return (
|
||||
<div className="w-full h-screen p-4 bg-white">
|
||||
@@ -317,44 +180,6 @@ const MCPToolsViewer = ({
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
|
||||
{/* Authentication Section - Below tools list */}
|
||||
{mcpServerHasAuth(auth_type) && (
|
||||
<div className="pt-4 border-t border-gray-200 flex-shrink-0 mt-6">
|
||||
{!hasAuth ? (
|
||||
/* Prominent display when auth required but not provided */
|
||||
<div className="p-4 bg-gradient-to-r from-orange-50 to-red-50 border border-orange-200 rounded-lg">
|
||||
<div className="flex items-center mb-3">
|
||||
<SafetyOutlined className="mr-2 text-orange-600 text-lg" />
|
||||
<Text className="font-semibold text-orange-800">Authentication Required</Text>
|
||||
</div>
|
||||
<Text className="text-sm text-orange-700 mb-4">
|
||||
This MCP server requires authentication. You must add your credentials below to access the
|
||||
tools.
|
||||
</Text>
|
||||
<AuthSection
|
||||
authType={auth_type}
|
||||
onAuthSubmit={handleAuthSubmit}
|
||||
onClearAuth={handleClearAuth}
|
||||
hasAuth={hasAuth}
|
||||
/>
|
||||
</div>
|
||||
) : (
|
||||
/* Subtle display when already authenticated */
|
||||
<>
|
||||
<Text className="font-medium block mb-3 text-gray-700 flex items-center">
|
||||
<SafetyOutlined className="mr-2" /> Authentication
|
||||
</Text>
|
||||
<AuthSection
|
||||
authType={auth_type}
|
||||
onAuthSubmit={handleAuthSubmit}
|
||||
onClearAuth={handleClearAuth}
|
||||
hasAuth={hasAuth}
|
||||
/>
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
@@ -379,10 +204,8 @@ const MCPToolsViewer = ({
|
||||
<div className="h-full">
|
||||
<ToolTestPanel
|
||||
tool={selectedTool}
|
||||
needsAuth={mcpServerHasAuth(auth_type)}
|
||||
authValue={mcpAuthValue}
|
||||
onSubmit={(args) => {
|
||||
executeTool({ tool: selectedTool, arguments: args, authValue: mcpAuthValue });
|
||||
executeTool({ tool: selectedTool, arguments: args });
|
||||
}}
|
||||
result={toolResult}
|
||||
error={toolError}
|
||||
|
||||
@@ -34,10 +34,6 @@ export const handleAuth = (authType?: string | null): string => {
|
||||
return authType;
|
||||
};
|
||||
|
||||
export const mcpServerHasAuth = (authType?: string | null): boolean => {
|
||||
return handleAuth(authType) !== AUTH_TYPE.NONE;
|
||||
};
|
||||
|
||||
// Define the structure for tool input schema properties
|
||||
export interface InputSchemaProperty {
|
||||
type: string;
|
||||
|
||||
@@ -15,7 +15,6 @@ import {
|
||||
import { clearTokenCookies } from "@/utils/cookieUtils";
|
||||
import { fetchProxySettings } from "@/utils/proxyUtils";
|
||||
import { useTheme } from "@/contexts/ThemeContext";
|
||||
import { clearMCPAuthTokens } from "./mcp_tools/mcp_auth_storage";
|
||||
import useFeatureFlags from "@/hooks/useFeatureFlags";
|
||||
|
||||
interface NavbarProps {
|
||||
@@ -71,7 +70,6 @@ const Navbar: React.FC<NavbarProps> = ({
|
||||
|
||||
const handleLogout = () => {
|
||||
clearTokenCookies();
|
||||
clearMCPAuthTokens(); // Clear MCP auth tokens on logout
|
||||
window.location.href = logoutUrl;
|
||||
};
|
||||
|
||||
|
||||
@@ -5750,7 +5750,7 @@ export const testSearchToolConnection = async (accessToken: string, litellmParam
|
||||
}
|
||||
};
|
||||
|
||||
export const listMCPTools = async (accessToken: string, serverId: string, authValue?: string, serverAlias?: string) => {
|
||||
export const listMCPTools = async (accessToken: string, serverId: string) => {
|
||||
try {
|
||||
// Construct base URL
|
||||
let url = proxyBaseUrl
|
||||
@@ -5764,14 +5764,6 @@ export const listMCPTools = async (accessToken: string, serverId: string, authVa
|
||||
"Content-Type": "application/json",
|
||||
};
|
||||
|
||||
// Use new server-specific auth header format if serverAlias is provided
|
||||
if (serverAlias && authValue) {
|
||||
headers[`x-mcp-${serverAlias}-authorization`] = authValue;
|
||||
} else if (authValue) {
|
||||
// Fall back to deprecated x-mcp-auth header for backward compatibility
|
||||
headers[MCP_AUTH_HEADER] = authValue;
|
||||
}
|
||||
|
||||
const response = await fetch(url, {
|
||||
method: "GET",
|
||||
headers,
|
||||
@@ -5805,9 +5797,7 @@ export const listMCPTools = async (accessToken: string, serverId: string, authVa
|
||||
export const callMCPTool = async (
|
||||
accessToken: string,
|
||||
toolName: string,
|
||||
toolArguments: Record<string, any>,
|
||||
authValue: string,
|
||||
serverAlias?: string,
|
||||
toolArguments: Record<string, any>
|
||||
) => {
|
||||
try {
|
||||
// Construct base URL
|
||||
@@ -5820,14 +5810,6 @@ export const callMCPTool = async (
|
||||
"Content-Type": "application/json",
|
||||
};
|
||||
|
||||
// Use new server-specific auth header format if serverAlias is provided
|
||||
if (serverAlias) {
|
||||
headers[`x-mcp-${serverAlias}-authorization`] = authValue;
|
||||
} else {
|
||||
// Fall back to deprecated x-mcp-auth header for backward compatibility
|
||||
headers[MCP_AUTH_HEADER] = authValue;
|
||||
}
|
||||
|
||||
const response = await fetch(url, {
|
||||
method: "POST",
|
||||
headers,
|
||||
|
||||
@@ -9,6 +9,12 @@ interface MCPServerConfig {
|
||||
auth_type?: string;
|
||||
mcp_info?: any;
|
||||
static_headers?: Record<string, string>;
|
||||
credentials?: {
|
||||
auth_value?: string;
|
||||
client_id?: string;
|
||||
client_secret?: string;
|
||||
scopes?: string[];
|
||||
};
|
||||
}
|
||||
|
||||
interface UseTestMCPConnectionProps {
|
||||
@@ -41,6 +47,7 @@ export const useTestMCPConnection = ({
|
||||
const canFetchTools = !!(formValues.url && formValues.transport && formValues.auth_type && accessToken);
|
||||
|
||||
const staticHeadersKey = JSON.stringify(formValues.static_headers ?? {});
|
||||
const credentialsKey = JSON.stringify(formValues.credentials ?? {});
|
||||
|
||||
const fetchTools = async () => {
|
||||
if (!accessToken || !formValues.url) {
|
||||
@@ -74,6 +81,29 @@ export const useTestMCPConnection = ({
|
||||
)
|
||||
: {} as Record<string, string>;
|
||||
|
||||
const credentials =
|
||||
formValues.credentials && typeof formValues.credentials === "object"
|
||||
? Object.entries(formValues.credentials).reduce(
|
||||
(acc: Record<string, any>, [key, value]) => {
|
||||
if (value === undefined || value === null || value === "") {
|
||||
return acc;
|
||||
}
|
||||
if (key === "scopes") {
|
||||
if (Array.isArray(value)) {
|
||||
const normalizedScopes = value.filter((scope) => scope != null && scope !== "");
|
||||
if (normalizedScopes.length > 0) {
|
||||
acc[key] = normalizedScopes;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
acc[key] = value;
|
||||
}
|
||||
return acc;
|
||||
},
|
||||
{},
|
||||
)
|
||||
: undefined;
|
||||
|
||||
const mcpServerConfig: MCPServerConfig = {
|
||||
server_id: formValues.server_id || "",
|
||||
server_name: formValues.server_name || "",
|
||||
@@ -84,6 +114,10 @@ export const useTestMCPConnection = ({
|
||||
static_headers: staticHeaders,
|
||||
};
|
||||
|
||||
if (credentials && Object.keys(credentials).length > 0) {
|
||||
mcpServerConfig.credentials = credentials;
|
||||
}
|
||||
|
||||
const toolsResponse = await testMCPToolsListRequest(accessToken, mcpServerConfig);
|
||||
|
||||
if (toolsResponse.tools && !toolsResponse.error) {
|
||||
@@ -126,7 +160,16 @@ export const useTestMCPConnection = ({
|
||||
clearTools();
|
||||
}
|
||||
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||
}, [formValues.url, formValues.transport, formValues.auth_type, accessToken, enabled, canFetchTools, staticHeadersKey]);
|
||||
}, [
|
||||
formValues.url,
|
||||
formValues.transport,
|
||||
formValues.auth_type,
|
||||
accessToken,
|
||||
enabled,
|
||||
canFetchTools,
|
||||
staticHeadersKey,
|
||||
credentialsKey,
|
||||
]);
|
||||
|
||||
return {
|
||||
tools,
|
||||
|
||||
Reference in New Issue
Block a user