mirror of
https://github.com/tiennm99/litellm.git
synced 2026-06-17 14:48:44 +00:00
292 lines
9.7 KiB
Python
292 lines
9.7 KiB
Python
import pytest
|
|
from unittest.mock import Mock, patch
|
|
|
|
|
|
def create_mock_user_api_key_auth():
|
|
"""Create mock user API key authentication."""
|
|
mock_auth = Mock()
|
|
mock_auth.api_key = "test-key"
|
|
mock_auth.user_id = "test-user"
|
|
mock_auth.team_id = "test-team"
|
|
mock_auth.team_models = []
|
|
mock_auth.models = []
|
|
return mock_auth
|
|
|
|
|
|
def create_mock_router_with_fallbacks():
|
|
"""Create a mock router with fallback configurations."""
|
|
router = Mock()
|
|
router.fallbacks = [
|
|
{"claude-4-sonnet": ["bedrock-claude-sonnet-4", "google-claude-sonnet-4"]},
|
|
{"gpt-4": ["gpt-4-turbo", "gpt-3.5-turbo"]},
|
|
]
|
|
router.context_window_fallbacks = [
|
|
{"claude-4-sonnet": ["claude-3-sonnet"]},
|
|
{"gpt-4": ["gpt-3.5-turbo"]},
|
|
]
|
|
router.content_policy_fallbacks = [{"claude-4-sonnet": ["claude-3-haiku"]}]
|
|
router.get_model_names.return_value = [
|
|
"claude-4-sonnet",
|
|
"bedrock-claude-sonnet-4",
|
|
"google-claude-sonnet-4",
|
|
"gpt-4",
|
|
"gpt-4-turbo",
|
|
"gpt-3.5-turbo",
|
|
]
|
|
router.get_model_access_groups.return_value = {}
|
|
return router
|
|
|
|
|
|
def test_model_list_function_signature():
|
|
"""Test that model_list function has the correct signature with new parameters."""
|
|
from litellm.proxy.proxy_server import model_list
|
|
import inspect
|
|
|
|
sig = inspect.signature(model_list)
|
|
params = list(sig.parameters.keys())
|
|
|
|
# Check that our new parameters are present
|
|
assert "include_metadata" in params, "include_metadata parameter missing"
|
|
assert "fallback_type" in params, "fallback_type parameter missing"
|
|
|
|
# Check parameter defaults
|
|
include_metadata_param = sig.parameters["include_metadata"]
|
|
fallback_type_param = sig.parameters["fallback_type"]
|
|
|
|
assert (
|
|
include_metadata_param.default is False
|
|
), "include_metadata should default to False"
|
|
assert fallback_type_param.default is None, "fallback_type should default to None"
|
|
|
|
|
|
@patch("litellm.proxy.proxy_server.llm_router")
|
|
@patch("litellm.proxy.proxy_server.get_complete_model_list")
|
|
@patch("litellm.proxy.proxy_server.get_key_models")
|
|
@patch("litellm.proxy.proxy_server.get_team_models")
|
|
@patch("litellm.proxy.proxy_server.get_all_fallbacks")
|
|
def test_model_list_with_fallback_metadata(
|
|
mock_get_all_fallbacks,
|
|
mock_get_team_models,
|
|
mock_get_key_models,
|
|
mock_get_complete_model_list,
|
|
mock_router,
|
|
):
|
|
"""Test model_list function with fallback metadata."""
|
|
|
|
# Setup mocks
|
|
mock_user_auth = create_mock_user_api_key_auth()
|
|
mock_router_instance = create_mock_router_with_fallbacks()
|
|
mock_router.return_value = mock_router_instance
|
|
|
|
mock_get_key_models.return_value = []
|
|
mock_get_team_models.return_value = []
|
|
mock_get_complete_model_list.return_value = [
|
|
"claude-4-sonnet",
|
|
"bedrock-claude-sonnet-4",
|
|
]
|
|
|
|
# Mock fallback responses
|
|
def fallback_side_effect(model, llm_router, fallback_type):
|
|
if model == "claude-4-sonnet":
|
|
return ["bedrock-claude-sonnet-4", "google-claude-sonnet-4"]
|
|
return []
|
|
|
|
mock_get_all_fallbacks.side_effect = fallback_side_effect
|
|
|
|
# Test async function call (simplified - just test the logic)
|
|
# Note: This is a simplified test since we can't easily run the full async endpoint
|
|
# The important thing is that our function signature and logic are correct
|
|
|
|
# Import the constants we need
|
|
try:
|
|
from litellm.proxy.proxy_server import DEFAULT_MODEL_CREATED_AT_TIME
|
|
except ImportError:
|
|
DEFAULT_MODEL_CREATED_AT_TIME = 1640995200 # Default fallback
|
|
|
|
# Test with include_metadata=True (should default to general fallbacks)
|
|
all_models = ["claude-4-sonnet", "bedrock-claude-sonnet-4"]
|
|
|
|
# Build response manually to test our logic
|
|
model_data = []
|
|
for model in all_models:
|
|
model_info = {
|
|
"id": model,
|
|
"object": "model",
|
|
"created": DEFAULT_MODEL_CREATED_AT_TIME,
|
|
"owned_by": "openai",
|
|
}
|
|
|
|
# Test metadata logic
|
|
include_metadata = True
|
|
fallback_type = None # Should default to "general"
|
|
|
|
if include_metadata:
|
|
metadata = {}
|
|
effective_fallback_type = (
|
|
fallback_type if fallback_type is not None else "general"
|
|
)
|
|
|
|
# Validate fallback_type
|
|
valid_fallback_types = ["general", "context_window", "content_policy"]
|
|
assert effective_fallback_type in valid_fallback_types
|
|
|
|
fallbacks = fallback_side_effect(
|
|
model, mock_router_instance, effective_fallback_type
|
|
)
|
|
metadata["fallbacks"] = fallbacks
|
|
model_info["metadata"] = metadata
|
|
|
|
model_data.append(model_info)
|
|
|
|
response = {
|
|
"data": model_data,
|
|
"object": "list",
|
|
}
|
|
|
|
# Verify response structure
|
|
assert "data" in response
|
|
assert "object" in response
|
|
assert response["object"] == "list"
|
|
|
|
# Find claude-4-sonnet in response
|
|
claude_model = next(
|
|
(m for m in response["data"] if m["id"] == "claude-4-sonnet"), None
|
|
)
|
|
assert claude_model is not None
|
|
assert "metadata" in claude_model
|
|
assert "fallbacks" in claude_model["metadata"]
|
|
assert claude_model["metadata"]["fallbacks"] == [
|
|
"bedrock-claude-sonnet-4",
|
|
"google-claude-sonnet-4",
|
|
]
|
|
|
|
# Find bedrock-claude-sonnet-4 in response (should have no fallbacks)
|
|
bedrock_model = next(
|
|
(m for m in response["data"] if m["id"] == "bedrock-claude-sonnet-4"), None
|
|
)
|
|
assert bedrock_model is not None
|
|
assert "metadata" in bedrock_model
|
|
assert "fallbacks" in bedrock_model["metadata"]
|
|
assert bedrock_model["metadata"]["fallbacks"] == []
|
|
|
|
|
|
def test_model_list_invalid_fallback_type_validation():
|
|
"""Test that invalid fallback_type raises proper validation error."""
|
|
# Test the validation logic
|
|
valid_fallback_types = ["general", "context_window", "content_policy"]
|
|
|
|
# Valid types should pass
|
|
for valid_type in valid_fallback_types:
|
|
assert valid_type in valid_fallback_types
|
|
|
|
# Invalid type should fail validation
|
|
invalid_type = "invalid"
|
|
assert invalid_type not in valid_fallback_types
|
|
|
|
# Test HTTPException creation logic
|
|
try:
|
|
from fastapi import HTTPException
|
|
|
|
# This is the logic from our endpoint
|
|
if invalid_type not in valid_fallback_types:
|
|
error = HTTPException(
|
|
status_code=400,
|
|
detail=f"Invalid fallback_type. Must be one of: {valid_fallback_types}",
|
|
)
|
|
assert error.status_code == 400
|
|
assert "Invalid fallback_type" in error.detail
|
|
assert "general" in error.detail
|
|
assert "context_window" in error.detail
|
|
assert "content_policy" in error.detail
|
|
except ImportError:
|
|
# FastAPI not available, skip this part
|
|
pass
|
|
|
|
|
|
def test_fallback_type_defaults_to_general():
|
|
"""Test that fallback_type defaults to 'general' when include_metadata=True."""
|
|
# Test the defaulting logic
|
|
include_metadata = True
|
|
fallback_type = None
|
|
|
|
if include_metadata:
|
|
effective_fallback_type = (
|
|
fallback_type if fallback_type is not None else "general"
|
|
)
|
|
assert effective_fallback_type == "general"
|
|
|
|
# Test with explicit general type
|
|
fallback_type = "general"
|
|
effective_fallback_type = fallback_type if fallback_type is not None else "general"
|
|
assert effective_fallback_type == "general"
|
|
|
|
# Test with other types
|
|
fallback_type = "context_window"
|
|
effective_fallback_type = fallback_type if fallback_type is not None else "general"
|
|
assert effective_fallback_type == "context_window"
|
|
|
|
|
|
def test_response_structure_compatibility():
|
|
"""Test that response structure maintains OpenAI compatibility."""
|
|
# Test basic model structure (without metadata)
|
|
basic_model = {
|
|
"id": "claude-4-sonnet",
|
|
"object": "model",
|
|
"created": 1640995200,
|
|
"owned_by": "openai",
|
|
}
|
|
|
|
required_keys = ["id", "object", "created", "owned_by"]
|
|
for key in required_keys:
|
|
assert key in basic_model, f"Required OpenAI key '{key}' missing"
|
|
|
|
# Test model with metadata
|
|
metadata_model = {
|
|
**basic_model,
|
|
"metadata": {
|
|
"fallbacks": ["bedrock-claude-sonnet-4", "google-claude-sonnet-4"]
|
|
},
|
|
}
|
|
|
|
# Should still have all required keys
|
|
for key in required_keys:
|
|
assert (
|
|
key in metadata_model
|
|
), f"Required OpenAI key '{key}' missing from metadata model"
|
|
|
|
# Should have metadata
|
|
assert "metadata" in metadata_model
|
|
assert "fallbacks" in metadata_model["metadata"]
|
|
assert isinstance(metadata_model["metadata"]["fallbacks"], list)
|
|
|
|
# Test complete response structure
|
|
response = {"data": [basic_model, metadata_model], "object": "list"}
|
|
|
|
assert "data" in response
|
|
assert "object" in response
|
|
assert response["object"] == "list"
|
|
assert isinstance(response["data"], list)
|
|
assert len(response["data"]) == 2
|
|
|
|
|
|
def test_get_all_fallbacks_integration():
|
|
"""Test that get_all_fallbacks function can be imported and has correct signature."""
|
|
from litellm.proxy.auth.model_checks import get_all_fallbacks
|
|
import inspect
|
|
|
|
# Test function signature
|
|
sig = inspect.signature(get_all_fallbacks)
|
|
params = list(sig.parameters.keys())
|
|
expected_params = ["model", "llm_router", "fallback_type"]
|
|
|
|
assert params == expected_params, f"Expected {expected_params}, got {params}"
|
|
|
|
# Test default parameter values
|
|
fallback_type_param = sig.parameters["fallback_type"]
|
|
assert (
|
|
fallback_type_param.default == "general"
|
|
), "fallback_type should default to 'general'"
|
|
|
|
llm_router_param = sig.parameters["llm_router"]
|
|
assert llm_router_param.default is None, "llm_router should default to None"
|