diff --git a/docs/my-website/docs/proxy/jwt_auth_arch.md b/docs/my-website/docs/proxy/jwt_auth_arch.md new file mode 100644 index 0000000000..e48fa71f8b --- /dev/null +++ b/docs/my-website/docs/proxy/jwt_auth_arch.md @@ -0,0 +1,116 @@ +import Image from '@theme/IdealImage'; +import Tabs from '@theme/Tabs'; +import TabItem from '@theme/TabItem'; + +# Control Model Access with SSO (Azure AD/Keycloak/etc.) + +:::info + +✨ JWT Auth is on LiteLLM Enterprise + +[Enterprise Pricing](https://www.litellm.ai/#pricing) + +[Get free 7-day trial key](https://www.litellm.ai/#trial) + +::: + + + +## Example Token + + + + +```bash +{ + "sub": "1234567890", + "name": "John Doe", + "email": "john.doe@example.com", + "roles": ["basic_user"] # 👈 ROLE +} +``` + + + +```bash +{ + "sub": "1234567890", + "name": "John Doe", + "email": "john.doe@example.com", + "resource_access": { + "litellm-test-client-id": { + "roles": ["basic_user"] # 👈 ROLE + } + } +} +``` + + + +## Proxy Configuration + + + + +```yaml +general_settings: + enable_jwt_auth: True + litellm_jwtauth: + user_roles_jwt_field: "roles" # the field in the JWT that contains the roles + user_allowed_roles: ["basic_user"] # roles that map to an 'internal_user' role on LiteLLM + enforce_rbac: true # if true, will check if the user has the correct role to access the model + + role_permissions: # control what models are allowed for each role + - role: internal_user + models: ["anthropic-claude"] + +model_list: + - model: anthropic-claude + litellm_params: + model: claude-3-5-haiku-20241022 + - model: openai-gpt-4o + litellm_params: + model: gpt-4o +``` + + + + +```yaml +general_settings: + enable_jwt_auth: True + litellm_jwtauth: + user_roles_jwt_field: "resource_access.litellm-test-client-id.roles" # the field in the JWT that contains the roles + user_allowed_roles: ["basic_user"] # roles that map to an 'internal_user' role on LiteLLM + enforce_rbac: true # if true, will check if the user has the correct role to access the model + + role_permissions: # control what models are allowed for each role + - role: internal_user + models: ["anthropic-claude"] + +model_list: + - model: anthropic-claude + litellm_params: + model: claude-3-5-haiku-20241022 + - model: openai-gpt-4o + litellm_params: + model: gpt-4o +``` + + + + + +## How it works + +1. Specify JWT_PUBLIC_KEY_URL - This is the public keys endpoint of your OpenID provider. For Azure AD it's `https://login.microsoftonline.com/{tenant_id}/discovery/v2.0/keys`. For Keycloak it's `{keycloak_base_url}/realms/{your-realm}/protocol/openid-connect/certs`. + +1. Map JWT roles to LiteLLM roles - Done via `user_roles_jwt_field` and `user_allowed_roles` + - Currently just `internal_user` is supported for role mapping. +2. Specify model access: + - `role_permissions`: control what models are allowed for each role. + - `role`: the LiteLLM role to control access for. Allowed roles = ["internal_user", "proxy_admin", "team"] + - `models`: list of models that the role is allowed to access. + - `model_list`: parent list of models on the proxy. [Learn more](./configs.md#llm-configs-model_list) + +3. Model Checks: The proxy will run validation checks on the received JWT. [Code](https://github.com/BerriAI/litellm/blob/3a4f5b23b5025b87b6d969f2485cc9bc741f9ba6/litellm/proxy/auth/user_api_key_auth.py#L284) \ No newline at end of file diff --git a/docs/my-website/docs/proxy/model_access.md b/docs/my-website/docs/proxy/model_access.md index 545d74865b..854baa2edb 100644 --- a/docs/my-website/docs/proxy/model_access.md +++ b/docs/my-website/docs/proxy/model_access.md @@ -344,3 +344,6 @@ curl -i http://localhost:4000/v1/chat/completions \ + + +## [Role Based Access Control (RBAC)](./jwt_auth_arch) \ No newline at end of file diff --git a/docs/my-website/docs/proxy/token_auth.md b/docs/my-website/docs/proxy/token_auth.md index ffff2694fe..df57cadd3b 100644 --- a/docs/my-website/docs/proxy/token_auth.md +++ b/docs/my-website/docs/proxy/token_auth.md @@ -1,7 +1,7 @@ import Tabs from '@theme/Tabs'; import TabItem from '@theme/TabItem'; -# JWT-based Auth +# SSO - JWT-based Auth Use JWT's to auth admins / projects into the proxy. @@ -183,6 +183,24 @@ Expected Scope in JWT: } ``` +### Control Model Access + +```yaml +general_settings: + enable_jwt_auth: True + litellm_jwtauth: + user_roles_jwt_field: "resource_access.litellm-test-client-id.roles" + user_allowed_roles: ["basic_user"] # roles that map to an 'internal_user' role on LiteLLM + enforce_rbac: true # if true, will check if the user has the correct role to access the model + endpoint + + role_permissions: # control what models + endpointsare allowed for each role + - role: internal_user + models: ["anthropic-claude"] +``` + + +**[Architecture Diagram (Control Model Access)](./jwt_auth_arch)** + ## Advanced - Allowed Routes Configure which routes a JWT can access via the config. diff --git a/docs/my-website/img/control_model_access_jwt.png b/docs/my-website/img/control_model_access_jwt.png new file mode 100644 index 0000000000..ab6cda5396 Binary files /dev/null and b/docs/my-website/img/control_model_access_jwt.png differ diff --git a/docs/my-website/sidebars.js b/docs/my-website/sidebars.js index 29d674f3f4..d20f2a73e4 100644 --- a/docs/my-website/sidebars.js +++ b/docs/my-website/sidebars.js @@ -51,7 +51,7 @@ const sidebars = { { type: "category", label: "Architecture", - items: ["proxy/architecture", "proxy/db_info", "router_architecture", "proxy/user_management_heirarchy"], + items: ["proxy/architecture", "proxy/db_info", "router_architecture", "proxy/user_management_heirarchy", "proxy/jwt_auth_arch"], }, { type: "link", diff --git a/litellm/litellm_core_utils/dot_notation_indexing.py b/litellm/litellm_core_utils/dot_notation_indexing.py new file mode 100644 index 0000000000..fda37f6500 --- /dev/null +++ b/litellm/litellm_core_utils/dot_notation_indexing.py @@ -0,0 +1,59 @@ +""" +This file contains the logic for dot notation indexing. + +Used by JWT Auth to get the user role from the token. +""" + +from typing import Any, Dict, Optional, TypeVar + +T = TypeVar("T") + + +def get_nested_value( + data: Dict[str, Any], key_path: str, default: Optional[T] = None +) -> Optional[T]: + """ + Retrieves a value from a nested dictionary using dot notation. + + Args: + data: The dictionary to search in + key_path: The path to the value using dot notation (e.g., "a.b.c") + default: The default value to return if the path is not found + + Returns: + The value at the specified path, or the default value if not found + + Example: + >>> data = {"a": {"b": {"c": "value"}}} + >>> get_nested_value(data, "a.b.c") + 'value' + >>> get_nested_value(data, "a.b.d", "default") + 'default' + """ + if not key_path: + return default + + # Remove metadata. prefix if it exists + key_path = ( + key_path.replace("metadata.", "", 1) + if key_path.startswith("metadata.") + else key_path + ) + + # Split the key path into parts + parts = key_path.split(".") + + # Traverse through the dictionary + current: Any = data + for part in parts: + try: + current = current[part] + except (KeyError, TypeError): + return default + + # If default is None, we can return any type + if default is None: + return current + + # Otherwise, ensure the type matches the default + return current if isinstance(current, type(default)) else default diff --git a/litellm/proxy/_new_secret_config.yaml b/litellm/proxy/_new_secret_config.yaml index 983525f495..423032ac86 100644 --- a/litellm/proxy/_new_secret_config.yaml +++ b/litellm/proxy/_new_secret_config.yaml @@ -1,18 +1,16 @@ model_list: - - model_name: gpt-3.5-turbo-end-user-test + - model_name: gpt-3.5-turbo litellm_params: model: gpt-3.5-turbo - region_name: "eu" - model_info: - id: "1" - - model_name: gpt-3.5-turbo-end-user-test - litellm_params: - model: gpt-3.5-turbo - timeout: 2 - num_retries: 0 - model_name: anthropic-claude litellm_params: - model: anthropic.claude-3-sonnet-20240229-v1:0 - -litellm_settings: - callbacks: ["langsmith"] \ No newline at end of file + model: claude-3-5-haiku-20241022 + - model_name: groq/* + litellm_params: + model: groq/* + api_key: os.environ/GROQ_API_KEY + mock_response: Hi! + - model_name: deepseek/* + litellm_params: + model: deepseek/* + api_key: os.environ/DEEPSEEK_API_KEY diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index 5a456aec97..bf3f6b6543 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -445,6 +445,8 @@ class LiteLLM_JWTAuth(LiteLLMPydanticObjectBase): user_id_jwt_field: Optional[str] = None user_email_jwt_field: Optional[str] = None user_allowed_email_domain: Optional[str] = None + user_roles_jwt_field: Optional[str] = None + user_allowed_roles: Optional[List[str]] = None user_id_upsert: bool = Field( default=False, description="If user doesn't exist, upsert them into the db." ) @@ -458,11 +460,19 @@ class LiteLLM_JWTAuth(LiteLLMPydanticObjectBase): allowed_keys = self.__annotations__.keys() invalid_keys = set(kwargs.keys()) - allowed_keys + user_roles_jwt_field = kwargs.get("user_roles_jwt_field") + user_allowed_roles = kwargs.get("user_allowed_roles") if invalid_keys: raise ValueError( f"Invalid arguments provided: {', '.join(invalid_keys)}. Allowed arguments are: {', '.join(allowed_keys)}." ) + if (user_roles_jwt_field is not None and user_allowed_roles is None) or ( + user_roles_jwt_field is None and user_allowed_roles is not None + ): + raise ValueError( + "user_allowed_roles must be provided if user_roles_jwt_field is set." + ) super().__init__(**kwargs) @@ -2335,3 +2345,15 @@ class ClientSideFallbackModel(TypedDict, total=False): ALL_FALLBACK_MODEL_VALUES = Union[str, ClientSideFallbackModel] + + +RBAC_ROLES = Literal[ + LitellmUserRoles.PROXY_ADMIN, + LitellmUserRoles.TEAM, + LitellmUserRoles.INTERNAL_USER, +] + + +class RoleBasedPermissions(TypedDict): + role: Required[RBAC_ROLES] + models: Required[List[str]] diff --git a/litellm/proxy/auth/auth_checks.py b/litellm/proxy/auth/auth_checks.py index d6bbf760bd..8d0132709c 100644 --- a/litellm/proxy/auth/auth_checks.py +++ b/litellm/proxy/auth/auth_checks.py @@ -12,7 +12,7 @@ import asyncio import re import time import traceback -from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional +from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, cast from fastapi import status from pydantic import BaseModel @@ -24,6 +24,7 @@ from litellm.caching.dual_cache import LimitedSizeOrderedDict from litellm.litellm_core_utils.get_llm_provider_logic import get_llm_provider from litellm.proxy._types import ( DB_CONNECTION_ERROR_TYPES, + RBAC_ROLES, CallInfo, LiteLLM_EndUserTable, LiteLLM_JWTAuth, @@ -35,6 +36,7 @@ from litellm.proxy._types import ( LitellmUserRoles, ProxyErrorTypes, ProxyException, + RoleBasedPermissions, UserAPIKeyAuth, ) from litellm.proxy.auth.route_checks import RouteChecks @@ -100,6 +102,14 @@ async def common_checks( llm_router=llm_router, ) + ## 2.1 If user can call model (if personal key) + if team_object is None and user_object is not None: + await can_user_call_model( + model=_model, + llm_router=llm_router, + user_object=user_object, + ) + # 3. If team is in budget await _team_max_budget_check( team_object=team_object, @@ -391,6 +401,30 @@ def _update_last_db_access_time( last_db_access_time[key] = (value, time.time()) +def get_role_based_models( + rbac_role: RBAC_ROLES, + general_settings: dict, +) -> Optional[List[str]]: + """ + Get the models allowed for a user role. + + Used by JWT Auth. + """ + + role_based_permissions = cast( + Optional[List[RoleBasedPermissions]], + general_settings.get("role_permissions", []), + ) + if role_based_permissions is None: + return None + + for role_based_permission in role_based_permissions: + if role_based_permission["role"] == rbac_role: + return role_based_permission["models"] + + return None + + @log_db_metrics async def get_user_object( user_id: str, @@ -836,6 +870,68 @@ async def get_org_object( ) +async def _can_object_call_model( + model: str, + llm_router: Optional[Router], + models: List[str], +) -> Literal[True]: + """ + Checks if token can call a given model + + Returns: + - True: if token allowed to call model + + Raises: + - Exception: If token not allowed to call model + """ + if model in litellm.model_alias_map: + model = litellm.model_alias_map[model] + + ## check if model in allowed model names + from collections import defaultdict + + access_groups: Dict[str, List[str]] = defaultdict(list) + + if llm_router: + access_groups = llm_router.get_model_access_groups(model_name=model) + if ( + len(access_groups) > 0 and llm_router is not None + ): # check if token contains any model access groups + for idx, m in enumerate( + models + ): # loop token models, if any of them are an access group add the access group + if m in access_groups: + return True + + # Filter out models that are access_groups + filtered_models = [m for m in models if m not in access_groups] + + verbose_proxy_logger.debug(f"model: {model}; allowed_models: {filtered_models}") + + if _model_matches_any_wildcard_pattern_in_list( + model=model, allowed_model_list=filtered_models + ): + return True + + all_model_access: bool = False + + if (len(filtered_models) == 0 and len(models) == 0) or "*" in filtered_models: + all_model_access = True + + if model is not None and model not in filtered_models and all_model_access is False: + raise ProxyException( + message=f"API Key not allowed to access model. This token can only access models={models}. Tried to access {model}", + type=ProxyErrorTypes.key_model_access_denied, + param="model", + code=status.HTTP_401_UNAUTHORIZED, + ) + + verbose_proxy_logger.debug( + f"filtered allowed_models: {filtered_models}; models: {models}" + ) + return True + + async def can_key_call_model( model: str, llm_model_list: Optional[list], @@ -851,57 +947,27 @@ async def can_key_call_model( Raises: - Exception: If token not allowed to call model """ - if model in litellm.model_alias_map: - model = litellm.model_alias_map[model] - - ## check if model in allowed model names - verbose_proxy_logger.debug( - f"LLM Model List pre access group check: {llm_model_list}" + return await _can_object_call_model( + model=model, + llm_router=llm_router, + models=valid_token.models, ) - from collections import defaultdict - access_groups: Dict[str, List[str]] = defaultdict(list) - if llm_router: - access_groups = llm_router.get_model_access_groups(model_name=model) - if ( - len(access_groups) > 0 and llm_router is not None - ): # check if token contains any model access groups - for idx, m in enumerate( - valid_token.models - ): # loop token models, if any of them are an access group add the access group - if m in access_groups: - return True +async def can_user_call_model( + model: str, + llm_router: Optional[Router], + user_object: Optional[LiteLLM_UserTable], +) -> Literal[True]: - # Filter out models that are access_groups - filtered_models = [m for m in valid_token.models if m not in access_groups] - - verbose_proxy_logger.debug(f"model: {model}; allowed_models: {filtered_models}") - - if _model_matches_any_wildcard_pattern_in_list( - model=model, allowed_model_list=filtered_models - ): + if user_object is None: return True - all_model_access: bool = False - - if ( - len(filtered_models) == 0 and len(valid_token.models) == 0 - ) or "*" in filtered_models: - all_model_access = True - - if model is not None and model not in filtered_models and all_model_access is False: - raise ProxyException( - message=f"API Key not allowed to access model. This token can only access models={valid_token.models}. Tried to access {model}", - type=ProxyErrorTypes.key_model_access_denied, - param="model", - code=status.HTTP_401_UNAUTHORIZED, - ) - valid_token.models = filtered_models - verbose_proxy_logger.debug( - f"filtered allowed_models: {filtered_models}; valid_token.models: {valid_token.models}" + return await _can_object_call_model( + model=model, + llm_router=llm_router, + models=user_object.models, ) - return True async def is_valid_fallback_model( @@ -1161,7 +1227,11 @@ def _model_custom_llm_provider_matches_wildcard_pattern( - `model=claude-3-5-sonnet-20240620` - `allowed_model_pattern=anthropic/*` """ - model, custom_llm_provider, _, _ = get_llm_provider(model=model) + try: + model, custom_llm_provider, _, _ = get_llm_provider(model=model) + except Exception: + return False + return is_model_allowed_by_pattern( model=f"{custom_llm_provider}/{model}", allowed_model_pattern=allowed_model_pattern, diff --git a/litellm/proxy/auth/handle_jwt.py b/litellm/proxy/auth/handle_jwt.py index cf57011546..bcda413b68 100644 --- a/litellm/proxy/auth/handle_jwt.py +++ b/litellm/proxy/auth/handle_jwt.py @@ -16,8 +16,10 @@ from cryptography.hazmat.primitives import serialization from litellm._logging import verbose_proxy_logger from litellm.caching.caching import DualCache +from litellm.litellm_core_utils.dot_notation_indexing import get_nested_value from litellm.llms.custom_httpx.httpx_handler import HTTPHandler from litellm.proxy._types import ( + RBAC_ROLES, JWKKeyValue, JWTKeyItem, LiteLLM_JWTAuth, @@ -59,7 +61,7 @@ class JWTHandler: parts = token.split(".") return len(parts) == 3 - def get_rbac_role(self, token: dict) -> Optional[LitellmUserRoles]: + def get_rbac_role(self, token: dict) -> Optional[RBAC_ROLES]: """ Returns the RBAC role the token 'belongs' to. @@ -78,12 +80,18 @@ class JWTHandler: """ scopes = self.get_scopes(token=token) is_admin = self.is_admin(scopes=scopes) + user_roles = self.get_user_roles(token=token, default_value=None) + if is_admin: return LitellmUserRoles.PROXY_ADMIN elif self.get_team_id(token=token, default_value=None) is not None: return LitellmUserRoles.TEAM elif self.get_user_id(token=token, default_value=None) is not None: return LitellmUserRoles.INTERNAL_USER + elif user_roles is not None and self.is_allowed_user_role( + user_roles=user_roles + ): + return LitellmUserRoles.INTERNAL_USER return None @@ -166,6 +174,43 @@ class JWTHandler: user_id = default_value return user_id + def get_user_roles( + self, token: dict, default_value: Optional[List[str]] + ) -> Optional[List[str]]: + """ + Returns the user role from the token. + + Set via 'user_roles_jwt_field' in the config. + """ + try: + if self.litellm_jwtauth.user_roles_jwt_field is not None: + user_roles = get_nested_value( + data=token, + key_path=self.litellm_jwtauth.user_roles_jwt_field, + default=default_value, + ) + else: + user_roles = default_value + except KeyError: + user_roles = default_value + return user_roles + + def is_allowed_user_role(self, user_roles: Optional[List[str]]) -> bool: + """ + Returns the user role from the token. + + Set via 'user_allowed_roles' in the config. + """ + if ( + user_roles is not None + and self.litellm_jwtauth.user_allowed_roles is not None + and any( + role in self.litellm_jwtauth.user_allowed_roles for role in user_roles + ) + ): + return True + return False + def get_user_email( self, token: dict, default_value: Optional[str] ) -> Optional[str]: diff --git a/litellm/proxy/auth/user_api_key_auth.py b/litellm/proxy/auth/user_api_key_auth.py index 33247308f6..7d499af5b2 100644 --- a/litellm/proxy/auth/user_api_key_auth.py +++ b/litellm/proxy/auth/user_api_key_auth.py @@ -33,6 +33,7 @@ from litellm.proxy.auth.auth_checks import ( get_end_user_object, get_key_object, get_org_object, + get_role_based_models, get_team_object, get_user_object, is_valid_fallback_model, @@ -281,9 +282,34 @@ def get_rbac_role(jwt_handler: JWTHandler, scopes: List[str]) -> str: return LitellmUserRoles.TEAM +def can_rbac_role_call_model( + rbac_role: RBAC_ROLES, + general_settings: dict, + model: Optional[str], +) -> Literal[True]: + """ + Checks if user is allowed to access the model, based on their role. + """ + role_based_models = get_role_based_models( + rbac_role=rbac_role, general_settings=general_settings + ) + if role_based_models is None or model is None: + return True + + if model not in role_based_models: + raise HTTPException( + status_code=403, + detail=f"User role={rbac_role} not allowed to call model={model}. Allowed models={role_based_models}", + ) + + return True + + async def _jwt_auth_user_api_key_auth_builder( api_key: str, jwt_handler: JWTHandler, + request_data: dict, + general_settings: dict, route: str, prisma_client: Optional[PrismaClient], user_api_key_cache: DualCache, @@ -295,14 +321,20 @@ async def _jwt_auth_user_api_key_auth_builder( jwt_valid_token: dict = await jwt_handler.auth_jwt(token=api_key) # check if unmatched token and enforce_rbac is true - if ( - jwt_handler.litellm_jwtauth.enforce_rbac is True - and jwt_handler.get_rbac_role(token=jwt_valid_token) is None - ): - raise HTTPException( - status_code=403, - detail="Unmatched token passed in. enforce_rbac is set to True. Token must belong to a proxy admin, team, or user. See how to set roles in config here: https://docs.litellm.ai/docs/proxy/token_auth#advanced---spend-tracking-end-users--internal-users--team--org", - ) + if jwt_handler.litellm_jwtauth.enforce_rbac is True: + rbac_role = jwt_handler.get_rbac_role(token=jwt_valid_token) + if rbac_role is None: + raise HTTPException( + status_code=403, + detail="Unmatched token passed in. enforce_rbac is set to True. Token must belong to a proxy admin, team, or user. See how to set roles in config here: https://docs.litellm.ai/docs/proxy/token_auth#advanced---spend-tracking-end-users--internal-users--team--org", + ) + else: + # run rbac validation checks + can_rbac_role_call_model( + rbac_role=rbac_role, + general_settings=general_settings, + model=request_data.get("model"), + ) # get scopes scopes = jwt_handler.get_scopes(token=jwt_valid_token) @@ -431,18 +463,18 @@ async def _jwt_auth_user_api_key_auth_builder( proxy_logging_obj=proxy_logging_obj, ) - return { - "is_proxy_admin": False, - "team_id": team_id, - "team_object": team_object, - "user_id": user_id, - "user_object": user_object, - "org_id": org_id, - "org_object": org_object, - "end_user_id": end_user_id, - "end_user_object": end_user_object, - "token": api_key, - } + return JWTAuthBuilderResult( + is_proxy_admin=False, + team_id=team_id, + team_object=team_object, + user_id=user_id, + user_object=user_object, + org_id=org_id, + org_object=org_object, + end_user_id=end_user_id, + end_user_object=end_user_object, + token=api_key, + ) async def _user_api_key_auth_builder( # noqa: PLR0915 @@ -581,6 +613,8 @@ async def _user_api_key_auth_builder( # noqa: PLR0915 verbose_proxy_logger.debug("is_jwt: %s", is_jwt) if is_jwt: result = await _jwt_auth_user_api_key_auth_builder( + request_data=request_data, + general_settings=general_settings, api_key=api_key, jwt_handler=jwt_handler, route=route, diff --git a/tests/proxy_unit_tests/test_auth_checks.py b/tests/proxy_unit_tests/test_auth_checks.py index 85b5b216a5..04af3d6e29 100644 --- a/tests/proxy_unit_tests/test_auth_checks.py +++ b/tests/proxy_unit_tests/test_auth_checks.py @@ -508,3 +508,43 @@ async def test_virtual_key_soft_budget_check(spend, soft_budget, expect_alert): assert ( alert_triggered == expect_alert ), f"Expected alert_triggered to be {expect_alert} for spend={spend}, soft_budget={soft_budget}" + + +@pytest.mark.asyncio +async def test_can_user_call_model(): + from litellm.proxy.auth.auth_checks import can_user_call_model + from litellm.proxy._types import ProxyException + from litellm import Router + + router = Router( + model_list=[ + { + "model_name": "anthropic-claude", + "litellm_params": {"model": "anthropic/anthropic-claude"}, + }, + { + "model_name": "gpt-3.5-turbo", + "litellm_params": {"model": "gpt-3.5-turbo", "api_key": "test-api-key"}, + }, + ] + ) + + args = { + "model": "anthropic-claude", + "llm_router": router, + "user_object": LiteLLM_UserTable( + user_id="testuser21@mycompany.com", + max_budget=None, + spend=0.0042295, + model_max_budget={}, + model_spend={}, + user_email="testuser@mycompany.com", + models=["gpt-3.5-turbo"], + ), + } + + with pytest.raises(ProxyException) as e: + await can_user_call_model(**args) + + args["model"] = "gpt-3.5-turbo" + await can_user_call_model(**args) diff --git a/tests/proxy_unit_tests/test_user_api_key_auth.py b/tests/proxy_unit_tests/test_user_api_key_auth.py index a428a29c63..3e9ba17889 100644 --- a/tests/proxy_unit_tests/test_user_api_key_auth.py +++ b/tests/proxy_unit_tests/test_user_api_key_auth.py @@ -855,6 +855,8 @@ async def test_jwt_user_api_key_auth_builder_enforce_rbac(enforce_rbac, monkeypa "user_api_key_cache": Mock(), "parent_otel_span": None, "proxy_logging_obj": Mock(), + "request_data": {}, + "general_settings": {}, } if enforce_rbac: @@ -877,3 +879,55 @@ def test_user_api_key_auth_end_user_str(): user_api_key_auth = UserAPIKeyAuth(**user_api_key_args) assert user_api_key_auth.end_user_id == "1" + + +def test_can_rbac_role_call_model(): + from litellm.proxy.auth.user_api_key_auth import can_rbac_role_call_model + from litellm.proxy._types import RoleBasedPermissions + + roles_based_permissions = [ + RoleBasedPermissions( + role=LitellmUserRoles.INTERNAL_USER, + models=["gpt-4"], + ), + RoleBasedPermissions( + role=LitellmUserRoles.PROXY_ADMIN, + models=["anthropic-claude"], + ), + ] + + assert can_rbac_role_call_model( + rbac_role=LitellmUserRoles.INTERNAL_USER, + general_settings={"role_permissions": roles_based_permissions}, + model="gpt-4", + ) + + with pytest.raises(HTTPException): + can_rbac_role_call_model( + rbac_role=LitellmUserRoles.INTERNAL_USER, + general_settings={"role_permissions": roles_based_permissions}, + model="gpt-4o", + ) + + with pytest.raises(HTTPException): + can_rbac_role_call_model( + rbac_role=LitellmUserRoles.PROXY_ADMIN, + general_settings={"role_permissions": roles_based_permissions}, + model="gpt-4o", + ) + + +def test_can_rbac_role_call_model_no_role_permissions(): + from litellm.proxy.auth.user_api_key_auth import can_rbac_role_call_model + + assert can_rbac_role_call_model( + rbac_role=LitellmUserRoles.INTERNAL_USER, + general_settings={}, + model="gpt-4", + ) + + assert can_rbac_role_call_model( + rbac_role=LitellmUserRoles.PROXY_ADMIN, + general_settings={"role_permissions": []}, + model="anthropic-claude", + ) diff --git a/tests/test_users.py b/tests/test_users.py index 7e267ac4df..812783681c 100644 --- a/tests/test_users.py +++ b/tests/test_users.py @@ -7,13 +7,17 @@ import time from openai import AsyncOpenAI from test_team import list_teams from typing import Optional +from test_keys import generate_key +from fastapi import HTTPException -async def new_user(session, i, user_id=None, budget=None, budget_duration=None): +async def new_user( + session, i, user_id=None, budget=None, budget_duration=None, models=None +): url = "http://0.0.0.0:4000/user/new" headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"} data = { - "models": ["azure-models"], + "models": models or ["azure-models"], "aliases": {"mistral-7b": "gpt-3.5-turbo"}, "duration": None, "max_budget": budget, @@ -37,6 +41,51 @@ async def new_user(session, i, user_id=None, budget=None, budget_duration=None): return await response.json() +async def generate_key( + session, + i, + budget=None, + budget_duration=None, + models=["azure-models", "gpt-4", "dall-e-3"], + max_parallel_requests: Optional[int] = None, + user_id: Optional[str] = None, + team_id: Optional[str] = None, + metadata: Optional[dict] = None, + calling_key="sk-1234", +): + url = "http://0.0.0.0:4000/key/generate" + headers = { + "Authorization": f"Bearer {calling_key}", + "Content-Type": "application/json", + } + data = { + "models": models, + "aliases": {"mistral-7b": "gpt-3.5-turbo"}, + "duration": None, + "max_budget": budget, + "budget_duration": budget_duration, + "max_parallel_requests": max_parallel_requests, + "user_id": user_id, + "team_id": team_id, + "metadata": metadata, + } + + print(f"data: {data}") + + async with session.post(url, headers=headers, json=data) as response: + status = response.status + response_text = await response.text() + + print(f"Response {i} (Status code: {status}):") + print(response_text) + print() + + if status != 200: + raise Exception(f"Request {i} did not return a 200 status code: {status}") + + return await response.json() + + @pytest.mark.asyncio async def test_user_new(): """ @@ -210,3 +259,59 @@ async def test_global_proxy_budget_update(): new_new_spend = user_info["user_info"]["spend"] print(f"new_spend: {new_spend}; original_spend: {original_spend}") assert new_new_spend > new_spend + + +@pytest.mark.asyncio +async def test_user_model_access(): + """ + - Create user with model access + - Create key with user + - Call model that user has access to -> should work + - Call wildcard model that user has access to -> should work + - Call model that user does not have access to -> should fail + - Call wildcard model that user does not have access to -> should fail + """ + import openai + + async with aiohttp.ClientSession() as session: + get_user = f"krrish_{time.time()}@berri.ai" + await new_user( + session=session, + i=0, + user_id=get_user, + models=["good-model", "anthropic/*"], + ) + + result = await generate_key( + session=session, + i=0, + user_id=get_user, + models=[], # assign no models. Allow inheritance from user + ) + key = result["key"] + + await chat_completion( + session=session, + key=key, + model="anthropic/claude-3-5-haiku-20241022", + ) + + await chat_completion( + session=session, + key=key, + model="good-model", + ) + + with pytest.raises(openai.AuthenticationError): + await chat_completion( + session=session, + key=key, + model="bedrock/anthropic.claude-3-sonnet-20240229-v1:0", + ) + + with pytest.raises(openai.AuthenticationError): + await chat_completion( + session=session, + key=key, + model="groq/claude-3-5-haiku-20241022", + )