mirror of
https://github.com/tiennm99/litellm.git
synced 2026-06-17 20:48:32 +00:00
38709ba9bb
* feat(proxy): skip disable_background_health_check models on GET /health when flag set Co-authored-by: Cursor <cursoragent@cursor.com> * fix comment * fix greptile comments * Fix health check fallback kwargs * Format health endpoint * Harden direct health check kwargs compatibility for monkeypatched perform_health_check Replace substring-based TypeError detection with unexpected-keyword checks and a short retry chain (full kwargs, instrumentation only, filter only, minimal) so partial stubs work regardless of which optional kwarg fails first. Add proxy unit tests for legacy three-arg stubs and single-kwarg variants. Co-authored-by: Sameer Kankute <Sameerlite@users.noreply.github.com> * fix black --------- Co-authored-by: Cursor <cursoragent@cursor.com> Co-authored-by: Sameer Kankute <Sameerlite@users.noreply.github.com>
2978 lines
100 KiB
Python
2978 lines
100 KiB
Python
import os
|
|
import sys
|
|
import traceback
|
|
from unittest import mock
|
|
|
|
from dotenv import load_dotenv
|
|
|
|
import litellm.proxy
|
|
import litellm.proxy.proxy_server
|
|
|
|
load_dotenv()
|
|
import io
|
|
import json
|
|
import os
|
|
|
|
# this file is to test litellm/proxy
|
|
|
|
sys.path.insert(
|
|
0, os.path.abspath("../..")
|
|
) # Adds the parent directory to the system path
|
|
import asyncio
|
|
import logging
|
|
|
|
import pytest
|
|
|
|
import litellm
|
|
from litellm import RateLimitError, Timeout, completion, completion_cost, embedding
|
|
|
|
# Configure logging
|
|
logging.basicConfig(
|
|
level=logging.DEBUG, # Set the desired logging level
|
|
format="%(asctime)s - %(levelname)s - %(message)s",
|
|
)
|
|
|
|
from unittest.mock import AsyncMock, patch
|
|
|
|
from fastapi import FastAPI
|
|
|
|
# test /chat/completion request to the proxy
|
|
from fastapi.testclient import TestClient
|
|
|
|
from litellm.integrations.custom_logger import CustomLogger
|
|
from litellm.proxy.proxy_server import ( # Replace with the actual module where your FastAPI router is defined
|
|
app,
|
|
initialize,
|
|
save_worker_config,
|
|
)
|
|
from litellm.proxy.utils import ProxyLogging
|
|
|
|
# Your bearer token
|
|
token = "sk-1234"
|
|
|
|
headers = {"Authorization": f"Bearer {token}"}
|
|
|
|
example_completion_result = {
|
|
"choices": [
|
|
{
|
|
"message": {
|
|
"content": "Whispers of the wind carry dreams to me.",
|
|
"role": "assistant",
|
|
}
|
|
}
|
|
],
|
|
}
|
|
example_embedding_result = {
|
|
"object": "list",
|
|
"data": [
|
|
{
|
|
"object": "embedding",
|
|
"index": 0,
|
|
"embedding": [
|
|
-0.006929283495992422,
|
|
-0.005336422007530928,
|
|
-4.547132266452536e-05,
|
|
-0.024047505110502243,
|
|
-0.006929283495992422,
|
|
-0.005336422007530928,
|
|
-4.547132266452536e-05,
|
|
-0.024047505110502243,
|
|
-0.006929283495992422,
|
|
-0.005336422007530928,
|
|
-4.547132266452536e-05,
|
|
-0.024047505110502243,
|
|
],
|
|
}
|
|
],
|
|
"model": "text-embedding-3-small",
|
|
"usage": {"prompt_tokens": 5, "total_tokens": 5},
|
|
}
|
|
example_image_generation_result = {
|
|
"created": 1589478378,
|
|
"data": [{"url": "https://..."}, {"url": "https://..."}],
|
|
}
|
|
|
|
|
|
def mock_patch_acompletion():
|
|
return mock.patch(
|
|
"litellm.proxy.proxy_server.llm_router.acompletion",
|
|
return_value=example_completion_result,
|
|
)
|
|
|
|
|
|
def mock_patch_aembedding():
|
|
return mock.patch(
|
|
"litellm.proxy.proxy_server.llm_router.aembedding",
|
|
return_value=example_embedding_result,
|
|
)
|
|
|
|
|
|
def mock_patch_aimage_generation():
|
|
return mock.patch(
|
|
"litellm.proxy.proxy_server.llm_router.aimage_generation",
|
|
return_value=example_image_generation_result,
|
|
)
|
|
|
|
|
|
@pytest.fixture(scope="function")
|
|
def fake_env_vars(monkeypatch):
|
|
# Set some fake environment variables
|
|
monkeypatch.setenv("OPENAI_API_KEY", "fake_openai_api_key")
|
|
monkeypatch.setenv("OPENAI_API_BASE", "http://fake-openai-api-base")
|
|
monkeypatch.setenv("AZURE_AI_API_BASE", "http://fake-azure-api-base")
|
|
monkeypatch.setenv("AZURE_OPENAI_API_KEY", "fake_azure_openai_api_key")
|
|
monkeypatch.setenv("AZURE_SWEDEN_API_BASE", "http://fake-azure-sweden-api-base")
|
|
monkeypatch.setenv("REDIS_HOST", "localhost")
|
|
|
|
|
|
@pytest.fixture(scope="function")
|
|
def client_no_auth(fake_env_vars):
|
|
# Assuming litellm.proxy.proxy_server is an object
|
|
from litellm.proxy.proxy_server import cleanup_router_config_variables
|
|
|
|
cleanup_router_config_variables()
|
|
filepath = os.path.dirname(os.path.abspath(__file__))
|
|
config_fp = f"{filepath}/test_configs/test_config_no_auth.yaml"
|
|
# initialize can get run in parallel, it sets specific variables for the fast api app, sinc eit gets run in parallel different tests use the wrong variables
|
|
asyncio.run(initialize(config=config_fp, debug=True))
|
|
return TestClient(app)
|
|
|
|
|
|
@mock_patch_acompletion()
|
|
def test_chat_completion(mock_acompletion, client_no_auth):
|
|
global headers
|
|
try:
|
|
# Your test data
|
|
test_data = {
|
|
"model": "gpt-3.5-turbo",
|
|
"messages": [
|
|
{"role": "user", "content": "hi"},
|
|
],
|
|
"max_tokens": 10,
|
|
}
|
|
|
|
print("testing proxy server with chat completions")
|
|
response = client_no_auth.post("/v1/chat/completions", json=test_data)
|
|
mock_acompletion.assert_called_once_with(
|
|
model="gpt-3.5-turbo",
|
|
messages=[
|
|
{"role": "user", "content": "hi"},
|
|
],
|
|
max_tokens=10,
|
|
litellm_call_id=mock.ANY,
|
|
litellm_logging_obj=mock.ANY,
|
|
request_timeout=mock.ANY,
|
|
specific_deployment=True,
|
|
metadata=mock.ANY,
|
|
proxy_server_request=mock.ANY,
|
|
secret_fields=mock.ANY,
|
|
)
|
|
print(f"response - {response.text}")
|
|
assert response.status_code == 200
|
|
result = response.json()
|
|
print(f"Received response: {result}")
|
|
except Exception as e:
|
|
pytest.fail(f"LiteLLM Proxy test failed. Exception - {str(e)}")
|
|
|
|
|
|
def test_chat_completion_malformed_messages_returns_400(client_no_auth):
|
|
"""
|
|
Test that malformed messages (strings instead of dicts) return 400 instead of 500.
|
|
|
|
This test verifies that when a client sends messages as raw strings instead of
|
|
{role, content} objects, LiteLLM returns a 400 invalid_request_error instead
|
|
of a 500 Internal Server Error.
|
|
"""
|
|
global headers
|
|
try:
|
|
# Test data with malformed messages (string instead of dict)
|
|
test_data = {
|
|
"model": "gpt-3.5-turbo",
|
|
"messages": [
|
|
"hi how are you"
|
|
], # Invalid: should be [{"role": "user", "content": "hi how are you"}]
|
|
}
|
|
|
|
print("testing proxy server with malformed messages")
|
|
response = client_no_auth.post(
|
|
"/v1/chat/completions", json=test_data, headers=headers
|
|
)
|
|
|
|
print(f"response status: {response.status_code}")
|
|
print(f"response text: {response.text}")
|
|
|
|
# Should return 400, not 500
|
|
assert (
|
|
response.status_code == 400
|
|
), f"Expected 400, got {response.status_code}. Response: {response.text}"
|
|
|
|
# Verify error format
|
|
result = response.json()
|
|
assert "error" in result, "Response should contain 'error' key"
|
|
error = result["error"]
|
|
|
|
# Verify error type and message
|
|
assert (
|
|
error.get("type") == "invalid_request_error" or error.get("type") is None
|
|
), f"Expected invalid_request_error or None, got {error.get('type')}"
|
|
assert (
|
|
error.get("code") == "400" or error.get("code") == 400
|
|
), f"Expected code 400, got {error.get('code')}"
|
|
|
|
# Error message should indicate invalid request format
|
|
error_message = error.get("message", "")
|
|
assert len(error_message) > 0, "Error message should not be empty"
|
|
|
|
except Exception as e:
|
|
pytest.fail(f"LiteLLM Proxy test failed. Exception - {str(e)}")
|
|
|
|
|
|
def test_get_settings_request_timeout(client_no_auth):
|
|
"""
|
|
When no timeout is set, it should use the litellm.request_timeout value
|
|
"""
|
|
# Set a known value for litellm.request_timeout
|
|
import litellm
|
|
|
|
# Make a GET request to /settings
|
|
response = client_no_auth.get("/settings")
|
|
|
|
# Check if the request was successful
|
|
assert response.status_code == 200
|
|
|
|
# Parse the JSON response
|
|
settings = response.json()
|
|
print("settings", settings)
|
|
|
|
assert settings["litellm.request_timeout"] == litellm.request_timeout
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"litellm_key_header_name",
|
|
["x-litellm-key", None],
|
|
)
|
|
def test_add_headers_to_request(litellm_key_header_name):
|
|
from fastapi import Request
|
|
from starlette.datastructures import URL
|
|
import json
|
|
from litellm.proxy.litellm_pre_call_utils import (
|
|
clean_headers,
|
|
LiteLLMProxyRequestSetup,
|
|
)
|
|
|
|
headers = {
|
|
"Authorization": "Bearer 1234",
|
|
"X-Custom-Header": "Custom-Value",
|
|
"X-Stainless-Header": "Stainless-Value",
|
|
"anthropic-beta": "beta-value",
|
|
}
|
|
request = Request(scope={"type": "http"})
|
|
request._url = URL(url="/chat/completions")
|
|
request._body = json.dumps({"model": "gpt-3.5-turbo"}).encode("utf-8")
|
|
request_headers = clean_headers(headers, litellm_key_header_name)
|
|
forwarded_headers = LiteLLMProxyRequestSetup._get_forwardable_headers(
|
|
request_headers
|
|
)
|
|
assert forwarded_headers == {
|
|
"X-Custom-Header": "Custom-Value",
|
|
"anthropic-beta": "beta-value",
|
|
}
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"litellm_key_header_name",
|
|
["x-litellm-key", None],
|
|
)
|
|
@pytest.mark.parametrize(
|
|
"forward_headers",
|
|
[True, False],
|
|
)
|
|
@mock_patch_acompletion()
|
|
def test_chat_completion_forward_headers(
|
|
mock_acompletion, client_no_auth, litellm_key_header_name, forward_headers
|
|
):
|
|
global headers
|
|
try:
|
|
if forward_headers:
|
|
gs = getattr(litellm.proxy.proxy_server, "general_settings")
|
|
gs["forward_client_headers_to_llm_api"] = True
|
|
setattr(litellm.proxy.proxy_server, "general_settings", gs)
|
|
if litellm_key_header_name is not None:
|
|
gs = getattr(litellm.proxy.proxy_server, "general_settings")
|
|
gs["litellm_key_header_name"] = litellm_key_header_name
|
|
setattr(litellm.proxy.proxy_server, "general_settings", gs)
|
|
# Your test data
|
|
test_data = {
|
|
"model": "gpt-3.5-turbo",
|
|
"messages": [
|
|
{"role": "user", "content": "hi"},
|
|
],
|
|
"max_tokens": 10,
|
|
}
|
|
|
|
headers_to_forward = {
|
|
"X-Custom-Header": "Custom-Value",
|
|
"X-Another-Header": "Another-Value",
|
|
}
|
|
|
|
if litellm_key_header_name is not None:
|
|
headers_to_not_forward = {litellm_key_header_name: "Bearer 1234"}
|
|
else:
|
|
headers_to_not_forward = {"Authorization": "Bearer 1234"}
|
|
|
|
received_headers = {**headers_to_forward, **headers_to_not_forward}
|
|
|
|
print("testing proxy server with chat completions")
|
|
response = client_no_auth.post(
|
|
"/v1/chat/completions", json=test_data, headers=received_headers
|
|
)
|
|
if not forward_headers:
|
|
assert "headers" not in mock_acompletion.call_args.kwargs
|
|
else:
|
|
assert mock_acompletion.call_args.kwargs["headers"] == {
|
|
"x-custom-header": "Custom-Value",
|
|
"x-another-header": "Another-Value",
|
|
}
|
|
|
|
print(f"response - {response.text}")
|
|
assert response.status_code == 200
|
|
result = response.json()
|
|
print(f"Received response: {result}")
|
|
except Exception as e:
|
|
pytest.fail(f"LiteLLM Proxy test failed. Exception - {str(e)}")
|
|
|
|
|
|
@pytest.mark.parametrize("forward_llm_auth_headers", [True, False])
|
|
@mock_patch_acompletion()
|
|
def test_chat_completion_forward_llm_provider_auth_headers(
|
|
mock_acompletion, client_no_auth, forward_llm_auth_headers
|
|
):
|
|
"""
|
|
Test that LLM provider auth headers (x-api-key, x-goog-api-key) are forwarded
|
|
when forward_llm_provider_auth_headers=True.
|
|
|
|
This allows clients to send their own LLM provider API keys through the proxy.
|
|
"""
|
|
try:
|
|
# Configure general settings
|
|
gs = getattr(litellm.proxy.proxy_server, "general_settings")
|
|
gs["forward_client_headers_to_llm_api"] = True
|
|
gs["forward_llm_provider_auth_headers"] = forward_llm_auth_headers
|
|
setattr(litellm.proxy.proxy_server, "general_settings", gs)
|
|
|
|
# Test data
|
|
test_data = {
|
|
"model": "gpt-3.5-turbo",
|
|
"messages": [
|
|
{"role": "user", "content": "hello"},
|
|
],
|
|
"max_tokens": 10,
|
|
}
|
|
|
|
# Headers including LLM provider auth
|
|
request_headers = {
|
|
"Authorization": "Bearer sk-proxy-auth-123", # Proxy auth (should be stripped)
|
|
"x-api-key": "sk-ant-api03-test-anthropic-key", # Anthropic API key
|
|
"x-goog-api-key": "google-api-key-123", # Google API key
|
|
"X-Custom-Header": "custom-value", # Custom header (should be forwarded)
|
|
}
|
|
|
|
# Make request
|
|
response = client_no_auth.post(
|
|
"/v1/chat/completions", json=test_data, headers=request_headers
|
|
)
|
|
|
|
assert response.status_code == 200
|
|
|
|
# Check forwarded headers
|
|
forwarded_headers = mock_acompletion.call_args.kwargs.get("headers", {})
|
|
|
|
if forward_llm_auth_headers:
|
|
# LLM provider auth headers should be forwarded
|
|
assert "x-api-key" in forwarded_headers
|
|
assert forwarded_headers["x-api-key"] == "sk-ant-api03-test-anthropic-key"
|
|
assert "x-goog-api-key" in forwarded_headers
|
|
assert forwarded_headers["x-goog-api-key"] == "google-api-key-123"
|
|
else:
|
|
# LLM provider auth headers should be stripped
|
|
assert "x-api-key" not in forwarded_headers
|
|
assert "x-goog-api-key" not in forwarded_headers
|
|
|
|
# Custom headers should always be forwarded (when forward_client_headers_to_llm_api=True)
|
|
assert "x-custom-header" in forwarded_headers
|
|
assert forwarded_headers["x-custom-header"] == "custom-value"
|
|
|
|
# Proxy Authorization should never be forwarded
|
|
assert "authorization" not in forwarded_headers
|
|
|
|
print(
|
|
f"✓ Test passed with forward_llm_provider_auth_headers={forward_llm_auth_headers}"
|
|
)
|
|
print(f" Forwarded headers: {list(forwarded_headers.keys())}")
|
|
|
|
except Exception as e:
|
|
pytest.fail(
|
|
f"Test failed with forward_llm_auth_headers={forward_llm_auth_headers}: {str(e)}"
|
|
)
|
|
finally:
|
|
# Clean up
|
|
gs = getattr(litellm.proxy.proxy_server, "general_settings")
|
|
gs.pop("forward_llm_provider_auth_headers", None)
|
|
setattr(litellm.proxy.proxy_server, "general_settings", gs)
|
|
|
|
|
|
@mock_patch_acompletion()
|
|
@pytest.mark.asyncio
|
|
async def test_team_disable_guardrails(mock_acompletion, client_no_auth):
|
|
"""
|
|
If team not allowed to turn on/off guardrails
|
|
|
|
Raise 403 forbidden error, if request is made by team on `/key/generate` or `/chat/completions`.
|
|
"""
|
|
import asyncio
|
|
import json
|
|
import time
|
|
|
|
from fastapi import HTTPException, Request
|
|
from starlette.datastructures import URL
|
|
|
|
from litellm.proxy._types import (
|
|
LiteLLM_TeamTable,
|
|
LiteLLM_TeamTableCachedObj,
|
|
ProxyException,
|
|
UserAPIKeyAuth,
|
|
)
|
|
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
|
from litellm.proxy.proxy_server import hash_token, user_api_key_cache
|
|
|
|
_team_id = "1234"
|
|
user_key = "sk-12345678"
|
|
|
|
valid_token = UserAPIKeyAuth(
|
|
team_id=_team_id,
|
|
team_blocked=True,
|
|
token=hash_token(user_key),
|
|
last_refreshed_at=time.time(),
|
|
)
|
|
await asyncio.sleep(1)
|
|
team_obj = LiteLLM_TeamTableCachedObj(
|
|
team_id=_team_id,
|
|
blocked=False,
|
|
last_refreshed_at=time.time(),
|
|
metadata={"guardrails": {"modify_guardrails": False}},
|
|
)
|
|
user_api_key_cache.set_cache(key=hash_token(user_key), value=valid_token)
|
|
user_api_key_cache.set_cache(key="team_id:{}".format(_team_id), value=team_obj)
|
|
|
|
setattr(litellm.proxy.proxy_server, "user_api_key_cache", user_api_key_cache)
|
|
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
|
|
setattr(litellm.proxy.proxy_server, "prisma_client", "hello-world")
|
|
|
|
request = Request(scope={"type": "http"})
|
|
request._url = URL(url="/chat/completions")
|
|
|
|
body = {"metadata": {"guardrails": {"hide_secrets": False}}}
|
|
json_bytes = json.dumps(body).encode("utf-8")
|
|
|
|
request._body = json_bytes
|
|
|
|
try:
|
|
await user_api_key_auth(request=request, api_key="Bearer " + user_key)
|
|
pytest.fail("Expected to raise 403 forbidden error.")
|
|
except ProxyException as e:
|
|
assert e.code == str(403)
|
|
|
|
|
|
from test_custom_callback_input import CompletionCustomHandler
|
|
|
|
|
|
@mock_patch_acompletion()
|
|
def test_custom_logger_failure_handler(mock_acompletion, client_no_auth):
|
|
from litellm.proxy._types import UserAPIKeyAuth
|
|
from litellm.proxy.proxy_server import hash_token, user_api_key_cache
|
|
|
|
rpm_limit = 0
|
|
|
|
mock_api_key = "sk-my-test-key"
|
|
cache_value = UserAPIKeyAuth(token=hash_token(mock_api_key), rpm_limit=rpm_limit)
|
|
|
|
user_api_key_cache.set_cache(key=hash_token(mock_api_key), value=cache_value)
|
|
|
|
mock_logger = CustomLogger()
|
|
mock_logger_unit_tests = CompletionCustomHandler()
|
|
proxy_logging_obj: ProxyLogging = getattr(
|
|
litellm.proxy.proxy_server, "proxy_logging_obj"
|
|
)
|
|
|
|
litellm.callbacks = [mock_logger, mock_logger_unit_tests]
|
|
proxy_logging_obj._init_litellm_callbacks(llm_router=None)
|
|
|
|
setattr(litellm.proxy.proxy_server, "user_api_key_cache", user_api_key_cache)
|
|
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
|
|
setattr(litellm.proxy.proxy_server, "prisma_client", "FAKE-VAR")
|
|
setattr(litellm.proxy.proxy_server, "proxy_logging_obj", proxy_logging_obj)
|
|
|
|
with patch.object(
|
|
mock_logger, "async_log_failure_event", new=AsyncMock()
|
|
) as mock_failed_alert:
|
|
# Your test data
|
|
test_data = {
|
|
"model": "gpt-3.5-turbo",
|
|
"messages": [
|
|
{"role": "user", "content": "hi"},
|
|
],
|
|
"max_tokens": 10,
|
|
}
|
|
|
|
print("testing proxy server with chat completions")
|
|
response = client_no_auth.post(
|
|
"/v1/chat/completions",
|
|
json=test_data,
|
|
headers={"Authorization": "Bearer {}".format(mock_api_key)},
|
|
)
|
|
assert response.status_code == 429
|
|
|
|
# confirm async_log_failure_event is called
|
|
mock_failed_alert.assert_called()
|
|
|
|
assert len(mock_logger_unit_tests.errors) == 0
|
|
|
|
|
|
@mock_patch_acompletion()
|
|
def test_engines_model_chat_completions(mock_acompletion, client_no_auth):
|
|
global headers
|
|
try:
|
|
# Your test data
|
|
test_data = {
|
|
"model": "gpt-3.5-turbo",
|
|
"messages": [
|
|
{"role": "user", "content": "hi"},
|
|
],
|
|
"max_tokens": 10,
|
|
}
|
|
|
|
print("testing proxy server with chat completions")
|
|
response = client_no_auth.post(
|
|
"/engines/gpt-3.5-turbo/chat/completions", json=test_data
|
|
)
|
|
mock_acompletion.assert_called_once_with(
|
|
model="gpt-3.5-turbo",
|
|
messages=[
|
|
{"role": "user", "content": "hi"},
|
|
],
|
|
max_tokens=10,
|
|
litellm_call_id=mock.ANY,
|
|
litellm_logging_obj=mock.ANY,
|
|
request_timeout=mock.ANY,
|
|
specific_deployment=True,
|
|
metadata=mock.ANY,
|
|
proxy_server_request=mock.ANY,
|
|
secret_fields=mock.ANY,
|
|
)
|
|
print(f"response - {response.text}")
|
|
assert response.status_code == 200
|
|
result = response.json()
|
|
print(f"Received response: {result}")
|
|
except Exception as e:
|
|
pytest.fail(f"LiteLLM Proxy test failed. Exception - {str(e)}")
|
|
|
|
|
|
@mock_patch_acompletion()
|
|
def test_chat_completion_azure(mock_acompletion, client_no_auth):
|
|
global headers
|
|
try:
|
|
# Your test data
|
|
test_data = {
|
|
"model": "azure/gpt-4.1-mini",
|
|
"messages": [
|
|
{"role": "user", "content": "write 1 sentence poem"},
|
|
],
|
|
"max_tokens": 10,
|
|
}
|
|
|
|
print("testing proxy server with Azure Request /chat/completions")
|
|
response = client_no_auth.post("/v1/chat/completions", json=test_data)
|
|
|
|
mock_acompletion.assert_called_once_with(
|
|
model="azure/gpt-4.1-mini",
|
|
messages=[
|
|
{"role": "user", "content": "write 1 sentence poem"},
|
|
],
|
|
max_tokens=10,
|
|
litellm_call_id=mock.ANY,
|
|
litellm_logging_obj=mock.ANY,
|
|
request_timeout=mock.ANY,
|
|
specific_deployment=True,
|
|
metadata=mock.ANY,
|
|
proxy_server_request=mock.ANY,
|
|
secret_fields=mock.ANY,
|
|
)
|
|
assert response.status_code == 200
|
|
result = response.json()
|
|
print(f"Received response: {result}")
|
|
assert len(result["choices"][0]["message"]["content"]) > 0
|
|
except Exception as e:
|
|
pytest.fail(f"LiteLLM Proxy test failed. Exception - {str(e)}")
|
|
|
|
|
|
# Run the test
|
|
# test_chat_completion_azure()
|
|
|
|
|
|
@mock_patch_acompletion()
|
|
def test_openai_deployments_model_chat_completions_azure(
|
|
mock_acompletion, client_no_auth
|
|
):
|
|
global headers
|
|
try:
|
|
# Your test data
|
|
test_data = {
|
|
"model": "azure/gpt-4.1-mini",
|
|
"messages": [
|
|
{"role": "user", "content": "write 1 sentence poem"},
|
|
],
|
|
"max_tokens": 10,
|
|
}
|
|
|
|
url = "/openai/deployments/azure/gpt-4.1-mini/chat/completions"
|
|
print(f"testing proxy server with Azure Request {url}")
|
|
response = client_no_auth.post(url, json=test_data)
|
|
|
|
mock_acompletion.assert_called_once_with(
|
|
model="azure/gpt-4.1-mini",
|
|
messages=[
|
|
{"role": "user", "content": "write 1 sentence poem"},
|
|
],
|
|
max_tokens=10,
|
|
litellm_call_id=mock.ANY,
|
|
litellm_logging_obj=mock.ANY,
|
|
request_timeout=mock.ANY,
|
|
specific_deployment=True,
|
|
metadata=mock.ANY,
|
|
proxy_server_request=mock.ANY,
|
|
secret_fields=mock.ANY,
|
|
)
|
|
assert response.status_code == 200
|
|
result = response.json()
|
|
print(f"Received response: {result}")
|
|
assert len(result["choices"][0]["message"]["content"]) > 0
|
|
except Exception as e:
|
|
pytest.fail(f"LiteLLM Proxy test failed. Exception - {str(e)}")
|
|
|
|
|
|
# Run the test
|
|
# test_openai_deployments_model_chat_completions_azure()
|
|
|
|
|
|
### EMBEDDING
|
|
@mock_patch_aembedding()
|
|
def test_embedding(mock_aembedding, client_no_auth):
|
|
global headers
|
|
from litellm.proxy.proxy_server import user_custom_auth
|
|
|
|
try:
|
|
test_data = {
|
|
"model": "azure/text-embedding-ada-002",
|
|
"input": ["good morning from litellm"],
|
|
}
|
|
|
|
async def _pre_call_hook_side_effect(**kwargs):
|
|
data = kwargs["data"]
|
|
metadata = {**(data.get("metadata") or {}), "source": "unit-test"}
|
|
data["metadata"] = metadata
|
|
proxy_request = {**(data.get("proxy_server_request") or {})}
|
|
proxy_request["path"] = "/v1/embeddings"
|
|
data["proxy_server_request"] = proxy_request
|
|
return data
|
|
|
|
async def _post_call_success_side_effect(**kwargs):
|
|
return kwargs["response"]
|
|
|
|
with (
|
|
patch.object(
|
|
litellm.proxy.proxy_server.proxy_logging_obj,
|
|
"pre_call_hook",
|
|
new=AsyncMock(side_effect=_pre_call_hook_side_effect),
|
|
) as mock_pre_call_hook,
|
|
patch.object(
|
|
litellm.proxy.proxy_server.proxy_logging_obj,
|
|
"during_call_hook",
|
|
new=AsyncMock(return_value=None),
|
|
) as mock_during_hook,
|
|
patch.object(
|
|
litellm.proxy.proxy_server.proxy_logging_obj,
|
|
"post_call_success_hook",
|
|
new=AsyncMock(side_effect=_post_call_success_side_effect),
|
|
),
|
|
):
|
|
response = client_no_auth.post("/v1/embeddings", json=test_data)
|
|
|
|
mock_aembedding.assert_called_once_with(
|
|
model="azure/text-embedding-ada-002",
|
|
input=["good morning from litellm"],
|
|
specific_deployment=True,
|
|
litellm_call_id=mock.ANY,
|
|
litellm_logging_obj=mock.ANY,
|
|
request_timeout=mock.ANY,
|
|
metadata=mock.ANY,
|
|
proxy_server_request=mock.ANY,
|
|
secret_fields=mock.ANY,
|
|
)
|
|
assert response.status_code == 200
|
|
result = response.json()
|
|
print(len(result["data"][0]["embedding"]))
|
|
assert len(result["data"][0]["embedding"]) > 10 # this usually has len==1536 so
|
|
|
|
call_metadata = mock_aembedding.call_args.kwargs["metadata"]
|
|
assert call_metadata.get("source") == "unit-test"
|
|
|
|
pre_call_kwargs = mock_pre_call_hook.await_args_list[0].kwargs
|
|
assert (
|
|
pre_call_kwargs.get("call_type") == "aembedding"
|
|
), f"expected pre_call_hook to receive call_type='aembedding', got {pre_call_kwargs.get('call_type')}"
|
|
|
|
except Exception as e:
|
|
pytest.fail(f"LiteLLM Proxy test failed. Exception - {str(e)}")
|
|
|
|
|
|
@mock_patch_aembedding()
|
|
def test_bedrock_embedding(mock_aembedding, client_no_auth):
|
|
global headers
|
|
from litellm.proxy.proxy_server import user_custom_auth
|
|
|
|
try:
|
|
test_data = {
|
|
"model": "amazon-embeddings",
|
|
"input": ["good morning from litellm"],
|
|
}
|
|
|
|
response = client_no_auth.post("/v1/embeddings", json=test_data)
|
|
|
|
mock_aembedding.assert_called_once_with(
|
|
model="amazon-embeddings",
|
|
input=["good morning from litellm"],
|
|
litellm_call_id=mock.ANY,
|
|
litellm_logging_obj=mock.ANY,
|
|
request_timeout=mock.ANY,
|
|
metadata=mock.ANY,
|
|
proxy_server_request=mock.ANY,
|
|
secret_fields=mock.ANY,
|
|
)
|
|
assert response.status_code == 200
|
|
print(response.status_code, response.text)
|
|
result = response.json()
|
|
print(len(result["data"][0]["embedding"]))
|
|
assert len(result["data"][0]["embedding"]) > 10 # this usually has len==1536 so
|
|
except Exception as e:
|
|
pytest.fail(f"LiteLLM Proxy test failed. Exception - {str(e)}")
|
|
|
|
|
|
@pytest.mark.skip(reason="AWS Suspended Account")
|
|
def test_sagemaker_embedding(client_no_auth):
|
|
global headers
|
|
from litellm.proxy.proxy_server import user_custom_auth
|
|
|
|
try:
|
|
test_data = {
|
|
"model": "GPT-J 6B - Sagemaker Text Embedding (Internal)",
|
|
"input": ["good morning from litellm"],
|
|
}
|
|
|
|
response = client_no_auth.post("/v1/embeddings", json=test_data)
|
|
|
|
assert response.status_code == 200
|
|
result = response.json()
|
|
print(len(result["data"][0]["embedding"]))
|
|
assert len(result["data"][0]["embedding"]) > 10 # this usually has len==1536 so
|
|
except Exception as e:
|
|
pytest.fail(f"LiteLLM Proxy test failed. Exception - {str(e)}")
|
|
|
|
|
|
# Run the test
|
|
# test_embedding()
|
|
#### IMAGE GENERATION
|
|
|
|
|
|
@mock_patch_aimage_generation()
|
|
def test_img_gen(mock_aimage_generation, client_no_auth):
|
|
global headers
|
|
from litellm.proxy.proxy_server import user_custom_auth
|
|
|
|
try:
|
|
test_data = {
|
|
"model": "dall-e-3",
|
|
"prompt": "A cute baby sea otter",
|
|
"n": 1,
|
|
"size": "1024x1024",
|
|
}
|
|
|
|
response = client_no_auth.post("/v1/images/generations", json=test_data)
|
|
|
|
mock_aimage_generation.assert_called_once_with(
|
|
model="dall-e-3",
|
|
prompt="A cute baby sea otter",
|
|
n=1,
|
|
size="1024x1024",
|
|
metadata=mock.ANY,
|
|
proxy_server_request=mock.ANY,
|
|
secret_fields=mock.ANY,
|
|
)
|
|
assert response.status_code == 200
|
|
result = response.json()
|
|
print(len(result["data"][0]["url"]))
|
|
assert len(result["data"][0]["url"]) > 10
|
|
except Exception as e:
|
|
pytest.fail(f"LiteLLM Proxy test failed. Exception - {str(e)}")
|
|
|
|
|
|
#### ADDITIONAL
|
|
@pytest.mark.skip(reason="test via docker tests. Requires prisma client.")
|
|
def test_add_new_model(client_no_auth):
|
|
global headers
|
|
try:
|
|
test_data = {
|
|
"model_name": "test_openai_models",
|
|
"litellm_params": {
|
|
"model": "gpt-3.5-turbo",
|
|
},
|
|
"model_info": {"description": "this is a test openai model"},
|
|
}
|
|
client_no_auth.post("/model/new", json=test_data, headers=headers)
|
|
response = client_no_auth.get("/model/info", headers=headers)
|
|
assert response.status_code == 200
|
|
result = response.json()
|
|
print(f"response: {result}")
|
|
model_info = None
|
|
for m in result["data"]:
|
|
if m["model_name"] == "test_openai_models":
|
|
model_info = m["model_info"]
|
|
assert model_info["description"] == "this is a test openai model"
|
|
except Exception as e:
|
|
pytest.fail(f"LiteLLM Proxy test failed. Exception {str(e)}")
|
|
|
|
|
|
@pytest.mark.xdist_group("proxy_heavy")
|
|
def test_health(client_no_auth):
|
|
global headers
|
|
import logging
|
|
import time
|
|
|
|
from litellm._logging import verbose_logger, verbose_proxy_logger
|
|
|
|
verbose_proxy_logger.setLevel(logging.DEBUG)
|
|
|
|
try:
|
|
response = client_no_auth.get("/health")
|
|
assert response.status_code == 200
|
|
except Exception as e:
|
|
pytest.fail(f"LiteLLM Proxy test failed. Exception - {str(e)}")
|
|
|
|
|
|
# test_add_new_model()
|
|
|
|
from litellm.integrations.custom_logger import CustomLogger
|
|
|
|
|
|
class MyCustomHandler(CustomLogger):
|
|
def log_pre_api_call(self, model, messages, kwargs):
|
|
print(f"Pre-API Call")
|
|
|
|
def log_success_event(self, kwargs, response_obj, start_time, end_time):
|
|
print(f"On Success")
|
|
assert kwargs["user"] == "proxy-user"
|
|
assert kwargs["model"] == "gpt-3.5-turbo"
|
|
assert kwargs["max_tokens"] == 10
|
|
|
|
|
|
customHandler = MyCustomHandler()
|
|
|
|
|
|
@mock_patch_acompletion()
|
|
def test_chat_completion_optional_params(mock_acompletion, client_no_auth):
|
|
# [PROXY: PROD TEST] - DO NOT DELETE
|
|
# This tests if all the /chat/completion params are passed to litellm
|
|
try:
|
|
# Your test data
|
|
litellm.set_verbose = True
|
|
test_data = {
|
|
"model": "gpt-3.5-turbo",
|
|
"messages": [
|
|
{"role": "user", "content": "hi"},
|
|
],
|
|
"max_tokens": 10,
|
|
"user": "proxy-user",
|
|
}
|
|
|
|
litellm.callbacks = [customHandler]
|
|
print("testing proxy server: optional params")
|
|
response = client_no_auth.post("/v1/chat/completions", json=test_data)
|
|
mock_acompletion.assert_called_once_with(
|
|
model="gpt-3.5-turbo",
|
|
messages=[
|
|
{"role": "user", "content": "hi"},
|
|
],
|
|
max_tokens=10,
|
|
user="proxy-user",
|
|
litellm_call_id=mock.ANY,
|
|
litellm_logging_obj=mock.ANY,
|
|
request_timeout=mock.ANY,
|
|
specific_deployment=True,
|
|
metadata=mock.ANY,
|
|
proxy_server_request=mock.ANY,
|
|
secret_fields=mock.ANY,
|
|
)
|
|
assert response.status_code == 200
|
|
result = response.json()
|
|
print(f"Received response: {result}")
|
|
except Exception as e:
|
|
pytest.fail("LiteLLM Proxy test failed. Exception", e)
|
|
|
|
|
|
# Run the test
|
|
# test_chat_completion_optional_params()
|
|
|
|
|
|
# Test Reading config.yaml file
|
|
from litellm.proxy.proxy_server import ProxyConfig
|
|
|
|
|
|
@pytest.mark.skip(reason="local variable conflicts. needs to be refactored.")
|
|
@mock.patch("litellm.proxy.proxy_server.litellm.Cache")
|
|
def test_load_router_config(mock_cache, fake_env_vars):
|
|
mock_cache.return_value.cache.__dict__ = {"redis_client": None}
|
|
mock_cache.return_value.supported_call_types = [
|
|
"completion",
|
|
"acompletion",
|
|
"embedding",
|
|
"aembedding",
|
|
"atranscription",
|
|
"transcription",
|
|
]
|
|
|
|
try:
|
|
import asyncio
|
|
|
|
print("testing reading config")
|
|
# this is a basic config.yaml with only a model
|
|
filepath = os.path.dirname(os.path.abspath(__file__))
|
|
proxy_config = ProxyConfig()
|
|
result = asyncio.run(
|
|
proxy_config.load_config(
|
|
router=None,
|
|
config_file_path=f"{filepath}/example_config_yaml/simple_config.yaml",
|
|
)
|
|
)
|
|
print(result)
|
|
assert len(result[1]) == 1
|
|
|
|
# this is a load balancing config yaml
|
|
result = asyncio.run(
|
|
proxy_config.load_config(
|
|
router=None,
|
|
config_file_path=f"{filepath}/example_config_yaml/azure_config.yaml",
|
|
)
|
|
)
|
|
print(result)
|
|
assert len(result[1]) == 2
|
|
|
|
# config with general settings - custom callbacks
|
|
result = asyncio.run(
|
|
proxy_config.load_config(
|
|
router=None,
|
|
config_file_path=f"{filepath}/example_config_yaml/azure_config.yaml",
|
|
)
|
|
)
|
|
print(result)
|
|
assert len(result[1]) == 2
|
|
|
|
# tests for litellm.cache set from config
|
|
print("testing reading proxy config for cache")
|
|
litellm.cache = None
|
|
asyncio.run(
|
|
proxy_config.load_config(
|
|
router=None,
|
|
config_file_path=f"{filepath}/example_config_yaml/cache_no_params.yaml",
|
|
)
|
|
)
|
|
assert litellm.cache is not None
|
|
assert "redis_client" in vars(
|
|
litellm.cache.cache
|
|
) # it should default to redis on proxy
|
|
assert litellm.cache.supported_call_types == [
|
|
"completion",
|
|
"acompletion",
|
|
"embedding",
|
|
"aembedding",
|
|
"atranscription",
|
|
"transcription",
|
|
] # init with all call types
|
|
|
|
litellm.disable_cache()
|
|
|
|
print("testing reading proxy config for cache with params")
|
|
mock_cache.return_value.supported_call_types = [
|
|
"embedding",
|
|
"aembedding",
|
|
]
|
|
asyncio.run(
|
|
proxy_config.load_config(
|
|
router=None,
|
|
config_file_path=f"{filepath}/example_config_yaml/cache_with_params.yaml",
|
|
)
|
|
)
|
|
assert litellm.cache is not None
|
|
print(litellm.cache)
|
|
print(litellm.cache.supported_call_types)
|
|
print(vars(litellm.cache.cache))
|
|
assert "redis_client" in vars(
|
|
litellm.cache.cache
|
|
) # it should default to redis on proxy
|
|
assert litellm.cache.supported_call_types == [
|
|
"embedding",
|
|
"aembedding",
|
|
] # init with all call types
|
|
|
|
except Exception as e:
|
|
pytest.fail(
|
|
f"Proxy: Got exception reading config: {str(e)}\n{traceback.format_exc()}"
|
|
)
|
|
|
|
|
|
# test_load_router_config()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_team_update_redis():
|
|
"""
|
|
Tests if team update, updates the redis cache if set
|
|
"""
|
|
from litellm.caching.caching import DualCache, RedisCache
|
|
from litellm.proxy._types import LiteLLM_TeamTableCachedObj
|
|
from litellm.proxy.auth.auth_checks import _cache_team_object
|
|
|
|
proxy_logging_obj: ProxyLogging = getattr(
|
|
litellm.proxy.proxy_server, "proxy_logging_obj"
|
|
)
|
|
|
|
redis_cache = RedisCache(host="localhost")
|
|
|
|
with patch.object(
|
|
redis_cache,
|
|
"async_set_cache",
|
|
new=AsyncMock(),
|
|
) as mock_client:
|
|
await _cache_team_object(
|
|
team_id="1234",
|
|
team_table=LiteLLM_TeamTableCachedObj(team_id="1234"),
|
|
user_api_key_cache=DualCache(redis_cache=redis_cache),
|
|
proxy_logging_obj=proxy_logging_obj,
|
|
)
|
|
|
|
mock_client.assert_called()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_team_redis(client_no_auth):
|
|
"""
|
|
Tests if get_team_object gets value from redis cache, if set
|
|
"""
|
|
from litellm.caching.caching import DualCache, RedisCache
|
|
from litellm.proxy.auth.auth_checks import get_team_object
|
|
|
|
proxy_logging_obj: ProxyLogging = getattr(
|
|
litellm.proxy.proxy_server, "proxy_logging_obj"
|
|
)
|
|
|
|
redis_cache = RedisCache()
|
|
|
|
from fastapi import HTTPException
|
|
|
|
with patch.object(
|
|
redis_cache,
|
|
"async_get_cache",
|
|
new=AsyncMock(),
|
|
) as mock_client:
|
|
try:
|
|
await get_team_object(
|
|
team_id="1234",
|
|
user_api_key_cache=DualCache(redis_cache=redis_cache),
|
|
parent_otel_span=None,
|
|
proxy_logging_obj=proxy_logging_obj,
|
|
prisma_client=AsyncMock(),
|
|
)
|
|
except HTTPException:
|
|
pass
|
|
|
|
mock_client.assert_called_once()
|
|
|
|
|
|
import random
|
|
from litellm._uuid import uuid
|
|
from unittest.mock import AsyncMock, MagicMock, PropertyMock, patch
|
|
|
|
from litellm.proxy._types import (
|
|
LitellmUserRoles,
|
|
NewUserRequest,
|
|
TeamMemberAddRequest,
|
|
UserAPIKeyAuth,
|
|
)
|
|
from litellm.proxy.management_endpoints.internal_user_endpoints import new_user
|
|
from litellm.proxy.management_endpoints.team_endpoints import team_member_add
|
|
from test_key_generate_prisma import prisma_client
|
|
|
|
|
|
@pytest.mark.skip(reason="Requires reliable external DB connection (prisma).")
|
|
@pytest.mark.parametrize(
|
|
"user_role",
|
|
[LitellmUserRoles.INTERNAL_USER.value, LitellmUserRoles.PROXY_ADMIN.value],
|
|
)
|
|
@pytest.mark.asyncio
|
|
@pytest.mark.skip(reason="Requires reliable external DB connection (prisma).")
|
|
async def test_create_user_default_budget(prisma_client, user_role):
|
|
|
|
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
|
|
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
|
|
setattr(litellm, "max_internal_user_budget", 10)
|
|
setattr(litellm, "internal_user_budget_duration", "5m")
|
|
await litellm.proxy.proxy_server.prisma_client.connect()
|
|
user = f"ishaan {uuid.uuid4().hex}"
|
|
request = NewUserRequest(
|
|
user_id=user, user_role=user_role
|
|
) # create a key with no budget
|
|
with patch.object(
|
|
litellm.proxy.proxy_server.prisma_client, "insert_data", new=AsyncMock()
|
|
) as mock_client:
|
|
await new_user(
|
|
request,
|
|
)
|
|
|
|
mock_client.assert_called()
|
|
|
|
print(f"mock_client.call_args: {mock_client.call_args}")
|
|
print("mock_client.call_args.kwargs: {}".format(mock_client.call_args.kwargs))
|
|
|
|
if user_role == LitellmUserRoles.INTERNAL_USER.value:
|
|
assert (
|
|
mock_client.call_args.kwargs["data"]["max_budget"]
|
|
== litellm.max_internal_user_budget
|
|
)
|
|
assert (
|
|
mock_client.call_args.kwargs["data"]["budget_duration"]
|
|
== litellm.internal_user_budget_duration
|
|
)
|
|
|
|
else:
|
|
assert mock_client.call_args.kwargs["data"]["max_budget"] is None
|
|
assert mock_client.call_args.kwargs["data"]["budget_duration"] is None
|
|
|
|
|
|
@pytest.mark.parametrize("new_member_method", ["user_id", "user_email"])
|
|
@pytest.mark.asyncio
|
|
@pytest.mark.skip(reason="Requires reliable external DB connection (prisma).")
|
|
async def test_create_team_member_add(prisma_client, new_member_method):
|
|
import time
|
|
|
|
from fastapi import Request
|
|
|
|
from litellm.proxy._types import LiteLLM_TeamTableCachedObj, LiteLLM_UserTable
|
|
from litellm.proxy.proxy_server import hash_token, user_api_key_cache
|
|
|
|
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
|
|
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
|
|
setattr(litellm, "max_internal_user_budget", 10)
|
|
setattr(litellm, "internal_user_budget_duration", "5m")
|
|
await litellm.proxy.proxy_server.prisma_client.connect()
|
|
user = f"ishaan {uuid.uuid4().hex}"
|
|
_team_id = "litellm-test-client-id-new"
|
|
team_obj = LiteLLM_TeamTableCachedObj(
|
|
team_id=_team_id,
|
|
blocked=False,
|
|
last_refreshed_at=time.time(),
|
|
metadata={"guardrails": {"modify_guardrails": False}},
|
|
)
|
|
# user_api_key_cache.set_cache(key=hash_token(user_key), value=valid_token)
|
|
user_api_key_cache.set_cache(key="team_id:{}".format(_team_id), value=team_obj)
|
|
|
|
setattr(litellm.proxy.proxy_server, "user_api_key_cache", user_api_key_cache)
|
|
if new_member_method == "user_id":
|
|
data = {
|
|
"team_id": _team_id,
|
|
"member": [{"role": "user", "user_id": user}],
|
|
}
|
|
elif new_member_method == "user_email":
|
|
data = {
|
|
"team_id": _team_id,
|
|
"member": [{"role": "user", "user_email": user}],
|
|
}
|
|
team_member_add_request = TeamMemberAddRequest(**data)
|
|
|
|
with (
|
|
patch(
|
|
"litellm.proxy.proxy_server.prisma_client.db.litellm_usertable",
|
|
new_callable=AsyncMock,
|
|
) as mock_litellm_usertable,
|
|
patch(
|
|
"litellm.proxy.auth.auth_checks._get_team_object_from_user_api_key_cache",
|
|
new=AsyncMock(return_value=team_obj),
|
|
) as mock_team_obj,
|
|
patch(
|
|
"litellm.proxy.proxy_server.prisma_client.get_data",
|
|
new=AsyncMock(return_value=[]),
|
|
) as mock_get_data,
|
|
):
|
|
|
|
mock_client = AsyncMock(
|
|
return_value=LiteLLM_UserTable(
|
|
user_id="1234", max_budget=100, user_email="1234"
|
|
)
|
|
)
|
|
mock_litellm_usertable.upsert = mock_client
|
|
mock_litellm_usertable.find_many = AsyncMock(return_value=None)
|
|
# Mock find_first for user_email validation (returns None for new users)
|
|
mock_litellm_usertable.find_first = AsyncMock(return_value=None)
|
|
# Mock find_unique for user_id validation (returns None for new users)
|
|
mock_litellm_usertable.find_unique = AsyncMock(return_value=None)
|
|
team_mock_client = AsyncMock()
|
|
original_val = getattr(
|
|
litellm.proxy.proxy_server.prisma_client.db, "litellm_teamtable"
|
|
)
|
|
litellm.proxy.proxy_server.prisma_client.db.litellm_teamtable = team_mock_client
|
|
|
|
team_mock_client.update = AsyncMock(
|
|
return_value=LiteLLM_TeamTableCachedObj(team_id="1234")
|
|
)
|
|
|
|
print(f"team_member_add_request={team_member_add_request}")
|
|
await team_member_add(
|
|
data=team_member_add_request,
|
|
user_api_key_dict=UserAPIKeyAuth(user_role="proxy_admin"),
|
|
)
|
|
|
|
mock_client.assert_called()
|
|
|
|
print(f"mock_client.call_args: {mock_client.call_args}")
|
|
print("mock_client.call_args.kwargs: {}".format(mock_client.call_args.kwargs))
|
|
|
|
assert (
|
|
mock_client.call_args.kwargs["data"]["create"]["max_budget"]
|
|
== litellm.max_internal_user_budget
|
|
)
|
|
assert (
|
|
mock_client.call_args.kwargs["data"]["create"]["budget_duration"]
|
|
== litellm.internal_user_budget_duration
|
|
)
|
|
|
|
litellm.proxy.proxy_server.prisma_client.db.litellm_teamtable = original_val
|
|
|
|
|
|
@pytest.mark.parametrize("team_member_role", ["admin", "user"])
|
|
@pytest.mark.parametrize("team_route", ["/team/member_add", "/team/member_delete"])
|
|
@pytest.mark.asyncio
|
|
async def test_create_team_member_add_team_admin_user_api_key_auth(
|
|
prisma_client, team_member_role, team_route
|
|
):
|
|
import time
|
|
|
|
from fastapi import Request
|
|
|
|
from litellm.proxy._types import LiteLLM_TeamTableCachedObj, Member
|
|
from litellm.proxy.proxy_server import (
|
|
ProxyException,
|
|
hash_token,
|
|
user_api_key_auth,
|
|
user_api_key_cache,
|
|
)
|
|
|
|
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
|
|
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
|
|
setattr(litellm, "max_internal_user_budget", 10)
|
|
setattr(litellm, "internal_user_budget_duration", "5m")
|
|
await litellm.proxy.proxy_server.prisma_client.connect()
|
|
user = f"ishaan {uuid.uuid4().hex}"
|
|
_team_id = "litellm-test-client-id-new"
|
|
user_key = "sk-12345678"
|
|
|
|
valid_token = UserAPIKeyAuth(
|
|
team_id=_team_id,
|
|
token=hash_token(user_key),
|
|
team_member=Member(role=team_member_role, user_id=user),
|
|
last_refreshed_at=time.time(),
|
|
)
|
|
user_api_key_cache.set_cache(key=hash_token(user_key), value=valid_token)
|
|
|
|
team_obj = LiteLLM_TeamTableCachedObj(
|
|
team_id=_team_id,
|
|
blocked=False,
|
|
last_refreshed_at=time.time(),
|
|
metadata={"guardrails": {"modify_guardrails": False}},
|
|
)
|
|
|
|
user_api_key_cache.set_cache(key="team_id:{}".format(_team_id), value=team_obj)
|
|
|
|
setattr(litellm.proxy.proxy_server, "user_api_key_cache", user_api_key_cache)
|
|
|
|
## TEST IF TEAM ADMIN ALLOWED TO CALL /MEMBER_ADD ENDPOINT
|
|
import json
|
|
|
|
from starlette.datastructures import URL
|
|
|
|
request = Request(scope={"type": "http"})
|
|
request._url = URL(url=team_route)
|
|
|
|
body = {}
|
|
json_bytes = json.dumps(body).encode("utf-8")
|
|
|
|
request._body = json_bytes
|
|
|
|
## ALLOWED BY USER_API_KEY_AUTH
|
|
await user_api_key_auth(request=request, api_key="Bearer " + user_key)
|
|
|
|
|
|
@pytest.mark.parametrize("new_member_method", ["user_id", "user_email"])
|
|
@pytest.mark.parametrize("user_role", ["admin", "user"])
|
|
@pytest.mark.asyncio
|
|
async def test_create_team_member_add_team_admin(
|
|
prisma_client, new_member_method, user_role
|
|
):
|
|
"""
|
|
Relevant issue - https://github.com/BerriAI/litellm/issues/5300
|
|
|
|
Allow team admins to:
|
|
- Add and remove team members
|
|
- raise error if team member not an existing 'internal_user'
|
|
"""
|
|
import time
|
|
|
|
from fastapi import Request
|
|
|
|
from litellm.proxy._types import (
|
|
LiteLLM_TeamTableCachedObj,
|
|
LiteLLM_UserTable,
|
|
Member,
|
|
)
|
|
from litellm.proxy.proxy_server import (
|
|
HTTPException,
|
|
ProxyException,
|
|
hash_token,
|
|
user_api_key_auth,
|
|
user_api_key_cache,
|
|
)
|
|
|
|
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
|
|
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
|
|
setattr(litellm, "max_internal_user_budget", 10)
|
|
setattr(litellm, "internal_user_budget_duration", "5m")
|
|
await litellm.proxy.proxy_server.prisma_client.connect()
|
|
user = f"ishaan {uuid.uuid4().hex}"
|
|
_team_id = "litellm-test-client-id-new"
|
|
user_key = "sk-12345678"
|
|
team_admin = f"krrish {uuid.uuid4().hex}"
|
|
|
|
valid_token = UserAPIKeyAuth(
|
|
team_id=_team_id,
|
|
user_id=team_admin,
|
|
token=hash_token(user_key),
|
|
last_refreshed_at=time.time(),
|
|
)
|
|
user_api_key_cache.set_cache(key=hash_token(user_key), value=valid_token)
|
|
|
|
team_obj = LiteLLM_TeamTableCachedObj(
|
|
team_id=_team_id,
|
|
blocked=False,
|
|
last_refreshed_at=time.time(),
|
|
members_with_roles=[Member(role=user_role, user_id=team_admin)],
|
|
metadata={"guardrails": {"modify_guardrails": False}},
|
|
)
|
|
|
|
user_api_key_cache.set_cache(key="team_id:{}".format(_team_id), value=team_obj)
|
|
|
|
setattr(litellm.proxy.proxy_server, "user_api_key_cache", user_api_key_cache)
|
|
if new_member_method == "user_id":
|
|
data = {
|
|
"team_id": _team_id,
|
|
"member": [{"role": "user", "user_id": user}],
|
|
}
|
|
elif new_member_method == "user_email":
|
|
data = {
|
|
"team_id": _team_id,
|
|
"member": [{"role": "user", "user_email": user}],
|
|
}
|
|
team_member_add_request = TeamMemberAddRequest(**data)
|
|
|
|
with (
|
|
patch(
|
|
"litellm.proxy.proxy_server.prisma_client.db.litellm_usertable",
|
|
new_callable=AsyncMock,
|
|
) as mock_litellm_usertable,
|
|
patch(
|
|
"litellm.proxy.auth.auth_checks._get_team_object_from_user_api_key_cache",
|
|
new=AsyncMock(return_value=team_obj),
|
|
) as mock_team_obj,
|
|
patch(
|
|
"litellm.proxy.proxy_server.prisma_client.get_data",
|
|
new=AsyncMock(return_value=[]),
|
|
) as mock_get_data,
|
|
):
|
|
mock_client = AsyncMock(
|
|
return_value=LiteLLM_UserTable(
|
|
user_id="1234", max_budget=100, user_email="1234"
|
|
)
|
|
)
|
|
mock_litellm_usertable.upsert = mock_client
|
|
mock_litellm_usertable.find_many = AsyncMock(return_value=None)
|
|
# Mock find_first for user_email validation (returns None for new users)
|
|
mock_litellm_usertable.find_first = AsyncMock(return_value=None)
|
|
# Mock find_unique for user_id validation (returns None for new users)
|
|
mock_litellm_usertable.find_unique = AsyncMock(return_value=None)
|
|
|
|
team_mock_client = AsyncMock()
|
|
original_val = getattr(
|
|
litellm.proxy.proxy_server.prisma_client.db, "litellm_teamtable"
|
|
)
|
|
litellm.proxy.proxy_server.prisma_client.db.litellm_teamtable = team_mock_client
|
|
|
|
team_mock_client.update = AsyncMock(
|
|
return_value=LiteLLM_TeamTableCachedObj(team_id="1234")
|
|
)
|
|
|
|
try:
|
|
await team_member_add(
|
|
data=team_member_add_request,
|
|
user_api_key_dict=valid_token,
|
|
)
|
|
except HTTPException as e:
|
|
if user_role == "user":
|
|
assert e.status_code == 403
|
|
return
|
|
else:
|
|
raise e
|
|
|
|
mock_client.assert_called()
|
|
|
|
print(f"mock_client.call_args: {mock_client.call_args}")
|
|
print("mock_client.call_args.kwargs: {}".format(mock_client.call_args.kwargs))
|
|
|
|
assert (
|
|
mock_client.call_args.kwargs["data"]["create"]["max_budget"]
|
|
== litellm.max_internal_user_budget
|
|
)
|
|
assert (
|
|
mock_client.call_args.kwargs["data"]["create"]["budget_duration"]
|
|
== litellm.internal_user_budget_duration
|
|
)
|
|
|
|
litellm.proxy.proxy_server.prisma_client.db.litellm_teamtable = original_val
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
@pytest.mark.skip(reason="Requires reliable external DB connection (prisma).")
|
|
async def test_user_info_team_list(prisma_client):
|
|
"""Assert user_info for admin calls team_list function"""
|
|
from litellm.proxy._types import LiteLLM_UserTable
|
|
|
|
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
|
|
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
|
|
await litellm.proxy.proxy_server.prisma_client.connect()
|
|
|
|
from litellm.proxy.management_endpoints.internal_user_endpoints import user_info
|
|
|
|
with patch(
|
|
"litellm.proxy.management_endpoints.team_endpoints.list_team",
|
|
new_callable=AsyncMock,
|
|
) as mock_client:
|
|
|
|
prisma_client.get_data = AsyncMock(
|
|
return_value=LiteLLM_UserTable(
|
|
user_role="proxy_admin",
|
|
user_id="default_user_id",
|
|
max_budget=None,
|
|
user_email="",
|
|
)
|
|
)
|
|
|
|
try:
|
|
await user_info(
|
|
request=MagicMock(),
|
|
user_id=None,
|
|
user_api_key_dict=UserAPIKeyAuth(
|
|
api_key="sk-1234", user_id="default_user_id"
|
|
),
|
|
)
|
|
except Exception:
|
|
pass
|
|
|
|
mock_client.assert_called()
|
|
|
|
|
|
@pytest.mark.skip(reason="Local test")
|
|
@pytest.mark.asyncio
|
|
async def test_add_callback_via_key(prisma_client):
|
|
"""
|
|
Test if callback specified in key, is used.
|
|
"""
|
|
global headers
|
|
import json
|
|
|
|
from fastapi import HTTPException, Request, Response
|
|
from starlette.datastructures import URL
|
|
|
|
from litellm.proxy.proxy_server import chat_completion
|
|
|
|
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
|
|
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
|
|
await litellm.proxy.proxy_server.prisma_client.connect()
|
|
|
|
litellm.set_verbose = True
|
|
|
|
try:
|
|
# Your test data
|
|
test_data = {
|
|
"model": "azure/gpt-4.1-mini",
|
|
"messages": [
|
|
{"role": "user", "content": "write 1 sentence poem"},
|
|
],
|
|
"max_tokens": 10,
|
|
"mock_response": "Hello world",
|
|
"api_key": "my-fake-key",
|
|
}
|
|
|
|
request = Request(scope={"type": "http", "method": "POST", "headers": {}})
|
|
request._url = URL(url="/chat/completions")
|
|
|
|
json_bytes = json.dumps(test_data).encode("utf-8")
|
|
|
|
request._body = json_bytes
|
|
|
|
with patch.object(
|
|
litellm.litellm_core_utils.litellm_logging,
|
|
"LangFuseLogger",
|
|
new=MagicMock(),
|
|
) as mock_client:
|
|
resp = await chat_completion(
|
|
request=request,
|
|
fastapi_response=Response(),
|
|
user_api_key_dict=UserAPIKeyAuth(
|
|
metadata={
|
|
"allow_client_mock_response": True,
|
|
"logging": [
|
|
{
|
|
"callback_name": "langfuse", # 'otel', 'langfuse', 'lunary'
|
|
"callback_type": "success", # set, if required by integration - future improvement, have logging tools work for success + failure by default
|
|
"callback_vars": {
|
|
"langfuse_public_key": "os.environ/LANGFUSE_PUBLIC_KEY",
|
|
"langfuse_secret_key": "os.environ/LANGFUSE_SECRET_KEY",
|
|
"langfuse_host": "https://us.cloud.langfuse.com",
|
|
},
|
|
}
|
|
],
|
|
}
|
|
),
|
|
)
|
|
print(resp)
|
|
mock_client.assert_called()
|
|
mock_client.return_value.log_event.assert_called()
|
|
args, kwargs = mock_client.return_value.log_event.call_args
|
|
kwargs = kwargs["kwargs"]
|
|
assert "user_api_key_metadata" in kwargs["litellm_params"]["metadata"]
|
|
assert (
|
|
"logging"
|
|
in kwargs["litellm_params"]["metadata"]["user_api_key_metadata"]
|
|
)
|
|
checked_keys = False
|
|
for item in kwargs["litellm_params"]["metadata"]["user_api_key_metadata"][
|
|
"logging"
|
|
]:
|
|
for k, v in item["callback_vars"].items():
|
|
print("k={}, v={}".format(k, v))
|
|
if "key" in k:
|
|
assert "os.environ" in v
|
|
checked_keys = True
|
|
|
|
assert checked_keys
|
|
except Exception as e:
|
|
pytest.fail(f"LiteLLM Proxy test failed. Exception - {str(e)}")
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
@pytest.mark.parametrize(
|
|
"callback_type, expected_success_callbacks, expected_failure_callbacks",
|
|
[
|
|
("success", ["langfuse"], []),
|
|
("failure", [], ["langfuse"]),
|
|
("success_and_failure", ["langfuse"], ["langfuse"]),
|
|
],
|
|
)
|
|
async def test_add_callback_via_key_litellm_pre_call_utils(
|
|
prisma_client, callback_type, expected_success_callbacks, expected_failure_callbacks
|
|
):
|
|
import json
|
|
|
|
from fastapi import HTTPException, Request, Response
|
|
from starlette.datastructures import URL
|
|
|
|
from litellm.proxy.litellm_pre_call_utils import add_litellm_data_to_request
|
|
|
|
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
|
|
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
|
|
await litellm.proxy.proxy_server.prisma_client.connect()
|
|
|
|
proxy_config = getattr(litellm.proxy.proxy_server, "proxy_config")
|
|
|
|
request = Request(scope={"type": "http", "method": "POST", "headers": {}})
|
|
request._url = URL(url="/chat/completions")
|
|
|
|
test_data = {
|
|
"model": "azure/gpt-4.1-mini",
|
|
"messages": [
|
|
{"role": "user", "content": "write 1 sentence poem"},
|
|
],
|
|
"max_tokens": 10,
|
|
"mock_response": "Hello world",
|
|
"api_key": "my-fake-key",
|
|
}
|
|
|
|
json_bytes = json.dumps(test_data).encode("utf-8")
|
|
|
|
request._body = json_bytes
|
|
|
|
data = {
|
|
"data": {
|
|
"model": "azure/gpt-4.1-mini",
|
|
"messages": [{"role": "user", "content": "write 1 sentence poem"}],
|
|
"max_tokens": 10,
|
|
"mock_response": "Hello world",
|
|
"api_key": "my-fake-key",
|
|
},
|
|
"request": request,
|
|
"user_api_key_dict": UserAPIKeyAuth(
|
|
token=None,
|
|
key_name=None,
|
|
key_alias=None,
|
|
spend=0.0,
|
|
max_budget=None,
|
|
expires=None,
|
|
models=[],
|
|
aliases={},
|
|
config={},
|
|
user_id=None,
|
|
team_id=None,
|
|
max_parallel_requests=None,
|
|
metadata={
|
|
"allow_client_mock_response": True,
|
|
"logging": [
|
|
{
|
|
"callback_name": "langfuse",
|
|
"callback_type": callback_type,
|
|
"callback_vars": {
|
|
"langfuse_public_key": "my-mock-public-key",
|
|
"langfuse_secret_key": "my-mock-secret-key",
|
|
"langfuse_host": "https://us.cloud.langfuse.com",
|
|
},
|
|
}
|
|
],
|
|
},
|
|
tpm_limit=None,
|
|
rpm_limit=None,
|
|
budget_duration=None,
|
|
budget_reset_at=None,
|
|
allowed_cache_controls=[],
|
|
permissions={},
|
|
model_spend={},
|
|
model_max_budget={},
|
|
soft_budget_cooldown=False,
|
|
litellm_budget_table=None,
|
|
org_id=None,
|
|
team_spend=None,
|
|
team_alias=None,
|
|
team_tpm_limit=None,
|
|
team_rpm_limit=None,
|
|
team_max_budget=None,
|
|
team_models=[],
|
|
team_blocked=False,
|
|
soft_budget=None,
|
|
team_model_aliases=None,
|
|
team_member_spend=None,
|
|
team_metadata=None,
|
|
end_user_id=None,
|
|
end_user_tpm_limit=None,
|
|
end_user_rpm_limit=None,
|
|
end_user_max_budget=None,
|
|
last_refreshed_at=None,
|
|
api_key=None,
|
|
user_role=None,
|
|
allowed_model_region=None,
|
|
parent_otel_span=None,
|
|
),
|
|
"proxy_config": proxy_config,
|
|
"general_settings": {},
|
|
"version": "0.0.0",
|
|
}
|
|
|
|
new_data = await add_litellm_data_to_request(**data)
|
|
print("NEW DATA: {}".format(new_data))
|
|
|
|
assert "langfuse_public_key" in new_data
|
|
assert new_data["langfuse_public_key"] == "my-mock-public-key"
|
|
assert "langfuse_secret_key" in new_data
|
|
assert new_data["langfuse_secret_key"] == "my-mock-secret-key"
|
|
|
|
if expected_success_callbacks:
|
|
assert "success_callback" in new_data
|
|
assert new_data["success_callback"] == expected_success_callbacks
|
|
|
|
if expected_failure_callbacks:
|
|
assert "failure_callback" in new_data
|
|
assert new_data["failure_callback"] == expected_failure_callbacks
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
@pytest.mark.parametrize(
|
|
"disable_fallbacks_set",
|
|
[
|
|
True,
|
|
False,
|
|
],
|
|
)
|
|
async def test_disable_fallbacks_by_key(disable_fallbacks_set):
|
|
from litellm.proxy.litellm_pre_call_utils import LiteLLMProxyRequestSetup
|
|
|
|
key_metadata = {"disable_fallbacks": disable_fallbacks_set}
|
|
existing_data = {
|
|
"model": "azure/gpt-4.1-mini",
|
|
"messages": [{"role": "user", "content": "write 1 sentence poem"}],
|
|
}
|
|
data = LiteLLMProxyRequestSetup.add_key_level_controls(
|
|
key_metadata=key_metadata,
|
|
data=existing_data,
|
|
_metadata_variable_name="metadata",
|
|
)
|
|
|
|
assert data["disable_fallbacks"] == disable_fallbacks_set
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
@pytest.mark.parametrize(
|
|
"callback_type, expected_success_callbacks, expected_failure_callbacks",
|
|
[
|
|
("success", ["gcs_bucket"], []),
|
|
("failure", [], ["gcs_bucket"]),
|
|
("success_and_failure", ["gcs_bucket"], ["gcs_bucket"]),
|
|
],
|
|
)
|
|
async def test_add_callback_via_key_litellm_pre_call_utils_gcs_bucket(
|
|
prisma_client, callback_type, expected_success_callbacks, expected_failure_callbacks
|
|
):
|
|
import json
|
|
|
|
from fastapi import HTTPException, Request, Response
|
|
from starlette.datastructures import URL
|
|
|
|
from litellm.proxy.litellm_pre_call_utils import add_litellm_data_to_request
|
|
|
|
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
|
|
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
|
|
await litellm.proxy.proxy_server.prisma_client.connect()
|
|
|
|
proxy_config = getattr(litellm.proxy.proxy_server, "proxy_config")
|
|
|
|
request = Request(scope={"type": "http", "method": "POST", "headers": {}})
|
|
request._url = URL(url="/chat/completions")
|
|
|
|
test_data = {
|
|
"model": "azure/gpt-4.1-mini",
|
|
"messages": [
|
|
{"role": "user", "content": "write 1 sentence poem"},
|
|
],
|
|
"max_tokens": 10,
|
|
"mock_response": "Hello world",
|
|
"api_key": "my-fake-key",
|
|
}
|
|
|
|
json_bytes = json.dumps(test_data).encode("utf-8")
|
|
|
|
request._body = json_bytes
|
|
|
|
data = {
|
|
"data": {
|
|
"model": "azure/gpt-4.1-mini",
|
|
"messages": [{"role": "user", "content": "write 1 sentence poem"}],
|
|
"max_tokens": 10,
|
|
"mock_response": "Hello world",
|
|
"api_key": "my-fake-key",
|
|
},
|
|
"request": request,
|
|
"user_api_key_dict": UserAPIKeyAuth(
|
|
token=None,
|
|
key_name=None,
|
|
key_alias=None,
|
|
spend=0.0,
|
|
max_budget=None,
|
|
expires=None,
|
|
models=[],
|
|
aliases={},
|
|
config={},
|
|
user_id=None,
|
|
team_id=None,
|
|
max_parallel_requests=None,
|
|
metadata={
|
|
"allow_client_mock_response": True,
|
|
"logging": [
|
|
{
|
|
"callback_name": "gcs_bucket",
|
|
"callback_type": callback_type,
|
|
"callback_vars": {
|
|
"gcs_bucket_name": "key-logging-project1",
|
|
"gcs_path_service_account": "pathrise-convert-1606954137718-a956eef1a2a8.json",
|
|
},
|
|
}
|
|
],
|
|
},
|
|
tpm_limit=None,
|
|
rpm_limit=None,
|
|
budget_duration=None,
|
|
budget_reset_at=None,
|
|
allowed_cache_controls=[],
|
|
permissions={},
|
|
model_spend={},
|
|
model_max_budget={},
|
|
soft_budget_cooldown=False,
|
|
litellm_budget_table=None,
|
|
org_id=None,
|
|
team_spend=None,
|
|
team_alias=None,
|
|
team_tpm_limit=None,
|
|
team_rpm_limit=None,
|
|
team_max_budget=None,
|
|
team_models=[],
|
|
team_blocked=False,
|
|
soft_budget=None,
|
|
team_model_aliases=None,
|
|
team_member_spend=None,
|
|
team_metadata=None,
|
|
end_user_id=None,
|
|
end_user_tpm_limit=None,
|
|
end_user_rpm_limit=None,
|
|
end_user_max_budget=None,
|
|
last_refreshed_at=None,
|
|
api_key=None,
|
|
user_role=None,
|
|
allowed_model_region=None,
|
|
parent_otel_span=None,
|
|
),
|
|
"proxy_config": proxy_config,
|
|
"general_settings": {},
|
|
"version": "0.0.0",
|
|
}
|
|
|
|
new_data = await add_litellm_data_to_request(**data)
|
|
print("NEW DATA: {}".format(new_data))
|
|
|
|
assert "gcs_bucket_name" in new_data
|
|
assert new_data["gcs_bucket_name"] == "key-logging-project1"
|
|
assert "gcs_path_service_account" in new_data
|
|
assert (
|
|
new_data["gcs_path_service_account"]
|
|
== "pathrise-convert-1606954137718-a956eef1a2a8.json"
|
|
)
|
|
|
|
if expected_success_callbacks:
|
|
assert "success_callback" in new_data
|
|
assert new_data["success_callback"] == expected_success_callbacks
|
|
|
|
if expected_failure_callbacks:
|
|
assert "failure_callback" in new_data
|
|
assert new_data["failure_callback"] == expected_failure_callbacks
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
@pytest.mark.parametrize(
|
|
"callback_type, expected_success_callbacks, expected_failure_callbacks",
|
|
[
|
|
("success", ["langsmith"], []),
|
|
("failure", [], ["langsmith"]),
|
|
("success_and_failure", ["langsmith"], ["langsmith"]),
|
|
],
|
|
)
|
|
async def test_add_callback_via_key_litellm_pre_call_utils_langsmith(
|
|
prisma_client, callback_type, expected_success_callbacks, expected_failure_callbacks
|
|
):
|
|
import json
|
|
|
|
from fastapi import HTTPException, Request, Response
|
|
from starlette.datastructures import URL
|
|
|
|
from litellm.proxy.litellm_pre_call_utils import add_litellm_data_to_request
|
|
|
|
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
|
|
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
|
|
await litellm.proxy.proxy_server.prisma_client.connect()
|
|
|
|
proxy_config = getattr(litellm.proxy.proxy_server, "proxy_config")
|
|
|
|
request = Request(scope={"type": "http", "method": "POST", "headers": {}})
|
|
request._url = URL(url="/chat/completions")
|
|
|
|
test_data = {
|
|
"model": "azure/gpt-4.1-mini",
|
|
"messages": [
|
|
{"role": "user", "content": "write 1 sentence poem"},
|
|
],
|
|
"max_tokens": 10,
|
|
"mock_response": "Hello world",
|
|
"api_key": "my-fake-key",
|
|
}
|
|
|
|
json_bytes = json.dumps(test_data).encode("utf-8")
|
|
|
|
request._body = json_bytes
|
|
|
|
data = {
|
|
"data": {
|
|
"model": "azure/gpt-4.1-mini",
|
|
"messages": [{"role": "user", "content": "write 1 sentence poem"}],
|
|
"max_tokens": 10,
|
|
"mock_response": "Hello world",
|
|
"api_key": "my-fake-key",
|
|
},
|
|
"request": request,
|
|
"user_api_key_dict": UserAPIKeyAuth(
|
|
token=None,
|
|
key_name=None,
|
|
key_alias=None,
|
|
spend=0.0,
|
|
max_budget=None,
|
|
expires=None,
|
|
models=[],
|
|
aliases={},
|
|
config={},
|
|
user_id=None,
|
|
team_id=None,
|
|
max_parallel_requests=None,
|
|
metadata={
|
|
"allow_client_mock_response": True,
|
|
"logging": [
|
|
{
|
|
"callback_name": "langsmith",
|
|
"callback_type": callback_type,
|
|
"callback_vars": {
|
|
"langsmith_api_key": "ls-1234",
|
|
"langsmith_project": "pr-brief-resemblance-72",
|
|
"langsmith_base_url": "https://api.smith.langchain.com",
|
|
},
|
|
}
|
|
],
|
|
},
|
|
tpm_limit=None,
|
|
rpm_limit=None,
|
|
budget_duration=None,
|
|
budget_reset_at=None,
|
|
allowed_cache_controls=[],
|
|
permissions={},
|
|
model_spend={},
|
|
model_max_budget={},
|
|
soft_budget_cooldown=False,
|
|
litellm_budget_table=None,
|
|
org_id=None,
|
|
team_spend=None,
|
|
team_alias=None,
|
|
team_tpm_limit=None,
|
|
team_rpm_limit=None,
|
|
team_max_budget=None,
|
|
team_models=[],
|
|
team_blocked=False,
|
|
soft_budget=None,
|
|
team_model_aliases=None,
|
|
team_member_spend=None,
|
|
team_metadata=None,
|
|
end_user_id=None,
|
|
end_user_tpm_limit=None,
|
|
end_user_rpm_limit=None,
|
|
end_user_max_budget=None,
|
|
last_refreshed_at=None,
|
|
api_key=None,
|
|
user_role=None,
|
|
allowed_model_region=None,
|
|
parent_otel_span=None,
|
|
),
|
|
"proxy_config": proxy_config,
|
|
"general_settings": {},
|
|
"version": "0.0.0",
|
|
}
|
|
|
|
new_data = await add_litellm_data_to_request(**data)
|
|
print("NEW DATA: {}".format(new_data))
|
|
|
|
assert "langsmith_api_key" in new_data
|
|
assert new_data["langsmith_api_key"] == "ls-1234"
|
|
assert "langsmith_project" in new_data
|
|
assert new_data["langsmith_project"] == "pr-brief-resemblance-72"
|
|
assert "langsmith_base_url" in new_data
|
|
assert new_data["langsmith_base_url"] == "https://api.smith.langchain.com"
|
|
|
|
if expected_success_callbacks:
|
|
assert "success_callback" in new_data
|
|
assert new_data["success_callback"] == expected_success_callbacks
|
|
|
|
if expected_failure_callbacks:
|
|
assert "failure_callback" in new_data
|
|
assert new_data["failure_callback"] == expected_failure_callbacks
|
|
|
|
|
|
@pytest.mark.skipif(
|
|
not os.getenv("GEMINI_API_KEY") and not os.getenv("GOOGLE_API_KEY"),
|
|
reason="Requires GEMINI_API_KEY or GOOGLE_API_KEY.",
|
|
)
|
|
@pytest.mark.asyncio
|
|
async def test_gemini_pass_through_endpoint():
|
|
from starlette.datastructures import URL
|
|
|
|
from litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints import (
|
|
Request,
|
|
Response,
|
|
gemini_proxy_route,
|
|
)
|
|
|
|
body = b"""
|
|
{
|
|
"contents": [{
|
|
"parts":[{
|
|
"text": "The quick brown fox jumps over the lazy dog."
|
|
}]
|
|
}]
|
|
}
|
|
"""
|
|
|
|
# Construct the scope dictionary
|
|
scope = {
|
|
"type": "http",
|
|
"method": "POST",
|
|
"path": "/gemini/v1beta/models/gemini-2.5-flash:countTokens",
|
|
"query_string": b"key=sk-1234",
|
|
"headers": [
|
|
(b"content-type", b"application/json"),
|
|
],
|
|
}
|
|
|
|
# Create a new Request object
|
|
async def async_receive():
|
|
return {"type": "http.request", "body": body, "more_body": False}
|
|
|
|
request = Request(
|
|
scope=scope,
|
|
receive=async_receive,
|
|
)
|
|
|
|
resp = await gemini_proxy_route(
|
|
endpoint="v1beta/models/gemini-2.5-flash:countTokens?key=sk-1234",
|
|
request=request,
|
|
fastapi_response=Response(),
|
|
)
|
|
|
|
print(resp.body)
|
|
|
|
|
|
@pytest.mark.parametrize("hidden", [True, False])
|
|
@pytest.mark.asyncio
|
|
@pytest.mark.skip(reason="Requires reliable external DB connection (prisma).")
|
|
async def test_proxy_model_group_alias_checks(prisma_client, hidden):
|
|
"""
|
|
Check if model group alias is returned on
|
|
|
|
`/v1/models`
|
|
`/v1/model/info`
|
|
`/v1/model_group/info`
|
|
"""
|
|
import json
|
|
|
|
from fastapi import HTTPException, Request, Response
|
|
from starlette.datastructures import URL
|
|
|
|
from litellm.proxy.proxy_server import model_group_info, model_info_v1, model_list
|
|
|
|
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
|
|
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
|
|
await litellm.proxy.proxy_server.prisma_client.connect()
|
|
|
|
proxy_config = getattr(litellm.proxy.proxy_server, "proxy_config")
|
|
|
|
_model_list = [
|
|
{
|
|
"model_name": "gpt-3.5-turbo",
|
|
"litellm_params": {"model": "gpt-3.5-turbo"},
|
|
}
|
|
]
|
|
model_alias = "gpt-4"
|
|
router = litellm.Router(
|
|
model_list=_model_list,
|
|
model_group_alias={model_alias: {"model": "gpt-3.5-turbo", "hidden": hidden}},
|
|
)
|
|
setattr(litellm.proxy.proxy_server, "llm_router", router)
|
|
setattr(litellm.proxy.proxy_server, "llm_model_list", _model_list)
|
|
|
|
request = Request(scope={"type": "http", "method": "POST", "headers": {}})
|
|
request._url = URL(url="/v1/models")
|
|
|
|
resp = await model_list(
|
|
user_api_key_dict=UserAPIKeyAuth(models=[]),
|
|
)
|
|
|
|
if hidden:
|
|
assert len(resp["data"]) == 1
|
|
else:
|
|
assert len(resp["data"]) == 2
|
|
print(resp)
|
|
|
|
resp = await model_info_v1(
|
|
user_api_key_dict=UserAPIKeyAuth(models=[]),
|
|
)
|
|
models = resp["data"]
|
|
is_model_alias_in_list = False
|
|
for item in models:
|
|
if model_alias == item["model_name"]:
|
|
is_model_alias_in_list = True
|
|
|
|
if hidden:
|
|
assert is_model_alias_in_list is False
|
|
else:
|
|
assert is_model_alias_in_list
|
|
|
|
resp = await model_group_info(
|
|
user_api_key_dict=UserAPIKeyAuth(models=[]),
|
|
)
|
|
print(f"resp: {resp}")
|
|
models = resp["data"]
|
|
is_model_alias_in_list = False
|
|
print(f"model_alias: {model_alias}, models: {models}")
|
|
for item in models:
|
|
if model_alias == item.model_group:
|
|
is_model_alias_in_list = True
|
|
|
|
if hidden:
|
|
assert is_model_alias_in_list is False
|
|
else:
|
|
assert is_model_alias_in_list, f"models: {models}"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
@pytest.mark.skip(reason="Requires reliable external DB connection (prisma).")
|
|
async def test_proxy_model_group_info_rerank(prisma_client):
|
|
"""
|
|
Check if rerank model is returned on the following endpoints
|
|
|
|
`/v1/models`
|
|
`/v1/model/info`
|
|
`/v1/model_group/info`
|
|
"""
|
|
import json
|
|
|
|
from fastapi import HTTPException, Request, Response
|
|
from starlette.datastructures import URL
|
|
|
|
from litellm.proxy.proxy_server import model_group_info, model_info_v1, model_list
|
|
|
|
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
|
|
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
|
|
await litellm.proxy.proxy_server.prisma_client.connect()
|
|
|
|
proxy_config = getattr(litellm.proxy.proxy_server, "proxy_config")
|
|
|
|
_model_list = [
|
|
{
|
|
"model_name": "rerank-english-v3.0",
|
|
"litellm_params": {"model": "cohere/rerank-english-v3.0"},
|
|
"model_info": {
|
|
"mode": "rerank",
|
|
},
|
|
}
|
|
]
|
|
router = litellm.Router(model_list=_model_list)
|
|
setattr(litellm.proxy.proxy_server, "llm_router", router)
|
|
setattr(litellm.proxy.proxy_server, "llm_model_list", _model_list)
|
|
|
|
request = Request(scope={"type": "http", "method": "POST", "headers": {}})
|
|
request._url = URL(url="/v1/models")
|
|
|
|
resp = await model_list(
|
|
user_api_key_dict=UserAPIKeyAuth(models=[]),
|
|
)
|
|
|
|
assert len(resp["data"]) == 1
|
|
print(resp)
|
|
|
|
resp = await model_info_v1(
|
|
user_api_key_dict=UserAPIKeyAuth(models=[]),
|
|
)
|
|
models = resp["data"]
|
|
assert models[0]["model_info"]["mode"] == "rerank"
|
|
resp = await model_group_info(
|
|
user_api_key_dict=UserAPIKeyAuth(models=[]),
|
|
)
|
|
|
|
print(resp)
|
|
models = resp["data"]
|
|
assert models[0].mode == "rerank"
|
|
|
|
|
|
# @pytest.mark.asyncio
|
|
# async def test_proxy_team_member_add(prisma_client):
|
|
# """
|
|
# Add 10 people to a team. Confirm all 10 are added.
|
|
# """
|
|
# from litellm.proxy.management_endpoints.team_endpoints import (
|
|
# team_member_add,
|
|
# new_team,
|
|
# )
|
|
# from litellm.proxy._types import TeamMemberAddRequest, Member, NewTeamRequest
|
|
|
|
# setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
|
|
# setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
|
|
# try:
|
|
|
|
# async def test():
|
|
# await litellm.proxy.proxy_server.prisma_client.connect()
|
|
# from litellm.proxy.proxy_server import user_api_key_cache
|
|
|
|
# user_api_key_dict = UserAPIKeyAuth(
|
|
# user_role=LitellmUserRoles.PROXY_ADMIN,
|
|
# api_key="sk-1234",
|
|
# user_id="1234",
|
|
# )
|
|
|
|
# new_team()
|
|
# for _ in range(10):
|
|
# request = TeamMemberAddRequest(
|
|
# team_id="1234",
|
|
# member=Member(
|
|
# user_id="1234",
|
|
# user_role=LitellmUserRoles.INTERNAL_USER,
|
|
# ),
|
|
# )
|
|
# key = await team_member_add(
|
|
# request, user_api_key_dict=user_api_key_dict
|
|
# )
|
|
|
|
# print(key)
|
|
# user_id = key.user_id
|
|
|
|
# # check /user/info to verify user_role was set correctly
|
|
# new_user_info = await user_info(
|
|
# user_id=user_id, user_api_key_dict=user_api_key_dict
|
|
# )
|
|
# new_user_info = new_user_info.user_info
|
|
# print("new_user_info=", new_user_info)
|
|
# assert new_user_info["user_role"] == LitellmUserRoles.INTERNAL_USER
|
|
# assert new_user_info["user_id"] == user_id
|
|
|
|
# generated_key = key.key
|
|
# bearer_token = "Bearer " + generated_key
|
|
|
|
# assert generated_key not in user_api_key_cache.in_memory_cache.cache_dict
|
|
|
|
# value_from_prisma = await prisma_client.get_data(
|
|
# token=generated_key,
|
|
# )
|
|
# print("token from prisma", value_from_prisma)
|
|
|
|
# request = Request(
|
|
# {
|
|
# "type": "http",
|
|
# "route": api_route,
|
|
# "path": api_route.path,
|
|
# "headers": [("Authorization", bearer_token)],
|
|
# }
|
|
# )
|
|
|
|
# # use generated key to auth in
|
|
# result = await user_api_key_auth(request=request, api_key=bearer_token)
|
|
# print("result from user auth with new key", result)
|
|
|
|
# asyncio.run(test())
|
|
# except Exception as e:
|
|
# pytest.fail(f"An exception occurred - {str(e)}")
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_proxy_server_prisma_setup():
|
|
from litellm.proxy.proxy_server import ProxyStartupEvent, proxy_state
|
|
from litellm.proxy.utils import ProxyLogging
|
|
from litellm.caching import DualCache
|
|
|
|
user_api_key_cache = DualCache()
|
|
|
|
with patch.object(
|
|
litellm.proxy.proxy_server, "PrismaClient", new=MagicMock()
|
|
) as mock_prisma_client:
|
|
mock_client = mock_prisma_client.return_value # This is the mocked instance
|
|
mock_client.connect = AsyncMock() # Mock the connect method
|
|
mock_client.check_view_exists = AsyncMock() # Mock the check_view_exists method
|
|
mock_client.health_check = AsyncMock() # Mock the health_check method
|
|
mock_client._set_spend_logs_row_count_in_proxy_state = (
|
|
AsyncMock()
|
|
) # Mock the _set_spend_logs_row_count_in_proxy_state method
|
|
mock_client.start_db_health_watchdog_task = AsyncMock()
|
|
# Mock the db attribute with start_token_refresh_task for RDS IAM token refresh
|
|
mock_db = MagicMock()
|
|
mock_db.start_token_refresh_task = AsyncMock()
|
|
mock_client.db = mock_db
|
|
|
|
await ProxyStartupEvent._setup_prisma_client(
|
|
database_url=os.getenv("DATABASE_URL"),
|
|
proxy_logging_obj=ProxyLogging(user_api_key_cache=user_api_key_cache),
|
|
user_api_key_cache=user_api_key_cache,
|
|
)
|
|
|
|
# Verify our mocked methods were called
|
|
mock_client.connect.assert_called_once()
|
|
mock_client.check_view_exists.assert_called_once()
|
|
|
|
# Note: This is REALLY IMPORTANT to check that the health check is called
|
|
# This is how we ensure the DB is ready before proceeding
|
|
mock_client.health_check.assert_called_once()
|
|
|
|
# check that the spend logs row count is set in proxy state
|
|
mock_client._set_spend_logs_row_count_in_proxy_state.assert_called_once()
|
|
assert proxy_state.get_proxy_state_variable("spend_logs_row_count") is not None
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_proxy_server_prisma_setup_invalid_db():
|
|
"""
|
|
PROD TEST: Test that proxy server startup fails when it's unable to connect to the database
|
|
|
|
Think 2-3 times before editing / deleting this test, it's important for PROD
|
|
"""
|
|
from litellm.proxy.proxy_server import ProxyStartupEvent
|
|
from litellm.proxy.utils import ProxyLogging
|
|
from litellm.caching import DualCache
|
|
|
|
user_api_key_cache = DualCache()
|
|
invalid_db_url = "postgresql://invalid:invalid@localhost:5432/nonexistent"
|
|
|
|
_old_db_url = os.getenv("DATABASE_URL")
|
|
os.environ["DATABASE_URL"] = invalid_db_url
|
|
|
|
with pytest.raises(Exception) as exc_info:
|
|
await ProxyStartupEvent._setup_prisma_client(
|
|
database_url=invalid_db_url,
|
|
proxy_logging_obj=ProxyLogging(user_api_key_cache=user_api_key_cache),
|
|
user_api_key_cache=user_api_key_cache,
|
|
)
|
|
print("GOT EXCEPTION=", exc_info)
|
|
|
|
assert "httpx.ConnectError" in str(exc_info.value)
|
|
|
|
# # Verify the error message indicates a database connection issue
|
|
# assert any(x in str(exc_info.value).lower() for x in ["database", "connection", "authentication"])
|
|
|
|
if _old_db_url:
|
|
os.environ["DATABASE_URL"] = _old_db_url
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_ui_settings_spend_logs_threshold():
|
|
"""
|
|
Test that get_ui_settings correctly sets DISABLE_EXPENSIVE_DB_QUERIES based on spend_logs_row_count threshold
|
|
"""
|
|
from litellm.proxy.management_endpoints.ui_sso import get_ui_settings
|
|
from litellm.proxy.proxy_server import proxy_state
|
|
from fastapi import Request
|
|
from litellm.constants import MAX_SPENDLOG_ROWS_TO_QUERY
|
|
|
|
# Create a mock request
|
|
mock_request = Request(
|
|
scope={
|
|
"type": "http",
|
|
"headers": [],
|
|
"method": "GET",
|
|
"scheme": "http",
|
|
"server": ("testserver", 80),
|
|
"path": "/sso/get/ui_settings",
|
|
"query_string": b"",
|
|
}
|
|
)
|
|
|
|
# Test case 1: When spend_logs_row_count > MAX_SPENDLOG_ROWS_TO_QUERY
|
|
proxy_state.set_proxy_state_variable(
|
|
"spend_logs_row_count", MAX_SPENDLOG_ROWS_TO_QUERY + 1
|
|
)
|
|
response = await get_ui_settings(mock_request)
|
|
print("response from get_ui_settings", json.dumps(response, indent=4))
|
|
assert response["DISABLE_EXPENSIVE_DB_QUERIES"] is True
|
|
assert response["NUM_SPEND_LOGS_ROWS"] == MAX_SPENDLOG_ROWS_TO_QUERY + 1
|
|
|
|
# Test case 2: When spend_logs_row_count < MAX_SPENDLOG_ROWS_TO_QUERY
|
|
proxy_state.set_proxy_state_variable(
|
|
"spend_logs_row_count", MAX_SPENDLOG_ROWS_TO_QUERY - 1
|
|
)
|
|
response = await get_ui_settings(mock_request)
|
|
print("response from get_ui_settings", json.dumps(response, indent=4))
|
|
assert response["DISABLE_EXPENSIVE_DB_QUERIES"] is False
|
|
assert response["NUM_SPEND_LOGS_ROWS"] == MAX_SPENDLOG_ROWS_TO_QUERY - 1
|
|
|
|
# Test case 3: Edge case - exactly MAX_SPENDLOG_ROWS_TO_QUERY
|
|
proxy_state.set_proxy_state_variable(
|
|
"spend_logs_row_count", MAX_SPENDLOG_ROWS_TO_QUERY
|
|
)
|
|
response = await get_ui_settings(mock_request)
|
|
print("response from get_ui_settings", json.dumps(response, indent=4))
|
|
assert response["DISABLE_EXPENSIVE_DB_QUERIES"] is False
|
|
assert response["NUM_SPEND_LOGS_ROWS"] == MAX_SPENDLOG_ROWS_TO_QUERY
|
|
|
|
# Clean up
|
|
proxy_state.set_proxy_state_variable("spend_logs_row_count", 0)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_run_background_health_check_reflects_llm_model_list(monkeypatch):
|
|
"""
|
|
Test that _run_background_health_check reflects changes to llm_model_list in each health check iteration.
|
|
"""
|
|
import litellm.proxy.proxy_server as proxy_server
|
|
import copy
|
|
|
|
test_model_list_1 = [{"model_name": "model-a"}]
|
|
test_model_list_2 = [{"model_name": "model-b"}]
|
|
called_model_lists = []
|
|
|
|
async def fake_perform_health_check(model_list, details, max_concurrency=None):
|
|
called_model_lists.append(copy.deepcopy(model_list))
|
|
return (["healthy"], ["unhealthy"], {})
|
|
|
|
monkeypatch.setattr(proxy_server, "health_check_interval", 1)
|
|
monkeypatch.setattr(proxy_server, "health_check_details", None)
|
|
monkeypatch.setattr(
|
|
proxy_server, "llm_model_list", copy.deepcopy(test_model_list_1)
|
|
)
|
|
monkeypatch.setattr(proxy_server, "perform_health_check", fake_perform_health_check)
|
|
monkeypatch.setattr(proxy_server, "health_check_results", {})
|
|
|
|
async def fake_sleep(interval):
|
|
raise asyncio.CancelledError()
|
|
|
|
monkeypatch.setattr(asyncio, "sleep", fake_sleep)
|
|
|
|
try:
|
|
await proxy_server._run_background_health_check()
|
|
except asyncio.CancelledError:
|
|
pass
|
|
|
|
monkeypatch.setattr(
|
|
proxy_server, "llm_model_list", copy.deepcopy(test_model_list_2)
|
|
)
|
|
|
|
try:
|
|
await proxy_server._run_background_health_check()
|
|
except asyncio.CancelledError:
|
|
pass
|
|
|
|
assert len(called_model_lists) >= 2
|
|
assert called_model_lists[0] == test_model_list_1
|
|
assert called_model_lists[1] == test_model_list_2
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_background_health_check_skip_disabled_models(monkeypatch):
|
|
"""Ensure models with disable_background_health_check are skipped."""
|
|
import litellm.proxy.proxy_server as proxy_server
|
|
import copy
|
|
|
|
test_model_list = [
|
|
{"model_name": "model-a"},
|
|
{
|
|
"model_name": "model-b",
|
|
"model_info": {"disable_background_health_check": True},
|
|
},
|
|
]
|
|
called_model_lists = []
|
|
|
|
async def fake_perform_health_check(
|
|
model_list, details, max_concurrency=None, **kwargs
|
|
):
|
|
called_model_lists.append(copy.deepcopy(model_list))
|
|
return (["healthy"], [], {})
|
|
|
|
monkeypatch.setattr(proxy_server, "health_check_interval", 1)
|
|
monkeypatch.setattr(proxy_server, "health_check_details", None)
|
|
monkeypatch.setattr(proxy_server, "llm_model_list", copy.deepcopy(test_model_list))
|
|
monkeypatch.setattr(proxy_server, "perform_health_check", fake_perform_health_check)
|
|
monkeypatch.setattr(proxy_server, "health_check_results", {})
|
|
|
|
async def fake_sleep(interval):
|
|
raise asyncio.CancelledError()
|
|
|
|
monkeypatch.setattr(asyncio, "sleep", fake_sleep)
|
|
|
|
try:
|
|
await proxy_server._run_background_health_check()
|
|
except asyncio.CancelledError:
|
|
pass
|
|
|
|
assert called_model_lists == [[{"model_name": "model-a"}]]
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_run_direct_health_check_with_instrumentation_legacy_three_arg_stub(
|
|
monkeypatch,
|
|
):
|
|
"""Monkeypatched perform_health_check with only base kwargs should still run."""
|
|
import litellm.proxy.proxy_server as proxy_server
|
|
|
|
async def fake_perform_health_check(model_list, details, max_concurrency=None):
|
|
return ([], [], {})
|
|
|
|
monkeypatch.setattr(proxy_server, "perform_health_check", fake_perform_health_check)
|
|
result = await proxy_server._run_direct_health_check_with_instrumentation(
|
|
[{"model_name": "m"}],
|
|
True,
|
|
1,
|
|
{"enabled": True, "source": "test", "cycle_id": "c1"},
|
|
)
|
|
assert result == ([], [], {})
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_run_direct_health_check_with_instrumentation_accepts_instrumentation_only(
|
|
monkeypatch,
|
|
):
|
|
"""Stub that accepts instrumentation_context but not health_check filter kwargs."""
|
|
import litellm.proxy.proxy_server as proxy_server
|
|
|
|
seen: list = []
|
|
|
|
async def fake_perform_health_check(
|
|
model_list, details, max_concurrency=None, instrumentation_context=None
|
|
):
|
|
seen.append(instrumentation_context)
|
|
return ([], [], {})
|
|
|
|
monkeypatch.setattr(proxy_server, "perform_health_check", fake_perform_health_check)
|
|
await proxy_server._run_direct_health_check_with_instrumentation(
|
|
[],
|
|
False,
|
|
2,
|
|
{"enabled": True, "source": "test", "cycle_id": "c2"},
|
|
)
|
|
assert len(seen) == 1
|
|
assert seen[0]["cycle_id"] == "c2"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_run_direct_health_check_with_instrumentation_accepts_filter_only(
|
|
monkeypatch,
|
|
):
|
|
"""Stub that accepts health_check_skip_disabled_background_models but not instrumentation."""
|
|
import litellm.proxy.proxy_server as proxy_server
|
|
|
|
seen: list = []
|
|
|
|
async def fake_perform_health_check(
|
|
model_list,
|
|
details,
|
|
max_concurrency=None,
|
|
health_check_skip_disabled_background_models=False,
|
|
):
|
|
seen.append(health_check_skip_disabled_background_models)
|
|
return ([], [], {})
|
|
|
|
monkeypatch.setattr(proxy_server, "perform_health_check", fake_perform_health_check)
|
|
await proxy_server._run_direct_health_check_with_instrumentation(
|
|
[],
|
|
True,
|
|
None,
|
|
{"enabled": False},
|
|
)
|
|
assert len(seen) == 1
|
|
assert seen[0] is False
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_run_direct_health_check_with_instrumentation_non_kw_typeerror_reraises(
|
|
monkeypatch,
|
|
):
|
|
import litellm.proxy.proxy_server as proxy_server
|
|
|
|
async def fake_perform_health_check(**kwargs):
|
|
raise TypeError("unsupported operand type(s)")
|
|
|
|
monkeypatch.setattr(proxy_server, "perform_health_check", fake_perform_health_check)
|
|
with pytest.raises(TypeError, match="unsupported operand"):
|
|
await proxy_server._run_direct_health_check_with_instrumentation(
|
|
[],
|
|
True,
|
|
1,
|
|
{},
|
|
)
|
|
|
|
|
|
def test_get_timeout_from_request():
|
|
from litellm.proxy.litellm_pre_call_utils import LiteLLMProxyRequestSetup
|
|
|
|
headers = {
|
|
"x-litellm-timeout": "90",
|
|
}
|
|
timeout = LiteLLMProxyRequestSetup._get_timeout_from_request(headers)
|
|
assert timeout == 90
|
|
|
|
headers = {
|
|
"x-litellm-timeout": "90.5",
|
|
}
|
|
timeout = LiteLLMProxyRequestSetup._get_timeout_from_request(headers)
|
|
assert timeout == 90.5
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"ui_exists, ui_has_content",
|
|
[
|
|
(True, True), # UI path exists and has content
|
|
(True, False), # UI path exists but is empty
|
|
(False, False), # UI path doesn't exist
|
|
],
|
|
)
|
|
def test_non_root_ui_path_logic(monkeypatch, tmp_path, ui_exists, ui_has_content):
|
|
"""
|
|
Test the non-root Docker UI path detection logic.
|
|
|
|
Tests that when LITELLM_NON_ROOT is set to "true":
|
|
- If UI path exists and has content, it should be used
|
|
- If UI path doesn't exist or is empty, proper error logging occurs
|
|
"""
|
|
import tempfile
|
|
import shutil
|
|
from unittest.mock import MagicMock
|
|
|
|
# Create a temporary directory to act as /tmp/litellm_ui
|
|
test_ui_path = tmp_path / "litellm_ui"
|
|
|
|
if ui_exists:
|
|
test_ui_path.mkdir(parents=True, exist_ok=True)
|
|
if ui_has_content:
|
|
# Create some dummy files to simulate built UI
|
|
(test_ui_path / "index.html").write_text("<html></html>")
|
|
(test_ui_path / "app.js").write_text("console.log('test');")
|
|
|
|
# Mock the environment variable and os.path operations
|
|
monkeypatch.setenv("LITELLM_NON_ROOT", "true")
|
|
|
|
# Create a mock logger to capture log messages
|
|
mock_logger = MagicMock()
|
|
|
|
# We need to reimport or reload the relevant code section
|
|
# Since this is module-level code, we'll test the logic directly
|
|
ui_path = None
|
|
non_root_ui_path = str(test_ui_path)
|
|
|
|
# Simulate the logic from proxy_server.py lines 909-920
|
|
if os.getenv("LITELLM_NON_ROOT", "").lower() == "true":
|
|
if os.path.exists(non_root_ui_path) and os.listdir(non_root_ui_path):
|
|
mock_logger.info(
|
|
f"Using pre-built UI for non-root Docker: {non_root_ui_path}"
|
|
)
|
|
mock_logger.info(
|
|
f"UI files found: {len(os.listdir(non_root_ui_path))} items"
|
|
)
|
|
ui_path = non_root_ui_path
|
|
else:
|
|
mock_logger.error(
|
|
f"UI not found at {non_root_ui_path}. UI will not be available."
|
|
)
|
|
mock_logger.error(
|
|
f"Path exists: {os.path.exists(non_root_ui_path)}, Has content: {os.path.exists(non_root_ui_path) and bool(os.listdir(non_root_ui_path))}"
|
|
)
|
|
|
|
# Verify behavior based on test parameters
|
|
if ui_exists and ui_has_content:
|
|
# UI should be found and used
|
|
assert ui_path == non_root_ui_path
|
|
assert mock_logger.info.call_count == 2
|
|
mock_logger.info.assert_any_call(
|
|
f"Using pre-built UI for non-root Docker: {non_root_ui_path}"
|
|
)
|
|
# Verify the second info call mentions the number of items
|
|
info_calls = [call[0][0] for call in mock_logger.info.call_args_list]
|
|
assert any("UI files found:" in call and "items" in call for call in info_calls)
|
|
assert mock_logger.error.call_count == 0
|
|
else:
|
|
# UI should not be found, error should be logged
|
|
assert ui_path is None
|
|
assert mock_logger.error.call_count == 2
|
|
mock_logger.error.assert_any_call(
|
|
f"UI not found at {non_root_ui_path}. UI will not be available."
|
|
)
|
|
# Verify the second error call has path existence info
|
|
error_calls = [call[0][0] for call in mock_logger.error.call_args_list]
|
|
assert any("Path exists:" in call for call in error_calls)
|
|
assert mock_logger.info.call_count == 0
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_config_callbacks_with_all_types(client_no_auth):
|
|
"""
|
|
Test that /get/config/callbacks returns all three callback types:
|
|
- success_callback with type="success"
|
|
- failure_callback with type="failure"
|
|
- callbacks (success_and_failure) with type="success_and_failure"
|
|
"""
|
|
from litellm.proxy.proxy_server import ProxyConfig
|
|
|
|
# Create a mock config with all three callback types
|
|
mock_config_data = {
|
|
"litellm_settings": {
|
|
"success_callback": ["langfuse", "braintrust"],
|
|
"failure_callback": ["sentry"],
|
|
"callbacks": ["otel", "langsmith"],
|
|
},
|
|
"environment_variables": {
|
|
"LANGFUSE_PUBLIC_KEY": "test-public-key",
|
|
"LANGFUSE_SECRET_KEY": "test-secret-key",
|
|
"LANGFUSE_HOST": "https://test.langfuse.com",
|
|
"BRAINTRUST_API_KEY": "test-braintrust-key",
|
|
"OTEL_EXPORTER": "otlp",
|
|
"OTEL_ENDPOINT": "http://localhost:4317",
|
|
"LANGSMITH_API_KEY": "test-langsmith-key",
|
|
},
|
|
"general_settings": {},
|
|
}
|
|
|
|
proxy_config = getattr(litellm.proxy.proxy_server, "proxy_config")
|
|
|
|
with patch.object(
|
|
proxy_config, "get_config", new=AsyncMock(return_value=mock_config_data)
|
|
):
|
|
response = client_no_auth.get("/get/config/callbacks")
|
|
|
|
assert response.status_code == 200
|
|
result = response.json()
|
|
|
|
# Verify response structure
|
|
assert "status" in result
|
|
assert result["status"] == "success"
|
|
assert "callbacks" in result
|
|
|
|
callbacks = result["callbacks"]
|
|
|
|
# Verify we have all 5 callbacks (2 success + 1 failure + 2 success_and_failure)
|
|
assert len(callbacks) == 5
|
|
|
|
# Group callbacks by type
|
|
success_callbacks = [cb for cb in callbacks if cb.get("type") == "success"]
|
|
failure_callbacks = [cb for cb in callbacks if cb.get("type") == "failure"]
|
|
success_and_failure_callbacks = [
|
|
cb for cb in callbacks if cb.get("type") == "success_and_failure"
|
|
]
|
|
|
|
# Verify all callbacks have required fields
|
|
for callback in callbacks:
|
|
assert "name" in callback
|
|
assert "variables" in callback
|
|
assert "type" in callback
|
|
assert callback["type"] in ["success", "failure", "success_and_failure"]
|
|
|
|
# Verify success callbacks
|
|
assert len(success_callbacks) == 2
|
|
success_names = [cb["name"] for cb in success_callbacks]
|
|
assert "langfuse" in success_names
|
|
assert "braintrust" in success_names
|
|
|
|
# Verify failure callbacks
|
|
assert len(failure_callbacks) == 1
|
|
assert failure_callbacks[0]["name"] == "sentry"
|
|
|
|
# Verify success_and_failure callbacks
|
|
assert len(success_and_failure_callbacks) == 2
|
|
success_and_failure_names = [cb["name"] for cb in success_and_failure_callbacks]
|
|
assert "otel" in success_and_failure_names
|
|
assert "langsmith" in success_and_failure_names
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_config_callbacks_environment_variables(client_no_auth):
|
|
"""
|
|
Test that /get/config/callbacks correctly includes environment variables
|
|
for each callback type. Values are returned as-is from the config (no decryption).
|
|
"""
|
|
from litellm.proxy.proxy_server import ProxyConfig
|
|
|
|
# Create a mock config with callbacks and their env vars
|
|
mock_config_data = {
|
|
"litellm_settings": {
|
|
"success_callback": ["langfuse"],
|
|
"failure_callback": [],
|
|
"callbacks": ["otel"],
|
|
},
|
|
"environment_variables": {
|
|
"LANGFUSE_PUBLIC_KEY": "test-public-key",
|
|
"LANGFUSE_SECRET_KEY": "test-secret-key",
|
|
"LANGFUSE_HOST": "https://cloud.langfuse.com",
|
|
"OTEL_EXPORTER": "otlp",
|
|
"OTEL_ENDPOINT": "http://localhost:4317",
|
|
"OTEL_HEADERS": "key=value",
|
|
},
|
|
"general_settings": {},
|
|
}
|
|
|
|
proxy_config = getattr(litellm.proxy.proxy_server, "proxy_config")
|
|
|
|
with patch.object(
|
|
proxy_config, "get_config", new=AsyncMock(return_value=mock_config_data)
|
|
):
|
|
response = client_no_auth.get("/get/config/callbacks")
|
|
|
|
assert response.status_code == 200
|
|
result = response.json()
|
|
|
|
callbacks = result["callbacks"]
|
|
|
|
# Find langfuse callback (success type)
|
|
langfuse_callback = next(
|
|
(cb for cb in callbacks if cb["name"] == "langfuse"), None
|
|
)
|
|
assert langfuse_callback is not None
|
|
assert langfuse_callback["type"] == "success"
|
|
assert "variables" in langfuse_callback
|
|
|
|
# Verify langfuse env vars are present (values returned as-is, no decryption)
|
|
langfuse_vars = langfuse_callback["variables"]
|
|
assert "LANGFUSE_PUBLIC_KEY" in langfuse_vars
|
|
assert langfuse_vars["LANGFUSE_PUBLIC_KEY"] == "test-public-key"
|
|
assert "LANGFUSE_SECRET_KEY" in langfuse_vars
|
|
assert langfuse_vars["LANGFUSE_SECRET_KEY"] == "test-secret-key"
|
|
assert "LANGFUSE_HOST" in langfuse_vars
|
|
assert langfuse_vars["LANGFUSE_HOST"] == "https://cloud.langfuse.com"
|
|
|
|
# Find otel callback (success_and_failure type)
|
|
otel_callback = next((cb for cb in callbacks if cb["name"] == "otel"), None)
|
|
assert otel_callback is not None
|
|
assert otel_callback["type"] == "success_and_failure"
|
|
assert "variables" in otel_callback
|
|
|
|
# Verify otel env vars are present
|
|
otel_vars = otel_callback["variables"]
|
|
assert "OTEL_EXPORTER" in otel_vars
|
|
assert otel_vars["OTEL_EXPORTER"] == "otlp"
|
|
assert "OTEL_ENDPOINT" in otel_vars
|
|
assert otel_vars["OTEL_ENDPOINT"] == "http://localhost:4317"
|
|
assert "OTEL_HEADERS" in otel_vars
|
|
assert otel_vars["OTEL_HEADERS"] == "key=value"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_update_config_success_callback_normalization():
|
|
"""
|
|
Ensure success_callback values are normalized to lowercase when updating config.
|
|
This prevents delete_callback (which searches lowercase) from failing on mixed case inputs like 'SQS'.
|
|
"""
|
|
import litellm.proxy.proxy_server as proxy_server
|
|
from litellm.proxy._types import ConfigYAML
|
|
|
|
setattr(proxy_server, "proxy_logging_obj", MagicMock())
|
|
|
|
existing_litellm_settings = {"success_callback": ["langfuse"]}
|
|
|
|
class FakeRow:
|
|
def __init__(self, name, value):
|
|
self.param_name = name
|
|
self.param_value = value
|
|
|
|
upserted = {}
|
|
|
|
async def fake_find_first(where=None):
|
|
if where and where.get("param_name") == "litellm_settings":
|
|
return FakeRow("litellm_settings", existing_litellm_settings)
|
|
return None
|
|
|
|
async def fake_upsert(where=None, data=None):
|
|
upserted[where["param_name"]] = json.loads(data["update"]["param_value"])
|
|
|
|
class MockPrisma:
|
|
def __init__(self):
|
|
self.db = MagicMock()
|
|
self.db.litellm_config = MagicMock()
|
|
self.db.litellm_config.find_first = AsyncMock(side_effect=fake_find_first)
|
|
self.db.litellm_config.upsert = AsyncMock(side_effect=fake_upsert)
|
|
|
|
setattr(proxy_server, "prisma_client", MockPrisma())
|
|
|
|
class MockProxyConfig:
|
|
async def add_deployment(self, prisma_client=None, proxy_logging_obj=None):
|
|
return None
|
|
|
|
setattr(proxy_server, "proxy_config", MockProxyConfig())
|
|
|
|
config_update = ConfigYAML(litellm_settings={"success_callback": ["SQS", "sQs"]})
|
|
from litellm.proxy._types import LitellmUserRoles, UserAPIKeyAuth
|
|
|
|
admin_user = UserAPIKeyAuth(
|
|
user_role=LitellmUserRoles.PROXY_ADMIN, api_key="sk-test"
|
|
)
|
|
await proxy_server.update_config(config_update, user_api_key_dict=admin_user)
|
|
|
|
assert (
|
|
"litellm_settings" in upserted
|
|
), "litellm_config.upsert was not called for litellm_settings"
|
|
callbacks = upserted["litellm_settings"]["success_callback"]
|
|
|
|
# Deduped and normalized
|
|
assert "sqs" in callbacks
|
|
assert "SQS" not in callbacks
|
|
assert "sQs" not in callbacks
|
|
# Existing callback should still be present
|
|
assert "langfuse" in callbacks
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"data",
|
|
[
|
|
{
|
|
"model": {
|
|
"model_name": "azure/gpt-4.1-mini",
|
|
"litellm_params": {"model": "azure/gpt-4.1-mini"},
|
|
"model_info": {"base_model": "gpt-4.1-mini"},
|
|
},
|
|
"expected": "gpt-4.1-mini",
|
|
},
|
|
{
|
|
"model": {
|
|
"model_name": "openai/gpt-4.1-mini",
|
|
"litellm_params": {"model": "openai/gpt-4.1-mini"},
|
|
},
|
|
"expected": "openai/gpt-4.1-mini",
|
|
},
|
|
{
|
|
"model": {
|
|
"model_name": "openai/gpt-4.1-mini",
|
|
"litellm_params": {"model": "openai/gpt-4.1-mini"},
|
|
"model_info": {"base_model": "gpt-4.1-mini"},
|
|
},
|
|
"expected": "gpt-4.1-mini",
|
|
},
|
|
{
|
|
"model": {
|
|
"model_name": "claude-sonnet-4-5-20250929",
|
|
"litellm_params": {"model": "anthropic/claude-sonnet-4-5@20250929"},
|
|
"model_info": {"base_model": "anthropic/claude-sonnet-4-5-20250929"},
|
|
},
|
|
"expected": "anthropic/claude-sonnet-4-5-20250929",
|
|
},
|
|
{
|
|
"model": {
|
|
"model_name": "gemini-2.5-flash-001",
|
|
"litellm_params": {"model": "gemini/gemini-2.5-flash@001"},
|
|
"model_info": {"base_model": "gemini-2.5-flash-001"},
|
|
},
|
|
"expected": "gemini-2.5-flash-001",
|
|
},
|
|
],
|
|
)
|
|
def test_get_litellm_model_info(data):
|
|
from litellm.proxy.proxy_server import get_litellm_model_info
|
|
|
|
model = data["model"]
|
|
get_info_mock = MagicMock()
|
|
|
|
with mock.patch(
|
|
"litellm.get_model_info",
|
|
new=get_info_mock,
|
|
):
|
|
get_litellm_model_info(model=model)
|
|
get_info_mock.assert_called_once_with(data["expected"])
|