mirror of
https://github.com/tiennm99/litellm.git
synced 2026-06-26 23:05:45 +00:00
d533b432fd
Budget checks on API keys, teams, and team members were not enforced in multi-pod deployments because user_api_key_cache is intentionally in-memory-only. Each pod tracked spend independently, so with N pods the effective budget was N × max_budget. Introduces a separate spend_counter_cache (DualCache wired to redis_usage_cache) with atomic increment/read helpers: - increment_spend_counters(): awaited in cost callback (not create_task) to update both in-memory and Redis before the next auth check - get_current_spend(): reads Redis first (cross-pod authoritative), falls back to in-memory, then to cached object .spend from DB Budget check functions (_virtual_key_max_budget_check, _team_max_budget_check, _check_team_member_budget) now read spend via get_current_spend() instead of cached object .spend fields. When Redis is not configured, falls back to in-memory-only counters (same as current single-instance behavior). Fixes #23714
1783 lines
59 KiB
Python
1783 lines
59 KiB
Python
import asyncio
|
|
import json
|
|
import os
|
|
import sys
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
sys.path.insert(
|
|
0, os.path.abspath("../../..")
|
|
) # Adds the parent directory to the system path
|
|
|
|
from datetime import datetime, timedelta
|
|
|
|
import httpx
|
|
import pytest
|
|
|
|
import litellm
|
|
from litellm.proxy._types import (
|
|
CallInfo,
|
|
Litellm_EntityType,
|
|
LiteLLM_ObjectPermissionTable,
|
|
LiteLLM_TeamTable,
|
|
LiteLLM_UserTable,
|
|
LitellmUserRoles,
|
|
ProxyErrorTypes,
|
|
ProxyException,
|
|
SSOUserDefinedValues,
|
|
UserAPIKeyAuth,
|
|
)
|
|
from litellm.proxy.auth.auth_checks import (
|
|
ExperimentalUIJWTToken,
|
|
_can_object_call_vector_stores,
|
|
_check_team_member_budget,
|
|
_get_fuzzy_user_object,
|
|
_get_team_db_check,
|
|
_log_budget_lookup_failure,
|
|
_team_max_budget_check,
|
|
_virtual_key_max_budget_alert_check,
|
|
_virtual_key_max_budget_check,
|
|
_virtual_key_soft_budget_check,
|
|
get_key_object,
|
|
get_user_object,
|
|
vector_store_access_check,
|
|
)
|
|
from litellm.proxy.common_utils.encrypt_decrypt_utils import decrypt_value_helper
|
|
from litellm.utils import get_utc_datetime
|
|
|
|
|
|
@pytest.fixture(autouse=True)
|
|
def set_salt_key(monkeypatch):
|
|
"""Automatically set LITELLM_SALT_KEY for all tests"""
|
|
monkeypatch.setenv("LITELLM_SALT_KEY", "sk-1234")
|
|
|
|
|
|
@pytest.fixture(autouse=True)
|
|
def reset_constants_module():
|
|
"""Reset constants module to ensure clean state before each test"""
|
|
import importlib
|
|
|
|
from litellm import constants
|
|
from litellm.proxy.auth import auth_checks
|
|
|
|
# Reload modules before test
|
|
importlib.reload(constants)
|
|
importlib.reload(auth_checks)
|
|
|
|
yield
|
|
|
|
# Reload modules after test to clean up
|
|
importlib.reload(constants)
|
|
importlib.reload(auth_checks)
|
|
|
|
|
|
@pytest.fixture
|
|
def valid_sso_user_defined_values():
|
|
return LiteLLM_UserTable(
|
|
user_id="test_user",
|
|
user_email="test@example.com",
|
|
user_role=LitellmUserRoles.PROXY_ADMIN.value,
|
|
models=["gpt-3.5-turbo"],
|
|
max_budget=100.0,
|
|
)
|
|
|
|
|
|
@pytest.fixture
|
|
def invalid_sso_user_defined_values():
|
|
return LiteLLM_UserTable(
|
|
user_id="test_user",
|
|
user_email="test@example.com",
|
|
user_role=None, # Missing user role
|
|
models=["gpt-3.5-turbo"],
|
|
max_budget=100.0,
|
|
)
|
|
|
|
|
|
def test_get_experimental_ui_login_jwt_auth_token_valid(valid_sso_user_defined_values):
|
|
"""Test generating JWT token with valid user role"""
|
|
token = ExperimentalUIJWTToken.get_experimental_ui_login_jwt_auth_token(
|
|
valid_sso_user_defined_values
|
|
)
|
|
|
|
# Decrypt and verify token contents
|
|
decrypted_token = decrypt_value_helper(
|
|
token, key="ui_hash_key", exception_type="debug"
|
|
)
|
|
# Check that decrypted_token is not None before using json.loads
|
|
assert decrypted_token is not None
|
|
token_data = json.loads(decrypted_token)
|
|
|
|
assert token_data["user_id"] == "test_user"
|
|
assert token_data["user_role"] == LitellmUserRoles.PROXY_ADMIN.value
|
|
assert token_data["models"] == ["gpt-3.5-turbo"]
|
|
assert token_data["max_budget"] == litellm.max_ui_session_budget
|
|
|
|
# Verify expiration time is set and valid (Experimental UI uses fixed 10-min expiry)
|
|
assert "expires" in token_data
|
|
expires = datetime.fromisoformat(token_data["expires"].replace("Z", "+00:00"))
|
|
now = get_utc_datetime()
|
|
# Allow 2 second buffer for test execution timing
|
|
assert expires > now
|
|
assert expires <= now + timedelta(minutes=10, seconds=2)
|
|
|
|
|
|
def test_get_experimental_ui_login_jwt_auth_token_uses_10_min_expiry(
|
|
valid_sso_user_defined_values,
|
|
):
|
|
"""Test that Experimental UI token uses fixed 10-minute expiry (does not use LITELLM_UI_SESSION_DURATION)."""
|
|
token = ExperimentalUIJWTToken.get_experimental_ui_login_jwt_auth_token(
|
|
valid_sso_user_defined_values
|
|
)
|
|
decrypted_token = decrypt_value_helper(
|
|
token, key="ui_hash_key", exception_type="debug"
|
|
)
|
|
assert decrypted_token is not None
|
|
token_data = json.loads(decrypted_token)
|
|
expires = datetime.fromisoformat(token_data["expires"].replace("Z", "+00:00"))
|
|
now = get_utc_datetime()
|
|
# Should expire in ~10 minutes (allow 2 second buffer)
|
|
assert expires > now + timedelta(minutes=9)
|
|
assert expires <= now + timedelta(minutes=10, seconds=2)
|
|
|
|
|
|
def test_experimental_ui_token_ignores_litellm_ui_session_duration(
|
|
valid_sso_user_defined_values,
|
|
):
|
|
"""Regression test: LITELLM_UI_SESSION_DURATION must NOT affect Experimental UI token expiry.
|
|
Experimental UI intentionally uses fixed 10-min expiry. If this test fails, the constant
|
|
was incorrectly wired to the experimental flow."""
|
|
# Default LITELLM_UI_SESSION_DURATION is "24h" - token must still expire in ~10 min
|
|
token = ExperimentalUIJWTToken.get_experimental_ui_login_jwt_auth_token(
|
|
valid_sso_user_defined_values
|
|
)
|
|
decrypted_token = decrypt_value_helper(
|
|
token, key="ui_hash_key", exception_type="debug"
|
|
)
|
|
assert decrypted_token is not None
|
|
token_data = json.loads(decrypted_token)
|
|
expires = datetime.fromisoformat(token_data["expires"].replace("Z", "+00:00"))
|
|
now = get_utc_datetime()
|
|
# Must be ~10 min, NOT 24h. If LITELLM_UI_SESSION_DURATION were incorrectly used, this would fail.
|
|
assert expires <= now + timedelta(minutes=11), (
|
|
"Experimental UI must use 10-min expiry, not LITELLM_UI_SESSION_DURATION"
|
|
)
|
|
|
|
|
|
def test_get_experimental_ui_login_jwt_auth_token_invalid(
|
|
invalid_sso_user_defined_values,
|
|
):
|
|
"""Test generating JWT token with missing user role"""
|
|
with pytest.raises(Exception) as exc_info:
|
|
ExperimentalUIJWTToken.get_experimental_ui_login_jwt_auth_token(
|
|
invalid_sso_user_defined_values
|
|
)
|
|
|
|
assert str(exc_info.value) == "User role is required for experimental UI login"
|
|
|
|
|
|
def test_get_key_object_from_ui_hash_key_valid(
|
|
valid_sso_user_defined_values, monkeypatch
|
|
):
|
|
"""Test getting key object from valid UI hash key"""
|
|
monkeypatch.setenv("EXPERIMENTAL_UI_LOGIN", "True")
|
|
# Generate a valid token
|
|
token = ExperimentalUIJWTToken.get_experimental_ui_login_jwt_auth_token(
|
|
valid_sso_user_defined_values
|
|
)
|
|
|
|
# Get key object
|
|
key_object = ExperimentalUIJWTToken.get_key_object_from_ui_hash_key(token)
|
|
|
|
assert key_object is not None
|
|
assert key_object.user_id == "test_user"
|
|
assert key_object.user_role == LitellmUserRoles.PROXY_ADMIN
|
|
assert key_object.models == ["gpt-3.5-turbo"]
|
|
assert key_object.max_budget == litellm.max_ui_session_budget
|
|
|
|
|
|
def test_get_key_object_from_ui_hash_key_invalid():
|
|
"""Test getting key object from invalid UI hash key"""
|
|
# Test with invalid token
|
|
key_object = ExperimentalUIJWTToken.get_key_object_from_ui_hash_key("invalid_token")
|
|
assert key_object is None
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_key_object_should_reconnect_once_on_db_connection_error():
|
|
mock_prisma_client = MagicMock()
|
|
mock_prisma_client.get_data = AsyncMock(
|
|
side_effect=[
|
|
httpx.ConnectError("db connection reset"),
|
|
UserAPIKeyAuth(token="hashed-token-1"),
|
|
]
|
|
)
|
|
mock_prisma_client.attempt_db_reconnect = AsyncMock(return_value=True)
|
|
|
|
mock_cache = MagicMock()
|
|
mock_cache.async_get_cache = AsyncMock(return_value=None)
|
|
mock_cache.async_set_cache = AsyncMock()
|
|
|
|
key_obj = await get_key_object(
|
|
hashed_token="hashed-token-1",
|
|
prisma_client=mock_prisma_client,
|
|
user_api_key_cache=mock_cache,
|
|
)
|
|
|
|
assert key_obj.token == "hashed-token-1"
|
|
assert mock_prisma_client.get_data.await_count == 2
|
|
mock_prisma_client.attempt_db_reconnect.assert_awaited_once_with(
|
|
reason="auth_get_key_object_lookup_failure",
|
|
timeout_seconds=2.0,
|
|
lock_timeout_seconds=0.1,
|
|
)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_key_object_should_raise_if_reconnect_fails_on_db_connection_error():
|
|
mock_prisma_client = MagicMock()
|
|
mock_prisma_client.get_data = AsyncMock(
|
|
side_effect=httpx.ConnectError("db not reachable after outage")
|
|
)
|
|
mock_prisma_client.attempt_db_reconnect = AsyncMock(return_value=False)
|
|
|
|
mock_cache = MagicMock()
|
|
mock_cache.async_get_cache = AsyncMock(return_value=None)
|
|
mock_cache.async_set_cache = AsyncMock()
|
|
|
|
with pytest.raises(Exception, match="db not reachable after outage"):
|
|
await get_key_object(
|
|
hashed_token="hashed-token-2",
|
|
prisma_client=mock_prisma_client,
|
|
user_api_key_cache=mock_cache,
|
|
)
|
|
|
|
mock_prisma_client.attempt_db_reconnect.assert_awaited_once_with(
|
|
reason="auth_get_key_object_lookup_failure",
|
|
timeout_seconds=2.0,
|
|
lock_timeout_seconds=0.1,
|
|
)
|
|
assert mock_prisma_client.get_data.await_count == 1
|
|
|
|
|
|
def test_get_cli_jwt_auth_token_default_expiration(valid_sso_user_defined_values):
|
|
"""Test generating CLI JWT token with default 24-hour expiration"""
|
|
token = ExperimentalUIJWTToken.get_cli_jwt_auth_token(valid_sso_user_defined_values)
|
|
|
|
# Decrypt and verify token contents
|
|
decrypted_token = decrypt_value_helper(
|
|
token, key="ui_hash_key", exception_type="debug"
|
|
)
|
|
assert decrypted_token is not None
|
|
token_data = json.loads(decrypted_token)
|
|
|
|
assert token_data["user_id"] == "test_user"
|
|
assert token_data["user_role"] == LitellmUserRoles.PROXY_ADMIN.value
|
|
assert token_data["models"] == ["gpt-3.5-turbo"]
|
|
assert token_data["max_budget"] == litellm.max_ui_session_budget
|
|
|
|
# Verify expiration time is set to 24 hours (default)
|
|
assert "expires" in token_data
|
|
expires = datetime.fromisoformat(token_data["expires"].replace("Z", "+00:00"))
|
|
assert expires > get_utc_datetime()
|
|
assert expires <= get_utc_datetime() + timedelta(hours=24, minutes=1)
|
|
assert expires >= get_utc_datetime() + timedelta(hours=23, minutes=59)
|
|
|
|
|
|
def test_get_cli_jwt_auth_token_custom_expiration(
|
|
valid_sso_user_defined_values, monkeypatch
|
|
):
|
|
"""Test generating CLI JWT token with custom expiration via environment variable"""
|
|
import importlib
|
|
|
|
from litellm import constants
|
|
from litellm.proxy.auth import auth_checks
|
|
|
|
# Set custom expiration to 48 hours
|
|
monkeypatch.setenv("LITELLM_CLI_JWT_EXPIRATION_HOURS", "48")
|
|
|
|
# Reload the constants module to pick up the new env var
|
|
importlib.reload(constants)
|
|
# Also reload auth_checks to pick up the new constant value
|
|
importlib.reload(auth_checks)
|
|
|
|
token = auth_checks.ExperimentalUIJWTToken.get_cli_jwt_auth_token(valid_sso_user_defined_values)
|
|
|
|
# Decrypt and verify token contents
|
|
decrypted_token = decrypt_value_helper(
|
|
token, key="ui_hash_key", exception_type="debug"
|
|
)
|
|
assert decrypted_token is not None
|
|
token_data = json.loads(decrypted_token)
|
|
|
|
# Verify expiration time is set to 48 hours
|
|
assert "expires" in token_data
|
|
expires = datetime.fromisoformat(token_data["expires"].replace("Z", "+00:00"))
|
|
assert expires > get_utc_datetime() + timedelta(hours=47, minutes=59)
|
|
assert expires <= get_utc_datetime() + timedelta(hours=48, minutes=1)
|
|
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_default_internal_user_params_with_get_user_object(monkeypatch):
|
|
"""Test that default_internal_user_params is used when creating a new user via get_user_object"""
|
|
# Set up default_internal_user_params
|
|
default_params = {
|
|
"models": ["gpt-4", "claude-3-opus"],
|
|
"max_budget": 200.0,
|
|
"user_role": "internal_user",
|
|
}
|
|
monkeypatch.setattr(litellm, "default_internal_user_params", default_params)
|
|
|
|
# Mock the necessary dependencies
|
|
mock_prisma_client = MagicMock()
|
|
mock_db = AsyncMock()
|
|
mock_prisma_client.db = mock_db
|
|
|
|
# Set up the user creation mock - create a complete user model that can be converted to a dict
|
|
mock_user = MagicMock()
|
|
mock_user.user_id = "new_test_user"
|
|
mock_user.models = ["gpt-4", "claude-3-opus"]
|
|
mock_user.max_budget = 200.0
|
|
mock_user.user_role = "internal_user"
|
|
mock_user.organization_memberships = []
|
|
|
|
# Make the mock model_dump or dict method return appropriate data
|
|
mock_user.dict = lambda: {
|
|
"user_id": "new_test_user",
|
|
"models": ["gpt-4", "claude-3-opus"],
|
|
"max_budget": 200.0,
|
|
"user_role": "internal_user",
|
|
"organization_memberships": [],
|
|
}
|
|
|
|
# Setup the mock returns
|
|
mock_prisma_client.db.litellm_usertable.find_unique = AsyncMock(return_value=None)
|
|
mock_prisma_client.db.litellm_usertable.create = AsyncMock(return_value=mock_user)
|
|
|
|
# Create a mock cache - use AsyncMock for async methods
|
|
mock_cache = MagicMock()
|
|
mock_cache.async_get_cache = AsyncMock(return_value=None)
|
|
mock_cache.async_set_cache = AsyncMock()
|
|
|
|
# Call get_user_object with user_id_upsert=True to trigger user creation
|
|
try:
|
|
user_obj = await get_user_object(
|
|
user_id="new_test_user",
|
|
prisma_client=mock_prisma_client,
|
|
user_api_key_cache=mock_cache,
|
|
user_id_upsert=True,
|
|
proxy_logging_obj=None,
|
|
)
|
|
except Exception as e:
|
|
# this fails since the mock object is a MagicMock and not a LiteLLM_UserTable
|
|
print(e)
|
|
|
|
# Verify the user was created with the default params
|
|
mock_prisma_client.db.litellm_usertable.create.assert_called_once()
|
|
creation_args = mock_prisma_client.db.litellm_usertable.create.call_args[1]["data"]
|
|
|
|
# Verify defaults were applied to the creation args
|
|
assert "models" in creation_args
|
|
assert creation_args["models"] == ["gpt-4", "claude-3-opus"]
|
|
assert creation_args["max_budget"] == 200.0
|
|
assert creation_args["user_role"] == "internal_user"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_user_object_upsert_includes_user_email():
|
|
"""Test that user_email is included when creating a new user via get_user_object upsert"""
|
|
# Mock the necessary dependencies
|
|
mock_prisma_client = MagicMock()
|
|
mock_db = AsyncMock()
|
|
mock_prisma_client.db = mock_db
|
|
|
|
# Set up the user creation mock
|
|
mock_user = MagicMock()
|
|
mock_user.user_id = "new_test_user"
|
|
mock_user.user_email = "test@example.com"
|
|
mock_user.models = []
|
|
mock_user.max_budget = None
|
|
mock_user.user_role = None
|
|
mock_user.organization_memberships = []
|
|
|
|
mock_user.dict = lambda: {
|
|
"user_id": "new_test_user",
|
|
"user_email": "test@example.com",
|
|
"models": [],
|
|
"max_budget": None,
|
|
"user_role": None,
|
|
"organization_memberships": [],
|
|
}
|
|
|
|
# Setup the mock returns - user does not exist
|
|
mock_prisma_client.db.litellm_usertable.find_unique = AsyncMock(return_value=None)
|
|
mock_prisma_client.db.litellm_usertable.find_first = AsyncMock(return_value=None)
|
|
mock_prisma_client.db.litellm_usertable.create = AsyncMock(return_value=mock_user)
|
|
|
|
# Create a mock cache
|
|
mock_cache = MagicMock()
|
|
mock_cache.async_get_cache = AsyncMock(return_value=None)
|
|
mock_cache.async_set_cache = AsyncMock()
|
|
|
|
# Call get_user_object with user_id_upsert=True and user_email
|
|
try:
|
|
await get_user_object(
|
|
user_id="new_test_user",
|
|
prisma_client=mock_prisma_client,
|
|
user_api_key_cache=mock_cache,
|
|
user_id_upsert=True,
|
|
proxy_logging_obj=None,
|
|
user_email="test@example.com",
|
|
)
|
|
except Exception as e:
|
|
# May fail since mock object is not a real LiteLLM_UserTable
|
|
print(e)
|
|
|
|
# Verify the user was created with user_email included
|
|
mock_prisma_client.db.litellm_usertable.create.assert_called_once()
|
|
creation_args = mock_prisma_client.db.litellm_usertable.create.call_args[1]["data"]
|
|
|
|
assert "user_email" in creation_args, "user_email should be included when upserting a new user"
|
|
assert creation_args["user_email"] == "test@example.com"
|
|
assert creation_args["user_id"] == "new_test_user"
|
|
|
|
|
|
def test_log_budget_lookup_failure_dry_run():
|
|
"""Dry run: verify _log_budget_lookup_failure logs for schema/DB errors."""
|
|
with patch("litellm.proxy.auth.auth_checks.verbose_proxy_logger") as mock_logger:
|
|
err = Exception("column 'policies' does not exist in prisma schema")
|
|
_log_budget_lookup_failure("user", err)
|
|
mock_logger.error.assert_called_once()
|
|
call_msg = mock_logger.error.call_args[0][0]
|
|
assert "user" in call_msg
|
|
assert "cache will not be populated" in call_msg
|
|
assert "policies" in call_msg or "prisma" in call_msg
|
|
assert "prisma db push" in call_msg
|
|
|
|
|
|
def test_log_budget_lookup_failure_skips_user_not_found():
|
|
"""Verify _log_budget_lookup_failure does NOT log for expected user-not-found."""
|
|
with patch("litellm.proxy.auth.auth_checks.verbose_proxy_logger") as mock_logger:
|
|
err = Exception() # bare Exception from get_user_object when user not found
|
|
_log_budget_lookup_failure("user", err)
|
|
mock_logger.error.assert_not_called()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
@patch("litellm.proxy.management_endpoints.team_endpoints.new_team", new_callable=AsyncMock)
|
|
async def test_get_team_db_check_calls_new_team_on_upsert(mock_new_team, monkeypatch):
|
|
"""
|
|
Test that _get_team_db_check correctly calls the `new_team` function
|
|
when a team does not exist and upsert is enabled.
|
|
"""
|
|
mock_prisma_client = MagicMock()
|
|
mock_db = AsyncMock()
|
|
mock_prisma_client.db = mock_db
|
|
mock_prisma_client.db.litellm_teamtable.find_unique.return_value = None
|
|
|
|
# Define what our mocked `new_team` function should return
|
|
team_id_to_create = "new-jwt-team"
|
|
mock_new_team.return_value = {"team_id": team_id_to_create, "max_budget": 123.45}
|
|
|
|
await _get_team_db_check(
|
|
team_id=team_id_to_create,
|
|
prisma_client=mock_prisma_client,
|
|
team_id_upsert=True,
|
|
)
|
|
|
|
# Verify that our mocked `new_team` function was called exactly once
|
|
mock_new_team.assert_called_once()
|
|
|
|
call_args = mock_new_team.call_args[1]
|
|
data_arg = call_args["data"]
|
|
|
|
# Verify that `new_team` was called with the correct team_id and that
|
|
# `max_budget` was None, as our function's job is to delegate, not to set defaults.
|
|
assert data_arg.team_id == team_id_to_create
|
|
assert data_arg.max_budget is None
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
@patch("litellm.proxy.management_endpoints.team_endpoints.new_team", new_callable=AsyncMock)
|
|
async def test_get_team_db_check_does_not_call_new_team_if_exists(mock_new_team, monkeypatch):
|
|
"""
|
|
Test that _get_team_db_check does NOT call the `new_team` function
|
|
if the team already exists in the database.
|
|
"""
|
|
mock_prisma_client = MagicMock()
|
|
mock_db = AsyncMock()
|
|
mock_prisma_client.db = mock_db
|
|
mock_prisma_client.db.litellm_teamtable.find_unique.return_value = MagicMock()
|
|
|
|
team_id_to_find = "existing-jwt-team"
|
|
|
|
await _get_team_db_check(
|
|
team_id=team_id_to_find,
|
|
prisma_client=mock_prisma_client,
|
|
team_id_upsert=True,
|
|
)
|
|
|
|
# Verify that `new_team` was NEVER called, because the team was found.
|
|
mock_new_team.assert_not_called()
|
|
|
|
|
|
# Vector Store Auth Check Tests
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
@pytest.mark.parametrize(
|
|
"prisma_client,vector_store_registry,expected_result",
|
|
[
|
|
(None, MagicMock(), True), # No prisma client
|
|
(MagicMock(), None, True), # No vector store registry
|
|
(MagicMock(), MagicMock(), True), # No vector stores to run
|
|
],
|
|
)
|
|
async def test_vector_store_access_check_early_returns(
|
|
prisma_client, vector_store_registry, expected_result
|
|
):
|
|
"""Test vector_store_access_check returns True for early exit conditions"""
|
|
request_body = {"messages": [{"role": "user", "content": "test"}]}
|
|
|
|
if vector_store_registry:
|
|
vector_store_registry.get_vector_store_ids_to_run.return_value = None
|
|
|
|
with patch("litellm.proxy.proxy_server.prisma_client", prisma_client), patch(
|
|
"litellm.vector_store_registry", vector_store_registry
|
|
):
|
|
result = await vector_store_access_check(
|
|
request_body=request_body,
|
|
team_object=None,
|
|
valid_token=None,
|
|
)
|
|
|
|
assert result == expected_result
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"object_permissions,vector_store_ids,should_raise,error_type",
|
|
[
|
|
(None, ["store-1"], False, None), # None permissions - should pass
|
|
(
|
|
{"vector_stores": []},
|
|
["store-1"],
|
|
False,
|
|
None,
|
|
), # Empty vector_stores - should pass (access to all)
|
|
(
|
|
{"vector_stores": ["store-1", "store-2"]},
|
|
["store-1"],
|
|
False,
|
|
None,
|
|
), # Has access
|
|
(
|
|
{"vector_stores": ["store-1", "store-2"]},
|
|
["store-3"],
|
|
True,
|
|
ProxyErrorTypes.key_vector_store_access_denied,
|
|
), # No access
|
|
(
|
|
{"vector_stores": ["store-1"]},
|
|
["store-1", "store-3"],
|
|
True,
|
|
ProxyErrorTypes.team_vector_store_access_denied,
|
|
), # Partial access
|
|
],
|
|
)
|
|
def test_can_object_call_vector_stores_scenarios(
|
|
object_permissions, vector_store_ids, should_raise, error_type
|
|
):
|
|
"""Test _can_object_call_vector_stores with various permission scenarios"""
|
|
# Convert dict to object if not None
|
|
if object_permissions is not None:
|
|
mock_permissions = MagicMock()
|
|
mock_permissions.vector_stores = object_permissions["vector_stores"]
|
|
object_permissions = mock_permissions
|
|
|
|
object_type = (
|
|
"key"
|
|
if error_type == ProxyErrorTypes.key_vector_store_access_denied
|
|
else "team"
|
|
)
|
|
|
|
if should_raise:
|
|
with pytest.raises(ProxyException) as exc_info:
|
|
_can_object_call_vector_stores(
|
|
object_type=object_type,
|
|
vector_store_ids_to_run=vector_store_ids,
|
|
object_permissions=object_permissions,
|
|
)
|
|
assert exc_info.value.type == error_type
|
|
else:
|
|
result = _can_object_call_vector_stores(
|
|
object_type=object_type,
|
|
vector_store_ids_to_run=vector_store_ids,
|
|
object_permissions=object_permissions,
|
|
)
|
|
assert result is True
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_vector_store_access_check_with_permissions():
|
|
"""Test vector_store_access_check with actual permission checking"""
|
|
request_body = {"tools": [{"type": "function", "function": {"name": "test"}}]}
|
|
|
|
# Test with valid token that has access
|
|
valid_token = UserAPIKeyAuth(
|
|
token="test-token",
|
|
object_permission_id="perm-123",
|
|
models=["gpt-4"],
|
|
max_budget=100.0,
|
|
)
|
|
|
|
mock_prisma_client = MagicMock()
|
|
mock_permissions = MagicMock()
|
|
mock_permissions.vector_stores = ["store-1", "store-2"]
|
|
mock_prisma_client.db.litellm_objectpermissiontable.find_unique = AsyncMock(
|
|
return_value=mock_permissions
|
|
)
|
|
|
|
mock_vector_store_registry = MagicMock()
|
|
mock_vector_store_registry.get_vector_store_ids_to_run.return_value = ["store-1"]
|
|
|
|
with patch("litellm.proxy.proxy_server.prisma_client", mock_prisma_client), patch(
|
|
"litellm.vector_store_registry", mock_vector_store_registry
|
|
):
|
|
result = await vector_store_access_check(
|
|
request_body=request_body,
|
|
team_object=None,
|
|
valid_token=valid_token,
|
|
)
|
|
|
|
assert result is True
|
|
|
|
# Test with denied access
|
|
mock_vector_store_registry.get_vector_store_ids_to_run.return_value = ["store-3"]
|
|
|
|
with patch("litellm.proxy.proxy_server.prisma_client", mock_prisma_client), patch(
|
|
"litellm.vector_store_registry", mock_vector_store_registry
|
|
):
|
|
with pytest.raises(ProxyException) as exc_info:
|
|
await vector_store_access_check(
|
|
request_body=request_body,
|
|
team_object=None,
|
|
valid_token=valid_token,
|
|
)
|
|
|
|
assert exc_info.value.type == ProxyErrorTypes.key_vector_store_access_denied
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_vector_store_access_check_with_team_permissions():
|
|
"""Ensure teams restricted to specific vector stores cannot access others."""
|
|
request_body = {}
|
|
valid_token = UserAPIKeyAuth(token="team-test-token", object_permission_id=None)
|
|
|
|
team_object = MagicMock()
|
|
team_object.object_permission_id = "team-permission"
|
|
|
|
mock_prisma_client = MagicMock()
|
|
team_permissions = MagicMock()
|
|
team_permissions.vector_stores = ["team-store-allowed"]
|
|
mock_prisma_client.db.litellm_objectpermissiontable.find_unique = AsyncMock(
|
|
return_value=team_permissions
|
|
)
|
|
|
|
mock_vector_store_registry = MagicMock()
|
|
mock_vector_store_registry.get_vector_store_ids_to_run.return_value = [
|
|
"team-store-allowed"
|
|
]
|
|
|
|
with patch("litellm.proxy.proxy_server.prisma_client", mock_prisma_client), patch(
|
|
"litellm.vector_store_registry", mock_vector_store_registry
|
|
):
|
|
result = await vector_store_access_check(
|
|
request_body=request_body,
|
|
team_object=team_object,
|
|
valid_token=valid_token,
|
|
)
|
|
|
|
assert result is True
|
|
|
|
mock_vector_store_registry.get_vector_store_ids_to_run.return_value = [
|
|
"team-store-denied"
|
|
]
|
|
|
|
with patch("litellm.proxy.proxy_server.prisma_client", mock_prisma_client), patch(
|
|
"litellm.vector_store_registry", mock_vector_store_registry
|
|
):
|
|
with pytest.raises(ProxyException) as exc_info:
|
|
await vector_store_access_check(
|
|
request_body=request_body,
|
|
team_object=team_object,
|
|
valid_token=valid_token,
|
|
)
|
|
|
|
assert exc_info.value.type == ProxyErrorTypes.team_vector_store_access_denied
|
|
|
|
|
|
def test_can_object_call_model_with_alias():
|
|
"""Test that can_object_call_model works with model aliases"""
|
|
from litellm import Router
|
|
from litellm.proxy.auth.auth_checks import _can_object_call_model
|
|
|
|
model = "[ip-approved] gpt-4o"
|
|
llm_router = Router(
|
|
model_list=[
|
|
{
|
|
"model_name": "gpt-3.5-turbo",
|
|
"litellm_params": {
|
|
"model": "gpt-3.5-turbo",
|
|
"api_key": "test-api-key",
|
|
},
|
|
}
|
|
],
|
|
model_group_alias={
|
|
"[ip-approved] gpt-4o": {
|
|
"model": "gpt-3.5-turbo",
|
|
"hidden": True,
|
|
},
|
|
},
|
|
)
|
|
|
|
result = _can_object_call_model(
|
|
model=model,
|
|
llm_router=llm_router,
|
|
models=["gpt-3.5-turbo"],
|
|
team_model_aliases=None,
|
|
object_type="key",
|
|
fallback_depth=0,
|
|
)
|
|
|
|
print(result)
|
|
|
|
|
|
def test_can_object_call_model_access_via_alias_only():
|
|
"""
|
|
Test that a key can access a model via alias even when it doesn't have access to the underlying model.
|
|
|
|
This tests the scenario where:
|
|
- Router has model alias: "my-fake-gpt" -> "gpt-4"
|
|
- Key has access to: ["my-fake-gpt"] (alias)
|
|
- Key does NOT have access to: ["gpt-4"] (underlying model)
|
|
- The call should succeed because access is granted via the alias
|
|
"""
|
|
from litellm import Router
|
|
from litellm.proxy.auth.auth_checks import _can_object_call_model
|
|
|
|
model = "my-fake-gpt"
|
|
llm_router = Router(
|
|
model_list=[
|
|
{
|
|
"model_name": "gpt-4",
|
|
"litellm_params": {
|
|
"model": "gpt-4",
|
|
"api_key": "test-api-key",
|
|
},
|
|
}
|
|
],
|
|
model_group_alias={
|
|
"my-fake-gpt": {
|
|
"model": "gpt-4",
|
|
"hidden": False,
|
|
},
|
|
},
|
|
)
|
|
|
|
# Key has access to the alias but NOT the underlying model
|
|
result = _can_object_call_model(
|
|
model=model,
|
|
llm_router=llm_router,
|
|
models=["my-fake-gpt"], # Only has access to alias, not "gpt-4"
|
|
team_model_aliases=None,
|
|
object_type="key",
|
|
fallback_depth=0,
|
|
)
|
|
|
|
# Should return True because access is granted via the alias
|
|
assert result is True
|
|
|
|
|
|
def test_can_object_call_model_access_via_underlying_model_only():
|
|
"""
|
|
Test that a key can access a model via underlying model even when using an alias.
|
|
|
|
This tests the scenario where:
|
|
- Router has model alias: "my-fake-gpt" -> "gpt-4"
|
|
- Key has access to: ["gpt-4"] (underlying model)
|
|
- Key does NOT have access to: ["my-fake-gpt"] (alias)
|
|
- The call should succeed because access is granted via the underlying model
|
|
"""
|
|
from litellm import Router
|
|
from litellm.proxy.auth.auth_checks import _can_object_call_model
|
|
|
|
model = "my-fake-gpt"
|
|
llm_router = Router(
|
|
model_list=[
|
|
{
|
|
"model_name": "gpt-4",
|
|
"litellm_params": {
|
|
"model": "gpt-4",
|
|
"api_key": "test-api-key",
|
|
},
|
|
}
|
|
],
|
|
model_group_alias={
|
|
"my-fake-gpt": {
|
|
"model": "gpt-4",
|
|
"hidden": False,
|
|
},
|
|
},
|
|
)
|
|
|
|
# Key has access to the underlying model but NOT the alias
|
|
result = _can_object_call_model(
|
|
model=model,
|
|
llm_router=llm_router,
|
|
models=["gpt-4"], # Only has access to underlying model, not "my-fake-gpt"
|
|
team_model_aliases=None,
|
|
object_type="key",
|
|
fallback_depth=0,
|
|
)
|
|
|
|
# Should return True because access is granted via the underlying model
|
|
assert result is True
|
|
|
|
|
|
def test_can_object_call_model_no_access_to_alias_or_underlying():
|
|
"""
|
|
Test that a key cannot access a model when it has no access to either alias or underlying model.
|
|
"""
|
|
from litellm import Router
|
|
from litellm.proxy._types import ProxyErrorTypes, ProxyException
|
|
from litellm.proxy.auth.auth_checks import _can_object_call_model
|
|
|
|
model = "my-fake-gpt"
|
|
llm_router = Router(
|
|
model_list=[
|
|
{
|
|
"model_name": "gpt-4",
|
|
"litellm_params": {
|
|
"model": "gpt-4",
|
|
"api_key": "test-api-key",
|
|
},
|
|
}
|
|
],
|
|
model_group_alias={
|
|
"my-fake-gpt": {
|
|
"model": "gpt-4",
|
|
"hidden": False,
|
|
},
|
|
},
|
|
)
|
|
|
|
# Key has access to neither the alias nor the underlying model
|
|
with pytest.raises(ProxyException) as exc_info:
|
|
_can_object_call_model(
|
|
model=model,
|
|
llm_router=llm_router,
|
|
models=["gpt-3.5-turbo"], # Has access to different model entirely
|
|
team_model_aliases=None,
|
|
object_type="key",
|
|
fallback_depth=0,
|
|
)
|
|
|
|
# Should raise ProxyException with appropriate error type
|
|
assert exc_info.value.type == ProxyErrorTypes.key_model_access_denied
|
|
assert "key not allowed to access model" in str(exc_info.value.message)
|
|
assert "my-fake-gpt" in str(exc_info.value.message)
|
|
|
|
|
|
# Tag Budget Enforcement Tests
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_tag_objects_batch():
|
|
"""
|
|
Test batch fetching of tags validates:
|
|
- Cached tags are fetched from cache (no DB call for them)
|
|
- Uncached tags are fetched in ONE batch DB query
|
|
- After fetching, uncached tags are cached
|
|
"""
|
|
from litellm.proxy._types import LiteLLM_TagTable
|
|
from litellm.proxy.auth.auth_checks import get_tag_objects_batch
|
|
|
|
mock_prisma = MagicMock()
|
|
mock_cache = MagicMock()
|
|
mock_proxy_logging = MagicMock()
|
|
|
|
# Simulate 5 tags: 2 cached, 3 uncached
|
|
tag_names = ["cached-1", "uncached-1", "cached-2", "uncached-2", "uncached-3"]
|
|
|
|
# Mock cached tags
|
|
cached_tag_1 = {
|
|
"tag_name": "cached-1",
|
|
"spend": 10.0,
|
|
"models": [],
|
|
"litellm_budget_table": None,
|
|
}
|
|
cached_tag_2 = {
|
|
"tag_name": "cached-2",
|
|
"spend": 20.0,
|
|
"models": [],
|
|
"litellm_budget_table": None,
|
|
}
|
|
|
|
# Mock DB response for uncached tags
|
|
uncached_tag_1 = MagicMock()
|
|
uncached_tag_1.tag_name = "uncached-1"
|
|
uncached_tag_1.spend = 30.0
|
|
uncached_tag_1.models = []
|
|
uncached_tag_1.litellm_budget_table = None
|
|
uncached_tag_1.dict = MagicMock(
|
|
return_value={
|
|
"tag_name": "uncached-1",
|
|
"spend": 30.0,
|
|
"models": [],
|
|
"litellm_budget_table": None,
|
|
}
|
|
)
|
|
|
|
uncached_tag_2 = MagicMock()
|
|
uncached_tag_2.tag_name = "uncached-2"
|
|
uncached_tag_2.spend = 40.0
|
|
uncached_tag_2.models = []
|
|
uncached_tag_2.litellm_budget_table = None
|
|
uncached_tag_2.dict = MagicMock(
|
|
return_value={
|
|
"tag_name": "uncached-2",
|
|
"spend": 40.0,
|
|
"models": [],
|
|
"litellm_budget_table": None,
|
|
}
|
|
)
|
|
|
|
uncached_tag_3 = MagicMock()
|
|
uncached_tag_3.tag_name = "uncached-3"
|
|
uncached_tag_3.spend = 50.0
|
|
uncached_tag_3.models = []
|
|
uncached_tag_3.litellm_budget_table = None
|
|
uncached_tag_3.dict = MagicMock(
|
|
return_value={
|
|
"tag_name": "uncached-3",
|
|
"spend": 50.0,
|
|
"models": [],
|
|
"litellm_budget_table": None,
|
|
}
|
|
)
|
|
|
|
# Mock cache behavior - return cached tags, None for uncached
|
|
async def mock_get_cache(key):
|
|
if key == "tag:cached-1":
|
|
return cached_tag_1
|
|
elif key == "tag:cached-2":
|
|
return cached_tag_2
|
|
else:
|
|
return None
|
|
|
|
mock_cache.async_get_cache = AsyncMock(side_effect=mock_get_cache)
|
|
mock_cache.async_set_cache = AsyncMock()
|
|
|
|
# Mock DB to return all uncached tags in ONE query
|
|
mock_prisma.db.litellm_tagtable.find_many = AsyncMock(
|
|
return_value=[uncached_tag_1, uncached_tag_2, uncached_tag_3]
|
|
)
|
|
|
|
# Call batch fetch
|
|
tag_objects = await get_tag_objects_batch(
|
|
tag_names=tag_names,
|
|
prisma_client=mock_prisma,
|
|
user_api_key_cache=mock_cache,
|
|
proxy_logging_obj=mock_proxy_logging,
|
|
)
|
|
|
|
# Verify results
|
|
assert len(tag_objects) == 5
|
|
assert "cached-1" in tag_objects
|
|
assert "cached-2" in tag_objects
|
|
assert "uncached-1" in tag_objects
|
|
assert "uncached-2" in tag_objects
|
|
assert "uncached-3" in tag_objects
|
|
|
|
# Verify cached tags have correct values
|
|
assert tag_objects["cached-1"].spend == 10.0
|
|
assert tag_objects["cached-2"].spend == 20.0
|
|
|
|
# Verify uncached tags have correct values
|
|
assert tag_objects["uncached-1"].spend == 30.0
|
|
assert tag_objects["uncached-2"].spend == 40.0
|
|
assert tag_objects["uncached-3"].spend == 50.0
|
|
|
|
# Verify DB was called ONCE with all 3 uncached tags
|
|
mock_prisma.db.litellm_tagtable.find_many.assert_called_once()
|
|
call_args = mock_prisma.db.litellm_tagtable.find_many.call_args
|
|
assert call_args.kwargs["where"]["tag_name"]["in"] == [
|
|
"uncached-1",
|
|
"uncached-2",
|
|
"uncached-3",
|
|
]
|
|
|
|
# Verify uncached tags were cached after fetching
|
|
assert mock_cache.async_set_cache.call_count == 3
|
|
cache_calls = mock_cache.async_set_cache.call_args_list
|
|
cached_keys = [call.kwargs["key"] for call in cache_calls]
|
|
assert "tag:uncached-1" in cached_keys
|
|
assert "tag:uncached-2" in cached_keys
|
|
assert "tag:uncached-3" in cached_keys
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_team_object_raises_404_when_not_found():
|
|
from unittest.mock import AsyncMock, MagicMock
|
|
|
|
from fastapi import HTTPException
|
|
|
|
from litellm.proxy.auth.auth_checks import get_team_object
|
|
|
|
mock_prisma_client = MagicMock()
|
|
mock_db = AsyncMock()
|
|
mock_prisma_client.db = mock_db
|
|
mock_prisma_client.db.litellm_teamtable.find_unique = AsyncMock(return_value=None)
|
|
|
|
mock_cache = MagicMock()
|
|
mock_cache.async_get_cache = AsyncMock(return_value=None)
|
|
|
|
with pytest.raises(HTTPException) as exc_info:
|
|
await get_team_object(
|
|
team_id="nonexistent-team",
|
|
prisma_client=mock_prisma_client,
|
|
user_api_key_cache=mock_cache,
|
|
check_cache_only=False,
|
|
check_db_only=True,
|
|
)
|
|
|
|
assert exc_info.value.status_code == 404
|
|
assert "Team doesn't exist in db" in str(exc_info.value.detail)
|
|
|
|
|
|
# Reject Client-Side Metadata Tags Tests
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_reject_clientside_metadata_tags_enabled_with_tags():
|
|
"""Test that common_checks rejects request when reject_clientside_metadata_tags is True and metadata.tags is present."""
|
|
from fastapi import Request
|
|
|
|
from litellm.proxy.auth.auth_checks import common_checks
|
|
|
|
request_body = {
|
|
"model": "gpt-3.5-turbo",
|
|
"messages": [{"role": "user", "content": "test"}],
|
|
"metadata": {"tags": ["custom-tag"]},
|
|
}
|
|
|
|
general_settings = {"reject_clientside_metadata_tags": True}
|
|
|
|
# Create a mock request object
|
|
mock_request = MagicMock(spec=Request)
|
|
|
|
# Create a valid token for the test
|
|
valid_token = UserAPIKeyAuth(token="test-token", models=["gpt-3.5-turbo"])
|
|
|
|
with pytest.raises(ProxyException) as exc_info:
|
|
await common_checks(
|
|
request_body=request_body,
|
|
team_object=None,
|
|
user_object=None,
|
|
end_user_object=None,
|
|
global_proxy_spend=None,
|
|
general_settings=general_settings,
|
|
route="/chat/completions",
|
|
llm_router=None,
|
|
proxy_logging_obj=MagicMock(),
|
|
valid_token=valid_token,
|
|
request=mock_request,
|
|
)
|
|
|
|
assert exc_info.value.type == ProxyErrorTypes.bad_request_error
|
|
assert "metadata.tags" in exc_info.value.message
|
|
assert exc_info.value.param == "metadata.tags"
|
|
assert exc_info.value.code == "400"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_reject_clientside_metadata_tags_enabled_without_tags():
|
|
"""Test that common_checks allows request when reject_clientside_metadata_tags is True but no metadata.tags is present."""
|
|
from fastapi import Request
|
|
|
|
from litellm.proxy.auth.auth_checks import common_checks
|
|
|
|
request_body = {
|
|
"model": "gpt-3.5-turbo",
|
|
"messages": [{"role": "user", "content": "test"}],
|
|
"metadata": {"custom_field": "value"}, # No tags field
|
|
}
|
|
|
|
general_settings = {"reject_clientside_metadata_tags": True}
|
|
|
|
# Create a mock request object
|
|
mock_request = MagicMock(spec=Request)
|
|
|
|
# Create a valid token for the test
|
|
valid_token = UserAPIKeyAuth(token="test-token", models=["gpt-3.5-turbo"])
|
|
|
|
# Should not raise an exception
|
|
result = await common_checks(
|
|
request_body=request_body,
|
|
team_object=None,
|
|
user_object=None,
|
|
end_user_object=None,
|
|
global_proxy_spend=None,
|
|
general_settings=general_settings,
|
|
route="/chat/completions",
|
|
llm_router=None,
|
|
proxy_logging_obj=MagicMock(),
|
|
valid_token=valid_token,
|
|
request=mock_request,
|
|
)
|
|
|
|
assert result is True
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_reject_clientside_metadata_tags_disabled_with_tags():
|
|
"""Test that common_checks allows request with metadata.tags when reject_clientside_metadata_tags is False."""
|
|
from fastapi import Request
|
|
|
|
from litellm.proxy.auth.auth_checks import common_checks
|
|
|
|
request_body = {
|
|
"model": "gpt-3.5-turbo",
|
|
"messages": [{"role": "user", "content": "test"}],
|
|
"metadata": {"tags": ["custom-tag"]},
|
|
}
|
|
|
|
general_settings = {"reject_clientside_metadata_tags": False}
|
|
|
|
# Create a mock request object
|
|
mock_request = MagicMock(spec=Request)
|
|
|
|
# Create a valid token for the test
|
|
valid_token = UserAPIKeyAuth(token="test-token", models=["gpt-3.5-turbo"])
|
|
|
|
# Should not raise an exception
|
|
result = await common_checks(
|
|
request_body=request_body,
|
|
team_object=None,
|
|
user_object=None,
|
|
end_user_object=None,
|
|
global_proxy_spend=None,
|
|
general_settings=general_settings,
|
|
route="/chat/completions",
|
|
llm_router=None,
|
|
proxy_logging_obj=MagicMock(),
|
|
valid_token=valid_token,
|
|
request=mock_request,
|
|
)
|
|
|
|
assert result is True
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_reject_clientside_metadata_tags_not_set_with_tags():
|
|
"""Test that common_checks allows request with metadata.tags when reject_clientside_metadata_tags is not set."""
|
|
from fastapi import Request
|
|
|
|
from litellm.proxy.auth.auth_checks import common_checks
|
|
|
|
request_body = {
|
|
"model": "gpt-3.5-turbo",
|
|
"messages": [{"role": "user", "content": "test"}],
|
|
"metadata": {"tags": ["custom-tag"]},
|
|
}
|
|
|
|
general_settings = {} # No reject_clientside_metadata_tags setting
|
|
|
|
# Create a mock request object
|
|
mock_request = MagicMock(spec=Request)
|
|
|
|
# Create a valid token for the test
|
|
valid_token = UserAPIKeyAuth(token="test-token", models=["gpt-3.5-turbo"])
|
|
|
|
# Should not raise an exception
|
|
result = await common_checks(
|
|
request_body=request_body,
|
|
team_object=None,
|
|
user_object=None,
|
|
end_user_object=None,
|
|
global_proxy_spend=None,
|
|
general_settings=general_settings,
|
|
route="/chat/completions",
|
|
llm_router=None,
|
|
proxy_logging_obj=MagicMock(),
|
|
valid_token=valid_token,
|
|
request=mock_request,
|
|
)
|
|
|
|
assert result is True
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_reject_clientside_metadata_tags_non_llm_route():
|
|
"""Test that reject_clientside_metadata_tags check only applies to LLM API routes."""
|
|
from fastapi import Request
|
|
|
|
from litellm.proxy.auth.auth_checks import common_checks
|
|
|
|
request_body = {
|
|
"metadata": {"tags": ["custom-tag"]},
|
|
}
|
|
|
|
general_settings = {"reject_clientside_metadata_tags": True}
|
|
|
|
# Create a mock request object
|
|
mock_request = MagicMock(spec=Request)
|
|
|
|
# Create a valid token for the test
|
|
valid_token = UserAPIKeyAuth(token="test-token", models=["gpt-3.5-turbo"])
|
|
|
|
# Create an admin user object for the management route
|
|
admin_user = LiteLLM_UserTable(
|
|
user_id="admin-user",
|
|
user_role=LitellmUserRoles.PROXY_ADMIN,
|
|
)
|
|
|
|
# Should not raise an exception for non-LLM route
|
|
result = await common_checks(
|
|
request_body=request_body,
|
|
team_object=None,
|
|
user_object=admin_user,
|
|
end_user_object=None,
|
|
global_proxy_spend=None,
|
|
general_settings=general_settings,
|
|
route="/key/generate", # Management route, not LLM route
|
|
llm_router=None,
|
|
proxy_logging_obj=MagicMock(),
|
|
valid_token=valid_token,
|
|
request=mock_request,
|
|
)
|
|
|
|
assert result is True
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_virtual_key_soft_budget_check_with_user_obj():
|
|
"""Test _virtual_key_soft_budget_check includes user_email when user_obj is provided"""
|
|
alert_triggered = False
|
|
captured_call_info = None
|
|
|
|
class MockProxyLogging:
|
|
async def budget_alerts(self, type, user_info):
|
|
nonlocal alert_triggered, captured_call_info
|
|
alert_triggered = True
|
|
captured_call_info = user_info
|
|
assert type == "soft_budget"
|
|
assert isinstance(user_info, CallInfo)
|
|
|
|
valid_token = UserAPIKeyAuth(
|
|
token="test-token",
|
|
spend=100.0,
|
|
soft_budget=50.0,
|
|
user_id="test-user",
|
|
team_id="test-team",
|
|
team_alias="test-team-alias",
|
|
org_id="test-org",
|
|
key_alias="test-key",
|
|
max_budget=200.0,
|
|
)
|
|
|
|
user_obj = LiteLLM_UserTable(
|
|
user_id="test-user",
|
|
user_email="test@example.com",
|
|
max_budget=None,
|
|
)
|
|
|
|
proxy_logging_obj = MockProxyLogging()
|
|
|
|
await _virtual_key_soft_budget_check(
|
|
valid_token=valid_token,
|
|
proxy_logging_obj=proxy_logging_obj,
|
|
user_obj=user_obj,
|
|
)
|
|
|
|
await asyncio.sleep(0.1)
|
|
|
|
assert alert_triggered is True
|
|
assert captured_call_info is not None
|
|
assert captured_call_info.user_email == "test@example.com"
|
|
assert captured_call_info.token == "test-token"
|
|
assert captured_call_info.spend == 100.0
|
|
assert captured_call_info.soft_budget == 50.0
|
|
assert captured_call_info.max_budget == 200.0
|
|
assert captured_call_info.user_id == "test-user"
|
|
assert captured_call_info.team_id == "test-team"
|
|
assert captured_call_info.team_alias == "test-team-alias"
|
|
assert captured_call_info.organization_id == "test-org"
|
|
assert captured_call_info.key_alias == "test-key"
|
|
assert captured_call_info.event_group == Litellm_EntityType.KEY
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_virtual_key_soft_budget_check_without_user_obj():
|
|
"""Test _virtual_key_soft_budget_check sets user_email to None when user_obj is not provided"""
|
|
alert_triggered = False
|
|
captured_call_info = None
|
|
|
|
class MockProxyLogging:
|
|
async def budget_alerts(self, type, user_info):
|
|
nonlocal alert_triggered, captured_call_info
|
|
alert_triggered = True
|
|
captured_call_info = user_info
|
|
assert type == "soft_budget"
|
|
assert isinstance(user_info, CallInfo)
|
|
|
|
valid_token = UserAPIKeyAuth(
|
|
token="test-token",
|
|
spend=100.0,
|
|
soft_budget=50.0,
|
|
user_id="test-user",
|
|
team_id="test-team",
|
|
key_alias="test-key",
|
|
)
|
|
|
|
proxy_logging_obj = MockProxyLogging()
|
|
|
|
await _virtual_key_soft_budget_check(
|
|
valid_token=valid_token,
|
|
proxy_logging_obj=proxy_logging_obj,
|
|
user_obj=None,
|
|
)
|
|
|
|
await asyncio.sleep(0.1)
|
|
|
|
assert alert_triggered is True
|
|
assert captured_call_info is not None
|
|
assert captured_call_info.user_email is None
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"spend, soft_budget, expect_alert",
|
|
[
|
|
(100.0, 50.0, True), # Over soft budget
|
|
(50.0, 50.0, True), # At soft budget
|
|
(25.0, 50.0, False), # Under soft budget
|
|
(100.0, None, False), # No soft budget set
|
|
],
|
|
)
|
|
@pytest.mark.asyncio
|
|
async def test_virtual_key_soft_budget_check_scenarios(
|
|
spend, soft_budget, expect_alert
|
|
):
|
|
"""Test _virtual_key_soft_budget_check with various spend and soft_budget scenarios"""
|
|
alert_triggered = False
|
|
|
|
class MockProxyLogging:
|
|
async def budget_alerts(self, type, user_info):
|
|
nonlocal alert_triggered
|
|
alert_triggered = True
|
|
assert type == "soft_budget"
|
|
assert isinstance(user_info, CallInfo)
|
|
|
|
valid_token = UserAPIKeyAuth(
|
|
token="test-token",
|
|
spend=spend,
|
|
soft_budget=soft_budget,
|
|
user_id="test-user",
|
|
key_alias="test-key",
|
|
)
|
|
|
|
proxy_logging_obj = MockProxyLogging()
|
|
|
|
await _virtual_key_soft_budget_check(
|
|
valid_token=valid_token,
|
|
proxy_logging_obj=proxy_logging_obj,
|
|
user_obj=None,
|
|
)
|
|
|
|
await asyncio.sleep(0.1)
|
|
|
|
assert (
|
|
alert_triggered == expect_alert
|
|
), f"Expected alert_triggered to be {expect_alert} for spend={spend}, soft_budget={soft_budget}"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_virtual_key_max_budget_alert_check_with_user_obj():
|
|
"""Test _virtual_key_max_budget_alert_check includes user_email when user_obj is provided"""
|
|
alert_triggered = False
|
|
captured_call_info = None
|
|
|
|
class MockProxyLogging:
|
|
async def budget_alerts(self, type, user_info):
|
|
nonlocal alert_triggered, captured_call_info
|
|
alert_triggered = True
|
|
captured_call_info = user_info
|
|
assert type == "max_budget_alert"
|
|
assert isinstance(user_info, CallInfo)
|
|
|
|
valid_token = UserAPIKeyAuth(
|
|
token="test-token",
|
|
spend=90.0,
|
|
max_budget=100.0,
|
|
user_id="test-user",
|
|
team_id="test-team",
|
|
team_alias="test-team-alias",
|
|
org_id="test-org",
|
|
key_alias="test-key",
|
|
soft_budget=50.0,
|
|
)
|
|
|
|
user_obj = LiteLLM_UserTable(
|
|
user_id="test-user",
|
|
user_email="test@example.com",
|
|
max_budget=None,
|
|
)
|
|
|
|
proxy_logging_obj = MockProxyLogging()
|
|
|
|
await _virtual_key_max_budget_alert_check(
|
|
valid_token=valid_token,
|
|
proxy_logging_obj=proxy_logging_obj,
|
|
user_obj=user_obj,
|
|
)
|
|
|
|
await asyncio.sleep(0.1)
|
|
|
|
assert alert_triggered is True
|
|
assert captured_call_info is not None
|
|
assert captured_call_info.user_email == "test@example.com"
|
|
assert captured_call_info.token == "test-token"
|
|
assert captured_call_info.spend == 90.0
|
|
assert captured_call_info.max_budget == 100.0
|
|
assert captured_call_info.soft_budget == 50.0
|
|
assert captured_call_info.user_id == "test-user"
|
|
assert captured_call_info.team_id == "test-team"
|
|
assert captured_call_info.team_alias == "test-team-alias"
|
|
assert captured_call_info.organization_id == "test-org"
|
|
assert captured_call_info.key_alias == "test-key"
|
|
assert captured_call_info.event_group == Litellm_EntityType.KEY
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_virtual_key_max_budget_alert_check_without_user_obj():
|
|
"""Test _virtual_key_max_budget_alert_check sets user_email to None when user_obj is not provided"""
|
|
alert_triggered = False
|
|
captured_call_info = None
|
|
|
|
class MockProxyLogging:
|
|
async def budget_alerts(self, type, user_info):
|
|
nonlocal alert_triggered, captured_call_info
|
|
alert_triggered = True
|
|
captured_call_info = user_info
|
|
assert type == "max_budget_alert"
|
|
assert isinstance(user_info, CallInfo)
|
|
|
|
valid_token = UserAPIKeyAuth(
|
|
token="test-token",
|
|
spend=90.0,
|
|
max_budget=100.0,
|
|
user_id="test-user",
|
|
team_id="test-team",
|
|
key_alias="test-key",
|
|
)
|
|
|
|
proxy_logging_obj = MockProxyLogging()
|
|
|
|
await _virtual_key_max_budget_alert_check(
|
|
valid_token=valid_token,
|
|
proxy_logging_obj=proxy_logging_obj,
|
|
user_obj=None,
|
|
)
|
|
|
|
await asyncio.sleep(0.1)
|
|
|
|
assert alert_triggered is True
|
|
assert captured_call_info is not None
|
|
assert captured_call_info.user_email is None
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"spend, max_budget, expect_alert",
|
|
[
|
|
(80.0, 100.0, True), # At 80% threshold (alert threshold)
|
|
(90.0, 100.0, True), # Above threshold, below max_budget
|
|
(79.0, 100.0, False), # Below threshold
|
|
(100.0, 100.0, False), # At max_budget (not below, so no alert)
|
|
(110.0, 100.0, False), # Above max_budget (already exceeded)
|
|
(100.0, None, False), # No max_budget set
|
|
(0.0, 100.0, False), # Spend is 0
|
|
],
|
|
)
|
|
@pytest.mark.asyncio
|
|
async def test_virtual_key_max_budget_alert_check_scenarios(
|
|
spend, max_budget, expect_alert
|
|
):
|
|
"""Test _virtual_key_max_budget_alert_check with various spend and max_budget scenarios"""
|
|
alert_triggered = False
|
|
|
|
class MockProxyLogging:
|
|
async def budget_alerts(self, type, user_info):
|
|
nonlocal alert_triggered
|
|
alert_triggered = True
|
|
assert type == "max_budget_alert"
|
|
assert isinstance(user_info, CallInfo)
|
|
|
|
valid_token = UserAPIKeyAuth(
|
|
token="test-token",
|
|
spend=spend,
|
|
max_budget=max_budget,
|
|
user_id="test-user",
|
|
key_alias="test-key",
|
|
)
|
|
|
|
proxy_logging_obj = MockProxyLogging()
|
|
|
|
await _virtual_key_max_budget_alert_check(
|
|
valid_token=valid_token,
|
|
proxy_logging_obj=proxy_logging_obj,
|
|
user_obj=None,
|
|
)
|
|
|
|
await asyncio.sleep(0.1)
|
|
|
|
assert (
|
|
alert_triggered == expect_alert
|
|
), f"Expected alert_triggered to be {expect_alert} for spend={spend}, max_budget={max_budget}"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_fuzzy_user_object_case_insensitive_email():
|
|
"""Test that _get_fuzzy_user_object uses case-insensitive email lookup"""
|
|
# Setup mock Prisma client
|
|
mock_prisma = MagicMock()
|
|
mock_prisma.db = MagicMock()
|
|
mock_prisma.db.litellm_usertable = MagicMock()
|
|
|
|
# Mock user data with mixed case email
|
|
test_user = LiteLLM_UserTable(
|
|
user_id="test_123",
|
|
sso_user_id=None,
|
|
user_email="Test@Example.com", # Mixed case in DB
|
|
organization_memberships=[],
|
|
max_budget=None,
|
|
)
|
|
|
|
# Test: SSO ID not found, find by email with different casing
|
|
mock_prisma.db.litellm_usertable.find_unique = AsyncMock(return_value=None)
|
|
mock_prisma.db.litellm_usertable.find_first = AsyncMock(return_value=test_user)
|
|
|
|
# Search with lowercase email (different from DB)
|
|
result = await _get_fuzzy_user_object(
|
|
prisma_client=mock_prisma,
|
|
sso_user_id=None,
|
|
user_email="test@example.com", # Lowercase search
|
|
)
|
|
|
|
# Verify user was found despite case difference
|
|
assert result == test_user
|
|
|
|
# Verify the query used case-insensitive mode
|
|
mock_prisma.db.litellm_usertable.find_first.assert_called_once()
|
|
call_args = mock_prisma.db.litellm_usertable.find_first.call_args
|
|
assert call_args.kwargs["where"]["user_email"]["equals"] == "test@example.com"
|
|
assert call_args.kwargs["where"]["user_email"]["mode"] == "insensitive"
|
|
assert call_args.kwargs["include"] == {"organization_memberships": True}
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_custom_auth_common_checks_opt_in():
|
|
"""
|
|
Test that _run_post_custom_auth_checks only runs common_checks when
|
|
custom_auth_run_common_checks is explicitly set to True in general_settings.
|
|
|
|
By default (False), common_checks is skipped for backwards compatibility
|
|
with custom auth flows that existed before PR #22164.
|
|
"""
|
|
from litellm.proxy.auth.user_api_key_auth import _run_post_custom_auth_checks
|
|
|
|
valid_token = UserAPIKeyAuth(token="test-token")
|
|
mock_request = MagicMock()
|
|
|
|
# Default (no flag) — common_checks should NOT be called
|
|
with patch(
|
|
"litellm.proxy.auth.user_api_key_auth.common_checks",
|
|
new_callable=AsyncMock,
|
|
) as mock_common, patch(
|
|
"litellm.proxy.proxy_server.general_settings",
|
|
{},
|
|
):
|
|
mock_common.return_value = True
|
|
result = await _run_post_custom_auth_checks(
|
|
valid_token=valid_token,
|
|
request=mock_request,
|
|
request_data={},
|
|
route="/ldap/ngs/ready",
|
|
parent_otel_span=None,
|
|
)
|
|
mock_common.assert_not_called()
|
|
|
|
# With flag=True — common_checks SHOULD be called
|
|
with patch(
|
|
"litellm.proxy.auth.user_api_key_auth.common_checks",
|
|
new_callable=AsyncMock,
|
|
) as mock_common, patch(
|
|
"litellm.proxy.proxy_server.general_settings",
|
|
{"custom_auth_run_common_checks": True},
|
|
):
|
|
mock_common.return_value = True
|
|
result = await _run_post_custom_auth_checks(
|
|
valid_token=valid_token,
|
|
request=mock_request,
|
|
request_data={},
|
|
route="/chat/completions",
|
|
parent_otel_span=None,
|
|
)
|
|
mock_common.assert_called_once()
|
|
|
|
|
|
# =====================================================================
|
|
# Spend counter budget check tests (v2 — Redis-backed spend counters)
|
|
# =====================================================================
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_virtual_key_budget_check_reads_from_spend_counter():
|
|
"""Budget check should use get_current_spend when counter exists,
|
|
even if cached object shows lower spend."""
|
|
from litellm.proxy.utils import ProxyLogging
|
|
|
|
valid_token = UserAPIKeyAuth(
|
|
token="test-hashed-token",
|
|
spend=0.0, # stale — counter has 1.5
|
|
max_budget=1.0,
|
|
user_id="test-user",
|
|
)
|
|
|
|
proxy_logging_obj = ProxyLogging(user_api_key_cache=None)
|
|
proxy_logging_obj.budget_alerts = AsyncMock()
|
|
|
|
async def mock_get_current_spend(counter_key, fallback_spend):
|
|
if counter_key == "spend:key:test-hashed-token":
|
|
return 1.5
|
|
return fallback_spend
|
|
|
|
with patch(
|
|
"litellm.proxy.proxy_server.get_current_spend", mock_get_current_spend
|
|
):
|
|
with pytest.raises(litellm.BudgetExceededError) as exc_info:
|
|
await _virtual_key_max_budget_check(
|
|
valid_token=valid_token,
|
|
proxy_logging_obj=proxy_logging_obj,
|
|
)
|
|
assert exc_info.value.current_cost == 1.5
|
|
assert exc_info.value.max_budget == 1.0
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_virtual_key_budget_check_fallback_no_counter():
|
|
"""When counter doesn't exist, budget check should fall back
|
|
to cached object's spend via fallback_spend."""
|
|
from litellm.proxy.utils import ProxyLogging
|
|
|
|
valid_token = UserAPIKeyAuth(
|
|
token="test-hashed-token",
|
|
spend=15.0,
|
|
max_budget=10.0,
|
|
user_id="test-user",
|
|
)
|
|
|
|
proxy_logging_obj = ProxyLogging(user_api_key_cache=None)
|
|
proxy_logging_obj.budget_alerts = AsyncMock()
|
|
|
|
# get_current_spend returns fallback_spend when no counter exists
|
|
async def mock_get_current_spend(counter_key, fallback_spend):
|
|
return fallback_spend
|
|
|
|
with patch(
|
|
"litellm.proxy.proxy_server.get_current_spend", mock_get_current_spend
|
|
):
|
|
with pytest.raises(litellm.BudgetExceededError) as exc_info:
|
|
await _virtual_key_max_budget_check(
|
|
valid_token=valid_token,
|
|
proxy_logging_obj=proxy_logging_obj,
|
|
)
|
|
assert exc_info.value.current_cost == 15.0
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_team_budget_check_reads_from_spend_counter():
|
|
"""Team budget check should use get_current_spend when counter exists."""
|
|
from litellm.proxy.utils import ProxyLogging
|
|
|
|
team_object = LiteLLM_TeamTable(
|
|
team_id="test-team",
|
|
spend=0.0, # stale
|
|
max_budget=1.0,
|
|
)
|
|
valid_token = UserAPIKeyAuth(token="test-token", team_id="test-team")
|
|
|
|
proxy_logging_obj = ProxyLogging(user_api_key_cache=None)
|
|
proxy_logging_obj.budget_alerts = AsyncMock()
|
|
|
|
async def mock_get_current_spend(counter_key, fallback_spend):
|
|
if counter_key == "spend:team:test-team":
|
|
return 1.5
|
|
return fallback_spend
|
|
|
|
with patch(
|
|
"litellm.proxy.proxy_server.get_current_spend", mock_get_current_spend
|
|
):
|
|
with pytest.raises(litellm.BudgetExceededError) as exc_info:
|
|
await _team_max_budget_check(
|
|
team_object=team_object,
|
|
valid_token=valid_token,
|
|
proxy_logging_obj=proxy_logging_obj,
|
|
)
|
|
assert exc_info.value.current_cost == 1.5
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_team_member_budget_check_reads_from_spend_counter():
|
|
"""Team member budget check should use get_current_spend when counter exists."""
|
|
from litellm.proxy._types import LiteLLM_BudgetTable, LiteLLM_TeamMembership
|
|
from litellm.proxy.utils import ProxyLogging
|
|
|
|
team_object = LiteLLM_TeamTable(team_id="test-team")
|
|
user_object = LiteLLM_UserTable(user_id="test-user")
|
|
valid_token = UserAPIKeyAuth(
|
|
token="test-token",
|
|
user_id="test-user",
|
|
team_id="test-team",
|
|
)
|
|
|
|
team_membership = LiteLLM_TeamMembership(
|
|
user_id="test-user",
|
|
team_id="test-team",
|
|
spend=0.0, # stale
|
|
litellm_budget_table=LiteLLM_BudgetTable(max_budget=1.0),
|
|
)
|
|
|
|
proxy_logging_obj = ProxyLogging(user_api_key_cache=None)
|
|
|
|
async def mock_get_current_spend(counter_key, fallback_spend):
|
|
if counter_key == "spend:team_member:test-user:test-team":
|
|
return 1.5
|
|
return fallback_spend
|
|
|
|
with patch(
|
|
"litellm.proxy.proxy_server.get_current_spend", mock_get_current_spend
|
|
), patch(
|
|
"litellm.proxy.auth.auth_checks.get_team_membership",
|
|
new_callable=AsyncMock,
|
|
return_value=team_membership,
|
|
):
|
|
with pytest.raises(litellm.BudgetExceededError) as exc_info:
|
|
await _check_team_member_budget(
|
|
team_object=team_object,
|
|
user_object=user_object,
|
|
valid_token=valid_token,
|
|
prisma_client=MagicMock(),
|
|
user_api_key_cache=MagicMock(),
|
|
proxy_logging_obj=proxy_logging_obj,
|
|
)
|
|
assert exc_info.value.current_cost == 1.5
|