Feat/persist mcp credentials in db (#16308)

* feat: persist mcp credentials in db

* feat: remove Auth Value field from MCP Tool Testing Playground

* fix: test
This commit is contained in:
YutaSaito
2025-11-08 12:22:49 +09:00
committed by GitHub
parent b6f792f301
commit 6eb74bd62a
22 changed files with 529 additions and 352 deletions
@@ -0,0 +1,2 @@
-- AlterTable
ALTER TABLE "LiteLLM_MCPServerTable" ADD COLUMN "credentials" JSONB DEFAULT '{}';
@@ -174,6 +174,7 @@ model LiteLLM_MCPServerTable {
url String?
transport String @default("sse")
auth_type String?
credentials Json? @default("{}")
created_at DateTime? @default(now()) @map("created_at")
created_by String?
updated_at DateTime? @default(now()) @updatedAt @map("updated_at")
+67 -1
View File
@@ -1,4 +1,4 @@
from typing import Any, Dict, Iterable, List, Optional, Set, Union
from typing import Any, Dict, Iterable, List, Optional, Set, Union, cast
from litellm._logging import verbose_proxy_logger
from litellm._uuid import uuid
@@ -11,7 +11,12 @@ from litellm.proxy._types import (
UpdateMCPServerRequest,
UserAPIKeyAuth,
)
from litellm.proxy.common_utils.encrypt_decrypt_utils import (
_get_salt_key,
encrypt_value_helper,
)
from litellm.proxy.utils import PrismaClient
from litellm.types.mcp import MCPCredentials
def _prepare_mcp_server_data(
@@ -35,6 +40,14 @@ def _prepare_mcp_server_data(
if "alias" not in data_dict:
data_dict["alias"] = getattr(data, "alias", None)
# Handle credentials serialization
credentials = data_dict.get("credentials")
if credentials is not None:
data_dict["credentials"] = encrypt_credentials(
credentials=credentials, encryption_key=_get_salt_key()
)
data_dict["credentials"] = safe_dumps(data_dict["credentials"])
# Handle static_headers serialization
if data.static_headers is not None:
data_dict["static_headers"] = safe_dumps(data.static_headers)
@@ -52,6 +65,30 @@ def _prepare_mcp_server_data(
return data_dict
def encrypt_credentials(
credentials: MCPCredentials, encryption_key: Optional[str]
) -> MCPCredentials:
auth_value = credentials.get("auth_value")
if auth_value is not None:
credentials["auth_value"] = encrypt_value_helper(
value=auth_value,
new_encryption_key=encryption_key,
)
client_id = credentials.get("client_id")
if client_id is not None:
credentials["client_id"] = encrypt_value_helper(
value=client_id,
new_encryption_key=encryption_key,
)
client_secret = credentials.get("client_secret")
if client_secret is not None:
credentials["client_secret"] = encrypt_value_helper(
value=client_secret,
new_encryption_key=encryption_key,
)
return credentials
async def get_all_mcp_servers(
prisma_client: PrismaClient,
) -> List[LiteLLM_MCPServerTable]:
@@ -303,3 +340,32 @@ async def update_mcp_server(
)
return updated_mcp_server
async def rotate_mcp_server_credentials_master_key(
prisma_client: PrismaClient, touched_by: str, new_master_key: str
):
mcp_servers = await prisma_client.db.litellm_mcpservertable.find_many()
for mcp_server in mcp_servers:
credentials = mcp_server.credentials
if not credentials:
continue
credentials_copy = dict(credentials)
encrypted_credentials = encrypt_credentials(
credentials=cast(MCPCredentials, credentials_copy),
encryption_key=new_master_key,
)
from litellm.litellm_core_utils.safe_json_dumps import safe_dumps
serialized_credentials = safe_dumps(encrypted_credentials)
await prisma_client.db.litellm_mcpservertable.update(
where={"server_id": mcp_server.server_id},
data={
"credentials": serialized_credentials,
"updated_by": touched_by,
},
)
@@ -38,6 +38,9 @@ from litellm.proxy._types import (
MCPTransportType,
UserAPIKeyAuth,
)
from litellm.proxy.common_utils.encrypt_decrypt_utils import (
decrypt_value_helper,
)
from litellm.proxy.utils import ProxyLogging
from litellm.types.mcp import MCPAuth, MCPStdioConfig
from litellm.types.mcp_server.mcp_server_manager import MCPInfo, MCPServer
@@ -400,6 +403,20 @@ class MCPServerManager:
static_headers_dict = _deserialize_json_dict(
getattr(mcp_server, "static_headers", None)
)
credentials_dict = _deserialize_json_dict(
getattr(mcp_server, "credentials", None)
)
encrypted_auth_value: Optional[str] = None
if credentials_dict:
encrypted_auth_value = credentials_dict.get("auth_value")
auth_value: Optional[str] = None
if encrypted_auth_value:
auth_value = decrypt_value_helper(
value=encrypted_auth_value,
key="auth_value",
)
# Use alias for name if present, else server_name
name_for_prefix = (
mcp_server.alias or mcp_server.server_name or mcp_server.server_id
@@ -422,6 +439,7 @@ class MCPServerManager:
url=mcp_server.url,
transport=cast(MCPTransportType, mcp_server.transport),
auth_type=cast(MCPAuthType, mcp_server.auth_type),
authentication_token=auth_value,
mcp_info=mcp_info,
extra_headers=getattr(mcp_server, "extra_headers", None),
static_headers=static_headers_dict,
+32 -1
View File
@@ -17,7 +17,13 @@ from typing_extensions import Required, TypedDict
from litellm._uuid import uuid
from litellm.types.integrations.slack_alerting import AlertType
from litellm.types.llms.openai import AllMessageValues, OpenAIFileObject
from litellm.types.mcp import MCPAuthType, MCPTransport, MCPTransportType
from litellm.types.mcp import (
MCPAuth,
MCPAuthType,
MCPCredentials,
MCPTransport,
MCPTransportType,
)
from litellm.types.mcp_server.mcp_server_manager import MCPInfo
from litellm.types.router import RouterErrors, UpdateRouterConfig
from litellm.types.secret_managers.main import KeyManagementSystem
@@ -959,6 +965,7 @@ class NewMCPServerRequest(LiteLLMPydanticObjectBase):
description: Optional[str] = None
transport: MCPTransportType = MCPTransport.sse
auth_type: Optional[MCPAuthType] = None
credentials: Optional[MCPCredentials] = None
url: Optional[str] = None
mcp_info: Optional[MCPInfo] = None
mcp_access_groups: List[str] = Field(default_factory=list)
@@ -985,6 +992,28 @@ class NewMCPServerRequest(LiteLLMPydanticObjectBase):
raise ValueError("url is required for HTTP/SSE transport")
return values
@model_validator(mode="before")
@classmethod
def validate_credentials_requirements(cls, values):
if not isinstance(values, dict):
return values
auth_type = values.get("auth_type")
if auth_type in {MCPAuth.api_key, MCPAuth.bearer_token, MCPAuth.basic}:
credentials = values.get("credentials")
auth_value = None
if isinstance(credentials, dict):
auth_value = credentials.get("auth_value")
elif hasattr(credentials, "get"):
auth_value = credentials.get("auth_value") # type: ignore[attr-defined]
if not auth_value:
raise ValueError(
"auth_value is required when auth_type is api_key, bearer_token, or basic"
)
return values
class UpdateMCPServerRequest(LiteLLMPydanticObjectBase):
server_id: str
@@ -993,6 +1022,7 @@ class UpdateMCPServerRequest(LiteLLMPydanticObjectBase):
description: Optional[str] = None
transport: MCPTransportType = MCPTransport.sse
auth_type: Optional[MCPAuthType] = None
credentials: Optional[MCPCredentials] = None
url: Optional[str] = None
mcp_info: Optional[MCPInfo] = None
mcp_access_groups: List[str] = Field(default_factory=list)
@@ -1028,6 +1058,7 @@ class LiteLLM_MCPServerTable(LiteLLMPydanticObjectBase):
url: Optional[str] = None
transport: MCPTransportType
auth_type: Optional[MCPAuthType] = None
credentials: Optional[MCPCredentials] = None
created_at: Optional[datetime] = None
created_by: Optional[str] = None
updated_at: Optional[datetime] = None
@@ -24,8 +24,15 @@ import litellm
from litellm._logging import verbose_proxy_logger
from litellm._uuid import uuid
from litellm.caching import DualCache
from litellm.constants import LENGTH_OF_LITELLM_GENERATED_KEY, UI_SESSION_TOKEN_TEAM_ID
from litellm.constants import (
LENGTH_OF_LITELLM_GENERATED_KEY,
LITELLM_PROXY_ADMIN_NAME,
UI_SESSION_TOKEN_TEAM_ID,
)
from litellm.litellm_core_utils.duration_parser import duration_in_seconds
from litellm.proxy._experimental.mcp_server.db import (
rotate_mcp_server_credentials_master_key,
)
from litellm.proxy._types import *
from litellm.proxy._types import LiteLLM_VerificationToken
from litellm.proxy.auth.auth_checks import (
@@ -640,9 +647,9 @@ async def _common_key_generation_helper( # noqa: PLR0915
request_type="key", **data_json, table_name="key"
)
response["soft_budget"] = (
data.soft_budget
) # include the user-input soft budget in the response
response[
"soft_budget"
] = data.soft_budget # include the user-input soft budget in the response
response = GenerateKeyResponse(**response)
@@ -1299,7 +1306,6 @@ async def prepare_key_update_data(
data: Union[UpdateKeyRequest, RegenerateKeyRequest],
existing_key_row: LiteLLM_VerificationToken,
):
data_json: dict = data.model_dump(exclude_unset=True)
data_json.pop("key", None)
data_json.pop("new_key", None)
@@ -2357,10 +2363,10 @@ async def delete_verification_tokens(
try:
if prisma_client:
tokens = [_hash_token_if_needed(token=key) for key in tokens]
_keys_being_deleted: List[LiteLLM_VerificationToken] = (
await prisma_client.db.litellm_verificationtoken.find_many(
where={"token": {"in": tokens}}
)
_keys_being_deleted: List[
LiteLLM_VerificationToken
] = await prisma_client.db.litellm_verificationtoken.find_many(
where={"token": {"in": tokens}}
)
if len(_keys_being_deleted) == 0:
@@ -2468,9 +2474,9 @@ async def _rotate_master_key(
from litellm.proxy.proxy_server import proxy_config
try:
models: Optional[List] = (
await prisma_client.db.litellm_proxymodeltable.find_many()
)
models: Optional[
List
] = await prisma_client.db.litellm_proxymodeltable.find_many()
except Exception:
models = None
# 2. process model table
@@ -2524,6 +2530,13 @@ async def _rotate_master_key(
data={"param_value": jsonify_object(encrypted_env_vars)},
)
# 4. process MCP server table
await rotate_mcp_server_credentials_master_key(
prisma_client=prisma_client,
touched_by=user_api_key_dict.user_id or LITELLM_PROXY_ADMIN_NAME,
new_master_key=new_master_key,
)
def get_new_token(data: Optional[RegenerateKeyRequest]) -> str:
if data and data.new_key is not None:
@@ -2781,11 +2794,11 @@ async def validate_key_list_check(
param="user_id",
code=status.HTTP_403_FORBIDDEN,
)
complete_user_info_db_obj: Optional[BaseModel] = (
await prisma_client.db.litellm_usertable.find_unique(
where={"user_id": user_api_key_dict.user_id},
include={"organization_memberships": True},
)
complete_user_info_db_obj: Optional[
BaseModel
] = await prisma_client.db.litellm_usertable.find_unique(
where={"user_id": user_api_key_dict.user_id},
include={"organization_memberships": True},
)
if complete_user_info_db_obj is None:
@@ -2871,10 +2884,10 @@ async def get_admin_team_ids(
if complete_user_info is None:
return []
# Get all teams that user is an admin of
teams: Optional[List[BaseModel]] = (
await prisma_client.db.litellm_teamtable.find_many(
where={"team_id": {"in": complete_user_info.teams}}
)
teams: Optional[
List[BaseModel]
] = await prisma_client.db.litellm_teamtable.find_many(
where={"team_id": {"in": complete_user_info.teams}}
)
if teams is None:
return []
@@ -58,6 +58,24 @@ if MCP_AVAILABLE:
from litellm.proxy.management_endpoints.common_utils import _user_has_admin_view
from litellm.proxy.management_helpers.utils import management_endpoint_wrapper
def _redact_mcp_credentials(
mcp_server: LiteLLM_MCPServerTable,
) -> LiteLLM_MCPServerTable:
"""Return a copy of the MCP server object with credentials removed."""
try:
redacted_server = mcp_server.model_copy(deep=True)
except AttributeError:
redacted_server = mcp_server.copy(deep=True) # type: ignore[attr-defined]
redacted_server.credentials = None
return redacted_server
def _redact_mcp_credentials_list(
mcp_servers: Iterable[LiteLLM_MCPServerTable],
) -> List[LiteLLM_MCPServerTable]:
return [_redact_mcp_credentials(server) for server in mcp_servers]
def get_prisma_client_or_throw(message: str):
from litellm.proxy.proxy_server import prisma_client
@@ -273,11 +291,12 @@ if MCP_AVAILABLE:
```
"""
# Use server manager to get all servers with health and team data
return (
mcp_servers = (
await global_mcp_server_manager.get_all_mcp_servers_with_health_and_teams(
user_api_key_auth=user_api_key_dict
)
)
return _redact_mcp_credentials_list(mcp_servers)
@router.get(
"/server/{server_id}",
@@ -335,7 +354,7 @@ if MCP_AVAILABLE:
# Implement authz restriction from requested user
if _user_has_admin_view(user_api_key_dict):
return mcp_server
return _redact_mcp_credentials(mcp_server)
# Perform authz check to filter the mcp servers user has access to
mcp_server_records = await get_all_mcp_servers_for_user(
@@ -345,7 +364,7 @@ if MCP_AVAILABLE:
if exists:
global_mcp_server_manager.add_update_server(mcp_server)
return mcp_server
return _redact_mcp_credentials(mcp_server)
else:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
@@ -428,7 +447,7 @@ if MCP_AVAILABLE:
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail={"error": f"Error creating mcp server: {str(e)}"},
)
return new_mcp_server
return _redact_mcp_credentials(new_mcp_server)
@router.delete(
"/server/{server_id}",
@@ -563,4 +582,4 @@ if MCP_AVAILABLE:
if litellm.store_audit_logs:
pass
return mcp_server_record_updated
return _redact_mcp_credentials(mcp_server_record_updated)
+1
View File
@@ -174,6 +174,7 @@ model LiteLLM_MCPServerTable {
url String?
transport String @default("sse")
auth_type String?
credentials Json? @default("{}")
created_at DateTime? @default(now()) @map("created_at")
created_by String?
updated_at DateTime? @default(now()) @updatedAt @map("updated_at")
+22
View File
@@ -54,6 +54,28 @@ MCPAuthType = Optional[
]
class MCPCredentials(TypedDict, total=False):
auth_value: Optional[str]
"""
Authentication value
"""
client_id: Optional[str]
"""
OAuth 2.0 client identifier used when auth_type is oauth2
"""
client_secret: Optional[str]
"""
OAuth 2.0 client secret used when auth_type is oauth2
"""
scopes: Optional[List[str]]
"""
OAuth 2.0 scopes to request when exchanging the client credentials
"""
class MCPServerCostInfo(TypedDict, total=False):
default_cost_per_query: Optional[float]
"""
+1
View File
@@ -174,6 +174,7 @@ model LiteLLM_MCPServerTable {
url String?
transport String @default("sse")
auth_type String?
credentials Json? @default("{}")
created_at DateTime? @default(now()) @map("created_at")
created_by String?
updated_at DateTime? @default(now()) @updatedAt @map("updated_at")
+3
View File
@@ -1342,6 +1342,7 @@ def test_add_update_server_with_alias():
mock_mcp_server.url = "https://test-server.com/mcp"
mock_mcp_server.transport = MCPTransport.http
mock_mcp_server.auth_type = None
mock_mcp_server.credentials = {}
mock_mcp_server.description = "Test server description"
mock_mcp_server.mcp_info = {}
mock_mcp_server.static_headers = {}
@@ -1380,6 +1381,7 @@ def test_add_update_server_without_alias():
mock_mcp_server.url = "https://test-server.com/mcp"
mock_mcp_server.transport = MCPTransport.http
mock_mcp_server.auth_type = None
mock_mcp_server.credentials = {}
mock_mcp_server.description = "Test server description"
mock_mcp_server.mcp_info = {}
mock_mcp_server.static_headers = {}
@@ -1418,6 +1420,7 @@ def test_add_update_server_fallback_to_server_id():
mock_mcp_server.url = "https://test-server.com/mcp"
mock_mcp_server.transport = MCPTransport.http
mock_mcp_server.auth_type = None
mock_mcp_server.credentials = {}
mock_mcp_server.description = "Test server description"
mock_mcp_server.mcp_info = {}
mock_mcp_server.static_headers = {}
@@ -15,6 +15,7 @@ from litellm.proxy._types import (
MCPTransportType,
MCPTransport,
NewMCPServerRequest,
UpdateMCPServerRequest,
LiteLLM_MCPServerTable,
LitellmUserRoles,
UserAPIKeyAuth,
@@ -165,6 +166,7 @@ async def test_create_mcp_server_direct():
updated_by=LITELLM_PROXY_ADMIN_NAME,
teams=[],
)
expected_response.credentials = {"auth_value": "secret"}
# Mock the database calls
mock_get_server.return_value = None # Server doesn't exist yet
@@ -188,6 +190,7 @@ async def test_create_mcp_server_direct():
assert result.alias == expected_alias # Check against normalized alias
assert result.url == mcp_server_request.url
assert result.transport == mcp_server_request.transport
assert result.credentials is None
# Verify mocks were called
mock_get_server.assert_called_once_with(mock_prisma, server_id)
@@ -353,6 +356,69 @@ async def test_create_mcp_server_invalid_alias():
)
@pytest.mark.asyncio
async def test_edit_mcp_server_redacts_credentials():
with mock.patch(
"litellm.proxy.management_endpoints.mcp_management_endpoints.MCP_AVAILABLE",
True,
), mock.patch(
"litellm.proxy.management_endpoints.mcp_management_endpoints.get_prisma_client_or_throw"
) as mock_get_prisma, mock.patch(
"litellm.proxy.management_endpoints.mcp_management_endpoints.update_mcp_server",
new_callable=mock.AsyncMock,
) as mock_update, mock.patch(
"litellm.proxy.management_endpoints.mcp_management_endpoints.validate_and_normalize_mcp_server_payload",
autospec=True,
) as mock_validate, mock.patch(
"litellm.proxy.management_endpoints.mcp_management_endpoints.global_mcp_server_manager"
) as mock_manager:
from litellm.proxy.management_endpoints.mcp_management_endpoints import (
edit_mcp_server,
)
mock_prisma = mock.Mock()
mock_get_prisma.return_value = mock_prisma
mock_manager.add_update_server = mock.Mock()
mock_manager.reload_servers_from_database = mock.AsyncMock()
server_id = str(uuid.uuid4())
updated_server = LiteLLM_MCPServerTable(
server_id=server_id,
alias="Updated Server",
url="https://updated.example.com/mcp",
transport=MCPTransport.http,
created_at=datetime.now(),
updated_at=datetime.now(),
teams=[],
)
updated_server.credentials = {"auth_value": "secret"}
mock_update.return_value = updated_server
payload = UpdateMCPServerRequest(
server_id=server_id,
alias="Updated Server",
url="https://updated.example.com/mcp",
transport=MCPTransport.http,
)
user_auth = UserAPIKeyAuth(
api_key=TEST_MASTER_KEY,
user_id="test-user",
user_role=LitellmUserRoles.PROXY_ADMIN,
)
result = await edit_mcp_server(payload=payload, user_api_key_dict=user_auth)
assert result.server_id == server_id
assert result.credentials is None
assert updated_server.credentials == {"auth_value": "secret"}
mock_validate.assert_called_once()
mock_update.assert_awaited_once()
mock_manager.add_update_server.assert_called_once_with(updated_server)
mock_manager.reload_servers_from_database.assert_awaited_once()
def test_validate_mcp_server_name_direct():
"""
Test the validation function directly to ensure it works.
@@ -186,6 +186,9 @@ class TestListMCPServers:
return_value=mock_servers_with_health
)
for idx, server in enumerate(mock_servers_with_health):
server.credentials = {"auth_value": f"secret_{idx}"}
with patch(
"litellm.proxy.management_endpoints.mcp_management_endpoints.global_mcp_server_manager",
mock_manager,
@@ -205,6 +208,7 @@ class TestListMCPServers:
# Verify results
assert len(result) == 2
assert all(server.credentials is None for server in result)
# Check that both config servers are returned
server_ids = [server.server_id for server in result]
@@ -315,6 +319,9 @@ class TestListMCPServers:
return_value=mock_servers_with_health
)
for idx, server in enumerate(mock_servers_with_health):
server.credentials = {"auth_value": f"secret_{idx}"}
with patch(
"litellm.proxy.management_endpoints.mcp_management_endpoints.global_mcp_server_manager",
mock_manager,
@@ -334,6 +341,7 @@ class TestListMCPServers:
# Verify results
assert len(result) == 4
assert all(server.credentials is None for server in result)
# Check that both DB and config servers are returned
server_ids = [server.server_id for server in result]
@@ -428,6 +436,9 @@ class TestListMCPServers:
return_value=mock_servers_with_health
)
for idx, server in enumerate(mock_servers_with_health):
server.credentials = {"auth_value": f"secret_{idx}"}
with patch(
"litellm.proxy.management_endpoints.mcp_management_endpoints.global_mcp_server_manager",
mock_manager,
@@ -447,6 +458,7 @@ class TestListMCPServers:
# Verify results - should only return servers user has access to
assert len(result) == 2
assert all(server.credentials is None for server in result)
# Check that only allowed servers are returned
server_ids = [server.server_id for server in result]
@@ -464,6 +476,51 @@ class TestListMCPServers:
assert server.url == "https://actions.zapier.com/mcp/sse"
@pytest.mark.asyncio
async def test_fetch_single_mcp_server_redacts_credentials(self):
mock_server = generate_mock_mcp_server_db_record(
server_id="server-1", alias="Server 1"
)
mock_server.credentials = {"auth_value": "top-secret"}
mock_prisma_client = MagicMock()
mock_health_result = {
"status": "healthy",
"last_health_check": datetime.now().isoformat(),
"error": None,
}
mock_user_auth = generate_mock_user_api_key_auth(
user_role=LitellmUserRoles.PROXY_ADMIN
)
with patch(
"litellm.proxy.management_endpoints.mcp_management_endpoints.get_prisma_client_or_throw",
return_value=mock_prisma_client,
), patch(
"litellm.proxy.management_endpoints.mcp_management_endpoints.get_mcp_server",
AsyncMock(return_value=mock_server),
), patch(
"litellm.proxy.management_endpoints.mcp_management_endpoints.global_mcp_server_manager.health_check_server",
AsyncMock(return_value=mock_health_result),
), patch(
"litellm.proxy.management_endpoints.mcp_management_endpoints._user_has_admin_view",
return_value=True,
):
from litellm.proxy.management_endpoints.mcp_management_endpoints import (
fetch_mcp_server,
)
result = await fetch_mcp_server(
server_id="server-1", user_api_key_dict=mock_user_auth
)
assert result.server_id == "server-1"
assert result.credentials is None
assert mock_server.credentials == {"auth_value": "top-secret"}
assert result.status == "healthy"
class TestMCPHealthCheckEndpoints:
"""Test MCP health check endpoints"""
@@ -721,6 +778,8 @@ class TestMCPHealthCheckEndpoints:
return_value=[mock_server]
)
mock_server.credentials = {"auth_value": "secret"}
mock_user_auth = generate_mock_user_api_key_auth(
user_role=LitellmUserRoles.PROXY_ADMIN
)
@@ -749,3 +808,4 @@ class TestMCPHealthCheckEndpoints:
assert server.status == "healthy"
assert server.last_health_check is not None
assert server.health_check_error is None
assert server.credentials is None
@@ -7,8 +7,6 @@ import NotificationsManager from "../molecules/notifications_manager";
export function ToolTestPanel({
tool,
needsAuth,
authValue,
onSubmit,
isLoading,
result,
@@ -16,8 +14,6 @@ export function ToolTestPanel({
onClose,
}: {
tool: MCPTool;
needsAuth: boolean;
authValue?: string | null;
onSubmit: (args: Record<string, any>) => void;
isLoading: boolean;
result: any | null;
@@ -3,7 +3,7 @@ import { Modal, Tooltip, Form, Select } from "antd";
import { InfoCircleOutlined } from "@ant-design/icons";
import { Button, TextInput } from "@tremor/react";
import { createMCPServer } from "../networking";
import { MCPServer, MCPServerCostInfo } from "./types";
import { AUTH_TYPE, MCPServer, MCPServerCostInfo } from "./types";
import MCPServerCostConfig from "./mcp_server_cost_config";
import MCPConnectionStatus from "./mcp_connection_status";
import MCPToolConfiguration from "./mcp_tool_configuration";
@@ -25,6 +25,12 @@ interface CreateMCPServerProps {
availableAccessGroups: string[];
}
const AUTH_TYPES_REQUIRING_AUTH_VALUE = [
AUTH_TYPE.API_KEY,
AUTH_TYPE.BEARER_TOKEN,
AUTH_TYPE.BASIC,
];
const CreateMCPServer: React.FC<CreateMCPServerProps> = ({
userRole,
accessToken,
@@ -43,6 +49,10 @@ const CreateMCPServer: React.FC<CreateMCPServerProps> = ({
const [transportType, setTransportType] = useState<string>("");
const [searchValue, setSearchValue] = useState<string>("");
const [urlWarning, setUrlWarning] = useState<string>("");
const authType = formValues.auth_type as string | undefined;
const shouldShowAuthValueField = authType
? AUTH_TYPES_REQUIRING_AUTH_VALUE.includes(authType)
: false;
// Function to check URL format based on transport type
const checkUrlFormat = (url: string, transport: string) => {
@@ -63,7 +73,12 @@ const CreateMCPServer: React.FC<CreateMCPServerProps> = ({
const handleCreate = async (values: Record<string, any>) => {
setIsLoading(true);
try {
const { static_headers: staticHeadersList, stdio_config: rawStdioConfig, ...restValues } = values;
const {
static_headers: staticHeadersList,
stdio_config: rawStdioConfig,
credentials: credentialValues,
...restValues
} = values;
// Transform access groups into objects with name property
const accessGroups = restValues.mcp_access_groups;
@@ -79,6 +94,26 @@ const CreateMCPServer: React.FC<CreateMCPServerProps> = ({
}, {})
: {} as Record<string, string>;
const credentialsPayload =
credentialValues && typeof credentialValues === "object"
? Object.entries(credentialValues).reduce((acc: Record<string, any>, [key, value]) => {
if (value === undefined || value === null || value === "") {
return acc;
}
if (key === "scopes") {
if (Array.isArray(value)) {
const filteredScopes = value.filter((scope) => scope != null && scope !== "");
if (filteredScopes.length > 0) {
acc[key] = filteredScopes;
}
}
} else {
acc[key] = value;
}
return acc;
}, {})
: undefined;
// Process stdio configuration if present
let stdioFields = {};
if (rawStdioConfig && transportType === "stdio") {
@@ -135,6 +170,18 @@ const CreateMCPServer: React.FC<CreateMCPServerProps> = ({
static_headers: staticHeaders,
};
payload.static_headers = staticHeaders;
const includeCredentials =
restValues.auth_type && AUTH_TYPES_REQUIRING_AUTH_VALUE.includes(restValues.auth_type);
if (
includeCredentials &&
credentialsPayload &&
Object.keys(credentialsPayload).length > 0
) {
payload.credentials = credentialsPayload;
}
console.log(`Payload: ${JSON.stringify(payload)}`);
if (accessToken != null) {
@@ -393,6 +440,27 @@ const CreateMCPServer: React.FC<CreateMCPServerProps> = ({
</Form.Item>
)}
{transportType !== "stdio" && shouldShowAuthValueField && (
<Form.Item
label={
<span className="text-sm font-medium text-gray-700 flex items-center">
Authentication Value
<Tooltip title="Token, password, or header value to send with each request for the selected auth type.">
<InfoCircleOutlined className="ml-2 text-blue-400 hover:text-blue-600 cursor-help" />
</Tooltip>
</span>
}
name={["credentials", "auth_value"]}
rules={[{ required: true, message: "Please enter the authentication value" }]}
>
<TextInput
type="password"
placeholder="Enter token or secret"
className="rounded-lg border-gray-300 focus:border-blue-500 focus:ring-blue-500"
/>
</Form.Item>
)}
{/* Stdio Configuration - only show for stdio transport */}
<StdioConfiguration isVisible={transportType === "stdio"} />
</div>
@@ -1,106 +0,0 @@
// Utility functions for managing MCP server authentication tokens in localStorage
const MCP_AUTH_STORAGE_KEY = "litellm_mcp_auth_tokens";
export interface MCPAuthToken {
serverId: string;
serverAlias?: string;
authValue: string;
authType: string;
timestamp: number;
}
export interface MCPAuthStorage {
[serverId: string]: MCPAuthToken;
}
/**
* Get all stored MCP authentication tokens
*/
export const getMCPAuthTokens = (): MCPAuthStorage => {
try {
const stored = localStorage.getItem(MCP_AUTH_STORAGE_KEY);
return stored ? JSON.parse(stored) : {};
} catch (error) {
console.error("Error reading MCP auth tokens from localStorage:", error);
return {};
}
};
/**
* Get authentication token for a specific MCP server
*/
export const getMCPAuthToken = (serverId: string, serverAlias?: string): string | null => {
try {
const tokens = getMCPAuthTokens();
const token = tokens[serverId];
// If token exists, check if serverAlias matches (both can be undefined)
if (token && token.serverAlias === serverAlias) {
return token.authValue;
}
// If no serverAlias was provided and token exists without serverAlias, return it
if (token && !serverAlias && !token.serverAlias) {
return token.authValue;
}
return null;
} catch (error) {
console.error("Error getting MCP auth token:", error);
return null;
}
};
/**
* Store authentication token for an MCP server
*/
export const setMCPAuthToken = (serverId: string, authValue: string, authType: string, serverAlias?: string): void => {
try {
const tokens = getMCPAuthTokens();
tokens[serverId] = {
serverId,
serverAlias,
authValue,
authType,
timestamp: Date.now(),
};
localStorage.setItem(MCP_AUTH_STORAGE_KEY, JSON.stringify(tokens));
} catch (error) {
console.error("Error storing MCP auth token:", error);
}
};
/**
* Remove authentication token for an MCP server
*/
export const removeMCPAuthToken = (serverId: string): void => {
try {
const tokens = getMCPAuthTokens();
delete tokens[serverId];
localStorage.setItem(MCP_AUTH_STORAGE_KEY, JSON.stringify(tokens));
} catch (error) {
console.error("Error removing MCP auth token:", error);
}
};
/**
* Clear all MCP authentication tokens (useful for logout)
*/
export const clearMCPAuthTokens = (): void => {
try {
localStorage.removeItem(MCP_AUTH_STORAGE_KEY);
} catch (error) {
console.error("Error clearing MCP auth tokens:", error);
}
};
/**
* Check if a token exists for a server
*/
export const hasMCPAuthToken = (serverId: string, serverAlias?: string): boolean => {
const token = getMCPAuthToken(serverId, serverAlias);
return token !== null;
};
@@ -1,7 +1,8 @@
import React, { useState, useEffect } from "react";
import { Form, Select, Button as AntdButton } from "antd";
import { Form, Select, Button as AntdButton, Tooltip } from "antd";
import { InfoCircleOutlined } from "@ant-design/icons";
import { Button, TextInput, TabGroup, TabList, Tab, TabPanels, TabPanel } from "@tremor/react";
import { MCPServer, MCPServerCostInfo } from "./types";
import { AUTH_TYPE, MCPServer, MCPServerCostInfo } from "./types";
import { updateMCPServer, testMCPToolsListRequest } from "../networking";
import MCPServerCostConfig from "./mcp_server_cost_config";
import MCPPermissionManagement from "./MCPPermissionManagement";
@@ -17,6 +18,12 @@ interface MCPServerEditProps {
availableAccessGroups: string[];
}
const AUTH_TYPES_REQUIRING_AUTH_VALUE = [
AUTH_TYPE.API_KEY,
AUTH_TYPE.BEARER_TOKEN,
AUTH_TYPE.BASIC,
];
const MCPServerEdit: React.FC<MCPServerEditProps> = ({
mcpServer,
accessToken,
@@ -31,6 +38,10 @@ const MCPServerEdit: React.FC<MCPServerEditProps> = ({
const [searchValue, setSearchValue] = useState<string>("");
const [aliasManuallyEdited, setAliasManuallyEdited] = useState(false);
const [allowedTools, setAllowedTools] = useState<string[]>([]);
const authType = Form.useWatch("auth_type", form) as string | undefined;
const shouldShowAuthValueField = authType
? AUTH_TYPES_REQUIRING_AUTH_VALUE.includes(authType)
: false;
const initialStaticHeaders = React.useMemo(() => {
if (!mcpServer.static_headers) {
@@ -148,7 +159,11 @@ const MCPServerEdit: React.FC<MCPServerEditProps> = ({
if (!accessToken) return;
try {
// Ensure access groups is always a string array
const { static_headers: staticHeadersList, ...restValues } = values;
const {
static_headers: staticHeadersList,
credentials: credentialValues,
...restValues
} = values;
const accessGroups = (restValues.mcp_access_groups || []).map((g: any) =>
typeof g === "string" ? g : g.name || String(g),
@@ -165,6 +180,26 @@ const MCPServerEdit: React.FC<MCPServerEditProps> = ({
}, {})
: {} as Record<string, string>;
const credentialsPayload =
credentialValues && typeof credentialValues === "object"
? Object.entries(credentialValues).reduce((acc: Record<string, any>, [key, value]) => {
if (value === undefined || value === null || value === "") {
return acc;
}
if (key === "scopes") {
if (Array.isArray(value)) {
const filteredScopes = value.filter((scope) => scope != null && scope !== "");
if (filteredScopes.length > 0) {
acc[key] = filteredScopes;
}
}
} else {
acc[key] = value;
}
return acc;
}, {})
: undefined;
// Prepare the payload with cost configuration and permission fields
const payload = {
...restValues,
@@ -183,6 +218,17 @@ const MCPServerEdit: React.FC<MCPServerEditProps> = ({
static_headers: staticHeaders,
};
const includeCredentials =
restValues.auth_type && AUTH_TYPES_REQUIRING_AUTH_VALUE.includes(restValues.auth_type);
if (
includeCredentials &&
credentialsPayload &&
Object.keys(credentialsPayload).length > 0
) {
payload.credentials = credentialsPayload;
}
const updated = await updateMCPServer(accessToken, payload);
NotificationsManager.success("MCP Server updated successfully");
onSuccess(updated);
@@ -250,6 +296,34 @@ const MCPServerEdit: React.FC<MCPServerEditProps> = ({
</Select>
</Form.Item>
{shouldShowAuthValueField && (
<Form.Item
label={
<span className="text-sm font-medium text-gray-700 flex items-center">
Authentication Value
<Tooltip title="Token, password, or header value to send with each request for the selected auth type.">
<InfoCircleOutlined className="ml-2 text-blue-400 hover:text-blue-600 cursor-help" />
</Tooltip>
</span>
}
name={["credentials", "auth_value"]}
rules={[
{
validator: (_, value) =>
value && typeof value === "string" && value.trim() === ""
? Promise.reject(new Error("Authentication value cannot be empty"))
: Promise.resolve(),
},
]}
>
<TextInput
type="password"
placeholder="Enter token or secret (leave blank to keep existing)"
className="rounded-lg border-gray-300 focus:border-blue-500 focus:ring-blue-500"
/>
</Form.Item>
)}
{/* Permission Management / Access Control Section */}
<div className="mt-6">
<MCPPermissionManagement
@@ -1,122 +1,15 @@
import React, { useState, useEffect } from "react";
import { useQuery, useMutation } from "@tanstack/react-query";
import { ToolTestPanel } from "./ToolTestPanel";
import { MCPTool, MCPToolsViewerProps, CallMCPToolResponse, mcpServerHasAuth } from "./types";
import { MCPTool, MCPToolsViewerProps, CallMCPToolResponse } from "./types";
import { listMCPTools, callMCPTool } from "../networking";
import { getMCPAuthToken, setMCPAuthToken, removeMCPAuthToken } from "./mcp_auth_storage";
import { Modal, Input, Form } from "antd";
import { Button, Card, Title, Text } from "@tremor/react";
import { RobotOutlined, SafetyOutlined, ToolOutlined } from "@ant-design/icons";
import { AUTH_TYPE } from "./types";
import NotificationsManager from "../molecules/notifications_manager";
type AuthModalProps = {
visible: boolean;
onOk: (values: any) => void;
onCancel: () => void;
authType?: string | null;
};
export const AuthModal = ({ visible, onOk, onCancel, authType }: AuthModalProps) => {
const [form] = Form.useForm();
// Handler for modal OK
const handleOk = () => {
form.validateFields().then((values) => {
if (authType === AUTH_TYPE.BASIC) {
onOk(`${values.username.trim()}:${values.password.trim()}`);
} else {
onOk(values.authValue.trim());
}
});
};
let content;
if (authType === AUTH_TYPE.API_KEY || authType === AUTH_TYPE.BEARER_TOKEN) {
const label = authType === AUTH_TYPE.API_KEY ? "API Key" : "Bearer Token";
content = (
<Form.Item name="authValue" label={label} rules={[{ required: true, message: `Please input your ${label}` }]}>
<Input.Password />
</Form.Item>
);
} else if (authType === AUTH_TYPE.BASIC) {
content = (
<>
<Form.Item name="username" label="Username" rules={[{ required: true, message: "Please input your username" }]}>
<Input />
</Form.Item>
<Form.Item name="password" label="Password" rules={[{ required: true, message: "Please input your password" }]}>
<Input.Password />
</Form.Item>
</>
);
}
return (
<Modal open={visible} title="Authentication" onOk={handleOk} onCancel={onCancel} destroyOnClose>
<Form form={form} layout="vertical">
{content}
</Form>
</Modal>
);
};
const AuthSection = ({
authType,
onAuthSubmit,
onClearAuth,
hasAuth,
}: {
authType: string | null | undefined;
onAuthSubmit: (value: string) => void;
onClearAuth: () => void;
hasAuth: boolean;
}) => {
const [modalVisible, setModalVisible] = useState(false);
const handleAddAuth = () => setModalVisible(true);
const handleModalOk = (authValue: string) => {
onAuthSubmit(authValue);
setModalVisible(false);
};
const handleModalCancel = () => setModalVisible(false);
const handleClearAuth = () => {
onClearAuth();
};
return (
<div className="space-y-3">
<div className="flex items-center justify-between">
<Text className="text-sm font-medium text-gray-700">Authentication {hasAuth ? "✓" : ""}</Text>
<div className="flex gap-2">
{hasAuth && (
<Button
onClick={handleClearAuth}
size="sm"
variant="secondary"
className="text-xs text-red-600 hover:text-red-700"
>
Clear
</Button>
)}
<Button onClick={handleAddAuth} size="sm" variant="secondary" className="text-xs">
{hasAuth ? "Update" : "Add Auth"}
</Button>
</div>
</div>
<Text className="text-xs text-gray-500">
{hasAuth ? "Authentication configured and saved locally" : "Some tools may require authentication"}
</Text>
<AuthModal visible={modalVisible} onOk={handleModalOk} onCancel={handleModalCancel} authType={authType} />
</div>
);
};
const MCPToolsViewer = ({
serverId,
accessToken,
@@ -125,47 +18,20 @@ const MCPToolsViewer = ({
userID,
serverAlias, // Add serverAlias prop
}: MCPToolsViewerProps) => {
const [mcpAuthValue, setMcpAuthValue] = useState("");
const [selectedTool, setSelectedTool] = useState<MCPTool | null>(null);
const [toolResult, setToolResult] = useState<CallMCPToolResponse | null>(null);
const [toolError, setToolError] = useState<Error | null>(null);
// Load stored auth token on component mount
useEffect(() => {
if (mcpServerHasAuth(auth_type)) {
const storedAuthValue = getMCPAuthToken(serverId, serverAlias || undefined);
if (storedAuthValue) {
setMcpAuthValue(storedAuthValue);
}
}
}, [serverId, serverAlias, auth_type]);
// Function to handle auth submission with localStorage persistence
const handleAuthSubmit = (authValue: string) => {
setMcpAuthValue(authValue);
if (authValue && mcpServerHasAuth(auth_type)) {
setMCPAuthToken(serverId, authValue, auth_type || "none", serverAlias || undefined);
NotificationsManager.success("Authentication token saved locally");
}
};
// Function to clear auth token
const handleClearAuth = () => {
setMcpAuthValue("");
removeMCPAuthToken(serverId);
NotificationsManager.info("Authentication token cleared");
};
// Query to fetch MCP tools
const {
data: mcpToolsResponse,
isLoading: isLoadingTools,
error: mcpToolsError,
} = useQuery({
queryKey: ["mcpTools", serverId, mcpAuthValue, serverAlias],
queryKey: ["mcpTools", serverId],
queryFn: () => {
if (!accessToken) throw new Error("Access Token required");
return listMCPTools(accessToken, serverId, mcpAuthValue, serverAlias || undefined);
return listMCPTools(accessToken, serverId);
},
enabled: !!accessToken,
staleTime: 30000, // Consider data fresh for 30 seconds
@@ -173,7 +39,7 @@ const MCPToolsViewer = ({
// Mutation for calling a tool
const { mutate: executeTool, isPending: isCallingTool } = useMutation({
mutationFn: async (args: { tool: MCPTool; arguments: Record<string, any>; authValue: string }) => {
mutationFn: async (args: { tool: MCPTool; arguments: Record<string, any> }) => {
if (!accessToken) throw new Error("Access Token required");
try {
@@ -181,8 +47,6 @@ const MCPToolsViewer = ({
accessToken,
args.tool.name,
args.arguments,
args.authValue,
serverAlias || undefined,
);
return result;
} catch (error) {
@@ -200,7 +64,6 @@ const MCPToolsViewer = ({
});
const toolsData = mcpToolsResponse?.tools || [];
const hasAuth = mcpAuthValue !== "";
return (
<div className="w-full h-screen p-4 bg-white">
@@ -317,44 +180,6 @@ const MCPToolsViewer = ({
</div>
)}
</div>
{/* Authentication Section - Below tools list */}
{mcpServerHasAuth(auth_type) && (
<div className="pt-4 border-t border-gray-200 flex-shrink-0 mt-6">
{!hasAuth ? (
/* Prominent display when auth required but not provided */
<div className="p-4 bg-gradient-to-r from-orange-50 to-red-50 border border-orange-200 rounded-lg">
<div className="flex items-center mb-3">
<SafetyOutlined className="mr-2 text-orange-600 text-lg" />
<Text className="font-semibold text-orange-800">Authentication Required</Text>
</div>
<Text className="text-sm text-orange-700 mb-4">
This MCP server requires authentication. You must add your credentials below to access the
tools.
</Text>
<AuthSection
authType={auth_type}
onAuthSubmit={handleAuthSubmit}
onClearAuth={handleClearAuth}
hasAuth={hasAuth}
/>
</div>
) : (
/* Subtle display when already authenticated */
<>
<Text className="font-medium block mb-3 text-gray-700 flex items-center">
<SafetyOutlined className="mr-2" /> Authentication
</Text>
<AuthSection
authType={auth_type}
onAuthSubmit={handleAuthSubmit}
onClearAuth={handleClearAuth}
hasAuth={hasAuth}
/>
</>
)}
</div>
)}
</div>
</div>
@@ -379,10 +204,8 @@ const MCPToolsViewer = ({
<div className="h-full">
<ToolTestPanel
tool={selectedTool}
needsAuth={mcpServerHasAuth(auth_type)}
authValue={mcpAuthValue}
onSubmit={(args) => {
executeTool({ tool: selectedTool, arguments: args, authValue: mcpAuthValue });
executeTool({ tool: selectedTool, arguments: args });
}}
result={toolResult}
error={toolError}
@@ -34,10 +34,6 @@ export const handleAuth = (authType?: string | null): string => {
return authType;
};
export const mcpServerHasAuth = (authType?: string | null): boolean => {
return handleAuth(authType) !== AUTH_TYPE.NONE;
};
// Define the structure for tool input schema properties
export interface InputSchemaProperty {
type: string;
@@ -15,7 +15,6 @@ import {
import { clearTokenCookies } from "@/utils/cookieUtils";
import { fetchProxySettings } from "@/utils/proxyUtils";
import { useTheme } from "@/contexts/ThemeContext";
import { clearMCPAuthTokens } from "./mcp_tools/mcp_auth_storage";
import useFeatureFlags from "@/hooks/useFeatureFlags";
interface NavbarProps {
@@ -71,7 +70,6 @@ const Navbar: React.FC<NavbarProps> = ({
const handleLogout = () => {
clearTokenCookies();
clearMCPAuthTokens(); // Clear MCP auth tokens on logout
window.location.href = logoutUrl;
};
@@ -5750,7 +5750,7 @@ export const testSearchToolConnection = async (accessToken: string, litellmParam
}
};
export const listMCPTools = async (accessToken: string, serverId: string, authValue?: string, serverAlias?: string) => {
export const listMCPTools = async (accessToken: string, serverId: string) => {
try {
// Construct base URL
let url = proxyBaseUrl
@@ -5764,14 +5764,6 @@ export const listMCPTools = async (accessToken: string, serverId: string, authVa
"Content-Type": "application/json",
};
// Use new server-specific auth header format if serverAlias is provided
if (serverAlias && authValue) {
headers[`x-mcp-${serverAlias}-authorization`] = authValue;
} else if (authValue) {
// Fall back to deprecated x-mcp-auth header for backward compatibility
headers[MCP_AUTH_HEADER] = authValue;
}
const response = await fetch(url, {
method: "GET",
headers,
@@ -5805,9 +5797,7 @@ export const listMCPTools = async (accessToken: string, serverId: string, authVa
export const callMCPTool = async (
accessToken: string,
toolName: string,
toolArguments: Record<string, any>,
authValue: string,
serverAlias?: string,
toolArguments: Record<string, any>
) => {
try {
// Construct base URL
@@ -5820,14 +5810,6 @@ export const callMCPTool = async (
"Content-Type": "application/json",
};
// Use new server-specific auth header format if serverAlias is provided
if (serverAlias) {
headers[`x-mcp-${serverAlias}-authorization`] = authValue;
} else {
// Fall back to deprecated x-mcp-auth header for backward compatibility
headers[MCP_AUTH_HEADER] = authValue;
}
const response = await fetch(url, {
method: "POST",
headers,
@@ -9,6 +9,12 @@ interface MCPServerConfig {
auth_type?: string;
mcp_info?: any;
static_headers?: Record<string, string>;
credentials?: {
auth_value?: string;
client_id?: string;
client_secret?: string;
scopes?: string[];
};
}
interface UseTestMCPConnectionProps {
@@ -41,6 +47,7 @@ export const useTestMCPConnection = ({
const canFetchTools = !!(formValues.url && formValues.transport && formValues.auth_type && accessToken);
const staticHeadersKey = JSON.stringify(formValues.static_headers ?? {});
const credentialsKey = JSON.stringify(formValues.credentials ?? {});
const fetchTools = async () => {
if (!accessToken || !formValues.url) {
@@ -74,6 +81,29 @@ export const useTestMCPConnection = ({
)
: {} as Record<string, string>;
const credentials =
formValues.credentials && typeof formValues.credentials === "object"
? Object.entries(formValues.credentials).reduce(
(acc: Record<string, any>, [key, value]) => {
if (value === undefined || value === null || value === "") {
return acc;
}
if (key === "scopes") {
if (Array.isArray(value)) {
const normalizedScopes = value.filter((scope) => scope != null && scope !== "");
if (normalizedScopes.length > 0) {
acc[key] = normalizedScopes;
}
}
} else {
acc[key] = value;
}
return acc;
},
{},
)
: undefined;
const mcpServerConfig: MCPServerConfig = {
server_id: formValues.server_id || "",
server_name: formValues.server_name || "",
@@ -84,6 +114,10 @@ export const useTestMCPConnection = ({
static_headers: staticHeaders,
};
if (credentials && Object.keys(credentials).length > 0) {
mcpServerConfig.credentials = credentials;
}
const toolsResponse = await testMCPToolsListRequest(accessToken, mcpServerConfig);
if (toolsResponse.tools && !toolsResponse.error) {
@@ -126,7 +160,16 @@ export const useTestMCPConnection = ({
clearTools();
}
// eslint-disable-next-line react-hooks/exhaustive-deps
}, [formValues.url, formValues.transport, formValues.auth_type, accessToken, enabled, canFetchTools, staticHeadersKey]);
}, [
formValues.url,
formValues.transport,
formValues.auth_type,
accessToken,
enabled,
canFetchTools,
staticHeadersKey,
credentialsKey,
]);
return {
tools,