mirror of
https://github.com/tiennm99/litellm.git
synced 2026-06-18 03:31:23 +00:00
3eeb14bf1a
* fix redis cluster startup_nodes check order * add tests for redis cluster startup_nodes fix
290 lines
12 KiB
Python
290 lines
12 KiB
Python
from litellm._redis import (
|
|
get_redis_url_from_environment,
|
|
_get_redis_cluster_kwargs,
|
|
get_redis_async_client,
|
|
get_redis_client,
|
|
get_redis_connection_pool,
|
|
)
|
|
import json
|
|
import os
|
|
import pytest
|
|
from unittest.mock import MagicMock, patch
|
|
import redis
|
|
import redis.asyncio as async_redis
|
|
|
|
def test_get_redis_url_from_environment_single_url(monkeypatch):
|
|
"""Test when REDIS_URL is directly provided"""
|
|
# Set the environment variable
|
|
monkeypatch.setenv("REDIS_URL", "redis://redis-server:6379/0")
|
|
|
|
# Call the function to get the Redis URL
|
|
redis_url = get_redis_url_from_environment()
|
|
|
|
# Assert that the returned URL matches the expected value
|
|
assert redis_url == "redis://redis-server:6379/0"
|
|
|
|
def test_get_redis_url_from_environment_host_port(monkeypatch):
|
|
"""Test when REDIS_HOST and REDIS_PORT are provided"""
|
|
# Set the environment variables
|
|
monkeypatch.setenv("REDIS_HOST", "redis-server")
|
|
monkeypatch.setenv("REDIS_PORT", "6379")
|
|
# Ensure authentication variables are not set
|
|
monkeypatch.delenv("REDIS_USERNAME", raising=False)
|
|
monkeypatch.delenv("REDIS_PASSWORD", raising=False)
|
|
monkeypatch.delenv("REDIS_SSL", raising=False)
|
|
|
|
# Call the function to get the Redis URL
|
|
redis_url = get_redis_url_from_environment()
|
|
|
|
# Assert that the returned URL matches the expected value
|
|
assert redis_url == "redis://redis-server:6379"
|
|
|
|
def test_get_redis_url_from_environment_with_ssl(monkeypatch):
|
|
"""Test when SSL is enabled"""
|
|
# Set the environment variables
|
|
monkeypatch.setenv("REDIS_HOST", "redis-server")
|
|
monkeypatch.setenv("REDIS_PORT", "6379")
|
|
monkeypatch.setenv("REDIS_SSL", "true")
|
|
# Ensure authentication variables are not set
|
|
monkeypatch.delenv("REDIS_USERNAME", raising=False)
|
|
monkeypatch.delenv("REDIS_PASSWORD", raising=False)
|
|
|
|
# Call the function to get the Redis URL
|
|
redis_url = get_redis_url_from_environment()
|
|
|
|
# Assert that the returned URL uses rediss:// protocol
|
|
assert redis_url == "rediss://redis-server:6379"
|
|
|
|
def test_get_redis_url_from_environment_with_username_password(monkeypatch):
|
|
"""Test when username and password are provided"""
|
|
# Set the environment variables
|
|
monkeypatch.setenv("REDIS_HOST", "redis-server")
|
|
monkeypatch.setenv("REDIS_PORT", "6379")
|
|
monkeypatch.setenv("REDIS_USERNAME", "user")
|
|
monkeypatch.setenv("REDIS_PASSWORD", "password")
|
|
|
|
# Call the function to get the Redis URL
|
|
redis_url = get_redis_url_from_environment()
|
|
|
|
# Assert that the returned URL includes username:password@
|
|
assert redis_url == "redis://user:password@redis-server:6379"
|
|
|
|
def test_get_redis_url_from_environment_with_password_only(monkeypatch):
|
|
"""Test when only password is provided"""
|
|
# Set the environment variables
|
|
monkeypatch.setenv("REDIS_HOST", "redis-server")
|
|
monkeypatch.setenv("REDIS_PORT", "6379")
|
|
monkeypatch.setenv("REDIS_PASSWORD", "password")
|
|
# Ensure username is not set
|
|
monkeypatch.delenv("REDIS_USERNAME", raising=False)
|
|
monkeypatch.delenv("REDIS_SSL", raising=False)
|
|
|
|
# Call the function to get the Redis URL
|
|
redis_url = get_redis_url_from_environment()
|
|
|
|
# Assert that the returned URL includes :password@
|
|
assert redis_url == "redis://password@redis-server:6379"
|
|
|
|
def test_get_redis_url_from_environment_with_all_options(monkeypatch):
|
|
"""Test when all options are provided"""
|
|
# Set the environment variables
|
|
monkeypatch.setenv("REDIS_HOST", "redis-server")
|
|
monkeypatch.setenv("REDIS_PORT", "6379")
|
|
monkeypatch.setenv("REDIS_USERNAME", "user")
|
|
monkeypatch.setenv("REDIS_PASSWORD", "password")
|
|
monkeypatch.setenv("REDIS_SSL", "true")
|
|
|
|
# Call the function to get the Redis URL
|
|
redis_url = get_redis_url_from_environment()
|
|
|
|
# Assert that the returned URL includes all components
|
|
assert redis_url == "rediss://user:password@redis-server:6379"
|
|
|
|
def test_get_redis_url_from_environment_missing_host_port(monkeypatch):
|
|
"""Test error when required variables are missing"""
|
|
# Make sure these environment variables don't exist
|
|
monkeypatch.delenv("REDIS_URL", raising=False)
|
|
monkeypatch.delenv("REDIS_HOST", raising=False)
|
|
monkeypatch.delenv("REDIS_PORT", raising=False)
|
|
|
|
# Call the function and expect a ValueError
|
|
with pytest.raises(ValueError) as excinfo:
|
|
get_redis_url_from_environment()
|
|
|
|
# Check the error message
|
|
assert "Either 'REDIS_URL' or both 'REDIS_HOST' and 'REDIS_PORT' must be specified" in str(excinfo.value)
|
|
|
|
def test_get_redis_url_from_environment_missing_port(monkeypatch):
|
|
"""Test error when only REDIS_HOST is provided but REDIS_PORT is missing"""
|
|
# Make sure REDIS_URL doesn't exist and set only REDIS_HOST
|
|
monkeypatch.delenv("REDIS_URL", raising=False)
|
|
monkeypatch.delenv("REDIS_PORT", raising=False)
|
|
monkeypatch.setenv("REDIS_HOST", "redis-server")
|
|
|
|
# Call the function and expect a ValueError
|
|
with pytest.raises(ValueError) as excinfo:
|
|
get_redis_url_from_environment()
|
|
|
|
# Check the error message
|
|
assert "Either 'REDIS_URL' or both 'REDIS_HOST' and 'REDIS_PORT' must be specified" in str(excinfo.value)
|
|
|
|
def test_max_connections_in_cluster_kwargs():
|
|
"""Test that max_connections is included in Redis cluster kwargs"""
|
|
kwargs = _get_redis_cluster_kwargs()
|
|
assert "max_connections" in kwargs, "max_connections should be in available Redis cluster kwargs"
|
|
|
|
def test_get_redis_async_client_with_connection_pool():
|
|
"""Test that connection_pool parameter is properly passed to Redis client"""
|
|
# Create a mock connection pool
|
|
mock_pool = MagicMock(spec=async_redis.BlockingConnectionPool)
|
|
|
|
# Mock the Redis client creation
|
|
with patch('litellm._redis.async_redis.Redis') as mock_redis, \
|
|
patch('litellm._redis._get_redis_client_logic') as mock_logic:
|
|
|
|
# Configure mock to return basic redis kwargs
|
|
mock_logic.return_value = {
|
|
"host": "localhost",
|
|
"port": 6379,
|
|
"db": 0
|
|
}
|
|
|
|
# Call get_redis_async_client with connection_pool
|
|
get_redis_async_client(connection_pool=mock_pool)
|
|
|
|
# Verify Redis was called with connection_pool in kwargs
|
|
call_kwargs = mock_redis.call_args[1]
|
|
assert "connection_pool" in call_kwargs, "connection_pool should be passed to Redis client"
|
|
assert call_kwargs["connection_pool"] == mock_pool, "connection_pool should match the provided pool"
|
|
|
|
def test_get_redis_async_client_without_connection_pool():
|
|
"""Test that Redis client works without connection_pool parameter"""
|
|
with patch('litellm._redis.async_redis.Redis') as mock_redis, \
|
|
patch('litellm._redis._get_redis_client_logic') as mock_logic:
|
|
|
|
# Configure mock to return basic redis kwargs
|
|
mock_logic.return_value = {
|
|
"host": "localhost",
|
|
"port": 6379,
|
|
"db": 0
|
|
}
|
|
|
|
# Call get_redis_async_client without connection_pool
|
|
get_redis_async_client()
|
|
|
|
# Verify Redis was called without connection_pool in kwargs
|
|
call_kwargs = mock_redis.call_args[1]
|
|
assert "connection_pool" not in call_kwargs, "connection_pool should not be in kwargs when not provided"
|
|
|
|
@patch("litellm._redis.init_redis_cluster")
|
|
def test_sync_client_prefers_cluster_over_url(mock_init_cluster, monkeypatch):
|
|
"""
|
|
Test get_redis_client returns RedisCluster when startup_nodes is present even if
|
|
REDIS_URL is also set.
|
|
"""
|
|
monkeypatch.setenv("REDIS_URL", "redis://fallback-host:6379")
|
|
mock_init_cluster.return_value = MagicMock(spec=redis.RedisCluster)
|
|
|
|
startup_nodes = [{"host": "cluster-node.example.com", "port": 6379}]
|
|
get_redis_client(startup_nodes=startup_nodes)
|
|
|
|
mock_init_cluster.assert_called_once()
|
|
call_kwargs = mock_init_cluster.call_args[0][0]
|
|
assert (
|
|
"startup_nodes" in call_kwargs
|
|
), "startup_nodes must be forwarded to init_redis_cluster"
|
|
|
|
@patch("litellm._redis.async_redis.RedisCluster")
|
|
def test_async_client_prefers_cluster_over_url(mock_cluster_cls, monkeypatch):
|
|
"""
|
|
Test (1) get_redis_async_client returns async RedisCluster when startup_nodes is present
|
|
even if REDIS_URL is also set and (2) startup_nodes is forwarded to RedisCluster.
|
|
"""
|
|
monkeypatch.setenv("REDIS_URL", "redis://fallback-host:6379")
|
|
|
|
startup_nodes = [{"host": "cluster-node.example.com", "port": 6379}]
|
|
get_redis_async_client(startup_nodes=startup_nodes)
|
|
|
|
mock_cluster_cls.assert_called_once()
|
|
call_kwargs = mock_cluster_cls.call_args[1]
|
|
assert "startup_nodes" in call_kwargs, "startup_nodes must be forwarded to async RedisCluster"
|
|
assert len(call_kwargs["startup_nodes"]) == 1, "should forward exactly 1 cluster node"
|
|
|
|
|
|
@patch("litellm._redis.async_redis.RedisCluster")
|
|
def test_async_client_prefers_cluster_over_url_via_env_var(mock_cluster_cls, monkeypatch):
|
|
"""
|
|
Test get_redis_async_client returns async RedisCluster when REDIS_CLUSTER_NODES is set
|
|
even if REDIS_URL is also set.
|
|
"""
|
|
monkeypatch.setenv("REDIS_URL", "redis://fallback-host:6379")
|
|
monkeypatch.setenv(
|
|
"REDIS_CLUSTER_NODES",
|
|
json.dumps([{"host": "cluster-node.example.com", "port": 6379}]),
|
|
)
|
|
|
|
get_redis_async_client()
|
|
|
|
mock_cluster_cls.assert_called_once()
|
|
call_kwargs = mock_cluster_cls.call_args[1]
|
|
assert "startup_nodes" in call_kwargs, "startup_nodes must be forwarded to async RedisCluster"
|
|
|
|
@patch("litellm._redis.init_redis_cluster")
|
|
def test_sync_client_prefers_cluster_over_url_via_env_var(mock_init_cluster, monkeypatch):
|
|
"""
|
|
Test get_redis_client returns RedisCluster when REDIS_CLUSTER_NODES is set even if
|
|
REDIS_URL is also set.
|
|
"""
|
|
monkeypatch.setenv("REDIS_URL", "redis://fallback-host:6379")
|
|
monkeypatch.setenv(
|
|
"REDIS_CLUSTER_NODES",
|
|
json.dumps([{"host": "cluster-node.example.com", "port": 6379}]),
|
|
)
|
|
mock_init_cluster.return_value = MagicMock(spec=redis.RedisCluster)
|
|
|
|
get_redis_client()
|
|
|
|
mock_init_cluster.assert_called_once()
|
|
call_kwargs = mock_init_cluster.call_args[0][0]
|
|
assert "startup_nodes" in call_kwargs, "startup_nodes must be forwarded to init_redis_cluster"
|
|
assert len(call_kwargs["startup_nodes"]) == 1
|
|
|
|
@patch("litellm._redis.init_redis_cluster")
|
|
def test_sync_client_preserves_password_for_cluster_when_url_also_set(mock_init_cluster, monkeypatch):
|
|
"""
|
|
Test _get_redis_client_logic does not strip password from redis_kwargs when
|
|
startup_nodes is present even if REDIS_URL is also set.
|
|
"""
|
|
monkeypatch.setenv("REDIS_URL", "redis://fallback-host:6379")
|
|
monkeypatch.setenv("REDIS_PASSWORD", "secret")
|
|
mock_init_cluster.return_value = MagicMock(spec=redis.RedisCluster)
|
|
|
|
startup_nodes = [{"host": "cluster-node.example.com", "port": 6379}]
|
|
get_redis_client(startup_nodes=startup_nodes)
|
|
|
|
mock_init_cluster.assert_called_once()
|
|
call_kwargs = mock_init_cluster.call_args[0][0]
|
|
assert "password" in call_kwargs, "password must not be stripped when routing to cluster"
|
|
assert call_kwargs["password"] == "secret"
|
|
|
|
|
|
def test_connection_pool_returns_none_for_cluster(monkeypatch):
|
|
"""Test get_redis_connection_pool returns None when startup_nodes is present."""
|
|
monkeypatch.setenv("REDIS_URL", "redis://fallback-host:6379")
|
|
startup_nodes = [{"host": "cluster-node.example.com", "port": 6379}]
|
|
result = get_redis_connection_pool(startup_nodes=startup_nodes)
|
|
assert result is None, "connection pool must be None for cluster mode"
|
|
|
|
|
|
@patch("litellm._redis.redis.Redis.from_url")
|
|
def test_sync_client_url_used_when_no_cluster(mock_from_url, monkeypatch):
|
|
"""
|
|
Test get_redis_client default to using URL path when no startup_nodes are provided.
|
|
"""
|
|
monkeypatch.setenv("REDIS_URL", "redis://plain-host:6379")
|
|
monkeypatch.delenv("REDIS_CLUSTER_NODES", raising=False)
|
|
|
|
get_redis_client()
|
|
|
|
mock_from_url.assert_called_once()
|