feat: add persistence and management for guardrails on virtual keys

This commit is contained in:
mubashir1osmani (aider)
2025-09-01 00:56:22 -04:00
parent 04dc1a5351
commit 6cd5afa8b1
2 changed files with 97 additions and 3 deletions
@@ -6,16 +6,28 @@ To see all free guardrails see litellm/proxy/guardrails/*
Exposed Routes:
- /mask_pii
- /virtual_key/guardrails
"""
from typing import Optional
from typing import Dict, List, Optional
from fastapi import APIRouter, Depends
from fastapi import APIRouter, Depends, HTTPException
from pydantic import BaseModel
from litellm.integrations.custom_guardrail import CustomGuardrail
from litellm.proxy._types import UserAPIKeyAuth
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
from litellm.proxy.guardrails.guardrail_endpoints import GUARDRAIL_REGISTRY
from litellm.types.guardrails import ApplyGuardrailRequest, ApplyGuardrailResponse
from litellm.proxy.guardrails.guardrail_registry import IN_MEMORY_GUARDRAIL_HANDLER
from litellm.types.guardrails import ApplyGuardrailRequest, ApplyGuardrailResponse, Guardrail
# Models for virtual key guardrail management
class VirtualKeyGuardrailRequest(BaseModel):
virtual_key_id: str
guardrail_id: str
class VirtualKeyGuardrailsResponse(BaseModel):
virtual_key_id: str
guardrails: List[Guardrail]
router = APIRouter(tags=["guardrails"], prefix="/guardrails")
@@ -39,3 +51,55 @@ async def apply_guardrail(
return await active_guardrail.apply_guardrail(
text=request.text, language=request.language, entities=request.entities
)
@router.post("/virtual_key/associate", response_model=Dict[str, str])
async def associate_guardrail_with_virtual_key(
request: VirtualKeyGuardrailRequest,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
Associate a guardrail with a virtual key
"""
# Check if guardrail exists
guardrail = IN_MEMORY_GUARDRAIL_HANDLER.get_guardrail_by_id(request.guardrail_id)
if not guardrail:
raise HTTPException(status_code=404, detail=f"Guardrail {request.guardrail_id} not found")
# Associate guardrail with virtual key
IN_MEMORY_GUARDRAIL_HANDLER.associate_guardrail_with_virtual_key(
virtual_key_id=request.virtual_key_id,
guardrail_id=request.guardrail_id
)
return {"message": f"Guardrail {request.guardrail_id} associated with virtual key {request.virtual_key_id}"}
@router.post("/virtual_key/disassociate", response_model=Dict[str, str])
async def disassociate_guardrail_from_virtual_key(
request: VirtualKeyGuardrailRequest,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
Disassociate a guardrail from a virtual key
"""
# Disassociate guardrail from virtual key
IN_MEMORY_GUARDRAIL_HANDLER.disassociate_guardrail_from_virtual_key(
virtual_key_id=request.virtual_key_id,
guardrail_id=request.guardrail_id
)
return {"message": f"Guardrail {request.guardrail_id} disassociated from virtual key {request.virtual_key_id}"}
@router.get("/virtual_key/{virtual_key_id}", response_model=VirtualKeyGuardrailsResponse)
async def get_guardrails_for_virtual_key(
virtual_key_id: str,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
Get all guardrails associated with a virtual key
"""
guardrails = IN_MEMORY_GUARDRAIL_HANDLER.get_guardrails_for_virtual_key(virtual_key_id)
return VirtualKeyGuardrailsResponse(
virtual_key_id=virtual_key_id,
guardrails=guardrails
)
@@ -369,6 +369,11 @@ class InMemoryGuardrailHandler:
"""
Guardrail id to CustomGuardrail object mapping
"""
self.virtual_key_to_guardrails: Dict[str, List[str]] = {}
"""
Virtual key id to list of guardrail ids mapping
"""
def initialize_guardrail(
self,
@@ -538,6 +543,31 @@ class InMemoryGuardrailHandler:
Get a guardrail by its ID from memory
"""
return self.IN_MEMORY_GUARDRAILS.get(guardrail_id)
def associate_guardrail_with_virtual_key(self, virtual_key_id: str, guardrail_id: str) -> None:
"""
Associate a guardrail with a virtual key
"""
if virtual_key_id not in self.virtual_key_to_guardrails:
self.virtual_key_to_guardrails[virtual_key_id] = []
if guardrail_id not in self.virtual_key_to_guardrails[virtual_key_id]:
self.virtual_key_to_guardrails[virtual_key_id].append(guardrail_id)
def disassociate_guardrail_from_virtual_key(self, virtual_key_id: str, guardrail_id: str) -> None:
"""
Disassociate a guardrail from a virtual key
"""
if virtual_key_id in self.virtual_key_to_guardrails:
if guardrail_id in self.virtual_key_to_guardrails[virtual_key_id]:
self.virtual_key_to_guardrails[virtual_key_id].remove(guardrail_id)
def get_guardrails_for_virtual_key(self, virtual_key_id: str) -> List[Guardrail]:
"""
Get all guardrails associated with a virtual key
"""
guardrail_ids = self.virtual_key_to_guardrails.get(virtual_key_id, [])
return [self.IN_MEMORY_GUARDRAILS[gid] for gid in guardrail_ids if gid in self.IN_MEMORY_GUARDRAILS]
########################################################