Fix bug that bypasses per-team member budget limit

This commit is contained in:
Michael Riad Zaky
2026-04-17 16:58:26 -07:00
committed by Michael Riad Zaky
parent 24aec61e4b
commit 0bd49ecb8b
6 changed files with 545 additions and 28 deletions
+76 -6
View File
@@ -905,6 +905,63 @@ async def get_default_end_user_budget(
return None
@log_db_metrics
async def get_team_member_default_budget(
budget_id: str,
prisma_client: Optional[PrismaClient],
user_api_key_cache: DualCache,
) -> Optional[LiteLLM_BudgetTable]:
"""
Fetches the team-level default per-member budget referenced by team.metadata["team_member_budget_id"].
This budget is applied to team members whose TeamMembership row has no
linked budget. Results are cached for performance.
Args:
budget_id: The budget_id pulled from team.metadata["team_member_budget_id"]
prisma_client: Database client instance
user_api_key_cache: Cache for storing/retrieving budget data
Returns:
LiteLLM_BudgetTable if found, None otherwise
"""
if prisma_client is None:
return None
cache_key = f"team_member_default_budget:{budget_id}"
cached_budget = await user_api_key_cache.async_get_cache(key=cache_key)
if isinstance(cached_budget, LiteLLM_BudgetTable):
return cached_budget
if isinstance(cached_budget, dict):
return LiteLLM_BudgetTable(**cached_budget)
try:
budget_record = await prisma_client.db.litellm_budgettable.find_unique(
where={"budget_id": budget_id}
)
if budget_record is None:
verbose_proxy_logger.warning(
f"Team-default member budget not found in database: {budget_id}"
)
return None
await user_api_key_cache.async_set_cache(
key=cache_key,
value=budget_record.dict(),
ttl=DEFAULT_MANAGEMENT_OBJECT_IN_MEMORY_CACHE_TTL,
)
return LiteLLM_BudgetTable(**budget_record.dict())
except Exception:
verbose_proxy_logger.exception(
f"Error fetching team-default member budget {budget_id}"
)
return None
async def _apply_default_budget_to_end_user(
end_user_obj: LiteLLM_EndUserTable,
prisma_client: PrismaClient,
@@ -3230,13 +3287,26 @@ async def _check_team_member_budget(
proxy_logging_obj=proxy_logging_obj,
)
if (
team_membership is not None
and team_membership.litellm_budget_table is not None
and team_membership.litellm_budget_table.max_budget is not None
):
# Per-member override wins; otherwise fall back to the team-level
# default configured via team.metadata["team_member_budget_id"].
team_member_budget: Optional[float] = None
if team_membership is not None and team_membership.litellm_budget_table is not None:
team_member_budget = team_membership.litellm_budget_table.max_budget
team_member_spend = team_membership.spend or 0.0
else:
default_budget_id = (team_object.metadata or {}).get("team_member_budget_id")
if isinstance(default_budget_id, str):
default_budget = await get_team_member_default_budget(
budget_id=default_budget_id,
prisma_client=prisma_client,
user_api_key_cache=user_api_key_cache,
)
if default_budget is not None:
team_member_budget = default_budget.max_budget
if team_member_budget is not None:
team_member_spend = (
team_membership.spend if team_membership is not None else 0.0
) or 0.0
# Read from cross-pod counter (Redis-first) if available
from litellm.proxy.proxy_server import get_current_spend
@@ -302,14 +302,15 @@ class TeamMemberBudgetHandler:
prisma_client: PrismaClient,
) -> None:
"""
Create team_memberships entries for existing members that don't have one.
Ensure every team member has a TeamMembership row linked to the
team_member_budget.
Called after team_member_budget is set/updated on a team to ensure
members who joined before the budget was configured also get budget
enforcement.
Only creates missing entries does not touch existing memberships
(which may carry individual per-member budgets).
Called after team_member_budget is set/updated on a team. Creates
rows for members who don't have one, and populates budget_id on
existing rows where it is NULL. Rows with a non-NULL budget_id
are left untouched, which preserves per-member overrides but also
means rows pointing to a prior team-default budget_id are not
migrated to the new one.
"""
if not members_with_roles:
return
@@ -347,6 +348,21 @@ class TeamMemberBudgetHandler:
_sanitize_for_log(team_member_budget_id),
)
# Heal existing membership rows that predate the team_member_budget
# configuration: populate budget_id where it is currently NULL.
# Rows with an explicit budget_id (per-member override) are left alone.
updated = await prisma_client.db.litellm_teammembership.update_many(
where={"team_id": team_id, "budget_id": None},
data={"budget_id": team_member_budget_id},
)
if updated:
verbose_proxy_logger.info(
"Populated budget_id on %d existing team_memberships for team %s with budget %s",
updated,
_sanitize_for_log(team_id),
_sanitize_for_log(team_member_budget_id),
)
def _get_default_team_param(field: str) -> Any:
"""
+85 -15
View File
@@ -1908,40 +1908,110 @@ async def increment_spend_counters(
)
async def _reseed_spend_from_db(counter_key: str) -> float:
"""
Read the authoritative spend for a missing counter from the DB. The
counter_key prefix encodes the table to query:
spend:key:{token} -> LiteLLM_VerificationToken.spend
spend:team:{team_id} -> LiteLLM_TeamTable.spend
spend:team_member:{uid}:{tid} -> LiteLLM_TeamMembership.spend
spend:user:{user_id} -> LiteLLM_UserTable.spend
spend:org:{org_id} -> LiteLLM_OrganizationTable.spend
Returns 0.0 if prisma is unavailable, the row is missing, or the
key format is unrecognized. On failure, logs and returns 0.0 rather
than raising so the caller can still record the current increment.
"""
if prisma_client is None:
return 0.0
# Per-window counters (spend:*:window:{duration}) share prefixes with
# primary counters but don't correspond to a DB row; their ambiguity
# would otherwise be silently parsed as a regular counter and miss.
if ":window:" in counter_key:
return 0.0
try:
if counter_key.startswith("spend:key:"):
token = counter_key[len("spend:key:") :]
row = await prisma_client.db.litellm_verificationtoken.find_unique(
where={"token": token}
)
elif counter_key.startswith("spend:team_member:"):
suffix = counter_key[len("spend:team_member:") :]
if ":" not in suffix:
return 0.0
user_id, team_id = suffix.rsplit(":", 1)
row = await prisma_client.db.litellm_teammembership.find_unique(
where={"user_id_team_id": {"user_id": user_id, "team_id": team_id}}
)
elif counter_key.startswith("spend:team:"):
team_id = counter_key[len("spend:team:") :]
row = await prisma_client.db.litellm_teamtable.find_unique(
where={"team_id": team_id}
)
elif counter_key.startswith("spend:user:"):
user_id = counter_key[len("spend:user:") :]
row = await prisma_client.db.litellm_usertable.find_unique(
where={"user_id": user_id}
)
elif counter_key.startswith("spend:org:"):
org_id = counter_key[len("spend:org:") :]
row = await prisma_client.db.litellm_organizationtable.find_unique(
where={"organization_id": org_id}
)
else:
return 0.0
except Exception:
verbose_proxy_logger.exception(
"Failed to reseed spend counter %s from DB", counter_key
)
return 0.0
if row is None:
return 0.0
return float(getattr(row, "spend", 0.0) or 0.0)
async def _init_and_increment_spend_counter(
counter_key: str,
source_cache_key: str,
increment: float,
):
"""
Initialize counter from cached object's DB-loaded spend if not yet set,
then atomically increment in both in-memory and Redis.
Initialize counter from the authoritative DB spend value if not yet
set, then atomically increment in both in-memory and Redis.
On first access per pod:
1. Check spend_counter_cache (in-memory -> Redis via DualCache for init check)
2. If not found anywhere, read base spend from user_api_key_cache (DB-loaded object)
1. Check spend_counter_cache (in-memory -> Redis via DualCache)
2. If not found, reseed from the DB (`_reseed_spend_from_db`). Falls
back to the cached object's `.spend` via user_api_key_cache only
if prisma is unavailable, since that value can lag the flusher.
3. Seed counter via async_increment_cache (not async_set_cache) to avoid a
check-then-set race: if two pods cold-start simultaneously, both may see
the counter as absent and seed it. Using increment instead of set means
the worst case is over-counting (conservative blocks slightly early)
rather than under-counting (would allow overspend).
the counter as absent and seed it. Using increment means the worst case
is over-counting (conservative, blocks slightly early) rather than
under-counting (would allow overspend).
4. Increment atomically (both in-memory + Redis)
"""
current = await spend_counter_cache.async_get_cache(key=counter_key)
if current is None:
source = await user_api_key_cache.async_get_cache(key=source_cache_key)
base_spend = 0.0
if source is not None:
if isinstance(source, dict):
base_spend = source.get("spend", 0.0) or 0.0
else:
base_spend = getattr(source, "spend", 0.0) or 0.0
base_spend = await _reseed_spend_from_db(counter_key)
if prisma_client is None:
# Best-effort fallback when prisma is unavailable (tests or
# early-startup paths). May be stale but avoids resetting to 0.
source = await user_api_key_cache.async_get_cache(key=source_cache_key)
if source is not None:
if isinstance(source, dict):
base_spend = source.get("spend", 0.0) or 0.0
else:
base_spend = getattr(source, "spend", 0.0) or 0.0
if base_spend > 0:
await spend_counter_cache.async_increment_cache(
key=counter_key, value=base_spend
)
await spend_counter_cache.async_increment_cache(key=counter_key, value=increment)
await spend_counter_cache.async_increment_cache(
key=counter_key, value=increment
)
async def update_cache( # noqa: PLR0915
@@ -2126,3 +2126,192 @@ class TestGuardrailModificationCheck:
"""Unparseable strings should not trigger a 403 — they have no keys."""
self._call({"metadata": "not-json"})
self._call({"metadata": '"just a string"'})
@pytest.mark.asyncio
async def test_team_member_budget_check_falls_back_to_team_default_budget_id():
"""When a member's TeamMembership has no linked budget row, the check
should fall back to team.metadata["team_member_budget_id"] and still
enforce the cap. Pre-fix, this path silently skipped enforcement."""
from litellm.caching.dual_cache import DualCache
from litellm.proxy._types import LiteLLM_TeamMembership
from litellm.proxy.utils import ProxyLogging
team_object = LiteLLM_TeamTable(
team_id="test-team",
metadata={"team_member_budget_id": "budget-default"},
)
user_object = LiteLLM_UserTable(user_id="test-user")
valid_token = UserAPIKeyAuth(
token="test-token",
user_id="test-user",
team_id="test-team",
)
# Membership row without an attached budget.
team_membership = LiteLLM_TeamMembership(
user_id="test-user",
team_id="test-team",
spend=0.0,
budget_id=None,
litellm_budget_table=None,
)
proxy_logging_obj = ProxyLogging(user_api_key_cache=None)
fake_budget_row = MagicMock()
fake_budget_row.max_budget = 50.0
fake_budget_row.dict = MagicMock(
return_value={"budget_id": "budget-default", "max_budget": 50.0}
)
prisma_client = MagicMock()
prisma_client.db.litellm_budgettable.find_unique = AsyncMock(
return_value=fake_budget_row
)
async def mock_get_current_spend(counter_key, fallback_spend):
if counter_key == "spend:team_member:test-user:test-team":
return 70.0
return fallback_spend
user_api_key_cache = DualCache()
with (
patch("litellm.proxy.proxy_server.get_current_spend", mock_get_current_spend),
patch(
"litellm.proxy.auth.auth_checks.get_team_membership",
new_callable=AsyncMock,
return_value=team_membership,
),
):
with pytest.raises(litellm.BudgetExceededError) as exc_info:
await _check_team_member_budget(
team_object=team_object,
user_object=user_object,
valid_token=valid_token,
prisma_client=prisma_client,
user_api_key_cache=user_api_key_cache,
proxy_logging_obj=proxy_logging_obj,
)
assert exc_info.value.current_cost == 70.0
assert exc_info.value.max_budget == 50.0
# First call did perform the fallback DB lookup.
prisma_client.db.litellm_budgettable.find_unique.assert_awaited_once()
# Second call hits the cached budget row, no additional prisma read.
prisma_client.db.litellm_budgettable.find_unique.reset_mock()
with (
patch("litellm.proxy.proxy_server.get_current_spend", mock_get_current_spend),
patch(
"litellm.proxy.auth.auth_checks.get_team_membership",
new_callable=AsyncMock,
return_value=team_membership,
),
):
with pytest.raises(litellm.BudgetExceededError) as second_exc_info:
await _check_team_member_budget(
team_object=team_object,
user_object=user_object,
valid_token=valid_token,
prisma_client=prisma_client,
user_api_key_cache=user_api_key_cache,
proxy_logging_obj=proxy_logging_obj,
)
# The cached $50 cap is still being applied (not a coincidental skip)
assert second_exc_info.value.current_cost == 70.0
assert second_exc_info.value.max_budget == 50.0
prisma_client.db.litellm_budgettable.find_unique.assert_not_awaited()
@pytest.mark.asyncio
async def test_team_member_budget_check_per_member_override_wins_over_team_default():
"""If a member has a per-member budget AND the team carries a
team_member_budget_id default, the per-member value wins and the
fallback prisma lookup is never performed."""
from litellm.caching.dual_cache import DualCache
from litellm.proxy._types import LiteLLM_BudgetTable, LiteLLM_TeamMembership
from litellm.proxy.utils import ProxyLogging
team_object = LiteLLM_TeamTable(
team_id="test-team",
metadata={"team_member_budget_id": "budget-default"},
)
user_object = LiteLLM_UserTable(user_id="test-user")
valid_token = UserAPIKeyAuth(
token="test-token",
user_id="test-user",
team_id="test-team",
)
team_membership = LiteLLM_TeamMembership(
user_id="test-user",
team_id="test-team",
spend=0.0,
budget_id="budget-override",
litellm_budget_table=LiteLLM_BudgetTable(max_budget=200.0),
)
proxy_logging_obj = ProxyLogging(user_api_key_cache=None)
# Team-default row resolves to $50. If the fallback fired (it must
# not here), spend $70 would exceed that $50 cap and raise.
fake_budget_row = MagicMock()
fake_budget_row.max_budget = 50.0
prisma_client = MagicMock()
prisma_client.db.litellm_budgettable.find_unique = AsyncMock(
return_value=fake_budget_row
)
mocked_spend = 70.0
async def mock_get_current_spend(counter_key, fallback_spend):
if counter_key == "spend:team_member:test-user:test-team":
return mocked_spend
return fallback_spend
# 1. spend ($70) < per-member cap ($200) → no raise, no fallback lookup.
with (
patch("litellm.proxy.proxy_server.get_current_spend", mock_get_current_spend),
patch(
"litellm.proxy.auth.auth_checks.get_team_membership",
new_callable=AsyncMock,
return_value=team_membership,
),
):
await _check_team_member_budget(
team_object=team_object,
user_object=user_object,
valid_token=valid_token,
prisma_client=prisma_client,
user_api_key_cache=DualCache(),
proxy_logging_obj=proxy_logging_obj,
)
prisma_client.db.litellm_budgettable.find_unique.assert_not_awaited()
# 2. Now push spend above the per-member cap ($200). Must raise with
# max_budget=200 to prove the per-member cap is the value being
# enforced (not just that enforcement silently skipped).
mocked_spend = 250.0
with (
patch("litellm.proxy.proxy_server.get_current_spend", mock_get_current_spend),
patch(
"litellm.proxy.auth.auth_checks.get_team_membership",
new_callable=AsyncMock,
return_value=team_membership,
),
):
with pytest.raises(litellm.BudgetExceededError) as exc_info:
await _check_team_member_budget(
team_object=team_object,
user_object=user_object,
valid_token=valid_token,
prisma_client=prisma_client,
user_api_key_cache=DualCache(),
proxy_logging_obj=proxy_logging_obj,
)
assert exc_info.value.current_cost == 250.0
assert exc_info.value.max_budget == 200.0
@@ -1795,6 +1795,7 @@ async def test_backfill_team_member_budget_entries_creates_missing_memberships()
return_value=[existing_membership]
)
mock_prisma.db.litellm_teammembership.create_many = AsyncMock(return_value=None)
mock_prisma.db.litellm_teammembership.update_many = AsyncMock(return_value=0)
# Test with Member instances
members = [
@@ -1823,6 +1824,7 @@ async def test_backfill_team_member_budget_entries_creates_missing_memberships()
# Also test with raw dicts (members_with_roles may be dicts when deserialized from DB)
mock_prisma.db.litellm_teammembership.find_many.reset_mock()
mock_prisma.db.litellm_teammembership.create_many.reset_mock()
mock_prisma.db.litellm_teammembership.update_many.reset_mock()
members_as_dicts = [
{"user_id": "user-A", "role": "user"},
@@ -1868,6 +1870,7 @@ async def test_backfill_team_member_budget_entries_no_op_when_all_exist():
return_value=[existing_a, existing_b]
)
mock_prisma.db.litellm_teammembership.create_many = AsyncMock(return_value=None)
mock_prisma.db.litellm_teammembership.update_many = AsyncMock(return_value=0)
members = [
Member(user_id="user-A", role="user"),
@@ -1884,6 +1887,55 @@ async def test_backfill_team_member_budget_entries_no_op_when_all_exist():
mock_prisma.db.litellm_teammembership.create_many.assert_not_awaited()
@pytest.mark.asyncio
async def test_backfill_team_member_budget_entries_populates_null_budget_id_on_existing_rows():
"""
backfill_team_member_budget_entries should populate budget_id on
existing TeamMembership rows where it is currently NULL, so admins
can configure a team member budget after members have already joined
and have enforcement apply to those pre-existing members.
"""
from unittest.mock import AsyncMock, MagicMock
from litellm.proxy._types import Member
from litellm.proxy.management_endpoints.team_endpoints import (
TeamMemberBudgetHandler,
)
team_id = "team-abc"
budget_id = "budget-xyz"
# Both members already have rows, so create_many must not fire;
# update_many must fire with the NULL-budget_id filter.
existing_a = MagicMock()
existing_a.user_id = "user-A"
existing_b = MagicMock()
existing_b.user_id = "user-B"
mock_prisma = MagicMock()
mock_prisma.db.litellm_teammembership.find_many = AsyncMock(
return_value=[existing_a, existing_b]
)
mock_prisma.db.litellm_teammembership.create_many = AsyncMock(return_value=None)
mock_prisma.db.litellm_teammembership.update_many = AsyncMock(return_value=2)
await TeamMemberBudgetHandler.backfill_team_member_budget_entries(
team_id=team_id,
members_with_roles=[
Member(user_id="user-A", role="user"),
Member(user_id="user-B", role="user"),
],
team_member_budget_id=budget_id,
prisma_client=mock_prisma,
)
mock_prisma.db.litellm_teammembership.create_many.assert_not_awaited()
mock_prisma.db.litellm_teammembership.update_many.assert_awaited_once_with(
where={"team_id": team_id, "budget_id": None},
data={"budget_id": budget_id},
)
@pytest.mark.asyncio
async def test_backfill_team_member_budget_entries_empty_members():
"""
@@ -4965,3 +4965,123 @@ async def test_increment_spend_counters_team_and_member():
finally:
ps.user_api_key_cache = original_key_cache
ps.spend_counter_cache = original_counter_cache
@pytest.mark.asyncio
async def test_init_and_increment_spend_counter_reseeds_from_db_on_counter_miss():
"""When the Redis counter is missing, the reseed path reads the
authoritative spend from the DB (not a stale cache), so the next
increment continues from the correct base value."""
from litellm.caching.dual_cache import DualCache
counter_cache = DualCache()
recorded_increments: list = []
async def record_increment(key, value, ttl=None, **kwargs):
recorded_increments.append({"key": key, "value": value, "ttl": ttl})
return value
fake_redis = AsyncMock()
fake_redis.async_increment = AsyncMock(side_effect=record_increment)
fake_redis.async_get_cache = AsyncMock(return_value=None) # counter missing
counter_cache.redis_cache = fake_redis
# Prisma returns spend=42.0 (authoritative) while the stale cached
# value (would be read only if prisma is None) is 10.0. The counter
# must seed from 42, not 10.
db_row = MagicMock()
db_row.spend = 42.0
fake_prisma = MagicMock()
fake_prisma.db.litellm_teamtable.find_unique = AsyncMock(return_value=db_row)
stale_cache = DualCache()
stale_team = MagicMock()
stale_team.spend = 10.0
stale_cache.in_memory_cache.set_cache(key="team_id:team-9", value=stale_team)
import litellm.proxy.proxy_server as ps
from litellm.proxy.proxy_server import _init_and_increment_spend_counter
orig_user, orig_counter, orig_prisma = (
ps.user_api_key_cache,
ps.spend_counter_cache,
ps.prisma_client,
)
ps.user_api_key_cache = stale_cache
ps.spend_counter_cache = counter_cache
ps.prisma_client = fake_prisma
try:
await _init_and_increment_spend_counter(
counter_key="spend:team:team-9",
source_cache_key="team_id:team-9",
increment=1.5,
)
fake_prisma.db.litellm_teamtable.find_unique.assert_awaited_once_with(
where={"team_id": "team-9"}
)
# Two increments keyed on the counter: seed ($42) then request ($1.50).
writes = [(c["key"], c["value"]) for c in recorded_increments]
assert ("spend:team:team-9", 42.0) in writes
assert ("spend:team:team-9", 1.5) in writes
finally:
ps.user_api_key_cache = orig_user
ps.spend_counter_cache = orig_counter
ps.prisma_client = orig_prisma
@pytest.mark.asyncio
async def test_reseed_spend_from_db_user_and_org_prefixes():
"""User and org counters must reseed from their own DB tables, not
fall through to 0.0 like the other counters do today."""
import litellm.proxy.proxy_server as ps
from litellm.proxy.proxy_server import _reseed_spend_from_db
user_row = MagicMock()
user_row.spend = 17.0
org_row = MagicMock()
org_row.spend = 305.0
fake_prisma = MagicMock()
fake_prisma.db.litellm_usertable.find_unique = AsyncMock(return_value=user_row)
fake_prisma.db.litellm_organizationtable.find_unique = AsyncMock(
return_value=org_row
)
orig_prisma = ps.prisma_client
ps.prisma_client = fake_prisma
try:
assert await _reseed_spend_from_db("spend:user:alice") == 17.0
fake_prisma.db.litellm_usertable.find_unique.assert_awaited_once_with(
where={"user_id": "alice"}
)
assert await _reseed_spend_from_db("spend:org:acme") == 305.0
fake_prisma.db.litellm_organizationtable.find_unique.assert_awaited_once_with(
where={"organization_id": "acme"}
)
finally:
ps.prisma_client = orig_prisma
@pytest.mark.asyncio
async def test_reseed_spend_from_db_skips_window_variant_keys():
"""Window counters (spend:*:window:{duration}) share prefixes with
primary counters but don't correspond to a DB row. The guard must
short-circuit without querying the DB."""
import litellm.proxy.proxy_server as ps
from litellm.proxy.proxy_server import _reseed_spend_from_db
fake_prisma = MagicMock()
fake_prisma.db.litellm_verificationtoken.find_unique = AsyncMock()
fake_prisma.db.litellm_teamtable.find_unique = AsyncMock()
orig_prisma = ps.prisma_client
ps.prisma_client = fake_prisma
try:
assert await _reseed_spend_from_db("spend:key:sk-abc:window:1h") == 0.0
assert await _reseed_spend_from_db("spend:team:team-1:window:1d") == 0.0
fake_prisma.db.litellm_verificationtoken.find_unique.assert_not_awaited()
fake_prisma.db.litellm_teamtable.find_unique.assert_not_awaited()
finally:
ps.prisma_client = orig_prisma