Files
litellm/tests/proxy_unit_tests/test_jwt_key_mapping.py
T
Shivam Rawat 3bd89f209e Litellm jwt mapping virtualkeys (#28510)
* restore an explicit no-match policy

* fix(jwt): fix AUTO_REGISTER sentinel bypass, race condition, and inline import comment

- AUTO_REGISTER now evicts stale __NO_MAPPING__ sentinel instead of silently
  returning None when cached under a prior fallback_team_mapping config
- Race condition in _auto_register_jwt_mapping: catch P2002 unique-constraint
  violation on concurrent creates, fetch the winning mapping, proceed cleanly
- Added comment on inline generate_key_helper_fn import explaining the circular
  dependency (key_management_endpoints imports user_api_key_auth at line 51)
- 3 new tests: stale sentinel eviction, race condition winner fallback, and the
  existing auto_register happy path

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* fix(jwt): cache __NO_MAPPING__ sentinel before raising 403 in REJECT mode

REJECT mode was raising HTTPException immediately on a DB miss without writing
the __NO_MAPPING__ sentinel, causing every subsequent rejected request to
re-query the DB. Write the sentinel first so repeated rejections are served
from cache within virtual_key_mapping_cache_ttl.

Adds test asserting DB is not hit on the second reject after a cache-warm miss.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* fix(jwt): enforce no-match policy when prisma_client is None

The early `if prisma_client is None: return None` guard ran before the
no-match policy check, silently bypassing REJECT and AUTO_REGISTER — every
JWT client fell through to team auth regardless of configuration.

Fix: treat prisma_client=None as a definitive DB miss and fall through to the
same policy block as a real miss. REJECT now raises 403, AUTO_REGISTER raises
500 with a clear message (can't create keys without a DB), FALLBACK_TEAM_MAPPING
returns None unchanged.

Adds three tests: REJECT/403 with no DB, FALLBACK returns None with no DB,
AUTO_REGISTER/500 with no DB.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* fix(jwt): consistent AUTO_REGISTER on cached sentinel; clean up race orphans

Addresses Greptile review on PR #25570 cherry-pick.

1. Inconsistent AUTO_REGISTER when __NO_MAPPING__ sentinel is cached:
   The cached-sentinel branch silently returned None when prisma_client was
   None, while the fresh path raised HTTP 500 under the same config. Same
   request, different access-control outcome depending on cache state. Both
   paths now raise the same 500.

2. Orphaned virtual keys from race-condition losers:
   On unique-constraint conflict, generate_key_helper_fn had already persisted
   an unrestricted virtual key in LiteLLM_VerificationToken with the cleartext
   in request memory. Under sustained concurrency these accumulated
   indefinitely. The loser now deletes its orphan before falling back to the
   winner's mapping; failure to delete is logged but does not fail the request.

Also corrects a latent FK bug surfaced while fixing #2: the mapping row was
storing the plaintext key in LiteLLM_JWTKeyMapping.token, but that column FKs
to the hashed LiteLLM_VerificationToken.token — now hashed at the call site.

Tests:
- updated test_auto_register_creates_key_and_mapping to assert the hashed
  token is stored, not the plaintext
- updated test_auto_register_race_condition_unique_conflict to assert the
  orphan is deleted with the correct hashed token
- added test_auto_register_raises_500_when_sentinel_cached_and_no_db
- added test_auto_register_race_conflict_tolerates_delete_failure

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

* fix(jwt): close REJECT bypass when JWT omits the configured claim field

A JWT presented without the configured `virtual_key_claim_field` previously
returned None at the `claim_value is None` guard before the
`unregistered_jwt_client_behavior` check ran. A caller who knows the configured
claim-field name could bypass REJECT by simply omitting that field and falling
through to team-based JWT auth.

Apply the no-match policy on a missing claim:
  - REJECT          → 403
  - AUTO_REGISTER   → 403 (no stable identity to map; refuse rather than
                     create a sentinel-keyed record)
  - FALLBACK_TEAM_MAPPING → return None (unchanged, backward-compatible)

Adds three tests covering each branch of the missing-claim path.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

* fix(jwt): AUTO_REGISTER inherits team_id so keys are bounded by team limits

Auto-registered virtual keys were created with no team, model, route, rate, or
budget constraints — broader access than the standard team-based JWT auth path
the same client would have taken. Under AUTO_REGISTER, resolve the team_id
from the JWT (via the operator-configured team_id_jwt_field / team_id_default)
and stamp it on the new key. Downstream auth then applies the team's
budget/models/tpm/rpm/allowed_routes via the existing virtual-key flow.

Policy when team_id_jwt_field is configured:
  - JWT carries team claim → stamp resolved team_id
  - JWT lacks claim + team_id_default set → stamp default
  - JWT lacks claim + no default → 403 (refuse to create an unbounded key)

When neither team_id_jwt_field nor team_id_default is configured, the
operator has explicitly opted out of team-based limits — the auto-created
key has no team_id (matches what team-auth would do in the same config).

Adds 4 tests covering each branch.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

* fix(jwt): make AUTO_REGISTER functional in prod; raise on missing winner

Two correctness fixes flagged by Greptile on the AUTO_REGISTER path:

1. generate_key_helper_fn was called without table_name="key". Without that,
   the helper falls into the user-upsert branch (table_name in (None, "user"))
   and tries to insert into LiteLLM_UserTable with user_id=None, which hits
   the NOT NULL @id constraint. AUTO_REGISTER would never have succeeded in
   production. Now passes table_name="key" explicitly, matching the
   /key/generate caller.

2. When the race loser refetches the winner's mapping and gets None (winner
   row concurrently deleted), the previous code returned None — and the
   caller in _resolve_jwt_to_virtual_key then fell through to less-
   restrictive team-based JWT auth, silently bypassing the configured
   AUTO_REGISTER policy. Now raises HTTP 503 so the caller retries against
   a stable state rather than getting unintended fallback access.

Adds one test for the 503 winner-vanishes path.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

* fix(jwt): defer AUTO_REGISTER until JWT policy is enforced by auth_builder

Closes the JWT policy bypass on the AUTO_REGISTER path flagged by veria-ai.

Before: when unregistered_jwt_client_behavior=auto_register and the JWT's
claim was unmapped, _resolve_jwt_to_virtual_key validated the JWT signature
and then immediately created a virtual key + mapping. JWTAuthManager.auth_builder
never ran for the first request (the new key short-circuited the team-auth
path), and every subsequent request hit the cached mapping — so custom_validate,
RBAC, scope_mappings, and user_allowed_email_domain were never enforced for
auto-registered clients.

After: _resolve_jwt_to_virtual_key returns a _PendingAutoRegister signal
instead of creating the key. The caller in _user_api_key_auth_builder runs
JWTAuthManager.auth_builder, then — only on a validated, policy-passing
result — calls _auto_register_jwt_mapping with the team_id / user_id from
that result. The created key inherits team + user limits from the validated
identity, and future cache hits load that already-policy-checked key.

Also drops the interim _resolve_inherited_team_id helper that pulled team_id
from raw JWT claims — same bypass risk; team_id now comes exclusively from
auth_builder.

Tests:
  - Rewrote two existing tests to assert _resolve_jwt_to_virtual_key returns
    _PendingAutoRegister (no key created yet) for both the fresh-DB-miss
    and stale-sentinel branches
  - Added a contract test that _auto_register_jwt_mapping stamps the
    validated team_id/user_id onto generate_key_helper_fn
  - Removed four stale team-binding tests that exercised the prior
    raw-claim helper

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

* Update user_api_key_auth.py

* fix(jwt): cache proxy-admin AUTO_REGISTER path to avoid repeated DB lookups

Cache-miss regression introduced by the deferred-auto-register refactor:
when a JWT under AUTO_REGISTER resolved to a proxy admin, the is_proxy_admin
early-return in _user_api_key_auth_builder ran *before* the pending
auto-register cache-write block. Result: no cache entry, so every
subsequent proxy-admin request re-queried get_jwt_key_mapping_object
indefinitely.

Fix: write a __JWT_PROXY_ADMIN__ sentinel to user_api_key_cache before the
early return when a pending auto-register existed. _resolve_jwt_to_virtual_key
treats that sentinel as "skip mapping, fall through to auth_builder", so
future requests from the same JWT identity hit the cache instead of the DB.
auth_builder still runs full JWT policy on every request — only the
mapping DB lookup is short-circuited.

Adds one test asserting the sentinel cache-hit returns None without
hitting prisma_client.db.litellm_jwtkeymapping.find_first.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

* fix(proxy): stamp org context on JWT auto-registered keys

AUTO_REGISTER keys were created with team_id and user_id only, so org budget checks were skipped after switching to the key-scoped path.

Co-authored-by: Cursor <cursoragent@cursor.com>

---------

Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
Co-authored-by: Cursor <cursoragent@cursor.com>
2026-06-04 19:00:36 -07:00

1309 lines
49 KiB
Python

import pytest
import sys
import os
from datetime import datetime, timezone
from unittest.mock import AsyncMock, MagicMock, patch
# Add project root to sys.path
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")))
from litellm.proxy.auth.user_api_key_auth import (
_resolve_jwt_to_virtual_key,
)
from litellm.proxy.auth.handle_jwt import JWTHandler
from litellm.proxy._types import (
JWTKeyMappingResponse,
LiteLLM_JWTAuth,
LitellmUserRoles,
UserAPIKeyAuth,
)
from litellm.proxy.management_endpoints.jwt_key_mapping_endpoints import (
_to_response,
create_jwt_key_mapping,
delete_jwt_key_mapping,
info_jwt_key_mapping,
update_jwt_key_mapping,
)
from litellm.caching.caching import DualCache
from fastapi import HTTPException
# ──────────────────────────────────────────────
# Tests: _resolve_jwt_to_virtual_key
# ──────────────────────────────────────────────
@pytest.mark.asyncio
async def test_jwt_to_virtual_key_mapping_resolution():
"""
Test that a JWT claim is correctly resolved to a virtual key token.
"""
jwt_handler = JWTHandler()
jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth(
virtual_key_claim_field="email", virtual_key_mapping_cache_ttl=3600
)
jwt_claims = {"email": "user@example.com", "sub": "123"}
prisma_client = MagicMock()
prisma_client.db.litellm_jwtkeymapping.find_first = AsyncMock()
# Mock finding a mapping
mock_mapping = MagicMock()
mock_mapping.token = "sk-1234"
mock_mapping.is_active = True
prisma_client.db.litellm_jwtkeymapping.find_first.return_value = mock_mapping
# Mock getting the key object
mock_key_obj = UserAPIKeyAuth(token="sk-1234", team_id="team1")
user_api_key_cache = DualCache()
# Use patch to mock get_key_object in the module where it's used
with patch(
"litellm.proxy.auth.user_api_key_auth.get_key_object", new_callable=AsyncMock
) as mock_get_key:
mock_get_key.return_value = mock_key_obj
result = await _resolve_jwt_to_virtual_key(
jwt_claims=jwt_claims,
jwt_handler=jwt_handler,
prisma_client=prisma_client,
user_api_key_cache=user_api_key_cache,
parent_otel_span=None,
proxy_logging_obj=None,
)
assert result == mock_key_obj
prisma_client.db.litellm_jwtkeymapping.find_first.assert_called_once()
# Test Cache hit
prisma_client.db.litellm_jwtkeymapping.find_first.reset_mock()
result_cached = await _resolve_jwt_to_virtual_key(
jwt_claims=jwt_claims,
jwt_handler=jwt_handler,
prisma_client=prisma_client,
user_api_key_cache=user_api_key_cache,
parent_otel_span=None,
proxy_logging_obj=None,
)
assert result_cached == mock_key_obj
prisma_client.db.litellm_jwtkeymapping.find_first.assert_not_called()
@pytest.mark.asyncio
async def test_jwt_to_virtual_key_mapping_no_mapping():
"""
Test that when no mapping exists, resolve returns None.
"""
jwt_handler = JWTHandler()
jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth(virtual_key_claim_field="email")
jwt_claims = {"email": "unknown@example.com"}
prisma_client = MagicMock()
prisma_client.db.litellm_jwtkeymapping.find_first = AsyncMock()
prisma_client.db.litellm_jwtkeymapping.find_first.return_value = None
# Mock get_key_object just in case
with patch(
"litellm.proxy.auth.user_api_key_auth.get_key_object", new_callable=AsyncMock
):
user_api_key_cache = DualCache()
result = await _resolve_jwt_to_virtual_key(
jwt_claims=jwt_claims,
jwt_handler=jwt_handler,
prisma_client=prisma_client,
user_api_key_cache=user_api_key_cache,
parent_otel_span=None,
proxy_logging_obj=None,
)
assert result is None
# Test Negative Cache hit
prisma_client.db.litellm_jwtkeymapping.find_first.reset_mock()
result_cached = await _resolve_jwt_to_virtual_key(
jwt_claims=jwt_claims,
jwt_handler=jwt_handler,
prisma_client=prisma_client,
user_api_key_cache=user_api_key_cache,
parent_otel_span=None,
proxy_logging_obj=None,
)
assert result_cached is None
prisma_client.db.litellm_jwtkeymapping.find_first.assert_not_called()
# ──────────────────────────────────────────────
# Tests: OIDC / JWT routing in user_api_key_auth
# ──────────────────────────────────────────────
@pytest.mark.asyncio
async def test_virtual_key_mapping_oidc_enabled_jwt_token_uses_auth_jwt():
"""
Regression test for the is_jwt routing fix in user_api_key_auth.py.
When oidc_userinfo_enabled=True and virtual_key_claim_field is set, but
the token is a well-formed JWT (3-part header.payload.sig), the virtual-key
claim lookup must call auth_jwt — not get_oidc_userinfo.
"""
# Three-part token: is_jwt() returns True
api_key = "eyJhbGciOiJSUzI1NiJ9.eyJlbWFpbCI6InVzZXJAZXhhbXBsZS5jb20ifQ.sig"
jwt_handler = JWTHandler()
jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth(
oidc_userinfo_enabled=True,
virtual_key_claim_field="email",
)
# Confirm our fixture token is treated as a JWT
assert jwt_handler.is_jwt(token=api_key) is True
auth_jwt_mock = AsyncMock(return_value={"email": "user@example.com", "sub": "123"})
oidc_userinfo_mock = AsyncMock(return_value={"email": "user@example.com"})
# Simulate the routing condition from user_api_key_auth.py
if jwt_handler.litellm_jwtauth.oidc_userinfo_enabled and not jwt_handler.is_jwt(
token=api_key
):
jwt_claims = await oidc_userinfo_mock(token=api_key)
else:
jwt_claims = await auth_jwt_mock(token=api_key)
auth_jwt_mock.assert_called_once_with(token=api_key)
oidc_userinfo_mock.assert_not_called()
assert jwt_claims["email"] == "user@example.com"
@pytest.mark.asyncio
async def test_virtual_key_mapping_oidc_enabled_opaque_token_uses_oidc_userinfo():
"""
Complement of the test above: when oidc_userinfo_enabled=True and the token
is an opaque access token (not a JWT), the virtual-key claim lookup must
call get_oidc_userinfo — not auth_jwt.
"""
# Opaque token: no dots → is_jwt() returns False
api_key = "some_opaque_access_token_with_no_dots"
jwt_handler = JWTHandler()
jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth(
oidc_userinfo_enabled=True,
virtual_key_claim_field="email",
)
assert jwt_handler.is_jwt(token=api_key) is False
auth_jwt_mock = AsyncMock(return_value={"email": "user@example.com"})
oidc_userinfo_mock = AsyncMock(
return_value={"email": "user@example.com", "sub": "123"}
)
if jwt_handler.litellm_jwtauth.oidc_userinfo_enabled and not jwt_handler.is_jwt(
token=api_key
):
jwt_claims = await oidc_userinfo_mock(token=api_key)
else:
jwt_claims = await auth_jwt_mock(token=api_key)
oidc_userinfo_mock.assert_called_once_with(token=api_key)
auth_jwt_mock.assert_not_called()
assert jwt_claims["sub"] == "123"
# ──────────────────────────────────────────────
# Tests: _to_response redacts hashed token
# ──────────────────────────────────────────────
def test_to_response_excludes_token():
"""_to_response should not expose the hashed token field."""
now = datetime.now(timezone.utc)
mock_mapping = MagicMock()
mock_mapping.id = "mapping-1"
mock_mapping.jwt_claim_name = "email"
mock_mapping.jwt_claim_value = "user@example.com"
mock_mapping.token = "hashed_secret_value"
mock_mapping.description = "test"
mock_mapping.is_active = True
mock_mapping.created_at = now
mock_mapping.updated_at = now
mock_mapping.created_by = "admin"
mock_mapping.updated_by = "admin"
resp = _to_response(mock_mapping)
assert isinstance(resp, JWTKeyMappingResponse)
assert resp.id == "mapping-1"
assert resp.jwt_claim_name == "email"
assert "token" not in resp.model_fields
# ──────────────────────────────────────────────
# Helpers
# ──────────────────────────────────────────────
def _make_admin_auth() -> UserAPIKeyAuth:
return UserAPIKeyAuth(
token="sk-admin",
user_role=LitellmUserRoles.PROXY_ADMIN,
)
def _make_non_admin_auth() -> UserAPIKeyAuth:
return UserAPIKeyAuth(
token="sk-user",
user_role=LitellmUserRoles.INTERNAL_USER,
)
def _mock_prisma():
prisma = MagicMock()
prisma.db.litellm_jwtkeymapping.create = AsyncMock()
prisma.db.litellm_jwtkeymapping.find_unique = AsyncMock()
prisma.db.litellm_jwtkeymapping.find_many = AsyncMock()
prisma.db.litellm_jwtkeymapping.update = AsyncMock()
prisma.db.litellm_jwtkeymapping.delete = AsyncMock()
prisma.db.litellm_jwtkeymapping.count = AsyncMock(return_value=0)
return prisma
def _mock_mapping(
id="mapping-1",
claim_name="email",
claim_value="user@example.com",
):
now = datetime.now(timezone.utc)
m = MagicMock()
m.id = id
m.jwt_claim_name = claim_name
m.jwt_claim_value = claim_value
m.token = "hashed_token"
m.description = None
m.is_active = True
m.created_at = now
m.updated_at = now
m.created_by = "admin"
m.updated_by = "admin"
return m
# ──────────────────────────────────────────────
# Tests: CRUD endpoint error handling
# ──────────────────────────────────────────────
@pytest.mark.asyncio
async def test_create_returns_409_on_unique_violation():
"""Duplicate mapping should return 409, not 500."""
from litellm.proxy._types import CreateJWTKeyMappingRequest
mock_prisma = _mock_prisma()
mock_prisma.db.litellm_jwtkeymapping.create.side_effect = Exception(
"Unique constraint failed (P2002)"
)
mock_cache = AsyncMock()
data = CreateJWTKeyMappingRequest(
jwt_claim_name="email",
jwt_claim_value="user@example.com",
key="sk-test-key",
)
with (
patch("litellm.proxy.proxy_server.prisma_client", mock_prisma),
patch("litellm.proxy.proxy_server.user_api_key_cache", mock_cache),
):
with pytest.raises(HTTPException) as exc_info:
await create_jwt_key_mapping(
data=data, user_api_key_dict=_make_admin_auth()
)
assert exc_info.value.status_code == 409
assert "already exists" in exc_info.value.detail
@pytest.mark.asyncio
async def test_create_returns_400_on_foreign_key_violation():
"""Non-existent key should return 400, not 500."""
from litellm.proxy._types import CreateJWTKeyMappingRequest
mock_prisma = _mock_prisma()
mock_prisma.db.litellm_jwtkeymapping.create.side_effect = Exception(
"Foreign key constraint failed on field: `token` (P2003)"
)
mock_cache = AsyncMock()
data = CreateJWTKeyMappingRequest(
jwt_claim_name="sub",
jwt_claim_value="user-999",
key="sk-nonexistent",
)
with (
patch("litellm.proxy.proxy_server.prisma_client", mock_prisma),
patch("litellm.proxy.proxy_server.user_api_key_cache", mock_cache),
):
with pytest.raises(HTTPException) as exc_info:
await create_jwt_key_mapping(
data=data, user_api_key_dict=_make_admin_auth()
)
assert exc_info.value.status_code == 400
assert "does not match" in exc_info.value.detail
@pytest.mark.asyncio
async def test_create_non_admin_returns_403():
"""Non-admin users should get 403."""
from litellm.proxy._types import CreateJWTKeyMappingRequest
data = CreateJWTKeyMappingRequest(
jwt_claim_name="email",
jwt_claim_value="user@example.com",
key="sk-test",
)
with pytest.raises(HTTPException) as exc_info:
await create_jwt_key_mapping(
data=data, user_api_key_dict=_make_non_admin_auth()
)
assert exc_info.value.status_code == 403
@pytest.mark.asyncio
async def test_delete_returns_404_when_not_found():
"""Deleting non-existent mapping should return 404."""
from litellm.proxy._types import DeleteJWTKeyMappingRequest
mock_prisma = _mock_prisma()
mock_prisma.db.litellm_jwtkeymapping.find_unique.return_value = None
mock_cache = AsyncMock()
data = DeleteJWTKeyMappingRequest(id="nonexistent-id")
with (
patch("litellm.proxy.proxy_server.prisma_client", mock_prisma),
patch("litellm.proxy.proxy_server.user_api_key_cache", mock_cache),
):
with pytest.raises(HTTPException) as exc_info:
await delete_jwt_key_mapping(
data=data, user_api_key_dict=_make_admin_auth()
)
assert exc_info.value.status_code == 404
@pytest.mark.asyncio
async def test_update_returns_404_when_not_found():
"""Updating non-existent mapping should return 404."""
from litellm.proxy._types import UpdateJWTKeyMappingRequest
mock_prisma = _mock_prisma()
mock_prisma.db.litellm_jwtkeymapping.find_unique.return_value = None
mock_cache = AsyncMock()
data = UpdateJWTKeyMappingRequest(id="nonexistent-id", description="test")
with (
patch("litellm.proxy.proxy_server.prisma_client", mock_prisma),
patch("litellm.proxy.proxy_server.user_api_key_cache", mock_cache),
):
with pytest.raises(HTTPException) as exc_info:
await update_jwt_key_mapping(
data=data, user_api_key_dict=_make_admin_auth()
)
assert exc_info.value.status_code == 404
@pytest.mark.asyncio
async def test_info_returns_404_when_not_found():
"""Getting info for non-existent mapping should return 404."""
mock_prisma = _mock_prisma()
mock_prisma.db.litellm_jwtkeymapping.find_unique.return_value = None
with patch("litellm.proxy.proxy_server.prisma_client", mock_prisma):
with pytest.raises(HTTPException) as exc_info:
await info_jwt_key_mapping(
id="nonexistent-id", user_api_key_dict=_make_admin_auth()
)
assert exc_info.value.status_code == 404
@pytest.mark.asyncio
async def test_create_success_returns_response_without_token():
"""Successful create should return JWTKeyMappingResponse without hashed token."""
from litellm.proxy._types import CreateJWTKeyMappingRequest
mock_prisma = _mock_prisma()
mock_prisma.db.litellm_jwtkeymapping.create.return_value = _mock_mapping()
mock_cache = AsyncMock()
data = CreateJWTKeyMappingRequest(
jwt_claim_name="email",
jwt_claim_value="user@example.com",
key="sk-test-key",
)
with (
patch("litellm.proxy.proxy_server.prisma_client", mock_prisma),
patch("litellm.proxy.proxy_server.user_api_key_cache", mock_cache),
):
result = await create_jwt_key_mapping(
data=data, user_api_key_dict=_make_admin_auth()
)
assert isinstance(result, JWTKeyMappingResponse)
assert "token" not in result.model_fields
assert result.jwt_claim_name == "email"
# ──────────────────────────────────────────────
# Tests: unregistered_jwt_client_behavior
# ──────────────────────────────────────────────
@pytest.mark.asyncio
async def test_reject_behavior_raises_403_on_no_mapping():
"""
When unregistered_jwt_client_behavior='reject' and no mapping exists,
_resolve_jwt_to_virtual_key must raise HTTP 403.
"""
from litellm.proxy._types import UnregisteredJWTClientBehavior
jwt_handler = JWTHandler()
jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth(
virtual_key_claim_field="email",
unregistered_jwt_client_behavior=UnregisteredJWTClientBehavior.REJECT,
)
jwt_claims = {"email": "unknown@example.com"}
prisma_client = MagicMock()
prisma_client.db.litellm_jwtkeymapping.find_first = AsyncMock(return_value=None)
user_api_key_cache = DualCache()
with patch(
"litellm.proxy.auth.user_api_key_auth.get_key_object", new_callable=AsyncMock
):
with pytest.raises(HTTPException) as exc_info:
await _resolve_jwt_to_virtual_key(
jwt_claims=jwt_claims,
jwt_handler=jwt_handler,
prisma_client=prisma_client,
user_api_key_cache=user_api_key_cache,
parent_otel_span=None,
proxy_logging_obj=None,
)
assert exc_info.value.status_code == 403
assert "unknown@example.com" in exc_info.value.detail
@pytest.mark.asyncio
async def test_reject_behavior_caches_sentinel_after_db_miss():
"""
On a fresh DB miss with REJECT, the __NO_MAPPING__ sentinel must be written
to cache so that subsequent rejected requests are served from cache and do
not re-query the DB.
"""
from litellm.proxy._types import UnregisteredJWTClientBehavior
jwt_handler = JWTHandler()
jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth(
virtual_key_claim_field="email",
unregistered_jwt_client_behavior=UnregisteredJWTClientBehavior.REJECT,
virtual_key_mapping_cache_ttl=300,
)
jwt_claims = {"email": "unknown@example.com"}
prisma_client = MagicMock()
prisma_client.db.litellm_jwtkeymapping.find_first = AsyncMock(return_value=None)
user_api_key_cache = DualCache()
with patch(
"litellm.proxy.auth.user_api_key_auth.get_key_object", new_callable=AsyncMock
):
# First call — DB miss, should raise 403 and write sentinel
with pytest.raises(HTTPException) as exc_info:
await _resolve_jwt_to_virtual_key(
jwt_claims=jwt_claims,
jwt_handler=jwt_handler,
prisma_client=prisma_client,
user_api_key_cache=user_api_key_cache,
parent_otel_span=None,
proxy_logging_obj=None,
)
assert exc_info.value.status_code == 403
# Sentinel must now be in cache
cached = await user_api_key_cache.async_get_cache(
"jwt_key_mapping:email:unknown@example.com"
)
assert cached == "__NO_MAPPING__"
# Second call — must raise 403 from cache, no additional DB hit
prisma_client.db.litellm_jwtkeymapping.find_first.reset_mock()
with pytest.raises(HTTPException) as exc_info2:
await _resolve_jwt_to_virtual_key(
jwt_claims=jwt_claims,
jwt_handler=jwt_handler,
prisma_client=prisma_client,
user_api_key_cache=user_api_key_cache,
parent_otel_span=None,
proxy_logging_obj=None,
)
assert exc_info2.value.status_code == 403
prisma_client.db.litellm_jwtkeymapping.find_first.assert_not_called()
@pytest.mark.asyncio
async def test_reject_behavior_raises_403_on_cached_no_mapping():
"""
When the negative-cache sentinel __NO_MAPPING__ is present and behavior is
'reject', the function must also raise HTTP 403 (not return None silently).
"""
from litellm.proxy._types import UnregisteredJWTClientBehavior
jwt_handler = JWTHandler()
jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth(
virtual_key_claim_field="email",
unregistered_jwt_client_behavior=UnregisteredJWTClientBehavior.REJECT,
)
jwt_claims = {"email": "unknown@example.com"}
prisma_client = MagicMock()
prisma_client.db.litellm_jwtkeymapping.find_first = AsyncMock(return_value=None)
# Pre-populate the negative cache so the DB is not hit
user_api_key_cache = DualCache()
cache_key = "jwt_key_mapping:email:unknown@example.com"
await user_api_key_cache.async_set_cache(cache_key, "__NO_MAPPING__")
with patch(
"litellm.proxy.auth.user_api_key_auth.get_key_object", new_callable=AsyncMock
):
with pytest.raises(HTTPException) as exc_info:
await _resolve_jwt_to_virtual_key(
jwt_claims=jwt_claims,
jwt_handler=jwt_handler,
prisma_client=prisma_client,
user_api_key_cache=user_api_key_cache,
parent_otel_span=None,
proxy_logging_obj=None,
)
assert exc_info.value.status_code == 403
# DB must NOT have been hit (sentinel served from cache)
prisma_client.db.litellm_jwtkeymapping.find_first.assert_not_called()
@pytest.mark.asyncio
async def test_auto_register_returns_pending_signal_without_creating_key():
"""
Security: when unregistered_jwt_client_behavior='auto_register' and no
mapping exists, _resolve_jwt_to_virtual_key must NOT create the key yet.
It returns a _PendingAutoRegister signal so the caller can run
JWTAuthManager.auth_builder (enforcing RBAC, scope mappings,
custom_validate, user_allowed_email_domain) FIRST. Creating the key here
would bypass every JWT policy beyond signature verification.
"""
from litellm.proxy._types import UnregisteredJWTClientBehavior
from litellm.proxy.auth.user_api_key_auth import _PendingAutoRegister
jwt_handler = JWTHandler()
jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth(
virtual_key_claim_field="sub",
unregistered_jwt_client_behavior=UnregisteredJWTClientBehavior.AUTO_REGISTER,
virtual_key_mapping_cache_ttl=300,
)
jwt_claims = {"sub": "new-user-42"}
prisma_client = MagicMock()
prisma_client.db.litellm_jwtkeymapping.find_first = AsyncMock(return_value=None)
prisma_client.db.litellm_jwtkeymapping.create = AsyncMock()
user_api_key_cache = DualCache()
with patch(
"litellm.proxy.management_endpoints.key_management_endpoints.generate_key_helper_fn",
new_callable=AsyncMock,
) as mock_gen_key:
result = await _resolve_jwt_to_virtual_key(
jwt_claims=jwt_claims,
jwt_handler=jwt_handler,
prisma_client=prisma_client,
user_api_key_cache=user_api_key_cache,
parent_otel_span=None,
proxy_logging_obj=None,
)
assert isinstance(result, _PendingAutoRegister)
assert result.claim_field == "sub"
assert result.claim_value == "new-user-42"
assert result.cache_key == "jwt_key_mapping:sub:new-user-42"
# CRITICAL: no key was created — that must wait until after auth_builder
mock_gen_key.assert_not_called()
prisma_client.db.litellm_jwtkeymapping.create.assert_not_called()
@pytest.mark.asyncio
async def test_auto_register_creates_key_and_mapping_when_helper_invoked():
"""
When the caller invokes _auto_register_jwt_mapping directly (after
auth_builder validation), the helper creates the key + mapping row and
returns a UserAPIKeyAuth. The mapping row stores the hashed token (FK to
LiteLLM_VerificationToken), not the plaintext key.
"""
from litellm.proxy._types import hash_token
from litellm.proxy.auth.user_api_key_auth import _auto_register_jwt_mapping
jwt_handler = JWTHandler()
jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth(
virtual_key_claim_field="sub",
virtual_key_mapping_cache_ttl=300,
)
prisma_client = MagicMock()
prisma_client.db.litellm_jwtkeymapping.find_first = AsyncMock(return_value=None)
prisma_client.db.litellm_jwtkeymapping.create = AsyncMock()
user_api_key_cache = DualCache()
plaintext_key = "sk-auto-key"
expected_hash = hash_token(plaintext_key)
mock_key_obj = UserAPIKeyAuth(token=expected_hash, team_id="validated-team")
with (
patch(
"litellm.proxy.auth.user_api_key_auth.get_key_object",
new_callable=AsyncMock,
) as mock_get_key,
patch(
"litellm.proxy.management_endpoints.key_management_endpoints.generate_key_helper_fn",
new_callable=AsyncMock,
) as mock_gen_key,
):
mock_gen_key.return_value = {"token": plaintext_key, "key": plaintext_key}
mock_get_key.return_value = mock_key_obj
result = await _auto_register_jwt_mapping(
virtual_key_claim_field="sub",
claim_value="new-user-42",
jwt_handler=jwt_handler,
prisma_client=prisma_client,
user_api_key_cache=user_api_key_cache,
parent_otel_span=None,
proxy_logging_obj=None,
cache_key="jwt_key_mapping:sub:new-user-42",
team_id="validated-team",
user_id="validated-user",
)
assert result == mock_key_obj
# generate_key_helper_fn was passed table_name="key" (not user-upsert path)
# and the validated team_id + user_id from auth_builder
assert mock_gen_key.call_args.kwargs["table_name"] == "key"
assert mock_gen_key.call_args.kwargs["team_id"] == "validated-team"
assert mock_gen_key.call_args.kwargs["user_id"] == "validated-user"
# Mapping row was created with the hashed token (FK target)
call_data = prisma_client.db.litellm_jwtkeymapping.create.call_args[1]["data"]
assert call_data["jwt_claim_name"] == "sub"
assert call_data["jwt_claim_value"] == "new-user-42"
assert call_data["token"] == expected_hash
cached = await user_api_key_cache.async_get_cache("jwt_key_mapping:sub:new-user-42")
assert cached == expected_hash
@pytest.mark.asyncio
async def test_auto_register_returns_pending_signal_on_stale_no_mapping_sentinel():
"""
If the cache holds a stale __NO_MAPPING__ sentinel (written under a prior
fallback_team_mapping config) and behavior is now AUTO_REGISTER, the
resolver must evict the sentinel and return _PendingAutoRegister (so the
caller can run auth_builder before creating the key) — not silently return
None and not create the key on the spot.
"""
from litellm.proxy._types import UnregisteredJWTClientBehavior
from litellm.proxy.auth.user_api_key_auth import _PendingAutoRegister
jwt_handler = JWTHandler()
jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth(
virtual_key_claim_field="email",
unregistered_jwt_client_behavior=UnregisteredJWTClientBehavior.AUTO_REGISTER,
virtual_key_mapping_cache_ttl=300,
)
jwt_claims = {"email": "alice@corp.com"}
prisma_client = MagicMock()
prisma_client.db.litellm_jwtkeymapping.find_first = AsyncMock(return_value=None)
prisma_client.db.litellm_jwtkeymapping.create = AsyncMock()
user_api_key_cache = DualCache()
await user_api_key_cache.async_set_cache(
"jwt_key_mapping:email:alice@corp.com", "__NO_MAPPING__"
)
with patch(
"litellm.proxy.management_endpoints.key_management_endpoints.generate_key_helper_fn",
new_callable=AsyncMock,
) as mock_gen_key:
result = await _resolve_jwt_to_virtual_key(
jwt_claims=jwt_claims,
jwt_handler=jwt_handler,
prisma_client=prisma_client,
user_api_key_cache=user_api_key_cache,
parent_otel_span=None,
proxy_logging_obj=None,
)
assert isinstance(result, _PendingAutoRegister)
# Stale sentinel must be evicted so the deferred auto-register actually
# runs after auth_builder validates the JWT
cached_after = await user_api_key_cache.async_get_cache(
"jwt_key_mapping:email:alice@corp.com"
)
assert cached_after is None
mock_gen_key.assert_not_called()
prisma_client.db.litellm_jwtkeymapping.create.assert_not_called()
@pytest.mark.asyncio
async def test_auto_register_race_condition_unique_conflict():
"""
If two concurrent requests both call _auto_register_jwt_mapping and the
second hits a unique-constraint violation on create, it must:
1) delete the orphaned virtual key it just created (so orphans don't
accumulate in LiteLLM_VerificationToken under sustained concurrency),
2) fall back to the winner's mapping,
3) not surface an error.
"""
from litellm.proxy.auth.user_api_key_auth import _auto_register_jwt_mapping
from litellm.proxy._types import UnregisteredJWTClientBehavior, hash_token
jwt_handler = JWTHandler()
jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth(
virtual_key_claim_field="sub",
unregistered_jwt_client_behavior=UnregisteredJWTClientBehavior.AUTO_REGISTER,
virtual_key_mapping_cache_ttl=300,
)
prisma_client = MagicMock()
prisma_client.db.litellm_jwtkeymapping.create = AsyncMock(
side_effect=Exception("Unique constraint failed (P2002)")
)
prisma_client.db.litellm_verificationtoken.delete = AsyncMock()
# Simulate the winner's mapping already in DB after the conflict
winner_mapping = MagicMock()
winner_mapping.token = "winner_token_hash"
winner_mapping.is_active = True
prisma_client.db.litellm_jwtkeymapping.find_first = AsyncMock(
return_value=winner_mapping
)
user_api_key_cache = DualCache()
loser_plaintext = "sk-loser"
loser_hash = hash_token(loser_plaintext)
mock_key_obj = UserAPIKeyAuth(token="winner_token_hash", team_id=None)
with (
patch(
"litellm.proxy.auth.user_api_key_auth.get_key_object",
new_callable=AsyncMock,
) as mock_get_key,
patch(
"litellm.proxy.management_endpoints.key_management_endpoints.generate_key_helper_fn",
new_callable=AsyncMock,
return_value={"token": loser_plaintext, "key": loser_plaintext},
),
):
mock_get_key.return_value = mock_key_obj
result = await _auto_register_jwt_mapping(
virtual_key_claim_field="sub",
claim_value="user-42",
jwt_handler=jwt_handler,
prisma_client=prisma_client,
user_api_key_cache=user_api_key_cache,
parent_otel_span=None,
proxy_logging_obj=None,
cache_key="jwt_key_mapping:sub:user-42",
)
assert result == mock_key_obj
# The orphaned loser key must be deleted from LiteLLM_VerificationToken
prisma_client.db.litellm_verificationtoken.delete.assert_called_once_with(
where={"token": loser_hash}
)
# Cache should hold the winner's token, not the loser's
cached = await user_api_key_cache.async_get_cache("jwt_key_mapping:sub:user-42")
assert cached == "winner_token_hash"
mock_get_key.assert_called_once_with(
hashed_token="winner_token_hash",
prisma_client=prisma_client,
user_api_key_cache=user_api_key_cache,
parent_otel_span=None,
proxy_logging_obj=None,
)
# ──────────────────────────────────────────────
# Tests: prisma_client=None does not bypass no-match policy
# ──────────────────────────────────────────────
@pytest.mark.asyncio
async def test_reject_behavior_enforced_when_prisma_client_is_none():
"""
When prisma_client is None and behavior is REJECT, a 403 must be raised —
not silently fallen through to team auth.
"""
from litellm.proxy._types import UnregisteredJWTClientBehavior
jwt_handler = JWTHandler()
jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth(
virtual_key_claim_field="email",
unregistered_jwt_client_behavior=UnregisteredJWTClientBehavior.REJECT,
)
jwt_claims = {"email": "unknown@example.com"}
user_api_key_cache = DualCache()
with pytest.raises(HTTPException) as exc_info:
await _resolve_jwt_to_virtual_key(
jwt_claims=jwt_claims,
jwt_handler=jwt_handler,
prisma_client=None, # no DB
user_api_key_cache=user_api_key_cache,
parent_otel_span=None,
proxy_logging_obj=None,
)
assert exc_info.value.status_code == 403
assert "unknown@example.com" in exc_info.value.detail
@pytest.mark.asyncio
async def test_reject_raises_403_when_claim_field_missing_from_jwt():
"""
Security: a JWT that omits the configured virtual_key_claim_field must NOT
bypass the REJECT policy. Previously the early `if claim_value is None:
return None` branch ran before the policy check, letting a caller who knows
the configured claim-field name silently fall through to team-based auth.
"""
from litellm.proxy._types import UnregisteredJWTClientBehavior
jwt_handler = JWTHandler()
jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth(
virtual_key_claim_field="sub",
unregistered_jwt_client_behavior=UnregisteredJWTClientBehavior.REJECT,
)
# JWT does NOT contain "sub"
jwt_claims = {"email": "user@example.com"}
with pytest.raises(HTTPException) as exc_info:
await _resolve_jwt_to_virtual_key(
jwt_claims=jwt_claims,
jwt_handler=jwt_handler,
prisma_client=MagicMock(),
user_api_key_cache=DualCache(),
parent_otel_span=None,
proxy_logging_obj=None,
)
assert exc_info.value.status_code == 403
assert "'sub'" in exc_info.value.detail
assert "missing from the JWT" in exc_info.value.detail
@pytest.mark.asyncio
async def test_auto_register_raises_403_when_claim_field_missing_from_jwt():
"""
AUTO_REGISTER cannot create a mapping without a stable identity. When the
configured claim field is missing from the JWT, return 403 rather than
silently falling through (which would bypass the unregistered-client policy)
or creating a sentinel-keyed record.
"""
from litellm.proxy._types import UnregisteredJWTClientBehavior
jwt_handler = JWTHandler()
jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth(
virtual_key_claim_field="sub",
unregistered_jwt_client_behavior=UnregisteredJWTClientBehavior.AUTO_REGISTER,
)
jwt_claims = {"email": "user@example.com"}
with pytest.raises(HTTPException) as exc_info:
await _resolve_jwt_to_virtual_key(
jwt_claims=jwt_claims,
jwt_handler=jwt_handler,
prisma_client=MagicMock(),
user_api_key_cache=DualCache(),
parent_otel_span=None,
proxy_logging_obj=None,
)
assert exc_info.value.status_code == 403
assert "missing from the JWT" in exc_info.value.detail
@pytest.mark.asyncio
async def test_fallback_team_mapping_returns_none_when_claim_field_missing_from_jwt():
"""
Under FALLBACK_TEAM_MAPPING (the default, backward-compatible mode), a JWT
without the configured claim field must still fall through to team-based
JWT auth — not raise. This preserves the pre-existing contract.
"""
from litellm.proxy._types import UnregisteredJWTClientBehavior
jwt_handler = JWTHandler()
jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth(
virtual_key_claim_field="sub",
unregistered_jwt_client_behavior=UnregisteredJWTClientBehavior.FALLBACK_TEAM_MAPPING,
)
jwt_claims = {"email": "user@example.com"}
result = await _resolve_jwt_to_virtual_key(
jwt_claims=jwt_claims,
jwt_handler=jwt_handler,
prisma_client=MagicMock(),
user_api_key_cache=DualCache(),
parent_otel_span=None,
proxy_logging_obj=None,
)
assert result is None
@pytest.mark.asyncio
async def test_fallback_team_mapping_returns_none_when_prisma_client_is_none():
"""
When prisma_client is None and behavior is FALLBACK_TEAM_MAPPING, the
function must return None (fall through to team auth) — not raise.
"""
from litellm.proxy._types import UnregisteredJWTClientBehavior
jwt_handler = JWTHandler()
jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth(
virtual_key_claim_field="email",
unregistered_jwt_client_behavior=UnregisteredJWTClientBehavior.FALLBACK_TEAM_MAPPING,
)
jwt_claims = {"email": "anyone@example.com"}
result = await _resolve_jwt_to_virtual_key(
jwt_claims=jwt_claims,
jwt_handler=jwt_handler,
prisma_client=None,
user_api_key_cache=DualCache(),
parent_otel_span=None,
proxy_logging_obj=None,
)
assert result is None
@pytest.mark.asyncio
async def test_auto_register_raises_500_when_prisma_client_is_none():
"""
AUTO_REGISTER without a DB connection must raise HTTP 500 with a clear
message — it cannot create keys without a database.
"""
from litellm.proxy._types import UnregisteredJWTClientBehavior
jwt_handler = JWTHandler()
jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth(
virtual_key_claim_field="sub",
unregistered_jwt_client_behavior=UnregisteredJWTClientBehavior.AUTO_REGISTER,
)
jwt_claims = {"sub": "new-user-42"}
with pytest.raises(HTTPException) as exc_info:
await _resolve_jwt_to_virtual_key(
jwt_claims=jwt_claims,
jwt_handler=jwt_handler,
prisma_client=None,
user_api_key_cache=DualCache(),
parent_otel_span=None,
proxy_logging_obj=None,
)
assert exc_info.value.status_code == 500
assert "AUTO_REGISTER requires a database" in exc_info.value.detail
@pytest.mark.asyncio
async def test_auto_register_raises_500_when_sentinel_cached_and_no_db():
"""
AUTO_REGISTER + cached __NO_MAPPING__ sentinel + prisma_client is None must
raise HTTP 500, matching the fresh-path behavior. Previously this path
silently returned None and let the request fall through to team auth,
creating different access-control outcomes under identical configuration.
"""
from litellm.proxy._types import UnregisteredJWTClientBehavior
jwt_handler = JWTHandler()
jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth(
virtual_key_claim_field="sub",
unregistered_jwt_client_behavior=UnregisteredJWTClientBehavior.AUTO_REGISTER,
virtual_key_mapping_cache_ttl=300,
)
jwt_claims = {"sub": "user-42"}
user_api_key_cache = DualCache()
# Stale sentinel written under a prior fallback_team_mapping config
await user_api_key_cache.async_set_cache(
"jwt_key_mapping:sub:user-42", "__NO_MAPPING__"
)
with pytest.raises(HTTPException) as exc_info:
await _resolve_jwt_to_virtual_key(
jwt_claims=jwt_claims,
jwt_handler=jwt_handler,
prisma_client=None,
user_api_key_cache=user_api_key_cache,
parent_otel_span=None,
proxy_logging_obj=None,
)
assert exc_info.value.status_code == 500
assert "AUTO_REGISTER requires a database" in exc_info.value.detail
@pytest.mark.asyncio
async def test_auto_register_race_conflict_tolerates_delete_failure():
"""
If deleting the orphaned virtual key after a race-condition conflict fails
(e.g. transient DB error), the request must still succeed by returning the
winner's mapping — the orphan is unmapped and inert.
"""
from litellm.proxy.auth.user_api_key_auth import _auto_register_jwt_mapping
from litellm.proxy._types import UnregisteredJWTClientBehavior
jwt_handler = JWTHandler()
jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth(
virtual_key_claim_field="sub",
unregistered_jwt_client_behavior=UnregisteredJWTClientBehavior.AUTO_REGISTER,
virtual_key_mapping_cache_ttl=300,
)
prisma_client = MagicMock()
prisma_client.db.litellm_jwtkeymapping.create = AsyncMock(
side_effect=Exception("Unique constraint failed (P2002)")
)
prisma_client.db.litellm_verificationtoken.delete = AsyncMock(
side_effect=Exception("transient DB error")
)
winner_mapping = MagicMock()
winner_mapping.token = "winner_token_hash"
winner_mapping.is_active = True
prisma_client.db.litellm_jwtkeymapping.find_first = AsyncMock(
return_value=winner_mapping
)
user_api_key_cache = DualCache()
mock_key_obj = UserAPIKeyAuth(token="winner_token_hash", team_id=None)
with (
patch(
"litellm.proxy.auth.user_api_key_auth.get_key_object",
new_callable=AsyncMock,
) as mock_get_key,
patch(
"litellm.proxy.management_endpoints.key_management_endpoints.generate_key_helper_fn",
new_callable=AsyncMock,
return_value={"token": "sk-loser", "key": "sk-loser"},
),
):
mock_get_key.return_value = mock_key_obj
result = await _auto_register_jwt_mapping(
virtual_key_claim_field="sub",
claim_value="user-42",
jwt_handler=jwt_handler,
prisma_client=prisma_client,
user_api_key_cache=user_api_key_cache,
parent_otel_span=None,
proxy_logging_obj=None,
cache_key="jwt_key_mapping:sub:user-42",
)
# Caller still receives the winner's mapping even when cleanup fails
assert result == mock_key_obj
prisma_client.db.litellm_verificationtoken.delete.assert_called_once()
@pytest.mark.asyncio
async def test_auto_register_raises_503_when_winner_mapping_vanishes():
"""
Race edge case: this request loses the unique-constraint race, deletes its
orphan, then refetches the winner's mapping — but the winner's row was
concurrently deleted. Previously this returned None, silently falling
through to less-restrictive team-based JWT auth (bypassing the configured
AUTO_REGISTER policy). Must now raise HTTP 503 so the caller retries
rather than getting unintended fallback access.
"""
from litellm.proxy.auth.user_api_key_auth import _auto_register_jwt_mapping
from litellm.proxy._types import UnregisteredJWTClientBehavior
jwt_handler = JWTHandler()
jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth(
virtual_key_claim_field="sub",
unregistered_jwt_client_behavior=UnregisteredJWTClientBehavior.AUTO_REGISTER,
virtual_key_mapping_cache_ttl=300,
)
prisma_client = MagicMock()
prisma_client.db.litellm_jwtkeymapping.create = AsyncMock(
side_effect=Exception("Unique constraint failed (P2002)")
)
prisma_client.db.litellm_verificationtoken.delete = AsyncMock()
# Winner row no longer exists by the time we refetch
prisma_client.db.litellm_jwtkeymapping.find_first = AsyncMock(return_value=None)
user_api_key_cache = DualCache()
with (
patch(
"litellm.proxy.management_endpoints.key_management_endpoints.generate_key_helper_fn",
new_callable=AsyncMock,
return_value={"token": "sk-loser", "key": "sk-loser"},
),
pytest.raises(HTTPException) as exc_info,
):
await _auto_register_jwt_mapping(
virtual_key_claim_field="sub",
claim_value="user-42",
jwt_handler=jwt_handler,
prisma_client=prisma_client,
user_api_key_cache=user_api_key_cache,
parent_otel_span=None,
proxy_logging_obj=None,
cache_key="jwt_key_mapping:sub:user-42",
)
assert exc_info.value.status_code == 503
assert "concurrently removed" in exc_info.value.detail
@pytest.mark.asyncio
async def test_proxy_admin_sentinel_skips_db_lookup_on_cache_hit():
"""
When the cache holds the proxy-admin sentinel (written after a prior
request's is_proxy_admin early-return), _resolve_jwt_to_virtual_key must
return None *without* hitting the DB. Caller proceeds to auth_builder.
Without this, every subsequent proxy-admin request under AUTO_REGISTER
would re-query get_jwt_key_mapping_object — a cache-miss regression
introduced by the deferred-auto-register refactor.
"""
from litellm.proxy._types import UnregisteredJWTClientBehavior
jwt_handler = JWTHandler()
jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth(
virtual_key_claim_field="sub",
unregistered_jwt_client_behavior=UnregisteredJWTClientBehavior.AUTO_REGISTER,
virtual_key_mapping_cache_ttl=300,
)
jwt_claims = {"sub": "admin-user"}
prisma_client = MagicMock()
# Will fail the test if accessed — proves the sentinel short-circuits DB
prisma_client.db.litellm_jwtkeymapping.find_first = AsyncMock(
side_effect=AssertionError("DB must not be hit when sentinel is cached")
)
user_api_key_cache = DualCache()
await user_api_key_cache.async_set_cache(
"jwt_key_mapping:sub:admin-user", "__JWT_PROXY_ADMIN__"
)
result = await _resolve_jwt_to_virtual_key(
jwt_claims=jwt_claims,
jwt_handler=jwt_handler,
prisma_client=prisma_client,
user_api_key_cache=user_api_key_cache,
parent_otel_span=None,
proxy_logging_obj=None,
)
assert result is None
prisma_client.db.litellm_jwtkeymapping.find_first.assert_not_called()
# ──────────────────────────────────────────────
# Tests: AUTO_REGISTER stamps validated identity from auth_builder
# ──────────────────────────────────────────────
@pytest.mark.asyncio
async def test_auto_register_helper_stamps_validated_identity_context():
"""
The deferred-auto-register contract: _auto_register_jwt_mapping is called
with identity fields from JWTAuthManager.auth_builder's *validated*
result (after RBAC, scope mappings, custom_validate, email-domain policy).
These must be passed to generate_key_helper_fn so the created key carries
them — the cached future-request path then inherits the same team/user/org
limits the auth_builder path would have applied.
"""
from litellm.proxy.auth.user_api_key_auth import _auto_register_jwt_mapping
jwt_handler = JWTHandler()
jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth(
virtual_key_claim_field="sub",
virtual_key_mapping_cache_ttl=300,
)
prisma_client = MagicMock()
prisma_client.db.litellm_jwtkeymapping.create = AsyncMock()
mock_key_obj = UserAPIKeyAuth(
token="hashed", team_id="validated-team", user_id="validated-user"
)
with (
patch(
"litellm.proxy.auth.user_api_key_auth.get_key_object",
new_callable=AsyncMock,
) as mock_get_key,
patch(
"litellm.proxy.management_endpoints.key_management_endpoints.generate_key_helper_fn",
new_callable=AsyncMock,
) as mock_gen_key,
):
mock_gen_key.return_value = {"token": "sk-newkey", "key": "sk-newkey"}
mock_get_key.return_value = mock_key_obj
result = await _auto_register_jwt_mapping(
virtual_key_claim_field="sub",
claim_value="new-user",
jwt_handler=jwt_handler,
prisma_client=prisma_client,
user_api_key_cache=DualCache(),
parent_otel_span=None,
proxy_logging_obj=None,
cache_key="jwt_key_mapping:sub:new-user",
team_id="validated-team",
user_id="validated-user",
org_id="validated-org",
end_user_id="validated-end-user",
)
assert result == mock_key_obj
assert mock_gen_key.call_args.kwargs["team_id"] == "validated-team"
assert mock_gen_key.call_args.kwargs["user_id"] == "validated-user"
assert mock_gen_key.call_args.kwargs["organization_id"] == "validated-org"
assert result.org_id == "validated-org"
assert result.end_user_id == "validated-end-user"
# ──────────────────────────────────────────────
# Tests: backward-compat alias jwt_client_id_field
# ──────────────────────────────────────────────
def test_jwt_client_id_field_alias_maps_to_virtual_key_claim_field():
"""
jwt_client_id_field (old doc name) must silently alias to virtual_key_claim_field.
"""
auth = LiteLLM_JWTAuth(jwt_client_id_field="azp")
assert auth.virtual_key_claim_field == "azp"
def test_jwt_client_id_field_does_not_raise_on_duplicate():
"""
If both jwt_client_id_field and virtual_key_claim_field are supplied,
virtual_key_claim_field takes precedence and no error is raised.
"""
auth = LiteLLM_JWTAuth(
jwt_client_id_field="old_field",
virtual_key_claim_field="new_field",
)
assert auth.virtual_key_claim_field == "new_field"