diff --git a/enterprise/litellm_enterprise/proxy/guardrails/endpoints.py b/enterprise/litellm_enterprise/proxy/guardrails/endpoints.py index cdf86dcea6..fcd5e82ab3 100644 --- a/enterprise/litellm_enterprise/proxy/guardrails/endpoints.py +++ b/enterprise/litellm_enterprise/proxy/guardrails/endpoints.py @@ -6,16 +6,28 @@ To see all free guardrails see litellm/proxy/guardrails/* Exposed Routes: - /mask_pii +- /virtual_key/guardrails """ -from typing import Optional +from typing import Dict, List, Optional -from fastapi import APIRouter, Depends +from fastapi import APIRouter, Depends, HTTPException +from pydantic import BaseModel from litellm.integrations.custom_guardrail import CustomGuardrail from litellm.proxy._types import UserAPIKeyAuth from litellm.proxy.auth.user_api_key_auth import user_api_key_auth from litellm.proxy.guardrails.guardrail_endpoints import GUARDRAIL_REGISTRY -from litellm.types.guardrails import ApplyGuardrailRequest, ApplyGuardrailResponse +from litellm.proxy.guardrails.guardrail_registry import IN_MEMORY_GUARDRAIL_HANDLER +from litellm.types.guardrails import ApplyGuardrailRequest, ApplyGuardrailResponse, Guardrail + +# Models for virtual key guardrail management +class VirtualKeyGuardrailRequest(BaseModel): + virtual_key_id: str + guardrail_id: str + +class VirtualKeyGuardrailsResponse(BaseModel): + virtual_key_id: str + guardrails: List[Guardrail] router = APIRouter(tags=["guardrails"], prefix="/guardrails") @@ -39,3 +51,55 @@ async def apply_guardrail( return await active_guardrail.apply_guardrail( text=request.text, language=request.language, entities=request.entities ) + +@router.post("/virtual_key/associate", response_model=Dict[str, str]) +async def associate_guardrail_with_virtual_key( + request: VirtualKeyGuardrailRequest, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), +): + """ + Associate a guardrail with a virtual key + """ + # Check if guardrail exists + guardrail = IN_MEMORY_GUARDRAIL_HANDLER.get_guardrail_by_id(request.guardrail_id) + if not guardrail: + raise HTTPException(status_code=404, detail=f"Guardrail {request.guardrail_id} not found") + + # Associate guardrail with virtual key + IN_MEMORY_GUARDRAIL_HANDLER.associate_guardrail_with_virtual_key( + virtual_key_id=request.virtual_key_id, + guardrail_id=request.guardrail_id + ) + + return {"message": f"Guardrail {request.guardrail_id} associated with virtual key {request.virtual_key_id}"} + +@router.post("/virtual_key/disassociate", response_model=Dict[str, str]) +async def disassociate_guardrail_from_virtual_key( + request: VirtualKeyGuardrailRequest, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), +): + """ + Disassociate a guardrail from a virtual key + """ + # Disassociate guardrail from virtual key + IN_MEMORY_GUARDRAIL_HANDLER.disassociate_guardrail_from_virtual_key( + virtual_key_id=request.virtual_key_id, + guardrail_id=request.guardrail_id + ) + + return {"message": f"Guardrail {request.guardrail_id} disassociated from virtual key {request.virtual_key_id}"} + +@router.get("/virtual_key/{virtual_key_id}", response_model=VirtualKeyGuardrailsResponse) +async def get_guardrails_for_virtual_key( + virtual_key_id: str, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), +): + """ + Get all guardrails associated with a virtual key + """ + guardrails = IN_MEMORY_GUARDRAIL_HANDLER.get_guardrails_for_virtual_key(virtual_key_id) + + return VirtualKeyGuardrailsResponse( + virtual_key_id=virtual_key_id, + guardrails=guardrails + ) diff --git a/litellm/proxy/guardrails/guardrail_registry.py b/litellm/proxy/guardrails/guardrail_registry.py index 21429f462d..523ff9f4cd 100644 --- a/litellm/proxy/guardrails/guardrail_registry.py +++ b/litellm/proxy/guardrails/guardrail_registry.py @@ -369,6 +369,11 @@ class InMemoryGuardrailHandler: """ Guardrail id to CustomGuardrail object mapping """ + + self.virtual_key_to_guardrails: Dict[str, List[str]] = {} + """ + Virtual key id to list of guardrail ids mapping + """ def initialize_guardrail( self, @@ -538,6 +543,31 @@ class InMemoryGuardrailHandler: Get a guardrail by its ID from memory """ return self.IN_MEMORY_GUARDRAILS.get(guardrail_id) + + def associate_guardrail_with_virtual_key(self, virtual_key_id: str, guardrail_id: str) -> None: + """ + Associate a guardrail with a virtual key + """ + if virtual_key_id not in self.virtual_key_to_guardrails: + self.virtual_key_to_guardrails[virtual_key_id] = [] + + if guardrail_id not in self.virtual_key_to_guardrails[virtual_key_id]: + self.virtual_key_to_guardrails[virtual_key_id].append(guardrail_id) + + def disassociate_guardrail_from_virtual_key(self, virtual_key_id: str, guardrail_id: str) -> None: + """ + Disassociate a guardrail from a virtual key + """ + if virtual_key_id in self.virtual_key_to_guardrails: + if guardrail_id in self.virtual_key_to_guardrails[virtual_key_id]: + self.virtual_key_to_guardrails[virtual_key_id].remove(guardrail_id) + + def get_guardrails_for_virtual_key(self, virtual_key_id: str) -> List[Guardrail]: + """ + Get all guardrails associated with a virtual key + """ + guardrail_ids = self.virtual_key_to_guardrails.get(virtual_key_id, []) + return [self.IN_MEMORY_GUARDRAILS[gid] for gid in guardrail_ids if gid in self.IN_MEMORY_GUARDRAILS] ########################################################