From c9cdce96fa02d6798fe02dddc91c82d73bdd03ca Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Thu, 19 Feb 2026 14:09:20 -0800 Subject: [PATCH] feat(policy): test playground for AI policy suggestions (#21608) * fix aviation safety topic filter: remove overly broad exceptions, add cockpit access block words * fix airline brand protection filter: add identifier words, competitor/ops block words, tighten exceptions * feat(policy): add POST /policy/templates/test endpoint for testing guardrails before creating them * feat(ui): add testPolicyTemplate networking function * feat(ui): add test playground to AI policy suggestion modal * test(policy): add tests for POST /policy/templates/test endpoint --- .../policy_endpoints/endpoints.py | 149 ++++- .../policy_endpoints/test_endpoints.py | 213 ++++++ .../src/components/networking.tsx | 35 + .../policies/ai_suggestion_modal.tsx | 622 ++++++++++++++---- 4 files changed, 881 insertions(+), 138 deletions(-) create mode 100644 tests/test_litellm/proxy/management_endpoints/policy_endpoints/test_endpoints.py diff --git a/litellm/proxy/management_endpoints/policy_endpoints/endpoints.py b/litellm/proxy/management_endpoints/policy_endpoints/endpoints.py index efecd172d7..36745a607c 100644 --- a/litellm/proxy/management_endpoints/policy_endpoints/endpoints.py +++ b/litellm/proxy/management_endpoints/policy_endpoints/endpoints.py @@ -18,13 +18,13 @@ from typing import ( List, Literal, Optional, - TypedDict, cast, ) from fastapi import APIRouter, Depends, HTTPException, Request from fastapi.responses import StreamingResponse from pydantic import BaseModel, Field +from typing_extensions import TypedDict from litellm._logging import verbose_proxy_logger from litellm.constants import ( @@ -993,3 +993,150 @@ async def suggest_policy_templates( description=data.description, model=data.model, ) + + +class GuardrailTestResultEntry(TypedDict): + guardrail_name: str + action: str # "passed" | "blocked" | "masked" | "unsupported" + output_text: str + details: str + + +class TestPolicyTemplateRequest(BaseModel): + guardrail_definitions: List[dict] = Field( + description="All guardrailDefinitions from the policy template" + ) + text: str = Field(description="Test input text to run guardrails against") + + +class TestPolicyTemplateResponse(TypedDict): + overall_action: str # worst-case across all guardrails + results: List[GuardrailTestResultEntry] + + +@router.post( + "/policy/templates/test", + tags=["policy management"], + dependencies=[Depends(user_api_key_auth)], +) +@management_endpoint_wrapper +async def test_policy_template( + data: TestPolicyTemplateRequest, + request: Request, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), +) -> TestPolicyTemplateResponse: + """ + Test a policy template's guardrails against a text input without creating them. + + Instantiates temporary guardrails from the template definitions, runs them + against the provided text, and returns per-guardrail results so users can + verify the template solves their problem before creating it. + """ + from litellm.proxy.utils import handle_exception_on_proxy + + try: + results = await _test_guardrail_definitions( + guardrail_definitions=data.guardrail_definitions, + text=data.text, + ) + overall = _compute_overall_action(results) + return TestPolicyTemplateResponse( + overall_action=overall, + results=results, + ) + except Exception as e: + raise handle_exception_on_proxy(e) + + +async def _test_guardrail_definitions( + guardrail_definitions: List[dict], + text: str, +) -> List[GuardrailTestResultEntry]: + """Instantiate and run each guardrail definition against the text.""" + from litellm.proxy.guardrails.guardrail_hooks.litellm_content_filter.content_filter import ( + ContentFilterGuardrail, + ) + + results: List[GuardrailTestResultEntry] = [] + + for guardrail_def in guardrail_definitions: + guardrail_name = guardrail_def.get("guardrail_name", "unknown") + litellm_params = guardrail_def.get("litellm_params", {}) + guardrail_type = litellm_params.get("guardrail", "") + + if guardrail_type != "litellm_content_filter": + results.append( + GuardrailTestResultEntry( + guardrail_name=guardrail_name, + action="unsupported", + output_text=text, + details=f"Preview not available for guardrail type: {guardrail_type}", + ) + ) + continue + + try: + guardrail = ContentFilterGuardrail( + guardrail_name=guardrail_name, + patterns=litellm_params.get("patterns"), + blocked_words=litellm_params.get("blocked_words"), + categories=litellm_params.get("categories"), + pattern_redaction_format=litellm_params.get("pattern_redaction_format"), + default_on=litellm_params.get("default_on", False), + ) + + output = await guardrail.apply_guardrail( + inputs={"texts": [text]}, + request_data={}, + input_type="request", + ) + output_text = output.get("texts", [text])[0] if output.get("texts") else text + + if output_text != text: + action = "masked" + details = "Content was modified (masked)" + else: + action = "passed" + details = "No issues detected" + + results.append( + GuardrailTestResultEntry( + guardrail_name=guardrail_name, + action=action, + output_text=output_text, + details=details, + ) + ) + except HTTPException as e: + detail = e.detail if hasattr(e, "detail") else str(e) + if isinstance(detail, dict): + detail = detail.get("error", str(detail)) + results.append( + GuardrailTestResultEntry( + guardrail_name=guardrail_name, + action="blocked", + output_text="", + details=str(detail), + ) + ) + except Exception as e: + results.append( + GuardrailTestResultEntry( + guardrail_name=guardrail_name, + action="error", + output_text=text, + details=str(e), + ) + ) + + return results + + +def _compute_overall_action(results: List[GuardrailTestResultEntry]) -> str: + """Return the worst-case action: blocked > masked > error > unsupported > passed.""" + priority = {"blocked": 4, "masked": 3, "error": 2, "unsupported": 1, "passed": 0} + worst = "passed" + for r in results: + if priority.get(r["action"], 0) > priority.get(worst, 0): + worst = r["action"] + return worst diff --git a/tests/test_litellm/proxy/management_endpoints/policy_endpoints/test_endpoints.py b/tests/test_litellm/proxy/management_endpoints/policy_endpoints/test_endpoints.py new file mode 100644 index 0000000000..4d3063fdcc --- /dev/null +++ b/tests/test_litellm/proxy/management_endpoints/policy_endpoints/test_endpoints.py @@ -0,0 +1,213 @@ +""" +Tests for POST /policy/templates/test endpoint logic. + +Tests _test_guardrail_definitions and _compute_overall_action directly +without needing a running proxy. +""" + +import pytest + +from litellm.proxy.management_endpoints.policy_endpoints.endpoints import ( + GuardrailTestResultEntry, + _compute_overall_action, + _test_guardrail_definitions, +) + + +@pytest.mark.asyncio +async def test_pattern_based_guardrail_masks_pii(): + """A pattern-based guardrail should mask matching PII.""" + guardrail_defs = [ + { + "guardrail_name": "test-ssn-masker", + "litellm_params": { + "guardrail": "litellm_content_filter", + "mode": "pre_call", + "patterns": [ + { + "pattern_type": "prebuilt", + "pattern_name": "us_ssn", + "action": "MASK", + } + ], + "pattern_redaction_format": "[{pattern_name}_REDACTED]", + }, + "guardrail_info": {"description": "Masks US SSNs"}, + } + ] + + results = await _test_guardrail_definitions( + guardrail_definitions=guardrail_defs, + text="My SSN is 123-45-6789", + ) + + assert len(results) == 1 + assert results[0]["guardrail_name"] == "test-ssn-masker" + assert results[0]["action"] == "masked" + assert "123-45-6789" not in results[0]["output_text"] + assert "REDACTED" in results[0]["output_text"] + + +@pytest.mark.asyncio +async def test_blocked_words_guardrail_blocks(): + """A blocked_words guardrail should block matching text.""" + guardrail_defs = [ + { + "guardrail_name": "test-word-blocker", + "litellm_params": { + "guardrail": "litellm_content_filter", + "mode": "pre_call", + "blocked_words": [ + { + "keyword": "forbidden_word", + "action": "BLOCK", + "description": "test block", + } + ], + }, + "guardrail_info": {"description": "Blocks forbidden words"}, + } + ] + + results = await _test_guardrail_definitions( + guardrail_definitions=guardrail_defs, + text="This contains forbidden_word in it", + ) + + assert len(results) == 1 + assert results[0]["guardrail_name"] == "test-word-blocker" + assert results[0]["action"] == "blocked" + + +@pytest.mark.asyncio +async def test_clean_text_passes(): + """Clean text should pass all guardrails.""" + guardrail_defs = [ + { + "guardrail_name": "test-ssn-masker", + "litellm_params": { + "guardrail": "litellm_content_filter", + "mode": "pre_call", + "patterns": [ + { + "pattern_type": "prebuilt", + "pattern_name": "us_ssn", + "action": "MASK", + } + ], + }, + "guardrail_info": {"description": "Masks US SSNs"}, + } + ] + + results = await _test_guardrail_definitions( + guardrail_definitions=guardrail_defs, + text="Hello, this is a perfectly clean message.", + ) + + assert len(results) == 1 + assert results[0]["action"] == "passed" + assert results[0]["output_text"] == "Hello, this is a perfectly clean message." + + +@pytest.mark.asyncio +async def test_unsupported_guardrail_type(): + """Non-litellm_content_filter types should return unsupported.""" + guardrail_defs = [ + { + "guardrail_name": "test-mcp", + "litellm_params": { + "guardrail": "mcp_security", + "mode": "pre_call", + }, + "guardrail_info": {"description": "MCP guardrail"}, + } + ] + + results = await _test_guardrail_definitions( + guardrail_definitions=guardrail_defs, + text="Any text", + ) + + assert len(results) == 1 + assert results[0]["action"] == "unsupported" + assert "mcp_security" in results[0]["details"] + + +@pytest.mark.asyncio +async def test_multiple_guardrails_mixed_results(): + """Multiple guardrails with different outcomes.""" + guardrail_defs = [ + { + "guardrail_name": "ssn-masker", + "litellm_params": { + "guardrail": "litellm_content_filter", + "mode": "pre_call", + "patterns": [ + { + "pattern_type": "prebuilt", + "pattern_name": "us_ssn", + "action": "MASK", + } + ], + "pattern_redaction_format": "[{pattern_name}_REDACTED]", + }, + "guardrail_info": {"description": "Masks SSNs"}, + }, + { + "guardrail_name": "email-masker", + "litellm_params": { + "guardrail": "litellm_content_filter", + "mode": "pre_call", + "patterns": [ + { + "pattern_type": "prebuilt", + "pattern_name": "email", + "action": "MASK", + } + ], + "pattern_redaction_format": "[{pattern_name}_REDACTED]", + }, + "guardrail_info": {"description": "Masks emails"}, + }, + ] + + results = await _test_guardrail_definitions( + guardrail_definitions=guardrail_defs, + text="My SSN is 123-45-6789 but no email here", + ) + + assert len(results) == 2 + ssn_result = next(r for r in results if r["guardrail_name"] == "ssn-masker") + email_result = next(r for r in results if r["guardrail_name"] == "email-masker") + assert ssn_result["action"] == "masked" + assert email_result["action"] == "passed" + + +def test_compute_overall_action_blocked_wins(): + results: list[GuardrailTestResultEntry] = [ + GuardrailTestResultEntry(guardrail_name="a", action="passed", output_text="", details=""), + GuardrailTestResultEntry(guardrail_name="b", action="blocked", output_text="", details=""), + GuardrailTestResultEntry(guardrail_name="c", action="masked", output_text="", details=""), + ] + assert _compute_overall_action(results) == "blocked" + + +def test_compute_overall_action_masked_wins_over_passed(): + results: list[GuardrailTestResultEntry] = [ + GuardrailTestResultEntry(guardrail_name="a", action="passed", output_text="", details=""), + GuardrailTestResultEntry(guardrail_name="b", action="masked", output_text="", details=""), + ] + assert _compute_overall_action(results) == "masked" + + +def test_compute_overall_action_all_passed(): + results: list[GuardrailTestResultEntry] = [ + GuardrailTestResultEntry(guardrail_name="a", action="passed", output_text="", details=""), + GuardrailTestResultEntry(guardrail_name="b", action="passed", output_text="", details=""), + ] + assert _compute_overall_action(results) == "passed" + + +def test_compute_overall_action_empty(): + assert _compute_overall_action([]) == "passed" diff --git a/ui/litellm-dashboard/src/components/networking.tsx b/ui/litellm-dashboard/src/components/networking.tsx index 3517822985..8b624f7caf 100644 --- a/ui/litellm-dashboard/src/components/networking.tsx +++ b/ui/litellm-dashboard/src/components/networking.tsx @@ -5627,6 +5627,41 @@ export const suggestPolicyTemplates = async ( } }; +export const testPolicyTemplate = async ( + accessToken: string, + guardrailDefinitions: any[], + text: string +) => { + try { + const url = proxyBaseUrl + ? `${proxyBaseUrl}/policy/templates/test` + : `/policy/templates/test`; + const response = await fetch(url, { + method: "POST", + headers: { + [globalLitellmHeaderName]: `Bearer ${accessToken}`, + "Content-Type": "application/json", + }, + body: JSON.stringify({ + guardrail_definitions: guardrailDefinitions, + text, + }), + }); + + if (!response.ok) { + const errorData = await response.json(); + const errorMessage = deriveErrorMessage(errorData); + handleError(errorMessage); + throw new Error(errorMessage); + } + + return response.json(); + } catch (error) { + console.error("Failed to test policy template:", error); + throw error; + } +}; + export const enrichPolicyTemplateStream = async ( accessToken: string, templateId: string, diff --git a/ui/litellm-dashboard/src/components/policies/ai_suggestion_modal.tsx b/ui/litellm-dashboard/src/components/policies/ai_suggestion_modal.tsx index 623c291f07..60b49a7755 100644 --- a/ui/litellm-dashboard/src/components/policies/ai_suggestion_modal.tsx +++ b/ui/litellm-dashboard/src/components/policies/ai_suggestion_modal.tsx @@ -1,13 +1,24 @@ import React, { useState, useEffect } from "react"; -import { Modal, Spin, Checkbox, Select } from "antd"; -import { Button } from "@tremor/react"; -import { suggestPolicyTemplates, modelHubCall } from "../networking"; +import { Modal, Spin, Checkbox, Select, Input, Typography, Tooltip } from "antd"; +import { Button, Card } from "@tremor/react"; +import { CheckCircleOutlined, CloseCircleOutlined, InfoCircleOutlined, DownOutlined, RightOutlined } from "@ant-design/icons"; +import { suggestPolicyTemplates, modelHubCall, testPolicyTemplate, enrichPolicyTemplate } from "../networking"; + +const { TextArea } = Input; +const { Text } = Typography; interface SuggestedTemplate { template_id: string; reason: string; } +interface GuardrailTestResult { + guardrail_name: string; + action: string; + output_text: string; + details: string; +} + interface AiSuggestionModalProps { visible: boolean; onSelectTemplates: (templates: any[]) => void; @@ -34,6 +45,17 @@ const AiSuggestionModal: React.FC = ({ const [selectedModel, setSelectedModel] = useState(undefined); const [availableModels, setAvailableModels] = useState([]); const [isLoadingModels, setIsLoadingModels] = useState(false); + // Test panel state + const [showTestPanel, setShowTestPanel] = useState(false); + const [testInputText, setTestInputText] = useState(""); + const [isTestLoading, setIsTestLoading] = useState(false); + const [testResults, setTestResults] = useState(null); + const [testOverallAction, setTestOverallAction] = useState(null); + const [collapsedResults, setCollapsedResults] = useState>(new Set()); + // Enrichment state for competitor templates + const [enrichedDefs, setEnrichedDefs] = useState>({}); + const [isEnriching, setIsEnriching] = useState(false); + const [enrichBrandName, setEnrichBrandName] = useState(""); useEffect(() => { if (visible && availableModels.length === 0) { @@ -67,6 +89,15 @@ const AiSuggestionModal: React.FC = ({ setExplanation(null); setSelectedIds(new Set()); setSelectedModel(undefined); + setShowTestPanel(false); + setTestInputText(""); + setIsTestLoading(false); + setTestResults(null); + setTestOverallAction(null); + setCollapsedResults(new Set()); + setEnrichedDefs({}); + setIsEnriching(false); + setEnrichBrandName(""); }; const handleCancel = () => { @@ -126,6 +157,11 @@ const AiSuggestionModal: React.FC = ({ setSuggestions(null); setExplanation(null); setSelectedIds(new Set()); + setShowTestPanel(false); + setTestInputText(""); + setTestResults(null); + setTestOverallAction(null); + setCollapsedResults(new Set()); }; const handleUseSelected = () => { @@ -149,14 +185,433 @@ const AiSuggestionModal: React.FC = ({ const getTemplateById = (id: string) => allTemplates.find((t) => t.id === id); + const toggleResultCollapse = (name: string) => { + setCollapsedResults((prev) => { + const next = new Set(prev); + if (next.has(name)) next.delete(name); + else next.add(name); + return next; + }); + }; + + const getSelectedTemplatesNeedingEnrichment = (): any[] => { + return Array.from(selectedIds) + .map((id) => getTemplateById(id)) + .filter((t) => t?.llm_enrichment); + }; + + const needsEnrichment = getSelectedTemplatesNeedingEnrichment().length > 0; + + const getAllSelectedGuardrailDefs = (): any[] => { + const defs: any[] = []; + for (const id of selectedIds) { + // Use enriched defs if available, otherwise use template's original + if (enrichedDefs[id]) { + defs.push(...enrichedDefs[id]); + } else { + const t = getTemplateById(id); + if (t?.guardrailDefinitions) { + defs.push(...t.guardrailDefinitions); + } + } + } + return defs; + }; + + const handleEnrichCompetitors = async () => { + if (!accessToken || !selectedModel) return; + const templatesToEnrich = getSelectedTemplatesNeedingEnrichment(); + if (templatesToEnrich.length === 0) return; + + setIsEnriching(true); + try { + for (const template of templatesToEnrich) { + const paramName = template.llm_enrichment.parameter; + const result = await enrichPolicyTemplate( + accessToken, + template.id, + { [paramName]: enrichBrandName }, + selectedModel, + ); + setEnrichedDefs((prev) => ({ + ...prev, + [template.id]: result.guardrailDefinitions, + })); + } + } catch (e) { + console.error("Failed to enrich templates:", e); + } finally { + setIsEnriching(false); + } + }; + + const handleRunTest = async () => { + if (!accessToken || !testInputText.trim()) return; + const allDefs = getAllSelectedGuardrailDefs(); + if (allDefs.length === 0) return; + + setIsTestLoading(true); + setTestResults(null); + setTestOverallAction(null); + setCollapsedResults(new Set()); + + try { + const result = await testPolicyTemplate(accessToken, allDefs, testInputText); + setTestResults(result.results || []); + setTestOverallAction(result.overall_action || "passed"); + } catch { + setTestResults([]); + setTestOverallAction("error"); + } finally { + setIsTestLoading(false); + } + }; + + const handleTestKeyDown = (e: React.KeyboardEvent) => { + if (e.key === "Enter" && !e.shiftKey && !e.ctrlKey && !e.metaKey) { + e.preventDefault(); + handleRunTest(); + } + }; + const showResults = suggestions !== null && !isLoading; + // Helper to render the suggestions list (reused in both layouts) + const renderSuggestionsList = () => { + if (!suggestions || suggestions.length === 0) { + return ( +
+ + + +

No matching templates found

+

Try adjusting your examples or description.

+
+ ); + } + + return ( +
+ {suggestions.map((suggestion) => { + const template = getTemplateById(suggestion.template_id); + if (!template) return null; + const isSelected = selectedIds.has(suggestion.template_id); + return ( +
+
toggleTemplate(suggestion.template_id)} + > +
+ toggleTemplate(suggestion.template_id)} + className="mt-0.5" + /> +
+
+ + {template.title} + + {template.complexity && ( + + {template.complexity} + + )} +
+

+ {template.description} +

+
+ {template.guardrails && template.guardrails.slice(0, 4).map((g: string) => ( + + {g} + + ))} + {template.guardrails && template.guardrails.length > 4 && ( + + +{template.guardrails.length - 4} more + + )} +
+
+ +

+ {suggestion.reason} +

+
+
+
+
+
+ ); + })} + + {/* Explanation */} + {explanation && ( +
+
+ + + Why these templates + +
+

{explanation}

+
+ )} +
+ ); + }; + + // Helper to render the test panel + const renderTestPanel = () => ( +
+ {/* Test header */} +
+
+

Test Guardrails

+ +
+
+ {Array.from(selectedIds).map((id) => { + const t = getTemplateById(id); + return t ? ( + + {t.title} + + ) : null; + })} +
+

+ {getAllSelectedGuardrailDefs().length} guardrails across {selectedIds.size} template{selectedIds.size !== 1 ? "s" : ""} +

+
+ + {/* Enrichment section for competitor templates */} + {needsEnrichment && Object.keys(enrichedDefs).length > 0 && ( +
+
+ + + Competitor names loaded for {enrichBrandName} + +
+
+ )} + {needsEnrichment && Object.keys(enrichedDefs).length === 0 && ( +
+
+ + + + + Competitor template requires your brand name to discover competitors + +
+
+ setEnrichBrandName(e.target.value)} + onPressEnter={() => enrichBrandName.trim() && handleEnrichCompetitors()} + className="flex-1" + /> + +
+
+ )} + + {/* Input */} +
+
+
+
+ + + + +
+ Characters: {testInputText.length} +
+