mirror of
https://github.com/tiennm99/litellm.git
synced 2026-06-18 00:48:01 +00:00
481 lines
17 KiB
Python
481 lines
17 KiB
Python
"""
|
|
Unit tests for CheckResponsesCost class
|
|
"""
|
|
|
|
import asyncio
|
|
from datetime import datetime
|
|
from unittest.mock import AsyncMock, MagicMock, Mock, patch
|
|
|
|
import pytest
|
|
|
|
from litellm.constants import MAX_OBJECTS_PER_POLL_CYCLE
|
|
from litellm.types.llms.openai import ResponseAPIUsage, ResponsesAPIResponse
|
|
|
|
|
|
class TestCheckResponsesCost:
|
|
"""Test suite for CheckResponsesCost class"""
|
|
|
|
@pytest.fixture
|
|
def mock_prisma_client(self):
|
|
"""Create a mock Prisma client"""
|
|
client = MagicMock()
|
|
client.db = MagicMock()
|
|
client.db.litellm_managedobjecttable = MagicMock()
|
|
return client
|
|
|
|
@pytest.fixture
|
|
def mock_proxy_logging_obj(self):
|
|
"""Create a mock ProxyLogging object"""
|
|
logging_obj = MagicMock()
|
|
logging_obj.get_proxy_hook = MagicMock(return_value=None)
|
|
return logging_obj
|
|
|
|
@pytest.fixture
|
|
def mock_llm_router(self):
|
|
"""Create a mock LLM Router"""
|
|
router = MagicMock()
|
|
router.aget_responses = AsyncMock()
|
|
router.get_deployment = MagicMock()
|
|
return router
|
|
|
|
@pytest.fixture
|
|
def check_responses_cost_instance(
|
|
self, mock_proxy_logging_obj, mock_prisma_client, mock_llm_router
|
|
):
|
|
"""Create a CheckResponsesCost instance with mocked dependencies"""
|
|
from litellm_enterprise.proxy.common_utils.check_responses_cost import (
|
|
CheckResponsesCost,
|
|
)
|
|
|
|
instance = CheckResponsesCost(
|
|
proxy_logging_obj=mock_proxy_logging_obj,
|
|
prisma_client=mock_prisma_client,
|
|
llm_router=mock_llm_router,
|
|
)
|
|
# Mock _expire_stale_rows (raw SQL) so _cleanup_stale_managed_objects
|
|
# succeeds without a real DB. Individual tests can override this.
|
|
instance._expire_stale_rows = AsyncMock(return_value=0)
|
|
return instance
|
|
|
|
def test_initialization(self, check_responses_cost_instance):
|
|
"""Test that CheckResponsesCost initializes correctly"""
|
|
assert check_responses_cost_instance.proxy_logging_obj is not None
|
|
assert check_responses_cost_instance.prisma_client is not None
|
|
assert check_responses_cost_instance.llm_router is not None
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_check_responses_cost_no_jobs(
|
|
self, check_responses_cost_instance, mock_prisma_client
|
|
):
|
|
"""Test check_responses_cost when there are no jobs to process"""
|
|
mock_prisma_client.db.litellm_managedobjecttable.find_many = AsyncMock(
|
|
return_value=[]
|
|
)
|
|
|
|
await check_responses_cost_instance.check_responses_cost()
|
|
|
|
# Verify find_many was called with pagination params
|
|
find_many_call = (
|
|
mock_prisma_client.db.litellm_managedobjecttable.find_many.call_args
|
|
)
|
|
assert find_many_call[1]["where"] == {
|
|
"status": {"in": ["queued", "in_progress"]},
|
|
"file_purpose": "response",
|
|
}
|
|
assert find_many_call[1]["take"] == MAX_OBJECTS_PER_POLL_CYCLE
|
|
assert find_many_call[1]["order"] == {"created_at": "asc"}
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_cleanup_stale_managed_objects(
|
|
self, check_responses_cost_instance, mock_prisma_client
|
|
):
|
|
"""Stale rows are expired via _expire_stale_rows before polling."""
|
|
from litellm.constants import STALE_OBJECT_CLEANUP_BATCH_SIZE
|
|
|
|
check_responses_cost_instance._expire_stale_rows = AsyncMock(return_value=5)
|
|
mock_prisma_client.db.litellm_managedobjecttable.find_many = AsyncMock(
|
|
return_value=[]
|
|
)
|
|
|
|
await check_responses_cost_instance.check_responses_cost()
|
|
|
|
# _expire_stale_rows should have been called with a cutoff datetime and batch size
|
|
check_responses_cost_instance._expire_stale_rows.assert_called_once()
|
|
call_args = check_responses_cost_instance._expire_stale_rows.call_args
|
|
assert call_args[0][1] == STALE_OBJECT_CLEANUP_BATCH_SIZE
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_check_responses_cost_with_completed_response(
|
|
self, check_responses_cost_instance, mock_prisma_client, mock_llm_router
|
|
):
|
|
"""Test check_responses_cost with a completed response"""
|
|
# Mock job with response ID
|
|
mock_job = MagicMock()
|
|
mock_job.unified_object_id = "resp_test_123"
|
|
mock_job.created_by = "test-user"
|
|
mock_job.id = "job-123"
|
|
mock_job.file_object = {"model": "gpt-4o", "id": "resp_test_123"}
|
|
|
|
mock_prisma_client.db.litellm_managedobjecttable.find_many = AsyncMock(
|
|
return_value=[mock_job]
|
|
)
|
|
|
|
# Mock completed response
|
|
mock_response = ResponsesAPIResponse(
|
|
id="resp_123",
|
|
object="response",
|
|
status="completed",
|
|
created_at=int(datetime.now().timestamp()),
|
|
output=[],
|
|
usage=ResponseAPIUsage(
|
|
input_tokens=100,
|
|
output_tokens=50,
|
|
total_tokens=150,
|
|
),
|
|
)
|
|
|
|
mock_prisma_client.db.litellm_managedobjecttable.update_many = AsyncMock(
|
|
return_value=0
|
|
)
|
|
|
|
# Run the check with mocked litellm.aget_responses
|
|
with patch("litellm.aget_responses", new_callable=AsyncMock) as mock_aget:
|
|
mock_aget.return_value = mock_response
|
|
|
|
await check_responses_cost_instance.check_responses_cost()
|
|
|
|
# update_many should only contain the job completion call
|
|
calls = (
|
|
mock_prisma_client.db.litellm_managedobjecttable.update_many.call_args_list
|
|
)
|
|
assert len(calls) == 1
|
|
completion_call = calls[0]
|
|
assert completion_call[1]["data"]["status"] == "completed"
|
|
assert completion_call[1]["where"]["id"]["in"] == ["job-123"]
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_check_responses_cost_with_failed_response(
|
|
self, check_responses_cost_instance, mock_prisma_client, mock_llm_router
|
|
):
|
|
"""Test check_responses_cost with a failed response"""
|
|
# Mock job
|
|
mock_job = MagicMock()
|
|
mock_job.unified_object_id = "resp_test_456"
|
|
mock_job.created_by = "test-user"
|
|
mock_job.id = "job-456"
|
|
mock_job.file_object = {"model": "gpt-4o", "id": "resp_test_456"}
|
|
|
|
mock_prisma_client.db.litellm_managedobjecttable.find_many = AsyncMock(
|
|
return_value=[mock_job]
|
|
)
|
|
|
|
# Mock failed response
|
|
mock_response = ResponsesAPIResponse(
|
|
id="resp_456",
|
|
object="response",
|
|
status="failed",
|
|
created_at=int(datetime.now().timestamp()),
|
|
output=[],
|
|
usage=None,
|
|
)
|
|
|
|
mock_prisma_client.db.litellm_managedobjecttable.update_many = AsyncMock(
|
|
return_value=0
|
|
)
|
|
|
|
# Run the check
|
|
with patch("litellm.aget_responses", new_callable=AsyncMock) as mock_aget:
|
|
mock_aget.return_value = mock_response
|
|
|
|
await check_responses_cost_instance.check_responses_cost()
|
|
|
|
# update_many should only contain the job completion call
|
|
calls = (
|
|
mock_prisma_client.db.litellm_managedobjecttable.update_many.call_args_list
|
|
)
|
|
assert len(calls) == 1
|
|
assert calls[0][1]["data"]["status"] == "completed"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_check_responses_cost_with_cancelled_response(
|
|
self, check_responses_cost_instance, mock_prisma_client
|
|
):
|
|
"""Test check_responses_cost with a cancelled response"""
|
|
# Mock job
|
|
mock_job = MagicMock()
|
|
mock_job.unified_object_id = "resp_test_789"
|
|
mock_job.created_by = "test-user"
|
|
mock_job.id = "job-789"
|
|
mock_job.file_object = {"model": "gpt-4o", "id": "resp_test_789"}
|
|
|
|
mock_prisma_client.db.litellm_managedobjecttable.find_many = AsyncMock(
|
|
return_value=[mock_job]
|
|
)
|
|
|
|
# Mock cancelled response
|
|
mock_response = ResponsesAPIResponse(
|
|
id="resp_789",
|
|
object="response",
|
|
status="cancelled",
|
|
created_at=int(datetime.now().timestamp()),
|
|
output=[],
|
|
usage=None,
|
|
)
|
|
|
|
mock_prisma_client.db.litellm_managedobjecttable.update_many = AsyncMock(
|
|
return_value=0
|
|
)
|
|
|
|
# Run the check
|
|
with patch("litellm.aget_responses", new_callable=AsyncMock) as mock_aget:
|
|
mock_aget.return_value = mock_response
|
|
|
|
await check_responses_cost_instance.check_responses_cost()
|
|
|
|
# update_many should only contain the job completion call
|
|
calls = (
|
|
mock_prisma_client.db.litellm_managedobjecttable.update_many.call_args_list
|
|
)
|
|
assert len(calls) == 1
|
|
assert calls[0][1]["data"]["status"] == "completed"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_check_responses_cost_with_in_progress_response(
|
|
self, check_responses_cost_instance, mock_prisma_client
|
|
):
|
|
"""Test check_responses_cost with a response still in progress"""
|
|
# Mock job
|
|
mock_job = MagicMock()
|
|
mock_job.unified_object_id = "resp_test_in_progress"
|
|
mock_job.created_by = "test-user"
|
|
mock_job.id = "job-in-progress"
|
|
mock_job.file_object = {"model": "gpt-4o", "id": "resp_test_in_progress"}
|
|
|
|
mock_prisma_client.db.litellm_managedobjecttable.find_many = AsyncMock(
|
|
return_value=[mock_job]
|
|
)
|
|
|
|
# Mock in-progress response
|
|
mock_response = ResponsesAPIResponse(
|
|
id="resp_in_progress",
|
|
object="response",
|
|
status="in_progress",
|
|
created_at=int(datetime.now().timestamp()),
|
|
output=[],
|
|
usage=None,
|
|
)
|
|
|
|
mock_prisma_client.db.litellm_managedobjecttable.update_many = AsyncMock(
|
|
return_value=0
|
|
)
|
|
|
|
# Run the check
|
|
with patch("litellm.aget_responses", new_callable=AsyncMock) as mock_aget:
|
|
mock_aget.return_value = mock_response
|
|
|
|
await check_responses_cost_instance.check_responses_cost()
|
|
|
|
# No job completion update_many — response is still in progress
|
|
calls = (
|
|
mock_prisma_client.db.litellm_managedobjecttable.update_many.call_args_list
|
|
)
|
|
assert len(calls) == 0
|
|
# Stale cleanup still ran via _expire_stale_rows
|
|
check_responses_cost_instance._expire_stale_rows.assert_called_once()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_check_responses_cost_with_queued_response(
|
|
self, check_responses_cost_instance, mock_prisma_client
|
|
):
|
|
"""Test check_responses_cost with a queued response"""
|
|
# Mock job
|
|
mock_job = MagicMock()
|
|
mock_job.unified_object_id = "resp_test_queued"
|
|
mock_job.created_by = "test-user"
|
|
mock_job.id = "job-queued"
|
|
mock_job.file_object = {"model": "gpt-4o", "id": "resp_test_queued"}
|
|
|
|
mock_prisma_client.db.litellm_managedobjecttable.find_many = AsyncMock(
|
|
return_value=[mock_job]
|
|
)
|
|
|
|
# Mock queued response
|
|
mock_response = ResponsesAPIResponse(
|
|
id="resp_queued",
|
|
object="response",
|
|
status="queued",
|
|
created_at=int(datetime.now().timestamp()),
|
|
output=[],
|
|
usage=None,
|
|
)
|
|
|
|
mock_prisma_client.db.litellm_managedobjecttable.update_many = AsyncMock(
|
|
return_value=0
|
|
)
|
|
|
|
# Run the check
|
|
with patch("litellm.aget_responses", new_callable=AsyncMock) as mock_aget:
|
|
mock_aget.return_value = mock_response
|
|
|
|
await check_responses_cost_instance.check_responses_cost()
|
|
|
|
# No job completion update_many — response is still queued
|
|
calls = (
|
|
mock_prisma_client.db.litellm_managedobjecttable.update_many.call_args_list
|
|
)
|
|
assert len(calls) == 0
|
|
# Stale cleanup still ran via _expire_stale_rows
|
|
check_responses_cost_instance._expire_stale_rows.assert_called_once()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_check_responses_cost_with_exception(
|
|
self, check_responses_cost_instance, mock_prisma_client
|
|
):
|
|
"""Test check_responses_cost handles exceptions gracefully"""
|
|
# Mock job
|
|
mock_job = MagicMock()
|
|
mock_job.unified_object_id = "resp_test_error"
|
|
mock_job.created_by = "test-user"
|
|
mock_job.id = "job-error"
|
|
mock_job.file_object = {"model": "gpt-4o", "id": "resp_test_error"}
|
|
|
|
mock_prisma_client.db.litellm_managedobjecttable.find_many = AsyncMock(
|
|
return_value=[mock_job]
|
|
)
|
|
|
|
mock_prisma_client.db.litellm_managedobjecttable.update_many = AsyncMock(
|
|
return_value=0
|
|
)
|
|
|
|
# Run the check with mocked exception
|
|
with patch(
|
|
"litellm.aget_responses",
|
|
new_callable=AsyncMock,
|
|
side_effect=Exception("Provider error"),
|
|
):
|
|
# Should not raise, just skip the job
|
|
await check_responses_cost_instance.check_responses_cost()
|
|
|
|
# No job completion update_many — exception skipped the job
|
|
calls = (
|
|
mock_prisma_client.db.litellm_managedobjecttable.update_many.call_args_list
|
|
)
|
|
assert len(calls) == 0
|
|
# Stale cleanup still ran via _expire_stale_rows
|
|
check_responses_cost_instance._expire_stale_rows.assert_called_once()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_check_responses_cost_multiple_jobs(
|
|
self, check_responses_cost_instance, mock_prisma_client
|
|
):
|
|
"""Test check_responses_cost with multiple jobs"""
|
|
# Mock multiple jobs
|
|
mock_job1 = MagicMock()
|
|
mock_job1.unified_object_id = "resp_test_1"
|
|
mock_job1.created_by = "user1"
|
|
mock_job1.id = "job-1"
|
|
mock_job1.file_object = {"model": "gpt-4o", "id": "resp_test_1"}
|
|
|
|
mock_job2 = MagicMock()
|
|
mock_job2.unified_object_id = "resp_test_2"
|
|
mock_job2.created_by = "user2"
|
|
mock_job2.id = "job-2"
|
|
mock_job2.file_object = {"model": "gpt-4o", "id": "resp_test_2"}
|
|
|
|
mock_job3 = MagicMock()
|
|
mock_job3.unified_object_id = "resp_test_3"
|
|
mock_job3.created_by = "user3"
|
|
mock_job3.id = "job-3"
|
|
mock_job3.file_object = {"model": "gpt-4o", "id": "resp_test_3"}
|
|
|
|
mock_prisma_client.db.litellm_managedobjecttable.find_many = AsyncMock(
|
|
return_value=[mock_job1, mock_job2, mock_job3]
|
|
)
|
|
|
|
# Mock responses - 2 completed, 1 in progress
|
|
mock_response1 = ResponsesAPIResponse(
|
|
id="resp_1",
|
|
object="response",
|
|
status="completed",
|
|
created_at=int(datetime.now().timestamp()),
|
|
output=[],
|
|
usage=ResponseAPIUsage(
|
|
input_tokens=100,
|
|
output_tokens=50,
|
|
total_tokens=150,
|
|
),
|
|
)
|
|
|
|
mock_response2 = ResponsesAPIResponse(
|
|
id="resp_2",
|
|
object="response",
|
|
status="in_progress",
|
|
created_at=int(datetime.now().timestamp()),
|
|
output=[],
|
|
usage=None,
|
|
)
|
|
|
|
mock_response3 = ResponsesAPIResponse(
|
|
id="resp_3",
|
|
object="response",
|
|
status="completed",
|
|
created_at=int(datetime.now().timestamp()),
|
|
output=[],
|
|
usage=ResponseAPIUsage(
|
|
input_tokens=200,
|
|
output_tokens=100,
|
|
total_tokens=300,
|
|
),
|
|
)
|
|
|
|
mock_prisma_client.db.litellm_managedobjecttable.update_many = AsyncMock(
|
|
return_value=0
|
|
)
|
|
|
|
# Run the check
|
|
with patch("litellm.aget_responses", new_callable=AsyncMock) as mock_aget:
|
|
mock_aget.side_effect = [mock_response1, mock_response2, mock_response3]
|
|
|
|
await check_responses_cost_instance.check_responses_cost()
|
|
|
|
# update_many should only contain the job completion call
|
|
calls = (
|
|
mock_prisma_client.db.litellm_managedobjecttable.update_many.call_args_list
|
|
)
|
|
assert len(calls) == 1
|
|
completion_call = calls[0]
|
|
assert len(completion_call[1]["where"]["id"]["in"]) == 2
|
|
assert "job-1" in completion_call[1]["where"]["id"]["in"]
|
|
assert "job-3" in completion_call[1]["where"]["id"]["in"]
|
|
assert "job-2" not in completion_call[1]["where"]["id"]["in"]
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_check_responses_cost_no_model_in_file_object(
|
|
self, check_responses_cost_instance, mock_prisma_client
|
|
):
|
|
"""When file_object has no 'model' key, model_name is None and metadata skips model fields."""
|
|
mock_job = MagicMock()
|
|
mock_job.unified_object_id = "resp_test_no_model"
|
|
mock_job.created_by = "test-user"
|
|
mock_job.id = "job-no-model"
|
|
mock_job.file_object = {} # no "model" key → model_name=None branch
|
|
|
|
mock_prisma_client.db.litellm_managedobjecttable.find_many = AsyncMock(
|
|
return_value=[mock_job]
|
|
)
|
|
mock_prisma_client.db.litellm_managedobjecttable.update_many = AsyncMock(
|
|
return_value=0
|
|
)
|
|
|
|
mock_response = MagicMock()
|
|
mock_response.status = "completed"
|
|
|
|
with patch("litellm.aget_responses", new_callable=AsyncMock) as mock_aget:
|
|
mock_aget.return_value = mock_response
|
|
await check_responses_cost_instance.check_responses_cost()
|
|
|
|
# aget_responses should be called without model metadata
|
|
call_kwargs = mock_aget.call_args[1]
|
|
assert "model" not in call_kwargs.get("litellm_metadata", {})
|
|
assert "model_group" not in call_kwargs.get("litellm_metadata", {})
|