mirror of
https://github.com/tiennm99/litellm.git
synced 2026-06-30 21:06:03 +00:00
feed274aa3
This reverts commit 3d2df7e8b5.
223 lines
7.5 KiB
Python
223 lines
7.5 KiB
Python
from unittest.mock import AsyncMock, patch
|
|
|
|
import pytest
|
|
|
|
from litellm.proxy._types import LiteLLM_TeamTable, LiteLLM_UserTable, Member
|
|
from litellm.proxy.auth.handle_jwt import JWTAuthManager
|
|
|
|
|
|
def test_get_team_models_for_all_models_and_team_only_models():
|
|
from litellm.proxy.auth.model_checks import get_team_models
|
|
|
|
team_models = ["all-proxy-models", "team-only-model", "team-only-model-2"]
|
|
proxy_model_list = ["model1", "model2", "model3"]
|
|
model_access_groups = {}
|
|
include_model_access_groups = False
|
|
|
|
result = get_team_models(
|
|
team_models, proxy_model_list, model_access_groups, include_model_access_groups
|
|
)
|
|
combined_models = team_models + proxy_model_list
|
|
assert set(result) == set(combined_models)
|
|
|
|
|
|
def test_get_team_models_all_proxy_models_includes_access_groups():
|
|
"""
|
|
When a team has 'all-proxy-models' and include_model_access_groups=True,
|
|
the result should include model access group names (e.g. 'claude-model-group')
|
|
in addition to individual model names.
|
|
"""
|
|
from litellm.proxy.auth.model_checks import get_team_models
|
|
|
|
team_models = ["all-proxy-models"]
|
|
proxy_model_list = ["model1", "model2"]
|
|
model_access_groups = {
|
|
"group-a": ["model1"],
|
|
"group-b": ["model2"],
|
|
}
|
|
|
|
result = get_team_models(
|
|
team_models, proxy_model_list, model_access_groups, include_model_access_groups=True
|
|
)
|
|
assert "group-a" in result
|
|
assert "group-b" in result
|
|
assert "model1" in result
|
|
assert "model2" in result
|
|
assert len(result) == len(set(result)), "result should have no duplicates"
|
|
|
|
|
|
def test_get_team_models_all_proxy_models_without_include_flag():
|
|
"""
|
|
When include_model_access_groups=False, access group names should NOT
|
|
appear in the result even with 'all-proxy-models'.
|
|
"""
|
|
from litellm.proxy.auth.model_checks import get_team_models
|
|
|
|
team_models = ["all-proxy-models"]
|
|
proxy_model_list = ["model1", "model2"]
|
|
model_access_groups = {
|
|
"group-a": ["model1"],
|
|
"group-b": ["model2"],
|
|
}
|
|
|
|
result = get_team_models(
|
|
team_models, proxy_model_list, model_access_groups, include_model_access_groups=False
|
|
)
|
|
assert "group-a" not in result
|
|
assert "group-b" not in result
|
|
assert "model1" in result
|
|
assert "model2" in result
|
|
|
|
|
|
def test_get_key_models_all_proxy_models_includes_access_groups():
|
|
"""
|
|
When a key has 'all-proxy-models' and include_model_access_groups=True,
|
|
the result should include model access group names.
|
|
"""
|
|
from litellm.proxy._types import UserAPIKeyAuth
|
|
from litellm.proxy.auth.model_checks import get_key_models
|
|
|
|
user_api_key_dict = UserAPIKeyAuth(
|
|
models=["all-proxy-models"],
|
|
api_key="test-key",
|
|
)
|
|
proxy_model_list = ["model1", "model2"]
|
|
model_access_groups = {
|
|
"group-a": ["model1"],
|
|
}
|
|
|
|
result = get_key_models(
|
|
user_api_key_dict=user_api_key_dict,
|
|
proxy_model_list=proxy_model_list,
|
|
model_access_groups=model_access_groups,
|
|
include_model_access_groups=True,
|
|
)
|
|
assert "group-a" in result
|
|
assert "model1" in result
|
|
assert "model2" in result
|
|
assert len(result) == len(set(result)), "result should have no duplicates"
|
|
|
|
|
|
def test_get_key_models_passes_include_model_access_groups():
|
|
"""
|
|
When a key explicitly has an access group name in its models list and
|
|
include_model_access_groups=True, the group name should be retained
|
|
(not stripped by _get_models_from_access_groups).
|
|
"""
|
|
from litellm.proxy._types import UserAPIKeyAuth
|
|
from litellm.proxy.auth.model_checks import get_key_models
|
|
|
|
user_api_key_dict = UserAPIKeyAuth(
|
|
models=["group-a"],
|
|
api_key="test-key",
|
|
)
|
|
proxy_model_list = ["model1", "model2"]
|
|
model_access_groups = {
|
|
"group-a": ["model1", "model2"],
|
|
}
|
|
|
|
result = get_key_models(
|
|
user_api_key_dict=user_api_key_dict,
|
|
proxy_model_list=proxy_model_list,
|
|
model_access_groups=model_access_groups,
|
|
include_model_access_groups=True,
|
|
)
|
|
assert "group-a" in result
|
|
assert "model1" in result
|
|
assert "model2" in result
|
|
|
|
|
|
def test_get_key_models_does_not_mutate_input():
|
|
"""
|
|
get_key_models must not mutate user_api_key_dict.models in-place.
|
|
_get_models_from_access_groups uses .pop()/.extend() which would corrupt
|
|
cached UserAPIKeyAuth objects if all_models were an alias instead of a copy.
|
|
"""
|
|
from litellm.proxy._types import UserAPIKeyAuth
|
|
from litellm.proxy.auth.model_checks import get_key_models
|
|
|
|
original_models = ["group-a", "extra-model"]
|
|
user_api_key_dict = UserAPIKeyAuth(
|
|
models=list(original_models), # give it a list
|
|
api_key="test-key",
|
|
)
|
|
model_access_groups = {
|
|
"group-a": ["model1", "model2"],
|
|
}
|
|
|
|
_ = get_key_models(
|
|
user_api_key_dict=user_api_key_dict,
|
|
proxy_model_list=["model1", "model2"],
|
|
model_access_groups=model_access_groups,
|
|
include_model_access_groups=False,
|
|
)
|
|
# The original models list on the auth object must be unchanged
|
|
assert user_api_key_dict.models == original_models
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"key_models,team_models,proxy_model_list,model_list,expected",
|
|
[
|
|
(
|
|
["anthropic/claude-3-haiku-20240307", "anthropic/claude-3-5-haiku-20241022"],
|
|
[],
|
|
[],
|
|
[{"model_name": "anthropic/*", "litellm_params": {"model": "anthropic/*"}}],
|
|
["anthropic/claude-3-haiku-20240307", "anthropic/claude-3-5-haiku-20241022"]
|
|
),
|
|
(
|
|
[],
|
|
["anthropic/claude-3-haiku-20240307", "anthropic/claude-3-5-haiku-20241022"],
|
|
[],
|
|
[{"model_name": "anthropic/*", "litellm_params": {"model": "anthropic/*"}}],
|
|
["anthropic/claude-3-haiku-20240307", "anthropic/claude-3-5-haiku-20241022"]
|
|
),
|
|
(
|
|
[],
|
|
[],
|
|
["anthropic/claude-3-haiku-20240307", "anthropic/claude-3-5-haiku-20241022"],
|
|
[{"model_name": "anthropic/*", "litellm_params": {"model": "anthropic/*"}}],
|
|
["anthropic/claude-3-haiku-20240307", "anthropic/claude-3-5-haiku-20241022"]
|
|
),
|
|
],
|
|
)
|
|
def test_get_complete_model_list_order(key_models, team_models, proxy_model_list, model_list, expected):
|
|
"""
|
|
Test that get_complete_model_list preserves order
|
|
"""
|
|
from litellm.proxy.auth.model_checks import get_complete_model_list
|
|
from litellm import Router
|
|
|
|
assert get_complete_model_list(
|
|
proxy_model_list=proxy_model_list,
|
|
key_models=key_models,
|
|
team_models=team_models,
|
|
user_model=None,
|
|
infer_model_from_keys=False,
|
|
llm_router=Router(model_list=model_list),
|
|
) == expected
|
|
|
|
|
|
def test_get_complete_model_list_byok_wildcard_expansion():
|
|
"""
|
|
Test that wildcard models (e.g., openai/*) are expanded when the router has
|
|
no deployment for them - BYOK case where team has openai/* but proxy has
|
|
no openai config.
|
|
"""
|
|
from litellm.proxy.auth.model_checks import get_complete_model_list
|
|
from litellm import Router
|
|
|
|
# Router with empty model_list - no openai/* deployment (BYOK scenario)
|
|
result = get_complete_model_list(
|
|
key_models=[],
|
|
team_models=["openai/*"],
|
|
proxy_model_list=[],
|
|
user_model=None,
|
|
infer_model_from_keys=False,
|
|
llm_router=Router(model_list=[]),
|
|
)
|
|
# Should expand openai/* to actual OpenAI models
|
|
assert len(result) > 0
|
|
assert all(m.startswith("openai/") for m in result)
|
|
assert "openai/*" not in result
|