diff --git a/litellm/proxy/auth/auth_checks.py b/litellm/proxy/auth/auth_checks.py index 45b0752d4c..fc79a4d359 100644 --- a/litellm/proxy/auth/auth_checks.py +++ b/litellm/proxy/auth/auth_checks.py @@ -190,7 +190,18 @@ async def common_checks( general_settings.get("enforce_user_param", None) is not None and general_settings["enforce_user_param"] is True ): - if RouteChecks.is_llm_api_route(route=route) and "user" not in request_body: + # Get HTTP method from request + http_method = request.method if hasattr(request, 'method') else None + + # Check if it's a POST request and if it's an OpenAI route but not MCP + is_post_method = http_method and http_method.upper() == "POST" + is_openai_route = RouteChecks.is_llm_api_route(route=route) + is_mcp_route = route in LiteLLMRoutes.mcp_routes.value or RouteChecks.check_route_access( + route=route, allowed_routes=LiteLLMRoutes.mcp_routes.value + ) + + # Enforce user param only for POST requests on OpenAI routes (excluding MCP routes) + if is_post_method and is_openai_route and not is_mcp_route and "user" not in request_body: raise Exception( f"'user' param not passed in. 'enforce_user_param'={general_settings['enforce_user_param']}" ) diff --git a/tests/test_litellm/proxy/test_enforce_user_param.py b/tests/test_litellm/proxy/test_enforce_user_param.py new file mode 100644 index 0000000000..5d9369d61b --- /dev/null +++ b/tests/test_litellm/proxy/test_enforce_user_param.py @@ -0,0 +1,438 @@ +""" +Tests for enforce_user_param feature with POST/GET method filtering and MCP route exclusion. + +Tests verify that: +1. enforce_user_param only applies to POST requests +2. GET requests like /v1/models are not affected +3. MCP routes are excluded from enforcement +4. POST requests to completion endpoints still require user param when enforced +""" + +import asyncio +from unittest.mock import AsyncMock, MagicMock, PropertyMock, patch + +import pytest +from fastapi import Request + +from litellm.proxy._types import LiteLLMRoutes, UserAPIKeyAuth +from litellm.proxy.auth.auth_checks import common_checks +from litellm.proxy.auth.route_checks import RouteChecks + + +class MockRequest: + """Mock FastAPI Request object""" + def __init__(self, method: str = "POST"): + self.method = method + + +def get_mock_user_token(): + """Create a mock UserAPIKeyAuth token for testing""" + return UserAPIKeyAuth( + api_key="test-key", + user_id="test-user", + team_id="test-team", + org_id="test-org", + models=["*"], + metadata={} + ) + + +class TestEnforceUserParamPostGetFiltering: + """Test POST/GET method filtering for enforce_user_param""" + + @pytest.mark.asyncio + async def test_post_completion_without_user_param_should_fail(self): + """POST to /v1/chat/completions without user param should raise error""" + request = MockRequest(method="POST") + general_settings = {"enforce_user_param": True} + request_body = { + "model": "gpt-3.5-turbo", + "messages": [{"role": "user", "content": "Hello"}] + } + + with patch('litellm.proxy.auth.auth_checks._is_api_route_allowed', new_callable=AsyncMock, return_value=True): + with pytest.raises(Exception) as exc_info: + await common_checks( + request_body=request_body, + team_object=None, + user_object=None, + end_user_object=None, + global_proxy_spend=None, + general_settings=general_settings, + route="/v1/chat/completions", + llm_router=None, + proxy_logging_obj=MagicMock(), + valid_token=get_mock_user_token(), + request=request, + ) + + assert "user" in str(exc_info.value).lower() + + @pytest.mark.asyncio + async def test_post_completion_with_user_param_should_pass(self): + """POST to /v1/chat/completions with user param should pass""" + request = MockRequest(method="POST") + general_settings = {"enforce_user_param": True} + request_body = { + "model": "gpt-3.5-turbo", + "messages": [{"role": "user", "content": "Hello"}], + "user": "user123" + } + + with patch('litellm.proxy.auth.auth_checks._is_api_route_allowed', new_callable=AsyncMock, return_value=True): + # Should not raise exception + result = await common_checks( + request_body=request_body, + team_object=None, + user_object=None, + end_user_object=None, + global_proxy_spend=None, + general_settings=general_settings, + route="/v1/chat/completions", + llm_router=None, + proxy_logging_obj=MagicMock(), + valid_token=get_mock_user_token(), + request=request, + ) + + assert result is True + + @pytest.mark.asyncio + async def test_get_models_without_user_param_should_pass(self): + """GET to /v1/models without user param should NOT raise error""" + request = MockRequest(method="GET") + general_settings = {"enforce_user_param": True} + request_body = {} # GET requests typically don't have body + + with patch('litellm.proxy.auth.auth_checks._is_api_route_allowed', new_callable=AsyncMock, return_value=True): + # Should not raise exception + result = await common_checks( + request_body=request_body, + team_object=None, + user_object=None, + end_user_object=None, + global_proxy_spend=None, + general_settings=general_settings, + route="/v1/models", + llm_router=None, + proxy_logging_obj=MagicMock(), + valid_token=get_mock_user_token(), + request=request, + ) + + assert result is True + + @pytest.mark.asyncio + async def test_get_files_without_user_param_should_pass(self): + """GET to /v1/files without user param should NOT raise error""" + request = MockRequest(method="GET") + general_settings = {"enforce_user_param": True} + request_body = {} + + with patch('litellm.proxy.auth.auth_checks._is_api_route_allowed', new_callable=AsyncMock, return_value=True): + result = await common_checks( + request_body=request_body, + team_object=None, + user_object=None, + end_user_object=None, + global_proxy_spend=None, + general_settings=general_settings, + route="/v1/files", + llm_router=None, + proxy_logging_obj=MagicMock(), + valid_token=get_mock_user_token(), + request=request, + ) + + assert result is True + + @pytest.mark.asyncio + async def test_post_embeddings_without_user_param_should_fail(self): + """POST to /v1/embeddings without user param should raise error""" + request = MockRequest(method="POST") + general_settings = {"enforce_user_param": True} + request_body = { + "model": "text-embedding-ada-002", + "input": "test" + } + + with patch('litellm.proxy.auth.auth_checks._is_api_route_allowed', new_callable=AsyncMock, return_value=True): + with pytest.raises(Exception) as exc_info: + await common_checks( + request_body=request_body, + team_object=None, + user_object=None, + end_user_object=None, + global_proxy_spend=None, + general_settings=general_settings, + route="/v1/embeddings", + llm_router=None, + proxy_logging_obj=MagicMock(), + valid_token=get_mock_user_token(), + request=request, + ) + + assert "user" in str(exc_info.value).lower() + + @pytest.mark.asyncio + async def test_post_embeddings_with_user_param_should_pass(self): + """POST to /v1/embeddings with user param should pass""" + request = MockRequest(method="POST") + general_settings = {"enforce_user_param": True} + request_body = { + "model": "text-embedding-ada-002", + "input": "test", + "user": "user123" + } + + with patch('litellm.proxy.auth.auth_checks._is_api_route_allowed', new_callable=AsyncMock, return_value=True): + result = await common_checks( + request_body=request_body, + team_object=None, + user_object=None, + end_user_object=None, + global_proxy_spend=None, + general_settings=general_settings, + route="/v1/embeddings", + llm_router=None, + proxy_logging_obj=MagicMock(), + valid_token=get_mock_user_token(), + request=request, + ) + + assert result is True + + +class TestEnforceUserParamMCPExclusion: + """Test MCP route exclusion from enforce_user_param""" + + @pytest.mark.asyncio + async def test_mcp_route_without_user_param_should_pass(self): + """POST to MCP route without user param should NOT raise error""" + request = MockRequest(method="POST") + general_settings = {"enforce_user_param": True} + request_body = {"action": "list_tools"} + + with patch('litellm.proxy.auth.auth_checks._is_api_route_allowed', new_callable=AsyncMock, return_value=True): + # Should not raise exception for MCP routes + result = await common_checks( + request_body=request_body, + team_object=None, + user_object=None, + end_user_object=None, + global_proxy_spend=None, + general_settings=general_settings, + route="/mcp/tools/list", + llm_router=None, + proxy_logging_obj=MagicMock(), + valid_token=get_mock_user_token(), + request=request, + ) + + assert result is True + + @pytest.mark.asyncio + async def test_mcp_root_route_without_user_param_should_pass(self): + """POST to /mcp without user param should NOT raise error""" + request = MockRequest(method="POST") + general_settings = {"enforce_user_param": True} + request_body = {"data": "test"} + + with patch('litellm.proxy.auth.auth_checks._is_api_route_allowed', new_callable=AsyncMock, return_value=True): + result = await common_checks( + request_body=request_body, + team_object=None, + user_object=None, + end_user_object=None, + global_proxy_spend=None, + general_settings=general_settings, + route="/mcp", + llm_router=None, + proxy_logging_obj=MagicMock(), + valid_token=get_mock_user_token(), + request=request, + ) + + assert result is True + + +class TestEnforceUserParamDisabled: + """Test behavior when enforce_user_param is disabled""" + + @pytest.mark.asyncio + async def test_post_without_user_param_when_disabled_should_pass(self): + """POST without user param when enforce_user_param=False should pass""" + request = MockRequest(method="POST") + general_settings = {"enforce_user_param": False} + request_body = { + "model": "gpt-3.5-turbo", + "messages": [{"role": "user", "content": "Hello"}] + } + + with patch('litellm.proxy.auth.auth_checks._is_api_route_allowed', new_callable=AsyncMock, return_value=True): + result = await common_checks( + request_body=request_body, + team_object=None, + user_object=None, + end_user_object=None, + global_proxy_spend=None, + general_settings=general_settings, + route="/v1/chat/completions", + llm_router=None, + proxy_logging_obj=MagicMock(), + valid_token=get_mock_user_token(), + request=request, + ) + + assert result is True + + @pytest.mark.asyncio + async def test_post_without_user_param_when_not_set_should_pass(self): + """POST without user param when enforce_user_param not set should pass""" + request = MockRequest(method="POST") + general_settings = {} # enforce_user_param not set + request_body = { + "model": "gpt-3.5-turbo", + "messages": [{"role": "user", "content": "Hello"}] + } + + with patch('litellm.proxy.auth.auth_checks._is_api_route_allowed', new_callable=AsyncMock, return_value=True): + result = await common_checks( + request_body=request_body, + team_object=None, + user_object=None, + end_user_object=None, + global_proxy_spend=None, + general_settings=general_settings, + route="/v1/chat/completions", + llm_router=None, + proxy_logging_obj=MagicMock(), + valid_token=get_mock_user_token(), + request=request, + ) + + assert result is True + + +class TestEnforceUserParamEdgeCases: + """Test edge cases for enforce_user_param""" + + @pytest.mark.asyncio + async def test_request_without_method_attribute_should_pass(self): + """Request without method attribute should not raise error""" + request = MagicMock() + del request.method # Remove method attribute + request.__hasattr__ = MagicMock(return_value=False) + + general_settings = {"enforce_user_param": True} + request_body = { + "model": "gpt-3.5-turbo", + "messages": [{"role": "user", "content": "Hello"}] + } + + with patch('litellm.proxy.auth.auth_checks._is_api_route_allowed', new_callable=AsyncMock, return_value=True): + # Should not raise error even without method + result = await common_checks( + request_body=request_body, + team_object=None, + user_object=None, + end_user_object=None, + global_proxy_spend=None, + general_settings=general_settings, + route="/v1/chat/completions", + llm_router=None, + proxy_logging_obj=MagicMock(), + valid_token=get_mock_user_token(), + request=request, + ) + + assert result is True + + @pytest.mark.asyncio + async def test_case_insensitive_http_method(self): + """HTTP method comparison should be case-insensitive""" + request = MockRequest(method="post") # lowercase + general_settings = {"enforce_user_param": True} + request_body = { + "model": "gpt-3.5-turbo", + "messages": [{"role": "user", "content": "Hello"}] + } + + with patch('litellm.proxy.auth.auth_checks._is_api_route_allowed', new_callable=AsyncMock, return_value=True): + with pytest.raises(Exception) as exc_info: + await common_checks( + request_body=request_body, + team_object=None, + user_object=None, + end_user_object=None, + global_proxy_spend=None, + general_settings=general_settings, + route="/v1/chat/completions", + llm_router=None, + proxy_logging_obj=MagicMock(), + valid_token=get_mock_user_token(), + request=request, + ) + + assert "user" in str(exc_info.value).lower() + + @pytest.mark.asyncio + async def test_put_method_should_not_enforce_user_param(self): + """PUT method should not enforce user param (only POST)""" + request = MockRequest(method="PUT") + general_settings = {"enforce_user_param": True} + request_body = { + "model": "gpt-3.5-turbo", + "messages": [{"role": "user", "content": "Hello"}] + } + + with patch('litellm.proxy.auth.auth_checks._is_api_route_allowed', new_callable=AsyncMock, return_value=True): + # Should not raise for PUT method + result = await common_checks( + request_body=request_body, + team_object=None, + user_object=None, + end_user_object=None, + global_proxy_spend=None, + general_settings=general_settings, + route="/v1/chat/completions", + llm_router=None, + proxy_logging_obj=MagicMock(), + valid_token=get_mock_user_token(), + request=request, + ) + + assert result is True + + @pytest.mark.asyncio + async def test_patch_method_should_not_enforce_user_param(self): + """PATCH method should not enforce user param (only POST)""" + request = MockRequest(method="PATCH") + general_settings = {"enforce_user_param": True} + request_body = { + "model": "gpt-3.5-turbo", + "messages": [{"role": "user", "content": "Hello"}] + } + + with patch('litellm.proxy.auth.auth_checks._is_api_route_allowed', new_callable=AsyncMock, return_value=True): + # Should not raise for PATCH method + result = await common_checks( + request_body=request_body, + team_object=None, + user_object=None, + end_user_object=None, + global_proxy_spend=None, + general_settings=general_settings, + route="/v1/chat/completions", + llm_router=None, + proxy_logging_obj=MagicMock(), + valid_token=get_mock_user_token(), + request=request, + ) + + assert result is True + + +if __name__ == "__main__": + # Run tests with: pytest tests/test_litellm/proxy/test_enforce_user_param.py -v + pytest.main([__file__, "-v"])