# What is this? ## Tests if 'get_end_user_object' works as expected import sys, os, asyncio, time, random, uuid import traceback from dotenv import load_dotenv load_dotenv() import os sys.path.insert( 0, os.path.abspath("../..") ) # Adds the parent directory to the system path import pytest, litellm import httpx from litellm.proxy._types import UserAPIKeyAuth from litellm.proxy.auth.auth_checks import get_end_user_object from litellm.caching.caching import DualCache from litellm.proxy._types import ( LiteLLM_EndUserTable, LiteLLM_BudgetTable, LiteLLM_UserTable, LiteLLM_TeamTable, Litellm_EntityType, ) from litellm.proxy.utils import PrismaClient from litellm.proxy.auth.auth_checks import ( can_team_access_model, _virtual_key_soft_budget_check, _team_soft_budget_check, ) from litellm.proxy.utils import ProxyLogging from litellm.proxy.utils import CallInfo @pytest.mark.parametrize("customer_spend, customer_budget", [(0, 10), (10, 0)]) @pytest.mark.asyncio async def test_get_end_user_object(customer_spend, customer_budget): """ Scenario 1: normal Scenario 2: user over budget """ end_user_id = "my-test-customer" _budget = LiteLLM_BudgetTable(max_budget=customer_budget) end_user_obj = LiteLLM_EndUserTable( user_id=end_user_id, spend=customer_spend, litellm_budget_table=_budget, blocked=False, ) _cache = DualCache() _key = "end_user_id:{}".format(end_user_id) _cache.set_cache(key=_key, value=end_user_obj.model_dump()) try: await get_end_user_object( end_user_id=end_user_id, prisma_client="RANDOM VALUE", # type: ignore user_api_key_cache=_cache, route="/v1/chat/completions", ) if customer_spend > customer_budget: pytest.fail( "Expected call to fail. Customer Spend={}, Customer Budget={}".format( customer_spend, customer_budget ) ) except Exception as e: if ( isinstance(e, litellm.BudgetExceededError) and customer_spend > customer_budget ): pass else: pytest.fail( "Expected call to work. Customer Spend={}, Customer Budget={}, Error={}".format( customer_spend, customer_budget, str(e) ) ) @pytest.mark.parametrize( "model, expect_to_work", [ ("openai/gpt-4o-mini", True), ("openai/gpt-4o", False), ], ) @pytest.mark.asyncio async def test_can_key_call_model(model, expect_to_work): """ If wildcard model + specific model is used, choose the specific model settings """ from litellm.proxy.auth.auth_checks import can_key_call_model from fastapi import HTTPException llm_model_list = [ { "model_name": "openai/*", "litellm_params": { "model": "openai/*", "api_key": "test-api-key", }, "model_info": { "id": "e6e7006f83029df40ebc02ddd068890253f4cd3092bcb203d3d8e6f6f606f30f", "db_model": False, "access_groups": ["public-openai-models"], }, }, { "model_name": "openai/gpt-4o", "litellm_params": { "model": "openai/gpt-4o", "api_key": "test-api-key", }, "model_info": { "id": "0cfcd87f2cb12a783a466888d05c6c89df66db23e01cecd75ec0b83aed73c9ad", "db_model": False, "access_groups": ["private-openai-models"], }, }, ] router = litellm.Router(model_list=llm_model_list) args = { "model": model, "llm_model_list": llm_model_list, "valid_token": UserAPIKeyAuth( models=["public-openai-models"], ), "llm_router": router, } if expect_to_work: await can_key_call_model(**args) else: with pytest.raises(Exception) as e: await can_key_call_model(**args) print(e) @pytest.mark.parametrize( "model, expect_to_work", [("openai/gpt-4o", False), ("openai/gpt-4o-mini", True)], ) @pytest.mark.asyncio async def test_can_team_call_model(model, expect_to_work): from litellm.proxy.auth.auth_checks import model_in_access_group from fastapi import HTTPException llm_model_list = [ { "model_name": "openai/*", "litellm_params": { "model": "openai/*", "api_key": "test-api-key", }, "model_info": { "id": "e6e7006f83029df40ebc02ddd068890253f4cd3092bcb203d3d8e6f6f606f30f", "db_model": False, "access_groups": ["public-openai-models"], }, }, { "model_name": "openai/gpt-4o", "litellm_params": { "model": "openai/gpt-4o", "api_key": "test-api-key", }, "model_info": { "id": "0cfcd87f2cb12a783a466888d05c6c89df66db23e01cecd75ec0b83aed73c9ad", "db_model": False, "access_groups": ["private-openai-models"], }, }, ] router = litellm.Router(model_list=llm_model_list) args = { "model": model, "team_models": ["public-openai-models"], "llm_router": router, } if expect_to_work: assert model_in_access_group(**args) else: assert not model_in_access_group(**args) @pytest.mark.parametrize( "key_models, model, expect_to_work", [ (["openai/*"], "openai/gpt-4o", True), (["openai/*"], "openai/gpt-4o-mini", True), (["openai/*"], "openaiz/gpt-4o-mini", False), (["bedrock/*"], "bedrock/anthropic.claude-3-5-sonnet-20240620", True), (["bedrock/*"], "bedrockz/anthropic.claude-3-5-sonnet-20240620", False), (["bedrock/us.*"], "bedrock/us.amazon.nova-micro-v1:0", True), ], ) @pytest.mark.asyncio async def test_can_key_call_model_wildcard_access(key_models, model, expect_to_work): from litellm.proxy.auth.auth_checks import can_key_call_model from fastapi import HTTPException llm_model_list = [ { "model_name": "openai/*", "litellm_params": { "model": "openai/*", "api_key": "test-api-key", }, "model_info": { "id": "e6e7006f83029df40ebc02ddd068890253f4cd3092bcb203d3d8e6f6f606f30f", "db_model": False, }, }, { "model_name": "bedrock/*", "litellm_params": { "model": "bedrock/*", "api_key": "test-api-key", }, "model_info": { "id": "e6e7006f83029df40ebc02ddd068890253f4cd3092bcb203d3d8e6f6f606f30f", "db_model": False, }, }, { "model_name": "openai/gpt-4o", "litellm_params": { "model": "openai/gpt-4o", "api_key": "test-api-key", }, "model_info": { "id": "0cfcd87f2cb12a783a466888d05c6c89df66db23e01cecd75ec0b83aed73c9ad", "db_model": False, }, }, ] router = litellm.Router(model_list=llm_model_list) user_api_key_object = UserAPIKeyAuth( models=key_models, ) if expect_to_work: await can_key_call_model( model=model, llm_model_list=llm_model_list, valid_token=user_api_key_object, llm_router=router, ) else: with pytest.raises(Exception) as e: await can_key_call_model( model=model, llm_model_list=llm_model_list, valid_token=user_api_key_object, llm_router=router, ) print(e) @pytest.mark.parametrize( "key_models, model, expect_to_work", [ # After a cost-map reload, add_known_models() updates anthropic_models so # the anthropic/* wildcard can match a newly-added Anthropic model. (["anthropic/*"], "claude-brand-new-model-reload-test", True), # Wrong provider wildcard must still be denied even after reload. (["openai/*"], "claude-brand-new-model-reload-test", False), ], ) @pytest.mark.asyncio async def test_wildcard_access_after_cost_map_reload(key_models, model, expect_to_work): """ Regression test: after a cost-map hot-reload, calling add_known_models(model_cost_map=new_map) must update litellm.anthropic_models so that the anthropic/* wildcard correctly grants (or denies) access to newly-added models. Root cause: both reload paths in proxy_server.py only updated litellm.model_cost but never re-ran add_known_models(), so the provider sets stayed stale and wildcard matching failed for new models. Fix: each reload now calls litellm.add_known_models(model_cost_map=new_map) with the fetched map passed explicitly to avoid any reference ambiguity. """ from litellm.proxy.auth.auth_checks import can_key_call_model # Build a new cost map that includes the brand-new model — exactly what # proxy_server.py receives from get_model_cost_map() during a reload. new_cost_map = dict(litellm.model_cost) new_cost_map[model] = { "litellm_provider": "anthropic", "max_tokens": 8192, "input_cost_per_token": 0.000003, "output_cost_per_token": 0.000015, } original_model_cost = litellm.model_cost litellm.model_cost = new_cost_map # Confirm the model is NOT yet in the provider set before reload propagation. assert model not in litellm.anthropic_models # Simulate what proxy_server.py now does after every reload. litellm.add_known_models(model_cost_map=new_cost_map) # After add_known_models(), the model must be in the set. assert model in litellm.anthropic_models llm_model_list = [ { "model_name": "anthropic/*", "litellm_params": {"model": "anthropic/*", "api_key": "test-api-key"}, "model_info": {"id": "test-id-anthropic-wildcard", "db_model": False}, }, { "model_name": "openai/*", "litellm_params": {"model": "openai/*", "api_key": "test-api-key"}, "model_info": {"id": "test-id-openai-wildcard", "db_model": False}, }, ] router = litellm.Router(model_list=llm_model_list) user_api_key_object = UserAPIKeyAuth(models=key_models) try: if expect_to_work: await can_key_call_model( model=model, llm_model_list=llm_model_list, valid_token=user_api_key_object, llm_router=router, ) else: with pytest.raises(Exception): await can_key_call_model( model=model, llm_model_list=llm_model_list, valid_token=user_api_key_object, llm_router=router, ) finally: litellm.model_cost = original_model_cost litellm.anthropic_models.discard(model) @pytest.mark.asyncio async def test_add_known_models_explicit_map_updates_provider_sets(): """ Regression test: after a cost-map hot-reload, calling add_known_models(model_cost_map=new_map) with the new map passed explicitly must add any new provider models to the correct provider sets so that wildcard access checks (anthropic/*, openai/*, …) work immediately. This covers the proxy_server.py fix where both reload paths now call litellm.add_known_models(model_cost_map=new_model_cost_map) instead of relying on the module-level model_cost being up to date. """ fake_new_model = "claude-brand-new-explicit-map-test" # Baseline: the model must not be in the sets before we do anything. assert fake_new_model not in litellm.anthropic_models new_cost_map = dict(litellm.model_cost) new_cost_map[fake_new_model] = { "litellm_provider": "anthropic", "max_tokens": 8192, "input_cost_per_token": 0.000003, "output_cost_per_token": 0.000015, } # Simulate what proxy_server.py does on reload. original_model_cost = litellm.model_cost litellm.model_cost = new_cost_map litellm.add_known_models(model_cost_map=new_cost_map) try: assert fake_new_model in litellm.anthropic_models, ( "add_known_models(model_cost_map=...) did not add the new model to " "litellm.anthropic_models — wildcard access checks would fail." ) finally: # Clean up: restore original state. litellm.model_cost = original_model_cost litellm.anthropic_models.discard(fake_new_model) @pytest.mark.asyncio async def test_is_valid_fallback_model(): from litellm.proxy.auth.auth_checks import is_valid_fallback_model from litellm import Router router = Router( model_list=[ { "model_name": "gpt-3.5-turbo", "litellm_params": {"model": "openai/gpt-3.5-turbo"}, } ] ) try: await is_valid_fallback_model( model="gpt-3.5-turbo", llm_router=router, user_model=None ) except Exception as e: pytest.fail(f"Expected is_valid_fallback_model to work, got exception: {e}") try: await is_valid_fallback_model( model="gpt-4o", llm_router=router, user_model=None ) pytest.fail("Expected is_valid_fallback_model to fail") except Exception as e: assert "Invalid" in str(e) @pytest.mark.parametrize( "token_spend, max_budget, expect_budget_error", [ (5.0, 10.0, False), # Under budget (10.0, 10.0, True), # At budget limit (15.0, 10.0, True), # Over budget ], ) @pytest.mark.asyncio async def test_virtual_key_max_budget_check( token_spend, max_budget, expect_budget_error ): """ Test if virtual key budget checks work as expected: 1. Triggers budget alert for all cases 2. Raises BudgetExceededError when spend >= max_budget """ from litellm.proxy.auth.auth_checks import _virtual_key_max_budget_check from litellm.proxy.utils import ProxyLogging # Setup test data valid_token = UserAPIKeyAuth( token="test-token", spend=token_spend, max_budget=max_budget, user_id="test-user", key_alias="test-key", ) user_obj = LiteLLM_UserTable( user_id="test-user", user_email="test@email.com", max_budget=None, ) proxy_logging_obj = ProxyLogging( user_api_key_cache=None, ) # Track if budget alert was called alert_called = False async def mock_budget_alert(*args, **kwargs): nonlocal alert_called alert_called = True proxy_logging_obj.budget_alerts = mock_budget_alert try: await _virtual_key_max_budget_check( valid_token=valid_token, proxy_logging_obj=proxy_logging_obj, user_obj=user_obj, ) if expect_budget_error: pytest.fail( f"Expected BudgetExceededError for spend={token_spend}, max_budget={max_budget}" ) except litellm.BudgetExceededError as e: if not expect_budget_error: pytest.fail( f"Unexpected BudgetExceededError for spend={token_spend}, max_budget={max_budget}" ) assert e.current_cost == token_spend assert e.max_budget == max_budget await asyncio.sleep(1) # Verify budget alert was triggered assert alert_called, "Budget alert should be triggered" @pytest.mark.parametrize( "model, team_models, expect_to_work", [ ("gpt-4", ["gpt-4"], True), # exact match ("gpt-4", ["all-proxy-models"], True), # all-proxy-models access ("gpt-4", ["*"], True), # wildcard access ("gpt-4", ["openai/*"], True), # openai wildcard access ( "bedrock/anthropic.claude-3-5-sonnet-20240620", ["bedrock/*"], True, ), # wildcard access ( "bedrockz/anthropic.claude-3-5-sonnet-20240620", ["bedrock/*"], False, ), # non-match wildcard access ("bedrock/very_new_model", ["bedrock/*"], True), # bedrock wildcard access ( "bedrock/claude-3-5-sonnet-20240620", ["bedrock/claude-*"], True, ), # match on pattern ( "bedrock/claude-3-6-sonnet-20240620", ["bedrock/claude-3-5-*"], False, ), # don't match on pattern ("openai/gpt-4o", ["openai/*"], True), # openai wildcard access ("gpt-4", ["gpt-3.5-turbo"], False), # model not in allowed list ("claude-3", [], True), # empty model list (allows all) ], ) @pytest.mark.asyncio async def test_can_team_access_model(model, team_models, expect_to_work): """ Test cases for can_team_access_model: 1. Exact model match 2. all-proxy-models access 3. Wildcard (*) access 4. OpenAI wildcard access 5. Model not in allowed list 6. Empty model list 7. None model list """ try: team_object = LiteLLM_TeamTable( team_id="test-team", models=team_models, ) result = await can_team_access_model( model=model, team_object=team_object, llm_router=None, team_model_aliases=None, ) if not expect_to_work: pytest.fail( f"Expected model access check to fail for model={model}, team_models={team_models}" ) except Exception as e: if expect_to_work: pytest.fail( f"Expected model access check to work for model={model}, team_models={team_models}. Got error: {str(e)}" ) @pytest.mark.parametrize( "spend, soft_budget, expect_alert", [ (100, 50, True), # Over soft budget (50, 50, True), # At soft budget (25, 50, False), # Under soft budget (100, None, False), # No soft budget set ], ) @pytest.mark.asyncio async def test_virtual_key_soft_budget_check(spend, soft_budget, expect_alert): """ Test cases for _virtual_key_soft_budget_check: 1. Spend over soft budget 2. Spend at soft budget 3. Spend under soft budget 4. No soft budget set """ alert_triggered = False class MockProxyLogging: async def budget_alerts(self, type, user_info): nonlocal alert_triggered alert_triggered = True assert type == "soft_budget" assert isinstance(user_info, CallInfo) valid_token = UserAPIKeyAuth( token="test-token", spend=spend, soft_budget=soft_budget, user_id="test-user", team_id="test-team", key_alias="test-key", ) proxy_logging_obj = MockProxyLogging() await _virtual_key_soft_budget_check( valid_token=valid_token, proxy_logging_obj=proxy_logging_obj, ) await asyncio.sleep(0.1) # Allow time for the alert task to complete assert ( alert_triggered == expect_alert ), f"Expected alert_triggered to be {expect_alert} for spend={spend}, soft_budget={soft_budget}" @pytest.mark.parametrize( "spend, soft_budget, expect_alert, metadata, expected_alert_emails", [ (100, 50, False, None, None), # Over soft budget, no metadata - no alert_emails configured, so no alert (50, 50, False, None, None), # At soft budget, no metadata - no alert_emails configured, so no alert (25, 50, False, None, None), # Under soft budget (100, None, False, None, None), # No soft budget set (100, 50, True, {"soft_budget_alerting_emails": ["team1@example.com", "team2@example.com"]}, ["team1@example.com", "team2@example.com"]), # Over soft budget with list of emails (100, 50, True, {"soft_budget_alerting_emails": "team1@example.com,team2@example.com"}, ["team1@example.com", "team2@example.com"]), # Over soft budget with comma-separated emails (100, 50, True, {"soft_budget_alerting_emails": ["team1@example.com", "", " ", "team2@example.com"]}, ["team1@example.com", "team2@example.com"]), # Over soft budget with empty strings filtered ], ) @pytest.mark.asyncio async def test_team_soft_budget_check(spend, soft_budget, expect_alert, metadata, expected_alert_emails): """ Test cases for _team_soft_budget_check: 1. Spend over soft budget, no alert_emails configured - should NOT trigger alert (alerts only sent when alert_emails configured) 2. Spend at soft budget, no alert_emails configured - should NOT trigger alert (alerts only sent when alert_emails configured) 3. Spend under soft budget - should not trigger alert 4. No soft budget set - should not trigger alert 5. Team with alert emails in metadata (list) - should include alert_emails in CallInfo 6. Team with alert emails in metadata (comma-separated string) - should parse and include alert_emails 7. Team with alert emails containing empty strings - should filter them out """ alert_triggered = False captured_call_info = None class MockProxyLogging: async def budget_alerts(self, type, user_info): nonlocal alert_triggered, captured_call_info alert_triggered = True captured_call_info = user_info assert type == "soft_budget" assert isinstance(user_info, CallInfo) valid_token = UserAPIKeyAuth( token="test-token", user_id="test-user", team_id="test-team", team_alias="test-team-alias", key_alias="test-key", ) team_object = LiteLLM_TeamTable( team_id="test-team", spend=spend, soft_budget=soft_budget, max_budget=100.0, metadata=metadata, ) proxy_logging_obj = MockProxyLogging() await _team_soft_budget_check( team_object=team_object, valid_token=valid_token, proxy_logging_obj=proxy_logging_obj, ) await asyncio.sleep(0.1) # Allow time for the alert task to complete assert ( alert_triggered == expect_alert ), f"Expected alert_triggered to be {expect_alert} for spend={spend}, soft_budget={soft_budget}" if expect_alert: assert captured_call_info is not None assert captured_call_info.team_id == "test-team" assert captured_call_info.spend == spend assert captured_call_info.soft_budget == soft_budget assert captured_call_info.event_group == Litellm_EntityType.TEAM # Verify alert_emails if expected if expected_alert_emails is not None: assert captured_call_info.alert_emails == expected_alert_emails else: assert captured_call_info.alert_emails is None or captured_call_info.alert_emails == [] @pytest.mark.asyncio async def test_can_user_call_model(): from litellm.proxy.auth.auth_checks import can_user_call_model from litellm.proxy._types import ProxyException from litellm import Router router = Router( model_list=[ { "model_name": "anthropic-claude", "litellm_params": {"model": "anthropic/anthropic-claude"}, }, { "model_name": "gpt-3.5-turbo", "litellm_params": {"model": "gpt-3.5-turbo", "api_key": "test-api-key"}, }, ] ) args = { "model": "anthropic-claude", "llm_router": router, "user_object": LiteLLM_UserTable( user_id="testuser21@mycompany.com", max_budget=None, spend=0.0042295, model_max_budget={}, model_spend={}, user_email="testuser@mycompany.com", models=["gpt-3.5-turbo"], ), } with pytest.raises(ProxyException) as e: await can_user_call_model(**args) args["model"] = "gpt-3.5-turbo" await can_user_call_model(**args) @pytest.mark.asyncio async def test_can_user_call_model_with_no_default_models(): from litellm.proxy.auth.auth_checks import can_user_call_model from litellm.proxy._types import ProxyException, SpecialModelNames from unittest.mock import MagicMock args = { "model": "anthropic-claude", "llm_router": MagicMock(), "user_object": LiteLLM_UserTable( user_id="testuser21@mycompany.com", max_budget=None, spend=0.0042295, model_max_budget={}, model_spend={}, user_email="testuser@mycompany.com", models=[SpecialModelNames.no_default_models.value], ), } with pytest.raises(ProxyException) as e: await can_user_call_model(**args) @pytest.mark.asyncio async def test_get_fuzzy_user_object(): from litellm.proxy.auth.auth_checks import _get_fuzzy_user_object from litellm.proxy.utils import PrismaClient from unittest.mock import AsyncMock, MagicMock # Setup mock Prisma client mock_prisma = MagicMock() mock_prisma.db = MagicMock() mock_prisma.db.litellm_usertable = MagicMock() # Mock user data test_user = LiteLLM_UserTable( user_id="test_123", sso_user_id="sso_123", user_email="test@example.com", organization_memberships=[], max_budget=None, ) # Test 1: Find user by SSO ID mock_prisma.db.litellm_usertable.find_unique = AsyncMock(return_value=test_user) result = await _get_fuzzy_user_object( prisma_client=mock_prisma, sso_user_id="sso_123", user_email="test@example.com" ) assert result == test_user mock_prisma.db.litellm_usertable.find_unique.assert_called_with( where={"sso_user_id": "sso_123"}, include={"organization_memberships": True} ) # Test 2: SSO ID not found, find by email mock_prisma.db.litellm_usertable.find_unique = AsyncMock(return_value=None) mock_prisma.db.litellm_usertable.find_first = AsyncMock(return_value=test_user) mock_prisma.db.litellm_usertable.update = AsyncMock() result = await _get_fuzzy_user_object( prisma_client=mock_prisma, sso_user_id="new_sso_456", user_email="test@example.com", ) assert result == test_user mock_prisma.db.litellm_usertable.find_first.assert_called_with( where={"user_email": {"equals": "test@example.com", "mode": "insensitive"}}, include={"organization_memberships": True}, ) # Test 3: Verify background SSO update task when user found by email await asyncio.sleep(0.1) # Allow time for background task mock_prisma.db.litellm_usertable.update.assert_called_with( where={"user_id": "test_123"}, data={"sso_user_id": "new_sso_456"} ) # Test 4: User not found by either method mock_prisma.db.litellm_usertable.find_unique = AsyncMock(return_value=None) mock_prisma.db.litellm_usertable.find_first = AsyncMock(return_value=None) result = await _get_fuzzy_user_object( prisma_client=mock_prisma, sso_user_id="unknown_sso", user_email="unknown@example.com", ) assert result is None # Test 5: Only email provided (no SSO ID) mock_prisma.db.litellm_usertable.find_first = AsyncMock(return_value=test_user) result = await _get_fuzzy_user_object( prisma_client=mock_prisma, user_email="test@example.com" ) assert result == test_user mock_prisma.db.litellm_usertable.find_first.assert_called_with( where={"user_email": {"equals": "test@example.com", "mode": "insensitive"}}, include={"organization_memberships": True}, ) # Test 6: Only SSO ID provided (no email) mock_prisma.db.litellm_usertable.find_unique = AsyncMock(return_value=test_user) result = await _get_fuzzy_user_object( prisma_client=mock_prisma, sso_user_id="sso_123" ) assert result == test_user mock_prisma.db.litellm_usertable.find_unique.assert_called_with( where={"sso_user_id": "sso_123"}, include={"organization_memberships": True} ) @pytest.mark.parametrize( "model, alias_map, expect_to_work", [ ("gpt-4", {"gpt-4": "gpt-4-team1"}, True), # model matches alias value ("gpt-5", {"gpt-4": "gpt-4-team1"}, False), ], ) @pytest.mark.asyncio async def test_can_key_call_model_with_aliases(model, alias_map, expect_to_work): """ Test if can_key_call_model correctly handles model aliases in the token """ from litellm.proxy.auth.auth_checks import can_key_call_model llm_model_list = [ { "model_name": "gpt-4-team1", "litellm_params": { "model": "gpt-4", "api_key": "test-api-key", }, } ] router = litellm.Router(model_list=llm_model_list) user_api_key_object = UserAPIKeyAuth( models=[ "gpt-4-team1", ], team_model_aliases=alias_map, ) if expect_to_work: await can_key_call_model( model=model, llm_model_list=llm_model_list, valid_token=user_api_key_object, llm_router=router, ) else: with pytest.raises(Exception) as e: await can_key_call_model( model=model, llm_model_list=llm_model_list, valid_token=user_api_key_object, llm_router=router, ) # --------------------------------------------------------------------------- # Access group cache helpers (_cache_access_object, _delete_cache_access_object) # --------------------------------------------------------------------------- @pytest.mark.asyncio async def test_cache_access_object(): """Test _cache_access_object stores access group in cache with correct key.""" from litellm.proxy.auth.auth_checks import _cache_access_object from litellm.proxy._types import LiteLLM_AccessGroupTable cache = DualCache() ag_id = "ag-test-123" ag_table = LiteLLM_AccessGroupTable( access_group_id=ag_id, access_group_name="test-group", access_model_names=["gpt-4"], ) await _cache_access_object( access_group_id=ag_id, access_group_table=ag_table, user_api_key_cache=cache, ) cached = await cache.async_get_cache(key=f"access_group_id:{ag_id}") assert cached is not None if isinstance(cached, dict): assert cached.get("access_group_id") == ag_id assert cached.get("access_group_name") == "test-group" else: assert cached.access_group_id == ag_id assert cached.access_group_name == "test-group" @pytest.mark.asyncio async def test_delete_cache_access_object(): """Test _delete_cache_access_object removes access group from in-memory cache.""" from litellm.proxy.auth.auth_checks import _delete_cache_access_object from litellm.proxy._types import LiteLLM_AccessGroupTable cache = DualCache() ag_id = "ag-delete-test" ag_table = LiteLLM_AccessGroupTable( access_group_id=ag_id, access_group_name="to-delete", ) await cache.async_set_cache(key=f"access_group_id:{ag_id}", value=ag_table, ttl=60) await _delete_cache_access_object(access_group_id=ag_id, user_api_key_cache=cache) cached = await cache.async_get_cache(key=f"access_group_id:{ag_id}") assert cached is None # --------------------------------------------------------------------------- # Access group resource fetchers (_get_models_from_access_groups, _get_agent_ids_from_access_groups) # --------------------------------------------------------------------------- @pytest.mark.parametrize( "resource_field, access_group_data, expected", [ ( "access_model_names", {"access_group_id": "ag-1", "access_model_names": ["gpt-4", "claude-3"]}, ["gpt-4", "claude-3"], ), ( "access_agent_ids", {"access_group_id": "ag-2", "access_agent_ids": ["agent-a", "agent-b"]}, ["agent-a", "agent-b"], ), ( "access_model_names", {"access_group_id": "ag-3", "access_model_names": []}, [], ), ], ) @pytest.mark.asyncio async def test_get_resources_from_access_groups(resource_field, access_group_data, expected): """Test _get_resources_from_access_groups returns correct resource list from access groups.""" from unittest.mock import AsyncMock, MagicMock, patch from litellm.proxy._types import LiteLLM_AccessGroupTable from litellm.proxy.auth.auth_checks import ( _get_agent_ids_from_access_groups, _get_models_from_access_groups, ) ag_table = LiteLLM_AccessGroupTable( access_group_id=access_group_data["access_group_id"], access_group_name="test", access_model_names=access_group_data.get("access_model_names", []), access_agent_ids=access_group_data.get("access_agent_ids", []), ) with patch( "litellm.proxy.auth.auth_checks.get_access_object", new_callable=AsyncMock, return_value=ag_table, ): if resource_field == "access_model_names": result = await _get_models_from_access_groups( access_group_ids=[access_group_data["access_group_id"]], prisma_client=MagicMock(), user_api_key_cache=DualCache(), ) else: result = await _get_agent_ids_from_access_groups( access_group_ids=[access_group_data["access_group_id"]], prisma_client=MagicMock(), user_api_key_cache=DualCache(), ) assert sorted(result) == sorted(expected) @pytest.mark.asyncio async def test_get_models_from_access_groups_empty_ids(): """Test _get_models_from_access_groups returns empty list when access_group_ids is empty.""" from litellm.proxy.auth.auth_checks import _get_models_from_access_groups result = await _get_models_from_access_groups(access_group_ids=[]) assert result == [] # --------------------------------------------------------------------------- # can_team_access_model with access_group_ids fallback # --------------------------------------------------------------------------- @pytest.mark.asyncio async def test_can_team_access_model_via_access_group_ids(): """Test can_team_access_model allows access when team has access_group_ids granting model access.""" from unittest.mock import AsyncMock, patch from litellm.proxy.auth.auth_checks import can_team_access_model team_object = LiteLLM_TeamTable( team_id="test-team", models=[], access_group_ids=["ag-with-gpt4"], ) with patch( "litellm.proxy.auth.auth_checks._get_models_from_access_groups", new_callable=AsyncMock, return_value=["gpt-4"], ): result = await can_team_access_model( model="gpt-4", team_object=team_object, llm_router=None, team_model_aliases=None, ) assert result is True @pytest.mark.asyncio async def test_can_team_access_model_access_group_ids_denied(): """Test can_team_access_model denies when neither team models nor access_group_ids grant access.""" from unittest.mock import AsyncMock, patch from litellm.proxy.auth.auth_checks import can_team_access_model from litellm.proxy._types import ProxyException team_object = LiteLLM_TeamTable( team_id="test-team", models=["gpt-3.5-turbo"], access_group_ids=["ag-other"], ) with patch( "litellm.proxy.auth.auth_checks._get_models_from_access_groups", new_callable=AsyncMock, return_value=["claude-3"], ): with pytest.raises(ProxyException): await can_team_access_model( model="gpt-4", team_object=team_object, llm_router=None, team_model_aliases=None, ) # --------------------------------------------------------------------------- # can_key_call_model with access_group_ids fallback # --------------------------------------------------------------------------- @pytest.mark.asyncio async def test_can_key_call_model_via_access_group_ids(): """Test can_key_call_model allows access when key has access_group_ids granting model access.""" from unittest.mock import AsyncMock, patch from litellm.proxy.auth.auth_checks import can_key_call_model user_api_key_object = UserAPIKeyAuth( token="test-token", models=[], access_group_ids=["ag-with-gpt4"], ) router = litellm.Router( model_list=[ { "model_name": "gpt-4", "litellm_params": {"model": "openai/gpt-4", "api_key": "test"}, } ] ) with patch( "litellm.proxy.auth.auth_checks._get_models_from_access_groups", new_callable=AsyncMock, return_value=["gpt-4"], ): await can_key_call_model( model="gpt-4", llm_model_list=[], valid_token=user_api_key_object, llm_router=router, )