mirror of
https://github.com/tiennm99/litellm.git
synced 2026-06-17 14:48:44 +00:00
c908505e6a
* fix(proxy): omit OpenAI [DONE] on google-genai streamGenerateContent google-genai SDK uses ?alt=sse and cannot parse the proxy's trailing data: [DONE] chunk. Skip that terminator for agenerate_content_stream. Co-authored-by: Cursor <cursoragent@cursor.com> * fix(proxy): address Greptile review on google-genai stream fix Always yield stream error_message; only gate data: [DONE] on the skip flag. Set _litellm_skip_openai_stream_done in google_endpoints instead of common_request_processing. Co-authored-by: Cursor <cursoragent@cursor.com> --------- Co-authored-by: Cursor <cursoragent@cursor.com> Co-authored-by: mateo-berri <277851410+mateo-berri@users.noreply.github.com>
7799 lines
283 KiB
Python
7799 lines
283 KiB
Python
import asyncio
|
||
import importlib
|
||
import json
|
||
import os
|
||
import socket
|
||
import subprocess
|
||
import sys
|
||
from datetime import datetime, timedelta, timezone
|
||
from pathlib import Path
|
||
from unittest import mock
|
||
from unittest.mock import AsyncMock, MagicMock, mock_open, patch
|
||
|
||
import click
|
||
import httpx
|
||
import pytest
|
||
import yaml
|
||
from fastapi import FastAPI
|
||
from fastapi.staticfiles import StaticFiles
|
||
from fastapi.testclient import TestClient
|
||
|
||
sys.path.insert(
|
||
0, os.path.abspath("../../..")
|
||
) # Adds the parent directory to the system-path
|
||
|
||
import litellm
|
||
from litellm.proxy._types import LitellmUserRoles, UserAPIKeyAuth
|
||
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
||
from litellm.proxy.proxy_server import app, initialize
|
||
from litellm.utils import _invalidate_model_cost_lowercase_map
|
||
|
||
example_embedding_result = {
|
||
"object": "list",
|
||
"data": [
|
||
{
|
||
"object": "embedding",
|
||
"index": 0,
|
||
"embedding": [
|
||
-0.006929283495992422,
|
||
-0.005336422007530928,
|
||
-4.547132266452536e-05,
|
||
-0.024047505110502243,
|
||
-0.006929283495992422,
|
||
-0.005336422007530928,
|
||
-4.547132266452536e-05,
|
||
-0.024047505110502243,
|
||
-0.006929283495992422,
|
||
-0.005336422007530928,
|
||
-4.547132266452536e-05,
|
||
-0.024047505110502243,
|
||
],
|
||
}
|
||
],
|
||
"model": "text-embedding-3-small",
|
||
"usage": {"prompt_tokens": 5, "total_tokens": 5},
|
||
}
|
||
|
||
|
||
def mock_patch_aembedding():
|
||
return mock.patch(
|
||
"litellm.proxy.proxy_server.llm_router.aembedding",
|
||
return_value=example_embedding_result,
|
||
)
|
||
|
||
|
||
@pytest.fixture(scope="function")
|
||
def client_no_auth():
|
||
# Assuming litellm.proxy.proxy_server is an object
|
||
from litellm.proxy.proxy_server import cleanup_router_config_variables
|
||
|
||
cleanup_router_config_variables()
|
||
filepath = os.path.dirname(os.path.abspath(__file__))
|
||
config_fp = f"{filepath}/test_configs/test_config_no_auth.yaml"
|
||
# initialize can get run in parallel, it sets specific variables for the fast api app, sinc eit gets run in parallel different tests use the wrong variables
|
||
asyncio.run(initialize(config=config_fp, debug=True))
|
||
return TestClient(app)
|
||
|
||
|
||
def test_login_v2_returns_redirect_url_and_sets_cookie(monkeypatch):
|
||
mock_login_result = {"user_id": "test-user"}
|
||
mock_prisma_client = MagicMock()
|
||
mock_authenticate_user = AsyncMock(return_value=mock_login_result)
|
||
mock_create_ui_token_object = MagicMock(return_value={"user_id": "test-user"})
|
||
mock_jwt_encode = MagicMock(return_value="signed-token")
|
||
|
||
monkeypatch.setattr(
|
||
"litellm.proxy.auth.login_utils.authenticate_user",
|
||
mock_authenticate_user,
|
||
)
|
||
monkeypatch.setattr(
|
||
"litellm.proxy.auth.login_utils.create_ui_token_object",
|
||
mock_create_ui_token_object,
|
||
)
|
||
monkeypatch.setattr("jwt.encode", mock_jwt_encode)
|
||
monkeypatch.setattr("litellm.proxy.proxy_server.master_key", "test-master-key")
|
||
monkeypatch.setattr("litellm.proxy.proxy_server.general_settings", {})
|
||
monkeypatch.setattr("litellm.proxy.proxy_server.premium_user", False)
|
||
monkeypatch.setattr("litellm.proxy.proxy_server.prisma_client", mock_prisma_client)
|
||
monkeypatch.setattr("litellm.proxy.utils.get_server_root_path", lambda: "")
|
||
monkeypatch.setattr("litellm.proxy.utils.get_proxy_base_url", lambda: None)
|
||
|
||
client = TestClient(app)
|
||
response = client.post(
|
||
"/v2/login",
|
||
json={"username": "alice", "password": "secret"},
|
||
)
|
||
|
||
assert response.status_code == 200
|
||
assert response.json() == {
|
||
"redirect_url": "http://testserver/ui/?login=success",
|
||
"token": "signed-token",
|
||
}
|
||
assert response.cookies.get("token") == "signed-token"
|
||
|
||
mock_authenticate_user.assert_awaited_once_with(
|
||
username="alice",
|
||
password="secret",
|
||
master_key="test-master-key",
|
||
prisma_client=mock_prisma_client,
|
||
)
|
||
mock_create_ui_token_object.assert_called_once_with(
|
||
login_result=mock_login_result,
|
||
general_settings={},
|
||
premium_user=False,
|
||
)
|
||
mock_jwt_encode.assert_called_once_with(
|
||
{"user_id": "test-user"},
|
||
"test-master-key",
|
||
algorithm="HS256",
|
||
)
|
||
|
||
|
||
def test_login_v2_returns_json_on_proxy_exception(monkeypatch):
|
||
"""Test that /v2/login returns JSON error when ProxyException is raised"""
|
||
from litellm.proxy._types import ProxyErrorTypes, ProxyException
|
||
|
||
mock_prisma_client = MagicMock()
|
||
mock_authenticate_user = AsyncMock(
|
||
side_effect=ProxyException(
|
||
message="Invalid credentials",
|
||
type=ProxyErrorTypes.auth_error,
|
||
param="password",
|
||
code=401,
|
||
)
|
||
)
|
||
|
||
monkeypatch.setattr(
|
||
"litellm.proxy.auth.login_utils.authenticate_user",
|
||
mock_authenticate_user,
|
||
)
|
||
monkeypatch.setattr("litellm.proxy.proxy_server.master_key", "test-master-key")
|
||
monkeypatch.setattr("litellm.proxy.proxy_server.prisma_client", mock_prisma_client)
|
||
|
||
client = TestClient(app)
|
||
response = client.post(
|
||
"/v2/login",
|
||
json={"username": "alice", "password": "wrong"},
|
||
)
|
||
|
||
assert response.status_code == 401
|
||
assert response.headers["content-type"] == "application/json"
|
||
data = response.json()
|
||
assert "error" in data
|
||
assert data["error"]["message"] == "Invalid credentials"
|
||
assert data["error"]["type"] == "auth_error"
|
||
|
||
|
||
def test_login_v2_returns_json_on_http_exception(monkeypatch):
|
||
"""Test that /v2/login converts HTTPException to JSON error response"""
|
||
from fastapi import HTTPException
|
||
|
||
mock_prisma_client = MagicMock()
|
||
mock_authenticate_user = AsyncMock(
|
||
side_effect=HTTPException(status_code=401, detail="Unauthorized")
|
||
)
|
||
|
||
monkeypatch.setattr(
|
||
"litellm.proxy.auth.login_utils.authenticate_user",
|
||
mock_authenticate_user,
|
||
)
|
||
monkeypatch.setattr("litellm.proxy.proxy_server.master_key", "test-master-key")
|
||
monkeypatch.setattr("litellm.proxy.proxy_server.prisma_client", mock_prisma_client)
|
||
|
||
client = TestClient(app)
|
||
response = client.post(
|
||
"/v2/login",
|
||
json={"username": "alice", "password": "secret"},
|
||
)
|
||
|
||
assert response.status_code == 401
|
||
assert response.headers["content-type"] == "application/json"
|
||
data = response.json()
|
||
assert "error" in data
|
||
assert isinstance(data["error"], dict)
|
||
|
||
|
||
def test_login_v2_returns_json_on_unexpected_exception(monkeypatch):
|
||
"""Test that /v2/login returns JSON error when unexpected exception occurs"""
|
||
mock_prisma_client = MagicMock()
|
||
mock_authenticate_user = AsyncMock(side_effect=ValueError("Unexpected error"))
|
||
|
||
monkeypatch.setattr(
|
||
"litellm.proxy.auth.login_utils.authenticate_user",
|
||
mock_authenticate_user,
|
||
)
|
||
monkeypatch.setattr("litellm.proxy.proxy_server.master_key", "test-master-key")
|
||
monkeypatch.setattr("litellm.proxy.proxy_server.prisma_client", mock_prisma_client)
|
||
|
||
client = TestClient(app)
|
||
response = client.post(
|
||
"/v2/login",
|
||
json={"username": "alice", "password": "secret"},
|
||
)
|
||
|
||
assert response.status_code == 500
|
||
assert response.headers["content-type"] == "application/json"
|
||
data = response.json()
|
||
assert "error" in data
|
||
assert isinstance(data["error"], dict)
|
||
assert "Unexpected error" in data["error"]["message"]
|
||
|
||
|
||
def test_login_v2_returns_json_on_invalid_json_body(monkeypatch):
|
||
"""Test that /v2/login returns JSON error when request body is invalid JSON"""
|
||
monkeypatch.setattr("litellm.proxy.proxy_server.master_key", "test-master-key")
|
||
|
||
client = TestClient(app)
|
||
response = client.post(
|
||
"/v2/login",
|
||
content="invalid json",
|
||
headers={"Content-Type": "application/json"},
|
||
)
|
||
|
||
assert response.status_code == 500
|
||
assert response.headers["content-type"] == "application/json"
|
||
data = response.json()
|
||
assert "error" in data
|
||
assert isinstance(data["error"], dict)
|
||
|
||
|
||
def test_login_v3_rejected_without_control_plane_url(monkeypatch):
|
||
"""v3/login returns 404 when control_plane_url is not configured."""
|
||
mock_prisma_client = MagicMock()
|
||
monkeypatch.setattr("litellm.proxy.proxy_server.master_key", "test-master-key")
|
||
monkeypatch.setattr("litellm.proxy.proxy_server.general_settings", {})
|
||
monkeypatch.setattr("litellm.proxy.proxy_server.prisma_client", mock_prisma_client)
|
||
|
||
client = TestClient(app)
|
||
response = client.post(
|
||
"/v3/login",
|
||
json={"username": "alice", "password": "secret"},
|
||
)
|
||
|
||
assert response.status_code == 404
|
||
assert "control_plane_url" in response.json()["error"]["message"]
|
||
|
||
|
||
def test_login_v3_returns_code(monkeypatch):
|
||
"""v3/login returns an opaque code, not the JWT directly."""
|
||
mock_prisma_client = MagicMock()
|
||
monkeypatch.setattr(
|
||
"litellm.proxy.auth.login_utils.authenticate_user",
|
||
AsyncMock(return_value={"user_id": "test-user"}),
|
||
)
|
||
monkeypatch.setattr(
|
||
"litellm.proxy.auth.login_utils.create_ui_token_object",
|
||
MagicMock(return_value={"user_id": "test-user"}),
|
||
)
|
||
monkeypatch.setattr("jwt.encode", MagicMock(return_value="signed-token"))
|
||
monkeypatch.setattr("litellm.proxy.proxy_server.master_key", "test-master-key")
|
||
monkeypatch.setattr(
|
||
"litellm.proxy.proxy_server.general_settings",
|
||
{"control_plane_url": "https://cp.example.com"},
|
||
)
|
||
monkeypatch.setattr("litellm.proxy.proxy_server.premium_user", False)
|
||
monkeypatch.setattr("litellm.proxy.proxy_server.prisma_client", mock_prisma_client)
|
||
mock_config = MagicMock()
|
||
mock_config.worker_registry = []
|
||
monkeypatch.setattr("litellm.proxy.proxy_server.proxy_config", mock_config)
|
||
monkeypatch.setattr("litellm.proxy.utils.get_server_root_path", lambda: "")
|
||
monkeypatch.setattr("litellm.proxy.utils.get_proxy_base_url", lambda: None)
|
||
|
||
client = TestClient(app)
|
||
response = client.post(
|
||
"/v3/login",
|
||
json={"username": "alice", "password": "secret"},
|
||
)
|
||
|
||
assert response.status_code == 200
|
||
data = response.json()
|
||
assert "code" in data
|
||
assert data["expires_in"] == 60
|
||
assert "token" not in data
|
||
|
||
|
||
def test_login_v3_exchange_happy_path(monkeypatch):
|
||
"""Full flow: v3/login returns code, v3/login/exchange redeems it for JWT."""
|
||
mock_prisma_client = MagicMock()
|
||
monkeypatch.setattr(
|
||
"litellm.proxy.auth.login_utils.authenticate_user",
|
||
AsyncMock(return_value={"user_id": "test-user"}),
|
||
)
|
||
monkeypatch.setattr(
|
||
"litellm.proxy.auth.login_utils.create_ui_token_object",
|
||
MagicMock(return_value={"user_id": "test-user"}),
|
||
)
|
||
monkeypatch.setattr("jwt.encode", MagicMock(return_value="signed-token"))
|
||
monkeypatch.setattr("litellm.proxy.proxy_server.master_key", "test-master-key")
|
||
monkeypatch.setattr(
|
||
"litellm.proxy.proxy_server.general_settings",
|
||
{"control_plane_url": "https://cp.example.com"},
|
||
)
|
||
monkeypatch.setattr("litellm.proxy.proxy_server.premium_user", False)
|
||
monkeypatch.setattr("litellm.proxy.proxy_server.prisma_client", mock_prisma_client)
|
||
mock_config = MagicMock()
|
||
mock_config.worker_registry = []
|
||
monkeypatch.setattr("litellm.proxy.proxy_server.proxy_config", mock_config)
|
||
monkeypatch.setattr("litellm.proxy.utils.get_server_root_path", lambda: "")
|
||
monkeypatch.setattr("litellm.proxy.utils.get_proxy_base_url", lambda: None)
|
||
|
||
client = TestClient(app)
|
||
|
||
# Step 1: login — get code
|
||
login_response = client.post(
|
||
"/v3/login",
|
||
json={"username": "alice", "password": "secret"},
|
||
)
|
||
assert login_response.status_code == 200
|
||
code = login_response.json()["code"]
|
||
|
||
# Step 2: exchange — get JWT
|
||
exchange_response = client.post(
|
||
"/v3/login/exchange",
|
||
json={"code": code},
|
||
)
|
||
assert exchange_response.status_code == 200
|
||
exchange_data = exchange_response.json()
|
||
assert exchange_data["token"] == "signed-token"
|
||
assert "redirect_url" in exchange_data
|
||
assert exchange_response.cookies.get("token") == "signed-token"
|
||
|
||
|
||
def test_login_v3_exchange_single_use(monkeypatch):
|
||
"""Code can only be redeemed once."""
|
||
mock_prisma_client = MagicMock()
|
||
monkeypatch.setattr(
|
||
"litellm.proxy.auth.login_utils.authenticate_user",
|
||
AsyncMock(return_value={"user_id": "test-user"}),
|
||
)
|
||
monkeypatch.setattr(
|
||
"litellm.proxy.auth.login_utils.create_ui_token_object",
|
||
MagicMock(return_value={"user_id": "test-user"}),
|
||
)
|
||
monkeypatch.setattr("jwt.encode", MagicMock(return_value="signed-token"))
|
||
monkeypatch.setattr("litellm.proxy.proxy_server.master_key", "test-master-key")
|
||
monkeypatch.setattr(
|
||
"litellm.proxy.proxy_server.general_settings",
|
||
{"control_plane_url": "https://cp.example.com"},
|
||
)
|
||
monkeypatch.setattr("litellm.proxy.proxy_server.premium_user", False)
|
||
monkeypatch.setattr("litellm.proxy.proxy_server.prisma_client", mock_prisma_client)
|
||
mock_config = MagicMock()
|
||
mock_config.worker_registry = []
|
||
monkeypatch.setattr("litellm.proxy.proxy_server.proxy_config", mock_config)
|
||
monkeypatch.setattr("litellm.proxy.utils.get_server_root_path", lambda: "")
|
||
monkeypatch.setattr("litellm.proxy.utils.get_proxy_base_url", lambda: None)
|
||
|
||
client = TestClient(app)
|
||
|
||
login_response = client.post(
|
||
"/v3/login",
|
||
json={"username": "alice", "password": "secret"},
|
||
)
|
||
code = login_response.json()["code"]
|
||
|
||
# First exchange succeeds
|
||
first = client.post("/v3/login/exchange", json={"code": code})
|
||
assert first.status_code == 200
|
||
|
||
# Second exchange fails
|
||
second = client.post("/v3/login/exchange", json={"code": code})
|
||
assert second.status_code == 401
|
||
|
||
|
||
def test_login_v3_exchange_invalid_code(monkeypatch):
|
||
"""Random code returns 401."""
|
||
monkeypatch.setattr(
|
||
"litellm.proxy.proxy_server.general_settings",
|
||
{"control_plane_url": "https://cp.example.com"},
|
||
)
|
||
client = TestClient(app)
|
||
response = client.post(
|
||
"/v3/login/exchange",
|
||
json={"code": "nonexistent-code"},
|
||
)
|
||
assert response.status_code == 401
|
||
|
||
|
||
def test_login_v3_exchange_rejected_without_control_plane_url(monkeypatch):
|
||
"""v3/login/exchange returns 404 when control_plane_url is not configured."""
|
||
monkeypatch.setattr("litellm.proxy.proxy_server.general_settings", {})
|
||
|
||
client = TestClient(app)
|
||
response = client.post(
|
||
"/v3/login/exchange",
|
||
json={"code": "some-code"},
|
||
)
|
||
|
||
assert response.status_code == 404
|
||
assert "control_plane_url" in response.json()["error"]["message"]
|
||
|
||
|
||
def test_login_v3_returns_json_on_proxy_exception(monkeypatch):
|
||
"""Test that /v3/login returns JSON error when ProxyException is raised"""
|
||
from litellm.proxy._types import ProxyErrorTypes, ProxyException
|
||
|
||
mock_prisma_client = MagicMock()
|
||
mock_authenticate_user = AsyncMock(
|
||
side_effect=ProxyException(
|
||
message="Invalid credentials",
|
||
type=ProxyErrorTypes.auth_error,
|
||
param="password",
|
||
code=401,
|
||
)
|
||
)
|
||
|
||
monkeypatch.setattr(
|
||
"litellm.proxy.auth.login_utils.authenticate_user",
|
||
mock_authenticate_user,
|
||
)
|
||
monkeypatch.setattr("litellm.proxy.proxy_server.master_key", "test-master-key")
|
||
monkeypatch.setattr(
|
||
"litellm.proxy.proxy_server.general_settings",
|
||
{"control_plane_url": "https://cp.example.com"},
|
||
)
|
||
monkeypatch.setattr("litellm.proxy.proxy_server.prisma_client", mock_prisma_client)
|
||
|
||
client = TestClient(app)
|
||
response = client.post(
|
||
"/v3/login",
|
||
json={"username": "alice", "password": "wrong"},
|
||
)
|
||
|
||
assert response.status_code == 401
|
||
assert response.headers["content-type"] == "application/json"
|
||
data = response.json()
|
||
assert "error" in data
|
||
assert data["error"]["message"] == "Invalid credentials"
|
||
assert data["error"]["type"] == "auth_error"
|
||
|
||
|
||
def test_fallback_login_has_no_deprecation_banner(client_no_auth):
|
||
response = client_no_auth.get("/fallback/login")
|
||
|
||
assert response.status_code == 200
|
||
html = response.text
|
||
assert '<div class="deprecation-banner">' not in html
|
||
assert "Deprecated:" not in html
|
||
assert "<form" in html
|
||
|
||
|
||
@pytest.mark.parametrize(
|
||
"ui_logo_path",
|
||
[
|
||
"/etc/litellm/secret-config.json",
|
||
"/var/secrets/admin.key",
|
||
"/proc/self/environ",
|
||
"relative/path/logo.png",
|
||
],
|
||
)
|
||
def test_get_logo_url_does_not_disclose_local_paths(
|
||
client_no_auth, monkeypatch, ui_logo_path
|
||
):
|
||
# ``/get_logo_url`` is unauthenticated. Returning a local filesystem
|
||
# path verbatim discloses admin-only config to any caller. Only
|
||
# browser-loadable HTTP(S) URLs should be returned; for local paths
|
||
# the dashboard falls back to ``/get_image``.
|
||
monkeypatch.setenv("UI_LOGO_PATH", ui_logo_path)
|
||
|
||
response = client_no_auth.get("/get_logo_url")
|
||
|
||
assert response.status_code == 200
|
||
assert response.json() == {"logo_url": ""}
|
||
|
||
|
||
def test_get_logo_url_returns_https_url(client_no_auth, monkeypatch):
|
||
monkeypatch.setenv("UI_LOGO_PATH", "https://cdn.public.example/logo.png")
|
||
|
||
response = client_no_auth.get("/get_logo_url")
|
||
|
||
assert response.status_code == 200
|
||
assert response.json() == {"logo_url": "https://cdn.public.example/logo.png"}
|
||
|
||
|
||
def test_get_logo_url_returns_http_url(client_no_auth, monkeypatch):
|
||
# HTTP URLs (typically internal CDN) are still returned — those are
|
||
# intended to be loaded directly by the browser.
|
||
monkeypatch.setenv("UI_LOGO_PATH", "http://internal-cdn.corp:8080/logo.png")
|
||
|
||
response = client_no_auth.get("/get_logo_url")
|
||
|
||
assert response.status_code == 200
|
||
assert response.json() == {"logo_url": "http://internal-cdn.corp:8080/logo.png"}
|
||
|
||
|
||
def test_get_logo_url_returns_empty_when_unset(client_no_auth, monkeypatch):
|
||
monkeypatch.delenv("UI_LOGO_PATH", raising=False)
|
||
|
||
response = client_no_auth.get("/get_logo_url")
|
||
|
||
assert response.status_code == 200
|
||
assert response.json() == {"logo_url": ""}
|
||
|
||
|
||
def test_sso_key_generate_shows_deprecation_banner(client_no_auth, monkeypatch):
|
||
# Ensure the route returns the HTML form instead of redirecting
|
||
monkeypatch.setattr(
|
||
"litellm.proxy.management_endpoints.ui_sso.show_missing_vars_in_env",
|
||
lambda: None,
|
||
)
|
||
monkeypatch.setattr(
|
||
"litellm.proxy.management_endpoints.ui_sso.SSOAuthenticationHandler.get_redirect_url_for_sso",
|
||
lambda *args, **kwargs: "http://test/redirect",
|
||
)
|
||
monkeypatch.setattr(
|
||
"litellm.proxy.management_endpoints.ui_sso.SSOAuthenticationHandler._get_cli_state",
|
||
lambda *args, **kwargs: None,
|
||
)
|
||
monkeypatch.setattr(
|
||
"litellm.proxy.management_endpoints.ui_sso.SSOAuthenticationHandler.should_use_sso_handler",
|
||
lambda *args, **kwargs: False,
|
||
)
|
||
# Mock premium_user to bypass enterprise check (prevents 403 Forbidden)
|
||
monkeypatch.setattr(
|
||
"litellm.proxy.proxy_server.premium_user",
|
||
True,
|
||
)
|
||
monkeypatch.setenv("UI_USERNAME", "admin")
|
||
|
||
response = client_no_auth.get("/sso/key/generate")
|
||
|
||
assert response.status_code == 200
|
||
html = response.text
|
||
assert '<div class="deprecation-banner">' in html
|
||
assert "Deprecated:" in html
|
||
|
||
|
||
def test_restructure_ui_html_files_handles_nested_routes(tmp_path):
|
||
"""
|
||
Test that _restructure_ui_html_files correctly restructures HTML files.
|
||
Note: This function is always called now, both in development and non-root Docker environments.
|
||
"""
|
||
from litellm.proxy import proxy_server
|
||
|
||
ui_root = tmp_path / "ui"
|
||
ui_root.mkdir()
|
||
|
||
def write_file(path: Path, content: str) -> None:
|
||
path.parent.mkdir(parents=True, exist_ok=True)
|
||
path.write_text(content)
|
||
|
||
write_file(ui_root / "home.html", "home")
|
||
write_file(ui_root / "mcp" / "oauth" / "callback.html", "callback")
|
||
write_file(ui_root / "existing" / "index.html", "keep")
|
||
write_file(ui_root / "_next" / "ignore.html", "asset")
|
||
write_file(ui_root / "litellm-asset-prefix" / "ignore.html", "asset")
|
||
|
||
proxy_server._restructure_ui_html_files(str(ui_root))
|
||
|
||
assert not (ui_root / "home.html").exists()
|
||
assert (ui_root / "home" / "index.html").read_text() == "home"
|
||
assert not (ui_root / "mcp" / "oauth" / "callback.html").exists()
|
||
assert (
|
||
ui_root / "mcp" / "oauth" / "callback" / "index.html"
|
||
).read_text() == "callback"
|
||
assert (ui_root / "existing" / "index.html").read_text() == "keep"
|
||
assert (ui_root / "_next" / "ignore.html").read_text() == "asset"
|
||
assert (ui_root / "litellm-asset-prefix" / "ignore.html").read_text() == "asset"
|
||
|
||
|
||
def test_ui_extensionless_route_requires_restructure(tmp_path):
|
||
"""
|
||
Regression for non-root fallback: /ui/login expects login/index.html.
|
||
Note: Restructuring always happens now, both in development and non-root Docker environments.
|
||
"""
|
||
|
||
from litellm.proxy import proxy_server
|
||
|
||
ui_root = tmp_path / "ui"
|
||
ui_root.mkdir()
|
||
(ui_root / "index.html").write_text("index")
|
||
(ui_root / "login.html").write_text("login")
|
||
|
||
fastapi_app = FastAPI()
|
||
fastapi_app.mount("/ui", StaticFiles(directory=str(ui_root), html=True), name="ui")
|
||
client = TestClient(fastapi_app)
|
||
|
||
assert client.get("/ui/login.html").status_code == 200
|
||
assert client.get("/ui/login").status_code == 404
|
||
|
||
proxy_server._restructure_ui_html_files(str(ui_root))
|
||
|
||
response = client.get("/ui/login")
|
||
assert response.status_code == 200
|
||
assert "login" in response.text
|
||
|
||
|
||
def test_admin_ui_export_serves_nested_extensionless_routes():
|
||
out_dir = Path(litellm.__file__).parent / "proxy" / "_experimental" / "out"
|
||
assert out_dir.is_dir(), f"missing UI export at {out_dir}"
|
||
|
||
nested_html_offenders = [
|
||
path.relative_to(out_dir).as_posix()
|
||
for path in out_dir.rglob("*.html")
|
||
if path.parent != out_dir
|
||
and path.name != "index.html"
|
||
and "_next" not in path.parts
|
||
and "litellm-asset-prefix" not in path.parts
|
||
]
|
||
assert not nested_html_offenders, (
|
||
"Nested routes must be named index.html. Offenders: " f"{nested_html_offenders}"
|
||
)
|
||
|
||
callback_index = out_dir / "mcp" / "oauth" / "callback" / "index.html"
|
||
assert callback_index.is_file(), (
|
||
f"MCP OAuth callback page must exist at {callback_index}; "
|
||
"without it /ui/mcp/oauth/callback 404s after Linear redirects back."
|
||
)
|
||
|
||
fastapi_app = FastAPI()
|
||
fastapi_app.mount("/ui", StaticFiles(directory=str(out_dir), html=True), name="ui")
|
||
client = TestClient(fastapi_app)
|
||
|
||
redirect = client.get(
|
||
"/ui/mcp/oauth/callback?code=abc&state=xyz",
|
||
follow_redirects=False,
|
||
)
|
||
assert redirect.status_code == 307
|
||
assert redirect.headers["location"].endswith(
|
||
"/ui/mcp/oauth/callback/?code=abc&state=xyz"
|
||
)
|
||
|
||
landed = client.get("/ui/mcp/oauth/callback?code=abc&state=xyz")
|
||
assert landed.status_code == 200
|
||
assert "<html" in landed.text.lower()
|
||
|
||
|
||
def test_restructure_always_happens(monkeypatch):
|
||
"""
|
||
Test that restructuring logic always executes regardless of LITELLM_NON_ROOT setting.
|
||
In development (is_non_root=False), restructuring happens directly in _experimental/out.
|
||
In non-root Docker (is_non_root=True), restructuring happens in /var/lib/litellm/ui.
|
||
"""
|
||
# Test Case 1: is_non_root is True - restructuring happens in /var/lib/litellm/ui
|
||
monkeypatch.setenv("LITELLM_NON_ROOT", "true")
|
||
|
||
runtime_ui_path = "/var/lib/litellm/ui"
|
||
packaged_ui_path = "/some/packaged/ui/path"
|
||
|
||
# Simulate the logic from proxy_server.py
|
||
is_non_root = os.getenv("LITELLM_NON_ROOT", "").lower() == "true"
|
||
if is_non_root:
|
||
ui_path = runtime_ui_path
|
||
else:
|
||
ui_path = packaged_ui_path
|
||
|
||
# Restructuring always happens now, regardless of ui_path vs packaged_ui_path
|
||
should_restructure = True
|
||
|
||
assert is_non_root is True
|
||
assert should_restructure is True
|
||
assert ui_path == runtime_ui_path
|
||
|
||
# Test Case 2: is_non_root is False - restructuring happens directly in packaged_ui_path
|
||
monkeypatch.delenv("LITELLM_NON_ROOT", raising=False)
|
||
|
||
# Simulate the logic from proxy_server.py
|
||
is_non_root = os.getenv("LITELLM_NON_ROOT", "").lower() == "true"
|
||
if is_non_root:
|
||
ui_path = runtime_ui_path
|
||
else:
|
||
ui_path = packaged_ui_path
|
||
|
||
# Restructuring always happens now, even when ui_path == packaged_ui_path
|
||
should_restructure = True
|
||
|
||
assert is_non_root is False
|
||
assert should_restructure is True
|
||
assert ui_path == packaged_ui_path
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_initialize_scheduled_jobs_credentials(monkeypatch):
|
||
"""
|
||
Test that get_credentials is only called when store_model_in_db is True
|
||
"""
|
||
monkeypatch.delenv("DISABLE_PRISMA_SCHEMA_UPDATE", raising=False)
|
||
monkeypatch.delenv("STORE_MODEL_IN_DB", raising=False)
|
||
from litellm.proxy.proxy_server import ProxyStartupEvent
|
||
from litellm.proxy.utils import ProxyLogging
|
||
|
||
# Mock dependencies
|
||
mock_prisma_client = MagicMock()
|
||
mock_proxy_logging = MagicMock(spec=ProxyLogging)
|
||
mock_proxy_logging.slack_alerting_instance = MagicMock()
|
||
mock_proxy_config = AsyncMock()
|
||
|
||
with (
|
||
patch("litellm.proxy.proxy_server.proxy_config", mock_proxy_config),
|
||
patch("litellm.proxy.proxy_server.store_model_in_db", False),
|
||
): # set store_model_in_db to False
|
||
# Test when store_model_in_db is False
|
||
await ProxyStartupEvent.initialize_scheduled_background_jobs(
|
||
general_settings={},
|
||
prisma_client=mock_prisma_client,
|
||
proxy_budget_rescheduler_min_time=1,
|
||
proxy_budget_rescheduler_max_time=2,
|
||
proxy_batch_write_at=5,
|
||
proxy_logging_obj=mock_proxy_logging,
|
||
)
|
||
|
||
# Verify get_credentials was not called
|
||
mock_proxy_config.get_credentials.assert_not_called()
|
||
|
||
# Now test with store_model_in_db = True
|
||
with (
|
||
patch("litellm.proxy.proxy_server.proxy_config", mock_proxy_config),
|
||
patch("litellm.proxy.proxy_server.store_model_in_db", True),
|
||
patch("litellm.proxy.proxy_server.get_secret_bool", return_value=True),
|
||
):
|
||
await ProxyStartupEvent.initialize_scheduled_background_jobs(
|
||
general_settings={},
|
||
prisma_client=mock_prisma_client,
|
||
proxy_budget_rescheduler_min_time=1,
|
||
proxy_budget_rescheduler_max_time=2,
|
||
proxy_batch_write_at=5,
|
||
proxy_logging_obj=mock_proxy_logging,
|
||
)
|
||
|
||
# Verify get_credentials was called both directly and scheduled
|
||
assert mock_proxy_config.get_credentials.call_count == 1 # Direct call
|
||
|
||
# Verify a scheduled job was added for get_credentials
|
||
mock_scheduler_calls = [
|
||
call[0] for call in mock_proxy_config.get_credentials.mock_calls
|
||
]
|
||
assert len(mock_scheduler_calls) > 0
|
||
|
||
|
||
def test_update_config_fields_deep_merge_db_wins():
|
||
from litellm.proxy.proxy_server import ProxyConfig
|
||
|
||
proxy_config = ProxyConfig()
|
||
|
||
current_config = {
|
||
"router_settings": {
|
||
"routing_mode": "cost_optimized",
|
||
"model_group_alias": {
|
||
# Existing alias with older model + different hidden flag
|
||
"claude-sonnet-4": {
|
||
"model": "claude-sonnet-4-20240219",
|
||
"hidden": True,
|
||
},
|
||
# An extra alias that should remain untouched unless DB overrides it
|
||
"legacy-sonnet": {
|
||
"model": "claude-2.1",
|
||
"hidden": True,
|
||
},
|
||
},
|
||
}
|
||
}
|
||
|
||
db_param_value = {
|
||
"model_group_alias": {
|
||
# Conflict: DB should win (both 'model' and 'hidden')
|
||
"claude-sonnet-4": {
|
||
"model": "claude-sonnet-4-20250514",
|
||
"hidden": False,
|
||
},
|
||
# New alias to be added by the merge
|
||
"claude-sonnet-latest": {
|
||
"model": "claude-sonnet-4-20250514",
|
||
"hidden": True,
|
||
},
|
||
# Demonstrate that None values from DB are skipped (preserve existing)
|
||
"legacy-sonnet": {"hidden": None}, # should not clobber current True
|
||
}
|
||
}
|
||
|
||
updated = proxy_config._update_config_fields(
|
||
current_config=current_config,
|
||
param_name="router_settings",
|
||
db_param_value=db_param_value,
|
||
)
|
||
|
||
rs = updated["router_settings"]
|
||
aliases = rs["model_group_alias"]
|
||
|
||
# DB wins on conflicts (deep) for existing alias
|
||
assert aliases["claude-sonnet-4"]["model"] == "claude-sonnet-4-20250514"
|
||
assert aliases["claude-sonnet-4"]["hidden"] is False
|
||
|
||
# New alias introduced by DB is present with its values
|
||
assert "claude-sonnet-latest" in aliases
|
||
assert aliases["claude-sonnet-latest"]["model"] == "claude-sonnet-4-20250514"
|
||
assert aliases["claude-sonnet-latest"]["hidden"] is True
|
||
|
||
# None in DB does not overwrite existing values
|
||
assert aliases["legacy-sonnet"]["model"] == "claude-2.1"
|
||
assert aliases["legacy-sonnet"]["hidden"] is True
|
||
|
||
# Unrelated router_settings keys are preserved
|
||
assert rs["routing_mode"] == "cost_optimized"
|
||
|
||
|
||
def test_get_config_custom_callback_api_env_vars(monkeypatch):
|
||
"""
|
||
Ensure /get/config/callbacks returns custom callback env vars when both custom values are provided.
|
||
"""
|
||
from litellm.proxy.proxy_server import app, proxy_config, user_api_key_auth
|
||
|
||
# Mock config with custom_callback_api enabled and generic logger env vars present
|
||
config_data = {
|
||
"litellm_settings": {"success_callback": ["custom_callback_api"]},
|
||
"general_settings": {},
|
||
"environment_variables": {
|
||
"GENERIC_LOGGER_ENDPOINT": "https://callback.example.com",
|
||
"GENERIC_LOGGER_HEADERS": "Auth: token",
|
||
},
|
||
}
|
||
|
||
# Mock proxy_config.get_config and router settings
|
||
mock_router = MagicMock()
|
||
mock_router.get_settings.return_value = {}
|
||
monkeypatch.setattr("litellm.proxy.proxy_server.llm_router", mock_router)
|
||
monkeypatch.setattr(proxy_config, "get_config", AsyncMock(return_value=config_data))
|
||
|
||
# Bypass auth dependency
|
||
original_overrides = app.dependency_overrides.copy()
|
||
app.dependency_overrides[user_api_key_auth] = lambda: MagicMock()
|
||
|
||
client = TestClient(app)
|
||
try:
|
||
response = client.get("/get/config/callbacks")
|
||
finally:
|
||
app.dependency_overrides = original_overrides
|
||
|
||
assert response.status_code == 200
|
||
callbacks = response.json()["callbacks"]
|
||
custom_cb = next(
|
||
(cb for cb in callbacks if cb["name"] == "custom_callback_api"), None
|
||
)
|
||
|
||
assert custom_cb is not None
|
||
assert custom_cb["variables"] == {
|
||
"GENERIC_LOGGER_ENDPOINT": "https://callback.example.com",
|
||
"GENERIC_LOGGER_HEADERS": "Auth: token",
|
||
}
|
||
|
||
|
||
# Mock Prisma
|
||
class MockPrisma:
|
||
def __init__(self, database_url=None, proxy_logging_obj=None, http_client=None):
|
||
self.database_url = database_url
|
||
self.proxy_logging_obj = proxy_logging_obj
|
||
self.http_client = http_client
|
||
|
||
async def connect(self):
|
||
pass
|
||
|
||
async def disconnect(self):
|
||
pass
|
||
|
||
|
||
mock_prisma = MockPrisma()
|
||
|
||
|
||
@patch(
|
||
"litellm.proxy.proxy_server.ProxyStartupEvent._setup_prisma_client",
|
||
return_value=mock_prisma,
|
||
)
|
||
@pytest.mark.asyncio
|
||
async def test_aaaproxy_startup_master_key(mock_prisma, monkeypatch, tmp_path):
|
||
"""
|
||
Test that master_key is correctly loaded from either config.yaml or environment variables
|
||
"""
|
||
import yaml
|
||
from fastapi import FastAPI
|
||
|
||
# Import happens here - this is when the module probably reads the config path
|
||
from litellm.proxy.proxy_server import proxy_startup_event
|
||
|
||
# Mock the Prisma import
|
||
monkeypatch.setattr("litellm.proxy.proxy_server.PrismaClient", MockPrisma)
|
||
|
||
# Create test app
|
||
app = FastAPI()
|
||
|
||
# Test Case 1: Master key from config.yaml
|
||
test_master_key = "sk-12345"
|
||
test_config = {"general_settings": {"master_key": test_master_key}}
|
||
|
||
# Create a temporary config file
|
||
config_path = tmp_path / "config.yaml"
|
||
with open(config_path, "w") as f:
|
||
yaml.dump(test_config, f)
|
||
|
||
print(f"SET ENV VARIABLE - CONFIG_FILE_PATH, str(config_path): {str(config_path)}")
|
||
# Second setting of CONFIG_FILE_PATH to a different value
|
||
monkeypatch.setenv("CONFIG_FILE_PATH", str(config_path))
|
||
print(f"config_path: {config_path}")
|
||
print(f"os.getenv('CONFIG_FILE_PATH'): {os.getenv('CONFIG_FILE_PATH')}")
|
||
async with proxy_startup_event(app):
|
||
from litellm.proxy.proxy_server import master_key
|
||
|
||
assert master_key == test_master_key
|
||
|
||
# Test Case 2: Master key from environment variable
|
||
test_env_master_key = "sk-test-67890"
|
||
|
||
# Create empty config
|
||
empty_config = {"general_settings": {}}
|
||
with open(config_path, "w") as f:
|
||
yaml.dump(empty_config, f)
|
||
|
||
monkeypatch.setenv("LITELLM_MASTER_KEY", test_env_master_key)
|
||
print("test_env_master_key: {}".format(test_env_master_key))
|
||
async with proxy_startup_event(app):
|
||
from litellm.proxy.proxy_server import master_key
|
||
|
||
assert master_key == test_env_master_key
|
||
|
||
# Test Case 3: Master key with os.environ prefix
|
||
test_resolved_key = "sk-resolved-key"
|
||
test_config_with_prefix = {
|
||
"general_settings": {"master_key": "os.environ/CUSTOM_MASTER_KEY"}
|
||
}
|
||
|
||
# Create config with os.environ prefix
|
||
with open(config_path, "w") as f:
|
||
yaml.dump(test_config_with_prefix, f)
|
||
|
||
monkeypatch.setenv("CUSTOM_MASTER_KEY", test_resolved_key)
|
||
async with proxy_startup_event(app):
|
||
from litellm.proxy.proxy_server import master_key
|
||
|
||
assert master_key == test_resolved_key
|
||
|
||
|
||
def test_team_info_masking():
|
||
"""
|
||
Test that sensitive team information is properly masked
|
||
|
||
Ref: https://huntr.com/bounties/661b388a-44d8-4ad5-862b-4dc5b80be30a
|
||
"""
|
||
from litellm.proxy.proxy_server import ProxyConfig
|
||
|
||
proxy_config = ProxyConfig()
|
||
# Test team object with sensitive data
|
||
team1_info = {
|
||
"success_callback": "['langfuse', 's3']",
|
||
"langfuse_secret": "secret-test-key",
|
||
"langfuse_public_key": "public-test-key",
|
||
}
|
||
|
||
with pytest.raises(Exception) as exc_info:
|
||
proxy_config._get_team_config(
|
||
team_id="test_dev",
|
||
all_teams_config=[team1_info],
|
||
)
|
||
|
||
print("Got exception: {}".format(exc_info.value))
|
||
assert "secret-test-key" not in str(exc_info.value)
|
||
assert "public-test-key" not in str(exc_info.value)
|
||
|
||
|
||
def test_embedding_input_array_of_tokens(client_no_auth):
|
||
"""
|
||
Test to bypass decoding input as array of tokens for selected providers
|
||
|
||
Ref: https://github.com/BerriAI/litellm/issues/10113
|
||
"""
|
||
from litellm.proxy import proxy_server
|
||
|
||
# The client_no_auth fixture should initialize the router
|
||
# Assert this to catch any router initialization regressions
|
||
assert proxy_server.llm_router is not None, (
|
||
"llm_router is None after client_no_auth fixture initialized. "
|
||
"This indicates a router initialization issue that should be investigated."
|
||
)
|
||
|
||
try:
|
||
with mock.patch.object(
|
||
proxy_server.llm_router,
|
||
"aembedding",
|
||
return_value=example_embedding_result,
|
||
) as mock_aembedding:
|
||
test_data = {
|
||
"model": "vllm_embed_model",
|
||
"input": [[2046, 13269, 158208]],
|
||
}
|
||
|
||
response = client_no_auth.post("/v1/embeddings", json=test_data)
|
||
|
||
# Assert that aembedding was called, and that input was not modified
|
||
mock_aembedding.assert_called_once()
|
||
call_args, call_kwargs = mock_aembedding.call_args
|
||
assert call_kwargs["model"] == "vllm_embed_model"
|
||
assert call_kwargs["input"] == [[2046, 13269, 158208]]
|
||
|
||
assert response.status_code == 200
|
||
result = response.json()
|
||
print(len(result["data"][0]["embedding"]))
|
||
assert (
|
||
len(result["data"][0]["embedding"]) > 10
|
||
) # this usually has len==1536 so
|
||
except Exception as e:
|
||
pytest.fail(f"LiteLLM Proxy test failed. Exception - {str(e)}")
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_get_all_team_models():
|
||
"""
|
||
Test get_all_team_models function with both "*" and specific team IDs
|
||
"""
|
||
from unittest.mock import AsyncMock, MagicMock
|
||
|
||
from litellm.proxy._types import LiteLLM_TeamTable
|
||
from litellm.proxy.proxy_server import get_all_team_models
|
||
|
||
# Mock team data
|
||
mock_team1 = MagicMock()
|
||
mock_team1.model_dump.return_value = {
|
||
"team_id": "team1",
|
||
"models": ["gpt-4", "gpt-3.5-turbo"],
|
||
"team_alias": "Team 1",
|
||
}
|
||
|
||
mock_team2 = MagicMock()
|
||
mock_team2.model_dump.return_value = {
|
||
"team_id": "team2",
|
||
"models": ["claude-3", "gpt-4"],
|
||
"team_alias": "Team 2",
|
||
}
|
||
|
||
# Mock model data returned by router
|
||
mock_models_gpt4 = [
|
||
{"model_info": {"id": "gpt-4-model-1"}},
|
||
{"model_info": {"id": "gpt-4-model-2"}},
|
||
]
|
||
mock_models_gpt35 = [
|
||
{"model_info": {"id": "gpt-3.5-turbo-model-1"}},
|
||
]
|
||
mock_models_claude = [
|
||
{"model_info": {"id": "claude-3-model-1"}},
|
||
]
|
||
|
||
# Mock prisma client
|
||
mock_prisma_client = MagicMock()
|
||
mock_db = MagicMock()
|
||
mock_litellm_teamtable = MagicMock()
|
||
|
||
mock_prisma_client.db = mock_db
|
||
mock_db.litellm_teamtable = mock_litellm_teamtable
|
||
|
||
# Make find_many async
|
||
mock_litellm_teamtable.find_many = AsyncMock()
|
||
|
||
# Mock router
|
||
mock_router = MagicMock()
|
||
|
||
def mock_get_model_list(model_name, team_id=None):
|
||
if model_name == "gpt-4":
|
||
return mock_models_gpt4
|
||
elif model_name == "gpt-3.5-turbo":
|
||
return mock_models_gpt35
|
||
elif model_name == "claude-3":
|
||
return mock_models_claude
|
||
return None
|
||
|
||
mock_router.get_model_list.side_effect = mock_get_model_list
|
||
|
||
# Test Case 1: user_teams = "*" (all teams)
|
||
mock_litellm_teamtable.find_many.return_value = [mock_team1, mock_team2]
|
||
|
||
with patch("litellm.proxy.proxy_server.LiteLLM_TeamTable") as mock_team_table_class:
|
||
# Configure the mock class to return proper instances
|
||
def mock_team_table_constructor(**kwargs):
|
||
mock_instance = MagicMock()
|
||
mock_instance.team_id = kwargs["team_id"]
|
||
mock_instance.models = kwargs["models"]
|
||
mock_instance.access_group_ids = kwargs.get("access_group_ids")
|
||
return mock_instance
|
||
|
||
mock_team_table_class.side_effect = mock_team_table_constructor
|
||
|
||
result = await get_all_team_models(
|
||
user_teams="*",
|
||
prisma_client=mock_prisma_client,
|
||
llm_router=mock_router,
|
||
)
|
||
|
||
# Verify find_many was called without where clause for "*"
|
||
mock_litellm_teamtable.find_many.assert_called_with()
|
||
|
||
# Verify router.get_model_list was called for each model
|
||
expected_calls = [
|
||
mock.call(model_name="gpt-4", team_id="team1"),
|
||
mock.call(model_name="gpt-3.5-turbo", team_id="team1"),
|
||
mock.call(model_name="claude-3", team_id="team2"),
|
||
mock.call(model_name="gpt-4", team_id="team2"),
|
||
]
|
||
mock_router.get_model_list.assert_has_calls(expected_calls, any_order=True)
|
||
|
||
# Test Case 2: user_teams = specific list
|
||
mock_litellm_teamtable.reset_mock()
|
||
mock_router.reset_mock()
|
||
mock_router.get_model_list.side_effect = mock_get_model_list
|
||
|
||
# Only return team1 for specific team query
|
||
mock_litellm_teamtable.find_many.return_value = [mock_team1]
|
||
|
||
with patch("litellm.proxy.proxy_server.LiteLLM_TeamTable") as mock_team_table_class:
|
||
mock_team_table_class.side_effect = mock_team_table_constructor
|
||
|
||
result = await get_all_team_models(
|
||
user_teams=["team1"],
|
||
prisma_client=mock_prisma_client,
|
||
llm_router=mock_router,
|
||
)
|
||
|
||
# Verify find_many was called with where clause for specific teams
|
||
mock_litellm_teamtable.find_many.assert_called_with(
|
||
where={"team_id": {"in": ["team1"]}}
|
||
)
|
||
|
||
# Verify router.get_model_list was called only for team1 models
|
||
expected_calls = [
|
||
mock.call(model_name="gpt-4", team_id="team1"),
|
||
mock.call(model_name="gpt-3.5-turbo", team_id="team1"),
|
||
]
|
||
mock_router.get_model_list.assert_has_calls(expected_calls, any_order=True)
|
||
|
||
# Test Case 3: Empty teams list
|
||
mock_litellm_teamtable.reset_mock()
|
||
mock_router.reset_mock()
|
||
mock_litellm_teamtable.find_many.return_value = []
|
||
|
||
result = await get_all_team_models(
|
||
user_teams=[],
|
||
prisma_client=mock_prisma_client,
|
||
llm_router=mock_router,
|
||
)
|
||
|
||
# Verify find_many was called with empty list
|
||
mock_litellm_teamtable.find_many.assert_called_with(where={"team_id": {"in": []}})
|
||
|
||
# Should return empty list when no teams
|
||
assert result == {}
|
||
|
||
# Test Case 4: Router returns None for some models
|
||
mock_litellm_teamtable.reset_mock()
|
||
mock_router.reset_mock()
|
||
mock_litellm_teamtable.find_many.return_value = [mock_team1]
|
||
|
||
def mock_get_model_list_with_none(model_name, team_id=None):
|
||
if model_name == "gpt-4":
|
||
return mock_models_gpt4
|
||
# Return None for gpt-3.5-turbo to test None handling
|
||
return None
|
||
|
||
mock_router.get_model_list.side_effect = mock_get_model_list_with_none
|
||
|
||
with patch("litellm.proxy.proxy_server.LiteLLM_TeamTable") as mock_team_table_class:
|
||
mock_team_table_class.side_effect = mock_team_table_constructor
|
||
|
||
result = await get_all_team_models(
|
||
user_teams=["team1"],
|
||
prisma_client=mock_prisma_client,
|
||
llm_router=mock_router,
|
||
)
|
||
|
||
# Should handle None return gracefully
|
||
assert isinstance(result, dict)
|
||
print("result: ", result)
|
||
assert result == {"gpt-4-model-1": ["team1"], "gpt-4-model-2": ["team1"]}
|
||
|
||
|
||
def test_add_team_models_to_all_models():
|
||
"""
|
||
Test add_team_models_to_all_models function
|
||
"""
|
||
from litellm.proxy._types import LiteLLM_TeamTable
|
||
from litellm.proxy.proxy_server import _add_team_models_to_all_models
|
||
|
||
team_db_objects_typed = MagicMock(spec=LiteLLM_TeamTable)
|
||
team_db_objects_typed.team_id = "team1"
|
||
team_db_objects_typed.models = ["all-proxy-models"]
|
||
|
||
llm_router = MagicMock()
|
||
llm_router.get_model_list.return_value = [
|
||
{"model_info": {"id": "gpt-4-model-1", "team_id": "team2"}},
|
||
{"model_info": {"id": "gpt-4-model-2"}},
|
||
]
|
||
|
||
result = _add_team_models_to_all_models(
|
||
team_db_objects_typed=[team_db_objects_typed],
|
||
llm_router=llm_router,
|
||
)
|
||
assert result == {"gpt-4-model-2": {"team1"}}
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_apply_search_filter_matches_team_public_model_name():
|
||
"""
|
||
Regression test: team BYOK models persist an internal model_name
|
||
(e.g. `model_name_{team_id}_{uuid}`) and surface the user-facing name
|
||
via `model_info.team_public_model_name`. The /v2/model/info search
|
||
filter must match that public name so BYOK rows appear in results.
|
||
"""
|
||
from litellm.proxy.proxy_server import _apply_search_filter_to_models
|
||
|
||
byok_model = {
|
||
"model_name": "model_name_team-abc-123_4a6b8",
|
||
"litellm_params": {"model": "claude-sonnet-4-5"},
|
||
"model_info": {
|
||
"id": "byok-id-1",
|
||
"team_id": "team-abc-123",
|
||
"team_public_model_name": "team-claude-sonnet",
|
||
"db_model": True,
|
||
},
|
||
}
|
||
unrelated_model = {
|
||
"model_name": "gpt-4",
|
||
"litellm_params": {"model": "gpt-4"},
|
||
"model_info": {"id": "normal-id-1", "db_model": False},
|
||
}
|
||
|
||
# Search matching only team_public_model_name should still include BYOK
|
||
filtered, _ = await _apply_search_filter_to_models(
|
||
all_models=[byok_model, unrelated_model],
|
||
search="claude",
|
||
prisma_client=None,
|
||
proxy_config=MagicMock(),
|
||
)
|
||
filtered_ids = {m["model_info"]["id"] for m in filtered}
|
||
assert "byok-id-1" in filtered_ids
|
||
assert "normal-id-1" not in filtered_ids
|
||
|
||
# Search by internal model_name still matches as before
|
||
filtered, _ = await _apply_search_filter_to_models(
|
||
all_models=[byok_model, unrelated_model],
|
||
search="model_name_team-abc-123",
|
||
prisma_client=None,
|
||
proxy_config=MagicMock(),
|
||
)
|
||
assert [m["model_info"]["id"] for m in filtered] == ["byok-id-1"]
|
||
|
||
# Non-matching search returns nothing
|
||
filtered, _ = await _apply_search_filter_to_models(
|
||
all_models=[byok_model, unrelated_model],
|
||
search="gemini",
|
||
prisma_client=None,
|
||
proxy_config=MagicMock(),
|
||
)
|
||
assert filtered == []
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_apply_search_filter_scopes_byok_to_caller_teams():
|
||
"""
|
||
Regression test: `/v2/model/info?search=...` must not leak BYOK rows
|
||
from teams the caller is not a member of. Even with a bounded
|
||
`model_name`-contains DB query, a non-admin caller could otherwise
|
||
see other teams' BYOK rows that happen to match by internal name.
|
||
The post-fetch team scope drops those.
|
||
"""
|
||
from litellm.proxy.proxy_server import _apply_search_filter_to_models
|
||
|
||
# In-router BYOK rows: one in the caller's team, one in someone else's.
|
||
caller_team_byok = {
|
||
"model_name": "model_name_team-mine_internal",
|
||
"litellm_params": {"model": "claude-sonnet"},
|
||
"model_info": {
|
||
"id": "byok-mine",
|
||
"team_id": "team-mine",
|
||
"team_public_model_name": "claude-sonnet-prod",
|
||
"db_model": True,
|
||
},
|
||
}
|
||
other_team_byok = {
|
||
"model_name": "model_name_team-other_internal",
|
||
"litellm_params": {"model": "claude-sonnet"},
|
||
"model_info": {
|
||
"id": "byok-other",
|
||
"team_id": "team-other",
|
||
"team_public_model_name": "claude-sonnet-staging",
|
||
"db_model": True,
|
||
},
|
||
}
|
||
# Non-team row stays in the router-side result regardless of teams.
|
||
public_model = {
|
||
"model_name": "claude-public",
|
||
"litellm_params": {"model": "claude-sonnet"},
|
||
"model_info": {"id": "public-id", "db_model": False},
|
||
}
|
||
|
||
# DB-only BYOK rows fetched by the over-broad JSON branch.
|
||
db_caller_row = MagicMock()
|
||
db_caller_row.model_id = "byok-db-mine"
|
||
db_caller_row.model_name = "model_name_team-mine_db"
|
||
db_caller_row.model_info = {
|
||
"id": "byok-db-mine",
|
||
"team_id": "team-mine",
|
||
"team_public_model_name": "Claude DB Mine",
|
||
"db_model": True,
|
||
}
|
||
db_other_row = MagicMock()
|
||
db_other_row.model_id = "byok-db-other"
|
||
db_other_row.model_name = "model_name_team-other_db"
|
||
db_other_row.model_info = {
|
||
"id": "byok-db-other",
|
||
"team_id": "team-other",
|
||
"team_public_model_name": "Claude DB Other",
|
||
"db_model": True,
|
||
}
|
||
|
||
prisma_client = MagicMock()
|
||
prisma_client.db.litellm_proxymodeltable.count = AsyncMock(return_value=2)
|
||
prisma_client.db.litellm_proxymodeltable.find_many = AsyncMock(
|
||
return_value=[db_caller_row, db_other_row]
|
||
)
|
||
caller_user_row = MagicMock()
|
||
caller_user_row.teams = ["team-mine"]
|
||
prisma_client.db.litellm_usertable.find_unique = AsyncMock(
|
||
return_value=caller_user_row
|
||
)
|
||
|
||
proxy_config = MagicMock()
|
||
proxy_config.decrypt_model_list_from_db = lambda rows: [
|
||
{
|
||
"model_name": r.model_name,
|
||
"model_info": r.model_info,
|
||
"litellm_params": {"model": "claude-sonnet"},
|
||
}
|
||
for r in rows
|
||
]
|
||
|
||
non_admin = MagicMock(spec=UserAPIKeyAuth)
|
||
non_admin.user_role = LitellmUserRoles.INTERNAL_USER
|
||
non_admin.user_id = "user-mine"
|
||
|
||
filtered, total_count = await _apply_search_filter_to_models(
|
||
all_models=[caller_team_byok, other_team_byok, public_model],
|
||
search="claude",
|
||
prisma_client=prisma_client,
|
||
proxy_config=proxy_config,
|
||
user_api_key_dict=non_admin,
|
||
)
|
||
|
||
filtered_ids = {m["model_info"]["id"] for m in filtered}
|
||
assert "byok-mine" in filtered_ids
|
||
assert "byok-db-mine" in filtered_ids
|
||
assert "public-id" in filtered_ids
|
||
assert "byok-other" not in filtered_ids, (
|
||
"router-side BYOK from another team must be dropped from search "
|
||
"when caller doesn't belong to that team"
|
||
)
|
||
assert "byok-db-other" not in filtered_ids, (
|
||
"DB-only BYOK from another team must be dropped from search when "
|
||
"caller doesn't belong to that team"
|
||
)
|
||
# total_count is router_models_count (2: caller_team_byok + public_model,
|
||
# other_team_byok dropped router-side) + DB count (2 from the mocked
|
||
# `count()`). The DB count is the *unscoped* match count; non-admin
|
||
# team scoping applies only to the returned page so the count can be
|
||
# over-reported, but it must never under-report (callers can paginate
|
||
# within the bound).
|
||
assert total_count == 4
|
||
|
||
# Admins keep the un-scoped view across teams.
|
||
admin = MagicMock(spec=UserAPIKeyAuth)
|
||
admin.user_role = LitellmUserRoles.PROXY_ADMIN
|
||
admin.user_id = "admin-1"
|
||
|
||
filtered_admin, _ = await _apply_search_filter_to_models(
|
||
all_models=[caller_team_byok, other_team_byok, public_model],
|
||
search="claude",
|
||
prisma_client=prisma_client,
|
||
proxy_config=proxy_config,
|
||
user_api_key_dict=admin,
|
||
)
|
||
admin_ids = {m["model_info"]["id"] for m in filtered_admin}
|
||
assert "byok-other" in admin_ids
|
||
assert "byok-db-other" in admin_ids
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_apply_search_filter_bounds_db_fetch_by_page_and_cap():
|
||
"""
|
||
Regression test: a broad search term must not force a full BYOK-table
|
||
read + decrypt on each request.
|
||
|
||
* Unsorted searches: `find_many(take=N)` where N is just enough to
|
||
fill the current page after counting router-side matches.
|
||
* Sorted searches: `find_many(take=cap)` falls back to
|
||
`_SORTED_SEARCH_DB_FETCH_CAP` so ordering still works across a
|
||
large match set without scanning the whole table.
|
||
"""
|
||
from litellm.proxy.proxy_server import (
|
||
_SORTED_SEARCH_DB_FETCH_CAP,
|
||
_apply_search_filter_to_models,
|
||
)
|
||
|
||
prisma_client = MagicMock()
|
||
prisma_client.db.litellm_proxymodeltable.count = AsyncMock(return_value=10_000)
|
||
prisma_client.db.litellm_proxymodeltable.find_many = AsyncMock(return_value=[])
|
||
|
||
proxy_config = MagicMock()
|
||
proxy_config.decrypt_model_list_from_db = lambda rows: []
|
||
|
||
# Unsorted: page=1, size=50, no router-side matches -> take must be 50.
|
||
await _apply_search_filter_to_models(
|
||
all_models=[],
|
||
search="model",
|
||
prisma_client=prisma_client,
|
||
proxy_config=proxy_config,
|
||
page=1,
|
||
size=50,
|
||
sort_by=None,
|
||
)
|
||
take = prisma_client.db.litellm_proxymodeltable.find_many.call_args.kwargs["take"]
|
||
assert take == 50, "unsorted search must take just one page's worth of rows"
|
||
|
||
# Sorted: still bounded, but by the hard cap rather than the page.
|
||
prisma_client.db.litellm_proxymodeltable.find_many.reset_mock()
|
||
await _apply_search_filter_to_models(
|
||
all_models=[],
|
||
search="model",
|
||
prisma_client=prisma_client,
|
||
proxy_config=proxy_config,
|
||
page=1,
|
||
size=50,
|
||
sort_by="model_name",
|
||
)
|
||
take = prisma_client.db.litellm_proxymodeltable.find_many.call_args.kwargs["take"]
|
||
assert take == _SORTED_SEARCH_DB_FETCH_CAP
|
||
assert take < 10_000, "sorted search must cap below the full match set"
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_filter_models_by_team_id_excludes_viewer_direct_access():
|
||
"""
|
||
Regression test: when the UI picks a specific team in the Current Team
|
||
selector, the model list must show only that team's BYOK rows + the
|
||
models assigned to the team. The admin viewer's `direct_access` flag
|
||
(set on every non-team model upstream) must NOT widen the team's
|
||
visible set, or selecting team-111 still shows every public model.
|
||
"""
|
||
from litellm.proxy.proxy_server import _filter_models_by_team_id
|
||
|
||
public_model = {
|
||
"model_name": "gpt-4",
|
||
"litellm_params": {"model": "gpt-4"},
|
||
"model_info": {
|
||
"id": "public-id",
|
||
# admin viewer has direct_access on this public model
|
||
"direct_access": True,
|
||
# team-111 is NOT in access_via_team_ids -> shouldn't show for team-111
|
||
"access_via_team_ids": ["team-222"],
|
||
},
|
||
}
|
||
team111_byok = {
|
||
"model_name": "model_name_team-111_uuid",
|
||
"litellm_params": {"model": "claude-sonnet"},
|
||
"model_info": {
|
||
"id": "byok-team-111",
|
||
"team_id": "team-111",
|
||
"team_public_model_name": "team-claude",
|
||
"access_via_team_ids": ["team-111"],
|
||
},
|
||
}
|
||
team222_byok = {
|
||
"model_name": "model_name_team-222_uuid",
|
||
"litellm_params": {"model": "claude-haiku"},
|
||
"model_info": {
|
||
"id": "byok-team-222",
|
||
"team_id": "team-222",
|
||
"team_public_model_name": "team-haiku",
|
||
"access_via_team_ids": ["team-222"],
|
||
},
|
||
}
|
||
|
||
prisma = MagicMock()
|
||
team_db = MagicMock()
|
||
team_db.model_dump.return_value = {
|
||
"team_id": "team-111",
|
||
"team_alias": "Team 111",
|
||
# specific models list that doesn't include the BYOK's internal name
|
||
"models": ["some-other-model"],
|
||
"access_group_ids": None,
|
||
}
|
||
prisma.db.litellm_teamtable.find_unique = AsyncMock(return_value=team_db)
|
||
prisma.db.litellm_proxymodeltable.find_many = AsyncMock(return_value=[])
|
||
|
||
router = MagicMock()
|
||
router.get_model_access_groups = MagicMock(return_value={})
|
||
# team-111 only resolves "some-other-model", which has no deployments
|
||
router.get_model_list = MagicMock(return_value=[])
|
||
|
||
filtered = await _filter_models_by_team_id(
|
||
all_models=[public_model, team111_byok, team222_byok],
|
||
team_id="team-111",
|
||
prisma_client=prisma,
|
||
llm_router=router,
|
||
)
|
||
visible_ids = sorted(m["model_info"]["id"] for m in filtered)
|
||
|
||
assert "byok-team-111" in visible_ids, "team-111's own BYOK must always be visible"
|
||
assert "byok-team-222" not in visible_ids, "must not leak other teams' BYOK"
|
||
assert (
|
||
"public-id" not in visible_ids
|
||
), "viewer's direct_access must not widen the team's visible set"
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_filter_models_by_team_id_rejects_non_member():
|
||
"""
|
||
Regression test: /v2/model/info?teamId=X includes BYOK rows solely on
|
||
`model_info.team_id == X`. Without an auth check, any authenticated user
|
||
could enumerate another team's BYOK metadata by guessing its id. Callers
|
||
that are neither proxy admins nor members of `team_id` must get 403.
|
||
"""
|
||
from fastapi import HTTPException
|
||
|
||
from litellm.proxy._types import LitellmUserRoles, UserAPIKeyAuth
|
||
from litellm.proxy.proxy_server import _filter_models_by_team_id
|
||
|
||
byok = {
|
||
"model_name": "model_name_team-111_uuid",
|
||
"litellm_params": {"model": "claude"},
|
||
"model_info": {"id": "byok-team-111", "team_id": "team-111"},
|
||
}
|
||
|
||
prisma = MagicMock()
|
||
# Caller is in team-222 only
|
||
user_row = MagicMock()
|
||
user_row.teams = ["team-222"]
|
||
prisma.db.litellm_usertable.find_unique = AsyncMock(return_value=user_row)
|
||
|
||
caller = UserAPIKeyAuth(
|
||
user_id="alice",
|
||
user_role=LitellmUserRoles.INTERNAL_USER,
|
||
api_key="sk-test",
|
||
)
|
||
|
||
with pytest.raises(HTTPException) as excinfo:
|
||
await _filter_models_by_team_id(
|
||
all_models=[byok],
|
||
team_id="team-111",
|
||
prisma_client=prisma,
|
||
llm_router=MagicMock(),
|
||
user_api_key_dict=caller,
|
||
)
|
||
assert excinfo.value.status_code == 403
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_filter_models_by_team_id_allows_team_member():
|
||
"""
|
||
A caller who IS a member of `team_id` must be allowed to filter, and
|
||
should see that team's BYOK rows.
|
||
"""
|
||
from litellm.proxy._types import LitellmUserRoles, UserAPIKeyAuth
|
||
from litellm.proxy.proxy_server import _filter_models_by_team_id
|
||
|
||
byok = {
|
||
"model_name": "model_name_team-111_uuid",
|
||
"litellm_params": {"model": "claude"},
|
||
"model_info": {"id": "byok-team-111", "team_id": "team-111"},
|
||
}
|
||
|
||
prisma = MagicMock()
|
||
user_row = MagicMock()
|
||
user_row.teams = ["team-111", "team-999"]
|
||
prisma.db.litellm_usertable.find_unique = AsyncMock(return_value=user_row)
|
||
team_db = MagicMock()
|
||
team_db.model_dump.return_value = {
|
||
"team_id": "team-111",
|
||
"team_alias": "Team 111",
|
||
"models": [],
|
||
"access_group_ids": None,
|
||
}
|
||
prisma.db.litellm_teamtable.find_unique = AsyncMock(return_value=team_db)
|
||
prisma.db.litellm_proxymodeltable.find_many = AsyncMock(return_value=[])
|
||
|
||
router = MagicMock()
|
||
router.get_model_access_groups = MagicMock(return_value={})
|
||
router.get_model_list = MagicMock(return_value=[byok])
|
||
|
||
caller = UserAPIKeyAuth(
|
||
user_id="bob",
|
||
user_role=LitellmUserRoles.INTERNAL_USER,
|
||
api_key="sk-test",
|
||
)
|
||
|
||
result = await _filter_models_by_team_id(
|
||
all_models=[byok],
|
||
team_id="team-111",
|
||
prisma_client=prisma,
|
||
llm_router=router,
|
||
user_api_key_dict=caller,
|
||
)
|
||
assert [m["model_info"]["id"] for m in result] == ["byok-team-111"]
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_caller_byok_team_scope_treats_view_only_admin_as_unscoped():
|
||
"""
|
||
Regression test: `PROXY_ADMIN_VIEW_ONLY` is an admin role
|
||
("can login, view all own keys, view all spend"). Search results for
|
||
this role must show BYOK rows across all teams, not be silently scoped
|
||
to the user-id's `teams` field — that path narrows results to whatever
|
||
teams the admin happens to be a member of, regressing pre-PR behavior.
|
||
"""
|
||
from litellm.proxy._types import LitellmUserRoles, UserAPIKeyAuth
|
||
from litellm.proxy.proxy_server import _get_caller_byok_team_scope
|
||
|
||
caller = UserAPIKeyAuth(
|
||
user_id="view-admin",
|
||
user_role=LitellmUserRoles.PROXY_ADMIN_VIEW_ONLY,
|
||
api_key="sk-test",
|
||
)
|
||
scope = await _get_caller_byok_team_scope(
|
||
user_api_key_dict=caller,
|
||
prisma_client=MagicMock(),
|
||
)
|
||
assert scope is None, "PROXY_ADMIN_VIEW_ONLY must be unscoped, like PROXY_ADMIN"
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_add_access_group_models_to_team_models():
|
||
"""
|
||
Test that models reachable via team access groups are included in team_models.
|
||
|
||
Scenario: A team has models=["gpt-4"] and access_group_ids=["premium"].
|
||
The "premium" access group contains ["claude-3", "gemini"].
|
||
After resolution, the team should see gpt-4 (direct) + claude-3/gemini (via access group).
|
||
"""
|
||
from litellm.proxy._types import LiteLLM_TeamTable
|
||
from litellm.proxy.proxy_server import _add_access_group_models_to_team_models
|
||
|
||
# Team with specific models AND access groups
|
||
team_with_access_groups = MagicMock(spec=LiteLLM_TeamTable)
|
||
team_with_access_groups.team_id = "team1"
|
||
team_with_access_groups.models = ["gpt-4"] # non-empty = specific models
|
||
team_with_access_groups.access_group_ids = ["premium"]
|
||
|
||
# Team with no access groups — should be skipped
|
||
team_without_access_groups = MagicMock(spec=LiteLLM_TeamTable)
|
||
team_without_access_groups.team_id = "team2"
|
||
team_without_access_groups.models = ["gpt-4"]
|
||
team_without_access_groups.access_group_ids = None
|
||
|
||
# Team with empty access_group_ids list — should be skipped
|
||
team_empty_access_groups = MagicMock(spec=LiteLLM_TeamTable)
|
||
team_empty_access_groups.team_id = "team2b"
|
||
team_empty_access_groups.models = ["gpt-4"]
|
||
team_empty_access_groups.access_group_ids = []
|
||
|
||
# Team with empty models (all access) — should be skipped
|
||
team_all_access = MagicMock(spec=LiteLLM_TeamTable)
|
||
team_all_access.team_id = "team3"
|
||
team_all_access.models = []
|
||
team_all_access.access_group_ids = ["premium"]
|
||
|
||
# Team with all-proxy-models sentinel (all access) — should be skipped
|
||
team_all_proxy = MagicMock(spec=LiteLLM_TeamTable)
|
||
team_all_proxy.team_id = "team4"
|
||
team_all_proxy.models = ["all-proxy-models"]
|
||
team_all_proxy.access_group_ids = ["premium"]
|
||
|
||
# Mock router
|
||
mock_router = MagicMock()
|
||
|
||
def mock_get_model_list(model_name, team_id=None):
|
||
if model_name == "claude-3":
|
||
return [{"model_info": {"id": "claude-3-id"}}]
|
||
elif model_name == "gemini":
|
||
return [{"model_info": {"id": "gemini-id"}}]
|
||
return None
|
||
|
||
mock_router.get_model_list.side_effect = mock_get_model_list
|
||
|
||
# Pre-existing team_models (e.g., from _add_team_models_to_all_models)
|
||
existing_team_models = {
|
||
"gpt-4-id": {"team1"},
|
||
}
|
||
|
||
# Mock prisma client with batch find_many returning access group rows
|
||
mock_ag_row = MagicMock()
|
||
mock_ag_row.access_group_id = "premium"
|
||
mock_ag_row.access_model_names = ["claude-3", "gemini"]
|
||
|
||
mock_prisma_client = MagicMock()
|
||
mock_prisma_client.db.litellm_accessgrouptable.find_many = AsyncMock(
|
||
return_value=[mock_ag_row]
|
||
)
|
||
|
||
result = await _add_access_group_models_to_team_models(
|
||
team_db_objects_typed=[
|
||
team_with_access_groups,
|
||
team_without_access_groups,
|
||
team_empty_access_groups,
|
||
team_all_access,
|
||
team_all_proxy,
|
||
],
|
||
llm_router=mock_router,
|
||
prisma_client=mock_prisma_client,
|
||
team_models=existing_team_models,
|
||
)
|
||
|
||
# Single batch query with only the eligible team's access group IDs
|
||
mock_prisma_client.db.litellm_accessgrouptable.find_many.assert_called_once()
|
||
call_args = mock_prisma_client.db.litellm_accessgrouptable.find_many.call_args
|
||
queried_ids = call_args[1]["where"]["access_group_id"]["in"]
|
||
assert set(queried_ids) == {"premium"}
|
||
|
||
# Original model still present
|
||
assert "gpt-4-id" in result
|
||
assert "team1" in result["gpt-4-id"]
|
||
|
||
# Access group models added for team1
|
||
assert "claude-3-id" in result
|
||
assert "team1" in result["claude-3-id"]
|
||
assert "gemini-id" in result
|
||
assert "team1" in result["gemini-id"]
|
||
|
||
# Skipped teams should NOT have added these models
|
||
for skipped_team in ["team2", "team2b", "team3", "team4"]:
|
||
assert skipped_team not in result.get("claude-3-id", set())
|
||
assert skipped_team not in result.get("gemini-id", set())
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_add_access_group_models_multiple_teams_shared_group():
|
||
"""
|
||
Test that multiple teams sharing the same access group each get the models,
|
||
and only one batch DB query is made.
|
||
"""
|
||
from litellm.proxy._types import LiteLLM_TeamTable
|
||
from litellm.proxy.proxy_server import _add_access_group_models_to_team_models
|
||
|
||
team_a = MagicMock(spec=LiteLLM_TeamTable)
|
||
team_a.team_id = "team-a"
|
||
team_a.models = ["gpt-4"]
|
||
team_a.access_group_ids = ["shared-group"]
|
||
|
||
team_b = MagicMock(spec=LiteLLM_TeamTable)
|
||
team_b.team_id = "team-b"
|
||
team_b.models = ["gpt-3.5"]
|
||
team_b.access_group_ids = ["shared-group", "extra-group"]
|
||
|
||
mock_router = MagicMock()
|
||
|
||
def mock_get_model_list(model_name, team_id=None):
|
||
if model_name == "claude-3":
|
||
return [{"model_info": {"id": "claude-3-id"}}]
|
||
elif model_name == "gemini":
|
||
return [{"model_info": {"id": "gemini-id"}}]
|
||
return None
|
||
|
||
mock_router.get_model_list.side_effect = mock_get_model_list
|
||
|
||
mock_shared_row = MagicMock()
|
||
mock_shared_row.access_group_id = "shared-group"
|
||
mock_shared_row.access_model_names = ["claude-3"]
|
||
|
||
mock_extra_row = MagicMock()
|
||
mock_extra_row.access_group_id = "extra-group"
|
||
mock_extra_row.access_model_names = ["gemini"]
|
||
|
||
mock_prisma_client = MagicMock()
|
||
mock_prisma_client.db.litellm_accessgrouptable.find_many = AsyncMock(
|
||
return_value=[mock_shared_row, mock_extra_row]
|
||
)
|
||
|
||
result = await _add_access_group_models_to_team_models(
|
||
team_db_objects_typed=[team_a, team_b],
|
||
llm_router=mock_router,
|
||
prisma_client=mock_prisma_client,
|
||
team_models={},
|
||
)
|
||
|
||
# Single batch query for both groups
|
||
mock_prisma_client.db.litellm_accessgrouptable.find_many.assert_called_once()
|
||
call_args = mock_prisma_client.db.litellm_accessgrouptable.find_many.call_args
|
||
queried_ids = set(call_args[1]["where"]["access_group_id"]["in"])
|
||
assert queried_ids == {"shared-group", "extra-group"}
|
||
|
||
# Both teams get claude-3 from the shared group
|
||
assert "claude-3-id" in result
|
||
assert "team-a" in result["claude-3-id"]
|
||
assert "team-b" in result["claude-3-id"]
|
||
|
||
# Only team-b gets gemini (from extra-group)
|
||
assert "gemini-id" in result
|
||
assert "team-b" in result["gemini-id"]
|
||
assert "team-a" not in result["gemini-id"]
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_add_access_group_models_no_eligible_teams():
|
||
"""
|
||
When no teams have access groups, find_many should not be called at all.
|
||
"""
|
||
from litellm.proxy._types import LiteLLM_TeamTable
|
||
from litellm.proxy.proxy_server import _add_access_group_models_to_team_models
|
||
|
||
team = MagicMock(spec=LiteLLM_TeamTable)
|
||
team.team_id = "team1"
|
||
team.models = ["gpt-4"]
|
||
team.access_group_ids = None
|
||
|
||
mock_prisma_client = MagicMock()
|
||
mock_prisma_client.db.litellm_accessgrouptable.find_many = AsyncMock()
|
||
|
||
result = await _add_access_group_models_to_team_models(
|
||
team_db_objects_typed=[team],
|
||
llm_router=MagicMock(),
|
||
prisma_client=mock_prisma_client,
|
||
team_models={"existing-id": {"team1"}},
|
||
)
|
||
|
||
# No DB call made
|
||
mock_prisma_client.db.litellm_accessgrouptable.find_many.assert_not_called()
|
||
|
||
# Original data unchanged
|
||
assert result == {"existing-id": {"team1"}}
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_get_all_team_models_with_access_groups():
|
||
"""
|
||
End-to-end test: get_all_team_models includes models from access groups.
|
||
|
||
Scenario: User is on team1 which has models=["gpt-4"] and
|
||
access_group_ids=["premium"]. The "premium" group has ["claude-3"].
|
||
The result should include both gpt-4 and claude-3 deployments for team1.
|
||
"""
|
||
from litellm.proxy.proxy_server import get_all_team_models
|
||
|
||
mock_team1 = MagicMock()
|
||
mock_team1.model_dump.return_value = {
|
||
"team_id": "team1",
|
||
"models": ["gpt-4"],
|
||
"team_alias": "Team 1",
|
||
"access_group_ids": ["premium"],
|
||
}
|
||
|
||
# Mock access group row returned by batch find_many
|
||
mock_ag_row = MagicMock()
|
||
mock_ag_row.access_group_id = "premium"
|
||
mock_ag_row.access_model_names = ["claude-3"]
|
||
|
||
mock_prisma_client = MagicMock()
|
||
mock_db = MagicMock()
|
||
mock_litellm_teamtable = MagicMock()
|
||
mock_prisma_client.db = mock_db
|
||
mock_db.litellm_teamtable = mock_litellm_teamtable
|
||
mock_litellm_teamtable.find_many = AsyncMock(return_value=[mock_team1])
|
||
mock_db.litellm_accessgrouptable = MagicMock()
|
||
mock_db.litellm_accessgrouptable.find_many = AsyncMock(return_value=[mock_ag_row])
|
||
|
||
mock_router = MagicMock()
|
||
|
||
def mock_get_model_list(model_name, team_id=None):
|
||
if model_name == "gpt-4":
|
||
return [{"model_info": {"id": "gpt-4-deploy-1"}}]
|
||
elif model_name == "claude-3":
|
||
return [{"model_info": {"id": "claude-3-deploy-1"}}]
|
||
return None
|
||
|
||
mock_router.get_model_list.side_effect = mock_get_model_list
|
||
|
||
with patch("litellm.proxy.proxy_server.LiteLLM_TeamTable") as mock_tt_class:
|
||
|
||
def mock_team_table_constructor(**kwargs):
|
||
mock_instance = MagicMock()
|
||
mock_instance.team_id = kwargs["team_id"]
|
||
mock_instance.models = kwargs["models"]
|
||
mock_instance.access_group_ids = kwargs.get("access_group_ids")
|
||
return mock_instance
|
||
|
||
mock_tt_class.side_effect = mock_team_table_constructor
|
||
|
||
result = await get_all_team_models(
|
||
user_teams=["team1"],
|
||
prisma_client=mock_prisma_client,
|
||
llm_router=mock_router,
|
||
)
|
||
|
||
# gpt-4 from direct team.models
|
||
assert "gpt-4-deploy-1" in result
|
||
assert "team1" in result["gpt-4-deploy-1"]
|
||
|
||
# claude-3 from access group
|
||
assert "claude-3-deploy-1" in result
|
||
assert "team1" in result["claude-3-deploy-1"]
|
||
|
||
# Return type is Dict[str, List[str]]
|
||
assert isinstance(result["gpt-4-deploy-1"], list)
|
||
assert isinstance(result["claude-3-deploy-1"], list)
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_delete_deployment_type_mismatch():
|
||
"""
|
||
Test that the _delete_deployment function handles type mismatches correctly.
|
||
Specifically test that models 12345678 and 12345679 are NOT deleted when
|
||
they exist in both combined_id_list (as integers) and router_model_ids (as strings).
|
||
|
||
This test reproduces the bug where type mismatch causes valid models to be deleted.
|
||
"""
|
||
from unittest.mock import MagicMock, patch
|
||
|
||
from litellm.proxy.proxy_server import ProxyConfig
|
||
|
||
# Create mock ProxyConfig instance
|
||
pc = ProxyConfig()
|
||
|
||
pc.get_config = MagicMock(
|
||
return_value={
|
||
"model_list": [
|
||
{
|
||
"model_name": "openai-gpt-4o",
|
||
"litellm_params": {"model": "gpt-4o"},
|
||
"model_info": {"id": 12345678},
|
||
},
|
||
{
|
||
"model_name": "openai-gpt-4o",
|
||
"litellm_params": {"model": "gpt-4o"},
|
||
"model_info": {"id": 12345679},
|
||
},
|
||
]
|
||
}
|
||
)
|
||
|
||
# Mock llm_router with string IDs (this is the source of the type mismatch)
|
||
mock_llm_router = MagicMock()
|
||
mock_llm_router.get_model_ids.return_value = [
|
||
"a96e12e76b36a57cfae57a41288eb41567629cac89b4828c6f7074afc3534695",
|
||
"a40186dd0fdb9b7282380277d7f57044d29de95bfbfcd7f4322b3493702d5cd3",
|
||
"12345678", # String ID
|
||
"12345679", # String ID
|
||
]
|
||
|
||
# Track which deployments were deleted
|
||
deleted_ids = []
|
||
|
||
def mock_delete_deployment(id):
|
||
deleted_ids.append(id)
|
||
return True # Simulate successful deletion
|
||
|
||
mock_llm_router.delete_deployment = MagicMock(side_effect=mock_delete_deployment)
|
||
|
||
# Mock get_config to return empty config (no config models)
|
||
async def mock_get_config(config_file_path):
|
||
return {}
|
||
|
||
pc.get_config = MagicMock(side_effect=mock_get_config)
|
||
|
||
# Patch the global llm_router
|
||
with (
|
||
patch("litellm.proxy.proxy_server.llm_router", mock_llm_router),
|
||
patch("litellm.proxy.proxy_server.user_config_file_path", "test_config.yaml"),
|
||
):
|
||
# Call the function under test
|
||
deleted_count = await pc._delete_deployment(db_models=[])
|
||
|
||
# Assertions: Models 12345678 and 12345679 should NOT be deleted
|
||
# because they exist in combined_id_list (as integers) even though
|
||
# router has them as strings
|
||
|
||
# The function should delete the other 2 models that are not in combined_id_list
|
||
assert deleted_count == 0, f"Expected 0 deletions, got {deleted_count}"
|
||
|
||
# Verify that 12345678 and 12345679 were NOT deleted
|
||
assert (
|
||
"12345678" not in deleted_ids
|
||
), f"Model 12345678 should NOT be deleted. Deleted IDs: {deleted_ids}"
|
||
assert (
|
||
"12345679" not in deleted_ids
|
||
), f"Model 12345679 should NOT be deleted. Deleted IDs: {deleted_ids}"
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_get_config_from_file(tmp_path, monkeypatch):
|
||
"""
|
||
Test the _get_config_from_file method of ProxyConfig class.
|
||
Tests various scenarios: valid file, non-existent file, no file path, None config.
|
||
"""
|
||
import yaml
|
||
|
||
from litellm.proxy.proxy_server import ProxyConfig
|
||
|
||
# Create a ProxyConfig instance
|
||
proxy_config = ProxyConfig()
|
||
|
||
# Test Case 1: Valid YAML config file exists
|
||
test_config = {
|
||
"model_list": [{"model_name": "gpt-4", "litellm_params": {"model": "gpt-4"}}],
|
||
"general_settings": {"master_key": "sk-test"},
|
||
"router_settings": {"enable_pre_call_checks": True},
|
||
"litellm_settings": {"drop_params": True},
|
||
}
|
||
|
||
config_file = tmp_path / "test_config.yaml"
|
||
with open(config_file, "w") as f:
|
||
yaml.dump(test_config, f)
|
||
|
||
# Clear global user_config_file_path for this test
|
||
monkeypatch.setattr("litellm.proxy.proxy_server.user_config_file_path", None)
|
||
|
||
result = await proxy_config._get_config_from_file(str(config_file))
|
||
assert result == test_config
|
||
|
||
# Verify that user_config_file_path was set
|
||
from litellm.proxy.proxy_server import user_config_file_path
|
||
|
||
assert user_config_file_path == str(config_file)
|
||
|
||
# Test Case 2: File path provided but file doesn't exist
|
||
non_existent_file = tmp_path / "non_existent.yaml"
|
||
|
||
with pytest.raises(Exception, match=f"Config file not found: {non_existent_file}"):
|
||
await proxy_config._get_config_from_file(str(non_existent_file))
|
||
|
||
# Test Case 3: No file path provided (should return default config)
|
||
monkeypatch.setattr("litellm.proxy.proxy_server.user_config_file_path", None)
|
||
|
||
expected_default = {
|
||
"model_list": [],
|
||
"general_settings": {},
|
||
"router_settings": {},
|
||
"litellm_settings": {},
|
||
}
|
||
|
||
result = await proxy_config._get_config_from_file(None)
|
||
assert result == expected_default
|
||
|
||
# Test Case 4: Empty YAML file (should raise exception for None config)
|
||
empty_file = tmp_path / "empty_config.yaml"
|
||
with open(empty_file, "w") as f:
|
||
f.write("") # Write empty content which will result in None when loaded
|
||
|
||
with pytest.raises(Exception, match="Config cannot be None or Empty."):
|
||
await proxy_config._get_config_from_file(str(empty_file))
|
||
|
||
# Test Case 5: Using global user_config_file_path when no config_file_path provided
|
||
monkeypatch.setattr(
|
||
"litellm.proxy.proxy_server.user_config_file_path", str(config_file)
|
||
)
|
||
|
||
result = await proxy_config._get_config_from_file(None)
|
||
assert result == test_config
|
||
|
||
|
||
def test_normalize_datetime_for_sorting():
|
||
"""
|
||
Test the _normalize_datetime_for_sorting function.
|
||
Tests various scenarios: None values, ISO format strings, datetime objects (naive and aware).
|
||
"""
|
||
from litellm.proxy.proxy_server import _normalize_datetime_for_sorting
|
||
|
||
# Test Case 1: None value
|
||
assert _normalize_datetime_for_sorting(None) is None
|
||
|
||
# Test Case 2: ISO format string with 'Z' suffix
|
||
dt_str_z = "2024-01-15T10:30:00Z"
|
||
result = _normalize_datetime_for_sorting(dt_str_z)
|
||
assert result is not None
|
||
assert isinstance(result, datetime)
|
||
assert result.tzinfo == timezone.utc
|
||
assert result.year == 2024
|
||
assert result.month == 1
|
||
assert result.day == 15
|
||
assert result.hour == 10
|
||
assert result.minute == 30
|
||
|
||
# Test Case 3: ISO format string without 'Z' suffix (naive)
|
||
dt_str_naive = "2024-01-15T10:30:00"
|
||
result = _normalize_datetime_for_sorting(dt_str_naive)
|
||
assert result is not None
|
||
assert isinstance(result, datetime)
|
||
assert result.tzinfo == timezone.utc
|
||
|
||
# Test Case 4: ISO format string with timezone offset
|
||
dt_str_tz = "2024-01-15T10:30:00+05:00"
|
||
result = _normalize_datetime_for_sorting(dt_str_tz)
|
||
assert result is not None
|
||
assert isinstance(result, datetime)
|
||
assert result.tzinfo == timezone.utc
|
||
# Should convert from +05:00 to UTC (subtract 5 hours)
|
||
assert result.hour == 5 # 10:30 - 5 hours = 5:30 UTC
|
||
|
||
# Test Case 5: Naive datetime object
|
||
naive_dt = datetime(2024, 1, 15, 10, 30, 0)
|
||
result = _normalize_datetime_for_sorting(naive_dt)
|
||
assert result is not None
|
||
assert isinstance(result, datetime)
|
||
assert result.tzinfo == timezone.utc
|
||
assert result.year == 2024
|
||
assert result.month == 1
|
||
assert result.day == 15
|
||
|
||
# Test Case 6: Timezone-aware datetime object (non-UTC)
|
||
from datetime import timedelta
|
||
|
||
aware_dt = datetime(2024, 1, 15, 10, 30, 0, tzinfo=timezone(timedelta(hours=5)))
|
||
result = _normalize_datetime_for_sorting(aware_dt)
|
||
assert result is not None
|
||
assert isinstance(result, datetime)
|
||
assert result.tzinfo == timezone.utc
|
||
# Should convert from +05:00 to UTC
|
||
assert result.hour == 5
|
||
|
||
# Test Case 7: UTC-aware datetime object
|
||
utc_dt = datetime(2024, 1, 15, 10, 30, 0, tzinfo=timezone.utc)
|
||
result = _normalize_datetime_for_sorting(utc_dt)
|
||
assert result is not None
|
||
assert isinstance(result, datetime)
|
||
assert result.tzinfo == timezone.utc
|
||
assert result == utc_dt
|
||
|
||
# Test Case 8: Invalid string format
|
||
invalid_str = "not-a-date"
|
||
result = _normalize_datetime_for_sorting(invalid_str)
|
||
assert result is None
|
||
|
||
# Test Case 9: Invalid type (should return None)
|
||
result = _normalize_datetime_for_sorting(12345)
|
||
assert result is None
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_add_proxy_budget_to_db_only_creates_user_no_keys():
|
||
"""
|
||
Test that _add_proxy_budget_to_db only creates a user and no keys are added.
|
||
|
||
This validates that generate_key_helper_fn is called with table_name="user"
|
||
which should prevent key creation in LiteLLM_VerificationToken table.
|
||
"""
|
||
from unittest.mock import AsyncMock, patch
|
||
|
||
import litellm
|
||
from litellm.proxy.proxy_server import ProxyStartupEvent
|
||
|
||
# Set up required litellm settings
|
||
litellm.budget_duration = "30d"
|
||
litellm.max_budget = 100.0
|
||
|
||
litellm_proxy_budget_name = "litellm-proxy-budget"
|
||
|
||
# Mock generate_key_helper_fn to capture its call arguments
|
||
mock_generate_key_helper = AsyncMock(
|
||
return_value={
|
||
"user_id": litellm_proxy_budget_name,
|
||
"max_budget": 100.0,
|
||
"budget_duration": "30d",
|
||
"spend": 0,
|
||
"models": [],
|
||
}
|
||
)
|
||
|
||
# Patch generate_key_helper_fn in proxy_server where it's being called from
|
||
with patch(
|
||
"litellm.proxy.proxy_server.generate_key_helper_fn", mock_generate_key_helper
|
||
):
|
||
# Call the function under test
|
||
ProxyStartupEvent._add_proxy_budget_to_db(litellm_proxy_budget_name)
|
||
|
||
# Allow async task to complete
|
||
import asyncio
|
||
|
||
await asyncio.sleep(0.1)
|
||
|
||
# Verify that generate_key_helper_fn was called
|
||
mock_generate_key_helper.assert_called_once()
|
||
call_args = mock_generate_key_helper.call_args
|
||
|
||
# Verify critical parameters that prevent key creation
|
||
assert call_args.kwargs["request_type"] == "user"
|
||
assert call_args.kwargs["table_name"] == "user"
|
||
assert call_args.kwargs["user_id"] == litellm_proxy_budget_name
|
||
assert call_args.kwargs["max_budget"] == 100.0
|
||
assert call_args.kwargs["budget_duration"] == "30d"
|
||
assert call_args.kwargs["query_type"] == "update_data"
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_add_proxy_budget_to_db_backfills_budget_reset_at():
|
||
"""
|
||
Test that _upsert_proxy_budget_with_reset_at_backfill issues a conditional
|
||
update_many with `WHERE budget_reset_at IS NULL` to backfill the column on
|
||
rows that pre-existed without a reset schedule. Without this, the proxy
|
||
admin row stays at NULL and reset_budget_for_litellm_users never matches
|
||
it (NULL < now() is unknown in SQL), so the global proxy budget never
|
||
resets.
|
||
"""
|
||
from unittest.mock import AsyncMock, MagicMock, patch
|
||
|
||
import litellm
|
||
from litellm.proxy.proxy_server import ProxyStartupEvent
|
||
|
||
litellm.budget_duration = "30d"
|
||
litellm.max_budget = 100.0
|
||
litellm_proxy_budget_name = "litellm-proxy-budget"
|
||
|
||
mock_prisma = MagicMock()
|
||
mock_prisma.db.litellm_usertable.update_many = AsyncMock(return_value={"count": 1})
|
||
|
||
mock_generate_key_helper = AsyncMock(
|
||
return_value={
|
||
"user_id": litellm_proxy_budget_name,
|
||
"max_budget": 100.0,
|
||
"budget_duration": "30d",
|
||
"spend": 0,
|
||
"models": [],
|
||
}
|
||
)
|
||
|
||
with (
|
||
patch(
|
||
"litellm.proxy.proxy_server.generate_key_helper_fn",
|
||
mock_generate_key_helper,
|
||
),
|
||
patch("litellm.proxy.proxy_server.prisma_client", mock_prisma),
|
||
):
|
||
await ProxyStartupEvent._upsert_proxy_budget_with_reset_at_backfill(
|
||
litellm_proxy_budget_name
|
||
)
|
||
|
||
# Upsert ran with the configured budget
|
||
mock_generate_key_helper.assert_called_once()
|
||
|
||
# Backfill update_many ran with the conditional WHERE
|
||
mock_prisma.db.litellm_usertable.update_many.assert_called_once()
|
||
backfill_call = mock_prisma.db.litellm_usertable.update_many.call_args
|
||
assert backfill_call.kwargs["where"]["user_id"] == litellm_proxy_budget_name
|
||
assert backfill_call.kwargs["where"]["budget_reset_at"] is None
|
||
|
||
# The backfilled value must be a real future datetime — anything else and
|
||
# reset_budget_for_litellm_users would still skip the row.
|
||
from datetime import datetime, timezone
|
||
|
||
backfilled_reset_at = backfill_call.kwargs["data"]["budget_reset_at"]
|
||
assert isinstance(backfilled_reset_at, datetime)
|
||
assert backfilled_reset_at > datetime.now(timezone.utc)
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_custom_ui_sso_sign_in_handler_config_loading():
|
||
"""
|
||
Test that custom_ui_sso_sign_in_handler from config gets properly loaded into the global variable
|
||
"""
|
||
import tempfile
|
||
from unittest.mock import MagicMock, patch
|
||
|
||
import yaml
|
||
|
||
from litellm.proxy.proxy_server import ProxyConfig
|
||
|
||
# Create a test config with custom_ui_sso_sign_in_handler
|
||
test_config = {
|
||
"general_settings": {
|
||
"custom_ui_sso_sign_in_handler": "custom_hooks.custom_ui_sso_hook.custom_ui_sso_sign_in_handler"
|
||
},
|
||
"model_list": [],
|
||
"router_settings": {},
|
||
"litellm_settings": {},
|
||
}
|
||
|
||
# Create temporary config file
|
||
with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f:
|
||
yaml.dump(test_config, f)
|
||
config_file_path = f.name
|
||
|
||
# Mock the get_instance_fn to return a mock handler
|
||
mock_custom_handler = MagicMock()
|
||
|
||
try:
|
||
with patch(
|
||
"litellm.proxy.proxy_server.get_instance_fn",
|
||
return_value=mock_custom_handler,
|
||
) as mock_get_instance:
|
||
# Create ProxyConfig instance and load config
|
||
proxy_config = ProxyConfig()
|
||
# Create a mock router since load_config requires it
|
||
mock_router = MagicMock()
|
||
await proxy_config.load_config(
|
||
router=mock_router, config_file_path=config_file_path
|
||
)
|
||
|
||
# Verify get_instance_fn was called with correct parameters
|
||
mock_get_instance.assert_called_with(
|
||
value="custom_hooks.custom_ui_sso_hook.custom_ui_sso_sign_in_handler",
|
||
config_file_path=config_file_path,
|
||
)
|
||
|
||
# Verify the global variable was set
|
||
from litellm.proxy.proxy_server import user_custom_ui_sso_sign_in_handler
|
||
|
||
assert user_custom_ui_sso_sign_in_handler == mock_custom_handler
|
||
|
||
finally:
|
||
# Clean up temporary file
|
||
import os
|
||
|
||
os.unlink(config_file_path)
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_load_environment_variables_direct_and_os_environ():
|
||
"""
|
||
Test _load_environment_variables method with direct values and os.environ/ prefixed values
|
||
"""
|
||
from unittest.mock import patch
|
||
|
||
from litellm.proxy.proxy_server import ProxyConfig
|
||
|
||
proxy_config = ProxyConfig()
|
||
|
||
# Test config with both direct values and os.environ/ prefixed values
|
||
test_config = {
|
||
"environment_variables": {
|
||
"DIRECT_VAR": "direct_value",
|
||
"NUMERIC_VAR": 12345,
|
||
"BOOL_VAR": True,
|
||
"SECRET_VAR": "os.environ/ACTUAL_SECRET_VAR",
|
||
}
|
||
}
|
||
|
||
# Mock get_secret_str to return a resolved value
|
||
mock_secret_value = "resolved_secret_value"
|
||
|
||
with patch(
|
||
"litellm.proxy.proxy_server.get_secret_str", return_value=mock_secret_value
|
||
) as mock_get_secret:
|
||
with patch.dict(
|
||
os.environ, {}, clear=False
|
||
): # Don't clear existing env vars, just track changes
|
||
# Call the method under test
|
||
proxy_config._load_environment_variables(test_config)
|
||
|
||
# Verify direct environment variables were set correctly
|
||
assert os.environ["DIRECT_VAR"] == "direct_value"
|
||
assert os.environ["NUMERIC_VAR"] == "12345" # Should be converted to string
|
||
assert os.environ["BOOL_VAR"] == "True" # Should be converted to string
|
||
|
||
# Verify os.environ/ prefixed variable was resolved and set
|
||
assert os.environ["SECRET_VAR"] == mock_secret_value
|
||
|
||
# Verify get_secret_str was called with the correct value
|
||
mock_get_secret.assert_called_once_with(
|
||
secret_name="os.environ/ACTUAL_SECRET_VAR"
|
||
)
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_load_environment_variables_litellm_license_and_edge_cases():
|
||
"""
|
||
Test _load_environment_variables method with LITELLM_LICENSE special handling and edge cases
|
||
"""
|
||
from unittest.mock import MagicMock, patch
|
||
|
||
from litellm.proxy.proxy_server import ProxyConfig
|
||
|
||
proxy_config = ProxyConfig()
|
||
|
||
# Test Case 1: LITELLM_LICENSE in environment_variables
|
||
test_config_with_license = {
|
||
"environment_variables": {
|
||
"LITELLM_LICENSE": "test_license_key",
|
||
"OTHER_VAR": "other_value",
|
||
}
|
||
}
|
||
|
||
# Mock _license_check
|
||
mock_license_check = MagicMock()
|
||
mock_license_check.is_premium.return_value = True
|
||
|
||
with patch("litellm.proxy.proxy_server._license_check", mock_license_check):
|
||
with patch.dict(os.environ, {}, clear=False):
|
||
# Call the method under test
|
||
proxy_config._load_environment_variables(test_config_with_license)
|
||
|
||
# Verify LITELLM_LICENSE was set in environment
|
||
assert os.environ["LITELLM_LICENSE"] == "test_license_key"
|
||
|
||
# Verify license check was updated
|
||
assert mock_license_check.license_str == "test_license_key"
|
||
mock_license_check.is_premium.assert_called_once()
|
||
|
||
# Test Case 2: No environment_variables in config
|
||
test_config_no_env_vars = {}
|
||
|
||
# This should not raise any errors and should return without doing anything
|
||
result = proxy_config._load_environment_variables(test_config_no_env_vars)
|
||
assert result is None # Method returns None
|
||
|
||
# Test Case 3: environment_variables is None
|
||
test_config_none_env_vars = {"environment_variables": None}
|
||
|
||
# This should not raise any errors and should return without doing anything
|
||
result = proxy_config._load_environment_variables(test_config_none_env_vars)
|
||
assert result is None # Method returns None
|
||
|
||
# Test Case 4: os.environ/ prefix but get_secret_str returns None
|
||
test_config_secret_none = {
|
||
"environment_variables": {"FAILED_SECRET": "os.environ/NONEXISTENT_SECRET"}
|
||
}
|
||
|
||
with patch("litellm.proxy.proxy_server.get_secret_str", return_value=None):
|
||
with patch.dict(os.environ, {}, clear=False):
|
||
# Call the method under test
|
||
proxy_config._load_environment_variables(test_config_secret_none)
|
||
|
||
# Verify that the environment variable was not set when secret resolution fails
|
||
assert "FAILED_SECRET" not in os.environ
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_load_environment_variables_blocks_dangerous_keys():
|
||
"""
|
||
Test that _load_environment_variables rejects dangerous env var keys
|
||
like PATH, LD_PRELOAD, PYTHONPATH, etc.
|
||
"""
|
||
import logging
|
||
|
||
from litellm.proxy.proxy_server import ProxyConfig
|
||
|
||
proxy_config = ProxyConfig()
|
||
|
||
original_path = os.environ.get("PATH", "")
|
||
|
||
test_config = {
|
||
"environment_variables": {
|
||
"PATH": "/tmp/evil",
|
||
"LD_PRELOAD": "/tmp/evil.so",
|
||
"PYTHONPATH": "/tmp/evil",
|
||
"SAFE_CUSTOM_VAR": "safe_value",
|
||
}
|
||
}
|
||
|
||
with patch.dict(os.environ, {}, clear=False):
|
||
proxy_config._load_environment_variables(test_config)
|
||
|
||
# Blocked keys should not be set to the attacker value
|
||
assert os.environ.get("PATH") != "/tmp/evil"
|
||
assert (
|
||
"LD_PRELOAD" not in os.environ or os.environ["LD_PRELOAD"] != "/tmp/evil.so"
|
||
)
|
||
assert os.environ.get("PYTHONPATH") != "/tmp/evil"
|
||
|
||
# Safe keys should still be set
|
||
assert os.environ["SAFE_CUSTOM_VAR"] == "safe_value"
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_load_environment_variables_allows_proxy_keys():
|
||
"""
|
||
Test that HTTP_PROXY/HTTPS_PROXY are allowed since they are commonly used
|
||
in corporate environments to route outbound API calls.
|
||
"""
|
||
from litellm.proxy.proxy_server import ProxyConfig
|
||
|
||
proxy_config = ProxyConfig()
|
||
|
||
test_config = {
|
||
"environment_variables": {
|
||
"HTTP_PROXY": "http://corp-proxy:8080",
|
||
"HTTPS_PROXY": "http://corp-proxy:8080",
|
||
}
|
||
}
|
||
|
||
with patch.dict(os.environ, {}, clear=False):
|
||
proxy_config._load_environment_variables(test_config)
|
||
|
||
assert os.environ["HTTP_PROXY"] == "http://corp-proxy:8080"
|
||
assert os.environ["HTTPS_PROXY"] == "http://corp-proxy:8080"
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_load_environment_variables_blocks_no_proxy():
|
||
"""
|
||
Test that NO_PROXY/no_proxy are blocked to prevent bypassing proxy-based
|
||
network monitoring.
|
||
"""
|
||
from litellm.proxy.proxy_server import ProxyConfig
|
||
|
||
proxy_config = ProxyConfig()
|
||
|
||
test_config = {
|
||
"environment_variables": {
|
||
"NO_PROXY": "internal-service",
|
||
"no_proxy": "internal-service",
|
||
}
|
||
}
|
||
|
||
with patch.dict(os.environ, {}, clear=False):
|
||
proxy_config._load_environment_variables(test_config)
|
||
|
||
assert os.environ.get("NO_PROXY") != "internal-service"
|
||
assert os.environ.get("no_proxy") != "internal-service"
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_write_config_to_file(monkeypatch):
|
||
"""
|
||
Do not write config to file if store_model_in_db is True
|
||
"""
|
||
from unittest.mock import AsyncMock, MagicMock, mock_open, patch
|
||
|
||
from litellm.proxy.proxy_server import ProxyConfig
|
||
|
||
# Set store_model_in_db to True
|
||
monkeypatch.setattr("litellm.proxy.proxy_server.store_model_in_db", True)
|
||
|
||
# Mock prisma_client to not be None (so DB path is taken)
|
||
mock_prisma_client = AsyncMock()
|
||
mock_prisma_client.insert_data = AsyncMock()
|
||
monkeypatch.setattr("litellm.proxy.proxy_server.prisma_client", mock_prisma_client)
|
||
|
||
# Mock general_settings
|
||
mock_general_settings = {"store_model_in_db": True}
|
||
monkeypatch.setattr(
|
||
"litellm.proxy.proxy_server.general_settings", mock_general_settings
|
||
)
|
||
|
||
# Mock user_config_file_path
|
||
test_config_path = "/tmp/test_config.yaml"
|
||
monkeypatch.setattr(
|
||
"litellm.proxy.proxy_server.user_config_file_path", test_config_path
|
||
)
|
||
|
||
proxy_config = ProxyConfig()
|
||
|
||
# Mock the open function to track if file writing is attempted
|
||
mock_file_open = mock_open()
|
||
|
||
with patch("builtins.open", mock_file_open), patch("yaml.dump") as mock_yaml_dump:
|
||
# Call save_config with test data
|
||
test_config = {"key": "value", "model_list": ["model1", "model2"]}
|
||
await proxy_config.save_config(new_config=test_config)
|
||
|
||
# Verify that file was NOT opened for writing (since store_model_in_db=True)
|
||
mock_file_open.assert_not_called()
|
||
mock_yaml_dump.assert_not_called()
|
||
|
||
# Verify that database insert was called instead
|
||
mock_prisma_client.insert_data.assert_called_once()
|
||
|
||
# Verify the config passed to DB has model_list removed
|
||
call_args = mock_prisma_client.insert_data.call_args
|
||
assert call_args.kwargs["data"] == {
|
||
"key": "value"
|
||
} # model_list should be popped
|
||
assert call_args.kwargs["table_name"] == "config"
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_write_config_to_file_when_store_model_in_db_false(monkeypatch):
|
||
"""
|
||
Test that config IS written to file when store_model_in_db is False
|
||
"""
|
||
from unittest.mock import AsyncMock, MagicMock, mock_open, patch
|
||
|
||
from litellm.proxy.proxy_server import ProxyConfig
|
||
|
||
# Set store_model_in_db to False
|
||
monkeypatch.setattr("litellm.proxy.proxy_server.store_model_in_db", False)
|
||
|
||
# Mock prisma_client to be None (so file path is taken)
|
||
monkeypatch.setattr("litellm.proxy.proxy_server.prisma_client", None)
|
||
|
||
# Mock general_settings
|
||
mock_general_settings = {"store_model_in_db": False}
|
||
monkeypatch.setattr(
|
||
"litellm.proxy.proxy_server.general_settings", mock_general_settings
|
||
)
|
||
|
||
# Mock user_config_file_path
|
||
test_config_path = "/tmp/test_config.yaml"
|
||
monkeypatch.setattr(
|
||
"litellm.proxy.proxy_server.user_config_file_path", test_config_path
|
||
)
|
||
|
||
proxy_config = ProxyConfig()
|
||
|
||
# Mock the open function and yaml.dump
|
||
mock_file_open = mock_open()
|
||
|
||
with patch("builtins.open", mock_file_open), patch("yaml.dump") as mock_yaml_dump:
|
||
# Call save_config with test data
|
||
test_config = {"key": "value", "other_key": "other_value"}
|
||
await proxy_config.save_config(new_config=test_config)
|
||
|
||
# Verify that file WAS opened for writing (since store_model_in_db=False)
|
||
mock_file_open.assert_called_once_with(f"{test_config_path}", "w")
|
||
|
||
# Verify yaml.dump was called with the config
|
||
mock_yaml_dump.assert_called_once_with(
|
||
test_config,
|
||
mock_file_open.return_value.__enter__.return_value,
|
||
default_flow_style=False,
|
||
)
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_async_data_generator_midstream_error():
|
||
"""
|
||
Test async_data_generator handles midstream error from async_post_call_streaming_hook
|
||
Specifically testing the case where Azure Content Safety Guardrail returns an error
|
||
"""
|
||
from litellm.proxy._types import UserAPIKeyAuth
|
||
from litellm.proxy.proxy_server import async_data_generator
|
||
from litellm.proxy.utils import ProxyLogging
|
||
|
||
# Create mock objects
|
||
mock_user_api_key_dict = MagicMock(spec=UserAPIKeyAuth)
|
||
mock_request_data = {
|
||
"model": "gpt-3.5-turbo",
|
||
"messages": [{"role": "user", "content": "test"}],
|
||
}
|
||
|
||
# Mock response chunks - simulating normal streaming that gets interrupted
|
||
mock_chunks = [
|
||
{"choices": [{"delta": {"content": "Hello"}}]},
|
||
{"choices": [{"delta": {"content": " world"}}]},
|
||
{"choices": [{"delta": {"content": " this"}}]},
|
||
]
|
||
|
||
# Mock the proxy_logging_obj
|
||
mock_proxy_logging_obj = MagicMock(spec=ProxyLogging)
|
||
|
||
# Mock async_post_call_streaming_iterator_hook to yield chunks
|
||
async def mock_streaming_iterator(*args, **kwargs):
|
||
for chunk in mock_chunks:
|
||
yield chunk
|
||
|
||
mock_proxy_logging_obj.async_post_call_streaming_iterator_hook = (
|
||
mock_streaming_iterator
|
||
)
|
||
|
||
# Mock async_post_call_streaming_hook to return error on third chunk
|
||
def mock_streaming_hook(*args, **kwargs):
|
||
chunk = kwargs.get("response")
|
||
# Return error message for the third chunk (simulating guardrail trigger)
|
||
if chunk == mock_chunks[2]:
|
||
return 'data: {"error": {"error": "Azure Content Safety Guardrail: Hate crossed severity 2, Got severity: 2"}}'
|
||
# Return normal chunks for first two
|
||
return chunk
|
||
|
||
mock_proxy_logging_obj.async_post_call_streaming_hook = AsyncMock(
|
||
side_effect=mock_streaming_hook
|
||
)
|
||
mock_proxy_logging_obj.post_call_failure_hook = AsyncMock()
|
||
|
||
# Mock the global proxy_logging_obj
|
||
with patch("litellm.proxy.proxy_server.proxy_logging_obj", mock_proxy_logging_obj):
|
||
# Create a mock response object
|
||
mock_response = MagicMock()
|
||
|
||
# Collect all yielded data from the generator
|
||
yielded_data = []
|
||
try:
|
||
async for data in async_data_generator(
|
||
mock_response, mock_user_api_key_dict, mock_request_data
|
||
):
|
||
yielded_data.append(data)
|
||
except Exception as e:
|
||
# If there's an exception, that's also part of what we want to test
|
||
pass
|
||
|
||
# Verify the results
|
||
assert (
|
||
len(yielded_data) >= 3
|
||
), f"Expected at least 3 chunks, got {len(yielded_data)}: {yielded_data}"
|
||
|
||
# First two chunks should be normal data
|
||
assert yielded_data[0].startswith(
|
||
"data: "
|
||
), f"First chunk should start with 'data: ', got: {yielded_data[0]}"
|
||
assert yielded_data[1].startswith(
|
||
"data: "
|
||
), f"Second chunk should start with 'data: ', got: {yielded_data[1]}"
|
||
|
||
# The error message should be yielded
|
||
error_found = False
|
||
done_found = False
|
||
|
||
for data in yielded_data:
|
||
if "Azure Content Safety Guardrail: Hate crossed severity 2" in data:
|
||
error_found = True
|
||
if "data: [DONE]" in data:
|
||
done_found = True
|
||
|
||
assert (
|
||
error_found
|
||
), f"Error message should be found in yielded data. Got: {yielded_data}"
|
||
assert done_found, f"[DONE] message should be found at the end. Got: {yielded_data}"
|
||
|
||
# Verify that the streaming hook was called for each chunk
|
||
assert mock_proxy_logging_obj.async_post_call_streaming_hook.call_count == len(
|
||
mock_chunks
|
||
)
|
||
|
||
# Verify that post_call_failure_hook was NOT called (since this is not an exception case)
|
||
mock_proxy_logging_obj.post_call_failure_hook.assert_not_called()
|
||
|
||
|
||
def _has_nested_none_values(obj, path="root"):
|
||
"""
|
||
Recursively check if an object contains nested None values.
|
||
|
||
Args:
|
||
obj: The object to check
|
||
path: Current path in the object tree (for debugging)
|
||
|
||
Returns:
|
||
List of paths where None values were found
|
||
"""
|
||
none_paths = []
|
||
|
||
if obj is None:
|
||
none_paths.append(path)
|
||
elif isinstance(obj, dict):
|
||
for key, value in obj.items():
|
||
none_paths.extend(_has_nested_none_values(value, f"{path}.{key}"))
|
||
elif isinstance(obj, (list, tuple)):
|
||
for i, item in enumerate(obj):
|
||
none_paths.extend(_has_nested_none_values(item, f"{path}[{i}]"))
|
||
elif hasattr(obj, "__dict__"):
|
||
# Handle object attributes
|
||
for key, value in obj.__dict__.items():
|
||
if not key.startswith("_"): # Skip private attributes
|
||
none_paths.extend(_has_nested_none_values(value, f"{path}.{key}"))
|
||
|
||
return none_paths
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_chat_completion_result_no_nested_none_values():
|
||
"""
|
||
Test that chat_completion result doesn't have nested None values when using exclude_none=True
|
||
"""
|
||
from unittest.mock import AsyncMock, MagicMock, patch
|
||
|
||
from fastapi import Request, Response
|
||
from pydantic import BaseModel
|
||
|
||
import litellm
|
||
from litellm.proxy._types import UserAPIKeyAuth
|
||
from litellm.proxy.proxy_server import chat_completion
|
||
|
||
# Create a mock ModelResponse with nested None values
|
||
mock_model_response = litellm.ModelResponse()
|
||
mock_model_response.id = "test-id"
|
||
mock_model_response.model = "gpt-3.5-turbo"
|
||
mock_model_response.object = "chat.completion"
|
||
mock_model_response.created = 1234567890
|
||
|
||
# Create message with None values that should be excluded
|
||
mock_message = litellm.Message(
|
||
content="Hello, world!",
|
||
role="assistant",
|
||
function_call=None, # This should be excluded
|
||
tool_calls=None, # This should be excluded
|
||
audio=None, # This should be excluded
|
||
reasoning_content=None, # This should be excluded
|
||
thinking_blocks=None, # This should be excluded
|
||
annotations=None, # This should be excluded
|
||
)
|
||
|
||
# Create choice with potential None values
|
||
mock_choice = litellm.Choices(
|
||
finish_reason="stop",
|
||
index=0,
|
||
message=mock_message,
|
||
logprobs=None, # This should be excluded when exclude_none=True
|
||
)
|
||
|
||
mock_model_response.choices = [mock_choice]
|
||
setattr(
|
||
mock_model_response,
|
||
"usage",
|
||
litellm.Usage(prompt_tokens=10, completion_tokens=5, total_tokens=15),
|
||
)
|
||
|
||
# Verify the mock has None values before serialization
|
||
raw_dict = mock_model_response.model_dump()
|
||
none_paths_before = _has_nested_none_values(raw_dict)
|
||
assert (
|
||
len(none_paths_before) > 0
|
||
), "Mock should have None values before exclude_none=True"
|
||
|
||
# Mock the request processing to return our mock response
|
||
mock_base_processor = MagicMock()
|
||
mock_base_processor.base_process_llm_request = AsyncMock(
|
||
return_value=mock_model_response
|
||
)
|
||
|
||
# Mock other dependencies
|
||
mock_request = MagicMock(spec=Request)
|
||
mock_response = MagicMock(spec=Response)
|
||
mock_user_api_key_dict = MagicMock(spec=UserAPIKeyAuth)
|
||
|
||
with (
|
||
patch(
|
||
"litellm.proxy.proxy_server._read_request_body",
|
||
return_value={"model": "gpt-3.5-turbo", "messages": []},
|
||
),
|
||
patch(
|
||
"litellm.proxy.proxy_server.ProxyBaseLLMRequestProcessing",
|
||
return_value=mock_base_processor,
|
||
),
|
||
):
|
||
# Call the chat_completion function
|
||
result = await chat_completion(
|
||
request=mock_request,
|
||
fastapi_response=mock_response,
|
||
user_api_key_dict=mock_user_api_key_dict,
|
||
)
|
||
|
||
# Verify the result is a dict (since isinstance(result, BaseModel) was True)
|
||
assert isinstance(result, dict), f"Expected dict result, got {type(result)}"
|
||
|
||
# Check that there are no nested None values in the result
|
||
none_paths_after = _has_nested_none_values(result)
|
||
assert (
|
||
len(none_paths_after) == 0
|
||
), f"Result should not contain nested None values. Found None at: {none_paths_after}"
|
||
|
||
# Verify essential fields are present
|
||
assert "id" in result
|
||
assert "model" in result
|
||
assert "object" in result
|
||
assert "created" in result
|
||
assert "choices" in result
|
||
assert "usage" in result
|
||
|
||
# Verify that the choices contain the expected message content
|
||
assert len(result["choices"]) == 1
|
||
assert result["choices"][0]["message"]["content"] == "Hello, world!"
|
||
assert result["choices"][0]["message"]["role"] == "assistant"
|
||
|
||
# Verify that None fields were excluded (should not be present in the dict)
|
||
message = result["choices"][0]["message"]
|
||
excluded_fields = [
|
||
"function_call",
|
||
"tool_calls",
|
||
"audio",
|
||
"reasoning_content",
|
||
"thinking_blocks",
|
||
"annotations",
|
||
]
|
||
for field in excluded_fields:
|
||
assert (
|
||
field not in message
|
||
), f"Field '{field}' should be excluded when it's None"
|
||
|
||
|
||
# ============================================================================
|
||
# Price Data Reload Tests
|
||
# ============================================================================
|
||
|
||
|
||
class TestPriceDataReloadAPI:
|
||
"""Test cases for price data reload API endpoints"""
|
||
|
||
@pytest.fixture
|
||
def client_with_auth(self):
|
||
"""Create a test client with authentication"""
|
||
from litellm.proxy._types import LitellmUserRoles
|
||
from litellm.proxy.proxy_server import cleanup_router_config_variables
|
||
|
||
cleanup_router_config_variables()
|
||
filepath = os.path.dirname(os.path.abspath(__file__))
|
||
config_fp = f"{filepath}/test_configs/test_config_no_auth.yaml"
|
||
asyncio.run(initialize(config=config_fp, debug=True))
|
||
|
||
# Mock admin user authentication
|
||
mock_auth = MagicMock()
|
||
mock_auth.user_role = LitellmUserRoles.PROXY_ADMIN
|
||
app.dependency_overrides[user_api_key_auth] = lambda: mock_auth
|
||
|
||
return TestClient(app)
|
||
|
||
def test_reload_model_cost_map_admin_access(self, client_with_auth):
|
||
"""Test that admin users can access the reload endpoint"""
|
||
# Save the original model_cost so the endpoint's direct assignment
|
||
# (litellm.model_cost = new_model_cost_map) does not contaminate
|
||
# subsequent tests running in the same worker process.
|
||
original_model_cost = litellm.model_cost.copy()
|
||
try:
|
||
with patch(
|
||
"litellm.litellm_core_utils.get_model_cost_map.get_model_cost_map"
|
||
) as mock_get_map:
|
||
mock_get_map.return_value = {
|
||
"gpt-3.5-turbo": {"input_cost_per_token": 0.001}
|
||
}
|
||
# Mock the database connection
|
||
with patch("litellm.proxy.proxy_server.prisma_client") as mock_prisma:
|
||
mock_prisma.db.litellm_config.find_unique = AsyncMock(
|
||
return_value=None
|
||
)
|
||
mock_prisma.db.litellm_config.upsert = AsyncMock(return_value=None)
|
||
|
||
response = client_with_auth.post("/reload/model_cost_map")
|
||
|
||
assert response.status_code == 200
|
||
data = response.json()
|
||
assert data["status"] == "success"
|
||
assert "message" in data
|
||
assert "timestamp" in data
|
||
assert "models_count" in data
|
||
# The new implementation immediately reloads and returns the count
|
||
assert (
|
||
"Price data reloaded successfully! 1 models updated."
|
||
in data["message"]
|
||
)
|
||
assert data["models_count"] == 1
|
||
finally:
|
||
# Restore the full model cost map so subsequent tests are not affected
|
||
litellm.model_cost = original_model_cost
|
||
_invalidate_model_cost_lowercase_map()
|
||
|
||
def test_reload_model_cost_map_non_admin_access(self, client_with_auth):
|
||
"""Test that non-admin users cannot access the reload endpoint"""
|
||
# Mock non-admin user
|
||
mock_auth = MagicMock()
|
||
mock_auth.user_role = "user" # Non-admin role
|
||
app.dependency_overrides[user_api_key_auth] = lambda: mock_auth
|
||
|
||
response = client_with_auth.post("/reload/model_cost_map")
|
||
|
||
assert response.status_code == 403
|
||
data = response.json()
|
||
assert "Access denied" in data["detail"]
|
||
assert "Admin role required" in data["detail"]
|
||
|
||
def test_get_model_cost_map_public_access(self, client_no_auth):
|
||
"""Test that the model cost map endpoint is publicly accessible"""
|
||
with patch(
|
||
"litellm.model_cost", {"gpt-3.5-turbo": {"input_cost_per_token": 0.001}}
|
||
):
|
||
response = client_no_auth.get("/public/litellm_model_cost_map")
|
||
|
||
assert response.status_code == 200
|
||
data = response.json()
|
||
assert "gpt-3.5-turbo" in data
|
||
|
||
def test_reload_model_cost_map_error_handling(self, client_with_auth):
|
||
"""Test error handling in the reload endpoint"""
|
||
with patch(
|
||
"litellm.litellm_core_utils.get_model_cost_map.get_model_cost_map"
|
||
) as mock_get_map:
|
||
mock_get_map.side_effect = Exception("Network error")
|
||
|
||
# Mock the database connection
|
||
with patch("litellm.proxy.proxy_server.prisma_client") as mock_prisma:
|
||
mock_prisma.db.litellm_config.find_unique = AsyncMock(return_value=None)
|
||
mock_prisma.db.litellm_config.upsert = AsyncMock(return_value=None)
|
||
|
||
response = client_with_auth.post("/reload/model_cost_map")
|
||
|
||
assert (
|
||
response.status_code == 500
|
||
) # The new implementation immediately reloads and fails on error
|
||
data = response.json()
|
||
assert "Failed to reload model cost map" in data["detail"]
|
||
|
||
def test_schedule_model_cost_map_reload_admin_access(self, client_with_auth):
|
||
"""Test that admin users can schedule periodic reload"""
|
||
with patch("litellm.proxy.proxy_server.prisma_client") as mock_prisma:
|
||
# Mock database upsert
|
||
mock_prisma.db.litellm_config.upsert = AsyncMock(return_value=None)
|
||
|
||
response = client_with_auth.post("/schedule/model_cost_map_reload?hours=6")
|
||
|
||
assert response.status_code == 200
|
||
data = response.json()
|
||
assert data["status"] == "success"
|
||
assert data["interval_hours"] == 6
|
||
assert "message" in data
|
||
assert "timestamp" in data
|
||
|
||
def test_schedule_model_cost_map_reload_non_admin_access(self, client_with_auth):
|
||
"""Test that non-admin users cannot schedule periodic reload"""
|
||
# Mock non-admin user
|
||
mock_auth = MagicMock()
|
||
mock_auth.user_role = "user" # Non-admin role
|
||
app.dependency_overrides[user_api_key_auth] = lambda: mock_auth
|
||
|
||
response = client_with_auth.post("/schedule/model_cost_map_reload?hours=6")
|
||
|
||
assert response.status_code == 403
|
||
data = response.json()
|
||
assert "Access denied" in data["detail"]
|
||
assert "Admin role required" in data["detail"]
|
||
|
||
def test_schedule_model_cost_map_reload_invalid_hours(self, client_with_auth):
|
||
"""Test that invalid hours parameter is rejected"""
|
||
response = client_with_auth.post("/schedule/model_cost_map_reload?hours=0")
|
||
|
||
assert response.status_code == 400
|
||
data = response.json()
|
||
assert "Hours must be greater than 0" in data["detail"]
|
||
|
||
def test_cancel_model_cost_map_reload_admin_access(self, client_with_auth):
|
||
"""Test that admin users can cancel periodic reload"""
|
||
with patch("litellm.proxy.proxy_server.prisma_client") as mock_prisma:
|
||
# Mock database delete
|
||
mock_prisma.db.litellm_config.delete = AsyncMock(return_value=None)
|
||
|
||
response = client_with_auth.delete("/schedule/model_cost_map_reload")
|
||
|
||
assert response.status_code == 200
|
||
data = response.json()
|
||
assert data["status"] == "success"
|
||
assert "message" in data
|
||
assert "timestamp" in data
|
||
|
||
def test_cancel_model_cost_map_reload_non_admin_access(self, client_with_auth):
|
||
"""Test that non-admin users cannot cancel periodic reload"""
|
||
# Mock non-admin user
|
||
mock_auth = MagicMock()
|
||
mock_auth.user_role = "user" # Non-admin role
|
||
app.dependency_overrides[user_api_key_auth] = lambda: mock_auth
|
||
|
||
response = client_with_auth.delete("/schedule/model_cost_map_reload")
|
||
|
||
assert response.status_code == 403
|
||
data = response.json()
|
||
assert "Access denied" in data["detail"]
|
||
assert "Admin role required" in data["detail"]
|
||
|
||
def test_get_model_cost_map_reload_status_admin_access(self, client_with_auth):
|
||
"""Test that admin users can get reload status"""
|
||
with patch("litellm.proxy.proxy_server.prisma_client") as mock_prisma:
|
||
# Mock database config record
|
||
mock_config = MagicMock()
|
||
mock_config.param_value = {"interval_hours": 6, "force_reload": False}
|
||
mock_prisma.db.litellm_config.find_unique = AsyncMock(
|
||
return_value=mock_config
|
||
)
|
||
|
||
# Mock the last reload time and current time
|
||
with patch(
|
||
"litellm.proxy.proxy_server.last_model_cost_map_reload",
|
||
"2024-01-01T06:00:00",
|
||
):
|
||
with patch("litellm.proxy.proxy_server.datetime") as mock_datetime:
|
||
# Mock current time to be 1 hour after last reload
|
||
mock_datetime.utcnow.return_value = datetime(2024, 1, 1, 7, 0, 0)
|
||
mock_datetime.fromisoformat = datetime.fromisoformat
|
||
|
||
response = client_with_auth.get(
|
||
"/schedule/model_cost_map_reload/status"
|
||
)
|
||
|
||
assert response.status_code == 200
|
||
data = response.json()
|
||
assert data["scheduled"] == True
|
||
assert data["interval_hours"] == 6
|
||
assert data["last_run"] == "2024-01-01T06:00:00"
|
||
assert data["next_run"] == "2024-01-01T12:00:00"
|
||
|
||
def test_get_model_cost_map_reload_status_non_admin_access(self, client_with_auth):
|
||
"""Test that non-admin users cannot get reload status"""
|
||
# Mock non-admin user
|
||
mock_auth = MagicMock()
|
||
mock_auth.user_role = "user" # Non-admin role
|
||
app.dependency_overrides[user_api_key_auth] = lambda: mock_auth
|
||
|
||
response = client_with_auth.get("/schedule/model_cost_map_reload/status")
|
||
|
||
assert response.status_code == 403
|
||
data = response.json()
|
||
assert "Access denied" in data["detail"]
|
||
assert "Admin role required" in data["detail"]
|
||
|
||
def test_get_model_cost_map_reload_status_no_config(self, client_with_auth):
|
||
"""Test that status returns not scheduled when no config exists"""
|
||
with patch("litellm.proxy.proxy_server.prisma_client") as mock_prisma:
|
||
mock_prisma.db.litellm_config.find_unique = AsyncMock(return_value=None)
|
||
|
||
response = client_with_auth.get("/schedule/model_cost_map_reload/status")
|
||
|
||
assert response.status_code == 200
|
||
data = response.json()
|
||
assert data["scheduled"] == False
|
||
assert data["interval_hours"] == None
|
||
assert data["last_run"] == None
|
||
assert data["next_run"] == None
|
||
|
||
def test_get_model_cost_map_reload_status_no_interval(self, client_with_auth):
|
||
"""Test that status returns not scheduled when no interval is configured"""
|
||
with patch("litellm.proxy.proxy_server.prisma_client") as mock_prisma:
|
||
# Mock config with no interval
|
||
mock_config = MagicMock()
|
||
mock_config.param_value = {"interval_hours": None, "force_reload": False}
|
||
mock_prisma.db.litellm_config.find_unique = AsyncMock(
|
||
return_value=mock_config
|
||
)
|
||
|
||
response = client_with_auth.get("/schedule/model_cost_map_reload/status")
|
||
|
||
assert response.status_code == 200
|
||
data = response.json()
|
||
assert data["scheduled"] == False
|
||
assert data["interval_hours"] == None
|
||
assert data["last_run"] == None
|
||
assert data["next_run"] == None
|
||
|
||
|
||
class TestPriceDataReloadIntegration:
|
||
"""Integration tests for the complete price data reload feature"""
|
||
|
||
@pytest.fixture(autouse=True)
|
||
def _flush_litellm_config_cache(self):
|
||
from litellm.proxy.utils import litellm_config_cache
|
||
|
||
litellm_config_cache.flush_cache()
|
||
yield
|
||
litellm_config_cache.flush_cache()
|
||
|
||
@pytest.fixture
|
||
def client_with_auth(self):
|
||
"""Create a test client with authentication"""
|
||
from litellm.proxy._types import LitellmUserRoles
|
||
from litellm.proxy.proxy_server import cleanup_router_config_variables
|
||
|
||
cleanup_router_config_variables()
|
||
filepath = os.path.dirname(os.path.abspath(__file__))
|
||
config_fp = f"{filepath}/test_configs/test_config_no_auth.yaml"
|
||
asyncio.run(initialize(config=config_fp, debug=True))
|
||
|
||
# Mock admin user authentication
|
||
mock_auth = MagicMock()
|
||
mock_auth.user_role = LitellmUserRoles.PROXY_ADMIN
|
||
app.dependency_overrides[user_api_key_auth] = lambda: mock_auth
|
||
|
||
return TestClient(app)
|
||
|
||
def test_complete_reload_flow(self, client_with_auth):
|
||
"""Test the complete reload flow from API to model cost update"""
|
||
# Mock the model cost map
|
||
mock_cost_map = {
|
||
"gpt-3.5-turbo": {
|
||
"input_cost_per_token": 0.001,
|
||
"output_cost_per_token": 0.002,
|
||
},
|
||
"gpt-4": {"input_cost_per_token": 0.03, "output_cost_per_token": 0.06},
|
||
}
|
||
|
||
original_model_cost = litellm.model_cost.copy()
|
||
try:
|
||
with patch(
|
||
"litellm.litellm_core_utils.get_model_cost_map.get_model_cost_map"
|
||
) as mock_get_map:
|
||
mock_get_map.return_value = mock_cost_map
|
||
|
||
# Mock the database connection
|
||
with patch("litellm.proxy.proxy_server.prisma_client") as mock_prisma:
|
||
mock_prisma.db.litellm_config.find_unique = AsyncMock(
|
||
return_value=None
|
||
)
|
||
mock_prisma.db.litellm_config.upsert = AsyncMock(return_value=None)
|
||
|
||
# Test reload endpoint
|
||
response = client_with_auth.post("/reload/model_cost_map")
|
||
assert response.status_code == 200
|
||
|
||
# Test get endpoint
|
||
response = client_with_auth.get("/public/litellm_model_cost_map")
|
||
assert response.status_code == 200
|
||
finally:
|
||
litellm.model_cost = original_model_cost
|
||
_invalidate_model_cost_lowercase_map()
|
||
|
||
def test_distributed_reload_check_function(self):
|
||
"""Test the _check_and_reload_model_cost_map function"""
|
||
from litellm.proxy.proxy_server import ProxyConfig
|
||
from litellm.proxy.utils import litellm_config_cache
|
||
|
||
proxy_config = ProxyConfig()
|
||
|
||
# Mock prisma client
|
||
mock_prisma = MagicMock()
|
||
|
||
# Test case 1: No config in database
|
||
mock_prisma.db.litellm_config.find_unique = AsyncMock(return_value=None)
|
||
# _check_and_reload_model_cost_map routes through get_config_param,
|
||
# which calls prisma.get_generic_data on a cache miss.
|
||
mock_prisma.get_generic_data = AsyncMock(return_value=None)
|
||
|
||
# Should return early without reloading
|
||
asyncio.run(proxy_config._check_and_reload_model_cost_map(mock_prisma))
|
||
|
||
# Test case 2: Config with interval but not time to reload
|
||
litellm_config_cache.flush_cache()
|
||
mock_config = MagicMock()
|
||
mock_config.param_value = {"interval_hours": 6, "force_reload": False}
|
||
mock_prisma.db.litellm_config.find_unique = AsyncMock(return_value=mock_config)
|
||
mock_prisma.get_generic_data = AsyncMock(return_value=mock_config)
|
||
|
||
# Mock current time and last reload time
|
||
with patch(
|
||
"litellm.proxy.proxy_server.last_model_cost_map_reload",
|
||
"2024-01-01T06:00:00",
|
||
):
|
||
with patch("litellm.proxy.proxy_server.datetime") as mock_datetime:
|
||
mock_datetime.utcnow.return_value = datetime(
|
||
2024, 1, 1, 7, 0, 0
|
||
) # 1 hour later
|
||
|
||
# Should not reload (only 1 hour passed, need 6)
|
||
asyncio.run(proxy_config._check_and_reload_model_cost_map(mock_prisma))
|
||
|
||
# Test case 3: Config with force reload
|
||
litellm_config_cache.flush_cache()
|
||
mock_config.param_value = {"interval_hours": 6, "force_reload": True}
|
||
mock_prisma.db.litellm_config.find_unique = AsyncMock(return_value=mock_config)
|
||
mock_prisma.get_generic_data = AsyncMock(return_value=mock_config)
|
||
mock_prisma.db.litellm_config.upsert = AsyncMock(return_value=None)
|
||
|
||
original_model_cost = litellm.model_cost.copy()
|
||
try:
|
||
with patch(
|
||
"litellm.litellm_core_utils.get_model_cost_map.get_model_cost_map"
|
||
) as mock_get_map:
|
||
mock_get_map.return_value = {
|
||
"gpt-3.5-turbo": {"input_cost_per_token": 0.001}
|
||
}
|
||
|
||
# Should reload due to force flag
|
||
asyncio.run(proxy_config._check_and_reload_model_cost_map(mock_prisma))
|
||
|
||
# Verify force_reload was reset to False
|
||
mock_prisma.db.litellm_config.upsert.assert_called()
|
||
call_args = mock_prisma.db.litellm_config.upsert.call_args
|
||
# The param_value is now a JSON string, so we need to parse it
|
||
param_value_json = call_args[1]["data"]["update"]["param_value"]
|
||
param_value_dict = json.loads(param_value_json)
|
||
assert param_value_dict["force_reload"] == False
|
||
assert param_value_dict.get("interval_hours") == 6
|
||
finally:
|
||
litellm.model_cost = original_model_cost
|
||
_invalidate_model_cost_lowercase_map()
|
||
|
||
def test_distributed_reload_preserves_interval_hours(self):
|
||
"""Test that _check_and_reload_model_cost_map preserves interval_hours after reload.
|
||
|
||
Regression test: the update branch of the upsert was previously dropping
|
||
interval_hours, causing scheduled reloads to self-destruct after first execution.
|
||
"""
|
||
from litellm.proxy.proxy_server import ProxyConfig
|
||
|
||
proxy_config = ProxyConfig()
|
||
mock_prisma = MagicMock()
|
||
|
||
# Set up config with interval_hours=24 and force_reload=True to trigger reload
|
||
mock_config = MagicMock()
|
||
mock_config.param_value = {"interval_hours": 24, "force_reload": True}
|
||
mock_prisma.db.litellm_config.find_unique = AsyncMock(return_value=mock_config)
|
||
# _check_and_reload_model_cost_map now reads through get_generic_data.
|
||
mock_prisma.get_generic_data = AsyncMock(return_value=mock_config)
|
||
mock_prisma.db.litellm_config.upsert = AsyncMock(return_value=None)
|
||
|
||
original_model_cost = litellm.model_cost.copy()
|
||
try:
|
||
with patch(
|
||
"litellm.litellm_core_utils.get_model_cost_map.get_model_cost_map"
|
||
) as mock_get_map:
|
||
mock_get_map.return_value = {"gpt-4": {"input_cost_per_token": 0.001}}
|
||
|
||
asyncio.run(proxy_config._check_and_reload_model_cost_map(mock_prisma))
|
||
|
||
# Verify the upsert update branch preserves interval_hours
|
||
mock_prisma.db.litellm_config.upsert.assert_called()
|
||
call_args = mock_prisma.db.litellm_config.upsert.call_args
|
||
param_value_json = call_args[1]["data"]["update"]["param_value"]
|
||
param_value_dict = json.loads(param_value_json)
|
||
assert param_value_dict["force_reload"] == False
|
||
assert param_value_dict["interval_hours"] == 24, (
|
||
"interval_hours must be preserved in the update branch; "
|
||
"dropping it causes the schedule to self-destruct"
|
||
)
|
||
finally:
|
||
litellm.model_cost = original_model_cost
|
||
_invalidate_model_cost_lowercase_map()
|
||
|
||
def test_manual_reload_preserves_interval_hours(self):
|
||
"""Test that manual reload via /reload/model_cost_map preserves existing interval_hours.
|
||
|
||
Regression test: the manual reload endpoint was overwriting param_value with
|
||
only force_reload=True, dropping any existing interval_hours schedule.
|
||
"""
|
||
from litellm.proxy._types import LitellmUserRoles
|
||
from litellm.proxy.proxy_server import cleanup_router_config_variables
|
||
|
||
cleanup_router_config_variables()
|
||
filepath = os.path.dirname(os.path.abspath(__file__))
|
||
config_fp = f"{filepath}/test_configs/test_config_no_auth.yaml"
|
||
asyncio.run(initialize(config=config_fp, debug=True))
|
||
|
||
mock_auth = MagicMock()
|
||
mock_auth.user_role = LitellmUserRoles.PROXY_ADMIN
|
||
app.dependency_overrides[user_api_key_auth] = lambda: mock_auth
|
||
client = TestClient(app)
|
||
|
||
original_model_cost = litellm.model_cost.copy()
|
||
try:
|
||
with patch(
|
||
"litellm.litellm_core_utils.get_model_cost_map.get_model_cost_map"
|
||
) as mock_get_map:
|
||
mock_get_map.return_value = {"gpt-4": {"input_cost_per_token": 0.001}}
|
||
|
||
with patch("litellm.proxy.proxy_server.prisma_client") as mock_prisma:
|
||
# Simulate existing config with a schedule
|
||
mock_existing = MagicMock()
|
||
mock_existing.param_value = {
|
||
"interval_hours": 12,
|
||
"force_reload": False,
|
||
}
|
||
mock_prisma.db.litellm_config.find_unique = AsyncMock(
|
||
return_value=mock_existing
|
||
)
|
||
mock_prisma.db.litellm_config.upsert = AsyncMock(return_value=None)
|
||
|
||
response = client.post("/reload/model_cost_map")
|
||
assert response.status_code == 200
|
||
|
||
# Verify interval_hours was preserved in the upsert
|
||
mock_prisma.db.litellm_config.upsert.assert_called()
|
||
call_args = mock_prisma.db.litellm_config.upsert.call_args
|
||
param_value_json = call_args[1]["data"]["update"]["param_value"]
|
||
param_value_dict = json.loads(param_value_json)
|
||
assert param_value_dict["force_reload"] == True
|
||
assert param_value_dict["interval_hours"] == 12, (
|
||
"interval_hours must be preserved when manual reload sets force_reload; "
|
||
"dropping it destroys any existing schedule"
|
||
)
|
||
finally:
|
||
litellm.model_cost = original_model_cost
|
||
_invalidate_model_cost_lowercase_map()
|
||
|
||
def test_anthropic_beta_headers_reload_preserves_interval_hours(self):
|
||
"""Test that _check_and_reload_anthropic_beta_headers preserves interval_hours after reload.
|
||
|
||
Regression test: the update branch of the upsert was dropping interval_hours,
|
||
identical to the model cost map bug.
|
||
"""
|
||
from litellm.proxy.proxy_server import ProxyConfig
|
||
|
||
proxy_config = ProxyConfig()
|
||
mock_prisma = MagicMock()
|
||
|
||
# Set up config with interval_hours=12 and force_reload=True to trigger reload
|
||
mock_config = MagicMock()
|
||
mock_config.param_value = {"interval_hours": 12, "force_reload": True}
|
||
mock_prisma.db.litellm_config.find_unique = AsyncMock(return_value=mock_config)
|
||
# _check_and_reload_anthropic_beta_headers now reads through get_generic_data.
|
||
mock_prisma.get_generic_data = AsyncMock(return_value=mock_config)
|
||
mock_prisma.db.litellm_config.upsert = AsyncMock(return_value=None)
|
||
|
||
with patch(
|
||
"litellm.anthropic_beta_headers_manager.reload_beta_headers_config"
|
||
) as mock_reload:
|
||
mock_reload.return_value = {"anthropic": {"beta_header": "test-value"}}
|
||
|
||
asyncio.run(
|
||
proxy_config._check_and_reload_anthropic_beta_headers(mock_prisma)
|
||
)
|
||
|
||
# Verify the upsert update branch preserves interval_hours
|
||
mock_prisma.db.litellm_config.upsert.assert_called()
|
||
call_args = mock_prisma.db.litellm_config.upsert.call_args
|
||
param_value_json = call_args[1]["data"]["update"]["param_value"]
|
||
param_value_dict = json.loads(param_value_json)
|
||
assert param_value_dict["force_reload"] == False
|
||
assert param_value_dict["interval_hours"] == 12, (
|
||
"interval_hours must be preserved in the update branch; "
|
||
"dropping it causes the schedule to self-destruct"
|
||
)
|
||
|
||
def test_anthropic_beta_headers_manual_reload_preserves_interval_hours(self):
|
||
"""Test that manual reload via /reload/anthropic_beta_headers preserves existing interval_hours.
|
||
|
||
Regression test: the manual reload endpoint was overwriting param_value with
|
||
only force_reload=True, dropping any existing interval_hours schedule.
|
||
"""
|
||
from litellm.proxy._types import LitellmUserRoles
|
||
from litellm.proxy.proxy_server import cleanup_router_config_variables
|
||
|
||
cleanup_router_config_variables()
|
||
filepath = os.path.dirname(os.path.abspath(__file__))
|
||
config_fp = f"{filepath}/test_configs/test_config_no_auth.yaml"
|
||
asyncio.run(initialize(config=config_fp, debug=True))
|
||
|
||
mock_auth = MagicMock()
|
||
mock_auth.user_role = LitellmUserRoles.PROXY_ADMIN
|
||
app.dependency_overrides[user_api_key_auth] = lambda: mock_auth
|
||
client = TestClient(app)
|
||
|
||
with patch(
|
||
"litellm.anthropic_beta_headers_manager.reload_beta_headers_config"
|
||
) as mock_reload:
|
||
mock_reload.return_value = {"anthropic": {"beta_header": "test-value"}}
|
||
|
||
with patch("litellm.proxy.proxy_server.prisma_client") as mock_prisma:
|
||
# Simulate existing config with a schedule
|
||
mock_existing = MagicMock()
|
||
mock_existing.param_value = {"interval_hours": 8, "force_reload": False}
|
||
mock_prisma.db.litellm_config.find_unique = AsyncMock(
|
||
return_value=mock_existing
|
||
)
|
||
mock_prisma.db.litellm_config.upsert = AsyncMock(return_value=None)
|
||
|
||
response = client.post("/reload/anthropic_beta_headers")
|
||
assert response.status_code == 200
|
||
|
||
# Verify interval_hours was preserved in the upsert
|
||
mock_prisma.db.litellm_config.upsert.assert_called()
|
||
call_args = mock_prisma.db.litellm_config.upsert.call_args
|
||
param_value_json = call_args[1]["data"]["update"]["param_value"]
|
||
param_value_dict = json.loads(param_value_json)
|
||
assert param_value_dict["force_reload"] == True
|
||
assert param_value_dict["interval_hours"] == 8, (
|
||
"interval_hours must be preserved when manual reload sets force_reload; "
|
||
"dropping it destroys any existing schedule"
|
||
)
|
||
|
||
def test_config_file_parsing(self):
|
||
"""Test parsing of config file with reload settings"""
|
||
config_content = """
|
||
general_settings:
|
||
master_key: sk-1234
|
||
model_cost_map_reload_interval: 21600
|
||
|
||
model_list:
|
||
- model_name: gpt-3.5-turbo
|
||
litellm_params:
|
||
model: gpt-3.5-turbo
|
||
- model_name: gpt-4
|
||
litellm_params:
|
||
model: gpt-4
|
||
"""
|
||
|
||
# Parse the config
|
||
config = yaml.safe_load(config_content)
|
||
|
||
# Verify the reload setting is present
|
||
assert "general_settings" in config
|
||
assert "model_cost_map_reload_interval" in config["general_settings"]
|
||
assert config["general_settings"]["model_cost_map_reload_interval"] == 21600
|
||
|
||
# Verify models are present
|
||
assert "model_list" in config
|
||
assert len(config["model_list"]) == 2
|
||
|
||
def test_database_config_storage(self):
|
||
"""Test that configuration is properly stored in database"""
|
||
# Mock prisma client
|
||
mock_prisma = MagicMock()
|
||
|
||
# Test the database upsert call that would be made by the schedule endpoint
|
||
mock_prisma.db.litellm_config.upsert = AsyncMock(return_value=None)
|
||
|
||
# Simulate the database call that the schedule endpoint would make
|
||
asyncio.run(
|
||
mock_prisma.db.litellm_config.upsert(
|
||
where={"param_name": "model_cost_map_reload_config"},
|
||
data={
|
||
"create": {
|
||
"param_name": "model_cost_map_reload_config",
|
||
"param_value": {"interval_hours": 6, "force_reload": False},
|
||
},
|
||
"update": {
|
||
"param_value": {"interval_hours": 6, "force_reload": False}
|
||
},
|
||
},
|
||
)
|
||
)
|
||
|
||
# Verify database upsert was called with correct data
|
||
mock_prisma.db.litellm_config.upsert.assert_called_once()
|
||
call_args = mock_prisma.db.litellm_config.upsert.call_args
|
||
assert call_args[1]["where"]["param_name"] == "model_cost_map_reload_config"
|
||
assert call_args[1]["data"]["create"]["param_value"]["interval_hours"] == 6
|
||
assert call_args[1]["data"]["create"]["param_value"]["force_reload"] == False
|
||
|
||
def test_manual_reload_force_flag(self):
|
||
"""Test that manual reload sets force flag correctly"""
|
||
# Mock prisma client
|
||
mock_prisma = MagicMock()
|
||
|
||
# Test the database upsert call that would be made by the manual reload endpoint
|
||
mock_prisma.db.litellm_config.upsert = AsyncMock(return_value=None)
|
||
|
||
# Simulate the database call that the manual reload endpoint would make
|
||
asyncio.run(
|
||
mock_prisma.db.litellm_config.upsert(
|
||
where={"param_name": "model_cost_map_reload_config"},
|
||
data={
|
||
"create": {
|
||
"param_name": "model_cost_map_reload_config",
|
||
"param_value": {"interval_hours": None, "force_reload": True},
|
||
},
|
||
"update": {"param_value": {"force_reload": True}},
|
||
},
|
||
)
|
||
)
|
||
|
||
# Verify force_reload flag was set
|
||
mock_prisma.db.litellm_config.upsert.assert_called_once()
|
||
call_args = mock_prisma.db.litellm_config.upsert.call_args
|
||
assert call_args[1]["data"]["update"]["param_value"]["force_reload"] == True
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_add_router_settings_from_db_config_merge_logic():
|
||
"""
|
||
Test the _add_router_settings_from_db_config method's merge logic.
|
||
|
||
This tests how router settings from config file and database are combined,
|
||
including scenarios where nested dictionaries should be properly merged.
|
||
"""
|
||
from unittest.mock import AsyncMock, MagicMock, patch
|
||
|
||
from litellm.proxy.proxy_server import ProxyConfig
|
||
|
||
# Create ProxyConfig instance
|
||
proxy_config = ProxyConfig()
|
||
|
||
# Mock router
|
||
mock_router = MagicMock()
|
||
mock_router.update_settings = MagicMock()
|
||
|
||
# Test Case 1: Both config and DB settings exist - should merge them
|
||
config_data = {
|
||
"router_settings": {
|
||
"routing_strategy": "usage-based-routing",
|
||
"model_group_alias": {"gpt-4": "openai-gpt-4"},
|
||
"enable_pre_call_checks": True,
|
||
"timeout": 30,
|
||
"nested_config": {"setting1": "config_value1", "setting2": "config_value2"},
|
||
}
|
||
}
|
||
|
||
# Mock database config record
|
||
mock_db_config = MagicMock()
|
||
mock_db_config.param_value = {
|
||
"routing_strategy": "least-busy", # This should override config value
|
||
"retry_delay": 2, # This is new, should be added
|
||
"nested_config": {
|
||
"setting2": "db_value2", # This should override config value
|
||
"setting3": "db_value3", # This is new, should be added
|
||
},
|
||
}
|
||
|
||
# Mock prisma client
|
||
mock_prisma_client = MagicMock()
|
||
mock_prisma_client.db.litellm_config.find_first = AsyncMock(
|
||
return_value=mock_db_config
|
||
)
|
||
|
||
# Call the method under test
|
||
await proxy_config._add_router_settings_from_db_config(
|
||
config_data=config_data,
|
||
llm_router=mock_router,
|
||
prisma_client=mock_prisma_client,
|
||
)
|
||
|
||
# Verify find_first was called with correct parameters
|
||
mock_prisma_client.db.litellm_config.find_first.assert_called_once_with(
|
||
where={"param_name": "router_settings"}
|
||
)
|
||
|
||
# Verify update_settings was called
|
||
mock_router.update_settings.assert_called_once()
|
||
|
||
# Get the actual settings passed to update_settings
|
||
call_args = mock_router.update_settings.call_args
|
||
combined_settings = call_args[1] # kwargs
|
||
|
||
# Verify the merge results
|
||
# DB values should override config values
|
||
assert combined_settings["routing_strategy"] == "least-busy"
|
||
|
||
# Config-only values should be preserved
|
||
assert combined_settings["model_group_alias"] == {"gpt-4": "openai-gpt-4"}
|
||
assert combined_settings["enable_pre_call_checks"] == True
|
||
assert combined_settings["timeout"] == 30
|
||
|
||
# DB-only values should be added
|
||
assert combined_settings["retry_delay"] == 2
|
||
|
||
# Nested dictionaries should be merged (but this is shallow merge)
|
||
expected_nested = {
|
||
"setting1": "config_value1",
|
||
"setting2": "db_value2",
|
||
"setting3": "db_value3",
|
||
}
|
||
assert combined_settings["nested_config"] == expected_nested
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_add_router_settings_from_db_config_edge_cases():
|
||
"""
|
||
Test edge cases for _add_router_settings_from_db_config method.
|
||
"""
|
||
from unittest.mock import AsyncMock, MagicMock
|
||
|
||
from litellm.proxy.proxy_server import ProxyConfig
|
||
|
||
proxy_config = ProxyConfig()
|
||
mock_router = MagicMock()
|
||
mock_router.update_settings = MagicMock()
|
||
|
||
# Test Case 1: No router provided
|
||
await proxy_config._add_router_settings_from_db_config(
|
||
config_data={"router_settings": {"test": "value"}},
|
||
llm_router=None,
|
||
prisma_client=MagicMock(),
|
||
)
|
||
# Should not call anything when router is None
|
||
mock_router.update_settings.assert_not_called()
|
||
|
||
# Test Case 2: No prisma client provided
|
||
await proxy_config._add_router_settings_from_db_config(
|
||
config_data={"router_settings": {"test": "value"}},
|
||
llm_router=mock_router,
|
||
prisma_client=None,
|
||
)
|
||
# Should not call anything when prisma_client is None
|
||
mock_router.update_settings.assert_not_called()
|
||
|
||
# Test Case 3: DB returns None (no router_settings in DB)
|
||
mock_prisma_client = MagicMock()
|
||
mock_prisma_client.db.litellm_config.find_first = AsyncMock(return_value=None)
|
||
|
||
config_data = {"router_settings": {"routing_strategy": "usage-based"}}
|
||
|
||
await proxy_config._add_router_settings_from_db_config(
|
||
config_data=config_data,
|
||
llm_router=mock_router,
|
||
prisma_client=mock_prisma_client,
|
||
)
|
||
|
||
# Should use only config settings
|
||
mock_router.update_settings.assert_called_once_with(routing_strategy="usage-based")
|
||
mock_router.reset_mock()
|
||
|
||
# Test Case 4: Config has no router_settings
|
||
mock_db_config = MagicMock()
|
||
mock_db_config.param_value = {"db_setting": "db_value"}
|
||
mock_prisma_client.db.litellm_config.find_first = AsyncMock(
|
||
return_value=mock_db_config
|
||
)
|
||
|
||
await proxy_config._add_router_settings_from_db_config(
|
||
config_data={}, # No router_settings in config
|
||
llm_router=mock_router,
|
||
prisma_client=mock_prisma_client,
|
||
)
|
||
|
||
# Should use only DB settings
|
||
mock_router.update_settings.assert_called_once_with(db_setting="db_value")
|
||
mock_router.reset_mock()
|
||
|
||
# Test Case 5: Both config and DB router_settings are None/empty
|
||
mock_prisma_client.db.litellm_config.find_first = AsyncMock(return_value=None)
|
||
|
||
await proxy_config._add_router_settings_from_db_config(
|
||
config_data={}, llm_router=mock_router, prisma_client=mock_prisma_client
|
||
)
|
||
|
||
# Should not call update_settings when no settings exist
|
||
mock_router.update_settings.assert_not_called()
|
||
|
||
# Test Case 6: DB config exists but param_value is not a dict
|
||
mock_db_config_invalid = MagicMock()
|
||
mock_db_config_invalid.param_value = "not_a_dict"
|
||
mock_prisma_client.db.litellm_config.find_first = AsyncMock(
|
||
return_value=mock_db_config_invalid
|
||
)
|
||
|
||
config_data = {"router_settings": {"config_setting": "config_value"}}
|
||
|
||
await proxy_config._add_router_settings_from_db_config(
|
||
config_data=config_data,
|
||
llm_router=mock_router,
|
||
prisma_client=mock_prisma_client,
|
||
)
|
||
|
||
# Should use only config settings when DB param_value is invalid
|
||
mock_router.update_settings.assert_called_once_with(config_setting="config_value")
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_add_router_settings_shallow_merge_behavior():
|
||
"""
|
||
Test that the merge behavior is shallow (nested dicts get replaced, not merged).
|
||
This documents the current behavior using _update_dictionary.
|
||
"""
|
||
from unittest.mock import AsyncMock, MagicMock
|
||
|
||
from litellm.proxy.proxy_server import ProxyConfig
|
||
|
||
proxy_config = ProxyConfig()
|
||
mock_router = MagicMock()
|
||
mock_router.update_settings = MagicMock()
|
||
|
||
# Config with nested dictionary
|
||
config_data = {
|
||
"router_settings": {
|
||
"nested_setting": {
|
||
"key1": "config_value1",
|
||
"key2": "config_value2",
|
||
"key3": "config_value3",
|
||
},
|
||
"top_level": "config_top",
|
||
}
|
||
}
|
||
|
||
# DB config that partially overlaps the nested dictionary
|
||
mock_db_config = MagicMock()
|
||
mock_db_config.param_value = {
|
||
"nested_setting": {
|
||
"key2": "db_value2", # Override existing key
|
||
"key4": "db_value4", # Add new key
|
||
# Note: key1 and key3 from config will be lost due to shallow merge
|
||
},
|
||
"top_level": "db_top", # Override top level
|
||
}
|
||
|
||
mock_prisma_client = MagicMock()
|
||
mock_prisma_client.db.litellm_config.find_first = AsyncMock(
|
||
return_value=mock_db_config
|
||
)
|
||
|
||
await proxy_config._add_router_settings_from_db_config(
|
||
config_data=config_data,
|
||
llm_router=mock_router,
|
||
prisma_client=mock_prisma_client,
|
||
)
|
||
|
||
# Get the merged settings
|
||
call_args = mock_router.update_settings.call_args
|
||
merged_settings = call_args[1]
|
||
|
||
# Verify shallow merge behavior:
|
||
# The entire nested_setting dict from config is replaced by the DB version
|
||
expected_nested = {
|
||
"key1": "config_value1",
|
||
"key3": "config_value3",
|
||
"key2": "db_value2",
|
||
"key4": "db_value4",
|
||
}
|
||
|
||
assert merged_settings["nested_setting"] == expected_nested
|
||
assert merged_settings["top_level"] == "db_top"
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_model_info_v1_oci_secrets_not_leaked():
|
||
"""
|
||
Test that model_info_v1 endpoint properly masks OCI sensitive parameters and does not leak secrets.
|
||
"""
|
||
from unittest.mock import MagicMock, patch
|
||
|
||
from litellm.proxy._types import UserAPIKeyAuth
|
||
from litellm.proxy.proxy_server import model_info_v1
|
||
|
||
# Mock user authentication
|
||
mock_user_api_key_dict = MagicMock(spec=UserAPIKeyAuth)
|
||
mock_user_api_key_dict.user_id = "test-user"
|
||
mock_user_api_key_dict.api_key = "test-key"
|
||
mock_user_api_key_dict.team_models = []
|
||
mock_user_api_key_dict.models = ["oci-grok-test"]
|
||
|
||
# Mock model data with OCI sensitive information
|
||
mock_model_data = {
|
||
"model_name": "oci-grok-test",
|
||
"litellm_params": {
|
||
"model": "oci/xai.grok-4",
|
||
"oci_key": "ocid1.api_key.oc1..aaaaaaaa7kbkbkbkbkbkbkbkbkbkbkbkbkbkbkbkbkbkbkbkbkbkbkbkbkbk",
|
||
"oci_region": "us-phoenix-1",
|
||
"oci_user": "ocid1.user.oc1..aaaaaaaa7kbkbkbkbkbkbkbkbkbkbkbkbkbkbkbkbkbkbkbkbkbkbkbkbkbk",
|
||
"oci_fingerprint": "aa:bb:cc:dd:ee:ff:11:22:33:44:55:66:77:88:99:00",
|
||
"oci_tenancy": "ocid1.tenancy.oc1..aaaaaaaa7kbkbkbkbkbkbkbkbkbkbkbkbkbkbkbkbkbkbkbkbkbkbkbkbkbk",
|
||
"oci_key_file": "/path/to/oci_api_key.pem",
|
||
"oci_compartment_id": "ocid1.compartment.oc1..aaaaaaaa7kbkbkbkbkbkbkbkbkbkbkbkbkbkbkbkbkbkbkbkbkbkbkbkbkbk",
|
||
"drop_params": True,
|
||
},
|
||
"model_info": {"mode": "completion", "id": "test-model-id"},
|
||
}
|
||
|
||
# Mock the llm_router to return our test data
|
||
mock_router = MagicMock()
|
||
mock_router.get_model_names.return_value = ["oci-grok-test"]
|
||
mock_router.get_model_access_groups.return_value = {}
|
||
mock_router.get_model_list.return_value = [mock_model_data]
|
||
|
||
# Mock global variables
|
||
with (
|
||
patch("litellm.proxy.proxy_server.llm_router", mock_router),
|
||
patch("litellm.proxy.proxy_server.llm_model_list", [mock_model_data]),
|
||
patch(
|
||
"litellm.proxy.proxy_server.general_settings",
|
||
{"infer_model_from_keys": False},
|
||
),
|
||
patch("litellm.proxy.proxy_server.user_model", None),
|
||
):
|
||
# Call the model_info_v1 endpoint
|
||
result = await model_info_v1(
|
||
user_api_key_dict=mock_user_api_key_dict, litellm_model_id=None
|
||
)
|
||
|
||
# Verify the result structure
|
||
assert "data" in result
|
||
assert len(result["data"]) == 1
|
||
|
||
model_info = result["data"][0]
|
||
litellm_params = model_info["litellm_params"]
|
||
|
||
# Verify that sensitive OCI fields are masked
|
||
assert "****" in litellm_params["oci_key"], "oci_key should be masked"
|
||
assert (
|
||
"****" in litellm_params["oci_fingerprint"]
|
||
), "oci_fingerprint should be masked"
|
||
assert "****" in litellm_params["oci_tenancy"], "oci_tenancy should be masked"
|
||
assert "****" in litellm_params["oci_key_file"], "oci_key_file should be masked"
|
||
|
||
# Verify that non-sensitive fields are NOT masked
|
||
assert (
|
||
litellm_params["model"] == "oci/xai.grok-4"
|
||
), "model field should not be masked"
|
||
assert (
|
||
litellm_params["oci_region"] == "us-phoenix-1"
|
||
), "oci_region should not be masked"
|
||
assert litellm_params["drop_params"] is True, "drop_params should not be masked"
|
||
|
||
# Verify the model field specifically is not masked (this was the original issue)
|
||
assert (
|
||
"****" not in litellm_params["model"]
|
||
), "model field should never be masked"
|
||
assert litellm_params["model"].startswith(
|
||
"oci/"
|
||
), "model should retain its full value"
|
||
|
||
# Verify that actual secret values are not present in the response
|
||
result_str = str(result)
|
||
assert (
|
||
"ocid1.api_key.oc1..aaaaaaaa7kbkbkbkbkbkbkbkbkbkbkbkbkbkbkbkbkbkbkbkbkbkbkbkbkbk"
|
||
not in result_str
|
||
)
|
||
assert "aa:bb:cc:dd:ee:ff:11:22:33:44:55:66:77:88:99:00" not in result_str
|
||
assert (
|
||
"ocid1.tenancy.oc1..aaaaaaaa7kbkbkbkbkbkbkbkbkbkbkbkbkbkbkbkbkbkbkbkbkbkbkbkbkbk"
|
||
not in result_str
|
||
)
|
||
assert "/path/to/oci_api_key.pem" not in result_str
|
||
|
||
|
||
def test_add_callback_from_db_to_in_memory_litellm_callbacks():
|
||
"""
|
||
Test that _add_callback_from_db_to_in_memory_litellm_callbacks correctly adds callbacks
|
||
for success, failure, and combined event types.
|
||
"""
|
||
from unittest.mock import MagicMock, patch
|
||
|
||
from litellm.proxy.proxy_server import ProxyConfig
|
||
|
||
proxy_config = ProxyConfig()
|
||
|
||
# Mock the callback manager
|
||
mock_callback_manager = MagicMock()
|
||
|
||
with patch("litellm.proxy.proxy_server.litellm") as mock_litellm:
|
||
# Set up mock litellm attributes
|
||
mock_litellm._known_custom_logger_compatible_callbacks = []
|
||
mock_litellm.logging_callback_manager = mock_callback_manager
|
||
|
||
# Test Case 1: Add success callback
|
||
mock_success_callbacks = []
|
||
proxy_config._add_callback_from_db_to_in_memory_litellm_callbacks(
|
||
callback="prometheus",
|
||
event_types=["success"],
|
||
existing_callbacks=mock_success_callbacks,
|
||
)
|
||
mock_callback_manager.add_litellm_success_callback.assert_called_once_with(
|
||
"prometheus"
|
||
)
|
||
mock_callback_manager.reset_mock()
|
||
|
||
# Test Case 2: Add failure callback
|
||
mock_failure_callbacks = []
|
||
proxy_config._add_callback_from_db_to_in_memory_litellm_callbacks(
|
||
callback="langfuse",
|
||
event_types=["failure"],
|
||
existing_callbacks=mock_failure_callbacks,
|
||
)
|
||
mock_callback_manager.add_litellm_failure_callback.assert_called_once_with(
|
||
"langfuse"
|
||
)
|
||
mock_callback_manager.reset_mock()
|
||
|
||
# Test Case 3: Add callback for both success and failure
|
||
mock_callbacks = []
|
||
proxy_config._add_callback_from_db_to_in_memory_litellm_callbacks(
|
||
callback="s3",
|
||
event_types=["success", "failure"],
|
||
existing_callbacks=mock_callbacks,
|
||
)
|
||
mock_callback_manager.add_litellm_callback.assert_called_once_with("s3")
|
||
mock_callback_manager.reset_mock()
|
||
|
||
# Test Case 4: Don't add callback if it already exists
|
||
existing_callbacks_with_item = ["prometheus"]
|
||
proxy_config._add_callback_from_db_to_in_memory_litellm_callbacks(
|
||
callback="prometheus",
|
||
event_types=["success"],
|
||
existing_callbacks=existing_callbacks_with_item,
|
||
)
|
||
mock_callback_manager.add_litellm_success_callback.assert_not_called()
|
||
|
||
|
||
def test_should_load_db_object_with_supported_db_objects():
|
||
"""
|
||
Test _should_load_db_object method with supported_db_objects configuration.
|
||
|
||
Verifies that when supported_db_objects is set, only specified object types
|
||
are loaded from the database.
|
||
"""
|
||
from unittest.mock import patch
|
||
|
||
from litellm.proxy.proxy_server import ProxyConfig
|
||
|
||
proxy_config = ProxyConfig()
|
||
|
||
# Test Case 1: supported_db_objects not set - all objects should be loaded
|
||
with patch("litellm.proxy.proxy_server.general_settings", {}):
|
||
assert proxy_config._should_load_db_object(object_type="models") is True
|
||
assert proxy_config._should_load_db_object(object_type="mcp") is True
|
||
assert proxy_config._should_load_db_object(object_type="guardrails") is True
|
||
assert proxy_config._should_load_db_object(object_type="vector_stores") is True
|
||
|
||
# Test Case 2: supported_db_objects set to only load MCP
|
||
with patch(
|
||
"litellm.proxy.proxy_server.general_settings",
|
||
{"supported_db_objects": ["mcp"]},
|
||
):
|
||
assert proxy_config._should_load_db_object(object_type="models") is False
|
||
assert proxy_config._should_load_db_object(object_type="mcp") is True
|
||
assert proxy_config._should_load_db_object(object_type="guardrails") is False
|
||
assert proxy_config._should_load_db_object(object_type="vector_stores") is False
|
||
assert proxy_config._should_load_db_object(object_type="prompts") is False
|
||
|
||
# Test Case 3: supported_db_objects set to load multiple types
|
||
with patch(
|
||
"litellm.proxy.proxy_server.general_settings",
|
||
{"supported_db_objects": ["mcp", "guardrails", "vector_stores"]},
|
||
):
|
||
assert proxy_config._should_load_db_object(object_type="models") is False
|
||
assert proxy_config._should_load_db_object(object_type="mcp") is True
|
||
assert proxy_config._should_load_db_object(object_type="guardrails") is True
|
||
assert proxy_config._should_load_db_object(object_type="vector_stores") is True
|
||
assert proxy_config._should_load_db_object(object_type="prompts") is False
|
||
|
||
# Test Case 4: supported_db_objects is not a list (should default to loading all)
|
||
with patch(
|
||
"litellm.proxy.proxy_server.general_settings",
|
||
{"supported_db_objects": "invalid_type"},
|
||
):
|
||
assert proxy_config._should_load_db_object(object_type="models") is True
|
||
assert proxy_config._should_load_db_object(object_type="mcp") is True
|
||
|
||
# Test Case 5: supported_db_objects is an empty list (nothing should be loaded)
|
||
with patch(
|
||
"litellm.proxy.proxy_server.general_settings",
|
||
{"supported_db_objects": []},
|
||
):
|
||
assert proxy_config._should_load_db_object(object_type="models") is False
|
||
assert proxy_config._should_load_db_object(object_type="mcp") is False
|
||
assert proxy_config._should_load_db_object(object_type="guardrails") is False
|
||
|
||
# Test Case 6: Test all available object types
|
||
with patch(
|
||
"litellm.proxy.proxy_server.general_settings",
|
||
{
|
||
"supported_db_objects": [
|
||
"models",
|
||
"mcp",
|
||
"guardrails",
|
||
"vector_stores",
|
||
"pass_through_endpoints",
|
||
"prompts",
|
||
"model_cost_map",
|
||
]
|
||
},
|
||
):
|
||
assert proxy_config._should_load_db_object(object_type="models") is True
|
||
assert proxy_config._should_load_db_object(object_type="mcp") is True
|
||
assert proxy_config._should_load_db_object(object_type="guardrails") is True
|
||
assert proxy_config._should_load_db_object(object_type="vector_stores") is True
|
||
assert (
|
||
proxy_config._should_load_db_object(object_type="pass_through_endpoints")
|
||
is True
|
||
)
|
||
assert proxy_config._should_load_db_object(object_type="prompts") is True
|
||
assert proxy_config._should_load_db_object(object_type="model_cost_map") is True
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_tag_cache_update_called():
|
||
"""
|
||
Test that update_cache updates tag cache when tags are provided.
|
||
"""
|
||
from litellm.caching.caching import DualCache
|
||
from litellm.proxy.proxy_server import user_api_key_cache
|
||
|
||
cache = DualCache()
|
||
|
||
setattr(
|
||
litellm.proxy.proxy_server,
|
||
"user_api_key_cache",
|
||
cache,
|
||
)
|
||
|
||
mock_tag_obj = {
|
||
"tag_name": "test-tag",
|
||
"spend": 10.0,
|
||
}
|
||
|
||
with patch.object(
|
||
cache, "async_get_cache", new=AsyncMock(return_value=mock_tag_obj)
|
||
) as mock_get_cache:
|
||
with patch.object(
|
||
cache, "async_set_cache_pipeline", new=AsyncMock()
|
||
) as mock_set_cache:
|
||
await litellm.proxy.proxy_server.update_cache(
|
||
token=None,
|
||
user_id=None,
|
||
end_user_id=None,
|
||
team_id=None,
|
||
response_cost=5.0,
|
||
parent_otel_span=None,
|
||
tags=["test-tag"],
|
||
)
|
||
|
||
await asyncio.sleep(0.1)
|
||
|
||
mock_get_cache.assert_awaited_once_with(key="tag:test-tag")
|
||
mock_set_cache.assert_awaited_once()
|
||
|
||
call_args = mock_set_cache.call_args
|
||
cache_list = call_args.kwargs["cache_list"]
|
||
|
||
assert len(cache_list) == 1
|
||
cache_key, cache_value = cache_list[0]
|
||
assert cache_key == "tag:test-tag"
|
||
assert cache_value["spend"] == 15.0
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_tag_cache_update_multiple_tags():
|
||
"""
|
||
Test that multiple tags are updated in cache.
|
||
"""
|
||
from litellm.caching.caching import DualCache
|
||
from litellm.proxy.proxy_server import user_api_key_cache
|
||
|
||
cache = DualCache()
|
||
|
||
setattr(
|
||
litellm.proxy.proxy_server,
|
||
"user_api_key_cache",
|
||
cache,
|
||
)
|
||
|
||
mock_tag1_obj = {"tag_name": "tag1", "spend": 10.0}
|
||
mock_tag2_obj = {"tag_name": "tag2", "spend": 20.0}
|
||
|
||
async def mock_get_cache_side_effect(key):
|
||
if key == "tag:tag1":
|
||
return mock_tag1_obj
|
||
elif key == "tag:tag2":
|
||
return mock_tag2_obj
|
||
return None
|
||
|
||
with patch.object(
|
||
cache, "async_get_cache", new=AsyncMock(side_effect=mock_get_cache_side_effect)
|
||
) as mock_get_cache:
|
||
with patch.object(
|
||
cache, "async_set_cache_pipeline", new=AsyncMock()
|
||
) as mock_set_cache:
|
||
await litellm.proxy.proxy_server.update_cache(
|
||
token=None,
|
||
user_id=None,
|
||
end_user_id=None,
|
||
team_id=None,
|
||
response_cost=5.0,
|
||
parent_otel_span=None,
|
||
tags=["tag1", "tag2"],
|
||
)
|
||
|
||
await asyncio.sleep(0.1)
|
||
|
||
assert mock_get_cache.call_count == 2
|
||
mock_set_cache.assert_awaited_once()
|
||
|
||
call_args = mock_set_cache.call_args
|
||
cache_list = call_args.kwargs["cache_list"]
|
||
|
||
assert len(cache_list) == 2
|
||
|
||
tag_updates = {
|
||
cache_key: cache_value for cache_key, cache_value in cache_list
|
||
}
|
||
assert "tag:tag1" in tag_updates
|
||
assert "tag:tag2" in tag_updates
|
||
assert tag_updates["tag:tag1"]["spend"] == 15.0
|
||
assert tag_updates["tag:tag2"]["spend"] == 25.0
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_init_sso_settings_in_db():
|
||
"""
|
||
Test that _init_sso_settings_in_db properly loads SSO settings from database,
|
||
uppercases keys, and calls _decrypt_and_set_db_env_variables.
|
||
"""
|
||
from unittest.mock import AsyncMock, MagicMock, patch
|
||
|
||
from litellm.proxy.proxy_server import ProxyConfig
|
||
|
||
proxy_config = ProxyConfig()
|
||
|
||
# Test Case 1: SSO settings exist in database
|
||
mock_sso_config = MagicMock()
|
||
mock_sso_config.sso_settings = {
|
||
"google_client_id": "test-client-id",
|
||
"google_client_secret": "test-client-secret",
|
||
"microsoft_client_id": "ms-client-id",
|
||
"microsoft_client_secret": "ms-client-secret",
|
||
}
|
||
|
||
mock_prisma_client = MagicMock()
|
||
mock_prisma_client.db.litellm_ssoconfig.find_unique = AsyncMock(
|
||
return_value=mock_sso_config
|
||
)
|
||
|
||
# Mock _decrypt_and_set_db_env_variables
|
||
with patch.object(
|
||
proxy_config, "_decrypt_and_set_db_env_variables"
|
||
) as mock_decrypt_and_set:
|
||
await proxy_config._init_sso_settings_in_db(prisma_client=mock_prisma_client)
|
||
|
||
# Verify find_unique was called with correct parameters
|
||
mock_prisma_client.db.litellm_ssoconfig.find_unique.assert_awaited_once_with(
|
||
where={"id": "sso_config"}
|
||
)
|
||
|
||
# Verify _decrypt_and_set_db_env_variables was called with uppercased keys
|
||
mock_decrypt_and_set.assert_called_once()
|
||
call_args = mock_decrypt_and_set.call_args
|
||
uppercased_settings = call_args.kwargs["environment_variables"]
|
||
|
||
# Verify all keys are uppercased
|
||
assert "GOOGLE_CLIENT_ID" in uppercased_settings
|
||
assert "GOOGLE_CLIENT_SECRET" in uppercased_settings
|
||
assert "MICROSOFT_CLIENT_ID" in uppercased_settings
|
||
assert "MICROSOFT_CLIENT_SECRET" in uppercased_settings
|
||
|
||
# Verify values are preserved
|
||
assert uppercased_settings["GOOGLE_CLIENT_ID"] == "test-client-id"
|
||
assert uppercased_settings["GOOGLE_CLIENT_SECRET"] == "test-client-secret"
|
||
assert uppercased_settings["MICROSOFT_CLIENT_ID"] == "ms-client-id"
|
||
assert uppercased_settings["MICROSOFT_CLIENT_SECRET"] == "ms-client-secret"
|
||
|
||
# Verify original lowercase keys are not present
|
||
assert "google_client_id" not in uppercased_settings
|
||
assert "microsoft_client_id" not in uppercased_settings
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_init_sso_settings_in_db_no_settings():
|
||
"""
|
||
Test that _init_sso_settings_in_db handles the case when no SSO settings exist in database.
|
||
"""
|
||
from unittest.mock import AsyncMock, MagicMock, patch
|
||
|
||
from litellm.proxy.proxy_server import ProxyConfig
|
||
|
||
proxy_config = ProxyConfig()
|
||
|
||
# Mock prisma client to return None (no SSO settings)
|
||
mock_prisma_client = MagicMock()
|
||
mock_prisma_client.db.litellm_ssoconfig.find_unique = AsyncMock(return_value=None)
|
||
|
||
# Mock _decrypt_and_set_db_env_variables
|
||
with patch.object(
|
||
proxy_config, "_decrypt_and_set_db_env_variables"
|
||
) as mock_decrypt_and_set:
|
||
await proxy_config._init_sso_settings_in_db(prisma_client=mock_prisma_client)
|
||
|
||
# Verify find_unique was called
|
||
mock_prisma_client.db.litellm_ssoconfig.find_unique.assert_awaited_once_with(
|
||
where={"id": "sso_config"}
|
||
)
|
||
|
||
# Verify _decrypt_and_set_db_env_variables was NOT called when no settings exist
|
||
mock_decrypt_and_set.assert_not_called()
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_init_sso_settings_in_db_error_handling():
|
||
"""
|
||
Test that _init_sso_settings_in_db handles database errors gracefully.
|
||
"""
|
||
from unittest.mock import AsyncMock, MagicMock, patch
|
||
|
||
from litellm.proxy.proxy_server import ProxyConfig
|
||
|
||
proxy_config = ProxyConfig()
|
||
|
||
# Mock prisma client to raise an exception
|
||
mock_prisma_client = MagicMock()
|
||
mock_prisma_client.db.litellm_ssoconfig.find_unique = AsyncMock(
|
||
side_effect=Exception("Database connection error")
|
||
)
|
||
|
||
# The method should not raise an exception, it should log it instead
|
||
try:
|
||
await proxy_config._init_sso_settings_in_db(prisma_client=mock_prisma_client)
|
||
# If we get here, the exception was handled properly
|
||
assert True
|
||
except Exception as e:
|
||
# The exception should be caught and logged, not propagated
|
||
pytest.fail(
|
||
f"Exception should have been caught and logged, but was raised: {e}"
|
||
)
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_init_sso_settings_in_db_empty_settings():
|
||
"""
|
||
Test that _init_sso_settings_in_db handles empty SSO settings dictionary.
|
||
"""
|
||
from unittest.mock import AsyncMock, MagicMock, patch
|
||
|
||
from litellm.proxy.proxy_server import ProxyConfig
|
||
|
||
proxy_config = ProxyConfig()
|
||
|
||
# Mock SSO config with empty settings dictionary
|
||
mock_sso_config = MagicMock()
|
||
mock_sso_config.sso_settings = {}
|
||
|
||
mock_prisma_client = MagicMock()
|
||
mock_prisma_client.db.litellm_ssoconfig.find_unique = AsyncMock(
|
||
return_value=mock_sso_config
|
||
)
|
||
|
||
# Mock _decrypt_and_set_db_env_variables
|
||
with patch.object(
|
||
proxy_config, "_decrypt_and_set_db_env_variables"
|
||
) as mock_decrypt_and_set:
|
||
await proxy_config._init_sso_settings_in_db(prisma_client=mock_prisma_client)
|
||
|
||
# Verify find_unique was called
|
||
mock_prisma_client.db.litellm_ssoconfig.find_unique.assert_awaited_once_with(
|
||
where={"id": "sso_config"}
|
||
)
|
||
|
||
# Verify _decrypt_and_set_db_env_variables was called with empty dict
|
||
mock_decrypt_and_set.assert_called_once()
|
||
call_args = mock_decrypt_and_set.call_args
|
||
uppercased_settings = call_args.kwargs["environment_variables"]
|
||
|
||
# Verify empty dictionary
|
||
assert uppercased_settings == {}
|
||
|
||
|
||
def test_update_config_fields_uppercases_env_vars(monkeypatch):
|
||
"""
|
||
Ensure environment variables pulled from DB are uppercased when applied so
|
||
integrations like Datadog that expect uppercase env keys can read them.
|
||
"""
|
||
from litellm.proxy.proxy_server import ProxyConfig
|
||
|
||
for key in ["DD_API_KEY", "DD_SITE", "dd_api_key", "dd_site"]:
|
||
monkeypatch.delenv(key, raising=False)
|
||
|
||
proxy_config = ProxyConfig()
|
||
updated_config = proxy_config._update_config_fields(
|
||
current_config={},
|
||
param_name="environment_variables",
|
||
db_param_value={"dd_api_key": "test-api-key", "dd_site": "us5.datadoghq.com"},
|
||
)
|
||
|
||
env_vars = updated_config.get("environment_variables", {})
|
||
assert env_vars["DD_API_KEY"] == "test-api-key"
|
||
assert env_vars["DD_SITE"] == "us5.datadoghq.com"
|
||
assert os.environ.get("DD_API_KEY") == "test-api-key"
|
||
assert os.environ.get("DD_SITE") == "us5.datadoghq.com"
|
||
|
||
|
||
def test_encrypt_env_variables_for_db_is_idempotent(monkeypatch):
|
||
"""
|
||
Regression: /config/update and save_config must not stack a second
|
||
encryption layer when a caller re-submits a value that is already
|
||
ciphertext (the Admin UI reads config back from /get/config/callbacks —
|
||
which returns the stored, still-encrypted value — and re-POSTs it on the
|
||
next save). _encrypt_env_variables_for_db must yield a value that decrypts
|
||
to the original plaintext in exactly ONE layer, no matter how many times
|
||
its own output is fed back in. It must also not mutate os.environ (write
|
||
path — loading into the process env is the read path's job).
|
||
"""
|
||
from litellm.proxy.common_utils.encrypt_decrypt_utils import (
|
||
decrypt_value_helper,
|
||
)
|
||
from litellm.proxy.proxy_server import ProxyConfig
|
||
|
||
monkeypatch.setenv("LITELLM_SALT_KEY", "sk-test-salt-key")
|
||
monkeypatch.delenv("LANGFUSE_PUBLIC_KEY", raising=False)
|
||
|
||
proxy_config = ProxyConfig()
|
||
plaintext = "pk-langfuse-secret-value"
|
||
|
||
# First write: plaintext in -> single-encrypted out.
|
||
enc1 = proxy_config._encrypt_env_variables_for_db(
|
||
{"LANGFUSE_PUBLIC_KEY": plaintext}
|
||
)
|
||
assert enc1["LANGFUSE_PUBLIC_KEY"] != plaintext
|
||
assert (
|
||
decrypt_value_helper(
|
||
value=enc1["LANGFUSE_PUBLIC_KEY"], key="LANGFUSE_PUBLIC_KEY"
|
||
)
|
||
== plaintext
|
||
)
|
||
|
||
# UI round-trip: feed the ciphertext back in. Must NOT double-encrypt.
|
||
enc2 = proxy_config._encrypt_env_variables_for_db(enc1)
|
||
assert (
|
||
decrypt_value_helper(
|
||
value=enc2["LANGFUSE_PUBLIC_KEY"], key="LANGFUSE_PUBLIC_KEY"
|
||
)
|
||
== plaintext
|
||
)
|
||
|
||
# And again, ×3 total ciphertext re-feeds — still exactly one layer,
|
||
# never stacked, no matter how many times the UI re-saves.
|
||
enc3 = proxy_config._encrypt_env_variables_for_db(enc2)
|
||
enc4 = proxy_config._encrypt_env_variables_for_db(enc3)
|
||
for stacked in (enc3, enc4):
|
||
assert (
|
||
decrypt_value_helper(
|
||
value=stacked["LANGFUSE_PUBLIC_KEY"], key="LANGFUSE_PUBLIC_KEY"
|
||
)
|
||
== plaintext
|
||
)
|
||
|
||
# Write path must not leak the value into the process environment.
|
||
assert os.environ.get("LANGFUSE_PUBLIC_KEY") is None
|
||
|
||
|
||
def test_get_prompt_spec_for_db_prompt_with_versions():
|
||
"""
|
||
Test that _get_prompt_spec_for_db_prompt correctly converts database prompts
|
||
to PromptSpec with versioned naming convention.
|
||
"""
|
||
from unittest.mock import MagicMock
|
||
|
||
from litellm.proxy.proxy_server import ProxyConfig
|
||
|
||
proxy_config = ProxyConfig()
|
||
|
||
# Mock database prompt version 1
|
||
mock_prompt_v1 = MagicMock()
|
||
mock_prompt_v1.model_dump.return_value = {
|
||
"id": "uuid-1",
|
||
"prompt_id": "chat_prompt",
|
||
"version": 1,
|
||
"litellm_params": '{"prompt_id": "chat_prompt", "prompt_integration": "dotprompt", "model": "gpt-3.5-turbo", "messages": [{"role": "user", "content": "v1 content"}]}',
|
||
"prompt_info": '{"prompt_type": "db"}',
|
||
"created_at": "2024-01-01T00:00:00",
|
||
"updated_at": "2024-01-01T00:00:00",
|
||
}
|
||
|
||
# Mock database prompt version 2
|
||
mock_prompt_v2 = MagicMock()
|
||
mock_prompt_v2.model_dump.return_value = {
|
||
"id": "uuid-2",
|
||
"prompt_id": "chat_prompt",
|
||
"version": 2,
|
||
"litellm_params": '{"prompt_id": "chat_prompt", "prompt_integration": "dotprompt", "model": "gpt-4", "messages": [{"role": "user", "content": "v2 content"}]}',
|
||
"prompt_info": '{"prompt_type": "db"}',
|
||
"created_at": "2024-01-02T00:00:00",
|
||
"updated_at": "2024-01-02T00:00:00",
|
||
}
|
||
|
||
# Test version 1
|
||
prompt_spec_v1 = proxy_config._get_prompt_spec_for_db_prompt(
|
||
db_prompt=mock_prompt_v1
|
||
)
|
||
assert prompt_spec_v1.prompt_id == "chat_prompt.v1"
|
||
|
||
# Test version 2
|
||
prompt_spec_v2 = proxy_config._get_prompt_spec_for_db_prompt(
|
||
db_prompt=mock_prompt_v2
|
||
)
|
||
assert prompt_spec_v2.prompt_id == "chat_prompt.v2"
|
||
|
||
|
||
def test_root_redirect_when_docs_url_not_root_and_redirect_url_set(monkeypatch):
|
||
from fastapi.responses import RedirectResponse
|
||
|
||
from litellm.proxy.proxy_server import cleanup_router_config_variables
|
||
from litellm.proxy.utils import _get_docs_url
|
||
|
||
cleanup_router_config_variables()
|
||
filepath = os.path.dirname(os.path.abspath(__file__))
|
||
config_fp = f"{filepath}/test_configs/test_config_no_auth.yaml"
|
||
# Ensure docs are mounted on a non-root path to trigger redirect logic
|
||
monkeypatch.setenv("DOCS_URL", "/docs")
|
||
|
||
test_redirect_url = "/ui"
|
||
monkeypatch.setenv("ROOT_REDIRECT_URL", test_redirect_url)
|
||
|
||
asyncio.run(initialize(config=config_fp, debug=True))
|
||
|
||
docs_url = _get_docs_url()
|
||
root_redirect_url = os.getenv("ROOT_REDIRECT_URL")
|
||
|
||
# Remove any existing "/" route that might interfere
|
||
routes_to_remove = []
|
||
for route in app.routes:
|
||
if hasattr(route, "path") and route.path == "/":
|
||
if hasattr(route, "methods") and "GET" in route.methods:
|
||
routes_to_remove.append(route)
|
||
elif not hasattr(route, "methods"): # Catch-all routes
|
||
routes_to_remove.append(route)
|
||
|
||
for route in routes_to_remove:
|
||
app.routes.remove(route)
|
||
|
||
# Add the redirect route if conditions are met (matching the actual implementation)
|
||
if docs_url != "/" and root_redirect_url:
|
||
|
||
@app.get("/", include_in_schema=False)
|
||
async def root_redirect():
|
||
return RedirectResponse(url=root_redirect_url)
|
||
|
||
client = TestClient(app)
|
||
response = client.get("/", follow_redirects=False)
|
||
assert response.status_code == 307
|
||
assert response.headers["location"] == test_redirect_url
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_get_image_non_root_uses_var_lib_assets_dir(monkeypatch):
|
||
"""
|
||
Test that get_image uses /var/lib/litellm/assets when LITELLM_NON_ROOT is true.
|
||
"""
|
||
from unittest.mock import patch
|
||
|
||
from litellm.proxy.proxy_server import get_image
|
||
|
||
# Set LITELLM_NON_ROOT to true
|
||
monkeypatch.setenv("LITELLM_NON_ROOT", "true")
|
||
monkeypatch.delenv("UI_LOGO_PATH", raising=False)
|
||
|
||
# Mock os.path operations - exists=False for assets_dir so makedirs gets called
|
||
def exists_side_effect(path):
|
||
return False if path == "/var/lib/litellm/assets" else True
|
||
|
||
with (
|
||
patch("litellm.proxy.proxy_server.os.makedirs") as mock_makedirs,
|
||
patch(
|
||
"litellm.proxy.proxy_server.os.path.exists", side_effect=exists_side_effect
|
||
),
|
||
patch("litellm.proxy.proxy_server.os.access", return_value=True),
|
||
patch("litellm.proxy.proxy_server.os.getenv") as mock_getenv,
|
||
patch("litellm.proxy.proxy_server.FileResponse") as mock_file_response,
|
||
):
|
||
# Setup mock_getenv to return empty string for UI_LOGO_PATH
|
||
def getenv_side_effect(key, default=""):
|
||
if key == "UI_LOGO_PATH":
|
||
return ""
|
||
elif key == "LITELLM_NON_ROOT":
|
||
return "true"
|
||
return default
|
||
|
||
mock_getenv.side_effect = getenv_side_effect
|
||
|
||
# Call the function
|
||
await get_image()
|
||
|
||
# Verify makedirs was called with /var/lib/litellm/assets
|
||
mock_makedirs.assert_called_once_with("/var/lib/litellm/assets", exist_ok=True)
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_get_image_non_root_fallback_to_default_logo(monkeypatch):
|
||
"""
|
||
Test that get_image falls back to default_site_logo when logo doesn't exist
|
||
in /var/lib/litellm/assets for non-root case.
|
||
"""
|
||
from unittest.mock import patch
|
||
|
||
from litellm.proxy.proxy_server import get_image
|
||
|
||
# Set LITELLM_NON_ROOT to true
|
||
monkeypatch.setenv("LITELLM_NON_ROOT", "true")
|
||
monkeypatch.delenv("UI_LOGO_PATH", raising=False)
|
||
|
||
# Track path.exists calls to verify it checks /var/lib/litellm/assets/logo.jpg
|
||
exists_calls = []
|
||
|
||
def exists_side_effect(path):
|
||
exists_calls.append(path)
|
||
# Return False for /var/lib/litellm/assets* so: makedirs is called, logo fallback
|
||
# triggers, and we don't return early with cached file
|
||
if "/var/lib/litellm/assets" in path:
|
||
return False
|
||
return True
|
||
|
||
# Mock os.path operations
|
||
with (
|
||
patch("litellm.proxy.proxy_server.os.makedirs") as mock_makedirs,
|
||
patch(
|
||
"litellm.proxy.proxy_server.os.path.exists", side_effect=exists_side_effect
|
||
),
|
||
patch("litellm.proxy.proxy_server.os.access", return_value=True),
|
||
patch("litellm.proxy.proxy_server.os.getenv") as mock_getenv,
|
||
patch("litellm.proxy.proxy_server.FileResponse") as mock_file_response,
|
||
):
|
||
# Setup mock_getenv
|
||
def getenv_side_effect(key, default=""):
|
||
if key == "UI_LOGO_PATH":
|
||
return ""
|
||
elif key == "LITELLM_NON_ROOT":
|
||
return "true"
|
||
return default
|
||
|
||
mock_getenv.side_effect = getenv_side_effect
|
||
|
||
# Call the function
|
||
await get_image()
|
||
|
||
# Verify makedirs was called with /var/lib/litellm/assets
|
||
mock_makedirs.assert_called_once_with("/var/lib/litellm/assets", exist_ok=True)
|
||
|
||
# Verify that exists was called to check /var/lib/litellm/assets/logo.jpg
|
||
assets_logo_path = "/var/lib/litellm/assets/logo.jpg"
|
||
assert any(
|
||
assets_logo_path in str(call) for call in exists_calls
|
||
), f"Should check if {assets_logo_path} exists"
|
||
|
||
# Verify FileResponse was called (with fallback logo)
|
||
assert mock_file_response.called, "FileResponse should be called"
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_get_image_root_case_uses_current_dir(monkeypatch):
|
||
"""
|
||
Test that get_image uses current_dir when LITELLM_NON_ROOT is not true.
|
||
"""
|
||
from unittest.mock import patch
|
||
|
||
from litellm.proxy.proxy_server import get_image
|
||
|
||
# Don't set LITELLM_NON_ROOT (or set it to false)
|
||
monkeypatch.delenv("LITELLM_NON_ROOT", raising=False)
|
||
monkeypatch.delenv("UI_LOGO_PATH", raising=False)
|
||
|
||
# Mock os.path operations
|
||
with (
|
||
patch("litellm.proxy.proxy_server.os.makedirs") as mock_makedirs,
|
||
patch("litellm.proxy.proxy_server.os.path.exists", return_value=True),
|
||
patch("litellm.proxy.proxy_server.os.getenv") as mock_getenv,
|
||
patch("litellm.proxy.proxy_server.FileResponse") as mock_file_response,
|
||
):
|
||
# Setup mock_getenv
|
||
def getenv_side_effect(key, default=""):
|
||
if key == "UI_LOGO_PATH":
|
||
return ""
|
||
elif key == "LITELLM_NON_ROOT":
|
||
return "" # Not set or empty
|
||
return default
|
||
|
||
mock_getenv.side_effect = getenv_side_effect
|
||
|
||
# Call the function
|
||
await get_image()
|
||
|
||
# Verify makedirs was NOT called with /var/lib/litellm/assets (should not create it for root case)
|
||
var_lib_assets_calls = [
|
||
call
|
||
for call in mock_makedirs.call_args_list
|
||
if "/var/lib/litellm/assets" in str(call)
|
||
]
|
||
assert (
|
||
len(var_lib_assets_calls) == 0
|
||
), "Should not create /var/lib/litellm/assets for root case"
|
||
|
||
# Verify FileResponse was called
|
||
assert mock_file_response.called, "FileResponse should be called"
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_get_image_custom_local_logo_bypasses_cache(monkeypatch, tmp_path):
|
||
"""
|
||
Test that when UI_LOGO_PATH is set to a local file, get_image serves it
|
||
directly and does not return a stale cached_logo.jpg.
|
||
|
||
Regression test: previously the cache check ran before reading UI_LOGO_PATH,
|
||
so a pre-existing cached_logo.jpg (e.g. from the base Docker image) would
|
||
always be returned, ignoring the user's custom logo.
|
||
"""
|
||
from litellm.proxy.proxy_server import get_image
|
||
|
||
custom_logo = tmp_path / "custom_logo.jpg"
|
||
custom_logo.write_bytes(b"\xff\xd8\xff custom logo")
|
||
monkeypatch.setenv("UI_LOGO_PATH", str(custom_logo))
|
||
monkeypatch.delenv("LITELLM_NON_ROOT", raising=False)
|
||
monkeypatch.delenv("LITELLM_ASSETS_PATH", raising=False)
|
||
|
||
calls_to_file_response = []
|
||
|
||
def fake_file_response(path, **kwargs):
|
||
calls_to_file_response.append(path)
|
||
return MagicMock()
|
||
|
||
with (
|
||
patch(
|
||
"litellm.proxy.proxy_server.FileResponse", side_effect=fake_file_response
|
||
),
|
||
):
|
||
await get_image()
|
||
|
||
assert (
|
||
len(calls_to_file_response) == 1
|
||
), "FileResponse should be called exactly once"
|
||
assert calls_to_file_response[0] == str(custom_logo.resolve()), (
|
||
f"Expected custom logo path, got {calls_to_file_response[0]}. "
|
||
"A stale cached_logo.jpg may have been returned instead."
|
||
)
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_get_image_default_logo_ignores_stale_cache(monkeypatch, tmp_path):
|
||
"""
|
||
Test that when UI_LOGO_PATH is NOT set, stale pre-fix cached_logo.jpg
|
||
files are ignored and the default logo is served.
|
||
"""
|
||
from unittest.mock import patch
|
||
|
||
from litellm.proxy.proxy_server import get_image
|
||
|
||
cache_path = tmp_path / "cached_logo.jpg"
|
||
cache_path.write_bytes(b"\xff\xd8\xff cached logo")
|
||
monkeypatch.delenv("UI_LOGO_PATH", raising=False)
|
||
monkeypatch.delenv("LITELLM_NON_ROOT", raising=False)
|
||
monkeypatch.setenv("LITELLM_ASSETS_PATH", str(tmp_path))
|
||
|
||
calls_to_file_response = []
|
||
|
||
def fake_file_response(path, **kwargs):
|
||
calls_to_file_response.append(path)
|
||
return MagicMock()
|
||
|
||
with (
|
||
patch(
|
||
"litellm.proxy.proxy_server.FileResponse", side_effect=fake_file_response
|
||
),
|
||
):
|
||
await get_image()
|
||
|
||
assert (
|
||
len(calls_to_file_response) == 1
|
||
), "FileResponse should be called exactly once"
|
||
served_path = calls_to_file_response[0]
|
||
assert served_path != str(cache_path.resolve())
|
||
assert served_path.endswith("logo.jpg")
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_get_image_custom_logo_missing_falls_through_to_default(
|
||
monkeypatch, tmp_path
|
||
):
|
||
"""
|
||
Test that when UI_LOGO_PATH points to a non-existent local file,
|
||
get_image falls through to the default logo instead of failing.
|
||
"""
|
||
from unittest.mock import patch
|
||
|
||
from litellm.proxy.proxy_server import get_image
|
||
|
||
custom_logo_path = tmp_path / "nonexistent_logo.jpg"
|
||
monkeypatch.setenv("UI_LOGO_PATH", str(custom_logo_path))
|
||
monkeypatch.delenv("LITELLM_NON_ROOT", raising=False)
|
||
monkeypatch.setenv("LITELLM_ASSETS_PATH", str(tmp_path))
|
||
|
||
calls_to_file_response = []
|
||
|
||
def fake_file_response(path, **kwargs):
|
||
calls_to_file_response.append(path)
|
||
return MagicMock()
|
||
|
||
with (
|
||
patch(
|
||
"litellm.proxy.proxy_server.FileResponse", side_effect=fake_file_response
|
||
),
|
||
):
|
||
await get_image()
|
||
|
||
assert (
|
||
len(calls_to_file_response) == 1
|
||
), "FileResponse should be called exactly once"
|
||
served_path = calls_to_file_response[0]
|
||
assert served_path != str(
|
||
custom_logo_path
|
||
), "Should not attempt to serve a non-existent custom logo"
|
||
assert served_path.endswith("logo.jpg")
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_get_image_custom_logo_missing_no_cache_serves_default(
|
||
monkeypatch, tmp_path
|
||
):
|
||
"""
|
||
Test that when UI_LOGO_PATH points to a non-existent file AND there is no
|
||
cached_logo.jpg, get_image serves the default logo instead of the non-existent
|
||
custom path.
|
||
"""
|
||
from unittest.mock import patch
|
||
|
||
from litellm.proxy.proxy_server import get_image
|
||
|
||
custom_logo_path = tmp_path / "nonexistent_logo.jpg"
|
||
monkeypatch.setenv("UI_LOGO_PATH", str(custom_logo_path))
|
||
monkeypatch.delenv("LITELLM_NON_ROOT", raising=False)
|
||
monkeypatch.setenv("LITELLM_ASSETS_PATH", str(tmp_path))
|
||
|
||
calls_to_file_response = []
|
||
|
||
def fake_file_response(path, **kwargs):
|
||
calls_to_file_response.append(path)
|
||
return MagicMock()
|
||
|
||
with (
|
||
patch(
|
||
"litellm.proxy.proxy_server.FileResponse", side_effect=fake_file_response
|
||
),
|
||
):
|
||
await get_image()
|
||
|
||
assert (
|
||
len(calls_to_file_response) == 1
|
||
), "FileResponse should be called exactly once"
|
||
served_path = calls_to_file_response[0]
|
||
assert served_path != str(
|
||
custom_logo_path
|
||
), "Should not attempt to serve a non-existent custom logo"
|
||
assert served_path.endswith(
|
||
"logo.jpg"
|
||
), f"Expected fallback to default logo.jpg, got {served_path}"
|
||
|
||
|
||
def test_get_config_normalizes_string_callbacks(monkeypatch):
|
||
"""
|
||
Test that /get/config/callbacks normalizes string callbacks to lists.
|
||
"""
|
||
from litellm.proxy.proxy_server import app, proxy_config, user_api_key_auth
|
||
|
||
config_data = {
|
||
"litellm_settings": {
|
||
"success_callback": "langfuse",
|
||
"failure_callback": None,
|
||
"callbacks": ["prometheus", "datadog"],
|
||
},
|
||
"general_settings": {},
|
||
"environment_variables": {},
|
||
}
|
||
|
||
mock_router = MagicMock()
|
||
mock_router.get_settings.return_value = {}
|
||
monkeypatch.setattr("litellm.proxy.proxy_server.llm_router", mock_router)
|
||
monkeypatch.setattr(proxy_config, "get_config", AsyncMock(return_value=config_data))
|
||
|
||
original_overrides = app.dependency_overrides.copy()
|
||
app.dependency_overrides[user_api_key_auth] = lambda: MagicMock()
|
||
|
||
client = TestClient(app)
|
||
try:
|
||
response = client.get("/get/config/callbacks")
|
||
finally:
|
||
app.dependency_overrides = original_overrides
|
||
|
||
assert response.status_code == 200
|
||
callbacks = response.json()["callbacks"]
|
||
|
||
success_callbacks = [cb["name"] for cb in callbacks if cb.get("type") == "success"]
|
||
failure_callbacks = [cb["name"] for cb in callbacks if cb.get("type") == "failure"]
|
||
success_and_failure_callbacks = [
|
||
cb["name"] for cb in callbacks if cb.get("type") == "success_and_failure"
|
||
]
|
||
|
||
assert "langfuse" in success_callbacks
|
||
assert len(failure_callbacks) == 0
|
||
assert "prometheus" in success_and_failure_callbacks
|
||
assert "datadog" in success_and_failure_callbacks
|
||
|
||
|
||
def test_deep_merge_dicts_skips_none_and_empty_lists(monkeypatch):
|
||
"""
|
||
Test that _update_config_fields deep merge skips None values and empty lists.
|
||
"""
|
||
from litellm.proxy.proxy_server import ProxyConfig
|
||
|
||
proxy_config = ProxyConfig()
|
||
|
||
current_config = {
|
||
"general_settings": {
|
||
"max_parallel_requests": 10,
|
||
"allowed_models": ["gpt-3.5-turbo", "gpt-4"],
|
||
"nested": {
|
||
"key1": "value1",
|
||
"key2": "value2",
|
||
},
|
||
}
|
||
}
|
||
|
||
db_param_value = {
|
||
"max_parallel_requests": None,
|
||
"allowed_models": [],
|
||
"new_key": "new_value",
|
||
"nested": {
|
||
"key1": "updated_value1",
|
||
"key3": "value3",
|
||
},
|
||
}
|
||
|
||
result = proxy_config._update_config_fields(
|
||
current_config, "general_settings", db_param_value
|
||
)
|
||
|
||
assert result["general_settings"]["max_parallel_requests"] == 10
|
||
assert result["general_settings"]["allowed_models"] == ["gpt-3.5-turbo", "gpt-4"]
|
||
assert result["general_settings"]["new_key"] == "new_value"
|
||
assert result["general_settings"]["nested"]["key1"] == "updated_value1"
|
||
assert result["general_settings"]["nested"]["key2"] == "value2"
|
||
assert result["general_settings"]["nested"]["key3"] == "value3"
|
||
|
||
|
||
class TestInvitationEndpoints:
|
||
"""Tests for /invitation/new and /invitation/delete endpoints."""
|
||
|
||
@pytest.fixture
|
||
def client_with_auth(self):
|
||
"""Create a test client with admin authentication."""
|
||
from litellm.proxy._types import LitellmUserRoles
|
||
from litellm.proxy.proxy_server import cleanup_router_config_variables
|
||
|
||
cleanup_router_config_variables()
|
||
filepath = os.path.dirname(os.path.abspath(__file__))
|
||
config_fp = f"{filepath}/test_configs/test_config_no_auth.yaml"
|
||
asyncio.run(initialize(config=config_fp, debug=True))
|
||
|
||
mock_auth = MagicMock()
|
||
mock_auth.user_id = "admin-user-id"
|
||
mock_auth.user_role = LitellmUserRoles.PROXY_ADMIN
|
||
mock_auth.api_key = "sk-test"
|
||
app.dependency_overrides[user_api_key_auth] = lambda: mock_auth
|
||
|
||
return TestClient(app)
|
||
|
||
@pytest.mark.parametrize(
|
||
"endpoint,payload,mock_return",
|
||
[
|
||
(
|
||
"/invitation/new",
|
||
{"user_id": "target-user-123"},
|
||
{
|
||
"id": "inv-123",
|
||
"user_id": "target-user-123",
|
||
"is_accepted": False,
|
||
"accepted_at": None,
|
||
"expires_at": "2025-02-18T00:00:00",
|
||
"created_at": "2025-02-11T00:00:00",
|
||
"created_by": "admin-user-id",
|
||
"updated_at": "2025-02-11T00:00:00",
|
||
"updated_by": "admin-user-id",
|
||
},
|
||
),
|
||
(
|
||
"/invitation/delete",
|
||
{"invitation_id": "inv-456"},
|
||
{
|
||
"id": "inv-456",
|
||
"user_id": "target-user-123",
|
||
"is_accepted": False,
|
||
"accepted_at": None,
|
||
"expires_at": "2025-02-18T00:00:00",
|
||
"created_at": "2025-02-11T00:00:00",
|
||
"created_by": "admin-user-id",
|
||
"updated_at": "2025-02-11T00:00:00",
|
||
"updated_by": "admin-user-id",
|
||
},
|
||
),
|
||
],
|
||
)
|
||
def test_invitation_endpoints_proxy_admin_success(
|
||
self, client_with_auth, endpoint, payload, mock_return
|
||
):
|
||
"""Proxy admin can successfully create and delete invitations."""
|
||
with patch("litellm.proxy.proxy_server.prisma_client") as mock_prisma:
|
||
mock_prisma.db.litellm_invitationlink = MagicMock()
|
||
if endpoint == "/invitation/new":
|
||
mock_create = AsyncMock(return_value=mock_return)
|
||
with patch(
|
||
"litellm.proxy.management_helpers.user_invitation.create_invitation_for_user",
|
||
mock_create,
|
||
):
|
||
response = client_with_auth.post(endpoint, json=payload)
|
||
else:
|
||
mock_prisma.db.litellm_invitationlink.find_unique = AsyncMock(
|
||
return_value={**mock_return, "created_by": "admin-user-id"}
|
||
)
|
||
mock_prisma.db.litellm_invitationlink.delete = AsyncMock(
|
||
return_value=mock_return
|
||
)
|
||
response = client_with_auth.post(endpoint, json=payload)
|
||
|
||
assert response.status_code == 200
|
||
data = response.json()
|
||
assert data["id"] == mock_return["id"]
|
||
assert data["user_id"] == mock_return["user_id"]
|
||
|
||
@pytest.mark.parametrize(
|
||
"endpoint,payload",
|
||
[
|
||
("/invitation/new", {"user_id": "target-user-123"}),
|
||
("/invitation/delete", {"invitation_id": "inv-456"}),
|
||
],
|
||
)
|
||
def test_invitation_endpoints_non_admin_denied(
|
||
self, client_with_auth, endpoint, payload
|
||
):
|
||
"""Non-admin users cannot access invitation endpoints."""
|
||
from litellm.proxy._types import LitellmUserRoles
|
||
|
||
mock_auth = MagicMock()
|
||
mock_auth.user_id = "regular-user"
|
||
mock_auth.user_role = LitellmUserRoles.INTERNAL_USER
|
||
mock_auth.api_key = "sk-regular"
|
||
app.dependency_overrides[user_api_key_auth] = lambda: mock_auth
|
||
|
||
with patch("litellm.proxy.proxy_server.prisma_client") as mock_prisma:
|
||
mock_prisma.db.litellm_invitationlink = MagicMock()
|
||
# Avoid triggering async DB calls in _user_has_admin_privileges
|
||
with patch(
|
||
"litellm.proxy.proxy_server._user_has_admin_privileges",
|
||
new_callable=AsyncMock,
|
||
return_value=False,
|
||
):
|
||
response = client_with_auth.post(endpoint, json=payload)
|
||
|
||
assert response.status_code == 400
|
||
body = response.json()
|
||
# ProxyException handler returns {"error": {...}}, HTTPException returns {"detail": {...}}
|
||
error_content = body.get("error", body.get("detail", body))
|
||
assert "not allowed" in str(error_content).lower()
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_async_data_generator_cleanup_on_early_exit():
|
||
"""
|
||
Test that async_data_generator calls response.aclose() in the finally block
|
||
when the generator is abandoned mid-stream (client disconnect).
|
||
"""
|
||
from litellm.proxy._types import UserAPIKeyAuth
|
||
from litellm.proxy.proxy_server import async_data_generator
|
||
from litellm.proxy.utils import ProxyLogging
|
||
|
||
mock_user_api_key_dict = MagicMock(spec=UserAPIKeyAuth)
|
||
mock_request_data = {
|
||
"model": "gpt-3.5-turbo",
|
||
"messages": [{"role": "user", "content": "test"}],
|
||
}
|
||
|
||
mock_chunks = [
|
||
{"choices": [{"delta": {"content": "Hello"}}]},
|
||
{"choices": [{"delta": {"content": " world"}}]},
|
||
{"choices": [{"delta": {"content": " more"}}]},
|
||
]
|
||
|
||
mock_proxy_logging_obj = MagicMock(spec=ProxyLogging)
|
||
|
||
async def mock_streaming_iterator(*args, **kwargs):
|
||
for chunk in mock_chunks:
|
||
yield chunk
|
||
|
||
mock_proxy_logging_obj.async_post_call_streaming_iterator_hook = (
|
||
mock_streaming_iterator
|
||
)
|
||
mock_proxy_logging_obj.async_post_call_streaming_hook = AsyncMock(
|
||
side_effect=lambda **kwargs: kwargs.get("response")
|
||
)
|
||
mock_proxy_logging_obj.post_call_failure_hook = AsyncMock()
|
||
|
||
# Create a mock response with aclose
|
||
mock_response = MagicMock()
|
||
mock_response.aclose = AsyncMock()
|
||
|
||
with patch("litellm.proxy.proxy_server.proxy_logging_obj", mock_proxy_logging_obj):
|
||
# Consume only the first chunk then abandon the generator (simulates client disconnect)
|
||
gen = async_data_generator(
|
||
mock_response, mock_user_api_key_dict, mock_request_data
|
||
)
|
||
first_chunk = await gen.__anext__()
|
||
assert first_chunk.startswith("data: ")
|
||
|
||
# Close the generator early (simulates what ASGI does on client disconnect)
|
||
await gen.aclose()
|
||
|
||
# Verify aclose was called on the response to release the HTTP connection
|
||
mock_response.aclose.assert_awaited_once()
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_async_data_generator_uses_direct_stream_fast_path_without_callbacks():
|
||
"""
|
||
When there are no streaming callbacks, async_data_generator should avoid
|
||
per-chunk hook machinery and iterate the provider stream directly.
|
||
"""
|
||
from litellm.proxy._types import UserAPIKeyAuth
|
||
from litellm.proxy.proxy_server import async_data_generator
|
||
from litellm.proxy.utils import ProxyLogging
|
||
|
||
mock_user_api_key_dict = MagicMock(spec=UserAPIKeyAuth)
|
||
mock_request_data = {
|
||
"model": "gpt-3.5-turbo",
|
||
"messages": [{"role": "user", "content": "test"}],
|
||
}
|
||
mock_chunks = [
|
||
{"choices": [{"delta": {"content": "Hello"}}]},
|
||
{"choices": [{"delta": {"content": " world"}}]},
|
||
]
|
||
|
||
class MockStream:
|
||
def __aiter__(self):
|
||
return self._stream()
|
||
|
||
async def _stream(self):
|
||
for chunk in mock_chunks:
|
||
yield chunk
|
||
|
||
async def aclose(self):
|
||
pass
|
||
|
||
mock_response = MockStream()
|
||
mock_response.aclose = AsyncMock()
|
||
mock_proxy_logging_obj = MagicMock(spec=ProxyLogging)
|
||
mock_proxy_logging_obj.has_streaming_callbacks.return_value = False
|
||
mock_proxy_logging_obj.needs_iterator_wrap.return_value = False
|
||
mock_proxy_logging_obj.needs_per_chunk_streaming_hook.return_value = False
|
||
mock_proxy_logging_obj.async_post_call_streaming_iterator_hook = MagicMock()
|
||
mock_proxy_logging_obj.async_post_call_streaming_hook = AsyncMock()
|
||
mock_proxy_logging_obj.post_call_failure_hook = AsyncMock()
|
||
|
||
with patch("litellm.proxy.proxy_server.proxy_logging_obj", mock_proxy_logging_obj):
|
||
with patch.object(
|
||
ProxyLogging, "_fire_deferred_stream_logging"
|
||
) as mock_deferred_logging:
|
||
yielded_data = []
|
||
async for data in async_data_generator(
|
||
mock_response, mock_user_api_key_dict, mock_request_data
|
||
):
|
||
yielded_data.append(data)
|
||
|
||
yielded_text = [
|
||
chunk.decode("utf-8") if isinstance(chunk, bytes) else chunk
|
||
for chunk in yielded_data
|
||
]
|
||
assert len([chunk for chunk in yielded_text if chunk.startswith("data: {")]) == 2
|
||
assert yielded_text[-1] == "data: [DONE]\n\n"
|
||
mock_proxy_logging_obj.async_post_call_streaming_iterator_hook.assert_not_called()
|
||
mock_proxy_logging_obj.async_post_call_streaming_hook.assert_not_awaited()
|
||
mock_deferred_logging.assert_called_once_with(mock_request_data)
|
||
mock_response.aclose.assert_awaited_once()
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_async_data_generator_passes_through_google_native_sse_bytes():
|
||
"""
|
||
Google-native streamGenerateContent yields raw SSE bytes; they must not be
|
||
re-wrapped as data: b'data: {...}'.
|
||
"""
|
||
from litellm.proxy._types import UserAPIKeyAuth
|
||
from litellm.proxy.proxy_server import async_data_generator
|
||
from litellm.proxy.utils import ProxyLogging
|
||
|
||
mock_user_api_key_dict = MagicMock(spec=UserAPIKeyAuth)
|
||
mock_request_data = {
|
||
"model": "gemini-2.0-flash",
|
||
"messages": [{"role": "user", "content": "test"}],
|
||
}
|
||
gemini_event = b'data: {"candidates": [{"content": "hi"}]}\n\n'
|
||
gemini_event_without_terminator = b'data: {"candidates": [{"content": "there"}]}'
|
||
raw_payload = b'{"partial": true}'
|
||
|
||
class MockStream:
|
||
def __aiter__(self):
|
||
return self._stream()
|
||
|
||
async def _stream(self):
|
||
yield gemini_event
|
||
yield gemini_event_without_terminator
|
||
yield raw_payload
|
||
|
||
async def aclose(self):
|
||
pass
|
||
|
||
mock_response = MockStream()
|
||
mock_response.aclose = AsyncMock()
|
||
mock_proxy_logging_obj = MagicMock(spec=ProxyLogging)
|
||
mock_proxy_logging_obj.has_streaming_callbacks.return_value = False
|
||
mock_proxy_logging_obj.needs_iterator_wrap.return_value = False
|
||
mock_proxy_logging_obj.needs_per_chunk_streaming_hook.return_value = False
|
||
mock_proxy_logging_obj.async_post_call_streaming_iterator_hook = MagicMock()
|
||
mock_proxy_logging_obj.async_post_call_streaming_hook = AsyncMock()
|
||
mock_proxy_logging_obj.post_call_failure_hook = AsyncMock()
|
||
|
||
with patch("litellm.proxy.proxy_server.proxy_logging_obj", mock_proxy_logging_obj):
|
||
with patch.object(ProxyLogging, "_fire_deferred_stream_logging"):
|
||
yielded_data = []
|
||
async for data in async_data_generator(
|
||
mock_response, mock_user_api_key_dict, mock_request_data
|
||
):
|
||
yielded_data.append(data)
|
||
|
||
yielded_text = [
|
||
chunk.decode("utf-8") if isinstance(chunk, bytes) else chunk
|
||
for chunk in yielded_data
|
||
]
|
||
assert yielded_text[0] == gemini_event.decode("utf-8")
|
||
assert yielded_text[1] == gemini_event_without_terminator.decode("utf-8") + "\n\n"
|
||
assert yielded_text[2] == f'data: {raw_payload.decode("utf-8")}\n\n'
|
||
assert "b'data:" not in "".join(yielded_text)
|
||
assert yielded_text[-1] == "data: [DONE]\n\n"
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_async_data_generator_google_genai_stream_omits_openai_done():
|
||
"""
|
||
google-genai SDK streamGenerateContent?alt=sse must not receive data: [DONE].
|
||
"""
|
||
from litellm.proxy._types import UserAPIKeyAuth
|
||
from litellm.proxy.proxy_server import async_data_generator
|
||
from litellm.proxy.utils import ProxyLogging
|
||
|
||
mock_user_api_key_dict = MagicMock(spec=UserAPIKeyAuth)
|
||
mock_request_data = {
|
||
"model": "gemini-2.0-flash",
|
||
"_litellm_skip_openai_stream_done": True,
|
||
}
|
||
gemini_event = (
|
||
b'data: {"candidates": [{"content": {"parts": [{"text": "Hi"}]}}]}\n\n'
|
||
)
|
||
|
||
class MockStream:
|
||
def __aiter__(self):
|
||
return self._stream()
|
||
|
||
async def _stream(self):
|
||
yield gemini_event
|
||
|
||
async def aclose(self):
|
||
pass
|
||
|
||
mock_response = MockStream()
|
||
mock_response.aclose = AsyncMock()
|
||
mock_proxy_logging_obj = MagicMock(spec=ProxyLogging)
|
||
mock_proxy_logging_obj.has_streaming_callbacks.return_value = False
|
||
mock_proxy_logging_obj.needs_iterator_wrap.return_value = False
|
||
mock_proxy_logging_obj.needs_per_chunk_streaming_hook.return_value = False
|
||
mock_proxy_logging_obj.async_post_call_streaming_iterator_hook = MagicMock()
|
||
mock_proxy_logging_obj.async_post_call_streaming_hook = AsyncMock()
|
||
mock_proxy_logging_obj.post_call_failure_hook = AsyncMock()
|
||
|
||
with patch("litellm.proxy.proxy_server.proxy_logging_obj", mock_proxy_logging_obj):
|
||
with patch.object(ProxyLogging, "_fire_deferred_stream_logging"):
|
||
yielded_data = []
|
||
async for data in async_data_generator(
|
||
mock_response, mock_user_api_key_dict, mock_request_data
|
||
):
|
||
yielded_data.append(data)
|
||
|
||
yielded_text = [
|
||
chunk.decode("utf-8") if isinstance(chunk, bytes) else chunk
|
||
for chunk in yielded_data
|
||
]
|
||
assert yielded_text == [gemini_event.decode("utf-8")]
|
||
assert "[DONE]" not in "".join(yielded_text)
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_async_data_generator_google_genai_stream_forwards_error_without_done():
|
||
"""Stream errors must still reach the client when OpenAI [DONE] is skipped."""
|
||
from litellm.proxy._types import UserAPIKeyAuth
|
||
from litellm.proxy.proxy_server import async_data_generator
|
||
from litellm.proxy.utils import ProxyLogging
|
||
|
||
error_sse = 'data: {"error": {"message": "stream failed"}}\n\n'
|
||
mock_user_api_key_dict = MagicMock(spec=UserAPIKeyAuth)
|
||
mock_request_data = {
|
||
"model": "gemini-2.0-flash",
|
||
"_litellm_skip_openai_stream_done": True,
|
||
}
|
||
|
||
class MockStream:
|
||
def __aiter__(self):
|
||
return self._stream()
|
||
|
||
async def _stream(self):
|
||
yield error_sse
|
||
|
||
async def aclose(self):
|
||
pass
|
||
|
||
mock_response = MockStream()
|
||
mock_response.aclose = AsyncMock()
|
||
mock_proxy_logging_obj = MagicMock(spec=ProxyLogging)
|
||
mock_proxy_logging_obj.has_streaming_callbacks.return_value = False
|
||
mock_proxy_logging_obj.needs_iterator_wrap.return_value = False
|
||
mock_proxy_logging_obj.needs_per_chunk_streaming_hook.return_value = False
|
||
mock_proxy_logging_obj.async_post_call_streaming_iterator_hook = MagicMock()
|
||
mock_proxy_logging_obj.async_post_call_streaming_hook = AsyncMock()
|
||
mock_proxy_logging_obj.post_call_failure_hook = AsyncMock()
|
||
|
||
with patch("litellm.proxy.proxy_server.proxy_logging_obj", mock_proxy_logging_obj):
|
||
with patch.object(ProxyLogging, "_fire_deferred_stream_logging"):
|
||
yielded_data = []
|
||
async for data in async_data_generator(
|
||
mock_response, mock_user_api_key_dict, mock_request_data
|
||
):
|
||
yielded_data.append(data)
|
||
|
||
yielded_text = [
|
||
chunk.decode("utf-8") if isinstance(chunk, bytes) else chunk
|
||
for chunk in yielded_data
|
||
]
|
||
assert yielded_text == [error_sse]
|
||
assert "[DONE]" not in "".join(yielded_text)
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_async_data_generator_cleanup_on_normal_completion():
|
||
"""
|
||
Test that async_data_generator calls response.aclose() even on normal completion.
|
||
"""
|
||
from litellm.proxy._types import UserAPIKeyAuth
|
||
from litellm.proxy.proxy_server import async_data_generator
|
||
from litellm.proxy.utils import ProxyLogging
|
||
|
||
mock_user_api_key_dict = MagicMock(spec=UserAPIKeyAuth)
|
||
mock_request_data = {
|
||
"model": "gpt-3.5-turbo",
|
||
"messages": [{"role": "user", "content": "test"}],
|
||
}
|
||
|
||
mock_chunks = [
|
||
{"choices": [{"delta": {"content": "Hello"}}]},
|
||
]
|
||
|
||
mock_proxy_logging_obj = MagicMock(spec=ProxyLogging)
|
||
|
||
async def mock_streaming_iterator(*args, **kwargs):
|
||
for chunk in mock_chunks:
|
||
yield chunk
|
||
|
||
mock_proxy_logging_obj.async_post_call_streaming_iterator_hook = (
|
||
mock_streaming_iterator
|
||
)
|
||
mock_proxy_logging_obj.async_post_call_streaming_hook = AsyncMock(
|
||
side_effect=lambda **kwargs: kwargs.get("response")
|
||
)
|
||
mock_proxy_logging_obj.post_call_failure_hook = AsyncMock()
|
||
|
||
mock_response = MagicMock()
|
||
mock_response.aclose = AsyncMock()
|
||
|
||
with patch("litellm.proxy.proxy_server.proxy_logging_obj", mock_proxy_logging_obj):
|
||
yielded_data = []
|
||
async for data in async_data_generator(
|
||
mock_response, mock_user_api_key_dict, mock_request_data
|
||
):
|
||
yielded_data.append(data)
|
||
|
||
# Should have completed normally with [DONE]
|
||
assert any("[DONE]" in d for d in yielded_data)
|
||
# aclose should still be called via finally block
|
||
mock_response.aclose.assert_awaited_once()
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_async_data_generator_cleanup_on_midstream_error():
|
||
"""
|
||
Test that async_data_generator calls response.aclose() via finally block
|
||
even when an exception occurs mid-stream.
|
||
"""
|
||
from litellm.proxy._types import UserAPIKeyAuth
|
||
from litellm.proxy.proxy_server import async_data_generator
|
||
from litellm.proxy.utils import ProxyLogging
|
||
|
||
mock_user_api_key_dict = MagicMock(spec=UserAPIKeyAuth)
|
||
mock_request_data = {
|
||
"model": "gpt-3.5-turbo",
|
||
"messages": [{"role": "user", "content": "test"}],
|
||
}
|
||
|
||
mock_proxy_logging_obj = MagicMock(spec=ProxyLogging)
|
||
|
||
async def mock_streaming_iterator_with_error(*args, **kwargs):
|
||
yield {"choices": [{"delta": {"content": "Hello"}}]}
|
||
raise RuntimeError("upstream connection reset")
|
||
|
||
mock_proxy_logging_obj.async_post_call_streaming_iterator_hook = (
|
||
mock_streaming_iterator_with_error
|
||
)
|
||
mock_proxy_logging_obj.async_post_call_streaming_hook = AsyncMock(
|
||
side_effect=lambda **kwargs: kwargs.get("response")
|
||
)
|
||
mock_proxy_logging_obj.post_call_failure_hook = AsyncMock()
|
||
|
||
mock_response = MagicMock()
|
||
mock_response.aclose = AsyncMock()
|
||
|
||
with patch("litellm.proxy.proxy_server.proxy_logging_obj", mock_proxy_logging_obj):
|
||
yielded_data = []
|
||
async for data in async_data_generator(
|
||
mock_response, mock_user_api_key_dict, mock_request_data
|
||
):
|
||
yielded_data.append(data)
|
||
|
||
# Should have yielded data chunk and then an error chunk
|
||
assert len(yielded_data) >= 2
|
||
assert any("error" in d for d in yielded_data)
|
||
# aclose must still be called via finally block despite the error
|
||
mock_response.aclose.assert_awaited_once()
|
||
|
||
|
||
# ============================================================================
|
||
# store_model_in_db DB Config Override Tests
|
||
# ============================================================================
|
||
|
||
|
||
def test_store_model_in_db_in_config_general_settings():
|
||
"""
|
||
Verify store_model_in_db is a valid field in ConfigGeneralSettings
|
||
and validates correctly for True/False values.
|
||
"""
|
||
from litellm.proxy._types import ConfigGeneralSettings
|
||
|
||
assert "store_model_in_db" in ConfigGeneralSettings.model_fields
|
||
|
||
# Should validate with True
|
||
config = ConfigGeneralSettings(store_model_in_db=True)
|
||
assert config.store_model_in_db is True
|
||
|
||
# Should validate with False
|
||
config = ConfigGeneralSettings(store_model_in_db=False)
|
||
assert config.store_model_in_db is False
|
||
|
||
# Should validate with None (default)
|
||
config = ConfigGeneralSettings(store_model_in_db=None)
|
||
assert config.store_model_in_db is None
|
||
|
||
# Should validate with no value
|
||
config = ConfigGeneralSettings()
|
||
assert config.store_model_in_db is None
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_update_general_settings_store_model_in_db_true():
|
||
"""
|
||
Verify _update_general_settings sets global store_model_in_db to True
|
||
when DB general_settings has store_model_in_db=True.
|
||
"""
|
||
from litellm.proxy.proxy_server import ProxyConfig
|
||
|
||
proxy_config = ProxyConfig()
|
||
|
||
with (
|
||
patch("litellm.proxy.proxy_server.store_model_in_db", False) as mock_store,
|
||
patch("litellm.proxy.proxy_server.general_settings", {}) as mock_gs,
|
||
):
|
||
await proxy_config._update_general_settings(
|
||
db_general_settings={"store_model_in_db": True}
|
||
)
|
||
|
||
import litellm.proxy.proxy_server as ps
|
||
|
||
assert ps.store_model_in_db is True
|
||
assert ps.general_settings["store_model_in_db"] is True
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_update_general_settings_store_model_in_db_false():
|
||
"""
|
||
Verify _update_general_settings sets global store_model_in_db to False
|
||
when DB general_settings has store_model_in_db=False.
|
||
"""
|
||
from litellm.proxy.proxy_server import ProxyConfig
|
||
|
||
proxy_config = ProxyConfig()
|
||
|
||
with (
|
||
patch("litellm.proxy.proxy_server.store_model_in_db", True),
|
||
patch("litellm.proxy.proxy_server.general_settings", {}),
|
||
):
|
||
await proxy_config._update_general_settings(
|
||
db_general_settings={"store_model_in_db": False}
|
||
)
|
||
|
||
import litellm.proxy.proxy_server as ps
|
||
|
||
assert ps.store_model_in_db is False
|
||
assert ps.general_settings["store_model_in_db"] is False
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_update_general_settings_store_model_in_db_string_normalization():
|
||
"""
|
||
Verify _update_general_settings normalizes string values for store_model_in_db.
|
||
"""
|
||
from litellm.proxy.proxy_server import ProxyConfig
|
||
|
||
proxy_config = ProxyConfig()
|
||
|
||
# Test "true" string
|
||
with (
|
||
patch("litellm.proxy.proxy_server.store_model_in_db", False),
|
||
patch("litellm.proxy.proxy_server.general_settings", {}),
|
||
):
|
||
await proxy_config._update_general_settings(
|
||
db_general_settings={"store_model_in_db": "true"}
|
||
)
|
||
import litellm.proxy.proxy_server as ps
|
||
|
||
assert ps.store_model_in_db is True
|
||
|
||
# Test "True" string
|
||
with (
|
||
patch("litellm.proxy.proxy_server.store_model_in_db", False),
|
||
patch("litellm.proxy.proxy_server.general_settings", {}),
|
||
):
|
||
await proxy_config._update_general_settings(
|
||
db_general_settings={"store_model_in_db": "True"}
|
||
)
|
||
import litellm.proxy.proxy_server as ps
|
||
|
||
assert ps.store_model_in_db is True
|
||
|
||
# Test "false" string
|
||
with (
|
||
patch("litellm.proxy.proxy_server.store_model_in_db", True),
|
||
patch("litellm.proxy.proxy_server.general_settings", {}),
|
||
):
|
||
await proxy_config._update_general_settings(
|
||
db_general_settings={"store_model_in_db": "false"}
|
||
)
|
||
import litellm.proxy.proxy_server as ps
|
||
|
||
assert ps.store_model_in_db is False
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_update_general_settings_store_model_in_db_none_keeps_current():
|
||
"""
|
||
Verify _update_general_settings does not change store_model_in_db
|
||
when DB value is None.
|
||
"""
|
||
from litellm.proxy.proxy_server import ProxyConfig
|
||
|
||
proxy_config = ProxyConfig()
|
||
|
||
# When current is True and DB sends None, should stay True
|
||
with (
|
||
patch("litellm.proxy.proxy_server.store_model_in_db", True),
|
||
patch("litellm.proxy.proxy_server.general_settings", {}),
|
||
):
|
||
await proxy_config._update_general_settings(
|
||
db_general_settings={"store_model_in_db": None}
|
||
)
|
||
import litellm.proxy.proxy_server as ps
|
||
|
||
assert ps.store_model_in_db is True
|
||
|
||
# When current is False and DB sends None, should stay False
|
||
with (
|
||
patch("litellm.proxy.proxy_server.store_model_in_db", False),
|
||
patch("litellm.proxy.proxy_server.general_settings", {}),
|
||
):
|
||
await proxy_config._update_general_settings(
|
||
db_general_settings={"store_model_in_db": None}
|
||
)
|
||
import litellm.proxy.proxy_server as ps
|
||
|
||
assert ps.store_model_in_db is False
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_store_model_in_db_db_override_when_config_false():
|
||
"""
|
||
Verify the early DB check in initialize_scheduled_background_jobs
|
||
overrides store_model_in_db=False when DB has True.
|
||
"""
|
||
from litellm.proxy.proxy_server import ProxyStartupEvent
|
||
from litellm.proxy.utils import ProxyLogging
|
||
|
||
mock_prisma_client = MagicMock()
|
||
|
||
# Mock DB returning store_model_in_db=True in general_settings
|
||
mock_db_record = MagicMock()
|
||
mock_db_record.param_value = {"store_model_in_db": True}
|
||
mock_prisma_client.db.litellm_config.find_first = AsyncMock(
|
||
return_value=mock_db_record
|
||
)
|
||
|
||
mock_proxy_logging = MagicMock(spec=ProxyLogging)
|
||
mock_proxy_logging.slack_alerting_instance = MagicMock()
|
||
mock_proxy_config = AsyncMock()
|
||
|
||
with (
|
||
patch("litellm.proxy.proxy_server.proxy_config", mock_proxy_config),
|
||
patch("litellm.proxy.proxy_server.store_model_in_db", False),
|
||
patch("litellm.proxy.proxy_server.get_secret_bool", return_value=False),
|
||
):
|
||
await ProxyStartupEvent.initialize_scheduled_background_jobs(
|
||
general_settings={},
|
||
prisma_client=mock_prisma_client,
|
||
proxy_budget_rescheduler_min_time=1,
|
||
proxy_budget_rescheduler_max_time=2,
|
||
proxy_batch_write_at=5,
|
||
proxy_logging_obj=mock_proxy_logging,
|
||
)
|
||
|
||
import litellm.proxy.proxy_server as ps
|
||
|
||
# store_model_in_db should now be True (overridden by DB)
|
||
assert ps.store_model_in_db is True
|
||
|
||
# add_deployment and get_credentials should have been called
|
||
# since store_model_in_db is now True
|
||
assert mock_proxy_config.add_deployment.call_count == 1
|
||
assert mock_proxy_config.get_credentials.call_count == 1
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_store_model_in_db_db_check_skipped_when_already_true(monkeypatch):
|
||
"""
|
||
Verify the early DB check is skipped when store_model_in_db is already True.
|
||
The DB query for the early check should not be called.
|
||
"""
|
||
monkeypatch.delenv("STORE_MODEL_IN_DB", raising=False)
|
||
from litellm.proxy.proxy_server import ProxyStartupEvent
|
||
from litellm.proxy.utils import ProxyLogging
|
||
|
||
mock_prisma_client = MagicMock()
|
||
mock_prisma_client.db.litellm_config.find_first = AsyncMock(return_value=None)
|
||
|
||
mock_proxy_logging = MagicMock(spec=ProxyLogging)
|
||
mock_proxy_logging.slack_alerting_instance = MagicMock()
|
||
mock_proxy_config = AsyncMock()
|
||
|
||
with (
|
||
patch("litellm.proxy.proxy_server.proxy_config", mock_proxy_config),
|
||
patch("litellm.proxy.proxy_server.store_model_in_db", True),
|
||
patch("litellm.proxy.proxy_server.get_secret_bool", return_value=True),
|
||
):
|
||
await ProxyStartupEvent.initialize_scheduled_background_jobs(
|
||
general_settings={},
|
||
prisma_client=mock_prisma_client,
|
||
proxy_budget_rescheduler_min_time=1,
|
||
proxy_budget_rescheduler_max_time=2,
|
||
proxy_batch_write_at=5,
|
||
proxy_logging_obj=mock_proxy_logging,
|
||
)
|
||
|
||
# The early DB check uses find_first with param_name="general_settings".
|
||
# When store_model_in_db is already True, the early check should be skipped.
|
||
# However, add_deployment may also call find_first.
|
||
# We just verify that store_model_in_db stays True and jobs are scheduled.
|
||
import litellm.proxy.proxy_server as ps
|
||
|
||
assert ps.store_model_in_db is True
|
||
assert mock_proxy_config.add_deployment.call_count == 1
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_store_model_in_db_db_failure_graceful(monkeypatch):
|
||
"""
|
||
Verify the early DB check handles DB failures gracefully
|
||
without crashing and keeps store_model_in_db as False.
|
||
"""
|
||
monkeypatch.delenv("STORE_MODEL_IN_DB", raising=False)
|
||
from litellm.proxy.proxy_server import ProxyStartupEvent
|
||
from litellm.proxy.utils import ProxyLogging
|
||
|
||
mock_prisma_client = MagicMock()
|
||
# Simulate DB failure
|
||
mock_prisma_client.db.litellm_config.find_first = AsyncMock(
|
||
side_effect=Exception("DB connection error")
|
||
)
|
||
|
||
mock_proxy_logging = MagicMock(spec=ProxyLogging)
|
||
mock_proxy_logging.slack_alerting_instance = MagicMock()
|
||
mock_proxy_config = AsyncMock()
|
||
|
||
with (
|
||
patch("litellm.proxy.proxy_server.proxy_config", mock_proxy_config),
|
||
patch("litellm.proxy.proxy_server.store_model_in_db", False),
|
||
patch("litellm.proxy.proxy_server.get_secret_bool", return_value=False),
|
||
):
|
||
# Should not raise an exception
|
||
await ProxyStartupEvent.initialize_scheduled_background_jobs(
|
||
general_settings={},
|
||
prisma_client=mock_prisma_client,
|
||
proxy_budget_rescheduler_min_time=1,
|
||
proxy_budget_rescheduler_max_time=2,
|
||
proxy_batch_write_at=5,
|
||
proxy_logging_obj=mock_proxy_logging,
|
||
)
|
||
|
||
import litellm.proxy.proxy_server as ps
|
||
|
||
# store_model_in_db should remain False
|
||
assert ps.store_model_in_db is False
|
||
|
||
# add_deployment should NOT have been called since store_model_in_db is False
|
||
mock_proxy_config.add_deployment.assert_not_called()
|
||
|
||
|
||
# =====================================================================
|
||
# Spend counter tests (v2 — Redis-backed spend counters)
|
||
# =====================================================================
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_get_current_spend_reads_redis_first():
|
||
"""get_current_spend should prefer Redis over in-memory."""
|
||
from litellm.caching.dual_cache import DualCache
|
||
|
||
counter_cache = DualCache()
|
||
|
||
# In-memory has stale value
|
||
counter_cache.in_memory_cache.set_cache(key="spend:key:test", value=0.30)
|
||
|
||
# Mock Redis with cross-pod authoritative value
|
||
mock_redis = AsyncMock()
|
||
mock_redis.async_get_cache = AsyncMock(return_value=0.90)
|
||
counter_cache.redis_cache = mock_redis
|
||
|
||
import litellm.proxy.proxy_server as ps
|
||
|
||
original = ps.spend_counter_cache
|
||
ps.spend_counter_cache = counter_cache
|
||
|
||
try:
|
||
from litellm.proxy.proxy_server import get_current_spend
|
||
|
||
result = await get_current_spend(
|
||
counter_key="spend:key:test",
|
||
fallback_spend=0.0,
|
||
)
|
||
# Should return Redis value (0.90), not in-memory (0.30)
|
||
assert result == 0.90
|
||
mock_redis.async_get_cache.assert_called_once_with(key="spend:key:test")
|
||
finally:
|
||
ps.spend_counter_cache = original
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_get_current_spend_fallback_to_in_memory():
|
||
"""When Redis is not configured, get_current_spend uses in-memory."""
|
||
from litellm.caching.dual_cache import DualCache
|
||
|
||
counter_cache = DualCache() # no redis_cache
|
||
counter_cache.in_memory_cache.set_cache(key="spend:key:test", value=0.50)
|
||
|
||
import litellm.proxy.proxy_server as ps
|
||
|
||
original = ps.spend_counter_cache
|
||
ps.spend_counter_cache = counter_cache
|
||
|
||
try:
|
||
from litellm.proxy.proxy_server import get_current_spend
|
||
|
||
result = await get_current_spend(
|
||
counter_key="spend:key:test",
|
||
fallback_spend=0.0,
|
||
)
|
||
assert result == 0.50
|
||
finally:
|
||
ps.spend_counter_cache = original
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_increment_spend_counters_initializes_and_increments():
|
||
"""Counter should initialize from cached object spend, then increment.
|
||
|
||
Uses a pre-hashed token to match production: metadata["user_api_key"]
|
||
is always hashed by the auth flow before reaching the cost callback.
|
||
"""
|
||
from litellm.caching.dual_cache import DualCache
|
||
from litellm.proxy._types import LiteLLM_VerificationTokenView, hash_token
|
||
|
||
key_cache = DualCache()
|
||
counter_cache = DualCache()
|
||
|
||
# In production, the auth flow hashes the raw key before it reaches
|
||
# the cost callback. Simulate that by passing the hashed token.
|
||
hashed_token = hash_token("sk-test-token-for-counter")
|
||
|
||
# Simulate a cached key object with existing spend from DB
|
||
cached_key = LiteLLM_VerificationTokenView(
|
||
token=hashed_token,
|
||
spend=5.0,
|
||
max_budget=10.0,
|
||
)
|
||
key_cache.in_memory_cache.set_cache(key=hashed_token, value=cached_key)
|
||
|
||
import litellm.proxy.proxy_server as ps
|
||
|
||
original_key_cache = ps.user_api_key_cache
|
||
original_counter_cache = ps.spend_counter_cache
|
||
ps.user_api_key_cache = key_cache
|
||
ps.spend_counter_cache = counter_cache
|
||
|
||
try:
|
||
from litellm.proxy.proxy_server import increment_spend_counters
|
||
|
||
# Pass pre-hashed token (as the cost callback would in production)
|
||
await increment_spend_counters(
|
||
token=hashed_token,
|
||
team_id=None,
|
||
user_id=None,
|
||
response_cost=0.50,
|
||
)
|
||
|
||
# Counter should be: base(5.0) + increment(0.50) = 5.50
|
||
counter = counter_cache.in_memory_cache.get_cache(
|
||
key=f"spend:key:{hashed_token}"
|
||
)
|
||
assert counter == 5.50
|
||
|
||
# Second increment — counter already exists, just increment
|
||
await increment_spend_counters(
|
||
token=hashed_token,
|
||
team_id=None,
|
||
user_id=None,
|
||
response_cost=0.25,
|
||
)
|
||
|
||
counter = counter_cache.in_memory_cache.get_cache(
|
||
key=f"spend:key:{hashed_token}"
|
||
)
|
||
assert counter == 5.75
|
||
finally:
|
||
ps.user_api_key_cache = original_key_cache
|
||
ps.spend_counter_cache = original_counter_cache
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_increment_spend_counters_team_and_member():
|
||
"""Counter should track team and team member spend separately."""
|
||
from litellm.caching.dual_cache import DualCache
|
||
from litellm.proxy._types import LiteLLM_TeamTable
|
||
|
||
key_cache = DualCache()
|
||
counter_cache = DualCache()
|
||
|
||
# Cached team object
|
||
team_obj = LiteLLM_TeamTable(team_id="team-1", spend=2.0)
|
||
key_cache.in_memory_cache.set_cache(key="team_id:team-1", value=team_obj)
|
||
|
||
# Cached team membership
|
||
key_cache.in_memory_cache.set_cache(
|
||
key="team_membership:user-1:team-1",
|
||
value={"user_id": "user-1", "team_id": "team-1", "spend": 1.0},
|
||
)
|
||
|
||
import litellm.proxy.proxy_server as ps
|
||
|
||
original_key_cache = ps.user_api_key_cache
|
||
original_counter_cache = ps.spend_counter_cache
|
||
ps.user_api_key_cache = key_cache
|
||
ps.spend_counter_cache = counter_cache
|
||
|
||
try:
|
||
from litellm.proxy.proxy_server import increment_spend_counters
|
||
|
||
await increment_spend_counters(
|
||
token=None,
|
||
team_id="team-1",
|
||
user_id="user-1",
|
||
response_cost=0.30,
|
||
)
|
||
|
||
team_counter = counter_cache.in_memory_cache.get_cache(key="spend:team:team-1")
|
||
assert team_counter == 2.30
|
||
|
||
member_counter = counter_cache.in_memory_cache.get_cache(
|
||
key="spend:team_member:user-1:team-1"
|
||
)
|
||
assert member_counter == 1.30
|
||
finally:
|
||
ps.user_api_key_cache = original_key_cache
|
||
ps.spend_counter_cache = original_counter_cache
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_init_and_increment_spend_counter_reseeds_from_db_on_counter_miss():
|
||
"""When the Redis counter is missing, the reseed path reads the
|
||
authoritative spend from the DB (not a stale cache), so the next
|
||
increment continues from the correct base value."""
|
||
from litellm.caching.dual_cache import DualCache
|
||
|
||
counter_cache = DualCache()
|
||
recorded_increments: list = []
|
||
|
||
async def record_increment(key, value, ttl=None, **kwargs):
|
||
recorded_increments.append({"key": key, "value": value, "ttl": ttl})
|
||
return value
|
||
|
||
fake_redis = AsyncMock()
|
||
fake_redis.async_increment = AsyncMock(side_effect=record_increment)
|
||
fake_redis.async_get_cache = AsyncMock(return_value=None) # counter missing
|
||
fake_redis.async_set_cache = AsyncMock(return_value=True) # SET NX wins
|
||
counter_cache.redis_cache = fake_redis
|
||
|
||
# Prisma returns spend=42.0 (authoritative) while the stale cached
|
||
# value (would be read only if prisma is None) is 10.0. The counter
|
||
# must seed from 42, not 10.
|
||
db_row = MagicMock()
|
||
db_row.spend = 42.0
|
||
fake_prisma = MagicMock()
|
||
fake_prisma.db.litellm_teamtable.find_unique = AsyncMock(return_value=db_row)
|
||
|
||
stale_cache = DualCache()
|
||
stale_team = MagicMock()
|
||
stale_team.spend = 10.0
|
||
stale_cache.in_memory_cache.set_cache(key="team_id:team-9", value=stale_team)
|
||
|
||
import litellm.proxy.proxy_server as ps
|
||
from litellm.proxy.proxy_server import _init_and_increment_spend_counter
|
||
|
||
orig_user, orig_counter, orig_prisma = (
|
||
ps.user_api_key_cache,
|
||
ps.spend_counter_cache,
|
||
ps.prisma_client,
|
||
)
|
||
ps.user_api_key_cache = stale_cache
|
||
ps.spend_counter_cache = counter_cache
|
||
ps.prisma_client = fake_prisma
|
||
try:
|
||
await _init_and_increment_spend_counter(
|
||
counter_key="spend:team:team-9",
|
||
source_cache_key="team_id:team-9",
|
||
increment=1.5,
|
||
)
|
||
|
||
fake_prisma.db.litellm_teamtable.find_unique.assert_awaited_once_with(
|
||
where={"team_id": "team-9"}
|
||
)
|
||
# Seed uses SET NX with db_spend (42) — cross-pod safe, no INCR of 42.
|
||
# Only the per-request delta (1.5) goes through INCRBYFLOAT.
|
||
fake_redis.async_set_cache.assert_awaited_once_with(
|
||
key="spend:team:team-9", value=42.0, nx=True
|
||
)
|
||
writes = [(c["key"], c["value"]) for c in recorded_increments]
|
||
assert writes == [("spend:team:team-9", 1.5)]
|
||
finally:
|
||
ps.user_api_key_cache = orig_user
|
||
ps.spend_counter_cache = orig_counter
|
||
ps.prisma_client = orig_prisma
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_primary_spend_counter_redis_concurrent_seed_does_not_double_seed():
|
||
"""Two pods both observing a missing Redis counter must not both
|
||
INCRBYFLOAT the full DB spend. SpendCounterReseed.coalesced uses SET NX
|
||
so the loser reads the winner's value; final Redis = db_spend, not
|
||
2 * db_spend.
|
||
|
||
The per-counter asyncio.Lock is per-process, so it does NOT coordinate
|
||
across pods. We simulate two pods by patching _get_lock to return a
|
||
fresh lock per call (each "pod" has its own lock registry in real life).
|
||
"""
|
||
from litellm.caching.dual_cache import DualCache
|
||
from litellm.proxy.db.spend_counter_reseed import SpendCounterReseed
|
||
|
||
counter_key = "spend:team:team-concurrent-seed"
|
||
redis_store: dict = {}
|
||
db_read_count = 0
|
||
set_results: list = []
|
||
get_after_set_count = 0
|
||
set_completed_count = 0
|
||
|
||
async def redis_set_cache(key, value, nx=False, **_):
|
||
# Yield BEFORE the membership check so two concurrent callers
|
||
# interleave the way real atomic Redis SET NX does: the first
|
||
# to resume runs check + write atomically and wins; the second
|
||
# resumes after the key exists and loses. Yielding *after* the
|
||
# check would let both callers pass the empty-store check before
|
||
# either writes, so neither would ever lose.
|
||
await asyncio.sleep(0)
|
||
if nx and key in redis_store:
|
||
set_results.append(False)
|
||
return False
|
||
redis_store[key] = float(value)
|
||
set_results.append(True)
|
||
nonlocal set_completed_count
|
||
set_completed_count += 1
|
||
return True
|
||
|
||
async def redis_get_cache(key):
|
||
# Track reads that happen after at least one SET NX has completed
|
||
# — those are the loser-path fallback reads we want to verify.
|
||
if set_completed_count > 0:
|
||
nonlocal get_after_set_count
|
||
get_after_set_count += 1
|
||
return redis_store.get(key)
|
||
|
||
fake_redis = AsyncMock()
|
||
fake_redis.async_get_cache = AsyncMock(side_effect=redis_get_cache)
|
||
fake_redis.async_set_cache = AsyncMock(side_effect=redis_set_cache)
|
||
|
||
async def slow_find_unique(**_):
|
||
nonlocal db_read_count
|
||
db_read_count += 1
|
||
# Both pods read DB before either's SET NX lands.
|
||
await asyncio.sleep(0)
|
||
row = MagicMock()
|
||
row.spend = 506.0
|
||
return row
|
||
|
||
fake_prisma = MagicMock()
|
||
fake_prisma.db.litellm_teamtable.find_unique = AsyncMock(
|
||
side_effect=slow_find_unique
|
||
)
|
||
|
||
pod_a = DualCache()
|
||
pod_a.redis_cache = fake_redis
|
||
pod_b = DualCache()
|
||
pod_b.redis_cache = fake_redis
|
||
|
||
# Each "pod" has its own per-process lock registry. Patch _get_lock to
|
||
# always return a fresh lock so the two coalesced calls do not serialize
|
||
# via one in-process lock (which is what would happen across pods).
|
||
async def fresh_lock(_counter_key):
|
||
return asyncio.Lock()
|
||
|
||
with patch.object(SpendCounterReseed, "_get_lock", side_effect=fresh_lock):
|
||
results = await asyncio.gather(
|
||
SpendCounterReseed.coalesced(
|
||
prisma_client=fake_prisma,
|
||
spend_counter_cache=pod_a,
|
||
counter_key=counter_key,
|
||
),
|
||
SpendCounterReseed.coalesced(
|
||
prisma_client=fake_prisma,
|
||
spend_counter_cache=pod_b,
|
||
counter_key=counter_key,
|
||
),
|
||
)
|
||
|
||
assert all(r == 506.0 for r in results), results
|
||
assert redis_store[counter_key] == pytest.approx(506.0), redis_store
|
||
# Both pods read the DB and both attempted SET NX; exactly one wrote
|
||
# (winner) and one was rejected (loser).
|
||
assert db_read_count == 2
|
||
assert fake_redis.async_set_cache.await_count == 2
|
||
nx_writes = [
|
||
call
|
||
for call in fake_redis.async_set_cache.await_args_list
|
||
if call.kwargs.get("nx") is True
|
||
]
|
||
assert len(nx_writes) == 2
|
||
assert sorted(set_results) == [
|
||
False,
|
||
True,
|
||
], f"expected exactly one SET NX winner and one loser, got {set_results}"
|
||
# Loser path executed: after the winner's SET NX returned True, the
|
||
# losing coalesced() call falls back to async_get_cache to read the
|
||
# winner's value rather than re-seeding.
|
||
assert (
|
||
get_after_set_count >= 1
|
||
), "loser branch (else: read back winner's value) was never exercised"
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_reseed_spend_from_db_user_and_org_prefixes():
|
||
"""User and org counters reseed from their own DB tables.
|
||
|
||
End-user and tag counters use the already fetched auth objects passed as
|
||
fallback_spend, so this reseed helper must not add extra per-request DB
|
||
reads for them.
|
||
"""
|
||
from litellm.proxy.db.spend_counter_reseed import SpendCounterReseed
|
||
|
||
user_row = MagicMock()
|
||
user_row.spend = 17.0
|
||
org_row = MagicMock()
|
||
org_row.spend = 305.0
|
||
|
||
fake_prisma = MagicMock()
|
||
fake_prisma.db.litellm_usertable.find_unique = AsyncMock(return_value=user_row)
|
||
fake_prisma.db.litellm_endusertable.find_unique = AsyncMock()
|
||
fake_prisma.db.litellm_tagtable.find_unique = AsyncMock()
|
||
fake_prisma.db.litellm_organizationtable.find_unique = AsyncMock(
|
||
return_value=org_row
|
||
)
|
||
|
||
assert await SpendCounterReseed.from_db(fake_prisma, "spend:user:alice") == 17.0
|
||
fake_prisma.db.litellm_usertable.find_unique.assert_awaited_once_with(
|
||
where={"user_id": "alice"}
|
||
)
|
||
|
||
assert (
|
||
await SpendCounterReseed.from_db(
|
||
fake_prisma,
|
||
"spend:end_user:customer-1",
|
||
)
|
||
is None
|
||
)
|
||
fake_prisma.db.litellm_endusertable.find_unique.assert_not_awaited()
|
||
|
||
assert await SpendCounterReseed.from_db(fake_prisma, "spend:tag:paid-tag") is None
|
||
fake_prisma.db.litellm_tagtable.find_unique.assert_not_awaited()
|
||
|
||
assert await SpendCounterReseed.from_db(fake_prisma, "spend:org:acme") == 305.0
|
||
fake_prisma.db.litellm_organizationtable.find_unique.assert_awaited_once_with(
|
||
where={"organization_id": "acme"}
|
||
)
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_reseed_spend_from_db_skips_window_variant_keys():
|
||
"""Window counters (spend:*:window:{duration}) share prefixes with
|
||
primary counters but don't correspond to a DB row. The guard must
|
||
short-circuit without querying the DB."""
|
||
from litellm.proxy.db.spend_counter_reseed import SpendCounterReseed
|
||
|
||
fake_prisma = MagicMock()
|
||
fake_prisma.db.litellm_verificationtoken.find_unique = AsyncMock()
|
||
fake_prisma.db.litellm_teamtable.find_unique = AsyncMock()
|
||
|
||
assert (
|
||
await SpendCounterReseed.from_db(fake_prisma, "spend:key:sk-abc:window:1h")
|
||
is None
|
||
)
|
||
assert (
|
||
await SpendCounterReseed.from_db(fake_prisma, "spend:team:team-1:window:1d")
|
||
is None
|
||
)
|
||
fake_prisma.db.litellm_verificationtoken.find_unique.assert_not_awaited()
|
||
fake_prisma.db.litellm_teamtable.find_unique.assert_not_awaited()
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_window_spend_counter_reseeds_from_spend_logs_on_counter_miss():
|
||
from litellm.caching.dual_cache import DualCache
|
||
from litellm.proxy.proxy_server import _init_and_increment_window_spend_counter
|
||
|
||
counter_cache = DualCache()
|
||
window_start = datetime.now(timezone.utc) - timedelta(hours=1)
|
||
fake_prisma = MagicMock()
|
||
fake_prisma.db.litellm_spendlogs.group_by = AsyncMock(
|
||
return_value=[{"api_key": "key-window", "_sum": {"spend": 2.25}}]
|
||
)
|
||
|
||
import litellm.proxy.proxy_server as ps
|
||
|
||
orig_counter, orig_prisma = ps.spend_counter_cache, ps.prisma_client
|
||
ps.spend_counter_cache = counter_cache
|
||
ps.prisma_client = fake_prisma
|
||
try:
|
||
await _init_and_increment_window_spend_counter(
|
||
counter_key="spend:key:key-window:window:1h",
|
||
entity_type="Key",
|
||
entity_id="key-window",
|
||
window_start=window_start,
|
||
increment=0.5,
|
||
)
|
||
|
||
fake_prisma.db.litellm_spendlogs.group_by.assert_awaited_once_with(
|
||
by=["api_key"],
|
||
where={"api_key": "key-window", "startTime": {"gte": window_start}},
|
||
sum={"spend": True},
|
||
)
|
||
assert counter_cache.in_memory_cache.get_cache(
|
||
key="spend:key:key-window:window:1h"
|
||
) == pytest.approx(2.75)
|
||
finally:
|
||
ps.spend_counter_cache = orig_counter
|
||
ps.prisma_client = orig_prisma
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_init_spend_counter_redis_clean_miss_skips_stale_in_memory():
|
||
from litellm.caching.dual_cache import DualCache
|
||
from litellm.proxy.proxy_server import _init_and_increment_spend_counter
|
||
|
||
counter_cache = DualCache()
|
||
counter_key = "spend:team:team-stale-local"
|
||
counter_cache.in_memory_cache.set_cache(key=counter_key, value=10.0)
|
||
|
||
redis_store: dict = {}
|
||
|
||
async def redis_increment(key, value, **_):
|
||
redis_store[key] = (redis_store.get(key) or 0.0) + value
|
||
return redis_store[key]
|
||
|
||
async def redis_set_cache(key, value, nx=False, **_):
|
||
if nx and key in redis_store:
|
||
return False
|
||
redis_store[key] = float(value)
|
||
return True
|
||
|
||
fake_redis = AsyncMock()
|
||
fake_redis.async_get_cache = AsyncMock(return_value=None)
|
||
fake_redis.async_increment = AsyncMock(side_effect=redis_increment)
|
||
fake_redis.async_set_cache = AsyncMock(side_effect=redis_set_cache)
|
||
counter_cache.redis_cache = fake_redis
|
||
|
||
db_row = MagicMock()
|
||
db_row.spend = 42.0
|
||
fake_prisma = MagicMock()
|
||
fake_prisma.db.litellm_teamtable.find_unique = AsyncMock(return_value=db_row)
|
||
|
||
import litellm.proxy.proxy_server as ps
|
||
|
||
orig_counter, orig_prisma, orig_user = (
|
||
ps.spend_counter_cache,
|
||
ps.prisma_client,
|
||
ps.user_api_key_cache,
|
||
)
|
||
ps.spend_counter_cache = counter_cache
|
||
ps.prisma_client = fake_prisma
|
||
ps.user_api_key_cache = DualCache()
|
||
try:
|
||
await _init_and_increment_spend_counter(
|
||
counter_key=counter_key,
|
||
source_cache_key="team_id:team-stale-local",
|
||
increment=1.5,
|
||
)
|
||
|
||
fake_prisma.db.litellm_teamtable.find_unique.assert_awaited_once_with(
|
||
where={"team_id": "team-stale-local"}
|
||
)
|
||
# Seed via SET NX (42) + delta via INCRBYFLOAT (1.5) = 43.5.
|
||
assert redis_store[counter_key] == pytest.approx(43.5)
|
||
assert counter_cache.in_memory_cache.get_cache(
|
||
key=counter_key
|
||
) == pytest.approx(43.5)
|
||
finally:
|
||
ps.spend_counter_cache = orig_counter
|
||
ps.prisma_client = orig_prisma
|
||
ps.user_api_key_cache = orig_user
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_window_spend_counter_redis_clean_miss_skips_stale_in_memory():
|
||
from litellm.caching.dual_cache import DualCache
|
||
from litellm.proxy.proxy_server import _init_and_increment_window_spend_counter
|
||
|
||
counter_cache = DualCache()
|
||
counter_key = "spend:key:key-window-stale-local:window:1h"
|
||
counter_cache.in_memory_cache.set_cache(key=counter_key, value=100.0)
|
||
window_start = datetime.now(timezone.utc) - timedelta(hours=1)
|
||
|
||
redis_store: dict = {}
|
||
|
||
async def redis_increment(key, value, **_):
|
||
redis_store[key] = (redis_store.get(key) or 0.0) + value
|
||
return redis_store[key]
|
||
|
||
async def redis_set_cache(key, value, **_):
|
||
if key in redis_store:
|
||
return False
|
||
redis_store[key] = value
|
||
return True
|
||
|
||
fake_redis = AsyncMock()
|
||
fake_redis.async_get_cache = AsyncMock(return_value=None)
|
||
fake_redis.async_set_cache = AsyncMock(side_effect=redis_set_cache)
|
||
fake_redis.async_increment = AsyncMock(side_effect=redis_increment)
|
||
counter_cache.redis_cache = fake_redis
|
||
|
||
fake_prisma = MagicMock()
|
||
fake_prisma.db.litellm_spendlogs.group_by = AsyncMock(
|
||
return_value=[{"api_key": "key-window-stale-local", "_sum": {"spend": 2.25}}]
|
||
)
|
||
|
||
import litellm.proxy.proxy_server as ps
|
||
|
||
orig_counter, orig_prisma = ps.spend_counter_cache, ps.prisma_client
|
||
ps.spend_counter_cache = counter_cache
|
||
ps.prisma_client = fake_prisma
|
||
try:
|
||
await _init_and_increment_window_spend_counter(
|
||
counter_key=counter_key,
|
||
entity_type="Key",
|
||
entity_id="key-window-stale-local",
|
||
window_start=window_start,
|
||
increment=0.5,
|
||
)
|
||
|
||
fake_prisma.db.litellm_spendlogs.group_by.assert_awaited_once_with(
|
||
by=["api_key"],
|
||
where={
|
||
"api_key": "key-window-stale-local",
|
||
"startTime": {"gte": window_start},
|
||
},
|
||
sum={"spend": True},
|
||
)
|
||
assert redis_store[counter_key] == pytest.approx(2.75)
|
||
assert counter_cache.in_memory_cache.get_cache(
|
||
key=counter_key
|
||
) == pytest.approx(2.75)
|
||
finally:
|
||
ps.spend_counter_cache = orig_counter
|
||
ps.prisma_client = orig_prisma
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_window_spend_counter_redis_concurrent_seed_does_not_double_seed():
|
||
from litellm.caching.dual_cache import DualCache
|
||
from litellm.proxy.proxy_server import _init_and_increment_window_spend_counter
|
||
|
||
counter_cache = DualCache()
|
||
counter_key = "spend:key:key-window-concurrent-seed:window:1h"
|
||
window_start = datetime.now(timezone.utc) - timedelta(hours=1)
|
||
redis_store = {counter_key: 2.75}
|
||
redis_reads = 0
|
||
|
||
async def redis_get_cache(key):
|
||
nonlocal redis_reads
|
||
redis_reads += 1
|
||
if redis_reads <= 2:
|
||
return None
|
||
return redis_store.get(key)
|
||
|
||
async def redis_increment(key, value, **_):
|
||
redis_store[key] = (redis_store.get(key) or 0.0) + value
|
||
return redis_store[key]
|
||
|
||
fake_redis = AsyncMock()
|
||
fake_redis.async_get_cache = AsyncMock(side_effect=redis_get_cache)
|
||
fake_redis.async_set_cache = AsyncMock(return_value=False)
|
||
fake_redis.async_increment = AsyncMock(side_effect=redis_increment)
|
||
counter_cache.redis_cache = fake_redis
|
||
|
||
fake_prisma = MagicMock()
|
||
fake_prisma.db.litellm_spendlogs.group_by = AsyncMock(
|
||
return_value=[
|
||
{"api_key": "key-window-concurrent-seed", "_sum": {"spend": 2.25}}
|
||
]
|
||
)
|
||
|
||
import litellm.proxy.proxy_server as ps
|
||
|
||
orig_counter, orig_prisma = ps.spend_counter_cache, ps.prisma_client
|
||
ps.spend_counter_cache = counter_cache
|
||
ps.prisma_client = fake_prisma
|
||
try:
|
||
await _init_and_increment_window_spend_counter(
|
||
counter_key=counter_key,
|
||
entity_type="Key",
|
||
entity_id="key-window-concurrent-seed",
|
||
window_start=window_start,
|
||
increment=0.5,
|
||
)
|
||
|
||
fake_redis.async_set_cache.assert_awaited_once_with(
|
||
key=counter_key,
|
||
value=2.25,
|
||
nx=True,
|
||
)
|
||
assert redis_store[counter_key] == pytest.approx(3.25)
|
||
assert counter_cache.in_memory_cache.get_cache(
|
||
key=counter_key
|
||
) == pytest.approx(3.25)
|
||
finally:
|
||
ps.spend_counter_cache = orig_counter
|
||
ps.prisma_client = orig_prisma
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_window_spend_counter_skips_invalid_window_start():
|
||
from litellm.caching.dual_cache import DualCache
|
||
from litellm.proxy.proxy_server import _init_and_increment_window_spend_counter
|
||
|
||
counter_cache = DualCache()
|
||
|
||
import litellm.proxy.proxy_server as ps
|
||
|
||
orig_counter = ps.spend_counter_cache
|
||
ps.spend_counter_cache = counter_cache
|
||
try:
|
||
await _init_and_increment_window_spend_counter(
|
||
counter_key="spend:key:key-invalid-window:window:not-a-duration",
|
||
entity_type="Key",
|
||
entity_id="key-invalid-window",
|
||
window_start=None,
|
||
increment=0.5,
|
||
)
|
||
|
||
assert (
|
||
counter_cache.in_memory_cache.get_cache(
|
||
key="spend:key:key-invalid-window:window:not-a-duration"
|
||
)
|
||
is None
|
||
)
|
||
finally:
|
||
ps.spend_counter_cache = orig_counter
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_window_spend_counter_does_not_seed_zero_when_db_unavailable():
|
||
from litellm.caching.dual_cache import DualCache
|
||
from litellm.proxy.proxy_server import _ensure_window_spend_counter_initialized
|
||
|
||
counter_cache = DualCache()
|
||
counter_key = "spend:key:key-window-db-unavailable:window:1h"
|
||
|
||
import litellm.proxy.proxy_server as ps
|
||
|
||
orig_counter, orig_prisma = ps.spend_counter_cache, ps.prisma_client
|
||
ps.spend_counter_cache = counter_cache
|
||
ps.prisma_client = None
|
||
try:
|
||
initialized = await _ensure_window_spend_counter_initialized(
|
||
counter_key=counter_key,
|
||
entity_type="Key",
|
||
entity_id="key-window-db-unavailable",
|
||
window_start=datetime.now(timezone.utc) - timedelta(hours=1),
|
||
)
|
||
|
||
assert initialized is False
|
||
assert counter_cache.in_memory_cache.get_cache(key=counter_key) is None
|
||
finally:
|
||
ps.spend_counter_cache = orig_counter
|
||
ps.prisma_client = orig_prisma
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_increment_spend_counters_finalizes_after_unreserved_increments():
|
||
from litellm.caching.dual_cache import DualCache
|
||
from litellm.proxy.proxy_server import increment_spend_counters
|
||
|
||
counter_cache = DualCache()
|
||
counter_cache.in_memory_cache.set_cache(
|
||
key="spend:key:key-finalize-after-increments",
|
||
value=0.5,
|
||
)
|
||
budget_reservation = {
|
||
"reserved_cost": 0.5,
|
||
"entries": [
|
||
{
|
||
"counter_key": "spend:key:key-finalize-after-increments",
|
||
"entity_type": "Key",
|
||
"entity_id": "key-finalize-after-increments",
|
||
"reserved_cost": 0.5,
|
||
"applied_adjustment": 0.0,
|
||
}
|
||
],
|
||
"finalized": False,
|
||
}
|
||
incremented_counters = []
|
||
|
||
async def assert_reservation_not_finalized_yet(**kwargs):
|
||
assert budget_reservation["finalized"] is False
|
||
incremented_counters.append(kwargs["counter_key"])
|
||
|
||
import litellm.proxy.proxy_server as ps
|
||
|
||
orig_counter, orig_user = ps.spend_counter_cache, ps.user_api_key_cache
|
||
ps.spend_counter_cache = counter_cache
|
||
ps.user_api_key_cache = DualCache()
|
||
try:
|
||
with patch(
|
||
"litellm.proxy.proxy_server._init_and_increment_spend_counter",
|
||
new=AsyncMock(side_effect=assert_reservation_not_finalized_yet),
|
||
):
|
||
await increment_spend_counters(
|
||
token="key-finalize-after-increments",
|
||
team_id="team-finalize-after-increments",
|
||
user_id=None,
|
||
response_cost=0.25,
|
||
budget_reservation=budget_reservation,
|
||
)
|
||
|
||
assert incremented_counters == ["spend:team:team-finalize-after-increments"]
|
||
assert budget_reservation["finalized"] is True
|
||
assert counter_cache.in_memory_cache.get_cache(
|
||
key="spend:key:key-finalize-after-increments"
|
||
) == pytest.approx(0.25)
|
||
finally:
|
||
ps.spend_counter_cache = orig_counter
|
||
ps.user_api_key_cache = orig_user
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_increment_spend_counters_finalizes_none_cost_reservation():
|
||
from litellm.caching.dual_cache import DualCache
|
||
from litellm.proxy.proxy_server import increment_spend_counters
|
||
|
||
counter_cache = DualCache()
|
||
counter_cache.in_memory_cache.set_cache(
|
||
key="spend:key:key-finalize-none-cost",
|
||
value=0.5,
|
||
)
|
||
budget_reservation = {
|
||
"reserved_cost": 0.5,
|
||
"entries": [
|
||
{
|
||
"counter_key": "spend:key:key-finalize-none-cost",
|
||
"entity_type": "Key",
|
||
"entity_id": "key-finalize-none-cost",
|
||
"reserved_cost": 0.5,
|
||
"applied_adjustment": 0.0,
|
||
}
|
||
],
|
||
"finalized": False,
|
||
}
|
||
|
||
import litellm.proxy.proxy_server as ps
|
||
|
||
orig_counter = ps.spend_counter_cache
|
||
ps.spend_counter_cache = counter_cache
|
||
try:
|
||
await increment_spend_counters(
|
||
token="key-finalize-none-cost",
|
||
team_id=None,
|
||
user_id=None,
|
||
response_cost=None,
|
||
budget_reservation=budget_reservation,
|
||
)
|
||
|
||
assert budget_reservation["finalized"] is True
|
||
assert counter_cache.in_memory_cache.get_cache(
|
||
key="spend:key:key-finalize-none-cost"
|
||
) == pytest.approx(0.0)
|
||
finally:
|
||
ps.spend_counter_cache = orig_counter
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_increment_spend_counters_invalidates_bad_reserved_counter_without_failing():
|
||
from litellm.caching.dual_cache import DualCache
|
||
from litellm.proxy.proxy_server import increment_spend_counters
|
||
|
||
counter_cache = DualCache()
|
||
budget_reservation = {
|
||
"reserved_cost": 0.5,
|
||
"entries": [
|
||
{
|
||
"counter_key": "spend:key:key-bad-reserved-counter",
|
||
"entity_type": "Key",
|
||
"entity_id": "key-bad-reserved-counter",
|
||
"reserved_cost": 0.5,
|
||
"applied_adjustment": 0.0,
|
||
}
|
||
],
|
||
"finalized": False,
|
||
}
|
||
|
||
import litellm.proxy.proxy_server as ps
|
||
|
||
orig_counter = ps.spend_counter_cache
|
||
ps.spend_counter_cache = counter_cache
|
||
try:
|
||
with patch(
|
||
"litellm.proxy.proxy_server.verbose_proxy_logger.warning"
|
||
) as mock_warning:
|
||
await increment_spend_counters(
|
||
token="key-bad-reserved-counter",
|
||
team_id=None,
|
||
user_id=None,
|
||
response_cost=0.25,
|
||
budget_reservation=budget_reservation,
|
||
)
|
||
|
||
mock_warning.assert_called_once()
|
||
assert budget_reservation["finalized"] is True
|
||
assert (
|
||
counter_cache.in_memory_cache.get_cache(
|
||
key="spend:key:key-bad-reserved-counter"
|
||
)
|
||
is None
|
||
)
|
||
finally:
|
||
ps.spend_counter_cache = orig_counter
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_increment_spend_counter_invalidates_stale_cache_on_redis_failure():
|
||
from litellm.caching.dual_cache import DualCache
|
||
from litellm.proxy.proxy_server import _increment_spend_counter_cache
|
||
|
||
counter_cache = DualCache()
|
||
counter_cache.in_memory_cache.set_cache(key="spend:team:redis-fail", value=4.0)
|
||
fake_redis = AsyncMock()
|
||
fake_redis.async_increment = AsyncMock(side_effect=RuntimeError("redis down"))
|
||
fake_redis.async_delete_cache = AsyncMock()
|
||
counter_cache.redis_cache = fake_redis
|
||
|
||
import litellm.proxy.proxy_server as ps
|
||
|
||
orig_counter = ps.spend_counter_cache
|
||
ps.spend_counter_cache = counter_cache
|
||
try:
|
||
with pytest.raises(RuntimeError):
|
||
await _increment_spend_counter_cache(
|
||
counter_key="spend:team:redis-fail",
|
||
increment=0.5,
|
||
)
|
||
|
||
assert (
|
||
counter_cache.in_memory_cache.get_cache(key="spend:team:redis-fail") is None
|
||
)
|
||
fake_redis.async_delete_cache.assert_awaited_once_with(
|
||
key="spend:team:redis-fail"
|
||
)
|
||
finally:
|
||
ps.spend_counter_cache = orig_counter
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_get_current_spend_reseeds_from_db_when_counter_missing():
|
||
"""
|
||
When both the Redis and in-memory counters are missing, the enforcement
|
||
read path must reseed from the authoritative DB, not fall back to the
|
||
caller-supplied stale value. Otherwise, every Redis TTL expiry lets a
|
||
request through against a stale in-process `team_membership.spend`.
|
||
"""
|
||
from litellm.caching.dual_cache import DualCache
|
||
from litellm.proxy.proxy_server import get_current_spend
|
||
|
||
counter_cache = DualCache()
|
||
recorded_seeds: list = []
|
||
|
||
async def record_set_cache(key, value, nx=False, **kwargs):
|
||
recorded_seeds.append({"key": key, "value": value, "nx": nx})
|
||
return True
|
||
|
||
fake_redis = AsyncMock()
|
||
fake_redis.async_set_cache = AsyncMock(side_effect=record_set_cache)
|
||
fake_redis.async_get_cache = AsyncMock(return_value=None)
|
||
counter_cache.redis_cache = fake_redis
|
||
|
||
# DB has authoritative spend=362.0; caller hands us stale fallback=30.0
|
||
# (the in-process team_membership.spend that hasn't caught up to DB).
|
||
db_row = MagicMock()
|
||
db_row.spend = 362.0
|
||
fake_prisma = MagicMock()
|
||
fake_prisma.db.litellm_teammembership.find_unique = AsyncMock(return_value=db_row)
|
||
|
||
import litellm.proxy.proxy_server as ps
|
||
|
||
orig_counter, orig_prisma = ps.spend_counter_cache, ps.prisma_client
|
||
ps.spend_counter_cache = counter_cache
|
||
ps.prisma_client = fake_prisma
|
||
try:
|
||
spend = await get_current_spend(
|
||
counter_key="spend:team_member:user-1:team-1",
|
||
fallback_spend=30.0,
|
||
)
|
||
assert spend == 362.0, (
|
||
f"expected DB reseed to return 362.0, got {spend} "
|
||
f"(fallback would have returned 30.0 and caused bypass)"
|
||
)
|
||
# Counter warmed via SET NX so subsequent reads are fast.
|
||
assert ("spend:team_member:user-1:team-1", 362.0, True) in [
|
||
(s["key"], s["value"], s["nx"]) for s in recorded_seeds
|
||
]
|
||
assert counter_cache.in_memory_cache.get_cache(
|
||
key="spend:team_member:user-1:team-1"
|
||
) == pytest.approx(362.0)
|
||
finally:
|
||
ps.spend_counter_cache = orig_counter
|
||
ps.prisma_client = orig_prisma
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_get_current_spend_uses_fallback_when_db_unavailable():
|
||
"""
|
||
If prisma is unavailable and both counters are missing, the read path
|
||
must degrade to the caller-supplied fallback rather than raising.
|
||
"""
|
||
from litellm.caching.dual_cache import DualCache
|
||
from litellm.proxy.proxy_server import get_current_spend
|
||
|
||
counter_cache = DualCache()
|
||
fake_redis = AsyncMock()
|
||
fake_redis.async_get_cache = AsyncMock(return_value=None)
|
||
counter_cache.redis_cache = fake_redis
|
||
|
||
import litellm.proxy.proxy_server as ps
|
||
|
||
orig_counter, orig_prisma = ps.spend_counter_cache, ps.prisma_client
|
||
ps.spend_counter_cache = counter_cache
|
||
ps.prisma_client = None # simulate prisma unavailable
|
||
try:
|
||
spend = await get_current_spend(
|
||
counter_key="spend:team_member:user-1:team-1",
|
||
fallback_spend=15.5,
|
||
)
|
||
assert spend == 15.5
|
||
finally:
|
||
ps.spend_counter_cache = orig_counter
|
||
ps.prisma_client = orig_prisma
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_get_current_spend_coalesces_concurrent_reseeds():
|
||
"""
|
||
When several concurrent calls hit a cold counter on the same pod,
|
||
only one DB query should fire. The rest should wait for the lock,
|
||
re-check the warmed counter, and return without hitting the DB.
|
||
"""
|
||
import asyncio as _asyncio
|
||
|
||
from litellm.caching.dual_cache import DualCache
|
||
from litellm.proxy.proxy_server import get_current_spend
|
||
|
||
counter_cache = DualCache()
|
||
counter_key = "spend:team_member:user-1:team-coalesce"
|
||
|
||
# Track DB query calls and inject a small delay so the concurrent
|
||
# callers actually overlap in the lock-acquire window.
|
||
db_call_count = 0
|
||
|
||
async def slow_find_unique(**kwargs):
|
||
nonlocal db_call_count
|
||
db_call_count += 1
|
||
await _asyncio.sleep(0.05)
|
||
row = MagicMock()
|
||
row.spend = 100.0
|
||
return row
|
||
|
||
fake_redis = AsyncMock()
|
||
redis_store: dict = {}
|
||
|
||
async def redis_get(key, **_):
|
||
return redis_store.get(key)
|
||
|
||
async def redis_increment(key, value, **_):
|
||
redis_store[key] = (redis_store.get(key) or 0.0) + value
|
||
return redis_store[key]
|
||
|
||
async def redis_set_cache(key, value, nx=False, **_):
|
||
if nx and key in redis_store:
|
||
return False
|
||
redis_store[key] = float(value)
|
||
return True
|
||
|
||
fake_redis.async_get_cache = AsyncMock(side_effect=redis_get)
|
||
fake_redis.async_increment = AsyncMock(side_effect=redis_increment)
|
||
fake_redis.async_set_cache = AsyncMock(side_effect=redis_set_cache)
|
||
counter_cache.redis_cache = fake_redis
|
||
|
||
fake_prisma = MagicMock()
|
||
fake_prisma.db.litellm_teammembership.find_unique = AsyncMock(
|
||
side_effect=slow_find_unique
|
||
)
|
||
|
||
import litellm.proxy.proxy_server as ps
|
||
|
||
orig_counter, orig_prisma = ps.spend_counter_cache, ps.prisma_client
|
||
ps.spend_counter_cache = counter_cache
|
||
ps.prisma_client = fake_prisma
|
||
try:
|
||
results = await _asyncio.gather(
|
||
*[
|
||
get_current_spend(counter_key=counter_key, fallback_spend=0.0)
|
||
for _ in range(5)
|
||
]
|
||
)
|
||
assert results == [100.0] * 5, f"all callers should see DB value, got {results}"
|
||
assert (
|
||
db_call_count == 1
|
||
), f"expected exactly 1 DB query for 5 concurrent reseeds, got {db_call_count}"
|
||
finally:
|
||
ps.spend_counter_cache = orig_counter
|
||
ps.prisma_client = orig_prisma
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_get_current_spend_uses_db_zero_over_stale_fallback():
|
||
"""
|
||
When DB returns spend=0 (e.g. just after a budget period reset), the
|
||
authoritative DB value must win over a stale non-zero fallback. The
|
||
fallback in production is the in-process team_membership.spend, which
|
||
can still hold the pre-reset value across pods.
|
||
"""
|
||
from litellm.caching.dual_cache import DualCache
|
||
from litellm.proxy.proxy_server import get_current_spend
|
||
|
||
counter_cache = DualCache()
|
||
fake_redis = AsyncMock()
|
||
fake_redis.async_get_cache = AsyncMock(return_value=None)
|
||
counter_cache.redis_cache = fake_redis
|
||
|
||
db_row = MagicMock()
|
||
db_row.spend = 0.0
|
||
fake_prisma = MagicMock()
|
||
fake_prisma.db.litellm_teammembership.find_unique = AsyncMock(return_value=db_row)
|
||
|
||
import litellm.proxy.proxy_server as ps
|
||
|
||
orig_counter, orig_prisma = ps.spend_counter_cache, ps.prisma_client
|
||
ps.spend_counter_cache = counter_cache
|
||
ps.prisma_client = fake_prisma
|
||
try:
|
||
spend = await get_current_spend(
|
||
counter_key="spend:team_member:user-1:team-after-reset",
|
||
fallback_spend=42.0,
|
||
)
|
||
assert (
|
||
spend == 0.0
|
||
), f"DB authoritative 0 must override stale fallback 42, got {spend}"
|
||
finally:
|
||
ps.spend_counter_cache = orig_counter
|
||
ps.prisma_client = orig_prisma
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_concurrent_read_and_write_paths_share_one_db_query():
|
||
"""
|
||
The read path (`get_current_spend`) and the write path
|
||
(`_init_and_increment_spend_counter`) both reseed cold counters from
|
||
the DB. They must share the per-counter lock so a concurrent pre-call
|
||
enforcement read and post-call increment for the same counter collapse
|
||
to one DB query, not two.
|
||
"""
|
||
import asyncio as _asyncio
|
||
|
||
from litellm.caching.dual_cache import DualCache
|
||
from litellm.proxy.proxy_server import (
|
||
_init_and_increment_spend_counter,
|
||
get_current_spend,
|
||
)
|
||
|
||
counter_cache = DualCache()
|
||
counter_key = "spend:team_member:user-1:team-cross-path"
|
||
|
||
db_call_count = 0
|
||
|
||
async def slow_find_unique(**kwargs):
|
||
nonlocal db_call_count
|
||
db_call_count += 1
|
||
await _asyncio.sleep(0.05)
|
||
row = MagicMock()
|
||
row.spend = 50.0
|
||
return row
|
||
|
||
redis_store: dict = {}
|
||
|
||
async def redis_get(key, **_):
|
||
return redis_store.get(key)
|
||
|
||
async def redis_increment(key, value, **_):
|
||
redis_store[key] = (redis_store.get(key) or 0.0) + value
|
||
return redis_store[key]
|
||
|
||
async def redis_set_cache(key, value, nx=False, **_):
|
||
if nx and key in redis_store:
|
||
return False
|
||
redis_store[key] = float(value)
|
||
return True
|
||
|
||
fake_redis = AsyncMock()
|
||
fake_redis.async_get_cache = AsyncMock(side_effect=redis_get)
|
||
fake_redis.async_increment = AsyncMock(side_effect=redis_increment)
|
||
fake_redis.async_set_cache = AsyncMock(side_effect=redis_set_cache)
|
||
counter_cache.redis_cache = fake_redis
|
||
|
||
fake_prisma = MagicMock()
|
||
fake_prisma.db.litellm_teammembership.find_unique = AsyncMock(
|
||
side_effect=slow_find_unique
|
||
)
|
||
|
||
import litellm.proxy.proxy_server as ps
|
||
|
||
orig_counter, orig_prisma, orig_user = (
|
||
ps.spend_counter_cache,
|
||
ps.prisma_client,
|
||
ps.user_api_key_cache,
|
||
)
|
||
ps.spend_counter_cache = counter_cache
|
||
ps.prisma_client = fake_prisma
|
||
ps.user_api_key_cache = DualCache()
|
||
try:
|
||
results = await _asyncio.gather(
|
||
get_current_spend(counter_key=counter_key, fallback_spend=0.0),
|
||
_init_and_increment_spend_counter(
|
||
counter_key=counter_key,
|
||
source_cache_key="ignored",
|
||
increment=1.5,
|
||
),
|
||
get_current_spend(counter_key=counter_key, fallback_spend=0.0),
|
||
)
|
||
assert (
|
||
db_call_count == 1
|
||
), f"expected 1 DB query for concurrent read+write+read, got {db_call_count}"
|
||
# Read-path callers see the warmed counter; the write path's
|
||
# increment may or may not have landed by then, so accept either
|
||
# the seeded value or seeded+increment.
|
||
assert results[0] in (50.0, 51.5), f"got {results[0]}"
|
||
assert results[2] in (50.0, 51.5), f"got {results[2]}"
|
||
finally:
|
||
ps.spend_counter_cache = orig_counter
|
||
ps.prisma_client = orig_prisma
|
||
ps.user_api_key_cache = orig_user
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_reseed_locks_dict_is_bounded():
|
||
"""
|
||
`SpendCounterReseed._locks` is an LRU bounded at
|
||
`SPEND_COUNTER_RESEED_LOCKS_MAX_SIZE` to prevent unbounded growth in
|
||
long-lived deployments with high counter-key churn. Inserting more
|
||
than the cap evicts the oldest entries.
|
||
"""
|
||
import litellm.constants as constants
|
||
from litellm.proxy.db.spend_counter_reseed import SpendCounterReseed
|
||
|
||
orig_locks = SpendCounterReseed._locks.copy()
|
||
SpendCounterReseed._locks.clear()
|
||
orig_max = constants.SPEND_COUNTER_RESEED_LOCKS_MAX_SIZE
|
||
constants.SPEND_COUNTER_RESEED_LOCKS_MAX_SIZE = 5
|
||
# The class reads the constant via module-level import, so patch the
|
||
# module-level name on the spend_counter_reseed module too.
|
||
import litellm.proxy.db.spend_counter_reseed as scr
|
||
|
||
orig_module_max = scr.SPEND_COUNTER_RESEED_LOCKS_MAX_SIZE
|
||
scr.SPEND_COUNTER_RESEED_LOCKS_MAX_SIZE = 5
|
||
try:
|
||
for i in range(7):
|
||
await SpendCounterReseed._get_lock(f"spend:key:test-key-{i}")
|
||
assert (
|
||
len(SpendCounterReseed._locks) == 5
|
||
), f"got {len(SpendCounterReseed._locks)}"
|
||
# Oldest two evicted
|
||
assert "spend:key:test-key-0" not in SpendCounterReseed._locks
|
||
assert "spend:key:test-key-1" not in SpendCounterReseed._locks
|
||
# Most recent retained
|
||
assert "spend:key:test-key-6" in SpendCounterReseed._locks
|
||
finally:
|
||
constants.SPEND_COUNTER_RESEED_LOCKS_MAX_SIZE = orig_max
|
||
scr.SPEND_COUNTER_RESEED_LOCKS_MAX_SIZE = orig_module_max
|
||
SpendCounterReseed._locks.clear()
|
||
SpendCounterReseed._locks.update(orig_locks)
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_reseed_warms_cache_even_on_zero_db_spend():
|
||
"""
|
||
When DB returns 0.0 (fresh entity / just after reset), the cache must
|
||
still be warmed so subsequent reads hit the cache instead of issuing
|
||
another DB query. Skipping the warm causes O(requests) DB load on
|
||
zero-spend entities.
|
||
"""
|
||
from litellm.caching.dual_cache import DualCache
|
||
from litellm.proxy.proxy_server import get_current_spend
|
||
|
||
counter_cache = DualCache()
|
||
counter_key = "spend:team_member:user-1:team-zero-warm"
|
||
redis_store: dict = {}
|
||
|
||
async def redis_get(key, **_):
|
||
return redis_store.get(key)
|
||
|
||
async def redis_increment(key, value, **_):
|
||
redis_store[key] = (redis_store.get(key) or 0.0) + value
|
||
return redis_store[key]
|
||
|
||
async def redis_set_cache(key, value, nx=False, **_):
|
||
if nx and key in redis_store:
|
||
return False
|
||
redis_store[key] = float(value)
|
||
return True
|
||
|
||
fake_redis = AsyncMock()
|
||
fake_redis.async_get_cache = AsyncMock(side_effect=redis_get)
|
||
fake_redis.async_increment = AsyncMock(side_effect=redis_increment)
|
||
fake_redis.async_set_cache = AsyncMock(side_effect=redis_set_cache)
|
||
counter_cache.redis_cache = fake_redis
|
||
|
||
db_call_count = 0
|
||
|
||
async def find_unique(**kwargs):
|
||
nonlocal db_call_count
|
||
db_call_count += 1
|
||
row = MagicMock()
|
||
row.spend = 0.0
|
||
return row
|
||
|
||
fake_prisma = MagicMock()
|
||
fake_prisma.db.litellm_teammembership.find_unique = AsyncMock(
|
||
side_effect=find_unique
|
||
)
|
||
|
||
import litellm.proxy.proxy_server as ps
|
||
|
||
orig_counter, orig_prisma = ps.spend_counter_cache, ps.prisma_client
|
||
ps.spend_counter_cache = counter_cache
|
||
ps.prisma_client = fake_prisma
|
||
try:
|
||
# First call: cold cache, hits DB, returns 0.
|
||
spend1 = await get_current_spend(counter_key=counter_key, fallback_spend=0.0)
|
||
# Second call: cache should be warmed at 0, no second DB query.
|
||
spend2 = await get_current_spend(counter_key=counter_key, fallback_spend=0.0)
|
||
assert spend1 == 0.0 and spend2 == 0.0
|
||
assert (
|
||
db_call_count == 1
|
||
), f"second read should hit warmed cache, got {db_call_count} DB queries"
|
||
assert redis_store.get(counter_key) == 0.0, "cache must be warmed at 0"
|
||
finally:
|
||
ps.spend_counter_cache = orig_counter
|
||
ps.prisma_client = orig_prisma
|
||
|
||
|
||
# -----------------------------------------------------------------------------
|
||
# /config/update — critical paths only.
|
||
#
|
||
# These exercise the four behaviors that broke or changed in the rewrite of
|
||
# update_config (litellm/proxy/proxy_server.py): targeted per-section writes,
|
||
# the removal of the store_model_in_db gate, env var encryption, and the
|
||
# success_callback / litellm_settings merge semantics. All other branches
|
||
# (auth, missing-DB, slack auto-enable, router_settings merge) are covered
|
||
# implicitly or by upstream tests.
|
||
# -----------------------------------------------------------------------------
|
||
|
||
|
||
class _FakeRow:
|
||
def __init__(self, param_name, param_value):
|
||
self.param_name = param_name
|
||
self.param_value = param_value
|
||
|
||
|
||
class _FakeLitellmConfig:
|
||
def __init__(self, initial_rows=None):
|
||
self.rows = dict(initial_rows or {})
|
||
self.upsert_calls: list = []
|
||
self.find_first = AsyncMock(side_effect=self._find_first)
|
||
self.upsert = AsyncMock(side_effect=self._upsert)
|
||
|
||
async def _find_first(self, where=None):
|
||
if where and "param_name" in where:
|
||
name = where["param_name"]
|
||
if name in self.rows:
|
||
return _FakeRow(name, self.rows[name])
|
||
return None
|
||
|
||
async def _upsert(self, where=None, data=None):
|
||
name = where["param_name"]
|
||
raw = data["update"]["param_value"]
|
||
value = json.loads(raw) if isinstance(raw, str) else raw
|
||
self.rows[name] = value
|
||
self.upsert_calls.append((name, value))
|
||
|
||
|
||
class _FakePrismaClient:
|
||
def __init__(self, initial_rows=None):
|
||
self.db = mock.MagicMock()
|
||
self.db.litellm_config = _FakeLitellmConfig(initial_rows=initial_rows)
|
||
self.jsonify_object = lambda obj: obj
|
||
|
||
|
||
@pytest.fixture
|
||
def _update_config_setup(monkeypatch):
|
||
"""Install fakes for the /config/update endpoint and return (client, prisma)."""
|
||
from litellm.proxy._types import LitellmUserRoles, UserAPIKeyAuth
|
||
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth as auth_dep
|
||
|
||
def _install(initial_rows=None, store_model_in_db=True):
|
||
prisma = _FakePrismaClient(initial_rows=initial_rows)
|
||
monkeypatch.setattr("litellm.proxy.proxy_server.prisma_client", prisma)
|
||
monkeypatch.setattr(
|
||
"litellm.proxy.proxy_server.store_model_in_db", store_model_in_db
|
||
)
|
||
monkeypatch.setattr(
|
||
"litellm.proxy.proxy_server.encrypt_value_helper",
|
||
lambda value, **_: f"enc:{value}",
|
||
)
|
||
monkeypatch.setattr(
|
||
"litellm.proxy.proxy_server.invalidate_config_param",
|
||
AsyncMock(return_value=None),
|
||
)
|
||
from litellm.proxy.proxy_server import proxy_config as real_proxy_config
|
||
|
||
monkeypatch.setattr(
|
||
real_proxy_config, "add_deployment", AsyncMock(return_value=None)
|
||
)
|
||
|
||
original_overrides = app.dependency_overrides.copy()
|
||
app.dependency_overrides[auth_dep] = lambda: UserAPIKeyAuth(
|
||
user_id="test_admin",
|
||
user_role=LitellmUserRoles.PROXY_ADMIN,
|
||
api_key="sk-1234",
|
||
)
|
||
client = TestClient(app)
|
||
|
||
def _restore():
|
||
app.dependency_overrides = original_overrides
|
||
|
||
return client, prisma, _restore
|
||
|
||
return _install
|
||
|
||
|
||
def test_update_config_writes_only_sent_section(_update_config_setup):
|
||
"""A request that only touches general_settings must not write any other
|
||
section row, and must leave previously-written rows byte-identical."""
|
||
client, prisma, restore = _update_config_setup(
|
||
initial_rows={
|
||
"litellm_settings": {"drop_params": True},
|
||
"environment_variables": {"FOO": "enc:bar"},
|
||
}
|
||
)
|
||
try:
|
||
resp = client.post(
|
||
"/config/update",
|
||
json={"general_settings": {"store_prompts_in_spend_logs": True}},
|
||
)
|
||
assert resp.status_code == 200
|
||
written = {name for name, _ in prisma.db.litellm_config.upsert_calls}
|
||
assert written == {"general_settings"}
|
||
assert prisma.db.litellm_config.rows["litellm_settings"] == {
|
||
"drop_params": True
|
||
}
|
||
assert prisma.db.litellm_config.rows["environment_variables"] == {
|
||
"FOO": "enc:bar"
|
||
}
|
||
finally:
|
||
restore()
|
||
|
||
|
||
def test_update_config_env_var_round_trip_not_double_encrypted(
|
||
_update_config_setup, monkeypatch
|
||
):
|
||
"""Endpoint-level regression for the /config/update double-encryption bug.
|
||
|
||
The Admin UI reads config back via /get/config/callbacks (which returns
|
||
the stored, still-encrypted value) and re-POSTs it on the next save. The
|
||
handler must NOT stack a second encryption layer on the re-submitted
|
||
ciphertext, and must leave untouched keys byte-identical.
|
||
|
||
Uses an invertible fake encrypt/decrypt pair ("enc:" prefix) so the
|
||
decrypt-then-encrypt chokepoint round-trips faithfully. On the pre-fix
|
||
code this stored "enc:enc:..."; the assertions below would fail there.
|
||
"""
|
||
|
||
def _fake_decrypt(
|
||
value, key=None, exception_type="error", return_original_value=False
|
||
):
|
||
if isinstance(value, str) and value.startswith("enc:"):
|
||
return value[len("enc:") :]
|
||
return value if return_original_value else None
|
||
|
||
monkeypatch.setattr(
|
||
"litellm.proxy.proxy_server.decrypt_value_helper", _fake_decrypt
|
||
)
|
||
|
||
client, prisma, restore = _update_config_setup(
|
||
initial_rows={"environment_variables": {"PREEXISTING_KEY": "enc:keepme"}}
|
||
)
|
||
try:
|
||
# First write: plaintext in -> single-encrypted at rest.
|
||
resp = client.post(
|
||
"/config/update",
|
||
json={"environment_variables": {"LANGFUSE_SECRET_KEY": "sk-secret"}},
|
||
)
|
||
assert resp.status_code == 200
|
||
stored = prisma.db.litellm_config.rows["environment_variables"]
|
||
assert stored["LANGFUSE_SECRET_KEY"] == "enc:sk-secret"
|
||
|
||
# UI round-trip: re-POST the stored ciphertext (no field change).
|
||
resp = client.post(
|
||
"/config/update",
|
||
json={
|
||
"environment_variables": {
|
||
"LANGFUSE_SECRET_KEY": stored["LANGFUSE_SECRET_KEY"]
|
||
}
|
||
},
|
||
)
|
||
assert resp.status_code == 200
|
||
stored = prisma.db.litellm_config.rows["environment_variables"]
|
||
|
||
# The bug: this would be "enc:enc:sk-secret". The fix keeps it single.
|
||
assert stored["LANGFUSE_SECRET_KEY"] == "enc:sk-secret"
|
||
assert (
|
||
_fake_decrypt(stored["LANGFUSE_SECRET_KEY"], return_original_value=True)
|
||
== "sk-secret"
|
||
)
|
||
|
||
# Untouched key preserved byte-for-byte (only sent keys rewritten).
|
||
assert stored["PREEXISTING_KEY"] == "enc:keepme"
|
||
finally:
|
||
restore()
|
||
|
||
|
||
def test_update_config_can_flip_store_model_in_db_when_currently_false(
|
||
_update_config_setup,
|
||
):
|
||
"""The endpoint used to refuse all writes when store_model_in_db was
|
||
False, blocking the very request that would flip it to True."""
|
||
client, prisma, restore = _update_config_setup(store_model_in_db=False)
|
||
try:
|
||
resp = client.post(
|
||
"/config/update", json={"general_settings": {"store_model_in_db": True}}
|
||
)
|
||
assert resp.status_code == 200
|
||
assert (
|
||
prisma.db.litellm_config.rows["general_settings"]["store_model_in_db"]
|
||
is True
|
||
)
|
||
finally:
|
||
restore()
|
||
|
||
|
||
def test_update_config_environment_variables_encrypted_before_write(
|
||
_update_config_setup,
|
||
):
|
||
"""env var values must be encrypted before they hit the DB row."""
|
||
client, prisma, restore = _update_config_setup()
|
||
try:
|
||
resp = client.post(
|
||
"/config/update",
|
||
json={"environment_variables": {"OPENAI_API_KEY": "sk-secret"}},
|
||
)
|
||
assert resp.status_code == 200
|
||
stored = prisma.db.litellm_config.rows["environment_variables"]
|
||
assert stored == {"OPENAI_API_KEY": "enc:sk-secret"}
|
||
finally:
|
||
restore()
|
||
|
||
|
||
def test_update_config_litellm_settings_request_wins_for_non_callback_keys(
|
||
_update_config_setup,
|
||
):
|
||
"""Sending {"drop_params": False} when the row holds drop_params: True
|
||
must persist False (request wins). Untouched keys preserved."""
|
||
client, prisma, restore = _update_config_setup(
|
||
initial_rows={
|
||
"litellm_settings": {"drop_params": True, "set_verbose": True},
|
||
}
|
||
)
|
||
try:
|
||
resp = client.post(
|
||
"/config/update", json={"litellm_settings": {"drop_params": False}}
|
||
)
|
||
assert resp.status_code == 200
|
||
stored = prisma.db.litellm_config.rows["litellm_settings"]
|
||
assert stored["drop_params"] is False
|
||
assert stored["set_verbose"] is True
|
||
finally:
|
||
restore()
|
||
|
||
|
||
def test_update_config_success_callback_normalizes_existing_mixed_case(
|
||
_update_config_setup,
|
||
):
|
||
"""Existing mixed-case callback names (written elsewhere) must be
|
||
normalized to lowercase before union, otherwise the union dedup misses
|
||
against the lowercase incoming entry and delete_callback (lowercase
|
||
lookup) cannot find the original."""
|
||
client, prisma, restore = _update_config_setup(
|
||
initial_rows={"litellm_settings": {"success_callback": ["Langfuse", "SQS"]}}
|
||
)
|
||
try:
|
||
resp = client.post(
|
||
"/config/update",
|
||
json={"litellm_settings": {"success_callback": ["langfuse"]}},
|
||
)
|
||
assert resp.status_code == 200
|
||
stored = prisma.db.litellm_config.rows["litellm_settings"]["success_callback"]
|
||
assert set(stored) == {"langfuse", "sqs"}
|
||
finally:
|
||
restore()
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Lazy feature loading (LazyFeatureMiddleware) — verifies that optional
|
||
# routers are NOT imported at module load and ARE imported on first request
|
||
# to a matching path prefix. The same module isn't re-imported on subsequent
|
||
# requests.
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
class TestLazyFeatureRegistry:
|
||
"""Sanity checks on the registry shape — guards against accidental edits."""
|
||
|
||
def test_registry_entries_have_required_fields(self):
|
||
from litellm.proxy._lazy_features import LAZY_FEATURES, LazyFeature
|
||
|
||
assert len(LAZY_FEATURES) > 0
|
||
for feat in LAZY_FEATURES:
|
||
assert isinstance(feat, LazyFeature)
|
||
assert feat.name
|
||
assert feat.module_path
|
||
assert feat.path_prefixes
|
||
assert all(p.startswith("/") for p in feat.path_prefixes)
|
||
assert callable(feat.register_fn)
|
||
|
||
def test_registry_names_unique(self):
|
||
from litellm.proxy._lazy_features import LAZY_FEATURES
|
||
|
||
names = [f.name for f in LAZY_FEATURES]
|
||
assert len(names) == len(set(names)), "duplicate feature names"
|
||
|
||
def test_matches_covers_prefix_and_suffix(self):
|
||
"""``matches`` is the single matcher shared by the middleware (request
|
||
paths) and the warm endpoint (registered route paths), so a route that
|
||
only matches via suffix — e.g. ``/v1/a2a/{id}/message/send`` against the
|
||
``/a2a`` prefix — must still be claimed by the feature."""
|
||
from litellm.proxy._lazy_features import LazyFeature
|
||
|
||
feat = LazyFeature(
|
||
name="a2a",
|
||
module_path="json",
|
||
path_prefixes=("/a2a",),
|
||
path_suffixes=("/message/send",),
|
||
)
|
||
assert feat.matches("/a2a/abc/message/send")
|
||
assert feat.matches("/v1/a2a/abc/message/send")
|
||
assert feat.matches("/a2a/abc/.well-known/agent-card.json")
|
||
assert not feat.matches("/v1/a2a/discover")
|
||
assert not feat.matches("/unrelated")
|
||
|
||
|
||
class TestLazyFeaturesNotImportedAtStartup:
|
||
"""
|
||
The whole point of the refactor: gated feature modules must NOT be
|
||
present in `sys.modules` immediately after `proxy_server` imports.
|
||
"""
|
||
|
||
def test_heavy_modules_absent_at_startup(self):
|
||
# Static scan of proxy_server.py source — catches any top-level
|
||
# `from <lazy_module> import` that would defeat lazy loading.
|
||
# Importing proxy_server in a subprocess and diffing sys.modules
|
||
# would also work, but takes 60-120 s and flakes on slow CI runners.
|
||
import re
|
||
from pathlib import Path
|
||
|
||
from litellm.proxy._lazy_features import LAZY_FEATURES
|
||
|
||
proxy_server_src = (
|
||
Path(__file__).resolve().parents[3] / "litellm/proxy/proxy_server.py"
|
||
).read_text()
|
||
|
||
leaks = []
|
||
for feat in LAZY_FEATURES:
|
||
# Anchor at column 0 — indented imports inside function bodies
|
||
# are fine (deferred until the function runs).
|
||
pattern = (
|
||
rf"^(from\s+{re.escape(feat.module_path)}\s+import|"
|
||
rf"import\s+{re.escape(feat.module_path)})"
|
||
)
|
||
if re.search(pattern, proxy_server_src, re.MULTILINE):
|
||
leaks.append(feat.module_path)
|
||
|
||
assert not leaks, (
|
||
"proxy_server.py top-level imports a lazy feature module — these "
|
||
f"should be loaded via LazyFeatureMiddleware: {leaks}"
|
||
)
|
||
|
||
|
||
class TestLazyFeatureMiddleware:
|
||
"""Behavior of the middleware itself, exercised in isolation."""
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_first_request_triggers_load_subsequent_does_not(self):
|
||
from fastapi import FastAPI
|
||
|
||
from litellm.proxy._lazy_features import (
|
||
LazyFeature,
|
||
LazyFeatureMiddleware,
|
||
)
|
||
|
||
loads = []
|
||
|
||
def fake_register(app, module):
|
||
loads.append(getattr(module, "__name__", "?"))
|
||
|
||
feat = LazyFeature(
|
||
name="dummy",
|
||
module_path="json", # any always-importable stdlib module
|
||
path_prefixes=("/dummy",),
|
||
register_fn=fake_register,
|
||
)
|
||
|
||
# Build a minimal ASGI receiver to satisfy the middleware contract
|
||
async def downstream(scope, receive, send):
|
||
# echo back; no-op handler
|
||
await send({"type": "http.response.start", "status": 200, "headers": []})
|
||
await send({"type": "http.response.body", "body": b""})
|
||
|
||
target_app = FastAPI()
|
||
mw = LazyFeatureMiddleware(downstream, fastapi_app=target_app, features=(feat,))
|
||
|
||
async def receive():
|
||
return {"type": "http.request", "body": b"", "more_body": False}
|
||
|
||
sent: list = []
|
||
|
||
async def send(message):
|
||
sent.append(message)
|
||
|
||
# First request matching the prefix triggers register
|
||
await mw(
|
||
{"type": "http", "path": "/dummy/x", "method": "GET", "headers": []},
|
||
receive,
|
||
send,
|
||
)
|
||
assert loads == ["json"]
|
||
|
||
# Second matching request must NOT re-register
|
||
sent.clear()
|
||
await mw(
|
||
{"type": "http", "path": "/dummy/y", "method": "GET", "headers": []},
|
||
receive,
|
||
send,
|
||
)
|
||
assert loads == ["json"], "register_fn called twice for the same feature"
|
||
|
||
# Non-matching path must not trigger anything
|
||
await mw(
|
||
{"type": "http", "path": "/unrelated", "method": "GET", "headers": []},
|
||
receive,
|
||
send,
|
||
)
|
||
assert loads == ["json"]
|
||
|
||
@pytest.mark.asyncio
|
||
@pytest.mark.parametrize(
|
||
"server_root_path,request_path,should_load,case",
|
||
[
|
||
# SERVER_ROOT_PATH set: incoming path includes prefix → strip and match.
|
||
("/api/v1", "/api/v1/dummy/x", True, "root_path strip + match"),
|
||
# Trailing-slash env var must be normalized.
|
||
("/api/v1/", "/api/v1/dummy/x", True, "trailing-slash env normalization"),
|
||
# Reverse proxy already stripped the prefix → original path still matches.
|
||
("/api/v1", "/dummy/x", True, "pre-stripped path still loads"),
|
||
# No SERVER_ROOT_PATH set → unchanged behavior.
|
||
("", "/dummy/x", True, "no root path"),
|
||
# SERVER_ROOT_PATH=/ must be a no-op (not strip every leading slash).
|
||
("/", "/dummy/x", True, "root_path='/' is no-op"),
|
||
# Boundary check: /apiv2 must not match root /api.
|
||
("/api", "/apiv2/foo", False, "boundary check prevents false match"),
|
||
# Genuine non-match under root_path.
|
||
("/api/v1", "/api/v1/unrelated", False, "unrelated path under root"),
|
||
],
|
||
)
|
||
async def test_root_path_handling(
|
||
self, monkeypatch, server_root_path, request_path, should_load, case
|
||
):
|
||
"""
|
||
The middleware must strip SERVER_ROOT_PATH before prefix-matching so
|
||
lazy features load under deployments that set a server root path,
|
||
while handling boundary, trailing-slash, and reverse-proxy edge cases
|
||
correctly.
|
||
"""
|
||
from fastapi import FastAPI
|
||
|
||
from litellm.proxy._lazy_features import (
|
||
LazyFeature,
|
||
LazyFeatureMiddleware,
|
||
)
|
||
|
||
monkeypatch.setenv("SERVER_ROOT_PATH", server_root_path)
|
||
|
||
loads = []
|
||
|
||
def fake_register(app, module):
|
||
loads.append(getattr(module, "__name__", "?"))
|
||
|
||
feat = LazyFeature(
|
||
name=f"dummy_{case}",
|
||
module_path="json",
|
||
path_prefixes=("/dummy",),
|
||
register_fn=fake_register,
|
||
)
|
||
|
||
async def downstream(scope, receive, send):
|
||
await send({"type": "http.response.start", "status": 200, "headers": []})
|
||
await send({"type": "http.response.body", "body": b""})
|
||
|
||
target_app = FastAPI()
|
||
mw = LazyFeatureMiddleware(downstream, fastapi_app=target_app, features=(feat,))
|
||
|
||
async def receive():
|
||
return {"type": "http.request", "body": b"", "more_body": False}
|
||
|
||
async def send(message):
|
||
pass
|
||
|
||
await mw(
|
||
{
|
||
"type": "http",
|
||
"path": request_path,
|
||
"method": "GET",
|
||
"headers": [],
|
||
},
|
||
receive,
|
||
send,
|
||
)
|
||
if should_load:
|
||
assert loads == ["json"], f"{case}: expected feature to load"
|
||
else:
|
||
assert loads == [], f"{case}: feature must not load"
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_concurrent_first_requests_only_register_once(self):
|
||
"""
|
||
Two requests to the same prefix arriving in parallel must result in
|
||
exactly one `register_fn` invocation — the lock prevents the import +
|
||
register from racing with itself.
|
||
"""
|
||
from fastapi import FastAPI
|
||
|
||
from litellm.proxy._lazy_features import (
|
||
LazyFeature,
|
||
LazyFeatureMiddleware,
|
||
)
|
||
|
||
loads = []
|
||
|
||
def slow_register(app, module):
|
||
loads.append(getattr(module, "__name__", "?"))
|
||
|
||
feat = LazyFeature(
|
||
name="dummy_concurrent",
|
||
module_path="json",
|
||
path_prefixes=("/dummy_c",),
|
||
register_fn=slow_register,
|
||
)
|
||
|
||
async def downstream(scope, receive, send):
|
||
await send({"type": "http.response.start", "status": 200, "headers": []})
|
||
await send({"type": "http.response.body", "body": b""})
|
||
|
||
target_app = FastAPI()
|
||
mw = LazyFeatureMiddleware(downstream, fastapi_app=target_app, features=(feat,))
|
||
|
||
async def receive():
|
||
return {"type": "http.request", "body": b"", "more_body": False}
|
||
|
||
sent: list = []
|
||
|
||
async def send(message):
|
||
sent.append(message)
|
||
|
||
async def hit():
|
||
await mw(
|
||
{
|
||
"type": "http",
|
||
"path": "/dummy_c/x",
|
||
"method": "GET",
|
||
"headers": [],
|
||
},
|
||
receive,
|
||
send,
|
||
)
|
||
|
||
await asyncio.gather(hit(), hit(), hit(), hit(), hit())
|
||
assert loads == [
|
||
"json"
|
||
], f"expected one registration despite concurrent first hits, got {loads}"
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_failing_import_does_not_loop(self):
|
||
"""
|
||
If a feature's module can't be imported, the middleware should mark it
|
||
loaded anyway so subsequent requests don't repeatedly retry the failing
|
||
import (which would amplify the cost on every request).
|
||
"""
|
||
from fastapi import FastAPI
|
||
|
||
from litellm.proxy._lazy_features import (
|
||
LazyFeature,
|
||
LazyFeatureMiddleware,
|
||
)
|
||
|
||
attempts = []
|
||
|
||
def fail_register(app, module):
|
||
attempts.append("called")
|
||
raise RuntimeError("boom")
|
||
|
||
feat = LazyFeature(
|
||
name="failing",
|
||
module_path="json",
|
||
path_prefixes=("/fail",),
|
||
register_fn=fail_register,
|
||
)
|
||
|
||
async def downstream(scope, receive, send):
|
||
await send({"type": "http.response.start", "status": 200, "headers": []})
|
||
await send({"type": "http.response.body", "body": b""})
|
||
|
||
target_app = FastAPI()
|
||
mw = LazyFeatureMiddleware(downstream, fastapi_app=target_app, features=(feat,))
|
||
|
||
async def receive():
|
||
return {"type": "http.request", "body": b"", "more_body": False}
|
||
|
||
sent: list = []
|
||
|
||
async def send(message):
|
||
sent.append(message)
|
||
|
||
for _ in range(3):
|
||
await mw(
|
||
{"type": "http", "path": "/fail/x", "method": "GET", "headers": []},
|
||
receive,
|
||
send,
|
||
)
|
||
assert attempts == [
|
||
"called"
|
||
], f"failing register_fn should be invoked once, not on every request; got {attempts}"
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_get_current_spend_redis_clean_miss_skips_stale_in_memory():
|
||
"""When Redis is reachable and cleanly returns None (TTL expired,
|
||
counter genuinely absent), the read must reseed from DB - NOT fall
|
||
through to per-pod in-memory which only contains this pod's writes.
|
||
|
||
Pre-fix in multi-pod deployments, in-memory contained a stale local
|
||
subset (e.g. $30) while DB had the true cross-pod total ($500). The
|
||
fall-through returned $30, enforcement passed, bypass.
|
||
"""
|
||
from litellm.caching.dual_cache import DualCache
|
||
from litellm.proxy.proxy_server import get_current_spend
|
||
|
||
counter_cache = DualCache()
|
||
counter_key = "spend:team_member:user-1:team-1"
|
||
|
||
# Per-pod stale in-memory: only this pod's writes, not cross-pod truth.
|
||
counter_cache.in_memory_cache.set_cache(key=counter_key, value=30.0)
|
||
|
||
# Redis cleanly returns None (key expired or never written on this pod).
|
||
fake_redis = AsyncMock()
|
||
fake_redis.async_get_cache = AsyncMock(return_value=None)
|
||
fake_redis.async_increment = AsyncMock(return_value=500.0)
|
||
counter_cache.redis_cache = fake_redis
|
||
|
||
# DB has the authoritative cross-pod spend.
|
||
db_row = MagicMock()
|
||
db_row.spend = 500.0
|
||
fake_prisma = MagicMock()
|
||
fake_prisma.db.litellm_teammembership.find_unique = AsyncMock(return_value=db_row)
|
||
|
||
import litellm.proxy.proxy_server as ps
|
||
|
||
orig_counter, orig_prisma = ps.spend_counter_cache, ps.prisma_client
|
||
ps.spend_counter_cache = counter_cache
|
||
ps.prisma_client = fake_prisma
|
||
try:
|
||
spend = await get_current_spend(counter_key=counter_key, fallback_spend=0.0)
|
||
assert spend == 500.0, (
|
||
f"expected DB-authoritative 500.0 on clean Redis miss, got {spend} "
|
||
f"(stale per-pod in-memory $30 would have caused multi-pod bypass)"
|
||
)
|
||
finally:
|
||
ps.spend_counter_cache = orig_counter
|
||
ps.prisma_client = orig_prisma
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_get_current_spend_redis_error_falls_back_to_in_memory():
|
||
"""When Redis raises, the read should still degrade to in-memory rather
|
||
than going straight to DB - in-memory is at least same-pod-fresh and
|
||
cheaper than a DB query during a Redis outage."""
|
||
from litellm.caching.dual_cache import DualCache
|
||
from litellm.proxy.proxy_server import get_current_spend
|
||
|
||
counter_cache = DualCache()
|
||
counter_key = "spend:team_member:user-1:team-1"
|
||
|
||
counter_cache.in_memory_cache.set_cache(key=counter_key, value=42.0)
|
||
|
||
fake_redis = AsyncMock()
|
||
fake_redis.async_get_cache = AsyncMock(side_effect=ConnectionError("redis down"))
|
||
counter_cache.redis_cache = fake_redis
|
||
|
||
fake_prisma = MagicMock()
|
||
fake_prisma.db.litellm_teammembership.find_unique = AsyncMock(
|
||
return_value=MagicMock(spend=999.0)
|
||
)
|
||
|
||
import litellm.proxy.proxy_server as ps
|
||
|
||
orig_counter, orig_prisma = ps.spend_counter_cache, ps.prisma_client
|
||
ps.spend_counter_cache = counter_cache
|
||
ps.prisma_client = fake_prisma
|
||
try:
|
||
spend = await get_current_spend(counter_key=counter_key, fallback_spend=0.0)
|
||
assert spend == 42.0, (
|
||
f"expected in-memory fallback 42.0 on Redis error, got {spend} "
|
||
f"(should not have hit DB when Redis errored)"
|
||
)
|
||
# DB query should NOT have fired - in-memory short-circuits.
|
||
fake_prisma.db.litellm_teammembership.find_unique.assert_not_awaited()
|
||
finally:
|
||
ps.spend_counter_cache = orig_counter
|
||
ps.prisma_client = orig_prisma
|
||
|
||
|
||
def test_realtime_websocket_route_aliases_registered():
|
||
"""Realtime sessions reach the proxy via three path aliases stacked on
|
||
`realtime_websocket_endpoint`. Dropping any of them silently 405s
|
||
WebSocket upgrades because the catch-all `/openai/{endpoint:path}`
|
||
HTTP passthrough only declares HTTP methods. The aliases must also be
|
||
in `LiteLLMRoutes.openai_routes` (so non-admin / team / key-scoped
|
||
auth allows them) and in `API_ROUTE_TO_CALL_TYPES` (so call-type-aware
|
||
logic such as guardrails can resolve the realtime call type)."""
|
||
from starlette.routing import WebSocketRoute
|
||
|
||
from litellm.proxy._types import LiteLLMRoutes
|
||
from litellm.proxy.proxy_server import app
|
||
from litellm.types.utils import API_ROUTE_TO_CALL_TYPES, CallTypes
|
||
|
||
websocket_paths = {
|
||
route.path for route in app.routes if isinstance(route, WebSocketRoute)
|
||
}
|
||
openai_routes = LiteLLMRoutes.openai_routes.value
|
||
|
||
for expected in ("/openai/v1/realtime", "/v1/realtime", "/realtime"):
|
||
assert expected in websocket_paths, (
|
||
f"{expected!r} missing from registered WebSocket routes; the "
|
||
f"realtime endpoint will 405 for clients hitting this path."
|
||
)
|
||
assert expected in openai_routes, (
|
||
f"{expected!r} missing from LiteLLMRoutes.openai_routes; "
|
||
f"non-admin / team / key-scoped users will get 403 on this path."
|
||
)
|
||
assert API_ROUTE_TO_CALL_TYPES.get(expected) == [CallTypes.arealtime], (
|
||
f"{expected!r} missing from API_ROUTE_TO_CALL_TYPES; call-type "
|
||
f"resolution will return None and break call-type-aware features."
|
||
)
|
||
|
||
|
||
class TestTransformRequestBannedParams:
|
||
"""
|
||
/utils/transform_request applies the same banned-param check as LLM endpoints.
|
||
|
||
Without this check, any authenticated user could supply aws_sts_endpoint,
|
||
api_base, etc. and have the server forward its credentials to an
|
||
attacker-controlled endpoint during SDK credential resolution.
|
||
"""
|
||
|
||
@pytest.fixture
|
||
def client(self):
|
||
mock_auth = UserAPIKeyAuth(
|
||
user_id="test-internal",
|
||
user_role=LitellmUserRoles.INTERNAL_USER,
|
||
)
|
||
original = app.dependency_overrides.copy()
|
||
app.dependency_overrides[user_api_key_auth] = lambda: mock_auth
|
||
try:
|
||
yield TestClient(app)
|
||
finally:
|
||
app.dependency_overrides = original
|
||
|
||
@pytest.mark.parametrize(
|
||
"banned",
|
||
[
|
||
"aws_sts_endpoint",
|
||
"api_base",
|
||
"aws_web_identity_token",
|
||
"vertex_credentials",
|
||
],
|
||
)
|
||
def test_banned_params_rejected_for_all_users(self, client, banned):
|
||
"""Banned params must be blocked for any authenticated user."""
|
||
response = client.post(
|
||
"/utils/transform_request",
|
||
json={
|
||
"call_type": "completion",
|
||
"request_body": {
|
||
"model": "gpt-3.5-turbo",
|
||
banned: "https://attacker.example",
|
||
},
|
||
},
|
||
)
|
||
assert response.status_code == 400, (
|
||
f"Expected 400 for banned param '{banned}', "
|
||
f"got {response.status_code}: {response.json()}"
|
||
)
|
||
|
||
|
||
class TestSortModelsByDisplayName:
|
||
"""Regression: team BYOK rows persist an internal `model_name` like
|
||
`model_name_{team_id}_{uuid}` and expose the user-facing name via
|
||
`model_info.team_public_model_name`. Sorting must use the displayed
|
||
name so BYOK rows interleave with non-BYOK rows alphabetically —
|
||
otherwise they clump at the end on their opaque IDs even though the
|
||
UI shows them under a normal-looking name.
|
||
"""
|
||
|
||
def test_byok_models_sort_by_team_public_model_name(self):
|
||
from litellm.proxy.proxy_server import _sort_models
|
||
|
||
models = [
|
||
{"model_name": "claude-haiku-4-5", "model_info": {}},
|
||
{
|
||
# Opaque internal name; UI displays team_public_model_name.
|
||
"model_name": "model_name_team-1_abc123",
|
||
"model_info": {"team_public_model_name": "anthropic/claude"},
|
||
},
|
||
{"model_name": "gpt-4o", "model_info": {}},
|
||
]
|
||
|
||
sorted_models = _sort_models(
|
||
all_models=models, sort_by="model_name", sort_order="asc"
|
||
)
|
||
displayed_order = [
|
||
m["model_info"].get("team_public_model_name") or m["model_name"]
|
||
for m in sorted_models
|
||
]
|
||
assert displayed_order == [
|
||
"anthropic/claude",
|
||
"claude-haiku-4-5",
|
||
"gpt-4o",
|
||
]
|
||
|
||
def test_byok_models_sort_descending_by_display_name(self):
|
||
from litellm.proxy.proxy_server import _sort_models
|
||
|
||
models = [
|
||
{"model_name": "claude-haiku-4-5", "model_info": {}},
|
||
{
|
||
"model_name": "model_name_team-1_zzz",
|
||
"model_info": {"team_public_model_name": "zeta/model"},
|
||
},
|
||
{"model_name": "gpt-4o", "model_info": {}},
|
||
]
|
||
|
||
sorted_models = _sort_models(
|
||
all_models=models, sort_by="model_name", sort_order="desc"
|
||
)
|
||
displayed_order = [
|
||
m["model_info"].get("team_public_model_name") or m["model_name"]
|
||
for m in sorted_models
|
||
]
|
||
assert displayed_order == [
|
||
"zeta/model",
|
||
"gpt-4o",
|
||
"claude-haiku-4-5",
|
||
]
|
||
|
||
def test_empty_team_public_model_name_falls_back_to_model_name(self):
|
||
# Empty string for team_public_model_name (not None) must still
|
||
# fall back to model_name — otherwise BYOK rows with a blank
|
||
# display name would sort to the top.
|
||
from litellm.proxy.proxy_server import _sort_models
|
||
|
||
models = [
|
||
{"model_name": "alpha", "model_info": {"team_public_model_name": ""}},
|
||
{"model_name": "beta", "model_info": {}},
|
||
]
|
||
|
||
sorted_models = _sort_models(
|
||
all_models=models, sort_by="model_name", sort_order="asc"
|
||
)
|
||
assert [m["model_name"] for m in sorted_models] == ["alpha", "beta"]
|