Files
litellm/tests/test_litellm/proxy/test_proxy_cli.py
T
Mateo Wang b3297fc2ea feat(proxy): hot-reload .env in dev when running with --reload (#29783)
* feat(proxy): hot-reload .env in dev when running with --reload

The --reload watcher already restarts the worker on *.py and --config YAML
edits, but .env was unwatched, so changing a key there did nothing until a
manual restart. Add .env to the uvicorn reload_includes (and to the
StatReload monkeypatch, which ignores reload_includes) so an edit triggers a
worker restart.

A reloaded worker is a fresh process that inherits the reloader's
environment, so load_dotenv(override=False) would keep serving the stale
inherited value for any key already in the environment. The CLI now exports
LITELLM_DEV_ENV_HOT_RELOAD when --reload is set, and litellm/__init__.py
reads it to load .env with override=True only on that dev path, leaving
normal startup precedence untouched.

* feat(proxy): warn that --reload makes .env override shell env vars

When --reload is active, worker processes re-read .env with override=True, so
.env values win over shell-exported environment variables. Surface this dotenv
precedence change with a startup warning so a developer who relies on a
shell-exported override is not silently surprised.

* fix(proxy): type reload helper paths as Optional[str] to satisfy mypy

* fix(proxy): watch the cwd .env in both reload backends for parity

WatchFiles only watches cwd (and the --config dir) for .env, while the
StatReload fallback used find_dotenv(usecwd=True), which walks up to a
parent-dir .env that WatchFiles never sees. Point StatReload at the same
cwd .env so the two reload backends react to the same file.
2026-06-06 09:39:21 -07:00

1513 lines
56 KiB
Python

import os
import sys
from pathlib import Path
from types import SimpleNamespace
from unittest.mock import AsyncMock, MagicMock, patch
import click
import fastapi
import pytest
sys.path.insert(
0, os.path.abspath("../../..")
) # Adds the parent directory to the system-path
import builtins
import types
from litellm.proxy.proxy_cli import ProxyInitializationHelpers
@pytest.mark.xdist_group("proxy_cli")
class TestProxyInitializationHelpers:
@patch("importlib.metadata.version")
@patch("click.echo")
def test_echo_litellm_version(self, mock_echo, mock_version):
# Setup
mock_version.return_value = "1.0.0"
# Execute
ProxyInitializationHelpers._echo_litellm_version()
# Assert
mock_version.assert_called_once_with("litellm")
mock_echo.assert_called_once_with("\nLiteLLM: Current Version = 1.0.0\n")
@patch("httpx.get")
@patch("builtins.print")
@patch("json.dumps")
def test_run_health_check(self, mock_dumps, mock_print, mock_get):
# Setup
mock_response = MagicMock()
mock_response.json.return_value = {"status": "healthy"}
mock_get.return_value = mock_response
mock_dumps.return_value = '{"status": "healthy"}'
# Execute
ProxyInitializationHelpers._run_health_check("localhost", 8000)
# Assert
mock_get.assert_called_once_with(url="http://localhost:8000/health")
mock_response.json.assert_called_once()
mock_dumps.assert_called_once_with({"status": "healthy"}, indent=4)
@patch("openai.OpenAI")
@patch("click.echo")
@patch("builtins.print")
def test_run_test_chat_completion(self, mock_print, mock_echo, mock_openai):
# Setup
mock_client = MagicMock()
mock_openai.return_value = mock_client
mock_response = MagicMock()
mock_client.chat.completions.create.return_value = mock_response
mock_stream_response = MagicMock()
mock_stream_response.__iter__.return_value = [MagicMock(), MagicMock()]
mock_client.chat.completions.create.side_effect = [
mock_response,
mock_stream_response,
]
# Execute
with pytest.raises(ValueError, match="Invalid test value"):
ProxyInitializationHelpers._run_test_chat_completion(
"localhost", 8000, "gpt-3.5-turbo", True
)
# Test with valid string test value
ProxyInitializationHelpers._run_test_chat_completion(
"localhost", 8000, "gpt-3.5-turbo", "http://test-url"
)
# Assert
mock_openai.assert_called_once_with(
api_key="My API Key", base_url="http://test-url"
)
mock_client.chat.completions.create.assert_called()
def test_get_default_unvicorn_init_args(self):
# Test without log_config
args = ProxyInitializationHelpers._get_default_unvicorn_init_args(
"localhost", 8000
)
assert args["app"] == "litellm.proxy.proxy_server:app"
assert args["host"] == "localhost"
assert args["port"] == 8000
# Test with log_config
args = ProxyInitializationHelpers._get_default_unvicorn_init_args(
"localhost", 8000, "log_config.json"
)
assert args["log_config"] == "log_config.json"
# Test with json_logs=True
with patch("litellm.json_logs", True):
args = ProxyInitializationHelpers._get_default_unvicorn_init_args(
"localhost", 8000
)
# When json_logs is True, log_config should be set to the JSON log config dict
assert args["log_config"] is not None
assert isinstance(args["log_config"], dict)
assert "version" in args["log_config"]
assert "formatters" in args["log_config"]
# Test with keepalive_timeout
args = ProxyInitializationHelpers._get_default_unvicorn_init_args(
"localhost", 8000, None, 60
)
assert args["timeout_keep_alive"] == 60
# Test with both log_config and keepalive_timeout
args = ProxyInitializationHelpers._get_default_unvicorn_init_args(
"localhost", 8000, "log_config.json", 120
)
assert args["log_config"] == "log_config.json"
assert args["timeout_keep_alive"] == 120
class _FakeUvicornConfig:
def __init__(self, timeout_worker_healthcheck=None):
pass
with patch("uvicorn.Config", _FakeUvicornConfig):
args = ProxyInitializationHelpers._get_default_unvicorn_init_args(
"localhost", 8000, timeout_worker_healthcheck=15
)
assert args["timeout_worker_healthcheck"] == 15
def test_get_reload_options_no_config_still_watches_env(self):
opts = ProxyInitializationHelpers._get_reload_options(None)
assert opts["reload"] is True
assert opts["reload_dirs"] == [os.path.abspath(os.getcwd())]
assert opts["reload_includes"] == ["*.py", ".env"]
def test_get_reload_options_with_config_in_cwd(self, tmp_path, monkeypatch):
config_file = tmp_path / "config.yaml"
config_file.write_text("model_list: []\n")
monkeypatch.chdir(tmp_path)
opts = ProxyInitializationHelpers._get_reload_options("config.yaml")
assert opts["reload"] is True
assert opts["reload_dirs"] == [str(tmp_path)]
assert opts["reload_includes"] == ["*.py", ".env", "config.yaml"]
def test_get_reload_options_with_config_outside_cwd(self, tmp_path, monkeypatch):
cwd_dir = tmp_path / "work"
cwd_dir.mkdir()
elsewhere = tmp_path / "configs"
elsewhere.mkdir()
config_file = elsewhere / "proxy.yaml"
config_file.write_text("model_list: []\n")
monkeypatch.chdir(cwd_dir)
opts = ProxyInitializationHelpers._get_reload_options(str(config_file))
assert opts["reload"] is True
assert opts["reload_dirs"] == [str(cwd_dir), str(elsewhere)]
assert opts["reload_includes"] == ["*.py", ".env", "proxy.yaml"]
def test_patch_statreload_extra_paths_yields_config_and_py(self, tmp_path):
from pathlib import Path
from uvicorn.supervisors.statreload import StatReload
if hasattr(StatReload, "_litellm_patched_config_paths"):
StatReload._litellm_patched_config_paths.clear()
config_file = tmp_path / "config.yaml"
config_file.write_text("model_list: []\n")
py_file = tmp_path / "module.py"
py_file.write_text("x = 1\n")
applied = ProxyInitializationHelpers._patch_statreload_extra_paths(
[str(config_file)]
)
assert applied is True
fake_self = types.SimpleNamespace(
config=types.SimpleNamespace(reload_dirs=[tmp_path])
)
yielded_paths = {Path(p).resolve() for p in StatReload.iter_py_files(fake_self)}
assert config_file.resolve() in yielded_paths
assert py_file.resolve() in yielded_paths
def test_patch_statreload_extra_paths_yields_env(self, tmp_path):
from pathlib import Path
from uvicorn.supervisors.statreload import StatReload
if hasattr(StatReload, "_litellm_patched_config_paths"):
StatReload._litellm_patched_config_paths.clear()
env_file = tmp_path / ".env"
env_file.write_text("FOO=bar\n")
applied = ProxyInitializationHelpers._patch_statreload_extra_paths(
[str(env_file)]
)
assert applied is True
fake_self = types.SimpleNamespace(
config=types.SimpleNamespace(reload_dirs=[tmp_path])
)
yielded_paths = {Path(p).resolve() for p in StatReload.iter_py_files(fake_self)}
assert env_file.resolve() in yielded_paths
def test_patch_statreload_extra_paths_skips_falsy(self, tmp_path):
from uvicorn.supervisors.statreload import StatReload
if hasattr(StatReload, "_litellm_patched_config_paths"):
StatReload._litellm_patched_config_paths.clear()
assert ProxyInitializationHelpers._patch_statreload_extra_paths([]) is False
assert (
ProxyInitializationHelpers._patch_statreload_extra_paths([None, ""])
is False
)
def test_patch_statreload_extra_paths_is_idempotent(self, tmp_path):
from pathlib import Path
from uvicorn.supervisors.statreload import StatReload
if hasattr(StatReload, "_litellm_patched_config_paths"):
StatReload._litellm_patched_config_paths.clear()
config_file = tmp_path / "config.yaml"
config_file.write_text("model_list: []\n")
py_file = tmp_path / "only.py"
py_file.write_text("x = 1\n")
for _ in range(3):
ProxyInitializationHelpers._patch_statreload_extra_paths([str(config_file)])
fake_self = types.SimpleNamespace(
config=types.SimpleNamespace(reload_dirs=[tmp_path])
)
yielded = list(StatReload.iter_py_files(fake_self))
assert len(yielded) == len(set(map(str, yielded)))
yielded_paths = {Path(p).resolve() for p in yielded}
assert config_file.resolve() in yielded_paths
assert py_file.resolve() in yielded_paths
def test_configure_dev_reload_watches_env_and_sets_override_flag(
self, tmp_path, monkeypatch
):
from pathlib import Path
from uvicorn.supervisors.statreload import StatReload
if hasattr(StatReload, "_litellm_patched_config_paths"):
StatReload._litellm_patched_config_paths.clear()
monkeypatch.delenv("LITELLM_DEV_ENV_HOT_RELOAD", raising=False)
config_file = tmp_path / "config.yaml"
config_file.write_text("model_list: []\n")
env_file = tmp_path / ".env"
env_file.write_text("FOO=bar\n")
monkeypatch.chdir(tmp_path)
uvicorn_args: dict = {}
with patch("litellm._logging.verbose_proxy_logger.warning") as mock_warning:
ProxyInitializationHelpers._configure_dev_reload(
uvicorn_args, str(config_file)
)
assert os.environ["LITELLM_DEV_ENV_HOT_RELOAD"] == "True"
assert uvicorn_args["reload"] is True
assert ".env" in uvicorn_args["reload_includes"]
mock_warning.assert_called_once()
warning_text = mock_warning.call_args.args[0].lower()
assert "override" in warning_text
assert ".env" in warning_text
fake_self = types.SimpleNamespace(
config=types.SimpleNamespace(reload_dirs=[tmp_path])
)
yielded_paths = {Path(p).resolve() for p in StatReload.iter_py_files(fake_self)}
assert env_file.resolve() in yielded_paths
assert config_file.resolve() in yielded_paths
def test_dev_env_hot_reload_enabled_reads_flag(self, monkeypatch):
import litellm
monkeypatch.setenv("LITELLM_DEV_ENV_HOT_RELOAD", "True")
assert litellm._dev_env_hot_reload_enabled() is True
monkeypatch.setenv("LITELLM_DEV_ENV_HOT_RELOAD", "false")
assert litellm._dev_env_hot_reload_enabled() is False
monkeypatch.delenv("LITELLM_DEV_ENV_HOT_RELOAD", raising=False)
assert litellm._dev_env_hot_reload_enabled() is False
@patch("asyncio.run")
@patch("builtins.print")
def test_init_hypercorn_server(self, mock_print, mock_asyncio_run):
# Setup
mock_app = MagicMock()
# Execute
ProxyInitializationHelpers._init_hypercorn_server(
mock_app, "localhost", 8000, None, None, None
)
# Assert
mock_asyncio_run.assert_called_once()
# Test with SSL
ProxyInitializationHelpers._init_hypercorn_server(
mock_app, "localhost", 8000, "cert.pem", "key.pem", "ECDHE"
)
@patch("granian.Granian")
@patch("builtins.print")
def test_init_granian_server(self, mock_print, mock_granian_cls):
pytest.importorskip("granian")
mock_server = MagicMock()
mock_granian_cls.return_value = mock_server
fake_interfaces = SimpleNamespace(ASGI="asgi")
with patch("granian.constants.Interfaces", fake_interfaces):
ProxyInitializationHelpers._init_granian_server(
host="0.0.0.0",
port=4000,
num_workers=2,
ssl_certfile_path=None,
ssl_keyfile_path=None,
max_requests_before_restart=None,
ciphers=None,
granian_runtime_threads=None,
)
mock_granian_cls.assert_called_once()
call_kwargs = mock_granian_cls.call_args.kwargs
assert call_kwargs["target"] == "litellm.proxy.proxy_server:app"
assert call_kwargs["address"] == "0.0.0.0"
assert call_kwargs["port"] == 4000
assert call_kwargs["workers"] == 2
assert call_kwargs["interface"] == "asgi"
assert call_kwargs["websockets"] is True
assert "runtime_threads" not in call_kwargs
mock_server.serve.assert_called_once()
@patch("granian.Granian")
@patch("builtins.print")
def test_init_granian_server_runtime_threads(self, mock_print, mock_granian_cls):
pytest.importorskip("granian")
mock_server = MagicMock()
mock_granian_cls.return_value = mock_server
fake_interfaces = SimpleNamespace(ASGI="asgi")
with patch("granian.constants.Interfaces", fake_interfaces):
ProxyInitializationHelpers._init_granian_server(
host="0.0.0.0",
port=4000,
num_workers=1,
ssl_certfile_path=None,
ssl_keyfile_path=None,
max_requests_before_restart=None,
ciphers=None,
granian_runtime_threads=4,
)
assert mock_granian_cls.call_args.kwargs["runtime_threads"] == 4
@patch("granian.Granian")
@patch("builtins.print")
def test_init_granian_server_ssl(self, mock_print, mock_granian_cls):
pytest.importorskip("granian")
mock_server = MagicMock()
mock_granian_cls.return_value = mock_server
fake_interfaces = SimpleNamespace(ASGI="asgi")
with patch("granian.constants.Interfaces", fake_interfaces):
ProxyInitializationHelpers._init_granian_server(
host="0.0.0.0",
port=4000,
num_workers=1,
ssl_certfile_path="/path/to/cert.pem",
ssl_keyfile_path="/path/to/key.pem",
max_requests_before_restart=None,
ciphers=None,
granian_runtime_threads=None,
)
call_kwargs = mock_granian_cls.call_args.kwargs
assert call_kwargs["ssl_cert"] == Path("/path/to/cert.pem")
assert call_kwargs["ssl_key"] == Path("/path/to/key.pem")
mock_server.serve.assert_called_once()
@patch("granian.Granian")
def test_init_granian_server_ssl_requires_cert_and_key(self, mock_granian_cls):
pytest.importorskip("granian")
fake_interfaces = SimpleNamespace(ASGI="asgi")
with patch("granian.constants.Interfaces", fake_interfaces):
with pytest.raises(click.ClickException, match="Both --ssl_certfile_path"):
ProxyInitializationHelpers._init_granian_server(
host="0.0.0.0",
port=4000,
num_workers=1,
ssl_certfile_path="/path/to/cert.pem",
ssl_keyfile_path=None,
max_requests_before_restart=None,
ciphers=None,
granian_runtime_threads=None,
)
mock_granian_cls.assert_not_called()
@patch("subprocess.Popen")
def test_run_ollama_serve(self, mock_popen):
# Execute
ProxyInitializationHelpers._run_ollama_serve()
# Assert
mock_popen.assert_called_once()
# Test exception handling
mock_popen.side_effect = Exception("Test exception")
ProxyInitializationHelpers._run_ollama_serve() # Should not raise
@patch("socket.socket")
def test_is_port_in_use(self, mock_socket):
# Setup for port in use
mock_socket_instance = MagicMock()
mock_socket_instance.connect_ex.return_value = 0
mock_socket.return_value.__enter__.return_value = mock_socket_instance
# Execute and Assert
assert ProxyInitializationHelpers._is_port_in_use(8000) is True
# Setup for port not in use
mock_socket_instance.connect_ex.return_value = 1
# Execute and Assert
assert ProxyInitializationHelpers._is_port_in_use(8000) is False
def test_get_loop_type(self):
# Test on Windows
with patch("sys.platform", "win32"):
assert ProxyInitializationHelpers._get_loop_type() is None
# Test on Linux
with patch("sys.platform", "linux"):
assert ProxyInitializationHelpers._get_loop_type() == "uvloop"
@patch.dict(os.environ, {}, clear=True)
def test_database_url_construction_with_special_characters(self):
# Setup environment variables with special characters that need escaping
test_env = {
"DATABASE_HOST": "localhost:5432",
"DATABASE_USERNAME": "user@with+special",
"DATABASE_PASSWORD": "test-password-special-chars",
"DATABASE_NAME": "db_name/test",
}
with patch.dict(os.environ, test_env):
# Call the relevant function - we'll need to extract the database URL construction logic
# This is simulating what happens in the run_server function when database_url is None
import urllib.parse
from litellm.proxy.proxy_cli import append_query_params
database_host = os.environ["DATABASE_HOST"]
database_username = os.environ["DATABASE_USERNAME"]
database_password = os.environ["DATABASE_PASSWORD"]
database_name = os.environ["DATABASE_NAME"]
# Test the URL encoding part
database_username_enc = urllib.parse.quote_plus(database_username)
database_password_enc = urllib.parse.quote_plus(database_password)
database_name_enc = urllib.parse.quote_plus(database_name)
# Construct DATABASE_URL from the provided variables
database_url = f"postgresql://{database_username_enc}:{database_password_enc}@{database_host}/{database_name_enc}"
# Assert the correct URL was constructed with properly escaped characters
expected_url = "postgresql://user%40with%2Bspecial:test-password-special-chars@localhost:5432/db_name%2Ftest"
assert database_url == expected_url
# Test appending query parameters
params = {"connection_limit": 10, "pool_timeout": 60}
modified_url = append_query_params(database_url, params)
assert "connection_limit=10" in modified_url
assert "pool_timeout=60" in modified_url
def test_append_query_params_handles_missing_url(self):
from litellm.proxy.proxy_cli import append_query_params
modified_url = append_query_params(None, {"connection_limit": 10})
assert modified_url == ""
@patch("uvicorn.run")
@patch("atexit.register") # critical
@patch("litellm.proxy.db.prisma_client.PrismaManager.setup_database")
@patch(
"litellm.proxy.db.prisma_client.should_update_prisma_schema", return_value=False
)
def test_skip_server_startup(
self, mock_should_update, mock_setup_db, mock_atexit_register, mock_uvicorn_run
):
from click.testing import CliRunner
from litellm.proxy.proxy_cli import run_server
runner = CliRunner()
mock_proxy_module = MagicMock(
app=MagicMock(),
ProxyConfig=MagicMock(),
KeyManagementSettings=MagicMock(),
save_worker_config=MagicMock(),
)
# Remove DATABASE_URL/DIRECT_URL so the CLI doesn't attempt
# real prisma operations when these are set in CI.
clean_env = {
k: v
for k, v in os.environ.items()
if k not in ("DATABASE_URL", "DIRECT_URL")
}
with (
patch.dict(
os.environ,
clean_env,
clear=True,
),
patch.dict(
"sys.modules",
{
"proxy_server": mock_proxy_module,
# Prevent real import of proxy_server inside Click's
# isolation context (heavy side effects cause stream
# lifecycle issues with Click 8.2+)
"litellm.proxy.proxy_server": mock_proxy_module,
},
),
patch(
"litellm.proxy.proxy_cli.ProxyInitializationHelpers._get_default_unvicorn_init_args"
) as mock_get_args,
):
mock_get_args.return_value = {
"app": "litellm.proxy.proxy_server:app",
"host": "localhost",
"port": 8000,
}
# --- skip startup ---
result = runner.invoke(run_server, ["--local", "--skip_server_startup"])
assert (
result.exit_code == 0
), f"exit_code={result.exit_code}, output={result.output}"
assert "Skipping server startup" in result.output
mock_uvicorn_run.assert_not_called()
# --- normal startup ---
mock_uvicorn_run.reset_mock()
result = runner.invoke(run_server, ["--local"])
assert (
result.exit_code == 0
), f"exit_code={result.exit_code}, output={result.output}"
mock_uvicorn_run.assert_called_once()
@pytest.mark.parametrize(
"timeout_config,expected_timeout",
[
({"database_connection_timeout": 30}, 30),
({"database_connection_pool_timeout": 45}, 45),
(
{
"database_connection_timeout": 30,
"database_connection_pool_timeout": 45,
},
30,
),
],
)
@patch("subprocess.run")
@patch("atexit.register")
@patch("litellm.proxy.db.prisma_client.PrismaManager.setup_database")
@patch(
"litellm.proxy.db.prisma_client.should_update_prisma_schema", return_value=False
)
def test_db_timeout_settings_are_forwarded_to_pool_timeout(
self,
mock_should_update,
mock_setup_db,
mock_atexit_register,
mock_subprocess_run,
timeout_config,
expected_timeout,
):
from click.testing import CliRunner
from litellm.proxy.proxy_cli import run_server
runner = CliRunner()
mock_subprocess_run.return_value = MagicMock(returncode=0)
mock_proxy_module = MagicMock(
app=MagicMock(),
ProxyConfig=MagicMock(),
KeyManagementSettings=MagicMock(),
save_worker_config=MagicMock(),
)
mock_proxy_module.ProxyConfig.return_value.get_config = AsyncMock(
return_value={
"general_settings": {
"database_url": "postgresql://test:test@localhost:5432/test",
"database_connection_pool_limit": 5,
**timeout_config,
}
}
)
clean_env = {
k: v
for k, v in os.environ.items()
if k not in ("DATABASE_URL", "DIRECT_URL")
}
with (
patch.dict(os.environ, clean_env, clear=True),
patch.dict(
"sys.modules",
{
"proxy_server": mock_proxy_module,
"litellm.proxy.proxy_server": mock_proxy_module,
},
),
patch(
"litellm.proxy.proxy_cli.ProxyInitializationHelpers._get_default_unvicorn_init_args"
) as mock_get_args,
patch(
"litellm.proxy.proxy_cli.append_query_params",
side_effect=lambda url, params: (
f"{url}?connection_limit={params['connection_limit']}&pool_timeout={params['pool_timeout']}"
),
) as mock_append_query_params,
):
mock_get_args.return_value = {
"app": "litellm.proxy.proxy_server:app",
"host": "localhost",
"port": 8000,
}
result = runner.invoke(
run_server,
["--local", "--config", "test-config.yaml", "--skip_server_startup"],
)
assert (
result.exit_code == 0
), f"exit_code={result.exit_code}, output={result.output}"
mock_append_query_params.assert_called()
appended_params = mock_append_query_params.call_args.args[1]
assert appended_params["connection_limit"] == 5
assert appended_params["pool_timeout"] == expected_timeout
def test_build_db_connection_url_params_defaults(self):
from litellm.proxy.proxy_cli import _build_db_connection_url_params
params = _build_db_connection_url_params(connection_limit=10, pool_timeout=60)
assert params == {"connection_limit": 10, "pool_timeout": 60}
def test_build_db_connection_url_params_omits_none_timeouts(self):
from litellm.proxy.proxy_cli import _build_db_connection_url_params
params = _build_db_connection_url_params(
connection_limit=10,
pool_timeout=60,
connect_timeout=None,
socket_timeout=None,
)
assert "connect_timeout" not in params
assert "socket_timeout" not in params
def test_build_db_connection_url_params_includes_optional_timeouts(self):
from litellm.proxy.proxy_cli import _build_db_connection_url_params
params = _build_db_connection_url_params(
connection_limit=10,
pool_timeout=60,
connect_timeout=15,
socket_timeout=120,
)
assert params["connect_timeout"] == 15
assert params["socket_timeout"] == 120
def test_build_db_connection_url_params_extras_override_defaults(self):
from litellm.proxy.proxy_cli import _build_db_connection_url_params
params = _build_db_connection_url_params(
connection_limit=10,
pool_timeout=60,
extra_params={
"pgbouncer": "true",
"statement_cache_size": 0,
"pool_timeout": 5,
},
)
assert params["pgbouncer"] == "true"
assert params["statement_cache_size"] == 0
assert params["pool_timeout"] == 5
@patch("subprocess.run")
@patch("atexit.register")
@patch("litellm.proxy.db.prisma_client.PrismaManager.setup_database")
@patch(
"litellm.proxy.db.prisma_client.should_update_prisma_schema", return_value=False
)
def test_db_connection_extra_params_forwarded_to_url(
self,
mock_should_update,
mock_setup_db,
mock_atexit_register,
mock_subprocess_run,
):
from click.testing import CliRunner
from litellm.proxy.proxy_cli import run_server
runner = CliRunner()
mock_subprocess_run.return_value = MagicMock(returncode=0)
mock_proxy_module = MagicMock(
app=MagicMock(),
ProxyConfig=MagicMock(),
KeyManagementSettings=MagicMock(),
save_worker_config=MagicMock(),
)
mock_proxy_module.ProxyConfig.return_value.get_config = AsyncMock(
return_value={
"general_settings": {
"database_url": "postgresql://test:test@localhost:5432/test",
"database_connect_timeout": 15,
"database_socket_timeout": 120,
"database_extra_connection_params": {
"pgbouncer": "true",
"statement_cache_size": 0,
},
}
}
)
clean_env = {
k: v
for k, v in os.environ.items()
if k not in ("DATABASE_URL", "DIRECT_URL")
}
with (
patch.dict(os.environ, clean_env, clear=True),
patch.dict(
"sys.modules",
{
"proxy_server": mock_proxy_module,
"litellm.proxy.proxy_server": mock_proxy_module,
},
),
patch(
"litellm.proxy.proxy_cli.ProxyInitializationHelpers._get_default_unvicorn_init_args"
) as mock_get_args,
patch(
"litellm.proxy.proxy_cli.append_query_params",
side_effect=lambda url, params: str(url),
) as mock_append_query_params,
):
mock_get_args.return_value = {
"app": "litellm.proxy.proxy_server:app",
"host": "localhost",
"port": 8000,
}
result = runner.invoke(
run_server,
["--local", "--config", "test-config.yaml", "--skip_server_startup"],
)
assert (
result.exit_code == 0
), f"exit_code={result.exit_code}, output={result.output}"
mock_append_query_params.assert_called()
appended_params = mock_append_query_params.call_args.args[1]
assert appended_params["connect_timeout"] == 15
assert appended_params["socket_timeout"] == 120
assert appended_params["pgbouncer"] == "true"
assert appended_params["statement_cache_size"] == 0
@patch("uvicorn.run")
@patch("atexit.register")
@patch("litellm.proxy.db.prisma_client.PrismaManager.setup_database")
@patch(
"litellm.proxy.db.prisma_client.should_update_prisma_schema", return_value=False
)
def test_proxy_default_api_version_uses_azure_default(
self, mock_should_update, mock_setup_db, mock_atexit_register, mock_uvicorn_run
):
"""Proxy default api_version should match litellm.AZURE_DEFAULT_API_VERSION for consistency."""
from click.testing import CliRunner
import litellm
from litellm.proxy.proxy_cli import run_server
runner = CliRunner()
mock_proxy_module = MagicMock(
app=MagicMock(),
ProxyConfig=MagicMock(),
KeyManagementSettings=MagicMock(),
save_worker_config=MagicMock(),
)
clean_env = {
k: v
for k, v in os.environ.items()
if k not in ("DATABASE_URL", "DIRECT_URL")
}
with (
patch.dict(os.environ, clean_env, clear=True),
patch.dict(
"sys.modules",
{
"proxy_server": mock_proxy_module,
"litellm.proxy.proxy_server": mock_proxy_module,
},
),
patch(
"litellm.proxy.proxy_cli.ProxyInitializationHelpers._get_default_unvicorn_init_args"
) as mock_get_args,
):
mock_get_args.return_value = {
"app": "litellm.proxy.proxy_server:app",
"host": "localhost",
"port": 8000,
}
result = runner.invoke(run_server, ["--local", "--skip_server_startup"])
assert (
result.exit_code == 0
), f"exit_code={result.exit_code}, output={result.output}"
mock_proxy_module.save_worker_config.assert_called_once()
call_kwargs = mock_proxy_module.save_worker_config.call_args[1]
assert call_kwargs["api_version"] == litellm.AZURE_DEFAULT_API_VERSION
@patch("uvicorn.run")
@patch("builtins.print")
@patch("litellm.proxy.db.prisma_client.PrismaManager.setup_database")
@patch(
"litellm.proxy.db.prisma_client.should_update_prisma_schema", return_value=False
)
def test_keepalive_timeout_flag(
self, mock_should_update, mock_setup_db, mock_print, mock_uvicorn_run
):
"""Test that the keepalive_timeout flag is properly passed to uvicorn"""
from click.testing import CliRunner
from litellm.proxy.proxy_cli import run_server
runner = CliRunner()
mock_app = MagicMock()
mock_proxy_config = MagicMock()
mock_key_mgmt = MagicMock()
mock_save_worker_config = MagicMock()
# Strip DATABASE_URL/DIRECT_URL so run_server doesn't enter the prisma
# DB-setup block (un-timeout'd `subprocess.run(["prisma"])` +
# migrate-deploy retry loop) — same isolation every other run_server
# test in this file uses.
clean_env = {
k: v
for k, v in os.environ.items()
if k not in ("DATABASE_URL", "DIRECT_URL")
}
with (
patch.dict(os.environ, clean_env, clear=True),
patch.dict(
"sys.modules",
{
"proxy_server": MagicMock(
app=mock_app,
ProxyConfig=mock_proxy_config,
KeyManagementSettings=mock_key_mgmt,
save_worker_config=mock_save_worker_config,
)
},
),
patch(
"litellm.proxy.proxy_cli.ProxyInitializationHelpers._get_default_unvicorn_init_args"
) as mock_get_args,
patch(
"litellm.proxy.proxy_cli.ProxyInitializationHelpers._is_port_in_use",
return_value=False,
),
):
mock_get_args.return_value = {
"app": "litellm.proxy.proxy_server:app",
"host": "localhost",
"port": 8000,
"timeout_keep_alive": 30,
}
result = runner.invoke(run_server, ["--local", "--keepalive_timeout", "30"])
assert result.exit_code == 0
mock_get_args.assert_called_once_with(
host="0.0.0.0",
port=4000,
log_config=None,
keepalive_timeout=30,
timeout_worker_healthcheck=None,
)
mock_uvicorn_run.assert_called_once()
# Check that the uvicorn.run was called with the timeout_keep_alive parameter
call_args = mock_uvicorn_run.call_args
assert call_args[1]["timeout_keep_alive"] == 30
@patch("uvicorn.run")
@patch("builtins.print")
@patch("litellm.proxy.db.prisma_client.PrismaManager.setup_database")
@patch(
"litellm.proxy.db.prisma_client.should_update_prisma_schema", return_value=False
)
def test_timeout_worker_healthcheck_flag(
self, mock_should_update, mock_setup_db, mock_print, mock_uvicorn_run
):
"""Test that the --timeout_worker_healthcheck flag is threaded through to the uvicorn init helper."""
from click.testing import CliRunner
from litellm.proxy.proxy_cli import run_server
runner = CliRunner()
mock_app = MagicMock()
mock_proxy_config = MagicMock()
mock_key_mgmt = MagicMock()
mock_save_worker_config = MagicMock()
# Strip DATABASE_URL/DIRECT_URL so run_server doesn't enter the prisma
# DB-setup block (un-timeout'd `subprocess.run(["prisma"])` +
# migrate-deploy retry loop) — same isolation every other run_server
# test in this file uses.
clean_env = {
k: v
for k, v in os.environ.items()
if k not in ("DATABASE_URL", "DIRECT_URL")
}
with (
patch.dict(os.environ, clean_env, clear=True),
patch.dict(
"sys.modules",
{
"proxy_server": MagicMock(
app=mock_app,
ProxyConfig=mock_proxy_config,
KeyManagementSettings=mock_key_mgmt,
save_worker_config=mock_save_worker_config,
)
},
),
patch(
"litellm.proxy.proxy_cli.ProxyInitializationHelpers._get_default_unvicorn_init_args"
) as mock_get_args,
patch(
"litellm.proxy.proxy_cli.ProxyInitializationHelpers._is_port_in_use",
return_value=False,
),
):
mock_get_args.return_value = {
"app": "litellm.proxy.proxy_server:app",
"host": "localhost",
"port": 8000,
}
result = runner.invoke(
run_server, ["--local", "--timeout_worker_healthcheck", "15"]
)
assert result.exit_code == 0
mock_get_args.assert_called_once_with(
host="0.0.0.0",
port=4000,
log_config=None,
keepalive_timeout=None,
timeout_worker_healthcheck=15,
)
@patch("uvicorn.run")
@patch("builtins.print")
@patch("litellm.proxy.db.prisma_client.PrismaManager.setup_database")
def test_max_requests_before_restart_flag(
self, mock_setup_db, mock_print, mock_uvicorn_run
):
"""Test that the max_requests_before_restart flag is passed to uvicorn as limit_max_requests"""
from click.testing import CliRunner
from litellm.proxy.proxy_cli import run_server
runner = CliRunner()
mock_app = MagicMock()
mock_proxy_config = MagicMock()
mock_key_mgmt = MagicMock()
mock_save_worker_config = MagicMock()
clean_env = {
k: v
for k, v in os.environ.items()
if k not in ("DATABASE_URL", "DIRECT_URL")
}
with (
patch.dict(
os.environ,
clean_env,
clear=True,
),
patch.dict(
"sys.modules",
{
"proxy_server": MagicMock(
app=mock_app,
ProxyConfig=mock_proxy_config,
KeyManagementSettings=mock_key_mgmt,
save_worker_config=mock_save_worker_config,
)
},
),
patch(
"litellm.proxy.proxy_cli.ProxyInitializationHelpers._get_default_unvicorn_init_args"
) as mock_get_args,
):
mock_get_args.return_value = {
"app": "litellm.proxy.proxy_server:app",
"host": "localhost",
"port": 8000,
}
result = runner.invoke(
run_server, ["--local", "--max_requests_before_restart", "123"]
)
assert (
result.exit_code == 0
), f"exit_code={result.exit_code}, output={result.output}"
mock_uvicorn_run.assert_called_once()
# Check that uvicorn.run was called with limit_max_requests parameter
call_args = mock_uvicorn_run.call_args
assert call_args[1]["limit_max_requests"] == 123
@patch.dict(os.environ, {}, clear=True)
def test_construct_database_url_from_env_vars(self):
"""Test the construct_database_url_from_env_vars function with various scenarios"""
from litellm.proxy.utils import construct_database_url_from_env_vars
# Test with all required variables present
test_env = {
"DATABASE_HOST": "localhost:5432",
"DATABASE_USERNAME": "testuser",
"DATABASE_PASSWORD": "testpass",
"DATABASE_NAME": "testdb",
}
with patch.dict(os.environ, test_env):
result = construct_database_url_from_env_vars()
expected_url = "postgresql://testuser:testpass@localhost:5432/testdb"
assert result == expected_url
# Test with special characters that need URL encoding
test_env_special = {
"DATABASE_HOST": "localhost:5432",
"DATABASE_USERNAME": "user@with+special",
"DATABASE_PASSWORD": "test-password-special-chars",
"DATABASE_NAME": "db_name/test",
}
with patch.dict(os.environ, test_env_special):
result = construct_database_url_from_env_vars()
expected_url = "postgresql://user%40with%2Bspecial:test-password-special-chars@localhost:5432/db_name%2Ftest"
assert result == expected_url
# Test without password (should still work)
test_env_no_password = {
"DATABASE_HOST": "localhost:5432",
"DATABASE_USERNAME": "testuser",
"DATABASE_NAME": "testdb",
}
with patch.dict(os.environ, test_env_no_password):
result = construct_database_url_from_env_vars()
expected_url = "postgresql://testuser@localhost:5432/testdb"
assert result == expected_url
# Test with missing required variables (should return None)
test_env_missing = {
"DATABASE_HOST": "localhost:5432",
"DATABASE_USERNAME": "testuser",
# Missing DATABASE_NAME
}
with patch.dict(os.environ, test_env_missing):
result = construct_database_url_from_env_vars()
assert result is None
# Test with empty environment (should return None)
with patch.dict(os.environ, {}, clear=True):
result = construct_database_url_from_env_vars()
assert result is None
@patch("uvicorn.run")
@patch("builtins.print")
def test_run_server_no_config_passed(self, mock_print, mock_uvicorn_run):
"""Test that run_server properly handles the case when no config is passed"""
import asyncio
from click.testing import CliRunner
from litellm.proxy.proxy_cli import run_server
runner = CliRunner()
mock_app = MagicMock()
mock_proxy_config = MagicMock()
mock_key_mgmt = MagicMock()
mock_save_worker_config = MagicMock()
# Mock the ProxyConfig.get_config method to return a proper async config
async def mock_get_config(config_file_path=None):
return {"general_settings": {}, "litellm_settings": {}}
mock_proxy_config_instance = MagicMock()
mock_proxy_config_instance.get_config = mock_get_config
mock_proxy_config.return_value = mock_proxy_config_instance
mock_proxy_server_module = MagicMock(app=mock_app)
# Only remove DATABASE_URL and DIRECT_URL to prevent the database setup
# code path from running. Do NOT use clear=True as it removes PATH, HOME,
# etc., which causes imports inside run_server to break in CI (the real
# litellm.proxy.proxy_server import at line 820 of proxy_cli.py has heavy
# side effects that fail without a proper environment).
env_overrides = {
"DATABASE_URL": "",
"DIRECT_URL": "",
"IAM_TOKEN_DB_AUTH": "",
"USE_AWS_KMS": "",
}
with patch.dict(os.environ, env_overrides):
# Remove DATABASE_URL entirely so the DB setup block is skipped
os.environ.pop("DATABASE_URL", None)
os.environ.pop("DIRECT_URL", None)
with (
patch.dict(
"sys.modules",
{
"proxy_server": MagicMock(
app=mock_app,
ProxyConfig=mock_proxy_config,
KeyManagementSettings=mock_key_mgmt,
save_worker_config=mock_save_worker_config,
),
# Also mock litellm.proxy.proxy_server to prevent the real
# import at line 820 of proxy_cli.py which has heavy side
# effects (FastAPI app init, logging setup, etc.)
"litellm.proxy.proxy_server": mock_proxy_server_module,
},
),
patch(
"litellm.proxy.proxy_cli.ProxyInitializationHelpers._get_default_unvicorn_init_args"
) as mock_get_args,
):
mock_get_args.return_value = {
"app": "litellm.proxy.proxy_server:app",
"host": "localhost",
"port": 8000,
}
# Test with no config parameter (config=None)
result = runner.invoke(run_server, ["--local"])
assert result.exit_code == 0, (
f"run_server failed with exit_code={result.exit_code}, "
f"output={result.output}, exception={result.exception}"
)
# Verify that uvicorn.run was called
mock_uvicorn_run.assert_called_once()
# Reset mocks for second test
mock_uvicorn_run.reset_mock()
# Test with explicit --config None (should behave the same)
result = runner.invoke(run_server, ["--local", "--config", "None"])
assert result.exit_code == 0, (
f"run_server failed with exit_code={result.exit_code}, "
f"output={result.output}, exception={result.exception}"
)
# Verify that uvicorn.run was called again
mock_uvicorn_run.assert_called_once()
class TestRunServerDbSetup:
"""Tests for run_server's prisma setup_database behavior."""
@patch("subprocess.run")
@patch("atexit.register")
@patch("litellm.proxy.db.prisma_client.PrismaManager.setup_database")
@patch("litellm.proxy.db.check_migration.check_prisma_schema_diff")
@patch("litellm.proxy.db.prisma_client.should_update_prisma_schema")
def test_use_prisma_db_push_flag_behavior(
self,
mock_should_update_schema,
mock_check_schema_diff,
mock_setup_database,
mock_atexit_register,
mock_subprocess_run,
):
"""Test that use_prisma_db_push flag correctly controls PrismaManager.setup_database use_migrate parameter"""
from litellm.proxy.proxy_cli import run_server
# Mock subprocess.run to simulate prisma being available
mock_subprocess_run.return_value = MagicMock(returncode=0)
# Mock should_update_prisma_schema to return True (so setup_database gets called)
mock_should_update_schema.return_value = True
mock_proxy_module = MagicMock(
app=MagicMock(),
ProxyConfig=MagicMock(),
KeyManagementSettings=MagicMock(),
save_worker_config=MagicMock(),
)
clean_env = {
k: v
for k, v in os.environ.items()
if k not in ("DATABASE_URL", "DIRECT_URL")
}
clean_env["DATABASE_URL"] = "postgresql://test:test@localhost:5432/test"
with (
patch.dict(os.environ, clean_env, clear=True),
patch.dict(
"sys.modules",
{
"proxy_server": mock_proxy_module,
"litellm.proxy.proxy_server": mock_proxy_module,
},
),
patch(
"litellm.proxy.proxy_cli.ProxyInitializationHelpers._get_default_unvicorn_init_args"
) as mock_get_args,
):
mock_get_args.return_value = {
"app": "litellm.proxy.proxy_server:app",
"host": "localhost",
"port": 8000,
}
# Use standalone_mode=False to bypass Click's CliRunner stream
# isolation which causes flaky "I/O operation on closed file"
# errors in CI environments (Click 8.3.x stream lifecycle issue).
# Test 1: Without --use_prisma_db_push flag (default behavior)
# use_prisma_db_push should be False (default), so use_migrate should be True
run_server.main(["--local", "--skip_server_startup"], standalone_mode=False)
mock_setup_database.assert_called_with(
use_migrate=True, use_v2_resolver=False
)
# Reset mocks
mock_setup_database.reset_mock()
mock_should_update_schema.reset_mock()
mock_should_update_schema.return_value = True
# Test 2: With --use_prisma_db_push flag set
# use_prisma_db_push should be True, so use_migrate should be False
run_server.main(
["--local", "--skip_server_startup", "--use_prisma_db_push"],
standalone_mode=False,
)
mock_setup_database.assert_called_with(
use_migrate=False, use_v2_resolver=False
)
@patch("subprocess.run")
@patch("atexit.register")
@patch("litellm.proxy.db.prisma_client.PrismaManager.setup_database")
@patch("litellm.proxy.db.check_migration.check_prisma_schema_diff")
@patch("litellm.proxy.db.prisma_client.should_update_prisma_schema")
def test_startup_fails_when_db_setup_fails(
self,
mock_should_update_schema,
mock_check_schema_diff,
mock_setup_database,
mock_atexit_register,
mock_subprocess_run,
):
"""Test that proxy exits with code 1 when PrismaManager.setup_database returns False and --enforce_prisma_migration_check is set"""
from litellm.proxy.proxy_cli import run_server
mock_subprocess_run.return_value = MagicMock(returncode=0)
mock_should_update_schema.return_value = True
mock_setup_database.return_value = False
mock_proxy_module = MagicMock(
app=MagicMock(),
ProxyConfig=MagicMock(),
KeyManagementSettings=MagicMock(),
save_worker_config=MagicMock(),
)
clean_env = {
k: v
for k, v in os.environ.items()
if k not in ("DATABASE_URL", "DIRECT_URL")
}
clean_env["DATABASE_URL"] = "postgresql://test:test@localhost:5432/test"
with (
patch.dict(os.environ, clean_env, clear=True),
patch.dict(
"sys.modules",
{
"proxy_server": mock_proxy_module,
"litellm.proxy.proxy_server": mock_proxy_module,
},
),
patch(
"litellm.proxy.proxy_cli.ProxyInitializationHelpers._get_default_unvicorn_init_args"
) as mock_get_args,
):
mock_get_args.return_value = {
"app": "litellm.proxy.proxy_server:app",
"host": "localhost",
"port": 8000,
}
with pytest.raises(SystemExit) as exc_info:
run_server.main(
[
"--local",
"--skip_server_startup",
"--enforce_prisma_migration_check",
],
standalone_mode=False,
)
assert exc_info.value.code == 1
mock_setup_database.assert_called_once_with(
use_migrate=True, use_v2_resolver=False
)
# --- Module-level helpers for worker startup hook tests ---
_dummy_hook_called = False
def _dummy_hook():
"""A simple sync hook used by test_should_run_worker_startup_hooks."""
global _dummy_hook_called
_dummy_hook_called = True
_dummy_async_hook_called = False
async def _dummy_async_hook():
"""A simple async hook used by test_should_run_async_worker_startup_hook."""
global _dummy_async_hook_called
_dummy_async_hook_called = True
def _failing_hook():
"""A hook that always raises, used by test_should_raise_on_failing_hook."""
raise RuntimeError("Hook failed on purpose")
class TestWorkerStartupHooks:
"""Tests for the LITELLM_WORKER_STARTUP_HOOKS mechanism in proxy_startup_event."""
@pytest.mark.asyncio
async def test_should_run_worker_startup_hooks(self):
"""Sync worker startup hook is called during proxy_startup_event."""
global _dummy_hook_called
_dummy_hook_called = False
from litellm.proxy.proxy_server import proxy_startup_event
env_overrides = {
"LITELLM_WORKER_STARTUP_HOOKS": "tests.test_litellm.proxy.test_proxy_cli:_dummy_hook",
}
# Remove DATABASE_URL to avoid real DB setup
clean_env = {
k: v
for k, v in os.environ.items()
if k not in ("DATABASE_URL", "DIRECT_URL")
}
clean_env.update(env_overrides)
with patch.dict(os.environ, clean_env, clear=True):
try:
async with proxy_startup_event(app=None) as _:
pass
except Exception:
pass # We expect errors after the hook (no DB, etc.)
assert _dummy_hook_called is True, "Sync startup hook was not called"
@pytest.mark.asyncio
async def test_should_run_async_worker_startup_hook(self):
"""Async worker startup hook is awaited during proxy_startup_event."""
global _dummy_async_hook_called
_dummy_async_hook_called = False
from litellm.proxy.proxy_server import proxy_startup_event
env_overrides = {
"LITELLM_WORKER_STARTUP_HOOKS": "tests.test_litellm.proxy.test_proxy_cli:_dummy_async_hook",
}
clean_env = {
k: v
for k, v in os.environ.items()
if k not in ("DATABASE_URL", "DIRECT_URL")
}
clean_env.update(env_overrides)
with patch.dict(os.environ, clean_env, clear=True):
try:
async with proxy_startup_event(app=None) as _:
pass
except Exception:
pass
assert _dummy_async_hook_called is True, "Async startup hook was not called"
@pytest.mark.asyncio
async def test_should_raise_on_failing_worker_startup_hook(self):
"""A failing worker startup hook propagates the error."""
from litellm.proxy.proxy_server import proxy_startup_event
env_overrides = {
"LITELLM_WORKER_STARTUP_HOOKS": "tests.test_litellm.proxy.test_proxy_cli:_failing_hook",
}
clean_env = {
k: v
for k, v in os.environ.items()
if k not in ("DATABASE_URL", "DIRECT_URL")
}
clean_env.update(env_overrides)
with patch.dict(os.environ, clean_env, clear=True):
with pytest.raises(RuntimeError, match="Hook failed on purpose"):
async with proxy_startup_event(app=None) as _:
pass
def test_should_skip_when_no_hooks_set(self):
"""When LITELLM_WORKER_STARTUP_HOOKS is not set, no hooks are executed."""
global _dummy_hook_called
_dummy_hook_called = False
with patch.dict(os.environ, {}, clear=False):
os.environ.pop("LITELLM_WORKER_STARTUP_HOOKS", None)
# The hook block should be skipped entirely when env var is absent
assert "LITELLM_WORKER_STARTUP_HOOKS" not in os.environ
# Verify that an empty env var value also results in no hook execution
assert os.environ.get("LITELLM_WORKER_STARTUP_HOOKS", "") == ""
@pytest.mark.asyncio
async def test_should_run_multiple_hooks(self):
"""Multiple comma-separated hooks are all called."""
global _dummy_hook_called, _dummy_async_hook_called
_dummy_hook_called = False
_dummy_async_hook_called = False
from litellm.proxy.proxy_server import proxy_startup_event
hooks = (
"tests.test_litellm.proxy.test_proxy_cli:_dummy_hook,"
"tests.test_litellm.proxy.test_proxy_cli:_dummy_async_hook"
)
env_overrides = {
"LITELLM_WORKER_STARTUP_HOOKS": hooks,
}
clean_env = {
k: v
for k, v in os.environ.items()
if k not in ("DATABASE_URL", "DIRECT_URL")
}
clean_env.update(env_overrides)
with patch.dict(os.environ, clean_env, clear=True):
try:
async with proxy_startup_event(app=None) as _:
pass
except Exception:
pass
assert _dummy_hook_called is True, "First hook was not called"
assert _dummy_async_hook_called is True, "Second hook was not called"