Fix xdist test isolation: capture true defaults and poll instead of sleep

The conftest fixtures were saving/restoring the current (potentially
contaminated) values of litellm globals like num_retries instead of
resetting to true defaults. Under xdist, module-level assignments
(e.g. `litellm.num_retries = 3` in 12+ test files) pollute the
shared module state and leak across tests in the same worker.

- Capture true litellm defaults at conftest import time and reset
  before each test (local_testing + llm_translation)
- Make llm_translation/conftest.py xdist-safe (skip reload, add
  state isolation)
- Replace asyncio.sleep(2) with polling in cooldown handler tests
- Add @pytest.mark.flaky to tests making real API calls under xdist

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
yuneng-jiang
2026-03-15 22:26:42 -07:00
parent 65b3335735
commit 9711e3adfe
7 changed files with 130 additions and 34 deletions
+68 -7
View File
@@ -1,4 +1,9 @@
# conftest.py
#
# xdist-compatible test isolation for llm_translation tests.
# Mirrors the pattern in tests/local_testing/conftest.py:
# - Function-scoped fixture resets litellm globals to true defaults
# - Module-scoped reload only in single-process mode
import importlib
import os
@@ -11,7 +16,20 @@ sys.path.insert(
) # Adds the parent directory to the system path
import litellm
import asyncio
# ---------------------------------------------------------------------------
# Capture TRUE defaults at conftest import time (before test modules pollute).
# ---------------------------------------------------------------------------
_SCALAR_DEFAULTS = {
"num_retries": getattr(litellm, "num_retries", None),
"set_verbose": getattr(litellm, "set_verbose", False),
"cache": getattr(litellm, "cache", None),
"allowed_fails": getattr(litellm, "allowed_fails", 3),
"disable_aiohttp_transport": getattr(litellm, "disable_aiohttp_transport", False),
"force_ipv4": getattr(litellm, "force_ipv4", False),
"drop_params": getattr(litellm, "drop_params", None),
"modify_params": getattr(litellm, "modify_params", False),
}
@pytest.fixture(scope="session")
def event_loop():
@@ -29,20 +47,63 @@ def setup_and_teardown(event_loop): # Add event_loop as a dependency
sys.path.insert(0, os.path.abspath("../.."))
import litellm
from litellm import Router
from litellm.litellm_core_utils.logging_worker import GLOBAL_LOGGING_WORKER
# flush all logs
asyncio.run(GLOBAL_LOGGING_WORKER.clear_queue())
# ---- Save current state (for teardown restore) ----
original_state = {}
for attr in (
"callbacks",
"success_callback",
"failure_callback",
"_async_success_callback",
"_async_failure_callback",
):
if hasattr(litellm, attr):
val = getattr(litellm, attr)
original_state[attr] = val.copy() if val else []
importlib.reload(litellm)
for attr in _SCALAR_DEFAULTS:
if hasattr(litellm, attr):
original_state[attr] = getattr(litellm, attr)
# ---- Reset to true defaults before the test ----
worker_id = os.environ.get("PYTEST_XDIST_WORKER", None)
if worker_id is None:
# Single-process mode: reload for full reset
from litellm.litellm_core_utils.logging_worker import GLOBAL_LOGGING_WORKER
asyncio.run(GLOBAL_LOGGING_WORKER.clear_queue())
importlib.reload(litellm)
else:
# xdist mode: reset globals without reload
for attr in (
"callbacks",
"success_callback",
"failure_callback",
"_async_success_callback",
"_async_failure_callback",
):
if hasattr(litellm, attr):
setattr(litellm, attr, [])
for attr, default_val in _SCALAR_DEFAULTS.items():
if hasattr(litellm, attr):
setattr(litellm, attr, default_val)
if hasattr(litellm, "in_memory_llm_clients_cache"):
litellm.in_memory_llm_clients_cache.flush_cache()
# Set the event loop from the fixture
asyncio.set_event_loop(event_loop)
print(litellm)
yield
# ---- Teardown ----
if hasattr(litellm, "in_memory_llm_clients_cache"):
litellm.in_memory_llm_clients_cache.flush_cache()
for attr, original_value in original_state.items():
if hasattr(litellm, attr):
setattr(litellm, attr, original_value)
# Clean up any pending tasks
pending = asyncio.all_tasks(event_loop)
for task in pending:
+1
View File
@@ -65,6 +65,7 @@ async def test_chat_completion_cohere_citations(stream):
pytest.fail(f"Error occurred: {e}")
@pytest.mark.flaky(retries=3, delay=1)
def test_completion_cohere_command_r_plus_function_call():
litellm.set_verbose = True
tools = [
+46 -25
View File
@@ -4,6 +4,12 @@
# Pattern matches tests/test_litellm/conftest.py:
# - Function-scoped fixture saves/restores litellm globals (no reload)
# - Module-scoped fixture reloads only in single-process mode
#
# IMPORTANT: True defaults are captured at conftest import time (before any
# test module can pollute them via module-level assignments like
# `litellm.num_retries = 3`). The function-scoped fixture resets globals to
# these true defaults before every test, preventing cross-test contamination
# under xdist where module reload is skipped.
import importlib
import os
@@ -16,16 +22,40 @@ sys.path.insert(
) # Adds the parent directory to the system path
import litellm
# ---------------------------------------------------------------------------
# Capture TRUE defaults at conftest import time. This runs before any test
# module's top-level code (e.g. `litellm.num_retries = 3`) executes, so
# the values here are guaranteed to be the real package defaults.
# ---------------------------------------------------------------------------
_SCALAR_DEFAULTS = {
"num_retries": getattr(litellm, "num_retries", None),
"num_retries_per_request": getattr(litellm, "num_retries_per_request", None),
"request_timeout": getattr(litellm, "request_timeout", None),
"set_verbose": getattr(litellm, "set_verbose", False),
"cache": getattr(litellm, "cache", None),
"allowed_fails": getattr(litellm, "allowed_fails", 3),
"default_fallbacks": getattr(litellm, "default_fallbacks", None),
"enable_azure_ad_token_refresh": getattr(litellm, "enable_azure_ad_token_refresh", None),
"tag_budget_config": getattr(litellm, "tag_budget_config", None),
"model_cost": getattr(litellm, "model_cost", None),
"token_counter": getattr(litellm, "token_counter", None),
"disable_aiohttp_transport": getattr(litellm, "disable_aiohttp_transport", False),
"force_ipv4": getattr(litellm, "force_ipv4", False),
"drop_params": getattr(litellm, "drop_params", None),
"modify_params": getattr(litellm, "modify_params", False),
}
@pytest.fixture(scope="function", autouse=True)
def isolate_litellm_state():
"""
Per-function isolation fixture.
Saves and restores litellm callback/global state so tests don't leak
side effects. Works safely under pytest-xdist parallel execution.
Resets litellm globals to their true defaults before each test and
restores them afterward, so tests don't leak side effects.
Works safely under pytest-xdist parallel execution.
"""
# Save original callback state
# ---- Save current callback state (for teardown restore) ----
original_state = {}
for attr in (
"callbacks",
@@ -38,38 +68,23 @@ def isolate_litellm_state():
val = getattr(litellm, attr)
original_state[attr] = val.copy() if val else []
# Save other globals that tests commonly mutate
for attr in (
"set_verbose",
"cache",
"num_retries",
"num_retries_per_request",
"request_timeout",
"default_fallbacks",
"enable_azure_ad_token_refresh",
"tag_budget_config",
"model_cost",
"token_counter",
):
if hasattr(litellm, attr):
original_state[attr] = getattr(litellm, attr)
# Save rules that tests may set (e.g. test_rules.py)
# Save list-type globals
for attr in ("pre_call_rules", "post_call_rules"):
if hasattr(litellm, attr):
val = getattr(litellm, attr)
original_state[attr] = val.copy() if val else []
# Save transport/network globals
for attr in ("disable_aiohttp_transport", "force_ipv4"):
# Save scalar globals
for attr in _SCALAR_DEFAULTS:
if hasattr(litellm, attr):
original_state[attr] = getattr(litellm, attr)
# Flush cache before test
# ---- Reset to true defaults before the test ----
# Flush HTTP client cache
if hasattr(litellm, "in_memory_llm_clients_cache"):
litellm.in_memory_llm_clients_cache.flush_cache()
# Clear callbacks and rules before test
# Clear callbacks and rules
for attr in (
"callbacks",
"success_callback",
@@ -82,9 +97,15 @@ def isolate_litellm_state():
if hasattr(litellm, attr):
setattr(litellm, attr, [])
# Reset scalar globals to true defaults (prevents contamination from
# module-level code like `litellm.num_retries = 3` in test files)
for attr, default_val in _SCALAR_DEFAULTS.items():
if hasattr(litellm, attr):
setattr(litellm, attr, default_val)
yield
# Restore all saved state
# ---- Teardown: restore saved state ----
if hasattr(litellm, "in_memory_llm_clients_cache"):
litellm.in_memory_llm_clients_cache.flush_cache()
@@ -134,6 +134,7 @@ def test_multiple_deployments_parallel():
# test_multiple_deployments_parallel()
@pytest.mark.flaky(retries=3, delay=1)
@pytest.mark.parametrize("sync_mode", [True, False])
@pytest.mark.asyncio
async def test_cooldown_same_model_name(sync_mode):
+1
View File
@@ -51,6 +51,7 @@ router = Router(
)
@pytest.mark.flaky(retries=3, delay=1)
@pytest.mark.parametrize(
"model",
[
@@ -376,7 +376,12 @@ async def test_single_deployment_cooldown_with_allowed_fails():
except litellm.Timeout:
pass
await asyncio.sleep(2)
# Poll instead of fixed sleep — under xdist CPU contention 2s may
# not be enough for the async callback to fire.
for _ in range(100): # up to 10s
if mock_client.call_count >= 1:
break
await asyncio.sleep(0.1)
mock_client.assert_called_once()
@@ -426,7 +431,12 @@ async def test_single_deployment_cooldown_with_allowed_fail_policy():
except litellm.Timeout:
pass
await asyncio.sleep(2)
# Poll instead of fixed sleep — under xdist CPU contention 2s may
# not be enough for the async callback to fire.
for _ in range(100): # up to 10s
if mock_client.call_count >= 1:
break
await asyncio.sleep(0.1)
mock_client.assert_called_once()
@@ -154,6 +154,7 @@ async def _handle_router_calls(router):
print("done", chunk)
@pytest.mark.flaky(retries=3, delay=1)
@pytest.mark.asyncio
async def test_max_parallel_requests_rpm_rate_limiting():
"""