mirror of
https://github.com/tiennm99/litellm.git
synced 2026-06-17 22:48:35 +00:00
428 lines
16 KiB
Python
428 lines
16 KiB
Python
import pytest
|
|
import sys
|
|
import os
|
|
from datetime import datetime, timezone
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
# Add project root to sys.path
|
|
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")))
|
|
|
|
from litellm.proxy.auth.user_api_key_auth import (
|
|
_resolve_jwt_to_virtual_key,
|
|
)
|
|
from litellm.proxy.auth.handle_jwt import JWTHandler
|
|
from litellm.proxy._types import (
|
|
JWTKeyMappingResponse,
|
|
LiteLLM_JWTAuth,
|
|
LitellmUserRoles,
|
|
UserAPIKeyAuth,
|
|
)
|
|
from litellm.proxy.management_endpoints.jwt_key_mapping_endpoints import (
|
|
_to_response,
|
|
create_jwt_key_mapping,
|
|
delete_jwt_key_mapping,
|
|
info_jwt_key_mapping,
|
|
update_jwt_key_mapping,
|
|
)
|
|
from litellm.caching.caching import DualCache
|
|
from fastapi import HTTPException
|
|
|
|
|
|
# ──────────────────────────────────────────────
|
|
# Tests: _resolve_jwt_to_virtual_key
|
|
# ──────────────────────────────────────────────
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_jwt_to_virtual_key_mapping_resolution():
|
|
"""
|
|
Test that a JWT claim is correctly resolved to a virtual key token.
|
|
"""
|
|
jwt_handler = JWTHandler()
|
|
jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth(
|
|
virtual_key_claim_field="email", virtual_key_mapping_cache_ttl=3600
|
|
)
|
|
|
|
jwt_claims = {"email": "user@example.com", "sub": "123"}
|
|
|
|
prisma_client = MagicMock()
|
|
prisma_client.db.litellm_jwtkeymapping.find_first = AsyncMock()
|
|
|
|
# Mock finding a mapping
|
|
mock_mapping = MagicMock()
|
|
mock_mapping.token = "sk-1234"
|
|
mock_mapping.is_active = True
|
|
prisma_client.db.litellm_jwtkeymapping.find_first.return_value = mock_mapping
|
|
|
|
# Mock getting the key object
|
|
mock_key_obj = UserAPIKeyAuth(token="sk-1234", team_id="team1")
|
|
|
|
user_api_key_cache = DualCache()
|
|
|
|
# Use patch to mock get_key_object in the module where it's used
|
|
with patch(
|
|
"litellm.proxy.auth.user_api_key_auth.get_key_object", new_callable=AsyncMock
|
|
) as mock_get_key:
|
|
mock_get_key.return_value = mock_key_obj
|
|
|
|
result = await _resolve_jwt_to_virtual_key(
|
|
jwt_claims=jwt_claims,
|
|
jwt_handler=jwt_handler,
|
|
prisma_client=prisma_client,
|
|
user_api_key_cache=user_api_key_cache,
|
|
parent_otel_span=None,
|
|
proxy_logging_obj=None,
|
|
)
|
|
|
|
assert result == mock_key_obj
|
|
prisma_client.db.litellm_jwtkeymapping.find_first.assert_called_once()
|
|
|
|
# Test Cache hit
|
|
prisma_client.db.litellm_jwtkeymapping.find_first.reset_mock()
|
|
result_cached = await _resolve_jwt_to_virtual_key(
|
|
jwt_claims=jwt_claims,
|
|
jwt_handler=jwt_handler,
|
|
prisma_client=prisma_client,
|
|
user_api_key_cache=user_api_key_cache,
|
|
parent_otel_span=None,
|
|
proxy_logging_obj=None,
|
|
)
|
|
assert result_cached == mock_key_obj
|
|
prisma_client.db.litellm_jwtkeymapping.find_first.assert_not_called()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_jwt_to_virtual_key_mapping_no_mapping():
|
|
"""
|
|
Test that when no mapping exists, resolve returns None.
|
|
"""
|
|
jwt_handler = JWTHandler()
|
|
jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth(virtual_key_claim_field="email")
|
|
jwt_claims = {"email": "unknown@example.com"}
|
|
|
|
prisma_client = MagicMock()
|
|
prisma_client.db.litellm_jwtkeymapping.find_first = AsyncMock()
|
|
prisma_client.db.litellm_jwtkeymapping.find_first.return_value = None
|
|
|
|
# Mock get_key_object just in case
|
|
with patch(
|
|
"litellm.proxy.auth.user_api_key_auth.get_key_object", new_callable=AsyncMock
|
|
):
|
|
user_api_key_cache = DualCache()
|
|
|
|
result = await _resolve_jwt_to_virtual_key(
|
|
jwt_claims=jwt_claims,
|
|
jwt_handler=jwt_handler,
|
|
prisma_client=prisma_client,
|
|
user_api_key_cache=user_api_key_cache,
|
|
parent_otel_span=None,
|
|
proxy_logging_obj=None,
|
|
)
|
|
|
|
assert result is None
|
|
|
|
# Test Negative Cache hit
|
|
prisma_client.db.litellm_jwtkeymapping.find_first.reset_mock()
|
|
result_cached = await _resolve_jwt_to_virtual_key(
|
|
jwt_claims=jwt_claims,
|
|
jwt_handler=jwt_handler,
|
|
prisma_client=prisma_client,
|
|
user_api_key_cache=user_api_key_cache,
|
|
parent_otel_span=None,
|
|
proxy_logging_obj=None,
|
|
)
|
|
assert result_cached is None
|
|
prisma_client.db.litellm_jwtkeymapping.find_first.assert_not_called()
|
|
|
|
|
|
# ──────────────────────────────────────────────
|
|
# Tests: OIDC / JWT routing in user_api_key_auth
|
|
# ──────────────────────────────────────────────
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_virtual_key_mapping_oidc_enabled_jwt_token_uses_auth_jwt():
|
|
"""
|
|
Regression test for the is_jwt routing fix in user_api_key_auth.py.
|
|
|
|
When oidc_userinfo_enabled=True and virtual_key_claim_field is set, but
|
|
the token is a well-formed JWT (3-part header.payload.sig), the virtual-key
|
|
claim lookup must call auth_jwt — not get_oidc_userinfo.
|
|
"""
|
|
# Three-part token: is_jwt() returns True
|
|
api_key = "eyJhbGciOiJSUzI1NiJ9.eyJlbWFpbCI6InVzZXJAZXhhbXBsZS5jb20ifQ.sig"
|
|
|
|
jwt_handler = JWTHandler()
|
|
jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth(
|
|
oidc_userinfo_enabled=True,
|
|
virtual_key_claim_field="email",
|
|
)
|
|
|
|
# Confirm our fixture token is treated as a JWT
|
|
assert jwt_handler.is_jwt(token=api_key) is True
|
|
|
|
auth_jwt_mock = AsyncMock(return_value={"email": "user@example.com", "sub": "123"})
|
|
oidc_userinfo_mock = AsyncMock(return_value={"email": "user@example.com"})
|
|
|
|
# Simulate the routing condition from user_api_key_auth.py
|
|
if jwt_handler.litellm_jwtauth.oidc_userinfo_enabled and not jwt_handler.is_jwt(
|
|
token=api_key
|
|
):
|
|
jwt_claims = await oidc_userinfo_mock(token=api_key)
|
|
else:
|
|
jwt_claims = await auth_jwt_mock(token=api_key)
|
|
|
|
auth_jwt_mock.assert_called_once_with(token=api_key)
|
|
oidc_userinfo_mock.assert_not_called()
|
|
assert jwt_claims["email"] == "user@example.com"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_virtual_key_mapping_oidc_enabled_opaque_token_uses_oidc_userinfo():
|
|
"""
|
|
Complement of the test above: when oidc_userinfo_enabled=True and the token
|
|
is an opaque access token (not a JWT), the virtual-key claim lookup must
|
|
call get_oidc_userinfo — not auth_jwt.
|
|
"""
|
|
# Opaque token: no dots → is_jwt() returns False
|
|
api_key = "some_opaque_access_token_with_no_dots"
|
|
|
|
jwt_handler = JWTHandler()
|
|
jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth(
|
|
oidc_userinfo_enabled=True,
|
|
virtual_key_claim_field="email",
|
|
)
|
|
|
|
assert jwt_handler.is_jwt(token=api_key) is False
|
|
|
|
auth_jwt_mock = AsyncMock(return_value={"email": "user@example.com"})
|
|
oidc_userinfo_mock = AsyncMock(return_value={"email": "user@example.com", "sub": "123"})
|
|
|
|
if jwt_handler.litellm_jwtauth.oidc_userinfo_enabled and not jwt_handler.is_jwt(
|
|
token=api_key
|
|
):
|
|
jwt_claims = await oidc_userinfo_mock(token=api_key)
|
|
else:
|
|
jwt_claims = await auth_jwt_mock(token=api_key)
|
|
|
|
oidc_userinfo_mock.assert_called_once_with(token=api_key)
|
|
auth_jwt_mock.assert_not_called()
|
|
assert jwt_claims["sub"] == "123"
|
|
|
|
|
|
# ──────────────────────────────────────────────
|
|
# Tests: _to_response redacts hashed token
|
|
# ──────────────────────────────────────────────
|
|
|
|
|
|
def test_to_response_excludes_token():
|
|
"""_to_response should not expose the hashed token field."""
|
|
now = datetime.now(timezone.utc)
|
|
mock_mapping = MagicMock()
|
|
mock_mapping.id = "mapping-1"
|
|
mock_mapping.jwt_claim_name = "email"
|
|
mock_mapping.jwt_claim_value = "user@example.com"
|
|
mock_mapping.token = "hashed_secret_value"
|
|
mock_mapping.description = "test"
|
|
mock_mapping.is_active = True
|
|
mock_mapping.created_at = now
|
|
mock_mapping.updated_at = now
|
|
mock_mapping.created_by = "admin"
|
|
mock_mapping.updated_by = "admin"
|
|
|
|
resp = _to_response(mock_mapping)
|
|
|
|
assert isinstance(resp, JWTKeyMappingResponse)
|
|
assert resp.id == "mapping-1"
|
|
assert resp.jwt_claim_name == "email"
|
|
assert "token" not in resp.model_fields
|
|
|
|
|
|
# ──────────────────────────────────────────────
|
|
# Helpers
|
|
# ──────────────────────────────────────────────
|
|
|
|
|
|
def _make_admin_auth() -> UserAPIKeyAuth:
|
|
return UserAPIKeyAuth(
|
|
token="sk-admin",
|
|
user_role=LitellmUserRoles.PROXY_ADMIN,
|
|
)
|
|
|
|
|
|
def _make_non_admin_auth() -> UserAPIKeyAuth:
|
|
return UserAPIKeyAuth(
|
|
token="sk-user",
|
|
user_role=LitellmUserRoles.INTERNAL_USER,
|
|
)
|
|
|
|
|
|
def _mock_prisma():
|
|
prisma = MagicMock()
|
|
prisma.db.litellm_jwtkeymapping.create = AsyncMock()
|
|
prisma.db.litellm_jwtkeymapping.find_unique = AsyncMock()
|
|
prisma.db.litellm_jwtkeymapping.find_many = AsyncMock()
|
|
prisma.db.litellm_jwtkeymapping.update = AsyncMock()
|
|
prisma.db.litellm_jwtkeymapping.delete = AsyncMock()
|
|
prisma.db.litellm_jwtkeymapping.count = AsyncMock(return_value=0)
|
|
return prisma
|
|
|
|
|
|
def _mock_mapping(
|
|
id="mapping-1",
|
|
claim_name="email",
|
|
claim_value="user@example.com",
|
|
):
|
|
now = datetime.now(timezone.utc)
|
|
m = MagicMock()
|
|
m.id = id
|
|
m.jwt_claim_name = claim_name
|
|
m.jwt_claim_value = claim_value
|
|
m.token = "hashed_token"
|
|
m.description = None
|
|
m.is_active = True
|
|
m.created_at = now
|
|
m.updated_at = now
|
|
m.created_by = "admin"
|
|
m.updated_by = "admin"
|
|
return m
|
|
|
|
|
|
# ──────────────────────────────────────────────
|
|
# Tests: CRUD endpoint error handling
|
|
# ──────────────────────────────────────────────
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_create_returns_409_on_unique_violation():
|
|
"""Duplicate mapping should return 409, not 500."""
|
|
from litellm.proxy._types import CreateJWTKeyMappingRequest
|
|
|
|
mock_prisma = _mock_prisma()
|
|
mock_prisma.db.litellm_jwtkeymapping.create.side_effect = Exception(
|
|
"Unique constraint failed (P2002)"
|
|
)
|
|
mock_cache = AsyncMock()
|
|
|
|
data = CreateJWTKeyMappingRequest(
|
|
jwt_claim_name="email", jwt_claim_value="user@example.com", key="sk-test-key",
|
|
)
|
|
|
|
with patch("litellm.proxy.proxy_server.prisma_client", mock_prisma), patch(
|
|
"litellm.proxy.proxy_server.user_api_key_cache", mock_cache
|
|
):
|
|
with pytest.raises(HTTPException) as exc_info:
|
|
await create_jwt_key_mapping(data=data, user_api_key_dict=_make_admin_auth())
|
|
assert exc_info.value.status_code == 409
|
|
assert "already exists" in exc_info.value.detail
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_create_returns_400_on_foreign_key_violation():
|
|
"""Non-existent key should return 400, not 500."""
|
|
from litellm.proxy._types import CreateJWTKeyMappingRequest
|
|
|
|
mock_prisma = _mock_prisma()
|
|
mock_prisma.db.litellm_jwtkeymapping.create.side_effect = Exception(
|
|
"Foreign key constraint failed on field: `token` (P2003)"
|
|
)
|
|
mock_cache = AsyncMock()
|
|
|
|
data = CreateJWTKeyMappingRequest(
|
|
jwt_claim_name="sub", jwt_claim_value="user-999", key="sk-nonexistent",
|
|
)
|
|
|
|
with patch("litellm.proxy.proxy_server.prisma_client", mock_prisma), patch(
|
|
"litellm.proxy.proxy_server.user_api_key_cache", mock_cache
|
|
):
|
|
with pytest.raises(HTTPException) as exc_info:
|
|
await create_jwt_key_mapping(data=data, user_api_key_dict=_make_admin_auth())
|
|
assert exc_info.value.status_code == 400
|
|
assert "does not match" in exc_info.value.detail
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_create_non_admin_returns_403():
|
|
"""Non-admin users should get 403."""
|
|
from litellm.proxy._types import CreateJWTKeyMappingRequest
|
|
|
|
data = CreateJWTKeyMappingRequest(
|
|
jwt_claim_name="email", jwt_claim_value="user@example.com", key="sk-test",
|
|
)
|
|
|
|
with pytest.raises(HTTPException) as exc_info:
|
|
await create_jwt_key_mapping(data=data, user_api_key_dict=_make_non_admin_auth())
|
|
assert exc_info.value.status_code == 403
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_delete_returns_404_when_not_found():
|
|
"""Deleting non-existent mapping should return 404."""
|
|
from litellm.proxy._types import DeleteJWTKeyMappingRequest
|
|
|
|
mock_prisma = _mock_prisma()
|
|
mock_prisma.db.litellm_jwtkeymapping.find_unique.return_value = None
|
|
mock_cache = AsyncMock()
|
|
|
|
data = DeleteJWTKeyMappingRequest(id="nonexistent-id")
|
|
|
|
with patch("litellm.proxy.proxy_server.prisma_client", mock_prisma), patch(
|
|
"litellm.proxy.proxy_server.user_api_key_cache", mock_cache
|
|
):
|
|
with pytest.raises(HTTPException) as exc_info:
|
|
await delete_jwt_key_mapping(data=data, user_api_key_dict=_make_admin_auth())
|
|
assert exc_info.value.status_code == 404
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_update_returns_404_when_not_found():
|
|
"""Updating non-existent mapping should return 404."""
|
|
from litellm.proxy._types import UpdateJWTKeyMappingRequest
|
|
|
|
mock_prisma = _mock_prisma()
|
|
mock_prisma.db.litellm_jwtkeymapping.find_unique.return_value = None
|
|
mock_cache = AsyncMock()
|
|
|
|
data = UpdateJWTKeyMappingRequest(id="nonexistent-id", description="test")
|
|
|
|
with patch("litellm.proxy.proxy_server.prisma_client", mock_prisma), patch(
|
|
"litellm.proxy.proxy_server.user_api_key_cache", mock_cache
|
|
):
|
|
with pytest.raises(HTTPException) as exc_info:
|
|
await update_jwt_key_mapping(data=data, user_api_key_dict=_make_admin_auth())
|
|
assert exc_info.value.status_code == 404
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_info_returns_404_when_not_found():
|
|
"""Getting info for non-existent mapping should return 404."""
|
|
mock_prisma = _mock_prisma()
|
|
mock_prisma.db.litellm_jwtkeymapping.find_unique.return_value = None
|
|
|
|
with patch("litellm.proxy.proxy_server.prisma_client", mock_prisma):
|
|
with pytest.raises(HTTPException) as exc_info:
|
|
await info_jwt_key_mapping(id="nonexistent-id", user_api_key_dict=_make_admin_auth())
|
|
assert exc_info.value.status_code == 404
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_create_success_returns_response_without_token():
|
|
"""Successful create should return JWTKeyMappingResponse without hashed token."""
|
|
from litellm.proxy._types import CreateJWTKeyMappingRequest
|
|
|
|
mock_prisma = _mock_prisma()
|
|
mock_prisma.db.litellm_jwtkeymapping.create.return_value = _mock_mapping()
|
|
mock_cache = AsyncMock()
|
|
|
|
data = CreateJWTKeyMappingRequest(
|
|
jwt_claim_name="email", jwt_claim_value="user@example.com", key="sk-test-key",
|
|
)
|
|
|
|
with patch("litellm.proxy.proxy_server.prisma_client", mock_prisma), patch(
|
|
"litellm.proxy.proxy_server.user_api_key_cache", mock_cache
|
|
):
|
|
result = await create_jwt_key_mapping(data=data, user_api_key_dict=_make_admin_auth())
|
|
assert isinstance(result, JWTKeyMappingResponse)
|
|
assert "token" not in result.model_fields
|
|
assert result.jwt_claim_name == "email"
|