mirror of
https://github.com/tiennm99/litellm.git
synced 2026-06-18 05:28:02 +00:00
Fix xdist test isolation: capture true defaults and poll instead of sleep
The conftest fixtures were saving/restoring the current (potentially contaminated) values of litellm globals like num_retries instead of resetting to true defaults. Under xdist, module-level assignments (e.g. `litellm.num_retries = 3` in 12+ test files) pollute the shared module state and leak across tests in the same worker. - Capture true litellm defaults at conftest import time and reset before each test (local_testing + llm_translation) - Make llm_translation/conftest.py xdist-safe (skip reload, add state isolation) - Replace asyncio.sleep(2) with polling in cooldown handler tests - Add @pytest.mark.flaky to tests making real API calls under xdist Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -1,4 +1,9 @@
|
||||
# conftest.py
|
||||
#
|
||||
# xdist-compatible test isolation for llm_translation tests.
|
||||
# Mirrors the pattern in tests/local_testing/conftest.py:
|
||||
# - Function-scoped fixture resets litellm globals to true defaults
|
||||
# - Module-scoped reload only in single-process mode
|
||||
|
||||
import importlib
|
||||
import os
|
||||
@@ -11,7 +16,20 @@ sys.path.insert(
|
||||
) # Adds the parent directory to the system path
|
||||
import litellm
|
||||
|
||||
import asyncio
|
||||
# ---------------------------------------------------------------------------
|
||||
# Capture TRUE defaults at conftest import time (before test modules pollute).
|
||||
# ---------------------------------------------------------------------------
|
||||
_SCALAR_DEFAULTS = {
|
||||
"num_retries": getattr(litellm, "num_retries", None),
|
||||
"set_verbose": getattr(litellm, "set_verbose", False),
|
||||
"cache": getattr(litellm, "cache", None),
|
||||
"allowed_fails": getattr(litellm, "allowed_fails", 3),
|
||||
"disable_aiohttp_transport": getattr(litellm, "disable_aiohttp_transport", False),
|
||||
"force_ipv4": getattr(litellm, "force_ipv4", False),
|
||||
"drop_params": getattr(litellm, "drop_params", None),
|
||||
"modify_params": getattr(litellm, "modify_params", False),
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def event_loop():
|
||||
@@ -29,20 +47,63 @@ def setup_and_teardown(event_loop): # Add event_loop as a dependency
|
||||
sys.path.insert(0, os.path.abspath("../.."))
|
||||
|
||||
import litellm
|
||||
from litellm import Router
|
||||
|
||||
from litellm.litellm_core_utils.logging_worker import GLOBAL_LOGGING_WORKER
|
||||
# flush all logs
|
||||
asyncio.run(GLOBAL_LOGGING_WORKER.clear_queue())
|
||||
# ---- Save current state (for teardown restore) ----
|
||||
original_state = {}
|
||||
for attr in (
|
||||
"callbacks",
|
||||
"success_callback",
|
||||
"failure_callback",
|
||||
"_async_success_callback",
|
||||
"_async_failure_callback",
|
||||
):
|
||||
if hasattr(litellm, attr):
|
||||
val = getattr(litellm, attr)
|
||||
original_state[attr] = val.copy() if val else []
|
||||
|
||||
importlib.reload(litellm)
|
||||
for attr in _SCALAR_DEFAULTS:
|
||||
if hasattr(litellm, attr):
|
||||
original_state[attr] = getattr(litellm, attr)
|
||||
|
||||
# ---- Reset to true defaults before the test ----
|
||||
worker_id = os.environ.get("PYTEST_XDIST_WORKER", None)
|
||||
if worker_id is None:
|
||||
# Single-process mode: reload for full reset
|
||||
from litellm.litellm_core_utils.logging_worker import GLOBAL_LOGGING_WORKER
|
||||
asyncio.run(GLOBAL_LOGGING_WORKER.clear_queue())
|
||||
importlib.reload(litellm)
|
||||
else:
|
||||
# xdist mode: reset globals without reload
|
||||
for attr in (
|
||||
"callbacks",
|
||||
"success_callback",
|
||||
"failure_callback",
|
||||
"_async_success_callback",
|
||||
"_async_failure_callback",
|
||||
):
|
||||
if hasattr(litellm, attr):
|
||||
setattr(litellm, attr, [])
|
||||
|
||||
for attr, default_val in _SCALAR_DEFAULTS.items():
|
||||
if hasattr(litellm, attr):
|
||||
setattr(litellm, attr, default_val)
|
||||
|
||||
if hasattr(litellm, "in_memory_llm_clients_cache"):
|
||||
litellm.in_memory_llm_clients_cache.flush_cache()
|
||||
|
||||
# Set the event loop from the fixture
|
||||
asyncio.set_event_loop(event_loop)
|
||||
|
||||
print(litellm)
|
||||
yield
|
||||
|
||||
# ---- Teardown ----
|
||||
if hasattr(litellm, "in_memory_llm_clients_cache"):
|
||||
litellm.in_memory_llm_clients_cache.flush_cache()
|
||||
|
||||
for attr, original_value in original_state.items():
|
||||
if hasattr(litellm, attr):
|
||||
setattr(litellm, attr, original_value)
|
||||
|
||||
# Clean up any pending tasks
|
||||
pending = asyncio.all_tasks(event_loop)
|
||||
for task in pending:
|
||||
|
||||
@@ -65,6 +65,7 @@ async def test_chat_completion_cohere_citations(stream):
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
||||
@pytest.mark.flaky(retries=3, delay=1)
|
||||
def test_completion_cohere_command_r_plus_function_call():
|
||||
litellm.set_verbose = True
|
||||
tools = [
|
||||
|
||||
@@ -4,6 +4,12 @@
|
||||
# Pattern matches tests/test_litellm/conftest.py:
|
||||
# - Function-scoped fixture saves/restores litellm globals (no reload)
|
||||
# - Module-scoped fixture reloads only in single-process mode
|
||||
#
|
||||
# IMPORTANT: True defaults are captured at conftest import time (before any
|
||||
# test module can pollute them via module-level assignments like
|
||||
# `litellm.num_retries = 3`). The function-scoped fixture resets globals to
|
||||
# these true defaults before every test, preventing cross-test contamination
|
||||
# under xdist where module reload is skipped.
|
||||
|
||||
import importlib
|
||||
import os
|
||||
@@ -16,16 +22,40 @@ sys.path.insert(
|
||||
) # Adds the parent directory to the system path
|
||||
import litellm
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Capture TRUE defaults at conftest import time. This runs before any test
|
||||
# module's top-level code (e.g. `litellm.num_retries = 3`) executes, so
|
||||
# the values here are guaranteed to be the real package defaults.
|
||||
# ---------------------------------------------------------------------------
|
||||
_SCALAR_DEFAULTS = {
|
||||
"num_retries": getattr(litellm, "num_retries", None),
|
||||
"num_retries_per_request": getattr(litellm, "num_retries_per_request", None),
|
||||
"request_timeout": getattr(litellm, "request_timeout", None),
|
||||
"set_verbose": getattr(litellm, "set_verbose", False),
|
||||
"cache": getattr(litellm, "cache", None),
|
||||
"allowed_fails": getattr(litellm, "allowed_fails", 3),
|
||||
"default_fallbacks": getattr(litellm, "default_fallbacks", None),
|
||||
"enable_azure_ad_token_refresh": getattr(litellm, "enable_azure_ad_token_refresh", None),
|
||||
"tag_budget_config": getattr(litellm, "tag_budget_config", None),
|
||||
"model_cost": getattr(litellm, "model_cost", None),
|
||||
"token_counter": getattr(litellm, "token_counter", None),
|
||||
"disable_aiohttp_transport": getattr(litellm, "disable_aiohttp_transport", False),
|
||||
"force_ipv4": getattr(litellm, "force_ipv4", False),
|
||||
"drop_params": getattr(litellm, "drop_params", None),
|
||||
"modify_params": getattr(litellm, "modify_params", False),
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture(scope="function", autouse=True)
|
||||
def isolate_litellm_state():
|
||||
"""
|
||||
Per-function isolation fixture.
|
||||
|
||||
Saves and restores litellm callback/global state so tests don't leak
|
||||
side effects. Works safely under pytest-xdist parallel execution.
|
||||
Resets litellm globals to their true defaults before each test and
|
||||
restores them afterward, so tests don't leak side effects.
|
||||
Works safely under pytest-xdist parallel execution.
|
||||
"""
|
||||
# Save original callback state
|
||||
# ---- Save current callback state (for teardown restore) ----
|
||||
original_state = {}
|
||||
for attr in (
|
||||
"callbacks",
|
||||
@@ -38,38 +68,23 @@ def isolate_litellm_state():
|
||||
val = getattr(litellm, attr)
|
||||
original_state[attr] = val.copy() if val else []
|
||||
|
||||
# Save other globals that tests commonly mutate
|
||||
for attr in (
|
||||
"set_verbose",
|
||||
"cache",
|
||||
"num_retries",
|
||||
"num_retries_per_request",
|
||||
"request_timeout",
|
||||
"default_fallbacks",
|
||||
"enable_azure_ad_token_refresh",
|
||||
"tag_budget_config",
|
||||
"model_cost",
|
||||
"token_counter",
|
||||
):
|
||||
if hasattr(litellm, attr):
|
||||
original_state[attr] = getattr(litellm, attr)
|
||||
|
||||
# Save rules that tests may set (e.g. test_rules.py)
|
||||
# Save list-type globals
|
||||
for attr in ("pre_call_rules", "post_call_rules"):
|
||||
if hasattr(litellm, attr):
|
||||
val = getattr(litellm, attr)
|
||||
original_state[attr] = val.copy() if val else []
|
||||
|
||||
# Save transport/network globals
|
||||
for attr in ("disable_aiohttp_transport", "force_ipv4"):
|
||||
# Save scalar globals
|
||||
for attr in _SCALAR_DEFAULTS:
|
||||
if hasattr(litellm, attr):
|
||||
original_state[attr] = getattr(litellm, attr)
|
||||
|
||||
# Flush cache before test
|
||||
# ---- Reset to true defaults before the test ----
|
||||
# Flush HTTP client cache
|
||||
if hasattr(litellm, "in_memory_llm_clients_cache"):
|
||||
litellm.in_memory_llm_clients_cache.flush_cache()
|
||||
|
||||
# Clear callbacks and rules before test
|
||||
# Clear callbacks and rules
|
||||
for attr in (
|
||||
"callbacks",
|
||||
"success_callback",
|
||||
@@ -82,9 +97,15 @@ def isolate_litellm_state():
|
||||
if hasattr(litellm, attr):
|
||||
setattr(litellm, attr, [])
|
||||
|
||||
# Reset scalar globals to true defaults (prevents contamination from
|
||||
# module-level code like `litellm.num_retries = 3` in test files)
|
||||
for attr, default_val in _SCALAR_DEFAULTS.items():
|
||||
if hasattr(litellm, attr):
|
||||
setattr(litellm, attr, default_val)
|
||||
|
||||
yield
|
||||
|
||||
# Restore all saved state
|
||||
# ---- Teardown: restore saved state ----
|
||||
if hasattr(litellm, "in_memory_llm_clients_cache"):
|
||||
litellm.in_memory_llm_clients_cache.flush_cache()
|
||||
|
||||
|
||||
@@ -134,6 +134,7 @@ def test_multiple_deployments_parallel():
|
||||
|
||||
|
||||
# test_multiple_deployments_parallel()
|
||||
@pytest.mark.flaky(retries=3, delay=1)
|
||||
@pytest.mark.parametrize("sync_mode", [True, False])
|
||||
@pytest.mark.asyncio
|
||||
async def test_cooldown_same_model_name(sync_mode):
|
||||
|
||||
@@ -51,6 +51,7 @@ router = Router(
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.flaky(retries=3, delay=1)
|
||||
@pytest.mark.parametrize(
|
||||
"model",
|
||||
[
|
||||
|
||||
@@ -376,7 +376,12 @@ async def test_single_deployment_cooldown_with_allowed_fails():
|
||||
except litellm.Timeout:
|
||||
pass
|
||||
|
||||
await asyncio.sleep(2)
|
||||
# Poll instead of fixed sleep — under xdist CPU contention 2s may
|
||||
# not be enough for the async callback to fire.
|
||||
for _ in range(100): # up to 10s
|
||||
if mock_client.call_count >= 1:
|
||||
break
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
mock_client.assert_called_once()
|
||||
|
||||
@@ -426,7 +431,12 @@ async def test_single_deployment_cooldown_with_allowed_fail_policy():
|
||||
except litellm.Timeout:
|
||||
pass
|
||||
|
||||
await asyncio.sleep(2)
|
||||
# Poll instead of fixed sleep — under xdist CPU contention 2s may
|
||||
# not be enough for the async callback to fire.
|
||||
for _ in range(100): # up to 10s
|
||||
if mock_client.call_count >= 1:
|
||||
break
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
mock_client.assert_called_once()
|
||||
|
||||
|
||||
@@ -154,6 +154,7 @@ async def _handle_router_calls(router):
|
||||
print("done", chunk)
|
||||
|
||||
|
||||
@pytest.mark.flaky(retries=3, delay=1)
|
||||
@pytest.mark.asyncio
|
||||
async def test_max_parallel_requests_rpm_rate_limiting():
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user