mirror of
https://github.com/tiennm99/litellm.git
synced 2026-06-17 20:48:32 +00:00
51876292a0
* feat(router): integrate allowed_fails_policy into health check failures (#24988) * feat(router): integrate allowed_fails_policy into health check failures Health check failures now increment the same per-deployment failure counters used by allowed_fails_policy, so users can control how many health check failures of each error type are required before a deployment enters cooldown. - ahealth_check() preserves the original exception in its return dict - run_with_timeout() returns a litellm.Timeout on health check timeout - _perform_health_check() propagates exceptions to unhealthy endpoints - _write_health_state_to_router_cache() calls _set_cooldown_deployments for each unhealthy endpoint that has an exception - When allowed_fails_policy is set, the binary health check filter is bypassed so cooldown is the sole routing exclusion mechanism - Safety net: if all deployments are in cooldown with enable_health_check_routing=True, the cooldown filter is bypassed Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * feat(router): add health_check_ignore_transient_errors flag When enabled, health check failures with 429 (rate limit) or 408 (timeout) status codes are skipped from the cooldown pipeline. These are transient load issues, not broken deployments. Auth errors (401), 404, and 5xx errors still increment counters and trigger cooldown as before. Config (general_settings): health_check_ignore_transient_errors: true Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * fix(router): also exclude 429/408 from health state cache when ignore_transient_errors set The previous fix only skipped cooldown counter increments. The health state cache was still marking 429/408 endpoints as is_healthy=False, causing the binary health check filter to exclude them from routing. Now, when health_check_ignore_transient_errors=True, 429/408 endpoints are also excluded from the unhealthy list passed to build_deployment_health_states(), so the binary filter treats them as unaffected (not unhealthy). Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * docs(router): add health check driven routing guide New standalone page covering the full health check routing feature: allowed_fails_policy integration, health_check_ignore_transient_errors, architecture SVG, step-by-step setup, and gotchas (TTL, AllowedFails semantics). Replaces the inline section in health.md with a link to the new page. Added to the Routing & Load Balancing sidebar. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * fix(health-check-routing): fix three CI failures - Add "exception" to ILLEGAL_DISPLAY_PARAMS in health_check.py so the exception object is stripped before the health endpoint serializes results to JSON (fixes TypeError: 'URL' object is not iterable) - Add allowed_fails_policy = None to FakeRouter stubs in test_router_health_check_routing.py (fixes AttributeError) - Add health_check_ignore_transient_errors to config_settings.md router settings reference table (fixes documentation test) Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * Fix litellm/tests/proxy_unit_tests/test_proxy_server.py * fix(router): address greptile review comments - Narrow cooldown safety-net bypass: only fires when allowed_fails_policy is set (cooldown is health-check driven). Without a policy, cooldowns are from real request failures and must not be bypassed. - Restore cooldown deployments DEBUG log that was accidentally removed. - Fix test_health TypeError: move exception extraction to a separate exceptions_by_model_id dict returned alongside endpoints, so exception objects never appear in the endpoint dicts that get JSON-serialized by the /health response. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * fix(health-check-routing): properly isolate exceptions from health response Return exceptions_by_model_id as a separate third value from _perform_health_check / perform_health_check so exception objects (which contain non-JSON-serializable httpx URL types) never appear in the endpoint dicts that get serialized by the /health response. Callers updated: _health_endpoints.py, shared_health_check_manager.py, proxy_server.py background loop. All use the exceptions dict only for cooldown integration, not for display. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * fix(shared-health-check): fix remaining 2-value return sites and update type annotation * fix(health-check-routing): fix P0 cooldown integration never firing The cooldown loop was reading endpoint.get("exception") which is always None because exceptions are now returned via exceptions_by_model_id, not stored in endpoint dicts. Fixed to use _exceptions.get(model_id). Also fixes the transient-error filter to use _exceptions instead of endpoint.get("exception"), and fixes all remaining 2-value return sites in shared_health_check_manager.py. Tests updated to pass exceptions via exceptions_by_model_id parameter instead of endpoint dicts. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * fix(health-check-routing): fix P1 transient-error filter broken on cache hits When SharedHealthCheckManager returns cached results, exceptions_by_model_id is always {} so the transient-error filter defaulted to status 500 for all endpoints, incorrectly marking 429/408 endpoints as unhealthy. Fix: store integer exception_status on each unhealthy endpoint dict in _perform_health_check. _get_endpoint_exception_status() uses the live exception object when available (direct path) and falls back to the stored integer (cache-hit path). The integer is JSON-serializable and survives the shared cache round-trip. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * fix(health-check-routing): gate cooldown loop behind allowed_fails_policy Without the policy, cooldown is not the routing exclusion mechanism. Firing _set_cooldown_deployments for all enable_health_check_routing users was a backwards-incompatible change — 401s would immediately cooldown deployments that the binary filter would have recovered on the next cycle. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * revert: undo allowed_fails_policy gate on cooldown loop Cooldown integration via health checks is intentional for all enable_health_check_routing users, not just those with allowed_fails_policy. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * fix(docs+tests): fix health_check_ignore_transient_errors doc section and test coverage - Move health_check_ignore_transient_errors from router_settings to general_settings in config_settings.md (code reads it from general_settings) - Remove duplicate enable_health_check_routing / health_check_staleness_threshold entries that were incorrectly listed under router_settings - Replace TestHealthCheckEndpointExceptionPropagation tests with ones that exercise the real _perform_health_check code path via mocked ahealth_check, verifying exceptions appear in exceptions_by_model_id and NOT in endpoint dicts Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * fix(tests+docs): fix tuple unpacking and docs test failures - Update test mocks that return (healthy, unhealthy) to return (healthy, unhealthy, {}) to match the new 3-value signature - Update test unpackings of perform_shared_health_check to use healthy, unhealthy, _ = ... - Add health_check_ignore_transient_errors to router_settings section in config_settings.md (it is a Router constructor param, so the doc test requires it there; it also lives in general_settings for proxy use) Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * Fix CodeQL errors * fix(tests): fix 2-value unpackings of _perform_health_check in test_health_check.py * fix(tests): fix mock _perform_health_check returning 2-tuple instead of 3 * fix team routing --------- Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * fix: add distributed lock for key rotation job (#23364) * fix: add distributed lock for key rotation job * fix: address Greptile review feedback on key rotation lock (#23834) * fix: address Greptile review feedback on key rotation lock * fix req changes greptile * feat(proxy): Optional on_error for guardrail pipeline (API / technical failures) (#24831) * guardrails fallback * docs * docs: add LITELLM_KEY_ROTATION_LOCK_TTL_SECONDS to environment variables reference * fix(mypy): accept Union[Dict, Any] in _get_deployment_order and use typed list to fix min() type error * fix(mypy): use Optional[str] for api_base in PydanticAI provider to match superclass signature --------- Co-authored-by: Sameer Kankute <sameer@berri.ai> Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Co-authored-by: Harshit Jain <48647625+Harshit28j@users.noreply.github.com> Co-authored-by: Shivam Rawat <shivam@berri.ai> Co-authored-by: yuneng-jiang <yuneng@berri.ai>
2862 lines
96 KiB
Python
2862 lines
96 KiB
Python
import os
|
|
import sys
|
|
import traceback
|
|
from unittest import mock
|
|
|
|
from dotenv import load_dotenv
|
|
|
|
import litellm.proxy
|
|
import litellm.proxy.proxy_server
|
|
|
|
load_dotenv()
|
|
import io
|
|
import json
|
|
import os
|
|
|
|
# this file is to test litellm/proxy
|
|
|
|
sys.path.insert(
|
|
0, os.path.abspath("../..")
|
|
) # Adds the parent directory to the system path
|
|
import asyncio
|
|
import logging
|
|
|
|
import pytest
|
|
|
|
import litellm
|
|
from litellm import RateLimitError, Timeout, completion, completion_cost, embedding
|
|
|
|
# Configure logging
|
|
logging.basicConfig(
|
|
level=logging.DEBUG, # Set the desired logging level
|
|
format="%(asctime)s - %(levelname)s - %(message)s",
|
|
)
|
|
|
|
from unittest.mock import AsyncMock, patch
|
|
|
|
from fastapi import FastAPI
|
|
|
|
# test /chat/completion request to the proxy
|
|
from fastapi.testclient import TestClient
|
|
|
|
from litellm.integrations.custom_logger import CustomLogger
|
|
from litellm.proxy.proxy_server import ( # Replace with the actual module where your FastAPI router is defined
|
|
app,
|
|
initialize,
|
|
save_worker_config,
|
|
)
|
|
from litellm.proxy.utils import ProxyLogging
|
|
|
|
# Your bearer token
|
|
token = "sk-1234"
|
|
|
|
headers = {"Authorization": f"Bearer {token}"}
|
|
|
|
example_completion_result = {
|
|
"choices": [
|
|
{
|
|
"message": {
|
|
"content": "Whispers of the wind carry dreams to me.",
|
|
"role": "assistant",
|
|
}
|
|
}
|
|
],
|
|
}
|
|
example_embedding_result = {
|
|
"object": "list",
|
|
"data": [
|
|
{
|
|
"object": "embedding",
|
|
"index": 0,
|
|
"embedding": [
|
|
-0.006929283495992422,
|
|
-0.005336422007530928,
|
|
-4.547132266452536e-05,
|
|
-0.024047505110502243,
|
|
-0.006929283495992422,
|
|
-0.005336422007530928,
|
|
-4.547132266452536e-05,
|
|
-0.024047505110502243,
|
|
-0.006929283495992422,
|
|
-0.005336422007530928,
|
|
-4.547132266452536e-05,
|
|
-0.024047505110502243,
|
|
],
|
|
}
|
|
],
|
|
"model": "text-embedding-3-small",
|
|
"usage": {"prompt_tokens": 5, "total_tokens": 5},
|
|
}
|
|
example_image_generation_result = {
|
|
"created": 1589478378,
|
|
"data": [{"url": "https://..."}, {"url": "https://..."}],
|
|
}
|
|
|
|
|
|
def mock_patch_acompletion():
|
|
return mock.patch(
|
|
"litellm.proxy.proxy_server.llm_router.acompletion",
|
|
return_value=example_completion_result,
|
|
)
|
|
|
|
|
|
def mock_patch_aembedding():
|
|
return mock.patch(
|
|
"litellm.proxy.proxy_server.llm_router.aembedding",
|
|
return_value=example_embedding_result,
|
|
)
|
|
|
|
|
|
def mock_patch_aimage_generation():
|
|
return mock.patch(
|
|
"litellm.proxy.proxy_server.llm_router.aimage_generation",
|
|
return_value=example_image_generation_result,
|
|
)
|
|
|
|
|
|
@pytest.fixture(scope="function")
|
|
def fake_env_vars(monkeypatch):
|
|
# Set some fake environment variables
|
|
monkeypatch.setenv("OPENAI_API_KEY", "fake_openai_api_key")
|
|
monkeypatch.setenv("OPENAI_API_BASE", "http://fake-openai-api-base")
|
|
monkeypatch.setenv("AZURE_AI_API_BASE", "http://fake-azure-api-base")
|
|
monkeypatch.setenv("AZURE_OPENAI_API_KEY", "fake_azure_openai_api_key")
|
|
monkeypatch.setenv("AZURE_SWEDEN_API_BASE", "http://fake-azure-sweden-api-base")
|
|
monkeypatch.setenv("REDIS_HOST", "localhost")
|
|
|
|
|
|
@pytest.fixture(scope="function")
|
|
def client_no_auth(fake_env_vars):
|
|
# Assuming litellm.proxy.proxy_server is an object
|
|
from litellm.proxy.proxy_server import cleanup_router_config_variables
|
|
|
|
cleanup_router_config_variables()
|
|
filepath = os.path.dirname(os.path.abspath(__file__))
|
|
config_fp = f"{filepath}/test_configs/test_config_no_auth.yaml"
|
|
# initialize can get run in parallel, it sets specific variables for the fast api app, sinc eit gets run in parallel different tests use the wrong variables
|
|
asyncio.run(initialize(config=config_fp, debug=True))
|
|
return TestClient(app)
|
|
|
|
|
|
@mock_patch_acompletion()
|
|
def test_chat_completion(mock_acompletion, client_no_auth):
|
|
global headers
|
|
try:
|
|
# Your test data
|
|
test_data = {
|
|
"model": "gpt-3.5-turbo",
|
|
"messages": [
|
|
{"role": "user", "content": "hi"},
|
|
],
|
|
"max_tokens": 10,
|
|
}
|
|
|
|
print("testing proxy server with chat completions")
|
|
response = client_no_auth.post("/v1/chat/completions", json=test_data)
|
|
mock_acompletion.assert_called_once_with(
|
|
model="gpt-3.5-turbo",
|
|
messages=[
|
|
{"role": "user", "content": "hi"},
|
|
],
|
|
max_tokens=10,
|
|
litellm_call_id=mock.ANY,
|
|
litellm_logging_obj=mock.ANY,
|
|
request_timeout=mock.ANY,
|
|
specific_deployment=True,
|
|
metadata=mock.ANY,
|
|
proxy_server_request=mock.ANY,
|
|
secret_fields=mock.ANY,
|
|
)
|
|
print(f"response - {response.text}")
|
|
assert response.status_code == 200
|
|
result = response.json()
|
|
print(f"Received response: {result}")
|
|
except Exception as e:
|
|
pytest.fail(f"LiteLLM Proxy test failed. Exception - {str(e)}")
|
|
|
|
|
|
def test_chat_completion_malformed_messages_returns_400(client_no_auth):
|
|
"""
|
|
Test that malformed messages (strings instead of dicts) return 400 instead of 500.
|
|
|
|
This test verifies that when a client sends messages as raw strings instead of
|
|
{role, content} objects, LiteLLM returns a 400 invalid_request_error instead
|
|
of a 500 Internal Server Error.
|
|
"""
|
|
global headers
|
|
try:
|
|
# Test data with malformed messages (string instead of dict)
|
|
test_data = {
|
|
"model": "gpt-3.5-turbo",
|
|
"messages": [
|
|
"hi how are you"
|
|
], # Invalid: should be [{"role": "user", "content": "hi how are you"}]
|
|
}
|
|
|
|
print("testing proxy server with malformed messages")
|
|
response = client_no_auth.post(
|
|
"/v1/chat/completions", json=test_data, headers=headers
|
|
)
|
|
|
|
print(f"response status: {response.status_code}")
|
|
print(f"response text: {response.text}")
|
|
|
|
# Should return 400, not 500
|
|
assert (
|
|
response.status_code == 400
|
|
), f"Expected 400, got {response.status_code}. Response: {response.text}"
|
|
|
|
# Verify error format
|
|
result = response.json()
|
|
assert "error" in result, "Response should contain 'error' key"
|
|
error = result["error"]
|
|
|
|
# Verify error type and message
|
|
assert (
|
|
error.get("type") == "invalid_request_error" or error.get("type") is None
|
|
), f"Expected invalid_request_error or None, got {error.get('type')}"
|
|
assert (
|
|
error.get("code") == "400" or error.get("code") == 400
|
|
), f"Expected code 400, got {error.get('code')}"
|
|
|
|
# Error message should indicate invalid request format
|
|
error_message = error.get("message", "")
|
|
assert len(error_message) > 0, "Error message should not be empty"
|
|
|
|
except Exception as e:
|
|
pytest.fail(f"LiteLLM Proxy test failed. Exception - {str(e)}")
|
|
|
|
|
|
def test_get_settings_request_timeout(client_no_auth):
|
|
"""
|
|
When no timeout is set, it should use the litellm.request_timeout value
|
|
"""
|
|
# Set a known value for litellm.request_timeout
|
|
import litellm
|
|
|
|
# Make a GET request to /settings
|
|
response = client_no_auth.get("/settings")
|
|
|
|
# Check if the request was successful
|
|
assert response.status_code == 200
|
|
|
|
# Parse the JSON response
|
|
settings = response.json()
|
|
print("settings", settings)
|
|
|
|
assert settings["litellm.request_timeout"] == litellm.request_timeout
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"litellm_key_header_name",
|
|
["x-litellm-key", None],
|
|
)
|
|
def test_add_headers_to_request(litellm_key_header_name):
|
|
from fastapi import Request
|
|
from starlette.datastructures import URL
|
|
import json
|
|
from litellm.proxy.litellm_pre_call_utils import (
|
|
clean_headers,
|
|
LiteLLMProxyRequestSetup,
|
|
)
|
|
|
|
headers = {
|
|
"Authorization": "Bearer 1234",
|
|
"X-Custom-Header": "Custom-Value",
|
|
"X-Stainless-Header": "Stainless-Value",
|
|
"anthropic-beta": "beta-value",
|
|
}
|
|
request = Request(scope={"type": "http"})
|
|
request._url = URL(url="/chat/completions")
|
|
request._body = json.dumps({"model": "gpt-3.5-turbo"}).encode("utf-8")
|
|
request_headers = clean_headers(headers, litellm_key_header_name)
|
|
forwarded_headers = LiteLLMProxyRequestSetup._get_forwardable_headers(
|
|
request_headers
|
|
)
|
|
assert forwarded_headers == {
|
|
"X-Custom-Header": "Custom-Value",
|
|
"anthropic-beta": "beta-value",
|
|
}
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"litellm_key_header_name",
|
|
["x-litellm-key", None],
|
|
)
|
|
@pytest.mark.parametrize(
|
|
"forward_headers",
|
|
[True, False],
|
|
)
|
|
@mock_patch_acompletion()
|
|
def test_chat_completion_forward_headers(
|
|
mock_acompletion, client_no_auth, litellm_key_header_name, forward_headers
|
|
):
|
|
global headers
|
|
try:
|
|
if forward_headers:
|
|
gs = getattr(litellm.proxy.proxy_server, "general_settings")
|
|
gs["forward_client_headers_to_llm_api"] = True
|
|
setattr(litellm.proxy.proxy_server, "general_settings", gs)
|
|
if litellm_key_header_name is not None:
|
|
gs = getattr(litellm.proxy.proxy_server, "general_settings")
|
|
gs["litellm_key_header_name"] = litellm_key_header_name
|
|
setattr(litellm.proxy.proxy_server, "general_settings", gs)
|
|
# Your test data
|
|
test_data = {
|
|
"model": "gpt-3.5-turbo",
|
|
"messages": [
|
|
{"role": "user", "content": "hi"},
|
|
],
|
|
"max_tokens": 10,
|
|
}
|
|
|
|
headers_to_forward = {
|
|
"X-Custom-Header": "Custom-Value",
|
|
"X-Another-Header": "Another-Value",
|
|
}
|
|
|
|
if litellm_key_header_name is not None:
|
|
headers_to_not_forward = {litellm_key_header_name: "Bearer 1234"}
|
|
else:
|
|
headers_to_not_forward = {"Authorization": "Bearer 1234"}
|
|
|
|
received_headers = {**headers_to_forward, **headers_to_not_forward}
|
|
|
|
print("testing proxy server with chat completions")
|
|
response = client_no_auth.post(
|
|
"/v1/chat/completions", json=test_data, headers=received_headers
|
|
)
|
|
if not forward_headers:
|
|
assert "headers" not in mock_acompletion.call_args.kwargs
|
|
else:
|
|
assert mock_acompletion.call_args.kwargs["headers"] == {
|
|
"x-custom-header": "Custom-Value",
|
|
"x-another-header": "Another-Value",
|
|
}
|
|
|
|
print(f"response - {response.text}")
|
|
assert response.status_code == 200
|
|
result = response.json()
|
|
print(f"Received response: {result}")
|
|
except Exception as e:
|
|
pytest.fail(f"LiteLLM Proxy test failed. Exception - {str(e)}")
|
|
|
|
|
|
@pytest.mark.parametrize("forward_llm_auth_headers", [True, False])
|
|
@mock_patch_acompletion()
|
|
def test_chat_completion_forward_llm_provider_auth_headers(
|
|
mock_acompletion, client_no_auth, forward_llm_auth_headers
|
|
):
|
|
"""
|
|
Test that LLM provider auth headers (x-api-key, x-goog-api-key) are forwarded
|
|
when forward_llm_provider_auth_headers=True.
|
|
|
|
This allows clients to send their own LLM provider API keys through the proxy.
|
|
"""
|
|
try:
|
|
# Configure general settings
|
|
gs = getattr(litellm.proxy.proxy_server, "general_settings")
|
|
gs["forward_client_headers_to_llm_api"] = True
|
|
gs["forward_llm_provider_auth_headers"] = forward_llm_auth_headers
|
|
setattr(litellm.proxy.proxy_server, "general_settings", gs)
|
|
|
|
# Test data
|
|
test_data = {
|
|
"model": "gpt-3.5-turbo",
|
|
"messages": [
|
|
{"role": "user", "content": "hello"},
|
|
],
|
|
"max_tokens": 10,
|
|
}
|
|
|
|
# Headers including LLM provider auth
|
|
request_headers = {
|
|
"Authorization": "Bearer sk-proxy-auth-123", # Proxy auth (should be stripped)
|
|
"x-api-key": "sk-ant-api03-test-anthropic-key", # Anthropic API key
|
|
"x-goog-api-key": "google-api-key-123", # Google API key
|
|
"X-Custom-Header": "custom-value", # Custom header (should be forwarded)
|
|
}
|
|
|
|
# Make request
|
|
response = client_no_auth.post(
|
|
"/v1/chat/completions", json=test_data, headers=request_headers
|
|
)
|
|
|
|
assert response.status_code == 200
|
|
|
|
# Check forwarded headers
|
|
forwarded_headers = mock_acompletion.call_args.kwargs.get("headers", {})
|
|
|
|
if forward_llm_auth_headers:
|
|
# LLM provider auth headers should be forwarded
|
|
assert "x-api-key" in forwarded_headers
|
|
assert forwarded_headers["x-api-key"] == "sk-ant-api03-test-anthropic-key"
|
|
assert "x-goog-api-key" in forwarded_headers
|
|
assert forwarded_headers["x-goog-api-key"] == "google-api-key-123"
|
|
else:
|
|
# LLM provider auth headers should be stripped
|
|
assert "x-api-key" not in forwarded_headers
|
|
assert "x-goog-api-key" not in forwarded_headers
|
|
|
|
# Custom headers should always be forwarded (when forward_client_headers_to_llm_api=True)
|
|
assert "x-custom-header" in forwarded_headers
|
|
assert forwarded_headers["x-custom-header"] == "custom-value"
|
|
|
|
# Proxy Authorization should never be forwarded
|
|
assert "authorization" not in forwarded_headers
|
|
|
|
print(
|
|
f"✓ Test passed with forward_llm_provider_auth_headers={forward_llm_auth_headers}"
|
|
)
|
|
print(f" Forwarded headers: {list(forwarded_headers.keys())}")
|
|
|
|
except Exception as e:
|
|
pytest.fail(
|
|
f"Test failed with forward_llm_auth_headers={forward_llm_auth_headers}: {str(e)}"
|
|
)
|
|
finally:
|
|
# Clean up
|
|
gs = getattr(litellm.proxy.proxy_server, "general_settings")
|
|
gs.pop("forward_llm_provider_auth_headers", None)
|
|
setattr(litellm.proxy.proxy_server, "general_settings", gs)
|
|
|
|
|
|
@mock_patch_acompletion()
|
|
@pytest.mark.asyncio
|
|
async def test_team_disable_guardrails(mock_acompletion, client_no_auth):
|
|
"""
|
|
If team not allowed to turn on/off guardrails
|
|
|
|
Raise 403 forbidden error, if request is made by team on `/key/generate` or `/chat/completions`.
|
|
"""
|
|
import asyncio
|
|
import json
|
|
import time
|
|
|
|
from fastapi import HTTPException, Request
|
|
from starlette.datastructures import URL
|
|
|
|
from litellm.proxy._types import (
|
|
LiteLLM_TeamTable,
|
|
LiteLLM_TeamTableCachedObj,
|
|
ProxyException,
|
|
UserAPIKeyAuth,
|
|
)
|
|
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
|
from litellm.proxy.proxy_server import hash_token, user_api_key_cache
|
|
|
|
_team_id = "1234"
|
|
user_key = "sk-12345678"
|
|
|
|
valid_token = UserAPIKeyAuth(
|
|
team_id=_team_id,
|
|
team_blocked=True,
|
|
token=hash_token(user_key),
|
|
last_refreshed_at=time.time(),
|
|
)
|
|
await asyncio.sleep(1)
|
|
team_obj = LiteLLM_TeamTableCachedObj(
|
|
team_id=_team_id,
|
|
blocked=False,
|
|
last_refreshed_at=time.time(),
|
|
metadata={"guardrails": {"modify_guardrails": False}},
|
|
)
|
|
user_api_key_cache.set_cache(key=hash_token(user_key), value=valid_token)
|
|
user_api_key_cache.set_cache(key="team_id:{}".format(_team_id), value=team_obj)
|
|
|
|
setattr(litellm.proxy.proxy_server, "user_api_key_cache", user_api_key_cache)
|
|
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
|
|
setattr(litellm.proxy.proxy_server, "prisma_client", "hello-world")
|
|
|
|
request = Request(scope={"type": "http"})
|
|
request._url = URL(url="/chat/completions")
|
|
|
|
body = {"metadata": {"guardrails": {"hide_secrets": False}}}
|
|
json_bytes = json.dumps(body).encode("utf-8")
|
|
|
|
request._body = json_bytes
|
|
|
|
try:
|
|
await user_api_key_auth(request=request, api_key="Bearer " + user_key)
|
|
pytest.fail("Expected to raise 403 forbidden error.")
|
|
except ProxyException as e:
|
|
assert e.code == str(403)
|
|
|
|
|
|
from test_custom_callback_input import CompletionCustomHandler
|
|
|
|
|
|
@mock_patch_acompletion()
|
|
def test_custom_logger_failure_handler(mock_acompletion, client_no_auth):
|
|
from litellm.proxy._types import UserAPIKeyAuth
|
|
from litellm.proxy.proxy_server import hash_token, user_api_key_cache
|
|
|
|
rpm_limit = 0
|
|
|
|
mock_api_key = "sk-my-test-key"
|
|
cache_value = UserAPIKeyAuth(token=hash_token(mock_api_key), rpm_limit=rpm_limit)
|
|
|
|
user_api_key_cache.set_cache(key=hash_token(mock_api_key), value=cache_value)
|
|
|
|
mock_logger = CustomLogger()
|
|
mock_logger_unit_tests = CompletionCustomHandler()
|
|
proxy_logging_obj: ProxyLogging = getattr(
|
|
litellm.proxy.proxy_server, "proxy_logging_obj"
|
|
)
|
|
|
|
litellm.callbacks = [mock_logger, mock_logger_unit_tests]
|
|
proxy_logging_obj._init_litellm_callbacks(llm_router=None)
|
|
|
|
setattr(litellm.proxy.proxy_server, "user_api_key_cache", user_api_key_cache)
|
|
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
|
|
setattr(litellm.proxy.proxy_server, "prisma_client", "FAKE-VAR")
|
|
setattr(litellm.proxy.proxy_server, "proxy_logging_obj", proxy_logging_obj)
|
|
|
|
with patch.object(
|
|
mock_logger, "async_log_failure_event", new=AsyncMock()
|
|
) as mock_failed_alert:
|
|
# Your test data
|
|
test_data = {
|
|
"model": "gpt-3.5-turbo",
|
|
"messages": [
|
|
{"role": "user", "content": "hi"},
|
|
],
|
|
"max_tokens": 10,
|
|
}
|
|
|
|
print("testing proxy server with chat completions")
|
|
response = client_no_auth.post(
|
|
"/v1/chat/completions",
|
|
json=test_data,
|
|
headers={"Authorization": "Bearer {}".format(mock_api_key)},
|
|
)
|
|
assert response.status_code == 429
|
|
|
|
# confirm async_log_failure_event is called
|
|
mock_failed_alert.assert_called()
|
|
|
|
assert len(mock_logger_unit_tests.errors) == 0
|
|
|
|
|
|
@mock_patch_acompletion()
|
|
def test_engines_model_chat_completions(mock_acompletion, client_no_auth):
|
|
global headers
|
|
try:
|
|
# Your test data
|
|
test_data = {
|
|
"model": "gpt-3.5-turbo",
|
|
"messages": [
|
|
{"role": "user", "content": "hi"},
|
|
],
|
|
"max_tokens": 10,
|
|
}
|
|
|
|
print("testing proxy server with chat completions")
|
|
response = client_no_auth.post(
|
|
"/engines/gpt-3.5-turbo/chat/completions", json=test_data
|
|
)
|
|
mock_acompletion.assert_called_once_with(
|
|
model="gpt-3.5-turbo",
|
|
messages=[
|
|
{"role": "user", "content": "hi"},
|
|
],
|
|
max_tokens=10,
|
|
litellm_call_id=mock.ANY,
|
|
litellm_logging_obj=mock.ANY,
|
|
request_timeout=mock.ANY,
|
|
specific_deployment=True,
|
|
metadata=mock.ANY,
|
|
proxy_server_request=mock.ANY,
|
|
secret_fields=mock.ANY,
|
|
)
|
|
print(f"response - {response.text}")
|
|
assert response.status_code == 200
|
|
result = response.json()
|
|
print(f"Received response: {result}")
|
|
except Exception as e:
|
|
pytest.fail(f"LiteLLM Proxy test failed. Exception - {str(e)}")
|
|
|
|
|
|
@mock_patch_acompletion()
|
|
def test_chat_completion_azure(mock_acompletion, client_no_auth):
|
|
global headers
|
|
try:
|
|
# Your test data
|
|
test_data = {
|
|
"model": "azure/gpt-4.1-mini",
|
|
"messages": [
|
|
{"role": "user", "content": "write 1 sentence poem"},
|
|
],
|
|
"max_tokens": 10,
|
|
}
|
|
|
|
print("testing proxy server with Azure Request /chat/completions")
|
|
response = client_no_auth.post("/v1/chat/completions", json=test_data)
|
|
|
|
mock_acompletion.assert_called_once_with(
|
|
model="azure/gpt-4.1-mini",
|
|
messages=[
|
|
{"role": "user", "content": "write 1 sentence poem"},
|
|
],
|
|
max_tokens=10,
|
|
litellm_call_id=mock.ANY,
|
|
litellm_logging_obj=mock.ANY,
|
|
request_timeout=mock.ANY,
|
|
specific_deployment=True,
|
|
metadata=mock.ANY,
|
|
proxy_server_request=mock.ANY,
|
|
secret_fields=mock.ANY,
|
|
)
|
|
assert response.status_code == 200
|
|
result = response.json()
|
|
print(f"Received response: {result}")
|
|
assert len(result["choices"][0]["message"]["content"]) > 0
|
|
except Exception as e:
|
|
pytest.fail(f"LiteLLM Proxy test failed. Exception - {str(e)}")
|
|
|
|
|
|
# Run the test
|
|
# test_chat_completion_azure()
|
|
|
|
|
|
@mock_patch_acompletion()
|
|
def test_openai_deployments_model_chat_completions_azure(
|
|
mock_acompletion, client_no_auth
|
|
):
|
|
global headers
|
|
try:
|
|
# Your test data
|
|
test_data = {
|
|
"model": "azure/gpt-4.1-mini",
|
|
"messages": [
|
|
{"role": "user", "content": "write 1 sentence poem"},
|
|
],
|
|
"max_tokens": 10,
|
|
}
|
|
|
|
url = "/openai/deployments/azure/gpt-4.1-mini/chat/completions"
|
|
print(f"testing proxy server with Azure Request {url}")
|
|
response = client_no_auth.post(url, json=test_data)
|
|
|
|
mock_acompletion.assert_called_once_with(
|
|
model="azure/gpt-4.1-mini",
|
|
messages=[
|
|
{"role": "user", "content": "write 1 sentence poem"},
|
|
],
|
|
max_tokens=10,
|
|
litellm_call_id=mock.ANY,
|
|
litellm_logging_obj=mock.ANY,
|
|
request_timeout=mock.ANY,
|
|
specific_deployment=True,
|
|
metadata=mock.ANY,
|
|
proxy_server_request=mock.ANY,
|
|
secret_fields=mock.ANY,
|
|
)
|
|
assert response.status_code == 200
|
|
result = response.json()
|
|
print(f"Received response: {result}")
|
|
assert len(result["choices"][0]["message"]["content"]) > 0
|
|
except Exception as e:
|
|
pytest.fail(f"LiteLLM Proxy test failed. Exception - {str(e)}")
|
|
|
|
|
|
# Run the test
|
|
# test_openai_deployments_model_chat_completions_azure()
|
|
|
|
|
|
### EMBEDDING
|
|
@mock_patch_aembedding()
|
|
def test_embedding(mock_aembedding, client_no_auth):
|
|
global headers
|
|
from litellm.proxy.proxy_server import user_custom_auth
|
|
|
|
try:
|
|
test_data = {
|
|
"model": "azure/text-embedding-ada-002",
|
|
"input": ["good morning from litellm"],
|
|
}
|
|
|
|
async def _pre_call_hook_side_effect(**kwargs):
|
|
data = kwargs["data"]
|
|
metadata = {**(data.get("metadata") or {}), "source": "unit-test"}
|
|
data["metadata"] = metadata
|
|
proxy_request = {**(data.get("proxy_server_request") or {})}
|
|
proxy_request["path"] = "/v1/embeddings"
|
|
data["proxy_server_request"] = proxy_request
|
|
return data
|
|
|
|
async def _post_call_success_side_effect(**kwargs):
|
|
return kwargs["response"]
|
|
|
|
with patch.object(
|
|
litellm.proxy.proxy_server.proxy_logging_obj,
|
|
"pre_call_hook",
|
|
new=AsyncMock(side_effect=_pre_call_hook_side_effect),
|
|
) as mock_pre_call_hook, patch.object(
|
|
litellm.proxy.proxy_server.proxy_logging_obj,
|
|
"during_call_hook",
|
|
new=AsyncMock(return_value=None),
|
|
) as mock_during_hook, patch.object(
|
|
litellm.proxy.proxy_server.proxy_logging_obj,
|
|
"post_call_success_hook",
|
|
new=AsyncMock(side_effect=_post_call_success_side_effect),
|
|
):
|
|
response = client_no_auth.post("/v1/embeddings", json=test_data)
|
|
|
|
mock_aembedding.assert_called_once_with(
|
|
model="azure/text-embedding-ada-002",
|
|
input=["good morning from litellm"],
|
|
specific_deployment=True,
|
|
litellm_call_id=mock.ANY,
|
|
litellm_logging_obj=mock.ANY,
|
|
request_timeout=mock.ANY,
|
|
metadata=mock.ANY,
|
|
proxy_server_request=mock.ANY,
|
|
secret_fields=mock.ANY,
|
|
)
|
|
assert response.status_code == 200
|
|
result = response.json()
|
|
print(len(result["data"][0]["embedding"]))
|
|
assert len(result["data"][0]["embedding"]) > 10 # this usually has len==1536 so
|
|
|
|
call_metadata = mock_aembedding.call_args.kwargs["metadata"]
|
|
assert call_metadata.get("source") == "unit-test"
|
|
|
|
pre_call_kwargs = mock_pre_call_hook.await_args_list[0].kwargs
|
|
assert (
|
|
pre_call_kwargs.get("call_type") == "aembedding"
|
|
), f"expected pre_call_hook to receive call_type='aembedding', got {pre_call_kwargs.get('call_type')}"
|
|
|
|
except Exception as e:
|
|
pytest.fail(f"LiteLLM Proxy test failed. Exception - {str(e)}")
|
|
|
|
|
|
@mock_patch_aembedding()
|
|
def test_bedrock_embedding(mock_aembedding, client_no_auth):
|
|
global headers
|
|
from litellm.proxy.proxy_server import user_custom_auth
|
|
|
|
try:
|
|
test_data = {
|
|
"model": "amazon-embeddings",
|
|
"input": ["good morning from litellm"],
|
|
}
|
|
|
|
response = client_no_auth.post("/v1/embeddings", json=test_data)
|
|
|
|
mock_aembedding.assert_called_once_with(
|
|
model="amazon-embeddings",
|
|
input=["good morning from litellm"],
|
|
litellm_call_id=mock.ANY,
|
|
litellm_logging_obj=mock.ANY,
|
|
request_timeout=mock.ANY,
|
|
metadata=mock.ANY,
|
|
proxy_server_request=mock.ANY,
|
|
secret_fields=mock.ANY,
|
|
)
|
|
assert response.status_code == 200
|
|
print(response.status_code, response.text)
|
|
result = response.json()
|
|
print(len(result["data"][0]["embedding"]))
|
|
assert len(result["data"][0]["embedding"]) > 10 # this usually has len==1536 so
|
|
except Exception as e:
|
|
pytest.fail(f"LiteLLM Proxy test failed. Exception - {str(e)}")
|
|
|
|
|
|
@pytest.mark.skip(reason="AWS Suspended Account")
|
|
def test_sagemaker_embedding(client_no_auth):
|
|
global headers
|
|
from litellm.proxy.proxy_server import user_custom_auth
|
|
|
|
try:
|
|
test_data = {
|
|
"model": "GPT-J 6B - Sagemaker Text Embedding (Internal)",
|
|
"input": ["good morning from litellm"],
|
|
}
|
|
|
|
response = client_no_auth.post("/v1/embeddings", json=test_data)
|
|
|
|
assert response.status_code == 200
|
|
result = response.json()
|
|
print(len(result["data"][0]["embedding"]))
|
|
assert len(result["data"][0]["embedding"]) > 10 # this usually has len==1536 so
|
|
except Exception as e:
|
|
pytest.fail(f"LiteLLM Proxy test failed. Exception - {str(e)}")
|
|
|
|
|
|
# Run the test
|
|
# test_embedding()
|
|
#### IMAGE GENERATION
|
|
|
|
|
|
@mock_patch_aimage_generation()
|
|
def test_img_gen(mock_aimage_generation, client_no_auth):
|
|
global headers
|
|
from litellm.proxy.proxy_server import user_custom_auth
|
|
|
|
try:
|
|
test_data = {
|
|
"model": "dall-e-3",
|
|
"prompt": "A cute baby sea otter",
|
|
"n": 1,
|
|
"size": "1024x1024",
|
|
}
|
|
|
|
response = client_no_auth.post("/v1/images/generations", json=test_data)
|
|
|
|
mock_aimage_generation.assert_called_once_with(
|
|
model="dall-e-3",
|
|
prompt="A cute baby sea otter",
|
|
n=1,
|
|
size="1024x1024",
|
|
metadata=mock.ANY,
|
|
proxy_server_request=mock.ANY,
|
|
secret_fields=mock.ANY,
|
|
)
|
|
assert response.status_code == 200
|
|
result = response.json()
|
|
print(len(result["data"][0]["url"]))
|
|
assert len(result["data"][0]["url"]) > 10
|
|
except Exception as e:
|
|
pytest.fail(f"LiteLLM Proxy test failed. Exception - {str(e)}")
|
|
|
|
|
|
#### ADDITIONAL
|
|
@pytest.mark.skip(reason="test via docker tests. Requires prisma client.")
|
|
def test_add_new_model(client_no_auth):
|
|
global headers
|
|
try:
|
|
test_data = {
|
|
"model_name": "test_openai_models",
|
|
"litellm_params": {
|
|
"model": "gpt-3.5-turbo",
|
|
},
|
|
"model_info": {"description": "this is a test openai model"},
|
|
}
|
|
client_no_auth.post("/model/new", json=test_data, headers=headers)
|
|
response = client_no_auth.get("/model/info", headers=headers)
|
|
assert response.status_code == 200
|
|
result = response.json()
|
|
print(f"response: {result}")
|
|
model_info = None
|
|
for m in result["data"]:
|
|
if m["model_name"] == "test_openai_models":
|
|
model_info = m["model_info"]
|
|
assert model_info["description"] == "this is a test openai model"
|
|
except Exception as e:
|
|
pytest.fail(f"LiteLLM Proxy test failed. Exception {str(e)}")
|
|
|
|
|
|
@pytest.mark.xdist_group("proxy_heavy")
|
|
def test_health(client_no_auth):
|
|
global headers
|
|
import logging
|
|
import time
|
|
|
|
from litellm._logging import verbose_logger, verbose_proxy_logger
|
|
|
|
verbose_proxy_logger.setLevel(logging.DEBUG)
|
|
|
|
try:
|
|
response = client_no_auth.get("/health")
|
|
assert response.status_code == 200
|
|
except Exception as e:
|
|
pytest.fail(f"LiteLLM Proxy test failed. Exception - {str(e)}")
|
|
|
|
|
|
# test_add_new_model()
|
|
|
|
from litellm.integrations.custom_logger import CustomLogger
|
|
|
|
|
|
class MyCustomHandler(CustomLogger):
|
|
def log_pre_api_call(self, model, messages, kwargs):
|
|
print(f"Pre-API Call")
|
|
|
|
def log_success_event(self, kwargs, response_obj, start_time, end_time):
|
|
print(f"On Success")
|
|
assert kwargs["user"] == "proxy-user"
|
|
assert kwargs["model"] == "gpt-3.5-turbo"
|
|
assert kwargs["max_tokens"] == 10
|
|
|
|
|
|
customHandler = MyCustomHandler()
|
|
|
|
|
|
@mock_patch_acompletion()
|
|
def test_chat_completion_optional_params(mock_acompletion, client_no_auth):
|
|
# [PROXY: PROD TEST] - DO NOT DELETE
|
|
# This tests if all the /chat/completion params are passed to litellm
|
|
try:
|
|
# Your test data
|
|
litellm.set_verbose = True
|
|
test_data = {
|
|
"model": "gpt-3.5-turbo",
|
|
"messages": [
|
|
{"role": "user", "content": "hi"},
|
|
],
|
|
"max_tokens": 10,
|
|
"user": "proxy-user",
|
|
}
|
|
|
|
litellm.callbacks = [customHandler]
|
|
print("testing proxy server: optional params")
|
|
response = client_no_auth.post("/v1/chat/completions", json=test_data)
|
|
mock_acompletion.assert_called_once_with(
|
|
model="gpt-3.5-turbo",
|
|
messages=[
|
|
{"role": "user", "content": "hi"},
|
|
],
|
|
max_tokens=10,
|
|
user="proxy-user",
|
|
litellm_call_id=mock.ANY,
|
|
litellm_logging_obj=mock.ANY,
|
|
request_timeout=mock.ANY,
|
|
specific_deployment=True,
|
|
metadata=mock.ANY,
|
|
proxy_server_request=mock.ANY,
|
|
secret_fields=mock.ANY,
|
|
)
|
|
assert response.status_code == 200
|
|
result = response.json()
|
|
print(f"Received response: {result}")
|
|
except Exception as e:
|
|
pytest.fail("LiteLLM Proxy test failed. Exception", e)
|
|
|
|
|
|
# Run the test
|
|
# test_chat_completion_optional_params()
|
|
|
|
|
|
# Test Reading config.yaml file
|
|
from litellm.proxy.proxy_server import ProxyConfig
|
|
|
|
|
|
@pytest.mark.skip(reason="local variable conflicts. needs to be refactored.")
|
|
@mock.patch("litellm.proxy.proxy_server.litellm.Cache")
|
|
def test_load_router_config(mock_cache, fake_env_vars):
|
|
mock_cache.return_value.cache.__dict__ = {"redis_client": None}
|
|
mock_cache.return_value.supported_call_types = [
|
|
"completion",
|
|
"acompletion",
|
|
"embedding",
|
|
"aembedding",
|
|
"atranscription",
|
|
"transcription",
|
|
]
|
|
|
|
try:
|
|
import asyncio
|
|
|
|
print("testing reading config")
|
|
# this is a basic config.yaml with only a model
|
|
filepath = os.path.dirname(os.path.abspath(__file__))
|
|
proxy_config = ProxyConfig()
|
|
result = asyncio.run(
|
|
proxy_config.load_config(
|
|
router=None,
|
|
config_file_path=f"{filepath}/example_config_yaml/simple_config.yaml",
|
|
)
|
|
)
|
|
print(result)
|
|
assert len(result[1]) == 1
|
|
|
|
# this is a load balancing config yaml
|
|
result = asyncio.run(
|
|
proxy_config.load_config(
|
|
router=None,
|
|
config_file_path=f"{filepath}/example_config_yaml/azure_config.yaml",
|
|
)
|
|
)
|
|
print(result)
|
|
assert len(result[1]) == 2
|
|
|
|
# config with general settings - custom callbacks
|
|
result = asyncio.run(
|
|
proxy_config.load_config(
|
|
router=None,
|
|
config_file_path=f"{filepath}/example_config_yaml/azure_config.yaml",
|
|
)
|
|
)
|
|
print(result)
|
|
assert len(result[1]) == 2
|
|
|
|
# tests for litellm.cache set from config
|
|
print("testing reading proxy config for cache")
|
|
litellm.cache = None
|
|
asyncio.run(
|
|
proxy_config.load_config(
|
|
router=None,
|
|
config_file_path=f"{filepath}/example_config_yaml/cache_no_params.yaml",
|
|
)
|
|
)
|
|
assert litellm.cache is not None
|
|
assert "redis_client" in vars(
|
|
litellm.cache.cache
|
|
) # it should default to redis on proxy
|
|
assert litellm.cache.supported_call_types == [
|
|
"completion",
|
|
"acompletion",
|
|
"embedding",
|
|
"aembedding",
|
|
"atranscription",
|
|
"transcription",
|
|
] # init with all call types
|
|
|
|
litellm.disable_cache()
|
|
|
|
print("testing reading proxy config for cache with params")
|
|
mock_cache.return_value.supported_call_types = [
|
|
"embedding",
|
|
"aembedding",
|
|
]
|
|
asyncio.run(
|
|
proxy_config.load_config(
|
|
router=None,
|
|
config_file_path=f"{filepath}/example_config_yaml/cache_with_params.yaml",
|
|
)
|
|
)
|
|
assert litellm.cache is not None
|
|
print(litellm.cache)
|
|
print(litellm.cache.supported_call_types)
|
|
print(vars(litellm.cache.cache))
|
|
assert "redis_client" in vars(
|
|
litellm.cache.cache
|
|
) # it should default to redis on proxy
|
|
assert litellm.cache.supported_call_types == [
|
|
"embedding",
|
|
"aembedding",
|
|
] # init with all call types
|
|
|
|
except Exception as e:
|
|
pytest.fail(
|
|
f"Proxy: Got exception reading config: {str(e)}\n{traceback.format_exc()}"
|
|
)
|
|
|
|
|
|
# test_load_router_config()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_team_update_redis():
|
|
"""
|
|
Tests if team update, updates the redis cache if set
|
|
"""
|
|
from litellm.caching.caching import DualCache, RedisCache
|
|
from litellm.proxy._types import LiteLLM_TeamTableCachedObj
|
|
from litellm.proxy.auth.auth_checks import _cache_team_object
|
|
|
|
proxy_logging_obj: ProxyLogging = getattr(
|
|
litellm.proxy.proxy_server, "proxy_logging_obj"
|
|
)
|
|
|
|
redis_cache = RedisCache(host="localhost")
|
|
|
|
with patch.object(
|
|
redis_cache,
|
|
"async_set_cache",
|
|
new=AsyncMock(),
|
|
) as mock_client:
|
|
await _cache_team_object(
|
|
team_id="1234",
|
|
team_table=LiteLLM_TeamTableCachedObj(team_id="1234"),
|
|
user_api_key_cache=DualCache(redis_cache=redis_cache),
|
|
proxy_logging_obj=proxy_logging_obj,
|
|
)
|
|
|
|
mock_client.assert_called()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_team_redis(client_no_auth):
|
|
"""
|
|
Tests if get_team_object gets value from redis cache, if set
|
|
"""
|
|
from litellm.caching.caching import DualCache, RedisCache
|
|
from litellm.proxy.auth.auth_checks import get_team_object
|
|
|
|
proxy_logging_obj: ProxyLogging = getattr(
|
|
litellm.proxy.proxy_server, "proxy_logging_obj"
|
|
)
|
|
|
|
redis_cache = RedisCache()
|
|
|
|
from fastapi import HTTPException
|
|
|
|
with patch.object(
|
|
redis_cache,
|
|
"async_get_cache",
|
|
new=AsyncMock(),
|
|
) as mock_client:
|
|
try:
|
|
await get_team_object(
|
|
team_id="1234",
|
|
user_api_key_cache=DualCache(redis_cache=redis_cache),
|
|
parent_otel_span=None,
|
|
proxy_logging_obj=proxy_logging_obj,
|
|
prisma_client=AsyncMock(),
|
|
)
|
|
except HTTPException:
|
|
pass
|
|
|
|
mock_client.assert_called_once()
|
|
|
|
|
|
import random
|
|
from litellm._uuid import uuid
|
|
from unittest.mock import AsyncMock, MagicMock, PropertyMock, patch
|
|
|
|
from litellm.proxy._types import (
|
|
LitellmUserRoles,
|
|
NewUserRequest,
|
|
TeamMemberAddRequest,
|
|
UserAPIKeyAuth,
|
|
)
|
|
from litellm.proxy.management_endpoints.internal_user_endpoints import new_user
|
|
from litellm.proxy.management_endpoints.team_endpoints import team_member_add
|
|
from test_key_generate_prisma import prisma_client
|
|
|
|
|
|
@pytest.mark.skip(reason="Requires reliable external DB connection (prisma).")
|
|
@pytest.mark.parametrize(
|
|
"user_role",
|
|
[LitellmUserRoles.INTERNAL_USER.value, LitellmUserRoles.PROXY_ADMIN.value],
|
|
)
|
|
@pytest.mark.asyncio
|
|
@pytest.mark.skip(reason="Requires reliable external DB connection (prisma).")
|
|
async def test_create_user_default_budget(prisma_client, user_role):
|
|
|
|
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
|
|
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
|
|
setattr(litellm, "max_internal_user_budget", 10)
|
|
setattr(litellm, "internal_user_budget_duration", "5m")
|
|
await litellm.proxy.proxy_server.prisma_client.connect()
|
|
user = f"ishaan {uuid.uuid4().hex}"
|
|
request = NewUserRequest(
|
|
user_id=user, user_role=user_role
|
|
) # create a key with no budget
|
|
with patch.object(
|
|
litellm.proxy.proxy_server.prisma_client, "insert_data", new=AsyncMock()
|
|
) as mock_client:
|
|
await new_user(
|
|
request,
|
|
)
|
|
|
|
mock_client.assert_called()
|
|
|
|
print(f"mock_client.call_args: {mock_client.call_args}")
|
|
print("mock_client.call_args.kwargs: {}".format(mock_client.call_args.kwargs))
|
|
|
|
if user_role == LitellmUserRoles.INTERNAL_USER.value:
|
|
assert (
|
|
mock_client.call_args.kwargs["data"]["max_budget"]
|
|
== litellm.max_internal_user_budget
|
|
)
|
|
assert (
|
|
mock_client.call_args.kwargs["data"]["budget_duration"]
|
|
== litellm.internal_user_budget_duration
|
|
)
|
|
|
|
else:
|
|
assert mock_client.call_args.kwargs["data"]["max_budget"] is None
|
|
assert mock_client.call_args.kwargs["data"]["budget_duration"] is None
|
|
|
|
|
|
@pytest.mark.parametrize("new_member_method", ["user_id", "user_email"])
|
|
@pytest.mark.asyncio
|
|
@pytest.mark.skip(reason="Requires reliable external DB connection (prisma).")
|
|
async def test_create_team_member_add(prisma_client, new_member_method):
|
|
import time
|
|
|
|
from fastapi import Request
|
|
|
|
from litellm.proxy._types import LiteLLM_TeamTableCachedObj, LiteLLM_UserTable
|
|
from litellm.proxy.proxy_server import hash_token, user_api_key_cache
|
|
|
|
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
|
|
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
|
|
setattr(litellm, "max_internal_user_budget", 10)
|
|
setattr(litellm, "internal_user_budget_duration", "5m")
|
|
await litellm.proxy.proxy_server.prisma_client.connect()
|
|
user = f"ishaan {uuid.uuid4().hex}"
|
|
_team_id = "litellm-test-client-id-new"
|
|
team_obj = LiteLLM_TeamTableCachedObj(
|
|
team_id=_team_id,
|
|
blocked=False,
|
|
last_refreshed_at=time.time(),
|
|
metadata={"guardrails": {"modify_guardrails": False}},
|
|
)
|
|
# user_api_key_cache.set_cache(key=hash_token(user_key), value=valid_token)
|
|
user_api_key_cache.set_cache(key="team_id:{}".format(_team_id), value=team_obj)
|
|
|
|
setattr(litellm.proxy.proxy_server, "user_api_key_cache", user_api_key_cache)
|
|
if new_member_method == "user_id":
|
|
data = {
|
|
"team_id": _team_id,
|
|
"member": [{"role": "user", "user_id": user}],
|
|
}
|
|
elif new_member_method == "user_email":
|
|
data = {
|
|
"team_id": _team_id,
|
|
"member": [{"role": "user", "user_email": user}],
|
|
}
|
|
team_member_add_request = TeamMemberAddRequest(**data)
|
|
|
|
with patch(
|
|
"litellm.proxy.proxy_server.prisma_client.db.litellm_usertable",
|
|
new_callable=AsyncMock,
|
|
) as mock_litellm_usertable, patch(
|
|
"litellm.proxy.auth.auth_checks._get_team_object_from_user_api_key_cache",
|
|
new=AsyncMock(return_value=team_obj),
|
|
) as mock_team_obj, patch(
|
|
"litellm.proxy.proxy_server.prisma_client.get_data",
|
|
new=AsyncMock(return_value=[]),
|
|
) as mock_get_data:
|
|
|
|
mock_client = AsyncMock(
|
|
return_value=LiteLLM_UserTable(
|
|
user_id="1234", max_budget=100, user_email="1234"
|
|
)
|
|
)
|
|
mock_litellm_usertable.upsert = mock_client
|
|
mock_litellm_usertable.find_many = AsyncMock(return_value=None)
|
|
# Mock find_first for user_email validation (returns None for new users)
|
|
mock_litellm_usertable.find_first = AsyncMock(return_value=None)
|
|
# Mock find_unique for user_id validation (returns None for new users)
|
|
mock_litellm_usertable.find_unique = AsyncMock(return_value=None)
|
|
team_mock_client = AsyncMock()
|
|
original_val = getattr(
|
|
litellm.proxy.proxy_server.prisma_client.db, "litellm_teamtable"
|
|
)
|
|
litellm.proxy.proxy_server.prisma_client.db.litellm_teamtable = team_mock_client
|
|
|
|
team_mock_client.update = AsyncMock(
|
|
return_value=LiteLLM_TeamTableCachedObj(team_id="1234")
|
|
)
|
|
|
|
print(f"team_member_add_request={team_member_add_request}")
|
|
await team_member_add(
|
|
data=team_member_add_request,
|
|
user_api_key_dict=UserAPIKeyAuth(user_role="proxy_admin"),
|
|
)
|
|
|
|
mock_client.assert_called()
|
|
|
|
print(f"mock_client.call_args: {mock_client.call_args}")
|
|
print("mock_client.call_args.kwargs: {}".format(mock_client.call_args.kwargs))
|
|
|
|
assert (
|
|
mock_client.call_args.kwargs["data"]["create"]["max_budget"]
|
|
== litellm.max_internal_user_budget
|
|
)
|
|
assert (
|
|
mock_client.call_args.kwargs["data"]["create"]["budget_duration"]
|
|
== litellm.internal_user_budget_duration
|
|
)
|
|
|
|
litellm.proxy.proxy_server.prisma_client.db.litellm_teamtable = original_val
|
|
|
|
|
|
@pytest.mark.parametrize("team_member_role", ["admin", "user"])
|
|
@pytest.mark.parametrize("team_route", ["/team/member_add", "/team/member_delete"])
|
|
@pytest.mark.asyncio
|
|
async def test_create_team_member_add_team_admin_user_api_key_auth(
|
|
prisma_client, team_member_role, team_route
|
|
):
|
|
import time
|
|
|
|
from fastapi import Request
|
|
|
|
from litellm.proxy._types import LiteLLM_TeamTableCachedObj, Member
|
|
from litellm.proxy.proxy_server import (
|
|
ProxyException,
|
|
hash_token,
|
|
user_api_key_auth,
|
|
user_api_key_cache,
|
|
)
|
|
|
|
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
|
|
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
|
|
setattr(litellm, "max_internal_user_budget", 10)
|
|
setattr(litellm, "internal_user_budget_duration", "5m")
|
|
await litellm.proxy.proxy_server.prisma_client.connect()
|
|
user = f"ishaan {uuid.uuid4().hex}"
|
|
_team_id = "litellm-test-client-id-new"
|
|
user_key = "sk-12345678"
|
|
|
|
valid_token = UserAPIKeyAuth(
|
|
team_id=_team_id,
|
|
token=hash_token(user_key),
|
|
team_member=Member(role=team_member_role, user_id=user),
|
|
last_refreshed_at=time.time(),
|
|
)
|
|
user_api_key_cache.set_cache(key=hash_token(user_key), value=valid_token)
|
|
|
|
team_obj = LiteLLM_TeamTableCachedObj(
|
|
team_id=_team_id,
|
|
blocked=False,
|
|
last_refreshed_at=time.time(),
|
|
metadata={"guardrails": {"modify_guardrails": False}},
|
|
)
|
|
|
|
user_api_key_cache.set_cache(key="team_id:{}".format(_team_id), value=team_obj)
|
|
|
|
setattr(litellm.proxy.proxy_server, "user_api_key_cache", user_api_key_cache)
|
|
|
|
## TEST IF TEAM ADMIN ALLOWED TO CALL /MEMBER_ADD ENDPOINT
|
|
import json
|
|
|
|
from starlette.datastructures import URL
|
|
|
|
request = Request(scope={"type": "http"})
|
|
request._url = URL(url=team_route)
|
|
|
|
body = {}
|
|
json_bytes = json.dumps(body).encode("utf-8")
|
|
|
|
request._body = json_bytes
|
|
|
|
## ALLOWED BY USER_API_KEY_AUTH
|
|
await user_api_key_auth(request=request, api_key="Bearer " + user_key)
|
|
|
|
|
|
@pytest.mark.parametrize("new_member_method", ["user_id", "user_email"])
|
|
@pytest.mark.parametrize("user_role", ["admin", "user"])
|
|
@pytest.mark.asyncio
|
|
async def test_create_team_member_add_team_admin(
|
|
prisma_client, new_member_method, user_role
|
|
):
|
|
"""
|
|
Relevant issue - https://github.com/BerriAI/litellm/issues/5300
|
|
|
|
Allow team admins to:
|
|
- Add and remove team members
|
|
- raise error if team member not an existing 'internal_user'
|
|
"""
|
|
import time
|
|
|
|
from fastapi import Request
|
|
|
|
from litellm.proxy._types import (
|
|
LiteLLM_TeamTableCachedObj,
|
|
LiteLLM_UserTable,
|
|
Member,
|
|
)
|
|
from litellm.proxy.proxy_server import (
|
|
HTTPException,
|
|
ProxyException,
|
|
hash_token,
|
|
user_api_key_auth,
|
|
user_api_key_cache,
|
|
)
|
|
|
|
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
|
|
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
|
|
setattr(litellm, "max_internal_user_budget", 10)
|
|
setattr(litellm, "internal_user_budget_duration", "5m")
|
|
await litellm.proxy.proxy_server.prisma_client.connect()
|
|
user = f"ishaan {uuid.uuid4().hex}"
|
|
_team_id = "litellm-test-client-id-new"
|
|
user_key = "sk-12345678"
|
|
team_admin = f"krrish {uuid.uuid4().hex}"
|
|
|
|
valid_token = UserAPIKeyAuth(
|
|
team_id=_team_id,
|
|
user_id=team_admin,
|
|
token=hash_token(user_key),
|
|
last_refreshed_at=time.time(),
|
|
)
|
|
user_api_key_cache.set_cache(key=hash_token(user_key), value=valid_token)
|
|
|
|
team_obj = LiteLLM_TeamTableCachedObj(
|
|
team_id=_team_id,
|
|
blocked=False,
|
|
last_refreshed_at=time.time(),
|
|
members_with_roles=[Member(role=user_role, user_id=team_admin)],
|
|
metadata={"guardrails": {"modify_guardrails": False}},
|
|
)
|
|
|
|
user_api_key_cache.set_cache(key="team_id:{}".format(_team_id), value=team_obj)
|
|
|
|
setattr(litellm.proxy.proxy_server, "user_api_key_cache", user_api_key_cache)
|
|
if new_member_method == "user_id":
|
|
data = {
|
|
"team_id": _team_id,
|
|
"member": [{"role": "user", "user_id": user}],
|
|
}
|
|
elif new_member_method == "user_email":
|
|
data = {
|
|
"team_id": _team_id,
|
|
"member": [{"role": "user", "user_email": user}],
|
|
}
|
|
team_member_add_request = TeamMemberAddRequest(**data)
|
|
|
|
with patch(
|
|
"litellm.proxy.proxy_server.prisma_client.db.litellm_usertable",
|
|
new_callable=AsyncMock,
|
|
) as mock_litellm_usertable, patch(
|
|
"litellm.proxy.auth.auth_checks._get_team_object_from_user_api_key_cache",
|
|
new=AsyncMock(return_value=team_obj),
|
|
) as mock_team_obj, patch(
|
|
"litellm.proxy.proxy_server.prisma_client.get_data",
|
|
new=AsyncMock(return_value=[]),
|
|
) as mock_get_data:
|
|
mock_client = AsyncMock(
|
|
return_value=LiteLLM_UserTable(
|
|
user_id="1234", max_budget=100, user_email="1234"
|
|
)
|
|
)
|
|
mock_litellm_usertable.upsert = mock_client
|
|
mock_litellm_usertable.find_many = AsyncMock(return_value=None)
|
|
# Mock find_first for user_email validation (returns None for new users)
|
|
mock_litellm_usertable.find_first = AsyncMock(return_value=None)
|
|
# Mock find_unique for user_id validation (returns None for new users)
|
|
mock_litellm_usertable.find_unique = AsyncMock(return_value=None)
|
|
|
|
team_mock_client = AsyncMock()
|
|
original_val = getattr(
|
|
litellm.proxy.proxy_server.prisma_client.db, "litellm_teamtable"
|
|
)
|
|
litellm.proxy.proxy_server.prisma_client.db.litellm_teamtable = team_mock_client
|
|
|
|
team_mock_client.update = AsyncMock(
|
|
return_value=LiteLLM_TeamTableCachedObj(team_id="1234")
|
|
)
|
|
|
|
try:
|
|
await team_member_add(
|
|
data=team_member_add_request,
|
|
user_api_key_dict=valid_token,
|
|
)
|
|
except HTTPException as e:
|
|
if user_role == "user":
|
|
assert e.status_code == 403
|
|
return
|
|
else:
|
|
raise e
|
|
|
|
mock_client.assert_called()
|
|
|
|
print(f"mock_client.call_args: {mock_client.call_args}")
|
|
print("mock_client.call_args.kwargs: {}".format(mock_client.call_args.kwargs))
|
|
|
|
assert (
|
|
mock_client.call_args.kwargs["data"]["create"]["max_budget"]
|
|
== litellm.max_internal_user_budget
|
|
)
|
|
assert (
|
|
mock_client.call_args.kwargs["data"]["create"]["budget_duration"]
|
|
== litellm.internal_user_budget_duration
|
|
)
|
|
|
|
litellm.proxy.proxy_server.prisma_client.db.litellm_teamtable = original_val
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
@pytest.mark.skip(reason="Requires reliable external DB connection (prisma).")
|
|
async def test_user_info_team_list(prisma_client):
|
|
"""Assert user_info for admin calls team_list function"""
|
|
from litellm.proxy._types import LiteLLM_UserTable
|
|
|
|
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
|
|
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
|
|
await litellm.proxy.proxy_server.prisma_client.connect()
|
|
|
|
from litellm.proxy.management_endpoints.internal_user_endpoints import user_info
|
|
|
|
with patch(
|
|
"litellm.proxy.management_endpoints.team_endpoints.list_team",
|
|
new_callable=AsyncMock,
|
|
) as mock_client:
|
|
|
|
prisma_client.get_data = AsyncMock(
|
|
return_value=LiteLLM_UserTable(
|
|
user_role="proxy_admin",
|
|
user_id="default_user_id",
|
|
max_budget=None,
|
|
user_email="",
|
|
)
|
|
)
|
|
|
|
try:
|
|
await user_info(
|
|
request=MagicMock(),
|
|
user_id=None,
|
|
user_api_key_dict=UserAPIKeyAuth(
|
|
api_key="sk-1234", user_id="default_user_id"
|
|
),
|
|
)
|
|
except Exception:
|
|
pass
|
|
|
|
mock_client.assert_called()
|
|
|
|
|
|
@pytest.mark.skip(reason="Local test")
|
|
@pytest.mark.asyncio
|
|
async def test_add_callback_via_key(prisma_client):
|
|
"""
|
|
Test if callback specified in key, is used.
|
|
"""
|
|
global headers
|
|
import json
|
|
|
|
from fastapi import HTTPException, Request, Response
|
|
from starlette.datastructures import URL
|
|
|
|
from litellm.proxy.proxy_server import chat_completion
|
|
|
|
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
|
|
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
|
|
await litellm.proxy.proxy_server.prisma_client.connect()
|
|
|
|
litellm.set_verbose = True
|
|
|
|
try:
|
|
# Your test data
|
|
test_data = {
|
|
"model": "azure/gpt-4.1-mini",
|
|
"messages": [
|
|
{"role": "user", "content": "write 1 sentence poem"},
|
|
],
|
|
"max_tokens": 10,
|
|
"mock_response": "Hello world",
|
|
"api_key": "my-fake-key",
|
|
}
|
|
|
|
request = Request(scope={"type": "http", "method": "POST", "headers": {}})
|
|
request._url = URL(url="/chat/completions")
|
|
|
|
json_bytes = json.dumps(test_data).encode("utf-8")
|
|
|
|
request._body = json_bytes
|
|
|
|
with patch.object(
|
|
litellm.litellm_core_utils.litellm_logging,
|
|
"LangFuseLogger",
|
|
new=MagicMock(),
|
|
) as mock_client:
|
|
resp = await chat_completion(
|
|
request=request,
|
|
fastapi_response=Response(),
|
|
user_api_key_dict=UserAPIKeyAuth(
|
|
metadata={
|
|
"logging": [
|
|
{
|
|
"callback_name": "langfuse", # 'otel', 'langfuse', 'lunary'
|
|
"callback_type": "success", # set, if required by integration - future improvement, have logging tools work for success + failure by default
|
|
"callback_vars": {
|
|
"langfuse_public_key": "os.environ/LANGFUSE_PUBLIC_KEY",
|
|
"langfuse_secret_key": "os.environ/LANGFUSE_SECRET_KEY",
|
|
"langfuse_host": "https://us.cloud.langfuse.com",
|
|
},
|
|
}
|
|
]
|
|
}
|
|
),
|
|
)
|
|
print(resp)
|
|
mock_client.assert_called()
|
|
mock_client.return_value.log_event.assert_called()
|
|
args, kwargs = mock_client.return_value.log_event.call_args
|
|
kwargs = kwargs["kwargs"]
|
|
assert "user_api_key_metadata" in kwargs["litellm_params"]["metadata"]
|
|
assert (
|
|
"logging"
|
|
in kwargs["litellm_params"]["metadata"]["user_api_key_metadata"]
|
|
)
|
|
checked_keys = False
|
|
for item in kwargs["litellm_params"]["metadata"]["user_api_key_metadata"][
|
|
"logging"
|
|
]:
|
|
for k, v in item["callback_vars"].items():
|
|
print("k={}, v={}".format(k, v))
|
|
if "key" in k:
|
|
assert "os.environ" in v
|
|
checked_keys = True
|
|
|
|
assert checked_keys
|
|
except Exception as e:
|
|
pytest.fail(f"LiteLLM Proxy test failed. Exception - {str(e)}")
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
@pytest.mark.parametrize(
|
|
"callback_type, expected_success_callbacks, expected_failure_callbacks",
|
|
[
|
|
("success", ["langfuse"], []),
|
|
("failure", [], ["langfuse"]),
|
|
("success_and_failure", ["langfuse"], ["langfuse"]),
|
|
],
|
|
)
|
|
async def test_add_callback_via_key_litellm_pre_call_utils(
|
|
prisma_client, callback_type, expected_success_callbacks, expected_failure_callbacks
|
|
):
|
|
import json
|
|
|
|
from fastapi import HTTPException, Request, Response
|
|
from starlette.datastructures import URL
|
|
|
|
from litellm.proxy.litellm_pre_call_utils import add_litellm_data_to_request
|
|
|
|
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
|
|
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
|
|
await litellm.proxy.proxy_server.prisma_client.connect()
|
|
|
|
proxy_config = getattr(litellm.proxy.proxy_server, "proxy_config")
|
|
|
|
request = Request(scope={"type": "http", "method": "POST", "headers": {}})
|
|
request._url = URL(url="/chat/completions")
|
|
|
|
test_data = {
|
|
"model": "azure/gpt-4.1-mini",
|
|
"messages": [
|
|
{"role": "user", "content": "write 1 sentence poem"},
|
|
],
|
|
"max_tokens": 10,
|
|
"mock_response": "Hello world",
|
|
"api_key": "my-fake-key",
|
|
}
|
|
|
|
json_bytes = json.dumps(test_data).encode("utf-8")
|
|
|
|
request._body = json_bytes
|
|
|
|
data = {
|
|
"data": {
|
|
"model": "azure/gpt-4.1-mini",
|
|
"messages": [{"role": "user", "content": "write 1 sentence poem"}],
|
|
"max_tokens": 10,
|
|
"mock_response": "Hello world",
|
|
"api_key": "my-fake-key",
|
|
},
|
|
"request": request,
|
|
"user_api_key_dict": UserAPIKeyAuth(
|
|
token=None,
|
|
key_name=None,
|
|
key_alias=None,
|
|
spend=0.0,
|
|
max_budget=None,
|
|
expires=None,
|
|
models=[],
|
|
aliases={},
|
|
config={},
|
|
user_id=None,
|
|
team_id=None,
|
|
max_parallel_requests=None,
|
|
metadata={
|
|
"logging": [
|
|
{
|
|
"callback_name": "langfuse",
|
|
"callback_type": callback_type,
|
|
"callback_vars": {
|
|
"langfuse_public_key": "my-mock-public-key",
|
|
"langfuse_secret_key": "my-mock-secret-key",
|
|
"langfuse_host": "https://us.cloud.langfuse.com",
|
|
},
|
|
}
|
|
]
|
|
},
|
|
tpm_limit=None,
|
|
rpm_limit=None,
|
|
budget_duration=None,
|
|
budget_reset_at=None,
|
|
allowed_cache_controls=[],
|
|
permissions={},
|
|
model_spend={},
|
|
model_max_budget={},
|
|
soft_budget_cooldown=False,
|
|
litellm_budget_table=None,
|
|
org_id=None,
|
|
team_spend=None,
|
|
team_alias=None,
|
|
team_tpm_limit=None,
|
|
team_rpm_limit=None,
|
|
team_max_budget=None,
|
|
team_models=[],
|
|
team_blocked=False,
|
|
soft_budget=None,
|
|
team_model_aliases=None,
|
|
team_member_spend=None,
|
|
team_metadata=None,
|
|
end_user_id=None,
|
|
end_user_tpm_limit=None,
|
|
end_user_rpm_limit=None,
|
|
end_user_max_budget=None,
|
|
last_refreshed_at=None,
|
|
api_key=None,
|
|
user_role=None,
|
|
allowed_model_region=None,
|
|
parent_otel_span=None,
|
|
),
|
|
"proxy_config": proxy_config,
|
|
"general_settings": {},
|
|
"version": "0.0.0",
|
|
}
|
|
|
|
new_data = await add_litellm_data_to_request(**data)
|
|
print("NEW DATA: {}".format(new_data))
|
|
|
|
assert "langfuse_public_key" in new_data
|
|
assert new_data["langfuse_public_key"] == "my-mock-public-key"
|
|
assert "langfuse_secret_key" in new_data
|
|
assert new_data["langfuse_secret_key"] == "my-mock-secret-key"
|
|
|
|
if expected_success_callbacks:
|
|
assert "success_callback" in new_data
|
|
assert new_data["success_callback"] == expected_success_callbacks
|
|
|
|
if expected_failure_callbacks:
|
|
assert "failure_callback" in new_data
|
|
assert new_data["failure_callback"] == expected_failure_callbacks
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
@pytest.mark.parametrize(
|
|
"disable_fallbacks_set",
|
|
[
|
|
True,
|
|
False,
|
|
],
|
|
)
|
|
async def test_disable_fallbacks_by_key(disable_fallbacks_set):
|
|
from litellm.proxy.litellm_pre_call_utils import LiteLLMProxyRequestSetup
|
|
|
|
key_metadata = {"disable_fallbacks": disable_fallbacks_set}
|
|
existing_data = {
|
|
"model": "azure/gpt-4.1-mini",
|
|
"messages": [{"role": "user", "content": "write 1 sentence poem"}],
|
|
}
|
|
data = LiteLLMProxyRequestSetup.add_key_level_controls(
|
|
key_metadata=key_metadata,
|
|
data=existing_data,
|
|
_metadata_variable_name="metadata",
|
|
)
|
|
|
|
assert data["disable_fallbacks"] == disable_fallbacks_set
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
@pytest.mark.parametrize(
|
|
"callback_type, expected_success_callbacks, expected_failure_callbacks",
|
|
[
|
|
("success", ["gcs_bucket"], []),
|
|
("failure", [], ["gcs_bucket"]),
|
|
("success_and_failure", ["gcs_bucket"], ["gcs_bucket"]),
|
|
],
|
|
)
|
|
async def test_add_callback_via_key_litellm_pre_call_utils_gcs_bucket(
|
|
prisma_client, callback_type, expected_success_callbacks, expected_failure_callbacks
|
|
):
|
|
import json
|
|
|
|
from fastapi import HTTPException, Request, Response
|
|
from starlette.datastructures import URL
|
|
|
|
from litellm.proxy.litellm_pre_call_utils import add_litellm_data_to_request
|
|
|
|
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
|
|
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
|
|
await litellm.proxy.proxy_server.prisma_client.connect()
|
|
|
|
proxy_config = getattr(litellm.proxy.proxy_server, "proxy_config")
|
|
|
|
request = Request(scope={"type": "http", "method": "POST", "headers": {}})
|
|
request._url = URL(url="/chat/completions")
|
|
|
|
test_data = {
|
|
"model": "azure/gpt-4.1-mini",
|
|
"messages": [
|
|
{"role": "user", "content": "write 1 sentence poem"},
|
|
],
|
|
"max_tokens": 10,
|
|
"mock_response": "Hello world",
|
|
"api_key": "my-fake-key",
|
|
}
|
|
|
|
json_bytes = json.dumps(test_data).encode("utf-8")
|
|
|
|
request._body = json_bytes
|
|
|
|
data = {
|
|
"data": {
|
|
"model": "azure/gpt-4.1-mini",
|
|
"messages": [{"role": "user", "content": "write 1 sentence poem"}],
|
|
"max_tokens": 10,
|
|
"mock_response": "Hello world",
|
|
"api_key": "my-fake-key",
|
|
},
|
|
"request": request,
|
|
"user_api_key_dict": UserAPIKeyAuth(
|
|
token=None,
|
|
key_name=None,
|
|
key_alias=None,
|
|
spend=0.0,
|
|
max_budget=None,
|
|
expires=None,
|
|
models=[],
|
|
aliases={},
|
|
config={},
|
|
user_id=None,
|
|
team_id=None,
|
|
max_parallel_requests=None,
|
|
metadata={
|
|
"logging": [
|
|
{
|
|
"callback_name": "gcs_bucket",
|
|
"callback_type": callback_type,
|
|
"callback_vars": {
|
|
"gcs_bucket_name": "key-logging-project1",
|
|
"gcs_path_service_account": "pathrise-convert-1606954137718-a956eef1a2a8.json",
|
|
},
|
|
}
|
|
]
|
|
},
|
|
tpm_limit=None,
|
|
rpm_limit=None,
|
|
budget_duration=None,
|
|
budget_reset_at=None,
|
|
allowed_cache_controls=[],
|
|
permissions={},
|
|
model_spend={},
|
|
model_max_budget={},
|
|
soft_budget_cooldown=False,
|
|
litellm_budget_table=None,
|
|
org_id=None,
|
|
team_spend=None,
|
|
team_alias=None,
|
|
team_tpm_limit=None,
|
|
team_rpm_limit=None,
|
|
team_max_budget=None,
|
|
team_models=[],
|
|
team_blocked=False,
|
|
soft_budget=None,
|
|
team_model_aliases=None,
|
|
team_member_spend=None,
|
|
team_metadata=None,
|
|
end_user_id=None,
|
|
end_user_tpm_limit=None,
|
|
end_user_rpm_limit=None,
|
|
end_user_max_budget=None,
|
|
last_refreshed_at=None,
|
|
api_key=None,
|
|
user_role=None,
|
|
allowed_model_region=None,
|
|
parent_otel_span=None,
|
|
),
|
|
"proxy_config": proxy_config,
|
|
"general_settings": {},
|
|
"version": "0.0.0",
|
|
}
|
|
|
|
new_data = await add_litellm_data_to_request(**data)
|
|
print("NEW DATA: {}".format(new_data))
|
|
|
|
assert "gcs_bucket_name" in new_data
|
|
assert new_data["gcs_bucket_name"] == "key-logging-project1"
|
|
assert "gcs_path_service_account" in new_data
|
|
assert (
|
|
new_data["gcs_path_service_account"]
|
|
== "pathrise-convert-1606954137718-a956eef1a2a8.json"
|
|
)
|
|
|
|
if expected_success_callbacks:
|
|
assert "success_callback" in new_data
|
|
assert new_data["success_callback"] == expected_success_callbacks
|
|
|
|
if expected_failure_callbacks:
|
|
assert "failure_callback" in new_data
|
|
assert new_data["failure_callback"] == expected_failure_callbacks
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
@pytest.mark.parametrize(
|
|
"callback_type, expected_success_callbacks, expected_failure_callbacks",
|
|
[
|
|
("success", ["langsmith"], []),
|
|
("failure", [], ["langsmith"]),
|
|
("success_and_failure", ["langsmith"], ["langsmith"]),
|
|
],
|
|
)
|
|
async def test_add_callback_via_key_litellm_pre_call_utils_langsmith(
|
|
prisma_client, callback_type, expected_success_callbacks, expected_failure_callbacks
|
|
):
|
|
import json
|
|
|
|
from fastapi import HTTPException, Request, Response
|
|
from starlette.datastructures import URL
|
|
|
|
from litellm.proxy.litellm_pre_call_utils import add_litellm_data_to_request
|
|
|
|
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
|
|
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
|
|
await litellm.proxy.proxy_server.prisma_client.connect()
|
|
|
|
proxy_config = getattr(litellm.proxy.proxy_server, "proxy_config")
|
|
|
|
request = Request(scope={"type": "http", "method": "POST", "headers": {}})
|
|
request._url = URL(url="/chat/completions")
|
|
|
|
test_data = {
|
|
"model": "azure/gpt-4.1-mini",
|
|
"messages": [
|
|
{"role": "user", "content": "write 1 sentence poem"},
|
|
],
|
|
"max_tokens": 10,
|
|
"mock_response": "Hello world",
|
|
"api_key": "my-fake-key",
|
|
}
|
|
|
|
json_bytes = json.dumps(test_data).encode("utf-8")
|
|
|
|
request._body = json_bytes
|
|
|
|
data = {
|
|
"data": {
|
|
"model": "azure/gpt-4.1-mini",
|
|
"messages": [{"role": "user", "content": "write 1 sentence poem"}],
|
|
"max_tokens": 10,
|
|
"mock_response": "Hello world",
|
|
"api_key": "my-fake-key",
|
|
},
|
|
"request": request,
|
|
"user_api_key_dict": UserAPIKeyAuth(
|
|
token=None,
|
|
key_name=None,
|
|
key_alias=None,
|
|
spend=0.0,
|
|
max_budget=None,
|
|
expires=None,
|
|
models=[],
|
|
aliases={},
|
|
config={},
|
|
user_id=None,
|
|
team_id=None,
|
|
max_parallel_requests=None,
|
|
metadata={
|
|
"logging": [
|
|
{
|
|
"callback_name": "langsmith",
|
|
"callback_type": callback_type,
|
|
"callback_vars": {
|
|
"langsmith_api_key": "ls-1234",
|
|
"langsmith_project": "pr-brief-resemblance-72",
|
|
"langsmith_base_url": "https://api.smith.langchain.com",
|
|
},
|
|
}
|
|
]
|
|
},
|
|
tpm_limit=None,
|
|
rpm_limit=None,
|
|
budget_duration=None,
|
|
budget_reset_at=None,
|
|
allowed_cache_controls=[],
|
|
permissions={},
|
|
model_spend={},
|
|
model_max_budget={},
|
|
soft_budget_cooldown=False,
|
|
litellm_budget_table=None,
|
|
org_id=None,
|
|
team_spend=None,
|
|
team_alias=None,
|
|
team_tpm_limit=None,
|
|
team_rpm_limit=None,
|
|
team_max_budget=None,
|
|
team_models=[],
|
|
team_blocked=False,
|
|
soft_budget=None,
|
|
team_model_aliases=None,
|
|
team_member_spend=None,
|
|
team_metadata=None,
|
|
end_user_id=None,
|
|
end_user_tpm_limit=None,
|
|
end_user_rpm_limit=None,
|
|
end_user_max_budget=None,
|
|
last_refreshed_at=None,
|
|
api_key=None,
|
|
user_role=None,
|
|
allowed_model_region=None,
|
|
parent_otel_span=None,
|
|
),
|
|
"proxy_config": proxy_config,
|
|
"general_settings": {},
|
|
"version": "0.0.0",
|
|
}
|
|
|
|
new_data = await add_litellm_data_to_request(**data)
|
|
print("NEW DATA: {}".format(new_data))
|
|
|
|
assert "langsmith_api_key" in new_data
|
|
assert new_data["langsmith_api_key"] == "ls-1234"
|
|
assert "langsmith_project" in new_data
|
|
assert new_data["langsmith_project"] == "pr-brief-resemblance-72"
|
|
assert "langsmith_base_url" in new_data
|
|
assert new_data["langsmith_base_url"] == "https://api.smith.langchain.com"
|
|
|
|
if expected_success_callbacks:
|
|
assert "success_callback" in new_data
|
|
assert new_data["success_callback"] == expected_success_callbacks
|
|
|
|
if expected_failure_callbacks:
|
|
assert "failure_callback" in new_data
|
|
assert new_data["failure_callback"] == expected_failure_callbacks
|
|
|
|
|
|
@pytest.mark.skipif(
|
|
not os.getenv("GEMINI_API_KEY") and not os.getenv("GOOGLE_API_KEY"),
|
|
reason="Requires GEMINI_API_KEY or GOOGLE_API_KEY.",
|
|
)
|
|
@pytest.mark.asyncio
|
|
async def test_gemini_pass_through_endpoint():
|
|
from starlette.datastructures import URL
|
|
|
|
from litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints import (
|
|
Request,
|
|
Response,
|
|
gemini_proxy_route,
|
|
)
|
|
|
|
body = b"""
|
|
{
|
|
"contents": [{
|
|
"parts":[{
|
|
"text": "The quick brown fox jumps over the lazy dog."
|
|
}]
|
|
}]
|
|
}
|
|
"""
|
|
|
|
# Construct the scope dictionary
|
|
scope = {
|
|
"type": "http",
|
|
"method": "POST",
|
|
"path": "/gemini/v1beta/models/gemini-2.5-flash:countTokens",
|
|
"query_string": b"key=sk-1234",
|
|
"headers": [
|
|
(b"content-type", b"application/json"),
|
|
],
|
|
}
|
|
|
|
# Create a new Request object
|
|
async def async_receive():
|
|
return {"type": "http.request", "body": body, "more_body": False}
|
|
|
|
request = Request(
|
|
scope=scope,
|
|
receive=async_receive,
|
|
)
|
|
|
|
resp = await gemini_proxy_route(
|
|
endpoint="v1beta/models/gemini-2.5-flash:countTokens?key=sk-1234",
|
|
request=request,
|
|
fastapi_response=Response(),
|
|
)
|
|
|
|
print(resp.body)
|
|
|
|
|
|
@pytest.mark.parametrize("hidden", [True, False])
|
|
@pytest.mark.asyncio
|
|
@pytest.mark.skip(reason="Requires reliable external DB connection (prisma).")
|
|
async def test_proxy_model_group_alias_checks(prisma_client, hidden):
|
|
"""
|
|
Check if model group alias is returned on
|
|
|
|
`/v1/models`
|
|
`/v1/model/info`
|
|
`/v1/model_group/info`
|
|
"""
|
|
import json
|
|
|
|
from fastapi import HTTPException, Request, Response
|
|
from starlette.datastructures import URL
|
|
|
|
from litellm.proxy.proxy_server import model_group_info, model_info_v1, model_list
|
|
|
|
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
|
|
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
|
|
await litellm.proxy.proxy_server.prisma_client.connect()
|
|
|
|
proxy_config = getattr(litellm.proxy.proxy_server, "proxy_config")
|
|
|
|
_model_list = [
|
|
{
|
|
"model_name": "gpt-3.5-turbo",
|
|
"litellm_params": {"model": "gpt-3.5-turbo"},
|
|
}
|
|
]
|
|
model_alias = "gpt-4"
|
|
router = litellm.Router(
|
|
model_list=_model_list,
|
|
model_group_alias={model_alias: {"model": "gpt-3.5-turbo", "hidden": hidden}},
|
|
)
|
|
setattr(litellm.proxy.proxy_server, "llm_router", router)
|
|
setattr(litellm.proxy.proxy_server, "llm_model_list", _model_list)
|
|
|
|
request = Request(scope={"type": "http", "method": "POST", "headers": {}})
|
|
request._url = URL(url="/v1/models")
|
|
|
|
resp = await model_list(
|
|
user_api_key_dict=UserAPIKeyAuth(models=[]),
|
|
)
|
|
|
|
if hidden:
|
|
assert len(resp["data"]) == 1
|
|
else:
|
|
assert len(resp["data"]) == 2
|
|
print(resp)
|
|
|
|
resp = await model_info_v1(
|
|
user_api_key_dict=UserAPIKeyAuth(models=[]),
|
|
)
|
|
models = resp["data"]
|
|
is_model_alias_in_list = False
|
|
for item in models:
|
|
if model_alias == item["model_name"]:
|
|
is_model_alias_in_list = True
|
|
|
|
if hidden:
|
|
assert is_model_alias_in_list is False
|
|
else:
|
|
assert is_model_alias_in_list
|
|
|
|
resp = await model_group_info(
|
|
user_api_key_dict=UserAPIKeyAuth(models=[]),
|
|
)
|
|
print(f"resp: {resp}")
|
|
models = resp["data"]
|
|
is_model_alias_in_list = False
|
|
print(f"model_alias: {model_alias}, models: {models}")
|
|
for item in models:
|
|
if model_alias == item.model_group:
|
|
is_model_alias_in_list = True
|
|
|
|
if hidden:
|
|
assert is_model_alias_in_list is False
|
|
else:
|
|
assert is_model_alias_in_list, f"models: {models}"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
@pytest.mark.skip(reason="Requires reliable external DB connection (prisma).")
|
|
async def test_proxy_model_group_info_rerank(prisma_client):
|
|
"""
|
|
Check if rerank model is returned on the following endpoints
|
|
|
|
`/v1/models`
|
|
`/v1/model/info`
|
|
`/v1/model_group/info`
|
|
"""
|
|
import json
|
|
|
|
from fastapi import HTTPException, Request, Response
|
|
from starlette.datastructures import URL
|
|
|
|
from litellm.proxy.proxy_server import model_group_info, model_info_v1, model_list
|
|
|
|
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
|
|
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
|
|
await litellm.proxy.proxy_server.prisma_client.connect()
|
|
|
|
proxy_config = getattr(litellm.proxy.proxy_server, "proxy_config")
|
|
|
|
_model_list = [
|
|
{
|
|
"model_name": "rerank-english-v3.0",
|
|
"litellm_params": {"model": "cohere/rerank-english-v3.0"},
|
|
"model_info": {
|
|
"mode": "rerank",
|
|
},
|
|
}
|
|
]
|
|
router = litellm.Router(model_list=_model_list)
|
|
setattr(litellm.proxy.proxy_server, "llm_router", router)
|
|
setattr(litellm.proxy.proxy_server, "llm_model_list", _model_list)
|
|
|
|
request = Request(scope={"type": "http", "method": "POST", "headers": {}})
|
|
request._url = URL(url="/v1/models")
|
|
|
|
resp = await model_list(
|
|
user_api_key_dict=UserAPIKeyAuth(models=[]),
|
|
)
|
|
|
|
assert len(resp["data"]) == 1
|
|
print(resp)
|
|
|
|
resp = await model_info_v1(
|
|
user_api_key_dict=UserAPIKeyAuth(models=[]),
|
|
)
|
|
models = resp["data"]
|
|
assert models[0]["model_info"]["mode"] == "rerank"
|
|
resp = await model_group_info(
|
|
user_api_key_dict=UserAPIKeyAuth(models=[]),
|
|
)
|
|
|
|
print(resp)
|
|
models = resp["data"]
|
|
assert models[0].mode == "rerank"
|
|
|
|
|
|
# @pytest.mark.asyncio
|
|
# async def test_proxy_team_member_add(prisma_client):
|
|
# """
|
|
# Add 10 people to a team. Confirm all 10 are added.
|
|
# """
|
|
# from litellm.proxy.management_endpoints.team_endpoints import (
|
|
# team_member_add,
|
|
# new_team,
|
|
# )
|
|
# from litellm.proxy._types import TeamMemberAddRequest, Member, NewTeamRequest
|
|
|
|
# setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
|
|
# setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
|
|
# try:
|
|
|
|
# async def test():
|
|
# await litellm.proxy.proxy_server.prisma_client.connect()
|
|
# from litellm.proxy.proxy_server import user_api_key_cache
|
|
|
|
# user_api_key_dict = UserAPIKeyAuth(
|
|
# user_role=LitellmUserRoles.PROXY_ADMIN,
|
|
# api_key="sk-1234",
|
|
# user_id="1234",
|
|
# )
|
|
|
|
# new_team()
|
|
# for _ in range(10):
|
|
# request = TeamMemberAddRequest(
|
|
# team_id="1234",
|
|
# member=Member(
|
|
# user_id="1234",
|
|
# user_role=LitellmUserRoles.INTERNAL_USER,
|
|
# ),
|
|
# )
|
|
# key = await team_member_add(
|
|
# request, user_api_key_dict=user_api_key_dict
|
|
# )
|
|
|
|
# print(key)
|
|
# user_id = key.user_id
|
|
|
|
# # check /user/info to verify user_role was set correctly
|
|
# new_user_info = await user_info(
|
|
# user_id=user_id, user_api_key_dict=user_api_key_dict
|
|
# )
|
|
# new_user_info = new_user_info.user_info
|
|
# print("new_user_info=", new_user_info)
|
|
# assert new_user_info["user_role"] == LitellmUserRoles.INTERNAL_USER
|
|
# assert new_user_info["user_id"] == user_id
|
|
|
|
# generated_key = key.key
|
|
# bearer_token = "Bearer " + generated_key
|
|
|
|
# assert generated_key not in user_api_key_cache.in_memory_cache.cache_dict
|
|
|
|
# value_from_prisma = await prisma_client.get_data(
|
|
# token=generated_key,
|
|
# )
|
|
# print("token from prisma", value_from_prisma)
|
|
|
|
# request = Request(
|
|
# {
|
|
# "type": "http",
|
|
# "route": api_route,
|
|
# "path": api_route.path,
|
|
# "headers": [("Authorization", bearer_token)],
|
|
# }
|
|
# )
|
|
|
|
# # use generated key to auth in
|
|
# result = await user_api_key_auth(request=request, api_key=bearer_token)
|
|
# print("result from user auth with new key", result)
|
|
|
|
# asyncio.run(test())
|
|
# except Exception as e:
|
|
# pytest.fail(f"An exception occurred - {str(e)}")
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_proxy_server_prisma_setup():
|
|
from litellm.proxy.proxy_server import ProxyStartupEvent, proxy_state
|
|
from litellm.proxy.utils import ProxyLogging
|
|
from litellm.caching import DualCache
|
|
|
|
user_api_key_cache = DualCache()
|
|
|
|
with patch.object(
|
|
litellm.proxy.proxy_server, "PrismaClient", new=MagicMock()
|
|
) as mock_prisma_client:
|
|
mock_client = mock_prisma_client.return_value # This is the mocked instance
|
|
mock_client.connect = AsyncMock() # Mock the connect method
|
|
mock_client.check_view_exists = AsyncMock() # Mock the check_view_exists method
|
|
mock_client.health_check = AsyncMock() # Mock the health_check method
|
|
mock_client._set_spend_logs_row_count_in_proxy_state = (
|
|
AsyncMock()
|
|
) # Mock the _set_spend_logs_row_count_in_proxy_state method
|
|
mock_client.start_db_health_watchdog_task = AsyncMock()
|
|
# Mock the db attribute with start_token_refresh_task for RDS IAM token refresh
|
|
mock_db = MagicMock()
|
|
mock_db.start_token_refresh_task = AsyncMock()
|
|
mock_client.db = mock_db
|
|
|
|
await ProxyStartupEvent._setup_prisma_client(
|
|
database_url=os.getenv("DATABASE_URL"),
|
|
proxy_logging_obj=ProxyLogging(user_api_key_cache=user_api_key_cache),
|
|
user_api_key_cache=user_api_key_cache,
|
|
)
|
|
|
|
# Verify our mocked methods were called
|
|
mock_client.connect.assert_called_once()
|
|
mock_client.check_view_exists.assert_called_once()
|
|
|
|
# Note: This is REALLY IMPORTANT to check that the health check is called
|
|
# This is how we ensure the DB is ready before proceeding
|
|
mock_client.health_check.assert_called_once()
|
|
|
|
# check that the spend logs row count is set in proxy state
|
|
mock_client._set_spend_logs_row_count_in_proxy_state.assert_called_once()
|
|
assert proxy_state.get_proxy_state_variable("spend_logs_row_count") is not None
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_proxy_server_prisma_setup_invalid_db():
|
|
"""
|
|
PROD TEST: Test that proxy server startup fails when it's unable to connect to the database
|
|
|
|
Think 2-3 times before editing / deleting this test, it's important for PROD
|
|
"""
|
|
from litellm.proxy.proxy_server import ProxyStartupEvent
|
|
from litellm.proxy.utils import ProxyLogging
|
|
from litellm.caching import DualCache
|
|
|
|
user_api_key_cache = DualCache()
|
|
invalid_db_url = "postgresql://invalid:invalid@localhost:5432/nonexistent"
|
|
|
|
_old_db_url = os.getenv("DATABASE_URL")
|
|
os.environ["DATABASE_URL"] = invalid_db_url
|
|
|
|
with pytest.raises(Exception) as exc_info:
|
|
await ProxyStartupEvent._setup_prisma_client(
|
|
database_url=invalid_db_url,
|
|
proxy_logging_obj=ProxyLogging(user_api_key_cache=user_api_key_cache),
|
|
user_api_key_cache=user_api_key_cache,
|
|
)
|
|
print("GOT EXCEPTION=", exc_info)
|
|
|
|
assert "httpx.ConnectError" in str(exc_info.value)
|
|
|
|
# # Verify the error message indicates a database connection issue
|
|
# assert any(x in str(exc_info.value).lower() for x in ["database", "connection", "authentication"])
|
|
|
|
if _old_db_url:
|
|
os.environ["DATABASE_URL"] = _old_db_url
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_ui_settings_spend_logs_threshold():
|
|
"""
|
|
Test that get_ui_settings correctly sets DISABLE_EXPENSIVE_DB_QUERIES based on spend_logs_row_count threshold
|
|
"""
|
|
from litellm.proxy.management_endpoints.ui_sso import get_ui_settings
|
|
from litellm.proxy.proxy_server import proxy_state
|
|
from fastapi import Request
|
|
from litellm.constants import MAX_SPENDLOG_ROWS_TO_QUERY
|
|
|
|
# Create a mock request
|
|
mock_request = Request(
|
|
scope={
|
|
"type": "http",
|
|
"headers": [],
|
|
"method": "GET",
|
|
"scheme": "http",
|
|
"server": ("testserver", 80),
|
|
"path": "/sso/get/ui_settings",
|
|
"query_string": b"",
|
|
}
|
|
)
|
|
|
|
# Test case 1: When spend_logs_row_count > MAX_SPENDLOG_ROWS_TO_QUERY
|
|
proxy_state.set_proxy_state_variable(
|
|
"spend_logs_row_count", MAX_SPENDLOG_ROWS_TO_QUERY + 1
|
|
)
|
|
response = await get_ui_settings(mock_request)
|
|
print("response from get_ui_settings", json.dumps(response, indent=4))
|
|
assert response["DISABLE_EXPENSIVE_DB_QUERIES"] is True
|
|
assert response["NUM_SPEND_LOGS_ROWS"] == MAX_SPENDLOG_ROWS_TO_QUERY + 1
|
|
|
|
# Test case 2: When spend_logs_row_count < MAX_SPENDLOG_ROWS_TO_QUERY
|
|
proxy_state.set_proxy_state_variable(
|
|
"spend_logs_row_count", MAX_SPENDLOG_ROWS_TO_QUERY - 1
|
|
)
|
|
response = await get_ui_settings(mock_request)
|
|
print("response from get_ui_settings", json.dumps(response, indent=4))
|
|
assert response["DISABLE_EXPENSIVE_DB_QUERIES"] is False
|
|
assert response["NUM_SPEND_LOGS_ROWS"] == MAX_SPENDLOG_ROWS_TO_QUERY - 1
|
|
|
|
# Test case 3: Edge case - exactly MAX_SPENDLOG_ROWS_TO_QUERY
|
|
proxy_state.set_proxy_state_variable(
|
|
"spend_logs_row_count", MAX_SPENDLOG_ROWS_TO_QUERY
|
|
)
|
|
response = await get_ui_settings(mock_request)
|
|
print("response from get_ui_settings", json.dumps(response, indent=4))
|
|
assert response["DISABLE_EXPENSIVE_DB_QUERIES"] is False
|
|
assert response["NUM_SPEND_LOGS_ROWS"] == MAX_SPENDLOG_ROWS_TO_QUERY
|
|
|
|
# Clean up
|
|
proxy_state.set_proxy_state_variable("spend_logs_row_count", 0)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_run_background_health_check_reflects_llm_model_list(monkeypatch):
|
|
"""
|
|
Test that _run_background_health_check reflects changes to llm_model_list in each health check iteration.
|
|
"""
|
|
import litellm.proxy.proxy_server as proxy_server
|
|
import copy
|
|
|
|
test_model_list_1 = [{"model_name": "model-a"}]
|
|
test_model_list_2 = [{"model_name": "model-b"}]
|
|
called_model_lists = []
|
|
|
|
async def fake_perform_health_check(model_list, details, max_concurrency=None):
|
|
called_model_lists.append(copy.deepcopy(model_list))
|
|
return (["healthy"], ["unhealthy"], {})
|
|
|
|
monkeypatch.setattr(proxy_server, "health_check_interval", 1)
|
|
monkeypatch.setattr(proxy_server, "health_check_details", None)
|
|
monkeypatch.setattr(
|
|
proxy_server, "llm_model_list", copy.deepcopy(test_model_list_1)
|
|
)
|
|
monkeypatch.setattr(proxy_server, "perform_health_check", fake_perform_health_check)
|
|
monkeypatch.setattr(proxy_server, "health_check_results", {})
|
|
|
|
async def fake_sleep(interval):
|
|
raise asyncio.CancelledError()
|
|
|
|
monkeypatch.setattr(asyncio, "sleep", fake_sleep)
|
|
|
|
try:
|
|
await proxy_server._run_background_health_check()
|
|
except asyncio.CancelledError:
|
|
pass
|
|
|
|
monkeypatch.setattr(
|
|
proxy_server, "llm_model_list", copy.deepcopy(test_model_list_2)
|
|
)
|
|
|
|
try:
|
|
await proxy_server._run_background_health_check()
|
|
except asyncio.CancelledError:
|
|
pass
|
|
|
|
assert len(called_model_lists) >= 2
|
|
assert called_model_lists[0] == test_model_list_1
|
|
assert called_model_lists[1] == test_model_list_2
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_background_health_check_skip_disabled_models(monkeypatch):
|
|
"""Ensure models with disable_background_health_check are skipped."""
|
|
import litellm.proxy.proxy_server as proxy_server
|
|
import copy
|
|
|
|
test_model_list = [
|
|
{"model_name": "model-a"},
|
|
{
|
|
"model_name": "model-b",
|
|
"model_info": {"disable_background_health_check": True},
|
|
},
|
|
]
|
|
called_model_lists = []
|
|
|
|
async def fake_perform_health_check(model_list, details, max_concurrency=None):
|
|
called_model_lists.append(copy.deepcopy(model_list))
|
|
return (["healthy"], [], {})
|
|
|
|
monkeypatch.setattr(proxy_server, "health_check_interval", 1)
|
|
monkeypatch.setattr(proxy_server, "health_check_details", None)
|
|
monkeypatch.setattr(proxy_server, "llm_model_list", copy.deepcopy(test_model_list))
|
|
monkeypatch.setattr(proxy_server, "perform_health_check", fake_perform_health_check)
|
|
monkeypatch.setattr(proxy_server, "health_check_results", {})
|
|
|
|
async def fake_sleep(interval):
|
|
raise asyncio.CancelledError()
|
|
|
|
monkeypatch.setattr(asyncio, "sleep", fake_sleep)
|
|
|
|
try:
|
|
await proxy_server._run_background_health_check()
|
|
except asyncio.CancelledError:
|
|
pass
|
|
|
|
assert called_model_lists == [[{"model_name": "model-a"}]]
|
|
|
|
|
|
def test_get_timeout_from_request():
|
|
from litellm.proxy.litellm_pre_call_utils import LiteLLMProxyRequestSetup
|
|
|
|
headers = {
|
|
"x-litellm-timeout": "90",
|
|
}
|
|
timeout = LiteLLMProxyRequestSetup._get_timeout_from_request(headers)
|
|
assert timeout == 90
|
|
|
|
headers = {
|
|
"x-litellm-timeout": "90.5",
|
|
}
|
|
timeout = LiteLLMProxyRequestSetup._get_timeout_from_request(headers)
|
|
assert timeout == 90.5
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"ui_exists, ui_has_content",
|
|
[
|
|
(True, True), # UI path exists and has content
|
|
(True, False), # UI path exists but is empty
|
|
(False, False), # UI path doesn't exist
|
|
],
|
|
)
|
|
def test_non_root_ui_path_logic(monkeypatch, tmp_path, ui_exists, ui_has_content):
|
|
"""
|
|
Test the non-root Docker UI path detection logic.
|
|
|
|
Tests that when LITELLM_NON_ROOT is set to "true":
|
|
- If UI path exists and has content, it should be used
|
|
- If UI path doesn't exist or is empty, proper error logging occurs
|
|
"""
|
|
import tempfile
|
|
import shutil
|
|
from unittest.mock import MagicMock
|
|
|
|
# Create a temporary directory to act as /tmp/litellm_ui
|
|
test_ui_path = tmp_path / "litellm_ui"
|
|
|
|
if ui_exists:
|
|
test_ui_path.mkdir(parents=True, exist_ok=True)
|
|
if ui_has_content:
|
|
# Create some dummy files to simulate built UI
|
|
(test_ui_path / "index.html").write_text("<html></html>")
|
|
(test_ui_path / "app.js").write_text("console.log('test');")
|
|
|
|
# Mock the environment variable and os.path operations
|
|
monkeypatch.setenv("LITELLM_NON_ROOT", "true")
|
|
|
|
# Create a mock logger to capture log messages
|
|
mock_logger = MagicMock()
|
|
|
|
# We need to reimport or reload the relevant code section
|
|
# Since this is module-level code, we'll test the logic directly
|
|
ui_path = None
|
|
non_root_ui_path = str(test_ui_path)
|
|
|
|
# Simulate the logic from proxy_server.py lines 909-920
|
|
if os.getenv("LITELLM_NON_ROOT", "").lower() == "true":
|
|
if os.path.exists(non_root_ui_path) and os.listdir(non_root_ui_path):
|
|
mock_logger.info(
|
|
f"Using pre-built UI for non-root Docker: {non_root_ui_path}"
|
|
)
|
|
mock_logger.info(
|
|
f"UI files found: {len(os.listdir(non_root_ui_path))} items"
|
|
)
|
|
ui_path = non_root_ui_path
|
|
else:
|
|
mock_logger.error(
|
|
f"UI not found at {non_root_ui_path}. UI will not be available."
|
|
)
|
|
mock_logger.error(
|
|
f"Path exists: {os.path.exists(non_root_ui_path)}, Has content: {os.path.exists(non_root_ui_path) and bool(os.listdir(non_root_ui_path))}"
|
|
)
|
|
|
|
# Verify behavior based on test parameters
|
|
if ui_exists and ui_has_content:
|
|
# UI should be found and used
|
|
assert ui_path == non_root_ui_path
|
|
assert mock_logger.info.call_count == 2
|
|
mock_logger.info.assert_any_call(
|
|
f"Using pre-built UI for non-root Docker: {non_root_ui_path}"
|
|
)
|
|
# Verify the second info call mentions the number of items
|
|
info_calls = [call[0][0] for call in mock_logger.info.call_args_list]
|
|
assert any("UI files found:" in call and "items" in call for call in info_calls)
|
|
assert mock_logger.error.call_count == 0
|
|
else:
|
|
# UI should not be found, error should be logged
|
|
assert ui_path is None
|
|
assert mock_logger.error.call_count == 2
|
|
mock_logger.error.assert_any_call(
|
|
f"UI not found at {non_root_ui_path}. UI will not be available."
|
|
)
|
|
# Verify the second error call has path existence info
|
|
error_calls = [call[0][0] for call in mock_logger.error.call_args_list]
|
|
assert any("Path exists:" in call for call in error_calls)
|
|
assert mock_logger.info.call_count == 0
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_config_callbacks_with_all_types(client_no_auth):
|
|
"""
|
|
Test that /get/config/callbacks returns all three callback types:
|
|
- success_callback with type="success"
|
|
- failure_callback with type="failure"
|
|
- callbacks (success_and_failure) with type="success_and_failure"
|
|
"""
|
|
from litellm.proxy.proxy_server import ProxyConfig
|
|
|
|
# Create a mock config with all three callback types
|
|
mock_config_data = {
|
|
"litellm_settings": {
|
|
"success_callback": ["langfuse", "braintrust"],
|
|
"failure_callback": ["sentry"],
|
|
"callbacks": ["otel", "langsmith"],
|
|
},
|
|
"environment_variables": {
|
|
"LANGFUSE_PUBLIC_KEY": "test-public-key",
|
|
"LANGFUSE_SECRET_KEY": "test-secret-key",
|
|
"LANGFUSE_HOST": "https://test.langfuse.com",
|
|
"BRAINTRUST_API_KEY": "test-braintrust-key",
|
|
"OTEL_EXPORTER": "otlp",
|
|
"OTEL_ENDPOINT": "http://localhost:4317",
|
|
"LANGSMITH_API_KEY": "test-langsmith-key",
|
|
},
|
|
"general_settings": {},
|
|
}
|
|
|
|
proxy_config = getattr(litellm.proxy.proxy_server, "proxy_config")
|
|
|
|
with patch.object(
|
|
proxy_config, "get_config", new=AsyncMock(return_value=mock_config_data)
|
|
):
|
|
response = client_no_auth.get("/get/config/callbacks")
|
|
|
|
assert response.status_code == 200
|
|
result = response.json()
|
|
|
|
# Verify response structure
|
|
assert "status" in result
|
|
assert result["status"] == "success"
|
|
assert "callbacks" in result
|
|
|
|
callbacks = result["callbacks"]
|
|
|
|
# Verify we have all 5 callbacks (2 success + 1 failure + 2 success_and_failure)
|
|
assert len(callbacks) == 5
|
|
|
|
# Group callbacks by type
|
|
success_callbacks = [cb for cb in callbacks if cb.get("type") == "success"]
|
|
failure_callbacks = [cb for cb in callbacks if cb.get("type") == "failure"]
|
|
success_and_failure_callbacks = [
|
|
cb for cb in callbacks if cb.get("type") == "success_and_failure"
|
|
]
|
|
|
|
# Verify all callbacks have required fields
|
|
for callback in callbacks:
|
|
assert "name" in callback
|
|
assert "variables" in callback
|
|
assert "type" in callback
|
|
assert callback["type"] in ["success", "failure", "success_and_failure"]
|
|
|
|
# Verify success callbacks
|
|
assert len(success_callbacks) == 2
|
|
success_names = [cb["name"] for cb in success_callbacks]
|
|
assert "langfuse" in success_names
|
|
assert "braintrust" in success_names
|
|
|
|
# Verify failure callbacks
|
|
assert len(failure_callbacks) == 1
|
|
assert failure_callbacks[0]["name"] == "sentry"
|
|
|
|
# Verify success_and_failure callbacks
|
|
assert len(success_and_failure_callbacks) == 2
|
|
success_and_failure_names = [cb["name"] for cb in success_and_failure_callbacks]
|
|
assert "otel" in success_and_failure_names
|
|
assert "langsmith" in success_and_failure_names
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_config_callbacks_environment_variables(client_no_auth):
|
|
"""
|
|
Test that /get/config/callbacks correctly includes environment variables
|
|
for each callback type. Values are returned as-is from the config (no decryption).
|
|
"""
|
|
from litellm.proxy.proxy_server import ProxyConfig
|
|
|
|
# Create a mock config with callbacks and their env vars
|
|
mock_config_data = {
|
|
"litellm_settings": {
|
|
"success_callback": ["langfuse"],
|
|
"failure_callback": [],
|
|
"callbacks": ["otel"],
|
|
},
|
|
"environment_variables": {
|
|
"LANGFUSE_PUBLIC_KEY": "test-public-key",
|
|
"LANGFUSE_SECRET_KEY": "test-secret-key",
|
|
"LANGFUSE_HOST": "https://cloud.langfuse.com",
|
|
"OTEL_EXPORTER": "otlp",
|
|
"OTEL_ENDPOINT": "http://localhost:4317",
|
|
"OTEL_HEADERS": "key=value",
|
|
},
|
|
"general_settings": {},
|
|
}
|
|
|
|
proxy_config = getattr(litellm.proxy.proxy_server, "proxy_config")
|
|
|
|
with patch.object(
|
|
proxy_config, "get_config", new=AsyncMock(return_value=mock_config_data)
|
|
):
|
|
response = client_no_auth.get("/get/config/callbacks")
|
|
|
|
assert response.status_code == 200
|
|
result = response.json()
|
|
|
|
callbacks = result["callbacks"]
|
|
|
|
# Find langfuse callback (success type)
|
|
langfuse_callback = next(
|
|
(cb for cb in callbacks if cb["name"] == "langfuse"), None
|
|
)
|
|
assert langfuse_callback is not None
|
|
assert langfuse_callback["type"] == "success"
|
|
assert "variables" in langfuse_callback
|
|
|
|
# Verify langfuse env vars are present (values returned as-is, no decryption)
|
|
langfuse_vars = langfuse_callback["variables"]
|
|
assert "LANGFUSE_PUBLIC_KEY" in langfuse_vars
|
|
assert langfuse_vars["LANGFUSE_PUBLIC_KEY"] == "test-public-key"
|
|
assert "LANGFUSE_SECRET_KEY" in langfuse_vars
|
|
assert langfuse_vars["LANGFUSE_SECRET_KEY"] == "test-secret-key"
|
|
assert "LANGFUSE_HOST" in langfuse_vars
|
|
assert langfuse_vars["LANGFUSE_HOST"] == "https://cloud.langfuse.com"
|
|
|
|
# Find otel callback (success_and_failure type)
|
|
otel_callback = next((cb for cb in callbacks if cb["name"] == "otel"), None)
|
|
assert otel_callback is not None
|
|
assert otel_callback["type"] == "success_and_failure"
|
|
assert "variables" in otel_callback
|
|
|
|
# Verify otel env vars are present
|
|
otel_vars = otel_callback["variables"]
|
|
assert "OTEL_EXPORTER" in otel_vars
|
|
assert otel_vars["OTEL_EXPORTER"] == "otlp"
|
|
assert "OTEL_ENDPOINT" in otel_vars
|
|
assert otel_vars["OTEL_ENDPOINT"] == "http://localhost:4317"
|
|
assert "OTEL_HEADERS" in otel_vars
|
|
assert otel_vars["OTEL_HEADERS"] == "key=value"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_update_config_success_callback_normalization():
|
|
"""
|
|
Ensure success_callback values are normalized to lowercase when updating config.
|
|
This prevents delete_callback (which searches lowercase) from failing on mixed case inputs like 'SQS'.
|
|
"""
|
|
import litellm.proxy.proxy_server as proxy_server
|
|
from litellm.proxy._types import ConfigYAML
|
|
|
|
# Ensure feature is enabled and prisma_client is set
|
|
setattr(proxy_server, "store_model_in_db", True)
|
|
setattr(proxy_server, "proxy_logging_obj", MagicMock())
|
|
|
|
class MockPrisma:
|
|
def __init__(self):
|
|
self.db = MagicMock()
|
|
self.db.litellm_config = MagicMock()
|
|
self.db.litellm_config.upsert = AsyncMock()
|
|
|
|
# proxy_server.update_config expects this to be sync returning a dict
|
|
def jsonify_object(self, obj):
|
|
return obj
|
|
|
|
setattr(proxy_server, "prisma_client", MockPrisma())
|
|
|
|
class MockProxyConfig:
|
|
def __init__(self):
|
|
self.saved_config = None
|
|
|
|
async def get_config(self):
|
|
# Existing config has one lowercase callback already
|
|
return {"litellm_settings": {"success_callback": ["langfuse"]}}
|
|
|
|
async def save_config(self, new_config: dict):
|
|
self.saved_config = new_config
|
|
|
|
async def add_deployment(self, prisma_client=None, proxy_logging_obj=None):
|
|
return None
|
|
|
|
mock_proxy_config = MockProxyConfig()
|
|
setattr(proxy_server, "proxy_config", mock_proxy_config)
|
|
|
|
# Update config with mixed-case callbacks - expect normalization to lowercase
|
|
config_update = ConfigYAML(litellm_settings={"success_callback": ["SQS", "sQs"]})
|
|
from litellm.proxy._types import LitellmUserRoles, UserAPIKeyAuth
|
|
admin_user = UserAPIKeyAuth(user_role=LitellmUserRoles.PROXY_ADMIN, api_key="sk-test")
|
|
await proxy_server.update_config(config_update, user_api_key_dict=admin_user)
|
|
|
|
saved = mock_proxy_config.saved_config
|
|
assert saved is not None, "save_config was not called"
|
|
callbacks = saved["litellm_settings"]["success_callback"]
|
|
|
|
# Deduped and normalized
|
|
assert "sqs" in callbacks
|
|
assert "SQS" not in callbacks
|
|
assert "sQs" not in callbacks
|
|
# Existing callback should still be present
|
|
assert "langfuse" in callbacks
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"data",
|
|
[
|
|
{
|
|
"model": {
|
|
"model_name": "azure/gpt-4.1-mini",
|
|
"litellm_params": {"model": "azure/gpt-4.1-mini"},
|
|
"model_info": {"base_model": "gpt-4.1-mini"},
|
|
},
|
|
"expected": "gpt-4.1-mini",
|
|
},
|
|
{
|
|
"model": {
|
|
"model_name": "openai/gpt-4.1-mini",
|
|
"litellm_params": {"model": "openai/gpt-4.1-mini"},
|
|
},
|
|
"expected": "openai/gpt-4.1-mini",
|
|
},
|
|
{
|
|
"model": {
|
|
"model_name": "openai/gpt-4.1-mini",
|
|
"litellm_params": {"model": "openai/gpt-4.1-mini"},
|
|
"model_info": {"base_model": "gpt-4.1-mini"},
|
|
},
|
|
"expected": "gpt-4.1-mini",
|
|
},
|
|
{
|
|
"model": {
|
|
"model_name": "claude-sonnet-4-5-20250929",
|
|
"litellm_params": {"model": "anthropic/claude-sonnet-4-5@20250929"},
|
|
"model_info": {"base_model": "anthropic/claude-sonnet-4-5-20250929"},
|
|
},
|
|
"expected": "anthropic/claude-sonnet-4-5-20250929",
|
|
},
|
|
{
|
|
"model": {
|
|
"model_name": "gemini-2.5-flash-001",
|
|
"litellm_params": {"model": "gemini/gemini-2.5-flash@001"},
|
|
"model_info": {"base_model": "gemini-2.5-flash-001"},
|
|
},
|
|
"expected": "gemini-2.5-flash-001",
|
|
},
|
|
],
|
|
)
|
|
def test_get_litellm_model_info(data):
|
|
from litellm.proxy.proxy_server import get_litellm_model_info
|
|
|
|
model = data["model"]
|
|
get_info_mock = MagicMock()
|
|
|
|
with mock.patch(
|
|
"litellm.get_model_info",
|
|
new=get_info_mock,
|
|
):
|
|
get_litellm_model_info(model=model)
|
|
get_info_mock.assert_called_once_with(data["expected"])
|