From c215b3a79fcb999759bfec8b1beb287cbf82ed61 Mon Sep 17 00:00:00 2001 From: Peter Golm Date: Wed, 14 Jan 2026 18:19:39 +0100 Subject: [PATCH] fix: preserve llm_provider-* headers in error responses (#19020) Extract and preserve provider-specific headers (llm_provider-*) when handling error responses from LLM providers. This ensures that useful debugging information from providers is available even when requests fail with BadRequestError or similar exceptions. --- litellm/exceptions.py | 22 +- litellm/llms/azure/exception_mapping.py | 12 +- litellm/proxy/common_request_processing.py | 29 +- .../test_exception_header_preservation.py | 270 ++++++++++++++++++ 4 files changed, 308 insertions(+), 25 deletions(-) create mode 100644 tests/test_litellm/test_exception_header_preservation.py diff --git a/litellm/exceptions.py b/litellm/exceptions.py index 400973d822..c2443626b8 100644 --- a/litellm/exceptions.py +++ b/litellm/exceptions.py @@ -125,16 +125,20 @@ class BadRequestError(openai.BadRequestError): # type: ignore self.model = model self.llm_provider = llm_provider self.litellm_debug_info = litellm_debug_info - response = httpx.Response( + self.max_retries = max_retries + self.num_retries = num_retries + _response_headers = ( + getattr(response, "headers", None) if response is not None else None + ) + self.response = httpx.Response( status_code=self.status_code, + headers=_response_headers, request=httpx.Request( method="GET", url="https://litellm.ai" ), # mock request object ) - self.max_retries = max_retries - self.num_retries = num_retries super().__init__( - self.message, response=response, body=body + self.message, response=self.response, body=body ) # Call the base class constructor with the parameters it needs def __str__(self): @@ -368,13 +372,11 @@ class ContextWindowExceededError(BadRequestError): # type: ignore self.model = model self.llm_provider = llm_provider self.litellm_debug_info = litellm_debug_info - request = httpx.Request(method="POST", url="https://api.openai.com/v1") - self.response = httpx.Response(status_code=400, request=request) super().__init__( message=message, model=self.model, # type: ignore llm_provider=self.llm_provider, # type: ignore - response=self.response, + response=response, litellm_debug_info=self.litellm_debug_info, ) # Call the base class constructor with the parameters it needs @@ -457,18 +459,14 @@ class ContentPolicyViolationError(BadRequestError): # type: ignore self.model = model self.llm_provider = llm_provider self.litellm_debug_info = litellm_debug_info - request = httpx.Request(method="POST", url="https://api.openai.com/v1") - self.response = httpx.Response(status_code=400, request=request) self.provider_specific_fields = provider_specific_fields - super().__init__( message=self.message, model=self.model, # type: ignore llm_provider=self.llm_provider, # type: ignore - response=self.response, + response=response, litellm_debug_info=self.litellm_debug_info, ) # Call the base class constructor with the parameters it needs - def __str__(self): return self._transform_error_to_string() diff --git a/litellm/llms/azure/exception_mapping.py b/litellm/llms/azure/exception_mapping.py index 70c2609c6b..193f3d9995 100644 --- a/litellm/llms/azure/exception_mapping.py +++ b/litellm/llms/azure/exception_mapping.py @@ -7,6 +7,7 @@ class AzureOpenAIExceptionMapping: """ Class for creating Azure OpenAI specific exceptions """ + @staticmethod def create_content_policy_violation_error( message: str, @@ -16,18 +17,20 @@ class AzureOpenAIExceptionMapping: ) -> ContentPolicyViolationError: """ Create a content policy violation error - """ + """ raise ContentPolicyViolationError( - message=f"litellm.ContentPolicyViolationError: AzureException - {message}", + message=f"AzureException - {message}", llm_provider="azure", model=model, litellm_debug_info=extra_information, response=getattr(original_exception, "response", None), provider_specific_fields={ - "innererror": AzureOpenAIExceptionMapping._get_innererror_from_exception(original_exception) + "innererror": AzureOpenAIExceptionMapping._get_innererror_from_exception( + original_exception + ) }, ) - + @staticmethod def _get_innererror_from_exception(original_exception: Exception) -> Optional[dict]: """ @@ -39,4 +42,3 @@ class AzureOpenAIExceptionMapping: if isinstance(body_dict, dict): innererror = body_dict.get("innererror") return innererror - \ No newline at end of file diff --git a/litellm/proxy/common_request_processing.py b/litellm/proxy/common_request_processing.py index 3db1620107..52f7f227b5 100644 --- a/litellm/proxy/common_request_processing.py +++ b/litellm/proxy/common_request_processing.py @@ -28,6 +28,9 @@ from litellm.constants import ( ) from litellm.litellm_core_utils.dd_tracing import tracer from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj +from litellm.litellm_core_utils.llm_response_utils.get_headers import ( + get_response_headers, +) from litellm.litellm_core_utils.safe_json_dumps import safe_dumps from litellm.proxy._types import ProxyException, UserAPIKeyAuth from litellm.proxy.auth.auth_utils import check_response_size_is_safe @@ -284,7 +287,12 @@ class ProxyBaseLLMRequestProcessing: hidden_params = hidden_params or {} # Extract discount and margin info from cost_breakdown if available - original_cost, discount_amount, margin_total_amount, margin_percent = _get_cost_breakdown_from_logging_obj( + ( + original_cost, + discount_amount, + margin_total_amount, + margin_percent, + ) = _get_cost_breakdown_from_logging_obj( litellm_logging_obj=litellm_logging_obj ) @@ -527,12 +535,12 @@ class ProxyBaseLLMRequestProcessing: # Apply hierarchical router_settings (Key > Team > Global) if llm_router is not None and proxy_config is not None: from litellm.proxy.proxy_server import prisma_client - + router_settings = await proxy_config._get_hierarchical_router_settings( user_api_key_dict=user_api_key_dict, prisma_client=prisma_client, ) - + # If router_settings found (from key, team, or global), apply them # This ensures key/team settings override global settings if router_settings is not None and router_settings: @@ -541,10 +549,7 @@ class ProxyBaseLLMRequestProcessing: if model_list is not None: # Create user_config with model_list and router_settings # This creates a per-request router with the hierarchical settings - user_config = { - "model_list": model_list, - **router_settings - } + user_config = {"model_list": model_list, **router_settings} self.data["user_config"] = user_config if "messages" in self.data and self.data["messages"]: @@ -943,7 +948,15 @@ class ProxyBaseLLMRequestProcessing: timeout=timeout, litellm_logging_obj=_litellm_logging_obj, ) - headers = getattr(e, "headers", {}) or {} + # Extract headers from exception - check both e.headers and e.response.headers + headers = getattr(e, "headers", None) or {} + if not headers: + # Try to get headers from e.response.headers (httpx.Response) + _response = getattr(e, "response", None) + if _response is not None: + _response_headers = getattr(_response, "headers", None) + if _response_headers: + headers = get_response_headers(dict(_response_headers)) headers.update(custom_headers) if isinstance(e, HTTPException): diff --git a/tests/test_litellm/test_exception_header_preservation.py b/tests/test_litellm/test_exception_header_preservation.py new file mode 100644 index 0000000000..d3e33fa13b --- /dev/null +++ b/tests/test_litellm/test_exception_header_preservation.py @@ -0,0 +1,270 @@ +""" +Tests for exception header preservation. + +These tests verify that when LLM providers return error responses with headers, +those headers are preserved in the exception and can be returned to clients. + +This is important for debugging and observability - headers like x-request-id, +x-ms-region, rate limit headers, etc. should be available even when errors occur. +""" + +import httpx +import pytest + +from litellm.exceptions import ( + BadRequestError, + ContentPolicyViolationError, + ContextWindowExceededError, + ImageFetchError, +) + + +class TestExceptionHeaderPreservation: + """Test that exception classes preserve headers from provider responses.""" + + @pytest.fixture + def mock_response_with_headers(self) -> httpx.Response: + """Create a mock response with typical provider headers.""" + return httpx.Response( + status_code=400, + headers={ + "x-request-id": "req-abc123", + "x-ms-region": "eastus", + "x-ratelimit-remaining-requests": "99", + "x-ratelimit-remaining-tokens": "9999", + }, + request=httpx.Request("POST", "https://api.openai.com/v1/chat/completions"), + ) + + def test_bad_request_error_preserves_headers( + self, mock_response_with_headers: httpx.Response + ): + """BadRequestError should preserve headers from the provider response.""" + error = BadRequestError( + message="Invalid request", + model="gpt-4", + llm_provider="azure", + response=mock_response_with_headers, + ) + + assert error.response is not None + assert error.response.headers.get("x-request-id") == "req-abc123" + assert error.response.headers.get("x-ms-region") == "eastus" + assert error.response.headers.get("x-ratelimit-remaining-requests") == "99" + + def test_content_policy_violation_error_preserves_headers( + self, mock_response_with_headers: httpx.Response + ): + """ContentPolicyViolationError should preserve headers from the provider response.""" + error = ContentPolicyViolationError( + message="Content policy violation", + model="gpt-4", + llm_provider="azure", + response=mock_response_with_headers, + ) + + assert error.response is not None + assert error.response.headers.get("x-request-id") == "req-abc123" + assert error.response.headers.get("x-ms-region") == "eastus" + + def test_context_window_exceeded_error_preserves_headers( + self, mock_response_with_headers: httpx.Response + ): + """ContextWindowExceededError should preserve headers from the provider response.""" + error = ContextWindowExceededError( + message="Context window exceeded", + model="gpt-4", + llm_provider="azure", + response=mock_response_with_headers, + ) + + assert error.response is not None + assert error.response.headers.get("x-request-id") == "req-abc123" + assert error.response.headers.get("x-ms-region") == "eastus" + + def test_image_fetch_error_preserves_headers( + self, mock_response_with_headers: httpx.Response + ): + """ImageFetchError should preserve headers from the provider response.""" + error = ImageFetchError( + message="Failed to fetch image", + model="gpt-4", + llm_provider="azure", + response=mock_response_with_headers, + ) + + assert error.response is not None + assert error.response.headers.get("x-request-id") == "req-abc123" + assert error.response.headers.get("x-ms-region") == "eastus" + + def test_bad_request_error_handles_none_response(self): + """BadRequestError should handle None response gracefully.""" + error = BadRequestError( + message="Invalid request", + model="gpt-4", + llm_provider="azure", + response=None, + ) + + assert error.response is not None + # Headers should be empty but not cause an error + assert error.response.headers.get("x-request-id") is None + + def test_content_policy_violation_error_handles_none_response(self): + """ContentPolicyViolationError should handle None response gracefully.""" + error = ContentPolicyViolationError( + message="Content policy violation", + model="gpt-4", + llm_provider="azure", + response=None, + ) + + assert error.response is not None + assert error.response.headers.get("x-request-id") is None + + def test_context_window_exceeded_error_handles_none_response(self): + """ContextWindowExceededError should handle None response gracefully.""" + error = ContextWindowExceededError( + message="Context window exceeded", + model="gpt-4", + llm_provider="azure", + response=None, + ) + + assert error.response is not None + assert error.response.headers.get("x-request-id") is None + + +class TestExceptionMessageFormatting: + """Test that exception messages are formatted correctly after refactoring.""" + + def test_bad_request_error_message_format(self): + """BadRequestError should format message with litellm prefix.""" + error = BadRequestError( + message="test error", + model="gpt-4", + llm_provider="azure", + ) + + assert "litellm.BadRequestError" in error.message + assert "test error" in error.message + + def test_content_policy_violation_error_message_format(self): + """ContentPolicyViolationError should format message with specific prefix.""" + error = ContentPolicyViolationError( + message="test error", + model="gpt-4", + llm_provider="azure", + ) + + assert "litellm.ContentPolicyViolationError" in error.message + assert "test error" in error.message + + def test_context_window_exceeded_error_message_format(self): + """ContextWindowExceededError should format message with specific prefix.""" + error = ContextWindowExceededError( + message="test error", + model="gpt-4", + llm_provider="azure", + ) + + assert "litellm.ContextWindowExceededError" in error.message + assert "test error" in error.message + + +class TestExceptionAttributes: + """Test that exception attributes are set correctly.""" + + def test_content_policy_violation_error_provider_specific_fields(self): + """ContentPolicyViolationError should preserve provider_specific_fields.""" + provider_fields = {"innererror": {"code": "ResponsibleAIPolicyViolation"}} + + error = ContentPolicyViolationError( + message="test error", + model="gpt-4", + llm_provider="azure", + provider_specific_fields=provider_fields, + ) + + assert error.provider_specific_fields == provider_fields + assert ( + error.provider_specific_fields["innererror"]["code"] + == "ResponsibleAIPolicyViolation" + ) + + def test_bad_request_error_attributes(self): + """BadRequestError should set all expected attributes.""" + error = BadRequestError( + message="test error", + model="gpt-4", + llm_provider="azure", + litellm_debug_info="debug info", + max_retries=3, + num_retries=1, + ) + + assert error.model == "gpt-4" + assert error.llm_provider == "azure" + assert error.litellm_debug_info == "debug info" + assert error.max_retries == 3 + assert error.num_retries == 1 + assert error.status_code == 400 + + +class TestProxyHeaderExtraction: + """Test that proxy correctly extracts headers from exceptions.""" + + def test_get_response_headers_adds_llm_provider_prefix(self): + """get_response_headers should prefix non-OpenAI headers with llm_provider-.""" + from litellm.litellm_core_utils.llm_response_utils.get_headers import ( + get_response_headers, + ) + + response_headers = { + "x-request-id": "req-abc123", + "x-ms-region": "eastus", + "x-ratelimit-remaining-requests": "99", # OpenAI header - should not be prefixed + } + + result = get_response_headers(response_headers) + + # OpenAI ratelimit headers should be preserved as-is + assert result.get("x-ratelimit-remaining-requests") == "99" + # Other headers should be prefixed with llm_provider- + assert result.get("llm_provider-x-request-id") == "req-abc123" + assert result.get("llm_provider-x-ms-region") == "eastus" + + def test_proxy_can_extract_headers_from_exception_response(self): + """Simulate how proxy extracts headers from exception.response.headers.""" + from litellm.litellm_core_utils.llm_response_utils.get_headers import ( + get_response_headers, + ) + + # Create exception with headers in response + mock_response = httpx.Response( + status_code=400, + headers={ + "x-request-id": "req-abc123", + "x-ms-region": "eastus", + }, + request=httpx.Request("POST", "https://test.com"), + ) + error = ContentPolicyViolationError( + message="test", + model="gpt-4", + llm_provider="azure", + response=mock_response, + ) + + # Simulate proxy header extraction logic + headers = getattr(error, "headers", None) or {} + if not headers: + _response = getattr(error, "response", None) + if _response is not None: + _response_headers = getattr(_response, "headers", None) + if _response_headers: + headers = get_response_headers(dict(_response_headers)) + + # Verify headers are extracted and prefixed correctly + assert headers.get("llm_provider-x-request-id") == "req-abc123" + assert headers.get("llm_provider-x-ms-region") == "eastus"