mirror of
https://github.com/tiennm99/litellm.git
synced 2026-06-25 05:07:03 +00:00
7e49b4e2a0
* fix LiteLLM_ObjectPermissionTable * fix include object_permission for list key * fix key list to inclue obj permissions * fix object permissions for vector stores on key info * add key edit view with vector stores * allow editing vector stores permissions * fixes obj permissions * feat: add obj permission on UI * fix: add object_permission:true * ui show org vector stores on org info * fix: show object permissions on /org/info * feat: allow updating obj permissions for keys * fixes: key object permissions * fixes: team object permissions * fixes: org object permissions * fix vector store selector for Orgs * feat: add auth checks for vector store permissions * feat: working auth checks for vector store permissions * test vector stores auth checks * Update litellm/proxy/_types.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * fix: linting --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
337 lines
11 KiB
Python
337 lines
11 KiB
Python
import asyncio
|
|
import json
|
|
import os
|
|
import sys
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
sys.path.insert(
|
|
0, os.path.abspath("../../..")
|
|
) # Adds the parent directory to the system path
|
|
|
|
from datetime import datetime, timedelta
|
|
|
|
import pytest
|
|
|
|
import litellm
|
|
from litellm.proxy._types import (
|
|
LiteLLM_ObjectPermissionTable,
|
|
LiteLLM_TeamTable,
|
|
LiteLLM_UserTable,
|
|
LitellmUserRoles,
|
|
ProxyErrorTypes,
|
|
ProxyException,
|
|
SSOUserDefinedValues,
|
|
UserAPIKeyAuth,
|
|
)
|
|
from litellm.proxy.auth.auth_checks import (
|
|
ExperimentalUIJWTToken,
|
|
_can_object_call_vector_stores,
|
|
get_user_object,
|
|
vector_store_access_check,
|
|
)
|
|
from litellm.proxy.common_utils.encrypt_decrypt_utils import decrypt_value_helper
|
|
from litellm.utils import get_utc_datetime
|
|
|
|
|
|
@pytest.fixture(autouse=True)
|
|
def set_salt_key(monkeypatch):
|
|
"""Automatically set LITELLM_SALT_KEY for all tests"""
|
|
monkeypatch.setenv("LITELLM_SALT_KEY", "sk-1234")
|
|
|
|
|
|
@pytest.fixture
|
|
def valid_sso_user_defined_values():
|
|
return LiteLLM_UserTable(
|
|
user_id="test_user",
|
|
user_email="test@example.com",
|
|
user_role=LitellmUserRoles.PROXY_ADMIN.value,
|
|
models=["gpt-3.5-turbo"],
|
|
max_budget=100.0,
|
|
)
|
|
|
|
|
|
@pytest.fixture
|
|
def invalid_sso_user_defined_values():
|
|
return LiteLLM_UserTable(
|
|
user_id="test_user",
|
|
user_email="test@example.com",
|
|
user_role=None, # Missing user role
|
|
models=["gpt-3.5-turbo"],
|
|
max_budget=100.0,
|
|
)
|
|
|
|
|
|
def test_get_experimental_ui_login_jwt_auth_token_valid(valid_sso_user_defined_values):
|
|
"""Test generating JWT token with valid user role"""
|
|
token = ExperimentalUIJWTToken.get_experimental_ui_login_jwt_auth_token(
|
|
valid_sso_user_defined_values
|
|
)
|
|
|
|
# Decrypt and verify token contents
|
|
decrypted_token = decrypt_value_helper(token, exception_type="debug")
|
|
# Check that decrypted_token is not None before using json.loads
|
|
assert decrypted_token is not None
|
|
token_data = json.loads(decrypted_token)
|
|
|
|
assert token_data["user_id"] == "test_user"
|
|
assert token_data["user_role"] == LitellmUserRoles.PROXY_ADMIN.value
|
|
assert token_data["models"] == ["gpt-3.5-turbo"]
|
|
assert token_data["max_budget"] == litellm.max_ui_session_budget
|
|
|
|
# Verify expiration time is set and valid
|
|
assert "expires" in token_data
|
|
expires = datetime.fromisoformat(token_data["expires"].replace("Z", "+00:00"))
|
|
assert expires > get_utc_datetime()
|
|
assert expires <= get_utc_datetime() + timedelta(minutes=10)
|
|
|
|
|
|
def test_get_experimental_ui_login_jwt_auth_token_invalid(
|
|
invalid_sso_user_defined_values,
|
|
):
|
|
"""Test generating JWT token with missing user role"""
|
|
with pytest.raises(Exception) as exc_info:
|
|
ExperimentalUIJWTToken.get_experimental_ui_login_jwt_auth_token(
|
|
invalid_sso_user_defined_values
|
|
)
|
|
|
|
assert str(exc_info.value) == "User role is required for experimental UI login"
|
|
|
|
|
|
def test_get_key_object_from_ui_hash_key_valid(
|
|
valid_sso_user_defined_values, monkeypatch
|
|
):
|
|
"""Test getting key object from valid UI hash key"""
|
|
monkeypatch.setenv("EXPERIMENTAL_UI_LOGIN", "True")
|
|
# Generate a valid token
|
|
token = ExperimentalUIJWTToken.get_experimental_ui_login_jwt_auth_token(
|
|
valid_sso_user_defined_values
|
|
)
|
|
|
|
# Get key object
|
|
key_object = ExperimentalUIJWTToken.get_key_object_from_ui_hash_key(token)
|
|
|
|
assert key_object is not None
|
|
assert key_object.user_id == "test_user"
|
|
assert key_object.user_role == LitellmUserRoles.PROXY_ADMIN
|
|
assert key_object.models == ["gpt-3.5-turbo"]
|
|
assert key_object.max_budget == litellm.max_ui_session_budget
|
|
|
|
|
|
def test_get_key_object_from_ui_hash_key_invalid():
|
|
"""Test getting key object from invalid UI hash key"""
|
|
# Test with invalid token
|
|
key_object = ExperimentalUIJWTToken.get_key_object_from_ui_hash_key("invalid_token")
|
|
assert key_object is None
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_default_internal_user_params_with_get_user_object(monkeypatch):
|
|
"""Test that default_internal_user_params is used when creating a new user via get_user_object"""
|
|
# Set up default_internal_user_params
|
|
default_params = {
|
|
"models": ["gpt-4", "claude-3-opus"],
|
|
"max_budget": 200.0,
|
|
"user_role": "internal_user",
|
|
}
|
|
monkeypatch.setattr(litellm, "default_internal_user_params", default_params)
|
|
|
|
# Mock the necessary dependencies
|
|
mock_prisma_client = MagicMock()
|
|
mock_db = AsyncMock()
|
|
mock_prisma_client.db = mock_db
|
|
|
|
# Set up the user creation mock - create a complete user model that can be converted to a dict
|
|
mock_user = MagicMock()
|
|
mock_user.user_id = "new_test_user"
|
|
mock_user.models = ["gpt-4", "claude-3-opus"]
|
|
mock_user.max_budget = 200.0
|
|
mock_user.user_role = "internal_user"
|
|
mock_user.organization_memberships = []
|
|
|
|
# Make the mock model_dump or dict method return appropriate data
|
|
mock_user.dict = lambda: {
|
|
"user_id": "new_test_user",
|
|
"models": ["gpt-4", "claude-3-opus"],
|
|
"max_budget": 200.0,
|
|
"user_role": "internal_user",
|
|
"organization_memberships": [],
|
|
}
|
|
|
|
# Setup the mock returns
|
|
mock_prisma_client.db.litellm_usertable.find_unique = AsyncMock(return_value=None)
|
|
mock_prisma_client.db.litellm_usertable.create = AsyncMock(return_value=mock_user)
|
|
|
|
# Create a mock cache - use AsyncMock for async methods
|
|
mock_cache = MagicMock()
|
|
mock_cache.async_get_cache = AsyncMock(return_value=None)
|
|
mock_cache.async_set_cache = AsyncMock()
|
|
|
|
# Call get_user_object with user_id_upsert=True to trigger user creation
|
|
try:
|
|
user_obj = await get_user_object(
|
|
user_id="new_test_user",
|
|
prisma_client=mock_prisma_client,
|
|
user_api_key_cache=mock_cache,
|
|
user_id_upsert=True,
|
|
proxy_logging_obj=None,
|
|
)
|
|
except Exception as e:
|
|
# this fails since the mock object is a MagicMock and not a LiteLLM_UserTable
|
|
print(e)
|
|
|
|
# Verify the user was created with the default params
|
|
mock_prisma_client.db.litellm_usertable.create.assert_called_once()
|
|
creation_args = mock_prisma_client.db.litellm_usertable.create.call_args[1]["data"]
|
|
|
|
# Verify defaults were applied to the creation args
|
|
assert "models" in creation_args
|
|
assert creation_args["models"] == ["gpt-4", "claude-3-opus"]
|
|
assert creation_args["max_budget"] == 200.0
|
|
assert creation_args["user_role"] == "internal_user"
|
|
|
|
|
|
# Vector Store Auth Check Tests
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
@pytest.mark.parametrize(
|
|
"prisma_client,vector_store_registry,expected_result",
|
|
[
|
|
(None, MagicMock(), True), # No prisma client
|
|
(MagicMock(), None, True), # No vector store registry
|
|
(MagicMock(), MagicMock(), True), # No vector stores to run
|
|
],
|
|
)
|
|
async def test_vector_store_access_check_early_returns(
|
|
prisma_client, vector_store_registry, expected_result
|
|
):
|
|
"""Test vector_store_access_check returns True for early exit conditions"""
|
|
request_body = {"messages": [{"role": "user", "content": "test"}]}
|
|
|
|
if vector_store_registry:
|
|
vector_store_registry.get_vector_store_ids_to_run.return_value = None
|
|
|
|
with patch("litellm.proxy.proxy_server.prisma_client", prisma_client), patch(
|
|
"litellm.vector_store_registry", vector_store_registry
|
|
):
|
|
result = await vector_store_access_check(
|
|
request_body=request_body,
|
|
team_object=None,
|
|
valid_token=None,
|
|
)
|
|
|
|
assert result == expected_result
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"object_permissions,vector_store_ids,should_raise,error_type",
|
|
[
|
|
(None, ["store-1"], False, None), # None permissions - should pass
|
|
(
|
|
{"vector_stores": []},
|
|
["store-1"],
|
|
False,
|
|
None,
|
|
), # Empty vector_stores - should pass (access to all)
|
|
(
|
|
{"vector_stores": ["store-1", "store-2"]},
|
|
["store-1"],
|
|
False,
|
|
None,
|
|
), # Has access
|
|
(
|
|
{"vector_stores": ["store-1", "store-2"]},
|
|
["store-3"],
|
|
True,
|
|
ProxyErrorTypes.key_vector_store_access_denied,
|
|
), # No access
|
|
(
|
|
{"vector_stores": ["store-1"]},
|
|
["store-1", "store-3"],
|
|
True,
|
|
ProxyErrorTypes.team_vector_store_access_denied,
|
|
), # Partial access
|
|
],
|
|
)
|
|
def test_can_object_call_vector_stores_scenarios(
|
|
object_permissions, vector_store_ids, should_raise, error_type
|
|
):
|
|
"""Test _can_object_call_vector_stores with various permission scenarios"""
|
|
# Convert dict to object if not None
|
|
if object_permissions is not None:
|
|
mock_permissions = MagicMock()
|
|
mock_permissions.vector_stores = object_permissions["vector_stores"]
|
|
object_permissions = mock_permissions
|
|
|
|
object_type = (
|
|
"key"
|
|
if error_type == ProxyErrorTypes.key_vector_store_access_denied
|
|
else "team"
|
|
)
|
|
|
|
if should_raise:
|
|
with pytest.raises(ProxyException) as exc_info:
|
|
_can_object_call_vector_stores(
|
|
object_type=object_type,
|
|
vector_store_ids_to_run=vector_store_ids,
|
|
object_permissions=object_permissions,
|
|
)
|
|
assert exc_info.value.type == error_type
|
|
else:
|
|
result = _can_object_call_vector_stores(
|
|
object_type=object_type,
|
|
vector_store_ids_to_run=vector_store_ids,
|
|
object_permissions=object_permissions,
|
|
)
|
|
assert result is True
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_vector_store_access_check_with_permissions():
|
|
"""Test vector_store_access_check with actual permission checking"""
|
|
request_body = {"tools": [{"type": "function", "function": {"name": "test"}}]}
|
|
|
|
# Test with valid token that has access
|
|
valid_token = UserAPIKeyAuth(
|
|
token="test-token",
|
|
object_permission_id="perm-123",
|
|
models=["gpt-4"],
|
|
max_budget=100.0,
|
|
)
|
|
|
|
mock_prisma_client = MagicMock()
|
|
mock_permissions = MagicMock()
|
|
mock_permissions.vector_stores = ["store-1", "store-2"]
|
|
mock_prisma_client.db.litellm_objectpermissiontable.find_unique = AsyncMock(
|
|
return_value=mock_permissions
|
|
)
|
|
|
|
mock_vector_store_registry = MagicMock()
|
|
mock_vector_store_registry.get_vector_store_ids_to_run.return_value = ["store-1"]
|
|
|
|
with patch("litellm.proxy.proxy_server.prisma_client", mock_prisma_client), patch(
|
|
"litellm.vector_store_registry", mock_vector_store_registry
|
|
):
|
|
result = await vector_store_access_check(
|
|
request_body=request_body,
|
|
team_object=None,
|
|
valid_token=valid_token,
|
|
)
|
|
|
|
assert result is True
|
|
|
|
# Test with denied access
|
|
mock_vector_store_registry.get_vector_store_ids_to_run.return_value = ["store-3"]
|
|
|
|
with patch("litellm.proxy.proxy_server.prisma_client", mock_prisma_client), patch(
|
|
"litellm.vector_store_registry", mock_vector_store_registry
|
|
):
|
|
with pytest.raises(ProxyException) as exc_info:
|
|
await vector_store_access_check(
|
|
request_body=request_body,
|
|
team_object=None,
|
|
valid_token=valid_token,
|
|
)
|
|
|
|
assert exc_info.value.type == ProxyErrorTypes.key_vector_store_access_denied
|