mirror of
https://github.com/tiennm99/litellm.git
synced 2026-06-18 00:48:01 +00:00
Manual revert #19078
This commit is contained in:
@@ -1408,15 +1408,6 @@ class Router:
|
||||
async for item in model_response:
|
||||
yield item
|
||||
except MidStreamFallbackError as e:
|
||||
# Check if fallbacks are disabled by user
|
||||
if initial_kwargs.get("disable_fallbacks", False):
|
||||
verbose_router_logger.info(
|
||||
"Mid stream fallback disabled by user, re-raising original error"
|
||||
)
|
||||
if e.original_exception is not None:
|
||||
raise e.original_exception
|
||||
raise e
|
||||
|
||||
from litellm.main import stream_chunk_builder
|
||||
|
||||
complete_response_object = stream_chunk_builder(
|
||||
|
||||
@@ -1171,191 +1171,6 @@ async def test_acompletion_streaming_iterator_edge_cases():
|
||||
print("✓ Edge case tests passed!")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_acompletion_streaming_disable_fallbacks_midstream():
|
||||
"""Test that disable_fallbacks=True prevents mid-stream fallback attempts."""
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from litellm.exceptions import MidStreamFallbackError
|
||||
|
||||
# Set up router with fallback configuration
|
||||
router = litellm.Router(
|
||||
model_list=[
|
||||
{
|
||||
"model_name": "gpt-4",
|
||||
"litellm_params": {"model": "gpt-4", "api_key": "fake-key-1"},
|
||||
},
|
||||
{
|
||||
"model_name": "gpt-3.5-turbo",
|
||||
"litellm_params": {"model": "gpt-3.5-turbo", "api_key": "fake-key-2"},
|
||||
},
|
||||
],
|
||||
fallbacks=[{"gpt-4": ["gpt-3.5-turbo"]}],
|
||||
set_verbose=True,
|
||||
)
|
||||
|
||||
messages = [{"role": "user", "content": "Hello"}]
|
||||
|
||||
# Test 1: disable_fallbacks=True with original_exception
|
||||
print("\n=== Test 1: disable_fallbacks=True with original_exception ===")
|
||||
|
||||
# Create an original exception to wrap
|
||||
from litellm.llms.anthropic.common_utils import AnthropicError
|
||||
|
||||
original_error = AnthropicError(
|
||||
status_code=500,
|
||||
message="An unexpected error occurred while processing the response",
|
||||
)
|
||||
|
||||
# Create MidStreamFallbackError with original_exception
|
||||
error_with_original = MidStreamFallbackError(
|
||||
message="Connection lost",
|
||||
model="gpt-4",
|
||||
llm_provider="openai",
|
||||
generated_content="Hello",
|
||||
original_exception=original_error,
|
||||
)
|
||||
|
||||
class AsyncIteratorWithError:
|
||||
def __init__(self, items, error_after_index, error):
|
||||
self.items = items
|
||||
self.index = 0
|
||||
self.error_after_index = error_after_index
|
||||
self.error = error
|
||||
self.chunks = []
|
||||
self.model = "gpt-4"
|
||||
self.custom_llm_provider = "openai"
|
||||
self.logging_obj = MagicMock()
|
||||
|
||||
def __aiter__(self):
|
||||
return self
|
||||
|
||||
async def __anext__(self):
|
||||
if self.index >= len(self.items):
|
||||
raise StopAsyncIteration
|
||||
if self.index == self.error_after_index:
|
||||
raise self.error
|
||||
item = self.items[self.index]
|
||||
self.index += 1
|
||||
self.chunks.append(item)
|
||||
return item
|
||||
|
||||
mock_chunks = [
|
||||
MagicMock(choices=[MagicMock(delta=MagicMock(content="Hello"))]),
|
||||
]
|
||||
|
||||
mock_error_response = AsyncIteratorWithError(
|
||||
mock_chunks, 1, error_with_original
|
||||
) # Error after first chunk
|
||||
|
||||
initial_kwargs = {"model": "gpt-4", "stream": True, "disable_fallbacks": True}
|
||||
|
||||
# Mock the fallback function to ensure it's NOT called
|
||||
with patch.object(
|
||||
router,
|
||||
"async_function_with_fallbacks_common_utils",
|
||||
return_value=MagicMock(),
|
||||
) as mock_fallback_utils:
|
||||
with pytest.raises(AnthropicError, match="An unexpected error occurred"):
|
||||
result = await router._acompletion_streaming_iterator(
|
||||
model_response=mock_error_response,
|
||||
messages=messages,
|
||||
initial_kwargs=initial_kwargs,
|
||||
)
|
||||
|
||||
async for chunk in result:
|
||||
pass # Should not reach here; exception should be raised
|
||||
|
||||
# Verify fallback was NOT called
|
||||
mock_fallback_utils.assert_not_called()
|
||||
print("✓ Original exception raised correctly when disable_fallbacks=True")
|
||||
|
||||
# Test 2: disable_fallbacks=True without original_exception
|
||||
print("\n=== Test 2: disable_fallbacks=True without original_exception ===")
|
||||
|
||||
error_without_original = MidStreamFallbackError(
|
||||
message="Connection lost",
|
||||
model="gpt-4",
|
||||
llm_provider="openai",
|
||||
generated_content="Hello",
|
||||
original_exception=None,
|
||||
)
|
||||
|
||||
mock_error_response_2 = AsyncIteratorWithError(
|
||||
mock_chunks, 1, error_without_original
|
||||
)
|
||||
|
||||
with patch.object(
|
||||
router,
|
||||
"async_function_with_fallbacks_common_utils",
|
||||
return_value=MagicMock(),
|
||||
) as mock_fallback_utils:
|
||||
with pytest.raises(MidStreamFallbackError, match="Connection lost"):
|
||||
result = await router._acompletion_streaming_iterator(
|
||||
model_response=mock_error_response_2,
|
||||
messages=messages,
|
||||
initial_kwargs=initial_kwargs,
|
||||
)
|
||||
|
||||
async for chunk in result:
|
||||
pass # Should not reach here
|
||||
|
||||
# Verify fallback was NOT called
|
||||
mock_fallback_utils.assert_not_called()
|
||||
print(
|
||||
"✓ MidStreamFallbackError raised correctly when no original_exception and disable_fallbacks=True"
|
||||
)
|
||||
|
||||
# Test 3: disable_fallbacks=False (default behavior - fallback should work)
|
||||
print("\n=== Test 3: disable_fallbacks=False (fallback enabled) ===")
|
||||
|
||||
error_for_fallback = MidStreamFallbackError(
|
||||
message="Connection lost",
|
||||
model="gpt-4",
|
||||
llm_provider="openai",
|
||||
generated_content="Hello",
|
||||
)
|
||||
|
||||
mock_error_response_3 = AsyncIteratorWithError(mock_chunks, 1, error_for_fallback)
|
||||
|
||||
# Mock successful fallback response
|
||||
class EmptyAsyncIterator:
|
||||
def __aiter__(self):
|
||||
return self
|
||||
|
||||
async def __anext__(self):
|
||||
raise StopAsyncIteration
|
||||
|
||||
mock_fallback_response = EmptyAsyncIterator()
|
||||
|
||||
initial_kwargs_fallback_enabled = {
|
||||
"model": "gpt-4",
|
||||
"stream": True,
|
||||
"disable_fallbacks": False,
|
||||
}
|
||||
|
||||
with patch.object(
|
||||
router,
|
||||
"async_function_with_fallbacks_common_utils",
|
||||
return_value=mock_fallback_response,
|
||||
) as mock_fallback_utils:
|
||||
collected_chunks = []
|
||||
result = await router._acompletion_streaming_iterator(
|
||||
model_response=mock_error_response_3,
|
||||
messages=messages,
|
||||
initial_kwargs=initial_kwargs_fallback_enabled,
|
||||
)
|
||||
|
||||
async for chunk in result:
|
||||
collected_chunks.append(chunk)
|
||||
|
||||
# Verify fallback WAS called
|
||||
assert mock_fallback_utils.called
|
||||
print("✓ Fallback called correctly when disable_fallbacks=False")
|
||||
|
||||
print("\n=== All disable_fallbacks tests passed! ===")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_function_with_fallbacks_common_utils():
|
||||
"""Test the async_function_with_fallbacks_common_utils method"""
|
||||
|
||||
Reference in New Issue
Block a user