fix(vertex_ai): stabilize background refresh task tracking

- Guard background refresh done_callback with an identity check so a
  stale callback cannot remove a newer task that already replaced it in
  the tracking dict (done_callbacks are scheduled via call_soon, so a
  fresh task can be stored for the same credential key before the old
  callback fires).
- Replace WeakValueDictionary with a regular dict for
  _async_refresh_locks so the per-key asyncio.Lock identity is stable
  across concurrent callers; otherwise a lock can be GC'd between two
  coroutines arriving for the same key, breaking single-flight.

Co-authored-by: Yassin Kortam <yassin@berri.ai>
This commit is contained in:
Cursor Agent
2026-05-20 18:05:30 +00:00
parent 349848f88d
commit dfb2524def
+17 -12
View File
@@ -8,8 +8,7 @@ import asyncio
import json
import os
import threading
import weakref
from typing import TYPE_CHECKING, Any, Dict, Literal, MutableMapping, Optional, Tuple
from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Tuple
import litellm
from litellm._logging import verbose_logger
@@ -52,12 +51,10 @@ class VertexBase:
self.async_handler: Optional[AsyncHTTPHandler] = None
# Per-credential-key asyncio.Lock for single-flight async refresh.
# Prevents thundering herd when token expires under high concurrency.
# Uses a WeakValueDictionary so a lock is auto-pruned once no coroutine
# holds it any more — keeps the dict bounded in high-cardinality
# deployments without breaking single-flight while a refresh is active.
self._async_refresh_locks: MutableMapping[tuple, asyncio.Lock] = (
weakref.WeakValueDictionary()
)
# Uses a regular dict (not WeakValueDictionary) so the lock identity is
# stable across concurrent callers — a weak reference can be GC'd
# between two coroutines arriving at the lock, breaking single-flight.
self._async_refresh_locks: Dict[tuple, asyncio.Lock] = {}
# Tracks in-flight background refresh tasks to avoid duplicate refreshes.
self._background_refresh_tasks: Dict[tuple, asyncio.Task] = {}
# Protects the sync get_access_token refresh path.
@@ -975,13 +972,21 @@ class VertexBase:
# Clean up the entry automatically when the task finishes so
# that long-running proxies with many credential keys do not
# accumulate stale references.
# accumulate stale references. Guard with an identity check
# so a stale callback can't remove a newer task that already
# replaced this one in the dict (done_callbacks are scheduled
# via call_soon, so another coroutine may have stored a fresh
# task for the same key before this callback fires).
def _drop_background_refresh_task(
_fut: asyncio.Future[Any],
) -> None:
self._background_refresh_tasks.pop(
credential_cache_key, None
)
if (
self._background_refresh_tasks.get(credential_cache_key)
is _fut
):
self._background_refresh_tasks.pop(
credential_cache_key, None
)
task.add_done_callback(_drop_background_refresh_task)
self._background_refresh_tasks[credential_cache_key] = task