Mcp user permissions (#21462)

* feat(schema.prisma): add object permissions for end users

allows controlling if end user can call specific mcp servers

* feat: cleanup for customer_endpoints support of object permission id

* fix: cleanup str

* feat(customers/): enforce end user can only call allowed mcps - if configured

* docs: document customer/end user object permission usage

* feat: enforce end user permissions on MCP tool calls

This commit implements end user permission enforcement for MCP servers:

1. Always add server prefixes to MCP tool names
   - Removed conditional logic that only added prefixes when multiple servers existed
   - Now always adds server prefix for consistent tool naming across all scenarios
   - Updated 5 locations in server.py (list_tools, get_prompts, get_resources,
     get_resource_templates, get_prompt)

2. Created MCP End User Permission Guardrail Hook
   - New guardrail hook: litellm/proxy/guardrails/guardrail_hooks/mcp_end_user_permission.py
   - Runs on post_call to validate tool calls in LLM responses
   - Extracts MCP server name from tool names (splits on first '-')
   - Checks if end_user_id has permissions for the MCP server
   - Raises GuardrailRaisedException if end user lacks permission
   - Supports both streaming and non-streaming responses

3. Added comprehensive tests
   - Test file: tests/test_litellm/proxy/guardrails/guardrail_hooks/test_mcp_end_user_permission.py
   - Tests cover: authorized/unauthorized tools, non-MCP tools, no end_user scenarios
   - Tests permission checking logic and exception raising

The hook integrates with the existing MCPRequestHandler._get_allowed_mcp_servers_for_end_user
to fetch end user permissions and enforce access control at the response level.

Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>

* refactor: remove redundant add_prefix variable assignments

Simplified the code by removing intermediate `add_prefix` variable
assignments and passing `True` directly to function calls since
we now always add server prefixes.

Changes:
- Removed `add_prefix = True` variable assignments in 5 locations
- Changed `add_prefix=add_prefix` to `add_prefix=True` in function calls
- Added inline comments to clarify the behavior

This makes the code more concise and clearer in intent.

Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>

* feat(auth_utils.py): support safety_identifier as a valid way of passing the end user id for responses api

* feat(llms): ensure 'tools' is correctly updated for responses api

* fix: fix greptile feedback

* feat: transformation.py

proper responses api tool handling for guardrail translation layer

---------

Co-authored-by: Claude Sonnet 4.5 <noreply@anthropic.com>
This commit is contained in:
Krish Dholakia
2026-02-18 18:53:59 -08:00
committed by GitHub
parent 936e04e0e1
commit e00c181f0c
30 changed files with 1616 additions and 188 deletions
+36
View File
@@ -0,0 +1,36 @@
{
"permissions": {
"allow": [
"Bash(git show:*)",
"Bash(git worktree add:*)",
"Read(//Users/krrishdholakia/Documents/litellm/**)",
"Read(//Users/krrishdholakia/Documents/litellm-claude-code-guardrails/litellm/types/**)",
"Read(//Users/krrishdholakia/Documents/litellm-claude-code-guardrails/**)",
"Read(//Users/krrishdholakia/Documents/litellm-claude-code-guardrails/litellm/**)",
"Bash(python:*)",
"Bash(python -c \"\nimport sys; sys.path.insert\\(0, ''.''\\)\nfrom litellm.proxy.guardrails.guardrail_hooks.claude_code.guardrail import ClaudeCodeGuardrail, HOSTED_TOOL_PREFIXES\nprint\\(''HOSTED_TOOL_PREFIXES:'', HOSTED_TOOL_PREFIXES\\)\nprint\\(''ClaudeCodeGuardrail imported OK''\\)\n\")",
"Read(//Users/krrishdholakia/Documents/litellm-mcp-jwt-groups/litellm/proxy/**)",
"Read(//Users/krrishdholakia/Documents/litellm-mcp-jwt-groups/**)",
"Bash(poetry run pytest:*)",
"Bash(git add:*)",
"Bash(git commit:*)",
"Bash(poetry run python:*)",
"Bash(poetry run pip:*)",
"Bash(git reset:*)",
"Bash(git cherry-pick:*)",
"Bash(git checkout:*)",
"Read(//Users/krrishdholakia/Documents/litellm/litellm/proxy/guardrails/guardrail_hooks/**)",
"Read(//Users/krrishdholakia/Documents/**)",
"Bash(git -C /Users/krrishdholakia/Documents/litellm-mcp-user-permissions worktree list)",
"Bash(ls:*)"
],
"additionalDirectories": [
"/Users/krrishdholakia/Documents/litellm-mcp-group-plan/plan",
"/Users/krrishdholakia/Documents/litellm-claude-code-guardrails/litellm/proxy/guardrails/guardrail_hooks/claude_code",
"/Users/krrishdholakia/Documents/litellm-claude-code-guardrails/litellm/types",
"/Users/krrishdholakia/Documents/litellm-claude-code-guardrails",
"/Users/krrishdholakia/Documents/litellm-mcp-jwt-groups/litellm/proxy",
"/Users/krrishdholakia/Documents/litellm-mcp-jwt-groups/tests/test_litellm/proxy/auth"
]
}
}
+62
View File
@@ -808,6 +808,68 @@ If your stdio MCP server needs per-request credentials, you can map HTTP headers
In this example, when a client makes a request with the `X-GITHUB_PERSONAL_ACCESS_TOKEN` header, the proxy forwards that value into the stdio process as the `GITHUB_PERSONAL_ACCESS_TOKEN` environment variable.
## Control MCP Access for End Users
Control which MCP servers end users of your AI application can access (e.g. users of an internal chat UI). Pass the customer ID in the `x-litellm-end-user-id` header to:
- Enforce object permissions (limit which MCP servers they can access)
- Apply customer-specific budgets
- Track spend per customer
**FastMCP Client Example:**
```python title="Track customer spend with x-litellm-end-user-id" showLineNumbers
from fastmcp import Client
import asyncio
# MCP client configuration with customer tracking
config = {
"mcpServers": {
"github": {
"url": "http://localhost:4000/github_mcp/mcp",
"headers": {
"x-litellm-api-key": "Bearer sk-1234",
"x-litellm-end-user-id": "customer_123", # 👈 CUSTOMER ID
"Authorization": "Bearer gho_token"
}
}
}
}
client = Client(config)
async def main():
async with client:
# All MCP calls will be tracked under customer_123
tools = await client.list_tools()
result = await client.call_tool(tools[0].name, {})
print(f"Tool result: {result}")
asyncio.run(main())
```
**Cursor IDE Example:**
```json title="Cursor config with customer tracking" showLineNumbers
{
"mcpServers": {
"GitHub": {
"url": "http://localhost:4000/github_mcp/mcp",
"headers": {
"x-litellm-api-key": "Bearer $LITELLM_API_KEY",
"x-litellm-end-user-id": "customer_123"
}
}
}
}
```
**What happens:**
- Customer-specific object permissions are enforced (only allowed MCP servers are accessible)
- Customer budgets are applied
- All tool calls are tracked under `customer_123`
[Learn more about customer management →](./proxy/customers)
## Using your MCP with client side credentials
Use this if you want to pass a client side authentication token to LiteLLM to then pass to your MCP to auth to your MCP.
+248 -15
View File
@@ -2,29 +2,98 @@ import Image from '@theme/IdealImage';
import Tabs from '@theme/Tabs';
import TabItem from '@theme/TabItem';
# Customers / End-User Budgets
# Customers / End-Users
Track spend, set budgets for your customers.
Track spend, set budgets and permissions for your customers.
## Tracking Customer Spend
## Tracking Customer Spend + Permissions
### 1. Make LLM API call w/ Customer ID
Make a /chat/completions call, pass 'user' - First call Works
LiteLLM checks for a customer/end-user ID in the following order (first match wins):
```bash showLineNumbers title="Make request with customer ID"
| Priority | Method | Where | Notes |
|----------|--------|-------|-------|
| 1 | `x-litellm-customer-id` header | Request headers | Standard header, always checked |
| 2 | `x-litellm-end-user-id` header | Request headers | Standard header, always checked |
| 3 | Custom header via `user_header_mappings` | Request headers | Configured in `general_settings` |
| 4 | Custom header via `user_header_name` | Request headers | Deprecated — use `user_header_mappings` |
| 5 | `user` field | Request body | Standard OpenAI field |
| 6 | `litellm_metadata.user` field | Request body | Anthropic-style metadata |
| 7 | `metadata.user_id` field | Request body | Generic metadata pattern |
| 8 | `safety_identifier` field | Request body | Responses API |
**Option 1: Standard headers** (recommended — no request body modification needed)
```bash showLineNumbers title="Make request with customer ID in header"
curl -X POST 'http://0.0.0.0:4000/chat/completions' \
--header 'Content-Type: application/json' \
--header 'Authorization: Bearer sk-1234' \ # 👈 YOUR PROXY KEY
--data ' {
--header 'Authorization: Bearer sk-1234' \
--header 'x-litellm-end-user-id: ishaan3' \
--data '{
"model": "azure-gpt-3.5",
"user": "ishaan3", # 👈 CUSTOMER ID
"messages": [
{
"role": "user",
"content": "what time is it"
}
]
"messages": [{"role": "user", "content": "what time is it"}]
}'
```
Both `x-litellm-customer-id` and `x-litellm-end-user-id` are supported and always checked without any configuration.
**Option 2: `user` field in request body** (OpenAI-compatible)
```bash showLineNumbers title="Make request with customer ID in body"
curl -X POST 'http://0.0.0.0:4000/chat/completions' \
--header 'Content-Type: application/json' \
--header 'Authorization: Bearer sk-1234' \
--data '{
"model": "azure-gpt-3.5",
"user": "ishaan3",
"messages": [{"role": "user", "content": "what time is it"}]
}'
```
**Option 3: Custom header via `user_header_mappings`** (configurable)
```yaml showLineNumbers title="config.yaml"
general_settings:
user_header_mappings:
- header_name: "x-my-app-user-id"
litellm_user_role: "customer"
```
```bash showLineNumbers title="Make request with custom header"
curl -X POST 'http://0.0.0.0:4000/chat/completions' \
--header 'Content-Type: application/json' \
--header 'Authorization: Bearer sk-1234' \
--header 'x-my-app-user-id: ishaan3' \
--data '{
"model": "azure-gpt-3.5",
"messages": [{"role": "user", "content": "what time is it"}]
}'
```
**Option 4: `litellm_metadata.user`** (Anthropic-style)
```bash showLineNumbers title="Make request with litellm_metadata.user"
curl -X POST 'http://0.0.0.0:4000/chat/completions' \
--header 'Content-Type: application/json' \
--header 'Authorization: Bearer sk-1234' \
--data '{
"model": "claude-3-5-sonnet",
"messages": [{"role": "user", "content": "what time is it"}],
"litellm_metadata": {"user": "ishaan3"}
}'
```
**Option 5: `metadata.user_id`**
```bash showLineNumbers title="Make request with metadata.user_id"
curl -X POST 'http://0.0.0.0:4000/chat/completions' \
--header 'Content-Type: application/json' \
--header 'Authorization: Bearer sk-1234' \
--data '{
"model": "azure-gpt-3.5",
"messages": [{"role": "user", "content": "what time is it"}],
"metadata": {"user_id": "ishaan3"}
}'
```
@@ -123,7 +192,171 @@ Expected Response
</Tabs>
## Setting Customer Budgets
## Setting Customer Object Permissions
Control which resources (MCP servers, vector stores, agents) a customer can access.
### What are Object Permissions?
Object permissions allow you to restrict customer access to specific:
- **MCP Servers**: Limit which MCP servers the customer can call
- **MCP Access Groups**: Assign customers to predefined groups of MCP servers
- **MCP Tool Permissions**: Granular control over which tools within an MCP server the customer can use
- **Vector Stores**: Control which vector stores the customer can query
- **Agents**: Restrict which agents the customer can interact with
- **Agent Access Groups**: Assign customers to predefined groups of agents
### Creating a Customer with Object Permissions
```bash showLineNumbers title="Create customer with object permissions"
curl -L -X POST 'http://localhost:4000/customer/new' \
-H 'Authorization: Bearer sk-1234' \
-H 'Content-Type: application/json' \
-d '{
"user_id": "user_1",
"object_permission": {
"mcp_servers": ["server_1", "server_2"],
"mcp_access_groups": ["public_group"],
"mcp_tool_permissions": {
"server_1": ["tool_a", "tool_b"]
},
"vector_stores": ["vector_store_1"],
"agents": ["agent_1"],
"agent_access_groups": ["basic_agents"]
}
}'
```
**Parameters:**
- `mcp_servers` (Optional[List[str]]): List of allowed MCP server IDs
- `mcp_access_groups` (Optional[List[str]]): List of MCP access group names
- `mcp_tool_permissions` (Optional[Dict[str, List[str]]]): Map of server ID to allowed tool names
- `vector_stores` (Optional[List[str]]): List of allowed vector store IDs
- `agents` (Optional[List[str]]): List of allowed agent IDs
- `agent_access_groups` (Optional[List[str]]): List of agent access group names
**Note:** If `object_permission` is `null` or `{}`, the customer has no object-level restrictions.
### Updating Customer Object Permissions
You can update object permissions for existing customers:
```bash showLineNumbers title="Update customer object permissions"
curl -L -X POST 'http://localhost:4000/customer/update' \
-H 'Authorization: Bearer sk-1234' \
-H 'Content-Type: application/json' \
-d '{
"user_id": "user_1",
"object_permission": {
"mcp_servers": ["server_3"],
"vector_stores": ["vector_store_2", "vector_store_3"]
}
}'
```
### Viewing Customer Object Permissions
When you query customer info, object permissions are included in the response:
```bash showLineNumbers title="Get customer info with object permissions"
curl -X GET 'http://0.0.0.0:4000/customer/info?end_user_id=user_1' \
-H 'Authorization: Bearer sk-1234'
```
**Response:**
```json showLineNumbers title="Response with object permissions"
{
"user_id": "user_1",
"blocked": false,
"alias": "John Doe",
"spend": 0.0,
"object_permission": {
"object_permission_id": "perm_abc123",
"mcp_servers": ["server_1", "server_2"],
"mcp_access_groups": ["public_group"],
"mcp_tool_permissions": {
"server_1": ["tool_a", "tool_b"]
},
"vector_stores": ["vector_store_1"],
"agents": ["agent_1"],
"agent_access_groups": ["basic_agents"]
},
"litellm_budget_table": null
}
```
### Use Cases
**1. Tiered Access Control**
Create different permission tiers for your customers:
```bash showLineNumbers title="Free tier customer"
# Free tier - limited access
curl -L -X POST 'http://localhost:4000/customer/new' \
-H 'Authorization: Bearer sk-1234' \
-H 'Content-Type: application/json' \
-d '{
"user_id": "free_user",
"budget_id": "free_tier",
"object_permission": {
"mcp_access_groups": ["public_group"],
"agent_access_groups": ["basic_agents"]
}
}'
```
```bash showLineNumbers title="Premium tier customer"
# Premium tier - full access
curl -L -X POST 'http://localhost:4000/customer/new' \
-H 'Authorization: Bearer sk-1234' \
-H 'Content-Type: application/json' \
-d '{
"user_id": "premium_user",
"budget_id": "premium_tier",
"object_permission": {
"mcp_servers": ["server_1", "server_2", "server_3"],
"vector_stores": ["vector_store_1", "vector_store_2"],
"agents": ["agent_1", "agent_2", "agent_3"]
}
}'
```
**2. Department-Specific Access**
Restrict customers to resources relevant to their department:
```bash showLineNumbers title="Sales team customer"
curl -L -X POST 'http://localhost:4000/customer/new' \
-H 'Authorization: Bearer sk-1234' \
-H 'Content-Type: application/json' \
-d '{
"user_id": "sales_user",
"object_permission": {
"mcp_servers": ["crm_server", "email_server"],
"agents": ["sales_assistant"],
"vector_stores": ["sales_knowledge_base"]
}
}'
```
**3. Tool-Level Restrictions**
Grant access to specific tools within an MCP server:
```bash showLineNumbers title="Limited tool access"
curl -L -X POST 'http://localhost:4000/customer/new' \
-H 'Authorization: Bearer sk-1234' \
-H 'Content-Type: application/json' \
-d '{
"user_id": "restricted_user",
"object_permission": {
"mcp_servers": ["database_server"],
"mcp_tool_permissions": {
"database_server": ["read_only_query", "get_table_schema"]
}
}
}'
```
## Setting Customer Budgets
Set customer budgets (e.g. monthly budgets, tpm/rpm limits) on LiteLLM Proxy
@@ -20,6 +20,10 @@ By default, LiteLLM does not forward client headers to LLM provider APIs. Howeve
`x-litellm-spend-logs-metadata`: Optional[str]: JSON string containing custom metadata to include in spend logs. Example: `{"user_id": "12345", "project_id": "proj_abc", "request_type": "chat_completion"}`. [Learn More](../proxy/enterprise#tracking-spend-with-custom-metadata)
`x-litellm-customer-id`: Optional[str]: Standard header for passing a customer/end-user ID. Always checked without any configuration. [Learn More](./customers)
`x-litellm-end-user-id`: Optional[str]: Standard header for passing a customer/end-user ID. Always checked without any configuration. [Learn More](./customers)
## Anthropic Headers
`anthropic-version` Optional[str]: The version of the Anthropic API to use.
@@ -0,0 +1,6 @@
-- AlterTable
ALTER TABLE "LiteLLM_EndUserTable" ADD COLUMN "object_permission_id" TEXT;
-- AddForeignKey
ALTER TABLE "LiteLLM_EndUserTable" ADD CONSTRAINT "LiteLLM_EndUserTable_object_permission_id_fkey" FOREIGN KEY ("object_permission_id") REFERENCES "LiteLLM_ObjectPermissionTable"("object_permission_id") ON DELETE SET NULL ON UPDATE CASCADE;
@@ -233,6 +233,7 @@ model LiteLLM_ObjectPermissionTable {
verification_tokens LiteLLM_VerificationToken[]
organizations LiteLLM_OrganizationTable[]
users LiteLLM_UserTable[]
end_users LiteLLM_EndUserTable[]
}
// Holds the MCP server configuration
@@ -403,7 +404,9 @@ model LiteLLM_EndUserTable {
allowed_model_region String? // require all user requests to use models in this specific region
default_model String? // use along with 'allowed_model_region'. if no available model in region, default to this model.
budget_id String?
object_permission_id String?
litellm_budget_table LiteLLM_BudgetTable? @relation(fields: [budget_id], references: [budget_id])
object_permission LiteLLM_ObjectPermissionTable? @relation(fields: [object_permission_id], references: [object_permission_id])
blocked Boolean @default(false)
}
@@ -124,6 +124,9 @@ class AnthropicMessagesHandler(BaseTranslation):
)
guardrailed_texts = guardrailed_inputs.get("texts", [])
guardrailed_tools = guardrailed_inputs.get("tools")
if guardrailed_tools is not None:
data["tools"] = guardrailed_tools
# Step 3: Map guardrail responses back to original message structure
await self._apply_guardrail_responses_to_input(
@@ -194,7 +197,7 @@ class AnthropicMessagesHandler(BaseTranslation):
openai_tools = self.adapter.translate_anthropic_tools_to_openai(
tools=cast(List[AllAnthropicToolsValues], tools)
)
tools_to_check.extend(openai_tools)
tools_to_check.extend(openai_tools) # type: ignore
async def _apply_guardrail_responses_to_input(
self,
@@ -107,6 +107,9 @@ class OpenAIChatCompletionsHandler(BaseTranslation):
guardrailed_texts = guardrailed_inputs.get("texts", [])
guardrailed_tool_calls = guardrailed_inputs.get("tool_calls", [])
guardrailed_tools = guardrailed_inputs.get("tools")
if guardrailed_tools is not None:
data["tools"] = guardrailed_tools
# Step 3: Map guardrail responses back to original message structure
if guardrailed_texts and texts_to_check:
@@ -96,10 +96,11 @@ class OpenAIResponsesHandler(BaseTranslation):
# Handle simple string input
if isinstance(input_data, str):
inputs = GenericGuardrailAPIInputs(texts=[input_data])
original_tools: List[Dict[str, Any]] = []
# Extract and transform tools if present
if "tools" in data and data["tools"]:
original_tools = list(data["tools"])
self._extract_and_transform_tools(data["tools"], tools_to_check)
if tools_to_check:
inputs["tools"] = tools_to_check
@@ -118,6 +119,9 @@ class OpenAIResponsesHandler(BaseTranslation):
)
guardrailed_texts = guardrailed_inputs.get("texts", [])
data["input"] = guardrailed_texts[0] if guardrailed_texts else input_data
self._apply_guardrailed_tools_to_data(
data, original_tools, guardrailed_inputs.get("tools")
)
verbose_proxy_logger.debug("OpenAI Responses API: Processed string input")
return data
@@ -128,8 +132,7 @@ class OpenAIResponsesHandler(BaseTranslation):
texts_to_check: List[str] = []
images_to_check: List[str] = []
task_mappings: List[Tuple[int, Optional[int]]] = []
# Track (message_index, content_index) for each text
# content_index is None for string content, int for list content
original_tools_list: List[Dict[str, Any]] = list(data.get("tools") or [])
# Step 1: Extract all text content, images, and tools
for msg_idx, message in enumerate(input_data):
@@ -166,6 +169,11 @@ class OpenAIResponsesHandler(BaseTranslation):
)
guardrailed_texts = guardrailed_inputs.get("texts", [])
self._apply_guardrailed_tools_to_data(
data,
original_tools_list,
guardrailed_inputs.get("tools"),
)
# Step 3: Map guardrail responses back to original input structure
await self._apply_guardrail_responses_to_input(
@@ -203,6 +211,53 @@ class OpenAIResponsesHandler(BaseTranslation):
cast(List[ChatCompletionToolParam], transformed_tools)
)
def _remap_tools_to_responses_api_format(
self, guardrailed_tools: List[Any]
) -> List[Dict[str, Any]]:
"""
Remap guardrail-returned tools (Chat Completion format) back to
Responses API request tool format.
"""
return LiteLLMCompletionResponsesConfig.transform_chat_completion_tool_params_to_responses_api_tools(
guardrailed_tools # type: ignore
)
def _merge_tools_after_guardrail(
self,
original_tools: List[Dict[str, Any]],
remapped: List[Dict[str, Any]],
) -> List[Dict[str, Any]]:
"""
Merge remapped guardrailed tools with original tools that were not sent
to the guardrail (e.g. web_search, web_search_preview), preserving order.
"""
if not original_tools:
return remapped
result: List[Dict[str, Any]] = []
j = 0
for tool in original_tools:
if isinstance(tool, dict) and tool.get("type") in (
"web_search",
"web_search_preview",
):
result.append(tool)
else:
if j < len(remapped):
result.append(remapped[j])
j += 1
return result
def _apply_guardrailed_tools_to_data(
self,
data: dict,
original_tools: List[Dict[str, Any]],
guardrailed_tools: Optional[List[Any]],
) -> None:
"""Remap guardrailed tools to Responses API format and merge with original, then set data['tools']."""
if guardrailed_tools is not None:
remapped = self._remap_tools_to_responses_api_format(guardrailed_tools)
data["tools"] = self._merge_tools_after_guardrail(original_tools, remapped)
def _extract_input_text_and_images(
self,
message: Any, # Can be Dict[str, Any] or ResponseInputParam
@@ -407,7 +462,10 @@ class OpenAIResponsesHandler(BaseTranslation):
List[ChatCompletionToolCallChunk], tool_calls
)
# Include model information if available
if hasattr(model_response_stream, "model") and model_response_stream.model:
if (
hasattr(model_response_stream, "model")
and model_response_stream.model
):
inputs["model"] = model_response_stream.model
_guardrailed_inputs = await guardrail_to_apply.apply_guardrail(
inputs=inputs,
@@ -448,7 +506,9 @@ class OpenAIResponsesHandler(BaseTranslation):
)
return responses_so_far
else:
verbose_proxy_logger.debug("Skipping output guardrail - model response has no choices")
verbose_proxy_logger.debug(
"Skipping output guardrail - model response has no choices"
)
# model_response_stream = OpenAiResponsesToChatCompletionStreamIterator.translate_responses_chunk_to_openai_stream(final_chunk)
# tool_calls = model_response_stream.choices[0].tool_calls
# convert openai response to model response
@@ -456,7 +516,11 @@ class OpenAIResponsesHandler(BaseTranslation):
inputs = GenericGuardrailAPIInputs(texts=[string_so_far])
# Try to get model from the final chunk if available
if isinstance(final_chunk, dict):
response_model = final_chunk.get("response", {}).get("model") if isinstance(final_chunk.get("response"), dict) else None
response_model = (
final_chunk.get("response", {}).get("model")
if isinstance(final_chunk.get("response"), dict)
else None
)
if response_model:
inputs["model"] = response_model
_guardrailed_inputs = await guardrail_to_apply.apply_guardrail(
@@ -591,8 +655,8 @@ class OpenAIResponsesHandler(BaseTranslation):
content = generic_response_output_item.content
except Exception:
# Try to extract content directly from output_item if validation fails
if hasattr(output_item, "content") and output_item.content:
content = output_item.content
if hasattr(output_item, "content") and output_item.content: # type: ignore
content = output_item.content # type: ignore
else:
return
elif isinstance(output_item, dict):
@@ -669,10 +733,10 @@ class OpenAIResponsesHandler(BaseTranslation):
if isinstance(content_item, OutputText):
content_item.text = guardrail_response
# Update the original response output
if hasattr(output_item, "content") and output_item.content:
original_content = output_item.content[content_idx]
if hasattr(output_item, "content") and output_item.content: # type: ignore
original_content = output_item.content[content_idx] # type: ignore
if hasattr(original_content, "text"):
original_content.text = guardrail_response
original_content.text = guardrail_response # type: ignore
except Exception:
pass
elif isinstance(output_item, dict):
@@ -336,15 +336,21 @@ class MCPRequestHandler:
user_api_key_auth: Optional[UserAPIKeyAuth] = None,
) -> List[str]:
"""
Get list of allowed MCP servers for the given user/key based on permissions
Get list of allowed MCP servers for the given user/key based on permissions.
Permission hierarchy (all rules are intersections):
1. Get allowed servers from key permissions
2. Get allowed servers from team permissions
3. Get allowed servers from end_user permissions
4. Final result = intersection of key/team AND end_user (if end_user has permissions set)
Returns:
List[str]: List of allowed MCP servers by server id
"""
from typing import List
from litellm.proxy.proxy_server import general_settings
try:
allowed_mcp_servers: List[str] = []
# Get allowed servers from key and team
allowed_mcp_servers_for_key = (
await MCPRequestHandler._get_allowed_mcp_servers_for_key(
user_api_key_auth
@@ -357,8 +363,9 @@ class MCPRequestHandler:
)
#########################################################
# If team has mcp_servers, handle inheritance and intersection logic
# Calculate key/team allowed servers using inheritance and intersection logic
#########################################################
allowed_mcp_servers: List[str] = []
if len(allowed_mcp_servers_for_team) > 0:
if len(allowed_mcp_servers_for_key) > 0:
# Key has its own MCP permissions - use intersection with team permissions
@@ -371,6 +378,40 @@ class MCPRequestHandler:
else:
allowed_mcp_servers = allowed_mcp_servers_for_key
#########################################################
# Check end_user permissions if end_user_id is set
#########################################################
if user_api_key_auth and user_api_key_auth.end_user_id:
allowed_mcp_servers_for_end_user = (
await MCPRequestHandler._get_allowed_mcp_servers_for_end_user(
user_api_key_auth
)
)
# If end_user has explicit MCP server permissions, apply intersection
if len(allowed_mcp_servers_for_end_user) > 0:
verbose_logger.debug(
f"End user {user_api_key_auth.end_user_id} has explicit MCP permissions: {allowed_mcp_servers_for_end_user}"
)
# Always apply intersection: key/team AND end_user
# This ensures end_user can only access servers that both they AND their key/team are authorized for
filtered_servers = []
for _mcp_server in allowed_mcp_servers:
if _mcp_server in allowed_mcp_servers_for_end_user:
filtered_servers.append(_mcp_server)
allowed_mcp_servers = filtered_servers
verbose_logger.debug(
f"Applied end_user intersection filter. Final allowed servers: {allowed_mcp_servers}"
)
# If flag is enabled but end_user has no permissions, block all access
elif general_settings.get("require_end_user_mcp_access_defined", False):
verbose_logger.debug(
f"require_end_user_mcp_access_defined=True and end_user {user_api_key_auth.end_user_id} has no MCP permissions - blocking MCP access"
)
return []
return list(set(allowed_mcp_servers))
except Exception as e:
verbose_logger.warning(f"Failed to get allowed MCP servers: {str(e)}")
@@ -614,6 +655,66 @@ class MCPRequestHandler:
)
return []
@staticmethod
async def _get_allowed_mcp_servers_for_end_user(
user_api_key_auth: Optional[UserAPIKeyAuth] = None,
) -> List[str]:
"""
Get allowed MCP servers for an end user.
Returns the MCP servers from the end_user's object_permission.
"""
from litellm.proxy.auth.auth_checks import get_end_user_object
from litellm.proxy.proxy_server import (
prisma_client,
proxy_logging_obj,
user_api_key_cache,
)
if not user_api_key_auth or not user_api_key_auth.end_user_id:
return []
if prisma_client is None:
verbose_logger.debug("prisma_client is None")
return []
try:
# Use optimized get_end_user_object function with caching
end_user_obj = await get_end_user_object(
end_user_id=user_api_key_auth.end_user_id,
prisma_client=prisma_client,
user_api_key_cache=user_api_key_cache,
parent_otel_span=user_api_key_auth.parent_otel_span,
proxy_logging_obj=proxy_logging_obj,
route="/mcp",
)
if end_user_obj is None or end_user_obj.object_permission is None:
return []
# Get direct MCP servers
direct_mcp_servers = end_user_obj.object_permission.mcp_servers or []
# Get MCP servers from access groups
access_group_servers = (
await MCPRequestHandler._get_mcp_servers_from_access_groups(
end_user_obj.object_permission.mcp_access_groups or []
)
)
# Combine both lists
all_servers = direct_mcp_servers + access_group_servers
return list(set(all_servers))
except Exception as e:
verbose_logger.warning(
f"Failed to get allowed MCP servers for end_user: {str(e)}"
)
return []
@staticmethod
def _get_config_server_ids_for_access_groups(
config_mcp_servers, access_groups: List[str]
@@ -691,8 +792,6 @@ class MCPRequestHandler:
"""
Get list of MCP access groups for the given user/key based on permissions
"""
from typing import List
access_groups: List[str] = []
access_groups_for_key = await MCPRequestHandler._get_mcp_access_groups_for_key(
user_api_key_auth
@@ -43,9 +43,7 @@ from litellm.proxy._experimental.mcp_server.utils import (
)
from litellm.proxy._types import UserAPIKeyAuth
from litellm.proxy.auth.ip_address_utils import IPAddressUtils
from litellm.proxy.litellm_pre_call_utils import (
LiteLLMProxyRequestSetup,
)
from litellm.proxy.litellm_pre_call_utils import LiteLLMProxyRequestSetup
from litellm.types.mcp import MCPAuth
from litellm.types.mcp_server.mcp_server_manager import MCPInfo, MCPServer
from litellm.types.utils import CallTypes, StandardLoggingMCPToolCall
@@ -795,6 +793,7 @@ if MCP_AVAILABLE:
mcp_servers=mcp_servers,
allowed_mcp_servers=allowed_mcp_servers,
)
return allowed_mcp_servers
@@ -938,9 +937,6 @@ if MCP_AVAILABLE:
mcp_servers=mcp_servers,
)
# Decide whether to add prefix based on number of allowed servers
add_prefix = not (len(allowed_mcp_servers) == 1)
async def _fetch_and_filter_server_tools(
server: MCPServer,
) -> List[MCPTool]:
@@ -961,7 +957,7 @@ if MCP_AVAILABLE:
server=server,
mcp_auth_header=server_auth_header,
extra_headers=extra_headers,
add_prefix=add_prefix,
add_prefix=True, # Always add server prefix
raw_headers=raw_headers,
)
filtered_tools = filter_tools_by_allowed_tools(tools, server)
@@ -1079,8 +1075,6 @@ if MCP_AVAILABLE:
mcp_servers=mcp_servers,
)
# Decide whether to add prefix based on number of allowed servers
add_prefix = not (len(allowed_mcp_servers) == 1)
# Get prompts from each allowed server
all_prompts = []
@@ -1101,7 +1095,7 @@ if MCP_AVAILABLE:
server=server,
mcp_auth_header=server_auth_header,
extra_headers=extra_headers,
add_prefix=add_prefix,
add_prefix=True, # Always add server prefix
raw_headers=raw_headers,
)
@@ -1140,7 +1134,6 @@ if MCP_AVAILABLE:
mcp_servers=mcp_servers,
)
add_prefix = not (len(allowed_mcp_servers) == 1)
all_resources: List[Resource] = []
for server in allowed_mcp_servers:
@@ -1160,7 +1153,7 @@ if MCP_AVAILABLE:
server=server,
mcp_auth_header=server_auth_header,
extra_headers=extra_headers,
add_prefix=add_prefix,
add_prefix=True, # Always add server prefix
raw_headers=raw_headers,
)
all_resources.extend(resources)
@@ -1197,7 +1190,6 @@ if MCP_AVAILABLE:
mcp_servers=mcp_servers,
)
add_prefix = not (len(allowed_mcp_servers) == 1)
all_resource_templates: List[ResourceTemplate] = []
for server in allowed_mcp_servers:
@@ -1218,7 +1210,7 @@ if MCP_AVAILABLE:
server=server,
mcp_auth_header=server_auth_header,
extra_headers=extra_headers,
add_prefix=add_prefix,
add_prefix=True, # Always add server prefix
raw_headers=raw_headers,
)
)
@@ -1676,14 +1668,9 @@ if MCP_AVAILABLE:
detail="User not allowed to get this prompt.",
)
# Decide whether to add prefix based on number of allowed servers
add_prefix = not (len(allowed_mcp_servers) == 1)
if add_prefix:
original_prompt_name, server_name = split_server_prefix_from_name(name)
else:
original_prompt_name = name
server_name = allowed_mcp_servers[0].name
# Extract server name from prefixed prompt name
original_prompt_name, server_name = split_server_prefix_from_name(name)
server = next((s for s in allowed_mcp_servers if s.name == server_name), None)
if server is None:
+21 -22
View File
@@ -13,26 +13,25 @@ model_list:
- model_name: gpt-4.1-mini
litellm_params:
model: openai/gpt-4.1-mini
# guardrails:
# - guardrail_name: generic-guardrail
# litellm_params:
# guardrail: generic_guardrail_api
# mode: ["pre_call"]
# headers:
# Authorization: Bearer mock-bedrock-token-12345
# api_base: http://localhost:8080
# default_on: true
prompts:
- prompt_id: "simple_prompt"
- model_name: gpt-5-mini
litellm_params:
prompt_integration: "generic_prompt_management"
provider_specific_query_params:
project_name: litellm
slug: hello-world-prompt-2bac
api_base: http://localhost:8080
api_key: os.environ/BRAINTRUST_API_KEY
ignore_prompt_manager_model: true
ignore_prompt_manager_optional_params: true
model: openai/gpt-5-mini
guardrails:
- guardrail_name: mcp-user-permissions
litellm_params:
guardrail: mcp_end_user_permission
mode: pre_call
default_on: true
mcp_servers:
my_http_server:
url: "http://0.0.0.0:8001/mcp"
transport: "http"
description: "My custom MCP server"
available_on_public_internet: true
general_settings:
store_model_in_db: true
store_prompts_in_spend_logs: true
+17 -12
View File
@@ -1409,12 +1409,13 @@ class NewCustomerRequest(BudgetNewRequest):
blocked: bool = False # allow/disallow requests for this end-user
budget_id: Optional[str] = None # give either a budget_id or max_budget
spend: Optional[float] = None
allowed_model_region: Optional[
AllowedModelRegion
] = None # require all user requests to use models in this specific region
default_model: Optional[
str
] = None # if no equivalent model in allowed region - default all requests to this model
allowed_model_region: Optional[AllowedModelRegion] = (
None # require all user requests to use models in this specific region
)
default_model: Optional[str] = (
None # if no equivalent model in allowed region - default all requests to this model
)
object_permission: Optional[LiteLLM_ObjectPermissionBase] = None
@model_validator(mode="before")
@classmethod
@@ -1436,12 +1437,13 @@ class UpdateCustomerRequest(LiteLLMPydanticObjectBase):
blocked: bool = False # allow/disallow requests for this end-user
max_budget: Optional[float] = None
budget_id: Optional[str] = None # give either a budget_id or max_budget
allowed_model_region: Optional[
AllowedModelRegion
] = None # require all user requests to use models in this specific region
default_model: Optional[
str
] = None # if no equivalent model in allowed region - default all requests to this model
allowed_model_region: Optional[AllowedModelRegion] = (
None # require all user requests to use models in this specific region
)
default_model: Optional[str] = (
None # if no equivalent model in allowed region - default all requests to this model
)
object_permission: Optional[LiteLLM_ObjectPermissionBase] = None
class DeleteCustomerRequest(LiteLLMPydanticObjectBase):
@@ -2301,6 +2303,7 @@ class UserAPIKeyAuth(
user_max_budget: Optional[float] = None
request_route: Optional[str] = None
user: Optional[Any] = None # Expanded user object when expand=user is used
end_user_object_permission: Optional[LiteLLM_ObjectPermissionTable] = None
model_config = ConfigDict(arbitrary_types_allowed=True)
@@ -2535,6 +2538,8 @@ class LiteLLM_EndUserTable(LiteLLMPydanticObjectBase):
allowed_model_region: Optional[AllowedModelRegion] = None
default_model: Optional[str] = None
litellm_budget_table: Optional[LiteLLM_BudgetTable] = None
object_permission_id: Optional[str] = None
object_permission: Optional[LiteLLM_ObjectPermissionTable] = None
@model_validator(mode="before")
@classmethod
+38 -52
View File
@@ -11,7 +11,8 @@ Run checks for:
import asyncio
import re
import time
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Union, cast
from typing import (TYPE_CHECKING, Any, Dict, List, Literal, Optional, Union,
cast)
from fastapi import HTTPException, Request, status
from pydantic import BaseModel
@@ -20,41 +21,27 @@ import litellm
from litellm._logging import verbose_proxy_logger
from litellm.caching.caching import DualCache
from litellm.caching.dual_cache import LimitedSizeOrderedDict
from litellm.constants import (
CLI_JWT_EXPIRATION_HOURS,
CLI_JWT_TOKEN_NAME,
DEFAULT_ACCESS_GROUP_CACHE_TTL,
DEFAULT_IN_MEMORY_TTL,
DEFAULT_MANAGEMENT_OBJECT_IN_MEMORY_CACHE_TTL,
DEFAULT_MAX_RECURSE_DEPTH,
EMAIL_BUDGET_ALERT_MAX_SPEND_ALERT_PERCENTAGE,
)
from litellm.constants import (CLI_JWT_EXPIRATION_HOURS, CLI_JWT_TOKEN_NAME,
DEFAULT_ACCESS_GROUP_CACHE_TTL,
DEFAULT_IN_MEMORY_TTL,
DEFAULT_MANAGEMENT_OBJECT_IN_MEMORY_CACHE_TTL,
DEFAULT_MAX_RECURSE_DEPTH,
EMAIL_BUDGET_ALERT_MAX_SPEND_ALERT_PERCENTAGE)
from litellm.litellm_core_utils.get_llm_provider_logic import get_llm_provider
from litellm.proxy._types import (
RBAC_ROLES,
CallInfo,
LiteLLM_AccessGroupTable,
LiteLLM_BudgetTable,
LiteLLM_EndUserTable,
Litellm_EntityType,
LiteLLM_JWTAuth,
LiteLLM_ObjectPermissionTable,
LiteLLM_OrganizationMembershipTable,
LiteLLM_OrganizationTable,
LiteLLM_TagTable,
LiteLLM_TeamMembership,
LiteLLM_TeamTable,
LiteLLM_TeamTableCachedObj,
LiteLLM_UserTable,
LiteLLMRoutes,
LitellmUserRoles,
NewTeamRequest,
ProxyErrorTypes,
ProxyException,
RoleBasedPermissions,
SpecialModelNames,
UserAPIKeyAuth,
)
from litellm.proxy._types import (RBAC_ROLES, CallInfo,
LiteLLM_AccessGroupTable,
LiteLLM_BudgetTable, LiteLLM_EndUserTable,
Litellm_EntityType, LiteLLM_JWTAuth,
LiteLLM_ObjectPermissionTable,
LiteLLM_OrganizationMembershipTable,
LiteLLM_OrganizationTable, LiteLLM_TagTable,
LiteLLM_TeamMembership, LiteLLM_TeamTable,
LiteLLM_TeamTableCachedObj,
LiteLLM_UserTable, LiteLLMRoutes,
LitellmUserRoles, NewTeamRequest,
ProxyErrorTypes, ProxyException,
RoleBasedPermissions, SpecialModelNames,
UserAPIKeyAuth)
from litellm.proxy.auth.route_checks import RouteChecks
from litellm.proxy.route_llm_request import route_request
from litellm.proxy.utils import PrismaClient, ProxyLogging, log_db_metrics
@@ -366,7 +353,8 @@ async def common_checks(
_request_metadata: dict = request_body.get("metadata", {}) or {}
if _request_metadata.get("guardrails"):
# check if team allowed to modify guardrails
from litellm.proxy.guardrails.guardrail_helpers import can_modify_guardrails
from litellm.proxy.guardrails.guardrail_helpers import \
can_modify_guardrails
can_modify: bool = can_modify_guardrails(team_object)
if can_modify is False:
@@ -792,7 +780,7 @@ async def get_end_user_object(
try:
response = await prisma_client.db.litellm_endusertable.find_unique(
where={"user_id": end_user_id},
include={"litellm_budget_table": True},
include={"litellm_budget_table": True, "object_permission": True},
)
if response is None:
@@ -1812,9 +1800,8 @@ class ExperimentalUIJWTToken:
def get_experimental_ui_login_jwt_auth_token(user_info: LiteLLM_UserTable) -> str:
from datetime import timedelta
from litellm.proxy.common_utils.encrypt_decrypt_utils import (
encrypt_value_helper,
)
from litellm.proxy.common_utils.encrypt_decrypt_utils import \
encrypt_value_helper
if user_info.user_role is None:
raise Exception("User role is required for experimental UI login")
@@ -1860,9 +1847,8 @@ class ExperimentalUIJWTToken:
"""
from datetime import timedelta
from litellm.proxy.common_utils.encrypt_decrypt_utils import (
encrypt_value_helper,
)
from litellm.proxy.common_utils.encrypt_decrypt_utils import \
encrypt_value_helper
if user_info.user_role is None:
raise Exception("User role is required for CLI JWT login")
@@ -1901,9 +1887,8 @@ class ExperimentalUIJWTToken:
import json
from litellm.proxy.auth.user_api_key_auth import UserAPIKeyAuth
from litellm.proxy.common_utils.encrypt_decrypt_utils import (
decrypt_value_helper,
)
from litellm.proxy.common_utils.encrypt_decrypt_utils import \
decrypt_value_helper
decrypted_token = decrypt_value_helper(
hashed_token, key="ui_hash_key", exception_type="debug"
@@ -2150,11 +2135,11 @@ async def _get_resources_from_access_groups(
# Lazy import to avoid circular imports
if prisma_client is None or user_api_key_cache is None:
from litellm.proxy.proxy_server import (
prisma_client as _prisma_client,
proxy_logging_obj as _proxy_logging_obj,
user_api_key_cache as _user_api_key_cache,
)
from litellm.proxy.proxy_server import prisma_client as _prisma_client
from litellm.proxy.proxy_server import \
proxy_logging_obj as _proxy_logging_obj
from litellm.proxy.proxy_server import \
user_api_key_cache as _user_api_key_cache
prisma_client = prisma_client or _prisma_client
user_api_key_cache = user_api_key_cache or _user_api_key_cache
@@ -2936,7 +2921,8 @@ async def _tag_max_budget_check(
BudgetExceededError if any tag is over its max budget.
Triggers a budget alert if any tag is over its max budget.
"""
from litellm.proxy.common_utils.http_parsing_utils import get_tags_from_request_body
from litellm.proxy.common_utils.http_parsing_utils import \
get_tags_from_request_body
if prisma_client is None:
return
+10
View File
@@ -736,6 +736,16 @@ def get_end_user_id_from_request_body(
user_id_from_metadata_field = metadata_dict.get("user_id")
if user_id_from_metadata_field is not None:
return str(user_id_from_metadata_field)
# Check 6: 'safety_identifier' in request body (OpenAI Responses API parameter)
# SECURITY NOTE: safety_identifier can be set by any caller in the request body.
# Only use this for end-user identification in trusted environments where you control
# the calling application. For untrusted callers, prefer using headers or server-side
# middleware to set the end_user_id to prevent impersonation.
if request_body.get("safety_identifier") is not None:
user_from_body_user_field = request_body["safety_identifier"]
return str(user_from_body_user_field)
return None
+6 -1
View File
@@ -644,7 +644,7 @@ async def _user_api_key_auth_builder( # noqa: PLR0915
if team_object is not None
else None,
)
# Check if model has zero cost - if so, skip all budget checks
model = get_model_from_request(request_data, route)
skip_budget_checks = False
@@ -831,6 +831,8 @@ async def _user_api_key_auth_builder( # noqa: PLR0915
valid_token=valid_token, end_user_params=end_user_params
)
valid_token.parent_otel_span = parent_otel_span
if _end_user_object is not None:
valid_token.end_user_object_permission = _end_user_object.object_permission
return valid_token
@@ -1277,6 +1279,9 @@ async def _user_api_key_auth_builder( # noqa: PLR0915
if _end_user_object is not None:
valid_token_dict.update(end_user_params)
valid_token_dict["end_user_object_permission"] = (
_end_user_object.object_permission
)
# check if token is from litellm-ui, litellm ui makes keys to allow users to login with sso. These keys can only be used for LiteLLM UI functions
# sso/login, ui/login, /key functions and /user functions
@@ -205,12 +205,6 @@ class ContentFilterGuardrail(CustomGuardrail):
# Load categories if provided
if categories:
self._load_categories(categories)
else:
verbose_proxy_logger.warning(
"ContentFilterGuardrail has no content categories configured. "
"Toxic/abuse and other category-based keyword filtering will not run. "
"Add categories (e.g. harm_toxic_abuse) in the guardrail config to enable them."
)
# Normalize inputs: convert dicts to Pydantic models for consistent handling
normalized_patterns: List[ContentFilterPattern] = []
@@ -0,0 +1,35 @@
from typing import TYPE_CHECKING
from litellm.types.guardrails import SupportedGuardrailIntegrations
from .mcp_end_user_permission import MCPEndUserPermissionGuardrail
if TYPE_CHECKING:
from litellm.types.guardrails import Guardrail, LitellmParams
def initialize_guardrail(litellm_params: "LitellmParams", guardrail: "Guardrail"):
import litellm
# Default to always-on. Only disable if the user explicitly sets default_on: false.
# We check the raw guardrail dict because LitellmParams normalizes None → False,
# making it impossible to distinguish "not set" from "explicitly false" via litellm_params.
_raw_default_on = guardrail.get("litellm_params", {}).get("default_on")
_default_on = False if _raw_default_on is False else True
_callback = MCPEndUserPermissionGuardrail(
guardrail_name=guardrail.get("guardrail_name", ""),
event_hook=litellm_params.mode,
default_on=_default_on,
)
litellm.logging_callback_manager.add_litellm_callback(_callback)
return _callback
guardrail_initializer_registry = {
SupportedGuardrailIntegrations.MCP_END_USER_PERMISSION.value: initialize_guardrail,
}
guardrail_class_registry = {
SupportedGuardrailIntegrations.MCP_END_USER_PERMISSION.value: MCPEndUserPermissionGuardrail,
}
@@ -0,0 +1,262 @@
"""
MCP End User Permission Guardrail Hook
Enforces end user permissions for MCP server access via apply_guardrail:
- input_type="request" filter tools the end user cannot access
Permission logic:
- No end_user_id allow all (key/team-level permissions apply)
- end_user_id, no mcp_servers allow all (default)
- end_user_id + mcp_servers allow only those servers
"""
from typing import TYPE_CHECKING, Any, List, Literal, Optional, Type
from litellm._logging import verbose_proxy_logger
from litellm.integrations.custom_guardrail import (
CustomGuardrail,
log_guardrail_information,
)
from litellm.proxy._types import LiteLLM_ObjectPermissionTable
from litellm.types.guardrails import GuardrailEventHooks
from litellm.types.utils import GenericGuardrailAPIInputs
if TYPE_CHECKING:
from litellm.types.proxy.guardrails.guardrail_hooks.base import GuardrailConfigModel
GUARDRAIL_NAME = "mcp_end_user_permission"
class MCPEndUserPermissionGuardrail(CustomGuardrail):
"""
Guardrail that enforces end user permissions for MCP server access.
Runs on input only (pre-call). Filters tools in the request that the
end user is not permitted to call based on their object_permission.
end_user_object_permission is populated on UserAPIKeyAuth during auth.
The guardrail resolves it via a cached get_end_user_object lookup
no extra DB round-trip when the cache is warm.
"""
def __init__(self, **kwargs):
if "supported_event_hooks" not in kwargs:
kwargs["supported_event_hooks"] = [
GuardrailEventHooks.pre_call,
]
super().__init__(**kwargs)
verbose_proxy_logger.debug("MCP End User Permission Guardrail initialized")
# ------------------------------------------------------------------
# apply_guardrail — filters MCP tools on the request side only
# ------------------------------------------------------------------
@log_guardrail_information
async def apply_guardrail(
self,
inputs: GenericGuardrailAPIInputs,
request_data: dict,
input_type: Literal["request", "response"] = "request",
logging_obj: Optional[Any] = None,
) -> GenericGuardrailAPIInputs:
"""
Filters MCP tools the end user cannot access based on their
object_permission.mcp_servers / mcp_access_groups settings.
"""
object_permission = await self._resolve_end_user_object_permission(request_data)
return await self._check_request_tools(inputs, object_permission)
# ------------------------------------------------------------------
# Private — request-side tool filtering
# ------------------------------------------------------------------
async def _check_request_tools(
self,
inputs: GenericGuardrailAPIInputs,
object_permission: Optional[LiteLLM_ObjectPermissionTable],
) -> GenericGuardrailAPIInputs:
tools = inputs.get("tools")
if not tools:
return inputs
allowed_mcp_servers = (
await self._get_allowed_mcp_servers_from_object_permission(
object_permission
)
)
if allowed_mcp_servers is None:
return inputs # No restrictions → pass through unchanged
verbose_proxy_logger.debug(
f"MCP guardrail: end user restricted to MCP servers: {allowed_mcp_servers}"
)
filtered_tools = []
removed_tools = []
for tool in tools:
tool_name = self._get_tool_name_from_definition(tool)
server_name = (
self._extract_mcp_server_name(tool_name) if tool_name else None
)
if server_name is None:
# Not an MCP tool (no prefix) or unrecognised format → keep
filtered_tools.append(tool)
elif server_name in allowed_mcp_servers:
filtered_tools.append(tool)
else:
removed_tools.append(tool_name)
verbose_proxy_logger.warning(
f"MCP guardrail: removing tool '{tool_name}' "
f"(server: '{server_name}') — not in end user's allowed servers"
)
if removed_tools:
verbose_proxy_logger.debug(
f"MCP guardrail: removed {len(removed_tools)} unauthorized MCP tool(s): {removed_tools}"
)
inputs["tools"] = filtered_tools
return inputs
# ------------------------------------------------------------------
# Private — end user permission resolution
# ------------------------------------------------------------------
@staticmethod
async def _resolve_end_user_object_permission(
request_data: dict,
) -> Optional[LiteLLM_ObjectPermissionTable]:
"""
Resolve the end user's object_permission via the cached auth lookup.
Uses get_end_user_object (same path as auth) so no extra DB round-trip
when the cache is warm.
"""
end_user_id = MCPEndUserPermissionGuardrail._get_end_user_id_from_request_data(
request_data
)
if not end_user_id:
return None
end_user_object = await MCPEndUserPermissionGuardrail._fetch_end_user_object(
end_user_id
)
return (
end_user_object.object_permission if end_user_object is not None else None
)
@staticmethod
def _get_end_user_id_from_request_data(request_data: dict) -> Optional[str]:
return request_data.get("user_api_key_end_user_id") or request_data.get(
"litellm_metadata", {}
).get("user_api_key_end_user_id")
@staticmethod
async def _fetch_end_user_object(end_user_id: str): # type: ignore[return]
"""
Fetch end user object via the same cached path used during auth.
No extra DB round-trip when the cache is warm.
"""
from litellm.proxy.auth.auth_checks import get_end_user_object
from litellm.proxy.proxy_server import (
prisma_client,
proxy_logging_obj,
user_api_key_cache,
)
if prisma_client is None:
return None
try:
return await get_end_user_object(
end_user_id=end_user_id,
prisma_client=prisma_client,
user_api_key_cache=user_api_key_cache,
parent_otel_span=None,
proxy_logging_obj=proxy_logging_obj,
route="/mcp",
)
except Exception as e:
verbose_proxy_logger.warning(
f"MCP guardrail: failed to fetch end_user_object for '{end_user_id}': {e}"
)
return None
# ------------------------------------------------------------------
# Private — permission derivation
# ------------------------------------------------------------------
@staticmethod
async def _get_allowed_mcp_servers_from_object_permission(
object_permission: Optional[LiteLLM_ObjectPermissionTable],
) -> Optional[List[str]]:
"""
Returns:
None no restrictions configured, allow all MCP servers
list restrict to exactly these server names
"""
if object_permission is None:
return None
direct_mcp_servers = object_permission.mcp_servers or []
mcp_access_groups = object_permission.mcp_access_groups or []
if not direct_mcp_servers and not mcp_access_groups:
return None # Both empty → no restrictions
from litellm.proxy._experimental.mcp_server.auth.user_api_key_auth_mcp import (
MCPRequestHandler,
)
access_group_servers = (
await MCPRequestHandler._get_mcp_servers_from_access_groups(
mcp_access_groups
)
)
return list(set(direct_mcp_servers + access_group_servers))
# ------------------------------------------------------------------
# Config model — exposes this guardrail in the UI
# ------------------------------------------------------------------
@staticmethod
def get_config_model() -> Optional[Type["GuardrailConfigModel"]]:
from litellm.types.proxy.guardrails.guardrail_hooks.mcp_end_user_permission import (
MCPEndUserPermissionGuardrailConfigModel,
)
return MCPEndUserPermissionGuardrailConfigModel
# ------------------------------------------------------------------
# Private — tool name extraction
# ------------------------------------------------------------------
@staticmethod
def _extract_mcp_server_name(tool_name: str) -> Optional[str]:
"""
Split "github-create_issue" "github".
Returns None if the tool name has no '-' prefix (not an MCP tool).
"""
if not tool_name or "-" not in tool_name:
return None
return tool_name.split("-", 1)[0]
@staticmethod
def _get_tool_name_from_definition(tool: Any) -> Optional[str]:
"""
Extract tool name from a definition dict.
OpenAI format: {"type": "function", "function": {"name": "..."}}
Anthropic format: {"name": "...", "input_schema": {...}}
"""
if not isinstance(tool, dict):
return None
function_def = tool.get("function")
if isinstance(function_def, dict):
name = function_def.get("name")
if name:
return name
return tool.get("name")
@@ -85,6 +85,7 @@ class UnifiedLLMGuardrails(CustomLogger):
add_guardrail_to_applied_guardrails_header,
)
verbose_proxy_logger.debug("Running UnifiedLLMGuardrails pre-call hook")
guardrail_to_apply: CustomGuardrail = data.pop("guardrail_to_apply", None)
+1 -1
View File
@@ -34,7 +34,7 @@ def init_guardrails_v2(
if initialized_guardrail:
guardrail_list.append(initialized_guardrail)
verbose_proxy_logger.debug(f"\nGuardrail List:{guardrail_list}\n")
# verbose_proxy_logger.debug(f"\nGuardrail List:{guardrail_list}\n")
# Populate router's guardrail_list for load balancing support
_populate_router_guardrail_list(guardrail_list=guardrail_list)
+3 -7
View File
@@ -879,7 +879,7 @@ async def add_litellm_data_to_request( # noqa: PLR0915
general_settings, user_api_key_dict, _headers
)
# Parse user info from headers
# Parse user info from headers (fallback to general_settings.user_header_name)
user = LiteLLMProxyRequestSetup.get_user_from_headers(_headers, general_settings)
if user is not None:
if user_api_key_dict.end_user_id is None:
@@ -1540,9 +1540,7 @@ def _match_and_track_policies(
add_policy_sources_to_metadata,
add_policy_to_applied_policies_header,
)
from litellm.proxy.policy_engine.attachment_registry import (
get_attachment_registry,
)
from litellm.proxy.policy_engine.attachment_registry import get_attachment_registry
from litellm.proxy.policy_engine.policy_matcher import PolicyMatcher
# Get matching policies via attachments (with match reasons for attribution)
@@ -1677,9 +1675,7 @@ def add_guardrails_from_policy_engine(
user_api_key_dict: The user's API key authentication info
"""
from litellm._logging import verbose_proxy_logger
from litellm.proxy.common_utils.http_parsing_utils import (
get_tags_from_request_body,
)
from litellm.proxy.common_utils.http_parsing_utils import get_tags_from_request_body
from litellm.proxy.policy_engine.policy_registry import get_policy_registry
from litellm.types.proxy.policy_engine import PolicyMatchContext
@@ -19,11 +19,13 @@ import litellm
from litellm._logging import verbose_proxy_logger
from litellm.proxy._types import *
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
from litellm.proxy.management_endpoints.common_daily_activity import \
get_daily_activity
from litellm.proxy.management_helpers.object_permission_utils import (
_set_object_permission, handle_update_object_permission_common)
from litellm.proxy.utils import handle_exception_on_proxy
from litellm.types.proxy.management_endpoints.common_daily_activity import (
SpendAnalyticsPaginatedResponse,
)
from litellm.proxy.management_endpoints.common_daily_activity import get_daily_activity
from litellm.types.proxy.management_endpoints.common_daily_activity import \
SpendAnalyticsPaginatedResponse
router = APIRouter()
@@ -107,9 +109,8 @@ async def unblock_user(data: BlockUsers):
```
"""
try:
from enterprise.enterprise_hooks.blocked_user_list import (
_ENTERPRISE_BlockedUserList,
)
from enterprise.enterprise_hooks.blocked_user_list import \
_ENTERPRISE_BlockedUserList
except ImportError:
raise HTTPException(
status_code=400,
@@ -164,6 +165,38 @@ def new_budget_request(data: NewCustomerRequest) -> Optional[BudgetNewRequest]:
return None
async def _handle_customer_object_permission_update(
non_default_values: dict,
end_user_table_data_typed: Optional[LiteLLM_EndUserTable],
update_end_user_table_data: dict,
prisma_client,
) -> None:
"""
Handle object permission updates for customer endpoints.
Updates the update_end_user_table_data dict in place with the new object_permission_id.
Args:
non_default_values: Dictionary containing the update values including object_permission
end_user_table_data_typed: Existing end user table data
update_end_user_table_data: Dictionary to update with new object_permission_id
prisma_client: Prisma database client
"""
if "object_permission" in non_default_values:
existing_object_permission_id = (
end_user_table_data_typed.object_permission_id
if end_user_table_data_typed is not None
else None
)
object_permission_id = await handle_update_object_permission_common(
data_json=non_default_values,
existing_object_permission_id=existing_object_permission_id,
prisma_client=prisma_client,
)
if object_permission_id is not None:
update_end_user_table_data["object_permission_id"] = object_permission_id
@router.post(
"/end_user/new",
tags=["Customer Management"],
@@ -200,6 +233,16 @@ async def new_end_user(
- soft_budget: Optional[float] - [Not Implemented Yet] Get alerts when customer crosses given budget, doesn't block requests.
- spend: Optional[float] - Specify initial spend for a given customer.
- budget_reset_at: Optional[str] - Specify the date and time when the budget should be reset.
- object_permission: Optional[LiteLLM_ObjectPermissionBase] - Customer-specific object permissions to control access to resources.
Supported fields:
* mcp_servers: List[str] - List of allowed MCP server IDs
* mcp_access_groups: List[str] - List of MCP access group names
* mcp_tool_permissions: Dict[str, List[str]] - Map of server ID to allowed tool names (e.g., {"server_1": ["tool_a", "tool_b"]})
* vector_stores: List[str] - List of allowed vector store IDs
* agents: List[str] - List of allowed agent IDs
* agent_access_groups: List[str] - List of agent access group names
Example: {"mcp_servers": ["server_1", "server_2"], "vector_stores": ["vector_store_1"], "agents": ["agent_1"]}
IF null or {} then no object-level restrictions apply.
- Allow specifying allowed regions
@@ -214,9 +257,22 @@ async def new_end_user(
"user_id" : "ishaan-jaff-3",
"allowed_region": "eu",
"budget_id": "free_tier",
"default_model": "azure/gpt-3.5-turbo-eu" <- all calls from this user, use this model?
"default_model": "azure/gpt-3.5-turbo-eu"
}'
# With object permissions
curl -L -X POST 'http://localhost:4000/customer/new' \
-H 'Authorization: Bearer sk-1234' \
-H 'Content-Type: application/json' \
-d '{
"user_id": "user_1",
"object_permission": {
"mcp_servers": ["server_1"],
"mcp_access_groups": ["public_group"],
"vector_stores": ["vector_store_1"]
}
}'
# return end-user object
```
@@ -233,11 +289,8 @@ async def new_end_user(
- end-user object
- currently allowed models
"""
from litellm.proxy.proxy_server import (
litellm_proxy_admin_name,
llm_router,
prisma_client,
)
from litellm.proxy.proxy_server import (litellm_proxy_admin_name,
llm_router, prisma_client)
if prisma_client is None:
raise HTTPException(
@@ -289,13 +342,34 @@ async def new_end_user(
if k not in BudgetNewRequest.model_fields.keys():
new_end_user_obj[k] = v
## Handle Object Permission - MCP Servers, Vector Stores etc.
new_end_user_obj = await _set_object_permission(
data_json=new_end_user_obj,
prisma_client=prisma_client,
)
# Ensure object_permission is not in the data being sent to create
# It should have been converted to object_permission_id by _set_object_permission
if "object_permission" in new_end_user_obj:
verbose_proxy_logger.warning(
f"object_permission still in new_end_user_obj after _set_object_permission: {new_end_user_obj.get('object_permission')}"
)
new_end_user_obj.pop("object_permission", None)
## WRITE TO DB ##
end_user_record = await prisma_client.db.litellm_endusertable.create(
data=new_end_user_obj, # type: ignore
include={"litellm_budget_table": True},
include={"litellm_budget_table": True, "object_permission": True},
)
return end_user_record
# Convert to dict and clean up recursive fields
response_dict = end_user_record.model_dump()
if response_dict.get("object_permission"):
# Remove reverse relations from object_permission
for field in ["teams", "verification_tokens", "organizations", "users", "end_users"]:
response_dict["object_permission"].pop(field, None)
return response_dict
except Exception as e:
verbose_proxy_logger.exception(
"litellm.proxy.management_endpoints.customer_endpoints.new_end_user(): Exception occured - {}".format(
@@ -351,7 +425,7 @@ async def end_user_info(
)
user_info = await prisma_client.db.litellm_endusertable.find_first(
where={"user_id": end_user_id}, include={"litellm_budget_table": True}
where={"user_id": end_user_id}, include={"litellm_budget_table": True, "object_permission": True}
)
if user_info is None:
@@ -361,7 +435,15 @@ async def end_user_info(
code=404,
param="end_user_id",
)
return user_info.model_dump(exclude_none=True)
# Convert to dict and clean up recursive fields
response_dict = user_info.model_dump(exclude_none=True)
if response_dict.get("object_permission"):
# Remove reverse relations from object_permission
for field in ["teams", "verification_tokens", "organizations", "users", "end_users"]:
response_dict["object_permission"].pop(field, None)
return response_dict
except Exception as e:
verbose_proxy_logger.exception(
@@ -401,6 +483,16 @@ async def update_end_user(
- default_model: Optional[str] = (
None # if no equivalent model in allowed region - default all requests to this model
)
- object_permission: Optional[LiteLLM_ObjectPermissionBase] - Customer-specific object permissions to control access to resources.
Supported fields:
* mcp_servers: List[str] - List of allowed MCP server IDs
* mcp_access_groups: List[str] - List of MCP access group names
* mcp_tool_permissions: Dict[str, List[str]] - Map of server ID to allowed tool names
* vector_stores: List[str] - List of allowed vector store IDs
* agents: List[str] - List of allowed agent IDs
* agent_access_groups: List[str] - List of agent access group names
Example: {"mcp_servers": ["server_1"], "vector_stores": ["vector_store_1"]}
IF null or {} then no object-level restrictions apply.
Example curl:
```
@@ -412,11 +504,24 @@ async def update_end_user(
"budget_id": "paid_tier"
}'
See below for all params
# Updating object permissions
curl -L -X POST 'http://localhost:4000/customer/update' \
--header 'Authorization: Bearer sk-1234' \
--header 'Content-Type: application/json' \
--data '{
"user_id": "user_1",
"object_permission": {
"mcp_servers": ["server_3"],
"vector_stores": ["vector_store_2", "vector_store_3"]
}
}'
See below for all params
```
"""
from litellm.proxy.proxy_server import litellm_proxy_admin_name, prisma_client
from litellm.proxy.proxy_server import (litellm_proxy_admin_name,
prisma_client)
try:
data_json: dict = data.json()
@@ -467,6 +572,14 @@ async def update_end_user(
elif k in LiteLLM_EndUserTable.model_fields.keys():
update_end_user_table_data[k] = v
## Handle object permission updates (MCP servers, vector stores, etc.)
await _handle_customer_object_permission_update(
non_default_values=non_default_values,
end_user_table_data_typed=end_user_table_data_typed,
update_end_user_table_data=update_end_user_table_data,
prisma_client=prisma_client,
)
## Check if we need to create a new budget (only if budget fields are provided, not just budget_id) ##
if budget_table_data:
if end_user_budget_table is None:
@@ -498,11 +611,20 @@ async def update_end_user(
## Update user table, with update params + new budget id (if set) ##
verbose_proxy_logger.debug("/customer/update: Received data = %s", data)
# Ensure object_permission is not in the update data
# It should have been converted to object_permission_id by handle_update_object_permission_common
if "object_permission" in update_end_user_table_data:
verbose_proxy_logger.warning(
f"object_permission still in update_end_user_table_data: {update_end_user_table_data.get('object_permission')}"
)
update_end_user_table_data.pop("object_permission", None)
if data.user_id is not None and len(data.user_id) > 0:
update_end_user_table_data["user_id"] = data.user_id # type: ignore
verbose_proxy_logger.debug("In update customer, user_id condition block.")
response = await prisma_client.db.litellm_endusertable.update(
where={"user_id": data.user_id}, data=update_end_user_table_data, include={"litellm_budget_table": True} # type: ignore
where={"user_id": data.user_id}, data=update_end_user_table_data, include={"litellm_budget_table": True, "object_permission": True} # type: ignore
)
if response is None:
raise ValueError(
@@ -511,7 +633,15 @@ async def update_end_user(
verbose_proxy_logger.debug(
f"received response from updating prisma client. response={response}"
)
return response
# Convert to dict and clean up recursive fields
response_dict = response.model_dump()
if response_dict.get("object_permission"):
# Remove reverse relations from object_permission
for field in ["teams", "verification_tokens", "organizations", "users", "end_users"]:
response_dict["object_permission"].pop(field, None)
return response_dict
else:
raise ValueError(f"user_id is required, passed user_id = {data.user_id}")
@@ -663,12 +793,17 @@ async def list_end_user(
)
response = await prisma_client.db.litellm_endusertable.find_many(
include={"litellm_budget_table": True}
include={"litellm_budget_table": True, "object_permission": True}
)
returned_response: List[LiteLLM_EndUserTable] = []
for item in response:
returned_response.append(LiteLLM_EndUserTable(**item.model_dump()))
item_dict = item.model_dump()
# Remove reverse relations from object_permission
if item_dict.get("object_permission"):
for field in ["teams", "verification_tokens", "organizations", "users", "end_users"]:
item_dict["object_permission"].pop(field, None)
returned_response.append(LiteLLM_EndUserTable(**item_dict))
return returned_response
except Exception as e:
@@ -706,9 +841,7 @@ async def get_customer_daily_activity(
"""
Get daily activity for specific organizations or all accessible organizations.
"""
from litellm.proxy.proxy_server import (
prisma_client,
)
from litellm.proxy.proxy_server import prisma_client
if prisma_client is None:
raise HTTPException(
+3
View File
@@ -233,6 +233,7 @@ model LiteLLM_ObjectPermissionTable {
verification_tokens LiteLLM_VerificationToken[]
organizations LiteLLM_OrganizationTable[]
users LiteLLM_UserTable[]
end_users LiteLLM_EndUserTable[]
}
// Holds the MCP server configuration
@@ -403,7 +404,9 @@ model LiteLLM_EndUserTable {
allowed_model_region String? // require all user requests to use models in this specific region
default_model String? // use along with 'allowed_model_region'. if no available model in region, default to this model.
budget_id String?
object_permission_id String?
litellm_budget_table LiteLLM_BudgetTable? @relation(fields: [budget_id], references: [budget_id])
object_permission LiteLLM_ObjectPermissionTable? @relation(fields: [object_permission_id], references: [object_permission_id])
blocked Boolean @default(false)
}
+1
View File
@@ -1321,6 +1321,7 @@ class ProxyLogging:
metadata = data.get("metadata", data.get("litellm_metadata", {})) or {}
pipeline_managed: set = metadata.get("_pipeline_managed_guardrails", set())
for callback in litellm.callbacks:
start_time = time.time()
_callback = None
@@ -6,6 +6,7 @@ from collections.abc import Sequence
from typing import Any, Dict, List, Literal, Optional, Set, Tuple, Union, cast
from openai.types.responses import ResponseFunctionToolCall
from openai.types.responses.response_create_params import ResponseInputParam
from openai.types.responses.tool_param import FunctionToolParam
from typing_extensions import TypedDict
@@ -32,7 +33,6 @@ from litellm.types.llms.openai import (
OpenAIWebSearchUserLocation,
OutputTokensDetails,
ResponseAPIUsage,
ResponseInputParam,
ResponsesAPIOptionalRequestParams,
ResponsesAPIResponse,
ResponsesAPIStatus,
@@ -738,9 +738,25 @@ class LiteLLMCompletionResponsesConfig:
@staticmethod
def _ensure_tool_results_have_corresponding_tool_calls(
messages: List[Union[AllMessageValues, GenericChatCompletionMessage, ChatCompletionResponseMessage]],
messages: Sequence[
Union[
AllMessageValues,
GenericChatCompletionMessage,
ChatCompletionResponseMessage,
ChatCompletionMessageToolCall,
Message,
]
],
tools: Optional[List[Any]] = None,
) -> List[Union[AllMessageValues, GenericChatCompletionMessage, ChatCompletionResponseMessage]]:
) -> List[
Union[
AllMessageValues,
GenericChatCompletionMessage,
ChatCompletionResponseMessage,
ChatCompletionMessageToolCall,
Message,
]
]:
"""
Ensure that tool_result messages have corresponding tool_calls in the previous assistant message.
@@ -755,11 +771,19 @@ class LiteLLMCompletionResponsesConfig:
List of messages with tool_calls added to assistant messages when needed
"""
if not messages:
return messages
# Create a deep copy to avoid modifying the original
return list(messages)
# Create a deep copy to avoid modifying the original (use list() so we can mutate and return List)
import copy
fixed_messages = copy.deepcopy(messages)
fixed_messages: List[
Union[
AllMessageValues,
GenericChatCompletionMessage,
ChatCompletionResponseMessage,
ChatCompletionMessageToolCall,
Message,
]
] = list(copy.deepcopy(messages))
messages_to_remove = []
# Count non-tool messages to avoid removing all messages
@@ -1306,6 +1330,50 @@ class LiteLLMCompletionResponsesConfig:
chat_completion_tools.append(cast(Union[ChatCompletionToolParam, OpenAIMcpServerTool], tool))
return chat_completion_tools, web_search_options
@staticmethod
def transform_chat_completion_tool_params_to_responses_api_tools(
chat_completion_tools: Optional[
List[Union[ChatCompletionToolParam, OpenAIMcpServerTool]]
],
) -> List[Dict[str, Any]]:
"""
Transform Chat Completion tool params (e.g. from guardrail output) back to
Responses API request tool format. Inverse of
transform_responses_api_tools_to_chat_completion_tools for the tools list.
"""
if chat_completion_tools is None or not chat_completion_tools:
return []
result: List[Dict[str, Any]] = []
for tool in chat_completion_tools:
if not isinstance(tool, dict):
result.append(tool) # type: ignore
continue
if tool.get("type") == "function":
fn = tool.get("function") or {}
parameters = dict(fn.get("parameters", {}) or {})
if not parameters or "type" not in parameters:
parameters["type"] = "object"
responses_tool: Dict[str, Any] = {
"type": "function",
"name": fn.get("name") or "",
"description": fn.get("description") or "",
"parameters": parameters,
"strict": fn.get("strict", False) or False,
}
if tool.get("cache_control") is not None:
responses_tool["cache_control"] = tool.get("cache_control")
if tool.get("defer_loading") is not None:
responses_tool["defer_loading"] = tool.get("defer_loading")
if tool.get("allowed_callers") is not None:
responses_tool["allowed_callers"] = tool.get("allowed_callers")
if tool.get("input_examples") is not None:
responses_tool["input_examples"] = tool.get("input_examples")
result.append(responses_tool)
else:
# mcp or other: pass through unchanged
result.append(dict(tool))
return result
@staticmethod
def transform_chat_completion_tools_to_responses_tools(
chat_completion_response: ModelResponse,
+1
View File
@@ -70,6 +70,7 @@ class SupportedGuardrailIntegrations(Enum):
GENERIC_GUARDRAIL_API = "generic_guardrail_api"
QUALIFIRE = "qualifire"
CUSTOM_CODE = "custom_code"
MCP_END_USER_PERMISSION = "mcp_end_user_permission"
class Role(Enum):
@@ -0,0 +1,12 @@
from .base import GuardrailConfigModel
class MCPEndUserPermissionGuardrailConfigModel(GuardrailConfigModel):
"""
No provider-specific params required permissions come from the end user
object already stored in the database.
"""
@staticmethod
def ui_friendly_name() -> str:
return "MCP End User Permission"
+3
View File
@@ -233,6 +233,7 @@ model LiteLLM_ObjectPermissionTable {
verification_tokens LiteLLM_VerificationToken[]
organizations LiteLLM_OrganizationTable[]
users LiteLLM_UserTable[]
end_users LiteLLM_EndUserTable[]
}
// Holds the MCP server configuration
@@ -403,7 +404,9 @@ model LiteLLM_EndUserTable {
allowed_model_region String? // require all user requests to use models in this specific region
default_model String? // use along with 'allowed_model_region'. if no available model in region, default to this model.
budget_id String?
object_permission_id String?
litellm_budget_table LiteLLM_BudgetTable? @relation(fields: [budget_id], references: [budget_id])
object_permission LiteLLM_ObjectPermissionTable? @relation(fields: [object_permission_id], references: [object_permission_id])
blocked Boolean @default(false)
}
@@ -0,0 +1,414 @@
"""
Tests for MCP End User Permission Guardrail Hook
"""
import os
import sys
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
sys.path.insert(
0, os.path.abspath("../../../..")
) # Adds the parent directory to the system path
from litellm.exceptions import GuardrailRaisedException
from litellm.proxy._types import UserAPIKeyAuth
from litellm.proxy.guardrails.guardrail_hooks.mcp_end_user_permission import (
MCPEndUserPermissionGuardrail,
)
from litellm.types.utils import (
ChatCompletionMessageToolCall,
Choices,
Function,
Message,
ModelResponse,
)
class TestMCPEndUserPermissionGuardrail:
"""Test the MCP End User Permission Guardrail"""
def test_extract_mcp_server_name(self):
"""Test extracting MCP server name from tool name"""
guardrail = MCPEndUserPermissionGuardrail()
# Test valid MCP tool names
assert guardrail._extract_mcp_server_name("github-create_issue") == "github"
assert guardrail._extract_mcp_server_name("slack-send_message") == "slack"
assert guardrail._extract_mcp_server_name("jira-create-ticket") == "jira"
# Test invalid/non-MCP tool names
assert guardrail._extract_mcp_server_name("search") is None
assert guardrail._extract_mcp_server_name("") is None
assert guardrail._extract_mcp_server_name(None) is None
@pytest.mark.asyncio
async def test_apply_guardrail_no_end_user(self):
"""Test guardrail when no end_user_id is present"""
guardrail = MCPEndUserPermissionGuardrail()
# Create inputs with MCP tools
inputs = {
"tools": [
{
"type": "function",
"function": {
"name": "github-create_issue",
"description": "Create an issue",
},
}
]
}
request_data = {}
# Should pass through all tools when no end_user_id
result = await guardrail.apply_guardrail(
inputs=inputs,
request_data=request_data,
input_type="request",
)
assert len(result.get("tools", [])) == 1
assert result["tools"][0]["function"]["name"] == "github-create_issue"
@pytest.mark.asyncio
async def test_apply_guardrail_with_authorized_tools(self):
"""Test guardrail when end user has access to MCP servers"""
from litellm.proxy._types import LiteLLM_ObjectPermissionTable
guardrail = MCPEndUserPermissionGuardrail()
# Create inputs with multiple tools
inputs = {
"tools": [
{
"type": "function",
"function": {
"name": "github-create_issue",
"description": "Create an issue",
},
},
{
"type": "function",
"function": {
"name": "slack-send_message",
"description": "Send a message",
},
},
{
"type": "function",
"function": {
"name": "search",
"description": "Regular search tool",
},
},
]
}
request_data = {"user_api_key_end_user_id": "end-user-123"}
# Mock fetching end user object with permissions
with patch.object(
MCPEndUserPermissionGuardrail,
"_fetch_end_user_object",
return_value=MagicMock(
object_permission=LiteLLM_ObjectPermissionTable(
object_permission_id="perm-1",
mcp_servers=["github", "slack"],
)
),
):
result = await guardrail.apply_guardrail(
inputs=inputs,
request_data=request_data,
input_type="request",
)
# Should keep all authorized MCP tools + non-MCP tools
assert len(result.get("tools", [])) == 3
@pytest.mark.asyncio
async def test_apply_guardrail_with_unauthorized_tools(self):
"""Test guardrail filters out unauthorized MCP tools"""
from litellm.proxy._types import LiteLLM_ObjectPermissionTable
guardrail = MCPEndUserPermissionGuardrail()
# Create inputs with tools where end user only has access to some
inputs = {
"tools": [
{
"type": "function",
"function": {
"name": "github-create_issue",
"description": "Create an issue",
},
},
{
"type": "function",
"function": {
"name": "slack-send_message",
"description": "Send a message",
},
},
{
"type": "function",
"function": {
"name": "jira-create_ticket",
"description": "Create a ticket",
},
},
]
}
request_data = {"user_api_key_end_user_id": "end-user-123"}
# Mock fetching end user object with limited permissions (only slack and jira)
with patch.object(
MCPEndUserPermissionGuardrail,
"_fetch_end_user_object",
return_value=MagicMock(
object_permission=LiteLLM_ObjectPermissionTable(
object_permission_id="perm-1",
mcp_servers=["slack", "jira"],
)
),
):
result = await guardrail.apply_guardrail(
inputs=inputs,
request_data=request_data,
input_type="request",
)
# Should filter out github tool
assert len(result.get("tools", [])) == 2
tool_names = [t["function"]["name"] for t in result["tools"]]
assert "slack-send_message" in tool_names
assert "jira-create_ticket" in tool_names
assert "github-create_issue" not in tool_names
@pytest.mark.asyncio
async def test_apply_guardrail_no_permissions_configured(self):
"""Test guardrail when end user has no MCP permissions configured"""
guardrail = MCPEndUserPermissionGuardrail()
# Create inputs with MCP tools
inputs = {
"tools": [
{
"type": "function",
"function": {
"name": "github-create_issue",
"description": "Create an issue",
},
}
]
}
request_data = {"user_api_key_end_user_id": "end-user-123"}
# Mock fetching end user object with no object_permission
with patch.object(
MCPEndUserPermissionGuardrail,
"_fetch_end_user_object",
return_value=MagicMock(object_permission=None),
):
result = await guardrail.apply_guardrail(
inputs=inputs,
request_data=request_data,
input_type="request",
)
# Should pass through all tools when no permissions configured
assert len(result.get("tools", [])) == 1
assert result["tools"][0]["function"]["name"] == "github-create_issue"
@pytest.mark.asyncio
async def test_apply_guardrail_with_non_mcp_tools(self):
"""Test guardrail passes through non-MCP tools"""
from litellm.proxy._types import LiteLLM_ObjectPermissionTable
guardrail = MCPEndUserPermissionGuardrail()
# Create inputs with non-MCP tools
inputs = {
"tools": [
{
"type": "function",
"function": {
"name": "search",
"description": "Search tool",
},
},
{
"type": "function",
"function": {
"name": "calculate",
"description": "Calculate something",
},
},
]
}
request_data = {"user_api_key_end_user_id": "end-user-123"}
# Mock fetching end user object with MCP restrictions
with patch.object(
MCPEndUserPermissionGuardrail,
"_fetch_end_user_object",
return_value=MagicMock(
object_permission=LiteLLM_ObjectPermissionTable(
object_permission_id="perm-1",
mcp_servers=["github"],
)
),
):
result = await guardrail.apply_guardrail(
inputs=inputs,
request_data=request_data,
input_type="request",
)
# Should keep all non-MCP tools even with MCP restrictions
assert len(result.get("tools", [])) == 2
@pytest.mark.asyncio
async def test_apply_guardrail_filters_unauthorized_mcp_tools(self):
"""Test guardrail filters out unauthorized MCP tools"""
from litellm.proxy._types import LiteLLM_ObjectPermissionTable
guardrail = MCPEndUserPermissionGuardrail()
# Create inputs with MCP tools where user only has access to some
inputs = {
"tools": [
{
"type": "function",
"function": {
"name": "github-create_issue",
"description": "Create an issue",
},
},
{
"type": "function",
"function": {
"name": "slack-send_message",
"description": "Send a message",
},
},
{
"type": "function",
"function": {
"name": "jira-create_ticket",
"description": "Create a ticket",
},
},
]
}
request_data = {"user_api_key_end_user_id": "end-user-123"}
# Mock fetching end user object - only has access to slack and jira, not github
with patch.object(
MCPEndUserPermissionGuardrail,
"_fetch_end_user_object",
return_value=MagicMock(
object_permission=LiteLLM_ObjectPermissionTable(
object_permission_id="perm-1",
mcp_servers=["slack", "jira"],
)
),
):
result = await guardrail.apply_guardrail(
inputs=inputs,
request_data=request_data,
input_type="request",
)
# Should filter out github tool
assert len(result.get("tools", [])) == 2
tool_names = [t["function"]["name"] for t in result["tools"]]
assert "slack-send_message" in tool_names
assert "jira-create_ticket" in tool_names
assert "github-create_issue" not in tool_names
@pytest.mark.asyncio
async def test_apply_guardrail_with_mixed_tools(self):
"""Test guardrail with both MCP and non-MCP tools"""
from litellm.proxy._types import LiteLLM_ObjectPermissionTable
guardrail = MCPEndUserPermissionGuardrail()
# Create inputs with both MCP and non-MCP tools
inputs = {
"tools": [
{
"type": "function",
"function": {
"name": "github-create_issue",
"description": "Create an issue",
},
},
{
"type": "function",
"function": {
"name": "search",
"description": "Search tool",
},
},
{
"type": "function",
"function": {
"name": "slack-send_message",
"description": "Send a message",
},
},
]
}
request_data = {"user_api_key_end_user_id": "end-user-123"}
# Mock fetching end user object - only has access to slack
with patch.object(
MCPEndUserPermissionGuardrail,
"_fetch_end_user_object",
return_value=MagicMock(
object_permission=LiteLLM_ObjectPermissionTable(
object_permission_id="perm-1",
mcp_servers=["slack"],
)
),
):
result = await guardrail.apply_guardrail(
inputs=inputs,
request_data=request_data,
input_type="request",
)
# Should keep slack MCP tool and non-MCP search tool, filter out github
assert len(result.get("tools", [])) == 2
tool_names = [t["function"]["name"] for t in result["tools"]]
assert "search" in tool_names
assert "slack-send_message" in tool_names
assert "github-create_issue" not in tool_names
@pytest.mark.asyncio
async def test_apply_guardrail_no_tools_in_request(self):
"""Test guardrail when request has no tools"""
guardrail = MCPEndUserPermissionGuardrail()
# Create inputs without tools
inputs = {"model": "gpt-4", "messages": [{"role": "user", "content": "test"}]}
request_data = {"user_api_key_end_user_id": "end-user-123"}
# Should pass through unchanged
result = await guardrail.apply_guardrail(
inputs=inputs,
request_data=request_data,
input_type="request",
)
assert result == inputs
assert "tools" not in result