From 0bd49ecb8b497efd8f7ddfb56a07a48ac7c5a017 Mon Sep 17 00:00:00 2001 From: Michael Riad Zaky Date: Fri, 17 Apr 2026 16:58:26 -0700 Subject: [PATCH] Fix bug that bypasses per-team member budget limit --- litellm/proxy/auth/auth_checks.py | 82 +++++++- .../management_endpoints/team_endpoints.py | 30 ++- litellm/proxy/proxy_server.py | 100 +++++++-- .../proxy/auth/test_auth_checks.py | 189 ++++++++++++++++++ .../test_team_endpoints.py | 52 +++++ tests/test_litellm/proxy/test_proxy_server.py | 120 +++++++++++ 6 files changed, 545 insertions(+), 28 deletions(-) diff --git a/litellm/proxy/auth/auth_checks.py b/litellm/proxy/auth/auth_checks.py index 2c8299e77a..1c89b0bfc0 100644 --- a/litellm/proxy/auth/auth_checks.py +++ b/litellm/proxy/auth/auth_checks.py @@ -905,6 +905,63 @@ async def get_default_end_user_budget( return None +@log_db_metrics +async def get_team_member_default_budget( + budget_id: str, + prisma_client: Optional[PrismaClient], + user_api_key_cache: DualCache, +) -> Optional[LiteLLM_BudgetTable]: + """ + Fetches the team-level default per-member budget referenced by team.metadata["team_member_budget_id"]. + + This budget is applied to team members whose TeamMembership row has no + linked budget. Results are cached for performance. + + Args: + budget_id: The budget_id pulled from team.metadata["team_member_budget_id"] + prisma_client: Database client instance + user_api_key_cache: Cache for storing/retrieving budget data + + Returns: + LiteLLM_BudgetTable if found, None otherwise + """ + if prisma_client is None: + return None + + cache_key = f"team_member_default_budget:{budget_id}" + + cached_budget = await user_api_key_cache.async_get_cache(key=cache_key) + if isinstance(cached_budget, LiteLLM_BudgetTable): + return cached_budget + if isinstance(cached_budget, dict): + return LiteLLM_BudgetTable(**cached_budget) + + try: + budget_record = await prisma_client.db.litellm_budgettable.find_unique( + where={"budget_id": budget_id} + ) + + if budget_record is None: + verbose_proxy_logger.warning( + f"Team-default member budget not found in database: {budget_id}" + ) + return None + + await user_api_key_cache.async_set_cache( + key=cache_key, + value=budget_record.dict(), + ttl=DEFAULT_MANAGEMENT_OBJECT_IN_MEMORY_CACHE_TTL, + ) + + return LiteLLM_BudgetTable(**budget_record.dict()) + + except Exception: + verbose_proxy_logger.exception( + f"Error fetching team-default member budget {budget_id}" + ) + return None + + async def _apply_default_budget_to_end_user( end_user_obj: LiteLLM_EndUserTable, prisma_client: PrismaClient, @@ -3230,13 +3287,26 @@ async def _check_team_member_budget( proxy_logging_obj=proxy_logging_obj, ) - if ( - team_membership is not None - and team_membership.litellm_budget_table is not None - and team_membership.litellm_budget_table.max_budget is not None - ): + # Per-member override wins; otherwise fall back to the team-level + # default configured via team.metadata["team_member_budget_id"]. + team_member_budget: Optional[float] = None + if team_membership is not None and team_membership.litellm_budget_table is not None: team_member_budget = team_membership.litellm_budget_table.max_budget - team_member_spend = team_membership.spend or 0.0 + else: + default_budget_id = (team_object.metadata or {}).get("team_member_budget_id") + if isinstance(default_budget_id, str): + default_budget = await get_team_member_default_budget( + budget_id=default_budget_id, + prisma_client=prisma_client, + user_api_key_cache=user_api_key_cache, + ) + if default_budget is not None: + team_member_budget = default_budget.max_budget + + if team_member_budget is not None: + team_member_spend = ( + team_membership.spend if team_membership is not None else 0.0 + ) or 0.0 # Read from cross-pod counter (Redis-first) if available from litellm.proxy.proxy_server import get_current_spend diff --git a/litellm/proxy/management_endpoints/team_endpoints.py b/litellm/proxy/management_endpoints/team_endpoints.py index bf912fba4f..8357b1c0fe 100644 --- a/litellm/proxy/management_endpoints/team_endpoints.py +++ b/litellm/proxy/management_endpoints/team_endpoints.py @@ -302,14 +302,15 @@ class TeamMemberBudgetHandler: prisma_client: PrismaClient, ) -> None: """ - Create team_memberships entries for existing members that don't have one. + Ensure every team member has a TeamMembership row linked to the + team_member_budget. - Called after team_member_budget is set/updated on a team to ensure - members who joined before the budget was configured also get budget - enforcement. - - Only creates missing entries — does not touch existing memberships - (which may carry individual per-member budgets). + Called after team_member_budget is set/updated on a team. Creates + rows for members who don't have one, and populates budget_id on + existing rows where it is NULL. Rows with a non-NULL budget_id + are left untouched, which preserves per-member overrides but also + means rows pointing to a prior team-default budget_id are not + migrated to the new one. """ if not members_with_roles: return @@ -347,6 +348,21 @@ class TeamMemberBudgetHandler: _sanitize_for_log(team_member_budget_id), ) + # Heal existing membership rows that predate the team_member_budget + # configuration: populate budget_id where it is currently NULL. + # Rows with an explicit budget_id (per-member override) are left alone. + updated = await prisma_client.db.litellm_teammembership.update_many( + where={"team_id": team_id, "budget_id": None}, + data={"budget_id": team_member_budget_id}, + ) + if updated: + verbose_proxy_logger.info( + "Populated budget_id on %d existing team_memberships for team %s with budget %s", + updated, + _sanitize_for_log(team_id), + _sanitize_for_log(team_member_budget_id), + ) + def _get_default_team_param(field: str) -> Any: """ diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index aa8122d8fd..546d8df14c 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -1908,40 +1908,110 @@ async def increment_spend_counters( ) +async def _reseed_spend_from_db(counter_key: str) -> float: + """ + Read the authoritative spend for a missing counter from the DB. The + counter_key prefix encodes the table to query: + + spend:key:{token} -> LiteLLM_VerificationToken.spend + spend:team:{team_id} -> LiteLLM_TeamTable.spend + spend:team_member:{uid}:{tid} -> LiteLLM_TeamMembership.spend + spend:user:{user_id} -> LiteLLM_UserTable.spend + spend:org:{org_id} -> LiteLLM_OrganizationTable.spend + + Returns 0.0 if prisma is unavailable, the row is missing, or the + key format is unrecognized. On failure, logs and returns 0.0 rather + than raising so the caller can still record the current increment. + """ + if prisma_client is None: + return 0.0 + # Per-window counters (spend:*:window:{duration}) share prefixes with + # primary counters but don't correspond to a DB row; their ambiguity + # would otherwise be silently parsed as a regular counter and miss. + if ":window:" in counter_key: + return 0.0 + try: + if counter_key.startswith("spend:key:"): + token = counter_key[len("spend:key:") :] + row = await prisma_client.db.litellm_verificationtoken.find_unique( + where={"token": token} + ) + elif counter_key.startswith("spend:team_member:"): + suffix = counter_key[len("spend:team_member:") :] + if ":" not in suffix: + return 0.0 + user_id, team_id = suffix.rsplit(":", 1) + row = await prisma_client.db.litellm_teammembership.find_unique( + where={"user_id_team_id": {"user_id": user_id, "team_id": team_id}} + ) + elif counter_key.startswith("spend:team:"): + team_id = counter_key[len("spend:team:") :] + row = await prisma_client.db.litellm_teamtable.find_unique( + where={"team_id": team_id} + ) + elif counter_key.startswith("spend:user:"): + user_id = counter_key[len("spend:user:") :] + row = await prisma_client.db.litellm_usertable.find_unique( + where={"user_id": user_id} + ) + elif counter_key.startswith("spend:org:"): + org_id = counter_key[len("spend:org:") :] + row = await prisma_client.db.litellm_organizationtable.find_unique( + where={"organization_id": org_id} + ) + else: + return 0.0 + except Exception: + verbose_proxy_logger.exception( + "Failed to reseed spend counter %s from DB", counter_key + ) + return 0.0 + if row is None: + return 0.0 + return float(getattr(row, "spend", 0.0) or 0.0) + + async def _init_and_increment_spend_counter( counter_key: str, source_cache_key: str, increment: float, ): """ - Initialize counter from cached object's DB-loaded spend if not yet set, - then atomically increment in both in-memory and Redis. + Initialize counter from the authoritative DB spend value if not yet + set, then atomically increment in both in-memory and Redis. On first access per pod: - 1. Check spend_counter_cache (in-memory -> Redis via DualCache for init check) - 2. If not found anywhere, read base spend from user_api_key_cache (DB-loaded object) + 1. Check spend_counter_cache (in-memory -> Redis via DualCache) + 2. If not found, reseed from the DB (`_reseed_spend_from_db`). Falls + back to the cached object's `.spend` via user_api_key_cache only + if prisma is unavailable, since that value can lag the flusher. 3. Seed counter via async_increment_cache (not async_set_cache) to avoid a check-then-set race: if two pods cold-start simultaneously, both may see - the counter as absent and seed it. Using increment instead of set means - the worst case is over-counting (conservative — blocks slightly early) - rather than under-counting (would allow overspend). + the counter as absent and seed it. Using increment means the worst case + is over-counting (conservative, blocks slightly early) rather than + under-counting (would allow overspend). 4. Increment atomically (both in-memory + Redis) """ current = await spend_counter_cache.async_get_cache(key=counter_key) if current is None: - source = await user_api_key_cache.async_get_cache(key=source_cache_key) - base_spend = 0.0 - if source is not None: - if isinstance(source, dict): - base_spend = source.get("spend", 0.0) or 0.0 - else: - base_spend = getattr(source, "spend", 0.0) or 0.0 + base_spend = await _reseed_spend_from_db(counter_key) + if prisma_client is None: + # Best-effort fallback when prisma is unavailable (tests or + # early-startup paths). May be stale but avoids resetting to 0. + source = await user_api_key_cache.async_get_cache(key=source_cache_key) + if source is not None: + if isinstance(source, dict): + base_spend = source.get("spend", 0.0) or 0.0 + else: + base_spend = getattr(source, "spend", 0.0) or 0.0 if base_spend > 0: await spend_counter_cache.async_increment_cache( key=counter_key, value=base_spend ) - await spend_counter_cache.async_increment_cache(key=counter_key, value=increment) + await spend_counter_cache.async_increment_cache( + key=counter_key, value=increment + ) async def update_cache( # noqa: PLR0915 diff --git a/tests/test_litellm/proxy/auth/test_auth_checks.py b/tests/test_litellm/proxy/auth/test_auth_checks.py index 19fffffc65..8612d243c4 100644 --- a/tests/test_litellm/proxy/auth/test_auth_checks.py +++ b/tests/test_litellm/proxy/auth/test_auth_checks.py @@ -2126,3 +2126,192 @@ class TestGuardrailModificationCheck: """Unparseable strings should not trigger a 403 — they have no keys.""" self._call({"metadata": "not-json"}) self._call({"metadata": '"just a string"'}) + + +@pytest.mark.asyncio +async def test_team_member_budget_check_falls_back_to_team_default_budget_id(): + """When a member's TeamMembership has no linked budget row, the check + should fall back to team.metadata["team_member_budget_id"] and still + enforce the cap. Pre-fix, this path silently skipped enforcement.""" + from litellm.caching.dual_cache import DualCache + from litellm.proxy._types import LiteLLM_TeamMembership + from litellm.proxy.utils import ProxyLogging + + team_object = LiteLLM_TeamTable( + team_id="test-team", + metadata={"team_member_budget_id": "budget-default"}, + ) + user_object = LiteLLM_UserTable(user_id="test-user") + valid_token = UserAPIKeyAuth( + token="test-token", + user_id="test-user", + team_id="test-team", + ) + + # Membership row without an attached budget. + team_membership = LiteLLM_TeamMembership( + user_id="test-user", + team_id="test-team", + spend=0.0, + budget_id=None, + litellm_budget_table=None, + ) + + proxy_logging_obj = ProxyLogging(user_api_key_cache=None) + + fake_budget_row = MagicMock() + fake_budget_row.max_budget = 50.0 + fake_budget_row.dict = MagicMock( + return_value={"budget_id": "budget-default", "max_budget": 50.0} + ) + + prisma_client = MagicMock() + prisma_client.db.litellm_budgettable.find_unique = AsyncMock( + return_value=fake_budget_row + ) + + async def mock_get_current_spend(counter_key, fallback_spend): + if counter_key == "spend:team_member:test-user:test-team": + return 70.0 + return fallback_spend + + user_api_key_cache = DualCache() + + with ( + patch("litellm.proxy.proxy_server.get_current_spend", mock_get_current_spend), + patch( + "litellm.proxy.auth.auth_checks.get_team_membership", + new_callable=AsyncMock, + return_value=team_membership, + ), + ): + with pytest.raises(litellm.BudgetExceededError) as exc_info: + await _check_team_member_budget( + team_object=team_object, + user_object=user_object, + valid_token=valid_token, + prisma_client=prisma_client, + user_api_key_cache=user_api_key_cache, + proxy_logging_obj=proxy_logging_obj, + ) + assert exc_info.value.current_cost == 70.0 + assert exc_info.value.max_budget == 50.0 + + # First call did perform the fallback DB lookup. + prisma_client.db.litellm_budgettable.find_unique.assert_awaited_once() + + # Second call hits the cached budget row, no additional prisma read. + prisma_client.db.litellm_budgettable.find_unique.reset_mock() + with ( + patch("litellm.proxy.proxy_server.get_current_spend", mock_get_current_spend), + patch( + "litellm.proxy.auth.auth_checks.get_team_membership", + new_callable=AsyncMock, + return_value=team_membership, + ), + ): + with pytest.raises(litellm.BudgetExceededError) as second_exc_info: + await _check_team_member_budget( + team_object=team_object, + user_object=user_object, + valid_token=valid_token, + prisma_client=prisma_client, + user_api_key_cache=user_api_key_cache, + proxy_logging_obj=proxy_logging_obj, + ) + # The cached $50 cap is still being applied (not a coincidental skip) + assert second_exc_info.value.current_cost == 70.0 + assert second_exc_info.value.max_budget == 50.0 + prisma_client.db.litellm_budgettable.find_unique.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_team_member_budget_check_per_member_override_wins_over_team_default(): + """If a member has a per-member budget AND the team carries a + team_member_budget_id default, the per-member value wins and the + fallback prisma lookup is never performed.""" + from litellm.caching.dual_cache import DualCache + from litellm.proxy._types import LiteLLM_BudgetTable, LiteLLM_TeamMembership + from litellm.proxy.utils import ProxyLogging + + team_object = LiteLLM_TeamTable( + team_id="test-team", + metadata={"team_member_budget_id": "budget-default"}, + ) + user_object = LiteLLM_UserTable(user_id="test-user") + valid_token = UserAPIKeyAuth( + token="test-token", + user_id="test-user", + team_id="test-team", + ) + + team_membership = LiteLLM_TeamMembership( + user_id="test-user", + team_id="test-team", + spend=0.0, + budget_id="budget-override", + litellm_budget_table=LiteLLM_BudgetTable(max_budget=200.0), + ) + + proxy_logging_obj = ProxyLogging(user_api_key_cache=None) + + # Team-default row resolves to $50. If the fallback fired (it must + # not here), spend $70 would exceed that $50 cap and raise. + fake_budget_row = MagicMock() + fake_budget_row.max_budget = 50.0 + + prisma_client = MagicMock() + prisma_client.db.litellm_budgettable.find_unique = AsyncMock( + return_value=fake_budget_row + ) + + mocked_spend = 70.0 + + async def mock_get_current_spend(counter_key, fallback_spend): + if counter_key == "spend:team_member:test-user:test-team": + return mocked_spend + return fallback_spend + + # 1. spend ($70) < per-member cap ($200) → no raise, no fallback lookup. + with ( + patch("litellm.proxy.proxy_server.get_current_spend", mock_get_current_spend), + patch( + "litellm.proxy.auth.auth_checks.get_team_membership", + new_callable=AsyncMock, + return_value=team_membership, + ), + ): + await _check_team_member_budget( + team_object=team_object, + user_object=user_object, + valid_token=valid_token, + prisma_client=prisma_client, + user_api_key_cache=DualCache(), + proxy_logging_obj=proxy_logging_obj, + ) + + prisma_client.db.litellm_budgettable.find_unique.assert_not_awaited() + + # 2. Now push spend above the per-member cap ($200). Must raise with + # max_budget=200 to prove the per-member cap is the value being + # enforced (not just that enforcement silently skipped). + mocked_spend = 250.0 + with ( + patch("litellm.proxy.proxy_server.get_current_spend", mock_get_current_spend), + patch( + "litellm.proxy.auth.auth_checks.get_team_membership", + new_callable=AsyncMock, + return_value=team_membership, + ), + ): + with pytest.raises(litellm.BudgetExceededError) as exc_info: + await _check_team_member_budget( + team_object=team_object, + user_object=user_object, + valid_token=valid_token, + prisma_client=prisma_client, + user_api_key_cache=DualCache(), + proxy_logging_obj=proxy_logging_obj, + ) + assert exc_info.value.current_cost == 250.0 + assert exc_info.value.max_budget == 200.0 diff --git a/tests/test_litellm/proxy/management_endpoints/test_team_endpoints.py b/tests/test_litellm/proxy/management_endpoints/test_team_endpoints.py index 8da0ef19f8..65187fb52d 100644 --- a/tests/test_litellm/proxy/management_endpoints/test_team_endpoints.py +++ b/tests/test_litellm/proxy/management_endpoints/test_team_endpoints.py @@ -1795,6 +1795,7 @@ async def test_backfill_team_member_budget_entries_creates_missing_memberships() return_value=[existing_membership] ) mock_prisma.db.litellm_teammembership.create_many = AsyncMock(return_value=None) + mock_prisma.db.litellm_teammembership.update_many = AsyncMock(return_value=0) # Test with Member instances members = [ @@ -1823,6 +1824,7 @@ async def test_backfill_team_member_budget_entries_creates_missing_memberships() # Also test with raw dicts (members_with_roles may be dicts when deserialized from DB) mock_prisma.db.litellm_teammembership.find_many.reset_mock() mock_prisma.db.litellm_teammembership.create_many.reset_mock() + mock_prisma.db.litellm_teammembership.update_many.reset_mock() members_as_dicts = [ {"user_id": "user-A", "role": "user"}, @@ -1868,6 +1870,7 @@ async def test_backfill_team_member_budget_entries_no_op_when_all_exist(): return_value=[existing_a, existing_b] ) mock_prisma.db.litellm_teammembership.create_many = AsyncMock(return_value=None) + mock_prisma.db.litellm_teammembership.update_many = AsyncMock(return_value=0) members = [ Member(user_id="user-A", role="user"), @@ -1884,6 +1887,55 @@ async def test_backfill_team_member_budget_entries_no_op_when_all_exist(): mock_prisma.db.litellm_teammembership.create_many.assert_not_awaited() +@pytest.mark.asyncio +async def test_backfill_team_member_budget_entries_populates_null_budget_id_on_existing_rows(): + """ + backfill_team_member_budget_entries should populate budget_id on + existing TeamMembership rows where it is currently NULL, so admins + can configure a team member budget after members have already joined + and have enforcement apply to those pre-existing members. + """ + from unittest.mock import AsyncMock, MagicMock + + from litellm.proxy._types import Member + from litellm.proxy.management_endpoints.team_endpoints import ( + TeamMemberBudgetHandler, + ) + + team_id = "team-abc" + budget_id = "budget-xyz" + + # Both members already have rows, so create_many must not fire; + # update_many must fire with the NULL-budget_id filter. + existing_a = MagicMock() + existing_a.user_id = "user-A" + existing_b = MagicMock() + existing_b.user_id = "user-B" + + mock_prisma = MagicMock() + mock_prisma.db.litellm_teammembership.find_many = AsyncMock( + return_value=[existing_a, existing_b] + ) + mock_prisma.db.litellm_teammembership.create_many = AsyncMock(return_value=None) + mock_prisma.db.litellm_teammembership.update_many = AsyncMock(return_value=2) + + await TeamMemberBudgetHandler.backfill_team_member_budget_entries( + team_id=team_id, + members_with_roles=[ + Member(user_id="user-A", role="user"), + Member(user_id="user-B", role="user"), + ], + team_member_budget_id=budget_id, + prisma_client=mock_prisma, + ) + + mock_prisma.db.litellm_teammembership.create_many.assert_not_awaited() + mock_prisma.db.litellm_teammembership.update_many.assert_awaited_once_with( + where={"team_id": team_id, "budget_id": None}, + data={"budget_id": budget_id}, + ) + + @pytest.mark.asyncio async def test_backfill_team_member_budget_entries_empty_members(): """ diff --git a/tests/test_litellm/proxy/test_proxy_server.py b/tests/test_litellm/proxy/test_proxy_server.py index 79eba81dc4..efd1abbb38 100644 --- a/tests/test_litellm/proxy/test_proxy_server.py +++ b/tests/test_litellm/proxy/test_proxy_server.py @@ -4965,3 +4965,123 @@ async def test_increment_spend_counters_team_and_member(): finally: ps.user_api_key_cache = original_key_cache ps.spend_counter_cache = original_counter_cache + + +@pytest.mark.asyncio +async def test_init_and_increment_spend_counter_reseeds_from_db_on_counter_miss(): + """When the Redis counter is missing, the reseed path reads the + authoritative spend from the DB (not a stale cache), so the next + increment continues from the correct base value.""" + from litellm.caching.dual_cache import DualCache + + counter_cache = DualCache() + recorded_increments: list = [] + + async def record_increment(key, value, ttl=None, **kwargs): + recorded_increments.append({"key": key, "value": value, "ttl": ttl}) + return value + + fake_redis = AsyncMock() + fake_redis.async_increment = AsyncMock(side_effect=record_increment) + fake_redis.async_get_cache = AsyncMock(return_value=None) # counter missing + counter_cache.redis_cache = fake_redis + + # Prisma returns spend=42.0 (authoritative) while the stale cached + # value (would be read only if prisma is None) is 10.0. The counter + # must seed from 42, not 10. + db_row = MagicMock() + db_row.spend = 42.0 + fake_prisma = MagicMock() + fake_prisma.db.litellm_teamtable.find_unique = AsyncMock(return_value=db_row) + + stale_cache = DualCache() + stale_team = MagicMock() + stale_team.spend = 10.0 + stale_cache.in_memory_cache.set_cache(key="team_id:team-9", value=stale_team) + + import litellm.proxy.proxy_server as ps + from litellm.proxy.proxy_server import _init_and_increment_spend_counter + + orig_user, orig_counter, orig_prisma = ( + ps.user_api_key_cache, + ps.spend_counter_cache, + ps.prisma_client, + ) + ps.user_api_key_cache = stale_cache + ps.spend_counter_cache = counter_cache + ps.prisma_client = fake_prisma + try: + await _init_and_increment_spend_counter( + counter_key="spend:team:team-9", + source_cache_key="team_id:team-9", + increment=1.5, + ) + + fake_prisma.db.litellm_teamtable.find_unique.assert_awaited_once_with( + where={"team_id": "team-9"} + ) + # Two increments keyed on the counter: seed ($42) then request ($1.50). + writes = [(c["key"], c["value"]) for c in recorded_increments] + assert ("spend:team:team-9", 42.0) in writes + assert ("spend:team:team-9", 1.5) in writes + finally: + ps.user_api_key_cache = orig_user + ps.spend_counter_cache = orig_counter + ps.prisma_client = orig_prisma + + +@pytest.mark.asyncio +async def test_reseed_spend_from_db_user_and_org_prefixes(): + """User and org counters must reseed from their own DB tables, not + fall through to 0.0 like the other counters do today.""" + import litellm.proxy.proxy_server as ps + from litellm.proxy.proxy_server import _reseed_spend_from_db + + user_row = MagicMock() + user_row.spend = 17.0 + org_row = MagicMock() + org_row.spend = 305.0 + + fake_prisma = MagicMock() + fake_prisma.db.litellm_usertable.find_unique = AsyncMock(return_value=user_row) + fake_prisma.db.litellm_organizationtable.find_unique = AsyncMock( + return_value=org_row + ) + + orig_prisma = ps.prisma_client + ps.prisma_client = fake_prisma + try: + assert await _reseed_spend_from_db("spend:user:alice") == 17.0 + fake_prisma.db.litellm_usertable.find_unique.assert_awaited_once_with( + where={"user_id": "alice"} + ) + + assert await _reseed_spend_from_db("spend:org:acme") == 305.0 + fake_prisma.db.litellm_organizationtable.find_unique.assert_awaited_once_with( + where={"organization_id": "acme"} + ) + finally: + ps.prisma_client = orig_prisma + + +@pytest.mark.asyncio +async def test_reseed_spend_from_db_skips_window_variant_keys(): + """Window counters (spend:*:window:{duration}) share prefixes with + primary counters but don't correspond to a DB row. The guard must + short-circuit without querying the DB.""" + import litellm.proxy.proxy_server as ps + from litellm.proxy.proxy_server import _reseed_spend_from_db + + fake_prisma = MagicMock() + fake_prisma.db.litellm_verificationtoken.find_unique = AsyncMock() + fake_prisma.db.litellm_teamtable.find_unique = AsyncMock() + + orig_prisma = ps.prisma_client + ps.prisma_client = fake_prisma + try: + assert await _reseed_spend_from_db("spend:key:sk-abc:window:1h") == 0.0 + assert await _reseed_spend_from_db("spend:team:team-1:window:1d") == 0.0 + fake_prisma.db.litellm_verificationtoken.find_unique.assert_not_awaited() + fake_prisma.db.litellm_teamtable.find_unique.assert_not_awaited() + finally: + ps.prisma_client = orig_prisma