mirror of
https://github.com/tiennm99/litellm.git
synced 2026-06-18 00:48:01 +00:00
625ed3f8cf
* added qwen models and gpt-5-codex * fix flaky test * fix failing test * Added retries to prisma client state * fix: prisma client state retries in pods * Revert "fix failing test" This reverts commit dbec4988a2627257fd05b905e216225664517f32. * Revert "fix flaky test" This reverts commit b0ac2f2dc35ca433af0c82f3cda770d6981caff4. * Revert "added qwen models and gpt-5-codex" This reverts commit 9a8a8f2d47ab4dc8aecb0cd9a6a4f82ed81bb056. * Revert "fix: prisma client state retries in pods" This reverts commit 04e58e5ca1a489916e3b49e9b674f5c6713fd7cd. * fix lint * Revert "fix lint" This reverts commit 5303d52a5e3bee7e131dcabd098e94f0613a7bb9. * fixed lint
375 lines
14 KiB
Python
375 lines
14 KiB
Python
"""
|
|
Test backoff retry mechanisms for PrismaClient methods during _setup_prisma_client state.
|
|
|
|
This test validates that intermittent database connection issues are handled correctly
|
|
with exponential backoff retries for critical startup operations.
|
|
"""
|
|
|
|
import asyncio
|
|
import pytest
|
|
import time
|
|
from unittest.mock import AsyncMock, MagicMock, patch, call
|
|
from unittest.mock import Mock
|
|
import sys
|
|
import os
|
|
|
|
# Add project root to path
|
|
sys.path.insert(0, os.path.abspath("../.."))
|
|
|
|
from litellm.proxy.utils import PrismaClient, ProxyLogging
|
|
from prisma.errors import PrismaError, ClientNotConnectedError
|
|
import httpx
|
|
import backoff
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_prisma_client():
|
|
"""Create a mock PrismaClient with necessary attributes"""
|
|
client = MagicMock(spec=PrismaClient)
|
|
|
|
# Mock database connection
|
|
client.db = AsyncMock()
|
|
client.db.query_raw = AsyncMock()
|
|
|
|
# Mock proxy logging object
|
|
client.proxy_logging_obj = AsyncMock()
|
|
client.proxy_logging_obj.failure_handler = AsyncMock()
|
|
|
|
return client
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_proxy_logging():
|
|
"""Create a mock ProxyLogging object"""
|
|
proxy_logging = AsyncMock(spec=ProxyLogging)
|
|
proxy_logging.failure_handler = AsyncMock()
|
|
return proxy_logging
|
|
|
|
|
|
@pytest.fixture
|
|
def connection_errors():
|
|
"""Common database connection errors that should trigger retries"""
|
|
return [
|
|
httpx.ConnectError("Connection failed"),
|
|
httpx.TimeoutException("Request timeout"),
|
|
ClientNotConnectedError(),
|
|
PrismaError("Database connection lost"),
|
|
ConnectionError("Network connection failed"),
|
|
OSError("Connection refused"),
|
|
]
|
|
|
|
|
|
class TestPrismaClientBackoffRetry:
|
|
"""Test suite for PrismaClient backoff retry mechanisms"""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_health_check_success_no_retry(
|
|
self, mock_prisma_client, mock_proxy_logging
|
|
):
|
|
"""Test health_check succeeds immediately without retries"""
|
|
# Mock successful query response
|
|
mock_prisma_client.db.query_raw.return_value = [{"result": 1}]
|
|
|
|
# Create real PrismaClient instance with mocked db
|
|
client = PrismaClient(
|
|
database_url="mock://test", proxy_logging_obj=mock_proxy_logging
|
|
)
|
|
client.db = mock_prisma_client.db
|
|
client.proxy_logging_obj = mock_prisma_client.proxy_logging_obj
|
|
|
|
# Call health_check
|
|
result = await client.health_check()
|
|
|
|
# Verify success
|
|
assert result == [{"result": 1}]
|
|
mock_prisma_client.db.query_raw.assert_called_once_with("SELECT 1")
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_health_check_retry_then_success(
|
|
self, mock_prisma_client, mock_proxy_logging, connection_errors
|
|
):
|
|
"""Test health_check retries on connection errors and eventually succeeds"""
|
|
# Mock first two calls to fail, third to succeed
|
|
mock_prisma_client.db.query_raw.side_effect = [
|
|
connection_errors[0], # First call fails
|
|
connection_errors[1], # Second call fails
|
|
[{"result": 1}], # Third call succeeds
|
|
]
|
|
|
|
# Create real PrismaClient instance
|
|
client = PrismaClient(
|
|
database_url="mock://test", proxy_logging_obj=mock_proxy_logging
|
|
)
|
|
client.db = mock_prisma_client.db
|
|
client.proxy_logging_obj = mock_prisma_client.proxy_logging_obj
|
|
|
|
# Measure execution time to verify backoff delay
|
|
start_time = time.time()
|
|
result = await client.health_check()
|
|
end_time = time.time()
|
|
|
|
# Verify eventual success
|
|
assert result == [{"result": 1}]
|
|
|
|
# Verify retry attempts (3 calls total)
|
|
assert mock_prisma_client.db.query_raw.call_count == 3
|
|
|
|
# Verify backoff delay occurred (should be at least a few milliseconds)
|
|
assert end_time - start_time > 0.01
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_health_check_max_retries_exceeded(
|
|
self, mock_prisma_client, mock_proxy_logging, connection_errors
|
|
):
|
|
"""Test health_check fails after max retries (3) are exceeded"""
|
|
# Mock all calls to fail
|
|
mock_prisma_client.db.query_raw.side_effect = connection_errors[0]
|
|
|
|
# Create real PrismaClient instance
|
|
client = PrismaClient(
|
|
database_url="mock://test", proxy_logging_obj=mock_proxy_logging
|
|
)
|
|
client.db = mock_prisma_client.db
|
|
client.proxy_logging_obj = mock_prisma_client.proxy_logging_obj
|
|
|
|
# Expect final exception after retries
|
|
with pytest.raises(httpx.ConnectError):
|
|
await client.health_check()
|
|
|
|
# Verify max retries attempted (3 attempts)
|
|
assert mock_prisma_client.db.query_raw.call_count == 3
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_spend_logs_row_count_success_no_retry(
|
|
self, mock_prisma_client, mock_proxy_logging
|
|
):
|
|
"""Test _get_spend_logs_row_count succeeds immediately"""
|
|
# Mock successful query response
|
|
mock_prisma_client.db.query_raw.return_value = [{"reltuples": 1000}]
|
|
|
|
# Create real PrismaClient instance
|
|
client = PrismaClient(
|
|
database_url="mock://test", proxy_logging_obj=mock_proxy_logging
|
|
)
|
|
client.db = mock_prisma_client.db
|
|
client.proxy_logging_obj = mock_prisma_client.proxy_logging_obj
|
|
|
|
result = await client._get_spend_logs_row_count()
|
|
|
|
assert result == 1000
|
|
mock_prisma_client.db.query_raw.assert_called_once()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_spend_logs_row_count_retry_then_success(
|
|
self, mock_prisma_client, mock_proxy_logging, connection_errors
|
|
):
|
|
"""Test _get_spend_logs_row_count retries and eventually succeeds"""
|
|
# Mock first call fails, second succeeds
|
|
mock_prisma_client.db.query_raw.side_effect = [
|
|
connection_errors[2], # First call fails (ClientNotConnectedError)
|
|
[{"reltuples": 500}], # Second call succeeds
|
|
]
|
|
|
|
client = PrismaClient(
|
|
database_url="mock://test", proxy_logging_obj=mock_proxy_logging
|
|
)
|
|
client.db = mock_prisma_client.db
|
|
client.proxy_logging_obj = mock_prisma_client.proxy_logging_obj
|
|
|
|
result = await client._get_spend_logs_row_count()
|
|
|
|
assert mock_prisma_client.db.query_raw.call_count == 2
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_spend_logs_row_count_handles_errors_gracefully(
|
|
self, mock_prisma_client, mock_proxy_logging
|
|
):
|
|
"""Test _get_spend_logs_row_count returns 0 on persistent errors"""
|
|
# Mock all calls to fail
|
|
mock_prisma_client.db.query_raw.side_effect = PrismaError("Persistent DB error")
|
|
|
|
client = PrismaClient(
|
|
database_url="mock://test", proxy_logging_obj=mock_proxy_logging
|
|
)
|
|
client.db = mock_prisma_client.db
|
|
client.proxy_logging_obj = mock_prisma_client.proxy_logging_obj
|
|
|
|
result = await client._get_spend_logs_row_count()
|
|
|
|
assert result == 0
|
|
assert mock_prisma_client.db.query_raw.call_count == 3 # Max retries attempted
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_set_spend_logs_row_count_in_proxy_state_success(
|
|
self, mock_prisma_client, mock_proxy_logging
|
|
):
|
|
"""Test _set_spend_logs_row_count_in_proxy_state succeeds"""
|
|
# Mock successful query response
|
|
mock_prisma_client.db.query_raw.return_value = [{"reltuples": 2000}]
|
|
|
|
client = PrismaClient(
|
|
database_url="mock://test", proxy_logging_obj=mock_proxy_logging
|
|
)
|
|
client.db = mock_prisma_client.db
|
|
client.proxy_logging_obj = mock_prisma_client.proxy_logging_obj
|
|
|
|
# Mock proxy_state
|
|
with patch("litellm.proxy.proxy_server.proxy_state") as mock_proxy_state:
|
|
mock_proxy_state.set_proxy_state_variable = Mock()
|
|
|
|
await client._set_spend_logs_row_count_in_proxy_state()
|
|
|
|
# Verify proxy state was updated
|
|
mock_proxy_state.set_proxy_state_variable.assert_called_once_with(
|
|
variable_name="spend_logs_row_count", value=2000
|
|
)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_set_spend_logs_row_count_retry_behavior(
|
|
self, mock_prisma_client, mock_proxy_logging, connection_errors
|
|
):
|
|
"""Test _set_spend_logs_row_count_in_proxy_state retries on database errors"""
|
|
mock_prisma_client.db.query_raw.side_effect = [
|
|
connection_errors[3], # First call fails (PrismaError)
|
|
[{"reltuples": 1500}], # Second call succeeds
|
|
]
|
|
|
|
client = PrismaClient(
|
|
database_url="mock://test", proxy_logging_obj=mock_proxy_logging
|
|
)
|
|
client.db = mock_prisma_client.db
|
|
client.proxy_logging_obj = mock_prisma_client.proxy_logging_obj
|
|
|
|
with patch("litellm.proxy.proxy_server.proxy_state") as mock_proxy_state:
|
|
mock_proxy_state.set_proxy_state_variable = Mock()
|
|
|
|
await client._set_spend_logs_row_count_in_proxy_state()
|
|
|
|
assert mock_prisma_client.db.query_raw.call_count == 2
|
|
mock_proxy_state.set_proxy_state_variable.assert_called_once_with(
|
|
variable_name="spend_logs_row_count", value=1500
|
|
)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_backoff_configuration_parameters(self, mock_proxy_logging):
|
|
"""Test that backoff decorators are configured with correct parameters"""
|
|
client = PrismaClient(
|
|
database_url="mock://test", proxy_logging_obj=mock_proxy_logging
|
|
)
|
|
|
|
# Check that methods have backoff decorators
|
|
assert hasattr(client.health_check, "__wrapped__")
|
|
assert hasattr(client._set_spend_logs_row_count_in_proxy_state, "__wrapped__")
|
|
|
|
# Verify backoff configuration exists (methods should have retry behavior)
|
|
# This is implicit verification - the decorators are applied in the source code
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_multiple_connection_error_types(
|
|
self, mock_prisma_client, mock_proxy_logging
|
|
):
|
|
"""Test that different types of connection errors all trigger retries"""
|
|
error_types = [
|
|
httpx.ConnectError("Connection error"),
|
|
httpx.TimeoutException("Timeout error"),
|
|
ClientNotConnectedError(),
|
|
PrismaError("Database error"),
|
|
ConnectionError("Network error"),
|
|
OSError("OS-level connection error"),
|
|
]
|
|
|
|
for error_type in error_types:
|
|
# Reset mock for each test
|
|
mock_prisma_client.db.query_raw.reset_mock()
|
|
mock_prisma_client.db.query_raw.side_effect = [
|
|
error_type, # First call fails
|
|
[{"result": 1}], # Second call succeeds
|
|
]
|
|
|
|
client = PrismaClient(
|
|
database_url="mock://test", proxy_logging_obj=mock_proxy_logging
|
|
)
|
|
client.db = mock_prisma_client.db
|
|
client.proxy_logging_obj = mock_prisma_client.proxy_logging_obj
|
|
|
|
# Should succeed after retry
|
|
result = await client.health_check()
|
|
assert result == [{"result": 1}]
|
|
assert mock_prisma_client.db.query_raw.call_count == 2
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_setup_prisma_client_integration(
|
|
self, mock_prisma_client, mock_proxy_logging
|
|
):
|
|
"""Test simulated _setup_prisma_client flow with intermittent failures"""
|
|
# This simulates the actual flow that happens in proxy_server.py _setup_prisma_client
|
|
|
|
client = PrismaClient(
|
|
database_url="mock://test", proxy_logging_obj=mock_proxy_logging
|
|
)
|
|
client.db = mock_prisma_client.db
|
|
client.proxy_logging_obj = mock_prisma_client.proxy_logging_obj
|
|
|
|
# Simulate intermittent failures followed by success
|
|
call_count = 0
|
|
|
|
def mock_query_side_effect(*args, **kwargs):
|
|
nonlocal call_count
|
|
call_count += 1
|
|
if call_count <= 2: # First two calls fail
|
|
raise httpx.ConnectError("Intermittent connection issue")
|
|
else:
|
|
if "SELECT 1" in str(args):
|
|
return [{"result": 1}]
|
|
else:
|
|
return [{"reltuples": 1000}]
|
|
|
|
mock_prisma_client.db.query_raw.side_effect = mock_query_side_effect
|
|
|
|
with patch("litellm.proxy.proxy_server.proxy_state") as mock_proxy_state:
|
|
mock_proxy_state.set_proxy_state_variable = Mock()
|
|
|
|
# Execute the two critical calls from _setup_prisma_client
|
|
health_result = await client.health_check()
|
|
await client._set_spend_logs_row_count_in_proxy_state()
|
|
|
|
# Verify both operations eventually succeeded despite initial failures
|
|
assert health_result == [{"result": 1}]
|
|
mock_proxy_state.set_proxy_state_variable.assert_called_once()
|
|
|
|
# Verify retries occurred (should have more than 2 total query calls)
|
|
assert mock_prisma_client.db.query_raw.call_count >= 4
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_backoff_timing_constraints(
|
|
self, mock_prisma_client, mock_proxy_logging, connection_errors
|
|
):
|
|
"""Test that backoff respects max_time constraint (10 seconds)"""
|
|
# Mock all calls to fail to test max_time
|
|
mock_prisma_client.db.query_raw.side_effect = connection_errors[0]
|
|
|
|
client = PrismaClient(
|
|
database_url="mock://test", proxy_logging_obj=mock_proxy_logging
|
|
)
|
|
client.db = mock_prisma_client.db
|
|
client.proxy_logging_obj = mock_prisma_client.proxy_logging_obj
|
|
|
|
start_time = time.time()
|
|
|
|
with pytest.raises(httpx.ConnectError):
|
|
await client.health_check()
|
|
|
|
end_time = time.time()
|
|
duration = end_time - start_time
|
|
|
|
# Should not exceed max_time of 10 seconds significantly
|
|
# Adding small buffer for test execution overhead
|
|
assert duration < 12.0
|
|
|
|
# Should have attempted max_tries (3) retries
|
|
assert mock_prisma_client.db.query_raw.call_count == 3
|
|
|
|
|
|
if __name__ == "__main__":
|
|
pytest.main([__file__, "-v"])
|