Files
litellm/tests/logging_callback_tests/test_token_counting.py
T
Ishaan Jaff 86cdb8382b [Feat] Use aiohttp transport by default - 97% lower median latency (#11097)
* fix: add flag for disabling use_aiohttp_transport

* feat: add _create_async_transport

* feat: fixes for transport

* add httpx-aiohttp

* feat: fixes for transport

* refactor: fixes for transport

* build: fix deps

* fixes: test fixes

* fix: ensure aiohttp does not auto set content type

* test: test fixes

* feat: add LiteLLMAiohttpTransport

* fix: fixes for responses API handling

* test: fixes for responses API handling

* test: fixes for responses API handling

* feat: fixes for transport

* fix: base embedding handler

* test: test_async_http_handler_force_ipv4

* test: fix failing deepeval test

* fix: add YARL for bedrock urls

* fix: issues with transport

* fix: comment out linting issues

* test fix

* test: XAI is unstable

* test: fixes for using respx

* test: XAI fixes

* test: XAI fixes

* test: infinity testing fixes

* docs(config_settings.md): document param

* test: test_openai_image_edit_litellm_sdk

* test: remove deprecated test

* bump respx==0.22.0

* test: test_xai_message_name_filtering

* test: fix anthropic test after bumping httpx

* use n 4 for mapped tests (#11109)

* fix: use 1 session per event loop

* test: test_client_session_helper

* fix: linting error

* fix: resolving GET requests on httpx 0.28.1

* test fixes proxy unit tests

* fix: add ssl verify settings

* fix: proxy unit tests

* fix: refactor

* tests: basic unit tests for aiohttp transports

* tests: fixes xai

---------

Co-authored-by: Krrish Dholakia <krrishdholakia@gmail.com>
2025-05-23 22:55:35 -07:00

249 lines
7.9 KiB
Python

import os
import sys
import traceback
import uuid
import pytest
from dotenv import load_dotenv
from fastapi import Request
from fastapi.routing import APIRoute
load_dotenv()
import io
import os
import time
import json
# this file is to test litellm/proxy
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import litellm
import asyncio
from typing import Optional
from litellm.types.utils import StandardLoggingPayload, Usage, ModelInfoBase
from litellm.integrations.custom_logger import CustomLogger
class TestCustomLogger(CustomLogger):
def __init__(self):
self.recorded_usage: Optional[Usage] = None
self.standard_logging_payload: Optional[StandardLoggingPayload] = None
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
standard_logging_payload = kwargs.get("standard_logging_object")
self.standard_logging_payload = standard_logging_payload
print(
"standard_logging_payload",
json.dumps(standard_logging_payload, indent=4, default=str),
)
self.recorded_usage = Usage(
prompt_tokens=standard_logging_payload.get("prompt_tokens"),
completion_tokens=standard_logging_payload.get("completion_tokens"),
total_tokens=standard_logging_payload.get("total_tokens"),
)
pass
@pytest.mark.asyncio
async def test_stream_token_counting_gpt_4o():
"""
When stream_options={"include_usage": True} logging callback tracks Usage == Usage from llm API
"""
custom_logger = TestCustomLogger()
litellm.logging_callback_manager.add_litellm_callback(custom_logger)
response = await litellm.acompletion(
model="gpt-4o",
messages=[{"role": "user", "content": "Hello, how are you?" * 100}],
stream=True,
stream_options={"include_usage": True},
)
actual_usage = None
async for chunk in response:
if "usage" in chunk:
actual_usage = chunk["usage"]
print("chunk.usage", json.dumps(chunk["usage"], indent=4, default=str))
pass
await asyncio.sleep(2)
print("\n\n\n\n\n")
print(
"recorded_usage",
json.dumps(custom_logger.recorded_usage, indent=4, default=str),
)
print("\n\n\n\n\n")
assert actual_usage.prompt_tokens == custom_logger.recorded_usage.prompt_tokens
assert (
actual_usage.completion_tokens == custom_logger.recorded_usage.completion_tokens
)
assert actual_usage.total_tokens == custom_logger.recorded_usage.total_tokens
@pytest.mark.asyncio
async def test_stream_token_counting_without_include_usage():
"""
When stream_options={"include_usage": True} is not passed, the usage tracked == usage from llm api chunk
by default, litellm passes `include_usage=True` for OpenAI API
"""
custom_logger = TestCustomLogger()
litellm.logging_callback_manager.add_litellm_callback(custom_logger)
response = await litellm.acompletion(
model="gpt-4o",
messages=[{"role": "user", "content": "Hello, how are you?" * 100}],
stream=True,
)
actual_usage = None
async for chunk in response:
if "usage" in chunk:
actual_usage = chunk["usage"]
print("chunk.usage", json.dumps(chunk["usage"], indent=4, default=str))
pass
await asyncio.sleep(2)
print("\n\n\n\n\n")
print(
"recorded_usage",
json.dumps(custom_logger.recorded_usage, indent=4, default=str),
)
print("\n\n\n\n\n")
assert actual_usage.prompt_tokens == custom_logger.recorded_usage.prompt_tokens
assert (
actual_usage.completion_tokens == custom_logger.recorded_usage.completion_tokens
)
assert actual_usage.total_tokens == custom_logger.recorded_usage.total_tokens
@pytest.mark.asyncio
async def test_stream_token_counting_with_redaction():
"""
When litellm.turn_off_message_logging=True is used, the usage tracked == usage from llm api chunk
"""
litellm.turn_off_message_logging = True
custom_logger = TestCustomLogger()
litellm.logging_callback_manager.add_litellm_callback(custom_logger)
response = await litellm.acompletion(
model="gpt-4o",
messages=[{"role": "user", "content": "Hello, how are you?" * 100}],
stream=True,
)
actual_usage = None
async for chunk in response:
if "usage" in chunk:
actual_usage = chunk["usage"]
print("chunk.usage", json.dumps(chunk["usage"], indent=4, default=str))
pass
await asyncio.sleep(2)
print("\n\n\n\n\n")
print(
"recorded_usage",
json.dumps(custom_logger.recorded_usage, indent=4, default=str),
)
print("\n\n\n\n\n")
assert actual_usage.prompt_tokens == custom_logger.recorded_usage.prompt_tokens
assert (
actual_usage.completion_tokens == custom_logger.recorded_usage.completion_tokens
)
assert actual_usage.total_tokens == custom_logger.recorded_usage.total_tokens
@pytest.mark.asyncio
async def test_stream_token_counting_anthropic_with_include_usage():
""" """
from anthropic import Anthropic
anthropic_client = Anthropic(api_key=os.getenv("ANTHROPIC_API_KEY"))
litellm._turn_on_debug()
custom_logger = TestCustomLogger()
litellm.logging_callback_manager.add_litellm_callback(custom_logger)
input_text = "Respond in just 1 word. Say ping"
response = await litellm.acompletion(
model="claude-3-5-sonnet-20240620",
messages=[{"role": "user", "content": input_text}],
max_tokens=4096,
stream=True,
)
actual_usage = None
output_text = ""
async for chunk in response:
output_text += chunk["choices"][0]["delta"]["content"] or ""
pass
await asyncio.sleep(1)
print("\n\n\n\n\n")
print(
"recorded_usage",
json.dumps(custom_logger.recorded_usage, indent=4, default=str),
)
print("\n\n\n\n\n")
# print making the same request with anthropic client
anthropic_response = anthropic_client.messages.create(
model="claude-3-5-sonnet-20240620",
max_tokens=4096,
messages=[{"role": "user", "content": input_text}],
stream=True,
)
usage = None
all_anthropic_usage_chunks = []
for chunk in anthropic_response:
print("chunk", json.dumps(chunk, indent=4, default=str))
if hasattr(chunk, "message"):
if chunk.message.usage:
print(
"USAGE BLOCK",
json.dumps(chunk.message.usage, indent=4, default=str),
)
all_anthropic_usage_chunks.append(chunk.message.usage)
elif hasattr(chunk, "usage"):
print("USAGE BLOCK", json.dumps(chunk.usage, indent=4, default=str))
all_anthropic_usage_chunks.append(chunk.usage)
print(
"all_anthropic_usage_chunks",
json.dumps(all_anthropic_usage_chunks, indent=4, default=str),
)
input_tokens_anthropic_api = sum(
[getattr(usage, "input_tokens", 0) or 0 for usage in all_anthropic_usage_chunks]
)
output_tokens_anthropic_api = sum(
[getattr(usage, "output_tokens", 0) or 0 for usage in all_anthropic_usage_chunks]
)
print("input_tokens_anthropic_api", input_tokens_anthropic_api)
print("output_tokens_anthropic_api", output_tokens_anthropic_api)
print("input_tokens_litellm", custom_logger.recorded_usage.prompt_tokens)
print("output_tokens_litellm", custom_logger.recorded_usage.completion_tokens)
## Assert Accuracy of token counting
# input tokens should be exactly the same
assert input_tokens_anthropic_api == custom_logger.recorded_usage.prompt_tokens
# output tokens can have at max abs diff of 10. We can't guarantee the response from two api calls will be exactly the same
assert (
abs(
output_tokens_anthropic_api - custom_logger.recorded_usage.completion_tokens
)
<= 10
)