mirror of
https://github.com/tiennm99/litellm.git
synced 2026-06-17 14:48:44 +00:00
b3297fc2ea
* feat(proxy): hot-reload .env in dev when running with --reload The --reload watcher already restarts the worker on *.py and --config YAML edits, but .env was unwatched, so changing a key there did nothing until a manual restart. Add .env to the uvicorn reload_includes (and to the StatReload monkeypatch, which ignores reload_includes) so an edit triggers a worker restart. A reloaded worker is a fresh process that inherits the reloader's environment, so load_dotenv(override=False) would keep serving the stale inherited value for any key already in the environment. The CLI now exports LITELLM_DEV_ENV_HOT_RELOAD when --reload is set, and litellm/__init__.py reads it to load .env with override=True only on that dev path, leaving normal startup precedence untouched. * feat(proxy): warn that --reload makes .env override shell env vars When --reload is active, worker processes re-read .env with override=True, so .env values win over shell-exported environment variables. Surface this dotenv precedence change with a startup warning so a developer who relies on a shell-exported override is not silently surprised. * fix(proxy): type reload helper paths as Optional[str] to satisfy mypy * fix(proxy): watch the cwd .env in both reload backends for parity WatchFiles only watches cwd (and the --config dir) for .env, while the StatReload fallback used find_dotenv(usecwd=True), which walks up to a parent-dir .env that WatchFiles never sees. Point StatReload at the same cwd .env so the two reload backends react to the same file.
1513 lines
56 KiB
Python
1513 lines
56 KiB
Python
import os
|
|
import sys
|
|
from pathlib import Path
|
|
from types import SimpleNamespace
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
import click
|
|
import fastapi
|
|
import pytest
|
|
|
|
sys.path.insert(
|
|
0, os.path.abspath("../../..")
|
|
) # Adds the parent directory to the system-path
|
|
|
|
import builtins
|
|
import types
|
|
|
|
from litellm.proxy.proxy_cli import ProxyInitializationHelpers
|
|
|
|
|
|
@pytest.mark.xdist_group("proxy_cli")
|
|
class TestProxyInitializationHelpers:
|
|
@patch("importlib.metadata.version")
|
|
@patch("click.echo")
|
|
def test_echo_litellm_version(self, mock_echo, mock_version):
|
|
# Setup
|
|
mock_version.return_value = "1.0.0"
|
|
|
|
# Execute
|
|
ProxyInitializationHelpers._echo_litellm_version()
|
|
|
|
# Assert
|
|
mock_version.assert_called_once_with("litellm")
|
|
mock_echo.assert_called_once_with("\nLiteLLM: Current Version = 1.0.0\n")
|
|
|
|
@patch("httpx.get")
|
|
@patch("builtins.print")
|
|
@patch("json.dumps")
|
|
def test_run_health_check(self, mock_dumps, mock_print, mock_get):
|
|
# Setup
|
|
mock_response = MagicMock()
|
|
mock_response.json.return_value = {"status": "healthy"}
|
|
mock_get.return_value = mock_response
|
|
mock_dumps.return_value = '{"status": "healthy"}'
|
|
|
|
# Execute
|
|
ProxyInitializationHelpers._run_health_check("localhost", 8000)
|
|
|
|
# Assert
|
|
mock_get.assert_called_once_with(url="http://localhost:8000/health")
|
|
mock_response.json.assert_called_once()
|
|
mock_dumps.assert_called_once_with({"status": "healthy"}, indent=4)
|
|
|
|
@patch("openai.OpenAI")
|
|
@patch("click.echo")
|
|
@patch("builtins.print")
|
|
def test_run_test_chat_completion(self, mock_print, mock_echo, mock_openai):
|
|
# Setup
|
|
mock_client = MagicMock()
|
|
mock_openai.return_value = mock_client
|
|
|
|
mock_response = MagicMock()
|
|
mock_client.chat.completions.create.return_value = mock_response
|
|
|
|
mock_stream_response = MagicMock()
|
|
mock_stream_response.__iter__.return_value = [MagicMock(), MagicMock()]
|
|
mock_client.chat.completions.create.side_effect = [
|
|
mock_response,
|
|
mock_stream_response,
|
|
]
|
|
|
|
# Execute
|
|
with pytest.raises(ValueError, match="Invalid test value"):
|
|
ProxyInitializationHelpers._run_test_chat_completion(
|
|
"localhost", 8000, "gpt-3.5-turbo", True
|
|
)
|
|
|
|
# Test with valid string test value
|
|
ProxyInitializationHelpers._run_test_chat_completion(
|
|
"localhost", 8000, "gpt-3.5-turbo", "http://test-url"
|
|
)
|
|
|
|
# Assert
|
|
mock_openai.assert_called_once_with(
|
|
api_key="My API Key", base_url="http://test-url"
|
|
)
|
|
mock_client.chat.completions.create.assert_called()
|
|
|
|
def test_get_default_unvicorn_init_args(self):
|
|
# Test without log_config
|
|
args = ProxyInitializationHelpers._get_default_unvicorn_init_args(
|
|
"localhost", 8000
|
|
)
|
|
assert args["app"] == "litellm.proxy.proxy_server:app"
|
|
assert args["host"] == "localhost"
|
|
assert args["port"] == 8000
|
|
|
|
# Test with log_config
|
|
args = ProxyInitializationHelpers._get_default_unvicorn_init_args(
|
|
"localhost", 8000, "log_config.json"
|
|
)
|
|
assert args["log_config"] == "log_config.json"
|
|
|
|
# Test with json_logs=True
|
|
with patch("litellm.json_logs", True):
|
|
args = ProxyInitializationHelpers._get_default_unvicorn_init_args(
|
|
"localhost", 8000
|
|
)
|
|
# When json_logs is True, log_config should be set to the JSON log config dict
|
|
assert args["log_config"] is not None
|
|
assert isinstance(args["log_config"], dict)
|
|
assert "version" in args["log_config"]
|
|
assert "formatters" in args["log_config"]
|
|
|
|
# Test with keepalive_timeout
|
|
args = ProxyInitializationHelpers._get_default_unvicorn_init_args(
|
|
"localhost", 8000, None, 60
|
|
)
|
|
assert args["timeout_keep_alive"] == 60
|
|
|
|
# Test with both log_config and keepalive_timeout
|
|
args = ProxyInitializationHelpers._get_default_unvicorn_init_args(
|
|
"localhost", 8000, "log_config.json", 120
|
|
)
|
|
assert args["log_config"] == "log_config.json"
|
|
assert args["timeout_keep_alive"] == 120
|
|
|
|
class _FakeUvicornConfig:
|
|
def __init__(self, timeout_worker_healthcheck=None):
|
|
pass
|
|
|
|
with patch("uvicorn.Config", _FakeUvicornConfig):
|
|
args = ProxyInitializationHelpers._get_default_unvicorn_init_args(
|
|
"localhost", 8000, timeout_worker_healthcheck=15
|
|
)
|
|
assert args["timeout_worker_healthcheck"] == 15
|
|
|
|
def test_get_reload_options_no_config_still_watches_env(self):
|
|
opts = ProxyInitializationHelpers._get_reload_options(None)
|
|
assert opts["reload"] is True
|
|
assert opts["reload_dirs"] == [os.path.abspath(os.getcwd())]
|
|
assert opts["reload_includes"] == ["*.py", ".env"]
|
|
|
|
def test_get_reload_options_with_config_in_cwd(self, tmp_path, monkeypatch):
|
|
config_file = tmp_path / "config.yaml"
|
|
config_file.write_text("model_list: []\n")
|
|
monkeypatch.chdir(tmp_path)
|
|
|
|
opts = ProxyInitializationHelpers._get_reload_options("config.yaml")
|
|
|
|
assert opts["reload"] is True
|
|
assert opts["reload_dirs"] == [str(tmp_path)]
|
|
assert opts["reload_includes"] == ["*.py", ".env", "config.yaml"]
|
|
|
|
def test_get_reload_options_with_config_outside_cwd(self, tmp_path, monkeypatch):
|
|
cwd_dir = tmp_path / "work"
|
|
cwd_dir.mkdir()
|
|
elsewhere = tmp_path / "configs"
|
|
elsewhere.mkdir()
|
|
config_file = elsewhere / "proxy.yaml"
|
|
config_file.write_text("model_list: []\n")
|
|
monkeypatch.chdir(cwd_dir)
|
|
|
|
opts = ProxyInitializationHelpers._get_reload_options(str(config_file))
|
|
|
|
assert opts["reload"] is True
|
|
assert opts["reload_dirs"] == [str(cwd_dir), str(elsewhere)]
|
|
assert opts["reload_includes"] == ["*.py", ".env", "proxy.yaml"]
|
|
|
|
def test_patch_statreload_extra_paths_yields_config_and_py(self, tmp_path):
|
|
from pathlib import Path
|
|
|
|
from uvicorn.supervisors.statreload import StatReload
|
|
|
|
if hasattr(StatReload, "_litellm_patched_config_paths"):
|
|
StatReload._litellm_patched_config_paths.clear()
|
|
|
|
config_file = tmp_path / "config.yaml"
|
|
config_file.write_text("model_list: []\n")
|
|
py_file = tmp_path / "module.py"
|
|
py_file.write_text("x = 1\n")
|
|
|
|
applied = ProxyInitializationHelpers._patch_statreload_extra_paths(
|
|
[str(config_file)]
|
|
)
|
|
assert applied is True
|
|
|
|
fake_self = types.SimpleNamespace(
|
|
config=types.SimpleNamespace(reload_dirs=[tmp_path])
|
|
)
|
|
yielded_paths = {Path(p).resolve() for p in StatReload.iter_py_files(fake_self)}
|
|
|
|
assert config_file.resolve() in yielded_paths
|
|
assert py_file.resolve() in yielded_paths
|
|
|
|
def test_patch_statreload_extra_paths_yields_env(self, tmp_path):
|
|
from pathlib import Path
|
|
|
|
from uvicorn.supervisors.statreload import StatReload
|
|
|
|
if hasattr(StatReload, "_litellm_patched_config_paths"):
|
|
StatReload._litellm_patched_config_paths.clear()
|
|
|
|
env_file = tmp_path / ".env"
|
|
env_file.write_text("FOO=bar\n")
|
|
|
|
applied = ProxyInitializationHelpers._patch_statreload_extra_paths(
|
|
[str(env_file)]
|
|
)
|
|
assert applied is True
|
|
|
|
fake_self = types.SimpleNamespace(
|
|
config=types.SimpleNamespace(reload_dirs=[tmp_path])
|
|
)
|
|
yielded_paths = {Path(p).resolve() for p in StatReload.iter_py_files(fake_self)}
|
|
|
|
assert env_file.resolve() in yielded_paths
|
|
|
|
def test_patch_statreload_extra_paths_skips_falsy(self, tmp_path):
|
|
from uvicorn.supervisors.statreload import StatReload
|
|
|
|
if hasattr(StatReload, "_litellm_patched_config_paths"):
|
|
StatReload._litellm_patched_config_paths.clear()
|
|
|
|
assert ProxyInitializationHelpers._patch_statreload_extra_paths([]) is False
|
|
assert (
|
|
ProxyInitializationHelpers._patch_statreload_extra_paths([None, ""])
|
|
is False
|
|
)
|
|
|
|
def test_patch_statreload_extra_paths_is_idempotent(self, tmp_path):
|
|
from pathlib import Path
|
|
|
|
from uvicorn.supervisors.statreload import StatReload
|
|
|
|
if hasattr(StatReload, "_litellm_patched_config_paths"):
|
|
StatReload._litellm_patched_config_paths.clear()
|
|
|
|
config_file = tmp_path / "config.yaml"
|
|
config_file.write_text("model_list: []\n")
|
|
py_file = tmp_path / "only.py"
|
|
py_file.write_text("x = 1\n")
|
|
|
|
for _ in range(3):
|
|
ProxyInitializationHelpers._patch_statreload_extra_paths([str(config_file)])
|
|
|
|
fake_self = types.SimpleNamespace(
|
|
config=types.SimpleNamespace(reload_dirs=[tmp_path])
|
|
)
|
|
yielded = list(StatReload.iter_py_files(fake_self))
|
|
assert len(yielded) == len(set(map(str, yielded)))
|
|
yielded_paths = {Path(p).resolve() for p in yielded}
|
|
assert config_file.resolve() in yielded_paths
|
|
assert py_file.resolve() in yielded_paths
|
|
|
|
def test_configure_dev_reload_watches_env_and_sets_override_flag(
|
|
self, tmp_path, monkeypatch
|
|
):
|
|
from pathlib import Path
|
|
|
|
from uvicorn.supervisors.statreload import StatReload
|
|
|
|
if hasattr(StatReload, "_litellm_patched_config_paths"):
|
|
StatReload._litellm_patched_config_paths.clear()
|
|
monkeypatch.delenv("LITELLM_DEV_ENV_HOT_RELOAD", raising=False)
|
|
|
|
config_file = tmp_path / "config.yaml"
|
|
config_file.write_text("model_list: []\n")
|
|
env_file = tmp_path / ".env"
|
|
env_file.write_text("FOO=bar\n")
|
|
monkeypatch.chdir(tmp_path)
|
|
|
|
uvicorn_args: dict = {}
|
|
with patch("litellm._logging.verbose_proxy_logger.warning") as mock_warning:
|
|
ProxyInitializationHelpers._configure_dev_reload(
|
|
uvicorn_args, str(config_file)
|
|
)
|
|
|
|
assert os.environ["LITELLM_DEV_ENV_HOT_RELOAD"] == "True"
|
|
assert uvicorn_args["reload"] is True
|
|
assert ".env" in uvicorn_args["reload_includes"]
|
|
|
|
mock_warning.assert_called_once()
|
|
warning_text = mock_warning.call_args.args[0].lower()
|
|
assert "override" in warning_text
|
|
assert ".env" in warning_text
|
|
|
|
fake_self = types.SimpleNamespace(
|
|
config=types.SimpleNamespace(reload_dirs=[tmp_path])
|
|
)
|
|
yielded_paths = {Path(p).resolve() for p in StatReload.iter_py_files(fake_self)}
|
|
assert env_file.resolve() in yielded_paths
|
|
assert config_file.resolve() in yielded_paths
|
|
|
|
def test_dev_env_hot_reload_enabled_reads_flag(self, monkeypatch):
|
|
import litellm
|
|
|
|
monkeypatch.setenv("LITELLM_DEV_ENV_HOT_RELOAD", "True")
|
|
assert litellm._dev_env_hot_reload_enabled() is True
|
|
|
|
monkeypatch.setenv("LITELLM_DEV_ENV_HOT_RELOAD", "false")
|
|
assert litellm._dev_env_hot_reload_enabled() is False
|
|
|
|
monkeypatch.delenv("LITELLM_DEV_ENV_HOT_RELOAD", raising=False)
|
|
assert litellm._dev_env_hot_reload_enabled() is False
|
|
|
|
@patch("asyncio.run")
|
|
@patch("builtins.print")
|
|
def test_init_hypercorn_server(self, mock_print, mock_asyncio_run):
|
|
# Setup
|
|
mock_app = MagicMock()
|
|
|
|
# Execute
|
|
ProxyInitializationHelpers._init_hypercorn_server(
|
|
mock_app, "localhost", 8000, None, None, None
|
|
)
|
|
|
|
# Assert
|
|
mock_asyncio_run.assert_called_once()
|
|
|
|
# Test with SSL
|
|
ProxyInitializationHelpers._init_hypercorn_server(
|
|
mock_app, "localhost", 8000, "cert.pem", "key.pem", "ECDHE"
|
|
)
|
|
|
|
@patch("granian.Granian")
|
|
@patch("builtins.print")
|
|
def test_init_granian_server(self, mock_print, mock_granian_cls):
|
|
pytest.importorskip("granian")
|
|
mock_server = MagicMock()
|
|
mock_granian_cls.return_value = mock_server
|
|
fake_interfaces = SimpleNamespace(ASGI="asgi")
|
|
with patch("granian.constants.Interfaces", fake_interfaces):
|
|
ProxyInitializationHelpers._init_granian_server(
|
|
host="0.0.0.0",
|
|
port=4000,
|
|
num_workers=2,
|
|
ssl_certfile_path=None,
|
|
ssl_keyfile_path=None,
|
|
max_requests_before_restart=None,
|
|
ciphers=None,
|
|
granian_runtime_threads=None,
|
|
)
|
|
mock_granian_cls.assert_called_once()
|
|
call_kwargs = mock_granian_cls.call_args.kwargs
|
|
assert call_kwargs["target"] == "litellm.proxy.proxy_server:app"
|
|
assert call_kwargs["address"] == "0.0.0.0"
|
|
assert call_kwargs["port"] == 4000
|
|
assert call_kwargs["workers"] == 2
|
|
assert call_kwargs["interface"] == "asgi"
|
|
assert call_kwargs["websockets"] is True
|
|
assert "runtime_threads" not in call_kwargs
|
|
mock_server.serve.assert_called_once()
|
|
|
|
@patch("granian.Granian")
|
|
@patch("builtins.print")
|
|
def test_init_granian_server_runtime_threads(self, mock_print, mock_granian_cls):
|
|
pytest.importorskip("granian")
|
|
mock_server = MagicMock()
|
|
mock_granian_cls.return_value = mock_server
|
|
fake_interfaces = SimpleNamespace(ASGI="asgi")
|
|
with patch("granian.constants.Interfaces", fake_interfaces):
|
|
ProxyInitializationHelpers._init_granian_server(
|
|
host="0.0.0.0",
|
|
port=4000,
|
|
num_workers=1,
|
|
ssl_certfile_path=None,
|
|
ssl_keyfile_path=None,
|
|
max_requests_before_restart=None,
|
|
ciphers=None,
|
|
granian_runtime_threads=4,
|
|
)
|
|
assert mock_granian_cls.call_args.kwargs["runtime_threads"] == 4
|
|
|
|
@patch("granian.Granian")
|
|
@patch("builtins.print")
|
|
def test_init_granian_server_ssl(self, mock_print, mock_granian_cls):
|
|
pytest.importorskip("granian")
|
|
mock_server = MagicMock()
|
|
mock_granian_cls.return_value = mock_server
|
|
fake_interfaces = SimpleNamespace(ASGI="asgi")
|
|
with patch("granian.constants.Interfaces", fake_interfaces):
|
|
ProxyInitializationHelpers._init_granian_server(
|
|
host="0.0.0.0",
|
|
port=4000,
|
|
num_workers=1,
|
|
ssl_certfile_path="/path/to/cert.pem",
|
|
ssl_keyfile_path="/path/to/key.pem",
|
|
max_requests_before_restart=None,
|
|
ciphers=None,
|
|
granian_runtime_threads=None,
|
|
)
|
|
call_kwargs = mock_granian_cls.call_args.kwargs
|
|
assert call_kwargs["ssl_cert"] == Path("/path/to/cert.pem")
|
|
assert call_kwargs["ssl_key"] == Path("/path/to/key.pem")
|
|
mock_server.serve.assert_called_once()
|
|
|
|
@patch("granian.Granian")
|
|
def test_init_granian_server_ssl_requires_cert_and_key(self, mock_granian_cls):
|
|
pytest.importorskip("granian")
|
|
fake_interfaces = SimpleNamespace(ASGI="asgi")
|
|
with patch("granian.constants.Interfaces", fake_interfaces):
|
|
with pytest.raises(click.ClickException, match="Both --ssl_certfile_path"):
|
|
ProxyInitializationHelpers._init_granian_server(
|
|
host="0.0.0.0",
|
|
port=4000,
|
|
num_workers=1,
|
|
ssl_certfile_path="/path/to/cert.pem",
|
|
ssl_keyfile_path=None,
|
|
max_requests_before_restart=None,
|
|
ciphers=None,
|
|
granian_runtime_threads=None,
|
|
)
|
|
mock_granian_cls.assert_not_called()
|
|
|
|
@patch("subprocess.Popen")
|
|
def test_run_ollama_serve(self, mock_popen):
|
|
# Execute
|
|
ProxyInitializationHelpers._run_ollama_serve()
|
|
|
|
# Assert
|
|
mock_popen.assert_called_once()
|
|
|
|
# Test exception handling
|
|
mock_popen.side_effect = Exception("Test exception")
|
|
ProxyInitializationHelpers._run_ollama_serve() # Should not raise
|
|
|
|
@patch("socket.socket")
|
|
def test_is_port_in_use(self, mock_socket):
|
|
# Setup for port in use
|
|
mock_socket_instance = MagicMock()
|
|
mock_socket_instance.connect_ex.return_value = 0
|
|
mock_socket.return_value.__enter__.return_value = mock_socket_instance
|
|
|
|
# Execute and Assert
|
|
assert ProxyInitializationHelpers._is_port_in_use(8000) is True
|
|
|
|
# Setup for port not in use
|
|
mock_socket_instance.connect_ex.return_value = 1
|
|
|
|
# Execute and Assert
|
|
assert ProxyInitializationHelpers._is_port_in_use(8000) is False
|
|
|
|
def test_get_loop_type(self):
|
|
# Test on Windows
|
|
with patch("sys.platform", "win32"):
|
|
assert ProxyInitializationHelpers._get_loop_type() is None
|
|
|
|
# Test on Linux
|
|
with patch("sys.platform", "linux"):
|
|
assert ProxyInitializationHelpers._get_loop_type() == "uvloop"
|
|
|
|
@patch.dict(os.environ, {}, clear=True)
|
|
def test_database_url_construction_with_special_characters(self):
|
|
# Setup environment variables with special characters that need escaping
|
|
test_env = {
|
|
"DATABASE_HOST": "localhost:5432",
|
|
"DATABASE_USERNAME": "user@with+special",
|
|
"DATABASE_PASSWORD": "test-password-special-chars",
|
|
"DATABASE_NAME": "db_name/test",
|
|
}
|
|
|
|
with patch.dict(os.environ, test_env):
|
|
# Call the relevant function - we'll need to extract the database URL construction logic
|
|
# This is simulating what happens in the run_server function when database_url is None
|
|
import urllib.parse
|
|
|
|
from litellm.proxy.proxy_cli import append_query_params
|
|
|
|
database_host = os.environ["DATABASE_HOST"]
|
|
database_username = os.environ["DATABASE_USERNAME"]
|
|
database_password = os.environ["DATABASE_PASSWORD"]
|
|
database_name = os.environ["DATABASE_NAME"]
|
|
|
|
# Test the URL encoding part
|
|
database_username_enc = urllib.parse.quote_plus(database_username)
|
|
database_password_enc = urllib.parse.quote_plus(database_password)
|
|
database_name_enc = urllib.parse.quote_plus(database_name)
|
|
|
|
# Construct DATABASE_URL from the provided variables
|
|
database_url = f"postgresql://{database_username_enc}:{database_password_enc}@{database_host}/{database_name_enc}"
|
|
|
|
# Assert the correct URL was constructed with properly escaped characters
|
|
expected_url = "postgresql://user%40with%2Bspecial:test-password-special-chars@localhost:5432/db_name%2Ftest"
|
|
assert database_url == expected_url
|
|
|
|
# Test appending query parameters
|
|
params = {"connection_limit": 10, "pool_timeout": 60}
|
|
modified_url = append_query_params(database_url, params)
|
|
assert "connection_limit=10" in modified_url
|
|
assert "pool_timeout=60" in modified_url
|
|
|
|
def test_append_query_params_handles_missing_url(self):
|
|
from litellm.proxy.proxy_cli import append_query_params
|
|
|
|
modified_url = append_query_params(None, {"connection_limit": 10})
|
|
assert modified_url == ""
|
|
|
|
@patch("uvicorn.run")
|
|
@patch("atexit.register") # critical
|
|
@patch("litellm.proxy.db.prisma_client.PrismaManager.setup_database")
|
|
@patch(
|
|
"litellm.proxy.db.prisma_client.should_update_prisma_schema", return_value=False
|
|
)
|
|
def test_skip_server_startup(
|
|
self, mock_should_update, mock_setup_db, mock_atexit_register, mock_uvicorn_run
|
|
):
|
|
from click.testing import CliRunner
|
|
|
|
from litellm.proxy.proxy_cli import run_server
|
|
|
|
runner = CliRunner()
|
|
|
|
mock_proxy_module = MagicMock(
|
|
app=MagicMock(),
|
|
ProxyConfig=MagicMock(),
|
|
KeyManagementSettings=MagicMock(),
|
|
save_worker_config=MagicMock(),
|
|
)
|
|
# Remove DATABASE_URL/DIRECT_URL so the CLI doesn't attempt
|
|
# real prisma operations when these are set in CI.
|
|
clean_env = {
|
|
k: v
|
|
for k, v in os.environ.items()
|
|
if k not in ("DATABASE_URL", "DIRECT_URL")
|
|
}
|
|
with (
|
|
patch.dict(
|
|
os.environ,
|
|
clean_env,
|
|
clear=True,
|
|
),
|
|
patch.dict(
|
|
"sys.modules",
|
|
{
|
|
"proxy_server": mock_proxy_module,
|
|
# Prevent real import of proxy_server inside Click's
|
|
# isolation context (heavy side effects cause stream
|
|
# lifecycle issues with Click 8.2+)
|
|
"litellm.proxy.proxy_server": mock_proxy_module,
|
|
},
|
|
),
|
|
patch(
|
|
"litellm.proxy.proxy_cli.ProxyInitializationHelpers._get_default_unvicorn_init_args"
|
|
) as mock_get_args,
|
|
):
|
|
mock_get_args.return_value = {
|
|
"app": "litellm.proxy.proxy_server:app",
|
|
"host": "localhost",
|
|
"port": 8000,
|
|
}
|
|
|
|
# --- skip startup ---
|
|
result = runner.invoke(run_server, ["--local", "--skip_server_startup"])
|
|
|
|
assert (
|
|
result.exit_code == 0
|
|
), f"exit_code={result.exit_code}, output={result.output}"
|
|
assert "Skipping server startup" in result.output
|
|
mock_uvicorn_run.assert_not_called()
|
|
|
|
# --- normal startup ---
|
|
mock_uvicorn_run.reset_mock()
|
|
|
|
result = runner.invoke(run_server, ["--local"])
|
|
|
|
assert (
|
|
result.exit_code == 0
|
|
), f"exit_code={result.exit_code}, output={result.output}"
|
|
mock_uvicorn_run.assert_called_once()
|
|
|
|
@pytest.mark.parametrize(
|
|
"timeout_config,expected_timeout",
|
|
[
|
|
({"database_connection_timeout": 30}, 30),
|
|
({"database_connection_pool_timeout": 45}, 45),
|
|
(
|
|
{
|
|
"database_connection_timeout": 30,
|
|
"database_connection_pool_timeout": 45,
|
|
},
|
|
30,
|
|
),
|
|
],
|
|
)
|
|
@patch("subprocess.run")
|
|
@patch("atexit.register")
|
|
@patch("litellm.proxy.db.prisma_client.PrismaManager.setup_database")
|
|
@patch(
|
|
"litellm.proxy.db.prisma_client.should_update_prisma_schema", return_value=False
|
|
)
|
|
def test_db_timeout_settings_are_forwarded_to_pool_timeout(
|
|
self,
|
|
mock_should_update,
|
|
mock_setup_db,
|
|
mock_atexit_register,
|
|
mock_subprocess_run,
|
|
timeout_config,
|
|
expected_timeout,
|
|
):
|
|
from click.testing import CliRunner
|
|
|
|
from litellm.proxy.proxy_cli import run_server
|
|
|
|
runner = CliRunner()
|
|
mock_subprocess_run.return_value = MagicMock(returncode=0)
|
|
|
|
mock_proxy_module = MagicMock(
|
|
app=MagicMock(),
|
|
ProxyConfig=MagicMock(),
|
|
KeyManagementSettings=MagicMock(),
|
|
save_worker_config=MagicMock(),
|
|
)
|
|
mock_proxy_module.ProxyConfig.return_value.get_config = AsyncMock(
|
|
return_value={
|
|
"general_settings": {
|
|
"database_url": "postgresql://test:test@localhost:5432/test",
|
|
"database_connection_pool_limit": 5,
|
|
**timeout_config,
|
|
}
|
|
}
|
|
)
|
|
|
|
clean_env = {
|
|
k: v
|
|
for k, v in os.environ.items()
|
|
if k not in ("DATABASE_URL", "DIRECT_URL")
|
|
}
|
|
|
|
with (
|
|
patch.dict(os.environ, clean_env, clear=True),
|
|
patch.dict(
|
|
"sys.modules",
|
|
{
|
|
"proxy_server": mock_proxy_module,
|
|
"litellm.proxy.proxy_server": mock_proxy_module,
|
|
},
|
|
),
|
|
patch(
|
|
"litellm.proxy.proxy_cli.ProxyInitializationHelpers._get_default_unvicorn_init_args"
|
|
) as mock_get_args,
|
|
patch(
|
|
"litellm.proxy.proxy_cli.append_query_params",
|
|
side_effect=lambda url, params: (
|
|
f"{url}?connection_limit={params['connection_limit']}&pool_timeout={params['pool_timeout']}"
|
|
),
|
|
) as mock_append_query_params,
|
|
):
|
|
mock_get_args.return_value = {
|
|
"app": "litellm.proxy.proxy_server:app",
|
|
"host": "localhost",
|
|
"port": 8000,
|
|
}
|
|
|
|
result = runner.invoke(
|
|
run_server,
|
|
["--local", "--config", "test-config.yaml", "--skip_server_startup"],
|
|
)
|
|
|
|
assert (
|
|
result.exit_code == 0
|
|
), f"exit_code={result.exit_code}, output={result.output}"
|
|
mock_append_query_params.assert_called()
|
|
appended_params = mock_append_query_params.call_args.args[1]
|
|
assert appended_params["connection_limit"] == 5
|
|
assert appended_params["pool_timeout"] == expected_timeout
|
|
|
|
def test_build_db_connection_url_params_defaults(self):
|
|
from litellm.proxy.proxy_cli import _build_db_connection_url_params
|
|
|
|
params = _build_db_connection_url_params(connection_limit=10, pool_timeout=60)
|
|
assert params == {"connection_limit": 10, "pool_timeout": 60}
|
|
|
|
def test_build_db_connection_url_params_omits_none_timeouts(self):
|
|
from litellm.proxy.proxy_cli import _build_db_connection_url_params
|
|
|
|
params = _build_db_connection_url_params(
|
|
connection_limit=10,
|
|
pool_timeout=60,
|
|
connect_timeout=None,
|
|
socket_timeout=None,
|
|
)
|
|
assert "connect_timeout" not in params
|
|
assert "socket_timeout" not in params
|
|
|
|
def test_build_db_connection_url_params_includes_optional_timeouts(self):
|
|
from litellm.proxy.proxy_cli import _build_db_connection_url_params
|
|
|
|
params = _build_db_connection_url_params(
|
|
connection_limit=10,
|
|
pool_timeout=60,
|
|
connect_timeout=15,
|
|
socket_timeout=120,
|
|
)
|
|
assert params["connect_timeout"] == 15
|
|
assert params["socket_timeout"] == 120
|
|
|
|
def test_build_db_connection_url_params_extras_override_defaults(self):
|
|
from litellm.proxy.proxy_cli import _build_db_connection_url_params
|
|
|
|
params = _build_db_connection_url_params(
|
|
connection_limit=10,
|
|
pool_timeout=60,
|
|
extra_params={
|
|
"pgbouncer": "true",
|
|
"statement_cache_size": 0,
|
|
"pool_timeout": 5,
|
|
},
|
|
)
|
|
assert params["pgbouncer"] == "true"
|
|
assert params["statement_cache_size"] == 0
|
|
assert params["pool_timeout"] == 5
|
|
|
|
@patch("subprocess.run")
|
|
@patch("atexit.register")
|
|
@patch("litellm.proxy.db.prisma_client.PrismaManager.setup_database")
|
|
@patch(
|
|
"litellm.proxy.db.prisma_client.should_update_prisma_schema", return_value=False
|
|
)
|
|
def test_db_connection_extra_params_forwarded_to_url(
|
|
self,
|
|
mock_should_update,
|
|
mock_setup_db,
|
|
mock_atexit_register,
|
|
mock_subprocess_run,
|
|
):
|
|
from click.testing import CliRunner
|
|
|
|
from litellm.proxy.proxy_cli import run_server
|
|
|
|
runner = CliRunner()
|
|
mock_subprocess_run.return_value = MagicMock(returncode=0)
|
|
|
|
mock_proxy_module = MagicMock(
|
|
app=MagicMock(),
|
|
ProxyConfig=MagicMock(),
|
|
KeyManagementSettings=MagicMock(),
|
|
save_worker_config=MagicMock(),
|
|
)
|
|
mock_proxy_module.ProxyConfig.return_value.get_config = AsyncMock(
|
|
return_value={
|
|
"general_settings": {
|
|
"database_url": "postgresql://test:test@localhost:5432/test",
|
|
"database_connect_timeout": 15,
|
|
"database_socket_timeout": 120,
|
|
"database_extra_connection_params": {
|
|
"pgbouncer": "true",
|
|
"statement_cache_size": 0,
|
|
},
|
|
}
|
|
}
|
|
)
|
|
|
|
clean_env = {
|
|
k: v
|
|
for k, v in os.environ.items()
|
|
if k not in ("DATABASE_URL", "DIRECT_URL")
|
|
}
|
|
|
|
with (
|
|
patch.dict(os.environ, clean_env, clear=True),
|
|
patch.dict(
|
|
"sys.modules",
|
|
{
|
|
"proxy_server": mock_proxy_module,
|
|
"litellm.proxy.proxy_server": mock_proxy_module,
|
|
},
|
|
),
|
|
patch(
|
|
"litellm.proxy.proxy_cli.ProxyInitializationHelpers._get_default_unvicorn_init_args"
|
|
) as mock_get_args,
|
|
patch(
|
|
"litellm.proxy.proxy_cli.append_query_params",
|
|
side_effect=lambda url, params: str(url),
|
|
) as mock_append_query_params,
|
|
):
|
|
mock_get_args.return_value = {
|
|
"app": "litellm.proxy.proxy_server:app",
|
|
"host": "localhost",
|
|
"port": 8000,
|
|
}
|
|
|
|
result = runner.invoke(
|
|
run_server,
|
|
["--local", "--config", "test-config.yaml", "--skip_server_startup"],
|
|
)
|
|
|
|
assert (
|
|
result.exit_code == 0
|
|
), f"exit_code={result.exit_code}, output={result.output}"
|
|
mock_append_query_params.assert_called()
|
|
appended_params = mock_append_query_params.call_args.args[1]
|
|
assert appended_params["connect_timeout"] == 15
|
|
assert appended_params["socket_timeout"] == 120
|
|
assert appended_params["pgbouncer"] == "true"
|
|
assert appended_params["statement_cache_size"] == 0
|
|
|
|
@patch("uvicorn.run")
|
|
@patch("atexit.register")
|
|
@patch("litellm.proxy.db.prisma_client.PrismaManager.setup_database")
|
|
@patch(
|
|
"litellm.proxy.db.prisma_client.should_update_prisma_schema", return_value=False
|
|
)
|
|
def test_proxy_default_api_version_uses_azure_default(
|
|
self, mock_should_update, mock_setup_db, mock_atexit_register, mock_uvicorn_run
|
|
):
|
|
"""Proxy default api_version should match litellm.AZURE_DEFAULT_API_VERSION for consistency."""
|
|
from click.testing import CliRunner
|
|
|
|
import litellm
|
|
from litellm.proxy.proxy_cli import run_server
|
|
|
|
runner = CliRunner()
|
|
mock_proxy_module = MagicMock(
|
|
app=MagicMock(),
|
|
ProxyConfig=MagicMock(),
|
|
KeyManagementSettings=MagicMock(),
|
|
save_worker_config=MagicMock(),
|
|
)
|
|
clean_env = {
|
|
k: v
|
|
for k, v in os.environ.items()
|
|
if k not in ("DATABASE_URL", "DIRECT_URL")
|
|
}
|
|
with (
|
|
patch.dict(os.environ, clean_env, clear=True),
|
|
patch.dict(
|
|
"sys.modules",
|
|
{
|
|
"proxy_server": mock_proxy_module,
|
|
"litellm.proxy.proxy_server": mock_proxy_module,
|
|
},
|
|
),
|
|
patch(
|
|
"litellm.proxy.proxy_cli.ProxyInitializationHelpers._get_default_unvicorn_init_args"
|
|
) as mock_get_args,
|
|
):
|
|
mock_get_args.return_value = {
|
|
"app": "litellm.proxy.proxy_server:app",
|
|
"host": "localhost",
|
|
"port": 8000,
|
|
}
|
|
result = runner.invoke(run_server, ["--local", "--skip_server_startup"])
|
|
assert (
|
|
result.exit_code == 0
|
|
), f"exit_code={result.exit_code}, output={result.output}"
|
|
mock_proxy_module.save_worker_config.assert_called_once()
|
|
call_kwargs = mock_proxy_module.save_worker_config.call_args[1]
|
|
assert call_kwargs["api_version"] == litellm.AZURE_DEFAULT_API_VERSION
|
|
|
|
@patch("uvicorn.run")
|
|
@patch("builtins.print")
|
|
@patch("litellm.proxy.db.prisma_client.PrismaManager.setup_database")
|
|
@patch(
|
|
"litellm.proxy.db.prisma_client.should_update_prisma_schema", return_value=False
|
|
)
|
|
def test_keepalive_timeout_flag(
|
|
self, mock_should_update, mock_setup_db, mock_print, mock_uvicorn_run
|
|
):
|
|
"""Test that the keepalive_timeout flag is properly passed to uvicorn"""
|
|
from click.testing import CliRunner
|
|
|
|
from litellm.proxy.proxy_cli import run_server
|
|
|
|
runner = CliRunner()
|
|
|
|
mock_app = MagicMock()
|
|
mock_proxy_config = MagicMock()
|
|
mock_key_mgmt = MagicMock()
|
|
mock_save_worker_config = MagicMock()
|
|
|
|
# Strip DATABASE_URL/DIRECT_URL so run_server doesn't enter the prisma
|
|
# DB-setup block (un-timeout'd `subprocess.run(["prisma"])` +
|
|
# migrate-deploy retry loop) — same isolation every other run_server
|
|
# test in this file uses.
|
|
clean_env = {
|
|
k: v
|
|
for k, v in os.environ.items()
|
|
if k not in ("DATABASE_URL", "DIRECT_URL")
|
|
}
|
|
|
|
with (
|
|
patch.dict(os.environ, clean_env, clear=True),
|
|
patch.dict(
|
|
"sys.modules",
|
|
{
|
|
"proxy_server": MagicMock(
|
|
app=mock_app,
|
|
ProxyConfig=mock_proxy_config,
|
|
KeyManagementSettings=mock_key_mgmt,
|
|
save_worker_config=mock_save_worker_config,
|
|
)
|
|
},
|
|
),
|
|
patch(
|
|
"litellm.proxy.proxy_cli.ProxyInitializationHelpers._get_default_unvicorn_init_args"
|
|
) as mock_get_args,
|
|
patch(
|
|
"litellm.proxy.proxy_cli.ProxyInitializationHelpers._is_port_in_use",
|
|
return_value=False,
|
|
),
|
|
):
|
|
mock_get_args.return_value = {
|
|
"app": "litellm.proxy.proxy_server:app",
|
|
"host": "localhost",
|
|
"port": 8000,
|
|
"timeout_keep_alive": 30,
|
|
}
|
|
|
|
result = runner.invoke(run_server, ["--local", "--keepalive_timeout", "30"])
|
|
|
|
assert result.exit_code == 0
|
|
mock_get_args.assert_called_once_with(
|
|
host="0.0.0.0",
|
|
port=4000,
|
|
log_config=None,
|
|
keepalive_timeout=30,
|
|
timeout_worker_healthcheck=None,
|
|
)
|
|
mock_uvicorn_run.assert_called_once()
|
|
|
|
# Check that the uvicorn.run was called with the timeout_keep_alive parameter
|
|
call_args = mock_uvicorn_run.call_args
|
|
assert call_args[1]["timeout_keep_alive"] == 30
|
|
|
|
@patch("uvicorn.run")
|
|
@patch("builtins.print")
|
|
@patch("litellm.proxy.db.prisma_client.PrismaManager.setup_database")
|
|
@patch(
|
|
"litellm.proxy.db.prisma_client.should_update_prisma_schema", return_value=False
|
|
)
|
|
def test_timeout_worker_healthcheck_flag(
|
|
self, mock_should_update, mock_setup_db, mock_print, mock_uvicorn_run
|
|
):
|
|
"""Test that the --timeout_worker_healthcheck flag is threaded through to the uvicorn init helper."""
|
|
from click.testing import CliRunner
|
|
|
|
from litellm.proxy.proxy_cli import run_server
|
|
|
|
runner = CliRunner()
|
|
|
|
mock_app = MagicMock()
|
|
mock_proxy_config = MagicMock()
|
|
mock_key_mgmt = MagicMock()
|
|
mock_save_worker_config = MagicMock()
|
|
|
|
# Strip DATABASE_URL/DIRECT_URL so run_server doesn't enter the prisma
|
|
# DB-setup block (un-timeout'd `subprocess.run(["prisma"])` +
|
|
# migrate-deploy retry loop) — same isolation every other run_server
|
|
# test in this file uses.
|
|
clean_env = {
|
|
k: v
|
|
for k, v in os.environ.items()
|
|
if k not in ("DATABASE_URL", "DIRECT_URL")
|
|
}
|
|
|
|
with (
|
|
patch.dict(os.environ, clean_env, clear=True),
|
|
patch.dict(
|
|
"sys.modules",
|
|
{
|
|
"proxy_server": MagicMock(
|
|
app=mock_app,
|
|
ProxyConfig=mock_proxy_config,
|
|
KeyManagementSettings=mock_key_mgmt,
|
|
save_worker_config=mock_save_worker_config,
|
|
)
|
|
},
|
|
),
|
|
patch(
|
|
"litellm.proxy.proxy_cli.ProxyInitializationHelpers._get_default_unvicorn_init_args"
|
|
) as mock_get_args,
|
|
patch(
|
|
"litellm.proxy.proxy_cli.ProxyInitializationHelpers._is_port_in_use",
|
|
return_value=False,
|
|
),
|
|
):
|
|
mock_get_args.return_value = {
|
|
"app": "litellm.proxy.proxy_server:app",
|
|
"host": "localhost",
|
|
"port": 8000,
|
|
}
|
|
|
|
result = runner.invoke(
|
|
run_server, ["--local", "--timeout_worker_healthcheck", "15"]
|
|
)
|
|
|
|
assert result.exit_code == 0
|
|
mock_get_args.assert_called_once_with(
|
|
host="0.0.0.0",
|
|
port=4000,
|
|
log_config=None,
|
|
keepalive_timeout=None,
|
|
timeout_worker_healthcheck=15,
|
|
)
|
|
|
|
@patch("uvicorn.run")
|
|
@patch("builtins.print")
|
|
@patch("litellm.proxy.db.prisma_client.PrismaManager.setup_database")
|
|
def test_max_requests_before_restart_flag(
|
|
self, mock_setup_db, mock_print, mock_uvicorn_run
|
|
):
|
|
"""Test that the max_requests_before_restart flag is passed to uvicorn as limit_max_requests"""
|
|
from click.testing import CliRunner
|
|
|
|
from litellm.proxy.proxy_cli import run_server
|
|
|
|
runner = CliRunner()
|
|
|
|
mock_app = MagicMock()
|
|
mock_proxy_config = MagicMock()
|
|
mock_key_mgmt = MagicMock()
|
|
mock_save_worker_config = MagicMock()
|
|
|
|
clean_env = {
|
|
k: v
|
|
for k, v in os.environ.items()
|
|
if k not in ("DATABASE_URL", "DIRECT_URL")
|
|
}
|
|
with (
|
|
patch.dict(
|
|
os.environ,
|
|
clean_env,
|
|
clear=True,
|
|
),
|
|
patch.dict(
|
|
"sys.modules",
|
|
{
|
|
"proxy_server": MagicMock(
|
|
app=mock_app,
|
|
ProxyConfig=mock_proxy_config,
|
|
KeyManagementSettings=mock_key_mgmt,
|
|
save_worker_config=mock_save_worker_config,
|
|
)
|
|
},
|
|
),
|
|
patch(
|
|
"litellm.proxy.proxy_cli.ProxyInitializationHelpers._get_default_unvicorn_init_args"
|
|
) as mock_get_args,
|
|
):
|
|
mock_get_args.return_value = {
|
|
"app": "litellm.proxy.proxy_server:app",
|
|
"host": "localhost",
|
|
"port": 8000,
|
|
}
|
|
|
|
result = runner.invoke(
|
|
run_server, ["--local", "--max_requests_before_restart", "123"]
|
|
)
|
|
|
|
assert (
|
|
result.exit_code == 0
|
|
), f"exit_code={result.exit_code}, output={result.output}"
|
|
mock_uvicorn_run.assert_called_once()
|
|
|
|
# Check that uvicorn.run was called with limit_max_requests parameter
|
|
call_args = mock_uvicorn_run.call_args
|
|
assert call_args[1]["limit_max_requests"] == 123
|
|
|
|
@patch.dict(os.environ, {}, clear=True)
|
|
def test_construct_database_url_from_env_vars(self):
|
|
"""Test the construct_database_url_from_env_vars function with various scenarios"""
|
|
from litellm.proxy.utils import construct_database_url_from_env_vars
|
|
|
|
# Test with all required variables present
|
|
test_env = {
|
|
"DATABASE_HOST": "localhost:5432",
|
|
"DATABASE_USERNAME": "testuser",
|
|
"DATABASE_PASSWORD": "testpass",
|
|
"DATABASE_NAME": "testdb",
|
|
}
|
|
|
|
with patch.dict(os.environ, test_env):
|
|
result = construct_database_url_from_env_vars()
|
|
expected_url = "postgresql://testuser:testpass@localhost:5432/testdb"
|
|
assert result == expected_url
|
|
|
|
# Test with special characters that need URL encoding
|
|
test_env_special = {
|
|
"DATABASE_HOST": "localhost:5432",
|
|
"DATABASE_USERNAME": "user@with+special",
|
|
"DATABASE_PASSWORD": "test-password-special-chars",
|
|
"DATABASE_NAME": "db_name/test",
|
|
}
|
|
|
|
with patch.dict(os.environ, test_env_special):
|
|
result = construct_database_url_from_env_vars()
|
|
expected_url = "postgresql://user%40with%2Bspecial:test-password-special-chars@localhost:5432/db_name%2Ftest"
|
|
assert result == expected_url
|
|
|
|
# Test without password (should still work)
|
|
test_env_no_password = {
|
|
"DATABASE_HOST": "localhost:5432",
|
|
"DATABASE_USERNAME": "testuser",
|
|
"DATABASE_NAME": "testdb",
|
|
}
|
|
|
|
with patch.dict(os.environ, test_env_no_password):
|
|
result = construct_database_url_from_env_vars()
|
|
expected_url = "postgresql://testuser@localhost:5432/testdb"
|
|
assert result == expected_url
|
|
|
|
# Test with missing required variables (should return None)
|
|
test_env_missing = {
|
|
"DATABASE_HOST": "localhost:5432",
|
|
"DATABASE_USERNAME": "testuser",
|
|
# Missing DATABASE_NAME
|
|
}
|
|
|
|
with patch.dict(os.environ, test_env_missing):
|
|
result = construct_database_url_from_env_vars()
|
|
assert result is None
|
|
|
|
# Test with empty environment (should return None)
|
|
with patch.dict(os.environ, {}, clear=True):
|
|
result = construct_database_url_from_env_vars()
|
|
assert result is None
|
|
|
|
@patch("uvicorn.run")
|
|
@patch("builtins.print")
|
|
def test_run_server_no_config_passed(self, mock_print, mock_uvicorn_run):
|
|
"""Test that run_server properly handles the case when no config is passed"""
|
|
import asyncio
|
|
|
|
from click.testing import CliRunner
|
|
|
|
from litellm.proxy.proxy_cli import run_server
|
|
|
|
runner = CliRunner()
|
|
|
|
mock_app = MagicMock()
|
|
mock_proxy_config = MagicMock()
|
|
mock_key_mgmt = MagicMock()
|
|
mock_save_worker_config = MagicMock()
|
|
|
|
# Mock the ProxyConfig.get_config method to return a proper async config
|
|
async def mock_get_config(config_file_path=None):
|
|
return {"general_settings": {}, "litellm_settings": {}}
|
|
|
|
mock_proxy_config_instance = MagicMock()
|
|
mock_proxy_config_instance.get_config = mock_get_config
|
|
mock_proxy_config.return_value = mock_proxy_config_instance
|
|
|
|
mock_proxy_server_module = MagicMock(app=mock_app)
|
|
|
|
# Only remove DATABASE_URL and DIRECT_URL to prevent the database setup
|
|
# code path from running. Do NOT use clear=True as it removes PATH, HOME,
|
|
# etc., which causes imports inside run_server to break in CI (the real
|
|
# litellm.proxy.proxy_server import at line 820 of proxy_cli.py has heavy
|
|
# side effects that fail without a proper environment).
|
|
env_overrides = {
|
|
"DATABASE_URL": "",
|
|
"DIRECT_URL": "",
|
|
"IAM_TOKEN_DB_AUTH": "",
|
|
"USE_AWS_KMS": "",
|
|
}
|
|
with patch.dict(os.environ, env_overrides):
|
|
# Remove DATABASE_URL entirely so the DB setup block is skipped
|
|
os.environ.pop("DATABASE_URL", None)
|
|
os.environ.pop("DIRECT_URL", None)
|
|
|
|
with (
|
|
patch.dict(
|
|
"sys.modules",
|
|
{
|
|
"proxy_server": MagicMock(
|
|
app=mock_app,
|
|
ProxyConfig=mock_proxy_config,
|
|
KeyManagementSettings=mock_key_mgmt,
|
|
save_worker_config=mock_save_worker_config,
|
|
),
|
|
# Also mock litellm.proxy.proxy_server to prevent the real
|
|
# import at line 820 of proxy_cli.py which has heavy side
|
|
# effects (FastAPI app init, logging setup, etc.)
|
|
"litellm.proxy.proxy_server": mock_proxy_server_module,
|
|
},
|
|
),
|
|
patch(
|
|
"litellm.proxy.proxy_cli.ProxyInitializationHelpers._get_default_unvicorn_init_args"
|
|
) as mock_get_args,
|
|
):
|
|
mock_get_args.return_value = {
|
|
"app": "litellm.proxy.proxy_server:app",
|
|
"host": "localhost",
|
|
"port": 8000,
|
|
}
|
|
|
|
# Test with no config parameter (config=None)
|
|
result = runner.invoke(run_server, ["--local"])
|
|
|
|
assert result.exit_code == 0, (
|
|
f"run_server failed with exit_code={result.exit_code}, "
|
|
f"output={result.output}, exception={result.exception}"
|
|
)
|
|
|
|
# Verify that uvicorn.run was called
|
|
mock_uvicorn_run.assert_called_once()
|
|
|
|
# Reset mocks for second test
|
|
mock_uvicorn_run.reset_mock()
|
|
|
|
# Test with explicit --config None (should behave the same)
|
|
result = runner.invoke(run_server, ["--local", "--config", "None"])
|
|
|
|
assert result.exit_code == 0, (
|
|
f"run_server failed with exit_code={result.exit_code}, "
|
|
f"output={result.output}, exception={result.exception}"
|
|
)
|
|
|
|
# Verify that uvicorn.run was called again
|
|
mock_uvicorn_run.assert_called_once()
|
|
|
|
|
|
class TestRunServerDbSetup:
|
|
"""Tests for run_server's prisma setup_database behavior."""
|
|
|
|
@patch("subprocess.run")
|
|
@patch("atexit.register")
|
|
@patch("litellm.proxy.db.prisma_client.PrismaManager.setup_database")
|
|
@patch("litellm.proxy.db.check_migration.check_prisma_schema_diff")
|
|
@patch("litellm.proxy.db.prisma_client.should_update_prisma_schema")
|
|
def test_use_prisma_db_push_flag_behavior(
|
|
self,
|
|
mock_should_update_schema,
|
|
mock_check_schema_diff,
|
|
mock_setup_database,
|
|
mock_atexit_register,
|
|
mock_subprocess_run,
|
|
):
|
|
"""Test that use_prisma_db_push flag correctly controls PrismaManager.setup_database use_migrate parameter"""
|
|
from litellm.proxy.proxy_cli import run_server
|
|
|
|
# Mock subprocess.run to simulate prisma being available
|
|
mock_subprocess_run.return_value = MagicMock(returncode=0)
|
|
|
|
# Mock should_update_prisma_schema to return True (so setup_database gets called)
|
|
mock_should_update_schema.return_value = True
|
|
|
|
mock_proxy_module = MagicMock(
|
|
app=MagicMock(),
|
|
ProxyConfig=MagicMock(),
|
|
KeyManagementSettings=MagicMock(),
|
|
save_worker_config=MagicMock(),
|
|
)
|
|
|
|
clean_env = {
|
|
k: v
|
|
for k, v in os.environ.items()
|
|
if k not in ("DATABASE_URL", "DIRECT_URL")
|
|
}
|
|
clean_env["DATABASE_URL"] = "postgresql://test:test@localhost:5432/test"
|
|
|
|
with (
|
|
patch.dict(os.environ, clean_env, clear=True),
|
|
patch.dict(
|
|
"sys.modules",
|
|
{
|
|
"proxy_server": mock_proxy_module,
|
|
"litellm.proxy.proxy_server": mock_proxy_module,
|
|
},
|
|
),
|
|
patch(
|
|
"litellm.proxy.proxy_cli.ProxyInitializationHelpers._get_default_unvicorn_init_args"
|
|
) as mock_get_args,
|
|
):
|
|
mock_get_args.return_value = {
|
|
"app": "litellm.proxy.proxy_server:app",
|
|
"host": "localhost",
|
|
"port": 8000,
|
|
}
|
|
|
|
# Use standalone_mode=False to bypass Click's CliRunner stream
|
|
# isolation which causes flaky "I/O operation on closed file"
|
|
# errors in CI environments (Click 8.3.x stream lifecycle issue).
|
|
|
|
# Test 1: Without --use_prisma_db_push flag (default behavior)
|
|
# use_prisma_db_push should be False (default), so use_migrate should be True
|
|
run_server.main(["--local", "--skip_server_startup"], standalone_mode=False)
|
|
mock_setup_database.assert_called_with(
|
|
use_migrate=True, use_v2_resolver=False
|
|
)
|
|
|
|
# Reset mocks
|
|
mock_setup_database.reset_mock()
|
|
mock_should_update_schema.reset_mock()
|
|
mock_should_update_schema.return_value = True
|
|
|
|
# Test 2: With --use_prisma_db_push flag set
|
|
# use_prisma_db_push should be True, so use_migrate should be False
|
|
run_server.main(
|
|
["--local", "--skip_server_startup", "--use_prisma_db_push"],
|
|
standalone_mode=False,
|
|
)
|
|
mock_setup_database.assert_called_with(
|
|
use_migrate=False, use_v2_resolver=False
|
|
)
|
|
|
|
@patch("subprocess.run")
|
|
@patch("atexit.register")
|
|
@patch("litellm.proxy.db.prisma_client.PrismaManager.setup_database")
|
|
@patch("litellm.proxy.db.check_migration.check_prisma_schema_diff")
|
|
@patch("litellm.proxy.db.prisma_client.should_update_prisma_schema")
|
|
def test_startup_fails_when_db_setup_fails(
|
|
self,
|
|
mock_should_update_schema,
|
|
mock_check_schema_diff,
|
|
mock_setup_database,
|
|
mock_atexit_register,
|
|
mock_subprocess_run,
|
|
):
|
|
"""Test that proxy exits with code 1 when PrismaManager.setup_database returns False and --enforce_prisma_migration_check is set"""
|
|
from litellm.proxy.proxy_cli import run_server
|
|
|
|
mock_subprocess_run.return_value = MagicMock(returncode=0)
|
|
mock_should_update_schema.return_value = True
|
|
mock_setup_database.return_value = False
|
|
|
|
mock_proxy_module = MagicMock(
|
|
app=MagicMock(),
|
|
ProxyConfig=MagicMock(),
|
|
KeyManagementSettings=MagicMock(),
|
|
save_worker_config=MagicMock(),
|
|
)
|
|
|
|
clean_env = {
|
|
k: v
|
|
for k, v in os.environ.items()
|
|
if k not in ("DATABASE_URL", "DIRECT_URL")
|
|
}
|
|
clean_env["DATABASE_URL"] = "postgresql://test:test@localhost:5432/test"
|
|
|
|
with (
|
|
patch.dict(os.environ, clean_env, clear=True),
|
|
patch.dict(
|
|
"sys.modules",
|
|
{
|
|
"proxy_server": mock_proxy_module,
|
|
"litellm.proxy.proxy_server": mock_proxy_module,
|
|
},
|
|
),
|
|
patch(
|
|
"litellm.proxy.proxy_cli.ProxyInitializationHelpers._get_default_unvicorn_init_args"
|
|
) as mock_get_args,
|
|
):
|
|
mock_get_args.return_value = {
|
|
"app": "litellm.proxy.proxy_server:app",
|
|
"host": "localhost",
|
|
"port": 8000,
|
|
}
|
|
|
|
with pytest.raises(SystemExit) as exc_info:
|
|
run_server.main(
|
|
[
|
|
"--local",
|
|
"--skip_server_startup",
|
|
"--enforce_prisma_migration_check",
|
|
],
|
|
standalone_mode=False,
|
|
)
|
|
assert exc_info.value.code == 1
|
|
mock_setup_database.assert_called_once_with(
|
|
use_migrate=True, use_v2_resolver=False
|
|
)
|
|
|
|
|
|
# --- Module-level helpers for worker startup hook tests ---
|
|
|
|
_dummy_hook_called = False
|
|
|
|
|
|
def _dummy_hook():
|
|
"""A simple sync hook used by test_should_run_worker_startup_hooks."""
|
|
global _dummy_hook_called
|
|
_dummy_hook_called = True
|
|
|
|
|
|
_dummy_async_hook_called = False
|
|
|
|
|
|
async def _dummy_async_hook():
|
|
"""A simple async hook used by test_should_run_async_worker_startup_hook."""
|
|
global _dummy_async_hook_called
|
|
_dummy_async_hook_called = True
|
|
|
|
|
|
def _failing_hook():
|
|
"""A hook that always raises, used by test_should_raise_on_failing_hook."""
|
|
raise RuntimeError("Hook failed on purpose")
|
|
|
|
|
|
class TestWorkerStartupHooks:
|
|
"""Tests for the LITELLM_WORKER_STARTUP_HOOKS mechanism in proxy_startup_event."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_should_run_worker_startup_hooks(self):
|
|
"""Sync worker startup hook is called during proxy_startup_event."""
|
|
global _dummy_hook_called
|
|
_dummy_hook_called = False
|
|
|
|
from litellm.proxy.proxy_server import proxy_startup_event
|
|
|
|
env_overrides = {
|
|
"LITELLM_WORKER_STARTUP_HOOKS": "tests.test_litellm.proxy.test_proxy_cli:_dummy_hook",
|
|
}
|
|
# Remove DATABASE_URL to avoid real DB setup
|
|
clean_env = {
|
|
k: v
|
|
for k, v in os.environ.items()
|
|
if k not in ("DATABASE_URL", "DIRECT_URL")
|
|
}
|
|
clean_env.update(env_overrides)
|
|
|
|
with patch.dict(os.environ, clean_env, clear=True):
|
|
try:
|
|
async with proxy_startup_event(app=None) as _:
|
|
pass
|
|
except Exception:
|
|
pass # We expect errors after the hook (no DB, etc.)
|
|
|
|
assert _dummy_hook_called is True, "Sync startup hook was not called"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_should_run_async_worker_startup_hook(self):
|
|
"""Async worker startup hook is awaited during proxy_startup_event."""
|
|
global _dummy_async_hook_called
|
|
_dummy_async_hook_called = False
|
|
|
|
from litellm.proxy.proxy_server import proxy_startup_event
|
|
|
|
env_overrides = {
|
|
"LITELLM_WORKER_STARTUP_HOOKS": "tests.test_litellm.proxy.test_proxy_cli:_dummy_async_hook",
|
|
}
|
|
clean_env = {
|
|
k: v
|
|
for k, v in os.environ.items()
|
|
if k not in ("DATABASE_URL", "DIRECT_URL")
|
|
}
|
|
clean_env.update(env_overrides)
|
|
|
|
with patch.dict(os.environ, clean_env, clear=True):
|
|
try:
|
|
async with proxy_startup_event(app=None) as _:
|
|
pass
|
|
except Exception:
|
|
pass
|
|
|
|
assert _dummy_async_hook_called is True, "Async startup hook was not called"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_should_raise_on_failing_worker_startup_hook(self):
|
|
"""A failing worker startup hook propagates the error."""
|
|
from litellm.proxy.proxy_server import proxy_startup_event
|
|
|
|
env_overrides = {
|
|
"LITELLM_WORKER_STARTUP_HOOKS": "tests.test_litellm.proxy.test_proxy_cli:_failing_hook",
|
|
}
|
|
clean_env = {
|
|
k: v
|
|
for k, v in os.environ.items()
|
|
if k not in ("DATABASE_URL", "DIRECT_URL")
|
|
}
|
|
clean_env.update(env_overrides)
|
|
|
|
with patch.dict(os.environ, clean_env, clear=True):
|
|
with pytest.raises(RuntimeError, match="Hook failed on purpose"):
|
|
async with proxy_startup_event(app=None) as _:
|
|
pass
|
|
|
|
def test_should_skip_when_no_hooks_set(self):
|
|
"""When LITELLM_WORKER_STARTUP_HOOKS is not set, no hooks are executed."""
|
|
global _dummy_hook_called
|
|
_dummy_hook_called = False
|
|
|
|
with patch.dict(os.environ, {}, clear=False):
|
|
os.environ.pop("LITELLM_WORKER_STARTUP_HOOKS", None)
|
|
# The hook block should be skipped entirely when env var is absent
|
|
assert "LITELLM_WORKER_STARTUP_HOOKS" not in os.environ
|
|
# Verify that an empty env var value also results in no hook execution
|
|
assert os.environ.get("LITELLM_WORKER_STARTUP_HOOKS", "") == ""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_should_run_multiple_hooks(self):
|
|
"""Multiple comma-separated hooks are all called."""
|
|
global _dummy_hook_called, _dummy_async_hook_called
|
|
_dummy_hook_called = False
|
|
_dummy_async_hook_called = False
|
|
|
|
from litellm.proxy.proxy_server import proxy_startup_event
|
|
|
|
hooks = (
|
|
"tests.test_litellm.proxy.test_proxy_cli:_dummy_hook,"
|
|
"tests.test_litellm.proxy.test_proxy_cli:_dummy_async_hook"
|
|
)
|
|
env_overrides = {
|
|
"LITELLM_WORKER_STARTUP_HOOKS": hooks,
|
|
}
|
|
clean_env = {
|
|
k: v
|
|
for k, v in os.environ.items()
|
|
if k not in ("DATABASE_URL", "DIRECT_URL")
|
|
}
|
|
clean_env.update(env_overrides)
|
|
|
|
with patch.dict(os.environ, clean_env, clear=True):
|
|
try:
|
|
async with proxy_startup_event(app=None) as _:
|
|
pass
|
|
except Exception:
|
|
pass
|
|
|
|
assert _dummy_hook_called is True, "First hook was not called"
|
|
assert _dummy_async_hook_called is True, "Second hook was not called"
|