Files
litellm/tests/proxy_unit_tests/test_search_api_logging.py
2026-04-17 13:02:59 -07:00

207 lines
7.1 KiB
Python

"""
Test search API logging and cost tracking in proxy.
Tests that search API requests are properly logged to LiteLLM_SpendLogs
with correct fields populated (call_type, model, custom_llm_provider,
model_group, spend, etc.)
"""
import asyncio
import os
import sys
import time
from datetime import datetime
from unittest.mock import AsyncMock, patch
import pytest
sys.path.insert(0, os.path.abspath("../.."))
import litellm
from litellm import Router
from litellm.caching import DualCache
from litellm.proxy._types import UserAPIKeyAuth
from litellm.proxy.hooks.proxy_track_cost_callback import _ProxyDBLogger
from litellm.proxy.spend_tracking.spend_management_endpoints import view_spend_logs
from litellm.proxy.utils import ProxyLogging, hash_token, update_spend
from litellm.llms.base_llm.search.transformation import SearchResponse, SearchResult
@pytest.fixture
def prisma_client():
from litellm.proxy import proxy_server
from litellm.proxy.proxy_cli import append_query_params
from litellm.proxy.utils import PrismaClient
params = {"connection_limit": 100, "pool_timeout": 60}
database_url = os.getenv("DATABASE_URL")
if database_url is None:
pytest.skip("DATABASE_URL not set")
modified_url = append_query_params(database_url, params)
os.environ["DATABASE_URL"] = modified_url
user_api_key_cache = DualCache()
proxy_logging_obj = ProxyLogging(user_api_key_cache=user_api_key_cache)
prisma_client = PrismaClient(
database_url=os.environ["DATABASE_URL"], proxy_logging_obj=proxy_logging_obj
)
proxy_server.litellm_proxy_budget_name = f"litellm-proxy-budget-{time.time()}"
proxy_server.user_custom_key_generate = None
return prisma_client
@pytest.mark.skip(reason="Requires reliable external DB connection (prisma).")
@pytest.mark.asyncio
async def test_search_api_logging_and_cost_tracking(prisma_client):
"""
Test that search API requests are logged with correct fields and cost tracking.
Verifies:
1. Search request creates a spend log entry
2. call_type is set to "asearch"
3. model is set to search_tool_name
4. custom_llm_provider is set correctly
5. model_group is set to search_tool_name
6. spend is calculated and logged
"""
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
await litellm.proxy.proxy_server.prisma_client.connect()
# Setup router with search tool
search_tool_name = "tavily-search"
search_provider = "tavily"
router = Router(model_list=[])
router.search_tools = [
{
"search_tool_name": search_tool_name,
"litellm_params": {
"search_provider": search_provider,
},
}
]
setattr(litellm.proxy.proxy_server, "llm_router", router)
# Generate a test API key
from litellm.proxy.management_endpoints.key_management_endpoints import (
generate_key_fn,
)
from litellm.proxy._types import GenerateKeyRequest
from litellm.proxy._types import LitellmUserRoles
user_api_key_dict = UserAPIKeyAuth(
user_role=LitellmUserRoles.PROXY_ADMIN,
api_key="sk-1234",
user_id="test_user",
)
key_request = GenerateKeyRequest(models=[], duration=None)
key_response = await generate_key_fn(
data=key_request, user_api_key_dict=user_api_key_dict
)
generated_key = key_response.key
user_id = key_response.user_id
# Create mock search response
mock_search_result = SearchResult(
title="Test Result",
url="https://example.com",
snippet="Test snippet",
)
mock_search_response = SearchResponse(
object="search",
results=[mock_search_result],
)
# Mock the search function to return our mock response
with patch("litellm.search.main.asearch", new_callable=AsyncMock) as mock_asearch:
mock_asearch.return_value = mock_search_response
# Setup proxy logging
user_api_key_cache = DualCache()
proxy_logging_obj = ProxyLogging(user_api_key_cache=user_api_key_cache)
setattr(litellm.proxy.proxy_server, "proxy_logging_obj", proxy_logging_obj)
# Call the track_cost_callback directly to simulate what happens after a search
proxy_db_logger = _ProxyDBLogger()
# Simulate the kwargs that would be passed from the search endpoint
request_id = "search_test_123"
kwargs = {
"call_type": "asearch",
"model": search_tool_name,
"custom_llm_provider": search_provider,
"litellm_call_id": request_id, # Set request_id in kwargs
"litellm_params": {
"metadata": {
"user_api_key": hash_token(generated_key),
"user_api_key_user_id": user_id,
"model_group": search_tool_name,
}
},
"metadata": {
"user_api_key": hash_token(generated_key),
"user_api_key_user_id": user_id,
"model_group": search_tool_name,
},
"response_cost": 0.008, # Mock cost for tavily search
}
# Set id on the response object
mock_search_response.id = request_id
await proxy_db_logger._PROXY_track_cost_callback(
kwargs=kwargs,
completion_response=mock_search_response,
start_time=datetime.now(),
end_time=datetime.now(),
)
# Wait for async operations
await asyncio.sleep(2)
await update_spend(
prisma_client=prisma_client,
db_writer_client=None,
proxy_logging_obj=proxy_logging_obj,
)
# Query spend logs
spend_logs = await view_spend_logs(
request_id=request_id,
user_api_key_dict=UserAPIKeyAuth(api_key=generated_key),
)
# Verify spend log was created
assert len(spend_logs) == 1, f"Expected 1 spend log, got {len(spend_logs)}"
spend_log = spend_logs[0]
# Verify all fields are populated correctly
assert spend_log.request_id == request_id
assert spend_log.call_type == "asearch"
assert spend_log.model == search_tool_name
assert spend_log.custom_llm_provider == search_provider
assert spend_log.model_group == search_tool_name
assert spend_log.spend == 0.008
# API key should be hashed (either the generated key or the one from metadata)
assert spend_log.api_key != "" # Should be populated
# Note: user field may be empty if not set in the request, but user_id should be in metadata
assert (
spend_log.metadata.get("user_api_key_user_id") == user_id
or spend_log.user == user_id
)
print(f"✅ Search API logging test passed!")
print(f" - call_type: {spend_log.call_type}")
print(f" - model: {spend_log.model}")
print(f" - custom_llm_provider: {spend_log.custom_llm_provider}")
print(f" - model_group: {spend_log.model_group}")
print(f" - spend: {spend_log.spend}")