mirror of
https://github.com/tiennm99/litellm.git
synced 2026-06-18 03:31:23 +00:00
fix(cache): prevent DualCache async batch check-then-act race (#20986)
* fix(cache): prevent dual cache batch redis race under concurrency * chore(cache): remove unused dual cache batch key helper * chore(cache): align dual cache type hints and throttle comment
This commit is contained in:
committed by
Sameer Kankute
parent
cba3bcf1a9
commit
8a650f0170
@@ -12,7 +12,8 @@ import asyncio
|
||||
import time
|
||||
import traceback
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import TYPE_CHECKING, Any, List, Optional, Union
|
||||
from threading import Lock
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.types.caching import RedisPipelineIncrementOperation
|
||||
@@ -71,6 +72,7 @@ class DualCache(BaseCache):
|
||||
self.last_redis_batch_access_time = LimitedSizeOrderedDict(
|
||||
max_size=default_max_redis_batch_cache_size
|
||||
)
|
||||
self._last_redis_batch_access_time_lock = Lock()
|
||||
self.redis_batch_cache_expiry = (
|
||||
default_redis_batch_cache_expiry
|
||||
or litellm.default_redis_batch_cache_expiry
|
||||
@@ -236,22 +238,46 @@ class DualCache(BaseCache):
|
||||
except Exception:
|
||||
verbose_logger.error(traceback.format_exc())
|
||||
|
||||
def get_redis_batch_keys(
|
||||
def _reserve_redis_batch_keys(
|
||||
self,
|
||||
current_time: float,
|
||||
keys: List[str],
|
||||
result: List[Any],
|
||||
) -> List[str]:
|
||||
sublist_keys = []
|
||||
for key, value in zip(keys, result):
|
||||
if value is None:
|
||||
) -> Tuple[List[str], Dict[str, Optional[float]]]:
|
||||
"""
|
||||
Atomically choose keys to fetch from Redis and reserve their access time.
|
||||
This prevents check-then-act races under concurrent async callers.
|
||||
"""
|
||||
sublist_keys: List[str] = []
|
||||
previous_access_times: Dict[str, Optional[float]] = {}
|
||||
|
||||
with self._last_redis_batch_access_time_lock:
|
||||
for key, value in zip(keys, result):
|
||||
if value is not None:
|
||||
continue
|
||||
|
||||
if (
|
||||
key not in self.last_redis_batch_access_time
|
||||
or current_time - self.last_redis_batch_access_time[key]
|
||||
>= self.redis_batch_cache_expiry
|
||||
):
|
||||
sublist_keys.append(key)
|
||||
return sublist_keys
|
||||
previous_access_times[key] = self.last_redis_batch_access_time.get(
|
||||
key
|
||||
)
|
||||
self.last_redis_batch_access_time[key] = current_time
|
||||
|
||||
return sublist_keys, previous_access_times
|
||||
|
||||
def _rollback_redis_batch_key_reservations(
|
||||
self, previous_access_times: Dict[str, Optional[float]]
|
||||
) -> None:
|
||||
with self._last_redis_batch_access_time_lock:
|
||||
for key, previous_time in previous_access_times.items():
|
||||
if previous_time is None:
|
||||
self.last_redis_batch_access_time.pop(key, None)
|
||||
else:
|
||||
self.last_redis_batch_access_time[key] = previous_time
|
||||
|
||||
async def async_batch_get_cache(
|
||||
self,
|
||||
@@ -276,19 +302,23 @@ class DualCache(BaseCache):
|
||||
- check the redis cache
|
||||
"""
|
||||
current_time = time.time()
|
||||
sublist_keys = self.get_redis_batch_keys(current_time, keys, result)
|
||||
sublist_keys, previous_access_times = self._reserve_redis_batch_keys(
|
||||
current_time, keys, result
|
||||
)
|
||||
|
||||
# Only hit Redis if the last access time was more than 5 seconds ago
|
||||
# Only hit Redis if enough time has passed since last access.
|
||||
if len(sublist_keys) > 0:
|
||||
# If not found in in-memory cache, try fetching from Redis
|
||||
redis_result = await self.redis_cache.async_batch_get_cache(
|
||||
sublist_keys, parent_otel_span=parent_otel_span
|
||||
)
|
||||
|
||||
# Update the last access time for ALL queried keys
|
||||
# This includes keys with None values to throttle repeated Redis queries
|
||||
for key in sublist_keys:
|
||||
self.last_redis_batch_access_time[key] = current_time
|
||||
try:
|
||||
# If not found in in-memory cache, try fetching from Redis
|
||||
redis_result = await self.redis_cache.async_batch_get_cache(
|
||||
sublist_keys, parent_otel_span=parent_otel_span
|
||||
)
|
||||
except Exception:
|
||||
# Do not throttle subsequent callers if the Redis read fails.
|
||||
self._rollback_redis_batch_key_reservations(
|
||||
previous_access_times
|
||||
)
|
||||
raise
|
||||
|
||||
# Short-circuit if redis_result is None or contains only None values
|
||||
if redis_result is None or all(v is None for v in redis_result.values()):
|
||||
|
||||
@@ -0,0 +1,58 @@
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from litellm.caching.dual_cache import DualCache
|
||||
from litellm.caching.redis_cache import RedisCache
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dual_cache_async_batch_get_cache_coalesces_concurrent_redis_reads():
|
||||
dual_cache = DualCache(
|
||||
redis_cache=MagicMock(spec=RedisCache), default_redis_batch_cache_expiry=10
|
||||
)
|
||||
keys = ["shared_a", "shared_b"]
|
||||
start_gate = asyncio.Event()
|
||||
|
||||
async def _mock_async_batch_get_cache(key_list, parent_otel_span=None):
|
||||
await asyncio.sleep(0.05)
|
||||
return {k: None for k in key_list}
|
||||
|
||||
with patch.object(
|
||||
dual_cache.redis_cache,
|
||||
"async_batch_get_cache",
|
||||
new=AsyncMock(side_effect=_mock_async_batch_get_cache),
|
||||
) as mock_async_batch_get_cache:
|
||||
|
||||
async def worker():
|
||||
await start_gate.wait()
|
||||
return await dual_cache.async_batch_get_cache(keys=keys)
|
||||
|
||||
tasks = [asyncio.create_task(worker()) for _ in range(50)]
|
||||
start_gate.set()
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
assert mock_async_batch_get_cache.call_count == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dual_cache_async_batch_get_cache_rolls_back_redis_reservation_on_error():
|
||||
dual_cache = DualCache(
|
||||
redis_cache=MagicMock(spec=RedisCache), default_redis_batch_cache_expiry=10
|
||||
)
|
||||
keys = ["shared_a", "shared_b"]
|
||||
|
||||
with patch.object(
|
||||
dual_cache.redis_cache,
|
||||
"async_batch_get_cache",
|
||||
new=AsyncMock(side_effect=RuntimeError("redis unavailable")),
|
||||
) as mock_async_batch_get_cache:
|
||||
first_result = await dual_cache.async_batch_get_cache(keys=keys)
|
||||
second_result = await dual_cache.async_batch_get_cache(keys=keys)
|
||||
|
||||
assert first_result is None
|
||||
assert second_result is None
|
||||
assert mock_async_batch_get_cache.call_count == 2
|
||||
assert "shared_a" not in dual_cache.last_redis_batch_access_time
|
||||
assert "shared_b" not in dual_cache.last_redis_batch_access_time
|
||||
Reference in New Issue
Block a user