fix(cache): prevent DualCache async batch check-then-act race (#20986)

* fix(cache): prevent dual cache batch redis race under concurrency

* chore(cache): remove unused dual cache batch key helper

* chore(cache): align dual cache type hints and throttle comment
This commit is contained in:
Emerson Gomes
2026-02-11 23:37:28 -06:00
committed by Sameer Kankute
parent cba3bcf1a9
commit 8a650f0170
2 changed files with 106 additions and 18 deletions
+48 -18
View File
@@ -12,7 +12,8 @@ import asyncio
import time
import traceback
from concurrent.futures import ThreadPoolExecutor
from typing import TYPE_CHECKING, Any, List, Optional, Union
from threading import Lock
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
if TYPE_CHECKING:
from litellm.types.caching import RedisPipelineIncrementOperation
@@ -71,6 +72,7 @@ class DualCache(BaseCache):
self.last_redis_batch_access_time = LimitedSizeOrderedDict(
max_size=default_max_redis_batch_cache_size
)
self._last_redis_batch_access_time_lock = Lock()
self.redis_batch_cache_expiry = (
default_redis_batch_cache_expiry
or litellm.default_redis_batch_cache_expiry
@@ -236,22 +238,46 @@ class DualCache(BaseCache):
except Exception:
verbose_logger.error(traceback.format_exc())
def get_redis_batch_keys(
def _reserve_redis_batch_keys(
self,
current_time: float,
keys: List[str],
result: List[Any],
) -> List[str]:
sublist_keys = []
for key, value in zip(keys, result):
if value is None:
) -> Tuple[List[str], Dict[str, Optional[float]]]:
"""
Atomically choose keys to fetch from Redis and reserve their access time.
This prevents check-then-act races under concurrent async callers.
"""
sublist_keys: List[str] = []
previous_access_times: Dict[str, Optional[float]] = {}
with self._last_redis_batch_access_time_lock:
for key, value in zip(keys, result):
if value is not None:
continue
if (
key not in self.last_redis_batch_access_time
or current_time - self.last_redis_batch_access_time[key]
>= self.redis_batch_cache_expiry
):
sublist_keys.append(key)
return sublist_keys
previous_access_times[key] = self.last_redis_batch_access_time.get(
key
)
self.last_redis_batch_access_time[key] = current_time
return sublist_keys, previous_access_times
def _rollback_redis_batch_key_reservations(
self, previous_access_times: Dict[str, Optional[float]]
) -> None:
with self._last_redis_batch_access_time_lock:
for key, previous_time in previous_access_times.items():
if previous_time is None:
self.last_redis_batch_access_time.pop(key, None)
else:
self.last_redis_batch_access_time[key] = previous_time
async def async_batch_get_cache(
self,
@@ -276,19 +302,23 @@ class DualCache(BaseCache):
- check the redis cache
"""
current_time = time.time()
sublist_keys = self.get_redis_batch_keys(current_time, keys, result)
sublist_keys, previous_access_times = self._reserve_redis_batch_keys(
current_time, keys, result
)
# Only hit Redis if the last access time was more than 5 seconds ago
# Only hit Redis if enough time has passed since last access.
if len(sublist_keys) > 0:
# If not found in in-memory cache, try fetching from Redis
redis_result = await self.redis_cache.async_batch_get_cache(
sublist_keys, parent_otel_span=parent_otel_span
)
# Update the last access time for ALL queried keys
# This includes keys with None values to throttle repeated Redis queries
for key in sublist_keys:
self.last_redis_batch_access_time[key] = current_time
try:
# If not found in in-memory cache, try fetching from Redis
redis_result = await self.redis_cache.async_batch_get_cache(
sublist_keys, parent_otel_span=parent_otel_span
)
except Exception:
# Do not throttle subsequent callers if the Redis read fails.
self._rollback_redis_batch_key_reservations(
previous_access_times
)
raise
# Short-circuit if redis_result is None or contains only None values
if redis_result is None or all(v is None for v in redis_result.values()):
@@ -0,0 +1,58 @@
import asyncio
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from litellm.caching.dual_cache import DualCache
from litellm.caching.redis_cache import RedisCache
@pytest.mark.asyncio
async def test_dual_cache_async_batch_get_cache_coalesces_concurrent_redis_reads():
dual_cache = DualCache(
redis_cache=MagicMock(spec=RedisCache), default_redis_batch_cache_expiry=10
)
keys = ["shared_a", "shared_b"]
start_gate = asyncio.Event()
async def _mock_async_batch_get_cache(key_list, parent_otel_span=None):
await asyncio.sleep(0.05)
return {k: None for k in key_list}
with patch.object(
dual_cache.redis_cache,
"async_batch_get_cache",
new=AsyncMock(side_effect=_mock_async_batch_get_cache),
) as mock_async_batch_get_cache:
async def worker():
await start_gate.wait()
return await dual_cache.async_batch_get_cache(keys=keys)
tasks = [asyncio.create_task(worker()) for _ in range(50)]
start_gate.set()
await asyncio.gather(*tasks)
assert mock_async_batch_get_cache.call_count == 1
@pytest.mark.asyncio
async def test_dual_cache_async_batch_get_cache_rolls_back_redis_reservation_on_error():
dual_cache = DualCache(
redis_cache=MagicMock(spec=RedisCache), default_redis_batch_cache_expiry=10
)
keys = ["shared_a", "shared_b"]
with patch.object(
dual_cache.redis_cache,
"async_batch_get_cache",
new=AsyncMock(side_effect=RuntimeError("redis unavailable")),
) as mock_async_batch_get_cache:
first_result = await dual_cache.async_batch_get_cache(keys=keys)
second_result = await dual_cache.async_batch_get_cache(keys=keys)
assert first_result is None
assert second_result is None
assert mock_async_batch_get_cache.call_count == 2
assert "shared_a" not in dual_cache.last_redis_batch_access_time
assert "shared_b" not in dual_cache.last_redis_batch_access_time