diff --git a/litellm/llms/vertex_ai/vertex_llm_base.py b/litellm/llms/vertex_ai/vertex_llm_base.py index 21b5ff8512..f8ceb7ea93 100644 --- a/litellm/llms/vertex_ai/vertex_llm_base.py +++ b/litellm/llms/vertex_ai/vertex_llm_base.py @@ -8,8 +8,7 @@ import asyncio import json import os import threading -import weakref -from typing import TYPE_CHECKING, Any, Dict, Literal, MutableMapping, Optional, Tuple +from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Tuple import litellm from litellm._logging import verbose_logger @@ -52,12 +51,10 @@ class VertexBase: self.async_handler: Optional[AsyncHTTPHandler] = None # Per-credential-key asyncio.Lock for single-flight async refresh. # Prevents thundering herd when token expires under high concurrency. - # Uses a WeakValueDictionary so a lock is auto-pruned once no coroutine - # holds it any more — keeps the dict bounded in high-cardinality - # deployments without breaking single-flight while a refresh is active. - self._async_refresh_locks: MutableMapping[tuple, asyncio.Lock] = ( - weakref.WeakValueDictionary() - ) + # Uses a regular dict (not WeakValueDictionary) so the lock identity is + # stable across concurrent callers — a weak reference can be GC'd + # between two coroutines arriving at the lock, breaking single-flight. + self._async_refresh_locks: Dict[tuple, asyncio.Lock] = {} # Tracks in-flight background refresh tasks to avoid duplicate refreshes. self._background_refresh_tasks: Dict[tuple, asyncio.Task] = {} # Protects the sync get_access_token refresh path. @@ -975,13 +972,21 @@ class VertexBase: # Clean up the entry automatically when the task finishes so # that long-running proxies with many credential keys do not - # accumulate stale references. + # accumulate stale references. Guard with an identity check + # so a stale callback can't remove a newer task that already + # replaced this one in the dict (done_callbacks are scheduled + # via call_soon, so another coroutine may have stored a fresh + # task for the same key before this callback fires). def _drop_background_refresh_task( _fut: asyncio.Future[Any], ) -> None: - self._background_refresh_tasks.pop( - credential_cache_key, None - ) + if ( + self._background_refresh_tasks.get(credential_cache_key) + is _fut + ): + self._background_refresh_tasks.pop( + credential_cache_key, None + ) task.add_done_callback(_drop_background_refresh_task) self._background_refresh_tasks[credential_cache_key] = task