mirror of
https://github.com/tiennm99/litellm.git
synced 2026-06-18 03:31:23 +00:00
313 lines
12 KiB
Python
313 lines
12 KiB
Python
"""
|
|
Tests for exception header preservation.
|
|
|
|
These tests verify that when LLM providers return error responses with headers,
|
|
those headers are preserved in the exception and can be returned to clients.
|
|
|
|
This is important for debugging and observability - headers like x-request-id,
|
|
x-ms-region, rate limit headers, etc. should be available even when errors occur.
|
|
"""
|
|
|
|
import httpx
|
|
import pytest
|
|
|
|
from litellm.exceptions import (
|
|
BadRequestError,
|
|
ContentPolicyViolationError,
|
|
ContextWindowExceededError,
|
|
ImageFetchError,
|
|
MidStreamFallbackError,
|
|
RateLimitError,
|
|
)
|
|
|
|
|
|
class TestExceptionHeaderPreservation:
|
|
"""Test that exception classes preserve headers from provider responses."""
|
|
|
|
@pytest.fixture
|
|
def mock_response_with_headers(self) -> httpx.Response:
|
|
"""Create a mock response with typical provider headers."""
|
|
return httpx.Response(
|
|
status_code=400,
|
|
headers={
|
|
"x-request-id": "req-abc123",
|
|
"x-ms-region": "eastus",
|
|
"x-ratelimit-remaining-requests": "99",
|
|
"x-ratelimit-remaining-tokens": "9999",
|
|
},
|
|
request=httpx.Request("POST", "https://api.openai.com/v1/chat/completions"),
|
|
)
|
|
|
|
def test_bad_request_error_preserves_headers(
|
|
self, mock_response_with_headers: httpx.Response
|
|
):
|
|
"""BadRequestError should preserve headers from the provider response."""
|
|
error = BadRequestError(
|
|
message="Invalid request",
|
|
model="gpt-4",
|
|
llm_provider="azure",
|
|
response=mock_response_with_headers,
|
|
)
|
|
|
|
assert error.response is not None
|
|
assert error.response.headers.get("x-request-id") == "req-abc123"
|
|
assert error.response.headers.get("x-ms-region") == "eastus"
|
|
assert error.response.headers.get("x-ratelimit-remaining-requests") == "99"
|
|
|
|
def test_content_policy_violation_error_preserves_headers(
|
|
self, mock_response_with_headers: httpx.Response
|
|
):
|
|
"""ContentPolicyViolationError should preserve headers from the provider response."""
|
|
error = ContentPolicyViolationError(
|
|
message="Content policy violation",
|
|
model="gpt-4",
|
|
llm_provider="azure",
|
|
response=mock_response_with_headers,
|
|
)
|
|
|
|
assert error.response is not None
|
|
assert error.response.headers.get("x-request-id") == "req-abc123"
|
|
assert error.response.headers.get("x-ms-region") == "eastus"
|
|
|
|
def test_context_window_exceeded_error_preserves_headers(
|
|
self, mock_response_with_headers: httpx.Response
|
|
):
|
|
"""ContextWindowExceededError should preserve headers from the provider response."""
|
|
error = ContextWindowExceededError(
|
|
message="Context window exceeded",
|
|
model="gpt-4",
|
|
llm_provider="azure",
|
|
response=mock_response_with_headers,
|
|
)
|
|
|
|
assert error.response is not None
|
|
assert error.response.headers.get("x-request-id") == "req-abc123"
|
|
assert error.response.headers.get("x-ms-region") == "eastus"
|
|
|
|
def test_image_fetch_error_preserves_headers(
|
|
self, mock_response_with_headers: httpx.Response
|
|
):
|
|
"""ImageFetchError should preserve headers from the provider response."""
|
|
error = ImageFetchError(
|
|
message="Failed to fetch image",
|
|
model="gpt-4",
|
|
llm_provider="azure",
|
|
response=mock_response_with_headers,
|
|
)
|
|
|
|
assert error.response is not None
|
|
assert error.response.headers.get("x-request-id") == "req-abc123"
|
|
assert error.response.headers.get("x-ms-region") == "eastus"
|
|
|
|
def test_bad_request_error_handles_none_response(self):
|
|
"""BadRequestError should handle None response gracefully."""
|
|
error = BadRequestError(
|
|
message="Invalid request",
|
|
model="gpt-4",
|
|
llm_provider="azure",
|
|
response=None,
|
|
)
|
|
|
|
assert error.response is not None
|
|
# Headers should be empty but not cause an error
|
|
assert error.response.headers.get("x-request-id") is None
|
|
|
|
def test_content_policy_violation_error_handles_none_response(self):
|
|
"""ContentPolicyViolationError should handle None response gracefully."""
|
|
error = ContentPolicyViolationError(
|
|
message="Content policy violation",
|
|
model="gpt-4",
|
|
llm_provider="azure",
|
|
response=None,
|
|
)
|
|
|
|
assert error.response is not None
|
|
assert error.response.headers.get("x-request-id") is None
|
|
|
|
def test_context_window_exceeded_error_handles_none_response(self):
|
|
"""ContextWindowExceededError should handle None response gracefully."""
|
|
error = ContextWindowExceededError(
|
|
message="Context window exceeded",
|
|
model="gpt-4",
|
|
llm_provider="azure",
|
|
response=None,
|
|
)
|
|
|
|
assert error.response is not None
|
|
assert error.response.headers.get("x-request-id") is None
|
|
|
|
|
|
class TestExceptionMessageFormatting:
|
|
"""Test that exception messages are formatted correctly after refactoring."""
|
|
|
|
def test_bad_request_error_message_format(self):
|
|
"""BadRequestError should format message with litellm prefix."""
|
|
error = BadRequestError(
|
|
message="test error",
|
|
model="gpt-4",
|
|
llm_provider="azure",
|
|
)
|
|
|
|
assert "litellm.BadRequestError" in error.message
|
|
assert "test error" in error.message
|
|
|
|
def test_content_policy_violation_error_message_format(self):
|
|
"""ContentPolicyViolationError should format message with specific prefix."""
|
|
error = ContentPolicyViolationError(
|
|
message="test error",
|
|
model="gpt-4",
|
|
llm_provider="azure",
|
|
)
|
|
|
|
assert "litellm.ContentPolicyViolationError" in error.message
|
|
assert "test error" in error.message
|
|
|
|
def test_context_window_exceeded_error_message_format(self):
|
|
"""ContextWindowExceededError should format message with specific prefix."""
|
|
error = ContextWindowExceededError(
|
|
message="test error",
|
|
model="gpt-4",
|
|
llm_provider="azure",
|
|
)
|
|
|
|
assert "litellm.ContextWindowExceededError" in error.message
|
|
assert "test error" in error.message
|
|
|
|
|
|
class TestExceptionAttributes:
|
|
"""Test that exception attributes are set correctly."""
|
|
|
|
def test_content_policy_violation_error_provider_specific_fields(self):
|
|
"""ContentPolicyViolationError should preserve provider_specific_fields."""
|
|
provider_fields = {"innererror": {"code": "ResponsibleAIPolicyViolation"}}
|
|
|
|
error = ContentPolicyViolationError(
|
|
message="test error",
|
|
model="gpt-4",
|
|
llm_provider="azure",
|
|
provider_specific_fields=provider_fields,
|
|
)
|
|
|
|
assert error.provider_specific_fields == provider_fields
|
|
assert (
|
|
error.provider_specific_fields["innererror"]["code"]
|
|
== "ResponsibleAIPolicyViolation"
|
|
)
|
|
|
|
def test_bad_request_error_attributes(self):
|
|
"""BadRequestError should set all expected attributes."""
|
|
error = BadRequestError(
|
|
message="test error",
|
|
model="gpt-4",
|
|
llm_provider="azure",
|
|
litellm_debug_info="debug info",
|
|
max_retries=3,
|
|
num_retries=1,
|
|
)
|
|
|
|
assert error.model == "gpt-4"
|
|
assert error.llm_provider == "azure"
|
|
assert error.litellm_debug_info == "debug info"
|
|
assert error.max_retries == 3
|
|
assert error.num_retries == 1
|
|
assert error.status_code == 400
|
|
|
|
def test_midstream_fallback_error_status_code_propagation(self):
|
|
"""
|
|
MidStreamFallbackError should preserve the original status code and keep
|
|
message/request/response fields consistent after super().__init__().
|
|
"""
|
|
original_req = httpx.Request("POST", "https://api.openai.com/v1/chat/completions")
|
|
original_resp = httpx.Response(status_code=429, request=original_req)
|
|
|
|
rate_limit_error = RateLimitError(
|
|
message="Rate limit exceeded",
|
|
llm_provider="openai",
|
|
model="gpt-4o-mini",
|
|
response=original_resp,
|
|
)
|
|
|
|
midstream_error = MidStreamFallbackError(
|
|
message="stream broke",
|
|
model="gpt-4o-mini",
|
|
llm_provider="openai",
|
|
original_exception=rate_limit_error,
|
|
)
|
|
|
|
assert midstream_error.status_code == 429
|
|
assert midstream_error.response.status_code == 429
|
|
assert str(midstream_error.response.request.url) == "https://openai.com/v1/"
|
|
assert midstream_error.message == "litellm.MidStreamFallbackError: stream broke"
|
|
assert midstream_error.args == ("litellm.MidStreamFallbackError: stream broke",)
|
|
|
|
# With no original exception, should default to 503.
|
|
midstream_fallback = MidStreamFallbackError(
|
|
message="stream broke without original",
|
|
model="gpt-4o-mini",
|
|
llm_provider="openai",
|
|
original_exception=None,
|
|
)
|
|
|
|
assert midstream_fallback.status_code == 503
|
|
assert midstream_fallback.response.status_code == 503
|
|
assert str(midstream_fallback.response.request.url) == "https://openai.com/v1/"
|
|
|
|
|
|
class TestProxyHeaderExtraction:
|
|
"""Test that proxy correctly extracts headers from exceptions."""
|
|
|
|
def test_get_response_headers_adds_llm_provider_prefix(self):
|
|
"""get_response_headers should prefix non-OpenAI headers with llm_provider-."""
|
|
from litellm.litellm_core_utils.llm_response_utils.get_headers import (
|
|
get_response_headers,
|
|
)
|
|
|
|
response_headers = {
|
|
"x-request-id": "req-abc123",
|
|
"x-ms-region": "eastus",
|
|
"x-ratelimit-remaining-requests": "99", # OpenAI header - should not be prefixed
|
|
}
|
|
|
|
result = get_response_headers(response_headers)
|
|
|
|
# OpenAI ratelimit headers should be preserved as-is
|
|
assert result.get("x-ratelimit-remaining-requests") == "99"
|
|
# Other headers should be prefixed with llm_provider-
|
|
assert result.get("llm_provider-x-request-id") == "req-abc123"
|
|
assert result.get("llm_provider-x-ms-region") == "eastus"
|
|
|
|
def test_proxy_can_extract_headers_from_exception_response(self):
|
|
"""Simulate how proxy extracts headers from exception.response.headers."""
|
|
from litellm.litellm_core_utils.llm_response_utils.get_headers import (
|
|
get_response_headers,
|
|
)
|
|
|
|
# Create exception with headers in response
|
|
mock_response = httpx.Response(
|
|
status_code=400,
|
|
headers={
|
|
"x-request-id": "req-abc123",
|
|
"x-ms-region": "eastus",
|
|
},
|
|
request=httpx.Request("POST", "https://test.com"),
|
|
)
|
|
error = ContentPolicyViolationError(
|
|
message="test",
|
|
model="gpt-4",
|
|
llm_provider="azure",
|
|
response=mock_response,
|
|
)
|
|
|
|
# Simulate proxy header extraction logic
|
|
headers = getattr(error, "headers", None) or {}
|
|
if not headers:
|
|
_response = getattr(error, "response", None)
|
|
if _response is not None:
|
|
_response_headers = getattr(_response, "headers", None)
|
|
if _response_headers:
|
|
headers = get_response_headers(dict(_response_headers))
|
|
|
|
# Verify headers are extracted and prefixed correctly
|
|
assert headers.get("llm_provider-x-request-id") == "req-abc123"
|
|
assert headers.get("llm_provider-x-ms-region") == "eastus"
|