mirror of
https://github.com/tiennm99/litellm.git
synced 2026-07-05 13:07:08 +00:00
Merge pull request #17407 from BerriAI/litellm_enforce_enforce_user_param
Enforce support of enforce_user_param to openai post endpoints
This commit is contained in:
@@ -190,7 +190,18 @@ async def common_checks(
|
||||
general_settings.get("enforce_user_param", None) is not None
|
||||
and general_settings["enforce_user_param"] is True
|
||||
):
|
||||
if RouteChecks.is_llm_api_route(route=route) and "user" not in request_body:
|
||||
# Get HTTP method from request
|
||||
http_method = request.method if hasattr(request, 'method') else None
|
||||
|
||||
# Check if it's a POST request and if it's an OpenAI route but not MCP
|
||||
is_post_method = http_method and http_method.upper() == "POST"
|
||||
is_openai_route = RouteChecks.is_llm_api_route(route=route)
|
||||
is_mcp_route = route in LiteLLMRoutes.mcp_routes.value or RouteChecks.check_route_access(
|
||||
route=route, allowed_routes=LiteLLMRoutes.mcp_routes.value
|
||||
)
|
||||
|
||||
# Enforce user param only for POST requests on OpenAI routes (excluding MCP routes)
|
||||
if is_post_method and is_openai_route and not is_mcp_route and "user" not in request_body:
|
||||
raise Exception(
|
||||
f"'user' param not passed in. 'enforce_user_param'={general_settings['enforce_user_param']}"
|
||||
)
|
||||
|
||||
@@ -0,0 +1,438 @@
|
||||
"""
|
||||
Tests for enforce_user_param feature with POST/GET method filtering and MCP route exclusion.
|
||||
|
||||
Tests verify that:
|
||||
1. enforce_user_param only applies to POST requests
|
||||
2. GET requests like /v1/models are not affected
|
||||
3. MCP routes are excluded from enforcement
|
||||
4. POST requests to completion endpoints still require user param when enforced
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock, MagicMock, PropertyMock, patch
|
||||
|
||||
import pytest
|
||||
from fastapi import Request
|
||||
|
||||
from litellm.proxy._types import LiteLLMRoutes, UserAPIKeyAuth
|
||||
from litellm.proxy.auth.auth_checks import common_checks
|
||||
from litellm.proxy.auth.route_checks import RouteChecks
|
||||
|
||||
|
||||
class MockRequest:
|
||||
"""Mock FastAPI Request object"""
|
||||
def __init__(self, method: str = "POST"):
|
||||
self.method = method
|
||||
|
||||
|
||||
def get_mock_user_token():
|
||||
"""Create a mock UserAPIKeyAuth token for testing"""
|
||||
return UserAPIKeyAuth(
|
||||
api_key="test-key",
|
||||
user_id="test-user",
|
||||
team_id="test-team",
|
||||
org_id="test-org",
|
||||
models=["*"],
|
||||
metadata={}
|
||||
)
|
||||
|
||||
|
||||
class TestEnforceUserParamPostGetFiltering:
|
||||
"""Test POST/GET method filtering for enforce_user_param"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_post_completion_without_user_param_should_fail(self):
|
||||
"""POST to /v1/chat/completions without user param should raise error"""
|
||||
request = MockRequest(method="POST")
|
||||
general_settings = {"enforce_user_param": True}
|
||||
request_body = {
|
||||
"model": "gpt-3.5-turbo",
|
||||
"messages": [{"role": "user", "content": "Hello"}]
|
||||
}
|
||||
|
||||
with patch('litellm.proxy.auth.auth_checks._is_api_route_allowed', new_callable=AsyncMock, return_value=True):
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
await common_checks(
|
||||
request_body=request_body,
|
||||
team_object=None,
|
||||
user_object=None,
|
||||
end_user_object=None,
|
||||
global_proxy_spend=None,
|
||||
general_settings=general_settings,
|
||||
route="/v1/chat/completions",
|
||||
llm_router=None,
|
||||
proxy_logging_obj=MagicMock(),
|
||||
valid_token=get_mock_user_token(),
|
||||
request=request,
|
||||
)
|
||||
|
||||
assert "user" in str(exc_info.value).lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_post_completion_with_user_param_should_pass(self):
|
||||
"""POST to /v1/chat/completions with user param should pass"""
|
||||
request = MockRequest(method="POST")
|
||||
general_settings = {"enforce_user_param": True}
|
||||
request_body = {
|
||||
"model": "gpt-3.5-turbo",
|
||||
"messages": [{"role": "user", "content": "Hello"}],
|
||||
"user": "user123"
|
||||
}
|
||||
|
||||
with patch('litellm.proxy.auth.auth_checks._is_api_route_allowed', new_callable=AsyncMock, return_value=True):
|
||||
# Should not raise exception
|
||||
result = await common_checks(
|
||||
request_body=request_body,
|
||||
team_object=None,
|
||||
user_object=None,
|
||||
end_user_object=None,
|
||||
global_proxy_spend=None,
|
||||
general_settings=general_settings,
|
||||
route="/v1/chat/completions",
|
||||
llm_router=None,
|
||||
proxy_logging_obj=MagicMock(),
|
||||
valid_token=get_mock_user_token(),
|
||||
request=request,
|
||||
)
|
||||
|
||||
assert result is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_models_without_user_param_should_pass(self):
|
||||
"""GET to /v1/models without user param should NOT raise error"""
|
||||
request = MockRequest(method="GET")
|
||||
general_settings = {"enforce_user_param": True}
|
||||
request_body = {} # GET requests typically don't have body
|
||||
|
||||
with patch('litellm.proxy.auth.auth_checks._is_api_route_allowed', new_callable=AsyncMock, return_value=True):
|
||||
# Should not raise exception
|
||||
result = await common_checks(
|
||||
request_body=request_body,
|
||||
team_object=None,
|
||||
user_object=None,
|
||||
end_user_object=None,
|
||||
global_proxy_spend=None,
|
||||
general_settings=general_settings,
|
||||
route="/v1/models",
|
||||
llm_router=None,
|
||||
proxy_logging_obj=MagicMock(),
|
||||
valid_token=get_mock_user_token(),
|
||||
request=request,
|
||||
)
|
||||
|
||||
assert result is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_files_without_user_param_should_pass(self):
|
||||
"""GET to /v1/files without user param should NOT raise error"""
|
||||
request = MockRequest(method="GET")
|
||||
general_settings = {"enforce_user_param": True}
|
||||
request_body = {}
|
||||
|
||||
with patch('litellm.proxy.auth.auth_checks._is_api_route_allowed', new_callable=AsyncMock, return_value=True):
|
||||
result = await common_checks(
|
||||
request_body=request_body,
|
||||
team_object=None,
|
||||
user_object=None,
|
||||
end_user_object=None,
|
||||
global_proxy_spend=None,
|
||||
general_settings=general_settings,
|
||||
route="/v1/files",
|
||||
llm_router=None,
|
||||
proxy_logging_obj=MagicMock(),
|
||||
valid_token=get_mock_user_token(),
|
||||
request=request,
|
||||
)
|
||||
|
||||
assert result is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_post_embeddings_without_user_param_should_fail(self):
|
||||
"""POST to /v1/embeddings without user param should raise error"""
|
||||
request = MockRequest(method="POST")
|
||||
general_settings = {"enforce_user_param": True}
|
||||
request_body = {
|
||||
"model": "text-embedding-ada-002",
|
||||
"input": "test"
|
||||
}
|
||||
|
||||
with patch('litellm.proxy.auth.auth_checks._is_api_route_allowed', new_callable=AsyncMock, return_value=True):
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
await common_checks(
|
||||
request_body=request_body,
|
||||
team_object=None,
|
||||
user_object=None,
|
||||
end_user_object=None,
|
||||
global_proxy_spend=None,
|
||||
general_settings=general_settings,
|
||||
route="/v1/embeddings",
|
||||
llm_router=None,
|
||||
proxy_logging_obj=MagicMock(),
|
||||
valid_token=get_mock_user_token(),
|
||||
request=request,
|
||||
)
|
||||
|
||||
assert "user" in str(exc_info.value).lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_post_embeddings_with_user_param_should_pass(self):
|
||||
"""POST to /v1/embeddings with user param should pass"""
|
||||
request = MockRequest(method="POST")
|
||||
general_settings = {"enforce_user_param": True}
|
||||
request_body = {
|
||||
"model": "text-embedding-ada-002",
|
||||
"input": "test",
|
||||
"user": "user123"
|
||||
}
|
||||
|
||||
with patch('litellm.proxy.auth.auth_checks._is_api_route_allowed', new_callable=AsyncMock, return_value=True):
|
||||
result = await common_checks(
|
||||
request_body=request_body,
|
||||
team_object=None,
|
||||
user_object=None,
|
||||
end_user_object=None,
|
||||
global_proxy_spend=None,
|
||||
general_settings=general_settings,
|
||||
route="/v1/embeddings",
|
||||
llm_router=None,
|
||||
proxy_logging_obj=MagicMock(),
|
||||
valid_token=get_mock_user_token(),
|
||||
request=request,
|
||||
)
|
||||
|
||||
assert result is True
|
||||
|
||||
|
||||
class TestEnforceUserParamMCPExclusion:
|
||||
"""Test MCP route exclusion from enforce_user_param"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mcp_route_without_user_param_should_pass(self):
|
||||
"""POST to MCP route without user param should NOT raise error"""
|
||||
request = MockRequest(method="POST")
|
||||
general_settings = {"enforce_user_param": True}
|
||||
request_body = {"action": "list_tools"}
|
||||
|
||||
with patch('litellm.proxy.auth.auth_checks._is_api_route_allowed', new_callable=AsyncMock, return_value=True):
|
||||
# Should not raise exception for MCP routes
|
||||
result = await common_checks(
|
||||
request_body=request_body,
|
||||
team_object=None,
|
||||
user_object=None,
|
||||
end_user_object=None,
|
||||
global_proxy_spend=None,
|
||||
general_settings=general_settings,
|
||||
route="/mcp/tools/list",
|
||||
llm_router=None,
|
||||
proxy_logging_obj=MagicMock(),
|
||||
valid_token=get_mock_user_token(),
|
||||
request=request,
|
||||
)
|
||||
|
||||
assert result is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mcp_root_route_without_user_param_should_pass(self):
|
||||
"""POST to /mcp without user param should NOT raise error"""
|
||||
request = MockRequest(method="POST")
|
||||
general_settings = {"enforce_user_param": True}
|
||||
request_body = {"data": "test"}
|
||||
|
||||
with patch('litellm.proxy.auth.auth_checks._is_api_route_allowed', new_callable=AsyncMock, return_value=True):
|
||||
result = await common_checks(
|
||||
request_body=request_body,
|
||||
team_object=None,
|
||||
user_object=None,
|
||||
end_user_object=None,
|
||||
global_proxy_spend=None,
|
||||
general_settings=general_settings,
|
||||
route="/mcp",
|
||||
llm_router=None,
|
||||
proxy_logging_obj=MagicMock(),
|
||||
valid_token=get_mock_user_token(),
|
||||
request=request,
|
||||
)
|
||||
|
||||
assert result is True
|
||||
|
||||
|
||||
class TestEnforceUserParamDisabled:
|
||||
"""Test behavior when enforce_user_param is disabled"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_post_without_user_param_when_disabled_should_pass(self):
|
||||
"""POST without user param when enforce_user_param=False should pass"""
|
||||
request = MockRequest(method="POST")
|
||||
general_settings = {"enforce_user_param": False}
|
||||
request_body = {
|
||||
"model": "gpt-3.5-turbo",
|
||||
"messages": [{"role": "user", "content": "Hello"}]
|
||||
}
|
||||
|
||||
with patch('litellm.proxy.auth.auth_checks._is_api_route_allowed', new_callable=AsyncMock, return_value=True):
|
||||
result = await common_checks(
|
||||
request_body=request_body,
|
||||
team_object=None,
|
||||
user_object=None,
|
||||
end_user_object=None,
|
||||
global_proxy_spend=None,
|
||||
general_settings=general_settings,
|
||||
route="/v1/chat/completions",
|
||||
llm_router=None,
|
||||
proxy_logging_obj=MagicMock(),
|
||||
valid_token=get_mock_user_token(),
|
||||
request=request,
|
||||
)
|
||||
|
||||
assert result is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_post_without_user_param_when_not_set_should_pass(self):
|
||||
"""POST without user param when enforce_user_param not set should pass"""
|
||||
request = MockRequest(method="POST")
|
||||
general_settings = {} # enforce_user_param not set
|
||||
request_body = {
|
||||
"model": "gpt-3.5-turbo",
|
||||
"messages": [{"role": "user", "content": "Hello"}]
|
||||
}
|
||||
|
||||
with patch('litellm.proxy.auth.auth_checks._is_api_route_allowed', new_callable=AsyncMock, return_value=True):
|
||||
result = await common_checks(
|
||||
request_body=request_body,
|
||||
team_object=None,
|
||||
user_object=None,
|
||||
end_user_object=None,
|
||||
global_proxy_spend=None,
|
||||
general_settings=general_settings,
|
||||
route="/v1/chat/completions",
|
||||
llm_router=None,
|
||||
proxy_logging_obj=MagicMock(),
|
||||
valid_token=get_mock_user_token(),
|
||||
request=request,
|
||||
)
|
||||
|
||||
assert result is True
|
||||
|
||||
|
||||
class TestEnforceUserParamEdgeCases:
|
||||
"""Test edge cases for enforce_user_param"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_request_without_method_attribute_should_pass(self):
|
||||
"""Request without method attribute should not raise error"""
|
||||
request = MagicMock()
|
||||
del request.method # Remove method attribute
|
||||
request.__hasattr__ = MagicMock(return_value=False)
|
||||
|
||||
general_settings = {"enforce_user_param": True}
|
||||
request_body = {
|
||||
"model": "gpt-3.5-turbo",
|
||||
"messages": [{"role": "user", "content": "Hello"}]
|
||||
}
|
||||
|
||||
with patch('litellm.proxy.auth.auth_checks._is_api_route_allowed', new_callable=AsyncMock, return_value=True):
|
||||
# Should not raise error even without method
|
||||
result = await common_checks(
|
||||
request_body=request_body,
|
||||
team_object=None,
|
||||
user_object=None,
|
||||
end_user_object=None,
|
||||
global_proxy_spend=None,
|
||||
general_settings=general_settings,
|
||||
route="/v1/chat/completions",
|
||||
llm_router=None,
|
||||
proxy_logging_obj=MagicMock(),
|
||||
valid_token=get_mock_user_token(),
|
||||
request=request,
|
||||
)
|
||||
|
||||
assert result is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_case_insensitive_http_method(self):
|
||||
"""HTTP method comparison should be case-insensitive"""
|
||||
request = MockRequest(method="post") # lowercase
|
||||
general_settings = {"enforce_user_param": True}
|
||||
request_body = {
|
||||
"model": "gpt-3.5-turbo",
|
||||
"messages": [{"role": "user", "content": "Hello"}]
|
||||
}
|
||||
|
||||
with patch('litellm.proxy.auth.auth_checks._is_api_route_allowed', new_callable=AsyncMock, return_value=True):
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
await common_checks(
|
||||
request_body=request_body,
|
||||
team_object=None,
|
||||
user_object=None,
|
||||
end_user_object=None,
|
||||
global_proxy_spend=None,
|
||||
general_settings=general_settings,
|
||||
route="/v1/chat/completions",
|
||||
llm_router=None,
|
||||
proxy_logging_obj=MagicMock(),
|
||||
valid_token=get_mock_user_token(),
|
||||
request=request,
|
||||
)
|
||||
|
||||
assert "user" in str(exc_info.value).lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_put_method_should_not_enforce_user_param(self):
|
||||
"""PUT method should not enforce user param (only POST)"""
|
||||
request = MockRequest(method="PUT")
|
||||
general_settings = {"enforce_user_param": True}
|
||||
request_body = {
|
||||
"model": "gpt-3.5-turbo",
|
||||
"messages": [{"role": "user", "content": "Hello"}]
|
||||
}
|
||||
|
||||
with patch('litellm.proxy.auth.auth_checks._is_api_route_allowed', new_callable=AsyncMock, return_value=True):
|
||||
# Should not raise for PUT method
|
||||
result = await common_checks(
|
||||
request_body=request_body,
|
||||
team_object=None,
|
||||
user_object=None,
|
||||
end_user_object=None,
|
||||
global_proxy_spend=None,
|
||||
general_settings=general_settings,
|
||||
route="/v1/chat/completions",
|
||||
llm_router=None,
|
||||
proxy_logging_obj=MagicMock(),
|
||||
valid_token=get_mock_user_token(),
|
||||
request=request,
|
||||
)
|
||||
|
||||
assert result is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_patch_method_should_not_enforce_user_param(self):
|
||||
"""PATCH method should not enforce user param (only POST)"""
|
||||
request = MockRequest(method="PATCH")
|
||||
general_settings = {"enforce_user_param": True}
|
||||
request_body = {
|
||||
"model": "gpt-3.5-turbo",
|
||||
"messages": [{"role": "user", "content": "Hello"}]
|
||||
}
|
||||
|
||||
with patch('litellm.proxy.auth.auth_checks._is_api_route_allowed', new_callable=AsyncMock, return_value=True):
|
||||
# Should not raise for PATCH method
|
||||
result = await common_checks(
|
||||
request_body=request_body,
|
||||
team_object=None,
|
||||
user_object=None,
|
||||
end_user_object=None,
|
||||
global_proxy_spend=None,
|
||||
general_settings=general_settings,
|
||||
route="/v1/chat/completions",
|
||||
llm_router=None,
|
||||
proxy_logging_obj=MagicMock(),
|
||||
valid_token=get_mock_user_token(),
|
||||
request=request,
|
||||
)
|
||||
|
||||
assert result is True
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Run tests with: pytest tests/test_litellm/proxy/test_enforce_user_param.py -v
|
||||
pytest.main([__file__, "-v"])
|
||||
Reference in New Issue
Block a user