Files
litellm/tests/proxy_unit_tests/test_models_fallback_endpoint.py
2026-04-17 13:02:59 -07:00

292 lines
9.7 KiB
Python

import pytest
from unittest.mock import Mock, patch
def create_mock_user_api_key_auth():
"""Create mock user API key authentication."""
mock_auth = Mock()
mock_auth.api_key = "test-key"
mock_auth.user_id = "test-user"
mock_auth.team_id = "test-team"
mock_auth.team_models = []
mock_auth.models = []
return mock_auth
def create_mock_router_with_fallbacks():
"""Create a mock router with fallback configurations."""
router = Mock()
router.fallbacks = [
{"claude-4-sonnet": ["bedrock-claude-sonnet-4", "google-claude-sonnet-4"]},
{"gpt-4": ["gpt-4-turbo", "gpt-3.5-turbo"]},
]
router.context_window_fallbacks = [
{"claude-4-sonnet": ["claude-3-sonnet"]},
{"gpt-4": ["gpt-3.5-turbo"]},
]
router.content_policy_fallbacks = [{"claude-4-sonnet": ["claude-3-haiku"]}]
router.get_model_names.return_value = [
"claude-4-sonnet",
"bedrock-claude-sonnet-4",
"google-claude-sonnet-4",
"gpt-4",
"gpt-4-turbo",
"gpt-3.5-turbo",
]
router.get_model_access_groups.return_value = {}
return router
def test_model_list_function_signature():
"""Test that model_list function has the correct signature with new parameters."""
from litellm.proxy.proxy_server import model_list
import inspect
sig = inspect.signature(model_list)
params = list(sig.parameters.keys())
# Check that our new parameters are present
assert "include_metadata" in params, "include_metadata parameter missing"
assert "fallback_type" in params, "fallback_type parameter missing"
# Check parameter defaults
include_metadata_param = sig.parameters["include_metadata"]
fallback_type_param = sig.parameters["fallback_type"]
assert (
include_metadata_param.default is False
), "include_metadata should default to False"
assert fallback_type_param.default is None, "fallback_type should default to None"
@patch("litellm.proxy.proxy_server.llm_router")
@patch("litellm.proxy.proxy_server.get_complete_model_list")
@patch("litellm.proxy.proxy_server.get_key_models")
@patch("litellm.proxy.proxy_server.get_team_models")
@patch("litellm.proxy.proxy_server.get_all_fallbacks")
def test_model_list_with_fallback_metadata(
mock_get_all_fallbacks,
mock_get_team_models,
mock_get_key_models,
mock_get_complete_model_list,
mock_router,
):
"""Test model_list function with fallback metadata."""
# Setup mocks
mock_user_auth = create_mock_user_api_key_auth()
mock_router_instance = create_mock_router_with_fallbacks()
mock_router.return_value = mock_router_instance
mock_get_key_models.return_value = []
mock_get_team_models.return_value = []
mock_get_complete_model_list.return_value = [
"claude-4-sonnet",
"bedrock-claude-sonnet-4",
]
# Mock fallback responses
def fallback_side_effect(model, llm_router, fallback_type):
if model == "claude-4-sonnet":
return ["bedrock-claude-sonnet-4", "google-claude-sonnet-4"]
return []
mock_get_all_fallbacks.side_effect = fallback_side_effect
# Test async function call (simplified - just test the logic)
# Note: This is a simplified test since we can't easily run the full async endpoint
# The important thing is that our function signature and logic are correct
# Import the constants we need
try:
from litellm.proxy.proxy_server import DEFAULT_MODEL_CREATED_AT_TIME
except ImportError:
DEFAULT_MODEL_CREATED_AT_TIME = 1640995200 # Default fallback
# Test with include_metadata=True (should default to general fallbacks)
all_models = ["claude-4-sonnet", "bedrock-claude-sonnet-4"]
# Build response manually to test our logic
model_data = []
for model in all_models:
model_info = {
"id": model,
"object": "model",
"created": DEFAULT_MODEL_CREATED_AT_TIME,
"owned_by": "openai",
}
# Test metadata logic
include_metadata = True
fallback_type = None # Should default to "general"
if include_metadata:
metadata = {}
effective_fallback_type = (
fallback_type if fallback_type is not None else "general"
)
# Validate fallback_type
valid_fallback_types = ["general", "context_window", "content_policy"]
assert effective_fallback_type in valid_fallback_types
fallbacks = fallback_side_effect(
model, mock_router_instance, effective_fallback_type
)
metadata["fallbacks"] = fallbacks
model_info["metadata"] = metadata
model_data.append(model_info)
response = {
"data": model_data,
"object": "list",
}
# Verify response structure
assert "data" in response
assert "object" in response
assert response["object"] == "list"
# Find claude-4-sonnet in response
claude_model = next(
(m for m in response["data"] if m["id"] == "claude-4-sonnet"), None
)
assert claude_model is not None
assert "metadata" in claude_model
assert "fallbacks" in claude_model["metadata"]
assert claude_model["metadata"]["fallbacks"] == [
"bedrock-claude-sonnet-4",
"google-claude-sonnet-4",
]
# Find bedrock-claude-sonnet-4 in response (should have no fallbacks)
bedrock_model = next(
(m for m in response["data"] if m["id"] == "bedrock-claude-sonnet-4"), None
)
assert bedrock_model is not None
assert "metadata" in bedrock_model
assert "fallbacks" in bedrock_model["metadata"]
assert bedrock_model["metadata"]["fallbacks"] == []
def test_model_list_invalid_fallback_type_validation():
"""Test that invalid fallback_type raises proper validation error."""
# Test the validation logic
valid_fallback_types = ["general", "context_window", "content_policy"]
# Valid types should pass
for valid_type in valid_fallback_types:
assert valid_type in valid_fallback_types
# Invalid type should fail validation
invalid_type = "invalid"
assert invalid_type not in valid_fallback_types
# Test HTTPException creation logic
try:
from fastapi import HTTPException
# This is the logic from our endpoint
if invalid_type not in valid_fallback_types:
error = HTTPException(
status_code=400,
detail=f"Invalid fallback_type. Must be one of: {valid_fallback_types}",
)
assert error.status_code == 400
assert "Invalid fallback_type" in error.detail
assert "general" in error.detail
assert "context_window" in error.detail
assert "content_policy" in error.detail
except ImportError:
# FastAPI not available, skip this part
pass
def test_fallback_type_defaults_to_general():
"""Test that fallback_type defaults to 'general' when include_metadata=True."""
# Test the defaulting logic
include_metadata = True
fallback_type = None
if include_metadata:
effective_fallback_type = (
fallback_type if fallback_type is not None else "general"
)
assert effective_fallback_type == "general"
# Test with explicit general type
fallback_type = "general"
effective_fallback_type = fallback_type if fallback_type is not None else "general"
assert effective_fallback_type == "general"
# Test with other types
fallback_type = "context_window"
effective_fallback_type = fallback_type if fallback_type is not None else "general"
assert effective_fallback_type == "context_window"
def test_response_structure_compatibility():
"""Test that response structure maintains OpenAI compatibility."""
# Test basic model structure (without metadata)
basic_model = {
"id": "claude-4-sonnet",
"object": "model",
"created": 1640995200,
"owned_by": "openai",
}
required_keys = ["id", "object", "created", "owned_by"]
for key in required_keys:
assert key in basic_model, f"Required OpenAI key '{key}' missing"
# Test model with metadata
metadata_model = {
**basic_model,
"metadata": {
"fallbacks": ["bedrock-claude-sonnet-4", "google-claude-sonnet-4"]
},
}
# Should still have all required keys
for key in required_keys:
assert (
key in metadata_model
), f"Required OpenAI key '{key}' missing from metadata model"
# Should have metadata
assert "metadata" in metadata_model
assert "fallbacks" in metadata_model["metadata"]
assert isinstance(metadata_model["metadata"]["fallbacks"], list)
# Test complete response structure
response = {"data": [basic_model, metadata_model], "object": "list"}
assert "data" in response
assert "object" in response
assert response["object"] == "list"
assert isinstance(response["data"], list)
assert len(response["data"]) == 2
def test_get_all_fallbacks_integration():
"""Test that get_all_fallbacks function can be imported and has correct signature."""
from litellm.proxy.auth.model_checks import get_all_fallbacks
import inspect
# Test function signature
sig = inspect.signature(get_all_fallbacks)
params = list(sig.parameters.keys())
expected_params = ["model", "llm_router", "fallback_type"]
assert params == expected_params, f"Expected {expected_params}, got {params}"
# Test default parameter values
fallback_type_param = sig.parameters["fallback_type"]
assert (
fallback_type_param.default == "general"
), "fallback_type should default to 'general'"
llm_router_param = sig.parameters["llm_router"]
assert llm_router_param.default is None, "llm_router should default to None"