mirror of
https://github.com/tiennm99/litellm.git
synced 2026-06-22 03:44:54 +00:00
Grayswan guardrail passthrough on flagged (#16891)
* attempt to implement the passthrough feature * Formatting and small change * Fix formatting * Format test file --------- Co-authored-by: Xiaohan Fu <xiaohan@grayswan.ai>
This commit is contained in:
@@ -142,8 +142,8 @@ Provides the strongest enforcement by inspecting both prompts and responses.
|
||||
|---------------------------------------|-----------------|-------------|
|
||||
| `api_key` | string | Gray Swan Cygnal API key. Reads from `GRAYSWAN_API_KEY` if omitted. |
|
||||
| `mode` | string or list | Guardrail stages (`pre_call`, `during_call`, `post_call`). |
|
||||
| `optional_params.on_flagged_action` | string | `monitor` (log only) or `block` (raise `HTTPException`). |
|
||||
| `optional_params.on_flagged_action` | string | `monitor` (log only), `block` (raise `HTTPException`), or `passthrough` (include detection info in response without blocking). |
|
||||
| `.optional_params.violation_threshold`| number (0-1) | Scores at or above this value are considered violations. |
|
||||
| `optional_params.reasoning_mode` | string | `off`, `hybrid`, or `thinking`. Enables Cygnal’s reasoning capabilities. |
|
||||
| `optional_params.reasoning_mode` | string | `off`, `hybrid`, or `thinking`. Enables Cygnal's reasoning capabilities. |
|
||||
| `optional_params.categories` | object | Map of custom category names to descriptions. |
|
||||
| `optional_params.policy_id` | string | Gray Swan policy identifier. |
|
||||
|
||||
@@ -38,7 +38,7 @@ class GraySwanGuardrail(CustomGuardrail):
|
||||
see: https://docs.grayswan.ai/cygnal/monitor-requests
|
||||
"""
|
||||
|
||||
SUPPORTED_ON_FLAGGED_ACTIONS = {"block", "monitor"}
|
||||
SUPPORTED_ON_FLAGGED_ACTIONS = {"block", "monitor", "passthrough"}
|
||||
DEFAULT_ON_FLAGGED_ACTION = "monitor"
|
||||
BASE_API_URL = "https://api.grayswan.ai"
|
||||
MONITOR_PATH = "/cygnal/monitor"
|
||||
@@ -147,7 +147,7 @@ class GraySwanGuardrail(CustomGuardrail):
|
||||
)
|
||||
return data
|
||||
|
||||
await self.run_grayswan_guardrail(payload)
|
||||
await self.run_grayswan_guardrail(payload, data)
|
||||
add_guardrail_to_applied_guardrails_header(
|
||||
request_data=data, guardrail_name=self.guardrail_name
|
||||
)
|
||||
@@ -193,7 +193,7 @@ class GraySwanGuardrail(CustomGuardrail):
|
||||
)
|
||||
return data
|
||||
|
||||
await self.run_grayswan_guardrail(payload)
|
||||
await self.run_grayswan_guardrail(payload, data)
|
||||
add_guardrail_to_applied_guardrails_header(
|
||||
request_data=data, guardrail_name=self.guardrail_name
|
||||
)
|
||||
@@ -240,7 +240,24 @@ class GraySwanGuardrail(CustomGuardrail):
|
||||
)
|
||||
return response
|
||||
|
||||
await self.run_grayswan_guardrail(payload)
|
||||
await self.run_grayswan_guardrail(payload, data)
|
||||
|
||||
# If passthrough mode and detection info exists, add it to response
|
||||
if self.on_flagged_action == "passthrough" and "metadata" in data:
|
||||
guardrail_detections = data.get("metadata", {}).get(
|
||||
"guardrail_detections", []
|
||||
)
|
||||
if guardrail_detections:
|
||||
# Add guardrail detections to response hidden params for client visibility
|
||||
hidden_params = getattr(response, "_hidden_params", None)
|
||||
if hidden_params is not None:
|
||||
if not hidden_params:
|
||||
hidden_params = {}
|
||||
setattr(response, "_hidden_params", hidden_params)
|
||||
|
||||
hidden_params["guardrail_detections"] = guardrail_detections
|
||||
setattr(response, "_hidden_params", hidden_params)
|
||||
|
||||
add_guardrail_to_applied_guardrails_header(
|
||||
request_data=data, guardrail_name=self.guardrail_name
|
||||
)
|
||||
@@ -250,7 +267,7 @@ class GraySwanGuardrail(CustomGuardrail):
|
||||
# Core GraySwan interaction
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def run_grayswan_guardrail(self, payload: dict):
|
||||
async def run_grayswan_guardrail(self, payload: dict, data: Optional[dict] = None):
|
||||
headers = self._prepare_headers()
|
||||
|
||||
try:
|
||||
@@ -273,7 +290,7 @@ class GraySwanGuardrail(CustomGuardrail):
|
||||
)
|
||||
raise GraySwanGuardrailAPIError(str(exc)) from exc
|
||||
|
||||
self._process_grayswan_response(result)
|
||||
self._process_grayswan_response(result, data)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Helpers
|
||||
@@ -306,7 +323,9 @@ class GraySwanGuardrail(CustomGuardrail):
|
||||
|
||||
return payload
|
||||
|
||||
def _process_grayswan_response(self, response_json: Dict[str, Any]) -> None:
|
||||
def _process_grayswan_response(
|
||||
self, response_json: Dict[str, Any], data: Optional[dict] = None
|
||||
) -> None:
|
||||
violation_score = float(response_json.get("violation", 0.0) or 0.0)
|
||||
violated_rules = response_json.get("violated_rules", [])
|
||||
mutation_detected = response_json.get("mutation")
|
||||
@@ -338,6 +357,30 @@ class GraySwanGuardrail(CustomGuardrail):
|
||||
"ipi": ipi_detected,
|
||||
},
|
||||
)
|
||||
elif self.on_flagged_action == "monitor":
|
||||
verbose_proxy_logger.info(
|
||||
"Gray Swan Guardrail: Monitoring mode - allowing flagged content to proceed"
|
||||
)
|
||||
elif self.on_flagged_action == "passthrough":
|
||||
verbose_proxy_logger.info(
|
||||
"Gray Swan Guardrail: Passthrough mode - storing detection info in metadata"
|
||||
)
|
||||
if data is not None:
|
||||
# Store guardrail detection info in metadata to be included in response
|
||||
if "metadata" not in data:
|
||||
data["metadata"] = {}
|
||||
if "guardrail_detections" not in data["metadata"]:
|
||||
data["metadata"]["guardrail_detections"] = []
|
||||
|
||||
detection_info = {
|
||||
"guardrail": "grayswan",
|
||||
"flagged": True,
|
||||
"violation_score": violation_score,
|
||||
"violated_rules": violated_rules,
|
||||
"mutation": mutation_detected,
|
||||
"ipi": ipi_detected,
|
||||
}
|
||||
data["metadata"]["guardrail_detections"].append(detection_info)
|
||||
|
||||
def _resolve_threshold(self, threshold: Optional[float]) -> float:
|
||||
if threshold is not None:
|
||||
|
||||
@@ -12,7 +12,7 @@ class GraySwanGuardrailConfigModelOptionalParams(BaseModel):
|
||||
|
||||
on_flagged_action: Optional[str] = Field(
|
||||
default="monitor",
|
||||
description="Action when a violation is detected: 'block' rejects the call, 'monitor' logs only.",
|
||||
description="Action when a violation is detected: 'block' rejects the call, 'monitor' logs only, 'passthrough' includes detection info in response without blocking.",
|
||||
)
|
||||
violation_threshold: Optional[float] = Field(
|
||||
default=0.5,
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
from fastapi import HTTPException
|
||||
|
||||
@@ -22,7 +24,9 @@ def grayswan_guardrail() -> GraySwanGuardrail:
|
||||
)
|
||||
|
||||
|
||||
def test_prepare_payload_uses_dynamic_overrides(grayswan_guardrail: GraySwanGuardrail) -> None:
|
||||
def test_prepare_payload_uses_dynamic_overrides(
|
||||
grayswan_guardrail: GraySwanGuardrail,
|
||||
) -> None:
|
||||
messages = [{"role": "user", "content": "hello"}]
|
||||
dynamic_body = {
|
||||
"categories": {"custom": "override"},
|
||||
@@ -38,7 +42,9 @@ def test_prepare_payload_uses_dynamic_overrides(grayswan_guardrail: GraySwanGuar
|
||||
assert payload["reasoning_mode"] == "thinking"
|
||||
|
||||
|
||||
def test_prepare_payload_falls_back_to_guardrail_defaults(grayswan_guardrail: GraySwanGuardrail) -> None:
|
||||
def test_prepare_payload_falls_back_to_guardrail_defaults(
|
||||
grayswan_guardrail: GraySwanGuardrail,
|
||||
) -> None:
|
||||
messages = [{"role": "user", "content": "hello"}]
|
||||
|
||||
payload = grayswan_guardrail._prepare_payload(messages, {})
|
||||
@@ -48,8 +54,12 @@ def test_prepare_payload_falls_back_to_guardrail_defaults(grayswan_guardrail: Gr
|
||||
assert payload["reasoning_mode"] == "hybrid"
|
||||
|
||||
|
||||
def test_process_response_does_not_block_under_threshold(grayswan_guardrail: GraySwanGuardrail) -> None:
|
||||
grayswan_guardrail._process_grayswan_response({"violation": 0.3, "violated_rules": []})
|
||||
def test_process_response_does_not_block_under_threshold(
|
||||
grayswan_guardrail: GraySwanGuardrail,
|
||||
) -> None:
|
||||
grayswan_guardrail._process_grayswan_response(
|
||||
{"violation": 0.3, "violated_rules": []}
|
||||
)
|
||||
|
||||
|
||||
def test_process_response_blocks_when_threshold_exceeded() -> None:
|
||||
@@ -85,18 +95,22 @@ class _DummyClient:
|
||||
self.calls: list[dict] = []
|
||||
|
||||
async def post(self, *, url: str, headers: dict, json: dict, timeout: float):
|
||||
self.calls.append({"url": url, "headers": headers, "json": json, "timeout": timeout})
|
||||
self.calls.append(
|
||||
{"url": url, "headers": headers, "json": json, "timeout": timeout}
|
||||
)
|
||||
return _DummyResponse(self.payload)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_guardrail_posts_payload(monkeypatch, grayswan_guardrail: GraySwanGuardrail) -> None:
|
||||
async def test_run_guardrail_posts_payload(
|
||||
monkeypatch, grayswan_guardrail: GraySwanGuardrail
|
||||
) -> None:
|
||||
dummy_client = _DummyClient({"violation": 0.1})
|
||||
grayswan_guardrail.async_handler = dummy_client
|
||||
|
||||
captured = {}
|
||||
|
||||
def fake_process(response_json: dict) -> None:
|
||||
def fake_process(response_json: dict, data: Optional[dict] = None) -> None:
|
||||
captured["response"] = response_json
|
||||
|
||||
monkeypatch.setattr(grayswan_guardrail, "_process_grayswan_response", fake_process)
|
||||
@@ -110,7 +124,9 @@ async def test_run_guardrail_posts_payload(monkeypatch, grayswan_guardrail: Gray
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_guardrail_raises_api_error(grayswan_guardrail: GraySwanGuardrail) -> None:
|
||||
async def test_run_guardrail_raises_api_error(
|
||||
grayswan_guardrail: GraySwanGuardrail,
|
||||
) -> None:
|
||||
class _FailingClient:
|
||||
async def post(self, **_kwargs):
|
||||
raise RuntimeError("boom")
|
||||
@@ -121,3 +137,61 @@ async def test_run_guardrail_raises_api_error(grayswan_guardrail: GraySwanGuardr
|
||||
|
||||
with pytest.raises(GraySwanGuardrailAPIError):
|
||||
await grayswan_guardrail.run_grayswan_guardrail(payload)
|
||||
|
||||
|
||||
def test_process_response_passthrough_stores_detection_info() -> None:
|
||||
"""Test that passthrough mode stores detection info in metadata without blocking."""
|
||||
guardrail = GraySwanGuardrail(
|
||||
guardrail_name="grayswan-passthrough",
|
||||
api_key="test-key",
|
||||
on_flagged_action="passthrough",
|
||||
violation_threshold=0.2,
|
||||
event_hook=GuardrailEventHooks.pre_call,
|
||||
)
|
||||
|
||||
data = {"messages": [{"role": "user", "content": "test"}]}
|
||||
response_json = {
|
||||
"violation": 0.8,
|
||||
"violated_rules": [1, 2],
|
||||
"mutation": True,
|
||||
"ipi": False,
|
||||
}
|
||||
|
||||
# Should not raise an exception
|
||||
guardrail._process_grayswan_response(response_json, data)
|
||||
|
||||
# Verify detection info was stored in metadata
|
||||
assert "metadata" in data
|
||||
assert "guardrail_detections" in data["metadata"]
|
||||
assert len(data["metadata"]["guardrail_detections"]) == 1
|
||||
|
||||
detection = data["metadata"]["guardrail_detections"][0]
|
||||
assert detection["guardrail"] == "grayswan"
|
||||
assert detection["flagged"] is True
|
||||
assert detection["violation_score"] == 0.8
|
||||
assert detection["violated_rules"] == [1, 2]
|
||||
assert detection["mutation"] is True
|
||||
assert detection["ipi"] is False
|
||||
|
||||
|
||||
def test_process_response_passthrough_does_not_store_if_under_threshold() -> None:
|
||||
"""Test that passthrough mode doesn't store anything if violation is under threshold."""
|
||||
guardrail = GraySwanGuardrail(
|
||||
guardrail_name="grayswan-passthrough",
|
||||
api_key="test-key",
|
||||
on_flagged_action="passthrough",
|
||||
violation_threshold=0.5,
|
||||
event_hook=GuardrailEventHooks.pre_call,
|
||||
)
|
||||
|
||||
data = {"messages": [{"role": "user", "content": "test"}]}
|
||||
response_json = {
|
||||
"violation": 0.3,
|
||||
"violated_rules": [],
|
||||
}
|
||||
|
||||
# Should not raise an exception
|
||||
guardrail._process_grayswan_response(response_json, data)
|
||||
|
||||
# Should not have any detection info since it didn't exceed threshold
|
||||
assert "guardrail_detections" not in data.get("metadata", {})
|
||||
|
||||
Reference in New Issue
Block a user