[Feat] Enable async_post_call_failure_hook to transform error responses (#18348)

This commit is contained in:
Alexsander Hamir
2025-12-22 11:24:30 -08:00
committed by GitHub
parent 6dc11deeac
commit 30fa90f70d
8 changed files with 263 additions and 23 deletions
+51 -1
View File
@@ -17,6 +17,7 @@ import Image from '@theme/IdealImage';
| `async_pre_call_hook` | Modify incoming request before it's sent to model | Before the LLM API call is made |
| `async_moderation_hook` | Run checks on input in parallel to LLM API call | In parallel with the LLM API call |
| `async_post_call_success_hook` | Modify outgoing response (non-streaming) | After successful LLM API call, for non-streaming responses |
| `async_post_call_failure_hook` | Transform error responses sent to clients | After failed LLM API call |
| `async_post_call_streaming_hook` | Modify outgoing response (streaming) | After successful LLM API call, for streaming responses |
See a complete example with our [parallel request rate limiter](https://github.com/BerriAI/litellm/blob/main/litellm/proxy/hooks/parallel_request_limiter.py)
@@ -60,7 +61,21 @@ class MyCustomHandler(CustomLogger): # https://docs.litellm.ai/docs/observabilit
original_exception: Exception,
user_api_key_dict: UserAPIKeyAuth,
traceback_str: Optional[str] = None,
):
) -> Optional[HTTPException]:
"""
Transform error responses sent to clients.
Return an HTTPException to replace the original error with a user-friendly message.
Return None to use the original exception.
Example:
if isinstance(original_exception, litellm.ContextWindowExceededError):
return HTTPException(
status_code=400,
detail="Your prompt is too long. Please reduce the length and try again."
)
return None # Use original exception
"""
pass
async def async_post_call_success_hook(
@@ -339,3 +354,38 @@ curl --location 'http://0.0.0.0:4000/chat/completions' \
"usage": {}
}
```
## Advanced - Transform Error Responses
Transform technical API errors into user-friendly messages using `async_post_call_failure_hook`. Return an `HTTPException` to replace the original error, or `None` to use the original exception.
```python
from litellm.integrations.custom_logger import CustomLogger
from fastapi import HTTPException
from typing import Optional
import litellm
class MyErrorTransformer(CustomLogger):
async def async_post_call_failure_hook(
self,
request_data: dict,
original_exception: Exception,
user_api_key_dict: UserAPIKeyAuth,
traceback_str: Optional[str] = None,
) -> Optional[HTTPException]:
if isinstance(original_exception, litellm.ContextWindowExceededError):
return HTTPException(
status_code=400,
detail="Your prompt is too long. Please reduce the length and try again."
)
if isinstance(original_exception, litellm.RateLimitError):
return HTTPException(
status_code=429,
detail="Rate limit exceeded. Please try again in a moment."
)
return None # Use original exception
proxy_handler_instance = MyErrorTransformer()
```
**Result:** Clients receive `"Your prompt is too long..."` instead of `"ContextWindowExceededError: Prompt exceeds context window"`.
+16 -1
View File
@@ -32,6 +32,8 @@ from litellm.types.utils import (
)
if TYPE_CHECKING:
from fastapi import HTTPException
from litellm.caching.caching import DualCache
from opentelemetry.trace import Span as _Span
@@ -348,7 +350,20 @@ class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callbac
original_exception: Exception,
user_api_key_dict: UserAPIKeyAuth,
traceback_str: Optional[str] = None,
):
) -> Optional["HTTPException"]:
"""
Called after an LLM API call fails. Can return or raise HTTPException to transform error responses.
Args:
- request_data: dict - The request data.
- original_exception: Exception - The original exception that occurred.
- user_api_key_dict: UserAPIKeyAuth - The user API key dictionary.
- traceback_str: Optional[str] - The traceback string.
Returns:
- Optional[HTTPException]: Return an HTTPException to transform the error response sent to the client.
Return None to use the original exception.
"""
pass
async def async_post_call_success_hook(
+10 -9
View File
@@ -2,7 +2,6 @@
Handles Authentication Errors
"""
import asyncio
from typing import TYPE_CHECKING, Any, Optional, Union
from fastapi import HTTPException, Request, status
@@ -90,15 +89,17 @@ class UserAPIKeyAuthExceptionHandler:
api_key=api_key,
request_route=route,
)
asyncio.create_task(
proxy_logging_obj.post_call_failure_hook(
request_data=request_data,
original_exception=e,
user_api_key_dict=user_api_key_dict,
error_type=ProxyErrorTypes.auth_error,
route=route,
)
# Allow callbacks to transform the error response
transformed_exception = await proxy_logging_obj.post_call_failure_hook(
request_data=request_data,
original_exception=e,
user_api_key_dict=user_api_key_dict,
error_type=ProxyErrorTypes.auth_error,
route=route,
)
# Use transformed exception if callback returned one, otherwise use original
if transformed_exception is not None:
e = transformed_exception
if isinstance(e, litellm.BudgetExceededError):
raise ProxyException(
+10 -2
View File
@@ -786,11 +786,15 @@ class ProxyBaseLLMRequestProcessing:
verbose_proxy_logger.exception(
f"litellm.proxy.proxy_server._handle_llm_api_exception(): Exception occured - {str(e)}"
)
await proxy_logging_obj.post_call_failure_hook(
# Allow callbacks to transform the error response
transformed_exception = await proxy_logging_obj.post_call_failure_hook(
user_api_key_dict=user_api_key_dict,
original_exception=e,
request_data=self.data,
)
# Use transformed exception if callback returned one, otherwise use original
if transformed_exception is not None:
e = transformed_exception
litellm_debug_info = getattr(e, "litellm_debug_info", "")
verbose_proxy_logger.debug(
"\033[1;31mAn error occurred: %s %s\n\n Debug this by setting `--debug`, e.g. `litellm --model gpt-3.5-turbo --debug`",
@@ -970,11 +974,15 @@ class ProxyBaseLLMRequestProcessing:
str(e)
)
)
await proxy_logging_obj.post_call_failure_hook(
# Allow callbacks to transform the error response
transformed_exception = await proxy_logging_obj.post_call_failure_hook(
user_api_key_dict=user_api_key_dict,
original_exception=e,
request_data=request_data,
)
# Use transformed exception if callback returned one, otherwise use original
if transformed_exception is not None:
e = transformed_exception
verbose_proxy_logger.debug(
f"\033[1;31mAn error occurred: {e}\n\n Debug this by setting `--debug`, e.g. `litellm --model gpt-3.5-turbo --debug`"
)
+26 -6
View File
@@ -1465,9 +1465,10 @@ class ProxyLogging:
error_type: Optional[ProxyErrorTypes] = None,
route: Optional[str] = None,
traceback_str: Optional[str] = None,
):
) -> Optional[HTTPException]:
"""
Allows users to raise custom exceptions/log when a call fails, without having to deal with parsing Request body.
Callbacks can return or raise HTTPException to transform error responses sent to clients.
Covers:
1. /chat/completions
@@ -1481,6 +1482,10 @@ class ProxyLogging:
- error_type: Optional[ProxyErrorTypes] - The error type.
- route: Optional[str] - The route.
- traceback_str: Optional[str] - The traceback string, sometimes upstream endpoints might need to send the upstream traceback. In which case we use this
Returns:
- Optional[HTTPException]: If any callback returns or raises an HTTPException, the first one found is returned.
Otherwise, returns None and the original exception is used.
"""
### ALERTING ###
@@ -1522,6 +1527,9 @@ class ProxyLogging:
original_exception=original_exception,
)
# Track the first HTTPException returned or raised by any callback
transformed_exception: Optional[HTTPException] = None
for callback in litellm.callbacks:
try:
_callback: Optional[CustomLogger] = None
@@ -1532,19 +1540,31 @@ class ProxyLogging:
else:
_callback = callback # type: ignore
if _callback is not None and isinstance(_callback, CustomLogger):
asyncio.create_task(
_callback.async_post_call_failure_hook(
try:
hook_result = await _callback.async_post_call_failure_hook(
request_data=request_data,
user_api_key_dict=user_api_key_dict,
original_exception=original_exception,
traceback_str=traceback_str,
)
)
# If callback returned an HTTPException, use it (first one wins)
if isinstance(hook_result, HTTPException) and transformed_exception is None:
transformed_exception = hook_result
except HTTPException as e:
# If callback raised an HTTPException, use it (first one wins)
if transformed_exception is None:
transformed_exception = e
except Exception as e:
# Log non-HTTPException errors from callbacks but don't break the flow
verbose_proxy_logger.exception(
f"[Non-Blocking] Error in async_post_call_failure_hook callback: {e}"
)
except Exception as e:
verbose_proxy_logger.exception(
f"[Non-Blocking] Error in post_call_failure_hook: {e}"
f"[Non-Blocking] Error setting up post_call_failure_hook callback: {e}"
)
return
return transformed_exception
def _is_proxy_only_llm_api_error(
self,
@@ -278,8 +278,8 @@ async def test_proxy_admin_expired_key_from_cache():
mock_proxy_logging_obj.internal_usage_cache = MagicMock()
mock_proxy_logging_obj.internal_usage_cache.dual_cache = AsyncMock()
mock_proxy_logging_obj.internal_usage_cache.dual_cache.async_delete_cache = AsyncMock()
# Mock post_call_failure_hook as async function
mock_proxy_logging_obj.post_call_failure_hook = AsyncMock()
# Mock post_call_failure_hook as async function returning None (no transformation)
mock_proxy_logging_obj.post_call_failure_hook = AsyncMock(return_value=None)
# Mock prisma_client
mock_prisma_client = MagicMock()
@@ -0,0 +1,146 @@
"""
Integration tests for async_post_call_failure_hook.
Tests verify that the failure hook can transform error responses sent to clients,
similar to how async_post_call_success_hook can transform successful responses.
"""
import os
import sys
import pytest
from typing import Optional
from unittest.mock import patch
sys.path.insert(0, os.path.abspath("../../../.."))
from fastapi import HTTPException
from litellm.integrations.custom_logger import CustomLogger
from litellm.proxy._types import UserAPIKeyAuth
class ErrorTransformerLogger(CustomLogger):
"""Logger that transforms errors into user-friendly messages"""
def __init__(self):
self.called = False
self.transformed_exception = None
async def async_post_call_failure_hook(
self,
request_data: dict,
original_exception: Exception,
user_api_key_dict: UserAPIKeyAuth,
traceback_str: Optional[str] = None,
):
self.called = True
self.transformed_exception = HTTPException(
status_code=400,
detail="User-friendly error: Your request could not be processed."
)
return self.transformed_exception
@pytest.mark.asyncio
async def test_failure_hook_transforms_error_response():
"""
Test that async_post_call_failure_hook can transform error responses.
This mirrors how async_post_call_success_hook can transform successful responses.
"""
transformer = ErrorTransformerLogger()
# Mock litellm.callbacks to include our transformer
with patch("litellm.callbacks", [transformer]):
from litellm.proxy.utils import ProxyLogging
from litellm.caching.caching import DualCache
proxy_logging = ProxyLogging(user_api_key_cache=DualCache())
original_exception = Exception("Technical error message")
request_data = {"model": "test-model"}
user_api_key_dict = UserAPIKeyAuth(api_key="test-key")
# Call the hook
result = await proxy_logging.post_call_failure_hook(
request_data=request_data,
original_exception=original_exception,
user_api_key_dict=user_api_key_dict,
)
# Verify hook was called
assert transformer.called is True
# Verify transformed exception is returned
assert result is not None
assert isinstance(result, HTTPException)
assert result.detail == "User-friendly error: Your request could not be processed."
@pytest.mark.asyncio
async def test_failure_hook_returns_none_when_no_transformation():
"""
Test that hook returning None uses original exception.
"""
class NoOpLogger(CustomLogger):
def __init__(self):
self.called = False
async def async_post_call_failure_hook(self, *args, **kwargs):
self.called = True
return None
logger = NoOpLogger()
with patch("litellm.callbacks", [logger]):
from litellm.proxy.utils import ProxyLogging
from litellm.caching.caching import DualCache
proxy_logging = ProxyLogging(user_api_key_cache=DualCache())
original_exception = Exception("Original error")
request_data = {"model": "test"}
user_api_key_dict = UserAPIKeyAuth(api_key="test")
result = await proxy_logging.post_call_failure_hook(
request_data=request_data,
original_exception=original_exception,
user_api_key_dict=user_api_key_dict,
)
# Should return None (original exception will be used)
assert result is None
assert logger.called is True
@pytest.mark.asyncio
async def test_failure_hook_handles_exceptions_gracefully():
"""
Test that hook failures don't break the error flow.
"""
class FailingLogger(CustomLogger):
def __init__(self):
self.called = False
async def async_post_call_failure_hook(self, *args, **kwargs):
self.called = True
raise RuntimeError("Hook crashed!")
logger = FailingLogger()
with patch("litellm.callbacks", [logger]):
from litellm.proxy.utils import ProxyLogging
from litellm.caching.caching import DualCache
proxy_logging = ProxyLogging(user_api_key_cache=DualCache())
original_exception = Exception("Original error")
request_data = {"model": "test"}
user_api_key_dict = UserAPIKeyAuth(api_key="test")
# Should not raise, should handle gracefully
result = await proxy_logging.post_call_failure_hook(
request_data=request_data,
original_exception=original_exception,
user_api_key_dict=user_api_key_dict,
)
# Should return None (original exception will be used)
assert result is None
assert logger.called is True
@@ -1148,7 +1148,7 @@ class TestBedrockLLMProxyRoute:
mock_user_api_key_dict.allowed_model_region = None
mock_proxy_logging_obj = Mock()
mock_proxy_logging_obj.post_call_failure_hook = AsyncMock()
mock_proxy_logging_obj.post_call_failure_hook = AsyncMock(return_value=None)
endpoint = "model/test-model/converse"
model = "test-model"
@@ -1291,7 +1291,7 @@ class TestBedrockLLMProxyRoute:
mock_user_api_key_dict = Mock()
mock_user_api_key_dict.api_key = "test-key"
mock_proxy_logging_obj = Mock()
mock_proxy_logging_obj.post_call_failure_hook = AsyncMock()
mock_proxy_logging_obj.post_call_failure_hook = AsyncMock(return_value=None)
with patch(
"litellm.passthrough.main.llm_passthrough_route",