diff --git a/litellm/proxy/auth/handle_jwt.py b/litellm/proxy/auth/handle_jwt.py index 880ce3fb32..fef7235585 100644 --- a/litellm/proxy/auth/handle_jwt.py +++ b/litellm/proxy/auth/handle_jwt.py @@ -1397,8 +1397,12 @@ class JWTAuthManager: request_headers: Optional[dict] = None, ) -> JWTAuthBuilderResult: """Main authentication and authorization builder""" - # Check if OIDC UserInfo endpoint is enabled - if jwt_handler.litellm_jwtauth.oidc_userinfo_enabled: + # Check if OIDC UserInfo endpoint is enabled, but fall back to standard + # JWT auth if the token itself is a well-formed JWT (3-part structure). + if ( + jwt_handler.litellm_jwtauth.oidc_userinfo_enabled + and not jwt_handler.is_jwt(token=api_key) + ): verbose_proxy_logger.debug( "OIDC UserInfo is enabled. Fetching user info from UserInfo endpoint." ) diff --git a/litellm/proxy/auth/user_api_key_auth.py b/litellm/proxy/auth/user_api_key_auth.py index 9dd2ab18c8..53ae08aefb 100644 --- a/litellm/proxy/auth/user_api_key_auth.py +++ b/litellm/proxy/auth/user_api_key_auth.py @@ -686,7 +686,10 @@ async def _user_api_key_auth_builder( # noqa: PLR0915 if jwt_handler.litellm_jwtauth.virtual_key_claim_field is not None: # Decode JWT to get claims without running full auth_builder jwt_claims: Optional[dict] - if jwt_handler.litellm_jwtauth.oidc_userinfo_enabled: + if ( + jwt_handler.litellm_jwtauth.oidc_userinfo_enabled + and not jwt_handler.is_jwt(token=api_key) + ): jwt_claims = await jwt_handler.get_oidc_userinfo(token=api_key) else: jwt_claims = await jwt_handler.auth_jwt(token=api_key) diff --git a/tests/litellm_utils_tests/test_azure_ai_anthropic_token_counter.py b/tests/litellm_utils_tests/test_azure_ai_anthropic_token_counter.py index 031502cbec..2686c28cb1 100644 --- a/tests/litellm_utils_tests/test_azure_ai_anthropic_token_counter.py +++ b/tests/litellm_utils_tests/test_azure_ai_anthropic_token_counter.py @@ -26,22 +26,20 @@ class TestAzureAIAnthropicTokenCounter(BaseTokenCounterTest): return AzureAIAnthropicTokenCounter() def get_test_model(self) -> str: - return "claude-3-5-sonnet" + return "claude-sonnet-4-6" def get_test_messages(self) -> List[Dict[str, Any]]: - return [ - {"role": "user", "content": "Hello, how are you today?"} - ] + return [{"role": "user", "content": "Hello, how are you today?"}] def get_deployment_config(self) -> Dict[str, Any]: - api_key = os.getenv("AZURE_AI_API_KEY") - api_base = os.getenv("AZURE_AI_API_BASE") - + api_key = os.getenv("AZURE_ANTHROPIC_API_KEY") + api_base = os.getenv("AZURE_AI_SWEDEN_API_BASE") + if not api_key: pytest.skip("AZURE_AI_API_KEY not set") if not api_base: pytest.skip("AZURE_AI_API_BASE not set") - + return { "litellm_params": { "api_key": api_key, diff --git a/tests/local_testing/test_streaming.py b/tests/local_testing/test_streaming.py index 5819bf1afd..98e1726c5f 100644 --- a/tests/local_testing/test_streaming.py +++ b/tests/local_testing/test_streaming.py @@ -471,85 +471,6 @@ def test_completion_azure_stream(): # test_completion_azure_stream() -@pytest.mark.skip("Skipping predibase streaming test - ran out of credits") -@pytest.mark.parametrize("sync_mode", [True, False]) -@pytest.mark.asyncio -async def test_completion_predibase_streaming(sync_mode): - try: - litellm.set_verbose = True - litellm._turn_on_debug() - if sync_mode: - response = completion( - model="predibase/llama-3-8b-instruct", - timeout=5, - tenant_id="c4768f95", - max_tokens=10, - api_base="https://serving.app.predibase.com", - api_key=os.getenv("PREDIBASE_API_KEY"), - messages=[{"role": "user", "content": "What is the meaning of life?"}], - stream=True, - ) - - complete_response = "" - for idx, init_chunk in enumerate(response): - chunk, finished = streaming_format_tests(idx, init_chunk) - complete_response += chunk - custom_llm_provider = init_chunk._hidden_params["custom_llm_provider"] - print(f"custom_llm_provider: {custom_llm_provider}") - assert custom_llm_provider == "predibase" - if finished: - assert isinstance( - init_chunk.choices[0], litellm.utils.StreamingChoices - ) - break - if complete_response.strip() == "": - raise Exception("Empty response received") - else: - response = await litellm.acompletion( - model="predibase/llama-3-8b-instruct", - tenant_id="c4768f95", - timeout=5, - max_tokens=10, - api_base="https://serving.app.predibase.com", - api_key=os.getenv("PREDIBASE_API_KEY"), - messages=[{"role": "user", "content": "What is the meaning of life?"}], - stream=True, - ) - - # await response - - complete_response = "" - idx = 0 - async for init_chunk in response: - chunk, finished = streaming_format_tests(idx, init_chunk) - complete_response += chunk - custom_llm_provider = init_chunk._hidden_params["custom_llm_provider"] - print(f"custom_llm_provider: {custom_llm_provider}") - assert custom_llm_provider == "predibase" - idx += 1 - if finished: - assert isinstance( - init_chunk.choices[0], litellm.utils.StreamingChoices - ) - break - if complete_response.strip() == "": - raise Exception("Empty response received") - - print(f"complete_response: {complete_response}") - except litellm.Timeout: - pass - except litellm.InternalServerError: - pass - except litellm.ServiceUnavailableError: - pass - except litellm.APIConnectionError: - pass - except Exception as e: - print("ERROR class", e.__class__) - print("ERROR message", e) - print("ERROR traceback", traceback.format_exc()) - - pytest.fail(f"Error occurred: {e}") def test_completion_azure_function_calling_stream(): @@ -1143,6 +1064,7 @@ def test_vertex_ai_stream(provider): # test_completion_vertexai_stream_bad_key() +@pytest.mark.skip(reason="Replicate extremely flaky.") @pytest.mark.parametrize("sync_mode", [False, True]) @pytest.mark.asyncio async def test_completion_replicate_llama3_streaming(sync_mode): diff --git a/tests/proxy_unit_tests/test_jwt_key_mapping.py b/tests/proxy_unit_tests/test_jwt_key_mapping.py index b67dd2792f..66e5b3839b 100644 --- a/tests/proxy_unit_tests/test_jwt_key_mapping.py +++ b/tests/proxy_unit_tests/test_jwt_key_mapping.py @@ -135,6 +135,81 @@ async def test_jwt_to_virtual_key_mapping_no_mapping(): prisma_client.db.litellm_jwtkeymapping.find_first.assert_not_called() +# ────────────────────────────────────────────── +# Tests: OIDC / JWT routing in user_api_key_auth +# ────────────────────────────────────────────── + + +@pytest.mark.asyncio +async def test_virtual_key_mapping_oidc_enabled_jwt_token_uses_auth_jwt(): + """ + Regression test for the is_jwt routing fix in user_api_key_auth.py. + + When oidc_userinfo_enabled=True and virtual_key_claim_field is set, but + the token is a well-formed JWT (3-part header.payload.sig), the virtual-key + claim lookup must call auth_jwt — not get_oidc_userinfo. + """ + # Three-part token: is_jwt() returns True + api_key = "eyJhbGciOiJSUzI1NiJ9.eyJlbWFpbCI6InVzZXJAZXhhbXBsZS5jb20ifQ.sig" + + jwt_handler = JWTHandler() + jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth( + oidc_userinfo_enabled=True, + virtual_key_claim_field="email", + ) + + # Confirm our fixture token is treated as a JWT + assert jwt_handler.is_jwt(token=api_key) is True + + auth_jwt_mock = AsyncMock(return_value={"email": "user@example.com", "sub": "123"}) + oidc_userinfo_mock = AsyncMock(return_value={"email": "user@example.com"}) + + # Simulate the routing condition from user_api_key_auth.py + if jwt_handler.litellm_jwtauth.oidc_userinfo_enabled and not jwt_handler.is_jwt( + token=api_key + ): + jwt_claims = await oidc_userinfo_mock(token=api_key) + else: + jwt_claims = await auth_jwt_mock(token=api_key) + + auth_jwt_mock.assert_called_once_with(token=api_key) + oidc_userinfo_mock.assert_not_called() + assert jwt_claims["email"] == "user@example.com" + + +@pytest.mark.asyncio +async def test_virtual_key_mapping_oidc_enabled_opaque_token_uses_oidc_userinfo(): + """ + Complement of the test above: when oidc_userinfo_enabled=True and the token + is an opaque access token (not a JWT), the virtual-key claim lookup must + call get_oidc_userinfo — not auth_jwt. + """ + # Opaque token: no dots → is_jwt() returns False + api_key = "some_opaque_access_token_with_no_dots" + + jwt_handler = JWTHandler() + jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth( + oidc_userinfo_enabled=True, + virtual_key_claim_field="email", + ) + + assert jwt_handler.is_jwt(token=api_key) is False + + auth_jwt_mock = AsyncMock(return_value={"email": "user@example.com"}) + oidc_userinfo_mock = AsyncMock(return_value={"email": "user@example.com", "sub": "123"}) + + if jwt_handler.litellm_jwtauth.oidc_userinfo_enabled and not jwt_handler.is_jwt( + token=api_key + ): + jwt_claims = await oidc_userinfo_mock(token=api_key) + else: + jwt_claims = await auth_jwt_mock(token=api_key) + + oidc_userinfo_mock.assert_called_once_with(token=api_key) + auth_jwt_mock.assert_not_called() + assert jwt_claims["sub"] == "123" + + # ────────────────────────────────────────────── # Tests: _to_response redacts hashed token # ────────────────────────────────────────────── diff --git a/tests/test_litellm/proxy/auth/test_handle_jwt.py b/tests/test_litellm/proxy/auth/test_handle_jwt.py index ada67fbba8..5303da6fbc 100644 --- a/tests/test_litellm/proxy/auth/test_handle_jwt.py +++ b/tests/test_litellm/proxy/auth/test_handle_jwt.py @@ -305,24 +305,28 @@ async def test_sync_user_role_and_teams(): # Create mock objects for required types mock_user_api_key_cache = MagicMock() mock_proxy_logging_obj = MagicMock() - + jwt_handler = JWTHandler() jwt_handler.update_environment( prisma_client=None, user_api_key_cache=mock_user_api_key_cache, litellm_jwtauth=LiteLLM_JWTAuth( jwt_litellm_role_map=[ - JWTLiteLLMRoleMap(jwt_role="ADMIN", litellm_role=LitellmUserRoles.PROXY_ADMIN) + JWTLiteLLMRoleMap( + jwt_role="ADMIN", litellm_role=LitellmUserRoles.PROXY_ADMIN + ) ], roles_jwt_field="roles", team_ids_jwt_field="my_id_teams", - sync_user_role_and_teams=True + sync_user_role_and_teams=True, ), ) token = {"roles": ["ADMIN"], "my_id_teams": ["team1", "team2"]} - user = LiteLLM_UserTable(user_id="u1", user_role=LitellmUserRoles.INTERNAL_USER.value, teams=["team2"]) + user = LiteLLM_UserTable( + user_id="u1", user_role=LitellmUserRoles.INTERNAL_USER.value, teams=["team2"] + ) prisma = AsyncMock() prisma.db.litellm_usertable.update = AsyncMock() @@ -350,7 +354,9 @@ async def test_sync_user_role_and_teams_cache_invalidation_on_role_change(): user_api_key_cache=AsyncMock(), litellm_jwtauth=LiteLLM_JWTAuth( jwt_litellm_role_map=[ - JWTLiteLLMRoleMap(jwt_role="ADMIN", litellm_role=LitellmUserRoles.PROXY_ADMIN) + JWTLiteLLMRoleMap( + jwt_role="ADMIN", litellm_role=LitellmUserRoles.PROXY_ADMIN + ) ], roles_jwt_field="roles", team_ids_jwt_field="my_id_teams", @@ -375,7 +381,9 @@ async def test_sync_user_role_and_teams_cache_invalidation_on_role_change(): mock_cache.async_set_cache.assert_called_once() call_kwargs = mock_cache.async_set_cache.call_args assert call_kwargs.kwargs["key"] == "u1" - assert call_kwargs.kwargs["value"]["user_role"] == LitellmUserRoles.PROXY_ADMIN.value + assert ( + call_kwargs.kwargs["value"]["user_role"] == LitellmUserRoles.PROXY_ADMIN.value + ) @pytest.mark.asyncio @@ -389,7 +397,9 @@ async def test_sync_user_role_and_teams_cache_invalidation_on_team_change(): user_api_key_cache=AsyncMock(), litellm_jwtauth=LiteLLM_JWTAuth( jwt_litellm_role_map=[ - JWTLiteLLMRoleMap(jwt_role="ADMIN", litellm_role=LitellmUserRoles.PROXY_ADMIN) + JWTLiteLLMRoleMap( + jwt_role="ADMIN", litellm_role=LitellmUserRoles.PROXY_ADMIN + ) ], roles_jwt_field="roles", team_ids_jwt_field="my_id_teams", @@ -432,7 +442,9 @@ async def test_sync_user_role_and_teams_no_cache_write_when_nothing_changes(): user_api_key_cache=AsyncMock(), litellm_jwtauth=LiteLLM_JWTAuth( jwt_litellm_role_map=[ - JWTLiteLLMRoleMap(jwt_role="ADMIN", litellm_role=LitellmUserRoles.PROXY_ADMIN) + JWTLiteLLMRoleMap( + jwt_role="ADMIN", litellm_role=LitellmUserRoles.PROXY_ADMIN + ) ], roles_jwt_field="roles", team_ids_jwt_field="my_id_teams", @@ -463,7 +475,7 @@ async def test_map_jwt_role_to_litellm_role(): # Create mock objects for required types mock_user_api_key_cache = MagicMock() - + jwt_handler = JWTHandler() jwt_handler.update_environment( prisma_client=None, @@ -471,13 +483,21 @@ async def test_map_jwt_role_to_litellm_role(): litellm_jwtauth=LiteLLM_JWTAuth( jwt_litellm_role_map=[ # Exact match - JWTLiteLLMRoleMap(jwt_role="ADMIN", litellm_role=LitellmUserRoles.PROXY_ADMIN), + JWTLiteLLMRoleMap( + jwt_role="ADMIN", litellm_role=LitellmUserRoles.PROXY_ADMIN + ), # Wildcard patterns - JWTLiteLLMRoleMap(jwt_role="user_*", litellm_role=LitellmUserRoles.INTERNAL_USER), - JWTLiteLLMRoleMap(jwt_role="team_?", litellm_role=LitellmUserRoles.TEAM), - JWTLiteLLMRoleMap(jwt_role="dev_[123]", litellm_role=LitellmUserRoles.INTERNAL_USER), + JWTLiteLLMRoleMap( + jwt_role="user_*", litellm_role=LitellmUserRoles.INTERNAL_USER + ), + JWTLiteLLMRoleMap( + jwt_role="team_?", litellm_role=LitellmUserRoles.TEAM + ), + JWTLiteLLMRoleMap( + jwt_role="dev_[123]", litellm_role=LitellmUserRoles.INTERNAL_USER + ), ], - roles_jwt_field="roles" + roles_jwt_field="roles", ), ) @@ -547,7 +567,9 @@ async def test_map_jwt_role_to_litellm_role(): # Test patterns that don't match character classes jwt_handler.litellm_jwtauth.jwt_litellm_role_map = [ - JWTLiteLLMRoleMap(jwt_role="dev_[123]", litellm_role=LitellmUserRoles.INTERNAL_USER), + JWTLiteLLMRoleMap( + jwt_role="dev_[123]", litellm_role=LitellmUserRoles.INTERNAL_USER + ), ] token = {"roles": ["dev_4"]} # 4 is not in [123] result = jwt_handler.map_jwt_role_to_litellm_role(token) @@ -570,7 +592,7 @@ async def test_map_jwt_role_to_litellm_role(): async def test_nested_jwt_field_access(): """ Test that all JWT fields support dot notation for nested access - + This test verifies that: 1. All JWT field methods can access nested values using dot notation 2. Backward compatibility is maintained for flat field names @@ -581,33 +603,18 @@ async def test_nested_jwt_field_access(): # Create JWT handler jwt_handler = JWTHandler() - + # Test token with nested claims nested_token = { - "user": { - "sub": "u123", - "email": "user@example.com" - }, - "resource_access": { - "my-client": { - "roles": ["admin", "user"] - } - }, + "user": {"sub": "u123", "email": "user@example.com"}, + "resource_access": {"my-client": {"roles": ["admin", "user"]}}, "groups": ["team1", "team2"], - "organization": { - "id": "org456" - }, - "profile": { - "object_id": "obj789" - }, - "customer": { - "end_user_id": "customer123" - }, - "tenant": { - "team_id": "team456" - } + "organization": {"id": "org456"}, + "profile": {"object_id": "obj789"}, + "customer": {"end_user_id": "customer123"}, + "tenant": {"team_id": "team456"}, } - + # Test flat token for backward compatibility flat_token = { "sub": "u123", @@ -617,13 +624,13 @@ async def test_nested_jwt_field_access(): "org_id": "org456", "object_id": "obj789", "end_user_id": "customer123", - "team_id": "team456" + "team_id": "team456", } # Test 1: user_id_jwt_field with nested access jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth(user_id_jwt_field="user.sub") assert jwt_handler.get_user_id(nested_token, None) == "u123" - + # Test 1b: user_id_jwt_field with flat access (backward compatibility) jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth(user_id_jwt_field="sub") assert jwt_handler.get_user_id(flat_token, None) == "u123" @@ -631,7 +638,7 @@ async def test_nested_jwt_field_access(): # Test 2: user_email_jwt_field with nested access jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth(user_email_jwt_field="user.email") assert jwt_handler.get_user_email(nested_token, None) == "user@example.com" - + # Test 2b: user_email_jwt_field with flat access (backward compatibility) jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth(user_email_jwt_field="email") assert jwt_handler.get_user_email(flat_token, None) == "user@example.com" @@ -639,7 +646,7 @@ async def test_nested_jwt_field_access(): # Test 3: team_ids_jwt_field with nested access jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth(team_ids_jwt_field="groups") assert jwt_handler.get_team_ids_from_jwt(nested_token) == ["team1", "team2"] - + # Test 3b: team_ids_jwt_field with flat access (backward compatibility) jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth(team_ids_jwt_field="groups") assert jwt_handler.get_team_ids_from_jwt(flat_token) == ["team1", "team2"] @@ -647,30 +654,37 @@ async def test_nested_jwt_field_access(): # Test 4: org_id_jwt_field with nested access jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth(org_id_jwt_field="organization.id") assert jwt_handler.get_org_id(nested_token, None) == "org456" - + # Test 4b: org_id_jwt_field with flat access (backward compatibility) jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth(org_id_jwt_field="org_id") assert jwt_handler.get_org_id(flat_token, None) == "org456" # Test 5: object_id_jwt_field with nested access (requires role_mappings) from litellm.proxy._types import LitellmUserRoles, RoleMapping + jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth( object_id_jwt_field="profile.object_id", - role_mappings=[RoleMapping(role="admin", internal_role=LitellmUserRoles.INTERNAL_USER)] + role_mappings=[ + RoleMapping(role="admin", internal_role=LitellmUserRoles.INTERNAL_USER) + ], ) assert jwt_handler.get_object_id(nested_token, None) == "obj789" - + # Test 5b: object_id_jwt_field with flat access (backward compatibility) jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth( object_id_jwt_field="object_id", - role_mappings=[RoleMapping(role="admin", internal_role=LitellmUserRoles.INTERNAL_USER)] + role_mappings=[ + RoleMapping(role="admin", internal_role=LitellmUserRoles.INTERNAL_USER) + ], ) assert jwt_handler.get_object_id(flat_token, None) == "obj789" # Test 6: end_user_id_jwt_field with nested access - jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth(end_user_id_jwt_field="customer.end_user_id") + jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth( + end_user_id_jwt_field="customer.end_user_id" + ) assert jwt_handler.get_end_user_id(nested_token, None) == "customer123" - + # Test 6b: end_user_id_jwt_field with flat access (backward compatibility) jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth(end_user_id_jwt_field="end_user_id") assert jwt_handler.get_end_user_id(flat_token, None) == "customer123" @@ -678,19 +692,21 @@ async def test_nested_jwt_field_access(): # Test 7: team_id_jwt_field with nested access jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth(team_id_jwt_field="tenant.team_id") assert jwt_handler.get_team_id(nested_token, None) == "team456" - + # Test 7b: team_id_jwt_field with flat access (backward compatibility) jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth(team_id_jwt_field="team_id") assert jwt_handler.get_team_id(flat_token, None) == "team456" # Test 8: roles_jwt_field with deeply nested access (already supported, but testing) - jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth(roles_jwt_field="resource_access.my-client.roles") + jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth( + roles_jwt_field="resource_access.my-client.roles" + ) assert jwt_handler.get_jwt_role(nested_token, []) == ["admin", "user"] # Test 9: user_roles_jwt_field with nested access (already supported, but testing) jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth( user_roles_jwt_field="resource_access.my-client.roles", - user_allowed_roles=["admin", "user"] + user_allowed_roles=["admin", "user"], ) assert jwt_handler.get_user_roles(nested_token, []) == ["admin", "user"] @@ -699,7 +715,7 @@ async def test_nested_jwt_field_access(): async def test_nested_jwt_field_missing_paths(): """ Test handling of missing nested paths in JWT tokens - + This test verifies that: 1. Missing nested paths return appropriate defaults 2. Partial paths that exist but don't have the final key return defaults @@ -710,7 +726,7 @@ async def test_nested_jwt_field_missing_paths(): # Create JWT handler jwt_handler = JWTHandler() - + # Test token with missing nested paths incomplete_token = { "user": { @@ -718,9 +734,7 @@ async def test_nested_jwt_field_missing_paths(): # missing "sub" and "email" }, "resource_access": { - "other-client": { - "roles": ["viewer"] - } + "other-client": {"roles": ["viewer"]} # missing "my-client" } # missing "organization", "profile", "customer", "tenant", "groups" @@ -732,7 +746,10 @@ async def test_nested_jwt_field_missing_paths(): # Test 2: Missing user.email should return default jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth(user_email_jwt_field="user.email") - assert jwt_handler.get_user_email(incomplete_token, "default@example.com") == "default@example.com" + assert ( + jwt_handler.get_user_email(incomplete_token, "default@example.com") + == "default@example.com" + ) # Test 3: Missing groups should return empty list jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth(team_ids_jwt_field="groups") @@ -744,40 +761,53 @@ async def test_nested_jwt_field_missing_paths(): # Test 5: Missing profile.object_id should return default (requires role_mappings) from litellm.proxy._types import LitellmUserRoles, RoleMapping + jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth( object_id_jwt_field="profile.object_id", - role_mappings=[RoleMapping(role="admin", internal_role=LitellmUserRoles.INTERNAL_USER)] + role_mappings=[ + RoleMapping(role="admin", internal_role=LitellmUserRoles.INTERNAL_USER) + ], ) assert jwt_handler.get_object_id(incomplete_token, "default_obj") == "default_obj" # Test 6: Missing customer.end_user_id should return default - jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth(end_user_id_jwt_field="customer.end_user_id") - assert jwt_handler.get_end_user_id(incomplete_token, "default_customer") == "default_customer" + jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth( + end_user_id_jwt_field="customer.end_user_id" + ) + assert ( + jwt_handler.get_end_user_id(incomplete_token, "default_customer") + == "default_customer" + ) # Test 7: Missing tenant.team_id should use team_id_default fallback jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth( - team_id_jwt_field="tenant.team_id", - team_id_default="fallback_team" + team_id_jwt_field="tenant.team_id", team_id_default="fallback_team" ) assert jwt_handler.get_team_id(incomplete_token, "default_team") == "fallback_team" # Test 8: Missing resource_access.my-client.roles should return default - jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth(roles_jwt_field="resource_access.my-client.roles") - assert jwt_handler.get_jwt_role(incomplete_token, ["default_role"]) == ["default_role"] + jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth( + roles_jwt_field="resource_access.my-client.roles" + ) + assert jwt_handler.get_jwt_role(incomplete_token, ["default_role"]) == [ + "default_role" + ] # Test 9: Missing nested user roles should return default jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth( user_roles_jwt_field="resource_access.my-client.roles", - user_allowed_roles=["admin", "user"] + user_allowed_roles=["admin", "user"], ) - assert jwt_handler.get_user_roles(incomplete_token, ["default_user_role"]) == ["default_user_role"] + assert jwt_handler.get_user_roles(incomplete_token, ["default_user_role"]) == [ + "default_user_role" + ] -@pytest.mark.asyncio +@pytest.mark.asyncio async def test_metadata_prefix_handling_in_nested_fields(): """ Test that metadata. prefix is properly handled in nested JWT field access - + The get_nested_value function should remove metadata. prefix before traversing """ from litellm.proxy._types import LiteLLM_JWTAuth @@ -785,17 +815,19 @@ async def test_metadata_prefix_handling_in_nested_fields(): # Create JWT handler jwt_handler = JWTHandler() - + # Test token with proper structure for metadata prefix removal token = { "user": { "email": "user@example.com" # This will be accessed when metadata.user.email is used }, - "sub": "u123" + "sub": "u123", } # Test 1: metadata.user.email should access user.email after prefix removal - jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth(user_email_jwt_field="metadata.user.email") + jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth( + user_email_jwt_field="metadata.user.email" + ) # The get_nested_value function removes "metadata." prefix, so "metadata.user.email" becomes "user.email" assert jwt_handler.get_user_email(token, None) == "user@example.com" @@ -871,24 +903,21 @@ async def test_auth_builder_returns_team_membership_object(): # Create mock objects from litellm.proxy._types import LiteLLM_BudgetTable, LiteLLM_TeamMembership - + mock_team_membership = LiteLLM_TeamMembership( user_id=_user_id, team_id=_team_id, budget_id="budget_123", spend=10.5, litellm_budget_table=LiteLLM_BudgetTable( - budget_id="budget_123", - rpm_limit=100, - tpm_limit=5000 - ) + budget_id="budget_123", rpm_limit=100, tpm_limit=5000 + ), ) - + user_object = LiteLLM_UserTable( - user_id=_user_id, - user_role=LitellmUserRoles.INTERNAL_USER + user_id=_user_id, user_role=LitellmUserRoles.INTERNAL_USER ) - + team_object = LiteLLM_TeamTable(team_id=_team_id) # Create mock JWT handler @@ -958,12 +987,24 @@ async def test_auth_builder_returns_team_membership_object(): ) # Verify that team_membership_object is returned - assert result["team_membership"] is not None, "team_membership should be present" - assert result["team_membership"] == mock_team_membership, "team_membership should match the mock object" - assert result["team_membership"].user_id == _user_id, "team_membership user_id should match" - assert result["team_membership"].team_id == _team_id, "team_membership team_id should match" - assert result["team_membership"].budget_id == "budget_123", "team_membership budget_id should match" - assert result["team_membership"].spend == 10.5, "team_membership spend should match" + assert ( + result["team_membership"] is not None + ), "team_membership should be present" + assert ( + result["team_membership"] == mock_team_membership + ), "team_membership should match the mock object" + assert ( + result["team_membership"].user_id == _user_id + ), "team_membership user_id should match" + assert ( + result["team_membership"].team_id == _team_id + ), "team_membership team_id should match" + assert ( + result["team_membership"].budget_id == "budget_123" + ), "team_membership budget_id should match" + assert ( + result["team_membership"].spend == 10.5 + ), "team_membership spend should match" @pytest.mark.asyncio @@ -979,16 +1020,16 @@ async def test_auth_builder_with_oidc_userinfo_enabled(): request_data = {"model": "gpt-4"} general_settings = {"enforce_rbac": False} route = "/chat/completions" - + user_object = LiteLLM_UserTable( user_id="test_user_1", user_role=LitellmUserRoles.INTERNAL_USER ) - + # Create JWT handler with OIDC UserInfo enabled jwt_handler = JWTHandler() user_api_key_cache = DualCache() proxy_logging_obj = ProxyLogging(user_api_key_cache=user_api_key_cache) - + jwt_handler.update_environment( prisma_client=None, user_api_key_cache=user_api_key_cache, @@ -999,14 +1040,14 @@ async def test_auth_builder_with_oidc_userinfo_enabled(): user_email_jwt_field="email", ), ) - + # Mock OIDC UserInfo response userinfo_response = { "sub": "test_user_1", "email": "test@example.com", "scope": "", } - + # Mock all the dependencies with patch.object( jwt_handler, "get_oidc_userinfo", new_callable=AsyncMock @@ -1057,7 +1098,7 @@ async def test_auth_builder_with_oidc_userinfo_enabled(): ) as mock_sync_user: # Set up mock return values mock_get_userinfo.return_value = userinfo_response - + # Call auth_builder result = await JWTAuthManager.auth_builder( api_key=api_key, @@ -1070,11 +1111,11 @@ async def test_auth_builder_with_oidc_userinfo_enabled(): parent_otel_span=None, proxy_logging_obj=proxy_logging_obj, ) - + # Verify that get_oidc_userinfo was called instead of auth_jwt mock_get_userinfo.assert_called_once_with(token=api_key) mock_auth_jwt.assert_not_called() # Should not be called when OIDC is enabled - + # Verify the result assert result["user_id"] == "test_user_1" assert result["user_object"] == user_object @@ -1093,16 +1134,16 @@ async def test_auth_builder_with_oidc_userinfo_disabled(): request_data = {"model": "gpt-4"} general_settings = {"enforce_rbac": False} route = "/chat/completions" - + user_object = LiteLLM_UserTable( user_id="test_user_1", user_role=LitellmUserRoles.INTERNAL_USER ) - + # Create JWT handler with OIDC UserInfo disabled jwt_handler = JWTHandler() user_api_key_cache = DualCache() proxy_logging_obj = ProxyLogging(user_api_key_cache=user_api_key_cache) - + jwt_handler.update_environment( prisma_client=None, user_api_key_cache=user_api_key_cache, @@ -1111,13 +1152,13 @@ async def test_auth_builder_with_oidc_userinfo_disabled(): user_id_jwt_field="sub", ), ) - + # Mock JWT validation response jwt_response = { "sub": "test_user_1", "scope": "", } - + # Mock all the dependencies with patch.object( jwt_handler, "get_oidc_userinfo", new_callable=AsyncMock @@ -1168,7 +1209,7 @@ async def test_auth_builder_with_oidc_userinfo_disabled(): ) as mock_sync_user: # Set up mock return values mock_auth_jwt.return_value = jwt_response - + # Call auth_builder result = await JWTAuthManager.auth_builder( api_key=api_key, @@ -1181,16 +1222,125 @@ async def test_auth_builder_with_oidc_userinfo_disabled(): parent_otel_span=None, proxy_logging_obj=proxy_logging_obj, ) - + # Verify that auth_jwt was called instead of get_oidc_userinfo mock_auth_jwt.assert_called_once_with(token=api_key) mock_get_userinfo.assert_not_called() # Should not be called when OIDC is disabled - + # Verify the result assert result["user_id"] == "test_user_1" assert result["user_object"] == user_object +@pytest.mark.asyncio +async def test_auth_builder_oidc_enabled_falls_back_to_jwt_auth_for_jwt_tokens(): + """ + Regression test for the is_jwt routing fix. + + When oidc_userinfo_enabled=True but the supplied token is a well-formed + JWT (three dot-separated parts), auth_builder must call auth_jwt and skip + get_oidc_userinfo. Sending a standard JWT to the OIDC UserInfo endpoint + is incorrect — the endpoint expects an opaque access token. + """ + from litellm.caching import DualCache + from litellm.proxy.utils import ProxyLogging + + # Three-part token: recognised as a JWT by is_jwt() + api_key = "eyJhbGciOiJSUzI1NiJ9.eyJzdWIiOiJ0ZXN0X3VzZXIifQ.some_signature" + request_data = {"model": "gpt-4"} + general_settings = {"enforce_rbac": False} + route = "/chat/completions" + + user_object = LiteLLM_UserTable( + user_id="test_user_1", user_role=LitellmUserRoles.INTERNAL_USER + ) + + jwt_handler = JWTHandler() + user_api_key_cache = DualCache() + proxy_logging_obj = ProxyLogging(user_api_key_cache=user_api_key_cache) + + jwt_handler.update_environment( + prisma_client=None, + user_api_key_cache=user_api_key_cache, + litellm_jwtauth=LiteLLM_JWTAuth( + oidc_userinfo_enabled=True, + oidc_userinfo_endpoint="https://example.com/oauth2/userinfo", + user_id_jwt_field="sub", + ), + ) + + jwt_response = {"sub": "test_user_1", "scope": ""} + + with patch.object( + jwt_handler, "get_oidc_userinfo", new_callable=AsyncMock + ) as mock_get_userinfo, patch.object( + jwt_handler, "auth_jwt", new_callable=AsyncMock + ) as mock_auth_jwt, patch.object( + JWTAuthManager, "check_rbac_role", new_callable=AsyncMock + ), patch.object( + jwt_handler, "get_rbac_role", return_value=None + ), patch.object( + jwt_handler, "get_scopes", return_value=[] + ), patch.object( + jwt_handler, "get_object_id", return_value=None + ), patch.object( + JWTAuthManager, + "get_user_info", + new_callable=AsyncMock, + return_value=("test_user_1", None, None), + ), patch.object( + jwt_handler, "get_org_id", return_value=None + ), patch.object( + jwt_handler, "get_end_user_id", return_value=None + ), patch.object( + JWTAuthManager, "check_admin_access", new_callable=AsyncMock, return_value=None + ), patch.object( + JWTAuthManager, + "find_and_validate_specific_team_id", + new_callable=AsyncMock, + return_value=(None, None), + ), patch.object( + JWTAuthManager, "get_all_team_ids", return_value=set() + ), patch.object( + JWTAuthManager, + "find_team_with_model_access", + new_callable=AsyncMock, + return_value=(None, None), + ), patch.object( + JWTAuthManager, + "get_objects", + new_callable=AsyncMock, + return_value=(user_object, None, None, None), + ), patch.object( + JWTAuthManager, "map_user_to_teams", new_callable=AsyncMock + ), patch.object( + JWTAuthManager, "validate_object_id", return_value=True + ), patch.object( + JWTAuthManager, "sync_user_role_and_teams", new_callable=AsyncMock + ): + mock_auth_jwt.return_value = jwt_response + + result = await JWTAuthManager.auth_builder( + api_key=api_key, + jwt_handler=jwt_handler, + request_data=request_data, + general_settings=general_settings, + route=route, + prisma_client=None, + user_api_key_cache=user_api_key_cache, + parent_otel_span=None, + proxy_logging_obj=proxy_logging_obj, + ) + + # Token is a JWT, so standard JWT auth must be used even when + # oidc_userinfo_enabled is True. + mock_auth_jwt.assert_called_once_with(token=api_key) + mock_get_userinfo.assert_not_called() + + assert result["user_id"] == "test_user_1" + assert result["user_object"] == user_object + + def test_get_team_id_from_header(): """Test get_team_id_from_header returns team when valid, None when missing, raises on invalid.""" from fastapi import HTTPException @@ -1236,17 +1386,33 @@ async def test_auth_builder_uses_team_from_header_e2e(): ) team_object = LiteLLM_TeamTable(team_id="team-2") - user_object = LiteLLM_UserTable(user_id="user-1", user_role=LitellmUserRoles.INTERNAL_USER) + user_object = LiteLLM_UserTable( + user_id="user-1", user_role=LitellmUserRoles.INTERNAL_USER + ) - with patch.object(jwt_handler, "auth_jwt", new_callable=AsyncMock) as mock_auth_jwt, \ - patch.object(JWTAuthManager, "check_rbac_role", new_callable=AsyncMock), \ - patch.object(JWTAuthManager, "check_admin_access", new_callable=AsyncMock, return_value=None), \ - patch("litellm.proxy.auth.handle_jwt.get_team_object", new_callable=AsyncMock) as mock_get_team, \ - patch.object(JWTAuthManager, "get_objects", new_callable=AsyncMock, return_value=(user_object, None, None, None)), \ - patch.object(JWTAuthManager, "map_user_to_teams", new_callable=AsyncMock), \ - patch.object(JWTAuthManager, "sync_user_role_and_teams", new_callable=AsyncMock): - - mock_auth_jwt.return_value = {"sub": "user-1", "scope": "", "groups": ["team-1", "team-2"]} + with patch.object( + jwt_handler, "auth_jwt", new_callable=AsyncMock + ) as mock_auth_jwt, patch.object( + JWTAuthManager, "check_rbac_role", new_callable=AsyncMock + ), patch.object( + JWTAuthManager, "check_admin_access", new_callable=AsyncMock, return_value=None + ), patch( + "litellm.proxy.auth.handle_jwt.get_team_object", new_callable=AsyncMock + ) as mock_get_team, patch.object( + JWTAuthManager, + "get_objects", + new_callable=AsyncMock, + return_value=(user_object, None, None, None), + ), patch.object( + JWTAuthManager, "map_user_to_teams", new_callable=AsyncMock + ), patch.object( + JWTAuthManager, "sync_user_role_and_teams", new_callable=AsyncMock + ): + mock_auth_jwt.return_value = { + "sub": "user-1", + "scope": "", + "groups": ["team-1", "team-2"], + } mock_get_team.return_value = team_object result = await JWTAuthManager.auth_builder( @@ -1275,29 +1441,29 @@ async def test_get_team_alias_with_nested_fields(): from litellm.proxy.auth.handle_jwt import JWTHandler jwt_handler = JWTHandler() - + # Test token with nested team name nested_token = { - "organization": { - "team": { - "name": "engineering-team" - } - }, - "team_name": "flat-team" + "organization": {"team": {"name": "engineering-team"}}, + "team_name": "flat-team", } - + # Test nested access - jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth(team_alias_jwt_field="organization.team.name") + jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth( + team_alias_jwt_field="organization.team.name" + ) assert jwt_handler.get_team_alias(nested_token, None) == "engineering-team" - + # Test flat access (backward compatibility) jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth(team_alias_jwt_field="team_name") assert jwt_handler.get_team_alias(nested_token, None) == "flat-team" - + # Test missing field returns default - jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth(team_alias_jwt_field="nonexistent.field") + jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth( + team_alias_jwt_field="nonexistent.field" + ) assert jwt_handler.get_team_alias(nested_token, "default-team") == "default-team" - + # Test with team_alias_jwt_field not configured jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth() # team_alias_jwt_field is None assert jwt_handler.get_team_alias(nested_token, "default") is None @@ -1312,23 +1478,22 @@ async def test_is_required_team_id_with_team_alias_field(): from litellm.proxy.auth.handle_jwt import JWTHandler jwt_handler = JWTHandler() - + # Neither field set - should return False jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth() assert jwt_handler.is_required_team_id() is False - + # Only team_id_jwt_field set - should return True jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth(team_id_jwt_field="team_id") assert jwt_handler.is_required_team_id() is True - + # Only team_alias_jwt_field set - should return True jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth(team_alias_jwt_field="team_name") assert jwt_handler.is_required_team_id() is True - + # Both fields set - should return True jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth( - team_id_jwt_field="team_id", - team_alias_jwt_field="team_name" + team_id_jwt_field="team_id", team_alias_jwt_field="team_name" ) assert jwt_handler.is_required_team_id() is True @@ -1348,30 +1513,24 @@ async def test_find_and_validate_specific_team_id_with_team_alias(): jwt_handler = JWTHandler() user_api_key_cache = DualCache() proxy_logging_obj = ProxyLogging(user_api_key_cache=user_api_key_cache) - + jwt_handler.update_environment( prisma_client=None, user_api_key_cache=user_api_key_cache, - litellm_jwtauth=LiteLLM_JWTAuth( - team_alias_jwt_field="team_alias" - ), + litellm_jwtauth=LiteLLM_JWTAuth(team_alias_jwt_field="team_alias"), ) - + # Token with team name (no team_id) - jwt_token = { - "sub": "user-1", - "team_alias": "my-team" - } - + jwt_token = {"sub": "user-1", "team_alias": "my-team"} + # Mock team object returned by get_team_object_by_alias team_object = LiteLLM_TeamTable(team_id="resolved-team-id", team_alias="my-team") - + with patch( - "litellm.proxy.auth.handle_jwt.get_team_object_by_alias", - new_callable=AsyncMock + "litellm.proxy.auth.handle_jwt.get_team_object_by_alias", new_callable=AsyncMock ) as mock_get_by_alias: mock_get_by_alias.return_value = team_object - + team_id, result_team = await JWTAuthManager.find_and_validate_specific_team_id( jwt_handler=jwt_handler, jwt_valid_token=jwt_token, @@ -1380,7 +1539,7 @@ async def test_find_and_validate_specific_team_id_with_team_alias(): parent_otel_span=None, proxy_logging_obj=proxy_logging_obj, ) - + # Should have resolved team_id from team name assert team_id == "resolved-team-id" assert result_team == team_object @@ -1408,35 +1567,28 @@ async def test_find_and_validate_team_id_takes_precedence_over_name(): jwt_handler = JWTHandler() user_api_key_cache = DualCache() proxy_logging_obj = ProxyLogging(user_api_key_cache=user_api_key_cache) - + jwt_handler.update_environment( prisma_client=None, user_api_key_cache=user_api_key_cache, litellm_jwtauth=LiteLLM_JWTAuth( - team_id_jwt_field="team_id", - team_alias_jwt_field="team_alias" + team_id_jwt_field="team_id", team_alias_jwt_field="team_alias" ), ) - + # Token with both team_id and team name - jwt_token = { - "sub": "user-1", - "team_id": "direct-team-id", - "team_alias": "my-team" - } - + jwt_token = {"sub": "user-1", "team_id": "direct-team-id", "team_alias": "my-team"} + # Mock team object returned by get_team_object (by ID) team_object = LiteLLM_TeamTable(team_id="direct-team-id") - + with patch( - "litellm.proxy.auth.handle_jwt.get_team_object", - new_callable=AsyncMock + "litellm.proxy.auth.handle_jwt.get_team_object", new_callable=AsyncMock ) as mock_get_by_id, patch( - "litellm.proxy.auth.handle_jwt.get_team_object_by_alias", - new_callable=AsyncMock + "litellm.proxy.auth.handle_jwt.get_team_object_by_alias", new_callable=AsyncMock ) as mock_get_by_alias: mock_get_by_id.return_value = team_object - + team_id, result_team = await JWTAuthManager.find_and_validate_specific_team_id( jwt_handler=jwt_handler, jwt_valid_token=jwt_token, @@ -1445,7 +1597,7 @@ async def test_find_and_validate_team_id_takes_precedence_over_name(): parent_otel_span=None, proxy_logging_obj=proxy_logging_obj, ) - + # Should use team_id directly, not resolve by name assert team_id == "direct-team-id" assert result_team == team_object @@ -1466,7 +1618,7 @@ async def test_find_and_validate_raises_when_required_team_not_found(): jwt_handler = JWTHandler() user_api_key_cache = DualCache() proxy_logging_obj = ProxyLogging(user_api_key_cache=user_api_key_cache) - + jwt_handler.update_environment( prisma_client=None, user_api_key_cache=user_api_key_cache, @@ -1474,12 +1626,10 @@ async def test_find_and_validate_raises_when_required_team_not_found(): team_alias_jwt_field="team_alias" # Required, but not in token ), ) - + # Token without team info - jwt_token = { - "sub": "user-1" - } - + jwt_token = {"sub": "user-1"} + with pytest.raises(Exception) as exc_info: await JWTAuthManager.find_and_validate_specific_team_id( jwt_handler=jwt_handler, @@ -1489,7 +1639,7 @@ async def test_find_and_validate_raises_when_required_team_not_found(): parent_otel_span=None, proxy_logging_obj=proxy_logging_obj, ) - + assert "No team found in token" in str(exc_info.value) assert "team_alias field 'team_alias'" in str(exc_info.value) @@ -1503,29 +1653,29 @@ async def test_get_org_alias_with_nested_fields(): from litellm.proxy.auth.handle_jwt import JWTHandler jwt_handler = JWTHandler() - + # Test token with nested org name nested_token = { - "company": { - "organization": { - "name": "acme-corp" - } - }, - "org_name": "flat-org" + "company": {"organization": {"name": "acme-corp"}}, + "org_name": "flat-org", } - + # Test nested access - jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth(org_alias_jwt_field="company.organization.name") + jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth( + org_alias_jwt_field="company.organization.name" + ) assert jwt_handler.get_org_alias(nested_token, None) == "acme-corp" - + # Test flat access jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth(org_alias_jwt_field="org_name") assert jwt_handler.get_org_alias(nested_token, None) == "flat-org" - + # Test missing field returns default - jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth(org_alias_jwt_field="nonexistent.field") + jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth( + org_alias_jwt_field="nonexistent.field" + ) assert jwt_handler.get_org_alias(nested_token, "default-org") == "default-org" - + # Test with org_alias_jwt_field not configured jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth() assert jwt_handler.get_org_alias(nested_token, "default") is None @@ -1544,15 +1694,13 @@ async def test_get_objects_resolves_org_by_name(): jwt_handler = JWTHandler() user_api_key_cache = DualCache() proxy_logging_obj = ProxyLogging(user_api_key_cache=user_api_key_cache) - + jwt_handler.update_environment( prisma_client=None, user_api_key_cache=user_api_key_cache, - litellm_jwtauth=LiteLLM_JWTAuth( - org_alias_jwt_field="org_alias" - ), + litellm_jwtauth=LiteLLM_JWTAuth(org_alias_jwt_field="org_alias"), ) - + # Mock org object returned by get_org_object_by_alias org_object = LiteLLM_OrganizationTable( organization_id="resolved-org-id", @@ -1560,15 +1708,14 @@ async def test_get_objects_resolves_org_by_name(): budget_id="budget-1", created_by="admin", updated_by="admin", - models=[] + models=[], ) - + with patch( - "litellm.proxy.auth.handle_jwt.get_org_object_by_alias", - new_callable=AsyncMock + "litellm.proxy.auth.handle_jwt.get_org_object_by_alias", new_callable=AsyncMock ) as mock_get_by_alias: mock_get_by_alias.return_value = org_object - + ( result_user_obj, result_org_obj, @@ -1589,7 +1736,7 @@ async def test_get_objects_resolves_org_by_name(): route="/chat/completions", org_alias="my-org", ) - + # Should resolve org by alias - org_id can be derived from org_object.organization_id assert result_org_obj == org_object assert result_org_obj.organization_id == "resolved-org-id" @@ -1643,7 +1790,9 @@ async def test_resolve_jwks_url_resolves_oidc_discovery_document(): litellm_jwtauth=LiteLLM_JWTAuth(), ) - discovery_url = "https://login.microsoftonline.com/tenant/.well-known/openid-configuration" + discovery_url = ( + "https://login.microsoftonline.com/tenant/.well-known/openid-configuration" + ) jwks_url = "https://login.microsoftonline.com/tenant/discovery/keys" mock_response = MagicMock() @@ -1674,7 +1823,9 @@ async def test_resolve_jwks_url_caches_resolved_jwks_uri(): litellm_jwtauth=LiteLLM_JWTAuth(), ) - discovery_url = "https://login.microsoftonline.com/tenant/.well-known/openid-configuration" + discovery_url = ( + "https://login.microsoftonline.com/tenant/.well-known/openid-configuration" + ) jwks_url = "https://login.microsoftonline.com/tenant/discovery/keys" mock_response = MagicMock() @@ -1818,9 +1969,9 @@ async def test_find_and_validate_specific_team_id_hints_bracket_notation(): error_msg = str(exc_info.value) # Should mention the bad field name and suggest the fix assert "roles.0" in error_msg, f"Expected field name in: {error_msg}" - assert "roles" in error_msg and "list" in error_msg, ( - f"Expected hint about using 'roles' instead: {error_msg}" - ) + assert ( + "roles" in error_msg and "list" in error_msg + ), f"Expected hint about using 'roles' instead: {error_msg}" @pytest.mark.asyncio @@ -1848,9 +1999,9 @@ async def test_find_and_validate_specific_team_id_hints_bracket_index_notation() error_msg = str(exc_info.value) assert "roles[0]" in error_msg, f"Expected field name in: {error_msg}" - assert "roles" in error_msg and "list" in error_msg, ( - f"Expected hint about using 'roles' instead: {error_msg}" - ) + assert ( + "roles" in error_msg and "list" in error_msg + ), f"Expected hint about using 'roles' instead: {error_msg}" @pytest.mark.asyncio @@ -1878,4 +2029,3 @@ async def test_find_and_validate_specific_team_id_no_hint_for_valid_field(): error_msg = str(exc_info.value) assert "Hint" not in error_msg -