mirror of
https://github.com/tiennm99/litellm.git
synced 2026-06-18 19:31:57 +00:00
9a338e1b6b
Several tests parametrized over (model, api_key, ...) tuples or raw token strings, causing pytest to embed those values in the test ID and print them in CI logs. Refactored each affected test to keep the same coverage without putting key material into parametrize. - audio_tests/test_audio_speech.py: split env-var keys into separate azure/openai test functions sharing a helper; sync_mode parametrize preserved. - audio_tests/test_whisper.py: split into openai_whisper / azure_whisper functions sharing a helper; response_format parametrize preserved. - local_testing/test_embedding.py: single-case parametrize inlined. - proxy_unit_tests/test_user_api_key_auth.py: 5 header parametrize cases split into 5 named tests sharing an _assert helper. - proxy_unit_tests/test_proxy_utils.py: 4 api_key_value cases split into 4 named tests. - test_litellm/proxy/auth/test_user_api_key_auth.py: 5 key-prefix cases (Bearer / Basic / lowercase bearer / raw / AWS SigV4) split into 5 named tests. Verified: black clean; 14 refactored unit tests pass; pytest collects audio/embedding tests with safe IDs (no key material in test IDs).
2782 lines
87 KiB
Python
2782 lines
87 KiB
Python
import asyncio
|
|
import json
|
|
import os
|
|
import sys
|
|
from datetime import datetime
|
|
from typing import Any, Dict, List, Optional, Union
|
|
from unittest.mock import Mock
|
|
|
|
import pytest
|
|
from fastapi import Request
|
|
from starlette.datastructures import State
|
|
|
|
from litellm.proxy.utils import _get_docs_url, _get_openapi_url, _get_redoc_url
|
|
|
|
sys.path.insert(
|
|
0, os.path.abspath("../..")
|
|
) # Adds the parent directory to the system path
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
import litellm
|
|
from litellm.proxy._types import LitellmUserRoles, UserAPIKeyAuth
|
|
from litellm.proxy.auth.auth_utils import (
|
|
check_complete_credentials,
|
|
is_request_body_safe,
|
|
)
|
|
from litellm.proxy.litellm_pre_call_utils import (
|
|
_get_dynamic_logging_metadata,
|
|
add_litellm_data_to_request,
|
|
)
|
|
|
|
pytestmark = pytest.mark.xdist_group("proxy_heavy")
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_request(monkeypatch):
|
|
mock_request = Mock(spec=Request)
|
|
mock_request.query_params = {} # Set mock query_params to an empty dictionary
|
|
mock_request.headers = {"traceparent": "test_traceparent"}
|
|
mock_request.state = (
|
|
State()
|
|
) # Real State so _safe_get_request_headers caching works
|
|
monkeypatch.setattr(
|
|
"litellm.proxy.litellm_pre_call_utils.add_litellm_data_to_request", mock_request
|
|
)
|
|
return mock_request
|
|
|
|
|
|
@pytest.mark.parametrize("endpoint", ["/v1/threads", "/v1/thread/123"])
|
|
@pytest.mark.asyncio
|
|
async def test_add_litellm_data_to_request_thread_endpoint(endpoint, mock_request):
|
|
mock_request.url.path = endpoint
|
|
user_api_key_dict = UserAPIKeyAuth(
|
|
api_key="test_api_key", user_id="test_user_id", org_id="test_org_id"
|
|
)
|
|
proxy_config = Mock()
|
|
|
|
data = {}
|
|
await add_litellm_data_to_request(
|
|
data, mock_request, user_api_key_dict, proxy_config
|
|
)
|
|
|
|
print("DATA: ", data)
|
|
|
|
assert "litellm_metadata" in data
|
|
assert "metadata" not in data
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"endpoint", ["/chat/completions", "/v1/completions", "/completions"]
|
|
)
|
|
@pytest.mark.asyncio
|
|
async def test_add_litellm_data_to_request_non_thread_endpoint(endpoint, mock_request):
|
|
mock_request.url.path = endpoint
|
|
user_api_key_dict = UserAPIKeyAuth(
|
|
api_key="test_api_key", user_id="test_user_id", org_id="test_org_id"
|
|
)
|
|
proxy_config = Mock()
|
|
|
|
data = {}
|
|
await add_litellm_data_to_request(
|
|
data, mock_request, user_api_key_dict, proxy_config
|
|
)
|
|
|
|
print("DATA: ", data)
|
|
|
|
assert "metadata" in data
|
|
assert "litellm_metadata" not in data
|
|
|
|
|
|
# test adding traceparent
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"endpoint", ["/chat/completions", "/v1/completions", "/completions"]
|
|
)
|
|
@pytest.mark.asyncio
|
|
async def test_traceparent_not_added_by_default(endpoint, mock_request):
|
|
"""
|
|
This tests that traceparent is not forwarded in the extra_headers
|
|
|
|
We had an incident where bedrock calls were failing because traceparent was forwarded
|
|
"""
|
|
from litellm.integrations.opentelemetry import OpenTelemetry
|
|
|
|
otel_logger = OpenTelemetry()
|
|
setattr(litellm.proxy.proxy_server, "open_telemetry_logger", otel_logger)
|
|
|
|
mock_request.url.path = endpoint
|
|
user_api_key_dict = UserAPIKeyAuth(
|
|
api_key="test_api_key", user_id="test_user_id", org_id="test_org_id"
|
|
)
|
|
proxy_config = Mock()
|
|
|
|
data = {}
|
|
await add_litellm_data_to_request(
|
|
data, mock_request, user_api_key_dict, proxy_config
|
|
)
|
|
|
|
print("DATA: ", data)
|
|
|
|
_extra_headers = data.get("extra_headers") or {}
|
|
assert "traceparent" not in _extra_headers
|
|
|
|
setattr(litellm.proxy.proxy_server, "open_telemetry_logger", None)
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"request_tags", [None, ["request_tag1", "request_tag2", "request_tag3"]]
|
|
)
|
|
@pytest.mark.parametrize(
|
|
"request_sl_metadata", [None, {"request_key": "request_value"}]
|
|
)
|
|
@pytest.mark.parametrize("key_tags", [None, ["key_tag1", "key_tag2", "key_tag3"]])
|
|
@pytest.mark.parametrize("key_sl_metadata", [None, {"key_key": "key_value"}])
|
|
@pytest.mark.parametrize("team_tags", [None, ["team_tag1", "team_tag2", "team_tag3"]])
|
|
@pytest.mark.parametrize("team_sl_metadata", [None, {"team_key": "team_value"}])
|
|
@pytest.mark.asyncio
|
|
async def test_add_key_or_team_level_spend_logs_metadata_to_request(
|
|
mock_request,
|
|
request_tags,
|
|
request_sl_metadata,
|
|
team_tags,
|
|
key_sl_metadata,
|
|
team_sl_metadata,
|
|
key_tags,
|
|
):
|
|
## COMPLETE LIST OF TAGS
|
|
all_tags = []
|
|
if request_tags is not None:
|
|
print("Request Tags - {}".format(request_tags))
|
|
all_tags.extend(request_tags)
|
|
if key_tags is not None:
|
|
print("Key Tags - {}".format(key_tags))
|
|
all_tags.extend(key_tags)
|
|
if team_tags is not None:
|
|
print("Team Tags - {}".format(team_tags))
|
|
all_tags.extend(team_tags)
|
|
|
|
## COMPLETE SPEND_LOGS METADATA
|
|
all_sl_metadata = {}
|
|
if request_sl_metadata is not None:
|
|
all_sl_metadata.update(request_sl_metadata)
|
|
if key_sl_metadata is not None:
|
|
all_sl_metadata.update(key_sl_metadata)
|
|
if team_sl_metadata is not None:
|
|
all_sl_metadata.update(team_sl_metadata)
|
|
|
|
print(f"team_sl_metadata: {team_sl_metadata}")
|
|
mock_request.url.path = "/chat/completions"
|
|
# Opt the key into client-supplied tags so request_tags are preserved
|
|
# and merged with admin-configured key/team tags. Without this flag,
|
|
# request_tags would be stripped by add_litellm_data_to_request.
|
|
key_metadata = {
|
|
"tags": key_tags,
|
|
"spend_logs_metadata": key_sl_metadata,
|
|
"allow_client_tags": True,
|
|
}
|
|
team_metadata = {
|
|
"tags": team_tags,
|
|
"spend_logs_metadata": team_sl_metadata,
|
|
}
|
|
user_api_key_dict = UserAPIKeyAuth(
|
|
api_key="test_api_key",
|
|
user_id="test_user_id",
|
|
org_id="test_org_id",
|
|
metadata=key_metadata,
|
|
team_metadata=team_metadata,
|
|
)
|
|
proxy_config = Mock()
|
|
|
|
data = {"metadata": {}}
|
|
if request_tags is not None:
|
|
data["metadata"]["tags"] = request_tags
|
|
if request_sl_metadata is not None:
|
|
data["metadata"]["spend_logs_metadata"] = request_sl_metadata
|
|
|
|
print(data)
|
|
new_data = await add_litellm_data_to_request(
|
|
data, mock_request, user_api_key_dict, proxy_config
|
|
)
|
|
|
|
print("New Data: {}".format(new_data))
|
|
print("all_tags: {}".format(all_tags))
|
|
assert "metadata" in new_data
|
|
if len(all_tags) == 0:
|
|
assert "tags" not in new_data["metadata"], "Expected=No tags. Got={}".format(
|
|
new_data["metadata"]["tags"]
|
|
)
|
|
else:
|
|
assert new_data["metadata"]["tags"] == all_tags, "Expected={}. Got={}".format(
|
|
all_tags, new_data["metadata"].get("tags", None)
|
|
)
|
|
|
|
if len(all_sl_metadata.keys()) == 0:
|
|
assert (
|
|
"spend_logs_metadata" not in new_data["metadata"]
|
|
), "Expected=No spend logs metadata. Got={}".format(
|
|
new_data["metadata"]["spend_logs_metadata"]
|
|
)
|
|
else:
|
|
assert (
|
|
new_data["metadata"]["spend_logs_metadata"] == all_sl_metadata
|
|
), "Expected={}. Got={}".format(
|
|
all_sl_metadata, new_data["metadata"]["spend_logs_metadata"]
|
|
)
|
|
# assert (
|
|
# new_data["metadata"]["spend_logs_metadata"] == metadata["spend_logs_metadata"]
|
|
# )
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"callback_vars",
|
|
[
|
|
{
|
|
"langfuse_host": "https://us.cloud.langfuse.com",
|
|
"langfuse_public_key": "pk-lf-9636b7a6-c066",
|
|
"langfuse_secret_key": "sk-lf-7cc8b620",
|
|
}
|
|
],
|
|
)
|
|
def test_dynamic_logging_metadata_key_and_team_metadata(callback_vars):
|
|
from litellm.proxy.proxy_server import ProxyConfig
|
|
|
|
proxy_config = ProxyConfig()
|
|
user_api_key_dict = UserAPIKeyAuth(
|
|
token="sk-test-mock-token-789",
|
|
key_name="sk-...63Fg",
|
|
key_alias=None,
|
|
spend=0.000111,
|
|
max_budget=None,
|
|
expires=None,
|
|
models=[],
|
|
aliases={},
|
|
config={},
|
|
user_id=None,
|
|
team_id="ishaan-special-team_e02dd54f-f790-4755-9f93-73734f415898",
|
|
max_parallel_requests=None,
|
|
metadata={
|
|
"logging": [
|
|
{
|
|
"callback_name": "langfuse",
|
|
"callback_type": "success",
|
|
"callback_vars": callback_vars,
|
|
}
|
|
]
|
|
},
|
|
tpm_limit=None,
|
|
rpm_limit=None,
|
|
budget_duration=None,
|
|
budget_reset_at=None,
|
|
allowed_cache_controls=[],
|
|
permissions={},
|
|
model_spend={},
|
|
model_max_budget={},
|
|
soft_budget_cooldown=False,
|
|
litellm_budget_table=None,
|
|
org_id=None,
|
|
team_spend=0.000132,
|
|
team_alias=None,
|
|
team_tpm_limit=None,
|
|
team_rpm_limit=None,
|
|
team_max_budget=None,
|
|
team_models=[],
|
|
team_blocked=False,
|
|
soft_budget=None,
|
|
team_model_aliases=None,
|
|
team_member_spend=None,
|
|
team_member=None,
|
|
team_metadata={},
|
|
end_user_id=None,
|
|
end_user_tpm_limit=None,
|
|
end_user_rpm_limit=None,
|
|
end_user_max_budget=None,
|
|
last_refreshed_at=1726101560.967527,
|
|
api_key="sk-test-mock-api-key-202",
|
|
user_role=LitellmUserRoles.INTERNAL_USER,
|
|
allowed_model_region=None,
|
|
parent_otel_span=None,
|
|
rpm_limit_per_model=None,
|
|
tpm_limit_per_model=None,
|
|
)
|
|
callbacks = _get_dynamic_logging_metadata(
|
|
user_api_key_dict=user_api_key_dict, proxy_config=proxy_config
|
|
)
|
|
|
|
assert callbacks is not None
|
|
|
|
for var in callbacks.callback_vars.values():
|
|
assert "os.environ" not in var
|
|
|
|
|
|
def test_dynamic_logging_metadata_ignores_env_references_from_key_metadata(
|
|
monkeypatch,
|
|
):
|
|
monkeypatch.setenv("LANGFUSE_SECRET_KEY_TEMP", "server-side-secret")
|
|
monkeypatch.setattr(
|
|
litellm.utils,
|
|
"get_secret",
|
|
lambda *args, **kwargs: pytest.fail("get_secret should not be called"),
|
|
)
|
|
from litellm.proxy.proxy_server import ProxyConfig
|
|
|
|
proxy_config = ProxyConfig()
|
|
user_api_key_dict = UserAPIKeyAuth(
|
|
api_key="test-key",
|
|
metadata={
|
|
"logging": [
|
|
{
|
|
"callback_name": "langfuse",
|
|
"callback_type": "success",
|
|
"callback_vars": {
|
|
"langfuse_secret_key": "os.environ/LANGFUSE_SECRET_KEY_TEMP",
|
|
},
|
|
}
|
|
]
|
|
},
|
|
team_metadata={},
|
|
)
|
|
|
|
callbacks = _get_dynamic_logging_metadata(
|
|
user_api_key_dict=user_api_key_dict, proxy_config=proxy_config
|
|
)
|
|
|
|
assert callbacks is None
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"callback_vars",
|
|
[
|
|
{
|
|
"turn_off_message_logging": True,
|
|
},
|
|
{
|
|
"turn_off_message_logging": False,
|
|
},
|
|
],
|
|
)
|
|
def test_dynamic_turn_off_message_logging(callback_vars):
|
|
from litellm.proxy.proxy_server import ProxyConfig
|
|
|
|
proxy_config = ProxyConfig()
|
|
user_api_key_dict = UserAPIKeyAuth(
|
|
token="sk-test-mock-token-789",
|
|
key_name="sk-...63Fg",
|
|
key_alias=None,
|
|
spend=0.000111,
|
|
max_budget=None,
|
|
expires=None,
|
|
models=[],
|
|
aliases={},
|
|
config={},
|
|
user_id=None,
|
|
team_id="ishaan-special-team_e02dd54f-f790-4755-9f93-73734f415898",
|
|
max_parallel_requests=None,
|
|
metadata={
|
|
"logging": [
|
|
{
|
|
"callback_name": "datadog",
|
|
"callback_vars": callback_vars,
|
|
}
|
|
]
|
|
},
|
|
tpm_limit=None,
|
|
rpm_limit=None,
|
|
budget_duration=None,
|
|
budget_reset_at=None,
|
|
allowed_cache_controls=[],
|
|
permissions={},
|
|
model_spend={},
|
|
model_max_budget={},
|
|
soft_budget_cooldown=False,
|
|
litellm_budget_table=None,
|
|
org_id=None,
|
|
team_spend=0.000132,
|
|
team_alias=None,
|
|
team_tpm_limit=None,
|
|
team_rpm_limit=None,
|
|
team_max_budget=None,
|
|
team_models=[],
|
|
team_blocked=False,
|
|
soft_budget=None,
|
|
team_model_aliases=None,
|
|
team_member_spend=None,
|
|
team_member=None,
|
|
team_metadata={},
|
|
end_user_id=None,
|
|
end_user_tpm_limit=None,
|
|
end_user_rpm_limit=None,
|
|
end_user_max_budget=None,
|
|
last_refreshed_at=1726101560.967527,
|
|
api_key="sk-test-mock-api-key-202",
|
|
user_role=LitellmUserRoles.INTERNAL_USER,
|
|
allowed_model_region=None,
|
|
parent_otel_span=None,
|
|
rpm_limit_per_model=None,
|
|
tpm_limit_per_model=None,
|
|
)
|
|
callbacks = _get_dynamic_logging_metadata(
|
|
user_api_key_dict=user_api_key_dict, proxy_config=proxy_config
|
|
)
|
|
|
|
assert callbacks is not None
|
|
assert (
|
|
callbacks.callback_vars["turn_off_message_logging"]
|
|
== callback_vars["turn_off_message_logging"]
|
|
)
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"allow_client_side_credentials, expect_error", [(True, False), (False, True)]
|
|
)
|
|
def test_is_request_body_safe_global_enabled(
|
|
allow_client_side_credentials, expect_error
|
|
):
|
|
from litellm import Router
|
|
|
|
error_raised = False
|
|
|
|
llm_router = Router(
|
|
model_list=[
|
|
{
|
|
"model_name": "gpt-3.5-turbo",
|
|
"litellm_params": {
|
|
"model": "gpt-3.5-turbo",
|
|
"api_key": os.getenv("OPENAI_API_KEY"),
|
|
},
|
|
}
|
|
]
|
|
)
|
|
try:
|
|
is_request_body_safe(
|
|
request_body={"api_base": "hello-world"},
|
|
general_settings={
|
|
"allow_client_side_credentials": allow_client_side_credentials
|
|
},
|
|
llm_router=llm_router,
|
|
model="gpt-3.5-turbo",
|
|
)
|
|
except Exception as e:
|
|
print(e)
|
|
error_raised = True
|
|
|
|
assert expect_error == error_raised
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"allow_client_side_credentials, expect_error", [(True, False), (False, True)]
|
|
)
|
|
def test_is_request_body_safe_model_enabled(
|
|
allow_client_side_credentials, expect_error
|
|
):
|
|
from litellm import Router
|
|
|
|
error_raised = False
|
|
|
|
llm_router = Router(
|
|
model_list=[
|
|
{
|
|
"model_name": "fireworks_ai/*",
|
|
"litellm_params": {
|
|
"model": "fireworks_ai/*",
|
|
"api_key": os.getenv("FIREWORKS_API_KEY"),
|
|
"configurable_clientside_auth_params": (
|
|
["api_base"] if allow_client_side_credentials else []
|
|
),
|
|
},
|
|
}
|
|
]
|
|
)
|
|
try:
|
|
is_request_body_safe(
|
|
request_body={"api_base": "hello-world"},
|
|
general_settings={},
|
|
llm_router=llm_router,
|
|
model="fireworks_ai/my-new-model",
|
|
)
|
|
except Exception as e:
|
|
print(e)
|
|
error_raised = True
|
|
|
|
assert expect_error == error_raised
|
|
|
|
|
|
def _assert_check_complete_credentials(api_key_value, expect_complete):
|
|
request_body = {"model": "gpt-3.5-turbo", "api_key": api_key_value}
|
|
result = check_complete_credentials(request_body=request_body)
|
|
assert result == expect_complete
|
|
|
|
|
|
def test_check_complete_credentials_with_real_key():
|
|
_assert_check_complete_credentials(
|
|
api_key_value="sk-" + "x" * 8, expect_complete=True
|
|
)
|
|
|
|
|
|
def test_check_complete_credentials_with_empty_string():
|
|
_assert_check_complete_credentials(api_key_value="", expect_complete=False)
|
|
|
|
|
|
def test_check_complete_credentials_with_none():
|
|
_assert_check_complete_credentials(api_key_value=None, expect_complete=False)
|
|
|
|
|
|
def test_check_complete_credentials_with_whitespace():
|
|
_assert_check_complete_credentials(api_key_value=" ", expect_complete=False)
|
|
|
|
|
|
def test_reading_openai_org_id_from_headers():
|
|
from litellm.proxy.litellm_pre_call_utils import LiteLLMProxyRequestSetup
|
|
|
|
headers = {
|
|
"OpenAI-Organization": "test_org_id",
|
|
}
|
|
org_id = LiteLLMProxyRequestSetup.get_openai_org_id_from_headers(headers)
|
|
assert org_id == "test_org_id"
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"headers, general_settings, expected_data",
|
|
[
|
|
(
|
|
{"X-OpenWebUI-User-Id": "ishaan3"},
|
|
{"user_header_name": "X-OpenWebUI-User-Id"},
|
|
"ishaan3",
|
|
),
|
|
(
|
|
{"x-openwebui-user-id": "ishaan3"},
|
|
{"user_header_name": "X-OpenWebUI-User-Id"},
|
|
"ishaan3",
|
|
),
|
|
({"X-OpenWebUI-User-Id": "ishaan3"}, {}, None),
|
|
({}, None, None),
|
|
],
|
|
)
|
|
def test_add_litellm_data_for_backend_llm_call(
|
|
headers, general_settings, expected_data
|
|
):
|
|
import json
|
|
|
|
from litellm.proxy._types import UserAPIKeyAuth
|
|
from litellm.proxy.litellm_pre_call_utils import LiteLLMProxyRequestSetup
|
|
|
|
UserAPIKeyAuth(api_key="test_api_key", user_id="test_user_id", org_id="test_org_id")
|
|
|
|
data = LiteLLMProxyRequestSetup.get_user_from_headers(
|
|
headers=headers,
|
|
general_settings=general_settings,
|
|
)
|
|
|
|
assert json.dumps(data, sort_keys=True) == json.dumps(expected_data, sort_keys=True)
|
|
|
|
|
|
def test_foward_litellm_user_info_to_backend_llm_call():
|
|
import json
|
|
|
|
litellm.add_user_information_to_llm_headers = True
|
|
|
|
from litellm.proxy._types import UserAPIKeyAuth
|
|
from litellm.proxy.litellm_pre_call_utils import LiteLLMProxyRequestSetup
|
|
|
|
user_api_key_dict = UserAPIKeyAuth(
|
|
api_key="test_api_key", user_id="test_user_id", org_id="test_org_id"
|
|
)
|
|
|
|
data = LiteLLMProxyRequestSetup.add_headers_to_llm_call(
|
|
headers={},
|
|
user_api_key_dict=user_api_key_dict,
|
|
)
|
|
|
|
expected_data = {
|
|
"x-litellm-user_api_key_user_id": "test_user_id",
|
|
"x-litellm-user_api_key_org_id": "test_org_id",
|
|
"x-litellm-user_api_key_hash": "test_api_key",
|
|
"x-litellm-user_api_key_spend": 0.0,
|
|
"x-litellm-user_api_key_auth_metadata": {},
|
|
}
|
|
|
|
assert json.dumps(data, sort_keys=True) == json.dumps(expected_data, sort_keys=True)
|
|
|
|
|
|
def test_update_internal_user_params():
|
|
from litellm.proxy._types import NewUserRequest
|
|
from litellm.proxy.management_endpoints.internal_user_endpoints import (
|
|
_update_internal_new_user_params,
|
|
)
|
|
|
|
litellm.default_internal_user_params = {
|
|
"max_budget": 100,
|
|
"budget_duration": "30d",
|
|
"models": ["gpt-3.5-turbo"],
|
|
}
|
|
|
|
data = NewUserRequest(user_role="internal_user", user_email="krrish3@berri.ai")
|
|
data_json = data.model_dump()
|
|
updated_data_json = _update_internal_new_user_params(data_json, data)
|
|
assert updated_data_json["models"] == litellm.default_internal_user_params["models"]
|
|
assert (
|
|
updated_data_json["max_budget"]
|
|
== litellm.default_internal_user_params["max_budget"]
|
|
)
|
|
assert (
|
|
updated_data_json["budget_duration"]
|
|
== litellm.default_internal_user_params["budget_duration"]
|
|
)
|
|
|
|
|
|
def test_update_internal_new_user_params_with_no_initial_role_set():
|
|
from litellm.proxy._types import NewUserRequest
|
|
from litellm.proxy.management_endpoints.internal_user_endpoints import (
|
|
_update_internal_new_user_params,
|
|
)
|
|
|
|
litellm.default_internal_user_params = {
|
|
"max_budget": 100,
|
|
"budget_duration": "30d",
|
|
"models": ["gpt-3.5-turbo"],
|
|
}
|
|
|
|
data = NewUserRequest(user_email="krrish3@berri.ai")
|
|
data_json = data.model_dump()
|
|
updated_data_json = _update_internal_new_user_params(data_json, data)
|
|
assert updated_data_json["models"] == litellm.default_internal_user_params["models"]
|
|
assert (
|
|
updated_data_json["max_budget"]
|
|
== litellm.default_internal_user_params["max_budget"]
|
|
)
|
|
assert (
|
|
updated_data_json["budget_duration"]
|
|
== litellm.default_internal_user_params["budget_duration"]
|
|
)
|
|
|
|
|
|
def test_update_internal_new_user_params_with_user_defined_values():
|
|
from litellm.proxy._types import NewUserRequest
|
|
from litellm.proxy.management_endpoints.internal_user_endpoints import (
|
|
_update_internal_new_user_params,
|
|
)
|
|
|
|
litellm.default_internal_user_params = {
|
|
"max_budget": 100,
|
|
"budget_duration": "30d",
|
|
"models": ["gpt-3.5-turbo"],
|
|
"user_role": "proxy_admin",
|
|
}
|
|
|
|
data = NewUserRequest(
|
|
user_email="krrish3@berri.ai", max_budget=1000, budget_duration="1mo"
|
|
)
|
|
data_json = data.model_dump()
|
|
updated_data_json = _update_internal_new_user_params(data_json, data)
|
|
assert updated_data_json["user_email"] == "krrish3@berri.ai"
|
|
assert updated_data_json["user_role"] == "proxy_admin"
|
|
assert updated_data_json["max_budget"] == 1000
|
|
assert updated_data_json["budget_duration"] == "1mo"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_proxy_config_update_from_db():
|
|
from pydantic import BaseModel
|
|
|
|
from litellm.proxy.proxy_server import ProxyConfig
|
|
|
|
proxy_config = ProxyConfig()
|
|
|
|
pc = AsyncMock()
|
|
|
|
test_config = {
|
|
"litellm_settings": {
|
|
"callbacks": ["prometheus", "otel"],
|
|
}
|
|
}
|
|
|
|
class ReturnValue(BaseModel):
|
|
param_name: str
|
|
param_value: dict
|
|
|
|
with patch.object(
|
|
pc,
|
|
"get_generic_data",
|
|
new=AsyncMock(
|
|
return_value=ReturnValue(
|
|
param_name="litellm_settings",
|
|
param_value={
|
|
"success_callback": "langfuse",
|
|
},
|
|
)
|
|
),
|
|
):
|
|
new_config = await proxy_config._update_config_from_db(
|
|
prisma_client=pc,
|
|
config=test_config,
|
|
store_model_in_db=True,
|
|
)
|
|
|
|
assert new_config == {
|
|
"litellm_settings": {
|
|
"callbacks": ["prometheus", "otel"],
|
|
"success_callback": "langfuse",
|
|
}
|
|
}
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_prepare_key_update_data():
|
|
from litellm.proxy._types import LiteLLM_VerificationToken, UpdateKeyRequest
|
|
from litellm.proxy.management_endpoints.key_management_endpoints import (
|
|
prepare_key_update_data,
|
|
)
|
|
|
|
existing_key_row = MagicMock(spec=LiteLLM_VerificationToken)
|
|
existing_key_row.metadata = {}
|
|
data = UpdateKeyRequest(key="test_key", models=["gpt-4"], duration="120s")
|
|
updated_data = await prepare_key_update_data(data, existing_key_row)
|
|
assert "expires" in updated_data
|
|
|
|
data = UpdateKeyRequest(key="test_key", metadata={})
|
|
updated_data = await prepare_key_update_data(data, existing_key_row)
|
|
assert updated_data["metadata"] == {}
|
|
|
|
data = UpdateKeyRequest(key="test_key", metadata=None)
|
|
updated_data = await prepare_key_update_data(data, existing_key_row)
|
|
assert updated_data["metadata"] is None
|
|
|
|
# Test duration "-1" sets expires to None (never expires)
|
|
data = UpdateKeyRequest(key="test_key", duration="-1")
|
|
updated_data = await prepare_key_update_data(data, existing_key_row)
|
|
assert updated_data["expires"] is None
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"env_vars, expected_url",
|
|
[
|
|
({}, "/redoc"), # default case
|
|
({"REDOC_URL": "/custom-redoc"}, "/custom-redoc"), # custom URL
|
|
(
|
|
{"REDOC_URL": "https://example.com/redoc"},
|
|
"https://example.com/redoc",
|
|
), # full URL
|
|
({"NO_REDOC": "True"}, None), # Redoc disabled
|
|
],
|
|
)
|
|
def test_get_redoc_url(env_vars, expected_url):
|
|
# Clear relevant environment variables
|
|
for key in ["REDOC_URL", "NO_REDOC"]:
|
|
os.environ.pop(key, None)
|
|
|
|
# Set test environment variables
|
|
for key, value in env_vars.items():
|
|
os.environ[key] = value
|
|
|
|
result = _get_redoc_url()
|
|
assert result == expected_url
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"env_vars, expected_url",
|
|
[
|
|
({}, "/"), # default case
|
|
({"DOCS_URL": "/custom-docs"}, "/custom-docs"), # custom URL
|
|
(
|
|
{"DOCS_URL": "https://example.com/docs"},
|
|
"https://example.com/docs",
|
|
), # full URL
|
|
({"NO_DOCS": "True"}, None), # docs disabled
|
|
],
|
|
)
|
|
def test_get_docs_url(env_vars, expected_url):
|
|
# Clear relevant environment variables
|
|
for key in ["DOCS_URL", "NO_DOCS"]:
|
|
os.environ.pop(key, None)
|
|
|
|
# Set test environment variables
|
|
for key, value in env_vars.items():
|
|
os.environ[key] = value
|
|
|
|
result = _get_docs_url()
|
|
assert result == expected_url
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"env_vars, expected_url",
|
|
[
|
|
({}, "/openapi.json"), # default case
|
|
({"OPENAPI_URL": "/custom-openapi.json"}, "/custom-openapi.json"), # custom URL
|
|
(
|
|
{"OPENAPI_URL": "https://example.com/openapi.json"},
|
|
"https://example.com/openapi.json",
|
|
), # full URL
|
|
({"NO_OPENAPI": "True"}, None), # openapi disabled
|
|
],
|
|
)
|
|
def test_get_openapi_url(env_vars, expected_url):
|
|
# Clear relevant environment variables
|
|
for key in ["OPENAPI_URL", "NO_OPENAPI"]:
|
|
os.environ.pop(key, None)
|
|
|
|
# Set test environment variables
|
|
for key, value in env_vars.items():
|
|
os.environ[key] = value
|
|
|
|
result = _get_openapi_url()
|
|
assert result == expected_url
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"request_tags, tags_to_add, expected_tags",
|
|
[
|
|
(None, None, []), # both None
|
|
(["tag1", "tag2"], None, ["tag1", "tag2"]), # tags_to_add is None
|
|
(None, ["tag3", "tag4"], ["tag3", "tag4"]), # request_tags is None
|
|
(
|
|
["tag1", "tag2"],
|
|
["tag3", "tag4"],
|
|
["tag1", "tag2", "tag3", "tag4"],
|
|
), # both have unique tags
|
|
(
|
|
["tag1", "tag2"],
|
|
["tag2", "tag3"],
|
|
["tag1", "tag2", "tag3"],
|
|
), # overlapping tags
|
|
([], [], []), # both empty lists
|
|
("not_a_list", ["tag1"], ["tag1"]), # request_tags invalid type
|
|
(["tag1"], "not_a_list", ["tag1"]), # tags_to_add invalid type
|
|
(
|
|
["tag1"],
|
|
["tag1", "tag2"],
|
|
["tag1", "tag2"],
|
|
), # duplicate tags in inputs
|
|
],
|
|
)
|
|
def test_merge_tags(request_tags, tags_to_add, expected_tags):
|
|
from litellm.proxy.litellm_pre_call_utils import LiteLLMProxyRequestSetup
|
|
|
|
result = LiteLLMProxyRequestSetup._merge_tags(
|
|
request_tags=request_tags, tags_to_add=tags_to_add
|
|
)
|
|
|
|
assert isinstance(result, list)
|
|
assert sorted(result) == sorted(expected_tags)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
@pytest.mark.parametrize(
|
|
"key_tags, request_tags, expected_tags",
|
|
[
|
|
# exact duplicates
|
|
(["tag1", "tag2", "tag3"], ["tag1", "tag2", "tag3"], ["tag1", "tag2", "tag3"]),
|
|
# partial duplicates
|
|
(
|
|
["tag1", "tag2", "tag3"],
|
|
["tag2", "tag3", "tag4"],
|
|
["tag1", "tag2", "tag3", "tag4"],
|
|
),
|
|
# duplicates within key tags
|
|
(["tag1", "tag2"], ["tag3", "tag4"], ["tag1", "tag2", "tag3", "tag4"]),
|
|
# duplicates within request tags
|
|
(["tag1", "tag2"], ["tag2", "tag3", "tag4"], ["tag1", "tag2", "tag3", "tag4"]),
|
|
# case sensitive duplicates
|
|
(["Tag1", "TAG2"], ["tag1", "tag2"], ["Tag1", "TAG2", "tag1", "tag2"]),
|
|
],
|
|
)
|
|
async def test_add_litellm_data_to_request_duplicate_tags(
|
|
key_tags, request_tags, expected_tags
|
|
):
|
|
"""
|
|
Test to verify duplicate tags between request and key metadata are handled correctly
|
|
|
|
|
|
Aggregation logic when checking spend can be impacted if duplicate tags are not handled correctly.
|
|
|
|
User feedback:
|
|
"If I register my key with tag1 and
|
|
also pass the same tag1 when using the key
|
|
then I see tag1 twice in the
|
|
LiteLLM_SpendLogs table request_tags column. This can mess up aggregation logic"
|
|
"""
|
|
mock_request = Mock(spec=Request)
|
|
mock_request.url.path = "/chat/completions"
|
|
mock_request.query_params = {}
|
|
mock_request.headers = {}
|
|
mock_request.state = State()
|
|
|
|
# Setup key with tags in metadata. Opt into client-supplied tags so the
|
|
# request_tags are preserved for the merge under test.
|
|
user_api_key_dict = UserAPIKeyAuth(
|
|
api_key="test_api_key",
|
|
user_id="test_user_id",
|
|
org_id="test_org_id",
|
|
metadata={"tags": key_tags, "allow_client_tags": True},
|
|
)
|
|
|
|
# Setup request data with tags
|
|
data = {"metadata": {"tags": request_tags}}
|
|
|
|
# Process request
|
|
proxy_config = Mock()
|
|
result = await add_litellm_data_to_request(
|
|
data=data,
|
|
request=mock_request,
|
|
user_api_key_dict=user_api_key_dict,
|
|
proxy_config=proxy_config,
|
|
)
|
|
|
|
# Verify results
|
|
assert "metadata" in result
|
|
assert "tags" in result["metadata"]
|
|
assert sorted(result["metadata"]["tags"]) == sorted(
|
|
expected_tags
|
|
), f"Expected {expected_tags}, got {result['metadata']['tags']}"
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"general_settings, user_api_key_dict, request_body, expected_error",
|
|
[
|
|
(
|
|
{"enforced_params": ["param1", "param2"]},
|
|
UserAPIKeyAuth(
|
|
api_key="test_api_key", user_id="test_user_id", org_id="test_org_id"
|
|
),
|
|
{},
|
|
True,
|
|
),
|
|
(
|
|
{"service_account_settings": {"enforced_params": ["user"]}},
|
|
UserAPIKeyAuth(
|
|
api_key="test_api_key", user_id="test_user_id", org_id="test_org_id"
|
|
),
|
|
{},
|
|
False,
|
|
),
|
|
(
|
|
{"service_account_settings": {"enforced_params": ["user"]}},
|
|
UserAPIKeyAuth(
|
|
api_key="test_api_key",
|
|
user_id="test_user_id",
|
|
org_id="test_org_id",
|
|
metadata={"service_account_id": "test_service_account_id"},
|
|
),
|
|
{},
|
|
True,
|
|
),
|
|
(
|
|
{},
|
|
UserAPIKeyAuth(
|
|
api_key="test_api_key",
|
|
metadata={"enforced_params": ["user"]},
|
|
),
|
|
{},
|
|
True,
|
|
),
|
|
(
|
|
{},
|
|
UserAPIKeyAuth(
|
|
api_key="test_api_key",
|
|
metadata={"enforced_params": ["user"]},
|
|
),
|
|
{"user": "test_user"},
|
|
False,
|
|
),
|
|
(
|
|
{"enforced_params": ["user"]},
|
|
UserAPIKeyAuth(
|
|
api_key="test_api_key",
|
|
),
|
|
{"user": "test_user"},
|
|
False,
|
|
),
|
|
(
|
|
{"service_account_settings": {"enforced_params": ["user"]}},
|
|
UserAPIKeyAuth(
|
|
api_key="test_api_key",
|
|
metadata={"service_account_id": "test_service_account_id"},
|
|
),
|
|
{"user": "test_user"},
|
|
False,
|
|
),
|
|
(
|
|
{"enforced_params": ["metadata.generation_name"]},
|
|
UserAPIKeyAuth(
|
|
api_key="test_api_key",
|
|
),
|
|
{"metadata": {}},
|
|
True,
|
|
),
|
|
(
|
|
{"enforced_params": ["metadata.generation_name"]},
|
|
UserAPIKeyAuth(
|
|
api_key="test_api_key",
|
|
),
|
|
{"metadata": {"generation_name": "test_generation_name"}},
|
|
False,
|
|
),
|
|
],
|
|
)
|
|
def test_enforced_params_check(
|
|
general_settings, user_api_key_dict, request_body, expected_error
|
|
):
|
|
from litellm.proxy.litellm_pre_call_utils import _enforced_params_check
|
|
|
|
if expected_error:
|
|
with pytest.raises(ValueError):
|
|
_enforced_params_check(
|
|
request_body=request_body,
|
|
general_settings=general_settings,
|
|
user_api_key_dict=user_api_key_dict,
|
|
premium_user=True,
|
|
)
|
|
else:
|
|
_enforced_params_check(
|
|
request_body=request_body,
|
|
general_settings=general_settings,
|
|
user_api_key_dict=user_api_key_dict,
|
|
premium_user=True,
|
|
)
|
|
|
|
|
|
def test_get_key_models():
|
|
from collections import defaultdict
|
|
|
|
from litellm.proxy.auth.model_checks import get_key_models
|
|
|
|
user_api_key_dict = UserAPIKeyAuth(
|
|
api_key="test_api_key",
|
|
user_id="test_user_id",
|
|
org_id="test_org_id",
|
|
models=["default"],
|
|
)
|
|
proxy_model_list = ["gpt-4o", "gpt-3.5-turbo"]
|
|
model_access_groups = defaultdict(list)
|
|
model_access_groups["default"].extend(["gpt-4o", "gpt-3.5-turbo"])
|
|
model_access_groups["default"].extend(["gpt-4o-mini"])
|
|
model_access_groups["team2"].extend(["gpt-3.5-turbo"])
|
|
|
|
result = get_key_models(
|
|
user_api_key_dict=user_api_key_dict,
|
|
proxy_model_list=proxy_model_list,
|
|
model_access_groups=model_access_groups,
|
|
)
|
|
assert result == ["gpt-4o", "gpt-3.5-turbo", "gpt-4o-mini"]
|
|
|
|
|
|
def test_get_team_models():
|
|
from collections import defaultdict
|
|
|
|
from litellm.proxy.auth.model_checks import get_team_models
|
|
|
|
user_api_key_dict = UserAPIKeyAuth(
|
|
api_key="test_api_key",
|
|
user_id="test_user_id",
|
|
org_id="test_org_id",
|
|
models=[],
|
|
team_models=["default"],
|
|
)
|
|
proxy_model_list = ["gpt-4o", "gpt-3.5-turbo"]
|
|
model_access_groups = defaultdict(list)
|
|
model_access_groups["default"].extend(["gpt-4o", "gpt-3.5-turbo"])
|
|
model_access_groups["default"].extend(["gpt-4o-mini"])
|
|
model_access_groups["team2"].extend(["gpt-3.5-turbo"])
|
|
|
|
team_models = user_api_key_dict.team_models
|
|
result = get_team_models(
|
|
team_models=team_models,
|
|
proxy_model_list=proxy_model_list,
|
|
model_access_groups=model_access_groups,
|
|
)
|
|
assert result == ["gpt-4o", "gpt-3.5-turbo", "gpt-4o-mini"]
|
|
|
|
|
|
def test_update_config_fields():
|
|
from litellm.proxy.proxy_server import ProxyConfig
|
|
|
|
proxy_config = ProxyConfig()
|
|
|
|
args = {
|
|
"current_config": {
|
|
"litellm_settings": {
|
|
"default_team_settings": [
|
|
{
|
|
"team_id": "c91e32bb-0f2a-4aa1-86c4-307ca2e03ea3",
|
|
"success_callback": ["langfuse"],
|
|
"failure_callback": ["langfuse"],
|
|
"langfuse_public_key": "my-fake-key",
|
|
"langfuse_secret": "my-fake-secret",
|
|
}
|
|
]
|
|
},
|
|
},
|
|
"param_name": "litellm_settings",
|
|
"db_param_value": {
|
|
"telemetry": False,
|
|
"drop_params": True,
|
|
"num_retries": 5,
|
|
"request_timeout": 600,
|
|
"success_callback": ["langfuse"],
|
|
"default_team_settings": None,
|
|
"context_window_fallbacks": [{"gpt-3.5-turbo": ["gpt-3.5-turbo-large"]}],
|
|
},
|
|
}
|
|
updated_config = proxy_config._update_config_fields(**args)
|
|
|
|
print("updated_config", updated_config)
|
|
all_team_config = updated_config["litellm_settings"]["default_team_settings"]
|
|
|
|
# check if team id config returned
|
|
print("all_team_config", all_team_config)
|
|
team_config = proxy_config._get_team_config(
|
|
team_id="c91e32bb-0f2a-4aa1-86c4-307ca2e03ea3", all_teams_config=all_team_config
|
|
)
|
|
print("team_config", team_config)
|
|
assert team_config["langfuse_public_key"] == "my-fake-key"
|
|
assert team_config["langfuse_secret"] == "my-fake-secret"
|
|
|
|
|
|
def test_update_config_fields_default_internal_user_params(monkeypatch):
|
|
from litellm.proxy.proxy_server import ProxyConfig
|
|
|
|
proxy_config = ProxyConfig()
|
|
|
|
monkeypatch.setattr(litellm, "default_internal_user_params", None)
|
|
|
|
args = {
|
|
"current_config": {},
|
|
"param_name": "litellm_settings",
|
|
"db_param_value": {
|
|
"default_internal_user_params": {
|
|
"user_role": "proxy_admin",
|
|
"max_budget": 1000,
|
|
"budget_duration": "1mo",
|
|
},
|
|
},
|
|
}
|
|
proxy_config._update_config_fields(**args)
|
|
|
|
assert litellm.default_internal_user_params == {
|
|
"user_role": "proxy_admin",
|
|
"max_budget": 1000,
|
|
"budget_duration": "1mo",
|
|
}
|
|
|
|
monkeypatch.setattr(
|
|
litellm, "default_internal_user_params", None
|
|
) # reset to default
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"proxy_model_list,model_list,provider",
|
|
[
|
|
(
|
|
["openai/*"],
|
|
[{"model_name": "openai/*", "litellm_params": {"model": "openai/*"}}],
|
|
"openai",
|
|
),
|
|
(
|
|
["bedrock/*"],
|
|
[{"model_name": "bedrock/*", "litellm_params": {"model": "bedrock/*"}}],
|
|
"bedrock",
|
|
),
|
|
(
|
|
["anthropic/*"],
|
|
[{"model_name": "anthropic/*", "litellm_params": {"model": "anthropic/*"}}],
|
|
"anthropic",
|
|
),
|
|
(
|
|
["cohere/*"],
|
|
[{"model_name": "cohere/*", "litellm_params": {"model": "cohere/*"}}],
|
|
"cohere",
|
|
),
|
|
],
|
|
)
|
|
def test_get_complete_model_list(proxy_model_list, model_list, provider):
|
|
"""
|
|
Test that get_complete_model_list correctly expands model groups like 'openai/*' into individual models with provider prefixes
|
|
"""
|
|
from litellm import Router
|
|
from litellm.proxy.auth.model_checks import get_complete_model_list
|
|
|
|
llm_router = Router(model_list=model_list)
|
|
|
|
complete_list = get_complete_model_list(
|
|
proxy_model_list=proxy_model_list,
|
|
key_models=[],
|
|
team_models=[],
|
|
user_model=None,
|
|
infer_model_from_keys=False,
|
|
llm_router=llm_router,
|
|
)
|
|
|
|
# Check that we got a non-empty list back
|
|
assert len(complete_list) > 0
|
|
|
|
print("complete_list", json.dumps(complete_list, indent=4))
|
|
|
|
for _model in complete_list:
|
|
assert provider in _model
|
|
|
|
|
|
def test_team_callback_metadata_all_none_values():
|
|
from litellm.proxy._types import TeamCallbackMetadata
|
|
|
|
resp = TeamCallbackMetadata(
|
|
success_callback=None,
|
|
failure_callback=None,
|
|
callback_vars=None,
|
|
)
|
|
|
|
assert resp.success_callback == []
|
|
assert resp.failure_callback == []
|
|
assert resp.callback_vars == {}
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"none_key",
|
|
[
|
|
"success_callback",
|
|
"failure_callback",
|
|
"callback_vars",
|
|
],
|
|
)
|
|
def test_team_callback_metadata_none_values(none_key):
|
|
from litellm.proxy._types import TeamCallbackMetadata
|
|
|
|
if none_key == "success_callback":
|
|
args = {
|
|
"success_callback": None,
|
|
"failure_callback": ["test"],
|
|
"callback_vars": None,
|
|
}
|
|
elif none_key == "failure_callback":
|
|
args = {
|
|
"success_callback": ["test"],
|
|
"failure_callback": None,
|
|
"callback_vars": None,
|
|
}
|
|
elif none_key == "callback_vars":
|
|
args = {
|
|
"success_callback": ["test"],
|
|
"failure_callback": ["test"],
|
|
"callback_vars": None,
|
|
}
|
|
|
|
resp = TeamCallbackMetadata(**args)
|
|
|
|
assert none_key not in resp
|
|
|
|
|
|
def test_proxy_config_state_post_init_callback_call(monkeypatch):
|
|
"""
|
|
Ensures team_id is still in config, after callback is called
|
|
|
|
Addresses issue: https://github.com/BerriAI/litellm/issues/6787
|
|
|
|
Where team_id was being popped from config, after callback was called
|
|
|
|
Note: Environment variables are mocked to avoid validation errors
|
|
in parallel execution where env vars may not be set.
|
|
"""
|
|
from litellm.proxy.litellm_pre_call_utils import LiteLLMProxyRequestSetup
|
|
from litellm.proxy.proxy_server import ProxyConfig
|
|
|
|
# Mock environment variables to avoid Pydantic validation errors
|
|
# when env vars are resolved to None in parallel execution
|
|
monkeypatch.setenv("LANGFUSE_PUBLIC_KEY", "test_public_key")
|
|
monkeypatch.setenv("LANGFUSE_SECRET_KEY", "test_secret_key")
|
|
|
|
pc = ProxyConfig()
|
|
|
|
pc.update_config_state(
|
|
config={
|
|
"litellm_settings": {
|
|
"default_team_settings": [
|
|
{
|
|
"team_id": "test",
|
|
"success_callback": ["langfuse"],
|
|
"langfuse_public_key": "os.environ/LANGFUSE_PUBLIC_KEY",
|
|
"langfuse_secret": "os.environ/LANGFUSE_SECRET_KEY",
|
|
}
|
|
]
|
|
}
|
|
}
|
|
)
|
|
|
|
callback_metadata = LiteLLMProxyRequestSetup.add_team_based_callbacks_from_config(
|
|
team_id="test",
|
|
proxy_config=pc,
|
|
)
|
|
|
|
assert callback_metadata is not None
|
|
assert callback_metadata.callback_vars is not None
|
|
assert callback_metadata.callback_vars["langfuse_public_key"] == "test_public_key"
|
|
assert callback_metadata.callback_vars["langfuse_secret"] == "test_secret_key"
|
|
|
|
config = pc.get_config_state()
|
|
assert config["litellm_settings"]["default_team_settings"][0]["team_id"] == "test"
|
|
|
|
|
|
def test_proxy_config_state_get_config_state_error():
|
|
"""
|
|
Ensures that get_config_state does not raise an error when the config is not a valid dictionary
|
|
"""
|
|
import threading
|
|
|
|
from litellm.proxy.proxy_server import ProxyConfig
|
|
|
|
test_config = {
|
|
"callback_list": [
|
|
{
|
|
"lock": threading.RLock(), # This will cause the deep copy to fail
|
|
"name": "test_callback",
|
|
}
|
|
],
|
|
"model_list": ["gpt-4", "claude-3"],
|
|
}
|
|
|
|
pc = ProxyConfig()
|
|
pc.config = test_config
|
|
config = pc.get_config_state()
|
|
assert config == {}
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"associated_budget_table, expected_user_api_key_auth_key, expected_user_api_key_auth_value",
|
|
[
|
|
(
|
|
{
|
|
"litellm_budget_table_max_budget": None,
|
|
"litellm_budget_table_tpm_limit": None,
|
|
"litellm_budget_table_rpm_limit": 1,
|
|
"litellm_budget_table_model_max_budget": None,
|
|
},
|
|
"rpm_limit",
|
|
1,
|
|
),
|
|
(
|
|
{},
|
|
None,
|
|
None,
|
|
),
|
|
(
|
|
{
|
|
"litellm_budget_table_max_budget": None,
|
|
"litellm_budget_table_tpm_limit": None,
|
|
"litellm_budget_table_rpm_limit": None,
|
|
"litellm_budget_table_model_max_budget": {"gpt-4o": 100},
|
|
},
|
|
"model_max_budget",
|
|
{"gpt-4o": 100},
|
|
),
|
|
],
|
|
)
|
|
def test_litellm_verification_token_view_response_with_budget_table(
|
|
associated_budget_table,
|
|
expected_user_api_key_auth_key,
|
|
expected_user_api_key_auth_value,
|
|
):
|
|
from litellm.proxy._types import LiteLLM_VerificationTokenView
|
|
|
|
args: Dict[str, Any] = {
|
|
"token": "sk-test-mock-token-303",
|
|
"key_name": "sk-...if_g",
|
|
"key_alias": None,
|
|
"soft_budget_cooldown": False,
|
|
"spend": 0.011441999999999997,
|
|
"expires": None,
|
|
"models": [],
|
|
"aliases": {},
|
|
"config": {},
|
|
"user_id": None,
|
|
"team_id": "test",
|
|
"permissions": {},
|
|
"max_parallel_requests": None,
|
|
"metadata": {},
|
|
"blocked": None,
|
|
"tpm_limit": None,
|
|
"rpm_limit": None,
|
|
"max_budget": None,
|
|
"budget_duration": None,
|
|
"budget_reset_at": None,
|
|
"allowed_cache_controls": [],
|
|
"model_spend": {},
|
|
"model_max_budget": {},
|
|
"budget_id": "my-test-tier",
|
|
"created_at": "2024-12-26T02:28:52.615+00:00",
|
|
"updated_at": "2024-12-26T03:01:51.159+00:00",
|
|
"team_spend": 0.012134999999999998,
|
|
"team_max_budget": None,
|
|
"team_tpm_limit": None,
|
|
"team_rpm_limit": None,
|
|
"team_models": [],
|
|
"team_metadata": {},
|
|
"team_blocked": False,
|
|
"team_alias": None,
|
|
"team_members_with_roles": [{"role": "admin", "user_id": "default_user_id"}],
|
|
"team_member_spend": None,
|
|
"team_model_aliases": None,
|
|
"team_member": None,
|
|
**associated_budget_table,
|
|
}
|
|
resp = LiteLLM_VerificationTokenView(**args)
|
|
if expected_user_api_key_auth_key is not None:
|
|
assert (
|
|
getattr(resp, expected_user_api_key_auth_key)
|
|
== expected_user_api_key_auth_value
|
|
)
|
|
|
|
|
|
def test_litellm_verification_token_view_budget_does_not_override_key_model_max_budget():
|
|
"""
|
|
When key has non-empty model_max_budget, budget's model_max_budget is NOT applied.
|
|
Regression test for per-model budget: only apply budget's model_max_budget when key's is empty.
|
|
"""
|
|
from litellm.proxy._types import LiteLLM_VerificationTokenView
|
|
|
|
key_model_max_budget = {"gpt-4": {"max_budget": 50.0, "budget_duration": "1d"}}
|
|
args = {
|
|
"token": "sk-test-mock-token-303",
|
|
"key_name": "sk-...if_g",
|
|
"key_alias": None,
|
|
"soft_budget_cooldown": False,
|
|
"spend": 0.0,
|
|
"expires": None,
|
|
"models": [],
|
|
"aliases": {},
|
|
"config": {},
|
|
"user_id": None,
|
|
"team_id": "test",
|
|
"permissions": {},
|
|
"max_parallel_requests": None,
|
|
"metadata": {},
|
|
"blocked": None,
|
|
"tpm_limit": None,
|
|
"rpm_limit": None,
|
|
"max_budget": None,
|
|
"budget_duration": None,
|
|
"budget_reset_at": None,
|
|
"allowed_cache_controls": [],
|
|
"model_spend": {},
|
|
"model_max_budget": key_model_max_budget,
|
|
"budget_id": "my-test-tier",
|
|
"created_at": "2024-12-26T02:28:52.615+00:00",
|
|
"updated_at": "2024-12-26T03:01:51.159+00:00",
|
|
"team_spend": None,
|
|
"team_max_budget": None,
|
|
"team_tpm_limit": None,
|
|
"team_rpm_limit": None,
|
|
"team_models": [],
|
|
"team_metadata": {},
|
|
"team_blocked": False,
|
|
"team_alias": None,
|
|
"team_members_with_roles": [],
|
|
"team_member_spend": None,
|
|
"team_model_aliases": None,
|
|
"team_member": None,
|
|
"litellm_budget_table_model_max_budget": {
|
|
"gpt-4o": {"max_budget": 100.0, "budget_duration": "1d"}
|
|
},
|
|
}
|
|
resp = LiteLLM_VerificationTokenView(**args)
|
|
assert resp.model_max_budget == key_model_max_budget
|
|
|
|
|
|
def test_is_allowed_to_make_key_request():
|
|
from litellm.proxy._types import LitellmUserRoles
|
|
from litellm.proxy.management_endpoints.key_management_endpoints import (
|
|
_is_allowed_to_make_key_request,
|
|
)
|
|
|
|
assert (
|
|
_is_allowed_to_make_key_request(
|
|
user_api_key_dict=UserAPIKeyAuth(
|
|
user_id="test_user_id", user_role=LitellmUserRoles.PROXY_ADMIN
|
|
),
|
|
user_id="test_user_id",
|
|
team_id="test_team_id",
|
|
)
|
|
is True
|
|
)
|
|
|
|
assert (
|
|
_is_allowed_to_make_key_request(
|
|
user_api_key_dict=UserAPIKeyAuth(
|
|
user_id="test_user_id",
|
|
user_role=LitellmUserRoles.INTERNAL_USER,
|
|
team_id="litellm-dashboard",
|
|
),
|
|
user_id="test_user_id",
|
|
team_id="test_team_id",
|
|
)
|
|
is True
|
|
)
|
|
|
|
|
|
def test_get_model_group_info():
|
|
from litellm import Router
|
|
from litellm.proxy.proxy_server import _get_model_group_info
|
|
|
|
router = Router(
|
|
model_list=[
|
|
{
|
|
"model_name": "openai/tts-1",
|
|
"litellm_params": {
|
|
"model": "openai/tts-1",
|
|
"api_key": "sk-1234",
|
|
},
|
|
},
|
|
{
|
|
"model_name": "openai/gpt-3.5-turbo",
|
|
"litellm_params": {
|
|
"model": "openai/gpt-3.5-turbo",
|
|
"api_key": "sk-1234",
|
|
},
|
|
},
|
|
]
|
|
)
|
|
model_list = _get_model_group_info(
|
|
llm_router=router,
|
|
all_models_str=["openai/tts-1", "openai/gpt-3.5-turbo"],
|
|
model_group="openai/tts-1",
|
|
)
|
|
assert len(model_list) == 1
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_team_data():
|
|
return [
|
|
{"team_id": "team1", "team_name": "Test Team 1"},
|
|
{"team_id": "team2", "team_name": "Test Team 2"},
|
|
]
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_key_data():
|
|
return [
|
|
{"token": "test_token_1", "key_name": "key1", "team_id": None, "spend": 0},
|
|
{"token": "test_token_2", "key_name": "key2", "team_id": "team1", "spend": 100},
|
|
{
|
|
"token": "test_token_3",
|
|
"key_name": "key3",
|
|
"team_id": "litellm-dashboard",
|
|
"spend": 50,
|
|
},
|
|
]
|
|
|
|
|
|
class MockDb:
|
|
def __init__(self, mock_team_data, mock_key_data):
|
|
self.mock_team_data = mock_team_data
|
|
self.mock_key_data = mock_key_data
|
|
|
|
async def query_raw(self, query: str, *args):
|
|
# Simulate the SQL query response
|
|
filtered_keys = [
|
|
k
|
|
for k in self.mock_key_data
|
|
if k["team_id"] != "litellm-dashboard" or k["team_id"] is None
|
|
]
|
|
|
|
return [{"teams": self.mock_team_data, "keys": filtered_keys}]
|
|
|
|
|
|
class MockPrismaClientDB:
|
|
def __init__(
|
|
self,
|
|
mock_team_data,
|
|
mock_key_data,
|
|
):
|
|
self.db = MockDb(mock_team_data, mock_key_data)
|
|
|
|
async def get_data(
|
|
self,
|
|
token: Optional[Union[str, list]] = None,
|
|
user_id: Optional[str] = None,
|
|
user_id_list: Optional[list] = None,
|
|
team_id: Optional[str] = None,
|
|
team_id_list: Optional[list] = None,
|
|
key_val: Optional[dict] = None,
|
|
table_name: Optional[str] = None,
|
|
query_type: str = "find_unique",
|
|
expires: Optional[datetime] = None,
|
|
reset_at: Optional[datetime] = None,
|
|
offset: Optional[int] = None,
|
|
limit: Optional[int] = None,
|
|
):
|
|
"""Mock get_data method to return user info for admin"""
|
|
from litellm.proxy._types import LiteLLM_UserTable
|
|
|
|
# Return a proper LiteLLM_UserTable object when querying by user_id
|
|
if user_id:
|
|
return LiteLLM_UserTable(
|
|
user_id=user_id,
|
|
user_role="proxy_admin",
|
|
spend=0.0,
|
|
max_budget=None,
|
|
)
|
|
return None
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_user_info_for_proxy_admin(mock_team_data, mock_key_data):
|
|
# Patch the prisma_client import
|
|
from litellm.proxy._types import UserAPIKeyAuth, UserInfoResponse
|
|
|
|
# Create a mock user_api_key_dict for admin user
|
|
mock_user_api_key_dict = UserAPIKeyAuth(
|
|
user_id="admin_user_123",
|
|
user_role="proxy_admin",
|
|
)
|
|
|
|
with patch(
|
|
"litellm.proxy.proxy_server.prisma_client",
|
|
MockPrismaClientDB(mock_team_data, mock_key_data),
|
|
):
|
|
from litellm.proxy.management_endpoints.internal_user_endpoints import (
|
|
_get_user_info_for_proxy_admin,
|
|
)
|
|
|
|
# Execute the function
|
|
result = await _get_user_info_for_proxy_admin(
|
|
user_api_key_dict=mock_user_api_key_dict
|
|
)
|
|
|
|
# Verify the result structure
|
|
assert isinstance(result, UserInfoResponse)
|
|
assert len(result.keys) == 2
|
|
# Verify admin's user_id is populated
|
|
assert result.user_id == "admin_user_123"
|
|
# Verify admin's user_info is populated
|
|
assert result.user_info is not None
|
|
assert result.user_info["user_id"] == "admin_user_123"
|
|
|
|
|
|
def test_custom_openid_response():
|
|
from litellm.caching import DualCache
|
|
from litellm.proxy._types import LiteLLM_JWTAuth
|
|
from litellm.proxy.management_endpoints.ui_sso import (
|
|
JWTHandler,
|
|
generic_response_convertor,
|
|
)
|
|
|
|
jwt_handler = JWTHandler()
|
|
jwt_handler.update_environment(
|
|
prisma_client={},
|
|
user_api_key_cache=DualCache(),
|
|
litellm_jwtauth=LiteLLM_JWTAuth(
|
|
team_ids_jwt_field="department",
|
|
),
|
|
)
|
|
response = {
|
|
"sub": "3f196e06-7484-451e-be5a-ea6c6bb86c5b",
|
|
"email_verified": True,
|
|
"name": "Krish Dholakia",
|
|
"preferred_username": "krrishd",
|
|
"given_name": "Krish",
|
|
"department": ["/test-group"],
|
|
"family_name": "Dholakia",
|
|
"email": "krrishdholakia@gmail.com",
|
|
}
|
|
|
|
resp = generic_response_convertor(
|
|
response=response,
|
|
jwt_handler=jwt_handler,
|
|
)
|
|
assert resp.team_ids == ["/test-group"]
|
|
|
|
|
|
def test_update_key_request_validation():
|
|
"""
|
|
Ensures that the UpdateKeyRequest model validates the temp_budget_increase and temp_budget_expiry fields together
|
|
"""
|
|
from litellm.proxy._types import UpdateKeyRequest
|
|
|
|
with pytest.raises(Exception):
|
|
UpdateKeyRequest(
|
|
key="test_key",
|
|
temp_budget_increase=100,
|
|
)
|
|
|
|
with pytest.raises(Exception):
|
|
UpdateKeyRequest(
|
|
key="test_key",
|
|
temp_budget_expiry="2024-01-20T00:00:00Z",
|
|
)
|
|
|
|
UpdateKeyRequest(
|
|
key="test_key",
|
|
temp_budget_increase=100,
|
|
temp_budget_expiry="2024-01-20T00:00:00Z",
|
|
)
|
|
|
|
|
|
def test_get_temp_budget_increase():
|
|
from datetime import datetime, timedelta
|
|
|
|
from litellm.proxy._types import UserAPIKeyAuth
|
|
from litellm.proxy.auth.user_api_key_auth import _get_temp_budget_increase
|
|
|
|
expiry = datetime.now() + timedelta(days=1)
|
|
expiry_in_isoformat = expiry.isoformat()
|
|
|
|
valid_token = UserAPIKeyAuth(
|
|
max_budget=100,
|
|
spend=0,
|
|
metadata={
|
|
"temp_budget_increase": 100,
|
|
"temp_budget_expiry": expiry_in_isoformat,
|
|
},
|
|
)
|
|
assert _get_temp_budget_increase(valid_token) == 100
|
|
|
|
|
|
def test_update_key_budget_with_temp_budget_increase():
|
|
from datetime import datetime, timedelta
|
|
|
|
from litellm.proxy._types import UserAPIKeyAuth
|
|
from litellm.proxy.auth.user_api_key_auth import (
|
|
_update_key_budget_with_temp_budget_increase,
|
|
)
|
|
|
|
expiry = datetime.now() + timedelta(days=1)
|
|
expiry_in_isoformat = expiry.isoformat()
|
|
|
|
valid_token = UserAPIKeyAuth(
|
|
max_budget=100,
|
|
spend=0,
|
|
metadata={
|
|
"temp_budget_increase": 100,
|
|
"temp_budget_expiry": expiry_in_isoformat,
|
|
},
|
|
)
|
|
assert _update_key_budget_with_temp_budget_increase(valid_token).max_budget == 200
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_health_check_not_called_when_disabled(monkeypatch):
|
|
from litellm.proxy.proxy_server import ProxyStartupEvent
|
|
|
|
# Mock environment variable
|
|
monkeypatch.setenv("DISABLE_PRISMA_HEALTH_CHECK_ON_STARTUP", "true")
|
|
|
|
# Create mock prisma client
|
|
mock_prisma = MagicMock()
|
|
mock_prisma.connect = AsyncMock()
|
|
mock_prisma.health_check = AsyncMock()
|
|
mock_prisma.check_view_exists = AsyncMock()
|
|
mock_prisma._set_spend_logs_row_count_in_proxy_state = AsyncMock()
|
|
mock_prisma.start_db_health_watchdog_task = AsyncMock()
|
|
# Mock the db attribute with start_token_refresh_task for RDS IAM token refresh
|
|
mock_db = MagicMock()
|
|
mock_db.start_token_refresh_task = AsyncMock()
|
|
mock_prisma.db = mock_db
|
|
# Mock PrismaClient constructor
|
|
monkeypatch.setattr(
|
|
"litellm.proxy.proxy_server.PrismaClient", lambda **kwargs: mock_prisma
|
|
)
|
|
|
|
# Call the setup function
|
|
await ProxyStartupEvent._setup_prisma_client(
|
|
database_url="mock_url",
|
|
proxy_logging_obj=MagicMock(),
|
|
user_api_key_cache=MagicMock(),
|
|
)
|
|
|
|
# Verify health check wasn't called
|
|
mock_prisma.health_check.assert_not_called()
|
|
|
|
|
|
@patch(
|
|
"litellm.proxy.proxy_server.get_openapi_schema",
|
|
return_value={
|
|
"paths": {
|
|
"/new/route": {"get": {"summary": "New"}},
|
|
}
|
|
},
|
|
)
|
|
def test_custom_openapi(mock_get_openapi_schema):
|
|
from litellm.proxy.proxy_server import custom_openapi
|
|
|
|
openapi_schema = custom_openapi()
|
|
assert openapi_schema is not None
|
|
|
|
|
|
from litellm.proxy.utils import ProxyUpdateSpend
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_end_user_transactions_reset():
|
|
# Setup
|
|
mock_client = MagicMock()
|
|
end_user_list_transactions = {"1": 10.0} # Bad log
|
|
mock_client.db.tx = AsyncMock(side_effect=Exception("DB Error"))
|
|
|
|
# Call function - should raise error
|
|
with pytest.raises(Exception):
|
|
await ProxyUpdateSpend.update_end_user_spend(
|
|
n_retry_times=0,
|
|
prisma_client=mock_client,
|
|
proxy_logging_obj=MagicMock(),
|
|
end_user_list_transactions=end_user_list_transactions,
|
|
)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_spend_logs_cleanup_after_error():
|
|
# Setup test data
|
|
import asyncio
|
|
|
|
mock_client = MagicMock()
|
|
mock_client.spend_log_transactions = [
|
|
{"id": 1, "amount": 10.0},
|
|
{"id": 2, "amount": 20.0},
|
|
{"id": 3, "amount": 30.0},
|
|
]
|
|
# Add lock for spend_log_transactions (matches real PrismaClient)
|
|
mock_client._spend_log_transactions_lock = asyncio.Lock()
|
|
# Make the DB operation fail
|
|
mock_client.db.litellm_spendlogs.create_many = AsyncMock(
|
|
side_effect=Exception("DB Error")
|
|
)
|
|
|
|
original_logs = mock_client.spend_log_transactions.copy()
|
|
|
|
# Call function - should raise error
|
|
with pytest.raises(Exception):
|
|
await ProxyUpdateSpend.update_spend_logs(
|
|
n_retry_times=0,
|
|
prisma_client=mock_client,
|
|
db_writer_client=None, # Test DB write path
|
|
proxy_logging_obj=MagicMock(),
|
|
)
|
|
|
|
# Verify the first batch was removed from spend_log_transactions
|
|
assert (
|
|
mock_client.spend_log_transactions == original_logs[100:]
|
|
), "Should remove processed logs even after error"
|
|
|
|
|
|
def test_provider_specific_header():
|
|
"""Test that provider_specific_header is set correctly for Anthropic headers."""
|
|
from litellm.proxy.litellm_pre_call_utils import (
|
|
add_provider_specific_headers_to_request,
|
|
)
|
|
|
|
data = {
|
|
"model": "gemini-1.5-flash",
|
|
"messages": [
|
|
{
|
|
"role": "user",
|
|
"content": [{"type": "text", "text": "Tell me a joke"}],
|
|
}
|
|
],
|
|
"stream": True,
|
|
"proxy_server_request": {
|
|
"url": "http://0.0.0.0:4000/v1/chat/completions",
|
|
"method": "POST",
|
|
"headers": {
|
|
"content-type": "application/json",
|
|
"anthropic-beta": "prompt-caching-2024-07-31",
|
|
"user-agent": "PostmanRuntime/7.32.3",
|
|
"accept": "*/*",
|
|
"postman-token": "81cccd87-c91d-4b2f-b252-c0fe0ca82529",
|
|
"host": "0.0.0.0:4000",
|
|
"accept-encoding": "gzip, deflate, br",
|
|
"connection": "keep-alive",
|
|
"content-length": "240",
|
|
},
|
|
"body": {
|
|
"model": "gemini-1.5-flash",
|
|
"messages": [
|
|
{
|
|
"role": "user",
|
|
"content": [{"type": "text", "text": "Tell me a joke"}],
|
|
}
|
|
],
|
|
"stream": True,
|
|
},
|
|
},
|
|
}
|
|
|
|
headers = {
|
|
"content-type": "application/json",
|
|
"anthropic-beta": "prompt-caching-2024-07-31",
|
|
"user-agent": "PostmanRuntime/7.32.3",
|
|
"accept": "*/*",
|
|
"postman-token": "81cccd87-c91d-4b2f-b252-c0fe0ca82529",
|
|
"host": "0.0.0.0:4000",
|
|
"accept-encoding": "gzip, deflate, br",
|
|
"connection": "keep-alive",
|
|
"content-length": "240",
|
|
}
|
|
|
|
add_provider_specific_headers_to_request(
|
|
data=data,
|
|
headers=headers,
|
|
)
|
|
# Verify multi-provider support: anthropic headers work across multiple providers
|
|
assert data["provider_specific_header"] == {
|
|
"custom_llm_provider": "anthropic,bedrock,vertex_ai",
|
|
"extra_headers": {
|
|
"anthropic-beta": "prompt-caching-2024-07-31",
|
|
},
|
|
}
|
|
|
|
|
|
def test_provider_specific_header_multi_provider():
|
|
"""Test that provider_specific_header supports multiple providers for Anthropic headers."""
|
|
from litellm.proxy.litellm_pre_call_utils import (
|
|
add_provider_specific_headers_to_request,
|
|
)
|
|
|
|
data = {
|
|
"model": "gemini-1.5-flash",
|
|
"messages": [
|
|
{
|
|
"role": "user",
|
|
"content": [{"type": "text", "text": "Tell me a joke"}],
|
|
}
|
|
],
|
|
"stream": True,
|
|
"proxy_server_request": {
|
|
"url": "http://0.0.0.0:4000/v1/chat/completions",
|
|
"method": "POST",
|
|
"headers": {
|
|
"content-type": "application/json",
|
|
"anthropic-beta": "context-1m-2025-08-07",
|
|
"anthropic-version": "2023-06-01",
|
|
"user-agent": "PostmanRuntime/7.32.3",
|
|
"accept": "*/*",
|
|
"postman-token": "81cccd87-c91d-4b2f-b252-c0fe0ca82529",
|
|
"host": "0.0.0.0:4000",
|
|
"accept-encoding": "gzip, deflate, br",
|
|
"connection": "keep-alive",
|
|
"content-length": "240",
|
|
},
|
|
"body": {
|
|
"model": "gemini-1.5-flash",
|
|
"messages": [
|
|
{
|
|
"role": "user",
|
|
"content": [{"type": "text", "text": "Tell me a joke"}],
|
|
}
|
|
],
|
|
"stream": True,
|
|
},
|
|
},
|
|
}
|
|
|
|
headers = {
|
|
"content-type": "application/json",
|
|
"anthropic-beta": "context-1m-2025-08-07",
|
|
"anthropic-version": "2023-06-01",
|
|
"user-agent": "PostmanRuntime/7.32.3",
|
|
"accept": "*/*",
|
|
"postman-token": "81cccd87-c91d-4b2f-b252-c0fe0ca82529",
|
|
"host": "0.0.0.0:4000",
|
|
"accept-encoding": "gzip, deflate, br",
|
|
"connection": "keep-alive",
|
|
"content-length": "240",
|
|
}
|
|
|
|
add_provider_specific_headers_to_request(
|
|
data=data,
|
|
headers=headers,
|
|
)
|
|
|
|
# Verify that provider_specific_header contains comma-separated providers
|
|
assert "provider_specific_header" in data
|
|
assert (
|
|
data["provider_specific_header"]["custom_llm_provider"]
|
|
== "anthropic,bedrock,vertex_ai"
|
|
)
|
|
assert data["provider_specific_header"]["extra_headers"] == {
|
|
"anthropic-beta": "context-1m-2025-08-07",
|
|
"anthropic-version": "2023-06-01",
|
|
}
|
|
|
|
|
|
# @pytest.mark.parametrize(
|
|
# "custom_llm_provider, expected_result",
|
|
# [
|
|
# ("anthropic", {"anthropic-beta": "test"}),
|
|
# ("bedrock", {"anthropic-beta": "test"}),
|
|
# ("vertex_ai", {"anthropic-beta": "test"}),
|
|
# ],
|
|
# )
|
|
# def test_provider_specific_header_in_request(custom_llm_provider, expected_result):
|
|
# from litellm.types.utils import ProviderSpecificHeader
|
|
# from litellm.llms.custom_httpx.http_handler import HTTPHandler
|
|
# from unittest.mock import patch
|
|
|
|
# litellm.set_verbose = True
|
|
# client = HTTPHandler()
|
|
# with patch.object(client, "post", return_value=MagicMock()) as mock_post:
|
|
# try:
|
|
# litellm.completion(
|
|
# model="anthropic/claude-3-5-sonnet-v2@20241022",
|
|
# messages=[{"role": "user", "content": "Hello world"}],
|
|
# provider_specific_header=ProviderSpecificHeader(
|
|
# custom_llm_provider="anthropic",
|
|
# extra_headers={"anthropic-beta": "test"},
|
|
# ),
|
|
# client=client,
|
|
# )
|
|
# except Exception as e:
|
|
# print(f"Error: {e}")
|
|
|
|
# mock_post.assert_called_once()
|
|
# print(mock_post.call_args.kwargs["headers"])
|
|
# assert "anthropic-beta" in mock_post.call_args.kwargs["headers"]
|
|
|
|
|
|
from litellm.proxy._types import LiteLLM_UserTable
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"wildcard_model, litellm_params, expected_models",
|
|
[
|
|
(
|
|
"anthropic/*",
|
|
{"model": "anthropic/*"},
|
|
["anthropic/claude-haiku-4-5-20251001", "anthropic/claude-opus-4-6"],
|
|
),
|
|
(
|
|
"vertex_ai/gemini-*",
|
|
{"model": "vertex_ai/gemini-*"},
|
|
["vertex_ai/gemini-2.5-flash", "vertex_ai/gemini-2.5-pro"],
|
|
),
|
|
(
|
|
"foo/*",
|
|
{"model": "openai/*"},
|
|
["foo/gpt-4o", "foo/gpt-4o-mini"],
|
|
),
|
|
],
|
|
)
|
|
def test_get_known_models_from_wildcard(
|
|
wildcard_model, litellm_params, expected_models
|
|
):
|
|
from litellm.proxy.auth.model_checks import get_known_models_from_wildcard
|
|
from litellm.types.router import LiteLLM_Params
|
|
|
|
wildcard_models = get_known_models_from_wildcard(
|
|
wildcard_model=wildcard_model, litellm_params=LiteLLM_Params(**litellm_params)
|
|
)
|
|
# Check if all expected models are in the returned list
|
|
print(f"wildcard_models: {wildcard_models}\n")
|
|
for model in expected_models:
|
|
if model not in wildcard_models:
|
|
print(f"Missing expected model: {model}")
|
|
|
|
assert all(model in wildcard_models for model in expected_models)
|
|
|
|
|
|
def test_get_known_models_from_wildcard_without_litellm_params():
|
|
"""
|
|
Test wildcard expansion without litellm_params (BYOK case - team has openai/*
|
|
but no deployment in router config).
|
|
"""
|
|
from litellm.proxy.auth.model_checks import get_known_models_from_wildcard
|
|
|
|
wildcard_models = get_known_models_from_wildcard(
|
|
wildcard_model="openai/*", litellm_params=None
|
|
)
|
|
# Should return expanded OpenAI models (gpt-4o, gpt-4o-mini, etc.)
|
|
assert len(wildcard_models) > 0
|
|
assert all(m.startswith("openai/") for m in wildcard_models)
|
|
# Check for common OpenAI models
|
|
model_ids = [m.split("/", 1)[1] for m in wildcard_models]
|
|
assert "gpt-4o" in model_ids or "gpt-3.5-turbo" in model_ids
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"data, user_api_key_dict, expected_model",
|
|
[
|
|
# Test case 1: Model exists in team aliases
|
|
(
|
|
{"model": "gpt-4o"},
|
|
UserAPIKeyAuth(
|
|
api_key="test_key", team_model_aliases={"gpt-4o": "gpt-4o-team-1"}
|
|
),
|
|
"gpt-4o-team-1",
|
|
),
|
|
# Test case 2: Model doesn't exist in team aliases
|
|
(
|
|
{"model": "gpt-4o"},
|
|
UserAPIKeyAuth(
|
|
api_key="test_key", team_model_aliases={"claude-3": "claude-3-team-1"}
|
|
),
|
|
"gpt-4o",
|
|
),
|
|
# Test case 3: No team aliases defined
|
|
(
|
|
{"model": "gpt-4o"},
|
|
UserAPIKeyAuth(api_key="test_key", team_model_aliases=None),
|
|
"gpt-4o",
|
|
),
|
|
# Test case 4: No model in request data
|
|
(
|
|
{"messages": []},
|
|
UserAPIKeyAuth(
|
|
api_key="test_key", team_model_aliases={"gpt-4o": "gpt-4o-team-1"}
|
|
),
|
|
None,
|
|
),
|
|
],
|
|
)
|
|
def test_update_model_if_team_alias_exists(data, user_api_key_dict, expected_model):
|
|
from litellm.proxy.litellm_pre_call_utils import _update_model_if_team_alias_exists
|
|
|
|
# Make a copy of the input data to avoid modifying the test parameters
|
|
test_data = data.copy()
|
|
|
|
# Call the function
|
|
_update_model_if_team_alias_exists(
|
|
data=test_data, user_api_key_dict=user_api_key_dict
|
|
)
|
|
|
|
# Check if model was updated correctly
|
|
assert test_data.get("model") == expected_model
|
|
|
|
|
|
def test_team_alias_stale_bypass_disabled_by_default(monkeypatch):
|
|
monkeypatch.delenv("LITELLM_ENABLE_TEAM_STALE_ALIAS_BYPASS", raising=False)
|
|
import litellm.proxy.litellm_pre_call_utils as pre_call_utils
|
|
from litellm.proxy.litellm_pre_call_utils import _update_model_if_team_alias_exists
|
|
|
|
# Reset module-level cache to ensure test isolation
|
|
pre_call_utils._ENABLE_TEAM_STALE_ALIAS_BYPASS = None
|
|
|
|
class _MockRouter:
|
|
team_model_to_deployment_indices = {("team-1", "gpt-4o"): [0]}
|
|
|
|
test_data = {"model": "gpt-4o"}
|
|
user_api_key_dict = UserAPIKeyAuth(
|
|
api_key="test_key",
|
|
team_id="team-1",
|
|
team_model_aliases={"gpt-4o": "model_name_team-1_legacy-uuid"},
|
|
)
|
|
|
|
with patch("litellm.proxy.proxy_server.llm_router", _MockRouter()):
|
|
_update_model_if_team_alias_exists(
|
|
data=test_data, user_api_key_dict=user_api_key_dict
|
|
)
|
|
|
|
assert test_data.get("model") == "model_name_team-1_legacy-uuid"
|
|
|
|
|
|
def test_team_alias_stale_bypass_enabled_by_flag(monkeypatch):
|
|
import litellm.proxy.litellm_pre_call_utils as pre_call_utils
|
|
from litellm.proxy.litellm_pre_call_utils import _update_model_if_team_alias_exists
|
|
|
|
# Reset module-level cache to ensure test isolation
|
|
pre_call_utils._ENABLE_TEAM_STALE_ALIAS_BYPASS = None
|
|
|
|
class _MockRouter:
|
|
team_model_to_deployment_indices = {("team-1", "gpt-4o"): [0]}
|
|
|
|
test_data = {"model": "gpt-4o"}
|
|
user_api_key_dict = UserAPIKeyAuth(
|
|
api_key="test_key",
|
|
team_id="team-1",
|
|
team_model_aliases={"gpt-4o": "model_name_team-1_legacy-uuid"},
|
|
)
|
|
monkeypatch.setenv("LITELLM_ENABLE_TEAM_STALE_ALIAS_BYPASS", "true")
|
|
|
|
with patch("litellm.proxy.proxy_server.llm_router", _MockRouter()):
|
|
_update_model_if_team_alias_exists(
|
|
data=test_data, user_api_key_dict=user_api_key_dict
|
|
)
|
|
|
|
assert test_data.get("model") == "gpt-4o"
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_prisma_client():
|
|
client = MagicMock()
|
|
client.db = MagicMock()
|
|
client.db.litellm_teamtable = AsyncMock()
|
|
return client
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
@pytest.mark.parametrize(
|
|
"test_id, user_info, user_role, mock_teams, expected_teams, should_query_db",
|
|
[
|
|
("no_user_info", None, "proxy_admin", None, [], False),
|
|
(
|
|
"no_teams_found",
|
|
LiteLLM_UserTable(
|
|
teams=["team1", "team2"],
|
|
user_id="user1",
|
|
max_budget=100,
|
|
spend=0,
|
|
user_email="user1@example.com",
|
|
user_role="proxy_admin",
|
|
),
|
|
"proxy_admin",
|
|
None,
|
|
[],
|
|
True,
|
|
),
|
|
(
|
|
"admin_user_with_teams",
|
|
LiteLLM_UserTable(
|
|
teams=["team1", "team2"],
|
|
user_id="user1",
|
|
max_budget=100,
|
|
spend=0,
|
|
user_email="user1@example.com",
|
|
user_role="proxy_admin",
|
|
),
|
|
"proxy_admin",
|
|
[
|
|
MagicMock(
|
|
model_dump=lambda: {
|
|
"team_id": "team1",
|
|
"members_with_roles": [{"role": "admin", "user_id": "user1"}],
|
|
}
|
|
),
|
|
MagicMock(
|
|
model_dump=lambda: {
|
|
"team_id": "team2",
|
|
"members_with_roles": [
|
|
{"role": "admin", "user_id": "user1"},
|
|
{"role": "user", "user_id": "user2"},
|
|
],
|
|
}
|
|
),
|
|
],
|
|
["team1", "team2"],
|
|
True,
|
|
),
|
|
(
|
|
"non_admin_user",
|
|
LiteLLM_UserTable(
|
|
teams=["team1", "team2"],
|
|
user_id="user1",
|
|
max_budget=100,
|
|
spend=0,
|
|
user_email="user1@example.com",
|
|
user_role="internal_user",
|
|
),
|
|
"internal_user",
|
|
[
|
|
MagicMock(
|
|
model_dump=lambda: {"team_id": "team1", "members": ["user1"]}
|
|
),
|
|
MagicMock(
|
|
model_dump=lambda: {
|
|
"team_id": "team2",
|
|
"members": ["user1", "user2"],
|
|
}
|
|
),
|
|
],
|
|
[],
|
|
True,
|
|
),
|
|
],
|
|
)
|
|
async def test_get_admin_team_ids(
|
|
test_id: str,
|
|
user_info: Optional[LiteLLM_UserTable],
|
|
user_role: str,
|
|
mock_teams: Optional[List[MagicMock]],
|
|
expected_teams: List[str],
|
|
should_query_db: bool,
|
|
mock_prisma_client,
|
|
):
|
|
from litellm.proxy.management_endpoints.key_management_endpoints import (
|
|
get_admin_team_ids,
|
|
)
|
|
|
|
# Setup
|
|
mock_prisma_client.db.litellm_teamtable.find_many.return_value = mock_teams
|
|
user_api_key_dict = UserAPIKeyAuth(
|
|
user_role=user_role, user_id=user_info.user_id if user_info else None
|
|
)
|
|
|
|
# Execute
|
|
result = await get_admin_team_ids(
|
|
complete_user_info=user_info,
|
|
user_api_key_dict=user_api_key_dict,
|
|
prisma_client=mock_prisma_client,
|
|
)
|
|
|
|
# Assert
|
|
assert result == expected_teams, f"Expected {expected_teams}, but got {result}"
|
|
|
|
if should_query_db:
|
|
mock_prisma_client.db.litellm_teamtable.find_many.assert_called_once_with(
|
|
where={"team_id": {"in": user_info.teams}}
|
|
)
|
|
else:
|
|
mock_prisma_client.db.litellm_teamtable.find_many.assert_not_called()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_post_call_failure_hook_auth_error_key_info_route():
|
|
"""
|
|
Test that post_call_failure_hook does NOT call _handle_logging_proxy_only_error
|
|
when we get an auth error from /key/info route (since it's not an LLM API route).
|
|
"""
|
|
from unittest.mock import AsyncMock, patch
|
|
|
|
from fastapi import HTTPException
|
|
|
|
from litellm.caching.caching import DualCache
|
|
from litellm.proxy._types import ProxyErrorTypes
|
|
from litellm.proxy.utils import ProxyLogging
|
|
|
|
# Setup
|
|
cache = DualCache()
|
|
proxy_logging = ProxyLogging(user_api_key_cache=cache)
|
|
|
|
# Mock the _handle_logging_proxy_only_error method
|
|
with patch.object(
|
|
proxy_logging, "_handle_logging_proxy_only_error", new_callable=AsyncMock
|
|
) as mock_handle_logging:
|
|
# Create an auth error (HTTPException)
|
|
auth_error = HTTPException(
|
|
status_code=401, detail="Authentication Error: invalid user key"
|
|
)
|
|
|
|
# Create request data for /key/info route
|
|
request_data = {
|
|
"route": "/key/info",
|
|
"model": "gpt-4",
|
|
"messages": [{"role": "user", "content": "test"}],
|
|
"litellm_call_id": "test_call_id_123",
|
|
}
|
|
|
|
# Create user API key dict
|
|
user_api_key_dict = UserAPIKeyAuth(
|
|
api_key="test_key", user_id="test_user", token="test_token"
|
|
)
|
|
|
|
# Call post_call_failure_hook with auth error from /key/info route
|
|
await proxy_logging.post_call_failure_hook(
|
|
request_data=request_data,
|
|
original_exception=auth_error,
|
|
user_api_key_dict=user_api_key_dict,
|
|
error_type=ProxyErrorTypes.auth_error,
|
|
route="/key/info",
|
|
)
|
|
|
|
# Assert that _handle_logging_proxy_only_error was NOT called
|
|
# because /key/info is not an LLM API route
|
|
mock_handle_logging.assert_not_called()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_post_call_failure_hook_auth_error_llm_api_route():
|
|
"""
|
|
Test that post_call_failure_hook DOES call _handle_logging_proxy_only_error
|
|
when we get an auth error from /v1/chat/completions route (since it is an LLM API route).
|
|
"""
|
|
from unittest.mock import AsyncMock, patch
|
|
|
|
from fastapi import HTTPException
|
|
|
|
from litellm.caching.caching import DualCache
|
|
from litellm.proxy._types import ProxyErrorTypes
|
|
from litellm.proxy.utils import ProxyLogging
|
|
|
|
# Setup
|
|
cache = DualCache()
|
|
proxy_logging = ProxyLogging(user_api_key_cache=cache)
|
|
|
|
# Mock the _handle_logging_proxy_only_error method
|
|
with patch.object(
|
|
proxy_logging, "_handle_logging_proxy_only_error", new_callable=AsyncMock
|
|
) as mock_handle_logging:
|
|
# Create an auth error (HTTPException)
|
|
auth_error = HTTPException(
|
|
status_code=401, detail="Authentication Error: invalid user key"
|
|
)
|
|
|
|
# Create request data for /v1/chat/completions route
|
|
request_data = {
|
|
"route": "/v1/chat/completions",
|
|
"model": "gpt-4",
|
|
"messages": [{"role": "user", "content": "test"}],
|
|
"litellm_call_id": "test_call_id_123",
|
|
}
|
|
|
|
# Create user API key dict
|
|
user_api_key_dict = UserAPIKeyAuth(
|
|
api_key="test_key",
|
|
user_id="test_user",
|
|
token="test_token",
|
|
request_route="/v1/chat/completions",
|
|
)
|
|
|
|
# Call post_call_failure_hook with auth error from /v1/chat/completions route
|
|
await proxy_logging.post_call_failure_hook(
|
|
request_data=request_data,
|
|
original_exception=auth_error,
|
|
user_api_key_dict=user_api_key_dict,
|
|
error_type=ProxyErrorTypes.auth_error,
|
|
route="/v1/chat/completions",
|
|
)
|
|
|
|
# Assert that _handle_logging_proxy_only_error WAS called
|
|
# because /v1/chat/completions is an LLM API route
|
|
mock_handle_logging.assert_called_once()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
@pytest.mark.parametrize(
|
|
"request_data, route, expected_call_type",
|
|
[
|
|
(
|
|
{"model": "bad-model", "messages": [{"role": "user", "content": "hello"}]},
|
|
"/v1/chat/completions",
|
|
"acompletion",
|
|
),
|
|
(
|
|
{"model": "bad-model", "prompt": "hello"},
|
|
"/v1/completions",
|
|
"atext_completion",
|
|
),
|
|
(
|
|
{"model": "bad-model", "input": ["hello"]},
|
|
"/v1/embeddings",
|
|
"aembedding",
|
|
),
|
|
],
|
|
)
|
|
async def test_handle_logging_proxy_only_error_syncs_normalized_call_type(
|
|
request_data, route, expected_call_type
|
|
):
|
|
from fastapi import HTTPException
|
|
|
|
from litellm.caching.caching import DualCache
|
|
from litellm.litellm_core_utils.litellm_logging import Logging
|
|
from litellm.proxy.utils import ProxyLogging
|
|
|
|
cache = DualCache()
|
|
proxy_logging = ProxyLogging(user_api_key_cache=cache)
|
|
captured_logging_obj = {}
|
|
original_function_setup = litellm.utils.function_setup
|
|
|
|
def _capture_function_setup(*args, **kwargs):
|
|
logging_obj, data = original_function_setup(*args, **kwargs)
|
|
captured_logging_obj["logging_obj"] = logging_obj
|
|
return logging_obj, data
|
|
|
|
with (
|
|
patch(
|
|
"litellm.proxy.utils.litellm.utils.function_setup",
|
|
side_effect=_capture_function_setup,
|
|
),
|
|
patch.object(
|
|
Logging, "async_failure_handler", new=AsyncMock(return_value=None)
|
|
),
|
|
patch.object(Logging, "failure_handler", return_value=None),
|
|
patch("litellm.proxy.utils.threading.Thread") as mock_thread,
|
|
):
|
|
mock_thread.return_value.start = Mock()
|
|
|
|
await proxy_logging._handle_logging_proxy_only_error(
|
|
request_data=request_data,
|
|
user_api_key_dict=UserAPIKeyAuth(
|
|
api_key="test_key",
|
|
user_id="test_user",
|
|
token="test_token",
|
|
request_route=route,
|
|
),
|
|
route=route,
|
|
original_exception=HTTPException(status_code=400, detail="bad request"),
|
|
)
|
|
|
|
logging_obj = captured_logging_obj["logging_obj"]
|
|
assert logging_obj.call_type == expected_call_type
|
|
assert logging_obj.model_call_details["call_type"] == expected_call_type
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_during_call_hook_parallel_execution():
|
|
"""
|
|
Test that multiple guardrails in during_call_hook are executed in parallel.
|
|
Verifies parallel execution by checking timing and execution order.
|
|
"""
|
|
from litellm.caching.caching import DualCache
|
|
from litellm.integrations.custom_guardrail import CustomGuardrail
|
|
from litellm.proxy.utils import ProxyLogging
|
|
from litellm.types.guardrails import GuardrailEventHooks
|
|
|
|
cache = DualCache()
|
|
proxy_logging = ProxyLogging(user_api_key_cache=cache)
|
|
execution_order = []
|
|
|
|
class TestGuardrail(CustomGuardrail):
|
|
def __init__(self, name):
|
|
super().__init__(
|
|
guardrail_name=name,
|
|
event_hook=GuardrailEventHooks.during_call,
|
|
default_on=True,
|
|
)
|
|
self.name = name
|
|
|
|
async def async_moderation_hook(self, data, user_api_key_dict, call_type):
|
|
execution_order.append(f"{self.name}_start")
|
|
await asyncio.sleep(0.1)
|
|
execution_order.append(f"{self.name}_end")
|
|
return data
|
|
|
|
original_callbacks = litellm.callbacks.copy() if litellm.callbacks else []
|
|
|
|
try:
|
|
litellm.callbacks = [TestGuardrail(f"g{i}") for i in range(3)]
|
|
|
|
start_time = asyncio.get_event_loop().time()
|
|
result = await proxy_logging.during_call_hook(
|
|
data={"model": "gpt-4", "messages": [{"role": "user", "content": "test"}]},
|
|
user_api_key_dict=UserAPIKeyAuth(api_key="test_key", user_id="test_user"),
|
|
call_type="completion",
|
|
)
|
|
execution_time = asyncio.get_event_loop().time() - start_time
|
|
|
|
# Verify parallel execution: all start before any end
|
|
first_end_idx = next(
|
|
i for i, item in enumerate(execution_order) if "end" in item
|
|
)
|
|
starts_before_end = sum(
|
|
1 for item in execution_order[:first_end_idx] if "start" in item
|
|
)
|
|
assert (
|
|
starts_before_end == 3
|
|
), f"Expected 3 starts before first end, got {starts_before_end}"
|
|
|
|
# Verify timing: parallel ~0.1s vs sequential ~0.3s
|
|
assert (
|
|
execution_time < 0.2
|
|
), f"Parallel execution took {execution_time}s, expected < 0.2s"
|
|
assert result["model"] == "gpt-4"
|
|
finally:
|
|
litellm.callbacks = original_callbacks
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_during_call_hook_parallel_execution_with_error():
|
|
"""
|
|
Test that exceptions from guardrails are properly raised in parallel execution.
|
|
"""
|
|
from litellm.caching.caching import DualCache
|
|
from litellm.integrations.custom_guardrail import CustomGuardrail
|
|
from litellm.proxy.utils import ProxyLogging
|
|
from litellm.types.guardrails import GuardrailEventHooks
|
|
|
|
cache = DualCache()
|
|
proxy_logging = ProxyLogging(user_api_key_cache=cache)
|
|
|
|
class FailingGuardrail(CustomGuardrail):
|
|
def __init__(self):
|
|
super().__init__(
|
|
guardrail_name="failing_guardrail",
|
|
event_hook=GuardrailEventHooks.during_call,
|
|
default_on=True,
|
|
)
|
|
|
|
async def async_moderation_hook(self, data, user_api_key_dict, call_type):
|
|
raise ValueError("Guardrail violation detected!")
|
|
|
|
original_callbacks = litellm.callbacks.copy() if litellm.callbacks else []
|
|
|
|
try:
|
|
litellm.callbacks = [FailingGuardrail()]
|
|
|
|
with pytest.raises(ValueError) as exc_info:
|
|
await proxy_logging.during_call_hook(
|
|
data={
|
|
"model": "gpt-4",
|
|
"messages": [{"role": "user", "content": "test"}],
|
|
},
|
|
user_api_key_dict=UserAPIKeyAuth(
|
|
api_key="test_key", user_id="test_user"
|
|
),
|
|
call_type="completion",
|
|
)
|
|
|
|
assert "Guardrail violation detected!" in str(exc_info.value)
|
|
finally:
|
|
litellm.callbacks = original_callbacks
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_handle_logging_proxy_only_error_preserves_pass_through_call_type():
|
|
"""Ensure _handle_logging_proxy_only_error does not overwrite call_type
|
|
when the logging object is already marked as pass_through_endpoint.
|
|
"""
|
|
from litellm.caching.caching import DualCache
|
|
from litellm.litellm_core_utils.litellm_logging import Logging
|
|
from litellm.proxy.utils import ProxyLogging
|
|
from litellm.types.utils import CallTypes
|
|
|
|
logging_obj = Logging(
|
|
model="unknown",
|
|
messages=[{"role": "user", "content": "test"}],
|
|
stream=False,
|
|
call_type="pass_through_endpoint",
|
|
start_time=datetime.now(),
|
|
litellm_call_id="test-call-id",
|
|
function_id="test-function-id",
|
|
)
|
|
|
|
request_data = {
|
|
"litellm_logging_obj": logging_obj,
|
|
"messages": [{"role": "user", "content": "test"}],
|
|
"model": "claude-3-5-sonnet",
|
|
}
|
|
|
|
cache = DualCache()
|
|
proxy_logging = ProxyLogging(user_api_key_cache=cache)
|
|
|
|
with patch.object(logging_obj, "async_failure_handler", new_callable=AsyncMock):
|
|
with patch.object(logging_obj, "failure_handler"):
|
|
await proxy_logging._handle_logging_proxy_only_error(
|
|
request_data=request_data,
|
|
user_api_key_dict=UserAPIKeyAuth(
|
|
api_key="test_key", token="test_token"
|
|
),
|
|
original_exception=Exception("test error"),
|
|
)
|
|
|
|
assert logging_obj.call_type == CallTypes.pass_through.value
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_litellm_logging_obj_excluded_from_optional_params():
|
|
"""Ensure litellm_logging_obj is excluded from _optional_params to prevent
|
|
circular references in model_call_details.
|
|
"""
|
|
from litellm.caching.caching import DualCache
|
|
from litellm.litellm_core_utils.litellm_logging import Logging
|
|
from litellm.proxy.utils import ProxyLogging
|
|
|
|
logging_obj = Logging(
|
|
model="unknown",
|
|
messages=[{"role": "user", "content": "test"}],
|
|
stream=False,
|
|
call_type="pass_through_endpoint",
|
|
start_time=datetime.now(),
|
|
litellm_call_id="test-call-id",
|
|
function_id="test-function-id",
|
|
)
|
|
|
|
request_data = {
|
|
"litellm_logging_obj": logging_obj,
|
|
"messages": [{"role": "user", "content": "test"}],
|
|
"model": "claude-3-5-sonnet",
|
|
}
|
|
|
|
cache = DualCache()
|
|
proxy_logging = ProxyLogging(user_api_key_cache=cache)
|
|
|
|
with patch.object(logging_obj, "async_failure_handler", new_callable=AsyncMock):
|
|
with patch.object(logging_obj, "failure_handler"):
|
|
await proxy_logging._handle_logging_proxy_only_error(
|
|
request_data=request_data,
|
|
user_api_key_dict=UserAPIKeyAuth(
|
|
api_key="test_key", token="test_token"
|
|
),
|
|
original_exception=Exception("test error"),
|
|
)
|
|
|
|
assert "litellm_logging_obj" not in logging_obj.model_call_details
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_handle_logging_proxy_only_error_skips_handlers_for_pass_through():
|
|
"""Ensure _handle_logging_proxy_only_error skips async_failure_handler and
|
|
failure_handler for pass-through endpoint errors, so only
|
|
async_post_call_failure_hook fires (avoiding duplicate logs).
|
|
|
|
Regression test for duplicate Datadog/Arize logs on pass-through failures.
|
|
"""
|
|
from litellm.caching.caching import DualCache
|
|
from litellm.litellm_core_utils.litellm_logging import Logging
|
|
from litellm.proxy.utils import ProxyLogging
|
|
from litellm.types.utils import CallTypes
|
|
|
|
logging_obj = Logging(
|
|
model="unknown",
|
|
messages=[{"role": "user", "content": "test"}],
|
|
stream=False,
|
|
call_type="pass_through_endpoint",
|
|
start_time=datetime.now(),
|
|
litellm_call_id="test-call-id",
|
|
function_id="test-function-id",
|
|
)
|
|
|
|
cache = DualCache()
|
|
proxy_logging = ProxyLogging(user_api_key_cache=cache)
|
|
|
|
request_data = {
|
|
"litellm_logging_obj": logging_obj,
|
|
"messages": [{"role": "user", "content": "test"}],
|
|
"model": "claude-3-5-sonnet",
|
|
}
|
|
|
|
with patch.object(
|
|
logging_obj, "async_failure_handler", new_callable=AsyncMock
|
|
) as mock_async:
|
|
with patch.object(logging_obj, "failure_handler") as mock_sync:
|
|
await proxy_logging._handle_logging_proxy_only_error(
|
|
request_data=request_data,
|
|
user_api_key_dict=UserAPIKeyAuth(
|
|
api_key="test_key", token="test_token"
|
|
),
|
|
original_exception=Exception("test error"),
|
|
)
|
|
|
|
# Neither handler should fire for pass-through requests
|
|
mock_async.assert_not_called()
|
|
mock_sync.assert_not_called()
|
|
assert logging_obj.call_type == CallTypes.pass_through.value
|
|
|
|
|
|
def test_handle_exception_on_proxy_preserves_status_code():
|
|
"""
|
|
OpenAI batch creation returns 429 for rate limits. LiteLLM wraps this as a
|
|
RateLimitError with status_code=429. handle_exception_on_proxy must pass
|
|
that status code through instead of hardcoding 500.
|
|
"""
|
|
from litellm.proxy.utils import handle_exception_on_proxy
|
|
|
|
rate_limit_error = litellm.RateLimitError(
|
|
message="Rate limit exceeded: batch creation limit of 2000/hour hit",
|
|
llm_provider="openai",
|
|
model="gpt-4o",
|
|
)
|
|
|
|
result = handle_exception_on_proxy(rate_limit_error)
|
|
|
|
assert int(result.code) == 429, f"Expected 429, got {result.code}"
|
|
|
|
|
|
def test_handle_exception_on_proxy_defaults_to_500_for_unknown_exceptions():
|
|
"""
|
|
Generic exceptions with no status_code should still return 500.
|
|
"""
|
|
from litellm.proxy.utils import handle_exception_on_proxy
|
|
|
|
result = handle_exception_on_proxy(Exception("something went wrong"))
|
|
|
|
assert int(result.code) == 500, f"Expected 500, got {result.code}"
|
|
|
|
|
|
def test_handle_exception_on_proxy_preserves_auth_error_status_code():
|
|
"""
|
|
AuthenticationError (401) should also pass through correctly.
|
|
"""
|
|
from litellm.proxy.utils import handle_exception_on_proxy
|
|
|
|
auth_error = litellm.AuthenticationError(
|
|
message="Invalid API key",
|
|
llm_provider="openai",
|
|
model="gpt-4o",
|
|
)
|
|
|
|
result = handle_exception_on_proxy(auth_error)
|
|
|
|
assert int(result.code) == 401, f"Expected 401, got {result.code}"
|