mirror of
https://github.com/tiennm99/litellm.git
synced 2026-06-17 18:48:36 +00:00
1b96064600
* fix(proxy): cap managed-object poll size + expire stale rows + kill-switch flag to prevent OOM/Prisma connection loss * fix(constants): simplify PROXY_BATCH_POLLING_ENABLED readability * docs+test: document new polling env vars, add pagination+stale-cleanup tests * fix: exclude stale_expired from batch poll queries; fix update_many assertions in tests * fix: scope stale cleanup to file_purpose, fix file_object mocks, add CheckBatchCost tests * fix: avoid duplicate cost logging in fallback path; guard integer constants against zero/negative values * fix: cache _has_batch_processed_column; guard cleanup from aborting poll; narrow fallback except * fix: add complete/completed to primary query not_in; fix vacuous test assertion - Primary find_many was missing "complete" and "completed" in its not_in filter, creating asymmetry with the fallback query. A job whose status was set to "complete" but whose batch_processed flag update failed would be silently re-fetched and re-processed every cycle, emitting duplicate cost logs. - test_fallback_completion_update_omits_batch_processed patched _is_base64_encoded_unified_file_id to return None, causing an immediate continue — so update() was never called and the assertion looped over an empty list (vacuously true). Rewrote the test to mock the full completion pipeline, verify update() is called exactly once, and assert batch_processed is absent from the update data. - Added symmetric test (primary path) proving batch_processed IS included when the column exists. Made-with: Cursor
465 lines
17 KiB
Python
465 lines
17 KiB
Python
"""
|
|
Unit tests for CheckResponsesCost class
|
|
"""
|
|
|
|
import asyncio
|
|
from datetime import datetime
|
|
from unittest.mock import AsyncMock, MagicMock, Mock, patch
|
|
|
|
import pytest
|
|
|
|
from litellm.constants import MAX_OBJECTS_PER_POLL_CYCLE
|
|
from litellm.types.llms.openai import ResponseAPIUsage, ResponsesAPIResponse
|
|
|
|
|
|
class TestCheckResponsesCost:
|
|
"""Test suite for CheckResponsesCost class"""
|
|
|
|
@pytest.fixture
|
|
def mock_prisma_client(self):
|
|
"""Create a mock Prisma client"""
|
|
client = MagicMock()
|
|
client.db = MagicMock()
|
|
client.db.litellm_managedobjecttable = MagicMock()
|
|
return client
|
|
|
|
@pytest.fixture
|
|
def mock_proxy_logging_obj(self):
|
|
"""Create a mock ProxyLogging object"""
|
|
logging_obj = MagicMock()
|
|
logging_obj.get_proxy_hook = MagicMock(return_value=None)
|
|
return logging_obj
|
|
|
|
@pytest.fixture
|
|
def mock_llm_router(self):
|
|
"""Create a mock LLM Router"""
|
|
router = MagicMock()
|
|
router.aget_responses = AsyncMock()
|
|
router.get_deployment = MagicMock()
|
|
return router
|
|
|
|
@pytest.fixture
|
|
def check_responses_cost_instance(
|
|
self, mock_proxy_logging_obj, mock_prisma_client, mock_llm_router
|
|
):
|
|
"""Create a CheckResponsesCost instance with mocked dependencies"""
|
|
from litellm_enterprise.proxy.common_utils.check_responses_cost import (
|
|
CheckResponsesCost,
|
|
)
|
|
|
|
return CheckResponsesCost(
|
|
proxy_logging_obj=mock_proxy_logging_obj,
|
|
prisma_client=mock_prisma_client,
|
|
llm_router=mock_llm_router,
|
|
)
|
|
|
|
def test_initialization(self, check_responses_cost_instance):
|
|
"""Test that CheckResponsesCost initializes correctly"""
|
|
assert check_responses_cost_instance.proxy_logging_obj is not None
|
|
assert check_responses_cost_instance.prisma_client is not None
|
|
assert check_responses_cost_instance.llm_router is not None
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_check_responses_cost_no_jobs(
|
|
self, check_responses_cost_instance, mock_prisma_client
|
|
):
|
|
"""Test check_responses_cost when there are no jobs to process"""
|
|
mock_prisma_client.db.litellm_managedobjecttable.find_many = AsyncMock(
|
|
return_value=[]
|
|
)
|
|
mock_prisma_client.db.litellm_managedobjecttable.update_many = AsyncMock(
|
|
return_value=0
|
|
)
|
|
|
|
await check_responses_cost_instance.check_responses_cost()
|
|
|
|
# Verify find_many was called with pagination params
|
|
find_many_call = mock_prisma_client.db.litellm_managedobjecttable.find_many.call_args
|
|
assert find_many_call[1]["where"] == {
|
|
"status": {"in": ["queued", "in_progress"]},
|
|
"file_purpose": "response",
|
|
}
|
|
assert find_many_call[1]["take"] == MAX_OBJECTS_PER_POLL_CYCLE
|
|
assert find_many_call[1]["order"] == {"created_at": "asc"}
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_cleanup_stale_managed_objects(
|
|
self, check_responses_cost_instance, mock_prisma_client
|
|
):
|
|
"""Stale rows (older than cutoff) are bulk-updated to stale_expired before polling."""
|
|
mock_prisma_client.db.litellm_managedobjecttable.update_many = AsyncMock(
|
|
return_value=5
|
|
)
|
|
mock_prisma_client.db.litellm_managedobjecttable.find_many = AsyncMock(
|
|
return_value=[]
|
|
)
|
|
|
|
await check_responses_cost_instance.check_responses_cost()
|
|
|
|
# The first update_many call should be the stale-row cleanup scoped to "response"
|
|
calls = mock_prisma_client.db.litellm_managedobjecttable.update_many.call_args_list
|
|
stale_call = calls[0]
|
|
assert stale_call[1]["data"] == {"status": "stale_expired"}
|
|
where = stale_call[1]["where"]
|
|
assert where["file_purpose"] == "response"
|
|
assert "stale_expired" in where["status"]["not_in"]
|
|
assert "created_at" in where
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_check_responses_cost_with_completed_response(
|
|
self, check_responses_cost_instance, mock_prisma_client, mock_llm_router
|
|
):
|
|
"""Test check_responses_cost with a completed response"""
|
|
# Mock job with response ID
|
|
mock_job = MagicMock()
|
|
mock_job.unified_object_id = "resp_test_123"
|
|
mock_job.created_by = "test-user"
|
|
mock_job.id = "job-123"
|
|
mock_job.file_object = {"model": "gpt-4o", "id": "resp_test_123"}
|
|
|
|
mock_prisma_client.db.litellm_managedobjecttable.find_many = AsyncMock(
|
|
return_value=[mock_job]
|
|
)
|
|
|
|
# Mock completed response
|
|
mock_response = ResponsesAPIResponse(
|
|
id="resp_123",
|
|
object="response",
|
|
status="completed",
|
|
created_at=int(datetime.now().timestamp()),
|
|
output=[],
|
|
usage=ResponseAPIUsage(
|
|
input_tokens=100,
|
|
output_tokens=50,
|
|
total_tokens=150,
|
|
),
|
|
)
|
|
|
|
mock_prisma_client.db.litellm_managedobjecttable.update_many = AsyncMock(
|
|
return_value=0
|
|
)
|
|
|
|
# Run the check with mocked litellm.aget_responses
|
|
with patch("litellm.aget_responses", new_callable=AsyncMock) as mock_aget:
|
|
mock_aget.return_value = mock_response
|
|
|
|
await check_responses_cost_instance.check_responses_cost()
|
|
|
|
# calls[0] = stale cleanup, calls[1] = job completion
|
|
calls = mock_prisma_client.db.litellm_managedobjecttable.update_many.call_args_list
|
|
assert len(calls) == 2
|
|
completion_call = calls[1]
|
|
assert completion_call[1]["data"]["status"] == "completed"
|
|
assert completion_call[1]["where"]["id"]["in"] == ["job-123"]
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_check_responses_cost_with_failed_response(
|
|
self, check_responses_cost_instance, mock_prisma_client, mock_llm_router
|
|
):
|
|
"""Test check_responses_cost with a failed response"""
|
|
# Mock job
|
|
mock_job = MagicMock()
|
|
mock_job.unified_object_id = "resp_test_456"
|
|
mock_job.created_by = "test-user"
|
|
mock_job.id = "job-456"
|
|
mock_job.file_object = {"model": "gpt-4o", "id": "resp_test_456"}
|
|
|
|
mock_prisma_client.db.litellm_managedobjecttable.find_many = AsyncMock(
|
|
return_value=[mock_job]
|
|
)
|
|
|
|
# Mock failed response
|
|
mock_response = ResponsesAPIResponse(
|
|
id="resp_456",
|
|
object="response",
|
|
status="failed",
|
|
created_at=int(datetime.now().timestamp()),
|
|
output=[],
|
|
usage=None,
|
|
)
|
|
|
|
mock_prisma_client.db.litellm_managedobjecttable.update_many = AsyncMock(
|
|
return_value=0
|
|
)
|
|
|
|
# Run the check
|
|
with patch("litellm.aget_responses", new_callable=AsyncMock) as mock_aget:
|
|
mock_aget.return_value = mock_response
|
|
|
|
await check_responses_cost_instance.check_responses_cost()
|
|
|
|
# calls[0] = stale cleanup, calls[1] = job completion
|
|
calls = mock_prisma_client.db.litellm_managedobjecttable.update_many.call_args_list
|
|
assert len(calls) == 2
|
|
assert calls[1][1]["data"]["status"] == "completed"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_check_responses_cost_with_cancelled_response(
|
|
self, check_responses_cost_instance, mock_prisma_client
|
|
):
|
|
"""Test check_responses_cost with a cancelled response"""
|
|
# Mock job
|
|
mock_job = MagicMock()
|
|
mock_job.unified_object_id = "resp_test_789"
|
|
mock_job.created_by = "test-user"
|
|
mock_job.id = "job-789"
|
|
mock_job.file_object = {"model": "gpt-4o", "id": "resp_test_789"}
|
|
|
|
mock_prisma_client.db.litellm_managedobjecttable.find_many = AsyncMock(
|
|
return_value=[mock_job]
|
|
)
|
|
|
|
# Mock cancelled response
|
|
mock_response = ResponsesAPIResponse(
|
|
id="resp_789",
|
|
object="response",
|
|
status="cancelled",
|
|
created_at=int(datetime.now().timestamp()),
|
|
output=[],
|
|
usage=None,
|
|
)
|
|
|
|
mock_prisma_client.db.litellm_managedobjecttable.update_many = AsyncMock(
|
|
return_value=0
|
|
)
|
|
|
|
# Run the check
|
|
with patch("litellm.aget_responses", new_callable=AsyncMock) as mock_aget:
|
|
mock_aget.return_value = mock_response
|
|
|
|
await check_responses_cost_instance.check_responses_cost()
|
|
|
|
# calls[0] = stale cleanup, calls[1] = job completion
|
|
calls = mock_prisma_client.db.litellm_managedobjecttable.update_many.call_args_list
|
|
assert len(calls) == 2
|
|
assert calls[1][1]["data"]["status"] == "completed"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_check_responses_cost_with_in_progress_response(
|
|
self, check_responses_cost_instance, mock_prisma_client
|
|
):
|
|
"""Test check_responses_cost with a response still in progress"""
|
|
# Mock job
|
|
mock_job = MagicMock()
|
|
mock_job.unified_object_id = "resp_test_in_progress"
|
|
mock_job.created_by = "test-user"
|
|
mock_job.id = "job-in-progress"
|
|
mock_job.file_object = {"model": "gpt-4o", "id": "resp_test_in_progress"}
|
|
|
|
mock_prisma_client.db.litellm_managedobjecttable.find_many = AsyncMock(
|
|
return_value=[mock_job]
|
|
)
|
|
|
|
# Mock in-progress response
|
|
mock_response = ResponsesAPIResponse(
|
|
id="resp_in_progress",
|
|
object="response",
|
|
status="in_progress",
|
|
created_at=int(datetime.now().timestamp()),
|
|
output=[],
|
|
usage=None,
|
|
)
|
|
|
|
mock_prisma_client.db.litellm_managedobjecttable.update_many = AsyncMock(
|
|
return_value=0
|
|
)
|
|
|
|
# Run the check
|
|
with patch("litellm.aget_responses", new_callable=AsyncMock) as mock_aget:
|
|
mock_aget.return_value = mock_response
|
|
|
|
await check_responses_cost_instance.check_responses_cost()
|
|
|
|
# Only the stale-cleanup call should have fired — no completion update
|
|
calls = mock_prisma_client.db.litellm_managedobjecttable.update_many.call_args_list
|
|
assert len(calls) == 1
|
|
assert calls[0][1]["data"] == {"status": "stale_expired"}
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_check_responses_cost_with_queued_response(
|
|
self, check_responses_cost_instance, mock_prisma_client
|
|
):
|
|
"""Test check_responses_cost with a queued response"""
|
|
# Mock job
|
|
mock_job = MagicMock()
|
|
mock_job.unified_object_id = "resp_test_queued"
|
|
mock_job.created_by = "test-user"
|
|
mock_job.id = "job-queued"
|
|
mock_job.file_object = {"model": "gpt-4o", "id": "resp_test_queued"}
|
|
|
|
mock_prisma_client.db.litellm_managedobjecttable.find_many = AsyncMock(
|
|
return_value=[mock_job]
|
|
)
|
|
|
|
# Mock queued response
|
|
mock_response = ResponsesAPIResponse(
|
|
id="resp_queued",
|
|
object="response",
|
|
status="queued",
|
|
created_at=int(datetime.now().timestamp()),
|
|
output=[],
|
|
usage=None,
|
|
)
|
|
|
|
mock_prisma_client.db.litellm_managedobjecttable.update_many = AsyncMock(
|
|
return_value=0
|
|
)
|
|
|
|
# Run the check
|
|
with patch("litellm.aget_responses", new_callable=AsyncMock) as mock_aget:
|
|
mock_aget.return_value = mock_response
|
|
|
|
await check_responses_cost_instance.check_responses_cost()
|
|
|
|
# Only the stale-cleanup call should have fired — no completion update
|
|
calls = mock_prisma_client.db.litellm_managedobjecttable.update_many.call_args_list
|
|
assert len(calls) == 1
|
|
assert calls[0][1]["data"] == {"status": "stale_expired"}
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_check_responses_cost_with_exception(
|
|
self, check_responses_cost_instance, mock_prisma_client
|
|
):
|
|
"""Test check_responses_cost handles exceptions gracefully"""
|
|
# Mock job
|
|
mock_job = MagicMock()
|
|
mock_job.unified_object_id = "resp_test_error"
|
|
mock_job.created_by = "test-user"
|
|
mock_job.id = "job-error"
|
|
mock_job.file_object = {"model": "gpt-4o", "id": "resp_test_error"}
|
|
|
|
mock_prisma_client.db.litellm_managedobjecttable.find_many = AsyncMock(
|
|
return_value=[mock_job]
|
|
)
|
|
|
|
mock_prisma_client.db.litellm_managedobjecttable.update_many = AsyncMock(
|
|
return_value=0
|
|
)
|
|
|
|
# Run the check with mocked exception
|
|
with patch(
|
|
"litellm.aget_responses",
|
|
new_callable=AsyncMock,
|
|
side_effect=Exception("Provider error"),
|
|
):
|
|
# Should not raise, just skip the job
|
|
await check_responses_cost_instance.check_responses_cost()
|
|
|
|
# Only the stale-cleanup call should have fired — no completion update
|
|
calls = mock_prisma_client.db.litellm_managedobjecttable.update_many.call_args_list
|
|
assert len(calls) == 1
|
|
assert calls[0][1]["data"] == {"status": "stale_expired"}
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_check_responses_cost_multiple_jobs(
|
|
self, check_responses_cost_instance, mock_prisma_client
|
|
):
|
|
"""Test check_responses_cost with multiple jobs"""
|
|
# Mock multiple jobs
|
|
mock_job1 = MagicMock()
|
|
mock_job1.unified_object_id = "resp_test_1"
|
|
mock_job1.created_by = "user1"
|
|
mock_job1.id = "job-1"
|
|
mock_job1.file_object = {"model": "gpt-4o", "id": "resp_test_1"}
|
|
|
|
mock_job2 = MagicMock()
|
|
mock_job2.unified_object_id = "resp_test_2"
|
|
mock_job2.created_by = "user2"
|
|
mock_job2.id = "job-2"
|
|
mock_job2.file_object = {"model": "gpt-4o", "id": "resp_test_2"}
|
|
|
|
mock_job3 = MagicMock()
|
|
mock_job3.unified_object_id = "resp_test_3"
|
|
mock_job3.created_by = "user3"
|
|
mock_job3.id = "job-3"
|
|
mock_job3.file_object = {"model": "gpt-4o", "id": "resp_test_3"}
|
|
|
|
mock_prisma_client.db.litellm_managedobjecttable.find_many = AsyncMock(
|
|
return_value=[mock_job1, mock_job2, mock_job3]
|
|
)
|
|
|
|
# Mock responses - 2 completed, 1 in progress
|
|
mock_response1 = ResponsesAPIResponse(
|
|
id="resp_1",
|
|
object="response",
|
|
status="completed",
|
|
created_at=int(datetime.now().timestamp()),
|
|
output=[],
|
|
usage=ResponseAPIUsage(
|
|
input_tokens=100,
|
|
output_tokens=50,
|
|
total_tokens=150,
|
|
),
|
|
)
|
|
|
|
mock_response2 = ResponsesAPIResponse(
|
|
id="resp_2",
|
|
object="response",
|
|
status="in_progress",
|
|
created_at=int(datetime.now().timestamp()),
|
|
output=[],
|
|
usage=None,
|
|
)
|
|
|
|
mock_response3 = ResponsesAPIResponse(
|
|
id="resp_3",
|
|
object="response",
|
|
status="completed",
|
|
created_at=int(datetime.now().timestamp()),
|
|
output=[],
|
|
usage=ResponseAPIUsage(
|
|
input_tokens=200,
|
|
output_tokens=100,
|
|
total_tokens=300,
|
|
),
|
|
)
|
|
|
|
mock_prisma_client.db.litellm_managedobjecttable.update_many = AsyncMock(
|
|
return_value=0
|
|
)
|
|
|
|
# Run the check
|
|
with patch("litellm.aget_responses", new_callable=AsyncMock) as mock_aget:
|
|
mock_aget.side_effect = [mock_response1, mock_response2, mock_response3]
|
|
|
|
await check_responses_cost_instance.check_responses_cost()
|
|
|
|
# calls[0] = stale cleanup, calls[1] = completion of 2 finished jobs
|
|
calls = mock_prisma_client.db.litellm_managedobjecttable.update_many.call_args_list
|
|
assert len(calls) == 2
|
|
completion_call = calls[1]
|
|
assert len(completion_call[1]["where"]["id"]["in"]) == 2
|
|
assert "job-1" in completion_call[1]["where"]["id"]["in"]
|
|
assert "job-3" in completion_call[1]["where"]["id"]["in"]
|
|
assert "job-2" not in completion_call[1]["where"]["id"]["in"]
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_check_responses_cost_no_model_in_file_object(
|
|
self, check_responses_cost_instance, mock_prisma_client
|
|
):
|
|
"""When file_object has no 'model' key, model_name is None and metadata skips model fields."""
|
|
mock_job = MagicMock()
|
|
mock_job.unified_object_id = "resp_test_no_model"
|
|
mock_job.created_by = "test-user"
|
|
mock_job.id = "job-no-model"
|
|
mock_job.file_object = {} # no "model" key → model_name=None branch
|
|
|
|
mock_prisma_client.db.litellm_managedobjecttable.find_many = AsyncMock(
|
|
return_value=[mock_job]
|
|
)
|
|
mock_prisma_client.db.litellm_managedobjecttable.update_many = AsyncMock(
|
|
return_value=0
|
|
)
|
|
|
|
mock_response = MagicMock()
|
|
mock_response.status = "completed"
|
|
|
|
with patch("litellm.aget_responses", new_callable=AsyncMock) as mock_aget:
|
|
mock_aget.return_value = mock_response
|
|
await check_responses_cost_instance.check_responses_cost()
|
|
|
|
# aget_responses should be called without model metadata
|
|
call_kwargs = mock_aget.call_args[1]
|
|
assert "model" not in call_kwargs.get("litellm_metadata", {})
|
|
assert "model_group" not in call_kwargs.get("litellm_metadata", {})
|