Grayswan guardrail passthrough on flagged (#16891)

* attempt to implement the passthrough feature

* Formatting and small change

* Fix formatting

* Format test file

---------

Co-authored-by: Xiaohan Fu <xiaohan@grayswan.ai>
This commit is contained in:
Derek Duenas
2025-11-21 23:01:35 -05:00
committed by GitHub
parent f56c7e1ef9
commit bbaf0af907
4 changed files with 135 additions and 18 deletions
@@ -142,8 +142,8 @@ Provides the strongest enforcement by inspecting both prompts and responses.
|---------------------------------------|-----------------|-------------|
| `api_key` | string | Gray Swan Cygnal API key. Reads from `GRAYSWAN_API_KEY` if omitted. |
| `mode` | string or list | Guardrail stages (`pre_call`, `during_call`, `post_call`). |
| `optional_params.on_flagged_action` | string | `monitor` (log only) or `block` (raise `HTTPException`). |
| `optional_params.on_flagged_action` | string | `monitor` (log only), `block` (raise `HTTPException`), or `passthrough` (include detection info in response without blocking). |
| `.optional_params.violation_threshold`| number (0-1) | Scores at or above this value are considered violations. |
| `optional_params.reasoning_mode` | string | `off`, `hybrid`, or `thinking`. Enables Cygnals reasoning capabilities. |
| `optional_params.reasoning_mode` | string | `off`, `hybrid`, or `thinking`. Enables Cygnal's reasoning capabilities. |
| `optional_params.categories` | object | Map of custom category names to descriptions. |
| `optional_params.policy_id` | string | Gray Swan policy identifier. |
@@ -38,7 +38,7 @@ class GraySwanGuardrail(CustomGuardrail):
see: https://docs.grayswan.ai/cygnal/monitor-requests
"""
SUPPORTED_ON_FLAGGED_ACTIONS = {"block", "monitor"}
SUPPORTED_ON_FLAGGED_ACTIONS = {"block", "monitor", "passthrough"}
DEFAULT_ON_FLAGGED_ACTION = "monitor"
BASE_API_URL = "https://api.grayswan.ai"
MONITOR_PATH = "/cygnal/monitor"
@@ -147,7 +147,7 @@ class GraySwanGuardrail(CustomGuardrail):
)
return data
await self.run_grayswan_guardrail(payload)
await self.run_grayswan_guardrail(payload, data)
add_guardrail_to_applied_guardrails_header(
request_data=data, guardrail_name=self.guardrail_name
)
@@ -193,7 +193,7 @@ class GraySwanGuardrail(CustomGuardrail):
)
return data
await self.run_grayswan_guardrail(payload)
await self.run_grayswan_guardrail(payload, data)
add_guardrail_to_applied_guardrails_header(
request_data=data, guardrail_name=self.guardrail_name
)
@@ -240,7 +240,24 @@ class GraySwanGuardrail(CustomGuardrail):
)
return response
await self.run_grayswan_guardrail(payload)
await self.run_grayswan_guardrail(payload, data)
# If passthrough mode and detection info exists, add it to response
if self.on_flagged_action == "passthrough" and "metadata" in data:
guardrail_detections = data.get("metadata", {}).get(
"guardrail_detections", []
)
if guardrail_detections:
# Add guardrail detections to response hidden params for client visibility
hidden_params = getattr(response, "_hidden_params", None)
if hidden_params is not None:
if not hidden_params:
hidden_params = {}
setattr(response, "_hidden_params", hidden_params)
hidden_params["guardrail_detections"] = guardrail_detections
setattr(response, "_hidden_params", hidden_params)
add_guardrail_to_applied_guardrails_header(
request_data=data, guardrail_name=self.guardrail_name
)
@@ -250,7 +267,7 @@ class GraySwanGuardrail(CustomGuardrail):
# Core GraySwan interaction
# ------------------------------------------------------------------
async def run_grayswan_guardrail(self, payload: dict):
async def run_grayswan_guardrail(self, payload: dict, data: Optional[dict] = None):
headers = self._prepare_headers()
try:
@@ -273,7 +290,7 @@ class GraySwanGuardrail(CustomGuardrail):
)
raise GraySwanGuardrailAPIError(str(exc)) from exc
self._process_grayswan_response(result)
self._process_grayswan_response(result, data)
# ------------------------------------------------------------------
# Helpers
@@ -306,7 +323,9 @@ class GraySwanGuardrail(CustomGuardrail):
return payload
def _process_grayswan_response(self, response_json: Dict[str, Any]) -> None:
def _process_grayswan_response(
self, response_json: Dict[str, Any], data: Optional[dict] = None
) -> None:
violation_score = float(response_json.get("violation", 0.0) or 0.0)
violated_rules = response_json.get("violated_rules", [])
mutation_detected = response_json.get("mutation")
@@ -338,6 +357,30 @@ class GraySwanGuardrail(CustomGuardrail):
"ipi": ipi_detected,
},
)
elif self.on_flagged_action == "monitor":
verbose_proxy_logger.info(
"Gray Swan Guardrail: Monitoring mode - allowing flagged content to proceed"
)
elif self.on_flagged_action == "passthrough":
verbose_proxy_logger.info(
"Gray Swan Guardrail: Passthrough mode - storing detection info in metadata"
)
if data is not None:
# Store guardrail detection info in metadata to be included in response
if "metadata" not in data:
data["metadata"] = {}
if "guardrail_detections" not in data["metadata"]:
data["metadata"]["guardrail_detections"] = []
detection_info = {
"guardrail": "grayswan",
"flagged": True,
"violation_score": violation_score,
"violated_rules": violated_rules,
"mutation": mutation_detected,
"ipi": ipi_detected,
}
data["metadata"]["guardrail_detections"].append(detection_info)
def _resolve_threshold(self, threshold: Optional[float]) -> float:
if threshold is not None:
@@ -12,7 +12,7 @@ class GraySwanGuardrailConfigModelOptionalParams(BaseModel):
on_flagged_action: Optional[str] = Field(
default="monitor",
description="Action when a violation is detected: 'block' rejects the call, 'monitor' logs only.",
description="Action when a violation is detected: 'block' rejects the call, 'monitor' logs only, 'passthrough' includes detection info in response without blocking.",
)
violation_threshold: Optional[float] = Field(
default=0.5,
@@ -1,3 +1,5 @@
from typing import Optional
import pytest
from fastapi import HTTPException
@@ -22,7 +24,9 @@ def grayswan_guardrail() -> GraySwanGuardrail:
)
def test_prepare_payload_uses_dynamic_overrides(grayswan_guardrail: GraySwanGuardrail) -> None:
def test_prepare_payload_uses_dynamic_overrides(
grayswan_guardrail: GraySwanGuardrail,
) -> None:
messages = [{"role": "user", "content": "hello"}]
dynamic_body = {
"categories": {"custom": "override"},
@@ -38,7 +42,9 @@ def test_prepare_payload_uses_dynamic_overrides(grayswan_guardrail: GraySwanGuar
assert payload["reasoning_mode"] == "thinking"
def test_prepare_payload_falls_back_to_guardrail_defaults(grayswan_guardrail: GraySwanGuardrail) -> None:
def test_prepare_payload_falls_back_to_guardrail_defaults(
grayswan_guardrail: GraySwanGuardrail,
) -> None:
messages = [{"role": "user", "content": "hello"}]
payload = grayswan_guardrail._prepare_payload(messages, {})
@@ -48,8 +54,12 @@ def test_prepare_payload_falls_back_to_guardrail_defaults(grayswan_guardrail: Gr
assert payload["reasoning_mode"] == "hybrid"
def test_process_response_does_not_block_under_threshold(grayswan_guardrail: GraySwanGuardrail) -> None:
grayswan_guardrail._process_grayswan_response({"violation": 0.3, "violated_rules": []})
def test_process_response_does_not_block_under_threshold(
grayswan_guardrail: GraySwanGuardrail,
) -> None:
grayswan_guardrail._process_grayswan_response(
{"violation": 0.3, "violated_rules": []}
)
def test_process_response_blocks_when_threshold_exceeded() -> None:
@@ -85,18 +95,22 @@ class _DummyClient:
self.calls: list[dict] = []
async def post(self, *, url: str, headers: dict, json: dict, timeout: float):
self.calls.append({"url": url, "headers": headers, "json": json, "timeout": timeout})
self.calls.append(
{"url": url, "headers": headers, "json": json, "timeout": timeout}
)
return _DummyResponse(self.payload)
@pytest.mark.asyncio
async def test_run_guardrail_posts_payload(monkeypatch, grayswan_guardrail: GraySwanGuardrail) -> None:
async def test_run_guardrail_posts_payload(
monkeypatch, grayswan_guardrail: GraySwanGuardrail
) -> None:
dummy_client = _DummyClient({"violation": 0.1})
grayswan_guardrail.async_handler = dummy_client
captured = {}
def fake_process(response_json: dict) -> None:
def fake_process(response_json: dict, data: Optional[dict] = None) -> None:
captured["response"] = response_json
monkeypatch.setattr(grayswan_guardrail, "_process_grayswan_response", fake_process)
@@ -110,7 +124,9 @@ async def test_run_guardrail_posts_payload(monkeypatch, grayswan_guardrail: Gray
@pytest.mark.asyncio
async def test_run_guardrail_raises_api_error(grayswan_guardrail: GraySwanGuardrail) -> None:
async def test_run_guardrail_raises_api_error(
grayswan_guardrail: GraySwanGuardrail,
) -> None:
class _FailingClient:
async def post(self, **_kwargs):
raise RuntimeError("boom")
@@ -121,3 +137,61 @@ async def test_run_guardrail_raises_api_error(grayswan_guardrail: GraySwanGuardr
with pytest.raises(GraySwanGuardrailAPIError):
await grayswan_guardrail.run_grayswan_guardrail(payload)
def test_process_response_passthrough_stores_detection_info() -> None:
"""Test that passthrough mode stores detection info in metadata without blocking."""
guardrail = GraySwanGuardrail(
guardrail_name="grayswan-passthrough",
api_key="test-key",
on_flagged_action="passthrough",
violation_threshold=0.2,
event_hook=GuardrailEventHooks.pre_call,
)
data = {"messages": [{"role": "user", "content": "test"}]}
response_json = {
"violation": 0.8,
"violated_rules": [1, 2],
"mutation": True,
"ipi": False,
}
# Should not raise an exception
guardrail._process_grayswan_response(response_json, data)
# Verify detection info was stored in metadata
assert "metadata" in data
assert "guardrail_detections" in data["metadata"]
assert len(data["metadata"]["guardrail_detections"]) == 1
detection = data["metadata"]["guardrail_detections"][0]
assert detection["guardrail"] == "grayswan"
assert detection["flagged"] is True
assert detection["violation_score"] == 0.8
assert detection["violated_rules"] == [1, 2]
assert detection["mutation"] is True
assert detection["ipi"] is False
def test_process_response_passthrough_does_not_store_if_under_threshold() -> None:
"""Test that passthrough mode doesn't store anything if violation is under threshold."""
guardrail = GraySwanGuardrail(
guardrail_name="grayswan-passthrough",
api_key="test-key",
on_flagged_action="passthrough",
violation_threshold=0.5,
event_hook=GuardrailEventHooks.pre_call,
)
data = {"messages": [{"role": "user", "content": "test"}]}
response_json = {
"violation": 0.3,
"violated_rules": [],
}
# Should not raise an exception
guardrail._process_grayswan_response(response_json, data)
# Should not have any detection info since it didn't exceed threshold
assert "guardrail_detections" not in data.get("metadata", {})