Files
litellm/tests/proxy_unit_tests/test_check_batch_cost.py
T
Ishaan Jaff 1b96064600 fix(proxy): prevent OOM/Prisma connection loss from unbounded managed-object poll (#23472)
* fix(proxy): cap managed-object poll size + expire stale rows + kill-switch flag to prevent OOM/Prisma connection loss

* fix(constants): simplify PROXY_BATCH_POLLING_ENABLED readability

* docs+test: document new polling env vars, add pagination+stale-cleanup tests

* fix: exclude stale_expired from batch poll queries; fix update_many assertions in tests

* fix: scope stale cleanup to file_purpose, fix file_object mocks, add CheckBatchCost tests

* fix: avoid duplicate cost logging in fallback path; guard integer constants against zero/negative values

* fix: cache _has_batch_processed_column; guard cleanup from aborting poll; narrow fallback except

* fix: add complete/completed to primary query not_in; fix vacuous test assertion

- Primary find_many was missing "complete" and "completed" in its not_in
  filter, creating asymmetry with the fallback query. A job whose status
  was set to "complete" but whose batch_processed flag update failed would
  be silently re-fetched and re-processed every cycle, emitting duplicate
  cost logs.

- test_fallback_completion_update_omits_batch_processed patched
  _is_base64_encoded_unified_file_id to return None, causing an immediate
  continue — so update() was never called and the assertion looped over an
  empty list (vacuously true). Rewrote the test to mock the full
  completion pipeline, verify update() is called exactly once, and assert
  batch_processed is absent from the update data.

- Added symmetric test (primary path) proving batch_processed IS included
  when the column exists.

Made-with: Cursor
2026-03-13 11:01:40 -07:00

341 lines
14 KiB
Python

"""
Unit tests for CheckBatchCost class.
Covers: stale-row cleanup (file_purpose scoping), paginated find_many,
and the batch_processed-column fallback query.
"""
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
class TestCheckBatchCost:
"""Test suite for CheckBatchCost class"""
@pytest.fixture
def mock_prisma_client(self):
client = MagicMock()
client.db = MagicMock()
client.db.litellm_managedobjecttable = MagicMock()
client.db.litellm_usertable = MagicMock()
return client
@pytest.fixture
def mock_proxy_logging_obj(self):
return MagicMock()
@pytest.fixture
def mock_llm_router(self):
return MagicMock()
@pytest.fixture
def check_batch_cost_instance(
self, mock_proxy_logging_obj, mock_prisma_client, mock_llm_router
):
from litellm_enterprise.proxy.common_utils.check_batch_cost import CheckBatchCost
return CheckBatchCost(
proxy_logging_obj=mock_proxy_logging_obj,
prisma_client=mock_prisma_client,
llm_router=mock_llm_router,
)
@pytest.mark.asyncio
async def test_cleanup_scoped_to_batch_file_purpose(
self, check_batch_cost_instance, mock_prisma_client
):
"""_cleanup_stale_managed_objects scopes its update to file_purpose='batch' only."""
mock_prisma_client.db.litellm_managedobjecttable.update_many = AsyncMock(
return_value=0
)
# Return empty so the main poll loop exits immediately
mock_prisma_client.db.litellm_managedobjecttable.find_many = AsyncMock(
return_value=[]
)
await check_batch_cost_instance.check_batch_cost()
calls = mock_prisma_client.db.litellm_managedobjecttable.update_many.call_args_list
stale_call = calls[0]
assert stale_call[1]["data"] == {"status": "stale_expired"}
where = stale_call[1]["where"]
assert where["file_purpose"] == "batch"
assert "stale_expired" in where["status"]["not_in"]
assert "created_at" in where
@pytest.mark.asyncio
async def test_find_many_uses_pagination_and_excludes_stale(
self, check_batch_cost_instance, mock_prisma_client
):
"""find_many is called with take, order, and all terminal statuses excluded."""
from litellm.constants import MAX_OBJECTS_PER_POLL_CYCLE
mock_prisma_client.db.litellm_managedobjecttable.update_many = AsyncMock(
return_value=0
)
mock_prisma_client.db.litellm_managedobjecttable.find_many = AsyncMock(
return_value=[]
)
await check_batch_cost_instance.check_batch_cost()
find_call = mock_prisma_client.db.litellm_managedobjecttable.find_many.call_args
assert find_call[1]["take"] == MAX_OBJECTS_PER_POLL_CYCLE
assert find_call[1]["order"] == {"created_at": "asc"}
not_in = find_call[1]["where"]["status"]["not_in"]
assert "stale_expired" in not_in
assert "complete" in not_in
assert "completed" in not_in
@pytest.mark.asyncio
async def test_fallback_query_used_when_batch_processed_missing(
self, check_batch_cost_instance, mock_prisma_client
):
"""Falls back to query without batch_processed when primary query raises."""
from litellm.constants import MAX_OBJECTS_PER_POLL_CYCLE
mock_prisma_client.db.litellm_managedobjecttable.update_many = AsyncMock(
return_value=0
)
# First find_many (primary query) raises with a schema error; second (fallback) returns empty
mock_prisma_client.db.litellm_managedobjecttable.find_many = AsyncMock(
side_effect=[Exception("column batch_processed does not exist"), []]
)
await check_batch_cost_instance.check_batch_cost()
calls = mock_prisma_client.db.litellm_managedobjecttable.find_many.call_args_list
assert len(calls) == 2
fallback_where = calls[1][1]["where"]
assert "batch_processed" not in fallback_where
assert "stale_expired" in fallback_where["status"]["not_in"]
assert calls[1][1]["take"] == MAX_OBJECTS_PER_POLL_CYCLE
# Column absence is now cached — next call should go straight to fallback
assert check_batch_cost_instance._has_batch_processed_column is False
@pytest.mark.asyncio
async def test_column_absence_cached_across_cycles(
self, check_batch_cost_instance, mock_prisma_client
):
"""After column absence is discovered, subsequent cycles skip the primary query entirely."""
from litellm.constants import MAX_OBJECTS_PER_POLL_CYCLE
mock_prisma_client.db.litellm_managedobjecttable.update_many = AsyncMock(
return_value=0
)
# Simulate column already known absent from a previous cycle
check_batch_cost_instance._has_batch_processed_column = False
mock_prisma_client.db.litellm_managedobjecttable.find_many = AsyncMock(
return_value=[]
)
await check_batch_cost_instance.check_batch_cost()
# Only one find_many call — the fallback directly, no primary query attempt
assert mock_prisma_client.db.litellm_managedobjecttable.find_many.call_count == 1
fallback_where = mock_prisma_client.db.litellm_managedobjecttable.find_many.call_args[1]["where"]
assert "batch_processed" not in fallback_where
@pytest.mark.asyncio
async def test_fallback_completion_update_omits_batch_processed(
self, check_batch_cost_instance, mock_prisma_client, mock_llm_router
):
"""When batch_processed column is absent, completion update must not include it.
If it did, the update would fail silently, the job would never be marked done,
and every subsequent poll cycle would re-log the cost (duplicate billing).
"""
from unittest.mock import patch
mock_prisma_client.db.litellm_managedobjecttable.update_many = AsyncMock(
return_value=0
)
mock_prisma_client.db.litellm_managedobjecttable.update = AsyncMock()
mock_prisma_client.db.litellm_usertable.find_unique = AsyncMock(
return_value=None
)
mock_job = MagicMock()
mock_job.id = "job-fallback-1"
mock_job.unified_object_id = "dW5pZmllZF9iYXRjaF9pZA=="
mock_job.created_by = "user-1"
# Simulate column already known absent (e.g. discovered on a previous cycle)
check_batch_cost_instance._has_batch_processed_column = False
mock_prisma_client.db.litellm_managedobjecttable.find_many = AsyncMock(
return_value=[mock_job]
)
# Build a fake batch response whose status triggers the completion branch
mock_response = MagicMock()
mock_response.status = "completed"
mock_response.output_file_id = "file-output-123"
mock_response.model_dump_json.return_value = '{"id":"batch-1","status":"completed"}'
mock_llm_router.aretrieve_batch = AsyncMock(return_value=mock_response)
mock_llm_router.get_deployment_credentials_with_provider = MagicMock(
return_value={"api_key": "sk-test"}
)
mock_deployment = MagicMock()
mock_deployment.litellm_params.custom_llm_provider = "openai"
mock_deployment.litellm_params.model = "gpt-4"
mock_deployment.model_info.model_dump.return_value = {}
mock_llm_router.get_deployment = MagicMock(return_value=mock_deployment)
mock_file_content = MagicMock()
mock_file_content.content = b'{"id":"req-1"}'
decoded_id = "llm_model_id,model-123;llm_batch_id,batch-456;"
with (
patch(
"litellm.proxy.openai_files_endpoints.common_utils._is_base64_encoded_unified_file_id",
side_effect=[decoded_id, None],
),
patch(
"litellm.proxy.openai_files_endpoints.common_utils.get_model_id_from_unified_batch_id",
return_value="model-123",
),
patch(
"litellm.proxy.openai_files_endpoints.common_utils.get_batch_id_from_unified_batch_id",
return_value="batch-456",
),
patch(
"litellm.files.main.afile_content",
new_callable=AsyncMock,
return_value=mock_file_content,
),
patch(
"litellm.batches.batch_utils._get_file_content_as_dictionary",
return_value=[{"id": "req-1"}],
),
patch(
"litellm.batches.batch_utils.calculate_batch_cost_and_usage",
new_callable=AsyncMock,
return_value=(0.01, {"prompt_tokens": 10, "completion_tokens": 5}, ["gpt-4"]),
),
patch(
"litellm.litellm_core_utils.get_llm_provider_logic.get_llm_provider",
return_value=("gpt-4", "openai", None, None),
),
patch(
"litellm.litellm_core_utils.litellm_logging.Logging"
) as mock_logging_cls,
):
mock_logging_obj = MagicMock()
mock_logging_obj.async_success_handler = AsyncMock()
mock_logging_cls.return_value = mock_logging_obj
await check_batch_cost_instance.check_batch_cost()
# The update must have been called — this is the core assertion.
assert mock_prisma_client.db.litellm_managedobjecttable.update.call_count == 1, (
"Expected update() to be called exactly once for the completed job"
)
update_data = mock_prisma_client.db.litellm_managedobjecttable.update.call_args[1]["data"]
assert "batch_processed" not in update_data, (
"update() must NOT include batch_processed when column is absent"
)
assert update_data["status"] == "complete"
@pytest.mark.asyncio
async def test_primary_path_completion_update_includes_batch_processed(
self, check_batch_cost_instance, mock_prisma_client, mock_llm_router
):
"""When batch_processed column IS present, completion update must set it to True.
This is the symmetric counterpart to test_fallback_completion_update_omits_batch_processed
and proves the conditional on _has_batch_processed_column governs the update data.
"""
from unittest.mock import patch
mock_prisma_client.db.litellm_managedobjecttable.update_many = AsyncMock(
return_value=0
)
mock_prisma_client.db.litellm_managedobjecttable.update = AsyncMock()
mock_prisma_client.db.litellm_usertable.find_unique = AsyncMock(
return_value=None
)
mock_job = MagicMock()
mock_job.id = "job-primary-1"
mock_job.unified_object_id = "dW5pZmllZF9iYXRjaF9pZA=="
mock_job.created_by = "user-1"
assert check_batch_cost_instance._has_batch_processed_column is True
mock_prisma_client.db.litellm_managedobjecttable.find_many = AsyncMock(
return_value=[mock_job]
)
mock_response = MagicMock()
mock_response.status = "completed"
mock_response.output_file_id = "file-output-123"
mock_response.model_dump_json.return_value = '{"id":"batch-1","status":"completed"}'
mock_llm_router.aretrieve_batch = AsyncMock(return_value=mock_response)
mock_llm_router.get_deployment_credentials_with_provider = MagicMock(
return_value={"api_key": "sk-test"}
)
mock_deployment = MagicMock()
mock_deployment.litellm_params.custom_llm_provider = "openai"
mock_deployment.litellm_params.model = "gpt-4"
mock_deployment.model_info.model_dump.return_value = {}
mock_llm_router.get_deployment = MagicMock(return_value=mock_deployment)
mock_file_content = MagicMock()
mock_file_content.content = b'{"id":"req-1"}'
decoded_id = "llm_model_id,model-123;llm_batch_id,batch-456;"
with (
patch(
"litellm.proxy.openai_files_endpoints.common_utils._is_base64_encoded_unified_file_id",
side_effect=[decoded_id, None],
),
patch(
"litellm.proxy.openai_files_endpoints.common_utils.get_model_id_from_unified_batch_id",
return_value="model-123",
),
patch(
"litellm.proxy.openai_files_endpoints.common_utils.get_batch_id_from_unified_batch_id",
return_value="batch-456",
),
patch(
"litellm.files.main.afile_content",
new_callable=AsyncMock,
return_value=mock_file_content,
),
patch(
"litellm.batches.batch_utils._get_file_content_as_dictionary",
return_value=[{"id": "req-1"}],
),
patch(
"litellm.batches.batch_utils.calculate_batch_cost_and_usage",
new_callable=AsyncMock,
return_value=(0.01, {"prompt_tokens": 10, "completion_tokens": 5}, ["gpt-4"]),
),
patch(
"litellm.litellm_core_utils.get_llm_provider_logic.get_llm_provider",
return_value=("gpt-4", "openai", None, None),
),
patch(
"litellm.litellm_core_utils.litellm_logging.Logging"
) as mock_logging_cls,
):
mock_logging_obj = MagicMock()
mock_logging_obj.async_success_handler = AsyncMock()
mock_logging_cls.return_value = mock_logging_obj
await check_batch_cost_instance.check_batch_cost()
assert mock_prisma_client.db.litellm_managedobjecttable.update.call_count == 1, (
"Expected update() to be called exactly once for the completed job"
)
update_data = mock_prisma_client.db.litellm_managedobjecttable.update.call_args[1]["data"]
assert update_data["batch_processed"] is True, (
"update() must include batch_processed=True when column is present"
)
assert update_data["status"] == "complete"