From a54cf53ffb554a0afa5fea2e61e0895c7cecbdfa Mon Sep 17 00:00:00 2001 From: Sameer Kankute Date: Thu, 26 Feb 2026 12:13:26 +0530 Subject: [PATCH] Fix test_standard_logging_payload_includes_guardrail_information --- .../test_tracing_guardrails.py | 88 ++++++++++--------- 1 file changed, 48 insertions(+), 40 deletions(-) diff --git a/tests/guardrails_tests/test_tracing_guardrails.py b/tests/guardrails_tests/test_tracing_guardrails.py index 8e7ce27bc2..f119d6df3d 100644 --- a/tests/guardrails_tests/test_tracing_guardrails.py +++ b/tests/guardrails_tests/test_tracing_guardrails.py @@ -12,6 +12,8 @@ from litellm.proxy.guardrails.guardrail_hooks.presidio import _OPTIONAL_Presidio from litellm.integrations.custom_logger import CustomLogger from litellm.types.utils import StandardLoggingPayload, StandardLoggingGuardrailInformation from litellm.types.guardrails import GuardrailEventHooks +from litellm.proxy._types import UserAPIKeyAuth +from litellm.caching.caching import DualCache from typing import Optional @@ -64,9 +66,13 @@ async def test_standard_logging_payload_includes_guardrail_information(): # Create mock response objects mock_analyze_resp = MagicMock() + mock_analyze_resp.status = 200 + mock_analyze_resp.content_type = "application/json" mock_analyze_resp.json = AsyncMock(return_value=mock_analyze_response) - + mock_anonymize_resp = MagicMock() + mock_anonymize_resp.status = 200 + mock_anonymize_resp.content_type = "application/json" mock_anonymize_resp.json = AsyncMock(return_value=mock_anonymize_response) # Mock the aiohttp ClientSession with global call tracking @@ -85,7 +91,7 @@ async def test_standard_logging_payload_includes_guardrail_information(): async def close(self): self.closed = True - def post(self, url, json=None): + def post(self, url, json=None, **kwargs): class MockResponse: def __init__(self, response_obj): self.response_obj = response_obj @@ -116,8 +122,8 @@ async def test_standard_logging_payload_includes_guardrail_information(): with patch("aiohttp.ClientSession", MockClientSession): await presidio_guard.async_pre_call_hook( - user_api_key_dict={}, - cache=None, + user_api_key_dict=UserAPIKeyAuth(), + cache=DualCache(), data=request_data, call_type="acompletion" ) @@ -136,11 +142,11 @@ async def test_standard_logging_payload_includes_guardrail_information(): assert len(test_custom_logger.standard_logging_payload["guardrail_information"]) > 0 guardrail_info = test_custom_logger.standard_logging_payload["guardrail_information"][0] - assert guardrail_info["guardrail_name"] == "presidio_guard" - assert guardrail_info["guardrail_mode"] == GuardrailEventHooks.pre_call + assert guardrail_info.get("guardrail_name") == "presidio_guard" + assert guardrail_info.get("guardrail_mode") == GuardrailEventHooks.pre_call # assert that the guardrail_response is a response from presidio analyze - presidio_response = guardrail_info["guardrail_response"] + presidio_response = guardrail_info.get("guardrail_response") assert isinstance(presidio_response, list) for response_item in presidio_response: assert "analysis_explanation" in response_item @@ -150,12 +156,14 @@ async def test_standard_logging_payload_includes_guardrail_information(): assert "entity_type" in response_item # assert that the duration is not None - assert guardrail_info["duration"] is not None - assert guardrail_info["duration"] > 0 + duration = guardrail_info.get("duration") + assert duration is not None + assert duration > 0 # assert that we get the count of masked entities - assert guardrail_info["masked_entity_count"] is not None - assert guardrail_info["masked_entity_count"]["PHONE_NUMBER"] == 1 + masked_entity_count = guardrail_info.get("masked_entity_count") + assert masked_entity_count is not None + assert masked_entity_count["PHONE_NUMBER"] == 1 @@ -201,8 +209,8 @@ async def test_langfuse_trace_includes_guardrail_information(): "metadata": {}, } await presidio_guard.async_pre_call_hook( - user_api_key_dict={}, - cache=None, + user_api_key_dict=UserAPIKeyAuth(), + cache=DualCache(), data=request_data, call_type="acompletion" ) @@ -310,7 +318,7 @@ async def test_bedrock_guardrail_status_blocked(): try: await bedrock_guard.async_pre_call_hook( user_api_key_dict=UserAPIKeyAuth(), - cache=None, + cache=DualCache(), data=request_data, call_type="completion" ) @@ -331,8 +339,8 @@ async def test_bedrock_guardrail_status_blocked(): # Verify guardrail information fields (guardrail_information is now a list) guardrail_info = test_custom_logger.standard_logging_payload["guardrail_information"][0] - assert guardrail_info["guardrail_status"] == "guardrail_intervened" - assert guardrail_info["guardrail_provider"] == "bedrock" + assert guardrail_info.get("guardrail_status") == "guardrail_intervened" + assert guardrail_info.get("guardrail_provider") == "bedrock" # Verify the new typed status fields # guardrail_status should be "guardrail_intervened" when content is blocked @@ -395,7 +403,7 @@ async def test_bedrock_guardrail_status_success(): with patch.object(bedrock_guard, 'should_run_guardrail', return_value=True): await bedrock_guard.async_pre_call_hook( user_api_key_dict=UserAPIKeyAuth(), - cache=None, + cache=DualCache(), data=request_data, call_type="completion" ) @@ -411,8 +419,8 @@ async def test_bedrock_guardrail_status_success(): assert len(test_custom_logger.standard_logging_payload["guardrail_information"]) > 0 guardrail_info = test_custom_logger.standard_logging_payload["guardrail_information"][0] - assert guardrail_info["guardrail_status"] == "success" - assert guardrail_info["guardrail_provider"] == "bedrock" + assert guardrail_info.get("guardrail_status") == "success" + assert guardrail_info.get("guardrail_provider") == "bedrock" # Check status fields status_fields = test_custom_logger.standard_logging_payload.get("status_fields", {}) @@ -469,7 +477,7 @@ async def test_bedrock_guardrail_status_failure(): try: await bedrock_guard.async_pre_call_hook( user_api_key_dict=UserAPIKeyAuth(), - cache=None, + cache=DualCache(), data=request_data, call_type="completion" ) @@ -488,8 +496,8 @@ async def test_bedrock_guardrail_status_failure(): assert len(test_custom_logger.standard_logging_payload["guardrail_information"]) > 0 guardrail_info = test_custom_logger.standard_logging_payload["guardrail_information"][0] - assert guardrail_info["guardrail_status"] == "guardrail_failed_to_respond" - assert guardrail_info["guardrail_provider"] == "bedrock" + assert guardrail_info.get("guardrail_status") == "guardrail_failed_to_respond" + assert guardrail_info.get("guardrail_provider") == "bedrock" # Check status fields status_fields = test_custom_logger.standard_logging_payload.get("status_fields", {}) @@ -554,7 +562,7 @@ async def test_noma_guardrail_status_blocked(): try: await noma_guard.async_pre_call_hook( user_api_key_dict=UserAPIKeyAuth(), - cache=None, + cache=DualCache(), data=request_data, call_type="completion" ) @@ -572,8 +580,8 @@ async def test_noma_guardrail_status_blocked(): assert len(test_custom_logger.standard_logging_payload["guardrail_information"]) > 0 guardrail_info = test_custom_logger.standard_logging_payload["guardrail_information"][0] - assert guardrail_info["guardrail_status"] == "guardrail_intervened" - assert guardrail_info["guardrail_provider"] == "noma" + assert guardrail_info.get("guardrail_status") == "guardrail_intervened" + assert guardrail_info.get("guardrail_provider") == "noma" # Check status fields status_fields = test_custom_logger.standard_logging_payload.get("status_fields", {}) @@ -632,7 +640,7 @@ async def test_noma_guardrail_status_success(): with patch.object(noma_guard, 'should_run_guardrail', return_value=True): await noma_guard.async_pre_call_hook( user_api_key_dict=UserAPIKeyAuth(), - cache=None, + cache=DualCache(), data=request_data, call_type="completion" ) @@ -648,8 +656,8 @@ async def test_noma_guardrail_status_success(): assert len(test_custom_logger.standard_logging_payload["guardrail_information"]) > 0 guardrail_info = test_custom_logger.standard_logging_payload["guardrail_information"][0] - assert guardrail_info["guardrail_status"] == "success" - assert guardrail_info["guardrail_provider"] == "noma" + assert guardrail_info.get("guardrail_status") == "success" + assert guardrail_info.get("guardrail_provider") == "noma" # Check status fields status_fields = test_custom_logger.standard_logging_payload.get("status_fields", {}) @@ -679,8 +687,8 @@ def test_guardrail_status_fields_computation(): guardrail_information=intervened_info, error_str=None ) - assert status_fields_intervened["llm_api_status"] == "success" - assert status_fields_intervened["guardrail_status"] == "guardrail_intervened" + assert status_fields_intervened.get("llm_api_status") == "success" + assert status_fields_intervened.get("guardrail_status") == "guardrail_intervened" # Test legacy blocked status (for backward compatibility) blocked_info = [{"guardrail_status": "blocked"}] @@ -689,8 +697,8 @@ def test_guardrail_status_fields_computation(): guardrail_information=blocked_info, error_str=None ) - assert status_fields_blocked["llm_api_status"] == "success" - assert status_fields_blocked["guardrail_status"] == "guardrail_intervened" + assert status_fields_blocked.get("llm_api_status") == "success" + assert status_fields_blocked.get("guardrail_status") == "guardrail_intervened" # Test success status success_info = [{"guardrail_status": "success"}] @@ -699,8 +707,8 @@ def test_guardrail_status_fields_computation(): guardrail_information=success_info, error_str=None ) - assert status_fields_success["llm_api_status"] == "success" - assert status_fields_success["guardrail_status"] == "success" + assert status_fields_success.get("llm_api_status") == "success" + assert status_fields_success.get("guardrail_status") == "success" # Test guardrail_failed_to_respond status failed_info = [{"guardrail_status": "guardrail_failed_to_respond"}] @@ -709,8 +717,8 @@ def test_guardrail_status_fields_computation(): guardrail_information=failed_info, error_str=None ) - assert status_fields_failed["llm_api_status"] == "failure" - assert status_fields_failed["guardrail_status"] == "guardrail_failed_to_respond" + assert status_fields_failed.get("llm_api_status") == "failure" + assert status_fields_failed.get("guardrail_status") == "guardrail_failed_to_respond" # Test legacy failure status (for backward compatibility) failure_info = [{"guardrail_status": "failure"}] @@ -719,8 +727,8 @@ def test_guardrail_status_fields_computation(): guardrail_information=failure_info, error_str=None ) - assert status_fields_failure["llm_api_status"] == "failure" - assert status_fields_failure["guardrail_status"] == "guardrail_failed_to_respond" + assert status_fields_failure.get("llm_api_status") == "failure" + assert status_fields_failure.get("guardrail_status") == "guardrail_failed_to_respond" # Test no guardrail run no_guardrail = None @@ -729,5 +737,5 @@ def test_guardrail_status_fields_computation(): guardrail_information=no_guardrail, error_str=None ) - assert status_fields_no_guardrail["llm_api_status"] == "success" - assert status_fields_no_guardrail["guardrail_status"] == "not_run" \ No newline at end of file + assert status_fields_no_guardrail.get("llm_api_status") == "success" + assert status_fields_no_guardrail.get("guardrail_status") == "not_run" \ No newline at end of file