Files
litellm/tests/logging_callback_tests/test_gcs_pub_sub.py
T
ishaan-berri 8a9faa81b2 feat(guardrails): LLM-as-a-Judge guardrail (#26360)
* feat(guardrails): add LLM_AS_A_JUDGE to SupportedGuardrailIntegrations

* feat(types): add EvalVerdict, StandardLoggingEvalInformation; wire eval_information into SpendLogsMetadata

* feat(guardrails): add self-contained llm_as_a_judge guardrail hook

* fix(a2a): filter agent-only litellm_params from acompletion kwargs; pass agent_id into body

* feat(ui): add LLMJudgeFields criteria builder component

* feat(ui): wire LLM-as-a-Judge into add guardrail form

* feat(ui): update EvalViewer — title 'LLM Judge Results', weighted score column, summary row

* fix(ui): wire EvalViewer into LogDetailContent to show LLM judge results on logs page

* fix(guardrails-ui): route llm_as_a_judge to criteria builder step; rename to LiteLLM LLM as a Judge; add litellm logo

* fix(guardrail-viewer): stack lifecycle + eval details vertically to avoid badge overflow in narrow drawer

* fix(guardrail-create): surface config validation errors on create instead of silently orphaning guardrail in DB

* fix(guardrail-registry): hardcode llm_as_a_judge in initializer registry so it loads regardless of package install path

* fix(llm-as-a-judge): fix P1 code quality issues - validate weights/on_failure, guard pre_call, handle multimodal, move imports to module level, fix spurious finally logging

* fix(guardrail_endpoints): use correct PK field in rollback delete and log rollback failure

* fix(llm_as_a_judge): support Pydantic object in _get_litellm_param fallback chain

* fix(LLMJudgeFields): replace @tremor/react Button with antd Button

* fix(llm_as_a_judge): remove dead registry dicts, fix KeyError in prompt builder, set correct status on judge failure

* test(llm_as_a_judge): add unit tests for guardrail hook

* fix(llm_as_a_judge): remove @log_guardrail_information decorator to fix duplicate guardrail_information entries

The decorator and the manual finally block both called add_standard_logging_guardrail_information_to_request_data, producing two entries per request. The decorator also misclassified HTTPException(422) blocks as guardrail_failed_to_respond (it checks for 400). The finally block correctly tracks status throughout, so removing the decorator is sufficient.

* fix(test_gcs_pub_sub): ignore metadata.eval_information in comparison

* fix(test_spend_management): ignore metadata.eval_information in payload comparison

* fix(types/guardrails): add input_type and messages to ApplyGuardrailRequest

* fix(guardrail_endpoints): pass input_type and messages through apply_guardrail endpoint

* fix(guardrail_endpoints): auto-detect post_call guardrails and use input_type=response

* fix(a2a_endpoints): merge agent litellm_params guardrails into data before post_call hooks

* fix(llm_as_a_judge): use float sum with tolerance for weight validation

* fix(guardrail_registry): split long import line for black formatting

* fix(llm_as_a_judge): guard guardrail_name Optional for mypy

* fix(llm_as_a_judge): set guardrail_status=guardrail_intervened when score fails, regardless of on_failure mode

* fix(a2a_endpoints): use try/finally so deferred spend log fires even when guardrail blocks with 422

* fix(litellm_logging): declare _defer_async_logging and _enqueue_deferred_logging on Logging class for mypy

* fix(logging_worker): restore queue.join() in flush() to wait for in-flight callbacks
2026-04-24 17:15:32 -07:00

298 lines
9.7 KiB
Python

import io
import os
import sys
sys.path.insert(0, os.path.abspath("../.."))
import asyncio
import litellm
import gzip
import json
import logging
import time
from unittest.mock import AsyncMock, patch
import pytest
import litellm
from litellm import completion
from litellm._logging import verbose_logger
from litellm.integrations.gcs_pubsub.pub_sub import *
from datetime import datetime, timedelta
from litellm.types.utils import (
StandardLoggingPayload,
StandardLoggingModelInformation,
StandardLoggingMetadata,
StandardLoggingHiddenParams,
)
verbose_logger.setLevel(logging.DEBUG)
ignored_keys = [
"request_id",
"session_id",
"startTime",
"endTime",
"completionStartTime",
"endTime",
"request_duration_ms",
"metadata.model_map_information",
"metadata.usage_object",
"metadata.cold_storage_object_key",
"metadata.litellm_overhead_time_ms",
"metadata.cost_breakdown",
"metadata.eval_information",
]
def _compare_nested_dicts(
actual: dict, expected: dict, path: str = "", ignore_keys: list[str] = []
) -> list[str]:
"""Compare nested dictionaries and return a list of differences in a human-friendly format."""
differences = []
# Check if current path should be ignored
if path in ignore_keys:
return differences
# Check for keys in actual but not in expected
for key in actual.keys():
current_path = f"{path}.{key}" if path else key
if current_path not in ignore_keys and key not in expected:
differences.append(f"Extra key in actual: {current_path}")
for key, expected_value in expected.items():
current_path = f"{path}.{key}" if path else key
if current_path in ignore_keys:
continue
if key not in actual:
differences.append(f"Missing key: {current_path}")
continue
actual_value = actual[key]
# Try to parse JSON strings
if isinstance(expected_value, str):
try:
expected_value = json.loads(expected_value)
except json.JSONDecodeError:
pass
if isinstance(actual_value, str):
try:
actual_value = json.loads(actual_value)
except json.JSONDecodeError:
pass
if isinstance(expected_value, dict) and isinstance(actual_value, dict):
differences.extend(
_compare_nested_dicts(
actual_value, expected_value, current_path, ignore_keys
)
)
elif isinstance(expected_value, dict) or isinstance(actual_value, dict):
differences.append(
f"Type mismatch at {current_path}: expected dict, got {type(actual_value).__name__}"
)
else:
# For non-dict values, only report if they're different
if actual_value != expected_value:
# Format the values to be more readable
actual_str = str(actual_value)
expected_str = str(expected_value)
if len(actual_str) > 50 or len(expected_str) > 50:
actual_str = f"{actual_str[:50]}..."
expected_str = f"{expected_str[:50]}..."
differences.append(
f"Value mismatch at {current_path}:\n expected: {expected_str}\n got: {actual_str}"
)
return differences
def assert_gcs_pubsub_request_matches_expected(
actual_request_body: dict,
expected_file_name: str,
):
"""
Helper function to compare actual GCS PubSub request body with expected JSON file.
Args:
actual_request_body (dict): The actual request body received from the API call
expected_file_name (str): Name of the JSON file containing expected request body
"""
# Get the current directory and read the expected request body
pwd = os.path.dirname(os.path.realpath(__file__))
expected_body_path = os.path.join(pwd, "gcs_pub_sub_body", expected_file_name)
with open(expected_body_path, "r") as f:
expected_request_body = json.load(f)
# Replace dynamic values in actual request body
differences = _compare_nested_dicts(
actual_request_body, expected_request_body, ignore_keys=ignored_keys
)
if differences:
assert False, f"Dictionary mismatch: {differences}"
def assert_gcs_pubsub_request_matches_expected_standard_logging_payload(
actual_request_body: dict,
expected_file_name: str,
):
"""
Helper function to compare actual GCS PubSub request body with expected JSON file.
Args:
actual_request_body (dict): The actual request body received from the API call
expected_file_name (str): Name of the JSON file containing expected request body
"""
# Get the current directory and read the expected request body
pwd = os.path.dirname(os.path.realpath(__file__))
expected_body_path = os.path.join(pwd, "gcs_pub_sub_body", expected_file_name)
with open(expected_body_path, "r") as f:
expected_request_body = json.load(f)
# Replace dynamic values in actual request body
FIELDS_TO_VALIDATE = [
"custom_llm_provider",
"hidden_params",
"messages",
"response",
"model",
"status",
"stream",
]
actual_request_body["response"]["id"] = expected_request_body["response"]["id"]
actual_request_body["response"]["created"] = expected_request_body["response"][
"created"
]
for field in FIELDS_TO_VALIDATE:
assert field in actual_request_body
FIELDS_EXISTENCE_CHECKS = [
"response_cost",
"response_time",
"completion_tokens",
"prompt_tokens",
"total_tokens",
]
for field in FIELDS_EXISTENCE_CHECKS:
assert field in actual_request_body
@pytest.mark.asyncio
async def test_async_gcs_pub_sub():
# Create a mock for the async_httpx_client's post method
mock_post = AsyncMock()
mock_post.return_value.status_code = 202
mock_post.return_value.text = "Accepted"
# Initialize the GcsPubSubLogger and set the mock
gcs_pub_sub_logger = GcsPubSubLogger(
project_id="STUBBED_PROJECT_ID", topic_id="STUBBED_TOPIC_ID", flush_interval=1
)
gcs_pub_sub_logger.async_httpx_client.post = mock_post
mock_construct_request_headers = AsyncMock()
mock_construct_request_headers.return_value = {"Authorization": "Bearer mock_token"}
gcs_pub_sub_logger.construct_request_headers = mock_construct_request_headers
litellm.callbacks = [gcs_pub_sub_logger]
# Make the completion call
response = await litellm.acompletion(
model="gpt-4o",
messages=[{"role": "user", "content": "Hello, world!"}],
mock_response="hi",
)
await asyncio.sleep(3) # Wait for async flush
# Assert httpx post was called
mock_post.assert_called_once()
# Get the actual request body from the mock
actual_url = mock_post.call_args[1]["url"]
print("sent to url", actual_url)
assert (
actual_url
== "https://pubsub.googleapis.com/v1/projects/STUBBED_PROJECT_ID/topics/STUBBED_TOPIC_ID:publish"
)
actual_request = mock_post.call_args[1]["json"]
# Extract and decode the base64 encoded message
encoded_message = actual_request["messages"][0]["data"]
import base64
decoded_message = base64.b64decode(encoded_message).decode("utf-8")
# Parse the JSON string into a dictionary
actual_request = json.loads(decoded_message)
print("##########\n")
print(json.dumps(actual_request, indent=4))
print("##########\n")
# Verify the request body matches expected format
assert_gcs_pubsub_request_matches_expected_standard_logging_payload(
actual_request, "standard_logging_payload.json"
)
@pytest.mark.asyncio
async def test_async_gcs_pub_sub_v1():
# Create a mock for the async_httpx_client's post method
litellm.gcs_pub_sub_use_v1 = True
mock_post = AsyncMock()
mock_post.return_value.status_code = 202
mock_post.return_value.text = "Accepted"
# Initialize the GcsPubSubLogger and set the mock
gcs_pub_sub_logger = GcsPubSubLogger(
project_id="STUBBED_PROJECT_ID", topic_id="STUBBED_TOPIC_ID", flush_interval=1
)
gcs_pub_sub_logger.async_httpx_client.post = mock_post
mock_construct_request_headers = AsyncMock()
mock_construct_request_headers.return_value = {"Authorization": "Bearer mock_token"}
gcs_pub_sub_logger.construct_request_headers = mock_construct_request_headers
litellm.callbacks = [gcs_pub_sub_logger]
# Make the completion call
response = await litellm.acompletion(
model="gpt-4o",
messages=[{"role": "user", "content": "Hello, world!"}],
mock_response="hi",
)
await asyncio.sleep(3) # Wait for async flush
# Assert httpx post was called
mock_post.assert_called_once()
# Get the actual request body from the mock
actual_url = mock_post.call_args[1]["url"]
print("sent to url", actual_url)
assert (
actual_url
== "https://pubsub.googleapis.com/v1/projects/STUBBED_PROJECT_ID/topics/STUBBED_TOPIC_ID:publish"
)
actual_request = mock_post.call_args[1]["json"]
# Extract and decode the base64 encoded message
encoded_message = actual_request["messages"][0]["data"]
import base64
decoded_message = base64.b64decode(encoded_message).decode("utf-8")
# Parse the JSON string into a dictionary
actual_request = json.loads(decoded_message)
print("##########\n")
print(json.dumps(actual_request, indent=4))
print("##########\n")
# Verify the request body matches expected format
assert_gcs_pubsub_request_matches_expected(
actual_request, "spend_logs_payload.json"
)