feat(handle_jwt.py): initial commit adding custom RBAC support on jwt… (#8037)

* feat(handle_jwt.py): initial commit adding custom RBAC support on jwt auth

allows admin to define user role field and allowed roles which map to 'internal_user' on litellm

* fix(auth_checks.py): ensure user allowed to access model, when calling via personal keys

Fixes https://github.com/BerriAI/litellm/issues/8029

* feat(handle_jwt.py): support role based access with model permission control on proxy

Allows admin to just grant users roles on IDP (e.g. Azure AD/Keycloak) and user can immediately start calling models

* docs(rbac): add docs on rbac for model access control

make it clear how admin can use roles to control model access on proxy

* fix: fix linting errors

* test(test_user_api_key_auth.py): add unit testing to ensure rbac role is correctly enforced

* test(test_user_api_key_auth.py): add more testing

* test(test_users.py): add unit testing to ensure user model access is always checked for new keys

Resolves https://github.com/BerriAI/litellm/issues/8029

* test: fix unit test

* fix(dot_notation_indexing.py): fix typing to work with python 3.8
This commit is contained in:
Krish Dholakia
2025-01-28 16:27:06 -08:00
committed by GitHub
parent 9644e197f7
commit 2eaa0079f2
14 changed files with 648 additions and 84 deletions
+116
View File
@@ -0,0 +1,116 @@
import Image from '@theme/IdealImage';
import Tabs from '@theme/Tabs';
import TabItem from '@theme/TabItem';
# Control Model Access with SSO (Azure AD/Keycloak/etc.)
:::info
✨ JWT Auth is on LiteLLM Enterprise
[Enterprise Pricing](https://www.litellm.ai/#pricing)
[Get free 7-day trial key](https://www.litellm.ai/#trial)
:::
<Image img={require('../../img/control_model_access_jwt.png')} style={{ width: '100%', maxWidth: '4000px' }} />
## Example Token
<Tabs>
<TabItem value="Azure AD">
```bash
{
"sub": "1234567890",
"name": "John Doe",
"email": "john.doe@example.com",
"roles": ["basic_user"] # 👈 ROLE
}
```
</TabItem>
<TabItem value="Keycloak">
```bash
{
"sub": "1234567890",
"name": "John Doe",
"email": "john.doe@example.com",
"resource_access": {
"litellm-test-client-id": {
"roles": ["basic_user"] # 👈 ROLE
}
}
}
```
</TabItem>
</Tabs>
## Proxy Configuration
<Tabs>
<TabItem value="Azure AD">
```yaml
general_settings:
enable_jwt_auth: True
litellm_jwtauth:
user_roles_jwt_field: "roles" # the field in the JWT that contains the roles
user_allowed_roles: ["basic_user"] # roles that map to an 'internal_user' role on LiteLLM
enforce_rbac: true # if true, will check if the user has the correct role to access the model
role_permissions: # control what models are allowed for each role
- role: internal_user
models: ["anthropic-claude"]
model_list:
- model: anthropic-claude
litellm_params:
model: claude-3-5-haiku-20241022
- model: openai-gpt-4o
litellm_params:
model: gpt-4o
```
</TabItem>
<TabItem value="Keycloak">
```yaml
general_settings:
enable_jwt_auth: True
litellm_jwtauth:
user_roles_jwt_field: "resource_access.litellm-test-client-id.roles" # the field in the JWT that contains the roles
user_allowed_roles: ["basic_user"] # roles that map to an 'internal_user' role on LiteLLM
enforce_rbac: true # if true, will check if the user has the correct role to access the model
role_permissions: # control what models are allowed for each role
- role: internal_user
models: ["anthropic-claude"]
model_list:
- model: anthropic-claude
litellm_params:
model: claude-3-5-haiku-20241022
- model: openai-gpt-4o
litellm_params:
model: gpt-4o
```
</TabItem>
</Tabs>
## How it works
1. Specify JWT_PUBLIC_KEY_URL - This is the public keys endpoint of your OpenID provider. For Azure AD it's `https://login.microsoftonline.com/{tenant_id}/discovery/v2.0/keys`. For Keycloak it's `{keycloak_base_url}/realms/{your-realm}/protocol/openid-connect/certs`.
1. Map JWT roles to LiteLLM roles - Done via `user_roles_jwt_field` and `user_allowed_roles`
- Currently just `internal_user` is supported for role mapping.
2. Specify model access:
- `role_permissions`: control what models are allowed for each role.
- `role`: the LiteLLM role to control access for. Allowed roles = ["internal_user", "proxy_admin", "team"]
- `models`: list of models that the role is allowed to access.
- `model_list`: parent list of models on the proxy. [Learn more](./configs.md#llm-configs-model_list)
3. Model Checks: The proxy will run validation checks on the received JWT. [Code](https://github.com/BerriAI/litellm/blob/3a4f5b23b5025b87b6d969f2485cc9bc741f9ba6/litellm/proxy/auth/user_api_key_auth.py#L284)
@@ -344,3 +344,6 @@ curl -i http://localhost:4000/v1/chat/completions \
</TabItem>
</Tabs>
## [Role Based Access Control (RBAC)](./jwt_auth_arch)
+19 -1
View File
@@ -1,7 +1,7 @@
import Tabs from '@theme/Tabs';
import TabItem from '@theme/TabItem';
# JWT-based Auth
# SSO - JWT-based Auth
Use JWT's to auth admins / projects into the proxy.
@@ -183,6 +183,24 @@ Expected Scope in JWT:
}
```
### Control Model Access
```yaml
general_settings:
enable_jwt_auth: True
litellm_jwtauth:
user_roles_jwt_field: "resource_access.litellm-test-client-id.roles"
user_allowed_roles: ["basic_user"] # roles that map to an 'internal_user' role on LiteLLM
enforce_rbac: true # if true, will check if the user has the correct role to access the model + endpoint
role_permissions: # control what models + endpointsare allowed for each role
- role: internal_user
models: ["anthropic-claude"]
```
**[Architecture Diagram (Control Model Access)](./jwt_auth_arch)**
## Advanced - Allowed Routes
Configure which routes a JWT can access via the config.
Binary file not shown.

After

Width:  |  Height:  |  Size: 113 KiB

+1 -1
View File
@@ -51,7 +51,7 @@ const sidebars = {
{
type: "category",
label: "Architecture",
items: ["proxy/architecture", "proxy/db_info", "router_architecture", "proxy/user_management_heirarchy"],
items: ["proxy/architecture", "proxy/db_info", "router_architecture", "proxy/user_management_heirarchy", "proxy/jwt_auth_arch"],
},
{
type: "link",
@@ -0,0 +1,59 @@
"""
This file contains the logic for dot notation indexing.
Used by JWT Auth to get the user role from the token.
"""
from typing import Any, Dict, Optional, TypeVar
T = TypeVar("T")
def get_nested_value(
data: Dict[str, Any], key_path: str, default: Optional[T] = None
) -> Optional[T]:
"""
Retrieves a value from a nested dictionary using dot notation.
Args:
data: The dictionary to search in
key_path: The path to the value using dot notation (e.g., "a.b.c")
default: The default value to return if the path is not found
Returns:
The value at the specified path, or the default value if not found
Example:
>>> data = {"a": {"b": {"c": "value"}}}
>>> get_nested_value(data, "a.b.c")
'value'
>>> get_nested_value(data, "a.b.d", "default")
'default'
"""
if not key_path:
return default
# Remove metadata. prefix if it exists
key_path = (
key_path.replace("metadata.", "", 1)
if key_path.startswith("metadata.")
else key_path
)
# Split the key path into parts
parts = key_path.split(".")
# Traverse through the dictionary
current: Any = data
for part in parts:
try:
current = current[part]
except (KeyError, TypeError):
return default
# If default is None, we can return any type
if default is None:
return current
# Otherwise, ensure the type matches the default
return current if isinstance(current, type(default)) else default
+11 -13
View File
@@ -1,18 +1,16 @@
model_list:
- model_name: gpt-3.5-turbo-end-user-test
- model_name: gpt-3.5-turbo
litellm_params:
model: gpt-3.5-turbo
region_name: "eu"
model_info:
id: "1"
- model_name: gpt-3.5-turbo-end-user-test
litellm_params:
model: gpt-3.5-turbo
timeout: 2
num_retries: 0
- model_name: anthropic-claude
litellm_params:
model: anthropic.claude-3-sonnet-20240229-v1:0
litellm_settings:
callbacks: ["langsmith"]
model: claude-3-5-haiku-20241022
- model_name: groq/*
litellm_params:
model: groq/*
api_key: os.environ/GROQ_API_KEY
mock_response: Hi!
- model_name: deepseek/*
litellm_params:
model: deepseek/*
api_key: os.environ/DEEPSEEK_API_KEY
+22
View File
@@ -445,6 +445,8 @@ class LiteLLM_JWTAuth(LiteLLMPydanticObjectBase):
user_id_jwt_field: Optional[str] = None
user_email_jwt_field: Optional[str] = None
user_allowed_email_domain: Optional[str] = None
user_roles_jwt_field: Optional[str] = None
user_allowed_roles: Optional[List[str]] = None
user_id_upsert: bool = Field(
default=False, description="If user doesn't exist, upsert them into the db."
)
@@ -458,11 +460,19 @@ class LiteLLM_JWTAuth(LiteLLMPydanticObjectBase):
allowed_keys = self.__annotations__.keys()
invalid_keys = set(kwargs.keys()) - allowed_keys
user_roles_jwt_field = kwargs.get("user_roles_jwt_field")
user_allowed_roles = kwargs.get("user_allowed_roles")
if invalid_keys:
raise ValueError(
f"Invalid arguments provided: {', '.join(invalid_keys)}. Allowed arguments are: {', '.join(allowed_keys)}."
)
if (user_roles_jwt_field is not None and user_allowed_roles is None) or (
user_roles_jwt_field is None and user_allowed_roles is not None
):
raise ValueError(
"user_allowed_roles must be provided if user_roles_jwt_field is set."
)
super().__init__(**kwargs)
@@ -2335,3 +2345,15 @@ class ClientSideFallbackModel(TypedDict, total=False):
ALL_FALLBACK_MODEL_VALUES = Union[str, ClientSideFallbackModel]
RBAC_ROLES = Literal[
LitellmUserRoles.PROXY_ADMIN,
LitellmUserRoles.TEAM,
LitellmUserRoles.INTERNAL_USER,
]
class RoleBasedPermissions(TypedDict):
role: Required[RBAC_ROLES]
models: Required[List[str]]
+116 -46
View File
@@ -12,7 +12,7 @@ import asyncio
import re
import time
import traceback
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, cast
from fastapi import status
from pydantic import BaseModel
@@ -24,6 +24,7 @@ from litellm.caching.dual_cache import LimitedSizeOrderedDict
from litellm.litellm_core_utils.get_llm_provider_logic import get_llm_provider
from litellm.proxy._types import (
DB_CONNECTION_ERROR_TYPES,
RBAC_ROLES,
CallInfo,
LiteLLM_EndUserTable,
LiteLLM_JWTAuth,
@@ -35,6 +36,7 @@ from litellm.proxy._types import (
LitellmUserRoles,
ProxyErrorTypes,
ProxyException,
RoleBasedPermissions,
UserAPIKeyAuth,
)
from litellm.proxy.auth.route_checks import RouteChecks
@@ -100,6 +102,14 @@ async def common_checks(
llm_router=llm_router,
)
## 2.1 If user can call model (if personal key)
if team_object is None and user_object is not None:
await can_user_call_model(
model=_model,
llm_router=llm_router,
user_object=user_object,
)
# 3. If team is in budget
await _team_max_budget_check(
team_object=team_object,
@@ -391,6 +401,30 @@ def _update_last_db_access_time(
last_db_access_time[key] = (value, time.time())
def get_role_based_models(
rbac_role: RBAC_ROLES,
general_settings: dict,
) -> Optional[List[str]]:
"""
Get the models allowed for a user role.
Used by JWT Auth.
"""
role_based_permissions = cast(
Optional[List[RoleBasedPermissions]],
general_settings.get("role_permissions", []),
)
if role_based_permissions is None:
return None
for role_based_permission in role_based_permissions:
if role_based_permission["role"] == rbac_role:
return role_based_permission["models"]
return None
@log_db_metrics
async def get_user_object(
user_id: str,
@@ -836,6 +870,68 @@ async def get_org_object(
)
async def _can_object_call_model(
model: str,
llm_router: Optional[Router],
models: List[str],
) -> Literal[True]:
"""
Checks if token can call a given model
Returns:
- True: if token allowed to call model
Raises:
- Exception: If token not allowed to call model
"""
if model in litellm.model_alias_map:
model = litellm.model_alias_map[model]
## check if model in allowed model names
from collections import defaultdict
access_groups: Dict[str, List[str]] = defaultdict(list)
if llm_router:
access_groups = llm_router.get_model_access_groups(model_name=model)
if (
len(access_groups) > 0 and llm_router is not None
): # check if token contains any model access groups
for idx, m in enumerate(
models
): # loop token models, if any of them are an access group add the access group
if m in access_groups:
return True
# Filter out models that are access_groups
filtered_models = [m for m in models if m not in access_groups]
verbose_proxy_logger.debug(f"model: {model}; allowed_models: {filtered_models}")
if _model_matches_any_wildcard_pattern_in_list(
model=model, allowed_model_list=filtered_models
):
return True
all_model_access: bool = False
if (len(filtered_models) == 0 and len(models) == 0) or "*" in filtered_models:
all_model_access = True
if model is not None and model not in filtered_models and all_model_access is False:
raise ProxyException(
message=f"API Key not allowed to access model. This token can only access models={models}. Tried to access {model}",
type=ProxyErrorTypes.key_model_access_denied,
param="model",
code=status.HTTP_401_UNAUTHORIZED,
)
verbose_proxy_logger.debug(
f"filtered allowed_models: {filtered_models}; models: {models}"
)
return True
async def can_key_call_model(
model: str,
llm_model_list: Optional[list],
@@ -851,57 +947,27 @@ async def can_key_call_model(
Raises:
- Exception: If token not allowed to call model
"""
if model in litellm.model_alias_map:
model = litellm.model_alias_map[model]
## check if model in allowed model names
verbose_proxy_logger.debug(
f"LLM Model List pre access group check: {llm_model_list}"
return await _can_object_call_model(
model=model,
llm_router=llm_router,
models=valid_token.models,
)
from collections import defaultdict
access_groups: Dict[str, List[str]] = defaultdict(list)
if llm_router:
access_groups = llm_router.get_model_access_groups(model_name=model)
if (
len(access_groups) > 0 and llm_router is not None
): # check if token contains any model access groups
for idx, m in enumerate(
valid_token.models
): # loop token models, if any of them are an access group add the access group
if m in access_groups:
return True
async def can_user_call_model(
model: str,
llm_router: Optional[Router],
user_object: Optional[LiteLLM_UserTable],
) -> Literal[True]:
# Filter out models that are access_groups
filtered_models = [m for m in valid_token.models if m not in access_groups]
verbose_proxy_logger.debug(f"model: {model}; allowed_models: {filtered_models}")
if _model_matches_any_wildcard_pattern_in_list(
model=model, allowed_model_list=filtered_models
):
if user_object is None:
return True
all_model_access: bool = False
if (
len(filtered_models) == 0 and len(valid_token.models) == 0
) or "*" in filtered_models:
all_model_access = True
if model is not None and model not in filtered_models and all_model_access is False:
raise ProxyException(
message=f"API Key not allowed to access model. This token can only access models={valid_token.models}. Tried to access {model}",
type=ProxyErrorTypes.key_model_access_denied,
param="model",
code=status.HTTP_401_UNAUTHORIZED,
)
valid_token.models = filtered_models
verbose_proxy_logger.debug(
f"filtered allowed_models: {filtered_models}; valid_token.models: {valid_token.models}"
return await _can_object_call_model(
model=model,
llm_router=llm_router,
models=user_object.models,
)
return True
async def is_valid_fallback_model(
@@ -1161,7 +1227,11 @@ def _model_custom_llm_provider_matches_wildcard_pattern(
- `model=claude-3-5-sonnet-20240620`
- `allowed_model_pattern=anthropic/*`
"""
model, custom_llm_provider, _, _ = get_llm_provider(model=model)
try:
model, custom_llm_provider, _, _ = get_llm_provider(model=model)
except Exception:
return False
return is_model_allowed_by_pattern(
model=f"{custom_llm_provider}/{model}",
allowed_model_pattern=allowed_model_pattern,
+46 -1
View File
@@ -16,8 +16,10 @@ from cryptography.hazmat.primitives import serialization
from litellm._logging import verbose_proxy_logger
from litellm.caching.caching import DualCache
from litellm.litellm_core_utils.dot_notation_indexing import get_nested_value
from litellm.llms.custom_httpx.httpx_handler import HTTPHandler
from litellm.proxy._types import (
RBAC_ROLES,
JWKKeyValue,
JWTKeyItem,
LiteLLM_JWTAuth,
@@ -59,7 +61,7 @@ class JWTHandler:
parts = token.split(".")
return len(parts) == 3
def get_rbac_role(self, token: dict) -> Optional[LitellmUserRoles]:
def get_rbac_role(self, token: dict) -> Optional[RBAC_ROLES]:
"""
Returns the RBAC role the token 'belongs' to.
@@ -78,12 +80,18 @@ class JWTHandler:
"""
scopes = self.get_scopes(token=token)
is_admin = self.is_admin(scopes=scopes)
user_roles = self.get_user_roles(token=token, default_value=None)
if is_admin:
return LitellmUserRoles.PROXY_ADMIN
elif self.get_team_id(token=token, default_value=None) is not None:
return LitellmUserRoles.TEAM
elif self.get_user_id(token=token, default_value=None) is not None:
return LitellmUserRoles.INTERNAL_USER
elif user_roles is not None and self.is_allowed_user_role(
user_roles=user_roles
):
return LitellmUserRoles.INTERNAL_USER
return None
@@ -166,6 +174,43 @@ class JWTHandler:
user_id = default_value
return user_id
def get_user_roles(
self, token: dict, default_value: Optional[List[str]]
) -> Optional[List[str]]:
"""
Returns the user role from the token.
Set via 'user_roles_jwt_field' in the config.
"""
try:
if self.litellm_jwtauth.user_roles_jwt_field is not None:
user_roles = get_nested_value(
data=token,
key_path=self.litellm_jwtauth.user_roles_jwt_field,
default=default_value,
)
else:
user_roles = default_value
except KeyError:
user_roles = default_value
return user_roles
def is_allowed_user_role(self, user_roles: Optional[List[str]]) -> bool:
"""
Returns the user role from the token.
Set via 'user_allowed_roles' in the config.
"""
if (
user_roles is not None
and self.litellm_jwtauth.user_allowed_roles is not None
and any(
role in self.litellm_jwtauth.user_allowed_roles for role in user_roles
)
):
return True
return False
def get_user_email(
self, token: dict, default_value: Optional[str]
) -> Optional[str]:
+54 -20
View File
@@ -33,6 +33,7 @@ from litellm.proxy.auth.auth_checks import (
get_end_user_object,
get_key_object,
get_org_object,
get_role_based_models,
get_team_object,
get_user_object,
is_valid_fallback_model,
@@ -281,9 +282,34 @@ def get_rbac_role(jwt_handler: JWTHandler, scopes: List[str]) -> str:
return LitellmUserRoles.TEAM
def can_rbac_role_call_model(
rbac_role: RBAC_ROLES,
general_settings: dict,
model: Optional[str],
) -> Literal[True]:
"""
Checks if user is allowed to access the model, based on their role.
"""
role_based_models = get_role_based_models(
rbac_role=rbac_role, general_settings=general_settings
)
if role_based_models is None or model is None:
return True
if model not in role_based_models:
raise HTTPException(
status_code=403,
detail=f"User role={rbac_role} not allowed to call model={model}. Allowed models={role_based_models}",
)
return True
async def _jwt_auth_user_api_key_auth_builder(
api_key: str,
jwt_handler: JWTHandler,
request_data: dict,
general_settings: dict,
route: str,
prisma_client: Optional[PrismaClient],
user_api_key_cache: DualCache,
@@ -295,14 +321,20 @@ async def _jwt_auth_user_api_key_auth_builder(
jwt_valid_token: dict = await jwt_handler.auth_jwt(token=api_key)
# check if unmatched token and enforce_rbac is true
if (
jwt_handler.litellm_jwtauth.enforce_rbac is True
and jwt_handler.get_rbac_role(token=jwt_valid_token) is None
):
raise HTTPException(
status_code=403,
detail="Unmatched token passed in. enforce_rbac is set to True. Token must belong to a proxy admin, team, or user. See how to set roles in config here: https://docs.litellm.ai/docs/proxy/token_auth#advanced---spend-tracking-end-users--internal-users--team--org",
)
if jwt_handler.litellm_jwtauth.enforce_rbac is True:
rbac_role = jwt_handler.get_rbac_role(token=jwt_valid_token)
if rbac_role is None:
raise HTTPException(
status_code=403,
detail="Unmatched token passed in. enforce_rbac is set to True. Token must belong to a proxy admin, team, or user. See how to set roles in config here: https://docs.litellm.ai/docs/proxy/token_auth#advanced---spend-tracking-end-users--internal-users--team--org",
)
else:
# run rbac validation checks
can_rbac_role_call_model(
rbac_role=rbac_role,
general_settings=general_settings,
model=request_data.get("model"),
)
# get scopes
scopes = jwt_handler.get_scopes(token=jwt_valid_token)
@@ -431,18 +463,18 @@ async def _jwt_auth_user_api_key_auth_builder(
proxy_logging_obj=proxy_logging_obj,
)
return {
"is_proxy_admin": False,
"team_id": team_id,
"team_object": team_object,
"user_id": user_id,
"user_object": user_object,
"org_id": org_id,
"org_object": org_object,
"end_user_id": end_user_id,
"end_user_object": end_user_object,
"token": api_key,
}
return JWTAuthBuilderResult(
is_proxy_admin=False,
team_id=team_id,
team_object=team_object,
user_id=user_id,
user_object=user_object,
org_id=org_id,
org_object=org_object,
end_user_id=end_user_id,
end_user_object=end_user_object,
token=api_key,
)
async def _user_api_key_auth_builder( # noqa: PLR0915
@@ -581,6 +613,8 @@ async def _user_api_key_auth_builder( # noqa: PLR0915
verbose_proxy_logger.debug("is_jwt: %s", is_jwt)
if is_jwt:
result = await _jwt_auth_user_api_key_auth_builder(
request_data=request_data,
general_settings=general_settings,
api_key=api_key,
jwt_handler=jwt_handler,
route=route,
@@ -508,3 +508,43 @@ async def test_virtual_key_soft_budget_check(spend, soft_budget, expect_alert):
assert (
alert_triggered == expect_alert
), f"Expected alert_triggered to be {expect_alert} for spend={spend}, soft_budget={soft_budget}"
@pytest.mark.asyncio
async def test_can_user_call_model():
from litellm.proxy.auth.auth_checks import can_user_call_model
from litellm.proxy._types import ProxyException
from litellm import Router
router = Router(
model_list=[
{
"model_name": "anthropic-claude",
"litellm_params": {"model": "anthropic/anthropic-claude"},
},
{
"model_name": "gpt-3.5-turbo",
"litellm_params": {"model": "gpt-3.5-turbo", "api_key": "test-api-key"},
},
]
)
args = {
"model": "anthropic-claude",
"llm_router": router,
"user_object": LiteLLM_UserTable(
user_id="testuser21@mycompany.com",
max_budget=None,
spend=0.0042295,
model_max_budget={},
model_spend={},
user_email="testuser@mycompany.com",
models=["gpt-3.5-turbo"],
),
}
with pytest.raises(ProxyException) as e:
await can_user_call_model(**args)
args["model"] = "gpt-3.5-turbo"
await can_user_call_model(**args)
@@ -855,6 +855,8 @@ async def test_jwt_user_api_key_auth_builder_enforce_rbac(enforce_rbac, monkeypa
"user_api_key_cache": Mock(),
"parent_otel_span": None,
"proxy_logging_obj": Mock(),
"request_data": {},
"general_settings": {},
}
if enforce_rbac:
@@ -877,3 +879,55 @@ def test_user_api_key_auth_end_user_str():
user_api_key_auth = UserAPIKeyAuth(**user_api_key_args)
assert user_api_key_auth.end_user_id == "1"
def test_can_rbac_role_call_model():
from litellm.proxy.auth.user_api_key_auth import can_rbac_role_call_model
from litellm.proxy._types import RoleBasedPermissions
roles_based_permissions = [
RoleBasedPermissions(
role=LitellmUserRoles.INTERNAL_USER,
models=["gpt-4"],
),
RoleBasedPermissions(
role=LitellmUserRoles.PROXY_ADMIN,
models=["anthropic-claude"],
),
]
assert can_rbac_role_call_model(
rbac_role=LitellmUserRoles.INTERNAL_USER,
general_settings={"role_permissions": roles_based_permissions},
model="gpt-4",
)
with pytest.raises(HTTPException):
can_rbac_role_call_model(
rbac_role=LitellmUserRoles.INTERNAL_USER,
general_settings={"role_permissions": roles_based_permissions},
model="gpt-4o",
)
with pytest.raises(HTTPException):
can_rbac_role_call_model(
rbac_role=LitellmUserRoles.PROXY_ADMIN,
general_settings={"role_permissions": roles_based_permissions},
model="gpt-4o",
)
def test_can_rbac_role_call_model_no_role_permissions():
from litellm.proxy.auth.user_api_key_auth import can_rbac_role_call_model
assert can_rbac_role_call_model(
rbac_role=LitellmUserRoles.INTERNAL_USER,
general_settings={},
model="gpt-4",
)
assert can_rbac_role_call_model(
rbac_role=LitellmUserRoles.PROXY_ADMIN,
general_settings={"role_permissions": []},
model="anthropic-claude",
)
+107 -2
View File
@@ -7,13 +7,17 @@ import time
from openai import AsyncOpenAI
from test_team import list_teams
from typing import Optional
from test_keys import generate_key
from fastapi import HTTPException
async def new_user(session, i, user_id=None, budget=None, budget_duration=None):
async def new_user(
session, i, user_id=None, budget=None, budget_duration=None, models=None
):
url = "http://0.0.0.0:4000/user/new"
headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"}
data = {
"models": ["azure-models"],
"models": models or ["azure-models"],
"aliases": {"mistral-7b": "gpt-3.5-turbo"},
"duration": None,
"max_budget": budget,
@@ -37,6 +41,51 @@ async def new_user(session, i, user_id=None, budget=None, budget_duration=None):
return await response.json()
async def generate_key(
session,
i,
budget=None,
budget_duration=None,
models=["azure-models", "gpt-4", "dall-e-3"],
max_parallel_requests: Optional[int] = None,
user_id: Optional[str] = None,
team_id: Optional[str] = None,
metadata: Optional[dict] = None,
calling_key="sk-1234",
):
url = "http://0.0.0.0:4000/key/generate"
headers = {
"Authorization": f"Bearer {calling_key}",
"Content-Type": "application/json",
}
data = {
"models": models,
"aliases": {"mistral-7b": "gpt-3.5-turbo"},
"duration": None,
"max_budget": budget,
"budget_duration": budget_duration,
"max_parallel_requests": max_parallel_requests,
"user_id": user_id,
"team_id": team_id,
"metadata": metadata,
}
print(f"data: {data}")
async with session.post(url, headers=headers, json=data) as response:
status = response.status
response_text = await response.text()
print(f"Response {i} (Status code: {status}):")
print(response_text)
print()
if status != 200:
raise Exception(f"Request {i} did not return a 200 status code: {status}")
return await response.json()
@pytest.mark.asyncio
async def test_user_new():
"""
@@ -210,3 +259,59 @@ async def test_global_proxy_budget_update():
new_new_spend = user_info["user_info"]["spend"]
print(f"new_spend: {new_spend}; original_spend: {original_spend}")
assert new_new_spend > new_spend
@pytest.mark.asyncio
async def test_user_model_access():
"""
- Create user with model access
- Create key with user
- Call model that user has access to -> should work
- Call wildcard model that user has access to -> should work
- Call model that user does not have access to -> should fail
- Call wildcard model that user does not have access to -> should fail
"""
import openai
async with aiohttp.ClientSession() as session:
get_user = f"krrish_{time.time()}@berri.ai"
await new_user(
session=session,
i=0,
user_id=get_user,
models=["good-model", "anthropic/*"],
)
result = await generate_key(
session=session,
i=0,
user_id=get_user,
models=[], # assign no models. Allow inheritance from user
)
key = result["key"]
await chat_completion(
session=session,
key=key,
model="anthropic/claude-3-5-haiku-20241022",
)
await chat_completion(
session=session,
key=key,
model="good-model",
)
with pytest.raises(openai.AuthenticationError):
await chat_completion(
session=session,
key=key,
model="bedrock/anthropic.claude-3-sonnet-20240229-v1:0",
)
with pytest.raises(openai.AuthenticationError):
await chat_completion(
session=session,
key=key,
model="groq/claude-3-5-haiku-20241022",
)