Files
litellm/tests/test_litellm/proxy/test_batch_expiry.py
T
Ryan Crabbe 52ec73c07d fix(proxy): improve team expiry enforcement validation
- Change status codes from 400 to 500 for team metadata misconfig errors
  (callers can't fix admin-set config, 400 is misleading)
- Add anchor value validation to batch endpoint (matching files endpoint)
- Coerce seconds to int to handle string values from metadata
- Add error-path tests: missing keys, invalid anchor, status code assertions
- Add happy-path test: team injects expiry when caller sends nothing
2026-03-03 17:29:39 -08:00

263 lines
8.2 KiB
Python

"""
Tests for batch output_expires_after passthrough and team-level expiry enforcement.
"""
import os
import sys
from unittest.mock import MagicMock, patch
import pytest
sys.path.insert(
0, os.path.abspath("../../..")
) # Adds the parent directory to the system path
import litellm
from litellm.caching.caching import DualCache
from litellm.proxy._types import UserAPIKeyAuth
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
from litellm.proxy.proxy_server import app
from litellm.proxy.utils import ProxyLogging
from litellm.router import Router
from litellm.types.utils import LiteLLMBatch
from fastapi.testclient import TestClient
client = TestClient(app)
TEAM_EXPIRY = {"anchor": "created_at", "seconds": 3600}
CALLER_EXPIRY = {"anchor": "created_at", "seconds": 86400}
@pytest.fixture
def llm_router() -> Router:
return Router(
model_list=[
{
"model_name": "gpt-3.5-turbo",
"litellm_params": {
"model": "openai/gpt-3.5-turbo",
"api_key": "test-key",
},
"model_info": {"id": "gpt-3.5-turbo-id"},
},
]
)
def _setup_proxy(monkeypatch, llm_router: Router):
proxy_logging_obj = ProxyLogging(
user_api_key_cache=DualCache(default_in_memory_ttl=1)
)
monkeypatch.setattr("litellm.proxy.proxy_server.llm_router", llm_router)
monkeypatch.setattr(
"litellm.proxy.proxy_server.proxy_logging_obj", proxy_logging_obj
)
def _make_batch_response() -> LiteLLMBatch:
return LiteLLMBatch(
id="batch_abc123",
completion_window="24h",
created_at=1234567890,
endpoint="/v1/chat/completions",
input_file_id="file-abc123",
object="batch",
status="validating",
)
def test_output_expires_after_passthrough():
"""output_expires_after flows through create_batch to the provider."""
captured = {}
def capturing_create(**kwargs):
captured.update(kwargs)
mock_response = MagicMock()
mock_response.id = "batch_123"
return mock_response
with patch("litellm.batches.main.openai_batches_instance") as mock_instance:
mock_instance.create_batch.side_effect = capturing_create
litellm.create_batch(
completion_window="24h",
endpoint="/v1/chat/completions",
input_file_id="file-abc123",
output_expires_after=CALLER_EXPIRY,
custom_llm_provider="openai",
)
assert captured["create_batch_data"]["output_expires_after"] == CALLER_EXPIRY
class TestBatchEndpointTeamOverride:
"""Verify team-level enforced_batch_output_expires_after in the proxy endpoint."""
def _post_batch(
self,
monkeypatch,
llm_router: Router,
team_metadata: dict,
request_body: dict,
) -> dict:
"""POST /v1/batches with given team_metadata and body, return captured kwargs."""
_setup_proxy(monkeypatch, llm_router)
user_key = UserAPIKeyAuth(
api_key="test-key",
team_metadata=team_metadata,
)
app.dependency_overrides[user_api_key_auth] = lambda: user_key
captured_kwargs = {}
async def mock_acreate_batch(**kwargs):
captured_kwargs.update(kwargs)
return _make_batch_response()
monkeypatch.setattr(litellm, "acreate_batch", mock_acreate_batch)
try:
response = client.post(
"/v1/batches",
json=request_body,
headers={"Authorization": "Bearer test-key"},
)
assert response.status_code == 200
finally:
app.dependency_overrides.clear()
return captured_kwargs
def test_team_override_overrides_caller(self, monkeypatch, llm_router):
"""Team enforcement wins over caller-provided value."""
kwargs = self._post_batch(
monkeypatch,
llm_router,
team_metadata={
"enforced_batch_output_expires_after": TEAM_EXPIRY,
},
request_body={
"input_file_id": "file-abc123",
"endpoint": "/v1/chat/completions",
"completion_window": "24h",
"output_expires_after": CALLER_EXPIRY,
},
)
assert kwargs["output_expires_after"] == TEAM_EXPIRY
def test_no_team_setting_preserves_caller(self, monkeypatch, llm_router):
"""No team setting = caller value passes through."""
kwargs = self._post_batch(
monkeypatch,
llm_router,
team_metadata={},
request_body={
"input_file_id": "file-abc123",
"endpoint": "/v1/chat/completions",
"completion_window": "24h",
"output_expires_after": CALLER_EXPIRY,
},
)
assert kwargs["output_expires_after"] == CALLER_EXPIRY
def test_team_injects_when_caller_sends_nothing(self, monkeypatch, llm_router):
"""Team enforcement applies even when caller sends no expiry."""
kwargs = self._post_batch(
monkeypatch,
llm_router,
team_metadata={
"enforced_batch_output_expires_after": TEAM_EXPIRY,
},
request_body={
"input_file_id": "file-abc123",
"endpoint": "/v1/chat/completions",
"completion_window": "24h",
},
)
assert kwargs["output_expires_after"] == TEAM_EXPIRY
class TestBatchEndpointTeamValidation:
"""Verify validation errors for malformed team metadata on batch endpoint."""
def _post_batch_raw(
self,
monkeypatch,
llm_router: Router,
team_metadata: dict,
request_body: dict,
):
"""POST /v1/batches and return the raw response (no status assertion)."""
_setup_proxy(monkeypatch, llm_router)
user_key = UserAPIKeyAuth(
api_key="test-key",
team_metadata=team_metadata,
)
app.dependency_overrides[user_api_key_auth] = lambda: user_key
async def mock_acreate_batch(**kwargs):
return _make_batch_response()
monkeypatch.setattr(litellm, "acreate_batch", mock_acreate_batch)
try:
response = client.post(
"/v1/batches",
json=request_body,
headers={"Authorization": "Bearer test-key"},
)
finally:
app.dependency_overrides.clear()
return response
_BATCH_BODY = {
"input_file_id": "file-abc123",
"endpoint": "/v1/chat/completions",
"completion_window": "24h",
}
def test_missing_anchor_key_returns_500(self, monkeypatch, llm_router):
"""Missing 'anchor' key in team metadata returns 500."""
response = self._post_batch_raw(
monkeypatch,
llm_router,
team_metadata={
"enforced_batch_output_expires_after": {"seconds": 3600},
},
request_body=self._BATCH_BODY,
)
assert response.status_code == 500
assert "malformed" in response.json()["error"]["message"]
def test_missing_seconds_key_returns_500(self, monkeypatch, llm_router):
"""Missing 'seconds' key in team metadata returns 500."""
response = self._post_batch_raw(
monkeypatch,
llm_router,
team_metadata={
"enforced_batch_output_expires_after": {"anchor": "created_at"},
},
request_body=self._BATCH_BODY,
)
assert response.status_code == 500
assert "malformed" in response.json()["error"]["message"]
def test_invalid_anchor_returns_500(self, monkeypatch, llm_router):
"""Invalid anchor value in team metadata returns 500."""
response = self._post_batch_raw(
monkeypatch,
llm_router,
team_metadata={
"enforced_batch_output_expires_after": {
"anchor": "last_active_at",
"seconds": 3600,
},
},
request_body=self._BATCH_BODY,
)
assert response.status_code == 500
assert "created_at" in response.json()["error"]["message"]