mirror of
https://github.com/tiennm99/litellm.git
synced 2026-06-17 16:48:54 +00:00
207 lines
7.1 KiB
Python
207 lines
7.1 KiB
Python
"""
|
|
Test search API logging and cost tracking in proxy.
|
|
|
|
Tests that search API requests are properly logged to LiteLLM_SpendLogs
|
|
with correct fields populated (call_type, model, custom_llm_provider,
|
|
model_group, spend, etc.)
|
|
"""
|
|
|
|
import asyncio
|
|
import os
|
|
import sys
|
|
import time
|
|
from datetime import datetime
|
|
from unittest.mock import AsyncMock, patch
|
|
|
|
import pytest
|
|
|
|
sys.path.insert(0, os.path.abspath("../.."))
|
|
import litellm
|
|
from litellm import Router
|
|
from litellm.caching import DualCache
|
|
from litellm.proxy._types import UserAPIKeyAuth
|
|
from litellm.proxy.hooks.proxy_track_cost_callback import _ProxyDBLogger
|
|
from litellm.proxy.spend_tracking.spend_management_endpoints import view_spend_logs
|
|
from litellm.proxy.utils import ProxyLogging, hash_token, update_spend
|
|
from litellm.llms.base_llm.search.transformation import SearchResponse, SearchResult
|
|
|
|
|
|
@pytest.fixture
|
|
def prisma_client():
|
|
from litellm.proxy import proxy_server
|
|
from litellm.proxy.proxy_cli import append_query_params
|
|
from litellm.proxy.utils import PrismaClient
|
|
|
|
params = {"connection_limit": 100, "pool_timeout": 60}
|
|
database_url = os.getenv("DATABASE_URL")
|
|
if database_url is None:
|
|
pytest.skip("DATABASE_URL not set")
|
|
|
|
modified_url = append_query_params(database_url, params)
|
|
os.environ["DATABASE_URL"] = modified_url
|
|
|
|
user_api_key_cache = DualCache()
|
|
proxy_logging_obj = ProxyLogging(user_api_key_cache=user_api_key_cache)
|
|
|
|
prisma_client = PrismaClient(
|
|
database_url=os.environ["DATABASE_URL"], proxy_logging_obj=proxy_logging_obj
|
|
)
|
|
|
|
proxy_server.litellm_proxy_budget_name = f"litellm-proxy-budget-{time.time()}"
|
|
proxy_server.user_custom_key_generate = None
|
|
|
|
return prisma_client
|
|
|
|
|
|
@pytest.mark.skip(reason="Requires reliable external DB connection (prisma).")
|
|
@pytest.mark.asyncio
|
|
async def test_search_api_logging_and_cost_tracking(prisma_client):
|
|
"""
|
|
Test that search API requests are logged with correct fields and cost tracking.
|
|
|
|
Verifies:
|
|
1. Search request creates a spend log entry
|
|
2. call_type is set to "asearch"
|
|
3. model is set to search_tool_name
|
|
4. custom_llm_provider is set correctly
|
|
5. model_group is set to search_tool_name
|
|
6. spend is calculated and logged
|
|
"""
|
|
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
|
|
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
|
|
await litellm.proxy.proxy_server.prisma_client.connect()
|
|
|
|
# Setup router with search tool
|
|
search_tool_name = "tavily-search"
|
|
search_provider = "tavily"
|
|
|
|
router = Router(model_list=[])
|
|
router.search_tools = [
|
|
{
|
|
"search_tool_name": search_tool_name,
|
|
"litellm_params": {
|
|
"search_provider": search_provider,
|
|
},
|
|
}
|
|
]
|
|
|
|
setattr(litellm.proxy.proxy_server, "llm_router", router)
|
|
|
|
# Generate a test API key
|
|
from litellm.proxy.management_endpoints.key_management_endpoints import (
|
|
generate_key_fn,
|
|
)
|
|
from litellm.proxy._types import GenerateKeyRequest
|
|
|
|
from litellm.proxy._types import LitellmUserRoles
|
|
|
|
user_api_key_dict = UserAPIKeyAuth(
|
|
user_role=LitellmUserRoles.PROXY_ADMIN,
|
|
api_key="sk-1234",
|
|
user_id="test_user",
|
|
)
|
|
|
|
key_request = GenerateKeyRequest(models=[], duration=None)
|
|
key_response = await generate_key_fn(
|
|
data=key_request, user_api_key_dict=user_api_key_dict
|
|
)
|
|
generated_key = key_response.key
|
|
user_id = key_response.user_id
|
|
|
|
# Create mock search response
|
|
mock_search_result = SearchResult(
|
|
title="Test Result",
|
|
url="https://example.com",
|
|
snippet="Test snippet",
|
|
)
|
|
|
|
mock_search_response = SearchResponse(
|
|
object="search",
|
|
results=[mock_search_result],
|
|
)
|
|
|
|
# Mock the search function to return our mock response
|
|
with patch("litellm.search.main.asearch", new_callable=AsyncMock) as mock_asearch:
|
|
mock_asearch.return_value = mock_search_response
|
|
|
|
# Setup proxy logging
|
|
user_api_key_cache = DualCache()
|
|
proxy_logging_obj = ProxyLogging(user_api_key_cache=user_api_key_cache)
|
|
setattr(litellm.proxy.proxy_server, "proxy_logging_obj", proxy_logging_obj)
|
|
|
|
# Call the track_cost_callback directly to simulate what happens after a search
|
|
proxy_db_logger = _ProxyDBLogger()
|
|
|
|
# Simulate the kwargs that would be passed from the search endpoint
|
|
request_id = "search_test_123"
|
|
kwargs = {
|
|
"call_type": "asearch",
|
|
"model": search_tool_name,
|
|
"custom_llm_provider": search_provider,
|
|
"litellm_call_id": request_id, # Set request_id in kwargs
|
|
"litellm_params": {
|
|
"metadata": {
|
|
"user_api_key": hash_token(generated_key),
|
|
"user_api_key_user_id": user_id,
|
|
"model_group": search_tool_name,
|
|
}
|
|
},
|
|
"metadata": {
|
|
"user_api_key": hash_token(generated_key),
|
|
"user_api_key_user_id": user_id,
|
|
"model_group": search_tool_name,
|
|
},
|
|
"response_cost": 0.008, # Mock cost for tavily search
|
|
}
|
|
|
|
# Set id on the response object
|
|
mock_search_response.id = request_id
|
|
|
|
await proxy_db_logger._PROXY_track_cost_callback(
|
|
kwargs=kwargs,
|
|
completion_response=mock_search_response,
|
|
start_time=datetime.now(),
|
|
end_time=datetime.now(),
|
|
)
|
|
|
|
# Wait for async operations
|
|
await asyncio.sleep(2)
|
|
await update_spend(
|
|
prisma_client=prisma_client,
|
|
db_writer_client=None,
|
|
proxy_logging_obj=proxy_logging_obj,
|
|
)
|
|
|
|
# Query spend logs
|
|
spend_logs = await view_spend_logs(
|
|
request_id=request_id,
|
|
user_api_key_dict=UserAPIKeyAuth(api_key=generated_key),
|
|
)
|
|
|
|
# Verify spend log was created
|
|
assert len(spend_logs) == 1, f"Expected 1 spend log, got {len(spend_logs)}"
|
|
|
|
spend_log = spend_logs[0]
|
|
|
|
# Verify all fields are populated correctly
|
|
assert spend_log.request_id == request_id
|
|
assert spend_log.call_type == "asearch"
|
|
assert spend_log.model == search_tool_name
|
|
assert spend_log.custom_llm_provider == search_provider
|
|
assert spend_log.model_group == search_tool_name
|
|
assert spend_log.spend == 0.008
|
|
# API key should be hashed (either the generated key or the one from metadata)
|
|
assert spend_log.api_key != "" # Should be populated
|
|
# Note: user field may be empty if not set in the request, but user_id should be in metadata
|
|
assert (
|
|
spend_log.metadata.get("user_api_key_user_id") == user_id
|
|
or spend_log.user == user_id
|
|
)
|
|
|
|
print(f"✅ Search API logging test passed!")
|
|
print(f" - call_type: {spend_log.call_type}")
|
|
print(f" - model: {spend_log.model}")
|
|
print(f" - custom_llm_provider: {spend_log.custom_llm_provider}")
|
|
print(f" - model_group: {spend_log.model_group}")
|
|
print(f" - spend: {spend_log.spend}")
|