mirror of
https://github.com/tiennm99/litellm.git
synced 2026-06-18 03:31:23 +00:00
2850 lines
95 KiB
Python
2850 lines
95 KiB
Python
import copy
|
||
import json
|
||
import os
|
||
import sys
|
||
from unittest.mock import AsyncMock, MagicMock, patch
|
||
|
||
import pytest
|
||
from fastapi.testclient import TestClient
|
||
|
||
sys.path.insert(
|
||
0, os.path.abspath("../../..")
|
||
) # Adds the parent directory to the system path
|
||
|
||
|
||
import litellm
|
||
from litellm.router_utils.fallback_event_handlers import run_async_fallback
|
||
|
||
|
||
def test_update_kwargs_does_not_mutate_defaults_and_merges_metadata():
|
||
# initialize a real Router (env‑vars can be empty)
|
||
router = litellm.Router(
|
||
model_list=[
|
||
{
|
||
"model_name": "gpt-3.5-turbo",
|
||
"litellm_params": {
|
||
"model": "azure/gpt-4.1-mini",
|
||
"api_key": os.getenv("AZURE_AI_API_KEY"),
|
||
"api_version": os.getenv("AZURE_API_VERSION"),
|
||
"api_base": os.getenv("AZURE_AI_API_BASE"),
|
||
},
|
||
}
|
||
],
|
||
)
|
||
|
||
# override to known defaults for the test
|
||
router.default_litellm_params = {
|
||
"foo": "bar",
|
||
"metadata": {"baz": 123},
|
||
}
|
||
original = copy.deepcopy(router.default_litellm_params)
|
||
kwargs: dict = {}
|
||
|
||
# invoke the helper
|
||
router._update_kwargs_with_default_litellm_params(
|
||
kwargs=kwargs,
|
||
metadata_variable_name="litellm_metadata",
|
||
)
|
||
|
||
# 1) router.defaults must be unchanged
|
||
assert router.default_litellm_params == original
|
||
|
||
# 2) non‑metadata keys get merged
|
||
assert kwargs["foo"] == "bar"
|
||
|
||
# 3) metadata lands under "metadata"
|
||
assert kwargs["litellm_metadata"] == {"baz": 123}
|
||
|
||
|
||
def test_router_with_model_info_and_model_group():
|
||
"""
|
||
Test edge case where user specifies model_group in model_info
|
||
"""
|
||
router = litellm.Router(
|
||
model_list=[
|
||
{
|
||
"model_name": "gpt-3.5-turbo",
|
||
"litellm_params": {
|
||
"model": "gpt-3.5-turbo",
|
||
},
|
||
"model_info": {
|
||
"tpm": 1000,
|
||
"rpm": 1000,
|
||
"model_group": "gpt-3.5-turbo",
|
||
},
|
||
}
|
||
],
|
||
)
|
||
|
||
router._set_model_group_info(
|
||
model_group="gpt-3.5-turbo",
|
||
user_facing_model_group_name="gpt-3.5-turbo",
|
||
)
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_arouter_with_tags_and_fallbacks():
|
||
"""
|
||
If fallback model missing tag, raise error
|
||
"""
|
||
from litellm import Router
|
||
|
||
router = Router(
|
||
model_list=[
|
||
{
|
||
"model_name": "gpt-3.5-turbo",
|
||
"litellm_params": {
|
||
"model": "gpt-3.5-turbo",
|
||
"mock_response": "Hello, world!",
|
||
"tags": ["test"],
|
||
},
|
||
},
|
||
{
|
||
"model_name": "anthropic-claude-3-5-sonnet",
|
||
"litellm_params": {
|
||
"model": "claude-sonnet-4-5-20250929",
|
||
"mock_response": "Hello, world 2!",
|
||
},
|
||
},
|
||
],
|
||
fallbacks=[
|
||
{"gpt-3.5-turbo": ["anthropic-claude-3-5-sonnet"]},
|
||
],
|
||
enable_tag_filtering=True,
|
||
)
|
||
|
||
with pytest.raises(Exception):
|
||
response = await router.acompletion(
|
||
model="gpt-3.5-turbo",
|
||
messages=[{"role": "user", "content": "Hello, world!"}],
|
||
mock_testing_fallbacks=True,
|
||
metadata={"tags": ["test"]},
|
||
)
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_async_router_acreate_file():
|
||
"""
|
||
Write to all deployments of a model
|
||
"""
|
||
from unittest.mock import MagicMock, call, patch
|
||
|
||
router = litellm.Router(
|
||
model_list=[
|
||
{
|
||
"model_name": "gpt-3.5-turbo",
|
||
"litellm_params": {"model": "gpt-3.5-turbo"},
|
||
},
|
||
{"model_name": "gpt-3.5-turbo", "litellm_params": {"model": "gpt-4o-mini"}},
|
||
],
|
||
)
|
||
|
||
with patch("litellm.acreate_file", return_value=MagicMock()) as mock_acreate_file:
|
||
mock_acreate_file.return_value = MagicMock()
|
||
response = await router.acreate_file(
|
||
model="gpt-3.5-turbo",
|
||
purpose="test",
|
||
file=MagicMock(),
|
||
)
|
||
|
||
# assert that the mock_acreate_file was called twice
|
||
assert mock_acreate_file.call_count == 2
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_async_router_acreate_file_with_jsonl():
|
||
"""
|
||
Test router.acreate_file with both JSONL and non-JSONL files
|
||
"""
|
||
import json
|
||
from io import BytesIO
|
||
from unittest.mock import MagicMock, patch
|
||
|
||
# Create test JSONL content
|
||
jsonl_data = [
|
||
{
|
||
"body": {
|
||
"model": "gpt-3.5-turbo-router",
|
||
"messages": [{"role": "user", "content": "test"}],
|
||
}
|
||
},
|
||
{
|
||
"body": {
|
||
"model": "gpt-3.5-turbo-router",
|
||
"messages": [{"role": "user", "content": "test2"}],
|
||
}
|
||
},
|
||
]
|
||
jsonl_content = "\n".join(json.dumps(item) for item in jsonl_data)
|
||
jsonl_file = BytesIO(jsonl_content.encode("utf-8"))
|
||
jsonl_file.name = "test.jsonl"
|
||
|
||
# Create test non-JSONL content
|
||
non_jsonl_content = "This is not a JSONL file"
|
||
non_jsonl_file = BytesIO(non_jsonl_content.encode("utf-8"))
|
||
non_jsonl_file.name = "test.txt"
|
||
|
||
router = litellm.Router(
|
||
model_list=[
|
||
{
|
||
"model_name": "gpt-3.5-turbo-router",
|
||
"litellm_params": {"model": "gpt-3.5-turbo"},
|
||
},
|
||
{
|
||
"model_name": "gpt-3.5-turbo-router",
|
||
"litellm_params": {"model": "gpt-4o-mini"},
|
||
},
|
||
],
|
||
)
|
||
|
||
with patch("litellm.acreate_file", return_value=MagicMock()) as mock_acreate_file:
|
||
# Test with JSONL file
|
||
response = await router.acreate_file(
|
||
model="gpt-3.5-turbo-router",
|
||
purpose="batch",
|
||
file=jsonl_file,
|
||
)
|
||
|
||
# Verify mock was called twice (once for each deployment)
|
||
print(f"mock_acreate_file.call_count: {mock_acreate_file.call_count}")
|
||
print(f"mock_acreate_file.call_args_list: {mock_acreate_file.call_args_list}")
|
||
assert mock_acreate_file.call_count == 2
|
||
|
||
# Get the file content passed to the first call
|
||
first_call_file = mock_acreate_file.call_args_list[0][1]["file"]
|
||
first_call_content = first_call_file.read().decode("utf-8")
|
||
|
||
# Verify the model name was replaced in the JSONL content
|
||
first_line = json.loads(first_call_content.split("\n")[0])
|
||
assert first_line["body"]["model"] == "gpt-3.5-turbo"
|
||
|
||
# Reset mock for next test
|
||
mock_acreate_file.reset_mock()
|
||
|
||
# Test with non-JSONL file
|
||
response = await router.acreate_file(
|
||
model="gpt-3.5-turbo-router",
|
||
purpose="user_data",
|
||
file=non_jsonl_file,
|
||
)
|
||
|
||
# Verify mock was called twice
|
||
assert mock_acreate_file.call_count == 2
|
||
|
||
# Get the file content passed to the first call
|
||
first_call_file = mock_acreate_file.call_args_list[0][1]["file"]
|
||
first_call_content = first_call_file.read().decode("utf-8")
|
||
|
||
# Verify the non-JSONL content was not modified
|
||
assert first_call_content == non_jsonl_content
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_arouter_async_get_healthy_deployments():
|
||
"""
|
||
Test that afile_content returns the correct file content
|
||
"""
|
||
router = litellm.Router(
|
||
model_list=[
|
||
{
|
||
"model_name": "gpt-3.5-turbo",
|
||
"litellm_params": {"model": "gpt-3.5-turbo"},
|
||
},
|
||
],
|
||
)
|
||
|
||
result = await router.async_get_healthy_deployments(
|
||
model="gpt-3.5-turbo",
|
||
request_kwargs={},
|
||
messages=None,
|
||
input=None,
|
||
specific_deployment=False,
|
||
parent_otel_span=None,
|
||
)
|
||
|
||
assert len(result) == 1
|
||
assert result[0]["model_name"] == "gpt-3.5-turbo"
|
||
assert result[0]["litellm_params"]["model"] == "gpt-3.5-turbo"
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
@patch("litellm.amoderation")
|
||
async def test_arouter_amoderation_with_credential_name(mock_amoderation):
|
||
"""
|
||
Test that router.amoderation passes litellm_credential_name to the underlying litellm.amoderation call
|
||
"""
|
||
mock_amoderation.return_value = AsyncMock()
|
||
|
||
router = litellm.Router(
|
||
model_list=[
|
||
{
|
||
"model_name": "text-moderation-stable",
|
||
"litellm_params": {
|
||
"model": "text-moderation-stable",
|
||
"litellm_credential_name": "my-custom-auth",
|
||
},
|
||
},
|
||
],
|
||
)
|
||
|
||
await router.amoderation(input="I love everyone!", model="text-moderation-stable")
|
||
|
||
mock_amoderation.assert_called_once()
|
||
call_kwargs = mock_amoderation.call_args[1] # Get the kwargs of the call
|
||
print(
|
||
"call kwargs for router.amoderation=",
|
||
json.dumps(call_kwargs, indent=4, default=str),
|
||
)
|
||
assert call_kwargs["litellm_credential_name"] == "my-custom-auth"
|
||
assert call_kwargs["model"] == "text-moderation-stable"
|
||
|
||
|
||
def test_arouter_test_team_model():
|
||
"""
|
||
Test that router.test_team_model returns the correct model
|
||
"""
|
||
router = litellm.Router(
|
||
model_list=[
|
||
{
|
||
"model_name": "gpt-3.5-turbo",
|
||
"litellm_params": {"model": "gpt-3.5-turbo"},
|
||
"model_info": {
|
||
"team_id": "test-team",
|
||
"team_public_model_name": "test-model",
|
||
},
|
||
},
|
||
],
|
||
)
|
||
|
||
result = router.map_team_model(team_model_name="test-model", team_id="test-team")
|
||
assert result is not None
|
||
|
||
|
||
def test_arouter_ignore_invalid_deployments():
|
||
"""
|
||
Test that router.ignore_invalid_deployments is set to True
|
||
"""
|
||
from litellm.types.router import Deployment
|
||
|
||
router = litellm.Router(
|
||
model_list=[
|
||
{
|
||
"model_name": "gpt-3.5-turbo",
|
||
"litellm_params": {"model": "my-bad-model"},
|
||
},
|
||
],
|
||
ignore_invalid_deployments=True,
|
||
)
|
||
|
||
assert router.ignore_invalid_deployments is True
|
||
assert router.get_model_list() == []
|
||
|
||
## check upsert deployment
|
||
router.upsert_deployment(
|
||
Deployment(
|
||
model_name="gpt-3.5-turbo",
|
||
litellm_params={"model": "my-bad-model"}, # type: ignore
|
||
model_info={"tpm": 1000, "rpm": 1000},
|
||
)
|
||
)
|
||
|
||
assert router.get_model_list() == []
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_arouter_aretrieve_batch():
|
||
"""
|
||
Test that router.aretrieve_batch returns the correct response
|
||
"""
|
||
router = litellm.Router(
|
||
model_list=[
|
||
{
|
||
"model_name": "gpt-3.5-turbo",
|
||
"litellm_params": {
|
||
"model": "gpt-3.5-turbo",
|
||
"custom_llm_provider": "azure",
|
||
"api_key": "my-custom-key",
|
||
"api_base": "my-custom-base",
|
||
},
|
||
}
|
||
],
|
||
)
|
||
|
||
with patch.object(
|
||
litellm, "aretrieve_batch", return_value=AsyncMock()
|
||
) as mock_aretrieve_batch:
|
||
try:
|
||
response = await router.aretrieve_batch(
|
||
model="gpt-3.5-turbo",
|
||
)
|
||
except Exception as e:
|
||
print(f"Error: {e}")
|
||
|
||
mock_aretrieve_batch.assert_called_once()
|
||
|
||
print(mock_aretrieve_batch.call_args.kwargs)
|
||
assert mock_aretrieve_batch.call_args.kwargs["api_key"] == "my-custom-key"
|
||
assert mock_aretrieve_batch.call_args.kwargs["api_base"] == "my-custom-base"
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_arouter_aretrieve_file_content():
|
||
"""
|
||
Test that router.acreate_file with JSONL file returns the correct response
|
||
"""
|
||
|
||
with patch.object(
|
||
litellm, "afile_content", return_value=AsyncMock()
|
||
) as mock_afile_content:
|
||
router = litellm.Router(
|
||
model_list=[
|
||
{
|
||
"model_name": "gpt-3.5-turbo",
|
||
"litellm_params": {
|
||
"model": "gpt-3.5-turbo",
|
||
"custom_llm_provider": "azure",
|
||
"api_key": "my-custom-key",
|
||
"api_base": "my-custom-base",
|
||
},
|
||
}
|
||
],
|
||
)
|
||
try:
|
||
response = await router.afile_content(
|
||
**{
|
||
"model": "gpt-3.5-turbo",
|
||
"file_id": "my-unique-file-id",
|
||
}
|
||
) # type: ignore
|
||
except Exception as e:
|
||
print(f"Error: {e}")
|
||
|
||
mock_afile_content.assert_called_once()
|
||
|
||
print(mock_afile_content.call_args.kwargs)
|
||
assert mock_afile_content.call_args.kwargs["api_key"] == "my-custom-key"
|
||
assert mock_afile_content.call_args.kwargs["api_base"] == "my-custom-base"
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_arouter_filter_team_based_models():
|
||
"""
|
||
Test that router.filter_team_based_models filters out models that are not in the team
|
||
"""
|
||
from litellm.types.router import Deployment
|
||
|
||
router = litellm.Router(
|
||
model_list=[
|
||
{
|
||
"model_name": "gpt-3.5-turbo",
|
||
"litellm_params": {"model": "gpt-3.5-turbo"},
|
||
"model_info": {
|
||
"team_id": "test-team",
|
||
},
|
||
},
|
||
],
|
||
)
|
||
|
||
# WORKS
|
||
result = await router.acompletion(
|
||
model="gpt-3.5-turbo",
|
||
messages=[{"role": "user", "content": "Hello, world!"}],
|
||
metadata={"user_api_key_team_id": "test-team"},
|
||
mock_response="Hello, world!",
|
||
)
|
||
|
||
assert result is not None
|
||
|
||
# FAILS
|
||
with pytest.raises(Exception) as e:
|
||
result = await router.acompletion(
|
||
model="gpt-3.5-turbo",
|
||
messages=[{"role": "user", "content": "Hello, world!"}],
|
||
metadata={"user_api_key_team_id": "test-team-2"},
|
||
mock_response="Hello, world!",
|
||
)
|
||
assert "No deployments available" in str(e.value)
|
||
|
||
## ADD A MODEL THAT IS NOT IN THE TEAM
|
||
router.add_deployment(
|
||
Deployment(
|
||
model_name="gpt-3.5-turbo",
|
||
litellm_params={"model": "gpt-3.5-turbo"}, # type: ignore
|
||
model_info={"tpm": 1000, "rpm": 1000},
|
||
)
|
||
)
|
||
|
||
result = await router.acompletion(
|
||
model="gpt-3.5-turbo",
|
||
messages=[{"role": "user", "content": "Hello, world!"}],
|
||
metadata={"user_api_key_team_id": "test-team-2"},
|
||
mock_response="Hello, world!",
|
||
)
|
||
|
||
assert result is not None
|
||
|
||
|
||
def test_arouter_should_include_deployment():
|
||
"""
|
||
Test the should_include_deployment method with various scenarios
|
||
|
||
The method logic:
|
||
1. Returns True if: team_id matches AND model_name matches team_public_model_name
|
||
2. Returns True if: model_name matches AND deployment has no team_id
|
||
3. Otherwise returns False
|
||
"""
|
||
router = litellm.Router(
|
||
model_list=[
|
||
{
|
||
"model_name": "gpt-3.5-turbo",
|
||
"litellm_params": {"model": "gpt-3.5-turbo"},
|
||
"model_info": {
|
||
"team_id": "test-team",
|
||
},
|
||
},
|
||
],
|
||
)
|
||
|
||
# Test deployment structures
|
||
deployment_with_team_and_public_name = {
|
||
"model_name": "gpt-3.5-turbo",
|
||
"model_info": {
|
||
"team_id": "test-team",
|
||
"team_public_model_name": "team-gpt-model",
|
||
},
|
||
}
|
||
|
||
deployment_with_team_no_public_name = {
|
||
"model_name": "gpt-3.5-turbo",
|
||
"model_info": {
|
||
"team_id": "test-team",
|
||
},
|
||
}
|
||
|
||
deployment_without_team = {
|
||
"model_name": "gpt-4",
|
||
"model_info": {},
|
||
}
|
||
|
||
deployment_different_team = {
|
||
"model_name": "claude-3",
|
||
"model_info": {
|
||
"team_id": "other-team",
|
||
"team_public_model_name": "team-claude-model",
|
||
},
|
||
}
|
||
|
||
# Test Case 1: Team-specific deployment - team_id and team_public_model_name match
|
||
result = router.should_include_deployment(
|
||
model_name="team-gpt-model",
|
||
model=deployment_with_team_and_public_name,
|
||
team_id="test-team",
|
||
)
|
||
assert (
|
||
result is True
|
||
), "Should return True when team_id and team_public_model_name match"
|
||
|
||
# Test Case 2: Team-specific deployment - team_id matches but model_name doesn't match team_public_model_name
|
||
result = router.should_include_deployment(
|
||
model_name="different-model",
|
||
model=deployment_with_team_and_public_name,
|
||
team_id="test-team",
|
||
)
|
||
assert (
|
||
result is False
|
||
), "Should return False when team_id matches but model_name doesn't match team_public_model_name"
|
||
|
||
# Test Case 3: Team-specific deployment - team_id doesn't match
|
||
result = router.should_include_deployment(
|
||
model_name="team-gpt-model",
|
||
model=deployment_with_team_and_public_name,
|
||
team_id="different-team",
|
||
)
|
||
assert result is False, "Should return False when team_id doesn't match"
|
||
|
||
# Test Case 4: Team-specific deployment with no team_public_model_name - should fail
|
||
result = router.should_include_deployment(
|
||
model_name="gpt-3.5-turbo",
|
||
model=deployment_with_team_no_public_name,
|
||
team_id="test-team",
|
||
)
|
||
assert (
|
||
result is True
|
||
), "Should return True when team deployment has no team_public_model_name to match"
|
||
|
||
# Test Case 5: Non-team deployment - model_name matches and no team_id
|
||
result = router.should_include_deployment(
|
||
model_name="gpt-4", model=deployment_without_team, team_id=None
|
||
)
|
||
assert (
|
||
result is True
|
||
), "Should return True when model_name matches and deployment has no team_id"
|
||
|
||
# Test Case 6: Non-team deployment - model_name matches but team_id provided (should still work)
|
||
result = router.should_include_deployment(
|
||
model_name="gpt-4", model=deployment_without_team, team_id="any-team"
|
||
)
|
||
assert (
|
||
result is True
|
||
), "Should return True when model_name matches non-team deployment, regardless of team_id param"
|
||
|
||
# Test Case 7: Non-team deployment - model_name doesn't match
|
||
result = router.should_include_deployment(
|
||
model_name="different-model", model=deployment_without_team, team_id=None
|
||
)
|
||
assert result is False, "Should return False when model_name doesn't match"
|
||
|
||
# Test Case 8: Team deployment accessed without matching team_id
|
||
result = router.should_include_deployment(
|
||
model_name="gpt-3.5-turbo",
|
||
model=deployment_with_team_and_public_name,
|
||
team_id=None,
|
||
)
|
||
assert (
|
||
result is True
|
||
), "Should return True when matching model with exact model_name"
|
||
|
||
|
||
def test_arouter_responses_api_bridge():
|
||
"""
|
||
Test that router.responses_api_bridge returns the correct response
|
||
"""
|
||
from unittest.mock import MagicMock, patch
|
||
|
||
from litellm.llms.custom_httpx.http_handler import HTTPHandler
|
||
|
||
router = litellm.Router(
|
||
model_list=[
|
||
{
|
||
"model_name": "[IP-approved] o3-pro",
|
||
"litellm_params": {
|
||
"model": "azure/responses/o_series/webinterface-o3-pro",
|
||
"api_base": "https://webhook.site/fba79dae-220a-4bb7-9a3a-8caa49604e55",
|
||
"api_key": "sk-1234567890",
|
||
"api_version": "preview",
|
||
"stream": True,
|
||
},
|
||
"model_info": {
|
||
"input_cost_per_token": 0.00002,
|
||
"output_cost_per_token": 0.00008,
|
||
},
|
||
}
|
||
],
|
||
)
|
||
|
||
## CONFIRM BRIDGE IS CALLED
|
||
with patch.object(litellm, "responses", return_value=AsyncMock()) as mock_responses:
|
||
result = router.completion(
|
||
model="[IP-approved] o3-pro",
|
||
messages=[{"role": "user", "content": "Hello, world!"}],
|
||
)
|
||
assert mock_responses.call_count == 1
|
||
|
||
## CONFIRM MODEL NAME IS STRIPPED
|
||
client = HTTPHandler()
|
||
|
||
mock_response = MagicMock()
|
||
mock_response.status_code = 200
|
||
mock_response.headers = {"content-type": "application/json"}
|
||
mock_response.json.return_value = {
|
||
"id": "resp_test",
|
||
"object": "response",
|
||
"status": "completed",
|
||
"output": [],
|
||
}
|
||
mock_response.text = (
|
||
'{"id": "resp_test", "object": "response", "status": "completed", "output": []}'
|
||
)
|
||
|
||
with patch.object(client, "post", return_value=mock_response) as mock_post:
|
||
try:
|
||
result = router.completion(
|
||
model="[IP-approved] o3-pro",
|
||
messages=[{"role": "user", "content": "Hello, world!"}],
|
||
client=client,
|
||
num_retries=0,
|
||
)
|
||
except Exception as e:
|
||
print(f"Error: {e}")
|
||
|
||
assert mock_post.call_count == 1
|
||
assert (
|
||
mock_post.call_args.kwargs["url"]
|
||
== "https://webhook.site/fba79dae-220a-4bb7-9a3a-8caa49604e55/openai/v1/responses?api-version=preview"
|
||
)
|
||
assert mock_post.call_args.kwargs["json"]["model"] == "webinterface-o3-pro"
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_router_v1_messages_fallbacks():
|
||
"""
|
||
Test that router.v1_messages_fallbacks returns the correct response
|
||
"""
|
||
router = litellm.Router(
|
||
model_list=[
|
||
{
|
||
"model_name": "claude-sonnet-4-5-20250929",
|
||
"litellm_params": {
|
||
"model": "anthropic/claude-sonnet-4-5-20250929",
|
||
"mock_response": "litellm.InternalServerError",
|
||
},
|
||
},
|
||
{
|
||
"model_name": "bedrock-claude",
|
||
"litellm_params": {
|
||
"model": "anthropic.claude-3-5-sonnet-20240620-v1:0",
|
||
"mock_response": "Hello, world I am a fallback!",
|
||
},
|
||
},
|
||
],
|
||
fallbacks=[
|
||
{"claude-sonnet-4-5-20250929": ["bedrock-claude"]},
|
||
],
|
||
)
|
||
|
||
result = await router.aanthropic_messages(
|
||
model="claude-sonnet-4-5-20250929",
|
||
messages=[{"role": "user", "content": "Hello, world!"}],
|
||
max_tokens=256,
|
||
)
|
||
assert result is not None
|
||
|
||
print(result)
|
||
assert result["content"][0]["text"] == "Hello, world I am a fallback!"
|
||
|
||
|
||
def test_add_invalid_provider_to_router():
|
||
"""
|
||
Test that router.add_deployment raises an error if the provider is invalid
|
||
"""
|
||
from litellm.types.router import Deployment
|
||
|
||
router = litellm.Router(
|
||
model_list=[
|
||
{
|
||
"model_name": "gpt-3.5-turbo",
|
||
"litellm_params": {"model": "gpt-3.5-turbo"},
|
||
}
|
||
],
|
||
)
|
||
|
||
with pytest.raises(Exception) as e:
|
||
router.add_deployment(
|
||
Deployment(
|
||
model_name="vertex_ai/*",
|
||
litellm_params={
|
||
"model": "vertex_ai/*",
|
||
"custom_llm_provider": "vertex_ai_eu",
|
||
},
|
||
)
|
||
)
|
||
|
||
assert router.pattern_router.patterns == {}
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_router_ageneric_api_call_with_fallbacks_helper():
|
||
"""
|
||
Test the _ageneric_api_call_with_fallbacks_helper method with various scenarios
|
||
"""
|
||
from unittest.mock import AsyncMock, MagicMock, patch
|
||
|
||
router = litellm.Router(
|
||
model_list=[
|
||
{
|
||
"model_name": "gpt-3.5-turbo",
|
||
"litellm_params": {
|
||
"model": "gpt-3.5-turbo",
|
||
"api_key": "test-key",
|
||
"api_base": "https://api.openai.com/v1",
|
||
},
|
||
"model_info": {
|
||
"tpm": 1000,
|
||
"rpm": 1000,
|
||
},
|
||
},
|
||
],
|
||
)
|
||
|
||
# Test 1: Successful call
|
||
async def mock_generic_function(**kwargs):
|
||
return {"result": "success", "model": kwargs.get("model")}
|
||
|
||
with patch.object(router, "async_get_available_deployment") as mock_get_deployment:
|
||
mock_get_deployment.return_value = {
|
||
"model_name": "gpt-3.5-turbo",
|
||
"litellm_params": {
|
||
"model": "gpt-3.5-turbo",
|
||
"api_key": "test-key",
|
||
"api_base": "https://api.openai.com/v1",
|
||
},
|
||
}
|
||
|
||
with patch.object(
|
||
router, "_update_kwargs_with_deployment"
|
||
) as mock_update_kwargs:
|
||
with patch.object(
|
||
router, "async_routing_strategy_pre_call_checks"
|
||
) as mock_pre_call_checks:
|
||
with patch.object(
|
||
router, "_get_client", return_value=None
|
||
) as mock_get_client:
|
||
result = await router._ageneric_api_call_with_fallbacks_helper(
|
||
model="gpt-3.5-turbo",
|
||
original_generic_function=mock_generic_function,
|
||
messages=[{"role": "user", "content": "test"}],
|
||
)
|
||
|
||
assert result is not None
|
||
assert result["result"] == "success"
|
||
mock_get_deployment.assert_called_once()
|
||
mock_update_kwargs.assert_called_once()
|
||
mock_pre_call_checks.assert_called_once()
|
||
|
||
# Test 2: Passthrough on no deployment (success case)
|
||
async def mock_passthrough_function(**kwargs):
|
||
return {"result": "passthrough", "model": kwargs.get("model")}
|
||
|
||
with patch.object(router, "async_get_available_deployment") as mock_get_deployment:
|
||
mock_get_deployment.side_effect = Exception("No deployment available")
|
||
|
||
result = await router._ageneric_api_call_with_fallbacks_helper(
|
||
model="gpt-3.5-turbo",
|
||
original_generic_function=mock_passthrough_function,
|
||
passthrough_on_no_deployment=True,
|
||
messages=[{"role": "user", "content": "test"}],
|
||
)
|
||
|
||
assert result is not None
|
||
assert result["result"] == "passthrough"
|
||
assert result["model"] == "gpt-3.5-turbo"
|
||
|
||
# Test 3: No deployment available and passthrough=False (should raise exception)
|
||
with patch.object(router, "async_get_available_deployment") as mock_get_deployment:
|
||
mock_get_deployment.side_effect = Exception("No deployment available")
|
||
|
||
with pytest.raises(Exception) as exc_info:
|
||
await router._ageneric_api_call_with_fallbacks_helper(
|
||
model="gpt-3.5-turbo",
|
||
original_generic_function=mock_generic_function,
|
||
passthrough_on_no_deployment=False,
|
||
messages=[{"role": "user", "content": "test"}],
|
||
)
|
||
|
||
assert "No deployment available" in str(exc_info.value)
|
||
|
||
# Test 4: Test with semaphore (rate limiting)
|
||
import asyncio
|
||
|
||
async def mock_semaphore_function(**kwargs):
|
||
return {"result": "semaphore_success", "model": kwargs.get("model")}
|
||
|
||
with patch.object(router, "async_get_available_deployment") as mock_get_deployment:
|
||
mock_get_deployment.return_value = {
|
||
"model_name": "gpt-3.5-turbo",
|
||
"litellm_params": {
|
||
"model": "gpt-3.5-turbo",
|
||
"api_key": "test-key",
|
||
"api_base": "https://api.openai.com/v1",
|
||
},
|
||
}
|
||
|
||
mock_semaphore = asyncio.Semaphore(1)
|
||
|
||
with patch.object(
|
||
router, "_update_kwargs_with_deployment"
|
||
) as mock_update_kwargs:
|
||
with patch.object(
|
||
router, "_get_client", return_value=mock_semaphore
|
||
) as mock_get_client:
|
||
with patch.object(
|
||
router, "async_routing_strategy_pre_call_checks"
|
||
) as mock_pre_call_checks:
|
||
result = await router._ageneric_api_call_with_fallbacks_helper(
|
||
model="gpt-3.5-turbo",
|
||
original_generic_function=mock_semaphore_function,
|
||
messages=[{"role": "user", "content": "test"}],
|
||
)
|
||
|
||
assert result is not None
|
||
assert result["result"] == "semaphore_success"
|
||
mock_get_client.assert_called_once()
|
||
mock_pre_call_checks.assert_called_once()
|
||
|
||
# Test 5: Test call tracking (success and failure counts)
|
||
initial_success_count = router.success_calls.get("gpt-3.5-turbo", 0)
|
||
initial_fail_count = router.fail_calls.get("gpt-3.5-turbo", 0)
|
||
|
||
async def mock_failing_function(**kwargs):
|
||
raise Exception("Mock failure")
|
||
|
||
with patch.object(router, "async_get_available_deployment") as mock_get_deployment:
|
||
mock_get_deployment.return_value = {
|
||
"model_name": "gpt-3.5-turbo",
|
||
"litellm_params": {
|
||
"model": "gpt-3.5-turbo",
|
||
"api_key": "test-key",
|
||
"api_base": "https://api.openai.com/v1",
|
||
},
|
||
}
|
||
|
||
with patch.object(
|
||
router, "_update_kwargs_with_deployment"
|
||
) as mock_update_kwargs:
|
||
with patch.object(
|
||
router, "_get_client", return_value=None
|
||
) as mock_get_client:
|
||
with patch.object(
|
||
router, "async_routing_strategy_pre_call_checks"
|
||
) as mock_pre_call_checks:
|
||
with pytest.raises(Exception) as exc_info:
|
||
await router._ageneric_api_call_with_fallbacks_helper(
|
||
model="gpt-3.5-turbo",
|
||
original_generic_function=mock_failing_function,
|
||
messages=[{"role": "user", "content": "test"}],
|
||
)
|
||
|
||
assert "Mock failure" in str(exc_info.value)
|
||
# Check that fail_calls was incremented
|
||
assert router.fail_calls["gpt-3.5-turbo"] == initial_fail_count + 1
|
||
|
||
|
||
def test_router_get_model_access_groups_team_only_models():
|
||
"""
|
||
Test that Router.get_model_access_groups returns the correct response for team-only models
|
||
"""
|
||
router = litellm.Router(
|
||
model_list=[
|
||
{
|
||
"model_name": "my-custom-model-name",
|
||
"litellm_params": {"model": "gpt-3.5-turbo"},
|
||
"model_info": {
|
||
"team_id": "team_1",
|
||
"access_groups": ["default-models"],
|
||
"team_public_model_name": "gpt-3.5-turbo",
|
||
},
|
||
},
|
||
]
|
||
)
|
||
|
||
access_groups = router.get_model_access_groups(
|
||
model_name="gpt-3.5-turbo", team_id=None
|
||
)
|
||
assert len(access_groups) == 0
|
||
|
||
access_groups = router.get_model_access_groups(
|
||
model_name="gpt-3.5-turbo", team_id="team_1"
|
||
)
|
||
assert list(access_groups.keys()) == ["default-models"]
|
||
|
||
|
||
def test_cached_get_model_group_info():
|
||
"""
|
||
Test that _cached_get_model_group_info caches results and
|
||
invalidates on deployment changes.
|
||
"""
|
||
from litellm.types.router import Deployment, LiteLLM_Params
|
||
|
||
router = litellm.Router(
|
||
model_list=[
|
||
{
|
||
"model_name": "gpt-4",
|
||
"litellm_params": {"model": "gpt-4", "api_key": "fake"},
|
||
"model_info": {"tpm": 1000, "rpm": 100},
|
||
},
|
||
]
|
||
)
|
||
|
||
# First call should compute and cache
|
||
result1 = router._cached_get_model_group_info("gpt-4")
|
||
assert result1 is not None
|
||
assert result1.tpm == 1000
|
||
|
||
# Second call should hit cache (same object)
|
||
result2 = router._cached_get_model_group_info("gpt-4")
|
||
assert result1 is result2
|
||
|
||
# Add a deployment — cache should be invalidated
|
||
router.add_deployment(
|
||
Deployment(
|
||
model_name="gpt-4",
|
||
litellm_params=LiteLLM_Params(model="gpt-4", api_key="fake2"),
|
||
model_info={"tpm": 2000, "rpm": 200},
|
||
)
|
||
)
|
||
result3 = router._cached_get_model_group_info("gpt-4")
|
||
assert result3 is not result2
|
||
assert result3 is not None
|
||
assert result3.tpm == 3000 # 1000 + 2000
|
||
|
||
# Delete a deployment — cache should be invalidated
|
||
deployment_id = router.model_list[-1]["model_info"]["id"]
|
||
router.delete_deployment(id=deployment_id)
|
||
result4 = router._cached_get_model_group_info("gpt-4")
|
||
assert result4 is not result3
|
||
assert result4 is not None
|
||
assert result4.tpm == 1000
|
||
|
||
# set_model_list — cache should be invalidated
|
||
router.set_model_list(
|
||
[
|
||
{
|
||
"model_name": "gpt-4",
|
||
"litellm_params": {"model": "gpt-4", "api_key": "fake"},
|
||
"model_info": {"tpm": 5000},
|
||
},
|
||
]
|
||
)
|
||
result5 = router._cached_get_model_group_info("gpt-4")
|
||
assert result5 is not result4
|
||
assert result5 is not None
|
||
assert result5.tpm == 5000
|
||
|
||
# Verify cache still works after invalidation
|
||
result6 = router._cached_get_model_group_info("gpt-4")
|
||
assert result5 is result6
|
||
|
||
|
||
def test_get_model_access_groups_caching():
|
||
"""
|
||
Test that get_model_access_groups caches the no-args result
|
||
and invalidates on deployment changes.
|
||
"""
|
||
from litellm.types.router import Deployment, LiteLLM_Params
|
||
|
||
router = litellm.Router(
|
||
model_list=[
|
||
{
|
||
"model_name": "gpt-4",
|
||
"litellm_params": {"model": "gpt-4"},
|
||
"model_info": {"access_groups": ["premium"]},
|
||
},
|
||
]
|
||
)
|
||
|
||
# First call computes and populates cache
|
||
result1 = router.get_model_access_groups()
|
||
assert "premium" in result1
|
||
|
||
# All subsequent calls should return the same cached object (including first)
|
||
result2 = router.get_model_access_groups()
|
||
assert result1 is result2
|
||
|
||
# Calls with args should bypass cache
|
||
result_with_args = router.get_model_access_groups(model_name="gpt-4")
|
||
assert result_with_args is not result2
|
||
|
||
# Add a deployment — cache should be invalidated
|
||
router.add_deployment(
|
||
Deployment(
|
||
model_name="gpt-3.5",
|
||
litellm_params=LiteLLM_Params(model="gpt-3.5-turbo"),
|
||
model_info={"access_groups": ["default"]},
|
||
)
|
||
)
|
||
result3 = router.get_model_access_groups()
|
||
assert result3 is not result2
|
||
assert "premium" in result3
|
||
assert "default" in result3
|
||
|
||
# Delete the deployment — cache should be invalidated again
|
||
deployment_id = None
|
||
for m in router.model_list:
|
||
if m.get("model_name") == "gpt-3.5":
|
||
deployment_id = m.get("model_info", {}).get("id")
|
||
break
|
||
assert deployment_id is not None
|
||
router.delete_deployment(id=deployment_id)
|
||
result4 = router.get_model_access_groups()
|
||
assert result4 is not result3
|
||
assert "default" not in result4
|
||
assert "premium" in result4
|
||
|
||
|
||
def test_get_model_access_groups_cache_invalidation_set_model_list():
|
||
"""
|
||
Test that set_model_list invalidates the access groups cache.
|
||
"""
|
||
router = litellm.Router(
|
||
model_list=[
|
||
{
|
||
"model_name": "gpt-4",
|
||
"litellm_params": {"model": "gpt-4"},
|
||
"model_info": {"access_groups": ["premium"]},
|
||
},
|
||
]
|
||
)
|
||
|
||
# Populate cache
|
||
result1 = router.get_model_access_groups()
|
||
assert "premium" in result1
|
||
|
||
# set_model_list should invalidate cache
|
||
router.set_model_list(
|
||
[
|
||
{
|
||
"model_name": "claude-3",
|
||
"litellm_params": {"model": "anthropic/claude-3-opus-20240229"},
|
||
"model_info": {"access_groups": ["research"]},
|
||
},
|
||
]
|
||
)
|
||
result2 = router.get_model_access_groups()
|
||
assert result2 is not result1
|
||
assert "research" in result2
|
||
assert "premium" not in result2
|
||
|
||
|
||
def test_get_model_access_groups_cache_invalidation_upsert_deployment():
|
||
"""
|
||
Test that upsert_deployment invalidates the access groups cache.
|
||
"""
|
||
from litellm.types.router import Deployment, LiteLLM_Params
|
||
|
||
router = litellm.Router(
|
||
model_list=[
|
||
{
|
||
"model_name": "gpt-4",
|
||
"litellm_params": {"model": "gpt-4"},
|
||
"model_info": {"access_groups": ["premium"]},
|
||
},
|
||
]
|
||
)
|
||
|
||
# Populate cache
|
||
result1 = router.get_model_access_groups()
|
||
assert "premium" in result1
|
||
|
||
# Get the existing deployment's ID
|
||
existing_id = router.model_list[0]["model_info"]["id"]
|
||
|
||
# Upsert with the same ID but different params — triggers pop + re-add
|
||
router.upsert_deployment(
|
||
Deployment(
|
||
model_name="gpt-4-updated",
|
||
litellm_params=LiteLLM_Params(model="gpt-4-turbo"),
|
||
model_info={"id": existing_id, "access_groups": ["updated-group"]},
|
||
)
|
||
)
|
||
result2 = router.get_model_access_groups()
|
||
assert result2 is not result1
|
||
assert "updated-group" in result2
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_acompletion_streaming_iterator():
|
||
"""Test _acompletion_streaming_iterator for normal streaming and fallback behavior."""
|
||
from unittest.mock import AsyncMock, MagicMock
|
||
|
||
from litellm.exceptions import MidStreamFallbackError
|
||
from litellm.types.utils import ModelResponseStream
|
||
|
||
# Helper class for creating async iterators
|
||
class AsyncIterator:
|
||
def __init__(self, items, error_after=None):
|
||
self.items = items
|
||
self.index = 0
|
||
self.error_after = error_after
|
||
|
||
def __aiter__(self):
|
||
return self
|
||
|
||
async def __anext__(self):
|
||
if self.error_after is not None and self.index >= self.error_after:
|
||
raise self.error_after
|
||
if self.index >= len(self.items):
|
||
raise StopAsyncIteration
|
||
item = self.items[self.index]
|
||
self.index += 1
|
||
return item
|
||
|
||
# Set up router with fallback configuration
|
||
router = litellm.Router(
|
||
model_list=[
|
||
{
|
||
"model_name": "gpt-4",
|
||
"litellm_params": {"model": "gpt-4", "api_key": "fake-key-1"},
|
||
},
|
||
{
|
||
"model_name": "gpt-3.5-turbo",
|
||
"litellm_params": {"model": "gpt-3.5-turbo", "api_key": "fake-key-2"},
|
||
},
|
||
],
|
||
fallbacks=[{"gpt-4": ["gpt-3.5-turbo"]}],
|
||
set_verbose=True,
|
||
)
|
||
|
||
# Test data
|
||
messages = [{"role": "user", "content": "Hello"}]
|
||
initial_kwargs = {"model": "gpt-4", "stream": True, "temperature": 0.7}
|
||
|
||
# Test 1: Successful streaming (no errors)
|
||
print("\n=== Test 1: Successful streaming ===")
|
||
|
||
# Mock successful streaming response
|
||
mock_chunks = [
|
||
MagicMock(choices=[MagicMock(delta=MagicMock(content="Hello"))]),
|
||
MagicMock(choices=[MagicMock(delta=MagicMock(content=" there"))]),
|
||
MagicMock(choices=[MagicMock(delta=MagicMock(content="!"))]),
|
||
]
|
||
|
||
mock_response = AsyncIterator(mock_chunks)
|
||
|
||
setattr(mock_response, "model", "gpt-4")
|
||
setattr(mock_response, "custom_llm_provider", "openai")
|
||
setattr(mock_response, "logging_obj", MagicMock())
|
||
|
||
result = await router._acompletion_streaming_iterator(
|
||
model_response=mock_response, messages=messages, initial_kwargs=initial_kwargs
|
||
)
|
||
|
||
# Collect streamed chunks
|
||
collected_chunks = []
|
||
async for chunk in result:
|
||
collected_chunks.append(chunk)
|
||
|
||
assert len(collected_chunks) == 3
|
||
assert all(chunk in mock_chunks for chunk in collected_chunks)
|
||
print("✓ Successfully streamed all chunks")
|
||
|
||
# Test 2: MidStreamFallbackError with fallback
|
||
print("\n=== Test 2: MidStreamFallbackError with fallback ===")
|
||
|
||
# Create error that should trigger after first chunk
|
||
error = MidStreamFallbackError(
|
||
message="Connection lost",
|
||
model="gpt-4",
|
||
llm_provider="openai",
|
||
generated_content="Hello",
|
||
)
|
||
|
||
class AsyncIteratorWithError:
|
||
def __init__(self, items, error_after_index):
|
||
self.items = items
|
||
self.index = 0
|
||
self.error_after_index = error_after_index
|
||
self.chunks = []
|
||
|
||
def __aiter__(self):
|
||
return self
|
||
|
||
async def __anext__(self):
|
||
if self.index >= len(self.items):
|
||
raise StopAsyncIteration
|
||
if self.index == self.error_after_index:
|
||
raise error
|
||
item = self.items[self.index]
|
||
self.index += 1
|
||
return item
|
||
|
||
mock_error_response = AsyncIteratorWithError(
|
||
mock_chunks, 1
|
||
) # Error after first chunk
|
||
|
||
setattr(mock_error_response, "model", "gpt-4")
|
||
setattr(mock_error_response, "custom_llm_provider", "openai")
|
||
setattr(mock_error_response, "logging_obj", MagicMock())
|
||
|
||
# Mock the fallback response
|
||
fallback_chunks = [
|
||
MagicMock(choices=[MagicMock(delta=MagicMock(content=" world"))]),
|
||
MagicMock(choices=[MagicMock(delta=MagicMock(content="!"))]),
|
||
]
|
||
|
||
mock_fallback_response = AsyncIterator(fallback_chunks)
|
||
|
||
# Mock the fallback function
|
||
with patch.object(
|
||
router,
|
||
"async_function_with_fallbacks_common_utils",
|
||
return_value=mock_fallback_response,
|
||
) as mock_fallback_utils:
|
||
collected_chunks = []
|
||
result = await router._acompletion_streaming_iterator(
|
||
model_response=mock_error_response,
|
||
messages=messages,
|
||
initial_kwargs=initial_kwargs,
|
||
)
|
||
|
||
async for chunk in result:
|
||
collected_chunks.append(chunk)
|
||
|
||
# Verify fallback was called
|
||
assert mock_fallback_utils.called
|
||
call_args = mock_fallback_utils.call_args
|
||
|
||
# Check that generated content was added to messages
|
||
fallback_kwargs = call_args.kwargs["kwargs"]
|
||
modified_messages = fallback_kwargs["messages"]
|
||
|
||
# Should have original message + system message + assistant message with prefix
|
||
assert len(modified_messages) == 3
|
||
assert modified_messages[0] == {"role": "user", "content": "Hello"}
|
||
assert modified_messages[1]["role"] == "system"
|
||
assert "continuation" in modified_messages[1]["content"]
|
||
assert modified_messages[2]["role"] == "assistant"
|
||
assert modified_messages[2]["content"] == "Hello"
|
||
assert modified_messages[2]["prefix"] == True
|
||
|
||
# Verify fallback parameters
|
||
assert call_args.kwargs["disable_fallbacks"] == False
|
||
assert call_args.kwargs["model_group"] == "gpt-4"
|
||
|
||
# Should get original chunk + fallback chunks
|
||
assert len(collected_chunks) == 3 # 1 original + 2 fallback
|
||
print("✓ Fallback system called correctly with proper message modification")
|
||
|
||
print("\n=== All tests passed! ===")
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_acompletion_streaming_iterator_edge_cases():
|
||
"""Test edge cases for _acompletion_streaming_iterator."""
|
||
from unittest.mock import MagicMock
|
||
|
||
from litellm.exceptions import MidStreamFallbackError
|
||
|
||
router = litellm.Router(
|
||
model_list=[
|
||
{
|
||
"model_name": "gpt-4",
|
||
"litellm_params": {"model": "gpt-4", "api_key": "fake-key"},
|
||
}
|
||
],
|
||
set_verbose=True,
|
||
)
|
||
|
||
messages = [{"role": "user", "content": "Test"}]
|
||
initial_kwargs = {"model": "gpt-4", "stream": True}
|
||
|
||
# Test: Empty generated content
|
||
empty_error = MidStreamFallbackError(
|
||
message="Error",
|
||
model="gpt-4",
|
||
llm_provider="openai",
|
||
generated_content="", # Empty content
|
||
)
|
||
|
||
class AsyncIteratorImmediateError:
|
||
def __init__(self):
|
||
self.model = "gpt-4"
|
||
self.custom_llm_provider = "openai"
|
||
self.logging_obj = MagicMock()
|
||
self.chunks = []
|
||
|
||
def __aiter__(self):
|
||
return self
|
||
|
||
async def __anext__(self):
|
||
raise empty_error
|
||
|
||
mock_response = AsyncIteratorImmediateError()
|
||
|
||
# Mock empty fallback response using AsyncIterator
|
||
class EmptyAsyncIterator:
|
||
def __aiter__(self):
|
||
return self
|
||
|
||
async def __anext__(self):
|
||
raise StopAsyncIteration
|
||
|
||
mock_fallback_response = EmptyAsyncIterator()
|
||
|
||
with patch.object(
|
||
router,
|
||
"async_function_with_fallbacks_common_utils",
|
||
return_value=mock_fallback_response,
|
||
) as mock_fallback_utils:
|
||
collected_chunks = []
|
||
iterator = await router._acompletion_streaming_iterator(
|
||
model_response=mock_response,
|
||
messages=messages,
|
||
initial_kwargs=initial_kwargs,
|
||
)
|
||
|
||
async for chunk in iterator:
|
||
collected_chunks.append(chunk)
|
||
|
||
# Should still call fallback even with empty content
|
||
assert mock_fallback_utils.called
|
||
fallback_kwargs = mock_fallback_utils.call_args.kwargs["kwargs"]
|
||
modified_messages = fallback_kwargs["messages"]
|
||
|
||
# Empty content → pre-first-chunk path uses original messages
|
||
# (no continuation prompt added)
|
||
assert modified_messages == messages
|
||
print("✓ Handles empty generated content correctly")
|
||
|
||
print("✓ Edge case tests passed!")
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_acompletion_streaming_iterator_preserves_hidden_params():
|
||
"""
|
||
Regression test: FallbackStreamWrapper must copy _hidden_params from the
|
||
original CustomStreamWrapper so that x-litellm-overhead-duration-ms (and
|
||
other hidden params) are present in the proxy response headers for streaming.
|
||
"""
|
||
from unittest.mock import MagicMock
|
||
|
||
router = litellm.Router(
|
||
model_list=[
|
||
{
|
||
"model_name": "gpt-4",
|
||
"litellm_params": {"model": "gpt-4", "api_key": "fake-key"},
|
||
}
|
||
],
|
||
)
|
||
|
||
# Simulate a CustomStreamWrapper that already has timing metadata set by
|
||
# update_response_metadata (litellm_overhead_time_ms, _response_ms, etc.)
|
||
mock_response = MagicMock()
|
||
mock_response.model = "gpt-4"
|
||
mock_response.custom_llm_provider = "openai"
|
||
mock_response.logging_obj = MagicMock()
|
||
mock_response._hidden_params = {
|
||
"litellm_overhead_time_ms": 12.34,
|
||
"_response_ms": 500.0,
|
||
"litellm_call_id": "test-call-id",
|
||
"api_base": "https://api.openai.com",
|
||
"additional_headers": {},
|
||
}
|
||
|
||
# Make the mock iterable (yields nothing — we only care about hidden_params)
|
||
async def _empty():
|
||
return
|
||
yield # make it an async generator
|
||
|
||
mock_response.__aiter__ = lambda self: _empty().__aiter__()
|
||
|
||
result = await router._acompletion_streaming_iterator(
|
||
model_response=mock_response,
|
||
messages=[{"role": "user", "content": "hi"}],
|
||
initial_kwargs={"model": "gpt-4", "stream": True},
|
||
)
|
||
|
||
# The returned FallbackStreamWrapper must carry the original _hidden_params
|
||
assert hasattr(result, "_hidden_params"), "result must have _hidden_params"
|
||
assert result._hidden_params.get("litellm_overhead_time_ms") == 12.34, (
|
||
"litellm_overhead_time_ms must be preserved — "
|
||
"this is what drives x-litellm-overhead-duration-ms in streaming responses"
|
||
)
|
||
assert result._hidden_params.get("litellm_call_id") == "test-call-id"
|
||
assert result._hidden_params.get("_response_ms") == 500.0
|
||
|
||
|
||
def test_completion_streaming_iterator_fallback_on_429():
|
||
"""Sync streaming: MidStreamFallbackError (429 pre-first-chunk) triggers fallback.
|
||
|
||
This is the sync counterpart of test_acompletion_streaming_iterator.
|
||
Before this fix, __next__ raised RateLimitError directly and the Router
|
||
never got a chance to fall back.
|
||
"""
|
||
from unittest.mock import MagicMock
|
||
|
||
from litellm.exceptions import MidStreamFallbackError
|
||
|
||
router = litellm.Router(
|
||
model_list=[
|
||
{
|
||
"model_name": "gpt-4",
|
||
"litellm_params": {"model": "gpt-4", "api_key": "fake-key"},
|
||
}
|
||
],
|
||
)
|
||
|
||
messages = [{"role": "user", "content": "Test"}]
|
||
initial_kwargs = {"model": "gpt-4", "stream": True}
|
||
|
||
rate_limit_error = MidStreamFallbackError(
|
||
message="Resource exhausted",
|
||
model="gpt-4",
|
||
llm_provider="vertex_ai",
|
||
generated_content="",
|
||
is_pre_first_chunk=True,
|
||
)
|
||
|
||
class SyncIteratorImmediateError:
|
||
def __init__(self):
|
||
self.model = "gpt-4"
|
||
self.custom_llm_provider = "openai"
|
||
self.logging_obj = MagicMock()
|
||
self.chunks = []
|
||
|
||
def __iter__(self):
|
||
return self
|
||
|
||
def __next__(self):
|
||
raise rate_limit_error
|
||
|
||
mock_response = SyncIteratorImmediateError()
|
||
|
||
# Fallback returns a simple non-streaming response (fallback may not stream)
|
||
mock_fallback_response = MagicMock()
|
||
mock_fallback_response.__iter__ = MagicMock(return_value=iter([]))
|
||
|
||
with patch.object(
|
||
router,
|
||
"function_with_fallbacks",
|
||
return_value=mock_fallback_response,
|
||
) as mock_fallback:
|
||
result = router._completion_streaming_iterator(
|
||
model_response=mock_response,
|
||
messages=messages,
|
||
initial_kwargs=initial_kwargs,
|
||
)
|
||
|
||
collected_chunks = list(result)
|
||
|
||
assert mock_fallback.called
|
||
call_kwargs = mock_fallback.call_args
|
||
# Pre-first-chunk: should use original messages, no continuation prompt
|
||
assert call_kwargs.kwargs.get("messages") == messages
|
||
# Verify original_function is _completion (sync)
|
||
assert call_kwargs.kwargs.get("original_function") == router._completion
|
||
|
||
|
||
def test_completion_streaming_iterator_preserves_hidden_params():
|
||
"""SyncFallbackStreamWrapper must copy _hidden_params from original response."""
|
||
from unittest.mock import MagicMock
|
||
|
||
router = litellm.Router(
|
||
model_list=[
|
||
{
|
||
"model_name": "gpt-4",
|
||
"litellm_params": {"model": "gpt-4", "api_key": "fake-key"},
|
||
}
|
||
],
|
||
)
|
||
|
||
mock_response = MagicMock()
|
||
mock_response.model = "gpt-4"
|
||
mock_response.custom_llm_provider = "openai"
|
||
mock_response.logging_obj = MagicMock()
|
||
mock_response._hidden_params = {
|
||
"litellm_overhead_time_ms": 42.0,
|
||
"litellm_call_id": "test-sync-call",
|
||
}
|
||
mock_response.__iter__ = MagicMock(return_value=iter([]))
|
||
|
||
result = router._completion_streaming_iterator(
|
||
model_response=mock_response,
|
||
messages=[{"role": "user", "content": "hi"}],
|
||
initial_kwargs={"model": "gpt-4", "stream": True},
|
||
)
|
||
|
||
assert hasattr(result, "_hidden_params")
|
||
assert result._hidden_params.get("litellm_overhead_time_ms") == 42.0
|
||
assert result._hidden_params.get("litellm_call_id") == "test-sync-call"
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_acompletion_streaming_iterator_pre_first_chunk_skips_continuation():
|
||
"""When MidStreamFallbackError has is_pre_first_chunk=True, use original messages."""
|
||
from unittest.mock import MagicMock
|
||
|
||
from litellm.exceptions import MidStreamFallbackError
|
||
|
||
router = litellm.Router(
|
||
model_list=[
|
||
{
|
||
"model_name": "gpt-4",
|
||
"litellm_params": {"model": "gpt-4", "api_key": "fake-key"},
|
||
}
|
||
],
|
||
)
|
||
|
||
messages = [{"role": "user", "content": "Hello"}]
|
||
initial_kwargs = {"model": "gpt-4", "stream": True}
|
||
|
||
pre_first_chunk_error = MidStreamFallbackError(
|
||
message="429 Resource exhausted",
|
||
model="gpt-4",
|
||
llm_provider="vertex_ai",
|
||
generated_content="",
|
||
is_pre_first_chunk=True,
|
||
)
|
||
|
||
class AsyncIteratorPreFirstChunkError:
|
||
def __init__(self):
|
||
self.model = "gpt-4"
|
||
self.custom_llm_provider = "openai"
|
||
self.logging_obj = MagicMock()
|
||
self.chunks = []
|
||
|
||
def __aiter__(self):
|
||
return self
|
||
|
||
async def __anext__(self):
|
||
raise pre_first_chunk_error
|
||
|
||
mock_response = AsyncIteratorPreFirstChunkError()
|
||
|
||
class EmptyAsyncIterator:
|
||
def __aiter__(self):
|
||
return self
|
||
|
||
async def __anext__(self):
|
||
raise StopAsyncIteration
|
||
|
||
with patch.object(
|
||
router,
|
||
"async_function_with_fallbacks_common_utils",
|
||
return_value=EmptyAsyncIterator(),
|
||
) as mock_fallback_utils:
|
||
iterator = await router._acompletion_streaming_iterator(
|
||
model_response=mock_response,
|
||
messages=messages,
|
||
initial_kwargs=initial_kwargs,
|
||
)
|
||
async for _ in iterator:
|
||
pass
|
||
|
||
assert mock_fallback_utils.called
|
||
fallback_kwargs = mock_fallback_utils.call_args.kwargs["kwargs"]
|
||
# Pre-first-chunk: should use original messages, no continuation prompt
|
||
assert fallback_kwargs["messages"] == messages
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_async_function_with_fallbacks_common_utils():
|
||
"""Test the async_function_with_fallbacks_common_utils method"""
|
||
# Create a basic router for testing
|
||
router = litellm.Router(
|
||
model_list=[
|
||
{
|
||
"model_name": "gpt-3.5-turbo",
|
||
"litellm_params": {
|
||
"model": "gpt-3.5-turbo",
|
||
},
|
||
}
|
||
],
|
||
max_fallbacks=5,
|
||
)
|
||
|
||
# Test case 1: disable_fallbacks=True should raise original exception
|
||
test_exception = Exception("Test error")
|
||
with pytest.raises(Exception, match="Test error"):
|
||
await router.async_function_with_fallbacks_common_utils(
|
||
e=test_exception,
|
||
disable_fallbacks=True,
|
||
fallbacks=None,
|
||
context_window_fallbacks=None,
|
||
content_policy_fallbacks=None,
|
||
model_group="gpt-3.5-turbo",
|
||
args=(),
|
||
kwargs=MagicMock(),
|
||
)
|
||
|
||
# Test case 2: original_model_group=None should raise original exception
|
||
with pytest.raises(Exception, match="Test error"):
|
||
await router.async_function_with_fallbacks_common_utils(
|
||
e=test_exception,
|
||
disable_fallbacks=False,
|
||
fallbacks=None,
|
||
context_window_fallbacks=None,
|
||
content_policy_fallbacks=None,
|
||
model_group="gpt-3.5-turbo",
|
||
args=(),
|
||
kwargs={}, # No model key
|
||
)
|
||
|
||
|
||
def test_should_include_deployment():
|
||
"""Test that Router.should_include_deployment returns the correct response"""
|
||
router = litellm.Router(
|
||
model_list=[
|
||
{
|
||
"model_name": "model_name_a28a12f9-3e44-4861-bd4f-325f2d309ce8_cd5dc6fb-b046-4e05-ae1d-32ba4d936266",
|
||
"litellm_params": {"model": "openai/*"},
|
||
"model_info": {
|
||
"team_id": "a28a12f9-3e44-4861-bd4f-325f2d309ce8",
|
||
"team_public_model_name": "openai/*",
|
||
},
|
||
}
|
||
],
|
||
)
|
||
|
||
model = {
|
||
"model_name": "model_name_a28a12f9-3e44-4861-bd4f-325f2d309ce8_cd5dc6fb-b046-4e05-ae1d-32ba4d936266",
|
||
"litellm_params": {
|
||
"api_key": "sk-proj-1234567890",
|
||
"custom_llm_provider": "openai",
|
||
"use_in_pass_through": False,
|
||
"use_litellm_proxy": False,
|
||
"merge_reasoning_content_in_choices": False,
|
||
"model": "openai/*",
|
||
},
|
||
"model_info": {
|
||
"id": "95f58039-d54a-4d1c-b700-5e32e99a1120",
|
||
"db_model": True,
|
||
"updated_by": "64a2f787-0863-4d76-9516-2dc49c1598e8",
|
||
"created_by": "64a2f787-0863-4d76-9516-2dc49c1598e8",
|
||
"team_id": "a28a12f9-3e44-4861-bd4f-325f2d309ce8",
|
||
"team_public_model_name": "openai/*",
|
||
"mode": "completion",
|
||
"access_groups": ["restricted-models-openai"],
|
||
},
|
||
}
|
||
model_name = "openai/o4-mini-deep-research"
|
||
team_id = "a28a12f9-3e44-4861-bd4f-325f2d309ce8"
|
||
assert router.get_model_list(
|
||
model_name=model_name,
|
||
team_id=team_id,
|
||
)
|
||
|
||
|
||
def test_get_deployment_model_info_base_model_flow():
|
||
"""Test that get_deployment_model_info correctly handles the base model flow"""
|
||
from unittest.mock import patch
|
||
|
||
router = litellm.Router(
|
||
model_list=[
|
||
{
|
||
"model_name": "test-model",
|
||
"litellm_params": {"model": "gpt-3.5-turbo"},
|
||
}
|
||
],
|
||
)
|
||
|
||
# Mock data for the test
|
||
mock_custom_model_info = {
|
||
"base_model": "gpt-3.5-turbo",
|
||
"input_cost_per_token": 0.001,
|
||
"output_cost_per_token": 0.002,
|
||
"custom_field": "custom_value",
|
||
}
|
||
|
||
mock_base_model_info = {
|
||
"key": "gpt-3.5-turbo",
|
||
"max_tokens": 4096,
|
||
"max_input_tokens": 4096,
|
||
"max_output_tokens": 4096,
|
||
"input_cost_per_token": 0.0015, # This should be overridden by custom model info
|
||
"output_cost_per_token": 0.002,
|
||
"litellm_provider": "openai",
|
||
"mode": "chat",
|
||
"supported_openai_params": ["temperature", "max_tokens"],
|
||
}
|
||
|
||
mock_litellm_model_name_info = {
|
||
"key": "test-model",
|
||
"max_tokens": 2048,
|
||
"max_input_tokens": 2048,
|
||
"max_output_tokens": 2048,
|
||
"input_cost_per_token": 0.0005,
|
||
"output_cost_per_token": 0.001,
|
||
"litellm_provider": "test_provider",
|
||
"mode": "completion",
|
||
"supported_openai_params": ["temperature"],
|
||
}
|
||
|
||
# Test Case 1: Base model flow with custom model info that has base_model
|
||
with patch.object(
|
||
litellm, "model_cost", {"test-custom-model": mock_custom_model_info}
|
||
):
|
||
with patch.object(litellm, "get_model_info") as mock_get_model_info:
|
||
# Configure mock returns
|
||
mock_get_model_info.side_effect = lambda model: {
|
||
"gpt-3.5-turbo": mock_base_model_info,
|
||
"test-model": mock_litellm_model_name_info,
|
||
}.get(model)
|
||
|
||
result = router.get_deployment_model_info(
|
||
model_id="test-custom-model", model_name="test-model"
|
||
)
|
||
|
||
# Verify that get_model_info was called for both base model and model name
|
||
assert mock_get_model_info.call_count == 2
|
||
mock_get_model_info.assert_any_call(
|
||
model="gpt-3.5-turbo"
|
||
) # base model call
|
||
mock_get_model_info.assert_any_call(model="test-model") # model name call
|
||
|
||
# Verify the result contains merged information
|
||
assert result is not None
|
||
|
||
# Test the correct merging behavior after fix:
|
||
# 1. base_model_info provides defaults, custom_model_info overrides (correct priority)
|
||
# 2. The result of step 1 gets merged into litellm_model_name_info (custom+base override litellm)
|
||
|
||
# Fields from custom model (should override base model values)
|
||
assert (
|
||
result["input_cost_per_token"] == 0.001
|
||
) # From custom model (overrides base 0.0015)
|
||
assert (
|
||
result["output_cost_per_token"] == 0.002
|
||
) # From custom model (same as base)
|
||
assert result["custom_field"] == "custom_value" # From custom model
|
||
|
||
# Fields from base model that weren't overridden by custom
|
||
assert result["max_tokens"] == 4096 # From base model
|
||
assert result["litellm_provider"] == "openai" # From base model
|
||
assert (
|
||
result["mode"] == "chat"
|
||
) # From base model (overrides litellm "completion")
|
||
|
||
# The key field comes from base model since both base and litellm have it
|
||
# and base model info overrides litellm model name info in final merge
|
||
assert (
|
||
result["key"] == "gpt-3.5-turbo"
|
||
) # From base model (overrides litellm key)
|
||
|
||
# Test Case 2: Custom model info without base_model
|
||
mock_custom_model_info_no_base = {
|
||
"input_cost_per_token": 0.001,
|
||
"output_cost_per_token": 0.002,
|
||
"custom_field": "custom_value",
|
||
}
|
||
|
||
with patch.object(
|
||
litellm,
|
||
"model_cost",
|
||
{"test-custom-model-no-base": mock_custom_model_info_no_base},
|
||
):
|
||
with patch.object(litellm, "get_model_info") as mock_get_model_info:
|
||
mock_get_model_info.side_effect = lambda model: {
|
||
"test-model": mock_litellm_model_name_info,
|
||
}.get(model)
|
||
|
||
result = router.get_deployment_model_info(
|
||
model_id="test-custom-model-no-base", model_name="test-model"
|
||
)
|
||
|
||
# Should only call get_model_info once for model name (no base model)
|
||
assert mock_get_model_info.call_count == 1
|
||
mock_get_model_info.assert_called_with(model="test-model")
|
||
|
||
# Verify the result contains merged information
|
||
assert result is not None
|
||
assert result["input_cost_per_token"] == 0.001 # From custom model
|
||
assert result["max_tokens"] == 2048 # From litellm model name info
|
||
assert result["custom_field"] == "custom_value" # From custom model
|
||
assert result["mode"] == "completion" # From litellm model name info
|
||
|
||
# Test Case 3: No custom model info, only litellm model name info
|
||
with patch.object(litellm, "model_cost", {}): # Empty model cost
|
||
with patch.object(litellm, "get_model_info") as mock_get_model_info:
|
||
mock_get_model_info.side_effect = lambda model: {
|
||
"test-model": mock_litellm_model_name_info,
|
||
}.get(model)
|
||
|
||
result = router.get_deployment_model_info(
|
||
model_id="non-existent-model", model_name="test-model"
|
||
)
|
||
|
||
# Should only call get_model_info once for model name
|
||
assert mock_get_model_info.call_count == 1
|
||
mock_get_model_info.assert_called_with(model="test-model")
|
||
|
||
# Result should be just the litellm model name info
|
||
assert result is not None
|
||
assert result == mock_litellm_model_name_info
|
||
|
||
# Test Case 4: Base model info retrieval fails (exception handling)
|
||
mock_custom_model_info_invalid_base = {
|
||
"base_model": "invalid-base-model",
|
||
"input_cost_per_token": 0.001,
|
||
"output_cost_per_token": 0.002,
|
||
}
|
||
|
||
with patch.object(
|
||
litellm,
|
||
"model_cost",
|
||
{"test-custom-model-invalid": mock_custom_model_info_invalid_base},
|
||
):
|
||
with patch.object(litellm, "get_model_info") as mock_get_model_info:
|
||
# Mock get_model_info to raise exception for invalid base model
|
||
def mock_get_model_info_side_effect(model):
|
||
if model == "invalid-base-model":
|
||
raise Exception("Model not found")
|
||
elif model == "test-model":
|
||
return mock_litellm_model_name_info
|
||
return None
|
||
|
||
mock_get_model_info.side_effect = mock_get_model_info_side_effect
|
||
|
||
result = router.get_deployment_model_info(
|
||
model_id="test-custom-model-invalid", model_name="test-model"
|
||
)
|
||
|
||
# Should handle exception gracefully and still return merged result
|
||
assert result is not None
|
||
assert result["input_cost_per_token"] == 0.001 # From custom model
|
||
assert result["mode"] == "completion" # From litellm model name info
|
||
|
||
# Test Case 5: Both model_cost.get() and get_model_info() return None
|
||
with patch.object(litellm, "model_cost", {}):
|
||
with patch.object(
|
||
litellm, "get_model_info", side_effect=Exception("Not found")
|
||
):
|
||
result = router.get_deployment_model_info(
|
||
model_id="non-existent", model_name="non-existent"
|
||
)
|
||
|
||
# Should return None when no model info is found
|
||
assert result is None
|
||
|
||
print("✓ All base model flow test cases passed!")
|
||
|
||
|
||
@patch("litellm.model_cost", {})
|
||
def test_get_deployment_model_info_base_model_merge_priority():
|
||
"""Test that base model info merging respects the correct priority order"""
|
||
from unittest.mock import patch
|
||
|
||
router = litellm.Router(
|
||
model_list=[
|
||
{
|
||
"model_name": "test-model",
|
||
"litellm_params": {"model": "gpt-3.5-turbo"},
|
||
}
|
||
],
|
||
)
|
||
|
||
# Test data with overlapping fields to test merge priority
|
||
mock_custom_model_info = {
|
||
"base_model": "gpt-4",
|
||
"input_cost_per_token": 0.01, # Should override base model value
|
||
"max_tokens": 8000, # Should override base model value
|
||
"custom_only_field": "custom_value",
|
||
}
|
||
|
||
mock_base_model_info = {
|
||
"key": "gpt-4",
|
||
"max_tokens": 4096, # Should be overridden by custom model
|
||
"input_cost_per_token": 0.03, # Should be overridden by custom model
|
||
"output_cost_per_token": 0.06, # Should be preserved (not in custom)
|
||
"litellm_provider": "openai",
|
||
"base_only_field": "base_value",
|
||
}
|
||
|
||
mock_litellm_model_name_info = {
|
||
"key": "test-model",
|
||
"max_tokens": 2048, # Should be overridden by final custom model info
|
||
"input_cost_per_token": 0.005, # Should be overridden by final custom model info
|
||
"output_cost_per_token": 0.01, # Should be overridden by final custom model info
|
||
"mode": "completion",
|
||
"litellm_only_field": "litellm_value",
|
||
}
|
||
|
||
with patch.object(
|
||
litellm, "model_cost", {"custom-model-id": mock_custom_model_info}
|
||
):
|
||
with patch.object(litellm, "get_model_info") as mock_get_model_info:
|
||
mock_get_model_info.side_effect = lambda model: {
|
||
"gpt-4": mock_base_model_info,
|
||
"test-model": mock_litellm_model_name_info,
|
||
}.get(model)
|
||
|
||
result = router.get_deployment_model_info(
|
||
model_id="custom-model-id", model_name="test-model"
|
||
)
|
||
|
||
assert result is not None
|
||
|
||
# Test correct merge priority after fix:
|
||
# 1. base_model_info provides defaults
|
||
# 2. custom_model_info overrides base_model_info
|
||
# 3. Result from steps 1-2 overrides litellm_model_name_info
|
||
|
||
# Fields that should come from custom model info (highest priority)
|
||
assert (
|
||
result["input_cost_per_token"] == 0.01
|
||
) # From custom model (overrides base 0.03)
|
||
assert (
|
||
result["max_tokens"] == 8000
|
||
) # From custom model (overrides base 4096)
|
||
assert result["custom_only_field"] == "custom_value" # From custom model
|
||
|
||
# Fields that should come from base model (not overridden by custom)
|
||
assert (
|
||
result["output_cost_per_token"] == 0.06
|
||
) # From base model (not in custom)
|
||
assert (
|
||
result["litellm_provider"] == "openai"
|
||
) # From base model (not in custom)
|
||
assert (
|
||
result["base_only_field"] == "base_value"
|
||
) # From base model (not in custom)
|
||
|
||
# Fields that should come from litellm model name info (not overridden by custom+base)
|
||
assert (
|
||
result["mode"] == "completion"
|
||
) # From litellm model name info (not in custom or base)
|
||
assert (
|
||
result["litellm_only_field"] == "litellm_value"
|
||
) # From litellm model name info (not in custom or base)
|
||
|
||
# Key comes from base model since both base and litellm have key fields
|
||
# and the merged custom+base overrides litellm in the final merge
|
||
assert result["key"] == "gpt-4"
|
||
|
||
print("✓ Base model merge priority test passed!")
|
||
|
||
|
||
def test_add_deployment_model_to_endpoint_for_llm_passthrough_route():
|
||
"""
|
||
Test that _add_deployment_model_to_endpoint_for_llm_passthrough_route correctly strips bedrock provider prefix
|
||
"""
|
||
router = litellm.Router(
|
||
model_list=[
|
||
{
|
||
"model_name": "special-bedrock-model",
|
||
"litellm_params": {
|
||
"model": "bedrock/us.anthropic.claude-3-5-sonnet-20240620-v1:0",
|
||
},
|
||
}
|
||
],
|
||
)
|
||
|
||
# Test Case 1: Bedrock model with provider prefix - should strip "bedrock/" prefix
|
||
kwargs = {
|
||
"endpoint": "/model/special-bedrock-model/invoke",
|
||
"custom_llm_provider": "bedrock",
|
||
}
|
||
result = router._add_deployment_model_to_endpoint_for_llm_passthrough_route(
|
||
kwargs=kwargs,
|
||
model="special-bedrock-model",
|
||
model_name="bedrock/us.anthropic.claude-3-5-sonnet-20240620-v1:0",
|
||
)
|
||
assert (
|
||
result["endpoint"]
|
||
== "/model/us.anthropic.claude-3-5-sonnet-20240620-v1:0/invoke"
|
||
), f"Expected '/model/us.anthropic.claude-3-5-sonnet-20240620-v1:0/invoke', got '{result['endpoint']}'"
|
||
|
||
# Test Case 2: Bedrock invoke-with-response-stream endpoint
|
||
kwargs = {
|
||
"endpoint": "/model/special-bedrock-model/invoke-with-response-stream",
|
||
"custom_llm_provider": "bedrock",
|
||
}
|
||
result = router._add_deployment_model_to_endpoint_for_llm_passthrough_route(
|
||
kwargs=kwargs,
|
||
model="special-bedrock-model",
|
||
model_name="bedrock/us.anthropic.claude-3-5-sonnet-20240620-v1:0",
|
||
)
|
||
assert (
|
||
result["endpoint"]
|
||
== "/model/us.anthropic.claude-3-5-sonnet-20240620-v1:0/invoke-with-response-stream"
|
||
), f"Expected streaming endpoint with stripped prefix, got '{result['endpoint']}'"
|
||
|
||
# Test Case 3: Bedrock converse endpoint
|
||
kwargs = {
|
||
"endpoint": "/model/bedrock-model/converse",
|
||
"custom_llm_provider": "bedrock",
|
||
}
|
||
result = router._add_deployment_model_to_endpoint_for_llm_passthrough_route(
|
||
kwargs=kwargs,
|
||
model="bedrock-model",
|
||
model_name="bedrock/us.meta.llama3-8b-instruct-v1:0",
|
||
)
|
||
assert (
|
||
result["endpoint"] == "/model/us.meta.llama3-8b-instruct-v1:0/converse"
|
||
), f"Expected '/model/us.meta.llama3-8b-instruct-v1:0/converse', got '{result['endpoint']}'"
|
||
|
||
# Test Case 4: Bedrock provider prefix auto-detected from model_name
|
||
kwargs = {
|
||
"endpoint": "/model/router-model/invoke",
|
||
}
|
||
result = router._add_deployment_model_to_endpoint_for_llm_passthrough_route(
|
||
kwargs=kwargs,
|
||
model="router-model",
|
||
model_name="bedrock/us.meta.llama3-8b-instruct-v1:0",
|
||
)
|
||
assert (
|
||
result["endpoint"] == "/model/us.meta.llama3-8b-instruct-v1:0/invoke"
|
||
), f"Expected '/model/us.meta.llama3-8b-instruct-v1:0/invoke', got '{result['endpoint']}'"
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_router_acompletion_with_unknown_model_and_default_fallback():
|
||
"""
|
||
Test that the router successfully uses a default fallback when a completely
|
||
unknown model is requested. It should not raise a BadRequestError.
|
||
This test verifies the fix for issue #15114.
|
||
"""
|
||
model_list = [
|
||
{
|
||
"model_name": "gpt-4o", # This is the fallback model
|
||
"litellm_params": {
|
||
"model": "azure/gpt-4o-real", # The actual underlying model name
|
||
"api_key": "fake-key",
|
||
"api_base": "https://fake-endpoint.openai.azure.com/",
|
||
"mock_response": "this is the fallback response", # Mocked response to prevent real API calls
|
||
},
|
||
}
|
||
]
|
||
|
||
# Initialize the router with a default fallback
|
||
router = litellm.Router(model_list=model_list, default_fallbacks=["gpt-4o"])
|
||
|
||
messages = [
|
||
{"role": "user", "content": "This call should succeed by falling back."}
|
||
]
|
||
|
||
# Call completion with a model name that is NOT in the model_list
|
||
response = await router.acompletion(
|
||
model="completely-unknown-model", messages=messages
|
||
)
|
||
|
||
# Check that the call did not fail and we received a valid response object.
|
||
assert response is not None
|
||
|
||
# Check that the content of the response is from the MOCKED fallback model.
|
||
assert response.choices[0].message.content == "this is the fallback response"
|
||
|
||
# Check that the response object reports the model that was *actually* called.
|
||
assert response.model == "gpt-4o-real"
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_router_acompletion_with_unknown_model_and_no_fallback():
|
||
"""
|
||
Test that the router still raises a BadRequestError for an unknown model
|
||
when no default fallbacks are configured. This ensures we don't break
|
||
the original behavior.
|
||
"""
|
||
model_list = [
|
||
{
|
||
"model_name": "gpt-4o",
|
||
"litellm_params": {
|
||
"model": "azure/gpt-4o-real",
|
||
"api_key": "fake-key",
|
||
"mock_response": "this should not be called",
|
||
},
|
||
}
|
||
]
|
||
|
||
# Initialize the router WITHOUT any default fallbacks
|
||
router = litellm.Router(model_list=model_list)
|
||
|
||
messages = [{"role": "user", "content": "This call should fail."}]
|
||
|
||
# Use pytest.raises to assert that a BadRequestError is thrown.
|
||
with pytest.raises(litellm.BadRequestError) as excinfo:
|
||
await router.acompletion(model="completely-unknown-model", messages=messages)
|
||
|
||
# Check that the error message is correct.
|
||
# The router returns 'no healthy deployments' because get_model_list returns [] not None.
|
||
assert "no healthy deployments for this model" in str(excinfo.value)
|
||
|
||
|
||
def test_get_deployment_credentials_with_provider_aws_bedrock_runtime_endpoint():
|
||
"""
|
||
Test that get_deployment_credentials_with_provider correctly copies
|
||
aws_bedrock_runtime_endpoint from deployment litellm_params to credentials.
|
||
"""
|
||
router = litellm.Router(
|
||
model_list=[
|
||
{
|
||
"model_name": "bedrock-claude-model",
|
||
"litellm_params": {
|
||
"model": "bedrock/us.anthropic.claude-3-5-sonnet-20240620-v1:0",
|
||
"aws_access_key_id": "test-access-key",
|
||
"aws_secret_access_key": "test-secret-key",
|
||
"aws_region_name": "us-east-1",
|
||
"aws_bedrock_runtime_endpoint": "https://bedrock-runtime.us-east-1.amazonaws.com",
|
||
},
|
||
}
|
||
],
|
||
)
|
||
|
||
credentials = router.get_deployment_credentials_with_provider(
|
||
model_id="bedrock-claude-model"
|
||
)
|
||
|
||
assert credentials is not None
|
||
assert (
|
||
credentials["aws_bedrock_runtime_endpoint"]
|
||
== "https://bedrock-runtime.us-east-1.amazonaws.com"
|
||
)
|
||
assert credentials["aws_access_key_id"] == "test-access-key"
|
||
assert credentials["aws_secret_access_key"] == "test-secret-key"
|
||
assert credentials["aws_region_name"] == "us-east-1"
|
||
assert credentials["custom_llm_provider"] == "bedrock"
|
||
|
||
|
||
def test_get_deployment_credentials_with_provider_resolves_credential_name():
|
||
"""
|
||
Test that get_deployment_credentials_with_provider correctly resolves
|
||
litellm_credential_name to actual credential values (for UI-created models).
|
||
"""
|
||
from litellm.types.utils import CredentialItem
|
||
|
||
# Setup credential list with a test credential
|
||
litellm.credential_list = [
|
||
CredentialItem(
|
||
credential_name="test-azure-cred",
|
||
credential_info={"custom_llm_provider": "azure"},
|
||
credential_values={
|
||
"api_key": "resolved-api-key",
|
||
"api_base": "https://resolved.openai.azure.com",
|
||
"api_version": "2024-02-01",
|
||
},
|
||
)
|
||
]
|
||
|
||
router = litellm.Router(
|
||
model_list=[
|
||
{
|
||
"model_name": "azure-gpt-4",
|
||
"litellm_params": {
|
||
"model": "azure/gpt-4",
|
||
"litellm_credential_name": "test-azure-cred",
|
||
},
|
||
}
|
||
],
|
||
)
|
||
|
||
credentials = router.get_deployment_credentials_with_provider(
|
||
model_id="azure-gpt-4"
|
||
)
|
||
|
||
assert credentials is not None
|
||
assert credentials["api_key"] == "resolved-api-key"
|
||
assert credentials["api_base"] == "https://resolved.openai.azure.com"
|
||
assert credentials["api_version"] == "2024-02-01"
|
||
assert credentials["custom_llm_provider"] == "azure"
|
||
# Ensure credential name is removed after resolution
|
||
assert "litellm_credential_name" not in credentials
|
||
|
||
# Cleanup
|
||
litellm.credential_list = []
|
||
|
||
|
||
def test_get_available_guardrail_single_deployment():
|
||
"""
|
||
Test get_available_guardrail returns the single guardrail when only one exists.
|
||
"""
|
||
guardrail_config = {
|
||
"guardrail_name": "content-filter",
|
||
"litellm_params": {"guardrail": "custom", "mode": "pre_call"},
|
||
"id": "guardrail-1",
|
||
}
|
||
|
||
router = litellm.Router(
|
||
model_list=[
|
||
{
|
||
"model_name": "gpt-3.5-turbo",
|
||
"litellm_params": {"model": "gpt-3.5-turbo"},
|
||
}
|
||
],
|
||
guardrail_list=[guardrail_config],
|
||
)
|
||
|
||
result = router.get_available_guardrail(guardrail_name="content-filter")
|
||
assert result == guardrail_config
|
||
|
||
|
||
def test_get_available_guardrail_multiple_deployments():
|
||
"""
|
||
Test get_available_guardrail load balances across multiple guardrails.
|
||
"""
|
||
guardrail_1 = {
|
||
"guardrail_name": "content-filter",
|
||
"litellm_params": {"guardrail": "custom", "mode": "pre_call"},
|
||
"id": "guardrail-1",
|
||
}
|
||
guardrail_2 = {
|
||
"guardrail_name": "content-filter",
|
||
"litellm_params": {"guardrail": "custom", "mode": "pre_call"},
|
||
"id": "guardrail-2",
|
||
}
|
||
|
||
router = litellm.Router(
|
||
model_list=[
|
||
{
|
||
"model_name": "gpt-3.5-turbo",
|
||
"litellm_params": {"model": "gpt-3.5-turbo"},
|
||
}
|
||
],
|
||
guardrail_list=[guardrail_1, guardrail_2],
|
||
)
|
||
|
||
# Call multiple times to verify load balancing
|
||
results = set()
|
||
for _ in range(20):
|
||
result = router.get_available_guardrail(guardrail_name="content-filter")
|
||
results.add(result["id"])
|
||
|
||
# Both guardrails should be selected at least once
|
||
assert "guardrail-1" in results or "guardrail-2" in results
|
||
|
||
|
||
def test_get_available_guardrail_not_found():
|
||
"""
|
||
Test get_available_guardrail raises ValueError when guardrail not found.
|
||
"""
|
||
router = litellm.Router(
|
||
model_list=[
|
||
{
|
||
"model_name": "gpt-3.5-turbo",
|
||
"litellm_params": {"model": "gpt-3.5-turbo"},
|
||
}
|
||
],
|
||
guardrail_list=[],
|
||
)
|
||
|
||
with pytest.raises(ValueError, match="No guardrail found with name"):
|
||
router.get_available_guardrail(guardrail_name="non-existent")
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_aguardrail_helper():
|
||
"""
|
||
Test _aguardrail_helper selects a guardrail and executes the original function.
|
||
"""
|
||
guardrail_config = {
|
||
"guardrail_name": "content-filter",
|
||
"litellm_params": {"guardrail": "custom", "mode": "pre_call"},
|
||
"id": "guardrail-1",
|
||
}
|
||
|
||
router = litellm.Router(
|
||
model_list=[
|
||
{
|
||
"model_name": "gpt-3.5-turbo",
|
||
"litellm_params": {"model": "gpt-3.5-turbo"},
|
||
}
|
||
],
|
||
guardrail_list=[guardrail_config],
|
||
)
|
||
|
||
# Mock the original function
|
||
async def mock_original_function(**kwargs):
|
||
return {
|
||
"result": "success",
|
||
"selected_guardrail": kwargs.get("selected_guardrail"),
|
||
}
|
||
|
||
result = await router._aguardrail_helper(
|
||
model="content-filter",
|
||
original_generic_function=mock_original_function,
|
||
)
|
||
|
||
assert result["result"] == "success"
|
||
assert result["selected_guardrail"] == guardrail_config
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_aguardrail():
|
||
"""
|
||
Test aguardrail executes a guardrail with load balancing and fallbacks.
|
||
"""
|
||
guardrail_config = {
|
||
"guardrail_name": "content-filter",
|
||
"litellm_params": {"guardrail": "custom", "mode": "pre_call"},
|
||
"id": "guardrail-1",
|
||
}
|
||
|
||
router = litellm.Router(
|
||
model_list=[
|
||
{
|
||
"model_name": "gpt-3.5-turbo",
|
||
"litellm_params": {"model": "gpt-3.5-turbo"},
|
||
}
|
||
],
|
||
guardrail_list=[guardrail_config],
|
||
)
|
||
|
||
# Mock the original function
|
||
async def mock_original_function(**kwargs):
|
||
return {
|
||
"result": "success",
|
||
"selected_guardrail": kwargs.get("selected_guardrail"),
|
||
}
|
||
|
||
result = await router.aguardrail(
|
||
guardrail_name="content-filter",
|
||
original_function=mock_original_function,
|
||
)
|
||
|
||
assert result["result"] == "success"
|
||
assert result["selected_guardrail"]["id"] == "guardrail-1"
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_anthropic_messages_call_type_is_cached():
|
||
"""
|
||
Regression test: Verify that anthropic_messages call type is allowed
|
||
in PromptCachingDeploymentCheck.async_log_success_event.
|
||
"""
|
||
import asyncio
|
||
|
||
from litellm.caching.dual_cache import DualCache
|
||
from litellm.router_utils.pre_call_checks.prompt_caching_deployment_check import (
|
||
PromptCachingDeploymentCheck,
|
||
)
|
||
from litellm.router_utils.prompt_caching_cache import PromptCachingCache
|
||
from litellm.types.utils import (
|
||
CallTypes,
|
||
StandardLoggingHiddenParams,
|
||
StandardLoggingMetadata,
|
||
StandardLoggingModelInformation,
|
||
StandardLoggingPayload,
|
||
)
|
||
|
||
# Create mock standard logging payload inline
|
||
def create_standard_logging_payload() -> StandardLoggingPayload:
|
||
return StandardLoggingPayload(
|
||
id="test_id",
|
||
call_type="completion",
|
||
response_cost=0.1,
|
||
response_cost_failure_debug_info=None,
|
||
status="success",
|
||
total_tokens=30,
|
||
prompt_tokens=20,
|
||
completion_tokens=10,
|
||
startTime=1234567890.0,
|
||
endTime=1234567891.0,
|
||
completionStartTime=1234567890.5,
|
||
model_map_information=StandardLoggingModelInformation(
|
||
model_map_key="gpt-3.5-turbo", model_map_value=None
|
||
),
|
||
model="gpt-3.5-turbo",
|
||
model_id="model-123",
|
||
model_group="openai-gpt",
|
||
api_base="https://api.openai.com",
|
||
metadata=StandardLoggingMetadata(
|
||
user_api_key_hash="test_hash",
|
||
user_api_key_org_id=None,
|
||
user_api_key_alias="test_alias",
|
||
user_api_key_team_id="test_team",
|
||
user_api_key_user_id="test_user",
|
||
user_api_key_team_alias="test_team_alias",
|
||
spend_logs_metadata=None,
|
||
requester_ip_address="127.0.0.1",
|
||
requester_metadata=None,
|
||
),
|
||
cache_hit=False,
|
||
cache_key=None,
|
||
saved_cache_cost=0.0,
|
||
request_tags=[],
|
||
end_user=None,
|
||
requester_ip_address="127.0.0.1",
|
||
messages=[{"role": "user", "content": "Hello, world!"}],
|
||
response={"choices": [{"message": {"content": "Hi there!"}}]},
|
||
error_str=None,
|
||
model_parameters={"stream": True},
|
||
hidden_params=StandardLoggingHiddenParams(
|
||
model_id="model-123",
|
||
cache_key=None,
|
||
api_base="https://api.openai.com",
|
||
response_cost="0.1",
|
||
additional_headers=None,
|
||
),
|
||
)
|
||
|
||
cache = DualCache()
|
||
deployment_check = PromptCachingDeploymentCheck(cache=cache)
|
||
prompt_cache = PromptCachingCache(cache=cache)
|
||
|
||
# Create messages with enough tokens to pass the caching threshold
|
||
test_messages = [
|
||
{
|
||
"role": "user",
|
||
"content": [
|
||
{
|
||
"type": "text",
|
||
"text": "test long message here" * 1024,
|
||
"cache_control": {"type": "ephemeral", "ttl": "5m"},
|
||
}
|
||
],
|
||
}
|
||
]
|
||
test_model_id = "test-model-id-123"
|
||
|
||
# Create a payload with anthropic_messages call type
|
||
payload = create_standard_logging_payload()
|
||
payload["call_type"] = CallTypes.anthropic_messages.value
|
||
payload["messages"] = test_messages
|
||
payload["model"] = "anthropic/claude-3-5-sonnet-20240620"
|
||
payload["model_id"] = test_model_id
|
||
|
||
# Log the success event (should cache the model_id)
|
||
await deployment_check.async_log_success_event(
|
||
kwargs={"standard_logging_object": payload},
|
||
response_obj={},
|
||
start_time=1234567890.0,
|
||
end_time=1234567891.0,
|
||
)
|
||
|
||
# Small delay to ensure cache write completes
|
||
await asyncio.sleep(0.1)
|
||
|
||
# Verify that the model_id was actually cached
|
||
cached_result = await prompt_cache.async_get_model_id(
|
||
messages=test_messages,
|
||
tools=None,
|
||
)
|
||
|
||
# This assertion will FAIL if anthropic_messages is filtered out
|
||
assert (
|
||
cached_result is not None
|
||
), "Model ID should be cached for anthropic_messages call type"
|
||
assert (
|
||
cached_result["model_id"] == test_model_id
|
||
), f"Expected {test_model_id}, got {cached_result['model_id']}"
|
||
|
||
|
||
def test_update_kwargs_with_deployment_propagates_model_tags():
|
||
"""
|
||
Test that deployment-level tags from litellm_params are merged into
|
||
kwargs metadata when _update_kwargs_with_deployment is called.
|
||
|
||
This ensures model-level tags defined in config.yaml appear in SpendLogs.
|
||
See: https://github.com/BerriAI/litellm/issues/XXXX
|
||
"""
|
||
router = litellm.Router(
|
||
model_list=[
|
||
{
|
||
"model_name": "gpt-4o-mini",
|
||
"litellm_params": {
|
||
"model": "openai/gpt-4o-mini",
|
||
"api_key": "fake-key",
|
||
"tags": ["openai-account", "production"],
|
||
},
|
||
},
|
||
],
|
||
)
|
||
|
||
kwargs: dict = {"metadata": {}}
|
||
deployment = router.get_deployment_by_model_group_name(
|
||
model_group_name="gpt-4o-mini"
|
||
)
|
||
router._update_kwargs_with_deployment(deployment=deployment, kwargs=kwargs)
|
||
|
||
# Deployment tags should be propagated to kwargs metadata
|
||
assert "tags" in kwargs["metadata"]
|
||
assert "openai-account" in kwargs["metadata"]["tags"]
|
||
assert "production" in kwargs["metadata"]["tags"]
|
||
|
||
|
||
def test_update_kwargs_with_deployment_merges_tags_without_duplicates():
|
||
"""
|
||
Test that when both request-level and deployment-level tags exist,
|
||
they are merged without duplicates.
|
||
"""
|
||
router = litellm.Router(
|
||
model_list=[
|
||
{
|
||
"model_name": "gpt-4o-mini",
|
||
"litellm_params": {
|
||
"model": "openai/gpt-4o-mini",
|
||
"api_key": "fake-key",
|
||
"tags": ["openai-account", "shared-tag"],
|
||
},
|
||
},
|
||
],
|
||
)
|
||
|
||
# Simulate request that already has tags (from request body or key/team level)
|
||
kwargs: dict = {"metadata": {"tags": ["user-tag", "shared-tag"]}}
|
||
deployment = router.get_deployment_by_model_group_name(
|
||
model_group_name="gpt-4o-mini"
|
||
)
|
||
router._update_kwargs_with_deployment(deployment=deployment, kwargs=kwargs)
|
||
|
||
# Both sources should be merged, no duplicates
|
||
assert "user-tag" in kwargs["metadata"]["tags"]
|
||
assert "openai-account" in kwargs["metadata"]["tags"]
|
||
assert "shared-tag" in kwargs["metadata"]["tags"]
|
||
assert kwargs["metadata"]["tags"].count("shared-tag") == 1
|
||
|
||
|
||
def test_update_kwargs_with_deployment_no_tags():
|
||
"""
|
||
Test that when deployment has no tags, kwargs metadata is not affected.
|
||
"""
|
||
router = litellm.Router(
|
||
model_list=[
|
||
{
|
||
"model_name": "gpt-4o-mini",
|
||
"litellm_params": {
|
||
"model": "openai/gpt-4o-mini",
|
||
"api_key": "fake-key",
|
||
},
|
||
},
|
||
],
|
||
)
|
||
|
||
kwargs: dict = {"metadata": {}}
|
||
deployment = router.get_deployment_by_model_group_name(
|
||
model_group_name="gpt-4o-mini"
|
||
)
|
||
router._update_kwargs_with_deployment(deployment=deployment, kwargs=kwargs)
|
||
|
||
# No tags key should be added if deployment has no tags
|
||
assert "tags" not in kwargs["metadata"]
|
||
|
||
|
||
def test_update_kwargs_with_deployment_merges_tools():
|
||
"""
|
||
Test that when both deployment litellm_params and request have tools,
|
||
they are merged (deployment tools first, then request tools).
|
||
|
||
Supports proxy-configured tools (e.g. for o3 deep research) merged with
|
||
client-provided tools.
|
||
"""
|
||
router = litellm.Router(
|
||
model_list=[
|
||
{
|
||
"model_name": "o3-deep-research",
|
||
"litellm_params": {
|
||
"model": "openai/o3-deep-research",
|
||
"api_key": "fake-key",
|
||
"tools": [{"type": "web_search"}],
|
||
"tool_choice": "auto",
|
||
},
|
||
},
|
||
],
|
||
)
|
||
|
||
kwargs: dict = {
|
||
"metadata": {},
|
||
"tools": [
|
||
{
|
||
"type": "function",
|
||
"function": {"name": "get_weather", "description": "Get weather"},
|
||
},
|
||
],
|
||
}
|
||
deployment = router.get_deployment_by_model_group_name(
|
||
model_group_name="o3-deep-research"
|
||
)
|
||
router._update_kwargs_with_deployment(deployment=deployment, kwargs=kwargs)
|
||
|
||
# Tools should be merged: deployment first, then request
|
||
assert "tools" in kwargs
|
||
assert len(kwargs["tools"]) == 2
|
||
assert kwargs["tools"][0] == {"type": "web_search"}
|
||
assert kwargs["tools"][1]["function"]["name"] == "get_weather"
|
||
# tool_choice from request (none) - deployment's should be used
|
||
assert kwargs["tool_choice"] == "auto"
|
||
|
||
|
||
def test_update_kwargs_with_deployment_merge_tools_deployment_only():
|
||
"""
|
||
Test that when only deployment has tools, they are applied to kwargs.
|
||
"""
|
||
router = litellm.Router(
|
||
model_list=[
|
||
{
|
||
"model_name": "o3-deep-research",
|
||
"litellm_params": {
|
||
"model": "openai/o3-deep-research",
|
||
"api_key": "fake-key",
|
||
"tools": [{"type": "web_search"}],
|
||
"tool_choice": "required",
|
||
},
|
||
},
|
||
],
|
||
)
|
||
|
||
kwargs: dict = {"metadata": {}}
|
||
deployment = router.get_deployment_by_model_group_name(
|
||
model_group_name="o3-deep-research"
|
||
)
|
||
router._update_kwargs_with_deployment(deployment=deployment, kwargs=kwargs)
|
||
|
||
assert kwargs["tools"] == [{"type": "web_search"}]
|
||
assert kwargs["tool_choice"] == "required"
|
||
|
||
|
||
def test_update_kwargs_with_deployment_merge_tools_request_overrides_tool_choice():
|
||
"""
|
||
Test that when request has tool_choice, it overrides deployment's.
|
||
"""
|
||
router = litellm.Router(
|
||
model_list=[
|
||
{
|
||
"model_name": "o3-deep-research",
|
||
"litellm_params": {
|
||
"model": "openai/o3-deep-research",
|
||
"api_key": "fake-key",
|
||
"tools": [{"type": "web_search"}],
|
||
"tool_choice": "auto",
|
||
},
|
||
},
|
||
],
|
||
)
|
||
|
||
kwargs: dict = {
|
||
"metadata": {},
|
||
"tool_choice": "none",
|
||
}
|
||
deployment = router.get_deployment_by_model_group_name(
|
||
model_group_name="o3-deep-research"
|
||
)
|
||
router._update_kwargs_with_deployment(deployment=deployment, kwargs=kwargs)
|
||
|
||
# Request tool_choice should be preserved (merged tools still applied)
|
||
assert kwargs["tool_choice"] == "none"
|
||
|
||
|
||
def test_credential_name_injected_as_tag():
|
||
"""
|
||
Test that litellm_credential_name from deployment litellm_params
|
||
is injected as a tag into metadata during _update_kwargs_with_deployment.
|
||
"""
|
||
router = litellm.Router(
|
||
model_list=[
|
||
{
|
||
"model_name": "xai-model",
|
||
"litellm_params": {
|
||
"model": "xai/grok-4-1-fast",
|
||
"litellm_credential_name": "xAI",
|
||
},
|
||
}
|
||
],
|
||
)
|
||
|
||
kwargs: dict = {"metadata": {"tags": ["A.101"]}}
|
||
deployment = router.get_deployment_by_model_group_name(model_group_name="xai-model")
|
||
router._update_kwargs_with_deployment(deployment=deployment, kwargs=kwargs)
|
||
|
||
assert "Credential: xAI" in kwargs["metadata"]["tags"]
|
||
assert "A.101" in kwargs["metadata"]["tags"]
|
||
|
||
|
||
def test_credential_name_not_duplicated_in_tags():
|
||
"""
|
||
Test that if the credential tag already exists in the tags list,
|
||
it is not duplicated.
|
||
"""
|
||
router = litellm.Router(
|
||
model_list=[
|
||
{
|
||
"model_name": "xai-model",
|
||
"litellm_params": {
|
||
"model": "xai/grok-4-1-fast",
|
||
"litellm_credential_name": "xAI",
|
||
},
|
||
}
|
||
],
|
||
)
|
||
|
||
kwargs: dict = {"metadata": {"tags": ["Credential: xAI", "A.101"]}}
|
||
deployment = router.get_deployment_by_model_group_name(model_group_name="xai-model")
|
||
router._update_kwargs_with_deployment(deployment=deployment, kwargs=kwargs)
|
||
|
||
assert kwargs["metadata"]["tags"].count("Credential: xAI") == 1
|
||
|
||
|
||
def test_credential_name_not_injected_when_absent():
|
||
"""
|
||
Test that when no litellm_credential_name is set, tags are unchanged.
|
||
"""
|
||
router = litellm.Router(
|
||
model_list=[
|
||
{
|
||
"model_name": "gpt-model",
|
||
"litellm_params": {
|
||
"model": "gpt-4o",
|
||
},
|
||
}
|
||
],
|
||
)
|
||
|
||
kwargs: dict = {"metadata": {"tags": ["A.101"]}}
|
||
deployment = router.get_deployment_by_model_group_name(model_group_name="gpt-model")
|
||
router._update_kwargs_with_deployment(deployment=deployment, kwargs=kwargs)
|
||
|
||
assert kwargs["metadata"]["tags"] == ["A.101"]
|
||
|
||
|
||
def test_update_kwargs_with_deployment_model_info_in_litellm_metadata():
|
||
"""For generic_api_call, model_info with pricing must go to litellm_metadata.
|
||
|
||
Routes like /messages and /responses use generic_api_call which stores
|
||
model_info under litellm_metadata. Regression test for #23185.
|
||
"""
|
||
router = litellm.Router(
|
||
model_list=[
|
||
{
|
||
"model_name": "claude-sonnet-4",
|
||
"litellm_params": {
|
||
"model": "anthropic/claude-sonnet-4-20250514",
|
||
"api_key": "fake-key",
|
||
},
|
||
"model_info": {
|
||
"id": "custom-pricing-id",
|
||
"input_cost_per_token": 0.0003,
|
||
"output_cost_per_token": 0.0015,
|
||
},
|
||
},
|
||
],
|
||
)
|
||
|
||
kwargs: dict = {}
|
||
deployment = router.get_deployment_by_model_group_name(
|
||
model_group_name="claude-sonnet-4"
|
||
)
|
||
router._update_kwargs_with_deployment(
|
||
deployment=deployment, kwargs=kwargs, function_name="generic_api_call"
|
||
)
|
||
|
||
assert "litellm_metadata" in kwargs
|
||
model_info = kwargs["litellm_metadata"]["model_info"]
|
||
assert model_info["id"] == "custom-pricing-id"
|
||
assert model_info["input_cost_per_token"] == 0.0003
|
||
assert model_info["output_cost_per_token"] == 0.0015
|
||
|
||
|
||
def test_update_kwargs_with_deployment_model_info_in_metadata():
|
||
"""For acompletion (function_name=None), model_info goes to metadata.
|
||
|
||
/chat/completions uses acompletion which stores model_info under metadata.
|
||
"""
|
||
router = litellm.Router(
|
||
model_list=[
|
||
{
|
||
"model_name": "claude-sonnet-4",
|
||
"litellm_params": {
|
||
"model": "anthropic/claude-sonnet-4-20250514",
|
||
"api_key": "fake-key",
|
||
},
|
||
"model_info": {
|
||
"id": "custom-pricing-id",
|
||
"input_cost_per_token": 0.0003,
|
||
"output_cost_per_token": 0.0015,
|
||
},
|
||
},
|
||
],
|
||
)
|
||
|
||
kwargs: dict = {}
|
||
deployment = router.get_deployment_by_model_group_name(
|
||
model_group_name="claude-sonnet-4"
|
||
)
|
||
router._update_kwargs_with_deployment(
|
||
deployment=deployment, kwargs=kwargs, function_name=None
|
||
)
|
||
|
||
assert "metadata" in kwargs
|
||
model_info = kwargs["metadata"]["model_info"]
|
||
assert model_info["id"] == "custom-pricing-id"
|
||
assert model_info["input_cost_per_token"] == 0.0003
|
||
assert model_info["output_cost_per_token"] == 0.0015
|
||
|
||
|
||
def test_combine_fallback_usage():
|
||
"""Test that _combine_fallback_usage merges partial and fallback usage."""
|
||
from litellm.router import Router
|
||
from litellm.types.utils import Usage
|
||
|
||
# Create a stream chunk with usage
|
||
chunk = litellm.ModelResponseStream(
|
||
id="test",
|
||
model="gpt-4o",
|
||
choices=[],
|
||
usage=Usage(prompt_tokens=10, completion_tokens=5, total_tokens=15),
|
||
)
|
||
|
||
# Call _combine_fallback_usage with no extra usage
|
||
Router._combine_fallback_usage(chunk, None)
|
||
assert chunk.usage is not None
|
||
assert chunk.usage.prompt_tokens == 10
|
||
assert chunk.usage.completion_tokens == 5
|
||
assert chunk.usage.total_tokens == 15
|