Files
litellm/tests/test_litellm/proxy/test_proxy_server.py
T
Sameer Kankute c908505e6a fix(proxy): omit OpenAI [DONE] on google-genai streamGenerateContent (#29426)
* fix(proxy): omit OpenAI [DONE] on google-genai streamGenerateContent

google-genai SDK uses ?alt=sse and cannot parse the proxy's trailing
data: [DONE] chunk. Skip that terminator for agenerate_content_stream.

Co-authored-by: Cursor <cursoragent@cursor.com>

* fix(proxy): address Greptile review on google-genai stream fix

Always yield stream error_message; only gate data: [DONE] on the skip flag.
Set _litellm_skip_openai_stream_done in google_endpoints instead of common_request_processing.

Co-authored-by: Cursor <cursoragent@cursor.com>

---------

Co-authored-by: Cursor <cursoragent@cursor.com>
Co-authored-by: mateo-berri <277851410+mateo-berri@users.noreply.github.com>
2026-06-01 14:38:19 -07:00

7799 lines
283 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import asyncio
import importlib
import json
import os
import socket
import subprocess
import sys
from datetime import datetime, timedelta, timezone
from pathlib import Path
from unittest import mock
from unittest.mock import AsyncMock, MagicMock, mock_open, patch
import click
import httpx
import pytest
import yaml
from fastapi import FastAPI
from fastapi.staticfiles import StaticFiles
from fastapi.testclient import TestClient
sys.path.insert(
0, os.path.abspath("../../..")
) # Adds the parent directory to the system-path
import litellm
from litellm.proxy._types import LitellmUserRoles, UserAPIKeyAuth
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
from litellm.proxy.proxy_server import app, initialize
from litellm.utils import _invalidate_model_cost_lowercase_map
example_embedding_result = {
"object": "list",
"data": [
{
"object": "embedding",
"index": 0,
"embedding": [
-0.006929283495992422,
-0.005336422007530928,
-4.547132266452536e-05,
-0.024047505110502243,
-0.006929283495992422,
-0.005336422007530928,
-4.547132266452536e-05,
-0.024047505110502243,
-0.006929283495992422,
-0.005336422007530928,
-4.547132266452536e-05,
-0.024047505110502243,
],
}
],
"model": "text-embedding-3-small",
"usage": {"prompt_tokens": 5, "total_tokens": 5},
}
def mock_patch_aembedding():
return mock.patch(
"litellm.proxy.proxy_server.llm_router.aembedding",
return_value=example_embedding_result,
)
@pytest.fixture(scope="function")
def client_no_auth():
# Assuming litellm.proxy.proxy_server is an object
from litellm.proxy.proxy_server import cleanup_router_config_variables
cleanup_router_config_variables()
filepath = os.path.dirname(os.path.abspath(__file__))
config_fp = f"{filepath}/test_configs/test_config_no_auth.yaml"
# initialize can get run in parallel, it sets specific variables for the fast api app, sinc eit gets run in parallel different tests use the wrong variables
asyncio.run(initialize(config=config_fp, debug=True))
return TestClient(app)
def test_login_v2_returns_redirect_url_and_sets_cookie(monkeypatch):
mock_login_result = {"user_id": "test-user"}
mock_prisma_client = MagicMock()
mock_authenticate_user = AsyncMock(return_value=mock_login_result)
mock_create_ui_token_object = MagicMock(return_value={"user_id": "test-user"})
mock_jwt_encode = MagicMock(return_value="signed-token")
monkeypatch.setattr(
"litellm.proxy.auth.login_utils.authenticate_user",
mock_authenticate_user,
)
monkeypatch.setattr(
"litellm.proxy.auth.login_utils.create_ui_token_object",
mock_create_ui_token_object,
)
monkeypatch.setattr("jwt.encode", mock_jwt_encode)
monkeypatch.setattr("litellm.proxy.proxy_server.master_key", "test-master-key")
monkeypatch.setattr("litellm.proxy.proxy_server.general_settings", {})
monkeypatch.setattr("litellm.proxy.proxy_server.premium_user", False)
monkeypatch.setattr("litellm.proxy.proxy_server.prisma_client", mock_prisma_client)
monkeypatch.setattr("litellm.proxy.utils.get_server_root_path", lambda: "")
monkeypatch.setattr("litellm.proxy.utils.get_proxy_base_url", lambda: None)
client = TestClient(app)
response = client.post(
"/v2/login",
json={"username": "alice", "password": "secret"},
)
assert response.status_code == 200
assert response.json() == {
"redirect_url": "http://testserver/ui/?login=success",
"token": "signed-token",
}
assert response.cookies.get("token") == "signed-token"
mock_authenticate_user.assert_awaited_once_with(
username="alice",
password="secret",
master_key="test-master-key",
prisma_client=mock_prisma_client,
)
mock_create_ui_token_object.assert_called_once_with(
login_result=mock_login_result,
general_settings={},
premium_user=False,
)
mock_jwt_encode.assert_called_once_with(
{"user_id": "test-user"},
"test-master-key",
algorithm="HS256",
)
def test_login_v2_returns_json_on_proxy_exception(monkeypatch):
"""Test that /v2/login returns JSON error when ProxyException is raised"""
from litellm.proxy._types import ProxyErrorTypes, ProxyException
mock_prisma_client = MagicMock()
mock_authenticate_user = AsyncMock(
side_effect=ProxyException(
message="Invalid credentials",
type=ProxyErrorTypes.auth_error,
param="password",
code=401,
)
)
monkeypatch.setattr(
"litellm.proxy.auth.login_utils.authenticate_user",
mock_authenticate_user,
)
monkeypatch.setattr("litellm.proxy.proxy_server.master_key", "test-master-key")
monkeypatch.setattr("litellm.proxy.proxy_server.prisma_client", mock_prisma_client)
client = TestClient(app)
response = client.post(
"/v2/login",
json={"username": "alice", "password": "wrong"},
)
assert response.status_code == 401
assert response.headers["content-type"] == "application/json"
data = response.json()
assert "error" in data
assert data["error"]["message"] == "Invalid credentials"
assert data["error"]["type"] == "auth_error"
def test_login_v2_returns_json_on_http_exception(monkeypatch):
"""Test that /v2/login converts HTTPException to JSON error response"""
from fastapi import HTTPException
mock_prisma_client = MagicMock()
mock_authenticate_user = AsyncMock(
side_effect=HTTPException(status_code=401, detail="Unauthorized")
)
monkeypatch.setattr(
"litellm.proxy.auth.login_utils.authenticate_user",
mock_authenticate_user,
)
monkeypatch.setattr("litellm.proxy.proxy_server.master_key", "test-master-key")
monkeypatch.setattr("litellm.proxy.proxy_server.prisma_client", mock_prisma_client)
client = TestClient(app)
response = client.post(
"/v2/login",
json={"username": "alice", "password": "secret"},
)
assert response.status_code == 401
assert response.headers["content-type"] == "application/json"
data = response.json()
assert "error" in data
assert isinstance(data["error"], dict)
def test_login_v2_returns_json_on_unexpected_exception(monkeypatch):
"""Test that /v2/login returns JSON error when unexpected exception occurs"""
mock_prisma_client = MagicMock()
mock_authenticate_user = AsyncMock(side_effect=ValueError("Unexpected error"))
monkeypatch.setattr(
"litellm.proxy.auth.login_utils.authenticate_user",
mock_authenticate_user,
)
monkeypatch.setattr("litellm.proxy.proxy_server.master_key", "test-master-key")
monkeypatch.setattr("litellm.proxy.proxy_server.prisma_client", mock_prisma_client)
client = TestClient(app)
response = client.post(
"/v2/login",
json={"username": "alice", "password": "secret"},
)
assert response.status_code == 500
assert response.headers["content-type"] == "application/json"
data = response.json()
assert "error" in data
assert isinstance(data["error"], dict)
assert "Unexpected error" in data["error"]["message"]
def test_login_v2_returns_json_on_invalid_json_body(monkeypatch):
"""Test that /v2/login returns JSON error when request body is invalid JSON"""
monkeypatch.setattr("litellm.proxy.proxy_server.master_key", "test-master-key")
client = TestClient(app)
response = client.post(
"/v2/login",
content="invalid json",
headers={"Content-Type": "application/json"},
)
assert response.status_code == 500
assert response.headers["content-type"] == "application/json"
data = response.json()
assert "error" in data
assert isinstance(data["error"], dict)
def test_login_v3_rejected_without_control_plane_url(monkeypatch):
"""v3/login returns 404 when control_plane_url is not configured."""
mock_prisma_client = MagicMock()
monkeypatch.setattr("litellm.proxy.proxy_server.master_key", "test-master-key")
monkeypatch.setattr("litellm.proxy.proxy_server.general_settings", {})
monkeypatch.setattr("litellm.proxy.proxy_server.prisma_client", mock_prisma_client)
client = TestClient(app)
response = client.post(
"/v3/login",
json={"username": "alice", "password": "secret"},
)
assert response.status_code == 404
assert "control_plane_url" in response.json()["error"]["message"]
def test_login_v3_returns_code(monkeypatch):
"""v3/login returns an opaque code, not the JWT directly."""
mock_prisma_client = MagicMock()
monkeypatch.setattr(
"litellm.proxy.auth.login_utils.authenticate_user",
AsyncMock(return_value={"user_id": "test-user"}),
)
monkeypatch.setattr(
"litellm.proxy.auth.login_utils.create_ui_token_object",
MagicMock(return_value={"user_id": "test-user"}),
)
monkeypatch.setattr("jwt.encode", MagicMock(return_value="signed-token"))
monkeypatch.setattr("litellm.proxy.proxy_server.master_key", "test-master-key")
monkeypatch.setattr(
"litellm.proxy.proxy_server.general_settings",
{"control_plane_url": "https://cp.example.com"},
)
monkeypatch.setattr("litellm.proxy.proxy_server.premium_user", False)
monkeypatch.setattr("litellm.proxy.proxy_server.prisma_client", mock_prisma_client)
mock_config = MagicMock()
mock_config.worker_registry = []
monkeypatch.setattr("litellm.proxy.proxy_server.proxy_config", mock_config)
monkeypatch.setattr("litellm.proxy.utils.get_server_root_path", lambda: "")
monkeypatch.setattr("litellm.proxy.utils.get_proxy_base_url", lambda: None)
client = TestClient(app)
response = client.post(
"/v3/login",
json={"username": "alice", "password": "secret"},
)
assert response.status_code == 200
data = response.json()
assert "code" in data
assert data["expires_in"] == 60
assert "token" not in data
def test_login_v3_exchange_happy_path(monkeypatch):
"""Full flow: v3/login returns code, v3/login/exchange redeems it for JWT."""
mock_prisma_client = MagicMock()
monkeypatch.setattr(
"litellm.proxy.auth.login_utils.authenticate_user",
AsyncMock(return_value={"user_id": "test-user"}),
)
monkeypatch.setattr(
"litellm.proxy.auth.login_utils.create_ui_token_object",
MagicMock(return_value={"user_id": "test-user"}),
)
monkeypatch.setattr("jwt.encode", MagicMock(return_value="signed-token"))
monkeypatch.setattr("litellm.proxy.proxy_server.master_key", "test-master-key")
monkeypatch.setattr(
"litellm.proxy.proxy_server.general_settings",
{"control_plane_url": "https://cp.example.com"},
)
monkeypatch.setattr("litellm.proxy.proxy_server.premium_user", False)
monkeypatch.setattr("litellm.proxy.proxy_server.prisma_client", mock_prisma_client)
mock_config = MagicMock()
mock_config.worker_registry = []
monkeypatch.setattr("litellm.proxy.proxy_server.proxy_config", mock_config)
monkeypatch.setattr("litellm.proxy.utils.get_server_root_path", lambda: "")
monkeypatch.setattr("litellm.proxy.utils.get_proxy_base_url", lambda: None)
client = TestClient(app)
# Step 1: login — get code
login_response = client.post(
"/v3/login",
json={"username": "alice", "password": "secret"},
)
assert login_response.status_code == 200
code = login_response.json()["code"]
# Step 2: exchange — get JWT
exchange_response = client.post(
"/v3/login/exchange",
json={"code": code},
)
assert exchange_response.status_code == 200
exchange_data = exchange_response.json()
assert exchange_data["token"] == "signed-token"
assert "redirect_url" in exchange_data
assert exchange_response.cookies.get("token") == "signed-token"
def test_login_v3_exchange_single_use(monkeypatch):
"""Code can only be redeemed once."""
mock_prisma_client = MagicMock()
monkeypatch.setattr(
"litellm.proxy.auth.login_utils.authenticate_user",
AsyncMock(return_value={"user_id": "test-user"}),
)
monkeypatch.setattr(
"litellm.proxy.auth.login_utils.create_ui_token_object",
MagicMock(return_value={"user_id": "test-user"}),
)
monkeypatch.setattr("jwt.encode", MagicMock(return_value="signed-token"))
monkeypatch.setattr("litellm.proxy.proxy_server.master_key", "test-master-key")
monkeypatch.setattr(
"litellm.proxy.proxy_server.general_settings",
{"control_plane_url": "https://cp.example.com"},
)
monkeypatch.setattr("litellm.proxy.proxy_server.premium_user", False)
monkeypatch.setattr("litellm.proxy.proxy_server.prisma_client", mock_prisma_client)
mock_config = MagicMock()
mock_config.worker_registry = []
monkeypatch.setattr("litellm.proxy.proxy_server.proxy_config", mock_config)
monkeypatch.setattr("litellm.proxy.utils.get_server_root_path", lambda: "")
monkeypatch.setattr("litellm.proxy.utils.get_proxy_base_url", lambda: None)
client = TestClient(app)
login_response = client.post(
"/v3/login",
json={"username": "alice", "password": "secret"},
)
code = login_response.json()["code"]
# First exchange succeeds
first = client.post("/v3/login/exchange", json={"code": code})
assert first.status_code == 200
# Second exchange fails
second = client.post("/v3/login/exchange", json={"code": code})
assert second.status_code == 401
def test_login_v3_exchange_invalid_code(monkeypatch):
"""Random code returns 401."""
monkeypatch.setattr(
"litellm.proxy.proxy_server.general_settings",
{"control_plane_url": "https://cp.example.com"},
)
client = TestClient(app)
response = client.post(
"/v3/login/exchange",
json={"code": "nonexistent-code"},
)
assert response.status_code == 401
def test_login_v3_exchange_rejected_without_control_plane_url(monkeypatch):
"""v3/login/exchange returns 404 when control_plane_url is not configured."""
monkeypatch.setattr("litellm.proxy.proxy_server.general_settings", {})
client = TestClient(app)
response = client.post(
"/v3/login/exchange",
json={"code": "some-code"},
)
assert response.status_code == 404
assert "control_plane_url" in response.json()["error"]["message"]
def test_login_v3_returns_json_on_proxy_exception(monkeypatch):
"""Test that /v3/login returns JSON error when ProxyException is raised"""
from litellm.proxy._types import ProxyErrorTypes, ProxyException
mock_prisma_client = MagicMock()
mock_authenticate_user = AsyncMock(
side_effect=ProxyException(
message="Invalid credentials",
type=ProxyErrorTypes.auth_error,
param="password",
code=401,
)
)
monkeypatch.setattr(
"litellm.proxy.auth.login_utils.authenticate_user",
mock_authenticate_user,
)
monkeypatch.setattr("litellm.proxy.proxy_server.master_key", "test-master-key")
monkeypatch.setattr(
"litellm.proxy.proxy_server.general_settings",
{"control_plane_url": "https://cp.example.com"},
)
monkeypatch.setattr("litellm.proxy.proxy_server.prisma_client", mock_prisma_client)
client = TestClient(app)
response = client.post(
"/v3/login",
json={"username": "alice", "password": "wrong"},
)
assert response.status_code == 401
assert response.headers["content-type"] == "application/json"
data = response.json()
assert "error" in data
assert data["error"]["message"] == "Invalid credentials"
assert data["error"]["type"] == "auth_error"
def test_fallback_login_has_no_deprecation_banner(client_no_auth):
response = client_no_auth.get("/fallback/login")
assert response.status_code == 200
html = response.text
assert '<div class="deprecation-banner">' not in html
assert "Deprecated:" not in html
assert "<form" in html
@pytest.mark.parametrize(
"ui_logo_path",
[
"/etc/litellm/secret-config.json",
"/var/secrets/admin.key",
"/proc/self/environ",
"relative/path/logo.png",
],
)
def test_get_logo_url_does_not_disclose_local_paths(
client_no_auth, monkeypatch, ui_logo_path
):
# ``/get_logo_url`` is unauthenticated. Returning a local filesystem
# path verbatim discloses admin-only config to any caller. Only
# browser-loadable HTTP(S) URLs should be returned; for local paths
# the dashboard falls back to ``/get_image``.
monkeypatch.setenv("UI_LOGO_PATH", ui_logo_path)
response = client_no_auth.get("/get_logo_url")
assert response.status_code == 200
assert response.json() == {"logo_url": ""}
def test_get_logo_url_returns_https_url(client_no_auth, monkeypatch):
monkeypatch.setenv("UI_LOGO_PATH", "https://cdn.public.example/logo.png")
response = client_no_auth.get("/get_logo_url")
assert response.status_code == 200
assert response.json() == {"logo_url": "https://cdn.public.example/logo.png"}
def test_get_logo_url_returns_http_url(client_no_auth, monkeypatch):
# HTTP URLs (typically internal CDN) are still returned — those are
# intended to be loaded directly by the browser.
monkeypatch.setenv("UI_LOGO_PATH", "http://internal-cdn.corp:8080/logo.png")
response = client_no_auth.get("/get_logo_url")
assert response.status_code == 200
assert response.json() == {"logo_url": "http://internal-cdn.corp:8080/logo.png"}
def test_get_logo_url_returns_empty_when_unset(client_no_auth, monkeypatch):
monkeypatch.delenv("UI_LOGO_PATH", raising=False)
response = client_no_auth.get("/get_logo_url")
assert response.status_code == 200
assert response.json() == {"logo_url": ""}
def test_sso_key_generate_shows_deprecation_banner(client_no_auth, monkeypatch):
# Ensure the route returns the HTML form instead of redirecting
monkeypatch.setattr(
"litellm.proxy.management_endpoints.ui_sso.show_missing_vars_in_env",
lambda: None,
)
monkeypatch.setattr(
"litellm.proxy.management_endpoints.ui_sso.SSOAuthenticationHandler.get_redirect_url_for_sso",
lambda *args, **kwargs: "http://test/redirect",
)
monkeypatch.setattr(
"litellm.proxy.management_endpoints.ui_sso.SSOAuthenticationHandler._get_cli_state",
lambda *args, **kwargs: None,
)
monkeypatch.setattr(
"litellm.proxy.management_endpoints.ui_sso.SSOAuthenticationHandler.should_use_sso_handler",
lambda *args, **kwargs: False,
)
# Mock premium_user to bypass enterprise check (prevents 403 Forbidden)
monkeypatch.setattr(
"litellm.proxy.proxy_server.premium_user",
True,
)
monkeypatch.setenv("UI_USERNAME", "admin")
response = client_no_auth.get("/sso/key/generate")
assert response.status_code == 200
html = response.text
assert '<div class="deprecation-banner">' in html
assert "Deprecated:" in html
def test_restructure_ui_html_files_handles_nested_routes(tmp_path):
"""
Test that _restructure_ui_html_files correctly restructures HTML files.
Note: This function is always called now, both in development and non-root Docker environments.
"""
from litellm.proxy import proxy_server
ui_root = tmp_path / "ui"
ui_root.mkdir()
def write_file(path: Path, content: str) -> None:
path.parent.mkdir(parents=True, exist_ok=True)
path.write_text(content)
write_file(ui_root / "home.html", "home")
write_file(ui_root / "mcp" / "oauth" / "callback.html", "callback")
write_file(ui_root / "existing" / "index.html", "keep")
write_file(ui_root / "_next" / "ignore.html", "asset")
write_file(ui_root / "litellm-asset-prefix" / "ignore.html", "asset")
proxy_server._restructure_ui_html_files(str(ui_root))
assert not (ui_root / "home.html").exists()
assert (ui_root / "home" / "index.html").read_text() == "home"
assert not (ui_root / "mcp" / "oauth" / "callback.html").exists()
assert (
ui_root / "mcp" / "oauth" / "callback" / "index.html"
).read_text() == "callback"
assert (ui_root / "existing" / "index.html").read_text() == "keep"
assert (ui_root / "_next" / "ignore.html").read_text() == "asset"
assert (ui_root / "litellm-asset-prefix" / "ignore.html").read_text() == "asset"
def test_ui_extensionless_route_requires_restructure(tmp_path):
"""
Regression for non-root fallback: /ui/login expects login/index.html.
Note: Restructuring always happens now, both in development and non-root Docker environments.
"""
from litellm.proxy import proxy_server
ui_root = tmp_path / "ui"
ui_root.mkdir()
(ui_root / "index.html").write_text("index")
(ui_root / "login.html").write_text("login")
fastapi_app = FastAPI()
fastapi_app.mount("/ui", StaticFiles(directory=str(ui_root), html=True), name="ui")
client = TestClient(fastapi_app)
assert client.get("/ui/login.html").status_code == 200
assert client.get("/ui/login").status_code == 404
proxy_server._restructure_ui_html_files(str(ui_root))
response = client.get("/ui/login")
assert response.status_code == 200
assert "login" in response.text
def test_admin_ui_export_serves_nested_extensionless_routes():
out_dir = Path(litellm.__file__).parent / "proxy" / "_experimental" / "out"
assert out_dir.is_dir(), f"missing UI export at {out_dir}"
nested_html_offenders = [
path.relative_to(out_dir).as_posix()
for path in out_dir.rglob("*.html")
if path.parent != out_dir
and path.name != "index.html"
and "_next" not in path.parts
and "litellm-asset-prefix" not in path.parts
]
assert not nested_html_offenders, (
"Nested routes must be named index.html. Offenders: " f"{nested_html_offenders}"
)
callback_index = out_dir / "mcp" / "oauth" / "callback" / "index.html"
assert callback_index.is_file(), (
f"MCP OAuth callback page must exist at {callback_index}; "
"without it /ui/mcp/oauth/callback 404s after Linear redirects back."
)
fastapi_app = FastAPI()
fastapi_app.mount("/ui", StaticFiles(directory=str(out_dir), html=True), name="ui")
client = TestClient(fastapi_app)
redirect = client.get(
"/ui/mcp/oauth/callback?code=abc&state=xyz",
follow_redirects=False,
)
assert redirect.status_code == 307
assert redirect.headers["location"].endswith(
"/ui/mcp/oauth/callback/?code=abc&state=xyz"
)
landed = client.get("/ui/mcp/oauth/callback?code=abc&state=xyz")
assert landed.status_code == 200
assert "<html" in landed.text.lower()
def test_restructure_always_happens(monkeypatch):
"""
Test that restructuring logic always executes regardless of LITELLM_NON_ROOT setting.
In development (is_non_root=False), restructuring happens directly in _experimental/out.
In non-root Docker (is_non_root=True), restructuring happens in /var/lib/litellm/ui.
"""
# Test Case 1: is_non_root is True - restructuring happens in /var/lib/litellm/ui
monkeypatch.setenv("LITELLM_NON_ROOT", "true")
runtime_ui_path = "/var/lib/litellm/ui"
packaged_ui_path = "/some/packaged/ui/path"
# Simulate the logic from proxy_server.py
is_non_root = os.getenv("LITELLM_NON_ROOT", "").lower() == "true"
if is_non_root:
ui_path = runtime_ui_path
else:
ui_path = packaged_ui_path
# Restructuring always happens now, regardless of ui_path vs packaged_ui_path
should_restructure = True
assert is_non_root is True
assert should_restructure is True
assert ui_path == runtime_ui_path
# Test Case 2: is_non_root is False - restructuring happens directly in packaged_ui_path
monkeypatch.delenv("LITELLM_NON_ROOT", raising=False)
# Simulate the logic from proxy_server.py
is_non_root = os.getenv("LITELLM_NON_ROOT", "").lower() == "true"
if is_non_root:
ui_path = runtime_ui_path
else:
ui_path = packaged_ui_path
# Restructuring always happens now, even when ui_path == packaged_ui_path
should_restructure = True
assert is_non_root is False
assert should_restructure is True
assert ui_path == packaged_ui_path
@pytest.mark.asyncio
async def test_initialize_scheduled_jobs_credentials(monkeypatch):
"""
Test that get_credentials is only called when store_model_in_db is True
"""
monkeypatch.delenv("DISABLE_PRISMA_SCHEMA_UPDATE", raising=False)
monkeypatch.delenv("STORE_MODEL_IN_DB", raising=False)
from litellm.proxy.proxy_server import ProxyStartupEvent
from litellm.proxy.utils import ProxyLogging
# Mock dependencies
mock_prisma_client = MagicMock()
mock_proxy_logging = MagicMock(spec=ProxyLogging)
mock_proxy_logging.slack_alerting_instance = MagicMock()
mock_proxy_config = AsyncMock()
with (
patch("litellm.proxy.proxy_server.proxy_config", mock_proxy_config),
patch("litellm.proxy.proxy_server.store_model_in_db", False),
): # set store_model_in_db to False
# Test when store_model_in_db is False
await ProxyStartupEvent.initialize_scheduled_background_jobs(
general_settings={},
prisma_client=mock_prisma_client,
proxy_budget_rescheduler_min_time=1,
proxy_budget_rescheduler_max_time=2,
proxy_batch_write_at=5,
proxy_logging_obj=mock_proxy_logging,
)
# Verify get_credentials was not called
mock_proxy_config.get_credentials.assert_not_called()
# Now test with store_model_in_db = True
with (
patch("litellm.proxy.proxy_server.proxy_config", mock_proxy_config),
patch("litellm.proxy.proxy_server.store_model_in_db", True),
patch("litellm.proxy.proxy_server.get_secret_bool", return_value=True),
):
await ProxyStartupEvent.initialize_scheduled_background_jobs(
general_settings={},
prisma_client=mock_prisma_client,
proxy_budget_rescheduler_min_time=1,
proxy_budget_rescheduler_max_time=2,
proxy_batch_write_at=5,
proxy_logging_obj=mock_proxy_logging,
)
# Verify get_credentials was called both directly and scheduled
assert mock_proxy_config.get_credentials.call_count == 1 # Direct call
# Verify a scheduled job was added for get_credentials
mock_scheduler_calls = [
call[0] for call in mock_proxy_config.get_credentials.mock_calls
]
assert len(mock_scheduler_calls) > 0
def test_update_config_fields_deep_merge_db_wins():
from litellm.proxy.proxy_server import ProxyConfig
proxy_config = ProxyConfig()
current_config = {
"router_settings": {
"routing_mode": "cost_optimized",
"model_group_alias": {
# Existing alias with older model + different hidden flag
"claude-sonnet-4": {
"model": "claude-sonnet-4-20240219",
"hidden": True,
},
# An extra alias that should remain untouched unless DB overrides it
"legacy-sonnet": {
"model": "claude-2.1",
"hidden": True,
},
},
}
}
db_param_value = {
"model_group_alias": {
# Conflict: DB should win (both 'model' and 'hidden')
"claude-sonnet-4": {
"model": "claude-sonnet-4-20250514",
"hidden": False,
},
# New alias to be added by the merge
"claude-sonnet-latest": {
"model": "claude-sonnet-4-20250514",
"hidden": True,
},
# Demonstrate that None values from DB are skipped (preserve existing)
"legacy-sonnet": {"hidden": None}, # should not clobber current True
}
}
updated = proxy_config._update_config_fields(
current_config=current_config,
param_name="router_settings",
db_param_value=db_param_value,
)
rs = updated["router_settings"]
aliases = rs["model_group_alias"]
# DB wins on conflicts (deep) for existing alias
assert aliases["claude-sonnet-4"]["model"] == "claude-sonnet-4-20250514"
assert aliases["claude-sonnet-4"]["hidden"] is False
# New alias introduced by DB is present with its values
assert "claude-sonnet-latest" in aliases
assert aliases["claude-sonnet-latest"]["model"] == "claude-sonnet-4-20250514"
assert aliases["claude-sonnet-latest"]["hidden"] is True
# None in DB does not overwrite existing values
assert aliases["legacy-sonnet"]["model"] == "claude-2.1"
assert aliases["legacy-sonnet"]["hidden"] is True
# Unrelated router_settings keys are preserved
assert rs["routing_mode"] == "cost_optimized"
def test_get_config_custom_callback_api_env_vars(monkeypatch):
"""
Ensure /get/config/callbacks returns custom callback env vars when both custom values are provided.
"""
from litellm.proxy.proxy_server import app, proxy_config, user_api_key_auth
# Mock config with custom_callback_api enabled and generic logger env vars present
config_data = {
"litellm_settings": {"success_callback": ["custom_callback_api"]},
"general_settings": {},
"environment_variables": {
"GENERIC_LOGGER_ENDPOINT": "https://callback.example.com",
"GENERIC_LOGGER_HEADERS": "Auth: token",
},
}
# Mock proxy_config.get_config and router settings
mock_router = MagicMock()
mock_router.get_settings.return_value = {}
monkeypatch.setattr("litellm.proxy.proxy_server.llm_router", mock_router)
monkeypatch.setattr(proxy_config, "get_config", AsyncMock(return_value=config_data))
# Bypass auth dependency
original_overrides = app.dependency_overrides.copy()
app.dependency_overrides[user_api_key_auth] = lambda: MagicMock()
client = TestClient(app)
try:
response = client.get("/get/config/callbacks")
finally:
app.dependency_overrides = original_overrides
assert response.status_code == 200
callbacks = response.json()["callbacks"]
custom_cb = next(
(cb for cb in callbacks if cb["name"] == "custom_callback_api"), None
)
assert custom_cb is not None
assert custom_cb["variables"] == {
"GENERIC_LOGGER_ENDPOINT": "https://callback.example.com",
"GENERIC_LOGGER_HEADERS": "Auth: token",
}
# Mock Prisma
class MockPrisma:
def __init__(self, database_url=None, proxy_logging_obj=None, http_client=None):
self.database_url = database_url
self.proxy_logging_obj = proxy_logging_obj
self.http_client = http_client
async def connect(self):
pass
async def disconnect(self):
pass
mock_prisma = MockPrisma()
@patch(
"litellm.proxy.proxy_server.ProxyStartupEvent._setup_prisma_client",
return_value=mock_prisma,
)
@pytest.mark.asyncio
async def test_aaaproxy_startup_master_key(mock_prisma, monkeypatch, tmp_path):
"""
Test that master_key is correctly loaded from either config.yaml or environment variables
"""
import yaml
from fastapi import FastAPI
# Import happens here - this is when the module probably reads the config path
from litellm.proxy.proxy_server import proxy_startup_event
# Mock the Prisma import
monkeypatch.setattr("litellm.proxy.proxy_server.PrismaClient", MockPrisma)
# Create test app
app = FastAPI()
# Test Case 1: Master key from config.yaml
test_master_key = "sk-12345"
test_config = {"general_settings": {"master_key": test_master_key}}
# Create a temporary config file
config_path = tmp_path / "config.yaml"
with open(config_path, "w") as f:
yaml.dump(test_config, f)
print(f"SET ENV VARIABLE - CONFIG_FILE_PATH, str(config_path): {str(config_path)}")
# Second setting of CONFIG_FILE_PATH to a different value
monkeypatch.setenv("CONFIG_FILE_PATH", str(config_path))
print(f"config_path: {config_path}")
print(f"os.getenv('CONFIG_FILE_PATH'): {os.getenv('CONFIG_FILE_PATH')}")
async with proxy_startup_event(app):
from litellm.proxy.proxy_server import master_key
assert master_key == test_master_key
# Test Case 2: Master key from environment variable
test_env_master_key = "sk-test-67890"
# Create empty config
empty_config = {"general_settings": {}}
with open(config_path, "w") as f:
yaml.dump(empty_config, f)
monkeypatch.setenv("LITELLM_MASTER_KEY", test_env_master_key)
print("test_env_master_key: {}".format(test_env_master_key))
async with proxy_startup_event(app):
from litellm.proxy.proxy_server import master_key
assert master_key == test_env_master_key
# Test Case 3: Master key with os.environ prefix
test_resolved_key = "sk-resolved-key"
test_config_with_prefix = {
"general_settings": {"master_key": "os.environ/CUSTOM_MASTER_KEY"}
}
# Create config with os.environ prefix
with open(config_path, "w") as f:
yaml.dump(test_config_with_prefix, f)
monkeypatch.setenv("CUSTOM_MASTER_KEY", test_resolved_key)
async with proxy_startup_event(app):
from litellm.proxy.proxy_server import master_key
assert master_key == test_resolved_key
def test_team_info_masking():
"""
Test that sensitive team information is properly masked
Ref: https://huntr.com/bounties/661b388a-44d8-4ad5-862b-4dc5b80be30a
"""
from litellm.proxy.proxy_server import ProxyConfig
proxy_config = ProxyConfig()
# Test team object with sensitive data
team1_info = {
"success_callback": "['langfuse', 's3']",
"langfuse_secret": "secret-test-key",
"langfuse_public_key": "public-test-key",
}
with pytest.raises(Exception) as exc_info:
proxy_config._get_team_config(
team_id="test_dev",
all_teams_config=[team1_info],
)
print("Got exception: {}".format(exc_info.value))
assert "secret-test-key" not in str(exc_info.value)
assert "public-test-key" not in str(exc_info.value)
def test_embedding_input_array_of_tokens(client_no_auth):
"""
Test to bypass decoding input as array of tokens for selected providers
Ref: https://github.com/BerriAI/litellm/issues/10113
"""
from litellm.proxy import proxy_server
# The client_no_auth fixture should initialize the router
# Assert this to catch any router initialization regressions
assert proxy_server.llm_router is not None, (
"llm_router is None after client_no_auth fixture initialized. "
"This indicates a router initialization issue that should be investigated."
)
try:
with mock.patch.object(
proxy_server.llm_router,
"aembedding",
return_value=example_embedding_result,
) as mock_aembedding:
test_data = {
"model": "vllm_embed_model",
"input": [[2046, 13269, 158208]],
}
response = client_no_auth.post("/v1/embeddings", json=test_data)
# Assert that aembedding was called, and that input was not modified
mock_aembedding.assert_called_once()
call_args, call_kwargs = mock_aembedding.call_args
assert call_kwargs["model"] == "vllm_embed_model"
assert call_kwargs["input"] == [[2046, 13269, 158208]]
assert response.status_code == 200
result = response.json()
print(len(result["data"][0]["embedding"]))
assert (
len(result["data"][0]["embedding"]) > 10
) # this usually has len==1536 so
except Exception as e:
pytest.fail(f"LiteLLM Proxy test failed. Exception - {str(e)}")
@pytest.mark.asyncio
async def test_get_all_team_models():
"""
Test get_all_team_models function with both "*" and specific team IDs
"""
from unittest.mock import AsyncMock, MagicMock
from litellm.proxy._types import LiteLLM_TeamTable
from litellm.proxy.proxy_server import get_all_team_models
# Mock team data
mock_team1 = MagicMock()
mock_team1.model_dump.return_value = {
"team_id": "team1",
"models": ["gpt-4", "gpt-3.5-turbo"],
"team_alias": "Team 1",
}
mock_team2 = MagicMock()
mock_team2.model_dump.return_value = {
"team_id": "team2",
"models": ["claude-3", "gpt-4"],
"team_alias": "Team 2",
}
# Mock model data returned by router
mock_models_gpt4 = [
{"model_info": {"id": "gpt-4-model-1"}},
{"model_info": {"id": "gpt-4-model-2"}},
]
mock_models_gpt35 = [
{"model_info": {"id": "gpt-3.5-turbo-model-1"}},
]
mock_models_claude = [
{"model_info": {"id": "claude-3-model-1"}},
]
# Mock prisma client
mock_prisma_client = MagicMock()
mock_db = MagicMock()
mock_litellm_teamtable = MagicMock()
mock_prisma_client.db = mock_db
mock_db.litellm_teamtable = mock_litellm_teamtable
# Make find_many async
mock_litellm_teamtable.find_many = AsyncMock()
# Mock router
mock_router = MagicMock()
def mock_get_model_list(model_name, team_id=None):
if model_name == "gpt-4":
return mock_models_gpt4
elif model_name == "gpt-3.5-turbo":
return mock_models_gpt35
elif model_name == "claude-3":
return mock_models_claude
return None
mock_router.get_model_list.side_effect = mock_get_model_list
# Test Case 1: user_teams = "*" (all teams)
mock_litellm_teamtable.find_many.return_value = [mock_team1, mock_team2]
with patch("litellm.proxy.proxy_server.LiteLLM_TeamTable") as mock_team_table_class:
# Configure the mock class to return proper instances
def mock_team_table_constructor(**kwargs):
mock_instance = MagicMock()
mock_instance.team_id = kwargs["team_id"]
mock_instance.models = kwargs["models"]
mock_instance.access_group_ids = kwargs.get("access_group_ids")
return mock_instance
mock_team_table_class.side_effect = mock_team_table_constructor
result = await get_all_team_models(
user_teams="*",
prisma_client=mock_prisma_client,
llm_router=mock_router,
)
# Verify find_many was called without where clause for "*"
mock_litellm_teamtable.find_many.assert_called_with()
# Verify router.get_model_list was called for each model
expected_calls = [
mock.call(model_name="gpt-4", team_id="team1"),
mock.call(model_name="gpt-3.5-turbo", team_id="team1"),
mock.call(model_name="claude-3", team_id="team2"),
mock.call(model_name="gpt-4", team_id="team2"),
]
mock_router.get_model_list.assert_has_calls(expected_calls, any_order=True)
# Test Case 2: user_teams = specific list
mock_litellm_teamtable.reset_mock()
mock_router.reset_mock()
mock_router.get_model_list.side_effect = mock_get_model_list
# Only return team1 for specific team query
mock_litellm_teamtable.find_many.return_value = [mock_team1]
with patch("litellm.proxy.proxy_server.LiteLLM_TeamTable") as mock_team_table_class:
mock_team_table_class.side_effect = mock_team_table_constructor
result = await get_all_team_models(
user_teams=["team1"],
prisma_client=mock_prisma_client,
llm_router=mock_router,
)
# Verify find_many was called with where clause for specific teams
mock_litellm_teamtable.find_many.assert_called_with(
where={"team_id": {"in": ["team1"]}}
)
# Verify router.get_model_list was called only for team1 models
expected_calls = [
mock.call(model_name="gpt-4", team_id="team1"),
mock.call(model_name="gpt-3.5-turbo", team_id="team1"),
]
mock_router.get_model_list.assert_has_calls(expected_calls, any_order=True)
# Test Case 3: Empty teams list
mock_litellm_teamtable.reset_mock()
mock_router.reset_mock()
mock_litellm_teamtable.find_many.return_value = []
result = await get_all_team_models(
user_teams=[],
prisma_client=mock_prisma_client,
llm_router=mock_router,
)
# Verify find_many was called with empty list
mock_litellm_teamtable.find_many.assert_called_with(where={"team_id": {"in": []}})
# Should return empty list when no teams
assert result == {}
# Test Case 4: Router returns None for some models
mock_litellm_teamtable.reset_mock()
mock_router.reset_mock()
mock_litellm_teamtable.find_many.return_value = [mock_team1]
def mock_get_model_list_with_none(model_name, team_id=None):
if model_name == "gpt-4":
return mock_models_gpt4
# Return None for gpt-3.5-turbo to test None handling
return None
mock_router.get_model_list.side_effect = mock_get_model_list_with_none
with patch("litellm.proxy.proxy_server.LiteLLM_TeamTable") as mock_team_table_class:
mock_team_table_class.side_effect = mock_team_table_constructor
result = await get_all_team_models(
user_teams=["team1"],
prisma_client=mock_prisma_client,
llm_router=mock_router,
)
# Should handle None return gracefully
assert isinstance(result, dict)
print("result: ", result)
assert result == {"gpt-4-model-1": ["team1"], "gpt-4-model-2": ["team1"]}
def test_add_team_models_to_all_models():
"""
Test add_team_models_to_all_models function
"""
from litellm.proxy._types import LiteLLM_TeamTable
from litellm.proxy.proxy_server import _add_team_models_to_all_models
team_db_objects_typed = MagicMock(spec=LiteLLM_TeamTable)
team_db_objects_typed.team_id = "team1"
team_db_objects_typed.models = ["all-proxy-models"]
llm_router = MagicMock()
llm_router.get_model_list.return_value = [
{"model_info": {"id": "gpt-4-model-1", "team_id": "team2"}},
{"model_info": {"id": "gpt-4-model-2"}},
]
result = _add_team_models_to_all_models(
team_db_objects_typed=[team_db_objects_typed],
llm_router=llm_router,
)
assert result == {"gpt-4-model-2": {"team1"}}
@pytest.mark.asyncio
async def test_apply_search_filter_matches_team_public_model_name():
"""
Regression test: team BYOK models persist an internal model_name
(e.g. `model_name_{team_id}_{uuid}`) and surface the user-facing name
via `model_info.team_public_model_name`. The /v2/model/info search
filter must match that public name so BYOK rows appear in results.
"""
from litellm.proxy.proxy_server import _apply_search_filter_to_models
byok_model = {
"model_name": "model_name_team-abc-123_4a6b8",
"litellm_params": {"model": "claude-sonnet-4-5"},
"model_info": {
"id": "byok-id-1",
"team_id": "team-abc-123",
"team_public_model_name": "team-claude-sonnet",
"db_model": True,
},
}
unrelated_model = {
"model_name": "gpt-4",
"litellm_params": {"model": "gpt-4"},
"model_info": {"id": "normal-id-1", "db_model": False},
}
# Search matching only team_public_model_name should still include BYOK
filtered, _ = await _apply_search_filter_to_models(
all_models=[byok_model, unrelated_model],
search="claude",
prisma_client=None,
proxy_config=MagicMock(),
)
filtered_ids = {m["model_info"]["id"] for m in filtered}
assert "byok-id-1" in filtered_ids
assert "normal-id-1" not in filtered_ids
# Search by internal model_name still matches as before
filtered, _ = await _apply_search_filter_to_models(
all_models=[byok_model, unrelated_model],
search="model_name_team-abc-123",
prisma_client=None,
proxy_config=MagicMock(),
)
assert [m["model_info"]["id"] for m in filtered] == ["byok-id-1"]
# Non-matching search returns nothing
filtered, _ = await _apply_search_filter_to_models(
all_models=[byok_model, unrelated_model],
search="gemini",
prisma_client=None,
proxy_config=MagicMock(),
)
assert filtered == []
@pytest.mark.asyncio
async def test_apply_search_filter_scopes_byok_to_caller_teams():
"""
Regression test: `/v2/model/info?search=...` must not leak BYOK rows
from teams the caller is not a member of. Even with a bounded
`model_name`-contains DB query, a non-admin caller could otherwise
see other teams' BYOK rows that happen to match by internal name.
The post-fetch team scope drops those.
"""
from litellm.proxy.proxy_server import _apply_search_filter_to_models
# In-router BYOK rows: one in the caller's team, one in someone else's.
caller_team_byok = {
"model_name": "model_name_team-mine_internal",
"litellm_params": {"model": "claude-sonnet"},
"model_info": {
"id": "byok-mine",
"team_id": "team-mine",
"team_public_model_name": "claude-sonnet-prod",
"db_model": True,
},
}
other_team_byok = {
"model_name": "model_name_team-other_internal",
"litellm_params": {"model": "claude-sonnet"},
"model_info": {
"id": "byok-other",
"team_id": "team-other",
"team_public_model_name": "claude-sonnet-staging",
"db_model": True,
},
}
# Non-team row stays in the router-side result regardless of teams.
public_model = {
"model_name": "claude-public",
"litellm_params": {"model": "claude-sonnet"},
"model_info": {"id": "public-id", "db_model": False},
}
# DB-only BYOK rows fetched by the over-broad JSON branch.
db_caller_row = MagicMock()
db_caller_row.model_id = "byok-db-mine"
db_caller_row.model_name = "model_name_team-mine_db"
db_caller_row.model_info = {
"id": "byok-db-mine",
"team_id": "team-mine",
"team_public_model_name": "Claude DB Mine",
"db_model": True,
}
db_other_row = MagicMock()
db_other_row.model_id = "byok-db-other"
db_other_row.model_name = "model_name_team-other_db"
db_other_row.model_info = {
"id": "byok-db-other",
"team_id": "team-other",
"team_public_model_name": "Claude DB Other",
"db_model": True,
}
prisma_client = MagicMock()
prisma_client.db.litellm_proxymodeltable.count = AsyncMock(return_value=2)
prisma_client.db.litellm_proxymodeltable.find_many = AsyncMock(
return_value=[db_caller_row, db_other_row]
)
caller_user_row = MagicMock()
caller_user_row.teams = ["team-mine"]
prisma_client.db.litellm_usertable.find_unique = AsyncMock(
return_value=caller_user_row
)
proxy_config = MagicMock()
proxy_config.decrypt_model_list_from_db = lambda rows: [
{
"model_name": r.model_name,
"model_info": r.model_info,
"litellm_params": {"model": "claude-sonnet"},
}
for r in rows
]
non_admin = MagicMock(spec=UserAPIKeyAuth)
non_admin.user_role = LitellmUserRoles.INTERNAL_USER
non_admin.user_id = "user-mine"
filtered, total_count = await _apply_search_filter_to_models(
all_models=[caller_team_byok, other_team_byok, public_model],
search="claude",
prisma_client=prisma_client,
proxy_config=proxy_config,
user_api_key_dict=non_admin,
)
filtered_ids = {m["model_info"]["id"] for m in filtered}
assert "byok-mine" in filtered_ids
assert "byok-db-mine" in filtered_ids
assert "public-id" in filtered_ids
assert "byok-other" not in filtered_ids, (
"router-side BYOK from another team must be dropped from search "
"when caller doesn't belong to that team"
)
assert "byok-db-other" not in filtered_ids, (
"DB-only BYOK from another team must be dropped from search when "
"caller doesn't belong to that team"
)
# total_count is router_models_count (2: caller_team_byok + public_model,
# other_team_byok dropped router-side) + DB count (2 from the mocked
# `count()`). The DB count is the *unscoped* match count; non-admin
# team scoping applies only to the returned page so the count can be
# over-reported, but it must never under-report (callers can paginate
# within the bound).
assert total_count == 4
# Admins keep the un-scoped view across teams.
admin = MagicMock(spec=UserAPIKeyAuth)
admin.user_role = LitellmUserRoles.PROXY_ADMIN
admin.user_id = "admin-1"
filtered_admin, _ = await _apply_search_filter_to_models(
all_models=[caller_team_byok, other_team_byok, public_model],
search="claude",
prisma_client=prisma_client,
proxy_config=proxy_config,
user_api_key_dict=admin,
)
admin_ids = {m["model_info"]["id"] for m in filtered_admin}
assert "byok-other" in admin_ids
assert "byok-db-other" in admin_ids
@pytest.mark.asyncio
async def test_apply_search_filter_bounds_db_fetch_by_page_and_cap():
"""
Regression test: a broad search term must not force a full BYOK-table
read + decrypt on each request.
* Unsorted searches: `find_many(take=N)` where N is just enough to
fill the current page after counting router-side matches.
* Sorted searches: `find_many(take=cap)` falls back to
`_SORTED_SEARCH_DB_FETCH_CAP` so ordering still works across a
large match set without scanning the whole table.
"""
from litellm.proxy.proxy_server import (
_SORTED_SEARCH_DB_FETCH_CAP,
_apply_search_filter_to_models,
)
prisma_client = MagicMock()
prisma_client.db.litellm_proxymodeltable.count = AsyncMock(return_value=10_000)
prisma_client.db.litellm_proxymodeltable.find_many = AsyncMock(return_value=[])
proxy_config = MagicMock()
proxy_config.decrypt_model_list_from_db = lambda rows: []
# Unsorted: page=1, size=50, no router-side matches -> take must be 50.
await _apply_search_filter_to_models(
all_models=[],
search="model",
prisma_client=prisma_client,
proxy_config=proxy_config,
page=1,
size=50,
sort_by=None,
)
take = prisma_client.db.litellm_proxymodeltable.find_many.call_args.kwargs["take"]
assert take == 50, "unsorted search must take just one page's worth of rows"
# Sorted: still bounded, but by the hard cap rather than the page.
prisma_client.db.litellm_proxymodeltable.find_many.reset_mock()
await _apply_search_filter_to_models(
all_models=[],
search="model",
prisma_client=prisma_client,
proxy_config=proxy_config,
page=1,
size=50,
sort_by="model_name",
)
take = prisma_client.db.litellm_proxymodeltable.find_many.call_args.kwargs["take"]
assert take == _SORTED_SEARCH_DB_FETCH_CAP
assert take < 10_000, "sorted search must cap below the full match set"
@pytest.mark.asyncio
async def test_filter_models_by_team_id_excludes_viewer_direct_access():
"""
Regression test: when the UI picks a specific team in the Current Team
selector, the model list must show only that team's BYOK rows + the
models assigned to the team. The admin viewer's `direct_access` flag
(set on every non-team model upstream) must NOT widen the team's
visible set, or selecting team-111 still shows every public model.
"""
from litellm.proxy.proxy_server import _filter_models_by_team_id
public_model = {
"model_name": "gpt-4",
"litellm_params": {"model": "gpt-4"},
"model_info": {
"id": "public-id",
# admin viewer has direct_access on this public model
"direct_access": True,
# team-111 is NOT in access_via_team_ids -> shouldn't show for team-111
"access_via_team_ids": ["team-222"],
},
}
team111_byok = {
"model_name": "model_name_team-111_uuid",
"litellm_params": {"model": "claude-sonnet"},
"model_info": {
"id": "byok-team-111",
"team_id": "team-111",
"team_public_model_name": "team-claude",
"access_via_team_ids": ["team-111"],
},
}
team222_byok = {
"model_name": "model_name_team-222_uuid",
"litellm_params": {"model": "claude-haiku"},
"model_info": {
"id": "byok-team-222",
"team_id": "team-222",
"team_public_model_name": "team-haiku",
"access_via_team_ids": ["team-222"],
},
}
prisma = MagicMock()
team_db = MagicMock()
team_db.model_dump.return_value = {
"team_id": "team-111",
"team_alias": "Team 111",
# specific models list that doesn't include the BYOK's internal name
"models": ["some-other-model"],
"access_group_ids": None,
}
prisma.db.litellm_teamtable.find_unique = AsyncMock(return_value=team_db)
prisma.db.litellm_proxymodeltable.find_many = AsyncMock(return_value=[])
router = MagicMock()
router.get_model_access_groups = MagicMock(return_value={})
# team-111 only resolves "some-other-model", which has no deployments
router.get_model_list = MagicMock(return_value=[])
filtered = await _filter_models_by_team_id(
all_models=[public_model, team111_byok, team222_byok],
team_id="team-111",
prisma_client=prisma,
llm_router=router,
)
visible_ids = sorted(m["model_info"]["id"] for m in filtered)
assert "byok-team-111" in visible_ids, "team-111's own BYOK must always be visible"
assert "byok-team-222" not in visible_ids, "must not leak other teams' BYOK"
assert (
"public-id" not in visible_ids
), "viewer's direct_access must not widen the team's visible set"
@pytest.mark.asyncio
async def test_filter_models_by_team_id_rejects_non_member():
"""
Regression test: /v2/model/info?teamId=X includes BYOK rows solely on
`model_info.team_id == X`. Without an auth check, any authenticated user
could enumerate another team's BYOK metadata by guessing its id. Callers
that are neither proxy admins nor members of `team_id` must get 403.
"""
from fastapi import HTTPException
from litellm.proxy._types import LitellmUserRoles, UserAPIKeyAuth
from litellm.proxy.proxy_server import _filter_models_by_team_id
byok = {
"model_name": "model_name_team-111_uuid",
"litellm_params": {"model": "claude"},
"model_info": {"id": "byok-team-111", "team_id": "team-111"},
}
prisma = MagicMock()
# Caller is in team-222 only
user_row = MagicMock()
user_row.teams = ["team-222"]
prisma.db.litellm_usertable.find_unique = AsyncMock(return_value=user_row)
caller = UserAPIKeyAuth(
user_id="alice",
user_role=LitellmUserRoles.INTERNAL_USER,
api_key="sk-test",
)
with pytest.raises(HTTPException) as excinfo:
await _filter_models_by_team_id(
all_models=[byok],
team_id="team-111",
prisma_client=prisma,
llm_router=MagicMock(),
user_api_key_dict=caller,
)
assert excinfo.value.status_code == 403
@pytest.mark.asyncio
async def test_filter_models_by_team_id_allows_team_member():
"""
A caller who IS a member of `team_id` must be allowed to filter, and
should see that team's BYOK rows.
"""
from litellm.proxy._types import LitellmUserRoles, UserAPIKeyAuth
from litellm.proxy.proxy_server import _filter_models_by_team_id
byok = {
"model_name": "model_name_team-111_uuid",
"litellm_params": {"model": "claude"},
"model_info": {"id": "byok-team-111", "team_id": "team-111"},
}
prisma = MagicMock()
user_row = MagicMock()
user_row.teams = ["team-111", "team-999"]
prisma.db.litellm_usertable.find_unique = AsyncMock(return_value=user_row)
team_db = MagicMock()
team_db.model_dump.return_value = {
"team_id": "team-111",
"team_alias": "Team 111",
"models": [],
"access_group_ids": None,
}
prisma.db.litellm_teamtable.find_unique = AsyncMock(return_value=team_db)
prisma.db.litellm_proxymodeltable.find_many = AsyncMock(return_value=[])
router = MagicMock()
router.get_model_access_groups = MagicMock(return_value={})
router.get_model_list = MagicMock(return_value=[byok])
caller = UserAPIKeyAuth(
user_id="bob",
user_role=LitellmUserRoles.INTERNAL_USER,
api_key="sk-test",
)
result = await _filter_models_by_team_id(
all_models=[byok],
team_id="team-111",
prisma_client=prisma,
llm_router=router,
user_api_key_dict=caller,
)
assert [m["model_info"]["id"] for m in result] == ["byok-team-111"]
@pytest.mark.asyncio
async def test_caller_byok_team_scope_treats_view_only_admin_as_unscoped():
"""
Regression test: `PROXY_ADMIN_VIEW_ONLY` is an admin role
("can login, view all own keys, view all spend"). Search results for
this role must show BYOK rows across all teams, not be silently scoped
to the user-id's `teams` field — that path narrows results to whatever
teams the admin happens to be a member of, regressing pre-PR behavior.
"""
from litellm.proxy._types import LitellmUserRoles, UserAPIKeyAuth
from litellm.proxy.proxy_server import _get_caller_byok_team_scope
caller = UserAPIKeyAuth(
user_id="view-admin",
user_role=LitellmUserRoles.PROXY_ADMIN_VIEW_ONLY,
api_key="sk-test",
)
scope = await _get_caller_byok_team_scope(
user_api_key_dict=caller,
prisma_client=MagicMock(),
)
assert scope is None, "PROXY_ADMIN_VIEW_ONLY must be unscoped, like PROXY_ADMIN"
@pytest.mark.asyncio
async def test_add_access_group_models_to_team_models():
"""
Test that models reachable via team access groups are included in team_models.
Scenario: A team has models=["gpt-4"] and access_group_ids=["premium"].
The "premium" access group contains ["claude-3", "gemini"].
After resolution, the team should see gpt-4 (direct) + claude-3/gemini (via access group).
"""
from litellm.proxy._types import LiteLLM_TeamTable
from litellm.proxy.proxy_server import _add_access_group_models_to_team_models
# Team with specific models AND access groups
team_with_access_groups = MagicMock(spec=LiteLLM_TeamTable)
team_with_access_groups.team_id = "team1"
team_with_access_groups.models = ["gpt-4"] # non-empty = specific models
team_with_access_groups.access_group_ids = ["premium"]
# Team with no access groups — should be skipped
team_without_access_groups = MagicMock(spec=LiteLLM_TeamTable)
team_without_access_groups.team_id = "team2"
team_without_access_groups.models = ["gpt-4"]
team_without_access_groups.access_group_ids = None
# Team with empty access_group_ids list — should be skipped
team_empty_access_groups = MagicMock(spec=LiteLLM_TeamTable)
team_empty_access_groups.team_id = "team2b"
team_empty_access_groups.models = ["gpt-4"]
team_empty_access_groups.access_group_ids = []
# Team with empty models (all access) — should be skipped
team_all_access = MagicMock(spec=LiteLLM_TeamTable)
team_all_access.team_id = "team3"
team_all_access.models = []
team_all_access.access_group_ids = ["premium"]
# Team with all-proxy-models sentinel (all access) — should be skipped
team_all_proxy = MagicMock(spec=LiteLLM_TeamTable)
team_all_proxy.team_id = "team4"
team_all_proxy.models = ["all-proxy-models"]
team_all_proxy.access_group_ids = ["premium"]
# Mock router
mock_router = MagicMock()
def mock_get_model_list(model_name, team_id=None):
if model_name == "claude-3":
return [{"model_info": {"id": "claude-3-id"}}]
elif model_name == "gemini":
return [{"model_info": {"id": "gemini-id"}}]
return None
mock_router.get_model_list.side_effect = mock_get_model_list
# Pre-existing team_models (e.g., from _add_team_models_to_all_models)
existing_team_models = {
"gpt-4-id": {"team1"},
}
# Mock prisma client with batch find_many returning access group rows
mock_ag_row = MagicMock()
mock_ag_row.access_group_id = "premium"
mock_ag_row.access_model_names = ["claude-3", "gemini"]
mock_prisma_client = MagicMock()
mock_prisma_client.db.litellm_accessgrouptable.find_many = AsyncMock(
return_value=[mock_ag_row]
)
result = await _add_access_group_models_to_team_models(
team_db_objects_typed=[
team_with_access_groups,
team_without_access_groups,
team_empty_access_groups,
team_all_access,
team_all_proxy,
],
llm_router=mock_router,
prisma_client=mock_prisma_client,
team_models=existing_team_models,
)
# Single batch query with only the eligible team's access group IDs
mock_prisma_client.db.litellm_accessgrouptable.find_many.assert_called_once()
call_args = mock_prisma_client.db.litellm_accessgrouptable.find_many.call_args
queried_ids = call_args[1]["where"]["access_group_id"]["in"]
assert set(queried_ids) == {"premium"}
# Original model still present
assert "gpt-4-id" in result
assert "team1" in result["gpt-4-id"]
# Access group models added for team1
assert "claude-3-id" in result
assert "team1" in result["claude-3-id"]
assert "gemini-id" in result
assert "team1" in result["gemini-id"]
# Skipped teams should NOT have added these models
for skipped_team in ["team2", "team2b", "team3", "team4"]:
assert skipped_team not in result.get("claude-3-id", set())
assert skipped_team not in result.get("gemini-id", set())
@pytest.mark.asyncio
async def test_add_access_group_models_multiple_teams_shared_group():
"""
Test that multiple teams sharing the same access group each get the models,
and only one batch DB query is made.
"""
from litellm.proxy._types import LiteLLM_TeamTable
from litellm.proxy.proxy_server import _add_access_group_models_to_team_models
team_a = MagicMock(spec=LiteLLM_TeamTable)
team_a.team_id = "team-a"
team_a.models = ["gpt-4"]
team_a.access_group_ids = ["shared-group"]
team_b = MagicMock(spec=LiteLLM_TeamTable)
team_b.team_id = "team-b"
team_b.models = ["gpt-3.5"]
team_b.access_group_ids = ["shared-group", "extra-group"]
mock_router = MagicMock()
def mock_get_model_list(model_name, team_id=None):
if model_name == "claude-3":
return [{"model_info": {"id": "claude-3-id"}}]
elif model_name == "gemini":
return [{"model_info": {"id": "gemini-id"}}]
return None
mock_router.get_model_list.side_effect = mock_get_model_list
mock_shared_row = MagicMock()
mock_shared_row.access_group_id = "shared-group"
mock_shared_row.access_model_names = ["claude-3"]
mock_extra_row = MagicMock()
mock_extra_row.access_group_id = "extra-group"
mock_extra_row.access_model_names = ["gemini"]
mock_prisma_client = MagicMock()
mock_prisma_client.db.litellm_accessgrouptable.find_many = AsyncMock(
return_value=[mock_shared_row, mock_extra_row]
)
result = await _add_access_group_models_to_team_models(
team_db_objects_typed=[team_a, team_b],
llm_router=mock_router,
prisma_client=mock_prisma_client,
team_models={},
)
# Single batch query for both groups
mock_prisma_client.db.litellm_accessgrouptable.find_many.assert_called_once()
call_args = mock_prisma_client.db.litellm_accessgrouptable.find_many.call_args
queried_ids = set(call_args[1]["where"]["access_group_id"]["in"])
assert queried_ids == {"shared-group", "extra-group"}
# Both teams get claude-3 from the shared group
assert "claude-3-id" in result
assert "team-a" in result["claude-3-id"]
assert "team-b" in result["claude-3-id"]
# Only team-b gets gemini (from extra-group)
assert "gemini-id" in result
assert "team-b" in result["gemini-id"]
assert "team-a" not in result["gemini-id"]
@pytest.mark.asyncio
async def test_add_access_group_models_no_eligible_teams():
"""
When no teams have access groups, find_many should not be called at all.
"""
from litellm.proxy._types import LiteLLM_TeamTable
from litellm.proxy.proxy_server import _add_access_group_models_to_team_models
team = MagicMock(spec=LiteLLM_TeamTable)
team.team_id = "team1"
team.models = ["gpt-4"]
team.access_group_ids = None
mock_prisma_client = MagicMock()
mock_prisma_client.db.litellm_accessgrouptable.find_many = AsyncMock()
result = await _add_access_group_models_to_team_models(
team_db_objects_typed=[team],
llm_router=MagicMock(),
prisma_client=mock_prisma_client,
team_models={"existing-id": {"team1"}},
)
# No DB call made
mock_prisma_client.db.litellm_accessgrouptable.find_many.assert_not_called()
# Original data unchanged
assert result == {"existing-id": {"team1"}}
@pytest.mark.asyncio
async def test_get_all_team_models_with_access_groups():
"""
End-to-end test: get_all_team_models includes models from access groups.
Scenario: User is on team1 which has models=["gpt-4"] and
access_group_ids=["premium"]. The "premium" group has ["claude-3"].
The result should include both gpt-4 and claude-3 deployments for team1.
"""
from litellm.proxy.proxy_server import get_all_team_models
mock_team1 = MagicMock()
mock_team1.model_dump.return_value = {
"team_id": "team1",
"models": ["gpt-4"],
"team_alias": "Team 1",
"access_group_ids": ["premium"],
}
# Mock access group row returned by batch find_many
mock_ag_row = MagicMock()
mock_ag_row.access_group_id = "premium"
mock_ag_row.access_model_names = ["claude-3"]
mock_prisma_client = MagicMock()
mock_db = MagicMock()
mock_litellm_teamtable = MagicMock()
mock_prisma_client.db = mock_db
mock_db.litellm_teamtable = mock_litellm_teamtable
mock_litellm_teamtable.find_many = AsyncMock(return_value=[mock_team1])
mock_db.litellm_accessgrouptable = MagicMock()
mock_db.litellm_accessgrouptable.find_many = AsyncMock(return_value=[mock_ag_row])
mock_router = MagicMock()
def mock_get_model_list(model_name, team_id=None):
if model_name == "gpt-4":
return [{"model_info": {"id": "gpt-4-deploy-1"}}]
elif model_name == "claude-3":
return [{"model_info": {"id": "claude-3-deploy-1"}}]
return None
mock_router.get_model_list.side_effect = mock_get_model_list
with patch("litellm.proxy.proxy_server.LiteLLM_TeamTable") as mock_tt_class:
def mock_team_table_constructor(**kwargs):
mock_instance = MagicMock()
mock_instance.team_id = kwargs["team_id"]
mock_instance.models = kwargs["models"]
mock_instance.access_group_ids = kwargs.get("access_group_ids")
return mock_instance
mock_tt_class.side_effect = mock_team_table_constructor
result = await get_all_team_models(
user_teams=["team1"],
prisma_client=mock_prisma_client,
llm_router=mock_router,
)
# gpt-4 from direct team.models
assert "gpt-4-deploy-1" in result
assert "team1" in result["gpt-4-deploy-1"]
# claude-3 from access group
assert "claude-3-deploy-1" in result
assert "team1" in result["claude-3-deploy-1"]
# Return type is Dict[str, List[str]]
assert isinstance(result["gpt-4-deploy-1"], list)
assert isinstance(result["claude-3-deploy-1"], list)
@pytest.mark.asyncio
async def test_delete_deployment_type_mismatch():
"""
Test that the _delete_deployment function handles type mismatches correctly.
Specifically test that models 12345678 and 12345679 are NOT deleted when
they exist in both combined_id_list (as integers) and router_model_ids (as strings).
This test reproduces the bug where type mismatch causes valid models to be deleted.
"""
from unittest.mock import MagicMock, patch
from litellm.proxy.proxy_server import ProxyConfig
# Create mock ProxyConfig instance
pc = ProxyConfig()
pc.get_config = MagicMock(
return_value={
"model_list": [
{
"model_name": "openai-gpt-4o",
"litellm_params": {"model": "gpt-4o"},
"model_info": {"id": 12345678},
},
{
"model_name": "openai-gpt-4o",
"litellm_params": {"model": "gpt-4o"},
"model_info": {"id": 12345679},
},
]
}
)
# Mock llm_router with string IDs (this is the source of the type mismatch)
mock_llm_router = MagicMock()
mock_llm_router.get_model_ids.return_value = [
"a96e12e76b36a57cfae57a41288eb41567629cac89b4828c6f7074afc3534695",
"a40186dd0fdb9b7282380277d7f57044d29de95bfbfcd7f4322b3493702d5cd3",
"12345678", # String ID
"12345679", # String ID
]
# Track which deployments were deleted
deleted_ids = []
def mock_delete_deployment(id):
deleted_ids.append(id)
return True # Simulate successful deletion
mock_llm_router.delete_deployment = MagicMock(side_effect=mock_delete_deployment)
# Mock get_config to return empty config (no config models)
async def mock_get_config(config_file_path):
return {}
pc.get_config = MagicMock(side_effect=mock_get_config)
# Patch the global llm_router
with (
patch("litellm.proxy.proxy_server.llm_router", mock_llm_router),
patch("litellm.proxy.proxy_server.user_config_file_path", "test_config.yaml"),
):
# Call the function under test
deleted_count = await pc._delete_deployment(db_models=[])
# Assertions: Models 12345678 and 12345679 should NOT be deleted
# because they exist in combined_id_list (as integers) even though
# router has them as strings
# The function should delete the other 2 models that are not in combined_id_list
assert deleted_count == 0, f"Expected 0 deletions, got {deleted_count}"
# Verify that 12345678 and 12345679 were NOT deleted
assert (
"12345678" not in deleted_ids
), f"Model 12345678 should NOT be deleted. Deleted IDs: {deleted_ids}"
assert (
"12345679" not in deleted_ids
), f"Model 12345679 should NOT be deleted. Deleted IDs: {deleted_ids}"
@pytest.mark.asyncio
async def test_get_config_from_file(tmp_path, monkeypatch):
"""
Test the _get_config_from_file method of ProxyConfig class.
Tests various scenarios: valid file, non-existent file, no file path, None config.
"""
import yaml
from litellm.proxy.proxy_server import ProxyConfig
# Create a ProxyConfig instance
proxy_config = ProxyConfig()
# Test Case 1: Valid YAML config file exists
test_config = {
"model_list": [{"model_name": "gpt-4", "litellm_params": {"model": "gpt-4"}}],
"general_settings": {"master_key": "sk-test"},
"router_settings": {"enable_pre_call_checks": True},
"litellm_settings": {"drop_params": True},
}
config_file = tmp_path / "test_config.yaml"
with open(config_file, "w") as f:
yaml.dump(test_config, f)
# Clear global user_config_file_path for this test
monkeypatch.setattr("litellm.proxy.proxy_server.user_config_file_path", None)
result = await proxy_config._get_config_from_file(str(config_file))
assert result == test_config
# Verify that user_config_file_path was set
from litellm.proxy.proxy_server import user_config_file_path
assert user_config_file_path == str(config_file)
# Test Case 2: File path provided but file doesn't exist
non_existent_file = tmp_path / "non_existent.yaml"
with pytest.raises(Exception, match=f"Config file not found: {non_existent_file}"):
await proxy_config._get_config_from_file(str(non_existent_file))
# Test Case 3: No file path provided (should return default config)
monkeypatch.setattr("litellm.proxy.proxy_server.user_config_file_path", None)
expected_default = {
"model_list": [],
"general_settings": {},
"router_settings": {},
"litellm_settings": {},
}
result = await proxy_config._get_config_from_file(None)
assert result == expected_default
# Test Case 4: Empty YAML file (should raise exception for None config)
empty_file = tmp_path / "empty_config.yaml"
with open(empty_file, "w") as f:
f.write("") # Write empty content which will result in None when loaded
with pytest.raises(Exception, match="Config cannot be None or Empty."):
await proxy_config._get_config_from_file(str(empty_file))
# Test Case 5: Using global user_config_file_path when no config_file_path provided
monkeypatch.setattr(
"litellm.proxy.proxy_server.user_config_file_path", str(config_file)
)
result = await proxy_config._get_config_from_file(None)
assert result == test_config
def test_normalize_datetime_for_sorting():
"""
Test the _normalize_datetime_for_sorting function.
Tests various scenarios: None values, ISO format strings, datetime objects (naive and aware).
"""
from litellm.proxy.proxy_server import _normalize_datetime_for_sorting
# Test Case 1: None value
assert _normalize_datetime_for_sorting(None) is None
# Test Case 2: ISO format string with 'Z' suffix
dt_str_z = "2024-01-15T10:30:00Z"
result = _normalize_datetime_for_sorting(dt_str_z)
assert result is not None
assert isinstance(result, datetime)
assert result.tzinfo == timezone.utc
assert result.year == 2024
assert result.month == 1
assert result.day == 15
assert result.hour == 10
assert result.minute == 30
# Test Case 3: ISO format string without 'Z' suffix (naive)
dt_str_naive = "2024-01-15T10:30:00"
result = _normalize_datetime_for_sorting(dt_str_naive)
assert result is not None
assert isinstance(result, datetime)
assert result.tzinfo == timezone.utc
# Test Case 4: ISO format string with timezone offset
dt_str_tz = "2024-01-15T10:30:00+05:00"
result = _normalize_datetime_for_sorting(dt_str_tz)
assert result is not None
assert isinstance(result, datetime)
assert result.tzinfo == timezone.utc
# Should convert from +05:00 to UTC (subtract 5 hours)
assert result.hour == 5 # 10:30 - 5 hours = 5:30 UTC
# Test Case 5: Naive datetime object
naive_dt = datetime(2024, 1, 15, 10, 30, 0)
result = _normalize_datetime_for_sorting(naive_dt)
assert result is not None
assert isinstance(result, datetime)
assert result.tzinfo == timezone.utc
assert result.year == 2024
assert result.month == 1
assert result.day == 15
# Test Case 6: Timezone-aware datetime object (non-UTC)
from datetime import timedelta
aware_dt = datetime(2024, 1, 15, 10, 30, 0, tzinfo=timezone(timedelta(hours=5)))
result = _normalize_datetime_for_sorting(aware_dt)
assert result is not None
assert isinstance(result, datetime)
assert result.tzinfo == timezone.utc
# Should convert from +05:00 to UTC
assert result.hour == 5
# Test Case 7: UTC-aware datetime object
utc_dt = datetime(2024, 1, 15, 10, 30, 0, tzinfo=timezone.utc)
result = _normalize_datetime_for_sorting(utc_dt)
assert result is not None
assert isinstance(result, datetime)
assert result.tzinfo == timezone.utc
assert result == utc_dt
# Test Case 8: Invalid string format
invalid_str = "not-a-date"
result = _normalize_datetime_for_sorting(invalid_str)
assert result is None
# Test Case 9: Invalid type (should return None)
result = _normalize_datetime_for_sorting(12345)
assert result is None
@pytest.mark.asyncio
async def test_add_proxy_budget_to_db_only_creates_user_no_keys():
"""
Test that _add_proxy_budget_to_db only creates a user and no keys are added.
This validates that generate_key_helper_fn is called with table_name="user"
which should prevent key creation in LiteLLM_VerificationToken table.
"""
from unittest.mock import AsyncMock, patch
import litellm
from litellm.proxy.proxy_server import ProxyStartupEvent
# Set up required litellm settings
litellm.budget_duration = "30d"
litellm.max_budget = 100.0
litellm_proxy_budget_name = "litellm-proxy-budget"
# Mock generate_key_helper_fn to capture its call arguments
mock_generate_key_helper = AsyncMock(
return_value={
"user_id": litellm_proxy_budget_name,
"max_budget": 100.0,
"budget_duration": "30d",
"spend": 0,
"models": [],
}
)
# Patch generate_key_helper_fn in proxy_server where it's being called from
with patch(
"litellm.proxy.proxy_server.generate_key_helper_fn", mock_generate_key_helper
):
# Call the function under test
ProxyStartupEvent._add_proxy_budget_to_db(litellm_proxy_budget_name)
# Allow async task to complete
import asyncio
await asyncio.sleep(0.1)
# Verify that generate_key_helper_fn was called
mock_generate_key_helper.assert_called_once()
call_args = mock_generate_key_helper.call_args
# Verify critical parameters that prevent key creation
assert call_args.kwargs["request_type"] == "user"
assert call_args.kwargs["table_name"] == "user"
assert call_args.kwargs["user_id"] == litellm_proxy_budget_name
assert call_args.kwargs["max_budget"] == 100.0
assert call_args.kwargs["budget_duration"] == "30d"
assert call_args.kwargs["query_type"] == "update_data"
@pytest.mark.asyncio
async def test_add_proxy_budget_to_db_backfills_budget_reset_at():
"""
Test that _upsert_proxy_budget_with_reset_at_backfill issues a conditional
update_many with `WHERE budget_reset_at IS NULL` to backfill the column on
rows that pre-existed without a reset schedule. Without this, the proxy
admin row stays at NULL and reset_budget_for_litellm_users never matches
it (NULL < now() is unknown in SQL), so the global proxy budget never
resets.
"""
from unittest.mock import AsyncMock, MagicMock, patch
import litellm
from litellm.proxy.proxy_server import ProxyStartupEvent
litellm.budget_duration = "30d"
litellm.max_budget = 100.0
litellm_proxy_budget_name = "litellm-proxy-budget"
mock_prisma = MagicMock()
mock_prisma.db.litellm_usertable.update_many = AsyncMock(return_value={"count": 1})
mock_generate_key_helper = AsyncMock(
return_value={
"user_id": litellm_proxy_budget_name,
"max_budget": 100.0,
"budget_duration": "30d",
"spend": 0,
"models": [],
}
)
with (
patch(
"litellm.proxy.proxy_server.generate_key_helper_fn",
mock_generate_key_helper,
),
patch("litellm.proxy.proxy_server.prisma_client", mock_prisma),
):
await ProxyStartupEvent._upsert_proxy_budget_with_reset_at_backfill(
litellm_proxy_budget_name
)
# Upsert ran with the configured budget
mock_generate_key_helper.assert_called_once()
# Backfill update_many ran with the conditional WHERE
mock_prisma.db.litellm_usertable.update_many.assert_called_once()
backfill_call = mock_prisma.db.litellm_usertable.update_many.call_args
assert backfill_call.kwargs["where"]["user_id"] == litellm_proxy_budget_name
assert backfill_call.kwargs["where"]["budget_reset_at"] is None
# The backfilled value must be a real future datetime — anything else and
# reset_budget_for_litellm_users would still skip the row.
from datetime import datetime, timezone
backfilled_reset_at = backfill_call.kwargs["data"]["budget_reset_at"]
assert isinstance(backfilled_reset_at, datetime)
assert backfilled_reset_at > datetime.now(timezone.utc)
@pytest.mark.asyncio
async def test_custom_ui_sso_sign_in_handler_config_loading():
"""
Test that custom_ui_sso_sign_in_handler from config gets properly loaded into the global variable
"""
import tempfile
from unittest.mock import MagicMock, patch
import yaml
from litellm.proxy.proxy_server import ProxyConfig
# Create a test config with custom_ui_sso_sign_in_handler
test_config = {
"general_settings": {
"custom_ui_sso_sign_in_handler": "custom_hooks.custom_ui_sso_hook.custom_ui_sso_sign_in_handler"
},
"model_list": [],
"router_settings": {},
"litellm_settings": {},
}
# Create temporary config file
with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f:
yaml.dump(test_config, f)
config_file_path = f.name
# Mock the get_instance_fn to return a mock handler
mock_custom_handler = MagicMock()
try:
with patch(
"litellm.proxy.proxy_server.get_instance_fn",
return_value=mock_custom_handler,
) as mock_get_instance:
# Create ProxyConfig instance and load config
proxy_config = ProxyConfig()
# Create a mock router since load_config requires it
mock_router = MagicMock()
await proxy_config.load_config(
router=mock_router, config_file_path=config_file_path
)
# Verify get_instance_fn was called with correct parameters
mock_get_instance.assert_called_with(
value="custom_hooks.custom_ui_sso_hook.custom_ui_sso_sign_in_handler",
config_file_path=config_file_path,
)
# Verify the global variable was set
from litellm.proxy.proxy_server import user_custom_ui_sso_sign_in_handler
assert user_custom_ui_sso_sign_in_handler == mock_custom_handler
finally:
# Clean up temporary file
import os
os.unlink(config_file_path)
@pytest.mark.asyncio
async def test_load_environment_variables_direct_and_os_environ():
"""
Test _load_environment_variables method with direct values and os.environ/ prefixed values
"""
from unittest.mock import patch
from litellm.proxy.proxy_server import ProxyConfig
proxy_config = ProxyConfig()
# Test config with both direct values and os.environ/ prefixed values
test_config = {
"environment_variables": {
"DIRECT_VAR": "direct_value",
"NUMERIC_VAR": 12345,
"BOOL_VAR": True,
"SECRET_VAR": "os.environ/ACTUAL_SECRET_VAR",
}
}
# Mock get_secret_str to return a resolved value
mock_secret_value = "resolved_secret_value"
with patch(
"litellm.proxy.proxy_server.get_secret_str", return_value=mock_secret_value
) as mock_get_secret:
with patch.dict(
os.environ, {}, clear=False
): # Don't clear existing env vars, just track changes
# Call the method under test
proxy_config._load_environment_variables(test_config)
# Verify direct environment variables were set correctly
assert os.environ["DIRECT_VAR"] == "direct_value"
assert os.environ["NUMERIC_VAR"] == "12345" # Should be converted to string
assert os.environ["BOOL_VAR"] == "True" # Should be converted to string
# Verify os.environ/ prefixed variable was resolved and set
assert os.environ["SECRET_VAR"] == mock_secret_value
# Verify get_secret_str was called with the correct value
mock_get_secret.assert_called_once_with(
secret_name="os.environ/ACTUAL_SECRET_VAR"
)
@pytest.mark.asyncio
async def test_load_environment_variables_litellm_license_and_edge_cases():
"""
Test _load_environment_variables method with LITELLM_LICENSE special handling and edge cases
"""
from unittest.mock import MagicMock, patch
from litellm.proxy.proxy_server import ProxyConfig
proxy_config = ProxyConfig()
# Test Case 1: LITELLM_LICENSE in environment_variables
test_config_with_license = {
"environment_variables": {
"LITELLM_LICENSE": "test_license_key",
"OTHER_VAR": "other_value",
}
}
# Mock _license_check
mock_license_check = MagicMock()
mock_license_check.is_premium.return_value = True
with patch("litellm.proxy.proxy_server._license_check", mock_license_check):
with patch.dict(os.environ, {}, clear=False):
# Call the method under test
proxy_config._load_environment_variables(test_config_with_license)
# Verify LITELLM_LICENSE was set in environment
assert os.environ["LITELLM_LICENSE"] == "test_license_key"
# Verify license check was updated
assert mock_license_check.license_str == "test_license_key"
mock_license_check.is_premium.assert_called_once()
# Test Case 2: No environment_variables in config
test_config_no_env_vars = {}
# This should not raise any errors and should return without doing anything
result = proxy_config._load_environment_variables(test_config_no_env_vars)
assert result is None # Method returns None
# Test Case 3: environment_variables is None
test_config_none_env_vars = {"environment_variables": None}
# This should not raise any errors and should return without doing anything
result = proxy_config._load_environment_variables(test_config_none_env_vars)
assert result is None # Method returns None
# Test Case 4: os.environ/ prefix but get_secret_str returns None
test_config_secret_none = {
"environment_variables": {"FAILED_SECRET": "os.environ/NONEXISTENT_SECRET"}
}
with patch("litellm.proxy.proxy_server.get_secret_str", return_value=None):
with patch.dict(os.environ, {}, clear=False):
# Call the method under test
proxy_config._load_environment_variables(test_config_secret_none)
# Verify that the environment variable was not set when secret resolution fails
assert "FAILED_SECRET" not in os.environ
@pytest.mark.asyncio
async def test_load_environment_variables_blocks_dangerous_keys():
"""
Test that _load_environment_variables rejects dangerous env var keys
like PATH, LD_PRELOAD, PYTHONPATH, etc.
"""
import logging
from litellm.proxy.proxy_server import ProxyConfig
proxy_config = ProxyConfig()
original_path = os.environ.get("PATH", "")
test_config = {
"environment_variables": {
"PATH": "/tmp/evil",
"LD_PRELOAD": "/tmp/evil.so",
"PYTHONPATH": "/tmp/evil",
"SAFE_CUSTOM_VAR": "safe_value",
}
}
with patch.dict(os.environ, {}, clear=False):
proxy_config._load_environment_variables(test_config)
# Blocked keys should not be set to the attacker value
assert os.environ.get("PATH") != "/tmp/evil"
assert (
"LD_PRELOAD" not in os.environ or os.environ["LD_PRELOAD"] != "/tmp/evil.so"
)
assert os.environ.get("PYTHONPATH") != "/tmp/evil"
# Safe keys should still be set
assert os.environ["SAFE_CUSTOM_VAR"] == "safe_value"
@pytest.mark.asyncio
async def test_load_environment_variables_allows_proxy_keys():
"""
Test that HTTP_PROXY/HTTPS_PROXY are allowed since they are commonly used
in corporate environments to route outbound API calls.
"""
from litellm.proxy.proxy_server import ProxyConfig
proxy_config = ProxyConfig()
test_config = {
"environment_variables": {
"HTTP_PROXY": "http://corp-proxy:8080",
"HTTPS_PROXY": "http://corp-proxy:8080",
}
}
with patch.dict(os.environ, {}, clear=False):
proxy_config._load_environment_variables(test_config)
assert os.environ["HTTP_PROXY"] == "http://corp-proxy:8080"
assert os.environ["HTTPS_PROXY"] == "http://corp-proxy:8080"
@pytest.mark.asyncio
async def test_load_environment_variables_blocks_no_proxy():
"""
Test that NO_PROXY/no_proxy are blocked to prevent bypassing proxy-based
network monitoring.
"""
from litellm.proxy.proxy_server import ProxyConfig
proxy_config = ProxyConfig()
test_config = {
"environment_variables": {
"NO_PROXY": "internal-service",
"no_proxy": "internal-service",
}
}
with patch.dict(os.environ, {}, clear=False):
proxy_config._load_environment_variables(test_config)
assert os.environ.get("NO_PROXY") != "internal-service"
assert os.environ.get("no_proxy") != "internal-service"
@pytest.mark.asyncio
async def test_write_config_to_file(monkeypatch):
"""
Do not write config to file if store_model_in_db is True
"""
from unittest.mock import AsyncMock, MagicMock, mock_open, patch
from litellm.proxy.proxy_server import ProxyConfig
# Set store_model_in_db to True
monkeypatch.setattr("litellm.proxy.proxy_server.store_model_in_db", True)
# Mock prisma_client to not be None (so DB path is taken)
mock_prisma_client = AsyncMock()
mock_prisma_client.insert_data = AsyncMock()
monkeypatch.setattr("litellm.proxy.proxy_server.prisma_client", mock_prisma_client)
# Mock general_settings
mock_general_settings = {"store_model_in_db": True}
monkeypatch.setattr(
"litellm.proxy.proxy_server.general_settings", mock_general_settings
)
# Mock user_config_file_path
test_config_path = "/tmp/test_config.yaml"
monkeypatch.setattr(
"litellm.proxy.proxy_server.user_config_file_path", test_config_path
)
proxy_config = ProxyConfig()
# Mock the open function to track if file writing is attempted
mock_file_open = mock_open()
with patch("builtins.open", mock_file_open), patch("yaml.dump") as mock_yaml_dump:
# Call save_config with test data
test_config = {"key": "value", "model_list": ["model1", "model2"]}
await proxy_config.save_config(new_config=test_config)
# Verify that file was NOT opened for writing (since store_model_in_db=True)
mock_file_open.assert_not_called()
mock_yaml_dump.assert_not_called()
# Verify that database insert was called instead
mock_prisma_client.insert_data.assert_called_once()
# Verify the config passed to DB has model_list removed
call_args = mock_prisma_client.insert_data.call_args
assert call_args.kwargs["data"] == {
"key": "value"
} # model_list should be popped
assert call_args.kwargs["table_name"] == "config"
@pytest.mark.asyncio
async def test_write_config_to_file_when_store_model_in_db_false(monkeypatch):
"""
Test that config IS written to file when store_model_in_db is False
"""
from unittest.mock import AsyncMock, MagicMock, mock_open, patch
from litellm.proxy.proxy_server import ProxyConfig
# Set store_model_in_db to False
monkeypatch.setattr("litellm.proxy.proxy_server.store_model_in_db", False)
# Mock prisma_client to be None (so file path is taken)
monkeypatch.setattr("litellm.proxy.proxy_server.prisma_client", None)
# Mock general_settings
mock_general_settings = {"store_model_in_db": False}
monkeypatch.setattr(
"litellm.proxy.proxy_server.general_settings", mock_general_settings
)
# Mock user_config_file_path
test_config_path = "/tmp/test_config.yaml"
monkeypatch.setattr(
"litellm.proxy.proxy_server.user_config_file_path", test_config_path
)
proxy_config = ProxyConfig()
# Mock the open function and yaml.dump
mock_file_open = mock_open()
with patch("builtins.open", mock_file_open), patch("yaml.dump") as mock_yaml_dump:
# Call save_config with test data
test_config = {"key": "value", "other_key": "other_value"}
await proxy_config.save_config(new_config=test_config)
# Verify that file WAS opened for writing (since store_model_in_db=False)
mock_file_open.assert_called_once_with(f"{test_config_path}", "w")
# Verify yaml.dump was called with the config
mock_yaml_dump.assert_called_once_with(
test_config,
mock_file_open.return_value.__enter__.return_value,
default_flow_style=False,
)
@pytest.mark.asyncio
async def test_async_data_generator_midstream_error():
"""
Test async_data_generator handles midstream error from async_post_call_streaming_hook
Specifically testing the case where Azure Content Safety Guardrail returns an error
"""
from litellm.proxy._types import UserAPIKeyAuth
from litellm.proxy.proxy_server import async_data_generator
from litellm.proxy.utils import ProxyLogging
# Create mock objects
mock_user_api_key_dict = MagicMock(spec=UserAPIKeyAuth)
mock_request_data = {
"model": "gpt-3.5-turbo",
"messages": [{"role": "user", "content": "test"}],
}
# Mock response chunks - simulating normal streaming that gets interrupted
mock_chunks = [
{"choices": [{"delta": {"content": "Hello"}}]},
{"choices": [{"delta": {"content": " world"}}]},
{"choices": [{"delta": {"content": " this"}}]},
]
# Mock the proxy_logging_obj
mock_proxy_logging_obj = MagicMock(spec=ProxyLogging)
# Mock async_post_call_streaming_iterator_hook to yield chunks
async def mock_streaming_iterator(*args, **kwargs):
for chunk in mock_chunks:
yield chunk
mock_proxy_logging_obj.async_post_call_streaming_iterator_hook = (
mock_streaming_iterator
)
# Mock async_post_call_streaming_hook to return error on third chunk
def mock_streaming_hook(*args, **kwargs):
chunk = kwargs.get("response")
# Return error message for the third chunk (simulating guardrail trigger)
if chunk == mock_chunks[2]:
return 'data: {"error": {"error": "Azure Content Safety Guardrail: Hate crossed severity 2, Got severity: 2"}}'
# Return normal chunks for first two
return chunk
mock_proxy_logging_obj.async_post_call_streaming_hook = AsyncMock(
side_effect=mock_streaming_hook
)
mock_proxy_logging_obj.post_call_failure_hook = AsyncMock()
# Mock the global proxy_logging_obj
with patch("litellm.proxy.proxy_server.proxy_logging_obj", mock_proxy_logging_obj):
# Create a mock response object
mock_response = MagicMock()
# Collect all yielded data from the generator
yielded_data = []
try:
async for data in async_data_generator(
mock_response, mock_user_api_key_dict, mock_request_data
):
yielded_data.append(data)
except Exception as e:
# If there's an exception, that's also part of what we want to test
pass
# Verify the results
assert (
len(yielded_data) >= 3
), f"Expected at least 3 chunks, got {len(yielded_data)}: {yielded_data}"
# First two chunks should be normal data
assert yielded_data[0].startswith(
"data: "
), f"First chunk should start with 'data: ', got: {yielded_data[0]}"
assert yielded_data[1].startswith(
"data: "
), f"Second chunk should start with 'data: ', got: {yielded_data[1]}"
# The error message should be yielded
error_found = False
done_found = False
for data in yielded_data:
if "Azure Content Safety Guardrail: Hate crossed severity 2" in data:
error_found = True
if "data: [DONE]" in data:
done_found = True
assert (
error_found
), f"Error message should be found in yielded data. Got: {yielded_data}"
assert done_found, f"[DONE] message should be found at the end. Got: {yielded_data}"
# Verify that the streaming hook was called for each chunk
assert mock_proxy_logging_obj.async_post_call_streaming_hook.call_count == len(
mock_chunks
)
# Verify that post_call_failure_hook was NOT called (since this is not an exception case)
mock_proxy_logging_obj.post_call_failure_hook.assert_not_called()
def _has_nested_none_values(obj, path="root"):
"""
Recursively check if an object contains nested None values.
Args:
obj: The object to check
path: Current path in the object tree (for debugging)
Returns:
List of paths where None values were found
"""
none_paths = []
if obj is None:
none_paths.append(path)
elif isinstance(obj, dict):
for key, value in obj.items():
none_paths.extend(_has_nested_none_values(value, f"{path}.{key}"))
elif isinstance(obj, (list, tuple)):
for i, item in enumerate(obj):
none_paths.extend(_has_nested_none_values(item, f"{path}[{i}]"))
elif hasattr(obj, "__dict__"):
# Handle object attributes
for key, value in obj.__dict__.items():
if not key.startswith("_"): # Skip private attributes
none_paths.extend(_has_nested_none_values(value, f"{path}.{key}"))
return none_paths
@pytest.mark.asyncio
async def test_chat_completion_result_no_nested_none_values():
"""
Test that chat_completion result doesn't have nested None values when using exclude_none=True
"""
from unittest.mock import AsyncMock, MagicMock, patch
from fastapi import Request, Response
from pydantic import BaseModel
import litellm
from litellm.proxy._types import UserAPIKeyAuth
from litellm.proxy.proxy_server import chat_completion
# Create a mock ModelResponse with nested None values
mock_model_response = litellm.ModelResponse()
mock_model_response.id = "test-id"
mock_model_response.model = "gpt-3.5-turbo"
mock_model_response.object = "chat.completion"
mock_model_response.created = 1234567890
# Create message with None values that should be excluded
mock_message = litellm.Message(
content="Hello, world!",
role="assistant",
function_call=None, # This should be excluded
tool_calls=None, # This should be excluded
audio=None, # This should be excluded
reasoning_content=None, # This should be excluded
thinking_blocks=None, # This should be excluded
annotations=None, # This should be excluded
)
# Create choice with potential None values
mock_choice = litellm.Choices(
finish_reason="stop",
index=0,
message=mock_message,
logprobs=None, # This should be excluded when exclude_none=True
)
mock_model_response.choices = [mock_choice]
setattr(
mock_model_response,
"usage",
litellm.Usage(prompt_tokens=10, completion_tokens=5, total_tokens=15),
)
# Verify the mock has None values before serialization
raw_dict = mock_model_response.model_dump()
none_paths_before = _has_nested_none_values(raw_dict)
assert (
len(none_paths_before) > 0
), "Mock should have None values before exclude_none=True"
# Mock the request processing to return our mock response
mock_base_processor = MagicMock()
mock_base_processor.base_process_llm_request = AsyncMock(
return_value=mock_model_response
)
# Mock other dependencies
mock_request = MagicMock(spec=Request)
mock_response = MagicMock(spec=Response)
mock_user_api_key_dict = MagicMock(spec=UserAPIKeyAuth)
with (
patch(
"litellm.proxy.proxy_server._read_request_body",
return_value={"model": "gpt-3.5-turbo", "messages": []},
),
patch(
"litellm.proxy.proxy_server.ProxyBaseLLMRequestProcessing",
return_value=mock_base_processor,
),
):
# Call the chat_completion function
result = await chat_completion(
request=mock_request,
fastapi_response=mock_response,
user_api_key_dict=mock_user_api_key_dict,
)
# Verify the result is a dict (since isinstance(result, BaseModel) was True)
assert isinstance(result, dict), f"Expected dict result, got {type(result)}"
# Check that there are no nested None values in the result
none_paths_after = _has_nested_none_values(result)
assert (
len(none_paths_after) == 0
), f"Result should not contain nested None values. Found None at: {none_paths_after}"
# Verify essential fields are present
assert "id" in result
assert "model" in result
assert "object" in result
assert "created" in result
assert "choices" in result
assert "usage" in result
# Verify that the choices contain the expected message content
assert len(result["choices"]) == 1
assert result["choices"][0]["message"]["content"] == "Hello, world!"
assert result["choices"][0]["message"]["role"] == "assistant"
# Verify that None fields were excluded (should not be present in the dict)
message = result["choices"][0]["message"]
excluded_fields = [
"function_call",
"tool_calls",
"audio",
"reasoning_content",
"thinking_blocks",
"annotations",
]
for field in excluded_fields:
assert (
field not in message
), f"Field '{field}' should be excluded when it's None"
# ============================================================================
# Price Data Reload Tests
# ============================================================================
class TestPriceDataReloadAPI:
"""Test cases for price data reload API endpoints"""
@pytest.fixture
def client_with_auth(self):
"""Create a test client with authentication"""
from litellm.proxy._types import LitellmUserRoles
from litellm.proxy.proxy_server import cleanup_router_config_variables
cleanup_router_config_variables()
filepath = os.path.dirname(os.path.abspath(__file__))
config_fp = f"{filepath}/test_configs/test_config_no_auth.yaml"
asyncio.run(initialize(config=config_fp, debug=True))
# Mock admin user authentication
mock_auth = MagicMock()
mock_auth.user_role = LitellmUserRoles.PROXY_ADMIN
app.dependency_overrides[user_api_key_auth] = lambda: mock_auth
return TestClient(app)
def test_reload_model_cost_map_admin_access(self, client_with_auth):
"""Test that admin users can access the reload endpoint"""
# Save the original model_cost so the endpoint's direct assignment
# (litellm.model_cost = new_model_cost_map) does not contaminate
# subsequent tests running in the same worker process.
original_model_cost = litellm.model_cost.copy()
try:
with patch(
"litellm.litellm_core_utils.get_model_cost_map.get_model_cost_map"
) as mock_get_map:
mock_get_map.return_value = {
"gpt-3.5-turbo": {"input_cost_per_token": 0.001}
}
# Mock the database connection
with patch("litellm.proxy.proxy_server.prisma_client") as mock_prisma:
mock_prisma.db.litellm_config.find_unique = AsyncMock(
return_value=None
)
mock_prisma.db.litellm_config.upsert = AsyncMock(return_value=None)
response = client_with_auth.post("/reload/model_cost_map")
assert response.status_code == 200
data = response.json()
assert data["status"] == "success"
assert "message" in data
assert "timestamp" in data
assert "models_count" in data
# The new implementation immediately reloads and returns the count
assert (
"Price data reloaded successfully! 1 models updated."
in data["message"]
)
assert data["models_count"] == 1
finally:
# Restore the full model cost map so subsequent tests are not affected
litellm.model_cost = original_model_cost
_invalidate_model_cost_lowercase_map()
def test_reload_model_cost_map_non_admin_access(self, client_with_auth):
"""Test that non-admin users cannot access the reload endpoint"""
# Mock non-admin user
mock_auth = MagicMock()
mock_auth.user_role = "user" # Non-admin role
app.dependency_overrides[user_api_key_auth] = lambda: mock_auth
response = client_with_auth.post("/reload/model_cost_map")
assert response.status_code == 403
data = response.json()
assert "Access denied" in data["detail"]
assert "Admin role required" in data["detail"]
def test_get_model_cost_map_public_access(self, client_no_auth):
"""Test that the model cost map endpoint is publicly accessible"""
with patch(
"litellm.model_cost", {"gpt-3.5-turbo": {"input_cost_per_token": 0.001}}
):
response = client_no_auth.get("/public/litellm_model_cost_map")
assert response.status_code == 200
data = response.json()
assert "gpt-3.5-turbo" in data
def test_reload_model_cost_map_error_handling(self, client_with_auth):
"""Test error handling in the reload endpoint"""
with patch(
"litellm.litellm_core_utils.get_model_cost_map.get_model_cost_map"
) as mock_get_map:
mock_get_map.side_effect = Exception("Network error")
# Mock the database connection
with patch("litellm.proxy.proxy_server.prisma_client") as mock_prisma:
mock_prisma.db.litellm_config.find_unique = AsyncMock(return_value=None)
mock_prisma.db.litellm_config.upsert = AsyncMock(return_value=None)
response = client_with_auth.post("/reload/model_cost_map")
assert (
response.status_code == 500
) # The new implementation immediately reloads and fails on error
data = response.json()
assert "Failed to reload model cost map" in data["detail"]
def test_schedule_model_cost_map_reload_admin_access(self, client_with_auth):
"""Test that admin users can schedule periodic reload"""
with patch("litellm.proxy.proxy_server.prisma_client") as mock_prisma:
# Mock database upsert
mock_prisma.db.litellm_config.upsert = AsyncMock(return_value=None)
response = client_with_auth.post("/schedule/model_cost_map_reload?hours=6")
assert response.status_code == 200
data = response.json()
assert data["status"] == "success"
assert data["interval_hours"] == 6
assert "message" in data
assert "timestamp" in data
def test_schedule_model_cost_map_reload_non_admin_access(self, client_with_auth):
"""Test that non-admin users cannot schedule periodic reload"""
# Mock non-admin user
mock_auth = MagicMock()
mock_auth.user_role = "user" # Non-admin role
app.dependency_overrides[user_api_key_auth] = lambda: mock_auth
response = client_with_auth.post("/schedule/model_cost_map_reload?hours=6")
assert response.status_code == 403
data = response.json()
assert "Access denied" in data["detail"]
assert "Admin role required" in data["detail"]
def test_schedule_model_cost_map_reload_invalid_hours(self, client_with_auth):
"""Test that invalid hours parameter is rejected"""
response = client_with_auth.post("/schedule/model_cost_map_reload?hours=0")
assert response.status_code == 400
data = response.json()
assert "Hours must be greater than 0" in data["detail"]
def test_cancel_model_cost_map_reload_admin_access(self, client_with_auth):
"""Test that admin users can cancel periodic reload"""
with patch("litellm.proxy.proxy_server.prisma_client") as mock_prisma:
# Mock database delete
mock_prisma.db.litellm_config.delete = AsyncMock(return_value=None)
response = client_with_auth.delete("/schedule/model_cost_map_reload")
assert response.status_code == 200
data = response.json()
assert data["status"] == "success"
assert "message" in data
assert "timestamp" in data
def test_cancel_model_cost_map_reload_non_admin_access(self, client_with_auth):
"""Test that non-admin users cannot cancel periodic reload"""
# Mock non-admin user
mock_auth = MagicMock()
mock_auth.user_role = "user" # Non-admin role
app.dependency_overrides[user_api_key_auth] = lambda: mock_auth
response = client_with_auth.delete("/schedule/model_cost_map_reload")
assert response.status_code == 403
data = response.json()
assert "Access denied" in data["detail"]
assert "Admin role required" in data["detail"]
def test_get_model_cost_map_reload_status_admin_access(self, client_with_auth):
"""Test that admin users can get reload status"""
with patch("litellm.proxy.proxy_server.prisma_client") as mock_prisma:
# Mock database config record
mock_config = MagicMock()
mock_config.param_value = {"interval_hours": 6, "force_reload": False}
mock_prisma.db.litellm_config.find_unique = AsyncMock(
return_value=mock_config
)
# Mock the last reload time and current time
with patch(
"litellm.proxy.proxy_server.last_model_cost_map_reload",
"2024-01-01T06:00:00",
):
with patch("litellm.proxy.proxy_server.datetime") as mock_datetime:
# Mock current time to be 1 hour after last reload
mock_datetime.utcnow.return_value = datetime(2024, 1, 1, 7, 0, 0)
mock_datetime.fromisoformat = datetime.fromisoformat
response = client_with_auth.get(
"/schedule/model_cost_map_reload/status"
)
assert response.status_code == 200
data = response.json()
assert data["scheduled"] == True
assert data["interval_hours"] == 6
assert data["last_run"] == "2024-01-01T06:00:00"
assert data["next_run"] == "2024-01-01T12:00:00"
def test_get_model_cost_map_reload_status_non_admin_access(self, client_with_auth):
"""Test that non-admin users cannot get reload status"""
# Mock non-admin user
mock_auth = MagicMock()
mock_auth.user_role = "user" # Non-admin role
app.dependency_overrides[user_api_key_auth] = lambda: mock_auth
response = client_with_auth.get("/schedule/model_cost_map_reload/status")
assert response.status_code == 403
data = response.json()
assert "Access denied" in data["detail"]
assert "Admin role required" in data["detail"]
def test_get_model_cost_map_reload_status_no_config(self, client_with_auth):
"""Test that status returns not scheduled when no config exists"""
with patch("litellm.proxy.proxy_server.prisma_client") as mock_prisma:
mock_prisma.db.litellm_config.find_unique = AsyncMock(return_value=None)
response = client_with_auth.get("/schedule/model_cost_map_reload/status")
assert response.status_code == 200
data = response.json()
assert data["scheduled"] == False
assert data["interval_hours"] == None
assert data["last_run"] == None
assert data["next_run"] == None
def test_get_model_cost_map_reload_status_no_interval(self, client_with_auth):
"""Test that status returns not scheduled when no interval is configured"""
with patch("litellm.proxy.proxy_server.prisma_client") as mock_prisma:
# Mock config with no interval
mock_config = MagicMock()
mock_config.param_value = {"interval_hours": None, "force_reload": False}
mock_prisma.db.litellm_config.find_unique = AsyncMock(
return_value=mock_config
)
response = client_with_auth.get("/schedule/model_cost_map_reload/status")
assert response.status_code == 200
data = response.json()
assert data["scheduled"] == False
assert data["interval_hours"] == None
assert data["last_run"] == None
assert data["next_run"] == None
class TestPriceDataReloadIntegration:
"""Integration tests for the complete price data reload feature"""
@pytest.fixture(autouse=True)
def _flush_litellm_config_cache(self):
from litellm.proxy.utils import litellm_config_cache
litellm_config_cache.flush_cache()
yield
litellm_config_cache.flush_cache()
@pytest.fixture
def client_with_auth(self):
"""Create a test client with authentication"""
from litellm.proxy._types import LitellmUserRoles
from litellm.proxy.proxy_server import cleanup_router_config_variables
cleanup_router_config_variables()
filepath = os.path.dirname(os.path.abspath(__file__))
config_fp = f"{filepath}/test_configs/test_config_no_auth.yaml"
asyncio.run(initialize(config=config_fp, debug=True))
# Mock admin user authentication
mock_auth = MagicMock()
mock_auth.user_role = LitellmUserRoles.PROXY_ADMIN
app.dependency_overrides[user_api_key_auth] = lambda: mock_auth
return TestClient(app)
def test_complete_reload_flow(self, client_with_auth):
"""Test the complete reload flow from API to model cost update"""
# Mock the model cost map
mock_cost_map = {
"gpt-3.5-turbo": {
"input_cost_per_token": 0.001,
"output_cost_per_token": 0.002,
},
"gpt-4": {"input_cost_per_token": 0.03, "output_cost_per_token": 0.06},
}
original_model_cost = litellm.model_cost.copy()
try:
with patch(
"litellm.litellm_core_utils.get_model_cost_map.get_model_cost_map"
) as mock_get_map:
mock_get_map.return_value = mock_cost_map
# Mock the database connection
with patch("litellm.proxy.proxy_server.prisma_client") as mock_prisma:
mock_prisma.db.litellm_config.find_unique = AsyncMock(
return_value=None
)
mock_prisma.db.litellm_config.upsert = AsyncMock(return_value=None)
# Test reload endpoint
response = client_with_auth.post("/reload/model_cost_map")
assert response.status_code == 200
# Test get endpoint
response = client_with_auth.get("/public/litellm_model_cost_map")
assert response.status_code == 200
finally:
litellm.model_cost = original_model_cost
_invalidate_model_cost_lowercase_map()
def test_distributed_reload_check_function(self):
"""Test the _check_and_reload_model_cost_map function"""
from litellm.proxy.proxy_server import ProxyConfig
from litellm.proxy.utils import litellm_config_cache
proxy_config = ProxyConfig()
# Mock prisma client
mock_prisma = MagicMock()
# Test case 1: No config in database
mock_prisma.db.litellm_config.find_unique = AsyncMock(return_value=None)
# _check_and_reload_model_cost_map routes through get_config_param,
# which calls prisma.get_generic_data on a cache miss.
mock_prisma.get_generic_data = AsyncMock(return_value=None)
# Should return early without reloading
asyncio.run(proxy_config._check_and_reload_model_cost_map(mock_prisma))
# Test case 2: Config with interval but not time to reload
litellm_config_cache.flush_cache()
mock_config = MagicMock()
mock_config.param_value = {"interval_hours": 6, "force_reload": False}
mock_prisma.db.litellm_config.find_unique = AsyncMock(return_value=mock_config)
mock_prisma.get_generic_data = AsyncMock(return_value=mock_config)
# Mock current time and last reload time
with patch(
"litellm.proxy.proxy_server.last_model_cost_map_reload",
"2024-01-01T06:00:00",
):
with patch("litellm.proxy.proxy_server.datetime") as mock_datetime:
mock_datetime.utcnow.return_value = datetime(
2024, 1, 1, 7, 0, 0
) # 1 hour later
# Should not reload (only 1 hour passed, need 6)
asyncio.run(proxy_config._check_and_reload_model_cost_map(mock_prisma))
# Test case 3: Config with force reload
litellm_config_cache.flush_cache()
mock_config.param_value = {"interval_hours": 6, "force_reload": True}
mock_prisma.db.litellm_config.find_unique = AsyncMock(return_value=mock_config)
mock_prisma.get_generic_data = AsyncMock(return_value=mock_config)
mock_prisma.db.litellm_config.upsert = AsyncMock(return_value=None)
original_model_cost = litellm.model_cost.copy()
try:
with patch(
"litellm.litellm_core_utils.get_model_cost_map.get_model_cost_map"
) as mock_get_map:
mock_get_map.return_value = {
"gpt-3.5-turbo": {"input_cost_per_token": 0.001}
}
# Should reload due to force flag
asyncio.run(proxy_config._check_and_reload_model_cost_map(mock_prisma))
# Verify force_reload was reset to False
mock_prisma.db.litellm_config.upsert.assert_called()
call_args = mock_prisma.db.litellm_config.upsert.call_args
# The param_value is now a JSON string, so we need to parse it
param_value_json = call_args[1]["data"]["update"]["param_value"]
param_value_dict = json.loads(param_value_json)
assert param_value_dict["force_reload"] == False
assert param_value_dict.get("interval_hours") == 6
finally:
litellm.model_cost = original_model_cost
_invalidate_model_cost_lowercase_map()
def test_distributed_reload_preserves_interval_hours(self):
"""Test that _check_and_reload_model_cost_map preserves interval_hours after reload.
Regression test: the update branch of the upsert was previously dropping
interval_hours, causing scheduled reloads to self-destruct after first execution.
"""
from litellm.proxy.proxy_server import ProxyConfig
proxy_config = ProxyConfig()
mock_prisma = MagicMock()
# Set up config with interval_hours=24 and force_reload=True to trigger reload
mock_config = MagicMock()
mock_config.param_value = {"interval_hours": 24, "force_reload": True}
mock_prisma.db.litellm_config.find_unique = AsyncMock(return_value=mock_config)
# _check_and_reload_model_cost_map now reads through get_generic_data.
mock_prisma.get_generic_data = AsyncMock(return_value=mock_config)
mock_prisma.db.litellm_config.upsert = AsyncMock(return_value=None)
original_model_cost = litellm.model_cost.copy()
try:
with patch(
"litellm.litellm_core_utils.get_model_cost_map.get_model_cost_map"
) as mock_get_map:
mock_get_map.return_value = {"gpt-4": {"input_cost_per_token": 0.001}}
asyncio.run(proxy_config._check_and_reload_model_cost_map(mock_prisma))
# Verify the upsert update branch preserves interval_hours
mock_prisma.db.litellm_config.upsert.assert_called()
call_args = mock_prisma.db.litellm_config.upsert.call_args
param_value_json = call_args[1]["data"]["update"]["param_value"]
param_value_dict = json.loads(param_value_json)
assert param_value_dict["force_reload"] == False
assert param_value_dict["interval_hours"] == 24, (
"interval_hours must be preserved in the update branch; "
"dropping it causes the schedule to self-destruct"
)
finally:
litellm.model_cost = original_model_cost
_invalidate_model_cost_lowercase_map()
def test_manual_reload_preserves_interval_hours(self):
"""Test that manual reload via /reload/model_cost_map preserves existing interval_hours.
Regression test: the manual reload endpoint was overwriting param_value with
only force_reload=True, dropping any existing interval_hours schedule.
"""
from litellm.proxy._types import LitellmUserRoles
from litellm.proxy.proxy_server import cleanup_router_config_variables
cleanup_router_config_variables()
filepath = os.path.dirname(os.path.abspath(__file__))
config_fp = f"{filepath}/test_configs/test_config_no_auth.yaml"
asyncio.run(initialize(config=config_fp, debug=True))
mock_auth = MagicMock()
mock_auth.user_role = LitellmUserRoles.PROXY_ADMIN
app.dependency_overrides[user_api_key_auth] = lambda: mock_auth
client = TestClient(app)
original_model_cost = litellm.model_cost.copy()
try:
with patch(
"litellm.litellm_core_utils.get_model_cost_map.get_model_cost_map"
) as mock_get_map:
mock_get_map.return_value = {"gpt-4": {"input_cost_per_token": 0.001}}
with patch("litellm.proxy.proxy_server.prisma_client") as mock_prisma:
# Simulate existing config with a schedule
mock_existing = MagicMock()
mock_existing.param_value = {
"interval_hours": 12,
"force_reload": False,
}
mock_prisma.db.litellm_config.find_unique = AsyncMock(
return_value=mock_existing
)
mock_prisma.db.litellm_config.upsert = AsyncMock(return_value=None)
response = client.post("/reload/model_cost_map")
assert response.status_code == 200
# Verify interval_hours was preserved in the upsert
mock_prisma.db.litellm_config.upsert.assert_called()
call_args = mock_prisma.db.litellm_config.upsert.call_args
param_value_json = call_args[1]["data"]["update"]["param_value"]
param_value_dict = json.loads(param_value_json)
assert param_value_dict["force_reload"] == True
assert param_value_dict["interval_hours"] == 12, (
"interval_hours must be preserved when manual reload sets force_reload; "
"dropping it destroys any existing schedule"
)
finally:
litellm.model_cost = original_model_cost
_invalidate_model_cost_lowercase_map()
def test_anthropic_beta_headers_reload_preserves_interval_hours(self):
"""Test that _check_and_reload_anthropic_beta_headers preserves interval_hours after reload.
Regression test: the update branch of the upsert was dropping interval_hours,
identical to the model cost map bug.
"""
from litellm.proxy.proxy_server import ProxyConfig
proxy_config = ProxyConfig()
mock_prisma = MagicMock()
# Set up config with interval_hours=12 and force_reload=True to trigger reload
mock_config = MagicMock()
mock_config.param_value = {"interval_hours": 12, "force_reload": True}
mock_prisma.db.litellm_config.find_unique = AsyncMock(return_value=mock_config)
# _check_and_reload_anthropic_beta_headers now reads through get_generic_data.
mock_prisma.get_generic_data = AsyncMock(return_value=mock_config)
mock_prisma.db.litellm_config.upsert = AsyncMock(return_value=None)
with patch(
"litellm.anthropic_beta_headers_manager.reload_beta_headers_config"
) as mock_reload:
mock_reload.return_value = {"anthropic": {"beta_header": "test-value"}}
asyncio.run(
proxy_config._check_and_reload_anthropic_beta_headers(mock_prisma)
)
# Verify the upsert update branch preserves interval_hours
mock_prisma.db.litellm_config.upsert.assert_called()
call_args = mock_prisma.db.litellm_config.upsert.call_args
param_value_json = call_args[1]["data"]["update"]["param_value"]
param_value_dict = json.loads(param_value_json)
assert param_value_dict["force_reload"] == False
assert param_value_dict["interval_hours"] == 12, (
"interval_hours must be preserved in the update branch; "
"dropping it causes the schedule to self-destruct"
)
def test_anthropic_beta_headers_manual_reload_preserves_interval_hours(self):
"""Test that manual reload via /reload/anthropic_beta_headers preserves existing interval_hours.
Regression test: the manual reload endpoint was overwriting param_value with
only force_reload=True, dropping any existing interval_hours schedule.
"""
from litellm.proxy._types import LitellmUserRoles
from litellm.proxy.proxy_server import cleanup_router_config_variables
cleanup_router_config_variables()
filepath = os.path.dirname(os.path.abspath(__file__))
config_fp = f"{filepath}/test_configs/test_config_no_auth.yaml"
asyncio.run(initialize(config=config_fp, debug=True))
mock_auth = MagicMock()
mock_auth.user_role = LitellmUserRoles.PROXY_ADMIN
app.dependency_overrides[user_api_key_auth] = lambda: mock_auth
client = TestClient(app)
with patch(
"litellm.anthropic_beta_headers_manager.reload_beta_headers_config"
) as mock_reload:
mock_reload.return_value = {"anthropic": {"beta_header": "test-value"}}
with patch("litellm.proxy.proxy_server.prisma_client") as mock_prisma:
# Simulate existing config with a schedule
mock_existing = MagicMock()
mock_existing.param_value = {"interval_hours": 8, "force_reload": False}
mock_prisma.db.litellm_config.find_unique = AsyncMock(
return_value=mock_existing
)
mock_prisma.db.litellm_config.upsert = AsyncMock(return_value=None)
response = client.post("/reload/anthropic_beta_headers")
assert response.status_code == 200
# Verify interval_hours was preserved in the upsert
mock_prisma.db.litellm_config.upsert.assert_called()
call_args = mock_prisma.db.litellm_config.upsert.call_args
param_value_json = call_args[1]["data"]["update"]["param_value"]
param_value_dict = json.loads(param_value_json)
assert param_value_dict["force_reload"] == True
assert param_value_dict["interval_hours"] == 8, (
"interval_hours must be preserved when manual reload sets force_reload; "
"dropping it destroys any existing schedule"
)
def test_config_file_parsing(self):
"""Test parsing of config file with reload settings"""
config_content = """
general_settings:
master_key: sk-1234
model_cost_map_reload_interval: 21600
model_list:
- model_name: gpt-3.5-turbo
litellm_params:
model: gpt-3.5-turbo
- model_name: gpt-4
litellm_params:
model: gpt-4
"""
# Parse the config
config = yaml.safe_load(config_content)
# Verify the reload setting is present
assert "general_settings" in config
assert "model_cost_map_reload_interval" in config["general_settings"]
assert config["general_settings"]["model_cost_map_reload_interval"] == 21600
# Verify models are present
assert "model_list" in config
assert len(config["model_list"]) == 2
def test_database_config_storage(self):
"""Test that configuration is properly stored in database"""
# Mock prisma client
mock_prisma = MagicMock()
# Test the database upsert call that would be made by the schedule endpoint
mock_prisma.db.litellm_config.upsert = AsyncMock(return_value=None)
# Simulate the database call that the schedule endpoint would make
asyncio.run(
mock_prisma.db.litellm_config.upsert(
where={"param_name": "model_cost_map_reload_config"},
data={
"create": {
"param_name": "model_cost_map_reload_config",
"param_value": {"interval_hours": 6, "force_reload": False},
},
"update": {
"param_value": {"interval_hours": 6, "force_reload": False}
},
},
)
)
# Verify database upsert was called with correct data
mock_prisma.db.litellm_config.upsert.assert_called_once()
call_args = mock_prisma.db.litellm_config.upsert.call_args
assert call_args[1]["where"]["param_name"] == "model_cost_map_reload_config"
assert call_args[1]["data"]["create"]["param_value"]["interval_hours"] == 6
assert call_args[1]["data"]["create"]["param_value"]["force_reload"] == False
def test_manual_reload_force_flag(self):
"""Test that manual reload sets force flag correctly"""
# Mock prisma client
mock_prisma = MagicMock()
# Test the database upsert call that would be made by the manual reload endpoint
mock_prisma.db.litellm_config.upsert = AsyncMock(return_value=None)
# Simulate the database call that the manual reload endpoint would make
asyncio.run(
mock_prisma.db.litellm_config.upsert(
where={"param_name": "model_cost_map_reload_config"},
data={
"create": {
"param_name": "model_cost_map_reload_config",
"param_value": {"interval_hours": None, "force_reload": True},
},
"update": {"param_value": {"force_reload": True}},
},
)
)
# Verify force_reload flag was set
mock_prisma.db.litellm_config.upsert.assert_called_once()
call_args = mock_prisma.db.litellm_config.upsert.call_args
assert call_args[1]["data"]["update"]["param_value"]["force_reload"] == True
@pytest.mark.asyncio
async def test_add_router_settings_from_db_config_merge_logic():
"""
Test the _add_router_settings_from_db_config method's merge logic.
This tests how router settings from config file and database are combined,
including scenarios where nested dictionaries should be properly merged.
"""
from unittest.mock import AsyncMock, MagicMock, patch
from litellm.proxy.proxy_server import ProxyConfig
# Create ProxyConfig instance
proxy_config = ProxyConfig()
# Mock router
mock_router = MagicMock()
mock_router.update_settings = MagicMock()
# Test Case 1: Both config and DB settings exist - should merge them
config_data = {
"router_settings": {
"routing_strategy": "usage-based-routing",
"model_group_alias": {"gpt-4": "openai-gpt-4"},
"enable_pre_call_checks": True,
"timeout": 30,
"nested_config": {"setting1": "config_value1", "setting2": "config_value2"},
}
}
# Mock database config record
mock_db_config = MagicMock()
mock_db_config.param_value = {
"routing_strategy": "least-busy", # This should override config value
"retry_delay": 2, # This is new, should be added
"nested_config": {
"setting2": "db_value2", # This should override config value
"setting3": "db_value3", # This is new, should be added
},
}
# Mock prisma client
mock_prisma_client = MagicMock()
mock_prisma_client.db.litellm_config.find_first = AsyncMock(
return_value=mock_db_config
)
# Call the method under test
await proxy_config._add_router_settings_from_db_config(
config_data=config_data,
llm_router=mock_router,
prisma_client=mock_prisma_client,
)
# Verify find_first was called with correct parameters
mock_prisma_client.db.litellm_config.find_first.assert_called_once_with(
where={"param_name": "router_settings"}
)
# Verify update_settings was called
mock_router.update_settings.assert_called_once()
# Get the actual settings passed to update_settings
call_args = mock_router.update_settings.call_args
combined_settings = call_args[1] # kwargs
# Verify the merge results
# DB values should override config values
assert combined_settings["routing_strategy"] == "least-busy"
# Config-only values should be preserved
assert combined_settings["model_group_alias"] == {"gpt-4": "openai-gpt-4"}
assert combined_settings["enable_pre_call_checks"] == True
assert combined_settings["timeout"] == 30
# DB-only values should be added
assert combined_settings["retry_delay"] == 2
# Nested dictionaries should be merged (but this is shallow merge)
expected_nested = {
"setting1": "config_value1",
"setting2": "db_value2",
"setting3": "db_value3",
}
assert combined_settings["nested_config"] == expected_nested
@pytest.mark.asyncio
async def test_add_router_settings_from_db_config_edge_cases():
"""
Test edge cases for _add_router_settings_from_db_config method.
"""
from unittest.mock import AsyncMock, MagicMock
from litellm.proxy.proxy_server import ProxyConfig
proxy_config = ProxyConfig()
mock_router = MagicMock()
mock_router.update_settings = MagicMock()
# Test Case 1: No router provided
await proxy_config._add_router_settings_from_db_config(
config_data={"router_settings": {"test": "value"}},
llm_router=None,
prisma_client=MagicMock(),
)
# Should not call anything when router is None
mock_router.update_settings.assert_not_called()
# Test Case 2: No prisma client provided
await proxy_config._add_router_settings_from_db_config(
config_data={"router_settings": {"test": "value"}},
llm_router=mock_router,
prisma_client=None,
)
# Should not call anything when prisma_client is None
mock_router.update_settings.assert_not_called()
# Test Case 3: DB returns None (no router_settings in DB)
mock_prisma_client = MagicMock()
mock_prisma_client.db.litellm_config.find_first = AsyncMock(return_value=None)
config_data = {"router_settings": {"routing_strategy": "usage-based"}}
await proxy_config._add_router_settings_from_db_config(
config_data=config_data,
llm_router=mock_router,
prisma_client=mock_prisma_client,
)
# Should use only config settings
mock_router.update_settings.assert_called_once_with(routing_strategy="usage-based")
mock_router.reset_mock()
# Test Case 4: Config has no router_settings
mock_db_config = MagicMock()
mock_db_config.param_value = {"db_setting": "db_value"}
mock_prisma_client.db.litellm_config.find_first = AsyncMock(
return_value=mock_db_config
)
await proxy_config._add_router_settings_from_db_config(
config_data={}, # No router_settings in config
llm_router=mock_router,
prisma_client=mock_prisma_client,
)
# Should use only DB settings
mock_router.update_settings.assert_called_once_with(db_setting="db_value")
mock_router.reset_mock()
# Test Case 5: Both config and DB router_settings are None/empty
mock_prisma_client.db.litellm_config.find_first = AsyncMock(return_value=None)
await proxy_config._add_router_settings_from_db_config(
config_data={}, llm_router=mock_router, prisma_client=mock_prisma_client
)
# Should not call update_settings when no settings exist
mock_router.update_settings.assert_not_called()
# Test Case 6: DB config exists but param_value is not a dict
mock_db_config_invalid = MagicMock()
mock_db_config_invalid.param_value = "not_a_dict"
mock_prisma_client.db.litellm_config.find_first = AsyncMock(
return_value=mock_db_config_invalid
)
config_data = {"router_settings": {"config_setting": "config_value"}}
await proxy_config._add_router_settings_from_db_config(
config_data=config_data,
llm_router=mock_router,
prisma_client=mock_prisma_client,
)
# Should use only config settings when DB param_value is invalid
mock_router.update_settings.assert_called_once_with(config_setting="config_value")
@pytest.mark.asyncio
async def test_add_router_settings_shallow_merge_behavior():
"""
Test that the merge behavior is shallow (nested dicts get replaced, not merged).
This documents the current behavior using _update_dictionary.
"""
from unittest.mock import AsyncMock, MagicMock
from litellm.proxy.proxy_server import ProxyConfig
proxy_config = ProxyConfig()
mock_router = MagicMock()
mock_router.update_settings = MagicMock()
# Config with nested dictionary
config_data = {
"router_settings": {
"nested_setting": {
"key1": "config_value1",
"key2": "config_value2",
"key3": "config_value3",
},
"top_level": "config_top",
}
}
# DB config that partially overlaps the nested dictionary
mock_db_config = MagicMock()
mock_db_config.param_value = {
"nested_setting": {
"key2": "db_value2", # Override existing key
"key4": "db_value4", # Add new key
# Note: key1 and key3 from config will be lost due to shallow merge
},
"top_level": "db_top", # Override top level
}
mock_prisma_client = MagicMock()
mock_prisma_client.db.litellm_config.find_first = AsyncMock(
return_value=mock_db_config
)
await proxy_config._add_router_settings_from_db_config(
config_data=config_data,
llm_router=mock_router,
prisma_client=mock_prisma_client,
)
# Get the merged settings
call_args = mock_router.update_settings.call_args
merged_settings = call_args[1]
# Verify shallow merge behavior:
# The entire nested_setting dict from config is replaced by the DB version
expected_nested = {
"key1": "config_value1",
"key3": "config_value3",
"key2": "db_value2",
"key4": "db_value4",
}
assert merged_settings["nested_setting"] == expected_nested
assert merged_settings["top_level"] == "db_top"
@pytest.mark.asyncio
async def test_model_info_v1_oci_secrets_not_leaked():
"""
Test that model_info_v1 endpoint properly masks OCI sensitive parameters and does not leak secrets.
"""
from unittest.mock import MagicMock, patch
from litellm.proxy._types import UserAPIKeyAuth
from litellm.proxy.proxy_server import model_info_v1
# Mock user authentication
mock_user_api_key_dict = MagicMock(spec=UserAPIKeyAuth)
mock_user_api_key_dict.user_id = "test-user"
mock_user_api_key_dict.api_key = "test-key"
mock_user_api_key_dict.team_models = []
mock_user_api_key_dict.models = ["oci-grok-test"]
# Mock model data with OCI sensitive information
mock_model_data = {
"model_name": "oci-grok-test",
"litellm_params": {
"model": "oci/xai.grok-4",
"oci_key": "ocid1.api_key.oc1..aaaaaaaa7kbkbkbkbkbkbkbkbkbkbkbkbkbkbkbkbkbkbkbkbkbkbkbkbkbk",
"oci_region": "us-phoenix-1",
"oci_user": "ocid1.user.oc1..aaaaaaaa7kbkbkbkbkbkbkbkbkbkbkbkbkbkbkbkbkbkbkbkbkbkbkbkbkbk",
"oci_fingerprint": "aa:bb:cc:dd:ee:ff:11:22:33:44:55:66:77:88:99:00",
"oci_tenancy": "ocid1.tenancy.oc1..aaaaaaaa7kbkbkbkbkbkbkbkbkbkbkbkbkbkbkbkbkbkbkbkbkbkbkbkbkbk",
"oci_key_file": "/path/to/oci_api_key.pem",
"oci_compartment_id": "ocid1.compartment.oc1..aaaaaaaa7kbkbkbkbkbkbkbkbkbkbkbkbkbkbkbkbkbkbkbkbkbkbkbkbkbk",
"drop_params": True,
},
"model_info": {"mode": "completion", "id": "test-model-id"},
}
# Mock the llm_router to return our test data
mock_router = MagicMock()
mock_router.get_model_names.return_value = ["oci-grok-test"]
mock_router.get_model_access_groups.return_value = {}
mock_router.get_model_list.return_value = [mock_model_data]
# Mock global variables
with (
patch("litellm.proxy.proxy_server.llm_router", mock_router),
patch("litellm.proxy.proxy_server.llm_model_list", [mock_model_data]),
patch(
"litellm.proxy.proxy_server.general_settings",
{"infer_model_from_keys": False},
),
patch("litellm.proxy.proxy_server.user_model", None),
):
# Call the model_info_v1 endpoint
result = await model_info_v1(
user_api_key_dict=mock_user_api_key_dict, litellm_model_id=None
)
# Verify the result structure
assert "data" in result
assert len(result["data"]) == 1
model_info = result["data"][0]
litellm_params = model_info["litellm_params"]
# Verify that sensitive OCI fields are masked
assert "****" in litellm_params["oci_key"], "oci_key should be masked"
assert (
"****" in litellm_params["oci_fingerprint"]
), "oci_fingerprint should be masked"
assert "****" in litellm_params["oci_tenancy"], "oci_tenancy should be masked"
assert "****" in litellm_params["oci_key_file"], "oci_key_file should be masked"
# Verify that non-sensitive fields are NOT masked
assert (
litellm_params["model"] == "oci/xai.grok-4"
), "model field should not be masked"
assert (
litellm_params["oci_region"] == "us-phoenix-1"
), "oci_region should not be masked"
assert litellm_params["drop_params"] is True, "drop_params should not be masked"
# Verify the model field specifically is not masked (this was the original issue)
assert (
"****" not in litellm_params["model"]
), "model field should never be masked"
assert litellm_params["model"].startswith(
"oci/"
), "model should retain its full value"
# Verify that actual secret values are not present in the response
result_str = str(result)
assert (
"ocid1.api_key.oc1..aaaaaaaa7kbkbkbkbkbkbkbkbkbkbkbkbkbkbkbkbkbkbkbkbkbkbkbkbkbk"
not in result_str
)
assert "aa:bb:cc:dd:ee:ff:11:22:33:44:55:66:77:88:99:00" not in result_str
assert (
"ocid1.tenancy.oc1..aaaaaaaa7kbkbkbkbkbkbkbkbkbkbkbkbkbkbkbkbkbkbkbkbkbkbkbkbkbk"
not in result_str
)
assert "/path/to/oci_api_key.pem" not in result_str
def test_add_callback_from_db_to_in_memory_litellm_callbacks():
"""
Test that _add_callback_from_db_to_in_memory_litellm_callbacks correctly adds callbacks
for success, failure, and combined event types.
"""
from unittest.mock import MagicMock, patch
from litellm.proxy.proxy_server import ProxyConfig
proxy_config = ProxyConfig()
# Mock the callback manager
mock_callback_manager = MagicMock()
with patch("litellm.proxy.proxy_server.litellm") as mock_litellm:
# Set up mock litellm attributes
mock_litellm._known_custom_logger_compatible_callbacks = []
mock_litellm.logging_callback_manager = mock_callback_manager
# Test Case 1: Add success callback
mock_success_callbacks = []
proxy_config._add_callback_from_db_to_in_memory_litellm_callbacks(
callback="prometheus",
event_types=["success"],
existing_callbacks=mock_success_callbacks,
)
mock_callback_manager.add_litellm_success_callback.assert_called_once_with(
"prometheus"
)
mock_callback_manager.reset_mock()
# Test Case 2: Add failure callback
mock_failure_callbacks = []
proxy_config._add_callback_from_db_to_in_memory_litellm_callbacks(
callback="langfuse",
event_types=["failure"],
existing_callbacks=mock_failure_callbacks,
)
mock_callback_manager.add_litellm_failure_callback.assert_called_once_with(
"langfuse"
)
mock_callback_manager.reset_mock()
# Test Case 3: Add callback for both success and failure
mock_callbacks = []
proxy_config._add_callback_from_db_to_in_memory_litellm_callbacks(
callback="s3",
event_types=["success", "failure"],
existing_callbacks=mock_callbacks,
)
mock_callback_manager.add_litellm_callback.assert_called_once_with("s3")
mock_callback_manager.reset_mock()
# Test Case 4: Don't add callback if it already exists
existing_callbacks_with_item = ["prometheus"]
proxy_config._add_callback_from_db_to_in_memory_litellm_callbacks(
callback="prometheus",
event_types=["success"],
existing_callbacks=existing_callbacks_with_item,
)
mock_callback_manager.add_litellm_success_callback.assert_not_called()
def test_should_load_db_object_with_supported_db_objects():
"""
Test _should_load_db_object method with supported_db_objects configuration.
Verifies that when supported_db_objects is set, only specified object types
are loaded from the database.
"""
from unittest.mock import patch
from litellm.proxy.proxy_server import ProxyConfig
proxy_config = ProxyConfig()
# Test Case 1: supported_db_objects not set - all objects should be loaded
with patch("litellm.proxy.proxy_server.general_settings", {}):
assert proxy_config._should_load_db_object(object_type="models") is True
assert proxy_config._should_load_db_object(object_type="mcp") is True
assert proxy_config._should_load_db_object(object_type="guardrails") is True
assert proxy_config._should_load_db_object(object_type="vector_stores") is True
# Test Case 2: supported_db_objects set to only load MCP
with patch(
"litellm.proxy.proxy_server.general_settings",
{"supported_db_objects": ["mcp"]},
):
assert proxy_config._should_load_db_object(object_type="models") is False
assert proxy_config._should_load_db_object(object_type="mcp") is True
assert proxy_config._should_load_db_object(object_type="guardrails") is False
assert proxy_config._should_load_db_object(object_type="vector_stores") is False
assert proxy_config._should_load_db_object(object_type="prompts") is False
# Test Case 3: supported_db_objects set to load multiple types
with patch(
"litellm.proxy.proxy_server.general_settings",
{"supported_db_objects": ["mcp", "guardrails", "vector_stores"]},
):
assert proxy_config._should_load_db_object(object_type="models") is False
assert proxy_config._should_load_db_object(object_type="mcp") is True
assert proxy_config._should_load_db_object(object_type="guardrails") is True
assert proxy_config._should_load_db_object(object_type="vector_stores") is True
assert proxy_config._should_load_db_object(object_type="prompts") is False
# Test Case 4: supported_db_objects is not a list (should default to loading all)
with patch(
"litellm.proxy.proxy_server.general_settings",
{"supported_db_objects": "invalid_type"},
):
assert proxy_config._should_load_db_object(object_type="models") is True
assert proxy_config._should_load_db_object(object_type="mcp") is True
# Test Case 5: supported_db_objects is an empty list (nothing should be loaded)
with patch(
"litellm.proxy.proxy_server.general_settings",
{"supported_db_objects": []},
):
assert proxy_config._should_load_db_object(object_type="models") is False
assert proxy_config._should_load_db_object(object_type="mcp") is False
assert proxy_config._should_load_db_object(object_type="guardrails") is False
# Test Case 6: Test all available object types
with patch(
"litellm.proxy.proxy_server.general_settings",
{
"supported_db_objects": [
"models",
"mcp",
"guardrails",
"vector_stores",
"pass_through_endpoints",
"prompts",
"model_cost_map",
]
},
):
assert proxy_config._should_load_db_object(object_type="models") is True
assert proxy_config._should_load_db_object(object_type="mcp") is True
assert proxy_config._should_load_db_object(object_type="guardrails") is True
assert proxy_config._should_load_db_object(object_type="vector_stores") is True
assert (
proxy_config._should_load_db_object(object_type="pass_through_endpoints")
is True
)
assert proxy_config._should_load_db_object(object_type="prompts") is True
assert proxy_config._should_load_db_object(object_type="model_cost_map") is True
@pytest.mark.asyncio
async def test_tag_cache_update_called():
"""
Test that update_cache updates tag cache when tags are provided.
"""
from litellm.caching.caching import DualCache
from litellm.proxy.proxy_server import user_api_key_cache
cache = DualCache()
setattr(
litellm.proxy.proxy_server,
"user_api_key_cache",
cache,
)
mock_tag_obj = {
"tag_name": "test-tag",
"spend": 10.0,
}
with patch.object(
cache, "async_get_cache", new=AsyncMock(return_value=mock_tag_obj)
) as mock_get_cache:
with patch.object(
cache, "async_set_cache_pipeline", new=AsyncMock()
) as mock_set_cache:
await litellm.proxy.proxy_server.update_cache(
token=None,
user_id=None,
end_user_id=None,
team_id=None,
response_cost=5.0,
parent_otel_span=None,
tags=["test-tag"],
)
await asyncio.sleep(0.1)
mock_get_cache.assert_awaited_once_with(key="tag:test-tag")
mock_set_cache.assert_awaited_once()
call_args = mock_set_cache.call_args
cache_list = call_args.kwargs["cache_list"]
assert len(cache_list) == 1
cache_key, cache_value = cache_list[0]
assert cache_key == "tag:test-tag"
assert cache_value["spend"] == 15.0
@pytest.mark.asyncio
async def test_tag_cache_update_multiple_tags():
"""
Test that multiple tags are updated in cache.
"""
from litellm.caching.caching import DualCache
from litellm.proxy.proxy_server import user_api_key_cache
cache = DualCache()
setattr(
litellm.proxy.proxy_server,
"user_api_key_cache",
cache,
)
mock_tag1_obj = {"tag_name": "tag1", "spend": 10.0}
mock_tag2_obj = {"tag_name": "tag2", "spend": 20.0}
async def mock_get_cache_side_effect(key):
if key == "tag:tag1":
return mock_tag1_obj
elif key == "tag:tag2":
return mock_tag2_obj
return None
with patch.object(
cache, "async_get_cache", new=AsyncMock(side_effect=mock_get_cache_side_effect)
) as mock_get_cache:
with patch.object(
cache, "async_set_cache_pipeline", new=AsyncMock()
) as mock_set_cache:
await litellm.proxy.proxy_server.update_cache(
token=None,
user_id=None,
end_user_id=None,
team_id=None,
response_cost=5.0,
parent_otel_span=None,
tags=["tag1", "tag2"],
)
await asyncio.sleep(0.1)
assert mock_get_cache.call_count == 2
mock_set_cache.assert_awaited_once()
call_args = mock_set_cache.call_args
cache_list = call_args.kwargs["cache_list"]
assert len(cache_list) == 2
tag_updates = {
cache_key: cache_value for cache_key, cache_value in cache_list
}
assert "tag:tag1" in tag_updates
assert "tag:tag2" in tag_updates
assert tag_updates["tag:tag1"]["spend"] == 15.0
assert tag_updates["tag:tag2"]["spend"] == 25.0
@pytest.mark.asyncio
async def test_init_sso_settings_in_db():
"""
Test that _init_sso_settings_in_db properly loads SSO settings from database,
uppercases keys, and calls _decrypt_and_set_db_env_variables.
"""
from unittest.mock import AsyncMock, MagicMock, patch
from litellm.proxy.proxy_server import ProxyConfig
proxy_config = ProxyConfig()
# Test Case 1: SSO settings exist in database
mock_sso_config = MagicMock()
mock_sso_config.sso_settings = {
"google_client_id": "test-client-id",
"google_client_secret": "test-client-secret",
"microsoft_client_id": "ms-client-id",
"microsoft_client_secret": "ms-client-secret",
}
mock_prisma_client = MagicMock()
mock_prisma_client.db.litellm_ssoconfig.find_unique = AsyncMock(
return_value=mock_sso_config
)
# Mock _decrypt_and_set_db_env_variables
with patch.object(
proxy_config, "_decrypt_and_set_db_env_variables"
) as mock_decrypt_and_set:
await proxy_config._init_sso_settings_in_db(prisma_client=mock_prisma_client)
# Verify find_unique was called with correct parameters
mock_prisma_client.db.litellm_ssoconfig.find_unique.assert_awaited_once_with(
where={"id": "sso_config"}
)
# Verify _decrypt_and_set_db_env_variables was called with uppercased keys
mock_decrypt_and_set.assert_called_once()
call_args = mock_decrypt_and_set.call_args
uppercased_settings = call_args.kwargs["environment_variables"]
# Verify all keys are uppercased
assert "GOOGLE_CLIENT_ID" in uppercased_settings
assert "GOOGLE_CLIENT_SECRET" in uppercased_settings
assert "MICROSOFT_CLIENT_ID" in uppercased_settings
assert "MICROSOFT_CLIENT_SECRET" in uppercased_settings
# Verify values are preserved
assert uppercased_settings["GOOGLE_CLIENT_ID"] == "test-client-id"
assert uppercased_settings["GOOGLE_CLIENT_SECRET"] == "test-client-secret"
assert uppercased_settings["MICROSOFT_CLIENT_ID"] == "ms-client-id"
assert uppercased_settings["MICROSOFT_CLIENT_SECRET"] == "ms-client-secret"
# Verify original lowercase keys are not present
assert "google_client_id" not in uppercased_settings
assert "microsoft_client_id" not in uppercased_settings
@pytest.mark.asyncio
async def test_init_sso_settings_in_db_no_settings():
"""
Test that _init_sso_settings_in_db handles the case when no SSO settings exist in database.
"""
from unittest.mock import AsyncMock, MagicMock, patch
from litellm.proxy.proxy_server import ProxyConfig
proxy_config = ProxyConfig()
# Mock prisma client to return None (no SSO settings)
mock_prisma_client = MagicMock()
mock_prisma_client.db.litellm_ssoconfig.find_unique = AsyncMock(return_value=None)
# Mock _decrypt_and_set_db_env_variables
with patch.object(
proxy_config, "_decrypt_and_set_db_env_variables"
) as mock_decrypt_and_set:
await proxy_config._init_sso_settings_in_db(prisma_client=mock_prisma_client)
# Verify find_unique was called
mock_prisma_client.db.litellm_ssoconfig.find_unique.assert_awaited_once_with(
where={"id": "sso_config"}
)
# Verify _decrypt_and_set_db_env_variables was NOT called when no settings exist
mock_decrypt_and_set.assert_not_called()
@pytest.mark.asyncio
async def test_init_sso_settings_in_db_error_handling():
"""
Test that _init_sso_settings_in_db handles database errors gracefully.
"""
from unittest.mock import AsyncMock, MagicMock, patch
from litellm.proxy.proxy_server import ProxyConfig
proxy_config = ProxyConfig()
# Mock prisma client to raise an exception
mock_prisma_client = MagicMock()
mock_prisma_client.db.litellm_ssoconfig.find_unique = AsyncMock(
side_effect=Exception("Database connection error")
)
# The method should not raise an exception, it should log it instead
try:
await proxy_config._init_sso_settings_in_db(prisma_client=mock_prisma_client)
# If we get here, the exception was handled properly
assert True
except Exception as e:
# The exception should be caught and logged, not propagated
pytest.fail(
f"Exception should have been caught and logged, but was raised: {e}"
)
@pytest.mark.asyncio
async def test_init_sso_settings_in_db_empty_settings():
"""
Test that _init_sso_settings_in_db handles empty SSO settings dictionary.
"""
from unittest.mock import AsyncMock, MagicMock, patch
from litellm.proxy.proxy_server import ProxyConfig
proxy_config = ProxyConfig()
# Mock SSO config with empty settings dictionary
mock_sso_config = MagicMock()
mock_sso_config.sso_settings = {}
mock_prisma_client = MagicMock()
mock_prisma_client.db.litellm_ssoconfig.find_unique = AsyncMock(
return_value=mock_sso_config
)
# Mock _decrypt_and_set_db_env_variables
with patch.object(
proxy_config, "_decrypt_and_set_db_env_variables"
) as mock_decrypt_and_set:
await proxy_config._init_sso_settings_in_db(prisma_client=mock_prisma_client)
# Verify find_unique was called
mock_prisma_client.db.litellm_ssoconfig.find_unique.assert_awaited_once_with(
where={"id": "sso_config"}
)
# Verify _decrypt_and_set_db_env_variables was called with empty dict
mock_decrypt_and_set.assert_called_once()
call_args = mock_decrypt_and_set.call_args
uppercased_settings = call_args.kwargs["environment_variables"]
# Verify empty dictionary
assert uppercased_settings == {}
def test_update_config_fields_uppercases_env_vars(monkeypatch):
"""
Ensure environment variables pulled from DB are uppercased when applied so
integrations like Datadog that expect uppercase env keys can read them.
"""
from litellm.proxy.proxy_server import ProxyConfig
for key in ["DD_API_KEY", "DD_SITE", "dd_api_key", "dd_site"]:
monkeypatch.delenv(key, raising=False)
proxy_config = ProxyConfig()
updated_config = proxy_config._update_config_fields(
current_config={},
param_name="environment_variables",
db_param_value={"dd_api_key": "test-api-key", "dd_site": "us5.datadoghq.com"},
)
env_vars = updated_config.get("environment_variables", {})
assert env_vars["DD_API_KEY"] == "test-api-key"
assert env_vars["DD_SITE"] == "us5.datadoghq.com"
assert os.environ.get("DD_API_KEY") == "test-api-key"
assert os.environ.get("DD_SITE") == "us5.datadoghq.com"
def test_encrypt_env_variables_for_db_is_idempotent(monkeypatch):
"""
Regression: /config/update and save_config must not stack a second
encryption layer when a caller re-submits a value that is already
ciphertext (the Admin UI reads config back from /get/config/callbacks —
which returns the stored, still-encrypted value — and re-POSTs it on the
next save). _encrypt_env_variables_for_db must yield a value that decrypts
to the original plaintext in exactly ONE layer, no matter how many times
its own output is fed back in. It must also not mutate os.environ (write
path — loading into the process env is the read path's job).
"""
from litellm.proxy.common_utils.encrypt_decrypt_utils import (
decrypt_value_helper,
)
from litellm.proxy.proxy_server import ProxyConfig
monkeypatch.setenv("LITELLM_SALT_KEY", "sk-test-salt-key")
monkeypatch.delenv("LANGFUSE_PUBLIC_KEY", raising=False)
proxy_config = ProxyConfig()
plaintext = "pk-langfuse-secret-value"
# First write: plaintext in -> single-encrypted out.
enc1 = proxy_config._encrypt_env_variables_for_db(
{"LANGFUSE_PUBLIC_KEY": plaintext}
)
assert enc1["LANGFUSE_PUBLIC_KEY"] != plaintext
assert (
decrypt_value_helper(
value=enc1["LANGFUSE_PUBLIC_KEY"], key="LANGFUSE_PUBLIC_KEY"
)
== plaintext
)
# UI round-trip: feed the ciphertext back in. Must NOT double-encrypt.
enc2 = proxy_config._encrypt_env_variables_for_db(enc1)
assert (
decrypt_value_helper(
value=enc2["LANGFUSE_PUBLIC_KEY"], key="LANGFUSE_PUBLIC_KEY"
)
== plaintext
)
# And again, ×3 total ciphertext re-feeds — still exactly one layer,
# never stacked, no matter how many times the UI re-saves.
enc3 = proxy_config._encrypt_env_variables_for_db(enc2)
enc4 = proxy_config._encrypt_env_variables_for_db(enc3)
for stacked in (enc3, enc4):
assert (
decrypt_value_helper(
value=stacked["LANGFUSE_PUBLIC_KEY"], key="LANGFUSE_PUBLIC_KEY"
)
== plaintext
)
# Write path must not leak the value into the process environment.
assert os.environ.get("LANGFUSE_PUBLIC_KEY") is None
def test_get_prompt_spec_for_db_prompt_with_versions():
"""
Test that _get_prompt_spec_for_db_prompt correctly converts database prompts
to PromptSpec with versioned naming convention.
"""
from unittest.mock import MagicMock
from litellm.proxy.proxy_server import ProxyConfig
proxy_config = ProxyConfig()
# Mock database prompt version 1
mock_prompt_v1 = MagicMock()
mock_prompt_v1.model_dump.return_value = {
"id": "uuid-1",
"prompt_id": "chat_prompt",
"version": 1,
"litellm_params": '{"prompt_id": "chat_prompt", "prompt_integration": "dotprompt", "model": "gpt-3.5-turbo", "messages": [{"role": "user", "content": "v1 content"}]}',
"prompt_info": '{"prompt_type": "db"}',
"created_at": "2024-01-01T00:00:00",
"updated_at": "2024-01-01T00:00:00",
}
# Mock database prompt version 2
mock_prompt_v2 = MagicMock()
mock_prompt_v2.model_dump.return_value = {
"id": "uuid-2",
"prompt_id": "chat_prompt",
"version": 2,
"litellm_params": '{"prompt_id": "chat_prompt", "prompt_integration": "dotprompt", "model": "gpt-4", "messages": [{"role": "user", "content": "v2 content"}]}',
"prompt_info": '{"prompt_type": "db"}',
"created_at": "2024-01-02T00:00:00",
"updated_at": "2024-01-02T00:00:00",
}
# Test version 1
prompt_spec_v1 = proxy_config._get_prompt_spec_for_db_prompt(
db_prompt=mock_prompt_v1
)
assert prompt_spec_v1.prompt_id == "chat_prompt.v1"
# Test version 2
prompt_spec_v2 = proxy_config._get_prompt_spec_for_db_prompt(
db_prompt=mock_prompt_v2
)
assert prompt_spec_v2.prompt_id == "chat_prompt.v2"
def test_root_redirect_when_docs_url_not_root_and_redirect_url_set(monkeypatch):
from fastapi.responses import RedirectResponse
from litellm.proxy.proxy_server import cleanup_router_config_variables
from litellm.proxy.utils import _get_docs_url
cleanup_router_config_variables()
filepath = os.path.dirname(os.path.abspath(__file__))
config_fp = f"{filepath}/test_configs/test_config_no_auth.yaml"
# Ensure docs are mounted on a non-root path to trigger redirect logic
monkeypatch.setenv("DOCS_URL", "/docs")
test_redirect_url = "/ui"
monkeypatch.setenv("ROOT_REDIRECT_URL", test_redirect_url)
asyncio.run(initialize(config=config_fp, debug=True))
docs_url = _get_docs_url()
root_redirect_url = os.getenv("ROOT_REDIRECT_URL")
# Remove any existing "/" route that might interfere
routes_to_remove = []
for route in app.routes:
if hasattr(route, "path") and route.path == "/":
if hasattr(route, "methods") and "GET" in route.methods:
routes_to_remove.append(route)
elif not hasattr(route, "methods"): # Catch-all routes
routes_to_remove.append(route)
for route in routes_to_remove:
app.routes.remove(route)
# Add the redirect route if conditions are met (matching the actual implementation)
if docs_url != "/" and root_redirect_url:
@app.get("/", include_in_schema=False)
async def root_redirect():
return RedirectResponse(url=root_redirect_url)
client = TestClient(app)
response = client.get("/", follow_redirects=False)
assert response.status_code == 307
assert response.headers["location"] == test_redirect_url
@pytest.mark.asyncio
async def test_get_image_non_root_uses_var_lib_assets_dir(monkeypatch):
"""
Test that get_image uses /var/lib/litellm/assets when LITELLM_NON_ROOT is true.
"""
from unittest.mock import patch
from litellm.proxy.proxy_server import get_image
# Set LITELLM_NON_ROOT to true
monkeypatch.setenv("LITELLM_NON_ROOT", "true")
monkeypatch.delenv("UI_LOGO_PATH", raising=False)
# Mock os.path operations - exists=False for assets_dir so makedirs gets called
def exists_side_effect(path):
return False if path == "/var/lib/litellm/assets" else True
with (
patch("litellm.proxy.proxy_server.os.makedirs") as mock_makedirs,
patch(
"litellm.proxy.proxy_server.os.path.exists", side_effect=exists_side_effect
),
patch("litellm.proxy.proxy_server.os.access", return_value=True),
patch("litellm.proxy.proxy_server.os.getenv") as mock_getenv,
patch("litellm.proxy.proxy_server.FileResponse") as mock_file_response,
):
# Setup mock_getenv to return empty string for UI_LOGO_PATH
def getenv_side_effect(key, default=""):
if key == "UI_LOGO_PATH":
return ""
elif key == "LITELLM_NON_ROOT":
return "true"
return default
mock_getenv.side_effect = getenv_side_effect
# Call the function
await get_image()
# Verify makedirs was called with /var/lib/litellm/assets
mock_makedirs.assert_called_once_with("/var/lib/litellm/assets", exist_ok=True)
@pytest.mark.asyncio
async def test_get_image_non_root_fallback_to_default_logo(monkeypatch):
"""
Test that get_image falls back to default_site_logo when logo doesn't exist
in /var/lib/litellm/assets for non-root case.
"""
from unittest.mock import patch
from litellm.proxy.proxy_server import get_image
# Set LITELLM_NON_ROOT to true
monkeypatch.setenv("LITELLM_NON_ROOT", "true")
monkeypatch.delenv("UI_LOGO_PATH", raising=False)
# Track path.exists calls to verify it checks /var/lib/litellm/assets/logo.jpg
exists_calls = []
def exists_side_effect(path):
exists_calls.append(path)
# Return False for /var/lib/litellm/assets* so: makedirs is called, logo fallback
# triggers, and we don't return early with cached file
if "/var/lib/litellm/assets" in path:
return False
return True
# Mock os.path operations
with (
patch("litellm.proxy.proxy_server.os.makedirs") as mock_makedirs,
patch(
"litellm.proxy.proxy_server.os.path.exists", side_effect=exists_side_effect
),
patch("litellm.proxy.proxy_server.os.access", return_value=True),
patch("litellm.proxy.proxy_server.os.getenv") as mock_getenv,
patch("litellm.proxy.proxy_server.FileResponse") as mock_file_response,
):
# Setup mock_getenv
def getenv_side_effect(key, default=""):
if key == "UI_LOGO_PATH":
return ""
elif key == "LITELLM_NON_ROOT":
return "true"
return default
mock_getenv.side_effect = getenv_side_effect
# Call the function
await get_image()
# Verify makedirs was called with /var/lib/litellm/assets
mock_makedirs.assert_called_once_with("/var/lib/litellm/assets", exist_ok=True)
# Verify that exists was called to check /var/lib/litellm/assets/logo.jpg
assets_logo_path = "/var/lib/litellm/assets/logo.jpg"
assert any(
assets_logo_path in str(call) for call in exists_calls
), f"Should check if {assets_logo_path} exists"
# Verify FileResponse was called (with fallback logo)
assert mock_file_response.called, "FileResponse should be called"
@pytest.mark.asyncio
async def test_get_image_root_case_uses_current_dir(monkeypatch):
"""
Test that get_image uses current_dir when LITELLM_NON_ROOT is not true.
"""
from unittest.mock import patch
from litellm.proxy.proxy_server import get_image
# Don't set LITELLM_NON_ROOT (or set it to false)
monkeypatch.delenv("LITELLM_NON_ROOT", raising=False)
monkeypatch.delenv("UI_LOGO_PATH", raising=False)
# Mock os.path operations
with (
patch("litellm.proxy.proxy_server.os.makedirs") as mock_makedirs,
patch("litellm.proxy.proxy_server.os.path.exists", return_value=True),
patch("litellm.proxy.proxy_server.os.getenv") as mock_getenv,
patch("litellm.proxy.proxy_server.FileResponse") as mock_file_response,
):
# Setup mock_getenv
def getenv_side_effect(key, default=""):
if key == "UI_LOGO_PATH":
return ""
elif key == "LITELLM_NON_ROOT":
return "" # Not set or empty
return default
mock_getenv.side_effect = getenv_side_effect
# Call the function
await get_image()
# Verify makedirs was NOT called with /var/lib/litellm/assets (should not create it for root case)
var_lib_assets_calls = [
call
for call in mock_makedirs.call_args_list
if "/var/lib/litellm/assets" in str(call)
]
assert (
len(var_lib_assets_calls) == 0
), "Should not create /var/lib/litellm/assets for root case"
# Verify FileResponse was called
assert mock_file_response.called, "FileResponse should be called"
@pytest.mark.asyncio
async def test_get_image_custom_local_logo_bypasses_cache(monkeypatch, tmp_path):
"""
Test that when UI_LOGO_PATH is set to a local file, get_image serves it
directly and does not return a stale cached_logo.jpg.
Regression test: previously the cache check ran before reading UI_LOGO_PATH,
so a pre-existing cached_logo.jpg (e.g. from the base Docker image) would
always be returned, ignoring the user's custom logo.
"""
from litellm.proxy.proxy_server import get_image
custom_logo = tmp_path / "custom_logo.jpg"
custom_logo.write_bytes(b"\xff\xd8\xff custom logo")
monkeypatch.setenv("UI_LOGO_PATH", str(custom_logo))
monkeypatch.delenv("LITELLM_NON_ROOT", raising=False)
monkeypatch.delenv("LITELLM_ASSETS_PATH", raising=False)
calls_to_file_response = []
def fake_file_response(path, **kwargs):
calls_to_file_response.append(path)
return MagicMock()
with (
patch(
"litellm.proxy.proxy_server.FileResponse", side_effect=fake_file_response
),
):
await get_image()
assert (
len(calls_to_file_response) == 1
), "FileResponse should be called exactly once"
assert calls_to_file_response[0] == str(custom_logo.resolve()), (
f"Expected custom logo path, got {calls_to_file_response[0]}. "
"A stale cached_logo.jpg may have been returned instead."
)
@pytest.mark.asyncio
async def test_get_image_default_logo_ignores_stale_cache(monkeypatch, tmp_path):
"""
Test that when UI_LOGO_PATH is NOT set, stale pre-fix cached_logo.jpg
files are ignored and the default logo is served.
"""
from unittest.mock import patch
from litellm.proxy.proxy_server import get_image
cache_path = tmp_path / "cached_logo.jpg"
cache_path.write_bytes(b"\xff\xd8\xff cached logo")
monkeypatch.delenv("UI_LOGO_PATH", raising=False)
monkeypatch.delenv("LITELLM_NON_ROOT", raising=False)
monkeypatch.setenv("LITELLM_ASSETS_PATH", str(tmp_path))
calls_to_file_response = []
def fake_file_response(path, **kwargs):
calls_to_file_response.append(path)
return MagicMock()
with (
patch(
"litellm.proxy.proxy_server.FileResponse", side_effect=fake_file_response
),
):
await get_image()
assert (
len(calls_to_file_response) == 1
), "FileResponse should be called exactly once"
served_path = calls_to_file_response[0]
assert served_path != str(cache_path.resolve())
assert served_path.endswith("logo.jpg")
@pytest.mark.asyncio
async def test_get_image_custom_logo_missing_falls_through_to_default(
monkeypatch, tmp_path
):
"""
Test that when UI_LOGO_PATH points to a non-existent local file,
get_image falls through to the default logo instead of failing.
"""
from unittest.mock import patch
from litellm.proxy.proxy_server import get_image
custom_logo_path = tmp_path / "nonexistent_logo.jpg"
monkeypatch.setenv("UI_LOGO_PATH", str(custom_logo_path))
monkeypatch.delenv("LITELLM_NON_ROOT", raising=False)
monkeypatch.setenv("LITELLM_ASSETS_PATH", str(tmp_path))
calls_to_file_response = []
def fake_file_response(path, **kwargs):
calls_to_file_response.append(path)
return MagicMock()
with (
patch(
"litellm.proxy.proxy_server.FileResponse", side_effect=fake_file_response
),
):
await get_image()
assert (
len(calls_to_file_response) == 1
), "FileResponse should be called exactly once"
served_path = calls_to_file_response[0]
assert served_path != str(
custom_logo_path
), "Should not attempt to serve a non-existent custom logo"
assert served_path.endswith("logo.jpg")
@pytest.mark.asyncio
async def test_get_image_custom_logo_missing_no_cache_serves_default(
monkeypatch, tmp_path
):
"""
Test that when UI_LOGO_PATH points to a non-existent file AND there is no
cached_logo.jpg, get_image serves the default logo instead of the non-existent
custom path.
"""
from unittest.mock import patch
from litellm.proxy.proxy_server import get_image
custom_logo_path = tmp_path / "nonexistent_logo.jpg"
monkeypatch.setenv("UI_LOGO_PATH", str(custom_logo_path))
monkeypatch.delenv("LITELLM_NON_ROOT", raising=False)
monkeypatch.setenv("LITELLM_ASSETS_PATH", str(tmp_path))
calls_to_file_response = []
def fake_file_response(path, **kwargs):
calls_to_file_response.append(path)
return MagicMock()
with (
patch(
"litellm.proxy.proxy_server.FileResponse", side_effect=fake_file_response
),
):
await get_image()
assert (
len(calls_to_file_response) == 1
), "FileResponse should be called exactly once"
served_path = calls_to_file_response[0]
assert served_path != str(
custom_logo_path
), "Should not attempt to serve a non-existent custom logo"
assert served_path.endswith(
"logo.jpg"
), f"Expected fallback to default logo.jpg, got {served_path}"
def test_get_config_normalizes_string_callbacks(monkeypatch):
"""
Test that /get/config/callbacks normalizes string callbacks to lists.
"""
from litellm.proxy.proxy_server import app, proxy_config, user_api_key_auth
config_data = {
"litellm_settings": {
"success_callback": "langfuse",
"failure_callback": None,
"callbacks": ["prometheus", "datadog"],
},
"general_settings": {},
"environment_variables": {},
}
mock_router = MagicMock()
mock_router.get_settings.return_value = {}
monkeypatch.setattr("litellm.proxy.proxy_server.llm_router", mock_router)
monkeypatch.setattr(proxy_config, "get_config", AsyncMock(return_value=config_data))
original_overrides = app.dependency_overrides.copy()
app.dependency_overrides[user_api_key_auth] = lambda: MagicMock()
client = TestClient(app)
try:
response = client.get("/get/config/callbacks")
finally:
app.dependency_overrides = original_overrides
assert response.status_code == 200
callbacks = response.json()["callbacks"]
success_callbacks = [cb["name"] for cb in callbacks if cb.get("type") == "success"]
failure_callbacks = [cb["name"] for cb in callbacks if cb.get("type") == "failure"]
success_and_failure_callbacks = [
cb["name"] for cb in callbacks if cb.get("type") == "success_and_failure"
]
assert "langfuse" in success_callbacks
assert len(failure_callbacks) == 0
assert "prometheus" in success_and_failure_callbacks
assert "datadog" in success_and_failure_callbacks
def test_deep_merge_dicts_skips_none_and_empty_lists(monkeypatch):
"""
Test that _update_config_fields deep merge skips None values and empty lists.
"""
from litellm.proxy.proxy_server import ProxyConfig
proxy_config = ProxyConfig()
current_config = {
"general_settings": {
"max_parallel_requests": 10,
"allowed_models": ["gpt-3.5-turbo", "gpt-4"],
"nested": {
"key1": "value1",
"key2": "value2",
},
}
}
db_param_value = {
"max_parallel_requests": None,
"allowed_models": [],
"new_key": "new_value",
"nested": {
"key1": "updated_value1",
"key3": "value3",
},
}
result = proxy_config._update_config_fields(
current_config, "general_settings", db_param_value
)
assert result["general_settings"]["max_parallel_requests"] == 10
assert result["general_settings"]["allowed_models"] == ["gpt-3.5-turbo", "gpt-4"]
assert result["general_settings"]["new_key"] == "new_value"
assert result["general_settings"]["nested"]["key1"] == "updated_value1"
assert result["general_settings"]["nested"]["key2"] == "value2"
assert result["general_settings"]["nested"]["key3"] == "value3"
class TestInvitationEndpoints:
"""Tests for /invitation/new and /invitation/delete endpoints."""
@pytest.fixture
def client_with_auth(self):
"""Create a test client with admin authentication."""
from litellm.proxy._types import LitellmUserRoles
from litellm.proxy.proxy_server import cleanup_router_config_variables
cleanup_router_config_variables()
filepath = os.path.dirname(os.path.abspath(__file__))
config_fp = f"{filepath}/test_configs/test_config_no_auth.yaml"
asyncio.run(initialize(config=config_fp, debug=True))
mock_auth = MagicMock()
mock_auth.user_id = "admin-user-id"
mock_auth.user_role = LitellmUserRoles.PROXY_ADMIN
mock_auth.api_key = "sk-test"
app.dependency_overrides[user_api_key_auth] = lambda: mock_auth
return TestClient(app)
@pytest.mark.parametrize(
"endpoint,payload,mock_return",
[
(
"/invitation/new",
{"user_id": "target-user-123"},
{
"id": "inv-123",
"user_id": "target-user-123",
"is_accepted": False,
"accepted_at": None,
"expires_at": "2025-02-18T00:00:00",
"created_at": "2025-02-11T00:00:00",
"created_by": "admin-user-id",
"updated_at": "2025-02-11T00:00:00",
"updated_by": "admin-user-id",
},
),
(
"/invitation/delete",
{"invitation_id": "inv-456"},
{
"id": "inv-456",
"user_id": "target-user-123",
"is_accepted": False,
"accepted_at": None,
"expires_at": "2025-02-18T00:00:00",
"created_at": "2025-02-11T00:00:00",
"created_by": "admin-user-id",
"updated_at": "2025-02-11T00:00:00",
"updated_by": "admin-user-id",
},
),
],
)
def test_invitation_endpoints_proxy_admin_success(
self, client_with_auth, endpoint, payload, mock_return
):
"""Proxy admin can successfully create and delete invitations."""
with patch("litellm.proxy.proxy_server.prisma_client") as mock_prisma:
mock_prisma.db.litellm_invitationlink = MagicMock()
if endpoint == "/invitation/new":
mock_create = AsyncMock(return_value=mock_return)
with patch(
"litellm.proxy.management_helpers.user_invitation.create_invitation_for_user",
mock_create,
):
response = client_with_auth.post(endpoint, json=payload)
else:
mock_prisma.db.litellm_invitationlink.find_unique = AsyncMock(
return_value={**mock_return, "created_by": "admin-user-id"}
)
mock_prisma.db.litellm_invitationlink.delete = AsyncMock(
return_value=mock_return
)
response = client_with_auth.post(endpoint, json=payload)
assert response.status_code == 200
data = response.json()
assert data["id"] == mock_return["id"]
assert data["user_id"] == mock_return["user_id"]
@pytest.mark.parametrize(
"endpoint,payload",
[
("/invitation/new", {"user_id": "target-user-123"}),
("/invitation/delete", {"invitation_id": "inv-456"}),
],
)
def test_invitation_endpoints_non_admin_denied(
self, client_with_auth, endpoint, payload
):
"""Non-admin users cannot access invitation endpoints."""
from litellm.proxy._types import LitellmUserRoles
mock_auth = MagicMock()
mock_auth.user_id = "regular-user"
mock_auth.user_role = LitellmUserRoles.INTERNAL_USER
mock_auth.api_key = "sk-regular"
app.dependency_overrides[user_api_key_auth] = lambda: mock_auth
with patch("litellm.proxy.proxy_server.prisma_client") as mock_prisma:
mock_prisma.db.litellm_invitationlink = MagicMock()
# Avoid triggering async DB calls in _user_has_admin_privileges
with patch(
"litellm.proxy.proxy_server._user_has_admin_privileges",
new_callable=AsyncMock,
return_value=False,
):
response = client_with_auth.post(endpoint, json=payload)
assert response.status_code == 400
body = response.json()
# ProxyException handler returns {"error": {...}}, HTTPException returns {"detail": {...}}
error_content = body.get("error", body.get("detail", body))
assert "not allowed" in str(error_content).lower()
@pytest.mark.asyncio
async def test_async_data_generator_cleanup_on_early_exit():
"""
Test that async_data_generator calls response.aclose() in the finally block
when the generator is abandoned mid-stream (client disconnect).
"""
from litellm.proxy._types import UserAPIKeyAuth
from litellm.proxy.proxy_server import async_data_generator
from litellm.proxy.utils import ProxyLogging
mock_user_api_key_dict = MagicMock(spec=UserAPIKeyAuth)
mock_request_data = {
"model": "gpt-3.5-turbo",
"messages": [{"role": "user", "content": "test"}],
}
mock_chunks = [
{"choices": [{"delta": {"content": "Hello"}}]},
{"choices": [{"delta": {"content": " world"}}]},
{"choices": [{"delta": {"content": " more"}}]},
]
mock_proxy_logging_obj = MagicMock(spec=ProxyLogging)
async def mock_streaming_iterator(*args, **kwargs):
for chunk in mock_chunks:
yield chunk
mock_proxy_logging_obj.async_post_call_streaming_iterator_hook = (
mock_streaming_iterator
)
mock_proxy_logging_obj.async_post_call_streaming_hook = AsyncMock(
side_effect=lambda **kwargs: kwargs.get("response")
)
mock_proxy_logging_obj.post_call_failure_hook = AsyncMock()
# Create a mock response with aclose
mock_response = MagicMock()
mock_response.aclose = AsyncMock()
with patch("litellm.proxy.proxy_server.proxy_logging_obj", mock_proxy_logging_obj):
# Consume only the first chunk then abandon the generator (simulates client disconnect)
gen = async_data_generator(
mock_response, mock_user_api_key_dict, mock_request_data
)
first_chunk = await gen.__anext__()
assert first_chunk.startswith("data: ")
# Close the generator early (simulates what ASGI does on client disconnect)
await gen.aclose()
# Verify aclose was called on the response to release the HTTP connection
mock_response.aclose.assert_awaited_once()
@pytest.mark.asyncio
async def test_async_data_generator_uses_direct_stream_fast_path_without_callbacks():
"""
When there are no streaming callbacks, async_data_generator should avoid
per-chunk hook machinery and iterate the provider stream directly.
"""
from litellm.proxy._types import UserAPIKeyAuth
from litellm.proxy.proxy_server import async_data_generator
from litellm.proxy.utils import ProxyLogging
mock_user_api_key_dict = MagicMock(spec=UserAPIKeyAuth)
mock_request_data = {
"model": "gpt-3.5-turbo",
"messages": [{"role": "user", "content": "test"}],
}
mock_chunks = [
{"choices": [{"delta": {"content": "Hello"}}]},
{"choices": [{"delta": {"content": " world"}}]},
]
class MockStream:
def __aiter__(self):
return self._stream()
async def _stream(self):
for chunk in mock_chunks:
yield chunk
async def aclose(self):
pass
mock_response = MockStream()
mock_response.aclose = AsyncMock()
mock_proxy_logging_obj = MagicMock(spec=ProxyLogging)
mock_proxy_logging_obj.has_streaming_callbacks.return_value = False
mock_proxy_logging_obj.needs_iterator_wrap.return_value = False
mock_proxy_logging_obj.needs_per_chunk_streaming_hook.return_value = False
mock_proxy_logging_obj.async_post_call_streaming_iterator_hook = MagicMock()
mock_proxy_logging_obj.async_post_call_streaming_hook = AsyncMock()
mock_proxy_logging_obj.post_call_failure_hook = AsyncMock()
with patch("litellm.proxy.proxy_server.proxy_logging_obj", mock_proxy_logging_obj):
with patch.object(
ProxyLogging, "_fire_deferred_stream_logging"
) as mock_deferred_logging:
yielded_data = []
async for data in async_data_generator(
mock_response, mock_user_api_key_dict, mock_request_data
):
yielded_data.append(data)
yielded_text = [
chunk.decode("utf-8") if isinstance(chunk, bytes) else chunk
for chunk in yielded_data
]
assert len([chunk for chunk in yielded_text if chunk.startswith("data: {")]) == 2
assert yielded_text[-1] == "data: [DONE]\n\n"
mock_proxy_logging_obj.async_post_call_streaming_iterator_hook.assert_not_called()
mock_proxy_logging_obj.async_post_call_streaming_hook.assert_not_awaited()
mock_deferred_logging.assert_called_once_with(mock_request_data)
mock_response.aclose.assert_awaited_once()
@pytest.mark.asyncio
async def test_async_data_generator_passes_through_google_native_sse_bytes():
"""
Google-native streamGenerateContent yields raw SSE bytes; they must not be
re-wrapped as data: b'data: {...}'.
"""
from litellm.proxy._types import UserAPIKeyAuth
from litellm.proxy.proxy_server import async_data_generator
from litellm.proxy.utils import ProxyLogging
mock_user_api_key_dict = MagicMock(spec=UserAPIKeyAuth)
mock_request_data = {
"model": "gemini-2.0-flash",
"messages": [{"role": "user", "content": "test"}],
}
gemini_event = b'data: {"candidates": [{"content": "hi"}]}\n\n'
gemini_event_without_terminator = b'data: {"candidates": [{"content": "there"}]}'
raw_payload = b'{"partial": true}'
class MockStream:
def __aiter__(self):
return self._stream()
async def _stream(self):
yield gemini_event
yield gemini_event_without_terminator
yield raw_payload
async def aclose(self):
pass
mock_response = MockStream()
mock_response.aclose = AsyncMock()
mock_proxy_logging_obj = MagicMock(spec=ProxyLogging)
mock_proxy_logging_obj.has_streaming_callbacks.return_value = False
mock_proxy_logging_obj.needs_iterator_wrap.return_value = False
mock_proxy_logging_obj.needs_per_chunk_streaming_hook.return_value = False
mock_proxy_logging_obj.async_post_call_streaming_iterator_hook = MagicMock()
mock_proxy_logging_obj.async_post_call_streaming_hook = AsyncMock()
mock_proxy_logging_obj.post_call_failure_hook = AsyncMock()
with patch("litellm.proxy.proxy_server.proxy_logging_obj", mock_proxy_logging_obj):
with patch.object(ProxyLogging, "_fire_deferred_stream_logging"):
yielded_data = []
async for data in async_data_generator(
mock_response, mock_user_api_key_dict, mock_request_data
):
yielded_data.append(data)
yielded_text = [
chunk.decode("utf-8") if isinstance(chunk, bytes) else chunk
for chunk in yielded_data
]
assert yielded_text[0] == gemini_event.decode("utf-8")
assert yielded_text[1] == gemini_event_without_terminator.decode("utf-8") + "\n\n"
assert yielded_text[2] == f'data: {raw_payload.decode("utf-8")}\n\n'
assert "b'data:" not in "".join(yielded_text)
assert yielded_text[-1] == "data: [DONE]\n\n"
@pytest.mark.asyncio
async def test_async_data_generator_google_genai_stream_omits_openai_done():
"""
google-genai SDK streamGenerateContent?alt=sse must not receive data: [DONE].
"""
from litellm.proxy._types import UserAPIKeyAuth
from litellm.proxy.proxy_server import async_data_generator
from litellm.proxy.utils import ProxyLogging
mock_user_api_key_dict = MagicMock(spec=UserAPIKeyAuth)
mock_request_data = {
"model": "gemini-2.0-flash",
"_litellm_skip_openai_stream_done": True,
}
gemini_event = (
b'data: {"candidates": [{"content": {"parts": [{"text": "Hi"}]}}]}\n\n'
)
class MockStream:
def __aiter__(self):
return self._stream()
async def _stream(self):
yield gemini_event
async def aclose(self):
pass
mock_response = MockStream()
mock_response.aclose = AsyncMock()
mock_proxy_logging_obj = MagicMock(spec=ProxyLogging)
mock_proxy_logging_obj.has_streaming_callbacks.return_value = False
mock_proxy_logging_obj.needs_iterator_wrap.return_value = False
mock_proxy_logging_obj.needs_per_chunk_streaming_hook.return_value = False
mock_proxy_logging_obj.async_post_call_streaming_iterator_hook = MagicMock()
mock_proxy_logging_obj.async_post_call_streaming_hook = AsyncMock()
mock_proxy_logging_obj.post_call_failure_hook = AsyncMock()
with patch("litellm.proxy.proxy_server.proxy_logging_obj", mock_proxy_logging_obj):
with patch.object(ProxyLogging, "_fire_deferred_stream_logging"):
yielded_data = []
async for data in async_data_generator(
mock_response, mock_user_api_key_dict, mock_request_data
):
yielded_data.append(data)
yielded_text = [
chunk.decode("utf-8") if isinstance(chunk, bytes) else chunk
for chunk in yielded_data
]
assert yielded_text == [gemini_event.decode("utf-8")]
assert "[DONE]" not in "".join(yielded_text)
@pytest.mark.asyncio
async def test_async_data_generator_google_genai_stream_forwards_error_without_done():
"""Stream errors must still reach the client when OpenAI [DONE] is skipped."""
from litellm.proxy._types import UserAPIKeyAuth
from litellm.proxy.proxy_server import async_data_generator
from litellm.proxy.utils import ProxyLogging
error_sse = 'data: {"error": {"message": "stream failed"}}\n\n'
mock_user_api_key_dict = MagicMock(spec=UserAPIKeyAuth)
mock_request_data = {
"model": "gemini-2.0-flash",
"_litellm_skip_openai_stream_done": True,
}
class MockStream:
def __aiter__(self):
return self._stream()
async def _stream(self):
yield error_sse
async def aclose(self):
pass
mock_response = MockStream()
mock_response.aclose = AsyncMock()
mock_proxy_logging_obj = MagicMock(spec=ProxyLogging)
mock_proxy_logging_obj.has_streaming_callbacks.return_value = False
mock_proxy_logging_obj.needs_iterator_wrap.return_value = False
mock_proxy_logging_obj.needs_per_chunk_streaming_hook.return_value = False
mock_proxy_logging_obj.async_post_call_streaming_iterator_hook = MagicMock()
mock_proxy_logging_obj.async_post_call_streaming_hook = AsyncMock()
mock_proxy_logging_obj.post_call_failure_hook = AsyncMock()
with patch("litellm.proxy.proxy_server.proxy_logging_obj", mock_proxy_logging_obj):
with patch.object(ProxyLogging, "_fire_deferred_stream_logging"):
yielded_data = []
async for data in async_data_generator(
mock_response, mock_user_api_key_dict, mock_request_data
):
yielded_data.append(data)
yielded_text = [
chunk.decode("utf-8") if isinstance(chunk, bytes) else chunk
for chunk in yielded_data
]
assert yielded_text == [error_sse]
assert "[DONE]" not in "".join(yielded_text)
@pytest.mark.asyncio
async def test_async_data_generator_cleanup_on_normal_completion():
"""
Test that async_data_generator calls response.aclose() even on normal completion.
"""
from litellm.proxy._types import UserAPIKeyAuth
from litellm.proxy.proxy_server import async_data_generator
from litellm.proxy.utils import ProxyLogging
mock_user_api_key_dict = MagicMock(spec=UserAPIKeyAuth)
mock_request_data = {
"model": "gpt-3.5-turbo",
"messages": [{"role": "user", "content": "test"}],
}
mock_chunks = [
{"choices": [{"delta": {"content": "Hello"}}]},
]
mock_proxy_logging_obj = MagicMock(spec=ProxyLogging)
async def mock_streaming_iterator(*args, **kwargs):
for chunk in mock_chunks:
yield chunk
mock_proxy_logging_obj.async_post_call_streaming_iterator_hook = (
mock_streaming_iterator
)
mock_proxy_logging_obj.async_post_call_streaming_hook = AsyncMock(
side_effect=lambda **kwargs: kwargs.get("response")
)
mock_proxy_logging_obj.post_call_failure_hook = AsyncMock()
mock_response = MagicMock()
mock_response.aclose = AsyncMock()
with patch("litellm.proxy.proxy_server.proxy_logging_obj", mock_proxy_logging_obj):
yielded_data = []
async for data in async_data_generator(
mock_response, mock_user_api_key_dict, mock_request_data
):
yielded_data.append(data)
# Should have completed normally with [DONE]
assert any("[DONE]" in d for d in yielded_data)
# aclose should still be called via finally block
mock_response.aclose.assert_awaited_once()
@pytest.mark.asyncio
async def test_async_data_generator_cleanup_on_midstream_error():
"""
Test that async_data_generator calls response.aclose() via finally block
even when an exception occurs mid-stream.
"""
from litellm.proxy._types import UserAPIKeyAuth
from litellm.proxy.proxy_server import async_data_generator
from litellm.proxy.utils import ProxyLogging
mock_user_api_key_dict = MagicMock(spec=UserAPIKeyAuth)
mock_request_data = {
"model": "gpt-3.5-turbo",
"messages": [{"role": "user", "content": "test"}],
}
mock_proxy_logging_obj = MagicMock(spec=ProxyLogging)
async def mock_streaming_iterator_with_error(*args, **kwargs):
yield {"choices": [{"delta": {"content": "Hello"}}]}
raise RuntimeError("upstream connection reset")
mock_proxy_logging_obj.async_post_call_streaming_iterator_hook = (
mock_streaming_iterator_with_error
)
mock_proxy_logging_obj.async_post_call_streaming_hook = AsyncMock(
side_effect=lambda **kwargs: kwargs.get("response")
)
mock_proxy_logging_obj.post_call_failure_hook = AsyncMock()
mock_response = MagicMock()
mock_response.aclose = AsyncMock()
with patch("litellm.proxy.proxy_server.proxy_logging_obj", mock_proxy_logging_obj):
yielded_data = []
async for data in async_data_generator(
mock_response, mock_user_api_key_dict, mock_request_data
):
yielded_data.append(data)
# Should have yielded data chunk and then an error chunk
assert len(yielded_data) >= 2
assert any("error" in d for d in yielded_data)
# aclose must still be called via finally block despite the error
mock_response.aclose.assert_awaited_once()
# ============================================================================
# store_model_in_db DB Config Override Tests
# ============================================================================
def test_store_model_in_db_in_config_general_settings():
"""
Verify store_model_in_db is a valid field in ConfigGeneralSettings
and validates correctly for True/False values.
"""
from litellm.proxy._types import ConfigGeneralSettings
assert "store_model_in_db" in ConfigGeneralSettings.model_fields
# Should validate with True
config = ConfigGeneralSettings(store_model_in_db=True)
assert config.store_model_in_db is True
# Should validate with False
config = ConfigGeneralSettings(store_model_in_db=False)
assert config.store_model_in_db is False
# Should validate with None (default)
config = ConfigGeneralSettings(store_model_in_db=None)
assert config.store_model_in_db is None
# Should validate with no value
config = ConfigGeneralSettings()
assert config.store_model_in_db is None
@pytest.mark.asyncio
async def test_update_general_settings_store_model_in_db_true():
"""
Verify _update_general_settings sets global store_model_in_db to True
when DB general_settings has store_model_in_db=True.
"""
from litellm.proxy.proxy_server import ProxyConfig
proxy_config = ProxyConfig()
with (
patch("litellm.proxy.proxy_server.store_model_in_db", False) as mock_store,
patch("litellm.proxy.proxy_server.general_settings", {}) as mock_gs,
):
await proxy_config._update_general_settings(
db_general_settings={"store_model_in_db": True}
)
import litellm.proxy.proxy_server as ps
assert ps.store_model_in_db is True
assert ps.general_settings["store_model_in_db"] is True
@pytest.mark.asyncio
async def test_update_general_settings_store_model_in_db_false():
"""
Verify _update_general_settings sets global store_model_in_db to False
when DB general_settings has store_model_in_db=False.
"""
from litellm.proxy.proxy_server import ProxyConfig
proxy_config = ProxyConfig()
with (
patch("litellm.proxy.proxy_server.store_model_in_db", True),
patch("litellm.proxy.proxy_server.general_settings", {}),
):
await proxy_config._update_general_settings(
db_general_settings={"store_model_in_db": False}
)
import litellm.proxy.proxy_server as ps
assert ps.store_model_in_db is False
assert ps.general_settings["store_model_in_db"] is False
@pytest.mark.asyncio
async def test_update_general_settings_store_model_in_db_string_normalization():
"""
Verify _update_general_settings normalizes string values for store_model_in_db.
"""
from litellm.proxy.proxy_server import ProxyConfig
proxy_config = ProxyConfig()
# Test "true" string
with (
patch("litellm.proxy.proxy_server.store_model_in_db", False),
patch("litellm.proxy.proxy_server.general_settings", {}),
):
await proxy_config._update_general_settings(
db_general_settings={"store_model_in_db": "true"}
)
import litellm.proxy.proxy_server as ps
assert ps.store_model_in_db is True
# Test "True" string
with (
patch("litellm.proxy.proxy_server.store_model_in_db", False),
patch("litellm.proxy.proxy_server.general_settings", {}),
):
await proxy_config._update_general_settings(
db_general_settings={"store_model_in_db": "True"}
)
import litellm.proxy.proxy_server as ps
assert ps.store_model_in_db is True
# Test "false" string
with (
patch("litellm.proxy.proxy_server.store_model_in_db", True),
patch("litellm.proxy.proxy_server.general_settings", {}),
):
await proxy_config._update_general_settings(
db_general_settings={"store_model_in_db": "false"}
)
import litellm.proxy.proxy_server as ps
assert ps.store_model_in_db is False
@pytest.mark.asyncio
async def test_update_general_settings_store_model_in_db_none_keeps_current():
"""
Verify _update_general_settings does not change store_model_in_db
when DB value is None.
"""
from litellm.proxy.proxy_server import ProxyConfig
proxy_config = ProxyConfig()
# When current is True and DB sends None, should stay True
with (
patch("litellm.proxy.proxy_server.store_model_in_db", True),
patch("litellm.proxy.proxy_server.general_settings", {}),
):
await proxy_config._update_general_settings(
db_general_settings={"store_model_in_db": None}
)
import litellm.proxy.proxy_server as ps
assert ps.store_model_in_db is True
# When current is False and DB sends None, should stay False
with (
patch("litellm.proxy.proxy_server.store_model_in_db", False),
patch("litellm.proxy.proxy_server.general_settings", {}),
):
await proxy_config._update_general_settings(
db_general_settings={"store_model_in_db": None}
)
import litellm.proxy.proxy_server as ps
assert ps.store_model_in_db is False
@pytest.mark.asyncio
async def test_store_model_in_db_db_override_when_config_false():
"""
Verify the early DB check in initialize_scheduled_background_jobs
overrides store_model_in_db=False when DB has True.
"""
from litellm.proxy.proxy_server import ProxyStartupEvent
from litellm.proxy.utils import ProxyLogging
mock_prisma_client = MagicMock()
# Mock DB returning store_model_in_db=True in general_settings
mock_db_record = MagicMock()
mock_db_record.param_value = {"store_model_in_db": True}
mock_prisma_client.db.litellm_config.find_first = AsyncMock(
return_value=mock_db_record
)
mock_proxy_logging = MagicMock(spec=ProxyLogging)
mock_proxy_logging.slack_alerting_instance = MagicMock()
mock_proxy_config = AsyncMock()
with (
patch("litellm.proxy.proxy_server.proxy_config", mock_proxy_config),
patch("litellm.proxy.proxy_server.store_model_in_db", False),
patch("litellm.proxy.proxy_server.get_secret_bool", return_value=False),
):
await ProxyStartupEvent.initialize_scheduled_background_jobs(
general_settings={},
prisma_client=mock_prisma_client,
proxy_budget_rescheduler_min_time=1,
proxy_budget_rescheduler_max_time=2,
proxy_batch_write_at=5,
proxy_logging_obj=mock_proxy_logging,
)
import litellm.proxy.proxy_server as ps
# store_model_in_db should now be True (overridden by DB)
assert ps.store_model_in_db is True
# add_deployment and get_credentials should have been called
# since store_model_in_db is now True
assert mock_proxy_config.add_deployment.call_count == 1
assert mock_proxy_config.get_credentials.call_count == 1
@pytest.mark.asyncio
async def test_store_model_in_db_db_check_skipped_when_already_true(monkeypatch):
"""
Verify the early DB check is skipped when store_model_in_db is already True.
The DB query for the early check should not be called.
"""
monkeypatch.delenv("STORE_MODEL_IN_DB", raising=False)
from litellm.proxy.proxy_server import ProxyStartupEvent
from litellm.proxy.utils import ProxyLogging
mock_prisma_client = MagicMock()
mock_prisma_client.db.litellm_config.find_first = AsyncMock(return_value=None)
mock_proxy_logging = MagicMock(spec=ProxyLogging)
mock_proxy_logging.slack_alerting_instance = MagicMock()
mock_proxy_config = AsyncMock()
with (
patch("litellm.proxy.proxy_server.proxy_config", mock_proxy_config),
patch("litellm.proxy.proxy_server.store_model_in_db", True),
patch("litellm.proxy.proxy_server.get_secret_bool", return_value=True),
):
await ProxyStartupEvent.initialize_scheduled_background_jobs(
general_settings={},
prisma_client=mock_prisma_client,
proxy_budget_rescheduler_min_time=1,
proxy_budget_rescheduler_max_time=2,
proxy_batch_write_at=5,
proxy_logging_obj=mock_proxy_logging,
)
# The early DB check uses find_first with param_name="general_settings".
# When store_model_in_db is already True, the early check should be skipped.
# However, add_deployment may also call find_first.
# We just verify that store_model_in_db stays True and jobs are scheduled.
import litellm.proxy.proxy_server as ps
assert ps.store_model_in_db is True
assert mock_proxy_config.add_deployment.call_count == 1
@pytest.mark.asyncio
async def test_store_model_in_db_db_failure_graceful(monkeypatch):
"""
Verify the early DB check handles DB failures gracefully
without crashing and keeps store_model_in_db as False.
"""
monkeypatch.delenv("STORE_MODEL_IN_DB", raising=False)
from litellm.proxy.proxy_server import ProxyStartupEvent
from litellm.proxy.utils import ProxyLogging
mock_prisma_client = MagicMock()
# Simulate DB failure
mock_prisma_client.db.litellm_config.find_first = AsyncMock(
side_effect=Exception("DB connection error")
)
mock_proxy_logging = MagicMock(spec=ProxyLogging)
mock_proxy_logging.slack_alerting_instance = MagicMock()
mock_proxy_config = AsyncMock()
with (
patch("litellm.proxy.proxy_server.proxy_config", mock_proxy_config),
patch("litellm.proxy.proxy_server.store_model_in_db", False),
patch("litellm.proxy.proxy_server.get_secret_bool", return_value=False),
):
# Should not raise an exception
await ProxyStartupEvent.initialize_scheduled_background_jobs(
general_settings={},
prisma_client=mock_prisma_client,
proxy_budget_rescheduler_min_time=1,
proxy_budget_rescheduler_max_time=2,
proxy_batch_write_at=5,
proxy_logging_obj=mock_proxy_logging,
)
import litellm.proxy.proxy_server as ps
# store_model_in_db should remain False
assert ps.store_model_in_db is False
# add_deployment should NOT have been called since store_model_in_db is False
mock_proxy_config.add_deployment.assert_not_called()
# =====================================================================
# Spend counter tests (v2 — Redis-backed spend counters)
# =====================================================================
@pytest.mark.asyncio
async def test_get_current_spend_reads_redis_first():
"""get_current_spend should prefer Redis over in-memory."""
from litellm.caching.dual_cache import DualCache
counter_cache = DualCache()
# In-memory has stale value
counter_cache.in_memory_cache.set_cache(key="spend:key:test", value=0.30)
# Mock Redis with cross-pod authoritative value
mock_redis = AsyncMock()
mock_redis.async_get_cache = AsyncMock(return_value=0.90)
counter_cache.redis_cache = mock_redis
import litellm.proxy.proxy_server as ps
original = ps.spend_counter_cache
ps.spend_counter_cache = counter_cache
try:
from litellm.proxy.proxy_server import get_current_spend
result = await get_current_spend(
counter_key="spend:key:test",
fallback_spend=0.0,
)
# Should return Redis value (0.90), not in-memory (0.30)
assert result == 0.90
mock_redis.async_get_cache.assert_called_once_with(key="spend:key:test")
finally:
ps.spend_counter_cache = original
@pytest.mark.asyncio
async def test_get_current_spend_fallback_to_in_memory():
"""When Redis is not configured, get_current_spend uses in-memory."""
from litellm.caching.dual_cache import DualCache
counter_cache = DualCache() # no redis_cache
counter_cache.in_memory_cache.set_cache(key="spend:key:test", value=0.50)
import litellm.proxy.proxy_server as ps
original = ps.spend_counter_cache
ps.spend_counter_cache = counter_cache
try:
from litellm.proxy.proxy_server import get_current_spend
result = await get_current_spend(
counter_key="spend:key:test",
fallback_spend=0.0,
)
assert result == 0.50
finally:
ps.spend_counter_cache = original
@pytest.mark.asyncio
async def test_increment_spend_counters_initializes_and_increments():
"""Counter should initialize from cached object spend, then increment.
Uses a pre-hashed token to match production: metadata["user_api_key"]
is always hashed by the auth flow before reaching the cost callback.
"""
from litellm.caching.dual_cache import DualCache
from litellm.proxy._types import LiteLLM_VerificationTokenView, hash_token
key_cache = DualCache()
counter_cache = DualCache()
# In production, the auth flow hashes the raw key before it reaches
# the cost callback. Simulate that by passing the hashed token.
hashed_token = hash_token("sk-test-token-for-counter")
# Simulate a cached key object with existing spend from DB
cached_key = LiteLLM_VerificationTokenView(
token=hashed_token,
spend=5.0,
max_budget=10.0,
)
key_cache.in_memory_cache.set_cache(key=hashed_token, value=cached_key)
import litellm.proxy.proxy_server as ps
original_key_cache = ps.user_api_key_cache
original_counter_cache = ps.spend_counter_cache
ps.user_api_key_cache = key_cache
ps.spend_counter_cache = counter_cache
try:
from litellm.proxy.proxy_server import increment_spend_counters
# Pass pre-hashed token (as the cost callback would in production)
await increment_spend_counters(
token=hashed_token,
team_id=None,
user_id=None,
response_cost=0.50,
)
# Counter should be: base(5.0) + increment(0.50) = 5.50
counter = counter_cache.in_memory_cache.get_cache(
key=f"spend:key:{hashed_token}"
)
assert counter == 5.50
# Second increment — counter already exists, just increment
await increment_spend_counters(
token=hashed_token,
team_id=None,
user_id=None,
response_cost=0.25,
)
counter = counter_cache.in_memory_cache.get_cache(
key=f"spend:key:{hashed_token}"
)
assert counter == 5.75
finally:
ps.user_api_key_cache = original_key_cache
ps.spend_counter_cache = original_counter_cache
@pytest.mark.asyncio
async def test_increment_spend_counters_team_and_member():
"""Counter should track team and team member spend separately."""
from litellm.caching.dual_cache import DualCache
from litellm.proxy._types import LiteLLM_TeamTable
key_cache = DualCache()
counter_cache = DualCache()
# Cached team object
team_obj = LiteLLM_TeamTable(team_id="team-1", spend=2.0)
key_cache.in_memory_cache.set_cache(key="team_id:team-1", value=team_obj)
# Cached team membership
key_cache.in_memory_cache.set_cache(
key="team_membership:user-1:team-1",
value={"user_id": "user-1", "team_id": "team-1", "spend": 1.0},
)
import litellm.proxy.proxy_server as ps
original_key_cache = ps.user_api_key_cache
original_counter_cache = ps.spend_counter_cache
ps.user_api_key_cache = key_cache
ps.spend_counter_cache = counter_cache
try:
from litellm.proxy.proxy_server import increment_spend_counters
await increment_spend_counters(
token=None,
team_id="team-1",
user_id="user-1",
response_cost=0.30,
)
team_counter = counter_cache.in_memory_cache.get_cache(key="spend:team:team-1")
assert team_counter == 2.30
member_counter = counter_cache.in_memory_cache.get_cache(
key="spend:team_member:user-1:team-1"
)
assert member_counter == 1.30
finally:
ps.user_api_key_cache = original_key_cache
ps.spend_counter_cache = original_counter_cache
@pytest.mark.asyncio
async def test_init_and_increment_spend_counter_reseeds_from_db_on_counter_miss():
"""When the Redis counter is missing, the reseed path reads the
authoritative spend from the DB (not a stale cache), so the next
increment continues from the correct base value."""
from litellm.caching.dual_cache import DualCache
counter_cache = DualCache()
recorded_increments: list = []
async def record_increment(key, value, ttl=None, **kwargs):
recorded_increments.append({"key": key, "value": value, "ttl": ttl})
return value
fake_redis = AsyncMock()
fake_redis.async_increment = AsyncMock(side_effect=record_increment)
fake_redis.async_get_cache = AsyncMock(return_value=None) # counter missing
fake_redis.async_set_cache = AsyncMock(return_value=True) # SET NX wins
counter_cache.redis_cache = fake_redis
# Prisma returns spend=42.0 (authoritative) while the stale cached
# value (would be read only if prisma is None) is 10.0. The counter
# must seed from 42, not 10.
db_row = MagicMock()
db_row.spend = 42.0
fake_prisma = MagicMock()
fake_prisma.db.litellm_teamtable.find_unique = AsyncMock(return_value=db_row)
stale_cache = DualCache()
stale_team = MagicMock()
stale_team.spend = 10.0
stale_cache.in_memory_cache.set_cache(key="team_id:team-9", value=stale_team)
import litellm.proxy.proxy_server as ps
from litellm.proxy.proxy_server import _init_and_increment_spend_counter
orig_user, orig_counter, orig_prisma = (
ps.user_api_key_cache,
ps.spend_counter_cache,
ps.prisma_client,
)
ps.user_api_key_cache = stale_cache
ps.spend_counter_cache = counter_cache
ps.prisma_client = fake_prisma
try:
await _init_and_increment_spend_counter(
counter_key="spend:team:team-9",
source_cache_key="team_id:team-9",
increment=1.5,
)
fake_prisma.db.litellm_teamtable.find_unique.assert_awaited_once_with(
where={"team_id": "team-9"}
)
# Seed uses SET NX with db_spend (42) — cross-pod safe, no INCR of 42.
# Only the per-request delta (1.5) goes through INCRBYFLOAT.
fake_redis.async_set_cache.assert_awaited_once_with(
key="spend:team:team-9", value=42.0, nx=True
)
writes = [(c["key"], c["value"]) for c in recorded_increments]
assert writes == [("spend:team:team-9", 1.5)]
finally:
ps.user_api_key_cache = orig_user
ps.spend_counter_cache = orig_counter
ps.prisma_client = orig_prisma
@pytest.mark.asyncio
async def test_primary_spend_counter_redis_concurrent_seed_does_not_double_seed():
"""Two pods both observing a missing Redis counter must not both
INCRBYFLOAT the full DB spend. SpendCounterReseed.coalesced uses SET NX
so the loser reads the winner's value; final Redis = db_spend, not
2 * db_spend.
The per-counter asyncio.Lock is per-process, so it does NOT coordinate
across pods. We simulate two pods by patching _get_lock to return a
fresh lock per call (each "pod" has its own lock registry in real life).
"""
from litellm.caching.dual_cache import DualCache
from litellm.proxy.db.spend_counter_reseed import SpendCounterReseed
counter_key = "spend:team:team-concurrent-seed"
redis_store: dict = {}
db_read_count = 0
set_results: list = []
get_after_set_count = 0
set_completed_count = 0
async def redis_set_cache(key, value, nx=False, **_):
# Yield BEFORE the membership check so two concurrent callers
# interleave the way real atomic Redis SET NX does: the first
# to resume runs check + write atomically and wins; the second
# resumes after the key exists and loses. Yielding *after* the
# check would let both callers pass the empty-store check before
# either writes, so neither would ever lose.
await asyncio.sleep(0)
if nx and key in redis_store:
set_results.append(False)
return False
redis_store[key] = float(value)
set_results.append(True)
nonlocal set_completed_count
set_completed_count += 1
return True
async def redis_get_cache(key):
# Track reads that happen after at least one SET NX has completed
# — those are the loser-path fallback reads we want to verify.
if set_completed_count > 0:
nonlocal get_after_set_count
get_after_set_count += 1
return redis_store.get(key)
fake_redis = AsyncMock()
fake_redis.async_get_cache = AsyncMock(side_effect=redis_get_cache)
fake_redis.async_set_cache = AsyncMock(side_effect=redis_set_cache)
async def slow_find_unique(**_):
nonlocal db_read_count
db_read_count += 1
# Both pods read DB before either's SET NX lands.
await asyncio.sleep(0)
row = MagicMock()
row.spend = 506.0
return row
fake_prisma = MagicMock()
fake_prisma.db.litellm_teamtable.find_unique = AsyncMock(
side_effect=slow_find_unique
)
pod_a = DualCache()
pod_a.redis_cache = fake_redis
pod_b = DualCache()
pod_b.redis_cache = fake_redis
# Each "pod" has its own per-process lock registry. Patch _get_lock to
# always return a fresh lock so the two coalesced calls do not serialize
# via one in-process lock (which is what would happen across pods).
async def fresh_lock(_counter_key):
return asyncio.Lock()
with patch.object(SpendCounterReseed, "_get_lock", side_effect=fresh_lock):
results = await asyncio.gather(
SpendCounterReseed.coalesced(
prisma_client=fake_prisma,
spend_counter_cache=pod_a,
counter_key=counter_key,
),
SpendCounterReseed.coalesced(
prisma_client=fake_prisma,
spend_counter_cache=pod_b,
counter_key=counter_key,
),
)
assert all(r == 506.0 for r in results), results
assert redis_store[counter_key] == pytest.approx(506.0), redis_store
# Both pods read the DB and both attempted SET NX; exactly one wrote
# (winner) and one was rejected (loser).
assert db_read_count == 2
assert fake_redis.async_set_cache.await_count == 2
nx_writes = [
call
for call in fake_redis.async_set_cache.await_args_list
if call.kwargs.get("nx") is True
]
assert len(nx_writes) == 2
assert sorted(set_results) == [
False,
True,
], f"expected exactly one SET NX winner and one loser, got {set_results}"
# Loser path executed: after the winner's SET NX returned True, the
# losing coalesced() call falls back to async_get_cache to read the
# winner's value rather than re-seeding.
assert (
get_after_set_count >= 1
), "loser branch (else: read back winner's value) was never exercised"
@pytest.mark.asyncio
async def test_reseed_spend_from_db_user_and_org_prefixes():
"""User and org counters reseed from their own DB tables.
End-user and tag counters use the already fetched auth objects passed as
fallback_spend, so this reseed helper must not add extra per-request DB
reads for them.
"""
from litellm.proxy.db.spend_counter_reseed import SpendCounterReseed
user_row = MagicMock()
user_row.spend = 17.0
org_row = MagicMock()
org_row.spend = 305.0
fake_prisma = MagicMock()
fake_prisma.db.litellm_usertable.find_unique = AsyncMock(return_value=user_row)
fake_prisma.db.litellm_endusertable.find_unique = AsyncMock()
fake_prisma.db.litellm_tagtable.find_unique = AsyncMock()
fake_prisma.db.litellm_organizationtable.find_unique = AsyncMock(
return_value=org_row
)
assert await SpendCounterReseed.from_db(fake_prisma, "spend:user:alice") == 17.0
fake_prisma.db.litellm_usertable.find_unique.assert_awaited_once_with(
where={"user_id": "alice"}
)
assert (
await SpendCounterReseed.from_db(
fake_prisma,
"spend:end_user:customer-1",
)
is None
)
fake_prisma.db.litellm_endusertable.find_unique.assert_not_awaited()
assert await SpendCounterReseed.from_db(fake_prisma, "spend:tag:paid-tag") is None
fake_prisma.db.litellm_tagtable.find_unique.assert_not_awaited()
assert await SpendCounterReseed.from_db(fake_prisma, "spend:org:acme") == 305.0
fake_prisma.db.litellm_organizationtable.find_unique.assert_awaited_once_with(
where={"organization_id": "acme"}
)
@pytest.mark.asyncio
async def test_reseed_spend_from_db_skips_window_variant_keys():
"""Window counters (spend:*:window:{duration}) share prefixes with
primary counters but don't correspond to a DB row. The guard must
short-circuit without querying the DB."""
from litellm.proxy.db.spend_counter_reseed import SpendCounterReseed
fake_prisma = MagicMock()
fake_prisma.db.litellm_verificationtoken.find_unique = AsyncMock()
fake_prisma.db.litellm_teamtable.find_unique = AsyncMock()
assert (
await SpendCounterReseed.from_db(fake_prisma, "spend:key:sk-abc:window:1h")
is None
)
assert (
await SpendCounterReseed.from_db(fake_prisma, "spend:team:team-1:window:1d")
is None
)
fake_prisma.db.litellm_verificationtoken.find_unique.assert_not_awaited()
fake_prisma.db.litellm_teamtable.find_unique.assert_not_awaited()
@pytest.mark.asyncio
async def test_window_spend_counter_reseeds_from_spend_logs_on_counter_miss():
from litellm.caching.dual_cache import DualCache
from litellm.proxy.proxy_server import _init_and_increment_window_spend_counter
counter_cache = DualCache()
window_start = datetime.now(timezone.utc) - timedelta(hours=1)
fake_prisma = MagicMock()
fake_prisma.db.litellm_spendlogs.group_by = AsyncMock(
return_value=[{"api_key": "key-window", "_sum": {"spend": 2.25}}]
)
import litellm.proxy.proxy_server as ps
orig_counter, orig_prisma = ps.spend_counter_cache, ps.prisma_client
ps.spend_counter_cache = counter_cache
ps.prisma_client = fake_prisma
try:
await _init_and_increment_window_spend_counter(
counter_key="spend:key:key-window:window:1h",
entity_type="Key",
entity_id="key-window",
window_start=window_start,
increment=0.5,
)
fake_prisma.db.litellm_spendlogs.group_by.assert_awaited_once_with(
by=["api_key"],
where={"api_key": "key-window", "startTime": {"gte": window_start}},
sum={"spend": True},
)
assert counter_cache.in_memory_cache.get_cache(
key="spend:key:key-window:window:1h"
) == pytest.approx(2.75)
finally:
ps.spend_counter_cache = orig_counter
ps.prisma_client = orig_prisma
@pytest.mark.asyncio
async def test_init_spend_counter_redis_clean_miss_skips_stale_in_memory():
from litellm.caching.dual_cache import DualCache
from litellm.proxy.proxy_server import _init_and_increment_spend_counter
counter_cache = DualCache()
counter_key = "spend:team:team-stale-local"
counter_cache.in_memory_cache.set_cache(key=counter_key, value=10.0)
redis_store: dict = {}
async def redis_increment(key, value, **_):
redis_store[key] = (redis_store.get(key) or 0.0) + value
return redis_store[key]
async def redis_set_cache(key, value, nx=False, **_):
if nx and key in redis_store:
return False
redis_store[key] = float(value)
return True
fake_redis = AsyncMock()
fake_redis.async_get_cache = AsyncMock(return_value=None)
fake_redis.async_increment = AsyncMock(side_effect=redis_increment)
fake_redis.async_set_cache = AsyncMock(side_effect=redis_set_cache)
counter_cache.redis_cache = fake_redis
db_row = MagicMock()
db_row.spend = 42.0
fake_prisma = MagicMock()
fake_prisma.db.litellm_teamtable.find_unique = AsyncMock(return_value=db_row)
import litellm.proxy.proxy_server as ps
orig_counter, orig_prisma, orig_user = (
ps.spend_counter_cache,
ps.prisma_client,
ps.user_api_key_cache,
)
ps.spend_counter_cache = counter_cache
ps.prisma_client = fake_prisma
ps.user_api_key_cache = DualCache()
try:
await _init_and_increment_spend_counter(
counter_key=counter_key,
source_cache_key="team_id:team-stale-local",
increment=1.5,
)
fake_prisma.db.litellm_teamtable.find_unique.assert_awaited_once_with(
where={"team_id": "team-stale-local"}
)
# Seed via SET NX (42) + delta via INCRBYFLOAT (1.5) = 43.5.
assert redis_store[counter_key] == pytest.approx(43.5)
assert counter_cache.in_memory_cache.get_cache(
key=counter_key
) == pytest.approx(43.5)
finally:
ps.spend_counter_cache = orig_counter
ps.prisma_client = orig_prisma
ps.user_api_key_cache = orig_user
@pytest.mark.asyncio
async def test_window_spend_counter_redis_clean_miss_skips_stale_in_memory():
from litellm.caching.dual_cache import DualCache
from litellm.proxy.proxy_server import _init_and_increment_window_spend_counter
counter_cache = DualCache()
counter_key = "spend:key:key-window-stale-local:window:1h"
counter_cache.in_memory_cache.set_cache(key=counter_key, value=100.0)
window_start = datetime.now(timezone.utc) - timedelta(hours=1)
redis_store: dict = {}
async def redis_increment(key, value, **_):
redis_store[key] = (redis_store.get(key) or 0.0) + value
return redis_store[key]
async def redis_set_cache(key, value, **_):
if key in redis_store:
return False
redis_store[key] = value
return True
fake_redis = AsyncMock()
fake_redis.async_get_cache = AsyncMock(return_value=None)
fake_redis.async_set_cache = AsyncMock(side_effect=redis_set_cache)
fake_redis.async_increment = AsyncMock(side_effect=redis_increment)
counter_cache.redis_cache = fake_redis
fake_prisma = MagicMock()
fake_prisma.db.litellm_spendlogs.group_by = AsyncMock(
return_value=[{"api_key": "key-window-stale-local", "_sum": {"spend": 2.25}}]
)
import litellm.proxy.proxy_server as ps
orig_counter, orig_prisma = ps.spend_counter_cache, ps.prisma_client
ps.spend_counter_cache = counter_cache
ps.prisma_client = fake_prisma
try:
await _init_and_increment_window_spend_counter(
counter_key=counter_key,
entity_type="Key",
entity_id="key-window-stale-local",
window_start=window_start,
increment=0.5,
)
fake_prisma.db.litellm_spendlogs.group_by.assert_awaited_once_with(
by=["api_key"],
where={
"api_key": "key-window-stale-local",
"startTime": {"gte": window_start},
},
sum={"spend": True},
)
assert redis_store[counter_key] == pytest.approx(2.75)
assert counter_cache.in_memory_cache.get_cache(
key=counter_key
) == pytest.approx(2.75)
finally:
ps.spend_counter_cache = orig_counter
ps.prisma_client = orig_prisma
@pytest.mark.asyncio
async def test_window_spend_counter_redis_concurrent_seed_does_not_double_seed():
from litellm.caching.dual_cache import DualCache
from litellm.proxy.proxy_server import _init_and_increment_window_spend_counter
counter_cache = DualCache()
counter_key = "spend:key:key-window-concurrent-seed:window:1h"
window_start = datetime.now(timezone.utc) - timedelta(hours=1)
redis_store = {counter_key: 2.75}
redis_reads = 0
async def redis_get_cache(key):
nonlocal redis_reads
redis_reads += 1
if redis_reads <= 2:
return None
return redis_store.get(key)
async def redis_increment(key, value, **_):
redis_store[key] = (redis_store.get(key) or 0.0) + value
return redis_store[key]
fake_redis = AsyncMock()
fake_redis.async_get_cache = AsyncMock(side_effect=redis_get_cache)
fake_redis.async_set_cache = AsyncMock(return_value=False)
fake_redis.async_increment = AsyncMock(side_effect=redis_increment)
counter_cache.redis_cache = fake_redis
fake_prisma = MagicMock()
fake_prisma.db.litellm_spendlogs.group_by = AsyncMock(
return_value=[
{"api_key": "key-window-concurrent-seed", "_sum": {"spend": 2.25}}
]
)
import litellm.proxy.proxy_server as ps
orig_counter, orig_prisma = ps.spend_counter_cache, ps.prisma_client
ps.spend_counter_cache = counter_cache
ps.prisma_client = fake_prisma
try:
await _init_and_increment_window_spend_counter(
counter_key=counter_key,
entity_type="Key",
entity_id="key-window-concurrent-seed",
window_start=window_start,
increment=0.5,
)
fake_redis.async_set_cache.assert_awaited_once_with(
key=counter_key,
value=2.25,
nx=True,
)
assert redis_store[counter_key] == pytest.approx(3.25)
assert counter_cache.in_memory_cache.get_cache(
key=counter_key
) == pytest.approx(3.25)
finally:
ps.spend_counter_cache = orig_counter
ps.prisma_client = orig_prisma
@pytest.mark.asyncio
async def test_window_spend_counter_skips_invalid_window_start():
from litellm.caching.dual_cache import DualCache
from litellm.proxy.proxy_server import _init_and_increment_window_spend_counter
counter_cache = DualCache()
import litellm.proxy.proxy_server as ps
orig_counter = ps.spend_counter_cache
ps.spend_counter_cache = counter_cache
try:
await _init_and_increment_window_spend_counter(
counter_key="spend:key:key-invalid-window:window:not-a-duration",
entity_type="Key",
entity_id="key-invalid-window",
window_start=None,
increment=0.5,
)
assert (
counter_cache.in_memory_cache.get_cache(
key="spend:key:key-invalid-window:window:not-a-duration"
)
is None
)
finally:
ps.spend_counter_cache = orig_counter
@pytest.mark.asyncio
async def test_window_spend_counter_does_not_seed_zero_when_db_unavailable():
from litellm.caching.dual_cache import DualCache
from litellm.proxy.proxy_server import _ensure_window_spend_counter_initialized
counter_cache = DualCache()
counter_key = "spend:key:key-window-db-unavailable:window:1h"
import litellm.proxy.proxy_server as ps
orig_counter, orig_prisma = ps.spend_counter_cache, ps.prisma_client
ps.spend_counter_cache = counter_cache
ps.prisma_client = None
try:
initialized = await _ensure_window_spend_counter_initialized(
counter_key=counter_key,
entity_type="Key",
entity_id="key-window-db-unavailable",
window_start=datetime.now(timezone.utc) - timedelta(hours=1),
)
assert initialized is False
assert counter_cache.in_memory_cache.get_cache(key=counter_key) is None
finally:
ps.spend_counter_cache = orig_counter
ps.prisma_client = orig_prisma
@pytest.mark.asyncio
async def test_increment_spend_counters_finalizes_after_unreserved_increments():
from litellm.caching.dual_cache import DualCache
from litellm.proxy.proxy_server import increment_spend_counters
counter_cache = DualCache()
counter_cache.in_memory_cache.set_cache(
key="spend:key:key-finalize-after-increments",
value=0.5,
)
budget_reservation = {
"reserved_cost": 0.5,
"entries": [
{
"counter_key": "spend:key:key-finalize-after-increments",
"entity_type": "Key",
"entity_id": "key-finalize-after-increments",
"reserved_cost": 0.5,
"applied_adjustment": 0.0,
}
],
"finalized": False,
}
incremented_counters = []
async def assert_reservation_not_finalized_yet(**kwargs):
assert budget_reservation["finalized"] is False
incremented_counters.append(kwargs["counter_key"])
import litellm.proxy.proxy_server as ps
orig_counter, orig_user = ps.spend_counter_cache, ps.user_api_key_cache
ps.spend_counter_cache = counter_cache
ps.user_api_key_cache = DualCache()
try:
with patch(
"litellm.proxy.proxy_server._init_and_increment_spend_counter",
new=AsyncMock(side_effect=assert_reservation_not_finalized_yet),
):
await increment_spend_counters(
token="key-finalize-after-increments",
team_id="team-finalize-after-increments",
user_id=None,
response_cost=0.25,
budget_reservation=budget_reservation,
)
assert incremented_counters == ["spend:team:team-finalize-after-increments"]
assert budget_reservation["finalized"] is True
assert counter_cache.in_memory_cache.get_cache(
key="spend:key:key-finalize-after-increments"
) == pytest.approx(0.25)
finally:
ps.spend_counter_cache = orig_counter
ps.user_api_key_cache = orig_user
@pytest.mark.asyncio
async def test_increment_spend_counters_finalizes_none_cost_reservation():
from litellm.caching.dual_cache import DualCache
from litellm.proxy.proxy_server import increment_spend_counters
counter_cache = DualCache()
counter_cache.in_memory_cache.set_cache(
key="spend:key:key-finalize-none-cost",
value=0.5,
)
budget_reservation = {
"reserved_cost": 0.5,
"entries": [
{
"counter_key": "spend:key:key-finalize-none-cost",
"entity_type": "Key",
"entity_id": "key-finalize-none-cost",
"reserved_cost": 0.5,
"applied_adjustment": 0.0,
}
],
"finalized": False,
}
import litellm.proxy.proxy_server as ps
orig_counter = ps.spend_counter_cache
ps.spend_counter_cache = counter_cache
try:
await increment_spend_counters(
token="key-finalize-none-cost",
team_id=None,
user_id=None,
response_cost=None,
budget_reservation=budget_reservation,
)
assert budget_reservation["finalized"] is True
assert counter_cache.in_memory_cache.get_cache(
key="spend:key:key-finalize-none-cost"
) == pytest.approx(0.0)
finally:
ps.spend_counter_cache = orig_counter
@pytest.mark.asyncio
async def test_increment_spend_counters_invalidates_bad_reserved_counter_without_failing():
from litellm.caching.dual_cache import DualCache
from litellm.proxy.proxy_server import increment_spend_counters
counter_cache = DualCache()
budget_reservation = {
"reserved_cost": 0.5,
"entries": [
{
"counter_key": "spend:key:key-bad-reserved-counter",
"entity_type": "Key",
"entity_id": "key-bad-reserved-counter",
"reserved_cost": 0.5,
"applied_adjustment": 0.0,
}
],
"finalized": False,
}
import litellm.proxy.proxy_server as ps
orig_counter = ps.spend_counter_cache
ps.spend_counter_cache = counter_cache
try:
with patch(
"litellm.proxy.proxy_server.verbose_proxy_logger.warning"
) as mock_warning:
await increment_spend_counters(
token="key-bad-reserved-counter",
team_id=None,
user_id=None,
response_cost=0.25,
budget_reservation=budget_reservation,
)
mock_warning.assert_called_once()
assert budget_reservation["finalized"] is True
assert (
counter_cache.in_memory_cache.get_cache(
key="spend:key:key-bad-reserved-counter"
)
is None
)
finally:
ps.spend_counter_cache = orig_counter
@pytest.mark.asyncio
async def test_increment_spend_counter_invalidates_stale_cache_on_redis_failure():
from litellm.caching.dual_cache import DualCache
from litellm.proxy.proxy_server import _increment_spend_counter_cache
counter_cache = DualCache()
counter_cache.in_memory_cache.set_cache(key="spend:team:redis-fail", value=4.0)
fake_redis = AsyncMock()
fake_redis.async_increment = AsyncMock(side_effect=RuntimeError("redis down"))
fake_redis.async_delete_cache = AsyncMock()
counter_cache.redis_cache = fake_redis
import litellm.proxy.proxy_server as ps
orig_counter = ps.spend_counter_cache
ps.spend_counter_cache = counter_cache
try:
with pytest.raises(RuntimeError):
await _increment_spend_counter_cache(
counter_key="spend:team:redis-fail",
increment=0.5,
)
assert (
counter_cache.in_memory_cache.get_cache(key="spend:team:redis-fail") is None
)
fake_redis.async_delete_cache.assert_awaited_once_with(
key="spend:team:redis-fail"
)
finally:
ps.spend_counter_cache = orig_counter
@pytest.mark.asyncio
async def test_get_current_spend_reseeds_from_db_when_counter_missing():
"""
When both the Redis and in-memory counters are missing, the enforcement
read path must reseed from the authoritative DB, not fall back to the
caller-supplied stale value. Otherwise, every Redis TTL expiry lets a
request through against a stale in-process `team_membership.spend`.
"""
from litellm.caching.dual_cache import DualCache
from litellm.proxy.proxy_server import get_current_spend
counter_cache = DualCache()
recorded_seeds: list = []
async def record_set_cache(key, value, nx=False, **kwargs):
recorded_seeds.append({"key": key, "value": value, "nx": nx})
return True
fake_redis = AsyncMock()
fake_redis.async_set_cache = AsyncMock(side_effect=record_set_cache)
fake_redis.async_get_cache = AsyncMock(return_value=None)
counter_cache.redis_cache = fake_redis
# DB has authoritative spend=362.0; caller hands us stale fallback=30.0
# (the in-process team_membership.spend that hasn't caught up to DB).
db_row = MagicMock()
db_row.spend = 362.0
fake_prisma = MagicMock()
fake_prisma.db.litellm_teammembership.find_unique = AsyncMock(return_value=db_row)
import litellm.proxy.proxy_server as ps
orig_counter, orig_prisma = ps.spend_counter_cache, ps.prisma_client
ps.spend_counter_cache = counter_cache
ps.prisma_client = fake_prisma
try:
spend = await get_current_spend(
counter_key="spend:team_member:user-1:team-1",
fallback_spend=30.0,
)
assert spend == 362.0, (
f"expected DB reseed to return 362.0, got {spend} "
f"(fallback would have returned 30.0 and caused bypass)"
)
# Counter warmed via SET NX so subsequent reads are fast.
assert ("spend:team_member:user-1:team-1", 362.0, True) in [
(s["key"], s["value"], s["nx"]) for s in recorded_seeds
]
assert counter_cache.in_memory_cache.get_cache(
key="spend:team_member:user-1:team-1"
) == pytest.approx(362.0)
finally:
ps.spend_counter_cache = orig_counter
ps.prisma_client = orig_prisma
@pytest.mark.asyncio
async def test_get_current_spend_uses_fallback_when_db_unavailable():
"""
If prisma is unavailable and both counters are missing, the read path
must degrade to the caller-supplied fallback rather than raising.
"""
from litellm.caching.dual_cache import DualCache
from litellm.proxy.proxy_server import get_current_spend
counter_cache = DualCache()
fake_redis = AsyncMock()
fake_redis.async_get_cache = AsyncMock(return_value=None)
counter_cache.redis_cache = fake_redis
import litellm.proxy.proxy_server as ps
orig_counter, orig_prisma = ps.spend_counter_cache, ps.prisma_client
ps.spend_counter_cache = counter_cache
ps.prisma_client = None # simulate prisma unavailable
try:
spend = await get_current_spend(
counter_key="spend:team_member:user-1:team-1",
fallback_spend=15.5,
)
assert spend == 15.5
finally:
ps.spend_counter_cache = orig_counter
ps.prisma_client = orig_prisma
@pytest.mark.asyncio
async def test_get_current_spend_coalesces_concurrent_reseeds():
"""
When several concurrent calls hit a cold counter on the same pod,
only one DB query should fire. The rest should wait for the lock,
re-check the warmed counter, and return without hitting the DB.
"""
import asyncio as _asyncio
from litellm.caching.dual_cache import DualCache
from litellm.proxy.proxy_server import get_current_spend
counter_cache = DualCache()
counter_key = "spend:team_member:user-1:team-coalesce"
# Track DB query calls and inject a small delay so the concurrent
# callers actually overlap in the lock-acquire window.
db_call_count = 0
async def slow_find_unique(**kwargs):
nonlocal db_call_count
db_call_count += 1
await _asyncio.sleep(0.05)
row = MagicMock()
row.spend = 100.0
return row
fake_redis = AsyncMock()
redis_store: dict = {}
async def redis_get(key, **_):
return redis_store.get(key)
async def redis_increment(key, value, **_):
redis_store[key] = (redis_store.get(key) or 0.0) + value
return redis_store[key]
async def redis_set_cache(key, value, nx=False, **_):
if nx and key in redis_store:
return False
redis_store[key] = float(value)
return True
fake_redis.async_get_cache = AsyncMock(side_effect=redis_get)
fake_redis.async_increment = AsyncMock(side_effect=redis_increment)
fake_redis.async_set_cache = AsyncMock(side_effect=redis_set_cache)
counter_cache.redis_cache = fake_redis
fake_prisma = MagicMock()
fake_prisma.db.litellm_teammembership.find_unique = AsyncMock(
side_effect=slow_find_unique
)
import litellm.proxy.proxy_server as ps
orig_counter, orig_prisma = ps.spend_counter_cache, ps.prisma_client
ps.spend_counter_cache = counter_cache
ps.prisma_client = fake_prisma
try:
results = await _asyncio.gather(
*[
get_current_spend(counter_key=counter_key, fallback_spend=0.0)
for _ in range(5)
]
)
assert results == [100.0] * 5, f"all callers should see DB value, got {results}"
assert (
db_call_count == 1
), f"expected exactly 1 DB query for 5 concurrent reseeds, got {db_call_count}"
finally:
ps.spend_counter_cache = orig_counter
ps.prisma_client = orig_prisma
@pytest.mark.asyncio
async def test_get_current_spend_uses_db_zero_over_stale_fallback():
"""
When DB returns spend=0 (e.g. just after a budget period reset), the
authoritative DB value must win over a stale non-zero fallback. The
fallback in production is the in-process team_membership.spend, which
can still hold the pre-reset value across pods.
"""
from litellm.caching.dual_cache import DualCache
from litellm.proxy.proxy_server import get_current_spend
counter_cache = DualCache()
fake_redis = AsyncMock()
fake_redis.async_get_cache = AsyncMock(return_value=None)
counter_cache.redis_cache = fake_redis
db_row = MagicMock()
db_row.spend = 0.0
fake_prisma = MagicMock()
fake_prisma.db.litellm_teammembership.find_unique = AsyncMock(return_value=db_row)
import litellm.proxy.proxy_server as ps
orig_counter, orig_prisma = ps.spend_counter_cache, ps.prisma_client
ps.spend_counter_cache = counter_cache
ps.prisma_client = fake_prisma
try:
spend = await get_current_spend(
counter_key="spend:team_member:user-1:team-after-reset",
fallback_spend=42.0,
)
assert (
spend == 0.0
), f"DB authoritative 0 must override stale fallback 42, got {spend}"
finally:
ps.spend_counter_cache = orig_counter
ps.prisma_client = orig_prisma
@pytest.mark.asyncio
async def test_concurrent_read_and_write_paths_share_one_db_query():
"""
The read path (`get_current_spend`) and the write path
(`_init_and_increment_spend_counter`) both reseed cold counters from
the DB. They must share the per-counter lock so a concurrent pre-call
enforcement read and post-call increment for the same counter collapse
to one DB query, not two.
"""
import asyncio as _asyncio
from litellm.caching.dual_cache import DualCache
from litellm.proxy.proxy_server import (
_init_and_increment_spend_counter,
get_current_spend,
)
counter_cache = DualCache()
counter_key = "spend:team_member:user-1:team-cross-path"
db_call_count = 0
async def slow_find_unique(**kwargs):
nonlocal db_call_count
db_call_count += 1
await _asyncio.sleep(0.05)
row = MagicMock()
row.spend = 50.0
return row
redis_store: dict = {}
async def redis_get(key, **_):
return redis_store.get(key)
async def redis_increment(key, value, **_):
redis_store[key] = (redis_store.get(key) or 0.0) + value
return redis_store[key]
async def redis_set_cache(key, value, nx=False, **_):
if nx and key in redis_store:
return False
redis_store[key] = float(value)
return True
fake_redis = AsyncMock()
fake_redis.async_get_cache = AsyncMock(side_effect=redis_get)
fake_redis.async_increment = AsyncMock(side_effect=redis_increment)
fake_redis.async_set_cache = AsyncMock(side_effect=redis_set_cache)
counter_cache.redis_cache = fake_redis
fake_prisma = MagicMock()
fake_prisma.db.litellm_teammembership.find_unique = AsyncMock(
side_effect=slow_find_unique
)
import litellm.proxy.proxy_server as ps
orig_counter, orig_prisma, orig_user = (
ps.spend_counter_cache,
ps.prisma_client,
ps.user_api_key_cache,
)
ps.spend_counter_cache = counter_cache
ps.prisma_client = fake_prisma
ps.user_api_key_cache = DualCache()
try:
results = await _asyncio.gather(
get_current_spend(counter_key=counter_key, fallback_spend=0.0),
_init_and_increment_spend_counter(
counter_key=counter_key,
source_cache_key="ignored",
increment=1.5,
),
get_current_spend(counter_key=counter_key, fallback_spend=0.0),
)
assert (
db_call_count == 1
), f"expected 1 DB query for concurrent read+write+read, got {db_call_count}"
# Read-path callers see the warmed counter; the write path's
# increment may or may not have landed by then, so accept either
# the seeded value or seeded+increment.
assert results[0] in (50.0, 51.5), f"got {results[0]}"
assert results[2] in (50.0, 51.5), f"got {results[2]}"
finally:
ps.spend_counter_cache = orig_counter
ps.prisma_client = orig_prisma
ps.user_api_key_cache = orig_user
@pytest.mark.asyncio
async def test_reseed_locks_dict_is_bounded():
"""
`SpendCounterReseed._locks` is an LRU bounded at
`SPEND_COUNTER_RESEED_LOCKS_MAX_SIZE` to prevent unbounded growth in
long-lived deployments with high counter-key churn. Inserting more
than the cap evicts the oldest entries.
"""
import litellm.constants as constants
from litellm.proxy.db.spend_counter_reseed import SpendCounterReseed
orig_locks = SpendCounterReseed._locks.copy()
SpendCounterReseed._locks.clear()
orig_max = constants.SPEND_COUNTER_RESEED_LOCKS_MAX_SIZE
constants.SPEND_COUNTER_RESEED_LOCKS_MAX_SIZE = 5
# The class reads the constant via module-level import, so patch the
# module-level name on the spend_counter_reseed module too.
import litellm.proxy.db.spend_counter_reseed as scr
orig_module_max = scr.SPEND_COUNTER_RESEED_LOCKS_MAX_SIZE
scr.SPEND_COUNTER_RESEED_LOCKS_MAX_SIZE = 5
try:
for i in range(7):
await SpendCounterReseed._get_lock(f"spend:key:test-key-{i}")
assert (
len(SpendCounterReseed._locks) == 5
), f"got {len(SpendCounterReseed._locks)}"
# Oldest two evicted
assert "spend:key:test-key-0" not in SpendCounterReseed._locks
assert "spend:key:test-key-1" not in SpendCounterReseed._locks
# Most recent retained
assert "spend:key:test-key-6" in SpendCounterReseed._locks
finally:
constants.SPEND_COUNTER_RESEED_LOCKS_MAX_SIZE = orig_max
scr.SPEND_COUNTER_RESEED_LOCKS_MAX_SIZE = orig_module_max
SpendCounterReseed._locks.clear()
SpendCounterReseed._locks.update(orig_locks)
@pytest.mark.asyncio
async def test_reseed_warms_cache_even_on_zero_db_spend():
"""
When DB returns 0.0 (fresh entity / just after reset), the cache must
still be warmed so subsequent reads hit the cache instead of issuing
another DB query. Skipping the warm causes O(requests) DB load on
zero-spend entities.
"""
from litellm.caching.dual_cache import DualCache
from litellm.proxy.proxy_server import get_current_spend
counter_cache = DualCache()
counter_key = "spend:team_member:user-1:team-zero-warm"
redis_store: dict = {}
async def redis_get(key, **_):
return redis_store.get(key)
async def redis_increment(key, value, **_):
redis_store[key] = (redis_store.get(key) or 0.0) + value
return redis_store[key]
async def redis_set_cache(key, value, nx=False, **_):
if nx and key in redis_store:
return False
redis_store[key] = float(value)
return True
fake_redis = AsyncMock()
fake_redis.async_get_cache = AsyncMock(side_effect=redis_get)
fake_redis.async_increment = AsyncMock(side_effect=redis_increment)
fake_redis.async_set_cache = AsyncMock(side_effect=redis_set_cache)
counter_cache.redis_cache = fake_redis
db_call_count = 0
async def find_unique(**kwargs):
nonlocal db_call_count
db_call_count += 1
row = MagicMock()
row.spend = 0.0
return row
fake_prisma = MagicMock()
fake_prisma.db.litellm_teammembership.find_unique = AsyncMock(
side_effect=find_unique
)
import litellm.proxy.proxy_server as ps
orig_counter, orig_prisma = ps.spend_counter_cache, ps.prisma_client
ps.spend_counter_cache = counter_cache
ps.prisma_client = fake_prisma
try:
# First call: cold cache, hits DB, returns 0.
spend1 = await get_current_spend(counter_key=counter_key, fallback_spend=0.0)
# Second call: cache should be warmed at 0, no second DB query.
spend2 = await get_current_spend(counter_key=counter_key, fallback_spend=0.0)
assert spend1 == 0.0 and spend2 == 0.0
assert (
db_call_count == 1
), f"second read should hit warmed cache, got {db_call_count} DB queries"
assert redis_store.get(counter_key) == 0.0, "cache must be warmed at 0"
finally:
ps.spend_counter_cache = orig_counter
ps.prisma_client = orig_prisma
# -----------------------------------------------------------------------------
# /config/update — critical paths only.
#
# These exercise the four behaviors that broke or changed in the rewrite of
# update_config (litellm/proxy/proxy_server.py): targeted per-section writes,
# the removal of the store_model_in_db gate, env var encryption, and the
# success_callback / litellm_settings merge semantics. All other branches
# (auth, missing-DB, slack auto-enable, router_settings merge) are covered
# implicitly or by upstream tests.
# -----------------------------------------------------------------------------
class _FakeRow:
def __init__(self, param_name, param_value):
self.param_name = param_name
self.param_value = param_value
class _FakeLitellmConfig:
def __init__(self, initial_rows=None):
self.rows = dict(initial_rows or {})
self.upsert_calls: list = []
self.find_first = AsyncMock(side_effect=self._find_first)
self.upsert = AsyncMock(side_effect=self._upsert)
async def _find_first(self, where=None):
if where and "param_name" in where:
name = where["param_name"]
if name in self.rows:
return _FakeRow(name, self.rows[name])
return None
async def _upsert(self, where=None, data=None):
name = where["param_name"]
raw = data["update"]["param_value"]
value = json.loads(raw) if isinstance(raw, str) else raw
self.rows[name] = value
self.upsert_calls.append((name, value))
class _FakePrismaClient:
def __init__(self, initial_rows=None):
self.db = mock.MagicMock()
self.db.litellm_config = _FakeLitellmConfig(initial_rows=initial_rows)
self.jsonify_object = lambda obj: obj
@pytest.fixture
def _update_config_setup(monkeypatch):
"""Install fakes for the /config/update endpoint and return (client, prisma)."""
from litellm.proxy._types import LitellmUserRoles, UserAPIKeyAuth
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth as auth_dep
def _install(initial_rows=None, store_model_in_db=True):
prisma = _FakePrismaClient(initial_rows=initial_rows)
monkeypatch.setattr("litellm.proxy.proxy_server.prisma_client", prisma)
monkeypatch.setattr(
"litellm.proxy.proxy_server.store_model_in_db", store_model_in_db
)
monkeypatch.setattr(
"litellm.proxy.proxy_server.encrypt_value_helper",
lambda value, **_: f"enc:{value}",
)
monkeypatch.setattr(
"litellm.proxy.proxy_server.invalidate_config_param",
AsyncMock(return_value=None),
)
from litellm.proxy.proxy_server import proxy_config as real_proxy_config
monkeypatch.setattr(
real_proxy_config, "add_deployment", AsyncMock(return_value=None)
)
original_overrides = app.dependency_overrides.copy()
app.dependency_overrides[auth_dep] = lambda: UserAPIKeyAuth(
user_id="test_admin",
user_role=LitellmUserRoles.PROXY_ADMIN,
api_key="sk-1234",
)
client = TestClient(app)
def _restore():
app.dependency_overrides = original_overrides
return client, prisma, _restore
return _install
def test_update_config_writes_only_sent_section(_update_config_setup):
"""A request that only touches general_settings must not write any other
section row, and must leave previously-written rows byte-identical."""
client, prisma, restore = _update_config_setup(
initial_rows={
"litellm_settings": {"drop_params": True},
"environment_variables": {"FOO": "enc:bar"},
}
)
try:
resp = client.post(
"/config/update",
json={"general_settings": {"store_prompts_in_spend_logs": True}},
)
assert resp.status_code == 200
written = {name for name, _ in prisma.db.litellm_config.upsert_calls}
assert written == {"general_settings"}
assert prisma.db.litellm_config.rows["litellm_settings"] == {
"drop_params": True
}
assert prisma.db.litellm_config.rows["environment_variables"] == {
"FOO": "enc:bar"
}
finally:
restore()
def test_update_config_env_var_round_trip_not_double_encrypted(
_update_config_setup, monkeypatch
):
"""Endpoint-level regression for the /config/update double-encryption bug.
The Admin UI reads config back via /get/config/callbacks (which returns
the stored, still-encrypted value) and re-POSTs it on the next save. The
handler must NOT stack a second encryption layer on the re-submitted
ciphertext, and must leave untouched keys byte-identical.
Uses an invertible fake encrypt/decrypt pair ("enc:" prefix) so the
decrypt-then-encrypt chokepoint round-trips faithfully. On the pre-fix
code this stored "enc:enc:..."; the assertions below would fail there.
"""
def _fake_decrypt(
value, key=None, exception_type="error", return_original_value=False
):
if isinstance(value, str) and value.startswith("enc:"):
return value[len("enc:") :]
return value if return_original_value else None
monkeypatch.setattr(
"litellm.proxy.proxy_server.decrypt_value_helper", _fake_decrypt
)
client, prisma, restore = _update_config_setup(
initial_rows={"environment_variables": {"PREEXISTING_KEY": "enc:keepme"}}
)
try:
# First write: plaintext in -> single-encrypted at rest.
resp = client.post(
"/config/update",
json={"environment_variables": {"LANGFUSE_SECRET_KEY": "sk-secret"}},
)
assert resp.status_code == 200
stored = prisma.db.litellm_config.rows["environment_variables"]
assert stored["LANGFUSE_SECRET_KEY"] == "enc:sk-secret"
# UI round-trip: re-POST the stored ciphertext (no field change).
resp = client.post(
"/config/update",
json={
"environment_variables": {
"LANGFUSE_SECRET_KEY": stored["LANGFUSE_SECRET_KEY"]
}
},
)
assert resp.status_code == 200
stored = prisma.db.litellm_config.rows["environment_variables"]
# The bug: this would be "enc:enc:sk-secret". The fix keeps it single.
assert stored["LANGFUSE_SECRET_KEY"] == "enc:sk-secret"
assert (
_fake_decrypt(stored["LANGFUSE_SECRET_KEY"], return_original_value=True)
== "sk-secret"
)
# Untouched key preserved byte-for-byte (only sent keys rewritten).
assert stored["PREEXISTING_KEY"] == "enc:keepme"
finally:
restore()
def test_update_config_can_flip_store_model_in_db_when_currently_false(
_update_config_setup,
):
"""The endpoint used to refuse all writes when store_model_in_db was
False, blocking the very request that would flip it to True."""
client, prisma, restore = _update_config_setup(store_model_in_db=False)
try:
resp = client.post(
"/config/update", json={"general_settings": {"store_model_in_db": True}}
)
assert resp.status_code == 200
assert (
prisma.db.litellm_config.rows["general_settings"]["store_model_in_db"]
is True
)
finally:
restore()
def test_update_config_environment_variables_encrypted_before_write(
_update_config_setup,
):
"""env var values must be encrypted before they hit the DB row."""
client, prisma, restore = _update_config_setup()
try:
resp = client.post(
"/config/update",
json={"environment_variables": {"OPENAI_API_KEY": "sk-secret"}},
)
assert resp.status_code == 200
stored = prisma.db.litellm_config.rows["environment_variables"]
assert stored == {"OPENAI_API_KEY": "enc:sk-secret"}
finally:
restore()
def test_update_config_litellm_settings_request_wins_for_non_callback_keys(
_update_config_setup,
):
"""Sending {"drop_params": False} when the row holds drop_params: True
must persist False (request wins). Untouched keys preserved."""
client, prisma, restore = _update_config_setup(
initial_rows={
"litellm_settings": {"drop_params": True, "set_verbose": True},
}
)
try:
resp = client.post(
"/config/update", json={"litellm_settings": {"drop_params": False}}
)
assert resp.status_code == 200
stored = prisma.db.litellm_config.rows["litellm_settings"]
assert stored["drop_params"] is False
assert stored["set_verbose"] is True
finally:
restore()
def test_update_config_success_callback_normalizes_existing_mixed_case(
_update_config_setup,
):
"""Existing mixed-case callback names (written elsewhere) must be
normalized to lowercase before union, otherwise the union dedup misses
against the lowercase incoming entry and delete_callback (lowercase
lookup) cannot find the original."""
client, prisma, restore = _update_config_setup(
initial_rows={"litellm_settings": {"success_callback": ["Langfuse", "SQS"]}}
)
try:
resp = client.post(
"/config/update",
json={"litellm_settings": {"success_callback": ["langfuse"]}},
)
assert resp.status_code == 200
stored = prisma.db.litellm_config.rows["litellm_settings"]["success_callback"]
assert set(stored) == {"langfuse", "sqs"}
finally:
restore()
# ---------------------------------------------------------------------------
# Lazy feature loading (LazyFeatureMiddleware) — verifies that optional
# routers are NOT imported at module load and ARE imported on first request
# to a matching path prefix. The same module isn't re-imported on subsequent
# requests.
# ---------------------------------------------------------------------------
class TestLazyFeatureRegistry:
"""Sanity checks on the registry shape — guards against accidental edits."""
def test_registry_entries_have_required_fields(self):
from litellm.proxy._lazy_features import LAZY_FEATURES, LazyFeature
assert len(LAZY_FEATURES) > 0
for feat in LAZY_FEATURES:
assert isinstance(feat, LazyFeature)
assert feat.name
assert feat.module_path
assert feat.path_prefixes
assert all(p.startswith("/") for p in feat.path_prefixes)
assert callable(feat.register_fn)
def test_registry_names_unique(self):
from litellm.proxy._lazy_features import LAZY_FEATURES
names = [f.name for f in LAZY_FEATURES]
assert len(names) == len(set(names)), "duplicate feature names"
def test_matches_covers_prefix_and_suffix(self):
"""``matches`` is the single matcher shared by the middleware (request
paths) and the warm endpoint (registered route paths), so a route that
only matches via suffix — e.g. ``/v1/a2a/{id}/message/send`` against the
``/a2a`` prefix — must still be claimed by the feature."""
from litellm.proxy._lazy_features import LazyFeature
feat = LazyFeature(
name="a2a",
module_path="json",
path_prefixes=("/a2a",),
path_suffixes=("/message/send",),
)
assert feat.matches("/a2a/abc/message/send")
assert feat.matches("/v1/a2a/abc/message/send")
assert feat.matches("/a2a/abc/.well-known/agent-card.json")
assert not feat.matches("/v1/a2a/discover")
assert not feat.matches("/unrelated")
class TestLazyFeaturesNotImportedAtStartup:
"""
The whole point of the refactor: gated feature modules must NOT be
present in `sys.modules` immediately after `proxy_server` imports.
"""
def test_heavy_modules_absent_at_startup(self):
# Static scan of proxy_server.py source — catches any top-level
# `from <lazy_module> import` that would defeat lazy loading.
# Importing proxy_server in a subprocess and diffing sys.modules
# would also work, but takes 60-120 s and flakes on slow CI runners.
import re
from pathlib import Path
from litellm.proxy._lazy_features import LAZY_FEATURES
proxy_server_src = (
Path(__file__).resolve().parents[3] / "litellm/proxy/proxy_server.py"
).read_text()
leaks = []
for feat in LAZY_FEATURES:
# Anchor at column 0 — indented imports inside function bodies
# are fine (deferred until the function runs).
pattern = (
rf"^(from\s+{re.escape(feat.module_path)}\s+import|"
rf"import\s+{re.escape(feat.module_path)})"
)
if re.search(pattern, proxy_server_src, re.MULTILINE):
leaks.append(feat.module_path)
assert not leaks, (
"proxy_server.py top-level imports a lazy feature module — these "
f"should be loaded via LazyFeatureMiddleware: {leaks}"
)
class TestLazyFeatureMiddleware:
"""Behavior of the middleware itself, exercised in isolation."""
@pytest.mark.asyncio
async def test_first_request_triggers_load_subsequent_does_not(self):
from fastapi import FastAPI
from litellm.proxy._lazy_features import (
LazyFeature,
LazyFeatureMiddleware,
)
loads = []
def fake_register(app, module):
loads.append(getattr(module, "__name__", "?"))
feat = LazyFeature(
name="dummy",
module_path="json", # any always-importable stdlib module
path_prefixes=("/dummy",),
register_fn=fake_register,
)
# Build a minimal ASGI receiver to satisfy the middleware contract
async def downstream(scope, receive, send):
# echo back; no-op handler
await send({"type": "http.response.start", "status": 200, "headers": []})
await send({"type": "http.response.body", "body": b""})
target_app = FastAPI()
mw = LazyFeatureMiddleware(downstream, fastapi_app=target_app, features=(feat,))
async def receive():
return {"type": "http.request", "body": b"", "more_body": False}
sent: list = []
async def send(message):
sent.append(message)
# First request matching the prefix triggers register
await mw(
{"type": "http", "path": "/dummy/x", "method": "GET", "headers": []},
receive,
send,
)
assert loads == ["json"]
# Second matching request must NOT re-register
sent.clear()
await mw(
{"type": "http", "path": "/dummy/y", "method": "GET", "headers": []},
receive,
send,
)
assert loads == ["json"], "register_fn called twice for the same feature"
# Non-matching path must not trigger anything
await mw(
{"type": "http", "path": "/unrelated", "method": "GET", "headers": []},
receive,
send,
)
assert loads == ["json"]
@pytest.mark.asyncio
@pytest.mark.parametrize(
"server_root_path,request_path,should_load,case",
[
# SERVER_ROOT_PATH set: incoming path includes prefix → strip and match.
("/api/v1", "/api/v1/dummy/x", True, "root_path strip + match"),
# Trailing-slash env var must be normalized.
("/api/v1/", "/api/v1/dummy/x", True, "trailing-slash env normalization"),
# Reverse proxy already stripped the prefix → original path still matches.
("/api/v1", "/dummy/x", True, "pre-stripped path still loads"),
# No SERVER_ROOT_PATH set → unchanged behavior.
("", "/dummy/x", True, "no root path"),
# SERVER_ROOT_PATH=/ must be a no-op (not strip every leading slash).
("/", "/dummy/x", True, "root_path='/' is no-op"),
# Boundary check: /apiv2 must not match root /api.
("/api", "/apiv2/foo", False, "boundary check prevents false match"),
# Genuine non-match under root_path.
("/api/v1", "/api/v1/unrelated", False, "unrelated path under root"),
],
)
async def test_root_path_handling(
self, monkeypatch, server_root_path, request_path, should_load, case
):
"""
The middleware must strip SERVER_ROOT_PATH before prefix-matching so
lazy features load under deployments that set a server root path,
while handling boundary, trailing-slash, and reverse-proxy edge cases
correctly.
"""
from fastapi import FastAPI
from litellm.proxy._lazy_features import (
LazyFeature,
LazyFeatureMiddleware,
)
monkeypatch.setenv("SERVER_ROOT_PATH", server_root_path)
loads = []
def fake_register(app, module):
loads.append(getattr(module, "__name__", "?"))
feat = LazyFeature(
name=f"dummy_{case}",
module_path="json",
path_prefixes=("/dummy",),
register_fn=fake_register,
)
async def downstream(scope, receive, send):
await send({"type": "http.response.start", "status": 200, "headers": []})
await send({"type": "http.response.body", "body": b""})
target_app = FastAPI()
mw = LazyFeatureMiddleware(downstream, fastapi_app=target_app, features=(feat,))
async def receive():
return {"type": "http.request", "body": b"", "more_body": False}
async def send(message):
pass
await mw(
{
"type": "http",
"path": request_path,
"method": "GET",
"headers": [],
},
receive,
send,
)
if should_load:
assert loads == ["json"], f"{case}: expected feature to load"
else:
assert loads == [], f"{case}: feature must not load"
@pytest.mark.asyncio
async def test_concurrent_first_requests_only_register_once(self):
"""
Two requests to the same prefix arriving in parallel must result in
exactly one `register_fn` invocation — the lock prevents the import +
register from racing with itself.
"""
from fastapi import FastAPI
from litellm.proxy._lazy_features import (
LazyFeature,
LazyFeatureMiddleware,
)
loads = []
def slow_register(app, module):
loads.append(getattr(module, "__name__", "?"))
feat = LazyFeature(
name="dummy_concurrent",
module_path="json",
path_prefixes=("/dummy_c",),
register_fn=slow_register,
)
async def downstream(scope, receive, send):
await send({"type": "http.response.start", "status": 200, "headers": []})
await send({"type": "http.response.body", "body": b""})
target_app = FastAPI()
mw = LazyFeatureMiddleware(downstream, fastapi_app=target_app, features=(feat,))
async def receive():
return {"type": "http.request", "body": b"", "more_body": False}
sent: list = []
async def send(message):
sent.append(message)
async def hit():
await mw(
{
"type": "http",
"path": "/dummy_c/x",
"method": "GET",
"headers": [],
},
receive,
send,
)
await asyncio.gather(hit(), hit(), hit(), hit(), hit())
assert loads == [
"json"
], f"expected one registration despite concurrent first hits, got {loads}"
@pytest.mark.asyncio
async def test_failing_import_does_not_loop(self):
"""
If a feature's module can't be imported, the middleware should mark it
loaded anyway so subsequent requests don't repeatedly retry the failing
import (which would amplify the cost on every request).
"""
from fastapi import FastAPI
from litellm.proxy._lazy_features import (
LazyFeature,
LazyFeatureMiddleware,
)
attempts = []
def fail_register(app, module):
attempts.append("called")
raise RuntimeError("boom")
feat = LazyFeature(
name="failing",
module_path="json",
path_prefixes=("/fail",),
register_fn=fail_register,
)
async def downstream(scope, receive, send):
await send({"type": "http.response.start", "status": 200, "headers": []})
await send({"type": "http.response.body", "body": b""})
target_app = FastAPI()
mw = LazyFeatureMiddleware(downstream, fastapi_app=target_app, features=(feat,))
async def receive():
return {"type": "http.request", "body": b"", "more_body": False}
sent: list = []
async def send(message):
sent.append(message)
for _ in range(3):
await mw(
{"type": "http", "path": "/fail/x", "method": "GET", "headers": []},
receive,
send,
)
assert attempts == [
"called"
], f"failing register_fn should be invoked once, not on every request; got {attempts}"
@pytest.mark.asyncio
async def test_get_current_spend_redis_clean_miss_skips_stale_in_memory():
"""When Redis is reachable and cleanly returns None (TTL expired,
counter genuinely absent), the read must reseed from DB - NOT fall
through to per-pod in-memory which only contains this pod's writes.
Pre-fix in multi-pod deployments, in-memory contained a stale local
subset (e.g. $30) while DB had the true cross-pod total ($500). The
fall-through returned $30, enforcement passed, bypass.
"""
from litellm.caching.dual_cache import DualCache
from litellm.proxy.proxy_server import get_current_spend
counter_cache = DualCache()
counter_key = "spend:team_member:user-1:team-1"
# Per-pod stale in-memory: only this pod's writes, not cross-pod truth.
counter_cache.in_memory_cache.set_cache(key=counter_key, value=30.0)
# Redis cleanly returns None (key expired or never written on this pod).
fake_redis = AsyncMock()
fake_redis.async_get_cache = AsyncMock(return_value=None)
fake_redis.async_increment = AsyncMock(return_value=500.0)
counter_cache.redis_cache = fake_redis
# DB has the authoritative cross-pod spend.
db_row = MagicMock()
db_row.spend = 500.0
fake_prisma = MagicMock()
fake_prisma.db.litellm_teammembership.find_unique = AsyncMock(return_value=db_row)
import litellm.proxy.proxy_server as ps
orig_counter, orig_prisma = ps.spend_counter_cache, ps.prisma_client
ps.spend_counter_cache = counter_cache
ps.prisma_client = fake_prisma
try:
spend = await get_current_spend(counter_key=counter_key, fallback_spend=0.0)
assert spend == 500.0, (
f"expected DB-authoritative 500.0 on clean Redis miss, got {spend} "
f"(stale per-pod in-memory $30 would have caused multi-pod bypass)"
)
finally:
ps.spend_counter_cache = orig_counter
ps.prisma_client = orig_prisma
@pytest.mark.asyncio
async def test_get_current_spend_redis_error_falls_back_to_in_memory():
"""When Redis raises, the read should still degrade to in-memory rather
than going straight to DB - in-memory is at least same-pod-fresh and
cheaper than a DB query during a Redis outage."""
from litellm.caching.dual_cache import DualCache
from litellm.proxy.proxy_server import get_current_spend
counter_cache = DualCache()
counter_key = "spend:team_member:user-1:team-1"
counter_cache.in_memory_cache.set_cache(key=counter_key, value=42.0)
fake_redis = AsyncMock()
fake_redis.async_get_cache = AsyncMock(side_effect=ConnectionError("redis down"))
counter_cache.redis_cache = fake_redis
fake_prisma = MagicMock()
fake_prisma.db.litellm_teammembership.find_unique = AsyncMock(
return_value=MagicMock(spend=999.0)
)
import litellm.proxy.proxy_server as ps
orig_counter, orig_prisma = ps.spend_counter_cache, ps.prisma_client
ps.spend_counter_cache = counter_cache
ps.prisma_client = fake_prisma
try:
spend = await get_current_spend(counter_key=counter_key, fallback_spend=0.0)
assert spend == 42.0, (
f"expected in-memory fallback 42.0 on Redis error, got {spend} "
f"(should not have hit DB when Redis errored)"
)
# DB query should NOT have fired - in-memory short-circuits.
fake_prisma.db.litellm_teammembership.find_unique.assert_not_awaited()
finally:
ps.spend_counter_cache = orig_counter
ps.prisma_client = orig_prisma
def test_realtime_websocket_route_aliases_registered():
"""Realtime sessions reach the proxy via three path aliases stacked on
`realtime_websocket_endpoint`. Dropping any of them silently 405s
WebSocket upgrades because the catch-all `/openai/{endpoint:path}`
HTTP passthrough only declares HTTP methods. The aliases must also be
in `LiteLLMRoutes.openai_routes` (so non-admin / team / key-scoped
auth allows them) and in `API_ROUTE_TO_CALL_TYPES` (so call-type-aware
logic such as guardrails can resolve the realtime call type)."""
from starlette.routing import WebSocketRoute
from litellm.proxy._types import LiteLLMRoutes
from litellm.proxy.proxy_server import app
from litellm.types.utils import API_ROUTE_TO_CALL_TYPES, CallTypes
websocket_paths = {
route.path for route in app.routes if isinstance(route, WebSocketRoute)
}
openai_routes = LiteLLMRoutes.openai_routes.value
for expected in ("/openai/v1/realtime", "/v1/realtime", "/realtime"):
assert expected in websocket_paths, (
f"{expected!r} missing from registered WebSocket routes; the "
f"realtime endpoint will 405 for clients hitting this path."
)
assert expected in openai_routes, (
f"{expected!r} missing from LiteLLMRoutes.openai_routes; "
f"non-admin / team / key-scoped users will get 403 on this path."
)
assert API_ROUTE_TO_CALL_TYPES.get(expected) == [CallTypes.arealtime], (
f"{expected!r} missing from API_ROUTE_TO_CALL_TYPES; call-type "
f"resolution will return None and break call-type-aware features."
)
class TestTransformRequestBannedParams:
"""
/utils/transform_request applies the same banned-param check as LLM endpoints.
Without this check, any authenticated user could supply aws_sts_endpoint,
api_base, etc. and have the server forward its credentials to an
attacker-controlled endpoint during SDK credential resolution.
"""
@pytest.fixture
def client(self):
mock_auth = UserAPIKeyAuth(
user_id="test-internal",
user_role=LitellmUserRoles.INTERNAL_USER,
)
original = app.dependency_overrides.copy()
app.dependency_overrides[user_api_key_auth] = lambda: mock_auth
try:
yield TestClient(app)
finally:
app.dependency_overrides = original
@pytest.mark.parametrize(
"banned",
[
"aws_sts_endpoint",
"api_base",
"aws_web_identity_token",
"vertex_credentials",
],
)
def test_banned_params_rejected_for_all_users(self, client, banned):
"""Banned params must be blocked for any authenticated user."""
response = client.post(
"/utils/transform_request",
json={
"call_type": "completion",
"request_body": {
"model": "gpt-3.5-turbo",
banned: "https://attacker.example",
},
},
)
assert response.status_code == 400, (
f"Expected 400 for banned param '{banned}', "
f"got {response.status_code}: {response.json()}"
)
class TestSortModelsByDisplayName:
"""Regression: team BYOK rows persist an internal `model_name` like
`model_name_{team_id}_{uuid}` and expose the user-facing name via
`model_info.team_public_model_name`. Sorting must use the displayed
name so BYOK rows interleave with non-BYOK rows alphabetically —
otherwise they clump at the end on their opaque IDs even though the
UI shows them under a normal-looking name.
"""
def test_byok_models_sort_by_team_public_model_name(self):
from litellm.proxy.proxy_server import _sort_models
models = [
{"model_name": "claude-haiku-4-5", "model_info": {}},
{
# Opaque internal name; UI displays team_public_model_name.
"model_name": "model_name_team-1_abc123",
"model_info": {"team_public_model_name": "anthropic/claude"},
},
{"model_name": "gpt-4o", "model_info": {}},
]
sorted_models = _sort_models(
all_models=models, sort_by="model_name", sort_order="asc"
)
displayed_order = [
m["model_info"].get("team_public_model_name") or m["model_name"]
for m in sorted_models
]
assert displayed_order == [
"anthropic/claude",
"claude-haiku-4-5",
"gpt-4o",
]
def test_byok_models_sort_descending_by_display_name(self):
from litellm.proxy.proxy_server import _sort_models
models = [
{"model_name": "claude-haiku-4-5", "model_info": {}},
{
"model_name": "model_name_team-1_zzz",
"model_info": {"team_public_model_name": "zeta/model"},
},
{"model_name": "gpt-4o", "model_info": {}},
]
sorted_models = _sort_models(
all_models=models, sort_by="model_name", sort_order="desc"
)
displayed_order = [
m["model_info"].get("team_public_model_name") or m["model_name"]
for m in sorted_models
]
assert displayed_order == [
"zeta/model",
"gpt-4o",
"claude-haiku-4-5",
]
def test_empty_team_public_model_name_falls_back_to_model_name(self):
# Empty string for team_public_model_name (not None) must still
# fall back to model_name — otherwise BYOK rows with a blank
# display name would sort to the top.
from litellm.proxy.proxy_server import _sort_models
models = [
{"model_name": "alpha", "model_info": {"team_public_model_name": ""}},
{"model_name": "beta", "model_info": {}},
]
sorted_models = _sort_models(
all_models=models, sort_by="model_name", sort_order="asc"
)
assert [m["model_name"] for m in sorted_models] == ["alpha", "beta"]