mirror of
https://github.com/tiennm99/litellm.git
synced 2026-06-18 17:28:19 +00:00
fix: preserve llm_provider-* headers in error responses (#19020)
Extract and preserve provider-specific headers (llm_provider-*) when handling error responses from LLM providers. This ensures that useful debugging information from providers is available even when requests fail with BadRequestError or similar exceptions.
This commit is contained in:
+10
-12
@@ -125,16 +125,20 @@ class BadRequestError(openai.BadRequestError): # type: ignore
|
||||
self.model = model
|
||||
self.llm_provider = llm_provider
|
||||
self.litellm_debug_info = litellm_debug_info
|
||||
response = httpx.Response(
|
||||
self.max_retries = max_retries
|
||||
self.num_retries = num_retries
|
||||
_response_headers = (
|
||||
getattr(response, "headers", None) if response is not None else None
|
||||
)
|
||||
self.response = httpx.Response(
|
||||
status_code=self.status_code,
|
||||
headers=_response_headers,
|
||||
request=httpx.Request(
|
||||
method="GET", url="https://litellm.ai"
|
||||
), # mock request object
|
||||
)
|
||||
self.max_retries = max_retries
|
||||
self.num_retries = num_retries
|
||||
super().__init__(
|
||||
self.message, response=response, body=body
|
||||
self.message, response=self.response, body=body
|
||||
) # Call the base class constructor with the parameters it needs
|
||||
|
||||
def __str__(self):
|
||||
@@ -368,13 +372,11 @@ class ContextWindowExceededError(BadRequestError): # type: ignore
|
||||
self.model = model
|
||||
self.llm_provider = llm_provider
|
||||
self.litellm_debug_info = litellm_debug_info
|
||||
request = httpx.Request(method="POST", url="https://api.openai.com/v1")
|
||||
self.response = httpx.Response(status_code=400, request=request)
|
||||
super().__init__(
|
||||
message=message,
|
||||
model=self.model, # type: ignore
|
||||
llm_provider=self.llm_provider, # type: ignore
|
||||
response=self.response,
|
||||
response=response,
|
||||
litellm_debug_info=self.litellm_debug_info,
|
||||
) # Call the base class constructor with the parameters it needs
|
||||
|
||||
@@ -457,18 +459,14 @@ class ContentPolicyViolationError(BadRequestError): # type: ignore
|
||||
self.model = model
|
||||
self.llm_provider = llm_provider
|
||||
self.litellm_debug_info = litellm_debug_info
|
||||
request = httpx.Request(method="POST", url="https://api.openai.com/v1")
|
||||
self.response = httpx.Response(status_code=400, request=request)
|
||||
self.provider_specific_fields = provider_specific_fields
|
||||
|
||||
super().__init__(
|
||||
message=self.message,
|
||||
model=self.model, # type: ignore
|
||||
llm_provider=self.llm_provider, # type: ignore
|
||||
response=self.response,
|
||||
response=response,
|
||||
litellm_debug_info=self.litellm_debug_info,
|
||||
) # Call the base class constructor with the parameters it needs
|
||||
|
||||
|
||||
def __str__(self):
|
||||
return self._transform_error_to_string()
|
||||
|
||||
@@ -7,6 +7,7 @@ class AzureOpenAIExceptionMapping:
|
||||
"""
|
||||
Class for creating Azure OpenAI specific exceptions
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def create_content_policy_violation_error(
|
||||
message: str,
|
||||
@@ -16,18 +17,20 @@ class AzureOpenAIExceptionMapping:
|
||||
) -> ContentPolicyViolationError:
|
||||
"""
|
||||
Create a content policy violation error
|
||||
"""
|
||||
"""
|
||||
raise ContentPolicyViolationError(
|
||||
message=f"litellm.ContentPolicyViolationError: AzureException - {message}",
|
||||
message=f"AzureException - {message}",
|
||||
llm_provider="azure",
|
||||
model=model,
|
||||
litellm_debug_info=extra_information,
|
||||
response=getattr(original_exception, "response", None),
|
||||
provider_specific_fields={
|
||||
"innererror": AzureOpenAIExceptionMapping._get_innererror_from_exception(original_exception)
|
||||
"innererror": AzureOpenAIExceptionMapping._get_innererror_from_exception(
|
||||
original_exception
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@staticmethod
|
||||
def _get_innererror_from_exception(original_exception: Exception) -> Optional[dict]:
|
||||
"""
|
||||
@@ -39,4 +42,3 @@ class AzureOpenAIExceptionMapping:
|
||||
if isinstance(body_dict, dict):
|
||||
innererror = body_dict.get("innererror")
|
||||
return innererror
|
||||
|
||||
@@ -28,6 +28,9 @@ from litellm.constants import (
|
||||
)
|
||||
from litellm.litellm_core_utils.dd_tracing import tracer
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||
from litellm.litellm_core_utils.llm_response_utils.get_headers import (
|
||||
get_response_headers,
|
||||
)
|
||||
from litellm.litellm_core_utils.safe_json_dumps import safe_dumps
|
||||
from litellm.proxy._types import ProxyException, UserAPIKeyAuth
|
||||
from litellm.proxy.auth.auth_utils import check_response_size_is_safe
|
||||
@@ -284,7 +287,12 @@ class ProxyBaseLLMRequestProcessing:
|
||||
hidden_params = hidden_params or {}
|
||||
|
||||
# Extract discount and margin info from cost_breakdown if available
|
||||
original_cost, discount_amount, margin_total_amount, margin_percent = _get_cost_breakdown_from_logging_obj(
|
||||
(
|
||||
original_cost,
|
||||
discount_amount,
|
||||
margin_total_amount,
|
||||
margin_percent,
|
||||
) = _get_cost_breakdown_from_logging_obj(
|
||||
litellm_logging_obj=litellm_logging_obj
|
||||
)
|
||||
|
||||
@@ -527,12 +535,12 @@ class ProxyBaseLLMRequestProcessing:
|
||||
# Apply hierarchical router_settings (Key > Team > Global)
|
||||
if llm_router is not None and proxy_config is not None:
|
||||
from litellm.proxy.proxy_server import prisma_client
|
||||
|
||||
|
||||
router_settings = await proxy_config._get_hierarchical_router_settings(
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
prisma_client=prisma_client,
|
||||
)
|
||||
|
||||
|
||||
# If router_settings found (from key, team, or global), apply them
|
||||
# This ensures key/team settings override global settings
|
||||
if router_settings is not None and router_settings:
|
||||
@@ -541,10 +549,7 @@ class ProxyBaseLLMRequestProcessing:
|
||||
if model_list is not None:
|
||||
# Create user_config with model_list and router_settings
|
||||
# This creates a per-request router with the hierarchical settings
|
||||
user_config = {
|
||||
"model_list": model_list,
|
||||
**router_settings
|
||||
}
|
||||
user_config = {"model_list": model_list, **router_settings}
|
||||
self.data["user_config"] = user_config
|
||||
|
||||
if "messages" in self.data and self.data["messages"]:
|
||||
@@ -943,7 +948,15 @@ class ProxyBaseLLMRequestProcessing:
|
||||
timeout=timeout,
|
||||
litellm_logging_obj=_litellm_logging_obj,
|
||||
)
|
||||
headers = getattr(e, "headers", {}) or {}
|
||||
# Extract headers from exception - check both e.headers and e.response.headers
|
||||
headers = getattr(e, "headers", None) or {}
|
||||
if not headers:
|
||||
# Try to get headers from e.response.headers (httpx.Response)
|
||||
_response = getattr(e, "response", None)
|
||||
if _response is not None:
|
||||
_response_headers = getattr(_response, "headers", None)
|
||||
if _response_headers:
|
||||
headers = get_response_headers(dict(_response_headers))
|
||||
headers.update(custom_headers)
|
||||
|
||||
if isinstance(e, HTTPException):
|
||||
|
||||
@@ -0,0 +1,270 @@
|
||||
"""
|
||||
Tests for exception header preservation.
|
||||
|
||||
These tests verify that when LLM providers return error responses with headers,
|
||||
those headers are preserved in the exception and can be returned to clients.
|
||||
|
||||
This is important for debugging and observability - headers like x-request-id,
|
||||
x-ms-region, rate limit headers, etc. should be available even when errors occur.
|
||||
"""
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
from litellm.exceptions import (
|
||||
BadRequestError,
|
||||
ContentPolicyViolationError,
|
||||
ContextWindowExceededError,
|
||||
ImageFetchError,
|
||||
)
|
||||
|
||||
|
||||
class TestExceptionHeaderPreservation:
|
||||
"""Test that exception classes preserve headers from provider responses."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_response_with_headers(self) -> httpx.Response:
|
||||
"""Create a mock response with typical provider headers."""
|
||||
return httpx.Response(
|
||||
status_code=400,
|
||||
headers={
|
||||
"x-request-id": "req-abc123",
|
||||
"x-ms-region": "eastus",
|
||||
"x-ratelimit-remaining-requests": "99",
|
||||
"x-ratelimit-remaining-tokens": "9999",
|
||||
},
|
||||
request=httpx.Request("POST", "https://api.openai.com/v1/chat/completions"),
|
||||
)
|
||||
|
||||
def test_bad_request_error_preserves_headers(
|
||||
self, mock_response_with_headers: httpx.Response
|
||||
):
|
||||
"""BadRequestError should preserve headers from the provider response."""
|
||||
error = BadRequestError(
|
||||
message="Invalid request",
|
||||
model="gpt-4",
|
||||
llm_provider="azure",
|
||||
response=mock_response_with_headers,
|
||||
)
|
||||
|
||||
assert error.response is not None
|
||||
assert error.response.headers.get("x-request-id") == "req-abc123"
|
||||
assert error.response.headers.get("x-ms-region") == "eastus"
|
||||
assert error.response.headers.get("x-ratelimit-remaining-requests") == "99"
|
||||
|
||||
def test_content_policy_violation_error_preserves_headers(
|
||||
self, mock_response_with_headers: httpx.Response
|
||||
):
|
||||
"""ContentPolicyViolationError should preserve headers from the provider response."""
|
||||
error = ContentPolicyViolationError(
|
||||
message="Content policy violation",
|
||||
model="gpt-4",
|
||||
llm_provider="azure",
|
||||
response=mock_response_with_headers,
|
||||
)
|
||||
|
||||
assert error.response is not None
|
||||
assert error.response.headers.get("x-request-id") == "req-abc123"
|
||||
assert error.response.headers.get("x-ms-region") == "eastus"
|
||||
|
||||
def test_context_window_exceeded_error_preserves_headers(
|
||||
self, mock_response_with_headers: httpx.Response
|
||||
):
|
||||
"""ContextWindowExceededError should preserve headers from the provider response."""
|
||||
error = ContextWindowExceededError(
|
||||
message="Context window exceeded",
|
||||
model="gpt-4",
|
||||
llm_provider="azure",
|
||||
response=mock_response_with_headers,
|
||||
)
|
||||
|
||||
assert error.response is not None
|
||||
assert error.response.headers.get("x-request-id") == "req-abc123"
|
||||
assert error.response.headers.get("x-ms-region") == "eastus"
|
||||
|
||||
def test_image_fetch_error_preserves_headers(
|
||||
self, mock_response_with_headers: httpx.Response
|
||||
):
|
||||
"""ImageFetchError should preserve headers from the provider response."""
|
||||
error = ImageFetchError(
|
||||
message="Failed to fetch image",
|
||||
model="gpt-4",
|
||||
llm_provider="azure",
|
||||
response=mock_response_with_headers,
|
||||
)
|
||||
|
||||
assert error.response is not None
|
||||
assert error.response.headers.get("x-request-id") == "req-abc123"
|
||||
assert error.response.headers.get("x-ms-region") == "eastus"
|
||||
|
||||
def test_bad_request_error_handles_none_response(self):
|
||||
"""BadRequestError should handle None response gracefully."""
|
||||
error = BadRequestError(
|
||||
message="Invalid request",
|
||||
model="gpt-4",
|
||||
llm_provider="azure",
|
||||
response=None,
|
||||
)
|
||||
|
||||
assert error.response is not None
|
||||
# Headers should be empty but not cause an error
|
||||
assert error.response.headers.get("x-request-id") is None
|
||||
|
||||
def test_content_policy_violation_error_handles_none_response(self):
|
||||
"""ContentPolicyViolationError should handle None response gracefully."""
|
||||
error = ContentPolicyViolationError(
|
||||
message="Content policy violation",
|
||||
model="gpt-4",
|
||||
llm_provider="azure",
|
||||
response=None,
|
||||
)
|
||||
|
||||
assert error.response is not None
|
||||
assert error.response.headers.get("x-request-id") is None
|
||||
|
||||
def test_context_window_exceeded_error_handles_none_response(self):
|
||||
"""ContextWindowExceededError should handle None response gracefully."""
|
||||
error = ContextWindowExceededError(
|
||||
message="Context window exceeded",
|
||||
model="gpt-4",
|
||||
llm_provider="azure",
|
||||
response=None,
|
||||
)
|
||||
|
||||
assert error.response is not None
|
||||
assert error.response.headers.get("x-request-id") is None
|
||||
|
||||
|
||||
class TestExceptionMessageFormatting:
|
||||
"""Test that exception messages are formatted correctly after refactoring."""
|
||||
|
||||
def test_bad_request_error_message_format(self):
|
||||
"""BadRequestError should format message with litellm prefix."""
|
||||
error = BadRequestError(
|
||||
message="test error",
|
||||
model="gpt-4",
|
||||
llm_provider="azure",
|
||||
)
|
||||
|
||||
assert "litellm.BadRequestError" in error.message
|
||||
assert "test error" in error.message
|
||||
|
||||
def test_content_policy_violation_error_message_format(self):
|
||||
"""ContentPolicyViolationError should format message with specific prefix."""
|
||||
error = ContentPolicyViolationError(
|
||||
message="test error",
|
||||
model="gpt-4",
|
||||
llm_provider="azure",
|
||||
)
|
||||
|
||||
assert "litellm.ContentPolicyViolationError" in error.message
|
||||
assert "test error" in error.message
|
||||
|
||||
def test_context_window_exceeded_error_message_format(self):
|
||||
"""ContextWindowExceededError should format message with specific prefix."""
|
||||
error = ContextWindowExceededError(
|
||||
message="test error",
|
||||
model="gpt-4",
|
||||
llm_provider="azure",
|
||||
)
|
||||
|
||||
assert "litellm.ContextWindowExceededError" in error.message
|
||||
assert "test error" in error.message
|
||||
|
||||
|
||||
class TestExceptionAttributes:
|
||||
"""Test that exception attributes are set correctly."""
|
||||
|
||||
def test_content_policy_violation_error_provider_specific_fields(self):
|
||||
"""ContentPolicyViolationError should preserve provider_specific_fields."""
|
||||
provider_fields = {"innererror": {"code": "ResponsibleAIPolicyViolation"}}
|
||||
|
||||
error = ContentPolicyViolationError(
|
||||
message="test error",
|
||||
model="gpt-4",
|
||||
llm_provider="azure",
|
||||
provider_specific_fields=provider_fields,
|
||||
)
|
||||
|
||||
assert error.provider_specific_fields == provider_fields
|
||||
assert (
|
||||
error.provider_specific_fields["innererror"]["code"]
|
||||
== "ResponsibleAIPolicyViolation"
|
||||
)
|
||||
|
||||
def test_bad_request_error_attributes(self):
|
||||
"""BadRequestError should set all expected attributes."""
|
||||
error = BadRequestError(
|
||||
message="test error",
|
||||
model="gpt-4",
|
||||
llm_provider="azure",
|
||||
litellm_debug_info="debug info",
|
||||
max_retries=3,
|
||||
num_retries=1,
|
||||
)
|
||||
|
||||
assert error.model == "gpt-4"
|
||||
assert error.llm_provider == "azure"
|
||||
assert error.litellm_debug_info == "debug info"
|
||||
assert error.max_retries == 3
|
||||
assert error.num_retries == 1
|
||||
assert error.status_code == 400
|
||||
|
||||
|
||||
class TestProxyHeaderExtraction:
|
||||
"""Test that proxy correctly extracts headers from exceptions."""
|
||||
|
||||
def test_get_response_headers_adds_llm_provider_prefix(self):
|
||||
"""get_response_headers should prefix non-OpenAI headers with llm_provider-."""
|
||||
from litellm.litellm_core_utils.llm_response_utils.get_headers import (
|
||||
get_response_headers,
|
||||
)
|
||||
|
||||
response_headers = {
|
||||
"x-request-id": "req-abc123",
|
||||
"x-ms-region": "eastus",
|
||||
"x-ratelimit-remaining-requests": "99", # OpenAI header - should not be prefixed
|
||||
}
|
||||
|
||||
result = get_response_headers(response_headers)
|
||||
|
||||
# OpenAI ratelimit headers should be preserved as-is
|
||||
assert result.get("x-ratelimit-remaining-requests") == "99"
|
||||
# Other headers should be prefixed with llm_provider-
|
||||
assert result.get("llm_provider-x-request-id") == "req-abc123"
|
||||
assert result.get("llm_provider-x-ms-region") == "eastus"
|
||||
|
||||
def test_proxy_can_extract_headers_from_exception_response(self):
|
||||
"""Simulate how proxy extracts headers from exception.response.headers."""
|
||||
from litellm.litellm_core_utils.llm_response_utils.get_headers import (
|
||||
get_response_headers,
|
||||
)
|
||||
|
||||
# Create exception with headers in response
|
||||
mock_response = httpx.Response(
|
||||
status_code=400,
|
||||
headers={
|
||||
"x-request-id": "req-abc123",
|
||||
"x-ms-region": "eastus",
|
||||
},
|
||||
request=httpx.Request("POST", "https://test.com"),
|
||||
)
|
||||
error = ContentPolicyViolationError(
|
||||
message="test",
|
||||
model="gpt-4",
|
||||
llm_provider="azure",
|
||||
response=mock_response,
|
||||
)
|
||||
|
||||
# Simulate proxy header extraction logic
|
||||
headers = getattr(error, "headers", None) or {}
|
||||
if not headers:
|
||||
_response = getattr(error, "response", None)
|
||||
if _response is not None:
|
||||
_response_headers = getattr(_response, "headers", None)
|
||||
if _response_headers:
|
||||
headers = get_response_headers(dict(_response_headers))
|
||||
|
||||
# Verify headers are extracted and prefixed correctly
|
||||
assert headers.get("llm_provider-x-request-id") == "req-abc123"
|
||||
assert headers.get("llm_provider-x-ms-region") == "eastus"
|
||||
Reference in New Issue
Block a user