mirror of
https://github.com/tiennm99/litellm.git
synced 2026-06-30 19:04:14 +00:00
feat: add persistence and management for guardrails on virtual keys
This commit is contained in:
@@ -6,16 +6,28 @@ To see all free guardrails see litellm/proxy/guardrails/*
|
||||
|
||||
Exposed Routes:
|
||||
- /mask_pii
|
||||
- /virtual_key/guardrails
|
||||
"""
|
||||
from typing import Optional
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from pydantic import BaseModel
|
||||
|
||||
from litellm.integrations.custom_guardrail import CustomGuardrail
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
||||
from litellm.proxy.guardrails.guardrail_endpoints import GUARDRAIL_REGISTRY
|
||||
from litellm.types.guardrails import ApplyGuardrailRequest, ApplyGuardrailResponse
|
||||
from litellm.proxy.guardrails.guardrail_registry import IN_MEMORY_GUARDRAIL_HANDLER
|
||||
from litellm.types.guardrails import ApplyGuardrailRequest, ApplyGuardrailResponse, Guardrail
|
||||
|
||||
# Models for virtual key guardrail management
|
||||
class VirtualKeyGuardrailRequest(BaseModel):
|
||||
virtual_key_id: str
|
||||
guardrail_id: str
|
||||
|
||||
class VirtualKeyGuardrailsResponse(BaseModel):
|
||||
virtual_key_id: str
|
||||
guardrails: List[Guardrail]
|
||||
|
||||
router = APIRouter(tags=["guardrails"], prefix="/guardrails")
|
||||
|
||||
@@ -39,3 +51,55 @@ async def apply_guardrail(
|
||||
return await active_guardrail.apply_guardrail(
|
||||
text=request.text, language=request.language, entities=request.entities
|
||||
)
|
||||
|
||||
@router.post("/virtual_key/associate", response_model=Dict[str, str])
|
||||
async def associate_guardrail_with_virtual_key(
|
||||
request: VirtualKeyGuardrailRequest,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
Associate a guardrail with a virtual key
|
||||
"""
|
||||
# Check if guardrail exists
|
||||
guardrail = IN_MEMORY_GUARDRAIL_HANDLER.get_guardrail_by_id(request.guardrail_id)
|
||||
if not guardrail:
|
||||
raise HTTPException(status_code=404, detail=f"Guardrail {request.guardrail_id} not found")
|
||||
|
||||
# Associate guardrail with virtual key
|
||||
IN_MEMORY_GUARDRAIL_HANDLER.associate_guardrail_with_virtual_key(
|
||||
virtual_key_id=request.virtual_key_id,
|
||||
guardrail_id=request.guardrail_id
|
||||
)
|
||||
|
||||
return {"message": f"Guardrail {request.guardrail_id} associated with virtual key {request.virtual_key_id}"}
|
||||
|
||||
@router.post("/virtual_key/disassociate", response_model=Dict[str, str])
|
||||
async def disassociate_guardrail_from_virtual_key(
|
||||
request: VirtualKeyGuardrailRequest,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
Disassociate a guardrail from a virtual key
|
||||
"""
|
||||
# Disassociate guardrail from virtual key
|
||||
IN_MEMORY_GUARDRAIL_HANDLER.disassociate_guardrail_from_virtual_key(
|
||||
virtual_key_id=request.virtual_key_id,
|
||||
guardrail_id=request.guardrail_id
|
||||
)
|
||||
|
||||
return {"message": f"Guardrail {request.guardrail_id} disassociated from virtual key {request.virtual_key_id}"}
|
||||
|
||||
@router.get("/virtual_key/{virtual_key_id}", response_model=VirtualKeyGuardrailsResponse)
|
||||
async def get_guardrails_for_virtual_key(
|
||||
virtual_key_id: str,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
Get all guardrails associated with a virtual key
|
||||
"""
|
||||
guardrails = IN_MEMORY_GUARDRAIL_HANDLER.get_guardrails_for_virtual_key(virtual_key_id)
|
||||
|
||||
return VirtualKeyGuardrailsResponse(
|
||||
virtual_key_id=virtual_key_id,
|
||||
guardrails=guardrails
|
||||
)
|
||||
|
||||
@@ -369,6 +369,11 @@ class InMemoryGuardrailHandler:
|
||||
"""
|
||||
Guardrail id to CustomGuardrail object mapping
|
||||
"""
|
||||
|
||||
self.virtual_key_to_guardrails: Dict[str, List[str]] = {}
|
||||
"""
|
||||
Virtual key id to list of guardrail ids mapping
|
||||
"""
|
||||
|
||||
def initialize_guardrail(
|
||||
self,
|
||||
@@ -538,6 +543,31 @@ class InMemoryGuardrailHandler:
|
||||
Get a guardrail by its ID from memory
|
||||
"""
|
||||
return self.IN_MEMORY_GUARDRAILS.get(guardrail_id)
|
||||
|
||||
def associate_guardrail_with_virtual_key(self, virtual_key_id: str, guardrail_id: str) -> None:
|
||||
"""
|
||||
Associate a guardrail with a virtual key
|
||||
"""
|
||||
if virtual_key_id not in self.virtual_key_to_guardrails:
|
||||
self.virtual_key_to_guardrails[virtual_key_id] = []
|
||||
|
||||
if guardrail_id not in self.virtual_key_to_guardrails[virtual_key_id]:
|
||||
self.virtual_key_to_guardrails[virtual_key_id].append(guardrail_id)
|
||||
|
||||
def disassociate_guardrail_from_virtual_key(self, virtual_key_id: str, guardrail_id: str) -> None:
|
||||
"""
|
||||
Disassociate a guardrail from a virtual key
|
||||
"""
|
||||
if virtual_key_id in self.virtual_key_to_guardrails:
|
||||
if guardrail_id in self.virtual_key_to_guardrails[virtual_key_id]:
|
||||
self.virtual_key_to_guardrails[virtual_key_id].remove(guardrail_id)
|
||||
|
||||
def get_guardrails_for_virtual_key(self, virtual_key_id: str) -> List[Guardrail]:
|
||||
"""
|
||||
Get all guardrails associated with a virtual key
|
||||
"""
|
||||
guardrail_ids = self.virtual_key_to_guardrails.get(virtual_key_id, [])
|
||||
return [self.IN_MEMORY_GUARDRAILS[gid] for gid in guardrail_ids if gid in self.IN_MEMORY_GUARDRAILS]
|
||||
|
||||
|
||||
########################################################
|
||||
|
||||
Reference in New Issue
Block a user