From e24819afefd385fcad4a8468dcbdcf09c628b648 Mon Sep 17 00:00:00 2001 From: Ryan Crabbe Date: Fri, 27 Mar 2026 13:50:30 -0700 Subject: [PATCH 1/8] fix(sso): pass decoded JWT access token to role mapping during SSO login During SSO login, bearer tokens are stripped from the OAuth response before role mapping runs. Custom role claims encoded inside the JWT access token are lost, so map_jwt_role_to_litellm_role() returns None and the user falls back to internal_user_viewer. process_sso_jwt_access_token() now returns the decoded JWT payload, and a new _sync_user_role_from_jwt_role_map() receives it so jwt_litellm_role_map works correctly during SSO login. --- litellm/proxy/management_endpoints/ui_sso.py | 94 +++++++++- .../proxy/management_endpoints/test_ui_sso.py | 161 +++++++++++++++++- 2 files changed, 246 insertions(+), 9 deletions(-) diff --git a/litellm/proxy/management_endpoints/ui_sso.py b/litellm/proxy/management_endpoints/ui_sso.py index 79a9e9bdf2..a5a62813b9 100644 --- a/litellm/proxy/management_endpoints/ui_sso.py +++ b/litellm/proxy/management_endpoints/ui_sso.py @@ -204,7 +204,7 @@ def process_sso_jwt_access_token( sso_jwt_handler: Optional[JWTHandler], result: Union[OpenID, dict, None], role_mappings: Optional["RoleMappings"] = None, -) -> None: +) -> Optional[dict]: """ Process SSO JWT access token and extract team IDs and user role if available. @@ -218,6 +218,12 @@ def process_sso_jwt_access_token( sso_jwt_handler: SSO-specific JWT handler for team ID extraction result: The SSO result object to update with team IDs and role role_mappings: Optional role mappings configuration for group-based role determination + + Returns: + The decoded access token payload dict, or None if decoding failed or + inputs were missing. Callers can pass this to _sync_user_role_from_jwt_role_map + so it has access to custom role claims (e.g. custom_roles) that are + encoded inside the JWT but stripped from received_response. """ if access_token_str and result: import jwt @@ -230,7 +236,7 @@ def process_sso_jwt_access_token( verbose_proxy_logger.debug( "Access token is not a valid JWT (possibly an opaque token), skipping JWT-based extraction" ) - return + return None # Extract team IDs from access token if sso_jwt_handler is available if sso_jwt_handler: @@ -306,6 +312,10 @@ def process_sso_jwt_access_token( f"Set user_role='{user_role}' from JWT access token" ) + return access_token_payload + + return None + @router.get("/sso/key/generate", tags=["experimental"], include_in_schema=False) async def google_login( @@ -817,7 +827,7 @@ async def get_generic_sso_response( ], # sso specific jwt handler - used for restricted sso group access control generic_client_id: str, redirect_url: str, -) -> Tuple[Union[OpenID, dict], Optional[dict]]: # return received response +) -> Tuple[Union[OpenID, dict], Optional[dict], Optional[dict]]: # (result, received_response, access_token_payload) # make generic sso provider from fastapi_sso.sso.base import DiscoveryDocument from fastapi_sso.sso.generic import create_provider @@ -872,6 +882,7 @@ async def get_generic_sso_response( code_verifier: Optional[ str ] = None # assigned inside try; initialized for type tracking + access_token_payload: Optional[dict] = None # decoded JWT access token claims try: token_exchange_params = ( @@ -958,7 +969,7 @@ async def get_generic_sso_response( ) access_token_str = generic_sso.access_token - process_sso_jwt_access_token( + access_token_payload = process_sso_jwt_access_token( access_token_str, sso_jwt_handler, result, role_mappings=role_mappings ) # Delete the single-use PKCE verifier only after all downstream processing @@ -976,7 +987,7 @@ async def get_generic_sso_response( additional_generic_sso_headers_dict, ) verbose_proxy_logger.debug("generic result: %s", result) - return result or {}, received_response + return result or {}, received_response, access_token_payload async def create_team_member_add_task(team_id, user_info): @@ -1176,6 +1187,56 @@ def _build_sso_user_update_data( return update_data +async def _sync_user_role_from_jwt_role_map( + jwt_handler: Optional[JWTHandler], + received_response: Optional[dict], + user_info: Optional[Union[LiteLLM_UserTable, NewUserResponse]], + prisma_client: PrismaClient, + user_api_key_cache: DualCache, + user_defined_values: Optional[SSOUserDefinedValues], +) -> None: + """ + Apply jwt_litellm_role_map during SSO login. + + When jwt_litellm_role_map is configured with sync_user_role_and_teams=True, + this ensures SSO users get the same role mapping as API/JWT users. Without + this, the SSO path falls back to INTERNAL_USER_VIEW_ONLY for roles that + don't directly match LitellmUserRoles enum values. + """ + if jwt_handler is None or received_response is None: + return + if not jwt_handler.litellm_jwtauth.sync_user_role_and_teams: + return + if not jwt_handler.litellm_jwtauth.jwt_litellm_role_map: + return + + mapped_role = jwt_handler.map_jwt_role_to_litellm_role(received_response) + if mapped_role is None: + return + + verbose_proxy_logger.info( + f"SSO jwt_litellm_role_map matched role: {mapped_role.value}" + ) + + # Update user_defined_values so downstream code uses the mapped role + if user_defined_values is not None: + user_defined_values["user_role"] = mapped_role.value + + # Update existing DB record if role differs + if user_info is not None and user_info.user_role != mapped_role.value: + await prisma_client.db.litellm_usertable.update( + where={"user_id": user_info.user_id}, + data={"user_role": mapped_role.value}, + ) + user_info.user_role = mapped_role.value + await user_api_key_cache.async_set_cache( + key=user_info.user_id, + value=user_info.model_dump() + if hasattr(user_info, "model_dump") + else dict(user_info), + ) + + def apply_user_info_values_to_sso_user_defined_values( user_info: Optional[Union[LiteLLM_UserTable, NewUserResponse]], user_defined_values: Optional[SSOUserDefinedValues], @@ -1279,6 +1340,7 @@ async def auth_callback(request: Request, state: Optional[str] = None): # noqa: google_client_id = os.getenv("GOOGLE_CLIENT_ID", None) generic_client_id = os.getenv("GENERIC_CLIENT_ID", None) received_response: Optional[dict] = None + access_token_payload: Optional[dict] = None # get url from request if master_key is None: raise ProxyException( @@ -1307,7 +1369,7 @@ async def auth_callback(request: Request, state: Optional[str] = None): # noqa: ) elif generic_client_id is not None: - result, received_response = await get_generic_sso_response( + result, received_response, access_token_payload = await get_generic_sso_response( request=request, jwt_handler=jwt_handler, generic_client_id=generic_client_id, @@ -1345,6 +1407,8 @@ async def auth_callback(request: Request, state: Optional[str] = None): # noqa: received_response=received_response, generic_client_id=generic_client_id, ui_access_mode=ui_access_mode, + access_token_payload=access_token_payload, + jwt_handler=jwt_handler, return_to=cp_return_to, ) @@ -2417,6 +2481,8 @@ class SSOAuthenticationHandler: received_response: Optional[dict] = None, generic_client_id: Optional[str] = None, ui_access_mode: Optional[Dict] = None, + access_token_payload: Optional[dict] = None, + jwt_handler: Optional[JWTHandler] = None, return_to: Optional[str] = None, ) -> RedirectResponse: import jwt @@ -2498,6 +2564,20 @@ class SSOAuthenticationHandler: alternate_user_id=user_id, ) + # Sync user role from JWT claims via jwt_litellm_role_map (if configured). + # This ensures SSO users get the same role mapping as API/JWT users. + # Use the decoded access_token_payload (not received_response) because + # custom role claims (e.g. custom_roles) are encoded inside the JWT + # access token, which is stripped from received_response. + await _sync_user_role_from_jwt_role_map( + jwt_handler=jwt_handler, + received_response=access_token_payload or received_response, + user_info=user_info, + prisma_client=prisma_client, + user_api_key_cache=user_api_key_cache, + user_defined_values=user_defined_values, + ) + user_defined_values = apply_user_info_values_to_sso_user_defined_values( user_info=user_info, user_defined_values=user_defined_values ) @@ -3703,7 +3783,7 @@ async def debug_sso_callback(request: Request): ) elif generic_client_id is not None: - result, _ = await get_generic_sso_response( + result, _, _ = await get_generic_sso_response( request=request, jwt_handler=jwt_handler, generic_client_id=generic_client_id, diff --git a/tests/test_litellm/proxy/management_endpoints/test_ui_sso.py b/tests/test_litellm/proxy/management_endpoints/test_ui_sso.py index ff636ca04a..f9c7cefcc4 100644 --- a/tests/test_litellm/proxy/management_endpoints/test_ui_sso.py +++ b/tests/test_litellm/proxy/management_endpoints/test_ui_sso.py @@ -24,6 +24,7 @@ from litellm.proxy.management_endpoints.ui_sso import ( MicrosoftSSOHandler, SSOAuthenticationHandler, _setup_team_mappings, + _sync_user_role_from_jwt_role_map, determine_role_from_groups, normalize_email, process_sso_jwt_access_token, @@ -1321,7 +1322,7 @@ async def test_get_generic_sso_response_with_additional_headers(): "fastapi_sso.sso.generic.create_provider", return_value=mock_sso_class ): # Act - result, received_response = await get_generic_sso_response( + result, received_response, _ = await get_generic_sso_response( request=mock_request, jwt_handler=mock_jwt_handler, generic_client_id=generic_client_id, @@ -1383,7 +1384,7 @@ async def test_get_generic_sso_response_with_empty_headers(): "fastapi_sso.sso.generic.create_provider", return_value=mock_sso_class ): # Act - result, received_response = await get_generic_sso_response( + result, received_response, _ = await get_generic_sso_response( request=mock_request, jwt_handler=mock_jwt_handler, generic_client_id=generic_client_id, @@ -5254,3 +5255,159 @@ class TestValidateReturnTo: ) SSOAuthenticationHandler._validate_return_to("https://cp.example.com:3000/ui") + +class TestSyncUserRoleFromJwtRoleMap: + """Tests for _sync_user_role_from_jwt_role_map.""" + + @staticmethod + def _make_jwt_handler(): + from litellm.caching.caching import DualCache + from litellm.proxy._types import ( + JWTLiteLLMRoleMap, + LiteLLM_JWTAuth, + LitellmUserRoles, + ) + + handler = JWTHandler() + handler.update_environment( + prisma_client=None, + user_api_key_cache=DualCache(), + litellm_jwtauth=LiteLLM_JWTAuth( + roles_jwt_field="custom_roles", + user_id_upsert=True, + sync_user_role_and_teams=True, + jwt_litellm_role_map=[ + JWTLiteLLMRoleMap( + jwt_role="my-admin", + litellm_role=LitellmUserRoles.PROXY_ADMIN, + ), + JWTLiteLLMRoleMap( + jwt_role="my-viewer", + litellm_role=LitellmUserRoles.INTERNAL_USER, + ), + ], + ), + ) + return handler + + @staticmethod + def _make_sso_values(user_role=None): + from litellm.proxy._types import SSOUserDefinedValues + + user_id = "testuser@example.com" + return SSOUserDefinedValues( + models=[], + user_id=user_id, + user_email=user_id, + user_role=user_role, + max_budget=None, + budget_duration=None, + ) + + @pytest.mark.asyncio + async def test_stripped_response_has_no_roles(self): + """Bug repro: stripped received_response lacks role claims.""" + from litellm.caching.caching import DualCache + + handler = self._make_jwt_handler() + sso_values = self._make_sso_values() + + await _sync_user_role_from_jwt_role_map( + jwt_handler=handler, + received_response={"token_type": "Bearer", "expires_in": 3600}, + user_info=None, + prisma_client=AsyncMock(), + user_api_key_cache=DualCache(), + user_defined_values=sso_values, + ) + + assert sso_values["user_role"] is None + + @pytest.mark.asyncio + async def test_decoded_access_token_maps_role(self): + """Decoded JWT payload with role claims maps correctly.""" + from litellm.caching.caching import DualCache + from litellm.proxy._types import LitellmUserRoles + + handler = self._make_jwt_handler() + sso_values = self._make_sso_values() + + await _sync_user_role_from_jwt_role_map( + jwt_handler=handler, + received_response={"sub": "testuser@example.com", "custom_roles": ["my-admin"]}, + user_info=None, + prisma_client=AsyncMock(), + user_api_key_cache=DualCache(), + user_defined_values=sso_values, + ) + + assert sso_values["user_role"] == LitellmUserRoles.PROXY_ADMIN.value + + @pytest.mark.asyncio + async def test_existing_user_role_updated_in_db_and_cache(self): + """Existing user with stale role gets updated in DB and cache.""" + from litellm.caching.caching import DualCache + from litellm.proxy._types import LitellmUserRoles + + handler = self._make_jwt_handler() + cache = DualCache() + prisma = AsyncMock() + prisma.db.litellm_usertable.update = AsyncMock() + user_id = "testuser@example.com" + + existing_user = LiteLLM_UserTable( + user_id=user_id, + user_role=LitellmUserRoles.INTERNAL_USER_VIEW_ONLY.value, + ) + await cache.async_set_cache(key=user_id, value=existing_user.model_dump(), ttl=60) + + sso_values = self._make_sso_values( + user_role=LitellmUserRoles.INTERNAL_USER_VIEW_ONLY.value, + ) + + await _sync_user_role_from_jwt_role_map( + jwt_handler=handler, + received_response={"sub": user_id, "custom_roles": ["my-admin"]}, + user_info=existing_user, + prisma_client=prisma, + user_api_key_cache=cache, + user_defined_values=sso_values, + ) + + prisma.db.litellm_usertable.update.assert_called_once_with( + where={"user_id": user_id}, + data={"user_role": LitellmUserRoles.PROXY_ADMIN.value}, + ) + assert existing_user.user_role == LitellmUserRoles.PROXY_ADMIN.value + assert sso_values["user_role"] == LitellmUserRoles.PROXY_ADMIN.value + + @pytest.mark.asyncio + async def test_same_role_no_db_write(self): + """No DB update when the mapped role matches the existing role.""" + from litellm.caching.caching import DualCache + from litellm.proxy._types import LitellmUserRoles + + handler = self._make_jwt_handler() + prisma = AsyncMock() + prisma.db.litellm_usertable.update = AsyncMock() + + existing_user = LiteLLM_UserTable( + user_id="testuser@example.com", + user_role=LitellmUserRoles.PROXY_ADMIN.value, + ) + + sso_values = self._make_sso_values( + user_role=LitellmUserRoles.PROXY_ADMIN.value, + ) + + await _sync_user_role_from_jwt_role_map( + jwt_handler=handler, + received_response={"sub": "testuser@example.com", "custom_roles": ["my-admin"]}, + user_info=existing_user, + prisma_client=prisma, + user_api_key_cache=DualCache(), + user_defined_values=sso_values, + ) + + prisma.db.litellm_usertable.update.assert_not_called() + From e36ab04a1856b072f6403d8b16236cf7cad3d960 Mon Sep 17 00:00:00 2001 From: Ryan Crabbe Date: Fri, 27 Mar 2026 16:26:00 -0700 Subject: [PATCH 2/8] fix(auth): guard JWTHandler.is_jwt() against None token When JWT auth is enabled and a request arrives without an Authorization header (e.g. health checks, monitoring), api_key is None due to APIKeyHeader(auto_error=False). The is_jwt() call crashes with AttributeError: 'NoneType' object has no attribute 'split'. Return False for None tokens since they are not JWTs. --- litellm/proxy/auth/handle_jwt.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/litellm/proxy/auth/handle_jwt.py b/litellm/proxy/auth/handle_jwt.py index bfad9f0c3c..86f7d614b9 100644 --- a/litellm/proxy/auth/handle_jwt.py +++ b/litellm/proxy/auth/handle_jwt.py @@ -89,7 +89,9 @@ class JWTHandler: self.leeway = leeway @staticmethod - def is_jwt(token: str): + def is_jwt(token: Optional[str]): + if token is None: + return False parts = token.split(".") return len(parts) == 3 From 8e3755931ddbd6aff21f064a424d64bdc0139204 Mon Sep 17 00:00:00 2001 From: Ryan Crabbe Date: Fri, 27 Mar 2026 16:50:58 -0700 Subject: [PATCH 3/8] test(auth): add regression tests for JWTHandler.is_jwt(None) Add None-token test cases to both proxy_unit_tests and test_litellm to cover the guard added in the previous commit. Also add -> bool return type annotation to is_jwt(). --- litellm/proxy/auth/handle_jwt.py | 2 +- tests/proxy_unit_tests/test_jwt.py | 3 +++ tests/test_litellm/proxy/auth/test_user_api_key_auth.py | 4 ++++ 3 files changed, 8 insertions(+), 1 deletion(-) diff --git a/litellm/proxy/auth/handle_jwt.py b/litellm/proxy/auth/handle_jwt.py index 86f7d614b9..48bbd8e215 100644 --- a/litellm/proxy/auth/handle_jwt.py +++ b/litellm/proxy/auth/handle_jwt.py @@ -89,7 +89,7 @@ class JWTHandler: self.leeway = leeway @staticmethod - def is_jwt(token: Optional[str]): + def is_jwt(token: Optional[str]) -> bool: if token is None: return False parts = token.split(".") diff --git a/tests/proxy_unit_tests/test_jwt.py b/tests/proxy_unit_tests/test_jwt.py index 24cf15a321..a5be1a3a42 100644 --- a/tests/proxy_unit_tests/test_jwt.py +++ b/tests/proxy_unit_tests/test_jwt.py @@ -1331,6 +1331,9 @@ def test_jwt_handler_is_jwt_static_method(): # Test with empty string assert JWTHandler.is_jwt("") == False + # Test with None (missing Authorization header) + assert JWTHandler.is_jwt(None) == False + @pytest.mark.parametrize( "requested_model, should_work", diff --git a/tests/test_litellm/proxy/auth/test_user_api_key_auth.py b/tests/test_litellm/proxy/auth/test_user_api_key_auth.py index 81ca758983..6e000f9c0e 100644 --- a/tests/test_litellm/proxy/auth/test_user_api_key_auth.py +++ b/tests/test_litellm/proxy/auth/test_user_api_key_auth.py @@ -567,6 +567,10 @@ class TestJWTOAuth2Coexistence: assert JWTHandler.is_jwt("Bearer token") is False assert JWTHandler.is_jwt("two.parts") is False + def test_is_jwt_returns_false_for_none(self): + """None token (missing Authorization header) should not be treated as JWT.""" + assert JWTHandler.is_jwt(None) is False + @pytest.mark.asyncio async def test_both_enabled_opaque_token_uses_oauth2(self): """ From a5ff668f5e2a9584796eaf87f569d611f1ff7d5d Mon Sep 17 00:00:00 2001 From: Ryan Crabbe Date: Thu, 26 Mar 2026 14:56:32 -0700 Subject: [PATCH 4/8] fix: add /user/bulk_update to management_routes so proxy admins can access it /user/bulk_update was missing from the management_routes list in _types.py, causing it to fall through to a 403 in non_proxy_admin_allowed_routes_check even for proxy admin users. Also added it to the PROXY_ADMIN_VIEW_ONLY blocked write operations list in route_checks.py to prevent view-only admins from using it. --- litellm/proxy/_types.py | 1 + litellm/proxy/auth/route_checks.py | 1 + 2 files changed, 2 insertions(+) diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index b59fc85d4b..8faf36df4c 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -525,6 +525,7 @@ class LiteLLMRoutes(enum.Enum): # user "/user/new", "/user/update", + "/user/bulk_update", "/user/delete", "/user/info", "/user/list", diff --git a/litellm/proxy/auth/route_checks.py b/litellm/proxy/auth/route_checks.py index 53cc88e3b1..26bbdef309 100644 --- a/litellm/proxy/auth/route_checks.py +++ b/litellm/proxy/auth/route_checks.py @@ -629,6 +629,7 @@ class RouteChecks: in [ "/user/new", "/user/delete", + "/user/bulk_update", "/team/new", "/team/update", "/team/delete", From 0c67f274e58f5e402c80aa6221d0b3a769f6875f Mon Sep 17 00:00:00 2001 From: Ryan Crabbe Date: Fri, 27 Mar 2026 18:01:08 -0700 Subject: [PATCH 5/8] docs: add /user/bulk_update to internal_user_endpoints module docstring --- litellm/proxy/management_endpoints/internal_user_endpoints.py | 1 + 1 file changed, 1 insertion(+) diff --git a/litellm/proxy/management_endpoints/internal_user_endpoints.py b/litellm/proxy/management_endpoints/internal_user_endpoints.py index ca8c345f46..db01e817aa 100644 --- a/litellm/proxy/management_endpoints/internal_user_endpoints.py +++ b/litellm/proxy/management_endpoints/internal_user_endpoints.py @@ -6,6 +6,7 @@ These are members of a Team on LiteLLM /user/new /user/update +/user/bulk_update /user/delete /user/info /user/list From 98ecf1755008c55cb39ea1952fcd4cabe85d3365 Mon Sep 17 00:00:00 2001 From: Ryan Crabbe Date: Fri, 27 Mar 2026 19:34:24 -0700 Subject: [PATCH 6/8] fix(ui): refactor budget page to React Query hooks and fix crashes - Migrate budget CRUD from manual state to React Query hooks (useBudgets, useCreateBudget, useUpdateBudget, useDeleteBudget) - Fix crash when budget list contains null entries by filtering in query hook - Fix max_budget type from string to number to match DB schema (double precision) - Disable budget_id field in edit modal to prevent accidental changes - Use budget_id as React key instead of array index - Update tests to mock hooks instead of networking functions --- .../budget_management_endpoints.py | 23 ++- .../(dashboard)/hooks/budgets/useBudgets.ts | 70 +++++++ .../src/components/budgets/budget_modal.tsx | 21 +-- .../components/budgets/budget_panel.test.tsx | 178 +++++++++++++----- .../src/components/budgets/budget_panel.tsx | 48 ++--- .../components/budgets/edit_budget_modal.tsx | 41 ++-- 6 files changed, 245 insertions(+), 136 deletions(-) create mode 100644 ui/litellm-dashboard/src/app/(dashboard)/hooks/budgets/useBudgets.ts diff --git a/litellm/proxy/management_endpoints/budget_management_endpoints.py b/litellm/proxy/management_endpoints/budget_management_endpoints.py index 20c7f9ec41..37f13269b1 100644 --- a/litellm/proxy/management_endpoints/budget_management_endpoints.py +++ b/litellm/proxy/management_endpoints/budget_management_endpoints.py @@ -15,6 +15,7 @@ All /budget management endpoints from datetime import timedelta from fastapi import APIRouter, Depends, HTTPException +from prisma.errors import UniqueViolationError from litellm.litellm_core_utils.duration_parser import duration_in_seconds from litellm.proxy._types import * @@ -90,13 +91,21 @@ async def new_budget( budget_obj_json = budget_obj.model_dump(exclude_none=True) budget_obj_jsonified = jsonify_object(budget_obj_json) # json dump any dictionaries - response = await prisma_client.db.litellm_budgettable.create( - data={ - **budget_obj_jsonified, # type: ignore - "created_by": user_api_key_dict.user_id or litellm_proxy_admin_name, - "updated_by": user_api_key_dict.user_id or litellm_proxy_admin_name, - } # type: ignore - ) + try: + response = await prisma_client.db.litellm_budgettable.create( + data={ + **budget_obj_jsonified, # type: ignore + "created_by": user_api_key_dict.user_id or litellm_proxy_admin_name, + "updated_by": user_api_key_dict.user_id or litellm_proxy_admin_name, + } # type: ignore + ) + except UniqueViolationError: + raise HTTPException( + status_code=400, + detail={ + "error": f"Budget with id '{budget_obj.budget_id}' already exists." + }, + ) return response diff --git a/ui/litellm-dashboard/src/app/(dashboard)/hooks/budgets/useBudgets.ts b/ui/litellm-dashboard/src/app/(dashboard)/hooks/budgets/useBudgets.ts new file mode 100644 index 0000000000..99c170b679 --- /dev/null +++ b/ui/litellm-dashboard/src/app/(dashboard)/hooks/budgets/useBudgets.ts @@ -0,0 +1,70 @@ +import { useQuery, useMutation, useQueryClient, UseQueryResult } from "@tanstack/react-query"; +import { createQueryKeys } from "../common/queryKeysFactory"; +import { getBudgetList, budgetCreateCall, budgetUpdateCall, budgetDeleteCall } from "@/components/networking"; +import useAuthorized from "@/app/(dashboard)/hooks/useAuthorized"; +import { budgetItem } from "@/components/budgets/budget_panel"; + +export const budgetKeys = createQueryKeys("budgets"); + +export const useBudgets = (): UseQueryResult => { + const { accessToken } = useAuthorized(); + return useQuery({ + queryKey: budgetKeys.list({}), + queryFn: async () => { + const data = await getBudgetList(accessToken!); + return (data ?? []).filter((item: budgetItem | null): item is budgetItem => item != null); + }, + enabled: Boolean(accessToken), + }); +}; + +export const useCreateBudget = () => { + const { accessToken } = useAuthorized(); + const queryClient = useQueryClient(); + + return useMutation>({ + mutationFn: async (formValues) => { + if (!accessToken) { + throw new Error("Access token is required"); + } + return budgetCreateCall(accessToken, formValues); + }, + onSuccess: () => { + queryClient.invalidateQueries({ queryKey: budgetKeys.all }); + }, + }); +}; + +export const useUpdateBudget = () => { + const { accessToken } = useAuthorized(); + const queryClient = useQueryClient(); + + return useMutation>({ + mutationFn: async (formValues) => { + if (!accessToken) { + throw new Error("Access token is required"); + } + return budgetUpdateCall(accessToken, formValues); + }, + onSuccess: () => { + queryClient.invalidateQueries({ queryKey: budgetKeys.all }); + }, + }); +}; + +export const useDeleteBudget = () => { + const { accessToken } = useAuthorized(); + const queryClient = useQueryClient(); + + return useMutation({ + mutationFn: async (budgetId) => { + if (!accessToken) { + throw new Error("Access token is required"); + } + return budgetDeleteCall(accessToken, budgetId); + }, + onSuccess: () => { + queryClient.invalidateQueries({ queryKey: budgetKeys.all }); + }, + }); +}; diff --git a/ui/litellm-dashboard/src/components/budgets/budget_modal.tsx b/ui/litellm-dashboard/src/components/budgets/budget_modal.tsx index 490613de25..b5ad8aaff3 100644 --- a/ui/litellm-dashboard/src/components/budgets/budget_modal.tsx +++ b/ui/litellm-dashboard/src/components/budgets/budget_modal.tsx @@ -1,17 +1,17 @@ import React from "react"; import { TextInput, Accordion, AccordionHeader, AccordionBody } from "@tremor/react"; import { Button as Button2, Modal, Form, InputNumber, Select } from "antd"; -import { budgetCreateCall } from "../networking"; +import { useCreateBudget } from "@/app/(dashboard)/hooks/budgets/useBudgets"; import NotificationsManager from "../molecules/notifications_manager"; interface BudgetModalProps { isModalVisible: boolean; - accessToken: string | null; setIsModalVisible: React.Dispatch>; - setBudgetList: React.Dispatch>; } -const BudgetModal: React.FC = ({ isModalVisible, accessToken, setIsModalVisible, setBudgetList }) => { +const BudgetModal: React.FC = ({ isModalVisible, setIsModalVisible }) => { const [form] = Form.useForm(); + const createBudget = useCreateBudget(); + const handleOk = () => { setIsModalVisible(false); form.resetFields(); @@ -23,20 +23,15 @@ const BudgetModal: React.FC = ({ isModalVisible, accessToken, }; const handleCreate = async (formValues: Record) => { - if (accessToken == null || accessToken == undefined) { - return; - } try { NotificationsManager.info("Making API Call"); - // setIsModalVisible(true); - const response = await budgetCreateCall(accessToken, formValues); - console.log("key create Response:", response); - setBudgetList((prevData) => (prevData ? [...prevData, response] : [response])); // Check if prevData is null + await createBudget.mutateAsync(formValues); NotificationsManager.success("Budget Created"); form.resetFields(); + setIsModalVisible(false); } catch (error) { - console.error("Error creating the key:", error); - NotificationsManager.fromBackend(`Error creating the key: ${error}`); + console.error("Error creating the budget:", error); + NotificationsManager.fromBackend(`Error creating the budget: ${error}`); } }; diff --git a/ui/litellm-dashboard/src/components/budgets/budget_panel.test.tsx b/ui/litellm-dashboard/src/components/budgets/budget_panel.test.tsx index 534693d398..ecae379c9f 100644 --- a/ui/litellm-dashboard/src/components/budgets/budget_panel.test.tsx +++ b/ui/litellm-dashboard/src/components/budgets/budget_panel.test.tsx @@ -1,31 +1,50 @@ -import * as networking from "../networking"; import { fireEvent, render, waitFor, screen } from "@testing-library/react"; import { act } from "@testing-library/react"; +import { QueryClient, QueryClientProvider } from "@tanstack/react-query"; import { afterEach, describe, expect, it, vi } from "vitest"; import BudgetPanel from "./budget_panel"; -vi.mock("../networking", () => ({ - getBudgetList: vi.fn(), - budgetDeleteCall: vi.fn(), +const mockBudgets = [ + { + budget_id: "budget-1", + max_budget: 100, + rpm_limit: 10, + tpm_limit: 1000, + updated_at: "2024-01-01T00:00:00Z", + }, +]; + +vi.mock("@/app/(dashboard)/hooks/budgets/useBudgets", () => ({ + useBudgets: vi.fn().mockReturnValue({ data: [], isLoading: false }), + useDeleteBudget: vi.fn().mockReturnValue({ mutateAsync: vi.fn(), isPending: false }), + useCreateBudget: vi.fn().mockReturnValue({ mutateAsync: vi.fn() }), + useUpdateBudget: vi.fn().mockReturnValue({ mutateAsync: vi.fn() }), })); +import { useBudgets, useDeleteBudget, useCreateBudget, useUpdateBudget } from "@/app/(dashboard)/hooks/budgets/useBudgets"; + +const createQueryClient = () => + new QueryClient({ + defaultOptions: { queries: { retry: false, gcTime: 0 } }, + }); + +function renderWithProviders(ui: React.ReactElement) { + const qc = createQueryClient(); + return render({ui}); +} + describe("Budget Panel", () => { afterEach(() => { vi.clearAllMocks(); }); it("should render the budget panel and load budgets", async () => { - vi.mocked(networking.getBudgetList).mockResolvedValue([ - { - budget_id: "budget-1", - max_budget: "100", - rpm_limit: 10, - tpm_limit: 1000, - updated_at: "2024-01-01T00:00:00Z", - }, - ]); + vi.mocked(useBudgets).mockReturnValue({ + data: mockBudgets, + isLoading: false, + } as any); - render(); + renderWithProviders(); await waitFor(() => { expect(screen.getByText("Create a budget to assign to customers.")).toBeInTheDocument(); @@ -34,17 +53,20 @@ describe("Budget Panel", () => { }); it("should open delete modal when clicking delete icon", async () => { - vi.mocked(networking.getBudgetList).mockResolvedValue([ - { - budget_id: "budget-to-delete", - max_budget: "200", - rpm_limit: 20, - tpm_limit: 2000, - updated_at: "2024-01-02T00:00:00Z", - }, - ]); + vi.mocked(useBudgets).mockReturnValue({ + data: [ + { + budget_id: "budget-to-delete", + max_budget: 200, + rpm_limit: 20, + tpm_limit: 2000, + updated_at: "2024-01-02T00:00:00Z", + }, + ], + isLoading: false, + } as any); - render(); + renderWithProviders(); await waitFor(() => { expect(screen.getByText("budget-to-delete")).toBeInTheDocument(); @@ -62,18 +84,25 @@ describe("Budget Panel", () => { }); it("should successfully delete a budget", async () => { - vi.mocked(networking.getBudgetList).mockResolvedValue([ - { - budget_id: "budget-to-delete", - max_budget: "200", - rpm_limit: 20, - tpm_limit: 2000, - updated_at: "2024-01-02T00:00:00Z", - }, - ]); - vi.mocked(networking.budgetDeleteCall).mockResolvedValue(undefined); + const deleteMutateAsync = vi.fn().mockResolvedValue(undefined); + vi.mocked(useBudgets).mockReturnValue({ + data: [ + { + budget_id: "budget-to-delete", + max_budget: 200, + rpm_limit: 20, + tpm_limit: 2000, + updated_at: "2024-01-02T00:00:00Z", + }, + ], + isLoading: false, + } as any); + vi.mocked(useDeleteBudget).mockReturnValue({ + mutateAsync: deleteMutateAsync, + isPending: false, + } as any); - render(); + renderWithProviders(); await waitFor(() => { expect(screen.getByText("budget-to-delete")).toBeInTheDocument(); @@ -96,24 +125,43 @@ describe("Budget Panel", () => { }); await waitFor(() => { - expect(networking.budgetDeleteCall).toHaveBeenCalledWith("token-123", "budget-to-delete"); - expect(networking.getBudgetList).toHaveBeenCalledTimes(2); // Initial load + refresh after delete + expect(deleteMutateAsync).toHaveBeenCalledWith("budget-to-delete"); + }); + }); + + it("should render empty state without crashing", async () => { + vi.mocked(useBudgets).mockReturnValue({ + data: [], + isLoading: false, + } as any); + + renderWithProviders(); + + await waitFor(() => { + expect(screen.getByText("Create a budget to assign to customers.")).toBeInTheDocument(); }); }); it("should handle delete error", async () => { - vi.mocked(networking.getBudgetList).mockResolvedValue([ - { - budget_id: "budget-to-delete", - max_budget: "200", - rpm_limit: 20, - tpm_limit: 2000, - updated_at: "2024-01-02T00:00:00Z", - }, - ]); - vi.mocked(networking.budgetDeleteCall).mockRejectedValue(new Error("Delete failed")); + const deleteMutateAsync = vi.fn().mockRejectedValue(new Error("Delete failed")); + vi.mocked(useBudgets).mockReturnValue({ + data: [ + { + budget_id: "budget-to-delete", + max_budget: 200, + rpm_limit: 20, + tpm_limit: 2000, + updated_at: "2024-01-02T00:00:00Z", + }, + ], + isLoading: false, + } as any); + vi.mocked(useDeleteBudget).mockReturnValue({ + mutateAsync: deleteMutateAsync, + isPending: false, + } as any); - render(); + renderWithProviders(); await waitFor(() => { expect(screen.getByText("budget-to-delete")).toBeInTheDocument(); @@ -136,10 +184,38 @@ describe("Budget Panel", () => { }); await waitFor(() => { - expect(networking.budgetDeleteCall).toHaveBeenCalledWith("token-123", "budget-to-delete"); + expect(deleteMutateAsync).toHaveBeenCalledWith("budget-to-delete"); + }); + }); + + it("should open edit modal when clicking edit icon", async () => { + vi.mocked(useBudgets).mockReturnValue({ + data: [ + { + budget_id: "budget-to-edit", + max_budget: 300, + rpm_limit: 30, + tpm_limit: 3000, + updated_at: "2024-01-03T00:00:00Z", + }, + ], + isLoading: false, + } as any); + + renderWithProviders(); + + await waitFor(() => { + expect(screen.getByText("budget-to-edit")).toBeInTheDocument(); }); - // Modal should still be open (error handling) - expect(screen.getByText("Delete Budget?")).toBeInTheDocument(); + const editButton = screen.getByTestId("edit-budget-button"); + + act(() => { + fireEvent.click(editButton); + }); + + await waitFor(() => { + expect(screen.getByText("Edit Budget")).toBeInTheDocument(); + }); }); }); diff --git a/ui/litellm-dashboard/src/components/budgets/budget_panel.tsx b/ui/litellm-dashboard/src/components/budgets/budget_panel.tsx index b52ef5ab94..e42d056965 100644 --- a/ui/litellm-dashboard/src/components/budgets/budget_panel.tsx +++ b/ui/litellm-dashboard/src/components/budgets/budget_panel.tsx @@ -19,12 +19,12 @@ import { TabPanels, Text, } from "@tremor/react"; -import React, { useEffect, useState } from "react"; +import React, { useState } from "react"; import { Prism as SyntaxHighlighter } from "react-syntax-highlighter"; import DeleteResourceModal from "../common_components/DeleteResourceModal"; import TableIconActionButton from "../common_components/IconActionButton/TableIconActionButtons/TableIconActionButton"; import NotificationsManager from "../molecules/notifications_manager"; -import { budgetDeleteCall, getBudgetList } from "../networking"; +import { useBudgets, useDeleteBudget } from "@/app/(dashboard)/hooks/budgets/useBudgets"; import BudgetModal from "./budget_modal"; import EditBudgetModal from "./edit_budget_modal"; import { CREATE_END_USER_CURL_COMMAND, CHAT_COMPLETIONS_CURL_COMMAND, OPENAI_SDK_PYTHON_CODE } from "./constants"; @@ -35,7 +35,7 @@ interface BudgetSettingsPageProps { export interface budgetItem { budget_id: string; - max_budget: string | null; + max_budget: number | null; rpm_limit: number | null; tpm_limit: number | null; updated_at: string; @@ -45,17 +45,10 @@ const BudgetPanel: React.FC = ({ accessToken }) => { const [isCreateModelVisible, setIsCreateModelVisible] = useState(false); const [isEditModalVisible, setIsEditModalVisible] = useState(false); const [selectedBudget, setSelectedBudget] = useState(null); - const [budgetList, setBudgetList] = useState([]); - const [isDeleting, setIsDeleting] = useState(false); const [isDeleteModalVisible, setIsDeleteModalVisible] = useState(false); - useEffect(() => { - if (!accessToken) { - return; - } - getBudgetList(accessToken).then((data) => { - setBudgetList(data); - }); - }, [accessToken]); + + const { data: budgetList = [] } = useBudgets(); + const deleteBudget = useDeleteBudget(); const handleEditCall = async (budget: budgetItem) => { if (accessToken == null) { @@ -74,11 +67,9 @@ const BudgetPanel: React.FC = ({ accessToken }) => { if (!selectedBudget || accessToken == null) { return; } - setIsDeleting(true); try { - await budgetDeleteCall(accessToken, selectedBudget.budget_id); + await deleteBudget.mutateAsync(selectedBudget.budget_id); NotificationsManager.success("Budget deleted."); - await handleUpdateCall(); } catch (error) { console.error("Error deleting budget:", error); if (typeof NotificationsManager.fromBackend === "function") { @@ -87,7 +78,6 @@ const BudgetPanel: React.FC = ({ accessToken }) => { NotificationsManager.info("Failed to delete budget"); } } finally { - setIsDeleting(false); setIsDeleteModalVisible(false); setSelectedBudget(null); } @@ -97,15 +87,6 @@ const BudgetPanel: React.FC = ({ accessToken }) => { setIsDeleteModalVisible(false); }; - const handleUpdateCall = async () => { - if (accessToken == null) { - return; - } - getBudgetList(accessToken).then((data) => { - setBudgetList(data); - }); - }; - return (