mirror of
https://github.com/tiennm99/litellm.git
synced 2026-06-27 01:07:02 +00:00
Mcp user permissions (#21462)
* feat(schema.prisma): add object permissions for end users
allows controlling if end user can call specific mcp servers
* feat: cleanup for customer_endpoints support of object permission id
* fix: cleanup str
* feat(customers/): enforce end user can only call allowed mcps - if configured
* docs: document customer/end user object permission usage
* feat: enforce end user permissions on MCP tool calls
This commit implements end user permission enforcement for MCP servers:
1. Always add server prefixes to MCP tool names
- Removed conditional logic that only added prefixes when multiple servers existed
- Now always adds server prefix for consistent tool naming across all scenarios
- Updated 5 locations in server.py (list_tools, get_prompts, get_resources,
get_resource_templates, get_prompt)
2. Created MCP End User Permission Guardrail Hook
- New guardrail hook: litellm/proxy/guardrails/guardrail_hooks/mcp_end_user_permission.py
- Runs on post_call to validate tool calls in LLM responses
- Extracts MCP server name from tool names (splits on first '-')
- Checks if end_user_id has permissions for the MCP server
- Raises GuardrailRaisedException if end user lacks permission
- Supports both streaming and non-streaming responses
3. Added comprehensive tests
- Test file: tests/test_litellm/proxy/guardrails/guardrail_hooks/test_mcp_end_user_permission.py
- Tests cover: authorized/unauthorized tools, non-MCP tools, no end_user scenarios
- Tests permission checking logic and exception raising
The hook integrates with the existing MCPRequestHandler._get_allowed_mcp_servers_for_end_user
to fetch end user permissions and enforce access control at the response level.
Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
* refactor: remove redundant add_prefix variable assignments
Simplified the code by removing intermediate `add_prefix` variable
assignments and passing `True` directly to function calls since
we now always add server prefixes.
Changes:
- Removed `add_prefix = True` variable assignments in 5 locations
- Changed `add_prefix=add_prefix` to `add_prefix=True` in function calls
- Added inline comments to clarify the behavior
This makes the code more concise and clearer in intent.
Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
* feat(auth_utils.py): support safety_identifier as a valid way of passing the end user id for responses api
* feat(llms): ensure 'tools' is correctly updated for responses api
* fix: fix greptile feedback
* feat: transformation.py
proper responses api tool handling for guardrail translation layer
---------
Co-authored-by: Claude Sonnet 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,36 @@
|
||||
{
|
||||
"permissions": {
|
||||
"allow": [
|
||||
"Bash(git show:*)",
|
||||
"Bash(git worktree add:*)",
|
||||
"Read(//Users/krrishdholakia/Documents/litellm/**)",
|
||||
"Read(//Users/krrishdholakia/Documents/litellm-claude-code-guardrails/litellm/types/**)",
|
||||
"Read(//Users/krrishdholakia/Documents/litellm-claude-code-guardrails/**)",
|
||||
"Read(//Users/krrishdholakia/Documents/litellm-claude-code-guardrails/litellm/**)",
|
||||
"Bash(python:*)",
|
||||
"Bash(python -c \"\nimport sys; sys.path.insert\\(0, ''.''\\)\nfrom litellm.proxy.guardrails.guardrail_hooks.claude_code.guardrail import ClaudeCodeGuardrail, HOSTED_TOOL_PREFIXES\nprint\\(''HOSTED_TOOL_PREFIXES:'', HOSTED_TOOL_PREFIXES\\)\nprint\\(''ClaudeCodeGuardrail imported OK''\\)\n\")",
|
||||
"Read(//Users/krrishdholakia/Documents/litellm-mcp-jwt-groups/litellm/proxy/**)",
|
||||
"Read(//Users/krrishdholakia/Documents/litellm-mcp-jwt-groups/**)",
|
||||
"Bash(poetry run pytest:*)",
|
||||
"Bash(git add:*)",
|
||||
"Bash(git commit:*)",
|
||||
"Bash(poetry run python:*)",
|
||||
"Bash(poetry run pip:*)",
|
||||
"Bash(git reset:*)",
|
||||
"Bash(git cherry-pick:*)",
|
||||
"Bash(git checkout:*)",
|
||||
"Read(//Users/krrishdholakia/Documents/litellm/litellm/proxy/guardrails/guardrail_hooks/**)",
|
||||
"Read(//Users/krrishdholakia/Documents/**)",
|
||||
"Bash(git -C /Users/krrishdholakia/Documents/litellm-mcp-user-permissions worktree list)",
|
||||
"Bash(ls:*)"
|
||||
],
|
||||
"additionalDirectories": [
|
||||
"/Users/krrishdholakia/Documents/litellm-mcp-group-plan/plan",
|
||||
"/Users/krrishdholakia/Documents/litellm-claude-code-guardrails/litellm/proxy/guardrails/guardrail_hooks/claude_code",
|
||||
"/Users/krrishdholakia/Documents/litellm-claude-code-guardrails/litellm/types",
|
||||
"/Users/krrishdholakia/Documents/litellm-claude-code-guardrails",
|
||||
"/Users/krrishdholakia/Documents/litellm-mcp-jwt-groups/litellm/proxy",
|
||||
"/Users/krrishdholakia/Documents/litellm-mcp-jwt-groups/tests/test_litellm/proxy/auth"
|
||||
]
|
||||
}
|
||||
}
|
||||
@@ -808,6 +808,68 @@ If your stdio MCP server needs per-request credentials, you can map HTTP headers
|
||||
|
||||
In this example, when a client makes a request with the `X-GITHUB_PERSONAL_ACCESS_TOKEN` header, the proxy forwards that value into the stdio process as the `GITHUB_PERSONAL_ACCESS_TOKEN` environment variable.
|
||||
|
||||
## Control MCP Access for End Users
|
||||
|
||||
Control which MCP servers end users of your AI application can access (e.g. users of an internal chat UI). Pass the customer ID in the `x-litellm-end-user-id` header to:
|
||||
- Enforce object permissions (limit which MCP servers they can access)
|
||||
- Apply customer-specific budgets
|
||||
- Track spend per customer
|
||||
|
||||
**FastMCP Client Example:**
|
||||
|
||||
```python title="Track customer spend with x-litellm-end-user-id" showLineNumbers
|
||||
from fastmcp import Client
|
||||
import asyncio
|
||||
|
||||
# MCP client configuration with customer tracking
|
||||
config = {
|
||||
"mcpServers": {
|
||||
"github": {
|
||||
"url": "http://localhost:4000/github_mcp/mcp",
|
||||
"headers": {
|
||||
"x-litellm-api-key": "Bearer sk-1234",
|
||||
"x-litellm-end-user-id": "customer_123", # 👈 CUSTOMER ID
|
||||
"Authorization": "Bearer gho_token"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
client = Client(config)
|
||||
|
||||
async def main():
|
||||
async with client:
|
||||
# All MCP calls will be tracked under customer_123
|
||||
tools = await client.list_tools()
|
||||
result = await client.call_tool(tools[0].name, {})
|
||||
print(f"Tool result: {result}")
|
||||
|
||||
asyncio.run(main())
|
||||
```
|
||||
|
||||
**Cursor IDE Example:**
|
||||
|
||||
```json title="Cursor config with customer tracking" showLineNumbers
|
||||
{
|
||||
"mcpServers": {
|
||||
"GitHub": {
|
||||
"url": "http://localhost:4000/github_mcp/mcp",
|
||||
"headers": {
|
||||
"x-litellm-api-key": "Bearer $LITELLM_API_KEY",
|
||||
"x-litellm-end-user-id": "customer_123"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**What happens:**
|
||||
- Customer-specific object permissions are enforced (only allowed MCP servers are accessible)
|
||||
- Customer budgets are applied
|
||||
- All tool calls are tracked under `customer_123`
|
||||
|
||||
[Learn more about customer management →](./proxy/customers)
|
||||
|
||||
## Using your MCP with client side credentials
|
||||
|
||||
Use this if you want to pass a client side authentication token to LiteLLM to then pass to your MCP to auth to your MCP.
|
||||
|
||||
@@ -2,29 +2,98 @@ import Image from '@theme/IdealImage';
|
||||
import Tabs from '@theme/Tabs';
|
||||
import TabItem from '@theme/TabItem';
|
||||
|
||||
# Customers / End-User Budgets
|
||||
# Customers / End-Users
|
||||
|
||||
Track spend, set budgets for your customers.
|
||||
Track spend, set budgets and permissions for your customers.
|
||||
|
||||
## Tracking Customer Spend
|
||||
## Tracking Customer Spend + Permissions
|
||||
|
||||
### 1. Make LLM API call w/ Customer ID
|
||||
|
||||
Make a /chat/completions call, pass 'user' - First call Works
|
||||
LiteLLM checks for a customer/end-user ID in the following order (first match wins):
|
||||
|
||||
```bash showLineNumbers title="Make request with customer ID"
|
||||
| Priority | Method | Where | Notes |
|
||||
|----------|--------|-------|-------|
|
||||
| 1 | `x-litellm-customer-id` header | Request headers | Standard header, always checked |
|
||||
| 2 | `x-litellm-end-user-id` header | Request headers | Standard header, always checked |
|
||||
| 3 | Custom header via `user_header_mappings` | Request headers | Configured in `general_settings` |
|
||||
| 4 | Custom header via `user_header_name` | Request headers | Deprecated — use `user_header_mappings` |
|
||||
| 5 | `user` field | Request body | Standard OpenAI field |
|
||||
| 6 | `litellm_metadata.user` field | Request body | Anthropic-style metadata |
|
||||
| 7 | `metadata.user_id` field | Request body | Generic metadata pattern |
|
||||
| 8 | `safety_identifier` field | Request body | Responses API |
|
||||
|
||||
**Option 1: Standard headers** (recommended — no request body modification needed)
|
||||
|
||||
```bash showLineNumbers title="Make request with customer ID in header"
|
||||
curl -X POST 'http://0.0.0.0:4000/chat/completions' \
|
||||
--header 'Content-Type: application/json' \
|
||||
--header 'Authorization: Bearer sk-1234' \ # 👈 YOUR PROXY KEY
|
||||
--data ' {
|
||||
--header 'Authorization: Bearer sk-1234' \
|
||||
--header 'x-litellm-end-user-id: ishaan3' \
|
||||
--data '{
|
||||
"model": "azure-gpt-3.5",
|
||||
"user": "ishaan3", # 👈 CUSTOMER ID
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "what time is it"
|
||||
}
|
||||
]
|
||||
"messages": [{"role": "user", "content": "what time is it"}]
|
||||
}'
|
||||
```
|
||||
|
||||
Both `x-litellm-customer-id` and `x-litellm-end-user-id` are supported and always checked without any configuration.
|
||||
|
||||
**Option 2: `user` field in request body** (OpenAI-compatible)
|
||||
|
||||
```bash showLineNumbers title="Make request with customer ID in body"
|
||||
curl -X POST 'http://0.0.0.0:4000/chat/completions' \
|
||||
--header 'Content-Type: application/json' \
|
||||
--header 'Authorization: Bearer sk-1234' \
|
||||
--data '{
|
||||
"model": "azure-gpt-3.5",
|
||||
"user": "ishaan3",
|
||||
"messages": [{"role": "user", "content": "what time is it"}]
|
||||
}'
|
||||
```
|
||||
|
||||
**Option 3: Custom header via `user_header_mappings`** (configurable)
|
||||
|
||||
```yaml showLineNumbers title="config.yaml"
|
||||
general_settings:
|
||||
user_header_mappings:
|
||||
- header_name: "x-my-app-user-id"
|
||||
litellm_user_role: "customer"
|
||||
```
|
||||
|
||||
```bash showLineNumbers title="Make request with custom header"
|
||||
curl -X POST 'http://0.0.0.0:4000/chat/completions' \
|
||||
--header 'Content-Type: application/json' \
|
||||
--header 'Authorization: Bearer sk-1234' \
|
||||
--header 'x-my-app-user-id: ishaan3' \
|
||||
--data '{
|
||||
"model": "azure-gpt-3.5",
|
||||
"messages": [{"role": "user", "content": "what time is it"}]
|
||||
}'
|
||||
```
|
||||
|
||||
**Option 4: `litellm_metadata.user`** (Anthropic-style)
|
||||
|
||||
```bash showLineNumbers title="Make request with litellm_metadata.user"
|
||||
curl -X POST 'http://0.0.0.0:4000/chat/completions' \
|
||||
--header 'Content-Type: application/json' \
|
||||
--header 'Authorization: Bearer sk-1234' \
|
||||
--data '{
|
||||
"model": "claude-3-5-sonnet",
|
||||
"messages": [{"role": "user", "content": "what time is it"}],
|
||||
"litellm_metadata": {"user": "ishaan3"}
|
||||
}'
|
||||
```
|
||||
|
||||
**Option 5: `metadata.user_id`**
|
||||
|
||||
```bash showLineNumbers title="Make request with metadata.user_id"
|
||||
curl -X POST 'http://0.0.0.0:4000/chat/completions' \
|
||||
--header 'Content-Type: application/json' \
|
||||
--header 'Authorization: Bearer sk-1234' \
|
||||
--data '{
|
||||
"model": "azure-gpt-3.5",
|
||||
"messages": [{"role": "user", "content": "what time is it"}],
|
||||
"metadata": {"user_id": "ishaan3"}
|
||||
}'
|
||||
```
|
||||
|
||||
@@ -123,7 +192,171 @@ Expected Response
|
||||
</Tabs>
|
||||
|
||||
|
||||
## Setting Customer Budgets
|
||||
## Setting Customer Object Permissions
|
||||
|
||||
Control which resources (MCP servers, vector stores, agents) a customer can access.
|
||||
|
||||
### What are Object Permissions?
|
||||
|
||||
Object permissions allow you to restrict customer access to specific:
|
||||
- **MCP Servers**: Limit which MCP servers the customer can call
|
||||
- **MCP Access Groups**: Assign customers to predefined groups of MCP servers
|
||||
- **MCP Tool Permissions**: Granular control over which tools within an MCP server the customer can use
|
||||
- **Vector Stores**: Control which vector stores the customer can query
|
||||
- **Agents**: Restrict which agents the customer can interact with
|
||||
- **Agent Access Groups**: Assign customers to predefined groups of agents
|
||||
|
||||
### Creating a Customer with Object Permissions
|
||||
|
||||
```bash showLineNumbers title="Create customer with object permissions"
|
||||
curl -L -X POST 'http://localhost:4000/customer/new' \
|
||||
-H 'Authorization: Bearer sk-1234' \
|
||||
-H 'Content-Type: application/json' \
|
||||
-d '{
|
||||
"user_id": "user_1",
|
||||
"object_permission": {
|
||||
"mcp_servers": ["server_1", "server_2"],
|
||||
"mcp_access_groups": ["public_group"],
|
||||
"mcp_tool_permissions": {
|
||||
"server_1": ["tool_a", "tool_b"]
|
||||
},
|
||||
"vector_stores": ["vector_store_1"],
|
||||
"agents": ["agent_1"],
|
||||
"agent_access_groups": ["basic_agents"]
|
||||
}
|
||||
}'
|
||||
```
|
||||
|
||||
**Parameters:**
|
||||
- `mcp_servers` (Optional[List[str]]): List of allowed MCP server IDs
|
||||
- `mcp_access_groups` (Optional[List[str]]): List of MCP access group names
|
||||
- `mcp_tool_permissions` (Optional[Dict[str, List[str]]]): Map of server ID to allowed tool names
|
||||
- `vector_stores` (Optional[List[str]]): List of allowed vector store IDs
|
||||
- `agents` (Optional[List[str]]): List of allowed agent IDs
|
||||
- `agent_access_groups` (Optional[List[str]]): List of agent access group names
|
||||
|
||||
**Note:** If `object_permission` is `null` or `{}`, the customer has no object-level restrictions.
|
||||
|
||||
### Updating Customer Object Permissions
|
||||
|
||||
You can update object permissions for existing customers:
|
||||
|
||||
```bash showLineNumbers title="Update customer object permissions"
|
||||
curl -L -X POST 'http://localhost:4000/customer/update' \
|
||||
-H 'Authorization: Bearer sk-1234' \
|
||||
-H 'Content-Type: application/json' \
|
||||
-d '{
|
||||
"user_id": "user_1",
|
||||
"object_permission": {
|
||||
"mcp_servers": ["server_3"],
|
||||
"vector_stores": ["vector_store_2", "vector_store_3"]
|
||||
}
|
||||
}'
|
||||
```
|
||||
|
||||
### Viewing Customer Object Permissions
|
||||
|
||||
When you query customer info, object permissions are included in the response:
|
||||
|
||||
```bash showLineNumbers title="Get customer info with object permissions"
|
||||
curl -X GET 'http://0.0.0.0:4000/customer/info?end_user_id=user_1' \
|
||||
-H 'Authorization: Bearer sk-1234'
|
||||
```
|
||||
|
||||
**Response:**
|
||||
```json showLineNumbers title="Response with object permissions"
|
||||
{
|
||||
"user_id": "user_1",
|
||||
"blocked": false,
|
||||
"alias": "John Doe",
|
||||
"spend": 0.0,
|
||||
"object_permission": {
|
||||
"object_permission_id": "perm_abc123",
|
||||
"mcp_servers": ["server_1", "server_2"],
|
||||
"mcp_access_groups": ["public_group"],
|
||||
"mcp_tool_permissions": {
|
||||
"server_1": ["tool_a", "tool_b"]
|
||||
},
|
||||
"vector_stores": ["vector_store_1"],
|
||||
"agents": ["agent_1"],
|
||||
"agent_access_groups": ["basic_agents"]
|
||||
},
|
||||
"litellm_budget_table": null
|
||||
}
|
||||
```
|
||||
|
||||
### Use Cases
|
||||
|
||||
**1. Tiered Access Control**
|
||||
Create different permission tiers for your customers:
|
||||
|
||||
```bash showLineNumbers title="Free tier customer"
|
||||
# Free tier - limited access
|
||||
curl -L -X POST 'http://localhost:4000/customer/new' \
|
||||
-H 'Authorization: Bearer sk-1234' \
|
||||
-H 'Content-Type: application/json' \
|
||||
-d '{
|
||||
"user_id": "free_user",
|
||||
"budget_id": "free_tier",
|
||||
"object_permission": {
|
||||
"mcp_access_groups": ["public_group"],
|
||||
"agent_access_groups": ["basic_agents"]
|
||||
}
|
||||
}'
|
||||
```
|
||||
|
||||
```bash showLineNumbers title="Premium tier customer"
|
||||
# Premium tier - full access
|
||||
curl -L -X POST 'http://localhost:4000/customer/new' \
|
||||
-H 'Authorization: Bearer sk-1234' \
|
||||
-H 'Content-Type: application/json' \
|
||||
-d '{
|
||||
"user_id": "premium_user",
|
||||
"budget_id": "premium_tier",
|
||||
"object_permission": {
|
||||
"mcp_servers": ["server_1", "server_2", "server_3"],
|
||||
"vector_stores": ["vector_store_1", "vector_store_2"],
|
||||
"agents": ["agent_1", "agent_2", "agent_3"]
|
||||
}
|
||||
}'
|
||||
```
|
||||
|
||||
**2. Department-Specific Access**
|
||||
Restrict customers to resources relevant to their department:
|
||||
|
||||
```bash showLineNumbers title="Sales team customer"
|
||||
curl -L -X POST 'http://localhost:4000/customer/new' \
|
||||
-H 'Authorization: Bearer sk-1234' \
|
||||
-H 'Content-Type: application/json' \
|
||||
-d '{
|
||||
"user_id": "sales_user",
|
||||
"object_permission": {
|
||||
"mcp_servers": ["crm_server", "email_server"],
|
||||
"agents": ["sales_assistant"],
|
||||
"vector_stores": ["sales_knowledge_base"]
|
||||
}
|
||||
}'
|
||||
```
|
||||
|
||||
**3. Tool-Level Restrictions**
|
||||
Grant access to specific tools within an MCP server:
|
||||
|
||||
```bash showLineNumbers title="Limited tool access"
|
||||
curl -L -X POST 'http://localhost:4000/customer/new' \
|
||||
-H 'Authorization: Bearer sk-1234' \
|
||||
-H 'Content-Type: application/json' \
|
||||
-d '{
|
||||
"user_id": "restricted_user",
|
||||
"object_permission": {
|
||||
"mcp_servers": ["database_server"],
|
||||
"mcp_tool_permissions": {
|
||||
"database_server": ["read_only_query", "get_table_schema"]
|
||||
}
|
||||
}
|
||||
}'
|
||||
```
|
||||
|
||||
## Setting Customer Budgets
|
||||
|
||||
Set customer budgets (e.g. monthly budgets, tpm/rpm limits) on LiteLLM Proxy
|
||||
|
||||
|
||||
@@ -20,6 +20,10 @@ By default, LiteLLM does not forward client headers to LLM provider APIs. Howeve
|
||||
|
||||
`x-litellm-spend-logs-metadata`: Optional[str]: JSON string containing custom metadata to include in spend logs. Example: `{"user_id": "12345", "project_id": "proj_abc", "request_type": "chat_completion"}`. [Learn More](../proxy/enterprise#tracking-spend-with-custom-metadata)
|
||||
|
||||
`x-litellm-customer-id`: Optional[str]: Standard header for passing a customer/end-user ID. Always checked without any configuration. [Learn More](./customers)
|
||||
|
||||
`x-litellm-end-user-id`: Optional[str]: Standard header for passing a customer/end-user ID. Always checked without any configuration. [Learn More](./customers)
|
||||
|
||||
## Anthropic Headers
|
||||
|
||||
`anthropic-version` Optional[str]: The version of the Anthropic API to use.
|
||||
|
||||
+6
@@ -0,0 +1,6 @@
|
||||
-- AlterTable
|
||||
ALTER TABLE "LiteLLM_EndUserTable" ADD COLUMN "object_permission_id" TEXT;
|
||||
|
||||
-- AddForeignKey
|
||||
ALTER TABLE "LiteLLM_EndUserTable" ADD CONSTRAINT "LiteLLM_EndUserTable_object_permission_id_fkey" FOREIGN KEY ("object_permission_id") REFERENCES "LiteLLM_ObjectPermissionTable"("object_permission_id") ON DELETE SET NULL ON UPDATE CASCADE;
|
||||
|
||||
@@ -233,6 +233,7 @@ model LiteLLM_ObjectPermissionTable {
|
||||
verification_tokens LiteLLM_VerificationToken[]
|
||||
organizations LiteLLM_OrganizationTable[]
|
||||
users LiteLLM_UserTable[]
|
||||
end_users LiteLLM_EndUserTable[]
|
||||
}
|
||||
|
||||
// Holds the MCP server configuration
|
||||
@@ -403,7 +404,9 @@ model LiteLLM_EndUserTable {
|
||||
allowed_model_region String? // require all user requests to use models in this specific region
|
||||
default_model String? // use along with 'allowed_model_region'. if no available model in region, default to this model.
|
||||
budget_id String?
|
||||
object_permission_id String?
|
||||
litellm_budget_table LiteLLM_BudgetTable? @relation(fields: [budget_id], references: [budget_id])
|
||||
object_permission LiteLLM_ObjectPermissionTable? @relation(fields: [object_permission_id], references: [object_permission_id])
|
||||
blocked Boolean @default(false)
|
||||
}
|
||||
|
||||
|
||||
@@ -124,6 +124,9 @@ class AnthropicMessagesHandler(BaseTranslation):
|
||||
)
|
||||
|
||||
guardrailed_texts = guardrailed_inputs.get("texts", [])
|
||||
guardrailed_tools = guardrailed_inputs.get("tools")
|
||||
if guardrailed_tools is not None:
|
||||
data["tools"] = guardrailed_tools
|
||||
|
||||
# Step 3: Map guardrail responses back to original message structure
|
||||
await self._apply_guardrail_responses_to_input(
|
||||
@@ -194,7 +197,7 @@ class AnthropicMessagesHandler(BaseTranslation):
|
||||
openai_tools = self.adapter.translate_anthropic_tools_to_openai(
|
||||
tools=cast(List[AllAnthropicToolsValues], tools)
|
||||
)
|
||||
tools_to_check.extend(openai_tools)
|
||||
tools_to_check.extend(openai_tools) # type: ignore
|
||||
|
||||
async def _apply_guardrail_responses_to_input(
|
||||
self,
|
||||
|
||||
@@ -107,6 +107,9 @@ class OpenAIChatCompletionsHandler(BaseTranslation):
|
||||
|
||||
guardrailed_texts = guardrailed_inputs.get("texts", [])
|
||||
guardrailed_tool_calls = guardrailed_inputs.get("tool_calls", [])
|
||||
guardrailed_tools = guardrailed_inputs.get("tools")
|
||||
if guardrailed_tools is not None:
|
||||
data["tools"] = guardrailed_tools
|
||||
|
||||
# Step 3: Map guardrail responses back to original message structure
|
||||
if guardrailed_texts and texts_to_check:
|
||||
|
||||
@@ -96,10 +96,11 @@ class OpenAIResponsesHandler(BaseTranslation):
|
||||
# Handle simple string input
|
||||
if isinstance(input_data, str):
|
||||
inputs = GenericGuardrailAPIInputs(texts=[input_data])
|
||||
original_tools: List[Dict[str, Any]] = []
|
||||
|
||||
# Extract and transform tools if present
|
||||
|
||||
if "tools" in data and data["tools"]:
|
||||
original_tools = list(data["tools"])
|
||||
self._extract_and_transform_tools(data["tools"], tools_to_check)
|
||||
if tools_to_check:
|
||||
inputs["tools"] = tools_to_check
|
||||
@@ -118,6 +119,9 @@ class OpenAIResponsesHandler(BaseTranslation):
|
||||
)
|
||||
guardrailed_texts = guardrailed_inputs.get("texts", [])
|
||||
data["input"] = guardrailed_texts[0] if guardrailed_texts else input_data
|
||||
self._apply_guardrailed_tools_to_data(
|
||||
data, original_tools, guardrailed_inputs.get("tools")
|
||||
)
|
||||
verbose_proxy_logger.debug("OpenAI Responses API: Processed string input")
|
||||
return data
|
||||
|
||||
@@ -128,8 +132,7 @@ class OpenAIResponsesHandler(BaseTranslation):
|
||||
texts_to_check: List[str] = []
|
||||
images_to_check: List[str] = []
|
||||
task_mappings: List[Tuple[int, Optional[int]]] = []
|
||||
# Track (message_index, content_index) for each text
|
||||
# content_index is None for string content, int for list content
|
||||
original_tools_list: List[Dict[str, Any]] = list(data.get("tools") or [])
|
||||
|
||||
# Step 1: Extract all text content, images, and tools
|
||||
for msg_idx, message in enumerate(input_data):
|
||||
@@ -166,6 +169,11 @@ class OpenAIResponsesHandler(BaseTranslation):
|
||||
)
|
||||
|
||||
guardrailed_texts = guardrailed_inputs.get("texts", [])
|
||||
self._apply_guardrailed_tools_to_data(
|
||||
data,
|
||||
original_tools_list,
|
||||
guardrailed_inputs.get("tools"),
|
||||
)
|
||||
|
||||
# Step 3: Map guardrail responses back to original input structure
|
||||
await self._apply_guardrail_responses_to_input(
|
||||
@@ -203,6 +211,53 @@ class OpenAIResponsesHandler(BaseTranslation):
|
||||
cast(List[ChatCompletionToolParam], transformed_tools)
|
||||
)
|
||||
|
||||
def _remap_tools_to_responses_api_format(
|
||||
self, guardrailed_tools: List[Any]
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Remap guardrail-returned tools (Chat Completion format) back to
|
||||
Responses API request tool format.
|
||||
"""
|
||||
return LiteLLMCompletionResponsesConfig.transform_chat_completion_tool_params_to_responses_api_tools(
|
||||
guardrailed_tools # type: ignore
|
||||
)
|
||||
|
||||
def _merge_tools_after_guardrail(
|
||||
self,
|
||||
original_tools: List[Dict[str, Any]],
|
||||
remapped: List[Dict[str, Any]],
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Merge remapped guardrailed tools with original tools that were not sent
|
||||
to the guardrail (e.g. web_search, web_search_preview), preserving order.
|
||||
"""
|
||||
if not original_tools:
|
||||
return remapped
|
||||
result: List[Dict[str, Any]] = []
|
||||
j = 0
|
||||
for tool in original_tools:
|
||||
if isinstance(tool, dict) and tool.get("type") in (
|
||||
"web_search",
|
||||
"web_search_preview",
|
||||
):
|
||||
result.append(tool)
|
||||
else:
|
||||
if j < len(remapped):
|
||||
result.append(remapped[j])
|
||||
j += 1
|
||||
return result
|
||||
|
||||
def _apply_guardrailed_tools_to_data(
|
||||
self,
|
||||
data: dict,
|
||||
original_tools: List[Dict[str, Any]],
|
||||
guardrailed_tools: Optional[List[Any]],
|
||||
) -> None:
|
||||
"""Remap guardrailed tools to Responses API format and merge with original, then set data['tools']."""
|
||||
if guardrailed_tools is not None:
|
||||
remapped = self._remap_tools_to_responses_api_format(guardrailed_tools)
|
||||
data["tools"] = self._merge_tools_after_guardrail(original_tools, remapped)
|
||||
|
||||
def _extract_input_text_and_images(
|
||||
self,
|
||||
message: Any, # Can be Dict[str, Any] or ResponseInputParam
|
||||
@@ -407,7 +462,10 @@ class OpenAIResponsesHandler(BaseTranslation):
|
||||
List[ChatCompletionToolCallChunk], tool_calls
|
||||
)
|
||||
# Include model information if available
|
||||
if hasattr(model_response_stream, "model") and model_response_stream.model:
|
||||
if (
|
||||
hasattr(model_response_stream, "model")
|
||||
and model_response_stream.model
|
||||
):
|
||||
inputs["model"] = model_response_stream.model
|
||||
_guardrailed_inputs = await guardrail_to_apply.apply_guardrail(
|
||||
inputs=inputs,
|
||||
@@ -448,7 +506,9 @@ class OpenAIResponsesHandler(BaseTranslation):
|
||||
)
|
||||
return responses_so_far
|
||||
else:
|
||||
verbose_proxy_logger.debug("Skipping output guardrail - model response has no choices")
|
||||
verbose_proxy_logger.debug(
|
||||
"Skipping output guardrail - model response has no choices"
|
||||
)
|
||||
# model_response_stream = OpenAiResponsesToChatCompletionStreamIterator.translate_responses_chunk_to_openai_stream(final_chunk)
|
||||
# tool_calls = model_response_stream.choices[0].tool_calls
|
||||
# convert openai response to model response
|
||||
@@ -456,7 +516,11 @@ class OpenAIResponsesHandler(BaseTranslation):
|
||||
inputs = GenericGuardrailAPIInputs(texts=[string_so_far])
|
||||
# Try to get model from the final chunk if available
|
||||
if isinstance(final_chunk, dict):
|
||||
response_model = final_chunk.get("response", {}).get("model") if isinstance(final_chunk.get("response"), dict) else None
|
||||
response_model = (
|
||||
final_chunk.get("response", {}).get("model")
|
||||
if isinstance(final_chunk.get("response"), dict)
|
||||
else None
|
||||
)
|
||||
if response_model:
|
||||
inputs["model"] = response_model
|
||||
_guardrailed_inputs = await guardrail_to_apply.apply_guardrail(
|
||||
@@ -591,8 +655,8 @@ class OpenAIResponsesHandler(BaseTranslation):
|
||||
content = generic_response_output_item.content
|
||||
except Exception:
|
||||
# Try to extract content directly from output_item if validation fails
|
||||
if hasattr(output_item, "content") and output_item.content:
|
||||
content = output_item.content
|
||||
if hasattr(output_item, "content") and output_item.content: # type: ignore
|
||||
content = output_item.content # type: ignore
|
||||
else:
|
||||
return
|
||||
elif isinstance(output_item, dict):
|
||||
@@ -669,10 +733,10 @@ class OpenAIResponsesHandler(BaseTranslation):
|
||||
if isinstance(content_item, OutputText):
|
||||
content_item.text = guardrail_response
|
||||
# Update the original response output
|
||||
if hasattr(output_item, "content") and output_item.content:
|
||||
original_content = output_item.content[content_idx]
|
||||
if hasattr(output_item, "content") and output_item.content: # type: ignore
|
||||
original_content = output_item.content[content_idx] # type: ignore
|
||||
if hasattr(original_content, "text"):
|
||||
original_content.text = guardrail_response
|
||||
original_content.text = guardrail_response # type: ignore
|
||||
except Exception:
|
||||
pass
|
||||
elif isinstance(output_item, dict):
|
||||
|
||||
@@ -336,15 +336,21 @@ class MCPRequestHandler:
|
||||
user_api_key_auth: Optional[UserAPIKeyAuth] = None,
|
||||
) -> List[str]:
|
||||
"""
|
||||
Get list of allowed MCP servers for the given user/key based on permissions
|
||||
Get list of allowed MCP servers for the given user/key based on permissions.
|
||||
|
||||
Permission hierarchy (all rules are intersections):
|
||||
1. Get allowed servers from key permissions
|
||||
2. Get allowed servers from team permissions
|
||||
3. Get allowed servers from end_user permissions
|
||||
4. Final result = intersection of key/team AND end_user (if end_user has permissions set)
|
||||
|
||||
Returns:
|
||||
List[str]: List of allowed MCP servers by server id
|
||||
"""
|
||||
from typing import List
|
||||
from litellm.proxy.proxy_server import general_settings
|
||||
|
||||
try:
|
||||
allowed_mcp_servers: List[str] = []
|
||||
# Get allowed servers from key and team
|
||||
allowed_mcp_servers_for_key = (
|
||||
await MCPRequestHandler._get_allowed_mcp_servers_for_key(
|
||||
user_api_key_auth
|
||||
@@ -357,8 +363,9 @@ class MCPRequestHandler:
|
||||
)
|
||||
|
||||
#########################################################
|
||||
# If team has mcp_servers, handle inheritance and intersection logic
|
||||
# Calculate key/team allowed servers using inheritance and intersection logic
|
||||
#########################################################
|
||||
allowed_mcp_servers: List[str] = []
|
||||
if len(allowed_mcp_servers_for_team) > 0:
|
||||
if len(allowed_mcp_servers_for_key) > 0:
|
||||
# Key has its own MCP permissions - use intersection with team permissions
|
||||
@@ -371,6 +378,40 @@ class MCPRequestHandler:
|
||||
else:
|
||||
allowed_mcp_servers = allowed_mcp_servers_for_key
|
||||
|
||||
#########################################################
|
||||
# Check end_user permissions if end_user_id is set
|
||||
#########################################################
|
||||
if user_api_key_auth and user_api_key_auth.end_user_id:
|
||||
allowed_mcp_servers_for_end_user = (
|
||||
await MCPRequestHandler._get_allowed_mcp_servers_for_end_user(
|
||||
user_api_key_auth
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
# If end_user has explicit MCP server permissions, apply intersection
|
||||
if len(allowed_mcp_servers_for_end_user) > 0:
|
||||
verbose_logger.debug(
|
||||
f"End user {user_api_key_auth.end_user_id} has explicit MCP permissions: {allowed_mcp_servers_for_end_user}"
|
||||
)
|
||||
|
||||
# Always apply intersection: key/team AND end_user
|
||||
# This ensures end_user can only access servers that both they AND their key/team are authorized for
|
||||
filtered_servers = []
|
||||
for _mcp_server in allowed_mcp_servers:
|
||||
if _mcp_server in allowed_mcp_servers_for_end_user:
|
||||
filtered_servers.append(_mcp_server)
|
||||
allowed_mcp_servers = filtered_servers
|
||||
verbose_logger.debug(
|
||||
f"Applied end_user intersection filter. Final allowed servers: {allowed_mcp_servers}"
|
||||
)
|
||||
# If flag is enabled but end_user has no permissions, block all access
|
||||
elif general_settings.get("require_end_user_mcp_access_defined", False):
|
||||
verbose_logger.debug(
|
||||
f"require_end_user_mcp_access_defined=True and end_user {user_api_key_auth.end_user_id} has no MCP permissions - blocking MCP access"
|
||||
)
|
||||
return []
|
||||
|
||||
return list(set(allowed_mcp_servers))
|
||||
except Exception as e:
|
||||
verbose_logger.warning(f"Failed to get allowed MCP servers: {str(e)}")
|
||||
@@ -614,6 +655,66 @@ class MCPRequestHandler:
|
||||
)
|
||||
return []
|
||||
|
||||
@staticmethod
|
||||
async def _get_allowed_mcp_servers_for_end_user(
|
||||
user_api_key_auth: Optional[UserAPIKeyAuth] = None,
|
||||
) -> List[str]:
|
||||
"""
|
||||
Get allowed MCP servers for an end user.
|
||||
|
||||
Returns the MCP servers from the end_user's object_permission.
|
||||
"""
|
||||
from litellm.proxy.auth.auth_checks import get_end_user_object
|
||||
from litellm.proxy.proxy_server import (
|
||||
prisma_client,
|
||||
proxy_logging_obj,
|
||||
user_api_key_cache,
|
||||
)
|
||||
|
||||
if not user_api_key_auth or not user_api_key_auth.end_user_id:
|
||||
return []
|
||||
|
||||
if prisma_client is None:
|
||||
|
||||
verbose_logger.debug("prisma_client is None")
|
||||
return []
|
||||
|
||||
try:
|
||||
# Use optimized get_end_user_object function with caching
|
||||
end_user_obj = await get_end_user_object(
|
||||
end_user_id=user_api_key_auth.end_user_id,
|
||||
prisma_client=prisma_client,
|
||||
user_api_key_cache=user_api_key_cache,
|
||||
parent_otel_span=user_api_key_auth.parent_otel_span,
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
route="/mcp",
|
||||
)
|
||||
|
||||
|
||||
if end_user_obj is None or end_user_obj.object_permission is None:
|
||||
return []
|
||||
|
||||
# Get direct MCP servers
|
||||
direct_mcp_servers = end_user_obj.object_permission.mcp_servers or []
|
||||
|
||||
|
||||
|
||||
# Get MCP servers from access groups
|
||||
access_group_servers = (
|
||||
await MCPRequestHandler._get_mcp_servers_from_access_groups(
|
||||
end_user_obj.object_permission.mcp_access_groups or []
|
||||
)
|
||||
)
|
||||
|
||||
# Combine both lists
|
||||
all_servers = direct_mcp_servers + access_group_servers
|
||||
return list(set(all_servers))
|
||||
except Exception as e:
|
||||
verbose_logger.warning(
|
||||
f"Failed to get allowed MCP servers for end_user: {str(e)}"
|
||||
)
|
||||
return []
|
||||
|
||||
@staticmethod
|
||||
def _get_config_server_ids_for_access_groups(
|
||||
config_mcp_servers, access_groups: List[str]
|
||||
@@ -691,8 +792,6 @@ class MCPRequestHandler:
|
||||
"""
|
||||
Get list of MCP access groups for the given user/key based on permissions
|
||||
"""
|
||||
from typing import List
|
||||
|
||||
access_groups: List[str] = []
|
||||
access_groups_for_key = await MCPRequestHandler._get_mcp_access_groups_for_key(
|
||||
user_api_key_auth
|
||||
|
||||
@@ -43,9 +43,7 @@ from litellm.proxy._experimental.mcp_server.utils import (
|
||||
)
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
from litellm.proxy.auth.ip_address_utils import IPAddressUtils
|
||||
from litellm.proxy.litellm_pre_call_utils import (
|
||||
LiteLLMProxyRequestSetup,
|
||||
)
|
||||
from litellm.proxy.litellm_pre_call_utils import LiteLLMProxyRequestSetup
|
||||
from litellm.types.mcp import MCPAuth
|
||||
from litellm.types.mcp_server.mcp_server_manager import MCPInfo, MCPServer
|
||||
from litellm.types.utils import CallTypes, StandardLoggingMCPToolCall
|
||||
@@ -795,6 +793,7 @@ if MCP_AVAILABLE:
|
||||
mcp_servers=mcp_servers,
|
||||
allowed_mcp_servers=allowed_mcp_servers,
|
||||
)
|
||||
|
||||
|
||||
return allowed_mcp_servers
|
||||
|
||||
@@ -938,9 +937,6 @@ if MCP_AVAILABLE:
|
||||
mcp_servers=mcp_servers,
|
||||
)
|
||||
|
||||
# Decide whether to add prefix based on number of allowed servers
|
||||
add_prefix = not (len(allowed_mcp_servers) == 1)
|
||||
|
||||
async def _fetch_and_filter_server_tools(
|
||||
server: MCPServer,
|
||||
) -> List[MCPTool]:
|
||||
@@ -961,7 +957,7 @@ if MCP_AVAILABLE:
|
||||
server=server,
|
||||
mcp_auth_header=server_auth_header,
|
||||
extra_headers=extra_headers,
|
||||
add_prefix=add_prefix,
|
||||
add_prefix=True, # Always add server prefix
|
||||
raw_headers=raw_headers,
|
||||
)
|
||||
filtered_tools = filter_tools_by_allowed_tools(tools, server)
|
||||
@@ -1079,8 +1075,6 @@ if MCP_AVAILABLE:
|
||||
mcp_servers=mcp_servers,
|
||||
)
|
||||
|
||||
# Decide whether to add prefix based on number of allowed servers
|
||||
add_prefix = not (len(allowed_mcp_servers) == 1)
|
||||
|
||||
# Get prompts from each allowed server
|
||||
all_prompts = []
|
||||
@@ -1101,7 +1095,7 @@ if MCP_AVAILABLE:
|
||||
server=server,
|
||||
mcp_auth_header=server_auth_header,
|
||||
extra_headers=extra_headers,
|
||||
add_prefix=add_prefix,
|
||||
add_prefix=True, # Always add server prefix
|
||||
raw_headers=raw_headers,
|
||||
)
|
||||
|
||||
@@ -1140,7 +1134,6 @@ if MCP_AVAILABLE:
|
||||
mcp_servers=mcp_servers,
|
||||
)
|
||||
|
||||
add_prefix = not (len(allowed_mcp_servers) == 1)
|
||||
|
||||
all_resources: List[Resource] = []
|
||||
for server in allowed_mcp_servers:
|
||||
@@ -1160,7 +1153,7 @@ if MCP_AVAILABLE:
|
||||
server=server,
|
||||
mcp_auth_header=server_auth_header,
|
||||
extra_headers=extra_headers,
|
||||
add_prefix=add_prefix,
|
||||
add_prefix=True, # Always add server prefix
|
||||
raw_headers=raw_headers,
|
||||
)
|
||||
all_resources.extend(resources)
|
||||
@@ -1197,7 +1190,6 @@ if MCP_AVAILABLE:
|
||||
mcp_servers=mcp_servers,
|
||||
)
|
||||
|
||||
add_prefix = not (len(allowed_mcp_servers) == 1)
|
||||
|
||||
all_resource_templates: List[ResourceTemplate] = []
|
||||
for server in allowed_mcp_servers:
|
||||
@@ -1218,7 +1210,7 @@ if MCP_AVAILABLE:
|
||||
server=server,
|
||||
mcp_auth_header=server_auth_header,
|
||||
extra_headers=extra_headers,
|
||||
add_prefix=add_prefix,
|
||||
add_prefix=True, # Always add server prefix
|
||||
raw_headers=raw_headers,
|
||||
)
|
||||
)
|
||||
@@ -1676,14 +1668,9 @@ if MCP_AVAILABLE:
|
||||
detail="User not allowed to get this prompt.",
|
||||
)
|
||||
|
||||
# Decide whether to add prefix based on number of allowed servers
|
||||
add_prefix = not (len(allowed_mcp_servers) == 1)
|
||||
|
||||
if add_prefix:
|
||||
original_prompt_name, server_name = split_server_prefix_from_name(name)
|
||||
else:
|
||||
original_prompt_name = name
|
||||
server_name = allowed_mcp_servers[0].name
|
||||
# Extract server name from prefixed prompt name
|
||||
original_prompt_name, server_name = split_server_prefix_from_name(name)
|
||||
|
||||
server = next((s for s in allowed_mcp_servers if s.name == server_name), None)
|
||||
if server is None:
|
||||
|
||||
@@ -13,26 +13,25 @@ model_list:
|
||||
- model_name: gpt-4.1-mini
|
||||
litellm_params:
|
||||
model: openai/gpt-4.1-mini
|
||||
|
||||
|
||||
# guardrails:
|
||||
# - guardrail_name: generic-guardrail
|
||||
# litellm_params:
|
||||
# guardrail: generic_guardrail_api
|
||||
# mode: ["pre_call"]
|
||||
# headers:
|
||||
# Authorization: Bearer mock-bedrock-token-12345
|
||||
# api_base: http://localhost:8080
|
||||
# default_on: true
|
||||
|
||||
prompts:
|
||||
- prompt_id: "simple_prompt"
|
||||
- model_name: gpt-5-mini
|
||||
litellm_params:
|
||||
prompt_integration: "generic_prompt_management"
|
||||
provider_specific_query_params:
|
||||
project_name: litellm
|
||||
slug: hello-world-prompt-2bac
|
||||
api_base: http://localhost:8080
|
||||
api_key: os.environ/BRAINTRUST_API_KEY
|
||||
ignore_prompt_manager_model: true
|
||||
ignore_prompt_manager_optional_params: true
|
||||
model: openai/gpt-5-mini
|
||||
|
||||
|
||||
guardrails:
|
||||
- guardrail_name: mcp-user-permissions
|
||||
litellm_params:
|
||||
guardrail: mcp_end_user_permission
|
||||
mode: pre_call
|
||||
default_on: true
|
||||
|
||||
mcp_servers:
|
||||
my_http_server:
|
||||
url: "http://0.0.0.0:8001/mcp"
|
||||
transport: "http"
|
||||
description: "My custom MCP server"
|
||||
available_on_public_internet: true
|
||||
|
||||
general_settings:
|
||||
store_model_in_db: true
|
||||
store_prompts_in_spend_logs: true
|
||||
|
||||
+17
-12
@@ -1409,12 +1409,13 @@ class NewCustomerRequest(BudgetNewRequest):
|
||||
blocked: bool = False # allow/disallow requests for this end-user
|
||||
budget_id: Optional[str] = None # give either a budget_id or max_budget
|
||||
spend: Optional[float] = None
|
||||
allowed_model_region: Optional[
|
||||
AllowedModelRegion
|
||||
] = None # require all user requests to use models in this specific region
|
||||
default_model: Optional[
|
||||
str
|
||||
] = None # if no equivalent model in allowed region - default all requests to this model
|
||||
allowed_model_region: Optional[AllowedModelRegion] = (
|
||||
None # require all user requests to use models in this specific region
|
||||
)
|
||||
default_model: Optional[str] = (
|
||||
None # if no equivalent model in allowed region - default all requests to this model
|
||||
)
|
||||
object_permission: Optional[LiteLLM_ObjectPermissionBase] = None
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
@@ -1436,12 +1437,13 @@ class UpdateCustomerRequest(LiteLLMPydanticObjectBase):
|
||||
blocked: bool = False # allow/disallow requests for this end-user
|
||||
max_budget: Optional[float] = None
|
||||
budget_id: Optional[str] = None # give either a budget_id or max_budget
|
||||
allowed_model_region: Optional[
|
||||
AllowedModelRegion
|
||||
] = None # require all user requests to use models in this specific region
|
||||
default_model: Optional[
|
||||
str
|
||||
] = None # if no equivalent model in allowed region - default all requests to this model
|
||||
allowed_model_region: Optional[AllowedModelRegion] = (
|
||||
None # require all user requests to use models in this specific region
|
||||
)
|
||||
default_model: Optional[str] = (
|
||||
None # if no equivalent model in allowed region - default all requests to this model
|
||||
)
|
||||
object_permission: Optional[LiteLLM_ObjectPermissionBase] = None
|
||||
|
||||
|
||||
class DeleteCustomerRequest(LiteLLMPydanticObjectBase):
|
||||
@@ -2301,6 +2303,7 @@ class UserAPIKeyAuth(
|
||||
user_max_budget: Optional[float] = None
|
||||
request_route: Optional[str] = None
|
||||
user: Optional[Any] = None # Expanded user object when expand=user is used
|
||||
end_user_object_permission: Optional[LiteLLM_ObjectPermissionTable] = None
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
@@ -2535,6 +2538,8 @@ class LiteLLM_EndUserTable(LiteLLMPydanticObjectBase):
|
||||
allowed_model_region: Optional[AllowedModelRegion] = None
|
||||
default_model: Optional[str] = None
|
||||
litellm_budget_table: Optional[LiteLLM_BudgetTable] = None
|
||||
object_permission_id: Optional[str] = None
|
||||
object_permission: Optional[LiteLLM_ObjectPermissionTable] = None
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
|
||||
@@ -11,7 +11,8 @@ Run checks for:
|
||||
import asyncio
|
||||
import re
|
||||
import time
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Union, cast
|
||||
from typing import (TYPE_CHECKING, Any, Dict, List, Literal, Optional, Union,
|
||||
cast)
|
||||
|
||||
from fastapi import HTTPException, Request, status
|
||||
from pydantic import BaseModel
|
||||
@@ -20,41 +21,27 @@ import litellm
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.caching.caching import DualCache
|
||||
from litellm.caching.dual_cache import LimitedSizeOrderedDict
|
||||
from litellm.constants import (
|
||||
CLI_JWT_EXPIRATION_HOURS,
|
||||
CLI_JWT_TOKEN_NAME,
|
||||
DEFAULT_ACCESS_GROUP_CACHE_TTL,
|
||||
DEFAULT_IN_MEMORY_TTL,
|
||||
DEFAULT_MANAGEMENT_OBJECT_IN_MEMORY_CACHE_TTL,
|
||||
DEFAULT_MAX_RECURSE_DEPTH,
|
||||
EMAIL_BUDGET_ALERT_MAX_SPEND_ALERT_PERCENTAGE,
|
||||
)
|
||||
from litellm.constants import (CLI_JWT_EXPIRATION_HOURS, CLI_JWT_TOKEN_NAME,
|
||||
DEFAULT_ACCESS_GROUP_CACHE_TTL,
|
||||
DEFAULT_IN_MEMORY_TTL,
|
||||
DEFAULT_MANAGEMENT_OBJECT_IN_MEMORY_CACHE_TTL,
|
||||
DEFAULT_MAX_RECURSE_DEPTH,
|
||||
EMAIL_BUDGET_ALERT_MAX_SPEND_ALERT_PERCENTAGE)
|
||||
from litellm.litellm_core_utils.get_llm_provider_logic import get_llm_provider
|
||||
from litellm.proxy._types import (
|
||||
RBAC_ROLES,
|
||||
CallInfo,
|
||||
LiteLLM_AccessGroupTable,
|
||||
LiteLLM_BudgetTable,
|
||||
LiteLLM_EndUserTable,
|
||||
Litellm_EntityType,
|
||||
LiteLLM_JWTAuth,
|
||||
LiteLLM_ObjectPermissionTable,
|
||||
LiteLLM_OrganizationMembershipTable,
|
||||
LiteLLM_OrganizationTable,
|
||||
LiteLLM_TagTable,
|
||||
LiteLLM_TeamMembership,
|
||||
LiteLLM_TeamTable,
|
||||
LiteLLM_TeamTableCachedObj,
|
||||
LiteLLM_UserTable,
|
||||
LiteLLMRoutes,
|
||||
LitellmUserRoles,
|
||||
NewTeamRequest,
|
||||
ProxyErrorTypes,
|
||||
ProxyException,
|
||||
RoleBasedPermissions,
|
||||
SpecialModelNames,
|
||||
UserAPIKeyAuth,
|
||||
)
|
||||
from litellm.proxy._types import (RBAC_ROLES, CallInfo,
|
||||
LiteLLM_AccessGroupTable,
|
||||
LiteLLM_BudgetTable, LiteLLM_EndUserTable,
|
||||
Litellm_EntityType, LiteLLM_JWTAuth,
|
||||
LiteLLM_ObjectPermissionTable,
|
||||
LiteLLM_OrganizationMembershipTable,
|
||||
LiteLLM_OrganizationTable, LiteLLM_TagTable,
|
||||
LiteLLM_TeamMembership, LiteLLM_TeamTable,
|
||||
LiteLLM_TeamTableCachedObj,
|
||||
LiteLLM_UserTable, LiteLLMRoutes,
|
||||
LitellmUserRoles, NewTeamRequest,
|
||||
ProxyErrorTypes, ProxyException,
|
||||
RoleBasedPermissions, SpecialModelNames,
|
||||
UserAPIKeyAuth)
|
||||
from litellm.proxy.auth.route_checks import RouteChecks
|
||||
from litellm.proxy.route_llm_request import route_request
|
||||
from litellm.proxy.utils import PrismaClient, ProxyLogging, log_db_metrics
|
||||
@@ -366,7 +353,8 @@ async def common_checks(
|
||||
_request_metadata: dict = request_body.get("metadata", {}) or {}
|
||||
if _request_metadata.get("guardrails"):
|
||||
# check if team allowed to modify guardrails
|
||||
from litellm.proxy.guardrails.guardrail_helpers import can_modify_guardrails
|
||||
from litellm.proxy.guardrails.guardrail_helpers import \
|
||||
can_modify_guardrails
|
||||
|
||||
can_modify: bool = can_modify_guardrails(team_object)
|
||||
if can_modify is False:
|
||||
@@ -792,7 +780,7 @@ async def get_end_user_object(
|
||||
try:
|
||||
response = await prisma_client.db.litellm_endusertable.find_unique(
|
||||
where={"user_id": end_user_id},
|
||||
include={"litellm_budget_table": True},
|
||||
include={"litellm_budget_table": True, "object_permission": True},
|
||||
)
|
||||
|
||||
if response is None:
|
||||
@@ -1812,9 +1800,8 @@ class ExperimentalUIJWTToken:
|
||||
def get_experimental_ui_login_jwt_auth_token(user_info: LiteLLM_UserTable) -> str:
|
||||
from datetime import timedelta
|
||||
|
||||
from litellm.proxy.common_utils.encrypt_decrypt_utils import (
|
||||
encrypt_value_helper,
|
||||
)
|
||||
from litellm.proxy.common_utils.encrypt_decrypt_utils import \
|
||||
encrypt_value_helper
|
||||
|
||||
if user_info.user_role is None:
|
||||
raise Exception("User role is required for experimental UI login")
|
||||
@@ -1860,9 +1847,8 @@ class ExperimentalUIJWTToken:
|
||||
"""
|
||||
from datetime import timedelta
|
||||
|
||||
from litellm.proxy.common_utils.encrypt_decrypt_utils import (
|
||||
encrypt_value_helper,
|
||||
)
|
||||
from litellm.proxy.common_utils.encrypt_decrypt_utils import \
|
||||
encrypt_value_helper
|
||||
|
||||
if user_info.user_role is None:
|
||||
raise Exception("User role is required for CLI JWT login")
|
||||
@@ -1901,9 +1887,8 @@ class ExperimentalUIJWTToken:
|
||||
import json
|
||||
|
||||
from litellm.proxy.auth.user_api_key_auth import UserAPIKeyAuth
|
||||
from litellm.proxy.common_utils.encrypt_decrypt_utils import (
|
||||
decrypt_value_helper,
|
||||
)
|
||||
from litellm.proxy.common_utils.encrypt_decrypt_utils import \
|
||||
decrypt_value_helper
|
||||
|
||||
decrypted_token = decrypt_value_helper(
|
||||
hashed_token, key="ui_hash_key", exception_type="debug"
|
||||
@@ -2150,11 +2135,11 @@ async def _get_resources_from_access_groups(
|
||||
|
||||
# Lazy import to avoid circular imports
|
||||
if prisma_client is None or user_api_key_cache is None:
|
||||
from litellm.proxy.proxy_server import (
|
||||
prisma_client as _prisma_client,
|
||||
proxy_logging_obj as _proxy_logging_obj,
|
||||
user_api_key_cache as _user_api_key_cache,
|
||||
)
|
||||
from litellm.proxy.proxy_server import prisma_client as _prisma_client
|
||||
from litellm.proxy.proxy_server import \
|
||||
proxy_logging_obj as _proxy_logging_obj
|
||||
from litellm.proxy.proxy_server import \
|
||||
user_api_key_cache as _user_api_key_cache
|
||||
|
||||
prisma_client = prisma_client or _prisma_client
|
||||
user_api_key_cache = user_api_key_cache or _user_api_key_cache
|
||||
@@ -2936,7 +2921,8 @@ async def _tag_max_budget_check(
|
||||
BudgetExceededError if any tag is over its max budget.
|
||||
Triggers a budget alert if any tag is over its max budget.
|
||||
"""
|
||||
from litellm.proxy.common_utils.http_parsing_utils import get_tags_from_request_body
|
||||
from litellm.proxy.common_utils.http_parsing_utils import \
|
||||
get_tags_from_request_body
|
||||
|
||||
if prisma_client is None:
|
||||
return
|
||||
|
||||
@@ -736,6 +736,16 @@ def get_end_user_id_from_request_body(
|
||||
user_id_from_metadata_field = metadata_dict.get("user_id")
|
||||
if user_id_from_metadata_field is not None:
|
||||
return str(user_id_from_metadata_field)
|
||||
|
||||
|
||||
# Check 6: 'safety_identifier' in request body (OpenAI Responses API parameter)
|
||||
# SECURITY NOTE: safety_identifier can be set by any caller in the request body.
|
||||
# Only use this for end-user identification in trusted environments where you control
|
||||
# the calling application. For untrusted callers, prefer using headers or server-side
|
||||
# middleware to set the end_user_id to prevent impersonation.
|
||||
if request_body.get("safety_identifier") is not None:
|
||||
user_from_body_user_field = request_body["safety_identifier"]
|
||||
return str(user_from_body_user_field)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
@@ -644,7 +644,7 @@ async def _user_api_key_auth_builder( # noqa: PLR0915
|
||||
if team_object is not None
|
||||
else None,
|
||||
)
|
||||
|
||||
|
||||
# Check if model has zero cost - if so, skip all budget checks
|
||||
model = get_model_from_request(request_data, route)
|
||||
skip_budget_checks = False
|
||||
@@ -831,6 +831,8 @@ async def _user_api_key_auth_builder( # noqa: PLR0915
|
||||
valid_token=valid_token, end_user_params=end_user_params
|
||||
)
|
||||
valid_token.parent_otel_span = parent_otel_span
|
||||
if _end_user_object is not None:
|
||||
valid_token.end_user_object_permission = _end_user_object.object_permission
|
||||
|
||||
return valid_token
|
||||
|
||||
@@ -1277,6 +1279,9 @@ async def _user_api_key_auth_builder( # noqa: PLR0915
|
||||
|
||||
if _end_user_object is not None:
|
||||
valid_token_dict.update(end_user_params)
|
||||
valid_token_dict["end_user_object_permission"] = (
|
||||
_end_user_object.object_permission
|
||||
)
|
||||
|
||||
# check if token is from litellm-ui, litellm ui makes keys to allow users to login with sso. These keys can only be used for LiteLLM UI functions
|
||||
# sso/login, ui/login, /key functions and /user functions
|
||||
|
||||
@@ -205,12 +205,6 @@ class ContentFilterGuardrail(CustomGuardrail):
|
||||
# Load categories if provided
|
||||
if categories:
|
||||
self._load_categories(categories)
|
||||
else:
|
||||
verbose_proxy_logger.warning(
|
||||
"ContentFilterGuardrail has no content categories configured. "
|
||||
"Toxic/abuse and other category-based keyword filtering will not run. "
|
||||
"Add categories (e.g. harm_toxic_abuse) in the guardrail config to enable them."
|
||||
)
|
||||
|
||||
# Normalize inputs: convert dicts to Pydantic models for consistent handling
|
||||
normalized_patterns: List[ContentFilterPattern] = []
|
||||
|
||||
@@ -0,0 +1,35 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from litellm.types.guardrails import SupportedGuardrailIntegrations
|
||||
|
||||
from .mcp_end_user_permission import MCPEndUserPermissionGuardrail
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.types.guardrails import Guardrail, LitellmParams
|
||||
|
||||
|
||||
def initialize_guardrail(litellm_params: "LitellmParams", guardrail: "Guardrail"):
|
||||
import litellm
|
||||
|
||||
# Default to always-on. Only disable if the user explicitly sets default_on: false.
|
||||
# We check the raw guardrail dict because LitellmParams normalizes None → False,
|
||||
# making it impossible to distinguish "not set" from "explicitly false" via litellm_params.
|
||||
_raw_default_on = guardrail.get("litellm_params", {}).get("default_on")
|
||||
_default_on = False if _raw_default_on is False else True
|
||||
|
||||
_callback = MCPEndUserPermissionGuardrail(
|
||||
guardrail_name=guardrail.get("guardrail_name", ""),
|
||||
event_hook=litellm_params.mode,
|
||||
default_on=_default_on,
|
||||
)
|
||||
litellm.logging_callback_manager.add_litellm_callback(_callback)
|
||||
return _callback
|
||||
|
||||
|
||||
guardrail_initializer_registry = {
|
||||
SupportedGuardrailIntegrations.MCP_END_USER_PERMISSION.value: initialize_guardrail,
|
||||
}
|
||||
|
||||
guardrail_class_registry = {
|
||||
SupportedGuardrailIntegrations.MCP_END_USER_PERMISSION.value: MCPEndUserPermissionGuardrail,
|
||||
}
|
||||
+262
@@ -0,0 +1,262 @@
|
||||
"""
|
||||
MCP End User Permission Guardrail Hook
|
||||
|
||||
Enforces end user permissions for MCP server access via apply_guardrail:
|
||||
- input_type="request" → filter tools the end user cannot access
|
||||
|
||||
Permission logic:
|
||||
- No end_user_id → allow all (key/team-level permissions apply)
|
||||
- end_user_id, no mcp_servers → allow all (default)
|
||||
- end_user_id + mcp_servers → allow only those servers
|
||||
"""
|
||||
|
||||
from typing import TYPE_CHECKING, Any, List, Literal, Optional, Type
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.integrations.custom_guardrail import (
|
||||
CustomGuardrail,
|
||||
log_guardrail_information,
|
||||
)
|
||||
from litellm.proxy._types import LiteLLM_ObjectPermissionTable
|
||||
from litellm.types.guardrails import GuardrailEventHooks
|
||||
from litellm.types.utils import GenericGuardrailAPIInputs
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.types.proxy.guardrails.guardrail_hooks.base import GuardrailConfigModel
|
||||
|
||||
GUARDRAIL_NAME = "mcp_end_user_permission"
|
||||
|
||||
|
||||
class MCPEndUserPermissionGuardrail(CustomGuardrail):
|
||||
"""
|
||||
Guardrail that enforces end user permissions for MCP server access.
|
||||
|
||||
Runs on input only (pre-call). Filters tools in the request that the
|
||||
end user is not permitted to call based on their object_permission.
|
||||
|
||||
end_user_object_permission is populated on UserAPIKeyAuth during auth.
|
||||
The guardrail resolves it via a cached get_end_user_object lookup —
|
||||
no extra DB round-trip when the cache is warm.
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
if "supported_event_hooks" not in kwargs:
|
||||
kwargs["supported_event_hooks"] = [
|
||||
GuardrailEventHooks.pre_call,
|
||||
]
|
||||
super().__init__(**kwargs)
|
||||
verbose_proxy_logger.debug("MCP End User Permission Guardrail initialized")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# apply_guardrail — filters MCP tools on the request side only
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@log_guardrail_information
|
||||
async def apply_guardrail(
|
||||
self,
|
||||
inputs: GenericGuardrailAPIInputs,
|
||||
request_data: dict,
|
||||
input_type: Literal["request", "response"] = "request",
|
||||
logging_obj: Optional[Any] = None,
|
||||
) -> GenericGuardrailAPIInputs:
|
||||
"""
|
||||
Filters MCP tools the end user cannot access based on their
|
||||
object_permission.mcp_servers / mcp_access_groups settings.
|
||||
"""
|
||||
object_permission = await self._resolve_end_user_object_permission(request_data)
|
||||
return await self._check_request_tools(inputs, object_permission)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Private — request-side tool filtering
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def _check_request_tools(
|
||||
self,
|
||||
inputs: GenericGuardrailAPIInputs,
|
||||
object_permission: Optional[LiteLLM_ObjectPermissionTable],
|
||||
) -> GenericGuardrailAPIInputs:
|
||||
tools = inputs.get("tools")
|
||||
if not tools:
|
||||
return inputs
|
||||
|
||||
allowed_mcp_servers = (
|
||||
await self._get_allowed_mcp_servers_from_object_permission(
|
||||
object_permission
|
||||
)
|
||||
)
|
||||
if allowed_mcp_servers is None:
|
||||
return inputs # No restrictions → pass through unchanged
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
f"MCP guardrail: end user restricted to MCP servers: {allowed_mcp_servers}"
|
||||
)
|
||||
|
||||
filtered_tools = []
|
||||
removed_tools = []
|
||||
|
||||
for tool in tools:
|
||||
tool_name = self._get_tool_name_from_definition(tool)
|
||||
server_name = (
|
||||
self._extract_mcp_server_name(tool_name) if tool_name else None
|
||||
)
|
||||
|
||||
if server_name is None:
|
||||
# Not an MCP tool (no prefix) or unrecognised format → keep
|
||||
filtered_tools.append(tool)
|
||||
elif server_name in allowed_mcp_servers:
|
||||
filtered_tools.append(tool)
|
||||
else:
|
||||
removed_tools.append(tool_name)
|
||||
verbose_proxy_logger.warning(
|
||||
f"MCP guardrail: removing tool '{tool_name}' "
|
||||
f"(server: '{server_name}') — not in end user's allowed servers"
|
||||
)
|
||||
|
||||
if removed_tools:
|
||||
verbose_proxy_logger.debug(
|
||||
f"MCP guardrail: removed {len(removed_tools)} unauthorized MCP tool(s): {removed_tools}"
|
||||
)
|
||||
inputs["tools"] = filtered_tools
|
||||
|
||||
return inputs
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Private — end user permission resolution
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@staticmethod
|
||||
async def _resolve_end_user_object_permission(
|
||||
request_data: dict,
|
||||
) -> Optional[LiteLLM_ObjectPermissionTable]:
|
||||
"""
|
||||
Resolve the end user's object_permission via the cached auth lookup.
|
||||
|
||||
Uses get_end_user_object (same path as auth) so no extra DB round-trip
|
||||
when the cache is warm.
|
||||
"""
|
||||
end_user_id = MCPEndUserPermissionGuardrail._get_end_user_id_from_request_data(
|
||||
request_data
|
||||
)
|
||||
if not end_user_id:
|
||||
return None
|
||||
|
||||
end_user_object = await MCPEndUserPermissionGuardrail._fetch_end_user_object(
|
||||
end_user_id
|
||||
)
|
||||
return (
|
||||
end_user_object.object_permission if end_user_object is not None else None
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _get_end_user_id_from_request_data(request_data: dict) -> Optional[str]:
|
||||
return request_data.get("user_api_key_end_user_id") or request_data.get(
|
||||
"litellm_metadata", {}
|
||||
).get("user_api_key_end_user_id")
|
||||
|
||||
@staticmethod
|
||||
async def _fetch_end_user_object(end_user_id: str): # type: ignore[return]
|
||||
"""
|
||||
Fetch end user object via the same cached path used during auth.
|
||||
No extra DB round-trip when the cache is warm.
|
||||
"""
|
||||
from litellm.proxy.auth.auth_checks import get_end_user_object
|
||||
from litellm.proxy.proxy_server import (
|
||||
prisma_client,
|
||||
proxy_logging_obj,
|
||||
user_api_key_cache,
|
||||
)
|
||||
|
||||
if prisma_client is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
return await get_end_user_object(
|
||||
end_user_id=end_user_id,
|
||||
prisma_client=prisma_client,
|
||||
user_api_key_cache=user_api_key_cache,
|
||||
parent_otel_span=None,
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
route="/mcp",
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.warning(
|
||||
f"MCP guardrail: failed to fetch end_user_object for '{end_user_id}': {e}"
|
||||
)
|
||||
return None
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Private — permission derivation
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@staticmethod
|
||||
async def _get_allowed_mcp_servers_from_object_permission(
|
||||
object_permission: Optional[LiteLLM_ObjectPermissionTable],
|
||||
) -> Optional[List[str]]:
|
||||
"""
|
||||
Returns:
|
||||
None — no restrictions configured, allow all MCP servers
|
||||
list — restrict to exactly these server names
|
||||
"""
|
||||
if object_permission is None:
|
||||
return None
|
||||
|
||||
direct_mcp_servers = object_permission.mcp_servers or []
|
||||
mcp_access_groups = object_permission.mcp_access_groups or []
|
||||
|
||||
if not direct_mcp_servers and not mcp_access_groups:
|
||||
return None # Both empty → no restrictions
|
||||
|
||||
from litellm.proxy._experimental.mcp_server.auth.user_api_key_auth_mcp import (
|
||||
MCPRequestHandler,
|
||||
)
|
||||
|
||||
access_group_servers = (
|
||||
await MCPRequestHandler._get_mcp_servers_from_access_groups(
|
||||
mcp_access_groups
|
||||
)
|
||||
)
|
||||
|
||||
return list(set(direct_mcp_servers + access_group_servers))
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Config model — exposes this guardrail in the UI
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@staticmethod
|
||||
def get_config_model() -> Optional[Type["GuardrailConfigModel"]]:
|
||||
from litellm.types.proxy.guardrails.guardrail_hooks.mcp_end_user_permission import (
|
||||
MCPEndUserPermissionGuardrailConfigModel,
|
||||
)
|
||||
|
||||
return MCPEndUserPermissionGuardrailConfigModel
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Private — tool name extraction
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@staticmethod
|
||||
def _extract_mcp_server_name(tool_name: str) -> Optional[str]:
|
||||
"""
|
||||
Split "github-create_issue" → "github".
|
||||
Returns None if the tool name has no '-' prefix (not an MCP tool).
|
||||
"""
|
||||
if not tool_name or "-" not in tool_name:
|
||||
return None
|
||||
return tool_name.split("-", 1)[0]
|
||||
|
||||
@staticmethod
|
||||
def _get_tool_name_from_definition(tool: Any) -> Optional[str]:
|
||||
"""
|
||||
Extract tool name from a definition dict.
|
||||
|
||||
OpenAI format: {"type": "function", "function": {"name": "..."}}
|
||||
Anthropic format: {"name": "...", "input_schema": {...}}
|
||||
"""
|
||||
if not isinstance(tool, dict):
|
||||
return None
|
||||
function_def = tool.get("function")
|
||||
if isinstance(function_def, dict):
|
||||
name = function_def.get("name")
|
||||
if name:
|
||||
return name
|
||||
return tool.get("name")
|
||||
@@ -85,6 +85,7 @@ class UnifiedLLMGuardrails(CustomLogger):
|
||||
add_guardrail_to_applied_guardrails_header,
|
||||
)
|
||||
|
||||
|
||||
verbose_proxy_logger.debug("Running UnifiedLLMGuardrails pre-call hook")
|
||||
|
||||
guardrail_to_apply: CustomGuardrail = data.pop("guardrail_to_apply", None)
|
||||
|
||||
@@ -34,7 +34,7 @@ def init_guardrails_v2(
|
||||
if initialized_guardrail:
|
||||
guardrail_list.append(initialized_guardrail)
|
||||
|
||||
verbose_proxy_logger.debug(f"\nGuardrail List:{guardrail_list}\n")
|
||||
# verbose_proxy_logger.debug(f"\nGuardrail List:{guardrail_list}\n")
|
||||
|
||||
# Populate router's guardrail_list for load balancing support
|
||||
_populate_router_guardrail_list(guardrail_list=guardrail_list)
|
||||
|
||||
@@ -879,7 +879,7 @@ async def add_litellm_data_to_request( # noqa: PLR0915
|
||||
general_settings, user_api_key_dict, _headers
|
||||
)
|
||||
|
||||
# Parse user info from headers
|
||||
# Parse user info from headers (fallback to general_settings.user_header_name)
|
||||
user = LiteLLMProxyRequestSetup.get_user_from_headers(_headers, general_settings)
|
||||
if user is not None:
|
||||
if user_api_key_dict.end_user_id is None:
|
||||
@@ -1540,9 +1540,7 @@ def _match_and_track_policies(
|
||||
add_policy_sources_to_metadata,
|
||||
add_policy_to_applied_policies_header,
|
||||
)
|
||||
from litellm.proxy.policy_engine.attachment_registry import (
|
||||
get_attachment_registry,
|
||||
)
|
||||
from litellm.proxy.policy_engine.attachment_registry import get_attachment_registry
|
||||
from litellm.proxy.policy_engine.policy_matcher import PolicyMatcher
|
||||
|
||||
# Get matching policies via attachments (with match reasons for attribution)
|
||||
@@ -1677,9 +1675,7 @@ def add_guardrails_from_policy_engine(
|
||||
user_api_key_dict: The user's API key authentication info
|
||||
"""
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.proxy.common_utils.http_parsing_utils import (
|
||||
get_tags_from_request_body,
|
||||
)
|
||||
from litellm.proxy.common_utils.http_parsing_utils import get_tags_from_request_body
|
||||
from litellm.proxy.policy_engine.policy_registry import get_policy_registry
|
||||
from litellm.types.proxy.policy_engine import PolicyMatchContext
|
||||
|
||||
|
||||
@@ -19,11 +19,13 @@ import litellm
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.proxy._types import *
|
||||
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
||||
from litellm.proxy.management_endpoints.common_daily_activity import \
|
||||
get_daily_activity
|
||||
from litellm.proxy.management_helpers.object_permission_utils import (
|
||||
_set_object_permission, handle_update_object_permission_common)
|
||||
from litellm.proxy.utils import handle_exception_on_proxy
|
||||
from litellm.types.proxy.management_endpoints.common_daily_activity import (
|
||||
SpendAnalyticsPaginatedResponse,
|
||||
)
|
||||
from litellm.proxy.management_endpoints.common_daily_activity import get_daily_activity
|
||||
from litellm.types.proxy.management_endpoints.common_daily_activity import \
|
||||
SpendAnalyticsPaginatedResponse
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@@ -107,9 +109,8 @@ async def unblock_user(data: BlockUsers):
|
||||
```
|
||||
"""
|
||||
try:
|
||||
from enterprise.enterprise_hooks.blocked_user_list import (
|
||||
_ENTERPRISE_BlockedUserList,
|
||||
)
|
||||
from enterprise.enterprise_hooks.blocked_user_list import \
|
||||
_ENTERPRISE_BlockedUserList
|
||||
except ImportError:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
@@ -164,6 +165,38 @@ def new_budget_request(data: NewCustomerRequest) -> Optional[BudgetNewRequest]:
|
||||
return None
|
||||
|
||||
|
||||
async def _handle_customer_object_permission_update(
|
||||
non_default_values: dict,
|
||||
end_user_table_data_typed: Optional[LiteLLM_EndUserTable],
|
||||
update_end_user_table_data: dict,
|
||||
prisma_client,
|
||||
) -> None:
|
||||
"""
|
||||
Handle object permission updates for customer endpoints.
|
||||
|
||||
Updates the update_end_user_table_data dict in place with the new object_permission_id.
|
||||
|
||||
Args:
|
||||
non_default_values: Dictionary containing the update values including object_permission
|
||||
end_user_table_data_typed: Existing end user table data
|
||||
update_end_user_table_data: Dictionary to update with new object_permission_id
|
||||
prisma_client: Prisma database client
|
||||
"""
|
||||
if "object_permission" in non_default_values:
|
||||
existing_object_permission_id = (
|
||||
end_user_table_data_typed.object_permission_id
|
||||
if end_user_table_data_typed is not None
|
||||
else None
|
||||
)
|
||||
object_permission_id = await handle_update_object_permission_common(
|
||||
data_json=non_default_values,
|
||||
existing_object_permission_id=existing_object_permission_id,
|
||||
prisma_client=prisma_client,
|
||||
)
|
||||
if object_permission_id is not None:
|
||||
update_end_user_table_data["object_permission_id"] = object_permission_id
|
||||
|
||||
|
||||
@router.post(
|
||||
"/end_user/new",
|
||||
tags=["Customer Management"],
|
||||
@@ -200,6 +233,16 @@ async def new_end_user(
|
||||
- soft_budget: Optional[float] - [Not Implemented Yet] Get alerts when customer crosses given budget, doesn't block requests.
|
||||
- spend: Optional[float] - Specify initial spend for a given customer.
|
||||
- budget_reset_at: Optional[str] - Specify the date and time when the budget should be reset.
|
||||
- object_permission: Optional[LiteLLM_ObjectPermissionBase] - Customer-specific object permissions to control access to resources.
|
||||
Supported fields:
|
||||
* mcp_servers: List[str] - List of allowed MCP server IDs
|
||||
* mcp_access_groups: List[str] - List of MCP access group names
|
||||
* mcp_tool_permissions: Dict[str, List[str]] - Map of server ID to allowed tool names (e.g., {"server_1": ["tool_a", "tool_b"]})
|
||||
* vector_stores: List[str] - List of allowed vector store IDs
|
||||
* agents: List[str] - List of allowed agent IDs
|
||||
* agent_access_groups: List[str] - List of agent access group names
|
||||
Example: {"mcp_servers": ["server_1", "server_2"], "vector_stores": ["vector_store_1"], "agents": ["agent_1"]}
|
||||
IF null or {} then no object-level restrictions apply.
|
||||
|
||||
|
||||
- Allow specifying allowed regions
|
||||
@@ -214,9 +257,22 @@ async def new_end_user(
|
||||
"user_id" : "ishaan-jaff-3",
|
||||
"allowed_region": "eu",
|
||||
"budget_id": "free_tier",
|
||||
"default_model": "azure/gpt-3.5-turbo-eu" <- all calls from this user, use this model?
|
||||
"default_model": "azure/gpt-3.5-turbo-eu"
|
||||
}'
|
||||
|
||||
# With object permissions
|
||||
curl -L -X POST 'http://localhost:4000/customer/new' \
|
||||
-H 'Authorization: Bearer sk-1234' \
|
||||
-H 'Content-Type: application/json' \
|
||||
-d '{
|
||||
"user_id": "user_1",
|
||||
"object_permission": {
|
||||
"mcp_servers": ["server_1"],
|
||||
"mcp_access_groups": ["public_group"],
|
||||
"vector_stores": ["vector_store_1"]
|
||||
}
|
||||
}'
|
||||
|
||||
# return end-user object
|
||||
```
|
||||
|
||||
@@ -233,11 +289,8 @@ async def new_end_user(
|
||||
- end-user object
|
||||
- currently allowed models
|
||||
"""
|
||||
from litellm.proxy.proxy_server import (
|
||||
litellm_proxy_admin_name,
|
||||
llm_router,
|
||||
prisma_client,
|
||||
)
|
||||
from litellm.proxy.proxy_server import (litellm_proxy_admin_name,
|
||||
llm_router, prisma_client)
|
||||
|
||||
if prisma_client is None:
|
||||
raise HTTPException(
|
||||
@@ -289,13 +342,34 @@ async def new_end_user(
|
||||
if k not in BudgetNewRequest.model_fields.keys():
|
||||
new_end_user_obj[k] = v
|
||||
|
||||
## Handle Object Permission - MCP Servers, Vector Stores etc.
|
||||
new_end_user_obj = await _set_object_permission(
|
||||
data_json=new_end_user_obj,
|
||||
prisma_client=prisma_client,
|
||||
)
|
||||
|
||||
# Ensure object_permission is not in the data being sent to create
|
||||
# It should have been converted to object_permission_id by _set_object_permission
|
||||
if "object_permission" in new_end_user_obj:
|
||||
verbose_proxy_logger.warning(
|
||||
f"object_permission still in new_end_user_obj after _set_object_permission: {new_end_user_obj.get('object_permission')}"
|
||||
)
|
||||
new_end_user_obj.pop("object_permission", None)
|
||||
|
||||
## WRITE TO DB ##
|
||||
end_user_record = await prisma_client.db.litellm_endusertable.create(
|
||||
data=new_end_user_obj, # type: ignore
|
||||
include={"litellm_budget_table": True},
|
||||
include={"litellm_budget_table": True, "object_permission": True},
|
||||
)
|
||||
|
||||
return end_user_record
|
||||
# Convert to dict and clean up recursive fields
|
||||
response_dict = end_user_record.model_dump()
|
||||
if response_dict.get("object_permission"):
|
||||
# Remove reverse relations from object_permission
|
||||
for field in ["teams", "verification_tokens", "organizations", "users", "end_users"]:
|
||||
response_dict["object_permission"].pop(field, None)
|
||||
|
||||
return response_dict
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(
|
||||
"litellm.proxy.management_endpoints.customer_endpoints.new_end_user(): Exception occured - {}".format(
|
||||
@@ -351,7 +425,7 @@ async def end_user_info(
|
||||
)
|
||||
|
||||
user_info = await prisma_client.db.litellm_endusertable.find_first(
|
||||
where={"user_id": end_user_id}, include={"litellm_budget_table": True}
|
||||
where={"user_id": end_user_id}, include={"litellm_budget_table": True, "object_permission": True}
|
||||
)
|
||||
|
||||
if user_info is None:
|
||||
@@ -361,7 +435,15 @@ async def end_user_info(
|
||||
code=404,
|
||||
param="end_user_id",
|
||||
)
|
||||
return user_info.model_dump(exclude_none=True)
|
||||
|
||||
# Convert to dict and clean up recursive fields
|
||||
response_dict = user_info.model_dump(exclude_none=True)
|
||||
if response_dict.get("object_permission"):
|
||||
# Remove reverse relations from object_permission
|
||||
for field in ["teams", "verification_tokens", "organizations", "users", "end_users"]:
|
||||
response_dict["object_permission"].pop(field, None)
|
||||
|
||||
return response_dict
|
||||
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(
|
||||
@@ -401,6 +483,16 @@ async def update_end_user(
|
||||
- default_model: Optional[str] = (
|
||||
None # if no equivalent model in allowed region - default all requests to this model
|
||||
)
|
||||
- object_permission: Optional[LiteLLM_ObjectPermissionBase] - Customer-specific object permissions to control access to resources.
|
||||
Supported fields:
|
||||
* mcp_servers: List[str] - List of allowed MCP server IDs
|
||||
* mcp_access_groups: List[str] - List of MCP access group names
|
||||
* mcp_tool_permissions: Dict[str, List[str]] - Map of server ID to allowed tool names
|
||||
* vector_stores: List[str] - List of allowed vector store IDs
|
||||
* agents: List[str] - List of allowed agent IDs
|
||||
* agent_access_groups: List[str] - List of agent access group names
|
||||
Example: {"mcp_servers": ["server_1"], "vector_stores": ["vector_store_1"]}
|
||||
IF null or {} then no object-level restrictions apply.
|
||||
|
||||
Example curl:
|
||||
```
|
||||
@@ -412,11 +504,24 @@ async def update_end_user(
|
||||
"budget_id": "paid_tier"
|
||||
}'
|
||||
|
||||
See below for all params
|
||||
# Updating object permissions
|
||||
curl -L -X POST 'http://localhost:4000/customer/update' \
|
||||
--header 'Authorization: Bearer sk-1234' \
|
||||
--header 'Content-Type: application/json' \
|
||||
--data '{
|
||||
"user_id": "user_1",
|
||||
"object_permission": {
|
||||
"mcp_servers": ["server_3"],
|
||||
"vector_stores": ["vector_store_2", "vector_store_3"]
|
||||
}
|
||||
}'
|
||||
|
||||
See below for all params
|
||||
```
|
||||
"""
|
||||
|
||||
from litellm.proxy.proxy_server import litellm_proxy_admin_name, prisma_client
|
||||
from litellm.proxy.proxy_server import (litellm_proxy_admin_name,
|
||||
prisma_client)
|
||||
|
||||
try:
|
||||
data_json: dict = data.json()
|
||||
@@ -467,6 +572,14 @@ async def update_end_user(
|
||||
elif k in LiteLLM_EndUserTable.model_fields.keys():
|
||||
update_end_user_table_data[k] = v
|
||||
|
||||
## Handle object permission updates (MCP servers, vector stores, etc.)
|
||||
await _handle_customer_object_permission_update(
|
||||
non_default_values=non_default_values,
|
||||
end_user_table_data_typed=end_user_table_data_typed,
|
||||
update_end_user_table_data=update_end_user_table_data,
|
||||
prisma_client=prisma_client,
|
||||
)
|
||||
|
||||
## Check if we need to create a new budget (only if budget fields are provided, not just budget_id) ##
|
||||
if budget_table_data:
|
||||
if end_user_budget_table is None:
|
||||
@@ -498,11 +611,20 @@ async def update_end_user(
|
||||
|
||||
## Update user table, with update params + new budget id (if set) ##
|
||||
verbose_proxy_logger.debug("/customer/update: Received data = %s", data)
|
||||
|
||||
# Ensure object_permission is not in the update data
|
||||
# It should have been converted to object_permission_id by handle_update_object_permission_common
|
||||
if "object_permission" in update_end_user_table_data:
|
||||
verbose_proxy_logger.warning(
|
||||
f"object_permission still in update_end_user_table_data: {update_end_user_table_data.get('object_permission')}"
|
||||
)
|
||||
update_end_user_table_data.pop("object_permission", None)
|
||||
|
||||
if data.user_id is not None and len(data.user_id) > 0:
|
||||
update_end_user_table_data["user_id"] = data.user_id # type: ignore
|
||||
verbose_proxy_logger.debug("In update customer, user_id condition block.")
|
||||
response = await prisma_client.db.litellm_endusertable.update(
|
||||
where={"user_id": data.user_id}, data=update_end_user_table_data, include={"litellm_budget_table": True} # type: ignore
|
||||
where={"user_id": data.user_id}, data=update_end_user_table_data, include={"litellm_budget_table": True, "object_permission": True} # type: ignore
|
||||
)
|
||||
if response is None:
|
||||
raise ValueError(
|
||||
@@ -511,7 +633,15 @@ async def update_end_user(
|
||||
verbose_proxy_logger.debug(
|
||||
f"received response from updating prisma client. response={response}"
|
||||
)
|
||||
return response
|
||||
|
||||
# Convert to dict and clean up recursive fields
|
||||
response_dict = response.model_dump()
|
||||
if response_dict.get("object_permission"):
|
||||
# Remove reverse relations from object_permission
|
||||
for field in ["teams", "verification_tokens", "organizations", "users", "end_users"]:
|
||||
response_dict["object_permission"].pop(field, None)
|
||||
|
||||
return response_dict
|
||||
else:
|
||||
raise ValueError(f"user_id is required, passed user_id = {data.user_id}")
|
||||
|
||||
@@ -663,12 +793,17 @@ async def list_end_user(
|
||||
)
|
||||
|
||||
response = await prisma_client.db.litellm_endusertable.find_many(
|
||||
include={"litellm_budget_table": True}
|
||||
include={"litellm_budget_table": True, "object_permission": True}
|
||||
)
|
||||
|
||||
returned_response: List[LiteLLM_EndUserTable] = []
|
||||
for item in response:
|
||||
returned_response.append(LiteLLM_EndUserTable(**item.model_dump()))
|
||||
item_dict = item.model_dump()
|
||||
# Remove reverse relations from object_permission
|
||||
if item_dict.get("object_permission"):
|
||||
for field in ["teams", "verification_tokens", "organizations", "users", "end_users"]:
|
||||
item_dict["object_permission"].pop(field, None)
|
||||
returned_response.append(LiteLLM_EndUserTable(**item_dict))
|
||||
return returned_response
|
||||
|
||||
except Exception as e:
|
||||
@@ -706,9 +841,7 @@ async def get_customer_daily_activity(
|
||||
"""
|
||||
Get daily activity for specific organizations or all accessible organizations.
|
||||
"""
|
||||
from litellm.proxy.proxy_server import (
|
||||
prisma_client,
|
||||
)
|
||||
from litellm.proxy.proxy_server import prisma_client
|
||||
|
||||
if prisma_client is None:
|
||||
raise HTTPException(
|
||||
|
||||
@@ -233,6 +233,7 @@ model LiteLLM_ObjectPermissionTable {
|
||||
verification_tokens LiteLLM_VerificationToken[]
|
||||
organizations LiteLLM_OrganizationTable[]
|
||||
users LiteLLM_UserTable[]
|
||||
end_users LiteLLM_EndUserTable[]
|
||||
}
|
||||
|
||||
// Holds the MCP server configuration
|
||||
@@ -403,7 +404,9 @@ model LiteLLM_EndUserTable {
|
||||
allowed_model_region String? // require all user requests to use models in this specific region
|
||||
default_model String? // use along with 'allowed_model_region'. if no available model in region, default to this model.
|
||||
budget_id String?
|
||||
object_permission_id String?
|
||||
litellm_budget_table LiteLLM_BudgetTable? @relation(fields: [budget_id], references: [budget_id])
|
||||
object_permission LiteLLM_ObjectPermissionTable? @relation(fields: [object_permission_id], references: [object_permission_id])
|
||||
blocked Boolean @default(false)
|
||||
}
|
||||
|
||||
|
||||
@@ -1321,6 +1321,7 @@ class ProxyLogging:
|
||||
metadata = data.get("metadata", data.get("litellm_metadata", {})) or {}
|
||||
pipeline_managed: set = metadata.get("_pipeline_managed_guardrails", set())
|
||||
|
||||
|
||||
for callback in litellm.callbacks:
|
||||
start_time = time.time()
|
||||
_callback = None
|
||||
|
||||
@@ -6,6 +6,7 @@ from collections.abc import Sequence
|
||||
from typing import Any, Dict, List, Literal, Optional, Set, Tuple, Union, cast
|
||||
|
||||
from openai.types.responses import ResponseFunctionToolCall
|
||||
from openai.types.responses.response_create_params import ResponseInputParam
|
||||
from openai.types.responses.tool_param import FunctionToolParam
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
@@ -32,7 +33,6 @@ from litellm.types.llms.openai import (
|
||||
OpenAIWebSearchUserLocation,
|
||||
OutputTokensDetails,
|
||||
ResponseAPIUsage,
|
||||
ResponseInputParam,
|
||||
ResponsesAPIOptionalRequestParams,
|
||||
ResponsesAPIResponse,
|
||||
ResponsesAPIStatus,
|
||||
@@ -738,9 +738,25 @@ class LiteLLMCompletionResponsesConfig:
|
||||
|
||||
@staticmethod
|
||||
def _ensure_tool_results_have_corresponding_tool_calls(
|
||||
messages: List[Union[AllMessageValues, GenericChatCompletionMessage, ChatCompletionResponseMessage]],
|
||||
messages: Sequence[
|
||||
Union[
|
||||
AllMessageValues,
|
||||
GenericChatCompletionMessage,
|
||||
ChatCompletionResponseMessage,
|
||||
ChatCompletionMessageToolCall,
|
||||
Message,
|
||||
]
|
||||
],
|
||||
tools: Optional[List[Any]] = None,
|
||||
) -> List[Union[AllMessageValues, GenericChatCompletionMessage, ChatCompletionResponseMessage]]:
|
||||
) -> List[
|
||||
Union[
|
||||
AllMessageValues,
|
||||
GenericChatCompletionMessage,
|
||||
ChatCompletionResponseMessage,
|
||||
ChatCompletionMessageToolCall,
|
||||
Message,
|
||||
]
|
||||
]:
|
||||
"""
|
||||
Ensure that tool_result messages have corresponding tool_calls in the previous assistant message.
|
||||
|
||||
@@ -755,11 +771,19 @@ class LiteLLMCompletionResponsesConfig:
|
||||
List of messages with tool_calls added to assistant messages when needed
|
||||
"""
|
||||
if not messages:
|
||||
return messages
|
||||
|
||||
# Create a deep copy to avoid modifying the original
|
||||
return list(messages)
|
||||
|
||||
# Create a deep copy to avoid modifying the original (use list() so we can mutate and return List)
|
||||
import copy
|
||||
fixed_messages = copy.deepcopy(messages)
|
||||
fixed_messages: List[
|
||||
Union[
|
||||
AllMessageValues,
|
||||
GenericChatCompletionMessage,
|
||||
ChatCompletionResponseMessage,
|
||||
ChatCompletionMessageToolCall,
|
||||
Message,
|
||||
]
|
||||
] = list(copy.deepcopy(messages))
|
||||
messages_to_remove = []
|
||||
|
||||
# Count non-tool messages to avoid removing all messages
|
||||
@@ -1306,6 +1330,50 @@ class LiteLLMCompletionResponsesConfig:
|
||||
chat_completion_tools.append(cast(Union[ChatCompletionToolParam, OpenAIMcpServerTool], tool))
|
||||
return chat_completion_tools, web_search_options
|
||||
|
||||
@staticmethod
|
||||
def transform_chat_completion_tool_params_to_responses_api_tools(
|
||||
chat_completion_tools: Optional[
|
||||
List[Union[ChatCompletionToolParam, OpenAIMcpServerTool]]
|
||||
],
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Transform Chat Completion tool params (e.g. from guardrail output) back to
|
||||
Responses API request tool format. Inverse of
|
||||
transform_responses_api_tools_to_chat_completion_tools for the tools list.
|
||||
"""
|
||||
if chat_completion_tools is None or not chat_completion_tools:
|
||||
return []
|
||||
result: List[Dict[str, Any]] = []
|
||||
for tool in chat_completion_tools:
|
||||
if not isinstance(tool, dict):
|
||||
result.append(tool) # type: ignore
|
||||
continue
|
||||
if tool.get("type") == "function":
|
||||
fn = tool.get("function") or {}
|
||||
parameters = dict(fn.get("parameters", {}) or {})
|
||||
if not parameters or "type" not in parameters:
|
||||
parameters["type"] = "object"
|
||||
responses_tool: Dict[str, Any] = {
|
||||
"type": "function",
|
||||
"name": fn.get("name") or "",
|
||||
"description": fn.get("description") or "",
|
||||
"parameters": parameters,
|
||||
"strict": fn.get("strict", False) or False,
|
||||
}
|
||||
if tool.get("cache_control") is not None:
|
||||
responses_tool["cache_control"] = tool.get("cache_control")
|
||||
if tool.get("defer_loading") is not None:
|
||||
responses_tool["defer_loading"] = tool.get("defer_loading")
|
||||
if tool.get("allowed_callers") is not None:
|
||||
responses_tool["allowed_callers"] = tool.get("allowed_callers")
|
||||
if tool.get("input_examples") is not None:
|
||||
responses_tool["input_examples"] = tool.get("input_examples")
|
||||
result.append(responses_tool)
|
||||
else:
|
||||
# mcp or other: pass through unchanged
|
||||
result.append(dict(tool))
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def transform_chat_completion_tools_to_responses_tools(
|
||||
chat_completion_response: ModelResponse,
|
||||
|
||||
@@ -70,6 +70,7 @@ class SupportedGuardrailIntegrations(Enum):
|
||||
GENERIC_GUARDRAIL_API = "generic_guardrail_api"
|
||||
QUALIFIRE = "qualifire"
|
||||
CUSTOM_CODE = "custom_code"
|
||||
MCP_END_USER_PERMISSION = "mcp_end_user_permission"
|
||||
|
||||
|
||||
class Role(Enum):
|
||||
|
||||
@@ -0,0 +1,12 @@
|
||||
from .base import GuardrailConfigModel
|
||||
|
||||
|
||||
class MCPEndUserPermissionGuardrailConfigModel(GuardrailConfigModel):
|
||||
"""
|
||||
No provider-specific params required — permissions come from the end user
|
||||
object already stored in the database.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def ui_friendly_name() -> str:
|
||||
return "MCP End User Permission"
|
||||
@@ -233,6 +233,7 @@ model LiteLLM_ObjectPermissionTable {
|
||||
verification_tokens LiteLLM_VerificationToken[]
|
||||
organizations LiteLLM_OrganizationTable[]
|
||||
users LiteLLM_UserTable[]
|
||||
end_users LiteLLM_EndUserTable[]
|
||||
}
|
||||
|
||||
// Holds the MCP server configuration
|
||||
@@ -403,7 +404,9 @@ model LiteLLM_EndUserTable {
|
||||
allowed_model_region String? // require all user requests to use models in this specific region
|
||||
default_model String? // use along with 'allowed_model_region'. if no available model in region, default to this model.
|
||||
budget_id String?
|
||||
object_permission_id String?
|
||||
litellm_budget_table LiteLLM_BudgetTable? @relation(fields: [budget_id], references: [budget_id])
|
||||
object_permission LiteLLM_ObjectPermissionTable? @relation(fields: [object_permission_id], references: [object_permission_id])
|
||||
blocked Boolean @default(false)
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,414 @@
|
||||
"""
|
||||
Tests for MCP End User Permission Guardrail Hook
|
||||
"""
|
||||
import os
|
||||
import sys
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../../../..")
|
||||
) # Adds the parent directory to the system path
|
||||
|
||||
from litellm.exceptions import GuardrailRaisedException
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
from litellm.proxy.guardrails.guardrail_hooks.mcp_end_user_permission import (
|
||||
MCPEndUserPermissionGuardrail,
|
||||
)
|
||||
from litellm.types.utils import (
|
||||
ChatCompletionMessageToolCall,
|
||||
Choices,
|
||||
Function,
|
||||
Message,
|
||||
ModelResponse,
|
||||
)
|
||||
|
||||
|
||||
class TestMCPEndUserPermissionGuardrail:
|
||||
"""Test the MCP End User Permission Guardrail"""
|
||||
|
||||
def test_extract_mcp_server_name(self):
|
||||
"""Test extracting MCP server name from tool name"""
|
||||
guardrail = MCPEndUserPermissionGuardrail()
|
||||
|
||||
# Test valid MCP tool names
|
||||
assert guardrail._extract_mcp_server_name("github-create_issue") == "github"
|
||||
assert guardrail._extract_mcp_server_name("slack-send_message") == "slack"
|
||||
assert guardrail._extract_mcp_server_name("jira-create-ticket") == "jira"
|
||||
|
||||
# Test invalid/non-MCP tool names
|
||||
assert guardrail._extract_mcp_server_name("search") is None
|
||||
assert guardrail._extract_mcp_server_name("") is None
|
||||
assert guardrail._extract_mcp_server_name(None) is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_apply_guardrail_no_end_user(self):
|
||||
"""Test guardrail when no end_user_id is present"""
|
||||
guardrail = MCPEndUserPermissionGuardrail()
|
||||
|
||||
# Create inputs with MCP tools
|
||||
inputs = {
|
||||
"tools": [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "github-create_issue",
|
||||
"description": "Create an issue",
|
||||
},
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
request_data = {}
|
||||
|
||||
# Should pass through all tools when no end_user_id
|
||||
result = await guardrail.apply_guardrail(
|
||||
inputs=inputs,
|
||||
request_data=request_data,
|
||||
input_type="request",
|
||||
)
|
||||
|
||||
assert len(result.get("tools", [])) == 1
|
||||
assert result["tools"][0]["function"]["name"] == "github-create_issue"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_apply_guardrail_with_authorized_tools(self):
|
||||
"""Test guardrail when end user has access to MCP servers"""
|
||||
from litellm.proxy._types import LiteLLM_ObjectPermissionTable
|
||||
|
||||
guardrail = MCPEndUserPermissionGuardrail()
|
||||
|
||||
# Create inputs with multiple tools
|
||||
inputs = {
|
||||
"tools": [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "github-create_issue",
|
||||
"description": "Create an issue",
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "slack-send_message",
|
||||
"description": "Send a message",
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "search",
|
||||
"description": "Regular search tool",
|
||||
},
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
request_data = {"user_api_key_end_user_id": "end-user-123"}
|
||||
|
||||
# Mock fetching end user object with permissions
|
||||
with patch.object(
|
||||
MCPEndUserPermissionGuardrail,
|
||||
"_fetch_end_user_object",
|
||||
return_value=MagicMock(
|
||||
object_permission=LiteLLM_ObjectPermissionTable(
|
||||
object_permission_id="perm-1",
|
||||
mcp_servers=["github", "slack"],
|
||||
)
|
||||
),
|
||||
):
|
||||
result = await guardrail.apply_guardrail(
|
||||
inputs=inputs,
|
||||
request_data=request_data,
|
||||
input_type="request",
|
||||
)
|
||||
|
||||
# Should keep all authorized MCP tools + non-MCP tools
|
||||
assert len(result.get("tools", [])) == 3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_apply_guardrail_with_unauthorized_tools(self):
|
||||
"""Test guardrail filters out unauthorized MCP tools"""
|
||||
from litellm.proxy._types import LiteLLM_ObjectPermissionTable
|
||||
|
||||
guardrail = MCPEndUserPermissionGuardrail()
|
||||
|
||||
# Create inputs with tools where end user only has access to some
|
||||
inputs = {
|
||||
"tools": [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "github-create_issue",
|
||||
"description": "Create an issue",
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "slack-send_message",
|
||||
"description": "Send a message",
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "jira-create_ticket",
|
||||
"description": "Create a ticket",
|
||||
},
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
request_data = {"user_api_key_end_user_id": "end-user-123"}
|
||||
|
||||
# Mock fetching end user object with limited permissions (only slack and jira)
|
||||
with patch.object(
|
||||
MCPEndUserPermissionGuardrail,
|
||||
"_fetch_end_user_object",
|
||||
return_value=MagicMock(
|
||||
object_permission=LiteLLM_ObjectPermissionTable(
|
||||
object_permission_id="perm-1",
|
||||
mcp_servers=["slack", "jira"],
|
||||
)
|
||||
),
|
||||
):
|
||||
result = await guardrail.apply_guardrail(
|
||||
inputs=inputs,
|
||||
request_data=request_data,
|
||||
input_type="request",
|
||||
)
|
||||
|
||||
# Should filter out github tool
|
||||
assert len(result.get("tools", [])) == 2
|
||||
tool_names = [t["function"]["name"] for t in result["tools"]]
|
||||
assert "slack-send_message" in tool_names
|
||||
assert "jira-create_ticket" in tool_names
|
||||
assert "github-create_issue" not in tool_names
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_apply_guardrail_no_permissions_configured(self):
|
||||
"""Test guardrail when end user has no MCP permissions configured"""
|
||||
guardrail = MCPEndUserPermissionGuardrail()
|
||||
|
||||
# Create inputs with MCP tools
|
||||
inputs = {
|
||||
"tools": [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "github-create_issue",
|
||||
"description": "Create an issue",
|
||||
},
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
request_data = {"user_api_key_end_user_id": "end-user-123"}
|
||||
|
||||
# Mock fetching end user object with no object_permission
|
||||
with patch.object(
|
||||
MCPEndUserPermissionGuardrail,
|
||||
"_fetch_end_user_object",
|
||||
return_value=MagicMock(object_permission=None),
|
||||
):
|
||||
result = await guardrail.apply_guardrail(
|
||||
inputs=inputs,
|
||||
request_data=request_data,
|
||||
input_type="request",
|
||||
)
|
||||
|
||||
# Should pass through all tools when no permissions configured
|
||||
assert len(result.get("tools", [])) == 1
|
||||
assert result["tools"][0]["function"]["name"] == "github-create_issue"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_apply_guardrail_with_non_mcp_tools(self):
|
||||
"""Test guardrail passes through non-MCP tools"""
|
||||
from litellm.proxy._types import LiteLLM_ObjectPermissionTable
|
||||
|
||||
guardrail = MCPEndUserPermissionGuardrail()
|
||||
|
||||
# Create inputs with non-MCP tools
|
||||
inputs = {
|
||||
"tools": [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "search",
|
||||
"description": "Search tool",
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "calculate",
|
||||
"description": "Calculate something",
|
||||
},
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
request_data = {"user_api_key_end_user_id": "end-user-123"}
|
||||
|
||||
# Mock fetching end user object with MCP restrictions
|
||||
with patch.object(
|
||||
MCPEndUserPermissionGuardrail,
|
||||
"_fetch_end_user_object",
|
||||
return_value=MagicMock(
|
||||
object_permission=LiteLLM_ObjectPermissionTable(
|
||||
object_permission_id="perm-1",
|
||||
mcp_servers=["github"],
|
||||
)
|
||||
),
|
||||
):
|
||||
result = await guardrail.apply_guardrail(
|
||||
inputs=inputs,
|
||||
request_data=request_data,
|
||||
input_type="request",
|
||||
)
|
||||
|
||||
# Should keep all non-MCP tools even with MCP restrictions
|
||||
assert len(result.get("tools", [])) == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_apply_guardrail_filters_unauthorized_mcp_tools(self):
|
||||
"""Test guardrail filters out unauthorized MCP tools"""
|
||||
from litellm.proxy._types import LiteLLM_ObjectPermissionTable
|
||||
|
||||
guardrail = MCPEndUserPermissionGuardrail()
|
||||
|
||||
# Create inputs with MCP tools where user only has access to some
|
||||
inputs = {
|
||||
"tools": [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "github-create_issue",
|
||||
"description": "Create an issue",
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "slack-send_message",
|
||||
"description": "Send a message",
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "jira-create_ticket",
|
||||
"description": "Create a ticket",
|
||||
},
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
request_data = {"user_api_key_end_user_id": "end-user-123"}
|
||||
|
||||
# Mock fetching end user object - only has access to slack and jira, not github
|
||||
with patch.object(
|
||||
MCPEndUserPermissionGuardrail,
|
||||
"_fetch_end_user_object",
|
||||
return_value=MagicMock(
|
||||
object_permission=LiteLLM_ObjectPermissionTable(
|
||||
object_permission_id="perm-1",
|
||||
mcp_servers=["slack", "jira"],
|
||||
)
|
||||
),
|
||||
):
|
||||
result = await guardrail.apply_guardrail(
|
||||
inputs=inputs,
|
||||
request_data=request_data,
|
||||
input_type="request",
|
||||
)
|
||||
|
||||
# Should filter out github tool
|
||||
assert len(result.get("tools", [])) == 2
|
||||
tool_names = [t["function"]["name"] for t in result["tools"]]
|
||||
assert "slack-send_message" in tool_names
|
||||
assert "jira-create_ticket" in tool_names
|
||||
assert "github-create_issue" not in tool_names
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_apply_guardrail_with_mixed_tools(self):
|
||||
"""Test guardrail with both MCP and non-MCP tools"""
|
||||
from litellm.proxy._types import LiteLLM_ObjectPermissionTable
|
||||
|
||||
guardrail = MCPEndUserPermissionGuardrail()
|
||||
|
||||
# Create inputs with both MCP and non-MCP tools
|
||||
inputs = {
|
||||
"tools": [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "github-create_issue",
|
||||
"description": "Create an issue",
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "search",
|
||||
"description": "Search tool",
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "slack-send_message",
|
||||
"description": "Send a message",
|
||||
},
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
request_data = {"user_api_key_end_user_id": "end-user-123"}
|
||||
|
||||
# Mock fetching end user object - only has access to slack
|
||||
with patch.object(
|
||||
MCPEndUserPermissionGuardrail,
|
||||
"_fetch_end_user_object",
|
||||
return_value=MagicMock(
|
||||
object_permission=LiteLLM_ObjectPermissionTable(
|
||||
object_permission_id="perm-1",
|
||||
mcp_servers=["slack"],
|
||||
)
|
||||
),
|
||||
):
|
||||
result = await guardrail.apply_guardrail(
|
||||
inputs=inputs,
|
||||
request_data=request_data,
|
||||
input_type="request",
|
||||
)
|
||||
|
||||
# Should keep slack MCP tool and non-MCP search tool, filter out github
|
||||
assert len(result.get("tools", [])) == 2
|
||||
tool_names = [t["function"]["name"] for t in result["tools"]]
|
||||
assert "search" in tool_names
|
||||
assert "slack-send_message" in tool_names
|
||||
assert "github-create_issue" not in tool_names
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_apply_guardrail_no_tools_in_request(self):
|
||||
"""Test guardrail when request has no tools"""
|
||||
guardrail = MCPEndUserPermissionGuardrail()
|
||||
|
||||
# Create inputs without tools
|
||||
inputs = {"model": "gpt-4", "messages": [{"role": "user", "content": "test"}]}
|
||||
|
||||
request_data = {"user_api_key_end_user_id": "end-user-123"}
|
||||
|
||||
# Should pass through unchanged
|
||||
result = await guardrail.apply_guardrail(
|
||||
inputs=inputs,
|
||||
request_data=request_data,
|
||||
input_type="request",
|
||||
)
|
||||
|
||||
assert result == inputs
|
||||
assert "tools" not in result
|
||||
Reference in New Issue
Block a user