diff --git a/docs/my-website/docs/proxy/guardrails/grayswan.md b/docs/my-website/docs/proxy/guardrails/grayswan.md index b510c870a1..7cc75b9f3b 100644 --- a/docs/my-website/docs/proxy/guardrails/grayswan.md +++ b/docs/my-website/docs/proxy/guardrails/grayswan.md @@ -142,8 +142,8 @@ Provides the strongest enforcement by inspecting both prompts and responses. |---------------------------------------|-----------------|-------------| | `api_key` | string | Gray Swan Cygnal API key. Reads from `GRAYSWAN_API_KEY` if omitted. | | `mode` | string or list | Guardrail stages (`pre_call`, `during_call`, `post_call`). | -| `optional_params.on_flagged_action` | string | `monitor` (log only) or `block` (raise `HTTPException`). | +| `optional_params.on_flagged_action` | string | `monitor` (log only), `block` (raise `HTTPException`), or `passthrough` (include detection info in response without blocking). | | `.optional_params.violation_threshold`| number (0-1) | Scores at or above this value are considered violations. | -| `optional_params.reasoning_mode` | string | `off`, `hybrid`, or `thinking`. Enables Cygnal’s reasoning capabilities. | +| `optional_params.reasoning_mode` | string | `off`, `hybrid`, or `thinking`. Enables Cygnal's reasoning capabilities. | | `optional_params.categories` | object | Map of custom category names to descriptions. | | `optional_params.policy_id` | string | Gray Swan policy identifier. | diff --git a/litellm/proxy/guardrails/guardrail_hooks/grayswan/grayswan.py b/litellm/proxy/guardrails/guardrail_hooks/grayswan/grayswan.py index fb93dec47b..7f9f900a5a 100644 --- a/litellm/proxy/guardrails/guardrail_hooks/grayswan/grayswan.py +++ b/litellm/proxy/guardrails/guardrail_hooks/grayswan/grayswan.py @@ -38,7 +38,7 @@ class GraySwanGuardrail(CustomGuardrail): see: https://docs.grayswan.ai/cygnal/monitor-requests """ - SUPPORTED_ON_FLAGGED_ACTIONS = {"block", "monitor"} + SUPPORTED_ON_FLAGGED_ACTIONS = {"block", "monitor", "passthrough"} DEFAULT_ON_FLAGGED_ACTION = "monitor" BASE_API_URL = "https://api.grayswan.ai" MONITOR_PATH = "/cygnal/monitor" @@ -147,7 +147,7 @@ class GraySwanGuardrail(CustomGuardrail): ) return data - await self.run_grayswan_guardrail(payload) + await self.run_grayswan_guardrail(payload, data) add_guardrail_to_applied_guardrails_header( request_data=data, guardrail_name=self.guardrail_name ) @@ -193,7 +193,7 @@ class GraySwanGuardrail(CustomGuardrail): ) return data - await self.run_grayswan_guardrail(payload) + await self.run_grayswan_guardrail(payload, data) add_guardrail_to_applied_guardrails_header( request_data=data, guardrail_name=self.guardrail_name ) @@ -240,7 +240,24 @@ class GraySwanGuardrail(CustomGuardrail): ) return response - await self.run_grayswan_guardrail(payload) + await self.run_grayswan_guardrail(payload, data) + + # If passthrough mode and detection info exists, add it to response + if self.on_flagged_action == "passthrough" and "metadata" in data: + guardrail_detections = data.get("metadata", {}).get( + "guardrail_detections", [] + ) + if guardrail_detections: + # Add guardrail detections to response hidden params for client visibility + hidden_params = getattr(response, "_hidden_params", None) + if hidden_params is not None: + if not hidden_params: + hidden_params = {} + setattr(response, "_hidden_params", hidden_params) + + hidden_params["guardrail_detections"] = guardrail_detections + setattr(response, "_hidden_params", hidden_params) + add_guardrail_to_applied_guardrails_header( request_data=data, guardrail_name=self.guardrail_name ) @@ -250,7 +267,7 @@ class GraySwanGuardrail(CustomGuardrail): # Core GraySwan interaction # ------------------------------------------------------------------ - async def run_grayswan_guardrail(self, payload: dict): + async def run_grayswan_guardrail(self, payload: dict, data: Optional[dict] = None): headers = self._prepare_headers() try: @@ -273,7 +290,7 @@ class GraySwanGuardrail(CustomGuardrail): ) raise GraySwanGuardrailAPIError(str(exc)) from exc - self._process_grayswan_response(result) + self._process_grayswan_response(result, data) # ------------------------------------------------------------------ # Helpers @@ -306,7 +323,9 @@ class GraySwanGuardrail(CustomGuardrail): return payload - def _process_grayswan_response(self, response_json: Dict[str, Any]) -> None: + def _process_grayswan_response( + self, response_json: Dict[str, Any], data: Optional[dict] = None + ) -> None: violation_score = float(response_json.get("violation", 0.0) or 0.0) violated_rules = response_json.get("violated_rules", []) mutation_detected = response_json.get("mutation") @@ -338,6 +357,30 @@ class GraySwanGuardrail(CustomGuardrail): "ipi": ipi_detected, }, ) + elif self.on_flagged_action == "monitor": + verbose_proxy_logger.info( + "Gray Swan Guardrail: Monitoring mode - allowing flagged content to proceed" + ) + elif self.on_flagged_action == "passthrough": + verbose_proxy_logger.info( + "Gray Swan Guardrail: Passthrough mode - storing detection info in metadata" + ) + if data is not None: + # Store guardrail detection info in metadata to be included in response + if "metadata" not in data: + data["metadata"] = {} + if "guardrail_detections" not in data["metadata"]: + data["metadata"]["guardrail_detections"] = [] + + detection_info = { + "guardrail": "grayswan", + "flagged": True, + "violation_score": violation_score, + "violated_rules": violated_rules, + "mutation": mutation_detected, + "ipi": ipi_detected, + } + data["metadata"]["guardrail_detections"].append(detection_info) def _resolve_threshold(self, threshold: Optional[float]) -> float: if threshold is not None: diff --git a/litellm/types/proxy/guardrails/guardrail_hooks/grayswan.py b/litellm/types/proxy/guardrails/guardrail_hooks/grayswan.py index d50ae95ee3..6a6e5a1e48 100644 --- a/litellm/types/proxy/guardrails/guardrail_hooks/grayswan.py +++ b/litellm/types/proxy/guardrails/guardrail_hooks/grayswan.py @@ -12,7 +12,7 @@ class GraySwanGuardrailConfigModelOptionalParams(BaseModel): on_flagged_action: Optional[str] = Field( default="monitor", - description="Action when a violation is detected: 'block' rejects the call, 'monitor' logs only.", + description="Action when a violation is detected: 'block' rejects the call, 'monitor' logs only, 'passthrough' includes detection info in response without blocking.", ) violation_threshold: Optional[float] = Field( default=0.5, diff --git a/tests/test_litellm/proxy/guardrails/guardrail_hooks/test_grayswan.py b/tests/test_litellm/proxy/guardrails/guardrail_hooks/test_grayswan.py index d6bd0a251b..9ee31cf6cb 100644 --- a/tests/test_litellm/proxy/guardrails/guardrail_hooks/test_grayswan.py +++ b/tests/test_litellm/proxy/guardrails/guardrail_hooks/test_grayswan.py @@ -1,3 +1,5 @@ +from typing import Optional + import pytest from fastapi import HTTPException @@ -22,7 +24,9 @@ def grayswan_guardrail() -> GraySwanGuardrail: ) -def test_prepare_payload_uses_dynamic_overrides(grayswan_guardrail: GraySwanGuardrail) -> None: +def test_prepare_payload_uses_dynamic_overrides( + grayswan_guardrail: GraySwanGuardrail, +) -> None: messages = [{"role": "user", "content": "hello"}] dynamic_body = { "categories": {"custom": "override"}, @@ -38,7 +42,9 @@ def test_prepare_payload_uses_dynamic_overrides(grayswan_guardrail: GraySwanGuar assert payload["reasoning_mode"] == "thinking" -def test_prepare_payload_falls_back_to_guardrail_defaults(grayswan_guardrail: GraySwanGuardrail) -> None: +def test_prepare_payload_falls_back_to_guardrail_defaults( + grayswan_guardrail: GraySwanGuardrail, +) -> None: messages = [{"role": "user", "content": "hello"}] payload = grayswan_guardrail._prepare_payload(messages, {}) @@ -48,8 +54,12 @@ def test_prepare_payload_falls_back_to_guardrail_defaults(grayswan_guardrail: Gr assert payload["reasoning_mode"] == "hybrid" -def test_process_response_does_not_block_under_threshold(grayswan_guardrail: GraySwanGuardrail) -> None: - grayswan_guardrail._process_grayswan_response({"violation": 0.3, "violated_rules": []}) +def test_process_response_does_not_block_under_threshold( + grayswan_guardrail: GraySwanGuardrail, +) -> None: + grayswan_guardrail._process_grayswan_response( + {"violation": 0.3, "violated_rules": []} + ) def test_process_response_blocks_when_threshold_exceeded() -> None: @@ -85,18 +95,22 @@ class _DummyClient: self.calls: list[dict] = [] async def post(self, *, url: str, headers: dict, json: dict, timeout: float): - self.calls.append({"url": url, "headers": headers, "json": json, "timeout": timeout}) + self.calls.append( + {"url": url, "headers": headers, "json": json, "timeout": timeout} + ) return _DummyResponse(self.payload) @pytest.mark.asyncio -async def test_run_guardrail_posts_payload(monkeypatch, grayswan_guardrail: GraySwanGuardrail) -> None: +async def test_run_guardrail_posts_payload( + monkeypatch, grayswan_guardrail: GraySwanGuardrail +) -> None: dummy_client = _DummyClient({"violation": 0.1}) grayswan_guardrail.async_handler = dummy_client captured = {} - def fake_process(response_json: dict) -> None: + def fake_process(response_json: dict, data: Optional[dict] = None) -> None: captured["response"] = response_json monkeypatch.setattr(grayswan_guardrail, "_process_grayswan_response", fake_process) @@ -110,7 +124,9 @@ async def test_run_guardrail_posts_payload(monkeypatch, grayswan_guardrail: Gray @pytest.mark.asyncio -async def test_run_guardrail_raises_api_error(grayswan_guardrail: GraySwanGuardrail) -> None: +async def test_run_guardrail_raises_api_error( + grayswan_guardrail: GraySwanGuardrail, +) -> None: class _FailingClient: async def post(self, **_kwargs): raise RuntimeError("boom") @@ -121,3 +137,61 @@ async def test_run_guardrail_raises_api_error(grayswan_guardrail: GraySwanGuardr with pytest.raises(GraySwanGuardrailAPIError): await grayswan_guardrail.run_grayswan_guardrail(payload) + + +def test_process_response_passthrough_stores_detection_info() -> None: + """Test that passthrough mode stores detection info in metadata without blocking.""" + guardrail = GraySwanGuardrail( + guardrail_name="grayswan-passthrough", + api_key="test-key", + on_flagged_action="passthrough", + violation_threshold=0.2, + event_hook=GuardrailEventHooks.pre_call, + ) + + data = {"messages": [{"role": "user", "content": "test"}]} + response_json = { + "violation": 0.8, + "violated_rules": [1, 2], + "mutation": True, + "ipi": False, + } + + # Should not raise an exception + guardrail._process_grayswan_response(response_json, data) + + # Verify detection info was stored in metadata + assert "metadata" in data + assert "guardrail_detections" in data["metadata"] + assert len(data["metadata"]["guardrail_detections"]) == 1 + + detection = data["metadata"]["guardrail_detections"][0] + assert detection["guardrail"] == "grayswan" + assert detection["flagged"] is True + assert detection["violation_score"] == 0.8 + assert detection["violated_rules"] == [1, 2] + assert detection["mutation"] is True + assert detection["ipi"] is False + + +def test_process_response_passthrough_does_not_store_if_under_threshold() -> None: + """Test that passthrough mode doesn't store anything if violation is under threshold.""" + guardrail = GraySwanGuardrail( + guardrail_name="grayswan-passthrough", + api_key="test-key", + on_flagged_action="passthrough", + violation_threshold=0.5, + event_hook=GuardrailEventHooks.pre_call, + ) + + data = {"messages": [{"role": "user", "content": "test"}]} + response_json = { + "violation": 0.3, + "violated_rules": [], + } + + # Should not raise an exception + guardrail._process_grayswan_response(response_json, data) + + # Should not have any detection info since it didn't exceed threshold + assert "guardrail_detections" not in data.get("metadata", {})