Files
litellm/tests/test_litellm/proxy/test_litellm_pre_call_utils.py
T
2025-07-29 15:49:20 -07:00

726 lines
24 KiB
Python

import asyncio
import copy
import json
import os
import sys
from unittest.mock import MagicMock, patch
import pytest
from fastapi import Request
from litellm.proxy._types import TeamCallbackMetadata, UserAPIKeyAuth
from litellm.proxy.litellm_pre_call_utils import (
KeyAndTeamLoggingSettings,
LiteLLMProxyRequestSetup,
_get_dynamic_logging_metadata,
_get_enforced_params,
add_litellm_data_to_request,
check_if_token_is_service_account,
)
sys.path.insert(
0, os.path.abspath("../../..")
) # Adds the parent directory to the system path
def test_check_if_token_is_service_account():
"""
Test that only keys with `service_account_id` in metadata are considered service accounts
"""
# Test case 1: Service account token
service_account_token = UserAPIKeyAuth(
api_key="test-key", metadata={"service_account_id": "test-service-account"}
)
assert check_if_token_is_service_account(service_account_token) == True
# Test case 2: Regular user token
regular_token = UserAPIKeyAuth(api_key="test-key", metadata={})
assert check_if_token_is_service_account(regular_token) == False
# Test case 3: Token with other metadata
other_metadata_token = UserAPIKeyAuth(
api_key="test-key", metadata={"user_id": "test-user"}
)
assert check_if_token_is_service_account(other_metadata_token) == False
def test_get_enforced_params_for_service_account_settings():
"""
Test that service account enforced params are only added to service account keys
"""
service_account_token = UserAPIKeyAuth(
api_key="test-key", metadata={"service_account_id": "test-service-account"}
)
general_settings_with_service_account_settings = {
"service_account_settings": {"enforced_params": ["metadata.service"]},
}
result = _get_enforced_params(
general_settings=general_settings_with_service_account_settings,
user_api_key_dict=service_account_token,
)
assert result == ["metadata.service"]
regular_token = UserAPIKeyAuth(
api_key="test-key", metadata={"enforced_params": ["user"]}
)
result = _get_enforced_params(
general_settings=general_settings_with_service_account_settings,
user_api_key_dict=regular_token,
)
assert result == ["user"]
@pytest.mark.parametrize(
"general_settings, user_api_key_dict, expected_enforced_params",
[
(
{"enforced_params": ["param1", "param2"]},
UserAPIKeyAuth(
api_key="test_api_key", user_id="test_user_id", org_id="test_org_id"
),
["param1", "param2"],
),
(
{"service_account_settings": {"enforced_params": ["param1", "param2"]}},
UserAPIKeyAuth(
api_key="test_api_key",
user_id="test_user_id",
org_id="test_org_id",
metadata={"service_account_id": "test_service_account_id"},
),
["param1", "param2"],
),
(
{"service_account_settings": {"enforced_params": ["param1", "param2"]}},
UserAPIKeyAuth(
api_key="test_api_key",
metadata={
"enforced_params": ["param3", "param4"],
"service_account_id": "test_service_account_id",
},
),
["param1", "param2", "param3", "param4"],
),
],
)
def test_get_enforced_params(
general_settings, user_api_key_dict, expected_enforced_params
):
from litellm.proxy.litellm_pre_call_utils import _get_enforced_params
enforced_params = _get_enforced_params(general_settings, user_api_key_dict)
assert enforced_params == expected_enforced_params
@pytest.mark.asyncio
async def test_add_litellm_data_to_request_parses_string_metadata():
from litellm.proxy.litellm_pre_call_utils import add_litellm_data_to_request
# Setup
request_mock = MagicMock(spec=Request)
request_mock.url.path = "/v1/completions"
request_mock.url = MagicMock()
request_mock.url.__str__.return_value = "http://localhost/v1/completions"
request_mock.method = "POST"
request_mock.query_params = {}
request_mock.headers = {"Content-Type": "application/json"}
request_mock.client = MagicMock()
request_mock.client.host = "127.0.0.1"
# Simulate data with stringified metadata
fake_metadata = {"generation_name": "gen123"}
data = {"metadata": json.dumps(fake_metadata), "model": "gpt-3.5-turbo"}
user_api_key_dict = UserAPIKeyAuth(
api_key="hashed-key",
metadata={},
team_metadata={},
spend=0.0,
max_budget=100.0,
model_max_budget={}, # this one can be a dict
team_spend=0.0,
team_max_budget=200.0,
)
# Call
updated_data = await add_litellm_data_to_request(
data=data,
request=request_mock,
user_api_key_dict=user_api_key_dict,
proxy_config=MagicMock(),
general_settings={},
version="test-version",
)
# Assert
litellm_metadata = updated_data.get("metadata", {})
assert isinstance(litellm_metadata, dict)
assert updated_data["metadata"]["generation_name"] == "gen123"
@pytest.mark.asyncio
async def test_add_litellm_data_to_request_audio_transcription_multipart():
from litellm.proxy.litellm_pre_call_utils import add_litellm_data_to_request
# Setup request mock for /v1/audio/transcriptions
request_mock = MagicMock(spec=Request)
request_mock.url.path = "/v1/audio/transcriptions"
request_mock.url = MagicMock()
request_mock.url.__str__.return_value = "http://localhost/v1/audio/transcriptions"
request_mock.method = "POST"
request_mock.query_params = {}
request_mock.headers = {
"Content-Type": "multipart/form-data",
"Authorization": "Bearer sk-1234",
}
request_mock.client = MagicMock()
request_mock.client.host = "127.0.0.1"
# Simulate multipart data (metadata as string)
metadata_dict = {
"tags": ["jobID:214590dsff09fds", "taskName:run_page_classification"]
}
stringified_metadata = json.dumps(metadata_dict)
data = {
"model": "fake-openai-endpoint",
"metadata": stringified_metadata, # Simulating multipart-form field
"file": b"Fake audio bytes",
}
user_api_key_dict = UserAPIKeyAuth(
api_key="hashed-key",
metadata={},
team_metadata={},
spend=0.0,
max_budget=100.0,
model_max_budget={},
team_spend=0.0,
team_max_budget=200.0,
)
updated_data = await add_litellm_data_to_request(
data=data,
request=request_mock,
user_api_key_dict=user_api_key_dict,
proxy_config=MagicMock(),
general_settings={},
version="test-version",
)
# Assert metadata was parsed correctly
metadata_field = updated_data.get("metadata", {})
litellm_metadata = updated_data.get("litellm_metadata", {})
assert isinstance(metadata_field, dict)
assert "tags" in metadata_field
assert metadata_field["tags"] == [
"jobID:214590dsff09fds",
"taskName:run_page_classification",
]
@pytest.mark.asyncio
async def test_add_litellm_data_to_request_disabled_callbacks():
"""
Test that litellm_disabled_callbacks from key metadata is properly added to the request data.
"""
from litellm.proxy.litellm_pre_call_utils import add_litellm_data_to_request
# Setup mock request
request_mock = MagicMock(spec=Request)
request_mock.url.path = "/chat/completions"
request_mock.url = MagicMock()
request_mock.url.__str__.return_value = "http://localhost/chat/completions"
request_mock.method = "POST"
request_mock.query_params = {}
request_mock.headers = {"Content-Type": "application/json"}
request_mock.client = MagicMock()
request_mock.client.host = "127.0.0.1"
# Setup user API key with disabled callbacks in metadata
user_api_key_dict = UserAPIKeyAuth(
api_key="test_api_key",
user_id="test_user_id",
org_id="test_org_id",
metadata={"litellm_disabled_callbacks": ["langfuse", "langsmith", "datadog"]},
)
# Setup request data
data = {
"model": "gpt-3.5-turbo",
"messages": [{"role": "user", "content": "Hello"}],
}
# Setup proxy config
proxy_config = MagicMock()
# Call add_litellm_data_to_request
result = await add_litellm_data_to_request(
data=data,
request=request_mock,
user_api_key_dict=user_api_key_dict,
proxy_config=proxy_config,
)
# Verify that litellm_disabled_callbacks was added to the request data
assert "litellm_disabled_callbacks" in result
assert result["litellm_disabled_callbacks"] == ["langfuse", "langsmith", "datadog"]
# Verify that other data is still present
assert "model" in result
assert result["model"] == "gpt-3.5-turbo"
assert "messages" in result
@pytest.mark.asyncio
async def test_add_litellm_data_to_request_disabled_callbacks_empty():
"""
Test that litellm_disabled_callbacks is not added when it's empty.
"""
from litellm.proxy.litellm_pre_call_utils import add_litellm_data_to_request
# Setup mock request
request_mock = MagicMock(spec=Request)
request_mock.url.path = "/chat/completions"
request_mock.url = MagicMock()
request_mock.url.__str__.return_value = "http://localhost/chat/completions"
request_mock.method = "POST"
request_mock.query_params = {}
request_mock.headers = {"Content-Type": "application/json"}
request_mock.client = MagicMock()
request_mock.client.host = "127.0.0.1"
# Setup user API key with empty disabled callbacks
user_api_key_dict = UserAPIKeyAuth(
api_key="test_api_key",
user_id="test_user_id",
org_id="test_org_id",
metadata={"litellm_disabled_callbacks": []},
)
# Setup request data
data = {
"model": "gpt-3.5-turbo",
"messages": [{"role": "user", "content": "Hello"}],
}
# Setup proxy config
proxy_config = MagicMock()
# Call add_litellm_data_to_request
result = await add_litellm_data_to_request(
data=data,
request=request_mock,
user_api_key_dict=user_api_key_dict,
proxy_config=proxy_config,
)
# Verify that litellm_disabled_callbacks is not added when empty
assert "litellm_disabled_callbacks" not in result
# Verify that other data is still present
assert "model" in result
assert result["model"] == "gpt-3.5-turbo"
assert "messages" in result
@pytest.mark.asyncio
async def test_add_litellm_data_to_request_disabled_callbacks_not_present():
"""
Test that litellm_disabled_callbacks is not added when it's not present in metadata.
"""
from litellm.proxy.litellm_pre_call_utils import add_litellm_data_to_request
# Setup mock request
request_mock = MagicMock(spec=Request)
request_mock.url.path = "/chat/completions"
request_mock.url = MagicMock()
request_mock.url.__str__.return_value = "http://localhost/chat/completions"
request_mock.method = "POST"
request_mock.query_params = {}
request_mock.headers = {"Content-Type": "application/json"}
request_mock.client = MagicMock()
request_mock.client.host = "127.0.0.1"
# Setup user API key without disabled callbacks in metadata
user_api_key_dict = UserAPIKeyAuth(
api_key="test_api_key",
user_id="test_user_id",
org_id="test_org_id",
metadata={}, # No litellm_disabled_callbacks
)
# Setup request data
data = {
"model": "gpt-3.5-turbo",
"messages": [{"role": "user", "content": "Hello"}],
}
# Setup proxy config
proxy_config = MagicMock()
# Call add_litellm_data_to_request
result = await add_litellm_data_to_request(
data=data,
request=request_mock,
user_api_key_dict=user_api_key_dict,
proxy_config=proxy_config,
)
# Verify that litellm_disabled_callbacks is not added when not present
assert "litellm_disabled_callbacks" not in result
# Verify that other data is still present
assert "model" in result
assert result["model"] == "gpt-3.5-turbo"
assert "messages" in result
@pytest.mark.asyncio
async def test_add_litellm_data_to_request_disabled_callbacks_invalid_type():
"""
Test that litellm_disabled_callbacks is not added when it's not a list.
"""
from litellm.proxy.litellm_pre_call_utils import add_litellm_data_to_request
# Setup mock request
request_mock = MagicMock(spec=Request)
request_mock.url.path = "/chat/completions"
request_mock.url = MagicMock()
request_mock.url.__str__.return_value = "http://localhost/chat/completions"
request_mock.method = "POST"
request_mock.query_params = {}
request_mock.headers = {"Content-Type": "application/json"}
request_mock.client = MagicMock()
request_mock.client.host = "127.0.0.1"
# Setup user API key with invalid disabled callbacks type
user_api_key_dict = UserAPIKeyAuth(
api_key="test_api_key",
user_id="test_user_id",
org_id="test_org_id",
metadata={"litellm_disabled_callbacks": "not_a_list"}, # Should be a list
)
# Setup request data
data = {
"model": "gpt-3.5-turbo",
"messages": [{"role": "user", "content": "Hello"}],
}
# Setup proxy config
proxy_config = MagicMock()
# Call add_litellm_data_to_request
result = await add_litellm_data_to_request(
data=data,
request=request_mock,
user_api_key_dict=user_api_key_dict,
proxy_config=proxy_config,
)
# Verify that litellm_disabled_callbacks is not added when invalid type
assert "litellm_disabled_callbacks" not in result
# Verify that other data is still present
assert "model" in result
assert result["model"] == "gpt-3.5-turbo"
assert "messages" in result
@pytest.mark.asyncio
async def test_add_litellm_data_to_request_disabled_callbacks_with_logging_settings():
"""
Test that litellm_disabled_callbacks works correctly alongside logging settings.
"""
from litellm.proxy.litellm_pre_call_utils import add_litellm_data_to_request
# Setup mock request
request_mock = MagicMock(spec=Request)
request_mock.url.path = "/chat/completions"
request_mock.url = MagicMock()
request_mock.url.__str__.return_value = "http://localhost/chat/completions"
request_mock.method = "POST"
request_mock.query_params = {}
request_mock.headers = {"Content-Type": "application/json"}
request_mock.client = MagicMock()
request_mock.client.host = "127.0.0.1"
# Setup user API key with both logging settings and disabled callbacks
user_api_key_dict = UserAPIKeyAuth(
api_key="test_api_key",
user_id="test_user_id",
org_id="test_org_id",
metadata={
"logging": [
{
"callback_name": "langfuse",
"callback_type": "success",
"callback_vars": {},
}
],
"litellm_disabled_callbacks": ["langsmith", "datadog"],
},
)
# Setup request data
data = {
"model": "gpt-3.5-turbo",
"messages": [{"role": "user", "content": "Hello"}],
}
# Setup proxy config
proxy_config = MagicMock()
# Call add_litellm_data_to_request
result = await add_litellm_data_to_request(
data=data,
request=request_mock,
user_api_key_dict=user_api_key_dict,
proxy_config=proxy_config,
)
# Verify that both logging settings and disabled callbacks are handled correctly
assert "litellm_disabled_callbacks" in result
assert result["litellm_disabled_callbacks"] == ["langsmith", "datadog"]
# Verify that other data is still present
assert "model" in result
assert result["model"] == "gpt-3.5-turbo"
assert "messages" in result
def test_key_dynamic_logging_settings():
"""
Test KeyAndTeamLoggingSettings.get_key_dynamic_logging_settings method with arize and langfuse callbacks
"""
# Test with arize logging
key_with_arize = UserAPIKeyAuth(
api_key="test-key",
metadata={"logging": [{"callback_name": "arize", "callback_type": "success"}]},
team_metadata={},
)
result = KeyAndTeamLoggingSettings.get_key_dynamic_logging_settings(key_with_arize)
assert result == [{"callback_name": "arize", "callback_type": "success"}]
# Test with langfuse logging
key_with_langfuse = UserAPIKeyAuth(
api_key="test-key",
metadata={
"logging": [{"callback_name": "langfuse", "callback_type": "success"}]
},
team_metadata={},
)
result = KeyAndTeamLoggingSettings.get_key_dynamic_logging_settings(
key_with_langfuse
)
assert result == [{"callback_name": "langfuse", "callback_type": "success"}]
# Test with no logging metadata
key_without_logging = UserAPIKeyAuth(
api_key="test-key", metadata={}, team_metadata={}
)
result = KeyAndTeamLoggingSettings.get_key_dynamic_logging_settings(
key_without_logging
)
assert result is None
def test_team_dynamic_logging_settings():
"""
Test KeyAndTeamLoggingSettings.get_team_dynamic_logging_settings method with arize and langfuse callbacks
"""
# Test with arize team logging
key_with_team_arize = UserAPIKeyAuth(
api_key="test-key",
metadata={},
team_metadata={
"logging": [{"callback_name": "arize", "callback_type": "failure"}]
},
)
result = KeyAndTeamLoggingSettings.get_team_dynamic_logging_settings(
key_with_team_arize
)
assert result == [{"callback_name": "arize", "callback_type": "failure"}]
# Test with langfuse team logging
key_with_team_langfuse = UserAPIKeyAuth(
api_key="test-key",
metadata={},
team_metadata={
"logging": [{"callback_name": "langfuse", "callback_type": "success"}]
},
)
result = KeyAndTeamLoggingSettings.get_team_dynamic_logging_settings(
key_with_team_langfuse
)
assert result == [{"callback_name": "langfuse", "callback_type": "success"}]
# Test with no team logging metadata
key_without_team_logging = UserAPIKeyAuth(
api_key="test-key", metadata={}, team_metadata={}
)
result = KeyAndTeamLoggingSettings.get_team_dynamic_logging_settings(
key_without_team_logging
)
assert result is None
def test_get_dynamic_logging_metadata_with_arize_team_logging():
"""
Test _get_dynamic_logging_metadata function with arize team logging and dynamic parameters
"""
# Setup user with arize team logging including callback_vars
user_api_key_dict = UserAPIKeyAuth(
api_key="test-key",
metadata={},
team_metadata={
"logging": [
{
"callback_name": "arize",
"callback_type": "success",
"callback_vars": {
"arize_api_key": "test_arize_api_key",
"arize_space_id": "test_arize_space_id",
},
}
]
},
)
# Mock proxy_config (not used in this test path since we have team dynamic logging)
mock_proxy_config = MagicMock()
# Call the function
result = _get_dynamic_logging_metadata(
user_api_key_dict=user_api_key_dict, proxy_config=mock_proxy_config
)
# Verify the result
assert result is not None
assert isinstance(result, TeamCallbackMetadata)
assert result.success_callback == ["arize"]
assert result.callback_vars is not None
assert result.callback_vars["arize_api_key"] == "test_arize_api_key"
assert result.callback_vars["arize_space_id"] == "test_arize_space_id"
def test_get_num_retries_from_request():
"""
Test LiteLLMProxyRequestSetup._get_num_retries_from_request method
"""
# Test case 1: Header is present with valid integer string
headers_with_retries = {"x-litellm-num-retries": "3"}
result = LiteLLMProxyRequestSetup._get_num_retries_from_request(
headers_with_retries
)
assert result == 3
# Test case 2: Header is not present
headers_without_retries = {"Content-Type": "application/json"}
result = LiteLLMProxyRequestSetup._get_num_retries_from_request(
headers_without_retries
)
assert result is None
# Test case 3: Empty headers dictionary
empty_headers = {}
result = LiteLLMProxyRequestSetup._get_num_retries_from_request(empty_headers)
assert result is None
# Test case 4: Header present with zero value
headers_with_zero = {"x-litellm-num-retries": "0"}
result = LiteLLMProxyRequestSetup._get_num_retries_from_request(headers_with_zero)
assert result == 0
# Test case 5: Header present with large number
headers_with_large_number = {"x-litellm-num-retries": "100"}
result = LiteLLMProxyRequestSetup._get_num_retries_from_request(
headers_with_large_number
)
assert result == 100
# Test case 6: Multiple headers with num retries header
headers_multiple = {
"Content-Type": "application/json",
"x-litellm-num-retries": "5",
"Authorization": "Bearer token",
}
result = LiteLLMProxyRequestSetup._get_num_retries_from_request(headers_multiple)
assert result == 5
# Test case 7: Header present with invalid value (should raise ValueError when int() is called)
headers_with_invalid = {"x-litellm-num-retries": "invalid"}
with pytest.raises(ValueError):
LiteLLMProxyRequestSetup._get_num_retries_from_request(headers_with_invalid)
# Test case 8: Header present with float string (should raise ValueError when int() is called)
headers_with_float = {"x-litellm-num-retries": "3.5"}
with pytest.raises(ValueError):
LiteLLMProxyRequestSetup._get_num_retries_from_request(headers_with_float)
# Test case 9: Header present with negative number
headers_with_negative = {"x-litellm-num-retries": "-1"}
result = LiteLLMProxyRequestSetup._get_num_retries_from_request(
headers_with_negative
)
assert result == -1
def test_add_user_api_key_auth_to_request_metadata():
"""
Test that add_user_api_key_auth_to_request_metadata properly adds user API key authentication data to request metadata
"""
# Setup test data
data = {
"model": "gpt-3.5-turbo",
"messages": [{"role": "user", "content": "Hello"}],
"litellm_metadata": {} # This will be the metadata variable name
}
user_api_key_dict = UserAPIKeyAuth(
api_key="hashed-test-key-123",
user_id="test-user-123",
org_id="test-org-456",
team_id="test-team-789",
key_alias="test-key-alias",
user_email="test@example.com",
team_alias="test-team-alias",
end_user_id="test-end-user-123",
request_route="/chat/completions",
end_user_max_budget=500.0
)
metadata_variable_name = "litellm_metadata"
# Call the function
result = LiteLLMProxyRequestSetup.add_user_api_key_auth_to_request_metadata(
data=data,
user_api_key_dict=user_api_key_dict,
_metadata_variable_name=metadata_variable_name
)
# Verify the metadata was properly added
metadata = result[metadata_variable_name]
# Check that user API key information was added
assert metadata["user_api_key_hash"] == "hashed-test-key-123"
assert metadata["user_api_key_alias"] == "test-key-alias"
assert metadata["user_api_key_team_id"] == "test-team-789"
assert metadata["user_api_key_user_id"] == "test-user-123"
assert metadata["user_api_key_org_id"] == "test-org-456"
assert metadata["user_api_key_team_alias"] == "test-team-alias"
assert metadata["user_api_key_end_user_id"] == "test-end-user-123"
assert metadata["user_api_key_user_email"] == "test@example.com"
assert metadata["user_api_key_request_route"] == "/chat/completions"
# Check that the hashed API key was added
assert metadata["user_api_key"] == "hashed-test-key-123"
# Check that end user max budget was added
assert metadata["user_api_end_user_max_budget"] == 500.0
# Verify original data is preserved
assert result["model"] == "gpt-3.5-turbo"
assert result["messages"] == [{"role": "user", "content": "Hello"}]