diff --git a/.claude/settings.json b/.claude/settings.json new file mode 100644 index 0000000000..8c1d85f96e --- /dev/null +++ b/.claude/settings.json @@ -0,0 +1,36 @@ +{ + "permissions": { + "allow": [ + "Bash(git show:*)", + "Bash(git worktree add:*)", + "Read(//Users/krrishdholakia/Documents/litellm/**)", + "Read(//Users/krrishdholakia/Documents/litellm-claude-code-guardrails/litellm/types/**)", + "Read(//Users/krrishdholakia/Documents/litellm-claude-code-guardrails/**)", + "Read(//Users/krrishdholakia/Documents/litellm-claude-code-guardrails/litellm/**)", + "Bash(python:*)", + "Bash(python -c \"\nimport sys; sys.path.insert\\(0, ''.''\\)\nfrom litellm.proxy.guardrails.guardrail_hooks.claude_code.guardrail import ClaudeCodeGuardrail, HOSTED_TOOL_PREFIXES\nprint\\(''HOSTED_TOOL_PREFIXES:'', HOSTED_TOOL_PREFIXES\\)\nprint\\(''ClaudeCodeGuardrail imported OK''\\)\n\")", + "Read(//Users/krrishdholakia/Documents/litellm-mcp-jwt-groups/litellm/proxy/**)", + "Read(//Users/krrishdholakia/Documents/litellm-mcp-jwt-groups/**)", + "Bash(poetry run pytest:*)", + "Bash(git add:*)", + "Bash(git commit:*)", + "Bash(poetry run python:*)", + "Bash(poetry run pip:*)", + "Bash(git reset:*)", + "Bash(git cherry-pick:*)", + "Bash(git checkout:*)", + "Read(//Users/krrishdholakia/Documents/litellm/litellm/proxy/guardrails/guardrail_hooks/**)", + "Read(//Users/krrishdholakia/Documents/**)", + "Bash(git -C /Users/krrishdholakia/Documents/litellm-mcp-user-permissions worktree list)", + "Bash(ls:*)" + ], + "additionalDirectories": [ + "/Users/krrishdholakia/Documents/litellm-mcp-group-plan/plan", + "/Users/krrishdholakia/Documents/litellm-claude-code-guardrails/litellm/proxy/guardrails/guardrail_hooks/claude_code", + "/Users/krrishdholakia/Documents/litellm-claude-code-guardrails/litellm/types", + "/Users/krrishdholakia/Documents/litellm-claude-code-guardrails", + "/Users/krrishdholakia/Documents/litellm-mcp-jwt-groups/litellm/proxy", + "/Users/krrishdholakia/Documents/litellm-mcp-jwt-groups/tests/test_litellm/proxy/auth" + ] + } +} diff --git a/docs/my-website/docs/mcp.md b/docs/my-website/docs/mcp.md index 84d10c2593..50973f220f 100644 --- a/docs/my-website/docs/mcp.md +++ b/docs/my-website/docs/mcp.md @@ -808,6 +808,68 @@ If your stdio MCP server needs per-request credentials, you can map HTTP headers In this example, when a client makes a request with the `X-GITHUB_PERSONAL_ACCESS_TOKEN` header, the proxy forwards that value into the stdio process as the `GITHUB_PERSONAL_ACCESS_TOKEN` environment variable. +## Control MCP Access for End Users + +Control which MCP servers end users of your AI application can access (e.g. users of an internal chat UI). Pass the customer ID in the `x-litellm-end-user-id` header to: +- Enforce object permissions (limit which MCP servers they can access) +- Apply customer-specific budgets +- Track spend per customer + +**FastMCP Client Example:** + +```python title="Track customer spend with x-litellm-end-user-id" showLineNumbers +from fastmcp import Client +import asyncio + +# MCP client configuration with customer tracking +config = { + "mcpServers": { + "github": { + "url": "http://localhost:4000/github_mcp/mcp", + "headers": { + "x-litellm-api-key": "Bearer sk-1234", + "x-litellm-end-user-id": "customer_123", # 👈 CUSTOMER ID + "Authorization": "Bearer gho_token" + } + } + } +} + +client = Client(config) + +async def main(): + async with client: + # All MCP calls will be tracked under customer_123 + tools = await client.list_tools() + result = await client.call_tool(tools[0].name, {}) + print(f"Tool result: {result}") + +asyncio.run(main()) +``` + +**Cursor IDE Example:** + +```json title="Cursor config with customer tracking" showLineNumbers +{ + "mcpServers": { + "GitHub": { + "url": "http://localhost:4000/github_mcp/mcp", + "headers": { + "x-litellm-api-key": "Bearer $LITELLM_API_KEY", + "x-litellm-end-user-id": "customer_123" + } + } + } +} +``` + +**What happens:** +- Customer-specific object permissions are enforced (only allowed MCP servers are accessible) +- Customer budgets are applied +- All tool calls are tracked under `customer_123` + +[Learn more about customer management →](./proxy/customers) + ## Using your MCP with client side credentials Use this if you want to pass a client side authentication token to LiteLLM to then pass to your MCP to auth to your MCP. diff --git a/docs/my-website/docs/proxy/customers.md b/docs/my-website/docs/proxy/customers.md index 1101884c36..50a5f994fa 100644 --- a/docs/my-website/docs/proxy/customers.md +++ b/docs/my-website/docs/proxy/customers.md @@ -2,29 +2,98 @@ import Image from '@theme/IdealImage'; import Tabs from '@theme/Tabs'; import TabItem from '@theme/TabItem'; -# Customers / End-User Budgets +# Customers / End-Users -Track spend, set budgets for your customers. +Track spend, set budgets and permissions for your customers. -## Tracking Customer Spend +## Tracking Customer Spend + Permissions ### 1. Make LLM API call w/ Customer ID -Make a /chat/completions call, pass 'user' - First call Works +LiteLLM checks for a customer/end-user ID in the following order (first match wins): -```bash showLineNumbers title="Make request with customer ID" +| Priority | Method | Where | Notes | +|----------|--------|-------|-------| +| 1 | `x-litellm-customer-id` header | Request headers | Standard header, always checked | +| 2 | `x-litellm-end-user-id` header | Request headers | Standard header, always checked | +| 3 | Custom header via `user_header_mappings` | Request headers | Configured in `general_settings` | +| 4 | Custom header via `user_header_name` | Request headers | Deprecated — use `user_header_mappings` | +| 5 | `user` field | Request body | Standard OpenAI field | +| 6 | `litellm_metadata.user` field | Request body | Anthropic-style metadata | +| 7 | `metadata.user_id` field | Request body | Generic metadata pattern | +| 8 | `safety_identifier` field | Request body | Responses API | + +**Option 1: Standard headers** (recommended — no request body modification needed) + +```bash showLineNumbers title="Make request with customer ID in header" curl -X POST 'http://0.0.0.0:4000/chat/completions' \ --header 'Content-Type: application/json' \ - --header 'Authorization: Bearer sk-1234' \ # 👈 YOUR PROXY KEY - --data ' { + --header 'Authorization: Bearer sk-1234' \ + --header 'x-litellm-end-user-id: ishaan3' \ + --data '{ "model": "azure-gpt-3.5", - "user": "ishaan3", # 👈 CUSTOMER ID - "messages": [ - { - "role": "user", - "content": "what time is it" - } - ] + "messages": [{"role": "user", "content": "what time is it"}] + }' +``` + +Both `x-litellm-customer-id` and `x-litellm-end-user-id` are supported and always checked without any configuration. + +**Option 2: `user` field in request body** (OpenAI-compatible) + +```bash showLineNumbers title="Make request with customer ID in body" +curl -X POST 'http://0.0.0.0:4000/chat/completions' \ + --header 'Content-Type: application/json' \ + --header 'Authorization: Bearer sk-1234' \ + --data '{ + "model": "azure-gpt-3.5", + "user": "ishaan3", + "messages": [{"role": "user", "content": "what time is it"}] + }' +``` + +**Option 3: Custom header via `user_header_mappings`** (configurable) + +```yaml showLineNumbers title="config.yaml" +general_settings: + user_header_mappings: + - header_name: "x-my-app-user-id" + litellm_user_role: "customer" +``` + +```bash showLineNumbers title="Make request with custom header" +curl -X POST 'http://0.0.0.0:4000/chat/completions' \ + --header 'Content-Type: application/json' \ + --header 'Authorization: Bearer sk-1234' \ + --header 'x-my-app-user-id: ishaan3' \ + --data '{ + "model": "azure-gpt-3.5", + "messages": [{"role": "user", "content": "what time is it"}] + }' +``` + +**Option 4: `litellm_metadata.user`** (Anthropic-style) + +```bash showLineNumbers title="Make request with litellm_metadata.user" +curl -X POST 'http://0.0.0.0:4000/chat/completions' \ + --header 'Content-Type: application/json' \ + --header 'Authorization: Bearer sk-1234' \ + --data '{ + "model": "claude-3-5-sonnet", + "messages": [{"role": "user", "content": "what time is it"}], + "litellm_metadata": {"user": "ishaan3"} + }' +``` + +**Option 5: `metadata.user_id`** + +```bash showLineNumbers title="Make request with metadata.user_id" +curl -X POST 'http://0.0.0.0:4000/chat/completions' \ + --header 'Content-Type: application/json' \ + --header 'Authorization: Bearer sk-1234' \ + --data '{ + "model": "azure-gpt-3.5", + "messages": [{"role": "user", "content": "what time is it"}], + "metadata": {"user_id": "ishaan3"} }' ``` @@ -123,7 +192,171 @@ Expected Response -## Setting Customer Budgets +## Setting Customer Object Permissions + +Control which resources (MCP servers, vector stores, agents) a customer can access. + +### What are Object Permissions? + +Object permissions allow you to restrict customer access to specific: +- **MCP Servers**: Limit which MCP servers the customer can call +- **MCP Access Groups**: Assign customers to predefined groups of MCP servers +- **MCP Tool Permissions**: Granular control over which tools within an MCP server the customer can use +- **Vector Stores**: Control which vector stores the customer can query +- **Agents**: Restrict which agents the customer can interact with +- **Agent Access Groups**: Assign customers to predefined groups of agents + +### Creating a Customer with Object Permissions + +```bash showLineNumbers title="Create customer with object permissions" +curl -L -X POST 'http://localhost:4000/customer/new' \ +-H 'Authorization: Bearer sk-1234' \ +-H 'Content-Type: application/json' \ +-d '{ + "user_id": "user_1", + "object_permission": { + "mcp_servers": ["server_1", "server_2"], + "mcp_access_groups": ["public_group"], + "mcp_tool_permissions": { + "server_1": ["tool_a", "tool_b"] + }, + "vector_stores": ["vector_store_1"], + "agents": ["agent_1"], + "agent_access_groups": ["basic_agents"] + } + }' +``` + +**Parameters:** +- `mcp_servers` (Optional[List[str]]): List of allowed MCP server IDs +- `mcp_access_groups` (Optional[List[str]]): List of MCP access group names +- `mcp_tool_permissions` (Optional[Dict[str, List[str]]]): Map of server ID to allowed tool names +- `vector_stores` (Optional[List[str]]): List of allowed vector store IDs +- `agents` (Optional[List[str]]): List of allowed agent IDs +- `agent_access_groups` (Optional[List[str]]): List of agent access group names + +**Note:** If `object_permission` is `null` or `{}`, the customer has no object-level restrictions. + +### Updating Customer Object Permissions + +You can update object permissions for existing customers: + +```bash showLineNumbers title="Update customer object permissions" +curl -L -X POST 'http://localhost:4000/customer/update' \ +-H 'Authorization: Bearer sk-1234' \ +-H 'Content-Type: application/json' \ +-d '{ + "user_id": "user_1", + "object_permission": { + "mcp_servers": ["server_3"], + "vector_stores": ["vector_store_2", "vector_store_3"] + } + }' +``` + +### Viewing Customer Object Permissions + +When you query customer info, object permissions are included in the response: + +```bash showLineNumbers title="Get customer info with object permissions" +curl -X GET 'http://0.0.0.0:4000/customer/info?end_user_id=user_1' \ + -H 'Authorization: Bearer sk-1234' +``` + +**Response:** +```json showLineNumbers title="Response with object permissions" +{ + "user_id": "user_1", + "blocked": false, + "alias": "John Doe", + "spend": 0.0, + "object_permission": { + "object_permission_id": "perm_abc123", + "mcp_servers": ["server_1", "server_2"], + "mcp_access_groups": ["public_group"], + "mcp_tool_permissions": { + "server_1": ["tool_a", "tool_b"] + }, + "vector_stores": ["vector_store_1"], + "agents": ["agent_1"], + "agent_access_groups": ["basic_agents"] + }, + "litellm_budget_table": null +} +``` + +### Use Cases + +**1. Tiered Access Control** +Create different permission tiers for your customers: + +```bash showLineNumbers title="Free tier customer" +# Free tier - limited access +curl -L -X POST 'http://localhost:4000/customer/new' \ +-H 'Authorization: Bearer sk-1234' \ +-H 'Content-Type: application/json' \ +-d '{ + "user_id": "free_user", + "budget_id": "free_tier", + "object_permission": { + "mcp_access_groups": ["public_group"], + "agent_access_groups": ["basic_agents"] + } + }' +``` + +```bash showLineNumbers title="Premium tier customer" +# Premium tier - full access +curl -L -X POST 'http://localhost:4000/customer/new' \ +-H 'Authorization: Bearer sk-1234' \ +-H 'Content-Type: application/json' \ +-d '{ + "user_id": "premium_user", + "budget_id": "premium_tier", + "object_permission": { + "mcp_servers": ["server_1", "server_2", "server_3"], + "vector_stores": ["vector_store_1", "vector_store_2"], + "agents": ["agent_1", "agent_2", "agent_3"] + } + }' +``` + +**2. Department-Specific Access** +Restrict customers to resources relevant to their department: + +```bash showLineNumbers title="Sales team customer" +curl -L -X POST 'http://localhost:4000/customer/new' \ +-H 'Authorization: Bearer sk-1234' \ +-H 'Content-Type: application/json' \ +-d '{ + "user_id": "sales_user", + "object_permission": { + "mcp_servers": ["crm_server", "email_server"], + "agents": ["sales_assistant"], + "vector_stores": ["sales_knowledge_base"] + } + }' +``` + +**3. Tool-Level Restrictions** +Grant access to specific tools within an MCP server: + +```bash showLineNumbers title="Limited tool access" +curl -L -X POST 'http://localhost:4000/customer/new' \ +-H 'Authorization: Bearer sk-1234' \ +-H 'Content-Type: application/json' \ +-d '{ + "user_id": "restricted_user", + "object_permission": { + "mcp_servers": ["database_server"], + "mcp_tool_permissions": { + "database_server": ["read_only_query", "get_table_schema"] + } + } + }' +``` + +## Setting Customer Budgets Set customer budgets (e.g. monthly budgets, tpm/rpm limits) on LiteLLM Proxy diff --git a/docs/my-website/docs/proxy/request_headers.md b/docs/my-website/docs/proxy/request_headers.md index 090c201f88..d76964611a 100644 --- a/docs/my-website/docs/proxy/request_headers.md +++ b/docs/my-website/docs/proxy/request_headers.md @@ -20,6 +20,10 @@ By default, LiteLLM does not forward client headers to LLM provider APIs. Howeve `x-litellm-spend-logs-metadata`: Optional[str]: JSON string containing custom metadata to include in spend logs. Example: `{"user_id": "12345", "project_id": "proj_abc", "request_type": "chat_completion"}`. [Learn More](../proxy/enterprise#tracking-spend-with-custom-metadata) +`x-litellm-customer-id`: Optional[str]: Standard header for passing a customer/end-user ID. Always checked without any configuration. [Learn More](./customers) + +`x-litellm-end-user-id`: Optional[str]: Standard header for passing a customer/end-user ID. Always checked without any configuration. [Learn More](./customers) + ## Anthropic Headers `anthropic-version` Optional[str]: The version of the Anthropic API to use. diff --git a/litellm-proxy-extras/litellm_proxy_extras/migrations/20260214185341_object_permissions_for_end_users/migration.sql b/litellm-proxy-extras/litellm_proxy_extras/migrations/20260214185341_object_permissions_for_end_users/migration.sql new file mode 100644 index 0000000000..5c5dc6fd6f --- /dev/null +++ b/litellm-proxy-extras/litellm_proxy_extras/migrations/20260214185341_object_permissions_for_end_users/migration.sql @@ -0,0 +1,6 @@ +-- AlterTable +ALTER TABLE "LiteLLM_EndUserTable" ADD COLUMN "object_permission_id" TEXT; + +-- AddForeignKey +ALTER TABLE "LiteLLM_EndUserTable" ADD CONSTRAINT "LiteLLM_EndUserTable_object_permission_id_fkey" FOREIGN KEY ("object_permission_id") REFERENCES "LiteLLM_ObjectPermissionTable"("object_permission_id") ON DELETE SET NULL ON UPDATE CASCADE; + diff --git a/litellm-proxy-extras/litellm_proxy_extras/schema.prisma b/litellm-proxy-extras/litellm_proxy_extras/schema.prisma index 441c2cdf70..8bd46672ae 100644 --- a/litellm-proxy-extras/litellm_proxy_extras/schema.prisma +++ b/litellm-proxy-extras/litellm_proxy_extras/schema.prisma @@ -233,6 +233,7 @@ model LiteLLM_ObjectPermissionTable { verification_tokens LiteLLM_VerificationToken[] organizations LiteLLM_OrganizationTable[] users LiteLLM_UserTable[] + end_users LiteLLM_EndUserTable[] } // Holds the MCP server configuration @@ -403,7 +404,9 @@ model LiteLLM_EndUserTable { allowed_model_region String? // require all user requests to use models in this specific region default_model String? // use along with 'allowed_model_region'. if no available model in region, default to this model. budget_id String? + object_permission_id String? litellm_budget_table LiteLLM_BudgetTable? @relation(fields: [budget_id], references: [budget_id]) + object_permission LiteLLM_ObjectPermissionTable? @relation(fields: [object_permission_id], references: [object_permission_id]) blocked Boolean @default(false) } diff --git a/litellm/llms/anthropic/chat/guardrail_translation/handler.py b/litellm/llms/anthropic/chat/guardrail_translation/handler.py index a14e7d118e..98650a238e 100644 --- a/litellm/llms/anthropic/chat/guardrail_translation/handler.py +++ b/litellm/llms/anthropic/chat/guardrail_translation/handler.py @@ -124,6 +124,9 @@ class AnthropicMessagesHandler(BaseTranslation): ) guardrailed_texts = guardrailed_inputs.get("texts", []) + guardrailed_tools = guardrailed_inputs.get("tools") + if guardrailed_tools is not None: + data["tools"] = guardrailed_tools # Step 3: Map guardrail responses back to original message structure await self._apply_guardrail_responses_to_input( @@ -194,7 +197,7 @@ class AnthropicMessagesHandler(BaseTranslation): openai_tools = self.adapter.translate_anthropic_tools_to_openai( tools=cast(List[AllAnthropicToolsValues], tools) ) - tools_to_check.extend(openai_tools) + tools_to_check.extend(openai_tools) # type: ignore async def _apply_guardrail_responses_to_input( self, diff --git a/litellm/llms/openai/chat/guardrail_translation/handler.py b/litellm/llms/openai/chat/guardrail_translation/handler.py index c406f502b4..683e165c31 100644 --- a/litellm/llms/openai/chat/guardrail_translation/handler.py +++ b/litellm/llms/openai/chat/guardrail_translation/handler.py @@ -107,6 +107,9 @@ class OpenAIChatCompletionsHandler(BaseTranslation): guardrailed_texts = guardrailed_inputs.get("texts", []) guardrailed_tool_calls = guardrailed_inputs.get("tool_calls", []) + guardrailed_tools = guardrailed_inputs.get("tools") + if guardrailed_tools is not None: + data["tools"] = guardrailed_tools # Step 3: Map guardrail responses back to original message structure if guardrailed_texts and texts_to_check: diff --git a/litellm/llms/openai/responses/guardrail_translation/handler.py b/litellm/llms/openai/responses/guardrail_translation/handler.py index ad3d4c932d..6b092911d3 100644 --- a/litellm/llms/openai/responses/guardrail_translation/handler.py +++ b/litellm/llms/openai/responses/guardrail_translation/handler.py @@ -96,10 +96,11 @@ class OpenAIResponsesHandler(BaseTranslation): # Handle simple string input if isinstance(input_data, str): inputs = GenericGuardrailAPIInputs(texts=[input_data]) + original_tools: List[Dict[str, Any]] = [] # Extract and transform tools if present - if "tools" in data and data["tools"]: + original_tools = list(data["tools"]) self._extract_and_transform_tools(data["tools"], tools_to_check) if tools_to_check: inputs["tools"] = tools_to_check @@ -118,6 +119,9 @@ class OpenAIResponsesHandler(BaseTranslation): ) guardrailed_texts = guardrailed_inputs.get("texts", []) data["input"] = guardrailed_texts[0] if guardrailed_texts else input_data + self._apply_guardrailed_tools_to_data( + data, original_tools, guardrailed_inputs.get("tools") + ) verbose_proxy_logger.debug("OpenAI Responses API: Processed string input") return data @@ -128,8 +132,7 @@ class OpenAIResponsesHandler(BaseTranslation): texts_to_check: List[str] = [] images_to_check: List[str] = [] task_mappings: List[Tuple[int, Optional[int]]] = [] - # Track (message_index, content_index) for each text - # content_index is None for string content, int for list content + original_tools_list: List[Dict[str, Any]] = list(data.get("tools") or []) # Step 1: Extract all text content, images, and tools for msg_idx, message in enumerate(input_data): @@ -166,6 +169,11 @@ class OpenAIResponsesHandler(BaseTranslation): ) guardrailed_texts = guardrailed_inputs.get("texts", []) + self._apply_guardrailed_tools_to_data( + data, + original_tools_list, + guardrailed_inputs.get("tools"), + ) # Step 3: Map guardrail responses back to original input structure await self._apply_guardrail_responses_to_input( @@ -203,6 +211,53 @@ class OpenAIResponsesHandler(BaseTranslation): cast(List[ChatCompletionToolParam], transformed_tools) ) + def _remap_tools_to_responses_api_format( + self, guardrailed_tools: List[Any] + ) -> List[Dict[str, Any]]: + """ + Remap guardrail-returned tools (Chat Completion format) back to + Responses API request tool format. + """ + return LiteLLMCompletionResponsesConfig.transform_chat_completion_tool_params_to_responses_api_tools( + guardrailed_tools # type: ignore + ) + + def _merge_tools_after_guardrail( + self, + original_tools: List[Dict[str, Any]], + remapped: List[Dict[str, Any]], + ) -> List[Dict[str, Any]]: + """ + Merge remapped guardrailed tools with original tools that were not sent + to the guardrail (e.g. web_search, web_search_preview), preserving order. + """ + if not original_tools: + return remapped + result: List[Dict[str, Any]] = [] + j = 0 + for tool in original_tools: + if isinstance(tool, dict) and tool.get("type") in ( + "web_search", + "web_search_preview", + ): + result.append(tool) + else: + if j < len(remapped): + result.append(remapped[j]) + j += 1 + return result + + def _apply_guardrailed_tools_to_data( + self, + data: dict, + original_tools: List[Dict[str, Any]], + guardrailed_tools: Optional[List[Any]], + ) -> None: + """Remap guardrailed tools to Responses API format and merge with original, then set data['tools'].""" + if guardrailed_tools is not None: + remapped = self._remap_tools_to_responses_api_format(guardrailed_tools) + data["tools"] = self._merge_tools_after_guardrail(original_tools, remapped) + def _extract_input_text_and_images( self, message: Any, # Can be Dict[str, Any] or ResponseInputParam @@ -407,7 +462,10 @@ class OpenAIResponsesHandler(BaseTranslation): List[ChatCompletionToolCallChunk], tool_calls ) # Include model information if available - if hasattr(model_response_stream, "model") and model_response_stream.model: + if ( + hasattr(model_response_stream, "model") + and model_response_stream.model + ): inputs["model"] = model_response_stream.model _guardrailed_inputs = await guardrail_to_apply.apply_guardrail( inputs=inputs, @@ -448,7 +506,9 @@ class OpenAIResponsesHandler(BaseTranslation): ) return responses_so_far else: - verbose_proxy_logger.debug("Skipping output guardrail - model response has no choices") + verbose_proxy_logger.debug( + "Skipping output guardrail - model response has no choices" + ) # model_response_stream = OpenAiResponsesToChatCompletionStreamIterator.translate_responses_chunk_to_openai_stream(final_chunk) # tool_calls = model_response_stream.choices[0].tool_calls # convert openai response to model response @@ -456,7 +516,11 @@ class OpenAIResponsesHandler(BaseTranslation): inputs = GenericGuardrailAPIInputs(texts=[string_so_far]) # Try to get model from the final chunk if available if isinstance(final_chunk, dict): - response_model = final_chunk.get("response", {}).get("model") if isinstance(final_chunk.get("response"), dict) else None + response_model = ( + final_chunk.get("response", {}).get("model") + if isinstance(final_chunk.get("response"), dict) + else None + ) if response_model: inputs["model"] = response_model _guardrailed_inputs = await guardrail_to_apply.apply_guardrail( @@ -591,8 +655,8 @@ class OpenAIResponsesHandler(BaseTranslation): content = generic_response_output_item.content except Exception: # Try to extract content directly from output_item if validation fails - if hasattr(output_item, "content") and output_item.content: - content = output_item.content + if hasattr(output_item, "content") and output_item.content: # type: ignore + content = output_item.content # type: ignore else: return elif isinstance(output_item, dict): @@ -669,10 +733,10 @@ class OpenAIResponsesHandler(BaseTranslation): if isinstance(content_item, OutputText): content_item.text = guardrail_response # Update the original response output - if hasattr(output_item, "content") and output_item.content: - original_content = output_item.content[content_idx] + if hasattr(output_item, "content") and output_item.content: # type: ignore + original_content = output_item.content[content_idx] # type: ignore if hasattr(original_content, "text"): - original_content.text = guardrail_response + original_content.text = guardrail_response # type: ignore except Exception: pass elif isinstance(output_item, dict): diff --git a/litellm/proxy/_experimental/mcp_server/auth/user_api_key_auth_mcp.py b/litellm/proxy/_experimental/mcp_server/auth/user_api_key_auth_mcp.py index ed4fb13347..60b29b975f 100644 --- a/litellm/proxy/_experimental/mcp_server/auth/user_api_key_auth_mcp.py +++ b/litellm/proxy/_experimental/mcp_server/auth/user_api_key_auth_mcp.py @@ -336,15 +336,21 @@ class MCPRequestHandler: user_api_key_auth: Optional[UserAPIKeyAuth] = None, ) -> List[str]: """ - Get list of allowed MCP servers for the given user/key based on permissions + Get list of allowed MCP servers for the given user/key based on permissions. + + Permission hierarchy (all rules are intersections): + 1. Get allowed servers from key permissions + 2. Get allowed servers from team permissions + 3. Get allowed servers from end_user permissions + 4. Final result = intersection of key/team AND end_user (if end_user has permissions set) Returns: List[str]: List of allowed MCP servers by server id """ - from typing import List + from litellm.proxy.proxy_server import general_settings try: - allowed_mcp_servers: List[str] = [] + # Get allowed servers from key and team allowed_mcp_servers_for_key = ( await MCPRequestHandler._get_allowed_mcp_servers_for_key( user_api_key_auth @@ -357,8 +363,9 @@ class MCPRequestHandler: ) ######################################################### - # If team has mcp_servers, handle inheritance and intersection logic + # Calculate key/team allowed servers using inheritance and intersection logic ######################################################### + allowed_mcp_servers: List[str] = [] if len(allowed_mcp_servers_for_team) > 0: if len(allowed_mcp_servers_for_key) > 0: # Key has its own MCP permissions - use intersection with team permissions @@ -371,6 +378,40 @@ class MCPRequestHandler: else: allowed_mcp_servers = allowed_mcp_servers_for_key + ######################################################### + # Check end_user permissions if end_user_id is set + ######################################################### + if user_api_key_auth and user_api_key_auth.end_user_id: + allowed_mcp_servers_for_end_user = ( + await MCPRequestHandler._get_allowed_mcp_servers_for_end_user( + user_api_key_auth + ) + ) + + + # If end_user has explicit MCP server permissions, apply intersection + if len(allowed_mcp_servers_for_end_user) > 0: + verbose_logger.debug( + f"End user {user_api_key_auth.end_user_id} has explicit MCP permissions: {allowed_mcp_servers_for_end_user}" + ) + + # Always apply intersection: key/team AND end_user + # This ensures end_user can only access servers that both they AND their key/team are authorized for + filtered_servers = [] + for _mcp_server in allowed_mcp_servers: + if _mcp_server in allowed_mcp_servers_for_end_user: + filtered_servers.append(_mcp_server) + allowed_mcp_servers = filtered_servers + verbose_logger.debug( + f"Applied end_user intersection filter. Final allowed servers: {allowed_mcp_servers}" + ) + # If flag is enabled but end_user has no permissions, block all access + elif general_settings.get("require_end_user_mcp_access_defined", False): + verbose_logger.debug( + f"require_end_user_mcp_access_defined=True and end_user {user_api_key_auth.end_user_id} has no MCP permissions - blocking MCP access" + ) + return [] + return list(set(allowed_mcp_servers)) except Exception as e: verbose_logger.warning(f"Failed to get allowed MCP servers: {str(e)}") @@ -614,6 +655,66 @@ class MCPRequestHandler: ) return [] + @staticmethod + async def _get_allowed_mcp_servers_for_end_user( + user_api_key_auth: Optional[UserAPIKeyAuth] = None, + ) -> List[str]: + """ + Get allowed MCP servers for an end user. + + Returns the MCP servers from the end_user's object_permission. + """ + from litellm.proxy.auth.auth_checks import get_end_user_object + from litellm.proxy.proxy_server import ( + prisma_client, + proxy_logging_obj, + user_api_key_cache, + ) + + if not user_api_key_auth or not user_api_key_auth.end_user_id: + return [] + + if prisma_client is None: + + verbose_logger.debug("prisma_client is None") + return [] + + try: + # Use optimized get_end_user_object function with caching + end_user_obj = await get_end_user_object( + end_user_id=user_api_key_auth.end_user_id, + prisma_client=prisma_client, + user_api_key_cache=user_api_key_cache, + parent_otel_span=user_api_key_auth.parent_otel_span, + proxy_logging_obj=proxy_logging_obj, + route="/mcp", + ) + + + if end_user_obj is None or end_user_obj.object_permission is None: + return [] + + # Get direct MCP servers + direct_mcp_servers = end_user_obj.object_permission.mcp_servers or [] + + + + # Get MCP servers from access groups + access_group_servers = ( + await MCPRequestHandler._get_mcp_servers_from_access_groups( + end_user_obj.object_permission.mcp_access_groups or [] + ) + ) + + # Combine both lists + all_servers = direct_mcp_servers + access_group_servers + return list(set(all_servers)) + except Exception as e: + verbose_logger.warning( + f"Failed to get allowed MCP servers for end_user: {str(e)}" + ) + return [] + @staticmethod def _get_config_server_ids_for_access_groups( config_mcp_servers, access_groups: List[str] @@ -691,8 +792,6 @@ class MCPRequestHandler: """ Get list of MCP access groups for the given user/key based on permissions """ - from typing import List - access_groups: List[str] = [] access_groups_for_key = await MCPRequestHandler._get_mcp_access_groups_for_key( user_api_key_auth diff --git a/litellm/proxy/_experimental/mcp_server/server.py b/litellm/proxy/_experimental/mcp_server/server.py index 31836a2750..e8877b4fff 100644 --- a/litellm/proxy/_experimental/mcp_server/server.py +++ b/litellm/proxy/_experimental/mcp_server/server.py @@ -43,9 +43,7 @@ from litellm.proxy._experimental.mcp_server.utils import ( ) from litellm.proxy._types import UserAPIKeyAuth from litellm.proxy.auth.ip_address_utils import IPAddressUtils -from litellm.proxy.litellm_pre_call_utils import ( - LiteLLMProxyRequestSetup, -) +from litellm.proxy.litellm_pre_call_utils import LiteLLMProxyRequestSetup from litellm.types.mcp import MCPAuth from litellm.types.mcp_server.mcp_server_manager import MCPInfo, MCPServer from litellm.types.utils import CallTypes, StandardLoggingMCPToolCall @@ -795,6 +793,7 @@ if MCP_AVAILABLE: mcp_servers=mcp_servers, allowed_mcp_servers=allowed_mcp_servers, ) + return allowed_mcp_servers @@ -938,9 +937,6 @@ if MCP_AVAILABLE: mcp_servers=mcp_servers, ) - # Decide whether to add prefix based on number of allowed servers - add_prefix = not (len(allowed_mcp_servers) == 1) - async def _fetch_and_filter_server_tools( server: MCPServer, ) -> List[MCPTool]: @@ -961,7 +957,7 @@ if MCP_AVAILABLE: server=server, mcp_auth_header=server_auth_header, extra_headers=extra_headers, - add_prefix=add_prefix, + add_prefix=True, # Always add server prefix raw_headers=raw_headers, ) filtered_tools = filter_tools_by_allowed_tools(tools, server) @@ -1079,8 +1075,6 @@ if MCP_AVAILABLE: mcp_servers=mcp_servers, ) - # Decide whether to add prefix based on number of allowed servers - add_prefix = not (len(allowed_mcp_servers) == 1) # Get prompts from each allowed server all_prompts = [] @@ -1101,7 +1095,7 @@ if MCP_AVAILABLE: server=server, mcp_auth_header=server_auth_header, extra_headers=extra_headers, - add_prefix=add_prefix, + add_prefix=True, # Always add server prefix raw_headers=raw_headers, ) @@ -1140,7 +1134,6 @@ if MCP_AVAILABLE: mcp_servers=mcp_servers, ) - add_prefix = not (len(allowed_mcp_servers) == 1) all_resources: List[Resource] = [] for server in allowed_mcp_servers: @@ -1160,7 +1153,7 @@ if MCP_AVAILABLE: server=server, mcp_auth_header=server_auth_header, extra_headers=extra_headers, - add_prefix=add_prefix, + add_prefix=True, # Always add server prefix raw_headers=raw_headers, ) all_resources.extend(resources) @@ -1197,7 +1190,6 @@ if MCP_AVAILABLE: mcp_servers=mcp_servers, ) - add_prefix = not (len(allowed_mcp_servers) == 1) all_resource_templates: List[ResourceTemplate] = [] for server in allowed_mcp_servers: @@ -1218,7 +1210,7 @@ if MCP_AVAILABLE: server=server, mcp_auth_header=server_auth_header, extra_headers=extra_headers, - add_prefix=add_prefix, + add_prefix=True, # Always add server prefix raw_headers=raw_headers, ) ) @@ -1676,14 +1668,9 @@ if MCP_AVAILABLE: detail="User not allowed to get this prompt.", ) - # Decide whether to add prefix based on number of allowed servers - add_prefix = not (len(allowed_mcp_servers) == 1) - if add_prefix: - original_prompt_name, server_name = split_server_prefix_from_name(name) - else: - original_prompt_name = name - server_name = allowed_mcp_servers[0].name + # Extract server name from prefixed prompt name + original_prompt_name, server_name = split_server_prefix_from_name(name) server = next((s for s in allowed_mcp_servers if s.name == server_name), None) if server is None: diff --git a/litellm/proxy/_new_secret_config.yaml b/litellm/proxy/_new_secret_config.yaml index edfe8abea5..67088e79ef 100644 --- a/litellm/proxy/_new_secret_config.yaml +++ b/litellm/proxy/_new_secret_config.yaml @@ -13,26 +13,25 @@ model_list: - model_name: gpt-4.1-mini litellm_params: model: openai/gpt-4.1-mini - - -# guardrails: -# - guardrail_name: generic-guardrail -# litellm_params: -# guardrail: generic_guardrail_api -# mode: ["pre_call"] -# headers: -# Authorization: Bearer mock-bedrock-token-12345 -# api_base: http://localhost:8080 -# default_on: true - -prompts: - - prompt_id: "simple_prompt" + - model_name: gpt-5-mini litellm_params: - prompt_integration: "generic_prompt_management" - provider_specific_query_params: - project_name: litellm - slug: hello-world-prompt-2bac - api_base: http://localhost:8080 - api_key: os.environ/BRAINTRUST_API_KEY - ignore_prompt_manager_model: true - ignore_prompt_manager_optional_params: true + model: openai/gpt-5-mini + + +guardrails: + - guardrail_name: mcp-user-permissions + litellm_params: + guardrail: mcp_end_user_permission + mode: pre_call + default_on: true + +mcp_servers: + my_http_server: + url: "http://0.0.0.0:8001/mcp" + transport: "http" + description: "My custom MCP server" + available_on_public_internet: true + +general_settings: + store_model_in_db: true + store_prompts_in_spend_logs: true diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index 71b376ea28..612cb0e1e7 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -1409,12 +1409,13 @@ class NewCustomerRequest(BudgetNewRequest): blocked: bool = False # allow/disallow requests for this end-user budget_id: Optional[str] = None # give either a budget_id or max_budget spend: Optional[float] = None - allowed_model_region: Optional[ - AllowedModelRegion - ] = None # require all user requests to use models in this specific region - default_model: Optional[ - str - ] = None # if no equivalent model in allowed region - default all requests to this model + allowed_model_region: Optional[AllowedModelRegion] = ( + None # require all user requests to use models in this specific region + ) + default_model: Optional[str] = ( + None # if no equivalent model in allowed region - default all requests to this model + ) + object_permission: Optional[LiteLLM_ObjectPermissionBase] = None @model_validator(mode="before") @classmethod @@ -1436,12 +1437,13 @@ class UpdateCustomerRequest(LiteLLMPydanticObjectBase): blocked: bool = False # allow/disallow requests for this end-user max_budget: Optional[float] = None budget_id: Optional[str] = None # give either a budget_id or max_budget - allowed_model_region: Optional[ - AllowedModelRegion - ] = None # require all user requests to use models in this specific region - default_model: Optional[ - str - ] = None # if no equivalent model in allowed region - default all requests to this model + allowed_model_region: Optional[AllowedModelRegion] = ( + None # require all user requests to use models in this specific region + ) + default_model: Optional[str] = ( + None # if no equivalent model in allowed region - default all requests to this model + ) + object_permission: Optional[LiteLLM_ObjectPermissionBase] = None class DeleteCustomerRequest(LiteLLMPydanticObjectBase): @@ -2301,6 +2303,7 @@ class UserAPIKeyAuth( user_max_budget: Optional[float] = None request_route: Optional[str] = None user: Optional[Any] = None # Expanded user object when expand=user is used + end_user_object_permission: Optional[LiteLLM_ObjectPermissionTable] = None model_config = ConfigDict(arbitrary_types_allowed=True) @@ -2535,6 +2538,8 @@ class LiteLLM_EndUserTable(LiteLLMPydanticObjectBase): allowed_model_region: Optional[AllowedModelRegion] = None default_model: Optional[str] = None litellm_budget_table: Optional[LiteLLM_BudgetTable] = None + object_permission_id: Optional[str] = None + object_permission: Optional[LiteLLM_ObjectPermissionTable] = None @model_validator(mode="before") @classmethod diff --git a/litellm/proxy/auth/auth_checks.py b/litellm/proxy/auth/auth_checks.py index 3eb6f28ddf..5a8a9f4e4d 100644 --- a/litellm/proxy/auth/auth_checks.py +++ b/litellm/proxy/auth/auth_checks.py @@ -11,7 +11,8 @@ Run checks for: import asyncio import re import time -from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Union, cast +from typing import (TYPE_CHECKING, Any, Dict, List, Literal, Optional, Union, + cast) from fastapi import HTTPException, Request, status from pydantic import BaseModel @@ -20,41 +21,27 @@ import litellm from litellm._logging import verbose_proxy_logger from litellm.caching.caching import DualCache from litellm.caching.dual_cache import LimitedSizeOrderedDict -from litellm.constants import ( - CLI_JWT_EXPIRATION_HOURS, - CLI_JWT_TOKEN_NAME, - DEFAULT_ACCESS_GROUP_CACHE_TTL, - DEFAULT_IN_MEMORY_TTL, - DEFAULT_MANAGEMENT_OBJECT_IN_MEMORY_CACHE_TTL, - DEFAULT_MAX_RECURSE_DEPTH, - EMAIL_BUDGET_ALERT_MAX_SPEND_ALERT_PERCENTAGE, -) +from litellm.constants import (CLI_JWT_EXPIRATION_HOURS, CLI_JWT_TOKEN_NAME, + DEFAULT_ACCESS_GROUP_CACHE_TTL, + DEFAULT_IN_MEMORY_TTL, + DEFAULT_MANAGEMENT_OBJECT_IN_MEMORY_CACHE_TTL, + DEFAULT_MAX_RECURSE_DEPTH, + EMAIL_BUDGET_ALERT_MAX_SPEND_ALERT_PERCENTAGE) from litellm.litellm_core_utils.get_llm_provider_logic import get_llm_provider -from litellm.proxy._types import ( - RBAC_ROLES, - CallInfo, - LiteLLM_AccessGroupTable, - LiteLLM_BudgetTable, - LiteLLM_EndUserTable, - Litellm_EntityType, - LiteLLM_JWTAuth, - LiteLLM_ObjectPermissionTable, - LiteLLM_OrganizationMembershipTable, - LiteLLM_OrganizationTable, - LiteLLM_TagTable, - LiteLLM_TeamMembership, - LiteLLM_TeamTable, - LiteLLM_TeamTableCachedObj, - LiteLLM_UserTable, - LiteLLMRoutes, - LitellmUserRoles, - NewTeamRequest, - ProxyErrorTypes, - ProxyException, - RoleBasedPermissions, - SpecialModelNames, - UserAPIKeyAuth, -) +from litellm.proxy._types import (RBAC_ROLES, CallInfo, + LiteLLM_AccessGroupTable, + LiteLLM_BudgetTable, LiteLLM_EndUserTable, + Litellm_EntityType, LiteLLM_JWTAuth, + LiteLLM_ObjectPermissionTable, + LiteLLM_OrganizationMembershipTable, + LiteLLM_OrganizationTable, LiteLLM_TagTable, + LiteLLM_TeamMembership, LiteLLM_TeamTable, + LiteLLM_TeamTableCachedObj, + LiteLLM_UserTable, LiteLLMRoutes, + LitellmUserRoles, NewTeamRequest, + ProxyErrorTypes, ProxyException, + RoleBasedPermissions, SpecialModelNames, + UserAPIKeyAuth) from litellm.proxy.auth.route_checks import RouteChecks from litellm.proxy.route_llm_request import route_request from litellm.proxy.utils import PrismaClient, ProxyLogging, log_db_metrics @@ -366,7 +353,8 @@ async def common_checks( _request_metadata: dict = request_body.get("metadata", {}) or {} if _request_metadata.get("guardrails"): # check if team allowed to modify guardrails - from litellm.proxy.guardrails.guardrail_helpers import can_modify_guardrails + from litellm.proxy.guardrails.guardrail_helpers import \ + can_modify_guardrails can_modify: bool = can_modify_guardrails(team_object) if can_modify is False: @@ -792,7 +780,7 @@ async def get_end_user_object( try: response = await prisma_client.db.litellm_endusertable.find_unique( where={"user_id": end_user_id}, - include={"litellm_budget_table": True}, + include={"litellm_budget_table": True, "object_permission": True}, ) if response is None: @@ -1812,9 +1800,8 @@ class ExperimentalUIJWTToken: def get_experimental_ui_login_jwt_auth_token(user_info: LiteLLM_UserTable) -> str: from datetime import timedelta - from litellm.proxy.common_utils.encrypt_decrypt_utils import ( - encrypt_value_helper, - ) + from litellm.proxy.common_utils.encrypt_decrypt_utils import \ + encrypt_value_helper if user_info.user_role is None: raise Exception("User role is required for experimental UI login") @@ -1860,9 +1847,8 @@ class ExperimentalUIJWTToken: """ from datetime import timedelta - from litellm.proxy.common_utils.encrypt_decrypt_utils import ( - encrypt_value_helper, - ) + from litellm.proxy.common_utils.encrypt_decrypt_utils import \ + encrypt_value_helper if user_info.user_role is None: raise Exception("User role is required for CLI JWT login") @@ -1901,9 +1887,8 @@ class ExperimentalUIJWTToken: import json from litellm.proxy.auth.user_api_key_auth import UserAPIKeyAuth - from litellm.proxy.common_utils.encrypt_decrypt_utils import ( - decrypt_value_helper, - ) + from litellm.proxy.common_utils.encrypt_decrypt_utils import \ + decrypt_value_helper decrypted_token = decrypt_value_helper( hashed_token, key="ui_hash_key", exception_type="debug" @@ -2150,11 +2135,11 @@ async def _get_resources_from_access_groups( # Lazy import to avoid circular imports if prisma_client is None or user_api_key_cache is None: - from litellm.proxy.proxy_server import ( - prisma_client as _prisma_client, - proxy_logging_obj as _proxy_logging_obj, - user_api_key_cache as _user_api_key_cache, - ) + from litellm.proxy.proxy_server import prisma_client as _prisma_client + from litellm.proxy.proxy_server import \ + proxy_logging_obj as _proxy_logging_obj + from litellm.proxy.proxy_server import \ + user_api_key_cache as _user_api_key_cache prisma_client = prisma_client or _prisma_client user_api_key_cache = user_api_key_cache or _user_api_key_cache @@ -2936,7 +2921,8 @@ async def _tag_max_budget_check( BudgetExceededError if any tag is over its max budget. Triggers a budget alert if any tag is over its max budget. """ - from litellm.proxy.common_utils.http_parsing_utils import get_tags_from_request_body + from litellm.proxy.common_utils.http_parsing_utils import \ + get_tags_from_request_body if prisma_client is None: return diff --git a/litellm/proxy/auth/auth_utils.py b/litellm/proxy/auth/auth_utils.py index fbfc88228c..59bc4190fd 100644 --- a/litellm/proxy/auth/auth_utils.py +++ b/litellm/proxy/auth/auth_utils.py @@ -736,6 +736,16 @@ def get_end_user_id_from_request_body( user_id_from_metadata_field = metadata_dict.get("user_id") if user_id_from_metadata_field is not None: return str(user_id_from_metadata_field) + + + # Check 6: 'safety_identifier' in request body (OpenAI Responses API parameter) + # SECURITY NOTE: safety_identifier can be set by any caller in the request body. + # Only use this for end-user identification in trusted environments where you control + # the calling application. For untrusted callers, prefer using headers or server-side + # middleware to set the end_user_id to prevent impersonation. + if request_body.get("safety_identifier") is not None: + user_from_body_user_field = request_body["safety_identifier"] + return str(user_from_body_user_field) return None diff --git a/litellm/proxy/auth/user_api_key_auth.py b/litellm/proxy/auth/user_api_key_auth.py index 77d96e2d39..e2c7dbdaba 100644 --- a/litellm/proxy/auth/user_api_key_auth.py +++ b/litellm/proxy/auth/user_api_key_auth.py @@ -644,7 +644,7 @@ async def _user_api_key_auth_builder( # noqa: PLR0915 if team_object is not None else None, ) - + # Check if model has zero cost - if so, skip all budget checks model = get_model_from_request(request_data, route) skip_budget_checks = False @@ -831,6 +831,8 @@ async def _user_api_key_auth_builder( # noqa: PLR0915 valid_token=valid_token, end_user_params=end_user_params ) valid_token.parent_otel_span = parent_otel_span + if _end_user_object is not None: + valid_token.end_user_object_permission = _end_user_object.object_permission return valid_token @@ -1277,6 +1279,9 @@ async def _user_api_key_auth_builder( # noqa: PLR0915 if _end_user_object is not None: valid_token_dict.update(end_user_params) + valid_token_dict["end_user_object_permission"] = ( + _end_user_object.object_permission + ) # check if token is from litellm-ui, litellm ui makes keys to allow users to login with sso. These keys can only be used for LiteLLM UI functions # sso/login, ui/login, /key functions and /user functions diff --git a/litellm/proxy/guardrails/guardrail_hooks/litellm_content_filter/content_filter.py b/litellm/proxy/guardrails/guardrail_hooks/litellm_content_filter/content_filter.py index 7058e7644c..d82548363e 100644 --- a/litellm/proxy/guardrails/guardrail_hooks/litellm_content_filter/content_filter.py +++ b/litellm/proxy/guardrails/guardrail_hooks/litellm_content_filter/content_filter.py @@ -205,12 +205,6 @@ class ContentFilterGuardrail(CustomGuardrail): # Load categories if provided if categories: self._load_categories(categories) - else: - verbose_proxy_logger.warning( - "ContentFilterGuardrail has no content categories configured. " - "Toxic/abuse and other category-based keyword filtering will not run. " - "Add categories (e.g. harm_toxic_abuse) in the guardrail config to enable them." - ) # Normalize inputs: convert dicts to Pydantic models for consistent handling normalized_patterns: List[ContentFilterPattern] = [] diff --git a/litellm/proxy/guardrails/guardrail_hooks/mcp_end_user_permission/__init__.py b/litellm/proxy/guardrails/guardrail_hooks/mcp_end_user_permission/__init__.py new file mode 100644 index 0000000000..02b38fe04e --- /dev/null +++ b/litellm/proxy/guardrails/guardrail_hooks/mcp_end_user_permission/__init__.py @@ -0,0 +1,35 @@ +from typing import TYPE_CHECKING + +from litellm.types.guardrails import SupportedGuardrailIntegrations + +from .mcp_end_user_permission import MCPEndUserPermissionGuardrail + +if TYPE_CHECKING: + from litellm.types.guardrails import Guardrail, LitellmParams + + +def initialize_guardrail(litellm_params: "LitellmParams", guardrail: "Guardrail"): + import litellm + + # Default to always-on. Only disable if the user explicitly sets default_on: false. + # We check the raw guardrail dict because LitellmParams normalizes None → False, + # making it impossible to distinguish "not set" from "explicitly false" via litellm_params. + _raw_default_on = guardrail.get("litellm_params", {}).get("default_on") + _default_on = False if _raw_default_on is False else True + + _callback = MCPEndUserPermissionGuardrail( + guardrail_name=guardrail.get("guardrail_name", ""), + event_hook=litellm_params.mode, + default_on=_default_on, + ) + litellm.logging_callback_manager.add_litellm_callback(_callback) + return _callback + + +guardrail_initializer_registry = { + SupportedGuardrailIntegrations.MCP_END_USER_PERMISSION.value: initialize_guardrail, +} + +guardrail_class_registry = { + SupportedGuardrailIntegrations.MCP_END_USER_PERMISSION.value: MCPEndUserPermissionGuardrail, +} diff --git a/litellm/proxy/guardrails/guardrail_hooks/mcp_end_user_permission/mcp_end_user_permission.py b/litellm/proxy/guardrails/guardrail_hooks/mcp_end_user_permission/mcp_end_user_permission.py new file mode 100644 index 0000000000..a485ff2646 --- /dev/null +++ b/litellm/proxy/guardrails/guardrail_hooks/mcp_end_user_permission/mcp_end_user_permission.py @@ -0,0 +1,262 @@ +""" +MCP End User Permission Guardrail Hook + +Enforces end user permissions for MCP server access via apply_guardrail: +- input_type="request" → filter tools the end user cannot access + +Permission logic: +- No end_user_id → allow all (key/team-level permissions apply) +- end_user_id, no mcp_servers → allow all (default) +- end_user_id + mcp_servers → allow only those servers +""" + +from typing import TYPE_CHECKING, Any, List, Literal, Optional, Type + +from litellm._logging import verbose_proxy_logger +from litellm.integrations.custom_guardrail import ( + CustomGuardrail, + log_guardrail_information, +) +from litellm.proxy._types import LiteLLM_ObjectPermissionTable +from litellm.types.guardrails import GuardrailEventHooks +from litellm.types.utils import GenericGuardrailAPIInputs + +if TYPE_CHECKING: + from litellm.types.proxy.guardrails.guardrail_hooks.base import GuardrailConfigModel + +GUARDRAIL_NAME = "mcp_end_user_permission" + + +class MCPEndUserPermissionGuardrail(CustomGuardrail): + """ + Guardrail that enforces end user permissions for MCP server access. + + Runs on input only (pre-call). Filters tools in the request that the + end user is not permitted to call based on their object_permission. + + end_user_object_permission is populated on UserAPIKeyAuth during auth. + The guardrail resolves it via a cached get_end_user_object lookup — + no extra DB round-trip when the cache is warm. + """ + + def __init__(self, **kwargs): + if "supported_event_hooks" not in kwargs: + kwargs["supported_event_hooks"] = [ + GuardrailEventHooks.pre_call, + ] + super().__init__(**kwargs) + verbose_proxy_logger.debug("MCP End User Permission Guardrail initialized") + + # ------------------------------------------------------------------ + # apply_guardrail — filters MCP tools on the request side only + # ------------------------------------------------------------------ + + @log_guardrail_information + async def apply_guardrail( + self, + inputs: GenericGuardrailAPIInputs, + request_data: dict, + input_type: Literal["request", "response"] = "request", + logging_obj: Optional[Any] = None, + ) -> GenericGuardrailAPIInputs: + """ + Filters MCP tools the end user cannot access based on their + object_permission.mcp_servers / mcp_access_groups settings. + """ + object_permission = await self._resolve_end_user_object_permission(request_data) + return await self._check_request_tools(inputs, object_permission) + + # ------------------------------------------------------------------ + # Private — request-side tool filtering + # ------------------------------------------------------------------ + + async def _check_request_tools( + self, + inputs: GenericGuardrailAPIInputs, + object_permission: Optional[LiteLLM_ObjectPermissionTable], + ) -> GenericGuardrailAPIInputs: + tools = inputs.get("tools") + if not tools: + return inputs + + allowed_mcp_servers = ( + await self._get_allowed_mcp_servers_from_object_permission( + object_permission + ) + ) + if allowed_mcp_servers is None: + return inputs # No restrictions → pass through unchanged + + verbose_proxy_logger.debug( + f"MCP guardrail: end user restricted to MCP servers: {allowed_mcp_servers}" + ) + + filtered_tools = [] + removed_tools = [] + + for tool in tools: + tool_name = self._get_tool_name_from_definition(tool) + server_name = ( + self._extract_mcp_server_name(tool_name) if tool_name else None + ) + + if server_name is None: + # Not an MCP tool (no prefix) or unrecognised format → keep + filtered_tools.append(tool) + elif server_name in allowed_mcp_servers: + filtered_tools.append(tool) + else: + removed_tools.append(tool_name) + verbose_proxy_logger.warning( + f"MCP guardrail: removing tool '{tool_name}' " + f"(server: '{server_name}') — not in end user's allowed servers" + ) + + if removed_tools: + verbose_proxy_logger.debug( + f"MCP guardrail: removed {len(removed_tools)} unauthorized MCP tool(s): {removed_tools}" + ) + inputs["tools"] = filtered_tools + + return inputs + + # ------------------------------------------------------------------ + # Private — end user permission resolution + # ------------------------------------------------------------------ + + @staticmethod + async def _resolve_end_user_object_permission( + request_data: dict, + ) -> Optional[LiteLLM_ObjectPermissionTable]: + """ + Resolve the end user's object_permission via the cached auth lookup. + + Uses get_end_user_object (same path as auth) so no extra DB round-trip + when the cache is warm. + """ + end_user_id = MCPEndUserPermissionGuardrail._get_end_user_id_from_request_data( + request_data + ) + if not end_user_id: + return None + + end_user_object = await MCPEndUserPermissionGuardrail._fetch_end_user_object( + end_user_id + ) + return ( + end_user_object.object_permission if end_user_object is not None else None + ) + + @staticmethod + def _get_end_user_id_from_request_data(request_data: dict) -> Optional[str]: + return request_data.get("user_api_key_end_user_id") or request_data.get( + "litellm_metadata", {} + ).get("user_api_key_end_user_id") + + @staticmethod + async def _fetch_end_user_object(end_user_id: str): # type: ignore[return] + """ + Fetch end user object via the same cached path used during auth. + No extra DB round-trip when the cache is warm. + """ + from litellm.proxy.auth.auth_checks import get_end_user_object + from litellm.proxy.proxy_server import ( + prisma_client, + proxy_logging_obj, + user_api_key_cache, + ) + + if prisma_client is None: + return None + + try: + return await get_end_user_object( + end_user_id=end_user_id, + prisma_client=prisma_client, + user_api_key_cache=user_api_key_cache, + parent_otel_span=None, + proxy_logging_obj=proxy_logging_obj, + route="/mcp", + ) + except Exception as e: + verbose_proxy_logger.warning( + f"MCP guardrail: failed to fetch end_user_object for '{end_user_id}': {e}" + ) + return None + + # ------------------------------------------------------------------ + # Private — permission derivation + # ------------------------------------------------------------------ + + @staticmethod + async def _get_allowed_mcp_servers_from_object_permission( + object_permission: Optional[LiteLLM_ObjectPermissionTable], + ) -> Optional[List[str]]: + """ + Returns: + None — no restrictions configured, allow all MCP servers + list — restrict to exactly these server names + """ + if object_permission is None: + return None + + direct_mcp_servers = object_permission.mcp_servers or [] + mcp_access_groups = object_permission.mcp_access_groups or [] + + if not direct_mcp_servers and not mcp_access_groups: + return None # Both empty → no restrictions + + from litellm.proxy._experimental.mcp_server.auth.user_api_key_auth_mcp import ( + MCPRequestHandler, + ) + + access_group_servers = ( + await MCPRequestHandler._get_mcp_servers_from_access_groups( + mcp_access_groups + ) + ) + + return list(set(direct_mcp_servers + access_group_servers)) + + # ------------------------------------------------------------------ + # Config model — exposes this guardrail in the UI + # ------------------------------------------------------------------ + + @staticmethod + def get_config_model() -> Optional[Type["GuardrailConfigModel"]]: + from litellm.types.proxy.guardrails.guardrail_hooks.mcp_end_user_permission import ( + MCPEndUserPermissionGuardrailConfigModel, + ) + + return MCPEndUserPermissionGuardrailConfigModel + + # ------------------------------------------------------------------ + # Private — tool name extraction + # ------------------------------------------------------------------ + + @staticmethod + def _extract_mcp_server_name(tool_name: str) -> Optional[str]: + """ + Split "github-create_issue" → "github". + Returns None if the tool name has no '-' prefix (not an MCP tool). + """ + if not tool_name or "-" not in tool_name: + return None + return tool_name.split("-", 1)[0] + + @staticmethod + def _get_tool_name_from_definition(tool: Any) -> Optional[str]: + """ + Extract tool name from a definition dict. + + OpenAI format: {"type": "function", "function": {"name": "..."}} + Anthropic format: {"name": "...", "input_schema": {...}} + """ + if not isinstance(tool, dict): + return None + function_def = tool.get("function") + if isinstance(function_def, dict): + name = function_def.get("name") + if name: + return name + return tool.get("name") diff --git a/litellm/proxy/guardrails/guardrail_hooks/unified_guardrail/unified_guardrail.py b/litellm/proxy/guardrails/guardrail_hooks/unified_guardrail/unified_guardrail.py index 84bbf6d20e..c35eadcb6f 100644 --- a/litellm/proxy/guardrails/guardrail_hooks/unified_guardrail/unified_guardrail.py +++ b/litellm/proxy/guardrails/guardrail_hooks/unified_guardrail/unified_guardrail.py @@ -85,6 +85,7 @@ class UnifiedLLMGuardrails(CustomLogger): add_guardrail_to_applied_guardrails_header, ) + verbose_proxy_logger.debug("Running UnifiedLLMGuardrails pre-call hook") guardrail_to_apply: CustomGuardrail = data.pop("guardrail_to_apply", None) diff --git a/litellm/proxy/guardrails/init_guardrails.py b/litellm/proxy/guardrails/init_guardrails.py index 5db61eb9c5..d742cc223b 100644 --- a/litellm/proxy/guardrails/init_guardrails.py +++ b/litellm/proxy/guardrails/init_guardrails.py @@ -34,7 +34,7 @@ def init_guardrails_v2( if initialized_guardrail: guardrail_list.append(initialized_guardrail) - verbose_proxy_logger.debug(f"\nGuardrail List:{guardrail_list}\n") + # verbose_proxy_logger.debug(f"\nGuardrail List:{guardrail_list}\n") # Populate router's guardrail_list for load balancing support _populate_router_guardrail_list(guardrail_list=guardrail_list) diff --git a/litellm/proxy/litellm_pre_call_utils.py b/litellm/proxy/litellm_pre_call_utils.py index 3e8cc46521..3f50aad355 100644 --- a/litellm/proxy/litellm_pre_call_utils.py +++ b/litellm/proxy/litellm_pre_call_utils.py @@ -879,7 +879,7 @@ async def add_litellm_data_to_request( # noqa: PLR0915 general_settings, user_api_key_dict, _headers ) - # Parse user info from headers + # Parse user info from headers (fallback to general_settings.user_header_name) user = LiteLLMProxyRequestSetup.get_user_from_headers(_headers, general_settings) if user is not None: if user_api_key_dict.end_user_id is None: @@ -1540,9 +1540,7 @@ def _match_and_track_policies( add_policy_sources_to_metadata, add_policy_to_applied_policies_header, ) - from litellm.proxy.policy_engine.attachment_registry import ( - get_attachment_registry, - ) + from litellm.proxy.policy_engine.attachment_registry import get_attachment_registry from litellm.proxy.policy_engine.policy_matcher import PolicyMatcher # Get matching policies via attachments (with match reasons for attribution) @@ -1677,9 +1675,7 @@ def add_guardrails_from_policy_engine( user_api_key_dict: The user's API key authentication info """ from litellm._logging import verbose_proxy_logger - from litellm.proxy.common_utils.http_parsing_utils import ( - get_tags_from_request_body, - ) + from litellm.proxy.common_utils.http_parsing_utils import get_tags_from_request_body from litellm.proxy.policy_engine.policy_registry import get_policy_registry from litellm.types.proxy.policy_engine import PolicyMatchContext diff --git a/litellm/proxy/management_endpoints/customer_endpoints.py b/litellm/proxy/management_endpoints/customer_endpoints.py index 9ff0fe6e59..bfafc943c4 100644 --- a/litellm/proxy/management_endpoints/customer_endpoints.py +++ b/litellm/proxy/management_endpoints/customer_endpoints.py @@ -19,11 +19,13 @@ import litellm from litellm._logging import verbose_proxy_logger from litellm.proxy._types import * from litellm.proxy.auth.user_api_key_auth import user_api_key_auth +from litellm.proxy.management_endpoints.common_daily_activity import \ + get_daily_activity +from litellm.proxy.management_helpers.object_permission_utils import ( + _set_object_permission, handle_update_object_permission_common) from litellm.proxy.utils import handle_exception_on_proxy -from litellm.types.proxy.management_endpoints.common_daily_activity import ( - SpendAnalyticsPaginatedResponse, -) -from litellm.proxy.management_endpoints.common_daily_activity import get_daily_activity +from litellm.types.proxy.management_endpoints.common_daily_activity import \ + SpendAnalyticsPaginatedResponse router = APIRouter() @@ -107,9 +109,8 @@ async def unblock_user(data: BlockUsers): ``` """ try: - from enterprise.enterprise_hooks.blocked_user_list import ( - _ENTERPRISE_BlockedUserList, - ) + from enterprise.enterprise_hooks.blocked_user_list import \ + _ENTERPRISE_BlockedUserList except ImportError: raise HTTPException( status_code=400, @@ -164,6 +165,38 @@ def new_budget_request(data: NewCustomerRequest) -> Optional[BudgetNewRequest]: return None +async def _handle_customer_object_permission_update( + non_default_values: dict, + end_user_table_data_typed: Optional[LiteLLM_EndUserTable], + update_end_user_table_data: dict, + prisma_client, +) -> None: + """ + Handle object permission updates for customer endpoints. + + Updates the update_end_user_table_data dict in place with the new object_permission_id. + + Args: + non_default_values: Dictionary containing the update values including object_permission + end_user_table_data_typed: Existing end user table data + update_end_user_table_data: Dictionary to update with new object_permission_id + prisma_client: Prisma database client + """ + if "object_permission" in non_default_values: + existing_object_permission_id = ( + end_user_table_data_typed.object_permission_id + if end_user_table_data_typed is not None + else None + ) + object_permission_id = await handle_update_object_permission_common( + data_json=non_default_values, + existing_object_permission_id=existing_object_permission_id, + prisma_client=prisma_client, + ) + if object_permission_id is not None: + update_end_user_table_data["object_permission_id"] = object_permission_id + + @router.post( "/end_user/new", tags=["Customer Management"], @@ -200,6 +233,16 @@ async def new_end_user( - soft_budget: Optional[float] - [Not Implemented Yet] Get alerts when customer crosses given budget, doesn't block requests. - spend: Optional[float] - Specify initial spend for a given customer. - budget_reset_at: Optional[str] - Specify the date and time when the budget should be reset. + - object_permission: Optional[LiteLLM_ObjectPermissionBase] - Customer-specific object permissions to control access to resources. + Supported fields: + * mcp_servers: List[str] - List of allowed MCP server IDs + * mcp_access_groups: List[str] - List of MCP access group names + * mcp_tool_permissions: Dict[str, List[str]] - Map of server ID to allowed tool names (e.g., {"server_1": ["tool_a", "tool_b"]}) + * vector_stores: List[str] - List of allowed vector store IDs + * agents: List[str] - List of allowed agent IDs + * agent_access_groups: List[str] - List of agent access group names + Example: {"mcp_servers": ["server_1", "server_2"], "vector_stores": ["vector_store_1"], "agents": ["agent_1"]} + IF null or {} then no object-level restrictions apply. - Allow specifying allowed regions @@ -214,9 +257,22 @@ async def new_end_user( "user_id" : "ishaan-jaff-3", "allowed_region": "eu", "budget_id": "free_tier", - "default_model": "azure/gpt-3.5-turbo-eu" <- all calls from this user, use this model? + "default_model": "azure/gpt-3.5-turbo-eu" }' + # With object permissions + curl -L -X POST 'http://localhost:4000/customer/new' \ + -H 'Authorization: Bearer sk-1234' \ + -H 'Content-Type: application/json' \ + -d '{ + "user_id": "user_1", + "object_permission": { + "mcp_servers": ["server_1"], + "mcp_access_groups": ["public_group"], + "vector_stores": ["vector_store_1"] + } + }' + # return end-user object ``` @@ -233,11 +289,8 @@ async def new_end_user( - end-user object - currently allowed models """ - from litellm.proxy.proxy_server import ( - litellm_proxy_admin_name, - llm_router, - prisma_client, - ) + from litellm.proxy.proxy_server import (litellm_proxy_admin_name, + llm_router, prisma_client) if prisma_client is None: raise HTTPException( @@ -289,13 +342,34 @@ async def new_end_user( if k not in BudgetNewRequest.model_fields.keys(): new_end_user_obj[k] = v + ## Handle Object Permission - MCP Servers, Vector Stores etc. + new_end_user_obj = await _set_object_permission( + data_json=new_end_user_obj, + prisma_client=prisma_client, + ) + + # Ensure object_permission is not in the data being sent to create + # It should have been converted to object_permission_id by _set_object_permission + if "object_permission" in new_end_user_obj: + verbose_proxy_logger.warning( + f"object_permission still in new_end_user_obj after _set_object_permission: {new_end_user_obj.get('object_permission')}" + ) + new_end_user_obj.pop("object_permission", None) + ## WRITE TO DB ## end_user_record = await prisma_client.db.litellm_endusertable.create( data=new_end_user_obj, # type: ignore - include={"litellm_budget_table": True}, + include={"litellm_budget_table": True, "object_permission": True}, ) - return end_user_record + # Convert to dict and clean up recursive fields + response_dict = end_user_record.model_dump() + if response_dict.get("object_permission"): + # Remove reverse relations from object_permission + for field in ["teams", "verification_tokens", "organizations", "users", "end_users"]: + response_dict["object_permission"].pop(field, None) + + return response_dict except Exception as e: verbose_proxy_logger.exception( "litellm.proxy.management_endpoints.customer_endpoints.new_end_user(): Exception occured - {}".format( @@ -351,7 +425,7 @@ async def end_user_info( ) user_info = await prisma_client.db.litellm_endusertable.find_first( - where={"user_id": end_user_id}, include={"litellm_budget_table": True} + where={"user_id": end_user_id}, include={"litellm_budget_table": True, "object_permission": True} ) if user_info is None: @@ -361,7 +435,15 @@ async def end_user_info( code=404, param="end_user_id", ) - return user_info.model_dump(exclude_none=True) + + # Convert to dict and clean up recursive fields + response_dict = user_info.model_dump(exclude_none=True) + if response_dict.get("object_permission"): + # Remove reverse relations from object_permission + for field in ["teams", "verification_tokens", "organizations", "users", "end_users"]: + response_dict["object_permission"].pop(field, None) + + return response_dict except Exception as e: verbose_proxy_logger.exception( @@ -401,6 +483,16 @@ async def update_end_user( - default_model: Optional[str] = ( None # if no equivalent model in allowed region - default all requests to this model ) + - object_permission: Optional[LiteLLM_ObjectPermissionBase] - Customer-specific object permissions to control access to resources. + Supported fields: + * mcp_servers: List[str] - List of allowed MCP server IDs + * mcp_access_groups: List[str] - List of MCP access group names + * mcp_tool_permissions: Dict[str, List[str]] - Map of server ID to allowed tool names + * vector_stores: List[str] - List of allowed vector store IDs + * agents: List[str] - List of allowed agent IDs + * agent_access_groups: List[str] - List of agent access group names + Example: {"mcp_servers": ["server_1"], "vector_stores": ["vector_store_1"]} + IF null or {} then no object-level restrictions apply. Example curl: ``` @@ -412,11 +504,24 @@ async def update_end_user( "budget_id": "paid_tier" }' - See below for all params + # Updating object permissions + curl -L -X POST 'http://localhost:4000/customer/update' \ + --header 'Authorization: Bearer sk-1234' \ + --header 'Content-Type: application/json' \ + --data '{ + "user_id": "user_1", + "object_permission": { + "mcp_servers": ["server_3"], + "vector_stores": ["vector_store_2", "vector_store_3"] + } + }' + + See below for all params ``` """ - from litellm.proxy.proxy_server import litellm_proxy_admin_name, prisma_client + from litellm.proxy.proxy_server import (litellm_proxy_admin_name, + prisma_client) try: data_json: dict = data.json() @@ -467,6 +572,14 @@ async def update_end_user( elif k in LiteLLM_EndUserTable.model_fields.keys(): update_end_user_table_data[k] = v + ## Handle object permission updates (MCP servers, vector stores, etc.) + await _handle_customer_object_permission_update( + non_default_values=non_default_values, + end_user_table_data_typed=end_user_table_data_typed, + update_end_user_table_data=update_end_user_table_data, + prisma_client=prisma_client, + ) + ## Check if we need to create a new budget (only if budget fields are provided, not just budget_id) ## if budget_table_data: if end_user_budget_table is None: @@ -498,11 +611,20 @@ async def update_end_user( ## Update user table, with update params + new budget id (if set) ## verbose_proxy_logger.debug("/customer/update: Received data = %s", data) + + # Ensure object_permission is not in the update data + # It should have been converted to object_permission_id by handle_update_object_permission_common + if "object_permission" in update_end_user_table_data: + verbose_proxy_logger.warning( + f"object_permission still in update_end_user_table_data: {update_end_user_table_data.get('object_permission')}" + ) + update_end_user_table_data.pop("object_permission", None) + if data.user_id is not None and len(data.user_id) > 0: update_end_user_table_data["user_id"] = data.user_id # type: ignore verbose_proxy_logger.debug("In update customer, user_id condition block.") response = await prisma_client.db.litellm_endusertable.update( - where={"user_id": data.user_id}, data=update_end_user_table_data, include={"litellm_budget_table": True} # type: ignore + where={"user_id": data.user_id}, data=update_end_user_table_data, include={"litellm_budget_table": True, "object_permission": True} # type: ignore ) if response is None: raise ValueError( @@ -511,7 +633,15 @@ async def update_end_user( verbose_proxy_logger.debug( f"received response from updating prisma client. response={response}" ) - return response + + # Convert to dict and clean up recursive fields + response_dict = response.model_dump() + if response_dict.get("object_permission"): + # Remove reverse relations from object_permission + for field in ["teams", "verification_tokens", "organizations", "users", "end_users"]: + response_dict["object_permission"].pop(field, None) + + return response_dict else: raise ValueError(f"user_id is required, passed user_id = {data.user_id}") @@ -663,12 +793,17 @@ async def list_end_user( ) response = await prisma_client.db.litellm_endusertable.find_many( - include={"litellm_budget_table": True} + include={"litellm_budget_table": True, "object_permission": True} ) returned_response: List[LiteLLM_EndUserTable] = [] for item in response: - returned_response.append(LiteLLM_EndUserTable(**item.model_dump())) + item_dict = item.model_dump() + # Remove reverse relations from object_permission + if item_dict.get("object_permission"): + for field in ["teams", "verification_tokens", "organizations", "users", "end_users"]: + item_dict["object_permission"].pop(field, None) + returned_response.append(LiteLLM_EndUserTable(**item_dict)) return returned_response except Exception as e: @@ -706,9 +841,7 @@ async def get_customer_daily_activity( """ Get daily activity for specific organizations or all accessible organizations. """ - from litellm.proxy.proxy_server import ( - prisma_client, - ) + from litellm.proxy.proxy_server import prisma_client if prisma_client is None: raise HTTPException( diff --git a/litellm/proxy/schema.prisma b/litellm/proxy/schema.prisma index 441c2cdf70..8bd46672ae 100644 --- a/litellm/proxy/schema.prisma +++ b/litellm/proxy/schema.prisma @@ -233,6 +233,7 @@ model LiteLLM_ObjectPermissionTable { verification_tokens LiteLLM_VerificationToken[] organizations LiteLLM_OrganizationTable[] users LiteLLM_UserTable[] + end_users LiteLLM_EndUserTable[] } // Holds the MCP server configuration @@ -403,7 +404,9 @@ model LiteLLM_EndUserTable { allowed_model_region String? // require all user requests to use models in this specific region default_model String? // use along with 'allowed_model_region'. if no available model in region, default to this model. budget_id String? + object_permission_id String? litellm_budget_table LiteLLM_BudgetTable? @relation(fields: [budget_id], references: [budget_id]) + object_permission LiteLLM_ObjectPermissionTable? @relation(fields: [object_permission_id], references: [object_permission_id]) blocked Boolean @default(false) } diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py index 7cfcef8155..94471f0e32 100644 --- a/litellm/proxy/utils.py +++ b/litellm/proxy/utils.py @@ -1321,6 +1321,7 @@ class ProxyLogging: metadata = data.get("metadata", data.get("litellm_metadata", {})) or {} pipeline_managed: set = metadata.get("_pipeline_managed_guardrails", set()) + for callback in litellm.callbacks: start_time = time.time() _callback = None diff --git a/litellm/responses/litellm_completion_transformation/transformation.py b/litellm/responses/litellm_completion_transformation/transformation.py index b8379b28c3..609e19b244 100644 --- a/litellm/responses/litellm_completion_transformation/transformation.py +++ b/litellm/responses/litellm_completion_transformation/transformation.py @@ -6,6 +6,7 @@ from collections.abc import Sequence from typing import Any, Dict, List, Literal, Optional, Set, Tuple, Union, cast from openai.types.responses import ResponseFunctionToolCall +from openai.types.responses.response_create_params import ResponseInputParam from openai.types.responses.tool_param import FunctionToolParam from typing_extensions import TypedDict @@ -32,7 +33,6 @@ from litellm.types.llms.openai import ( OpenAIWebSearchUserLocation, OutputTokensDetails, ResponseAPIUsage, - ResponseInputParam, ResponsesAPIOptionalRequestParams, ResponsesAPIResponse, ResponsesAPIStatus, @@ -738,9 +738,25 @@ class LiteLLMCompletionResponsesConfig: @staticmethod def _ensure_tool_results_have_corresponding_tool_calls( - messages: List[Union[AllMessageValues, GenericChatCompletionMessage, ChatCompletionResponseMessage]], + messages: Sequence[ + Union[ + AllMessageValues, + GenericChatCompletionMessage, + ChatCompletionResponseMessage, + ChatCompletionMessageToolCall, + Message, + ] + ], tools: Optional[List[Any]] = None, - ) -> List[Union[AllMessageValues, GenericChatCompletionMessage, ChatCompletionResponseMessage]]: + ) -> List[ + Union[ + AllMessageValues, + GenericChatCompletionMessage, + ChatCompletionResponseMessage, + ChatCompletionMessageToolCall, + Message, + ] + ]: """ Ensure that tool_result messages have corresponding tool_calls in the previous assistant message. @@ -755,11 +771,19 @@ class LiteLLMCompletionResponsesConfig: List of messages with tool_calls added to assistant messages when needed """ if not messages: - return messages - - # Create a deep copy to avoid modifying the original + return list(messages) + + # Create a deep copy to avoid modifying the original (use list() so we can mutate and return List) import copy - fixed_messages = copy.deepcopy(messages) + fixed_messages: List[ + Union[ + AllMessageValues, + GenericChatCompletionMessage, + ChatCompletionResponseMessage, + ChatCompletionMessageToolCall, + Message, + ] + ] = list(copy.deepcopy(messages)) messages_to_remove = [] # Count non-tool messages to avoid removing all messages @@ -1306,6 +1330,50 @@ class LiteLLMCompletionResponsesConfig: chat_completion_tools.append(cast(Union[ChatCompletionToolParam, OpenAIMcpServerTool], tool)) return chat_completion_tools, web_search_options + @staticmethod + def transform_chat_completion_tool_params_to_responses_api_tools( + chat_completion_tools: Optional[ + List[Union[ChatCompletionToolParam, OpenAIMcpServerTool]] + ], + ) -> List[Dict[str, Any]]: + """ + Transform Chat Completion tool params (e.g. from guardrail output) back to + Responses API request tool format. Inverse of + transform_responses_api_tools_to_chat_completion_tools for the tools list. + """ + if chat_completion_tools is None or not chat_completion_tools: + return [] + result: List[Dict[str, Any]] = [] + for tool in chat_completion_tools: + if not isinstance(tool, dict): + result.append(tool) # type: ignore + continue + if tool.get("type") == "function": + fn = tool.get("function") or {} + parameters = dict(fn.get("parameters", {}) or {}) + if not parameters or "type" not in parameters: + parameters["type"] = "object" + responses_tool: Dict[str, Any] = { + "type": "function", + "name": fn.get("name") or "", + "description": fn.get("description") or "", + "parameters": parameters, + "strict": fn.get("strict", False) or False, + } + if tool.get("cache_control") is not None: + responses_tool["cache_control"] = tool.get("cache_control") + if tool.get("defer_loading") is not None: + responses_tool["defer_loading"] = tool.get("defer_loading") + if tool.get("allowed_callers") is not None: + responses_tool["allowed_callers"] = tool.get("allowed_callers") + if tool.get("input_examples") is not None: + responses_tool["input_examples"] = tool.get("input_examples") + result.append(responses_tool) + else: + # mcp or other: pass through unchanged + result.append(dict(tool)) + return result + @staticmethod def transform_chat_completion_tools_to_responses_tools( chat_completion_response: ModelResponse, diff --git a/litellm/types/guardrails.py b/litellm/types/guardrails.py index bbbd60dab7..9a894418f0 100644 --- a/litellm/types/guardrails.py +++ b/litellm/types/guardrails.py @@ -70,6 +70,7 @@ class SupportedGuardrailIntegrations(Enum): GENERIC_GUARDRAIL_API = "generic_guardrail_api" QUALIFIRE = "qualifire" CUSTOM_CODE = "custom_code" + MCP_END_USER_PERMISSION = "mcp_end_user_permission" class Role(Enum): diff --git a/litellm/types/proxy/guardrails/guardrail_hooks/mcp_end_user_permission.py b/litellm/types/proxy/guardrails/guardrail_hooks/mcp_end_user_permission.py new file mode 100644 index 0000000000..8c4d80e98a --- /dev/null +++ b/litellm/types/proxy/guardrails/guardrail_hooks/mcp_end_user_permission.py @@ -0,0 +1,12 @@ +from .base import GuardrailConfigModel + + +class MCPEndUserPermissionGuardrailConfigModel(GuardrailConfigModel): + """ + No provider-specific params required — permissions come from the end user + object already stored in the database. + """ + + @staticmethod + def ui_friendly_name() -> str: + return "MCP End User Permission" diff --git a/schema.prisma b/schema.prisma index 441c2cdf70..8bd46672ae 100644 --- a/schema.prisma +++ b/schema.prisma @@ -233,6 +233,7 @@ model LiteLLM_ObjectPermissionTable { verification_tokens LiteLLM_VerificationToken[] organizations LiteLLM_OrganizationTable[] users LiteLLM_UserTable[] + end_users LiteLLM_EndUserTable[] } // Holds the MCP server configuration @@ -403,7 +404,9 @@ model LiteLLM_EndUserTable { allowed_model_region String? // require all user requests to use models in this specific region default_model String? // use along with 'allowed_model_region'. if no available model in region, default to this model. budget_id String? + object_permission_id String? litellm_budget_table LiteLLM_BudgetTable? @relation(fields: [budget_id], references: [budget_id]) + object_permission LiteLLM_ObjectPermissionTable? @relation(fields: [object_permission_id], references: [object_permission_id]) blocked Boolean @default(false) } diff --git a/tests/test_litellm/proxy/guardrails/guardrail_hooks/test_mcp_end_user_permission.py b/tests/test_litellm/proxy/guardrails/guardrail_hooks/test_mcp_end_user_permission.py new file mode 100644 index 0000000000..d0dd445dcc --- /dev/null +++ b/tests/test_litellm/proxy/guardrails/guardrail_hooks/test_mcp_end_user_permission.py @@ -0,0 +1,414 @@ +""" +Tests for MCP End User Permission Guardrail Hook +""" +import os +import sys +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +sys.path.insert( + 0, os.path.abspath("../../../..") +) # Adds the parent directory to the system path + +from litellm.exceptions import GuardrailRaisedException +from litellm.proxy._types import UserAPIKeyAuth +from litellm.proxy.guardrails.guardrail_hooks.mcp_end_user_permission import ( + MCPEndUserPermissionGuardrail, +) +from litellm.types.utils import ( + ChatCompletionMessageToolCall, + Choices, + Function, + Message, + ModelResponse, +) + + +class TestMCPEndUserPermissionGuardrail: + """Test the MCP End User Permission Guardrail""" + + def test_extract_mcp_server_name(self): + """Test extracting MCP server name from tool name""" + guardrail = MCPEndUserPermissionGuardrail() + + # Test valid MCP tool names + assert guardrail._extract_mcp_server_name("github-create_issue") == "github" + assert guardrail._extract_mcp_server_name("slack-send_message") == "slack" + assert guardrail._extract_mcp_server_name("jira-create-ticket") == "jira" + + # Test invalid/non-MCP tool names + assert guardrail._extract_mcp_server_name("search") is None + assert guardrail._extract_mcp_server_name("") is None + assert guardrail._extract_mcp_server_name(None) is None + + @pytest.mark.asyncio + async def test_apply_guardrail_no_end_user(self): + """Test guardrail when no end_user_id is present""" + guardrail = MCPEndUserPermissionGuardrail() + + # Create inputs with MCP tools + inputs = { + "tools": [ + { + "type": "function", + "function": { + "name": "github-create_issue", + "description": "Create an issue", + }, + } + ] + } + + request_data = {} + + # Should pass through all tools when no end_user_id + result = await guardrail.apply_guardrail( + inputs=inputs, + request_data=request_data, + input_type="request", + ) + + assert len(result.get("tools", [])) == 1 + assert result["tools"][0]["function"]["name"] == "github-create_issue" + + @pytest.mark.asyncio + async def test_apply_guardrail_with_authorized_tools(self): + """Test guardrail when end user has access to MCP servers""" + from litellm.proxy._types import LiteLLM_ObjectPermissionTable + + guardrail = MCPEndUserPermissionGuardrail() + + # Create inputs with multiple tools + inputs = { + "tools": [ + { + "type": "function", + "function": { + "name": "github-create_issue", + "description": "Create an issue", + }, + }, + { + "type": "function", + "function": { + "name": "slack-send_message", + "description": "Send a message", + }, + }, + { + "type": "function", + "function": { + "name": "search", + "description": "Regular search tool", + }, + }, + ] + } + + request_data = {"user_api_key_end_user_id": "end-user-123"} + + # Mock fetching end user object with permissions + with patch.object( + MCPEndUserPermissionGuardrail, + "_fetch_end_user_object", + return_value=MagicMock( + object_permission=LiteLLM_ObjectPermissionTable( + object_permission_id="perm-1", + mcp_servers=["github", "slack"], + ) + ), + ): + result = await guardrail.apply_guardrail( + inputs=inputs, + request_data=request_data, + input_type="request", + ) + + # Should keep all authorized MCP tools + non-MCP tools + assert len(result.get("tools", [])) == 3 + + @pytest.mark.asyncio + async def test_apply_guardrail_with_unauthorized_tools(self): + """Test guardrail filters out unauthorized MCP tools""" + from litellm.proxy._types import LiteLLM_ObjectPermissionTable + + guardrail = MCPEndUserPermissionGuardrail() + + # Create inputs with tools where end user only has access to some + inputs = { + "tools": [ + { + "type": "function", + "function": { + "name": "github-create_issue", + "description": "Create an issue", + }, + }, + { + "type": "function", + "function": { + "name": "slack-send_message", + "description": "Send a message", + }, + }, + { + "type": "function", + "function": { + "name": "jira-create_ticket", + "description": "Create a ticket", + }, + }, + ] + } + + request_data = {"user_api_key_end_user_id": "end-user-123"} + + # Mock fetching end user object with limited permissions (only slack and jira) + with patch.object( + MCPEndUserPermissionGuardrail, + "_fetch_end_user_object", + return_value=MagicMock( + object_permission=LiteLLM_ObjectPermissionTable( + object_permission_id="perm-1", + mcp_servers=["slack", "jira"], + ) + ), + ): + result = await guardrail.apply_guardrail( + inputs=inputs, + request_data=request_data, + input_type="request", + ) + + # Should filter out github tool + assert len(result.get("tools", [])) == 2 + tool_names = [t["function"]["name"] for t in result["tools"]] + assert "slack-send_message" in tool_names + assert "jira-create_ticket" in tool_names + assert "github-create_issue" not in tool_names + + @pytest.mark.asyncio + async def test_apply_guardrail_no_permissions_configured(self): + """Test guardrail when end user has no MCP permissions configured""" + guardrail = MCPEndUserPermissionGuardrail() + + # Create inputs with MCP tools + inputs = { + "tools": [ + { + "type": "function", + "function": { + "name": "github-create_issue", + "description": "Create an issue", + }, + } + ] + } + + request_data = {"user_api_key_end_user_id": "end-user-123"} + + # Mock fetching end user object with no object_permission + with patch.object( + MCPEndUserPermissionGuardrail, + "_fetch_end_user_object", + return_value=MagicMock(object_permission=None), + ): + result = await guardrail.apply_guardrail( + inputs=inputs, + request_data=request_data, + input_type="request", + ) + + # Should pass through all tools when no permissions configured + assert len(result.get("tools", [])) == 1 + assert result["tools"][0]["function"]["name"] == "github-create_issue" + + @pytest.mark.asyncio + async def test_apply_guardrail_with_non_mcp_tools(self): + """Test guardrail passes through non-MCP tools""" + from litellm.proxy._types import LiteLLM_ObjectPermissionTable + + guardrail = MCPEndUserPermissionGuardrail() + + # Create inputs with non-MCP tools + inputs = { + "tools": [ + { + "type": "function", + "function": { + "name": "search", + "description": "Search tool", + }, + }, + { + "type": "function", + "function": { + "name": "calculate", + "description": "Calculate something", + }, + }, + ] + } + + request_data = {"user_api_key_end_user_id": "end-user-123"} + + # Mock fetching end user object with MCP restrictions + with patch.object( + MCPEndUserPermissionGuardrail, + "_fetch_end_user_object", + return_value=MagicMock( + object_permission=LiteLLM_ObjectPermissionTable( + object_permission_id="perm-1", + mcp_servers=["github"], + ) + ), + ): + result = await guardrail.apply_guardrail( + inputs=inputs, + request_data=request_data, + input_type="request", + ) + + # Should keep all non-MCP tools even with MCP restrictions + assert len(result.get("tools", [])) == 2 + + @pytest.mark.asyncio + async def test_apply_guardrail_filters_unauthorized_mcp_tools(self): + """Test guardrail filters out unauthorized MCP tools""" + from litellm.proxy._types import LiteLLM_ObjectPermissionTable + + guardrail = MCPEndUserPermissionGuardrail() + + # Create inputs with MCP tools where user only has access to some + inputs = { + "tools": [ + { + "type": "function", + "function": { + "name": "github-create_issue", + "description": "Create an issue", + }, + }, + { + "type": "function", + "function": { + "name": "slack-send_message", + "description": "Send a message", + }, + }, + { + "type": "function", + "function": { + "name": "jira-create_ticket", + "description": "Create a ticket", + }, + }, + ] + } + + request_data = {"user_api_key_end_user_id": "end-user-123"} + + # Mock fetching end user object - only has access to slack and jira, not github + with patch.object( + MCPEndUserPermissionGuardrail, + "_fetch_end_user_object", + return_value=MagicMock( + object_permission=LiteLLM_ObjectPermissionTable( + object_permission_id="perm-1", + mcp_servers=["slack", "jira"], + ) + ), + ): + result = await guardrail.apply_guardrail( + inputs=inputs, + request_data=request_data, + input_type="request", + ) + + # Should filter out github tool + assert len(result.get("tools", [])) == 2 + tool_names = [t["function"]["name"] for t in result["tools"]] + assert "slack-send_message" in tool_names + assert "jira-create_ticket" in tool_names + assert "github-create_issue" not in tool_names + + @pytest.mark.asyncio + async def test_apply_guardrail_with_mixed_tools(self): + """Test guardrail with both MCP and non-MCP tools""" + from litellm.proxy._types import LiteLLM_ObjectPermissionTable + + guardrail = MCPEndUserPermissionGuardrail() + + # Create inputs with both MCP and non-MCP tools + inputs = { + "tools": [ + { + "type": "function", + "function": { + "name": "github-create_issue", + "description": "Create an issue", + }, + }, + { + "type": "function", + "function": { + "name": "search", + "description": "Search tool", + }, + }, + { + "type": "function", + "function": { + "name": "slack-send_message", + "description": "Send a message", + }, + }, + ] + } + + request_data = {"user_api_key_end_user_id": "end-user-123"} + + # Mock fetching end user object - only has access to slack + with patch.object( + MCPEndUserPermissionGuardrail, + "_fetch_end_user_object", + return_value=MagicMock( + object_permission=LiteLLM_ObjectPermissionTable( + object_permission_id="perm-1", + mcp_servers=["slack"], + ) + ), + ): + result = await guardrail.apply_guardrail( + inputs=inputs, + request_data=request_data, + input_type="request", + ) + + # Should keep slack MCP tool and non-MCP search tool, filter out github + assert len(result.get("tools", [])) == 2 + tool_names = [t["function"]["name"] for t in result["tools"]] + assert "search" in tool_names + assert "slack-send_message" in tool_names + assert "github-create_issue" not in tool_names + + @pytest.mark.asyncio + async def test_apply_guardrail_no_tools_in_request(self): + """Test guardrail when request has no tools""" + guardrail = MCPEndUserPermissionGuardrail() + + # Create inputs without tools + inputs = {"model": "gpt-4", "messages": [{"role": "user", "content": "test"}]} + + request_data = {"user_api_key_end_user_id": "end-user-123"} + + # Should pass through unchanged + result = await guardrail.apply_guardrail( + inputs=inputs, + request_data=request_data, + input_type="request", + ) + + assert result == inputs + assert "tools" not in result