mirror of
https://github.com/tiennm99/litellm.git
synced 2026-07-01 03:04:51 +00:00
9918a9c78c
* fix(guardrails): restore disable_global_guardrails persistence for keys The per-key/team "Disable Global Guardrails" toggle silently stopped working after #17042, which removed `disable_global_guardrails` from the key/team request models and from the premium metadata allowlist. Without those, the UI's top-level field was dropped by pydantic and never folded into key `metadata`, so the runtime gate always read False and global default_on guardrails kept running. Restore the request-model fields (KeyRequestBase, NewTeamRequest, UpdateTeamRequest) and the `LiteLLM_ManagementEndpoint_MetadataFields_Premium` entry so the flag is promoted into metadata again. Because the key edit form always submits the flag (false by default), guard the UI so it is only sent when it actually changed (edit) or is enabled (create) — this keeps the premium gate on enabling intact while not 403-ing non-premium users who edit unrelated key fields, mirroring how guardrails/tags are already stripped. * test(guardrails): cover disable_global_guardrails toggle-off + clarify premium field comment Add a prepare_metadata_fields case asserting `disable_global_guardrails: False` overwrites an existing `True`, and rewrite the PREMIUM_METADATA_FIELDS comment to explain why boolean premium fields are excluded from the empty-value strip loop.
1346 lines
43 KiB
Python
1346 lines
43 KiB
Python
import os
|
|
import sys
|
|
import traceback
|
|
from litellm._uuid import uuid
|
|
import datetime as dt
|
|
from datetime import datetime
|
|
|
|
from dotenv import load_dotenv
|
|
from fastapi import Request
|
|
from fastapi.routing import APIRoute
|
|
from unittest.mock import MagicMock, patch
|
|
|
|
load_dotenv()
|
|
import io
|
|
import os
|
|
import time
|
|
|
|
# this file is to test litellm/proxy
|
|
|
|
sys.path.insert(
|
|
0, os.path.abspath("../..")
|
|
) # Adds the parent directory to the system path
|
|
import asyncio
|
|
import logging
|
|
|
|
import pytest
|
|
import litellm
|
|
from litellm._logging import verbose_proxy_logger
|
|
from litellm.proxy.management_endpoints.team_endpoints import list_team
|
|
from litellm.proxy._types import *
|
|
from litellm.proxy.management_endpoints.internal_user_endpoints import (
|
|
new_user,
|
|
user_info,
|
|
user_update,
|
|
get_users,
|
|
)
|
|
from litellm.proxy.management_endpoints.key_management_endpoints import (
|
|
delete_key_fn,
|
|
generate_key_fn,
|
|
generate_key_helper_fn,
|
|
info_key_fn,
|
|
regenerate_key_fn,
|
|
update_key_fn,
|
|
)
|
|
from litellm.proxy.management_endpoints.team_endpoints import (
|
|
new_team,
|
|
team_info,
|
|
update_team,
|
|
)
|
|
from litellm.proxy.proxy_server import (
|
|
LitellmUserRoles,
|
|
audio_transcriptions,
|
|
chat_completion,
|
|
completion,
|
|
embeddings,
|
|
model_list,
|
|
moderations,
|
|
user_api_key_auth,
|
|
)
|
|
from litellm.proxy.management_endpoints.customer_endpoints import (
|
|
new_end_user,
|
|
)
|
|
from litellm.proxy.spend_tracking.spend_management_endpoints import (
|
|
global_spend,
|
|
global_spend_logs,
|
|
global_spend_models,
|
|
global_spend_keys,
|
|
spend_key_fn,
|
|
spend_user_fn,
|
|
view_spend_logs,
|
|
)
|
|
from litellm.proxy.utils import PrismaClient, ProxyLogging, hash_token, update_spend
|
|
|
|
verbose_proxy_logger.setLevel(level=logging.DEBUG)
|
|
|
|
from starlette.datastructures import URL
|
|
|
|
from litellm.caching.caching import DualCache
|
|
from litellm.proxy._types import (
|
|
DynamoDBArgs,
|
|
GenerateKeyRequest,
|
|
KeyRequest,
|
|
NewCustomerRequest,
|
|
NewTeamRequest,
|
|
NewUserRequest,
|
|
ProxyErrorTypes,
|
|
ProxyException,
|
|
UpdateKeyRequest,
|
|
RegenerateKeyRequest,
|
|
UpdateTeamRequest,
|
|
UpdateUserRequest,
|
|
UserAPIKeyAuth,
|
|
)
|
|
from litellm.types.proxy.management_endpoints.ui_sso import (
|
|
LiteLLM_UpperboundKeyGenerateParams,
|
|
)
|
|
|
|
proxy_logging_obj = ProxyLogging(user_api_key_cache=DualCache())
|
|
|
|
|
|
@pytest.fixture
|
|
def prisma_client():
|
|
from litellm.proxy.proxy_cli import append_query_params
|
|
|
|
### add connection pool + pool timeout args.
|
|
params = {"connection_limit": 100, "pool_timeout": 60}
|
|
database_url = os.getenv("DATABASE_URL")
|
|
modified_url = append_query_params(database_url, params)
|
|
os.environ["DATABASE_URL"] = modified_url
|
|
|
|
# Assuming PrismaClient is a class that needs to be instantiated
|
|
prisma_client = PrismaClient(
|
|
database_url=os.environ["DATABASE_URL"], proxy_logging_obj=proxy_logging_obj
|
|
)
|
|
|
|
# Reset litellm.proxy.proxy_server.prisma_client to None
|
|
litellm.proxy.proxy_server.litellm_proxy_budget_name = (
|
|
f"litellm-proxy-budget-{time.time()}"
|
|
)
|
|
litellm.proxy.proxy_server.user_custom_key_generate = None
|
|
|
|
return prisma_client
|
|
|
|
|
|
################ Unit Tests for testing regeneration of keys ###########
|
|
@pytest.mark.asyncio()
|
|
@pytest.mark.skip(reason="Requires reliable external DB connection (prisma).")
|
|
async def test_regenerate_api_key(prisma_client):
|
|
litellm.set_verbose = True
|
|
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
|
|
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
|
|
await litellm.proxy.proxy_server.prisma_client.connect()
|
|
|
|
# generate new key
|
|
key_alias = f"test_alias_regenerate_key-{uuid.uuid4()}"
|
|
spend = 100
|
|
max_budget = 400
|
|
models = ["fake-openai-endpoint"]
|
|
new_key = await generate_key_fn(
|
|
data=GenerateKeyRequest(
|
|
key_alias=key_alias, spend=spend, max_budget=max_budget, models=models
|
|
),
|
|
user_api_key_dict=UserAPIKeyAuth(
|
|
user_role=LitellmUserRoles.PROXY_ADMIN,
|
|
api_key="sk-1234",
|
|
user_id="1234",
|
|
),
|
|
)
|
|
|
|
generated_key = new_key.key
|
|
print(generated_key)
|
|
|
|
# assert the new key works as expected
|
|
request = Request(scope={"type": "http"})
|
|
request._url = URL(url="/chat/completions")
|
|
|
|
async def return_body():
|
|
return_string = f'{{"model": "fake-openai-endpoint"}}'
|
|
# return string as bytes
|
|
return return_string.encode()
|
|
|
|
request.body = return_body
|
|
result = await user_api_key_auth(request=request, api_key=f"Bearer {generated_key}")
|
|
print(result)
|
|
|
|
# regenerate the key
|
|
print("regenerating key: {}".format(generated_key))
|
|
new_key = await regenerate_key_fn(
|
|
key=generated_key,
|
|
user_api_key_dict=UserAPIKeyAuth(
|
|
user_role=LitellmUserRoles.PROXY_ADMIN,
|
|
api_key="sk-1234",
|
|
user_id="1234",
|
|
),
|
|
)
|
|
print("response from regenerate_key_fn", new_key)
|
|
|
|
# assert the new key works as expected
|
|
request = Request(scope={"type": "http"})
|
|
request._url = URL(url="/chat/completions")
|
|
|
|
async def return_body_2():
|
|
return_string = f'{{"model": "fake-openai-endpoint"}}'
|
|
# return string as bytes
|
|
return return_string.encode()
|
|
|
|
request.body = return_body_2
|
|
result = await user_api_key_auth(request=request, api_key=f"Bearer {new_key.key}")
|
|
print(result)
|
|
|
|
# assert the old key stops working
|
|
request = Request(scope={"type": "http"})
|
|
request._url = URL(url="/chat/completions")
|
|
|
|
async def return_body_3():
|
|
return_string = f'{{"model": "fake-openai-endpoint"}}'
|
|
# return string as bytes
|
|
return return_string.encode()
|
|
|
|
request.body = return_body_3
|
|
try:
|
|
result = await user_api_key_auth(
|
|
request=request, api_key=f"Bearer {generated_key}"
|
|
)
|
|
print(result)
|
|
pytest.fail(f"This should have failed!. the key has been regenerated")
|
|
except Exception as e:
|
|
print("got expected exception", e)
|
|
assert "Invalid proxy server token passed" in e.message
|
|
|
|
# Check that the regenerated key has the same spend, max_budget, models and key_alias
|
|
assert new_key.spend == spend, f"Expected spend {spend} but got {new_key.spend}"
|
|
assert (
|
|
new_key.max_budget == max_budget
|
|
), f"Expected max_budget {max_budget} but got {new_key.max_budget}"
|
|
assert (
|
|
new_key.key_alias == key_alias
|
|
), f"Expected key_alias {key_alias} but got {new_key.key_alias}"
|
|
assert (
|
|
new_key.models == models
|
|
), f"Expected models {models} but got {new_key.models}"
|
|
|
|
assert new_key.key_name == f"sk-...{new_key.key[-4:]}"
|
|
|
|
pass
|
|
|
|
|
|
@pytest.mark.asyncio()
|
|
@pytest.mark.skip(reason="Requires reliable external DB connection (prisma).")
|
|
async def test_regenerate_api_key_with_new_alias_and_expiration(prisma_client):
|
|
litellm.set_verbose = True
|
|
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
|
|
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
|
|
await litellm.proxy.proxy_server.prisma_client.connect()
|
|
from litellm._uuid import uuid
|
|
|
|
# generate new key
|
|
key_alias = f"test_alias_regenerate_key-{uuid.uuid4()}"
|
|
spend = 100
|
|
max_budget = 400
|
|
models = ["fake-openai-endpoint"]
|
|
new_key = await generate_key_fn(
|
|
data=GenerateKeyRequest(
|
|
key_alias=key_alias, spend=spend, max_budget=max_budget, models=models
|
|
),
|
|
user_api_key_dict=UserAPIKeyAuth(
|
|
user_role=LitellmUserRoles.PROXY_ADMIN,
|
|
api_key="sk-1234",
|
|
user_id="1234",
|
|
),
|
|
)
|
|
|
|
generated_key = new_key.key
|
|
print(generated_key)
|
|
|
|
# regenerate the key with new alias and expiration
|
|
new_key = await regenerate_key_fn(
|
|
key=generated_key,
|
|
data=RegenerateKeyRequest(
|
|
key_alias="very_new_alias",
|
|
duration="30d",
|
|
),
|
|
user_api_key_dict=UserAPIKeyAuth(
|
|
user_role=LitellmUserRoles.PROXY_ADMIN,
|
|
api_key="sk-1234",
|
|
user_id="1234",
|
|
),
|
|
)
|
|
print("response from regenerate_key_fn", new_key)
|
|
|
|
# assert the alias and duration are updated
|
|
assert new_key.key_alias == "very_new_alias"
|
|
|
|
# assert the new key expires 30 days from now
|
|
now = datetime.now(dt.timezone.utc)
|
|
assert new_key.expires > now + dt.timedelta(days=29)
|
|
assert new_key.expires < now + dt.timedelta(days=31)
|
|
|
|
|
|
@pytest.mark.asyncio()
|
|
@pytest.mark.skip(reason="Requires reliable external DB connection (prisma).")
|
|
async def test_regenerate_key_ui(prisma_client):
|
|
litellm.set_verbose = True
|
|
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
|
|
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
|
|
await litellm.proxy.proxy_server.prisma_client.connect()
|
|
from litellm._uuid import uuid
|
|
|
|
# generate new key
|
|
key_alias = f"test_alias_regenerate_key-{uuid.uuid4()}"
|
|
spend = 100
|
|
max_budget = 400
|
|
models = ["fake-openai-endpoint"]
|
|
new_key = await generate_key_fn(
|
|
data=GenerateKeyRequest(
|
|
key_alias=key_alias, spend=spend, max_budget=max_budget, models=models
|
|
),
|
|
user_api_key_dict=UserAPIKeyAuth(
|
|
user_role=LitellmUserRoles.PROXY_ADMIN,
|
|
api_key="sk-1234",
|
|
user_id="1234",
|
|
),
|
|
)
|
|
|
|
generated_key = new_key.key
|
|
print(generated_key)
|
|
|
|
# assert the new key works as expected
|
|
request = Request(scope={"type": "http"})
|
|
request._url = URL(url="/chat/completions")
|
|
|
|
async def return_body():
|
|
return_string = f'{{"model": "fake-openai-endpoint"}}'
|
|
# return string as bytes
|
|
return return_string.encode()
|
|
|
|
request.body = return_body
|
|
result = await user_api_key_auth(request=request, api_key=f"Bearer {generated_key}")
|
|
print(result)
|
|
|
|
# regenerate the key
|
|
new_key = await regenerate_key_fn(
|
|
key=generated_key,
|
|
data=RegenerateKeyRequest(duration=""),
|
|
user_api_key_dict=UserAPIKeyAuth(
|
|
user_role=LitellmUserRoles.PROXY_ADMIN,
|
|
api_key="sk-1234",
|
|
user_id="1234",
|
|
),
|
|
)
|
|
print("response from regenerate_key_fn", new_key)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
@pytest.mark.skip(reason="Requires reliable external DB connection (prisma).")
|
|
async def test_get_users(prisma_client):
|
|
"""
|
|
Tests /users/list endpoint
|
|
|
|
Admin UI calls this endpoint to list all Internal Users
|
|
"""
|
|
litellm.set_verbose = True
|
|
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
|
|
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
|
|
await litellm.proxy.proxy_server.prisma_client.connect()
|
|
|
|
# Create some test users
|
|
test_users = [
|
|
NewUserRequest(
|
|
user_id=f"test_user_{i}_{uuid.uuid4()}",
|
|
user_role=(
|
|
LitellmUserRoles.INTERNAL_USER.value
|
|
if i % 2 == 0
|
|
else LitellmUserRoles.PROXY_ADMIN.value
|
|
),
|
|
)
|
|
for i in range(5)
|
|
]
|
|
for user in test_users:
|
|
await new_user(
|
|
user,
|
|
UserAPIKeyAuth(
|
|
user_role=LitellmUserRoles.PROXY_ADMIN,
|
|
api_key="sk-1234",
|
|
user_id="admin",
|
|
),
|
|
)
|
|
|
|
# Test get_users without filters
|
|
result = await get_users(
|
|
role=None,
|
|
page=1,
|
|
page_size=20,
|
|
)
|
|
print("get users result", result)
|
|
assert "users" in result
|
|
|
|
for user in result["users"]:
|
|
assert isinstance(user, LiteLLM_UserTable)
|
|
|
|
# Clean up test users
|
|
for user in test_users:
|
|
await prisma_client.db.litellm_usertable.delete(where={"user_id": user.user_id})
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
@pytest.mark.skip(reason="Requires reliable external DB connection (prisma).")
|
|
async def test_get_users_filters_dashboard_keys(prisma_client):
|
|
"""
|
|
Tests that /users/list endpoint doesn't return keys with team_id='litellm-dashboard'
|
|
|
|
The dashboard keys should be filtered out from the response
|
|
"""
|
|
litellm.set_verbose = True
|
|
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
|
|
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
|
|
await litellm.proxy.proxy_server.prisma_client.connect()
|
|
|
|
# Create a test user
|
|
new_user_id = f"test_user_with_keys-{uuid.uuid4()}"
|
|
test_user = NewUserRequest(
|
|
user_id=new_user_id,
|
|
user_role=LitellmUserRoles.INTERNAL_USER.value,
|
|
auto_create_key=False,
|
|
)
|
|
|
|
await new_user(
|
|
test_user,
|
|
UserAPIKeyAuth(
|
|
user_role=LitellmUserRoles.PROXY_ADMIN,
|
|
api_key="sk-1234",
|
|
user_id="admin",
|
|
),
|
|
)
|
|
|
|
# Create two keys for the user - one with team_id="litellm-dashboard" and one without
|
|
regular_key = await generate_key_helper_fn(
|
|
user_id=test_user.user_id,
|
|
request_type="key",
|
|
team_id="litellm-dashboard", # This key should be included in the response
|
|
models=[],
|
|
aliases={},
|
|
config={},
|
|
spend=0,
|
|
duration=None,
|
|
)
|
|
|
|
regular_key = await generate_key_helper_fn(
|
|
user_id=test_user.user_id,
|
|
request_type="key",
|
|
team_id="NEW_TEAM", # This key should be included in the response
|
|
models=[],
|
|
aliases={},
|
|
config={},
|
|
spend=0,
|
|
duration=None,
|
|
)
|
|
|
|
regular_key = await generate_key_helper_fn(
|
|
user_id=test_user.user_id,
|
|
request_type="key",
|
|
team_id=None, # This key should be included in the response
|
|
models=[],
|
|
aliases={},
|
|
config={},
|
|
spend=0,
|
|
duration=None,
|
|
)
|
|
|
|
# Test get_users for the specific user
|
|
result = await get_users(
|
|
user_ids=test_user.user_id,
|
|
role=None,
|
|
page=1,
|
|
page_size=20,
|
|
)
|
|
|
|
print("get users result", result)
|
|
assert "users" in result
|
|
assert len(result["users"]) == 1
|
|
|
|
# Verify the key count is correct (should be 1, not counting dashboard keys)
|
|
user = result["users"][0]
|
|
assert user.user_id == test_user.user_id
|
|
assert user.key_count == 2 # Only count the regular keys, not the UI dashboard key
|
|
|
|
# Clean up test user and keys
|
|
await prisma_client.db.litellm_usertable.delete(
|
|
where={"user_id": test_user.user_id}
|
|
)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
@pytest.mark.skip(reason="Requires reliable external DB connection (prisma).")
|
|
async def test_get_users_key_count(prisma_client):
|
|
"""
|
|
Test that verifies the key_count in get_users increases when a new key is created for a user
|
|
"""
|
|
litellm.set_verbose = True
|
|
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
|
|
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
|
|
await litellm.proxy.proxy_server.prisma_client.connect()
|
|
|
|
# Create a test user with no initial keys to ensure deterministic behavior
|
|
test_user_id = f"test_user_key_count-{uuid.uuid4()}"
|
|
test_user_request = NewUserRequest(
|
|
user_id=test_user_id,
|
|
user_role=LitellmUserRoles.INTERNAL_USER.value,
|
|
auto_create_key=False, # Ensure we start with 0 keys
|
|
)
|
|
|
|
await new_user(
|
|
test_user_request,
|
|
UserAPIKeyAuth(
|
|
user_role=LitellmUserRoles.PROXY_ADMIN,
|
|
api_key="sk-1234",
|
|
user_id="admin",
|
|
),
|
|
)
|
|
|
|
# Get initial key count for the test user
|
|
initial_users = await get_users(
|
|
user_ids=test_user_id,
|
|
role=None,
|
|
page=1,
|
|
page_size=20,
|
|
)
|
|
print("initial_users", initial_users)
|
|
assert len(initial_users["users"]) == 1, "Test user should be found"
|
|
test_user = initial_users["users"][0]
|
|
assert test_user.user_id == test_user_id
|
|
initial_key_count = test_user.key_count
|
|
assert (
|
|
initial_key_count == 0
|
|
), f"Expected initial key count to be 0, but got {initial_key_count}"
|
|
|
|
# Create a new key for the test user
|
|
new_key = await generate_key_fn(
|
|
data=GenerateKeyRequest(
|
|
user_id=test_user_id,
|
|
key_alias=f"test_key_{uuid.uuid4()}",
|
|
models=["fake-model"],
|
|
),
|
|
user_api_key_dict=UserAPIKeyAuth(
|
|
user_role=LitellmUserRoles.PROXY_ADMIN,
|
|
api_key="sk-1234",
|
|
user_id="admin",
|
|
),
|
|
)
|
|
|
|
# Get updated user list and check key count
|
|
updated_users = await get_users(
|
|
user_ids=test_user_id,
|
|
role=None,
|
|
page=1,
|
|
page_size=20,
|
|
)
|
|
print("updated_users", updated_users)
|
|
assert len(updated_users["users"]) == 1, "Test user should still be found"
|
|
updated_user = updated_users["users"][0]
|
|
updated_key_count = updated_user.key_count
|
|
|
|
assert (
|
|
updated_key_count == initial_key_count + 1
|
|
), f"Expected key count to increase by 1, but got {updated_key_count} (was {initial_key_count})"
|
|
|
|
# Clean up test user and keys
|
|
await prisma_client.db.litellm_usertable.delete(where={"user_id": test_user_id})
|
|
|
|
|
|
async def cleanup_existing_teams(prisma_client):
|
|
all_teams = await prisma_client.db.litellm_teamtable.find_many()
|
|
for team in all_teams:
|
|
await prisma_client.delete_data(team_id_list=[team.team_id], table_name="team")
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
@pytest.mark.skip(reason="Requires reliable external DB connection (prisma).")
|
|
async def test_list_teams(prisma_client):
|
|
"""
|
|
Tests /team/list endpoint to verify it returns both keys and members_with_roles
|
|
"""
|
|
litellm.set_verbose = True
|
|
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
|
|
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
|
|
await litellm.proxy.proxy_server.prisma_client.connect()
|
|
|
|
# Delete all existing teams first
|
|
await cleanup_existing_teams(prisma_client)
|
|
|
|
# Create a test team with members
|
|
team_id = f"test_team_{uuid.uuid4()}"
|
|
team_alias = f"test_team_alias_{uuid.uuid4()}"
|
|
test_team = await new_team(
|
|
data=NewTeamRequest(
|
|
team_id=team_id,
|
|
team_alias=team_alias,
|
|
members_with_roles=[
|
|
Member(role="admin", user_id="test_user_1"),
|
|
Member(role="user", user_id="test_user_2"),
|
|
],
|
|
models=["gpt-4"],
|
|
tpm_limit=1000,
|
|
rpm_limit=1000,
|
|
budget_duration="30d",
|
|
max_budget=1000,
|
|
),
|
|
http_request=Request(scope={"type": "http"}),
|
|
user_api_key_dict=UserAPIKeyAuth(
|
|
user_role=LitellmUserRoles.PROXY_ADMIN, api_key="sk-1234", user_id="admin"
|
|
),
|
|
)
|
|
|
|
# Create a key for the team
|
|
test_key = await generate_key_fn(
|
|
data=GenerateKeyRequest(
|
|
team_id=team_id,
|
|
key_alias=f"test_key_{uuid.uuid4()}",
|
|
),
|
|
user_api_key_dict=UserAPIKeyAuth(
|
|
user_role=LitellmUserRoles.PROXY_ADMIN, api_key="sk-1234", user_id="admin"
|
|
),
|
|
)
|
|
|
|
# Get team list
|
|
teams = await list_team(
|
|
http_request=Request(scope={"type": "http"}),
|
|
user_api_key_dict=UserAPIKeyAuth(
|
|
user_role=LitellmUserRoles.PROXY_ADMIN, api_key="sk-1234", user_id="admin"
|
|
),
|
|
user_id=None,
|
|
)
|
|
|
|
print("teams", teams)
|
|
|
|
# Find our test team in the response
|
|
test_team_response = None
|
|
for team in teams:
|
|
if team.team_id == team_id:
|
|
test_team_response = team
|
|
break
|
|
|
|
assert (
|
|
test_team_response is not None
|
|
), f"Could not find test team {team_id} in response"
|
|
|
|
# Verify members_with_roles
|
|
assert (
|
|
len(test_team_response.members_with_roles) == 3
|
|
), "Expected 3 members in team" # 2 members + 1 team admin
|
|
member_roles = {m.role for m in test_team_response.members_with_roles}
|
|
assert "admin" in member_roles, "Expected admin role in members"
|
|
assert "user" in member_roles, "Expected user role in members"
|
|
|
|
# Verify all required fields in TeamListResponseObject
|
|
assert (
|
|
test_team_response.team_id == team_id
|
|
), f"team_id should be expected value {team_id}"
|
|
assert (
|
|
test_team_response.team_alias == team_alias
|
|
), f"team_alias should be expected value {team_alias}"
|
|
assert test_team_response.spend is not None, "spend should not be None"
|
|
assert (
|
|
test_team_response.max_budget == 1000
|
|
), f"max_budget should be expected value 1000"
|
|
assert test_team_response.models == [
|
|
"gpt-4"
|
|
], f"models should be expected value ['gpt-4']"
|
|
assert (
|
|
test_team_response.tpm_limit == 1000
|
|
), f"tpm_limit should be expected value 1000"
|
|
assert (
|
|
test_team_response.rpm_limit == 1000
|
|
), f"rpm_limit should be expected value 1000"
|
|
assert (
|
|
test_team_response.budget_reset_at is not None
|
|
), "budget_reset_at should not be None since budget_duration is 30d"
|
|
|
|
# Verify keys are returned
|
|
assert len(test_team_response.keys) > 0, "Expected at least one key for team"
|
|
assert any(
|
|
k.team_id == team_id for k in test_team_response.keys
|
|
), "Expected to find team key in response"
|
|
|
|
# Clean up
|
|
await prisma_client.delete_data(team_id_list=[team_id], table_name="team")
|
|
|
|
|
|
def test_is_team_key():
|
|
from litellm.proxy.management_endpoints.key_management_endpoints import _is_team_key
|
|
|
|
assert _is_team_key(GenerateKeyRequest(team_id="test_team_id"))
|
|
assert not _is_team_key(GenerateKeyRequest(user_id="test_user_id"))
|
|
|
|
|
|
def test_team_key_generation_team_member_check():
|
|
from litellm.proxy.management_endpoints.key_management_endpoints import (
|
|
_team_key_generation_check,
|
|
)
|
|
from fastapi import HTTPException
|
|
from litellm.proxy._types import LiteLLM_TeamTableCachedObj
|
|
|
|
litellm.key_generation_settings = {
|
|
"team_key_generation": {"allowed_team_member_roles": ["admin"]}
|
|
}
|
|
|
|
team_table = LiteLLM_TeamTableCachedObj(
|
|
team_id="test_team_id",
|
|
team_alias="test_team_alias",
|
|
members_with_roles=[Member(role="admin", user_id="test_user_id")],
|
|
)
|
|
|
|
assert _team_key_generation_check(
|
|
team_table=team_table,
|
|
user_api_key_dict=UserAPIKeyAuth(
|
|
user_id="test_user_id",
|
|
user_role=LitellmUserRoles.INTERNAL_USER,
|
|
api_key="sk-1234",
|
|
team_member=Member(role="admin", user_id="test_user_id"),
|
|
),
|
|
data=GenerateKeyRequest(),
|
|
route=KeyManagementRoutes.KEY_GENERATE,
|
|
)
|
|
|
|
team_table = LiteLLM_TeamTableCachedObj(
|
|
team_id="test_team_id",
|
|
team_alias="test_team_alias",
|
|
members_with_roles=[Member(role="user", user_id="test_user_id")],
|
|
)
|
|
|
|
with pytest.raises(HTTPException):
|
|
_team_key_generation_check(
|
|
team_table=team_table,
|
|
user_api_key_dict=UserAPIKeyAuth(
|
|
user_role=LitellmUserRoles.INTERNAL_USER,
|
|
api_key="sk-1234",
|
|
user_id="test_user_id",
|
|
team_member=Member(role="user", user_id="test_user_id"),
|
|
),
|
|
data=GenerateKeyRequest(),
|
|
route=KeyManagementRoutes.KEY_GENERATE,
|
|
)
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"team_key_generation_settings, input_data, expected_result",
|
|
[
|
|
({"required_params": ["tags"]}, GenerateKeyRequest(tags=["test_tags"]), True),
|
|
({}, GenerateKeyRequest(), True),
|
|
(
|
|
{"required_params": ["models"]},
|
|
GenerateKeyRequest(tags=["test_tags"]),
|
|
False,
|
|
),
|
|
],
|
|
)
|
|
@pytest.mark.parametrize("key_type", ["team_key", "personal_key"])
|
|
def test_key_generation_required_params_check(
|
|
team_key_generation_settings, input_data, expected_result, key_type
|
|
):
|
|
from litellm.proxy.management_endpoints.key_management_endpoints import (
|
|
_team_key_generation_check,
|
|
_personal_key_generation_check,
|
|
)
|
|
from litellm.types.utils import (
|
|
TeamUIKeyGenerationConfig,
|
|
StandardKeyGenerationConfig,
|
|
PersonalUIKeyGenerationConfig,
|
|
)
|
|
from litellm.proxy._types import LiteLLM_TeamTableCachedObj
|
|
from fastapi import HTTPException
|
|
|
|
user_api_key_dict = UserAPIKeyAuth(
|
|
user_role=LitellmUserRoles.INTERNAL_USER,
|
|
api_key="sk-1234",
|
|
user_id="test_user_id",
|
|
team_id="test_team_id",
|
|
team_member=None,
|
|
)
|
|
|
|
team_table = LiteLLM_TeamTableCachedObj(
|
|
team_id="test_team_id",
|
|
team_alias="test_team_alias",
|
|
members_with_roles=[Member(role="admin", user_id="test_user_id")],
|
|
)
|
|
|
|
if key_type == "team_key":
|
|
litellm.key_generation_settings = StandardKeyGenerationConfig(
|
|
team_key_generation=TeamUIKeyGenerationConfig(
|
|
**team_key_generation_settings
|
|
)
|
|
)
|
|
elif key_type == "personal_key":
|
|
litellm.key_generation_settings = StandardKeyGenerationConfig(
|
|
personal_key_generation=PersonalUIKeyGenerationConfig(
|
|
**team_key_generation_settings
|
|
)
|
|
)
|
|
|
|
if expected_result:
|
|
if key_type == "team_key":
|
|
assert _team_key_generation_check(
|
|
team_table=team_table,
|
|
user_api_key_dict=user_api_key_dict,
|
|
data=input_data,
|
|
route=KeyManagementRoutes.KEY_GENERATE,
|
|
)
|
|
elif key_type == "personal_key":
|
|
assert _personal_key_generation_check(
|
|
user_api_key_dict=user_api_key_dict,
|
|
data=input_data,
|
|
)
|
|
else:
|
|
if key_type == "team_key":
|
|
with pytest.raises(HTTPException):
|
|
_team_key_generation_check(
|
|
team_table=team_table,
|
|
user_api_key_dict=user_api_key_dict,
|
|
data=input_data,
|
|
route=KeyManagementRoutes.KEY_GENERATE,
|
|
)
|
|
elif key_type == "personal_key":
|
|
with pytest.raises(HTTPException):
|
|
_personal_key_generation_check(user_api_key_dict, input_data)
|
|
|
|
|
|
def test_personal_key_generation_check():
|
|
from litellm.proxy.management_endpoints.key_management_endpoints import (
|
|
_personal_key_generation_check,
|
|
)
|
|
from fastapi import HTTPException
|
|
|
|
litellm.key_generation_settings = {
|
|
"personal_key_generation": {"allowed_user_roles": ["proxy_admin"]}
|
|
}
|
|
|
|
assert _personal_key_generation_check(
|
|
user_api_key_dict=UserAPIKeyAuth(
|
|
user_role=LitellmUserRoles.PROXY_ADMIN, api_key="sk-1234", user_id="admin"
|
|
),
|
|
data=GenerateKeyRequest(),
|
|
)
|
|
|
|
with pytest.raises(HTTPException):
|
|
_personal_key_generation_check(
|
|
user_api_key_dict=UserAPIKeyAuth(
|
|
user_role=LitellmUserRoles.INTERNAL_USER,
|
|
api_key="sk-1234",
|
|
user_id="admin",
|
|
),
|
|
data=GenerateKeyRequest(),
|
|
)
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"update_request_data, non_default_values, existing_metadata, expected_result",
|
|
[
|
|
(
|
|
{"metadata": {"test": "new"}},
|
|
{"metadata": {"test": "new"}},
|
|
{"test": "test"},
|
|
{"metadata": {"test": "new"}},
|
|
),
|
|
(
|
|
{"tags": ["new_tag"]},
|
|
{},
|
|
{"tags": ["old_tag"]},
|
|
{"metadata": {"tags": ["new_tag"]}},
|
|
),
|
|
(
|
|
{"enforced_params": ["metadata.tags"]},
|
|
{},
|
|
{"tags": ["old_tag"]},
|
|
{"metadata": {"tags": ["old_tag"], "enforced_params": ["metadata.tags"]}},
|
|
),
|
|
(
|
|
{"disable_global_guardrails": True},
|
|
{},
|
|
{},
|
|
{"metadata": {"disable_global_guardrails": True}},
|
|
),
|
|
(
|
|
{"disable_global_guardrails": False},
|
|
{},
|
|
{"disable_global_guardrails": True},
|
|
{"metadata": {"disable_global_guardrails": False}},
|
|
),
|
|
],
|
|
)
|
|
def test_prepare_metadata_fields(
|
|
update_request_data, non_default_values, existing_metadata, expected_result
|
|
):
|
|
from litellm.proxy.management_endpoints.key_management_endpoints import (
|
|
prepare_metadata_fields,
|
|
)
|
|
|
|
args = {
|
|
"data": UpdateKeyRequest(
|
|
key="sk-1qGQUJJTcljeaPfzgWRrXQ", **update_request_data
|
|
),
|
|
"non_default_values": non_default_values,
|
|
"existing_metadata": existing_metadata,
|
|
}
|
|
|
|
updated_non_default_values = prepare_metadata_fields(**args)
|
|
assert updated_non_default_values == expected_result
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
@pytest.mark.skip(reason="Requires reliable external DB connection (prisma).")
|
|
async def test_key_update_with_model_specific_params(prisma_client):
|
|
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
|
|
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
|
|
await litellm.proxy.proxy_server.prisma_client.connect()
|
|
|
|
from litellm.proxy.management_endpoints.key_management_endpoints import (
|
|
update_key_fn,
|
|
)
|
|
from litellm.proxy._types import UpdateKeyRequest
|
|
|
|
new_key = await generate_key_fn(
|
|
data=GenerateKeyRequest(models=["gpt-4"]),
|
|
user_api_key_dict=UserAPIKeyAuth(
|
|
user_role=LitellmUserRoles.PROXY_ADMIN,
|
|
api_key="sk-1234",
|
|
user_id="1234",
|
|
),
|
|
)
|
|
|
|
generated_key = new_key.key
|
|
token_hash = new_key.token_id
|
|
print(generated_key)
|
|
|
|
request = Request(scope={"type": "http"})
|
|
request._url = URL(url="/update/key")
|
|
|
|
args = {
|
|
"key_alias": f"test-key_{uuid.uuid4()}",
|
|
"duration": None,
|
|
"models": ["all-team-models"],
|
|
"spend": 0,
|
|
"max_budget": None,
|
|
"user_id": "default_user_id",
|
|
"team_id": None,
|
|
"max_parallel_requests": None,
|
|
"metadata": {
|
|
"model_tpm_limit": {"fake-openai-endpoint": 10},
|
|
"model_rpm_limit": {"fake-openai-endpoint": 0},
|
|
},
|
|
"tpm_limit": None,
|
|
"rpm_limit": None,
|
|
"budget_duration": None,
|
|
"allowed_cache_controls": [],
|
|
"soft_budget": None,
|
|
"config": {},
|
|
"permissions": {},
|
|
"model_max_budget": {},
|
|
"send_invite_email": None,
|
|
"model_rpm_limit": None,
|
|
"model_tpm_limit": None,
|
|
"guardrails": None,
|
|
"blocked": None,
|
|
"aliases": {},
|
|
"key": token_hash,
|
|
"budget_id": None,
|
|
"key_name": "sk-...2GWA",
|
|
"expires": None,
|
|
"token_id": token_hash,
|
|
"litellm_budget_table": None,
|
|
"token": token_hash,
|
|
}
|
|
await update_key_fn(
|
|
request=request,
|
|
data=UpdateKeyRequest(**args),
|
|
user_api_key_dict=UserAPIKeyAuth(
|
|
user_role=LitellmUserRoles.PROXY_ADMIN,
|
|
api_key="sk-1234",
|
|
user_id="1234",
|
|
),
|
|
)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
@pytest.mark.skip(reason="Requires reliable external DB connection (prisma).")
|
|
async def test_list_key_helper(prisma_client):
|
|
"""
|
|
Test _list_key_helper function with various scenarios:
|
|
1. Basic pagination
|
|
2. Filtering by user_id
|
|
3. Filtering by team_id
|
|
4. Filtering by key_alias
|
|
5. Return full object vs token only
|
|
"""
|
|
from litellm.proxy.management_endpoints.key_management_endpoints import (
|
|
_list_key_helper,
|
|
)
|
|
|
|
# Setup - create multiple test keys
|
|
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
|
|
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
|
|
await litellm.proxy.proxy_server.prisma_client.connect()
|
|
|
|
# Create test data
|
|
test_user_id = f"test_user_{uuid.uuid4()}"
|
|
test_team_id = f"test_team_{uuid.uuid4()}"
|
|
test_key_alias = f"test_alias_{uuid.uuid4()}"
|
|
|
|
# Create test data with clear patterns
|
|
test_keys = []
|
|
|
|
# 1. Create 2 keys for test user + test team
|
|
for i in range(2):
|
|
key = await generate_key_fn(
|
|
data=GenerateKeyRequest(
|
|
user_id=test_user_id,
|
|
team_id=test_team_id,
|
|
key_alias=f"team_key_{uuid.uuid4()}", # Make unique with UUID
|
|
),
|
|
user_api_key_dict=UserAPIKeyAuth(
|
|
user_role=LitellmUserRoles.PROXY_ADMIN,
|
|
api_key="sk-1234",
|
|
user_id="admin",
|
|
),
|
|
)
|
|
test_keys.append(key)
|
|
|
|
# 2. Create 1 key for test user (no team)
|
|
key = await generate_key_fn(
|
|
data=GenerateKeyRequest(
|
|
user_id=test_user_id,
|
|
key_alias=test_key_alias, # Already unique from earlier UUID generation
|
|
),
|
|
user_api_key_dict=UserAPIKeyAuth(
|
|
user_role=LitellmUserRoles.PROXY_ADMIN,
|
|
api_key="sk-1234",
|
|
user_id="admin",
|
|
),
|
|
)
|
|
test_keys.append(key)
|
|
|
|
# 3. Create 2 keys for other users
|
|
for i in range(2):
|
|
key = await generate_key_fn(
|
|
data=GenerateKeyRequest(
|
|
user_id=f"other_user_{i}",
|
|
key_alias=f"other_key_{uuid.uuid4()}", # Make unique with UUID
|
|
),
|
|
user_api_key_dict=UserAPIKeyAuth(
|
|
user_role=LitellmUserRoles.PROXY_ADMIN,
|
|
api_key="sk-1234",
|
|
user_id="admin",
|
|
),
|
|
)
|
|
test_keys.append(key)
|
|
|
|
# Test 1: Basic pagination
|
|
result = await _list_key_helper(
|
|
prisma_client=prisma_client,
|
|
page=1,
|
|
size=2,
|
|
user_id=None,
|
|
team_id=None,
|
|
key_alias=None,
|
|
key_hash=None,
|
|
organization_id=None,
|
|
)
|
|
assert len(result["keys"]) == 2, "Should return exactly 2 keys"
|
|
assert result["total_count"] >= 5, "Should have at least 5 total keys"
|
|
assert result["current_page"] == 1
|
|
assert isinstance(result["keys"][0], str), "Should return token strings by default"
|
|
|
|
# Test 2: Filter by user_id
|
|
result = await _list_key_helper(
|
|
prisma_client=prisma_client,
|
|
page=1,
|
|
size=10,
|
|
user_id=test_user_id,
|
|
team_id=None,
|
|
key_alias=None,
|
|
key_hash=None,
|
|
organization_id=None,
|
|
)
|
|
assert len(result["keys"]) == 3, "Should return exactly 3 keys for test user"
|
|
|
|
# Test 3: Filter by team_id
|
|
result = await _list_key_helper(
|
|
prisma_client=prisma_client,
|
|
page=1,
|
|
size=10,
|
|
user_id=None,
|
|
team_id=test_team_id,
|
|
key_alias=None,
|
|
key_hash=None,
|
|
organization_id=None,
|
|
)
|
|
assert len(result["keys"]) == 2, "Should return exactly 2 keys for test team"
|
|
|
|
# Test 4: Filter by key_alias
|
|
result = await _list_key_helper(
|
|
prisma_client=prisma_client,
|
|
page=1,
|
|
size=10,
|
|
user_id=None,
|
|
team_id=None,
|
|
key_alias=test_key_alias,
|
|
key_hash=None,
|
|
organization_id=None,
|
|
)
|
|
assert len(result["keys"]) == 1, "Should return exactly 1 key with test alias"
|
|
|
|
# Test 5: Return full object
|
|
result = await _list_key_helper(
|
|
prisma_client=prisma_client,
|
|
page=1,
|
|
size=10,
|
|
user_id=test_user_id,
|
|
team_id=None,
|
|
key_alias=None,
|
|
key_hash=None,
|
|
return_full_object=True,
|
|
organization_id=None,
|
|
)
|
|
assert all(
|
|
isinstance(key, UserAPIKeyAuth) for key in result["keys"]
|
|
), "Should return UserAPIKeyAuth objects"
|
|
assert len(result["keys"]) == 3, "Should return exactly 3 keys for test user"
|
|
|
|
# Clean up test keys
|
|
for key in test_keys:
|
|
await delete_key_fn(
|
|
data=KeyRequest(keys=[key.key]),
|
|
user_api_key_dict=UserAPIKeyAuth(
|
|
user_role=LitellmUserRoles.PROXY_ADMIN,
|
|
api_key="sk-1234",
|
|
user_id="admin",
|
|
),
|
|
litellm_changed_by=None,
|
|
)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
@pytest.mark.skip(reason="Requires reliable external DB connection (prisma).")
|
|
async def test_list_key_helper_team_filtering(prisma_client):
|
|
"""
|
|
Test _list_key_helper function's team filtering behavior:
|
|
1. Create keys with different team_ids (None, litellm-dashboard, other)
|
|
2. Verify filtering excludes litellm-dashboard keys
|
|
3. Verify keys with team_id=None are included
|
|
4. Test with pagination to ensure behavior is consistent across pages
|
|
"""
|
|
from litellm.proxy.management_endpoints.key_management_endpoints import (
|
|
_list_key_helper,
|
|
)
|
|
from litellm._uuid import uuid
|
|
|
|
# Setup
|
|
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
|
|
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
|
|
await litellm.proxy.proxy_server.prisma_client.connect()
|
|
|
|
# Create test data with different team_ids
|
|
test_keys = []
|
|
|
|
# Create 3 keys with team_id=None
|
|
for i in range(3):
|
|
key = await generate_key_fn(
|
|
data=GenerateKeyRequest(
|
|
key_alias=f"no_team_key_{i}.{uuid.uuid4()}",
|
|
),
|
|
user_api_key_dict=UserAPIKeyAuth(
|
|
user_role=LitellmUserRoles.PROXY_ADMIN,
|
|
api_key="sk-1234",
|
|
user_id="admin",
|
|
),
|
|
)
|
|
test_keys.append(key)
|
|
|
|
# Create 2 keys with team_id=litellm-dashboard
|
|
for i in range(2):
|
|
key = await generate_key_fn(
|
|
data=GenerateKeyRequest(
|
|
team_id="litellm-dashboard",
|
|
key_alias=f"dashboard_key_{i}.{uuid.uuid4()}",
|
|
),
|
|
user_api_key_dict=UserAPIKeyAuth(
|
|
user_role=LitellmUserRoles.PROXY_ADMIN,
|
|
api_key="sk-1234",
|
|
user_id="admin",
|
|
),
|
|
)
|
|
test_keys.append(key)
|
|
|
|
# Create 2 keys with a different team_id
|
|
other_team_id = f"other_team_{uuid.uuid4()}"
|
|
for i in range(2):
|
|
key = await generate_key_fn(
|
|
data=GenerateKeyRequest(
|
|
team_id=other_team_id,
|
|
key_alias=f"other_team_key_{i}.{uuid.uuid4()}",
|
|
),
|
|
user_api_key_dict=UserAPIKeyAuth(
|
|
user_role=LitellmUserRoles.PROXY_ADMIN,
|
|
api_key="sk-1234",
|
|
user_id="admin",
|
|
),
|
|
)
|
|
test_keys.append(key)
|
|
|
|
try:
|
|
# Test 1: Get all keys with pagination (exclude litellm-dashboard)
|
|
all_keys = []
|
|
page = 1
|
|
max_pages_to_check = 3 # Only check the first 3 pages
|
|
|
|
while page <= max_pages_to_check:
|
|
result = await _list_key_helper(
|
|
prisma_client=prisma_client,
|
|
size=100,
|
|
page=page,
|
|
user_id=None,
|
|
team_id=None,
|
|
key_alias=None,
|
|
key_hash=None,
|
|
return_full_object=True,
|
|
organization_id=None,
|
|
)
|
|
|
|
all_keys.extend(result["keys"])
|
|
|
|
if page >= result["total_pages"] or page >= max_pages_to_check:
|
|
break
|
|
page += 1
|
|
|
|
# Verify results
|
|
print(f"Total keys found: {len(all_keys)}")
|
|
for key in all_keys:
|
|
print(f"Key team_id: {key.team_id}, alias: {key.key_alias}")
|
|
|
|
# Verify no litellm-dashboard keys are present
|
|
dashboard_keys = [k for k in all_keys if k.team_id == "litellm-dashboard"]
|
|
assert len(dashboard_keys) == 0, "Should not include litellm-dashboard keys"
|
|
|
|
# Verify keys with team_id=None are included
|
|
no_team_keys = [k for k in all_keys if k.team_id is None]
|
|
assert (
|
|
len(no_team_keys) > 0
|
|
), f"Expected more than 0 keys with no team, got {len(no_team_keys)}"
|
|
|
|
finally:
|
|
# Clean up test keys
|
|
for key in test_keys:
|
|
await delete_key_fn(
|
|
data=KeyRequest(keys=[key.key]),
|
|
user_api_key_dict=UserAPIKeyAuth(
|
|
user_role=LitellmUserRoles.PROXY_ADMIN,
|
|
api_key="sk-1234",
|
|
user_id="admin",
|
|
),
|
|
litellm_changed_by=None,
|
|
)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
@patch("litellm.proxy.management_endpoints.key_management_endpoints.get_team_object")
|
|
async def test_key_generate_always_db_team(mock_get_team_object):
|
|
from litellm.proxy.management_endpoints.key_management_endpoints import (
|
|
generate_key_fn,
|
|
)
|
|
|
|
setattr(litellm.proxy.proxy_server, "prisma_client", MagicMock())
|
|
mock_get_team_object.return_value = None
|
|
try:
|
|
await generate_key_fn(
|
|
data=GenerateKeyRequest(team_id="1234"),
|
|
user_api_key_dict=UserAPIKeyAuth(
|
|
user_role=LitellmUserRoles.PROXY_ADMIN,
|
|
api_key="sk-1234",
|
|
user_id="admin",
|
|
),
|
|
)
|
|
except Exception as e:
|
|
print(f"Error: {e}")
|
|
|
|
mock_get_team_object.assert_called_once()
|
|
assert mock_get_team_object.call_args.kwargs["check_db_only"] == True
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
@pytest.mark.parametrize(
|
|
"requested_model, should_pass",
|
|
[
|
|
("gpt-4o", True), # Should pass - exact match in aliases
|
|
("gpt-4o-team1", True), # Should pass - team has access to this deployment
|
|
("gpt-4o-mini", False), # Should fail - not in aliases
|
|
("o-3", False), # Should fail - not in aliases
|
|
],
|
|
)
|
|
@pytest.mark.skip(reason="Requires reliable external DB connection (prisma).")
|
|
async def test_team_model_alias(prisma_client, requested_model, should_pass):
|
|
"""
|
|
Test team model alias functionality:
|
|
1. Create team with model alias = `{gpt-4o: gpt-4o-team1}`
|
|
2. Generate key for that team with model = `gpt-4o`
|
|
3. Verify chat completion request works with aliased model = `gpt-4o`
|
|
"""
|
|
litellm.set_verbose = True
|
|
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
|
|
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
|
|
await litellm.proxy.proxy_server.prisma_client.connect()
|
|
|
|
# Create team with model alias
|
|
team_id = f"test_team_{uuid.uuid4()}"
|
|
await new_team(
|
|
data=NewTeamRequest(
|
|
team_id=team_id,
|
|
team_alias=f"test_team_alias_{uuid.uuid4()}",
|
|
models=["gpt-4o-team1"],
|
|
model_aliases={"gpt-4o": "gpt-4o-team1"},
|
|
),
|
|
http_request=Request(scope={"type": "http"}),
|
|
user_api_key_dict=UserAPIKeyAuth(
|
|
user_role=LitellmUserRoles.PROXY_ADMIN, api_key="sk-1234", user_id="admin"
|
|
),
|
|
)
|
|
|
|
# Generate key for the team
|
|
new_key = await generate_key_fn(
|
|
data=GenerateKeyRequest(
|
|
team_id=team_id,
|
|
models=["gpt-4o-team1"],
|
|
),
|
|
user_api_key_dict=UserAPIKeyAuth(
|
|
user_role=LitellmUserRoles.PROXY_ADMIN, api_key="sk-1234", user_id="admin"
|
|
),
|
|
)
|
|
|
|
generated_key = new_key.key
|
|
|
|
# Test chat completion request
|
|
request = Request(scope={"type": "http"})
|
|
request._url = URL(url="/chat/completions")
|
|
|
|
async def return_body():
|
|
return_string = f'{{"model": "{requested_model}"}}'
|
|
return return_string.encode()
|
|
|
|
request.body = return_body
|
|
|
|
if should_pass:
|
|
# Verify the key works with the aliased model
|
|
result = await user_api_key_auth(
|
|
request=request, api_key=f"Bearer {generated_key}"
|
|
)
|
|
|
|
assert result.models == [
|
|
"gpt-4o-team1"
|
|
], "Expected model list to contain aliased model"
|
|
assert result.team_model_aliases == {
|
|
"gpt-4o": "gpt-4o-team1"
|
|
}, "Expected model aliases to be present"
|
|
else:
|
|
# Verify the key fails with non-aliased models
|
|
with pytest.raises(Exception) as exc_info:
|
|
await user_api_key_auth(request=request, api_key=f"Bearer {generated_key}")
|
|
assert exc_info.value.type == ProxyErrorTypes.key_model_access_denied
|