From 2eab9ee2c0caf66b6ed51c3f3cb9b41d59cd1001 Mon Sep 17 00:00:00 2001 From: Yassin Kortam Date: Sat, 23 May 2026 12:15:59 -0700 Subject: [PATCH] perf: reduce per-request and per-chunk overhead across Anthropic streaming hot paths (#28289) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * perf: reduce per-request and per-chunk overhead across Anthropic streaming hot paths - Introduce pure-text fast-path in `_build_complete_streaming_response` that collapses O(N) `content_block_delta` events into a single equivalent SSE event before conversion, eliminating per-output-token Pydantic `ModelResponseStream` construction; non-text streams (tool_use, thinking, citations) fall back to the unchanged legacy path - Skip agentic streaming wrapper entirely when no callback overrides `async_should_run_agentic_loop`; the wrapper buffered every chunk and rebuilt the SSE response only to call hooks that all return `(False, {})` — a pure no-op for the default config - Serialize request body once (`json.dumps`) for both the pre-call log input and the wire, instead of twice; avoids a full O(payload) scan per request, significant for long-context Claude Code histories - Add fast path in `async_streaming_data_generator` that bypasses the per-chunk `async_post_call_streaming_hook` coroutine await, response-string materialization, and cost-injection call when no callback/guardrail/cost-injection is active (the default config) - Resolve `_DD_STREAMING_TRACE_ENABLED` once at import time; eliminate per-chunk `NullSpan` context manager allocation when Datadog tracing is disabled (the default) - Memoize `get_type_hints(AnthropicMessagesRequestOptionalParams)` with `@lru_cache(maxsize=1)` — resolves once per process instead of once per `/v1/messages` request (~80µs each) - Hoist `cost_injection_active` out of the per-chunk loop in `chunk_processor`; eliminates repeated `getattr` + endpoint-type checks on every streamed byte chunk - Extract `_build_passthrough_logging_result` from `_route_streaming_logging_to_handler` as a standalone static method to facilitate future off-loop dispatch - Convert `async_sse_data_generator` from an `async for: yield` trampoline to a direct return of the underlying generator, removing one async-generator layer per streamed chunk - Skip redundant `strip_empty_text_blocks_from_anthropic_messages` scan in `anthropic_messages_handler` when the async wrapper already sanitized (signalled via `_litellm_messages_presanitized` sentinel, popped before reaching provider params) - Gate debug log `f-string` evaluation behind `isEnabledFor(DEBUG)` in both the streaming generator and the transformation layer to avoid serializing entire message payloads on every request at non-debug log levels - Add benchmark script (`scripts/benchmark_anthropic_messages_perf.py`) with a local mock Anthropic SSE provider for reproducible TTFT and TPM measurement across commits/branches - Add parity tests asserting fast-path and legacy-path produce byte-identical logged/billed payloads, plus unit tests for agentic hook detection, pre-serialized body reuse, and memoized key resolution * perf: address greptile review for anthropic streaming hot path - Bail to legacy in `_collapse_pure_text_chunks` when content_block_delta events from different block indexes are observed without an intervening flush. Anthropic sends blocks strictly sequentially, but defensive bail prevents silent text-merging if the protocol ever interleaves. - Replace leaf-class `__dict__` check for `async_post_call_streaming_hook` in `_callback_capabilities` with a function-identity comparison that walks the MRO. A vendor base class can carry the override and the registered class can add nothing else; before this PR the hook was unconditionally invoked, so an inherited-override miss would silently drop the hook on the streaming path. - Add unit tests for both behaviors. Co-Authored-By: Claude Opus 4.7 (1M context) * fix(mypy): narrow model_name to str in cost-injection branch The hoisted cost_injection_active flag in chunk_processor encodes the `bool(model_name)` requirement but mypy can't track that invariant through the local, so the per-chunk `_process_chunk_with_cost_injection( chunk, model_name)` calls flagged Optional[str] vs str. Pin a typed non-None local inside the cost-injection branch so mypy narrows correctly without changing runtime behavior. Co-Authored-By: Claude Opus 4.7 (1M context) --------- Co-authored-by: Yassin Kortam Co-authored-by: Claude Opus 4.7 (1M context) --- .../messages/handler.py | 16 +- .../messages/transformation.py | 5 +- .../messages/utils.py | 17 +- litellm/llms/custom_httpx/llm_http_handler.py | 73 +- litellm/proxy/common_request_processing.py | 62 +- .../anthropic_passthrough_logging_handler.py | 169 ++++- .../streaming_handler.py | 191 ++++-- litellm/proxy/utils.py | 17 +- scripts/benchmark_anthropic_messages_perf.py | 624 ++++++++++++++++++ ...erimental_pass_through_messages_handler.py | 129 ++++ .../test_request_optional_param_utils.py | 56 ++ .../custom_httpx/test_llm_http_handler.py | 207 ++++++ ...t_anthropic_passthrough_logging_handler.py | 359 ++++++++++ .../proxy/test_common_request_processing.py | 122 +++- .../test_proxy_logging_hook_detection.py | 22 + 15 files changed, 1978 insertions(+), 91 deletions(-) create mode 100644 scripts/benchmark_anthropic_messages_perf.py create mode 100644 tests/test_litellm/llms/anthropic/experimental_pass_through/messages/test_request_optional_param_utils.py diff --git a/litellm/llms/anthropic/experimental_pass_through/messages/handler.py b/litellm/llms/anthropic/experimental_pass_through/messages/handler.py index 009ba6ef30..14e06e047e 100644 --- a/litellm/llms/anthropic/experimental_pass_through/messages/handler.py +++ b/litellm/llms/anthropic/experimental_pass_through/messages/handler.py @@ -293,6 +293,12 @@ async def anthropic_messages( api_base=api_base, client=client, custom_llm_provider=custom_llm_provider, + # messages were already empty-text-block sanitized at the top of this + # function and are NOT reassigned before this dispatch, so the handler + # can skip its (otherwise redundant) second full-messages scan. Passed + # explicitly (not via **kwargs) so it only affects this direct + # dispatch -- interceptor / sync entry points still sanitize. + _litellm_messages_presanitized=True, **kwargs, ) ctx = contextvars.copy_context() @@ -351,10 +357,14 @@ def anthropic_messages_handler( """ from litellm.types.utils import LlmProviders - # Sanitize empty text blocks here too so the sync entry point + # Sanitize empty text blocks so the sync entry point # (litellm.messages.create -> anthropic_messages_handler) gets the same - # protection as the async wrapper. Idempotent when called twice. - messages = strip_empty_text_blocks_from_anthropic_messages(messages) + # protection as the async wrapper. The async wrapper already sanitized and + # does not reassign messages before dispatch, so it sets + # ``_litellm_messages_presanitized`` to skip this redundant second + # full-messages scan. Pop it so it never leaks into provider params. + if not kwargs.pop("_litellm_messages_presanitized", False): + messages = strip_empty_text_blocks_from_anthropic_messages(messages) metadata = validate_anthropic_api_metadata(metadata) diff --git a/litellm/llms/anthropic/experimental_pass_through/messages/transformation.py b/litellm/llms/anthropic/experimental_pass_through/messages/transformation.py index 35495d5961..15f404d3f5 100644 --- a/litellm/llms/anthropic/experimental_pass_through/messages/transformation.py +++ b/litellm/llms/anthropic/experimental_pass_through/messages/transformation.py @@ -312,7 +312,10 @@ class AnthropicMessagesConfig(BaseAnthropicMessagesConfig): ) ####### get required params for all anthropic messages requests ###### - verbose_logger.debug(f"TRANSFORMATION DEBUG - Messages: {messages}") + # Lazy %s: the f-string previously stringified the entire messages + # payload on every request regardless of log level (a full scan of the + # request body on the hot path). Defer it to when DEBUG is enabled. + verbose_logger.debug("TRANSFORMATION DEBUG - Messages: %s", messages) # Auto-strip advisor blocks from history if advisor tool is absent. # Prevents Anthropic 400: advisor_tool_result in history requires advisor tool. diff --git a/litellm/llms/anthropic/experimental_pass_through/messages/utils.py b/litellm/llms/anthropic/experimental_pass_through/messages/utils.py index fa951ebd2e..88832fb3f6 100644 --- a/litellm/llms/anthropic/experimental_pass_through/messages/utils.py +++ b/litellm/llms/anthropic/experimental_pass_through/messages/utils.py @@ -1,4 +1,5 @@ -from typing import Any, Dict, List, cast, get_type_hints +from functools import lru_cache +from typing import Any, Dict, FrozenSet, List, cast, get_type_hints from litellm.types.llms.anthropic import AnthropicMessagesRequestOptionalParams from litellm.types.llms.anthropic_messages.anthropic_response import ( @@ -6,6 +7,18 @@ from litellm.types.llms.anthropic_messages.anthropic_response import ( ) +@lru_cache(maxsize=1) +def _anthropic_messages_optional_param_keys() -> FrozenSet[str]: + """ + Valid AnthropicMessagesRequestOptionalParams keys. + + ``typing.get_type_hints`` is ~80us/call and this TypedDict is static, so + resolving it once per process instead of once per request removes a fixed + full-pass cost from the /v1/messages request-parse path. + """ + return frozenset(get_type_hints(AnthropicMessagesRequestOptionalParams).keys()) + + class AnthropicMessagesRequestUtils: @staticmethod def get_requested_anthropic_messages_optional_param( @@ -20,7 +33,7 @@ class AnthropicMessagesRequestUtils: Returns: AnthropicMessagesRequestOptionalParams instance with only the valid parameters """ - valid_keys = get_type_hints(AnthropicMessagesRequestOptionalParams).keys() + valid_keys = _anthropic_messages_optional_param_keys() filtered_params = { k: v for k, v in params.items() if k in valid_keys and v is not None } diff --git a/litellm/llms/custom_httpx/llm_http_handler.py b/litellm/llms/custom_httpx/llm_http_handler.py index 090aac187b..74e55b9deb 100644 --- a/litellm/llms/custom_httpx/llm_http_handler.py +++ b/litellm/llms/custom_httpx/llm_http_handler.py @@ -1886,7 +1886,9 @@ class BaseLLMHTTPHandler: async_httpx_client: AsyncHTTPHandler, request_url: str, headers: dict, - signed_json_body: Optional[bytes], + # str when the caller passes a pre-serialized (unsigned) body to avoid + # re-dumping; bytes when a provider signed the request (e.g. Bedrock). + signed_json_body: Optional[Union[str, bytes]], request_body: dict, stream: bool, logging_obj: LiteLLMLoggingObj, @@ -2077,8 +2079,18 @@ class BaseLLMHTTPHandler: model=model, ) + # The request body was serialized once for the pre-call log input and + # again for the wire (json.dumps is O(payload), large for long-context + # Claude Code history). Serialize once and reuse for both. Only when + # the provider didn't sign the request (sign_request no-op for the + # native anthropic path -> signed_json_body is None); signed providers + # (e.g. Bedrock) keep their signed body untouched. The HTTP-error + # retry path mutates + re-signs the body, so it still re-serializes + # internally -- this only deduplicates the success path. + request_body_json = json.dumps(request_body) + logging_obj.pre_call( - input=[{"role": "user", "content": json.dumps(request_body)}], + input=[{"role": "user", "content": request_body_json}], api_key="", additional_args={ "complete_input_dict": request_body, @@ -2091,7 +2103,9 @@ class BaseLLMHTTPHandler: async_httpx_client=async_httpx_client, request_url=request_url, headers=headers, - signed_json_body=signed_json_body, + signed_json_body=( + signed_json_body if signed_json_body is not None else request_body_json + ), request_body=request_body, stream=stream or False, logging_obj=logging_obj, @@ -2113,6 +2127,14 @@ class BaseLLMHTTPHandler: litellm_logging_obj=logging_obj, ) + if not self._has_agentic_completion_hook(logging_obj): + # No callback overrides async_should_run_agentic_loop, so the + # agentic wrapper's only effect would be buffering every chunk + # and rebuilding the response from SSE at end-of-stream to call + # hooks that all return (False, {}). Stream through directly and + # skip that per-chunk + end-of-stream overhead. + return completion_stream + from litellm.llms.anthropic.experimental_pass_through.messages.agentic_streaming_iterator import ( AgenticAnthropicStreamingIterator, ) @@ -4620,6 +4642,51 @@ class BaseLLMHTTPHandler: fingerprints = list(kwargs.get("_agentic_loop_fingerprints", []) or []) return depth, max(max_loops, 1), fingerprints + @staticmethod + def _has_agentic_completion_hook(logging_obj: Any) -> bool: + """ + True if any registered callback actually overrides + ``async_should_run_agentic_loop`` (the gate every agentic hook goes + through). The base ``CustomLogger`` implementation returns + ``(False, {})``, so when nothing overrides it the agentic + post-processing is a guaranteed no-op and the streaming wrapper that + buffers + rebuilds the whole response from SSE just to call it can be + skipped entirely. + + Function-identity comparison (not a leaf ``__dict__`` check) so an + override inherited through any intermediate class is still detected -- + a false negative here would silently disable agentic features. + + String entries in ``litellm.callbacks`` (e.g. ``"datadog"``) are + resolved to their ``CustomLogger`` instance via + ``get_custom_logger_compatible_class`` -- same pattern as + ``ProxyLogging._callback_capabilities`` -- so a string-registered + agentic callback is detected too. + """ + from litellm.integrations.custom_logger import CustomLogger + from litellm.litellm_core_utils.litellm_logging import ( + get_custom_logger_compatible_class, + ) + + base_func = CustomLogger.async_should_run_agentic_loop + callbacks = litellm.callbacks + ( + getattr(logging_obj, "dynamic_success_callbacks", None) or [] + ) + for cb in callbacks: + if isinstance(cb, str): + resolved = get_custom_logger_compatible_class(cb) # type: ignore[arg-type] + if resolved is None: + continue + cb = resolved + if not isinstance(cb, CustomLogger): + continue + cb_func = getattr(type(cb), "async_should_run_agentic_loop", base_func) + if getattr(cb_func, "__func__", cb_func) is not getattr( + base_func, "__func__", base_func + ): + return True + return False + @staticmethod def _check_agentic_loop_safety( tool_calls: Any, diff --git a/litellm/proxy/common_request_processing.py b/litellm/proxy/common_request_processing.py index 7d2954fd2d..ef1d64335b 100644 --- a/litellm/proxy/common_request_processing.py +++ b/litellm/proxy/common_request_processing.py @@ -32,7 +32,7 @@ from litellm.constants import ( STREAM_SSE_DATA_PREFIX, ) from litellm.integrations.custom_guardrail import CustomGuardrail -from litellm.litellm_core_utils.dd_tracing import tracer +from litellm.litellm_core_utils.dd_tracing import NullTracer, tracer from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj from litellm.litellm_core_utils.llm_response_utils.get_headers import ( get_response_headers, @@ -65,6 +65,13 @@ else: from litellm.proxy.litellm_pre_call_utils import add_litellm_data_to_request from litellm.types.utils import ModelResponse, ModelResponseStream, Usage +# Datadog streaming spans are a no-op when ddtrace is not enabled, but the +# ``with tracer.trace(...)`` context manager still allocates a NullSpan and +# runs __enter__/__exit__ for every streamed chunk. Resolve once at import so +# the per-chunk hot path can skip the context manager entirely when tracing +# is off (the default). +_DD_STREAMING_TRACE_ENABLED = not isinstance(tracer, NullTracer) + def _serialize_http_exception_detail( detail: Any, @@ -231,7 +238,7 @@ def _extract_error_from_sse_chunk(event_line: Union[str, bytes]) -> dict: return default_error -async def create_response( +async def create_response( # noqa: PLR0915 generator: AsyncGenerator[str, None], media_type: str, headers: dict, @@ -336,6 +343,13 @@ async def create_response( ) async def combined_generator() -> AsyncGenerator[str, None]: + if not _DD_STREAMING_TRACE_ENABLED: + # Fast path: no per-chunk span object / context-manager overhead. + if first_chunk_value is not None: + yield first_chunk_value + async for chunk in generator: + yield chunk + return if first_chunk_value is not None: with tracer.trace(DD_TRACER_STREAMING_CHUNK_YIELD_RESOURCE): yield first_chunk_value @@ -1900,6 +1914,23 @@ class ProxyBaseLLMRequestProcessing: failure hook and yields via serialize_error. Use for SSE or NDJSON. """ verbose_proxy_logger.debug("inside generator") + # Resolve per-stream (not per-chunk) whether the heavy per-chunk path + # is needed. When no callback overrides ``async_post_call_streaming_hook``, + # no CustomGuardrail is active, and cost injection is disabled, the + # per-chunk hook returns the chunk unchanged, ``str_so_far`` is never + # consumed, and cost injection is a no-op -- so the per-chunk coroutine + # await, response-string materialization, and cost-injection call are + # pure overhead on the streaming hot path (the default config). + caps = ProxyLogging._callback_capabilities() + cost_injection_enabled = bool( + getattr(litellm, "include_cost_in_streaming_usage", False) + ) + fast_path = ( + not caps.has_streaming_chunk_override + and not caps.has_guardrail + and not cost_injection_enabled + ) + debug_enabled = verbose_proxy_logger.isEnabledFor(logging.DEBUG) try: str_so_far = "" async for ( @@ -1909,9 +1940,17 @@ class ProxyBaseLLMRequestProcessing: response=response, request_data=request_data, ): - verbose_proxy_logger.debug( - "async_data_generator: received streaming chunk - {}".format(chunk) - ) + # ``.format(chunk)`` was previously evaluated for every chunk + # regardless of log level; gate it behind the level check. + if debug_enabled: + verbose_proxy_logger.debug( + "async_data_generator: received streaming chunk - %s", chunk + ) + + if fast_path: + yield serialize_chunk(chunk) + continue + chunk = await proxy_logging_obj.async_post_call_streaming_hook( user_api_key_dict=user_api_key_dict, response=chunk, @@ -1969,7 +2008,7 @@ class ProxyBaseLLMRequestProcessing: yield serialize_error(proxy_exception) @staticmethod - async def async_sse_data_generator( + def async_sse_data_generator( response: Any, user_api_key_dict: UserAPIKeyAuth, request_data: dict, @@ -1977,17 +2016,20 @@ class ProxyBaseLLMRequestProcessing: ) -> AsyncGenerator[str, None]: """ Anthropic /messages and Google /generateContent streaming data generator require SSE events. - Delegates to async_streaming_data_generator with SSE serializers. + + Returns the underlying ``async_streaming_data_generator`` configured with + SSE serializers directly (rather than re-wrapping it in another + ``async for: yield`` trampoline), so a streamed chunk traverses one + fewer async-generator layer / coroutine resume on the hot path. """ - async for chunk in ProxyBaseLLMRequestProcessing.async_streaming_data_generator( + return ProxyBaseLLMRequestProcessing.async_streaming_data_generator( response=response, user_api_key_dict=user_api_key_dict, request_data=request_data, proxy_logging_obj=proxy_logging_obj, serialize_chunk=ProxyBaseLLMRequestProcessing.return_sse_chunk, serialize_error=lambda proxy_exc: f"{STREAM_SSE_DATA_PREFIX}{json.dumps({'error': proxy_exc.to_dict()})}\n\n", - ): - yield chunk + ) @staticmethod def _process_chunk_with_cost_injection(chunk: Any, model_name: str) -> Any: diff --git a/litellm/proxy/pass_through_endpoints/llm_provider_handlers/anthropic_passthrough_logging_handler.py b/litellm/proxy/pass_through_endpoints/llm_provider_handlers/anthropic_passthrough_logging_handler.py index c42faa59cf..3be26eb572 100644 --- a/litellm/proxy/pass_through_endpoints/llm_provider_handlers/anthropic_passthrough_logging_handler.py +++ b/litellm/proxy/pass_through_endpoints/llm_provider_handlers/anthropic_passthrough_logging_handler.py @@ -266,7 +266,174 @@ class AnthropicPassthroughLoggingHandler: model: str, ) -> Optional[Union[ModelResponse, TextCompletionResponse]]: """ - Builds complete response from raw Anthropic chunks + Builds complete response from raw Anthropic chunks. + + Fast path: for the dominant case of a pure-text streaming response + (no tool_use / thinking / non-text content blocks), the long run of + ``content_block_delta`` text deltas is collapsed into a single + equivalent SSE event before conversion. ``chunk_parser`` and + ``stream_chunk_builder`` remain the single source of truth for chunk + shape, usage math and finish-reason mapping, so the rebuilt response + (and therefore the logged/billed payload) is identical -- this is + asserted by a parity test. Anything non-trivial falls back to the + unchanged legacy reconstruction. + + Per-event Pydantic ``ModelResponseStream`` construction dominated + event-loop CPU under concurrent streaming; collapsing the homogeneous + text run removes O(num_output_tokens) of it. + """ + collapsed = AnthropicPassthroughLoggingHandler._collapse_pure_text_chunks( + all_chunks + ) + if collapsed is not None: + return AnthropicPassthroughLoggingHandler._build_complete_streaming_response_legacy( + all_chunks=collapsed, + litellm_logging_obj=litellm_logging_obj, + model=model, + ) + return AnthropicPassthroughLoggingHandler._build_complete_streaming_response_legacy( + all_chunks=all_chunks, + litellm_logging_obj=litellm_logging_obj, + model=model, + ) + + # Anthropic SSE block/delta types that the fast path is NOT allowed to + # collapse -- their presence forces the unchanged legacy path so tool + # calls, thinking, citations, etc. keep byte-identical reconstruction. + _FAST_PATH_DISALLOWED_DELTA_TYPES = frozenset( + { + "input_json_delta", + "thinking_delta", + "signature_delta", + "citations_delta", + } + ) + + @staticmethod + def _collapse_pure_text_chunks( # noqa: PLR0915 + all_chunks: Sequence[Union[str, bytes]], + ) -> Optional[List[str]]: + """ + Return a new chunk list with the contiguous run of text-only + ``content_block_delta`` events replaced by a single equivalent event, + or ``None`` if the stream is not a pure single-text-block response + (in which case the caller uses the legacy path unchanged). + + Only ``message_start`` / ``content_block_start(text)`` / + ``content_block_delta(text_delta)`` / ``content_block_stop`` / + ``message_delta`` / ``message_stop`` / ``ping`` events are accepted. + Any other content-block type or delta type returns ``None``. + """ + normalized: List[str] = [] + for raw in all_chunks: + line = raw.decode("utf-8") if isinstance(raw, bytes) else raw + for ev in line.split("\n\n"): + ev = ev.strip() + if ev: + normalized.append(ev) + + text_block_indexes: set = set() + out: List[str] = [] + pending_text: List[str] = [] + pending_index: Optional[int] = None + saw_any_text_delta = False + + def flush() -> None: + nonlocal pending_text, pending_index + if pending_text: + merged = { + "type": "content_block_delta", + "index": pending_index if pending_index is not None else 0, + "delta": {"type": "text_delta", "text": "".join(pending_text)}, + } + out.append("data: " + json.dumps(merged)) + pending_text = [] + pending_index = None + + for ev in normalized: + idx = ev.find("data:") + if idx == -1: + # Bare "event: " line. The legacy converter turns this + # into an empty ModelResponseStream that contributes nothing + # to stream_chunk_builder. Drop the high-frequency interior + # markers (content_block_delta / ping); keep every other + # bare event line verbatim so chunk ordering and the + # load-bearing chunks[0] (event: message_start) are retained. + name = ev[len("event:") :].strip() if ev.startswith("event:") else "" + if name in ("content_block_delta", "ping"): + continue + flush() + out.append(ev) + continue + + json_str = ev[idx + len("data:") :].strip() + try: + data = json.loads(json_str) + except (json.JSONDecodeError, ValueError): + return None + + etype = data.get("type") + if etype == "content_block_start": + block = data.get("content_block") or {} + if block.get("type") != "text": + return None + text_block_indexes.add(data.get("index")) + flush() + out.append(ev) + elif etype == "content_block_delta": + delta = data.get("delta") or {} + dtype = delta.get("type") + if ( + dtype + in AnthropicPassthroughLoggingHandler._FAST_PATH_DISALLOWED_DELTA_TYPES + ): + return None + if dtype != "text_delta": + return None + cur_index = data.get("index") + if cur_index not in text_block_indexes: + return None + # Defensive: Anthropic sends blocks strictly sequentially + # (start/deltas/stop, then next block), so pending_text from + # block N must be flushed by content_block_stop before block + # N+1's deltas arrive. If we ever see a delta whose index + # disagrees with the current pending buffer, the stream is + # interleaved -- fall back to legacy rather than risk merging + # text from different blocks under a single index. + if ( + pending_text + and pending_index is not None + and cur_index != pending_index + ): + return None + saw_any_text_delta = True + pending_index = cur_index + pending_text.append(delta.get("text") or "") + elif etype == "ping": + # Interior no-op; legacy maps it to an empty chunk. + continue + else: + # message_start / content_block_stop / message_delta / + # message_stop / error: pass through unchanged. + flush() + out.append(ev) + + flush() + + if not saw_any_text_delta: + return None + return out + + @staticmethod + def _build_complete_streaming_response_legacy( + all_chunks: Sequence[Union[str, bytes]], + litellm_logging_obj: LiteLLMLoggingObj, + model: str, + ) -> Optional[Union[ModelResponse, TextCompletionResponse]]: + """ + Original reconstruction: convert every SSE event to a generic chunk + and assemble via stream_chunk_builder. Kept verbatim as the fallback + / source of truth for the fast path's parity test. - Splits multi-event chunks into individual SSE events - Converts str chunks to generic chunks diff --git a/litellm/proxy/pass_through_endpoints/streaming_handler.py b/litellm/proxy/pass_through_endpoints/streaming_handler.py index cbfcd34c43..235a38b75f 100644 --- a/litellm/proxy/pass_through_endpoints/streaming_handler.py +++ b/litellm/proxy/pass_through_endpoints/streaming_handler.py @@ -1,6 +1,6 @@ import asyncio from datetime import datetime -from typing import List, Optional +from typing import List, Optional, Tuple import httpx @@ -45,28 +45,44 @@ class PassThroughStreamingHandler: litellm_logging_obj=litellm_logging_obj, ) + # Resolve once per stream rather than re-reading the global + + # re-branching on every chunk. ``include_cost_in_streaming_usage`` is + # set at config load and stable for the process, matching how the + # proxy-level streaming fast path resolves it. + cost_injection_active = ( + bool(getattr(litellm, "include_cost_in_streaming_usage", False)) + and bool(model_name) + and endpoint_type in (EndpointType.VERTEX_AI, EndpointType.ANTHROPIC) + ) try: - async for chunk in response.aiter_bytes(): - raw_bytes.append(chunk) - if ( - getattr(litellm, "include_cost_in_streaming_usage", False) - and model_name - ): + if not cost_injection_active: + # Hot path: just buffer for end-of-stream logging and forward. + async for chunk in response.aiter_bytes(): + raw_bytes.append(chunk) + yield chunk + else: + # ``cost_injection_active`` already requires ``model_name`` to + # be truthy; pin to a typed local so mypy narrows ``Optional[str]`` + # -> ``str`` for the per-chunk call site. + assert model_name is not None + resolved_model_name: str = model_name + async for chunk in response.aiter_bytes(): + raw_bytes.append(chunk) if endpoint_type == EndpointType.VERTEX_AI: if "streamRawPredict" in url_route or "rawPredict" in url_route: modified_chunk = ProxyBaseLLMRequestProcessing._process_chunk_with_cost_injection( - chunk, model_name + chunk, resolved_model_name ) if modified_chunk is not None: chunk = modified_chunk - elif endpoint_type == EndpointType.ANTHROPIC: + else: # EndpointType.ANTHROPIC modified_chunk = ProxyBaseLLMRequestProcessing._process_chunk_with_cost_injection( - chunk, model_name + chunk, resolved_model_name ) if modified_chunk is not None: chunk = modified_chunk - yield chunk + yield chunk except Exception as e: verbose_proxy_logger.error(f"Error in chunk_processor: {str(e)}") raise @@ -115,64 +131,20 @@ class PassThroughStreamingHandler: - OpenAI """ try: - all_chunks = PassThroughStreamingHandler._convert_raw_bytes_to_str_lines( - raw_bytes + ( + standard_logging_response_object, + kwargs, + ) = PassThroughStreamingHandler._build_passthrough_logging_result( + litellm_logging_obj=litellm_logging_obj, + passthrough_success_handler_obj=passthrough_success_handler_obj, + url_route=url_route, + request_body=request_body, + endpoint_type=endpoint_type, + start_time=start_time, + raw_bytes=raw_bytes, + end_time=end_time, + model=model, ) - standard_logging_response_object: Optional[ - PassThroughEndpointLoggingResultValues - ] = None - kwargs: dict = {} - if endpoint_type == EndpointType.ANTHROPIC: - anthropic_passthrough_logging_handler_result = AnthropicPassthroughLoggingHandler._handle_logging_anthropic_collected_chunks( - litellm_logging_obj=litellm_logging_obj, - passthrough_success_handler_obj=passthrough_success_handler_obj, - url_route=url_route, - request_body=request_body, - endpoint_type=endpoint_type, - start_time=start_time, - all_chunks=all_chunks, - end_time=end_time, - ) - standard_logging_response_object = ( - anthropic_passthrough_logging_handler_result["result"] - ) - kwargs = anthropic_passthrough_logging_handler_result["kwargs"] - elif endpoint_type == EndpointType.VERTEX_AI: - vertex_passthrough_logging_handler_result = VertexPassthroughLoggingHandler._handle_logging_vertex_collected_chunks( - litellm_logging_obj=litellm_logging_obj, - passthrough_success_handler_obj=passthrough_success_handler_obj, - url_route=url_route, - request_body=request_body, - endpoint_type=endpoint_type, - start_time=start_time, - all_chunks=all_chunks, - end_time=end_time, - model=model, - ) - standard_logging_response_object = ( - vertex_passthrough_logging_handler_result["result"] - ) - kwargs = vertex_passthrough_logging_handler_result["kwargs"] - elif endpoint_type == EndpointType.OPENAI: - openai_passthrough_logging_handler_result = OpenAIPassthroughLoggingHandler._handle_logging_openai_collected_chunks( - litellm_logging_obj=litellm_logging_obj, - passthrough_success_handler_obj=passthrough_success_handler_obj, - url_route=url_route, - request_body=request_body, - endpoint_type=endpoint_type, - start_time=start_time, - all_chunks=all_chunks, - end_time=end_time, - ) - standard_logging_response_object = ( - openai_passthrough_logging_handler_result["result"] - ) - kwargs = openai_passthrough_logging_handler_result["kwargs"] - - if standard_logging_response_object is None: - standard_logging_response_object = StandardPassThroughResponseObject( - response=f"cannot parse chunks to standard response object. Chunks={all_chunks}" - ) await litellm_logging_obj.async_success_handler( result=standard_logging_response_object, start_time=start_time, @@ -199,6 +171,89 @@ class PassThroughStreamingHandler: f"Error in _route_streaming_logging_to_handler: {str(e)}" ) + @staticmethod + def _build_passthrough_logging_result( + litellm_logging_obj: LiteLLMLoggingObj, + passthrough_success_handler_obj: PassThroughEndpointLogging, + url_route: str, + request_body: dict, + endpoint_type: EndpointType, + start_time: datetime, + raw_bytes: List[bytes], + end_time: datetime, + model: Optional[str], + ) -> Tuple[PassThroughEndpointLoggingResultValues, dict]: + """ + Synchronous, CPU-bound reconstruction of the standard logging payload + from collected raw SSE bytes. Extracted from + _route_streaming_logging_to_handler so the per-endpoint dispatch can + be unit-tested in isolation. Still invoked synchronously on the event + loop; an off-loop dispatch is a future change, not part of this PR. + """ + all_chunks = PassThroughStreamingHandler._convert_raw_bytes_to_str_lines( + raw_bytes + ) + standard_logging_response_object: Optional[ + PassThroughEndpointLoggingResultValues + ] = None + kwargs: dict = {} + if endpoint_type == EndpointType.ANTHROPIC: + anthropic_passthrough_logging_handler_result = AnthropicPassthroughLoggingHandler._handle_logging_anthropic_collected_chunks( + litellm_logging_obj=litellm_logging_obj, + passthrough_success_handler_obj=passthrough_success_handler_obj, + url_route=url_route, + request_body=request_body, + endpoint_type=endpoint_type, + start_time=start_time, + all_chunks=all_chunks, + end_time=end_time, + ) + standard_logging_response_object = ( + anthropic_passthrough_logging_handler_result["result"] + ) + kwargs = anthropic_passthrough_logging_handler_result["kwargs"] + elif endpoint_type == EndpointType.VERTEX_AI: + vertex_passthrough_logging_handler_result = ( + VertexPassthroughLoggingHandler._handle_logging_vertex_collected_chunks( + litellm_logging_obj=litellm_logging_obj, + passthrough_success_handler_obj=passthrough_success_handler_obj, + url_route=url_route, + request_body=request_body, + endpoint_type=endpoint_type, + start_time=start_time, + all_chunks=all_chunks, + end_time=end_time, + model=model, + ) + ) + standard_logging_response_object = ( + vertex_passthrough_logging_handler_result["result"] + ) + kwargs = vertex_passthrough_logging_handler_result["kwargs"] + elif endpoint_type == EndpointType.OPENAI: + openai_passthrough_logging_handler_result = ( + OpenAIPassthroughLoggingHandler._handle_logging_openai_collected_chunks( + litellm_logging_obj=litellm_logging_obj, + passthrough_success_handler_obj=passthrough_success_handler_obj, + url_route=url_route, + request_body=request_body, + endpoint_type=endpoint_type, + start_time=start_time, + all_chunks=all_chunks, + end_time=end_time, + ) + ) + standard_logging_response_object = ( + openai_passthrough_logging_handler_result["result"] + ) + kwargs = openai_passthrough_logging_handler_result["kwargs"] + + if standard_logging_response_object is None: + standard_logging_response_object = StandardPassThroughResponseObject( + response=f"cannot parse chunks to standard response object. Chunks={all_chunks}" + ) + return standard_logging_response_object, kwargs + @staticmethod def _extract_model_for_cost_injection( request_body: Optional[dict], diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py index 032ab6c63b..14f7f411e4 100644 --- a/litellm/proxy/utils.py +++ b/litellm/proxy/utils.py @@ -1596,7 +1596,22 @@ class ProxyLogging: iterator_overrides.append((resolved, "override")) elif "apply_guardrail" in cls_attrs: iterator_overrides.append((resolved, "apply_guardrail")) - if "async_post_call_streaming_hook" in cls_attrs: + # Walk the MRO for ``async_post_call_streaming_hook`` rather than + # using the leaf-class ``__dict__`` check used by the other flags: + # before this PR the hook was unconditionally invoked, so a + # callback that inherits an override from an intermediate parent + # (e.g. a vendor base class providing the override, with the + # registered class adding nothing else) MUST still be detected. + # A leaf-class miss here would silently drop the inherited hook. + base_streaming_hook = CustomLogger.async_post_call_streaming_hook + cls_streaming_hook = getattr( + cls, + "async_post_call_streaming_hook", + base_streaming_hook, + ) + if getattr( + cls_streaming_hook, "__func__", cls_streaming_hook + ) is not getattr(base_streaming_hook, "__func__", base_streaming_hook): has_streaming_chunk_override = True if "async_pre_call_hook" in cls_attrs: has_pre_call_override = True diff --git a/scripts/benchmark_anthropic_messages_perf.py b/scripts/benchmark_anthropic_messages_perf.py new file mode 100644 index 0000000000..3c8a22f0cc --- /dev/null +++ b/scripts/benchmark_anthropic_messages_perf.py @@ -0,0 +1,624 @@ +#!/usr/bin/env python3 +"""Benchmark LiteLLM proxy /v1/messages (Anthropic Messages API) streaming. + +Measures the two metrics that matter for an interactive streaming proxy: + + * TTFT - time to first streamed token (first ``content_block_delta``) + * TPM - sustained output token throughput (tokens / second) once the + full stream is consumed, plus request throughput (RPS) + +It boots a local mock Anthropic provider that speaks the real Anthropic +streaming SSE wire format (``message_start`` -> ``content_block_delta`` -> +``message_stop``) and a LiteLLM proxy from any checkout, so commits/branches +can be compared without depending on real provider latency. + +Example: + uv run python scripts/benchmark_anthropic_messages_perf.py \ + --label baseline --proxy-command ".venv/bin/litellm" + +Compare an already-running proxy: + uv run python scripts/benchmark_anthropic_messages_perf.py \ + --no-start-proxy --label current +""" + +from __future__ import annotations + +import argparse +import asyncio +import json +import os +import shlex +import signal +import statistics +import subprocess +import tempfile +import time +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Optional + +import aiohttp +from aiohttp import web + +DEFAULT_MODEL = "claude-perf-test" +DEFAULT_API_KEY = "sk-1234" + + +@dataclass +class StreamSample: + success: bool + ttft_ms: float + total_ms: float + output_tokens: int + status_code: int + error: str = "" + + +@dataclass +class SummaryStats: + requests: int + failures: int + rps: float + ttft_mean_ms: float + ttft_p50_ms: float + ttft_p95_ms: float + ttft_p99_ms: float + total_p50_ms: float + total_p95_ms: float + tokens_per_sec: float + + +class MockAnthropicProvider: + """Minimal Anthropic Messages API server (real streaming SSE format).""" + + def __init__( + self, + host: str, + port: int, + first_token_delay_ms: float, + stream_content_chunks: int, + ) -> None: + self.host = host + self.port = port + self.first_token_delay_ms = first_token_delay_ms + self.stream_content_chunks = stream_content_chunks + self.runner: Optional[web.AppRunner] = None + + @property + def base_url(self) -> str: + return f"http://{self.host}:{self.port}" + + async def start(self) -> None: + app = web.Application() + app.router.add_post("/v1/messages", self.handle_messages) + self.runner = web.AppRunner(app, access_log=None) + await self.runner.setup() + site = web.TCPSite(self.runner, self.host, self.port) + await site.start() + + async def stop(self) -> None: + if self.runner is not None: + await self.runner.cleanup() + + async def handle_messages(self, request: web.Request) -> web.StreamResponse: + body = await request.json() + if body.get("stream"): + return await self._streaming_response(request, body) + return self._json_response(body) + + def _json_response(self, body: dict[str, Any]) -> web.Response: + payload = { + "id": "msg_perf", + "type": "message", + "role": "assistant", + "model": body.get("model", DEFAULT_MODEL), + "content": [{"type": "text", "text": "hello"}], + "stop_reason": "end_turn", + "stop_sequence": None, + "usage": {"input_tokens": 8, "output_tokens": 1}, + } + return web.json_response(payload) + + @staticmethod + def _sse(event: str, data: dict[str, Any]) -> bytes: + return f"event: {event}\ndata: {json.dumps(data)}\n\n".encode() + + async def _streaming_response( + self, request: web.Request, body: dict[str, Any] + ) -> web.StreamResponse: + model = body.get("model", DEFAULT_MODEL) + response = web.StreamResponse( + status=200, + headers={ + "Content-Type": "text/event-stream", + "Cache-Control": "no-cache", + }, + ) + await response.prepare(request) + + await response.write( + self._sse( + "message_start", + { + "type": "message_start", + "message": { + "id": "msg_perf", + "type": "message", + "role": "assistant", + "model": model, + "content": [], + "stop_reason": None, + "stop_sequence": None, + "usage": {"input_tokens": 8, "output_tokens": 0}, + }, + }, + ) + ) + await response.write( + self._sse( + "content_block_start", + { + "type": "content_block_start", + "index": 0, + "content_block": {"type": "text", "text": ""}, + }, + ) + ) + + if self.first_token_delay_ms > 0: + await asyncio.sleep(self.first_token_delay_ms / 1000) + + for _ in range(self.stream_content_chunks): + await response.write( + self._sse( + "content_block_delta", + { + "type": "content_block_delta", + "index": 0, + "delta": {"type": "text_delta", "text": "hello "}, + }, + ) + ) + + await response.write( + self._sse("content_block_stop", {"type": "content_block_stop", "index": 0}) + ) + await response.write( + self._sse( + "message_delta", + { + "type": "message_delta", + "delta": {"stop_reason": "end_turn", "stop_sequence": None}, + "usage": {"output_tokens": self.stream_content_chunks}, + }, + ) + ) + await response.write(self._sse("message_stop", {"type": "message_stop"})) + await response.write_eof() + return response + + +def percentile(values: list[float], pct: float) -> float: + if not values: + return 0.0 + sorted_values = sorted(values) + index = min(int(len(sorted_values) * pct / 100), len(sorted_values) - 1) + return sorted_values[index] + + +def summarize(samples: list[StreamSample], wall_time_s: float) -> SummaryStats: + ok = [s for s in samples if s.success] + ttfts = [s.ttft_ms for s in ok] + totals = [s.total_ms for s in ok] + total_tokens = sum(s.output_tokens for s in ok) + return SummaryStats( + requests=len(samples), + failures=len(samples) - len(ok), + rps=(len(ok) / wall_time_s) if wall_time_s > 0 else 0.0, + ttft_mean_ms=statistics.mean(ttfts) if ttfts else 0.0, + ttft_p50_ms=percentile(ttfts, 50), + ttft_p95_ms=percentile(ttfts, 95), + ttft_p99_ms=percentile(ttfts, 99), + total_p50_ms=percentile(totals, 50), + total_p95_ms=percentile(totals, 95), + # Aggregate output-token throughput: total tokens delivered across all + # successful requests divided by wall-clock time. This is the true + # server TPM and (unlike tokens / summed-per-request-latency) scales + # correctly with concurrency. + tokens_per_sec=(total_tokens / wall_time_s) if wall_time_s > 0 else 0.0, + ) + + +def get_git_revision(litellm_dir: Path) -> str: + try: + result = subprocess.run( + ["git", "rev-parse", "--short", "HEAD"], + cwd=litellm_dir, + check=True, + capture_output=True, + text=True, + ) + return result.stdout.strip() + except Exception: + return "unknown" + + +def write_proxy_config(config_path: Path, provider_base_url: str, api_key: str) -> None: + config_path.write_text( + f"""model_list: + - model_name: {DEFAULT_MODEL} + litellm_params: + model: anthropic/{DEFAULT_MODEL} + api_key: fake-provider-key + api_base: {provider_base_url} + +general_settings: + master_key: {api_key} + +litellm_settings: + telemetry: false +""", + encoding="utf-8", + ) + + +async def wait_for_proxy(base_url: str, timeout_s: float) -> None: + deadline = time.perf_counter() + timeout_s + last_error = "" + async with aiohttp.ClientSession() as session: + while time.perf_counter() < deadline: + try: + async with session.get(f"{base_url}/health/liveliness") as response: + if response.status < 500: + return + last_error = f"HTTP {response.status}" + except Exception as exc: + last_error = str(exc) + await asyncio.sleep(0.5) + raise TimeoutError(f"Timed out waiting for proxy at {base_url}: {last_error}") + + +def start_proxy_process( + litellm_dir: Path, + proxy_command: str, + config_path: Path, + port: int, + log_path: Path, +) -> subprocess.Popen: + command = shlex.split(proxy_command) + [ + "--config", + str(config_path), + "--port", + str(port), + ] + env = { + **os.environ, + "LITELLM_TELEMETRY": "False", + "PYTHONUNBUFFERED": "1", + } + log_file = log_path.open("w", encoding="utf-8") + return subprocess.Popen( + command, + cwd=litellm_dir, + env=env, + stdout=log_file, + stderr=subprocess.STDOUT, + start_new_session=True, + ) + + +def stop_proxy_process(process: subprocess.Popen) -> None: + if process.poll() is not None: + return + try: + os.killpg(process.pid, signal.SIGTERM) + process.wait(timeout=10) + except Exception: + try: + os.killpg(process.pid, signal.SIGKILL) + except Exception: + pass + + +async def measure_stream( + session: aiohttp.ClientSession, + url: str, + headers: dict[str, str], + payload: dict[str, Any], +) -> StreamSample: + start = time.perf_counter() + ttft_ms = 0.0 + output_tokens = 0 + try: + async with session.post(url, headers=headers, json=payload) as response: + if response.status != 200: + body = await response.read() + return StreamSample( + success=False, + ttft_ms=0.0, + total_ms=(time.perf_counter() - start) * 1000, + output_tokens=0, + status_code=response.status, + error=body.decode("utf-8", errors="ignore")[:200], + ) + async for raw_line in response.content: + line = raw_line.strip() + if not line.startswith(b"data:"): + continue + data = line[5:].strip() + if data == b"[DONE]": + break + try: + event = json.loads(data) + except json.JSONDecodeError: + continue + etype = event.get("type") + if etype == "content_block_delta": + if ttft_ms == 0.0: + ttft_ms = (time.perf_counter() - start) * 1000 + output_tokens += 1 + elif etype == "message_stop": + break + total_ms = (time.perf_counter() - start) * 1000 + if ttft_ms == 0.0: + return StreamSample( + success=False, + ttft_ms=0.0, + total_ms=total_ms, + output_tokens=0, + status_code=response.status, + error="stream ended before a content token", + ) + return StreamSample( + success=True, + ttft_ms=ttft_ms, + total_ms=total_ms, + output_tokens=output_tokens, + status_code=response.status, + ) + except Exception as exc: + return StreamSample( + success=False, + ttft_ms=0.0, + total_ms=(time.perf_counter() - start) * 1000, + output_tokens=0, + status_code=0, + error=str(exc)[:200], + ) + + +async def run_benchmark( + url: str, + headers: dict[str, str], + payload: dict[str, Any], + requests: int, + concurrency: int, + warmup: int, + timeout_s: float, +) -> SummaryStats: + timeout = aiohttp.ClientTimeout(total=timeout_s) + connector = aiohttp.TCPConnector( + limit=max(concurrency * 2, 10), + limit_per_host=max(concurrency, 10), + force_close=False, + ) + + async def worker( + session: aiohttp.ClientSession, + counter: list[int], + budget: int, + sink: list[StreamSample], + ) -> None: + # Steady-state load: exactly `concurrency` workers, each pulling the + # next request slot as soon as its previous one finishes. Keeps + # in-flight concurrency constant (vs. a gather-all + semaphore burst) + # which removes the thundering-herd variance that otherwise swamps a + # 10% signal. + while True: + idx = counter[0] + if idx >= budget: + return + counter[0] = idx + 1 + sink.append(await measure_stream(session, url, headers, payload)) + + async with aiohttp.ClientSession(connector=connector, timeout=timeout) as session: + if warmup > 0: + wcounter = [0] + await asyncio.gather( + *[worker(session, wcounter, warmup, []) for _ in range(concurrency)] + ) + samples: list[StreamSample] = [] + counter = [0] + wall_start = time.perf_counter() + await asyncio.gather( + *[worker(session, counter, requests, samples) for _ in range(concurrency)] + ) + wall_time_s = time.perf_counter() - wall_start + return summarize(samples, wall_time_s) + + +def stats_to_dict(stats: SummaryStats) -> dict[str, Any]: + return { + "requests": stats.requests, + "failures": stats.failures, + "rps": stats.rps, + "ttft_mean_ms": stats.ttft_mean_ms, + "ttft_p50_ms": stats.ttft_p50_ms, + "ttft_p95_ms": stats.ttft_p95_ms, + "ttft_p99_ms": stats.ttft_p99_ms, + "total_p50_ms": stats.total_p50_ms, + "total_p95_ms": stats.total_p95_ms, + "tokens_per_sec": stats.tokens_per_sec, + } + + +def print_summary(label: str, revision: str, stats: SummaryStats) -> None: + print("\n=== Anthropic /v1/messages streaming benchmark ===") + print(f"Label: {label}") + print(f"Revision: {revision}") + print(f"Requests: {stats.requests} Failures: {stats.failures}") + print(f"TTFT mean: {stats.ttft_mean_ms:.2f} ms") + print(f"TTFT p50: {stats.ttft_p50_ms:.2f} ms") + print(f"TTFT p95: {stats.ttft_p95_ms:.2f} ms") + print(f"TTFT p99: {stats.ttft_p99_ms:.2f} ms") + print(f"Full p50: {stats.total_p50_ms:.2f} ms") + print(f"Full p95: {stats.total_p95_ms:.2f} ms") + print(f"Throughput: {stats.rps:.2f} req/s") + print(f"TPM: {stats.tokens_per_sec:.1f} output tokens/s") + print("\nMarkdown row:") + print( + "| " + + " | ".join( + [ + label, + revision, + f"{stats.ttft_p50_ms:.2f}", + f"{stats.ttft_p95_ms:.2f}", + f"{stats.tokens_per_sec:.1f}", + f"{stats.rps:.2f}", + ] + ) + + " |" + ) + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("--label", default="current") + parser.add_argument("--litellm-dir", default=str(Path.cwd())) + parser.add_argument("--proxy-command", default="uv run litellm") + parser.add_argument("--proxy-host", default="127.0.0.1") + parser.add_argument("--proxy-port", type=int, default=4000) + parser.add_argument("--provider-host", default="127.0.0.1") + parser.add_argument("--provider-port", type=int, default=8098) + parser.add_argument("--api-key", default=DEFAULT_API_KEY) + parser.add_argument("--requests", type=int, default=300) + parser.add_argument("--concurrency", type=int, default=20) + parser.add_argument("--warmup", type=int, default=30) + parser.add_argument("--timeout", type=float, default=30) + parser.add_argument("--proxy-start-timeout", type=float, default=90) + parser.add_argument("--provider-first-token-delay-ms", type=float, default=0) + parser.add_argument( + "--provider-stream-content-chunks", + type=int, + default=64, + help="Number of text delta chunks the mock emits (default 64).", + ) + parser.add_argument( + "--repeats", + type=int, + default=1, + help="Run the suite N times against the same proxy; report the median run.", + ) + parser.add_argument( + "--no-start-proxy", + action="store_true", + help="Benchmark an already-running proxy at --proxy-host/--proxy-port", + ) + parser.add_argument( + "--provider-url", + help="Use an already-running Anthropic-compatible provider", + ) + parser.add_argument("--output-json", help="Write machine-readable results") + return parser.parse_args() + + +async def async_main() -> None: + args = parse_args() + litellm_dir = Path(args.litellm_dir).resolve() + revision = get_git_revision(litellm_dir) + proxy_base_url = f"http://{args.proxy_host}:{args.proxy_port}" + proxy_url = f"{proxy_base_url}/v1/messages" + headers = { + "Authorization": f"Bearer {args.api_key}", + "Content-Type": "application/json", + } + stream_payload = { + "model": DEFAULT_MODEL, + "max_tokens": 256, + "messages": [{"role": "user", "content": "hi"}], + "stream": True, + } + + provider: Optional[MockAnthropicProvider] = None + proxy_process: Optional[subprocess.Popen] = None + with tempfile.TemporaryDirectory(prefix="litellm-anthropic-perf-") as tmp_dir_name: + tmp_dir = Path(tmp_dir_name) + proxy_log_path = tmp_dir / "proxy.log" + if args.provider_url: + provider_base_url = args.provider_url.rstrip("/") + else: + provider = MockAnthropicProvider( + host=args.provider_host, + port=args.provider_port, + first_token_delay_ms=args.provider_first_token_delay_ms, + stream_content_chunks=args.provider_stream_content_chunks, + ) + await provider.start() + provider_base_url = provider.base_url + + config_path = tmp_dir / "config.yaml" + write_proxy_config(config_path, provider_base_url, args.api_key) + + try: + if not args.no_start_proxy: + proxy_process = start_proxy_process( + litellm_dir=litellm_dir, + proxy_command=args.proxy_command, + config_path=config_path, + port=args.proxy_port, + log_path=proxy_log_path, + ) + await wait_for_proxy(proxy_base_url, args.proxy_start_timeout) + + runs: list[SummaryStats] = [] + for run_idx in range(max(1, args.repeats)): + if args.repeats > 1: + print(f"\n--- Run {run_idx + 1}/{args.repeats} ---") + stats = await run_benchmark( + url=proxy_url, + headers=headers, + payload=stream_payload, + requests=args.requests, + concurrency=args.concurrency, + warmup=args.warmup, + timeout_s=args.timeout, + ) + runs.append(stats) + if args.repeats > 1: + print( + f" run {run_idx + 1}: TTFT p50={stats.ttft_p50_ms:.2f}ms " + f"TPM={stats.tokens_per_sec:.1f} tok/s RPS={stats.rps:.2f}" + ) + + stats = sorted(runs, key=lambda s: s.ttft_p50_ms)[len(runs) // 2] + finally: + if proxy_process is not None: + stop_proxy_process(proxy_process) + if provider is not None: + await provider.stop() + + print_summary(args.label, revision, stats) + + if args.output_json: + Path(args.output_json).write_text( + json.dumps( + { + "label": args.label, + "revision": revision, + "proxy_streaming": stats_to_dict(stats), + "proxy_log_path": str(proxy_log_path), + }, + indent=2, + sort_keys=True, + ), + encoding="utf-8", + ) + + +def main() -> None: + asyncio.run(async_main()) + + +if __name__ == "__main__": + main() diff --git a/tests/test_litellm/llms/anthropic/experimental_pass_through/messages/test_anthropic_experimental_pass_through_messages_handler.py b/tests/test_litellm/llms/anthropic/experimental_pass_through/messages/test_anthropic_experimental_pass_through_messages_handler.py index bd1fe75f36..b1e1d789d7 100644 --- a/tests/test_litellm/llms/anthropic/experimental_pass_through/messages/test_anthropic_experimental_pass_through_messages_handler.py +++ b/tests/test_litellm/llms/anthropic/experimental_pass_through/messages/test_anthropic_experimental_pass_through_messages_handler.py @@ -544,3 +544,132 @@ class TestThinkingSummaryPreservation: assert result == { "reasoning_effort": {"effort": "medium", "summary": "concise"} } + + +# --------------------------------------------------------------------------- +# Parity tests: redundant empty-text-block sanitization scan removal. +# The async wrapper sanitizes once and tells the handler to skip its second +# (redundant) full-messages scan; the sync entry point still sanitizes. +# --------------------------------------------------------------------------- + + +def _empty_block_msgs(): + return [ + { + "role": "assistant", + "content": [ + {"type": "text", "text": " "}, # whitespace-only -> stripped + {"type": "tool_use", "id": "t", "name": "B", "input": {}}, + ], + } + ] + + +def test_handler_strips_when_no_presanitized_flag(): + """Sync entry point (no async wrapper): handler must still sanitize.""" + from litellm.llms.anthropic.experimental_pass_through.messages import handler + + with patch.object( + handler, + "strip_empty_text_blocks_from_anthropic_messages", + wraps=handler.strip_empty_text_blocks_from_anthropic_messages, + ) as spy: + result = handler.anthropic_messages_handler( + max_tokens=10, + messages=_empty_block_msgs(), + model="anthropic/claude-3-5-sonnet-20241022", + custom_llm_provider="anthropic", + mock_response="hi there", + ) + assert spy.call_count == 1 # sanitized exactly once here + assert result is not None + + +def test_handler_skips_strip_when_presanitized(): + """Async wrapper already sanitized -> handler must NOT rescan.""" + from litellm.llms.anthropic.experimental_pass_through.messages import handler + + with patch.object( + handler, + "strip_empty_text_blocks_from_anthropic_messages", + wraps=handler.strip_empty_text_blocks_from_anthropic_messages, + ) as spy: + result = handler.anthropic_messages_handler( + max_tokens=10, + messages=_empty_block_msgs(), + model="anthropic/claude-3-5-sonnet-20241022", + custom_llm_provider="anthropic", + mock_response="hi there", + _litellm_messages_presanitized=True, + ) + assert spy.call_count == 0 # skipped the redundant scan + assert result is not None + + +def test_presanitized_flag_not_leaked_to_provider_params(): + """The private sentinel must be popped, never forwarded as a request param.""" + from litellm.llms.anthropic.experimental_pass_through.messages import handler + + captured = {} + + def fake_base_handler(*args, **kwargs): + captured.update(kwargs) + captured["optional"] = kwargs.get( + "anthropic_messages_optional_request_params", {} + ) + return "stub" + + with patch.object( + handler.base_llm_http_handler, + "anthropic_messages_handler", + side_effect=fake_base_handler, + ): + handler.anthropic_messages_handler( + max_tokens=10, + messages=[{"role": "user", "content": "hi"}], + model="anthropic/claude-3-5-sonnet-20241022", + custom_llm_provider="anthropic", + _litellm_messages_presanitized=True, + ) + + assert "_litellm_messages_presanitized" not in captured.get("optional", {}) + assert "_litellm_messages_presanitized" not in captured.get("kwargs", {}) + + +@pytest.mark.asyncio +async def test_async_wrapper_sets_presanitized_and_sanitizes_once(): + """End-to-end: wrapper sanitizes (once) AND signals the handler to skip.""" + from litellm.llms.anthropic.experimental_pass_through.messages import handler + + captured = {} + + def fake_handler(*args, **kwargs): + captured["messages"] = kwargs.get("messages") + captured["presanitized"] = kwargs.get("_litellm_messages_presanitized") + return "stub" + + fake_loop = MagicMock() + fake_loop.run_in_executor = lambda _e, func: _async_return(func()) + + with ( + patch.object(handler, "anthropic_messages_handler", side_effect=fake_handler), + patch("asyncio.get_event_loop", return_value=fake_loop), + patch.object( + handler, + "strip_empty_text_blocks_from_anthropic_messages", + wraps=handler.strip_empty_text_blocks_from_anthropic_messages, + ) as spy, + ): + await handler.anthropic_messages( + max_tokens=100, + messages=_empty_block_msgs(), + model="anthropic/claude-sonnet-4-5-20250929", + custom_llm_provider="anthropic", + api_key="k", + ) + + # Wrapper stripped exactly once (the handler is faked, so its skipped + # call never runs anyway -- the point is the wrapper still sanitizes). + assert spy.call_count == 1 + assert captured["presanitized"] is True + assert [b["type"] for b in captured["messages"][0]["content"]] == ["tool_use"] diff --git a/tests/test_litellm/llms/anthropic/experimental_pass_through/messages/test_request_optional_param_utils.py b/tests/test_litellm/llms/anthropic/experimental_pass_through/messages/test_request_optional_param_utils.py new file mode 100644 index 0000000000..3ce076640e --- /dev/null +++ b/tests/test_litellm/llms/anthropic/experimental_pass_through/messages/test_request_optional_param_utils.py @@ -0,0 +1,56 @@ +""" +Regression tests for the /v1/messages request-parse fast paths: + +- get_requested_anthropic_messages_optional_param must still filter to the + valid AnthropicMessagesRequestOptionalParams keys and drop None values, + while resolving the (static) type hints only once per process. +""" + +from litellm.llms.anthropic.experimental_pass_through.messages.utils import ( + AnthropicMessagesRequestUtils, + _anthropic_messages_optional_param_keys, +) + + +def test_optional_param_filtering_unchanged(): + params = { + "temperature": 0.5, + "top_p": None, # None dropped + "tools": [{"name": "x"}], + "not_a_real_param": "drop me", # invalid key dropped + "stream": True, + } + result = ( + AnthropicMessagesRequestUtils.get_requested_anthropic_messages_optional_param( + params + ) + ) + assert result == {"temperature": 0.5, "tools": [{"name": "x"}], "stream": True} + assert "top_p" not in result + assert "not_a_real_param" not in result + + +def test_valid_keys_are_memoized(): + _anthropic_messages_optional_param_keys.cache_clear() + first = _anthropic_messages_optional_param_keys() + for _ in range(50): + AnthropicMessagesRequestUtils.get_requested_anthropic_messages_optional_param( + {"temperature": 0.1} + ) + info = _anthropic_messages_optional_param_keys.cache_info() + # Resolved exactly once despite many calls. + assert info.misses == 1 + assert info.hits >= 50 + # Stable identity (frozenset) returned each call. + assert _anthropic_messages_optional_param_keys() is first + assert isinstance(first, frozenset) + assert "temperature" in first and "tools" in first + + +def test_empty_params(): + assert ( + AnthropicMessagesRequestUtils.get_requested_anthropic_messages_optional_param( + {} + ) + == {} + ) diff --git a/tests/test_litellm/llms/custom_httpx/test_llm_http_handler.py b/tests/test_litellm/llms/custom_httpx/test_llm_http_handler.py index b846cd600f..07a61c9c10 100644 --- a/tests/test_litellm/llms/custom_httpx/test_llm_http_handler.py +++ b/tests/test_litellm/llms/custom_httpx/test_llm_http_handler.py @@ -101,6 +101,73 @@ def test_get_agentic_loop_settings_defaults_and_overrides(): assert fingerprints == ["fp-1", "fp-2"] +def test_has_agentic_completion_hook_detection(monkeypatch): + """The streaming path skips the agentic wrapper only when no callback + overrides async_should_run_agentic_loop. Verify both directions.""" + from litellm.integrations.custom_logger import CustomLogger + + handler = BaseLLMHTTPHandler() + logging_obj = Mock() + logging_obj.dynamic_success_callbacks = [] + + # No callbacks at all -> no agentic hook. + monkeypatch.setattr(litellm, "callbacks", []) + assert handler._has_agentic_completion_hook(logging_obj) is False + + # A plain CustomLogger that does NOT override the gate -> still no hook + # (so the wrapper is safely skipped). + class _PlainLogger(CustomLogger): + pass + + monkeypatch.setattr(litellm, "callbacks", [_PlainLogger()]) + assert handler._has_agentic_completion_hook(logging_obj) is False + + # A logger that overrides the gate (directly) -> hook present. + class _AgenticLogger(CustomLogger): + async def async_should_run_agentic_loop( + self, response, model, messages, tools, stream, custom_llm_provider, kwargs + ): + return True, {} + + monkeypatch.setattr(litellm, "callbacks", [_AgenticLogger()]) + assert handler._has_agentic_completion_hook(logging_obj) is True + + # Override inherited through an intermediate class is still detected + # (function-identity check, not a leaf __dict__ check). + class _DerivedAgenticLogger(_AgenticLogger): + pass + + monkeypatch.setattr(litellm, "callbacks", [_DerivedAgenticLogger()]) + assert handler._has_agentic_completion_hook(logging_obj) is True + + # Hook supplied via logging_obj.dynamic_success_callbacks is detected too. + monkeypatch.setattr(litellm, "callbacks", []) + logging_obj.dynamic_success_callbacks = [_AgenticLogger()] + assert handler._has_agentic_completion_hook(logging_obj) is True + + # String-named callback entry (e.g. "datadog") must be resolved to its + # CustomLogger instance via get_custom_logger_compatible_class -- the same + # way ProxyLogging._callback_capabilities handles them. Without that + # resolution a string-registered agentic callback would be silently + # skipped and the buffering wrapper would never fire. + logging_obj.dynamic_success_callbacks = [] + agentic_via_string = _AgenticLogger() + monkeypatch.setattr(litellm, "callbacks", ["fake_string_callback"]) + monkeypatch.setattr( + "litellm.litellm_core_utils.litellm_logging.get_custom_logger_compatible_class", + lambda name: agentic_via_string if name == "fake_string_callback" else None, + ) + assert handler._has_agentic_completion_hook(logging_obj) is True + + # Unresolvable string (returns None) is skipped, no false positive. + monkeypatch.setattr(litellm, "callbacks", ["unknown_callback"]) + monkeypatch.setattr( + "litellm.litellm_core_utils.litellm_logging.get_custom_logger_compatible_class", + lambda name: None, + ) + assert handler._has_agentic_completion_hook(logging_obj) is False + + def test_fingerprint_agentic_tools_is_deterministic(): handler = BaseLLMHTTPHandler() tools_a = {"tool_calls": [{"id": "1", "input": {"q": "abc"}, "name": "web_search"}]} @@ -422,3 +489,143 @@ def test_sync_delete_responses_omits_body_for_azure(): assert captured["url"].endswith( "/openai/responses/resp_xyz?api-version=2025-03-01-preview" ) + + +# --------------------------------------------------------------------------- +# Parity tests: request-body is serialized once and reused for the wire. +# (_async_post_anthropic_messages_with_http_error_retry) +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_anthropic_post_uses_prebuilt_body_without_redumping(): + """When the caller passes a pre-serialized (unsigned) body, attempt 0 must + send exactly those bytes -- no second json.dumps of request_body.""" + import json as _json + + handler = BaseLLMHTTPHandler() + request_body = {"model": "claude", "messages": [{"role": "user", "content": "hi"}]} + prebuilt = _json.dumps(request_body) + + ok_resp = Mock() + ok_resp.raise_for_status = Mock(return_value=None) + http_client = Mock() + http_client.post = AsyncMock(return_value=ok_resp) + + provider_config = Mock() + provider_config.max_retry_on_anthropic_messages_http_error = 2 + + logging_obj = Mock() + logging_obj.model_call_details = {} + + out = await handler._async_post_anthropic_messages_with_http_error_retry( + async_httpx_client=http_client, + request_url="http://x/v1/messages", + headers={}, + signed_json_body=prebuilt, + request_body=request_body, + stream=False, + logging_obj=logging_obj, + provider_config=provider_config, + litellm_params=GenericLiteLLMParams(), + api_key="k", + model="claude", + ) + assert out is ok_resp + http_client.post.assert_awaited_once() + sent = http_client.post.await_args.kwargs["data"] + # Byte-identical to the legacy wire serialization, and the SAME object the + # caller already used for the pre-call log (no re-serialization). + assert sent == prebuilt + assert sent is prebuilt + + +@pytest.mark.asyncio +async def test_anthropic_post_falls_back_to_json_dumps_when_unsigned_none(): + """signed_json_body=None keeps the exact legacy behavior.""" + import json as _json + + handler = BaseLLMHTTPHandler() + request_body = {"model": "claude", "messages": [{"role": "user", "content": "yo"}]} + + ok_resp = Mock() + ok_resp.raise_for_status = Mock(return_value=None) + http_client = Mock() + http_client.post = AsyncMock(return_value=ok_resp) + + provider_config = Mock() + provider_config.max_retry_on_anthropic_messages_http_error = 1 + logging_obj = Mock() + logging_obj.model_call_details = {} + + await handler._async_post_anthropic_messages_with_http_error_retry( + async_httpx_client=http_client, + request_url="http://x/v1/messages", + headers={}, + signed_json_body=None, + request_body=request_body, + stream=False, + logging_obj=logging_obj, + provider_config=provider_config, + litellm_params=GenericLiteLLMParams(), + api_key="k", + model="claude", + ) + sent = http_client.post.await_args.kwargs["data"] + assert sent == _json.dumps(request_body) + + +@pytest.mark.asyncio +async def test_anthropic_post_retry_reserializes_mutated_body(): + """On a retryable HTTP error the body is mutated + re-signed; the prebuilt + body must NOT be reused -- attempt 1 sends the freshly serialized body.""" + import json as _json + + handler = BaseLLMHTTPHandler() + request_body = {"model": "claude", "messages": [{"role": "user", "content": "a"}]} + prebuilt = _json.dumps(request_body) + + err_resp = Mock() + http_error = httpx.HTTPStatusError( + "bad", request=Mock(), response=Mock(status_code=400) + ) + err_resp.raise_for_status = Mock(side_effect=http_error) + ok_resp = Mock() + ok_resp.raise_for_status = Mock(return_value=None) + http_client = Mock() + http_client.post = AsyncMock(side_effect=[err_resp, ok_resp]) + + def _mutate(e, request_data): + request_data["messages"][0]["content"] = "MUTATED" + + provider_config = Mock() + provider_config.max_retry_on_anthropic_messages_http_error = 2 + provider_config.should_retry_anthropic_messages_on_http_error = Mock( + return_value=True + ) + provider_config.transform_anthropic_messages_request_on_http_error = _mutate + # Re-sign returns no signed body (native anthropic path) -> must re-dump. + provider_config.sign_request = Mock(return_value=({}, None)) + + logging_obj = Mock() + logging_obj.model_call_details = {} + + await handler._async_post_anthropic_messages_with_http_error_retry( + async_httpx_client=http_client, + request_url="http://x/v1/messages", + headers={}, + signed_json_body=prebuilt, + request_body=request_body, + stream=False, + logging_obj=logging_obj, + provider_config=provider_config, + litellm_params=GenericLiteLLMParams(), + api_key="k", + model="claude", + ) + assert http_client.post.await_count == 2 + first_sent = http_client.post.await_args_list[0].kwargs["data"] + second_sent = http_client.post.await_args_list[1].kwargs["data"] + assert first_sent == prebuilt # attempt 0 used prebuilt + assert second_sent == _json.dumps(request_body) # attempt 1 re-serialized + assert "MUTATED" in second_sent # ... the mutated body diff --git a/tests/test_litellm/proxy/pass_through_endpoints/llm_provider_handlers/test_anthropic_passthrough_logging_handler.py b/tests/test_litellm/proxy/pass_through_endpoints/llm_provider_handlers/test_anthropic_passthrough_logging_handler.py index c16c42decc..0a9e303103 100644 --- a/tests/test_litellm/proxy/pass_through_endpoints/llm_provider_handlers/test_anthropic_passthrough_logging_handler.py +++ b/tests/test_litellm/proxy/pass_through_endpoints/llm_provider_handlers/test_anthropic_passthrough_logging_handler.py @@ -684,3 +684,362 @@ class TestAnthropicBatchPassthroughCostTracking: mock_proxy_logging_obj.get_proxy_hook.assert_called_once_with( "managed_files" ) + + +class TestPureTextFastPathParity: + """ + The pure-text fast path in _build_complete_streaming_response must produce + a response (and downstream logging/cost payload) byte-identical to the + legacy stream_chunk_builder path. Anything non-text must fall back. + """ + + @staticmethod + def _sse(event, data): + return f"event: {event}\ndata: {json.dumps(data)}\n\n".encode() + + @staticmethod + def _to_all_chunks(raw_frames): + # Mirror production: raw bytes -> _convert_raw_bytes_to_str_lines. + from litellm.proxy.pass_through_endpoints.streaming_handler import ( + PassThroughStreamingHandler, + ) + + return PassThroughStreamingHandler._convert_raw_bytes_to_str_lines(raw_frames) + + @staticmethod + def _norm(resp): + if resp is None: + return None + d = resp.model_dump() + # id / created are non-deterministic even between two legacy runs. + d.pop("id", None) + d.pop("created", None) + return d + + def _text_stream( + self, + texts, + *, + input_tokens=12, + cache_creation=0, + cache_read=0, + stop_reason="end_turn", + with_ping=True, + blocks=1, + ): + frames = [ + self._sse( + "message_start", + { + "type": "message_start", + "message": { + "id": "msg_abc", + "type": "message", + "role": "assistant", + "model": "claude-3-5-sonnet-20241022", + "content": [], + "stop_reason": None, + "stop_sequence": None, + "usage": { + "input_tokens": input_tokens, + "output_tokens": 0, + "cache_creation_input_tokens": cache_creation, + "cache_read_input_tokens": cache_read, + }, + }, + }, + ) + ] + per_block = max(1, len(texts) // blocks) + ti = 0 + for b in range(blocks): + frames.append( + self._sse( + "content_block_start", + { + "type": "content_block_start", + "index": b, + "content_block": {"type": "text", "text": ""}, + }, + ) + ) + if with_ping: + frames.append(self._sse("ping", {"type": "ping"})) + chunk_texts = texts[ti : ti + per_block] if b < blocks - 1 else texts[ti:] + ti += per_block + for t in chunk_texts: + frames.append( + self._sse( + "content_block_delta", + { + "type": "content_block_delta", + "index": b, + "delta": {"type": "text_delta", "text": t}, + }, + ) + ) + frames.append( + self._sse( + "content_block_stop", {"type": "content_block_stop", "index": b} + ) + ) + frames.append( + self._sse( + "message_delta", + { + "type": "message_delta", + "delta": {"stop_reason": stop_reason, "stop_sequence": None}, + "usage": {"output_tokens": len(texts)}, + }, + ) + ) + frames.append(self._sse("message_stop", {"type": "message_stop"})) + return frames + + def _assert_parity(self, raw_frames): + all_chunks = self._to_all_chunks(raw_frames) + lo1 = MagicMock() + lo1.model_call_details = {} + lo2 = MagicMock() + lo2.model_call_details = {} + + legacy = AnthropicPassthroughLoggingHandler._build_complete_streaming_response_legacy( + all_chunks=list(all_chunks), + litellm_logging_obj=lo1, + model="claude-3-5-sonnet-20241022", + ) + fast = AnthropicPassthroughLoggingHandler._build_complete_streaming_response( + all_chunks=list(all_chunks), + litellm_logging_obj=lo2, + model="claude-3-5-sonnet-20241022", + ) + assert self._norm(fast) == self._norm(legacy) + + # Downstream logged/billed payload must also match. + start = datetime.now() + end = datetime.now() + k_legacy = AnthropicPassthroughLoggingHandler._create_anthropic_response_logging_payload( + litellm_model_response=legacy, + model="claude-3-5-sonnet-20241022", + kwargs={}, + start_time=start, + end_time=end, + logging_obj=lo1, + ) + k_fast = AnthropicPassthroughLoggingHandler._create_anthropic_response_logging_payload( + litellm_model_response=fast, + model="claude-3-5-sonnet-20241022", + kwargs={}, + start_time=start, + end_time=end, + logging_obj=lo2, + ) + # Usage drives cost; it must be byte-identical between paths. + assert getattr(fast, "usage", None) == getattr(legacy, "usage", None) + + # And the full logged payload (sans non-deterministic response id). + def _scrub(p): + d = dict(p) + r = d.get("complete_streaming_response_in_db") or d.get( + "complete_streaming_response" + ) + return d, getattr(r, "usage", None) + + assert _scrub(k_fast)[1] == _scrub(k_legacy)[1] + + def test_parity_simple_text(self): + self._assert_parity(self._text_stream(["Hello", " ", "world", "!"])) + + def test_parity_single_delta(self): + self._assert_parity(self._text_stream(["Just one piece of text."])) + + def test_parity_cache_tokens(self): + self._assert_parity( + self._text_stream( + ["a", "b", "c"], input_tokens=20, cache_creation=5, cache_read=7 + ) + ) + + def test_parity_max_tokens_stop(self): + self._assert_parity(self._text_stream(["tok"] * 8, stop_reason="max_tokens")) + + def test_parity_no_ping(self): + self._assert_parity(self._text_stream(["x", "y"], with_ping=False)) + + def test_parity_empty_text_deltas(self): + self._assert_parity(self._text_stream(["", "hi", "", "there"])) + + def test_parity_multi_text_block(self): + self._assert_parity(self._text_stream(["p1", "p2", "p3", "p4"], blocks=2)) + + def test_parity_multibyte_batched_frames(self): + # Several SSE events delivered in one network chunk. + frames = self._text_stream(["alpha", "beta", "gamma"]) + merged = b"".join(frames) + self._assert_parity([merged]) + + def test_collapse_returns_none_for_tool_use(self): + frames = [ + self._sse( + "message_start", + { + "type": "message_start", + "message": { + "id": "m", + "model": "x", + "role": "assistant", + "type": "message", + "content": [], + "usage": {"input_tokens": 1, "output_tokens": 0}, + }, + }, + ), + self._sse( + "content_block_start", + { + "type": "content_block_start", + "index": 0, + "content_block": { + "type": "tool_use", + "id": "t1", + "name": "get_weather", + "input": {}, + }, + }, + ), + self._sse( + "content_block_delta", + { + "type": "content_block_delta", + "index": 0, + "delta": {"type": "input_json_delta", "partial_json": "{}"}, + }, + ), + self._sse("content_block_stop", {"type": "content_block_stop", "index": 0}), + self._sse("message_stop", {"type": "message_stop"}), + ] + all_chunks = self._to_all_chunks(frames) + assert ( + AnthropicPassthroughLoggingHandler._collapse_pure_text_chunks( + list(all_chunks) + ) + is None + ) + + def test_collapse_returns_none_for_thinking(self): + frames = [ + self._sse( + "message_start", + { + "type": "message_start", + "message": { + "id": "m", + "model": "x", + "role": "assistant", + "type": "message", + "content": [], + "usage": {"input_tokens": 1, "output_tokens": 0}, + }, + }, + ), + self._sse( + "content_block_start", + { + "type": "content_block_start", + "index": 0, + "content_block": {"type": "thinking", "thinking": ""}, + }, + ), + self._sse( + "content_block_delta", + { + "type": "content_block_delta", + "index": 0, + "delta": {"type": "thinking_delta", "thinking": "hmm"}, + }, + ), + self._sse("message_stop", {"type": "message_stop"}), + ] + all_chunks = self._to_all_chunks(frames) + assert ( + AnthropicPassthroughLoggingHandler._collapse_pure_text_chunks( + list(all_chunks) + ) + is None + ) + + def test_collapse_actually_shrinks_chunk_count(self): + frames = self._text_stream(["a"] * 50) + all_chunks = list(self._to_all_chunks(frames)) + collapsed = AnthropicPassthroughLoggingHandler._collapse_pure_text_chunks( + all_chunks + ) + assert collapsed is not None + # 50 text deltas + 50 event markers + 1 ping collapse to far fewer. + assert len(collapsed) < len(all_chunks) / 2 + + def test_collapse_returns_none_for_interleaved_block_indexes(self): + """ + Anthropic sends content blocks strictly sequentially (start/deltas/stop + for one, then the next). If a stream ever interleaves deltas across + block indexes, the fast path must bail to legacy rather than merge text + from different blocks under a single index. + """ + frames = [ + self._sse( + "message_start", + { + "type": "message_start", + "message": { + "id": "msg_abc", + "type": "message", + "role": "assistant", + "model": "claude-3-5-sonnet-20241022", + "content": [], + "stop_reason": None, + "stop_sequence": None, + "usage": {"input_tokens": 1, "output_tokens": 0}, + }, + }, + ), + self._sse( + "content_block_start", + { + "type": "content_block_start", + "index": 0, + "content_block": {"type": "text", "text": ""}, + }, + ), + self._sse( + "content_block_start", + { + "type": "content_block_start", + "index": 1, + "content_block": {"type": "text", "text": ""}, + }, + ), + # Interleave: delta for block 0, then delta for block 1, with no + # content_block_stop between them. + self._sse( + "content_block_delta", + { + "type": "content_block_delta", + "index": 0, + "delta": {"type": "text_delta", "text": "hello "}, + }, + ), + self._sse( + "content_block_delta", + { + "type": "content_block_delta", + "index": 1, + "delta": {"type": "text_delta", "text": "world"}, + }, + ), + self._sse("message_stop", {"type": "message_stop"}), + ] + all_chunks = list(self._to_all_chunks(frames)) + assert ( + AnthropicPassthroughLoggingHandler._collapse_pure_text_chunks(all_chunks) + is None + ) diff --git a/tests/test_litellm/proxy/test_common_request_processing.py b/tests/test_litellm/proxy/test_common_request_processing.py index 31a25916e2..265a82d4a4 100644 --- a/tests/test_litellm/proxy/test_common_request_processing.py +++ b/tests/test_litellm/proxy/test_common_request_processing.py @@ -10,6 +10,7 @@ from fastapi.responses import JSONResponse, StreamingResponse import litellm from litellm._uuid import uuid +from litellm.integrations.custom_logger import CustomLogger from litellm.integrations.opentelemetry import UserAPIKeyAuth from litellm.proxy.common_request_processing import ( ProxyBaseLLMRequestProcessing, @@ -1316,8 +1317,17 @@ class TestCommonRequestProcessingHelpers: yield 'data: {"content": "chunk 3"}\n\n' yield "data: [DONE]\n\n" - # Patch the tracer in the common_request_processing module - with patch("litellm.proxy.common_request_processing.tracer", mock_tracer): + # Patch the tracer in the common_request_processing module. The + # per-chunk span is gated on _DD_STREAMING_TRACE_ENABLED (resolved at + # import from the real tracer, a NullTracer by default), so enable it + # explicitly to exercise the tracing path. + with ( + patch("litellm.proxy.common_request_processing.tracer", mock_tracer), + patch( + "litellm.proxy.common_request_processing._DD_STREAMING_TRACE_ENABLED", + True, + ), + ): response = await create_response(mock_generator(), "text/event-stream", {}) assert response.status_code == 200 @@ -1345,6 +1355,40 @@ class TestCommonRequestProcessingHelpers: args[0] == "streaming.chunk.yield" ), f"Call {i} should have operation name 'streaming.chunk.yield', got {args[0]}" + async def test_create_streaming_response_skips_dd_trace_when_disabled(self): + """When DD tracing is disabled (the default), the per-chunk span + context manager is skipped entirely but all chunks still stream.""" + from unittest.mock import patch + + mock_tracer = MagicMock() + + async def mock_generator(): + yield 'data: {"content": "chunk 1"}\n\n' + yield 'data: {"content": "chunk 2"}\n\n' + yield "data: [DONE]\n\n" + + with ( + patch("litellm.proxy.common_request_processing.tracer", mock_tracer), + patch( + "litellm.proxy.common_request_processing._DD_STREAMING_TRACE_ENABLED", + False, + ), + ): + response = await create_response(mock_generator(), "text/event-stream", {}) + + assert response.status_code == 200 + + content = await self.consume_stream(response) + + # All chunks stream through unchanged ... + assert content == [ + 'data: {"content": "chunk 1"}\n\n', + 'data: {"content": "chunk 2"}\n\n', + "data: [DONE]\n\n", + ] + # ... but no per-chunk span was created. + assert mock_tracer.trace.call_count == 0 + async def test_create_streaming_response_dd_trace_with_error_chunk(self): """ Test that when the first chunk contains an error, JSONResponse is returned @@ -2199,3 +2243,77 @@ class TestHandleLLMApiExceptionDictDetail: proxy_exc = await self._invoke(exc) assert proxy_exc.message == "Content blocked by guardrail" assert proxy_exc.provider_specific_fields is None + + +class TestAsyncStreamingDataGeneratorFastPath: + """Fast/slow path branching in async_streaming_data_generator.""" + + @staticmethod + async def _aiter(items): + for item in items: + yield item + + @pytest.mark.asyncio + async def test_fast_path_skips_per_chunk_hook(self, monkeypatch): + """With no callbacks/guardrails/cost-injection, chunks pass through + unchanged and the per-chunk hook is NOT awaited.""" + monkeypatch.setattr(litellm, "callbacks", []) + ProxyLogging._callback_capabilities_cache.clear() + + proxy_logging_obj = ProxyLogging(user_api_key_cache=MagicMock()) + hook_spy = AsyncMock(side_effect=lambda **kw: kw["response"]) + monkeypatch.setattr( + proxy_logging_obj, "async_post_call_streaming_hook", hook_spy + ) + + chunks = [b"event: a\ndata: {}\n\n", b"event: b\ndata: {}\n\n"] + out = [ + c + async for c in ProxyBaseLLMRequestProcessing.async_streaming_data_generator( + response=self._aiter(chunks), + user_api_key_dict=MagicMock(spec=UserAPIKeyAuth), + request_data={"model": "claude-x"}, + proxy_logging_obj=proxy_logging_obj, + serialize_chunk=ProxyBaseLLMRequestProcessing.return_sse_chunk, + serialize_error=lambda e: "data: error\n\n", + ) + ] + + assert out == chunks # bytes pass through return_sse_chunk untouched + hook_spy.assert_not_awaited() + + @pytest.mark.asyncio + async def test_slow_path_runs_per_chunk_hook(self, monkeypatch): + """A callback that overrides async_post_call_streaming_hook forces the + slow path and the per-chunk hook is invoked.""" + + class _StreamingCb(CustomLogger): + async def async_post_call_streaming_hook(self, user_api_key_dict, response): + return response + + cb = _StreamingCb() + monkeypatch.setattr(litellm, "callbacks", [cb]) + ProxyLogging._callback_capabilities_cache.clear() + + proxy_logging_obj = ProxyLogging(user_api_key_cache=MagicMock()) + hook_spy = AsyncMock(side_effect=lambda **kw: kw["response"]) + monkeypatch.setattr( + proxy_logging_obj, "async_post_call_streaming_hook", hook_spy + ) + + out = [ + c + async for c in ProxyBaseLLMRequestProcessing.async_streaming_data_generator( + response=self._aiter([{"type": "message_stop"}]), + user_api_key_dict=MagicMock(spec=UserAPIKeyAuth), + request_data={"model": "claude-x"}, + proxy_logging_obj=proxy_logging_obj, + serialize_chunk=ProxyBaseLLMRequestProcessing.return_sse_chunk, + serialize_error=lambda e: "data: error\n\n", + ) + ] + + assert len(out) == 1 + hook_spy.assert_awaited_once() + + ProxyLogging._callback_capabilities_cache.clear() diff --git a/tests/test_litellm/proxy/test_proxy_logging_hook_detection.py b/tests/test_litellm/proxy/test_proxy_logging_hook_detection.py index 4aebcf40aa..f596703056 100644 --- a/tests/test_litellm/proxy/test_proxy_logging_hook_detection.py +++ b/tests/test_litellm/proxy/test_proxy_logging_hook_detection.py @@ -111,6 +111,28 @@ def test_callback_capabilities_captures_iterator_override(monkeypatch): assert kind == "override" +def test_callback_capabilities_detects_inherited_streaming_chunk_override(monkeypatch): + """ + ``async_post_call_streaming_hook`` must be detected even when the override + lives on an intermediate parent class — a vendor base class can carry the + override and the registered class can add nothing else. Before this PR the + hook was unconditionally invoked, so a leaf-class ``__dict__`` miss here + would silently drop the inherited hook. + """ + ProxyLogging._callback_capabilities_cache.clear() + + class _StreamingBase(CustomLogger): + async def async_post_call_streaming_hook(self, *args, **kwargs): # type: ignore[override] + return kwargs.get("response") + + class _LeafWithoutOverride(_StreamingBase): + pass + + monkeypatch.setattr(litellm, "callbacks", [_LeafWithoutOverride()]) + caps = ProxyLogging._callback_capabilities() + assert caps.has_streaming_chunk_override is True + + def test_callback_capabilities_cache_invalidates_on_list_change(monkeypatch): """The cache key includes (length, id-of-each-callback). Mutating the callback list must produce a fresh capability snapshot."""