mirror of
https://github.com/tiennm99/litellm.git
synced 2026-06-18 03:31:23 +00:00
3003 lines
125 KiB
Python
3003 lines
125 KiB
Python
"""
|
|
This file handles authentication for the LiteLLM Proxy.
|
|
|
|
it checks if the user passed a valid API Key to the LiteLLM Proxy
|
|
|
|
Returns a UserAPIKeyAuth object if the API key is valid
|
|
|
|
"""
|
|
|
|
import asyncio
|
|
import fnmatch
|
|
import re
|
|
import secrets
|
|
from datetime import datetime, timezone
|
|
from typing import Any, Dict, Iterator, NamedTuple, List, Optional, Tuple, Union, cast
|
|
|
|
import fastapi
|
|
from fastapi import HTTPException, Request, WebSocket, status
|
|
from fastapi.security.api_key import APIKeyHeader
|
|
|
|
import litellm
|
|
from litellm._logging import verbose_logger, verbose_proxy_logger
|
|
from litellm._service_logger import ServiceLogging
|
|
from litellm.constants import LITELLM_PROXY_MASTER_KEY_ALIAS
|
|
from litellm.integrations.otel.model.config import is_otel_v2_enabled
|
|
from litellm.integrations.otel.runtime import phase_span, seed_request_identity
|
|
from litellm.litellm_core_utils.dd_tracing import tracer
|
|
from litellm.litellm_core_utils.dot_notation_indexing import get_nested_value
|
|
from litellm.proxy._types import *
|
|
from litellm.proxy.auth.auth_checks import (
|
|
ExperimentalUIJWTToken,
|
|
_cache_key_object,
|
|
_check_end_user_budget,
|
|
_delete_cache_key_object,
|
|
_get_user_role,
|
|
_is_model_cost_zero,
|
|
_is_user_proxy_admin,
|
|
_virtual_key_max_budget_alert_check,
|
|
_virtual_key_max_budget_check,
|
|
_virtual_key_soft_budget_check,
|
|
can_key_call_model,
|
|
common_checks,
|
|
get_end_user_object,
|
|
get_jwt_key_mapping_object,
|
|
get_key_object,
|
|
get_project_object,
|
|
get_team_object,
|
|
get_user_object,
|
|
is_valid_fallback_model,
|
|
resolve_and_validate_end_user_id,
|
|
)
|
|
from litellm.proxy.auth.auth_exception_handler import UserAPIKeyAuthExceptionHandler
|
|
from litellm.proxy.auth.auth_utils import (
|
|
abbreviate_api_key,
|
|
get_end_user_id_from_request_body,
|
|
get_model_from_request,
|
|
get_request_route,
|
|
get_request_route_template,
|
|
normalize_request_route,
|
|
pre_db_read_auth_checks,
|
|
route_in_additonal_public_routes,
|
|
)
|
|
from litellm.proxy.auth.handle_jwt import JWTAuthManager, JWTHandler
|
|
from litellm.proxy.auth.oauth2_check import Oauth2Handler
|
|
from litellm.proxy.auth.oauth2_proxy_hook import handle_oauth2_proxy_request
|
|
from litellm.proxy.auth.route_checks import RouteChecks
|
|
from litellm.proxy.common_utils.cache_coordinator import EventDrivenCacheCoordinator
|
|
from litellm.proxy.common_utils.http_parsing_utils import (
|
|
_read_request_body,
|
|
_safe_get_request_headers,
|
|
_safe_get_request_query_params,
|
|
populate_request_with_path_params,
|
|
)
|
|
from litellm.proxy.common_utils.realtime_utils import _realtime_request_body
|
|
from litellm.proxy.common_utils.user_api_key_cache import UserApiKeyCache
|
|
from litellm.proxy.litellm_pre_call_utils import LiteLLMProxyRequestSetup
|
|
from litellm.proxy.utils import (
|
|
PrismaClient,
|
|
ProxyLogging,
|
|
normalize_route_for_root_path,
|
|
)
|
|
from litellm.repositories.table_repositories import TeamMembershipRepository
|
|
from litellm.secret_managers.main import get_secret_bool
|
|
from litellm.types.services import ServiceTypes
|
|
|
|
try:
|
|
from litellm_enterprise.proxy.auth.user_api_key_auth import (
|
|
enterprise_custom_auth as _enterprise_custom_auth,
|
|
)
|
|
|
|
enterprise_custom_auth: Optional[Callable] = _enterprise_custom_auth
|
|
except ImportError as e:
|
|
verbose_proxy_logger.debug(f"Error in enterprise custom auth: {e}")
|
|
enterprise_custom_auth = None
|
|
|
|
user_api_key_service_logger_obj = ServiceLogging() # used for tracking latency on OTEL
|
|
|
|
|
|
def _normalize_public_auth_route(route: str) -> str:
|
|
if route != "/" and route.endswith("/"):
|
|
return route.rstrip("/")
|
|
return route
|
|
|
|
|
|
def _route_requires_auth_despite_public(
|
|
route: str, general_settings: Optional[dict]
|
|
) -> bool:
|
|
normalized_route = _normalize_public_auth_route(route)
|
|
if normalized_route == "/metrics":
|
|
return litellm.require_auth_for_metrics_endpoint is not False
|
|
|
|
return False
|
|
|
|
|
|
custom_litellm_key_header = APIKeyHeader(
|
|
name=SpecialHeaders.custom_litellm_api_key.value,
|
|
auto_error=False,
|
|
description="Bearer token",
|
|
)
|
|
api_key_header = APIKeyHeader(
|
|
name=SpecialHeaders.openai_authorization.value,
|
|
auto_error=False,
|
|
description="Bearer token",
|
|
)
|
|
azure_api_key_header = APIKeyHeader(
|
|
name=SpecialHeaders.azure_authorization.value,
|
|
auto_error=False,
|
|
description="Some older versions of the openai Python package will send an API-Key header with just the API key ",
|
|
)
|
|
anthropic_api_key_header = APIKeyHeader(
|
|
name=SpecialHeaders.anthropic_authorization.value,
|
|
auto_error=False,
|
|
description="If anthropic client used.",
|
|
)
|
|
google_ai_studio_api_key_header = APIKeyHeader(
|
|
name=SpecialHeaders.google_ai_studio_authorization.value,
|
|
auto_error=False,
|
|
description="If google ai studio client used.",
|
|
)
|
|
azure_apim_header = APIKeyHeader(
|
|
name=SpecialHeaders.azure_apim_authorization.value,
|
|
auto_error=False,
|
|
description="The default name of the subscription key header of Azure",
|
|
)
|
|
|
|
|
|
def _get_model_from_request_context(
|
|
request_data: dict,
|
|
route: str,
|
|
request: Optional[Request],
|
|
llm_router: Optional[Any] = None,
|
|
) -> Optional[Union[str, List[str]]]:
|
|
return get_model_from_request(
|
|
request_data=request_data,
|
|
route=route,
|
|
request_headers=_safe_get_request_headers(request=request),
|
|
request_query_params=_safe_get_request_query_params(request=request),
|
|
llm_router=llm_router,
|
|
)
|
|
|
|
|
|
def _get_model_names_for_budget_checks(
|
|
model: Optional[Union[str, List[str]]],
|
|
) -> List[str]:
|
|
if model is None:
|
|
return []
|
|
if isinstance(model, str):
|
|
return [model]
|
|
return model
|
|
|
|
|
|
def _get_bearer_token_or_received_api_key(api_key: str) -> str:
|
|
if api_key.startswith("Bearer "): # ensure Bearer token passed in
|
|
api_key = api_key.replace("Bearer ", "") # extract the token
|
|
elif api_key.startswith("Basic "):
|
|
api_key = api_key.replace("Basic ", "") # handle langfuse input
|
|
elif api_key.startswith("bearer "):
|
|
api_key = api_key.replace("bearer ", "")
|
|
elif api_key.startswith("AWS4-HMAC-SHA256"):
|
|
# Handle AWS Signature V4 format from LangChain
|
|
# Format: AWS4-HMAC-SHA256 Credential=Bearer sk-12345/date/region/service/aws4_request, SignedHeaders=..., Signature=...
|
|
# Extract the Bearer token from the Credential field
|
|
match = re.search(r"Credential=Bearer\s+([^/\s,]+)", api_key)
|
|
if match:
|
|
api_key = match.group(1)
|
|
else:
|
|
# If no Bearer token found in Credential, try to extract just the credential value
|
|
match = re.search(r"Credential=([^/\s,]+)", api_key)
|
|
if match:
|
|
api_key = match.group(1)
|
|
|
|
return api_key
|
|
|
|
|
|
def _routing_selector_matches_claim(
|
|
selector_value: Optional[Any],
|
|
claim_value: Optional[Any],
|
|
*,
|
|
split_space_delimited: bool = False,
|
|
) -> bool:
|
|
if selector_value is None:
|
|
return True
|
|
|
|
selector_list: List[str] = (
|
|
[str(v) for v in selector_value]
|
|
if isinstance(selector_value, list)
|
|
else [str(selector_value)]
|
|
)
|
|
|
|
if claim_value is None:
|
|
return False
|
|
|
|
if isinstance(claim_value, list):
|
|
claim_list = [str(v) for v in claim_value]
|
|
elif (
|
|
split_space_delimited
|
|
and isinstance(claim_value, str)
|
|
and " " in claim_value.strip()
|
|
):
|
|
# OAuth/OIDC often sends scope as a single space-delimited string. Only split
|
|
# for the scope selector: iss/aud/client_id must stay exact full-string match
|
|
# on unverified claims (see routing override security review). The elif guard
|
|
# (`" " in claim_value.strip()`) ensures at least two non-empty tokens survive.
|
|
claim_list = [v for v in claim_value.strip().split(" ") if v]
|
|
else:
|
|
claim_list = [str(claim_value)]
|
|
|
|
def _selector_matches_claim(selector: str, claim: str) -> bool:
|
|
# NOTE: wildcard matching is case-sensitive (fnmatch.fnmatchcase).
|
|
if "*" in selector or "?" in selector:
|
|
# Without scope splitting, do not let `*` span whitespace: a malformed
|
|
# iss like "trusted.example.com evil.com" must not match "trusted.*".
|
|
# Scope uses split_space_delimited so each claim token is checked separately.
|
|
if not split_space_delimited and any(ch.isspace() for ch in claim):
|
|
return False
|
|
return fnmatch.fnmatchcase(claim, selector)
|
|
return selector == claim
|
|
|
|
return any(
|
|
_selector_matches_claim(selector=s, claim=c)
|
|
for s in selector_list
|
|
for c in claim_list
|
|
)
|
|
|
|
|
|
def _matches_routing_override(
|
|
token_claims: dict, override: "JWTRoutingOverride"
|
|
) -> bool:
|
|
return (
|
|
_routing_selector_matches_claim(override.iss, token_claims.get("iss"))
|
|
and _routing_selector_matches_claim(
|
|
override.client_id, token_claims.get("client_id")
|
|
)
|
|
and _routing_selector_matches_claim(
|
|
override.scope,
|
|
token_claims.get("scope"),
|
|
split_space_delimited=True,
|
|
)
|
|
and _routing_selector_matches_claim(override.aud, token_claims.get("aud"))
|
|
)
|
|
|
|
|
|
def _should_route_jwt_to_oauth2_override(token: str, jwt_handler: JWTHandler) -> bool:
|
|
routing_overrides = jwt_handler.litellm_jwtauth.routing_overrides
|
|
if not routing_overrides:
|
|
return False
|
|
|
|
token_claims = jwt_handler.get_unverified_claims(token=token)
|
|
if token_claims is None:
|
|
return False
|
|
|
|
for override in routing_overrides:
|
|
if override.path == "oauth2" and _matches_routing_override(
|
|
token_claims=token_claims, override=override
|
|
):
|
|
verbose_proxy_logger.debug(
|
|
"JWT routing override matched. Routing token to OAuth2 introspection."
|
|
)
|
|
return True
|
|
|
|
return False
|
|
|
|
|
|
def _get_bearer_token(
|
|
api_key: str,
|
|
):
|
|
if api_key.startswith("Bearer "): # ensure Bearer token passed in
|
|
api_key = api_key.replace("Bearer ", "") # extract the token
|
|
elif api_key.startswith("Basic "):
|
|
api_key = api_key.replace("Basic ", "") # handle langfuse input
|
|
elif api_key.startswith("bearer "):
|
|
api_key = api_key.replace("bearer ", "")
|
|
elif api_key.startswith("AWS4-HMAC-SHA256"):
|
|
# Handle AWS Signature V4 format from LangChain
|
|
# Format: AWS4-HMAC-SHA256 Credential=Bearer sk-12345/date/region/service/aws4_request, SignedHeaders=..., Signature=...
|
|
# Extract the Bearer token from the Credential field
|
|
match = re.search(r"Credential=Bearer\s+([^/\s,]+)", api_key)
|
|
if match:
|
|
api_key = match.group(1)
|
|
else:
|
|
# If no Bearer token found in Credential, try to extract just the credential value
|
|
match = re.search(r"Credential=([^/\s,]+)", api_key)
|
|
if match:
|
|
api_key = match.group(1)
|
|
else:
|
|
api_key = ""
|
|
else:
|
|
api_key = ""
|
|
return api_key
|
|
|
|
|
|
def _apply_budget_limits_to_end_user_params(
|
|
end_user_params: dict,
|
|
budget_info: LiteLLM_BudgetTable,
|
|
end_user_id: Optional[str],
|
|
) -> None:
|
|
"""
|
|
Helper function to apply budget limits to end user parameters.
|
|
|
|
Args:
|
|
end_user_params: Dictionary to update with budget parameters
|
|
budget_info: Budget table object containing limits
|
|
end_user_id: ID of the end user for logging
|
|
"""
|
|
if budget_info.tpm_limit is not None:
|
|
end_user_params["end_user_tpm_limit"] = budget_info.tpm_limit
|
|
|
|
if budget_info.rpm_limit is not None:
|
|
end_user_params["end_user_rpm_limit"] = budget_info.rpm_limit
|
|
|
|
if budget_info.max_budget is not None:
|
|
end_user_params["end_user_max_budget"] = budget_info.max_budget
|
|
|
|
if budget_info.model_max_budget is not None:
|
|
end_user_params["end_user_model_max_budget"] = budget_info.model_max_budget
|
|
|
|
verbose_proxy_logger.debug(f"Applied budget limits to end user {end_user_id}")
|
|
|
|
|
|
async def user_api_key_auth_websocket(websocket: WebSocket):
|
|
# Accept the WebSocket connection
|
|
|
|
ws_scope = websocket.scope or {}
|
|
scope_headers = list(ws_scope.get("headers") or [])
|
|
# ``get_request_route`` falls back to ``request.url.path`` when
|
|
# ``scope["path"]`` is absent. On WebSockets that fallback reads
|
|
# ``websocket.url``, which Starlette reconstructs from the (poisonable)
|
|
# Host header. Carry the ASGI scope's path / root_path so the lookup
|
|
# never reaches the fallback.
|
|
synthetic_scope: Dict[str, Any] = {
|
|
"type": "http",
|
|
"headers": scope_headers,
|
|
"path": ws_scope.get("path", ""),
|
|
}
|
|
for key in ("root_path", "app_root_path"):
|
|
if key in ws_scope:
|
|
synthetic_scope[key] = ws_scope[key]
|
|
request = Request(scope=synthetic_scope)
|
|
|
|
request._url = websocket.url
|
|
|
|
query_params = websocket.query_params
|
|
|
|
model = query_params.get("model")
|
|
|
|
async def return_body():
|
|
return _realtime_request_body(model)
|
|
|
|
request.body = return_body # type: ignore
|
|
|
|
authorization = websocket.headers.get("authorization")
|
|
# If no Authorization header, try the api-key header
|
|
if not authorization:
|
|
api_key = websocket.headers.get("api-key")
|
|
if not api_key:
|
|
# Try extracting from WebSocket subprotocol (browser clients)
|
|
for protocol in websocket.headers.get("sec-websocket-protocol", "").split(
|
|
","
|
|
):
|
|
protocol = protocol.strip()
|
|
if protocol.startswith("openai-insecure-api-key."):
|
|
api_key = protocol[len("openai-insecure-api-key.") :]
|
|
break
|
|
if not api_key:
|
|
await websocket.close(code=status.WS_1008_POLICY_VIOLATION)
|
|
raise HTTPException(status_code=403, detail="No API key provided")
|
|
else:
|
|
# Extract the API key from the Bearer token
|
|
if not authorization.startswith("Bearer "):
|
|
await websocket.close(code=status.WS_1008_POLICY_VIOLATION)
|
|
raise HTTPException(
|
|
status_code=403, detail="Invalid Authorization header format"
|
|
)
|
|
|
|
api_key = authorization[len("Bearer ") :].strip()
|
|
|
|
# Call user_api_key_auth with the extracted API key
|
|
# Note: You'll need to modify this to work with WebSocket context if needed
|
|
try:
|
|
return await user_api_key_auth(request=request, api_key=f"Bearer {api_key}")
|
|
except Exception as e:
|
|
verbose_proxy_logger.exception(e)
|
|
await websocket.close(code=status.WS_1008_POLICY_VIOLATION)
|
|
raise HTTPException(status_code=403, detail=str(e))
|
|
|
|
|
|
def update_valid_token_with_end_user_params(
|
|
valid_token: UserAPIKeyAuth, end_user_params: dict
|
|
) -> UserAPIKeyAuth:
|
|
valid_token.end_user_id = end_user_params.get("end_user_id")
|
|
# Only overwrite token fields when the DB-derived value is not None.
|
|
# This prevents DB lookups (where the budget table has no value set)
|
|
# from silently clearing values that a custom auth function may have
|
|
# already set on the token.
|
|
if end_user_params.get("end_user_tpm_limit") is not None:
|
|
valid_token.end_user_tpm_limit = end_user_params["end_user_tpm_limit"]
|
|
if end_user_params.get("end_user_rpm_limit") is not None:
|
|
valid_token.end_user_rpm_limit = end_user_params["end_user_rpm_limit"]
|
|
if end_user_params.get("allowed_model_region") is not None:
|
|
valid_token.allowed_model_region = end_user_params["allowed_model_region"]
|
|
if end_user_params.get("end_user_model_max_budget") is not None:
|
|
valid_token.end_user_model_max_budget = end_user_params[
|
|
"end_user_model_max_budget"
|
|
]
|
|
return valid_token
|
|
|
|
|
|
# Reusable coordinator for global spend to prevent cache stampede
|
|
_global_spend_coordinator = EventDrivenCacheCoordinator(log_prefix="[GLOBAL SPEND]")
|
|
|
|
|
|
async def _fetch_global_spend_with_event_coordination(
|
|
cache_key: str,
|
|
user_api_key_cache: UserApiKeyCache,
|
|
prisma_client: PrismaClient,
|
|
) -> Optional[float]:
|
|
"""
|
|
Fetch global spend with event-driven coordination to prevent cache stampede.
|
|
Uses EventDrivenCacheCoordinator: first request queries DB and signals others when done.
|
|
"""
|
|
|
|
async def _load_global_spend() -> Optional[float]:
|
|
sql_query = """SELECT SUM(spend) AS total_spend FROM "MonthlyGlobalSpend";"""
|
|
response = await prisma_client.db.query_raw(query=sql_query)
|
|
val = response[0]["total_spend"]
|
|
return float(val) if val is not None else None
|
|
|
|
return await _global_spend_coordinator.get_or_load(
|
|
cache_key=cache_key,
|
|
cache=user_api_key_cache, # pyright: ignore[reportArgumentType]
|
|
load_fn=_load_global_spend,
|
|
)
|
|
|
|
|
|
async def get_global_proxy_spend(
|
|
litellm_proxy_admin_name: str,
|
|
user_api_key_cache: UserApiKeyCache,
|
|
prisma_client: Optional[PrismaClient],
|
|
token: str,
|
|
proxy_logging_obj: ProxyLogging,
|
|
) -> Optional[float]:
|
|
global_proxy_spend = None
|
|
if (
|
|
litellm.max_budget > 0 and prisma_client is not None
|
|
): # user set proxy max budget
|
|
# Use event-driven coordination to prevent cache stampede
|
|
cache_key = "{}:spend".format(litellm_proxy_admin_name)
|
|
global_proxy_spend = await _fetch_global_spend_with_event_coordination(
|
|
cache_key=cache_key,
|
|
user_api_key_cache=user_api_key_cache,
|
|
prisma_client=prisma_client,
|
|
)
|
|
if global_proxy_spend is not None:
|
|
user_info = CallInfo(
|
|
user_id=litellm_proxy_admin_name,
|
|
max_budget=litellm.max_budget,
|
|
spend=global_proxy_spend,
|
|
token=token,
|
|
event_group=Litellm_EntityType.PROXY,
|
|
)
|
|
asyncio.create_task(
|
|
proxy_logging_obj.budget_alerts(
|
|
type="proxy_budget",
|
|
user_info=user_info,
|
|
)
|
|
)
|
|
return global_proxy_spend
|
|
|
|
|
|
def get_rbac_role(jwt_handler: JWTHandler, scopes: List[str]) -> str:
|
|
is_admin = jwt_handler.is_admin(scopes=scopes)
|
|
if is_admin:
|
|
return LitellmUserRoles.PROXY_ADMIN
|
|
else:
|
|
return LitellmUserRoles.TEAM
|
|
|
|
|
|
def get_api_key(
|
|
custom_litellm_key_header: Optional[str],
|
|
api_key: str,
|
|
azure_api_key_header: Optional[str],
|
|
anthropic_api_key_header: Optional[str],
|
|
google_ai_studio_api_key_header: Optional[str],
|
|
azure_apim_header: Optional[str],
|
|
pass_through_endpoints: Optional[List[dict]],
|
|
route: str,
|
|
request: Request,
|
|
) -> Tuple[str, Optional[str]]:
|
|
"""
|
|
Returns:
|
|
Tuple[Optional[str], Optional[str]]: Tuple of the api_key and the passed_in_key
|
|
"""
|
|
from litellm.proxy.auth.route_checks import RouteChecks
|
|
from litellm.proxy.common_utils.http_parsing_utils import (
|
|
_safe_get_request_query_params,
|
|
)
|
|
|
|
api_key = api_key
|
|
passed_in_key: Optional[str] = None
|
|
if isinstance(custom_litellm_key_header, str):
|
|
passed_in_key = custom_litellm_key_header
|
|
api_key = _get_bearer_token_or_received_api_key(custom_litellm_key_header)
|
|
elif isinstance(api_key, str) and len(api_key) > 0:
|
|
passed_in_key = api_key
|
|
api_key = _get_bearer_token(api_key=api_key)
|
|
elif isinstance(azure_api_key_header, str):
|
|
passed_in_key = azure_api_key_header
|
|
api_key = azure_api_key_header
|
|
elif isinstance(anthropic_api_key_header, str):
|
|
passed_in_key = anthropic_api_key_header
|
|
api_key = anthropic_api_key_header
|
|
elif isinstance(google_ai_studio_api_key_header, str):
|
|
passed_in_key = google_ai_studio_api_key_header
|
|
api_key = google_ai_studio_api_key_header
|
|
elif isinstance(azure_apim_header, str):
|
|
passed_in_key = azure_apim_header
|
|
api_key = azure_apim_header
|
|
elif (
|
|
RouteChecks.is_generate_content_route(route=route)
|
|
and request is not None
|
|
and _safe_get_request_query_params(request).get("key")
|
|
):
|
|
google_auth_key: str = _safe_get_request_query_params(request).get("key") or ""
|
|
passed_in_key = google_auth_key
|
|
api_key = google_auth_key
|
|
elif pass_through_endpoints is not None:
|
|
for endpoint in pass_through_endpoints:
|
|
if endpoint.get("path", "") == route:
|
|
headers: Optional[dict] = endpoint.get("headers", None)
|
|
if headers is not None:
|
|
header_key: str = headers.get("litellm_user_api_key", "")
|
|
if request.headers.get(header_key) is not None:
|
|
api_key = request.headers.get(header_key) or ""
|
|
passed_in_key = api_key
|
|
return api_key, passed_in_key
|
|
|
|
|
|
async def check_api_key_for_custom_headers_or_pass_through_endpoints(
|
|
request: Request,
|
|
route: str,
|
|
pass_through_endpoints: Optional[List[dict]],
|
|
api_key: str,
|
|
) -> Union[UserAPIKeyAuth, str]:
|
|
is_mapped_pass_through_route: bool = False
|
|
normalized_route = normalize_route_for_root_path(route)
|
|
if normalized_route is not None:
|
|
for mapped_route in LiteLLMRoutes.mapped_pass_through_routes.value: # type: ignore
|
|
if normalized_route.startswith(mapped_route):
|
|
is_mapped_pass_through_route = True
|
|
break
|
|
if is_mapped_pass_through_route:
|
|
if request.headers.get("litellm_user_api_key") is not None:
|
|
api_key = request.headers.get("litellm_user_api_key") or ""
|
|
if pass_through_endpoints is not None:
|
|
for endpoint in pass_through_endpoints:
|
|
if isinstance(endpoint, dict) and endpoint.get("path", "") == route:
|
|
## IF AUTH DISABLED
|
|
# Default to True: a config dict with no ``auth`` key
|
|
# otherwise produced an unauthenticated forwarder. The
|
|
# Pydantic ``PassThroughGenericEndpoint.auth`` default
|
|
# is also True, but raw config dicts skip that path —
|
|
# so this runtime check has to default to True too.
|
|
if endpoint.get("auth", True) is not True:
|
|
return UserAPIKeyAuth()
|
|
## IF AUTH ENABLED
|
|
### IF CUSTOM PARSER REQUIRED
|
|
if (
|
|
endpoint.get("custom_auth_parser") is not None
|
|
and endpoint.get("custom_auth_parser") == "langfuse"
|
|
):
|
|
# langfuse returns {'Authorization': 'Basic <base64(username:password)>'}
|
|
# check the langfuse public key if it contains the litellm api key
|
|
import base64
|
|
|
|
api_key = api_key.replace("Basic ", "").strip()
|
|
decoded_bytes = base64.b64decode(api_key)
|
|
decoded_str = decoded_bytes.decode("utf-8")
|
|
api_key = decoded_str.split(":")[0]
|
|
else:
|
|
headers = endpoint.get("headers", None)
|
|
if headers is not None:
|
|
header_key = headers.get("litellm_user_api_key", "")
|
|
if (
|
|
isinstance(request.headers, dict)
|
|
and request.headers.get(key=header_key) is not None # type: ignore
|
|
):
|
|
api_key = request.headers.get(key=header_key) # type: ignore
|
|
return api_key
|
|
|
|
|
|
# Cache sentinel written when a JWT under AUTO_REGISTER resolved to a proxy
|
|
# admin via auth_builder. Proxy admins don't need a mapped virtual key (they
|
|
# have full access via auth_builder anyway), but without a cache entry every
|
|
# subsequent request from the same JWT identity would re-query the DB for a
|
|
# non-existent mapping. Sentinel tells _resolve_jwt_to_virtual_key to skip
|
|
# the lookup and return None (caller proceeds to auth_builder).
|
|
_JWT_PROXY_ADMIN_SENTINEL = "__JWT_PROXY_ADMIN__"
|
|
|
|
|
|
class _PendingAutoRegister(NamedTuple):
|
|
"""
|
|
Signal returned by ``_resolve_jwt_to_virtual_key`` when the JWT's claim is
|
|
unmapped and ``unregistered_jwt_client_behavior`` is AUTO_REGISTER.
|
|
|
|
The caller MUST run standard ``JWTAuthManager.auth_builder`` to apply RBAC,
|
|
scope mappings, ``custom_validate``, and ``user_allowed_email_domain``
|
|
policy BEFORE calling ``_auto_register_jwt_mapping`` with the validated
|
|
``team_id`` / ``user_id`` from the auth_builder result. Auto-registering
|
|
purely on a signature-valid JWT (the old behavior) bypassed every JWT
|
|
policy beyond signature verification.
|
|
"""
|
|
|
|
claim_field: str
|
|
claim_value: str
|
|
cache_key: str
|
|
|
|
|
|
async def _auto_register_jwt_mapping(
|
|
virtual_key_claim_field: str,
|
|
claim_value: str,
|
|
jwt_handler: JWTHandler,
|
|
prisma_client: PrismaClient,
|
|
user_api_key_cache: UserApiKeyCache,
|
|
parent_otel_span: Optional[Span],
|
|
proxy_logging_obj: ProxyLogging,
|
|
cache_key: str,
|
|
team_id: Optional[str] = None,
|
|
user_id: Optional[str] = None,
|
|
org_id: Optional[str] = None,
|
|
end_user_id: Optional[str] = None,
|
|
) -> Optional[UserAPIKeyAuth]:
|
|
"""
|
|
Auto-register: create a new virtual key + mapping for an unrecognised JWT
|
|
claim value. ``team_id`` and ``user_id`` must come from a successful
|
|
``JWTAuthManager.auth_builder`` run — they encode the JWT identity AFTER
|
|
RBAC/scope/custom_validate/email-domain policy has been enforced. The key
|
|
is stamped with those values so the cached future-request path inherits
|
|
the same team/user/org limits the auth_builder path would have applied.
|
|
|
|
Race safety: if two concurrent requests both reach here simultaneously (both
|
|
saw no mapping in the DB), one will win the unique-constraint race on
|
|
litellm_jwtkeymapping. The loser catches the conflict, deletes its orphaned
|
|
key, fetches the winner's mapping, and proceeds — no error surfaced.
|
|
"""
|
|
# Inline import required: key_management_endpoints imports user_api_key_auth
|
|
# (line 51) so a module-level import here would create a circular dependency.
|
|
from litellm.proxy.management_endpoints.key_management_endpoints import (
|
|
generate_key_helper_fn,
|
|
)
|
|
|
|
# ``table_name="key"`` is required: without it, generate_key_helper_fn
|
|
# falls into the user-upsert branch (`table_name is None or "user"`) and
|
|
# attempts to insert into LiteLLM_UserTable with user_id=None, which fails
|
|
# the NOT NULL @id constraint. Every successful key-creation caller (e.g.
|
|
# /key/generate) passes table_name="key" explicitly.
|
|
key_data = await generate_key_helper_fn(
|
|
request_type="key",
|
|
table_name="key",
|
|
team_id=team_id,
|
|
user_id=user_id,
|
|
organization_id=org_id,
|
|
metadata={
|
|
"auto_registered": True,
|
|
"jwt_claim_field": virtual_key_claim_field,
|
|
"jwt_claim_value": claim_value,
|
|
},
|
|
)
|
|
# generate_key_helper_fn returns the plaintext key in "token"; the persisted
|
|
# row in LiteLLM_VerificationToken uses its hash, so hash here to get the FK
|
|
# value referenced by LiteLLM_JWTKeyMapping.token.
|
|
token_hash = hash_token(key_data["token"])
|
|
|
|
try:
|
|
await prisma_client.db.litellm_jwtkeymapping.create(
|
|
data={
|
|
"jwt_claim_name": virtual_key_claim_field,
|
|
"jwt_claim_value": claim_value,
|
|
"token": token_hash,
|
|
"created_by": "auto_register",
|
|
"updated_by": "auto_register",
|
|
}
|
|
)
|
|
except Exception as e:
|
|
error_str = str(e).lower()
|
|
if "unique" in error_str or "p2002" in error_str:
|
|
# A concurrent request won the race. The key generate_key_helper_fn
|
|
# just persisted to LiteLLM_VerificationToken is orphaned — nothing
|
|
# maps to it, but it's a fully valid unrestricted API key sitting in
|
|
# the DB and the cleartext is in memory on this request. Delete it
|
|
# so orphans don't accumulate under sustained concurrency.
|
|
verbose_proxy_logger.debug(
|
|
"JWT Key Mapping (auto_register): unique conflict on create — "
|
|
"deleting orphaned virtual key and fetching winner's mapping for %s='%s'.",
|
|
virtual_key_claim_field,
|
|
claim_value,
|
|
)
|
|
try:
|
|
await prisma_client.db.litellm_verificationtoken.delete(
|
|
where={"token": token_hash}
|
|
)
|
|
except Exception as delete_err:
|
|
# Don't fail the request if cleanup fails — the orphan is
|
|
# unmapped and inert. Log so an operator can prune it later.
|
|
verbose_proxy_logger.warning(
|
|
"JWT Key Mapping (auto_register): failed to delete orphaned key after race: %s",
|
|
delete_err,
|
|
)
|
|
token_hash = await get_jwt_key_mapping_object(
|
|
jwt_claim_name=virtual_key_claim_field,
|
|
jwt_claim_value=claim_value,
|
|
prisma_client=prisma_client,
|
|
)
|
|
if token_hash is None:
|
|
# The winner's mapping vanished between the unique-constraint
|
|
# conflict and our re-fetch (concurrent delete). Returning None
|
|
# here would silently fall through to team-based JWT auth —
|
|
# a less-restrictive path than the operator configured. Raise
|
|
# 503 so the caller retries against a stable state instead.
|
|
raise HTTPException(
|
|
status_code=503,
|
|
detail=(
|
|
"JWT Key Mapping: AUTO_REGISTER race resolution failed — "
|
|
"winner's mapping was concurrently removed. Retry the request."
|
|
),
|
|
)
|
|
else:
|
|
raise
|
|
|
|
await user_api_key_cache.async_set_cache(
|
|
key=cache_key,
|
|
value=token_hash,
|
|
ttl=jwt_handler.litellm_jwtauth.virtual_key_mapping_cache_ttl,
|
|
)
|
|
|
|
verbose_proxy_logger.info(
|
|
"JWT Key Mapping (auto_register): created new virtual key for %s='%s'.",
|
|
virtual_key_claim_field,
|
|
claim_value,
|
|
)
|
|
|
|
auto_registered_key = await get_key_object(
|
|
hashed_token=token_hash,
|
|
prisma_client=prisma_client,
|
|
user_api_key_cache=user_api_key_cache,
|
|
parent_otel_span=parent_otel_span,
|
|
proxy_logging_obj=proxy_logging_obj,
|
|
)
|
|
if auto_registered_key is not None:
|
|
auto_registered_key.org_id = org_id
|
|
auto_registered_key.end_user_id = end_user_id
|
|
return auto_registered_key
|
|
|
|
|
|
async def _resolve_jwt_to_virtual_key(
|
|
jwt_claims: dict,
|
|
jwt_handler: JWTHandler,
|
|
prisma_client: Optional[PrismaClient],
|
|
user_api_key_cache: UserApiKeyCache,
|
|
parent_otel_span: Optional[Span],
|
|
proxy_logging_obj: ProxyLogging,
|
|
) -> Union[Optional[UserAPIKeyAuth], "_PendingAutoRegister"]:
|
|
"""
|
|
Returns:
|
|
- ``UserAPIKeyAuth``: a resolved virtual key (cache hit or DB hit). The
|
|
caller may use this directly; JWT policy has been enforced previously
|
|
(at key-creation time or, for cached results, before caching).
|
|
- ``_PendingAutoRegister``: claim is unmapped and behavior is AUTO_REGISTER.
|
|
The caller MUST run ``JWTAuthManager.auth_builder`` to enforce JWT
|
|
policy (RBAC, scope, custom_validate, email-domain), then invoke
|
|
``_auto_register_jwt_mapping`` with the validated team_id/user_id.
|
|
- ``None``: claim is unmapped and behavior is FALLBACK_TEAM_MAPPING.
|
|
The caller falls through to standard team-based JWT auth (which itself
|
|
enforces full JWT policy via auth_builder).
|
|
- Raises HTTPException: REJECT policy hit, missing claim under
|
|
REJECT/AUTO_REGISTER, or other policy violations.
|
|
"""
|
|
virtual_key_claim_field = jwt_handler.litellm_jwtauth.virtual_key_claim_field
|
|
if virtual_key_claim_field is None:
|
|
return None
|
|
|
|
claim_value = get_nested_value(
|
|
data=jwt_claims,
|
|
key_path=virtual_key_claim_field,
|
|
default=None,
|
|
)
|
|
|
|
if claim_value is None:
|
|
verbose_proxy_logger.debug(
|
|
f"JWT Key Mapping: Claim field '{virtual_key_claim_field}' not found in JWT claims."
|
|
)
|
|
# A missing claim is an unmapped client — apply the no-match policy
|
|
# rather than returning early. Otherwise a caller can bypass REJECT
|
|
# simply by presenting a JWT that omits the configured field. For
|
|
# AUTO_REGISTER there is no stable identity to map without a claim
|
|
# value, so we deny rather than create a sentinel-keyed record.
|
|
behavior = jwt_handler.litellm_jwtauth.unregistered_jwt_client_behavior
|
|
if behavior in (
|
|
UnregisteredJWTClientBehavior.REJECT,
|
|
UnregisteredJWTClientBehavior.AUTO_REGISTER,
|
|
):
|
|
raise HTTPException(
|
|
status_code=403,
|
|
detail=(
|
|
f"JWT Key Mapping: Required claim '{virtual_key_claim_field}' "
|
|
"is missing from the JWT. Access denied."
|
|
),
|
|
)
|
|
return None
|
|
|
|
cache_key = f"jwt_key_mapping:{virtual_key_claim_field}:{claim_value}"
|
|
cached_mapping = await user_api_key_cache.async_get_cache(cache_key)
|
|
|
|
if cached_mapping == _JWT_PROXY_ADMIN_SENTINEL:
|
|
# Previously resolved to a proxy admin via auth_builder; skip the
|
|
# mapping lookup and let the caller re-run auth_builder. Avoids a
|
|
# repeated DB hit on every proxy-admin request under AUTO_REGISTER.
|
|
return None
|
|
|
|
if cached_mapping == "__NO_MAPPING__":
|
|
behavior = jwt_handler.litellm_jwtauth.unregistered_jwt_client_behavior
|
|
if behavior == UnregisteredJWTClientBehavior.REJECT:
|
|
raise HTTPException(
|
|
status_code=403,
|
|
detail=f"JWT Key Mapping: No registered mapping for {virtual_key_claim_field}='{claim_value}'. Access denied.",
|
|
)
|
|
if behavior == UnregisteredJWTClientBehavior.AUTO_REGISTER:
|
|
# Stale sentinel written under a prior fallback_team_mapping config —
|
|
# evict it and defer auto-register to after auth_builder runs. Raise
|
|
# the same 500 as the fresh-path AUTO_REGISTER branch when there is
|
|
# no DB, so behavior is consistent regardless of whether the cache
|
|
# happens to hold the sentinel.
|
|
if prisma_client is None:
|
|
raise HTTPException(
|
|
status_code=500,
|
|
detail=(
|
|
"JWT Key Mapping: AUTO_REGISTER requires a database connection. "
|
|
"Configure a database or change unregistered_jwt_client_behavior."
|
|
),
|
|
)
|
|
await user_api_key_cache.async_delete_cache(cache_key)
|
|
return _PendingAutoRegister(
|
|
claim_field=virtual_key_claim_field,
|
|
claim_value=str(claim_value),
|
|
cache_key=cache_key,
|
|
)
|
|
return None
|
|
elif cached_mapping is not None:
|
|
return await get_key_object(
|
|
hashed_token=cached_mapping,
|
|
prisma_client=prisma_client,
|
|
user_api_key_cache=user_api_key_cache,
|
|
parent_otel_span=parent_otel_span,
|
|
proxy_logging_obj=proxy_logging_obj,
|
|
)
|
|
|
|
# Resolve the mapping from DB, or treat prisma_client=None as a definitive
|
|
# miss (no DB → no mapping can exist → apply no-match policy below).
|
|
token_hash: Optional[str] = None
|
|
if prisma_client is not None:
|
|
token_hash = await get_jwt_key_mapping_object(
|
|
jwt_claim_name=virtual_key_claim_field,
|
|
jwt_claim_value=str(claim_value),
|
|
prisma_client=prisma_client,
|
|
)
|
|
|
|
if token_hash is not None:
|
|
await user_api_key_cache.async_set_cache(
|
|
key=cache_key,
|
|
value=token_hash,
|
|
ttl=jwt_handler.litellm_jwtauth.virtual_key_mapping_cache_ttl,
|
|
)
|
|
return await get_key_object(
|
|
hashed_token=token_hash,
|
|
prisma_client=prisma_client,
|
|
user_api_key_cache=user_api_key_cache,
|
|
parent_otel_span=parent_otel_span,
|
|
proxy_logging_obj=proxy_logging_obj,
|
|
)
|
|
|
|
# No mapping found (DB miss or no DB) — apply no-match policy.
|
|
behavior = jwt_handler.litellm_jwtauth.unregistered_jwt_client_behavior
|
|
|
|
if behavior == UnregisteredJWTClientBehavior.REJECT:
|
|
# Cache the miss before raising so repeated rejections are served from
|
|
# cache and don't re-query the DB on every request.
|
|
await user_api_key_cache.async_set_cache(
|
|
key=cache_key,
|
|
value="__NO_MAPPING__",
|
|
ttl=jwt_handler.litellm_jwtauth.virtual_key_mapping_cache_ttl,
|
|
)
|
|
raise HTTPException(
|
|
status_code=403,
|
|
detail=f"JWT Key Mapping: No registered mapping for {virtual_key_claim_field}='{claim_value}'. Access denied.",
|
|
)
|
|
|
|
if behavior == UnregisteredJWTClientBehavior.AUTO_REGISTER:
|
|
if prisma_client is None:
|
|
raise HTTPException(
|
|
status_code=500,
|
|
detail=(
|
|
"JWT Key Mapping: AUTO_REGISTER requires a database connection. "
|
|
"Configure a database or change unregistered_jwt_client_behavior."
|
|
),
|
|
)
|
|
# Defer: caller runs JWTAuthManager.auth_builder to enforce RBAC, scope,
|
|
# custom_validate, and email-domain policy, then auto-registers using
|
|
# the validated identity. Auto-registering here on a signature-only
|
|
# JWT would bypass every JWT policy beyond signature verification.
|
|
return _PendingAutoRegister(
|
|
claim_field=virtual_key_claim_field,
|
|
claim_value=str(claim_value),
|
|
cache_key=cache_key,
|
|
)
|
|
|
|
# FALLBACK_TEAM_MAPPING (default): cache the miss and return None so the
|
|
# caller falls through to standard team-based JWT auth.
|
|
await user_api_key_cache.async_set_cache(
|
|
key=cache_key,
|
|
value="__NO_MAPPING__",
|
|
ttl=jwt_handler.litellm_jwtauth.virtual_key_mapping_cache_ttl,
|
|
)
|
|
return None
|
|
|
|
|
|
def _ensure_parent_otel_span_on_request_state(request: Request) -> None:
|
|
"""Idempotently create the OTEL SERVER span and stash it on
|
|
``request.state.parent_otel_span``. Safe to call multiple times.
|
|
|
|
Called both at the top of ``user_api_key_auth`` (so body-parse failures
|
|
have a span to close) and inside ``_user_api_key_auth_builder`` (for
|
|
callers that bypass ``user_api_key_auth``, e.g. MCP).
|
|
"""
|
|
from litellm.proxy.proxy_server import open_telemetry_logger
|
|
|
|
if open_telemetry_logger is None:
|
|
return
|
|
if getattr(request.state, "parent_otel_span", None) is not None:
|
|
return
|
|
start_time = datetime.now()
|
|
try:
|
|
request.state.litellm_received_at = start_time
|
|
except Exception:
|
|
pass
|
|
parent_otel_span = open_telemetry_logger.create_litellm_proxy_request_started_span(
|
|
start_time=start_time,
|
|
headers=_safe_get_request_headers(request),
|
|
)
|
|
# Under V2 the FastAPI instrumentor stamps http.route / url.path on the server
|
|
# span; only the legacy logger needs these set explicitly.
|
|
set_route_attrs = getattr(
|
|
open_telemetry_logger, "set_proxy_request_route_attributes", None
|
|
)
|
|
if not is_otel_v2_enabled() and set_route_attrs is not None:
|
|
set_route_attrs(
|
|
parent_otel_span,
|
|
url_path=get_request_route(request=request),
|
|
http_route=get_request_route_template(request),
|
|
)
|
|
request.state.parent_otel_span = parent_otel_span
|
|
|
|
|
|
async def _user_api_key_auth_builder( # noqa: PLR0915
|
|
request: Request,
|
|
api_key: str,
|
|
azure_api_key_header: str,
|
|
anthropic_api_key_header: Optional[str],
|
|
google_ai_studio_api_key_header: Optional[str],
|
|
azure_apim_header: Optional[str],
|
|
request_data: dict,
|
|
custom_litellm_key_header: Optional[str] = None,
|
|
) -> UserAPIKeyAuth:
|
|
from litellm.proxy.proxy_server import (
|
|
general_settings,
|
|
jwt_handler,
|
|
litellm_proxy_admin_name,
|
|
llm_model_list,
|
|
llm_router,
|
|
master_key,
|
|
model_max_budget_limiter,
|
|
open_telemetry_logger,
|
|
prisma_client,
|
|
proxy_logging_obj,
|
|
user_api_key_cache,
|
|
user_custom_auth,
|
|
)
|
|
|
|
parent_otel_span: Optional[Span] = None
|
|
# Prefer the receive-instant stamped by the early helper in
|
|
# user_api_key_auth (before body parse) — overwriting it would shorten
|
|
# the preprocessing-duration measurement by the body-parse window.
|
|
start_time = getattr(request.state, "litellm_received_at", None) or datetime.now()
|
|
try:
|
|
request.state.litellm_received_at = start_time
|
|
except Exception:
|
|
pass
|
|
route: str = get_request_route(request=request)
|
|
valid_token: Optional[UserAPIKeyAuth] = None
|
|
custom_auth_api_key: bool = False
|
|
|
|
try:
|
|
with tracer.trace("litellm.proxy.auth.pre_db_read_auth_checks"):
|
|
await pre_db_read_auth_checks(
|
|
request_data=request_data,
|
|
request=request,
|
|
route=route,
|
|
)
|
|
pass_through_endpoints: Optional[List[dict]] = general_settings.get(
|
|
"pass_through_endpoints", None
|
|
)
|
|
## CHECK IF X-LITELM-API-KEY IS PASSED IN - supercedes Authorization header
|
|
api_key, passed_in_key = get_api_key(
|
|
custom_litellm_key_header=custom_litellm_key_header,
|
|
api_key=api_key,
|
|
azure_api_key_header=azure_api_key_header,
|
|
anthropic_api_key_header=anthropic_api_key_header,
|
|
google_ai_studio_api_key_header=google_ai_studio_api_key_header,
|
|
azure_apim_header=azure_apim_header,
|
|
pass_through_endpoints=pass_through_endpoints,
|
|
route=route,
|
|
request=request,
|
|
)
|
|
# if user wants to pass LiteLLM_Master_Key as a custom header, example pass litellm keys as X-LiteLLM-Key: Bearer sk-1234
|
|
custom_litellm_key_header_name = general_settings.get("litellm_key_header_name")
|
|
if custom_litellm_key_header_name is not None:
|
|
api_key = get_api_key_from_custom_header(
|
|
request=request,
|
|
custom_litellm_key_header_name=custom_litellm_key_header_name,
|
|
)
|
|
|
|
if open_telemetry_logger is not None:
|
|
# Reuse the span created by user_api_key_auth (before body parse)
|
|
# so it survives _read_request_body failures. For callers that
|
|
# bypass user_api_key_auth (e.g. MCP), create it lazily.
|
|
_ensure_parent_otel_span_on_request_state(request)
|
|
parent_otel_span = getattr(request.state, "parent_otel_span", None)
|
|
|
|
### USER-DEFINED AUTH FUNCTION ###
|
|
if enterprise_custom_auth is not None:
|
|
with tracer.trace("litellm.proxy.auth.enterprise_custom_auth"):
|
|
response = await enterprise_custom_auth(
|
|
request=request, api_key=api_key, user_custom_auth=user_custom_auth
|
|
)
|
|
if response is not None and isinstance(response, UserAPIKeyAuth):
|
|
validated = UserAPIKeyAuth.model_validate(response)
|
|
if getattr(litellm, "enable_post_custom_auth_checks", False):
|
|
validated = await _run_post_custom_auth_checks(
|
|
valid_token=validated,
|
|
request=request,
|
|
request_data=request_data,
|
|
route=route,
|
|
parent_otel_span=parent_otel_span,
|
|
)
|
|
return validated
|
|
elif response is not None and isinstance(response, str):
|
|
api_key = response
|
|
custom_auth_api_key = True
|
|
elif user_custom_auth is not None:
|
|
response = await user_custom_auth(request=request, api_key=api_key) # type: ignore
|
|
validated = UserAPIKeyAuth.model_validate(response)
|
|
if getattr(litellm, "enable_post_custom_auth_checks", False):
|
|
validated = await _run_post_custom_auth_checks(
|
|
valid_token=validated,
|
|
request=request,
|
|
request_data=request_data,
|
|
route=route,
|
|
parent_otel_span=parent_otel_span,
|
|
)
|
|
return validated
|
|
|
|
### LITELLM-DEFINED AUTH FUNCTION ###
|
|
#### IF JWT ####
|
|
"""
|
|
LiteLLM supports using JWTs.
|
|
|
|
Enable this in proxy config, by setting
|
|
```
|
|
general_settings:
|
|
enable_jwt_auth: true
|
|
```
|
|
"""
|
|
|
|
######## Route Checks Before Reading DB / Cache for "token" ################
|
|
if not _route_requires_auth_despite_public(
|
|
route=route, general_settings=general_settings
|
|
) and (
|
|
route in LiteLLMRoutes.public_routes.value # type: ignore
|
|
or route_in_additonal_public_routes(current_route=route)
|
|
):
|
|
# check if public endpoint
|
|
return UserAPIKeyAuth(user_role=LitellmUserRoles.INTERNAL_USER_VIEW_ONLY)
|
|
|
|
########## End of Route Checks Before Reading DB / Cache for "token" ########
|
|
|
|
enable_oauth2_auth = general_settings.get("enable_oauth2_auth", False) is True
|
|
enable_jwt_auth = general_settings.get("enable_jwt_auth", False) is True
|
|
is_jwt = jwt_handler.is_jwt(token=api_key) if enable_jwt_auth else False
|
|
|
|
# Routing uses unverified JWT claims only to choose auth path.
|
|
# Final authentication is enforced by the selected validator.
|
|
route_jwt_to_oauth2 = is_jwt and _should_route_jwt_to_oauth2_override(
|
|
token=api_key, jwt_handler=jwt_handler
|
|
)
|
|
|
|
# OAuth2 applies for:
|
|
# 1) when global OAuth2 auth is enabled on LLM + info routes
|
|
# 2) JWT tokens that explicitly match routing_overrides on LLM + info routes
|
|
should_apply_override_oauth2 = route_jwt_to_oauth2 and (
|
|
RouteChecks.is_llm_api_route(route=route)
|
|
or RouteChecks.is_info_route(route=route)
|
|
)
|
|
should_apply_global_oauth2 = enable_oauth2_auth and (
|
|
RouteChecks.is_llm_api_route(route=route)
|
|
or RouteChecks.is_info_route(route=route)
|
|
)
|
|
if (should_apply_global_oauth2 and not is_jwt) or should_apply_override_oauth2:
|
|
from litellm.proxy.proxy_server import premium_user
|
|
|
|
if premium_user is not True:
|
|
raise ValueError(
|
|
"Oauth2 token validation is only available for premium users"
|
|
+ CommonProxyErrors.not_premium_user.value
|
|
)
|
|
|
|
return await Oauth2Handler.check_oauth2_token(token=api_key)
|
|
|
|
if general_settings.get("enable_oauth2_proxy_auth", False) is True:
|
|
return await handle_oauth2_proxy_request(request=request)
|
|
|
|
if general_settings.get("enable_jwt_auth", False) is True:
|
|
from litellm.proxy.proxy_server import premium_user
|
|
|
|
if premium_user is not True:
|
|
raise ValueError(
|
|
f"JWT Auth is an enterprise only feature. {CommonProxyErrors.not_premium_user.value}"
|
|
)
|
|
is_jwt = jwt_handler.is_jwt(token=api_key)
|
|
verbose_proxy_logger.debug("is_jwt: %s", is_jwt)
|
|
if is_jwt:
|
|
# Try JWT-to-Virtual-Key mapping first to avoid
|
|
# unnecessary DB queries in auth_builder
|
|
do_standard_jwt_auth = True
|
|
pending_auto_register: Optional[_PendingAutoRegister] = None
|
|
if jwt_handler.litellm_jwtauth.virtual_key_claim_field is not None:
|
|
# Decode JWT to get claims without running full auth_builder
|
|
jwt_claims: Optional[dict]
|
|
if jwt_handler.litellm_jwtauth.oidc_userinfo_enabled and not is_jwt:
|
|
jwt_claims = await jwt_handler.get_oidc_userinfo(token=api_key)
|
|
else:
|
|
jwt_claims = await jwt_handler.auth_jwt(token=api_key)
|
|
|
|
resolve_result = await _resolve_jwt_to_virtual_key(
|
|
jwt_claims=jwt_claims,
|
|
jwt_handler=jwt_handler,
|
|
prisma_client=prisma_client,
|
|
user_api_key_cache=user_api_key_cache,
|
|
parent_otel_span=parent_otel_span,
|
|
proxy_logging_obj=proxy_logging_obj,
|
|
)
|
|
if isinstance(resolve_result, UserAPIKeyAuth):
|
|
valid_token = resolve_result
|
|
api_key = valid_token.token or ""
|
|
valid_token.jwt_claims = jwt_claims
|
|
do_standard_jwt_auth = False
|
|
# Fall through to virtual key checks
|
|
elif isinstance(resolve_result, _PendingAutoRegister):
|
|
# Run full JWT policy (RBAC, scope, custom_validate,
|
|
# email-domain) via auth_builder, then create the key
|
|
# from the validated identity below.
|
|
pending_auto_register = resolve_result
|
|
# else: None → FALLBACK_TEAM_MAPPING, falls through to
|
|
# standard JWT auth_builder below
|
|
|
|
if do_standard_jwt_auth:
|
|
with tracer.trace("litellm.proxy.auth.jwt_auth_builder"):
|
|
result = await JWTAuthManager.auth_builder(
|
|
request_data=request_data,
|
|
general_settings=general_settings,
|
|
api_key=api_key,
|
|
jwt_handler=jwt_handler,
|
|
route=route,
|
|
prisma_client=prisma_client,
|
|
user_api_key_cache=user_api_key_cache,
|
|
proxy_logging_obj=proxy_logging_obj,
|
|
parent_otel_span=parent_otel_span,
|
|
request_headers=_safe_get_request_headers(request),
|
|
request_method=RouteChecks._get_request_method(
|
|
request=request
|
|
),
|
|
)
|
|
|
|
is_proxy_admin = result["is_proxy_admin"]
|
|
team_id = result["team_id"]
|
|
team_object = result["team_object"]
|
|
user_id = result["user_id"]
|
|
user_object = result["user_object"]
|
|
end_user_id = result["end_user_id"]
|
|
org_id = result["org_id"]
|
|
team_membership: Optional[LiteLLM_TeamMembership] = result.get(
|
|
"team_membership", None
|
|
)
|
|
jwt_claims = result.get("jwt_claims", None)
|
|
|
|
if is_proxy_admin:
|
|
# Proxy admins authenticate via auth_builder (full
|
|
# access), not via a mapped virtual key. If
|
|
# AUTO_REGISTER was pending, cache a sentinel so
|
|
# future requests from this JWT identity skip the
|
|
# DB mapping lookup in _resolve_jwt_to_virtual_key.
|
|
# Without this, every proxy-admin request under
|
|
# AUTO_REGISTER re-hits get_jwt_key_mapping_object.
|
|
if pending_auto_register is not None:
|
|
await user_api_key_cache.async_set_cache(
|
|
key=pending_auto_register.cache_key,
|
|
value=_JWT_PROXY_ADMIN_SENTINEL,
|
|
ttl=jwt_handler.litellm_jwtauth.virtual_key_mapping_cache_ttl,
|
|
)
|
|
return UserAPIKeyAuth(
|
|
api_key=None,
|
|
user_role=LitellmUserRoles.PROXY_ADMIN,
|
|
user_id=user_id,
|
|
team_id=team_id,
|
|
team_alias=(
|
|
team_object.team_alias
|
|
if team_object is not None
|
|
else None
|
|
),
|
|
team_tpm_limit=(
|
|
team_object.tpm_limit
|
|
if team_object is not None
|
|
else None
|
|
),
|
|
team_rpm_limit=(
|
|
team_object.rpm_limit
|
|
if team_object is not None
|
|
else None
|
|
),
|
|
team_models=(
|
|
team_object.models if team_object is not None else []
|
|
),
|
|
team_metadata=(
|
|
team_object.metadata
|
|
if team_object is not None
|
|
else None
|
|
),
|
|
org_id=org_id,
|
|
end_user_id=end_user_id,
|
|
parent_otel_span=parent_otel_span,
|
|
jwt_claims=jwt_claims,
|
|
)
|
|
|
|
valid_token = UserAPIKeyAuth(
|
|
api_key=None,
|
|
team_id=team_id,
|
|
team_alias=(
|
|
team_object.team_alias if team_object is not None else None
|
|
),
|
|
team_tpm_limit=(
|
|
team_object.tpm_limit if team_object is not None else None
|
|
),
|
|
team_rpm_limit=(
|
|
team_object.rpm_limit if team_object is not None else None
|
|
),
|
|
team_models=(
|
|
team_object.models if team_object is not None else []
|
|
),
|
|
user_role=(
|
|
LitellmUserRoles(user_object.user_role)
|
|
if user_object is not None
|
|
and user_object.user_role is not None
|
|
else LitellmUserRoles.INTERNAL_USER
|
|
),
|
|
user_id=user_id,
|
|
org_id=org_id,
|
|
parent_otel_span=parent_otel_span,
|
|
end_user_id=end_user_id,
|
|
user_tpm_limit=(
|
|
user_object.tpm_limit if user_object is not None else None
|
|
),
|
|
user_rpm_limit=(
|
|
user_object.rpm_limit if user_object is not None else None
|
|
),
|
|
team_member_rpm_limit=(
|
|
team_membership.safe_get_team_member_rpm_limit()
|
|
if team_membership is not None
|
|
else None
|
|
),
|
|
team_member_tpm_limit=(
|
|
team_membership.safe_get_team_member_tpm_limit()
|
|
if team_membership is not None
|
|
else None
|
|
),
|
|
team_metadata=(
|
|
team_object.metadata if team_object is not None else None
|
|
),
|
|
jwt_claims=jwt_claims,
|
|
)
|
|
valid_token.team_object_permission = (
|
|
team_object.object_permission
|
|
if team_object is not None
|
|
else None
|
|
)
|
|
|
|
# AUTO_REGISTER deferred from _resolve_jwt_to_virtual_key.
|
|
# JWT policy (RBAC, scope, custom_validate, email-domain)
|
|
# has now been enforced by auth_builder above. Create the
|
|
# mapping + virtual key from the *validated* identity, then
|
|
# replace valid_token with the new key so downstream checks
|
|
# use the key-scoped path.
|
|
if pending_auto_register is not None and prisma_client is not None:
|
|
auto_registered = await _auto_register_jwt_mapping(
|
|
virtual_key_claim_field=pending_auto_register.claim_field,
|
|
claim_value=pending_auto_register.claim_value,
|
|
jwt_handler=jwt_handler,
|
|
prisma_client=prisma_client,
|
|
user_api_key_cache=user_api_key_cache,
|
|
parent_otel_span=parent_otel_span,
|
|
proxy_logging_obj=proxy_logging_obj,
|
|
cache_key=pending_auto_register.cache_key,
|
|
team_id=team_id,
|
|
user_id=user_id,
|
|
org_id=org_id,
|
|
end_user_id=end_user_id,
|
|
)
|
|
if auto_registered is not None:
|
|
auto_registered.jwt_claims = jwt_claims
|
|
valid_token = auto_registered
|
|
api_key = valid_token.token or ""
|
|
|
|
# Check if model has zero cost - if so, skip all budget checks
|
|
model = _get_model_from_request_context(
|
|
request_data=request_data,
|
|
route=route,
|
|
request=request,
|
|
llm_router=llm_router,
|
|
)
|
|
skip_budget_checks = False
|
|
if model is not None and llm_router is not None:
|
|
from litellm.proxy.auth.auth_checks import _is_model_cost_zero
|
|
|
|
skip_budget_checks = _is_model_cost_zero(
|
|
model=model, llm_router=llm_router
|
|
)
|
|
if skip_budget_checks:
|
|
verbose_proxy_logger.info(
|
|
f"Skipping all budget checks for zero-cost model: {model}"
|
|
)
|
|
|
|
# Fetch project object for JWT path if project_id is set
|
|
_jwt_project_obj = None
|
|
if valid_token.project_id is not None:
|
|
_jwt_project_obj = await get_project_object(
|
|
project_id=valid_token.project_id,
|
|
prisma_client=prisma_client,
|
|
user_api_key_cache=user_api_key_cache,
|
|
proxy_logging_obj=proxy_logging_obj,
|
|
)
|
|
if _jwt_project_obj is not None:
|
|
valid_token.project_metadata = _jwt_project_obj.metadata
|
|
valid_token.project_alias = _jwt_project_obj.project_alias
|
|
|
|
return cast(UserAPIKeyAuth, valid_token)
|
|
|
|
#### ELSE ####
|
|
## CHECK PASS-THROUGH ENDPOINTS ##
|
|
if not custom_auth_api_key:
|
|
response = await check_api_key_for_custom_headers_or_pass_through_endpoints(
|
|
request=request,
|
|
route=route,
|
|
pass_through_endpoints=pass_through_endpoints,
|
|
api_key=api_key,
|
|
)
|
|
if isinstance(response, str):
|
|
api_key = response
|
|
elif isinstance(response, UserAPIKeyAuth):
|
|
return response
|
|
if master_key is None:
|
|
if isinstance(api_key, str):
|
|
return UserAPIKeyAuth(
|
|
api_key=api_key,
|
|
user_role=LitellmUserRoles.INTERNAL_USER,
|
|
parent_otel_span=parent_otel_span,
|
|
)
|
|
else:
|
|
return UserAPIKeyAuth(
|
|
user_role=LitellmUserRoles.INTERNAL_USER,
|
|
parent_otel_span=parent_otel_span,
|
|
)
|
|
elif api_key is None: # only require api key if master key is set
|
|
raise Exception("No api key passed in.")
|
|
elif api_key == "":
|
|
# missing 'Bearer ' prefix
|
|
raise Exception(
|
|
"Malformed API Key passed in. Ensure Key has `Bearer ` prefix."
|
|
)
|
|
|
|
if route == "/user/auth":
|
|
if general_settings.get("allow_user_auth", False) is True:
|
|
return UserAPIKeyAuth()
|
|
else:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_403_FORBIDDEN,
|
|
detail="'allow_user_auth' not set or set to False",
|
|
)
|
|
|
|
## Check END-USER OBJECT
|
|
_end_user_object = None
|
|
end_user_params = {}
|
|
|
|
raw_end_user_id = get_end_user_id_from_request_body(
|
|
request_data, _safe_get_request_headers(request)
|
|
)
|
|
end_user_id = await resolve_and_validate_end_user_id(
|
|
raw_end_user_id=raw_end_user_id,
|
|
prisma_client=prisma_client,
|
|
user_api_key_cache=user_api_key_cache,
|
|
parent_otel_span=parent_otel_span,
|
|
proxy_logging_obj=proxy_logging_obj,
|
|
route=route,
|
|
)
|
|
if end_user_id:
|
|
try:
|
|
end_user_params["end_user_id"] = end_user_id
|
|
|
|
with tracer.trace("litellm.proxy.auth.get_end_user_object"):
|
|
_end_user_object = await get_end_user_object(
|
|
end_user_id=end_user_id,
|
|
prisma_client=prisma_client,
|
|
user_api_key_cache=user_api_key_cache,
|
|
parent_otel_span=parent_otel_span,
|
|
proxy_logging_obj=proxy_logging_obj,
|
|
route=route,
|
|
)
|
|
if _end_user_object is not None:
|
|
end_user_params["allowed_model_region"] = (
|
|
_end_user_object.allowed_model_region
|
|
)
|
|
if _end_user_object.litellm_budget_table is not None:
|
|
_apply_budget_limits_to_end_user_params(
|
|
end_user_params=end_user_params,
|
|
budget_info=_end_user_object.litellm_budget_table,
|
|
end_user_id=end_user_id,
|
|
)
|
|
elif litellm.max_end_user_budget_id is not None:
|
|
# End user doesn't exist yet, but apply default budget limits if configured
|
|
from litellm.proxy.auth.auth_checks import (
|
|
get_default_end_user_budget,
|
|
)
|
|
|
|
default_budget = await get_default_end_user_budget(
|
|
prisma_client=prisma_client,
|
|
user_api_key_cache=user_api_key_cache,
|
|
parent_otel_span=parent_otel_span,
|
|
)
|
|
if default_budget is not None:
|
|
_apply_budget_limits_to_end_user_params(
|
|
end_user_params=end_user_params,
|
|
budget_info=default_budget,
|
|
end_user_id=end_user_id,
|
|
)
|
|
except Exception as e:
|
|
if isinstance(e, litellm.BudgetExceededError):
|
|
raise e
|
|
verbose_proxy_logger.debug(
|
|
"Unable to find user in db. Error - {}".format(str(e))
|
|
)
|
|
pass
|
|
|
|
### CHECK IF ADMIN ###
|
|
# note: never string compare api keys, this is vulenerable to a time attack. Use secrets.compare_digest instead
|
|
### CHECK IF ADMIN ###
|
|
# note: never string compare api keys, this is vulenerable to a time attack. Use secrets.compare_digest instead
|
|
if valid_token is None:
|
|
## Check CACHE
|
|
try:
|
|
with tracer.trace("litellm.proxy.auth.get_key_object_check_cache"):
|
|
valid_token = await get_key_object(
|
|
hashed_token=hash_token(api_key),
|
|
prisma_client=prisma_client,
|
|
user_api_key_cache=user_api_key_cache,
|
|
parent_otel_span=parent_otel_span,
|
|
proxy_logging_obj=proxy_logging_obj,
|
|
check_cache_only=True,
|
|
)
|
|
except Exception:
|
|
verbose_logger.debug("api key not found in cache.")
|
|
valid_token = None
|
|
|
|
## Check UI Hash Key
|
|
if valid_token is None and get_secret_bool("EXPERIMENTAL_UI_LOGIN"):
|
|
valid_token = ExperimentalUIJWTToken.get_key_object_from_ui_hash_key(
|
|
api_key
|
|
)
|
|
|
|
if (
|
|
valid_token is not None
|
|
and isinstance(valid_token, UserAPIKeyAuth)
|
|
and valid_token.user_role == LitellmUserRoles.PROXY_ADMIN
|
|
):
|
|
if valid_token.expires is not None:
|
|
current_time = datetime.now(timezone.utc)
|
|
if isinstance(valid_token.expires, datetime):
|
|
expiry_time = valid_token.expires
|
|
else:
|
|
expiry_time = datetime.fromisoformat(valid_token.expires)
|
|
if (
|
|
expiry_time.tzinfo is None
|
|
or expiry_time.tzinfo.utcoffset(expiry_time) is None
|
|
):
|
|
expiry_time = expiry_time.replace(tzinfo=timezone.utc)
|
|
if expiry_time < current_time:
|
|
await _delete_cache_key_object(
|
|
hashed_token=hash_token(api_key),
|
|
user_api_key_cache=user_api_key_cache,
|
|
proxy_logging_obj=proxy_logging_obj,
|
|
)
|
|
raise ProxyException(
|
|
message=f"Authentication Error - Expired Key. Key Expiry time {expiry_time} and current time {current_time}",
|
|
type=ProxyErrorTypes.expired_key,
|
|
code=status.HTTP_401_UNAUTHORIZED,
|
|
param=abbreviate_api_key(api_key=api_key),
|
|
)
|
|
valid_token = update_valid_token_with_end_user_params(
|
|
valid_token=valid_token, end_user_params=end_user_params
|
|
)
|
|
valid_token.parent_otel_span = parent_otel_span
|
|
if _end_user_object is not None:
|
|
valid_token.end_user_object_permission = (
|
|
_end_user_object.object_permission
|
|
)
|
|
|
|
return valid_token
|
|
|
|
if (
|
|
valid_token is not None
|
|
and isinstance(valid_token, UserAPIKeyAuth)
|
|
and valid_token.team_id is not None
|
|
):
|
|
## UPDATE TEAM VALUES BASED ON CACHED TEAM OBJECT - allows `/team/update` values to work for cached token
|
|
try:
|
|
team_obj: LiteLLM_TeamTableCachedObj = await get_team_object(
|
|
team_id=valid_token.team_id,
|
|
prisma_client=prisma_client,
|
|
user_api_key_cache=user_api_key_cache,
|
|
parent_otel_span=parent_otel_span,
|
|
proxy_logging_obj=proxy_logging_obj,
|
|
check_cache_only=True,
|
|
)
|
|
|
|
if (
|
|
team_obj.last_refreshed_at is not None
|
|
and valid_token.last_refreshed_at is not None
|
|
and team_obj.last_refreshed_at > valid_token.last_refreshed_at
|
|
):
|
|
team_obj_dict = team_obj.__dict__
|
|
|
|
for k, v in team_obj_dict.items():
|
|
field_name = f"team_{k}"
|
|
if field_name in valid_token.__fields__:
|
|
setattr(valid_token, field_name, v)
|
|
except Exception as e:
|
|
verbose_logger.debug(
|
|
e
|
|
) # moving from .warning to .debug as it spams logs when team missing from cache.
|
|
|
|
try:
|
|
is_master_key_valid = secrets.compare_digest(api_key, master_key) # type: ignore
|
|
except Exception:
|
|
is_master_key_valid = False
|
|
|
|
## VALIDATE MASTER KEY ##
|
|
if not isinstance(master_key, str):
|
|
raise HTTPException(
|
|
status_code=500,
|
|
detail={
|
|
"Master key must be a valid string. Current type={}".format(
|
|
type(master_key)
|
|
)
|
|
},
|
|
)
|
|
|
|
if is_master_key_valid:
|
|
# Substitute a stable alias for the raw master key so neither the
|
|
# master key nor its hash propagates into spend logs, Prometheus
|
|
# /metrics labels, audit trails, rate-limit buckets, or any other
|
|
# downstream consumer of UserAPIKeyAuth.api_key.
|
|
_user_api_key_obj = await _return_user_api_key_auth_obj(
|
|
user_obj=None,
|
|
user_role=LitellmUserRoles.PROXY_ADMIN,
|
|
api_key=LITELLM_PROXY_MASTER_KEY_ALIAS,
|
|
parent_otel_span=parent_otel_span,
|
|
valid_token_dict={
|
|
**end_user_params,
|
|
"user_id": litellm_proxy_admin_name,
|
|
},
|
|
route=route,
|
|
start_time=start_time,
|
|
)
|
|
asyncio.create_task(
|
|
_cache_key_object(
|
|
hashed_token=hash_token(master_key),
|
|
user_api_key_obj=_user_api_key_obj,
|
|
user_api_key_cache=user_api_key_cache,
|
|
proxy_logging_obj=proxy_logging_obj,
|
|
)
|
|
)
|
|
|
|
_user_api_key_obj = update_valid_token_with_end_user_params(
|
|
valid_token=_user_api_key_obj, end_user_params=end_user_params
|
|
)
|
|
|
|
return _user_api_key_obj
|
|
|
|
## IF it's not a master key
|
|
## Route should not be in master_key_only_routes
|
|
if route in LiteLLMRoutes.master_key_only_routes.value: # type: ignore
|
|
raise Exception(
|
|
f"Tried to access route={route}, which is only for MASTER KEY"
|
|
)
|
|
|
|
## Check DB
|
|
|
|
if (
|
|
prisma_client is None
|
|
): # if both master key + user key submitted, and user key != master key, and no db connected, raise an error
|
|
raise ProxyException(
|
|
message="No connected db.",
|
|
type=ProxyErrorTypes.no_db_connection,
|
|
code=400,
|
|
param=None,
|
|
)
|
|
|
|
if valid_token is None:
|
|
if isinstance(
|
|
api_key, str
|
|
): # if generated token, make sure it starts with sk-.
|
|
_masked_key = (
|
|
"{}****{}".format(api_key[:4], api_key[-4:])
|
|
if len(api_key) > 8
|
|
else "****"
|
|
)
|
|
if not api_key.startswith("sk-"):
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail=(
|
|
"LiteLLM Virtual Key expected. Received={}, expected to start with 'sk-'.".format(
|
|
_masked_key
|
|
)
|
|
),
|
|
) # prevent token hashes from being used
|
|
else:
|
|
verbose_logger.warning(
|
|
"litellm.proxy.proxy_server.user_api_key_auth(): Warning - Key is not a string. Got type={}".format(
|
|
type(api_key) if api_key is not None else "None"
|
|
)
|
|
)
|
|
abbreviated_api_key = abbreviate_api_key(api_key=api_key)
|
|
if api_key.startswith("sk-"):
|
|
api_key = hash_token(token=api_key)
|
|
|
|
try:
|
|
with tracer.trace("litellm.proxy.auth.get_key_object_from_db"):
|
|
valid_token = await get_key_object(
|
|
hashed_token=api_key,
|
|
prisma_client=prisma_client,
|
|
user_api_key_cache=user_api_key_cache,
|
|
parent_otel_span=parent_otel_span,
|
|
proxy_logging_obj=proxy_logging_obj,
|
|
)
|
|
except ProxyException as e:
|
|
if e.code == 401 or e.code == "401":
|
|
e.message = "Authentication Error, Invalid proxy server token passed. Received API Key = {}, Key Hash (Token) ={}. Unable to find token in cache or `LiteLLM_VerificationTokenTable`".format(
|
|
abbreviated_api_key, api_key
|
|
)
|
|
raise e
|
|
# update end-user params on valid token
|
|
# These can change per request - it's important to update them here
|
|
valid_token.end_user_id = end_user_params.get("end_user_id")
|
|
valid_token.end_user_tpm_limit = end_user_params.get("end_user_tpm_limit")
|
|
valid_token.end_user_rpm_limit = end_user_params.get("end_user_rpm_limit")
|
|
valid_token.allowed_model_region = end_user_params.get(
|
|
"allowed_model_region"
|
|
)
|
|
# update key budget with temp budget increase
|
|
valid_token = _update_key_budget_with_temp_budget_increase(
|
|
valid_token
|
|
) # updating it here, allows all downstream reporting / checks to use the updated budget
|
|
|
|
user_obj: Optional[LiteLLM_UserTable] = None
|
|
valid_token_dict: dict = {}
|
|
if valid_token is not None:
|
|
# Got Valid Token from Cache, DB
|
|
# Run checks for
|
|
# 1. If token can call model
|
|
## 1a. If token can call fallback models (if client-side fallbacks given)
|
|
# 2. If user_id for this token is in budget
|
|
# 3. If the user spend within their own team is within budget
|
|
# 4. If 'user' passed to /chat/completions, /embeddings endpoint is in budget
|
|
# 5. If token is expired
|
|
# 6. If token spend is under Budget for the token
|
|
# 7. If token spend per model is under budget per model
|
|
# 8. If token spend is under team budget
|
|
# 9. If team spend is under team budget
|
|
|
|
## base case ## key is disabled
|
|
if valid_token.blocked is True:
|
|
raise Exception(
|
|
"Key is blocked. Update via `/key/unblock` if you're an admin."
|
|
)
|
|
await _enforce_key_and_fallback_model_access(
|
|
valid_token=valid_token,
|
|
request_data=request_data,
|
|
route=route,
|
|
request=request,
|
|
llm_model_list=llm_model_list,
|
|
llm_router=llm_router,
|
|
)
|
|
|
|
# Check 2. If user_id for this token is in budget - done in common_checks()
|
|
if valid_token.user_id is not None:
|
|
try:
|
|
with tracer.trace("litellm.proxy.auth.get_user_object"):
|
|
user_obj = await get_user_object(
|
|
user_id=valid_token.user_id,
|
|
prisma_client=prisma_client,
|
|
user_api_key_cache=user_api_key_cache,
|
|
user_id_upsert=False,
|
|
parent_otel_span=parent_otel_span,
|
|
proxy_logging_obj=proxy_logging_obj,
|
|
)
|
|
except Exception as e:
|
|
verbose_logger.debug(
|
|
"litellm.proxy.auth.user_api_key_auth.py::user_api_key_auth() - Unable to get user from db/cache. Setting user_obj to None. Exception received - {}".format(
|
|
str(e)
|
|
)
|
|
)
|
|
user_obj = None
|
|
|
|
if (
|
|
user_obj is not None
|
|
and isinstance(user_obj.metadata, dict)
|
|
and user_obj.metadata.get("scim_active") is False
|
|
):
|
|
raise Exception(
|
|
f"User={valid_token.user_id} has been deactivated via SCIM. Keys owned by this user cannot be used."
|
|
)
|
|
|
|
# Check 2a. Check if model has zero cost - if so, skip all budget checks
|
|
model = _get_model_from_request_context(
|
|
request_data=request_data,
|
|
route=route,
|
|
request=request,
|
|
llm_router=llm_router,
|
|
)
|
|
skip_budget_checks = False
|
|
if model is not None and llm_router is not None:
|
|
from litellm.proxy.auth.auth_checks import _is_model_cost_zero
|
|
|
|
skip_budget_checks = _is_model_cost_zero(
|
|
model=model, llm_router=llm_router
|
|
)
|
|
if skip_budget_checks:
|
|
verbose_proxy_logger.info(
|
|
f"Skipping all budget checks for zero-cost model: {model}"
|
|
)
|
|
|
|
# Check 3. Check if user is in their team budget
|
|
if not skip_budget_checks and valid_token.team_member_spend is not None:
|
|
if prisma_client is not None:
|
|
_cache_key = f"{valid_token.team_id}_{valid_token.user_id}"
|
|
|
|
team_member_info = await user_api_key_cache.async_get_cache(
|
|
key=_cache_key,
|
|
model_type=LiteLLM_TeamMembership,
|
|
)
|
|
if team_member_info is None:
|
|
# read from DB
|
|
_user_id = valid_token.user_id
|
|
_team_id = valid_token.team_id
|
|
|
|
if _user_id is not None and _team_id is not None:
|
|
_db_member = await TeamMembershipRepository(
|
|
prisma_client
|
|
).table.find_first(
|
|
where={
|
|
"user_id": _user_id,
|
|
"team_id": _team_id,
|
|
}, # type: ignore
|
|
include={"litellm_budget_table": True},
|
|
)
|
|
if _db_member is not None:
|
|
team_member_info = LiteLLM_TeamMembership(
|
|
**_db_member.dict()
|
|
)
|
|
await user_api_key_cache.async_set_cache(
|
|
key=_cache_key,
|
|
value=team_member_info,
|
|
model_type=LiteLLM_TeamMembership,
|
|
ttl=5,
|
|
)
|
|
|
|
if (
|
|
team_member_info is not None
|
|
and team_member_info.litellm_budget_table is not None
|
|
):
|
|
team_member_budget = (
|
|
team_member_info.litellm_budget_table.max_budget
|
|
)
|
|
if team_member_budget is not None and team_member_budget > 0:
|
|
# Read from cross-pod counter (Redis-first) if available
|
|
from litellm.proxy.proxy_server import get_current_spend
|
|
|
|
team_member_spend = valid_token.team_member_spend
|
|
if (
|
|
valid_token.user_id is not None
|
|
and valid_token.team_id is not None
|
|
):
|
|
team_member_spend = await get_current_spend(
|
|
counter_key=f"spend:team_member:{valid_token.user_id}:{valid_token.team_id}",
|
|
fallback_spend=team_member_spend,
|
|
)
|
|
if team_member_spend > team_member_budget:
|
|
raise litellm.BudgetExceededError(
|
|
current_cost=team_member_spend,
|
|
max_budget=team_member_budget,
|
|
)
|
|
|
|
# Check 3. If token is expired
|
|
if valid_token.expires is not None:
|
|
current_time = datetime.now(timezone.utc)
|
|
if isinstance(valid_token.expires, datetime):
|
|
expiry_time = valid_token.expires
|
|
else:
|
|
expiry_time = datetime.fromisoformat(valid_token.expires)
|
|
if (
|
|
expiry_time.tzinfo is None
|
|
or expiry_time.tzinfo.utcoffset(expiry_time) is None
|
|
):
|
|
expiry_time = expiry_time.replace(tzinfo=timezone.utc)
|
|
verbose_proxy_logger.debug(
|
|
f"Checking if token expired, expiry time {expiry_time} and current time {current_time}"
|
|
)
|
|
if expiry_time < current_time:
|
|
# Token exists but is expired.
|
|
raise ProxyException(
|
|
message=f"Authentication Error - Expired Key. Key Expiry time {expiry_time} and current time {current_time}",
|
|
type=ProxyErrorTypes.expired_key,
|
|
code=status.HTTP_401_UNAUTHORIZED,
|
|
param=abbreviate_api_key(api_key=api_key),
|
|
)
|
|
|
|
if not skip_budget_checks:
|
|
with tracer.trace("litellm.proxy.auth.budget_checks"):
|
|
# Check 4. Max Budget Alert Check (runs before budget enforcement
|
|
# so multi-threshold 100% alerts fire on the request that crosses
|
|
# max_budget, before BudgetExceededError is raised below)
|
|
await _virtual_key_max_budget_alert_check(
|
|
valid_token=valid_token,
|
|
proxy_logging_obj=proxy_logging_obj,
|
|
user_obj=user_obj,
|
|
)
|
|
|
|
# Check 5. Token Spend is under budget
|
|
if RouteChecks.is_llm_api_route(route=route):
|
|
await _virtual_key_max_budget_check(
|
|
valid_token=valid_token,
|
|
proxy_logging_obj=proxy_logging_obj,
|
|
user_obj=user_obj,
|
|
)
|
|
|
|
# Check 6. Soft Budget Check
|
|
await _virtual_key_soft_budget_check(
|
|
valid_token=valid_token,
|
|
proxy_logging_obj=proxy_logging_obj,
|
|
user_obj=user_obj,
|
|
)
|
|
|
|
# Check 5. Token Model Spend is under Model budget
|
|
max_budget_per_model = valid_token.model_max_budget
|
|
current_model = _get_model_from_request_context(
|
|
request_data=request_data,
|
|
route=route,
|
|
request=request,
|
|
llm_router=llm_router,
|
|
)
|
|
current_models = _get_model_names_for_budget_checks(
|
|
model=current_model
|
|
)
|
|
|
|
if (
|
|
max_budget_per_model is not None
|
|
and isinstance(max_budget_per_model, dict)
|
|
and len(max_budget_per_model) > 0
|
|
and prisma_client is not None
|
|
and current_models
|
|
and valid_token.token is not None
|
|
):
|
|
## GET THE SPEND FOR THIS MODEL
|
|
for model_name in current_models:
|
|
await model_max_budget_limiter.is_key_within_model_budget(
|
|
user_api_key_dict=valid_token,
|
|
model=model_name,
|
|
)
|
|
|
|
# Check 5b. End-user model max budget
|
|
end_user_mmb = valid_token.end_user_model_max_budget
|
|
if (
|
|
end_user_mmb is not None
|
|
and isinstance(end_user_mmb, dict)
|
|
and len(end_user_mmb) > 0
|
|
and current_models
|
|
and valid_token.end_user_id is not None
|
|
):
|
|
for model_name in current_models:
|
|
await model_max_budget_limiter.is_end_user_within_model_budget(
|
|
end_user_id=valid_token.end_user_id,
|
|
end_user_model_max_budget=end_user_mmb,
|
|
model=model_name,
|
|
)
|
|
|
|
# Check 6: Additional Common Checks across jwt + key auth
|
|
if valid_token.team_id is not None:
|
|
try:
|
|
with tracer.trace("litellm.proxy.auth.get_team_object"):
|
|
_team_obj = await get_team_object(
|
|
team_id=valid_token.team_id,
|
|
prisma_client=prisma_client,
|
|
user_api_key_cache=user_api_key_cache,
|
|
parent_otel_span=parent_otel_span,
|
|
proxy_logging_obj=proxy_logging_obj,
|
|
)
|
|
except HTTPException:
|
|
_team_obj = LiteLLM_TeamTableCachedObj(
|
|
team_id=valid_token.team_id,
|
|
max_budget=valid_token.team_max_budget,
|
|
soft_budget=valid_token.team_soft_budget,
|
|
spend=valid_token.team_spend,
|
|
tpm_limit=valid_token.team_tpm_limit,
|
|
rpm_limit=valid_token.team_rpm_limit,
|
|
blocked=valid_token.team_blocked,
|
|
models=valid_token.team_models,
|
|
metadata=valid_token.team_metadata,
|
|
object_permission_id=valid_token.team_object_permission_id,
|
|
)
|
|
else:
|
|
_team_obj = None
|
|
|
|
if _team_obj is not None:
|
|
valid_token.team_object_permission = _team_obj.object_permission
|
|
# Keep team_metadata in sync with the freshly fetched team so that
|
|
# guardrails (or any other metadata) added after the key was cached
|
|
# are picked up on subsequent requests without a cache eviction.
|
|
valid_token.team_metadata = _team_obj.metadata
|
|
else:
|
|
valid_token.team_object_permission = None
|
|
|
|
# Only cache when the key is a real team_id (non-team keys must not use key=None).
|
|
if valid_token.team_id is not None and _team_obj is not None:
|
|
await user_api_key_cache.async_set_cache(
|
|
key=valid_token.team_id,
|
|
value=_team_obj,
|
|
model_type=LiteLLM_TeamTableCachedObj,
|
|
) # save team table in cache - used for tpm/rpm limiting - tpm_rpm_limiter.py
|
|
|
|
# Fetch project object if key belongs to a project
|
|
_project_obj = None
|
|
if valid_token.project_id is not None:
|
|
_project_obj = await get_project_object(
|
|
project_id=valid_token.project_id,
|
|
prisma_client=prisma_client,
|
|
user_api_key_cache=user_api_key_cache,
|
|
proxy_logging_obj=proxy_logging_obj,
|
|
)
|
|
if _project_obj is not None:
|
|
valid_token.project_metadata = _project_obj.metadata
|
|
valid_token.project_alias = _project_obj.project_alias
|
|
|
|
global_proxy_spend = None
|
|
if (
|
|
litellm.max_budget > 0 and prisma_client is not None
|
|
): # user set proxy max budget
|
|
cache_key = "{}:spend".format(litellm_proxy_admin_name)
|
|
with tracer.trace("litellm.proxy.auth.get_global_proxy_spend"):
|
|
global_proxy_spend = (
|
|
await _fetch_global_spend_with_event_coordination(
|
|
cache_key=cache_key,
|
|
user_api_key_cache=user_api_key_cache,
|
|
prisma_client=prisma_client,
|
|
)
|
|
)
|
|
|
|
if global_proxy_spend is not None:
|
|
call_info = CallInfo(
|
|
token=valid_token.token,
|
|
spend=global_proxy_spend,
|
|
max_budget=litellm.max_budget,
|
|
user_id=litellm_proxy_admin_name,
|
|
team_id=valid_token.team_id,
|
|
event_group=Litellm_EntityType.PROXY,
|
|
)
|
|
asyncio.create_task(
|
|
proxy_logging_obj.budget_alerts(
|
|
type="proxy_budget",
|
|
user_info=call_info,
|
|
)
|
|
)
|
|
# Token passed all checks
|
|
if valid_token is None:
|
|
raise HTTPException(401, detail="Invalid API key")
|
|
if valid_token.token is None:
|
|
raise HTTPException(401, detail="Invalid API key, no token associated")
|
|
api_key = valid_token.token
|
|
|
|
# Add hashed token to cache
|
|
asyncio.create_task(
|
|
_cache_key_object(
|
|
hashed_token=api_key,
|
|
user_api_key_obj=valid_token,
|
|
user_api_key_cache=user_api_key_cache,
|
|
proxy_logging_obj=proxy_logging_obj,
|
|
)
|
|
)
|
|
|
|
valid_token_dict = valid_token.model_dump(exclude_none=True)
|
|
valid_token_dict.pop("token", None)
|
|
|
|
if _end_user_object is not None:
|
|
valid_token_dict.update(end_user_params)
|
|
valid_token_dict["end_user_object_permission"] = (
|
|
_end_user_object.object_permission
|
|
)
|
|
|
|
# check if token is from litellm-ui, litellm ui makes keys to allow users to login with sso. These keys can only be used for LiteLLM UI functions
|
|
# sso/login, ui/login, /key functions and /user functions
|
|
# this will never be allowed to call /chat/completions
|
|
|
|
if valid_token is None:
|
|
# No token was found when looking up in the DB
|
|
raise Exception("Invalid proxy server token passed")
|
|
if valid_token_dict is not None:
|
|
return await _return_user_api_key_auth_obj(
|
|
user_obj=user_obj,
|
|
api_key=api_key,
|
|
parent_otel_span=parent_otel_span,
|
|
valid_token_dict=valid_token_dict,
|
|
route=route,
|
|
start_time=start_time,
|
|
)
|
|
except Exception as e:
|
|
return await UserAPIKeyAuthExceptionHandler._handle_authentication_error(
|
|
e=e,
|
|
request=request,
|
|
request_data=request_data,
|
|
route=route,
|
|
parent_otel_span=parent_otel_span,
|
|
api_key=api_key,
|
|
resolved_identity=valid_token,
|
|
)
|
|
|
|
|
|
async def _safe_fetch(label: str, awaitable):
|
|
"""Run an awaitable and return its result. Re-raises authentication /
|
|
authorization failures (HTTPException, ProxyException,
|
|
BudgetExceededError) so they propagate to the caller.
|
|
Other exceptions (e.g. transient DB errors fetching context) are
|
|
swallowed with a debug log and ``None`` is returned so
|
|
``common_checks`` can still run against whatever limits are recorded
|
|
directly on the token.
|
|
"""
|
|
try:
|
|
return await awaitable
|
|
except (HTTPException, ProxyException, litellm.BudgetExceededError) as e:
|
|
verbose_proxy_logger.debug(
|
|
"centralized auth: %s fetch failed (%s: %s)",
|
|
label,
|
|
type(e).__name__,
|
|
e,
|
|
)
|
|
raise
|
|
except Exception as e:
|
|
verbose_proxy_logger.debug(
|
|
"centralized auth: %s fetch swallowed (%s: %s)",
|
|
label,
|
|
type(e).__name__,
|
|
e,
|
|
)
|
|
return None
|
|
|
|
|
|
def _team_obj_from_token(valid_token: UserAPIKeyAuth) -> LiteLLM_TeamTableCachedObj:
|
|
"""Reconstruct a cached team object from the fields already on the
|
|
UserAPIKeyAuth. Only called when valid_token.team_id is known to be
|
|
non-None (the caller gates on it)."""
|
|
assert valid_token.team_id is not None
|
|
return LiteLLM_TeamTableCachedObj(
|
|
team_id=valid_token.team_id,
|
|
max_budget=valid_token.team_max_budget,
|
|
soft_budget=valid_token.team_soft_budget,
|
|
spend=valid_token.team_spend,
|
|
tpm_limit=valid_token.team_tpm_limit,
|
|
rpm_limit=valid_token.team_rpm_limit,
|
|
blocked=valid_token.team_blocked,
|
|
models=valid_token.team_models,
|
|
metadata=valid_token.team_metadata,
|
|
object_permission_id=valid_token.team_object_permission_id,
|
|
)
|
|
|
|
|
|
@tracer.wrap()
|
|
async def _run_centralized_common_checks( # noqa: PLR0915
|
|
user_api_key_auth_obj: UserAPIKeyAuth,
|
|
request: Request,
|
|
request_data: dict,
|
|
route: str,
|
|
) -> None:
|
|
"""Run ``common_checks`` once at the ``user_api_key_auth`` wrapper
|
|
boundary, regardless of which ``_user_api_key_auth_builder`` path
|
|
returned. This is the single invariant enforcement point for key
|
|
model-access, budgets, guardrails, org, and vector-store checks.
|
|
|
|
Invariants:
|
|
- ``user_custom_auth`` with ``custom_auth_run_common_checks`` unset
|
|
skips the gate — matches the existing custom-auth RPS guarantee.
|
|
Custom-auth deployments don't use OAuth2 / DB-fallback paths, so
|
|
the skip does not re-open any bypass.
|
|
- ``PROXY_ADMIN`` tokens still run through ``common_checks`` so
|
|
team-blocked / team-budget / end-user-budget / tag-budget /
|
|
vector-store / tool-allowlist enforcement applies to admin keys
|
|
too. Admin status is honored where the underlying check exempts it
|
|
(``_is_api_route_allowed``, ``organization_role_based_access_check``).
|
|
"""
|
|
from litellm.proxy.proxy_server import (
|
|
general_settings,
|
|
litellm_proxy_admin_name,
|
|
llm_router,
|
|
master_key,
|
|
prisma_client,
|
|
proxy_logging_obj,
|
|
user_api_key_cache,
|
|
user_custom_auth,
|
|
)
|
|
|
|
# Public routes (e.g. /health/liveness) are exempt from
|
|
# auth in the builder — the wrapper must not retroactively apply
|
|
# authz on top, or k8s readiness probes and other unauthenticated
|
|
# callers get 401.
|
|
if (
|
|
route in LiteLLMRoutes.public_routes.value # type: ignore[attr-defined]
|
|
or route_in_additonal_public_routes(current_route=route)
|
|
):
|
|
return
|
|
|
|
# User-configured pass-through endpoints with ``auth: false`` are
|
|
# explicitly unauthenticated — the builder returns an empty
|
|
# UserAPIKeyAuth() and the request is forwarded as-is. Running
|
|
# common_checks on the empty token would reject the request as
|
|
# admin-only. The "auth" flag on the endpoint config is the
|
|
# contract; honor it.
|
|
pass_through_endpoints = general_settings.get("pass_through_endpoints", None)
|
|
if pass_through_endpoints is not None:
|
|
for endpoint in pass_through_endpoints:
|
|
if (
|
|
isinstance(endpoint, dict)
|
|
and endpoint.get("path", "") == route
|
|
and endpoint.get("auth") is not True
|
|
):
|
|
return
|
|
|
|
# No-auth dev mode: master_key unset AND no JWT/OAuth2 auth
|
|
# configured. The builder returns an INTERNAL_USER token for any
|
|
# api_key; the proxy is unauthenticated by configuration.
|
|
# Running common_checks would block every admin route on these
|
|
# deployments where that was previously not the contract. If any
|
|
# authn is enabled (JWT, OAuth2, OAuth2-proxy), authz must run.
|
|
if master_key is None and not (
|
|
general_settings.get("enable_jwt_auth", False)
|
|
or general_settings.get("enable_oauth2_auth", False)
|
|
or general_settings.get("enable_oauth2_proxy_auth", False)
|
|
):
|
|
return
|
|
|
|
if user_custom_auth is not None and not general_settings.get(
|
|
"custom_auth_run_common_checks", False
|
|
):
|
|
return
|
|
|
|
parent_otel_span = user_api_key_auth_obj.parent_otel_span
|
|
# In the integrated auth flow ``_user_api_key_auth_builder`` has already
|
|
# resolved the end-user id and attached it here. Reuse that to avoid a
|
|
# second extraction pass; fall back to extracting locally when the
|
|
# function is invoked in isolation (e.g. in direct unit tests).
|
|
end_user_id = user_api_key_auth_obj.end_user_id
|
|
if end_user_id is None:
|
|
raw_end_user_id = get_end_user_id_from_request_body(
|
|
request_data, _safe_get_request_headers(request)
|
|
)
|
|
end_user_id = await resolve_and_validate_end_user_id(
|
|
raw_end_user_id=raw_end_user_id,
|
|
prisma_client=prisma_client,
|
|
user_api_key_cache=user_api_key_cache,
|
|
parent_otel_span=parent_otel_span,
|
|
proxy_logging_obj=proxy_logging_obj,
|
|
route=route,
|
|
)
|
|
|
|
fetch_coros = []
|
|
if user_api_key_auth_obj.team_id is not None:
|
|
fetch_coros.append(
|
|
_safe_fetch(
|
|
"team",
|
|
get_team_object(
|
|
team_id=user_api_key_auth_obj.team_id,
|
|
prisma_client=prisma_client,
|
|
user_api_key_cache=user_api_key_cache,
|
|
parent_otel_span=parent_otel_span,
|
|
proxy_logging_obj=proxy_logging_obj,
|
|
),
|
|
)
|
|
)
|
|
else:
|
|
fetch_coros.append(_safe_fetch("team", _noop_none()))
|
|
|
|
if user_api_key_auth_obj.user_id is not None:
|
|
fetch_coros.append(
|
|
_safe_fetch(
|
|
"user",
|
|
get_user_object(
|
|
user_id=user_api_key_auth_obj.user_id,
|
|
prisma_client=prisma_client,
|
|
user_api_key_cache=user_api_key_cache,
|
|
user_id_upsert=False,
|
|
parent_otel_span=parent_otel_span,
|
|
proxy_logging_obj=proxy_logging_obj,
|
|
),
|
|
)
|
|
)
|
|
else:
|
|
fetch_coros.append(_safe_fetch("user", _noop_none()))
|
|
|
|
if user_api_key_auth_obj.project_id is not None:
|
|
fetch_coros.append(
|
|
_safe_fetch(
|
|
"project",
|
|
get_project_object(
|
|
project_id=user_api_key_auth_obj.project_id,
|
|
prisma_client=prisma_client,
|
|
user_api_key_cache=user_api_key_cache,
|
|
proxy_logging_obj=proxy_logging_obj,
|
|
),
|
|
)
|
|
)
|
|
else:
|
|
fetch_coros.append(_safe_fetch("project", _noop_none()))
|
|
|
|
if end_user_id:
|
|
fetch_coros.append(
|
|
_safe_fetch(
|
|
"end_user",
|
|
get_end_user_object(
|
|
end_user_id=end_user_id,
|
|
prisma_client=prisma_client,
|
|
user_api_key_cache=user_api_key_cache,
|
|
parent_otel_span=parent_otel_span,
|
|
proxy_logging_obj=proxy_logging_obj,
|
|
route=route,
|
|
),
|
|
)
|
|
)
|
|
else:
|
|
fetch_coros.append(_safe_fetch("end_user", _noop_none()))
|
|
|
|
fetch_coros.append(
|
|
_safe_fetch(
|
|
"global_spend",
|
|
get_global_proxy_spend(
|
|
litellm_proxy_admin_name=litellm_proxy_admin_name,
|
|
user_api_key_cache=user_api_key_cache,
|
|
prisma_client=prisma_client,
|
|
token=user_api_key_auth_obj.token or "",
|
|
proxy_logging_obj=proxy_logging_obj,
|
|
),
|
|
)
|
|
)
|
|
|
|
# Per-fetch error isolation. ``_safe_fetch`` lets HTTPException,
|
|
# ProxyException, and BudgetExceededError escape (everything else is
|
|
# already swallowed to None). A bare ``except`` over ``gather`` would
|
|
# let one fetch's HTTPException null out every other context — e.g.
|
|
# a 404 from ``get_team_object`` (token references a deleted team)
|
|
# would silently skip the user, end-user, project, and global-spend
|
|
# checks. Use ``return_exceptions=True`` and apply per-fetch fallback
|
|
# so a missing team only zeros out the team object.
|
|
(
|
|
team_result,
|
|
user_result,
|
|
project_result,
|
|
end_user_result,
|
|
global_spend_result,
|
|
) = await asyncio.gather(*fetch_coros, return_exceptions=True)
|
|
|
|
# ProxyException / BudgetExceededError are authorization failures —
|
|
# propagate so the wrapper renders them. HTTPException is fallback
|
|
# material (404 from get_team_object is the only known producer).
|
|
for r in (
|
|
team_result,
|
|
user_result,
|
|
project_result,
|
|
end_user_result,
|
|
global_spend_result,
|
|
):
|
|
if isinstance(r, (ProxyException, litellm.BudgetExceededError)):
|
|
raise r
|
|
|
|
# Use BaseException (not HTTPException) in the narrowing checks so
|
|
# mypy can narrow ``Any | BaseException`` to the typed object in the
|
|
# else branch. After the for-loop above, the only BaseException that
|
|
# can still appear here is HTTPException (other listed re-raises were
|
|
# propagated; non-listed exceptions were already swallowed to None).
|
|
team_object: Optional[LiteLLM_TeamTableCachedObj]
|
|
if isinstance(team_result, BaseException):
|
|
# Token-derived fallback only valid when a team_id is set;
|
|
# _team_obj_from_token asserts that precondition.
|
|
team_object = (
|
|
_team_obj_from_token(user_api_key_auth_obj)
|
|
if user_api_key_auth_obj.team_id is not None
|
|
else None
|
|
)
|
|
else:
|
|
team_object = team_result
|
|
|
|
user_object: Optional[LiteLLM_UserTable] = (
|
|
None if isinstance(user_result, BaseException) else user_result
|
|
)
|
|
project_object: Optional[LiteLLM_ProjectTableCachedObj] = (
|
|
None if isinstance(project_result, BaseException) else project_result
|
|
)
|
|
end_user_object: Optional[LiteLLM_EndUserTable] = (
|
|
None if isinstance(end_user_result, BaseException) else end_user_result
|
|
)
|
|
global_proxy_spend: Optional[float] = (
|
|
None if isinstance(global_spend_result, BaseException) else global_spend_result
|
|
)
|
|
|
|
# common_checks identifies admin via user_object, not the token
|
|
# (non_proxy_admin_allowed_routes_check). JWT admin shortcut and
|
|
# master_key tokens get admin from the token; the DB row for the
|
|
# same user_id (e.g. litellm_proxy_admin_name = "default_user_id")
|
|
# may have a non-admin user_role and would otherwise demote the
|
|
# caller. The token is the source of truth for these paths — force
|
|
# the admin user_object whenever the token says PROXY_ADMIN, even
|
|
# if a DB row was fetched.
|
|
if user_api_key_auth_obj.user_role == LitellmUserRoles.PROXY_ADMIN:
|
|
user_object = LiteLLM_UserTable(
|
|
user_id=user_api_key_auth_obj.user_id or litellm_proxy_admin_name,
|
|
user_role=LitellmUserRoles.PROXY_ADMIN,
|
|
spend=user_object.spend if user_object is not None else 0.0,
|
|
)
|
|
|
|
if project_object is not None:
|
|
user_api_key_auth_obj.project_metadata = project_object.metadata
|
|
user_api_key_auth_obj.project_alias = project_object.project_alias
|
|
|
|
skip_budget_checks = _should_skip_budget_checks(
|
|
request_data=request_data,
|
|
route=route,
|
|
request=request,
|
|
llm_router=llm_router,
|
|
)
|
|
|
|
# Merge x-litellm-tags into request_data BEFORE common_checks runs.
|
|
# _tag_max_budget_check inside common_checks only inspects request_data;
|
|
# without this pre-merge, header-supplied tags bypass tag-budget
|
|
# enforcement.
|
|
LiteLLMProxyRequestSetup.apply_client_tag_policy_pre_auth(
|
|
request=request,
|
|
request_data=request_data,
|
|
user_api_key_dict=user_api_key_auth_obj,
|
|
)
|
|
|
|
_ = await common_checks(
|
|
request=request,
|
|
request_body=request_data,
|
|
team_object=team_object,
|
|
user_object=user_object,
|
|
end_user_object=end_user_object,
|
|
general_settings=general_settings,
|
|
global_proxy_spend=global_proxy_spend,
|
|
route=route,
|
|
llm_router=llm_router,
|
|
proxy_logging_obj=proxy_logging_obj,
|
|
valid_token=user_api_key_auth_obj,
|
|
skip_budget_checks=skip_budget_checks,
|
|
project_object=project_object,
|
|
)
|
|
|
|
await _reserve_budget_after_common_checks(
|
|
user_api_key_auth_obj=user_api_key_auth_obj,
|
|
request_data=request_data,
|
|
route=route,
|
|
llm_router=llm_router,
|
|
team_object=team_object,
|
|
user_object=user_object,
|
|
end_user_id=end_user_id,
|
|
end_user_object=end_user_object,
|
|
prisma_client=prisma_client,
|
|
user_api_key_cache=user_api_key_cache,
|
|
proxy_logging_obj=proxy_logging_obj,
|
|
skip_budget_checks=skip_budget_checks,
|
|
)
|
|
|
|
|
|
async def _noop_none() -> None:
|
|
"""Sentinel coroutine for asyncio.gather when a fetch is unnecessary
|
|
(e.g. token has no team_id). Keeps the result tuple positional."""
|
|
return None
|
|
|
|
|
|
async def _reserve_budget_after_common_checks(
|
|
user_api_key_auth_obj: UserAPIKeyAuth,
|
|
request_data: dict,
|
|
route: str,
|
|
llm_router: Optional[Any],
|
|
team_object: Optional[LiteLLM_TeamTableCachedObj],
|
|
user_object: Optional[LiteLLM_UserTable],
|
|
prisma_client: Optional[PrismaClient],
|
|
user_api_key_cache: UserApiKeyCache,
|
|
proxy_logging_obj: ProxyLogging,
|
|
skip_budget_checks: bool,
|
|
end_user_id: Optional[str] = None,
|
|
end_user_object: Optional[LiteLLM_EndUserTable] = None,
|
|
) -> None:
|
|
user_api_key_auth_obj.budget_reservation = None
|
|
if skip_budget_checks:
|
|
return
|
|
|
|
from litellm.proxy.spend_tracking.budget_reservation import (
|
|
reserve_budget_for_request,
|
|
)
|
|
|
|
user_api_key_auth_obj.budget_reservation = await reserve_budget_for_request(
|
|
request_body=request_data,
|
|
route=route,
|
|
llm_router=llm_router,
|
|
valid_token=user_api_key_auth_obj,
|
|
team_object=team_object,
|
|
user_object=user_object,
|
|
prisma_client=prisma_client,
|
|
user_api_key_cache=user_api_key_cache,
|
|
proxy_logging_obj=proxy_logging_obj,
|
|
end_user_id=end_user_id,
|
|
end_user_object=end_user_object,
|
|
)
|
|
|
|
|
|
def _should_skip_budget_checks(
|
|
request_data: dict,
|
|
route: str,
|
|
request: Optional[Request],
|
|
llm_router: Optional[Any],
|
|
) -> bool:
|
|
model = _get_model_from_request_context(
|
|
request_data=request_data,
|
|
route=route,
|
|
request=request,
|
|
llm_router=llm_router,
|
|
)
|
|
if model is not None and llm_router is not None:
|
|
return _is_model_cost_zero(model=model, llm_router=llm_router)
|
|
return False
|
|
|
|
|
|
@tracer.wrap()
|
|
async def user_api_key_auth(
|
|
request: Request,
|
|
api_key: str = fastapi.Security(api_key_header),
|
|
azure_api_key_header: str = fastapi.Security(azure_api_key_header),
|
|
anthropic_api_key_header: Optional[str] = fastapi.Security(
|
|
anthropic_api_key_header
|
|
),
|
|
google_ai_studio_api_key_header: Optional[str] = fastapi.Security(
|
|
google_ai_studio_api_key_header
|
|
),
|
|
azure_apim_header: Optional[str] = fastapi.Security(azure_apim_header),
|
|
custom_litellm_key_header: Optional[str] = fastapi.Security(
|
|
custom_litellm_key_header
|
|
),
|
|
) -> UserAPIKeyAuth:
|
|
"""
|
|
Parent function to authenticate user api key / jwt token.
|
|
"""
|
|
|
|
# Create the SERVER span and stash it on request.state BEFORE reading the
|
|
# body. _read_request_body can raise ProxyException for malformed JSON;
|
|
# without this, that path leaves no span for the exception handler to
|
|
# close, and the trace never reaches the backend.
|
|
_ensure_parent_otel_span_on_request_state(request)
|
|
|
|
request_data = await _read_request_body(request=request)
|
|
request_data = populate_request_with_path_params(
|
|
request_data=request_data, request=request
|
|
)
|
|
route: str = get_request_route(request=request)
|
|
## CHECK IF ROUTE IS ALLOWED
|
|
|
|
# Run the whole auth phase inside a live ``auth`` span so the DB lookups it
|
|
# triggers (key/user/team object reads) nest under it instead of flattening
|
|
# onto the server span. No-op when OTel V2 isn't active.
|
|
with phase_span(f"auth {route}"):
|
|
user_api_key_auth_obj = await _user_api_key_auth_builder(
|
|
request=request,
|
|
api_key=api_key,
|
|
azure_api_key_header=azure_api_key_header,
|
|
anthropic_api_key_header=anthropic_api_key_header,
|
|
google_ai_studio_api_key_header=google_ai_studio_api_key_header,
|
|
azure_apim_header=azure_apim_header,
|
|
request_data=request_data,
|
|
custom_litellm_key_header=custom_litellm_key_header,
|
|
)
|
|
user_api_key_auth_obj.budget_reservation = None
|
|
|
|
## ENSURE DISABLE ROUTE WORKS ACROSS ALL USER AUTH FLOWS ##
|
|
RouteChecks.should_call_route(
|
|
route=route, valid_token=user_api_key_auth_obj, request=request
|
|
)
|
|
|
|
# Single authorization point. Builder paths MUST NOT call common_checks.
|
|
# Route through the same exception handler the builder uses so
|
|
# authorization failures (ProxyException, or plain Exception from
|
|
# admin-only-route / model-access / budget checks) surface as
|
|
# ProxyException consistently with pre-refactor behavior.
|
|
try:
|
|
await _run_centralized_common_checks(
|
|
user_api_key_auth_obj=user_api_key_auth_obj,
|
|
request=request,
|
|
request_data=request_data,
|
|
route=route,
|
|
)
|
|
except Exception as e:
|
|
return await UserAPIKeyAuthExceptionHandler._handle_authentication_error(
|
|
e=e,
|
|
request=request,
|
|
request_data=request_data,
|
|
route=route,
|
|
parent_otel_span=user_api_key_auth_obj.parent_otel_span,
|
|
api_key=api_key,
|
|
resolved_identity=user_api_key_auth_obj,
|
|
)
|
|
|
|
# Defense-in-depth: ``_user_api_key_auth_builder`` has multiple early-return
|
|
# paths (no master key, /user/auth route, JWT short-circuits) that bypass
|
|
# the end-user resolution block. If those paths produced an auth obj
|
|
# without an ``end_user_id`` set, fall back to extracting from the request
|
|
# body so spend logs are still attributed correctly. Validation honours
|
|
# ``litellm.validate_end_user_id_in_db``.
|
|
if user_api_key_auth_obj.end_user_id is None:
|
|
from litellm.proxy.proxy_server import (
|
|
prisma_client,
|
|
proxy_logging_obj,
|
|
user_api_key_cache,
|
|
)
|
|
|
|
raw_end_user_id = get_end_user_id_from_request_body(
|
|
request_data, _safe_get_request_headers(request)
|
|
)
|
|
if raw_end_user_id is not None:
|
|
resolved_end_user_id = await resolve_and_validate_end_user_id(
|
|
raw_end_user_id=raw_end_user_id,
|
|
prisma_client=prisma_client,
|
|
user_api_key_cache=user_api_key_cache,
|
|
parent_otel_span=user_api_key_auth_obj.parent_otel_span,
|
|
proxy_logging_obj=proxy_logging_obj,
|
|
route=route,
|
|
)
|
|
if resolved_end_user_id is not None:
|
|
user_api_key_auth_obj.end_user_id = resolved_end_user_id
|
|
|
|
# Identity is now resolved. Seed it AFTER the auth span closes so the Baggage
|
|
# persists on the request task (detaching the span's context token inside the
|
|
# ``with`` would unwind a Baggage attach made within it) and every post-auth
|
|
# span — pre-call, LLM call, guardrail, spend write — inherits team/key/user.
|
|
seed_request_identity(
|
|
user_api_key_auth_obj,
|
|
model=request_data.get("model") if isinstance(request_data, dict) else None,
|
|
)
|
|
user_api_key_auth_obj.request_route = normalize_request_route(route)
|
|
return user_api_key_auth_obj
|
|
|
|
|
|
async def _return_user_api_key_auth_obj(
|
|
user_obj: Optional[LiteLLM_UserTable],
|
|
api_key: str,
|
|
parent_otel_span: Optional[Span],
|
|
valid_token_dict: dict,
|
|
route: str,
|
|
start_time: datetime,
|
|
user_role: Optional[LitellmUserRoles] = None,
|
|
) -> UserAPIKeyAuth:
|
|
end_time = datetime.now()
|
|
|
|
asyncio.create_task(
|
|
user_api_key_service_logger_obj.async_service_success_hook(
|
|
service=ServiceTypes.AUTH,
|
|
call_type=route,
|
|
start_time=start_time,
|
|
end_time=end_time,
|
|
duration=end_time.timestamp() - start_time.timestamp(),
|
|
parent_otel_span=parent_otel_span,
|
|
)
|
|
)
|
|
|
|
retrieved_user_role = (
|
|
user_role or _get_user_role(user_obj=user_obj) or LitellmUserRoles.INTERNAL_USER
|
|
)
|
|
|
|
user_api_key_kwargs = {
|
|
"api_key": api_key,
|
|
"parent_otel_span": parent_otel_span,
|
|
"user_role": retrieved_user_role,
|
|
**valid_token_dict,
|
|
}
|
|
if user_obj is not None:
|
|
user_api_key_kwargs.update(
|
|
user_tpm_limit=user_obj.tpm_limit,
|
|
user_rpm_limit=user_obj.rpm_limit,
|
|
user_email=user_obj.user_email,
|
|
user_spend=getattr(user_obj, "spend", None),
|
|
user_max_budget=getattr(user_obj, "max_budget", None),
|
|
)
|
|
if user_obj is not None and _is_user_proxy_admin(user_obj=user_obj):
|
|
user_api_key_kwargs.update(
|
|
user_role=LitellmUserRoles.PROXY_ADMIN,
|
|
)
|
|
return UserAPIKeyAuth(**user_api_key_kwargs)
|
|
else:
|
|
return UserAPIKeyAuth(**user_api_key_kwargs)
|
|
|
|
|
|
def get_api_key_from_custom_header(
|
|
request: Request, custom_litellm_key_header_name: str
|
|
) -> str:
|
|
"""
|
|
Get API key from custom header
|
|
|
|
Args:
|
|
request (Request): Request object
|
|
custom_litellm_key_header_name (str): Custom header name
|
|
|
|
Returns:
|
|
Optional[str]: API key
|
|
"""
|
|
api_key: str = ""
|
|
# use this as the virtual key passed to litellm proxy
|
|
custom_litellm_key_header_name = custom_litellm_key_header_name.lower()
|
|
_headers = {k.lower(): v for k, v in request.headers.items()}
|
|
verbose_proxy_logger.debug(
|
|
"searching for custom_litellm_key_header_name= %s, in headers=%s",
|
|
custom_litellm_key_header_name,
|
|
_headers,
|
|
)
|
|
custom_api_key = _headers.get(custom_litellm_key_header_name)
|
|
if custom_api_key:
|
|
api_key = _get_bearer_token(api_key=custom_api_key)
|
|
verbose_proxy_logger.debug(
|
|
"Found custom API key using header: {}, setting api_key={}".format(
|
|
custom_litellm_key_header_name, abbreviate_api_key(api_key)
|
|
)
|
|
)
|
|
else:
|
|
verbose_proxy_logger.exception(
|
|
f"No LiteLLM Virtual Key pass. Please set header={custom_litellm_key_header_name}: Bearer <api_key>"
|
|
)
|
|
return api_key
|
|
|
|
|
|
def _get_temp_budget_increase(valid_token: UserAPIKeyAuth):
|
|
valid_token_metadata = valid_token.metadata
|
|
if (
|
|
"temp_budget_increase" in valid_token_metadata
|
|
and "temp_budget_expiry" in valid_token_metadata
|
|
):
|
|
expiry = datetime.fromisoformat(valid_token_metadata["temp_budget_expiry"])
|
|
if expiry > datetime.now():
|
|
return valid_token_metadata["temp_budget_increase"]
|
|
return None
|
|
|
|
|
|
def _update_key_budget_with_temp_budget_increase(
|
|
valid_token: UserAPIKeyAuth,
|
|
) -> UserAPIKeyAuth:
|
|
if valid_token.max_budget is None:
|
|
return valid_token
|
|
temp_budget_increase = _get_temp_budget_increase(valid_token) or 0.0
|
|
valid_token.max_budget = valid_token.max_budget + temp_budget_increase
|
|
return valid_token
|
|
|
|
|
|
async def _lookup_end_user_and_apply_budget(
|
|
valid_token: UserAPIKeyAuth,
|
|
route: str,
|
|
parent_otel_span: Optional[Span],
|
|
prisma_client,
|
|
user_api_key_cache,
|
|
proxy_logging_obj,
|
|
):
|
|
"""Look up end_user from DB and apply budget limits to valid_token."""
|
|
end_user_object = None
|
|
try:
|
|
end_user_object = await get_end_user_object(
|
|
end_user_id=valid_token.end_user_id,
|
|
prisma_client=prisma_client,
|
|
user_api_key_cache=user_api_key_cache,
|
|
parent_otel_span=parent_otel_span,
|
|
proxy_logging_obj=proxy_logging_obj,
|
|
route=route,
|
|
)
|
|
if end_user_object is not None:
|
|
end_user_params = {
|
|
"end_user_id": valid_token.end_user_id,
|
|
"allowed_model_region": end_user_object.allowed_model_region,
|
|
}
|
|
if end_user_object.litellm_budget_table is not None:
|
|
_apply_budget_limits_to_end_user_params(
|
|
end_user_params=end_user_params,
|
|
budget_info=end_user_object.litellm_budget_table,
|
|
end_user_id=valid_token.end_user_id or "",
|
|
)
|
|
valid_token = update_valid_token_with_end_user_params(
|
|
valid_token=valid_token, end_user_params=end_user_params
|
|
)
|
|
elif litellm.max_end_user_budget_id is not None:
|
|
from litellm.proxy.auth.auth_checks import get_default_end_user_budget
|
|
|
|
default_budget = await get_default_end_user_budget(
|
|
prisma_client=prisma_client,
|
|
user_api_key_cache=user_api_key_cache,
|
|
parent_otel_span=parent_otel_span,
|
|
)
|
|
if default_budget is not None:
|
|
end_user_params = {"end_user_id": valid_token.end_user_id}
|
|
_apply_budget_limits_to_end_user_params(
|
|
end_user_params=end_user_params,
|
|
budget_info=default_budget,
|
|
end_user_id=valid_token.end_user_id or "",
|
|
)
|
|
valid_token = update_valid_token_with_end_user_params(
|
|
valid_token=valid_token, end_user_params=end_user_params
|
|
)
|
|
except Exception as e:
|
|
if isinstance(e, litellm.BudgetExceededError):
|
|
raise e
|
|
verbose_proxy_logger.debug(f"Unable to find user in db. Error - {str(e)}")
|
|
return valid_token, end_user_object
|
|
|
|
|
|
async def _enforce_key_and_fallback_model_access(
|
|
*,
|
|
valid_token: UserAPIKeyAuth,
|
|
request_data: dict,
|
|
route: str,
|
|
request: Optional[Request],
|
|
llm_model_list: Optional[list],
|
|
llm_router: Optional[Any],
|
|
) -> None:
|
|
"""
|
|
Key-level model allowlist and client fallbacks (same as standard auth).
|
|
Not included in common_checks — common_checks enforces team/user/project model access only.
|
|
"""
|
|
config = valid_token.config
|
|
|
|
if config != {}:
|
|
model_list = config.get("model_list", [])
|
|
new_model_list = model_list
|
|
verbose_proxy_logger.debug(f"\n new llm router model list {new_model_list}")
|
|
elif (
|
|
isinstance(valid_token.models, list) and "all-team-models" in valid_token.models
|
|
):
|
|
pass
|
|
else:
|
|
model = _get_model_from_request_context(
|
|
request_data=request_data,
|
|
route=route,
|
|
request=request,
|
|
llm_router=llm_router,
|
|
)
|
|
|
|
if model is not None:
|
|
await can_key_call_model(
|
|
model=model,
|
|
llm_model_list=llm_model_list,
|
|
valid_token=valid_token,
|
|
llm_router=llm_router,
|
|
)
|
|
|
|
# Validate every fallback model name reachable by this request.
|
|
# All three fields (``fallbacks``, ``context_window_fallbacks``,
|
|
# ``content_policy_fallbacks``) are forwarded to the router as
|
|
# per-request kwargs whether they appear at the top level of
|
|
# ``request_data`` or nested under ``router_settings_override``.
|
|
# Both surfaces must be validated against the API key's model
|
|
# allowlist or a caller can smuggle a restricted model. VERIA-44.
|
|
fallback_names: List[str] = []
|
|
override_settings = request_data.get("router_settings_override")
|
|
for _fb_key in ROUTER_FALLBACK_FIELDS:
|
|
fallback_names.extend(
|
|
iter_router_fallback_model_names(request_data.get(_fb_key))
|
|
)
|
|
if isinstance(override_settings, dict):
|
|
fallback_names.extend(
|
|
iter_router_fallback_model_names(override_settings.get(_fb_key))
|
|
)
|
|
|
|
for _name in dict.fromkeys(fallback_names): # dedupe, preserve order
|
|
await can_key_call_model(
|
|
model=_name,
|
|
llm_model_list=llm_model_list,
|
|
valid_token=valid_token,
|
|
llm_router=llm_router,
|
|
)
|
|
await is_valid_fallback_model(
|
|
model=_name,
|
|
llm_router=llm_router,
|
|
user_model=None,
|
|
)
|
|
|
|
|
|
ROUTER_FALLBACK_FIELDS: Tuple[str, ...] = (
|
|
"fallbacks",
|
|
"context_window_fallbacks",
|
|
"content_policy_fallbacks",
|
|
)
|
|
|
|
|
|
def iter_router_fallback_model_names(fallbacks: Any) -> Iterator[str]:
|
|
"""Yield leaf model names from any of the supported fallbacks shapes.
|
|
|
|
Handles the simple top-level shape (``str`` or ``{"model": str}``) and
|
|
the nested router-config shape (``[{primary: [fallback_list]}]``).
|
|
"""
|
|
if not isinstance(fallbacks, list):
|
|
return
|
|
for entry in fallbacks:
|
|
if isinstance(entry, str):
|
|
yield entry
|
|
elif isinstance(entry, dict):
|
|
if isinstance(entry.get("model"), str):
|
|
yield entry["model"]
|
|
continue
|
|
for fallback_list in entry.values():
|
|
if not isinstance(fallback_list, list):
|
|
continue
|
|
for m in fallback_list:
|
|
if isinstance(m, str):
|
|
yield m
|
|
elif isinstance(m, dict) and isinstance(m.get("model"), str):
|
|
yield m["model"]
|
|
|
|
|
|
async def _run_post_custom_auth_checks(
|
|
valid_token: UserAPIKeyAuth,
|
|
request: Request,
|
|
request_data: dict,
|
|
route: str,
|
|
parent_otel_span: Optional[Span],
|
|
) -> UserAPIKeyAuth:
|
|
from litellm.proxy.proxy_server import (
|
|
general_settings,
|
|
llm_model_list,
|
|
llm_router,
|
|
model_max_budget_limiter,
|
|
prisma_client,
|
|
proxy_logging_obj,
|
|
user_api_key_cache,
|
|
)
|
|
|
|
# 1. Look up end_user object from DB if end_user_id is set
|
|
end_user_object = None
|
|
if valid_token.end_user_id is not None:
|
|
valid_token, end_user_object = await _lookup_end_user_and_apply_budget(
|
|
valid_token=valid_token,
|
|
route=route,
|
|
parent_otel_span=parent_otel_span,
|
|
prisma_client=prisma_client,
|
|
user_api_key_cache=user_api_key_cache,
|
|
proxy_logging_obj=proxy_logging_obj,
|
|
)
|
|
# common_checks() enforces the end-user budget, but the centralized
|
|
# gate skips it for custom-auth deployments unless
|
|
# custom_auth_run_common_checks is set. Enforce it here on that path
|
|
# so an over-budget end user can't keep making requests.
|
|
if end_user_object is not None and not general_settings.get(
|
|
"custom_auth_run_common_checks", False
|
|
):
|
|
await _check_end_user_budget(end_user_obj=end_user_object, route=route)
|
|
|
|
# 2. Check token expiry
|
|
if valid_token.expires is not None:
|
|
current_time = datetime.now(timezone.utc)
|
|
if isinstance(valid_token.expires, datetime):
|
|
expiry_time = valid_token.expires
|
|
else:
|
|
expiry_time = datetime.fromisoformat(valid_token.expires)
|
|
if (
|
|
expiry_time.tzinfo is None
|
|
or expiry_time.tzinfo.utcoffset(expiry_time) is None
|
|
):
|
|
expiry_time = expiry_time.replace(tzinfo=timezone.utc)
|
|
if expiry_time < current_time:
|
|
raise ProxyException(
|
|
message=f"Authentication Error - Expired Key. Key Expiry time {expiry_time} and current time {current_time}",
|
|
type=ProxyErrorTypes.expired_key,
|
|
code=status.HTTP_401_UNAUTHORIZED,
|
|
param=(
|
|
abbreviate_api_key(api_key=valid_token.token)
|
|
if valid_token.token
|
|
else ""
|
|
),
|
|
)
|
|
|
|
if general_settings.get("custom_auth_run_common_checks", False):
|
|
await _enforce_key_and_fallback_model_access(
|
|
valid_token=valid_token,
|
|
request_data=request_data,
|
|
route=route,
|
|
request=request,
|
|
llm_model_list=llm_model_list,
|
|
llm_router=llm_router,
|
|
)
|
|
|
|
current_model = _get_model_from_request_context(
|
|
request_data=request_data,
|
|
route=route,
|
|
request=request,
|
|
llm_router=llm_router,
|
|
)
|
|
current_models = _get_model_names_for_budget_checks(model=current_model)
|
|
|
|
# 3. Check key-level model_max_budget
|
|
max_budget_per_model = valid_token.model_max_budget
|
|
if (
|
|
max_budget_per_model is not None
|
|
and isinstance(max_budget_per_model, dict)
|
|
and len(max_budget_per_model) > 0
|
|
and current_models
|
|
and valid_token.token is not None
|
|
):
|
|
for model_name in current_models:
|
|
await model_max_budget_limiter.is_key_within_model_budget(
|
|
user_api_key_dict=valid_token,
|
|
model=model_name,
|
|
)
|
|
|
|
# 4. Check end-user model_max_budget
|
|
end_user_mmb = valid_token.end_user_model_max_budget
|
|
if (
|
|
end_user_mmb is not None
|
|
and isinstance(end_user_mmb, dict)
|
|
and len(end_user_mmb) > 0
|
|
and current_models
|
|
and valid_token.end_user_id is not None
|
|
):
|
|
for model_name in current_models:
|
|
await model_max_budget_limiter.is_end_user_within_model_budget(
|
|
end_user_id=valid_token.end_user_id,
|
|
end_user_model_max_budget=end_user_mmb,
|
|
model=model_name,
|
|
)
|
|
|
|
# team / user / end_user / project context objects are fetched by
|
|
# the centralized common_checks gate in user_api_key_auth after
|
|
# this helper returns. Keep only the project fetch here because it
|
|
# mutates the token (project_metadata / project_alias).
|
|
if valid_token.project_id is not None:
|
|
_project_obj = await get_project_object(
|
|
project_id=valid_token.project_id,
|
|
prisma_client=prisma_client,
|
|
user_api_key_cache=user_api_key_cache,
|
|
proxy_logging_obj=proxy_logging_obj,
|
|
)
|
|
if _project_obj is not None:
|
|
valid_token.project_metadata = _project_obj.metadata
|
|
valid_token.project_alias = _project_obj.project_alias
|
|
|
|
return valid_token
|