mirror of
https://github.com/tiennm99/litellm.git
synced 2026-06-29 09:08:27 +00:00
9495f4e941
* auth_with_role_name add region_name arg for cross-account sts * update tests to include case with aws_region_name for _auth_with_aws_role * Only pass region_name to STS client when aws_region_name is set * Add optional aws_sts_endpoint to _auth_with_aws_role * Parametrize ambient-credentials test for no opts, region_name, and aws_sts_endpoint * consistently passing region and endpoint args into explicit credentials irsa * fix env var leakage * fix: bedrock openai-compatible imported-model should also have model arn encoded * feat: show proxy url in ModelHub (#21660) * fix(bedrock): correct modelInput format for Converse API batch models (#21656) * fix(proxy): add model_ids param to access group endpoints for precise deployment tagging (#21655) POST /access_group/new and PUT /access_group/{name}/update now accept an optional model_ids list that targets specific deployments by their unique model_id, instead of tagging every deployment that shares a model_name. When model_ids is provided it takes priority over model_names, giving API callers the same single-deployment precision that the UI already has via PATCH /model/{model_id}/update. Backward compatible: model_names continues to work as before. Closes #21544 * feat(proxy): add custom favicon support\n\nAdd ability to configure a custom favicon for the litellm proxy UI.\n\n- Add favicon_url field to UIThemeConfig model\n- Add LITELLM_FAVICON_URL env var support\n- Add /get_favicon endpoint to serve custom favicons\n- Update ThemeContext to dynamically set favicon\n- Add favicon URL input to UI theme settings page\n- Add comprehensive tests\n\nCloses #8323 (#21653) * fix(bedrock): prevent double UUID in create_file S3 key (#21650) In create_file for Bedrock, get_complete_file_url is called twice: once in the sync handler (generating UUID-1 for api_base) and once inside transform_create_file_request (generating UUID-2 for the actual S3 upload). The Bedrock provider correctly writes UUID-2 into litellm_params["upload_url"], but the sync handler unconditionally overwrites it with api_base (UUID-1). This causes the returned file_id to point to a non-existent S3 key. Fix: only set upload_url to api_base when transform_create_file_request has not already set it, preserving the Bedrock provider's value. Closes #21546 * feat(semantic-cache): support configurable vector dimensions for Qdrant (#21649) Add vector_size parameter to QdrantSemanticCache and expose it through the Cache facade as qdrant_semantic_cache_vector_size. This allows users to use embedding models with dimensions other than the default 1536, enabling cheaper/stronger models like Stella (1024d), bge-en-icl (4096d), voyage, cohere, etc. The parameter defaults to QDRANT_VECTOR_SIZE (env var or 1536) for backward compatibility. When creating new collections, the configured vector_size is used instead of the hardcoded constant. Closes #9377 * fix(utils): normalize camelCase thinking param keys to snake_case (#21762) Clients like OpenCode's @ai-sdk/openai-compatible send budgetTokens (camelCase) instead of budget_tokens in the thinking parameter, causing validation errors. Add early normalization in completion(). * feat: add optional digest mode for Slack alert types (#21683) Adds per-alert-type digest mode that aggregates duplicate alerts within a configurable time window and emits a single summary message with count, start/end timestamps. Configuration via general_settings.alert_type_config: alert_type_config: llm_requests_hanging: digest: true digest_interval: 86400 Digest key: (alert_type, request_model, api_base) Default interval: 24 hours Window type: fixed interval Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * feat: add blog_posts.json and local backup * feat: add GetBlogPosts utility with GitHub fetch and local fallback Adds GetBlogPosts class that fetches blog posts from GitHub with a 1-hour in-process TTL cache, validates the response, and falls back to the bundled blog_posts_backup.json on any network or validation failure. * test: add cache reset fixture and LITELLM_LOCAL_BLOG_POSTS test Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> * feat: add GET /public/litellm_blog_posts endpoint Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> * fix: log fallback warning in blog posts endpoint and tighten test * feat: add disable_show_blog to UISettings Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> * feat: add useUISettings and useDisableShowBlog hooks * fix: rename useUISettings to useUISettingsFlags to avoid naming collision * fix: use existing useUISettings hook in useDisableShowBlog to avoid cache duplication Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> * feat: add BlogDropdown component with react-query and error/retry state Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> * fix: enforce 5-post limit in BlogDropdown and add cap test * fix: add retry, stable post key, enabled guard in BlogDropdown Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> * feat: add BlogDropdown to navbar after Docs link * feat: add network_mock transport for benchmarking proxy overhead without real API calls Intercepts at httpx transport layer so the full proxy path (auth, routing, OpenAI SDK, response transformation) is exercised with zero-latency responses. Activated via `litellm_settings: { network_mock: true }` in proxy config. * Litellm dev 02 19 2026 p2 (#21871) * feat(ui/): new guardrails monitor 'demo mock representation of what guardrails monitor looks like * fix: ui updates * style(ui/): fix styling * feat: enable running ai monitor on individual guardrails * feat: add backend logic for guardrail monitoring * fix(guardrails/usage_endpoints.py): fix usage dashboard * fix(budget): fix timezone config lookup and replace hardcoded timezone map with ZoneInfo (#21754) * fix(budget): fix timezone config lookup and replace hardcoded timezone map with ZoneInfo * fix(budget): update stale docstring on get_budget_reset_time * fix: add missing return type annotations to iterator protocol methods in streaming_handler (#21750) * fix: add return type annotations to iterator protocol methods in streaming_handler Add missing return type annotations to __iter__, __aiter__, __next__, and __anext__ methods in CustomStreamWrapper and related classes. - __iter__(self) -> Iterator["ModelResponseStream"] - __aiter__(self) -> AsyncIterator["ModelResponseStream"] - __next__(self) -> "ModelResponseStream" - __anext__(self) -> "ModelResponseStream" Also adds AsyncIterator and Iterator to typing imports. Fixes issue with PLR0915 noqa comments and ensures proper type checking support. Related to: BerriAI/litellm#8304 * fix: add ruff PLR0915 noqa for files with too many statements * Add gollem Go agent framework cookbook example (#21747) Show how to use gollem, a production Go agent framework, with LiteLLM proxy for multi-provider LLM access including tool use and streaming. * fix: avoid mutating caller-owned dicts in SpendUpdateQueue aggregation (#21742) * fix(vertex_ai): enable context-1m-2025-08-07 beta header (#21870) * server root path regression doc * fixing syntax * fix: replace Zapier webhook with Google Form for survey submission (#21621) * Replace Zapier webhook with Google Form for survey submission * Add back error logging for survey submission debugging --------- Co-authored-by: Ishaan Jaff <ishaanjaffer0324@gmail.com> * Revert "Merge pull request #21140 from BerriAI/litellm_perf_user_api_key_auth" This reverts commit0e1db3f7e4, reversing changes made to7e2d6f2355. * test_vertex_ai_gemini_2_5_pro_streaming * UI new build * fix rendering * ui new build * docs fix * docs fix * docs fix * docs fix * docs fix * docs fix * docs fix * docs fix * release note docs * docs * adding image * fix(vertex_ai): enable context-1m-2025-08-07 beta header The `context-1m-2025-08-07` Anthropic beta header was set to `null` for vertex_ai, causing it to be filtered out when users set `extra_headers: {anthropic-beta: context-1m-2025-08-07}`. This prevented using Claude's 1M context window feature via Vertex AI, resulting in `prompt is too long: 460500 tokens > 200000 maximum` errors. Fixes #21861 --------- Co-authored-by: yuneng-jiang <yuneng.jiang@gmail.com> Co-authored-by: milan-berri <milan@berri.ai> Co-authored-by: Ishaan Jaff <ishaanjaffer0324@gmail.com> * Revert "fix(vertex_ai): enable context-1m-2025-08-07 beta header (#21870)" (#21876) This reverts commitbce078a796. * docs(ui): add pre-PR checklist to UI contributing guide Add testing and build verification steps per maintainer feedback from @yjiang-litellm. Contributors should run their related tests per-file and ensure npm run build passes before opening PRs. * Fix entries with fast and us/ * Add tests for fast and us * Add support for Priority PayGo for vertex ai and gemini * Add model pricing * fix: ensure arrival_time is set before calculating queue time * Fix: Anthropic model wildcard access issue * Add incident report * Add ability to see which model cost map is getting used * Fix name of title * Readd tpm limit * State management fixes for CheckBatchCost * Fix PR review comments * State management fixes for CheckBatchCost - Address greptile comments * fix mypy issues: * Add Noma guardrails v2 based on custom guardrails (#21400) * Fix code qa issues * Fix mypy issues * Fix mypy issues * Fix test_aaamodel_prices_and_context_window_json_is_valid * fix: update calendly on repo * fix(tests): use counter-based mock for time.time in prisma self-heal test The test used a fixed side_effect list for time.time(), but the number of calls varies by Python version, causing StopIteration on 3.12 and AssertionError on 3.14. Replace with an infinite counter-based callable and assert the timestamp was updated rather than checking for an exact value. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * fix(tests): use absolute path for model_prices JSON in validation test The test used a relative path 'litellm/model_prices_and_context_window.json' which only works when pytest runs from a specific working directory. Use os.path based on __file__ to resolve the path reliably. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * Update tests/test_litellm/test_utils.py Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> * fix(tests): use os.path instead of Path to avoid NameError Path is not imported at module level. Use os.path.join which is already available. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * clean up mock transport: remove streaming, add defensive parsing * docs: add Google GenAI SDK tutorial (JS & Python) (#21885) * docs: add Google GenAI SDK tutorial for JS and Python Add tutorial for using Google's official GenAI SDK (@google/genai for JS, google-genai for Python) with LiteLLM proxy. Covers pass-through and native router endpoints, streaming, multi-turn chat, and multi-provider routing via model_group_alias. Also updates pass-through docs to use the new SDK replacing the deprecated @google/generative-ai. * fix(docs): correct Python SDK env var name in GenAI tutorial GOOGLE_GENAI_API_KEY does not exist in the google-genai SDK. The correct env var is GEMINI_API_KEY (or GOOGLE_API_KEY). Also note that the Python SDK has no base URL env var. * fix(docs): replace non-existent GOOGLE_GENAI_BASE_URL env var in interactions.md The Python google-genai SDK does not read GOOGLE_GENAI_BASE_URL. Use http_options={"base_url": "..."} in code instead. * docs: add network mock benchmarking section * docs: tweak benchmarks wording * fix: add auth headers and empty latencies guard to benchmark script * refactor: use method-level import for MockOpenAITransport * fix: guard print_aggregate against empty latencies * fix: add INCOMPLETE status to Interactions API enum and test Google added INCOMPLETE to the Interactions API OpenAPI spec status enum. Update both the Status3 enum in the SDK types and the test's expected values to match. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * Guardrail Monitor - measure guardrail reliability in prod (#21944) * fix: fix log viewer for guardrail monitoring * feat(ui/): fix rendering logs per guardrail * fix: fix viewing logs on overview tab of guardrail * fix: log viewer * fix: fix naming to align with metric * docs: add performance & reliability section to v1.81.14 release notes * fix(tests): make RPM limit test sequential to avoid race condition Concurrent requests via run_in_executor + asyncio.gather caused a race condition where more requests slipped through the rate limiter than expected, leading to flaky test failures (e.g. 3 successes instead of 2 with rpm_limit=2). Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * feat: Singapore guardrail policies (PDPA + MAS AI Risk Management) (#21948) * feat: Singapore PDPA PII protection guardrail policy template Add Singapore Personal Data Protection Act (PDPA) guardrail support: Regex patterns (patterns.json): - sg_nric: NRIC/FIN detection ([STFGM] + 7 digits + checksum letter) - sg_phone: Singapore phone numbers (+65/0065/65 prefix) - sg_postal_code: 6-digit postal codes (contextual) - passport_singapore: Passport numbers (E/K + 7 digits, contextual) - sg_uen: Unique Entity Numbers (3 formats) - sg_bank_account: Bank account numbers (dash format, contextual) YAML policy templates (5 sub-guardrails): - sg_pdpa_personal_identifiers: s.13 Consent - sg_pdpa_sensitive_data: Advisory Guidelines - sg_pdpa_do_not_call: Part IX DNC Registry - sg_pdpa_data_transfer: s.26 overseas transfers - sg_pdpa_profiling_automated_decisions: Model AI Governance Framework Policy template entry in policy_templates.json with 9 guardrail definitions (4 regex-based + 5 YAML conditional keyword matching). Tests: - test_sg_patterns.py: regex pattern unit tests - test_sg_pdpa_guardrails.py: conditional keyword matching tests (100+ cases) * feat: MAS AI Risk Management Guidelines guardrail policy template Add Monetary Authority of Singapore (MAS) AI Risk Management Guidelines guardrail support for financial institutions: YAML policy templates (5 sub-guardrails): - sg_mas_fairness_bias: Blocks discriminatory financial AI (credit/loans/insurance by protected attributes) - sg_mas_transparency_explainability: Blocks opaque/unexplainable AI for consequential financial decisions - sg_mas_human_oversight: Blocks fully automated financial decisions without human-in-the-loop - sg_mas_data_governance: Blocks unauthorized sharing/mishandling of financial customer data - sg_mas_model_security: Blocks adversarial attacks, model poisoning, inversion on financial AI Policy template entry in policy_templates.json with 5 guardrail definitions. Aligned with MAS FEAT Principles, Project MindForge, and NIST AI RMF. Tests: - test_sg_mas_ai_guardrails.py: conditional keyword matching tests (100+ cases) * fix: address SG pattern review feedback - Update NRIC lowercase test for IGNORECASE runtime behavior - Add keyword context guard to sg_uen pattern to reduce false positives * docs: clarify MAS AIRM timeline references - Explicitly mark MAS AIRM as Nov 2025 consultation draft - Add 2018 qualifier for FEAT principles in MAS policy descriptions - Update MAS guardrail wording to avoid release-year ambiguity * chore: commit resolved MAS policy conflicts * test: * chore: * Add OpenAI Agents SDK tutorial with LiteLLM Proxy to docs (#21221) * Add OpenAI Agents SDK tutorial to docs * Update OpenAI Agents SDK tutorial to use LiteLLM environment variables * Enhance OpenAI Agents SDK tutorial with built-in LiteLLM extension details and updated configuration steps. Adjust section headings for clarity and improve the flow of information regarding model setup and usage. * adjust blog posts to fetch from github first * feat(videos): add variant parameter to video content download (#21955) openai videos models support the features to download variants. See more details here: https://developers.openai.com/api/docs/guides/video-generation#use-image-references. Plumb variant (e.g. "thumbnail", "spritesheet") through the full video content download chain: avideo_content → video_content → video_content_handler → transform_video_content_request. OpenAI appends ?variant=<value> to the GET URL; other providers accept the parameter in their signature but ignore it. * fixing path * adjust blog post path * Revert duplicate issue checker to text-based matching, remove duplicate PR workflow Remove the Claude Code-powered duplicate PR detection workflow and revert the duplicate issue checker back to wow-actions/potential-duplicates with text similarity matching. * ui changes * adding tests * adjust default aggregation threshold * fix(videos): pass api_key from litellm_params to video remix handlers (#21965) video_remix_handler and async_video_remix_handler were not falling back to litellm_params.api_key when the api_key parameter was None, causing Authorization: Bearer None to be sent to the provider. This matches the pattern already used by async_video_generation_handler. * adding testing coverage + fixing flaky tests * fix(ollama): thread api_base through get_model_info and add graceful fallback When users pass api_base to litellm.completion() for Ollama, the model info fetch (context window, function_calling support) was ignoring the user's api_base and only reading OLLAMA_API_BASE env var or defaulting to localhost:11434. This caused confusing errors in logs when Ollama runs on a remote server. Thread api_base from litellm_params through the get_model_info call chain so OllamaConfig.get_model_info() uses the correct server. Also return safe defaults instead of raising when the server is unreachable. Fixes #21967 --------- Co-authored-by: An Tang <ta@stripe.com> Co-authored-by: janfrederickk <75388864+janfrederickk@users.noreply.github.com> Co-authored-by: Zhenting Huang <3061613175@qq.com> Co-authored-by: Darien Kindlund <darien@kindlund.com> Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Co-authored-by: yuneng-jiang <yuneng.jiang@gmail.com> Co-authored-by: Ryan Crabbe <rcrabbe@berkeley.edu> Co-authored-by: Krish Dholakia <krrishdholakia@gmail.com> Co-authored-by: LeeJuOh <56071126+LeeJuOh@users.noreply.github.com> Co-authored-by: Monesh Ram <31161039+WhoisMonesh@users.noreply.github.com> Co-authored-by: Trevor Prater <trevor.prater@gmail.com> Co-authored-by: The Mavik <179817126+themavik@users.noreply.github.com> Co-authored-by: Edwin Isac <33712823+edwiniac@users.noreply.github.com> Co-authored-by: milan-berri <milan@berri.ai> Co-authored-by: Ishaan Jaff <ishaanjaffer0324@gmail.com> Co-authored-by: Sameer Kankute <sameer@berri.ai> Co-authored-by: Harshit Jain <harshitjain0562@gmail.com> Co-authored-by: Harshit Jain <48647625+Harshit28j@users.noreply.github.com> Co-authored-by: Ephrim Stanley <ephrim.stanley@point72.com> Co-authored-by: TomAlon <tom@noma.security> Co-authored-by: Julio Quinteros Pro <jquinter@gmail.com> Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Co-authored-by: ryan-crabbe <128659760+ryan-crabbe@users.noreply.github.com> Co-authored-by: Ron Zhong <ron-zhong@hotmail.com> Co-authored-by: Arindam Majumder <109217591+Arindam200@users.noreply.github.com> Co-authored-by: Lei Nie <lenie@quora.com>
1105 lines
37 KiB
Python
1105 lines
37 KiB
Python
# What is this?
|
|
## Tests if 'get_end_user_object' works as expected
|
|
|
|
import sys, os, asyncio, time, random, uuid
|
|
import traceback
|
|
from dotenv import load_dotenv
|
|
|
|
load_dotenv()
|
|
import os
|
|
|
|
sys.path.insert(
|
|
0, os.path.abspath("../..")
|
|
) # Adds the parent directory to the system path
|
|
import pytest, litellm
|
|
import httpx
|
|
from litellm.proxy._types import UserAPIKeyAuth
|
|
from litellm.proxy.auth.auth_checks import get_end_user_object
|
|
from litellm.caching.caching import DualCache
|
|
from litellm.proxy._types import (
|
|
LiteLLM_EndUserTable,
|
|
LiteLLM_BudgetTable,
|
|
LiteLLM_UserTable,
|
|
LiteLLM_TeamTable,
|
|
Litellm_EntityType,
|
|
)
|
|
from litellm.proxy.utils import PrismaClient
|
|
from litellm.proxy.auth.auth_checks import (
|
|
can_team_access_model,
|
|
_virtual_key_soft_budget_check,
|
|
_team_soft_budget_check,
|
|
)
|
|
from litellm.proxy.utils import ProxyLogging
|
|
from litellm.proxy.utils import CallInfo
|
|
|
|
|
|
@pytest.mark.parametrize("customer_spend, customer_budget", [(0, 10), (10, 0)])
|
|
@pytest.mark.asyncio
|
|
async def test_get_end_user_object(customer_spend, customer_budget):
|
|
"""
|
|
Scenario 1: normal
|
|
Scenario 2: user over budget
|
|
"""
|
|
end_user_id = "my-test-customer"
|
|
_budget = LiteLLM_BudgetTable(max_budget=customer_budget)
|
|
end_user_obj = LiteLLM_EndUserTable(
|
|
user_id=end_user_id,
|
|
spend=customer_spend,
|
|
litellm_budget_table=_budget,
|
|
blocked=False,
|
|
)
|
|
_cache = DualCache()
|
|
_key = "end_user_id:{}".format(end_user_id)
|
|
_cache.set_cache(key=_key, value=end_user_obj.model_dump())
|
|
try:
|
|
await get_end_user_object(
|
|
end_user_id=end_user_id,
|
|
prisma_client="RANDOM VALUE", # type: ignore
|
|
user_api_key_cache=_cache,
|
|
route="/v1/chat/completions",
|
|
)
|
|
if customer_spend > customer_budget:
|
|
pytest.fail(
|
|
"Expected call to fail. Customer Spend={}, Customer Budget={}".format(
|
|
customer_spend, customer_budget
|
|
)
|
|
)
|
|
except Exception as e:
|
|
if (
|
|
isinstance(e, litellm.BudgetExceededError)
|
|
and customer_spend > customer_budget
|
|
):
|
|
pass
|
|
else:
|
|
pytest.fail(
|
|
"Expected call to work. Customer Spend={}, Customer Budget={}, Error={}".format(
|
|
customer_spend, customer_budget, str(e)
|
|
)
|
|
)
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"model, expect_to_work",
|
|
[
|
|
("openai/gpt-4o-mini", True),
|
|
("openai/gpt-4o", False),
|
|
],
|
|
)
|
|
@pytest.mark.asyncio
|
|
async def test_can_key_call_model(model, expect_to_work):
|
|
"""
|
|
If wildcard model + specific model is used, choose the specific model settings
|
|
"""
|
|
from litellm.proxy.auth.auth_checks import can_key_call_model
|
|
from fastapi import HTTPException
|
|
|
|
llm_model_list = [
|
|
{
|
|
"model_name": "openai/*",
|
|
"litellm_params": {
|
|
"model": "openai/*",
|
|
"api_key": "test-api-key",
|
|
},
|
|
"model_info": {
|
|
"id": "e6e7006f83029df40ebc02ddd068890253f4cd3092bcb203d3d8e6f6f606f30f",
|
|
"db_model": False,
|
|
"access_groups": ["public-openai-models"],
|
|
},
|
|
},
|
|
{
|
|
"model_name": "openai/gpt-4o",
|
|
"litellm_params": {
|
|
"model": "openai/gpt-4o",
|
|
"api_key": "test-api-key",
|
|
},
|
|
"model_info": {
|
|
"id": "0cfcd87f2cb12a783a466888d05c6c89df66db23e01cecd75ec0b83aed73c9ad",
|
|
"db_model": False,
|
|
"access_groups": ["private-openai-models"],
|
|
},
|
|
},
|
|
]
|
|
router = litellm.Router(model_list=llm_model_list)
|
|
args = {
|
|
"model": model,
|
|
"llm_model_list": llm_model_list,
|
|
"valid_token": UserAPIKeyAuth(
|
|
models=["public-openai-models"],
|
|
),
|
|
"llm_router": router,
|
|
}
|
|
if expect_to_work:
|
|
await can_key_call_model(**args)
|
|
else:
|
|
with pytest.raises(Exception) as e:
|
|
await can_key_call_model(**args)
|
|
|
|
print(e)
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"model, expect_to_work",
|
|
[("openai/gpt-4o", False), ("openai/gpt-4o-mini", True)],
|
|
)
|
|
@pytest.mark.asyncio
|
|
async def test_can_team_call_model(model, expect_to_work):
|
|
from litellm.proxy.auth.auth_checks import model_in_access_group
|
|
from fastapi import HTTPException
|
|
|
|
llm_model_list = [
|
|
{
|
|
"model_name": "openai/*",
|
|
"litellm_params": {
|
|
"model": "openai/*",
|
|
"api_key": "test-api-key",
|
|
},
|
|
"model_info": {
|
|
"id": "e6e7006f83029df40ebc02ddd068890253f4cd3092bcb203d3d8e6f6f606f30f",
|
|
"db_model": False,
|
|
"access_groups": ["public-openai-models"],
|
|
},
|
|
},
|
|
{
|
|
"model_name": "openai/gpt-4o",
|
|
"litellm_params": {
|
|
"model": "openai/gpt-4o",
|
|
"api_key": "test-api-key",
|
|
},
|
|
"model_info": {
|
|
"id": "0cfcd87f2cb12a783a466888d05c6c89df66db23e01cecd75ec0b83aed73c9ad",
|
|
"db_model": False,
|
|
"access_groups": ["private-openai-models"],
|
|
},
|
|
},
|
|
]
|
|
router = litellm.Router(model_list=llm_model_list)
|
|
|
|
args = {
|
|
"model": model,
|
|
"team_models": ["public-openai-models"],
|
|
"llm_router": router,
|
|
}
|
|
if expect_to_work:
|
|
assert model_in_access_group(**args)
|
|
else:
|
|
assert not model_in_access_group(**args)
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"key_models, model, expect_to_work",
|
|
[
|
|
(["openai/*"], "openai/gpt-4o", True),
|
|
(["openai/*"], "openai/gpt-4o-mini", True),
|
|
(["openai/*"], "openaiz/gpt-4o-mini", False),
|
|
(["bedrock/*"], "bedrock/anthropic.claude-3-5-sonnet-20240620", True),
|
|
(["bedrock/*"], "bedrockz/anthropic.claude-3-5-sonnet-20240620", False),
|
|
(["bedrock/us.*"], "bedrock/us.amazon.nova-micro-v1:0", True),
|
|
],
|
|
)
|
|
@pytest.mark.asyncio
|
|
async def test_can_key_call_model_wildcard_access(key_models, model, expect_to_work):
|
|
from litellm.proxy.auth.auth_checks import can_key_call_model
|
|
from fastapi import HTTPException
|
|
|
|
llm_model_list = [
|
|
{
|
|
"model_name": "openai/*",
|
|
"litellm_params": {
|
|
"model": "openai/*",
|
|
"api_key": "test-api-key",
|
|
},
|
|
"model_info": {
|
|
"id": "e6e7006f83029df40ebc02ddd068890253f4cd3092bcb203d3d8e6f6f606f30f",
|
|
"db_model": False,
|
|
},
|
|
},
|
|
{
|
|
"model_name": "bedrock/*",
|
|
"litellm_params": {
|
|
"model": "bedrock/*",
|
|
"api_key": "test-api-key",
|
|
},
|
|
"model_info": {
|
|
"id": "e6e7006f83029df40ebc02ddd068890253f4cd3092bcb203d3d8e6f6f606f30f",
|
|
"db_model": False,
|
|
},
|
|
},
|
|
{
|
|
"model_name": "openai/gpt-4o",
|
|
"litellm_params": {
|
|
"model": "openai/gpt-4o",
|
|
"api_key": "test-api-key",
|
|
},
|
|
"model_info": {
|
|
"id": "0cfcd87f2cb12a783a466888d05c6c89df66db23e01cecd75ec0b83aed73c9ad",
|
|
"db_model": False,
|
|
},
|
|
},
|
|
]
|
|
router = litellm.Router(model_list=llm_model_list)
|
|
|
|
user_api_key_object = UserAPIKeyAuth(
|
|
models=key_models,
|
|
)
|
|
|
|
if expect_to_work:
|
|
await can_key_call_model(
|
|
model=model,
|
|
llm_model_list=llm_model_list,
|
|
valid_token=user_api_key_object,
|
|
llm_router=router,
|
|
)
|
|
else:
|
|
with pytest.raises(Exception) as e:
|
|
await can_key_call_model(
|
|
model=model,
|
|
llm_model_list=llm_model_list,
|
|
valid_token=user_api_key_object,
|
|
llm_router=router,
|
|
)
|
|
|
|
print(e)
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"key_models, model, expect_to_work",
|
|
[
|
|
# After a cost-map reload, add_known_models() updates anthropic_models so
|
|
# the anthropic/* wildcard can match a newly-added Anthropic model.
|
|
(["anthropic/*"], "claude-brand-new-model-reload-test", True),
|
|
# Wrong provider wildcard must still be denied even after reload.
|
|
(["openai/*"], "claude-brand-new-model-reload-test", False),
|
|
],
|
|
)
|
|
@pytest.mark.asyncio
|
|
async def test_wildcard_access_after_cost_map_reload(key_models, model, expect_to_work):
|
|
"""
|
|
Regression test: after a cost-map hot-reload, calling
|
|
add_known_models(model_cost_map=new_map) must update litellm.anthropic_models
|
|
so that the anthropic/* wildcard correctly grants (or denies) access to
|
|
newly-added models.
|
|
|
|
Root cause: both reload paths in proxy_server.py only updated
|
|
litellm.model_cost but never re-ran add_known_models(), so the provider sets
|
|
stayed stale and wildcard matching failed for new models.
|
|
|
|
Fix: each reload now calls litellm.add_known_models(model_cost_map=new_map)
|
|
with the fetched map passed explicitly to avoid any reference ambiguity.
|
|
"""
|
|
from litellm.proxy.auth.auth_checks import can_key_call_model
|
|
|
|
# Build a new cost map that includes the brand-new model — exactly what
|
|
# proxy_server.py receives from get_model_cost_map() during a reload.
|
|
new_cost_map = dict(litellm.model_cost)
|
|
new_cost_map[model] = {
|
|
"litellm_provider": "anthropic",
|
|
"max_tokens": 8192,
|
|
"input_cost_per_token": 0.000003,
|
|
"output_cost_per_token": 0.000015,
|
|
}
|
|
|
|
original_model_cost = litellm.model_cost
|
|
litellm.model_cost = new_cost_map
|
|
|
|
# Confirm the model is NOT yet in the provider set before reload propagation.
|
|
assert model not in litellm.anthropic_models
|
|
|
|
# Simulate what proxy_server.py now does after every reload.
|
|
litellm.add_known_models(model_cost_map=new_cost_map)
|
|
|
|
# After add_known_models(), the model must be in the set.
|
|
assert model in litellm.anthropic_models
|
|
|
|
llm_model_list = [
|
|
{
|
|
"model_name": "anthropic/*",
|
|
"litellm_params": {"model": "anthropic/*", "api_key": "test-api-key"},
|
|
"model_info": {"id": "test-id-anthropic-wildcard", "db_model": False},
|
|
},
|
|
{
|
|
"model_name": "openai/*",
|
|
"litellm_params": {"model": "openai/*", "api_key": "test-api-key"},
|
|
"model_info": {"id": "test-id-openai-wildcard", "db_model": False},
|
|
},
|
|
]
|
|
router = litellm.Router(model_list=llm_model_list)
|
|
user_api_key_object = UserAPIKeyAuth(models=key_models)
|
|
|
|
try:
|
|
if expect_to_work:
|
|
await can_key_call_model(
|
|
model=model,
|
|
llm_model_list=llm_model_list,
|
|
valid_token=user_api_key_object,
|
|
llm_router=router,
|
|
)
|
|
else:
|
|
with pytest.raises(Exception):
|
|
await can_key_call_model(
|
|
model=model,
|
|
llm_model_list=llm_model_list,
|
|
valid_token=user_api_key_object,
|
|
llm_router=router,
|
|
)
|
|
finally:
|
|
litellm.model_cost = original_model_cost
|
|
litellm.anthropic_models.discard(model)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_add_known_models_explicit_map_updates_provider_sets():
|
|
"""
|
|
Regression test: after a cost-map hot-reload, calling
|
|
add_known_models(model_cost_map=new_map) with the new map passed explicitly
|
|
must add any new provider models to the correct provider sets so that
|
|
wildcard access checks (anthropic/*, openai/*, …) work immediately.
|
|
|
|
This covers the proxy_server.py fix where both reload paths now call
|
|
litellm.add_known_models(model_cost_map=new_model_cost_map) instead of
|
|
relying on the module-level model_cost being up to date.
|
|
"""
|
|
fake_new_model = "claude-brand-new-explicit-map-test"
|
|
|
|
# Baseline: the model must not be in the sets before we do anything.
|
|
assert fake_new_model not in litellm.anthropic_models
|
|
|
|
new_cost_map = dict(litellm.model_cost)
|
|
new_cost_map[fake_new_model] = {
|
|
"litellm_provider": "anthropic",
|
|
"max_tokens": 8192,
|
|
"input_cost_per_token": 0.000003,
|
|
"output_cost_per_token": 0.000015,
|
|
}
|
|
|
|
# Simulate what proxy_server.py does on reload.
|
|
original_model_cost = litellm.model_cost
|
|
litellm.model_cost = new_cost_map
|
|
litellm.add_known_models(model_cost_map=new_cost_map)
|
|
|
|
try:
|
|
assert fake_new_model in litellm.anthropic_models, (
|
|
"add_known_models(model_cost_map=...) did not add the new model to "
|
|
"litellm.anthropic_models — wildcard access checks would fail."
|
|
)
|
|
finally:
|
|
# Clean up: restore original state.
|
|
litellm.model_cost = original_model_cost
|
|
litellm.anthropic_models.discard(fake_new_model)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_is_valid_fallback_model():
|
|
from litellm.proxy.auth.auth_checks import is_valid_fallback_model
|
|
from litellm import Router
|
|
|
|
router = Router(
|
|
model_list=[
|
|
{
|
|
"model_name": "gpt-3.5-turbo",
|
|
"litellm_params": {"model": "openai/gpt-3.5-turbo"},
|
|
}
|
|
]
|
|
)
|
|
|
|
try:
|
|
await is_valid_fallback_model(
|
|
model="gpt-3.5-turbo", llm_router=router, user_model=None
|
|
)
|
|
except Exception as e:
|
|
pytest.fail(f"Expected is_valid_fallback_model to work, got exception: {e}")
|
|
|
|
try:
|
|
await is_valid_fallback_model(
|
|
model="gpt-4o", llm_router=router, user_model=None
|
|
)
|
|
pytest.fail("Expected is_valid_fallback_model to fail")
|
|
except Exception as e:
|
|
assert "Invalid" in str(e)
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"token_spend, max_budget, expect_budget_error",
|
|
[
|
|
(5.0, 10.0, False), # Under budget
|
|
(10.0, 10.0, True), # At budget limit
|
|
(15.0, 10.0, True), # Over budget
|
|
],
|
|
)
|
|
@pytest.mark.asyncio
|
|
async def test_virtual_key_max_budget_check(
|
|
token_spend, max_budget, expect_budget_error
|
|
):
|
|
"""
|
|
Test if virtual key budget checks work as expected:
|
|
1. Triggers budget alert for all cases
|
|
2. Raises BudgetExceededError when spend >= max_budget
|
|
"""
|
|
from litellm.proxy.auth.auth_checks import _virtual_key_max_budget_check
|
|
from litellm.proxy.utils import ProxyLogging
|
|
|
|
# Setup test data
|
|
valid_token = UserAPIKeyAuth(
|
|
token="test-token",
|
|
spend=token_spend,
|
|
max_budget=max_budget,
|
|
user_id="test-user",
|
|
key_alias="test-key",
|
|
)
|
|
|
|
user_obj = LiteLLM_UserTable(
|
|
user_id="test-user",
|
|
user_email="test@email.com",
|
|
max_budget=None,
|
|
)
|
|
|
|
proxy_logging_obj = ProxyLogging(
|
|
user_api_key_cache=None,
|
|
)
|
|
|
|
# Track if budget alert was called
|
|
alert_called = False
|
|
|
|
async def mock_budget_alert(*args, **kwargs):
|
|
nonlocal alert_called
|
|
alert_called = True
|
|
|
|
proxy_logging_obj.budget_alerts = mock_budget_alert
|
|
|
|
try:
|
|
await _virtual_key_max_budget_check(
|
|
valid_token=valid_token,
|
|
proxy_logging_obj=proxy_logging_obj,
|
|
user_obj=user_obj,
|
|
)
|
|
if expect_budget_error:
|
|
pytest.fail(
|
|
f"Expected BudgetExceededError for spend={token_spend}, max_budget={max_budget}"
|
|
)
|
|
except litellm.BudgetExceededError as e:
|
|
if not expect_budget_error:
|
|
pytest.fail(
|
|
f"Unexpected BudgetExceededError for spend={token_spend}, max_budget={max_budget}"
|
|
)
|
|
assert e.current_cost == token_spend
|
|
assert e.max_budget == max_budget
|
|
|
|
await asyncio.sleep(1)
|
|
|
|
# Verify budget alert was triggered
|
|
assert alert_called, "Budget alert should be triggered"
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"model, team_models, expect_to_work",
|
|
[
|
|
("gpt-4", ["gpt-4"], True), # exact match
|
|
("gpt-4", ["all-proxy-models"], True), # all-proxy-models access
|
|
("gpt-4", ["*"], True), # wildcard access
|
|
("gpt-4", ["openai/*"], True), # openai wildcard access
|
|
(
|
|
"bedrock/anthropic.claude-3-5-sonnet-20240620",
|
|
["bedrock/*"],
|
|
True,
|
|
), # wildcard access
|
|
(
|
|
"bedrockz/anthropic.claude-3-5-sonnet-20240620",
|
|
["bedrock/*"],
|
|
False,
|
|
), # non-match wildcard access
|
|
("bedrock/very_new_model", ["bedrock/*"], True), # bedrock wildcard access
|
|
(
|
|
"bedrock/claude-3-5-sonnet-20240620",
|
|
["bedrock/claude-*"],
|
|
True,
|
|
), # match on pattern
|
|
(
|
|
"bedrock/claude-3-6-sonnet-20240620",
|
|
["bedrock/claude-3-5-*"],
|
|
False,
|
|
), # don't match on pattern
|
|
("openai/gpt-4o", ["openai/*"], True), # openai wildcard access
|
|
("gpt-4", ["gpt-3.5-turbo"], False), # model not in allowed list
|
|
("claude-3", [], True), # empty model list (allows all)
|
|
],
|
|
)
|
|
@pytest.mark.asyncio
|
|
async def test_can_team_access_model(model, team_models, expect_to_work):
|
|
"""
|
|
Test cases for can_team_access_model:
|
|
1. Exact model match
|
|
2. all-proxy-models access
|
|
3. Wildcard (*) access
|
|
4. OpenAI wildcard access
|
|
5. Model not in allowed list
|
|
6. Empty model list
|
|
7. None model list
|
|
"""
|
|
try:
|
|
team_object = LiteLLM_TeamTable(
|
|
team_id="test-team",
|
|
models=team_models,
|
|
)
|
|
result = await can_team_access_model(
|
|
model=model,
|
|
team_object=team_object,
|
|
llm_router=None,
|
|
team_model_aliases=None,
|
|
)
|
|
if not expect_to_work:
|
|
pytest.fail(
|
|
f"Expected model access check to fail for model={model}, team_models={team_models}"
|
|
)
|
|
except Exception as e:
|
|
if expect_to_work:
|
|
pytest.fail(
|
|
f"Expected model access check to work for model={model}, team_models={team_models}. Got error: {str(e)}"
|
|
)
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"spend, soft_budget, expect_alert",
|
|
[
|
|
(100, 50, True), # Over soft budget
|
|
(50, 50, True), # At soft budget
|
|
(25, 50, False), # Under soft budget
|
|
(100, None, False), # No soft budget set
|
|
],
|
|
)
|
|
@pytest.mark.asyncio
|
|
async def test_virtual_key_soft_budget_check(spend, soft_budget, expect_alert):
|
|
"""
|
|
Test cases for _virtual_key_soft_budget_check:
|
|
1. Spend over soft budget
|
|
2. Spend at soft budget
|
|
3. Spend under soft budget
|
|
4. No soft budget set
|
|
"""
|
|
alert_triggered = False
|
|
|
|
class MockProxyLogging:
|
|
async def budget_alerts(self, type, user_info):
|
|
nonlocal alert_triggered
|
|
alert_triggered = True
|
|
assert type == "soft_budget"
|
|
assert isinstance(user_info, CallInfo)
|
|
|
|
valid_token = UserAPIKeyAuth(
|
|
token="test-token",
|
|
spend=spend,
|
|
soft_budget=soft_budget,
|
|
user_id="test-user",
|
|
team_id="test-team",
|
|
key_alias="test-key",
|
|
)
|
|
|
|
proxy_logging_obj = MockProxyLogging()
|
|
|
|
await _virtual_key_soft_budget_check(
|
|
valid_token=valid_token,
|
|
proxy_logging_obj=proxy_logging_obj,
|
|
)
|
|
|
|
await asyncio.sleep(0.1) # Allow time for the alert task to complete
|
|
|
|
assert (
|
|
alert_triggered == expect_alert
|
|
), f"Expected alert_triggered to be {expect_alert} for spend={spend}, soft_budget={soft_budget}"
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"spend, soft_budget, expect_alert, metadata, expected_alert_emails",
|
|
[
|
|
(100, 50, False, None, None), # Over soft budget, no metadata - no alert_emails configured, so no alert
|
|
(50, 50, False, None, None), # At soft budget, no metadata - no alert_emails configured, so no alert
|
|
(25, 50, False, None, None), # Under soft budget
|
|
(100, None, False, None, None), # No soft budget set
|
|
(100, 50, True, {"soft_budget_alerting_emails": ["team1@example.com", "team2@example.com"]}, ["team1@example.com", "team2@example.com"]), # Over soft budget with list of emails
|
|
(100, 50, True, {"soft_budget_alerting_emails": "team1@example.com,team2@example.com"}, ["team1@example.com", "team2@example.com"]), # Over soft budget with comma-separated emails
|
|
(100, 50, True, {"soft_budget_alerting_emails": ["team1@example.com", "", " ", "team2@example.com"]}, ["team1@example.com", "team2@example.com"]), # Over soft budget with empty strings filtered
|
|
],
|
|
)
|
|
@pytest.mark.asyncio
|
|
async def test_team_soft_budget_check(spend, soft_budget, expect_alert, metadata, expected_alert_emails):
|
|
"""
|
|
Test cases for _team_soft_budget_check:
|
|
1. Spend over soft budget, no alert_emails configured - should NOT trigger alert (alerts only sent when alert_emails configured)
|
|
2. Spend at soft budget, no alert_emails configured - should NOT trigger alert (alerts only sent when alert_emails configured)
|
|
3. Spend under soft budget - should not trigger alert
|
|
4. No soft budget set - should not trigger alert
|
|
5. Team with alert emails in metadata (list) - should include alert_emails in CallInfo
|
|
6. Team with alert emails in metadata (comma-separated string) - should parse and include alert_emails
|
|
7. Team with alert emails containing empty strings - should filter them out
|
|
"""
|
|
alert_triggered = False
|
|
captured_call_info = None
|
|
|
|
class MockProxyLogging:
|
|
async def budget_alerts(self, type, user_info):
|
|
nonlocal alert_triggered, captured_call_info
|
|
alert_triggered = True
|
|
captured_call_info = user_info
|
|
assert type == "soft_budget"
|
|
assert isinstance(user_info, CallInfo)
|
|
|
|
valid_token = UserAPIKeyAuth(
|
|
token="test-token",
|
|
user_id="test-user",
|
|
team_id="test-team",
|
|
team_alias="test-team-alias",
|
|
key_alias="test-key",
|
|
)
|
|
|
|
team_object = LiteLLM_TeamTable(
|
|
team_id="test-team",
|
|
spend=spend,
|
|
soft_budget=soft_budget,
|
|
max_budget=100.0,
|
|
metadata=metadata,
|
|
)
|
|
|
|
proxy_logging_obj = MockProxyLogging()
|
|
|
|
await _team_soft_budget_check(
|
|
team_object=team_object,
|
|
valid_token=valid_token,
|
|
proxy_logging_obj=proxy_logging_obj,
|
|
)
|
|
|
|
await asyncio.sleep(0.1) # Allow time for the alert task to complete
|
|
|
|
assert (
|
|
alert_triggered == expect_alert
|
|
), f"Expected alert_triggered to be {expect_alert} for spend={spend}, soft_budget={soft_budget}"
|
|
|
|
if expect_alert:
|
|
assert captured_call_info is not None
|
|
assert captured_call_info.team_id == "test-team"
|
|
assert captured_call_info.spend == spend
|
|
assert captured_call_info.soft_budget == soft_budget
|
|
assert captured_call_info.event_group == Litellm_EntityType.TEAM
|
|
# Verify alert_emails if expected
|
|
if expected_alert_emails is not None:
|
|
assert captured_call_info.alert_emails == expected_alert_emails
|
|
else:
|
|
assert captured_call_info.alert_emails is None or captured_call_info.alert_emails == []
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_can_user_call_model():
|
|
from litellm.proxy.auth.auth_checks import can_user_call_model
|
|
from litellm.proxy._types import ProxyException
|
|
from litellm import Router
|
|
|
|
router = Router(
|
|
model_list=[
|
|
{
|
|
"model_name": "anthropic-claude",
|
|
"litellm_params": {"model": "anthropic/anthropic-claude"},
|
|
},
|
|
{
|
|
"model_name": "gpt-3.5-turbo",
|
|
"litellm_params": {"model": "gpt-3.5-turbo", "api_key": "test-api-key"},
|
|
},
|
|
]
|
|
)
|
|
|
|
args = {
|
|
"model": "anthropic-claude",
|
|
"llm_router": router,
|
|
"user_object": LiteLLM_UserTable(
|
|
user_id="testuser21@mycompany.com",
|
|
max_budget=None,
|
|
spend=0.0042295,
|
|
model_max_budget={},
|
|
model_spend={},
|
|
user_email="testuser@mycompany.com",
|
|
models=["gpt-3.5-turbo"],
|
|
),
|
|
}
|
|
|
|
with pytest.raises(ProxyException) as e:
|
|
await can_user_call_model(**args)
|
|
|
|
args["model"] = "gpt-3.5-turbo"
|
|
await can_user_call_model(**args)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_can_user_call_model_with_no_default_models():
|
|
from litellm.proxy.auth.auth_checks import can_user_call_model
|
|
from litellm.proxy._types import ProxyException, SpecialModelNames
|
|
from unittest.mock import MagicMock
|
|
|
|
args = {
|
|
"model": "anthropic-claude",
|
|
"llm_router": MagicMock(),
|
|
"user_object": LiteLLM_UserTable(
|
|
user_id="testuser21@mycompany.com",
|
|
max_budget=None,
|
|
spend=0.0042295,
|
|
model_max_budget={},
|
|
model_spend={},
|
|
user_email="testuser@mycompany.com",
|
|
models=[SpecialModelNames.no_default_models.value],
|
|
),
|
|
}
|
|
|
|
with pytest.raises(ProxyException) as e:
|
|
await can_user_call_model(**args)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_fuzzy_user_object():
|
|
from litellm.proxy.auth.auth_checks import _get_fuzzy_user_object
|
|
from litellm.proxy.utils import PrismaClient
|
|
from unittest.mock import AsyncMock, MagicMock
|
|
|
|
# Setup mock Prisma client
|
|
mock_prisma = MagicMock()
|
|
mock_prisma.db = MagicMock()
|
|
mock_prisma.db.litellm_usertable = MagicMock()
|
|
|
|
# Mock user data
|
|
test_user = LiteLLM_UserTable(
|
|
user_id="test_123",
|
|
sso_user_id="sso_123",
|
|
user_email="test@example.com",
|
|
organization_memberships=[],
|
|
max_budget=None,
|
|
)
|
|
|
|
# Test 1: Find user by SSO ID
|
|
mock_prisma.db.litellm_usertable.find_unique = AsyncMock(return_value=test_user)
|
|
result = await _get_fuzzy_user_object(
|
|
prisma_client=mock_prisma, sso_user_id="sso_123", user_email="test@example.com"
|
|
)
|
|
assert result == test_user
|
|
mock_prisma.db.litellm_usertable.find_unique.assert_called_with(
|
|
where={"sso_user_id": "sso_123"}, include={"organization_memberships": True}
|
|
)
|
|
|
|
# Test 2: SSO ID not found, find by email
|
|
mock_prisma.db.litellm_usertable.find_unique = AsyncMock(return_value=None)
|
|
mock_prisma.db.litellm_usertable.find_first = AsyncMock(return_value=test_user)
|
|
mock_prisma.db.litellm_usertable.update = AsyncMock()
|
|
|
|
result = await _get_fuzzy_user_object(
|
|
prisma_client=mock_prisma,
|
|
sso_user_id="new_sso_456",
|
|
user_email="test@example.com",
|
|
)
|
|
assert result == test_user
|
|
mock_prisma.db.litellm_usertable.find_first.assert_called_with(
|
|
where={"user_email": {"equals": "test@example.com", "mode": "insensitive"}},
|
|
include={"organization_memberships": True},
|
|
)
|
|
|
|
# Test 3: Verify background SSO update task when user found by email
|
|
await asyncio.sleep(0.1) # Allow time for background task
|
|
mock_prisma.db.litellm_usertable.update.assert_called_with(
|
|
where={"user_id": "test_123"}, data={"sso_user_id": "new_sso_456"}
|
|
)
|
|
|
|
# Test 4: User not found by either method
|
|
mock_prisma.db.litellm_usertable.find_unique = AsyncMock(return_value=None)
|
|
mock_prisma.db.litellm_usertable.find_first = AsyncMock(return_value=None)
|
|
|
|
result = await _get_fuzzy_user_object(
|
|
prisma_client=mock_prisma,
|
|
sso_user_id="unknown_sso",
|
|
user_email="unknown@example.com",
|
|
)
|
|
assert result is None
|
|
|
|
# Test 5: Only email provided (no SSO ID)
|
|
mock_prisma.db.litellm_usertable.find_first = AsyncMock(return_value=test_user)
|
|
result = await _get_fuzzy_user_object(
|
|
prisma_client=mock_prisma, user_email="test@example.com"
|
|
)
|
|
assert result == test_user
|
|
mock_prisma.db.litellm_usertable.find_first.assert_called_with(
|
|
where={"user_email": {"equals": "test@example.com", "mode": "insensitive"}},
|
|
include={"organization_memberships": True},
|
|
)
|
|
|
|
# Test 6: Only SSO ID provided (no email)
|
|
mock_prisma.db.litellm_usertable.find_unique = AsyncMock(return_value=test_user)
|
|
result = await _get_fuzzy_user_object(
|
|
prisma_client=mock_prisma, sso_user_id="sso_123"
|
|
)
|
|
assert result == test_user
|
|
mock_prisma.db.litellm_usertable.find_unique.assert_called_with(
|
|
where={"sso_user_id": "sso_123"}, include={"organization_memberships": True}
|
|
)
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"model, alias_map, expect_to_work",
|
|
[
|
|
("gpt-4", {"gpt-4": "gpt-4-team1"}, True), # model matches alias value
|
|
("gpt-5", {"gpt-4": "gpt-4-team1"}, False),
|
|
],
|
|
)
|
|
@pytest.mark.asyncio
|
|
async def test_can_key_call_model_with_aliases(model, alias_map, expect_to_work):
|
|
"""
|
|
Test if can_key_call_model correctly handles model aliases in the token
|
|
"""
|
|
from litellm.proxy.auth.auth_checks import can_key_call_model
|
|
|
|
llm_model_list = [
|
|
{
|
|
"model_name": "gpt-4-team1",
|
|
"litellm_params": {
|
|
"model": "gpt-4",
|
|
"api_key": "test-api-key",
|
|
},
|
|
}
|
|
]
|
|
router = litellm.Router(model_list=llm_model_list)
|
|
|
|
user_api_key_object = UserAPIKeyAuth(
|
|
models=[
|
|
"gpt-4-team1",
|
|
],
|
|
team_model_aliases=alias_map,
|
|
)
|
|
|
|
if expect_to_work:
|
|
await can_key_call_model(
|
|
model=model,
|
|
llm_model_list=llm_model_list,
|
|
valid_token=user_api_key_object,
|
|
llm_router=router,
|
|
)
|
|
else:
|
|
with pytest.raises(Exception) as e:
|
|
await can_key_call_model(
|
|
model=model,
|
|
llm_model_list=llm_model_list,
|
|
valid_token=user_api_key_object,
|
|
llm_router=router,
|
|
)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Access group cache helpers (_cache_access_object, _delete_cache_access_object)
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_cache_access_object():
|
|
"""Test _cache_access_object stores access group in cache with correct key."""
|
|
from litellm.proxy.auth.auth_checks import _cache_access_object
|
|
from litellm.proxy._types import LiteLLM_AccessGroupTable
|
|
|
|
cache = DualCache()
|
|
ag_id = "ag-test-123"
|
|
ag_table = LiteLLM_AccessGroupTable(
|
|
access_group_id=ag_id,
|
|
access_group_name="test-group",
|
|
access_model_names=["gpt-4"],
|
|
)
|
|
await _cache_access_object(
|
|
access_group_id=ag_id,
|
|
access_group_table=ag_table,
|
|
user_api_key_cache=cache,
|
|
)
|
|
cached = await cache.async_get_cache(key=f"access_group_id:{ag_id}")
|
|
assert cached is not None
|
|
if isinstance(cached, dict):
|
|
assert cached.get("access_group_id") == ag_id
|
|
assert cached.get("access_group_name") == "test-group"
|
|
else:
|
|
assert cached.access_group_id == ag_id
|
|
assert cached.access_group_name == "test-group"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_delete_cache_access_object():
|
|
"""Test _delete_cache_access_object removes access group from in-memory cache."""
|
|
from litellm.proxy.auth.auth_checks import _delete_cache_access_object
|
|
from litellm.proxy._types import LiteLLM_AccessGroupTable
|
|
|
|
cache = DualCache()
|
|
ag_id = "ag-delete-test"
|
|
ag_table = LiteLLM_AccessGroupTable(
|
|
access_group_id=ag_id,
|
|
access_group_name="to-delete",
|
|
)
|
|
await cache.async_set_cache(key=f"access_group_id:{ag_id}", value=ag_table, ttl=60)
|
|
await _delete_cache_access_object(access_group_id=ag_id, user_api_key_cache=cache)
|
|
cached = await cache.async_get_cache(key=f"access_group_id:{ag_id}")
|
|
assert cached is None
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Access group resource fetchers (_get_models_from_access_groups, _get_agent_ids_from_access_groups)
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"resource_field, access_group_data, expected",
|
|
[
|
|
(
|
|
"access_model_names",
|
|
{"access_group_id": "ag-1", "access_model_names": ["gpt-4", "claude-3"]},
|
|
["gpt-4", "claude-3"],
|
|
),
|
|
(
|
|
"access_agent_ids",
|
|
{"access_group_id": "ag-2", "access_agent_ids": ["agent-a", "agent-b"]},
|
|
["agent-a", "agent-b"],
|
|
),
|
|
(
|
|
"access_model_names",
|
|
{"access_group_id": "ag-3", "access_model_names": []},
|
|
[],
|
|
),
|
|
],
|
|
)
|
|
@pytest.mark.asyncio
|
|
async def test_get_resources_from_access_groups(resource_field, access_group_data, expected):
|
|
"""Test _get_resources_from_access_groups returns correct resource list from access groups."""
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
from litellm.proxy._types import LiteLLM_AccessGroupTable
|
|
from litellm.proxy.auth.auth_checks import (
|
|
_get_agent_ids_from_access_groups,
|
|
_get_models_from_access_groups,
|
|
)
|
|
|
|
ag_table = LiteLLM_AccessGroupTable(
|
|
access_group_id=access_group_data["access_group_id"],
|
|
access_group_name="test",
|
|
access_model_names=access_group_data.get("access_model_names", []),
|
|
access_agent_ids=access_group_data.get("access_agent_ids", []),
|
|
)
|
|
|
|
with patch(
|
|
"litellm.proxy.auth.auth_checks.get_access_object",
|
|
new_callable=AsyncMock,
|
|
return_value=ag_table,
|
|
):
|
|
if resource_field == "access_model_names":
|
|
result = await _get_models_from_access_groups(
|
|
access_group_ids=[access_group_data["access_group_id"]],
|
|
prisma_client=MagicMock(),
|
|
user_api_key_cache=DualCache(),
|
|
)
|
|
else:
|
|
result = await _get_agent_ids_from_access_groups(
|
|
access_group_ids=[access_group_data["access_group_id"]],
|
|
prisma_client=MagicMock(),
|
|
user_api_key_cache=DualCache(),
|
|
)
|
|
assert sorted(result) == sorted(expected)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_models_from_access_groups_empty_ids():
|
|
"""Test _get_models_from_access_groups returns empty list when access_group_ids is empty."""
|
|
from litellm.proxy.auth.auth_checks import _get_models_from_access_groups
|
|
|
|
result = await _get_models_from_access_groups(access_group_ids=[])
|
|
assert result == []
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# can_team_access_model with access_group_ids fallback
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_can_team_access_model_via_access_group_ids():
|
|
"""Test can_team_access_model allows access when team has access_group_ids granting model access."""
|
|
from unittest.mock import AsyncMock, patch
|
|
|
|
from litellm.proxy.auth.auth_checks import can_team_access_model
|
|
|
|
team_object = LiteLLM_TeamTable(
|
|
team_id="test-team",
|
|
models=[],
|
|
access_group_ids=["ag-with-gpt4"],
|
|
)
|
|
|
|
with patch(
|
|
"litellm.proxy.auth.auth_checks._get_models_from_access_groups",
|
|
new_callable=AsyncMock,
|
|
return_value=["gpt-4"],
|
|
):
|
|
result = await can_team_access_model(
|
|
model="gpt-4",
|
|
team_object=team_object,
|
|
llm_router=None,
|
|
team_model_aliases=None,
|
|
)
|
|
assert result is True
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_can_team_access_model_access_group_ids_denied():
|
|
"""Test can_team_access_model denies when neither team models nor access_group_ids grant access."""
|
|
from unittest.mock import AsyncMock, patch
|
|
|
|
from litellm.proxy.auth.auth_checks import can_team_access_model
|
|
from litellm.proxy._types import ProxyException
|
|
|
|
team_object = LiteLLM_TeamTable(
|
|
team_id="test-team",
|
|
models=["gpt-3.5-turbo"],
|
|
access_group_ids=["ag-other"],
|
|
)
|
|
|
|
with patch(
|
|
"litellm.proxy.auth.auth_checks._get_models_from_access_groups",
|
|
new_callable=AsyncMock,
|
|
return_value=["claude-3"],
|
|
):
|
|
with pytest.raises(ProxyException):
|
|
await can_team_access_model(
|
|
model="gpt-4",
|
|
team_object=team_object,
|
|
llm_router=None,
|
|
team_model_aliases=None,
|
|
)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# can_key_call_model with access_group_ids fallback
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_can_key_call_model_via_access_group_ids():
|
|
"""Test can_key_call_model allows access when key has access_group_ids granting model access."""
|
|
from unittest.mock import AsyncMock, patch
|
|
|
|
from litellm.proxy.auth.auth_checks import can_key_call_model
|
|
|
|
user_api_key_object = UserAPIKeyAuth(
|
|
token="test-token",
|
|
models=[],
|
|
access_group_ids=["ag-with-gpt4"],
|
|
)
|
|
router = litellm.Router(
|
|
model_list=[
|
|
{
|
|
"model_name": "gpt-4",
|
|
"litellm_params": {"model": "openai/gpt-4", "api_key": "test"},
|
|
}
|
|
]
|
|
)
|
|
|
|
with patch(
|
|
"litellm.proxy.auth.auth_checks._get_models_from_access_groups",
|
|
new_callable=AsyncMock,
|
|
return_value=["gpt-4"],
|
|
):
|
|
await can_key_call_model(
|
|
model="gpt-4",
|
|
llm_model_list=[],
|
|
valid_token=user_api_key_object,
|
|
llm_router=router,
|
|
)
|