Merge pull request #24718 from BerriAI/litellm_ryan-march-26

litellm ryan march 26
This commit is contained in:
ryan-crabbe-berri
2026-03-28 09:01:11 -07:00
committed by GitHub
15 changed files with 636 additions and 146 deletions
+1
View File
@@ -525,6 +525,7 @@ class LiteLLMRoutes(enum.Enum):
# user
"/user/new",
"/user/update",
"/user/bulk_update",
"/user/delete",
"/user/info",
"/user/list",
+18 -1
View File
@@ -18,6 +18,7 @@ from fastapi import HTTPException
from litellm._logging import verbose_proxy_logger
from litellm.caching.caching import DualCache
from litellm.constants import DEFAULT_MANAGEMENT_OBJECT_IN_MEMORY_CACHE_TTL
from litellm.litellm_core_utils.dot_notation_indexing import get_nested_value
from litellm.llms.custom_httpx.httpx_handler import HTTPHandler
from litellm.proxy._types import (
@@ -89,7 +90,9 @@ class JWTHandler:
self.leeway = leeway
@staticmethod
def is_jwt(token: str):
def is_jwt(token: Optional[str]) -> bool:
if token is None:
return False
parts = token.split(".")
return len(parts) == 3
@@ -1324,6 +1327,7 @@ class JWTAuthManager:
jwt_valid_token: dict,
user_object: Optional[LiteLLM_UserTable],
prisma_client: Optional[PrismaClient],
user_api_key_cache: Optional[DualCache] = None,
) -> None:
"""
Sync user role and team memberships with JWT claims
@@ -1348,6 +1352,12 @@ class JWTAuthManager:
data={"user_role": new_role.value},
)
user_object.user_role = new_role.value
if user_api_key_cache is not None:
await user_api_key_cache.async_set_cache(
key=user_object.user_id,
value=user_object.model_dump(),
ttl=DEFAULT_MANAGEMENT_OBJECT_IN_MEMORY_CACHE_TTL,
)
# Sync team memberships
jwt_team_ids = set(jwt_handler.get_team_ids_from_jwt(jwt_valid_token))
@@ -1365,6 +1375,12 @@ class JWTAuthManager:
teams_ids_to_remove_user_from=list(teams_to_remove),
)
user_object.teams = list(jwt_team_ids)
if user_api_key_cache is not None:
await user_api_key_cache.async_set_cache(
key=user_object.user_id,
value=user_object.model_dump(),
ttl=DEFAULT_MANAGEMENT_OBJECT_IN_MEMORY_CACHE_TTL,
)
return None
@staticmethod
@@ -1536,6 +1552,7 @@ class JWTAuthManager:
jwt_valid_token=jwt_valid_token,
user_object=user_object,
prisma_client=prisma_client,
user_api_key_cache=user_api_key_cache,
)
## MAP USER TO TEAMS
+1
View File
@@ -629,6 +629,7 @@ class RouteChecks:
in [
"/user/new",
"/user/delete",
"/user/bulk_update",
"/team/new",
"/team/update",
"/team/delete",
@@ -15,6 +15,7 @@ All /budget management endpoints
from datetime import timedelta
from fastapi import APIRouter, Depends, HTTPException
from prisma.errors import UniqueViolationError
from litellm.litellm_core_utils.duration_parser import duration_in_seconds
from litellm.proxy._types import *
@@ -90,13 +91,21 @@ async def new_budget(
budget_obj_json = budget_obj.model_dump(exclude_none=True)
budget_obj_jsonified = jsonify_object(budget_obj_json) # json dump any dictionaries
response = await prisma_client.db.litellm_budgettable.create(
data={
**budget_obj_jsonified, # type: ignore
"created_by": user_api_key_dict.user_id or litellm_proxy_admin_name,
"updated_by": user_api_key_dict.user_id or litellm_proxy_admin_name,
} # type: ignore
)
try:
response = await prisma_client.db.litellm_budgettable.create(
data={
**budget_obj_jsonified, # type: ignore
"created_by": user_api_key_dict.user_id or litellm_proxy_admin_name,
"updated_by": user_api_key_dict.user_id or litellm_proxy_admin_name,
} # type: ignore
)
except UniqueViolationError:
raise HTTPException(
status_code=400,
detail={
"error": f"Budget with id '{budget_obj.budget_id}' already exists."
},
)
return response
@@ -6,6 +6,7 @@ These are members of a Team on LiteLLM
/user/new
/user/update
/user/bulk_update
/user/delete
/user/info
/user/list
+87 -7
View File
@@ -204,7 +204,7 @@ def process_sso_jwt_access_token(
sso_jwt_handler: Optional[JWTHandler],
result: Union[OpenID, dict, None],
role_mappings: Optional["RoleMappings"] = None,
) -> None:
) -> Optional[dict]:
"""
Process SSO JWT access token and extract team IDs and user role if available.
@@ -218,6 +218,12 @@ def process_sso_jwt_access_token(
sso_jwt_handler: SSO-specific JWT handler for team ID extraction
result: The SSO result object to update with team IDs and role
role_mappings: Optional role mappings configuration for group-based role determination
Returns:
The decoded access token payload dict, or None if decoding failed or
inputs were missing. Callers can pass this to _sync_user_role_from_jwt_role_map
so it has access to custom role claims (e.g. custom_roles) that are
encoded inside the JWT but stripped from received_response.
"""
if access_token_str and result:
import jwt
@@ -230,7 +236,7 @@ def process_sso_jwt_access_token(
verbose_proxy_logger.debug(
"Access token is not a valid JWT (possibly an opaque token), skipping JWT-based extraction"
)
return
return None
# Extract team IDs from access token if sso_jwt_handler is available
if sso_jwt_handler:
@@ -306,6 +312,10 @@ def process_sso_jwt_access_token(
f"Set user_role='{user_role}' from JWT access token"
)
return access_token_payload
return None
@router.get("/sso/key/generate", tags=["experimental"], include_in_schema=False)
async def google_login(
@@ -817,7 +827,7 @@ async def get_generic_sso_response(
], # sso specific jwt handler - used for restricted sso group access control
generic_client_id: str,
redirect_url: str,
) -> Tuple[Union[OpenID, dict], Optional[dict]]: # return received response
) -> Tuple[Union[OpenID, dict], Optional[dict], Optional[dict]]: # (result, received_response, access_token_payload)
# make generic sso provider
from fastapi_sso.sso.base import DiscoveryDocument
from fastapi_sso.sso.generic import create_provider
@@ -872,6 +882,7 @@ async def get_generic_sso_response(
code_verifier: Optional[
str
] = None # assigned inside try; initialized for type tracking
access_token_payload: Optional[dict] = None # decoded JWT access token claims
try:
token_exchange_params = (
@@ -958,7 +969,7 @@ async def get_generic_sso_response(
)
access_token_str = generic_sso.access_token
process_sso_jwt_access_token(
access_token_payload = process_sso_jwt_access_token(
access_token_str, sso_jwt_handler, result, role_mappings=role_mappings
)
# Delete the single-use PKCE verifier only after all downstream processing
@@ -976,7 +987,7 @@ async def get_generic_sso_response(
additional_generic_sso_headers_dict,
)
verbose_proxy_logger.debug("generic result: %s", result)
return result or {}, received_response
return result or {}, received_response, access_token_payload
async def create_team_member_add_task(team_id, user_info):
@@ -1176,6 +1187,56 @@ def _build_sso_user_update_data(
return update_data
async def _sync_user_role_from_jwt_role_map(
jwt_handler: Optional[JWTHandler],
received_response: Optional[dict],
user_info: Optional[Union[LiteLLM_UserTable, NewUserResponse]],
prisma_client: PrismaClient,
user_api_key_cache: DualCache,
user_defined_values: Optional[SSOUserDefinedValues],
) -> None:
"""
Apply jwt_litellm_role_map during SSO login.
When jwt_litellm_role_map is configured with sync_user_role_and_teams=True,
this ensures SSO users get the same role mapping as API/JWT users. Without
this, the SSO path falls back to INTERNAL_USER_VIEW_ONLY for roles that
don't directly match LitellmUserRoles enum values.
"""
if jwt_handler is None or received_response is None:
return
if not jwt_handler.litellm_jwtauth.sync_user_role_and_teams:
return
if not jwt_handler.litellm_jwtauth.jwt_litellm_role_map:
return
mapped_role = jwt_handler.map_jwt_role_to_litellm_role(received_response)
if mapped_role is None:
return
verbose_proxy_logger.info(
f"SSO jwt_litellm_role_map matched role: {mapped_role.value}"
)
# Update user_defined_values so downstream code uses the mapped role
if user_defined_values is not None:
user_defined_values["user_role"] = mapped_role.value
# Update existing DB record if role differs
if user_info is not None and user_info.user_role != mapped_role.value:
await prisma_client.db.litellm_usertable.update(
where={"user_id": user_info.user_id},
data={"user_role": mapped_role.value},
)
user_info.user_role = mapped_role.value
await user_api_key_cache.async_set_cache(
key=user_info.user_id,
value=user_info.model_dump()
if hasattr(user_info, "model_dump")
else dict(user_info),
)
def apply_user_info_values_to_sso_user_defined_values(
user_info: Optional[Union[LiteLLM_UserTable, NewUserResponse]],
user_defined_values: Optional[SSOUserDefinedValues],
@@ -1279,6 +1340,7 @@ async def auth_callback(request: Request, state: Optional[str] = None): # noqa:
google_client_id = os.getenv("GOOGLE_CLIENT_ID", None)
generic_client_id = os.getenv("GENERIC_CLIENT_ID", None)
received_response: Optional[dict] = None
access_token_payload: Optional[dict] = None
# get url from request
if master_key is None:
raise ProxyException(
@@ -1307,7 +1369,7 @@ async def auth_callback(request: Request, state: Optional[str] = None): # noqa:
)
elif generic_client_id is not None:
result, received_response = await get_generic_sso_response(
result, received_response, access_token_payload = await get_generic_sso_response(
request=request,
jwt_handler=jwt_handler,
generic_client_id=generic_client_id,
@@ -1345,6 +1407,8 @@ async def auth_callback(request: Request, state: Optional[str] = None): # noqa:
received_response=received_response,
generic_client_id=generic_client_id,
ui_access_mode=ui_access_mode,
access_token_payload=access_token_payload,
jwt_handler=jwt_handler,
return_to=cp_return_to,
)
@@ -2417,6 +2481,8 @@ class SSOAuthenticationHandler:
received_response: Optional[dict] = None,
generic_client_id: Optional[str] = None,
ui_access_mode: Optional[Dict] = None,
access_token_payload: Optional[dict] = None,
jwt_handler: Optional[JWTHandler] = None,
return_to: Optional[str] = None,
) -> RedirectResponse:
import jwt
@@ -2498,6 +2564,20 @@ class SSOAuthenticationHandler:
alternate_user_id=user_id,
)
# Sync user role from JWT claims via jwt_litellm_role_map (if configured).
# This ensures SSO users get the same role mapping as API/JWT users.
# Use the decoded access_token_payload (not received_response) because
# custom role claims (e.g. custom_roles) are encoded inside the JWT
# access token, which is stripped from received_response.
await _sync_user_role_from_jwt_role_map(
jwt_handler=jwt_handler,
received_response=access_token_payload or received_response,
user_info=user_info,
prisma_client=prisma_client,
user_api_key_cache=user_api_key_cache,
user_defined_values=user_defined_values,
)
user_defined_values = apply_user_info_values_to_sso_user_defined_values(
user_info=user_info, user_defined_values=user_defined_values
)
@@ -3703,7 +3783,7 @@ async def debug_sso_callback(request: Request):
)
elif generic_client_id is not None:
result, _ = await get_generic_sso_response(
result, _, _ = await get_generic_sso_response(
request=request,
jwt_handler=jwt_handler,
generic_client_id=generic_client_id,
+3
View File
@@ -1331,6 +1331,9 @@ def test_jwt_handler_is_jwt_static_method():
# Test with empty string
assert JWTHandler.is_jwt("") == False
# Test with None (missing Authorization header)
assert JWTHandler.is_jwt(None) == False
@pytest.mark.parametrize(
"requested_model, should_work",
@@ -339,6 +339,123 @@ async def test_sync_user_role_and_teams():
assert set(user.teams) == {"team1", "team2"}
@pytest.mark.asyncio
async def test_sync_user_role_and_teams_cache_invalidation_on_role_change():
"""Test that user cache is updated when role changes."""
mock_cache = AsyncMock()
jwt_handler = JWTHandler()
jwt_handler.update_environment(
prisma_client=None,
user_api_key_cache=AsyncMock(),
litellm_jwtauth=LiteLLM_JWTAuth(
jwt_litellm_role_map=[
JWTLiteLLMRoleMap(jwt_role="ADMIN", litellm_role=LitellmUserRoles.PROXY_ADMIN)
],
roles_jwt_field="roles",
team_ids_jwt_field="my_id_teams",
sync_user_role_and_teams=True,
),
)
token = {"roles": ["ADMIN"], "my_id_teams": ["team1"]}
user = LiteLLM_UserTable(
user_id="u1",
user_role=LitellmUserRoles.INTERNAL_USER.value,
teams=["team1"], # teams already match — only role differs
)
prisma = AsyncMock()
prisma.db.litellm_usertable.update = AsyncMock()
await JWTAuthManager.sync_user_role_and_teams(
jwt_handler, token, user, prisma, user_api_key_cache=mock_cache
)
mock_cache.async_set_cache.assert_called_once()
call_kwargs = mock_cache.async_set_cache.call_args
assert call_kwargs.kwargs["key"] == "u1"
assert call_kwargs.kwargs["value"]["user_role"] == LitellmUserRoles.PROXY_ADMIN.value
@pytest.mark.asyncio
async def test_sync_user_role_and_teams_cache_invalidation_on_team_change():
"""Test that user cache is updated when team memberships change."""
mock_cache = AsyncMock()
jwt_handler = JWTHandler()
jwt_handler.update_environment(
prisma_client=None,
user_api_key_cache=AsyncMock(),
litellm_jwtauth=LiteLLM_JWTAuth(
jwt_litellm_role_map=[
JWTLiteLLMRoleMap(jwt_role="ADMIN", litellm_role=LitellmUserRoles.PROXY_ADMIN)
],
roles_jwt_field="roles",
team_ids_jwt_field="my_id_teams",
sync_user_role_and_teams=True,
),
)
token = {"roles": ["ADMIN"], "my_id_teams": ["team1", "team2"]}
user = LiteLLM_UserTable(
user_id="u1",
user_role=LitellmUserRoles.PROXY_ADMIN.value, # role already matches
teams=["team2"], # teams differ
)
prisma = AsyncMock()
prisma.db.litellm_usertable.update = AsyncMock()
with patch(
"litellm.proxy.management_endpoints.scim.scim_v2.patch_team_membership",
new_callable=AsyncMock,
):
await JWTAuthManager.sync_user_role_and_teams(
jwt_handler, token, user, prisma, user_api_key_cache=mock_cache
)
mock_cache.async_set_cache.assert_called_once()
call_kwargs = mock_cache.async_set_cache.call_args
assert call_kwargs.kwargs["key"] == "u1"
assert set(call_kwargs.kwargs["value"]["teams"]) == {"team1", "team2"}
@pytest.mark.asyncio
async def test_sync_user_role_and_teams_no_cache_write_when_nothing_changes():
"""Test that cache is NOT written when role and teams already match."""
mock_cache = AsyncMock()
jwt_handler = JWTHandler()
jwt_handler.update_environment(
prisma_client=None,
user_api_key_cache=AsyncMock(),
litellm_jwtauth=LiteLLM_JWTAuth(
jwt_litellm_role_map=[
JWTLiteLLMRoleMap(jwt_role="ADMIN", litellm_role=LitellmUserRoles.PROXY_ADMIN)
],
roles_jwt_field="roles",
team_ids_jwt_field="my_id_teams",
sync_user_role_and_teams=True,
),
)
token = {"roles": ["ADMIN"], "my_id_teams": ["team1"]}
user = LiteLLM_UserTable(
user_id="u1",
user_role=LitellmUserRoles.PROXY_ADMIN.value,
teams=["team1"],
)
prisma = AsyncMock()
await JWTAuthManager.sync_user_role_and_teams(
jwt_handler, token, user, prisma, user_api_key_cache=mock_cache
)
mock_cache.async_set_cache.assert_not_called()
@pytest.mark.asyncio
async def test_map_jwt_role_to_litellm_role():
"""Test JWT role mapping to LiteLLM roles with various patterns"""
@@ -567,6 +567,10 @@ class TestJWTOAuth2Coexistence:
assert JWTHandler.is_jwt("Bearer token") is False
assert JWTHandler.is_jwt("two.parts") is False
def test_is_jwt_returns_false_for_none(self):
"""None token (missing Authorization header) should not be treated as JWT."""
assert JWTHandler.is_jwt(None) is False
@pytest.mark.asyncio
async def test_both_enabled_opaque_token_uses_oauth2(self):
"""
@@ -24,6 +24,7 @@ from litellm.proxy.management_endpoints.ui_sso import (
MicrosoftSSOHandler,
SSOAuthenticationHandler,
_setup_team_mappings,
_sync_user_role_from_jwt_role_map,
determine_role_from_groups,
normalize_email,
process_sso_jwt_access_token,
@@ -1321,7 +1322,7 @@ async def test_get_generic_sso_response_with_additional_headers():
"fastapi_sso.sso.generic.create_provider", return_value=mock_sso_class
):
# Act
result, received_response = await get_generic_sso_response(
result, received_response, _ = await get_generic_sso_response(
request=mock_request,
jwt_handler=mock_jwt_handler,
generic_client_id=generic_client_id,
@@ -1383,7 +1384,7 @@ async def test_get_generic_sso_response_with_empty_headers():
"fastapi_sso.sso.generic.create_provider", return_value=mock_sso_class
):
# Act
result, received_response = await get_generic_sso_response(
result, received_response, _ = await get_generic_sso_response(
request=mock_request,
jwt_handler=mock_jwt_handler,
generic_client_id=generic_client_id,
@@ -5254,3 +5255,159 @@ class TestValidateReturnTo:
)
SSOAuthenticationHandler._validate_return_to("https://cp.example.com:3000/ui")
class TestSyncUserRoleFromJwtRoleMap:
"""Tests for _sync_user_role_from_jwt_role_map."""
@staticmethod
def _make_jwt_handler():
from litellm.caching.caching import DualCache
from litellm.proxy._types import (
JWTLiteLLMRoleMap,
LiteLLM_JWTAuth,
LitellmUserRoles,
)
handler = JWTHandler()
handler.update_environment(
prisma_client=None,
user_api_key_cache=DualCache(),
litellm_jwtauth=LiteLLM_JWTAuth(
roles_jwt_field="custom_roles",
user_id_upsert=True,
sync_user_role_and_teams=True,
jwt_litellm_role_map=[
JWTLiteLLMRoleMap(
jwt_role="my-admin",
litellm_role=LitellmUserRoles.PROXY_ADMIN,
),
JWTLiteLLMRoleMap(
jwt_role="my-viewer",
litellm_role=LitellmUserRoles.INTERNAL_USER,
),
],
),
)
return handler
@staticmethod
def _make_sso_values(user_role=None):
from litellm.proxy._types import SSOUserDefinedValues
user_id = "testuser@example.com"
return SSOUserDefinedValues(
models=[],
user_id=user_id,
user_email=user_id,
user_role=user_role,
max_budget=None,
budget_duration=None,
)
@pytest.mark.asyncio
async def test_stripped_response_has_no_roles(self):
"""Bug repro: stripped received_response lacks role claims."""
from litellm.caching.caching import DualCache
handler = self._make_jwt_handler()
sso_values = self._make_sso_values()
await _sync_user_role_from_jwt_role_map(
jwt_handler=handler,
received_response={"token_type": "Bearer", "expires_in": 3600},
user_info=None,
prisma_client=AsyncMock(),
user_api_key_cache=DualCache(),
user_defined_values=sso_values,
)
assert sso_values["user_role"] is None
@pytest.mark.asyncio
async def test_decoded_access_token_maps_role(self):
"""Decoded JWT payload with role claims maps correctly."""
from litellm.caching.caching import DualCache
from litellm.proxy._types import LitellmUserRoles
handler = self._make_jwt_handler()
sso_values = self._make_sso_values()
await _sync_user_role_from_jwt_role_map(
jwt_handler=handler,
received_response={"sub": "testuser@example.com", "custom_roles": ["my-admin"]},
user_info=None,
prisma_client=AsyncMock(),
user_api_key_cache=DualCache(),
user_defined_values=sso_values,
)
assert sso_values["user_role"] == LitellmUserRoles.PROXY_ADMIN.value
@pytest.mark.asyncio
async def test_existing_user_role_updated_in_db_and_cache(self):
"""Existing user with stale role gets updated in DB and cache."""
from litellm.caching.caching import DualCache
from litellm.proxy._types import LitellmUserRoles
handler = self._make_jwt_handler()
cache = DualCache()
prisma = AsyncMock()
prisma.db.litellm_usertable.update = AsyncMock()
user_id = "testuser@example.com"
existing_user = LiteLLM_UserTable(
user_id=user_id,
user_role=LitellmUserRoles.INTERNAL_USER_VIEW_ONLY.value,
)
await cache.async_set_cache(key=user_id, value=existing_user.model_dump(), ttl=60)
sso_values = self._make_sso_values(
user_role=LitellmUserRoles.INTERNAL_USER_VIEW_ONLY.value,
)
await _sync_user_role_from_jwt_role_map(
jwt_handler=handler,
received_response={"sub": user_id, "custom_roles": ["my-admin"]},
user_info=existing_user,
prisma_client=prisma,
user_api_key_cache=cache,
user_defined_values=sso_values,
)
prisma.db.litellm_usertable.update.assert_called_once_with(
where={"user_id": user_id},
data={"user_role": LitellmUserRoles.PROXY_ADMIN.value},
)
assert existing_user.user_role == LitellmUserRoles.PROXY_ADMIN.value
assert sso_values["user_role"] == LitellmUserRoles.PROXY_ADMIN.value
@pytest.mark.asyncio
async def test_same_role_no_db_write(self):
"""No DB update when the mapped role matches the existing role."""
from litellm.caching.caching import DualCache
from litellm.proxy._types import LitellmUserRoles
handler = self._make_jwt_handler()
prisma = AsyncMock()
prisma.db.litellm_usertable.update = AsyncMock()
existing_user = LiteLLM_UserTable(
user_id="testuser@example.com",
user_role=LitellmUserRoles.PROXY_ADMIN.value,
)
sso_values = self._make_sso_values(
user_role=LitellmUserRoles.PROXY_ADMIN.value,
)
await _sync_user_role_from_jwt_role_map(
jwt_handler=handler,
received_response={"sub": "testuser@example.com", "custom_roles": ["my-admin"]},
user_info=existing_user,
prisma_client=prisma,
user_api_key_cache=DualCache(),
user_defined_values=sso_values,
)
prisma.db.litellm_usertable.update.assert_not_called()
@@ -0,0 +1,70 @@
import { useQuery, useMutation, useQueryClient, UseQueryResult } from "@tanstack/react-query";
import { createQueryKeys } from "../common/queryKeysFactory";
import { getBudgetList, budgetCreateCall, budgetUpdateCall, budgetDeleteCall } from "@/components/networking";
import useAuthorized from "@/app/(dashboard)/hooks/useAuthorized";
import { budgetItem } from "@/components/budgets/budget_panel";
export const budgetKeys = createQueryKeys("budgets");
export const useBudgets = (): UseQueryResult<budgetItem[]> => {
const { accessToken } = useAuthorized();
return useQuery<budgetItem[]>({
queryKey: budgetKeys.list({}),
queryFn: async () => {
const data = await getBudgetList(accessToken!);
return (data ?? []).filter((item: budgetItem | null): item is budgetItem => item != null);
},
enabled: Boolean(accessToken),
});
};
export const useCreateBudget = () => {
const { accessToken } = useAuthorized();
const queryClient = useQueryClient();
return useMutation<unknown, Error, Record<string, any>>({
mutationFn: async (formValues) => {
if (!accessToken) {
throw new Error("Access token is required");
}
return budgetCreateCall(accessToken, formValues);
},
onSuccess: () => {
queryClient.invalidateQueries({ queryKey: budgetKeys.all });
},
});
};
export const useUpdateBudget = () => {
const { accessToken } = useAuthorized();
const queryClient = useQueryClient();
return useMutation<unknown, Error, Record<string, any>>({
mutationFn: async (formValues) => {
if (!accessToken) {
throw new Error("Access token is required");
}
return budgetUpdateCall(accessToken, formValues);
},
onSuccess: () => {
queryClient.invalidateQueries({ queryKey: budgetKeys.all });
},
});
};
export const useDeleteBudget = () => {
const { accessToken } = useAuthorized();
const queryClient = useQueryClient();
return useMutation<unknown, Error, string>({
mutationFn: async (budgetId) => {
if (!accessToken) {
throw new Error("Access token is required");
}
return budgetDeleteCall(accessToken, budgetId);
},
onSuccess: () => {
queryClient.invalidateQueries({ queryKey: budgetKeys.all });
},
});
};
@@ -1,17 +1,17 @@
import React from "react";
import { TextInput, Accordion, AccordionHeader, AccordionBody } from "@tremor/react";
import { Button as Button2, Modal, Form, InputNumber, Select } from "antd";
import { budgetCreateCall } from "../networking";
import { useCreateBudget } from "@/app/(dashboard)/hooks/budgets/useBudgets";
import NotificationsManager from "../molecules/notifications_manager";
interface BudgetModalProps {
isModalVisible: boolean;
accessToken: string | null;
setIsModalVisible: React.Dispatch<React.SetStateAction<boolean>>;
setBudgetList: React.Dispatch<React.SetStateAction<any[]>>;
}
const BudgetModal: React.FC<BudgetModalProps> = ({ isModalVisible, accessToken, setIsModalVisible, setBudgetList }) => {
const BudgetModal: React.FC<BudgetModalProps> = ({ isModalVisible, setIsModalVisible }) => {
const [form] = Form.useForm();
const createBudget = useCreateBudget();
const handleOk = () => {
setIsModalVisible(false);
form.resetFields();
@@ -23,20 +23,15 @@ const BudgetModal: React.FC<BudgetModalProps> = ({ isModalVisible, accessToken,
};
const handleCreate = async (formValues: Record<string, any>) => {
if (accessToken == null || accessToken == undefined) {
return;
}
try {
NotificationsManager.info("Making API Call");
// setIsModalVisible(true);
const response = await budgetCreateCall(accessToken, formValues);
console.log("key create Response:", response);
setBudgetList((prevData) => (prevData ? [...prevData, response] : [response])); // Check if prevData is null
await createBudget.mutateAsync(formValues);
NotificationsManager.success("Budget Created");
form.resetFields();
setIsModalVisible(false);
} catch (error) {
console.error("Error creating the key:", error);
NotificationsManager.fromBackend(`Error creating the key: ${error}`);
console.error("Error creating the budget:", error);
NotificationsManager.fromBackend(`Error creating the budget: ${error}`);
}
};
@@ -1,31 +1,50 @@
import * as networking from "../networking";
import { fireEvent, render, waitFor, screen } from "@testing-library/react";
import { act } from "@testing-library/react";
import { QueryClient, QueryClientProvider } from "@tanstack/react-query";
import { afterEach, describe, expect, it, vi } from "vitest";
import BudgetPanel from "./budget_panel";
vi.mock("../networking", () => ({
getBudgetList: vi.fn(),
budgetDeleteCall: vi.fn(),
const mockBudgets = [
{
budget_id: "budget-1",
max_budget: 100,
rpm_limit: 10,
tpm_limit: 1000,
updated_at: "2024-01-01T00:00:00Z",
},
];
vi.mock("@/app/(dashboard)/hooks/budgets/useBudgets", () => ({
useBudgets: vi.fn().mockReturnValue({ data: [], isLoading: false }),
useDeleteBudget: vi.fn().mockReturnValue({ mutateAsync: vi.fn(), isPending: false }),
useCreateBudget: vi.fn().mockReturnValue({ mutateAsync: vi.fn() }),
useUpdateBudget: vi.fn().mockReturnValue({ mutateAsync: vi.fn() }),
}));
import { useBudgets, useDeleteBudget, useCreateBudget, useUpdateBudget } from "@/app/(dashboard)/hooks/budgets/useBudgets";
const createQueryClient = () =>
new QueryClient({
defaultOptions: { queries: { retry: false, gcTime: 0 } },
});
function renderWithProviders(ui: React.ReactElement) {
const qc = createQueryClient();
return render(<QueryClientProvider client={qc}>{ui}</QueryClientProvider>);
}
describe("Budget Panel", () => {
afterEach(() => {
vi.clearAllMocks();
});
it("should render the budget panel and load budgets", async () => {
vi.mocked(networking.getBudgetList).mockResolvedValue([
{
budget_id: "budget-1",
max_budget: "100",
rpm_limit: 10,
tpm_limit: 1000,
updated_at: "2024-01-01T00:00:00Z",
},
]);
vi.mocked(useBudgets).mockReturnValue({
data: mockBudgets,
isLoading: false,
} as any);
render(<BudgetPanel accessToken="token-123" />);
renderWithProviders(<BudgetPanel accessToken="token-123" />);
await waitFor(() => {
expect(screen.getByText("Create a budget to assign to customers.")).toBeInTheDocument();
@@ -34,17 +53,20 @@ describe("Budget Panel", () => {
});
it("should open delete modal when clicking delete icon", async () => {
vi.mocked(networking.getBudgetList).mockResolvedValue([
{
budget_id: "budget-to-delete",
max_budget: "200",
rpm_limit: 20,
tpm_limit: 2000,
updated_at: "2024-01-02T00:00:00Z",
},
]);
vi.mocked(useBudgets).mockReturnValue({
data: [
{
budget_id: "budget-to-delete",
max_budget: 200,
rpm_limit: 20,
tpm_limit: 2000,
updated_at: "2024-01-02T00:00:00Z",
},
],
isLoading: false,
} as any);
render(<BudgetPanel accessToken="token-123" />);
renderWithProviders(<BudgetPanel accessToken="token-123" />);
await waitFor(() => {
expect(screen.getByText("budget-to-delete")).toBeInTheDocument();
@@ -62,18 +84,25 @@ describe("Budget Panel", () => {
});
it("should successfully delete a budget", async () => {
vi.mocked(networking.getBudgetList).mockResolvedValue([
{
budget_id: "budget-to-delete",
max_budget: "200",
rpm_limit: 20,
tpm_limit: 2000,
updated_at: "2024-01-02T00:00:00Z",
},
]);
vi.mocked(networking.budgetDeleteCall).mockResolvedValue(undefined);
const deleteMutateAsync = vi.fn().mockResolvedValue(undefined);
vi.mocked(useBudgets).mockReturnValue({
data: [
{
budget_id: "budget-to-delete",
max_budget: 200,
rpm_limit: 20,
tpm_limit: 2000,
updated_at: "2024-01-02T00:00:00Z",
},
],
isLoading: false,
} as any);
vi.mocked(useDeleteBudget).mockReturnValue({
mutateAsync: deleteMutateAsync,
isPending: false,
} as any);
render(<BudgetPanel accessToken="token-123" />);
renderWithProviders(<BudgetPanel accessToken="token-123" />);
await waitFor(() => {
expect(screen.getByText("budget-to-delete")).toBeInTheDocument();
@@ -96,24 +125,43 @@ describe("Budget Panel", () => {
});
await waitFor(() => {
expect(networking.budgetDeleteCall).toHaveBeenCalledWith("token-123", "budget-to-delete");
expect(networking.getBudgetList).toHaveBeenCalledTimes(2); // Initial load + refresh after delete
expect(deleteMutateAsync).toHaveBeenCalledWith("budget-to-delete");
});
});
it("should render empty state without crashing", async () => {
vi.mocked(useBudgets).mockReturnValue({
data: [],
isLoading: false,
} as any);
renderWithProviders(<BudgetPanel accessToken="token-123" />);
await waitFor(() => {
expect(screen.getByText("Create a budget to assign to customers.")).toBeInTheDocument();
});
});
it("should handle delete error", async () => {
vi.mocked(networking.getBudgetList).mockResolvedValue([
{
budget_id: "budget-to-delete",
max_budget: "200",
rpm_limit: 20,
tpm_limit: 2000,
updated_at: "2024-01-02T00:00:00Z",
},
]);
vi.mocked(networking.budgetDeleteCall).mockRejectedValue(new Error("Delete failed"));
const deleteMutateAsync = vi.fn().mockRejectedValue(new Error("Delete failed"));
vi.mocked(useBudgets).mockReturnValue({
data: [
{
budget_id: "budget-to-delete",
max_budget: 200,
rpm_limit: 20,
tpm_limit: 2000,
updated_at: "2024-01-02T00:00:00Z",
},
],
isLoading: false,
} as any);
vi.mocked(useDeleteBudget).mockReturnValue({
mutateAsync: deleteMutateAsync,
isPending: false,
} as any);
render(<BudgetPanel accessToken="token-123" />);
renderWithProviders(<BudgetPanel accessToken="token-123" />);
await waitFor(() => {
expect(screen.getByText("budget-to-delete")).toBeInTheDocument();
@@ -136,10 +184,38 @@ describe("Budget Panel", () => {
});
await waitFor(() => {
expect(networking.budgetDeleteCall).toHaveBeenCalledWith("token-123", "budget-to-delete");
expect(deleteMutateAsync).toHaveBeenCalledWith("budget-to-delete");
});
});
it("should open edit modal when clicking edit icon", async () => {
vi.mocked(useBudgets).mockReturnValue({
data: [
{
budget_id: "budget-to-edit",
max_budget: 300,
rpm_limit: 30,
tpm_limit: 3000,
updated_at: "2024-01-03T00:00:00Z",
},
],
isLoading: false,
} as any);
renderWithProviders(<BudgetPanel accessToken="token-123" />);
await waitFor(() => {
expect(screen.getByText("budget-to-edit")).toBeInTheDocument();
});
// Modal should still be open (error handling)
expect(screen.getByText("Delete Budget?")).toBeInTheDocument();
const editButton = screen.getByTestId("edit-budget-button");
act(() => {
fireEvent.click(editButton);
});
await waitFor(() => {
expect(screen.getByText("Edit Budget")).toBeInTheDocument();
});
});
});
@@ -19,12 +19,12 @@ import {
TabPanels,
Text,
} from "@tremor/react";
import React, { useEffect, useState } from "react";
import React, { useState } from "react";
import { Prism as SyntaxHighlighter } from "react-syntax-highlighter";
import DeleteResourceModal from "../common_components/DeleteResourceModal";
import TableIconActionButton from "../common_components/IconActionButton/TableIconActionButtons/TableIconActionButton";
import NotificationsManager from "../molecules/notifications_manager";
import { budgetDeleteCall, getBudgetList } from "../networking";
import { useBudgets, useDeleteBudget } from "@/app/(dashboard)/hooks/budgets/useBudgets";
import BudgetModal from "./budget_modal";
import EditBudgetModal from "./edit_budget_modal";
import { CREATE_END_USER_CURL_COMMAND, CHAT_COMPLETIONS_CURL_COMMAND, OPENAI_SDK_PYTHON_CODE } from "./constants";
@@ -35,7 +35,7 @@ interface BudgetSettingsPageProps {
export interface budgetItem {
budget_id: string;
max_budget: string | null;
max_budget: number | null;
rpm_limit: number | null;
tpm_limit: number | null;
updated_at: string;
@@ -45,17 +45,10 @@ const BudgetPanel: React.FC<BudgetSettingsPageProps> = ({ accessToken }) => {
const [isCreateModelVisible, setIsCreateModelVisible] = useState(false);
const [isEditModalVisible, setIsEditModalVisible] = useState(false);
const [selectedBudget, setSelectedBudget] = useState<budgetItem | null>(null);
const [budgetList, setBudgetList] = useState<budgetItem[]>([]);
const [isDeleting, setIsDeleting] = useState(false);
const [isDeleteModalVisible, setIsDeleteModalVisible] = useState(false);
useEffect(() => {
if (!accessToken) {
return;
}
getBudgetList(accessToken).then((data) => {
setBudgetList(data);
});
}, [accessToken]);
const { data: budgetList = [] } = useBudgets();
const deleteBudget = useDeleteBudget();
const handleEditCall = async (budget: budgetItem) => {
if (accessToken == null) {
@@ -74,11 +67,9 @@ const BudgetPanel: React.FC<BudgetSettingsPageProps> = ({ accessToken }) => {
if (!selectedBudget || accessToken == null) {
return;
}
setIsDeleting(true);
try {
await budgetDeleteCall(accessToken, selectedBudget.budget_id);
await deleteBudget.mutateAsync(selectedBudget.budget_id);
NotificationsManager.success("Budget deleted.");
await handleUpdateCall();
} catch (error) {
console.error("Error deleting budget:", error);
if (typeof NotificationsManager.fromBackend === "function") {
@@ -87,7 +78,6 @@ const BudgetPanel: React.FC<BudgetSettingsPageProps> = ({ accessToken }) => {
NotificationsManager.info("Failed to delete budget");
}
} finally {
setIsDeleting(false);
setIsDeleteModalVisible(false);
setSelectedBudget(null);
}
@@ -97,15 +87,6 @@ const BudgetPanel: React.FC<BudgetSettingsPageProps> = ({ accessToken }) => {
setIsDeleteModalVisible(false);
};
const handleUpdateCall = async () => {
if (accessToken == null) {
return;
}
getBudgetList(accessToken).then((data) => {
setBudgetList(data);
});
};
return (
<div className="w-full mx-auto flex-auto overflow-y-auto m-8 p-2">
<Button size="sm" variant="primary" className="mb-2" onClick={() => setIsCreateModelVisible(true)}>
@@ -120,19 +101,14 @@ const BudgetPanel: React.FC<BudgetSettingsPageProps> = ({ accessToken }) => {
<TabPanel>
<div className="mt-6">
<BudgetModal
accessToken={accessToken}
isModalVisible={isCreateModelVisible}
setIsModalVisible={setIsCreateModelVisible}
setBudgetList={setBudgetList}
/>
{selectedBudget && (
<EditBudgetModal
accessToken={accessToken}
isModalVisible={isEditModalVisible}
setIsModalVisible={setIsEditModalVisible}
setBudgetList={setBudgetList}
existingBudget={selectedBudget}
handleUpdateCall={handleUpdateCall}
/>
)}
<Card>
@@ -149,10 +125,10 @@ const BudgetPanel: React.FC<BudgetSettingsPageProps> = ({ accessToken }) => {
<TableBody>
{budgetList
.slice() // Creates a shallow copy to avoid mutating the original array
.sort((a, b) => new Date(b.updated_at).getTime() - new Date(a.updated_at).getTime()) // Sort by updated_at in descending order
.map((value: budgetItem, index: number) => (
<TableRow key={index}>
.slice()
.sort((a, b) => new Date(b.updated_at).getTime() - new Date(a.updated_at).getTime())
.map((value: budgetItem) => (
<TableRow key={value.budget_id}>
<TableCell>{value.budget_id}</TableCell>
<TableCell>{value.max_budget ? value.max_budget : "n/a"}</TableCell>
<TableCell>{value.tpm_limit ? value.tpm_limit : "n/a"}</TableCell>
@@ -187,7 +163,7 @@ const BudgetPanel: React.FC<BudgetSettingsPageProps> = ({ accessToken }) => {
]}
onCancel={handleDeleteCancel}
onOk={handleDeleteConfirm}
confirmLoading={isDeleting}
confirmLoading={deleteBudget.isPending}
/>
</div>
</TabPanel>
@@ -1,28 +1,22 @@
import React, { useEffect } from "react";
import { TextInput, Accordion, AccordionHeader, AccordionBody } from "@tremor/react";
import { Button as Button2, Modal, Form, InputNumber, Select } from "antd";
import { budgetUpdateCall } from "../networking";
import { useUpdateBudget } from "@/app/(dashboard)/hooks/budgets/useBudgets";
import { budgetItem } from "./budget_panel";
import NotificationsManager from "../molecules/notifications_manager";
interface BudgetModalProps {
interface EditBudgetModalProps {
isModalVisible: boolean;
accessToken: string | null;
setIsModalVisible: React.Dispatch<React.SetStateAction<boolean>>;
setBudgetList: React.Dispatch<React.SetStateAction<any[]>>;
existingBudget: budgetItem;
handleUpdateCall: () => void;
}
const EditBudgetModal: React.FC<BudgetModalProps> = ({
const EditBudgetModal: React.FC<EditBudgetModalProps> = ({
isModalVisible,
accessToken,
setIsModalVisible,
setBudgetList,
existingBudget,
handleUpdateCall,
}) => {
console.log("existingBudget", existingBudget);
const [form] = Form.useForm();
const updateBudget = useUpdateBudget();
useEffect(() => {
form.setFieldsValue(existingBudget);
@@ -38,21 +32,16 @@ const EditBudgetModal: React.FC<BudgetModalProps> = ({
form.resetFields();
};
const handleCreate = async (formValues: Record<string, any>) => {
if (accessToken == null || accessToken == undefined) {
return;
}
const handleUpdate = async (formValues: Record<string, any>) => {
try {
NotificationsManager.info("Making API Call");
setIsModalVisible(true);
const response = await budgetUpdateCall(accessToken, formValues);
setBudgetList((prevData) => (prevData ? [...prevData, response] : [response])); // Check if prevData is null
await updateBudget.mutateAsync(formValues);
NotificationsManager.success("Budget Updated");
form.resetFields();
handleUpdateCall();
setIsModalVisible(false);
} catch (error) {
console.error("Error creating the key:", error);
NotificationsManager.fromBackend(`Error creating the key: ${error}`);
console.error("Error updating the budget:", error);
NotificationsManager.fromBackend(`Error updating the budget: ${error}`);
}
};
@@ -67,7 +56,7 @@ const EditBudgetModal: React.FC<BudgetModalProps> = ({
>
<Form
form={form}
onFinish={handleCreate}
onFinish={handleUpdate}
labelCol={{ span: 8 }}
wrapperCol={{ span: 16 }}
labelAlign="left"
@@ -77,15 +66,9 @@ const EditBudgetModal: React.FC<BudgetModalProps> = ({
<Form.Item
label="Budget ID"
name="budget_id"
rules={[
{
required: true,
message: "Please input a human-friendly name for the budget",
},
]}
help="A human-friendly name for the budget"
help="Budget ID cannot be changed after creation"
>
<TextInput placeholder="" />
<TextInput placeholder="" disabled={true} />
</Form.Item>
<Form.Item label="Max Tokens per minute" name="tpm_limit" help="Default is model limit.">
<InputNumber step={1} precision={2} width={200} />