mirror of
https://github.com/tiennm99/litellm.git
synced 2026-06-18 07:33:58 +00:00
[Feat] Enable async_post_call_failure_hook to transform error responses (#18348)
This commit is contained in:
@@ -17,6 +17,7 @@ import Image from '@theme/IdealImage';
|
||||
| `async_pre_call_hook` | Modify incoming request before it's sent to model | Before the LLM API call is made |
|
||||
| `async_moderation_hook` | Run checks on input in parallel to LLM API call | In parallel with the LLM API call |
|
||||
| `async_post_call_success_hook` | Modify outgoing response (non-streaming) | After successful LLM API call, for non-streaming responses |
|
||||
| `async_post_call_failure_hook` | Transform error responses sent to clients | After failed LLM API call |
|
||||
| `async_post_call_streaming_hook` | Modify outgoing response (streaming) | After successful LLM API call, for streaming responses |
|
||||
|
||||
See a complete example with our [parallel request rate limiter](https://github.com/BerriAI/litellm/blob/main/litellm/proxy/hooks/parallel_request_limiter.py)
|
||||
@@ -60,7 +61,21 @@ class MyCustomHandler(CustomLogger): # https://docs.litellm.ai/docs/observabilit
|
||||
original_exception: Exception,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
traceback_str: Optional[str] = None,
|
||||
):
|
||||
) -> Optional[HTTPException]:
|
||||
"""
|
||||
Transform error responses sent to clients.
|
||||
|
||||
Return an HTTPException to replace the original error with a user-friendly message.
|
||||
Return None to use the original exception.
|
||||
|
||||
Example:
|
||||
if isinstance(original_exception, litellm.ContextWindowExceededError):
|
||||
return HTTPException(
|
||||
status_code=400,
|
||||
detail="Your prompt is too long. Please reduce the length and try again."
|
||||
)
|
||||
return None # Use original exception
|
||||
"""
|
||||
pass
|
||||
|
||||
async def async_post_call_success_hook(
|
||||
@@ -339,3 +354,38 @@ curl --location 'http://0.0.0.0:4000/chat/completions' \
|
||||
"usage": {}
|
||||
}
|
||||
```
|
||||
|
||||
## Advanced - Transform Error Responses
|
||||
|
||||
Transform technical API errors into user-friendly messages using `async_post_call_failure_hook`. Return an `HTTPException` to replace the original error, or `None` to use the original exception.
|
||||
|
||||
```python
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from fastapi import HTTPException
|
||||
from typing import Optional
|
||||
import litellm
|
||||
|
||||
class MyErrorTransformer(CustomLogger):
|
||||
async def async_post_call_failure_hook(
|
||||
self,
|
||||
request_data: dict,
|
||||
original_exception: Exception,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
traceback_str: Optional[str] = None,
|
||||
) -> Optional[HTTPException]:
|
||||
if isinstance(original_exception, litellm.ContextWindowExceededError):
|
||||
return HTTPException(
|
||||
status_code=400,
|
||||
detail="Your prompt is too long. Please reduce the length and try again."
|
||||
)
|
||||
if isinstance(original_exception, litellm.RateLimitError):
|
||||
return HTTPException(
|
||||
status_code=429,
|
||||
detail="Rate limit exceeded. Please try again in a moment."
|
||||
)
|
||||
return None # Use original exception
|
||||
|
||||
proxy_handler_instance = MyErrorTransformer()
|
||||
```
|
||||
|
||||
**Result:** Clients receive `"Your prompt is too long..."` instead of `"ContextWindowExceededError: Prompt exceeds context window"`.
|
||||
|
||||
@@ -32,6 +32,8 @@ from litellm.types.utils import (
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from fastapi import HTTPException
|
||||
|
||||
from litellm.caching.caching import DualCache
|
||||
from opentelemetry.trace import Span as _Span
|
||||
|
||||
@@ -348,7 +350,20 @@ class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callbac
|
||||
original_exception: Exception,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
traceback_str: Optional[str] = None,
|
||||
):
|
||||
) -> Optional["HTTPException"]:
|
||||
"""
|
||||
Called after an LLM API call fails. Can return or raise HTTPException to transform error responses.
|
||||
|
||||
Args:
|
||||
- request_data: dict - The request data.
|
||||
- original_exception: Exception - The original exception that occurred.
|
||||
- user_api_key_dict: UserAPIKeyAuth - The user API key dictionary.
|
||||
- traceback_str: Optional[str] - The traceback string.
|
||||
|
||||
Returns:
|
||||
- Optional[HTTPException]: Return an HTTPException to transform the error response sent to the client.
|
||||
Return None to use the original exception.
|
||||
"""
|
||||
pass
|
||||
|
||||
async def async_post_call_success_hook(
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
Handles Authentication Errors
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
|
||||
from fastapi import HTTPException, Request, status
|
||||
@@ -90,15 +89,17 @@ class UserAPIKeyAuthExceptionHandler:
|
||||
api_key=api_key,
|
||||
request_route=route,
|
||||
)
|
||||
asyncio.create_task(
|
||||
proxy_logging_obj.post_call_failure_hook(
|
||||
request_data=request_data,
|
||||
original_exception=e,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
error_type=ProxyErrorTypes.auth_error,
|
||||
route=route,
|
||||
)
|
||||
# Allow callbacks to transform the error response
|
||||
transformed_exception = await proxy_logging_obj.post_call_failure_hook(
|
||||
request_data=request_data,
|
||||
original_exception=e,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
error_type=ProxyErrorTypes.auth_error,
|
||||
route=route,
|
||||
)
|
||||
# Use transformed exception if callback returned one, otherwise use original
|
||||
if transformed_exception is not None:
|
||||
e = transformed_exception
|
||||
|
||||
if isinstance(e, litellm.BudgetExceededError):
|
||||
raise ProxyException(
|
||||
|
||||
@@ -786,11 +786,15 @@ class ProxyBaseLLMRequestProcessing:
|
||||
verbose_proxy_logger.exception(
|
||||
f"litellm.proxy.proxy_server._handle_llm_api_exception(): Exception occured - {str(e)}"
|
||||
)
|
||||
await proxy_logging_obj.post_call_failure_hook(
|
||||
# Allow callbacks to transform the error response
|
||||
transformed_exception = await proxy_logging_obj.post_call_failure_hook(
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
original_exception=e,
|
||||
request_data=self.data,
|
||||
)
|
||||
# Use transformed exception if callback returned one, otherwise use original
|
||||
if transformed_exception is not None:
|
||||
e = transformed_exception
|
||||
litellm_debug_info = getattr(e, "litellm_debug_info", "")
|
||||
verbose_proxy_logger.debug(
|
||||
"\033[1;31mAn error occurred: %s %s\n\n Debug this by setting `--debug`, e.g. `litellm --model gpt-3.5-turbo --debug`",
|
||||
@@ -970,11 +974,15 @@ class ProxyBaseLLMRequestProcessing:
|
||||
str(e)
|
||||
)
|
||||
)
|
||||
await proxy_logging_obj.post_call_failure_hook(
|
||||
# Allow callbacks to transform the error response
|
||||
transformed_exception = await proxy_logging_obj.post_call_failure_hook(
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
original_exception=e,
|
||||
request_data=request_data,
|
||||
)
|
||||
# Use transformed exception if callback returned one, otherwise use original
|
||||
if transformed_exception is not None:
|
||||
e = transformed_exception
|
||||
verbose_proxy_logger.debug(
|
||||
f"\033[1;31mAn error occurred: {e}\n\n Debug this by setting `--debug`, e.g. `litellm --model gpt-3.5-turbo --debug`"
|
||||
)
|
||||
|
||||
+26
-6
@@ -1465,9 +1465,10 @@ class ProxyLogging:
|
||||
error_type: Optional[ProxyErrorTypes] = None,
|
||||
route: Optional[str] = None,
|
||||
traceback_str: Optional[str] = None,
|
||||
):
|
||||
) -> Optional[HTTPException]:
|
||||
"""
|
||||
Allows users to raise custom exceptions/log when a call fails, without having to deal with parsing Request body.
|
||||
Callbacks can return or raise HTTPException to transform error responses sent to clients.
|
||||
|
||||
Covers:
|
||||
1. /chat/completions
|
||||
@@ -1481,6 +1482,10 @@ class ProxyLogging:
|
||||
- error_type: Optional[ProxyErrorTypes] - The error type.
|
||||
- route: Optional[str] - The route.
|
||||
- traceback_str: Optional[str] - The traceback string, sometimes upstream endpoints might need to send the upstream traceback. In which case we use this
|
||||
|
||||
Returns:
|
||||
- Optional[HTTPException]: If any callback returns or raises an HTTPException, the first one found is returned.
|
||||
Otherwise, returns None and the original exception is used.
|
||||
"""
|
||||
|
||||
### ALERTING ###
|
||||
@@ -1522,6 +1527,9 @@ class ProxyLogging:
|
||||
original_exception=original_exception,
|
||||
)
|
||||
|
||||
# Track the first HTTPException returned or raised by any callback
|
||||
transformed_exception: Optional[HTTPException] = None
|
||||
|
||||
for callback in litellm.callbacks:
|
||||
try:
|
||||
_callback: Optional[CustomLogger] = None
|
||||
@@ -1532,19 +1540,31 @@ class ProxyLogging:
|
||||
else:
|
||||
_callback = callback # type: ignore
|
||||
if _callback is not None and isinstance(_callback, CustomLogger):
|
||||
asyncio.create_task(
|
||||
_callback.async_post_call_failure_hook(
|
||||
try:
|
||||
hook_result = await _callback.async_post_call_failure_hook(
|
||||
request_data=request_data,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
original_exception=original_exception,
|
||||
traceback_str=traceback_str,
|
||||
)
|
||||
)
|
||||
# If callback returned an HTTPException, use it (first one wins)
|
||||
if isinstance(hook_result, HTTPException) and transformed_exception is None:
|
||||
transformed_exception = hook_result
|
||||
except HTTPException as e:
|
||||
# If callback raised an HTTPException, use it (first one wins)
|
||||
if transformed_exception is None:
|
||||
transformed_exception = e
|
||||
except Exception as e:
|
||||
# Log non-HTTPException errors from callbacks but don't break the flow
|
||||
verbose_proxy_logger.exception(
|
||||
f"[Non-Blocking] Error in async_post_call_failure_hook callback: {e}"
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(
|
||||
f"[Non-Blocking] Error in post_call_failure_hook: {e}"
|
||||
f"[Non-Blocking] Error setting up post_call_failure_hook callback: {e}"
|
||||
)
|
||||
return
|
||||
|
||||
return transformed_exception
|
||||
|
||||
def _is_proxy_only_llm_api_error(
|
||||
self,
|
||||
|
||||
@@ -278,8 +278,8 @@ async def test_proxy_admin_expired_key_from_cache():
|
||||
mock_proxy_logging_obj.internal_usage_cache = MagicMock()
|
||||
mock_proxy_logging_obj.internal_usage_cache.dual_cache = AsyncMock()
|
||||
mock_proxy_logging_obj.internal_usage_cache.dual_cache.async_delete_cache = AsyncMock()
|
||||
# Mock post_call_failure_hook as async function
|
||||
mock_proxy_logging_obj.post_call_failure_hook = AsyncMock()
|
||||
# Mock post_call_failure_hook as async function returning None (no transformation)
|
||||
mock_proxy_logging_obj.post_call_failure_hook = AsyncMock(return_value=None)
|
||||
|
||||
# Mock prisma_client
|
||||
mock_prisma_client = MagicMock()
|
||||
|
||||
@@ -0,0 +1,146 @@
|
||||
"""
|
||||
Integration tests for async_post_call_failure_hook.
|
||||
|
||||
Tests verify that the failure hook can transform error responses sent to clients,
|
||||
similar to how async_post_call_success_hook can transform successful responses.
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import pytest
|
||||
from typing import Optional
|
||||
from unittest.mock import patch
|
||||
|
||||
sys.path.insert(0, os.path.abspath("../../../.."))
|
||||
|
||||
from fastapi import HTTPException
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
|
||||
|
||||
class ErrorTransformerLogger(CustomLogger):
|
||||
"""Logger that transforms errors into user-friendly messages"""
|
||||
|
||||
def __init__(self):
|
||||
self.called = False
|
||||
self.transformed_exception = None
|
||||
|
||||
async def async_post_call_failure_hook(
|
||||
self,
|
||||
request_data: dict,
|
||||
original_exception: Exception,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
traceback_str: Optional[str] = None,
|
||||
):
|
||||
self.called = True
|
||||
self.transformed_exception = HTTPException(
|
||||
status_code=400,
|
||||
detail="User-friendly error: Your request could not be processed."
|
||||
)
|
||||
return self.transformed_exception
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_failure_hook_transforms_error_response():
|
||||
"""
|
||||
Test that async_post_call_failure_hook can transform error responses.
|
||||
This mirrors how async_post_call_success_hook can transform successful responses.
|
||||
"""
|
||||
transformer = ErrorTransformerLogger()
|
||||
|
||||
# Mock litellm.callbacks to include our transformer
|
||||
with patch("litellm.callbacks", [transformer]):
|
||||
from litellm.proxy.utils import ProxyLogging
|
||||
from litellm.caching.caching import DualCache
|
||||
|
||||
proxy_logging = ProxyLogging(user_api_key_cache=DualCache())
|
||||
original_exception = Exception("Technical error message")
|
||||
request_data = {"model": "test-model"}
|
||||
user_api_key_dict = UserAPIKeyAuth(api_key="test-key")
|
||||
|
||||
# Call the hook
|
||||
result = await proxy_logging.post_call_failure_hook(
|
||||
request_data=request_data,
|
||||
original_exception=original_exception,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
)
|
||||
|
||||
# Verify hook was called
|
||||
assert transformer.called is True
|
||||
|
||||
# Verify transformed exception is returned
|
||||
assert result is not None
|
||||
assert isinstance(result, HTTPException)
|
||||
assert result.detail == "User-friendly error: Your request could not be processed."
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_failure_hook_returns_none_when_no_transformation():
|
||||
"""
|
||||
Test that hook returning None uses original exception.
|
||||
"""
|
||||
class NoOpLogger(CustomLogger):
|
||||
def __init__(self):
|
||||
self.called = False
|
||||
|
||||
async def async_post_call_failure_hook(self, *args, **kwargs):
|
||||
self.called = True
|
||||
return None
|
||||
|
||||
logger = NoOpLogger()
|
||||
|
||||
with patch("litellm.callbacks", [logger]):
|
||||
from litellm.proxy.utils import ProxyLogging
|
||||
from litellm.caching.caching import DualCache
|
||||
|
||||
proxy_logging = ProxyLogging(user_api_key_cache=DualCache())
|
||||
original_exception = Exception("Original error")
|
||||
request_data = {"model": "test"}
|
||||
user_api_key_dict = UserAPIKeyAuth(api_key="test")
|
||||
|
||||
result = await proxy_logging.post_call_failure_hook(
|
||||
request_data=request_data,
|
||||
original_exception=original_exception,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
)
|
||||
|
||||
# Should return None (original exception will be used)
|
||||
assert result is None
|
||||
assert logger.called is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_failure_hook_handles_exceptions_gracefully():
|
||||
"""
|
||||
Test that hook failures don't break the error flow.
|
||||
"""
|
||||
class FailingLogger(CustomLogger):
|
||||
def __init__(self):
|
||||
self.called = False
|
||||
|
||||
async def async_post_call_failure_hook(self, *args, **kwargs):
|
||||
self.called = True
|
||||
raise RuntimeError("Hook crashed!")
|
||||
|
||||
logger = FailingLogger()
|
||||
|
||||
with patch("litellm.callbacks", [logger]):
|
||||
from litellm.proxy.utils import ProxyLogging
|
||||
from litellm.caching.caching import DualCache
|
||||
|
||||
proxy_logging = ProxyLogging(user_api_key_cache=DualCache())
|
||||
original_exception = Exception("Original error")
|
||||
request_data = {"model": "test"}
|
||||
user_api_key_dict = UserAPIKeyAuth(api_key="test")
|
||||
|
||||
# Should not raise, should handle gracefully
|
||||
result = await proxy_logging.post_call_failure_hook(
|
||||
request_data=request_data,
|
||||
original_exception=original_exception,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
)
|
||||
|
||||
# Should return None (original exception will be used)
|
||||
assert result is None
|
||||
assert logger.called is True
|
||||
|
||||
@@ -1148,7 +1148,7 @@ class TestBedrockLLMProxyRoute:
|
||||
mock_user_api_key_dict.allowed_model_region = None
|
||||
|
||||
mock_proxy_logging_obj = Mock()
|
||||
mock_proxy_logging_obj.post_call_failure_hook = AsyncMock()
|
||||
mock_proxy_logging_obj.post_call_failure_hook = AsyncMock(return_value=None)
|
||||
|
||||
endpoint = "model/test-model/converse"
|
||||
model = "test-model"
|
||||
@@ -1291,7 +1291,7 @@ class TestBedrockLLMProxyRoute:
|
||||
mock_user_api_key_dict = Mock()
|
||||
mock_user_api_key_dict.api_key = "test-key"
|
||||
mock_proxy_logging_obj = Mock()
|
||||
mock_proxy_logging_obj.post_call_failure_hook = AsyncMock()
|
||||
mock_proxy_logging_obj.post_call_failure_hook = AsyncMock(return_value=None)
|
||||
|
||||
with patch(
|
||||
"litellm.passthrough.main.llm_passthrough_route",
|
||||
|
||||
Reference in New Issue
Block a user