Manual revert #19078

This commit is contained in:
Yuta Saito
2026-01-17 06:54:08 +09:00
parent 8d4bfabb6a
commit 18bcb429fc
2 changed files with 0 additions and 194 deletions
-9
View File
@@ -1408,15 +1408,6 @@ class Router:
async for item in model_response:
yield item
except MidStreamFallbackError as e:
# Check if fallbacks are disabled by user
if initial_kwargs.get("disable_fallbacks", False):
verbose_router_logger.info(
"Mid stream fallback disabled by user, re-raising original error"
)
if e.original_exception is not None:
raise e.original_exception
raise e
from litellm.main import stream_chunk_builder
complete_response_object = stream_chunk_builder(
-185
View File
@@ -1171,191 +1171,6 @@ async def test_acompletion_streaming_iterator_edge_cases():
print("✓ Edge case tests passed!")
@pytest.mark.asyncio
async def test_acompletion_streaming_disable_fallbacks_midstream():
"""Test that disable_fallbacks=True prevents mid-stream fallback attempts."""
from unittest.mock import MagicMock
from litellm.exceptions import MidStreamFallbackError
# Set up router with fallback configuration
router = litellm.Router(
model_list=[
{
"model_name": "gpt-4",
"litellm_params": {"model": "gpt-4", "api_key": "fake-key-1"},
},
{
"model_name": "gpt-3.5-turbo",
"litellm_params": {"model": "gpt-3.5-turbo", "api_key": "fake-key-2"},
},
],
fallbacks=[{"gpt-4": ["gpt-3.5-turbo"]}],
set_verbose=True,
)
messages = [{"role": "user", "content": "Hello"}]
# Test 1: disable_fallbacks=True with original_exception
print("\n=== Test 1: disable_fallbacks=True with original_exception ===")
# Create an original exception to wrap
from litellm.llms.anthropic.common_utils import AnthropicError
original_error = AnthropicError(
status_code=500,
message="An unexpected error occurred while processing the response",
)
# Create MidStreamFallbackError with original_exception
error_with_original = MidStreamFallbackError(
message="Connection lost",
model="gpt-4",
llm_provider="openai",
generated_content="Hello",
original_exception=original_error,
)
class AsyncIteratorWithError:
def __init__(self, items, error_after_index, error):
self.items = items
self.index = 0
self.error_after_index = error_after_index
self.error = error
self.chunks = []
self.model = "gpt-4"
self.custom_llm_provider = "openai"
self.logging_obj = MagicMock()
def __aiter__(self):
return self
async def __anext__(self):
if self.index >= len(self.items):
raise StopAsyncIteration
if self.index == self.error_after_index:
raise self.error
item = self.items[self.index]
self.index += 1
self.chunks.append(item)
return item
mock_chunks = [
MagicMock(choices=[MagicMock(delta=MagicMock(content="Hello"))]),
]
mock_error_response = AsyncIteratorWithError(
mock_chunks, 1, error_with_original
) # Error after first chunk
initial_kwargs = {"model": "gpt-4", "stream": True, "disable_fallbacks": True}
# Mock the fallback function to ensure it's NOT called
with patch.object(
router,
"async_function_with_fallbacks_common_utils",
return_value=MagicMock(),
) as mock_fallback_utils:
with pytest.raises(AnthropicError, match="An unexpected error occurred"):
result = await router._acompletion_streaming_iterator(
model_response=mock_error_response,
messages=messages,
initial_kwargs=initial_kwargs,
)
async for chunk in result:
pass # Should not reach here; exception should be raised
# Verify fallback was NOT called
mock_fallback_utils.assert_not_called()
print("✓ Original exception raised correctly when disable_fallbacks=True")
# Test 2: disable_fallbacks=True without original_exception
print("\n=== Test 2: disable_fallbacks=True without original_exception ===")
error_without_original = MidStreamFallbackError(
message="Connection lost",
model="gpt-4",
llm_provider="openai",
generated_content="Hello",
original_exception=None,
)
mock_error_response_2 = AsyncIteratorWithError(
mock_chunks, 1, error_without_original
)
with patch.object(
router,
"async_function_with_fallbacks_common_utils",
return_value=MagicMock(),
) as mock_fallback_utils:
with pytest.raises(MidStreamFallbackError, match="Connection lost"):
result = await router._acompletion_streaming_iterator(
model_response=mock_error_response_2,
messages=messages,
initial_kwargs=initial_kwargs,
)
async for chunk in result:
pass # Should not reach here
# Verify fallback was NOT called
mock_fallback_utils.assert_not_called()
print(
"✓ MidStreamFallbackError raised correctly when no original_exception and disable_fallbacks=True"
)
# Test 3: disable_fallbacks=False (default behavior - fallback should work)
print("\n=== Test 3: disable_fallbacks=False (fallback enabled) ===")
error_for_fallback = MidStreamFallbackError(
message="Connection lost",
model="gpt-4",
llm_provider="openai",
generated_content="Hello",
)
mock_error_response_3 = AsyncIteratorWithError(mock_chunks, 1, error_for_fallback)
# Mock successful fallback response
class EmptyAsyncIterator:
def __aiter__(self):
return self
async def __anext__(self):
raise StopAsyncIteration
mock_fallback_response = EmptyAsyncIterator()
initial_kwargs_fallback_enabled = {
"model": "gpt-4",
"stream": True,
"disable_fallbacks": False,
}
with patch.object(
router,
"async_function_with_fallbacks_common_utils",
return_value=mock_fallback_response,
) as mock_fallback_utils:
collected_chunks = []
result = await router._acompletion_streaming_iterator(
model_response=mock_error_response_3,
messages=messages,
initial_kwargs=initial_kwargs_fallback_enabled,
)
async for chunk in result:
collected_chunks.append(chunk)
# Verify fallback WAS called
assert mock_fallback_utils.called
print("✓ Fallback called correctly when disable_fallbacks=False")
print("\n=== All disable_fallbacks tests passed! ===")
@pytest.mark.asyncio
async def test_async_function_with_fallbacks_common_utils():
"""Test the async_function_with_fallbacks_common_utils method"""