mirror of
https://github.com/tiennm99/litellm.git
synced 2026-06-28 23:09:28 +00:00
fix: custom auth budget issue
This commit is contained in:
+26
-18
@@ -1107,7 +1107,9 @@ class NewMCPServerRequest(LiteLLMPydanticObjectBase):
|
||||
raise ValueError("args is required for stdio transport")
|
||||
elif transport in [MCPTransport.http, MCPTransport.sse]:
|
||||
if not values.get("url") and not values.get("spec_path"):
|
||||
raise ValueError("url or spec_path is required for HTTP/SSE transport")
|
||||
raise ValueError(
|
||||
"url or spec_path is required for HTTP/SSE transport"
|
||||
)
|
||||
return values
|
||||
|
||||
@model_validator(mode="before")
|
||||
@@ -1170,7 +1172,9 @@ class UpdateMCPServerRequest(LiteLLMPydanticObjectBase):
|
||||
raise ValueError("args is required for stdio transport")
|
||||
elif transport in [MCPTransport.http, MCPTransport.sse]:
|
||||
if not values.get("url") and not values.get("spec_path"):
|
||||
raise ValueError("url or spec_path is required for HTTP/SSE transport")
|
||||
raise ValueError(
|
||||
"url or spec_path is required for HTTP/SSE transport"
|
||||
)
|
||||
return values
|
||||
|
||||
|
||||
@@ -1421,12 +1425,12 @@ class NewCustomerRequest(BudgetNewRequest):
|
||||
blocked: bool = False # allow/disallow requests for this end-user
|
||||
budget_id: Optional[str] = None # give either a budget_id or max_budget
|
||||
spend: Optional[float] = None
|
||||
allowed_model_region: Optional[AllowedModelRegion] = (
|
||||
None # require all user requests to use models in this specific region
|
||||
)
|
||||
default_model: Optional[str] = (
|
||||
None # if no equivalent model in allowed region - default all requests to this model
|
||||
)
|
||||
allowed_model_region: Optional[
|
||||
AllowedModelRegion
|
||||
] = None # require all user requests to use models in this specific region
|
||||
default_model: Optional[
|
||||
str
|
||||
] = None # if no equivalent model in allowed region - default all requests to this model
|
||||
object_permission: Optional[LiteLLM_ObjectPermissionBase] = None
|
||||
|
||||
@model_validator(mode="before")
|
||||
@@ -1449,12 +1453,12 @@ class UpdateCustomerRequest(LiteLLMPydanticObjectBase):
|
||||
blocked: bool = False # allow/disallow requests for this end-user
|
||||
max_budget: Optional[float] = None
|
||||
budget_id: Optional[str] = None # give either a budget_id or max_budget
|
||||
allowed_model_region: Optional[AllowedModelRegion] = (
|
||||
None # require all user requests to use models in this specific region
|
||||
)
|
||||
default_model: Optional[str] = (
|
||||
None # if no equivalent model in allowed region - default all requests to this model
|
||||
)
|
||||
allowed_model_region: Optional[
|
||||
AllowedModelRegion
|
||||
] = None # require all user requests to use models in this specific region
|
||||
default_model: Optional[
|
||||
str
|
||||
] = None # if no equivalent model in allowed region - default all requests to this model
|
||||
object_permission: Optional[LiteLLM_ObjectPermissionBase] = None
|
||||
|
||||
|
||||
@@ -2279,6 +2283,7 @@ class LiteLLM_VerificationTokenView(LiteLLM_VerificationToken):
|
||||
end_user_tpm_limit: Optional[int] = None
|
||||
end_user_rpm_limit: Optional[int] = None
|
||||
end_user_max_budget: Optional[float] = None
|
||||
end_user_model_max_budget: Optional[dict] = None
|
||||
|
||||
# Organization Params
|
||||
organization_max_budget: Optional[float] = None
|
||||
@@ -3067,7 +3072,9 @@ class SpendLogsMetadata(TypedDict):
|
||||
str
|
||||
] # S3/GCS object key for cold storage retrieval
|
||||
litellm_overhead_time_ms: Optional[float] # LiteLLM overhead time in milliseconds
|
||||
attempted_retries: Optional[int] # Number of retries attempted (0 = first attempt succeeded)
|
||||
attempted_retries: Optional[
|
||||
int
|
||||
] # Number of retries attempted (0 = first attempt succeeded)
|
||||
max_retries: Optional[int] # Max retries configured for this request
|
||||
cost_breakdown: Optional[
|
||||
CostBreakdown
|
||||
@@ -4127,10 +4134,10 @@ class SpendUpdateQueueItem(TypedDict, total=False):
|
||||
|
||||
class ToolDiscoveryQueueItem(TypedDict, total=False):
|
||||
tool_name: str
|
||||
origin: Optional[str] # MCP server name or "user_defined"
|
||||
origin: Optional[str] # MCP server name or "user_defined"
|
||||
created_by: Optional[str]
|
||||
key_hash: Optional[str] # hash of virtual key that triggered discovery
|
||||
team_id: Optional[str] # team that triggered discovery
|
||||
key_hash: Optional[str] # hash of virtual key that triggered discovery
|
||||
team_id: Optional[str] # team that triggered discovery
|
||||
key_alias: Optional[str] # human-readable key alias
|
||||
|
||||
|
||||
@@ -4154,6 +4161,7 @@ class LiteLLM_ManagedObjectTable(LiteLLMPydanticObjectBase):
|
||||
|
||||
class LiteLLM_ManagedVectorStoreTable(LiteLLMPydanticObjectBase):
|
||||
"""Table for managing vector stores with target_model_names support."""
|
||||
|
||||
unified_resource_id: str
|
||||
resource_object: Optional[Any] = None # VectorStoreCreateResponse
|
||||
model_mappings: Dict[str, str]
|
||||
|
||||
@@ -183,6 +183,9 @@ def _apply_budget_limits_to_end_user_params(
|
||||
if budget_info.max_budget is not None:
|
||||
end_user_params["end_user_max_budget"] = budget_info.max_budget
|
||||
|
||||
if budget_info.model_max_budget is not None:
|
||||
end_user_params["end_user_model_max_budget"] = budget_info.model_max_budget
|
||||
|
||||
verbose_proxy_logger.debug(f"Applied budget limits to end user {end_user_id}")
|
||||
|
||||
|
||||
@@ -237,6 +240,9 @@ def update_valid_token_with_end_user_params(
|
||||
valid_token.end_user_tpm_limit = end_user_params.get("end_user_tpm_limit")
|
||||
valid_token.end_user_rpm_limit = end_user_params.get("end_user_rpm_limit")
|
||||
valid_token.allowed_model_region = end_user_params.get("allowed_model_region")
|
||||
valid_token.end_user_model_max_budget = end_user_params.get(
|
||||
"end_user_model_max_budget"
|
||||
)
|
||||
return valid_token
|
||||
|
||||
|
||||
@@ -493,13 +499,29 @@ async def _user_api_key_auth_builder( # noqa: PLR0915
|
||||
request=request, api_key=api_key, user_custom_auth=user_custom_auth
|
||||
)
|
||||
if response is not None and isinstance(response, UserAPIKeyAuth):
|
||||
return UserAPIKeyAuth.model_validate(response)
|
||||
validated = UserAPIKeyAuth.model_validate(response)
|
||||
validated = await _run_post_custom_auth_checks(
|
||||
valid_token=validated,
|
||||
request=request,
|
||||
request_data=request_data,
|
||||
route=route,
|
||||
parent_otel_span=parent_otel_span,
|
||||
)
|
||||
return validated
|
||||
elif response is not None and isinstance(response, str):
|
||||
api_key = response
|
||||
custom_auth_api_key = True
|
||||
elif user_custom_auth is not None:
|
||||
response = await user_custom_auth(request=request, api_key=api_key) # type: ignore
|
||||
return UserAPIKeyAuth.model_validate(response)
|
||||
validated = UserAPIKeyAuth.model_validate(response)
|
||||
validated = await _run_post_custom_auth_checks(
|
||||
valid_token=validated,
|
||||
request=request,
|
||||
request_data=request_data,
|
||||
route=route,
|
||||
parent_otel_span=parent_otel_span,
|
||||
)
|
||||
return validated
|
||||
|
||||
### LITELLM-DEFINED AUTH FUNCTION ###
|
||||
#### IF JWT ####
|
||||
@@ -593,9 +615,7 @@ async def _user_api_key_auth_builder( # noqa: PLR0915
|
||||
user_id=user_id,
|
||||
team_id=team_id,
|
||||
team_alias=(
|
||||
team_object.team_alias
|
||||
if team_object is not None
|
||||
else None
|
||||
team_object.team_alias if team_object is not None else None
|
||||
),
|
||||
team_metadata=team_object.metadata
|
||||
if team_object is not None
|
||||
@@ -846,7 +866,9 @@ async def _user_api_key_auth_builder( # noqa: PLR0915
|
||||
)
|
||||
valid_token.parent_otel_span = parent_otel_span
|
||||
if _end_user_object is not None:
|
||||
valid_token.end_user_object_permission = _end_user_object.object_permission
|
||||
valid_token.end_user_object_permission = (
|
||||
_end_user_object.object_permission
|
||||
)
|
||||
|
||||
return valid_token
|
||||
|
||||
@@ -954,7 +976,11 @@ async def _user_api_key_auth_builder( # noqa: PLR0915
|
||||
if isinstance(
|
||||
api_key, str
|
||||
): # if generated token, make sure it starts with sk-.
|
||||
_masked_key = "{}****{}".format(api_key[:4], api_key[-4:]) if len(api_key) > 8 else "****"
|
||||
_masked_key = (
|
||||
"{}****{}".format(api_key[:4], api_key[-4:])
|
||||
if len(api_key) > 8
|
||||
else "****"
|
||||
)
|
||||
assert api_key.startswith(
|
||||
"sk-"
|
||||
), "LiteLLM Virtual Key expected. Received={}, expected to start with 'sk-'.".format(
|
||||
@@ -1201,6 +1227,21 @@ async def _user_api_key_auth_builder( # noqa: PLR0915
|
||||
model=current_model,
|
||||
)
|
||||
|
||||
# Check 5b. End-user model max budget
|
||||
end_user_mmb = valid_token.end_user_model_max_budget
|
||||
if (
|
||||
end_user_mmb is not None
|
||||
and isinstance(end_user_mmb, dict)
|
||||
and len(end_user_mmb) > 0
|
||||
and current_model is not None
|
||||
and valid_token.end_user_id is not None
|
||||
):
|
||||
await model_max_budget_limiter.is_end_user_within_model_budget(
|
||||
end_user_id=valid_token.end_user_id,
|
||||
end_user_model_max_budget=end_user_mmb,
|
||||
model=current_model,
|
||||
)
|
||||
|
||||
# Check 6: Additional Common Checks across jwt + key auth
|
||||
if valid_token.team_id is not None:
|
||||
try:
|
||||
@@ -1304,9 +1345,9 @@ async def _user_api_key_auth_builder( # noqa: PLR0915
|
||||
|
||||
if _end_user_object is not None:
|
||||
valid_token_dict.update(end_user_params)
|
||||
valid_token_dict["end_user_object_permission"] = (
|
||||
_end_user_object.object_permission
|
||||
)
|
||||
valid_token_dict[
|
||||
"end_user_object_permission"
|
||||
] = _end_user_object.object_permission
|
||||
|
||||
# check if token is from litellm-ui, litellm ui makes keys to allow users to login with sso. These keys can only be used for LiteLLM UI functions
|
||||
# sso/login, ui/login, /key functions and /user functions
|
||||
@@ -1492,3 +1533,218 @@ def _update_key_budget_with_temp_budget_increase(
|
||||
temp_budget_increase = _get_temp_budget_increase(valid_token) or 0.0
|
||||
valid_token.max_budget = valid_token.max_budget + temp_budget_increase
|
||||
return valid_token
|
||||
|
||||
|
||||
async def _lookup_end_user_and_apply_budget(
|
||||
valid_token: UserAPIKeyAuth,
|
||||
route: str,
|
||||
parent_otel_span: Optional[Span],
|
||||
prisma_client,
|
||||
user_api_key_cache,
|
||||
proxy_logging_obj,
|
||||
):
|
||||
"""Look up end_user from DB and apply budget limits to valid_token."""
|
||||
end_user_object = None
|
||||
try:
|
||||
end_user_object = await get_end_user_object(
|
||||
end_user_id=valid_token.end_user_id,
|
||||
prisma_client=prisma_client,
|
||||
user_api_key_cache=user_api_key_cache,
|
||||
parent_otel_span=parent_otel_span,
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
route=route,
|
||||
)
|
||||
if end_user_object is not None:
|
||||
end_user_params = {
|
||||
"end_user_id": valid_token.end_user_id,
|
||||
"allowed_model_region": end_user_object.allowed_model_region,
|
||||
}
|
||||
if end_user_object.litellm_budget_table is not None:
|
||||
_apply_budget_limits_to_end_user_params(
|
||||
end_user_params=end_user_params,
|
||||
budget_info=end_user_object.litellm_budget_table,
|
||||
end_user_id=valid_token.end_user_id,
|
||||
)
|
||||
valid_token = update_valid_token_with_end_user_params(
|
||||
valid_token=valid_token, end_user_params=end_user_params
|
||||
)
|
||||
elif litellm.max_end_user_budget_id is not None:
|
||||
from litellm.proxy.auth.auth_checks import get_default_end_user_budget
|
||||
|
||||
default_budget = await get_default_end_user_budget(
|
||||
prisma_client=prisma_client,
|
||||
user_api_key_cache=user_api_key_cache,
|
||||
parent_otel_span=parent_otel_span,
|
||||
)
|
||||
if default_budget is not None:
|
||||
end_user_params = {"end_user_id": valid_token.end_user_id}
|
||||
_apply_budget_limits_to_end_user_params(
|
||||
end_user_params=end_user_params,
|
||||
budget_info=default_budget,
|
||||
end_user_id=valid_token.end_user_id,
|
||||
)
|
||||
valid_token = update_valid_token_with_end_user_params(
|
||||
valid_token=valid_token, end_user_params=end_user_params
|
||||
)
|
||||
except Exception as e:
|
||||
if isinstance(e, litellm.BudgetExceededError):
|
||||
raise e
|
||||
verbose_proxy_logger.debug(f"Unable to find user in db. Error - {str(e)}")
|
||||
return valid_token, end_user_object
|
||||
|
||||
|
||||
async def _run_post_custom_auth_checks(
|
||||
valid_token: UserAPIKeyAuth,
|
||||
request: Request,
|
||||
request_data: dict,
|
||||
route: str,
|
||||
parent_otel_span: Optional[Span],
|
||||
) -> UserAPIKeyAuth:
|
||||
from litellm.proxy.proxy_server import (
|
||||
prisma_client,
|
||||
user_api_key_cache,
|
||||
proxy_logging_obj,
|
||||
general_settings,
|
||||
llm_router,
|
||||
model_max_budget_limiter,
|
||||
)
|
||||
|
||||
# 1. Look up end_user object from DB if end_user_id is set
|
||||
end_user_object = None
|
||||
if valid_token.end_user_id is not None:
|
||||
valid_token, end_user_object = await _lookup_end_user_and_apply_budget(
|
||||
valid_token=valid_token,
|
||||
route=route,
|
||||
parent_otel_span=parent_otel_span,
|
||||
prisma_client=prisma_client,
|
||||
user_api_key_cache=user_api_key_cache,
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
)
|
||||
|
||||
# 2. Check token expiry
|
||||
if valid_token.expires is not None:
|
||||
current_time = datetime.now(timezone.utc)
|
||||
if isinstance(valid_token.expires, datetime):
|
||||
expiry_time = valid_token.expires
|
||||
else:
|
||||
expiry_time = datetime.fromisoformat(valid_token.expires)
|
||||
if (
|
||||
expiry_time.tzinfo is None
|
||||
or expiry_time.tzinfo.utcoffset(expiry_time) is None
|
||||
):
|
||||
expiry_time = expiry_time.replace(tzinfo=timezone.utc)
|
||||
if expiry_time < current_time:
|
||||
raise ProxyException(
|
||||
message=f"Authentication Error - Expired Key. Key Expiry time {expiry_time} and current time {current_time}",
|
||||
type=ProxyErrorTypes.expired_key,
|
||||
code=400,
|
||||
param=abbreviate_api_key(api_key=valid_token.token)
|
||||
if valid_token.token
|
||||
else "",
|
||||
)
|
||||
|
||||
current_model = request_data.get("model", None)
|
||||
|
||||
# 3. Check key-level model_max_budget
|
||||
max_budget_per_model = valid_token.model_max_budget
|
||||
if (
|
||||
max_budget_per_model is not None
|
||||
and isinstance(max_budget_per_model, dict)
|
||||
and len(max_budget_per_model) > 0
|
||||
and current_model is not None
|
||||
and valid_token.token is not None
|
||||
):
|
||||
await model_max_budget_limiter.is_key_within_model_budget(
|
||||
user_api_key_dict=valid_token,
|
||||
model=current_model,
|
||||
)
|
||||
|
||||
# 4. Check end-user model_max_budget
|
||||
end_user_mmb = valid_token.end_user_model_max_budget
|
||||
if (
|
||||
end_user_mmb is not None
|
||||
and isinstance(end_user_mmb, dict)
|
||||
and len(end_user_mmb) > 0
|
||||
and current_model is not None
|
||||
and valid_token.end_user_id is not None
|
||||
):
|
||||
await model_max_budget_limiter.is_end_user_within_model_budget(
|
||||
end_user_id=valid_token.end_user_id,
|
||||
end_user_model_max_budget=end_user_mmb,
|
||||
model=current_model,
|
||||
)
|
||||
|
||||
# 5. Look up user object if user_id is set
|
||||
user_object = None
|
||||
if valid_token.user_id is not None:
|
||||
try:
|
||||
user_object = await get_user_object(
|
||||
user_id=valid_token.user_id,
|
||||
prisma_client=prisma_client,
|
||||
user_api_key_cache=user_api_key_cache,
|
||||
user_id_upsert=False,
|
||||
parent_otel_span=parent_otel_span,
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
)
|
||||
except Exception:
|
||||
# If user_role is PROXY_ADMIN on the token, create a synthetic user object
|
||||
# so that admin route checks pass for custom auth
|
||||
if valid_token.user_role == LitellmUserRoles.PROXY_ADMIN:
|
||||
user_object = LiteLLM_UserTable(
|
||||
user_id=valid_token.user_id,
|
||||
user_role=LitellmUserRoles.PROXY_ADMIN,
|
||||
spend=0.0,
|
||||
)
|
||||
|
||||
# 6. Run common checks
|
||||
if valid_token.team_id is not None:
|
||||
try:
|
||||
_team_obj = await get_team_object(
|
||||
team_id=valid_token.team_id,
|
||||
prisma_client=prisma_client,
|
||||
user_api_key_cache=user_api_key_cache,
|
||||
parent_otel_span=parent_otel_span,
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
)
|
||||
except HTTPException:
|
||||
_team_obj = LiteLLM_TeamTableCachedObj(
|
||||
team_id=valid_token.team_id,
|
||||
max_budget=valid_token.team_max_budget,
|
||||
soft_budget=valid_token.team_soft_budget,
|
||||
spend=valid_token.team_spend,
|
||||
tpm_limit=valid_token.team_tpm_limit,
|
||||
rpm_limit=valid_token.team_rpm_limit,
|
||||
blocked=valid_token.team_blocked,
|
||||
models=valid_token.team_models,
|
||||
metadata=valid_token.team_metadata,
|
||||
object_permission_id=valid_token.team_object_permission_id,
|
||||
)
|
||||
else:
|
||||
_team_obj = None
|
||||
|
||||
_project_obj = None
|
||||
if valid_token.project_id is not None:
|
||||
_project_obj = await get_project_object(
|
||||
project_id=valid_token.project_id,
|
||||
prisma_client=prisma_client,
|
||||
user_api_key_cache=user_api_key_cache,
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
)
|
||||
|
||||
_ = await common_checks(
|
||||
request=request,
|
||||
request_body=request_data,
|
||||
team_object=_team_obj,
|
||||
user_object=user_object,
|
||||
end_user_object=end_user_object,
|
||||
general_settings=general_settings,
|
||||
global_proxy_spend=None,
|
||||
route=route,
|
||||
llm_router=llm_router,
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
valid_token=valid_token,
|
||||
skip_budget_checks=False,
|
||||
project_object=_project_obj,
|
||||
)
|
||||
|
||||
return valid_token
|
||||
|
||||
@@ -15,6 +15,7 @@ from litellm.types.utils import (
|
||||
)
|
||||
|
||||
VIRTUAL_KEY_SPEND_CACHE_KEY_PREFIX = "virtual_key_spend"
|
||||
END_USER_SPEND_CACHE_KEY_PREFIX = "end_user_model_spend"
|
||||
|
||||
|
||||
class _PROXY_VirtualKeyModelMaxBudgetLimiter(RouterBudgetLimiting):
|
||||
@@ -83,6 +84,81 @@ class _PROXY_VirtualKeyModelMaxBudgetLimiter(RouterBudgetLimiting):
|
||||
|
||||
return True
|
||||
|
||||
async def is_end_user_within_model_budget(
|
||||
self,
|
||||
end_user_id: str,
|
||||
end_user_model_max_budget: dict,
|
||||
model: str,
|
||||
) -> bool:
|
||||
"""
|
||||
Check if the end_user is within the model budget
|
||||
|
||||
Raises:
|
||||
BudgetExceededError: If the end_user has exceeded the model budget
|
||||
"""
|
||||
internal_model_max_budget: GenericBudgetConfigType = {}
|
||||
|
||||
for _model, _budget_info in end_user_model_max_budget.items():
|
||||
internal_model_max_budget[_model] = BudgetConfig(**_budget_info)
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
"end_user internal_model_max_budget %s",
|
||||
json.dumps(internal_model_max_budget, indent=4, default=str),
|
||||
)
|
||||
|
||||
# check if current model is in internal_model_max_budget
|
||||
_current_model_budget_info = self._get_request_model_budget_config(
|
||||
model=model, internal_model_max_budget=internal_model_max_budget
|
||||
)
|
||||
if _current_model_budget_info is None:
|
||||
verbose_proxy_logger.debug(
|
||||
f"Model {model} not found in end_user_model_max_budget"
|
||||
)
|
||||
return True
|
||||
|
||||
# check if current model is within budget
|
||||
if (
|
||||
_current_model_budget_info.max_budget
|
||||
and _current_model_budget_info.max_budget > 0
|
||||
):
|
||||
_current_spend = await self._get_end_user_spend_for_model(
|
||||
end_user_id=end_user_id,
|
||||
model=model,
|
||||
key_budget_config=_current_model_budget_info,
|
||||
)
|
||||
if (
|
||||
_current_spend is not None
|
||||
and _current_model_budget_info.max_budget is not None
|
||||
and _current_spend > _current_model_budget_info.max_budget
|
||||
):
|
||||
raise litellm.BudgetExceededError(
|
||||
message=f"LiteLLM End User: {end_user_id}, exceeded budget for model={model}",
|
||||
current_cost=_current_spend,
|
||||
max_budget=_current_model_budget_info.max_budget,
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
async def _get_end_user_spend_for_model(
|
||||
self,
|
||||
end_user_id: str,
|
||||
model: str,
|
||||
key_budget_config: BudgetConfig,
|
||||
) -> Optional[float]:
|
||||
# 1. model: directly look up `model`
|
||||
end_user_model_spend_cache_key = f"{END_USER_SPEND_CACHE_KEY_PREFIX}:{end_user_id}:{model}:{key_budget_config.budget_duration}"
|
||||
_current_spend = await self.dual_cache.async_get_cache(
|
||||
key=end_user_model_spend_cache_key,
|
||||
)
|
||||
|
||||
if _current_spend is None:
|
||||
# 2. If 1, does not exist, check if passed as {custom_llm_provider}/model
|
||||
end_user_model_spend_cache_key = f"{END_USER_SPEND_CACHE_KEY_PREFIX}:{end_user_id}:{self._get_model_without_custom_llm_provider(model)}:{key_budget_config.budget_duration}"
|
||||
_current_spend = await self.dual_cache.async_get_cache(
|
||||
key=end_user_model_spend_cache_key,
|
||||
)
|
||||
return _current_spend
|
||||
|
||||
async def _get_virtual_key_spend_for_model(
|
||||
self,
|
||||
user_api_key_hash: Optional[str],
|
||||
@@ -163,46 +239,77 @@ class _PROXY_VirtualKeyModelMaxBudgetLimiter(RouterBudgetLimiting):
|
||||
user_api_key_model_max_budget: Optional[dict] = _metadata.get(
|
||||
"user_api_key_model_max_budget", None
|
||||
)
|
||||
user_api_key_end_user_model_max_budget: Optional[dict] = _metadata.get(
|
||||
"user_api_key_end_user_model_max_budget", None
|
||||
)
|
||||
if (
|
||||
user_api_key_model_max_budget is None
|
||||
or len(user_api_key_model_max_budget) == 0
|
||||
) and (
|
||||
user_api_key_end_user_model_max_budget is None
|
||||
or len(user_api_key_end_user_model_max_budget) == 0
|
||||
):
|
||||
verbose_proxy_logger.debug(
|
||||
"Not running _PROXY_VirtualKeyModelMaxBudgetLimiter.async_log_success_event because user_api_key_model_max_budget is None or empty. `user_api_key_model_max_budget`=%s",
|
||||
user_api_key_model_max_budget,
|
||||
"Not running _PROXY_VirtualKeyModelMaxBudgetLimiter.async_log_success_event because user_api_key_model_max_budget and user_api_key_end_user_model_max_budget are None or empty."
|
||||
)
|
||||
return
|
||||
|
||||
response_cost: float = standard_logging_payload.get("response_cost", 0)
|
||||
model = standard_logging_payload.get("model")
|
||||
virtual_key = standard_logging_payload.get("metadata", {}).get(
|
||||
"user_api_key_hash"
|
||||
)
|
||||
end_user_id = standard_logging_payload.get(
|
||||
"end_user"
|
||||
) or standard_logging_payload.get("metadata", {}).get(
|
||||
"user_api_key_end_user_id"
|
||||
)
|
||||
|
||||
if virtual_key is None or model is None:
|
||||
if model is None:
|
||||
return
|
||||
|
||||
# Resolve per-model budget config (same logic as is_key_within_model_budget)
|
||||
internal_model_max_budget: GenericBudgetConfigType = {}
|
||||
for _model, _budget_info in user_api_key_model_max_budget.items():
|
||||
internal_model_max_budget[_model] = BudgetConfig(**_budget_info)
|
||||
key_budget_config = self._get_request_model_budget_config(
|
||||
model=model, internal_model_max_budget=internal_model_max_budget
|
||||
)
|
||||
if key_budget_config is None or not key_budget_config.budget_duration:
|
||||
verbose_proxy_logger.debug(
|
||||
"Not incrementing model spend: no budget config or budget_duration for model=%s",
|
||||
model,
|
||||
if (
|
||||
virtual_key is not None
|
||||
and user_api_key_model_max_budget is not None
|
||||
and len(user_api_key_model_max_budget) > 0
|
||||
):
|
||||
internal_model_max_budget: GenericBudgetConfigType = {}
|
||||
for _model, _budget_info in user_api_key_model_max_budget.items():
|
||||
internal_model_max_budget[_model] = BudgetConfig(**_budget_info)
|
||||
key_budget_config = self._get_request_model_budget_config(
|
||||
model=model, internal_model_max_budget=internal_model_max_budget
|
||||
)
|
||||
return
|
||||
if key_budget_config is not None and key_budget_config.budget_duration:
|
||||
virtual_spend_key = f"{VIRTUAL_KEY_SPEND_CACHE_KEY_PREFIX}:{virtual_key}:{model}:{key_budget_config.budget_duration}"
|
||||
virtual_start_time_key = f"virtual_key_budget_start_time:{virtual_key}"
|
||||
await self._increment_spend_for_key(
|
||||
budget_config=key_budget_config,
|
||||
spend_key=virtual_spend_key,
|
||||
start_time_key=virtual_start_time_key,
|
||||
response_cost=response_cost,
|
||||
)
|
||||
|
||||
if (
|
||||
end_user_id is not None
|
||||
and user_api_key_end_user_model_max_budget is not None
|
||||
and len(user_api_key_end_user_model_max_budget) > 0
|
||||
):
|
||||
internal_model_max_budget: GenericBudgetConfigType = {}
|
||||
for _model, _budget_info in user_api_key_end_user_model_max_budget.items():
|
||||
internal_model_max_budget[_model] = BudgetConfig(**_budget_info)
|
||||
key_budget_config = self._get_request_model_budget_config(
|
||||
model=model, internal_model_max_budget=internal_model_max_budget
|
||||
)
|
||||
if key_budget_config is not None and key_budget_config.budget_duration:
|
||||
end_user_spend_key = f"{END_USER_SPEND_CACHE_KEY_PREFIX}:{end_user_id}:{model}:{key_budget_config.budget_duration}"
|
||||
end_user_start_time_key = f"end_user_budget_start_time:{end_user_id}"
|
||||
await self._increment_spend_for_key(
|
||||
budget_config=key_budget_config,
|
||||
spend_key=end_user_spend_key,
|
||||
start_time_key=end_user_start_time_key,
|
||||
response_cost=response_cost,
|
||||
)
|
||||
|
||||
virtual_spend_key = f"{VIRTUAL_KEY_SPEND_CACHE_KEY_PREFIX}:{virtual_key}:{model}:{key_budget_config.budget_duration}"
|
||||
virtual_start_time_key = f"virtual_key_budget_start_time:{virtual_key}"
|
||||
await self._increment_spend_for_key(
|
||||
budget_config=key_budget_config,
|
||||
spend_key=virtual_spend_key,
|
||||
start_time_key=virtual_start_time_key,
|
||||
response_cost=response_cost,
|
||||
)
|
||||
verbose_proxy_logger.debug(
|
||||
"current state of in memory cache %s",
|
||||
json.dumps(
|
||||
|
||||
@@ -10,10 +10,15 @@ import litellm
|
||||
from litellm._logging import verbose_logger, verbose_proxy_logger
|
||||
from litellm._service_logger import ServiceLogging
|
||||
from litellm.litellm_core_utils.safe_json_loads import safe_json_loads
|
||||
from litellm.proxy._types import (AddTeamCallback, CommonProxyErrors,
|
||||
LitellmDataForBackendLLMCall,
|
||||
LitellmUserRoles, SpecialHeaders,
|
||||
TeamCallbackMetadata, UserAPIKeyAuth)
|
||||
from litellm.proxy._types import (
|
||||
AddTeamCallback,
|
||||
CommonProxyErrors,
|
||||
LitellmDataForBackendLLMCall,
|
||||
LitellmUserRoles,
|
||||
SpecialHeaders,
|
||||
TeamCallbackMetadata,
|
||||
UserAPIKeyAuth,
|
||||
)
|
||||
from litellm.proxy.common_utils.http_parsing_utils import _safe_get_request_headers
|
||||
|
||||
# Cache special headers as a frozenset for O(1) lookup performance
|
||||
@@ -23,9 +28,12 @@ _SPECIAL_HEADERS_CACHE = frozenset(
|
||||
from litellm.router import Router
|
||||
from litellm.types.llms.anthropic import ANTHROPIC_API_HEADERS
|
||||
from litellm.types.services import ServiceTypes
|
||||
from litellm.types.utils import (LlmProviders, ProviderSpecificHeader,
|
||||
StandardLoggingUserAPIKeyMetadata,
|
||||
SupportedCacheControls)
|
||||
from litellm.types.utils import (
|
||||
LlmProviders,
|
||||
ProviderSpecificHeader,
|
||||
StandardLoggingUserAPIKeyMetadata,
|
||||
SupportedCacheControls,
|
||||
)
|
||||
|
||||
service_logger_obj = ServiceLogging() # used for tracking latency on OTEL
|
||||
|
||||
@@ -654,7 +662,8 @@ class LiteLLMProxyRequestSetup:
|
||||
return data
|
||||
from litellm.proxy._types import (
|
||||
LiteLLM_ManagementEndpoint_MetadataFields,
|
||||
LiteLLM_ManagementEndpoint_MetadataFields_Premium)
|
||||
LiteLLM_ManagementEndpoint_MetadataFields_Premium,
|
||||
)
|
||||
|
||||
# ignore any special fields
|
||||
added_metadata = {}
|
||||
@@ -1025,6 +1034,9 @@ async def add_litellm_data_to_request( # noqa: PLR0915
|
||||
data[_metadata_variable_name][
|
||||
"user_api_key_model_max_budget"
|
||||
] = user_api_key_dict.model_max_budget
|
||||
data[_metadata_variable_name][
|
||||
"user_api_key_end_user_model_max_budget"
|
||||
] = user_api_key_dict.end_user_model_max_budget
|
||||
|
||||
# User spend, budget - used by prometheus.py
|
||||
# Follow same pattern as team and API key budgets
|
||||
@@ -1479,8 +1491,7 @@ async def move_guardrails_to_metadata(
|
||||
|
||||
# Only check policy engine if no local config (avoid import + registry lookup)
|
||||
if not (has_key_config or has_team_config or has_request_config):
|
||||
from litellm.proxy.policy_engine.policy_registry import \
|
||||
get_policy_registry
|
||||
from litellm.proxy.policy_engine.policy_registry import get_policy_registry
|
||||
|
||||
if not get_policy_registry().is_initialized():
|
||||
# Nothing configured anywhere - clean up request body fields and return
|
||||
@@ -1544,16 +1555,14 @@ async def move_guardrails_to_metadata(
|
||||
|
||||
def _is_policy_version_id(s: str) -> bool:
|
||||
"""Return True if string is a policy version ID (starts with policy_<uuid> prefix)."""
|
||||
from litellm.proxy.policy_engine.policy_registry import \
|
||||
POLICY_VERSION_ID_PREFIX
|
||||
from litellm.proxy.policy_engine.policy_registry import POLICY_VERSION_ID_PREFIX
|
||||
|
||||
return isinstance(s, str) and s.startswith(POLICY_VERSION_ID_PREFIX)
|
||||
|
||||
|
||||
def _extract_policy_id(s: str) -> Optional[str]:
|
||||
"""Extract raw UUID from policy_<uuid> string, or None if not a valid version ID."""
|
||||
from litellm.proxy.policy_engine.policy_registry import \
|
||||
POLICY_VERSION_ID_PREFIX
|
||||
from litellm.proxy.policy_engine.policy_registry import POLICY_VERSION_ID_PREFIX
|
||||
|
||||
if not _is_policy_version_id(s):
|
||||
return None
|
||||
@@ -1574,9 +1583,10 @@ def _match_and_track_policies(
|
||||
"""
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.proxy.common_utils.callback_utils import (
|
||||
add_policy_sources_to_metadata, add_policy_to_applied_policies_header)
|
||||
from litellm.proxy.policy_engine.attachment_registry import \
|
||||
get_attachment_registry
|
||||
add_policy_sources_to_metadata,
|
||||
add_policy_to_applied_policies_header,
|
||||
)
|
||||
from litellm.proxy.policy_engine.attachment_registry import get_attachment_registry
|
||||
from litellm.proxy.policy_engine.policy_matcher import PolicyMatcher
|
||||
|
||||
# Get matching policies via attachments (with match reasons for attribution)
|
||||
@@ -1721,8 +1731,7 @@ async def add_guardrails_from_policy_engine(
|
||||
user_api_key_dict: The user's API key authentication info
|
||||
"""
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.proxy.common_utils.http_parsing_utils import \
|
||||
get_tags_from_request_body
|
||||
from litellm.proxy.common_utils.http_parsing_utils import get_tags_from_request_body
|
||||
from litellm.proxy.policy_engine.policy_registry import get_policy_registry
|
||||
from litellm.types.proxy.policy_engine import PolicyMatchContext
|
||||
|
||||
|
||||
@@ -158,3 +158,109 @@ async def test_async_log_success_event_uses_per_model_budget_duration(budget_lim
|
||||
f"{VIRTUAL_KEY_SPEND_CACHE_KEY_PREFIX}:{virtual_key}:{model}:{budget_duration}"
|
||||
)
|
||||
assert call_kwargs["response_cost"] == 0.05
|
||||
|
||||
|
||||
# Test is_end_user_within_model_budget
|
||||
@pytest.mark.asyncio
|
||||
async def test_is_end_user_within_model_budget(budget_limiter):
|
||||
# Test when model is within budget
|
||||
with patch.object(
|
||||
budget_limiter, "_get_end_user_spend_for_model", return_value=50.0
|
||||
):
|
||||
assert (
|
||||
await budget_limiter.is_end_user_within_model_budget(
|
||||
"test-user",
|
||||
{"gpt-4": {"budget_limit": 100.0, "time_period": "1d"}},
|
||||
"gpt-4",
|
||||
)
|
||||
is True
|
||||
)
|
||||
|
||||
# Test when model exceeds budget
|
||||
with patch.object(
|
||||
budget_limiter, "_get_end_user_spend_for_model", return_value=150.0
|
||||
):
|
||||
with pytest.raises(litellm.BudgetExceededError):
|
||||
await budget_limiter.is_end_user_within_model_budget(
|
||||
"test-user",
|
||||
{"gpt-4": {"budget_limit": 100.0, "time_period": "1d"}},
|
||||
"gpt-4",
|
||||
)
|
||||
|
||||
# Test model not in budget config
|
||||
assert (
|
||||
await budget_limiter.is_end_user_within_model_budget(
|
||||
"test-user",
|
||||
{"gpt-4": {"budget_limit": 100.0, "time_period": "1d"}},
|
||||
"non-existent",
|
||||
)
|
||||
is True
|
||||
)
|
||||
|
||||
|
||||
# Test _get_end_user_spend_for_model
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_end_user_spend_for_model(budget_limiter):
|
||||
budget_config = GenericBudgetInfo(budget_limit=100.0, time_period="1d")
|
||||
|
||||
# Mock cache get
|
||||
with patch.object(budget_limiter.dual_cache, "async_get_cache", return_value=50.0):
|
||||
spend = await budget_limiter._get_end_user_spend_for_model(
|
||||
end_user_id="test-user", model="gpt-4", key_budget_config=budget_config
|
||||
)
|
||||
assert spend == 50.0
|
||||
|
||||
# Test with provider prefix
|
||||
spend = await budget_limiter._get_end_user_spend_for_model(
|
||||
end_user_id="test-user",
|
||||
model="openai/gpt-4",
|
||||
key_budget_config=budget_config,
|
||||
)
|
||||
assert spend == 50.0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_log_success_event_uses_end_user_model_budget_duration(
|
||||
budget_limiter,
|
||||
):
|
||||
"""
|
||||
async_log_success_event must use the per-model budget_duration for the end user cache key
|
||||
"""
|
||||
from litellm.proxy.hooks.model_max_budget_limiter import (
|
||||
END_USER_SPEND_CACHE_KEY_PREFIX,
|
||||
)
|
||||
|
||||
end_user_id = "test-user"
|
||||
model = "gpt-4"
|
||||
budget_duration = "1d"
|
||||
user_api_key_end_user_model_max_budget = {
|
||||
model: {"budget_limit": 100.0, "time_period": budget_duration},
|
||||
}
|
||||
kwargs = {
|
||||
"standard_logging_object": {
|
||||
"response_cost": 0.05,
|
||||
"model": model,
|
||||
"end_user": end_user_id,
|
||||
"metadata": {"user_api_key_end_user_id": end_user_id},
|
||||
},
|
||||
"litellm_params": {
|
||||
"metadata": {
|
||||
"user_api_key_end_user_model_max_budget": user_api_key_end_user_model_max_budget
|
||||
},
|
||||
},
|
||||
}
|
||||
with patch.object(
|
||||
budget_limiter,
|
||||
"_increment_spend_for_key",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_increment:
|
||||
await budget_limiter.async_log_success_event(
|
||||
kwargs, response_obj=None, start_time=None, end_time=None
|
||||
)
|
||||
mock_increment.assert_awaited_once()
|
||||
call_kwargs = mock_increment.call_args.kwargs
|
||||
spend_key = call_kwargs["spend_key"]
|
||||
assert spend_key == (
|
||||
f"{END_USER_SPEND_CACHE_KEY_PREFIX}:{end_user_id}:{model}:{budget_duration}"
|
||||
)
|
||||
assert call_kwargs["response_cost"] == 0.05
|
||||
|
||||
@@ -0,0 +1,59 @@
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, patch
|
||||
import litellm
|
||||
from litellm.proxy.auth.user_api_key_auth import _run_post_custom_auth_checks
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_custom_auth_run_post_custom_auth_checks_without_end_user_id():
|
||||
# Test backwards compatibility
|
||||
valid_token = UserAPIKeyAuth(token="test_token")
|
||||
|
||||
with patch(
|
||||
"litellm.proxy.auth.user_api_key_auth.common_checks", new_callable=AsyncMock
|
||||
) as mock_common:
|
||||
mock_common.return_value = True
|
||||
result = await _run_post_custom_auth_checks(
|
||||
valid_token=valid_token,
|
||||
request=None,
|
||||
request_data={},
|
||||
route="/v1/chat/completions",
|
||||
parent_otel_span=None,
|
||||
)
|
||||
assert result.token == "test_token"
|
||||
assert getattr(result, "end_user_id", None) is None
|
||||
mock_common.assert_awaited_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_custom_auth_run_post_custom_auth_checks_with_end_user_budget_exceeded():
|
||||
valid_token = UserAPIKeyAuth(
|
||||
token="test_token",
|
||||
end_user_id="test_user",
|
||||
end_user_model_max_budget={
|
||||
"gpt-4": {"budget_limit": 10.0, "time_period": "1d"}
|
||||
},
|
||||
)
|
||||
request_data = {"model": "gpt-4"}
|
||||
|
||||
with patch(
|
||||
"litellm.proxy.auth.user_api_key_auth.common_checks", new_callable=AsyncMock
|
||||
):
|
||||
with patch(
|
||||
"litellm.proxy.proxy_server.model_max_budget_limiter.is_end_user_within_model_budget",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_budget_check:
|
||||
mock_budget_check.side_effect = litellm.BudgetExceededError(
|
||||
message="Exceeded budget", current_cost=20.0, max_budget=10.0
|
||||
)
|
||||
|
||||
with pytest.raises(litellm.BudgetExceededError):
|
||||
await _run_post_custom_auth_checks(
|
||||
valid_token=valid_token,
|
||||
request=None,
|
||||
request_data=request_data,
|
||||
route="/v1/chat/completions",
|
||||
parent_otel_span=None,
|
||||
)
|
||||
mock_budget_check.assert_awaited_once()
|
||||
Reference in New Issue
Block a user