mirror of
https://github.com/tiennm99/litellm.git
synced 2026-06-17 16:48:54 +00:00
Fix bug that bypasses per-team member budget limit
This commit is contained in:
committed by
Michael Riad Zaky
parent
24aec61e4b
commit
0bd49ecb8b
@@ -905,6 +905,63 @@ async def get_default_end_user_budget(
|
||||
return None
|
||||
|
||||
|
||||
@log_db_metrics
|
||||
async def get_team_member_default_budget(
|
||||
budget_id: str,
|
||||
prisma_client: Optional[PrismaClient],
|
||||
user_api_key_cache: DualCache,
|
||||
) -> Optional[LiteLLM_BudgetTable]:
|
||||
"""
|
||||
Fetches the team-level default per-member budget referenced by team.metadata["team_member_budget_id"].
|
||||
|
||||
This budget is applied to team members whose TeamMembership row has no
|
||||
linked budget. Results are cached for performance.
|
||||
|
||||
Args:
|
||||
budget_id: The budget_id pulled from team.metadata["team_member_budget_id"]
|
||||
prisma_client: Database client instance
|
||||
user_api_key_cache: Cache for storing/retrieving budget data
|
||||
|
||||
Returns:
|
||||
LiteLLM_BudgetTable if found, None otherwise
|
||||
"""
|
||||
if prisma_client is None:
|
||||
return None
|
||||
|
||||
cache_key = f"team_member_default_budget:{budget_id}"
|
||||
|
||||
cached_budget = await user_api_key_cache.async_get_cache(key=cache_key)
|
||||
if isinstance(cached_budget, LiteLLM_BudgetTable):
|
||||
return cached_budget
|
||||
if isinstance(cached_budget, dict):
|
||||
return LiteLLM_BudgetTable(**cached_budget)
|
||||
|
||||
try:
|
||||
budget_record = await prisma_client.db.litellm_budgettable.find_unique(
|
||||
where={"budget_id": budget_id}
|
||||
)
|
||||
|
||||
if budget_record is None:
|
||||
verbose_proxy_logger.warning(
|
||||
f"Team-default member budget not found in database: {budget_id}"
|
||||
)
|
||||
return None
|
||||
|
||||
await user_api_key_cache.async_set_cache(
|
||||
key=cache_key,
|
||||
value=budget_record.dict(),
|
||||
ttl=DEFAULT_MANAGEMENT_OBJECT_IN_MEMORY_CACHE_TTL,
|
||||
)
|
||||
|
||||
return LiteLLM_BudgetTable(**budget_record.dict())
|
||||
|
||||
except Exception:
|
||||
verbose_proxy_logger.exception(
|
||||
f"Error fetching team-default member budget {budget_id}"
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
async def _apply_default_budget_to_end_user(
|
||||
end_user_obj: LiteLLM_EndUserTable,
|
||||
prisma_client: PrismaClient,
|
||||
@@ -3230,13 +3287,26 @@ async def _check_team_member_budget(
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
)
|
||||
|
||||
if (
|
||||
team_membership is not None
|
||||
and team_membership.litellm_budget_table is not None
|
||||
and team_membership.litellm_budget_table.max_budget is not None
|
||||
):
|
||||
# Per-member override wins; otherwise fall back to the team-level
|
||||
# default configured via team.metadata["team_member_budget_id"].
|
||||
team_member_budget: Optional[float] = None
|
||||
if team_membership is not None and team_membership.litellm_budget_table is not None:
|
||||
team_member_budget = team_membership.litellm_budget_table.max_budget
|
||||
team_member_spend = team_membership.spend or 0.0
|
||||
else:
|
||||
default_budget_id = (team_object.metadata or {}).get("team_member_budget_id")
|
||||
if isinstance(default_budget_id, str):
|
||||
default_budget = await get_team_member_default_budget(
|
||||
budget_id=default_budget_id,
|
||||
prisma_client=prisma_client,
|
||||
user_api_key_cache=user_api_key_cache,
|
||||
)
|
||||
if default_budget is not None:
|
||||
team_member_budget = default_budget.max_budget
|
||||
|
||||
if team_member_budget is not None:
|
||||
team_member_spend = (
|
||||
team_membership.spend if team_membership is not None else 0.0
|
||||
) or 0.0
|
||||
|
||||
# Read from cross-pod counter (Redis-first) if available
|
||||
from litellm.proxy.proxy_server import get_current_spend
|
||||
|
||||
@@ -302,14 +302,15 @@ class TeamMemberBudgetHandler:
|
||||
prisma_client: PrismaClient,
|
||||
) -> None:
|
||||
"""
|
||||
Create team_memberships entries for existing members that don't have one.
|
||||
Ensure every team member has a TeamMembership row linked to the
|
||||
team_member_budget.
|
||||
|
||||
Called after team_member_budget is set/updated on a team to ensure
|
||||
members who joined before the budget was configured also get budget
|
||||
enforcement.
|
||||
|
||||
Only creates missing entries — does not touch existing memberships
|
||||
(which may carry individual per-member budgets).
|
||||
Called after team_member_budget is set/updated on a team. Creates
|
||||
rows for members who don't have one, and populates budget_id on
|
||||
existing rows where it is NULL. Rows with a non-NULL budget_id
|
||||
are left untouched, which preserves per-member overrides but also
|
||||
means rows pointing to a prior team-default budget_id are not
|
||||
migrated to the new one.
|
||||
"""
|
||||
if not members_with_roles:
|
||||
return
|
||||
@@ -347,6 +348,21 @@ class TeamMemberBudgetHandler:
|
||||
_sanitize_for_log(team_member_budget_id),
|
||||
)
|
||||
|
||||
# Heal existing membership rows that predate the team_member_budget
|
||||
# configuration: populate budget_id where it is currently NULL.
|
||||
# Rows with an explicit budget_id (per-member override) are left alone.
|
||||
updated = await prisma_client.db.litellm_teammembership.update_many(
|
||||
where={"team_id": team_id, "budget_id": None},
|
||||
data={"budget_id": team_member_budget_id},
|
||||
)
|
||||
if updated:
|
||||
verbose_proxy_logger.info(
|
||||
"Populated budget_id on %d existing team_memberships for team %s with budget %s",
|
||||
updated,
|
||||
_sanitize_for_log(team_id),
|
||||
_sanitize_for_log(team_member_budget_id),
|
||||
)
|
||||
|
||||
|
||||
def _get_default_team_param(field: str) -> Any:
|
||||
"""
|
||||
|
||||
@@ -1908,40 +1908,110 @@ async def increment_spend_counters(
|
||||
)
|
||||
|
||||
|
||||
async def _reseed_spend_from_db(counter_key: str) -> float:
|
||||
"""
|
||||
Read the authoritative spend for a missing counter from the DB. The
|
||||
counter_key prefix encodes the table to query:
|
||||
|
||||
spend:key:{token} -> LiteLLM_VerificationToken.spend
|
||||
spend:team:{team_id} -> LiteLLM_TeamTable.spend
|
||||
spend:team_member:{uid}:{tid} -> LiteLLM_TeamMembership.spend
|
||||
spend:user:{user_id} -> LiteLLM_UserTable.spend
|
||||
spend:org:{org_id} -> LiteLLM_OrganizationTable.spend
|
||||
|
||||
Returns 0.0 if prisma is unavailable, the row is missing, or the
|
||||
key format is unrecognized. On failure, logs and returns 0.0 rather
|
||||
than raising so the caller can still record the current increment.
|
||||
"""
|
||||
if prisma_client is None:
|
||||
return 0.0
|
||||
# Per-window counters (spend:*:window:{duration}) share prefixes with
|
||||
# primary counters but don't correspond to a DB row; their ambiguity
|
||||
# would otherwise be silently parsed as a regular counter and miss.
|
||||
if ":window:" in counter_key:
|
||||
return 0.0
|
||||
try:
|
||||
if counter_key.startswith("spend:key:"):
|
||||
token = counter_key[len("spend:key:") :]
|
||||
row = await prisma_client.db.litellm_verificationtoken.find_unique(
|
||||
where={"token": token}
|
||||
)
|
||||
elif counter_key.startswith("spend:team_member:"):
|
||||
suffix = counter_key[len("spend:team_member:") :]
|
||||
if ":" not in suffix:
|
||||
return 0.0
|
||||
user_id, team_id = suffix.rsplit(":", 1)
|
||||
row = await prisma_client.db.litellm_teammembership.find_unique(
|
||||
where={"user_id_team_id": {"user_id": user_id, "team_id": team_id}}
|
||||
)
|
||||
elif counter_key.startswith("spend:team:"):
|
||||
team_id = counter_key[len("spend:team:") :]
|
||||
row = await prisma_client.db.litellm_teamtable.find_unique(
|
||||
where={"team_id": team_id}
|
||||
)
|
||||
elif counter_key.startswith("spend:user:"):
|
||||
user_id = counter_key[len("spend:user:") :]
|
||||
row = await prisma_client.db.litellm_usertable.find_unique(
|
||||
where={"user_id": user_id}
|
||||
)
|
||||
elif counter_key.startswith("spend:org:"):
|
||||
org_id = counter_key[len("spend:org:") :]
|
||||
row = await prisma_client.db.litellm_organizationtable.find_unique(
|
||||
where={"organization_id": org_id}
|
||||
)
|
||||
else:
|
||||
return 0.0
|
||||
except Exception:
|
||||
verbose_proxy_logger.exception(
|
||||
"Failed to reseed spend counter %s from DB", counter_key
|
||||
)
|
||||
return 0.0
|
||||
if row is None:
|
||||
return 0.0
|
||||
return float(getattr(row, "spend", 0.0) or 0.0)
|
||||
|
||||
|
||||
async def _init_and_increment_spend_counter(
|
||||
counter_key: str,
|
||||
source_cache_key: str,
|
||||
increment: float,
|
||||
):
|
||||
"""
|
||||
Initialize counter from cached object's DB-loaded spend if not yet set,
|
||||
then atomically increment in both in-memory and Redis.
|
||||
Initialize counter from the authoritative DB spend value if not yet
|
||||
set, then atomically increment in both in-memory and Redis.
|
||||
|
||||
On first access per pod:
|
||||
1. Check spend_counter_cache (in-memory -> Redis via DualCache for init check)
|
||||
2. If not found anywhere, read base spend from user_api_key_cache (DB-loaded object)
|
||||
1. Check spend_counter_cache (in-memory -> Redis via DualCache)
|
||||
2. If not found, reseed from the DB (`_reseed_spend_from_db`). Falls
|
||||
back to the cached object's `.spend` via user_api_key_cache only
|
||||
if prisma is unavailable, since that value can lag the flusher.
|
||||
3. Seed counter via async_increment_cache (not async_set_cache) to avoid a
|
||||
check-then-set race: if two pods cold-start simultaneously, both may see
|
||||
the counter as absent and seed it. Using increment instead of set means
|
||||
the worst case is over-counting (conservative — blocks slightly early)
|
||||
rather than under-counting (would allow overspend).
|
||||
the counter as absent and seed it. Using increment means the worst case
|
||||
is over-counting (conservative, blocks slightly early) rather than
|
||||
under-counting (would allow overspend).
|
||||
4. Increment atomically (both in-memory + Redis)
|
||||
"""
|
||||
current = await spend_counter_cache.async_get_cache(key=counter_key)
|
||||
if current is None:
|
||||
source = await user_api_key_cache.async_get_cache(key=source_cache_key)
|
||||
base_spend = 0.0
|
||||
if source is not None:
|
||||
if isinstance(source, dict):
|
||||
base_spend = source.get("spend", 0.0) or 0.0
|
||||
else:
|
||||
base_spend = getattr(source, "spend", 0.0) or 0.0
|
||||
base_spend = await _reseed_spend_from_db(counter_key)
|
||||
if prisma_client is None:
|
||||
# Best-effort fallback when prisma is unavailable (tests or
|
||||
# early-startup paths). May be stale but avoids resetting to 0.
|
||||
source = await user_api_key_cache.async_get_cache(key=source_cache_key)
|
||||
if source is not None:
|
||||
if isinstance(source, dict):
|
||||
base_spend = source.get("spend", 0.0) or 0.0
|
||||
else:
|
||||
base_spend = getattr(source, "spend", 0.0) or 0.0
|
||||
if base_spend > 0:
|
||||
await spend_counter_cache.async_increment_cache(
|
||||
key=counter_key, value=base_spend
|
||||
)
|
||||
|
||||
await spend_counter_cache.async_increment_cache(key=counter_key, value=increment)
|
||||
await spend_counter_cache.async_increment_cache(
|
||||
key=counter_key, value=increment
|
||||
)
|
||||
|
||||
|
||||
async def update_cache( # noqa: PLR0915
|
||||
|
||||
@@ -2126,3 +2126,192 @@ class TestGuardrailModificationCheck:
|
||||
"""Unparseable strings should not trigger a 403 — they have no keys."""
|
||||
self._call({"metadata": "not-json"})
|
||||
self._call({"metadata": '"just a string"'})
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_team_member_budget_check_falls_back_to_team_default_budget_id():
|
||||
"""When a member's TeamMembership has no linked budget row, the check
|
||||
should fall back to team.metadata["team_member_budget_id"] and still
|
||||
enforce the cap. Pre-fix, this path silently skipped enforcement."""
|
||||
from litellm.caching.dual_cache import DualCache
|
||||
from litellm.proxy._types import LiteLLM_TeamMembership
|
||||
from litellm.proxy.utils import ProxyLogging
|
||||
|
||||
team_object = LiteLLM_TeamTable(
|
||||
team_id="test-team",
|
||||
metadata={"team_member_budget_id": "budget-default"},
|
||||
)
|
||||
user_object = LiteLLM_UserTable(user_id="test-user")
|
||||
valid_token = UserAPIKeyAuth(
|
||||
token="test-token",
|
||||
user_id="test-user",
|
||||
team_id="test-team",
|
||||
)
|
||||
|
||||
# Membership row without an attached budget.
|
||||
team_membership = LiteLLM_TeamMembership(
|
||||
user_id="test-user",
|
||||
team_id="test-team",
|
||||
spend=0.0,
|
||||
budget_id=None,
|
||||
litellm_budget_table=None,
|
||||
)
|
||||
|
||||
proxy_logging_obj = ProxyLogging(user_api_key_cache=None)
|
||||
|
||||
fake_budget_row = MagicMock()
|
||||
fake_budget_row.max_budget = 50.0
|
||||
fake_budget_row.dict = MagicMock(
|
||||
return_value={"budget_id": "budget-default", "max_budget": 50.0}
|
||||
)
|
||||
|
||||
prisma_client = MagicMock()
|
||||
prisma_client.db.litellm_budgettable.find_unique = AsyncMock(
|
||||
return_value=fake_budget_row
|
||||
)
|
||||
|
||||
async def mock_get_current_spend(counter_key, fallback_spend):
|
||||
if counter_key == "spend:team_member:test-user:test-team":
|
||||
return 70.0
|
||||
return fallback_spend
|
||||
|
||||
user_api_key_cache = DualCache()
|
||||
|
||||
with (
|
||||
patch("litellm.proxy.proxy_server.get_current_spend", mock_get_current_spend),
|
||||
patch(
|
||||
"litellm.proxy.auth.auth_checks.get_team_membership",
|
||||
new_callable=AsyncMock,
|
||||
return_value=team_membership,
|
||||
),
|
||||
):
|
||||
with pytest.raises(litellm.BudgetExceededError) as exc_info:
|
||||
await _check_team_member_budget(
|
||||
team_object=team_object,
|
||||
user_object=user_object,
|
||||
valid_token=valid_token,
|
||||
prisma_client=prisma_client,
|
||||
user_api_key_cache=user_api_key_cache,
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
)
|
||||
assert exc_info.value.current_cost == 70.0
|
||||
assert exc_info.value.max_budget == 50.0
|
||||
|
||||
# First call did perform the fallback DB lookup.
|
||||
prisma_client.db.litellm_budgettable.find_unique.assert_awaited_once()
|
||||
|
||||
# Second call hits the cached budget row, no additional prisma read.
|
||||
prisma_client.db.litellm_budgettable.find_unique.reset_mock()
|
||||
with (
|
||||
patch("litellm.proxy.proxy_server.get_current_spend", mock_get_current_spend),
|
||||
patch(
|
||||
"litellm.proxy.auth.auth_checks.get_team_membership",
|
||||
new_callable=AsyncMock,
|
||||
return_value=team_membership,
|
||||
),
|
||||
):
|
||||
with pytest.raises(litellm.BudgetExceededError) as second_exc_info:
|
||||
await _check_team_member_budget(
|
||||
team_object=team_object,
|
||||
user_object=user_object,
|
||||
valid_token=valid_token,
|
||||
prisma_client=prisma_client,
|
||||
user_api_key_cache=user_api_key_cache,
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
)
|
||||
# The cached $50 cap is still being applied (not a coincidental skip)
|
||||
assert second_exc_info.value.current_cost == 70.0
|
||||
assert second_exc_info.value.max_budget == 50.0
|
||||
prisma_client.db.litellm_budgettable.find_unique.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_team_member_budget_check_per_member_override_wins_over_team_default():
|
||||
"""If a member has a per-member budget AND the team carries a
|
||||
team_member_budget_id default, the per-member value wins and the
|
||||
fallback prisma lookup is never performed."""
|
||||
from litellm.caching.dual_cache import DualCache
|
||||
from litellm.proxy._types import LiteLLM_BudgetTable, LiteLLM_TeamMembership
|
||||
from litellm.proxy.utils import ProxyLogging
|
||||
|
||||
team_object = LiteLLM_TeamTable(
|
||||
team_id="test-team",
|
||||
metadata={"team_member_budget_id": "budget-default"},
|
||||
)
|
||||
user_object = LiteLLM_UserTable(user_id="test-user")
|
||||
valid_token = UserAPIKeyAuth(
|
||||
token="test-token",
|
||||
user_id="test-user",
|
||||
team_id="test-team",
|
||||
)
|
||||
|
||||
team_membership = LiteLLM_TeamMembership(
|
||||
user_id="test-user",
|
||||
team_id="test-team",
|
||||
spend=0.0,
|
||||
budget_id="budget-override",
|
||||
litellm_budget_table=LiteLLM_BudgetTable(max_budget=200.0),
|
||||
)
|
||||
|
||||
proxy_logging_obj = ProxyLogging(user_api_key_cache=None)
|
||||
|
||||
# Team-default row resolves to $50. If the fallback fired (it must
|
||||
# not here), spend $70 would exceed that $50 cap and raise.
|
||||
fake_budget_row = MagicMock()
|
||||
fake_budget_row.max_budget = 50.0
|
||||
|
||||
prisma_client = MagicMock()
|
||||
prisma_client.db.litellm_budgettable.find_unique = AsyncMock(
|
||||
return_value=fake_budget_row
|
||||
)
|
||||
|
||||
mocked_spend = 70.0
|
||||
|
||||
async def mock_get_current_spend(counter_key, fallback_spend):
|
||||
if counter_key == "spend:team_member:test-user:test-team":
|
||||
return mocked_spend
|
||||
return fallback_spend
|
||||
|
||||
# 1. spend ($70) < per-member cap ($200) → no raise, no fallback lookup.
|
||||
with (
|
||||
patch("litellm.proxy.proxy_server.get_current_spend", mock_get_current_spend),
|
||||
patch(
|
||||
"litellm.proxy.auth.auth_checks.get_team_membership",
|
||||
new_callable=AsyncMock,
|
||||
return_value=team_membership,
|
||||
),
|
||||
):
|
||||
await _check_team_member_budget(
|
||||
team_object=team_object,
|
||||
user_object=user_object,
|
||||
valid_token=valid_token,
|
||||
prisma_client=prisma_client,
|
||||
user_api_key_cache=DualCache(),
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
)
|
||||
|
||||
prisma_client.db.litellm_budgettable.find_unique.assert_not_awaited()
|
||||
|
||||
# 2. Now push spend above the per-member cap ($200). Must raise with
|
||||
# max_budget=200 to prove the per-member cap is the value being
|
||||
# enforced (not just that enforcement silently skipped).
|
||||
mocked_spend = 250.0
|
||||
with (
|
||||
patch("litellm.proxy.proxy_server.get_current_spend", mock_get_current_spend),
|
||||
patch(
|
||||
"litellm.proxy.auth.auth_checks.get_team_membership",
|
||||
new_callable=AsyncMock,
|
||||
return_value=team_membership,
|
||||
),
|
||||
):
|
||||
with pytest.raises(litellm.BudgetExceededError) as exc_info:
|
||||
await _check_team_member_budget(
|
||||
team_object=team_object,
|
||||
user_object=user_object,
|
||||
valid_token=valid_token,
|
||||
prisma_client=prisma_client,
|
||||
user_api_key_cache=DualCache(),
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
)
|
||||
assert exc_info.value.current_cost == 250.0
|
||||
assert exc_info.value.max_budget == 200.0
|
||||
|
||||
@@ -1795,6 +1795,7 @@ async def test_backfill_team_member_budget_entries_creates_missing_memberships()
|
||||
return_value=[existing_membership]
|
||||
)
|
||||
mock_prisma.db.litellm_teammembership.create_many = AsyncMock(return_value=None)
|
||||
mock_prisma.db.litellm_teammembership.update_many = AsyncMock(return_value=0)
|
||||
|
||||
# Test with Member instances
|
||||
members = [
|
||||
@@ -1823,6 +1824,7 @@ async def test_backfill_team_member_budget_entries_creates_missing_memberships()
|
||||
# Also test with raw dicts (members_with_roles may be dicts when deserialized from DB)
|
||||
mock_prisma.db.litellm_teammembership.find_many.reset_mock()
|
||||
mock_prisma.db.litellm_teammembership.create_many.reset_mock()
|
||||
mock_prisma.db.litellm_teammembership.update_many.reset_mock()
|
||||
|
||||
members_as_dicts = [
|
||||
{"user_id": "user-A", "role": "user"},
|
||||
@@ -1868,6 +1870,7 @@ async def test_backfill_team_member_budget_entries_no_op_when_all_exist():
|
||||
return_value=[existing_a, existing_b]
|
||||
)
|
||||
mock_prisma.db.litellm_teammembership.create_many = AsyncMock(return_value=None)
|
||||
mock_prisma.db.litellm_teammembership.update_many = AsyncMock(return_value=0)
|
||||
|
||||
members = [
|
||||
Member(user_id="user-A", role="user"),
|
||||
@@ -1884,6 +1887,55 @@ async def test_backfill_team_member_budget_entries_no_op_when_all_exist():
|
||||
mock_prisma.db.litellm_teammembership.create_many.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_backfill_team_member_budget_entries_populates_null_budget_id_on_existing_rows():
|
||||
"""
|
||||
backfill_team_member_budget_entries should populate budget_id on
|
||||
existing TeamMembership rows where it is currently NULL, so admins
|
||||
can configure a team member budget after members have already joined
|
||||
and have enforcement apply to those pre-existing members.
|
||||
"""
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
from litellm.proxy._types import Member
|
||||
from litellm.proxy.management_endpoints.team_endpoints import (
|
||||
TeamMemberBudgetHandler,
|
||||
)
|
||||
|
||||
team_id = "team-abc"
|
||||
budget_id = "budget-xyz"
|
||||
|
||||
# Both members already have rows, so create_many must not fire;
|
||||
# update_many must fire with the NULL-budget_id filter.
|
||||
existing_a = MagicMock()
|
||||
existing_a.user_id = "user-A"
|
||||
existing_b = MagicMock()
|
||||
existing_b.user_id = "user-B"
|
||||
|
||||
mock_prisma = MagicMock()
|
||||
mock_prisma.db.litellm_teammembership.find_many = AsyncMock(
|
||||
return_value=[existing_a, existing_b]
|
||||
)
|
||||
mock_prisma.db.litellm_teammembership.create_many = AsyncMock(return_value=None)
|
||||
mock_prisma.db.litellm_teammembership.update_many = AsyncMock(return_value=2)
|
||||
|
||||
await TeamMemberBudgetHandler.backfill_team_member_budget_entries(
|
||||
team_id=team_id,
|
||||
members_with_roles=[
|
||||
Member(user_id="user-A", role="user"),
|
||||
Member(user_id="user-B", role="user"),
|
||||
],
|
||||
team_member_budget_id=budget_id,
|
||||
prisma_client=mock_prisma,
|
||||
)
|
||||
|
||||
mock_prisma.db.litellm_teammembership.create_many.assert_not_awaited()
|
||||
mock_prisma.db.litellm_teammembership.update_many.assert_awaited_once_with(
|
||||
where={"team_id": team_id, "budget_id": None},
|
||||
data={"budget_id": budget_id},
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_backfill_team_member_budget_entries_empty_members():
|
||||
"""
|
||||
|
||||
@@ -4965,3 +4965,123 @@ async def test_increment_spend_counters_team_and_member():
|
||||
finally:
|
||||
ps.user_api_key_cache = original_key_cache
|
||||
ps.spend_counter_cache = original_counter_cache
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_init_and_increment_spend_counter_reseeds_from_db_on_counter_miss():
|
||||
"""When the Redis counter is missing, the reseed path reads the
|
||||
authoritative spend from the DB (not a stale cache), so the next
|
||||
increment continues from the correct base value."""
|
||||
from litellm.caching.dual_cache import DualCache
|
||||
|
||||
counter_cache = DualCache()
|
||||
recorded_increments: list = []
|
||||
|
||||
async def record_increment(key, value, ttl=None, **kwargs):
|
||||
recorded_increments.append({"key": key, "value": value, "ttl": ttl})
|
||||
return value
|
||||
|
||||
fake_redis = AsyncMock()
|
||||
fake_redis.async_increment = AsyncMock(side_effect=record_increment)
|
||||
fake_redis.async_get_cache = AsyncMock(return_value=None) # counter missing
|
||||
counter_cache.redis_cache = fake_redis
|
||||
|
||||
# Prisma returns spend=42.0 (authoritative) while the stale cached
|
||||
# value (would be read only if prisma is None) is 10.0. The counter
|
||||
# must seed from 42, not 10.
|
||||
db_row = MagicMock()
|
||||
db_row.spend = 42.0
|
||||
fake_prisma = MagicMock()
|
||||
fake_prisma.db.litellm_teamtable.find_unique = AsyncMock(return_value=db_row)
|
||||
|
||||
stale_cache = DualCache()
|
||||
stale_team = MagicMock()
|
||||
stale_team.spend = 10.0
|
||||
stale_cache.in_memory_cache.set_cache(key="team_id:team-9", value=stale_team)
|
||||
|
||||
import litellm.proxy.proxy_server as ps
|
||||
from litellm.proxy.proxy_server import _init_and_increment_spend_counter
|
||||
|
||||
orig_user, orig_counter, orig_prisma = (
|
||||
ps.user_api_key_cache,
|
||||
ps.spend_counter_cache,
|
||||
ps.prisma_client,
|
||||
)
|
||||
ps.user_api_key_cache = stale_cache
|
||||
ps.spend_counter_cache = counter_cache
|
||||
ps.prisma_client = fake_prisma
|
||||
try:
|
||||
await _init_and_increment_spend_counter(
|
||||
counter_key="spend:team:team-9",
|
||||
source_cache_key="team_id:team-9",
|
||||
increment=1.5,
|
||||
)
|
||||
|
||||
fake_prisma.db.litellm_teamtable.find_unique.assert_awaited_once_with(
|
||||
where={"team_id": "team-9"}
|
||||
)
|
||||
# Two increments keyed on the counter: seed ($42) then request ($1.50).
|
||||
writes = [(c["key"], c["value"]) for c in recorded_increments]
|
||||
assert ("spend:team:team-9", 42.0) in writes
|
||||
assert ("spend:team:team-9", 1.5) in writes
|
||||
finally:
|
||||
ps.user_api_key_cache = orig_user
|
||||
ps.spend_counter_cache = orig_counter
|
||||
ps.prisma_client = orig_prisma
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reseed_spend_from_db_user_and_org_prefixes():
|
||||
"""User and org counters must reseed from their own DB tables, not
|
||||
fall through to 0.0 like the other counters do today."""
|
||||
import litellm.proxy.proxy_server as ps
|
||||
from litellm.proxy.proxy_server import _reseed_spend_from_db
|
||||
|
||||
user_row = MagicMock()
|
||||
user_row.spend = 17.0
|
||||
org_row = MagicMock()
|
||||
org_row.spend = 305.0
|
||||
|
||||
fake_prisma = MagicMock()
|
||||
fake_prisma.db.litellm_usertable.find_unique = AsyncMock(return_value=user_row)
|
||||
fake_prisma.db.litellm_organizationtable.find_unique = AsyncMock(
|
||||
return_value=org_row
|
||||
)
|
||||
|
||||
orig_prisma = ps.prisma_client
|
||||
ps.prisma_client = fake_prisma
|
||||
try:
|
||||
assert await _reseed_spend_from_db("spend:user:alice") == 17.0
|
||||
fake_prisma.db.litellm_usertable.find_unique.assert_awaited_once_with(
|
||||
where={"user_id": "alice"}
|
||||
)
|
||||
|
||||
assert await _reseed_spend_from_db("spend:org:acme") == 305.0
|
||||
fake_prisma.db.litellm_organizationtable.find_unique.assert_awaited_once_with(
|
||||
where={"organization_id": "acme"}
|
||||
)
|
||||
finally:
|
||||
ps.prisma_client = orig_prisma
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reseed_spend_from_db_skips_window_variant_keys():
|
||||
"""Window counters (spend:*:window:{duration}) share prefixes with
|
||||
primary counters but don't correspond to a DB row. The guard must
|
||||
short-circuit without querying the DB."""
|
||||
import litellm.proxy.proxy_server as ps
|
||||
from litellm.proxy.proxy_server import _reseed_spend_from_db
|
||||
|
||||
fake_prisma = MagicMock()
|
||||
fake_prisma.db.litellm_verificationtoken.find_unique = AsyncMock()
|
||||
fake_prisma.db.litellm_teamtable.find_unique = AsyncMock()
|
||||
|
||||
orig_prisma = ps.prisma_client
|
||||
ps.prisma_client = fake_prisma
|
||||
try:
|
||||
assert await _reseed_spend_from_db("spend:key:sk-abc:window:1h") == 0.0
|
||||
assert await _reseed_spend_from_db("spend:team:team-1:window:1d") == 0.0
|
||||
fake_prisma.db.litellm_verificationtoken.find_unique.assert_not_awaited()
|
||||
fake_prisma.db.litellm_teamtable.find_unique.assert_not_awaited()
|
||||
finally:
|
||||
ps.prisma_client = orig_prisma
|
||||
|
||||
Reference in New Issue
Block a user