mirror of
https://github.com/tiennm99/litellm.git
synced 2026-06-17 18:48:36 +00:00
feat(handle_jwt.py): initial commit adding custom RBAC support on jwt… (#8037)
* feat(handle_jwt.py): initial commit adding custom RBAC support on jwt auth allows admin to define user role field and allowed roles which map to 'internal_user' on litellm * fix(auth_checks.py): ensure user allowed to access model, when calling via personal keys Fixes https://github.com/BerriAI/litellm/issues/8029 * feat(handle_jwt.py): support role based access with model permission control on proxy Allows admin to just grant users roles on IDP (e.g. Azure AD/Keycloak) and user can immediately start calling models * docs(rbac): add docs on rbac for model access control make it clear how admin can use roles to control model access on proxy * fix: fix linting errors * test(test_user_api_key_auth.py): add unit testing to ensure rbac role is correctly enforced * test(test_user_api_key_auth.py): add more testing * test(test_users.py): add unit testing to ensure user model access is always checked for new keys Resolves https://github.com/BerriAI/litellm/issues/8029 * test: fix unit test * fix(dot_notation_indexing.py): fix typing to work with python 3.8
This commit is contained in:
@@ -0,0 +1,116 @@
|
||||
import Image from '@theme/IdealImage';
|
||||
import Tabs from '@theme/Tabs';
|
||||
import TabItem from '@theme/TabItem';
|
||||
|
||||
# Control Model Access with SSO (Azure AD/Keycloak/etc.)
|
||||
|
||||
:::info
|
||||
|
||||
✨ JWT Auth is on LiteLLM Enterprise
|
||||
|
||||
[Enterprise Pricing](https://www.litellm.ai/#pricing)
|
||||
|
||||
[Get free 7-day trial key](https://www.litellm.ai/#trial)
|
||||
|
||||
:::
|
||||
|
||||
<Image img={require('../../img/control_model_access_jwt.png')} style={{ width: '100%', maxWidth: '4000px' }} />
|
||||
|
||||
## Example Token
|
||||
|
||||
<Tabs>
|
||||
<TabItem value="Azure AD">
|
||||
|
||||
```bash
|
||||
{
|
||||
"sub": "1234567890",
|
||||
"name": "John Doe",
|
||||
"email": "john.doe@example.com",
|
||||
"roles": ["basic_user"] # 👈 ROLE
|
||||
}
|
||||
```
|
||||
</TabItem>
|
||||
<TabItem value="Keycloak">
|
||||
|
||||
```bash
|
||||
{
|
||||
"sub": "1234567890",
|
||||
"name": "John Doe",
|
||||
"email": "john.doe@example.com",
|
||||
"resource_access": {
|
||||
"litellm-test-client-id": {
|
||||
"roles": ["basic_user"] # 👈 ROLE
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
</TabItem>
|
||||
</Tabs>
|
||||
|
||||
## Proxy Configuration
|
||||
|
||||
<Tabs>
|
||||
<TabItem value="Azure AD">
|
||||
|
||||
```yaml
|
||||
general_settings:
|
||||
enable_jwt_auth: True
|
||||
litellm_jwtauth:
|
||||
user_roles_jwt_field: "roles" # the field in the JWT that contains the roles
|
||||
user_allowed_roles: ["basic_user"] # roles that map to an 'internal_user' role on LiteLLM
|
||||
enforce_rbac: true # if true, will check if the user has the correct role to access the model
|
||||
|
||||
role_permissions: # control what models are allowed for each role
|
||||
- role: internal_user
|
||||
models: ["anthropic-claude"]
|
||||
|
||||
model_list:
|
||||
- model: anthropic-claude
|
||||
litellm_params:
|
||||
model: claude-3-5-haiku-20241022
|
||||
- model: openai-gpt-4o
|
||||
litellm_params:
|
||||
model: gpt-4o
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
<TabItem value="Keycloak">
|
||||
|
||||
```yaml
|
||||
general_settings:
|
||||
enable_jwt_auth: True
|
||||
litellm_jwtauth:
|
||||
user_roles_jwt_field: "resource_access.litellm-test-client-id.roles" # the field in the JWT that contains the roles
|
||||
user_allowed_roles: ["basic_user"] # roles that map to an 'internal_user' role on LiteLLM
|
||||
enforce_rbac: true # if true, will check if the user has the correct role to access the model
|
||||
|
||||
role_permissions: # control what models are allowed for each role
|
||||
- role: internal_user
|
||||
models: ["anthropic-claude"]
|
||||
|
||||
model_list:
|
||||
- model: anthropic-claude
|
||||
litellm_params:
|
||||
model: claude-3-5-haiku-20241022
|
||||
- model: openai-gpt-4o
|
||||
litellm_params:
|
||||
model: gpt-4o
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
</Tabs>
|
||||
|
||||
|
||||
## How it works
|
||||
|
||||
1. Specify JWT_PUBLIC_KEY_URL - This is the public keys endpoint of your OpenID provider. For Azure AD it's `https://login.microsoftonline.com/{tenant_id}/discovery/v2.0/keys`. For Keycloak it's `{keycloak_base_url}/realms/{your-realm}/protocol/openid-connect/certs`.
|
||||
|
||||
1. Map JWT roles to LiteLLM roles - Done via `user_roles_jwt_field` and `user_allowed_roles`
|
||||
- Currently just `internal_user` is supported for role mapping.
|
||||
2. Specify model access:
|
||||
- `role_permissions`: control what models are allowed for each role.
|
||||
- `role`: the LiteLLM role to control access for. Allowed roles = ["internal_user", "proxy_admin", "team"]
|
||||
- `models`: list of models that the role is allowed to access.
|
||||
- `model_list`: parent list of models on the proxy. [Learn more](./configs.md#llm-configs-model_list)
|
||||
|
||||
3. Model Checks: The proxy will run validation checks on the received JWT. [Code](https://github.com/BerriAI/litellm/blob/3a4f5b23b5025b87b6d969f2485cc9bc741f9ba6/litellm/proxy/auth/user_api_key_auth.py#L284)
|
||||
@@ -344,3 +344,6 @@ curl -i http://localhost:4000/v1/chat/completions \
|
||||
|
||||
</TabItem>
|
||||
</Tabs>
|
||||
|
||||
|
||||
## [Role Based Access Control (RBAC)](./jwt_auth_arch)
|
||||
@@ -1,7 +1,7 @@
|
||||
import Tabs from '@theme/Tabs';
|
||||
import TabItem from '@theme/TabItem';
|
||||
|
||||
# JWT-based Auth
|
||||
# SSO - JWT-based Auth
|
||||
|
||||
Use JWT's to auth admins / projects into the proxy.
|
||||
|
||||
@@ -183,6 +183,24 @@ Expected Scope in JWT:
|
||||
}
|
||||
```
|
||||
|
||||
### Control Model Access
|
||||
|
||||
```yaml
|
||||
general_settings:
|
||||
enable_jwt_auth: True
|
||||
litellm_jwtauth:
|
||||
user_roles_jwt_field: "resource_access.litellm-test-client-id.roles"
|
||||
user_allowed_roles: ["basic_user"] # roles that map to an 'internal_user' role on LiteLLM
|
||||
enforce_rbac: true # if true, will check if the user has the correct role to access the model + endpoint
|
||||
|
||||
role_permissions: # control what models + endpointsare allowed for each role
|
||||
- role: internal_user
|
||||
models: ["anthropic-claude"]
|
||||
```
|
||||
|
||||
|
||||
**[Architecture Diagram (Control Model Access)](./jwt_auth_arch)**
|
||||
|
||||
## Advanced - Allowed Routes
|
||||
|
||||
Configure which routes a JWT can access via the config.
|
||||
|
||||
Binary file not shown.
|
After Width: | Height: | Size: 113 KiB |
@@ -51,7 +51,7 @@ const sidebars = {
|
||||
{
|
||||
type: "category",
|
||||
label: "Architecture",
|
||||
items: ["proxy/architecture", "proxy/db_info", "router_architecture", "proxy/user_management_heirarchy"],
|
||||
items: ["proxy/architecture", "proxy/db_info", "router_architecture", "proxy/user_management_heirarchy", "proxy/jwt_auth_arch"],
|
||||
},
|
||||
{
|
||||
type: "link",
|
||||
|
||||
@@ -0,0 +1,59 @@
|
||||
"""
|
||||
This file contains the logic for dot notation indexing.
|
||||
|
||||
Used by JWT Auth to get the user role from the token.
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, Optional, TypeVar
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def get_nested_value(
|
||||
data: Dict[str, Any], key_path: str, default: Optional[T] = None
|
||||
) -> Optional[T]:
|
||||
"""
|
||||
Retrieves a value from a nested dictionary using dot notation.
|
||||
|
||||
Args:
|
||||
data: The dictionary to search in
|
||||
key_path: The path to the value using dot notation (e.g., "a.b.c")
|
||||
default: The default value to return if the path is not found
|
||||
|
||||
Returns:
|
||||
The value at the specified path, or the default value if not found
|
||||
|
||||
Example:
|
||||
>>> data = {"a": {"b": {"c": "value"}}}
|
||||
>>> get_nested_value(data, "a.b.c")
|
||||
'value'
|
||||
>>> get_nested_value(data, "a.b.d", "default")
|
||||
'default'
|
||||
"""
|
||||
if not key_path:
|
||||
return default
|
||||
|
||||
# Remove metadata. prefix if it exists
|
||||
key_path = (
|
||||
key_path.replace("metadata.", "", 1)
|
||||
if key_path.startswith("metadata.")
|
||||
else key_path
|
||||
)
|
||||
|
||||
# Split the key path into parts
|
||||
parts = key_path.split(".")
|
||||
|
||||
# Traverse through the dictionary
|
||||
current: Any = data
|
||||
for part in parts:
|
||||
try:
|
||||
current = current[part]
|
||||
except (KeyError, TypeError):
|
||||
return default
|
||||
|
||||
# If default is None, we can return any type
|
||||
if default is None:
|
||||
return current
|
||||
|
||||
# Otherwise, ensure the type matches the default
|
||||
return current if isinstance(current, type(default)) else default
|
||||
@@ -1,18 +1,16 @@
|
||||
model_list:
|
||||
- model_name: gpt-3.5-turbo-end-user-test
|
||||
- model_name: gpt-3.5-turbo
|
||||
litellm_params:
|
||||
model: gpt-3.5-turbo
|
||||
region_name: "eu"
|
||||
model_info:
|
||||
id: "1"
|
||||
- model_name: gpt-3.5-turbo-end-user-test
|
||||
litellm_params:
|
||||
model: gpt-3.5-turbo
|
||||
timeout: 2
|
||||
num_retries: 0
|
||||
- model_name: anthropic-claude
|
||||
litellm_params:
|
||||
model: anthropic.claude-3-sonnet-20240229-v1:0
|
||||
|
||||
litellm_settings:
|
||||
callbacks: ["langsmith"]
|
||||
model: claude-3-5-haiku-20241022
|
||||
- model_name: groq/*
|
||||
litellm_params:
|
||||
model: groq/*
|
||||
api_key: os.environ/GROQ_API_KEY
|
||||
mock_response: Hi!
|
||||
- model_name: deepseek/*
|
||||
litellm_params:
|
||||
model: deepseek/*
|
||||
api_key: os.environ/DEEPSEEK_API_KEY
|
||||
|
||||
@@ -445,6 +445,8 @@ class LiteLLM_JWTAuth(LiteLLMPydanticObjectBase):
|
||||
user_id_jwt_field: Optional[str] = None
|
||||
user_email_jwt_field: Optional[str] = None
|
||||
user_allowed_email_domain: Optional[str] = None
|
||||
user_roles_jwt_field: Optional[str] = None
|
||||
user_allowed_roles: Optional[List[str]] = None
|
||||
user_id_upsert: bool = Field(
|
||||
default=False, description="If user doesn't exist, upsert them into the db."
|
||||
)
|
||||
@@ -458,11 +460,19 @@ class LiteLLM_JWTAuth(LiteLLMPydanticObjectBase):
|
||||
allowed_keys = self.__annotations__.keys()
|
||||
|
||||
invalid_keys = set(kwargs.keys()) - allowed_keys
|
||||
user_roles_jwt_field = kwargs.get("user_roles_jwt_field")
|
||||
user_allowed_roles = kwargs.get("user_allowed_roles")
|
||||
|
||||
if invalid_keys:
|
||||
raise ValueError(
|
||||
f"Invalid arguments provided: {', '.join(invalid_keys)}. Allowed arguments are: {', '.join(allowed_keys)}."
|
||||
)
|
||||
if (user_roles_jwt_field is not None and user_allowed_roles is None) or (
|
||||
user_roles_jwt_field is None and user_allowed_roles is not None
|
||||
):
|
||||
raise ValueError(
|
||||
"user_allowed_roles must be provided if user_roles_jwt_field is set."
|
||||
)
|
||||
|
||||
super().__init__(**kwargs)
|
||||
|
||||
@@ -2335,3 +2345,15 @@ class ClientSideFallbackModel(TypedDict, total=False):
|
||||
|
||||
|
||||
ALL_FALLBACK_MODEL_VALUES = Union[str, ClientSideFallbackModel]
|
||||
|
||||
|
||||
RBAC_ROLES = Literal[
|
||||
LitellmUserRoles.PROXY_ADMIN,
|
||||
LitellmUserRoles.TEAM,
|
||||
LitellmUserRoles.INTERNAL_USER,
|
||||
]
|
||||
|
||||
|
||||
class RoleBasedPermissions(TypedDict):
|
||||
role: Required[RBAC_ROLES]
|
||||
models: Required[List[str]]
|
||||
|
||||
@@ -12,7 +12,7 @@ import asyncio
|
||||
import re
|
||||
import time
|
||||
import traceback
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, cast
|
||||
|
||||
from fastapi import status
|
||||
from pydantic import BaseModel
|
||||
@@ -24,6 +24,7 @@ from litellm.caching.dual_cache import LimitedSizeOrderedDict
|
||||
from litellm.litellm_core_utils.get_llm_provider_logic import get_llm_provider
|
||||
from litellm.proxy._types import (
|
||||
DB_CONNECTION_ERROR_TYPES,
|
||||
RBAC_ROLES,
|
||||
CallInfo,
|
||||
LiteLLM_EndUserTable,
|
||||
LiteLLM_JWTAuth,
|
||||
@@ -35,6 +36,7 @@ from litellm.proxy._types import (
|
||||
LitellmUserRoles,
|
||||
ProxyErrorTypes,
|
||||
ProxyException,
|
||||
RoleBasedPermissions,
|
||||
UserAPIKeyAuth,
|
||||
)
|
||||
from litellm.proxy.auth.route_checks import RouteChecks
|
||||
@@ -100,6 +102,14 @@ async def common_checks(
|
||||
llm_router=llm_router,
|
||||
)
|
||||
|
||||
## 2.1 If user can call model (if personal key)
|
||||
if team_object is None and user_object is not None:
|
||||
await can_user_call_model(
|
||||
model=_model,
|
||||
llm_router=llm_router,
|
||||
user_object=user_object,
|
||||
)
|
||||
|
||||
# 3. If team is in budget
|
||||
await _team_max_budget_check(
|
||||
team_object=team_object,
|
||||
@@ -391,6 +401,30 @@ def _update_last_db_access_time(
|
||||
last_db_access_time[key] = (value, time.time())
|
||||
|
||||
|
||||
def get_role_based_models(
|
||||
rbac_role: RBAC_ROLES,
|
||||
general_settings: dict,
|
||||
) -> Optional[List[str]]:
|
||||
"""
|
||||
Get the models allowed for a user role.
|
||||
|
||||
Used by JWT Auth.
|
||||
"""
|
||||
|
||||
role_based_permissions = cast(
|
||||
Optional[List[RoleBasedPermissions]],
|
||||
general_settings.get("role_permissions", []),
|
||||
)
|
||||
if role_based_permissions is None:
|
||||
return None
|
||||
|
||||
for role_based_permission in role_based_permissions:
|
||||
if role_based_permission["role"] == rbac_role:
|
||||
return role_based_permission["models"]
|
||||
|
||||
return None
|
||||
|
||||
|
||||
@log_db_metrics
|
||||
async def get_user_object(
|
||||
user_id: str,
|
||||
@@ -836,6 +870,68 @@ async def get_org_object(
|
||||
)
|
||||
|
||||
|
||||
async def _can_object_call_model(
|
||||
model: str,
|
||||
llm_router: Optional[Router],
|
||||
models: List[str],
|
||||
) -> Literal[True]:
|
||||
"""
|
||||
Checks if token can call a given model
|
||||
|
||||
Returns:
|
||||
- True: if token allowed to call model
|
||||
|
||||
Raises:
|
||||
- Exception: If token not allowed to call model
|
||||
"""
|
||||
if model in litellm.model_alias_map:
|
||||
model = litellm.model_alias_map[model]
|
||||
|
||||
## check if model in allowed model names
|
||||
from collections import defaultdict
|
||||
|
||||
access_groups: Dict[str, List[str]] = defaultdict(list)
|
||||
|
||||
if llm_router:
|
||||
access_groups = llm_router.get_model_access_groups(model_name=model)
|
||||
if (
|
||||
len(access_groups) > 0 and llm_router is not None
|
||||
): # check if token contains any model access groups
|
||||
for idx, m in enumerate(
|
||||
models
|
||||
): # loop token models, if any of them are an access group add the access group
|
||||
if m in access_groups:
|
||||
return True
|
||||
|
||||
# Filter out models that are access_groups
|
||||
filtered_models = [m for m in models if m not in access_groups]
|
||||
|
||||
verbose_proxy_logger.debug(f"model: {model}; allowed_models: {filtered_models}")
|
||||
|
||||
if _model_matches_any_wildcard_pattern_in_list(
|
||||
model=model, allowed_model_list=filtered_models
|
||||
):
|
||||
return True
|
||||
|
||||
all_model_access: bool = False
|
||||
|
||||
if (len(filtered_models) == 0 and len(models) == 0) or "*" in filtered_models:
|
||||
all_model_access = True
|
||||
|
||||
if model is not None and model not in filtered_models and all_model_access is False:
|
||||
raise ProxyException(
|
||||
message=f"API Key not allowed to access model. This token can only access models={models}. Tried to access {model}",
|
||||
type=ProxyErrorTypes.key_model_access_denied,
|
||||
param="model",
|
||||
code=status.HTTP_401_UNAUTHORIZED,
|
||||
)
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
f"filtered allowed_models: {filtered_models}; models: {models}"
|
||||
)
|
||||
return True
|
||||
|
||||
|
||||
async def can_key_call_model(
|
||||
model: str,
|
||||
llm_model_list: Optional[list],
|
||||
@@ -851,57 +947,27 @@ async def can_key_call_model(
|
||||
Raises:
|
||||
- Exception: If token not allowed to call model
|
||||
"""
|
||||
if model in litellm.model_alias_map:
|
||||
model = litellm.model_alias_map[model]
|
||||
|
||||
## check if model in allowed model names
|
||||
verbose_proxy_logger.debug(
|
||||
f"LLM Model List pre access group check: {llm_model_list}"
|
||||
return await _can_object_call_model(
|
||||
model=model,
|
||||
llm_router=llm_router,
|
||||
models=valid_token.models,
|
||||
)
|
||||
from collections import defaultdict
|
||||
|
||||
access_groups: Dict[str, List[str]] = defaultdict(list)
|
||||
|
||||
if llm_router:
|
||||
access_groups = llm_router.get_model_access_groups(model_name=model)
|
||||
if (
|
||||
len(access_groups) > 0 and llm_router is not None
|
||||
): # check if token contains any model access groups
|
||||
for idx, m in enumerate(
|
||||
valid_token.models
|
||||
): # loop token models, if any of them are an access group add the access group
|
||||
if m in access_groups:
|
||||
return True
|
||||
async def can_user_call_model(
|
||||
model: str,
|
||||
llm_router: Optional[Router],
|
||||
user_object: Optional[LiteLLM_UserTable],
|
||||
) -> Literal[True]:
|
||||
|
||||
# Filter out models that are access_groups
|
||||
filtered_models = [m for m in valid_token.models if m not in access_groups]
|
||||
|
||||
verbose_proxy_logger.debug(f"model: {model}; allowed_models: {filtered_models}")
|
||||
|
||||
if _model_matches_any_wildcard_pattern_in_list(
|
||||
model=model, allowed_model_list=filtered_models
|
||||
):
|
||||
if user_object is None:
|
||||
return True
|
||||
|
||||
all_model_access: bool = False
|
||||
|
||||
if (
|
||||
len(filtered_models) == 0 and len(valid_token.models) == 0
|
||||
) or "*" in filtered_models:
|
||||
all_model_access = True
|
||||
|
||||
if model is not None and model not in filtered_models and all_model_access is False:
|
||||
raise ProxyException(
|
||||
message=f"API Key not allowed to access model. This token can only access models={valid_token.models}. Tried to access {model}",
|
||||
type=ProxyErrorTypes.key_model_access_denied,
|
||||
param="model",
|
||||
code=status.HTTP_401_UNAUTHORIZED,
|
||||
)
|
||||
valid_token.models = filtered_models
|
||||
verbose_proxy_logger.debug(
|
||||
f"filtered allowed_models: {filtered_models}; valid_token.models: {valid_token.models}"
|
||||
return await _can_object_call_model(
|
||||
model=model,
|
||||
llm_router=llm_router,
|
||||
models=user_object.models,
|
||||
)
|
||||
return True
|
||||
|
||||
|
||||
async def is_valid_fallback_model(
|
||||
@@ -1161,7 +1227,11 @@ def _model_custom_llm_provider_matches_wildcard_pattern(
|
||||
- `model=claude-3-5-sonnet-20240620`
|
||||
- `allowed_model_pattern=anthropic/*`
|
||||
"""
|
||||
model, custom_llm_provider, _, _ = get_llm_provider(model=model)
|
||||
try:
|
||||
model, custom_llm_provider, _, _ = get_llm_provider(model=model)
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
return is_model_allowed_by_pattern(
|
||||
model=f"{custom_llm_provider}/{model}",
|
||||
allowed_model_pattern=allowed_model_pattern,
|
||||
|
||||
@@ -16,8 +16,10 @@ from cryptography.hazmat.primitives import serialization
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.caching.caching import DualCache
|
||||
from litellm.litellm_core_utils.dot_notation_indexing import get_nested_value
|
||||
from litellm.llms.custom_httpx.httpx_handler import HTTPHandler
|
||||
from litellm.proxy._types import (
|
||||
RBAC_ROLES,
|
||||
JWKKeyValue,
|
||||
JWTKeyItem,
|
||||
LiteLLM_JWTAuth,
|
||||
@@ -59,7 +61,7 @@ class JWTHandler:
|
||||
parts = token.split(".")
|
||||
return len(parts) == 3
|
||||
|
||||
def get_rbac_role(self, token: dict) -> Optional[LitellmUserRoles]:
|
||||
def get_rbac_role(self, token: dict) -> Optional[RBAC_ROLES]:
|
||||
"""
|
||||
Returns the RBAC role the token 'belongs' to.
|
||||
|
||||
@@ -78,12 +80,18 @@ class JWTHandler:
|
||||
"""
|
||||
scopes = self.get_scopes(token=token)
|
||||
is_admin = self.is_admin(scopes=scopes)
|
||||
user_roles = self.get_user_roles(token=token, default_value=None)
|
||||
|
||||
if is_admin:
|
||||
return LitellmUserRoles.PROXY_ADMIN
|
||||
elif self.get_team_id(token=token, default_value=None) is not None:
|
||||
return LitellmUserRoles.TEAM
|
||||
elif self.get_user_id(token=token, default_value=None) is not None:
|
||||
return LitellmUserRoles.INTERNAL_USER
|
||||
elif user_roles is not None and self.is_allowed_user_role(
|
||||
user_roles=user_roles
|
||||
):
|
||||
return LitellmUserRoles.INTERNAL_USER
|
||||
|
||||
return None
|
||||
|
||||
@@ -166,6 +174,43 @@ class JWTHandler:
|
||||
user_id = default_value
|
||||
return user_id
|
||||
|
||||
def get_user_roles(
|
||||
self, token: dict, default_value: Optional[List[str]]
|
||||
) -> Optional[List[str]]:
|
||||
"""
|
||||
Returns the user role from the token.
|
||||
|
||||
Set via 'user_roles_jwt_field' in the config.
|
||||
"""
|
||||
try:
|
||||
if self.litellm_jwtauth.user_roles_jwt_field is not None:
|
||||
user_roles = get_nested_value(
|
||||
data=token,
|
||||
key_path=self.litellm_jwtauth.user_roles_jwt_field,
|
||||
default=default_value,
|
||||
)
|
||||
else:
|
||||
user_roles = default_value
|
||||
except KeyError:
|
||||
user_roles = default_value
|
||||
return user_roles
|
||||
|
||||
def is_allowed_user_role(self, user_roles: Optional[List[str]]) -> bool:
|
||||
"""
|
||||
Returns the user role from the token.
|
||||
|
||||
Set via 'user_allowed_roles' in the config.
|
||||
"""
|
||||
if (
|
||||
user_roles is not None
|
||||
and self.litellm_jwtauth.user_allowed_roles is not None
|
||||
and any(
|
||||
role in self.litellm_jwtauth.user_allowed_roles for role in user_roles
|
||||
)
|
||||
):
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_user_email(
|
||||
self, token: dict, default_value: Optional[str]
|
||||
) -> Optional[str]:
|
||||
|
||||
@@ -33,6 +33,7 @@ from litellm.proxy.auth.auth_checks import (
|
||||
get_end_user_object,
|
||||
get_key_object,
|
||||
get_org_object,
|
||||
get_role_based_models,
|
||||
get_team_object,
|
||||
get_user_object,
|
||||
is_valid_fallback_model,
|
||||
@@ -281,9 +282,34 @@ def get_rbac_role(jwt_handler: JWTHandler, scopes: List[str]) -> str:
|
||||
return LitellmUserRoles.TEAM
|
||||
|
||||
|
||||
def can_rbac_role_call_model(
|
||||
rbac_role: RBAC_ROLES,
|
||||
general_settings: dict,
|
||||
model: Optional[str],
|
||||
) -> Literal[True]:
|
||||
"""
|
||||
Checks if user is allowed to access the model, based on their role.
|
||||
"""
|
||||
role_based_models = get_role_based_models(
|
||||
rbac_role=rbac_role, general_settings=general_settings
|
||||
)
|
||||
if role_based_models is None or model is None:
|
||||
return True
|
||||
|
||||
if model not in role_based_models:
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail=f"User role={rbac_role} not allowed to call model={model}. Allowed models={role_based_models}",
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
async def _jwt_auth_user_api_key_auth_builder(
|
||||
api_key: str,
|
||||
jwt_handler: JWTHandler,
|
||||
request_data: dict,
|
||||
general_settings: dict,
|
||||
route: str,
|
||||
prisma_client: Optional[PrismaClient],
|
||||
user_api_key_cache: DualCache,
|
||||
@@ -295,14 +321,20 @@ async def _jwt_auth_user_api_key_auth_builder(
|
||||
jwt_valid_token: dict = await jwt_handler.auth_jwt(token=api_key)
|
||||
|
||||
# check if unmatched token and enforce_rbac is true
|
||||
if (
|
||||
jwt_handler.litellm_jwtauth.enforce_rbac is True
|
||||
and jwt_handler.get_rbac_role(token=jwt_valid_token) is None
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail="Unmatched token passed in. enforce_rbac is set to True. Token must belong to a proxy admin, team, or user. See how to set roles in config here: https://docs.litellm.ai/docs/proxy/token_auth#advanced---spend-tracking-end-users--internal-users--team--org",
|
||||
)
|
||||
if jwt_handler.litellm_jwtauth.enforce_rbac is True:
|
||||
rbac_role = jwt_handler.get_rbac_role(token=jwt_valid_token)
|
||||
if rbac_role is None:
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail="Unmatched token passed in. enforce_rbac is set to True. Token must belong to a proxy admin, team, or user. See how to set roles in config here: https://docs.litellm.ai/docs/proxy/token_auth#advanced---spend-tracking-end-users--internal-users--team--org",
|
||||
)
|
||||
else:
|
||||
# run rbac validation checks
|
||||
can_rbac_role_call_model(
|
||||
rbac_role=rbac_role,
|
||||
general_settings=general_settings,
|
||||
model=request_data.get("model"),
|
||||
)
|
||||
# get scopes
|
||||
scopes = jwt_handler.get_scopes(token=jwt_valid_token)
|
||||
|
||||
@@ -431,18 +463,18 @@ async def _jwt_auth_user_api_key_auth_builder(
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
)
|
||||
|
||||
return {
|
||||
"is_proxy_admin": False,
|
||||
"team_id": team_id,
|
||||
"team_object": team_object,
|
||||
"user_id": user_id,
|
||||
"user_object": user_object,
|
||||
"org_id": org_id,
|
||||
"org_object": org_object,
|
||||
"end_user_id": end_user_id,
|
||||
"end_user_object": end_user_object,
|
||||
"token": api_key,
|
||||
}
|
||||
return JWTAuthBuilderResult(
|
||||
is_proxy_admin=False,
|
||||
team_id=team_id,
|
||||
team_object=team_object,
|
||||
user_id=user_id,
|
||||
user_object=user_object,
|
||||
org_id=org_id,
|
||||
org_object=org_object,
|
||||
end_user_id=end_user_id,
|
||||
end_user_object=end_user_object,
|
||||
token=api_key,
|
||||
)
|
||||
|
||||
|
||||
async def _user_api_key_auth_builder( # noqa: PLR0915
|
||||
@@ -581,6 +613,8 @@ async def _user_api_key_auth_builder( # noqa: PLR0915
|
||||
verbose_proxy_logger.debug("is_jwt: %s", is_jwt)
|
||||
if is_jwt:
|
||||
result = await _jwt_auth_user_api_key_auth_builder(
|
||||
request_data=request_data,
|
||||
general_settings=general_settings,
|
||||
api_key=api_key,
|
||||
jwt_handler=jwt_handler,
|
||||
route=route,
|
||||
|
||||
@@ -508,3 +508,43 @@ async def test_virtual_key_soft_budget_check(spend, soft_budget, expect_alert):
|
||||
assert (
|
||||
alert_triggered == expect_alert
|
||||
), f"Expected alert_triggered to be {expect_alert} for spend={spend}, soft_budget={soft_budget}"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_can_user_call_model():
|
||||
from litellm.proxy.auth.auth_checks import can_user_call_model
|
||||
from litellm.proxy._types import ProxyException
|
||||
from litellm import Router
|
||||
|
||||
router = Router(
|
||||
model_list=[
|
||||
{
|
||||
"model_name": "anthropic-claude",
|
||||
"litellm_params": {"model": "anthropic/anthropic-claude"},
|
||||
},
|
||||
{
|
||||
"model_name": "gpt-3.5-turbo",
|
||||
"litellm_params": {"model": "gpt-3.5-turbo", "api_key": "test-api-key"},
|
||||
},
|
||||
]
|
||||
)
|
||||
|
||||
args = {
|
||||
"model": "anthropic-claude",
|
||||
"llm_router": router,
|
||||
"user_object": LiteLLM_UserTable(
|
||||
user_id="testuser21@mycompany.com",
|
||||
max_budget=None,
|
||||
spend=0.0042295,
|
||||
model_max_budget={},
|
||||
model_spend={},
|
||||
user_email="testuser@mycompany.com",
|
||||
models=["gpt-3.5-turbo"],
|
||||
),
|
||||
}
|
||||
|
||||
with pytest.raises(ProxyException) as e:
|
||||
await can_user_call_model(**args)
|
||||
|
||||
args["model"] = "gpt-3.5-turbo"
|
||||
await can_user_call_model(**args)
|
||||
|
||||
@@ -855,6 +855,8 @@ async def test_jwt_user_api_key_auth_builder_enforce_rbac(enforce_rbac, monkeypa
|
||||
"user_api_key_cache": Mock(),
|
||||
"parent_otel_span": None,
|
||||
"proxy_logging_obj": Mock(),
|
||||
"request_data": {},
|
||||
"general_settings": {},
|
||||
}
|
||||
|
||||
if enforce_rbac:
|
||||
@@ -877,3 +879,55 @@ def test_user_api_key_auth_end_user_str():
|
||||
|
||||
user_api_key_auth = UserAPIKeyAuth(**user_api_key_args)
|
||||
assert user_api_key_auth.end_user_id == "1"
|
||||
|
||||
|
||||
def test_can_rbac_role_call_model():
|
||||
from litellm.proxy.auth.user_api_key_auth import can_rbac_role_call_model
|
||||
from litellm.proxy._types import RoleBasedPermissions
|
||||
|
||||
roles_based_permissions = [
|
||||
RoleBasedPermissions(
|
||||
role=LitellmUserRoles.INTERNAL_USER,
|
||||
models=["gpt-4"],
|
||||
),
|
||||
RoleBasedPermissions(
|
||||
role=LitellmUserRoles.PROXY_ADMIN,
|
||||
models=["anthropic-claude"],
|
||||
),
|
||||
]
|
||||
|
||||
assert can_rbac_role_call_model(
|
||||
rbac_role=LitellmUserRoles.INTERNAL_USER,
|
||||
general_settings={"role_permissions": roles_based_permissions},
|
||||
model="gpt-4",
|
||||
)
|
||||
|
||||
with pytest.raises(HTTPException):
|
||||
can_rbac_role_call_model(
|
||||
rbac_role=LitellmUserRoles.INTERNAL_USER,
|
||||
general_settings={"role_permissions": roles_based_permissions},
|
||||
model="gpt-4o",
|
||||
)
|
||||
|
||||
with pytest.raises(HTTPException):
|
||||
can_rbac_role_call_model(
|
||||
rbac_role=LitellmUserRoles.PROXY_ADMIN,
|
||||
general_settings={"role_permissions": roles_based_permissions},
|
||||
model="gpt-4o",
|
||||
)
|
||||
|
||||
|
||||
def test_can_rbac_role_call_model_no_role_permissions():
|
||||
from litellm.proxy.auth.user_api_key_auth import can_rbac_role_call_model
|
||||
|
||||
assert can_rbac_role_call_model(
|
||||
rbac_role=LitellmUserRoles.INTERNAL_USER,
|
||||
general_settings={},
|
||||
model="gpt-4",
|
||||
)
|
||||
|
||||
assert can_rbac_role_call_model(
|
||||
rbac_role=LitellmUserRoles.PROXY_ADMIN,
|
||||
general_settings={"role_permissions": []},
|
||||
model="anthropic-claude",
|
||||
)
|
||||
|
||||
+107
-2
@@ -7,13 +7,17 @@ import time
|
||||
from openai import AsyncOpenAI
|
||||
from test_team import list_teams
|
||||
from typing import Optional
|
||||
from test_keys import generate_key
|
||||
from fastapi import HTTPException
|
||||
|
||||
|
||||
async def new_user(session, i, user_id=None, budget=None, budget_duration=None):
|
||||
async def new_user(
|
||||
session, i, user_id=None, budget=None, budget_duration=None, models=None
|
||||
):
|
||||
url = "http://0.0.0.0:4000/user/new"
|
||||
headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"}
|
||||
data = {
|
||||
"models": ["azure-models"],
|
||||
"models": models or ["azure-models"],
|
||||
"aliases": {"mistral-7b": "gpt-3.5-turbo"},
|
||||
"duration": None,
|
||||
"max_budget": budget,
|
||||
@@ -37,6 +41,51 @@ async def new_user(session, i, user_id=None, budget=None, budget_duration=None):
|
||||
return await response.json()
|
||||
|
||||
|
||||
async def generate_key(
|
||||
session,
|
||||
i,
|
||||
budget=None,
|
||||
budget_duration=None,
|
||||
models=["azure-models", "gpt-4", "dall-e-3"],
|
||||
max_parallel_requests: Optional[int] = None,
|
||||
user_id: Optional[str] = None,
|
||||
team_id: Optional[str] = None,
|
||||
metadata: Optional[dict] = None,
|
||||
calling_key="sk-1234",
|
||||
):
|
||||
url = "http://0.0.0.0:4000/key/generate"
|
||||
headers = {
|
||||
"Authorization": f"Bearer {calling_key}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
data = {
|
||||
"models": models,
|
||||
"aliases": {"mistral-7b": "gpt-3.5-turbo"},
|
||||
"duration": None,
|
||||
"max_budget": budget,
|
||||
"budget_duration": budget_duration,
|
||||
"max_parallel_requests": max_parallel_requests,
|
||||
"user_id": user_id,
|
||||
"team_id": team_id,
|
||||
"metadata": metadata,
|
||||
}
|
||||
|
||||
print(f"data: {data}")
|
||||
|
||||
async with session.post(url, headers=headers, json=data) as response:
|
||||
status = response.status
|
||||
response_text = await response.text()
|
||||
|
||||
print(f"Response {i} (Status code: {status}):")
|
||||
print(response_text)
|
||||
print()
|
||||
|
||||
if status != 200:
|
||||
raise Exception(f"Request {i} did not return a 200 status code: {status}")
|
||||
|
||||
return await response.json()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_user_new():
|
||||
"""
|
||||
@@ -210,3 +259,59 @@ async def test_global_proxy_budget_update():
|
||||
new_new_spend = user_info["user_info"]["spend"]
|
||||
print(f"new_spend: {new_spend}; original_spend: {original_spend}")
|
||||
assert new_new_spend > new_spend
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_user_model_access():
|
||||
"""
|
||||
- Create user with model access
|
||||
- Create key with user
|
||||
- Call model that user has access to -> should work
|
||||
- Call wildcard model that user has access to -> should work
|
||||
- Call model that user does not have access to -> should fail
|
||||
- Call wildcard model that user does not have access to -> should fail
|
||||
"""
|
||||
import openai
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
get_user = f"krrish_{time.time()}@berri.ai"
|
||||
await new_user(
|
||||
session=session,
|
||||
i=0,
|
||||
user_id=get_user,
|
||||
models=["good-model", "anthropic/*"],
|
||||
)
|
||||
|
||||
result = await generate_key(
|
||||
session=session,
|
||||
i=0,
|
||||
user_id=get_user,
|
||||
models=[], # assign no models. Allow inheritance from user
|
||||
)
|
||||
key = result["key"]
|
||||
|
||||
await chat_completion(
|
||||
session=session,
|
||||
key=key,
|
||||
model="anthropic/claude-3-5-haiku-20241022",
|
||||
)
|
||||
|
||||
await chat_completion(
|
||||
session=session,
|
||||
key=key,
|
||||
model="good-model",
|
||||
)
|
||||
|
||||
with pytest.raises(openai.AuthenticationError):
|
||||
await chat_completion(
|
||||
session=session,
|
||||
key=key,
|
||||
model="bedrock/anthropic.claude-3-sonnet-20240229-v1:0",
|
||||
)
|
||||
|
||||
with pytest.raises(openai.AuthenticationError):
|
||||
await chat_completion(
|
||||
session=session,
|
||||
key=key,
|
||||
model="groq/claude-3-5-haiku-20241022",
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user