Files
litellm/tests/proxy_unit_tests/test_jwt_key_mapping.py
T
Harshit28j 2f15686ea2 fix: address greptile feedback - redact hashed tokens, proper error codes, add tests
- Remove token field from JWTKeyMappingResponse to prevent hashed key exposure
- Use _to_response() helper on all CRUD endpoints to control returned fields
- Return 409 for unique constraint violations, 400 for FK violations, 404 for not found
- Add response_model to endpoint decorators
- Add 8 new unit tests covering error handling and token redaction

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-05 03:46:03 +05:30

353 lines
13 KiB
Python

import pytest
import sys
import os
from datetime import datetime, timezone
from unittest.mock import AsyncMock, MagicMock, patch
# Add project root to sys.path
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")))
from litellm.proxy.auth.user_api_key_auth import (
_resolve_jwt_to_virtual_key,
)
from litellm.proxy.auth.handle_jwt import JWTHandler
from litellm.proxy._types import (
JWTKeyMappingResponse,
LiteLLM_JWTAuth,
LitellmUserRoles,
UserAPIKeyAuth,
)
from litellm.proxy.management_endpoints.jwt_key_mapping_endpoints import (
_to_response,
create_jwt_key_mapping,
delete_jwt_key_mapping,
info_jwt_key_mapping,
update_jwt_key_mapping,
)
from litellm.caching.caching import DualCache
from fastapi import HTTPException
# ──────────────────────────────────────────────
# Tests: _resolve_jwt_to_virtual_key
# ──────────────────────────────────────────────
@pytest.mark.asyncio
async def test_jwt_to_virtual_key_mapping_resolution():
"""
Test that a JWT claim is correctly resolved to a virtual key token.
"""
jwt_handler = JWTHandler()
jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth(
virtual_key_claim_field="email", virtual_key_mapping_cache_ttl=3600
)
jwt_claims = {"email": "user@example.com", "sub": "123"}
prisma_client = MagicMock()
prisma_client.db.litellm_jwtkeymapping.find_first = AsyncMock()
# Mock finding a mapping
mock_mapping = MagicMock()
mock_mapping.token = "sk-1234"
mock_mapping.is_active = True
prisma_client.db.litellm_jwtkeymapping.find_first.return_value = mock_mapping
# Mock getting the key object
mock_key_obj = UserAPIKeyAuth(token="sk-1234", team_id="team1")
user_api_key_cache = DualCache()
# Use patch to mock get_key_object in the module where it's used
with patch(
"litellm.proxy.auth.user_api_key_auth.get_key_object", new_callable=AsyncMock
) as mock_get_key:
mock_get_key.return_value = mock_key_obj
result = await _resolve_jwt_to_virtual_key(
jwt_claims=jwt_claims,
jwt_handler=jwt_handler,
prisma_client=prisma_client,
user_api_key_cache=user_api_key_cache,
parent_otel_span=None,
proxy_logging_obj=None,
)
assert result == mock_key_obj
prisma_client.db.litellm_jwtkeymapping.find_first.assert_called_once()
# Test Cache hit
prisma_client.db.litellm_jwtkeymapping.find_first.reset_mock()
result_cached = await _resolve_jwt_to_virtual_key(
jwt_claims=jwt_claims,
jwt_handler=jwt_handler,
prisma_client=prisma_client,
user_api_key_cache=user_api_key_cache,
parent_otel_span=None,
proxy_logging_obj=None,
)
assert result_cached == mock_key_obj
prisma_client.db.litellm_jwtkeymapping.find_first.assert_not_called()
@pytest.mark.asyncio
async def test_jwt_to_virtual_key_mapping_no_mapping():
"""
Test that when no mapping exists, resolve returns None.
"""
jwt_handler = JWTHandler()
jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth(virtual_key_claim_field="email")
jwt_claims = {"email": "unknown@example.com"}
prisma_client = MagicMock()
prisma_client.db.litellm_jwtkeymapping.find_first = AsyncMock()
prisma_client.db.litellm_jwtkeymapping.find_first.return_value = None
# Mock get_key_object just in case
with patch(
"litellm.proxy.auth.user_api_key_auth.get_key_object", new_callable=AsyncMock
):
user_api_key_cache = DualCache()
result = await _resolve_jwt_to_virtual_key(
jwt_claims=jwt_claims,
jwt_handler=jwt_handler,
prisma_client=prisma_client,
user_api_key_cache=user_api_key_cache,
parent_otel_span=None,
proxy_logging_obj=None,
)
assert result is None
# Test Negative Cache hit
prisma_client.db.litellm_jwtkeymapping.find_first.reset_mock()
result_cached = await _resolve_jwt_to_virtual_key(
jwt_claims=jwt_claims,
jwt_handler=jwt_handler,
prisma_client=prisma_client,
user_api_key_cache=user_api_key_cache,
parent_otel_span=None,
proxy_logging_obj=None,
)
assert result_cached is None
prisma_client.db.litellm_jwtkeymapping.find_first.assert_not_called()
# ──────────────────────────────────────────────
# Tests: _to_response redacts hashed token
# ──────────────────────────────────────────────
def test_to_response_excludes_token():
"""_to_response should not expose the hashed token field."""
now = datetime.now(timezone.utc)
mock_mapping = MagicMock()
mock_mapping.id = "mapping-1"
mock_mapping.jwt_claim_name = "email"
mock_mapping.jwt_claim_value = "user@example.com"
mock_mapping.token = "hashed_secret_value"
mock_mapping.description = "test"
mock_mapping.is_active = True
mock_mapping.created_at = now
mock_mapping.updated_at = now
mock_mapping.created_by = "admin"
mock_mapping.updated_by = "admin"
resp = _to_response(mock_mapping)
assert isinstance(resp, JWTKeyMappingResponse)
assert resp.id == "mapping-1"
assert resp.jwt_claim_name == "email"
assert "token" not in resp.model_fields
# ──────────────────────────────────────────────
# Helpers
# ──────────────────────────────────────────────
def _make_admin_auth() -> UserAPIKeyAuth:
return UserAPIKeyAuth(
token="sk-admin",
user_role=LitellmUserRoles.PROXY_ADMIN,
)
def _make_non_admin_auth() -> UserAPIKeyAuth:
return UserAPIKeyAuth(
token="sk-user",
user_role=LitellmUserRoles.INTERNAL_USER,
)
def _mock_prisma():
prisma = MagicMock()
prisma.db.litellm_jwtkeymapping.create = AsyncMock()
prisma.db.litellm_jwtkeymapping.find_unique = AsyncMock()
prisma.db.litellm_jwtkeymapping.find_many = AsyncMock()
prisma.db.litellm_jwtkeymapping.update = AsyncMock()
prisma.db.litellm_jwtkeymapping.delete = AsyncMock()
prisma.db.litellm_jwtkeymapping.count = AsyncMock(return_value=0)
return prisma
def _mock_mapping(
id="mapping-1",
claim_name="email",
claim_value="user@example.com",
):
now = datetime.now(timezone.utc)
m = MagicMock()
m.id = id
m.jwt_claim_name = claim_name
m.jwt_claim_value = claim_value
m.token = "hashed_token"
m.description = None
m.is_active = True
m.created_at = now
m.updated_at = now
m.created_by = "admin"
m.updated_by = "admin"
return m
# ──────────────────────────────────────────────
# Tests: CRUD endpoint error handling
# ──────────────────────────────────────────────
@pytest.mark.asyncio
async def test_create_returns_409_on_unique_violation():
"""Duplicate mapping should return 409, not 500."""
from litellm.proxy._types import CreateJWTKeyMappingRequest
mock_prisma = _mock_prisma()
mock_prisma.db.litellm_jwtkeymapping.create.side_effect = Exception(
"Unique constraint failed (P2002)"
)
mock_cache = AsyncMock()
data = CreateJWTKeyMappingRequest(
jwt_claim_name="email", jwt_claim_value="user@example.com", key="sk-test-key",
)
with patch("litellm.proxy.proxy_server.prisma_client", mock_prisma), patch(
"litellm.proxy.proxy_server.user_api_key_cache", mock_cache
):
with pytest.raises(HTTPException) as exc_info:
await create_jwt_key_mapping(data=data, user_api_key_dict=_make_admin_auth())
assert exc_info.value.status_code == 409
assert "already exists" in exc_info.value.detail
@pytest.mark.asyncio
async def test_create_returns_400_on_foreign_key_violation():
"""Non-existent key should return 400, not 500."""
from litellm.proxy._types import CreateJWTKeyMappingRequest
mock_prisma = _mock_prisma()
mock_prisma.db.litellm_jwtkeymapping.create.side_effect = Exception(
"Foreign key constraint failed on field: `token` (P2003)"
)
mock_cache = AsyncMock()
data = CreateJWTKeyMappingRequest(
jwt_claim_name="sub", jwt_claim_value="user-999", key="sk-nonexistent",
)
with patch("litellm.proxy.proxy_server.prisma_client", mock_prisma), patch(
"litellm.proxy.proxy_server.user_api_key_cache", mock_cache
):
with pytest.raises(HTTPException) as exc_info:
await create_jwt_key_mapping(data=data, user_api_key_dict=_make_admin_auth())
assert exc_info.value.status_code == 400
assert "does not match" in exc_info.value.detail
@pytest.mark.asyncio
async def test_create_non_admin_returns_403():
"""Non-admin users should get 403."""
from litellm.proxy._types import CreateJWTKeyMappingRequest
data = CreateJWTKeyMappingRequest(
jwt_claim_name="email", jwt_claim_value="user@example.com", key="sk-test",
)
with pytest.raises(HTTPException) as exc_info:
await create_jwt_key_mapping(data=data, user_api_key_dict=_make_non_admin_auth())
assert exc_info.value.status_code == 403
@pytest.mark.asyncio
async def test_delete_returns_404_when_not_found():
"""Deleting non-existent mapping should return 404."""
from litellm.proxy._types import DeleteJWTKeyMappingRequest
mock_prisma = _mock_prisma()
mock_prisma.db.litellm_jwtkeymapping.find_unique.return_value = None
mock_cache = AsyncMock()
data = DeleteJWTKeyMappingRequest(id="nonexistent-id")
with patch("litellm.proxy.proxy_server.prisma_client", mock_prisma), patch(
"litellm.proxy.proxy_server.user_api_key_cache", mock_cache
):
with pytest.raises(HTTPException) as exc_info:
await delete_jwt_key_mapping(data=data, user_api_key_dict=_make_admin_auth())
assert exc_info.value.status_code == 404
@pytest.mark.asyncio
async def test_update_returns_404_when_not_found():
"""Updating non-existent mapping should return 404."""
from litellm.proxy._types import UpdateJWTKeyMappingRequest
mock_prisma = _mock_prisma()
mock_prisma.db.litellm_jwtkeymapping.find_unique.return_value = None
mock_cache = AsyncMock()
data = UpdateJWTKeyMappingRequest(id="nonexistent-id", description="test")
with patch("litellm.proxy.proxy_server.prisma_client", mock_prisma), patch(
"litellm.proxy.proxy_server.user_api_key_cache", mock_cache
):
with pytest.raises(HTTPException) as exc_info:
await update_jwt_key_mapping(data=data, user_api_key_dict=_make_admin_auth())
assert exc_info.value.status_code == 404
@pytest.mark.asyncio
async def test_info_returns_404_when_not_found():
"""Getting info for non-existent mapping should return 404."""
mock_prisma = _mock_prisma()
mock_prisma.db.litellm_jwtkeymapping.find_unique.return_value = None
with patch("litellm.proxy.proxy_server.prisma_client", mock_prisma):
with pytest.raises(HTTPException) as exc_info:
await info_jwt_key_mapping(id="nonexistent-id", user_api_key_dict=_make_admin_auth())
assert exc_info.value.status_code == 404
@pytest.mark.asyncio
async def test_create_success_returns_response_without_token():
"""Successful create should return JWTKeyMappingResponse without hashed token."""
from litellm.proxy._types import CreateJWTKeyMappingRequest
mock_prisma = _mock_prisma()
mock_prisma.db.litellm_jwtkeymapping.create.return_value = _mock_mapping()
mock_cache = AsyncMock()
data = CreateJWTKeyMappingRequest(
jwt_claim_name="email", jwt_claim_value="user@example.com", key="sk-test-key",
)
with patch("litellm.proxy.proxy_server.prisma_client", mock_prisma), patch(
"litellm.proxy.proxy_server.user_api_key_cache", mock_cache
):
result = await create_jwt_key_mapping(data=data, user_api_key_dict=_make_admin_auth())
assert isinstance(result, JWTKeyMappingResponse)
assert "token" not in result.model_fields
assert result.jwt_claim_name == "email"