mirror of
https://github.com/tiennm99/litellm.git
synced 2026-06-18 05:28:02 +00:00
434 lines
14 KiB
Python
434 lines
14 KiB
Python
# stdlib imports
|
|
import json
|
|
import os
|
|
import sys
|
|
import time
|
|
from unittest.mock import patch
|
|
|
|
import pytest
|
|
|
|
# third party imports
|
|
from click.testing import CliRunner
|
|
|
|
sys.path.insert(
|
|
0, os.path.abspath("../../..")
|
|
) # Adds the parent directory to the system path
|
|
|
|
|
|
# local imports
|
|
from litellm.proxy.client.cli import cli
|
|
from litellm.proxy.client.cli.commands.models import (
|
|
format_cost_per_1k_tokens,
|
|
format_iso_datetime_str,
|
|
format_timestamp,
|
|
)
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_client():
|
|
"""Fixture to create a mock client with common setup"""
|
|
with patch("litellm.proxy.client.cli.commands.models.Client") as MockClient:
|
|
yield MockClient
|
|
|
|
|
|
@pytest.fixture
|
|
def cli_runner():
|
|
"""Fixture for Click CLI runner"""
|
|
return CliRunner()
|
|
|
|
|
|
@pytest.fixture(autouse=True)
|
|
def mock_env():
|
|
"""Fixture to set up environment variables for all tests"""
|
|
with patch.dict(
|
|
os.environ,
|
|
{
|
|
"LITELLM_PROXY_URL": "http://localhost:4000",
|
|
"LITELLM_PROXY_API_KEY": "sk-test",
|
|
},
|
|
):
|
|
yield
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_models_list(mock_client):
|
|
"""Fixture to set up common mocking pattern for models list tests"""
|
|
mock_client.return_value.models.list.return_value = [
|
|
{
|
|
"id": "model-123",
|
|
"object": "model",
|
|
"created": 1699848889,
|
|
"owned_by": "organization-123",
|
|
},
|
|
{
|
|
"id": "model-456",
|
|
"object": "model",
|
|
"created": 1699848890,
|
|
"owned_by": "organization-456",
|
|
},
|
|
]
|
|
|
|
mock_client.assert_not_called() # Ensure clean slate
|
|
return mock_client
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_models_info(mock_client):
|
|
"""Fixture to set up models info mock"""
|
|
mock_client.return_value.models.info.return_value = [
|
|
{
|
|
"model_name": "gpt-4",
|
|
"litellm_params": {"model": "gpt-4", "litellm_credential_name": "openai-1"},
|
|
"model_info": {
|
|
"id": "model-123",
|
|
"created_at": "2025-04-29T21:31:43.843000+00:00",
|
|
"updated_at": "2025-04-29T21:31:43.843000+00:00",
|
|
"input_cost_per_token": 0.00001,
|
|
"output_cost_per_token": 0.00002,
|
|
},
|
|
}
|
|
]
|
|
|
|
mock_client.assert_not_called()
|
|
return mock_client
|
|
|
|
|
|
@pytest.fixture
|
|
def force_utc_tz():
|
|
"""Fixture to force UTC timezone for tests that depend on system TZ."""
|
|
old_tz = os.environ.get("TZ")
|
|
os.environ["TZ"] = "UTC"
|
|
if hasattr(time, "tzset"):
|
|
time.tzset()
|
|
yield
|
|
# Restore previous TZ
|
|
if old_tz is not None:
|
|
os.environ["TZ"] = old_tz
|
|
else:
|
|
if "TZ" in os.environ:
|
|
del os.environ["TZ"]
|
|
if hasattr(time, "tzset"):
|
|
time.tzset()
|
|
|
|
|
|
def test_models_list_json_format(mock_models_list, cli_runner):
|
|
"""Test the models list command with JSON output format"""
|
|
# Run the command
|
|
result = cli_runner.invoke(cli, ["models", "list", "--format", "json"])
|
|
|
|
# Check that the command succeeded
|
|
assert result.exit_code == 0
|
|
|
|
# Parse the output and verify it matches our mock data
|
|
output_data = json.loads(result.output)
|
|
assert output_data == mock_models_list.return_value.models.list.return_value
|
|
|
|
# Verify the client was called correctly
|
|
mock_models_list.assert_called_once_with(
|
|
base_url="http://localhost:4000", api_key="sk-test"
|
|
)
|
|
mock_models_list.return_value.models.list.assert_called_once()
|
|
|
|
|
|
def test_models_list_table_format(mock_models_list, cli_runner):
|
|
"""Test the models list command with table output format"""
|
|
# Run the command
|
|
result = cli_runner.invoke(cli, ["models", "list"])
|
|
|
|
# Check that the command succeeded
|
|
assert result.exit_code == 0
|
|
|
|
# Verify the output contains expected table elements
|
|
assert "ID" in result.output
|
|
assert "Object" in result.output
|
|
assert "Created" in result.output
|
|
assert "Owned By" in result.output
|
|
assert "model-123" in result.output
|
|
assert "organization-123" in result.output
|
|
assert format_timestamp(1699848889) in result.output
|
|
|
|
# Verify the client was called correctly
|
|
mock_models_list.assert_called_once_with(
|
|
base_url="http://localhost:4000", api_key="sk-test"
|
|
)
|
|
mock_models_list.return_value.models.list.assert_called_once()
|
|
|
|
|
|
def test_models_list_with_base_url(mock_models_list, cli_runner):
|
|
"""Test the models list command with custom base URL overriding env var"""
|
|
custom_base_url = "http://custom.server:8000"
|
|
|
|
# Run the command with custom base URL
|
|
result = cli_runner.invoke(cli, ["--base-url", custom_base_url, "models", "list"])
|
|
|
|
# Check that the command succeeded
|
|
assert result.exit_code == 0
|
|
|
|
# Verify the client was created with the custom base URL (overriding env var)
|
|
mock_models_list.assert_called_once_with(
|
|
base_url=custom_base_url,
|
|
api_key="sk-test", # Should still use env var for API key
|
|
)
|
|
|
|
|
|
def test_models_list_with_api_key(mock_models_list, cli_runner):
|
|
"""Test the models list command with API key overriding env var"""
|
|
custom_api_key = "custom-test-key"
|
|
|
|
# Run the command with custom API key
|
|
result = cli_runner.invoke(cli, ["--api-key", custom_api_key, "models", "list"])
|
|
|
|
# Check that the command succeeded
|
|
assert result.exit_code == 0
|
|
|
|
# Verify the client was created with the custom API key (overriding env var)
|
|
mock_models_list.assert_called_once_with(
|
|
base_url="http://localhost:4000", # Should still use env var for base URL
|
|
api_key=custom_api_key,
|
|
)
|
|
|
|
|
|
def test_models_list_error_handling(mock_client, cli_runner):
|
|
"""Test error handling in the models list command"""
|
|
# Configure mock to raise an exception
|
|
mock_client.return_value.models.list.side_effect = Exception("API Error")
|
|
|
|
# Run the command
|
|
result = cli_runner.invoke(cli, ["models", "list"])
|
|
|
|
# Check that the command failed
|
|
assert result.exit_code != 0
|
|
assert "API Error" in str(result.exception)
|
|
|
|
# Verify the client was created with env var values
|
|
mock_client.assert_called_once_with(
|
|
base_url="http://localhost:4000", api_key="sk-test"
|
|
)
|
|
|
|
|
|
def test_models_info_json_format(mock_models_info, cli_runner):
|
|
"""Test the models info command with JSON output format"""
|
|
# Run the command
|
|
result = cli_runner.invoke(cli, ["models", "info", "--format", "json"])
|
|
|
|
# Check that the command succeeded
|
|
assert result.exit_code == 0
|
|
|
|
# Parse the output and verify it matches our mock data
|
|
output_data = json.loads(result.output)
|
|
assert output_data == mock_models_info.return_value.models.info.return_value
|
|
|
|
# Verify the client was called correctly with env var values
|
|
mock_models_info.assert_called_once_with(
|
|
base_url="http://localhost:4000", api_key="sk-test"
|
|
)
|
|
mock_models_info.return_value.models.info.assert_called_once()
|
|
|
|
|
|
def test_models_info_table_format(mock_models_info, cli_runner):
|
|
"""Test the models info command with table output format"""
|
|
# Run the command with default columns
|
|
result = cli_runner.invoke(cli, ["models", "info"])
|
|
|
|
# Check that the command succeeded
|
|
assert result.exit_code == 0
|
|
|
|
# Verify the output contains expected table elements
|
|
assert "Public Model" in result.output
|
|
assert "Upstream Model" in result.output
|
|
assert "Updated At" in result.output
|
|
assert "gpt-4" in result.output
|
|
assert "2025-04-29 21:31" in result.output
|
|
|
|
# Verify seconds and microseconds are not shown
|
|
assert "21:31:43" not in result.output
|
|
assert "843000" not in result.output
|
|
|
|
# Verify the client was called correctly with env var values
|
|
mock_models_info.assert_called_once_with(
|
|
base_url="http://localhost:4000", api_key="sk-test"
|
|
)
|
|
mock_models_info.return_value.models.info.assert_called_once()
|
|
|
|
|
|
def test_models_import_only_models_matching_regex(tmp_path, mock_client, cli_runner):
|
|
"""Test the --only-models-matching-regex option for models import command"""
|
|
# Prepare a YAML file with a mix of models
|
|
yaml_content = {
|
|
"model_list": [
|
|
{
|
|
"model_name": "gpt-4-model",
|
|
"litellm_params": {"model": "gpt-4"},
|
|
"model_info": {"id": "id-1"},
|
|
},
|
|
{
|
|
"model_name": "gpt-3.5-model",
|
|
"litellm_params": {"model": "gpt-3.5-turbo"},
|
|
"model_info": {"id": "id-2"},
|
|
},
|
|
{
|
|
"model_name": "llama2-model",
|
|
"litellm_params": {"model": "llama2"},
|
|
"model_info": {"id": "id-3"},
|
|
},
|
|
{
|
|
"model_name": "other-model",
|
|
"litellm_params": {"model": "other"},
|
|
"model_info": {"id": "id-4"},
|
|
},
|
|
]
|
|
}
|
|
import yaml as pyyaml
|
|
|
|
yaml_file = tmp_path / "models.yaml"
|
|
with open(yaml_file, "w") as f:
|
|
pyyaml.safe_dump(yaml_content, f)
|
|
|
|
# Patch client.models.new to track calls
|
|
mock_new = mock_client.return_value.models.new
|
|
|
|
# Only match models containing 'gpt' in their litellm_params.model
|
|
result = cli_runner.invoke(
|
|
cli, ["models", "import", str(yaml_file), "--only-models-matching-regex", "gpt"]
|
|
)
|
|
|
|
# Should succeed
|
|
assert result.exit_code == 0
|
|
# Only the two gpt models should be imported
|
|
calls = [call.kwargs["model_params"]["model"] for call in mock_new.call_args_list]
|
|
assert set(calls) == {"gpt-4", "gpt-3.5-turbo"}
|
|
# Should not include llama2 or other
|
|
assert "llama2" not in calls
|
|
assert "other" not in calls
|
|
# Output summary should mention the correct providers
|
|
assert "gpt-4".split("-")[0] in result.output or "gpt" in result.output
|
|
|
|
|
|
def test_models_import_only_access_groups_matching_regex(
|
|
tmp_path, mock_client, cli_runner
|
|
):
|
|
"""Test the --only-access-groups-matching-regex option for models import command"""
|
|
# Prepare a YAML file with a mix of models
|
|
yaml_content = {
|
|
"model_list": [
|
|
{
|
|
"model_name": "gpt-4-model",
|
|
"litellm_params": {"model": "gpt-4"},
|
|
"model_info": {
|
|
"id": "id-1",
|
|
"access_groups": ["beta-models", "prod-models"],
|
|
},
|
|
},
|
|
{
|
|
"model_name": "gpt-3.5-model",
|
|
"litellm_params": {"model": "gpt-3.5-turbo"},
|
|
"model_info": {"id": "id-2", "access_groups": ["alpha-models"]},
|
|
},
|
|
{
|
|
"model_name": "llama2-model",
|
|
"litellm_params": {"model": "llama2"},
|
|
"model_info": {"id": "id-3", "access_groups": ["beta-models"]},
|
|
},
|
|
{
|
|
"model_name": "other-model",
|
|
"litellm_params": {"model": "other"},
|
|
"model_info": {"id": "id-4", "access_groups": ["other-group"]},
|
|
},
|
|
{
|
|
"model_name": "no-access-group-model",
|
|
"litellm_params": {"model": "no-access"},
|
|
"model_info": {"id": "id-5"},
|
|
},
|
|
]
|
|
}
|
|
import yaml as pyyaml
|
|
|
|
yaml_file = tmp_path / "models.yaml"
|
|
with open(yaml_file, "w") as f:
|
|
pyyaml.safe_dump(yaml_content, f)
|
|
|
|
# Patch client.models.new to track calls
|
|
mock_new = mock_client.return_value.models.new
|
|
|
|
# Only match models with access_groups containing 'beta'
|
|
result = cli_runner.invoke(
|
|
cli,
|
|
[
|
|
"models",
|
|
"import",
|
|
str(yaml_file),
|
|
"--only-access-groups-matching-regex",
|
|
"beta",
|
|
],
|
|
)
|
|
|
|
# Should succeed
|
|
assert result.exit_code == 0
|
|
# Only the two models with 'beta-models' in access_groups should be imported
|
|
calls = [call.kwargs["model_params"]["model"] for call in mock_new.call_args_list]
|
|
assert set(calls) == {"gpt-4", "llama2"}
|
|
# Should not include gpt-3.5, other, or no-access
|
|
assert "gpt-3.5-turbo" not in calls
|
|
assert "other" not in calls
|
|
assert "no-access" not in calls
|
|
# Output summary should mention the correct providers
|
|
assert "gpt-4".split("-")[0] in result.output or "gpt" in result.output
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"input_str,expected",
|
|
[
|
|
(None, ""),
|
|
("", ""),
|
|
("2024-05-01T12:34:56Z", "2024-05-01 12:34"),
|
|
("2024-05-01T12:34:56+00:00", "2024-05-01 12:34"),
|
|
("2024-05-01T12:34:56.123456+00:00", "2024-05-01 12:34"),
|
|
("2024-05-01T12:34:56.123456Z", "2024-05-01 12:34"),
|
|
("2024-05-01T12:34:56-04:00", "2024-05-01 12:34"),
|
|
("2024-05-01", "2024-05-01 00:00"),
|
|
("not-a-date", "not-a-date"),
|
|
],
|
|
)
|
|
def test_format_iso_datetime_str(input_str, expected):
|
|
assert format_iso_datetime_str(input_str) == expected
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"input_val,expected",
|
|
[
|
|
(None, ""),
|
|
(1699848889, "2023-11-13 04:14"),
|
|
(1699848889.0, "2023-11-13 04:14"),
|
|
("not-a-timestamp", "not-a-timestamp"),
|
|
([1, 2, 3], "[1, 2, 3]"),
|
|
],
|
|
)
|
|
def test_format_timestamp(input_val, expected, force_utc_tz):
|
|
actual = format_timestamp(input_val)
|
|
if actual != expected:
|
|
print(f"input: {input_val}, expected: {expected}, actual: {actual}")
|
|
assert actual == expected
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"input_val,expected",
|
|
[
|
|
(None, ""),
|
|
(0, "$0.0000"),
|
|
(0.0, "$0.0000"),
|
|
(0.00001, "$0.0100"),
|
|
(0.00002, "$0.0200"),
|
|
(1, "$1000.0000"),
|
|
(1.5, "$1500.0000"),
|
|
("0.00001", "$0.0100"),
|
|
("1.5", "$1500.0000"),
|
|
("not-a-number", "not-a-number"),
|
|
(1e-10, "$0.0000"),
|
|
],
|
|
)
|
|
def test_format_cost_per_1k_tokens(input_val, expected):
|
|
actual = format_cost_per_1k_tokens(input_val)
|
|
if actual != expected:
|
|
print(f"input: {input_val}, expected: {expected}, actual: {actual}")
|
|
assert actual == expected
|