Files
litellm/tests/test_litellm/test_router.py
T
Krrish Dholakia bc829d51f2 test: test
2026-03-28 19:17:38 -07:00

2850 lines
95 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import copy
import json
import os
import sys
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from fastapi.testclient import TestClient
sys.path.insert(
0, os.path.abspath("../../..")
) # Adds the parent directory to the system path
import litellm
from litellm.router_utils.fallback_event_handlers import run_async_fallback
def test_update_kwargs_does_not_mutate_defaults_and_merges_metadata():
# initialize a real Router (envvars can be empty)
router = litellm.Router(
model_list=[
{
"model_name": "gpt-3.5-turbo",
"litellm_params": {
"model": "azure/gpt-4.1-mini",
"api_key": os.getenv("AZURE_AI_API_KEY"),
"api_version": os.getenv("AZURE_API_VERSION"),
"api_base": os.getenv("AZURE_AI_API_BASE"),
},
}
],
)
# override to known defaults for the test
router.default_litellm_params = {
"foo": "bar",
"metadata": {"baz": 123},
}
original = copy.deepcopy(router.default_litellm_params)
kwargs: dict = {}
# invoke the helper
router._update_kwargs_with_default_litellm_params(
kwargs=kwargs,
metadata_variable_name="litellm_metadata",
)
# 1) router.defaults must be unchanged
assert router.default_litellm_params == original
# 2) nonmetadata keys get merged
assert kwargs["foo"] == "bar"
# 3) metadata lands under "metadata"
assert kwargs["litellm_metadata"] == {"baz": 123}
def test_router_with_model_info_and_model_group():
"""
Test edge case where user specifies model_group in model_info
"""
router = litellm.Router(
model_list=[
{
"model_name": "gpt-3.5-turbo",
"litellm_params": {
"model": "gpt-3.5-turbo",
},
"model_info": {
"tpm": 1000,
"rpm": 1000,
"model_group": "gpt-3.5-turbo",
},
}
],
)
router._set_model_group_info(
model_group="gpt-3.5-turbo",
user_facing_model_group_name="gpt-3.5-turbo",
)
@pytest.mark.asyncio
async def test_arouter_with_tags_and_fallbacks():
"""
If fallback model missing tag, raise error
"""
from litellm import Router
router = Router(
model_list=[
{
"model_name": "gpt-3.5-turbo",
"litellm_params": {
"model": "gpt-3.5-turbo",
"mock_response": "Hello, world!",
"tags": ["test"],
},
},
{
"model_name": "anthropic-claude-3-5-sonnet",
"litellm_params": {
"model": "claude-sonnet-4-5-20250929",
"mock_response": "Hello, world 2!",
},
},
],
fallbacks=[
{"gpt-3.5-turbo": ["anthropic-claude-3-5-sonnet"]},
],
enable_tag_filtering=True,
)
with pytest.raises(Exception):
response = await router.acompletion(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "Hello, world!"}],
mock_testing_fallbacks=True,
metadata={"tags": ["test"]},
)
@pytest.mark.asyncio
async def test_async_router_acreate_file():
"""
Write to all deployments of a model
"""
from unittest.mock import MagicMock, call, patch
router = litellm.Router(
model_list=[
{
"model_name": "gpt-3.5-turbo",
"litellm_params": {"model": "gpt-3.5-turbo"},
},
{"model_name": "gpt-3.5-turbo", "litellm_params": {"model": "gpt-4o-mini"}},
],
)
with patch("litellm.acreate_file", return_value=MagicMock()) as mock_acreate_file:
mock_acreate_file.return_value = MagicMock()
response = await router.acreate_file(
model="gpt-3.5-turbo",
purpose="test",
file=MagicMock(),
)
# assert that the mock_acreate_file was called twice
assert mock_acreate_file.call_count == 2
@pytest.mark.asyncio
async def test_async_router_acreate_file_with_jsonl():
"""
Test router.acreate_file with both JSONL and non-JSONL files
"""
import json
from io import BytesIO
from unittest.mock import MagicMock, patch
# Create test JSONL content
jsonl_data = [
{
"body": {
"model": "gpt-3.5-turbo-router",
"messages": [{"role": "user", "content": "test"}],
}
},
{
"body": {
"model": "gpt-3.5-turbo-router",
"messages": [{"role": "user", "content": "test2"}],
}
},
]
jsonl_content = "\n".join(json.dumps(item) for item in jsonl_data)
jsonl_file = BytesIO(jsonl_content.encode("utf-8"))
jsonl_file.name = "test.jsonl"
# Create test non-JSONL content
non_jsonl_content = "This is not a JSONL file"
non_jsonl_file = BytesIO(non_jsonl_content.encode("utf-8"))
non_jsonl_file.name = "test.txt"
router = litellm.Router(
model_list=[
{
"model_name": "gpt-3.5-turbo-router",
"litellm_params": {"model": "gpt-3.5-turbo"},
},
{
"model_name": "gpt-3.5-turbo-router",
"litellm_params": {"model": "gpt-4o-mini"},
},
],
)
with patch("litellm.acreate_file", return_value=MagicMock()) as mock_acreate_file:
# Test with JSONL file
response = await router.acreate_file(
model="gpt-3.5-turbo-router",
purpose="batch",
file=jsonl_file,
)
# Verify mock was called twice (once for each deployment)
print(f"mock_acreate_file.call_count: {mock_acreate_file.call_count}")
print(f"mock_acreate_file.call_args_list: {mock_acreate_file.call_args_list}")
assert mock_acreate_file.call_count == 2
# Get the file content passed to the first call
first_call_file = mock_acreate_file.call_args_list[0][1]["file"]
first_call_content = first_call_file.read().decode("utf-8")
# Verify the model name was replaced in the JSONL content
first_line = json.loads(first_call_content.split("\n")[0])
assert first_line["body"]["model"] == "gpt-3.5-turbo"
# Reset mock for next test
mock_acreate_file.reset_mock()
# Test with non-JSONL file
response = await router.acreate_file(
model="gpt-3.5-turbo-router",
purpose="user_data",
file=non_jsonl_file,
)
# Verify mock was called twice
assert mock_acreate_file.call_count == 2
# Get the file content passed to the first call
first_call_file = mock_acreate_file.call_args_list[0][1]["file"]
first_call_content = first_call_file.read().decode("utf-8")
# Verify the non-JSONL content was not modified
assert first_call_content == non_jsonl_content
@pytest.mark.asyncio
async def test_arouter_async_get_healthy_deployments():
"""
Test that afile_content returns the correct file content
"""
router = litellm.Router(
model_list=[
{
"model_name": "gpt-3.5-turbo",
"litellm_params": {"model": "gpt-3.5-turbo"},
},
],
)
result = await router.async_get_healthy_deployments(
model="gpt-3.5-turbo",
request_kwargs={},
messages=None,
input=None,
specific_deployment=False,
parent_otel_span=None,
)
assert len(result) == 1
assert result[0]["model_name"] == "gpt-3.5-turbo"
assert result[0]["litellm_params"]["model"] == "gpt-3.5-turbo"
@pytest.mark.asyncio
@patch("litellm.amoderation")
async def test_arouter_amoderation_with_credential_name(mock_amoderation):
"""
Test that router.amoderation passes litellm_credential_name to the underlying litellm.amoderation call
"""
mock_amoderation.return_value = AsyncMock()
router = litellm.Router(
model_list=[
{
"model_name": "text-moderation-stable",
"litellm_params": {
"model": "text-moderation-stable",
"litellm_credential_name": "my-custom-auth",
},
},
],
)
await router.amoderation(input="I love everyone!", model="text-moderation-stable")
mock_amoderation.assert_called_once()
call_kwargs = mock_amoderation.call_args[1] # Get the kwargs of the call
print(
"call kwargs for router.amoderation=",
json.dumps(call_kwargs, indent=4, default=str),
)
assert call_kwargs["litellm_credential_name"] == "my-custom-auth"
assert call_kwargs["model"] == "text-moderation-stable"
def test_arouter_test_team_model():
"""
Test that router.test_team_model returns the correct model
"""
router = litellm.Router(
model_list=[
{
"model_name": "gpt-3.5-turbo",
"litellm_params": {"model": "gpt-3.5-turbo"},
"model_info": {
"team_id": "test-team",
"team_public_model_name": "test-model",
},
},
],
)
result = router.map_team_model(team_model_name="test-model", team_id="test-team")
assert result is not None
def test_arouter_ignore_invalid_deployments():
"""
Test that router.ignore_invalid_deployments is set to True
"""
from litellm.types.router import Deployment
router = litellm.Router(
model_list=[
{
"model_name": "gpt-3.5-turbo",
"litellm_params": {"model": "my-bad-model"},
},
],
ignore_invalid_deployments=True,
)
assert router.ignore_invalid_deployments is True
assert router.get_model_list() == []
## check upsert deployment
router.upsert_deployment(
Deployment(
model_name="gpt-3.5-turbo",
litellm_params={"model": "my-bad-model"}, # type: ignore
model_info={"tpm": 1000, "rpm": 1000},
)
)
assert router.get_model_list() == []
@pytest.mark.asyncio
async def test_arouter_aretrieve_batch():
"""
Test that router.aretrieve_batch returns the correct response
"""
router = litellm.Router(
model_list=[
{
"model_name": "gpt-3.5-turbo",
"litellm_params": {
"model": "gpt-3.5-turbo",
"custom_llm_provider": "azure",
"api_key": "my-custom-key",
"api_base": "my-custom-base",
},
}
],
)
with patch.object(
litellm, "aretrieve_batch", return_value=AsyncMock()
) as mock_aretrieve_batch:
try:
response = await router.aretrieve_batch(
model="gpt-3.5-turbo",
)
except Exception as e:
print(f"Error: {e}")
mock_aretrieve_batch.assert_called_once()
print(mock_aretrieve_batch.call_args.kwargs)
assert mock_aretrieve_batch.call_args.kwargs["api_key"] == "my-custom-key"
assert mock_aretrieve_batch.call_args.kwargs["api_base"] == "my-custom-base"
@pytest.mark.asyncio
async def test_arouter_aretrieve_file_content():
"""
Test that router.acreate_file with JSONL file returns the correct response
"""
with patch.object(
litellm, "afile_content", return_value=AsyncMock()
) as mock_afile_content:
router = litellm.Router(
model_list=[
{
"model_name": "gpt-3.5-turbo",
"litellm_params": {
"model": "gpt-3.5-turbo",
"custom_llm_provider": "azure",
"api_key": "my-custom-key",
"api_base": "my-custom-base",
},
}
],
)
try:
response = await router.afile_content(
**{
"model": "gpt-3.5-turbo",
"file_id": "my-unique-file-id",
}
) # type: ignore
except Exception as e:
print(f"Error: {e}")
mock_afile_content.assert_called_once()
print(mock_afile_content.call_args.kwargs)
assert mock_afile_content.call_args.kwargs["api_key"] == "my-custom-key"
assert mock_afile_content.call_args.kwargs["api_base"] == "my-custom-base"
@pytest.mark.asyncio
async def test_arouter_filter_team_based_models():
"""
Test that router.filter_team_based_models filters out models that are not in the team
"""
from litellm.types.router import Deployment
router = litellm.Router(
model_list=[
{
"model_name": "gpt-3.5-turbo",
"litellm_params": {"model": "gpt-3.5-turbo"},
"model_info": {
"team_id": "test-team",
},
},
],
)
# WORKS
result = await router.acompletion(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "Hello, world!"}],
metadata={"user_api_key_team_id": "test-team"},
mock_response="Hello, world!",
)
assert result is not None
# FAILS
with pytest.raises(Exception) as e:
result = await router.acompletion(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "Hello, world!"}],
metadata={"user_api_key_team_id": "test-team-2"},
mock_response="Hello, world!",
)
assert "No deployments available" in str(e.value)
## ADD A MODEL THAT IS NOT IN THE TEAM
router.add_deployment(
Deployment(
model_name="gpt-3.5-turbo",
litellm_params={"model": "gpt-3.5-turbo"}, # type: ignore
model_info={"tpm": 1000, "rpm": 1000},
)
)
result = await router.acompletion(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "Hello, world!"}],
metadata={"user_api_key_team_id": "test-team-2"},
mock_response="Hello, world!",
)
assert result is not None
def test_arouter_should_include_deployment():
"""
Test the should_include_deployment method with various scenarios
The method logic:
1. Returns True if: team_id matches AND model_name matches team_public_model_name
2. Returns True if: model_name matches AND deployment has no team_id
3. Otherwise returns False
"""
router = litellm.Router(
model_list=[
{
"model_name": "gpt-3.5-turbo",
"litellm_params": {"model": "gpt-3.5-turbo"},
"model_info": {
"team_id": "test-team",
},
},
],
)
# Test deployment structures
deployment_with_team_and_public_name = {
"model_name": "gpt-3.5-turbo",
"model_info": {
"team_id": "test-team",
"team_public_model_name": "team-gpt-model",
},
}
deployment_with_team_no_public_name = {
"model_name": "gpt-3.5-turbo",
"model_info": {
"team_id": "test-team",
},
}
deployment_without_team = {
"model_name": "gpt-4",
"model_info": {},
}
deployment_different_team = {
"model_name": "claude-3",
"model_info": {
"team_id": "other-team",
"team_public_model_name": "team-claude-model",
},
}
# Test Case 1: Team-specific deployment - team_id and team_public_model_name match
result = router.should_include_deployment(
model_name="team-gpt-model",
model=deployment_with_team_and_public_name,
team_id="test-team",
)
assert (
result is True
), "Should return True when team_id and team_public_model_name match"
# Test Case 2: Team-specific deployment - team_id matches but model_name doesn't match team_public_model_name
result = router.should_include_deployment(
model_name="different-model",
model=deployment_with_team_and_public_name,
team_id="test-team",
)
assert (
result is False
), "Should return False when team_id matches but model_name doesn't match team_public_model_name"
# Test Case 3: Team-specific deployment - team_id doesn't match
result = router.should_include_deployment(
model_name="team-gpt-model",
model=deployment_with_team_and_public_name,
team_id="different-team",
)
assert result is False, "Should return False when team_id doesn't match"
# Test Case 4: Team-specific deployment with no team_public_model_name - should fail
result = router.should_include_deployment(
model_name="gpt-3.5-turbo",
model=deployment_with_team_no_public_name,
team_id="test-team",
)
assert (
result is True
), "Should return True when team deployment has no team_public_model_name to match"
# Test Case 5: Non-team deployment - model_name matches and no team_id
result = router.should_include_deployment(
model_name="gpt-4", model=deployment_without_team, team_id=None
)
assert (
result is True
), "Should return True when model_name matches and deployment has no team_id"
# Test Case 6: Non-team deployment - model_name matches but team_id provided (should still work)
result = router.should_include_deployment(
model_name="gpt-4", model=deployment_without_team, team_id="any-team"
)
assert (
result is True
), "Should return True when model_name matches non-team deployment, regardless of team_id param"
# Test Case 7: Non-team deployment - model_name doesn't match
result = router.should_include_deployment(
model_name="different-model", model=deployment_without_team, team_id=None
)
assert result is False, "Should return False when model_name doesn't match"
# Test Case 8: Team deployment accessed without matching team_id
result = router.should_include_deployment(
model_name="gpt-3.5-turbo",
model=deployment_with_team_and_public_name,
team_id=None,
)
assert (
result is True
), "Should return True when matching model with exact model_name"
def test_arouter_responses_api_bridge():
"""
Test that router.responses_api_bridge returns the correct response
"""
from unittest.mock import MagicMock, patch
from litellm.llms.custom_httpx.http_handler import HTTPHandler
router = litellm.Router(
model_list=[
{
"model_name": "[IP-approved] o3-pro",
"litellm_params": {
"model": "azure/responses/o_series/webinterface-o3-pro",
"api_base": "https://webhook.site/fba79dae-220a-4bb7-9a3a-8caa49604e55",
"api_key": "sk-1234567890",
"api_version": "preview",
"stream": True,
},
"model_info": {
"input_cost_per_token": 0.00002,
"output_cost_per_token": 0.00008,
},
}
],
)
## CONFIRM BRIDGE IS CALLED
with patch.object(litellm, "responses", return_value=AsyncMock()) as mock_responses:
result = router.completion(
model="[IP-approved] o3-pro",
messages=[{"role": "user", "content": "Hello, world!"}],
)
assert mock_responses.call_count == 1
## CONFIRM MODEL NAME IS STRIPPED
client = HTTPHandler()
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.headers = {"content-type": "application/json"}
mock_response.json.return_value = {
"id": "resp_test",
"object": "response",
"status": "completed",
"output": [],
}
mock_response.text = (
'{"id": "resp_test", "object": "response", "status": "completed", "output": []}'
)
with patch.object(client, "post", return_value=mock_response) as mock_post:
try:
result = router.completion(
model="[IP-approved] o3-pro",
messages=[{"role": "user", "content": "Hello, world!"}],
client=client,
num_retries=0,
)
except Exception as e:
print(f"Error: {e}")
assert mock_post.call_count == 1
assert (
mock_post.call_args.kwargs["url"]
== "https://webhook.site/fba79dae-220a-4bb7-9a3a-8caa49604e55/openai/v1/responses?api-version=preview"
)
assert mock_post.call_args.kwargs["json"]["model"] == "webinterface-o3-pro"
@pytest.mark.asyncio
async def test_router_v1_messages_fallbacks():
"""
Test that router.v1_messages_fallbacks returns the correct response
"""
router = litellm.Router(
model_list=[
{
"model_name": "claude-sonnet-4-5-20250929",
"litellm_params": {
"model": "anthropic/claude-sonnet-4-5-20250929",
"mock_response": "litellm.InternalServerError",
},
},
{
"model_name": "bedrock-claude",
"litellm_params": {
"model": "anthropic.claude-3-5-sonnet-20240620-v1:0",
"mock_response": "Hello, world I am a fallback!",
},
},
],
fallbacks=[
{"claude-sonnet-4-5-20250929": ["bedrock-claude"]},
],
)
result = await router.aanthropic_messages(
model="claude-sonnet-4-5-20250929",
messages=[{"role": "user", "content": "Hello, world!"}],
max_tokens=256,
)
assert result is not None
print(result)
assert result["content"][0]["text"] == "Hello, world I am a fallback!"
def test_add_invalid_provider_to_router():
"""
Test that router.add_deployment raises an error if the provider is invalid
"""
from litellm.types.router import Deployment
router = litellm.Router(
model_list=[
{
"model_name": "gpt-3.5-turbo",
"litellm_params": {"model": "gpt-3.5-turbo"},
}
],
)
with pytest.raises(Exception) as e:
router.add_deployment(
Deployment(
model_name="vertex_ai/*",
litellm_params={
"model": "vertex_ai/*",
"custom_llm_provider": "vertex_ai_eu",
},
)
)
assert router.pattern_router.patterns == {}
@pytest.mark.asyncio
async def test_router_ageneric_api_call_with_fallbacks_helper():
"""
Test the _ageneric_api_call_with_fallbacks_helper method with various scenarios
"""
from unittest.mock import AsyncMock, MagicMock, patch
router = litellm.Router(
model_list=[
{
"model_name": "gpt-3.5-turbo",
"litellm_params": {
"model": "gpt-3.5-turbo",
"api_key": "test-key",
"api_base": "https://api.openai.com/v1",
},
"model_info": {
"tpm": 1000,
"rpm": 1000,
},
},
],
)
# Test 1: Successful call
async def mock_generic_function(**kwargs):
return {"result": "success", "model": kwargs.get("model")}
with patch.object(router, "async_get_available_deployment") as mock_get_deployment:
mock_get_deployment.return_value = {
"model_name": "gpt-3.5-turbo",
"litellm_params": {
"model": "gpt-3.5-turbo",
"api_key": "test-key",
"api_base": "https://api.openai.com/v1",
},
}
with patch.object(
router, "_update_kwargs_with_deployment"
) as mock_update_kwargs:
with patch.object(
router, "async_routing_strategy_pre_call_checks"
) as mock_pre_call_checks:
with patch.object(
router, "_get_client", return_value=None
) as mock_get_client:
result = await router._ageneric_api_call_with_fallbacks_helper(
model="gpt-3.5-turbo",
original_generic_function=mock_generic_function,
messages=[{"role": "user", "content": "test"}],
)
assert result is not None
assert result["result"] == "success"
mock_get_deployment.assert_called_once()
mock_update_kwargs.assert_called_once()
mock_pre_call_checks.assert_called_once()
# Test 2: Passthrough on no deployment (success case)
async def mock_passthrough_function(**kwargs):
return {"result": "passthrough", "model": kwargs.get("model")}
with patch.object(router, "async_get_available_deployment") as mock_get_deployment:
mock_get_deployment.side_effect = Exception("No deployment available")
result = await router._ageneric_api_call_with_fallbacks_helper(
model="gpt-3.5-turbo",
original_generic_function=mock_passthrough_function,
passthrough_on_no_deployment=True,
messages=[{"role": "user", "content": "test"}],
)
assert result is not None
assert result["result"] == "passthrough"
assert result["model"] == "gpt-3.5-turbo"
# Test 3: No deployment available and passthrough=False (should raise exception)
with patch.object(router, "async_get_available_deployment") as mock_get_deployment:
mock_get_deployment.side_effect = Exception("No deployment available")
with pytest.raises(Exception) as exc_info:
await router._ageneric_api_call_with_fallbacks_helper(
model="gpt-3.5-turbo",
original_generic_function=mock_generic_function,
passthrough_on_no_deployment=False,
messages=[{"role": "user", "content": "test"}],
)
assert "No deployment available" in str(exc_info.value)
# Test 4: Test with semaphore (rate limiting)
import asyncio
async def mock_semaphore_function(**kwargs):
return {"result": "semaphore_success", "model": kwargs.get("model")}
with patch.object(router, "async_get_available_deployment") as mock_get_deployment:
mock_get_deployment.return_value = {
"model_name": "gpt-3.5-turbo",
"litellm_params": {
"model": "gpt-3.5-turbo",
"api_key": "test-key",
"api_base": "https://api.openai.com/v1",
},
}
mock_semaphore = asyncio.Semaphore(1)
with patch.object(
router, "_update_kwargs_with_deployment"
) as mock_update_kwargs:
with patch.object(
router, "_get_client", return_value=mock_semaphore
) as mock_get_client:
with patch.object(
router, "async_routing_strategy_pre_call_checks"
) as mock_pre_call_checks:
result = await router._ageneric_api_call_with_fallbacks_helper(
model="gpt-3.5-turbo",
original_generic_function=mock_semaphore_function,
messages=[{"role": "user", "content": "test"}],
)
assert result is not None
assert result["result"] == "semaphore_success"
mock_get_client.assert_called_once()
mock_pre_call_checks.assert_called_once()
# Test 5: Test call tracking (success and failure counts)
initial_success_count = router.success_calls.get("gpt-3.5-turbo", 0)
initial_fail_count = router.fail_calls.get("gpt-3.5-turbo", 0)
async def mock_failing_function(**kwargs):
raise Exception("Mock failure")
with patch.object(router, "async_get_available_deployment") as mock_get_deployment:
mock_get_deployment.return_value = {
"model_name": "gpt-3.5-turbo",
"litellm_params": {
"model": "gpt-3.5-turbo",
"api_key": "test-key",
"api_base": "https://api.openai.com/v1",
},
}
with patch.object(
router, "_update_kwargs_with_deployment"
) as mock_update_kwargs:
with patch.object(
router, "_get_client", return_value=None
) as mock_get_client:
with patch.object(
router, "async_routing_strategy_pre_call_checks"
) as mock_pre_call_checks:
with pytest.raises(Exception) as exc_info:
await router._ageneric_api_call_with_fallbacks_helper(
model="gpt-3.5-turbo",
original_generic_function=mock_failing_function,
messages=[{"role": "user", "content": "test"}],
)
assert "Mock failure" in str(exc_info.value)
# Check that fail_calls was incremented
assert router.fail_calls["gpt-3.5-turbo"] == initial_fail_count + 1
def test_router_get_model_access_groups_team_only_models():
"""
Test that Router.get_model_access_groups returns the correct response for team-only models
"""
router = litellm.Router(
model_list=[
{
"model_name": "my-custom-model-name",
"litellm_params": {"model": "gpt-3.5-turbo"},
"model_info": {
"team_id": "team_1",
"access_groups": ["default-models"],
"team_public_model_name": "gpt-3.5-turbo",
},
},
]
)
access_groups = router.get_model_access_groups(
model_name="gpt-3.5-turbo", team_id=None
)
assert len(access_groups) == 0
access_groups = router.get_model_access_groups(
model_name="gpt-3.5-turbo", team_id="team_1"
)
assert list(access_groups.keys()) == ["default-models"]
def test_cached_get_model_group_info():
"""
Test that _cached_get_model_group_info caches results and
invalidates on deployment changes.
"""
from litellm.types.router import Deployment, LiteLLM_Params
router = litellm.Router(
model_list=[
{
"model_name": "gpt-4",
"litellm_params": {"model": "gpt-4", "api_key": "fake"},
"model_info": {"tpm": 1000, "rpm": 100},
},
]
)
# First call should compute and cache
result1 = router._cached_get_model_group_info("gpt-4")
assert result1 is not None
assert result1.tpm == 1000
# Second call should hit cache (same object)
result2 = router._cached_get_model_group_info("gpt-4")
assert result1 is result2
# Add a deployment — cache should be invalidated
router.add_deployment(
Deployment(
model_name="gpt-4",
litellm_params=LiteLLM_Params(model="gpt-4", api_key="fake2"),
model_info={"tpm": 2000, "rpm": 200},
)
)
result3 = router._cached_get_model_group_info("gpt-4")
assert result3 is not result2
assert result3 is not None
assert result3.tpm == 3000 # 1000 + 2000
# Delete a deployment — cache should be invalidated
deployment_id = router.model_list[-1]["model_info"]["id"]
router.delete_deployment(id=deployment_id)
result4 = router._cached_get_model_group_info("gpt-4")
assert result4 is not result3
assert result4 is not None
assert result4.tpm == 1000
# set_model_list — cache should be invalidated
router.set_model_list(
[
{
"model_name": "gpt-4",
"litellm_params": {"model": "gpt-4", "api_key": "fake"},
"model_info": {"tpm": 5000},
},
]
)
result5 = router._cached_get_model_group_info("gpt-4")
assert result5 is not result4
assert result5 is not None
assert result5.tpm == 5000
# Verify cache still works after invalidation
result6 = router._cached_get_model_group_info("gpt-4")
assert result5 is result6
def test_get_model_access_groups_caching():
"""
Test that get_model_access_groups caches the no-args result
and invalidates on deployment changes.
"""
from litellm.types.router import Deployment, LiteLLM_Params
router = litellm.Router(
model_list=[
{
"model_name": "gpt-4",
"litellm_params": {"model": "gpt-4"},
"model_info": {"access_groups": ["premium"]},
},
]
)
# First call computes and populates cache
result1 = router.get_model_access_groups()
assert "premium" in result1
# All subsequent calls should return the same cached object (including first)
result2 = router.get_model_access_groups()
assert result1 is result2
# Calls with args should bypass cache
result_with_args = router.get_model_access_groups(model_name="gpt-4")
assert result_with_args is not result2
# Add a deployment — cache should be invalidated
router.add_deployment(
Deployment(
model_name="gpt-3.5",
litellm_params=LiteLLM_Params(model="gpt-3.5-turbo"),
model_info={"access_groups": ["default"]},
)
)
result3 = router.get_model_access_groups()
assert result3 is not result2
assert "premium" in result3
assert "default" in result3
# Delete the deployment — cache should be invalidated again
deployment_id = None
for m in router.model_list:
if m.get("model_name") == "gpt-3.5":
deployment_id = m.get("model_info", {}).get("id")
break
assert deployment_id is not None
router.delete_deployment(id=deployment_id)
result4 = router.get_model_access_groups()
assert result4 is not result3
assert "default" not in result4
assert "premium" in result4
def test_get_model_access_groups_cache_invalidation_set_model_list():
"""
Test that set_model_list invalidates the access groups cache.
"""
router = litellm.Router(
model_list=[
{
"model_name": "gpt-4",
"litellm_params": {"model": "gpt-4"},
"model_info": {"access_groups": ["premium"]},
},
]
)
# Populate cache
result1 = router.get_model_access_groups()
assert "premium" in result1
# set_model_list should invalidate cache
router.set_model_list(
[
{
"model_name": "claude-3",
"litellm_params": {"model": "anthropic/claude-3-opus-20240229"},
"model_info": {"access_groups": ["research"]},
},
]
)
result2 = router.get_model_access_groups()
assert result2 is not result1
assert "research" in result2
assert "premium" not in result2
def test_get_model_access_groups_cache_invalidation_upsert_deployment():
"""
Test that upsert_deployment invalidates the access groups cache.
"""
from litellm.types.router import Deployment, LiteLLM_Params
router = litellm.Router(
model_list=[
{
"model_name": "gpt-4",
"litellm_params": {"model": "gpt-4"},
"model_info": {"access_groups": ["premium"]},
},
]
)
# Populate cache
result1 = router.get_model_access_groups()
assert "premium" in result1
# Get the existing deployment's ID
existing_id = router.model_list[0]["model_info"]["id"]
# Upsert with the same ID but different params — triggers pop + re-add
router.upsert_deployment(
Deployment(
model_name="gpt-4-updated",
litellm_params=LiteLLM_Params(model="gpt-4-turbo"),
model_info={"id": existing_id, "access_groups": ["updated-group"]},
)
)
result2 = router.get_model_access_groups()
assert result2 is not result1
assert "updated-group" in result2
@pytest.mark.asyncio
async def test_acompletion_streaming_iterator():
"""Test _acompletion_streaming_iterator for normal streaming and fallback behavior."""
from unittest.mock import AsyncMock, MagicMock
from litellm.exceptions import MidStreamFallbackError
from litellm.types.utils import ModelResponseStream
# Helper class for creating async iterators
class AsyncIterator:
def __init__(self, items, error_after=None):
self.items = items
self.index = 0
self.error_after = error_after
def __aiter__(self):
return self
async def __anext__(self):
if self.error_after is not None and self.index >= self.error_after:
raise self.error_after
if self.index >= len(self.items):
raise StopAsyncIteration
item = self.items[self.index]
self.index += 1
return item
# Set up router with fallback configuration
router = litellm.Router(
model_list=[
{
"model_name": "gpt-4",
"litellm_params": {"model": "gpt-4", "api_key": "fake-key-1"},
},
{
"model_name": "gpt-3.5-turbo",
"litellm_params": {"model": "gpt-3.5-turbo", "api_key": "fake-key-2"},
},
],
fallbacks=[{"gpt-4": ["gpt-3.5-turbo"]}],
set_verbose=True,
)
# Test data
messages = [{"role": "user", "content": "Hello"}]
initial_kwargs = {"model": "gpt-4", "stream": True, "temperature": 0.7}
# Test 1: Successful streaming (no errors)
print("\n=== Test 1: Successful streaming ===")
# Mock successful streaming response
mock_chunks = [
MagicMock(choices=[MagicMock(delta=MagicMock(content="Hello"))]),
MagicMock(choices=[MagicMock(delta=MagicMock(content=" there"))]),
MagicMock(choices=[MagicMock(delta=MagicMock(content="!"))]),
]
mock_response = AsyncIterator(mock_chunks)
setattr(mock_response, "model", "gpt-4")
setattr(mock_response, "custom_llm_provider", "openai")
setattr(mock_response, "logging_obj", MagicMock())
result = await router._acompletion_streaming_iterator(
model_response=mock_response, messages=messages, initial_kwargs=initial_kwargs
)
# Collect streamed chunks
collected_chunks = []
async for chunk in result:
collected_chunks.append(chunk)
assert len(collected_chunks) == 3
assert all(chunk in mock_chunks for chunk in collected_chunks)
print("✓ Successfully streamed all chunks")
# Test 2: MidStreamFallbackError with fallback
print("\n=== Test 2: MidStreamFallbackError with fallback ===")
# Create error that should trigger after first chunk
error = MidStreamFallbackError(
message="Connection lost",
model="gpt-4",
llm_provider="openai",
generated_content="Hello",
)
class AsyncIteratorWithError:
def __init__(self, items, error_after_index):
self.items = items
self.index = 0
self.error_after_index = error_after_index
self.chunks = []
def __aiter__(self):
return self
async def __anext__(self):
if self.index >= len(self.items):
raise StopAsyncIteration
if self.index == self.error_after_index:
raise error
item = self.items[self.index]
self.index += 1
return item
mock_error_response = AsyncIteratorWithError(
mock_chunks, 1
) # Error after first chunk
setattr(mock_error_response, "model", "gpt-4")
setattr(mock_error_response, "custom_llm_provider", "openai")
setattr(mock_error_response, "logging_obj", MagicMock())
# Mock the fallback response
fallback_chunks = [
MagicMock(choices=[MagicMock(delta=MagicMock(content=" world"))]),
MagicMock(choices=[MagicMock(delta=MagicMock(content="!"))]),
]
mock_fallback_response = AsyncIterator(fallback_chunks)
# Mock the fallback function
with patch.object(
router,
"async_function_with_fallbacks_common_utils",
return_value=mock_fallback_response,
) as mock_fallback_utils:
collected_chunks = []
result = await router._acompletion_streaming_iterator(
model_response=mock_error_response,
messages=messages,
initial_kwargs=initial_kwargs,
)
async for chunk in result:
collected_chunks.append(chunk)
# Verify fallback was called
assert mock_fallback_utils.called
call_args = mock_fallback_utils.call_args
# Check that generated content was added to messages
fallback_kwargs = call_args.kwargs["kwargs"]
modified_messages = fallback_kwargs["messages"]
# Should have original message + system message + assistant message with prefix
assert len(modified_messages) == 3
assert modified_messages[0] == {"role": "user", "content": "Hello"}
assert modified_messages[1]["role"] == "system"
assert "continuation" in modified_messages[1]["content"]
assert modified_messages[2]["role"] == "assistant"
assert modified_messages[2]["content"] == "Hello"
assert modified_messages[2]["prefix"] == True
# Verify fallback parameters
assert call_args.kwargs["disable_fallbacks"] == False
assert call_args.kwargs["model_group"] == "gpt-4"
# Should get original chunk + fallback chunks
assert len(collected_chunks) == 3 # 1 original + 2 fallback
print("✓ Fallback system called correctly with proper message modification")
print("\n=== All tests passed! ===")
@pytest.mark.asyncio
async def test_acompletion_streaming_iterator_edge_cases():
"""Test edge cases for _acompletion_streaming_iterator."""
from unittest.mock import MagicMock
from litellm.exceptions import MidStreamFallbackError
router = litellm.Router(
model_list=[
{
"model_name": "gpt-4",
"litellm_params": {"model": "gpt-4", "api_key": "fake-key"},
}
],
set_verbose=True,
)
messages = [{"role": "user", "content": "Test"}]
initial_kwargs = {"model": "gpt-4", "stream": True}
# Test: Empty generated content
empty_error = MidStreamFallbackError(
message="Error",
model="gpt-4",
llm_provider="openai",
generated_content="", # Empty content
)
class AsyncIteratorImmediateError:
def __init__(self):
self.model = "gpt-4"
self.custom_llm_provider = "openai"
self.logging_obj = MagicMock()
self.chunks = []
def __aiter__(self):
return self
async def __anext__(self):
raise empty_error
mock_response = AsyncIteratorImmediateError()
# Mock empty fallback response using AsyncIterator
class EmptyAsyncIterator:
def __aiter__(self):
return self
async def __anext__(self):
raise StopAsyncIteration
mock_fallback_response = EmptyAsyncIterator()
with patch.object(
router,
"async_function_with_fallbacks_common_utils",
return_value=mock_fallback_response,
) as mock_fallback_utils:
collected_chunks = []
iterator = await router._acompletion_streaming_iterator(
model_response=mock_response,
messages=messages,
initial_kwargs=initial_kwargs,
)
async for chunk in iterator:
collected_chunks.append(chunk)
# Should still call fallback even with empty content
assert mock_fallback_utils.called
fallback_kwargs = mock_fallback_utils.call_args.kwargs["kwargs"]
modified_messages = fallback_kwargs["messages"]
# Empty content → pre-first-chunk path uses original messages
# (no continuation prompt added)
assert modified_messages == messages
print("✓ Handles empty generated content correctly")
print("✓ Edge case tests passed!")
@pytest.mark.asyncio
async def test_acompletion_streaming_iterator_preserves_hidden_params():
"""
Regression test: FallbackStreamWrapper must copy _hidden_params from the
original CustomStreamWrapper so that x-litellm-overhead-duration-ms (and
other hidden params) are present in the proxy response headers for streaming.
"""
from unittest.mock import MagicMock
router = litellm.Router(
model_list=[
{
"model_name": "gpt-4",
"litellm_params": {"model": "gpt-4", "api_key": "fake-key"},
}
],
)
# Simulate a CustomStreamWrapper that already has timing metadata set by
# update_response_metadata (litellm_overhead_time_ms, _response_ms, etc.)
mock_response = MagicMock()
mock_response.model = "gpt-4"
mock_response.custom_llm_provider = "openai"
mock_response.logging_obj = MagicMock()
mock_response._hidden_params = {
"litellm_overhead_time_ms": 12.34,
"_response_ms": 500.0,
"litellm_call_id": "test-call-id",
"api_base": "https://api.openai.com",
"additional_headers": {},
}
# Make the mock iterable (yields nothing — we only care about hidden_params)
async def _empty():
return
yield # make it an async generator
mock_response.__aiter__ = lambda self: _empty().__aiter__()
result = await router._acompletion_streaming_iterator(
model_response=mock_response,
messages=[{"role": "user", "content": "hi"}],
initial_kwargs={"model": "gpt-4", "stream": True},
)
# The returned FallbackStreamWrapper must carry the original _hidden_params
assert hasattr(result, "_hidden_params"), "result must have _hidden_params"
assert result._hidden_params.get("litellm_overhead_time_ms") == 12.34, (
"litellm_overhead_time_ms must be preserved — "
"this is what drives x-litellm-overhead-duration-ms in streaming responses"
)
assert result._hidden_params.get("litellm_call_id") == "test-call-id"
assert result._hidden_params.get("_response_ms") == 500.0
def test_completion_streaming_iterator_fallback_on_429():
"""Sync streaming: MidStreamFallbackError (429 pre-first-chunk) triggers fallback.
This is the sync counterpart of test_acompletion_streaming_iterator.
Before this fix, __next__ raised RateLimitError directly and the Router
never got a chance to fall back.
"""
from unittest.mock import MagicMock
from litellm.exceptions import MidStreamFallbackError
router = litellm.Router(
model_list=[
{
"model_name": "gpt-4",
"litellm_params": {"model": "gpt-4", "api_key": "fake-key"},
}
],
)
messages = [{"role": "user", "content": "Test"}]
initial_kwargs = {"model": "gpt-4", "stream": True}
rate_limit_error = MidStreamFallbackError(
message="Resource exhausted",
model="gpt-4",
llm_provider="vertex_ai",
generated_content="",
is_pre_first_chunk=True,
)
class SyncIteratorImmediateError:
def __init__(self):
self.model = "gpt-4"
self.custom_llm_provider = "openai"
self.logging_obj = MagicMock()
self.chunks = []
def __iter__(self):
return self
def __next__(self):
raise rate_limit_error
mock_response = SyncIteratorImmediateError()
# Fallback returns a simple non-streaming response (fallback may not stream)
mock_fallback_response = MagicMock()
mock_fallback_response.__iter__ = MagicMock(return_value=iter([]))
with patch.object(
router,
"function_with_fallbacks",
return_value=mock_fallback_response,
) as mock_fallback:
result = router._completion_streaming_iterator(
model_response=mock_response,
messages=messages,
initial_kwargs=initial_kwargs,
)
collected_chunks = list(result)
assert mock_fallback.called
call_kwargs = mock_fallback.call_args
# Pre-first-chunk: should use original messages, no continuation prompt
assert call_kwargs.kwargs.get("messages") == messages
# Verify original_function is _completion (sync)
assert call_kwargs.kwargs.get("original_function") == router._completion
def test_completion_streaming_iterator_preserves_hidden_params():
"""SyncFallbackStreamWrapper must copy _hidden_params from original response."""
from unittest.mock import MagicMock
router = litellm.Router(
model_list=[
{
"model_name": "gpt-4",
"litellm_params": {"model": "gpt-4", "api_key": "fake-key"},
}
],
)
mock_response = MagicMock()
mock_response.model = "gpt-4"
mock_response.custom_llm_provider = "openai"
mock_response.logging_obj = MagicMock()
mock_response._hidden_params = {
"litellm_overhead_time_ms": 42.0,
"litellm_call_id": "test-sync-call",
}
mock_response.__iter__ = MagicMock(return_value=iter([]))
result = router._completion_streaming_iterator(
model_response=mock_response,
messages=[{"role": "user", "content": "hi"}],
initial_kwargs={"model": "gpt-4", "stream": True},
)
assert hasattr(result, "_hidden_params")
assert result._hidden_params.get("litellm_overhead_time_ms") == 42.0
assert result._hidden_params.get("litellm_call_id") == "test-sync-call"
@pytest.mark.asyncio
async def test_acompletion_streaming_iterator_pre_first_chunk_skips_continuation():
"""When MidStreamFallbackError has is_pre_first_chunk=True, use original messages."""
from unittest.mock import MagicMock
from litellm.exceptions import MidStreamFallbackError
router = litellm.Router(
model_list=[
{
"model_name": "gpt-4",
"litellm_params": {"model": "gpt-4", "api_key": "fake-key"},
}
],
)
messages = [{"role": "user", "content": "Hello"}]
initial_kwargs = {"model": "gpt-4", "stream": True}
pre_first_chunk_error = MidStreamFallbackError(
message="429 Resource exhausted",
model="gpt-4",
llm_provider="vertex_ai",
generated_content="",
is_pre_first_chunk=True,
)
class AsyncIteratorPreFirstChunkError:
def __init__(self):
self.model = "gpt-4"
self.custom_llm_provider = "openai"
self.logging_obj = MagicMock()
self.chunks = []
def __aiter__(self):
return self
async def __anext__(self):
raise pre_first_chunk_error
mock_response = AsyncIteratorPreFirstChunkError()
class EmptyAsyncIterator:
def __aiter__(self):
return self
async def __anext__(self):
raise StopAsyncIteration
with patch.object(
router,
"async_function_with_fallbacks_common_utils",
return_value=EmptyAsyncIterator(),
) as mock_fallback_utils:
iterator = await router._acompletion_streaming_iterator(
model_response=mock_response,
messages=messages,
initial_kwargs=initial_kwargs,
)
async for _ in iterator:
pass
assert mock_fallback_utils.called
fallback_kwargs = mock_fallback_utils.call_args.kwargs["kwargs"]
# Pre-first-chunk: should use original messages, no continuation prompt
assert fallback_kwargs["messages"] == messages
@pytest.mark.asyncio
async def test_async_function_with_fallbacks_common_utils():
"""Test the async_function_with_fallbacks_common_utils method"""
# Create a basic router for testing
router = litellm.Router(
model_list=[
{
"model_name": "gpt-3.5-turbo",
"litellm_params": {
"model": "gpt-3.5-turbo",
},
}
],
max_fallbacks=5,
)
# Test case 1: disable_fallbacks=True should raise original exception
test_exception = Exception("Test error")
with pytest.raises(Exception, match="Test error"):
await router.async_function_with_fallbacks_common_utils(
e=test_exception,
disable_fallbacks=True,
fallbacks=None,
context_window_fallbacks=None,
content_policy_fallbacks=None,
model_group="gpt-3.5-turbo",
args=(),
kwargs=MagicMock(),
)
# Test case 2: original_model_group=None should raise original exception
with pytest.raises(Exception, match="Test error"):
await router.async_function_with_fallbacks_common_utils(
e=test_exception,
disable_fallbacks=False,
fallbacks=None,
context_window_fallbacks=None,
content_policy_fallbacks=None,
model_group="gpt-3.5-turbo",
args=(),
kwargs={}, # No model key
)
def test_should_include_deployment():
"""Test that Router.should_include_deployment returns the correct response"""
router = litellm.Router(
model_list=[
{
"model_name": "model_name_a28a12f9-3e44-4861-bd4f-325f2d309ce8_cd5dc6fb-b046-4e05-ae1d-32ba4d936266",
"litellm_params": {"model": "openai/*"},
"model_info": {
"team_id": "a28a12f9-3e44-4861-bd4f-325f2d309ce8",
"team_public_model_name": "openai/*",
},
}
],
)
model = {
"model_name": "model_name_a28a12f9-3e44-4861-bd4f-325f2d309ce8_cd5dc6fb-b046-4e05-ae1d-32ba4d936266",
"litellm_params": {
"api_key": "sk-proj-1234567890",
"custom_llm_provider": "openai",
"use_in_pass_through": False,
"use_litellm_proxy": False,
"merge_reasoning_content_in_choices": False,
"model": "openai/*",
},
"model_info": {
"id": "95f58039-d54a-4d1c-b700-5e32e99a1120",
"db_model": True,
"updated_by": "64a2f787-0863-4d76-9516-2dc49c1598e8",
"created_by": "64a2f787-0863-4d76-9516-2dc49c1598e8",
"team_id": "a28a12f9-3e44-4861-bd4f-325f2d309ce8",
"team_public_model_name": "openai/*",
"mode": "completion",
"access_groups": ["restricted-models-openai"],
},
}
model_name = "openai/o4-mini-deep-research"
team_id = "a28a12f9-3e44-4861-bd4f-325f2d309ce8"
assert router.get_model_list(
model_name=model_name,
team_id=team_id,
)
def test_get_deployment_model_info_base_model_flow():
"""Test that get_deployment_model_info correctly handles the base model flow"""
from unittest.mock import patch
router = litellm.Router(
model_list=[
{
"model_name": "test-model",
"litellm_params": {"model": "gpt-3.5-turbo"},
}
],
)
# Mock data for the test
mock_custom_model_info = {
"base_model": "gpt-3.5-turbo",
"input_cost_per_token": 0.001,
"output_cost_per_token": 0.002,
"custom_field": "custom_value",
}
mock_base_model_info = {
"key": "gpt-3.5-turbo",
"max_tokens": 4096,
"max_input_tokens": 4096,
"max_output_tokens": 4096,
"input_cost_per_token": 0.0015, # This should be overridden by custom model info
"output_cost_per_token": 0.002,
"litellm_provider": "openai",
"mode": "chat",
"supported_openai_params": ["temperature", "max_tokens"],
}
mock_litellm_model_name_info = {
"key": "test-model",
"max_tokens": 2048,
"max_input_tokens": 2048,
"max_output_tokens": 2048,
"input_cost_per_token": 0.0005,
"output_cost_per_token": 0.001,
"litellm_provider": "test_provider",
"mode": "completion",
"supported_openai_params": ["temperature"],
}
# Test Case 1: Base model flow with custom model info that has base_model
with patch.object(
litellm, "model_cost", {"test-custom-model": mock_custom_model_info}
):
with patch.object(litellm, "get_model_info") as mock_get_model_info:
# Configure mock returns
mock_get_model_info.side_effect = lambda model: {
"gpt-3.5-turbo": mock_base_model_info,
"test-model": mock_litellm_model_name_info,
}.get(model)
result = router.get_deployment_model_info(
model_id="test-custom-model", model_name="test-model"
)
# Verify that get_model_info was called for both base model and model name
assert mock_get_model_info.call_count == 2
mock_get_model_info.assert_any_call(
model="gpt-3.5-turbo"
) # base model call
mock_get_model_info.assert_any_call(model="test-model") # model name call
# Verify the result contains merged information
assert result is not None
# Test the correct merging behavior after fix:
# 1. base_model_info provides defaults, custom_model_info overrides (correct priority)
# 2. The result of step 1 gets merged into litellm_model_name_info (custom+base override litellm)
# Fields from custom model (should override base model values)
assert (
result["input_cost_per_token"] == 0.001
) # From custom model (overrides base 0.0015)
assert (
result["output_cost_per_token"] == 0.002
) # From custom model (same as base)
assert result["custom_field"] == "custom_value" # From custom model
# Fields from base model that weren't overridden by custom
assert result["max_tokens"] == 4096 # From base model
assert result["litellm_provider"] == "openai" # From base model
assert (
result["mode"] == "chat"
) # From base model (overrides litellm "completion")
# The key field comes from base model since both base and litellm have it
# and base model info overrides litellm model name info in final merge
assert (
result["key"] == "gpt-3.5-turbo"
) # From base model (overrides litellm key)
# Test Case 2: Custom model info without base_model
mock_custom_model_info_no_base = {
"input_cost_per_token": 0.001,
"output_cost_per_token": 0.002,
"custom_field": "custom_value",
}
with patch.object(
litellm,
"model_cost",
{"test-custom-model-no-base": mock_custom_model_info_no_base},
):
with patch.object(litellm, "get_model_info") as mock_get_model_info:
mock_get_model_info.side_effect = lambda model: {
"test-model": mock_litellm_model_name_info,
}.get(model)
result = router.get_deployment_model_info(
model_id="test-custom-model-no-base", model_name="test-model"
)
# Should only call get_model_info once for model name (no base model)
assert mock_get_model_info.call_count == 1
mock_get_model_info.assert_called_with(model="test-model")
# Verify the result contains merged information
assert result is not None
assert result["input_cost_per_token"] == 0.001 # From custom model
assert result["max_tokens"] == 2048 # From litellm model name info
assert result["custom_field"] == "custom_value" # From custom model
assert result["mode"] == "completion" # From litellm model name info
# Test Case 3: No custom model info, only litellm model name info
with patch.object(litellm, "model_cost", {}): # Empty model cost
with patch.object(litellm, "get_model_info") as mock_get_model_info:
mock_get_model_info.side_effect = lambda model: {
"test-model": mock_litellm_model_name_info,
}.get(model)
result = router.get_deployment_model_info(
model_id="non-existent-model", model_name="test-model"
)
# Should only call get_model_info once for model name
assert mock_get_model_info.call_count == 1
mock_get_model_info.assert_called_with(model="test-model")
# Result should be just the litellm model name info
assert result is not None
assert result == mock_litellm_model_name_info
# Test Case 4: Base model info retrieval fails (exception handling)
mock_custom_model_info_invalid_base = {
"base_model": "invalid-base-model",
"input_cost_per_token": 0.001,
"output_cost_per_token": 0.002,
}
with patch.object(
litellm,
"model_cost",
{"test-custom-model-invalid": mock_custom_model_info_invalid_base},
):
with patch.object(litellm, "get_model_info") as mock_get_model_info:
# Mock get_model_info to raise exception for invalid base model
def mock_get_model_info_side_effect(model):
if model == "invalid-base-model":
raise Exception("Model not found")
elif model == "test-model":
return mock_litellm_model_name_info
return None
mock_get_model_info.side_effect = mock_get_model_info_side_effect
result = router.get_deployment_model_info(
model_id="test-custom-model-invalid", model_name="test-model"
)
# Should handle exception gracefully and still return merged result
assert result is not None
assert result["input_cost_per_token"] == 0.001 # From custom model
assert result["mode"] == "completion" # From litellm model name info
# Test Case 5: Both model_cost.get() and get_model_info() return None
with patch.object(litellm, "model_cost", {}):
with patch.object(
litellm, "get_model_info", side_effect=Exception("Not found")
):
result = router.get_deployment_model_info(
model_id="non-existent", model_name="non-existent"
)
# Should return None when no model info is found
assert result is None
print("✓ All base model flow test cases passed!")
@patch("litellm.model_cost", {})
def test_get_deployment_model_info_base_model_merge_priority():
"""Test that base model info merging respects the correct priority order"""
from unittest.mock import patch
router = litellm.Router(
model_list=[
{
"model_name": "test-model",
"litellm_params": {"model": "gpt-3.5-turbo"},
}
],
)
# Test data with overlapping fields to test merge priority
mock_custom_model_info = {
"base_model": "gpt-4",
"input_cost_per_token": 0.01, # Should override base model value
"max_tokens": 8000, # Should override base model value
"custom_only_field": "custom_value",
}
mock_base_model_info = {
"key": "gpt-4",
"max_tokens": 4096, # Should be overridden by custom model
"input_cost_per_token": 0.03, # Should be overridden by custom model
"output_cost_per_token": 0.06, # Should be preserved (not in custom)
"litellm_provider": "openai",
"base_only_field": "base_value",
}
mock_litellm_model_name_info = {
"key": "test-model",
"max_tokens": 2048, # Should be overridden by final custom model info
"input_cost_per_token": 0.005, # Should be overridden by final custom model info
"output_cost_per_token": 0.01, # Should be overridden by final custom model info
"mode": "completion",
"litellm_only_field": "litellm_value",
}
with patch.object(
litellm, "model_cost", {"custom-model-id": mock_custom_model_info}
):
with patch.object(litellm, "get_model_info") as mock_get_model_info:
mock_get_model_info.side_effect = lambda model: {
"gpt-4": mock_base_model_info,
"test-model": mock_litellm_model_name_info,
}.get(model)
result = router.get_deployment_model_info(
model_id="custom-model-id", model_name="test-model"
)
assert result is not None
# Test correct merge priority after fix:
# 1. base_model_info provides defaults
# 2. custom_model_info overrides base_model_info
# 3. Result from steps 1-2 overrides litellm_model_name_info
# Fields that should come from custom model info (highest priority)
assert (
result["input_cost_per_token"] == 0.01
) # From custom model (overrides base 0.03)
assert (
result["max_tokens"] == 8000
) # From custom model (overrides base 4096)
assert result["custom_only_field"] == "custom_value" # From custom model
# Fields that should come from base model (not overridden by custom)
assert (
result["output_cost_per_token"] == 0.06
) # From base model (not in custom)
assert (
result["litellm_provider"] == "openai"
) # From base model (not in custom)
assert (
result["base_only_field"] == "base_value"
) # From base model (not in custom)
# Fields that should come from litellm model name info (not overridden by custom+base)
assert (
result["mode"] == "completion"
) # From litellm model name info (not in custom or base)
assert (
result["litellm_only_field"] == "litellm_value"
) # From litellm model name info (not in custom or base)
# Key comes from base model since both base and litellm have key fields
# and the merged custom+base overrides litellm in the final merge
assert result["key"] == "gpt-4"
print("✓ Base model merge priority test passed!")
def test_add_deployment_model_to_endpoint_for_llm_passthrough_route():
"""
Test that _add_deployment_model_to_endpoint_for_llm_passthrough_route correctly strips bedrock provider prefix
"""
router = litellm.Router(
model_list=[
{
"model_name": "special-bedrock-model",
"litellm_params": {
"model": "bedrock/us.anthropic.claude-3-5-sonnet-20240620-v1:0",
},
}
],
)
# Test Case 1: Bedrock model with provider prefix - should strip "bedrock/" prefix
kwargs = {
"endpoint": "/model/special-bedrock-model/invoke",
"custom_llm_provider": "bedrock",
}
result = router._add_deployment_model_to_endpoint_for_llm_passthrough_route(
kwargs=kwargs,
model="special-bedrock-model",
model_name="bedrock/us.anthropic.claude-3-5-sonnet-20240620-v1:0",
)
assert (
result["endpoint"]
== "/model/us.anthropic.claude-3-5-sonnet-20240620-v1:0/invoke"
), f"Expected '/model/us.anthropic.claude-3-5-sonnet-20240620-v1:0/invoke', got '{result['endpoint']}'"
# Test Case 2: Bedrock invoke-with-response-stream endpoint
kwargs = {
"endpoint": "/model/special-bedrock-model/invoke-with-response-stream",
"custom_llm_provider": "bedrock",
}
result = router._add_deployment_model_to_endpoint_for_llm_passthrough_route(
kwargs=kwargs,
model="special-bedrock-model",
model_name="bedrock/us.anthropic.claude-3-5-sonnet-20240620-v1:0",
)
assert (
result["endpoint"]
== "/model/us.anthropic.claude-3-5-sonnet-20240620-v1:0/invoke-with-response-stream"
), f"Expected streaming endpoint with stripped prefix, got '{result['endpoint']}'"
# Test Case 3: Bedrock converse endpoint
kwargs = {
"endpoint": "/model/bedrock-model/converse",
"custom_llm_provider": "bedrock",
}
result = router._add_deployment_model_to_endpoint_for_llm_passthrough_route(
kwargs=kwargs,
model="bedrock-model",
model_name="bedrock/us.meta.llama3-8b-instruct-v1:0",
)
assert (
result["endpoint"] == "/model/us.meta.llama3-8b-instruct-v1:0/converse"
), f"Expected '/model/us.meta.llama3-8b-instruct-v1:0/converse', got '{result['endpoint']}'"
# Test Case 4: Bedrock provider prefix auto-detected from model_name
kwargs = {
"endpoint": "/model/router-model/invoke",
}
result = router._add_deployment_model_to_endpoint_for_llm_passthrough_route(
kwargs=kwargs,
model="router-model",
model_name="bedrock/us.meta.llama3-8b-instruct-v1:0",
)
assert (
result["endpoint"] == "/model/us.meta.llama3-8b-instruct-v1:0/invoke"
), f"Expected '/model/us.meta.llama3-8b-instruct-v1:0/invoke', got '{result['endpoint']}'"
@pytest.mark.asyncio
async def test_router_acompletion_with_unknown_model_and_default_fallback():
"""
Test that the router successfully uses a default fallback when a completely
unknown model is requested. It should not raise a BadRequestError.
This test verifies the fix for issue #15114.
"""
model_list = [
{
"model_name": "gpt-4o", # This is the fallback model
"litellm_params": {
"model": "azure/gpt-4o-real", # The actual underlying model name
"api_key": "fake-key",
"api_base": "https://fake-endpoint.openai.azure.com/",
"mock_response": "this is the fallback response", # Mocked response to prevent real API calls
},
}
]
# Initialize the router with a default fallback
router = litellm.Router(model_list=model_list, default_fallbacks=["gpt-4o"])
messages = [
{"role": "user", "content": "This call should succeed by falling back."}
]
# Call completion with a model name that is NOT in the model_list
response = await router.acompletion(
model="completely-unknown-model", messages=messages
)
# Check that the call did not fail and we received a valid response object.
assert response is not None
# Check that the content of the response is from the MOCKED fallback model.
assert response.choices[0].message.content == "this is the fallback response"
# Check that the response object reports the model that was *actually* called.
assert response.model == "gpt-4o-real"
@pytest.mark.asyncio
async def test_router_acompletion_with_unknown_model_and_no_fallback():
"""
Test that the router still raises a BadRequestError for an unknown model
when no default fallbacks are configured. This ensures we don't break
the original behavior.
"""
model_list = [
{
"model_name": "gpt-4o",
"litellm_params": {
"model": "azure/gpt-4o-real",
"api_key": "fake-key",
"mock_response": "this should not be called",
},
}
]
# Initialize the router WITHOUT any default fallbacks
router = litellm.Router(model_list=model_list)
messages = [{"role": "user", "content": "This call should fail."}]
# Use pytest.raises to assert that a BadRequestError is thrown.
with pytest.raises(litellm.BadRequestError) as excinfo:
await router.acompletion(model="completely-unknown-model", messages=messages)
# Check that the error message is correct.
# The router returns 'no healthy deployments' because get_model_list returns [] not None.
assert "no healthy deployments for this model" in str(excinfo.value)
def test_get_deployment_credentials_with_provider_aws_bedrock_runtime_endpoint():
"""
Test that get_deployment_credentials_with_provider correctly copies
aws_bedrock_runtime_endpoint from deployment litellm_params to credentials.
"""
router = litellm.Router(
model_list=[
{
"model_name": "bedrock-claude-model",
"litellm_params": {
"model": "bedrock/us.anthropic.claude-3-5-sonnet-20240620-v1:0",
"aws_access_key_id": "test-access-key",
"aws_secret_access_key": "test-secret-key",
"aws_region_name": "us-east-1",
"aws_bedrock_runtime_endpoint": "https://bedrock-runtime.us-east-1.amazonaws.com",
},
}
],
)
credentials = router.get_deployment_credentials_with_provider(
model_id="bedrock-claude-model"
)
assert credentials is not None
assert (
credentials["aws_bedrock_runtime_endpoint"]
== "https://bedrock-runtime.us-east-1.amazonaws.com"
)
assert credentials["aws_access_key_id"] == "test-access-key"
assert credentials["aws_secret_access_key"] == "test-secret-key"
assert credentials["aws_region_name"] == "us-east-1"
assert credentials["custom_llm_provider"] == "bedrock"
def test_get_deployment_credentials_with_provider_resolves_credential_name():
"""
Test that get_deployment_credentials_with_provider correctly resolves
litellm_credential_name to actual credential values (for UI-created models).
"""
from litellm.types.utils import CredentialItem
# Setup credential list with a test credential
litellm.credential_list = [
CredentialItem(
credential_name="test-azure-cred",
credential_info={"custom_llm_provider": "azure"},
credential_values={
"api_key": "resolved-api-key",
"api_base": "https://resolved.openai.azure.com",
"api_version": "2024-02-01",
},
)
]
router = litellm.Router(
model_list=[
{
"model_name": "azure-gpt-4",
"litellm_params": {
"model": "azure/gpt-4",
"litellm_credential_name": "test-azure-cred",
},
}
],
)
credentials = router.get_deployment_credentials_with_provider(
model_id="azure-gpt-4"
)
assert credentials is not None
assert credentials["api_key"] == "resolved-api-key"
assert credentials["api_base"] == "https://resolved.openai.azure.com"
assert credentials["api_version"] == "2024-02-01"
assert credentials["custom_llm_provider"] == "azure"
# Ensure credential name is removed after resolution
assert "litellm_credential_name" not in credentials
# Cleanup
litellm.credential_list = []
def test_get_available_guardrail_single_deployment():
"""
Test get_available_guardrail returns the single guardrail when only one exists.
"""
guardrail_config = {
"guardrail_name": "content-filter",
"litellm_params": {"guardrail": "custom", "mode": "pre_call"},
"id": "guardrail-1",
}
router = litellm.Router(
model_list=[
{
"model_name": "gpt-3.5-turbo",
"litellm_params": {"model": "gpt-3.5-turbo"},
}
],
guardrail_list=[guardrail_config],
)
result = router.get_available_guardrail(guardrail_name="content-filter")
assert result == guardrail_config
def test_get_available_guardrail_multiple_deployments():
"""
Test get_available_guardrail load balances across multiple guardrails.
"""
guardrail_1 = {
"guardrail_name": "content-filter",
"litellm_params": {"guardrail": "custom", "mode": "pre_call"},
"id": "guardrail-1",
}
guardrail_2 = {
"guardrail_name": "content-filter",
"litellm_params": {"guardrail": "custom", "mode": "pre_call"},
"id": "guardrail-2",
}
router = litellm.Router(
model_list=[
{
"model_name": "gpt-3.5-turbo",
"litellm_params": {"model": "gpt-3.5-turbo"},
}
],
guardrail_list=[guardrail_1, guardrail_2],
)
# Call multiple times to verify load balancing
results = set()
for _ in range(20):
result = router.get_available_guardrail(guardrail_name="content-filter")
results.add(result["id"])
# Both guardrails should be selected at least once
assert "guardrail-1" in results or "guardrail-2" in results
def test_get_available_guardrail_not_found():
"""
Test get_available_guardrail raises ValueError when guardrail not found.
"""
router = litellm.Router(
model_list=[
{
"model_name": "gpt-3.5-turbo",
"litellm_params": {"model": "gpt-3.5-turbo"},
}
],
guardrail_list=[],
)
with pytest.raises(ValueError, match="No guardrail found with name"):
router.get_available_guardrail(guardrail_name="non-existent")
@pytest.mark.asyncio
async def test_aguardrail_helper():
"""
Test _aguardrail_helper selects a guardrail and executes the original function.
"""
guardrail_config = {
"guardrail_name": "content-filter",
"litellm_params": {"guardrail": "custom", "mode": "pre_call"},
"id": "guardrail-1",
}
router = litellm.Router(
model_list=[
{
"model_name": "gpt-3.5-turbo",
"litellm_params": {"model": "gpt-3.5-turbo"},
}
],
guardrail_list=[guardrail_config],
)
# Mock the original function
async def mock_original_function(**kwargs):
return {
"result": "success",
"selected_guardrail": kwargs.get("selected_guardrail"),
}
result = await router._aguardrail_helper(
model="content-filter",
original_generic_function=mock_original_function,
)
assert result["result"] == "success"
assert result["selected_guardrail"] == guardrail_config
@pytest.mark.asyncio
async def test_aguardrail():
"""
Test aguardrail executes a guardrail with load balancing and fallbacks.
"""
guardrail_config = {
"guardrail_name": "content-filter",
"litellm_params": {"guardrail": "custom", "mode": "pre_call"},
"id": "guardrail-1",
}
router = litellm.Router(
model_list=[
{
"model_name": "gpt-3.5-turbo",
"litellm_params": {"model": "gpt-3.5-turbo"},
}
],
guardrail_list=[guardrail_config],
)
# Mock the original function
async def mock_original_function(**kwargs):
return {
"result": "success",
"selected_guardrail": kwargs.get("selected_guardrail"),
}
result = await router.aguardrail(
guardrail_name="content-filter",
original_function=mock_original_function,
)
assert result["result"] == "success"
assert result["selected_guardrail"]["id"] == "guardrail-1"
@pytest.mark.asyncio
async def test_anthropic_messages_call_type_is_cached():
"""
Regression test: Verify that anthropic_messages call type is allowed
in PromptCachingDeploymentCheck.async_log_success_event.
"""
import asyncio
from litellm.caching.dual_cache import DualCache
from litellm.router_utils.pre_call_checks.prompt_caching_deployment_check import (
PromptCachingDeploymentCheck,
)
from litellm.router_utils.prompt_caching_cache import PromptCachingCache
from litellm.types.utils import (
CallTypes,
StandardLoggingHiddenParams,
StandardLoggingMetadata,
StandardLoggingModelInformation,
StandardLoggingPayload,
)
# Create mock standard logging payload inline
def create_standard_logging_payload() -> StandardLoggingPayload:
return StandardLoggingPayload(
id="test_id",
call_type="completion",
response_cost=0.1,
response_cost_failure_debug_info=None,
status="success",
total_tokens=30,
prompt_tokens=20,
completion_tokens=10,
startTime=1234567890.0,
endTime=1234567891.0,
completionStartTime=1234567890.5,
model_map_information=StandardLoggingModelInformation(
model_map_key="gpt-3.5-turbo", model_map_value=None
),
model="gpt-3.5-turbo",
model_id="model-123",
model_group="openai-gpt",
api_base="https://api.openai.com",
metadata=StandardLoggingMetadata(
user_api_key_hash="test_hash",
user_api_key_org_id=None,
user_api_key_alias="test_alias",
user_api_key_team_id="test_team",
user_api_key_user_id="test_user",
user_api_key_team_alias="test_team_alias",
spend_logs_metadata=None,
requester_ip_address="127.0.0.1",
requester_metadata=None,
),
cache_hit=False,
cache_key=None,
saved_cache_cost=0.0,
request_tags=[],
end_user=None,
requester_ip_address="127.0.0.1",
messages=[{"role": "user", "content": "Hello, world!"}],
response={"choices": [{"message": {"content": "Hi there!"}}]},
error_str=None,
model_parameters={"stream": True},
hidden_params=StandardLoggingHiddenParams(
model_id="model-123",
cache_key=None,
api_base="https://api.openai.com",
response_cost="0.1",
additional_headers=None,
),
)
cache = DualCache()
deployment_check = PromptCachingDeploymentCheck(cache=cache)
prompt_cache = PromptCachingCache(cache=cache)
# Create messages with enough tokens to pass the caching threshold
test_messages = [
{
"role": "user",
"content": [
{
"type": "text",
"text": "test long message here" * 1024,
"cache_control": {"type": "ephemeral", "ttl": "5m"},
}
],
}
]
test_model_id = "test-model-id-123"
# Create a payload with anthropic_messages call type
payload = create_standard_logging_payload()
payload["call_type"] = CallTypes.anthropic_messages.value
payload["messages"] = test_messages
payload["model"] = "anthropic/claude-3-5-sonnet-20240620"
payload["model_id"] = test_model_id
# Log the success event (should cache the model_id)
await deployment_check.async_log_success_event(
kwargs={"standard_logging_object": payload},
response_obj={},
start_time=1234567890.0,
end_time=1234567891.0,
)
# Small delay to ensure cache write completes
await asyncio.sleep(0.1)
# Verify that the model_id was actually cached
cached_result = await prompt_cache.async_get_model_id(
messages=test_messages,
tools=None,
)
# This assertion will FAIL if anthropic_messages is filtered out
assert (
cached_result is not None
), "Model ID should be cached for anthropic_messages call type"
assert (
cached_result["model_id"] == test_model_id
), f"Expected {test_model_id}, got {cached_result['model_id']}"
def test_update_kwargs_with_deployment_propagates_model_tags():
"""
Test that deployment-level tags from litellm_params are merged into
kwargs metadata when _update_kwargs_with_deployment is called.
This ensures model-level tags defined in config.yaml appear in SpendLogs.
See: https://github.com/BerriAI/litellm/issues/XXXX
"""
router = litellm.Router(
model_list=[
{
"model_name": "gpt-4o-mini",
"litellm_params": {
"model": "openai/gpt-4o-mini",
"api_key": "fake-key",
"tags": ["openai-account", "production"],
},
},
],
)
kwargs: dict = {"metadata": {}}
deployment = router.get_deployment_by_model_group_name(
model_group_name="gpt-4o-mini"
)
router._update_kwargs_with_deployment(deployment=deployment, kwargs=kwargs)
# Deployment tags should be propagated to kwargs metadata
assert "tags" in kwargs["metadata"]
assert "openai-account" in kwargs["metadata"]["tags"]
assert "production" in kwargs["metadata"]["tags"]
def test_update_kwargs_with_deployment_merges_tags_without_duplicates():
"""
Test that when both request-level and deployment-level tags exist,
they are merged without duplicates.
"""
router = litellm.Router(
model_list=[
{
"model_name": "gpt-4o-mini",
"litellm_params": {
"model": "openai/gpt-4o-mini",
"api_key": "fake-key",
"tags": ["openai-account", "shared-tag"],
},
},
],
)
# Simulate request that already has tags (from request body or key/team level)
kwargs: dict = {"metadata": {"tags": ["user-tag", "shared-tag"]}}
deployment = router.get_deployment_by_model_group_name(
model_group_name="gpt-4o-mini"
)
router._update_kwargs_with_deployment(deployment=deployment, kwargs=kwargs)
# Both sources should be merged, no duplicates
assert "user-tag" in kwargs["metadata"]["tags"]
assert "openai-account" in kwargs["metadata"]["tags"]
assert "shared-tag" in kwargs["metadata"]["tags"]
assert kwargs["metadata"]["tags"].count("shared-tag") == 1
def test_update_kwargs_with_deployment_no_tags():
"""
Test that when deployment has no tags, kwargs metadata is not affected.
"""
router = litellm.Router(
model_list=[
{
"model_name": "gpt-4o-mini",
"litellm_params": {
"model": "openai/gpt-4o-mini",
"api_key": "fake-key",
},
},
],
)
kwargs: dict = {"metadata": {}}
deployment = router.get_deployment_by_model_group_name(
model_group_name="gpt-4o-mini"
)
router._update_kwargs_with_deployment(deployment=deployment, kwargs=kwargs)
# No tags key should be added if deployment has no tags
assert "tags" not in kwargs["metadata"]
def test_update_kwargs_with_deployment_merges_tools():
"""
Test that when both deployment litellm_params and request have tools,
they are merged (deployment tools first, then request tools).
Supports proxy-configured tools (e.g. for o3 deep research) merged with
client-provided tools.
"""
router = litellm.Router(
model_list=[
{
"model_name": "o3-deep-research",
"litellm_params": {
"model": "openai/o3-deep-research",
"api_key": "fake-key",
"tools": [{"type": "web_search"}],
"tool_choice": "auto",
},
},
],
)
kwargs: dict = {
"metadata": {},
"tools": [
{
"type": "function",
"function": {"name": "get_weather", "description": "Get weather"},
},
],
}
deployment = router.get_deployment_by_model_group_name(
model_group_name="o3-deep-research"
)
router._update_kwargs_with_deployment(deployment=deployment, kwargs=kwargs)
# Tools should be merged: deployment first, then request
assert "tools" in kwargs
assert len(kwargs["tools"]) == 2
assert kwargs["tools"][0] == {"type": "web_search"}
assert kwargs["tools"][1]["function"]["name"] == "get_weather"
# tool_choice from request (none) - deployment's should be used
assert kwargs["tool_choice"] == "auto"
def test_update_kwargs_with_deployment_merge_tools_deployment_only():
"""
Test that when only deployment has tools, they are applied to kwargs.
"""
router = litellm.Router(
model_list=[
{
"model_name": "o3-deep-research",
"litellm_params": {
"model": "openai/o3-deep-research",
"api_key": "fake-key",
"tools": [{"type": "web_search"}],
"tool_choice": "required",
},
},
],
)
kwargs: dict = {"metadata": {}}
deployment = router.get_deployment_by_model_group_name(
model_group_name="o3-deep-research"
)
router._update_kwargs_with_deployment(deployment=deployment, kwargs=kwargs)
assert kwargs["tools"] == [{"type": "web_search"}]
assert kwargs["tool_choice"] == "required"
def test_update_kwargs_with_deployment_merge_tools_request_overrides_tool_choice():
"""
Test that when request has tool_choice, it overrides deployment's.
"""
router = litellm.Router(
model_list=[
{
"model_name": "o3-deep-research",
"litellm_params": {
"model": "openai/o3-deep-research",
"api_key": "fake-key",
"tools": [{"type": "web_search"}],
"tool_choice": "auto",
},
},
],
)
kwargs: dict = {
"metadata": {},
"tool_choice": "none",
}
deployment = router.get_deployment_by_model_group_name(
model_group_name="o3-deep-research"
)
router._update_kwargs_with_deployment(deployment=deployment, kwargs=kwargs)
# Request tool_choice should be preserved (merged tools still applied)
assert kwargs["tool_choice"] == "none"
def test_credential_name_injected_as_tag():
"""
Test that litellm_credential_name from deployment litellm_params
is injected as a tag into metadata during _update_kwargs_with_deployment.
"""
router = litellm.Router(
model_list=[
{
"model_name": "xai-model",
"litellm_params": {
"model": "xai/grok-4-1-fast",
"litellm_credential_name": "xAI",
},
}
],
)
kwargs: dict = {"metadata": {"tags": ["A.101"]}}
deployment = router.get_deployment_by_model_group_name(model_group_name="xai-model")
router._update_kwargs_with_deployment(deployment=deployment, kwargs=kwargs)
assert "Credential: xAI" in kwargs["metadata"]["tags"]
assert "A.101" in kwargs["metadata"]["tags"]
def test_credential_name_not_duplicated_in_tags():
"""
Test that if the credential tag already exists in the tags list,
it is not duplicated.
"""
router = litellm.Router(
model_list=[
{
"model_name": "xai-model",
"litellm_params": {
"model": "xai/grok-4-1-fast",
"litellm_credential_name": "xAI",
},
}
],
)
kwargs: dict = {"metadata": {"tags": ["Credential: xAI", "A.101"]}}
deployment = router.get_deployment_by_model_group_name(model_group_name="xai-model")
router._update_kwargs_with_deployment(deployment=deployment, kwargs=kwargs)
assert kwargs["metadata"]["tags"].count("Credential: xAI") == 1
def test_credential_name_not_injected_when_absent():
"""
Test that when no litellm_credential_name is set, tags are unchanged.
"""
router = litellm.Router(
model_list=[
{
"model_name": "gpt-model",
"litellm_params": {
"model": "gpt-4o",
},
}
],
)
kwargs: dict = {"metadata": {"tags": ["A.101"]}}
deployment = router.get_deployment_by_model_group_name(model_group_name="gpt-model")
router._update_kwargs_with_deployment(deployment=deployment, kwargs=kwargs)
assert kwargs["metadata"]["tags"] == ["A.101"]
def test_update_kwargs_with_deployment_model_info_in_litellm_metadata():
"""For generic_api_call, model_info with pricing must go to litellm_metadata.
Routes like /messages and /responses use generic_api_call which stores
model_info under litellm_metadata. Regression test for #23185.
"""
router = litellm.Router(
model_list=[
{
"model_name": "claude-sonnet-4",
"litellm_params": {
"model": "anthropic/claude-sonnet-4-20250514",
"api_key": "fake-key",
},
"model_info": {
"id": "custom-pricing-id",
"input_cost_per_token": 0.0003,
"output_cost_per_token": 0.0015,
},
},
],
)
kwargs: dict = {}
deployment = router.get_deployment_by_model_group_name(
model_group_name="claude-sonnet-4"
)
router._update_kwargs_with_deployment(
deployment=deployment, kwargs=kwargs, function_name="generic_api_call"
)
assert "litellm_metadata" in kwargs
model_info = kwargs["litellm_metadata"]["model_info"]
assert model_info["id"] == "custom-pricing-id"
assert model_info["input_cost_per_token"] == 0.0003
assert model_info["output_cost_per_token"] == 0.0015
def test_update_kwargs_with_deployment_model_info_in_metadata():
"""For acompletion (function_name=None), model_info goes to metadata.
/chat/completions uses acompletion which stores model_info under metadata.
"""
router = litellm.Router(
model_list=[
{
"model_name": "claude-sonnet-4",
"litellm_params": {
"model": "anthropic/claude-sonnet-4-20250514",
"api_key": "fake-key",
},
"model_info": {
"id": "custom-pricing-id",
"input_cost_per_token": 0.0003,
"output_cost_per_token": 0.0015,
},
},
],
)
kwargs: dict = {}
deployment = router.get_deployment_by_model_group_name(
model_group_name="claude-sonnet-4"
)
router._update_kwargs_with_deployment(
deployment=deployment, kwargs=kwargs, function_name=None
)
assert "metadata" in kwargs
model_info = kwargs["metadata"]["model_info"]
assert model_info["id"] == "custom-pricing-id"
assert model_info["input_cost_per_token"] == 0.0003
assert model_info["output_cost_per_token"] == 0.0015
def test_combine_fallback_usage():
"""Test that _combine_fallback_usage merges partial and fallback usage."""
from litellm.router import Router
from litellm.types.utils import Usage
# Create a stream chunk with usage
chunk = litellm.ModelResponseStream(
id="test",
model="gpt-4o",
choices=[],
usage=Usage(prompt_tokens=10, completion_tokens=5, total_tokens=15),
)
# Call _combine_fallback_usage with no extra usage
Router._combine_fallback_usage(chunk, None)
assert chunk.usage is not None
assert chunk.usage.prompt_tokens == 10
assert chunk.usage.completion_tokens == 5
assert chunk.usage.total_tokens == 15