diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index b59fc85d4b..8faf36df4c 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -525,6 +525,7 @@ class LiteLLMRoutes(enum.Enum): # user "/user/new", "/user/update", + "/user/bulk_update", "/user/delete", "/user/info", "/user/list", diff --git a/litellm/proxy/auth/handle_jwt.py b/litellm/proxy/auth/handle_jwt.py index bfad9f0c3c..880ce3fb32 100644 --- a/litellm/proxy/auth/handle_jwt.py +++ b/litellm/proxy/auth/handle_jwt.py @@ -18,6 +18,7 @@ from fastapi import HTTPException from litellm._logging import verbose_proxy_logger from litellm.caching.caching import DualCache +from litellm.constants import DEFAULT_MANAGEMENT_OBJECT_IN_MEMORY_CACHE_TTL from litellm.litellm_core_utils.dot_notation_indexing import get_nested_value from litellm.llms.custom_httpx.httpx_handler import HTTPHandler from litellm.proxy._types import ( @@ -89,7 +90,9 @@ class JWTHandler: self.leeway = leeway @staticmethod - def is_jwt(token: str): + def is_jwt(token: Optional[str]) -> bool: + if token is None: + return False parts = token.split(".") return len(parts) == 3 @@ -1324,6 +1327,7 @@ class JWTAuthManager: jwt_valid_token: dict, user_object: Optional[LiteLLM_UserTable], prisma_client: Optional[PrismaClient], + user_api_key_cache: Optional[DualCache] = None, ) -> None: """ Sync user role and team memberships with JWT claims @@ -1348,6 +1352,12 @@ class JWTAuthManager: data={"user_role": new_role.value}, ) user_object.user_role = new_role.value + if user_api_key_cache is not None: + await user_api_key_cache.async_set_cache( + key=user_object.user_id, + value=user_object.model_dump(), + ttl=DEFAULT_MANAGEMENT_OBJECT_IN_MEMORY_CACHE_TTL, + ) # Sync team memberships jwt_team_ids = set(jwt_handler.get_team_ids_from_jwt(jwt_valid_token)) @@ -1365,6 +1375,12 @@ class JWTAuthManager: teams_ids_to_remove_user_from=list(teams_to_remove), ) user_object.teams = list(jwt_team_ids) + if user_api_key_cache is not None: + await user_api_key_cache.async_set_cache( + key=user_object.user_id, + value=user_object.model_dump(), + ttl=DEFAULT_MANAGEMENT_OBJECT_IN_MEMORY_CACHE_TTL, + ) return None @staticmethod @@ -1536,6 +1552,7 @@ class JWTAuthManager: jwt_valid_token=jwt_valid_token, user_object=user_object, prisma_client=prisma_client, + user_api_key_cache=user_api_key_cache, ) ## MAP USER TO TEAMS diff --git a/litellm/proxy/auth/route_checks.py b/litellm/proxy/auth/route_checks.py index 53cc88e3b1..26bbdef309 100644 --- a/litellm/proxy/auth/route_checks.py +++ b/litellm/proxy/auth/route_checks.py @@ -629,6 +629,7 @@ class RouteChecks: in [ "/user/new", "/user/delete", + "/user/bulk_update", "/team/new", "/team/update", "/team/delete", diff --git a/litellm/proxy/management_endpoints/budget_management_endpoints.py b/litellm/proxy/management_endpoints/budget_management_endpoints.py index 20c7f9ec41..37f13269b1 100644 --- a/litellm/proxy/management_endpoints/budget_management_endpoints.py +++ b/litellm/proxy/management_endpoints/budget_management_endpoints.py @@ -15,6 +15,7 @@ All /budget management endpoints from datetime import timedelta from fastapi import APIRouter, Depends, HTTPException +from prisma.errors import UniqueViolationError from litellm.litellm_core_utils.duration_parser import duration_in_seconds from litellm.proxy._types import * @@ -90,13 +91,21 @@ async def new_budget( budget_obj_json = budget_obj.model_dump(exclude_none=True) budget_obj_jsonified = jsonify_object(budget_obj_json) # json dump any dictionaries - response = await prisma_client.db.litellm_budgettable.create( - data={ - **budget_obj_jsonified, # type: ignore - "created_by": user_api_key_dict.user_id or litellm_proxy_admin_name, - "updated_by": user_api_key_dict.user_id or litellm_proxy_admin_name, - } # type: ignore - ) + try: + response = await prisma_client.db.litellm_budgettable.create( + data={ + **budget_obj_jsonified, # type: ignore + "created_by": user_api_key_dict.user_id or litellm_proxy_admin_name, + "updated_by": user_api_key_dict.user_id or litellm_proxy_admin_name, + } # type: ignore + ) + except UniqueViolationError: + raise HTTPException( + status_code=400, + detail={ + "error": f"Budget with id '{budget_obj.budget_id}' already exists." + }, + ) return response diff --git a/litellm/proxy/management_endpoints/internal_user_endpoints.py b/litellm/proxy/management_endpoints/internal_user_endpoints.py index ff7b89fe19..7b459cd850 100644 --- a/litellm/proxy/management_endpoints/internal_user_endpoints.py +++ b/litellm/proxy/management_endpoints/internal_user_endpoints.py @@ -6,6 +6,7 @@ These are members of a Team on LiteLLM /user/new /user/update +/user/bulk_update /user/delete /user/info /user/list diff --git a/litellm/proxy/management_endpoints/ui_sso.py b/litellm/proxy/management_endpoints/ui_sso.py index 79a9e9bdf2..a5a62813b9 100644 --- a/litellm/proxy/management_endpoints/ui_sso.py +++ b/litellm/proxy/management_endpoints/ui_sso.py @@ -204,7 +204,7 @@ def process_sso_jwt_access_token( sso_jwt_handler: Optional[JWTHandler], result: Union[OpenID, dict, None], role_mappings: Optional["RoleMappings"] = None, -) -> None: +) -> Optional[dict]: """ Process SSO JWT access token and extract team IDs and user role if available. @@ -218,6 +218,12 @@ def process_sso_jwt_access_token( sso_jwt_handler: SSO-specific JWT handler for team ID extraction result: The SSO result object to update with team IDs and role role_mappings: Optional role mappings configuration for group-based role determination + + Returns: + The decoded access token payload dict, or None if decoding failed or + inputs were missing. Callers can pass this to _sync_user_role_from_jwt_role_map + so it has access to custom role claims (e.g. custom_roles) that are + encoded inside the JWT but stripped from received_response. """ if access_token_str and result: import jwt @@ -230,7 +236,7 @@ def process_sso_jwt_access_token( verbose_proxy_logger.debug( "Access token is not a valid JWT (possibly an opaque token), skipping JWT-based extraction" ) - return + return None # Extract team IDs from access token if sso_jwt_handler is available if sso_jwt_handler: @@ -306,6 +312,10 @@ def process_sso_jwt_access_token( f"Set user_role='{user_role}' from JWT access token" ) + return access_token_payload + + return None + @router.get("/sso/key/generate", tags=["experimental"], include_in_schema=False) async def google_login( @@ -817,7 +827,7 @@ async def get_generic_sso_response( ], # sso specific jwt handler - used for restricted sso group access control generic_client_id: str, redirect_url: str, -) -> Tuple[Union[OpenID, dict], Optional[dict]]: # return received response +) -> Tuple[Union[OpenID, dict], Optional[dict], Optional[dict]]: # (result, received_response, access_token_payload) # make generic sso provider from fastapi_sso.sso.base import DiscoveryDocument from fastapi_sso.sso.generic import create_provider @@ -872,6 +882,7 @@ async def get_generic_sso_response( code_verifier: Optional[ str ] = None # assigned inside try; initialized for type tracking + access_token_payload: Optional[dict] = None # decoded JWT access token claims try: token_exchange_params = ( @@ -958,7 +969,7 @@ async def get_generic_sso_response( ) access_token_str = generic_sso.access_token - process_sso_jwt_access_token( + access_token_payload = process_sso_jwt_access_token( access_token_str, sso_jwt_handler, result, role_mappings=role_mappings ) # Delete the single-use PKCE verifier only after all downstream processing @@ -976,7 +987,7 @@ async def get_generic_sso_response( additional_generic_sso_headers_dict, ) verbose_proxy_logger.debug("generic result: %s", result) - return result or {}, received_response + return result or {}, received_response, access_token_payload async def create_team_member_add_task(team_id, user_info): @@ -1176,6 +1187,56 @@ def _build_sso_user_update_data( return update_data +async def _sync_user_role_from_jwt_role_map( + jwt_handler: Optional[JWTHandler], + received_response: Optional[dict], + user_info: Optional[Union[LiteLLM_UserTable, NewUserResponse]], + prisma_client: PrismaClient, + user_api_key_cache: DualCache, + user_defined_values: Optional[SSOUserDefinedValues], +) -> None: + """ + Apply jwt_litellm_role_map during SSO login. + + When jwt_litellm_role_map is configured with sync_user_role_and_teams=True, + this ensures SSO users get the same role mapping as API/JWT users. Without + this, the SSO path falls back to INTERNAL_USER_VIEW_ONLY for roles that + don't directly match LitellmUserRoles enum values. + """ + if jwt_handler is None or received_response is None: + return + if not jwt_handler.litellm_jwtauth.sync_user_role_and_teams: + return + if not jwt_handler.litellm_jwtauth.jwt_litellm_role_map: + return + + mapped_role = jwt_handler.map_jwt_role_to_litellm_role(received_response) + if mapped_role is None: + return + + verbose_proxy_logger.info( + f"SSO jwt_litellm_role_map matched role: {mapped_role.value}" + ) + + # Update user_defined_values so downstream code uses the mapped role + if user_defined_values is not None: + user_defined_values["user_role"] = mapped_role.value + + # Update existing DB record if role differs + if user_info is not None and user_info.user_role != mapped_role.value: + await prisma_client.db.litellm_usertable.update( + where={"user_id": user_info.user_id}, + data={"user_role": mapped_role.value}, + ) + user_info.user_role = mapped_role.value + await user_api_key_cache.async_set_cache( + key=user_info.user_id, + value=user_info.model_dump() + if hasattr(user_info, "model_dump") + else dict(user_info), + ) + + def apply_user_info_values_to_sso_user_defined_values( user_info: Optional[Union[LiteLLM_UserTable, NewUserResponse]], user_defined_values: Optional[SSOUserDefinedValues], @@ -1279,6 +1340,7 @@ async def auth_callback(request: Request, state: Optional[str] = None): # noqa: google_client_id = os.getenv("GOOGLE_CLIENT_ID", None) generic_client_id = os.getenv("GENERIC_CLIENT_ID", None) received_response: Optional[dict] = None + access_token_payload: Optional[dict] = None # get url from request if master_key is None: raise ProxyException( @@ -1307,7 +1369,7 @@ async def auth_callback(request: Request, state: Optional[str] = None): # noqa: ) elif generic_client_id is not None: - result, received_response = await get_generic_sso_response( + result, received_response, access_token_payload = await get_generic_sso_response( request=request, jwt_handler=jwt_handler, generic_client_id=generic_client_id, @@ -1345,6 +1407,8 @@ async def auth_callback(request: Request, state: Optional[str] = None): # noqa: received_response=received_response, generic_client_id=generic_client_id, ui_access_mode=ui_access_mode, + access_token_payload=access_token_payload, + jwt_handler=jwt_handler, return_to=cp_return_to, ) @@ -2417,6 +2481,8 @@ class SSOAuthenticationHandler: received_response: Optional[dict] = None, generic_client_id: Optional[str] = None, ui_access_mode: Optional[Dict] = None, + access_token_payload: Optional[dict] = None, + jwt_handler: Optional[JWTHandler] = None, return_to: Optional[str] = None, ) -> RedirectResponse: import jwt @@ -2498,6 +2564,20 @@ class SSOAuthenticationHandler: alternate_user_id=user_id, ) + # Sync user role from JWT claims via jwt_litellm_role_map (if configured). + # This ensures SSO users get the same role mapping as API/JWT users. + # Use the decoded access_token_payload (not received_response) because + # custom role claims (e.g. custom_roles) are encoded inside the JWT + # access token, which is stripped from received_response. + await _sync_user_role_from_jwt_role_map( + jwt_handler=jwt_handler, + received_response=access_token_payload or received_response, + user_info=user_info, + prisma_client=prisma_client, + user_api_key_cache=user_api_key_cache, + user_defined_values=user_defined_values, + ) + user_defined_values = apply_user_info_values_to_sso_user_defined_values( user_info=user_info, user_defined_values=user_defined_values ) @@ -3703,7 +3783,7 @@ async def debug_sso_callback(request: Request): ) elif generic_client_id is not None: - result, _ = await get_generic_sso_response( + result, _, _ = await get_generic_sso_response( request=request, jwt_handler=jwt_handler, generic_client_id=generic_client_id, diff --git a/tests/proxy_unit_tests/test_jwt.py b/tests/proxy_unit_tests/test_jwt.py index 24cf15a321..a5be1a3a42 100644 --- a/tests/proxy_unit_tests/test_jwt.py +++ b/tests/proxy_unit_tests/test_jwt.py @@ -1331,6 +1331,9 @@ def test_jwt_handler_is_jwt_static_method(): # Test with empty string assert JWTHandler.is_jwt("") == False + # Test with None (missing Authorization header) + assert JWTHandler.is_jwt(None) == False + @pytest.mark.parametrize( "requested_model, should_work", diff --git a/tests/test_litellm/proxy/auth/test_handle_jwt.py b/tests/test_litellm/proxy/auth/test_handle_jwt.py index 11939f0fdd..ada67fbba8 100644 --- a/tests/test_litellm/proxy/auth/test_handle_jwt.py +++ b/tests/test_litellm/proxy/auth/test_handle_jwt.py @@ -339,6 +339,123 @@ async def test_sync_user_role_and_teams(): assert set(user.teams) == {"team1", "team2"} +@pytest.mark.asyncio +async def test_sync_user_role_and_teams_cache_invalidation_on_role_change(): + """Test that user cache is updated when role changes.""" + mock_cache = AsyncMock() + + jwt_handler = JWTHandler() + jwt_handler.update_environment( + prisma_client=None, + user_api_key_cache=AsyncMock(), + litellm_jwtauth=LiteLLM_JWTAuth( + jwt_litellm_role_map=[ + JWTLiteLLMRoleMap(jwt_role="ADMIN", litellm_role=LitellmUserRoles.PROXY_ADMIN) + ], + roles_jwt_field="roles", + team_ids_jwt_field="my_id_teams", + sync_user_role_and_teams=True, + ), + ) + + token = {"roles": ["ADMIN"], "my_id_teams": ["team1"]} + user = LiteLLM_UserTable( + user_id="u1", + user_role=LitellmUserRoles.INTERNAL_USER.value, + teams=["team1"], # teams already match — only role differs + ) + + prisma = AsyncMock() + prisma.db.litellm_usertable.update = AsyncMock() + + await JWTAuthManager.sync_user_role_and_teams( + jwt_handler, token, user, prisma, user_api_key_cache=mock_cache + ) + + mock_cache.async_set_cache.assert_called_once() + call_kwargs = mock_cache.async_set_cache.call_args + assert call_kwargs.kwargs["key"] == "u1" + assert call_kwargs.kwargs["value"]["user_role"] == LitellmUserRoles.PROXY_ADMIN.value + + +@pytest.mark.asyncio +async def test_sync_user_role_and_teams_cache_invalidation_on_team_change(): + """Test that user cache is updated when team memberships change.""" + mock_cache = AsyncMock() + + jwt_handler = JWTHandler() + jwt_handler.update_environment( + prisma_client=None, + user_api_key_cache=AsyncMock(), + litellm_jwtauth=LiteLLM_JWTAuth( + jwt_litellm_role_map=[ + JWTLiteLLMRoleMap(jwt_role="ADMIN", litellm_role=LitellmUserRoles.PROXY_ADMIN) + ], + roles_jwt_field="roles", + team_ids_jwt_field="my_id_teams", + sync_user_role_and_teams=True, + ), + ) + + token = {"roles": ["ADMIN"], "my_id_teams": ["team1", "team2"]} + user = LiteLLM_UserTable( + user_id="u1", + user_role=LitellmUserRoles.PROXY_ADMIN.value, # role already matches + teams=["team2"], # teams differ + ) + + prisma = AsyncMock() + prisma.db.litellm_usertable.update = AsyncMock() + + with patch( + "litellm.proxy.management_endpoints.scim.scim_v2.patch_team_membership", + new_callable=AsyncMock, + ): + await JWTAuthManager.sync_user_role_and_teams( + jwt_handler, token, user, prisma, user_api_key_cache=mock_cache + ) + + mock_cache.async_set_cache.assert_called_once() + call_kwargs = mock_cache.async_set_cache.call_args + assert call_kwargs.kwargs["key"] == "u1" + assert set(call_kwargs.kwargs["value"]["teams"]) == {"team1", "team2"} + + +@pytest.mark.asyncio +async def test_sync_user_role_and_teams_no_cache_write_when_nothing_changes(): + """Test that cache is NOT written when role and teams already match.""" + mock_cache = AsyncMock() + + jwt_handler = JWTHandler() + jwt_handler.update_environment( + prisma_client=None, + user_api_key_cache=AsyncMock(), + litellm_jwtauth=LiteLLM_JWTAuth( + jwt_litellm_role_map=[ + JWTLiteLLMRoleMap(jwt_role="ADMIN", litellm_role=LitellmUserRoles.PROXY_ADMIN) + ], + roles_jwt_field="roles", + team_ids_jwt_field="my_id_teams", + sync_user_role_and_teams=True, + ), + ) + + token = {"roles": ["ADMIN"], "my_id_teams": ["team1"]} + user = LiteLLM_UserTable( + user_id="u1", + user_role=LitellmUserRoles.PROXY_ADMIN.value, + teams=["team1"], + ) + + prisma = AsyncMock() + + await JWTAuthManager.sync_user_role_and_teams( + jwt_handler, token, user, prisma, user_api_key_cache=mock_cache + ) + + mock_cache.async_set_cache.assert_not_called() + + @pytest.mark.asyncio async def test_map_jwt_role_to_litellm_role(): """Test JWT role mapping to LiteLLM roles with various patterns""" diff --git a/tests/test_litellm/proxy/auth/test_user_api_key_auth.py b/tests/test_litellm/proxy/auth/test_user_api_key_auth.py index 81ca758983..6e000f9c0e 100644 --- a/tests/test_litellm/proxy/auth/test_user_api_key_auth.py +++ b/tests/test_litellm/proxy/auth/test_user_api_key_auth.py @@ -567,6 +567,10 @@ class TestJWTOAuth2Coexistence: assert JWTHandler.is_jwt("Bearer token") is False assert JWTHandler.is_jwt("two.parts") is False + def test_is_jwt_returns_false_for_none(self): + """None token (missing Authorization header) should not be treated as JWT.""" + assert JWTHandler.is_jwt(None) is False + @pytest.mark.asyncio async def test_both_enabled_opaque_token_uses_oauth2(self): """ diff --git a/tests/test_litellm/proxy/management_endpoints/test_ui_sso.py b/tests/test_litellm/proxy/management_endpoints/test_ui_sso.py index ff636ca04a..f9c7cefcc4 100644 --- a/tests/test_litellm/proxy/management_endpoints/test_ui_sso.py +++ b/tests/test_litellm/proxy/management_endpoints/test_ui_sso.py @@ -24,6 +24,7 @@ from litellm.proxy.management_endpoints.ui_sso import ( MicrosoftSSOHandler, SSOAuthenticationHandler, _setup_team_mappings, + _sync_user_role_from_jwt_role_map, determine_role_from_groups, normalize_email, process_sso_jwt_access_token, @@ -1321,7 +1322,7 @@ async def test_get_generic_sso_response_with_additional_headers(): "fastapi_sso.sso.generic.create_provider", return_value=mock_sso_class ): # Act - result, received_response = await get_generic_sso_response( + result, received_response, _ = await get_generic_sso_response( request=mock_request, jwt_handler=mock_jwt_handler, generic_client_id=generic_client_id, @@ -1383,7 +1384,7 @@ async def test_get_generic_sso_response_with_empty_headers(): "fastapi_sso.sso.generic.create_provider", return_value=mock_sso_class ): # Act - result, received_response = await get_generic_sso_response( + result, received_response, _ = await get_generic_sso_response( request=mock_request, jwt_handler=mock_jwt_handler, generic_client_id=generic_client_id, @@ -5254,3 +5255,159 @@ class TestValidateReturnTo: ) SSOAuthenticationHandler._validate_return_to("https://cp.example.com:3000/ui") + +class TestSyncUserRoleFromJwtRoleMap: + """Tests for _sync_user_role_from_jwt_role_map.""" + + @staticmethod + def _make_jwt_handler(): + from litellm.caching.caching import DualCache + from litellm.proxy._types import ( + JWTLiteLLMRoleMap, + LiteLLM_JWTAuth, + LitellmUserRoles, + ) + + handler = JWTHandler() + handler.update_environment( + prisma_client=None, + user_api_key_cache=DualCache(), + litellm_jwtauth=LiteLLM_JWTAuth( + roles_jwt_field="custom_roles", + user_id_upsert=True, + sync_user_role_and_teams=True, + jwt_litellm_role_map=[ + JWTLiteLLMRoleMap( + jwt_role="my-admin", + litellm_role=LitellmUserRoles.PROXY_ADMIN, + ), + JWTLiteLLMRoleMap( + jwt_role="my-viewer", + litellm_role=LitellmUserRoles.INTERNAL_USER, + ), + ], + ), + ) + return handler + + @staticmethod + def _make_sso_values(user_role=None): + from litellm.proxy._types import SSOUserDefinedValues + + user_id = "testuser@example.com" + return SSOUserDefinedValues( + models=[], + user_id=user_id, + user_email=user_id, + user_role=user_role, + max_budget=None, + budget_duration=None, + ) + + @pytest.mark.asyncio + async def test_stripped_response_has_no_roles(self): + """Bug repro: stripped received_response lacks role claims.""" + from litellm.caching.caching import DualCache + + handler = self._make_jwt_handler() + sso_values = self._make_sso_values() + + await _sync_user_role_from_jwt_role_map( + jwt_handler=handler, + received_response={"token_type": "Bearer", "expires_in": 3600}, + user_info=None, + prisma_client=AsyncMock(), + user_api_key_cache=DualCache(), + user_defined_values=sso_values, + ) + + assert sso_values["user_role"] is None + + @pytest.mark.asyncio + async def test_decoded_access_token_maps_role(self): + """Decoded JWT payload with role claims maps correctly.""" + from litellm.caching.caching import DualCache + from litellm.proxy._types import LitellmUserRoles + + handler = self._make_jwt_handler() + sso_values = self._make_sso_values() + + await _sync_user_role_from_jwt_role_map( + jwt_handler=handler, + received_response={"sub": "testuser@example.com", "custom_roles": ["my-admin"]}, + user_info=None, + prisma_client=AsyncMock(), + user_api_key_cache=DualCache(), + user_defined_values=sso_values, + ) + + assert sso_values["user_role"] == LitellmUserRoles.PROXY_ADMIN.value + + @pytest.mark.asyncio + async def test_existing_user_role_updated_in_db_and_cache(self): + """Existing user with stale role gets updated in DB and cache.""" + from litellm.caching.caching import DualCache + from litellm.proxy._types import LitellmUserRoles + + handler = self._make_jwt_handler() + cache = DualCache() + prisma = AsyncMock() + prisma.db.litellm_usertable.update = AsyncMock() + user_id = "testuser@example.com" + + existing_user = LiteLLM_UserTable( + user_id=user_id, + user_role=LitellmUserRoles.INTERNAL_USER_VIEW_ONLY.value, + ) + await cache.async_set_cache(key=user_id, value=existing_user.model_dump(), ttl=60) + + sso_values = self._make_sso_values( + user_role=LitellmUserRoles.INTERNAL_USER_VIEW_ONLY.value, + ) + + await _sync_user_role_from_jwt_role_map( + jwt_handler=handler, + received_response={"sub": user_id, "custom_roles": ["my-admin"]}, + user_info=existing_user, + prisma_client=prisma, + user_api_key_cache=cache, + user_defined_values=sso_values, + ) + + prisma.db.litellm_usertable.update.assert_called_once_with( + where={"user_id": user_id}, + data={"user_role": LitellmUserRoles.PROXY_ADMIN.value}, + ) + assert existing_user.user_role == LitellmUserRoles.PROXY_ADMIN.value + assert sso_values["user_role"] == LitellmUserRoles.PROXY_ADMIN.value + + @pytest.mark.asyncio + async def test_same_role_no_db_write(self): + """No DB update when the mapped role matches the existing role.""" + from litellm.caching.caching import DualCache + from litellm.proxy._types import LitellmUserRoles + + handler = self._make_jwt_handler() + prisma = AsyncMock() + prisma.db.litellm_usertable.update = AsyncMock() + + existing_user = LiteLLM_UserTable( + user_id="testuser@example.com", + user_role=LitellmUserRoles.PROXY_ADMIN.value, + ) + + sso_values = self._make_sso_values( + user_role=LitellmUserRoles.PROXY_ADMIN.value, + ) + + await _sync_user_role_from_jwt_role_map( + jwt_handler=handler, + received_response={"sub": "testuser@example.com", "custom_roles": ["my-admin"]}, + user_info=existing_user, + prisma_client=prisma, + user_api_key_cache=DualCache(), + user_defined_values=sso_values, + ) + + prisma.db.litellm_usertable.update.assert_not_called() + diff --git a/ui/litellm-dashboard/src/app/(dashboard)/hooks/budgets/useBudgets.ts b/ui/litellm-dashboard/src/app/(dashboard)/hooks/budgets/useBudgets.ts new file mode 100644 index 0000000000..99c170b679 --- /dev/null +++ b/ui/litellm-dashboard/src/app/(dashboard)/hooks/budgets/useBudgets.ts @@ -0,0 +1,70 @@ +import { useQuery, useMutation, useQueryClient, UseQueryResult } from "@tanstack/react-query"; +import { createQueryKeys } from "../common/queryKeysFactory"; +import { getBudgetList, budgetCreateCall, budgetUpdateCall, budgetDeleteCall } from "@/components/networking"; +import useAuthorized from "@/app/(dashboard)/hooks/useAuthorized"; +import { budgetItem } from "@/components/budgets/budget_panel"; + +export const budgetKeys = createQueryKeys("budgets"); + +export const useBudgets = (): UseQueryResult => { + const { accessToken } = useAuthorized(); + return useQuery({ + queryKey: budgetKeys.list({}), + queryFn: async () => { + const data = await getBudgetList(accessToken!); + return (data ?? []).filter((item: budgetItem | null): item is budgetItem => item != null); + }, + enabled: Boolean(accessToken), + }); +}; + +export const useCreateBudget = () => { + const { accessToken } = useAuthorized(); + const queryClient = useQueryClient(); + + return useMutation>({ + mutationFn: async (formValues) => { + if (!accessToken) { + throw new Error("Access token is required"); + } + return budgetCreateCall(accessToken, formValues); + }, + onSuccess: () => { + queryClient.invalidateQueries({ queryKey: budgetKeys.all }); + }, + }); +}; + +export const useUpdateBudget = () => { + const { accessToken } = useAuthorized(); + const queryClient = useQueryClient(); + + return useMutation>({ + mutationFn: async (formValues) => { + if (!accessToken) { + throw new Error("Access token is required"); + } + return budgetUpdateCall(accessToken, formValues); + }, + onSuccess: () => { + queryClient.invalidateQueries({ queryKey: budgetKeys.all }); + }, + }); +}; + +export const useDeleteBudget = () => { + const { accessToken } = useAuthorized(); + const queryClient = useQueryClient(); + + return useMutation({ + mutationFn: async (budgetId) => { + if (!accessToken) { + throw new Error("Access token is required"); + } + return budgetDeleteCall(accessToken, budgetId); + }, + onSuccess: () => { + queryClient.invalidateQueries({ queryKey: budgetKeys.all }); + }, + }); +}; diff --git a/ui/litellm-dashboard/src/components/budgets/budget_modal.tsx b/ui/litellm-dashboard/src/components/budgets/budget_modal.tsx index 490613de25..b5ad8aaff3 100644 --- a/ui/litellm-dashboard/src/components/budgets/budget_modal.tsx +++ b/ui/litellm-dashboard/src/components/budgets/budget_modal.tsx @@ -1,17 +1,17 @@ import React from "react"; import { TextInput, Accordion, AccordionHeader, AccordionBody } from "@tremor/react"; import { Button as Button2, Modal, Form, InputNumber, Select } from "antd"; -import { budgetCreateCall } from "../networking"; +import { useCreateBudget } from "@/app/(dashboard)/hooks/budgets/useBudgets"; import NotificationsManager from "../molecules/notifications_manager"; interface BudgetModalProps { isModalVisible: boolean; - accessToken: string | null; setIsModalVisible: React.Dispatch>; - setBudgetList: React.Dispatch>; } -const BudgetModal: React.FC = ({ isModalVisible, accessToken, setIsModalVisible, setBudgetList }) => { +const BudgetModal: React.FC = ({ isModalVisible, setIsModalVisible }) => { const [form] = Form.useForm(); + const createBudget = useCreateBudget(); + const handleOk = () => { setIsModalVisible(false); form.resetFields(); @@ -23,20 +23,15 @@ const BudgetModal: React.FC = ({ isModalVisible, accessToken, }; const handleCreate = async (formValues: Record) => { - if (accessToken == null || accessToken == undefined) { - return; - } try { NotificationsManager.info("Making API Call"); - // setIsModalVisible(true); - const response = await budgetCreateCall(accessToken, formValues); - console.log("key create Response:", response); - setBudgetList((prevData) => (prevData ? [...prevData, response] : [response])); // Check if prevData is null + await createBudget.mutateAsync(formValues); NotificationsManager.success("Budget Created"); form.resetFields(); + setIsModalVisible(false); } catch (error) { - console.error("Error creating the key:", error); - NotificationsManager.fromBackend(`Error creating the key: ${error}`); + console.error("Error creating the budget:", error); + NotificationsManager.fromBackend(`Error creating the budget: ${error}`); } }; diff --git a/ui/litellm-dashboard/src/components/budgets/budget_panel.test.tsx b/ui/litellm-dashboard/src/components/budgets/budget_panel.test.tsx index 534693d398..ecae379c9f 100644 --- a/ui/litellm-dashboard/src/components/budgets/budget_panel.test.tsx +++ b/ui/litellm-dashboard/src/components/budgets/budget_panel.test.tsx @@ -1,31 +1,50 @@ -import * as networking from "../networking"; import { fireEvent, render, waitFor, screen } from "@testing-library/react"; import { act } from "@testing-library/react"; +import { QueryClient, QueryClientProvider } from "@tanstack/react-query"; import { afterEach, describe, expect, it, vi } from "vitest"; import BudgetPanel from "./budget_panel"; -vi.mock("../networking", () => ({ - getBudgetList: vi.fn(), - budgetDeleteCall: vi.fn(), +const mockBudgets = [ + { + budget_id: "budget-1", + max_budget: 100, + rpm_limit: 10, + tpm_limit: 1000, + updated_at: "2024-01-01T00:00:00Z", + }, +]; + +vi.mock("@/app/(dashboard)/hooks/budgets/useBudgets", () => ({ + useBudgets: vi.fn().mockReturnValue({ data: [], isLoading: false }), + useDeleteBudget: vi.fn().mockReturnValue({ mutateAsync: vi.fn(), isPending: false }), + useCreateBudget: vi.fn().mockReturnValue({ mutateAsync: vi.fn() }), + useUpdateBudget: vi.fn().mockReturnValue({ mutateAsync: vi.fn() }), })); +import { useBudgets, useDeleteBudget, useCreateBudget, useUpdateBudget } from "@/app/(dashboard)/hooks/budgets/useBudgets"; + +const createQueryClient = () => + new QueryClient({ + defaultOptions: { queries: { retry: false, gcTime: 0 } }, + }); + +function renderWithProviders(ui: React.ReactElement) { + const qc = createQueryClient(); + return render({ui}); +} + describe("Budget Panel", () => { afterEach(() => { vi.clearAllMocks(); }); it("should render the budget panel and load budgets", async () => { - vi.mocked(networking.getBudgetList).mockResolvedValue([ - { - budget_id: "budget-1", - max_budget: "100", - rpm_limit: 10, - tpm_limit: 1000, - updated_at: "2024-01-01T00:00:00Z", - }, - ]); + vi.mocked(useBudgets).mockReturnValue({ + data: mockBudgets, + isLoading: false, + } as any); - render(); + renderWithProviders(); await waitFor(() => { expect(screen.getByText("Create a budget to assign to customers.")).toBeInTheDocument(); @@ -34,17 +53,20 @@ describe("Budget Panel", () => { }); it("should open delete modal when clicking delete icon", async () => { - vi.mocked(networking.getBudgetList).mockResolvedValue([ - { - budget_id: "budget-to-delete", - max_budget: "200", - rpm_limit: 20, - tpm_limit: 2000, - updated_at: "2024-01-02T00:00:00Z", - }, - ]); + vi.mocked(useBudgets).mockReturnValue({ + data: [ + { + budget_id: "budget-to-delete", + max_budget: 200, + rpm_limit: 20, + tpm_limit: 2000, + updated_at: "2024-01-02T00:00:00Z", + }, + ], + isLoading: false, + } as any); - render(); + renderWithProviders(); await waitFor(() => { expect(screen.getByText("budget-to-delete")).toBeInTheDocument(); @@ -62,18 +84,25 @@ describe("Budget Panel", () => { }); it("should successfully delete a budget", async () => { - vi.mocked(networking.getBudgetList).mockResolvedValue([ - { - budget_id: "budget-to-delete", - max_budget: "200", - rpm_limit: 20, - tpm_limit: 2000, - updated_at: "2024-01-02T00:00:00Z", - }, - ]); - vi.mocked(networking.budgetDeleteCall).mockResolvedValue(undefined); + const deleteMutateAsync = vi.fn().mockResolvedValue(undefined); + vi.mocked(useBudgets).mockReturnValue({ + data: [ + { + budget_id: "budget-to-delete", + max_budget: 200, + rpm_limit: 20, + tpm_limit: 2000, + updated_at: "2024-01-02T00:00:00Z", + }, + ], + isLoading: false, + } as any); + vi.mocked(useDeleteBudget).mockReturnValue({ + mutateAsync: deleteMutateAsync, + isPending: false, + } as any); - render(); + renderWithProviders(); await waitFor(() => { expect(screen.getByText("budget-to-delete")).toBeInTheDocument(); @@ -96,24 +125,43 @@ describe("Budget Panel", () => { }); await waitFor(() => { - expect(networking.budgetDeleteCall).toHaveBeenCalledWith("token-123", "budget-to-delete"); - expect(networking.getBudgetList).toHaveBeenCalledTimes(2); // Initial load + refresh after delete + expect(deleteMutateAsync).toHaveBeenCalledWith("budget-to-delete"); + }); + }); + + it("should render empty state without crashing", async () => { + vi.mocked(useBudgets).mockReturnValue({ + data: [], + isLoading: false, + } as any); + + renderWithProviders(); + + await waitFor(() => { + expect(screen.getByText("Create a budget to assign to customers.")).toBeInTheDocument(); }); }); it("should handle delete error", async () => { - vi.mocked(networking.getBudgetList).mockResolvedValue([ - { - budget_id: "budget-to-delete", - max_budget: "200", - rpm_limit: 20, - tpm_limit: 2000, - updated_at: "2024-01-02T00:00:00Z", - }, - ]); - vi.mocked(networking.budgetDeleteCall).mockRejectedValue(new Error("Delete failed")); + const deleteMutateAsync = vi.fn().mockRejectedValue(new Error("Delete failed")); + vi.mocked(useBudgets).mockReturnValue({ + data: [ + { + budget_id: "budget-to-delete", + max_budget: 200, + rpm_limit: 20, + tpm_limit: 2000, + updated_at: "2024-01-02T00:00:00Z", + }, + ], + isLoading: false, + } as any); + vi.mocked(useDeleteBudget).mockReturnValue({ + mutateAsync: deleteMutateAsync, + isPending: false, + } as any); - render(); + renderWithProviders(); await waitFor(() => { expect(screen.getByText("budget-to-delete")).toBeInTheDocument(); @@ -136,10 +184,38 @@ describe("Budget Panel", () => { }); await waitFor(() => { - expect(networking.budgetDeleteCall).toHaveBeenCalledWith("token-123", "budget-to-delete"); + expect(deleteMutateAsync).toHaveBeenCalledWith("budget-to-delete"); + }); + }); + + it("should open edit modal when clicking edit icon", async () => { + vi.mocked(useBudgets).mockReturnValue({ + data: [ + { + budget_id: "budget-to-edit", + max_budget: 300, + rpm_limit: 30, + tpm_limit: 3000, + updated_at: "2024-01-03T00:00:00Z", + }, + ], + isLoading: false, + } as any); + + renderWithProviders(); + + await waitFor(() => { + expect(screen.getByText("budget-to-edit")).toBeInTheDocument(); }); - // Modal should still be open (error handling) - expect(screen.getByText("Delete Budget?")).toBeInTheDocument(); + const editButton = screen.getByTestId("edit-budget-button"); + + act(() => { + fireEvent.click(editButton); + }); + + await waitFor(() => { + expect(screen.getByText("Edit Budget")).toBeInTheDocument(); + }); }); }); diff --git a/ui/litellm-dashboard/src/components/budgets/budget_panel.tsx b/ui/litellm-dashboard/src/components/budgets/budget_panel.tsx index b52ef5ab94..e42d056965 100644 --- a/ui/litellm-dashboard/src/components/budgets/budget_panel.tsx +++ b/ui/litellm-dashboard/src/components/budgets/budget_panel.tsx @@ -19,12 +19,12 @@ import { TabPanels, Text, } from "@tremor/react"; -import React, { useEffect, useState } from "react"; +import React, { useState } from "react"; import { Prism as SyntaxHighlighter } from "react-syntax-highlighter"; import DeleteResourceModal from "../common_components/DeleteResourceModal"; import TableIconActionButton from "../common_components/IconActionButton/TableIconActionButtons/TableIconActionButton"; import NotificationsManager from "../molecules/notifications_manager"; -import { budgetDeleteCall, getBudgetList } from "../networking"; +import { useBudgets, useDeleteBudget } from "@/app/(dashboard)/hooks/budgets/useBudgets"; import BudgetModal from "./budget_modal"; import EditBudgetModal from "./edit_budget_modal"; import { CREATE_END_USER_CURL_COMMAND, CHAT_COMPLETIONS_CURL_COMMAND, OPENAI_SDK_PYTHON_CODE } from "./constants"; @@ -35,7 +35,7 @@ interface BudgetSettingsPageProps { export interface budgetItem { budget_id: string; - max_budget: string | null; + max_budget: number | null; rpm_limit: number | null; tpm_limit: number | null; updated_at: string; @@ -45,17 +45,10 @@ const BudgetPanel: React.FC = ({ accessToken }) => { const [isCreateModelVisible, setIsCreateModelVisible] = useState(false); const [isEditModalVisible, setIsEditModalVisible] = useState(false); const [selectedBudget, setSelectedBudget] = useState(null); - const [budgetList, setBudgetList] = useState([]); - const [isDeleting, setIsDeleting] = useState(false); const [isDeleteModalVisible, setIsDeleteModalVisible] = useState(false); - useEffect(() => { - if (!accessToken) { - return; - } - getBudgetList(accessToken).then((data) => { - setBudgetList(data); - }); - }, [accessToken]); + + const { data: budgetList = [] } = useBudgets(); + const deleteBudget = useDeleteBudget(); const handleEditCall = async (budget: budgetItem) => { if (accessToken == null) { @@ -74,11 +67,9 @@ const BudgetPanel: React.FC = ({ accessToken }) => { if (!selectedBudget || accessToken == null) { return; } - setIsDeleting(true); try { - await budgetDeleteCall(accessToken, selectedBudget.budget_id); + await deleteBudget.mutateAsync(selectedBudget.budget_id); NotificationsManager.success("Budget deleted."); - await handleUpdateCall(); } catch (error) { console.error("Error deleting budget:", error); if (typeof NotificationsManager.fromBackend === "function") { @@ -87,7 +78,6 @@ const BudgetPanel: React.FC = ({ accessToken }) => { NotificationsManager.info("Failed to delete budget"); } } finally { - setIsDeleting(false); setIsDeleteModalVisible(false); setSelectedBudget(null); } @@ -97,15 +87,6 @@ const BudgetPanel: React.FC = ({ accessToken }) => { setIsDeleteModalVisible(false); }; - const handleUpdateCall = async () => { - if (accessToken == null) { - return; - } - getBudgetList(accessToken).then((data) => { - setBudgetList(data); - }); - }; - return (