mirror of
https://github.com/tiennm99/litellm.git
synced 2026-06-18 03:31:23 +00:00
Fix test_standard_logging_payload_includes_guardrail_information
This commit is contained in:
@@ -12,6 +12,8 @@ from litellm.proxy.guardrails.guardrail_hooks.presidio import _OPTIONAL_Presidio
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm.types.utils import StandardLoggingPayload, StandardLoggingGuardrailInformation
|
||||
from litellm.types.guardrails import GuardrailEventHooks
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
from litellm.caching.caching import DualCache
|
||||
from typing import Optional
|
||||
|
||||
|
||||
@@ -64,9 +66,13 @@ async def test_standard_logging_payload_includes_guardrail_information():
|
||||
|
||||
# Create mock response objects
|
||||
mock_analyze_resp = MagicMock()
|
||||
mock_analyze_resp.status = 200
|
||||
mock_analyze_resp.content_type = "application/json"
|
||||
mock_analyze_resp.json = AsyncMock(return_value=mock_analyze_response)
|
||||
|
||||
|
||||
mock_anonymize_resp = MagicMock()
|
||||
mock_anonymize_resp.status = 200
|
||||
mock_anonymize_resp.content_type = "application/json"
|
||||
mock_anonymize_resp.json = AsyncMock(return_value=mock_anonymize_response)
|
||||
|
||||
# Mock the aiohttp ClientSession with global call tracking
|
||||
@@ -85,7 +91,7 @@ async def test_standard_logging_payload_includes_guardrail_information():
|
||||
async def close(self):
|
||||
self.closed = True
|
||||
|
||||
def post(self, url, json=None):
|
||||
def post(self, url, json=None, **kwargs):
|
||||
class MockResponse:
|
||||
def __init__(self, response_obj):
|
||||
self.response_obj = response_obj
|
||||
@@ -116,8 +122,8 @@ async def test_standard_logging_payload_includes_guardrail_information():
|
||||
|
||||
with patch("aiohttp.ClientSession", MockClientSession):
|
||||
await presidio_guard.async_pre_call_hook(
|
||||
user_api_key_dict={},
|
||||
cache=None,
|
||||
user_api_key_dict=UserAPIKeyAuth(),
|
||||
cache=DualCache(),
|
||||
data=request_data,
|
||||
call_type="acompletion"
|
||||
)
|
||||
@@ -136,11 +142,11 @@ async def test_standard_logging_payload_includes_guardrail_information():
|
||||
assert len(test_custom_logger.standard_logging_payload["guardrail_information"]) > 0
|
||||
|
||||
guardrail_info = test_custom_logger.standard_logging_payload["guardrail_information"][0]
|
||||
assert guardrail_info["guardrail_name"] == "presidio_guard"
|
||||
assert guardrail_info["guardrail_mode"] == GuardrailEventHooks.pre_call
|
||||
assert guardrail_info.get("guardrail_name") == "presidio_guard"
|
||||
assert guardrail_info.get("guardrail_mode") == GuardrailEventHooks.pre_call
|
||||
|
||||
# assert that the guardrail_response is a response from presidio analyze
|
||||
presidio_response = guardrail_info["guardrail_response"]
|
||||
presidio_response = guardrail_info.get("guardrail_response")
|
||||
assert isinstance(presidio_response, list)
|
||||
for response_item in presidio_response:
|
||||
assert "analysis_explanation" in response_item
|
||||
@@ -150,12 +156,14 @@ async def test_standard_logging_payload_includes_guardrail_information():
|
||||
assert "entity_type" in response_item
|
||||
|
||||
# assert that the duration is not None
|
||||
assert guardrail_info["duration"] is not None
|
||||
assert guardrail_info["duration"] > 0
|
||||
duration = guardrail_info.get("duration")
|
||||
assert duration is not None
|
||||
assert duration > 0
|
||||
|
||||
# assert that we get the count of masked entities
|
||||
assert guardrail_info["masked_entity_count"] is not None
|
||||
assert guardrail_info["masked_entity_count"]["PHONE_NUMBER"] == 1
|
||||
masked_entity_count = guardrail_info.get("masked_entity_count")
|
||||
assert masked_entity_count is not None
|
||||
assert masked_entity_count["PHONE_NUMBER"] == 1
|
||||
|
||||
|
||||
|
||||
@@ -201,8 +209,8 @@ async def test_langfuse_trace_includes_guardrail_information():
|
||||
"metadata": {},
|
||||
}
|
||||
await presidio_guard.async_pre_call_hook(
|
||||
user_api_key_dict={},
|
||||
cache=None,
|
||||
user_api_key_dict=UserAPIKeyAuth(),
|
||||
cache=DualCache(),
|
||||
data=request_data,
|
||||
call_type="acompletion"
|
||||
)
|
||||
@@ -310,7 +318,7 @@ async def test_bedrock_guardrail_status_blocked():
|
||||
try:
|
||||
await bedrock_guard.async_pre_call_hook(
|
||||
user_api_key_dict=UserAPIKeyAuth(),
|
||||
cache=None,
|
||||
cache=DualCache(),
|
||||
data=request_data,
|
||||
call_type="completion"
|
||||
)
|
||||
@@ -331,8 +339,8 @@ async def test_bedrock_guardrail_status_blocked():
|
||||
|
||||
# Verify guardrail information fields (guardrail_information is now a list)
|
||||
guardrail_info = test_custom_logger.standard_logging_payload["guardrail_information"][0]
|
||||
assert guardrail_info["guardrail_status"] == "guardrail_intervened"
|
||||
assert guardrail_info["guardrail_provider"] == "bedrock"
|
||||
assert guardrail_info.get("guardrail_status") == "guardrail_intervened"
|
||||
assert guardrail_info.get("guardrail_provider") == "bedrock"
|
||||
|
||||
# Verify the new typed status fields
|
||||
# guardrail_status should be "guardrail_intervened" when content is blocked
|
||||
@@ -395,7 +403,7 @@ async def test_bedrock_guardrail_status_success():
|
||||
with patch.object(bedrock_guard, 'should_run_guardrail', return_value=True):
|
||||
await bedrock_guard.async_pre_call_hook(
|
||||
user_api_key_dict=UserAPIKeyAuth(),
|
||||
cache=None,
|
||||
cache=DualCache(),
|
||||
data=request_data,
|
||||
call_type="completion"
|
||||
)
|
||||
@@ -411,8 +419,8 @@ async def test_bedrock_guardrail_status_success():
|
||||
assert len(test_custom_logger.standard_logging_payload["guardrail_information"]) > 0
|
||||
|
||||
guardrail_info = test_custom_logger.standard_logging_payload["guardrail_information"][0]
|
||||
assert guardrail_info["guardrail_status"] == "success"
|
||||
assert guardrail_info["guardrail_provider"] == "bedrock"
|
||||
assert guardrail_info.get("guardrail_status") == "success"
|
||||
assert guardrail_info.get("guardrail_provider") == "bedrock"
|
||||
|
||||
# Check status fields
|
||||
status_fields = test_custom_logger.standard_logging_payload.get("status_fields", {})
|
||||
@@ -469,7 +477,7 @@ async def test_bedrock_guardrail_status_failure():
|
||||
try:
|
||||
await bedrock_guard.async_pre_call_hook(
|
||||
user_api_key_dict=UserAPIKeyAuth(),
|
||||
cache=None,
|
||||
cache=DualCache(),
|
||||
data=request_data,
|
||||
call_type="completion"
|
||||
)
|
||||
@@ -488,8 +496,8 @@ async def test_bedrock_guardrail_status_failure():
|
||||
assert len(test_custom_logger.standard_logging_payload["guardrail_information"]) > 0
|
||||
|
||||
guardrail_info = test_custom_logger.standard_logging_payload["guardrail_information"][0]
|
||||
assert guardrail_info["guardrail_status"] == "guardrail_failed_to_respond"
|
||||
assert guardrail_info["guardrail_provider"] == "bedrock"
|
||||
assert guardrail_info.get("guardrail_status") == "guardrail_failed_to_respond"
|
||||
assert guardrail_info.get("guardrail_provider") == "bedrock"
|
||||
|
||||
# Check status fields
|
||||
status_fields = test_custom_logger.standard_logging_payload.get("status_fields", {})
|
||||
@@ -554,7 +562,7 @@ async def test_noma_guardrail_status_blocked():
|
||||
try:
|
||||
await noma_guard.async_pre_call_hook(
|
||||
user_api_key_dict=UserAPIKeyAuth(),
|
||||
cache=None,
|
||||
cache=DualCache(),
|
||||
data=request_data,
|
||||
call_type="completion"
|
||||
)
|
||||
@@ -572,8 +580,8 @@ async def test_noma_guardrail_status_blocked():
|
||||
assert len(test_custom_logger.standard_logging_payload["guardrail_information"]) > 0
|
||||
|
||||
guardrail_info = test_custom_logger.standard_logging_payload["guardrail_information"][0]
|
||||
assert guardrail_info["guardrail_status"] == "guardrail_intervened"
|
||||
assert guardrail_info["guardrail_provider"] == "noma"
|
||||
assert guardrail_info.get("guardrail_status") == "guardrail_intervened"
|
||||
assert guardrail_info.get("guardrail_provider") == "noma"
|
||||
|
||||
# Check status fields
|
||||
status_fields = test_custom_logger.standard_logging_payload.get("status_fields", {})
|
||||
@@ -632,7 +640,7 @@ async def test_noma_guardrail_status_success():
|
||||
with patch.object(noma_guard, 'should_run_guardrail', return_value=True):
|
||||
await noma_guard.async_pre_call_hook(
|
||||
user_api_key_dict=UserAPIKeyAuth(),
|
||||
cache=None,
|
||||
cache=DualCache(),
|
||||
data=request_data,
|
||||
call_type="completion"
|
||||
)
|
||||
@@ -648,8 +656,8 @@ async def test_noma_guardrail_status_success():
|
||||
assert len(test_custom_logger.standard_logging_payload["guardrail_information"]) > 0
|
||||
|
||||
guardrail_info = test_custom_logger.standard_logging_payload["guardrail_information"][0]
|
||||
assert guardrail_info["guardrail_status"] == "success"
|
||||
assert guardrail_info["guardrail_provider"] == "noma"
|
||||
assert guardrail_info.get("guardrail_status") == "success"
|
||||
assert guardrail_info.get("guardrail_provider") == "noma"
|
||||
|
||||
# Check status fields
|
||||
status_fields = test_custom_logger.standard_logging_payload.get("status_fields", {})
|
||||
@@ -679,8 +687,8 @@ def test_guardrail_status_fields_computation():
|
||||
guardrail_information=intervened_info,
|
||||
error_str=None
|
||||
)
|
||||
assert status_fields_intervened["llm_api_status"] == "success"
|
||||
assert status_fields_intervened["guardrail_status"] == "guardrail_intervened"
|
||||
assert status_fields_intervened.get("llm_api_status") == "success"
|
||||
assert status_fields_intervened.get("guardrail_status") == "guardrail_intervened"
|
||||
|
||||
# Test legacy blocked status (for backward compatibility)
|
||||
blocked_info = [{"guardrail_status": "blocked"}]
|
||||
@@ -689,8 +697,8 @@ def test_guardrail_status_fields_computation():
|
||||
guardrail_information=blocked_info,
|
||||
error_str=None
|
||||
)
|
||||
assert status_fields_blocked["llm_api_status"] == "success"
|
||||
assert status_fields_blocked["guardrail_status"] == "guardrail_intervened"
|
||||
assert status_fields_blocked.get("llm_api_status") == "success"
|
||||
assert status_fields_blocked.get("guardrail_status") == "guardrail_intervened"
|
||||
|
||||
# Test success status
|
||||
success_info = [{"guardrail_status": "success"}]
|
||||
@@ -699,8 +707,8 @@ def test_guardrail_status_fields_computation():
|
||||
guardrail_information=success_info,
|
||||
error_str=None
|
||||
)
|
||||
assert status_fields_success["llm_api_status"] == "success"
|
||||
assert status_fields_success["guardrail_status"] == "success"
|
||||
assert status_fields_success.get("llm_api_status") == "success"
|
||||
assert status_fields_success.get("guardrail_status") == "success"
|
||||
|
||||
# Test guardrail_failed_to_respond status
|
||||
failed_info = [{"guardrail_status": "guardrail_failed_to_respond"}]
|
||||
@@ -709,8 +717,8 @@ def test_guardrail_status_fields_computation():
|
||||
guardrail_information=failed_info,
|
||||
error_str=None
|
||||
)
|
||||
assert status_fields_failed["llm_api_status"] == "failure"
|
||||
assert status_fields_failed["guardrail_status"] == "guardrail_failed_to_respond"
|
||||
assert status_fields_failed.get("llm_api_status") == "failure"
|
||||
assert status_fields_failed.get("guardrail_status") == "guardrail_failed_to_respond"
|
||||
|
||||
# Test legacy failure status (for backward compatibility)
|
||||
failure_info = [{"guardrail_status": "failure"}]
|
||||
@@ -719,8 +727,8 @@ def test_guardrail_status_fields_computation():
|
||||
guardrail_information=failure_info,
|
||||
error_str=None
|
||||
)
|
||||
assert status_fields_failure["llm_api_status"] == "failure"
|
||||
assert status_fields_failure["guardrail_status"] == "guardrail_failed_to_respond"
|
||||
assert status_fields_failure.get("llm_api_status") == "failure"
|
||||
assert status_fields_failure.get("guardrail_status") == "guardrail_failed_to_respond"
|
||||
|
||||
# Test no guardrail run
|
||||
no_guardrail = None
|
||||
@@ -729,5 +737,5 @@ def test_guardrail_status_fields_computation():
|
||||
guardrail_information=no_guardrail,
|
||||
error_str=None
|
||||
)
|
||||
assert status_fields_no_guardrail["llm_api_status"] == "success"
|
||||
assert status_fields_no_guardrail["guardrail_status"] == "not_run"
|
||||
assert status_fields_no_guardrail.get("llm_api_status") == "success"
|
||||
assert status_fields_no_guardrail.get("guardrail_status") == "not_run"
|
||||
Reference in New Issue
Block a user