mirror of
https://github.com/tiennm99/litellm.git
synced 2026-06-17 12:48:57 +00:00
Merge pull request #24718 from BerriAI/litellm_ryan-march-26
litellm ryan march 26
This commit is contained in:
@@ -525,6 +525,7 @@ class LiteLLMRoutes(enum.Enum):
|
||||
# user
|
||||
"/user/new",
|
||||
"/user/update",
|
||||
"/user/bulk_update",
|
||||
"/user/delete",
|
||||
"/user/info",
|
||||
"/user/list",
|
||||
|
||||
@@ -18,6 +18,7 @@ from fastapi import HTTPException
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.caching.caching import DualCache
|
||||
from litellm.constants import DEFAULT_MANAGEMENT_OBJECT_IN_MEMORY_CACHE_TTL
|
||||
from litellm.litellm_core_utils.dot_notation_indexing import get_nested_value
|
||||
from litellm.llms.custom_httpx.httpx_handler import HTTPHandler
|
||||
from litellm.proxy._types import (
|
||||
@@ -89,7 +90,9 @@ class JWTHandler:
|
||||
self.leeway = leeway
|
||||
|
||||
@staticmethod
|
||||
def is_jwt(token: str):
|
||||
def is_jwt(token: Optional[str]) -> bool:
|
||||
if token is None:
|
||||
return False
|
||||
parts = token.split(".")
|
||||
return len(parts) == 3
|
||||
|
||||
@@ -1324,6 +1327,7 @@ class JWTAuthManager:
|
||||
jwt_valid_token: dict,
|
||||
user_object: Optional[LiteLLM_UserTable],
|
||||
prisma_client: Optional[PrismaClient],
|
||||
user_api_key_cache: Optional[DualCache] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Sync user role and team memberships with JWT claims
|
||||
@@ -1348,6 +1352,12 @@ class JWTAuthManager:
|
||||
data={"user_role": new_role.value},
|
||||
)
|
||||
user_object.user_role = new_role.value
|
||||
if user_api_key_cache is not None:
|
||||
await user_api_key_cache.async_set_cache(
|
||||
key=user_object.user_id,
|
||||
value=user_object.model_dump(),
|
||||
ttl=DEFAULT_MANAGEMENT_OBJECT_IN_MEMORY_CACHE_TTL,
|
||||
)
|
||||
|
||||
# Sync team memberships
|
||||
jwt_team_ids = set(jwt_handler.get_team_ids_from_jwt(jwt_valid_token))
|
||||
@@ -1365,6 +1375,12 @@ class JWTAuthManager:
|
||||
teams_ids_to_remove_user_from=list(teams_to_remove),
|
||||
)
|
||||
user_object.teams = list(jwt_team_ids)
|
||||
if user_api_key_cache is not None:
|
||||
await user_api_key_cache.async_set_cache(
|
||||
key=user_object.user_id,
|
||||
value=user_object.model_dump(),
|
||||
ttl=DEFAULT_MANAGEMENT_OBJECT_IN_MEMORY_CACHE_TTL,
|
||||
)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
@@ -1536,6 +1552,7 @@ class JWTAuthManager:
|
||||
jwt_valid_token=jwt_valid_token,
|
||||
user_object=user_object,
|
||||
prisma_client=prisma_client,
|
||||
user_api_key_cache=user_api_key_cache,
|
||||
)
|
||||
|
||||
## MAP USER TO TEAMS
|
||||
|
||||
@@ -629,6 +629,7 @@ class RouteChecks:
|
||||
in [
|
||||
"/user/new",
|
||||
"/user/delete",
|
||||
"/user/bulk_update",
|
||||
"/team/new",
|
||||
"/team/update",
|
||||
"/team/delete",
|
||||
|
||||
@@ -15,6 +15,7 @@ All /budget management endpoints
|
||||
from datetime import timedelta
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from prisma.errors import UniqueViolationError
|
||||
|
||||
from litellm.litellm_core_utils.duration_parser import duration_in_seconds
|
||||
from litellm.proxy._types import *
|
||||
@@ -90,13 +91,21 @@ async def new_budget(
|
||||
|
||||
budget_obj_json = budget_obj.model_dump(exclude_none=True)
|
||||
budget_obj_jsonified = jsonify_object(budget_obj_json) # json dump any dictionaries
|
||||
response = await prisma_client.db.litellm_budgettable.create(
|
||||
data={
|
||||
**budget_obj_jsonified, # type: ignore
|
||||
"created_by": user_api_key_dict.user_id or litellm_proxy_admin_name,
|
||||
"updated_by": user_api_key_dict.user_id or litellm_proxy_admin_name,
|
||||
} # type: ignore
|
||||
)
|
||||
try:
|
||||
response = await prisma_client.db.litellm_budgettable.create(
|
||||
data={
|
||||
**budget_obj_jsonified, # type: ignore
|
||||
"created_by": user_api_key_dict.user_id or litellm_proxy_admin_name,
|
||||
"updated_by": user_api_key_dict.user_id or litellm_proxy_admin_name,
|
||||
} # type: ignore
|
||||
)
|
||||
except UniqueViolationError:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": f"Budget with id '{budget_obj.budget_id}' already exists."
|
||||
},
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ These are members of a Team on LiteLLM
|
||||
|
||||
/user/new
|
||||
/user/update
|
||||
/user/bulk_update
|
||||
/user/delete
|
||||
/user/info
|
||||
/user/list
|
||||
|
||||
@@ -204,7 +204,7 @@ def process_sso_jwt_access_token(
|
||||
sso_jwt_handler: Optional[JWTHandler],
|
||||
result: Union[OpenID, dict, None],
|
||||
role_mappings: Optional["RoleMappings"] = None,
|
||||
) -> None:
|
||||
) -> Optional[dict]:
|
||||
"""
|
||||
Process SSO JWT access token and extract team IDs and user role if available.
|
||||
|
||||
@@ -218,6 +218,12 @@ def process_sso_jwt_access_token(
|
||||
sso_jwt_handler: SSO-specific JWT handler for team ID extraction
|
||||
result: The SSO result object to update with team IDs and role
|
||||
role_mappings: Optional role mappings configuration for group-based role determination
|
||||
|
||||
Returns:
|
||||
The decoded access token payload dict, or None if decoding failed or
|
||||
inputs were missing. Callers can pass this to _sync_user_role_from_jwt_role_map
|
||||
so it has access to custom role claims (e.g. custom_roles) that are
|
||||
encoded inside the JWT but stripped from received_response.
|
||||
"""
|
||||
if access_token_str and result:
|
||||
import jwt
|
||||
@@ -230,7 +236,7 @@ def process_sso_jwt_access_token(
|
||||
verbose_proxy_logger.debug(
|
||||
"Access token is not a valid JWT (possibly an opaque token), skipping JWT-based extraction"
|
||||
)
|
||||
return
|
||||
return None
|
||||
|
||||
# Extract team IDs from access token if sso_jwt_handler is available
|
||||
if sso_jwt_handler:
|
||||
@@ -306,6 +312,10 @@ def process_sso_jwt_access_token(
|
||||
f"Set user_role='{user_role}' from JWT access token"
|
||||
)
|
||||
|
||||
return access_token_payload
|
||||
|
||||
return None
|
||||
|
||||
|
||||
@router.get("/sso/key/generate", tags=["experimental"], include_in_schema=False)
|
||||
async def google_login(
|
||||
@@ -817,7 +827,7 @@ async def get_generic_sso_response(
|
||||
], # sso specific jwt handler - used for restricted sso group access control
|
||||
generic_client_id: str,
|
||||
redirect_url: str,
|
||||
) -> Tuple[Union[OpenID, dict], Optional[dict]]: # return received response
|
||||
) -> Tuple[Union[OpenID, dict], Optional[dict], Optional[dict]]: # (result, received_response, access_token_payload)
|
||||
# make generic sso provider
|
||||
from fastapi_sso.sso.base import DiscoveryDocument
|
||||
from fastapi_sso.sso.generic import create_provider
|
||||
@@ -872,6 +882,7 @@ async def get_generic_sso_response(
|
||||
code_verifier: Optional[
|
||||
str
|
||||
] = None # assigned inside try; initialized for type tracking
|
||||
access_token_payload: Optional[dict] = None # decoded JWT access token claims
|
||||
|
||||
try:
|
||||
token_exchange_params = (
|
||||
@@ -958,7 +969,7 @@ async def get_generic_sso_response(
|
||||
)
|
||||
access_token_str = generic_sso.access_token
|
||||
|
||||
process_sso_jwt_access_token(
|
||||
access_token_payload = process_sso_jwt_access_token(
|
||||
access_token_str, sso_jwt_handler, result, role_mappings=role_mappings
|
||||
)
|
||||
# Delete the single-use PKCE verifier only after all downstream processing
|
||||
@@ -976,7 +987,7 @@ async def get_generic_sso_response(
|
||||
additional_generic_sso_headers_dict,
|
||||
)
|
||||
verbose_proxy_logger.debug("generic result: %s", result)
|
||||
return result or {}, received_response
|
||||
return result or {}, received_response, access_token_payload
|
||||
|
||||
|
||||
async def create_team_member_add_task(team_id, user_info):
|
||||
@@ -1176,6 +1187,56 @@ def _build_sso_user_update_data(
|
||||
return update_data
|
||||
|
||||
|
||||
async def _sync_user_role_from_jwt_role_map(
|
||||
jwt_handler: Optional[JWTHandler],
|
||||
received_response: Optional[dict],
|
||||
user_info: Optional[Union[LiteLLM_UserTable, NewUserResponse]],
|
||||
prisma_client: PrismaClient,
|
||||
user_api_key_cache: DualCache,
|
||||
user_defined_values: Optional[SSOUserDefinedValues],
|
||||
) -> None:
|
||||
"""
|
||||
Apply jwt_litellm_role_map during SSO login.
|
||||
|
||||
When jwt_litellm_role_map is configured with sync_user_role_and_teams=True,
|
||||
this ensures SSO users get the same role mapping as API/JWT users. Without
|
||||
this, the SSO path falls back to INTERNAL_USER_VIEW_ONLY for roles that
|
||||
don't directly match LitellmUserRoles enum values.
|
||||
"""
|
||||
if jwt_handler is None or received_response is None:
|
||||
return
|
||||
if not jwt_handler.litellm_jwtauth.sync_user_role_and_teams:
|
||||
return
|
||||
if not jwt_handler.litellm_jwtauth.jwt_litellm_role_map:
|
||||
return
|
||||
|
||||
mapped_role = jwt_handler.map_jwt_role_to_litellm_role(received_response)
|
||||
if mapped_role is None:
|
||||
return
|
||||
|
||||
verbose_proxy_logger.info(
|
||||
f"SSO jwt_litellm_role_map matched role: {mapped_role.value}"
|
||||
)
|
||||
|
||||
# Update user_defined_values so downstream code uses the mapped role
|
||||
if user_defined_values is not None:
|
||||
user_defined_values["user_role"] = mapped_role.value
|
||||
|
||||
# Update existing DB record if role differs
|
||||
if user_info is not None and user_info.user_role != mapped_role.value:
|
||||
await prisma_client.db.litellm_usertable.update(
|
||||
where={"user_id": user_info.user_id},
|
||||
data={"user_role": mapped_role.value},
|
||||
)
|
||||
user_info.user_role = mapped_role.value
|
||||
await user_api_key_cache.async_set_cache(
|
||||
key=user_info.user_id,
|
||||
value=user_info.model_dump()
|
||||
if hasattr(user_info, "model_dump")
|
||||
else dict(user_info),
|
||||
)
|
||||
|
||||
|
||||
def apply_user_info_values_to_sso_user_defined_values(
|
||||
user_info: Optional[Union[LiteLLM_UserTable, NewUserResponse]],
|
||||
user_defined_values: Optional[SSOUserDefinedValues],
|
||||
@@ -1279,6 +1340,7 @@ async def auth_callback(request: Request, state: Optional[str] = None): # noqa:
|
||||
google_client_id = os.getenv("GOOGLE_CLIENT_ID", None)
|
||||
generic_client_id = os.getenv("GENERIC_CLIENT_ID", None)
|
||||
received_response: Optional[dict] = None
|
||||
access_token_payload: Optional[dict] = None
|
||||
# get url from request
|
||||
if master_key is None:
|
||||
raise ProxyException(
|
||||
@@ -1307,7 +1369,7 @@ async def auth_callback(request: Request, state: Optional[str] = None): # noqa:
|
||||
)
|
||||
|
||||
elif generic_client_id is not None:
|
||||
result, received_response = await get_generic_sso_response(
|
||||
result, received_response, access_token_payload = await get_generic_sso_response(
|
||||
request=request,
|
||||
jwt_handler=jwt_handler,
|
||||
generic_client_id=generic_client_id,
|
||||
@@ -1345,6 +1407,8 @@ async def auth_callback(request: Request, state: Optional[str] = None): # noqa:
|
||||
received_response=received_response,
|
||||
generic_client_id=generic_client_id,
|
||||
ui_access_mode=ui_access_mode,
|
||||
access_token_payload=access_token_payload,
|
||||
jwt_handler=jwt_handler,
|
||||
return_to=cp_return_to,
|
||||
)
|
||||
|
||||
@@ -2417,6 +2481,8 @@ class SSOAuthenticationHandler:
|
||||
received_response: Optional[dict] = None,
|
||||
generic_client_id: Optional[str] = None,
|
||||
ui_access_mode: Optional[Dict] = None,
|
||||
access_token_payload: Optional[dict] = None,
|
||||
jwt_handler: Optional[JWTHandler] = None,
|
||||
return_to: Optional[str] = None,
|
||||
) -> RedirectResponse:
|
||||
import jwt
|
||||
@@ -2498,6 +2564,20 @@ class SSOAuthenticationHandler:
|
||||
alternate_user_id=user_id,
|
||||
)
|
||||
|
||||
# Sync user role from JWT claims via jwt_litellm_role_map (if configured).
|
||||
# This ensures SSO users get the same role mapping as API/JWT users.
|
||||
# Use the decoded access_token_payload (not received_response) because
|
||||
# custom role claims (e.g. custom_roles) are encoded inside the JWT
|
||||
# access token, which is stripped from received_response.
|
||||
await _sync_user_role_from_jwt_role_map(
|
||||
jwt_handler=jwt_handler,
|
||||
received_response=access_token_payload or received_response,
|
||||
user_info=user_info,
|
||||
prisma_client=prisma_client,
|
||||
user_api_key_cache=user_api_key_cache,
|
||||
user_defined_values=user_defined_values,
|
||||
)
|
||||
|
||||
user_defined_values = apply_user_info_values_to_sso_user_defined_values(
|
||||
user_info=user_info, user_defined_values=user_defined_values
|
||||
)
|
||||
@@ -3703,7 +3783,7 @@ async def debug_sso_callback(request: Request):
|
||||
)
|
||||
|
||||
elif generic_client_id is not None:
|
||||
result, _ = await get_generic_sso_response(
|
||||
result, _, _ = await get_generic_sso_response(
|
||||
request=request,
|
||||
jwt_handler=jwt_handler,
|
||||
generic_client_id=generic_client_id,
|
||||
|
||||
@@ -1331,6 +1331,9 @@ def test_jwt_handler_is_jwt_static_method():
|
||||
# Test with empty string
|
||||
assert JWTHandler.is_jwt("") == False
|
||||
|
||||
# Test with None (missing Authorization header)
|
||||
assert JWTHandler.is_jwt(None) == False
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"requested_model, should_work",
|
||||
|
||||
@@ -339,6 +339,123 @@ async def test_sync_user_role_and_teams():
|
||||
assert set(user.teams) == {"team1", "team2"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sync_user_role_and_teams_cache_invalidation_on_role_change():
|
||||
"""Test that user cache is updated when role changes."""
|
||||
mock_cache = AsyncMock()
|
||||
|
||||
jwt_handler = JWTHandler()
|
||||
jwt_handler.update_environment(
|
||||
prisma_client=None,
|
||||
user_api_key_cache=AsyncMock(),
|
||||
litellm_jwtauth=LiteLLM_JWTAuth(
|
||||
jwt_litellm_role_map=[
|
||||
JWTLiteLLMRoleMap(jwt_role="ADMIN", litellm_role=LitellmUserRoles.PROXY_ADMIN)
|
||||
],
|
||||
roles_jwt_field="roles",
|
||||
team_ids_jwt_field="my_id_teams",
|
||||
sync_user_role_and_teams=True,
|
||||
),
|
||||
)
|
||||
|
||||
token = {"roles": ["ADMIN"], "my_id_teams": ["team1"]}
|
||||
user = LiteLLM_UserTable(
|
||||
user_id="u1",
|
||||
user_role=LitellmUserRoles.INTERNAL_USER.value,
|
||||
teams=["team1"], # teams already match — only role differs
|
||||
)
|
||||
|
||||
prisma = AsyncMock()
|
||||
prisma.db.litellm_usertable.update = AsyncMock()
|
||||
|
||||
await JWTAuthManager.sync_user_role_and_teams(
|
||||
jwt_handler, token, user, prisma, user_api_key_cache=mock_cache
|
||||
)
|
||||
|
||||
mock_cache.async_set_cache.assert_called_once()
|
||||
call_kwargs = mock_cache.async_set_cache.call_args
|
||||
assert call_kwargs.kwargs["key"] == "u1"
|
||||
assert call_kwargs.kwargs["value"]["user_role"] == LitellmUserRoles.PROXY_ADMIN.value
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sync_user_role_and_teams_cache_invalidation_on_team_change():
|
||||
"""Test that user cache is updated when team memberships change."""
|
||||
mock_cache = AsyncMock()
|
||||
|
||||
jwt_handler = JWTHandler()
|
||||
jwt_handler.update_environment(
|
||||
prisma_client=None,
|
||||
user_api_key_cache=AsyncMock(),
|
||||
litellm_jwtauth=LiteLLM_JWTAuth(
|
||||
jwt_litellm_role_map=[
|
||||
JWTLiteLLMRoleMap(jwt_role="ADMIN", litellm_role=LitellmUserRoles.PROXY_ADMIN)
|
||||
],
|
||||
roles_jwt_field="roles",
|
||||
team_ids_jwt_field="my_id_teams",
|
||||
sync_user_role_and_teams=True,
|
||||
),
|
||||
)
|
||||
|
||||
token = {"roles": ["ADMIN"], "my_id_teams": ["team1", "team2"]}
|
||||
user = LiteLLM_UserTable(
|
||||
user_id="u1",
|
||||
user_role=LitellmUserRoles.PROXY_ADMIN.value, # role already matches
|
||||
teams=["team2"], # teams differ
|
||||
)
|
||||
|
||||
prisma = AsyncMock()
|
||||
prisma.db.litellm_usertable.update = AsyncMock()
|
||||
|
||||
with patch(
|
||||
"litellm.proxy.management_endpoints.scim.scim_v2.patch_team_membership",
|
||||
new_callable=AsyncMock,
|
||||
):
|
||||
await JWTAuthManager.sync_user_role_and_teams(
|
||||
jwt_handler, token, user, prisma, user_api_key_cache=mock_cache
|
||||
)
|
||||
|
||||
mock_cache.async_set_cache.assert_called_once()
|
||||
call_kwargs = mock_cache.async_set_cache.call_args
|
||||
assert call_kwargs.kwargs["key"] == "u1"
|
||||
assert set(call_kwargs.kwargs["value"]["teams"]) == {"team1", "team2"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sync_user_role_and_teams_no_cache_write_when_nothing_changes():
|
||||
"""Test that cache is NOT written when role and teams already match."""
|
||||
mock_cache = AsyncMock()
|
||||
|
||||
jwt_handler = JWTHandler()
|
||||
jwt_handler.update_environment(
|
||||
prisma_client=None,
|
||||
user_api_key_cache=AsyncMock(),
|
||||
litellm_jwtauth=LiteLLM_JWTAuth(
|
||||
jwt_litellm_role_map=[
|
||||
JWTLiteLLMRoleMap(jwt_role="ADMIN", litellm_role=LitellmUserRoles.PROXY_ADMIN)
|
||||
],
|
||||
roles_jwt_field="roles",
|
||||
team_ids_jwt_field="my_id_teams",
|
||||
sync_user_role_and_teams=True,
|
||||
),
|
||||
)
|
||||
|
||||
token = {"roles": ["ADMIN"], "my_id_teams": ["team1"]}
|
||||
user = LiteLLM_UserTable(
|
||||
user_id="u1",
|
||||
user_role=LitellmUserRoles.PROXY_ADMIN.value,
|
||||
teams=["team1"],
|
||||
)
|
||||
|
||||
prisma = AsyncMock()
|
||||
|
||||
await JWTAuthManager.sync_user_role_and_teams(
|
||||
jwt_handler, token, user, prisma, user_api_key_cache=mock_cache
|
||||
)
|
||||
|
||||
mock_cache.async_set_cache.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_map_jwt_role_to_litellm_role():
|
||||
"""Test JWT role mapping to LiteLLM roles with various patterns"""
|
||||
|
||||
@@ -567,6 +567,10 @@ class TestJWTOAuth2Coexistence:
|
||||
assert JWTHandler.is_jwt("Bearer token") is False
|
||||
assert JWTHandler.is_jwt("two.parts") is False
|
||||
|
||||
def test_is_jwt_returns_false_for_none(self):
|
||||
"""None token (missing Authorization header) should not be treated as JWT."""
|
||||
assert JWTHandler.is_jwt(None) is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_both_enabled_opaque_token_uses_oauth2(self):
|
||||
"""
|
||||
|
||||
@@ -24,6 +24,7 @@ from litellm.proxy.management_endpoints.ui_sso import (
|
||||
MicrosoftSSOHandler,
|
||||
SSOAuthenticationHandler,
|
||||
_setup_team_mappings,
|
||||
_sync_user_role_from_jwt_role_map,
|
||||
determine_role_from_groups,
|
||||
normalize_email,
|
||||
process_sso_jwt_access_token,
|
||||
@@ -1321,7 +1322,7 @@ async def test_get_generic_sso_response_with_additional_headers():
|
||||
"fastapi_sso.sso.generic.create_provider", return_value=mock_sso_class
|
||||
):
|
||||
# Act
|
||||
result, received_response = await get_generic_sso_response(
|
||||
result, received_response, _ = await get_generic_sso_response(
|
||||
request=mock_request,
|
||||
jwt_handler=mock_jwt_handler,
|
||||
generic_client_id=generic_client_id,
|
||||
@@ -1383,7 +1384,7 @@ async def test_get_generic_sso_response_with_empty_headers():
|
||||
"fastapi_sso.sso.generic.create_provider", return_value=mock_sso_class
|
||||
):
|
||||
# Act
|
||||
result, received_response = await get_generic_sso_response(
|
||||
result, received_response, _ = await get_generic_sso_response(
|
||||
request=mock_request,
|
||||
jwt_handler=mock_jwt_handler,
|
||||
generic_client_id=generic_client_id,
|
||||
@@ -5254,3 +5255,159 @@ class TestValidateReturnTo:
|
||||
)
|
||||
SSOAuthenticationHandler._validate_return_to("https://cp.example.com:3000/ui")
|
||||
|
||||
|
||||
class TestSyncUserRoleFromJwtRoleMap:
|
||||
"""Tests for _sync_user_role_from_jwt_role_map."""
|
||||
|
||||
@staticmethod
|
||||
def _make_jwt_handler():
|
||||
from litellm.caching.caching import DualCache
|
||||
from litellm.proxy._types import (
|
||||
JWTLiteLLMRoleMap,
|
||||
LiteLLM_JWTAuth,
|
||||
LitellmUserRoles,
|
||||
)
|
||||
|
||||
handler = JWTHandler()
|
||||
handler.update_environment(
|
||||
prisma_client=None,
|
||||
user_api_key_cache=DualCache(),
|
||||
litellm_jwtauth=LiteLLM_JWTAuth(
|
||||
roles_jwt_field="custom_roles",
|
||||
user_id_upsert=True,
|
||||
sync_user_role_and_teams=True,
|
||||
jwt_litellm_role_map=[
|
||||
JWTLiteLLMRoleMap(
|
||||
jwt_role="my-admin",
|
||||
litellm_role=LitellmUserRoles.PROXY_ADMIN,
|
||||
),
|
||||
JWTLiteLLMRoleMap(
|
||||
jwt_role="my-viewer",
|
||||
litellm_role=LitellmUserRoles.INTERNAL_USER,
|
||||
),
|
||||
],
|
||||
),
|
||||
)
|
||||
return handler
|
||||
|
||||
@staticmethod
|
||||
def _make_sso_values(user_role=None):
|
||||
from litellm.proxy._types import SSOUserDefinedValues
|
||||
|
||||
user_id = "testuser@example.com"
|
||||
return SSOUserDefinedValues(
|
||||
models=[],
|
||||
user_id=user_id,
|
||||
user_email=user_id,
|
||||
user_role=user_role,
|
||||
max_budget=None,
|
||||
budget_duration=None,
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stripped_response_has_no_roles(self):
|
||||
"""Bug repro: stripped received_response lacks role claims."""
|
||||
from litellm.caching.caching import DualCache
|
||||
|
||||
handler = self._make_jwt_handler()
|
||||
sso_values = self._make_sso_values()
|
||||
|
||||
await _sync_user_role_from_jwt_role_map(
|
||||
jwt_handler=handler,
|
||||
received_response={"token_type": "Bearer", "expires_in": 3600},
|
||||
user_info=None,
|
||||
prisma_client=AsyncMock(),
|
||||
user_api_key_cache=DualCache(),
|
||||
user_defined_values=sso_values,
|
||||
)
|
||||
|
||||
assert sso_values["user_role"] is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_decoded_access_token_maps_role(self):
|
||||
"""Decoded JWT payload with role claims maps correctly."""
|
||||
from litellm.caching.caching import DualCache
|
||||
from litellm.proxy._types import LitellmUserRoles
|
||||
|
||||
handler = self._make_jwt_handler()
|
||||
sso_values = self._make_sso_values()
|
||||
|
||||
await _sync_user_role_from_jwt_role_map(
|
||||
jwt_handler=handler,
|
||||
received_response={"sub": "testuser@example.com", "custom_roles": ["my-admin"]},
|
||||
user_info=None,
|
||||
prisma_client=AsyncMock(),
|
||||
user_api_key_cache=DualCache(),
|
||||
user_defined_values=sso_values,
|
||||
)
|
||||
|
||||
assert sso_values["user_role"] == LitellmUserRoles.PROXY_ADMIN.value
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_existing_user_role_updated_in_db_and_cache(self):
|
||||
"""Existing user with stale role gets updated in DB and cache."""
|
||||
from litellm.caching.caching import DualCache
|
||||
from litellm.proxy._types import LitellmUserRoles
|
||||
|
||||
handler = self._make_jwt_handler()
|
||||
cache = DualCache()
|
||||
prisma = AsyncMock()
|
||||
prisma.db.litellm_usertable.update = AsyncMock()
|
||||
user_id = "testuser@example.com"
|
||||
|
||||
existing_user = LiteLLM_UserTable(
|
||||
user_id=user_id,
|
||||
user_role=LitellmUserRoles.INTERNAL_USER_VIEW_ONLY.value,
|
||||
)
|
||||
await cache.async_set_cache(key=user_id, value=existing_user.model_dump(), ttl=60)
|
||||
|
||||
sso_values = self._make_sso_values(
|
||||
user_role=LitellmUserRoles.INTERNAL_USER_VIEW_ONLY.value,
|
||||
)
|
||||
|
||||
await _sync_user_role_from_jwt_role_map(
|
||||
jwt_handler=handler,
|
||||
received_response={"sub": user_id, "custom_roles": ["my-admin"]},
|
||||
user_info=existing_user,
|
||||
prisma_client=prisma,
|
||||
user_api_key_cache=cache,
|
||||
user_defined_values=sso_values,
|
||||
)
|
||||
|
||||
prisma.db.litellm_usertable.update.assert_called_once_with(
|
||||
where={"user_id": user_id},
|
||||
data={"user_role": LitellmUserRoles.PROXY_ADMIN.value},
|
||||
)
|
||||
assert existing_user.user_role == LitellmUserRoles.PROXY_ADMIN.value
|
||||
assert sso_values["user_role"] == LitellmUserRoles.PROXY_ADMIN.value
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_same_role_no_db_write(self):
|
||||
"""No DB update when the mapped role matches the existing role."""
|
||||
from litellm.caching.caching import DualCache
|
||||
from litellm.proxy._types import LitellmUserRoles
|
||||
|
||||
handler = self._make_jwt_handler()
|
||||
prisma = AsyncMock()
|
||||
prisma.db.litellm_usertable.update = AsyncMock()
|
||||
|
||||
existing_user = LiteLLM_UserTable(
|
||||
user_id="testuser@example.com",
|
||||
user_role=LitellmUserRoles.PROXY_ADMIN.value,
|
||||
)
|
||||
|
||||
sso_values = self._make_sso_values(
|
||||
user_role=LitellmUserRoles.PROXY_ADMIN.value,
|
||||
)
|
||||
|
||||
await _sync_user_role_from_jwt_role_map(
|
||||
jwt_handler=handler,
|
||||
received_response={"sub": "testuser@example.com", "custom_roles": ["my-admin"]},
|
||||
user_info=existing_user,
|
||||
prisma_client=prisma,
|
||||
user_api_key_cache=DualCache(),
|
||||
user_defined_values=sso_values,
|
||||
)
|
||||
|
||||
prisma.db.litellm_usertable.update.assert_not_called()
|
||||
|
||||
|
||||
@@ -0,0 +1,70 @@
|
||||
import { useQuery, useMutation, useQueryClient, UseQueryResult } from "@tanstack/react-query";
|
||||
import { createQueryKeys } from "../common/queryKeysFactory";
|
||||
import { getBudgetList, budgetCreateCall, budgetUpdateCall, budgetDeleteCall } from "@/components/networking";
|
||||
import useAuthorized from "@/app/(dashboard)/hooks/useAuthorized";
|
||||
import { budgetItem } from "@/components/budgets/budget_panel";
|
||||
|
||||
export const budgetKeys = createQueryKeys("budgets");
|
||||
|
||||
export const useBudgets = (): UseQueryResult<budgetItem[]> => {
|
||||
const { accessToken } = useAuthorized();
|
||||
return useQuery<budgetItem[]>({
|
||||
queryKey: budgetKeys.list({}),
|
||||
queryFn: async () => {
|
||||
const data = await getBudgetList(accessToken!);
|
||||
return (data ?? []).filter((item: budgetItem | null): item is budgetItem => item != null);
|
||||
},
|
||||
enabled: Boolean(accessToken),
|
||||
});
|
||||
};
|
||||
|
||||
export const useCreateBudget = () => {
|
||||
const { accessToken } = useAuthorized();
|
||||
const queryClient = useQueryClient();
|
||||
|
||||
return useMutation<unknown, Error, Record<string, any>>({
|
||||
mutationFn: async (formValues) => {
|
||||
if (!accessToken) {
|
||||
throw new Error("Access token is required");
|
||||
}
|
||||
return budgetCreateCall(accessToken, formValues);
|
||||
},
|
||||
onSuccess: () => {
|
||||
queryClient.invalidateQueries({ queryKey: budgetKeys.all });
|
||||
},
|
||||
});
|
||||
};
|
||||
|
||||
export const useUpdateBudget = () => {
|
||||
const { accessToken } = useAuthorized();
|
||||
const queryClient = useQueryClient();
|
||||
|
||||
return useMutation<unknown, Error, Record<string, any>>({
|
||||
mutationFn: async (formValues) => {
|
||||
if (!accessToken) {
|
||||
throw new Error("Access token is required");
|
||||
}
|
||||
return budgetUpdateCall(accessToken, formValues);
|
||||
},
|
||||
onSuccess: () => {
|
||||
queryClient.invalidateQueries({ queryKey: budgetKeys.all });
|
||||
},
|
||||
});
|
||||
};
|
||||
|
||||
export const useDeleteBudget = () => {
|
||||
const { accessToken } = useAuthorized();
|
||||
const queryClient = useQueryClient();
|
||||
|
||||
return useMutation<unknown, Error, string>({
|
||||
mutationFn: async (budgetId) => {
|
||||
if (!accessToken) {
|
||||
throw new Error("Access token is required");
|
||||
}
|
||||
return budgetDeleteCall(accessToken, budgetId);
|
||||
},
|
||||
onSuccess: () => {
|
||||
queryClient.invalidateQueries({ queryKey: budgetKeys.all });
|
||||
},
|
||||
});
|
||||
};
|
||||
@@ -1,17 +1,17 @@
|
||||
import React from "react";
|
||||
import { TextInput, Accordion, AccordionHeader, AccordionBody } from "@tremor/react";
|
||||
import { Button as Button2, Modal, Form, InputNumber, Select } from "antd";
|
||||
import { budgetCreateCall } from "../networking";
|
||||
import { useCreateBudget } from "@/app/(dashboard)/hooks/budgets/useBudgets";
|
||||
import NotificationsManager from "../molecules/notifications_manager";
|
||||
|
||||
interface BudgetModalProps {
|
||||
isModalVisible: boolean;
|
||||
accessToken: string | null;
|
||||
setIsModalVisible: React.Dispatch<React.SetStateAction<boolean>>;
|
||||
setBudgetList: React.Dispatch<React.SetStateAction<any[]>>;
|
||||
}
|
||||
const BudgetModal: React.FC<BudgetModalProps> = ({ isModalVisible, accessToken, setIsModalVisible, setBudgetList }) => {
|
||||
const BudgetModal: React.FC<BudgetModalProps> = ({ isModalVisible, setIsModalVisible }) => {
|
||||
const [form] = Form.useForm();
|
||||
const createBudget = useCreateBudget();
|
||||
|
||||
const handleOk = () => {
|
||||
setIsModalVisible(false);
|
||||
form.resetFields();
|
||||
@@ -23,20 +23,15 @@ const BudgetModal: React.FC<BudgetModalProps> = ({ isModalVisible, accessToken,
|
||||
};
|
||||
|
||||
const handleCreate = async (formValues: Record<string, any>) => {
|
||||
if (accessToken == null || accessToken == undefined) {
|
||||
return;
|
||||
}
|
||||
try {
|
||||
NotificationsManager.info("Making API Call");
|
||||
// setIsModalVisible(true);
|
||||
const response = await budgetCreateCall(accessToken, formValues);
|
||||
console.log("key create Response:", response);
|
||||
setBudgetList((prevData) => (prevData ? [...prevData, response] : [response])); // Check if prevData is null
|
||||
await createBudget.mutateAsync(formValues);
|
||||
NotificationsManager.success("Budget Created");
|
||||
form.resetFields();
|
||||
setIsModalVisible(false);
|
||||
} catch (error) {
|
||||
console.error("Error creating the key:", error);
|
||||
NotificationsManager.fromBackend(`Error creating the key: ${error}`);
|
||||
console.error("Error creating the budget:", error);
|
||||
NotificationsManager.fromBackend(`Error creating the budget: ${error}`);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -1,31 +1,50 @@
|
||||
import * as networking from "../networking";
|
||||
import { fireEvent, render, waitFor, screen } from "@testing-library/react";
|
||||
import { act } from "@testing-library/react";
|
||||
import { QueryClient, QueryClientProvider } from "@tanstack/react-query";
|
||||
import { afterEach, describe, expect, it, vi } from "vitest";
|
||||
import BudgetPanel from "./budget_panel";
|
||||
|
||||
vi.mock("../networking", () => ({
|
||||
getBudgetList: vi.fn(),
|
||||
budgetDeleteCall: vi.fn(),
|
||||
const mockBudgets = [
|
||||
{
|
||||
budget_id: "budget-1",
|
||||
max_budget: 100,
|
||||
rpm_limit: 10,
|
||||
tpm_limit: 1000,
|
||||
updated_at: "2024-01-01T00:00:00Z",
|
||||
},
|
||||
];
|
||||
|
||||
vi.mock("@/app/(dashboard)/hooks/budgets/useBudgets", () => ({
|
||||
useBudgets: vi.fn().mockReturnValue({ data: [], isLoading: false }),
|
||||
useDeleteBudget: vi.fn().mockReturnValue({ mutateAsync: vi.fn(), isPending: false }),
|
||||
useCreateBudget: vi.fn().mockReturnValue({ mutateAsync: vi.fn() }),
|
||||
useUpdateBudget: vi.fn().mockReturnValue({ mutateAsync: vi.fn() }),
|
||||
}));
|
||||
|
||||
import { useBudgets, useDeleteBudget, useCreateBudget, useUpdateBudget } from "@/app/(dashboard)/hooks/budgets/useBudgets";
|
||||
|
||||
const createQueryClient = () =>
|
||||
new QueryClient({
|
||||
defaultOptions: { queries: { retry: false, gcTime: 0 } },
|
||||
});
|
||||
|
||||
function renderWithProviders(ui: React.ReactElement) {
|
||||
const qc = createQueryClient();
|
||||
return render(<QueryClientProvider client={qc}>{ui}</QueryClientProvider>);
|
||||
}
|
||||
|
||||
describe("Budget Panel", () => {
|
||||
afterEach(() => {
|
||||
vi.clearAllMocks();
|
||||
});
|
||||
|
||||
it("should render the budget panel and load budgets", async () => {
|
||||
vi.mocked(networking.getBudgetList).mockResolvedValue([
|
||||
{
|
||||
budget_id: "budget-1",
|
||||
max_budget: "100",
|
||||
rpm_limit: 10,
|
||||
tpm_limit: 1000,
|
||||
updated_at: "2024-01-01T00:00:00Z",
|
||||
},
|
||||
]);
|
||||
vi.mocked(useBudgets).mockReturnValue({
|
||||
data: mockBudgets,
|
||||
isLoading: false,
|
||||
} as any);
|
||||
|
||||
render(<BudgetPanel accessToken="token-123" />);
|
||||
renderWithProviders(<BudgetPanel accessToken="token-123" />);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText("Create a budget to assign to customers.")).toBeInTheDocument();
|
||||
@@ -34,17 +53,20 @@ describe("Budget Panel", () => {
|
||||
});
|
||||
|
||||
it("should open delete modal when clicking delete icon", async () => {
|
||||
vi.mocked(networking.getBudgetList).mockResolvedValue([
|
||||
{
|
||||
budget_id: "budget-to-delete",
|
||||
max_budget: "200",
|
||||
rpm_limit: 20,
|
||||
tpm_limit: 2000,
|
||||
updated_at: "2024-01-02T00:00:00Z",
|
||||
},
|
||||
]);
|
||||
vi.mocked(useBudgets).mockReturnValue({
|
||||
data: [
|
||||
{
|
||||
budget_id: "budget-to-delete",
|
||||
max_budget: 200,
|
||||
rpm_limit: 20,
|
||||
tpm_limit: 2000,
|
||||
updated_at: "2024-01-02T00:00:00Z",
|
||||
},
|
||||
],
|
||||
isLoading: false,
|
||||
} as any);
|
||||
|
||||
render(<BudgetPanel accessToken="token-123" />);
|
||||
renderWithProviders(<BudgetPanel accessToken="token-123" />);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText("budget-to-delete")).toBeInTheDocument();
|
||||
@@ -62,18 +84,25 @@ describe("Budget Panel", () => {
|
||||
});
|
||||
|
||||
it("should successfully delete a budget", async () => {
|
||||
vi.mocked(networking.getBudgetList).mockResolvedValue([
|
||||
{
|
||||
budget_id: "budget-to-delete",
|
||||
max_budget: "200",
|
||||
rpm_limit: 20,
|
||||
tpm_limit: 2000,
|
||||
updated_at: "2024-01-02T00:00:00Z",
|
||||
},
|
||||
]);
|
||||
vi.mocked(networking.budgetDeleteCall).mockResolvedValue(undefined);
|
||||
const deleteMutateAsync = vi.fn().mockResolvedValue(undefined);
|
||||
vi.mocked(useBudgets).mockReturnValue({
|
||||
data: [
|
||||
{
|
||||
budget_id: "budget-to-delete",
|
||||
max_budget: 200,
|
||||
rpm_limit: 20,
|
||||
tpm_limit: 2000,
|
||||
updated_at: "2024-01-02T00:00:00Z",
|
||||
},
|
||||
],
|
||||
isLoading: false,
|
||||
} as any);
|
||||
vi.mocked(useDeleteBudget).mockReturnValue({
|
||||
mutateAsync: deleteMutateAsync,
|
||||
isPending: false,
|
||||
} as any);
|
||||
|
||||
render(<BudgetPanel accessToken="token-123" />);
|
||||
renderWithProviders(<BudgetPanel accessToken="token-123" />);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText("budget-to-delete")).toBeInTheDocument();
|
||||
@@ -96,24 +125,43 @@ describe("Budget Panel", () => {
|
||||
});
|
||||
|
||||
await waitFor(() => {
|
||||
expect(networking.budgetDeleteCall).toHaveBeenCalledWith("token-123", "budget-to-delete");
|
||||
expect(networking.getBudgetList).toHaveBeenCalledTimes(2); // Initial load + refresh after delete
|
||||
expect(deleteMutateAsync).toHaveBeenCalledWith("budget-to-delete");
|
||||
});
|
||||
});
|
||||
|
||||
it("should render empty state without crashing", async () => {
|
||||
vi.mocked(useBudgets).mockReturnValue({
|
||||
data: [],
|
||||
isLoading: false,
|
||||
} as any);
|
||||
|
||||
renderWithProviders(<BudgetPanel accessToken="token-123" />);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText("Create a budget to assign to customers.")).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
it("should handle delete error", async () => {
|
||||
vi.mocked(networking.getBudgetList).mockResolvedValue([
|
||||
{
|
||||
budget_id: "budget-to-delete",
|
||||
max_budget: "200",
|
||||
rpm_limit: 20,
|
||||
tpm_limit: 2000,
|
||||
updated_at: "2024-01-02T00:00:00Z",
|
||||
},
|
||||
]);
|
||||
vi.mocked(networking.budgetDeleteCall).mockRejectedValue(new Error("Delete failed"));
|
||||
const deleteMutateAsync = vi.fn().mockRejectedValue(new Error("Delete failed"));
|
||||
vi.mocked(useBudgets).mockReturnValue({
|
||||
data: [
|
||||
{
|
||||
budget_id: "budget-to-delete",
|
||||
max_budget: 200,
|
||||
rpm_limit: 20,
|
||||
tpm_limit: 2000,
|
||||
updated_at: "2024-01-02T00:00:00Z",
|
||||
},
|
||||
],
|
||||
isLoading: false,
|
||||
} as any);
|
||||
vi.mocked(useDeleteBudget).mockReturnValue({
|
||||
mutateAsync: deleteMutateAsync,
|
||||
isPending: false,
|
||||
} as any);
|
||||
|
||||
render(<BudgetPanel accessToken="token-123" />);
|
||||
renderWithProviders(<BudgetPanel accessToken="token-123" />);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText("budget-to-delete")).toBeInTheDocument();
|
||||
@@ -136,10 +184,38 @@ describe("Budget Panel", () => {
|
||||
});
|
||||
|
||||
await waitFor(() => {
|
||||
expect(networking.budgetDeleteCall).toHaveBeenCalledWith("token-123", "budget-to-delete");
|
||||
expect(deleteMutateAsync).toHaveBeenCalledWith("budget-to-delete");
|
||||
});
|
||||
});
|
||||
|
||||
it("should open edit modal when clicking edit icon", async () => {
|
||||
vi.mocked(useBudgets).mockReturnValue({
|
||||
data: [
|
||||
{
|
||||
budget_id: "budget-to-edit",
|
||||
max_budget: 300,
|
||||
rpm_limit: 30,
|
||||
tpm_limit: 3000,
|
||||
updated_at: "2024-01-03T00:00:00Z",
|
||||
},
|
||||
],
|
||||
isLoading: false,
|
||||
} as any);
|
||||
|
||||
renderWithProviders(<BudgetPanel accessToken="token-123" />);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText("budget-to-edit")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
// Modal should still be open (error handling)
|
||||
expect(screen.getByText("Delete Budget?")).toBeInTheDocument();
|
||||
const editButton = screen.getByTestId("edit-budget-button");
|
||||
|
||||
act(() => {
|
||||
fireEvent.click(editButton);
|
||||
});
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText("Edit Budget")).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -19,12 +19,12 @@ import {
|
||||
TabPanels,
|
||||
Text,
|
||||
} from "@tremor/react";
|
||||
import React, { useEffect, useState } from "react";
|
||||
import React, { useState } from "react";
|
||||
import { Prism as SyntaxHighlighter } from "react-syntax-highlighter";
|
||||
import DeleteResourceModal from "../common_components/DeleteResourceModal";
|
||||
import TableIconActionButton from "../common_components/IconActionButton/TableIconActionButtons/TableIconActionButton";
|
||||
import NotificationsManager from "../molecules/notifications_manager";
|
||||
import { budgetDeleteCall, getBudgetList } from "../networking";
|
||||
import { useBudgets, useDeleteBudget } from "@/app/(dashboard)/hooks/budgets/useBudgets";
|
||||
import BudgetModal from "./budget_modal";
|
||||
import EditBudgetModal from "./edit_budget_modal";
|
||||
import { CREATE_END_USER_CURL_COMMAND, CHAT_COMPLETIONS_CURL_COMMAND, OPENAI_SDK_PYTHON_CODE } from "./constants";
|
||||
@@ -35,7 +35,7 @@ interface BudgetSettingsPageProps {
|
||||
|
||||
export interface budgetItem {
|
||||
budget_id: string;
|
||||
max_budget: string | null;
|
||||
max_budget: number | null;
|
||||
rpm_limit: number | null;
|
||||
tpm_limit: number | null;
|
||||
updated_at: string;
|
||||
@@ -45,17 +45,10 @@ const BudgetPanel: React.FC<BudgetSettingsPageProps> = ({ accessToken }) => {
|
||||
const [isCreateModelVisible, setIsCreateModelVisible] = useState(false);
|
||||
const [isEditModalVisible, setIsEditModalVisible] = useState(false);
|
||||
const [selectedBudget, setSelectedBudget] = useState<budgetItem | null>(null);
|
||||
const [budgetList, setBudgetList] = useState<budgetItem[]>([]);
|
||||
const [isDeleting, setIsDeleting] = useState(false);
|
||||
const [isDeleteModalVisible, setIsDeleteModalVisible] = useState(false);
|
||||
useEffect(() => {
|
||||
if (!accessToken) {
|
||||
return;
|
||||
}
|
||||
getBudgetList(accessToken).then((data) => {
|
||||
setBudgetList(data);
|
||||
});
|
||||
}, [accessToken]);
|
||||
|
||||
const { data: budgetList = [] } = useBudgets();
|
||||
const deleteBudget = useDeleteBudget();
|
||||
|
||||
const handleEditCall = async (budget: budgetItem) => {
|
||||
if (accessToken == null) {
|
||||
@@ -74,11 +67,9 @@ const BudgetPanel: React.FC<BudgetSettingsPageProps> = ({ accessToken }) => {
|
||||
if (!selectedBudget || accessToken == null) {
|
||||
return;
|
||||
}
|
||||
setIsDeleting(true);
|
||||
try {
|
||||
await budgetDeleteCall(accessToken, selectedBudget.budget_id);
|
||||
await deleteBudget.mutateAsync(selectedBudget.budget_id);
|
||||
NotificationsManager.success("Budget deleted.");
|
||||
await handleUpdateCall();
|
||||
} catch (error) {
|
||||
console.error("Error deleting budget:", error);
|
||||
if (typeof NotificationsManager.fromBackend === "function") {
|
||||
@@ -87,7 +78,6 @@ const BudgetPanel: React.FC<BudgetSettingsPageProps> = ({ accessToken }) => {
|
||||
NotificationsManager.info("Failed to delete budget");
|
||||
}
|
||||
} finally {
|
||||
setIsDeleting(false);
|
||||
setIsDeleteModalVisible(false);
|
||||
setSelectedBudget(null);
|
||||
}
|
||||
@@ -97,15 +87,6 @@ const BudgetPanel: React.FC<BudgetSettingsPageProps> = ({ accessToken }) => {
|
||||
setIsDeleteModalVisible(false);
|
||||
};
|
||||
|
||||
const handleUpdateCall = async () => {
|
||||
if (accessToken == null) {
|
||||
return;
|
||||
}
|
||||
getBudgetList(accessToken).then((data) => {
|
||||
setBudgetList(data);
|
||||
});
|
||||
};
|
||||
|
||||
return (
|
||||
<div className="w-full mx-auto flex-auto overflow-y-auto m-8 p-2">
|
||||
<Button size="sm" variant="primary" className="mb-2" onClick={() => setIsCreateModelVisible(true)}>
|
||||
@@ -120,19 +101,14 @@ const BudgetPanel: React.FC<BudgetSettingsPageProps> = ({ accessToken }) => {
|
||||
<TabPanel>
|
||||
<div className="mt-6">
|
||||
<BudgetModal
|
||||
accessToken={accessToken}
|
||||
isModalVisible={isCreateModelVisible}
|
||||
setIsModalVisible={setIsCreateModelVisible}
|
||||
setBudgetList={setBudgetList}
|
||||
/>
|
||||
{selectedBudget && (
|
||||
<EditBudgetModal
|
||||
accessToken={accessToken}
|
||||
isModalVisible={isEditModalVisible}
|
||||
setIsModalVisible={setIsEditModalVisible}
|
||||
setBudgetList={setBudgetList}
|
||||
existingBudget={selectedBudget}
|
||||
handleUpdateCall={handleUpdateCall}
|
||||
/>
|
||||
)}
|
||||
<Card>
|
||||
@@ -149,10 +125,10 @@ const BudgetPanel: React.FC<BudgetSettingsPageProps> = ({ accessToken }) => {
|
||||
|
||||
<TableBody>
|
||||
{budgetList
|
||||
.slice() // Creates a shallow copy to avoid mutating the original array
|
||||
.sort((a, b) => new Date(b.updated_at).getTime() - new Date(a.updated_at).getTime()) // Sort by updated_at in descending order
|
||||
.map((value: budgetItem, index: number) => (
|
||||
<TableRow key={index}>
|
||||
.slice()
|
||||
.sort((a, b) => new Date(b.updated_at).getTime() - new Date(a.updated_at).getTime())
|
||||
.map((value: budgetItem) => (
|
||||
<TableRow key={value.budget_id}>
|
||||
<TableCell>{value.budget_id}</TableCell>
|
||||
<TableCell>{value.max_budget ? value.max_budget : "n/a"}</TableCell>
|
||||
<TableCell>{value.tpm_limit ? value.tpm_limit : "n/a"}</TableCell>
|
||||
@@ -187,7 +163,7 @@ const BudgetPanel: React.FC<BudgetSettingsPageProps> = ({ accessToken }) => {
|
||||
]}
|
||||
onCancel={handleDeleteCancel}
|
||||
onOk={handleDeleteConfirm}
|
||||
confirmLoading={isDeleting}
|
||||
confirmLoading={deleteBudget.isPending}
|
||||
/>
|
||||
</div>
|
||||
</TabPanel>
|
||||
|
||||
@@ -1,28 +1,22 @@
|
||||
import React, { useEffect } from "react";
|
||||
import { TextInput, Accordion, AccordionHeader, AccordionBody } from "@tremor/react";
|
||||
import { Button as Button2, Modal, Form, InputNumber, Select } from "antd";
|
||||
import { budgetUpdateCall } from "../networking";
|
||||
import { useUpdateBudget } from "@/app/(dashboard)/hooks/budgets/useBudgets";
|
||||
import { budgetItem } from "./budget_panel";
|
||||
import NotificationsManager from "../molecules/notifications_manager";
|
||||
|
||||
interface BudgetModalProps {
|
||||
interface EditBudgetModalProps {
|
||||
isModalVisible: boolean;
|
||||
accessToken: string | null;
|
||||
setIsModalVisible: React.Dispatch<React.SetStateAction<boolean>>;
|
||||
setBudgetList: React.Dispatch<React.SetStateAction<any[]>>;
|
||||
existingBudget: budgetItem;
|
||||
handleUpdateCall: () => void;
|
||||
}
|
||||
const EditBudgetModal: React.FC<BudgetModalProps> = ({
|
||||
const EditBudgetModal: React.FC<EditBudgetModalProps> = ({
|
||||
isModalVisible,
|
||||
accessToken,
|
||||
setIsModalVisible,
|
||||
setBudgetList,
|
||||
existingBudget,
|
||||
handleUpdateCall,
|
||||
}) => {
|
||||
console.log("existingBudget", existingBudget);
|
||||
const [form] = Form.useForm();
|
||||
const updateBudget = useUpdateBudget();
|
||||
|
||||
useEffect(() => {
|
||||
form.setFieldsValue(existingBudget);
|
||||
@@ -38,21 +32,16 @@ const EditBudgetModal: React.FC<BudgetModalProps> = ({
|
||||
form.resetFields();
|
||||
};
|
||||
|
||||
const handleCreate = async (formValues: Record<string, any>) => {
|
||||
if (accessToken == null || accessToken == undefined) {
|
||||
return;
|
||||
}
|
||||
const handleUpdate = async (formValues: Record<string, any>) => {
|
||||
try {
|
||||
NotificationsManager.info("Making API Call");
|
||||
setIsModalVisible(true);
|
||||
const response = await budgetUpdateCall(accessToken, formValues);
|
||||
setBudgetList((prevData) => (prevData ? [...prevData, response] : [response])); // Check if prevData is null
|
||||
await updateBudget.mutateAsync(formValues);
|
||||
NotificationsManager.success("Budget Updated");
|
||||
form.resetFields();
|
||||
handleUpdateCall();
|
||||
setIsModalVisible(false);
|
||||
} catch (error) {
|
||||
console.error("Error creating the key:", error);
|
||||
NotificationsManager.fromBackend(`Error creating the key: ${error}`);
|
||||
console.error("Error updating the budget:", error);
|
||||
NotificationsManager.fromBackend(`Error updating the budget: ${error}`);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -67,7 +56,7 @@ const EditBudgetModal: React.FC<BudgetModalProps> = ({
|
||||
>
|
||||
<Form
|
||||
form={form}
|
||||
onFinish={handleCreate}
|
||||
onFinish={handleUpdate}
|
||||
labelCol={{ span: 8 }}
|
||||
wrapperCol={{ span: 16 }}
|
||||
labelAlign="left"
|
||||
@@ -77,15 +66,9 @@ const EditBudgetModal: React.FC<BudgetModalProps> = ({
|
||||
<Form.Item
|
||||
label="Budget ID"
|
||||
name="budget_id"
|
||||
rules={[
|
||||
{
|
||||
required: true,
|
||||
message: "Please input a human-friendly name for the budget",
|
||||
},
|
||||
]}
|
||||
help="A human-friendly name for the budget"
|
||||
help="Budget ID cannot be changed after creation"
|
||||
>
|
||||
<TextInput placeholder="" />
|
||||
<TextInput placeholder="" disabled={true} />
|
||||
</Form.Item>
|
||||
<Form.Item label="Max Tokens per minute" name="tpm_limit" help="Default is model limit.">
|
||||
<InputNumber step={1} precision={2} width={200} />
|
||||
|
||||
Reference in New Issue
Block a user