mirror of
https://github.com/tiennm99/litellm.git
synced 2026-06-17 20:48:32 +00:00
fix(caching): store background task references in LLMClientCache._remove_key to prevent unawaited coroutine warnings
Fixes #22128
This commit is contained in:
@@ -3,11 +3,37 @@ Add the event loop to the cache key, to prevent event loop closed errors.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from typing import Set
|
||||
|
||||
from .in_memory_cache import InMemoryCache
|
||||
|
||||
|
||||
class LLMClientCache(InMemoryCache):
|
||||
# Background tasks must be stored to prevent garbage collection, which would
|
||||
# trigger "coroutine was never awaited" warnings. See:
|
||||
# https://docs.python.org/3/library/asyncio-task.html#creating-tasks
|
||||
# Intentionally shared across all instances as a global task registry.
|
||||
_background_tasks: Set[asyncio.Task] = set()
|
||||
|
||||
def _remove_key(self, key: str) -> None:
|
||||
"""Close async clients before evicting them to prevent connection pool leaks."""
|
||||
value = self.cache_dict.get(key)
|
||||
super()._remove_key(key)
|
||||
if value is not None:
|
||||
close_fn = getattr(value, "aclose", None) or getattr(value, "close", None)
|
||||
if close_fn and asyncio.iscoroutinefunction(close_fn):
|
||||
try:
|
||||
task = asyncio.get_running_loop().create_task(close_fn())
|
||||
self._background_tasks.add(task)
|
||||
task.add_done_callback(self._background_tasks.discard)
|
||||
except RuntimeError:
|
||||
pass
|
||||
elif close_fn and callable(close_fn):
|
||||
try:
|
||||
close_fn()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def update_cache_key_with_event_loop(self, key):
|
||||
"""
|
||||
Add the event loop to the cache key, to prevent event loop closed errors.
|
||||
|
||||
@@ -0,0 +1,158 @@
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
import warnings
|
||||
|
||||
import pytest
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../../..")
|
||||
) # Adds the parent directory to the system path
|
||||
|
||||
from litellm.caching.llm_caching_handler import LLMClientCache
|
||||
|
||||
|
||||
class MockAsyncClient:
|
||||
"""Mock async HTTP client with an async close method."""
|
||||
|
||||
def __init__(self):
|
||||
self.closed = False
|
||||
|
||||
async def close(self):
|
||||
self.closed = True
|
||||
|
||||
|
||||
class MockSyncClient:
|
||||
"""Mock sync HTTP client with a sync close method."""
|
||||
|
||||
def __init__(self):
|
||||
self.closed = False
|
||||
|
||||
def close(self):
|
||||
self.closed = True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_remove_key_no_unawaited_coroutine_warning():
|
||||
"""
|
||||
Test that evicting an async client from LLMClientCache does not produce
|
||||
'coroutine was never awaited' warnings.
|
||||
|
||||
Regression test for https://github.com/BerriAI/litellm/issues/22128
|
||||
"""
|
||||
cache = LLMClientCache(max_size_in_memory=2)
|
||||
|
||||
mock_client = MockAsyncClient()
|
||||
cache.cache_dict["test-key"] = mock_client
|
||||
cache.ttl_dict["test-key"] = 0 # expired
|
||||
|
||||
with warnings.catch_warnings(record=True) as caught_warnings:
|
||||
warnings.simplefilter("always")
|
||||
cache._remove_key("test-key")
|
||||
# Let the event loop process the close task
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
coroutine_warnings = [
|
||||
w for w in caught_warnings if "coroutine" in str(w.message).lower()
|
||||
]
|
||||
assert (
|
||||
len(coroutine_warnings) == 0
|
||||
), f"Got unawaited coroutine warnings: {coroutine_warnings}"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_remove_key_closes_async_client():
|
||||
"""
|
||||
Test that evicting an async client from the cache properly closes it.
|
||||
"""
|
||||
cache = LLMClientCache(max_size_in_memory=2)
|
||||
|
||||
mock_client = MockAsyncClient()
|
||||
cache.cache_dict["test-key"] = mock_client
|
||||
cache.ttl_dict["test-key"] = 0
|
||||
|
||||
cache._remove_key("test-key")
|
||||
# Let the event loop process the close task
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
assert mock_client.closed is True
|
||||
assert "test-key" not in cache.cache_dict
|
||||
assert "test-key" not in cache.ttl_dict
|
||||
|
||||
|
||||
def test_remove_key_closes_sync_client():
|
||||
"""
|
||||
Test that evicting a sync client from the cache properly closes it.
|
||||
"""
|
||||
cache = LLMClientCache(max_size_in_memory=2)
|
||||
|
||||
mock_client = MockSyncClient()
|
||||
cache.cache_dict["test-key"] = mock_client
|
||||
cache.ttl_dict["test-key"] = 0
|
||||
|
||||
cache._remove_key("test-key")
|
||||
|
||||
assert mock_client.closed is True
|
||||
assert "test-key" not in cache.cache_dict
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_eviction_closes_async_clients():
|
||||
"""
|
||||
Test that cache eviction (when cache is full) properly closes async clients
|
||||
without producing warnings.
|
||||
"""
|
||||
cache = LLMClientCache(max_size_in_memory=2, default_ttl=1)
|
||||
|
||||
clients = []
|
||||
for i in range(2):
|
||||
client = MockAsyncClient()
|
||||
clients.append(client)
|
||||
cache.set_cache(f"key-{i}", client)
|
||||
|
||||
with warnings.catch_warnings(record=True) as caught_warnings:
|
||||
warnings.simplefilter("always")
|
||||
# This should trigger eviction of one of the existing entries
|
||||
cache.set_cache("key-new", "new-value")
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
coroutine_warnings = [
|
||||
w for w in caught_warnings if "coroutine" in str(w.message).lower()
|
||||
]
|
||||
assert (
|
||||
len(coroutine_warnings) == 0
|
||||
), f"Got unawaited coroutine warnings: {coroutine_warnings}"
|
||||
|
||||
|
||||
def test_remove_key_no_event_loop():
|
||||
"""
|
||||
Test that _remove_key doesn't raise when there's no running event loop
|
||||
(falls through to the RuntimeError except branch).
|
||||
"""
|
||||
cache = LLMClientCache(max_size_in_memory=2)
|
||||
|
||||
mock_client = MockAsyncClient()
|
||||
cache.cache_dict["test-key"] = mock_client
|
||||
cache.ttl_dict["test-key"] = 0
|
||||
|
||||
# Should not raise even though there's no running event loop
|
||||
cache._remove_key("test-key")
|
||||
assert "test-key" not in cache.cache_dict
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_background_tasks_cleaned_up_after_completion():
|
||||
"""
|
||||
Test that completed close tasks are removed from the _background_tasks set.
|
||||
"""
|
||||
cache = LLMClientCache(max_size_in_memory=2)
|
||||
|
||||
mock_client = MockAsyncClient()
|
||||
cache.cache_dict["test-key"] = mock_client
|
||||
cache.ttl_dict["test-key"] = 0
|
||||
|
||||
cache._remove_key("test-key")
|
||||
# Let the task complete
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
assert len(cache._background_tasks) == 0
|
||||
Reference in New Issue
Block a user