fix: custom auth budget issue

This commit is contained in:
Harshit28j
2026-02-26 13:03:01 +05:30
parent 00ab4d2067
commit 14badde13c
6 changed files with 615 additions and 70 deletions
+26 -18
View File
@@ -1107,7 +1107,9 @@ class NewMCPServerRequest(LiteLLMPydanticObjectBase):
raise ValueError("args is required for stdio transport")
elif transport in [MCPTransport.http, MCPTransport.sse]:
if not values.get("url") and not values.get("spec_path"):
raise ValueError("url or spec_path is required for HTTP/SSE transport")
raise ValueError(
"url or spec_path is required for HTTP/SSE transport"
)
return values
@model_validator(mode="before")
@@ -1170,7 +1172,9 @@ class UpdateMCPServerRequest(LiteLLMPydanticObjectBase):
raise ValueError("args is required for stdio transport")
elif transport in [MCPTransport.http, MCPTransport.sse]:
if not values.get("url") and not values.get("spec_path"):
raise ValueError("url or spec_path is required for HTTP/SSE transport")
raise ValueError(
"url or spec_path is required for HTTP/SSE transport"
)
return values
@@ -1421,12 +1425,12 @@ class NewCustomerRequest(BudgetNewRequest):
blocked: bool = False # allow/disallow requests for this end-user
budget_id: Optional[str] = None # give either a budget_id or max_budget
spend: Optional[float] = None
allowed_model_region: Optional[AllowedModelRegion] = (
None # require all user requests to use models in this specific region
)
default_model: Optional[str] = (
None # if no equivalent model in allowed region - default all requests to this model
)
allowed_model_region: Optional[
AllowedModelRegion
] = None # require all user requests to use models in this specific region
default_model: Optional[
str
] = None # if no equivalent model in allowed region - default all requests to this model
object_permission: Optional[LiteLLM_ObjectPermissionBase] = None
@model_validator(mode="before")
@@ -1449,12 +1453,12 @@ class UpdateCustomerRequest(LiteLLMPydanticObjectBase):
blocked: bool = False # allow/disallow requests for this end-user
max_budget: Optional[float] = None
budget_id: Optional[str] = None # give either a budget_id or max_budget
allowed_model_region: Optional[AllowedModelRegion] = (
None # require all user requests to use models in this specific region
)
default_model: Optional[str] = (
None # if no equivalent model in allowed region - default all requests to this model
)
allowed_model_region: Optional[
AllowedModelRegion
] = None # require all user requests to use models in this specific region
default_model: Optional[
str
] = None # if no equivalent model in allowed region - default all requests to this model
object_permission: Optional[LiteLLM_ObjectPermissionBase] = None
@@ -2279,6 +2283,7 @@ class LiteLLM_VerificationTokenView(LiteLLM_VerificationToken):
end_user_tpm_limit: Optional[int] = None
end_user_rpm_limit: Optional[int] = None
end_user_max_budget: Optional[float] = None
end_user_model_max_budget: Optional[dict] = None
# Organization Params
organization_max_budget: Optional[float] = None
@@ -3067,7 +3072,9 @@ class SpendLogsMetadata(TypedDict):
str
] # S3/GCS object key for cold storage retrieval
litellm_overhead_time_ms: Optional[float] # LiteLLM overhead time in milliseconds
attempted_retries: Optional[int] # Number of retries attempted (0 = first attempt succeeded)
attempted_retries: Optional[
int
] # Number of retries attempted (0 = first attempt succeeded)
max_retries: Optional[int] # Max retries configured for this request
cost_breakdown: Optional[
CostBreakdown
@@ -4127,10 +4134,10 @@ class SpendUpdateQueueItem(TypedDict, total=False):
class ToolDiscoveryQueueItem(TypedDict, total=False):
tool_name: str
origin: Optional[str] # MCP server name or "user_defined"
origin: Optional[str] # MCP server name or "user_defined"
created_by: Optional[str]
key_hash: Optional[str] # hash of virtual key that triggered discovery
team_id: Optional[str] # team that triggered discovery
key_hash: Optional[str] # hash of virtual key that triggered discovery
team_id: Optional[str] # team that triggered discovery
key_alias: Optional[str] # human-readable key alias
@@ -4154,6 +4161,7 @@ class LiteLLM_ManagedObjectTable(LiteLLMPydanticObjectBase):
class LiteLLM_ManagedVectorStoreTable(LiteLLMPydanticObjectBase):
"""Table for managing vector stores with target_model_names support."""
unified_resource_id: str
resource_object: Optional[Any] = None # VectorStoreCreateResponse
model_mappings: Dict[str, str]
+266 -10
View File
@@ -183,6 +183,9 @@ def _apply_budget_limits_to_end_user_params(
if budget_info.max_budget is not None:
end_user_params["end_user_max_budget"] = budget_info.max_budget
if budget_info.model_max_budget is not None:
end_user_params["end_user_model_max_budget"] = budget_info.model_max_budget
verbose_proxy_logger.debug(f"Applied budget limits to end user {end_user_id}")
@@ -237,6 +240,9 @@ def update_valid_token_with_end_user_params(
valid_token.end_user_tpm_limit = end_user_params.get("end_user_tpm_limit")
valid_token.end_user_rpm_limit = end_user_params.get("end_user_rpm_limit")
valid_token.allowed_model_region = end_user_params.get("allowed_model_region")
valid_token.end_user_model_max_budget = end_user_params.get(
"end_user_model_max_budget"
)
return valid_token
@@ -493,13 +499,29 @@ async def _user_api_key_auth_builder( # noqa: PLR0915
request=request, api_key=api_key, user_custom_auth=user_custom_auth
)
if response is not None and isinstance(response, UserAPIKeyAuth):
return UserAPIKeyAuth.model_validate(response)
validated = UserAPIKeyAuth.model_validate(response)
validated = await _run_post_custom_auth_checks(
valid_token=validated,
request=request,
request_data=request_data,
route=route,
parent_otel_span=parent_otel_span,
)
return validated
elif response is not None and isinstance(response, str):
api_key = response
custom_auth_api_key = True
elif user_custom_auth is not None:
response = await user_custom_auth(request=request, api_key=api_key) # type: ignore
return UserAPIKeyAuth.model_validate(response)
validated = UserAPIKeyAuth.model_validate(response)
validated = await _run_post_custom_auth_checks(
valid_token=validated,
request=request,
request_data=request_data,
route=route,
parent_otel_span=parent_otel_span,
)
return validated
### LITELLM-DEFINED AUTH FUNCTION ###
#### IF JWT ####
@@ -593,9 +615,7 @@ async def _user_api_key_auth_builder( # noqa: PLR0915
user_id=user_id,
team_id=team_id,
team_alias=(
team_object.team_alias
if team_object is not None
else None
team_object.team_alias if team_object is not None else None
),
team_metadata=team_object.metadata
if team_object is not None
@@ -846,7 +866,9 @@ async def _user_api_key_auth_builder( # noqa: PLR0915
)
valid_token.parent_otel_span = parent_otel_span
if _end_user_object is not None:
valid_token.end_user_object_permission = _end_user_object.object_permission
valid_token.end_user_object_permission = (
_end_user_object.object_permission
)
return valid_token
@@ -954,7 +976,11 @@ async def _user_api_key_auth_builder( # noqa: PLR0915
if isinstance(
api_key, str
): # if generated token, make sure it starts with sk-.
_masked_key = "{}****{}".format(api_key[:4], api_key[-4:]) if len(api_key) > 8 else "****"
_masked_key = (
"{}****{}".format(api_key[:4], api_key[-4:])
if len(api_key) > 8
else "****"
)
assert api_key.startswith(
"sk-"
), "LiteLLM Virtual Key expected. Received={}, expected to start with 'sk-'.".format(
@@ -1201,6 +1227,21 @@ async def _user_api_key_auth_builder( # noqa: PLR0915
model=current_model,
)
# Check 5b. End-user model max budget
end_user_mmb = valid_token.end_user_model_max_budget
if (
end_user_mmb is not None
and isinstance(end_user_mmb, dict)
and len(end_user_mmb) > 0
and current_model is not None
and valid_token.end_user_id is not None
):
await model_max_budget_limiter.is_end_user_within_model_budget(
end_user_id=valid_token.end_user_id,
end_user_model_max_budget=end_user_mmb,
model=current_model,
)
# Check 6: Additional Common Checks across jwt + key auth
if valid_token.team_id is not None:
try:
@@ -1304,9 +1345,9 @@ async def _user_api_key_auth_builder( # noqa: PLR0915
if _end_user_object is not None:
valid_token_dict.update(end_user_params)
valid_token_dict["end_user_object_permission"] = (
_end_user_object.object_permission
)
valid_token_dict[
"end_user_object_permission"
] = _end_user_object.object_permission
# check if token is from litellm-ui, litellm ui makes keys to allow users to login with sso. These keys can only be used for LiteLLM UI functions
# sso/login, ui/login, /key functions and /user functions
@@ -1492,3 +1533,218 @@ def _update_key_budget_with_temp_budget_increase(
temp_budget_increase = _get_temp_budget_increase(valid_token) or 0.0
valid_token.max_budget = valid_token.max_budget + temp_budget_increase
return valid_token
async def _lookup_end_user_and_apply_budget(
valid_token: UserAPIKeyAuth,
route: str,
parent_otel_span: Optional[Span],
prisma_client,
user_api_key_cache,
proxy_logging_obj,
):
"""Look up end_user from DB and apply budget limits to valid_token."""
end_user_object = None
try:
end_user_object = await get_end_user_object(
end_user_id=valid_token.end_user_id,
prisma_client=prisma_client,
user_api_key_cache=user_api_key_cache,
parent_otel_span=parent_otel_span,
proxy_logging_obj=proxy_logging_obj,
route=route,
)
if end_user_object is not None:
end_user_params = {
"end_user_id": valid_token.end_user_id,
"allowed_model_region": end_user_object.allowed_model_region,
}
if end_user_object.litellm_budget_table is not None:
_apply_budget_limits_to_end_user_params(
end_user_params=end_user_params,
budget_info=end_user_object.litellm_budget_table,
end_user_id=valid_token.end_user_id,
)
valid_token = update_valid_token_with_end_user_params(
valid_token=valid_token, end_user_params=end_user_params
)
elif litellm.max_end_user_budget_id is not None:
from litellm.proxy.auth.auth_checks import get_default_end_user_budget
default_budget = await get_default_end_user_budget(
prisma_client=prisma_client,
user_api_key_cache=user_api_key_cache,
parent_otel_span=parent_otel_span,
)
if default_budget is not None:
end_user_params = {"end_user_id": valid_token.end_user_id}
_apply_budget_limits_to_end_user_params(
end_user_params=end_user_params,
budget_info=default_budget,
end_user_id=valid_token.end_user_id,
)
valid_token = update_valid_token_with_end_user_params(
valid_token=valid_token, end_user_params=end_user_params
)
except Exception as e:
if isinstance(e, litellm.BudgetExceededError):
raise e
verbose_proxy_logger.debug(f"Unable to find user in db. Error - {str(e)}")
return valid_token, end_user_object
async def _run_post_custom_auth_checks(
valid_token: UserAPIKeyAuth,
request: Request,
request_data: dict,
route: str,
parent_otel_span: Optional[Span],
) -> UserAPIKeyAuth:
from litellm.proxy.proxy_server import (
prisma_client,
user_api_key_cache,
proxy_logging_obj,
general_settings,
llm_router,
model_max_budget_limiter,
)
# 1. Look up end_user object from DB if end_user_id is set
end_user_object = None
if valid_token.end_user_id is not None:
valid_token, end_user_object = await _lookup_end_user_and_apply_budget(
valid_token=valid_token,
route=route,
parent_otel_span=parent_otel_span,
prisma_client=prisma_client,
user_api_key_cache=user_api_key_cache,
proxy_logging_obj=proxy_logging_obj,
)
# 2. Check token expiry
if valid_token.expires is not None:
current_time = datetime.now(timezone.utc)
if isinstance(valid_token.expires, datetime):
expiry_time = valid_token.expires
else:
expiry_time = datetime.fromisoformat(valid_token.expires)
if (
expiry_time.tzinfo is None
or expiry_time.tzinfo.utcoffset(expiry_time) is None
):
expiry_time = expiry_time.replace(tzinfo=timezone.utc)
if expiry_time < current_time:
raise ProxyException(
message=f"Authentication Error - Expired Key. Key Expiry time {expiry_time} and current time {current_time}",
type=ProxyErrorTypes.expired_key,
code=400,
param=abbreviate_api_key(api_key=valid_token.token)
if valid_token.token
else "",
)
current_model = request_data.get("model", None)
# 3. Check key-level model_max_budget
max_budget_per_model = valid_token.model_max_budget
if (
max_budget_per_model is not None
and isinstance(max_budget_per_model, dict)
and len(max_budget_per_model) > 0
and current_model is not None
and valid_token.token is not None
):
await model_max_budget_limiter.is_key_within_model_budget(
user_api_key_dict=valid_token,
model=current_model,
)
# 4. Check end-user model_max_budget
end_user_mmb = valid_token.end_user_model_max_budget
if (
end_user_mmb is not None
and isinstance(end_user_mmb, dict)
and len(end_user_mmb) > 0
and current_model is not None
and valid_token.end_user_id is not None
):
await model_max_budget_limiter.is_end_user_within_model_budget(
end_user_id=valid_token.end_user_id,
end_user_model_max_budget=end_user_mmb,
model=current_model,
)
# 5. Look up user object if user_id is set
user_object = None
if valid_token.user_id is not None:
try:
user_object = await get_user_object(
user_id=valid_token.user_id,
prisma_client=prisma_client,
user_api_key_cache=user_api_key_cache,
user_id_upsert=False,
parent_otel_span=parent_otel_span,
proxy_logging_obj=proxy_logging_obj,
)
except Exception:
# If user_role is PROXY_ADMIN on the token, create a synthetic user object
# so that admin route checks pass for custom auth
if valid_token.user_role == LitellmUserRoles.PROXY_ADMIN:
user_object = LiteLLM_UserTable(
user_id=valid_token.user_id,
user_role=LitellmUserRoles.PROXY_ADMIN,
spend=0.0,
)
# 6. Run common checks
if valid_token.team_id is not None:
try:
_team_obj = await get_team_object(
team_id=valid_token.team_id,
prisma_client=prisma_client,
user_api_key_cache=user_api_key_cache,
parent_otel_span=parent_otel_span,
proxy_logging_obj=proxy_logging_obj,
)
except HTTPException:
_team_obj = LiteLLM_TeamTableCachedObj(
team_id=valid_token.team_id,
max_budget=valid_token.team_max_budget,
soft_budget=valid_token.team_soft_budget,
spend=valid_token.team_spend,
tpm_limit=valid_token.team_tpm_limit,
rpm_limit=valid_token.team_rpm_limit,
blocked=valid_token.team_blocked,
models=valid_token.team_models,
metadata=valid_token.team_metadata,
object_permission_id=valid_token.team_object_permission_id,
)
else:
_team_obj = None
_project_obj = None
if valid_token.project_id is not None:
_project_obj = await get_project_object(
project_id=valid_token.project_id,
prisma_client=prisma_client,
user_api_key_cache=user_api_key_cache,
proxy_logging_obj=proxy_logging_obj,
)
_ = await common_checks(
request=request,
request_body=request_data,
team_object=_team_obj,
user_object=user_object,
end_user_object=end_user_object,
general_settings=general_settings,
global_proxy_spend=None,
route=route,
llm_router=llm_router,
proxy_logging_obj=proxy_logging_obj,
valid_token=valid_token,
skip_budget_checks=False,
project_object=_project_obj,
)
return valid_token
+130 -23
View File
@@ -15,6 +15,7 @@ from litellm.types.utils import (
)
VIRTUAL_KEY_SPEND_CACHE_KEY_PREFIX = "virtual_key_spend"
END_USER_SPEND_CACHE_KEY_PREFIX = "end_user_model_spend"
class _PROXY_VirtualKeyModelMaxBudgetLimiter(RouterBudgetLimiting):
@@ -83,6 +84,81 @@ class _PROXY_VirtualKeyModelMaxBudgetLimiter(RouterBudgetLimiting):
return True
async def is_end_user_within_model_budget(
self,
end_user_id: str,
end_user_model_max_budget: dict,
model: str,
) -> bool:
"""
Check if the end_user is within the model budget
Raises:
BudgetExceededError: If the end_user has exceeded the model budget
"""
internal_model_max_budget: GenericBudgetConfigType = {}
for _model, _budget_info in end_user_model_max_budget.items():
internal_model_max_budget[_model] = BudgetConfig(**_budget_info)
verbose_proxy_logger.debug(
"end_user internal_model_max_budget %s",
json.dumps(internal_model_max_budget, indent=4, default=str),
)
# check if current model is in internal_model_max_budget
_current_model_budget_info = self._get_request_model_budget_config(
model=model, internal_model_max_budget=internal_model_max_budget
)
if _current_model_budget_info is None:
verbose_proxy_logger.debug(
f"Model {model} not found in end_user_model_max_budget"
)
return True
# check if current model is within budget
if (
_current_model_budget_info.max_budget
and _current_model_budget_info.max_budget > 0
):
_current_spend = await self._get_end_user_spend_for_model(
end_user_id=end_user_id,
model=model,
key_budget_config=_current_model_budget_info,
)
if (
_current_spend is not None
and _current_model_budget_info.max_budget is not None
and _current_spend > _current_model_budget_info.max_budget
):
raise litellm.BudgetExceededError(
message=f"LiteLLM End User: {end_user_id}, exceeded budget for model={model}",
current_cost=_current_spend,
max_budget=_current_model_budget_info.max_budget,
)
return True
async def _get_end_user_spend_for_model(
self,
end_user_id: str,
model: str,
key_budget_config: BudgetConfig,
) -> Optional[float]:
# 1. model: directly look up `model`
end_user_model_spend_cache_key = f"{END_USER_SPEND_CACHE_KEY_PREFIX}:{end_user_id}:{model}:{key_budget_config.budget_duration}"
_current_spend = await self.dual_cache.async_get_cache(
key=end_user_model_spend_cache_key,
)
if _current_spend is None:
# 2. If 1, does not exist, check if passed as {custom_llm_provider}/model
end_user_model_spend_cache_key = f"{END_USER_SPEND_CACHE_KEY_PREFIX}:{end_user_id}:{self._get_model_without_custom_llm_provider(model)}:{key_budget_config.budget_duration}"
_current_spend = await self.dual_cache.async_get_cache(
key=end_user_model_spend_cache_key,
)
return _current_spend
async def _get_virtual_key_spend_for_model(
self,
user_api_key_hash: Optional[str],
@@ -163,46 +239,77 @@ class _PROXY_VirtualKeyModelMaxBudgetLimiter(RouterBudgetLimiting):
user_api_key_model_max_budget: Optional[dict] = _metadata.get(
"user_api_key_model_max_budget", None
)
user_api_key_end_user_model_max_budget: Optional[dict] = _metadata.get(
"user_api_key_end_user_model_max_budget", None
)
if (
user_api_key_model_max_budget is None
or len(user_api_key_model_max_budget) == 0
) and (
user_api_key_end_user_model_max_budget is None
or len(user_api_key_end_user_model_max_budget) == 0
):
verbose_proxy_logger.debug(
"Not running _PROXY_VirtualKeyModelMaxBudgetLimiter.async_log_success_event because user_api_key_model_max_budget is None or empty. `user_api_key_model_max_budget`=%s",
user_api_key_model_max_budget,
"Not running _PROXY_VirtualKeyModelMaxBudgetLimiter.async_log_success_event because user_api_key_model_max_budget and user_api_key_end_user_model_max_budget are None or empty."
)
return
response_cost: float = standard_logging_payload.get("response_cost", 0)
model = standard_logging_payload.get("model")
virtual_key = standard_logging_payload.get("metadata", {}).get(
"user_api_key_hash"
)
end_user_id = standard_logging_payload.get(
"end_user"
) or standard_logging_payload.get("metadata", {}).get(
"user_api_key_end_user_id"
)
if virtual_key is None or model is None:
if model is None:
return
# Resolve per-model budget config (same logic as is_key_within_model_budget)
internal_model_max_budget: GenericBudgetConfigType = {}
for _model, _budget_info in user_api_key_model_max_budget.items():
internal_model_max_budget[_model] = BudgetConfig(**_budget_info)
key_budget_config = self._get_request_model_budget_config(
model=model, internal_model_max_budget=internal_model_max_budget
)
if key_budget_config is None or not key_budget_config.budget_duration:
verbose_proxy_logger.debug(
"Not incrementing model spend: no budget config or budget_duration for model=%s",
model,
if (
virtual_key is not None
and user_api_key_model_max_budget is not None
and len(user_api_key_model_max_budget) > 0
):
internal_model_max_budget: GenericBudgetConfigType = {}
for _model, _budget_info in user_api_key_model_max_budget.items():
internal_model_max_budget[_model] = BudgetConfig(**_budget_info)
key_budget_config = self._get_request_model_budget_config(
model=model, internal_model_max_budget=internal_model_max_budget
)
return
if key_budget_config is not None and key_budget_config.budget_duration:
virtual_spend_key = f"{VIRTUAL_KEY_SPEND_CACHE_KEY_PREFIX}:{virtual_key}:{model}:{key_budget_config.budget_duration}"
virtual_start_time_key = f"virtual_key_budget_start_time:{virtual_key}"
await self._increment_spend_for_key(
budget_config=key_budget_config,
spend_key=virtual_spend_key,
start_time_key=virtual_start_time_key,
response_cost=response_cost,
)
if (
end_user_id is not None
and user_api_key_end_user_model_max_budget is not None
and len(user_api_key_end_user_model_max_budget) > 0
):
internal_model_max_budget: GenericBudgetConfigType = {}
for _model, _budget_info in user_api_key_end_user_model_max_budget.items():
internal_model_max_budget[_model] = BudgetConfig(**_budget_info)
key_budget_config = self._get_request_model_budget_config(
model=model, internal_model_max_budget=internal_model_max_budget
)
if key_budget_config is not None and key_budget_config.budget_duration:
end_user_spend_key = f"{END_USER_SPEND_CACHE_KEY_PREFIX}:{end_user_id}:{model}:{key_budget_config.budget_duration}"
end_user_start_time_key = f"end_user_budget_start_time:{end_user_id}"
await self._increment_spend_for_key(
budget_config=key_budget_config,
spend_key=end_user_spend_key,
start_time_key=end_user_start_time_key,
response_cost=response_cost,
)
virtual_spend_key = f"{VIRTUAL_KEY_SPEND_CACHE_KEY_PREFIX}:{virtual_key}:{model}:{key_budget_config.budget_duration}"
virtual_start_time_key = f"virtual_key_budget_start_time:{virtual_key}"
await self._increment_spend_for_key(
budget_config=key_budget_config,
spend_key=virtual_spend_key,
start_time_key=virtual_start_time_key,
response_cost=response_cost,
)
verbose_proxy_logger.debug(
"current state of in memory cache %s",
json.dumps(
+28 -19
View File
@@ -10,10 +10,15 @@ import litellm
from litellm._logging import verbose_logger, verbose_proxy_logger
from litellm._service_logger import ServiceLogging
from litellm.litellm_core_utils.safe_json_loads import safe_json_loads
from litellm.proxy._types import (AddTeamCallback, CommonProxyErrors,
LitellmDataForBackendLLMCall,
LitellmUserRoles, SpecialHeaders,
TeamCallbackMetadata, UserAPIKeyAuth)
from litellm.proxy._types import (
AddTeamCallback,
CommonProxyErrors,
LitellmDataForBackendLLMCall,
LitellmUserRoles,
SpecialHeaders,
TeamCallbackMetadata,
UserAPIKeyAuth,
)
from litellm.proxy.common_utils.http_parsing_utils import _safe_get_request_headers
# Cache special headers as a frozenset for O(1) lookup performance
@@ -23,9 +28,12 @@ _SPECIAL_HEADERS_CACHE = frozenset(
from litellm.router import Router
from litellm.types.llms.anthropic import ANTHROPIC_API_HEADERS
from litellm.types.services import ServiceTypes
from litellm.types.utils import (LlmProviders, ProviderSpecificHeader,
StandardLoggingUserAPIKeyMetadata,
SupportedCacheControls)
from litellm.types.utils import (
LlmProviders,
ProviderSpecificHeader,
StandardLoggingUserAPIKeyMetadata,
SupportedCacheControls,
)
service_logger_obj = ServiceLogging() # used for tracking latency on OTEL
@@ -654,7 +662,8 @@ class LiteLLMProxyRequestSetup:
return data
from litellm.proxy._types import (
LiteLLM_ManagementEndpoint_MetadataFields,
LiteLLM_ManagementEndpoint_MetadataFields_Premium)
LiteLLM_ManagementEndpoint_MetadataFields_Premium,
)
# ignore any special fields
added_metadata = {}
@@ -1025,6 +1034,9 @@ async def add_litellm_data_to_request( # noqa: PLR0915
data[_metadata_variable_name][
"user_api_key_model_max_budget"
] = user_api_key_dict.model_max_budget
data[_metadata_variable_name][
"user_api_key_end_user_model_max_budget"
] = user_api_key_dict.end_user_model_max_budget
# User spend, budget - used by prometheus.py
# Follow same pattern as team and API key budgets
@@ -1479,8 +1491,7 @@ async def move_guardrails_to_metadata(
# Only check policy engine if no local config (avoid import + registry lookup)
if not (has_key_config or has_team_config or has_request_config):
from litellm.proxy.policy_engine.policy_registry import \
get_policy_registry
from litellm.proxy.policy_engine.policy_registry import get_policy_registry
if not get_policy_registry().is_initialized():
# Nothing configured anywhere - clean up request body fields and return
@@ -1544,16 +1555,14 @@ async def move_guardrails_to_metadata(
def _is_policy_version_id(s: str) -> bool:
"""Return True if string is a policy version ID (starts with policy_<uuid> prefix)."""
from litellm.proxy.policy_engine.policy_registry import \
POLICY_VERSION_ID_PREFIX
from litellm.proxy.policy_engine.policy_registry import POLICY_VERSION_ID_PREFIX
return isinstance(s, str) and s.startswith(POLICY_VERSION_ID_PREFIX)
def _extract_policy_id(s: str) -> Optional[str]:
"""Extract raw UUID from policy_<uuid> string, or None if not a valid version ID."""
from litellm.proxy.policy_engine.policy_registry import \
POLICY_VERSION_ID_PREFIX
from litellm.proxy.policy_engine.policy_registry import POLICY_VERSION_ID_PREFIX
if not _is_policy_version_id(s):
return None
@@ -1574,9 +1583,10 @@ def _match_and_track_policies(
"""
from litellm._logging import verbose_proxy_logger
from litellm.proxy.common_utils.callback_utils import (
add_policy_sources_to_metadata, add_policy_to_applied_policies_header)
from litellm.proxy.policy_engine.attachment_registry import \
get_attachment_registry
add_policy_sources_to_metadata,
add_policy_to_applied_policies_header,
)
from litellm.proxy.policy_engine.attachment_registry import get_attachment_registry
from litellm.proxy.policy_engine.policy_matcher import PolicyMatcher
# Get matching policies via attachments (with match reasons for attribution)
@@ -1721,8 +1731,7 @@ async def add_guardrails_from_policy_engine(
user_api_key_dict: The user's API key authentication info
"""
from litellm._logging import verbose_proxy_logger
from litellm.proxy.common_utils.http_parsing_utils import \
get_tags_from_request_body
from litellm.proxy.common_utils.http_parsing_utils import get_tags_from_request_body
from litellm.proxy.policy_engine.policy_registry import get_policy_registry
from litellm.types.proxy.policy_engine import PolicyMatchContext
@@ -158,3 +158,109 @@ async def test_async_log_success_event_uses_per_model_budget_duration(budget_lim
f"{VIRTUAL_KEY_SPEND_CACHE_KEY_PREFIX}:{virtual_key}:{model}:{budget_duration}"
)
assert call_kwargs["response_cost"] == 0.05
# Test is_end_user_within_model_budget
@pytest.mark.asyncio
async def test_is_end_user_within_model_budget(budget_limiter):
# Test when model is within budget
with patch.object(
budget_limiter, "_get_end_user_spend_for_model", return_value=50.0
):
assert (
await budget_limiter.is_end_user_within_model_budget(
"test-user",
{"gpt-4": {"budget_limit": 100.0, "time_period": "1d"}},
"gpt-4",
)
is True
)
# Test when model exceeds budget
with patch.object(
budget_limiter, "_get_end_user_spend_for_model", return_value=150.0
):
with pytest.raises(litellm.BudgetExceededError):
await budget_limiter.is_end_user_within_model_budget(
"test-user",
{"gpt-4": {"budget_limit": 100.0, "time_period": "1d"}},
"gpt-4",
)
# Test model not in budget config
assert (
await budget_limiter.is_end_user_within_model_budget(
"test-user",
{"gpt-4": {"budget_limit": 100.0, "time_period": "1d"}},
"non-existent",
)
is True
)
# Test _get_end_user_spend_for_model
@pytest.mark.asyncio
async def test_get_end_user_spend_for_model(budget_limiter):
budget_config = GenericBudgetInfo(budget_limit=100.0, time_period="1d")
# Mock cache get
with patch.object(budget_limiter.dual_cache, "async_get_cache", return_value=50.0):
spend = await budget_limiter._get_end_user_spend_for_model(
end_user_id="test-user", model="gpt-4", key_budget_config=budget_config
)
assert spend == 50.0
# Test with provider prefix
spend = await budget_limiter._get_end_user_spend_for_model(
end_user_id="test-user",
model="openai/gpt-4",
key_budget_config=budget_config,
)
assert spend == 50.0
@pytest.mark.asyncio
async def test_async_log_success_event_uses_end_user_model_budget_duration(
budget_limiter,
):
"""
async_log_success_event must use the per-model budget_duration for the end user cache key
"""
from litellm.proxy.hooks.model_max_budget_limiter import (
END_USER_SPEND_CACHE_KEY_PREFIX,
)
end_user_id = "test-user"
model = "gpt-4"
budget_duration = "1d"
user_api_key_end_user_model_max_budget = {
model: {"budget_limit": 100.0, "time_period": budget_duration},
}
kwargs = {
"standard_logging_object": {
"response_cost": 0.05,
"model": model,
"end_user": end_user_id,
"metadata": {"user_api_key_end_user_id": end_user_id},
},
"litellm_params": {
"metadata": {
"user_api_key_end_user_model_max_budget": user_api_key_end_user_model_max_budget
},
},
}
with patch.object(
budget_limiter,
"_increment_spend_for_key",
new_callable=AsyncMock,
) as mock_increment:
await budget_limiter.async_log_success_event(
kwargs, response_obj=None, start_time=None, end_time=None
)
mock_increment.assert_awaited_once()
call_kwargs = mock_increment.call_args.kwargs
spend_key = call_kwargs["spend_key"]
assert spend_key == (
f"{END_USER_SPEND_CACHE_KEY_PREFIX}:{end_user_id}:{model}:{budget_duration}"
)
assert call_kwargs["response_cost"] == 0.05
@@ -0,0 +1,59 @@
import pytest
from unittest.mock import AsyncMock, patch
import litellm
from litellm.proxy.auth.user_api_key_auth import _run_post_custom_auth_checks
from litellm.proxy._types import UserAPIKeyAuth
@pytest.mark.asyncio
async def test_custom_auth_run_post_custom_auth_checks_without_end_user_id():
# Test backwards compatibility
valid_token = UserAPIKeyAuth(token="test_token")
with patch(
"litellm.proxy.auth.user_api_key_auth.common_checks", new_callable=AsyncMock
) as mock_common:
mock_common.return_value = True
result = await _run_post_custom_auth_checks(
valid_token=valid_token,
request=None,
request_data={},
route="/v1/chat/completions",
parent_otel_span=None,
)
assert result.token == "test_token"
assert getattr(result, "end_user_id", None) is None
mock_common.assert_awaited_once()
@pytest.mark.asyncio
async def test_custom_auth_run_post_custom_auth_checks_with_end_user_budget_exceeded():
valid_token = UserAPIKeyAuth(
token="test_token",
end_user_id="test_user",
end_user_model_max_budget={
"gpt-4": {"budget_limit": 10.0, "time_period": "1d"}
},
)
request_data = {"model": "gpt-4"}
with patch(
"litellm.proxy.auth.user_api_key_auth.common_checks", new_callable=AsyncMock
):
with patch(
"litellm.proxy.proxy_server.model_max_budget_limiter.is_end_user_within_model_budget",
new_callable=AsyncMock,
) as mock_budget_check:
mock_budget_check.side_effect = litellm.BudgetExceededError(
message="Exceeded budget", current_cost=20.0, max_budget=10.0
)
with pytest.raises(litellm.BudgetExceededError):
await _run_post_custom_auth_checks(
valid_token=valid_token,
request=None,
request_data=request_data,
route="/v1/chat/completions",
parent_otel_span=None,
)
mock_budget_check.assert_awaited_once()