diff --git a/docs/my-website/docs/proxy/call_hooks.md b/docs/my-website/docs/proxy/call_hooks.md index fa420009cf..fe865f67e0 100644 --- a/docs/my-website/docs/proxy/call_hooks.md +++ b/docs/my-website/docs/proxy/call_hooks.md @@ -17,6 +17,7 @@ import Image from '@theme/IdealImage'; | `async_pre_call_hook` | Modify incoming request before it's sent to model | Before the LLM API call is made | | `async_moderation_hook` | Run checks on input in parallel to LLM API call | In parallel with the LLM API call | | `async_post_call_success_hook` | Modify outgoing response (non-streaming) | After successful LLM API call, for non-streaming responses | +| `async_post_call_failure_hook` | Transform error responses sent to clients | After failed LLM API call | | `async_post_call_streaming_hook` | Modify outgoing response (streaming) | After successful LLM API call, for streaming responses | See a complete example with our [parallel request rate limiter](https://github.com/BerriAI/litellm/blob/main/litellm/proxy/hooks/parallel_request_limiter.py) @@ -60,7 +61,21 @@ class MyCustomHandler(CustomLogger): # https://docs.litellm.ai/docs/observabilit original_exception: Exception, user_api_key_dict: UserAPIKeyAuth, traceback_str: Optional[str] = None, - ): + ) -> Optional[HTTPException]: + """ + Transform error responses sent to clients. + + Return an HTTPException to replace the original error with a user-friendly message. + Return None to use the original exception. + + Example: + if isinstance(original_exception, litellm.ContextWindowExceededError): + return HTTPException( + status_code=400, + detail="Your prompt is too long. Please reduce the length and try again." + ) + return None # Use original exception + """ pass async def async_post_call_success_hook( @@ -339,3 +354,38 @@ curl --location 'http://0.0.0.0:4000/chat/completions' \ "usage": {} } ``` + +## Advanced - Transform Error Responses + +Transform technical API errors into user-friendly messages using `async_post_call_failure_hook`. Return an `HTTPException` to replace the original error, or `None` to use the original exception. + +```python +from litellm.integrations.custom_logger import CustomLogger +from fastapi import HTTPException +from typing import Optional +import litellm + +class MyErrorTransformer(CustomLogger): + async def async_post_call_failure_hook( + self, + request_data: dict, + original_exception: Exception, + user_api_key_dict: UserAPIKeyAuth, + traceback_str: Optional[str] = None, + ) -> Optional[HTTPException]: + if isinstance(original_exception, litellm.ContextWindowExceededError): + return HTTPException( + status_code=400, + detail="Your prompt is too long. Please reduce the length and try again." + ) + if isinstance(original_exception, litellm.RateLimitError): + return HTTPException( + status_code=429, + detail="Rate limit exceeded. Please try again in a moment." + ) + return None # Use original exception + +proxy_handler_instance = MyErrorTransformer() +``` + +**Result:** Clients receive `"Your prompt is too long..."` instead of `"ContextWindowExceededError: Prompt exceeds context window"`. diff --git a/litellm/integrations/custom_logger.py b/litellm/integrations/custom_logger.py index 6771999cd3..4c4e6fa634 100644 --- a/litellm/integrations/custom_logger.py +++ b/litellm/integrations/custom_logger.py @@ -32,6 +32,8 @@ from litellm.types.utils import ( ) if TYPE_CHECKING: + from fastapi import HTTPException + from litellm.caching.caching import DualCache from opentelemetry.trace import Span as _Span @@ -348,7 +350,20 @@ class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callbac original_exception: Exception, user_api_key_dict: UserAPIKeyAuth, traceback_str: Optional[str] = None, - ): + ) -> Optional["HTTPException"]: + """ + Called after an LLM API call fails. Can return or raise HTTPException to transform error responses. + + Args: + - request_data: dict - The request data. + - original_exception: Exception - The original exception that occurred. + - user_api_key_dict: UserAPIKeyAuth - The user API key dictionary. + - traceback_str: Optional[str] - The traceback string. + + Returns: + - Optional[HTTPException]: Return an HTTPException to transform the error response sent to the client. + Return None to use the original exception. + """ pass async def async_post_call_success_hook( diff --git a/litellm/proxy/auth/auth_exception_handler.py b/litellm/proxy/auth/auth_exception_handler.py index 2b9c4cdce6..9c306acd2c 100644 --- a/litellm/proxy/auth/auth_exception_handler.py +++ b/litellm/proxy/auth/auth_exception_handler.py @@ -2,7 +2,6 @@ Handles Authentication Errors """ -import asyncio from typing import TYPE_CHECKING, Any, Optional, Union from fastapi import HTTPException, Request, status @@ -90,15 +89,17 @@ class UserAPIKeyAuthExceptionHandler: api_key=api_key, request_route=route, ) - asyncio.create_task( - proxy_logging_obj.post_call_failure_hook( - request_data=request_data, - original_exception=e, - user_api_key_dict=user_api_key_dict, - error_type=ProxyErrorTypes.auth_error, - route=route, - ) + # Allow callbacks to transform the error response + transformed_exception = await proxy_logging_obj.post_call_failure_hook( + request_data=request_data, + original_exception=e, + user_api_key_dict=user_api_key_dict, + error_type=ProxyErrorTypes.auth_error, + route=route, ) + # Use transformed exception if callback returned one, otherwise use original + if transformed_exception is not None: + e = transformed_exception if isinstance(e, litellm.BudgetExceededError): raise ProxyException( diff --git a/litellm/proxy/common_request_processing.py b/litellm/proxy/common_request_processing.py index f798d218f1..302dd5639e 100644 --- a/litellm/proxy/common_request_processing.py +++ b/litellm/proxy/common_request_processing.py @@ -786,11 +786,15 @@ class ProxyBaseLLMRequestProcessing: verbose_proxy_logger.exception( f"litellm.proxy.proxy_server._handle_llm_api_exception(): Exception occured - {str(e)}" ) - await proxy_logging_obj.post_call_failure_hook( + # Allow callbacks to transform the error response + transformed_exception = await proxy_logging_obj.post_call_failure_hook( user_api_key_dict=user_api_key_dict, original_exception=e, request_data=self.data, ) + # Use transformed exception if callback returned one, otherwise use original + if transformed_exception is not None: + e = transformed_exception litellm_debug_info = getattr(e, "litellm_debug_info", "") verbose_proxy_logger.debug( "\033[1;31mAn error occurred: %s %s\n\n Debug this by setting `--debug`, e.g. `litellm --model gpt-3.5-turbo --debug`", @@ -970,11 +974,15 @@ class ProxyBaseLLMRequestProcessing: str(e) ) ) - await proxy_logging_obj.post_call_failure_hook( + # Allow callbacks to transform the error response + transformed_exception = await proxy_logging_obj.post_call_failure_hook( user_api_key_dict=user_api_key_dict, original_exception=e, request_data=request_data, ) + # Use transformed exception if callback returned one, otherwise use original + if transformed_exception is not None: + e = transformed_exception verbose_proxy_logger.debug( f"\033[1;31mAn error occurred: {e}\n\n Debug this by setting `--debug`, e.g. `litellm --model gpt-3.5-turbo --debug`" ) diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py index ec86139c73..d595db4a2e 100644 --- a/litellm/proxy/utils.py +++ b/litellm/proxy/utils.py @@ -1465,9 +1465,10 @@ class ProxyLogging: error_type: Optional[ProxyErrorTypes] = None, route: Optional[str] = None, traceback_str: Optional[str] = None, - ): + ) -> Optional[HTTPException]: """ Allows users to raise custom exceptions/log when a call fails, without having to deal with parsing Request body. + Callbacks can return or raise HTTPException to transform error responses sent to clients. Covers: 1. /chat/completions @@ -1481,6 +1482,10 @@ class ProxyLogging: - error_type: Optional[ProxyErrorTypes] - The error type. - route: Optional[str] - The route. - traceback_str: Optional[str] - The traceback string, sometimes upstream endpoints might need to send the upstream traceback. In which case we use this + + Returns: + - Optional[HTTPException]: If any callback returns or raises an HTTPException, the first one found is returned. + Otherwise, returns None and the original exception is used. """ ### ALERTING ### @@ -1522,6 +1527,9 @@ class ProxyLogging: original_exception=original_exception, ) + # Track the first HTTPException returned or raised by any callback + transformed_exception: Optional[HTTPException] = None + for callback in litellm.callbacks: try: _callback: Optional[CustomLogger] = None @@ -1532,19 +1540,31 @@ class ProxyLogging: else: _callback = callback # type: ignore if _callback is not None and isinstance(_callback, CustomLogger): - asyncio.create_task( - _callback.async_post_call_failure_hook( + try: + hook_result = await _callback.async_post_call_failure_hook( request_data=request_data, user_api_key_dict=user_api_key_dict, original_exception=original_exception, traceback_str=traceback_str, ) - ) + # If callback returned an HTTPException, use it (first one wins) + if isinstance(hook_result, HTTPException) and transformed_exception is None: + transformed_exception = hook_result + except HTTPException as e: + # If callback raised an HTTPException, use it (first one wins) + if transformed_exception is None: + transformed_exception = e + except Exception as e: + # Log non-HTTPException errors from callbacks but don't break the flow + verbose_proxy_logger.exception( + f"[Non-Blocking] Error in async_post_call_failure_hook callback: {e}" + ) except Exception as e: verbose_proxy_logger.exception( - f"[Non-Blocking] Error in post_call_failure_hook: {e}" + f"[Non-Blocking] Error setting up post_call_failure_hook callback: {e}" ) - return + + return transformed_exception def _is_proxy_only_llm_api_error( self, diff --git a/tests/test_litellm/proxy/auth/test_user_api_key_auth.py b/tests/test_litellm/proxy/auth/test_user_api_key_auth.py index 04aeddb8f2..fcc8c1f0f2 100644 --- a/tests/test_litellm/proxy/auth/test_user_api_key_auth.py +++ b/tests/test_litellm/proxy/auth/test_user_api_key_auth.py @@ -278,8 +278,8 @@ async def test_proxy_admin_expired_key_from_cache(): mock_proxy_logging_obj.internal_usage_cache = MagicMock() mock_proxy_logging_obj.internal_usage_cache.dual_cache = AsyncMock() mock_proxy_logging_obj.internal_usage_cache.dual_cache.async_delete_cache = AsyncMock() - # Mock post_call_failure_hook as async function - mock_proxy_logging_obj.post_call_failure_hook = AsyncMock() + # Mock post_call_failure_hook as async function returning None (no transformation) + mock_proxy_logging_obj.post_call_failure_hook = AsyncMock(return_value=None) # Mock prisma_client mock_prisma_client = MagicMock() diff --git a/tests/test_litellm/proxy/hooks/test_post_call_failure_hook_integration.py b/tests/test_litellm/proxy/hooks/test_post_call_failure_hook_integration.py new file mode 100644 index 0000000000..7223c2e1f0 --- /dev/null +++ b/tests/test_litellm/proxy/hooks/test_post_call_failure_hook_integration.py @@ -0,0 +1,146 @@ +""" +Integration tests for async_post_call_failure_hook. + +Tests verify that the failure hook can transform error responses sent to clients, +similar to how async_post_call_success_hook can transform successful responses. +""" + +import os +import sys +import pytest +from typing import Optional +from unittest.mock import patch + +sys.path.insert(0, os.path.abspath("../../../..")) + +from fastapi import HTTPException +from litellm.integrations.custom_logger import CustomLogger +from litellm.proxy._types import UserAPIKeyAuth + + +class ErrorTransformerLogger(CustomLogger): + """Logger that transforms errors into user-friendly messages""" + + def __init__(self): + self.called = False + self.transformed_exception = None + + async def async_post_call_failure_hook( + self, + request_data: dict, + original_exception: Exception, + user_api_key_dict: UserAPIKeyAuth, + traceback_str: Optional[str] = None, + ): + self.called = True + self.transformed_exception = HTTPException( + status_code=400, + detail="User-friendly error: Your request could not be processed." + ) + return self.transformed_exception + + +@pytest.mark.asyncio +async def test_failure_hook_transforms_error_response(): + """ + Test that async_post_call_failure_hook can transform error responses. + This mirrors how async_post_call_success_hook can transform successful responses. + """ + transformer = ErrorTransformerLogger() + + # Mock litellm.callbacks to include our transformer + with patch("litellm.callbacks", [transformer]): + from litellm.proxy.utils import ProxyLogging + from litellm.caching.caching import DualCache + + proxy_logging = ProxyLogging(user_api_key_cache=DualCache()) + original_exception = Exception("Technical error message") + request_data = {"model": "test-model"} + user_api_key_dict = UserAPIKeyAuth(api_key="test-key") + + # Call the hook + result = await proxy_logging.post_call_failure_hook( + request_data=request_data, + original_exception=original_exception, + user_api_key_dict=user_api_key_dict, + ) + + # Verify hook was called + assert transformer.called is True + + # Verify transformed exception is returned + assert result is not None + assert isinstance(result, HTTPException) + assert result.detail == "User-friendly error: Your request could not be processed." + + +@pytest.mark.asyncio +async def test_failure_hook_returns_none_when_no_transformation(): + """ + Test that hook returning None uses original exception. + """ + class NoOpLogger(CustomLogger): + def __init__(self): + self.called = False + + async def async_post_call_failure_hook(self, *args, **kwargs): + self.called = True + return None + + logger = NoOpLogger() + + with patch("litellm.callbacks", [logger]): + from litellm.proxy.utils import ProxyLogging + from litellm.caching.caching import DualCache + + proxy_logging = ProxyLogging(user_api_key_cache=DualCache()) + original_exception = Exception("Original error") + request_data = {"model": "test"} + user_api_key_dict = UserAPIKeyAuth(api_key="test") + + result = await proxy_logging.post_call_failure_hook( + request_data=request_data, + original_exception=original_exception, + user_api_key_dict=user_api_key_dict, + ) + + # Should return None (original exception will be used) + assert result is None + assert logger.called is True + + +@pytest.mark.asyncio +async def test_failure_hook_handles_exceptions_gracefully(): + """ + Test that hook failures don't break the error flow. + """ + class FailingLogger(CustomLogger): + def __init__(self): + self.called = False + + async def async_post_call_failure_hook(self, *args, **kwargs): + self.called = True + raise RuntimeError("Hook crashed!") + + logger = FailingLogger() + + with patch("litellm.callbacks", [logger]): + from litellm.proxy.utils import ProxyLogging + from litellm.caching.caching import DualCache + + proxy_logging = ProxyLogging(user_api_key_cache=DualCache()) + original_exception = Exception("Original error") + request_data = {"model": "test"} + user_api_key_dict = UserAPIKeyAuth(api_key="test") + + # Should not raise, should handle gracefully + result = await proxy_logging.post_call_failure_hook( + request_data=request_data, + original_exception=original_exception, + user_api_key_dict=user_api_key_dict, + ) + + # Should return None (original exception will be used) + assert result is None + assert logger.called is True + diff --git a/tests/test_litellm/proxy/pass_through_endpoints/test_llm_pass_through_endpoints.py b/tests/test_litellm/proxy/pass_through_endpoints/test_llm_pass_through_endpoints.py index b0e198d5e7..0bb9924af8 100644 --- a/tests/test_litellm/proxy/pass_through_endpoints/test_llm_pass_through_endpoints.py +++ b/tests/test_litellm/proxy/pass_through_endpoints/test_llm_pass_through_endpoints.py @@ -1148,7 +1148,7 @@ class TestBedrockLLMProxyRoute: mock_user_api_key_dict.allowed_model_region = None mock_proxy_logging_obj = Mock() - mock_proxy_logging_obj.post_call_failure_hook = AsyncMock() + mock_proxy_logging_obj.post_call_failure_hook = AsyncMock(return_value=None) endpoint = "model/test-model/converse" model = "test-model" @@ -1291,7 +1291,7 @@ class TestBedrockLLMProxyRoute: mock_user_api_key_dict = Mock() mock_user_api_key_dict.api_key = "test-key" mock_proxy_logging_obj = Mock() - mock_proxy_logging_obj.post_call_failure_hook = AsyncMock() + mock_proxy_logging_obj.post_call_failure_hook = AsyncMock(return_value=None) with patch( "litellm.passthrough.main.llm_passthrough_route",