perf: reduce per-request and per-chunk overhead across Anthropic streaming hot paths (#28289)

* perf: reduce per-request and per-chunk overhead across Anthropic streaming hot paths

- Introduce pure-text fast-path in `_build_complete_streaming_response` that collapses O(N) `content_block_delta` events into a single equivalent SSE event before conversion, eliminating per-output-token Pydantic `ModelResponseStream` construction; non-text streams (tool_use, thinking, citations) fall back to the unchanged legacy path
- Skip agentic streaming wrapper entirely when no callback overrides `async_should_run_agentic_loop`; the wrapper buffered every chunk and rebuilt the SSE response only to call hooks that all return `(False, {})` — a pure no-op for the default config
- Serialize request body once (`json.dumps`) for both the pre-call log input and the wire, instead of twice; avoids a full O(payload) scan per request, significant for long-context Claude Code histories
- Add fast path in `async_streaming_data_generator` that bypasses the per-chunk `async_post_call_streaming_hook` coroutine await, response-string materialization, and cost-injection call when no callback/guardrail/cost-injection is active (the default config)
- Resolve `_DD_STREAMING_TRACE_ENABLED` once at import time; eliminate per-chunk `NullSpan` context manager allocation when Datadog tracing is disabled (the default)
- Memoize `get_type_hints(AnthropicMessagesRequestOptionalParams)` with `@lru_cache(maxsize=1)` — resolves once per process instead of once per `/v1/messages` request (~80µs each)
- Hoist `cost_injection_active` out of the per-chunk loop in `chunk_processor`; eliminates repeated `getattr` + endpoint-type checks on every streamed byte chunk
- Extract `_build_passthrough_logging_result` from `_route_streaming_logging_to_handler` as a standalone static method to facilitate future off-loop dispatch
- Convert `async_sse_data_generator` from an `async for: yield` trampoline to a direct return of the underlying generator, removing one async-generator layer per streamed chunk
- Skip redundant `strip_empty_text_blocks_from_anthropic_messages` scan in `anthropic_messages_handler` when the async wrapper already sanitized (signalled via `_litellm_messages_presanitized` sentinel, popped before reaching provider params)
- Gate debug log `f-string` evaluation behind `isEnabledFor(DEBUG)` in both the streaming generator and the transformation layer to avoid serializing entire message payloads on every request at non-debug log levels
- Add benchmark script (`scripts/benchmark_anthropic_messages_perf.py`) with a local mock Anthropic SSE provider for reproducible TTFT and TPM measurement across commits/branches
- Add parity tests asserting fast-path and legacy-path produce byte-identical logged/billed payloads, plus unit tests for agentic hook detection, pre-serialized body reuse, and memoized key resolution

* perf: address greptile review for anthropic streaming hot path

- Bail to legacy in `_collapse_pure_text_chunks` when content_block_delta
  events from different block indexes are observed without an intervening
  flush. Anthropic sends blocks strictly sequentially, but defensive bail
  prevents silent text-merging if the protocol ever interleaves.
- Replace leaf-class `__dict__` check for `async_post_call_streaming_hook`
  in `_callback_capabilities` with a function-identity comparison that
  walks the MRO. A vendor base class can carry the override and the
  registered class can add nothing else; before this PR the hook was
  unconditionally invoked, so an inherited-override miss would silently
  drop the hook on the streaming path.
- Add unit tests for both behaviors.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

* fix(mypy): narrow model_name to str in cost-injection branch

The hoisted cost_injection_active flag in chunk_processor encodes the
`bool(model_name)` requirement but mypy can't track that invariant
through the local, so the per-chunk `_process_chunk_with_cost_injection(
chunk, model_name)` calls flagged Optional[str] vs str. Pin a typed
non-None local inside the cost-injection branch so mypy narrows
correctly without changing runtime behavior.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

---------

Co-authored-by: Yassin Kortam <yassinkortam@g.ucla.edu>
Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
Yassin Kortam
2026-05-23 12:15:59 -07:00
committed by GitHub
parent 3b2ce201d8
commit 2eab9ee2c0
15 changed files with 1978 additions and 91 deletions
@@ -293,6 +293,12 @@ async def anthropic_messages(
api_base=api_base,
client=client,
custom_llm_provider=custom_llm_provider,
# messages were already empty-text-block sanitized at the top of this
# function and are NOT reassigned before this dispatch, so the handler
# can skip its (otherwise redundant) second full-messages scan. Passed
# explicitly (not via **kwargs) so it only affects this direct
# dispatch -- interceptor / sync entry points still sanitize.
_litellm_messages_presanitized=True,
**kwargs,
)
ctx = contextvars.copy_context()
@@ -351,10 +357,14 @@ def anthropic_messages_handler(
"""
from litellm.types.utils import LlmProviders
# Sanitize empty text blocks here too so the sync entry point
# Sanitize empty text blocks so the sync entry point
# (litellm.messages.create -> anthropic_messages_handler) gets the same
# protection as the async wrapper. Idempotent when called twice.
messages = strip_empty_text_blocks_from_anthropic_messages(messages)
# protection as the async wrapper. The async wrapper already sanitized and
# does not reassign messages before dispatch, so it sets
# ``_litellm_messages_presanitized`` to skip this redundant second
# full-messages scan. Pop it so it never leaks into provider params.
if not kwargs.pop("_litellm_messages_presanitized", False):
messages = strip_empty_text_blocks_from_anthropic_messages(messages)
metadata = validate_anthropic_api_metadata(metadata)
@@ -312,7 +312,10 @@ class AnthropicMessagesConfig(BaseAnthropicMessagesConfig):
)
####### get required params for all anthropic messages requests ######
verbose_logger.debug(f"TRANSFORMATION DEBUG - Messages: {messages}")
# Lazy %s: the f-string previously stringified the entire messages
# payload on every request regardless of log level (a full scan of the
# request body on the hot path). Defer it to when DEBUG is enabled.
verbose_logger.debug("TRANSFORMATION DEBUG - Messages: %s", messages)
# Auto-strip advisor blocks from history if advisor tool is absent.
# Prevents Anthropic 400: advisor_tool_result in history requires advisor tool.
@@ -1,4 +1,5 @@
from typing import Any, Dict, List, cast, get_type_hints
from functools import lru_cache
from typing import Any, Dict, FrozenSet, List, cast, get_type_hints
from litellm.types.llms.anthropic import AnthropicMessagesRequestOptionalParams
from litellm.types.llms.anthropic_messages.anthropic_response import (
@@ -6,6 +7,18 @@ from litellm.types.llms.anthropic_messages.anthropic_response import (
)
@lru_cache(maxsize=1)
def _anthropic_messages_optional_param_keys() -> FrozenSet[str]:
"""
Valid AnthropicMessagesRequestOptionalParams keys.
``typing.get_type_hints`` is ~80us/call and this TypedDict is static, so
resolving it once per process instead of once per request removes a fixed
full-pass cost from the /v1/messages request-parse path.
"""
return frozenset(get_type_hints(AnthropicMessagesRequestOptionalParams).keys())
class AnthropicMessagesRequestUtils:
@staticmethod
def get_requested_anthropic_messages_optional_param(
@@ -20,7 +33,7 @@ class AnthropicMessagesRequestUtils:
Returns:
AnthropicMessagesRequestOptionalParams instance with only the valid parameters
"""
valid_keys = get_type_hints(AnthropicMessagesRequestOptionalParams).keys()
valid_keys = _anthropic_messages_optional_param_keys()
filtered_params = {
k: v for k, v in params.items() if k in valid_keys and v is not None
}
+70 -3
View File
@@ -1886,7 +1886,9 @@ class BaseLLMHTTPHandler:
async_httpx_client: AsyncHTTPHandler,
request_url: str,
headers: dict,
signed_json_body: Optional[bytes],
# str when the caller passes a pre-serialized (unsigned) body to avoid
# re-dumping; bytes when a provider signed the request (e.g. Bedrock).
signed_json_body: Optional[Union[str, bytes]],
request_body: dict,
stream: bool,
logging_obj: LiteLLMLoggingObj,
@@ -2077,8 +2079,18 @@ class BaseLLMHTTPHandler:
model=model,
)
# The request body was serialized once for the pre-call log input and
# again for the wire (json.dumps is O(payload), large for long-context
# Claude Code history). Serialize once and reuse for both. Only when
# the provider didn't sign the request (sign_request no-op for the
# native anthropic path -> signed_json_body is None); signed providers
# (e.g. Bedrock) keep their signed body untouched. The HTTP-error
# retry path mutates + re-signs the body, so it still re-serializes
# internally -- this only deduplicates the success path.
request_body_json = json.dumps(request_body)
logging_obj.pre_call(
input=[{"role": "user", "content": json.dumps(request_body)}],
input=[{"role": "user", "content": request_body_json}],
api_key="",
additional_args={
"complete_input_dict": request_body,
@@ -2091,7 +2103,9 @@ class BaseLLMHTTPHandler:
async_httpx_client=async_httpx_client,
request_url=request_url,
headers=headers,
signed_json_body=signed_json_body,
signed_json_body=(
signed_json_body if signed_json_body is not None else request_body_json
),
request_body=request_body,
stream=stream or False,
logging_obj=logging_obj,
@@ -2113,6 +2127,14 @@ class BaseLLMHTTPHandler:
litellm_logging_obj=logging_obj,
)
if not self._has_agentic_completion_hook(logging_obj):
# No callback overrides async_should_run_agentic_loop, so the
# agentic wrapper's only effect would be buffering every chunk
# and rebuilding the response from SSE at end-of-stream to call
# hooks that all return (False, {}). Stream through directly and
# skip that per-chunk + end-of-stream overhead.
return completion_stream
from litellm.llms.anthropic.experimental_pass_through.messages.agentic_streaming_iterator import (
AgenticAnthropicStreamingIterator,
)
@@ -4620,6 +4642,51 @@ class BaseLLMHTTPHandler:
fingerprints = list(kwargs.get("_agentic_loop_fingerprints", []) or [])
return depth, max(max_loops, 1), fingerprints
@staticmethod
def _has_agentic_completion_hook(logging_obj: Any) -> bool:
"""
True if any registered callback actually overrides
``async_should_run_agentic_loop`` (the gate every agentic hook goes
through). The base ``CustomLogger`` implementation returns
``(False, {})``, so when nothing overrides it the agentic
post-processing is a guaranteed no-op and the streaming wrapper that
buffers + rebuilds the whole response from SSE just to call it can be
skipped entirely.
Function-identity comparison (not a leaf ``__dict__`` check) so an
override inherited through any intermediate class is still detected --
a false negative here would silently disable agentic features.
String entries in ``litellm.callbacks`` (e.g. ``"datadog"``) are
resolved to their ``CustomLogger`` instance via
``get_custom_logger_compatible_class`` -- same pattern as
``ProxyLogging._callback_capabilities`` -- so a string-registered
agentic callback is detected too.
"""
from litellm.integrations.custom_logger import CustomLogger
from litellm.litellm_core_utils.litellm_logging import (
get_custom_logger_compatible_class,
)
base_func = CustomLogger.async_should_run_agentic_loop
callbacks = litellm.callbacks + (
getattr(logging_obj, "dynamic_success_callbacks", None) or []
)
for cb in callbacks:
if isinstance(cb, str):
resolved = get_custom_logger_compatible_class(cb) # type: ignore[arg-type]
if resolved is None:
continue
cb = resolved
if not isinstance(cb, CustomLogger):
continue
cb_func = getattr(type(cb), "async_should_run_agentic_loop", base_func)
if getattr(cb_func, "__func__", cb_func) is not getattr(
base_func, "__func__", base_func
):
return True
return False
@staticmethod
def _check_agentic_loop_safety(
tool_calls: Any,
+52 -10
View File
@@ -32,7 +32,7 @@ from litellm.constants import (
STREAM_SSE_DATA_PREFIX,
)
from litellm.integrations.custom_guardrail import CustomGuardrail
from litellm.litellm_core_utils.dd_tracing import tracer
from litellm.litellm_core_utils.dd_tracing import NullTracer, tracer
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.litellm_core_utils.llm_response_utils.get_headers import (
get_response_headers,
@@ -65,6 +65,13 @@ else:
from litellm.proxy.litellm_pre_call_utils import add_litellm_data_to_request
from litellm.types.utils import ModelResponse, ModelResponseStream, Usage
# Datadog streaming spans are a no-op when ddtrace is not enabled, but the
# ``with tracer.trace(...)`` context manager still allocates a NullSpan and
# runs __enter__/__exit__ for every streamed chunk. Resolve once at import so
# the per-chunk hot path can skip the context manager entirely when tracing
# is off (the default).
_DD_STREAMING_TRACE_ENABLED = not isinstance(tracer, NullTracer)
def _serialize_http_exception_detail(
detail: Any,
@@ -231,7 +238,7 @@ def _extract_error_from_sse_chunk(event_line: Union[str, bytes]) -> dict:
return default_error
async def create_response(
async def create_response( # noqa: PLR0915
generator: AsyncGenerator[str, None],
media_type: str,
headers: dict,
@@ -336,6 +343,13 @@ async def create_response(
)
async def combined_generator() -> AsyncGenerator[str, None]:
if not _DD_STREAMING_TRACE_ENABLED:
# Fast path: no per-chunk span object / context-manager overhead.
if first_chunk_value is not None:
yield first_chunk_value
async for chunk in generator:
yield chunk
return
if first_chunk_value is not None:
with tracer.trace(DD_TRACER_STREAMING_CHUNK_YIELD_RESOURCE):
yield first_chunk_value
@@ -1900,6 +1914,23 @@ class ProxyBaseLLMRequestProcessing:
failure hook and yields via serialize_error. Use for SSE or NDJSON.
"""
verbose_proxy_logger.debug("inside generator")
# Resolve per-stream (not per-chunk) whether the heavy per-chunk path
# is needed. When no callback overrides ``async_post_call_streaming_hook``,
# no CustomGuardrail is active, and cost injection is disabled, the
# per-chunk hook returns the chunk unchanged, ``str_so_far`` is never
# consumed, and cost injection is a no-op -- so the per-chunk coroutine
# await, response-string materialization, and cost-injection call are
# pure overhead on the streaming hot path (the default config).
caps = ProxyLogging._callback_capabilities()
cost_injection_enabled = bool(
getattr(litellm, "include_cost_in_streaming_usage", False)
)
fast_path = (
not caps.has_streaming_chunk_override
and not caps.has_guardrail
and not cost_injection_enabled
)
debug_enabled = verbose_proxy_logger.isEnabledFor(logging.DEBUG)
try:
str_so_far = ""
async for (
@@ -1909,9 +1940,17 @@ class ProxyBaseLLMRequestProcessing:
response=response,
request_data=request_data,
):
verbose_proxy_logger.debug(
"async_data_generator: received streaming chunk - {}".format(chunk)
)
# ``.format(chunk)`` was previously evaluated for every chunk
# regardless of log level; gate it behind the level check.
if debug_enabled:
verbose_proxy_logger.debug(
"async_data_generator: received streaming chunk - %s", chunk
)
if fast_path:
yield serialize_chunk(chunk)
continue
chunk = await proxy_logging_obj.async_post_call_streaming_hook(
user_api_key_dict=user_api_key_dict,
response=chunk,
@@ -1969,7 +2008,7 @@ class ProxyBaseLLMRequestProcessing:
yield serialize_error(proxy_exception)
@staticmethod
async def async_sse_data_generator(
def async_sse_data_generator(
response: Any,
user_api_key_dict: UserAPIKeyAuth,
request_data: dict,
@@ -1977,17 +2016,20 @@ class ProxyBaseLLMRequestProcessing:
) -> AsyncGenerator[str, None]:
"""
Anthropic /messages and Google /generateContent streaming data generator require SSE events.
Delegates to async_streaming_data_generator with SSE serializers.
Returns the underlying ``async_streaming_data_generator`` configured with
SSE serializers directly (rather than re-wrapping it in another
``async for: yield`` trampoline), so a streamed chunk traverses one
fewer async-generator layer / coroutine resume on the hot path.
"""
async for chunk in ProxyBaseLLMRequestProcessing.async_streaming_data_generator(
return ProxyBaseLLMRequestProcessing.async_streaming_data_generator(
response=response,
user_api_key_dict=user_api_key_dict,
request_data=request_data,
proxy_logging_obj=proxy_logging_obj,
serialize_chunk=ProxyBaseLLMRequestProcessing.return_sse_chunk,
serialize_error=lambda proxy_exc: f"{STREAM_SSE_DATA_PREFIX}{json.dumps({'error': proxy_exc.to_dict()})}\n\n",
):
yield chunk
)
@staticmethod
def _process_chunk_with_cost_injection(chunk: Any, model_name: str) -> Any:
@@ -266,7 +266,174 @@ class AnthropicPassthroughLoggingHandler:
model: str,
) -> Optional[Union[ModelResponse, TextCompletionResponse]]:
"""
Builds complete response from raw Anthropic chunks
Builds complete response from raw Anthropic chunks.
Fast path: for the dominant case of a pure-text streaming response
(no tool_use / thinking / non-text content blocks), the long run of
``content_block_delta`` text deltas is collapsed into a single
equivalent SSE event before conversion. ``chunk_parser`` and
``stream_chunk_builder`` remain the single source of truth for chunk
shape, usage math and finish-reason mapping, so the rebuilt response
(and therefore the logged/billed payload) is identical -- this is
asserted by a parity test. Anything non-trivial falls back to the
unchanged legacy reconstruction.
Per-event Pydantic ``ModelResponseStream`` construction dominated
event-loop CPU under concurrent streaming; collapsing the homogeneous
text run removes O(num_output_tokens) of it.
"""
collapsed = AnthropicPassthroughLoggingHandler._collapse_pure_text_chunks(
all_chunks
)
if collapsed is not None:
return AnthropicPassthroughLoggingHandler._build_complete_streaming_response_legacy(
all_chunks=collapsed,
litellm_logging_obj=litellm_logging_obj,
model=model,
)
return AnthropicPassthroughLoggingHandler._build_complete_streaming_response_legacy(
all_chunks=all_chunks,
litellm_logging_obj=litellm_logging_obj,
model=model,
)
# Anthropic SSE block/delta types that the fast path is NOT allowed to
# collapse -- their presence forces the unchanged legacy path so tool
# calls, thinking, citations, etc. keep byte-identical reconstruction.
_FAST_PATH_DISALLOWED_DELTA_TYPES = frozenset(
{
"input_json_delta",
"thinking_delta",
"signature_delta",
"citations_delta",
}
)
@staticmethod
def _collapse_pure_text_chunks( # noqa: PLR0915
all_chunks: Sequence[Union[str, bytes]],
) -> Optional[List[str]]:
"""
Return a new chunk list with the contiguous run of text-only
``content_block_delta`` events replaced by a single equivalent event,
or ``None`` if the stream is not a pure single-text-block response
(in which case the caller uses the legacy path unchanged).
Only ``message_start`` / ``content_block_start(text)`` /
``content_block_delta(text_delta)`` / ``content_block_stop`` /
``message_delta`` / ``message_stop`` / ``ping`` events are accepted.
Any other content-block type or delta type returns ``None``.
"""
normalized: List[str] = []
for raw in all_chunks:
line = raw.decode("utf-8") if isinstance(raw, bytes) else raw
for ev in line.split("\n\n"):
ev = ev.strip()
if ev:
normalized.append(ev)
text_block_indexes: set = set()
out: List[str] = []
pending_text: List[str] = []
pending_index: Optional[int] = None
saw_any_text_delta = False
def flush() -> None:
nonlocal pending_text, pending_index
if pending_text:
merged = {
"type": "content_block_delta",
"index": pending_index if pending_index is not None else 0,
"delta": {"type": "text_delta", "text": "".join(pending_text)},
}
out.append("data: " + json.dumps(merged))
pending_text = []
pending_index = None
for ev in normalized:
idx = ev.find("data:")
if idx == -1:
# Bare "event: <name>" line. The legacy converter turns this
# into an empty ModelResponseStream that contributes nothing
# to stream_chunk_builder. Drop the high-frequency interior
# markers (content_block_delta / ping); keep every other
# bare event line verbatim so chunk ordering and the
# load-bearing chunks[0] (event: message_start) are retained.
name = ev[len("event:") :].strip() if ev.startswith("event:") else ""
if name in ("content_block_delta", "ping"):
continue
flush()
out.append(ev)
continue
json_str = ev[idx + len("data:") :].strip()
try:
data = json.loads(json_str)
except (json.JSONDecodeError, ValueError):
return None
etype = data.get("type")
if etype == "content_block_start":
block = data.get("content_block") or {}
if block.get("type") != "text":
return None
text_block_indexes.add(data.get("index"))
flush()
out.append(ev)
elif etype == "content_block_delta":
delta = data.get("delta") or {}
dtype = delta.get("type")
if (
dtype
in AnthropicPassthroughLoggingHandler._FAST_PATH_DISALLOWED_DELTA_TYPES
):
return None
if dtype != "text_delta":
return None
cur_index = data.get("index")
if cur_index not in text_block_indexes:
return None
# Defensive: Anthropic sends blocks strictly sequentially
# (start/deltas/stop, then next block), so pending_text from
# block N must be flushed by content_block_stop before block
# N+1's deltas arrive. If we ever see a delta whose index
# disagrees with the current pending buffer, the stream is
# interleaved -- fall back to legacy rather than risk merging
# text from different blocks under a single index.
if (
pending_text
and pending_index is not None
and cur_index != pending_index
):
return None
saw_any_text_delta = True
pending_index = cur_index
pending_text.append(delta.get("text") or "")
elif etype == "ping":
# Interior no-op; legacy maps it to an empty chunk.
continue
else:
# message_start / content_block_stop / message_delta /
# message_stop / error: pass through unchanged.
flush()
out.append(ev)
flush()
if not saw_any_text_delta:
return None
return out
@staticmethod
def _build_complete_streaming_response_legacy(
all_chunks: Sequence[Union[str, bytes]],
litellm_logging_obj: LiteLLMLoggingObj,
model: str,
) -> Optional[Union[ModelResponse, TextCompletionResponse]]:
"""
Original reconstruction: convert every SSE event to a generic chunk
and assemble via stream_chunk_builder. Kept verbatim as the fallback
/ source of truth for the fast path's parity test.
- Splits multi-event chunks into individual SSE events
- Converts str chunks to generic chunks
@@ -1,6 +1,6 @@
import asyncio
from datetime import datetime
from typing import List, Optional
from typing import List, Optional, Tuple
import httpx
@@ -45,28 +45,44 @@ class PassThroughStreamingHandler:
litellm_logging_obj=litellm_logging_obj,
)
# Resolve once per stream rather than re-reading the global +
# re-branching on every chunk. ``include_cost_in_streaming_usage`` is
# set at config load and stable for the process, matching how the
# proxy-level streaming fast path resolves it.
cost_injection_active = (
bool(getattr(litellm, "include_cost_in_streaming_usage", False))
and bool(model_name)
and endpoint_type in (EndpointType.VERTEX_AI, EndpointType.ANTHROPIC)
)
try:
async for chunk in response.aiter_bytes():
raw_bytes.append(chunk)
if (
getattr(litellm, "include_cost_in_streaming_usage", False)
and model_name
):
if not cost_injection_active:
# Hot path: just buffer for end-of-stream logging and forward.
async for chunk in response.aiter_bytes():
raw_bytes.append(chunk)
yield chunk
else:
# ``cost_injection_active`` already requires ``model_name`` to
# be truthy; pin to a typed local so mypy narrows ``Optional[str]``
# -> ``str`` for the per-chunk call site.
assert model_name is not None
resolved_model_name: str = model_name
async for chunk in response.aiter_bytes():
raw_bytes.append(chunk)
if endpoint_type == EndpointType.VERTEX_AI:
if "streamRawPredict" in url_route or "rawPredict" in url_route:
modified_chunk = ProxyBaseLLMRequestProcessing._process_chunk_with_cost_injection(
chunk, model_name
chunk, resolved_model_name
)
if modified_chunk is not None:
chunk = modified_chunk
elif endpoint_type == EndpointType.ANTHROPIC:
else: # EndpointType.ANTHROPIC
modified_chunk = ProxyBaseLLMRequestProcessing._process_chunk_with_cost_injection(
chunk, model_name
chunk, resolved_model_name
)
if modified_chunk is not None:
chunk = modified_chunk
yield chunk
yield chunk
except Exception as e:
verbose_proxy_logger.error(f"Error in chunk_processor: {str(e)}")
raise
@@ -115,64 +131,20 @@ class PassThroughStreamingHandler:
- OpenAI
"""
try:
all_chunks = PassThroughStreamingHandler._convert_raw_bytes_to_str_lines(
raw_bytes
(
standard_logging_response_object,
kwargs,
) = PassThroughStreamingHandler._build_passthrough_logging_result(
litellm_logging_obj=litellm_logging_obj,
passthrough_success_handler_obj=passthrough_success_handler_obj,
url_route=url_route,
request_body=request_body,
endpoint_type=endpoint_type,
start_time=start_time,
raw_bytes=raw_bytes,
end_time=end_time,
model=model,
)
standard_logging_response_object: Optional[
PassThroughEndpointLoggingResultValues
] = None
kwargs: dict = {}
if endpoint_type == EndpointType.ANTHROPIC:
anthropic_passthrough_logging_handler_result = AnthropicPassthroughLoggingHandler._handle_logging_anthropic_collected_chunks(
litellm_logging_obj=litellm_logging_obj,
passthrough_success_handler_obj=passthrough_success_handler_obj,
url_route=url_route,
request_body=request_body,
endpoint_type=endpoint_type,
start_time=start_time,
all_chunks=all_chunks,
end_time=end_time,
)
standard_logging_response_object = (
anthropic_passthrough_logging_handler_result["result"]
)
kwargs = anthropic_passthrough_logging_handler_result["kwargs"]
elif endpoint_type == EndpointType.VERTEX_AI:
vertex_passthrough_logging_handler_result = VertexPassthroughLoggingHandler._handle_logging_vertex_collected_chunks(
litellm_logging_obj=litellm_logging_obj,
passthrough_success_handler_obj=passthrough_success_handler_obj,
url_route=url_route,
request_body=request_body,
endpoint_type=endpoint_type,
start_time=start_time,
all_chunks=all_chunks,
end_time=end_time,
model=model,
)
standard_logging_response_object = (
vertex_passthrough_logging_handler_result["result"]
)
kwargs = vertex_passthrough_logging_handler_result["kwargs"]
elif endpoint_type == EndpointType.OPENAI:
openai_passthrough_logging_handler_result = OpenAIPassthroughLoggingHandler._handle_logging_openai_collected_chunks(
litellm_logging_obj=litellm_logging_obj,
passthrough_success_handler_obj=passthrough_success_handler_obj,
url_route=url_route,
request_body=request_body,
endpoint_type=endpoint_type,
start_time=start_time,
all_chunks=all_chunks,
end_time=end_time,
)
standard_logging_response_object = (
openai_passthrough_logging_handler_result["result"]
)
kwargs = openai_passthrough_logging_handler_result["kwargs"]
if standard_logging_response_object is None:
standard_logging_response_object = StandardPassThroughResponseObject(
response=f"cannot parse chunks to standard response object. Chunks={all_chunks}"
)
await litellm_logging_obj.async_success_handler(
result=standard_logging_response_object,
start_time=start_time,
@@ -199,6 +171,89 @@ class PassThroughStreamingHandler:
f"Error in _route_streaming_logging_to_handler: {str(e)}"
)
@staticmethod
def _build_passthrough_logging_result(
litellm_logging_obj: LiteLLMLoggingObj,
passthrough_success_handler_obj: PassThroughEndpointLogging,
url_route: str,
request_body: dict,
endpoint_type: EndpointType,
start_time: datetime,
raw_bytes: List[bytes],
end_time: datetime,
model: Optional[str],
) -> Tuple[PassThroughEndpointLoggingResultValues, dict]:
"""
Synchronous, CPU-bound reconstruction of the standard logging payload
from collected raw SSE bytes. Extracted from
_route_streaming_logging_to_handler so the per-endpoint dispatch can
be unit-tested in isolation. Still invoked synchronously on the event
loop; an off-loop dispatch is a future change, not part of this PR.
"""
all_chunks = PassThroughStreamingHandler._convert_raw_bytes_to_str_lines(
raw_bytes
)
standard_logging_response_object: Optional[
PassThroughEndpointLoggingResultValues
] = None
kwargs: dict = {}
if endpoint_type == EndpointType.ANTHROPIC:
anthropic_passthrough_logging_handler_result = AnthropicPassthroughLoggingHandler._handle_logging_anthropic_collected_chunks(
litellm_logging_obj=litellm_logging_obj,
passthrough_success_handler_obj=passthrough_success_handler_obj,
url_route=url_route,
request_body=request_body,
endpoint_type=endpoint_type,
start_time=start_time,
all_chunks=all_chunks,
end_time=end_time,
)
standard_logging_response_object = (
anthropic_passthrough_logging_handler_result["result"]
)
kwargs = anthropic_passthrough_logging_handler_result["kwargs"]
elif endpoint_type == EndpointType.VERTEX_AI:
vertex_passthrough_logging_handler_result = (
VertexPassthroughLoggingHandler._handle_logging_vertex_collected_chunks(
litellm_logging_obj=litellm_logging_obj,
passthrough_success_handler_obj=passthrough_success_handler_obj,
url_route=url_route,
request_body=request_body,
endpoint_type=endpoint_type,
start_time=start_time,
all_chunks=all_chunks,
end_time=end_time,
model=model,
)
)
standard_logging_response_object = (
vertex_passthrough_logging_handler_result["result"]
)
kwargs = vertex_passthrough_logging_handler_result["kwargs"]
elif endpoint_type == EndpointType.OPENAI:
openai_passthrough_logging_handler_result = (
OpenAIPassthroughLoggingHandler._handle_logging_openai_collected_chunks(
litellm_logging_obj=litellm_logging_obj,
passthrough_success_handler_obj=passthrough_success_handler_obj,
url_route=url_route,
request_body=request_body,
endpoint_type=endpoint_type,
start_time=start_time,
all_chunks=all_chunks,
end_time=end_time,
)
)
standard_logging_response_object = (
openai_passthrough_logging_handler_result["result"]
)
kwargs = openai_passthrough_logging_handler_result["kwargs"]
if standard_logging_response_object is None:
standard_logging_response_object = StandardPassThroughResponseObject(
response=f"cannot parse chunks to standard response object. Chunks={all_chunks}"
)
return standard_logging_response_object, kwargs
@staticmethod
def _extract_model_for_cost_injection(
request_body: Optional[dict],
+16 -1
View File
@@ -1596,7 +1596,22 @@ class ProxyLogging:
iterator_overrides.append((resolved, "override"))
elif "apply_guardrail" in cls_attrs:
iterator_overrides.append((resolved, "apply_guardrail"))
if "async_post_call_streaming_hook" in cls_attrs:
# Walk the MRO for ``async_post_call_streaming_hook`` rather than
# using the leaf-class ``__dict__`` check used by the other flags:
# before this PR the hook was unconditionally invoked, so a
# callback that inherits an override from an intermediate parent
# (e.g. a vendor base class providing the override, with the
# registered class adding nothing else) MUST still be detected.
# A leaf-class miss here would silently drop the inherited hook.
base_streaming_hook = CustomLogger.async_post_call_streaming_hook
cls_streaming_hook = getattr(
cls,
"async_post_call_streaming_hook",
base_streaming_hook,
)
if getattr(
cls_streaming_hook, "__func__", cls_streaming_hook
) is not getattr(base_streaming_hook, "__func__", base_streaming_hook):
has_streaming_chunk_override = True
if "async_pre_call_hook" in cls_attrs:
has_pre_call_override = True
@@ -0,0 +1,624 @@
#!/usr/bin/env python3
"""Benchmark LiteLLM proxy /v1/messages (Anthropic Messages API) streaming.
Measures the two metrics that matter for an interactive streaming proxy:
* TTFT - time to first streamed token (first ``content_block_delta``)
* TPM - sustained output token throughput (tokens / second) once the
full stream is consumed, plus request throughput (RPS)
It boots a local mock Anthropic provider that speaks the real Anthropic
streaming SSE wire format (``message_start`` -> ``content_block_delta`` ->
``message_stop``) and a LiteLLM proxy from any checkout, so commits/branches
can be compared without depending on real provider latency.
Example:
uv run python scripts/benchmark_anthropic_messages_perf.py \
--label baseline --proxy-command ".venv/bin/litellm"
Compare an already-running proxy:
uv run python scripts/benchmark_anthropic_messages_perf.py \
--no-start-proxy --label current
"""
from __future__ import annotations
import argparse
import asyncio
import json
import os
import shlex
import signal
import statistics
import subprocess
import tempfile
import time
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Optional
import aiohttp
from aiohttp import web
DEFAULT_MODEL = "claude-perf-test"
DEFAULT_API_KEY = "sk-1234"
@dataclass
class StreamSample:
success: bool
ttft_ms: float
total_ms: float
output_tokens: int
status_code: int
error: str = ""
@dataclass
class SummaryStats:
requests: int
failures: int
rps: float
ttft_mean_ms: float
ttft_p50_ms: float
ttft_p95_ms: float
ttft_p99_ms: float
total_p50_ms: float
total_p95_ms: float
tokens_per_sec: float
class MockAnthropicProvider:
"""Minimal Anthropic Messages API server (real streaming SSE format)."""
def __init__(
self,
host: str,
port: int,
first_token_delay_ms: float,
stream_content_chunks: int,
) -> None:
self.host = host
self.port = port
self.first_token_delay_ms = first_token_delay_ms
self.stream_content_chunks = stream_content_chunks
self.runner: Optional[web.AppRunner] = None
@property
def base_url(self) -> str:
return f"http://{self.host}:{self.port}"
async def start(self) -> None:
app = web.Application()
app.router.add_post("/v1/messages", self.handle_messages)
self.runner = web.AppRunner(app, access_log=None)
await self.runner.setup()
site = web.TCPSite(self.runner, self.host, self.port)
await site.start()
async def stop(self) -> None:
if self.runner is not None:
await self.runner.cleanup()
async def handle_messages(self, request: web.Request) -> web.StreamResponse:
body = await request.json()
if body.get("stream"):
return await self._streaming_response(request, body)
return self._json_response(body)
def _json_response(self, body: dict[str, Any]) -> web.Response:
payload = {
"id": "msg_perf",
"type": "message",
"role": "assistant",
"model": body.get("model", DEFAULT_MODEL),
"content": [{"type": "text", "text": "hello"}],
"stop_reason": "end_turn",
"stop_sequence": None,
"usage": {"input_tokens": 8, "output_tokens": 1},
}
return web.json_response(payload)
@staticmethod
def _sse(event: str, data: dict[str, Any]) -> bytes:
return f"event: {event}\ndata: {json.dumps(data)}\n\n".encode()
async def _streaming_response(
self, request: web.Request, body: dict[str, Any]
) -> web.StreamResponse:
model = body.get("model", DEFAULT_MODEL)
response = web.StreamResponse(
status=200,
headers={
"Content-Type": "text/event-stream",
"Cache-Control": "no-cache",
},
)
await response.prepare(request)
await response.write(
self._sse(
"message_start",
{
"type": "message_start",
"message": {
"id": "msg_perf",
"type": "message",
"role": "assistant",
"model": model,
"content": [],
"stop_reason": None,
"stop_sequence": None,
"usage": {"input_tokens": 8, "output_tokens": 0},
},
},
)
)
await response.write(
self._sse(
"content_block_start",
{
"type": "content_block_start",
"index": 0,
"content_block": {"type": "text", "text": ""},
},
)
)
if self.first_token_delay_ms > 0:
await asyncio.sleep(self.first_token_delay_ms / 1000)
for _ in range(self.stream_content_chunks):
await response.write(
self._sse(
"content_block_delta",
{
"type": "content_block_delta",
"index": 0,
"delta": {"type": "text_delta", "text": "hello "},
},
)
)
await response.write(
self._sse("content_block_stop", {"type": "content_block_stop", "index": 0})
)
await response.write(
self._sse(
"message_delta",
{
"type": "message_delta",
"delta": {"stop_reason": "end_turn", "stop_sequence": None},
"usage": {"output_tokens": self.stream_content_chunks},
},
)
)
await response.write(self._sse("message_stop", {"type": "message_stop"}))
await response.write_eof()
return response
def percentile(values: list[float], pct: float) -> float:
if not values:
return 0.0
sorted_values = sorted(values)
index = min(int(len(sorted_values) * pct / 100), len(sorted_values) - 1)
return sorted_values[index]
def summarize(samples: list[StreamSample], wall_time_s: float) -> SummaryStats:
ok = [s for s in samples if s.success]
ttfts = [s.ttft_ms for s in ok]
totals = [s.total_ms for s in ok]
total_tokens = sum(s.output_tokens for s in ok)
return SummaryStats(
requests=len(samples),
failures=len(samples) - len(ok),
rps=(len(ok) / wall_time_s) if wall_time_s > 0 else 0.0,
ttft_mean_ms=statistics.mean(ttfts) if ttfts else 0.0,
ttft_p50_ms=percentile(ttfts, 50),
ttft_p95_ms=percentile(ttfts, 95),
ttft_p99_ms=percentile(ttfts, 99),
total_p50_ms=percentile(totals, 50),
total_p95_ms=percentile(totals, 95),
# Aggregate output-token throughput: total tokens delivered across all
# successful requests divided by wall-clock time. This is the true
# server TPM and (unlike tokens / summed-per-request-latency) scales
# correctly with concurrency.
tokens_per_sec=(total_tokens / wall_time_s) if wall_time_s > 0 else 0.0,
)
def get_git_revision(litellm_dir: Path) -> str:
try:
result = subprocess.run(
["git", "rev-parse", "--short", "HEAD"],
cwd=litellm_dir,
check=True,
capture_output=True,
text=True,
)
return result.stdout.strip()
except Exception:
return "unknown"
def write_proxy_config(config_path: Path, provider_base_url: str, api_key: str) -> None:
config_path.write_text(
f"""model_list:
- model_name: {DEFAULT_MODEL}
litellm_params:
model: anthropic/{DEFAULT_MODEL}
api_key: fake-provider-key
api_base: {provider_base_url}
general_settings:
master_key: {api_key}
litellm_settings:
telemetry: false
""",
encoding="utf-8",
)
async def wait_for_proxy(base_url: str, timeout_s: float) -> None:
deadline = time.perf_counter() + timeout_s
last_error = ""
async with aiohttp.ClientSession() as session:
while time.perf_counter() < deadline:
try:
async with session.get(f"{base_url}/health/liveliness") as response:
if response.status < 500:
return
last_error = f"HTTP {response.status}"
except Exception as exc:
last_error = str(exc)
await asyncio.sleep(0.5)
raise TimeoutError(f"Timed out waiting for proxy at {base_url}: {last_error}")
def start_proxy_process(
litellm_dir: Path,
proxy_command: str,
config_path: Path,
port: int,
log_path: Path,
) -> subprocess.Popen:
command = shlex.split(proxy_command) + [
"--config",
str(config_path),
"--port",
str(port),
]
env = {
**os.environ,
"LITELLM_TELEMETRY": "False",
"PYTHONUNBUFFERED": "1",
}
log_file = log_path.open("w", encoding="utf-8")
return subprocess.Popen(
command,
cwd=litellm_dir,
env=env,
stdout=log_file,
stderr=subprocess.STDOUT,
start_new_session=True,
)
def stop_proxy_process(process: subprocess.Popen) -> None:
if process.poll() is not None:
return
try:
os.killpg(process.pid, signal.SIGTERM)
process.wait(timeout=10)
except Exception:
try:
os.killpg(process.pid, signal.SIGKILL)
except Exception:
pass
async def measure_stream(
session: aiohttp.ClientSession,
url: str,
headers: dict[str, str],
payload: dict[str, Any],
) -> StreamSample:
start = time.perf_counter()
ttft_ms = 0.0
output_tokens = 0
try:
async with session.post(url, headers=headers, json=payload) as response:
if response.status != 200:
body = await response.read()
return StreamSample(
success=False,
ttft_ms=0.0,
total_ms=(time.perf_counter() - start) * 1000,
output_tokens=0,
status_code=response.status,
error=body.decode("utf-8", errors="ignore")[:200],
)
async for raw_line in response.content:
line = raw_line.strip()
if not line.startswith(b"data:"):
continue
data = line[5:].strip()
if data == b"[DONE]":
break
try:
event = json.loads(data)
except json.JSONDecodeError:
continue
etype = event.get("type")
if etype == "content_block_delta":
if ttft_ms == 0.0:
ttft_ms = (time.perf_counter() - start) * 1000
output_tokens += 1
elif etype == "message_stop":
break
total_ms = (time.perf_counter() - start) * 1000
if ttft_ms == 0.0:
return StreamSample(
success=False,
ttft_ms=0.0,
total_ms=total_ms,
output_tokens=0,
status_code=response.status,
error="stream ended before a content token",
)
return StreamSample(
success=True,
ttft_ms=ttft_ms,
total_ms=total_ms,
output_tokens=output_tokens,
status_code=response.status,
)
except Exception as exc:
return StreamSample(
success=False,
ttft_ms=0.0,
total_ms=(time.perf_counter() - start) * 1000,
output_tokens=0,
status_code=0,
error=str(exc)[:200],
)
async def run_benchmark(
url: str,
headers: dict[str, str],
payload: dict[str, Any],
requests: int,
concurrency: int,
warmup: int,
timeout_s: float,
) -> SummaryStats:
timeout = aiohttp.ClientTimeout(total=timeout_s)
connector = aiohttp.TCPConnector(
limit=max(concurrency * 2, 10),
limit_per_host=max(concurrency, 10),
force_close=False,
)
async def worker(
session: aiohttp.ClientSession,
counter: list[int],
budget: int,
sink: list[StreamSample],
) -> None:
# Steady-state load: exactly `concurrency` workers, each pulling the
# next request slot as soon as its previous one finishes. Keeps
# in-flight concurrency constant (vs. a gather-all + semaphore burst)
# which removes the thundering-herd variance that otherwise swamps a
# 10% signal.
while True:
idx = counter[0]
if idx >= budget:
return
counter[0] = idx + 1
sink.append(await measure_stream(session, url, headers, payload))
async with aiohttp.ClientSession(connector=connector, timeout=timeout) as session:
if warmup > 0:
wcounter = [0]
await asyncio.gather(
*[worker(session, wcounter, warmup, []) for _ in range(concurrency)]
)
samples: list[StreamSample] = []
counter = [0]
wall_start = time.perf_counter()
await asyncio.gather(
*[worker(session, counter, requests, samples) for _ in range(concurrency)]
)
wall_time_s = time.perf_counter() - wall_start
return summarize(samples, wall_time_s)
def stats_to_dict(stats: SummaryStats) -> dict[str, Any]:
return {
"requests": stats.requests,
"failures": stats.failures,
"rps": stats.rps,
"ttft_mean_ms": stats.ttft_mean_ms,
"ttft_p50_ms": stats.ttft_p50_ms,
"ttft_p95_ms": stats.ttft_p95_ms,
"ttft_p99_ms": stats.ttft_p99_ms,
"total_p50_ms": stats.total_p50_ms,
"total_p95_ms": stats.total_p95_ms,
"tokens_per_sec": stats.tokens_per_sec,
}
def print_summary(label: str, revision: str, stats: SummaryStats) -> None:
print("\n=== Anthropic /v1/messages streaming benchmark ===")
print(f"Label: {label}")
print(f"Revision: {revision}")
print(f"Requests: {stats.requests} Failures: {stats.failures}")
print(f"TTFT mean: {stats.ttft_mean_ms:.2f} ms")
print(f"TTFT p50: {stats.ttft_p50_ms:.2f} ms")
print(f"TTFT p95: {stats.ttft_p95_ms:.2f} ms")
print(f"TTFT p99: {stats.ttft_p99_ms:.2f} ms")
print(f"Full p50: {stats.total_p50_ms:.2f} ms")
print(f"Full p95: {stats.total_p95_ms:.2f} ms")
print(f"Throughput: {stats.rps:.2f} req/s")
print(f"TPM: {stats.tokens_per_sec:.1f} output tokens/s")
print("\nMarkdown row:")
print(
"| "
+ " | ".join(
[
label,
revision,
f"{stats.ttft_p50_ms:.2f}",
f"{stats.ttft_p95_ms:.2f}",
f"{stats.tokens_per_sec:.1f}",
f"{stats.rps:.2f}",
]
)
+ " |"
)
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("--label", default="current")
parser.add_argument("--litellm-dir", default=str(Path.cwd()))
parser.add_argument("--proxy-command", default="uv run litellm")
parser.add_argument("--proxy-host", default="127.0.0.1")
parser.add_argument("--proxy-port", type=int, default=4000)
parser.add_argument("--provider-host", default="127.0.0.1")
parser.add_argument("--provider-port", type=int, default=8098)
parser.add_argument("--api-key", default=DEFAULT_API_KEY)
parser.add_argument("--requests", type=int, default=300)
parser.add_argument("--concurrency", type=int, default=20)
parser.add_argument("--warmup", type=int, default=30)
parser.add_argument("--timeout", type=float, default=30)
parser.add_argument("--proxy-start-timeout", type=float, default=90)
parser.add_argument("--provider-first-token-delay-ms", type=float, default=0)
parser.add_argument(
"--provider-stream-content-chunks",
type=int,
default=64,
help="Number of text delta chunks the mock emits (default 64).",
)
parser.add_argument(
"--repeats",
type=int,
default=1,
help="Run the suite N times against the same proxy; report the median run.",
)
parser.add_argument(
"--no-start-proxy",
action="store_true",
help="Benchmark an already-running proxy at --proxy-host/--proxy-port",
)
parser.add_argument(
"--provider-url",
help="Use an already-running Anthropic-compatible provider",
)
parser.add_argument("--output-json", help="Write machine-readable results")
return parser.parse_args()
async def async_main() -> None:
args = parse_args()
litellm_dir = Path(args.litellm_dir).resolve()
revision = get_git_revision(litellm_dir)
proxy_base_url = f"http://{args.proxy_host}:{args.proxy_port}"
proxy_url = f"{proxy_base_url}/v1/messages"
headers = {
"Authorization": f"Bearer {args.api_key}",
"Content-Type": "application/json",
}
stream_payload = {
"model": DEFAULT_MODEL,
"max_tokens": 256,
"messages": [{"role": "user", "content": "hi"}],
"stream": True,
}
provider: Optional[MockAnthropicProvider] = None
proxy_process: Optional[subprocess.Popen] = None
with tempfile.TemporaryDirectory(prefix="litellm-anthropic-perf-") as tmp_dir_name:
tmp_dir = Path(tmp_dir_name)
proxy_log_path = tmp_dir / "proxy.log"
if args.provider_url:
provider_base_url = args.provider_url.rstrip("/")
else:
provider = MockAnthropicProvider(
host=args.provider_host,
port=args.provider_port,
first_token_delay_ms=args.provider_first_token_delay_ms,
stream_content_chunks=args.provider_stream_content_chunks,
)
await provider.start()
provider_base_url = provider.base_url
config_path = tmp_dir / "config.yaml"
write_proxy_config(config_path, provider_base_url, args.api_key)
try:
if not args.no_start_proxy:
proxy_process = start_proxy_process(
litellm_dir=litellm_dir,
proxy_command=args.proxy_command,
config_path=config_path,
port=args.proxy_port,
log_path=proxy_log_path,
)
await wait_for_proxy(proxy_base_url, args.proxy_start_timeout)
runs: list[SummaryStats] = []
for run_idx in range(max(1, args.repeats)):
if args.repeats > 1:
print(f"\n--- Run {run_idx + 1}/{args.repeats} ---")
stats = await run_benchmark(
url=proxy_url,
headers=headers,
payload=stream_payload,
requests=args.requests,
concurrency=args.concurrency,
warmup=args.warmup,
timeout_s=args.timeout,
)
runs.append(stats)
if args.repeats > 1:
print(
f" run {run_idx + 1}: TTFT p50={stats.ttft_p50_ms:.2f}ms "
f"TPM={stats.tokens_per_sec:.1f} tok/s RPS={stats.rps:.2f}"
)
stats = sorted(runs, key=lambda s: s.ttft_p50_ms)[len(runs) // 2]
finally:
if proxy_process is not None:
stop_proxy_process(proxy_process)
if provider is not None:
await provider.stop()
print_summary(args.label, revision, stats)
if args.output_json:
Path(args.output_json).write_text(
json.dumps(
{
"label": args.label,
"revision": revision,
"proxy_streaming": stats_to_dict(stats),
"proxy_log_path": str(proxy_log_path),
},
indent=2,
sort_keys=True,
),
encoding="utf-8",
)
def main() -> None:
asyncio.run(async_main())
if __name__ == "__main__":
main()
@@ -544,3 +544,132 @@ class TestThinkingSummaryPreservation:
assert result == {
"reasoning_effort": {"effort": "medium", "summary": "concise"}
}
# ---------------------------------------------------------------------------
# Parity tests: redundant empty-text-block sanitization scan removal.
# The async wrapper sanitizes once and tells the handler to skip its second
# (redundant) full-messages scan; the sync entry point still sanitizes.
# ---------------------------------------------------------------------------
def _empty_block_msgs():
return [
{
"role": "assistant",
"content": [
{"type": "text", "text": " "}, # whitespace-only -> stripped
{"type": "tool_use", "id": "t", "name": "B", "input": {}},
],
}
]
def test_handler_strips_when_no_presanitized_flag():
"""Sync entry point (no async wrapper): handler must still sanitize."""
from litellm.llms.anthropic.experimental_pass_through.messages import handler
with patch.object(
handler,
"strip_empty_text_blocks_from_anthropic_messages",
wraps=handler.strip_empty_text_blocks_from_anthropic_messages,
) as spy:
result = handler.anthropic_messages_handler(
max_tokens=10,
messages=_empty_block_msgs(),
model="anthropic/claude-3-5-sonnet-20241022",
custom_llm_provider="anthropic",
mock_response="hi there",
)
assert spy.call_count == 1 # sanitized exactly once here
assert result is not None
def test_handler_skips_strip_when_presanitized():
"""Async wrapper already sanitized -> handler must NOT rescan."""
from litellm.llms.anthropic.experimental_pass_through.messages import handler
with patch.object(
handler,
"strip_empty_text_blocks_from_anthropic_messages",
wraps=handler.strip_empty_text_blocks_from_anthropic_messages,
) as spy:
result = handler.anthropic_messages_handler(
max_tokens=10,
messages=_empty_block_msgs(),
model="anthropic/claude-3-5-sonnet-20241022",
custom_llm_provider="anthropic",
mock_response="hi there",
_litellm_messages_presanitized=True,
)
assert spy.call_count == 0 # skipped the redundant scan
assert result is not None
def test_presanitized_flag_not_leaked_to_provider_params():
"""The private sentinel must be popped, never forwarded as a request param."""
from litellm.llms.anthropic.experimental_pass_through.messages import handler
captured = {}
def fake_base_handler(*args, **kwargs):
captured.update(kwargs)
captured["optional"] = kwargs.get(
"anthropic_messages_optional_request_params", {}
)
return "stub"
with patch.object(
handler.base_llm_http_handler,
"anthropic_messages_handler",
side_effect=fake_base_handler,
):
handler.anthropic_messages_handler(
max_tokens=10,
messages=[{"role": "user", "content": "hi"}],
model="anthropic/claude-3-5-sonnet-20241022",
custom_llm_provider="anthropic",
_litellm_messages_presanitized=True,
)
assert "_litellm_messages_presanitized" not in captured.get("optional", {})
assert "_litellm_messages_presanitized" not in captured.get("kwargs", {})
@pytest.mark.asyncio
async def test_async_wrapper_sets_presanitized_and_sanitizes_once():
"""End-to-end: wrapper sanitizes (once) AND signals the handler to skip."""
from litellm.llms.anthropic.experimental_pass_through.messages import handler
captured = {}
def fake_handler(*args, **kwargs):
captured["messages"] = kwargs.get("messages")
captured["presanitized"] = kwargs.get("_litellm_messages_presanitized")
return "stub"
fake_loop = MagicMock()
fake_loop.run_in_executor = lambda _e, func: _async_return(func())
with (
patch.object(handler, "anthropic_messages_handler", side_effect=fake_handler),
patch("asyncio.get_event_loop", return_value=fake_loop),
patch.object(
handler,
"strip_empty_text_blocks_from_anthropic_messages",
wraps=handler.strip_empty_text_blocks_from_anthropic_messages,
) as spy,
):
await handler.anthropic_messages(
max_tokens=100,
messages=_empty_block_msgs(),
model="anthropic/claude-sonnet-4-5-20250929",
custom_llm_provider="anthropic",
api_key="k",
)
# Wrapper stripped exactly once (the handler is faked, so its skipped
# call never runs anyway -- the point is the wrapper still sanitizes).
assert spy.call_count == 1
assert captured["presanitized"] is True
assert [b["type"] for b in captured["messages"][0]["content"]] == ["tool_use"]
@@ -0,0 +1,56 @@
"""
Regression tests for the /v1/messages request-parse fast paths:
- get_requested_anthropic_messages_optional_param must still filter to the
valid AnthropicMessagesRequestOptionalParams keys and drop None values,
while resolving the (static) type hints only once per process.
"""
from litellm.llms.anthropic.experimental_pass_through.messages.utils import (
AnthropicMessagesRequestUtils,
_anthropic_messages_optional_param_keys,
)
def test_optional_param_filtering_unchanged():
params = {
"temperature": 0.5,
"top_p": None, # None dropped
"tools": [{"name": "x"}],
"not_a_real_param": "drop me", # invalid key dropped
"stream": True,
}
result = (
AnthropicMessagesRequestUtils.get_requested_anthropic_messages_optional_param(
params
)
)
assert result == {"temperature": 0.5, "tools": [{"name": "x"}], "stream": True}
assert "top_p" not in result
assert "not_a_real_param" not in result
def test_valid_keys_are_memoized():
_anthropic_messages_optional_param_keys.cache_clear()
first = _anthropic_messages_optional_param_keys()
for _ in range(50):
AnthropicMessagesRequestUtils.get_requested_anthropic_messages_optional_param(
{"temperature": 0.1}
)
info = _anthropic_messages_optional_param_keys.cache_info()
# Resolved exactly once despite many calls.
assert info.misses == 1
assert info.hits >= 50
# Stable identity (frozenset) returned each call.
assert _anthropic_messages_optional_param_keys() is first
assert isinstance(first, frozenset)
assert "temperature" in first and "tools" in first
def test_empty_params():
assert (
AnthropicMessagesRequestUtils.get_requested_anthropic_messages_optional_param(
{}
)
== {}
)
@@ -101,6 +101,73 @@ def test_get_agentic_loop_settings_defaults_and_overrides():
assert fingerprints == ["fp-1", "fp-2"]
def test_has_agentic_completion_hook_detection(monkeypatch):
"""The streaming path skips the agentic wrapper only when no callback
overrides async_should_run_agentic_loop. Verify both directions."""
from litellm.integrations.custom_logger import CustomLogger
handler = BaseLLMHTTPHandler()
logging_obj = Mock()
logging_obj.dynamic_success_callbacks = []
# No callbacks at all -> no agentic hook.
monkeypatch.setattr(litellm, "callbacks", [])
assert handler._has_agentic_completion_hook(logging_obj) is False
# A plain CustomLogger that does NOT override the gate -> still no hook
# (so the wrapper is safely skipped).
class _PlainLogger(CustomLogger):
pass
monkeypatch.setattr(litellm, "callbacks", [_PlainLogger()])
assert handler._has_agentic_completion_hook(logging_obj) is False
# A logger that overrides the gate (directly) -> hook present.
class _AgenticLogger(CustomLogger):
async def async_should_run_agentic_loop(
self, response, model, messages, tools, stream, custom_llm_provider, kwargs
):
return True, {}
monkeypatch.setattr(litellm, "callbacks", [_AgenticLogger()])
assert handler._has_agentic_completion_hook(logging_obj) is True
# Override inherited through an intermediate class is still detected
# (function-identity check, not a leaf __dict__ check).
class _DerivedAgenticLogger(_AgenticLogger):
pass
monkeypatch.setattr(litellm, "callbacks", [_DerivedAgenticLogger()])
assert handler._has_agentic_completion_hook(logging_obj) is True
# Hook supplied via logging_obj.dynamic_success_callbacks is detected too.
monkeypatch.setattr(litellm, "callbacks", [])
logging_obj.dynamic_success_callbacks = [_AgenticLogger()]
assert handler._has_agentic_completion_hook(logging_obj) is True
# String-named callback entry (e.g. "datadog") must be resolved to its
# CustomLogger instance via get_custom_logger_compatible_class -- the same
# way ProxyLogging._callback_capabilities handles them. Without that
# resolution a string-registered agentic callback would be silently
# skipped and the buffering wrapper would never fire.
logging_obj.dynamic_success_callbacks = []
agentic_via_string = _AgenticLogger()
monkeypatch.setattr(litellm, "callbacks", ["fake_string_callback"])
monkeypatch.setattr(
"litellm.litellm_core_utils.litellm_logging.get_custom_logger_compatible_class",
lambda name: agentic_via_string if name == "fake_string_callback" else None,
)
assert handler._has_agentic_completion_hook(logging_obj) is True
# Unresolvable string (returns None) is skipped, no false positive.
monkeypatch.setattr(litellm, "callbacks", ["unknown_callback"])
monkeypatch.setattr(
"litellm.litellm_core_utils.litellm_logging.get_custom_logger_compatible_class",
lambda name: None,
)
assert handler._has_agentic_completion_hook(logging_obj) is False
def test_fingerprint_agentic_tools_is_deterministic():
handler = BaseLLMHTTPHandler()
tools_a = {"tool_calls": [{"id": "1", "input": {"q": "abc"}, "name": "web_search"}]}
@@ -422,3 +489,143 @@ def test_sync_delete_responses_omits_body_for_azure():
assert captured["url"].endswith(
"/openai/responses/resp_xyz?api-version=2025-03-01-preview"
)
# ---------------------------------------------------------------------------
# Parity tests: request-body is serialized once and reused for the wire.
# (_async_post_anthropic_messages_with_http_error_retry)
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_anthropic_post_uses_prebuilt_body_without_redumping():
"""When the caller passes a pre-serialized (unsigned) body, attempt 0 must
send exactly those bytes -- no second json.dumps of request_body."""
import json as _json
handler = BaseLLMHTTPHandler()
request_body = {"model": "claude", "messages": [{"role": "user", "content": "hi"}]}
prebuilt = _json.dumps(request_body)
ok_resp = Mock()
ok_resp.raise_for_status = Mock(return_value=None)
http_client = Mock()
http_client.post = AsyncMock(return_value=ok_resp)
provider_config = Mock()
provider_config.max_retry_on_anthropic_messages_http_error = 2
logging_obj = Mock()
logging_obj.model_call_details = {}
out = await handler._async_post_anthropic_messages_with_http_error_retry(
async_httpx_client=http_client,
request_url="http://x/v1/messages",
headers={},
signed_json_body=prebuilt,
request_body=request_body,
stream=False,
logging_obj=logging_obj,
provider_config=provider_config,
litellm_params=GenericLiteLLMParams(),
api_key="k",
model="claude",
)
assert out is ok_resp
http_client.post.assert_awaited_once()
sent = http_client.post.await_args.kwargs["data"]
# Byte-identical to the legacy wire serialization, and the SAME object the
# caller already used for the pre-call log (no re-serialization).
assert sent == prebuilt
assert sent is prebuilt
@pytest.mark.asyncio
async def test_anthropic_post_falls_back_to_json_dumps_when_unsigned_none():
"""signed_json_body=None keeps the exact legacy behavior."""
import json as _json
handler = BaseLLMHTTPHandler()
request_body = {"model": "claude", "messages": [{"role": "user", "content": "yo"}]}
ok_resp = Mock()
ok_resp.raise_for_status = Mock(return_value=None)
http_client = Mock()
http_client.post = AsyncMock(return_value=ok_resp)
provider_config = Mock()
provider_config.max_retry_on_anthropic_messages_http_error = 1
logging_obj = Mock()
logging_obj.model_call_details = {}
await handler._async_post_anthropic_messages_with_http_error_retry(
async_httpx_client=http_client,
request_url="http://x/v1/messages",
headers={},
signed_json_body=None,
request_body=request_body,
stream=False,
logging_obj=logging_obj,
provider_config=provider_config,
litellm_params=GenericLiteLLMParams(),
api_key="k",
model="claude",
)
sent = http_client.post.await_args.kwargs["data"]
assert sent == _json.dumps(request_body)
@pytest.mark.asyncio
async def test_anthropic_post_retry_reserializes_mutated_body():
"""On a retryable HTTP error the body is mutated + re-signed; the prebuilt
body must NOT be reused -- attempt 1 sends the freshly serialized body."""
import json as _json
handler = BaseLLMHTTPHandler()
request_body = {"model": "claude", "messages": [{"role": "user", "content": "a"}]}
prebuilt = _json.dumps(request_body)
err_resp = Mock()
http_error = httpx.HTTPStatusError(
"bad", request=Mock(), response=Mock(status_code=400)
)
err_resp.raise_for_status = Mock(side_effect=http_error)
ok_resp = Mock()
ok_resp.raise_for_status = Mock(return_value=None)
http_client = Mock()
http_client.post = AsyncMock(side_effect=[err_resp, ok_resp])
def _mutate(e, request_data):
request_data["messages"][0]["content"] = "MUTATED"
provider_config = Mock()
provider_config.max_retry_on_anthropic_messages_http_error = 2
provider_config.should_retry_anthropic_messages_on_http_error = Mock(
return_value=True
)
provider_config.transform_anthropic_messages_request_on_http_error = _mutate
# Re-sign returns no signed body (native anthropic path) -> must re-dump.
provider_config.sign_request = Mock(return_value=({}, None))
logging_obj = Mock()
logging_obj.model_call_details = {}
await handler._async_post_anthropic_messages_with_http_error_retry(
async_httpx_client=http_client,
request_url="http://x/v1/messages",
headers={},
signed_json_body=prebuilt,
request_body=request_body,
stream=False,
logging_obj=logging_obj,
provider_config=provider_config,
litellm_params=GenericLiteLLMParams(),
api_key="k",
model="claude",
)
assert http_client.post.await_count == 2
first_sent = http_client.post.await_args_list[0].kwargs["data"]
second_sent = http_client.post.await_args_list[1].kwargs["data"]
assert first_sent == prebuilt # attempt 0 used prebuilt
assert second_sent == _json.dumps(request_body) # attempt 1 re-serialized
assert "MUTATED" in second_sent # ... the mutated body
@@ -684,3 +684,362 @@ class TestAnthropicBatchPassthroughCostTracking:
mock_proxy_logging_obj.get_proxy_hook.assert_called_once_with(
"managed_files"
)
class TestPureTextFastPathParity:
"""
The pure-text fast path in _build_complete_streaming_response must produce
a response (and downstream logging/cost payload) byte-identical to the
legacy stream_chunk_builder path. Anything non-text must fall back.
"""
@staticmethod
def _sse(event, data):
return f"event: {event}\ndata: {json.dumps(data)}\n\n".encode()
@staticmethod
def _to_all_chunks(raw_frames):
# Mirror production: raw bytes -> _convert_raw_bytes_to_str_lines.
from litellm.proxy.pass_through_endpoints.streaming_handler import (
PassThroughStreamingHandler,
)
return PassThroughStreamingHandler._convert_raw_bytes_to_str_lines(raw_frames)
@staticmethod
def _norm(resp):
if resp is None:
return None
d = resp.model_dump()
# id / created are non-deterministic even between two legacy runs.
d.pop("id", None)
d.pop("created", None)
return d
def _text_stream(
self,
texts,
*,
input_tokens=12,
cache_creation=0,
cache_read=0,
stop_reason="end_turn",
with_ping=True,
blocks=1,
):
frames = [
self._sse(
"message_start",
{
"type": "message_start",
"message": {
"id": "msg_abc",
"type": "message",
"role": "assistant",
"model": "claude-3-5-sonnet-20241022",
"content": [],
"stop_reason": None,
"stop_sequence": None,
"usage": {
"input_tokens": input_tokens,
"output_tokens": 0,
"cache_creation_input_tokens": cache_creation,
"cache_read_input_tokens": cache_read,
},
},
},
)
]
per_block = max(1, len(texts) // blocks)
ti = 0
for b in range(blocks):
frames.append(
self._sse(
"content_block_start",
{
"type": "content_block_start",
"index": b,
"content_block": {"type": "text", "text": ""},
},
)
)
if with_ping:
frames.append(self._sse("ping", {"type": "ping"}))
chunk_texts = texts[ti : ti + per_block] if b < blocks - 1 else texts[ti:]
ti += per_block
for t in chunk_texts:
frames.append(
self._sse(
"content_block_delta",
{
"type": "content_block_delta",
"index": b,
"delta": {"type": "text_delta", "text": t},
},
)
)
frames.append(
self._sse(
"content_block_stop", {"type": "content_block_stop", "index": b}
)
)
frames.append(
self._sse(
"message_delta",
{
"type": "message_delta",
"delta": {"stop_reason": stop_reason, "stop_sequence": None},
"usage": {"output_tokens": len(texts)},
},
)
)
frames.append(self._sse("message_stop", {"type": "message_stop"}))
return frames
def _assert_parity(self, raw_frames):
all_chunks = self._to_all_chunks(raw_frames)
lo1 = MagicMock()
lo1.model_call_details = {}
lo2 = MagicMock()
lo2.model_call_details = {}
legacy = AnthropicPassthroughLoggingHandler._build_complete_streaming_response_legacy(
all_chunks=list(all_chunks),
litellm_logging_obj=lo1,
model="claude-3-5-sonnet-20241022",
)
fast = AnthropicPassthroughLoggingHandler._build_complete_streaming_response(
all_chunks=list(all_chunks),
litellm_logging_obj=lo2,
model="claude-3-5-sonnet-20241022",
)
assert self._norm(fast) == self._norm(legacy)
# Downstream logged/billed payload must also match.
start = datetime.now()
end = datetime.now()
k_legacy = AnthropicPassthroughLoggingHandler._create_anthropic_response_logging_payload(
litellm_model_response=legacy,
model="claude-3-5-sonnet-20241022",
kwargs={},
start_time=start,
end_time=end,
logging_obj=lo1,
)
k_fast = AnthropicPassthroughLoggingHandler._create_anthropic_response_logging_payload(
litellm_model_response=fast,
model="claude-3-5-sonnet-20241022",
kwargs={},
start_time=start,
end_time=end,
logging_obj=lo2,
)
# Usage drives cost; it must be byte-identical between paths.
assert getattr(fast, "usage", None) == getattr(legacy, "usage", None)
# And the full logged payload (sans non-deterministic response id).
def _scrub(p):
d = dict(p)
r = d.get("complete_streaming_response_in_db") or d.get(
"complete_streaming_response"
)
return d, getattr(r, "usage", None)
assert _scrub(k_fast)[1] == _scrub(k_legacy)[1]
def test_parity_simple_text(self):
self._assert_parity(self._text_stream(["Hello", " ", "world", "!"]))
def test_parity_single_delta(self):
self._assert_parity(self._text_stream(["Just one piece of text."]))
def test_parity_cache_tokens(self):
self._assert_parity(
self._text_stream(
["a", "b", "c"], input_tokens=20, cache_creation=5, cache_read=7
)
)
def test_parity_max_tokens_stop(self):
self._assert_parity(self._text_stream(["tok"] * 8, stop_reason="max_tokens"))
def test_parity_no_ping(self):
self._assert_parity(self._text_stream(["x", "y"], with_ping=False))
def test_parity_empty_text_deltas(self):
self._assert_parity(self._text_stream(["", "hi", "", "there"]))
def test_parity_multi_text_block(self):
self._assert_parity(self._text_stream(["p1", "p2", "p3", "p4"], blocks=2))
def test_parity_multibyte_batched_frames(self):
# Several SSE events delivered in one network chunk.
frames = self._text_stream(["alpha", "beta", "gamma"])
merged = b"".join(frames)
self._assert_parity([merged])
def test_collapse_returns_none_for_tool_use(self):
frames = [
self._sse(
"message_start",
{
"type": "message_start",
"message": {
"id": "m",
"model": "x",
"role": "assistant",
"type": "message",
"content": [],
"usage": {"input_tokens": 1, "output_tokens": 0},
},
},
),
self._sse(
"content_block_start",
{
"type": "content_block_start",
"index": 0,
"content_block": {
"type": "tool_use",
"id": "t1",
"name": "get_weather",
"input": {},
},
},
),
self._sse(
"content_block_delta",
{
"type": "content_block_delta",
"index": 0,
"delta": {"type": "input_json_delta", "partial_json": "{}"},
},
),
self._sse("content_block_stop", {"type": "content_block_stop", "index": 0}),
self._sse("message_stop", {"type": "message_stop"}),
]
all_chunks = self._to_all_chunks(frames)
assert (
AnthropicPassthroughLoggingHandler._collapse_pure_text_chunks(
list(all_chunks)
)
is None
)
def test_collapse_returns_none_for_thinking(self):
frames = [
self._sse(
"message_start",
{
"type": "message_start",
"message": {
"id": "m",
"model": "x",
"role": "assistant",
"type": "message",
"content": [],
"usage": {"input_tokens": 1, "output_tokens": 0},
},
},
),
self._sse(
"content_block_start",
{
"type": "content_block_start",
"index": 0,
"content_block": {"type": "thinking", "thinking": ""},
},
),
self._sse(
"content_block_delta",
{
"type": "content_block_delta",
"index": 0,
"delta": {"type": "thinking_delta", "thinking": "hmm"},
},
),
self._sse("message_stop", {"type": "message_stop"}),
]
all_chunks = self._to_all_chunks(frames)
assert (
AnthropicPassthroughLoggingHandler._collapse_pure_text_chunks(
list(all_chunks)
)
is None
)
def test_collapse_actually_shrinks_chunk_count(self):
frames = self._text_stream(["a"] * 50)
all_chunks = list(self._to_all_chunks(frames))
collapsed = AnthropicPassthroughLoggingHandler._collapse_pure_text_chunks(
all_chunks
)
assert collapsed is not None
# 50 text deltas + 50 event markers + 1 ping collapse to far fewer.
assert len(collapsed) < len(all_chunks) / 2
def test_collapse_returns_none_for_interleaved_block_indexes(self):
"""
Anthropic sends content blocks strictly sequentially (start/deltas/stop
for one, then the next). If a stream ever interleaves deltas across
block indexes, the fast path must bail to legacy rather than merge text
from different blocks under a single index.
"""
frames = [
self._sse(
"message_start",
{
"type": "message_start",
"message": {
"id": "msg_abc",
"type": "message",
"role": "assistant",
"model": "claude-3-5-sonnet-20241022",
"content": [],
"stop_reason": None,
"stop_sequence": None,
"usage": {"input_tokens": 1, "output_tokens": 0},
},
},
),
self._sse(
"content_block_start",
{
"type": "content_block_start",
"index": 0,
"content_block": {"type": "text", "text": ""},
},
),
self._sse(
"content_block_start",
{
"type": "content_block_start",
"index": 1,
"content_block": {"type": "text", "text": ""},
},
),
# Interleave: delta for block 0, then delta for block 1, with no
# content_block_stop between them.
self._sse(
"content_block_delta",
{
"type": "content_block_delta",
"index": 0,
"delta": {"type": "text_delta", "text": "hello "},
},
),
self._sse(
"content_block_delta",
{
"type": "content_block_delta",
"index": 1,
"delta": {"type": "text_delta", "text": "world"},
},
),
self._sse("message_stop", {"type": "message_stop"}),
]
all_chunks = list(self._to_all_chunks(frames))
assert (
AnthropicPassthroughLoggingHandler._collapse_pure_text_chunks(all_chunks)
is None
)
@@ -10,6 +10,7 @@ from fastapi.responses import JSONResponse, StreamingResponse
import litellm
from litellm._uuid import uuid
from litellm.integrations.custom_logger import CustomLogger
from litellm.integrations.opentelemetry import UserAPIKeyAuth
from litellm.proxy.common_request_processing import (
ProxyBaseLLMRequestProcessing,
@@ -1316,8 +1317,17 @@ class TestCommonRequestProcessingHelpers:
yield 'data: {"content": "chunk 3"}\n\n'
yield "data: [DONE]\n\n"
# Patch the tracer in the common_request_processing module
with patch("litellm.proxy.common_request_processing.tracer", mock_tracer):
# Patch the tracer in the common_request_processing module. The
# per-chunk span is gated on _DD_STREAMING_TRACE_ENABLED (resolved at
# import from the real tracer, a NullTracer by default), so enable it
# explicitly to exercise the tracing path.
with (
patch("litellm.proxy.common_request_processing.tracer", mock_tracer),
patch(
"litellm.proxy.common_request_processing._DD_STREAMING_TRACE_ENABLED",
True,
),
):
response = await create_response(mock_generator(), "text/event-stream", {})
assert response.status_code == 200
@@ -1345,6 +1355,40 @@ class TestCommonRequestProcessingHelpers:
args[0] == "streaming.chunk.yield"
), f"Call {i} should have operation name 'streaming.chunk.yield', got {args[0]}"
async def test_create_streaming_response_skips_dd_trace_when_disabled(self):
"""When DD tracing is disabled (the default), the per-chunk span
context manager is skipped entirely but all chunks still stream."""
from unittest.mock import patch
mock_tracer = MagicMock()
async def mock_generator():
yield 'data: {"content": "chunk 1"}\n\n'
yield 'data: {"content": "chunk 2"}\n\n'
yield "data: [DONE]\n\n"
with (
patch("litellm.proxy.common_request_processing.tracer", mock_tracer),
patch(
"litellm.proxy.common_request_processing._DD_STREAMING_TRACE_ENABLED",
False,
),
):
response = await create_response(mock_generator(), "text/event-stream", {})
assert response.status_code == 200
content = await self.consume_stream(response)
# All chunks stream through unchanged ...
assert content == [
'data: {"content": "chunk 1"}\n\n',
'data: {"content": "chunk 2"}\n\n',
"data: [DONE]\n\n",
]
# ... but no per-chunk span was created.
assert mock_tracer.trace.call_count == 0
async def test_create_streaming_response_dd_trace_with_error_chunk(self):
"""
Test that when the first chunk contains an error, JSONResponse is returned
@@ -2199,3 +2243,77 @@ class TestHandleLLMApiExceptionDictDetail:
proxy_exc = await self._invoke(exc)
assert proxy_exc.message == "Content blocked by guardrail"
assert proxy_exc.provider_specific_fields is None
class TestAsyncStreamingDataGeneratorFastPath:
"""Fast/slow path branching in async_streaming_data_generator."""
@staticmethod
async def _aiter(items):
for item in items:
yield item
@pytest.mark.asyncio
async def test_fast_path_skips_per_chunk_hook(self, monkeypatch):
"""With no callbacks/guardrails/cost-injection, chunks pass through
unchanged and the per-chunk hook is NOT awaited."""
monkeypatch.setattr(litellm, "callbacks", [])
ProxyLogging._callback_capabilities_cache.clear()
proxy_logging_obj = ProxyLogging(user_api_key_cache=MagicMock())
hook_spy = AsyncMock(side_effect=lambda **kw: kw["response"])
monkeypatch.setattr(
proxy_logging_obj, "async_post_call_streaming_hook", hook_spy
)
chunks = [b"event: a\ndata: {}\n\n", b"event: b\ndata: {}\n\n"]
out = [
c
async for c in ProxyBaseLLMRequestProcessing.async_streaming_data_generator(
response=self._aiter(chunks),
user_api_key_dict=MagicMock(spec=UserAPIKeyAuth),
request_data={"model": "claude-x"},
proxy_logging_obj=proxy_logging_obj,
serialize_chunk=ProxyBaseLLMRequestProcessing.return_sse_chunk,
serialize_error=lambda e: "data: error\n\n",
)
]
assert out == chunks # bytes pass through return_sse_chunk untouched
hook_spy.assert_not_awaited()
@pytest.mark.asyncio
async def test_slow_path_runs_per_chunk_hook(self, monkeypatch):
"""A callback that overrides async_post_call_streaming_hook forces the
slow path and the per-chunk hook is invoked."""
class _StreamingCb(CustomLogger):
async def async_post_call_streaming_hook(self, user_api_key_dict, response):
return response
cb = _StreamingCb()
monkeypatch.setattr(litellm, "callbacks", [cb])
ProxyLogging._callback_capabilities_cache.clear()
proxy_logging_obj = ProxyLogging(user_api_key_cache=MagicMock())
hook_spy = AsyncMock(side_effect=lambda **kw: kw["response"])
monkeypatch.setattr(
proxy_logging_obj, "async_post_call_streaming_hook", hook_spy
)
out = [
c
async for c in ProxyBaseLLMRequestProcessing.async_streaming_data_generator(
response=self._aiter([{"type": "message_stop"}]),
user_api_key_dict=MagicMock(spec=UserAPIKeyAuth),
request_data={"model": "claude-x"},
proxy_logging_obj=proxy_logging_obj,
serialize_chunk=ProxyBaseLLMRequestProcessing.return_sse_chunk,
serialize_error=lambda e: "data: error\n\n",
)
]
assert len(out) == 1
hook_spy.assert_awaited_once()
ProxyLogging._callback_capabilities_cache.clear()
@@ -111,6 +111,28 @@ def test_callback_capabilities_captures_iterator_override(monkeypatch):
assert kind == "override"
def test_callback_capabilities_detects_inherited_streaming_chunk_override(monkeypatch):
"""
``async_post_call_streaming_hook`` must be detected even when the override
lives on an intermediate parent class a vendor base class can carry the
override and the registered class can add nothing else. Before this PR the
hook was unconditionally invoked, so a leaf-class ``__dict__`` miss here
would silently drop the inherited hook.
"""
ProxyLogging._callback_capabilities_cache.clear()
class _StreamingBase(CustomLogger):
async def async_post_call_streaming_hook(self, *args, **kwargs): # type: ignore[override]
return kwargs.get("response")
class _LeafWithoutOverride(_StreamingBase):
pass
monkeypatch.setattr(litellm, "callbacks", [_LeafWithoutOverride()])
caps = ProxyLogging._callback_capabilities()
assert caps.has_streaming_chunk_override is True
def test_callback_capabilities_cache_invalidates_on_list_change(monkeypatch):
"""The cache key includes (length, id-of-each-callback). Mutating the
callback list must produce a fresh capability snapshot."""