Files
litellm/tests/proxy_unit_tests/test_check_responses_cost.py
T
Ishaan Jaff 1b96064600 fix(proxy): prevent OOM/Prisma connection loss from unbounded managed-object poll (#23472)
* fix(proxy): cap managed-object poll size + expire stale rows + kill-switch flag to prevent OOM/Prisma connection loss

* fix(constants): simplify PROXY_BATCH_POLLING_ENABLED readability

* docs+test: document new polling env vars, add pagination+stale-cleanup tests

* fix: exclude stale_expired from batch poll queries; fix update_many assertions in tests

* fix: scope stale cleanup to file_purpose, fix file_object mocks, add CheckBatchCost tests

* fix: avoid duplicate cost logging in fallback path; guard integer constants against zero/negative values

* fix: cache _has_batch_processed_column; guard cleanup from aborting poll; narrow fallback except

* fix: add complete/completed to primary query not_in; fix vacuous test assertion

- Primary find_many was missing "complete" and "completed" in its not_in
  filter, creating asymmetry with the fallback query. A job whose status
  was set to "complete" but whose batch_processed flag update failed would
  be silently re-fetched and re-processed every cycle, emitting duplicate
  cost logs.

- test_fallback_completion_update_omits_batch_processed patched
  _is_base64_encoded_unified_file_id to return None, causing an immediate
  continue — so update() was never called and the assertion looped over an
  empty list (vacuously true). Rewrote the test to mock the full
  completion pipeline, verify update() is called exactly once, and assert
  batch_processed is absent from the update data.

- Added symmetric test (primary path) proving batch_processed IS included
  when the column exists.

Made-with: Cursor
2026-03-13 11:01:40 -07:00

465 lines
17 KiB
Python

"""
Unit tests for CheckResponsesCost class
"""
import asyncio
from datetime import datetime
from unittest.mock import AsyncMock, MagicMock, Mock, patch
import pytest
from litellm.constants import MAX_OBJECTS_PER_POLL_CYCLE
from litellm.types.llms.openai import ResponseAPIUsage, ResponsesAPIResponse
class TestCheckResponsesCost:
"""Test suite for CheckResponsesCost class"""
@pytest.fixture
def mock_prisma_client(self):
"""Create a mock Prisma client"""
client = MagicMock()
client.db = MagicMock()
client.db.litellm_managedobjecttable = MagicMock()
return client
@pytest.fixture
def mock_proxy_logging_obj(self):
"""Create a mock ProxyLogging object"""
logging_obj = MagicMock()
logging_obj.get_proxy_hook = MagicMock(return_value=None)
return logging_obj
@pytest.fixture
def mock_llm_router(self):
"""Create a mock LLM Router"""
router = MagicMock()
router.aget_responses = AsyncMock()
router.get_deployment = MagicMock()
return router
@pytest.fixture
def check_responses_cost_instance(
self, mock_proxy_logging_obj, mock_prisma_client, mock_llm_router
):
"""Create a CheckResponsesCost instance with mocked dependencies"""
from litellm_enterprise.proxy.common_utils.check_responses_cost import (
CheckResponsesCost,
)
return CheckResponsesCost(
proxy_logging_obj=mock_proxy_logging_obj,
prisma_client=mock_prisma_client,
llm_router=mock_llm_router,
)
def test_initialization(self, check_responses_cost_instance):
"""Test that CheckResponsesCost initializes correctly"""
assert check_responses_cost_instance.proxy_logging_obj is not None
assert check_responses_cost_instance.prisma_client is not None
assert check_responses_cost_instance.llm_router is not None
@pytest.mark.asyncio
async def test_check_responses_cost_no_jobs(
self, check_responses_cost_instance, mock_prisma_client
):
"""Test check_responses_cost when there are no jobs to process"""
mock_prisma_client.db.litellm_managedobjecttable.find_many = AsyncMock(
return_value=[]
)
mock_prisma_client.db.litellm_managedobjecttable.update_many = AsyncMock(
return_value=0
)
await check_responses_cost_instance.check_responses_cost()
# Verify find_many was called with pagination params
find_many_call = mock_prisma_client.db.litellm_managedobjecttable.find_many.call_args
assert find_many_call[1]["where"] == {
"status": {"in": ["queued", "in_progress"]},
"file_purpose": "response",
}
assert find_many_call[1]["take"] == MAX_OBJECTS_PER_POLL_CYCLE
assert find_many_call[1]["order"] == {"created_at": "asc"}
@pytest.mark.asyncio
async def test_cleanup_stale_managed_objects(
self, check_responses_cost_instance, mock_prisma_client
):
"""Stale rows (older than cutoff) are bulk-updated to stale_expired before polling."""
mock_prisma_client.db.litellm_managedobjecttable.update_many = AsyncMock(
return_value=5
)
mock_prisma_client.db.litellm_managedobjecttable.find_many = AsyncMock(
return_value=[]
)
await check_responses_cost_instance.check_responses_cost()
# The first update_many call should be the stale-row cleanup scoped to "response"
calls = mock_prisma_client.db.litellm_managedobjecttable.update_many.call_args_list
stale_call = calls[0]
assert stale_call[1]["data"] == {"status": "stale_expired"}
where = stale_call[1]["where"]
assert where["file_purpose"] == "response"
assert "stale_expired" in where["status"]["not_in"]
assert "created_at" in where
@pytest.mark.asyncio
async def test_check_responses_cost_with_completed_response(
self, check_responses_cost_instance, mock_prisma_client, mock_llm_router
):
"""Test check_responses_cost with a completed response"""
# Mock job with response ID
mock_job = MagicMock()
mock_job.unified_object_id = "resp_test_123"
mock_job.created_by = "test-user"
mock_job.id = "job-123"
mock_job.file_object = {"model": "gpt-4o", "id": "resp_test_123"}
mock_prisma_client.db.litellm_managedobjecttable.find_many = AsyncMock(
return_value=[mock_job]
)
# Mock completed response
mock_response = ResponsesAPIResponse(
id="resp_123",
object="response",
status="completed",
created_at=int(datetime.now().timestamp()),
output=[],
usage=ResponseAPIUsage(
input_tokens=100,
output_tokens=50,
total_tokens=150,
),
)
mock_prisma_client.db.litellm_managedobjecttable.update_many = AsyncMock(
return_value=0
)
# Run the check with mocked litellm.aget_responses
with patch("litellm.aget_responses", new_callable=AsyncMock) as mock_aget:
mock_aget.return_value = mock_response
await check_responses_cost_instance.check_responses_cost()
# calls[0] = stale cleanup, calls[1] = job completion
calls = mock_prisma_client.db.litellm_managedobjecttable.update_many.call_args_list
assert len(calls) == 2
completion_call = calls[1]
assert completion_call[1]["data"]["status"] == "completed"
assert completion_call[1]["where"]["id"]["in"] == ["job-123"]
@pytest.mark.asyncio
async def test_check_responses_cost_with_failed_response(
self, check_responses_cost_instance, mock_prisma_client, mock_llm_router
):
"""Test check_responses_cost with a failed response"""
# Mock job
mock_job = MagicMock()
mock_job.unified_object_id = "resp_test_456"
mock_job.created_by = "test-user"
mock_job.id = "job-456"
mock_job.file_object = {"model": "gpt-4o", "id": "resp_test_456"}
mock_prisma_client.db.litellm_managedobjecttable.find_many = AsyncMock(
return_value=[mock_job]
)
# Mock failed response
mock_response = ResponsesAPIResponse(
id="resp_456",
object="response",
status="failed",
created_at=int(datetime.now().timestamp()),
output=[],
usage=None,
)
mock_prisma_client.db.litellm_managedobjecttable.update_many = AsyncMock(
return_value=0
)
# Run the check
with patch("litellm.aget_responses", new_callable=AsyncMock) as mock_aget:
mock_aget.return_value = mock_response
await check_responses_cost_instance.check_responses_cost()
# calls[0] = stale cleanup, calls[1] = job completion
calls = mock_prisma_client.db.litellm_managedobjecttable.update_many.call_args_list
assert len(calls) == 2
assert calls[1][1]["data"]["status"] == "completed"
@pytest.mark.asyncio
async def test_check_responses_cost_with_cancelled_response(
self, check_responses_cost_instance, mock_prisma_client
):
"""Test check_responses_cost with a cancelled response"""
# Mock job
mock_job = MagicMock()
mock_job.unified_object_id = "resp_test_789"
mock_job.created_by = "test-user"
mock_job.id = "job-789"
mock_job.file_object = {"model": "gpt-4o", "id": "resp_test_789"}
mock_prisma_client.db.litellm_managedobjecttable.find_many = AsyncMock(
return_value=[mock_job]
)
# Mock cancelled response
mock_response = ResponsesAPIResponse(
id="resp_789",
object="response",
status="cancelled",
created_at=int(datetime.now().timestamp()),
output=[],
usage=None,
)
mock_prisma_client.db.litellm_managedobjecttable.update_many = AsyncMock(
return_value=0
)
# Run the check
with patch("litellm.aget_responses", new_callable=AsyncMock) as mock_aget:
mock_aget.return_value = mock_response
await check_responses_cost_instance.check_responses_cost()
# calls[0] = stale cleanup, calls[1] = job completion
calls = mock_prisma_client.db.litellm_managedobjecttable.update_many.call_args_list
assert len(calls) == 2
assert calls[1][1]["data"]["status"] == "completed"
@pytest.mark.asyncio
async def test_check_responses_cost_with_in_progress_response(
self, check_responses_cost_instance, mock_prisma_client
):
"""Test check_responses_cost with a response still in progress"""
# Mock job
mock_job = MagicMock()
mock_job.unified_object_id = "resp_test_in_progress"
mock_job.created_by = "test-user"
mock_job.id = "job-in-progress"
mock_job.file_object = {"model": "gpt-4o", "id": "resp_test_in_progress"}
mock_prisma_client.db.litellm_managedobjecttable.find_many = AsyncMock(
return_value=[mock_job]
)
# Mock in-progress response
mock_response = ResponsesAPIResponse(
id="resp_in_progress",
object="response",
status="in_progress",
created_at=int(datetime.now().timestamp()),
output=[],
usage=None,
)
mock_prisma_client.db.litellm_managedobjecttable.update_many = AsyncMock(
return_value=0
)
# Run the check
with patch("litellm.aget_responses", new_callable=AsyncMock) as mock_aget:
mock_aget.return_value = mock_response
await check_responses_cost_instance.check_responses_cost()
# Only the stale-cleanup call should have fired — no completion update
calls = mock_prisma_client.db.litellm_managedobjecttable.update_many.call_args_list
assert len(calls) == 1
assert calls[0][1]["data"] == {"status": "stale_expired"}
@pytest.mark.asyncio
async def test_check_responses_cost_with_queued_response(
self, check_responses_cost_instance, mock_prisma_client
):
"""Test check_responses_cost with a queued response"""
# Mock job
mock_job = MagicMock()
mock_job.unified_object_id = "resp_test_queued"
mock_job.created_by = "test-user"
mock_job.id = "job-queued"
mock_job.file_object = {"model": "gpt-4o", "id": "resp_test_queued"}
mock_prisma_client.db.litellm_managedobjecttable.find_many = AsyncMock(
return_value=[mock_job]
)
# Mock queued response
mock_response = ResponsesAPIResponse(
id="resp_queued",
object="response",
status="queued",
created_at=int(datetime.now().timestamp()),
output=[],
usage=None,
)
mock_prisma_client.db.litellm_managedobjecttable.update_many = AsyncMock(
return_value=0
)
# Run the check
with patch("litellm.aget_responses", new_callable=AsyncMock) as mock_aget:
mock_aget.return_value = mock_response
await check_responses_cost_instance.check_responses_cost()
# Only the stale-cleanup call should have fired — no completion update
calls = mock_prisma_client.db.litellm_managedobjecttable.update_many.call_args_list
assert len(calls) == 1
assert calls[0][1]["data"] == {"status": "stale_expired"}
@pytest.mark.asyncio
async def test_check_responses_cost_with_exception(
self, check_responses_cost_instance, mock_prisma_client
):
"""Test check_responses_cost handles exceptions gracefully"""
# Mock job
mock_job = MagicMock()
mock_job.unified_object_id = "resp_test_error"
mock_job.created_by = "test-user"
mock_job.id = "job-error"
mock_job.file_object = {"model": "gpt-4o", "id": "resp_test_error"}
mock_prisma_client.db.litellm_managedobjecttable.find_many = AsyncMock(
return_value=[mock_job]
)
mock_prisma_client.db.litellm_managedobjecttable.update_many = AsyncMock(
return_value=0
)
# Run the check with mocked exception
with patch(
"litellm.aget_responses",
new_callable=AsyncMock,
side_effect=Exception("Provider error"),
):
# Should not raise, just skip the job
await check_responses_cost_instance.check_responses_cost()
# Only the stale-cleanup call should have fired — no completion update
calls = mock_prisma_client.db.litellm_managedobjecttable.update_many.call_args_list
assert len(calls) == 1
assert calls[0][1]["data"] == {"status": "stale_expired"}
@pytest.mark.asyncio
async def test_check_responses_cost_multiple_jobs(
self, check_responses_cost_instance, mock_prisma_client
):
"""Test check_responses_cost with multiple jobs"""
# Mock multiple jobs
mock_job1 = MagicMock()
mock_job1.unified_object_id = "resp_test_1"
mock_job1.created_by = "user1"
mock_job1.id = "job-1"
mock_job1.file_object = {"model": "gpt-4o", "id": "resp_test_1"}
mock_job2 = MagicMock()
mock_job2.unified_object_id = "resp_test_2"
mock_job2.created_by = "user2"
mock_job2.id = "job-2"
mock_job2.file_object = {"model": "gpt-4o", "id": "resp_test_2"}
mock_job3 = MagicMock()
mock_job3.unified_object_id = "resp_test_3"
mock_job3.created_by = "user3"
mock_job3.id = "job-3"
mock_job3.file_object = {"model": "gpt-4o", "id": "resp_test_3"}
mock_prisma_client.db.litellm_managedobjecttable.find_many = AsyncMock(
return_value=[mock_job1, mock_job2, mock_job3]
)
# Mock responses - 2 completed, 1 in progress
mock_response1 = ResponsesAPIResponse(
id="resp_1",
object="response",
status="completed",
created_at=int(datetime.now().timestamp()),
output=[],
usage=ResponseAPIUsage(
input_tokens=100,
output_tokens=50,
total_tokens=150,
),
)
mock_response2 = ResponsesAPIResponse(
id="resp_2",
object="response",
status="in_progress",
created_at=int(datetime.now().timestamp()),
output=[],
usage=None,
)
mock_response3 = ResponsesAPIResponse(
id="resp_3",
object="response",
status="completed",
created_at=int(datetime.now().timestamp()),
output=[],
usage=ResponseAPIUsage(
input_tokens=200,
output_tokens=100,
total_tokens=300,
),
)
mock_prisma_client.db.litellm_managedobjecttable.update_many = AsyncMock(
return_value=0
)
# Run the check
with patch("litellm.aget_responses", new_callable=AsyncMock) as mock_aget:
mock_aget.side_effect = [mock_response1, mock_response2, mock_response3]
await check_responses_cost_instance.check_responses_cost()
# calls[0] = stale cleanup, calls[1] = completion of 2 finished jobs
calls = mock_prisma_client.db.litellm_managedobjecttable.update_many.call_args_list
assert len(calls) == 2
completion_call = calls[1]
assert len(completion_call[1]["where"]["id"]["in"]) == 2
assert "job-1" in completion_call[1]["where"]["id"]["in"]
assert "job-3" in completion_call[1]["where"]["id"]["in"]
assert "job-2" not in completion_call[1]["where"]["id"]["in"]
@pytest.mark.asyncio
async def test_check_responses_cost_no_model_in_file_object(
self, check_responses_cost_instance, mock_prisma_client
):
"""When file_object has no 'model' key, model_name is None and metadata skips model fields."""
mock_job = MagicMock()
mock_job.unified_object_id = "resp_test_no_model"
mock_job.created_by = "test-user"
mock_job.id = "job-no-model"
mock_job.file_object = {} # no "model" key → model_name=None branch
mock_prisma_client.db.litellm_managedobjecttable.find_many = AsyncMock(
return_value=[mock_job]
)
mock_prisma_client.db.litellm_managedobjecttable.update_many = AsyncMock(
return_value=0
)
mock_response = MagicMock()
mock_response.status = "completed"
with patch("litellm.aget_responses", new_callable=AsyncMock) as mock_aget:
mock_aget.return_value = mock_response
await check_responses_cost_instance.check_responses_cost()
# aget_responses should be called without model metadata
call_kwargs = mock_aget.call_args[1]
assert "model" not in call_kwargs.get("litellm_metadata", {})
assert "model_group" not in call_kwargs.get("litellm_metadata", {})