From 8a650f0170db2f47499fc92deb30693df65002ca Mon Sep 17 00:00:00 2001 From: Emerson Gomes Date: Wed, 11 Feb 2026 23:37:28 -0600 Subject: [PATCH] fix(cache): prevent DualCache async batch check-then-act race (#20986) * fix(cache): prevent dual cache batch redis race under concurrency * chore(cache): remove unused dual cache batch key helper * chore(cache): align dual cache type hints and throttle comment --- litellm/caching/dual_cache.py | 66 ++++++++++++++----- tests/test_litellm/caching/test_dual_cache.py | 58 ++++++++++++++++ 2 files changed, 106 insertions(+), 18 deletions(-) create mode 100644 tests/test_litellm/caching/test_dual_cache.py diff --git a/litellm/caching/dual_cache.py b/litellm/caching/dual_cache.py index 3edc3f4282..6df570c72b 100644 --- a/litellm/caching/dual_cache.py +++ b/litellm/caching/dual_cache.py @@ -12,7 +12,8 @@ import asyncio import time import traceback from concurrent.futures import ThreadPoolExecutor -from typing import TYPE_CHECKING, Any, List, Optional, Union +from threading import Lock +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union if TYPE_CHECKING: from litellm.types.caching import RedisPipelineIncrementOperation @@ -71,6 +72,7 @@ class DualCache(BaseCache): self.last_redis_batch_access_time = LimitedSizeOrderedDict( max_size=default_max_redis_batch_cache_size ) + self._last_redis_batch_access_time_lock = Lock() self.redis_batch_cache_expiry = ( default_redis_batch_cache_expiry or litellm.default_redis_batch_cache_expiry @@ -236,22 +238,46 @@ class DualCache(BaseCache): except Exception: verbose_logger.error(traceback.format_exc()) - def get_redis_batch_keys( + def _reserve_redis_batch_keys( self, current_time: float, keys: List[str], result: List[Any], - ) -> List[str]: - sublist_keys = [] - for key, value in zip(keys, result): - if value is None: + ) -> Tuple[List[str], Dict[str, Optional[float]]]: + """ + Atomically choose keys to fetch from Redis and reserve their access time. + This prevents check-then-act races under concurrent async callers. + """ + sublist_keys: List[str] = [] + previous_access_times: Dict[str, Optional[float]] = {} + + with self._last_redis_batch_access_time_lock: + for key, value in zip(keys, result): + if value is not None: + continue + if ( key not in self.last_redis_batch_access_time or current_time - self.last_redis_batch_access_time[key] >= self.redis_batch_cache_expiry ): sublist_keys.append(key) - return sublist_keys + previous_access_times[key] = self.last_redis_batch_access_time.get( + key + ) + self.last_redis_batch_access_time[key] = current_time + + return sublist_keys, previous_access_times + + def _rollback_redis_batch_key_reservations( + self, previous_access_times: Dict[str, Optional[float]] + ) -> None: + with self._last_redis_batch_access_time_lock: + for key, previous_time in previous_access_times.items(): + if previous_time is None: + self.last_redis_batch_access_time.pop(key, None) + else: + self.last_redis_batch_access_time[key] = previous_time async def async_batch_get_cache( self, @@ -276,19 +302,23 @@ class DualCache(BaseCache): - check the redis cache """ current_time = time.time() - sublist_keys = self.get_redis_batch_keys(current_time, keys, result) + sublist_keys, previous_access_times = self._reserve_redis_batch_keys( + current_time, keys, result + ) - # Only hit Redis if the last access time was more than 5 seconds ago + # Only hit Redis if enough time has passed since last access. if len(sublist_keys) > 0: - # If not found in in-memory cache, try fetching from Redis - redis_result = await self.redis_cache.async_batch_get_cache( - sublist_keys, parent_otel_span=parent_otel_span - ) - - # Update the last access time for ALL queried keys - # This includes keys with None values to throttle repeated Redis queries - for key in sublist_keys: - self.last_redis_batch_access_time[key] = current_time + try: + # If not found in in-memory cache, try fetching from Redis + redis_result = await self.redis_cache.async_batch_get_cache( + sublist_keys, parent_otel_span=parent_otel_span + ) + except Exception: + # Do not throttle subsequent callers if the Redis read fails. + self._rollback_redis_batch_key_reservations( + previous_access_times + ) + raise # Short-circuit if redis_result is None or contains only None values if redis_result is None or all(v is None for v in redis_result.values()): diff --git a/tests/test_litellm/caching/test_dual_cache.py b/tests/test_litellm/caching/test_dual_cache.py new file mode 100644 index 0000000000..9974c23e4b --- /dev/null +++ b/tests/test_litellm/caching/test_dual_cache.py @@ -0,0 +1,58 @@ +import asyncio +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from litellm.caching.dual_cache import DualCache +from litellm.caching.redis_cache import RedisCache + + +@pytest.mark.asyncio +async def test_dual_cache_async_batch_get_cache_coalesces_concurrent_redis_reads(): + dual_cache = DualCache( + redis_cache=MagicMock(spec=RedisCache), default_redis_batch_cache_expiry=10 + ) + keys = ["shared_a", "shared_b"] + start_gate = asyncio.Event() + + async def _mock_async_batch_get_cache(key_list, parent_otel_span=None): + await asyncio.sleep(0.05) + return {k: None for k in key_list} + + with patch.object( + dual_cache.redis_cache, + "async_batch_get_cache", + new=AsyncMock(side_effect=_mock_async_batch_get_cache), + ) as mock_async_batch_get_cache: + + async def worker(): + await start_gate.wait() + return await dual_cache.async_batch_get_cache(keys=keys) + + tasks = [asyncio.create_task(worker()) for _ in range(50)] + start_gate.set() + await asyncio.gather(*tasks) + + assert mock_async_batch_get_cache.call_count == 1 + + +@pytest.mark.asyncio +async def test_dual_cache_async_batch_get_cache_rolls_back_redis_reservation_on_error(): + dual_cache = DualCache( + redis_cache=MagicMock(spec=RedisCache), default_redis_batch_cache_expiry=10 + ) + keys = ["shared_a", "shared_b"] + + with patch.object( + dual_cache.redis_cache, + "async_batch_get_cache", + new=AsyncMock(side_effect=RuntimeError("redis unavailable")), + ) as mock_async_batch_get_cache: + first_result = await dual_cache.async_batch_get_cache(keys=keys) + second_result = await dual_cache.async_batch_get_cache(keys=keys) + + assert first_result is None + assert second_result is None + assert mock_async_batch_get_cache.call_count == 2 + assert "shared_a" not in dual_cache.last_redis_batch_access_time + assert "shared_b" not in dual_cache.last_redis_batch_access_time