mirror of
https://github.com/tiennm99/litellm.git
synced 2026-06-17 20:48:32 +00:00
cbdc70d544
* fix(managed_batches): convert raw output_file_id to managed ID in CheckBatchCost poller CheckBatchCost bypasses async_post_call_success_hook, causing raw provider output_file_ids to be persisted in LiteLLM_ManagedObjectTable. This fix converts output_file_id and error_file_id to managed base64 IDs before the DB write. Co-authored-by: Cursor <cursoragent@cursor.com> * fix(check_batch_cost): persist managed file before mutating response and propagate team_id - Move setattr after store_unified_file_id so the response only receives the managed ID once the DB record is successfully written. Avoids serializing an orphaned managed ID into file_object when the store call fails. - Populate team_id on the minimal UserAPIKeyAuth from job.team_id so the managed file record is created with the correct team ownership, allowing other team members to access the batch output file via /files/{id}/content. Co-authored-by: Yassin Kortam <yassin@berri.ai> * test(managed_batches): extend test to cover error_file_id conversion Co-authored-by: Cursor <cursoragent@cursor.com> * fix managed file test --------- Co-authored-by: Cursor <cursoragent@cursor.com> Co-authored-by: Yassin Kortam <yassin@berri.ai>
515 lines
20 KiB
Python
515 lines
20 KiB
Python
"""
|
|
Unit tests for CheckBatchCost class.
|
|
Covers: stale-row cleanup (file_purpose scoping), paginated find_many,
|
|
and the batch_processed-column fallback query.
|
|
"""
|
|
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
import pytest
|
|
|
|
|
|
class TestCheckBatchCost:
|
|
"""Test suite for CheckBatchCost class"""
|
|
|
|
@pytest.fixture
|
|
def mock_prisma_client(self):
|
|
client = MagicMock()
|
|
client.db = MagicMock()
|
|
client.db.litellm_managedobjecttable = MagicMock()
|
|
client.db.litellm_usertable = MagicMock()
|
|
return client
|
|
|
|
@pytest.fixture
|
|
def mock_proxy_logging_obj(self):
|
|
mock = MagicMock()
|
|
mock.get_proxy_hook.return_value = None
|
|
return mock
|
|
|
|
@pytest.fixture
|
|
def mock_llm_router(self):
|
|
return MagicMock()
|
|
|
|
@pytest.fixture
|
|
def check_batch_cost_instance(
|
|
self, mock_proxy_logging_obj, mock_prisma_client, mock_llm_router
|
|
):
|
|
from litellm_enterprise.proxy.common_utils.check_batch_cost import (
|
|
CheckBatchCost,
|
|
)
|
|
|
|
return CheckBatchCost(
|
|
proxy_logging_obj=mock_proxy_logging_obj,
|
|
prisma_client=mock_prisma_client,
|
|
llm_router=mock_llm_router,
|
|
)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_cleanup_scoped_to_batch_file_purpose(
|
|
self, check_batch_cost_instance, mock_prisma_client
|
|
):
|
|
"""_cleanup_stale_managed_objects scopes its update to file_purpose='batch' only."""
|
|
mock_prisma_client.db.litellm_managedobjecttable.update_many = AsyncMock(
|
|
return_value=0
|
|
)
|
|
# Return empty so the main poll loop exits immediately
|
|
mock_prisma_client.db.litellm_managedobjecttable.find_many = AsyncMock(
|
|
return_value=[]
|
|
)
|
|
|
|
await check_batch_cost_instance.check_batch_cost()
|
|
|
|
calls = (
|
|
mock_prisma_client.db.litellm_managedobjecttable.update_many.call_args_list
|
|
)
|
|
stale_call = calls[0]
|
|
assert stale_call[1]["data"] == {"status": "stale_expired"}
|
|
where = stale_call[1]["where"]
|
|
assert where["file_purpose"] == "batch"
|
|
assert "stale_expired" in where["status"]["not_in"]
|
|
assert "created_at" in where
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_find_many_uses_pagination_and_excludes_stale(
|
|
self, check_batch_cost_instance, mock_prisma_client
|
|
):
|
|
"""find_many is called with take, order, and all terminal statuses excluded."""
|
|
from litellm.constants import MAX_OBJECTS_PER_POLL_CYCLE
|
|
|
|
mock_prisma_client.db.litellm_managedobjecttable.update_many = AsyncMock(
|
|
return_value=0
|
|
)
|
|
mock_prisma_client.db.litellm_managedobjecttable.find_many = AsyncMock(
|
|
return_value=[]
|
|
)
|
|
|
|
await check_batch_cost_instance.check_batch_cost()
|
|
|
|
find_call = mock_prisma_client.db.litellm_managedobjecttable.find_many.call_args
|
|
assert find_call[1]["take"] == MAX_OBJECTS_PER_POLL_CYCLE
|
|
assert find_call[1]["order"] == {"created_at": "asc"}
|
|
not_in = find_call[1]["where"]["status"]["not_in"]
|
|
assert "stale_expired" in not_in
|
|
# "complete"/"completed" are intentionally NOT excluded from the
|
|
# primary query — the batch_processed=False filter is sufficient.
|
|
# This allows CheckBatchCost to pick up batches that were
|
|
# transitioned to "complete" by the retrieve_batch endpoint
|
|
# before CheckBatchCost had a chance to process them.
|
|
assert "complete" not in not_in
|
|
assert "completed" not in not_in
|
|
assert find_call[1]["where"]["batch_processed"] is False
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_fallback_query_used_when_batch_processed_missing(
|
|
self, check_batch_cost_instance, mock_prisma_client
|
|
):
|
|
"""Falls back to query without batch_processed when primary query raises."""
|
|
from litellm.constants import MAX_OBJECTS_PER_POLL_CYCLE
|
|
|
|
mock_prisma_client.db.litellm_managedobjecttable.update_many = AsyncMock(
|
|
return_value=0
|
|
)
|
|
# First find_many (primary query) raises with a schema error; second (fallback) returns empty
|
|
mock_prisma_client.db.litellm_managedobjecttable.find_many = AsyncMock(
|
|
side_effect=[Exception("column batch_processed does not exist"), []]
|
|
)
|
|
|
|
await check_batch_cost_instance.check_batch_cost()
|
|
|
|
calls = (
|
|
mock_prisma_client.db.litellm_managedobjecttable.find_many.call_args_list
|
|
)
|
|
assert len(calls) == 2
|
|
fallback_where = calls[1][1]["where"]
|
|
assert "batch_processed" not in fallback_where
|
|
assert "stale_expired" in fallback_where["status"]["not_in"]
|
|
assert calls[1][1]["take"] == MAX_OBJECTS_PER_POLL_CYCLE
|
|
# Column absence is now cached — next call should go straight to fallback
|
|
assert check_batch_cost_instance._has_batch_processed_column is False
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_column_absence_cached_across_cycles(
|
|
self, check_batch_cost_instance, mock_prisma_client
|
|
):
|
|
"""After column absence is discovered, subsequent cycles skip the primary query entirely."""
|
|
from litellm.constants import MAX_OBJECTS_PER_POLL_CYCLE
|
|
|
|
mock_prisma_client.db.litellm_managedobjecttable.update_many = AsyncMock(
|
|
return_value=0
|
|
)
|
|
# Simulate column already known absent from a previous cycle
|
|
check_batch_cost_instance._has_batch_processed_column = False
|
|
mock_prisma_client.db.litellm_managedobjecttable.find_many = AsyncMock(
|
|
return_value=[]
|
|
)
|
|
|
|
await check_batch_cost_instance.check_batch_cost()
|
|
|
|
# Only one find_many call — the fallback directly, no primary query attempt
|
|
assert (
|
|
mock_prisma_client.db.litellm_managedobjecttable.find_many.call_count == 1
|
|
)
|
|
fallback_where = (
|
|
mock_prisma_client.db.litellm_managedobjecttable.find_many.call_args[1][
|
|
"where"
|
|
]
|
|
)
|
|
assert "batch_processed" not in fallback_where
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_fallback_completion_update_omits_batch_processed(
|
|
self, check_batch_cost_instance, mock_prisma_client, mock_llm_router
|
|
):
|
|
"""When batch_processed column is absent, completion update must not include it.
|
|
|
|
If it did, the update would fail silently, the job would never be marked done,
|
|
and every subsequent poll cycle would re-log the cost (duplicate billing).
|
|
"""
|
|
from unittest.mock import patch
|
|
|
|
mock_prisma_client.db.litellm_managedobjecttable.update_many = AsyncMock(
|
|
return_value=0
|
|
)
|
|
mock_prisma_client.db.litellm_managedobjecttable.update = AsyncMock()
|
|
mock_prisma_client.db.litellm_usertable.find_unique = AsyncMock(
|
|
return_value=None
|
|
)
|
|
|
|
mock_job = MagicMock()
|
|
mock_job.id = "job-fallback-1"
|
|
mock_job.unified_object_id = "dW5pZmllZF9iYXRjaF9pZA=="
|
|
mock_job.created_by = "user-1"
|
|
|
|
# Simulate column already known absent (e.g. discovered on a previous cycle)
|
|
check_batch_cost_instance._has_batch_processed_column = False
|
|
mock_prisma_client.db.litellm_managedobjecttable.find_many = AsyncMock(
|
|
return_value=[mock_job]
|
|
)
|
|
|
|
# Build a fake batch response whose status triggers the completion branch
|
|
mock_response = MagicMock()
|
|
mock_response.status = "completed"
|
|
mock_response.output_file_id = "file-output-123"
|
|
mock_response.model_dump_json.return_value = (
|
|
'{"id":"batch-1","status":"completed"}'
|
|
)
|
|
|
|
mock_llm_router.aretrieve_batch = AsyncMock(return_value=mock_response)
|
|
mock_llm_router.get_deployment_credentials_with_provider = MagicMock(
|
|
return_value={"api_key": "sk-test"}
|
|
)
|
|
|
|
mock_deployment = MagicMock()
|
|
mock_deployment.litellm_params.custom_llm_provider = "openai"
|
|
mock_deployment.litellm_params.model = "gpt-4"
|
|
mock_deployment.model_info.model_dump.return_value = {}
|
|
mock_llm_router.get_deployment = MagicMock(return_value=mock_deployment)
|
|
|
|
mock_file_content = MagicMock()
|
|
mock_file_content.content = b'{"id":"req-1"}'
|
|
|
|
decoded_id = "llm_model_id,model-123;llm_batch_id,batch-456;"
|
|
|
|
with (
|
|
patch(
|
|
"litellm.proxy.openai_files_endpoints.common_utils._is_base64_encoded_unified_file_id",
|
|
side_effect=[decoded_id, None],
|
|
),
|
|
patch(
|
|
"litellm.proxy.openai_files_endpoints.common_utils.get_model_id_from_unified_batch_id",
|
|
return_value="model-123",
|
|
),
|
|
patch(
|
|
"litellm.proxy.openai_files_endpoints.common_utils.get_batch_id_from_unified_batch_id",
|
|
return_value="batch-456",
|
|
),
|
|
patch(
|
|
"litellm.files.main.afile_content",
|
|
new_callable=AsyncMock,
|
|
return_value=mock_file_content,
|
|
),
|
|
patch(
|
|
"litellm.batches.batch_utils._get_file_content_as_dictionary",
|
|
return_value=[{"id": "req-1"}],
|
|
),
|
|
patch(
|
|
"litellm.batches.batch_utils.calculate_batch_cost_and_usage",
|
|
new_callable=AsyncMock,
|
|
return_value=(
|
|
0.01,
|
|
{"prompt_tokens": 10, "completion_tokens": 5},
|
|
["gpt-4"],
|
|
),
|
|
),
|
|
patch(
|
|
"litellm.litellm_core_utils.get_llm_provider_logic.get_llm_provider",
|
|
return_value=("gpt-4", "openai", None, None),
|
|
),
|
|
patch(
|
|
"litellm.litellm_core_utils.litellm_logging.Logging"
|
|
) as mock_logging_cls,
|
|
):
|
|
mock_logging_obj = MagicMock()
|
|
mock_logging_obj.async_success_handler = AsyncMock()
|
|
mock_logging_cls.return_value = mock_logging_obj
|
|
|
|
await check_batch_cost_instance.check_batch_cost()
|
|
|
|
# The update must have been called — this is the core assertion.
|
|
assert (
|
|
mock_prisma_client.db.litellm_managedobjecttable.update.call_count == 1
|
|
), "Expected update() to be called exactly once for the completed job"
|
|
update_data = mock_prisma_client.db.litellm_managedobjecttable.update.call_args[
|
|
1
|
|
]["data"]
|
|
assert (
|
|
"batch_processed" not in update_data
|
|
), "update() must NOT include batch_processed when column is absent"
|
|
assert update_data["status"] == "complete"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_primary_path_completion_update_includes_batch_processed(
|
|
self, check_batch_cost_instance, mock_prisma_client, mock_llm_router
|
|
):
|
|
"""When batch_processed column IS present, completion update must set it to True.
|
|
|
|
This is the symmetric counterpart to test_fallback_completion_update_omits_batch_processed
|
|
and proves the conditional on _has_batch_processed_column governs the update data.
|
|
"""
|
|
from unittest.mock import patch
|
|
|
|
mock_prisma_client.db.litellm_managedobjecttable.update_many = AsyncMock(
|
|
return_value=0
|
|
)
|
|
mock_prisma_client.db.litellm_managedobjecttable.update = AsyncMock()
|
|
mock_prisma_client.db.litellm_usertable.find_unique = AsyncMock(
|
|
return_value=None
|
|
)
|
|
|
|
mock_job = MagicMock()
|
|
mock_job.id = "job-primary-1"
|
|
mock_job.unified_object_id = "dW5pZmllZF9iYXRjaF9pZA=="
|
|
mock_job.created_by = "user-1"
|
|
|
|
assert check_batch_cost_instance._has_batch_processed_column is True
|
|
mock_prisma_client.db.litellm_managedobjecttable.find_many = AsyncMock(
|
|
return_value=[mock_job]
|
|
)
|
|
|
|
mock_response = MagicMock()
|
|
mock_response.status = "completed"
|
|
mock_response.output_file_id = "file-output-123"
|
|
mock_response.model_dump_json.return_value = (
|
|
'{"id":"batch-1","status":"completed"}'
|
|
)
|
|
|
|
mock_llm_router.aretrieve_batch = AsyncMock(return_value=mock_response)
|
|
mock_llm_router.get_deployment_credentials_with_provider = MagicMock(
|
|
return_value={"api_key": "sk-test"}
|
|
)
|
|
|
|
mock_deployment = MagicMock()
|
|
mock_deployment.litellm_params.custom_llm_provider = "openai"
|
|
mock_deployment.litellm_params.model = "gpt-4"
|
|
mock_deployment.model_info.model_dump.return_value = {}
|
|
mock_llm_router.get_deployment = MagicMock(return_value=mock_deployment)
|
|
|
|
mock_file_content = MagicMock()
|
|
mock_file_content.content = b'{"id":"req-1"}'
|
|
|
|
decoded_id = "llm_model_id,model-123;llm_batch_id,batch-456;"
|
|
|
|
with (
|
|
patch(
|
|
"litellm.proxy.openai_files_endpoints.common_utils._is_base64_encoded_unified_file_id",
|
|
side_effect=[decoded_id, None],
|
|
),
|
|
patch(
|
|
"litellm.proxy.openai_files_endpoints.common_utils.get_model_id_from_unified_batch_id",
|
|
return_value="model-123",
|
|
),
|
|
patch(
|
|
"litellm.proxy.openai_files_endpoints.common_utils.get_batch_id_from_unified_batch_id",
|
|
return_value="batch-456",
|
|
),
|
|
patch(
|
|
"litellm.files.main.afile_content",
|
|
new_callable=AsyncMock,
|
|
return_value=mock_file_content,
|
|
),
|
|
patch(
|
|
"litellm.batches.batch_utils._get_file_content_as_dictionary",
|
|
return_value=[{"id": "req-1"}],
|
|
),
|
|
patch(
|
|
"litellm.batches.batch_utils.calculate_batch_cost_and_usage",
|
|
new_callable=AsyncMock,
|
|
return_value=(
|
|
0.01,
|
|
{"prompt_tokens": 10, "completion_tokens": 5},
|
|
["gpt-4"],
|
|
),
|
|
),
|
|
patch(
|
|
"litellm.litellm_core_utils.get_llm_provider_logic.get_llm_provider",
|
|
return_value=("gpt-4", "openai", None, None),
|
|
),
|
|
patch(
|
|
"litellm.litellm_core_utils.litellm_logging.Logging"
|
|
) as mock_logging_cls,
|
|
):
|
|
mock_logging_obj = MagicMock()
|
|
mock_logging_obj.async_success_handler = AsyncMock()
|
|
mock_logging_cls.return_value = mock_logging_obj
|
|
|
|
await check_batch_cost_instance.check_batch_cost()
|
|
|
|
assert (
|
|
mock_prisma_client.db.litellm_managedobjecttable.update.call_count == 1
|
|
), "Expected update() to be called exactly once for the completed job"
|
|
update_data = mock_prisma_client.db.litellm_managedobjecttable.update.call_args[
|
|
1
|
|
]["data"]
|
|
assert (
|
|
update_data["batch_processed"] is True
|
|
), "update() must include batch_processed=True when column is present"
|
|
assert update_data["status"] == "complete"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_raw_output_file_id_converted_to_managed_id(
|
|
self, check_batch_cost_instance, mock_prisma_client, mock_llm_router
|
|
):
|
|
"""CheckBatchCost must convert a raw provider output_file_id to a managed base64 ID.
|
|
|
|
Without this, GET /batches/{id} returns a raw file ID that cannot be routed
|
|
through the proxy, causing API_KEY errors when clients call GET /files/{id}/content.
|
|
"""
|
|
mock_prisma_client.db.litellm_managedobjecttable.update_many = AsyncMock(
|
|
return_value=0
|
|
)
|
|
mock_prisma_client.db.litellm_managedobjecttable.update = AsyncMock()
|
|
mock_prisma_client.db.litellm_usertable.find_unique = AsyncMock(
|
|
return_value=None
|
|
)
|
|
|
|
mock_job = MagicMock()
|
|
mock_job.id = "job-raw-file-1"
|
|
mock_job.unified_object_id = "dW5pZmllZF9iYXRjaF9pZA=="
|
|
mock_job.created_by = "user-1"
|
|
mock_job.team_id = None
|
|
|
|
check_batch_cost_instance._has_batch_processed_column = True
|
|
mock_prisma_client.db.litellm_managedobjecttable.find_many = AsyncMock(
|
|
return_value=[mock_job]
|
|
)
|
|
|
|
raw_output_file_id = "file-batch-output-abc123"
|
|
raw_error_file_id = "file-batch-error-xyz456"
|
|
fake_managed_output_id = "bGl0ZWxsbV9wcm94eTo6b3V0cHV0"
|
|
fake_managed_error_id = "bGl0ZWxsbV9wcm94eTo6ZXJyb3I="
|
|
|
|
mock_response = MagicMock()
|
|
mock_response.status = "completed"
|
|
mock_response.output_file_id = raw_output_file_id
|
|
mock_response.error_file_id = raw_error_file_id
|
|
mock_response.model_dump_json.return_value = (
|
|
'{"id":"batch-1","status":"completed"}'
|
|
)
|
|
|
|
mock_llm_router.aretrieve_batch = AsyncMock(return_value=mock_response)
|
|
mock_llm_router.get_deployment_credentials_with_provider = MagicMock(
|
|
return_value={"api_key": "sk-test"}
|
|
)
|
|
|
|
mock_deployment = MagicMock()
|
|
mock_deployment.litellm_params.custom_llm_provider = "azure"
|
|
mock_deployment.litellm_params.model = "azure/gpt-5-mini"
|
|
mock_deployment.model_name = "gpt-5-batch"
|
|
mock_deployment.model_info.model_dump.return_value = {}
|
|
mock_llm_router.get_deployment = MagicMock(return_value=mock_deployment)
|
|
|
|
mock_hook = MagicMock()
|
|
mock_hook.get_unified_output_file_id.side_effect = [
|
|
fake_managed_output_id,
|
|
fake_managed_error_id,
|
|
]
|
|
mock_hook.store_unified_file_id = AsyncMock()
|
|
check_batch_cost_instance.proxy_logging_obj.get_proxy_hook.return_value = (
|
|
mock_hook
|
|
)
|
|
|
|
mock_file_content = MagicMock()
|
|
mock_file_content.content = b'{"id":"req-1"}'
|
|
decoded_id = "llm_model_id,model-123;llm_batch_id,batch-456;"
|
|
|
|
with (
|
|
patch(
|
|
"litellm.proxy.openai_files_endpoints.common_utils._is_base64_encoded_unified_file_id",
|
|
# call 1: job unified_object_id decode, call 2: existing raw check for output_file_id,
|
|
# call 3: fix guard for output_file_id, call 4: fix guard for error_file_id
|
|
side_effect=[decoded_id, None, None, None],
|
|
),
|
|
patch(
|
|
"litellm.proxy.openai_files_endpoints.common_utils.get_model_id_from_unified_batch_id",
|
|
return_value="model-123",
|
|
),
|
|
patch(
|
|
"litellm.proxy.openai_files_endpoints.common_utils.get_batch_id_from_unified_batch_id",
|
|
return_value="batch-456",
|
|
),
|
|
patch(
|
|
"litellm.files.main.afile_content",
|
|
new_callable=AsyncMock,
|
|
return_value=mock_file_content,
|
|
),
|
|
patch(
|
|
"litellm.batches.batch_utils._get_file_content_as_dictionary",
|
|
return_value=[{"id": "req-1"}],
|
|
),
|
|
patch(
|
|
"litellm.batches.batch_utils.calculate_batch_cost_and_usage",
|
|
new_callable=AsyncMock,
|
|
return_value=(
|
|
0.01,
|
|
{"prompt_tokens": 10, "completion_tokens": 5},
|
|
["gpt-4"],
|
|
),
|
|
),
|
|
patch(
|
|
"litellm.litellm_core_utils.get_llm_provider_logic.get_llm_provider",
|
|
return_value=("gpt-5-mini", "azure", None, None),
|
|
),
|
|
patch(
|
|
"litellm.litellm_core_utils.litellm_logging.Logging"
|
|
) as mock_logging_cls,
|
|
):
|
|
mock_logging_obj = MagicMock()
|
|
mock_logging_obj.async_success_handler = AsyncMock()
|
|
mock_logging_cls.return_value = mock_logging_obj
|
|
|
|
await check_batch_cost_instance.check_batch_cost()
|
|
|
|
assert mock_hook.get_unified_output_file_id.call_count == 2
|
|
mock_hook.get_unified_output_file_id.assert_any_call(
|
|
output_file_id=raw_output_file_id,
|
|
model_id="model-123",
|
|
model_name="gpt-5-mini",
|
|
)
|
|
mock_hook.get_unified_output_file_id.assert_any_call(
|
|
output_file_id=raw_error_file_id,
|
|
model_id="model-123",
|
|
model_name="gpt-5-mini",
|
|
)
|
|
assert mock_hook.store_unified_file_id.await_count == 2
|
|
# {raw_file_id: managed_file_id} for each store call
|
|
stored = {
|
|
next(iter(c[1]["model_mappings"].values())): c[1]["file_id"]
|
|
for c in mock_hook.store_unified_file_id.call_args_list
|
|
}
|
|
assert stored == {
|
|
raw_output_file_id: fake_managed_output_id,
|
|
raw_error_file_id: fake_managed_error_id,
|
|
}
|
|
assert mock_response.output_file_id == fake_managed_output_id
|
|
assert mock_response.error_file_id == fake_managed_error_id
|