mirror of
https://github.com/tiennm99/litellm.git
synced 2026-06-18 07:33:58 +00:00
074455c138
* fix(auth): expand all-team-models sentinel in can_key_call_model Keys with models=["all-team-models"] were denied during batch JSONL model validation because can_key_call_model matched the literal string against the model name. Add _resolve_key_models_for_auth_check to expand the sentinel to team_models before the check, consistent with get_key_models in model_checks.py and the completion-route bypass. Co-authored-by: Cursor <cursoragent@cursor.com> * docs(auth): document empty team_models unrestricted access behavior; add regression test Adds a docstring note to _resolve_key_models_for_auth_check explaining that when team_models is empty, all-team-models resolves to [] which is treated as unrestricted access (consistent with get_key_models behavior on other auth paths). Adds a test to lock in this behavior. * fix(auth): deny all-team-models access when key has no team_id A key configured with models=["all-team-models"] but no team_id could previously resolve to an empty allowlist, which _check_model_access_helper treats as unrestricted access. Now the sentinel is only expanded when team_id is set; otherwise the unresolved sentinel stays in the model list and causes a deny (no real model name matches it). Same fix applied to get_key_models in model_checks.py for consistency across batch and non-batch auth paths. * style: black format model_checks.py * Fix batch all-team-models auth * style: black format batch_rate_limiter.py * fix(test): add tool_use_system_prompt_tokens to model prices schema validator * fix(batch): catch get_team_object errors to avoid 404 escaping batch auth * fix(batch): apply per-member model scope check after team auth in batch validation * Fail closed on batch team auth fetch errors * test(batch): cover team_object grant and member-scope denial in batch auth --------- Co-authored-by: Cursor <cursoragent@cursor.com> Co-authored-by: mateo-berri <277851410+mateo-berri@users.noreply.github.com>
1298 lines
46 KiB
Python
1298 lines
46 KiB
Python
"""
|
|
VERIA-39 regression tests:
|
|
|
|
- The batch input-file token counter must measure embeddings (`input`)
|
|
and text-completion (`prompt`) payloads, not only chat (`messages`).
|
|
- The batch rate-limiter pre-call hook must reject batch files that name
|
|
models the caller is not authorized to use.
|
|
"""
|
|
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
import pytest
|
|
from fastapi import HTTPException
|
|
|
|
from litellm.proxy._types import LitellmUserRoles, UserAPIKeyAuth
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Token counter — covers all three batch payload shapes
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def test_token_counter_counts_chat_messages():
|
|
from litellm.batches.batch_utils import _get_batch_job_input_file_usage
|
|
|
|
usage = _get_batch_job_input_file_usage(
|
|
file_content_dictionary=[
|
|
{
|
|
"body": {
|
|
"model": "gpt-4o-mini",
|
|
"messages": [{"role": "user", "content": "hello"}],
|
|
}
|
|
}
|
|
]
|
|
)
|
|
assert usage.prompt_tokens > 0
|
|
|
|
|
|
def test_token_counter_counts_text_completion_prompt():
|
|
"""Pre-fix this returned 0 tokens (the function only inspected
|
|
`messages`), letting `prompt`-style batches slip past TPM limits."""
|
|
from litellm.batches.batch_utils import _get_batch_job_input_file_usage
|
|
|
|
usage = _get_batch_job_input_file_usage(
|
|
file_content_dictionary=[
|
|
{"body": {"model": "gpt-3.5-turbo-instruct", "prompt": "hello world"}}
|
|
]
|
|
)
|
|
assert usage.prompt_tokens > 0
|
|
|
|
|
|
def test_token_counter_counts_embedding_input_string():
|
|
from litellm.batches.batch_utils import _get_batch_job_input_file_usage
|
|
|
|
usage = _get_batch_job_input_file_usage(
|
|
file_content_dictionary=[
|
|
{"body": {"model": "text-embedding-3-small", "input": "hello world"}}
|
|
]
|
|
)
|
|
assert usage.prompt_tokens > 0
|
|
|
|
|
|
def test_token_counter_counts_embedding_input_list():
|
|
from litellm.batches.batch_utils import _get_batch_job_input_file_usage
|
|
|
|
usage = _get_batch_job_input_file_usage(
|
|
file_content_dictionary=[
|
|
{
|
|
"body": {
|
|
"model": "text-embedding-3-small",
|
|
"input": ["hello", "world"],
|
|
}
|
|
}
|
|
]
|
|
)
|
|
assert usage.prompt_tokens > 0
|
|
|
|
|
|
def test_token_counter_counts_text_completion_prompt_list():
|
|
from litellm.batches.batch_utils import _get_batch_job_input_file_usage
|
|
|
|
usage = _get_batch_job_input_file_usage(
|
|
file_content_dictionary=[
|
|
{
|
|
"body": {
|
|
"model": "gpt-3.5-turbo-instruct",
|
|
"prompt": ["alpha", "beta"],
|
|
}
|
|
}
|
|
]
|
|
)
|
|
assert usage.prompt_tokens > 0
|
|
|
|
|
|
def test_token_counter_counts_pre_tokenized_prompt_int_list():
|
|
"""OpenAI's text-completion API accepts a single pre-tokenized prompt as
|
|
a list of ints. Each int is one token; pre-fix this shape was silently
|
|
counted as zero, leaving a TPM bypass."""
|
|
from litellm.batches.batch_utils import _get_batch_job_input_file_usage
|
|
|
|
usage = _get_batch_job_input_file_usage(
|
|
file_content_dictionary=[
|
|
{
|
|
"body": {
|
|
"model": "gpt-3.5-turbo-instruct",
|
|
"prompt": [1, 2, 3, 4, 5],
|
|
}
|
|
}
|
|
]
|
|
)
|
|
assert usage.prompt_tokens == 5
|
|
|
|
|
|
def test_token_counter_counts_pre_tokenized_prompt_list_of_int_lists():
|
|
"""Multiple pre-tokenized prompts (`list[list[int]]`) — the most
|
|
important bypass shape. A 1000-token batch must report 1000 tokens,
|
|
not zero."""
|
|
from litellm.batches.batch_utils import _get_batch_job_input_file_usage
|
|
|
|
usage = _get_batch_job_input_file_usage(
|
|
file_content_dictionary=[
|
|
{
|
|
"body": {
|
|
"model": "gpt-3.5-turbo-instruct",
|
|
"prompt": [[1] * 250, [2] * 250, [3] * 500],
|
|
}
|
|
}
|
|
]
|
|
)
|
|
assert usage.prompt_tokens == 1000
|
|
|
|
|
|
def test_token_counter_counts_pre_tokenized_input_for_embeddings():
|
|
"""Same shape applies to embeddings (`input`)."""
|
|
from litellm.batches.batch_utils import _get_batch_job_input_file_usage
|
|
|
|
usage = _get_batch_job_input_file_usage(
|
|
file_content_dictionary=[
|
|
{
|
|
"body": {
|
|
"model": "text-embedding-3-small",
|
|
"input": [[1, 2, 3], [4, 5, 6]],
|
|
}
|
|
}
|
|
]
|
|
)
|
|
assert usage.prompt_tokens == 6
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Model extractor
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def test_model_extractor_returns_distinct_models():
|
|
from litellm.batches.batch_utils import _get_models_from_batch_input_file_content
|
|
|
|
models = _get_models_from_batch_input_file_content(
|
|
[
|
|
{"body": {"model": "gpt-4o", "messages": []}},
|
|
{"body": {"model": "gpt-4o", "messages": []}}, # duplicate
|
|
{"body": {"model": "gpt-4o-mini", "messages": []}},
|
|
{"body": {}}, # missing model
|
|
]
|
|
)
|
|
assert models == ["gpt-4o", "gpt-4o-mini"]
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Pre-call hook model validation
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_pre_call_rejects_unauthorized_model_in_batch_file():
|
|
"""Pre-fix the hook only validated the outer `model` parameter and
|
|
forwarded the file as-is. With this fix, a model named inside the
|
|
JSONL that the caller cannot use must trigger a 403."""
|
|
from litellm.proxy.hooks.batch_rate_limiter import _PROXY_BatchRateLimiter
|
|
|
|
rate_limiter = _PROXY_BatchRateLimiter(
|
|
internal_usage_cache=MagicMock(),
|
|
parallel_request_limiter=MagicMock(),
|
|
)
|
|
|
|
# Simulated decoded batch file: caller is restricted to gpt-3.5
|
|
# but the JSONL points at gpt-4o.
|
|
file_dict = [
|
|
{"body": {"model": "gpt-4o", "messages": [{"role": "user", "content": "x"}]}}
|
|
]
|
|
|
|
user = UserAPIKeyAuth(
|
|
api_key="sk-restricted",
|
|
user_id="alice",
|
|
models=["gpt-3.5-turbo"],
|
|
user_role=LitellmUserRoles.INTERNAL_USER.value,
|
|
)
|
|
|
|
# `can_key_call_model` raises a ProxyException for non-allowed models.
|
|
async def _raise_unauthorized(**kwargs):
|
|
raise Exception(
|
|
f"Key not allowed to access model. This key only has access to models={kwargs['valid_token'].models}"
|
|
)
|
|
|
|
with (
|
|
patch(
|
|
"litellm.proxy.auth.auth_checks.can_key_call_model",
|
|
new=AsyncMock(side_effect=_raise_unauthorized),
|
|
),
|
|
patch("litellm.proxy.proxy_server.llm_router", None),
|
|
):
|
|
with pytest.raises(HTTPException) as exc:
|
|
await rate_limiter._enforce_batch_file_model_access(
|
|
user_api_key_dict=user,
|
|
file_content_as_dict=file_dict,
|
|
)
|
|
|
|
assert exc.value.status_code == 403
|
|
assert "gpt-4o" in str(exc.value.detail)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_pre_call_allows_all_team_models_key_when_model_in_team_allowlist():
|
|
"""Keys with ``all-team-models`` must inherit the team allowlist when
|
|
validating models embedded in batch JSONL."""
|
|
from litellm.proxy._types import SpecialModelNames
|
|
from litellm.proxy.hooks.batch_rate_limiter import _PROXY_BatchRateLimiter
|
|
|
|
rate_limiter = _PROXY_BatchRateLimiter(
|
|
internal_usage_cache=MagicMock(),
|
|
parallel_request_limiter=MagicMock(),
|
|
)
|
|
proxy_alias = "openai/openai/gpt-5.5-batch"
|
|
file_dict = [
|
|
{
|
|
"body": {
|
|
"model": proxy_alias,
|
|
"messages": [{"role": "user", "content": "x"}],
|
|
}
|
|
}
|
|
]
|
|
user = UserAPIKeyAuth(
|
|
api_key="sk-team",
|
|
user_id="alice",
|
|
team_id="team-123",
|
|
models=[SpecialModelNames.all_team_models.value],
|
|
team_models=[proxy_alias],
|
|
user_role=LitellmUserRoles.INTERNAL_USER.value,
|
|
)
|
|
|
|
with patch("litellm.proxy.proxy_server.llm_router", None):
|
|
await rate_limiter._enforce_batch_file_model_access(
|
|
user_api_key_dict=user,
|
|
file_content_as_dict=file_dict,
|
|
)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_pre_call_uses_current_team_allowlist_for_all_team_models_key():
|
|
from litellm.proxy._types import LiteLLM_TeamTable, SpecialModelNames
|
|
from litellm.proxy.hooks.batch_rate_limiter import _PROXY_BatchRateLimiter
|
|
|
|
rate_limiter = _PROXY_BatchRateLimiter(
|
|
internal_usage_cache=MagicMock(),
|
|
parallel_request_limiter=MagicMock(),
|
|
)
|
|
stale_model = "stale-model"
|
|
current_model = "current-model"
|
|
file_dict = [
|
|
{
|
|
"body": {
|
|
"model": stale_model,
|
|
"messages": [{"role": "user", "content": "x"}],
|
|
}
|
|
}
|
|
]
|
|
user = UserAPIKeyAuth(
|
|
api_key="sk-team",
|
|
user_id="alice",
|
|
team_id="team-123",
|
|
models=[SpecialModelNames.all_team_models.value],
|
|
team_models=[stale_model],
|
|
user_role=LitellmUserRoles.INTERNAL_USER.value,
|
|
)
|
|
team_object = LiteLLM_TeamTable(
|
|
team_id="team-123",
|
|
models=[current_model],
|
|
)
|
|
|
|
with (
|
|
patch("litellm.proxy.proxy_server.prisma_client", MagicMock()),
|
|
patch("litellm.proxy.proxy_server.llm_router", None),
|
|
patch(
|
|
"litellm.proxy.auth.auth_checks.get_team_object",
|
|
new=AsyncMock(return_value=team_object),
|
|
) as mock_get_team_object,
|
|
pytest.raises(HTTPException) as exc_info,
|
|
):
|
|
await rate_limiter._enforce_batch_file_model_access(
|
|
user_api_key_dict=user,
|
|
file_content_as_dict=file_dict,
|
|
)
|
|
|
|
assert exc_info.value.status_code == 403
|
|
mock_get_team_object.assert_awaited_once()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_pre_call_allows_all_team_models_key_via_current_team_object():
|
|
"""Happy path for the team_object branch: with a DB client present, an
|
|
``all-team-models`` key whose batch model is on the *current* team
|
|
allowlist must be authorized through the freshly-fetched team object,
|
|
not the cached-``team_models`` fallback."""
|
|
from litellm.proxy._types import LiteLLM_TeamTable, SpecialModelNames
|
|
from litellm.proxy.hooks.batch_rate_limiter import _PROXY_BatchRateLimiter
|
|
|
|
rate_limiter = _PROXY_BatchRateLimiter(
|
|
internal_usage_cache=MagicMock(),
|
|
parallel_request_limiter=MagicMock(),
|
|
)
|
|
current_model = "current-model"
|
|
file_dict = [
|
|
{
|
|
"body": {
|
|
"model": current_model,
|
|
"messages": [{"role": "user", "content": "x"}],
|
|
}
|
|
}
|
|
]
|
|
user = UserAPIKeyAuth(
|
|
api_key="sk-team",
|
|
user_id="alice",
|
|
team_id="team-123",
|
|
models=[SpecialModelNames.all_team_models.value],
|
|
team_models=["stale-model"],
|
|
user_role=LitellmUserRoles.INTERNAL_USER.value,
|
|
)
|
|
team_object = LiteLLM_TeamTable(
|
|
team_id="team-123",
|
|
models=[current_model],
|
|
)
|
|
can_key_call_model = AsyncMock(return_value=True)
|
|
|
|
with (
|
|
patch("litellm.proxy.proxy_server.prisma_client", MagicMock()),
|
|
patch("litellm.proxy.proxy_server.llm_router", None),
|
|
patch(
|
|
"litellm.proxy.auth.auth_checks.get_team_object",
|
|
new=AsyncMock(return_value=team_object),
|
|
) as mock_get_team_object,
|
|
patch(
|
|
"litellm.proxy.auth.auth_checks.get_team_membership",
|
|
new=AsyncMock(return_value=None),
|
|
),
|
|
patch(
|
|
"litellm.proxy.auth.auth_checks.can_key_call_model",
|
|
new=can_key_call_model,
|
|
),
|
|
):
|
|
await rate_limiter._enforce_batch_file_model_access(
|
|
user_api_key_dict=user,
|
|
file_content_as_dict=file_dict,
|
|
)
|
|
|
|
mock_get_team_object.assert_awaited_once()
|
|
can_key_call_model.assert_not_awaited()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_pre_call_denies_all_team_models_key_via_member_scope():
|
|
"""The team_object branch must also apply the per-member model scope: a
|
|
model on the team allowlist but outside the member's ``allowed_models``
|
|
must be rejected with a 403."""
|
|
from litellm.proxy._types import (
|
|
LiteLLM_BudgetTable,
|
|
LiteLLM_TeamMembership,
|
|
LiteLLM_TeamTable,
|
|
SpecialModelNames,
|
|
)
|
|
from litellm.proxy.hooks.batch_rate_limiter import _PROXY_BatchRateLimiter
|
|
|
|
rate_limiter = _PROXY_BatchRateLimiter(
|
|
internal_usage_cache=MagicMock(),
|
|
parallel_request_limiter=MagicMock(),
|
|
)
|
|
team_model = "team-model"
|
|
file_dict = [
|
|
{
|
|
"body": {
|
|
"model": team_model,
|
|
"messages": [{"role": "user", "content": "x"}],
|
|
}
|
|
}
|
|
]
|
|
user = UserAPIKeyAuth(
|
|
api_key="sk-team",
|
|
user_id="alice",
|
|
team_id="team-123",
|
|
models=[SpecialModelNames.all_team_models.value],
|
|
team_models=[team_model],
|
|
user_role=LitellmUserRoles.INTERNAL_USER.value,
|
|
)
|
|
team_object = LiteLLM_TeamTable(team_id="team-123", models=[team_model])
|
|
membership = LiteLLM_TeamMembership(
|
|
user_id="alice",
|
|
team_id="team-123",
|
|
litellm_budget_table=LiteLLM_BudgetTable(allowed_models=["other-model"]),
|
|
)
|
|
|
|
with (
|
|
patch("litellm.proxy.proxy_server.prisma_client", MagicMock()),
|
|
patch("litellm.proxy.proxy_server.llm_router", None),
|
|
patch(
|
|
"litellm.proxy.auth.auth_checks.get_team_object",
|
|
new=AsyncMock(return_value=team_object),
|
|
),
|
|
patch(
|
|
"litellm.proxy.auth.auth_checks.get_team_membership",
|
|
new=AsyncMock(return_value=membership),
|
|
),
|
|
pytest.raises(HTTPException) as exc_info,
|
|
):
|
|
await rate_limiter._enforce_batch_file_model_access(
|
|
user_api_key_dict=user,
|
|
file_content_as_dict=file_dict,
|
|
)
|
|
|
|
assert exc_info.value.status_code == 403
|
|
assert team_model in str(exc_info.value.detail)
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
("team_fetch_error", "expected_status"),
|
|
[
|
|
(HTTPException(status_code=404, detail="team not found"), 404),
|
|
(Exception("team fetch failed"), 403),
|
|
],
|
|
)
|
|
@pytest.mark.asyncio
|
|
async def test_pre_call_fails_closed_when_current_team_fetch_fails_for_all_team_models_key(
|
|
team_fetch_error, expected_status
|
|
):
|
|
from litellm.proxy._types import SpecialModelNames
|
|
from litellm.proxy.hooks.batch_rate_limiter import _PROXY_BatchRateLimiter
|
|
|
|
rate_limiter = _PROXY_BatchRateLimiter(
|
|
internal_usage_cache=MagicMock(),
|
|
parallel_request_limiter=MagicMock(),
|
|
)
|
|
stale_model = "stale-model"
|
|
file_dict = [
|
|
{
|
|
"body": {
|
|
"model": stale_model,
|
|
"messages": [{"role": "user", "content": "x"}],
|
|
}
|
|
}
|
|
]
|
|
user = UserAPIKeyAuth(
|
|
api_key="sk-team",
|
|
user_id="alice",
|
|
team_id="team-123",
|
|
models=[SpecialModelNames.all_team_models.value],
|
|
team_models=[stale_model],
|
|
user_role=LitellmUserRoles.INTERNAL_USER.value,
|
|
)
|
|
|
|
with (
|
|
patch("litellm.proxy.proxy_server.prisma_client", MagicMock()),
|
|
patch("litellm.proxy.proxy_server.llm_router", None),
|
|
patch(
|
|
"litellm.proxy.auth.auth_checks.get_team_object",
|
|
new=AsyncMock(side_effect=team_fetch_error),
|
|
) as mock_get_team_object,
|
|
patch(
|
|
"litellm.proxy.auth.auth_checks.can_key_call_model",
|
|
new=AsyncMock(return_value=True),
|
|
) as mock_can_key_call_model,
|
|
pytest.raises(HTTPException) as exc_info,
|
|
):
|
|
await rate_limiter._enforce_batch_file_model_access(
|
|
user_api_key_dict=user,
|
|
file_content_as_dict=file_dict,
|
|
)
|
|
|
|
assert exc_info.value.status_code == expected_status
|
|
mock_get_team_object.assert_awaited_once()
|
|
mock_can_key_call_model.assert_not_awaited()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_pre_call_allows_authorized_model_in_batch_file():
|
|
"""If every model in the JSONL is on the caller's allowlist, the hook
|
|
must not raise."""
|
|
from litellm.proxy.hooks.batch_rate_limiter import _PROXY_BatchRateLimiter
|
|
|
|
rate_limiter = _PROXY_BatchRateLimiter(
|
|
internal_usage_cache=MagicMock(),
|
|
parallel_request_limiter=MagicMock(),
|
|
)
|
|
|
|
file_dict = [
|
|
{
|
|
"body": {
|
|
"model": "gpt-3.5-turbo",
|
|
"messages": [{"role": "user", "content": "x"}],
|
|
}
|
|
}
|
|
]
|
|
|
|
user = UserAPIKeyAuth(
|
|
api_key="sk-ok",
|
|
user_id="alice",
|
|
models=["gpt-3.5-turbo"],
|
|
user_role=LitellmUserRoles.INTERNAL_USER.value,
|
|
)
|
|
|
|
with (
|
|
patch(
|
|
"litellm.proxy.auth.auth_checks.can_key_call_model",
|
|
new=AsyncMock(return_value=True),
|
|
),
|
|
patch("litellm.proxy.proxy_server.llm_router", None),
|
|
):
|
|
# Should not raise
|
|
await rate_limiter._enforce_batch_file_model_access(
|
|
user_api_key_dict=user,
|
|
file_content_as_dict=file_dict,
|
|
)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_pre_call_skips_file_fetch_when_disabled_in_general_settings():
|
|
from litellm.proxy.hooks.batch_rate_limiter import _PROXY_BatchRateLimiter
|
|
|
|
rate_limiter = _PROXY_BatchRateLimiter(
|
|
internal_usage_cache=MagicMock(),
|
|
parallel_request_limiter=MagicMock(),
|
|
)
|
|
user = UserAPIKeyAuth(api_key="sk-ok", user_id="alice", models=["*"])
|
|
|
|
with patch(
|
|
"litellm.proxy.proxy_server.general_settings",
|
|
{"disable_batch_input_file_rate_limiting": True},
|
|
):
|
|
result = await rate_limiter.async_pre_call_hook(
|
|
user_api_key_dict=user,
|
|
cache=MagicMock(),
|
|
data={"input_file_id": "file-abc123"},
|
|
call_type="acreate_batch",
|
|
)
|
|
|
|
assert result == {"input_file_id": "file-abc123"}
|
|
rate_limiter.parallel_request_limiter._create_rate_limit_descriptors.assert_not_called()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_pre_call_skips_file_fetch_for_configured_provider():
|
|
from litellm.proxy.hooks.batch_rate_limiter import _PROXY_BatchRateLimiter
|
|
|
|
rate_limiter = _PROXY_BatchRateLimiter(
|
|
internal_usage_cache=MagicMock(),
|
|
parallel_request_limiter=MagicMock(),
|
|
)
|
|
user = UserAPIKeyAuth(api_key="sk-ok", user_id="alice", models=["*"])
|
|
data = {"input_file_id": "file-abc123", "model": "my-vllm-model"}
|
|
|
|
with (
|
|
patch(
|
|
"litellm.proxy.proxy_server.general_settings",
|
|
{"skip_batch_input_file_rate_limiting_for_providers": ["hosted_vllm"]},
|
|
),
|
|
patch("litellm.proxy.proxy_server.llm_router", MagicMock()),
|
|
patch(
|
|
"litellm.proxy.openai_files_endpoints.common_utils.get_credentials_for_model",
|
|
return_value={"custom_llm_provider": "hosted_vllm"},
|
|
),
|
|
patch("litellm.afile_content", new=AsyncMock()) as mock_afile_content,
|
|
):
|
|
result = await rate_limiter.async_pre_call_hook(
|
|
user_api_key_dict=user,
|
|
cache=MagicMock(),
|
|
data=data,
|
|
call_type="acreate_batch",
|
|
)
|
|
|
|
assert result == data
|
|
# A real skip must short-circuit before any file download or rate-limit
|
|
# work — assert the skip happened rather than the hook's error-recovery
|
|
# path (which also returns data unchanged).
|
|
mock_afile_content.assert_not_awaited()
|
|
rate_limiter.parallel_request_limiter._create_rate_limit_descriptors.assert_not_called()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_pre_call_does_not_skip_for_spoofed_provider():
|
|
"""The provider skip is resolved from trusted deployment credentials, so a
|
|
user-supplied ``custom_llm_provider`` that is not backed by the routing
|
|
deployment must not trigger a skip: the input file must still be fetched
|
|
and the rate-limit counters incremented."""
|
|
from litellm.proxy.hooks.batch_rate_limiter import _PROXY_BatchRateLimiter
|
|
|
|
rate_limiter = _PROXY_BatchRateLimiter(
|
|
internal_usage_cache=MagicMock(),
|
|
parallel_request_limiter=MagicMock(),
|
|
)
|
|
# An applicable rate limit keeps the no-limits shortcut from firing, so the
|
|
# only thing that could prevent the fetch below is the provider skip. If the
|
|
# spoofed ``custom_llm_provider`` were honored, afile_content would never be
|
|
# awaited.
|
|
rate_limiter.parallel_request_limiter._create_rate_limit_descriptors.return_value = [
|
|
{"rate_limit": {"requests_per_unit": 100}}
|
|
]
|
|
rate_limiter.parallel_request_limiter.atomic_check_and_increment_by_n = AsyncMock(
|
|
return_value={"overall_code": "OK", "statuses": []}
|
|
)
|
|
user = UserAPIKeyAuth(api_key="sk-ok", user_id="alice", models=["*"])
|
|
|
|
mock_router = MagicMock()
|
|
mock_router.model_list = []
|
|
mock_router.resolve_model_name_from_model_id.return_value = "my-openai-model"
|
|
|
|
mock_content = MagicMock()
|
|
mock_content.content = (
|
|
b'{"body": {"model": "my-openai-model", '
|
|
b'"messages": [{"role": "user", "content": "hi"}]}}\n'
|
|
)
|
|
|
|
with (
|
|
patch(
|
|
"litellm.proxy.proxy_server.general_settings",
|
|
{"skip_batch_input_file_rate_limiting_for_providers": ["hosted_vllm"]},
|
|
),
|
|
patch("litellm.proxy.proxy_server.llm_router", mock_router),
|
|
patch(
|
|
"litellm.proxy.openai_files_endpoints.common_utils.get_credentials_for_model",
|
|
return_value={"custom_llm_provider": "openai"},
|
|
),
|
|
patch(
|
|
"litellm.afile_content", new=AsyncMock(return_value=mock_content)
|
|
) as mock_afile_content,
|
|
):
|
|
await rate_limiter.async_pre_call_hook(
|
|
user_api_key_dict=user,
|
|
cache=MagicMock(),
|
|
data={
|
|
"input_file_id": "file-abc123",
|
|
"model": "my-openai-model",
|
|
"custom_llm_provider": "hosted_vllm",
|
|
},
|
|
call_type="acreate_batch",
|
|
)
|
|
|
|
# The spoofed provider did not short-circuit the skip decision: the file was
|
|
# fetched and the counters were incremented.
|
|
mock_afile_content.assert_awaited_once()
|
|
rate_limiter.parallel_request_limiter.atomic_check_and_increment_by_n.assert_awaited_once()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_count_input_file_usage_decodes_model_embedded_file_id():
|
|
import base64
|
|
|
|
from litellm.proxy.hooks.batch_rate_limiter import _PROXY_BatchRateLimiter
|
|
|
|
original_file_id = "file-provider-xyz"
|
|
encoded_payload = (
|
|
base64.urlsafe_b64encode(
|
|
f"litellm:{original_file_id};model,my-vllm-batch".encode()
|
|
)
|
|
.decode()
|
|
.rstrip("=")
|
|
)
|
|
encoded_file_id = f"file-{encoded_payload}"
|
|
|
|
rate_limiter = _PROXY_BatchRateLimiter(
|
|
internal_usage_cache=MagicMock(),
|
|
parallel_request_limiter=MagicMock(),
|
|
)
|
|
|
|
mock_content = MagicMock()
|
|
mock_content.content = b'{"custom_id": "1", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "my-vllm-batch", "messages": [{"role": "user", "content": "hi"}]}}\n'
|
|
|
|
with (
|
|
patch(
|
|
"litellm.afile_content",
|
|
new=AsyncMock(return_value=mock_content),
|
|
) as mock_afile_content,
|
|
patch(
|
|
"litellm.proxy.proxy_server.llm_router",
|
|
MagicMock(),
|
|
),
|
|
patch(
|
|
"litellm.proxy.openai_files_endpoints.common_utils.get_credentials_for_model",
|
|
return_value={
|
|
"api_key": "test-key",
|
|
"api_base": "http://vllm:8000/v1",
|
|
"custom_llm_provider": "hosted_vllm",
|
|
},
|
|
),
|
|
):
|
|
await rate_limiter.count_input_file_usage(
|
|
file_id=encoded_file_id,
|
|
custom_llm_provider="openai",
|
|
user_api_key_dict=UserAPIKeyAuth(api_key="sk-ok", user_id="alice"),
|
|
data={},
|
|
)
|
|
|
|
mock_afile_content.assert_awaited_once()
|
|
assert mock_afile_content.await_args.kwargs["file_id"] == original_file_id
|
|
assert mock_afile_content.await_args.kwargs["custom_llm_provider"] == "hosted_vllm"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_pre_call_allows_stripped_provider_model_when_key_has_proxy_alias():
|
|
"""After replace_model_in_jsonl, body.model is the provider id (e.g. gpt-5.5).
|
|
Auth must check the proxy model_name the key was granted, not the stripped id."""
|
|
from litellm.proxy.hooks.batch_rate_limiter import _PROXY_BatchRateLimiter
|
|
|
|
rate_limiter = _PROXY_BatchRateLimiter(
|
|
internal_usage_cache=MagicMock(),
|
|
parallel_request_limiter=MagicMock(),
|
|
)
|
|
proxy_alias = "openai/openai/gpt-5.5-batch"
|
|
file_dict = [
|
|
{"body": {"model": "gpt-5.5", "messages": [{"role": "user", "content": "x"}]}}
|
|
]
|
|
user = UserAPIKeyAuth(
|
|
api_key="sk-ok",
|
|
user_id="alice",
|
|
models=[proxy_alias],
|
|
user_role=LitellmUserRoles.INTERNAL_USER.value,
|
|
)
|
|
mock_router = MagicMock()
|
|
mock_router.model_list = []
|
|
mock_router.resolve_model_name_from_model_id.return_value = proxy_alias
|
|
can_key_call_model = AsyncMock(return_value=True)
|
|
|
|
with (
|
|
patch(
|
|
"litellm.proxy.auth.auth_checks.can_key_call_model",
|
|
new=can_key_call_model,
|
|
),
|
|
patch("litellm.proxy.proxy_server.llm_router", mock_router),
|
|
):
|
|
await rate_limiter._enforce_batch_file_model_access(
|
|
user_api_key_dict=user,
|
|
file_content_as_dict=file_dict,
|
|
)
|
|
|
|
can_key_call_model.assert_awaited_once()
|
|
assert can_key_call_model.await_args.kwargs["model"] == proxy_alias
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_pre_call_skips_check_when_no_models_present():
|
|
"""Files without any `body.model` (corrupt or empty) must not 500;
|
|
the rate limiter logs a warning elsewhere and proceeds."""
|
|
from litellm.proxy.hooks.batch_rate_limiter import _PROXY_BatchRateLimiter
|
|
|
|
rate_limiter = _PROXY_BatchRateLimiter(
|
|
internal_usage_cache=MagicMock(),
|
|
parallel_request_limiter=MagicMock(),
|
|
)
|
|
user = UserAPIKeyAuth(api_key="sk-ok", user_id="alice")
|
|
|
|
# Should not raise even though `can_key_call_model` is the default
|
|
# (would fail). The early-return on empty models keeps the call out
|
|
# entirely.
|
|
await rate_limiter._enforce_batch_file_model_access(
|
|
user_api_key_dict=user,
|
|
file_content_as_dict=[],
|
|
)
|
|
await rate_limiter._enforce_batch_file_model_access(
|
|
user_api_key_dict=user,
|
|
file_content_as_dict=[{"body": {}}],
|
|
)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Skip-path helpers
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def _make_rate_limiter():
|
|
from litellm.proxy.hooks.batch_rate_limiter import _PROXY_BatchRateLimiter
|
|
|
|
return _PROXY_BatchRateLimiter(
|
|
internal_usage_cache=MagicMock(),
|
|
parallel_request_limiter=MagicMock(),
|
|
)
|
|
|
|
|
|
def test_get_batch_routing_model_uses_request_model_for_plain_file():
|
|
rate_limiter = _make_rate_limiter()
|
|
assert (
|
|
rate_limiter._get_batch_routing_model({"model": "gpt-4o-mini"}) == "gpt-4o-mini"
|
|
)
|
|
|
|
|
|
def test_get_batch_routing_model_prefers_file_bound_over_request_model():
|
|
"""``create_batch`` routes a model-embedded file id on its bound model and
|
|
ignores the top-level ``model``. The skip decision must use the same
|
|
precedence, otherwise a caller could point ``model`` at a skip-listed
|
|
provider while the file routes a rate-limited one."""
|
|
import base64
|
|
|
|
rate_limiter = _make_rate_limiter()
|
|
encoded = (
|
|
base64.urlsafe_b64encode(b"litellm:file-xyz;model,vllm-batch")
|
|
.decode()
|
|
.rstrip("=")
|
|
)
|
|
assert (
|
|
rate_limiter._get_batch_routing_model(
|
|
{"input_file_id": f"file-{encoded}", "model": "gpt-4o-mini"}
|
|
)
|
|
== "vllm-batch"
|
|
)
|
|
|
|
|
|
def test_get_batch_routing_model_returns_none_without_model_or_file():
|
|
rate_limiter = _make_rate_limiter()
|
|
assert rate_limiter._get_batch_routing_model({}) is None
|
|
assert rate_limiter._get_batch_routing_model({"input_file_id": ""}) is None
|
|
|
|
|
|
def test_get_batch_routing_model_decodes_model_embedded_file_id():
|
|
import base64
|
|
|
|
rate_limiter = _make_rate_limiter()
|
|
encoded = (
|
|
base64.urlsafe_b64encode(b"litellm:file-xyz;model,vllm-batch")
|
|
.decode()
|
|
.rstrip("=")
|
|
)
|
|
assert (
|
|
rate_limiter._get_batch_routing_model({"input_file_id": f"file-{encoded}"})
|
|
== "vllm-batch"
|
|
)
|
|
|
|
|
|
def test_get_batch_routing_model_uses_unified_file_id_target():
|
|
rate_limiter = _make_rate_limiter()
|
|
with (
|
|
patch(
|
|
"litellm.proxy.openai_files_endpoints.common_utils.decode_model_from_file_id",
|
|
return_value=None,
|
|
),
|
|
patch(
|
|
"litellm.proxy.openai_files_endpoints.common_utils._is_base64_encoded_unified_file_id",
|
|
return_value="unified-id",
|
|
),
|
|
patch(
|
|
"litellm.proxy.openai_files_endpoints.common_utils.get_models_from_unified_file_id",
|
|
return_value=["model-a", "model-b"],
|
|
),
|
|
):
|
|
assert (
|
|
rate_limiter._get_batch_routing_model({"input_file_id": "file-managed"})
|
|
== "model-a"
|
|
)
|
|
|
|
|
|
def test_key_requires_batch_model_access_check_branches():
|
|
from litellm.proxy.hooks.batch_rate_limiter import _PROXY_BatchRateLimiter
|
|
|
|
check = _PROXY_BatchRateLimiter._key_requires_batch_model_access_check
|
|
assert check(UserAPIKeyAuth(api_key="sk", models=["*"])) is False
|
|
assert check(UserAPIKeyAuth(api_key="sk", models=["all-proxy-models"])) is False
|
|
assert (
|
|
check(UserAPIKeyAuth(api_key="sk", models=[], access_group_ids=["grp"])) is True
|
|
)
|
|
assert check(UserAPIKeyAuth(api_key="sk", models=[])) is False
|
|
assert check(UserAPIKeyAuth(api_key="sk", models=["gpt-4o-mini"])) is True
|
|
# Wildcard / all-proxy-models grant access to every model, so
|
|
# can_key_call_model passes any model regardless of access groups (which
|
|
# only ever widen access). Such keys must not be forced to download and
|
|
# validate the JSONL even when access_group_ids are also present.
|
|
assert (
|
|
check(UserAPIKeyAuth(api_key="sk", models=["*"], access_group_ids=["grp"]))
|
|
is False
|
|
)
|
|
assert (
|
|
check(
|
|
UserAPIKeyAuth(
|
|
api_key="sk", models=["all-proxy-models"], access_group_ids=["grp"]
|
|
)
|
|
)
|
|
is False
|
|
)
|
|
# A concrete model allowlist is still a subset even with access groups.
|
|
assert (
|
|
check(
|
|
UserAPIKeyAuth(
|
|
api_key="sk", models=["gpt-4o-mini"], access_group_ids=["grp"]
|
|
)
|
|
)
|
|
is True
|
|
)
|
|
|
|
|
|
def test_has_applicable_batch_rate_limits():
|
|
from litellm.proxy.hooks.batch_rate_limiter import _PROXY_BatchRateLimiter
|
|
|
|
has_limits = _PROXY_BatchRateLimiter._has_applicable_batch_rate_limits
|
|
assert has_limits([{"rate_limit": {"tokens_per_unit": 100}}]) is True
|
|
assert has_limits([{"rate_limit": {"requests_per_unit": 5}}]) is True
|
|
assert has_limits([{"rate_limit": {"max_parallel_requests": 2}}]) is True
|
|
assert has_limits([{"rate_limit": {}}, {}]) is False
|
|
|
|
|
|
def test_should_skip_returns_false_when_key_needs_model_access_check():
|
|
rate_limiter = _make_rate_limiter()
|
|
user = UserAPIKeyAuth(api_key="sk", models=["gpt-4o-mini"])
|
|
should_skip, descriptors = rate_limiter._should_skip_batch_input_file_processing(
|
|
data={"input_file_id": "file-abc"}, user_api_key_dict=user
|
|
)
|
|
assert should_skip is False
|
|
assert descriptors is None
|
|
|
|
|
|
def test_should_skip_ignores_client_supplied_metadata_flag():
|
|
"""A caller must not be able to bypass batch rate limits by setting
|
|
``litellm_metadata.skip_batch_input_file_rate_limiting`` in the request
|
|
body. The skip decision is server-controlled only, so with applicable rate
|
|
limits the JSONL is still processed despite the client flag."""
|
|
rate_limiter = _make_rate_limiter()
|
|
rate_limiter.parallel_request_limiter._create_rate_limit_descriptors.return_value = [
|
|
{"rate_limit": {"requests_per_unit": 5}}
|
|
]
|
|
user = UserAPIKeyAuth(api_key="sk", models=["*"])
|
|
with patch("litellm.proxy.proxy_server.general_settings", {}):
|
|
should_skip, descriptors = (
|
|
rate_limiter._should_skip_batch_input_file_processing(
|
|
data={
|
|
"input_file_id": "file-abc",
|
|
"litellm_metadata": {"skip_batch_input_file_rate_limiting": True},
|
|
},
|
|
user_api_key_dict=user,
|
|
)
|
|
)
|
|
assert should_skip is False
|
|
|
|
|
|
def test_should_not_skip_for_forged_model_embedded_file_id():
|
|
"""A ``file-<base64>`` id embeds an unsigned model name the caller fully
|
|
controls, so a caller can re-encode any accessible provider file id with a
|
|
skip-listed model while the JSONL still routes rate-limited ``body.model``
|
|
entries. The per-model skip must therefore never fire: with applicable rate
|
|
limits, a forged skip-listed file-bound model still falls through to file
|
|
processing and counter enforcement."""
|
|
import base64
|
|
|
|
rate_limiter = _make_rate_limiter()
|
|
rate_limiter.parallel_request_limiter._create_rate_limit_descriptors.return_value = [
|
|
{"rate_limit": {"requests_per_unit": 5}}
|
|
]
|
|
user = UserAPIKeyAuth(api_key="sk", models=["*"])
|
|
encoded = (
|
|
base64.urlsafe_b64encode(b"litellm:file-xyz;model,gpt-4o-mini")
|
|
.decode()
|
|
.rstrip("=")
|
|
)
|
|
with patch(
|
|
"litellm.proxy.proxy_server.general_settings",
|
|
{"skip_batch_input_file_rate_limiting_for_models": ["gpt-4o-mini"]},
|
|
):
|
|
should_skip, descriptors = (
|
|
rate_limiter._should_skip_batch_input_file_processing(
|
|
data={"input_file_id": f"file-{encoded}"},
|
|
user_api_key_dict=user,
|
|
)
|
|
)
|
|
assert should_skip is False
|
|
assert descriptors is not None
|
|
|
|
|
|
def test_should_not_skip_for_skip_listed_top_level_model():
|
|
"""A caller must not bypass batch rate limits by naming a skip-listed model
|
|
in the top-level ``model`` while routing a different model through the JSONL
|
|
``body.model`` entries. No per-model skip exists, so a skip-listed model over
|
|
a plain file still gets processed."""
|
|
rate_limiter = _make_rate_limiter()
|
|
rate_limiter.parallel_request_limiter._create_rate_limit_descriptors.return_value = [
|
|
{"rate_limit": {"requests_per_unit": 5}}
|
|
]
|
|
user = UserAPIKeyAuth(api_key="sk", models=["*"])
|
|
with patch(
|
|
"litellm.proxy.proxy_server.general_settings",
|
|
{"skip_batch_input_file_rate_limiting_for_models": ["gpt-4o-mini"]},
|
|
):
|
|
should_skip, descriptors = (
|
|
rate_limiter._should_skip_batch_input_file_processing(
|
|
data={"model": "gpt-4o-mini", "input_file_id": "file-abc"},
|
|
user_api_key_dict=user,
|
|
)
|
|
)
|
|
assert should_skip is False
|
|
|
|
|
|
def test_should_not_skip_when_file_bound_provider_is_rate_limited():
|
|
"""A caller must not bypass batch rate limits by pointing the top-level
|
|
``model`` at a skip-listed provider while the model-embedded ``input_file_id``
|
|
routes to a rate-limited provider. ``create_batch`` runs the batch on the
|
|
file-bound model, so the skip decision must resolve the provider from that
|
|
model and still process the file when its provider is not skip-listed."""
|
|
import base64
|
|
|
|
rate_limiter = _make_rate_limiter()
|
|
rate_limiter.parallel_request_limiter._create_rate_limit_descriptors.return_value = [
|
|
{"rate_limit": {"requests_per_unit": 5}}
|
|
]
|
|
user = UserAPIKeyAuth(api_key="sk", models=["*"])
|
|
encoded = (
|
|
base64.urlsafe_b64encode(b"litellm:file-orig;model,vllm-batch")
|
|
.decode()
|
|
.rstrip("=")
|
|
)
|
|
|
|
def _creds(model_id, **kwargs):
|
|
provider = "hosted_vllm" if model_id == "vllm-batch" else "openai"
|
|
return {"custom_llm_provider": provider}
|
|
|
|
with (
|
|
patch(
|
|
"litellm.proxy.proxy_server.general_settings",
|
|
{"skip_batch_input_file_rate_limiting_for_providers": ["openai"]},
|
|
),
|
|
patch("litellm.proxy.proxy_server.llm_router", MagicMock()),
|
|
patch(
|
|
"litellm.proxy.openai_files_endpoints.common_utils.get_credentials_for_model",
|
|
side_effect=_creds,
|
|
),
|
|
):
|
|
should_skip, descriptors = (
|
|
rate_limiter._should_skip_batch_input_file_processing(
|
|
data={"input_file_id": f"file-{encoded}", "model": "gpt-skip"},
|
|
user_api_key_dict=user,
|
|
)
|
|
)
|
|
assert should_skip is False
|
|
assert descriptors is not None
|
|
|
|
|
|
def test_should_skip_when_file_bound_provider_is_skip_listed():
|
|
"""The provider skip must still fire when the model the batch actually runs
|
|
on (the file-bound model) resolves to a skip-listed provider, even if the
|
|
top-level ``model`` resolves to a different, non-skipped provider."""
|
|
import base64
|
|
|
|
rate_limiter = _make_rate_limiter()
|
|
rate_limiter.parallel_request_limiter._create_rate_limit_descriptors.return_value = [
|
|
{"rate_limit": {"requests_per_unit": 5}}
|
|
]
|
|
user = UserAPIKeyAuth(api_key="sk", models=["*"])
|
|
encoded = (
|
|
base64.urlsafe_b64encode(b"litellm:file-orig;model,vllm-batch")
|
|
.decode()
|
|
.rstrip("=")
|
|
)
|
|
|
|
def _creds(model_id, **kwargs):
|
|
provider = "hosted_vllm" if model_id == "vllm-batch" else "openai"
|
|
return {"custom_llm_provider": provider}
|
|
|
|
with (
|
|
patch(
|
|
"litellm.proxy.proxy_server.general_settings",
|
|
{"skip_batch_input_file_rate_limiting_for_providers": ["hosted_vllm"]},
|
|
),
|
|
patch("litellm.proxy.proxy_server.llm_router", MagicMock()),
|
|
patch(
|
|
"litellm.proxy.openai_files_endpoints.common_utils.get_credentials_for_model",
|
|
side_effect=_creds,
|
|
),
|
|
):
|
|
should_skip, descriptors = (
|
|
rate_limiter._should_skip_batch_input_file_processing(
|
|
data={"input_file_id": f"file-{encoded}", "model": "gpt-skip"},
|
|
user_api_key_dict=user,
|
|
)
|
|
)
|
|
assert should_skip is True
|
|
|
|
|
|
def test_warns_once_for_unsupported_model_skip_setting():
|
|
"""Operators who set the no-op per-model skip key get a single warning so a
|
|
misconfigured deployment does not silently leave batch limits unenforced."""
|
|
rate_limiter = _make_rate_limiter()
|
|
rate_limiter.parallel_request_limiter._create_rate_limit_descriptors.return_value = [
|
|
{"rate_limit": {"requests_per_unit": 5}}
|
|
]
|
|
user = UserAPIKeyAuth(api_key="sk", models=["*"])
|
|
with (
|
|
patch(
|
|
"litellm.proxy.proxy_server.general_settings",
|
|
{"skip_batch_input_file_rate_limiting_for_models": ["gpt-4o-mini"]},
|
|
),
|
|
patch(
|
|
"litellm.proxy.hooks.batch_rate_limiter.verbose_proxy_logger"
|
|
) as mock_logger,
|
|
):
|
|
for _ in range(3):
|
|
rate_limiter._should_skip_batch_input_file_processing(
|
|
data={"model": "gpt-4o-mini", "input_file_id": "file-abc"},
|
|
user_api_key_dict=user,
|
|
)
|
|
assert mock_logger.warning.call_count == 1
|
|
assert (
|
|
"skip_batch_input_file_rate_limiting_for_models"
|
|
in mock_logger.warning.call_args[0][0]
|
|
)
|
|
|
|
|
|
def test_no_warning_when_model_skip_setting_absent():
|
|
rate_limiter = _make_rate_limiter()
|
|
rate_limiter.parallel_request_limiter._create_rate_limit_descriptors.return_value = [
|
|
{"rate_limit": {"requests_per_unit": 5}}
|
|
]
|
|
user = UserAPIKeyAuth(api_key="sk", models=["*"])
|
|
with (
|
|
patch(
|
|
"litellm.proxy.proxy_server.general_settings",
|
|
{"skip_batch_input_file_rate_limiting_for_providers": ["openai"]},
|
|
),
|
|
patch(
|
|
"litellm.proxy.hooks.batch_rate_limiter.verbose_proxy_logger"
|
|
) as mock_logger,
|
|
):
|
|
rate_limiter._should_skip_batch_input_file_processing(
|
|
data={"model": "gpt-4o-mini", "input_file_id": "file-abc"},
|
|
user_api_key_dict=user,
|
|
)
|
|
mock_logger.warning.assert_not_called()
|
|
|
|
|
|
def test_should_skip_when_no_rate_limits_configured():
|
|
rate_limiter = _make_rate_limiter()
|
|
rate_limiter.parallel_request_limiter._create_rate_limit_descriptors.return_value = [
|
|
{"rate_limit": {}}
|
|
]
|
|
user = UserAPIKeyAuth(api_key="sk", models=["*"])
|
|
with patch("litellm.proxy.proxy_server.general_settings", {}):
|
|
should_skip, descriptors = (
|
|
rate_limiter._should_skip_batch_input_file_processing(
|
|
data={"model": "gpt-4o-mini", "input_file_id": "file-abc"},
|
|
user_api_key_dict=user,
|
|
)
|
|
)
|
|
assert should_skip is True
|
|
assert descriptors is None
|
|
|
|
|
|
def test_should_not_skip_and_reuses_descriptors_when_limits_present():
|
|
rate_limiter = _make_rate_limiter()
|
|
descriptors = [{"rate_limit": {"tokens_per_unit": 100}}]
|
|
rate_limiter.parallel_request_limiter._create_rate_limit_descriptors.return_value = (
|
|
descriptors
|
|
)
|
|
user = UserAPIKeyAuth(api_key="sk", models=["*"])
|
|
with patch("litellm.proxy.proxy_server.general_settings", {}):
|
|
should_skip, returned = rate_limiter._should_skip_batch_input_file_processing(
|
|
data={"model": "gpt-4o-mini", "input_file_id": "file-abc"},
|
|
user_api_key_dict=user,
|
|
)
|
|
assert should_skip is False
|
|
assert returned is descriptors
|
|
|
|
|
|
def test_resolve_fetch_params_uses_request_model_credentials():
|
|
rate_limiter = _make_rate_limiter()
|
|
with (
|
|
patch("litellm.proxy.proxy_server.llm_router", MagicMock()),
|
|
patch(
|
|
"litellm.proxy.openai_files_endpoints.common_utils.get_credentials_for_model",
|
|
return_value={
|
|
"api_key": "k",
|
|
"api_base": "http://vllm:8000/v1",
|
|
"custom_llm_provider": "hosted_vllm",
|
|
},
|
|
),
|
|
):
|
|
provider_file_id, fetch_kwargs = (
|
|
rate_limiter._resolve_batch_input_file_fetch_params(
|
|
file_id="file-plain-openai",
|
|
custom_llm_provider="openai",
|
|
data={"model": "my-vllm-batch"},
|
|
)
|
|
)
|
|
assert provider_file_id == "file-plain-openai"
|
|
assert fetch_kwargs["model"] == "my-vllm-batch"
|
|
assert fetch_kwargs["custom_llm_provider"] == "hosted_vllm"
|
|
assert fetch_kwargs["api_base"] == "http://vllm:8000/v1"
|
|
|
|
|
|
def test_resolve_fetch_params_fails_open_on_credential_lookup_error():
|
|
rate_limiter = _make_rate_limiter()
|
|
with (
|
|
patch("litellm.proxy.proxy_server.llm_router", MagicMock()),
|
|
patch(
|
|
"litellm.proxy.openai_files_endpoints.common_utils.get_credentials_for_model",
|
|
side_effect=HTTPException(status_code=404, detail="no creds"),
|
|
),
|
|
):
|
|
provider_file_id, fetch_kwargs = (
|
|
rate_limiter._resolve_batch_input_file_fetch_params(
|
|
file_id="file-plain-openai",
|
|
custom_llm_provider="openai",
|
|
data={"model": "my-vllm-batch"},
|
|
)
|
|
)
|
|
assert provider_file_id == "file-plain-openai"
|
|
assert fetch_kwargs == {"custom_llm_provider": "openai"}
|
|
|
|
|
|
def test_resolve_fetch_params_model_embedded_fails_open_on_credential_error():
|
|
import base64
|
|
|
|
rate_limiter = _make_rate_limiter()
|
|
encoded = (
|
|
base64.urlsafe_b64encode(b"litellm:file-orig;model,vllm-batch")
|
|
.decode()
|
|
.rstrip("=")
|
|
)
|
|
encoded_file_id = f"file-{encoded}"
|
|
|
|
get_credentials = MagicMock(
|
|
side_effect=HTTPException(status_code=404, detail="no creds")
|
|
)
|
|
with (
|
|
patch("litellm.proxy.proxy_server.llm_router", MagicMock()),
|
|
patch(
|
|
"litellm.proxy.openai_files_endpoints.common_utils.get_credentials_for_model",
|
|
get_credentials,
|
|
),
|
|
):
|
|
provider_file_id, fetch_kwargs = (
|
|
rate_limiter._resolve_batch_input_file_fetch_params(
|
|
file_id=encoded_file_id,
|
|
custom_llm_provider="openai",
|
|
data={},
|
|
)
|
|
)
|
|
get_credentials.assert_called_once()
|
|
assert provider_file_id == "file-orig"
|
|
assert fetch_kwargs == {"custom_llm_provider": "openai"}
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_check_and_increment_computes_descriptors_when_not_passed():
|
|
from litellm.proxy.hooks.batch_rate_limiter import (
|
|
BatchFileUsage,
|
|
_PROXY_BatchRateLimiter,
|
|
)
|
|
|
|
parallel_request_limiter = MagicMock()
|
|
parallel_request_limiter._create_rate_limit_descriptors.return_value = [
|
|
{"rate_limit": {"tokens_per_unit": 100}}
|
|
]
|
|
parallel_request_limiter.atomic_check_and_increment_by_n = AsyncMock(
|
|
return_value={"overall_code": "OK", "statuses": []}
|
|
)
|
|
rate_limiter = _PROXY_BatchRateLimiter(
|
|
internal_usage_cache=MagicMock(),
|
|
parallel_request_limiter=parallel_request_limiter,
|
|
)
|
|
|
|
await rate_limiter._check_and_increment_batch_counters(
|
|
user_api_key_dict=UserAPIKeyAuth(api_key="sk", models=["*"]),
|
|
data={"model": "gpt-4o-mini"},
|
|
batch_usage=BatchFileUsage(total_tokens=10, request_count=1),
|
|
descriptors=None,
|
|
)
|
|
|
|
parallel_request_limiter._create_rate_limit_descriptors.assert_called_once()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_count_input_file_usage_raises_on_non_bytes_content():
|
|
from litellm.proxy.hooks.batch_rate_limiter import _PROXY_BatchRateLimiter
|
|
|
|
rate_limiter = _PROXY_BatchRateLimiter(
|
|
internal_usage_cache=MagicMock(),
|
|
parallel_request_limiter=MagicMock(),
|
|
)
|
|
|
|
bad_content = MagicMock()
|
|
bad_content.content = "not-bytes"
|
|
|
|
with patch("litellm.afile_content", new=AsyncMock(return_value=bad_content)):
|
|
with pytest.raises(ValueError, match="Expected bytes content"):
|
|
await rate_limiter.count_input_file_usage(
|
|
file_id="file-plain",
|
|
custom_llm_provider="openai",
|
|
user_api_key_dict=UserAPIKeyAuth(api_key="sk", models=["*"]),
|
|
data={},
|
|
)
|