Files
litellm/scripts/benchmark_model_response_creator.py
T
2026-05-28 11:31:37 -07:00

192 lines
5.5 KiB
Python

#!/usr/bin/env python3
"""Tight microbenchmark for CustomStreamWrapper.model_response_creator.
Calls model_response_creator() in a tight loop on a pre-built wrapper to
isolate per-call cost. Driving the full wrapper adds threadpool logging,
gc, and other noise that swamps microsecond-scale changes here.
Example:
uv run python scripts/benchmark_model_response_creator.py --label baseline
uv run python scripts/benchmark_model_response_creator.py --label optimized
"""
from __future__ import annotations
import argparse
import gc
import json
import logging
import os
import statistics
import time
from dataclasses import asdict, dataclass
from typing import List
from unittest.mock import MagicMock
os.environ.setdefault("LITELLM_LOG", "ERROR")
logging.getLogger("LiteLLM").setLevel(logging.ERROR)
import litellm # noqa: E402
litellm.suppress_debug_info = True
from litellm.litellm_core_utils.streaming_handler import (
CustomStreamWrapper,
) # noqa: E402
def _make_logging_obj(provider: str) -> MagicMock:
logging_obj = MagicMock()
logging_obj.model_call_details = {
"custom_llm_provider": provider,
"litellm_params": {},
}
logging_obj.call_type = "completion"
logging_obj.stream_options = None
logging_obj.messages = [{"role": "user", "content": "hi"}]
logging_obj.completion_start_time = None
logging_obj._llm_caching_handler = None
return logging_obj
def _make_wrapper(provider: str, model: str) -> CustomStreamWrapper:
return CustomStreamWrapper(
completion_stream=iter([]),
model=model,
logging_obj=_make_logging_obj(provider),
custom_llm_provider=provider,
)
@dataclass
class Result:
label: str
scenario: str
iterations: int
elapsed_min_s: float
elapsed_median_s: float
per_call_us: float
calls_per_sec: float
SCENARIOS = {
"no_chunk": {
"description": "model_response_creator() — no chunk arg (most common path)",
"chunk_factory": lambda i: None,
},
"text_chunk": {
"description": "model_response_creator(chunk={'text': '...'}) — text delta path",
"chunk_factory": lambda i: {"text": f"token{i}"},
},
"rich_chunk": {
"description": "model_response_creator(chunk={...}) — full chunk dict path",
"chunk_factory": lambda i: {
"id": f"id-{i}",
"object": "chat.completion.chunk",
"created": 1234567890,
},
},
}
def bench_no_chunk(wrapper: CustomStreamWrapper, iterations: int) -> float:
gc.collect()
gc.disable()
try:
start = time.perf_counter()
for _ in range(iterations):
wrapper.model_response_creator()
elapsed = time.perf_counter() - start
finally:
gc.enable()
return elapsed
def bench_with_chunk(wrapper: CustomStreamWrapper, factory, iterations: int) -> float:
# Pre-build chunks so we don't measure their construction cost.
chunks = [factory(i) for i in range(iterations)]
gc.collect()
gc.disable()
try:
start = time.perf_counter()
for chunk in chunks:
wrapper.model_response_creator(chunk=dict(chunk)) # copy because mutated
elapsed = time.perf_counter() - start
finally:
gc.enable()
return elapsed
def run_scenario(
label: str,
scenario_key: str,
iterations: int,
repeats: int,
warmup: int,
) -> Result:
spec = SCENARIOS[scenario_key]
wrapper = _make_wrapper(provider="anthropic", model="claude-3-5-sonnet")
if scenario_key == "no_chunk":
runner = lambda: bench_no_chunk(wrapper, iterations) # noqa: E731
else:
runner = lambda: bench_with_chunk(
wrapper, spec["chunk_factory"], iterations
) # noqa: E731
for _ in range(warmup):
runner()
samples = [runner() for _ in range(repeats)]
elapsed_min = min(samples)
elapsed_median = statistics.median(samples)
per_call_us = (elapsed_min * 1_000_000) / iterations
calls_per_sec = iterations / elapsed_min if elapsed_min > 0 else 0.0
return Result(
label=label,
scenario=scenario_key,
iterations=iterations,
elapsed_min_s=elapsed_min,
elapsed_median_s=elapsed_median,
per_call_us=per_call_us,
calls_per_sec=calls_per_sec,
)
def main() -> None:
ap = argparse.ArgumentParser(description=__doc__.splitlines()[0])
ap.add_argument("--label", required=True)
ap.add_argument("--iterations", type=int, default=200_000)
ap.add_argument("--warmup", type=int, default=2)
ap.add_argument("--repeats", type=int, default=8)
ap.add_argument("--json", dest="json_out")
args = ap.parse_args()
print(
f"\n=== label={args.label} iterations={args.iterations:,} "
f"warmup={args.warmup} repeats={args.repeats} (min reported) ==="
)
results: List[Result] = []
for scenario in SCENARIOS:
r = run_scenario(
args.label, scenario, args.iterations, args.repeats, args.warmup
)
results.append(r)
print(
f" {r.scenario:12s}: "
f"min={r.elapsed_min_s*1000:8.2f} ms "
f"median={r.elapsed_median_s*1000:8.2f} ms "
f"per-call={r.per_call_us:7.3f} μs "
f"calls/s={r.calls_per_sec:>12,.0f}"
)
if args.json_out:
with open(args.json_out, "w", encoding="utf-8") as f:
json.dump([asdict(r) for r in results], f, indent=2)
print(f"\nWrote {len(results)} results to {args.json_out}")
if __name__ == "__main__":
main()