Files
litellm/tests/proxy_unit_tests/test_jwt_key_mapping.py
T
2026-04-17 13:02:59 -07:00

457 lines
16 KiB
Python

import pytest
import sys
import os
from datetime import datetime, timezone
from unittest.mock import AsyncMock, MagicMock, patch
# Add project root to sys.path
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")))
from litellm.proxy.auth.user_api_key_auth import (
_resolve_jwt_to_virtual_key,
)
from litellm.proxy.auth.handle_jwt import JWTHandler
from litellm.proxy._types import (
JWTKeyMappingResponse,
LiteLLM_JWTAuth,
LitellmUserRoles,
UserAPIKeyAuth,
)
from litellm.proxy.management_endpoints.jwt_key_mapping_endpoints import (
_to_response,
create_jwt_key_mapping,
delete_jwt_key_mapping,
info_jwt_key_mapping,
update_jwt_key_mapping,
)
from litellm.caching.caching import DualCache
from fastapi import HTTPException
# ──────────────────────────────────────────────
# Tests: _resolve_jwt_to_virtual_key
# ──────────────────────────────────────────────
@pytest.mark.asyncio
async def test_jwt_to_virtual_key_mapping_resolution():
"""
Test that a JWT claim is correctly resolved to a virtual key token.
"""
jwt_handler = JWTHandler()
jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth(
virtual_key_claim_field="email", virtual_key_mapping_cache_ttl=3600
)
jwt_claims = {"email": "user@example.com", "sub": "123"}
prisma_client = MagicMock()
prisma_client.db.litellm_jwtkeymapping.find_first = AsyncMock()
# Mock finding a mapping
mock_mapping = MagicMock()
mock_mapping.token = "sk-1234"
mock_mapping.is_active = True
prisma_client.db.litellm_jwtkeymapping.find_first.return_value = mock_mapping
# Mock getting the key object
mock_key_obj = UserAPIKeyAuth(token="sk-1234", team_id="team1")
user_api_key_cache = DualCache()
# Use patch to mock get_key_object in the module where it's used
with patch(
"litellm.proxy.auth.user_api_key_auth.get_key_object", new_callable=AsyncMock
) as mock_get_key:
mock_get_key.return_value = mock_key_obj
result = await _resolve_jwt_to_virtual_key(
jwt_claims=jwt_claims,
jwt_handler=jwt_handler,
prisma_client=prisma_client,
user_api_key_cache=user_api_key_cache,
parent_otel_span=None,
proxy_logging_obj=None,
)
assert result == mock_key_obj
prisma_client.db.litellm_jwtkeymapping.find_first.assert_called_once()
# Test Cache hit
prisma_client.db.litellm_jwtkeymapping.find_first.reset_mock()
result_cached = await _resolve_jwt_to_virtual_key(
jwt_claims=jwt_claims,
jwt_handler=jwt_handler,
prisma_client=prisma_client,
user_api_key_cache=user_api_key_cache,
parent_otel_span=None,
proxy_logging_obj=None,
)
assert result_cached == mock_key_obj
prisma_client.db.litellm_jwtkeymapping.find_first.assert_not_called()
@pytest.mark.asyncio
async def test_jwt_to_virtual_key_mapping_no_mapping():
"""
Test that when no mapping exists, resolve returns None.
"""
jwt_handler = JWTHandler()
jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth(virtual_key_claim_field="email")
jwt_claims = {"email": "unknown@example.com"}
prisma_client = MagicMock()
prisma_client.db.litellm_jwtkeymapping.find_first = AsyncMock()
prisma_client.db.litellm_jwtkeymapping.find_first.return_value = None
# Mock get_key_object just in case
with patch(
"litellm.proxy.auth.user_api_key_auth.get_key_object", new_callable=AsyncMock
):
user_api_key_cache = DualCache()
result = await _resolve_jwt_to_virtual_key(
jwt_claims=jwt_claims,
jwt_handler=jwt_handler,
prisma_client=prisma_client,
user_api_key_cache=user_api_key_cache,
parent_otel_span=None,
proxy_logging_obj=None,
)
assert result is None
# Test Negative Cache hit
prisma_client.db.litellm_jwtkeymapping.find_first.reset_mock()
result_cached = await _resolve_jwt_to_virtual_key(
jwt_claims=jwt_claims,
jwt_handler=jwt_handler,
prisma_client=prisma_client,
user_api_key_cache=user_api_key_cache,
parent_otel_span=None,
proxy_logging_obj=None,
)
assert result_cached is None
prisma_client.db.litellm_jwtkeymapping.find_first.assert_not_called()
# ──────────────────────────────────────────────
# Tests: OIDC / JWT routing in user_api_key_auth
# ──────────────────────────────────────────────
@pytest.mark.asyncio
async def test_virtual_key_mapping_oidc_enabled_jwt_token_uses_auth_jwt():
"""
Regression test for the is_jwt routing fix in user_api_key_auth.py.
When oidc_userinfo_enabled=True and virtual_key_claim_field is set, but
the token is a well-formed JWT (3-part header.payload.sig), the virtual-key
claim lookup must call auth_jwt — not get_oidc_userinfo.
"""
# Three-part token: is_jwt() returns True
api_key = "eyJhbGciOiJSUzI1NiJ9.eyJlbWFpbCI6InVzZXJAZXhhbXBsZS5jb20ifQ.sig"
jwt_handler = JWTHandler()
jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth(
oidc_userinfo_enabled=True,
virtual_key_claim_field="email",
)
# Confirm our fixture token is treated as a JWT
assert jwt_handler.is_jwt(token=api_key) is True
auth_jwt_mock = AsyncMock(return_value={"email": "user@example.com", "sub": "123"})
oidc_userinfo_mock = AsyncMock(return_value={"email": "user@example.com"})
# Simulate the routing condition from user_api_key_auth.py
if jwt_handler.litellm_jwtauth.oidc_userinfo_enabled and not jwt_handler.is_jwt(
token=api_key
):
jwt_claims = await oidc_userinfo_mock(token=api_key)
else:
jwt_claims = await auth_jwt_mock(token=api_key)
auth_jwt_mock.assert_called_once_with(token=api_key)
oidc_userinfo_mock.assert_not_called()
assert jwt_claims["email"] == "user@example.com"
@pytest.mark.asyncio
async def test_virtual_key_mapping_oidc_enabled_opaque_token_uses_oidc_userinfo():
"""
Complement of the test above: when oidc_userinfo_enabled=True and the token
is an opaque access token (not a JWT), the virtual-key claim lookup must
call get_oidc_userinfo — not auth_jwt.
"""
# Opaque token: no dots → is_jwt() returns False
api_key = "some_opaque_access_token_with_no_dots"
jwt_handler = JWTHandler()
jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth(
oidc_userinfo_enabled=True,
virtual_key_claim_field="email",
)
assert jwt_handler.is_jwt(token=api_key) is False
auth_jwt_mock = AsyncMock(return_value={"email": "user@example.com"})
oidc_userinfo_mock = AsyncMock(
return_value={"email": "user@example.com", "sub": "123"}
)
if jwt_handler.litellm_jwtauth.oidc_userinfo_enabled and not jwt_handler.is_jwt(
token=api_key
):
jwt_claims = await oidc_userinfo_mock(token=api_key)
else:
jwt_claims = await auth_jwt_mock(token=api_key)
oidc_userinfo_mock.assert_called_once_with(token=api_key)
auth_jwt_mock.assert_not_called()
assert jwt_claims["sub"] == "123"
# ──────────────────────────────────────────────
# Tests: _to_response redacts hashed token
# ──────────────────────────────────────────────
def test_to_response_excludes_token():
"""_to_response should not expose the hashed token field."""
now = datetime.now(timezone.utc)
mock_mapping = MagicMock()
mock_mapping.id = "mapping-1"
mock_mapping.jwt_claim_name = "email"
mock_mapping.jwt_claim_value = "user@example.com"
mock_mapping.token = "hashed_secret_value"
mock_mapping.description = "test"
mock_mapping.is_active = True
mock_mapping.created_at = now
mock_mapping.updated_at = now
mock_mapping.created_by = "admin"
mock_mapping.updated_by = "admin"
resp = _to_response(mock_mapping)
assert isinstance(resp, JWTKeyMappingResponse)
assert resp.id == "mapping-1"
assert resp.jwt_claim_name == "email"
assert "token" not in resp.model_fields
# ──────────────────────────────────────────────
# Helpers
# ──────────────────────────────────────────────
def _make_admin_auth() -> UserAPIKeyAuth:
return UserAPIKeyAuth(
token="sk-admin",
user_role=LitellmUserRoles.PROXY_ADMIN,
)
def _make_non_admin_auth() -> UserAPIKeyAuth:
return UserAPIKeyAuth(
token="sk-user",
user_role=LitellmUserRoles.INTERNAL_USER,
)
def _mock_prisma():
prisma = MagicMock()
prisma.db.litellm_jwtkeymapping.create = AsyncMock()
prisma.db.litellm_jwtkeymapping.find_unique = AsyncMock()
prisma.db.litellm_jwtkeymapping.find_many = AsyncMock()
prisma.db.litellm_jwtkeymapping.update = AsyncMock()
prisma.db.litellm_jwtkeymapping.delete = AsyncMock()
prisma.db.litellm_jwtkeymapping.count = AsyncMock(return_value=0)
return prisma
def _mock_mapping(
id="mapping-1",
claim_name="email",
claim_value="user@example.com",
):
now = datetime.now(timezone.utc)
m = MagicMock()
m.id = id
m.jwt_claim_name = claim_name
m.jwt_claim_value = claim_value
m.token = "hashed_token"
m.description = None
m.is_active = True
m.created_at = now
m.updated_at = now
m.created_by = "admin"
m.updated_by = "admin"
return m
# ──────────────────────────────────────────────
# Tests: CRUD endpoint error handling
# ──────────────────────────────────────────────
@pytest.mark.asyncio
async def test_create_returns_409_on_unique_violation():
"""Duplicate mapping should return 409, not 500."""
from litellm.proxy._types import CreateJWTKeyMappingRequest
mock_prisma = _mock_prisma()
mock_prisma.db.litellm_jwtkeymapping.create.side_effect = Exception(
"Unique constraint failed (P2002)"
)
mock_cache = AsyncMock()
data = CreateJWTKeyMappingRequest(
jwt_claim_name="email",
jwt_claim_value="user@example.com",
key="sk-test-key",
)
with (
patch("litellm.proxy.proxy_server.prisma_client", mock_prisma),
patch("litellm.proxy.proxy_server.user_api_key_cache", mock_cache),
):
with pytest.raises(HTTPException) as exc_info:
await create_jwt_key_mapping(
data=data, user_api_key_dict=_make_admin_auth()
)
assert exc_info.value.status_code == 409
assert "already exists" in exc_info.value.detail
@pytest.mark.asyncio
async def test_create_returns_400_on_foreign_key_violation():
"""Non-existent key should return 400, not 500."""
from litellm.proxy._types import CreateJWTKeyMappingRequest
mock_prisma = _mock_prisma()
mock_prisma.db.litellm_jwtkeymapping.create.side_effect = Exception(
"Foreign key constraint failed on field: `token` (P2003)"
)
mock_cache = AsyncMock()
data = CreateJWTKeyMappingRequest(
jwt_claim_name="sub",
jwt_claim_value="user-999",
key="sk-nonexistent",
)
with (
patch("litellm.proxy.proxy_server.prisma_client", mock_prisma),
patch("litellm.proxy.proxy_server.user_api_key_cache", mock_cache),
):
with pytest.raises(HTTPException) as exc_info:
await create_jwt_key_mapping(
data=data, user_api_key_dict=_make_admin_auth()
)
assert exc_info.value.status_code == 400
assert "does not match" in exc_info.value.detail
@pytest.mark.asyncio
async def test_create_non_admin_returns_403():
"""Non-admin users should get 403."""
from litellm.proxy._types import CreateJWTKeyMappingRequest
data = CreateJWTKeyMappingRequest(
jwt_claim_name="email",
jwt_claim_value="user@example.com",
key="sk-test",
)
with pytest.raises(HTTPException) as exc_info:
await create_jwt_key_mapping(
data=data, user_api_key_dict=_make_non_admin_auth()
)
assert exc_info.value.status_code == 403
@pytest.mark.asyncio
async def test_delete_returns_404_when_not_found():
"""Deleting non-existent mapping should return 404."""
from litellm.proxy._types import DeleteJWTKeyMappingRequest
mock_prisma = _mock_prisma()
mock_prisma.db.litellm_jwtkeymapping.find_unique.return_value = None
mock_cache = AsyncMock()
data = DeleteJWTKeyMappingRequest(id="nonexistent-id")
with (
patch("litellm.proxy.proxy_server.prisma_client", mock_prisma),
patch("litellm.proxy.proxy_server.user_api_key_cache", mock_cache),
):
with pytest.raises(HTTPException) as exc_info:
await delete_jwt_key_mapping(
data=data, user_api_key_dict=_make_admin_auth()
)
assert exc_info.value.status_code == 404
@pytest.mark.asyncio
async def test_update_returns_404_when_not_found():
"""Updating non-existent mapping should return 404."""
from litellm.proxy._types import UpdateJWTKeyMappingRequest
mock_prisma = _mock_prisma()
mock_prisma.db.litellm_jwtkeymapping.find_unique.return_value = None
mock_cache = AsyncMock()
data = UpdateJWTKeyMappingRequest(id="nonexistent-id", description="test")
with (
patch("litellm.proxy.proxy_server.prisma_client", mock_prisma),
patch("litellm.proxy.proxy_server.user_api_key_cache", mock_cache),
):
with pytest.raises(HTTPException) as exc_info:
await update_jwt_key_mapping(
data=data, user_api_key_dict=_make_admin_auth()
)
assert exc_info.value.status_code == 404
@pytest.mark.asyncio
async def test_info_returns_404_when_not_found():
"""Getting info for non-existent mapping should return 404."""
mock_prisma = _mock_prisma()
mock_prisma.db.litellm_jwtkeymapping.find_unique.return_value = None
with patch("litellm.proxy.proxy_server.prisma_client", mock_prisma):
with pytest.raises(HTTPException) as exc_info:
await info_jwt_key_mapping(
id="nonexistent-id", user_api_key_dict=_make_admin_auth()
)
assert exc_info.value.status_code == 404
@pytest.mark.asyncio
async def test_create_success_returns_response_without_token():
"""Successful create should return JWTKeyMappingResponse without hashed token."""
from litellm.proxy._types import CreateJWTKeyMappingRequest
mock_prisma = _mock_prisma()
mock_prisma.db.litellm_jwtkeymapping.create.return_value = _mock_mapping()
mock_cache = AsyncMock()
data = CreateJWTKeyMappingRequest(
jwt_claim_name="email",
jwt_claim_value="user@example.com",
key="sk-test-key",
)
with (
patch("litellm.proxy.proxy_server.prisma_client", mock_prisma),
patch("litellm.proxy.proxy_server.user_api_key_cache", mock_cache),
):
result = await create_jwt_key_mapping(
data=data, user_api_key_dict=_make_admin_auth()
)
assert isinstance(result, JWTKeyMappingResponse)
assert "token" not in result.model_fields
assert result.jwt_claim_name == "email"